├── .gitignore ├── README.md ├── setup.py └── src ├── callbacks └── callbacks.py ├── config.yaml ├── data_processing ├── contrastive │ ├── create_mit_contrastive.py │ └── create_mmx_contrastive.py ├── labels │ └── moments_categories.csv ├── temporal │ ├── create_mit_temporal.py │ ├── create_mmx_frames.py │ └── create_mmx_temporal.py ├── tools │ ├── admin.py │ ├── nearest_neighbour.py │ └── test.ann └── transforms │ ├── audio_transforms.py │ ├── img_transforms.py │ └── spatio_cut.py ├── dataloaders ├── mit │ ├── MIT_Contrastive_dl.py │ └── MIT_Temporal_dl.py └── mmx │ ├── MMX_Contrastive_dl.py │ ├── MMX_Frame_dl.py │ ├── MMX_Light_dl.py │ └── MMX_Temporal_dl.py ├── frame.png ├── main.py ├── models ├── .contrastivemodel.py.swp ├── LSTM.py ├── TPN.py ├── basicmlp.py ├── collabgating.py ├── contrastivemodel.py ├── custom_resnet.py ├── frame_transformer.py ├── losses │ └── ntxent.py ├── pretrained │ └── models.py ├── transformer.py └── vit.py ├── test.png └── tests ├── test_dataloaders.py ├── test_tensors.py └── test_transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | /lightninglogs 2 | /wandb 3 | **/__pycache__ 4 | */mmx_tensors* 5 | /data 6 | /runs 7 | *.pkl 8 | /trained_models/ 9 | .vscode/ 10 | 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Data efficient video transformers for video classification 2 | 3 | Includes spatial-temporal pyramid network, mutli-modal distillation, multi-modal cross attention. 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="src", 5 | version="0.0.1", 6 | description="Frame stacked transformer test", 7 | author="Ed Fish", 8 | author_email="edward.fish@surrey.ac.uk", 9 | url="https://github.com/ed-fish/self-supervised-video", 10 | install_requires=["pytorch-lightning"], 11 | packages=find_packages(), 12 | ) -------------------------------------------------------------------------------- /src/callbacks/callbacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import device, Tensor 3 | import pickle 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.optim import Optimizer 7 | from torchmetrics.functional import auroc 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import Callback 10 | # from torchmetrics import ConfusionMatrix 11 | import pandas as pd 12 | import torchmetrics 13 | import sklearn 14 | from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, classification_report 15 | import numpy as np 16 | import wandb 17 | import matplotlib.pyplot as plt 18 | import seaborn as sns 19 | import pickle 20 | 21 | from sklearn.metrics import f1_score, recall_score, average_precision_score, precision_score 22 | from torchmetrics.functional import f1, auroc 23 | #from torchmetrics.functional import accuracy 24 | #from pytorch_lightning.metrics.functional import accuracy 25 | 26 | 27 | class TransformerEval(Callback): 28 | 29 | def on_validation_epoch_end(self, trainer, pl_module): 30 | 31 | target_names = ['Action', 'Animation', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 32 | 'Fantasy', 'History', 'Horror', 'Music', 'Romance', 'Mystery', 'TVMovie', 'ScienceFiction', 'Thriller', 'War', 'Western'] 33 | state = "val" 34 | running_labels = torch.cat(pl_module.running_labels).cpu() 35 | running_logits = torch.cat(pl_module.running_logits).cpu() 36 | t = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] 37 | for threshold in t: 38 | accuracy = f1_score(running_labels.to(int), (running_logits > threshold).to( 39 | int), average="samples", zero_division=0) 40 | # recall = recall_score(running_labels.to(int), (running_logits > threshold).to(int), average="weighted", zero_division=1) 41 | # precision = precision_score(running_labels.to(int), (running_logits > threshold).to(int), average="weighted", zero_division=1) 42 | # avg_precision = average_precision_score(running_labels.to(int), (running_logits > threshold).to(int), average="weighted") 43 | 44 | pl_module.log(f"{state}/online/f1@{str(threshold)}", accuracy) 45 | #pl_module.log(f"{state}/online/recall@{str(threshold)}", recall, on_epoch=True) 46 | #pl_module.log(f"{state}/online/precision@{str(threshold)}", precision, on_epoch=True) 47 | #pl_module.log(f"{state}/online/avg_precision@{str(threshold)}", avg_precision, on_epoch=True) 48 | 49 | aprc_samples = average_precision_score(running_labels.to( 50 | int), running_logits, average="samples") 51 | pl_module.log("sklearn apr", aprc_samples) 52 | 53 | aprc_weight = average_precision_score(running_labels.to( 54 | int), running_logits, average="weighted") 55 | pl_module.log("sklearn apr weighted", aprc_weight) 56 | 57 | print(aprc_samples) 58 | print(aprc_weight) 59 | print(classification_report(running_labels.to(int), (running_logits > 0.3).to(int), target_names=target_names)) 60 | 61 | pl_module.running_labels = [] 62 | pl_module.running_logits = [] 63 | 64 | label_str = [] 65 | target_str = [] 66 | 67 | def on_test_epoch_end(self, trainer, pl_module): 68 | 69 | target_names = ['Action', 'Animation', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 70 | 'Fantasy', 'History', 'Horror', 'Music', 'Romance', 'Mystery', 'TVMovie', 'ScienceFiction', 'Thriller', 'War', 'Western'] 71 | 72 | 73 | state = "val" 74 | running_labels = torch.cat(pl_module.running_labels).cpu() 75 | running_logits = torch.cat(pl_module.running_logits).cpu() 76 | with open("labels", "wb") as fp: 77 | pickle.dump(running_labels, fp) 78 | with open("logits", "wb") as fp: 79 | pickle.dump(running_labels, fp) 80 | running_labels = running_labels.to(int) 81 | running_logits = (running_logits > 0.3).to(int) 82 | print(classification_report(running_labels, running_logits, target_names=target_names)) 83 | 84 | 85 | class MITEval(Callback): 86 | def __init__(self): 87 | self.best_acc = 0 88 | 89 | def on_validation_epoch_end(self, trainer, pl_module): 90 | running_labels = torch.cat(pl_module.running_labels) 91 | running_logits = torch.cat(pl_module.running_logits) 92 | acc = torch.sum(running_logits == running_labels).item() / \ 93 | (len(running_labels) * 1.0) 94 | pl_module.log("val/accuracy/epoch", acc, on_step=False, on_epoch=True) 95 | print( 96 | f"acc:{acc} len_S:{len(running_labels)} ex:{running_labels[0]} : {running_logits[0]}") 97 | pl_module.running_labels = [] 98 | pl_module.running_logits = [] 99 | 100 | # if acc > self.best_acc: 101 | # trainer.save_checkpoint("mit_location_acc.ckkpt") 102 | # self.best_acc = acc 103 | 104 | 105 | class DisplayResults(Callback): 106 | def on_test_end(self, trainer, pl_module): 107 | 108 | cache_dict = {} 109 | 110 | running_logits = pl_module.running_logits 111 | running_labels = pl_module.running_labels 112 | running_embeds = pl_module.running_embeds 113 | running_paths = pl_module.running_paths 114 | 115 | # 52 10 15 116 | 117 | running_labels = torch.cat(running_labels).cpu().numpy() 118 | running_embeds = torch.cat(running_embeds).cpu().numpy() 119 | running_logits = torch.cat(running_logits).cpu().numpy() 120 | running_paths = [x for i in running_paths for x in i] 121 | running_logits = np.where(running_logits > 0.3, 1, 0) 122 | running_labels = running_labels.astype(int) 123 | 124 | for x, i in enumerate(running_labels): 125 | print("paths", x, running_paths[x]) 126 | print("actual", x, ":", self.n_to_labels(i)) 127 | print("predicted", x, ":", self.n_to_labels(running_logits[x])) 128 | print("embedding", x, ":", running_embeds[x]) 129 | cache_dict[x] = {"path": running_paths[x], "embedding": running_embeds[x], 130 | "predicted": self.n_to_labels(running_logits[x]), "actual": self.n_to_labels(i)} 131 | with open("embed_dict", "wb") as file: 132 | pickle.dump(cache_dict, file) 133 | 134 | # print(running_logits) 135 | 136 | def n_to_labels(self, vector): 137 | labels = [] 138 | 139 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 140 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 141 | for i, x in enumerate(vector): 142 | if x: 143 | labels.append(target_names[i]) 144 | return labels 145 | 146 | 147 | class SSLOnlineEval(Callback): 148 | """ 149 | Attached mlp for fine tuning - edited version from ligtning sslonlineval docs 150 | """ 151 | 152 | def __init__(self, drop_p=0.1, z_dim=None, num_classes=None, model="MIT"): 153 | super().__init__() 154 | 155 | self.drop_p = drop_p 156 | self.optimizer: Optimizer 157 | 158 | self.z_dim = z_dim 159 | self.num_classes = num_classes 160 | self.loss = nn.BCELoss() 161 | 162 | def on_pretrain_routine_start(self, trainer, pl_module): 163 | from pl_bolts.models.self_supervised.evaluator import SSLEvaluator 164 | pl_module.non_linear_evaluator = SSLEvaluator( 165 | n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p).to(pl_module.device) 166 | self.optimizer = torch.optim.SGD( 167 | pl_module.non_linear_evaluator.parameters(), lr=0.005) 168 | 169 | def get_representations(self, pl_module, x): 170 | x = x.squeeze() 171 | representations, _ = pl_module(x) 172 | return representations 173 | 174 | def to_device(self, batch, device): 175 | 176 | x_i_experts = batch["x_i_experts"] 177 | label = batch["label"] 178 | 179 | x_i_experts = [torch.cat(x, dim=-1) for x in x_i_experts] 180 | x_i_input = torch.stack(x_i_experts) 181 | labels = torch.stack(label) 182 | 183 | x_i_input = x_i_input.to(device) 184 | labels = labels.to(device) 185 | 186 | return x_i_input, labels 187 | 188 | def on_train_batch_end( 189 | self, 190 | trainer, 191 | pl_module, 192 | outputs, 193 | batch, 194 | batch_idx, 195 | data_loader_idx 196 | ): 197 | 198 | x, labels = self.to_device(batch, pl_module.device) 199 | 200 | with torch.no_grad(): 201 | representations = self.get_representations(pl_module, x) 202 | 203 | representations = representations.detach() 204 | logits = pl_module.non_linear_evaluator(representations) 205 | logits = torch.sigmoid(logits) 206 | 207 | # pl_module.running_logits.append(logits) 208 | # pl_module.running_labels.append(labels) 209 | 210 | mlp_loss = self.loss(logits, labels) 211 | pl_module.log("train/online/loss", mlp_loss) 212 | mlp_loss.backward() 213 | self.optimizer.step() 214 | 215 | def on_validation_batch_end( 216 | self, 217 | trainer, 218 | pl_module, 219 | outputs, 220 | batch, 221 | batch_idx, 222 | data_loader_idx 223 | ): 224 | 225 | x, labels = self.to_device(batch, pl_module.device) 226 | 227 | with torch.no_grad(): 228 | representations = self.get_representations(pl_module, x) 229 | 230 | representations = representations.detach() 231 | 232 | logits = pl_module.non_linear_evaluator(representations) 233 | logits = torch.sigmoid(logits) 234 | 235 | mlp_loss = self.loss(logits, labels) 236 | pl_module.log("val/online/loss", mlp_loss) 237 | logits = logits.cpu() 238 | labels = labels.cpu() 239 | 240 | pl_module.running_logits.append(logits) 241 | pl_module.running_labels.append(labels) 242 | 243 | def on_validation_epoch_end(self, trainer, pl_module): 244 | self.on_shared_end(pl_module, "val") 245 | 246 | # def on_train_epoch_end(self, trainer, pl_module): 247 | # self.on_shared_end(pl_module, "train") 248 | 249 | def on_shared_end(self, pl_module, state): 250 | 251 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 252 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 253 | running_labels = torch.cat(pl_module.running_labels) 254 | running_logits = torch.cat(pl_module.running_logits) 255 | # running_logits = F.sigmoid(running_logits) 256 | thresholds = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] 257 | for t in thresholds: 258 | accuracy = f1_score(running_labels.to(int), (running_logits > t).to( 259 | int), average="weighted", zero_division=1) 260 | recall = recall_score(running_labels.to(int), (running_logits > t).to( 261 | int), average="weighted", zero_division=1) 262 | precision = precision_score(running_labels.to( 263 | int), (running_logits > t).to(int), average="weighted", zero_division=1) 264 | avg_precision = average_precision_score(running_labels.to( 265 | int), (running_logits > t).to(int), average="weighted") 266 | 267 | pl_module.log(f"{state}/online/f1@{str(t)}", 268 | accuracy, on_epoch=True) 269 | pl_module.log(f"{state}/online/recall@{str(t)}", 270 | recall, on_epoch=True) 271 | pl_module.log(f"{state}/online/precision@{str(t)}", 272 | precision, on_epoch=True) 273 | pl_module.log(f"{state}/online/avg_precision@{str(t)}", 274 | avg_precision, on_epoch=True) 275 | 276 | running_labels = running_labels.to(int).numpy() 277 | running_logits = (running_logits > 0.3).to(int).numpy() 278 | 279 | pl_module.running_labels = [] 280 | pl_module.running_logits = [] 281 | 282 | label_str = [] 283 | target_str = [] 284 | 285 | test_table = wandb.Table(columns=["truth", "guess"]) 286 | 287 | for i in range(0, 20): 288 | test_table.add_data(self.translate_labels( 289 | running_labels[i]), self.translate_labels(running_logits[i])) 290 | 291 | pl_module.logger.experiment.log({"table": test_table}) 292 | 293 | def translate_labels(self, label_vec): 294 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 295 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 296 | labels = [] 297 | for i, l in enumerate(label_vec): 298 | if l: 299 | labels.append(target_names[i]) 300 | return labels 301 | -------------------------------------------------------------------------------- /src/config.yaml: -------------------------------------------------------------------------------- 1 | # General params 2 | batch_size: 2 3 | learning_rate: 0.000005 4 | epochs: 500 5 | seq_len: 13 6 | frame_len: 12 7 | test: False 8 | 9 | # Optimisation 10 | dropout: 0.5 11 | momentum: 0.005 12 | weight_decay: 0.09 13 | scheduling: True 14 | warm_up: 2 15 | n_classes: 15 16 | opt: "adamW" 17 | 18 | # num_samples: 50000 19 | # Architecure optimisation 20 | 21 | input_dimension: 2048 22 | nhead: 8 23 | token_embedding: 305 24 | nlayers: 8 25 | nhid: 2048 26 | projection_size: 305 27 | data_set: "mmx-frame" 28 | 29 | # double_transformer, single_transformer, lstm, frame_transformer, sum, frame, vid, pre_modal, sum_residual 30 | 31 | model: "vid" 32 | logger: "double_transformer" 33 | name: "mmx-frame-test" 34 | 35 | #experts: ["test-video-embeddings", "test-location-embeddings", "test-img-embeddings", "audio-embeddings"] 36 | experts: ["img-embeddings", "location-embeddings", "video-embeddings"] 37 | # experts: ["location-embeddings", "img-embeddings", "video-embeddings", "audio-embeddings"] 38 | # pool or None 39 | cls: 1 40 | 41 | # Multi modal settings 42 | mixing_method: "double_trans" 43 | 44 | device: 1 45 | save_path: "/trained_models/mit/transformer/" 46 | -------------------------------------------------------------------------------- /src/data_processing/contrastive/create_mit_contrastive.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | import os 4 | import pandas as pd 5 | import glob 6 | import tqdm 7 | import multiprocessing as mp 8 | import numpy as np 9 | import torch 10 | import pickle 11 | import resource 12 | 13 | 14 | def create_dictionary(filepath): 15 | experts = ["audio-embeddings", "location-embeddings", "img-embeddings", "video-embeddings"] 16 | 17 | #orig_dir = filepath.replace("/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 18 | 19 | label = filepath.split("/")[-3] 20 | 21 | #dirs = os.listdir(orig_dir) 22 | #meta_data = os.path.join(orig_dir, "meta.pkl") 23 | #with open(meta_data, "rb") as pickly: 24 | #label = os.path.basename(filepath) 25 | chunk_dict = dict() 26 | for chunk in glob.glob(filepath + "/*/"): 27 | expert_dict = {} 28 | for expert_dir in experts: 29 | tens_dir = os.path.join(chunk, expert_dir) 30 | try: 31 | if len(os.listdir(tens_dir)) > 1: 32 | tensor_list = [] 33 | for tensor in glob.glob(tens_dir + "/*.pt"): 34 | # tensor_list.append(torch.load(tensor, map_location="cpu")) 35 | tensor_list.append(tensor) 36 | expert_dict[expert_dir] = tensor_list 37 | elif len(os.listdir(tens_dir)) == 1: 38 | expert_tensor = glob.glob(tens_dir + "/*.pt") 39 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 40 | expert_dict[expert_dir] = expert_tensor 41 | #expert_list.append(expert_tensor[0]) 42 | else: 43 | # No audio embedding available 44 | continue 45 | except: 46 | continue 47 | chunk_str = chunk.split("/")[-2] 48 | chunk_dict[os.path.basename(chunk_str)] = expert_dict 49 | master_dict = {"path":filepath, "label":label, "data":chunk_dict} 50 | return master_dict 51 | 52 | 53 | def squish_folders(input_dir): 54 | all_files = [] 55 | for labels in glob.glob(input_dir + "/*/"): 56 | for video in glob.glob(labels + "/*/"): 57 | all_files.append(video) 58 | print(labels, "complete") 59 | print("length of files", len(all_files)) 60 | with open("mit_train_cache.pkl", "wb") as cache: 61 | pickle.dump(all_files, cache) 62 | 63 | 64 | def mp_handler(): 65 | p = mp.Pool(40) 66 | data_list = [] 67 | count = 0 68 | 69 | squish_folders("/mnt/fvpbignas/datasets/moments_in_time/Moments_in_Time_Aug/training") 70 | with open("mit_train_cache.pkl", 'rb') as cache: 71 | data = pickle.load(cache) 72 | random.shuffle(data) 73 | 74 | with open("mit_tensors_train.pkl", 'ab') as pkly: 75 | for result in p.imap(create_dictionary, tqdm.tqdm(data, total=len(data))): 76 | if result: 77 | pickle.dump(result, pkly) 78 | 79 | 80 | if __name__ == "__main__": 81 | torch.multiprocessing.set_sharing_strategy('file_system') 82 | # rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 83 | # resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 84 | mp_handler() 85 | # squish_folders(input_dir) 86 | -------------------------------------------------------------------------------- /src/data_processing/contrastive/create_mmx_contrastive.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | import os 4 | import pandas as pd 5 | import glob 6 | import tqdm 7 | import multiprocessing as mp 8 | import numpy as np 9 | import torch 10 | import pickle 11 | import resource 12 | 13 | 14 | input_dir = "/mnt/bigelow/scratch/mmx_aug/" 15 | 16 | def create_embedding_dict(filepath): 17 | genre_name = filepath.split("/")[-3:-1] 18 | orig_dir = filepath.replace("/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 19 | print(orig_dir) 20 | dirs = os.listdir(orig_dir) 21 | meta_data = os.path.join(orig_dir, dirs[0], "meta.pkl") 22 | with open(meta_data, "rb") as pickly: 23 | label = pickle.load(pickly) 24 | 25 | # subdirs = [000,001,002] 26 | scenes = glob.glob(filepath + "/*/") 27 | 28 | # if len(subdirs) < 2: 29 | # return False 30 | 31 | experts = ["location-embeddings", "img-embeddings", "video-embeddings", "audio-embeddings"] 32 | out_dict = dict() 33 | scene_dict = dict() 34 | 35 | for scene in scenes: 36 | # chunks is a list of filepaths [001/001, 001/002] 37 | chunks = glob.glob(scene + "/*/") 38 | chunk_dict = dict() 39 | 40 | for chunk in chunks: 41 | expert_dic = dict() 42 | for expert_dir in experts: 43 | tens_dir = os.path.join(chunk, expert_dir) 44 | if len(os.listdir(tens_dir)) > 1: 45 | tensor_list = [] 46 | for tensor in glob.glob(tens_dir + "/*.pt"): 47 | t = torch.load(tensor, map_location="cpu") 48 | t = t.cpu().detach().numpy() 49 | tensor_list.append(t) 50 | expert_list.append(tensor_list) 51 | elif len(os.listdir(tens_dir)) == 1: 52 | expert_tensor = glob.glob(tens_dir + "/*.pt") 53 | expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 54 | expert_tensor = expert_tensor.cpu().detach().numpy() 55 | expert_dict[expert_dir] = expert_tensor 56 | else: 57 | print("no audio") 58 | # No audio embedding available 59 | continue 60 | chunk_str = chunk.split("/")[-2] 61 | chunk_dict[os.path.basename(chunk_str)] = expert_dict 62 | 63 | scene_str = scene.split("/")[-2] 64 | scene_dict[os.path.basename(scene_str)] = chunk_dict 65 | out_dict = {"label": label, "name": name, "scenes": scene_dict} 66 | return out_dict 67 | 68 | 69 | def create_scene_dict_train(filepath): 70 | experts = ["location-embeddings", "img-embeddings", "video-embeddings", "audio-embeddings"] 71 | 72 | orig_dir = filepath.replace("/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 73 | 74 | scene = orig_dir.split("/")[-2] 75 | 76 | dirs = os.listdir(orig_dir) 77 | meta_data = os.path.join(orig_dir, "meta.pkl") 78 | with open(meta_data, "rb") as pickly: 79 | label = pickle.load(pickly) 80 | chunk_dict = dict() 81 | for chunk in glob.glob(filepath + "/*/"): 82 | expert_dict = dict() 83 | for expert_dir in experts: 84 | tens_dir = os.path.join(chunk, expert_dir) 85 | try: 86 | if len(os.listdir(tens_dir)) > 1: 87 | tensor_list = [] 88 | for tensor in glob.glob(tens_dir + "/*.pt"): 89 | # tensor_list.append(torch.load(tensor, map_location="cpu")) 90 | tensor_list.append(tensor) 91 | expert_dict[expert_dir] = tensor_list 92 | elif len(os.listdir(tens_dir)) == 1: 93 | expert_tensor = glob.glob(tens_dir + "/*.pt") 94 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 95 | expert_dict[expert_dir] = expert_tensor[0] 96 | else: 97 | # No audio embedding available 98 | continue 99 | except: 100 | continue 101 | chunk_str = chunk.split("/")[-2] 102 | chunk_dict[os.path.basename(chunk_str)] = expert_dict 103 | scene_dict = {"path":orig_dir, "scene":scene, "label":label, "data":chunk_dict} 104 | return scene_dict 105 | 106 | def create_scene_dict_test(filepath): 107 | experts = ["test-location-embeddings", "test-img-embeddings", "test-video-embeddings", "audio-embeddings"] 108 | 109 | orig_dir = filepath.replace("/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 110 | 111 | scene = orig_dir.split("/")[-2] 112 | 113 | dirs = os.listdir(orig_dir) 114 | meta_data = os.path.join(orig_dir, "meta.pkl") 115 | with open(meta_data, "rb") as pickly: 116 | label = pickle.load(pickly) 117 | chunk_dict = dict() 118 | for chunk in glob.glob(filepath + "/*/"): 119 | expert_dict = dict() 120 | for expert_dir in experts: 121 | tens_dir = os.path.join(chunk, expert_dir) 122 | try: 123 | if len(os.listdir(tens_dir)) > 1: 124 | tensor_list = [] 125 | for tensor in glob.glob(tens_dir + "/*.pt"): 126 | # tensor_list.append(torch.load(tensor, map_location="cpu")) 127 | tensor_list.append(tensor) 128 | expert_dict[expert_dir] = tensor_list 129 | elif len(os.listdir(tens_dir)) == 1: 130 | expert_tensor = glob.glob(tens_dir + "/*.pt") 131 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 132 | expert_dict[expert_dir] = expert_tensor[0] 133 | else: 134 | # No audio embedding available 135 | continue 136 | except: 137 | continue 138 | chunk_str = chunk.split("/")[-2] 139 | chunk_dict[os.path.basename(chunk_str)] = expert_dict 140 | scene_dict = {"path":orig_dir, "scene":scene, "label":label, "data":chunk_dict} 141 | return scene_dict 142 | 143 | 144 | def squish_folders(input_dir): 145 | all_files = [] 146 | for genres in tqdm.tqdm(glob.glob(input_dir + "/*/")): 147 | for movies in os.listdir(genres): 148 | path = os.path.join(genres, movies) 149 | for scene in glob.glob(path + "/*/"): 150 | all_files.append(scene) 151 | print("length of files", len(all_files)) 152 | with open("cache.pkl", "wb") as cache: 153 | pickle.dump(all_files, cache) 154 | 155 | 156 | def mp_handler(): 157 | p = mp.Pool(30) 158 | data_list = [] 159 | count = 0 160 | with open("cache.pkl", 'rb') as cache: 161 | data = pickle.load(cache) 162 | random.shuffle(data) 163 | train_data = data[:int((len(data)+1)*.90)] #Remaining 80% to training set 164 | test_data = data[int((len(data)+1)*.90):] #Splits 20% data to test set 165 | print("training_data", len(train_data)) 166 | print("testing_data", len(test_data)) 167 | big_list = [] 168 | 169 | 170 | with open("mmx_tensors_train.pkl", 'ab') as pkly: 171 | for result in p.imap(create_scene_dict_train, tqdm.tqdm(train_data, total=len(train_data))): 172 | if result: 173 | pickle.dump(result, pkly) 174 | 175 | with open("mmx_tensors_val.pkl", 'ab') as pkly: 176 | for result in p.imap(create_scene_dict_test, tqdm.tqdm(test_data, total=len(test_data))): 177 | if result: 178 | pickle.dump(result, pkly) 179 | #pickle.dump(result, pkly) 180 | # data_list.append(result) 181 | # if len(data_list) + count == 50000: 182 | # with open("master_tensors1.pkl", "wb") as csv_file: 183 | # pickle.dump(data_list, csv_file) 184 | # print("dumped 500000") 185 | # data_list = [] 186 | # count = 50001 187 | 188 | # with open("mmx_tensors_test.pkl", 'ab') as pkly: 189 | # for result in p.imap(create_scene_dict, tqdm.tqdm(test_data, total=len(working_dirs))): 190 | # if result: 191 | # pickle.dump(result, pkly) 192 | # data_list.append(result) 193 | # if len(data_list) + count == 50000: 194 | # with open("master_tensors1.pkl", "wb") as csv_file: 195 | # pickle.dump(data_list, csv_file) 196 | # print("dumped 500000") 197 | # data_list = [] 198 | # count = 50001 199 | 200 | if __name__ == "__main__": 201 | torch.multiprocessing.set_sharing_strategy('file_system') 202 | # rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 203 | # resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 204 | 205 | squish_folders(input_dir) 206 | mp_handler() 207 | -------------------------------------------------------------------------------- /src/data_processing/labels/moments_categories.csv: -------------------------------------------------------------------------------- 1 | label,id 2 | clapping,0 3 | dropping,1 4 | burying,2 5 | covering,3 6 | flooding,4 7 | leaping,5 8 | drinking,6 9 | raining,7 10 | stitching,8 11 | spraying,9 12 | twisting,10 13 | coaching,11 14 | submerging,12 15 | breaking,13 16 | boarding,14 17 | running,15 18 | destroying,16 19 | competing,17 20 | giggling,18 21 | shoveling,19 22 | chasing,20 23 | flicking,21 24 | pouring,22 25 | hammering,23 26 | carrying,24 27 | surfing,25 28 | pulling,26 29 | squatting,27 30 | crouching,28 31 | tapping,29 32 | skipping,30 33 | washing,31 34 | winking,32 35 | queuing,33 36 | locking,34 37 | stopping,35 38 | sneezing,36 39 | flipping,37 40 | sewing,38 41 | clipping,39 42 | working,40 43 | rocking,41 44 | asking,42 45 | playing+fun,43 46 | camping,44 47 | plugging,45 48 | pedaling,46 49 | constructing,47 50 | slipping,48 51 | sweeping,49 52 | screwing,50 53 | shrugging,51 54 | hitchhiking,52 55 | cracking,53 56 | scratching,54 57 | trimming,55 58 | selling,56 59 | stirring,57 60 | jumping,58 61 | starting,59 62 | clinging,60 63 | socializing,61 64 | picking,62 65 | splashing,63 66 | licking,64 67 | kicking,65 68 | sliding,66 69 | filming,67 70 | driving,68 71 | handwriting,69 72 | steering,70 73 | filling,71 74 | pressing,72 75 | shouting,73 76 | hiking,74 77 | vacuuming,75 78 | pointing,76 79 | giving,77 80 | diving,78 81 | hugging,79 82 | building,80 83 | dining,81 84 | floating,82 85 | leaning,83 86 | sailing,84 87 | singing,85 88 | playing,86 89 | bubbling,87 90 | joining,88 91 | raising,89 92 | sitting,90 93 | drawing,91 94 | rinsing,92 95 | coughing,93 96 | slicing,94 97 | balancing,95 98 | rafting,96 99 | kneeling,97 100 | dunking,98 101 | brushing,99 102 | crushing,100 103 | watering,101 104 | playing+music,102 105 | removing,103 106 | tearing,104 107 | imitating,105 108 | teaching,106 109 | cooking,107 110 | reaching,108 111 | studying,109 112 | serving,110 113 | bulldozing,111 114 | shaking,112 115 | discussing,113 116 | dragging,114 117 | gardening,115 118 | performing,116 119 | officiating,117 120 | photographing,118 121 | sowing,119 122 | dripping,120 123 | writing,121 124 | clawing,122 125 | bending,123 126 | boxing,124 127 | mopping,125 128 | gripping,126 129 | flowing,127 130 | digging,128 131 | tripping,129 132 | cheering,130 133 | buying,131 134 | bicycling,132 135 | feeding,133 136 | emptying,134 137 | unpacking,135 138 | sketching,136 139 | standing,137 140 | weeding,138 141 | stacking,139 142 | drying,140 143 | crying,141 144 | spinning,142 145 | frying,143 146 | cutting,144 147 | paying,145 148 | eating,146 149 | lecturing,147 150 | dancing,148 151 | adult+female+speaking,149 152 | boiling,150 153 | peeling,151 154 | wrapping,152 155 | wetting,153 156 | welding,154 157 | putting,155 158 | swinging,156 159 | carving,157 160 | walking,158 161 | inflating,159 162 | climbing,160 163 | shredding,161 164 | reading,162 165 | sanding,163 166 | frowning,164 167 | closing,165 168 | hunting,166 169 | clearing,167 170 | launching,168 171 | packaging,169 172 | fishing,170 173 | spilling,171 174 | leaking,172 175 | knitting,173 176 | boating,174 177 | sprinkling,175 178 | playing+sports,176 179 | rolling,177 180 | spitting,178 181 | dipping,179 182 | riding,180 183 | chopping,181 184 | extinguishing,182 185 | applauding,183 186 | calling,184 187 | talking,185 188 | adult+male+speaking,186 189 | snowing,187 190 | shaving,188 191 | marrying,189 192 | rising,190 193 | laughing,191 194 | crawling,192 195 | flying,193 196 | assembling,194 197 | injecting,195 198 | landing,196 199 | operating,197 200 | packing,198 201 | descending,199 202 | falling,200 203 | entering,201 204 | pushing,202 205 | sawing,203 206 | smelling,204 207 | overflowing,205 208 | waking,206 209 | barbecuing,207 210 | skating,208 211 | painting,209 212 | drilling,210 213 | tying,211 214 | manicuring,212 215 | plunging,213 216 | grilling,214 217 | pitching,215 218 | towing,216 219 | telephoning,217 220 | crafting,218 221 | knocking,219 222 | playing+videogames,220 223 | storming,221 224 | placing,222 225 | turning,223 226 | barking,224 227 | child+singing,225 228 | opening,226 229 | juggling,227 230 | mowing,228 231 | sniffing,229 232 | interviewing,230 233 | stomping,231 234 | chewing,232 235 | grooming,233 236 | rowing,234 237 | bowing,235 238 | gambling,236 239 | saluting,237 240 | fueling,238 241 | autographing,239 242 | throwing,240 243 | drenching,241 244 | waving,242 245 | signing,243 246 | repairing,244 247 | baking,245 248 | smoking,246 249 | skiing,247 250 | drumming,248 251 | child+speaking,249 252 | blowing,250 253 | cleaning,251 254 | combing,252 255 | spreading,253 256 | racing,254 257 | combusting,255 258 | adult+female+singing,256 259 | swimming,257 260 | adult+male+singing,258 261 | shopping,259 262 | bouncing,260 263 | dusting,261 264 | stroking,262 265 | snapping,263 266 | biting,264 267 | roaring,265 268 | guarding,266 269 | unloading,267 270 | lifting,268 271 | instructing,269 272 | folding,270 273 | measuring,271 274 | whistling,272 275 | exiting,273 276 | stretching,274 277 | taping,275 278 | squinting,276 279 | catching,277 280 | draining,278 281 | scrubbing,279 282 | celebrating,280 283 | jogging,281 284 | bowling,282 285 | resting,283 286 | blocking,284 287 | smiling,285 288 | tattooing,286 289 | erupting,287 290 | howling,288 291 | grinning,289 292 | sprinting,290 293 | hanging,291 294 | planting,292 295 | speaking,293 296 | ascending,294 297 | yawning,295 298 | cramming,296 299 | burning,297 300 | wrestling,298 301 | poking,299 302 | tickling,300 303 | exercising,301 304 | loading,302 305 | piloting,303 306 | typing,304 307 | -------------------------------------------------------------------------------- /src/data_processing/temporal/create_mit_temporal.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | import os 4 | import pandas as pd 5 | import glob 6 | import tqdm 7 | import multiprocessing as mp 8 | import numpy as np 9 | import torch 10 | import pickle 11 | import resource 12 | 13 | 14 | def load_labels(label_root): 15 | label_df = pd.read_csv(label_root) 16 | label_df.set_index('label', inplace=True) 17 | print("len of labels = ", len(label_df)) 18 | return label_df 19 | 20 | 21 | def collect_labels(label_df, label): 22 | index = label_df.loc[label]["id"] 23 | return index 24 | 25 | 26 | def create_dictionary(filepath): 27 | experts = ["audio-embeddings", "location-embeddings", 28 | "img-embeddings", "video-embeddings", "test-location-embeddings", "test-img-embeddings", "test-video-embeddings"] 29 | 30 | #orig_dir = filepath.replace("/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 31 | 32 | label = filepath.split("/")[-3] 33 | label = collect_labels(label_df, label) 34 | 35 | #dirs = os.listdir(orig_dir) 36 | #meta_data = os.path.join(orig_dir, "meta.pkl") 37 | # with open(meta_data, "rb") as pickly: 38 | #label = os.path.basename(filepath) 39 | chunk_dict = dict() 40 | for chunk in glob.glob(filepath + "/*/"): 41 | expert_dict = {} 42 | for expert_dir in experts: 43 | tens_dir = os.path.join(chunk, expert_dir) 44 | try: 45 | if len(os.listdir(tens_dir)) > 1: 46 | tensor_list = [] 47 | for tensor in glob.glob(tens_dir + "/*.pt"): 48 | # tensor_list.append(torch.load(tensor, map_location="cpu")) 49 | tensor_list.append(tensor) 50 | expert_dict[expert_dir] = tensor_list 51 | elif len(os.listdir(tens_dir)) == 1: 52 | expert_tensor = glob.glob(tens_dir + "/*.pt") 53 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 54 | expert_dict[expert_dir] = expert_tensor 55 | # expert_list.append(expert_tensor[0]) 56 | else: 57 | # No audio embedding available 58 | continue 59 | except: 60 | continue 61 | chunk_str = chunk.split("/")[-2] 62 | chunk_dict[os.path.basename(chunk_str)] = expert_dict 63 | master_dict = {"path": filepath, "label": label, "data": chunk_dict} 64 | return master_dict 65 | 66 | 67 | def squish_folders(input_dir): 68 | all_files = [] 69 | for labels in glob.glob(input_dir + "/*/"): 70 | for video in glob.glob(labels + "/*/"): 71 | all_files.append(video) 72 | print(labels, "complete") 73 | print("length of files", len(all_files)) 74 | with open("MIT_validation_cache.pkl", "wb") as cache: 75 | pickle.dump(all_files, cache) 76 | 77 | 78 | def mp_handler(): 79 | p = mp.Pool(40) 80 | 81 | squish_folders("/mnt/bigelow/scratch/mit_no_crop/validation") 82 | with open("MIT_validation_cache.pkl", 'rb') as cache: 83 | data = pickle.load(cache) 84 | # random.shuffle(data) 85 | 86 | with open("MIT_validation_temporal.pkl", 'ab') as pkly: 87 | for result in p.imap(create_dictionary, tqdm.tqdm(data, total=len(data))): 88 | if result: 89 | pickle.dump(result, pkly) 90 | 91 | 92 | if __name__ == "__main__": 93 | 94 | label_df = load_labels( 95 | "/home/ed/self-supervised-video/data_processing/moments_categories.csv") 96 | torch.multiprocessing.set_sharing_strategy('file_system') 97 | # rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 98 | # resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 99 | mp_handler() 100 | # squish_folders(input_dir) 101 | -------------------------------------------------------------------------------- /src/data_processing/temporal/create_mmx_frames.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import pickle 3 | import glob 4 | import os 5 | import re 6 | import multiprocessing as mp 7 | import numpy as np 8 | import random 9 | 10 | 11 | from collections import OrderedDict 12 | 13 | # first create a list of all the filepaths to go through and collect - do only on one thread. 14 | 15 | def collect_labels(label): 16 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 17 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 18 | label_list = np.zeros(15) 19 | 20 | for i, genre in enumerate(target_names): 21 | if genre == "Sci-Fi" or genre == "ScienceFiction": 22 | genre = "Science Fiction" 23 | if genre in label: 24 | label_list[i] = 1 25 | if np.sum(label_list) == 0: 26 | label_list[5] = 1 27 | 28 | return label_list 29 | 30 | def label_tidy(label): 31 | if len(label) == 2: 32 | return collect_labels(label[0]) 33 | else: 34 | return collect_labels(label) 35 | 36 | def squish_folders(input_dir): 37 | all_files = [] 38 | for genres in tqdm.tqdm(glob.glob(input_dir + "/*/")): 39 | for movies in os.listdir(genres): 40 | path = os.path.join(genres, movies) 41 | # use movies rather than individual scenes 42 | all_files.append(path) 43 | print("length of files", len(all_files)) 44 | with open("cache.pkl", "wb") as cache: 45 | pickle.dump(all_files, cache) 46 | 47 | def create_frame_path_dict(filepath): 48 | genre_name = filepath.split("/")[-3:-1] 49 | orig_dir = filepath.replace( 50 | "/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 51 | 52 | if not os.path.exists(orig_dir): 53 | return False 54 | 55 | dirs = os.listdir(orig_dir) 56 | meta_data = os.path.join(orig_dir, dirs[0], "meta.pkl") 57 | 58 | with open(meta_data, "rb") as pickly: 59 | label = pickle.load(pickly) 60 | year = label[1] 61 | label = label[0] 62 | label = label_tidy(label) 63 | 64 | 65 | # subdirs = [000,001,002] 66 | scenes = glob.glob(filepath + "/*/") 67 | 68 | # this ensures all scenes are in order - some are "000" and "1" so regex checks this. 69 | scenes = sorted(scenes, key=lambda x: ( 70 | int(re.findall("[0-9]+", x.split("/")[-2])[0]))) 71 | if len(scenes) < 1: 72 | return False 73 | 74 | out_dict = OrderedDict() 75 | scene_dict = OrderedDict() 76 | 77 | for i, scene in enumerate(scenes): 78 | # chunks is a list of filepaths [001/001, 001/002] 79 | clips = glob.glob(scene + "/*/") 80 | if len(clips) < 1: 81 | continue 82 | 83 | clips = sorted(clips, key=lambda x: ( 84 | int(re.findall("[0-9]+", x.split("/")[-2])[0]))) 85 | clip_dict = OrderedDict() 86 | for j, clip in enumerate(clips): 87 | img_list = [] 88 | img_dir = os.path.join(clip, "imgs") 89 | img_paths = glob.glob(img_dir + "/*") 90 | if len(img_paths) < 10: 91 | continue 92 | img_paths = sorted(img_paths, key=lambda i: int(os.path.splitext(os.path.basename(i))[0])) 93 | while len(img_paths) < 16: 94 | img_paths.append(img_paths[-1]) 95 | clip_dict[j] = img_paths 96 | scene_dict[i] = clip_dict 97 | out_dict = {"label": label,"year":year, "path":filepath, "scenes":scene_dict} 98 | return out_dict 99 | 100 | def mp_handler(): 101 | p = mp.Pool(40) 102 | data_list = [] 103 | 104 | with open("cache.pkl", 'rb') as cache: 105 | data = pickle.load(cache) 106 | random.shuffle(data) 107 | # Remaining 80% to training set 108 | train_data = data[:int((len(data)+1)*.90)] 109 | # Splits 20% data to test set 110 | test_data = data[int((len(data)+1)*.90):] 111 | print("training_data", len(train_data)) 112 | print("testing_data", len(test_data)) 113 | 114 | #squish_folders(input_dir) 115 | 116 | with open("mmx_train_temporal.pkl", 'ab') as pkly: 117 | for result in p.imap(create_frame_path_dict, tqdm.tqdm(train_data, total=len(train_data))): 118 | if result: 119 | pickle.dump(result, pkly) 120 | 121 | with open("mmx_val_temporal.pkl", 'ab') as pkly: 122 | for result in p.imap(create_frame_path_dict, tqdm.tqdm(test_data, total=len(test_data))): 123 | if result: 124 | pickle.dump(result, pkly) 125 | 126 | if __name__ == "__main__": 127 | #rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 128 | #resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 129 | input_dir = "/mnt/bigelow/scratch/mmx_aug/" 130 | 131 | squish_folders(input_dir) 132 | mp_handler() 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /src/data_processing/temporal/create_mmx_temporal.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | import os 4 | import pandas as pd 5 | import glob 6 | import tqdm 7 | import multiprocessing as mp 8 | import numpy as np 9 | import torch 10 | import pickle 11 | import re 12 | import resource 13 | 14 | from collections import OrderedDict 15 | 16 | 17 | input_dir = "/mnt/bigelow/scratch/mmx_aug/" 18 | 19 | 20 | def create_embedding_dict(filepath): 21 | genre_name = filepath.split("/")[-3:-1] 22 | orig_dir = filepath.replace( 23 | "/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 24 | 25 | dirs = os.listdir(orig_dir) 26 | meta_data = os.path.join(orig_dir, dirs[0], "meta.pkl") 27 | with open(meta_data, "rb") as pickly: 28 | label = pickle.load(pickly) 29 | 30 | # subdirs = [000,001,002] 31 | scenes = glob.glob(filepath + "/*/") 32 | scenes = sorted(scenes, key=lambda x: ( 33 | int(re.findall("[0-9]+", x.split("/")[-2])[0]))) 34 | # if len(subdirs) < 2: 35 | # return False 36 | 37 | experts = ["img-embeddings", "location-embeddings", "motion-embeddings", 38 | "test-location-embeddings", "test-img-embeddings", "test-video-embeddings", "audio-embeddings"] 39 | out_dict = OrderedDict() 40 | scene_dict = OrderedDict() 41 | 42 | for scene in scenes: 43 | # chunks is a list of filepaths [001/001, 001/002] 44 | chunks = glob.glob(scene + "/*/") 45 | chunk_dict = OrderedDict() 46 | 47 | for chunk in chunks: 48 | #expert_list = [] 49 | expert_dict = OrderedDict() 50 | for expert_dir in experts: 51 | tens_dir = os.path.join(chunk, expert_dir) 52 | try: 53 | if len(os.listdir(tens_dir)) > 1: 54 | tensor_list = [] 55 | for tensor in glob.glob(tens_dir + "/*.pt"): 56 | #t = torch.load(tensor, map_location="cpu") 57 | #t = t.cpu().detach().numpy() 58 | tensor_list.append(tensor) 59 | expert_dict[expert_dir] = tensor_list 60 | # expert_list.append(tensor_list) 61 | elif len(os.listdir(tens_dir)) == 1: 62 | 63 | expert_tensor = glob.glob(tens_dir + "/*.pt") 64 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 65 | #expert_tensor = expert_tensor.cpu().detach().numpy() 66 | # expert_list.append(expert_tensor[0]) 67 | expert_dict[expert_dir] = expert_tensor[0] 68 | else: 69 | # print("no audio") 70 | # No audio embedding available 71 | continue 72 | except FileNotFoundError: 73 | continue 74 | chunk_str = chunk.split("/")[-2] 75 | # chunk_dict[os.path.basename(chunk_str)] = expert_list 76 | chunk_dict[os.path.basename(chunk_str)] = expert_dict 77 | 78 | scene_str = scene.split("/")[-2] 79 | scene_dict[os.path.basename(scene_str)] = chunk_dict 80 | out_dict = {"label": label, "path": orig_dir, "scenes": scene_dict} 81 | return out_dict 82 | 83 | 84 | def create_scene_dict_train(filepath): 85 | experts = ["location-embeddings", "img-embeddings", "video-embeddings"] 86 | 87 | orig_dir = filepath.replace( 88 | "/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 89 | 90 | scene = orig_dir.split("/")[-2] 91 | 92 | dirs = os.listdir(orig_dir) 93 | meta_data = os.path.join(orig_dir, "meta.pkl") 94 | with open(meta_data, "rb") as pickly: 95 | label = pickle.load(pickly) 96 | chunk_dict = dict() 97 | for chunk in glob.glob(filepath + "/*/"): 98 | expert_list = [] 99 | for expert_dir in experts: 100 | tens_dir = os.path.join(chunk, expert_dir) 101 | try: 102 | if len(os.listdir(tens_dir)) > 1: 103 | tensor_list = [] 104 | for tensor in glob.glob(tens_dir + "/*.pt"): 105 | # tensor_list.append(torch.load(tensor, map_location="cpu")) 106 | tensor_list.append(tensor) 107 | expert_list.append(tensor_list) 108 | elif len(os.listdir(tens_dir)) == 1: 109 | expert_tensor = glob.glob(tens_dir + "/*.pt") 110 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 111 | expert_list.append(expert_tensor[0]) 112 | else: 113 | # No audio embedding available 114 | continue 115 | except: 116 | continue 117 | chunk_str = chunk.split("/")[-2] 118 | chunk_dict[os.path.basename(chunk_str)] = expert_list 119 | scene_dict = {"path": orig_dir, "scene": scene, 120 | "label": label, "data": chunk_dict} 121 | return scene_dict 122 | 123 | 124 | def create_scene_dict_test(filepath): 125 | experts = ["test-location-embeddings", 126 | "test-img-embeddings", "test-video-embeddings"] 127 | 128 | orig_dir = filepath.replace( 129 | "/mnt/bigelow/scratch/mmx_aug", "/mnt/fvpbignas/datasets/mmx_raw") 130 | 131 | scene = orig_dir.split("/")[-2] 132 | 133 | dirs = os.listdir(orig_dir) 134 | meta_data = os.path.join(orig_dir, "meta.pkl") 135 | with open(meta_data, "rb") as pickly: 136 | label = pickle.load(pickly) 137 | chunk_dict = dict() 138 | for chunk in glob.glob(filepath + "/*/"): 139 | expert_list = [] 140 | for expert_dir in experts: 141 | tens_dir = os.path.join(chunk, expert_dir) 142 | try: 143 | if len(os.listdir(tens_dir)) > 1: 144 | tensor_list = [] 145 | for tensor in glob.glob(tens_dir + "/*.pt"): 146 | # tensor_list.append(torch.load(tensor, map_location="cpu")) 147 | tensor_list.append(tensor) 148 | expert_list.append(tensor_list) 149 | elif len(os.listdir(tens_dir)) == 1: 150 | expert_tensor = glob.glob(tens_dir + "/*.pt") 151 | #expert_tensor = torch.load(expert_tensor[0], map_location="cpu") 152 | expert_list.append(expert_tensor[0]) 153 | else: 154 | # No audio embedding available 155 | continue 156 | except: 157 | continue 158 | chunk_str = chunk.split("/")[-2] 159 | chunk_dict[os.path.basename(chunk_str)] = expert_list 160 | scene_dict = {"path": orig_dir, "scene": scene, 161 | "label": label, "data": chunk_dict} 162 | return scene_dict 163 | 164 | 165 | def squish_folders(input_dir): 166 | all_files = [] 167 | for genres in tqdm.tqdm(glob.glob(input_dir + "/*/")): 168 | for movies in os.listdir(genres): 169 | path = os.path.join(genres, movies) 170 | # use movies rather than individual scenes 171 | all_files.append(path) 172 | print("length of files", len(all_files)) 173 | with open("cache.pkl", "wb") as cache: 174 | pickle.dump(all_files, cache) 175 | 176 | 177 | def mp_handler(): 178 | p = mp.Pool(30) 179 | data_list = [] 180 | count = 0 181 | with open("cache.pkl", 'rb') as cache: 182 | data = pickle.load(cache) 183 | random.shuffle(data) 184 | # Remaining 80% to training set 185 | train_data = data[:int((len(data)+1)*.90)] 186 | # Splits 20% data to test set 187 | test_data = data[int((len(data)+1)*.90):] 188 | print("training_data", len(train_data)) 189 | print("testing_data", len(test_data)) 190 | big_list = [] 191 | 192 | # append to pkl rather than write to pkl 193 | 194 | # with open("mmx_tensors_train.pkl", 'ab') as pkly: 195 | # for result in p.imap(create_embedding_dict, tqdm.tqdm(train_data, total=len(train_data))): 196 | # if result: 197 | # pickle.dump(result, pkly) 198 | 199 | with open("mmx_train_temporal.pkl", 'ab') as pkly: 200 | for result in p.imap(create_embedding_dict, tqdm.tqdm(train_data, total=len(train_data))): 201 | if result: 202 | pickle.dump(result, pkly) 203 | 204 | with open("mmx_val_temporal.pkl", 'ab') as pkly: 205 | for result in p.imap(create_embedding_dict, tqdm.tqdm(test_data, total=len(test_data))): 206 | if result: 207 | pickle.dump(result, pkly) 208 | 209 | if __name__ == "__main__": 210 | torch.multiprocessing.set_sharing_strategy('file_system') 211 | # rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 212 | # resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 213 | 214 | squish_folders(input_dir) 215 | mp_handler() 216 | -------------------------------------------------------------------------------- /src/data_processing/tools/admin.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import re 3 | import fnmatch 4 | 5 | def pickleLoader(pklFile): 6 | try: 7 | while True: 8 | yield pkl.load(pklFile) 9 | except EOFError: 10 | pass 11 | 12 | with open("mmx_tensors_train.pkl", "rb") as pkly: 13 | for entry in pickleLoader(pkly): 14 | if "Horror/TheWolfman" in entry["path"]: 15 | print(entry["path"]) 16 | pass 17 | else: 18 | with open("mmx_tensors_train_3.pkl", 'ab') as pkly: 19 | pkl.dump(entry, pkly) 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/data_processing/tools/nearest_neighbour.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import re 4 | import random 5 | import numpy as np 6 | from annoy import AnnoyIndex 7 | import pandas as pd 8 | import streamlit as st 9 | from youtubesearchpython import VideosSearch 10 | 11 | with open("embed_dict.pkl", "rb") as file: 12 | data_base = pickle.load(file) 13 | 14 | st.title('Movie trailer recommendation') 15 | st.text("Approximate nearest neighbours using only visual features (no metadata!)") 16 | st.text("Embeddings extracted from a custom video transformer encoder.") 17 | 18 | def annoy_processor(random_choice=False, id_n=0): 19 | # # 4096 20 | # from annoy import annoyindex 21 | recall = {} 22 | 23 | f = 15 24 | t = AnnoyIndex(f, 'euclidean') 25 | print(len(data_base.keys())) 26 | for i in range(len(data_base.keys())): 27 | v = data_base[i]["embedding"] 28 | t.add_item(i, v) 29 | 30 | t.build(750) 31 | t.save('./test.ann') 32 | u = AnnoyIndex(15, 'euclidean') 33 | u.load("./test.ann") 34 | if random_choice: 35 | results = u.get_nns_by_item(random.randrange(299), 10) 36 | else: 37 | results = u.get_nns_by_item(id_n, 10) 38 | for i, x in enumerate(results): 39 | recall[i] = {"path":data_base[x]["path"], "name": os.path.basename( 40 | os.path.normpath(data_base[x]["path"])), "actual": data_base[x]["actual"], "predicted": data_base[x]["predicted"]} 41 | #recall = pd.DataFrame.from_dict(recall, orient="index") 42 | return recall 43 | 44 | def load_data(nrows): 45 | data = pd.DataFrame.from_dict(data_base, orient="index") 46 | return data 47 | 48 | def retrieve_movies(random_choice=False, id_n=0): 49 | st.subheader("10 similar movies") 50 | data = annoy_processor(random_choice, id_n) 51 | # st.write(annoy_processor()) 52 | col1, col2 = st.columns(2) 53 | cols = [col1, col2] 54 | 55 | for i in range(10): 56 | # name = re.sub(r"(\w)([A-Z])", r"\1 \2",data[i]["name"]) 57 | cols[i%len(cols)].write(data[i]["path"]) 58 | cols[i%len(cols)].image(data[i]["path"]) 59 | # video_search = VideosSearch(name + " movie trailer", limit = 1) 60 | # result = video_search.result() 61 | # try: 62 | # url = result["result"][0]['link'] 63 | # cols[i%len(cols)].video(url) 64 | # except IndexError: 65 | # cols[i%len(cols)].write("no video") 66 | 67 | cols[i%len(cols)].caption("Actual genre:" + str(data[i]["actual"])) 68 | # cols[i%len(cols)].write(str(data[i]["actual"])) 69 | cols[i%len(cols)].caption("Predicted genre:" + str(data[i]["predicted"])) 70 | # cols[i%len(cols)].write(str(data[i]["predicted"])) 71 | 72 | def tsne_projection(): 73 | writer = SummaryWriter(log_dir="summaries") 74 | embedding = np.stack([data_base[i]["embedding"] 75 | for i in range(len(data_base.keys()))]) 76 | print(embedding.shape) 77 | keys = [data_base[i]["path"] for i in range(len(data_base.keys()))] 78 | print(len(keys)) 79 | writer.add_embedding(embedding, metadata=keys) 80 | 81 | data_load_state = st.text("loading data") 82 | data = load_data(len(data_base.keys())) 83 | data_load_state.text("loading data... done!") 84 | 85 | # st.subheader('Raw data') 86 | # st.write(data) 87 | 88 | option = st.selectbox("pick a trailer from the drop down", data) 89 | if st.button("generate random cluster"): 90 | retrieve_movies(random_choice=True, id_n=0) 91 | if st.button("search with selected"): 92 | id_n = data.index[data["path"] == option].tolist()[0] 93 | print(id_n) 94 | retrieve_movies(random_choice=False, id_n=id_n) -------------------------------------------------------------------------------- /src/data_processing/tools/test.ann: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ed-fish/data-efficient-video-transformers/0a7d1b40563e244df14f4f33376cad413b2ba558/src/data_processing/tools/test.ann -------------------------------------------------------------------------------- /src/data_processing/transforms/audio_transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import librosa 4 | from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift 5 | from scipy.io.wavfile import read 6 | import librosa 7 | from scipy.io.wavfile import write 8 | 9 | 10 | class AudioTransforms: 11 | 12 | def __init__(self): 13 | 14 | self.augment = Compose([ 15 | AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5), 16 | # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5), 17 | PitchShift(min_semitones=-4, max_semitones=4, p=0.5), 18 | # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5), 19 | ]) 20 | 21 | def extract_audio(self, input_file, output_dir): 22 | output_file = os.path.join(output_dir, "audio.wav") 23 | subprocess.call(['ffmpeg', '-hwaccel', 'cuda', '-i', input_file, 24 | '-codec:a', 'pcm_s16le', '-ac', '1', '-to', '1', 25 | output_file, '-loglevel', 'quiet']) 26 | 27 | audio, sample_rate = librosa.load(output_file) 28 | audio = self.augment(audio, sample_rate=sample_rate) 29 | return audio 30 | 31 | -------------------------------------------------------------------------------- /src/data_processing/transforms/img_transforms.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import random 3 | import cv2 4 | import numpy as np 5 | from torchvision import transforms 6 | 7 | class ImgTransform: 8 | 9 | ''' Augmentation selection for spatio temporal crops 10 | Default config is as follows: 11 | transform_prob: 0.5 12 | noise_multiplier: 10 13 | flip: 0.5 14 | jitter: 0.8 15 | gray: 0.2 16 | ) noise: 0.3 ''' 17 | 18 | 19 | def __init__(self, img, config): 20 | self.config = config 21 | self.img = img 22 | self.width = self.img.shape[1] 23 | self.height = self.img.shape[0] 24 | 25 | random.seed(random.random()) 26 | min(self.height, self.width) 27 | self.crop_size = random.randrange(30, min(self.height, self.width) - 10) 28 | if self.crop_size < 30: 29 | self.crop_size = 30 30 | if self.width <= self.crop_size or self.height <= self.crop_size: 31 | self.x = 0 32 | self.y = 0 33 | else: 34 | self.x = random.randrange(1, self.width - self.crop_size) 35 | self.y = random.randrange(1, self.height - self.crop_size) 36 | self.noise_multiply = random.randrange(0, config["noise_multiplier"].get()) 37 | self.make_gray = True if random.random() < config["gray"].get() else False 38 | self.flip_val = random.randrange(-1, 1) 39 | self.jitter_val = random.randrange(1, config["jitter"].get()) 40 | 41 | def crop(self, img): 42 | img = img[self.y:self.y + self.crop_size, 43 | self.x:self.x + self.crop_size] 44 | return img 45 | 46 | def noise(self, img, amount): 47 | gaussian_noise = np.zeros_like(img) 48 | gaussian_noise = cv2.randn(gaussian_noise, 0, amount) 49 | img = cv2.add(img, gaussian_noise, dtype=cv2.CV_8UC3) 50 | return img 51 | 52 | def gray(self, img): 53 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 54 | return img 55 | 56 | def flip(self, img): 57 | 58 | img = cv2.flip(img, self.flip_val) 59 | return img 60 | 61 | def blur(self, img): 62 | img = cv2.GaussianBlur(img, (5, 5), 0) 63 | return img 64 | 65 | def gen_hash(self, tbh): 66 | hash_object = hashlib.md5(tbh.encode()) 67 | return hash_object 68 | 69 | def colour_jitter(self, img): 70 | H, W, C = img.shape 71 | noise = np.random.randint(0, self.jitter_val, (H, W)) 72 | zitter = np.zeros_like(img) 73 | zitter[:, :, 1] = noise 74 | img = cv2.add(img, zitter) 75 | return img 76 | 77 | 78 | def debug(self): 79 | print("new crop") 80 | print(f"gray {self.make_gray}") 81 | print(f"jitter {self.jitter_val}") 82 | print(f"flip {self.flip_val}") 83 | 84 | def transform_with_prob(self, img): 85 | img = self.crop(img) 86 | img = self.flip(img) 87 | img = self.blur(img) 88 | img = self.noise(img, self.noise_multiply) 89 | img = self.colour_jitter(img) 90 | if self.make_gray: 91 | img = self.gray(img) 92 | #self.debug() 93 | return img 94 | 95 | 96 | class Normaliser: 97 | 98 | '''Converts images to PIL, resizes and normalises for appropriate model''' 99 | 100 | def __init__(self, config): 101 | self.config = config 102 | 103 | def rgb(self, img): 104 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 105 | return img 106 | 107 | def norm(self, img, mean, std): 108 | tensorfy = transforms.ToTensor() 109 | img = tensorfy(img) 110 | normarfy = transforms.Normalize(mean, std) 111 | img = normarfy(img).unsqueeze(0) 112 | return img 113 | 114 | # Imagenet ResNet 50/18 115 | def img_model(self, img): 116 | img = self.rgb(img) 117 | img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA) 118 | img = self.norm(img, self.config["image_norm"]["mean"].get( 119 | ), self.config["image_norm"]["std"].get()) 120 | return img 121 | 122 | def location_model(self, img): 123 | img = self.rgb(img) 124 | img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA) 125 | img = self.norm(img, self.config["location_norm"]["mean"].get( 126 | ), self.config["location_norm"]["std"].get()) 127 | return img 128 | 129 | def video_model(self, img): 130 | img = self.rgb(img) 131 | img = cv2.resize(img, (112, 112), interpolation=cv2.INTER_AREA) 132 | img = self.norm(img, self.config["video_norm"]["mean"].get( 133 | ), self.config["video_norm"]["std"].get()) 134 | return img 135 | 136 | def depth_model(self, img): 137 | img = self.rgb(img) 138 | img = cv2.resize(img, (384, 384), interpolation=cv2.INTER_AREA) 139 | img = self.norm(img, self.config["depth_norm"]["mean"].get( 140 | ), self.config["depth_norm"]["std"].get()) 141 | return img 142 | -------------------------------------------------------------------------------- /src/data_processing/transforms/spatio_cut.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import subprocess 3 | import os 4 | import glob 5 | import shutil 6 | import tempfile 7 | 8 | 9 | class SpatioCut: 10 | 11 | def convert_framerate(self, video_file, output_file, fps): 12 | subprocess.call(['ffmpeg', '-i', video_file, '-filter:v', f'fps={fps}', 13 | output_file, '-loglevel', 'quiet']) 14 | 15 | def split_video(self, video_file, output_dir): 16 | subprocess.call(['ffmpeg', '-hwaccel', 'cuda', '-i', video_file, 17 | '-c:v', 'libx264', '-crf', '22', '-map', '0', 18 | '-segment_time', '1', '-reset_timestamps', '1', '-g', 19 | '16', '-sc_threshold', '0', '-force_key_frames', 20 | "expr:gte(t, n_forced*16)", '-f', 'segment', 21 | os.path.join(output_dir, '%03d.mp4'), '-loglevel', 22 | 'quiet']) 23 | 24 | def split_frames(self, video_file): 25 | frame_list = [] 26 | vidcap = cv2.VideoCapture(video_file) 27 | success, image = vidcap.read() 28 | frame_list.append(image) 29 | while success: 30 | success, image = vidcap.read() 31 | if success: 32 | frame_list.append(image) 33 | return frame_list 34 | 35 | # Returns a 2d array of [n_chunks x n_frames] 36 | def cut_vid(self, video_file, frame_rate): 37 | output = [] 38 | tempdir = tempfile.mkdtemp() 39 | tempvid = os.path.join(tempdir, "vid.mp4") 40 | self.convert_framerate(str(video_file), tempvid, frame_rate) 41 | self.split_video(tempvid, tempdir) 42 | if os.path.exists(tempvid): 43 | os.remove(tempvid) 44 | for vid in glob.glob(tempdir + "/*.mp4"): 45 | output.append(self.split_frames(vid)) 46 | shutil.rmtree(tempdir) 47 | return output 48 | -------------------------------------------------------------------------------- /src/dataloaders/mit/MIT_Contrastive_dl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import ast 4 | import random 5 | import csv 6 | import torch 7 | import torch.nn.functional as F 8 | import _pickle as pickle 9 | import os 10 | import numpy as np 11 | from collections import defaultdict 12 | from torch.utils.data import Dataset, random_split, DataLoader 13 | import pytorch_lightning as pl 14 | import json 15 | from sklearn.model_selection import train_test_split 16 | 17 | 18 | class MITDataModule(pl.LightningDataModule): 19 | 20 | def __init__(self, train_data,val_data, config): 21 | super().__init__() 22 | self.train_data = train_data 23 | self.val_data = val_data 24 | self.config = config 25 | self.bs = self.config["batch_size"].get() 26 | 27 | def custom_collater(self, batch): 28 | 29 | return { 30 | 'label':[x['label'] for x in batch], 31 | 'x_i_experts':[x['x_i_experts'] for x in batch], 32 | 'x_j_experts':[x['x_j_experts'] for x in batch], 33 | 'path':[x['path'] for x in batch] 34 | } 35 | 36 | # def prepare_data(self): 37 | # data = self.load_data(self.pickle_file) 38 | # self.data = self.clean_data(data) 39 | 40 | def clean_data(self, data_frame): 41 | 42 | print("cleaning data") 43 | print(len(data_frame)) 44 | for i in range(len(data_frame)): 45 | 46 | data = data_frame.at[i, "data"] 47 | drop = False 48 | for d in data.values(): 49 | print(d.keys()) 50 | if len(d.keys()) < 2: 51 | drop = True 52 | if not "img-embeddings" in d.keys(): 53 | drop = True 54 | if drop: 55 | print("dropping missing experts") 56 | data_frame = data_frame.drop(i) 57 | continue 58 | 59 | data_chunk = list(data.values()) 60 | 61 | if len(data_chunk) < 2: 62 | print("dropping index with no data", i, len(data_chunk)) 63 | data_frame = data_frame.drop(i) 64 | continue 65 | 66 | # x = [len(data) for data in data_chunk] 67 | # if sum(x) < len(x) * 3: 68 | # print("dropping index with incomplete data", i, len(data)) 69 | # data_frame = data_frame.drop(i) 70 | # continue 71 | 72 | # test = [] 73 | # for f in data[1]: # data1 == img_embeddings, data2 == motion?, data0=location 74 | # print(f) 75 | # f = torch.load(f) 76 | # f = f.squeeze() 77 | # test.append(f) 78 | # print(f.dim) 79 | # if f.dim() > 0: 80 | # test.append(f) 81 | #else: 82 | # data_frame = data_frame.drop(i) 83 | # continue 84 | # try: 85 | # test = torch.cat(test, dim=-1) 86 | # except: 87 | # data_frame = data_frame.drop(i) 88 | # print("dropping", i) 89 | # continue 90 | #print(test.shape[0]) 91 | #if test.shape[0] != 2560: 92 | # print("dropping", i) 93 | # data_frame = data_frame.drop(i) 94 | # continue 95 | 96 | data_frame = data_frame.reset_index(drop=True) 97 | print(len(data_frame)) 98 | 99 | return data_frame 100 | 101 | def load_data(self, db): 102 | print("loading data") 103 | data = [] 104 | with open(db, "rb") as pkly: 105 | while 1: 106 | try: 107 | # append if data serialised with open file 108 | data.append(pickle.load(pkly)) 109 | # else data not streamed 110 | #data = pickle.load(pkly) 111 | except EOFError: 112 | break 113 | 114 | data_frame = pd.DataFrame(data) 115 | print("data loaded") 116 | print("length", len(data_frame)) 117 | # data_frame = data_frame.head(10000) 118 | return data_frame 119 | 120 | def setup(self, stage): 121 | 122 | self.train_data = self.load_data(self.train_data) 123 | self.train_data = self.clean_data(self.train_data) 124 | 125 | self.val_data = self.load_data(self.val_data) 126 | self.val_data = self.clean_data(self.val_data) 127 | 128 | def train_dataloader(self): 129 | return DataLoader(MITDataset(self.train_data, self.config), self.bs, shuffle=True, collate_fn=self.custom_collater, num_workers=0, drop_last=True) 130 | 131 | def val_dataloader(self): 132 | return DataLoader(MITDataset(self.val_data, self.config), self.bs, shuffle=False, collate_fn=self.custom_collater, num_workers=0, drop_last=True) 133 | # For now use validation until proper test split obtained 134 | def test_dataloader(self): 135 | return DataLoader(MITDataset(self.train_data, self.config), 1, shuffle=False, collate_fn=self.custom_collater, num_workers=0) 136 | 137 | 138 | class MITDataset(Dataset): 139 | def __init__(self, data, config): 140 | super().__init__() 141 | 142 | self.config = config 143 | self.data_frame = data 144 | self.aggregation = self.config["aggregation"].get() 145 | self.label_df = self.load_labels("/home/ed/self-supervised-video/data_processing/moments_categories.csv") 146 | 147 | def __len__(self): 148 | return len(self.data_frame) 149 | 150 | def collect_one_hot_labels(self, label): 151 | label_array = np.zeros(305) 152 | index = self.label_df.loc[label]["id"] 153 | label_array[index] = 1 154 | label_array = torch.LongTensor(label_array) 155 | print(label_array) 156 | return label_array 157 | 158 | def collect_labels(self, label): 159 | index = self.label_df.loc[label]["id"] 160 | return index 161 | 162 | def load_labels(self, label_root): 163 | label_df = pd.read_csv(label_root) 164 | label_df.set_index('label', inplace=True) 165 | print("len of labels = ", len(label_df)) 166 | return label_df 167 | 168 | def load_tensor(self, tensor): 169 | tensor = torch.load(tensor, map_location=torch.device('cpu')) 170 | return tensor 171 | 172 | def __getitem__(self, idx): 173 | 174 | label = self.data_frame.at[idx, "label"] 175 | label = self.collect_labels(label) 176 | data = self.data_frame.at[idx, "data"] 177 | path = self.data_frame.at[idx, "path"] 178 | 179 | experts_xi = [] 180 | experts_xj = [] 181 | 182 | 183 | x_i, x_j = random.sample(list(data.values()), 2) 184 | x_i = x_i["img-embeddings"][0] 185 | x_j = x_j["img-embeddings"][0] 186 | 187 | experts_xi = self.load_tensor(x_i) 188 | experts_xj = self.load_tensor(x_j) 189 | 190 | #for index, i in enumerate(x_i): 191 | # print(i) 192 | # t = torch.load(i) 193 | # experts_xi.append(t.squeeze()) 194 | 195 | #for index, i in enumerate(x_j): 196 | # print(i) 197 | # t = torch.load(i) 198 | # experts_xj.append(t.squeeze()) 199 | 200 | if self.aggregation == "debugging": 201 | experts_xi = torch.cat(experts_xi, dim=-1) 202 | experts_xj = torch.cat(experts_xj, dim=-1) 203 | 204 | return {"label":label, "path":path, "x_i_experts":experts_xi, "x_j_experts":experts_xj} 205 | 206 | 207 | class MIT_RAW_Dataset(Dataset): 208 | def __init__(self, config, pre_computed=True): 209 | super().__init__() 210 | self.config = config 211 | self.pre_computed = pre_computed 212 | self.chunk_size = config['data_size'].get() 213 | self.data_frame = self.load_data() 214 | # self.ee = EmbeddingExtractor(self.config) 215 | 216 | def load_data(self): 217 | train_data_frame = pd.read_csv(self.config['train_csv'].get()) 218 | # val_data_frame = pd.read_csv(self.config['val.csv'].get()) 219 | return train_data_frame 220 | 221 | def __len__(self): 222 | return len(self.data_frame) 223 | 224 | def stack_and_permute_vid(self, img_list): 225 | img_list = torch.stack(img_list) 226 | img_list = img_list.squeeze(1) 227 | img_list = img_list.permute(1, 0, 2, 3) 228 | return img_list 229 | 230 | def open_pt_return_list(self, folder_path): 231 | items = glob.glob(folder_path + "/*.pt") 232 | tensor_list = [] 233 | if len(items) > 1: 234 | for i in items: 235 | with torch.no_grad(): 236 | x = torch.load(i, map_location="cuda:3") 237 | x = x.detach() 238 | tensor_list.append(x) 239 | return tensor_list 240 | else: 241 | with torch.no_grad(): 242 | x = torch.load(items[0], map_location="cuda:3") 243 | x = x.detach() 244 | return x 245 | 246 | # For precomputed embeddings that need to be loaded 247 | def collect_pre_computed_embeddings(self, video, config, label): 248 | sample_dict = defaultdict(dict) 249 | # video_name = os.path.basename(video).replace(".mp4", "") 250 | # root_dir = os.path.join(config["train_root"].get()) 251 | dirs = glob.glob(video + "/*/") 252 | x_i_folder = dirs.pop(random.randrange(len(dirs))) 253 | x_j_folder = dirs.pop(random.randrange(len(dirs))) 254 | for s in ["x_i", "x_j"]: 255 | if s == "x_i": 256 | x_folder = x_i_folder 257 | else: 258 | x_folder = x_j_folder 259 | 260 | sample_dict[s]["video"] = self.open_pt_return_list(os.path.join(x_folder, "video-embeddings")) 261 | sample_dict[s]["location"] = self.open_pt_return_list(os.path.join(x_folder, "location-embeddings")) 262 | sample_dict[s]["image"] = self.open_pt_return_list(os.path.join(x_folder, "img-embeddings")) 263 | return sample_dict 264 | 265 | """ def collect_embedding(self, video, config): 266 | norm = Normaliser(config) 267 | sc = SpatioCut() 268 | video_imgs = sc.cut_vid(video, 16) 269 | if len(video_imgs) < 16: 270 | return 0 271 | 272 | # Take two groups of frames randomly - may want to make this 273 | # a temporal distance in the future as per the spatio-temporal 274 | # paper. 275 | 276 | x_i = video_imgs.pop(random.randrange(0, len(video_imgs))) 277 | x_j = video_imgs.pop(random.randrange(0, len(video_imgs))) 278 | augment_i = ImgTransform(x_i[0], config) 279 | augment_j = ImgTransform(x_j[0], config) 280 | sample_dict = defaultdict(dict) 281 | i_3d, i_loc, i_obj = [], [], [], [] 282 | j_3d, j_loc, j_obj = [], [], [], [] 283 | 284 | for img in x_i: 285 | t_img = augment_i.transform_with_prob(img) 286 | i_3d.append(norm.video_model(t_img)) 287 | i_loc.append(norm.location_model(t_img)) 288 | # i_dep.append(norm.depth_model(t_img)) 289 | i_obj.append(norm.img_model(t_img)) 290 | 291 | i_3d = self.stack_and_permute_vid(i_3d) 292 | sample_dict["x_i"]["video"] = i_3d 293 | sample_dict["x_i"]["location"] = i_loc 294 | # sample_dict["x_i"]["depth"] = i_dep 295 | sample_dict["x_i"]["image"] = i_obj 296 | 297 | for img in x_j: 298 | t_img = augment_j.transform_with_prob(img) 299 | j_3d.append(norm.video_model(t_img)) 300 | j_loc.append(norm.location_model(t_img)) 301 | # j_dep.append(norm.depth_model(t_img)) 302 | j_obj.append(norm.img_model(t_img)) 303 | 304 | j_3d = self.stack_and_permute_vid(j_3d) 305 | sample_dict["x_j"]["video"] = j_3d 306 | sample_dict["x_j"]["location"] = j_loc 307 | # sample_dict["x_j"]["depth"] = i_dep 308 | sample_dict["x_j"]["image"] = j_obj 309 | 310 | return sample_dict """ 311 | 312 | 313 | def return_expert_for_key_pretrained(self, key, raw_tensor): 314 | 315 | if key == "image": 316 | if len(raw_tensor) > 1: 317 | output = torch.stack(raw_tensor) 318 | output = output.transpose(0, 2) 319 | output = F.adaptive_avg_pool1d(output, 1) 320 | output = output.transpose(1, 0).squeeze(2) 321 | output = output.squeeze(1) 322 | else: 323 | output = raw_tensor[0].unsqueeze(0) 324 | 325 | if key == "motion" or key == "video": 326 | output = raw_tensor[0].unsqueeze(0) 327 | 328 | if key == "location": 329 | if len(raw_tensor) > 1: 330 | output = torch.stack(raw_tensor) 331 | output = output.transpose(0, 2) 332 | output = F.adaptive_avg_pool1d(output, 1) 333 | output = output.transpose(1, 0).squeeze(2) 334 | output = output.squeeze(1) 335 | else: 336 | output = raw_tensor[0].unsqueeze(0) 337 | 338 | return output 339 | 340 | def __getitem__(self, idx): 341 | label = self.data_frame.at[idx, "label"] 342 | path = self.data_frame.at[idx, "path"] 343 | 344 | if self.pre_computed: 345 | embedding_dict = self.collect_pre_computed_embeddings(path, self.config, label) 346 | 347 | x_i = embedding_dict["x_i"] 348 | x_j = embedding_dict["x_j"] 349 | 350 | for key, value in x_i.items(): 351 | x_i[key] = self.return_expert_for_key_pretrained(key, value) 352 | 353 | for key, value in x_j.items(): 354 | x_j[key] = self.return_expert_for_key_pretrained(key, value) 355 | 356 | return {'label': embedding_dict['label'], 'x_i': x_i, 'x_j': x_j} 357 | 358 | # else: 359 | # embedding_dict = self.collect_embedding(embed_dict["video"], self.config) 360 | # x_i = embedding_dict["x_i"] 361 | # x_j = embedding_dict["x_j"] 362 | 363 | # for key, value in x_i.items(): 364 | # x_i[key] = self.ee.return_expert_for_key(key, value) 365 | 366 | # for key, value in x_j.items(): 367 | # x_j[key] = self.ee.return_expert_for_key(key, value) 368 | 369 | # return { 'label': embed_dict['label'], 'x_i': x_i, 'x_j': x_j } 370 | 371 | 372 | class CustomDataset(Dataset): 373 | def __init__(self, config): 374 | 375 | self.config = config 376 | self.data_frame = self.load_data() 377 | 378 | def load_data(self): 379 | data_frame = pd.read_csv(self.config['input_csv'].get(), chunksize=self.config['data_size'].get()) 380 | return data_frame 381 | 382 | # def stack(self): 383 | 384 | def __len__(self): 385 | return len(self.data_frame) 386 | 387 | def collect_embeddings(self, data_type, idx): 388 | embedding_stack = [] 389 | data_path = self.data_frame.at[idx, data_type] 390 | if len(os.listdir(data_path)) > 1: 391 | for embed in os.listdir(data_path): 392 | embed_path = os.path.join(data_path, embed) 393 | embedding_stack.append(torch.load(embed_path)) 394 | data = torch.stack(embedding_stack) 395 | else: 396 | for embed in os.listdir(data_path): 397 | embed_path = os.path.join(data_path, embed) 398 | data = torch.load(embed_path) 399 | 400 | return data 401 | 402 | 403 | def __getitem__(self, idx): 404 | embed_dict = dict() 405 | embed_dict["label"] = self.data_frame.at[idx, LABEL] 406 | embed_dict["chunk"] = self.data_frame.at[idx, CHUNK] 407 | expert_list = self.config['experts'].get() 408 | if "image" in expert_list: 409 | embed_dict["image"] = self.collect_embeddings(IMG_EMBED, idx) 410 | if "location" in expert_list: 411 | embed_dict["location"] = self.collect_embeddings(LOC_EMBED, idx) 412 | if "depth" in expert_list: 413 | embed_dict["depth"] = self.collect_embeddings(DEPTH_EMBED, idx) 414 | if "motion" in expert_list: 415 | embed_dict["motion"] = self.collect_embeddings(VID_EMBED, idx) 416 | if "audio" in expert_list: 417 | embed_dict["audio"] = self.collect_embeddings(AUDIO_EMBED, idx) 418 | 419 | return embed_dict 420 | 421 | -------------------------------------------------------------------------------- /src/dataloaders/mit/MIT_Temporal_dl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import ast 4 | import random 5 | import csv 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | import pickle 10 | import os 11 | import numpy as np 12 | from collections import defaultdict 13 | from torch.utils.data import Dataset, random_split, DataLoader 14 | import pytorch_lightning as pl 15 | import json 16 | from sklearn.model_selection import train_test_split 17 | from torch.utils.data.sampler import WeightedRandomSampler 18 | 19 | class MITDataModule(pl.LightningDataModule): 20 | 21 | def __init__(self, train_data, val_data, config): 22 | super().__init__() 23 | self.train_data = train_data 24 | self.val_data = val_data 25 | self.config = config 26 | self.bs = self.config["batch_size"] 27 | # self.label_df = self.load_labels( 28 | # "/home/ed/self-supervised-video/data_processing/moments_categories.csv") 29 | 30 | # def load_labels(self, label_root): 31 | # label_df = pd.read_csv(label_root) 32 | # label_df.set_index('label', inplace=True) 33 | # print("len of labels = ", len(label_df)) 34 | # return label_df 35 | 36 | # def collect_labels(self, label): 37 | # index = self.label_df.loc[label]["id"] 38 | # return index 39 | 40 | def custom_collater(self, batch): 41 | 42 | return { 43 | 'label': [x['label'] for x in batch], 44 | 'experts': [x['expert_list'] for x in batch], 45 | 'path': [x['path'] for x in batch] 46 | } 47 | 48 | # def prepare_data(self): 49 | # data = self.load_data(self.pickle_file) 50 | # self.data = self.clean_data(data) 51 | 52 | def clean_data(self, data_frame, train=False): 53 | 54 | print("cleaning data") 55 | print(len(data_frame)) 56 | for i in range(len(data_frame)): 57 | 58 | data = data_frame.at[i, "data"] 59 | # label = data_frame.at[i, "label"] 60 | # label = self.collect_labels(label) 61 | # data_frame.at[i, "label"] = label 62 | 63 | drop = False 64 | # if len(data.values()) > 3: 65 | # print(data.values()) 66 | for key, value in data.items(): 67 | if len(value.keys()) < 2: 68 | drop = True 69 | # if train: 70 | # if not "img-embeddings" in value.keys(): 71 | # del data[key] 72 | # else: 73 | # if not "test-img-embeddings" in value.keys(): 74 | # del data[key] 75 | if drop: 76 | print("dropping missing experts") 77 | data_frame = data_frame.drop(i) 78 | continue 79 | 80 | data_chunk = list(data.values()) 81 | 82 | if len(data_chunk) < 2: 83 | print("dropping index with no data", i, len(data_chunk)) 84 | data_frame = data_frame.drop(i) 85 | continue 86 | 87 | # x = [len(data) for data in data_chunk] 88 | # if sum(x) < len(x) * 3: 89 | # print("dropping index with incomplete data", i, len(data)) 90 | # data_frame = data_frame.drop(i) 91 | # continue 92 | 93 | # test = [] 94 | # for f in data[1]: # data1 == img_embeddings, data2 == motion?, data0=location 95 | # print(f) 96 | # f = torch.load(f) 97 | # f = f.squeeze() 98 | # test.append(f) 99 | # print(f.dim) 100 | # if f.dim() > 0: 101 | # test.append(f) 102 | # else: 103 | # data_frame = data_frame.drop(i) 104 | # continue 105 | # try: 106 | # test = torch.cat(test, dim=-1) 107 | # except: 108 | # data_frame = data_frame.drop(i) 109 | # print("dropping", i) 110 | # continue 111 | # print(test.shape[0]) 112 | # if test.shape[0] != 2560: 113 | # print("dropping", i) 114 | # data_frame = data_frame.drop(i) 115 | # continue 116 | 117 | data_frame = data_frame.reset_index(drop=True) 118 | print(len(data_frame)) 119 | 120 | return data_frame 121 | 122 | def load_data(self, db): 123 | print("loading data") 124 | data = [] 125 | with open(db, "rb") as pkly: 126 | while 1: 127 | try: 128 | # append if data serialised with open file 129 | data.append(pickle.load(pkly)) 130 | # else data not streamed 131 | # data = pickle.load(pkly) 132 | except EOFError: 133 | break 134 | 135 | data_frame = pd.DataFrame(data) 136 | print("data loaded") 137 | print("length", len(data_frame)) 138 | 139 | # TODO remove - 64 Bx2 testing only 140 | data_frame = data_frame.head(10000) 141 | 142 | return data_frame 143 | 144 | def create_sampler(self, df): 145 | print("balancing data") 146 | labels_unique, counts = np.unique(df['label'], return_counts=True) 147 | sample_weights = [0] * len(df) 148 | class_list = [0] * 305 149 | 150 | print(f"unique labels :{labels_unique}{counts}") 151 | # class_weights = [sum(counts) / c for c in counts] 152 | class_weights = 1./torch.tensor(counts, dtype=torch.float) 153 | for n, i in enumerate(labels_unique): 154 | class_list[i] = class_weights[n] 155 | 156 | for n, e in enumerate(df['label']): 157 | class_weight = class_list[e] 158 | sample_weights[n] = class_weight.detach() 159 | sampler = WeightedRandomSampler( 160 | sample_weights, len(df['label']), replacement=True) 161 | return sampler 162 | 163 | def setup(self, stage): 164 | 165 | self.train_data = self.load_data(self.train_data) 166 | self.train_data = self.clean_data(self.train_data, train=True) 167 | self.weighted_sampler = self.create_sampler(self.train_data) 168 | self.val_data = self.load_data(self.val_data) 169 | self.val_data = self.clean_data(self.val_data, train=False) 170 | 171 | def train_dataloader(self): 172 | print("Loading train dataloader") 173 | return DataLoader(MITDataset(self.train_data, self.config, train=True), self.bs, sampler=self.weighted_sampler, collate_fn=self.custom_collater, num_workers=0, drop_last=True) 174 | 175 | def val_dataloader(self): 176 | return DataLoader(MITDataset(self.val_data, self.config, train=False), self.bs, shuffle=False, collate_fn=self.custom_collater, num_workers=0, drop_last=True) 177 | # For now use validation until proper test split obtained 178 | 179 | def test_dataloader(self): 180 | return DataLoader(MITDataset(self.train_data, self.config, train=False), 1, shuffle=False, collate_fn=self.custom_collater, num_workers=0) 181 | 182 | 183 | class MITDataset(Dataset): 184 | def __init__(self, data, config, train=True): 185 | super().__init__() 186 | 187 | self.config = config 188 | self.data_frame = data 189 | self.train = train 190 | self.label_df = self.load_labels( 191 | "/home/ed/self-supervised-video/data_processing/moments_categories.csv") 192 | 193 | def __len__(self): 194 | return len(self.data_frame) 195 | 196 | def collect_one_hot_labels(self, label): 197 | label_array = np.zeros(305) 198 | index = self.label_df.loc[label]["id"] 199 | label_array[index] = 1 200 | label_array = torch.LongTensor(label_array) 201 | print(label_array) 202 | return label_array 203 | 204 | def collect_labels(self, label): 205 | index = self.label_df.loc[label]["id"] 206 | return index 207 | 208 | def load_labels(self, label_root): 209 | label_df = pd.read_csv(label_root) 210 | label_df.set_index('label', inplace=True) 211 | print("len of labels = ", len(label_df)) 212 | return label_df 213 | 214 | def load_tensor(self, tensor): 215 | tensor = torch.load(tensor, map_location=torch.device('cpu')) 216 | if tensor.shape[-1] != 2048: 217 | tensor = nn.ConstantPad1d((0, 2048 - tensor.shape[-1]), 0)(tensor) 218 | # tensor = torch.load(tensor).detach() 219 | # tensor = torch.load(tensor, map_location=torch.device('cpu')) 220 | return tensor 221 | 222 | def __getitem__(self, idx): 223 | 224 | label = self.data_frame.at[idx, "label"] 225 | # label = self.collect_labels(label) 226 | data = self.data_frame.at[idx, "data"] 227 | path = self.data_frame.at[idx, "path"] 228 | 229 | # x_i, x_j = random.sample(list(data.values()), 2) 230 | expert_list = [] 231 | target_len = 3 232 | if self.config["cls"]: 233 | target_len += 1 234 | 235 | if self.config["mixing_method"] == "double_trans": 236 | 237 | for expert in self.config["experts"]: 238 | 239 | expert_t_list = [] 240 | if self.config["cls"]: 241 | expert_t_list.append(torch.rand(1, 2048)) 242 | # use test experts 243 | if not self.train: 244 | expert = "test-" + expert 245 | 246 | # not ordered dictionary so need to sort 247 | temp_list = [] 248 | # if sample is of len 4 remove the last one 249 | for i, d in enumerate(data.values()): 250 | try: 251 | temp_list.append(d[expert][0]) 252 | except KeyError: 253 | continue 254 | 255 | temp_list = sorted(temp_list) 256 | for i, d in enumerate(temp_list): 257 | if i < target_len: 258 | expert_t_list.append(self.load_tensor(temp_list[i])) 259 | while len(expert_t_list) < target_len: 260 | expert_t_list.append(expert_t_list[0]) 261 | 262 | print(len(expert_t_list)) 263 | 264 | assert(len(expert_t_list) == target_len) 265 | t_tens = torch.stack(expert_t_list) 266 | t_tens = t_tens.unsqueeze(0) 267 | expert_list.append(t_tens) 268 | else: 269 | # use test experts 270 | if not self.train: 271 | expert = "test-" + expert 272 | 273 | # not ordered dictionary so need to sort 274 | # if sample is of len 4 remove the last one 275 | for i, d in enumerate(data.values()): 276 | try: 277 | expert_list.append(d[expert][0]) 278 | except KeyError: 279 | continue 280 | 281 | temp_list = sorted(temp_list) 282 | for i, d in enumerate(temp_list): 283 | if i < 3: 284 | expert_list.append(self.load_tensor(temp_list[i]).unsqueeze(0)) 285 | while len(expert_list) < 3: 286 | expert_list.append(expert_list[0]) 287 | 288 | assert(len(expert_list) == 3) 289 | 290 | expert_list = torch.cat(expert_list, dim=0) 291 | expert_list = expert_list.squeeze() 292 | expert_list = expert_list.squeeze() 293 | label = torch.tensor([label]) 294 | 295 | # for index, i in enumerate(x_i): 296 | # print(i) 297 | # t = torch.load(i) 298 | # experts_xi.append(t.squeeze()) 299 | 300 | # for index, i in enumerate(x_j): 301 | # print(i) 302 | # t = torch.load(i) 303 | # experts_xj.append(t.squeeze()) 304 | 305 | 306 | return {"label": label, "path": path, "expert_list": expert_list} 307 | 308 | -------------------------------------------------------------------------------- /src/dataloaders/mmx/MMX_Contrastive_dl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import ast 4 | import random 5 | import csv 6 | import torch 7 | import torch.nn.functional as F 8 | import _pickle as pickle 9 | import os 10 | import numpy as np 11 | from collections import defaultdict 12 | from torch.utils.data import Dataset, random_split, DataLoader 13 | import pytorch_lightning as pl 14 | import json 15 | from sklearn.model_selection import train_test_split 16 | 17 | 18 | class MMXDataModule(pl.LightningDataModule): 19 | 20 | def __init__(self, train_data,val_data, config): 21 | super().__init__() 22 | self.train_data = train_data 23 | self.val_data = val_data 24 | self.config = config 25 | self.bs = self.config["batch_size"] 26 | 27 | def custom_collater(self, batch): 28 | 29 | return { 30 | 'label':[x['label'] for x in batch], 31 | 'x_i_experts':[x['x_i_experts'] for x in batch], 32 | 'x_j_experts':[x['x_j_experts'] for x in batch], 33 | } 34 | 35 | # def prepare_data(self): 36 | # data = self.load_data(self.pickle_file) 37 | # self.data = self.clean_data(data) 38 | 39 | def clean_data(self, data_frame, train=True): 40 | target_names = ['Action' ,'Adventure' ,'Comedy' ,'Crime' ,'Documentary' ,'Drama' ,'Family' , 'Fantasy' ,'History' ,'Horror' ,'Music' , 'Mystery' ,'Science Fiction' , 'Thriller', 'War'] 41 | 42 | 43 | print("cleaning data") 44 | print(len(data_frame)) 45 | 46 | data_frame = data_frame.reset_index(drop=True) 47 | 48 | longest_seq = 0 49 | for i in range(len(data_frame)): 50 | data = data_frame.at[i, "scene"] 51 | label = data_frame.at[i, "label"] 52 | n_labels = 0 53 | for l in label[0]: 54 | if l not in target_names: 55 | n_labels += 1 56 | if n_labels == 6: 57 | data_frame = data_frame.drop(i) 58 | continue 59 | data = data_frame.at[i, "data"] 60 | data_chunk = list(data.values()) 61 | 62 | if len(data_chunk) == 0: 63 | print("dropping index with no data", i, len(data_chunk)) 64 | data_frame = data_frame.drop(i) 65 | continue 66 | 67 | x = [len(data) for data in data_chunk] 68 | if sum(x) < len(x) * 3: 69 | print("dropping index with incomplete data", i, len(data)) 70 | data_frame = data_frame.drop(i) 71 | continue 72 | 73 | if train: 74 | experts = self.config["train_experts"] 75 | else: 76 | experts = self.config["test_experts"] 77 | for p in range(len(data_chunk)): 78 | for e in experts: 79 | if e not in data_chunk[p].keys(): 80 | print(data_chunk[p].keys()) 81 | print(i) 82 | data_frame = data_frame.drop(i) 83 | continue 84 | 85 | 86 | # test = [] 87 | # for f in data[1]: # data1 == img_embeddings, data2 == motion?, data0=location 88 | # print(f) 89 | # f = torch.load(f) 90 | # f = f.squeeze() 91 | # test.append(f) 92 | # print(f.dim) 93 | # if f.dim() > 0: 94 | # test.append(f) 95 | #else: 96 | # data_frame = data_frame.drop(i) 97 | # continue 98 | # try: 99 | # test = torch.cat(test, dim=-1) 100 | # except: 101 | # data_frame = data_frame.drop(i) 102 | # print("dropping", i) 103 | # continue 104 | #print(test.shape[0]) 105 | #if test.shape[0] != 2560: 106 | # print("dropping", i) 107 | # data_frame = data_frame.drop(i) 108 | # continue 109 | 110 | data_frame = data_frame.reset_index(drop=True) 111 | print(len(data_frame)) 112 | 113 | return data_frame 114 | 115 | def load_data(self, db): 116 | print("loading data") 117 | data = [] 118 | with open(db, "rb") as pkly: 119 | while 1: 120 | try: 121 | # append if data serialised with open file 122 | data.append(pickle.load(pkly)) 123 | # else data not streamed 124 | #data = pickle.load(pkly) 125 | except EOFError: 126 | break 127 | 128 | data_frame = pd.DataFrame(data) 129 | print("data loaded") 130 | print("length", len(data_frame)) 131 | # data_frame = data_frame.head(1000) 132 | return data_frame 133 | 134 | def setup(self, stage): 135 | 136 | self.train_data = self.load_data(self.train_data) 137 | self.train_data = self.clean_data(self.train_data, train=True) 138 | 139 | self.val_data = self.load_data(self.val_data) 140 | self.val_data = self.clean_data(self.val_data, train=False) 141 | 142 | def train_dataloader(self): 143 | return DataLoader(MMX_Dataset(self.train_data, self.config, train=True), self.bs, shuffle=True, collate_fn=self.custom_collater, num_workers=10, drop_last=True) 144 | 145 | def val_dataloader(self): 146 | return DataLoader(MMX_Dataset(self.val_data, self.config, train=False), self.bs, shuffle=False, collate_fn=self.custom_collater, num_workers=10, drop_last=True) 147 | # For now use validation until proper test split obtained 148 | def test_dataloader(self): 149 | return DataLoader(MMX_Dataset(self.train_data, self.config), 1, shuffle=False, collate_fn=self.custom_collater, num_workers=30) 150 | 151 | 152 | 153 | class MMX_Dataset(Dataset): 154 | def __init__(self, data, config, train=True): 155 | super().__init__() 156 | 157 | self.config = config 158 | self.data_frame = data 159 | self.aggregation = self.config["aggregation"] 160 | self.train = train 161 | 162 | 163 | def __len__(self): 164 | return len(self.data_frame) 165 | 166 | def collect_labels(self, label): 167 | 168 | target_names = ['Action' ,'Adventure' ,'Comedy' ,'Crime' ,'Documentary' ,'Drama' ,'Family' , 'Fantasy' ,'History' ,'Horror' ,'Music' , 'Mystery' ,'Science Fiction' , 'Thriller', 'War'] 169 | 170 | label_list = np.zeros(15) 171 | 172 | for i, genre in enumerate(target_names): 173 | if genre == "Sci-Fi" or genre == "ScienceFiction": 174 | genre = "Science Fiction" 175 | if genre in label: 176 | label_list[i] = 1 177 | 178 | return label_list 179 | 180 | 181 | def __getitem__(self, idx): 182 | 183 | label = self.data_frame.at[idx, "label"] 184 | label = self.collect_labels(label[0]) 185 | label = torch.FloatTensor(label) 186 | data = self.data_frame.at[idx, "data"] 187 | # path = self.data_frame.at[idx, "path"] 188 | # path = path.replace("/mnt/fvpbignas/datasets/mmx_raw", "/mnt/bigelow/scratch/mmx_aug") 189 | # try: 190 | # path = glob.glob(path + "/*/")[0] 191 | # path = os.path.join(path, "imgs") 192 | # path = glob.glob(path + "/*")[1] 193 | # except: 194 | # path = "None" 195 | scene = self.data_frame.at[idx, "scene"] 196 | 197 | experts_xi = [] 198 | experts_xj = [] 199 | 200 | # apply mix-up if less than 2 samples 201 | # keys here are the scenes ["000", "001", "002"] etc 202 | # if there are less than 2 scenes in the sample we need to do something else. 203 | 204 | if self.train: 205 | experts = self.config["train_experts"] 206 | else: 207 | experts = self.config["test_experts"] 208 | if len(data) < 2: 209 | data = list(data.values())[0] # take the first index as its the only valid one 210 | #data = list(data.values()) 211 | if idx == 0: 212 | idmx = idx + 1 213 | else: 214 | idmx = idx - 1 215 | mix_up_data = self.data_frame.at[idmx , "data"] 216 | mix_up_data = list(mix_up_data.values())[0] # take the first index of the sample either before or after e.g ["000"] 217 | 218 | # Now we have data = ["000"][experts] and mix_up_data = ["000"][experts] 219 | 220 | for expert in experts: 221 | try: 222 | expert_vec = data[expert] 223 | except: 224 | continue 225 | if not isinstance(expert_vec, str): 226 | expert_vec = random.choice(expert_vec) 227 | expert_vec = torch.load(expert_vec) 228 | if len(expert_vec.shape) < 2: 229 | expert_vec = expert_vec.unsqueeze(0) 230 | experts_xi.append(expert_vec) 231 | 232 | expert_vec = mix_up_data[expert] 233 | if not isinstance(expert_vec, str): 234 | expert_vec = random.choice(expert_vec) 235 | expert_vec = torch.load(expert_vec) 236 | if len(expert_vec.shape) < 2: 237 | expert_vec = expert_vec.unsqueeze(0) 238 | experts_xj.append(expert_vec) 239 | 240 | else: 241 | # select two random scenes as a positive pair 242 | x_i, x_j = random.sample(list(data.values()), 2) 243 | 244 | # x_i["test-location"] etc are valid keys 245 | 246 | for expert in experts: 247 | expert_vec = x_i[expert] 248 | if not isinstance(expert_vec, str): 249 | expert_vec = random.choice(expert_vec) 250 | expert_vec = torch.load(expert_vec) 251 | if len(expert_vec.shape) < 2: 252 | expert_vec = expert_vec.unsqueeze(0) 253 | experts_xi.append(expert_vec) 254 | 255 | expert_vec = x_j[expert] 256 | if not isinstance(expert_vec, str): 257 | expert_vec = random.choice(expert_vec) 258 | expert_vec = torch.load(expert_vec) 259 | if len(expert_vec.shape) < 2: 260 | expert_vec = expert_vec.unsqueeze(0) 261 | experts_xj.append(expert_vec) 262 | 263 | 264 | if self.aggregation == "debugging": 265 | experts_xi = torch.cat(experts_xi, dim=-1) 266 | experts_xj = torch.cat(experts_xj, dim=-1) 267 | 268 | return {"label":label, "scene":scene, "x_i_experts":experts_xi, "x_j_experts":experts_xj} 269 | 270 | 271 | class MIT_RAW_Dataset(Dataset): 272 | def __init__(self, config, pre_computed=True): 273 | super().__init__() 274 | self.config = config 275 | self.pre_computed = pre_computed 276 | self.chunk_size = config['data_size'].get() 277 | self.data_frame = self.load_data() 278 | # self.ee = EmbeddingExtractor(self.config) 279 | 280 | def load_data(self): 281 | train_data_frame = pd.read_csv(self.config['train_csv'].get()) 282 | # val_data_frame = pd.read_csv(self.config['val.csv'].get()) 283 | return train_data_frame 284 | 285 | def __len__(self): 286 | return len(self.data_frame) 287 | 288 | def stack_and_permute_vid(self, img_list): 289 | img_list = torch.stack(img_list) 290 | img_list = img_list.squeeze(1) 291 | img_list = img_list.permute(1, 0, 2, 3) 292 | return img_list 293 | 294 | def open_pt_return_list(self, folder_path): 295 | items = glob.glob(folder_path + "/*.pt") 296 | tensor_list = [] 297 | if len(items) > 1: 298 | for i in items: 299 | with torch.no_grad(): 300 | x = torch.load(i, map_location="cuda:3") 301 | x = x.detach() 302 | tensor_list.append(x) 303 | return tensor_list 304 | else: 305 | with torch.no_grad(): 306 | x = torch.load(items[0], map_location="cuda:3") 307 | x = x.detach() 308 | return x 309 | 310 | # For precomputed embeddings that need to be loaded 311 | def collect_pre_computed_embeddings(self, video, config, label): 312 | sample_dict = defaultdict(dict) 313 | # video_name = os.path.basename(video).replace(".mp4", "") 314 | # root_dir = os.path.join(config["train_root"].get()) 315 | dirs = glob.glob(video + "/*/") 316 | x_i_folder = dirs.pop(random.randrange(len(dirs))) 317 | x_j_folder = dirs.pop(random.randrange(len(dirs))) 318 | for s in ["x_i", "x_j"]: 319 | if s == "x_i": 320 | x_folder = x_i_folder 321 | else: 322 | x_folder = x_j_folder 323 | 324 | sample_dict[s]["video"] = self.open_pt_return_list(os.path.join(x_folder, "video-embeddings")) 325 | sample_dict[s]["location"] = self.open_pt_return_list(os.path.join(x_folder, "location-embeddings")) 326 | sample_dict[s]["image"] = self.open_pt_return_list(os.path.join(x_folder, "img-embeddings")) 327 | return sample_dict 328 | 329 | """ def collect_embedding(self, video, config): 330 | norm = Normaliser(config) 331 | sc = SpatioCut() 332 | video_imgs = sc.cut_vid(video, 16) 333 | if len(video_imgs) < 16: 334 | return 0 335 | 336 | # Take two groups of frames randomly - may want to make this 337 | # a temporal distance in the future as per the spatio-temporal 338 | # paper. 339 | 340 | x_i = video_imgs.pop(random.randrange(0, len(video_imgs))) 341 | x_j = video_imgs.pop(random.randrange(0, len(video_imgs))) 342 | augment_i = ImgTransform(x_i[0], config) 343 | augment_j = ImgTransform(x_j[0], config) 344 | sample_dict = defaultdict(dict) 345 | i_3d, i_loc, i_obj = [], [], [], [] 346 | j_3d, j_loc, j_obj = [], [], [], [] 347 | 348 | for img in x_i: 349 | t_img = augment_i.transform_with_prob(img) 350 | i_3d.append(norm.video_model(t_img)) 351 | i_loc.append(norm.location_model(t_img)) 352 | # i_dep.append(norm.depth_model(t_img)) 353 | i_obj.append(norm.img_model(t_img)) 354 | 355 | i_3d = self.stack_and_permute_vid(i_3d) 356 | sample_dict["x_i"]["video"] = i_3d 357 | sample_dict["x_i"]["location"] = i_loc 358 | # sample_dict["x_i"]["depth"] = i_dep 359 | sample_dict["x_i"]["image"] = i_obj 360 | 361 | for img in x_j: 362 | t_img = augment_j.transform_with_prob(img) 363 | j_3d.append(norm.video_model(t_img)) 364 | j_loc.append(norm.location_model(t_img)) 365 | # j_dep.append(norm.depth_model(t_img)) 366 | j_obj.append(norm.img_model(t_img)) 367 | 368 | j_3d = self.stack_and_permute_vid(j_3d) 369 | sample_dict["x_j"]["video"] = j_3d 370 | sample_dict["x_j"]["location"] = j_loc 371 | # sample_dict["x_j"]["depth"] = i_dep 372 | sample_dict["x_j"]["image"] = j_obj 373 | 374 | return sample_dict """ 375 | 376 | 377 | def return_expert_for_key_pretrained(self, key, raw_tensor): 378 | 379 | if key == "image": 380 | if len(raw_tensor) > 1: 381 | output = torch.stack(raw_tensor) 382 | output = output.transpose(0, 2) 383 | output = F.adaptive_avg_pool1d(output, 1) 384 | output = output.transpose(1, 0).squeeze(2) 385 | output = output.squeeze(1) 386 | else: 387 | output = raw_tensor[0].unsqueeze(0) 388 | 389 | if key == "motion" or key == "video": 390 | output = raw_tensor[0].unsqueeze(0) 391 | 392 | if key == "location": 393 | if len(raw_tensor) > 1: 394 | output = torch.stack(raw_tensor) 395 | output = output.transpose(0, 2) 396 | output = F.adaptive_avg_pool1d(output, 1) 397 | output = output.transpose(1, 0).squeeze(2) 398 | output = output.squeeze(1) 399 | else: 400 | output = raw_tensor[0].unsqueeze(0) 401 | 402 | return output 403 | 404 | def __getitem__(self, idx): 405 | label = self.data_frame.at[idx, "label"] 406 | path = self.data_frame.at[idx, "path"] 407 | 408 | if self.pre_computed: 409 | embedding_dict = self.collect_pre_computed_embeddings(path, self.config, label) 410 | 411 | x_i = embedding_dict["x_i"] 412 | x_j = embedding_dict["x_j"] 413 | 414 | for key, value in x_i.items(): 415 | x_i[key] = self.return_expert_for_key_pretrained(key, value) 416 | 417 | for key, value in x_j.items(): 418 | x_j[key] = self.return_expert_for_key_pretrained(key, value) 419 | 420 | return {'label': embedding_dict['label'], 'x_i': x_i, 'x_j': x_j} 421 | 422 | # else: 423 | # embedding_dict = self.collect_embedding(embed_dict["video"], self.config) 424 | # x_i = embedding_dict["x_i"] 425 | # x_j = embedding_dict["x_j"] 426 | 427 | # for key, value in x_i.items(): 428 | # x_i[key] = self.ee.return_expert_for_key(key, value) 429 | 430 | # for key, value in x_j.items(): 431 | # x_j[key] = self.ee.return_expert_for_key(key, value) 432 | 433 | # return { 'label': embed_dict['label'], 'x_i': x_i, 'x_j': x_j } 434 | 435 | 436 | class CustomDataset(Dataset): 437 | def __init__(self, config): 438 | 439 | self.config = config 440 | self.data_frame = self.load_data() 441 | 442 | def load_data(self): 443 | data_frame = pd.read_csv(self.config['input_csv'].get(), chunksize=self.config['data_size'].get()) 444 | return data_frame 445 | 446 | # def stack(self): 447 | 448 | def __len__(self): 449 | return len(self.data_frame) 450 | 451 | def collect_embeddings(self, data_type, idx): 452 | embedding_stack = [] 453 | data_path = self.data_frame.at[idx, data_type] 454 | if len(os.listdir(data_path)) > 1: 455 | for embed in os.listdir(data_path): 456 | embed_path = os.path.join(data_path, embed) 457 | embedding_stack.append(torch.load(embed_path)) 458 | data = torch.stack(embedding_stack) 459 | else: 460 | for embed in os.listdir(data_path): 461 | embed_path = os.path.join(data_path, embed) 462 | data = torch.load(embed_path) 463 | 464 | return data 465 | 466 | 467 | def __getitem__(self, idx): 468 | embed_dict = dict() 469 | embed_dict["label"] = self.data_frame.at[idx, LABEL] 470 | embed_dict["chunk"] = self.data_frame.at[idx, CHUNK] 471 | expert_list = self.config['experts'].get() 472 | if "image" in expert_list: 473 | embed_dict["image"] = self.collect_embeddings(IMG_EMBED, idx) 474 | if "location" in expert_list: 475 | embed_dict["location"] = self.collect_embeddings(LOC_EMBED, idx) 476 | if "depth" in expert_list: 477 | embed_dict["depth"] = self.collect_embeddings(DEPTH_EMBED, idx) 478 | if "motion" in expert_list: 479 | embed_dict["motion"] = self.collect_embeddings(VID_EMBED, idx) 480 | if "audio" in expert_list: 481 | embed_dict["audio"] = self.collect_embeddings(AUDIO_EMBED, idx) 482 | 483 | return embed_dict 484 | 485 | -------------------------------------------------------------------------------- /src/dataloaders/mmx/MMX_Frame_dl.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import _pickle as pickle 4 | from torch.utils.data import Dataset, DataLoader 5 | import pytorch_lightning as pl 6 | from PIL import Image 7 | from torchvision import transforms 8 | from random import randint 9 | 10 | 11 | class MMXFrameDataModule(pl.LightningDataModule): 12 | 13 | def __init__(self, train_data, val_data, config): 14 | super().__init__() 15 | self.train_data = train_data 16 | self.val_data = val_data 17 | self.config = config 18 | self.bs = config["batch_size"] 19 | self.seq_len = config["seq_len"] 20 | 21 | def load_data(self, db): 22 | data = [] 23 | with open(db, "rb") as pkly: 24 | while 1: 25 | try: 26 | # append if data serialised with open file 27 | data.append(pickle.load(pkly)) 28 | # else data not streamed 29 | # data = pickle.load(pkly) 30 | except EOFError: 31 | break 32 | 33 | data_frame = pd.DataFrame(data) 34 | data_frame = data_frame.reset_index(drop=True) 35 | # data_frame = data_frame.head(2000) 36 | print("length of data", len(data_frame)) 37 | return data_frame 38 | 39 | def setup(self, stage): 40 | self.train_data = self.load_data(self.train_data) 41 | self.val_data = self.load_data(self.val_data) 42 | 43 | def train_dataloader(self): 44 | return DataLoader(MMXFrameDataset(self.train_data, self.config, state="train"), self.bs, shuffle=True, num_workers=1, drop_last=True) 45 | 46 | def val_dataloader(self): 47 | return DataLoader(MMXFrameDataset(self.val_data, self.config, state="val"), self.bs, shuffle=False, num_workers=1, drop_last=True) 48 | 49 | def test_dataloader(self): 50 | return DataLoader(MMXFrameDataset(self.val_data, self.config, state="test"), self.bs, shuffle=False,drop_last=True, num_workers=5) 51 | 52 | 53 | class MMXFrameDataset(Dataset): 54 | def __init__(self, data, config, state="train"): 55 | super().__init__() 56 | 57 | self.config = config 58 | self.data_frame = data 59 | self.seq_len = self.config["seq_len"] 60 | self.state = state 61 | self.max_len = self.config["seq_len"] 62 | 63 | self.train_transform = transforms.Compose([ 64 | transforms.RandomResizedCrop(224), 65 | transforms.RandomHorizontalFlip(p=0.3), 66 | transforms.RandomVerticalFlip(p=0.3), 67 | transforms.AutoAugment(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]), 71 | ]) 72 | 73 | self.val_transform = transforms.Compose([ 74 | transforms.Resize(230), 75 | transforms.CenterCrop(224), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 78 | std=[0.229, 0.224, 0.225]), 79 | ]) 80 | 81 | self.train_vid = transforms.Compose([ 82 | transforms.Resize(120), 83 | transforms.CenterCrop(112), 84 | transforms.ToTensor(), 85 | transforms.Normalize( 86 | mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), 87 | transforms.RandomErasing(), 88 | ]) 89 | 90 | self.val_vid = transforms.Compose([ 91 | transforms.Resize(120), 92 | transforms.CenterCrop(112), 93 | transforms.ToTensor(), 94 | transforms.Normalize( 95 | mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), 96 | ]) 97 | 98 | def __len__(self): 99 | return len(self.data_frame) 100 | 101 | def pil_loader(self, path): 102 | img = Image.open(path) 103 | img = img.convert('RGB') 104 | return img 105 | 106 | def img_trans(self, img): 107 | if self.state == "train": 108 | img = self.train_transform(img) 109 | else: 110 | img = self.val_transform(img) 111 | #img_tensor = img_tensor.float() 112 | return img 113 | 114 | def vid_trans(self, vid): 115 | if self.state == "train": 116 | vid = self.train_vid(vid) 117 | else: 118 | vid = self.val_vid(vid) 119 | return vid 120 | 121 | def __getitem__(self, idx): 122 | 123 | label = self.data_frame.at[idx, "label"] 124 | scenes = self.data_frame.at[idx, "scenes"] 125 | x = torch.empty([self.max_len, 3, 224, 224]) 126 | v = torch.empty([self.max_len, 12, 3, 112, 112]) 127 | img_list = torch.full_like(x, 0) 128 | vid = torch.full_like(v, 0) 129 | num_collected = 0 130 | for j, s in enumerate(scenes.values()): 131 | if num_collected == self.max_len: 132 | break 133 | try: 134 | clip = s[0] 135 | except KeyError: 136 | try: 137 | clip = s["000"] 138 | except KeyError: 139 | try: 140 | clip = s["0"] 141 | except: 142 | continue 143 | 144 | if self.config["model"] == "sum" or self.config["model"] == "distil" or self.config["model"] == "vid" or self.config["model"] == "pre_modal" or self.config["model"] == "sum_residual": 145 | if self.state == "train": 146 | start_slice = randint(0, len(clip) - 13) 147 | clip_slice = clip[start_slice:start_slice + 12] 148 | else: 149 | start_slice = 0 150 | clip_slice = clip[0:12] 151 | for i in range(12): 152 | vid[num_collected][i] = self.vid_trans( 153 | self.pil_loader(clip_slice[i])) 154 | img_t = self.img_trans(self.pil_loader(clip[randint(0, len(clip) -1)])) 155 | img_list[num_collected] = img_t 156 | #img_list = [] 157 | num_collected += 1 158 | # vid = vid.permute(0, 2, 1, 3, 4) 159 | if self.config["model"] == "sum" or self.config["model"] == "distil" or self.config["model"] == "pre_modal" or self.config["model"] == "sum_residual": 160 | return label, img_list, vid 161 | if self.config["model"] == "frame": 162 | return label, img_list 163 | if self.config["model"] == "vid": 164 | return label, vid 165 | 166 | -------------------------------------------------------------------------------- /src/dataloaders/mmx/MMX_Light_dl.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import pandas as pd 3 | from simplejson import OrderedDict 4 | import torch 5 | import pickle 6 | import numpy as np 7 | from torch.utils.data import Dataset, DataLoader, IterableDataset 8 | import pytorch_lightning as pl 9 | from PIL import Image 10 | from torchvision import transforms 11 | from random import randint 12 | from collections import OrderedDict 13 | import glob 14 | from sklearn.utils import shuffle 15 | import nvidia.dali as dali 16 | from nvidia.dali import pipeline_def 17 | import nvidia.dali.fn as fn 18 | import nvidia.dali.types as types 19 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy 20 | from nvidia.dali.pipeline import Pipeline 21 | import nvidia.dali.ops as ops 22 | import nvidia.dali.types as types 23 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 24 | 25 | 26 | class InputIterator(object): 27 | def __init__(self, data, batch_size_total, batch_size, seq_len, frame_len): 28 | super().__init__() 29 | self.data = data 30 | self.batch_size_total = batch_size_total 31 | self.batch_size = batch_size 32 | self.indices = list(range(len(self.data))) 33 | self.seq_len = seq_len 34 | self.frame_len = frame_len 35 | 36 | def __iter__(self): 37 | self.i = 0 38 | self.n = len(self.data) 39 | return self 40 | 41 | def __next__(self): 42 | batch = [] 43 | targets = [] 44 | for _ in range(self.batch_size): 45 | img_root = self.data["img_root"].iloc[self.i] 46 | labels = [] 47 | vids = [] 48 | for i in range(1, 7): 49 | labels.append(self.data[f"g{i}"].iloc[self.i]) 50 | target = self.collect_labels(labels) 51 | scenes = sorted(glob.glob(img_root + "/*")) 52 | scene_len = len(scenes) 53 | 54 | for scene in range(self.seq_len): 55 | imgs = sorted(glob.glob(scenes[scene % scene_len] + "/*")) 56 | img_len = len(imgs) 57 | print("img len", img_len) 58 | for img in range(self.frame_len): 59 | print("img img_len", img % img_len) 60 | f = open(imgs[img % img_len], 'rb') 61 | vid = np.frombuffer(f.read(), dtype=np.uint8) 62 | batch.append(vid) 63 | targets.append(target) 64 | self.i = (self.i + 1) % self.n 65 | print("BATHC", len(batch)) 66 | return (batch, targets) 67 | 68 | def collect_labels(self, label): 69 | target_names = ['Action', 'Animation', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 70 | 'Fantasy', 'History', 'Horror', 'Music', 'Romance', 'Mystery', 'TVMovie', 'ScienceFiction', 'Thriller', 'War', 'Western'] 71 | label_list = np.zeros(19) 72 | for i, genre in enumerate(target_names): 73 | if genre in label: 74 | label_list[i] = 1 75 | if np.sum(label_list) == 0: 76 | label_list[6] = 1 77 | return label_list 78 | 79 | 80 | class SimplePipeline(Pipeline): 81 | def __init__(self, batch_size, eii, num_threads=2, device_id=0, resolution=256, crop=224, is_train=True): 82 | super(SimplePipeline, self).__init__( 83 | batch_size, num_threads, device_id, seed=12) 84 | self.source = ops.ExternalSource(source=eii, num_outputs=2) 85 | 86 | def define_graph(self): 87 | images, labels = self.source() 88 | images = fn.decoders.image( 89 | images, device="mixed", output_type=types.RGB) 90 | images = fn.resize( 91 | images, 92 | resize_shorter=fn.random.uniform(range=(120, 200)), 93 | interp_type=types.INTERP_LINEAR) 94 | images = fn.crop_mirror_normalize( 95 | images, 96 | crop_pos_x=fn.random.uniform(range=(0.0, 1.0)), 97 | crop_pos_y=fn.random.uniform(range=(0.0, 1.0)), 98 | dtype=types.FLOAT, 99 | crop=(112, 112), 100 | mean=[128., 128., 128.], 101 | std=[1., 1., 1.]) 102 | return images, labels 103 | 104 | 105 | class DALIClassificationLoader(DALIClassificationIterator): 106 | def __init__( 107 | self, pipelines, size=-1, reader_name=None, auto_reset=False, fill_last_batch=False, dynamic_shape=False, last_batch_padded=False): 108 | super().__init__(pipelines, 109 | size, 110 | reader_name, 111 | auto_reset, 112 | fill_last_batch, 113 | dynamic_shape, 114 | last_batch_padded) 115 | 116 | def __len__(self): 117 | batch_count = self._size // (self._num_gpus * self.batch_size) 118 | last_batch = 1 if self._fill_last_batch else 0 119 | print("COUNTER", batch_count) 120 | return batch_count + last_batch 121 | 122 | 123 | class MMXLightDataModule(pl.LightningDataModule): 124 | def __init__(self, csv_path, config): 125 | super().__init__() 126 | self.csv_path = csv_path 127 | self.config = config 128 | self.bs = config["batch_size"] 129 | self.seq_len = config["seq_len"] 130 | self.frame_len = config["frame_len"] 131 | self.bs_total = self.bs * self.seq_len * self.frame_len 132 | 133 | def setup(self, stage=None): 134 | 135 | self.data_frame = pd.read_csv(self.csv_path) 136 | self.data_frame = shuffle(self.data_frame) 137 | self.train_data = self.data_frame.iloc[:6047, :] 138 | self.train_data.reset_index(drop=True) 139 | self.val_data = self.data_frame.iloc[6047:6700, :] 140 | self.val_data.reset_index(drop=True) 141 | print(len(self.train_data)) 142 | print(len(self.val_data)) 143 | # device_id = self.local_rank 144 | # shard_id = self.global_rank 145 | # train_dataset = InputIterator( 146 | # self.train_data, self.bs_total, self.bs, self.seq_len, self.frame_len) 147 | # val_dataset = InputIterator( 148 | # self.val_data, self.bs_total, self.bs, self.seq_len, self.frame_len) 149 | 150 | # pipe_train = SimplePipeline( 151 | # batch_size=self.bs_total, eii=train_dataset, num_threads=2, device_id=0) 152 | # pipe_train.build() 153 | # self.train_loader = DALIClassificationLoader( 154 | # pipe_train, len(self.train_data), auto_reset=True) 155 | 156 | # pipe_val = SimplePipeline( 157 | # batch_size=self.bs_total, eii=val_dataset, num_threads=2, device_id=0) 158 | # pipe_val.build() 159 | # self.val_loader = DALIClassificationLoader( 160 | # pipe_train, len(self.val_data), auto_reset=True) 161 | self.train_loader = DataLoader(MMXLightDataset(self.train_data, self.config, state="train"), batch_size=self.bs, drop_last=True, num_workers=10, pin_memory=True) 162 | self.val_loader = DataLoader(MMXLightDataset(self.val_data, self.config, state="val"), batch_size=self.bs, drop_last=True, num_workers=10, pin_memory=True) 163 | 164 | def train_dataloader(self): 165 | return self.train_loader 166 | 167 | def val_dataloader(self): 168 | return self.val_loader 169 | 170 | def test_dataloader(self): 171 | return self.val_loader 172 | 173 | 174 | class MMXLightDataset(Dataset): 175 | def __init__(self, data, config, state="train"): 176 | super().__init__() 177 | 178 | self.config = config 179 | self.data_frame = data 180 | self.seq_len = self.config["seq_len"] 181 | self.state = state 182 | self.max_len = self.config["seq_len"] 183 | 184 | self.train_transform = transforms.Compose([ 185 | transforms.RandomResizedCrop(224), 186 | transforms.RandomHorizontalFlip(p=0.3), 187 | transforms.RandomVerticalFlip(p=0.3), 188 | transforms.AutoAugment(), 189 | transforms.ToTensor(), 190 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 191 | std=[0.229, 0.224, 0.225]), 192 | ]) 193 | 194 | self.val_transform = transforms.Compose([ 195 | transforms.Resize(230), 196 | transforms.CenterCrop(224), 197 | transforms.ToTensor(), 198 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 199 | std=[0.229, 0.224, 0.225]), 200 | ]) 201 | 202 | self.train_vid = transforms.Compose([ 203 | transforms.Resize(120), 204 | transforms.CenterCrop(112), 205 | transforms.ToTensor(), 206 | transforms.Normalize( 207 | mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), 208 | #transforms.RandomErasing(), 209 | ]) 210 | 211 | self.val_vid = transforms.Compose([ 212 | transforms.Resize(112), 213 | transforms.CenterCrop(112), 214 | transforms.ToTensor(), 215 | transforms.Normalize( 216 | mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), 217 | ]) 218 | 219 | def __len__(self): 220 | return len(self.data_frame) 221 | 222 | def pil_loader(self, path): 223 | img = Image.open(path) 224 | img = img.convert('RGB') 225 | return img 226 | 227 | def img_trans(self, img): 228 | if self.state == "train": 229 | img = self.train_transform(img) 230 | else: 231 | img = self.val_transform(img) 232 | # img_tensor = img_tensor.float() 233 | return img 234 | 235 | def collect_labels(self, label): 236 | 237 | target_names = ['Action', 'Animation', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 238 | 'Fantasy', 'History', 'Horror', 'Music', 'Romance', 'Mystery', 'TVMovie', 'ScienceFiction', 'Thriller', 'War', 'Western'] 239 | label_list = np.zeros(19) 240 | for i, genre in enumerate(target_names): 241 | if genre in label: 242 | label_list[i] = 1 243 | if np.sum(label_list) == 0: 244 | label_list[6] = 1 245 | return label_list 246 | 247 | def vid_trans(self, vid): 248 | if self.state == "train": 249 | vid = self.train_vid(vid) 250 | else: 251 | vid = self.val_vid(vid) 252 | return vid 253 | 254 | def __getitem__(self, idx): 255 | 256 | x = torch.empty([self.max_len, 3, 224, 224]) 257 | v = torch.empty([self.max_len, 12, 3, 112, 112]) 258 | img_list = torch.full_like(x, 0) 259 | vid = torch.full_like(v, 0) 260 | row = self.data_frame.iloc[idx] 261 | labels = [] 262 | for i in range(1, 6): 263 | labels.append(row[f"g{i}"]) 264 | img_root = row["img_root"] 265 | target = self.collect_labels(labels) 266 | scenes = sorted(glob.glob(img_root + "/*")) 267 | frame_dict = OrderedDict() 268 | 269 | for scene, img_dir in enumerate(scenes): 270 | frame_list = sorted(glob.glob(img_dir + "/*.png")) 271 | frame_dict[scene] = frame_list 272 | scene_len = len(scenes) 273 | i = 0 274 | for j in range(self.max_len): 275 | k = 0 276 | imgs = frame_dict[i] 277 | img_len = len(imgs) 278 | for x in range(12): 279 | vid[i][k] = self.vid_trans(self.pil_loader(imgs[k])) 280 | k += 1 281 | k = k % img_len 282 | #img_list[i] = self.img_trans(self.pil_loader(imgs[3])) 283 | i += 1 284 | i = i % scene_len 285 | 286 | return target, img_list, vid 287 | 288 | # Add looping for out of index samples 289 | # grab img as well as video frames or find model with same dim 290 | # set scenes etc as a parameter - not hardcoded 291 | -------------------------------------------------------------------------------- /src/dataloaders/mmx/MMX_Temporal_dl.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import ast 4 | import random 5 | import csv 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import _pickle as pickle 10 | import os 11 | import numpy as np 12 | from collections import defaultdict 13 | from torch.utils.data import Dataset, random_split, DataLoader 14 | import pytorch_lightning as pl 15 | import json 16 | from sklearn.model_selection import train_test_split 17 | import random 18 | 19 | 20 | class MMXDataModule(pl.LightningDataModule): 21 | 22 | def __init__(self, train_data, val_data, config): 23 | super().__init__() 24 | self.train_data = train_data 25 | self.val_data = val_data 26 | self.config = config 27 | self.bs = config["batch_size"] 28 | self.seq_len = config["seq_len"] 29 | 30 | def custom_collater(self, batch): 31 | 32 | return { 33 | 'label': [x['label'] for x in batch], 34 | 'experts': [x['experts'] for x in batch], 35 | 'path': [x["path"] for x in batch] 36 | } 37 | 38 | # def prepare_data(self): 39 | # data = self.load_data(self.pickle_file) 40 | # self.data = self.clean_data(data) 41 | 42 | def clean_data(self, data_frame): 43 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 44 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 45 | 46 | print("cleaning data") 47 | print(data_frame.describe()) 48 | 49 | longest_seq = 0 50 | for i in range(len(data_frame)): 51 | data = data_frame.at[i, "scenes"] 52 | label = data_frame.at[i, "label"] 53 | n_labels = 0 54 | for l in label[0]: 55 | if l not in target_names: 56 | n_labels += 1 57 | if n_labels == 6: 58 | data_frame = data_frame.drop(i) 59 | continue 60 | data_chunk = list(data.values()) 61 | if len(data_chunk) > longest_seq: 62 | longest_seq = len(data_chunk) 63 | if len(data_chunk) < 5: 64 | data_frame = data_frame.drop(i) 65 | continue 66 | 67 | data_frame = data_frame.reset_index(drop=True) 68 | return data_frame 69 | 70 | def load_data(self, db): 71 | data = [] 72 | with open(db, "rb") as pkly: 73 | while 1: 74 | try: 75 | # append if data serialised with open file 76 | data.append(pickle.load(pkly)) 77 | # else data not streamed 78 | # data = pickle.load(pkly) 79 | except EOFError: 80 | break 81 | 82 | data_frame = pd.DataFrame(data) 83 | #data_frame = data_frame.head(2000) 84 | print("data loaded") 85 | print("length", len(data_frame)) 86 | return data_frame 87 | 88 | def setup(self, stage): 89 | 90 | self.train_data = self.load_data(self.train_data) 91 | self.train_data = self.clean_data(self.train_data) 92 | self.val_data = self.load_data(self.val_data) 93 | self.val_data = self.clean_data(self.val_data) 94 | 95 | def train_dataloader(self): 96 | return DataLoader(MMXDataset(self.train_data, self.config, state="train"), self.bs, shuffle=True, num_workers=2, drop_last=True) 97 | 98 | def val_dataloader(self): 99 | return DataLoader(MMXDataset(self.val_data, self.config, state="val"), self.bs, shuffle=False, num_workers=2, drop_last=True) 100 | 101 | def test_dataloader(self): 102 | return DataLoader(MMXDataset(self.val_data, self.config, state="test"), self.bs, shuffle=False, collate_fn=self.custom_collater, drop_last=True) 103 | 104 | 105 | class MMXDataset(Dataset): 106 | def __init__(self, data, config, state="train"): 107 | super().__init__() 108 | 109 | self.config = config 110 | self.data_frame = data 111 | self.aggregation = None 112 | self.seq_len = self.config["seq_len"] 113 | self.state = state 114 | 115 | def __len__(self): 116 | return len(self.data_frame) 117 | 118 | def collect_labels(self, label): 119 | 120 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 121 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 122 | label_list = np.zeros(15) 123 | 124 | for i, genre in enumerate(target_names): 125 | if genre == "Sci-Fi" or genre == "ScienceFiction": 126 | genre = "Science Fiction" 127 | if genre in label: 128 | label_list[i] = 1 129 | if np.sum(label_list) == 0: 130 | label_list[5] = 1 131 | 132 | return label_list 133 | 134 | def load_tensor(self, tensor): 135 | tensor = torch.load(tensor, map_location=torch.device('cpu')) 136 | return tensor 137 | 138 | def return_expert_path(self, path, expert): 139 | if self.state == "val": 140 | expert = "test-" + expert 141 | 142 | try: 143 | scene_list = path[list(path.keys())[0]][expert] 144 | except KeyError: 145 | try: 146 | scene_list = path[list(path.keys())[0]][ex] 147 | except KeyError: 148 | scene_list = False 149 | except IndexError: 150 | scene_list = False 151 | except FileNotFoundError: 152 | scene_list = False 153 | return scene_list 154 | 155 | def retrieve_tensors(self, path, expert): 156 | tensor_paths = self.return_expert_path(path, expert) 157 | 158 | if tensor_paths: 159 | if expert == "img-embeddings" or expert == "location-embeddings": 160 | tensor_paths = tensor_paths[-1] 161 | try: 162 | t = self.load_tensor(tensor_paths) 163 | except FileNotFoundError: 164 | t = torch.zeros((1, 2048)) 165 | if expert == "audio-embeddings": 166 | t = t.unsqueeze(0) 167 | if t.shape[-1] != 2048: 168 | # zero pad dimensions. 169 | t = nn.ConstantPad1d((0, 2048 - t.shape[-1]), 0)(t) 170 | else: 171 | t = torch.zeros((1, 2048)) 172 | if self.state == "train": 173 | t = self.add_transforms(t) 174 | return t 175 | 176 | def add_transforms(self, x): 177 | if random.random() < 0.3: 178 | x = torch.zeros((1, 2048)) 179 | if random.random() < 0.3: 180 | x = x + (0.1**0.5)*torch.randn(1, 2048) 181 | return x 182 | 183 | def label_tidy(self, label): 184 | if len(label) == 2: 185 | return self.collect_labels(label[0]) 186 | else: 187 | return self.collect_labels(label) 188 | 189 | def multi_model_item_collection(self, scene_path): 190 | expert_tensor_list = [] 191 | for expert in self.config["experts"]: 192 | if self.config["mixing_method"] == "concat-norm": 193 | t = F.normalize(self.retrieve_tensors( 194 | scene_path, expert), p=2, dim=-1) 195 | else: 196 | t = self.retrieve_tensors(scene_path, expert) 197 | # Retrieve the tensors for each expert. 198 | expert_tensor_list.append(t) 199 | if self.config["mixing_method"] == "concat": 200 | # concat experts for pre model 201 | cat_experts = torch.cat(expert_tensor_list, dim=-1) 202 | # expert_list.append(cat_experts) 203 | if self.config["cat_norm"] == True: 204 | cat_experts = F.normalize( 205 | cat_experts, p=2, dim=-1) 206 | if self.config["cat_softmax"] == True: 207 | cat_experts = F.softmax(cat_experts, dim=-1) 208 | expert_list.append(cat_experts) 209 | elif self.config["mixing_method"] == "collab" or self.config["mixing_method"] == "post_collab": 210 | expert_list.append(torch.stack(expert_tensor_list)) 211 | 212 | def __getitem__(self, idx): 213 | 214 | # retrieve labels 215 | label = self.data_frame.at[idx, "label"] 216 | label = self.label_tidy(label) 217 | path = self.data_frame.at[idx, "path"] 218 | label = torch.tensor(label).unsqueeze(0) # Covert label to tensor 219 | scenes = self.data_frame.at[idx, "scenes"] 220 | expert_list = [] 221 | 222 | # iterate through the scenes for the trailer 223 | 224 | for d in scenes.values(): 225 | if len(expert_list) < self.seq_len: # collect tensors until sequence length 226 | expert_tensor_list = [] 227 | # otherwise return one expert 228 | try: 229 | tensor = self.retrieve_tensors( 230 | d, self.config["experts"][0]) 231 | except IndexError: 232 | continue 233 | except KeyError: 234 | print("key error", d) 235 | continue 236 | except IsADirectoryError: 237 | continue 238 | expert_list.append(tensor) 239 | 240 | if self.config["mixing_method"] == "collab" or self.config["mixing_method"] == "post_collab": 241 | while len(expert_list) < self.seq_len: 242 | pad_list = [] 243 | for i in range(len(self.config["experts"])): 244 | pad_list.append(torch.zeros_like(expert_list[0][0])) 245 | expert_list.append(torch.stack(pad_list)) 246 | if self.config["mixing_method"] == "post_collab": 247 | expert_list = torch.stack(expert_list) 248 | expert_list = expert_list.squeeze() 249 | else: 250 | while len(expert_list) < self.seq_len: 251 | expert_list.append(torch.zeros_like(expert_list[0])) 252 | 253 | expert_list = torch.cat(expert_list, dim=0) # scenes 254 | expert_list = expert_list.unsqueeze(0) 255 | 256 | return {"label": label, "path": path, "experts": expert_list} 257 | -------------------------------------------------------------------------------- /src/frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ed-fish/data-efficient-video-transformers/0a7d1b40563e244df14f4f33376cad413b2ba558/src/frame.png -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch.nn as nn 3 | import wandb 4 | import yaml 5 | import torch 6 | from pytorch_lightning.loggers import WandbLogger 7 | from dataloaders.mmx.MMX_Temporal_dl import MMXDataModule 8 | from dataloaders.mmx.MMX_Frame_dl import MMXFrameDataset, MMXFrameDataModule 9 | from dataloaders.mmx.MMX_Light_dl import MMXLightDataset, MMXLightDataModule 10 | from dataloaders.mit.MIT_Temporal_dl import MITDataModule, MITDataset 11 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 12 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 13 | from models.LSTM import LSTMRegressor 14 | from models.transformer import SimpleTransformer 15 | from models.frame_transformer import FrameTransformer 16 | from callbacks.callbacks import TransformerEval, DisplayResults, MITEval 17 | from yaml.loader import SafeLoader 18 | import torch.nn as nn 19 | 20 | from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad 21 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 22 | from pytorch_grad_cam.utils.image import show_cam_on_image 23 | 24 | if __name__ == "__main__": 25 | torch.manual_seed(1130) 26 | callbacks = [] 27 | with open('config.yaml') as f: 28 | data = yaml.load(f, Loader=SafeLoader) 29 | wandb.init(project="transformer-frame-video", name="seed_test_1130", 30 | config=data) 31 | config = wandb.config 32 | print(config["data_set"]) 33 | 34 | wandb_logger = WandbLogger( 35 | project=config["logger"]) 36 | 37 | if config["model"] == "ptn" or config["model"] == "ptn_shared": 38 | model = SimpleTransformer(**config) 39 | elif config["model"] == "lstm": 40 | model = LSTMRegressor(seq_len=200, batch_size=64, 41 | criterion=nn.BCELoss(), n_features=4608, hidden_size=512, num_layers=4, 42 | dropout=0.2, learning_rate=0.00005) 43 | elif config["model"] == "frame_transformer" or config["model"] == "distil" or config["model"] == "sum" or config["model"] == "frame" or config["model"] == "vid" or config["model"] == "pre_modal" or config["model"] == "sum_residual": 44 | model = FrameTransformer(**config) 45 | 46 | if config["data_set"] == "mit": 47 | miteval = MITEval() 48 | dm = MITDataModule("data/mit/MIT_train_temporal.pkl", 49 | "data/mit/MIT_validation_temporal.pkl", config) 50 | callbacks = [miteval] 51 | 52 | elif config["data_set"] == "mmx": 53 | transformer_callback = TransformerEval() 54 | dm = MMXDataModule("data/mmx/mmx_train_temporal.pkl", 55 | "data/mmx/mmx_val_temporal.pkl", config) 56 | callbacks = [transformer_callback] 57 | # checkpoint = ModelCheckpoint( 58 | # save_top_k=-1, dirpath="trained_models/mmx/double", filename="double-{epoch:02d}") 59 | # display = DisplayResults() 60 | 61 | elif config["data_set"] == "mmx-frame": 62 | transformer_callback = TransformerEval() 63 | dm = MMXLightDataModule("data/mmx/light/out.csv", config) 64 | if config["test"]: 65 | display = DisplayResults() 66 | callbacks = [transformer_callback] 67 | else: 68 | callbacks = [transformer_callback] 69 | else: 70 | assert( 71 | "No dataset selected, please update the configuration \n mit, mmx, mmx-frame") 72 | 73 | def weights_init_normal(m): 74 | '''Takes in a module and initializes all linear layers with weight 75 | values taken from a normal distribution.''' 76 | classname = m.__class__.__name__ 77 | # for every Linear layer in a model 78 | if classname.find('Linear') != -1: 79 | y = m.in_features 80 | # m.weight.data shoud be taken from a normal distribution 81 | m.weight.data.normal_(0.0, 1/np.sqrt(y)) 82 | # m.bias.data should be 0 83 | m.bias.data.fill_(0) 84 | # weights_init_normal(model) 85 | #trainer = pl.Trainer(gpus=1, logger=wandb_logger, callbacks=callbacks, accumulate_grad_batches=8, precision=16, max_epochs=50) 86 | 87 | trainer = pl.Trainer(gpus=1, logger=wandb_logger, 88 | callbacks=callbacks, max_epochs=1000) 89 | model = model.load_from_checkpoint("transformer-frame-video/2wxq6ed1/checkpoints/epoch=32-step=24947.ckpt") 90 | 91 | #trainer.fit(model, datamodule=dm) 92 | # dm.setup() 93 | # loader = dm.val_dataloader() 94 | # target, img, vid, frame_list = next(iter(loader)) 95 | # model = model.vid_model.backbone 96 | # target_layers = [model.layer4[-1]] 97 | 98 | # vid = vid.view(-1, 12, 3, 112, 112) 99 | # vid = vid.permute(0, 2, 1, 3, 4) 100 | # vid = vid[0].unsqueeze(0) 101 | # print(vid.shape) 102 | 103 | # cam = GradCAM(model=model, target_layers=target_layers) 104 | # grayscale_cam = cam(input_tensor=vid) 105 | # print(grayscale_cam.shape) 106 | # grayscale_cam = grayscale_cam[0, :] 107 | # print(grayscale_cam.shape) 108 | # visualization = show_cam_on_image(frame_list[0], grayscale_cam, use_rgb=True) 109 | 110 | 111 | trainer.test(model, datamodule=dm, ckpt_path="transformer-frame-video/2wxq6ed1/checkpoints/epoch=32-step=24947.ckpt") 112 | -------------------------------------------------------------------------------- /src/models/.contrastivemodel.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ed-fish/data-efficient-video-transformers/0a7d1b40563e244df14f4f33376cad413b2ba558/src/models/.contrastivemodel.py.swp -------------------------------------------------------------------------------- /src/models/LSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | 5 | 6 | class LSTMRegressor(pl.LightningModule): 7 | ''' 8 | Standard PyTorch Lightning module: 9 | https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html 10 | ''' 11 | 12 | def __init__(self, 13 | n_features, 14 | hidden_size, 15 | seq_len, 16 | batch_size, 17 | num_layers, 18 | dropout, 19 | learning_rate, 20 | criterion): 21 | super(LSTMRegressor, self).__init__() 22 | self.n_features = n_features 23 | self.hidden_size = hidden_size 24 | self.seq_len = seq_len 25 | self.batch_size = batch_size 26 | self.num_layers = num_layers 27 | self.dropout = dropout 28 | self.criterion = criterion 29 | self.running_logits = [] 30 | self.running_labels = [] 31 | self.learning_rate = learning_rate 32 | self.lstm = nn.LSTM(input_size=n_features, 33 | hidden_size=hidden_size, 34 | batch_first=True, 35 | num_layers=num_layers, 36 | dropout=dropout) 37 | self.linear = nn.Linear(hidden_size, 15) 38 | 39 | def forward(self, x): 40 | # lstm_out = (batch_size, seq_len, hidden_size) 41 | # print(x.shape) 42 | lstm_out, _ = self.lstm(x) 43 | y_pred = self.linear(lstm_out[:, -1]) 44 | return y_pred 45 | 46 | def configure_optimizers(self): 47 | return torch.optim.Adam(self.parameters(), lr=self.learning_rate) 48 | 49 | def training_step(self, batch, batch_idx): 50 | x = batch["experts"] 51 | x = torch.stack(x).squeeze(1) 52 | y = batch["label"] 53 | y = torch.cat(y).squeeze(1) 54 | y = y.float() 55 | y_hat = self(x) 56 | y_hat = torch.sigmoid(y_hat) 57 | loss = self.criterion(y_hat, y) 58 | #result = pl.TrainResult(loss) 59 | self.log('train_loss', loss) 60 | return loss 61 | 62 | def validation_step(self, batch, batch_idx): 63 | x = batch["experts"] 64 | x = torch.stack(x, dim=0).squeeze(1) 65 | y = batch["label"] 66 | y = torch.cat(y).squeeze(1) 67 | y = y.float() 68 | y_hat = self(x) 69 | y_hat = torch.sigmoid(y_hat) 70 | loss = self.criterion(y_hat, y) 71 | print("-" * 20) 72 | print(y[0]) 73 | print(y_hat[0]) 74 | print("-" * 20) 75 | self.running_labels.append(y) 76 | self.running_logits.append(y_hat) 77 | #result = pl.EvalResult(checkpoint_on=loss) 78 | self.log('val_loss', loss) 79 | return loss 80 | 81 | def test_step(self, batch, batch_idx): 82 | x, y = batch 83 | y_hat = self(x) 84 | loss = self.criterion(y_hat, y) 85 | self.running_labels.append(y) 86 | self.running_logits.append(y_hat) 87 | result.log('test_loss', loss) 88 | return result 89 | -------------------------------------------------------------------------------- /src/models/TPN.py: -------------------------------------------------------------------------------- 1 | 2 | class Feature_Pyramid_Mid(nn.Module): 3 | def __init__(self): 4 | super(Feature_Pyramid_Mid, self).__init__() 5 | self.pool_branch = nn.Sequential( 6 | nn.AvgPool2d(kernel_size=14), 7 | ) 8 | self.channels_reduce = nn.Conv2d(256, 256, kernel_size=1) 9 | 10 | def forward(self, mid): 11 | pyramid = self.pool_branch(mid) 12 | output = self.channels_reduce(pyramid) 13 | return output 14 | 15 | 16 | class Feature_Pyramid_High(nn.Module): 17 | def __init__(self): 18 | super(Feature_Pyramid_High, self).__init__() 19 | self.pool_branch = nn.Sequential( 20 | nn.AvgPool2d(kernel_size=7), 21 | ) 22 | self.channels_reduce = nn.Conv2d(512, 512, kernel_size=1) 23 | 24 | def forward(self, high): 25 | pyramid = self.pool_branch(high) 26 | return pyramid 27 | 28 | 29 | class Feature_Pyramid_low(nn.Module): 30 | def __init__(self): 31 | super(Feature_Pyramid_low, self).__init__() 32 | self.pool_branch = nn.Sequential( 33 | nn.AvgPool2d(kernel_size=28), 34 | ) 35 | self.channels_reduce = nn.Conv2d(128, 128, kernel_size=1) 36 | 37 | def forward(self, low): 38 | pyramid = self.pool_branch(low) 39 | output = self.channels_reduce(pyramid) 40 | return output 41 | 42 | 43 | class TPN(pl.LightningModule): 44 | def __init__(self): 45 | super(TPN, self).__init__() 46 | self.net = custom_resnet.resnet34(True) 47 | self.pyramid_low = Feature_Pyramid_low() 48 | self.pyramid_mid = Feature_Pyramid_Mid() 49 | self.pyramid_high = Feature_Pyramid_High() 50 | self.reason = Reasoning() 51 | # self.fusion = nn.Sequential(nn.ReLU(), nn.Linear()) 52 | 53 | def forward(self, x): 54 | low, mid, high = self.net(x) 55 | low_0 = self.pyramid_low(low).squeeze() 56 | mid_0 = self.pyramid_mid(mid).squeeze() 57 | high_0 = self.pyramid_high(high).squeeze() 58 | cnn_out = torch.cat((high_0, mid_0, low_0), dim=-1).unsqueeze(0) 59 | frame_feature = cnn_out.view(-1, 4 * 5, 896) 60 | output = self.reason(cnn_out) 61 | return output 62 | 63 | 64 | def sum_group(x, groups=2): 65 | batch, pics, vector = x.size() 66 | concatenation = [] 67 | for group_num in range(int(pics / groups)): 68 | segments = x[:, groups*group_num: groups*(group_num+1), :] 69 | segments = torch.sum(segments, dim=1) 70 | concatenation.append(segments) 71 | concatenation = torch.cat(concatenation, dim=1) 72 | return concatenation 73 | 74 | 75 | class Reasoning(nn.Module): 76 | def __init__(self, num_segments=4, num_frames=5, num_class=15, img_dim=896, max_group=4, start=2): 77 | super(Reasoning, self).__init__() 78 | self.num_segments = num_segments 79 | self.num_frames = num_frames 80 | self.num_class = num_class 81 | self.img_feature_dim = img_dim 82 | self.num_groups = max_group 83 | self.start = start 84 | self.relation = nn.ModuleList() 85 | self.classifier_scales = nn.ModuleList() 86 | num_bottleneck = 512 87 | for scales in range(self.start, self.num_groups+1): 88 | fc_fusion = nn.Sequential( 89 | nn.ReLU(), 90 | nn.Linear(self.img_feature_dim * int(self.num_segments * 91 | self.num_frames/scales), num_bottleneck), 92 | nn.ReLU(), 93 | nn.Dropout(p=0.6), 94 | nn.Linear(num_bottleneck, num_bottleneck), 95 | nn.ReLU(), 96 | nn.Dropout(p=0.5), 97 | nn.Linear(num_bottleneck, self.num_class), 98 | nn.Sigmoid(), 99 | # nn.ReLU(), 100 | # nn.Dropout(p=0.6), 101 | ) 102 | self.relation += [fc_fusion] 103 | # classifier = nn.Linear(num_bottleneck, self.num_class) 104 | # self.classifier_scales += [classifier] 105 | 106 | def forward(self, x): 107 | prediction = 0 108 | for segment_group in range(self.start, self.num_groups+1): 109 | segments = sum_group(x, groups=segment_group) 110 | segments = self.relation[segment_group-self.start](segments) 111 | prediction = prediction + segments 112 | return prediction / (self.num_groups-self.start+1) 113 | -------------------------------------------------------------------------------- /src/models/basicmlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchmetrics 3 | import torch 4 | import torch.nn.functional as F 5 | import pytorch_lightning as pl 6 | import math 7 | # from models.pretrained.models import EmbeddingExtractor 8 | from models.losses.ntxent import ContrastiveLoss 9 | 10 | class BasicMLP(pl.LightningModule): 11 | def __init__(self, config): 12 | super(BasicMLP, self).__init__() 13 | self.input_layer_size = config["input_shape"].get() 14 | self.bottleneck_size = config["bottle_neck"].get() 15 | self.output_layer_size = config["output_shape"].get() 16 | self.batch_size = config["batch_size"].get() 17 | self.config = config 18 | self.softmax = nn.LogSoftmax(dim=-1) 19 | # self.ee = EmbeddingExtractor(self.config) 20 | self.f1 = torchmetrics.F1(num_classes=22) 21 | 22 | self.fc1 = nn.Linear(self.input_layer_size, self.input_layer_size) 23 | self.batchnorm = nn.BatchNorm1d(1024) 24 | self.fc2 = nn.Linear(self.input_layer_size, self.bottleneck_size) 25 | self.fc3 = nn.Linear(self.bottleneck_size, self.bottleneck_size) 26 | self.fc4 = nn.Linear(self.bottleneck_size, 305) 27 | #self.loss = nn.BCEWithLogitsLoss() 28 | self.loss = nn.CrossEntropyLoss() 29 | self.acc = torchmetrics.Accuracy() 30 | 31 | def forward(self, tensor): 32 | output = F.relu(self.fc1(tensor)) 33 | output = self.batchnorm(F.relu(self.fc2(output))) 34 | embedding = F.relu(self.fc3(output)) 35 | output = self.fc4(embedding) 36 | return output 37 | 38 | def configure_optimizers(self): 39 | optimizer = torch.optim.Adam(self.parameters(), 40 | lr=self.config["learning_rate"].get()) 41 | return optimizer 42 | 43 | def expert_aggregation(self, expert_list): 44 | agg = self.config["aggregation"].get() 45 | 46 | if agg == "avg_pool": 47 | expert_list = torch.cat(expert_list, dim=-1) 48 | expert_list = F.adaptive_avg_pool2d(expert_list, self.input_layer_size) 49 | 50 | if agg == "mean_pool": 51 | expert_list = torch.cat(expert_list, dim=-1) 52 | expert_list = F.adaptive_max_pool2d(expert_list, size) 53 | 54 | if agg == "concat": 55 | expert_list = torch.cat(expert_list, dim=-1) 56 | 57 | return expert_list 58 | 59 | def debug(self, x_i, x_j): 60 | for keys, values in x_i.items(): 61 | print(keys, values.shape) 62 | 63 | def training_step(self, batch, batch_idx): 64 | x_i_experts = batch["x_i_experts"] 65 | labels = batch["label"] 66 | 67 | #x_i_experts = [self.expert_aggregation(x) for x in x_i_experts] 68 | #x_j_input = self.expert_aggregation(x_j_experts).squeeze(1) 69 | 70 | x_i_input = torch.stack(x_i_experts) 71 | labels = torch.tensor(labels) 72 | labels = labels.to(self.device) 73 | x_i_input = x_i_input.squeeze() 74 | 75 | output = self(x_i_input) 76 | 77 | output = output.squeeze() 78 | loss = self.loss(output, labels) 79 | self.log("training loss", loss, on_step=True, on_epoch=True) 80 | return loss 81 | 82 | def validation_step(self, batch, batch_idx): 83 | x_i_experts = batch["x_i_experts"] 84 | 85 | labels = batch["label"] 86 | labels = torch.tensor(labels) 87 | labels = labels.to(self.device) 88 | 89 | #x_i_experts = [self.expert_aggregation(x) for x in x_i_experts] 90 | #x_j_input = self.expert_aggregation(x_j_experts).squeeze(1) 91 | 92 | x_i_input = torch.stack(x_i_experts) 93 | #label = torch.tensor(label) 94 | x_i_input = x_i_input.squeeze() 95 | 96 | output = self(x_i_input) 97 | output = output.squeeze() 98 | print("out", output) 99 | print("label", labels) 100 | loss = self.loss(output, labels) 101 | self.log("validation loss", loss, on_step=True, on_epoch=True) 102 | #accuracy = self.f1(output, labels) 103 | #accuracy = self.acc(output, labels) 104 | #self.log("f1 score", accuracy, on_step=True, on_epoch=True) 105 | #self.log("val acc", accuracy, on_step=True, on_epoch=True) 106 | return loss 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /src/models/collabgating.py: -------------------------------------------------------------------------------- 1 | 2 | class CollaborativeGating(pl.LightningModule): 3 | def __init__(self): 4 | super(CollaborativeGating, self).__init__() 5 | self.proj_input = 2048 6 | self.proj_embedding_size = 2048 7 | self.projection = nn.Linear(self.proj_input, self.proj_embedding_size) 8 | self.cg = ContextGating(self.proj_input) 9 | self.geu = GatedEmbeddingUnit(self.proj_input, 1024, False) 10 | 11 | def pad(self, tensor): 12 | tensor = tensor.unsqueeze(0) 13 | curr_expert = F.interpolate(tensor, 2048) 14 | curr_expert = curr_expert.squeeze(0) 15 | return curr_expert 16 | 17 | def forward(self, batch): 18 | batch_list = [] 19 | for scenes in batch: # this will be batches 20 | scene_list = [] 21 | # first expert popped off 22 | for experts in scenes: 23 | expert_attention_vec = [] 24 | for i in range(len(experts)): 25 | curr_expert = experts.pop(0) 26 | if curr_expert.shape[1] != 2048: 27 | curr_expert = self.pad(curr_expert) 28 | 29 | # compare with all other experts 30 | curr_expert = self.projection(curr_expert) 31 | t_i_list = [] 32 | for c_expert in experts: 33 | # through g0 to get feature embedding t_i 34 | if c_expert.shape[1] != 2048: 35 | c_expert = self.pad(c_expert) 36 | c_expert = self.projection(c_expert) 37 | t_i = curr_expert + c_expert # t_i maps y1 to y2 38 | t_i_list.append(t_i) 39 | t_i_summed = torch.stack(t_i_list, dim=0).sum( 40 | dim=0) # all other features 41 | # attention vector for all comparrisons 42 | expert_attention = self.projection(t_i_summed) 43 | expert_attention_comp = self.cg( 44 | curr_expert, expert_attention) # gated version 45 | expert_attention_vec.append(expert_attention_comp) 46 | experts.append(curr_expert) 47 | expert_attention_vec = torch.stack(expert_attention_vec, dim=0).sum( 48 | dim=0) # concat all attention vectors 49 | # apply gated embedding 50 | expert_vector = self.geu(expert_attention_vec) 51 | scene_list.append(expert_vector) 52 | scene_stack = torch.stack(scene_list) 53 | batch_list.append(scene_stack) 54 | batch = torch.stack(batch_list, dim=0) 55 | batch = batch.squeeze(2) 56 | return batch 57 | 58 | 59 | class GatedEmbeddingUnit(nn.Module): 60 | def __init__(self, input_dimension, output_dimension, use_bn): 61 | super(GatedEmbeddingUnit, self).__init__() 62 | 63 | self.fc = nn.Linear(input_dimension, output_dimension) 64 | # self.cg = ContextGating(output_dimension, add_batch_norm=use_bn) 65 | 66 | def forward(self, x): 67 | x = self.fc(x) 68 | # x = self.cg(x) 69 | x = F.normalize(x) 70 | return x 71 | 72 | 73 | class ContextGating(nn.Module): 74 | def __init__(self, dimension, add_batch_norm=True): 75 | super(ContextGating, self).__init__() 76 | # self.add_batch_norm = add_batch_norm 77 | # self.batch_norm = nn.BatchNorm1d(dimension) 78 | # self.batch_norm2 = nn.BatchNorm1d(dimension) 79 | 80 | def forward(self, x, x1): 81 | 82 | # if self.add_batch_norm: 83 | # x = self.batch_norm(x) 84 | # x1 = self.batch_norm2(x1) 85 | t = x + x1 86 | x = torch.cat((x, t), -1) 87 | return F.glu(x, -1) 88 | -------------------------------------------------------------------------------- /src/models/contrastivemodel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import pytorch_lightning as pl 5 | import math 6 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 7 | # from pl_bolts.optimizers.lars_scheduling import LARSWrapper 8 | from pl_bolts.optimizers.lars import LARS 9 | # from models.pretrained.models import EmbeddingExtractor 10 | from models.losses.ntxent import ContrastiveLoss 11 | 12 | class SpatioTemporalContrastiveModel(pl.LightningModule): 13 | def __init__(self, config): 14 | super().__init__() 15 | self.input_layer_size = config["input_shape"] 16 | self.hidden_layer_size = config["hidden_layer"] 17 | self.projection_size = config["projection_size"] 18 | self.output_layer_size = config["output_shape"] 19 | self.batch_size = config["batch_size"] 20 | self.num_samples = config["num_samples"] 21 | self.config = config 22 | self.train_iters_per_epoch = self.num_samples // self.batch_size 23 | self.running_logits = [] 24 | self.running_labels = [] 25 | # self.ee = EmbeddingExtractor(self.config) 26 | 27 | self.encoder_net = nn.Sequential( 28 | nn.Linear(self.input_layer_size, self.hidden_layer_size, bias=False), 29 | nn.ReLU(inplace=True), 30 | nn.BatchNorm1d(self.hidden_layer_size), 31 | nn.Linear(self.hidden_layer_size, self.hidden_layer_size, bias=False), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(self.hidden_layer_size, self.projection_size), 34 | ) 35 | 36 | self.projector_net = nn.Sequential( 37 | nn.ReLU(inplace=True), 38 | nn.Linear(self.projection_size, self.projection_size), 39 | nn.ReLU(inplace=True), 40 | nn.Dropout(p=0.1), 41 | nn.Linear(self.projection_size, self.output_layer_size), 42 | ) 43 | 44 | self.loss = ContrastiveLoss(self.batch_size) 45 | 46 | self.proj_list = [] 47 | self.label_list = [] 48 | 49 | def forward(self, tensor): 50 | # if self.config["aggregation"].get() == "collab": 51 | embedding = self.encoder_net(tensor) 52 | # embedding = F.normalize(embedding) 53 | output = self.projector_net(embedding) 54 | 55 | return embedding, output 56 | 57 | def configure_optimizers(self): 58 | # parameters = self.exclude_from_wt_decay( 59 | # self.named_parameters(), 60 | # weight_decay=self.config["weight_decay"] 61 | # ) 62 | 63 | optimizer = torch.optim.Adam(self.parameters(), lr=self.config["learning_rate"], weight_decay=self.config["weight_decay"]) 64 | # optimizer = LARS( 65 | # parameters, 66 | # lr=self.config["learning_rate"], 67 | # momentum=self.config["momentum"], 68 | # weight_decay=self.config["weight_decay"], 69 | # trust_coefficient=0.0001, 70 | # ) 71 | 72 | # Trick 2 (after each step) 73 | # self.hparams.warmup_epochs = self.config["warm_up"] * self.train_iters_per_epoch 74 | # max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch 75 | 76 | # linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR( 77 | # optimizer, 78 | # warmup_epochs=2, 79 | # max_epochs=10, 80 | # warmup_start_lr=0, 81 | # eta_min=0 82 | # ) 83 | 84 | scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=self.config["epochs"] // 10, max_epochs=self.config["epochs"]) 85 | 86 | # scheduler = { 87 | # 'scheduler': linear_warmup_cosine_decay, 88 | # 'interval': 'step', 89 | # 'frequency': 1 90 | # } 91 | 92 | return [optimizer], [scheduler] 93 | # optimizer = torch.optim.Adam(self.parameters(), 94 | # lr=self.config["learning_rate"]) 95 | 96 | # return optimizer 97 | 98 | 99 | def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): 100 | params = [] 101 | excluded_params = [] 102 | 103 | for name, param in named_params: 104 | if not param.requires_grad: 105 | continue 106 | elif any(layer_name in name for layer_name in skip_list): 107 | excluded_params.append(param) 108 | else: 109 | params.append(param) 110 | 111 | return [ 112 | {'params': params, 'weight_decay': weight_decay}, 113 | {'params': excluded_params, 'weight_decay': 0.} 114 | ] 115 | 116 | def expert_aggregation(self, expert_list): 117 | agg = self.config["aggregation"] 118 | 119 | # in the case that there is just one expert 120 | if agg == "none": 121 | expert_list = expert_list[0] 122 | 123 | if agg == "avg_pool": 124 | expert_list = torch.cat(expert_list, dim=-1) 125 | expert_list = F.adaptive_avg_pool2d(expert_list, self.input_layer_size) 126 | 127 | if agg == "mean_pool": 128 | expert_list = torch.cat(expert_list, dim=-1) 129 | expert_list = F.adaptive_max_pool2d(expert_list, size) 130 | 131 | if agg == "concat": 132 | expert_list = torch.cat(expert_list, dim=-1) 133 | 134 | if agg == "collab_gate": 135 | pass 136 | 137 | return expert_list 138 | 139 | def debug(self, x_i, x_j): 140 | for keys, values in x_i.items(): 141 | print(keys, values.shape) 142 | 143 | def training_step(self, batch, batch_idx): 144 | x_i_experts = batch["x_i_experts"] 145 | x_j_experts = batch["x_j_experts"] 146 | label = batch["label"] 147 | 148 | #x_i_input = self.expert_aggregation(x_i_experts).squeeze(1) 149 | #x_j_input = self.expert_aggregation(x_j_experts).squeeze(1) 150 | 151 | x_i_experts = [self.expert_aggregation(x) for x in x_i_experts] 152 | x_j_experts = [self.expert_aggregation(x) for x in x_j_experts] 153 | 154 | x_i_experts = torch.stack(x_i_experts) 155 | x_j_experts = torch.stack(x_j_experts) 156 | 157 | x_i_experts = x_i_experts.squeeze() 158 | x_j_experts = x_j_experts.squeeze() 159 | 160 | x_i_embedding, x_i_out = self(x_i_experts) 161 | x_j_embedding, x_j_out = self(x_j_experts) 162 | 163 | x_i_out = F.normalize(x_i_out.squeeze()) 164 | x_j_out = F.normalize(x_j_out.squeeze()) 165 | 166 | loss = self.loss(x_i_out, x_j_out) 167 | self.log("train/contrastive/loss", loss) 168 | return loss 169 | 170 | def validation_step(self, batch, batch_idx): 171 | 172 | x_i_experts = batch["x_i_experts"] 173 | x_j_experts = batch["x_j_experts"] 174 | label = batch["label"] 175 | 176 | #x_i_input = self.expert_aggregation(x_i_experts).squeeze(1) 177 | #x_j_input = self.expert_aggregation(x_j_experts).squeeze(1) 178 | 179 | x_i_experts = [self.expert_aggregation(x) for x in x_i_experts] 180 | x_j_experts = [self.expert_aggregation(x) for x in x_j_experts] 181 | 182 | x_i_experts = torch.stack(x_i_experts) 183 | x_j_experts = torch.stack(x_j_experts) 184 | 185 | 186 | x_i_experts = x_i_experts.squeeze() 187 | x_j_experts = x_j_experts.squeeze() 188 | 189 | 190 | x_i_embedding, x_i_out = self(x_i_experts) 191 | x_j_embedding, x_j_out = self(x_j_experts) 192 | 193 | x_i_out = x_i_out.squeeze() 194 | x_j_out = x_j_out.squeeze() 195 | 196 | loss = self.loss(x_i_out, x_j_out) 197 | 198 | self.log("val/contrastive/loss", loss) 199 | return {"loss":loss, "val_outputs":x_i_embedding} 200 | 201 | 202 | 203 | def test_step(self, batch, batch_idx): 204 | x_i_experts = batch["x_i_experts"] 205 | x_i_experts = [self.expert_aggregation(x) for x in x_i_experts] 206 | x_i_experts = torch.stack(x_i_experts) 207 | 208 | label = batch["label"] 209 | x_i_embeddings, _ = self(x_i_experts) 210 | x_i_out = x_i_embeddings.squeeze() 211 | self.proj_list.append(x_i_out) 212 | self.label_list.append(label) 213 | list_len = len(self.proj_list) 214 | return {"length": list_len} 215 | 216 | 217 | # def validation_epoch_end(self, val_step_outputs): 218 | # print(val_step_outputs[0]["val_outputs"].shape) 219 | 220 | 221 | -------------------------------------------------------------------------------- /src/models/custom_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x2 = self.layer2(x) 146 | x3 = self.layer3(x2) 147 | x4 = self.layer4(x3) 148 | 149 | x = self.avgpool(x4) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x2, x3, x4 154 | 155 | 156 | def resnet18(pretrained=False, **kwargs): 157 | """Constructs a ResNet-18 model. 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 162 | if pretrained: 163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False, **kwargs): 168 | """Constructs a ResNet-34 model. 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 175 | return model 176 | 177 | 178 | def resnet50(pretrained=False, **kwargs): 179 | """Constructs a ResNet-50 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 186 | return model 187 | 188 | 189 | def resnet101(pretrained=False, **kwargs): 190 | """Constructs a ResNet-101 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 195 | if pretrained: 196 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 197 | return model 198 | 199 | 200 | def resnet152(pretrained=False, **kwargs): 201 | """Constructs a ResNet-152 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 208 | return model -------------------------------------------------------------------------------- /src/models/frame_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer, TransformerDecoder 8 | from torchmetrics import AUROC, F1, AveragePrecision 9 | from einops import rearrange 10 | import wandb 11 | from models import custom_resnet 12 | from torchvision.utils import make_grid, save_image 13 | import pickle as pkl 14 | from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad 15 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 16 | from pytorch_grad_cam.utils.image import show_cam_on_image 17 | 18 | 19 | class PositionalEncoding(pl.LightningModule): 20 | def __init__(self, d_model, dropout=0.1, max_len=4): 21 | super(PositionalEncoding, self).__init__() 22 | self.dropout = nn.Dropout(p=dropout) 23 | pe = torch.zeros(max_len, d_model) 24 | position = torch.arange(0, max_len).unsqueeze(1) 25 | div_term = torch.exp(torch.arange( 26 | 0, d_model, 2).float() * (-math.log(1000.0) / d_model)) 27 | pe[:, 0::2] = torch.sin(position * div_term) 28 | pe[:, 1::2] = torch.cos(position * div_term) 29 | pe = pe.unsqueeze(0).transpose(0, 1) 30 | self.register_buffer('pe', pe) 31 | 32 | def forward(self, x): 33 | x = x + self.pe[:x.size(0), :] 34 | return self.dropout(x) 35 | 36 | 37 | class TransformerBase(pl.LightningModule): 38 | def __init__(self, input_dimension, output_dimension, nhead, nhid, 39 | nlayers, dropout): 40 | super(TransformerBase, self).__init__() 41 | encoder_layer = TransformerEncoderLayer( 42 | input_dimension, nhead, nhid, dropout) 43 | nlayers = nlayers 44 | self.transformer = TransformerEncoder(encoder_layer, nlayers) 45 | 46 | def forward(self, x): 47 | return self.transformer(x) 48 | 49 | 50 | class ImgResNet(pl.LightningModule): 51 | def __init__(self): 52 | super(ImgResNet, self).__init__() 53 | self.backbone = models.resnet18(pretrained=True) 54 | num_filters = self.backbone.fc.in_features 55 | self.backbone.fc = nn.Sequential(nn.Linear(num_filters, 896)) 56 | 57 | def forward(self, x): 58 | # self.feature_extractor.eval() 59 | with torch.no_grad(): 60 | representations = self.backbone(x) 61 | return representations 62 | 63 | 64 | class VidResNet(pl.LightningModule): 65 | def __init__(self): 66 | super(VidResNet, self).__init__() 67 | self.backbone = models.video.r2plus1d_18(pretrained=True) 68 | num_filters = self.backbone.fc.in_features 69 | self.backbone.fc = nn.Sequential(nn.Linear(num_filters, 896)) 70 | 71 | def forward(self, x): 72 | #with torch.no_grad(): 73 | representations = self.backbone(x) 74 | return representations 75 | 76 | 77 | # class LocationResNet(pl.LightningDataModule): 78 | # def __init__(self): 79 | # super(LocationResNet, self).__init__() 80 | # self.backbone = 81 | 82 | 83 | class FrameTransformer(pl.LightningModule): 84 | def __init__(self, **kwargs): 85 | super(FrameTransformer, self).__init__() 86 | self.save_hyperparameters() 87 | if self.hparams.cls: 88 | self.hparams.seq_len += 1 89 | self.criterion = nn.BCEWithLogitsLoss() 90 | self.distil_criterion = nn.CrossEntropyLoss() 91 | self.position_encoder = PositionalEncoding( 92 | 896, 0.5, 93 | max_len=14) 94 | #self.img_model = ImgResNet() 95 | self.vid_model = VidResNet() 96 | # self.cls_token = nn.Parameter( 97 | # torch.randn(1, 1, self.hparams.input_dimension)) 98 | #self.scene_transformer = TransformerBase(896, 896, 4, 896, 4, 0.5) 99 | self.distil_transformer = TransformerBase(896, 128, 2, 512, 4, 0.5) 100 | self.running_labels = [] 101 | self.running_logits = [] 102 | self.running_paths = [] 103 | self.running_embeds = [] 104 | #self.img_cls = nn.Parameter(torch.rand(1, 3, 224, 224)) 105 | self.vid_cls = nn.Parameter(torch.rand(1, 12, 3, 112, 112)) 106 | self.img_mlp_head = nn.Sequential(nn.Linear(896, 512), nn.GELU(), nn.Linear(512, 128),nn.GELU(), nn.Linear(128, 19)) 107 | #self.vid_mlp_head = nn.Sequential( 108 | #nn.LayerNorm(896), nn.Linear(896, 19)) 109 | # self.decoder = nn.Sequential(nn.Linear(75, 32), nn.GELU(), nn.Dropout( 110 | # 0.5), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 15)) 111 | # self.encoder = nn.Sequential(nn.Linear(256, 256), nn.Dropout(0.5)) 112 | self.running_logits = [] 113 | self.running_labels = [] 114 | #self.val_auroc = AUROC(num_classes=19) 115 | #self.train_auroc = AUROC(num_classes=19) 116 | self.train_aprc = AveragePrecision(num_classes=19) 117 | self.norm = nn.LayerNorm(896) 118 | # self.tpn = TPN() 119 | self.val_aprc = AveragePrecision(num_classes=19) 120 | # self.pool = nn.AdaptiveAvgPool2d((1, 15)) 121 | self.cos = nn.CosineSimilarity(dim=1) 122 | 123 | def configure_optimizers(self): 124 | if self.hparams.opt == "sgd": 125 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate, 126 | momentum=self.hparams.momentum, weight_decay=self.hparams.weight_decay) 127 | elif self.hparams.opt == "adamW": 128 | optimizer = torch.optim.AdamW(self.parameters( 129 | ), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) 130 | 131 | elif self.hparams.opt == "adagrad": 132 | optimizer = torch.optim.Adagrad(self.parameters( 133 | ), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) 134 | return optimizer 135 | 136 | def forward(self, img, vid): 137 | total = [] 138 | # out_img = self.img_step(img) 139 | if self.hparams.model == "distil": 140 | img, vid = self.distillation_step(img, vid) 141 | return img, vid 142 | 143 | if self.hparams.model == "sum": 144 | img, vid = self.distillation_step(img, vid) 145 | output = img + vid 146 | output = self.img_mlp_head(output) 147 | return output 148 | 149 | if self.hparams.model == "sum_residual": 150 | vid_cls = self.vid_step(vid) 151 | img_cls, seq = self.img_step(img, vid_cls) 152 | embed_list = [] 153 | # for s in seq: 154 | # s = self.img_mlp_head(s) 155 | # embed_list.append(s) 156 | # return embed_list 157 | img_cls = F.normalize(img_cls, p=2.0, dim=-1) 158 | vid_cls = F.normalize(img_cls, p=2.0, dim=-1) 159 | embed = img_cls + vid_cls 160 | output = self.img_mlp_head(embed) 161 | return output 162 | 163 | if self.hparams.model == "post_sum": 164 | img, vid, vid_cls = self.distillation_step(img, vid) 165 | output = img + vid_cls 166 | output = self.img_mlp_head(output) 167 | return output 168 | 169 | if self.hparams.model == "frame": 170 | img_cls = self.img_step(img, None) 171 | return img_cls 172 | 173 | if self.hparams.model == "pre_modal": 174 | img_cls = self.pre_modal(img, vid) 175 | return img_cls 176 | 177 | if self.hparams.model == "vid": 178 | vid_cls = self.vid_step(vid) 179 | vid_cls = self.img_mlp_head(vid_cls) 180 | return vid_cls 181 | 182 | def distillation_step(self, img, vid): 183 | vid_cls = self.vid_step(vid) 184 | img_cls, vid_tkn = self.img_step(img, vid_cls) 185 | return img_cls, vid_tkn 186 | 187 | def pre_modal(self, img, vid): 188 | vid = self.vid_step 189 | img_cls = self.img_step(img, vid) 190 | return img_cls 191 | 192 | def vid_step(self, data): 193 | total = [] 194 | for d in range(len(data)): 195 | cls_d = torch.cat((self.vid_cls, data[d]), dim=0) 196 | total.append(cls_d) 197 | data = torch.stack(total) 198 | data = data.view(-1, 12, 3, 112, 112) 199 | data = data.permute(0, 2, 1, 3, 4) 200 | data = self.vid_model(data) 201 | 202 | if self.hparams.model == "pre-modal": 203 | return data 204 | data = data.view(self.hparams.batch_size, 14, 896) 205 | data = data.permute(1, 0, 2) 206 | data = self.position_encoder(data) 207 | data = self.distil_transformer(data) 208 | data = data.permute(1, 0, 2) 209 | vid_cls = data[:, 0] 210 | return vid_cls 211 | 212 | def img_step(self, data, distil_inject): 213 | total = [] 214 | for d in range(len(data)): 215 | cls_d = torch.cat((self.img_cls, data[d]), dim=0) 216 | total.append(cls_d) 217 | data = torch.stack(total) 218 | data = data.view(-1, 3, 224, 224) 219 | data = self.img_model(data) 220 | if self.hparams.model == "pre-modal": 221 | data = data + distil_inject 222 | # data = [batch + cls, dim] 223 | data = data.view(self.hparams.batch_size, self.hparams.seq_len, -1) 224 | data = data.permute(1, 0, 2) 225 | if self.hparams.model == "sum": 226 | data = torch.cat((data, distil_inject)) 227 | data = self.position_encoder(data) 228 | #img_seq = data.permute(1, 0, 2) 229 | #img_seq = self.norm(img_seq) 230 | #img_seq = img_seq.permute(1, 0, 2) 231 | img_seq = self.scene_transformer(data) 232 | img_seq = img_seq.permute(1, 0, 2) 233 | cls = img_seq[:, 0] 234 | if self.hparams.model == "distil": 235 | dis_tkn = img_seq[:, -1] 236 | return cls, dis_tkn 237 | if self.hparams.model == "sum": 238 | dis_tkn = img_seq[:, -1] 239 | return cls, dis_tkn 240 | if self.hparams.model == "sum_residual": 241 | return cls, img_seq 242 | else: 243 | cls = self.img_mlp_head(cls) 244 | return cls 245 | 246 | def training_step(self, batch, batch_idx): 247 | if self.hparams.model == "distil": 248 | target, img, vid = batch 249 | img, vid = self(img, vid) 250 | distil_loss = self.distil_criterion(img, torch.argmax(vid, dim=-1)) 251 | base_loss = self.criterion(img, target) 252 | loss = base_loss + distil_loss 253 | self.log("train/distilloss", distil_loss, 254 | on_step=True, on_epoch=True) 255 | self.log("train/bass_loss", base_loss, 256 | on_step=True, on_epoch=True) 257 | self.log("train/cossim", self.cos(img, vid)[0], 258 | on_step=True, on_epoch=True) 259 | data = img 260 | if self.hparams.model == "sum" or self.hparams.model == "pre_modal" or self.hparams.model == "sum_residual": 261 | target, img, vid = batch 262 | data = self(img, vid) 263 | loss = self.criterion(data, target) 264 | if self.hparams.model == "frame": 265 | target, img, vid = batch 266 | data = self(img, None) 267 | target = target.float() 268 | loss = self.criterion(data, target) 269 | if self.hparams.model == "vid": 270 | target, img, vid = batch 271 | data = self(None, vid) 272 | target = target.float() 273 | loss = self.criterion(data, target) 274 | 275 | target = target.int() 276 | #self.train_auroc(data, target) 277 | self.train_aprc(data, target) 278 | self.log("train/loss", loss, on_step=False, on_epoch=True) 279 | 280 | #self.log("train/auroc", self.train_auroc, on_step=True, on_epoch=True) 281 | self.log("train/aprc", self.train_aprc, on_step=False, on_epoch=True) 282 | return loss 283 | 284 | def translate_labels(self, label_vec): 285 | target_names = ['Action', 'Adventure', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 286 | 'Fantasy', 'History', 'Horror', 'Music', 'Mystery', 'Science Fiction', 'Thriller', 'War'] 287 | labels = [] 288 | for i, l in enumerate(label_vec): 289 | if l: 290 | labels.append(target_names[i]) 291 | return labels 292 | 293 | def validation_step(self, batch, batch_idx): 294 | if self.hparams.model == "distil": 295 | img, vid = self(img, vid) 296 | distil_loss = self.distil_criterion(img, torch.argmax(vid, dim=-1)) 297 | base_loss = self.criterion(img, target) 298 | self.log("val/distilloss", distil_loss, 299 | on_step=True, on_epoch=True) 300 | self.log("val/base_loss", base_loss, on_step=True, on_epoch=True) 301 | self.log("val/cossim", self.cos(img, vid)[0], 302 | on_step=True, on_epoch=True) 303 | loss = base_loss + distil_loss 304 | data = img 305 | elif self.hparams.model == "sum" or self.hparams.model == "pre_modal" or self.hparams.model == "sum_residual": 306 | target = batch[0]['label'].reshape(self.hparams.batch_size,-1, 19) 307 | target = target[:, 0, :] 308 | vid = batch[0]['data'] 309 | vid = vid.reshape(self.hparams.batch_size, self.hparams.seq_len -1, self.hparams.frame_len, 3, 112, 112) 310 | 311 | data = self(img, vid) 312 | loss = self.criterion(data, target) 313 | elif self.hparams.model == "frame": 314 | target, img, vid = batch 315 | data = self(img, None) 316 | target = target.float() 317 | loss = self.criterion(data, target) 318 | elif self.hparams.model == "vid": 319 | target, img, vid = batch 320 | #target = batch[0]['label'].reshape(self.hparams.batch_size,-1, 19) 321 | #target = target[:, 0, :] 322 | #vid = batch[0]['data'] 323 | #print(vid.shape) 324 | #print(target.shape) 325 | #vid = vid.reshape(self.hparams.batch_size, self.hparams.seq_len - 1, self.hparams.frame_len, 3, 112, 112) 326 | data = self(None, vid) 327 | target = target.float() 328 | loss = self.criterion(data, target) 329 | 330 | target = target.int() 331 | sig_data = F.sigmoid(data) 332 | self.running_logits.append(sig_data) 333 | self.running_labels.append(target) 334 | #format_target = self.translate_labels(target[0]) 335 | #format_logits = self.translate_labels((sig_data[0] > 0.2).to(int)) 336 | # images = wandb.Image( 337 | # grid, caption=f"predicted: {format_logits}, actual {format_target}") 338 | # self.logger.experiment.log({"examples": images}) 339 | #self.val_auroc(data, target) 340 | self.val_aprc(data, target) 341 | # self.val_f1_2(data, target) 342 | self.log("val/loss", loss, on_epoch=True) 343 | #self.log("val/auroc", self.val_auroc, on_step=True, on_epoch=True) 344 | self.log("val/aprc", self.val_aprc, on_step=False, on_epoch=True) 345 | return loss 346 | 347 | def test_step(self, batch, batch_idx): 348 | if self.hparams.model == "distil": 349 | img, vid, path = self(img, vid) 350 | data = img 351 | elif self.hparams.model == "sum" or self.hparams.model == "pre_modal" or self.hparams.model == "sum_residual": 352 | target, img, vid = batch 353 | embed = self(img, vid) 354 | elif self.hparams.model == "frame": 355 | target, img, path = batch 356 | data = self(img, None) 357 | target = target.float() 358 | elif self.hparams.model == "vid": 359 | target, img, vid = batch 360 | data = self(None, vid) 361 | target = target.float() 362 | 363 | target = target.int() 364 | self.running_logits.append(F.sigmoid(data)) 365 | # self.running_embeds.append(data) 366 | self.running_labels.append(target) 367 | # self.running_paths.append(path) 368 | 369 | -------------------------------------------------------------------------------- /src/models/losses/ntxent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class NT_Xent(nn.Module): 6 | def __init__(self, batch_size, temperature, world_size): 7 | super(NT_Xent, self).__init__() 8 | self.batch_size = batch_size 9 | self.temperature = temperature 10 | self.world_size = world_size 11 | self.mask = self.mask_correlated_samples(batch_size, world_size) 12 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 13 | self.similarity_f = nn.CosineSimilarity(dim=2) 14 | 15 | def mask_correlated_samples(self, batch_size, world_size): 16 | N = 2 * batch_size * world_size 17 | mask = torch.ones((N, N), dtype=bool) 18 | mask = mask.fill_diagonal_(0) 19 | for i in range(batch_size * world_size): 20 | mask[i, batch_size + i] = 0 21 | mask[batch_size + i, i] = 0 22 | return mask 23 | 24 | def forward(self, z_i, z_j): 25 | N = 2 * self.batch_size * self.world_size 26 | 27 | z = torch.cat((z_i, z_j), dim=0) 28 | 29 | sim = self.similarity_f(z.unsqueeze( 30 | 1), z.unsqueeze(0)) / self.temperature 31 | 32 | sim_i_j = torch.diag(sim, self.batch_size * self.world_size) 33 | sim_j_i = torch.diag(sim, -self.batch_size * self.world_size) 34 | 35 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 36 | negative_samples = sim[self.mask].reshape(N, -1) 37 | 38 | labels = torch.zeros(N).to(positive_samples.device).long() 39 | logits = torch.cat((positive_samples, negative_samples), dim=1) 40 | loss = self.criterion(logits, labels) 41 | loss /= N 42 | 43 | 44 | class ContrastiveLoss(nn.Module): 45 | def __init__(self, batch_size, temperature=0.5): 46 | super().__init__() 47 | self.batch_size = batch_size 48 | self.register_buffer("temperature", torch.tensor(temperature)) 49 | self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()) 50 | 51 | 52 | def forward(self, emb_i, emb_j): 53 | """ 54 | emb_i and emb_j are batches of embeddings, where corresponding indices are pairs 55 | z_i, z_j as per SimCLR paper 56 | """ 57 | #z_i = F.normalize(emb_i, dim=1) 58 | #z_j = F.normalize(emb_j, dim=1) 59 | 60 | z_i = emb_i 61 | z_j = emb_j 62 | 63 | representations = torch.cat([z_i, z_j], dim=0) 64 | similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2) 65 | 66 | sim_ij = torch.diag(similarity_matrix, self.batch_size) 67 | sim_ji = torch.diag(similarity_matrix, -self.batch_size) 68 | positives = torch.cat([sim_ij, sim_ji], dim=0) 69 | 70 | nominator = torch.exp(positives / self.temperature) 71 | denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature) 72 | 73 | loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) 74 | loss = torch.sum(loss_partial) / (2 * self.batch_size) 75 | return loss 76 | -------------------------------------------------------------------------------- /src/models/pretrained/models.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as torch_models 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | class EmbeddingExtractor: 9 | def __init__(self, config): 10 | self.image_net = torch_models.resnet50(pretrained=True) 11 | self.video_net = torch_models.video.r3d_18(pretrained=True) 12 | self.location_net = torch_models.resnet50(pretrained=False) 13 | # self.audio_net = torch.hub.load('harritaylor/torchvggish', 'vggish') 14 | location_weights = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 15 | self.location_net.load_state_dict(model_zoo.load_url(location_weights)) 16 | # self.depth_net = torch.hub.load("intel-isl/MiDas", "MiDaS") 17 | self.location_net.fc = Identity() 18 | # self.depth_net.scratch = Identity() 19 | self.image_net.fc = Identity() 20 | self.video_net.fc = Identity() 21 | self.device = torch.device(config["gpu"].get(int)) 22 | 23 | def init_models(self, m): 24 | m = m.to(self.device) 25 | m = m.eval() 26 | 27 | def forward_img(self, tensor): 28 | self.init_models(self.image_net) 29 | with torch.no_grad(): 30 | tensor = tensor.to(self.device) 31 | output = self.image_net.forward(tensor).cpu() 32 | return output 33 | 34 | def forward_location(self, tensor): 35 | self.init_models(self.location_net) 36 | with torch.no_grad(): 37 | tensor = tensor.to(self.device) 38 | output = self.location_net.forward(tensor).cpu() 39 | return output 40 | 41 | def forward_depth(self, tensor): 42 | self.init_models(self.depth_net) 43 | with torch.no_grad(): 44 | tensor = tensor.to(self.device) 45 | output = self.depth_net.forward(tensor).cpu() 46 | return output 47 | 48 | def forward_video(self, tensor_stack): 49 | self.init_models(self.video_net) 50 | with torch.no_grad(): 51 | tensor_stack = tensor_stack.to(self.device) 52 | output = self.video_net.forward(tensor_stack).cpu() 53 | return output 54 | 55 | def forward_audio(self, audio_sample): 56 | output = self.audio_net.forward(audio_sample) 57 | return output 58 | 59 | def depth_network_pool(self, depth_output): 60 | 61 | with torch.no_grad(): 62 | depth_output = torch.flatten(depth_output, start_dim=1).unsqueeze(0) 63 | pool = ((1, 2048)).to(self.device) 64 | depth_output = pool(depth_output) 65 | depth_output = depth_output.squeeze(0) 66 | output = depth_output.cpu() 67 | return output 68 | 69 | def return_expert_for_key(self, key, raw_tensor): 70 | 71 | output = [] 72 | if key == "image": 73 | for img in raw_tensor: 74 | img = img.squeeze(1) 75 | output.append(self.forward_img(img).to('cpu')) 76 | output = torch.stack(output) 77 | output = output.transpose(0, 2) 78 | output = F.adaptive_avg_pool1d(output, 1) 79 | output = output.transpose(1, 0).squeeze(2) 80 | 81 | if key == "motion" or key == "video": 82 | with torch.no_grad(): 83 | raw_tensor = raw_tensor.unsqueeze(0) 84 | output = self.forward_video(raw_tensor) 85 | 86 | if key == "location": 87 | for img in raw_tensor: 88 | with torch.no_grad(): 89 | img = img.squeeze(1) 90 | output.append(self.forward_location(img).cpu()) 91 | output = torch.stack(output) 92 | output = output.transpose(0, 2) 93 | output = F.adaptive_avg_pool1d(output, 1) 94 | output = output.transpose(1, 0).squeeze(2) 95 | 96 | return output 97 | 98 | 99 | def return_expert_for_key_pretrained(self, key, raw_tensor): 100 | 101 | output = [] 102 | if key == "image": 103 | output = torch.stack(raw_tensor) 104 | output = output.transpose(0, 2) 105 | output = F.adaptive_avg_pool1d(output, 1) 106 | output = output.transpose(1, 0).squeeze(2) 107 | output= output.squeeze(1) 108 | print(output.shape) 109 | 110 | if key == "motion" or key == "video": 111 | output = raw_tensor[0].unsqueeze(0) 112 | print(output.shape) 113 | 114 | if key == "location": 115 | output = torch.stack(raw_tensor) 116 | output = output.transpose(0, 2) 117 | output = F.adaptive_avg_pool1d(output, 1) 118 | output = output.transpose(1, 0).squeeze(2) 119 | output = output.squeeze(1) 120 | print(output.shape) 121 | 122 | return output 123 | 124 | 125 | 126 | class Identity(nn.Module): 127 | def forward(self, x): 128 | return x 129 | -------------------------------------------------------------------------------- /src/models/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer, TransformerDecoder 7 | from einops import rearrange 8 | 9 | 10 | class PositionalEncoding(pl.LightningModule): 11 | def __init__(self, d_model, dropout=0.1, max_len=4): 12 | super(PositionalEncoding, self).__init__() 13 | self.dropout = nn.Dropout(p=dropout) 14 | pe = torch.zeros(max_len, d_model) 15 | position = torch.arange(0, max_len).unsqueeze(1) 16 | div_term = torch.exp(torch.arange( 17 | 0, d_model, 2).float() * (-math.log(1000.0) / d_model)) 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | pe = pe.unsqueeze(0).transpose(0, 1) 21 | self.register_buffer('pe', pe) 22 | 23 | def forward(self, x): 24 | x = x + self.pe[:x.size(0), :] 25 | return self.dropout(x) 26 | 27 | 28 | class SimpleTransformer(pl.LightningModule): 29 | def __init__(self, **kwargs): 30 | super(SimpleTransformer, self).__init__() 31 | 32 | self.save_hyperparameters() 33 | if self.hparams.cls: 34 | self.hparams.seq_len += 1 35 | self.criterion = nn.BCEWithLogitsLoss() 36 | self.position_encoder = PositionalEncoding(2048, self.hparams.dropout, 37 | max_len=self.hparams.seq_len) 38 | 39 | self.encoder_layers0 = TransformerEncoderLayer( 40 | self.hparams.input_dimension, self.hparams.nhead, self.hparams.nhid, self.hparams.dropout) 41 | self.transformer_encoder0 = TransformerEncoder( 42 | self.encoder_layers0, self.hparams.nlayers) 43 | 44 | self.encoder_layers1 = TransformerEncoderLayer( 45 | self.hparams.input_dimension, self.hparams.nhead, self.hparams.nhid, self.hparams.dropout) 46 | self.transformer_encoder1 = TransformerEncoder( 47 | self.encoder_layers1, self.hparams.nlayers) 48 | 49 | self.norm = nn.LayerNorm(2048) 50 | self.running_labels = [] 51 | self.running_logits = [] 52 | self.cls = nn.Parameter(torch.rand( 53 | 1, self.hparams.batch_size, 2048)) 54 | self.mlp_head = nn.Sequential(nn.LayerNorm(2048), nn.Linear(2048, 15)) 55 | self.mlp_encoder = nn.Sequential( 56 | nn.LayerNorm(2048), nn.Linear(2048, 1024)) 57 | 58 | def configure_optimizers(self): 59 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate, 60 | momentum=self.hparams.momentum, weight_decay=self.hparams.weight_decay) 61 | 62 | # optimizer = torch.optim.AdamW(self.parameters( 63 | # ), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) 64 | return optimizer 65 | 66 | def forward(self, src): 67 | src = self.expert_encoder(src) 68 | src = src * math.sqrt(self.hparams.input_dimension//2) 69 | src = self.position_encoder(src) 70 | # src = self.norm(src) 71 | output = self.transformer_encoder(src) 72 | return output 73 | 74 | def add_pos_cls(self, data): 75 | # input is expert, batch, sequence, dimension 76 | data = rearrange(data, 'b s d -> s b d') 77 | data = torch.cat((self.cls, data)) 78 | data = self.position_encoder(data) 79 | data = rearrange(data, 's b d -> b s d') 80 | data = self.norm(data) 81 | data = rearrange(data, 'b s d -> s b d') 82 | return data 83 | 84 | def ptn_shared(self, data): 85 | # data input (BATCH, SEQ, EXPERTS, DIM) 86 | # experts batch sequence dimension 87 | expert_array = [] 88 | data = rearrange(data, 'b s e d -> e b s d') 89 | for expert in data: 90 | # experts sequence batch dimension 91 | e = self.add_pos_cls(expert) # s b d (s > seq len) 92 | e = self(e) 93 | e = rearrange(e, 's b d -> b s d') 94 | e = e[:, 0] 95 | print("e", e.shape) 96 | expert_array.append(e) 97 | expert_array = torch.stack(expert_array) # elen b d 98 | expert_array = rearrange(expert_array, 's b d -> b s d') 99 | expert_array = self.add_pos_cls(expert_array) 100 | ptn_out = self(expert_array) 101 | ptn_out = rearrange(ptn_out, 's b d -> b s d') 102 | ptn_out = ptn_out[:, 0] 103 | ptn_out = self.mlp_head(ptn_out) 104 | return ptn_out 105 | 106 | def ptn(self, data): 107 | # data input (BATCH, SEQ, EXPERTS, DIM) 108 | # experts batch sequence dimension 109 | expert_array = [] 110 | data = rearrange(data, 'b s e d -> e b s d') 111 | for i, expert in enumerate(data): 112 | # experts sequence batch dimension 113 | e = self.add_pos_cls(expert) # s b d (s > seq len) 114 | print("e", e.shape) 115 | if i == 0: 116 | e = self.transformer_encoder0(e) 117 | elif i == 1: 118 | e = self.transformer_encoder1(e) 119 | print("e1", e.shape) 120 | e = rearrange(e, 's b d -> b s d') 121 | 122 | print("e2", e.shape) 123 | e = e[:, 0, :] 124 | print("e3", e.shape) 125 | expert_array.append(e) 126 | 127 | ptn_out = torch.stack(expert_array) # elen b d 128 | ptn_out = rearrange(ptn_out, "e b d -> b e d") 129 | print("ptn", ptn_out.shape) 130 | ptn_out = torch.sum(ptn_out, dim=1) 131 | print("ptn out", ptn_out.shape) 132 | ptn_out = self.mlp_head(ptn_out) 133 | return ptn_out 134 | 135 | def training_step(self, batch, batch_idx): 136 | 137 | data = batch["experts"] 138 | target = batch["label"] 139 | 140 | data = self.shared_step(data) 141 | #target = self.format_target(target) 142 | loss = self.criterion(data, target) 143 | self.log("train/loss", loss, on_step=True, on_epoch=True) 144 | return loss 145 | 146 | def validation_step(self, batch, batch_idx): 147 | data = batch["experts"] 148 | target = batch["label"] 149 | data = self.shared_step(data) 150 | #target = self.format_target(target) 151 | loss = self.criterion(data, target) 152 | self.log("val/loss", loss, on_step=True, on_epoch=True) 153 | target = target.int() 154 | sig_data = F.sigmoid(data) 155 | self.running_logits.append(sig_data) 156 | 157 | self.running_labels.append(target) 158 | self.running_logits.append(data) 159 | self.log("val/loss", loss, on_step=False, on_epoch=True) 160 | return loss 161 | 162 | def shared_step(self, data): 163 | if self.hparams.model == "ptn": 164 | data = self.ptn(data) 165 | return data 166 | if self.hparams.model == "ptn_shared": 167 | data = self.ptn(data) 168 | return data 169 | 170 | 171 | def format_target(self, target): 172 | target = torch.cat(target, dim=0) 173 | target = target.squeeze() 174 | return target 175 | 176 | -------------------------------------------------------------------------------- /src/models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | import numpy as np 7 | 8 | class PreNorm(nn.Module): 9 | def __init__(self, dim, fn): 10 | super().__init__() 11 | self.norm = nn.LayerNorm(dim) 12 | self.fn = fn 13 | def forward(self, x, **kwargs): 14 | return self.fn(self.norm(x), **kwargs) 15 | 16 | 17 | class FeedForward(nn.Module): 18 | def __init__(self, dim, hidden_dim, dropout = 0.): 19 | super().__init__() 20 | self.net = nn.Sequential( 21 | nn.Linear(dim, hidden_dim), 22 | nn.GELU(), 23 | nn.Dropout(dropout), 24 | nn.Linear(hidden_dim, dim), 25 | nn.Dropout(dropout) 26 | ) 27 | def forward(self, x): 28 | return self.net(x) 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 32 | super().__init__() 33 | inner_dim = dim_head * heads 34 | project_out = not (heads == 1 and dim_head == dim) 35 | 36 | self.heads = heads 37 | self.scale = dim_head ** -0.5 38 | 39 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 40 | 41 | self.to_out = nn.Sequential( 42 | nn.Linear(inner_dim, dim), 43 | nn.Dropout(dropout) 44 | ) if project_out else nn.Identity() 45 | 46 | def forward(self, x): 47 | b, n, _, h = *x.shape, self.heads 48 | qkv = self.to_qkv(x).chunk(3, dim = -1) 49 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 50 | 51 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 52 | 53 | attn = dots.softmax(dim=-1) 54 | 55 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 56 | out = rearrange(out, 'b h n d -> b n (h d)') 57 | out = self.to_out(out) 58 | return out 59 | 60 | class Transformer(nn.Module): 61 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 62 | super().__init__() 63 | self.layers = nn.ModuleList([]) 64 | self.norm = nn.LayerNorm(dim) 65 | for _ in range(depth): 66 | self.layers.append(nn.ModuleList([ 67 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 68 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 69 | ])) 70 | 71 | def forward(self, x): 72 | for attn, ff in self.layers: 73 | x = attn(x) + x 74 | x = ff(x) + x 75 | return self.norm(x) 76 | 77 | 78 | 79 | class ViViT(nn.Module): 80 | def __init__(self, image_size, patch_size, num_classes, num_frames, dim = 192, depth = 4, heads = 3, pool = 'cls', in_channels = 3, dim_head = 64, dropout = 0., 81 | emb_dropout = 0., scale_dim = 4, ): 82 | super().__init__() 83 | 84 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 85 | 86 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 87 | num_patches = (image_size // patch_size) ** 2 88 | patch_dim = in_channels * patch_size ** 2 89 | self.to_patch_embedding = nn.Sequential( 90 | Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 91 | nn.Linear(patch_dim, dim), 92 | ) 93 | 94 | self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, num_patches + 1, dim)) # 1, num_frames, dim) 95 | self.space_token = nn.Parameter(torch.randn(1, 1, dim)) # 96 | self.space_transformer = Transformer(dim, depth, heads, dim_head, dim*scale_dim, dropout) 97 | 98 | self.temporal_token = nn.Parameter(torch.randn(1, 1, dim)) 99 | self.temporal_transformer = Transformer(dim, depth, heads, dim_head, dim*scale_dim, dropout) 100 | 101 | self.dropout = nn.Dropout(emb_dropout) 102 | self.pool = pool 103 | 104 | self.mlp_head = nn.Sequential( 105 | nn.LayerNorm(dim), 106 | nn.Linear(dim, num_classes) 107 | ) 108 | 109 | def forward(self, x): 110 | x = self.to_patch_embedding(x) 111 | b, t, n, _ = x.shape # batch, sequence, frames, height, width 112 | 113 | cls_space_tokens = repeat(self.space_token, '() n d -> b t n d', b = b, t=t) # repeat space token over batch and sequence 114 | x = torch.cat((cls_space_tokens, x), dim=2) 115 | x += self.pos_embedding[:, :, :(n + 1)] 116 | x = self.dropout(x) 117 | 118 | x = rearrange(x, 'b t n d -> (b t) n d') 119 | x = self.space_transformer(x) 120 | x = rearrange(x[:, 0], '(b t) ... -> b t ...', b=b) 121 | 122 | cls_temporal_tokens = repeat(self.temporal_token, '() n d -> b n d', b=b) 123 | x = torch.cat((cls_temporal_tokens, x), dim=1) 124 | 125 | x = self.temporal_transformer(x) 126 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 127 | 128 | return self.mlp_head(x) 129 | 130 | 131 | 132 | 133 | if __name__ == "__main__": 134 | 135 | img = torch.ones([1, 16, 3, 224, 224]).cuda() 136 | 137 | model = ViViT(224, 16, 100, 16).cuda() 138 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 139 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 140 | print('Trainable Parameters: %.3fM' % parameters) 141 | 142 | out = model(img) 143 | 144 | print("Shape of out :", out.shape) # [B, num_classes] 145 | 146 | -------------------------------------------------------------------------------- /src/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ed-fish/data-efficient-video-transformers/0a7d1b40563e244df14f4f33376cad413b2ba558/src/test.png -------------------------------------------------------------------------------- /src/tests/test_dataloaders.py: -------------------------------------------------------------------------------- 1 | from dataloaders.mmx.MMX_Frame_dl import MMXFrameDataset, MMXFrameDatamodule 2 | data_path = data_processing/temporal/mmx_train_temporal.pkl 3 | 4 | def dataloader_test(path): 5 | dataloader -------------------------------------------------------------------------------- /src/tests/test_tensors.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ed-fish/data-efficient-video-transformers/0a7d1b40563e244df14f4f33376cad413b2ba558/src/tests/test_tensors.py -------------------------------------------------------------------------------- /src/tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | 4 | sys.path.insert(1, '/home/ed/PhD/mmodal-moments-in-time') 5 | 6 | from transforms.spatio_cut import SpatioCut 7 | from transforms.img_transforms import ImgTransform 8 | 9 | 10 | 11 | test_vid = "/home/ed/PhD/mmodal-moments-in-time/input/juggling/juggling.mp4" 12 | 13 | class TestRandCrop(unittest.TestCase): 14 | def test_cutvid(self): 15 | 16 | sp = SpatioCut() 17 | output = sp.cut_vid(test_vid, 16) 18 | self.assertEqual(len(output), 3) 19 | self.assertEqual(len(output[0]), 16) 20 | self.assertEqual(len(output[1]), 16) 21 | self.assertEqual(len(output[2]), 16) 22 | 23 | 24 | if __name__ == '__main__': 25 | unittest.main() 26 | --------------------------------------------------------------------------------