├── data ├── .gitkeep └── assets │ └── teaser.png ├── configs ├── .gitkeep ├── something-else-detections.yaml ├── something-detections.yaml ├── something-video.yaml └── something-else-video.yaml ├── experiments └── .gitkeep ├── .gitignore ├── .flake8 ├── LICENSE ├── src ├── modelling │ ├── losses.py │ ├── dataset_layout.py │ ├── dataset_audio.py │ ├── datasets.py │ ├── models.py │ ├── distiller.py │ ├── resnets3d.py │ ├── dataset_proto.py │ ├── dataset_video.py │ ├── hand_models.py │ └── swin.py ├── utils │ ├── train_utils.py │ ├── calibration.py │ ├── setup.py │ ├── samplers.py │ ├── evaluation.py │ └── data_utils.py ├── pack_video_frames_to_hdf5.py ├── inference.py ├── train.py └── patient_distill.py └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | data/ 3 | notebooks/ 4 | __pycache__ 5 | logs/ 6 | .idea/ 7 | experiments/ 8 | .DS_Store -------------------------------------------------------------------------------- /data/assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gorjanradevski/multimodal-distillation/HEAD/data/assets/teaser.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | exclude = notebooks/ -------------------------------------------------------------------------------- /configs/something-else-detections.yaml: -------------------------------------------------------------------------------- 1 | ACTION_WEIGHTS: 2 | ACTION: 1.0 3 | LOG_TO_FILE: True 4 | BATCH_SIZE: 32 5 | DATASET_TYPE: layout 6 | EXPERIMENT_PATH: experiments/something-else-detections-stlt 7 | LABELS_PATH: data/comp_split_detect/something-something-v2-labels.json 8 | LOG_TO_FILE: False 9 | MODEL_NAME: stlt 10 | TOTAL_ACTIONS: 11 | ACTION: 174 12 | TRAIN_DATASET_NAME: something-something 13 | TRAIN_DATASET_PATH: data/something-something/something_else_detections/train_dataset.json 14 | VAL_DATASET_NAME: something-something 15 | VAL_DATASET_PATH: data/something-something/something_else_detections/val_dataset.json 16 | BATCH_SIZE: 128 17 | FRAME_UNIFORM: True 18 | NUM_FRAMES: 16 19 | WEIGHT_DECAY: 0.02 20 | LEARNING_RATE: 0.0001 21 | EPOCHS: 30 22 | WARMUP_EPOCHS: 3 -------------------------------------------------------------------------------- /configs/something-detections.yaml: -------------------------------------------------------------------------------- 1 | ACTION_WEIGHTS: 2 | ACTION: 1.0 3 | LOG_TO_FILE: True 4 | BATCH_SIZE: 32 5 | DATASET_TYPE: layout 6 | EXPERIMENT_PATH: experiments/something-something-detections/ 7 | LABELS_PATH: data/something_something_detections/something-something-v2-labels.json 8 | LOG_TO_FILE: False 9 | MODEL_NAME: stlt 10 | TOTAL_ACTIONS: 11 | ACTION: 174 12 | TRAIN_DATASET_NAME: something-something 13 | VAL_DATASET_NAME: something-something 14 | TRAIN_DATASET_PATH: data/something-something/something_something_detections/train_dataset.json 15 | VAL_DATASET_PATH: data/something-something/something_something_detections/val_dataset.json 16 | BATCH_SIZE: 128 17 | FRAME_UNIFORM: True 18 | NUM_FRAMES: 16 19 | WEIGHT_DECAY: 0.02 20 | LEARNING_RATE: 0.0001 21 | EPOCHS: 30 22 | WARMUP_EPOCHS: 3 -------------------------------------------------------------------------------- /configs/something-video.yaml: -------------------------------------------------------------------------------- 1 | TRAIN_DATASET_PATH: data/something-something/something-something-v2-train.json 2 | VAL_DATASET_PATH: data/something-something/something-something-v2-validation.json 3 | TRAIN_DATASET_NAME: something-something 4 | VAL_DATASET_NAME: something-something 5 | LABELS_PATH: data/something-something/something-something-v2-labels.json 6 | VIDEOS_PATH: data/something-something/dataset.hdf5 7 | EXPERIMENT_PATH: experiments/something-swin 8 | DATASET_TYPE: video 9 | TOTAL_ACTIONS: 10 | ACTION: 174 11 | ACTION_WEIGHTS: 12 | ACTION: 1.0 13 | AUGMENTATIONS: 14 | VIDEO: ["VideoResize", "VideoRandomCrop", "VideoColorJitter"] 15 | NUM_WORKERS: 5 16 | MODEL_NAME: swin 17 | BACKBONE_MODEL_PATH: data/pretrained-backbones/swin_tiny_patch244_window877_kinetics400_1k.pth 18 | BATCH_SIZE: 16 19 | FRAME_UNIFORM: True 20 | NUM_FRAMES: 16 21 | WEIGHT_DECAY: 0.02 # As per SWIN paper (Section 4.1) 22 | LEARNING_RATE: 0.0001 23 | EPOCHS: 30 24 | WARMUP_EPOCHS: 3 -------------------------------------------------------------------------------- /configs/something-else-video.yaml: -------------------------------------------------------------------------------- 1 | TRAIN_DATASET_PATH: data/something-something/something_else_detections/train_dataset.json 2 | VAL_DATASET_PATH: data/something-something/something_else_detections/val_dataset.json 3 | TRAIN_DATASET_NAME: something-something 4 | VAL_DATASET_NAME: something-something 5 | LABELS_PATH: data/something-something/something_else_detections/something-something-v2-labels.json 6 | VIDEOS_PATH: data/something-something/dataset.hdf5 7 | EXPERIMENT_PATH: experiments/something-else-swin 8 | DATASET_TYPE: video 9 | TOTAL_ACTIONS: 10 | ACTION: 174 11 | ACTION_WEIGHTS: 12 | ACTION: 1.0 13 | AUGMENTATIONS: 14 | VIDEO: ["VideoResize", "VideoRandomCrop", "VideoColorJitter"] 15 | NUM_WORKERS: 5 16 | MODEL_NAME: swin 17 | BACKBONE_MODEL_PATH: data/pretrained-backbones/swin_tiny_patch244_window877_kinetics400_1k.pth 18 | BATCH_SIZE: 16 19 | FRAME_UNIFORM: True 20 | NUM_FRAMES: 16 21 | WEIGHT_DECAY: 0.02 # As per SWIN paper (Section 4.1) 22 | LEARNING_RATE: 0.0001 23 | EPOCHS: 30 24 | WARMUP_EPOCHS: 3 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Gorjan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/modelling/losses.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from yacs.config import CfgNode 3 | 4 | 5 | class LossesModule: 6 | def __init__(self, cfg: CfgNode): 7 | self.cfg = cfg 8 | self.criterions = { 9 | "charades-ego": nn.BCEWithLogitsLoss(), 10 | "something-something": nn.CrossEntropyLoss(), 11 | "egtea-gaze": nn.CrossEntropyLoss(), 12 | "EPIC-KITCHENS": nn.CrossEntropyLoss(), 13 | } 14 | self.action_names = [ 15 | action 16 | for action in self.cfg.ACTION_WEIGHTS.keys() 17 | if self.cfg.ACTION_WEIGHTS[action] 18 | ] 19 | 20 | def __call__(self, model_output, batch): 21 | total_loss = 0 22 | # Aggregate losses 23 | for action_name in self.action_names: 24 | loss = ( 25 | self.criterions[self.cfg.TRAIN_DATASET_NAME]( 26 | model_output[action_name], 27 | batch["labels"][action_name], 28 | ) 29 | * self.cfg.ACTION_WEIGHTS[action_name] 30 | ) 31 | total_loss += loss 32 | 33 | return total_loss / len(self.action_names) 34 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from torch import optim 5 | 6 | 7 | def get_linear_schedule_with_warmup( 8 | optimizer: optim.Optimizer, num_warmup_steps: int, num_training_steps: int 9 | ): 10 | # https://huggingface.co/transformers/_modules/transformers/optimization.html#get_linear_schedule_with_warmup 11 | def lr_lambda(current_step: int): 12 | if current_step < num_warmup_steps: 13 | return float(current_step) / float(max(1, num_warmup_steps)) 14 | return max( 15 | 0.0, 16 | float(num_training_steps - current_step) 17 | / float(max(1, num_training_steps - num_warmup_steps)), 18 | ) 19 | 20 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 21 | 22 | 23 | def move_batch_to_device(batch, device): 24 | tmp_batch = {} 25 | for k, v in batch.items(): 26 | if isinstance(v, torch.Tensor): 27 | tmp_batch[k] = v.to(device) 28 | elif isinstance(v, dict): 29 | tmp_batch[k] = {} 30 | for inside_k, inside_v in v.items(): 31 | tmp_batch[k][inside_k] = inside_v.to(device) 32 | else: 33 | tmp_batch[k] = v 34 | 35 | return tmp_batch 36 | 37 | 38 | class EpochHandler: 39 | def __init__(self, epoch: int = 0): 40 | self.epoch = epoch 41 | 42 | def save_state(self, path: str): 43 | json.dump({"epoch": self.epoch}, open(path, "w")) 44 | 45 | def set_epoch(self, epoch: int): 46 | self.epoch = epoch 47 | 48 | def load_state(self, path: str): 49 | checkpoint = json.load(open(path)) 50 | self.epoch = checkpoint["epoch"] 51 | -------------------------------------------------------------------------------- /src/pack_video_frames_to_hdf5.py: -------------------------------------------------------------------------------- 1 | # This script packs a dataset of video frames into a HDF5 file. The script assumes that 2 | # the video frames are stored in an image format (.png, .jpg, etc.) and that the frames 3 | # for a single video are stored in a directory named after the video id. 4 | 5 | import argparse 6 | import os 7 | 8 | import h5py 9 | import numpy as np 10 | from natsort import natsorted 11 | from tqdm import tqdm 12 | 13 | 14 | def pack_video_frames_to_hdf5(args): 15 | with h5py.File(args.save_hdf5_path, "a") as hdf5_file: 16 | # Iterate over all video frames 17 | for video_id in tqdm(natsorted(os.listdir(args.all_video_frames_path))): 18 | video_frames_path = os.path.join(args.all_video_frames_path, video_id) 19 | # Iterate over all frames in a video 20 | frames = [] 21 | for frame_name in os.listdir(video_frames_path): 22 | frame_path = os.path.join(video_frames_path, frame_name) 23 | # Load frames 24 | with open(frame_path, "rb") as img_f: 25 | binary_data = img_f.read() 26 | # Covert to numpy as aggregate 27 | frames.append(np.asarray(binary_data)) 28 | frames = np.concatenate(frames, axis=0) 29 | hdf5_file.create_dataset(video_id, data=frames) 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser(description="Pack video frames in HDF5.") 34 | parser.add_argument( 35 | "--all_video_frames_path", 36 | type=str, 37 | default="data/extracted_videos", 38 | help="From where to load the video frames.", 39 | ) 40 | parser.add_argument( 41 | "--save_hdf5_path", 42 | type=str, 43 | default="data/dataset.hdf5", 44 | help="Where to save the HDF5 file.", 45 | ) 46 | args = parser.parse_args() 47 | pack_video_frames_to_hdf5(args) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /src/utils/calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from yacs.config import CfgNode 5 | 6 | 7 | class _ECELoss(nn.Module): 8 | """ 9 | https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py#L78 10 | Calculates the Expected Calibration Error of a model. 11 | (This isn't necessary for temperature scaling, just a cool metric). 12 | The input to this loss is the logits of a model, NOT the softmax scores. 13 | This divides the confidence outputs into equally-sized interval bins. 14 | In each bin, we compute the confidence gap: 15 | bin_gap = | avg_confidence_in_bin - accuracy_in_bin | 16 | We then return a weighted average of the gaps, based on the number 17 | of samples in each bin 18 | See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht. 19 | "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI. 20 | 2015. 21 | """ 22 | 23 | def __init__(self, n_bins=15): 24 | """ 25 | n_bins (int): number of confidence interval bins 26 | """ 27 | super(_ECELoss, self).__init__() 28 | bin_boundaries = torch.linspace(0, 1, n_bins + 1) 29 | self.bin_lowers = bin_boundaries[:-1] 30 | self.bin_uppers = bin_boundaries[1:] 31 | 32 | def forward(self, logits, labels): 33 | softmaxes = F.softmax(logits, dim=1) 34 | confidences, predictions = torch.max(softmaxes, 1) 35 | accuracies = predictions.eq(labels) 36 | 37 | ece = torch.zeros(1, device=logits.device) 38 | for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers): 39 | # Calculated |confidence - accuracy| in each bin 40 | in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) 41 | prop_in_bin = in_bin.float().mean() 42 | if prop_in_bin.item() > 0: 43 | accuracy_in_bin = accuracies[in_bin].float().mean() 44 | avg_confidence_in_bin = confidences[in_bin].mean() 45 | ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin 46 | 47 | return ece 48 | 49 | 50 | class CalibrationEvaluator: 51 | def __init__(self, cfg: CfgNode): 52 | self.cfg = cfg 53 | # Get action names to have only one evaluator 54 | self.action_names = [ 55 | action 56 | for action in self.cfg.ACTION_WEIGHTS.keys() 57 | if self.cfg.ACTION_WEIGHTS[action] 58 | ] 59 | # Aggregators 60 | self.logits = {} 61 | self.labels = {} 62 | for action_name in self.action_names: 63 | self.logits[action_name] = [] 64 | self.labels[action_name] = [] 65 | # Evaluation criterion 66 | self.ece_criterion = _ECELoss() 67 | 68 | def process(self, model_output, labels): 69 | for action_name in self.action_names: 70 | self.logits[action_name].append(model_output[action_name]) 71 | self.labels[action_name].append(labels[action_name]) 72 | 73 | def evaluate(self): 74 | calibration_metrics = {} 75 | for action_name in self.action_names: 76 | logits = torch.cat(self.logits[action_name], dim=0) 77 | # Because of the number of temporal clips x spatial crops 78 | logits = logits.mean(1) 79 | labels = torch.cat(self.labels[action_name], dim=0) 80 | calibration_metrics[action_name] = ( 81 | self.ece_criterion(logits=logits, labels=labels).item() * 100.0 82 | ) 83 | 84 | return calibration_metrics 85 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from collections import OrderedDict 4 | from os.path import join as pjoin 5 | 6 | import torch 7 | from accelerate import Accelerator 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from yacs.config import CfgNode 11 | 12 | from modelling.datasets import dataset_factory 13 | from modelling.models import model_factory 14 | from utils.calibration import CalibrationEvaluator 15 | from utils.data_utils import separator 16 | from utils.evaluation import evaluators_factory 17 | from utils.setup import get_cfg_defaults 18 | 19 | 20 | def unwrap_compiled_checkpoint(checkpoint): 21 | # If Pytorch 2.0, checkpoint is wrapped with _orig_mod, so we remove it 22 | new_checkpoint = OrderedDict() 23 | for key in checkpoint.keys(): 24 | new_key = key[10:] if key.startswith("_orig_mod") else key 25 | new_checkpoint[new_key] = checkpoint[key] 26 | 27 | return new_checkpoint 28 | 29 | 30 | @torch.no_grad() 31 | def inference(cfg: CfgNode): 32 | logging.basicConfig(level=logging.INFO) 33 | accelerator = Accelerator() 34 | # Prepare datasets 35 | if accelerator.is_main_process: 36 | logging.info("Preparing datasets...") 37 | # Prepare validation dataset 38 | if accelerator.is_main_process: 39 | logging.info(separator) 40 | logging.info(f"The config is:\n{cfg}") 41 | logging.info(separator) 42 | val_dataset = dataset_factory[cfg.DATASET_TYPE](cfg, train=False) 43 | if accelerator.is_main_process: 44 | logging.info(f"Validating on {len(val_dataset)}") 45 | # Prepare loaders 46 | val_loader = DataLoader( 47 | val_dataset, 48 | batch_size=cfg.BATCH_SIZE, 49 | num_workers=cfg.NUM_WORKERS, 50 | pin_memory=True if cfg.NUM_WORKERS else False, 51 | ) 52 | if accelerator.is_main_process: 53 | logging.info("Preparing model...") 54 | # Prepare model 55 | model = model_factory[cfg.MODEL_NAME](cfg) 56 | checkpoint = torch.load( 57 | pjoin(cfg.EXPERIMENT_PATH, "model_checkpoint.pt"), map_location="cpu" 58 | ) 59 | checkpoint = unwrap_compiled_checkpoint(checkpoint) 60 | model.load_state_dict(checkpoint) 61 | # Prepare evaluators 62 | evaluator = evaluators_factory[cfg.VAL_DATASET_NAME](len(val_dataset), cfg) 63 | calibration_evaluator = CalibrationEvaluator(cfg) 64 | if accelerator.is_main_process: 65 | logging.info("Starting inference...") 66 | # Accelerate 67 | model, val_loader = accelerator.prepare(model, val_loader) 68 | model.train(False) 69 | evaluator.reset() 70 | # Inference 71 | for batch in tqdm(val_loader, disable=not accelerator.is_main_process): 72 | # Obtain outputs: [b * n_clips, n_actions] 73 | model_output = model(batch) 74 | # Gather 75 | all_outputs = accelerator.gather(model_output) 76 | all_labels = accelerator.gather(batch["labels"]) 77 | # Reshape outputs and put on cpu 78 | for key in all_outputs.keys(): 79 | num_classes = all_outputs[key].size(-1) 80 | # Reshape 81 | all_outputs[key] = all_outputs[key].reshape( 82 | -1, cfg.NUM_TEST_CLIPS * cfg.NUM_TEST_CROPS, num_classes 83 | ) 84 | # Move on CPU 85 | all_outputs[key] = all_outputs[key].cpu() 86 | # Put labels on cpu 87 | for key in all_labels.keys(): 88 | all_labels[key] = all_labels[key].cpu() 89 | # Evaluate 90 | evaluator.process(all_outputs, all_labels) 91 | calibration_evaluator.process(all_outputs, all_labels) 92 | # Metrics 93 | if accelerator.is_main_process: 94 | metrics = evaluator.evaluate_verbose() 95 | for m in metrics.keys(): 96 | logging.info(f"{m}: {metrics[m]}") 97 | # Calibration Metrics 98 | calibration_metrics = calibration_evaluator.evaluate() 99 | logging.info("=============== Calibration Metrics (ECE) ===============") 100 | for m in calibration_metrics.keys(): 101 | logging.info(f"{m}: {calibration_metrics[m]}") 102 | 103 | 104 | def main(): 105 | parser = argparse.ArgumentParser(description="Inference.") 106 | parser.add_argument( 107 | "--experiment_path", required=True, help="Path to the experiment." 108 | ) 109 | parser.add_argument("--opts", nargs=argparse.REMAINDER) 110 | args = parser.parse_args() 111 | cfg = get_cfg_defaults() 112 | cfg.merge_from_file(pjoin(args.experiment_path, "config.yaml")) 113 | if args.opts: 114 | cfg.merge_from_list(args.opts) 115 | # Set backbone_model_path to None as not important 116 | cfg.BACKBONE_MODEL_PATH = None 117 | # Freeze the config 118 | cfg.freeze() 119 | inference(cfg) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /src/modelling/dataset_layout.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from yacs.config import CfgNode 7 | 8 | from modelling.dataset_proto import ProtoDataset 9 | from utils.data_utils import fix_box 10 | from utils.samplers import get_sampler 11 | 12 | 13 | class LayoutDataset(ProtoDataset): 14 | def __init__(self, cfg: CfgNode, train: bool = False): 15 | self.cfg = cfg 16 | self.train = train 17 | self.dataset_name = ( 18 | cfg.TRAIN_DATASET_NAME if self.train else cfg.VAL_DATASET_NAME 19 | ) 20 | self.dataset_path = ( 21 | cfg.TRAIN_DATASET_PATH if self.train else cfg.VAL_DATASET_PATH 22 | ) 23 | self.create_dataset() 24 | self.cat2index = {"pad": 0, "object": 1, "hand": 2} 25 | self.max_objects = 14 26 | self.sampler = get_sampler[self.cfg.MODEL_NAME](cfg=cfg, train=train) 27 | 28 | def open_videos(self): 29 | raise AttributeError("Layout dataset does not use videos.") 30 | 31 | def open_resource(self): 32 | assert self.cfg.DETECTIONS_PATH, "Path to detections must be provided!" 33 | self.resource = h5py.File( 34 | self.cfg.DETECTIONS_PATH, "r", libver="latest", swmr=True 35 | ) 36 | 37 | def get_detections(self, dataset_element, index: int, **kwargs): 38 | boxes, class_labels, scores, states, sides = [], [], [], [], [] 39 | if self.dataset_name == "EPIC-KITCHENS": 40 | if "resource" in kwargs: 41 | resource = kwargs.pop("resource") 42 | else: 43 | resource = self.resource 44 | video_name = dataset_element["id"] 45 | detections = ast.literal_eval( 46 | str(np.array(resource[video_name][str(index)]))[2:-1] 47 | ) 48 | for e in detections: 49 | boxes.append(e["box"]) 50 | class_labels.append(self.cat2index[e["category"]]) 51 | scores.append(e["score"]) 52 | state = e["state"] + 1 if "state" in e else 0 53 | states.append(state) 54 | side = e["side"] + 1 if "side" in e else 0 55 | sides.append(side) 56 | elif self.dataset_name == "something-something": 57 | w, h = dataset_element["size"] 58 | frame_objects = dataset_element["frames"][index] 59 | for e in frame_objects: 60 | box = fix_box([e["x1"], e["y1"], e["x2"], e["y2"]], video_size=(h, w)) 61 | boxes.append(box) 62 | class_label = ( 63 | self.cat2index["hand"] 64 | if "hand" in e["category"] 65 | else self.cat2index["object"] 66 | ) 67 | class_labels.append(class_label) 68 | scores.append(e["score"]) 69 | sides = [0 for _ in range(len(class_labels))] 70 | states = [0 for _ in range(len(class_labels))] 71 | else: 72 | raise ValueError(f"{self.dataset_name} not available!") 73 | 74 | return boxes, class_labels, scores, sides, states 75 | 76 | def get_video_length(self, sample): # Reimplemented 77 | return len(sample["frames"]) 78 | 79 | def __getitem__(self, idx: int): 80 | if not hasattr(self, "resource") and self.dataset_name == "EPIC-KITCHENS": 81 | self.open_resource() 82 | output = { 83 | "bboxes": [], 84 | "class_labels": [], 85 | "scores": [], 86 | "sides": [], 87 | "states": [], 88 | "src_key_padding_mask_boxes": [], 89 | } 90 | if not hasattr(self, "indices"): 91 | indices = self.sampler( 92 | video_length=self.get_video_length(self.dataset[idx]) 93 | ) 94 | else: 95 | indices = self.indices 96 | output["indices"] = indices 97 | output["start_frame"] = 0 98 | 99 | for index in indices: 100 | bboxes, class_labels, scores, sides, states = self.get_detections( 101 | self.dataset[idx], index 102 | ) 103 | # Perform padding to max objects 104 | while len(bboxes) < self.max_objects: 105 | class_labels.append(self.cat2index["pad"]) 106 | bboxes.append([0, 0, 1e-9, 1e-9]) 107 | scores.append(0.0) 108 | sides.append(0) 109 | states.append(0) 110 | # Add boxes 111 | bboxes = torch.tensor(bboxes, dtype=torch.float32) 112 | if self.dataset_name == "something-something": 113 | w, h = self.dataset[idx]["size"] 114 | bboxes = bboxes / torch.tensor([w, h, w, h]) 115 | output["bboxes"].append(bboxes) 116 | # Add class labels 117 | output["class_labels"].append(torch.tensor(class_labels)) 118 | # Add scores 119 | output["scores"].append(torch.tensor(scores)) 120 | # Add sides 121 | output["sides"].append(torch.tensor(sides)) 122 | # Add states 123 | output["states"].append(torch.tensor(states)) 124 | # Generate mask 125 | output["src_key_padding_mask_boxes"].append( 126 | output["class_labels"][-1] == self.cat2index["pad"] 127 | ) 128 | # Convert to tensors 129 | output = { 130 | key: torch.stack(val, dim=0) 131 | if key not in ["indices", "start_frame"] 132 | else val 133 | for key, val in output.items() 134 | } 135 | output["labels"] = self.get_actions(self.dataset[idx]) 136 | 137 | return output 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Distillation for Egocentric Action Recognition 2 | 3 | This repository contains the implementation of the paper [Multimodal Distillation for Egocentric Action Recognition](https://arxiv.org/abs/2307.07483), published at ICCV 2023. 4 | 5 | ![Teaser](data/assets/teaser.png) 6 | 7 | ## Reproducing the virtual environment 8 | 9 | The main dependencies that you need to install to reproduce the virtual environment are [PyTorch](https://pytorch.org/), and: 10 | 11 | ```shell 12 | pip install accelerate tqdm h5py yacs timm einops natsort 13 | ``` 14 | 15 | ## Downloading the pre-trained Swin-T model 16 | 17 | Create a directory `./data/pretrained-backbones/` and download Swin-T from [here](https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_tiny_patch244_window877_kinetics400_1k.pth): 18 | 19 | ```bash 20 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_tiny_patch244_window877_kinetics400_1k.pth -O ./data/pretrained-backbones/ 21 | ``` 22 | 23 | ## Preparing the Epic-Kitchens and the Something-Something/Else datasets 24 | 25 | We store all data (video frames, optical flow frames, audios, etc.) is an efficient HDF5 file where each video represents a dataset within the HDF5 file, and the n-th element of the dataset contains the bytes for the n-th frame of the video. You can download the Something-Something and Something-Else datasets from [this link](https://filesender.belnet.be/?s=download&token=cd0a85df-66ff-4716-80fc-734c0bae85ce), and the Epic-Kitchens dataset from [this link](https://filesender.belnet.be/?s=download&token=b8427f41-e8d7-4615-94d1-d0ef3eb5bbf1). This includes all the modalities we use for each dataset. 26 | 27 | Please download and place the datasets inside `./data/` - `./data/something-something/` and `./data/EPIC-KITCHENS`. Otherwise, feel free to store the data wherever you see fit, just do not forget to modify the `config.yaml` files with the appropriate location. In this README.md, we assume that all data is placed inside `./data/`, and all experiments are placed inside `./experiments/`. 28 | 29 | ## Model ZOO 30 | 31 | | Dataset | Model Type | Model architecture | Training modalities | Download Link | 32 | |----------|----------------|----------------|---------------|----------------------| 33 | | Something-Something | Distilled student | Swin-T | RGB frames + Optical Flow + Object Detections | [Download](https://drive.google.com/drive/folders/1qAZTKYjt-D2Y95BlTsjw2Ny9YWMWGnQC?usp=sharing) | 34 | | Something-Else | Distilled student | Swin-T | RGB frames + Optical Flow + Object Detections | [Download](https://drive.google.com/drive/folders/1zK_dEGJP21xtgrZgc_gOPBPvWg-DKL2j?usp=sharing) | 35 | | Epic-Kitchens | Distilled student | Swin-T | RGB frames + Optical Flow + Audio | [Download](https://drive.google.com/drive/folders/1KUBiwGodTLqtuoRoxJVbpwBZ8uieJCtm?usp=sharing) | 36 | | Something-Something | Unimodal | Swin-T | RGB Frames | [Download](https://drive.google.com/drive/folders/1YIcO65zWdMW1Cm11JF392b7uE4M1-XDB?usp=sharing) | 37 | | Something-Something | Unimodal | Swin-T | Optical Flow | [Download](https://drive.google.com/drive/folders/1GVMrpGtkv6fC6FgpWmEMykBb6HHCeQEE?usp=sharing) | 38 | | Something-Something | Unimodal | STLT | Object Detections | [Download](https://drive.google.com/drive/folders/1RbRUEpYFE4AqTrIJXNtfFFezfqal86Cp?usp=sharing) | 39 | | Something-Else | Unimodal | Swin-T | RGB frames | [Download](https://drive.google.com/drive/folders/1jNO-OBb6rmA2Gl0MS5x-gZmkG7lN5ogM?usp=sharing) | 40 | | Something-Else | Unimodal | Swin-T | Optical Flow | [Download](https://drive.google.com/drive/folders/1GSt-ZbAoVvGI8JWIXMaOphQm4FMWlbLe?usp=sharing) | 41 | | Something-Else | Unimodal | STLT | Object Detections | [Download](https://drive.google.com/drive/folders/16hna0e1RnzcQ750FAD213tm5clD5e4X3?usp=sharing) | 42 | | Epic-Kitchens | Unimodal | Swin-T | RGB frames | [Download](https://drive.google.com/drive/folders/101kHwBAQTDbL8IpaODxZ8jG_aMurCioW?usp=sharing) | 43 | | Epic-Kitchens | Unimodal | Swin-T | Optical Flow | [Download](https://drive.google.com/drive/folders/1DBmfQo5-8AmRmrqt9ZmFL7B4EjoGLQsA?usp=sharing) | 44 | | Epic-Kitchens | Unimodal | Swin-T | Audio | [Download](https://drive.google.com/drive/folders/1yS9apZIUPHWUpQiXi0bCHXb5sfFsvY84?usp=sharing) | 45 | 46 | 47 | 48 | 49 | ## Inference on Epic-Kitchens 50 | 51 | 1. Download our Epic-Kitchens distilled model from the Model ZOO, and place it in `./experiments/`. 52 | 2. Run inference as: 53 | 54 | ```python 55 | python src/inference.py --experiment_path "experiments/epic-kitchens-swint-distill-flow-audio" --opts DATASET_TYPE "video" 56 | ``` 57 | 58 | ## Inference on Something-Something & Something-Else 59 | 60 | 1. Download our Something-Else distilled model or the Something-Something distilled model from the Model ZOO, and place it in `./experiments/`. 61 | 2. Run inference as: 62 | 63 | ```python 64 | python src/inference.py --experiment_path "experiments/something-swint-distill-layout-flow" --opts DATASET_TYPE "video" 65 | ``` 66 | 67 | for Something-Something, and 68 | 69 | ```python 70 | python src/inference.py --experiment_path "experiments/something-else-swint-distill-layout-flow" --opts DATASET_TYPE "video" 71 | ``` 72 | 73 | for Something-Else. 74 | 75 | ## Distilling from Multimodal Teachers 76 | 77 | To reproduce the experiments (i.e., using the identical hyperparameters, where only the random seed will vary): 78 | 79 | ```python 80 | python src/patient_distill.py --config "experiments/something-else-swint-distill-layout-flow/config.yaml" --opts EXPERIMENT_PATH "experiments/experiments/reproducing-the-something-else-experiment" 81 | ``` 82 | 83 | note that this assumes access to the datasets for all modalities (video, optical flow, audio, object detections), as well as the individual (unimodal) models which constitute the multimodal ensemble teacher. 84 | 85 | ## TODOs 86 | 87 | - [ ] Release Something-Something pretrained teachers for each modality. 88 | - [ ] Test the codebase. 89 | - [x] Structure the Model ZOO part of the codebase. 90 | 91 | ## Citation 92 | 93 | If you find our code useful for your own research, please use the following BibTeX entry: 94 | 95 | ```tex 96 | @inproceedings{radevski2023multimodal, 97 | title={Multimodal Distillation for Egocentric Action Recognition}, 98 | author={Radevski, Gorjan and Grujicic, Dusan and Blaschko, Matthew and Moens, Marie-Francine and Tuytelaars, Tinne}, 99 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 100 | pages={5213--5224}, 101 | year={2023} 102 | } 103 | ``` -------------------------------------------------------------------------------- /src/modelling/dataset_audio.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import torch 4 | import torchaudio.functional as F 5 | import torchaudio.transforms as T 6 | from yacs.config import CfgNode 7 | 8 | from modelling.dataset_proto import ProtoDataset 9 | from utils.data_utils import ( 10 | compile_transforms, 11 | extract_audio_segments, 12 | get_audio_transforms, 13 | video2audio_indices, 14 | ) 15 | from utils.samplers import get_sampler 16 | 17 | 18 | class AudioDataset(ProtoDataset): 19 | def __init__(self, cfg: CfgNode, train: bool = False): 20 | self.cfg = cfg 21 | self.train = train 22 | self.dataset_type = "audio" 23 | self.dataset_name = ( 24 | cfg.TRAIN_DATASET_NAME if self.train else cfg.VAL_DATASET_NAME 25 | ) 26 | assert ( 27 | self.dataset_name == "EPIC-KITCHENS" 28 | ), "Audio only defined for 'EPIC-KITCHENS'!" 29 | self.dataset_path = ( 30 | cfg.TRAIN_DATASET_PATH if self.train else cfg.VAL_DATASET_PATH 31 | ) 32 | self.create_dataset() 33 | # Define transform 34 | self.spectrogram = T.MelSpectrogram( 35 | n_fft=1024, 36 | win_length=160, 37 | hop_length=80, 38 | center=True, 39 | pad_mode="reflect", 40 | power=2.0, 41 | ) 42 | self.sampler = get_sampler[self.cfg.MODEL_NAME](cfg=cfg, train=train) 43 | 44 | def get_difference_seconds(self, start_timestamp, stop_timestamp): 45 | start_timestamp = datetime.strptime(start_timestamp, "%H:%M:%S.%f") 46 | stop_timestamp = datetime.strptime(stop_timestamp, "%H:%M:%S.%f") 47 | return (stop_timestamp - start_timestamp).total_seconds() 48 | 49 | def get_audio_length(self, sample): 50 | # Returns the video length in seconds 51 | start_timestamp = sample["start_timestamp"] 52 | stop_timestamp = sample["stop_timestamp"] 53 | return self.get_difference_seconds(start_timestamp, stop_timestamp) 54 | 55 | def get_timestamp_seconds(self, timestamp): 56 | # Returns the timestamp total number of seconds 57 | timestamp = datetime.strptime(timestamp, "%H:%M:%S.%f") 58 | reference_timestamp = datetime.strptime("00:00:00.00", "%H:%M:%S.%f") 59 | return (timestamp - reference_timestamp).total_seconds() 60 | 61 | def get_audio_spectrograms(self, **kwargs): 62 | indices = kwargs.pop("indices") 63 | video_id = kwargs.pop("video_id") 64 | start_timestamp = kwargs.pop("start_timestamp") 65 | stop_timestamp = kwargs.pop("stop_timestamp") 66 | 67 | # Fix for omnivore so we don't need to re-implement method 68 | if "resource" in kwargs: 69 | resource = kwargs.pop("resource") 70 | else: 71 | resource = self.resource 72 | 73 | audio_sample_rate = self.cfg.AUDIO_SAMPLE_RATE 74 | frame_offset = int( 75 | self.get_timestamp_seconds(start_timestamp) * audio_sample_rate 76 | ) 77 | num_frames = int( 78 | self.get_difference_seconds(start_timestamp, stop_timestamp) 79 | * audio_sample_rate 80 | ) 81 | # Extract audio 82 | audio = torch.tensor( 83 | resource[video_id][frame_offset : frame_offset + num_frames] 84 | ).unsqueeze(0) 85 | if self.cfg.AUDIO_RESAMPLE_RATE: 86 | audio = F.resample(audio, audio_sample_rate, self.cfg.AUDIO_RESAMPLE_RATE) 87 | audio_sample_rate = self.cfg.AUDIO_RESAMPLE_RATE 88 | 89 | # Extract audio segments 90 | audio_indices = video2audio_indices( 91 | indices, 92 | video_fps=self.cfg.VIDEO_FPS, 93 | audio_sample_rate=audio_sample_rate, 94 | ) 95 | segment_length = int(self.cfg.AUDIO_SEGMENT_LENGTH * audio_sample_rate) 96 | audio_segments = extract_audio_segments( 97 | audio, 98 | segment_length=segment_length, 99 | audio_indices=audio_indices, 100 | ) 101 | spectrogram = self.spectrogram( 102 | audio_segments 103 | ) # [Eval_Clips x Frames, Height (Freqs), Width (Timesteps)] 104 | return spectrogram 105 | 106 | def __getitem__(self, idx: int): 107 | output = {self.dataset_type: []} 108 | if not hasattr(self, "resource"): 109 | self.open_resource() 110 | output["id"] = self.dataset[idx]["id"] 111 | if not hasattr(self, "indices"): 112 | indices = self.sampler( 113 | video_length=self.get_video_length(self.dataset[idx]) 114 | ) 115 | else: 116 | indices = self.indices 117 | output["indices"] = indices 118 | output["start_frame"] = int( 119 | self.get_timestamp_seconds(self.dataset[idx]["start_timestamp"]) 120 | * self.cfg.VIDEO_FPS 121 | ) 122 | # Check for existing transforms 123 | if not hasattr(self, "existing_transforms"): 124 | existing_transforms = {} 125 | else: 126 | existing_transforms = self.existing_transforms 127 | # Obtain the spectograms 128 | output[self.dataset_type] = self.get_audio_spectrograms( 129 | indices=indices, 130 | video_id=self.dataset[idx]["id"], 131 | start_timestamp=self.dataset[idx]["start_timestamp"], 132 | stop_timestamp=self.dataset[idx]["stop_timestamp"], 133 | ).unsqueeze(0) 134 | # Get augmentations 135 | audio_transforms_dict = get_audio_transforms( 136 | augmentations_list=self.cfg.AUGMENTATIONS.AUDIO, 137 | cfg=self.cfg, 138 | train=self.train, 139 | existing_transforms=existing_transforms, 140 | ) 141 | self.enforced_transforms = audio_transforms_dict 142 | audio_transforms = compile_transforms(audio_transforms_dict) 143 | output[self.dataset_type] = audio_transforms(output[self.dataset_type]) 144 | # [Eval_Clips x Frames, 1, Spatial size, Spatial size] 145 | output[self.dataset_type] = output[self.dataset_type].reshape( 146 | -1, 147 | self.cfg.NUM_FRAMES, 148 | 1, 149 | output[self.dataset_type].shape[-2], 150 | output[self.dataset_type].shape[-1], 151 | ) # [Eval_Clips, Frames, 1, Spatial size, Spatial size] 152 | output[self.dataset_type] = output[self.dataset_type].repeat(1, 1, 3, 1, 1) 153 | # [Eval_Clips, Frames, 3, Spatial size, Spatial size] 154 | # Obtain video labels 155 | output["labels"] = self.get_actions(self.dataset[idx]) 156 | return output 157 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from os.path import join as pjoin 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from torch import optim 8 | from torch.utils.data import DataLoader, Subset 9 | from tqdm import tqdm 10 | from yacs.config import CfgNode 11 | 12 | from modelling.datasets import dataset_factory 13 | from modelling.losses import LossesModule 14 | from modelling.models import model_factory 15 | from utils.data_utils import separator 16 | from utils.evaluation import evaluators_factory 17 | from utils.setup import train_setup 18 | from utils.train_utils import get_linear_schedule_with_warmup 19 | 20 | 21 | def train(cfg: CfgNode, accelerator: Accelerator): 22 | if cfg.LOG_TO_FILE: 23 | if accelerator.is_main_process: 24 | logging.basicConfig( 25 | level=logging.INFO, 26 | filename=pjoin(cfg.EXPERIMENT_PATH, "experiment_log.log"), 27 | filemode="w", 28 | ) 29 | else: 30 | logging.basicConfig(level=logging.INFO) 31 | if accelerator.is_main_process: 32 | logging.info(separator) 33 | logging.info(f"The config file is:\n {cfg}") 34 | logging.info(separator) 35 | # Prepare datasets 36 | if accelerator.is_main_process: 37 | logging.info("Preparing datasets...") 38 | # Prepare train dataset 39 | train_dataset = dataset_factory[cfg.DATASET_TYPE](cfg, train=True) 40 | # Prepare validation dataset 41 | val_dataset = dataset_factory[cfg.DATASET_TYPE](cfg, train=False) 42 | num_training_samples = len(train_dataset) 43 | if cfg.VAL_SUBSET: 44 | val_indices = random.sample(range(len(val_dataset)), cfg.VAL_SUBSET) 45 | val_dataset = Subset(val_dataset, val_indices) 46 | num_validation_samples = len(val_dataset) 47 | if accelerator.is_main_process: 48 | logging.info(f"Training on {num_training_samples}") 49 | logging.info(f"Validating on {num_validation_samples}") 50 | # Prepare loaders 51 | train_loader = DataLoader( 52 | train_dataset, 53 | batch_size=cfg.BATCH_SIZE, 54 | shuffle=True, 55 | num_workers=cfg.NUM_WORKERS, 56 | pin_memory=True if cfg.NUM_WORKERS else False, 57 | ) 58 | val_loader = DataLoader( 59 | val_dataset, 60 | batch_size=cfg.BATCH_SIZE, 61 | num_workers=cfg.NUM_WORKERS, 62 | pin_memory=True if cfg.NUM_WORKERS else False, 63 | ) 64 | if accelerator.is_main_process: 65 | logging.info("Preparing model...") 66 | # Prepare model 67 | model = model_factory[cfg.MODEL_NAME](cfg) 68 | # If PyTorch 2.0, compile the model 69 | if hasattr(torch, "compile"): 70 | if accelerator.is_main_process: 71 | logging.info("Compile model...") 72 | model = torch.compile(model) 73 | # Optimizer, scheduler and similar... 74 | optimizer = optim.AdamW( 75 | model.parameters(), 76 | lr=cfg.LEARNING_RATE, 77 | weight_decay=cfg.WEIGHT_DECAY, 78 | ) 79 | num_batches = num_training_samples // cfg.BATCH_SIZE 80 | scheduler = get_linear_schedule_with_warmup( 81 | optimizer, 82 | num_warmup_steps=cfg.WARMUP_EPOCHS * num_batches, 83 | num_training_steps=cfg.EPOCHS * num_batches, 84 | ) 85 | evaluator = evaluators_factory[cfg.VAL_DATASET_NAME](num_validation_samples, cfg) 86 | # Accelerate 87 | model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( 88 | model, optimizer, train_loader, val_loader, scheduler 89 | ) 90 | # Create loss 91 | criterion = LossesModule(cfg) 92 | if accelerator.is_main_process: 93 | logging.info("Starting training...") 94 | for epoch in range(cfg.EPOCHS): 95 | # Training loop 96 | model.train(True) 97 | with tqdm( 98 | total=len(train_loader), disable=not accelerator.is_main_process 99 | ) as pbar: 100 | for batch in train_loader: 101 | # Remove past gradients 102 | optimizer.zero_grad() 103 | # Obtain outputs: [b * n_clips, n_actions] 104 | model_output = model(batch) 105 | # Measure loss and update weights 106 | loss = criterion(model_output, batch) 107 | accelerator.backward(loss) 108 | accelerator.clip_grad_norm_(model.parameters(), cfg.CLIP_VAL) 109 | optimizer.step() 110 | # Update the scheduler 111 | scheduler.step() 112 | # Update progress bar 113 | pbar.update(1) 114 | pbar.set_postfix({"Loss": loss.item()}) 115 | # Validation loop 116 | model.train(False) 117 | evaluator.reset() 118 | for batch in tqdm(val_loader, disable=not accelerator.is_main_process): 119 | with torch.no_grad(): 120 | # Obtain outputs: [b * n_clips, n_actions] 121 | model_output = model(batch) 122 | all_outputs = accelerator.gather(model_output) 123 | all_labels = accelerator.gather(batch["labels"]) 124 | # Reshape outputs and put on cpu 125 | for key in all_outputs.keys(): 126 | num_classes = all_outputs[key].size(-1) 127 | # Reshape 128 | all_outputs[key] = all_outputs[key].reshape( 129 | -1, cfg.NUM_TEST_CLIPS * cfg.NUM_TEST_CROPS, num_classes 130 | ) 131 | # Move on CPU 132 | all_outputs[key] = all_outputs[key].cpu() 133 | # Put labels on cpu 134 | for key in all_labels.keys(): 135 | all_labels[key] = all_labels[key].cpu() 136 | # Pass to evaluator 137 | evaluator.process(all_outputs, all_labels) 138 | # Evaluate & save model 139 | accelerator.wait_for_everyone() 140 | if accelerator.is_main_process: 141 | metrics = evaluator.evaluate() 142 | if evaluator.is_best(): 143 | logging.info(separator) 144 | logging.info(f"Found new best on epoch {epoch+1}!") 145 | logging.info(separator) 146 | unwrapped_model = accelerator.unwrap_model(model) 147 | accelerator.save( 148 | unwrapped_model.state_dict(), 149 | pjoin(cfg.EXPERIMENT_PATH, "model_checkpoint.pt"), 150 | ) 151 | for m in metrics.keys(): 152 | logging.info(f"{m}: {metrics[m]}") 153 | 154 | 155 | def main(): 156 | cfg, accelerator = train_setup("Trains an action recognition model.") 157 | train(cfg, accelerator) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /src/utils/setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from accelerate import Accelerator, DistributedDataParallelKwargs 5 | from yacs.config import CfgNode as CN 6 | 7 | _C = CN() 8 | # Warm restarting training 9 | _C.WARM_RESTART = False 10 | # Whether to log to a file 11 | _C.LOG_TO_FILE = False 12 | # The path to the config 13 | _C.CONFIG_PATH = None 14 | # The name of the model to be used 15 | _C.MODEL_NAME = None 16 | # IF Resnet3D, then model depth too 17 | _C.RESNET3D_DEPTH = None 18 | # Path to the train dataset - usually a json file 19 | _C.TRAIN_DATASET_PATH = None 20 | # Path to the val dataset - usually a json file 21 | _C.VAL_DATASET_PATH = None 22 | # The name of the training dataset 23 | _C.TRAIN_DATASET_NAME = None 24 | # The name of the validation dataset 25 | _C.VAL_DATASET_NAME = None 26 | # Path to the labels 27 | _C.LABELS_PATH = None 28 | # Version of the dataset: Relevant for Epic Kitchens 29 | _C.DATASET_VERSION = None 30 | # Type of the dataset: Can be video or layout 31 | _C.DATASET_TYPE = None 32 | # Train subset 33 | _C.TRAIN_SUBSET = None 34 | # Validation subset 35 | _C.VAL_SUBSET = None 36 | # Whether the data during training is omnivore and model 37 | _C.OMNIVORE = None 38 | # Path to the videos 39 | _C.VIDEOS_PATH = None 40 | # Meant to modify the participants list (unseen list) 41 | _C.EPIC_PARTICIPANTS = None 42 | # Path to the flow 43 | _C.FLOW_PATH = None 44 | # Path to the depth 45 | _C.DEPTH_PATH = None 46 | # Path to the skeleton 47 | _C.SKELETON_PATH = None 48 | # Path to the segmentations 49 | _C.SEGMENTATION_PATH = None 50 | # Path to frame mapping dictionary (used for the Visor dataset in Epic-Kitchens) 51 | _C.SEGMENTATION_FRAME_MAPPING_PATH = None 52 | # Path to the detections (only used when Epic-Kitchens) 53 | _C.DETECTIONS_PATH = None 54 | # Number of frames to be sampled from the video 55 | _C.NUM_FRAMES = 8 56 | # Number of frames the student will use 57 | _C.STUDENT_SUBSAMPLE_RATE = None 58 | # Whether to use uniform sampling of frames 59 | _C.FRAME_UNIFORM = False 60 | # The batch size 61 | _C.BATCH_SIZE = 12 62 | # The learning rate 63 | _C.LEARNING_RATE = 0.0001 64 | # The weight decay 65 | _C.WEIGHT_DECAY = 0.0001 66 | # The gradient clipping value 67 | _C.CLIP_VAL = 5.0 68 | # The number of workers - threads in the dataloader 69 | _C.NUM_WORKERS = 0 70 | # The number of epochs for training 71 | _C.EPOCHS = 20 72 | # The number of warmup epochs for the warmup optimizer 73 | _C.WARMUP_EPOCHS = 2 74 | # Path to a pre-trained backbone 75 | _C.BACKBONE_MODEL_PATH = None 76 | # Path to existing checkpoint 77 | _C.CHECKPOINT_PATH = None 78 | # Device used for training: Cuda or Cpu 79 | _C.DEVICE = "cuda" 80 | # Where to save the experiment 81 | _C.EXPERIMENT_PATH = "experiments/default" 82 | # Whether to freeze the backbone during training 83 | _C.FREEZE_BACKBONE = False 84 | # Number of testing clips when evaluating 85 | _C.NUM_TEST_CLIPS = 1 86 | # How many crops during inference: Either 1 or 3 87 | _C.NUM_TEST_CROPS = 1 88 | # Stride for getting frames 89 | _C.STRIDE = 8 90 | # The augmentations per dataset 91 | _C.AUGMENTATIONS = CN() 92 | _C.AUGMENTATIONS.VIDEO = ["IdentityTransform"] 93 | _C.AUGMENTATIONS.FLOW = ["IdentityTransform"] 94 | _C.AUGMENTATIONS.DEPTH = ["IdentityTransform"] 95 | _C.AUGMENTATIONS.SKELETON = ["IdentityTransform"] 96 | _C.AUGMENTATIONS.LAYOUT = ["IdentityTransform"] 97 | _C.AUGMENTATIONS.AUDIO = ["IdentityTransform"] 98 | _C.AUGMENTATIONS.SEGMENTATION = ["IdentityTransform"] 99 | # The number of actions, for each action type 100 | _C.TOTAL_ACTIONS = CN() 101 | _C.TOTAL_ACTIONS.ACTION = None 102 | _C.TOTAL_ACTIONS.NOUN = None 103 | _C.TOTAL_ACTIONS.VERB = None 104 | # The action weights, for each action type 105 | _C.ACTION_WEIGHTS = CN() 106 | _C.ACTION_WEIGHTS.ACTION = None 107 | _C.ACTION_WEIGHTS.NOUN = None 108 | _C.ACTION_WEIGHTS.VERB = None 109 | # The losses weights, ONLY when distillation 110 | _C.LOSS_WEIGHTS = CN() 111 | _C.LOSS_WEIGHTS.DISTILLATION = 1.0 112 | _C.LOSS_WEIGHTS.GROUND_TRUTH = 0.0 113 | # Path to the teacher experiment path 114 | _C.TEACHERS = CN() 115 | _C.TEACHERS.OBJ_TEACHER_EXPERIMENT_PATH = None 116 | _C.TEACHERS.RGB_TEACHER_EXPERIMENT_PATH = None 117 | _C.TEACHERS.FLOW_TEACHER_EXPERIMENT_PATH = None 118 | _C.TEACHERS.DEPTH_TEACHER_EXPERIMENT_PATH = None 119 | _C.TEACHERS.SKELETON_TEACHER_EXPERIMENT_PATH = None 120 | _C.TEACHERS.AUDIO_TEACHER_EXPERIMENT_PATH = None 121 | _C.TEACHERS.SEGMENTATION_TEACHER_EXPERIMENT_PATH = None 122 | _C.CALIBRATE_TEACHER = False 123 | # Can be None, per-sample, full-dataset 124 | _C.DISTILLATION_WEIGHTING_SCHEME = None 125 | # The temperature (used during distilation) 126 | _C.TEMPERATURE = 1.0 127 | # The weighting temperature (used during distillation) 128 | _C.WEIGHTS_TEMPERATURE = 1.0 129 | # Audio stuff 130 | _C.AUDIO_PATH = None 131 | # Video FPS 132 | _C.VIDEO_FPS = 60 133 | # Audio sample rate 134 | _C.AUDIO_SAMPLE_RATE = 24000 135 | # Length of audio segments extracted from action sequence (in seconds) 136 | _C.AUDIO_SEGMENT_LENGTH = 1.116 # to get 224 width of spectrogram 137 | # Audio resample rate 138 | _C.AUDIO_RESAMPLE_RATE = None 139 | # Evaluate for tail class indices (on EPIC) 140 | _C.EPIC_TAIL_NOUNS_PATH = None 141 | _C.EPIC_TAIL_VERBS_PATH = None 142 | _C.TEST_SET_INFERENCE = False 143 | 144 | 145 | def get_cfg_defaults(): 146 | """Get a yacs CfgNode object with default values for my_project.""" 147 | # Return a clone so that the defaults will not be altered 148 | # This is for the "local variable" use pattern 149 | return _C.clone() 150 | 151 | 152 | def train_setup(description: str): 153 | parser = argparse.ArgumentParser(description=description) 154 | parser.add_argument( 155 | "--config_path", type=str, required=True, help="Path to the config file." 156 | ) 157 | parser.add_argument("--opts", nargs=argparse.REMAINDER) 158 | args = parser.parse_args() 159 | cfg = get_cfg_defaults() 160 | cfg.merge_from_file(args.config_path) 161 | cfg.CONFIG_PATH = args.config_path 162 | if args.opts: 163 | cfg.merge_from_list(args.opts) 164 | cfg.freeze() 165 | # Prepare accelerator 166 | accelerator = Accelerator( 167 | kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], 168 | cpu=cfg.DEVICE == "cpu", 169 | ) 170 | # Saving logic 171 | if accelerator.is_main_process: 172 | # If we have experiment already & not restarting --> Error! 173 | if os.path.exists(cfg.EXPERIMENT_PATH) and not cfg.WARM_RESTART: 174 | raise ValueError( 175 | f"{cfg.EXPERIMENT_PATH} exists & WARM_RESTART is False!\n" 176 | f"Please delete {cfg.EXPERIMENT_PATH} and run again!" 177 | ) 178 | # If we are restarting, we have to have experiment! 179 | if cfg.WARM_RESTART: 180 | assert os.path.exists( 181 | cfg.EXPERIMENT_PATH 182 | ), f"There is no {cfg.EXPERIMENT_PATH} to restart from!" 183 | else: 184 | # Otherwise, we create the experiment directory 185 | os.makedirs(cfg.EXPERIMENT_PATH, exist_ok=False) 186 | with open(os.path.join(cfg.EXPERIMENT_PATH, "config.yaml"), "w") as f: 187 | f.write(cfg.dump()) 188 | 189 | return cfg, accelerator 190 | -------------------------------------------------------------------------------- /src/modelling/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from yacs.config import CfgNode 3 | 4 | from modelling.dataset_audio import AudioDataset 5 | from modelling.dataset_layout import LayoutDataset 6 | from modelling.dataset_video import VideoDataset 7 | 8 | 9 | class VideoFlow(Dataset): 10 | # FIXME: Hacky 11 | def __init__(self, cfg: CfgNode, train: bool = False): 12 | # Flow dataset 13 | flow_cfg = cfg.clone() 14 | flow_cfg.defrost() 15 | flow_cfg.DATASET_TYPE = "flow" 16 | flow_cfg.freeze() 17 | self.flow_dataset = VideoDataset(flow_cfg, train=train) 18 | # video dataset 19 | video_cfg = cfg.clone() 20 | video_cfg.defrost() 21 | video_cfg.DATASET_TYPE = "video" 22 | video_cfg.freeze() 23 | self.video_dataset = VideoDataset(video_cfg, train=train) 24 | 25 | def __len__(self): 26 | return self.flow_dataset.__len__() 27 | 28 | def __getitem__(self, idx: int): 29 | flow_dict = self.flow_dataset[idx] 30 | self.video_dataset.set_indices(flow_dict["indices"]) 31 | self.video_dataset.set_existing_transforms( 32 | self.flow_dataset.enforced_transforms 33 | ) 34 | video_dict = self.video_dataset[idx] 35 | # Gather from both dicts 36 | output = {} 37 | for c_dict in [flow_dict, video_dict]: 38 | for key in c_dict.keys(): 39 | output[key] = c_dict[key] 40 | 41 | return output 42 | 43 | 44 | class VideoAudio(Dataset): 45 | # FIXME: Hacky 46 | def __init__(self, cfg: CfgNode, train: bool = False): 47 | # Audio dataset 48 | audio_cfg = cfg.clone() 49 | audio_cfg.defrost() 50 | audio_cfg.DATASET_TYPE = "audio" 51 | audio_cfg.freeze() 52 | self.audio_dataset = AudioDataset(audio_cfg, train=train) 53 | # video dataset 54 | video_cfg = cfg.clone() 55 | video_cfg.defrost() 56 | video_cfg.DATASET_TYPE = "video" 57 | video_cfg.freeze() 58 | self.video_dataset = VideoDataset(video_cfg, train=train) 59 | 60 | def __len__(self): 61 | return self.audio_dataset.__len__() 62 | 63 | def __getitem__(self, idx: int): 64 | audio_dict = self.audio_dataset[idx] 65 | self.video_dataset.set_indices(audio_dict["indices"]) 66 | video_dict = self.video_dataset[idx] 67 | # Gather from both dicts 68 | output = {} 69 | for c_dict in [audio_dict, video_dict]: 70 | for key in c_dict.keys(): 71 | output[key] = c_dict[key] 72 | 73 | return output 74 | 75 | 76 | class VideoFlowAudio(Dataset): 77 | # FIXME: Hacky 78 | def __init__(self, cfg: CfgNode, train: bool = False): 79 | # Audio dataset 80 | audio_cfg = cfg.clone() 81 | audio_cfg.defrost() 82 | audio_cfg.DATASET_TYPE = "audio" 83 | audio_cfg.freeze() 84 | self.audio_dataset = AudioDataset(audio_cfg, train=train) 85 | # Flow dataset 86 | flow_cfg = cfg.clone() 87 | flow_cfg.defrost() 88 | flow_cfg.DATASET_TYPE = "flow" 89 | flow_cfg.freeze() 90 | self.flow_dataset = VideoDataset(flow_cfg, train=train) 91 | # video dataset 92 | video_cfg = cfg.clone() 93 | video_cfg.defrost() 94 | video_cfg.DATASET_TYPE = "video" 95 | video_cfg.freeze() 96 | self.video_dataset = VideoDataset(video_cfg, train=train) 97 | 98 | def set_weighted(self): 99 | from copy import deepcopy 100 | 101 | copy_self = deepcopy(self) 102 | copy_self.audio_dataset = self.audio_dataset.set_weighted() 103 | copy_self.video_dataset = self.video_dataset.set_weighted() 104 | copy_self.flow_dataset = self.flow_dataset.set_weighted() 105 | 106 | return copy_self 107 | 108 | def __len__(self): 109 | return self.flow_dataset.__len__() 110 | 111 | def __getitem__(self, idx: int): 112 | audio_dict = self.audio_dataset[idx] 113 | self.flow_dataset.set_indices(audio_dict["indices"]) 114 | flow_dict = self.flow_dataset[idx] 115 | self.video_dataset.set_indices(audio_dict["indices"]) 116 | self.video_dataset.set_existing_transforms( 117 | self.flow_dataset.enforced_transforms 118 | ) 119 | video_dict = self.video_dataset[idx] 120 | # Gather from both dicts 121 | output = {} 122 | for c_dict in [flow_dict, video_dict, audio_dict]: 123 | for key in c_dict.keys(): 124 | output[key] = c_dict[key] 125 | 126 | return output 127 | 128 | 129 | class VideoLayoutDataset(Dataset): 130 | # FIXME: Hacky 131 | def __init__(self, cfg: CfgNode, train: bool = False): 132 | self.layout_dataset = LayoutDataset(cfg, train=train) 133 | video_cfg = cfg.clone() 134 | video_cfg.defrost() 135 | video_cfg.DATASET_TYPE = "video" 136 | video_cfg.freeze() 137 | self.video_dataset = VideoDataset(video_cfg, train=train) 138 | 139 | def __len__(self): 140 | return self.layout_dataset.__len__() 141 | 142 | def __getitem__(self, idx: int): 143 | layout_dict = self.layout_dataset[idx] 144 | self.video_dataset.set_indices(layout_dict["indices"]) 145 | video_dict = self.video_dataset[idx] 146 | # Gather from both dicts 147 | output = {} 148 | for c_dict in [layout_dict, video_dict]: 149 | for key in c_dict.keys(): 150 | output[key] = c_dict[key] 151 | 152 | return output 153 | 154 | 155 | class VideoLayoutFlow(Dataset): 156 | def __init__(self, cfg: CfgNode, train: bool = False): 157 | # FIXME: Hacky 158 | # Layout dataset 159 | self.layout_dataset = LayoutDataset(cfg, train=train) 160 | # Flow dataset 161 | flow_cfg = cfg.clone() 162 | flow_cfg.defrost() 163 | flow_cfg.DATASET_TYPE = "flow" 164 | flow_cfg.freeze() 165 | self.flow_dataset = VideoDataset(flow_cfg, train=train) 166 | # video dataset 167 | video_cfg = cfg.clone() 168 | video_cfg.defrost() 169 | video_cfg.DATASET_TYPE = "video" 170 | video_cfg.freeze() 171 | self.video_dataset = VideoDataset(video_cfg, train=train) 172 | 173 | def __len__(self): 174 | return self.flow_dataset.__len__() 175 | 176 | def __getitem__(self, idx: int): 177 | flow_dict = self.flow_dataset[idx] 178 | self.video_dataset.set_indices(flow_dict["indices"]) 179 | self.video_dataset.set_existing_transforms( 180 | self.flow_dataset.enforced_transforms 181 | ) 182 | video_dict = self.video_dataset[idx] 183 | self.layout_dataset.set_indices(flow_dict["indices"]) 184 | layout_dict = self.layout_dataset[idx] 185 | # Gather from both dicts 186 | output = {} 187 | for c_dict in [flow_dict, video_dict, layout_dict]: 188 | for key in c_dict.keys(): 189 | output[key] = c_dict[key] 190 | 191 | return output 192 | 193 | 194 | dataset_factory = { 195 | "video": VideoDataset, 196 | "flow": VideoDataset, 197 | "depth": VideoDataset, 198 | "layout": LayoutDataset, 199 | "video_layout": VideoLayoutDataset, 200 | "video_flow": VideoFlow, 201 | "video_layout_flow": VideoLayoutFlow, 202 | "video_flow_audio": VideoFlowAudio, 203 | "audio": AudioDataset, 204 | "video_audio": VideoAudio, 205 | } 206 | -------------------------------------------------------------------------------- /src/utils/samplers.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from yacs.config import CfgNode 5 | 6 | 7 | class StltSampler: 8 | def __init__(self, cfg: CfgNode, train: bool = False): 9 | self.cfg = cfg 10 | self.train = train 11 | 12 | def __call__(self, video_length: int): 13 | seg_size = float(video_length - 1) / self.cfg.NUM_FRAMES 14 | seq = [] 15 | for i in range(self.cfg.NUM_FRAMES): 16 | start = int(np.round(seg_size * i)) 17 | end = int(np.round(seg_size * (i + 1))) 18 | if self.train: 19 | seq.append(random.randint(start, end)) 20 | else: 21 | seq.append((start + end) // 2) 22 | 23 | seq = np.array(seq) 24 | 25 | return seq 26 | 27 | 28 | # Adapted from https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/datasets/pipelines/loading.py#L79 29 | class GeneralSampler: 30 | def __init__(self, cfg: CfgNode, train: bool = False): 31 | self.cfg = cfg 32 | self.clip_len = self.cfg.NUM_FRAMES 33 | self.frame_interval = self.cfg.STRIDE 34 | self.num_clips = self.cfg.NUM_TEST_CLIPS 35 | self.temporal_jitter = False 36 | self.twice_sample = False 37 | self.out_of_bound_opt = "repeat_last" 38 | self.train = train 39 | self.frame_uniform = cfg.FRAME_UNIFORM 40 | assert self.out_of_bound_opt in ["loop", "repeat_last"] 41 | 42 | def _get_train_clips(self, video_length): 43 | """Get clip offsets in train mode. 44 | It will calculate the average interval for selected frames, 45 | and randomly shift them within offsets between [0, avg_interval]. 46 | If the total number of frames is smaller than clips num or origin 47 | frames length, it will return all zero indices. 48 | Args: 49 | video_length (int): Total number of frame in the video. 50 | Returns: 51 | np.ndarray: Sampled frame indices in train mode. 52 | """ 53 | ori_clip_len = self.clip_len * self.frame_interval 54 | avg_interval = (video_length - ori_clip_len + 1) // self.num_clips 55 | 56 | if avg_interval > 0: 57 | base_offsets = np.arange(self.num_clips) * avg_interval 58 | clip_offsets = base_offsets + np.random.randint( 59 | avg_interval, size=self.num_clips 60 | ) 61 | elif video_length > max(self.num_clips, ori_clip_len): 62 | clip_offsets = np.sort( 63 | np.random.randint(video_length - ori_clip_len + 1, size=self.num_clips) 64 | ) 65 | elif avg_interval == 0: 66 | ratio = (video_length - ori_clip_len + 1.0) / self.num_clips 67 | clip_offsets = np.around(np.arange(self.num_clips) * ratio) 68 | else: 69 | clip_offsets = np.zeros((self.num_clips,), dtype=int) 70 | 71 | return clip_offsets 72 | 73 | def _get_test_clips(self, video_length): 74 | """Get clip offsets in test mode. 75 | Calculate the average interval for selected frames, and shift them 76 | fixedly by avg_interval/2. If set twice_sample True, it will sample 77 | frames together without fixed shift. If the total number of frames is 78 | not enough, it will return all zero indices. 79 | Args: 80 | video_length (int): Total number of frame in the video. 81 | Returns: 82 | np.ndarray: Sampled frame indices in test mode. 83 | """ 84 | ori_clip_len = self.clip_len * self.frame_interval 85 | avg_interval = (video_length - ori_clip_len + 1) / float(self.num_clips) 86 | if video_length > ori_clip_len - 1: 87 | base_offsets = np.arange(self.num_clips) * avg_interval 88 | clip_offsets = (base_offsets + avg_interval / 2.0).astype(int) 89 | if self.twice_sample: 90 | clip_offsets = np.concatenate([clip_offsets, base_offsets]) 91 | else: 92 | clip_offsets = np.zeros((self.num_clips,), dtype=int) 93 | return clip_offsets 94 | 95 | def _sample_clips(self, video_length): 96 | """Choose clip offsets for the video in a given mode. 97 | Args: 98 | video_length (int): Total number of frame in the video. 99 | Returns: 100 | np.ndarray: Sampled frame indices. 101 | """ 102 | if self.train: 103 | clip_offsets = self._get_train_clips(video_length) 104 | else: 105 | clip_offsets = self._get_test_clips(video_length) 106 | 107 | return clip_offsets 108 | 109 | def get_seq_frames(self, video_length): 110 | """ 111 | Modified from https://github.com/facebookresearch/SlowFast/blob/64abcc90ccfdcbb11cf91d6e525bed60e92a8796/slowfast/datasets/ssv2.py#L159 112 | Given the video index, return the list of sampled frame indexes. 113 | Args: 114 | video_length (int): Total number of frame in the video. 115 | Returns: 116 | seq (list): the indexes of frames of sampled from the video. 117 | """ 118 | # Update desired clip_len based on num_clips 119 | clip_len = self.clip_len * self.num_clips 120 | # Proceed to rest 121 | seg_size = float(video_length - 1) / clip_len 122 | seq = [] 123 | for i in range(clip_len): 124 | start = int(np.round(seg_size * i)) 125 | end = int(np.round(seg_size * (i + 1))) 126 | if self.train: 127 | seq.append(random.randint(start, end)) 128 | else: 129 | seq.append((start + end) // 2) 130 | 131 | return np.array(seq) 132 | 133 | def __call__(self, video_length: int): 134 | """Perform the SampleFrames loading. 135 | Args: 136 | results (dict): The resulting dict to be modified and passed 137 | to the next transform in pipeline. 138 | """ 139 | if self.frame_uniform: # sthv2 sampling strategy 140 | frame_inds = self.get_seq_frames(video_length) 141 | else: 142 | clip_offsets = self._sample_clips(video_length) 143 | frame_inds = ( 144 | clip_offsets[:, None] 145 | + np.arange(self.clip_len)[None, :] * self.frame_interval 146 | ) 147 | frame_inds = np.concatenate(frame_inds) 148 | 149 | if self.temporal_jitter: 150 | perframe_offsets = np.random.randint( 151 | self.frame_interval, size=len(frame_inds) 152 | ) 153 | frame_inds += perframe_offsets 154 | 155 | frame_inds = frame_inds.reshape((-1, self.clip_len)) 156 | if self.out_of_bound_opt == "loop": 157 | frame_inds = np.mod(frame_inds, video_length) 158 | elif self.out_of_bound_opt == "repeat_last": 159 | safe_inds = frame_inds < video_length 160 | unsafe_inds = 1 - safe_inds 161 | last_ind = np.max(safe_inds * frame_inds, axis=1) 162 | new_inds = safe_inds * frame_inds + (unsafe_inds.T * last_ind).T 163 | frame_inds = new_inds 164 | else: 165 | raise ValueError("Illegal out_of_bound option.") 166 | 167 | frame_inds = np.concatenate(frame_inds) 168 | 169 | return frame_inds.astype(int) 170 | 171 | 172 | # TODO: Find better naming for these 173 | get_sampler = { 174 | "swin": GeneralSampler, 175 | "stlt": StltSampler, 176 | "resnet3d": GeneralSampler, 177 | } 178 | -------------------------------------------------------------------------------- /src/modelling/models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict 3 | 4 | import torch 5 | from torch import nn 6 | from torchvision.transforms import functional as TF 7 | from yacs.config import CfgNode 8 | 9 | from modelling.hand_models import Stlt 10 | from modelling.resnets3d import generate_model 11 | from modelling.swin import SwinTransformer3D 12 | from utils.data_utils import get_normalizer 13 | 14 | 15 | class R3d(nn.Module): 16 | def __init__(self, cfg: CfgNode): 17 | super(R3d, self).__init__() 18 | self.cfg = cfg 19 | resnet = generate_model(model_depth=self.cfg.RESNET3D_DEPTH, n_classes=700) 20 | if self.cfg.BACKBONE_MODEL_PATH: 21 | checkpoint = torch.load(cfg.BACKBONE_MODEL_PATH, map_location="cpu") 22 | resnet.load_state_dict(checkpoint["state_dict"]) 23 | # Freeze the BatchNorm3D layers 24 | for module in resnet.modules(): 25 | if isinstance(module, nn.BatchNorm3d): 26 | module.weight.requires_grad = False 27 | module.bias.requires_grad = False 28 | # Strip the last two layers (pooling & classifier) 29 | self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) 30 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 31 | if self.cfg.RESNET3D_DEPTH < 50: 32 | self.projector = nn.Sequential(nn.Linear(512, 2048), nn.ReLU()) 33 | # Build classifier 34 | self.classifiers = nn.ModuleDict( 35 | { 36 | actions_name: nn.Linear(2048, actions_num) 37 | for actions_name, actions_num in self.cfg.TOTAL_ACTIONS.items() 38 | if actions_num is not None 39 | } 40 | ) 41 | # Load existing checkpoint, if any 42 | if cfg.CHECKPOINT_PATH: 43 | self.load_state_dict(torch.load(cfg.CHECKPOINT_PATH, map_location="cpu")) 44 | 45 | def train(self, mode: bool): 46 | super(R3d, self).train(mode) 47 | if self.cfg.BACKBONE_MODEL_PATH: 48 | for module in self.resnet.modules(): 49 | if isinstance(module, nn.BatchNorm3d): 50 | module.train(False) 51 | 52 | def get_modality(self): 53 | mapping = { 54 | "video": "video", 55 | "video_flow": "video", 56 | "flow": "flow", 57 | "depth": "depth", 58 | "video_layout": "video", 59 | "video_depth": "video", 60 | "layout": "layout", 61 | "video_layout_flow": "video", 62 | "omnivore": "omnivore", 63 | "audio": "audio", 64 | "video_audio": "video", 65 | "video_flow_audio": "video", 66 | "segmentation": "segmentation", 67 | "video_segmentation": "video", 68 | } 69 | 70 | return mapping[self.cfg.DATASET_TYPE] 71 | 72 | def forward(self, batch: Dict[str, torch.Tensor]): 73 | # Obtain the modality 74 | modality = self.get_modality() 75 | # Get the video frames and prepare 76 | video_frames = batch[modality] 77 | # Normalize video frames 78 | normalizer = get_normalizer(input_type=modality, model_name="resnet3d") 79 | video_frames = normalizer(video_frames) 80 | b, n_clips, n_frames, c, s, s = video_frames.size() 81 | # print(video_frames.size()) 82 | video_frames = video_frames.reshape(b * n_clips, n_frames, c, s, s) 83 | # HACK: Resize the video frames to 112 in case they're not already 84 | if s > 112: 85 | video_frames = video_frames.view(-1, c, s, s) 86 | video_frames = TF.resize(video_frames, size=(112, 112), antialias=True) 87 | video_frames = video_frames.view(b * n_clips, n_frames, c, 112, 112) 88 | video_frames = video_frames.permute(0, 2, 1, 3, 4) 89 | # Extract features 90 | output = {} 91 | features = self.avgpool(self.resnet(video_frames)).flatten(1) 92 | features = features.contiguous() 93 | if self.cfg.RESNET3D_DEPTH < 50: 94 | features = self.projector(features) 95 | # Classify 96 | for actions_name in self.classifiers.keys(): 97 | output[actions_name] = self.classifiers[actions_name](features) 98 | 99 | return output 100 | 101 | 102 | class SwinModel(nn.Module): 103 | def __init__(self, cfg: CfgNode): 104 | super(SwinModel, self).__init__() 105 | self.cfg = cfg 106 | # Create backbone 107 | self.backbone = SwinTransformer3D( 108 | patch_size=(2, 4, 4), 109 | embed_dim=96, 110 | depths=[2, 2, 6, 2], 111 | num_heads=[3, 6, 12, 24], 112 | window_size=(8, 7, 7), 113 | mlp_ratio=4.0, 114 | qkv_bias=True, 115 | qk_scale=None, 116 | drop_rate=0.0, 117 | attn_drop_rate=0.0, 118 | drop_path_rate=0.2, 119 | patch_norm=True, 120 | ) 121 | if self.cfg.BACKBONE_MODEL_PATH: 122 | checkpoint = torch.load(self.cfg.BACKBONE_MODEL_PATH, map_location="cpu") 123 | new_state_dict = OrderedDict() 124 | for k, v in checkpoint["state_dict"].items(): 125 | if "backbone" in k: 126 | name = k[9:] 127 | new_state_dict[name] = v 128 | self.backbone.load_state_dict(new_state_dict) 129 | # Build classifier 130 | self.classifiers = nn.ModuleDict( 131 | { 132 | actions_name: nn.Linear(768, actions_num) 133 | for actions_name, actions_num in self.cfg.TOTAL_ACTIONS.items() 134 | if actions_num is not None 135 | } 136 | ) 137 | # Load existing checkpoint, if any 138 | if cfg.CHECKPOINT_PATH: 139 | self.load_state_dict(torch.load(cfg.CHECKPOINT_PATH, map_location="cpu")) 140 | 141 | def get_modality(self): 142 | mapping = { 143 | "video": "video", 144 | "video_flow": "video", 145 | "flow": "flow", 146 | "depth": "depth", 147 | "video_layout": "video", 148 | "video_depth": "video", 149 | "layout": "layout", 150 | "video_layout_flow": "video", 151 | "omnivore": "omnivore", 152 | "audio": "audio", 153 | "video_audio": "video", 154 | "video_flow_audio": "video", 155 | "segmentation": "segmentation", 156 | "video_segmentation": "video", 157 | } 158 | 159 | return mapping[self.cfg.DATASET_TYPE] 160 | 161 | def forward(self, batch: Dict[str, torch.Tensor]): 162 | # Obtain the modality 163 | modality = self.get_modality() 164 | # Get the video frames and prepare 165 | video_frames = batch[modality] 166 | # Normalize video frames 167 | normalizer = get_normalizer(input_type=modality, model_name="swin") 168 | video_frames = normalizer(video_frames) 169 | b, n_clips, n_frames, c, s, s = video_frames.size() 170 | video_frames = video_frames.reshape(b * n_clips, n_frames, c, s, s) 171 | video_frames = video_frames.permute(0, 2, 1, 3, 4) 172 | # Extract features 173 | output = {} 174 | features = self.backbone(video_frames) 175 | features = features.mean(dim=[2, 3, 4]) 176 | # Classify 177 | for actions_name in self.classifiers.keys(): 178 | output[actions_name] = self.classifiers[actions_name](features) 179 | 180 | return output 181 | 182 | 183 | model_factory = { 184 | "resnet3d": R3d, 185 | "stlt": Stlt, 186 | "swin": SwinModel, 187 | } 188 | -------------------------------------------------------------------------------- /src/modelling/distiller.py: -------------------------------------------------------------------------------- 1 | from os.path import join as pjoin 2 | from typing import Dict, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from yacs.config import CfgNode 8 | 9 | from modelling.models import model_factory 10 | from utils.setup import get_cfg_defaults 11 | 12 | 13 | class DistillationCriterion: 14 | def __init__(self, cfg): 15 | self.cfg = cfg 16 | self.gt_criterion = nn.CrossEntropyLoss() 17 | self.action_names = [ 18 | action 19 | for action in self.cfg.ACTION_WEIGHTS.keys() 20 | if self.cfg.ACTION_WEIGHTS[action] 21 | ] 22 | 23 | def measure_temp_scaled_kl_loss(self, student_logits, teacher_logits): 24 | loss = 0 25 | for action_name in self.action_names: 26 | target = F.log_softmax( 27 | teacher_logits[action_name] / self.cfg.TEMPERATURE, dim=-1 28 | ) 29 | pred = F.log_softmax( 30 | student_logits[action_name] / self.cfg.TEMPERATURE, dim=-1 31 | ) 32 | cur_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)( 33 | pred, target 34 | ) 35 | cur_loss = cur_loss * (self.cfg.TEMPERATURE ** 2) 36 | loss += cur_loss 37 | # Average loss 38 | loss = loss / len(self.action_names) 39 | 40 | return loss 41 | 42 | def measure_gt_loss(self, student_logits, gt_labels): 43 | total_loss = 0 44 | # Aggregate losses 45 | for action_name in self.action_names: 46 | loss = ( 47 | self.gt_criterion(student_logits[action_name], gt_labels[action_name]) 48 | * self.cfg.ACTION_WEIGHTS[action_name] 49 | ) 50 | total_loss += loss 51 | 52 | return total_loss / len(self.action_names) 53 | 54 | def __call__(self, student_logits, teacher_logits, batch): 55 | # Measure losses 56 | distillation_loss = self.measure_temp_scaled_kl_loss( 57 | student_logits, teacher_logits 58 | ) 59 | gt_loss = self.measure_gt_loss(student_logits, batch["labels"]) 60 | # Compute total loss 61 | total_loss = ( 62 | distillation_loss * self.cfg.LOSS_WEIGHTS.DISTILLATION 63 | + gt_loss * self.cfg.LOSS_WEIGHTS.GROUND_TRUTH 64 | ) 65 | 66 | return total_loss 67 | 68 | 69 | class TeacherEnsemble(nn.Module): 70 | def __init__(self, cfg: CfgNode): 71 | super(TeacherEnsemble, self).__init__() 72 | # Prepare teachers 73 | self.cfg = cfg 74 | teachers = {} 75 | for teacher_name in self.cfg.TEACHERS.keys(): 76 | # Skip empty teachers 77 | if not self.cfg.TEACHERS.get(teacher_name): 78 | continue 79 | teacher_cfg = get_cfg_defaults() 80 | teacher_cfg.merge_from_file( 81 | pjoin(self.cfg.TEACHERS.get(teacher_name), "config.yaml") 82 | ) 83 | teacher = model_factory[teacher_cfg.MODEL_NAME](teacher_cfg) 84 | checkpoint = torch.load( 85 | pjoin(self.cfg.TEACHERS.get(teacher_name), "model_checkpoint.pt"), 86 | map_location="cpu", 87 | ) 88 | from collections import OrderedDict 89 | 90 | unwrapped_checkpoint = OrderedDict() 91 | prefix = "_orig_mod." 92 | for key in checkpoint.keys(): 93 | if key.startswith(prefix): 94 | unwrapped_checkpoint[key[len(prefix) :]] = checkpoint[key] 95 | else: 96 | unwrapped_checkpoint[key] = checkpoint[key] 97 | teacher.load_state_dict(unwrapped_checkpoint) 98 | teacher.train(False) 99 | teachers[teacher_name] = teacher 100 | self.teachers = nn.ModuleDict(teachers) 101 | # Establish initial weights 102 | assert self.cfg.DISTILLATION_WEIGHTING_SCHEME in [ 103 | None, 104 | "per-sample", 105 | "full-dataset", 106 | ] 107 | weights = {} 108 | for action_name, action_num in self.cfg.TOTAL_ACTIONS.items(): 109 | if action_num is not None: 110 | weights[action_name] = nn.Parameter( 111 | torch.full( 112 | size=(len(teachers),), fill_value=1 / len(self.teachers) 113 | ), 114 | requires_grad=False, 115 | ) 116 | self.weights = nn.ParameterDict(weights) 117 | 118 | def get_losses_wrt_labels(self, logits: torch.Tensor, labels: torch.Tensor): 119 | b, n_t, c = logits.size() 120 | # Prepare logits & labels 121 | logits = logits.reshape(-1, c) 122 | # BE CAREFULL, repeat would be wrong here! 123 | labels = labels.repeat_interleave(n_t) 124 | # Compute weights: [Batch_Size, Num_Teachers] 125 | per_teacher_losses = F.cross_entropy(logits, labels, reduction="none") 126 | # Reshape back in original shape 127 | per_teacher_losses = per_teacher_losses.reshape(b, n_t) 128 | 129 | return per_teacher_losses 130 | 131 | def get_per_sample_weights(self, logits: torch.Tensor, labels: torch.Tensor): 132 | per_teacher_losses = self.get_losses_wrt_labels(logits, labels) 133 | return F.softmin(per_teacher_losses / self.cfg.WEIGHTS_TEMPERATURE, dim=-1) 134 | 135 | def get_teacher_logits(self, batch): 136 | teacher_outputs: Dict[str, Dict[str, torch.Tensor]] = {} 137 | for teacher_name in self.cfg.TEACHERS.keys(): 138 | # Skip empty teachers 139 | if not self.cfg.TEACHERS.get(teacher_name): 140 | continue 141 | teacher_outputs[teacher_name] = self.teachers[teacher_name](batch) 142 | # Gather teacher logits 143 | teacher_logits: Dict[str, List[torch.Tensor]] = {} 144 | for teacher_output in teacher_outputs.values(): 145 | for action_name in teacher_output.keys(): 146 | if action_name not in teacher_logits: 147 | teacher_logits[action_name] = [] 148 | teacher_logits[action_name].append(teacher_output[action_name]) 149 | 150 | return teacher_logits 151 | 152 | def get_weights( 153 | self, action_name: str, logits: torch.Tensor, labels: Dict[str, torch.Tensor] 154 | ): 155 | # Weights obtained dynamically per-sample 156 | if self.cfg.DISTILLATION_WEIGHTING_SCHEME == "per-sample": 157 | weights = self.get_per_sample_weights(logits, labels[action_name]) 158 | # Weights were updated during set_teacher_weights in patient distill 159 | elif self.cfg.DISTILLATION_WEIGHTING_SCHEME == "full-dataset": 160 | weights = self.weights[action_name] 161 | # The initial weights are set-up as 1 / num_teachers 162 | elif self.cfg.DISTILLATION_WEIGHTING_SCHEME is None: 163 | weights = self.weights[action_name] 164 | 165 | return weights 166 | 167 | @torch.no_grad() 168 | def forward(self, batch): 169 | teacher_logits = self.get_teacher_logits(batch) 170 | # (Weighed) average of teacher logits 171 | for action_name in teacher_logits: 172 | # [Batch_Size, Num_Teachers, Num_Classes] 173 | logits = torch.cat( 174 | [t_o.unsqueeze(1) for t_o in teacher_logits[action_name]], dim=1 175 | ) 176 | weights = self.get_weights(action_name, logits, batch["labels"]) 177 | # [Batch_Size, Num_Teachers] 178 | logits *= weights.unsqueeze(-1) 179 | # Average teacher predictions 180 | teacher_logits[action_name] = logits.sum(1) 181 | 182 | return teacher_logits 183 | -------------------------------------------------------------------------------- /src/modelling/resnets3d.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def get_inplanes(): 9 | return [64, 128, 256, 512] 10 | 11 | 12 | def conv3x3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv3d( 14 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 15 | ) 16 | 17 | 18 | def conv1x1x1(in_planes, out_planes, stride=1): 19 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1, downsample=None): 26 | super().__init__() 27 | 28 | self.conv1 = conv3x3x3(in_planes, planes, stride) 29 | self.bn1 = nn.BatchNorm3d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm3d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, in_planes, planes, stride=1, downsample=None): 59 | super().__init__() 60 | self.conv1 = conv1x1x1(in_planes, planes) 61 | self.bn1 = nn.BatchNorm3d(planes) 62 | self.conv2 = conv3x3x3(planes, planes, stride) 63 | self.bn2 = nn.BatchNorm3d(planes) 64 | self.conv3 = conv1x1x1(planes, planes * self.expansion) 65 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | def __init__( 95 | self, 96 | block, 97 | layers, 98 | block_inplanes, 99 | n_input_channels=3, 100 | conv1_t_size=7, 101 | conv1_t_stride=1, 102 | no_max_pool=False, 103 | shortcut_type="B", 104 | widen_factor=1.0, 105 | n_classes=400, 106 | ): 107 | super().__init__() 108 | 109 | block_inplanes = [int(x * widen_factor) for x in block_inplanes] 110 | 111 | self.in_planes = block_inplanes[0] 112 | self.no_max_pool = no_max_pool 113 | 114 | self.conv1 = nn.Conv3d( 115 | n_input_channels, 116 | self.in_planes, 117 | kernel_size=(conv1_t_size, 7, 7), 118 | stride=(conv1_t_stride, 2, 2), 119 | padding=(conv1_t_size // 2, 3, 3), 120 | bias=False, 121 | ) 122 | self.bn1 = nn.BatchNorm3d(self.in_planes) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 125 | self.layer1 = self._make_layer( 126 | block, block_inplanes[0], layers[0], shortcut_type 127 | ) 128 | self.layer2 = self._make_layer( 129 | block, 130 | block_inplanes[1], 131 | layers[1], 132 | shortcut_type, 133 | stride=2, 134 | ) 135 | self.layer3 = self._make_layer( 136 | block, block_inplanes[2], layers[2], shortcut_type, stride=2 137 | ) 138 | self.layer4 = self._make_layer( 139 | block, block_inplanes[3], layers[3], shortcut_type, stride=2 140 | ) 141 | 142 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 143 | self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv3d): 147 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 148 | elif isinstance(m, nn.BatchNorm3d): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | def _downsample_basic_block(self, x, planes, stride): 153 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 154 | zero_pads = torch.zeros( 155 | out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4) 156 | ) 157 | if isinstance(out.data, torch.cuda.FloatTensor): 158 | zero_pads = zero_pads.cuda() 159 | 160 | out = torch.cat([out.data, zero_pads], dim=1) 161 | 162 | return out 163 | 164 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 165 | downsample = None 166 | if stride != 1 or self.in_planes != planes * block.expansion: 167 | if shortcut_type == "A": 168 | downsample = partial( 169 | self._downsample_basic_block, 170 | planes=planes * block.expansion, 171 | stride=stride, 172 | ) 173 | else: 174 | downsample = nn.Sequential( 175 | conv1x1x1(self.in_planes, planes * block.expansion, stride), 176 | nn.BatchNorm3d(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append( 181 | block( 182 | in_planes=self.in_planes, 183 | planes=planes, 184 | stride=stride, 185 | downsample=downsample, 186 | ) 187 | ) 188 | self.in_planes = planes * block.expansion 189 | for i in range(1, blocks): 190 | layers.append(block(self.in_planes, planes)) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x): 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | if not self.no_max_pool: 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | 208 | x = x.view(x.size(0), -1) 209 | x = self.fc(x) 210 | 211 | return x 212 | 213 | 214 | def generate_model(model_depth, **kwargs): 215 | assert model_depth in [10, 18, 34, 50, 101, 152, 200] 216 | 217 | if model_depth == 10: 218 | model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs) 219 | elif model_depth == 18: 220 | model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs) 221 | elif model_depth == 34: 222 | model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs) 223 | elif model_depth == 50: 224 | model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs) 225 | elif model_depth == 101: 226 | model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs) 227 | elif model_depth == 152: 228 | model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs) 229 | elif model_depth == 200: 230 | model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs) 231 | 232 | return model 233 | -------------------------------------------------------------------------------- /src/modelling/dataset_proto.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import pickle 4 | import re 5 | from typing import Dict 6 | 7 | import h5py 8 | import torch 9 | from torch.utils.data import Dataset 10 | from yacs.config import CfgNode 11 | 12 | from utils.samplers import get_sampler 13 | 14 | 15 | class ProtoDataset(Dataset): 16 | def __init__(self, cfg: CfgNode, train: bool = False): 17 | self.cfg = cfg 18 | self.train = train 19 | self.dataset_name = ( 20 | cfg.TRAIN_DATASET_NAME if self.train else cfg.VAL_DATASET_NAME 21 | ) 22 | self.dataset_path = ( 23 | cfg.TRAIN_DATASET_PATH if self.train else cfg.VAL_DATASET_PATH 24 | ) 25 | self.dataset_type = self.cfg.DATASET_TYPE 26 | self.create_dataset() 27 | self.sampler = get_sampler[self.cfg.MODEL_NAME](cfg=cfg, train=train) 28 | 29 | def create_dataset(self): 30 | self.dataset = [] 31 | if self.dataset_name == "something-something": 32 | assert ( 33 | self.cfg.LABELS_PATH 34 | ), "If something-something, path to labels required" 35 | self.dataset = json.load(open(self.dataset_path)) 36 | self.labels = json.load(open(self.cfg.LABELS_PATH)) 37 | elif self.dataset_name == "charades" or self.dataset_name == "charades-ego": 38 | with open(self.dataset_path, newline="") as csvfile: 39 | for row in csv.DictReader(csvfile): 40 | if len(row["actions"]) == 0: 41 | continue 42 | if row["id"] in {"93MIK", "8NJQ3", "8U2AR"}: 43 | continue 44 | actions = [int(a[1:4]) for a in row["actions"].split(";")] 45 | self.dataset.append({"id": row["id"], "actions": actions}) 46 | elif self.dataset_name == "egtea-gaze": 47 | with open(self.dataset_path, "r") as text_file: 48 | for line in text_file.readlines(): 49 | video_id, action_id = line.split()[:2] 50 | # Action start from 1, subtracting 1 51 | self.dataset.append({"id": video_id, "actions": int(action_id) - 1}) 52 | elif self.dataset_name == "montalbano": 53 | self.dataset = json.load(open(self.dataset_path)) 54 | elif self.dataset_name == "EPIC-KITCHENS": 55 | # Flow: https://github.com/epic-kitchens/epic-kitchens-download-scripts/issues/17#issuecomment-1222288006 56 | assert ( 57 | self.cfg.DATASET_VERSION == 55 or self.cfg.DATASET_VERSION == 100 58 | ), "If EPIC-KITCHENS, dataset version must be provided (55 or 100)" 59 | if self.cfg.DATASET_VERSION == 55: 60 | data_file = pickle.load(open(self.dataset_path, "rb")) 61 | # FIXME: Removing two indices which are bad 62 | bad = {36787, 36788} 63 | for index in data_file.index.to_list(): 64 | if index in bad: 65 | continue 66 | self.dataset.append( 67 | { 68 | "id": data_file["video_id"][index], 69 | "narration_id": data_file["narration_id"][index], 70 | "start_frame": data_file["start_frame"][index], 71 | "stop_frame": data_file["stop_frame"][index], 72 | "start_timestamp": data_file["start_timestamp"][index], 73 | "stop_timestamp": data_file["stop_timestamp"][index], 74 | # Noun starts from 1, subtracting 1 75 | "noun_class": data_file["noun_class"][index] - 1, 76 | "verb_class": data_file["verb_class"][index], 77 | } 78 | ) 79 | elif self.cfg.DATASET_VERSION == 100: 80 | with open(self.dataset_path, newline="") as csvfile: 81 | reader = csv.DictReader(csvfile) 82 | for row in reader: 83 | if self.cfg.EPIC_PARTICIPANTS: 84 | if row["participant_id"] not in self.cfg.EPIC_PARTICIPANTS: 85 | continue 86 | self.dataset.append( 87 | { 88 | "id": row["video_id"], 89 | "narration_id": row["narration_id"], 90 | "start_frame": int(row["start_frame"]), 91 | "stop_frame": int(row["stop_frame"]), 92 | "start_timestamp": row["start_timestamp"], 93 | "stop_timestamp": row["stop_timestamp"], 94 | "noun_class": int(row["noun_class"]), 95 | "verb_class": int(row["verb_class"]), 96 | } 97 | ) 98 | else: 99 | raise ValueError(f"{self.dataset_name} does not exist!") 100 | 101 | def get_actions(self, sample) -> Dict[str, torch.Tensor]: 102 | if self.dataset_name == "something-something": 103 | actions = { 104 | "ACTION": torch.tensor( 105 | int(self.labels[re.sub("[\[\]]", "", sample["template"])]) 106 | ) 107 | } 108 | elif self.dataset_name == "charades" or self.dataset_name == "charades-ego": 109 | actions = torch.zeros(self.cfg.TOTAL_ACTIONS.ACTION, dtype=torch.float) 110 | actions[sample["ACTION"]] = 1.0 111 | actions = {"ACTION": actions} 112 | elif self.dataset_name == "egtea-gaze": 113 | actions = {"ACTION": torch.tensor(sample["actions"])} 114 | elif self.dataset_name == "montalbano": 115 | actions = {"ACTION": torch.tensor(sample["gesture_class"] - 1)} 116 | elif self.dataset_name == "EPIC-KITCHENS": 117 | actions = { 118 | "NOUN": torch.tensor(sample["noun_class"]), 119 | "VERB": torch.tensor(sample["verb_class"]), 120 | } 121 | else: 122 | raise ValueError(f"{self.dataset_name} does not exist!") 123 | 124 | return actions 125 | 126 | def __len__(self): 127 | return len(self.dataset) 128 | 129 | def open_resource(self): 130 | if self.dataset_type == "video": 131 | self.resource = h5py.File( 132 | self.cfg.VIDEOS_PATH, "r", libver="latest", swmr=True 133 | ) 134 | elif self.dataset_type == "flow": 135 | self.resource = h5py.File( 136 | self.cfg.FLOW_PATH, "r", libver="latest", swmr=True 137 | ) 138 | elif self.dataset_type == "audio": 139 | self.resource = h5py.File( 140 | self.cfg.AUDIO_PATH, "r", libver="latest", swmr=True 141 | ) 142 | elif self.dataset_type == "segmentation": 143 | self.resource = h5py.File( 144 | self.cfg.SEGMENTATION_PATH, "r", libver="latest", swmr=True 145 | ) 146 | else: 147 | raise ValueError( 148 | f"{self.dataset_type} cannot load anything with this dataset!" 149 | ) 150 | 151 | def get_video_length(self, sample): 152 | # EPIC (covers audio too - hack) 153 | if self.dataset_name == "EPIC-KITCHENS" or self.dataset_name == "montalbano": 154 | return sample["stop_frame"] - sample["start_frame"] + 1 155 | # All else 156 | return len(self.resource[sample["id"]]) 157 | 158 | def set_indices(self, indices): 159 | self.indices = indices 160 | 161 | def set_existing_transforms(self, transforms): 162 | self.existing_transforms = transforms 163 | 164 | def __getitem__(self, idx: int): 165 | raise NotImplementedError("Subclasses must implement '__getitem__'.") 166 | -------------------------------------------------------------------------------- /src/modelling/dataset_video.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from yacs.config import CfgNode 7 | 8 | from modelling.dataset_proto import ProtoDataset 9 | from utils.data_utils import compile_transforms, get_video_transforms 10 | from utils.samplers import get_sampler 11 | 12 | 13 | class VideoDataset(ProtoDataset): 14 | def __init__(self, cfg: CfgNode, train: bool = False): 15 | self.cfg = cfg 16 | self.train = train 17 | self.dataset_name = ( 18 | cfg.TRAIN_DATASET_NAME if self.train else cfg.VAL_DATASET_NAME 19 | ) 20 | self.dataset_path = ( 21 | cfg.TRAIN_DATASET_PATH if self.train else cfg.VAL_DATASET_PATH 22 | ) 23 | self.dataset_type = self.cfg.DATASET_TYPE 24 | self.create_dataset() 25 | self.sampler = get_sampler[self.cfg.MODEL_NAME](cfg=cfg, train=train) 26 | 27 | def get_video_frames(self, **kwargs): 28 | indices = kwargs.pop("indices") 29 | video_id = kwargs.pop("video_id") 30 | # Fix for omnivore so we don't need to re-implement method 31 | if "resource" in kwargs: 32 | resource = kwargs.pop("resource") 33 | else: 34 | resource = self.resource 35 | # Epic-Kitchens 36 | if self.dataset_name == "EPIC-KITCHENS": 37 | narration_id = kwargs.pop("narration_id") 38 | unique_indices, inv_indices = np.unique(indices, return_inverse=True) 39 | frames = resource[video_id][narration_id][unique_indices] 40 | frames = [ 41 | Image.open(io.BytesIO(frames[index])) 42 | for index in range(len(unique_indices)) 43 | ] 44 | frames = [frames[index] for index in inv_indices] 45 | elif self.dataset_name == "something-something": 46 | unique_indices, inv_indices = np.unique(indices, return_inverse=True) 47 | frames = resource[video_id][unique_indices] 48 | frames = [ 49 | Image.open(io.BytesIO(frames[index])) 50 | for index in range(len(unique_indices)) 51 | ] 52 | frames = [frames[index] for index in inv_indices] 53 | else: 54 | raise ValueError( 55 | f"{self.dataset_type} cannot load anything with this dataset!" 56 | ) 57 | return frames 58 | 59 | def get_flow_frames(self, **kwargs): 60 | indices = kwargs.pop("indices") 61 | video_id = kwargs.pop("video_id") 62 | # Fix for omnivore so we don't need to re-implement method 63 | if "resource" in kwargs: 64 | resource = kwargs.pop("resource") 65 | else: 66 | resource = self.resource 67 | 68 | if self.dataset_name == "EPIC-KITCHENS": 69 | narration_id = kwargs.pop("narration_id") 70 | 71 | indices_u = 2 * indices 72 | indices_v = indices_u + 1 73 | 74 | indices = np.empty( 75 | (indices_u.size + indices_v.size,), dtype=indices_u.dtype 76 | ) 77 | indices[0::2] = indices_u 78 | indices[1::2] = indices_v 79 | 80 | unique_indices, inv_indices = np.unique(indices, return_inverse=True) 81 | 82 | frames = resource[video_id][narration_id][unique_indices] 83 | frames = [ 84 | Image.open(io.BytesIO(frames[index])) 85 | for index in range(len(unique_indices)) 86 | ] 87 | frames = [frames[index] for index in inv_indices] 88 | frames_u = frames[0::2] 89 | frames_v = frames[1::2] 90 | 91 | frames = [ 92 | Image.merge( 93 | "RGB", [frame_u, frame_v, Image.new("L", size=frame_v.size)] 94 | ) 95 | for frame_u, frame_v in zip(frames_u, frames_v) 96 | ] 97 | elif self.dataset_name == "something-something": 98 | unique_indices, inv_indices = np.unique(indices, return_inverse=True) 99 | frames = resource[video_id][unique_indices] 100 | frames = [ 101 | Image.open(io.BytesIO(frames[index])) 102 | for index in range(len(unique_indices)) 103 | ] 104 | frames = [frames[index] for index in inv_indices] 105 | else: 106 | raise ValueError( 107 | f"{self.dataset_type} cannot load anything with this dataset!" 108 | ) 109 | return frames 110 | 111 | def get_frames(self, **kwargs): 112 | if self.dataset_type == "video": 113 | return self.get_video_frames(**kwargs) 114 | elif self.dataset_type == "flow": 115 | return self.get_flow_frames(**kwargs) 116 | 117 | def __getitem__(self, idx: int): 118 | output = {self.dataset_type: []} 119 | if not hasattr(self, "resource"): 120 | self.open_resource() 121 | output["video_id"] = self.dataset[idx]["id"] 122 | if not hasattr(self, "indices"): 123 | indices = self.sampler( 124 | video_length=self.get_video_length(self.dataset[idx]) 125 | ) 126 | else: 127 | indices = self.indices 128 | # Check for existing transforms to achieve consistent teaching 129 | if not hasattr(self, "existing_transforms"): 130 | existing_transforms = {} 131 | else: 132 | existing_transforms = self.existing_transforms 133 | # Pass indices to the output object, so that the other datasets can use it 134 | output["indices"] = indices 135 | # If EPIC, we need the start frame, otherwise it is 0 136 | if self.dataset_name == "EPIC-KITCHENS": 137 | output["start_frame"] = self.dataset[idx]["start_frame"] 138 | output["narration_id"] = self.dataset[idx]["narration_id"] 139 | else: 140 | output["start_frame"] = 0 141 | output["narration_id"] = -1 142 | # Load all frames 143 | frames = self.get_frames( 144 | indices=indices, 145 | video_id=output["video_id"], 146 | narration_id=output["narration_id"], 147 | ) 148 | # Aggregate frame 149 | output[self.dataset_type] = frames 150 | # Get the video transformation dictionary 151 | video_transforms_dict = get_video_transforms( 152 | augmentations_list=self.cfg.AUGMENTATIONS.get(self.dataset_type.upper()), 153 | train=self.train, 154 | cfg=self.cfg, 155 | existing_transforms=existing_transforms, 156 | ) 157 | self.enforced_transforms = video_transforms_dict 158 | video_transforms = compile_transforms(video_transforms_dict) 159 | # Augment frames - (Eval_Clips x Frames) times 160 | for i in range(len(output[self.dataset_type])): 161 | # [Channels, Spatial size, Spatial size] 162 | frame = output[self.dataset_type][i] 163 | # [Channels, Spatial size, Spatial size] 164 | # or [Eval_Crops, Channels, Spatial size, Spatial size] 165 | output[self.dataset_type][i] = video_transforms(frame) 166 | # Stack the frames 167 | output[self.dataset_type] = torch.stack(output[self.dataset_type], dim=0) 168 | # [Eval_Clips x Frames, Channels, Spatial size, Spatial size] 169 | # or 170 | # [Eval_Clips x Frames, Eval_Crops, Channels, Spatial size, Spatial size] 171 | if len(output[self.dataset_type].shape) > 4: 172 | output[self.dataset_type] = output[self.dataset_type].permute(1, 0, 2, 3, 4) 173 | # [Eval_Crops, Eval_Clips x Frames, Channels, Spatial size, Spatial size] 174 | # Get spatial size 175 | spatial_size = output[self.dataset_type].size(-1) 176 | # Reshape 177 | output[self.dataset_type] = output[self.dataset_type].reshape( 178 | -1, # Eval_Clips x Eval_Crops 179 | self.cfg.NUM_FRAMES, 180 | 3, 181 | spatial_size, 182 | spatial_size, 183 | ) 184 | # Obtain video labels 185 | output["labels"] = self.get_actions(self.dataset[idx]) 186 | 187 | return output 188 | -------------------------------------------------------------------------------- /src/modelling/hand_models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch import nn 5 | from yacs.config import CfgNode 6 | 7 | 8 | def generate_square_subsequent_mask(sz: int) -> torch.Tensor: 9 | # https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer.generate_square_subsequent_mask 10 | mask = ~(torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 11 | return mask 12 | 13 | 14 | class CategoryBoxEmbeddings(nn.Module): 15 | def __init__(self): 16 | super(CategoryBoxEmbeddings, self).__init__() 17 | # Fixing params 18 | self.hidden_size = 768 19 | self.hidden_dropout_prob = 0.1 20 | self.layer_norm_eps = 0.1 21 | # [Hand, Object, Padding] 22 | self.category_embeddings = nn.Embedding( 23 | embedding_dim=self.hidden_size, num_embeddings=3, padding_idx=0 24 | ) 25 | self.box_embedding = nn.Linear(4, self.hidden_size) 26 | self.score_embedding = nn.Linear(1, self.hidden_size) 27 | # [Pad, Left, Right] 28 | self.side_embedding = nn.Embedding( 29 | num_embeddings=3, embedding_dim=self.hidden_size, padding_idx=0 30 | ) 31 | # [Pad, No contact, Self contact, Another person, Portable obj., Stationary obj.] 32 | self.state_embedding = nn.Embedding( 33 | num_embeddings=6, embedding_dim=self.hidden_size, padding_idx=0 34 | ) 35 | self.layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) 36 | self.dropout = nn.Dropout(self.hidden_dropout_prob) 37 | 38 | def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: 39 | category_embeddings = self.category_embeddings(batch["class_labels"]) 40 | boxes_embeddings = self.box_embedding(batch["bboxes"]) 41 | score_embeddings = self.score_embedding(batch["scores"].unsqueeze(-1)) 42 | sides_embedding = self.side_embedding(batch["sides"]) 43 | states_embedding = self.state_embedding(batch["states"]) 44 | embeddings = ( 45 | category_embeddings 46 | + boxes_embeddings 47 | + score_embeddings 48 | + sides_embedding 49 | + states_embedding 50 | ) 51 | embeddings = self.layer_norm(embeddings) 52 | embeddings = self.dropout(embeddings) 53 | 54 | return embeddings 55 | 56 | 57 | class SpatialTransformer(nn.Module): 58 | def __init__(self): 59 | super(SpatialTransformer, self).__init__() 60 | # Fixing params 61 | self.hidden_size = 768 62 | self.hidden_dropout_prob = 0.1 63 | self.num_attention_heads = 8 64 | self.num_layers = 6 65 | # Rest 66 | self.category_box_embeddings = CategoryBoxEmbeddings() 67 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) 68 | self.transformer = nn.TransformerEncoder( 69 | encoder_layer=nn.TransformerEncoderLayer( 70 | d_model=self.hidden_size, 71 | nhead=self.num_attention_heads, 72 | dim_feedforward=self.hidden_size * 4, 73 | dropout=self.hidden_dropout_prob, 74 | activation="gelu", 75 | ), 76 | num_layers=self.num_layers, 77 | ) 78 | 79 | def forward(self, batch: Dict[str, torch.Tensor]): 80 | # [Batch size, Num. frames, Num. boxes, Hidden size] 81 | cb_embeddings = self.category_box_embeddings(batch) 82 | bs, nf, nb, hs = cb_embeddings.size() 83 | # Add CLS token 84 | cb_embeddings = torch.cat( 85 | (self.cls_token.expand(bs, nf, -1, -1), cb_embeddings), dim=2 86 | ) 87 | src_key_padding_mask_boxes = torch.cat( 88 | ( 89 | torch.zeros(bs, nf, 1, dtype=torch.bool, device=cb_embeddings.device), 90 | batch["src_key_padding_mask_boxes"], 91 | ), 92 | dim=2, 93 | ) 94 | # [Batch size * Num. frames, Num. boxes, Hidden size] 95 | cb_embeddings = cb_embeddings.flatten(0, 1) 96 | src_key_padding_mask_boxes = src_key_padding_mask_boxes.flatten(0, 1) 97 | # [Num. boxes, Batch size * Num. frames, Hidden size] 98 | cb_embeddings = cb_embeddings.transpose(0, 1) 99 | # [Num. boxes, Batch size * Num. frames, Hidden size] 100 | layout_embeddings = self.transformer( 101 | src=cb_embeddings, 102 | src_key_padding_mask=src_key_padding_mask_boxes, 103 | ) 104 | # [Batch size * Num. frames, Num. boxes, Hidden size] 105 | layout_embeddings = layout_embeddings.transpose(0, 1) 106 | # [Batch size, Num. frames, Num. boxes, Hidden size] 107 | layout_embeddings = layout_embeddings.view(bs, nf, nb + 1, hs) 108 | # [Batch size, Num. frames, Hidden size] 109 | layout_embeddings = layout_embeddings[:, :, 0, :] 110 | 111 | return layout_embeddings 112 | 113 | 114 | class TemporalTransformer(nn.Module): 115 | def __init__(self, cfg: CfgNode): 116 | super(TemporalTransformer, self).__init__() 117 | # Fixing params 118 | self.hidden_size = 768 119 | self.hidden_dropout_prob = 0.1 120 | self.layer_norm_eps = 0.1 121 | self.num_attention_heads = 8 122 | self.num_layers = 6 123 | self.num_frames = cfg.NUM_FRAMES 124 | # Rest 125 | self.layout_embedding = SpatialTransformer() 126 | self.position_embeddings = nn.Embedding(self.num_frames, self.hidden_size) 127 | self.layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) 128 | self.dropout = nn.Dropout(self.hidden_dropout_prob) 129 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) 130 | # Temporal Transformer 131 | self.transformer = nn.TransformerEncoder( 132 | encoder_layer=nn.TransformerEncoderLayer( 133 | d_model=self.hidden_size, 134 | nhead=self.num_attention_heads, 135 | dim_feedforward=self.hidden_size * 4, 136 | dropout=self.hidden_dropout_prob, 137 | activation="gelu", 138 | ), 139 | num_layers=self.num_layers, 140 | ) 141 | 142 | def forward(self, batch: Dict[str, torch.Tensor]): 143 | # [Batch size, Num. frames, Hidden size] 144 | layout_embeddings = self.layout_embedding(batch) 145 | bs, nf, _ = layout_embeddings.size() 146 | position_embeddings = self.position_embeddings( 147 | torch.arange(nf, device=layout_embeddings.device).expand(1, -1) 148 | ) 149 | # Preparing everything together 150 | embeddings = layout_embeddings + position_embeddings 151 | embeddings = self.dropout(self.layer_norm(embeddings)) 152 | # Concatenate with CLS token 153 | embeddings = torch.cat((embeddings, self.cls_token.expand(bs, -1, -1)), dim=1) 154 | # [Num. frames, Batch size, Hidden size] 155 | embeddings = embeddings.transpose(0, 1) 156 | # [Num. frames, Batch size, Hidden size] 157 | causal_mask = generate_square_subsequent_mask(embeddings.size(0)).to( 158 | embeddings.device 159 | ) 160 | layout_embeddings = self.transformer(src=embeddings, mask=causal_mask) 161 | # [Batch size, Hidden size] 162 | layout_embeddings = layout_embeddings[-1, :, :] 163 | # Make contiguous 164 | layout_embeddings = layout_embeddings.contiguous() 165 | 166 | return layout_embeddings 167 | 168 | 169 | class Stlt(nn.Module): 170 | def __init__(self, cfg: CfgNode): 171 | super(Stlt, self).__init__() 172 | # Fixing params 173 | self.cfg = cfg 174 | self.hidden_size = 768 175 | # Rest 176 | self.temporal_transformer = TemporalTransformer(self.cfg) 177 | # Build classifier 178 | self.classifiers = nn.ModuleDict( 179 | { 180 | actions_name: nn.Linear(self.hidden_size, actions_num) 181 | for actions_name, actions_num in self.cfg.TOTAL_ACTIONS.items() 182 | if actions_num is not None 183 | } 184 | ) 185 | 186 | def forward(self, batch: Dict[str, torch.Tensor]): 187 | output = {} 188 | # Get features 189 | features = self.temporal_transformer(batch) 190 | for actions_name in self.classifiers.keys(): 191 | output[actions_name] = self.classifiers[actions_name](features) 192 | 193 | return output 194 | -------------------------------------------------------------------------------- /src/patient_distill.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from os.path import join as pjoin 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from accelerate import Accelerator 8 | from torch import nn, optim 9 | from torch.utils.data import DataLoader, Subset 10 | from tqdm import tqdm 11 | from yacs.config import CfgNode 12 | 13 | from modelling.datasets import dataset_factory 14 | from modelling.distiller import DistillationCriterion, TeacherEnsemble 15 | from modelling.models import model_factory 16 | from utils.data_utils import separator 17 | from utils.evaluation import evaluators_factory 18 | from utils.setup import train_setup 19 | from utils.train_utils import ( 20 | EpochHandler, 21 | get_linear_schedule_with_warmup, 22 | move_batch_to_device, 23 | ) 24 | 25 | 26 | @torch.no_grad() 27 | def set_teacher_weights( 28 | cfg: CfgNode, teacher: nn.Module, loader: DataLoader, accelerator: Accelerator 29 | ): 30 | teacher = teacher.to(cfg.DEVICE) 31 | total = 0 32 | for batch in tqdm(loader, disable=not accelerator.is_main_process): 33 | batch = move_batch_to_device(batch, device=cfg.DEVICE) 34 | teacher_logits = teacher.get_teacher_logits(batch) 35 | for action_name in teacher_logits.keys(): 36 | # [Batch_Size, Num_Teachers, Num_Classes] 37 | logits = torch.cat( 38 | [t_o.unsqueeze(1) for t_o in teacher_logits[action_name]], dim=1 39 | ) 40 | losses = teacher.get_losses_wrt_labels(logits, batch["labels"][action_name]) 41 | teacher.weights[action_name] = teacher.weights[action_name] + losses.sum(0) 42 | # Measure number of batches 43 | total += losses.size(0) 44 | # Find average weights 45 | for action_name in teacher.weights.keys(): 46 | teacher.weights[action_name] = teacher.weights[action_name] / total 47 | # Convert to weights - normalize 48 | teacher.weights[action_name] = F.softmin( 49 | teacher.weights[action_name] / cfg.WEIGHTS_TEMPERATURE, dim=-1 50 | ) 51 | # Return to CPU 52 | teacher = teacher.cpu() 53 | 54 | return teacher 55 | 56 | 57 | def distill(cfg: CfgNode, accelerator: Accelerator): 58 | if cfg.LOG_TO_FILE: 59 | if accelerator.is_main_process: 60 | logging.basicConfig( 61 | level=logging.INFO, 62 | filename=pjoin(cfg.EXPERIMENT_PATH, "experiment_log.log"), 63 | filemode="a", 64 | ) 65 | else: 66 | logging.basicConfig(level=logging.INFO) 67 | if accelerator.is_main_process: 68 | logging.info(separator) 69 | logging.info(f"The config file is:\n {cfg}") 70 | logging.info(separator) 71 | # Prepare datasets 72 | if accelerator.is_main_process: 73 | logging.info("Preparing datasets...") 74 | # Prepare train dataset 75 | train_dataset = dataset_factory[cfg.DATASET_TYPE](cfg, train=True) 76 | num_training_samples = len(train_dataset) 77 | # Prepare validation dataset 78 | val_dataset = dataset_factory[cfg.DATASET_TYPE](cfg, train=False) 79 | if cfg.VAL_SUBSET: 80 | val_indices = random.sample(range(len(val_dataset)), cfg.VAL_SUBSET) 81 | val_dataset = Subset(val_dataset, val_indices) 82 | num_validation_samples = len(val_dataset) 83 | if accelerator.is_main_process: 84 | logging.info(f"Training on {num_training_samples}") 85 | logging.info(f"Validating on {num_validation_samples}") 86 | # Prepare loaders 87 | train_loader = DataLoader( 88 | train_dataset, 89 | batch_size=cfg.BATCH_SIZE, 90 | shuffle=True, 91 | num_workers=cfg.NUM_WORKERS, 92 | pin_memory=True if cfg.NUM_WORKERS else False, 93 | ) 94 | val_loader = DataLoader( 95 | val_dataset, 96 | batch_size=cfg.BATCH_SIZE, 97 | shuffle=False, 98 | num_workers=cfg.NUM_WORKERS, 99 | pin_memory=True if cfg.NUM_WORKERS else False, 100 | ) 101 | if accelerator.is_main_process: 102 | logging.info("Preparing teacher...") 103 | teacher = TeacherEnsemble(cfg) 104 | # Check if weighted distillation 105 | if cfg.DISTILLATION_WEIGHTING_SCHEME == "full-dataset": 106 | if accelerator.is_main_process: 107 | logging.info("Weighted distillation, setting weights...") 108 | import json 109 | 110 | train_indices_path = pjoin( 111 | cfg.TEACHERS.RGB_TEACHER_EXPERIMENT_PATH, "train_indices.json" 112 | ) 113 | train_indices = set(json.load(open(train_indices_path, "r"))) 114 | reverse_indices = [ 115 | i for i in range(len(train_dataset)) if i not in train_indices 116 | ] 117 | weighting_dataset = Subset(train_dataset.set_weighted(), reverse_indices) 118 | weighting_loader = DataLoader( 119 | weighting_dataset, 120 | shuffle=False, 121 | batch_size=cfg.BATCH_SIZE, 122 | num_workers=cfg.NUM_WORKERS, 123 | pin_memory=True if cfg.NUM_WORKERS else False, 124 | ) 125 | teacher = set_teacher_weights( 126 | cfg=cfg, teacher=teacher, loader=weighting_loader, accelerator=accelerator 127 | ) 128 | if accelerator.is_main_process: 129 | for action_name in teacher.weights.keys(): 130 | weights = teacher.weights[action_name] 131 | logging.info(f"For action {action_name}, the weights are: {weights}") 132 | # Wait for all devices to calibrate... 133 | accelerator.wait_for_everyone() 134 | # Preparing student 135 | if accelerator.is_main_process: 136 | logging.info("Preparing student...") 137 | # Prepare model 138 | student = model_factory[cfg.MODEL_NAME](cfg) 139 | criterion = DistillationCriterion(cfg) 140 | # Optimizer, scheduler, evaluator & loss 141 | optimizer = optim.AdamW( 142 | student.parameters(), 143 | lr=cfg.LEARNING_RATE, 144 | weight_decay=cfg.WEIGHT_DECAY, 145 | ) 146 | num_batches = num_training_samples // cfg.BATCH_SIZE 147 | scheduler = get_linear_schedule_with_warmup( 148 | optimizer, 149 | num_warmup_steps=cfg.WARMUP_EPOCHS * num_batches, 150 | num_training_steps=cfg.EPOCHS * num_batches, 151 | ) 152 | evaluator = evaluators_factory[cfg.VAL_DATASET_NAME](num_validation_samples, cfg) 153 | # Accelerate 154 | ( 155 | student, 156 | teacher, 157 | optimizer, 158 | train_loader, 159 | val_loader, 160 | scheduler, 161 | ) = accelerator.prepare( 162 | student, teacher, optimizer, train_loader, val_loader, scheduler 163 | ) 164 | # https://huggingface.co/docs/accelerate/quicktour#savingloading-entire-states 165 | # Check for starting from existing checkpoint 166 | epoch_handler = EpochHandler() 167 | if cfg.WARM_RESTART: 168 | if accelerator.is_main_process: 169 | logging.info("Performing Warm Restart!") 170 | accelerator.load_state(pjoin(cfg.EXPERIMENT_PATH, "full-checkpoint")) 171 | epoch_handler.load_state(pjoin(cfg.EXPERIMENT_PATH, "last-epoch")) 172 | # Starting training 173 | if accelerator.is_main_process: 174 | logging.info("Starting training...") 175 | for epoch in range(epoch_handler.epoch, cfg.EPOCHS): 176 | # Training loop 177 | student.train(True) 178 | with tqdm( 179 | total=len(train_loader), disable=not accelerator.is_main_process 180 | ) as pbar: 181 | for batch in train_loader: 182 | # Remove past gradients 183 | optimizer.zero_grad() 184 | # Get outputs 185 | student_logits = student(batch) 186 | teacher_logits = teacher(batch) 187 | # Measure loss 188 | loss = criterion(student_logits, teacher_logits, batch) 189 | # Backpropagate 190 | accelerator.backward(loss) 191 | accelerator.clip_grad_norm_(student.parameters(), cfg.CLIP_VAL) 192 | optimizer.step() 193 | # Update the scheduler 194 | scheduler.step() 195 | # Update progress bar 196 | pbar.update(1) 197 | pbar.set_postfix({"Loss": loss.item()}) 198 | # Validation loop 199 | student.train(False) 200 | evaluator.reset() 201 | for batch in tqdm(val_loader, disable=not accelerator.is_main_process): 202 | with torch.no_grad(): 203 | # Obtain outputs: [b * n_clips, n_actions] 204 | student_output = student(batch) 205 | all_outputs = accelerator.gather(student_output) 206 | all_labels = accelerator.gather(batch["labels"]) 207 | # Reshape outputs & put on cpu 208 | for key in all_outputs.keys(): 209 | num_classes = all_outputs[key].size(-1) 210 | # Reshape 211 | all_outputs[key] = all_outputs[key].reshape( 212 | -1, cfg.NUM_TEST_CLIPS * cfg.NUM_TEST_CROPS, num_classes 213 | ) 214 | # Move on CPU 215 | all_outputs[key] = all_outputs[key].cpu() 216 | # Put labels on cpu 217 | for key in all_labels.keys(): 218 | all_labels[key] = all_labels[key].cpu() 219 | # Pass to evaluator 220 | evaluator.process(all_outputs, all_labels) 221 | # Evaluate & save model 222 | accelerator.wait_for_everyone() 223 | if accelerator.is_main_process: 224 | metrics = evaluator.evaluate() 225 | if evaluator.is_best(): 226 | logging.info(separator) 227 | logging.info(f"Found new best on epoch {epoch+1}!") 228 | logging.info(separator) 229 | unwrapped_student = accelerator.unwrap_model(student) 230 | accelerator.save( 231 | unwrapped_student.state_dict(), 232 | pjoin(cfg.EXPERIMENT_PATH, "model_checkpoint.pt"), 233 | ) 234 | for m in metrics.keys(): 235 | logging.info(f"{m}: {metrics[m]}") 236 | # Update epoch handler 237 | epoch_handler.set_epoch(epoch + 1) 238 | # Save/Overwrite full checkpoint 239 | logging.info(f"Saving/Overwriting full checkpoint at epoch {epoch+1}") 240 | accelerator.save_state(pjoin(cfg.EXPERIMENT_PATH, "full-checkpoint")) 241 | epoch_handler.save_state(pjoin(cfg.EXPERIMENT_PATH, "last-epoch")) 242 | 243 | 244 | def main(): 245 | cfg, accelerator = train_setup("Performs multimodal knowledge distillation.") 246 | distill(cfg, accelerator) 247 | 248 | 249 | if __name__ == "__main__": 250 | main() 251 | -------------------------------------------------------------------------------- /src/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from typing import Dict 3 | 4 | import numpy as np 5 | import torch 6 | from yacs.config import CfgNode 7 | 8 | 9 | class EvaluatorAccuracy: 10 | def __init__(self, total_instances: int): 11 | self.total_instances = total_instances 12 | # FIXME: Hacky solution because of multi-GPU evaluation 13 | self.data = { 14 | "top1_corr": np.zeros((int(self.total_instances * 1.1))), 15 | "top5_corr": np.zeros((int(self.total_instances * 1.1))), 16 | } 17 | self.best_acc = 0.0 18 | self.processed_instances = 0 19 | 20 | def reset(self): 21 | self.data = {} 22 | # FIXME: Hacky solution because of multi-GPU evaluation 23 | self.data = { 24 | "top1_corr": np.zeros((int(self.total_instances * 1.1))), 25 | "top5_corr": np.zeros((int(self.total_instances * 1.1))), 26 | } 27 | self.processed_instances = 0 28 | 29 | def process(self, logits: torch.Tensor, labels: torch.Tensor): 30 | assert ( 31 | len(logits.shape) == 3 32 | ), "The shape of logits must be in format [bs, num_test_clips * num_test_crops, total_classes]" 33 | num_instances = logits.shape[0] 34 | logits = logits.mean(1) 35 | self.data["top1_corr"][ 36 | self.processed_instances : self.processed_instances + num_instances 37 | ] = (logits.argmax(-1) == labels).int() 38 | self.data["top5_corr"][ 39 | self.processed_instances : self.processed_instances + num_instances 40 | ] = ((logits.topk(k=5).indices == labels.unsqueeze(1)).any(dim=1)).int() 41 | self.processed_instances += num_instances 42 | 43 | def evaluate(self): 44 | top1_acc = ( 45 | self.data["top1_corr"].sum() / self.processed_instances 46 | if self.processed_instances 47 | else 0.0 48 | ) 49 | top5_acc = ( 50 | self.data["top5_corr"].sum() / self.processed_instances 51 | if self.processed_instances 52 | else 0.0 53 | ) 54 | metrics = { 55 | "top1_acc": round(top1_acc * 100, 2), 56 | "top5_acc": round(top5_acc * 100, 2), 57 | } 58 | 59 | return metrics 60 | 61 | def evaluate_verbose(self): 62 | metrics = {} 63 | # Get metrics 64 | for m_name, m_value in self.evaluate().items(): 65 | conf = ( 66 | (m_value * (100 - m_value) / self.processed_instances) ** 0.5 67 | if self.processed_instances 68 | else 0.0 69 | ) 70 | conf = round(conf, 2) 71 | metrics[m_name] = f"{m_value} +/- {conf}" 72 | 73 | return metrics 74 | 75 | def is_best(self): 76 | metrics = self.evaluate() 77 | # Get currect accuracy 78 | cur_accuracy = sum( 79 | [metrics[accuracy_type] for accuracy_type in metrics.keys()] 80 | ) / len(metrics) 81 | # Validate whether it's the best model 82 | if cur_accuracy > self.best_acc: 83 | self.best_acc = cur_accuracy 84 | return True 85 | return False 86 | 87 | 88 | class EvaluatorAND: 89 | def __init__(self, evaluator1: EvaluatorAccuracy, evaluator2: EvaluatorAccuracy): 90 | self.evaluator1 = evaluator1 91 | self.evaluator2 = evaluator2 92 | assert ( 93 | self.evaluator1.total_instances == self.evaluator2.total_instances 94 | ), "Two evaluators need to have the same number of total instances." 95 | 96 | def evaluate(self): 97 | assert ( 98 | self.evaluator1.processed_instances == self.evaluator2.processed_instances 99 | ), "Two evaluators need to have processed the same number of instances." 100 | processed_instances = self.evaluator1.processed_instances 101 | corrects1 = self.evaluator1.data["top1_corr"] 102 | corrects2 = self.evaluator2.data["top1_corr"] 103 | 104 | corrects = ((corrects1 + corrects2) / 2 == 1).astype(int) 105 | accuracy = corrects.sum() / processed_instances 106 | metrics = {"top1_acc": round(accuracy * 100, 2)} 107 | 108 | return metrics 109 | 110 | def evaluate_verbose(self): 111 | assert ( 112 | self.evaluator1.processed_instances == self.evaluator2.processed_instances 113 | ), "Two evaluators need to have processed the same number of instances." 114 | processed_instances = self.evaluator1.processed_instances 115 | m_value = self.evaluate()["top1_acc"] 116 | conf = ( 117 | (m_value * (100 - m_value) / processed_instances) ** 0.5 118 | if processed_instances 119 | else 0.0 120 | ) 121 | conf = round(conf, 2) 122 | metrics = {"top1_acc": f"{m_value} +/- {conf}"} 123 | 124 | return metrics 125 | 126 | 127 | class EvaluatorSomething(EvaluatorAccuracy): 128 | def __init__(self, total_instances, cfg: CfgNode): 129 | self.cfg = cfg 130 | super().__init__( 131 | total_instances, 132 | ) 133 | 134 | def process( 135 | self, model_output: Dict[str, torch.Tensor], labels: Dict[str, torch.Tensor] 136 | ): 137 | # Prepare logits & labels 138 | logits = model_output["ACTION"] 139 | labels = labels["ACTION"] 140 | super().process(logits, labels) 141 | 142 | 143 | class EvaluatorEpic: 144 | def __init__(self, total_instances: int, cfg: CfgNode): 145 | self.cfg = cfg 146 | self.noun_evaluator = EvaluatorAccuracy( 147 | total_instances, 148 | ) 149 | self.verb_evaluator = EvaluatorAccuracy( 150 | total_instances, 151 | ) 152 | self.action_evaluator = EvaluatorAND(self.noun_evaluator, self.verb_evaluator) 153 | self.best_acc = 0.0 154 | # Prepare EPIC tail nouns, if provided 155 | if self.cfg.EPIC_TAIL_NOUNS_PATH is not None: 156 | assert ( 157 | cfg.DATASET_VERSION == 100 158 | ), "Tail classes only available on 'EPIC-KITCHENS 100'" 159 | with open(cfg.EPIC_TAIL_NOUNS_PATH, newline="") as csvfile: 160 | reader = csv.reader(csvfile, delimiter=" ", quotechar="|") 161 | tail_nouns = list(reader)[1:] 162 | self.tail_nouns = torch.tensor( 163 | [int(item[0]) for item in tail_nouns] 164 | ).unsqueeze(0) 165 | self.tail_noun_evaluator = EvaluatorAccuracy( 166 | total_instances, 167 | ) 168 | # Prepare EPIC tail verbs, if provided 169 | if self.cfg.EPIC_TAIL_VERBS_PATH is not None: 170 | assert ( 171 | cfg.DATASET_VERSION == 100 172 | ), "Tail classes only available on 'EPIC-KITCHENS 100'" 173 | with open(cfg.EPIC_TAIL_VERBS_PATH, newline="") as csvfile: 174 | reader = csv.reader(csvfile, delimiter=" ", quotechar="|") 175 | tail_verbs = list(reader)[1:] 176 | self.tail_verbs = torch.tensor( 177 | [int(item[0]) for item in tail_verbs] 178 | ).unsqueeze(0) 179 | self.tail_verb_evaluator = EvaluatorAccuracy( 180 | total_instances, 181 | ) 182 | 183 | if hasattr(self, "tail_noun_evaluator") and hasattr( 184 | self, "tail_verb_evaluator" 185 | ): 186 | self.tail_noun_evaluator_aux_AND = EvaluatorAccuracy( 187 | total_instances, 188 | ) 189 | self.tail_verb_evaluator_aux_AND = EvaluatorAccuracy( 190 | total_instances, 191 | ) 192 | self.tail_action_evaluator_AND = EvaluatorAND( 193 | self.tail_noun_evaluator_aux_AND, self.tail_verb_evaluator_aux_AND 194 | ) 195 | self.tail_noun_evaluator_aux_OR = EvaluatorAccuracy( 196 | total_instances, 197 | ) 198 | self.tail_verb_evaluator_aux_OR = EvaluatorAccuracy( 199 | total_instances, 200 | ) 201 | self.tail_action_evaluator_OR = EvaluatorAND( 202 | self.tail_noun_evaluator_aux_OR, self.tail_verb_evaluator_aux_OR 203 | ) 204 | 205 | def reset(self): 206 | self.noun_evaluator.reset() 207 | self.verb_evaluator.reset() 208 | if hasattr(self, "tail_noun_evaluator"): 209 | self.tail_noun_evaluator.reset() 210 | if hasattr(self, "tail_verb_evaluator"): 211 | self.tail_verb_evaluator.reset() 212 | if hasattr(self, "tail_action_evaluator_AND"): 213 | self.tail_noun_evaluator_aux_AND.reset() 214 | self.tail_verb_evaluator_aux_AND.reset() 215 | if hasattr(self, "tail_action_evaluator_OR"): 216 | self.tail_noun_evaluator_aux_OR.reset() 217 | self.tail_verb_evaluator_aux_OR.reset() 218 | 219 | def process( 220 | self, model_output: Dict[str, torch.Tensor], labels: Dict[str, torch.Tensor] 221 | ): 222 | self.noun_evaluator.process(model_output["NOUN"], labels["NOUN"]) 223 | self.verb_evaluator.process(model_output["VERB"], labels["VERB"]) 224 | if hasattr(self, "tail_noun_evaluator"): 225 | labels_noun = labels["NOUN"].unsqueeze(1) 226 | active_batch_indices_noun = (labels_noun == self.tail_nouns).any(dim=1) 227 | self.tail_noun_evaluator.process( 228 | model_output["NOUN"][active_batch_indices_noun], 229 | labels["NOUN"][active_batch_indices_noun], 230 | ) 231 | if hasattr(self, "tail_verb_evaluator"): 232 | labels_verb = labels["VERB"].unsqueeze(1) 233 | active_batch_indices_verb = (labels_verb == self.tail_verbs).any(dim=1) 234 | self.tail_verb_evaluator.process( 235 | model_output["VERB"][active_batch_indices_verb], 236 | labels["VERB"][active_batch_indices_verb], 237 | ) 238 | if hasattr(self, "tail_action_evaluator_AND"): 239 | actve_batch_indices_action_AND = torch.logical_and( 240 | active_batch_indices_noun, active_batch_indices_verb 241 | ) 242 | self.tail_noun_evaluator_aux_AND.process( 243 | model_output["NOUN"][actve_batch_indices_action_AND], 244 | labels["NOUN"][actve_batch_indices_action_AND], 245 | ) 246 | self.tail_verb_evaluator_aux_AND.process( 247 | model_output["VERB"][actve_batch_indices_action_AND], 248 | labels["VERB"][actve_batch_indices_action_AND], 249 | ) 250 | 251 | if hasattr(self, "tail_action_evaluator_OR"): 252 | actve_batch_indices_action_OR = torch.logical_or( 253 | active_batch_indices_noun, active_batch_indices_verb 254 | ) 255 | self.tail_noun_evaluator_aux_OR.process( 256 | model_output["NOUN"][actve_batch_indices_action_OR], 257 | labels["NOUN"][actve_batch_indices_action_OR], 258 | ) 259 | self.tail_verb_evaluator_aux_OR.process( 260 | model_output["VERB"][actve_batch_indices_action_OR], 261 | labels["VERB"][actve_batch_indices_action_OR], 262 | ) 263 | 264 | def evaluate(self): 265 | noun_metrics = self.noun_evaluator.evaluate() 266 | verb_metrics = self.verb_evaluator.evaluate() 267 | action_metrics = self.action_evaluator.evaluate() 268 | if hasattr(self, "tail_noun_evaluator"): 269 | tail_noun_metrics = self.tail_noun_evaluator.evaluate() 270 | else: 271 | tail_noun_metrics = {"top1_acc": "N/A", "top5_acc": "N/A"} 272 | if hasattr(self, "tail_verb_evaluator"): 273 | tail_verb_metrics = self.tail_verb_evaluator.evaluate() 274 | else: 275 | tail_verb_metrics = {"top1_acc": "N/A", "top5_acc": "N/A"} 276 | if hasattr(self, "tail_action_evaluator_AND"): 277 | tail_action_metrics_AND = self.tail_action_evaluator_AND.evaluate() 278 | else: 279 | tail_action_metrics_AND = {"top1_acc": "N/A"} 280 | if hasattr(self, "tail_action_evaluator_OR"): 281 | tail_action_metrics_OR = self.tail_action_evaluator_OR.evaluate() 282 | else: 283 | tail_action_metrics_OR = {"top1_acc": "N/A"} 284 | 285 | return { 286 | "noun_acc": noun_metrics["top1_acc"], 287 | "verb_acc": verb_metrics["top1_acc"], 288 | "action_acc": action_metrics["top1_acc"], 289 | "tail_noun_acc": tail_noun_metrics["top1_acc"], 290 | "tail_verb_acc": tail_verb_metrics["top1_acc"], 291 | "tail_action_AND_acc": tail_action_metrics_AND["top1_acc"], 292 | "tail_action_OR_acc": tail_action_metrics_OR["top1_acc"], 293 | } 294 | 295 | def evaluate_verbose(self): 296 | noun_metrics = self.noun_evaluator.evaluate_verbose() 297 | verb_metrics = self.verb_evaluator.evaluate_verbose() 298 | action_metrics = self.action_evaluator.evaluate_verbose() 299 | if hasattr(self, "tail_noun_evaluator"): 300 | tail_noun_metrics = self.tail_noun_evaluator.evaluate_verbose() 301 | else: 302 | tail_noun_metrics = {"top1_acc": "N/A", "top5_acc": "N/A"} 303 | if hasattr(self, "tail_verb_evaluator"): 304 | tail_verb_metrics = self.tail_verb_evaluator.evaluate_verbose() 305 | else: 306 | tail_verb_metrics = {"top1_acc": "N/A", "top5_acc": "N/A"} 307 | if hasattr(self, "tail_action_evaluator_AND"): 308 | tail_action_metrics_AND = self.tail_action_evaluator_AND.evaluate_verbose() 309 | else: 310 | tail_action_metrics_AND = {"top1_acc": "N/A"} 311 | if hasattr(self, "tail_action_evaluator_OR"): 312 | tail_action_metrics_OR = self.tail_action_evaluator_OR.evaluate_verbose() 313 | else: 314 | tail_action_metrics_OR = {"top1_acc": "N/A"} 315 | 316 | return { 317 | "noun_acc": noun_metrics["top1_acc"], 318 | "verb_acc": verb_metrics["top1_acc"], 319 | "action_acc": action_metrics["top1_acc"], 320 | "tail_noun_acc": tail_noun_metrics["top1_acc"], 321 | "tail_verb_acc": tail_verb_metrics["top1_acc"], 322 | "tail_action_AND_acc": tail_action_metrics_AND["top1_acc"], 323 | "tail_action_OR_acc": tail_action_metrics_OR["top1_acc"], 324 | } 325 | 326 | def is_best(self): 327 | metrics = self.evaluate() 328 | # Validate whether it's the best model 329 | if metrics["action_acc"] > self.best_acc: 330 | self.best_acc = metrics["action_acc"] 331 | return True 332 | return False 333 | 334 | 335 | evaluators_factory = { 336 | "something-something": EvaluatorSomething, 337 | "egtea-gaze": EvaluatorSomething, 338 | "EPIC-KITCHENS": EvaluatorEpic, 339 | "montalbano": EvaluatorSomething, 340 | } 341 | -------------------------------------------------------------------------------- /src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from copy import deepcopy 4 | from typing import List, Tuple 5 | 6 | import ffmpeg 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torchaudio.transforms as TAT 11 | from PIL import Image 12 | from torch.nn.modules.utils import _pair 13 | from torchaudio.transforms import AmplitudeToDB 14 | from torchvision.transforms import ColorJitter, Compose, Normalize, RandomCrop, Resize 15 | from torchvision.transforms import functional as TF 16 | from yacs.config import CfgNode 17 | 18 | spatial_sizes = { 19 | "resnet3d": 224, # HACK, because we resize inside the model to 112 20 | "video_mae": 224, 21 | "swin": 224, 22 | } 23 | 24 | 25 | def load_video(in_filepath: str): 26 | """Loads a video from a filepath.""" 27 | probe = ffmpeg.probe(in_filepath) 28 | video_stream = next( 29 | (stream for stream in probe["streams"] if stream["codec_type"] == "video"), 30 | None, 31 | ) 32 | width = int(video_stream["width"]) 33 | height = int(video_stream["height"]) 34 | out, _ = ( 35 | ffmpeg.input(in_filepath) 36 | .output("pipe:", format="rawvideo", pix_fmt="rgb24") 37 | # https://github.com/kkroening/ffmpeg-python/issues/68#issuecomment-443752014 38 | .global_args("-loglevel", "error") 39 | .run(capture_stdout=True) 40 | ) 41 | video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) 42 | 43 | return video 44 | 45 | 46 | class IdentityTransform: 47 | def __init__(self, **kwargs): 48 | pass 49 | 50 | def __call__(self, img): 51 | return img 52 | 53 | 54 | class ToTensor: 55 | def __init__(self, **kwargs): 56 | pass 57 | 58 | def __call__(self, img): 59 | return TF.to_tensor(img) 60 | 61 | 62 | class VideoNormalize: 63 | def __init__(self, model_name: str, **kwargs): 64 | normalizations = { 65 | # Orig: mean: (123.675, 116.28, 103.53); 66 | # std: (58.395, 57.12, 57.375), we divide by 67 | # 255 because our images are normalized by 255 68 | "swin": {"mean": (0.4850, 0.4560, 0.4060), "std": (0.2290, 0.2240, 0.2250)}, 69 | "resnet3d": {"mean": (0.5, 0.5, 0.5), "std": (0.5, 0.5, 0.5)}, 70 | # https://rwightman.github.io/pytorch-image-models/models/vision-transformer/ 71 | "video_mae": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)}, 72 | } 73 | self.normalize = Normalize( 74 | mean=normalizations[model_name]["mean"], 75 | std=normalizations[model_name]["std"], 76 | ) 77 | 78 | def __call__(self, img: torch.Tensor): 79 | return self.normalize(img) 80 | 81 | 82 | class VideoColorJitter: 83 | # Adapted from: https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py#L1140 84 | def __init__(self, **kwargs): 85 | ( 86 | self.fn_idx, 87 | self.brightness_factor, 88 | self.contrast_factor, 89 | self.saturation_factor, 90 | self.hue_factor, 91 | ) = ColorJitter.get_params( 92 | brightness=(0.75, 1.25), 93 | contrast=(0.75, 1.25), 94 | saturation=(0.75, 1.25), 95 | hue=(-0.1, 0.1), 96 | ) 97 | 98 | def __call__(self, img: Image): 99 | for fn_id in self.fn_idx: 100 | if fn_id == 0 and self.brightness_factor is not None: 101 | img = TF.adjust_brightness(img, self.brightness_factor) 102 | elif fn_id == 1 and self.contrast_factor is not None: 103 | img = TF.adjust_contrast(img, self.contrast_factor) 104 | elif fn_id == 2 and self.saturation_factor is not None: 105 | img = TF.adjust_saturation(img, self.saturation_factor) 106 | elif fn_id == 3 and self.hue_factor is not None: 107 | img = TF.adjust_hue(img, self.hue_factor) 108 | 109 | return img 110 | 111 | 112 | class VideoRandomHorizontalFlip: 113 | def __init__(self, p: float, **kwargs): 114 | self.flip = torch.rand(1) < p 115 | 116 | def __call__(self, img): 117 | if self.flip: 118 | img = TF.hflip(img) 119 | return img 120 | 121 | 122 | class VideoRandomCrop: 123 | def __init__(self, spatial_size: int, **kwargs): 124 | self.spatial_size = (spatial_size, spatial_size) 125 | 126 | def __call__(self, frame: Image): 127 | if ( 128 | not hasattr(self, "top") 129 | and not hasattr(self, "left") 130 | and not hasattr(self, "height") 131 | and not hasattr(self, "width") 132 | ): 133 | self.top, self.left, self.height, self.width = RandomCrop.get_params( 134 | frame, self.spatial_size 135 | ) 136 | 137 | return TF.crop(frame, self.top, self.left, self.height, self.width) 138 | 139 | 140 | class VideoResize: 141 | def __init__(self, spatial_size: int, train: bool = False, **kwargs): 142 | frame_size = ( 143 | int(spatial_size * random.uniform(1.15, 1.43)) if train else spatial_size 144 | ) 145 | self.resize = Resize(frame_size, antialias=True) 146 | 147 | def __call__(self, img: Image): 148 | return self.resize(img) 149 | 150 | 151 | class VideoInferenceCrop: 152 | def __init__(self, spatial_size: int, num_test_crops: int, **kwargs): 153 | self.spatial_size = spatial_size 154 | self.num_test_crops = num_test_crops 155 | assert self.num_test_crops in [1, 3], "Can be only 1 or 3" 156 | 157 | def __call__(self, img: torch.Tensor): 158 | if self.num_test_crops == 1: 159 | return TF.center_crop(img, self.spatial_size) 160 | # Three-crop inference. New dimension prepended and images are stacked along it. 161 | crop_size = _pair(self.spatial_size) 162 | img_h, img_w = img.shape[-2:] 163 | crop_w, crop_h = crop_size 164 | assert crop_h == img_h or crop_w == img_w 165 | 166 | if crop_h == img_h: 167 | w_step = (img_w - crop_w) // 2 168 | offsets = [ 169 | (0, 0), # left 170 | (2 * w_step, 0), # right 171 | (w_step, 0), # middle 172 | ] 173 | elif crop_w == img_w: 174 | h_step = (img_h - crop_h) // 2 175 | offsets = [ 176 | (0, 0), # top 177 | (0, 2 * h_step), # down 178 | (0, h_step), # middle 179 | ] 180 | 181 | img_cropped = [] 182 | for x_offset, y_offset in offsets: 183 | crop = img[..., y_offset : y_offset + crop_h, x_offset : x_offset + crop_w] 184 | img_cropped.append(crop) 185 | img_cropped = torch.stack(img_cropped, dim=0) 186 | return img_cropped 187 | 188 | 189 | class AudioResize: 190 | def __init__(self, spatial_size, **kwargs): 191 | self.spatial_size = spatial_size 192 | 193 | def __call__(self, img: torch.tensor): 194 | return F.interpolate(img, size=self.spatial_size) 195 | 196 | 197 | class AudioTimeStretch: 198 | def __init__( 199 | self, p: float = 0.5, spatial_size: int = 224, train: bool = False, **kwargs 200 | ): 201 | self.spatial_size = spatial_size 202 | self.train = train 203 | self.p = p 204 | self.transform = TAT.TimeStretch() if self.train else IdentityTransform() 205 | 206 | def __call__(self, img: Image): 207 | rate = random.uniform(0.85, 1.15) 208 | if random.uniform(0, 1) < self.p: 209 | return self.transform(img, rate) 210 | return img 211 | 212 | 213 | class AudioTimeMasking: 214 | def __init__( 215 | self, p: float = 0.5, time_mask_param: int = 80, train: bool = False, **kwargs 216 | ): 217 | self.train = train 218 | self.p = p 219 | self.transform = ( 220 | TAT.TimeMasking(time_mask_param=time_mask_param, iid_masks=True) 221 | if self.train 222 | else IdentityTransform() 223 | ) 224 | 225 | def __call__(self, img: torch.Tensor): 226 | if random.uniform(0, 1) < self.p: 227 | return self.transform(img) 228 | return img 229 | 230 | 231 | class AudioFrequencyMasking: 232 | def __init__( 233 | self, p: float = 0.5, freq_mask_param: int = 80, train: bool = False, **kwargs 234 | ): 235 | self.train = train 236 | self.p = p 237 | self.transform = ( 238 | TAT.FrequencyMasking(freq_mask_param=freq_mask_param, iid_masks=True) 239 | if self.train 240 | else IdentityTransform() 241 | ) 242 | 243 | def __call__(self, img: torch.Tensor): 244 | if random.uniform(0, 1) < self.p: 245 | return self.transform(img) 246 | return img 247 | 248 | 249 | class AudioNormalize: 250 | def __init__(self, **kwargs): 251 | self.mean = -28.1125 252 | self.std = 16.5627 253 | 254 | def __call__(self, audio: torch.Tensor): 255 | return (audio - self.mean) / self.std 256 | 257 | 258 | class AudioAmplitudeToDB: 259 | def __init__(self, **kwargs): 260 | self.amplitude_to_db = AmplitudeToDB() 261 | 262 | def __call__(self, audio: torch.Tensor): 263 | return self.amplitude_to_db(audio) 264 | 265 | 266 | def compile_transforms(transforms): 267 | return Compose(transforms.values()) 268 | 269 | 270 | augname2aug = { 271 | "AudioTimeStretch": AudioTimeStretch, 272 | "AudioTimeMasking": AudioTimeMasking, 273 | "AudioFrequencyMasking": AudioFrequencyMasking, 274 | "AudioResize": AudioResize, 275 | "AudioAmplitudeToDB": AudioAmplitudeToDB, 276 | "AudioNormalize": AudioNormalize, 277 | "IdentityTransform": IdentityTransform, 278 | "VideoColorJitter": VideoColorJitter, 279 | "VideoRandomHorizontalFlip": VideoRandomHorizontalFlip, 280 | "VideoRandomCrop": VideoRandomCrop, 281 | "VideoInferenceCrop": VideoInferenceCrop, 282 | "VideoResize": VideoResize, 283 | "ToTensor": ToTensor, 284 | "VideoNormalize": VideoNormalize, 285 | } 286 | 287 | 288 | def get_video_transforms( 289 | augmentations_list: List[str], cfg: CfgNode, train: bool = False, **kwargs 290 | ): 291 | aug_list = deepcopy(augmentations_list) 292 | existing_transforms = kwargs.pop("existing_transforms", {}) 293 | # Testing 294 | if not train: 295 | return { 296 | "VideoResize": VideoResize( 297 | spatial_size=spatial_sizes[cfg.MODEL_NAME], train=False 298 | ), 299 | "ToTensor": ToTensor(), 300 | "VideoInferenceCrop": VideoInferenceCrop( 301 | spatial_size=spatial_sizes[cfg.MODEL_NAME], 302 | num_test_crops=cfg.NUM_TEST_CROPS, 303 | ), 304 | } 305 | # During training, always add ToTensor 306 | aug_list.append("ToTensor") 307 | 308 | return { 309 | aug_name: augname2aug[aug_name]( 310 | spatial_size=spatial_sizes[cfg.MODEL_NAME], 311 | train=train, 312 | p=0.5, 313 | model_name=cfg.MODEL_NAME, 314 | **kwargs, 315 | ) 316 | if aug_name not in existing_transforms 317 | else existing_transforms[aug_name] 318 | for aug_name in aug_list 319 | } 320 | 321 | 322 | def get_audio_transforms( 323 | augmentations_list: List[str], cfg: CfgNode, train: bool = False, **kwargs 324 | ): 325 | aug_list = deepcopy(augmentations_list) 326 | existing_transforms = kwargs.pop("existing_transforms", {}) 327 | spatial_size = spatial_sizes[cfg.MODEL_NAME] 328 | # Testing 329 | if not train: 330 | return { 331 | "AudioResize": AudioResize( 332 | spatial_size=(spatial_size, spatial_size), train=False 333 | ), 334 | "AudioAmplitudeToDB": AudioAmplitudeToDB(), 335 | } 336 | # During training, always add AudioNormalize to the list 337 | aug_list.append("AudioAmplitudeToDB") 338 | return { 339 | aug_name: augname2aug[aug_name]( 340 | spatial_size=spatial_size, train=train, **kwargs 341 | ) 342 | if aug_name not in existing_transforms 343 | else existing_transforms[aug_name] 344 | for aug_name in aug_list 345 | } 346 | 347 | 348 | def get_normalizer(input_type: str, model_name: str): 349 | if input_type == "audio": 350 | return AudioNormalize(model_name=model_name) 351 | elif input_type == "video" or input_type == "flow": 352 | return VideoNormalize(model_name=model_name) 353 | 354 | raise ValueError(f"{input_type} not recognized!") 355 | 356 | 357 | def video2audio_indices(indices, video_fps, audio_sample_rate): 358 | return [int(index * audio_sample_rate / video_fps) for index in indices] 359 | 360 | 361 | def audio2video_indices(indices, video_fps, audio_sample_rate): 362 | return [int(index * video_fps / audio_sample_rate) for index in indices] 363 | 364 | 365 | def extract_audio_segments(audio_frames, segment_length, audio_indices): 366 | _, num_frames = audio_frames.shape 367 | 368 | audio_segments = [] 369 | for audio_index in audio_indices: 370 | centre_frame = audio_index 371 | left_frame = centre_frame - math.floor(segment_length / 2) 372 | right_frame = centre_frame + math.ceil(segment_length / 2) 373 | if left_frame < 0 and right_frame > num_frames: 374 | samples = torch.nn.functional.pad( 375 | audio_frames, pad=(abs(left_frame), right_frame - num_frames) 376 | ) 377 | elif left_frame < 0: 378 | samples = torch.nn.functional.pad(audio_frames, pad=(abs(left_frame), 0))[ 379 | :, :segment_length 380 | ] 381 | elif right_frame > num_frames: 382 | samples = torch.nn.functional.pad( 383 | audio_frames, pad=(0, right_frame - num_frames) 384 | )[:, -segment_length:] 385 | else: 386 | samples = audio_frames[:, left_frame:right_frame] 387 | audio_segments.append(samples) 388 | audio_segments = torch.cat(audio_segments, dim=0) 389 | return audio_segments 390 | 391 | 392 | def fix_box(box: List[int], video_size: Tuple[int, int]): 393 | # Cast box elements to integers 394 | box = [max(0, int(b)) for b in box] 395 | # If x1 > x2 or y1 > y2 switch (Hack) 396 | if box[0] > box[2]: 397 | box[0], box[2] = box[2], box[0] 398 | if box[1] > box[3]: 399 | box[1], box[3] = box[3], box[1] 400 | # Clamp to max size (Hack) 401 | if box[0] >= video_size[1]: 402 | box[0] = video_size[1] - 1 403 | if box[1] >= video_size[0]: 404 | box[1] = video_size[0] - 1 405 | if box[2] >= video_size[1]: 406 | box[2] = video_size[1] - 1 407 | if box[3] >= video_size[0]: 408 | box[3] = video_size[0] - 1 409 | # Fix if equal (Hack) 410 | if box[0] == box[2] and box[0] == 0: 411 | box[2] = 1 412 | if box[1] == box[3] and box[1] == 0: 413 | box[3] = 1 414 | if box[0] == box[2]: 415 | box[0] -= 1 416 | if box[1] == box[3]: 417 | box[1] -= 1 418 | return box 419 | 420 | 421 | separator = "=" * 40 422 | 423 | 424 | def rgb_to_flow_index(rgb_index: int, stride: int = 1): 425 | # https://github.com/epic-kitchens/epic-kitchens-download-scripts/issues/17#issuecomment-1222288006 426 | # https://github.com/epic-kitchens/epic-kitchens-55-lib/blob/7f2499aff5fdb62a66e6da92322e5c060ea4a414/epic_kitchens/video.py#L67 427 | # https://github.com/epic-kitchens/C1-Action-Recognition-TSN-TRN-TSM/blob/master/src/convert_rgb_to_flow_frame_idxs.py#L24 428 | return int(np.ceil(rgb_index / stride)) 429 | -------------------------------------------------------------------------------- /src/modelling/swin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credit to the official implementation: https://github.com/SwinTransformer/Video-Swin-Transformer 3 | """ 4 | 5 | 6 | from functools import lru_cache, reduce 7 | from operator import mul 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint as checkpoint 14 | from einops import rearrange 15 | from timm.models.layers import DropPath, trunc_normal_ 16 | 17 | 18 | class Mlp(nn.Module): 19 | """Multilayer perceptron.""" 20 | 21 | def __init__( 22 | self, 23 | in_features, 24 | hidden_features=None, 25 | out_features=None, 26 | act_layer=nn.GELU, 27 | drop=0.0, 28 | ): 29 | super().__init__() 30 | out_features = out_features or in_features 31 | hidden_features = hidden_features or in_features 32 | self.fc1 = nn.Linear(in_features, hidden_features) 33 | self.act = act_layer() 34 | self.fc2 = nn.Linear(hidden_features, out_features) 35 | self.drop = nn.Dropout(drop) 36 | 37 | def forward(self, x): 38 | x = self.fc1(x) 39 | x = self.act(x) 40 | x = self.drop(x) 41 | x = self.fc2(x) 42 | x = self.drop(x) 43 | return x 44 | 45 | 46 | def window_partition(x, window_size): 47 | """ 48 | Args: 49 | x: (B, D, H, W, C) 50 | window_size (tuple[int]): window size 51 | Returns: 52 | windows: (B*num_windows, window_size*window_size, C) 53 | """ 54 | B, D, H, W, C = x.shape 55 | x = x.view( 56 | B, 57 | D // window_size[0], 58 | window_size[0], 59 | H // window_size[1], 60 | window_size[1], 61 | W // window_size[2], 62 | window_size[2], 63 | C, 64 | ) 65 | windows = ( 66 | x.permute(0, 1, 3, 5, 2, 4, 6, 7) 67 | .contiguous() 68 | .view(-1, reduce(mul, window_size), C) 69 | ) 70 | return windows 71 | 72 | 73 | def window_reverse(windows, window_size, B, D, H, W): 74 | """ 75 | Args: 76 | windows: (B*num_windows, window_size, window_size, C) 77 | window_size (tuple[int]): Window size 78 | H (int): Height of image 79 | W (int): Width of image 80 | Returns: 81 | x: (B, D, H, W, C) 82 | """ 83 | x = windows.view( 84 | B, 85 | D // window_size[0], 86 | H // window_size[1], 87 | W // window_size[2], 88 | window_size[0], 89 | window_size[1], 90 | window_size[2], 91 | -1, 92 | ) 93 | x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) 94 | return x 95 | 96 | 97 | def get_window_size(x_size, window_size, shift_size=None): 98 | use_window_size = list(window_size) 99 | if shift_size is not None: 100 | use_shift_size = list(shift_size) 101 | for i in range(len(x_size)): 102 | if x_size[i] <= window_size[i]: 103 | use_window_size[i] = x_size[i] 104 | if shift_size is not None: 105 | use_shift_size[i] = 0 106 | 107 | if shift_size is None: 108 | return tuple(use_window_size) 109 | else: 110 | return tuple(use_window_size), tuple(use_shift_size) 111 | 112 | 113 | class WindowAttention3D(nn.Module): 114 | """Window based multi-head self attention (W-MSA) module with relative position bias. 115 | It supports both of shifted and non-shifted window. 116 | Args: 117 | dim (int): Number of input channels. 118 | window_size (tuple[int]): The temporal length, height and width of the window. 119 | num_heads (int): Number of attention heads. 120 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 121 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 122 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 123 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 124 | """ 125 | 126 | def __init__( 127 | self, 128 | dim, 129 | window_size, 130 | num_heads, 131 | qkv_bias=False, 132 | qk_scale=None, 133 | attn_drop=0.0, 134 | proj_drop=0.0, 135 | ): 136 | 137 | super().__init__() 138 | self.dim = dim 139 | self.window_size = window_size # Wd, Wh, Ww 140 | self.num_heads = num_heads 141 | head_dim = dim // num_heads 142 | self.scale = qk_scale or head_dim ** -0.5 143 | 144 | # define a parameter table of relative position bias 145 | self.relative_position_bias_table = nn.Parameter( 146 | torch.zeros( 147 | (2 * window_size[0] - 1) 148 | * (2 * window_size[1] - 1) 149 | * (2 * window_size[2] - 1), 150 | num_heads, 151 | ) 152 | ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH 153 | 154 | # get pair-wise relative position index for each token inside the window 155 | coords_d = torch.arange(self.window_size[0]) 156 | coords_h = torch.arange(self.window_size[1]) 157 | coords_w = torch.arange(self.window_size[2]) 158 | coords = torch.stack( 159 | torch.meshgrid(coords_d, coords_h, coords_w) 160 | ) # 3, Wd, Wh, Ww 161 | coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww 162 | relative_coords = ( 163 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 164 | ) # 3, Wd*Wh*Ww, Wd*Wh*Ww 165 | relative_coords = relative_coords.permute( 166 | 1, 2, 0 167 | ).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 168 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 169 | relative_coords[:, :, 1] += self.window_size[1] - 1 170 | relative_coords[:, :, 2] += self.window_size[2] - 1 171 | 172 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * ( 173 | 2 * self.window_size[2] - 1 174 | ) 175 | relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 176 | relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww 177 | self.register_buffer("relative_position_index", relative_position_index) 178 | 179 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 180 | self.attn_drop = nn.Dropout(attn_drop) 181 | self.proj = nn.Linear(dim, dim) 182 | self.proj_drop = nn.Dropout(proj_drop) 183 | 184 | trunc_normal_(self.relative_position_bias_table, std=0.02) 185 | self.softmax = nn.Softmax(dim=-1) 186 | 187 | def forward(self, x, mask=None): 188 | """Forward function. 189 | Args: 190 | x: input features with shape of (num_windows*B, N, C) 191 | mask: (0/-inf) mask with shape of (num_windows, N, N) or None 192 | """ 193 | B_, N, C = x.shape 194 | qkv = ( 195 | self.qkv(x) 196 | .reshape(B_, N, 3, self.num_heads, C // self.num_heads) 197 | .permute(2, 0, 3, 1, 4) 198 | ) 199 | q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C 200 | 201 | q = q * self.scale 202 | attn = q @ k.transpose(-2, -1) 203 | 204 | relative_position_bias = self.relative_position_bias_table[ 205 | self.relative_position_index[:N, :N].reshape(-1) 206 | ].reshape( 207 | N, N, -1 208 | ) # Wd*Wh*Ww,Wd*Wh*Ww,nH 209 | relative_position_bias = relative_position_bias.permute( 210 | 2, 0, 1 211 | ).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww 212 | attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N 213 | 214 | if mask is not None: 215 | nW = mask.shape[0] 216 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 217 | 1 218 | ).unsqueeze(0) 219 | attn = attn.view(-1, self.num_heads, N, N) 220 | attn = self.softmax(attn) 221 | else: 222 | attn = self.softmax(attn) 223 | 224 | attn = self.attn_drop(attn) 225 | 226 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 227 | x = self.proj(x) 228 | x = self.proj_drop(x) 229 | return x 230 | 231 | 232 | class SwinTransformerBlock3D(nn.Module): 233 | """Swin Transformer Block. 234 | Args: 235 | dim (int): Number of input channels. 236 | num_heads (int): Number of attention heads. 237 | window_size (tuple[int]): Window size. 238 | shift_size (tuple[int]): Shift size for SW-MSA. 239 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 240 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 241 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 242 | drop (float, optional): Dropout rate. Default: 0.0 243 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 244 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 245 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 246 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 247 | """ 248 | 249 | def __init__( 250 | self, 251 | dim, 252 | num_heads, 253 | window_size=(2, 7, 7), 254 | shift_size=(0, 0, 0), 255 | mlp_ratio=4.0, 256 | qkv_bias=True, 257 | qk_scale=None, 258 | drop=0.0, 259 | attn_drop=0.0, 260 | drop_path=0.0, 261 | act_layer=nn.GELU, 262 | norm_layer=nn.LayerNorm, 263 | use_checkpoint=False, 264 | ): 265 | super().__init__() 266 | self.dim = dim 267 | self.num_heads = num_heads 268 | self.window_size = window_size 269 | self.shift_size = shift_size 270 | self.mlp_ratio = mlp_ratio 271 | self.use_checkpoint = use_checkpoint 272 | 273 | assert ( 274 | 0 <= self.shift_size[0] < self.window_size[0] 275 | ), "shift_size must in 0-window_size" 276 | assert ( 277 | 0 <= self.shift_size[1] < self.window_size[1] 278 | ), "shift_size must in 0-window_size" 279 | assert ( 280 | 0 <= self.shift_size[2] < self.window_size[2] 281 | ), "shift_size must in 0-window_size" 282 | 283 | self.norm1 = norm_layer(dim) 284 | self.attn = WindowAttention3D( 285 | dim, 286 | window_size=self.window_size, 287 | num_heads=num_heads, 288 | qkv_bias=qkv_bias, 289 | qk_scale=qk_scale, 290 | attn_drop=attn_drop, 291 | proj_drop=drop, 292 | ) 293 | 294 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 295 | self.norm2 = norm_layer(dim) 296 | mlp_hidden_dim = int(dim * mlp_ratio) 297 | self.mlp = Mlp( 298 | in_features=dim, 299 | hidden_features=mlp_hidden_dim, 300 | act_layer=act_layer, 301 | drop=drop, 302 | ) 303 | 304 | def forward_part1(self, x, mask_matrix): 305 | B, D, H, W, C = x.shape 306 | window_size, shift_size = get_window_size( 307 | (D, H, W), self.window_size, self.shift_size 308 | ) 309 | 310 | x = self.norm1(x) 311 | # pad feature maps to multiples of window size 312 | pad_l = pad_t = pad_d0 = 0 313 | pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] 314 | pad_b = (window_size[1] - H % window_size[1]) % window_size[1] 315 | pad_r = (window_size[2] - W % window_size[2]) % window_size[2] 316 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) 317 | _, Dp, Hp, Wp, _ = x.shape 318 | # cyclic shift 319 | if any(i > 0 for i in shift_size): 320 | shifted_x = torch.roll( 321 | x, 322 | shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), 323 | dims=(1, 2, 3), 324 | ) 325 | attn_mask = mask_matrix 326 | else: 327 | shifted_x = x 328 | attn_mask = None 329 | # partition windows 330 | x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C 331 | # W-MSA/SW-MSA 332 | attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C 333 | # merge windows 334 | attn_windows = attn_windows.view(-1, *(window_size + (C,))) 335 | shifted_x = window_reverse( 336 | attn_windows, window_size, B, Dp, Hp, Wp 337 | ) # B D' H' W' C 338 | # reverse cyclic shift 339 | if any(i > 0 for i in shift_size): 340 | x = torch.roll( 341 | shifted_x, 342 | shifts=(shift_size[0], shift_size[1], shift_size[2]), 343 | dims=(1, 2, 3), 344 | ) 345 | else: 346 | x = shifted_x 347 | 348 | if pad_d1 > 0 or pad_r > 0 or pad_b > 0: 349 | x = x[:, :D, :H, :W, :].contiguous() 350 | return x 351 | 352 | def forward_part2(self, x): 353 | return self.drop_path(self.mlp(self.norm2(x))) 354 | 355 | def forward(self, x, mask_matrix): 356 | """Forward function. 357 | Args: 358 | x: Input feature, tensor size (B, D, H, W, C). 359 | mask_matrix: Attention mask for cyclic shift. 360 | """ 361 | 362 | shortcut = x 363 | if self.use_checkpoint: 364 | x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) 365 | else: 366 | x = self.forward_part1(x, mask_matrix) 367 | x = shortcut + self.drop_path(x) 368 | 369 | if self.use_checkpoint: 370 | x = x + checkpoint.checkpoint(self.forward_part2, x) 371 | else: 372 | x = x + self.forward_part2(x) 373 | 374 | return x 375 | 376 | 377 | class PatchMerging(nn.Module): 378 | """Patch Merging Layer 379 | Args: 380 | dim (int): Number of input channels. 381 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 382 | """ 383 | 384 | def __init__(self, dim, norm_layer=nn.LayerNorm): 385 | super().__init__() 386 | self.dim = dim 387 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 388 | self.norm = norm_layer(4 * dim) 389 | 390 | def forward(self, x): 391 | """Forward function. 392 | Args: 393 | x: Input feature, tensor size (B, D, H, W, C). 394 | """ 395 | B, D, H, W, C = x.shape 396 | 397 | # padding 398 | pad_input = (H % 2 == 1) or (W % 2 == 1) 399 | if pad_input: 400 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 401 | 402 | x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C 403 | x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C 404 | x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C 405 | x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C 406 | x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C 407 | 408 | x = self.norm(x) 409 | x = self.reduction(x) 410 | 411 | return x 412 | 413 | 414 | # cache each stage results 415 | @lru_cache() 416 | def compute_mask(D, H, W, window_size, shift_size, device): 417 | img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 418 | cnt = 0 419 | for d in ( 420 | slice(-window_size[0]), 421 | slice(-window_size[0], -shift_size[0]), 422 | slice(-shift_size[0], None), 423 | ): 424 | for h in ( 425 | slice(-window_size[1]), 426 | slice(-window_size[1], -shift_size[1]), 427 | slice(-shift_size[1], None), 428 | ): 429 | for w in ( 430 | slice(-window_size[2]), 431 | slice(-window_size[2], -shift_size[2]), 432 | slice(-shift_size[2], None), 433 | ): 434 | img_mask[:, d, h, w, :] = cnt 435 | cnt += 1 436 | mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 437 | mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] 438 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 439 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( 440 | attn_mask == 0, float(0.0) 441 | ) 442 | return attn_mask 443 | 444 | 445 | class BasicLayer(nn.Module): 446 | """A basic Swin Transformer layer for one stage. 447 | Args: 448 | dim (int): Number of feature channels 449 | depth (int): Depths of this stage. 450 | num_heads (int): Number of attention head. 451 | window_size (tuple[int]): Local window size. Default: (1,7,7). 452 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 453 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 454 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 455 | drop (float, optional): Dropout rate. Default: 0.0 456 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 457 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 458 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 459 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 460 | """ 461 | 462 | def __init__( 463 | self, 464 | dim, 465 | depth, 466 | num_heads, 467 | window_size=(1, 7, 7), 468 | mlp_ratio=4.0, 469 | qkv_bias=False, 470 | qk_scale=None, 471 | drop=0.0, 472 | attn_drop=0.0, 473 | drop_path=0.0, 474 | norm_layer=nn.LayerNorm, 475 | downsample=None, 476 | use_checkpoint=False, 477 | ): 478 | super().__init__() 479 | self.window_size = window_size 480 | self.shift_size = tuple(i // 2 for i in window_size) 481 | self.depth = depth 482 | self.use_checkpoint = use_checkpoint 483 | 484 | # build blocks 485 | self.blocks = nn.ModuleList( 486 | [ 487 | SwinTransformerBlock3D( 488 | dim=dim, 489 | num_heads=num_heads, 490 | window_size=window_size, 491 | shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, 492 | mlp_ratio=mlp_ratio, 493 | qkv_bias=qkv_bias, 494 | qk_scale=qk_scale, 495 | drop=drop, 496 | attn_drop=attn_drop, 497 | drop_path=drop_path[i] 498 | if isinstance(drop_path, list) 499 | else drop_path, 500 | norm_layer=norm_layer, 501 | use_checkpoint=use_checkpoint, 502 | ) 503 | for i in range(depth) 504 | ] 505 | ) 506 | 507 | self.downsample = downsample 508 | if self.downsample is not None: 509 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 510 | 511 | def forward(self, x): 512 | """Forward function. 513 | Args: 514 | x: Input feature, tensor size (B, C, D, H, W). 515 | """ 516 | # calculate attention mask for SW-MSA 517 | B, C, D, H, W = x.shape 518 | window_size, shift_size = get_window_size( 519 | (D, H, W), self.window_size, self.shift_size 520 | ) 521 | x = rearrange(x, "b c d h w -> b d h w c") 522 | Dp = int(np.ceil(D / window_size[0])) * window_size[0] 523 | Hp = int(np.ceil(H / window_size[1])) * window_size[1] 524 | Wp = int(np.ceil(W / window_size[2])) * window_size[2] 525 | attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 526 | for blk in self.blocks: 527 | x = blk(x, attn_mask) 528 | x = x.view(B, D, H, W, -1) 529 | 530 | if self.downsample is not None: 531 | x = self.downsample(x) 532 | x = rearrange(x, "b d h w c -> b c d h w") 533 | return x 534 | 535 | 536 | class PatchEmbed3D(nn.Module): 537 | """Video to Patch Embedding. 538 | Args: 539 | patch_size (int): Patch token size. Default: (2,4,4). 540 | in_chans (int): Number of input video channels. Default: 3. 541 | embed_dim (int): Number of linear projection output channels. Default: 96. 542 | norm_layer (nn.Module, optional): Normalization layer. Default: None 543 | """ 544 | 545 | def __init__(self, patch_size=(2, 4, 4), in_chans=3, embed_dim=96, norm_layer=None): 546 | super().__init__() 547 | self.patch_size = patch_size 548 | 549 | self.in_chans = in_chans 550 | self.embed_dim = embed_dim 551 | 552 | self.proj = nn.Conv3d( 553 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 554 | ) 555 | if norm_layer is not None: 556 | self.norm = norm_layer(embed_dim) 557 | else: 558 | self.norm = None 559 | 560 | def forward(self, x): 561 | """Forward function.""" 562 | # padding 563 | _, _, D, H, W = x.size() 564 | if W % self.patch_size[2] != 0: 565 | x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) 566 | if H % self.patch_size[1] != 0: 567 | x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) 568 | if D % self.patch_size[0] != 0: 569 | x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) 570 | 571 | x = self.proj(x) # B C D Wh Ww 572 | if self.norm is not None: 573 | D, Wh, Ww = x.size(2), x.size(3), x.size(4) 574 | x = x.flatten(2).transpose(1, 2) 575 | x = self.norm(x) 576 | x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) 577 | 578 | return x 579 | 580 | 581 | class SwinTransformer3D(nn.Module): 582 | """Swin Transformer backbone. 583 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 584 | https://arxiv.org/pdf/2103.14030 585 | Args: 586 | patch_size (int | tuple(int)): Patch size. Default: (4,4,4). 587 | in_chans (int): Number of input image channels. Default: 3. 588 | embed_dim (int): Number of linear projection output channels. Default: 96. 589 | depths (tuple[int]): Depths of each Swin Transformer stage. 590 | num_heads (tuple[int]): Number of attention head of each stage. 591 | window_size (int): Window size. Default: 7. 592 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 593 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee 594 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 595 | drop_rate (float): Dropout rate. 596 | attn_drop_rate (float): Attention dropout rate. Default: 0. 597 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 598 | norm_layer: Normalization layer. Default: nn.LayerNorm. 599 | patch_norm (bool): If True, add normalization after patch embedding. Default: False. 600 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 601 | -1 means not freezing any parameters. 602 | """ 603 | 604 | def __init__( 605 | self, 606 | pretrained=None, 607 | pretrained2d=True, 608 | patch_size=(4, 4, 4), 609 | in_chans=3, 610 | embed_dim=96, 611 | depths=[2, 2, 6, 2], 612 | num_heads=[3, 6, 12, 24], 613 | window_size=(2, 7, 7), 614 | mlp_ratio=4.0, 615 | qkv_bias=True, 616 | qk_scale=None, 617 | drop_rate=0.0, 618 | attn_drop_rate=0.0, 619 | drop_path_rate=0.2, 620 | norm_layer=nn.LayerNorm, 621 | patch_norm=False, 622 | frozen_stages=-1, 623 | use_checkpoint=False, 624 | ): 625 | super().__init__() 626 | 627 | self.pretrained = pretrained 628 | self.pretrained2d = pretrained2d 629 | self.num_layers = len(depths) 630 | self.embed_dim = embed_dim 631 | self.patch_norm = patch_norm 632 | self.frozen_stages = frozen_stages 633 | self.window_size = window_size 634 | self.patch_size = patch_size 635 | 636 | # split image into non-overlapping patches 637 | self.patch_embed = PatchEmbed3D( 638 | patch_size=patch_size, 639 | in_chans=in_chans, 640 | embed_dim=embed_dim, 641 | norm_layer=norm_layer if self.patch_norm else None, 642 | ) 643 | 644 | self.pos_drop = nn.Dropout(p=drop_rate) 645 | 646 | # stochastic depth 647 | dpr = [ 648 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) 649 | ] # stochastic depth decay rule 650 | 651 | # build layers 652 | self.layers = nn.ModuleList() 653 | for i_layer in range(self.num_layers): 654 | layer = BasicLayer( 655 | dim=int(embed_dim * 2 ** i_layer), 656 | depth=depths[i_layer], 657 | num_heads=num_heads[i_layer], 658 | window_size=window_size, 659 | mlp_ratio=mlp_ratio, 660 | qkv_bias=qkv_bias, 661 | qk_scale=qk_scale, 662 | drop=drop_rate, 663 | attn_drop=attn_drop_rate, 664 | drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], 665 | norm_layer=norm_layer, 666 | downsample=PatchMerging if i_layer < self.num_layers - 1 else None, 667 | use_checkpoint=use_checkpoint, 668 | ) 669 | self.layers.append(layer) 670 | 671 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 672 | 673 | # add a norm layer for each output 674 | self.norm = norm_layer(self.num_features) 675 | 676 | self._freeze_stages() 677 | 678 | def _freeze_stages(self): 679 | if self.frozen_stages >= 0: 680 | self.patch_embed.eval() 681 | for param in self.patch_embed.parameters(): 682 | param.requires_grad = False 683 | 684 | if self.frozen_stages >= 1: 685 | self.pos_drop.eval() 686 | for i in range(0, self.frozen_stages): 687 | m = self.layers[i] 688 | m.eval() 689 | for param in m.parameters(): 690 | param.requires_grad = False 691 | 692 | def inflate_weights(self): 693 | """Inflate the swin2d parameters to swin3d. 694 | The differences between swin3d and swin2d mainly lie in an extra 695 | axis. To utilize the pretrained parameters in 2d model, 696 | the weight of swin2d models should be inflated to fit in the shapes of 697 | the 3d counterpart. 698 | Args: 699 | logger (logging.Logger): The logger used to print 700 | debugging infomation. 701 | """ 702 | checkpoint = torch.load(self.pretrained, map_location="cpu") 703 | state_dict = checkpoint["model"] 704 | 705 | # delete relative_position_index since we always re-init it 706 | relative_position_index_keys = [ 707 | k for k in state_dict.keys() if "relative_position_index" in k 708 | ] 709 | for k in relative_position_index_keys: 710 | del state_dict[k] 711 | 712 | # delete attn_mask since we always re-init it 713 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 714 | for k in attn_mask_keys: 715 | del state_dict[k] 716 | 717 | state_dict["patch_embed.proj.weight"] = ( 718 | state_dict["patch_embed.proj.weight"] 719 | .unsqueeze(2) 720 | .repeat(1, 1, self.patch_size[0], 1, 1) 721 | / self.patch_size[0] 722 | ) 723 | 724 | # bicubic interpolate relative_position_bias_table if not match 725 | relative_position_bias_table_keys = [ 726 | k for k in state_dict.keys() if "relative_position_bias_table" in k 727 | ] 728 | for k in relative_position_bias_table_keys: 729 | relative_position_bias_table_pretrained = state_dict[k] 730 | relative_position_bias_table_current = self.state_dict()[k] 731 | L1, nH1 = relative_position_bias_table_pretrained.size() 732 | L2, nH2 = relative_position_bias_table_current.size() 733 | L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) 734 | wd = self.window_size[0] 735 | if nH1 != nH2: 736 | print(f"Error in loading {k}, passing") 737 | else: 738 | if L1 != L2: 739 | S1 = int(L1 ** 0.5) 740 | relative_position_bias_table_pretrained_resized = F.interpolate( 741 | relative_position_bias_table_pretrained.permute(1, 0).view( 742 | 1, nH1, S1, S1 743 | ), 744 | size=( 745 | 2 * self.window_size[1] - 1, 746 | 2 * self.window_size[2] - 1, 747 | ), 748 | mode="bicubic", 749 | ) 750 | relative_position_bias_table_pretrained = ( 751 | relative_position_bias_table_pretrained_resized.view( 752 | nH2, L2 753 | ).permute(1, 0) 754 | ) 755 | state_dict[k] = relative_position_bias_table_pretrained.repeat( 756 | 2 * wd - 1, 1 757 | ) 758 | 759 | self.load_state_dict(state_dict, strict=False) 760 | print(f"=> loaded successfully '{self.pretrained}'") 761 | del checkpoint 762 | torch.cuda.empty_cache() 763 | 764 | def init_weights(self, pretrained=None): 765 | """Initialize the weights in backbone. 766 | Args: 767 | pretrained (str, optional): Path to pre-trained weights. 768 | Defaults to None. 769 | """ 770 | 771 | def _init_weights(m): 772 | if isinstance(m, nn.Linear): 773 | trunc_normal_(m.weight, std=0.02) 774 | if isinstance(m, nn.Linear) and m.bias is not None: 775 | nn.init.constant_(m.bias, 0) 776 | elif isinstance(m, nn.LayerNorm): 777 | nn.init.constant_(m.bias, 0) 778 | nn.init.constant_(m.weight, 1.0) 779 | 780 | if pretrained: 781 | self.pretrained = pretrained 782 | if isinstance(self.pretrained, str): 783 | self.apply(_init_weights) 784 | print(f"load model from: {self.pretrained}") 785 | 786 | if self.pretrained2d: 787 | # Inflate 2D model into 3D model. 788 | self.inflate_weights() 789 | else: 790 | raise ValueError("Functionality not available!") 791 | elif self.pretrained is None: 792 | self.apply(_init_weights) 793 | else: 794 | raise TypeError("pretrained must be a str or None") 795 | 796 | def forward(self, x): 797 | """Forward function.""" 798 | x = self.patch_embed(x) 799 | 800 | x = self.pos_drop(x) 801 | 802 | for layer in self.layers: 803 | x = layer(x.contiguous()) 804 | 805 | x = rearrange(x, "n c d h w -> n d h w c") 806 | x = self.norm(x) 807 | x = rearrange(x, "n d h w c -> n c d h w") 808 | 809 | return x 810 | 811 | def train(self, mode=True): 812 | """Convert the model into training mode while keep layers freezed.""" 813 | super(SwinTransformer3D, self).train(mode) 814 | self._freeze_stages() 815 | --------------------------------------------------------------------------------