├── src ├── SCT │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── breakfast.py │ │ ├── utils.py │ │ └── general_dataset.py │ ├── models │ │ ├── data_classes.py │ │ ├── fc_model.py │ │ ├── fl_model.py │ │ ├── sct.py │ │ ├── fs_model.py │ │ ├── fer_model.py │ │ ├── general_model.py │ │ ├── __init__.py │ │ ├── temporal_modules.py │ │ ├── fu_model.py │ │ └── losses.py │ ├── evaluators │ │ ├── __init__.py │ │ ├── metrics.py │ │ └── general_evaluator.py │ ├── experiment │ │ ├── __init__.py │ │ └── general_experiment.py │ ├── config.py │ └── utils.py └── main.py ├── requirements.txt ├── README.md └── .gitignore /src/SCT/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.4,<2 2 | torch>=1.6.0 3 | torchvision>=0.3.0 4 | tqdm>=4.32.0,<5 5 | yacs>=0.1.6,<0.2 6 | tensorboard>=1.14,<2 7 | future>=0.17,<1.0 8 | edit-distance>=1.0.3,<2 9 | scipy 10 | -------------------------------------------------------------------------------- /src/SCT/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | from SCT.datasets.breakfast import create_breakfast_dataset 4 | from SCT.datasets.general_dataset import GeneralDataset 5 | 6 | 7 | def make_db(cfg: CfgNode, train: bool): 8 | dataset_name = cfg.dataset.name 9 | 10 | if dataset_name == breakfast.DATASET_NAME: 11 | db = create_breakfast_dataset(cfg, train) 12 | else: 13 | raise Exception("dataset not found. (name: %s)" % dataset_name) 14 | 15 | return db 16 | -------------------------------------------------------------------------------- /src/SCT/models/data_classes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch import Tensor 3 | 4 | 5 | @dataclass 6 | class ForwardOut: 7 | S: Tensor # [1 x C x T ] 8 | Y: Tensor # [1 x C x T] 9 | Z: Tensor # [1 x D' x T'] 10 | A: Tensor # [1 x C x K] 11 | L: Tensor # [1 x 1 x K] 12 | V: Tensor # [1 x M x C] 13 | 14 | @dataclass 15 | class LossOut: 16 | total_loss: Tensor # [] this is going to be used for backpropagation 17 | set_loss: Tensor # [] 18 | region_loss: Tensor # [] 19 | sct_loss: Tensor # [] 20 | temporal_consistency_loss: Tensor # [] 21 | length_loss: Tensor # [] 22 | inv_sparsity_loss: Tensor # [] 23 | -------------------------------------------------------------------------------- /src/SCT/models/fc_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from yacs.config import CfgNode 4 | 5 | 6 | class Fc(nn.Module): 7 | """ 8 | Abstract class for Fc(Z'): predicts the action probabilities for every region r over Z 9 | """ 10 | def forward(self, z_prime: Tensor) -> Tensor: 11 | raise NotImplementedError 12 | 13 | 14 | class Conv(Fc): 15 | def __init__(self, cfg: CfgNode, num_classes): 16 | super().__init__() 17 | self.cfg = cfg 18 | self.classifier = nn.Conv1d(self.cfg.model.fer.hidden_size, num_classes, 1) 19 | 20 | def forward(self, z_prime: Tensor): 21 | out = self.classifier(z_prime) 22 | return out 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCT: Set Constrained Temporal Transformer for Set Supervised Action Segmentation 2 | 3 | Source code of the CVPR 2020 paper: ["SCT: Set Constrained Temporal Transformer for Set Supervised Action Segmentation"](https://arxiv.org/pdf/2003.14266.pdf). 4 | 5 | 6 | 7 | ```latex 8 | @inproceedings{sct2020, 9 | Author = {Fayyaz, Mohsen and Gall, Juergen}, 10 | Title = {{SCT: Set Constrained Temporal Transformer for Set Supervised Action Segmentation}}, 11 | Booktitle = {{CVPR}}, 12 | Year = {2020} 13 | } 14 | ``` 15 | 16 | ## Requirements 17 | 18 | The main dependencies is: 19 | 20 | * Python 3.7 21 | * PyTorch 1.6 22 | 23 | Other dependencies are listed in the `requirements.txt` file. 24 | -------------------------------------------------------------------------------- /src/SCT/models/fl_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from yacs.config import CfgNode 4 | 5 | 6 | class Fl(nn.Module): 7 | """ 8 | Abstract class for Fl(Z'): estimates the temporal length of each region r over Z 9 | """ 10 | def forward(self, z_prime: Tensor) -> Tensor: 11 | raise NotImplementedError 12 | 13 | 14 | class Conv(Fl): 15 | def __init__(self, cfg: CfgNode): 16 | super().__init__() 17 | self.cfg = cfg 18 | D_prime = self.cfg.model.fer.hidden_size 19 | self.conv1 = nn.Conv1d(D_prime, int(D_prime/2), 1) 20 | self.activation_1 = nn.ReLU() 21 | self.conv2 = nn.Conv1d(int(D_prime/2), 1, 1) 22 | self.activation_2 = nn.ReLU() 23 | 24 | def forward(self, z_prime: Tensor): 25 | out = self.conv1(z_prime) 26 | # out = self.activation_1(out) 27 | out = self.conv2(out) 28 | # out = self.activation_2(out) 29 | 30 | return out 31 | -------------------------------------------------------------------------------- /src/SCT/models/sct.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | # noinspection PyPep8Naming 5 | def get_masks(Y: Tensor, A_hat: Tensor) -> Tensor: 6 | """ 7 | 8 | :param Y: upsampled predicted probabilities [1 x C x T] 9 | :param A_hat: A^ set of actions in the video [M] each item between 0 and C-1 10 | :return: W set of temporal masks [1 x M x T] 11 | """ 12 | # indices = torch.Tensor(list(A_hat)).to(device=Y.device).long() 13 | W = Y.index_select(dim=1, index=A_hat) # [1 x M x T] 14 | return W 15 | 16 | 17 | # noinspection PyPep8Naming 18 | def SCT(W: Tensor, S: Tensor, T: int) -> Tensor: 19 | """ 20 | Set Constrained Temporal Transformation 21 | :param W: masks [1 x M x T] 22 | :param S: S intermediate temporal representation over input video X [1 x C x T] 23 | :param T: T input video temporal length [1] 24 | :return: V set of predicted actions [1 x M x C] 25 | """ 26 | S = S.permute(0, 2, 1) # [1 x T x C] 27 | V = (W @ S) / T 28 | return V 29 | -------------------------------------------------------------------------------- /src/SCT/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from .general_evaluator import GeneralEvaluator 4 | 5 | 6 | def make_evaluators( 7 | cfg, train_db, test_db, model 8 | ) -> Tuple[Optional[GeneralEvaluator], Optional[GeneralEvaluator]]: 9 | 10 | train_evaluator = make_train_evaluator(cfg, train_db, model) 11 | val_evaluator = make_val_evaluator(cfg, test_db, model) 12 | 13 | return train_evaluator, val_evaluator 14 | 15 | 16 | def make_train_evaluator(cfg, train_db, model) -> Optional[GeneralEvaluator]: 17 | if cfg.training.evaluators.eval_train: 18 | return GeneralEvaluator(cfg=cfg, model=model, dataset=train_db) 19 | else: 20 | return None 21 | 22 | 23 | def make_val_evaluator(cfg, val_db, model) -> Optional[GeneralEvaluator]: 24 | return GeneralEvaluator(cfg=cfg, model=model, dataset=val_db) 25 | 26 | 27 | def make_evaluator_final(cfg, test_db, model) -> GeneralEvaluator: 28 | return GeneralEvaluator(cfg=cfg, model=model, dataset=test_db) 29 | -------------------------------------------------------------------------------- /src/SCT/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor 4 | from yacs.config import CfgNode 5 | 6 | from SCT.datasets import GeneralDataset 7 | from SCT.evaluators import GeneralEvaluator 8 | from SCT.experiment.general_experiment import GeneralExperiment 9 | from SCT.models import GeneralModel 10 | 11 | 12 | def make_experiment( 13 | cfg: CfgNode, 14 | dataset: GeneralDataset, 15 | model: GeneralModel, 16 | loss_weights: Tensor, 17 | val_evaluator: Optional[GeneralEvaluator], 18 | train_evaluator: Optional[GeneralEvaluator], 19 | ) -> GeneralExperiment: 20 | training_name = cfg.training.name 21 | if training_name == "normal": 22 | return GeneralExperiment( 23 | cfg, 24 | dataset, 25 | model, 26 | loss_weights, 27 | val_evaluator=val_evaluator, 28 | train_evaluator=train_evaluator, 29 | ) 30 | else: 31 | raise Exception("Invalid training name (%s)" % training_name) 32 | -------------------------------------------------------------------------------- /src/SCT/models/fs_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | import torch.nn.functional as F 4 | from yacs.config import CfgNode 5 | 6 | 7 | class Fs(nn.Module): 8 | """ 9 | Fs(Z) 10 | """ 11 | # noinspection PyPep8Naming 12 | def forward(self, Z: Tensor, T: int) -> Tensor: 13 | """ 14 | Predicts the temporal action probabilities S using the intermediate representation Z 15 | :param Z: intermediate representation [1 x D' x T'] 16 | :param T: input video X temporal length [1] 17 | :return: S [1 x C x T] 18 | """ 19 | raise NotImplementedError 20 | 21 | 22 | class Conv(Fs): 23 | def __init__(self, cfg: CfgNode, num_classes): 24 | super().__init__() 25 | self.cfg = cfg 26 | self.classifier = nn.Conv1d(self.cfg.model.fs.hidden_size, num_classes, 1) 27 | 28 | # noinspection PyPep8Naming 29 | def forward(self, Z: Tensor, T: int) -> Tensor: 30 | out = self.classifier(Z) # [1 x C x T'] 31 | out = F.interpolate(out, T) # [1 x C x T] 32 | return out # [1 x C x T] 33 | -------------------------------------------------------------------------------- /src/SCT/models/fer_model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | from yacs.config import CfgNode 6 | 7 | from .temporal_modules import WaveNetBlock 8 | 9 | 10 | class Fer(nn.Module): 11 | """ 12 | fe(X) & fr(Z) 13 | """ 14 | def forward(self, x: Tensor) -> List[Tensor]: 15 | raise NotImplementedError 16 | 17 | 18 | class WaveNet(Fer): 19 | def __init__(self, cfg: CfgNode, pooling_levels: List[int]): 20 | super().__init__() 21 | self.cfg = cfg.model.fer 22 | self.drop_on_frames = nn.Dropout(p=self.cfg.dropout_on_x) 23 | self.WaveNet = WaveNetBlock( 24 | in_channels=self.cfg.input_size, output_levels=self.cfg.output_levels, 25 | out_dims=self.cfg.hidden_size, pooling_levels=pooling_levels 26 | ) 27 | self.group_norm = nn.GroupNorm( 28 | num_groups=self.cfg.gn_num_groups, num_channels=self.cfg.hidden_size 29 | ) 30 | 31 | def forward(self, x: Tensor) -> List[Tensor]: 32 | """ 33 | 34 | :param x: 1 x D x T 35 | :return: 1 x D' x T' 36 | """ 37 | # x = x.permute(0, 2, 1) # [1 x T x D] 38 | x = self.drop_on_frames(x) 39 | # x = x.permute(0, 2, 1) # [1 x D x T] 40 | outputs = self.WaveNet(x) 41 | outputs[-1] = self.group_norm(outputs[-1]) 42 | return outputs 43 | -------------------------------------------------------------------------------- /src/SCT/datasets/breakfast.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from yacs.config import CfgNode 4 | 5 | from SCT.datasets.general_dataset import GeneralDataset 6 | from SCT.datasets.utils import MAPPING_FILE_NAME, FEAT_DIM_MAPPING 7 | 8 | POSSIBLE_SPLITS = [0, 1, 2, 3, 4] 9 | MAX_TRANSCRIPT_LENGTH = 25 10 | DATASET_NAME = "breakfast" 11 | NUM_CLASSES = 48 12 | 13 | 14 | def create_breakfast_dataset(cfg: CfgNode, train: bool = True) -> GeneralDataset: 15 | split = cfg.dataset.split 16 | feat_name = cfg.dataset.feat_name 17 | root = Path(cfg.dataset.root) 18 | assert split in POSSIBLE_SPLITS 19 | db_path = root / "datasets" / DATASET_NAME 20 | rnd_flip = False 21 | rnd_cat = False 22 | rnd_cat_n_vid = 0 23 | if train: 24 | rnd_flip = cfg.training.random_flip 25 | rnd_cat = cfg.training.random_concat.active 26 | rnd_cat_n_vid = cfg.training.random_concat.num_videos 27 | feat_list = db_path / "split{sn}_{tt}_feats_{fn}.txt".format( 28 | sn=split, tt="train" if train else "test", fn=feat_name 29 | ) 30 | 31 | gt_list = db_path / "split{sn}_{tt}_labels.txt".format( 32 | sn=split, tt="train" if train else "test" 33 | ) 34 | 35 | mapping = db_path / MAPPING_FILE_NAME 36 | 37 | db = GeneralDataset( 38 | root=root, 39 | feat_list=feat_list, 40 | gt_list=gt_list, 41 | mapping_file=mapping, 42 | feat_dim=FEAT_DIM_MAPPING[feat_name], 43 | num_classes= NUM_CLASSES, 44 | rnd_flip=rnd_flip, 45 | rnd_cat=rnd_cat, 46 | rnd_cat_n_vid=rnd_cat_n_vid 47 | ) 48 | db.end_class_id = 0 49 | db.mof_eval_ignore_classes = [] 50 | db.background_class_ids = [0] 51 | db.convenient_name = DATASET_NAME 52 | db.split = split 53 | db.max_transcript_length = MAX_TRANSCRIPT_LENGTH 54 | 55 | return db 56 | -------------------------------------------------------------------------------- /src/SCT/models/general_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch import Tensor, argmax 4 | from yacs.config import CfgNode 5 | 6 | from SCT.models import Fs 7 | from SCT.models.data_classes import ForwardOut 8 | from SCT.models.fc_model import Fc 9 | from SCT.models.fl_model import Fl 10 | from .fu_model import Fu 11 | from .fer_model import Fer 12 | from torch.nn.functional import softmax 13 | from ..datasets.general_dataset import BatchItem 14 | 15 | 16 | class GeneralModel(nn.Module): 17 | def __init__( 18 | self, 19 | cfg: CfgNode, 20 | fer: Fer, 21 | fs: Fs, 22 | fc: Fc, 23 | fl: Fl, 24 | fu: Fu, 25 | sct, 26 | get_masks, 27 | ): 28 | super().__init__() 29 | self.cfg = cfg 30 | self.fer = fer 31 | self.fs = fs 32 | self.fc = fc 33 | self.fl = fl 34 | self.fu = fu 35 | self.sct = sct 36 | self.get_masks = get_masks 37 | 38 | # noinspection PyPep8Naming 39 | def forward(self, batch: BatchItem) -> ForwardOut: 40 | T = batch.T 41 | X = batch.feats # .squeeze(dim=1) 42 | fer_outputs = self.fer.forward(X) # [[1 x D' x T'], [1 x D', K]] 43 | Z = fer_outputs[0] # [1 x D' x T'] 44 | Z_prime = fer_outputs[-1] # [1 x D' x K] 45 | S = self.fs(Z, T) # [1 x C x T] 46 | A = self.fc(Z_prime) # [1 x C x K] 47 | A = softmax(A / self.cfg.model.fc.softmax_temp, dim=1) 48 | L = self.fl(Z_prime) # [1 x 1 x K] 49 | Y = self.fu(A=A, L=L, T=T) # [1 x C x T] 50 | W = self.get_masks(Y=Y, A_hat=batch.A_hat) # [1 x M x T] 51 | V = self.sct(W=W, S=S, T=T) # [1 x M x C] 52 | # _res = F.interpolate(A, T) 53 | # Y = _res + Y 54 | return ForwardOut(S=S, Y=Y, Z=Z, A=A, L=L, V=V) 55 | -------------------------------------------------------------------------------- /src/SCT/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | from SCT.datasets import breakfast 4 | 5 | KINETICS_FEAT_NAME = "i3d" 6 | IDT_FEAT_NAME = "idt" 7 | FEAT_DIM_MAPPING = {KINETICS_FEAT_NAME: 2048, IDT_FEAT_NAME: 64} 8 | MAPPING_FILE_NAME = "mapping.txt" 9 | LIST_OF_DATASETS = ["breakfast"] 10 | 11 | BOS_I = 0 12 | BOS_S = "_bos_" 13 | EOS_I = 1 14 | EOS_S = "_eos_" 15 | BASE_STOI = {BOS_S: BOS_I, EOS_S: EOS_I} 16 | LEN_EXTRA_WORDS = len(BASE_STOI) 17 | 18 | 19 | 20 | def create_tf_input(transcript: List[int]) -> List[int]: 21 | return [BOS_I] + [x + LEN_EXTRA_WORDS for x in transcript] 22 | 23 | 24 | def create_tf_target(transcript: List[int]) -> List[int]: 25 | return [x + LEN_EXTRA_WORDS for x in transcript] + [EOS_I] 26 | 27 | 28 | def create_tf_set_target(transcript: List[int], num_classes: int) -> List[int]: 29 | set_target = [0]*num_classes 30 | for t in transcript: 31 | set_target[t] = 1 32 | return set_target 33 | 34 | 35 | def summarize_list(the_list: List[int]) -> Tuple[List[int], List[int]]: 36 | """ 37 | Given a list of items, it summarizes them in a way that no two neighboring values are the same. 38 | It also returns the size of each section. 39 | e.g. [4, 5, 5, 6] -> [4, 5, 6], [1, 2, 1] 40 | """ 41 | summary = [] 42 | lens = [] 43 | if len(the_list) > 0: 44 | current = the_list[0] 45 | summary.append(current) 46 | lens.append(1) 47 | for item in the_list[1:]: 48 | if item != current: 49 | current = item 50 | summary.append(item) 51 | lens.append(1) 52 | else: 53 | lens[-1] += 1 54 | return summary, lens 55 | 56 | 57 | def unsummarize_list(labels: List[int], lengths: List[int]) -> List[int]: 58 | """ 59 | Does the reverse of summarize list. You give it a list of segment labels and their lengths and it returns the full 60 | labels for the full sequence. 61 | e.g. ([4, 5, 6], [1, 2, 1]) -> [4, 5, 5, 6] 62 | """ 63 | assert len(labels) == len(lengths) 64 | 65 | the_sequence = [] 66 | for label, length in zip(labels, lengths): 67 | the_sequence.extend([label] * length) 68 | 69 | return the_sequence 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | detectron2_repo/ 3 | logs.txt 4 | .idea 5 | experiments/ 6 | pretrain/ 7 | # Byte-compiled / optimized / DLL files 8 | # 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 93 | __pypackages__/ 94 | 95 | # Celery stuff 96 | celerybeat-schedule 97 | celerybeat.pid 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # Queueing Tool logs 130 | /q.log/ 131 | # Pycharm Files 132 | *.pyc 133 | .nfs* 134 | # Logs 135 | src/log/ 136 | .idea/* 137 | /.idea/ -------------------------------------------------------------------------------- /src/SCT/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict 3 | from torch import Tensor, nn 4 | from .fu_model import Fu 5 | from yacs.config import CfgNode 6 | from SCT.models.fc_model import Fc 7 | from SCT.models.fl_model import Fl 8 | from SCT.models.fs_model import Fs 9 | from .fer_model import WaveNet, Fer 10 | from .general_model import GeneralModel 11 | from SCT.models.sct import SCT, get_masks 12 | 13 | 14 | def _make_fer(cfg: CfgNode) -> Fer: 15 | if cfg.model.fer.name == "wavenet": 16 | if cfg.model.fer.wavenet_pooling_levels is []: 17 | pooling_levels = [1, 2, 4, 8, 16] 18 | else: 19 | pooling_levels = cfg.model.fer.wavenet_pooling_levels 20 | return WaveNet(cfg, pooling_levels=pooling_levels) 21 | else: 22 | raise Exception("Invalid fer name") 23 | 24 | 25 | def _make_fs(cfg: CfgNode, num_classes: int) -> Fs: 26 | if cfg.model.fs.name == "conv": 27 | return fs_model.Conv( 28 | cfg, 29 | num_classes=num_classes, 30 | ) 31 | else: 32 | raise Exception("Invalid Fs name") 33 | 34 | 35 | def _make_fc(cfg: CfgNode, num_classes: int) -> Fc: 36 | if cfg.model.fc.name == "conv": 37 | return fc_model.Conv(cfg, num_classes=num_classes) 38 | else: 39 | raise Exception("Invalid fc name") 40 | 41 | 42 | def _make_fl(cfg: CfgNode) -> Fl: 43 | if cfg.model.fl.name == "conv": 44 | return fl_model.Conv(cfg) 45 | else: 46 | raise Exception("Invalid fl name") 47 | 48 | 49 | def _make_fu(cfg: CfgNode) -> Fu: 50 | if cfg.model.fu.name == "TemporalSampling": 51 | return fu_model.TemporalSamplingUpSampler() 52 | else: 53 | raise Exception("Invalid fu name") 54 | 55 | 56 | def make_loss_weights(num_classes: int, weights: Dict) -> Tensor: 57 | loss_weights = torch.ones(num_classes) 58 | for (class_idx, weight) in weights: 59 | loss_weights[class_idx] = weight 60 | loss_weights = nn.Parameter(loss_weights, requires_grad=False) 61 | return loss_weights 62 | 63 | 64 | def make_model(cfg: CfgNode, num_classes: int) -> GeneralModel: 65 | fer = _make_fer(cfg) 66 | fs = _make_fs(cfg, num_classes=num_classes) 67 | fc = _make_fc(cfg, num_classes=num_classes) 68 | fl = _make_fl(cfg) 69 | fu = _make_fu(cfg=cfg) 70 | sct_func = SCT 71 | get_masks_func = get_masks 72 | model = GeneralModel( 73 | cfg=cfg, 74 | fer=fer, 75 | fs=fs, 76 | fc=fc, 77 | fl=fl, 78 | fu=fu, 79 | sct=sct_func, 80 | get_masks=get_masks_func, 81 | ) 82 | return model 83 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from SCT.config import get_cfg_defaults 4 | from SCT.datasets import make_db 5 | from SCT.evaluators import make_evaluators, make_evaluator_final 6 | from SCT.experiment import make_experiment 7 | from SCT.models import make_model, make_loss_weights 8 | from SCT.utils import set_seed, change_multiprocess_strategy, print_with_time 9 | from yacs.config import CfgNode 10 | 11 | 12 | def parse_args(): 13 | """ 14 | Parse input arguments 15 | """ 16 | parser = argparse.ArgumentParser(description="Train") 17 | parser.add_argument( 18 | "--cfg", dest="cfg_file", help="optional config file", default=None, type=str 19 | ) 20 | parser.add_argument( 21 | "--set", 22 | dest="set_cfgs", 23 | help="set config keys", 24 | default=None, 25 | nargs=argparse.REMAINDER, 26 | ) 27 | 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def create_cfg() -> CfgNode: 33 | args = parse_args() 34 | cfg = get_cfg_defaults() 35 | if args.cfg_file is not None: 36 | cfg.merge_from_file(args.cfg_file) 37 | if args.set_cfgs is not None: 38 | cfg.merge_from_list(args.set_cfgs) 39 | cfg.freeze() 40 | return cfg 41 | 42 | 43 | def main(): 44 | cfg = create_cfg() 45 | # set_seed(cfg.system.seed) 46 | change_multiprocess_strategy() 47 | 48 | train_db = make_db(cfg, train=True) 49 | 50 | if cfg.training.overfit: 51 | test_db = train_db 52 | else: 53 | test_db = make_db(cfg, train=False) 54 | model = make_model( 55 | cfg, 56 | num_classes=train_db.num_classes, 57 | ) 58 | loss_weights = make_loss_weights( 59 | num_classes=train_db.num_classes, weights=cfg.loss.class_weight 60 | ) 61 | train_evaluator, val_evaluator = make_evaluators(cfg, train_db, test_db, model) 62 | experiment = make_experiment( 63 | cfg, 64 | train_db, 65 | model, 66 | loss_weights, 67 | val_evaluator=val_evaluator, 68 | train_evaluator=train_evaluator, 69 | ) 70 | 71 | if not cfg.training.only_test: 72 | if cfg.training.pretrained and cfg.training.resume: 73 | raise ValueError( 74 | "training.pretrained and training.resume" 75 | " flags cannot be True at the same time" 76 | ) 77 | elif cfg.training.pretrained: 78 | experiment.init_from_pretrain() 79 | elif cfg.training.resume: 80 | experiment.resume() 81 | experiment.train() 82 | else: 83 | experiment.load_model_for_test() 84 | 85 | final_evaluator = make_evaluator_final(cfg, test_db, model) 86 | final_eval_result = final_evaluator.evaluate() 87 | 88 | print_with_time("Final Evaluation Result ...") 89 | print(final_eval_result) 90 | if not cfg.training.only_test: 91 | print_with_time("Saving final model ...") 92 | experiment.save() 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /src/SCT/evaluators/metrics.py: -------------------------------------------------------------------------------- 1 | from difflib import SequenceMatcher 2 | from typing import List, Set 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn.functional import interpolate 7 | 8 | from SCT.utils import tensor_to_numpy 9 | 10 | 11 | def make_same_size( 12 | prediction: np.ndarray, target: np.ndarray, background: int = 0 13 | ) -> np.ndarray: 14 | """ 15 | Tries to use some heuristic to make the prediction the same size as the target. 16 | If the prediction is shorter, it will add background class at the end. 17 | If the prediction is longer, it will crop to the size of the target. 18 | :returns predictions. It will return the updated predictions file. 19 | """ 20 | 21 | t_len = len(target) 22 | p_len = len(prediction) 23 | 24 | if p_len == t_len: 25 | return prediction 26 | elif p_len > t_len: 27 | new_predictions = prediction.copy() 28 | extra_len = p_len - t_len 29 | new_predictions = new_predictions[:-extra_len] 30 | else: # p_len < t_len 31 | new_predictions = prediction.copy() 32 | remaining_len = t_len - p_len 33 | bg = np.full(remaining_len, fill_value=background) 34 | new_predictions = np.concatenate((new_predictions, bg), axis=0) 35 | return new_predictions 36 | 37 | 38 | def make_same_size_interpolate( 39 | prediction: np.ndarray, target: np.ndarray 40 | ) -> np.ndarray: 41 | """ 42 | It will use nearest neighbor interpolation to make the prediction the same size as the target. 43 | """ 44 | t_len = len(target) 45 | 46 | prediction_tensor = torch.tensor(prediction).float() 47 | prediction_tensor_resized = interpolate( 48 | prediction_tensor, size=t_len, mode="nearest" 49 | ) 50 | 51 | return tensor_to_numpy(prediction_tensor_resized.long()) 52 | 53 | 54 | # noinspection PyPep8Naming 55 | def MoF( 56 | predictions: List[np.ndarray], 57 | targets: List[np.ndarray], 58 | ignore_ids: List[int] = None, 59 | ) -> float: 60 | """ 61 | Calculates the Mean over frames segmentation metric. 62 | :param predictions: List of numpy arrays. Each array is assumed to do 1D. It should contain the id of the predicted 63 | frame label. 64 | :param targets: List of numpy arrays. 65 | :param ignore_ids: The list of ids that have to be ignored during evaluation. 66 | :return: the mean over frame metric. It is between 0 and 1. 67 | """ 68 | if ignore_ids is None: 69 | ignore_ids = [] 70 | if type(ignore_ids) == int: 71 | ignore_ids = [ignore_ids] 72 | 73 | assert len(predictions) == len(targets) 74 | total = 0 75 | correct = 0 76 | for i in range(len(predictions)): 77 | p = predictions[i] 78 | t = targets[i] 79 | 80 | assert len(p) == len(t) 81 | 82 | where_to_consider = np.ones(len(p)) 83 | for iid in ignore_ids: 84 | where_to_consider[np.where(t == iid)] = 0 85 | 86 | where_to_consider = np.where(where_to_consider) 87 | 88 | total += len(p[where_to_consider]) 89 | correct += (p[where_to_consider] == t[where_to_consider]).sum() 90 | 91 | return float(correct) / total 92 | 93 | def matching_score(gt_transcript: List[int], predicted_transcript: List[int]) -> float: 94 | return SequenceMatcher(a=gt_transcript, b=predicted_transcript).ratio() -------------------------------------------------------------------------------- /src/SCT/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | 5 | _C.system = CN() 6 | _C.system.device = "cuda" 7 | _C.system.num_workers = 2 8 | _C.system.seed = 1 9 | 10 | _C.dataset = CN() 11 | _C.dataset.root = "/Data/SCT" 12 | _C.dataset.name = "breakfast" 13 | _C.dataset.feat_name = "i3d" 14 | _C.dataset.split = 1 15 | 16 | _C.experiment = CN() 17 | _C.experiment.name = "default" 18 | _C.experiment.root = "/Data/ActionSet" 19 | _C.experiment.tb_root = "/Data/ActionSet" 20 | _C.experiment.run_number = -1 # -1 is the default 21 | _C.experiment.track_training_metrics_per_iter = False 22 | _C.experiment.log_inference_output = False 23 | 24 | _C.training = CN() 25 | _C.training.name = "normal" 26 | _C.training.save_every = 1 27 | _C.training.overfit = False 28 | _C.training.overfit_indices = [0] 29 | _C.training.clip_grad_norm = True 30 | _C.training.clip_grad_norm_value = 10 31 | _C.training.num_epochs = 120 32 | _C.training.optimizer = "SGD" 33 | _C.training.learning_rate = 0.005 34 | _C.training.momentum = 0.009 35 | _C.training.weight_decay = 0.000 36 | _C.training.random_flip = False 37 | _C.training.random_concat = CN() 38 | _C.training.random_concat.active = False 39 | _C.training.random_concat.num_videos = 1 40 | _C.training.scheduler = CN() 41 | _C.training.scheduler.name = "step" # can be 'none', 'plateau', 'step' 42 | 43 | _C.training.scheduler.multi_step = CN() 44 | _C.training.scheduler.multi_step.steps = [100] 45 | 46 | # below are the settings for plateau lr scheduler. 47 | _C.training.scheduler.plateau = CN() 48 | _C.training.scheduler.plateau.mode = "max" 49 | _C.training.scheduler.plateau.factor = 0.9 50 | _C.training.scheduler.plateau.verbose = True 51 | _C.training.scheduler.plateau.patience = 40 52 | 53 | _C.training.pretrained = False 54 | _C.training.pretrained_weight = "None" 55 | _C.training.skip_modules = [] 56 | 57 | _C.training.only_test = False 58 | _C.training.resume = False 59 | _C.training.resume_from = -1 # resume from the latest 60 | 61 | _C.training.evaluators = CN() 62 | _C.training.evaluators.eval_train = False 63 | _C.training.evaluators.ignore_classes = [0] 64 | 65 | _C.model = CN() 66 | _C.model.fer = CN() 67 | _C.model.fer.name = "wavenet" 68 | _C.model.fer.hidden_size = 128 69 | _C.model.fer.input_size = 2048 70 | _C.model.fer.dropout_on_x = 0.05 71 | _C.model.fer.gn_num_groups = 32 72 | _C.model.fer.wavenet_pooling_levels = [1, 2, 4, 8, 10] 73 | _C.model.fer.output_levels = [4] 74 | 75 | _C.model.fs = CN() 76 | _C.model.fs.name = "conv" 77 | _C.model.fs.hidden_size = 128 78 | _C.model.fs.dropout = 0.2 79 | _C.model.fs.set_pred_thr = 0.5 80 | _C.model.fs.length_loss_width = 1.0 81 | 82 | _C.model.fc = CN() 83 | _C.model.fc.name = "conv" 84 | _C.model.fc.softmax_temp = 0.1 85 | 86 | _C.model.fl = CN() 87 | _C.model.fl.name = "conv" 88 | 89 | _C.model.fu = CN() 90 | _C.model.fu.name = "TemporalSampling" 91 | 92 | 93 | _C.loss = CN() 94 | _C.loss.sct_loss_mul = 1.0 95 | _C.loss.region_loss_mul = 0.1 96 | _C.loss.set_loss_mul = 1.0 97 | _C.loss.length_loss_mul = 0.1 98 | _C.loss.temporal_consistency_loss_mul = 0.1 99 | _C.loss.inv_sparsity_loss_mul = 0.1 100 | _C.loss.inv_sparsity_loss_activation = 1.0 101 | _C.loss.inv_sparsity_loss_type = "L1" 102 | _C.loss.length_loss_width = 1.0 103 | _C.loss.interpolate_before_sct = True 104 | _C.loss.class_weight = [(0, 0.001)] 105 | 106 | def get_cfg_defaults(): 107 | return _C.clone() 108 | -------------------------------------------------------------------------------- /src/SCT/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import subprocess 4 | from typing import Tuple, Union, Set, List 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | from torch.utils.data import Sampler 10 | 11 | 12 | def set_seed(seed: int, fully_deterministic: bool = False): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | if fully_deterministic: 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | 21 | 22 | def send_to_device( 23 | items: Union[Tensor, Tuple[Tensor, ...], List[Tensor]], device 24 | ) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor]]: 25 | if type(items) is tuple: 26 | return tuple(map(lambda x: x.to(device), items)) 27 | elif type(items) is list: 28 | return list(map(lambda x: x.to(device), items)) 29 | else: 30 | return items.to(device) 31 | 32 | 33 | def get_git_commit_hash() -> str: 34 | try: 35 | return ( 36 | subprocess.check_output(["git", "rev-parse", "HEAD"]) 37 | .decode("utf-8") 38 | .strip() 39 | ) 40 | except subprocess.CalledProcessError: 41 | # this is probably not in a git repo or git is not installed. 42 | return "" 43 | 44 | 45 | def tensor_to_numpy(x: Tensor) -> np.ndarray: 46 | return x.detach().cpu().numpy() 47 | 48 | 49 | def tensors_to_numpys( 50 | x: Union[Tensor, Tuple[Tensor, ...]] 51 | ) -> Union[np.ndarray, Tuple[np.ndarray, ...]]: 52 | if type(x) is not tuple: 53 | return tensor_to_numpy(x) 54 | else: 55 | return tuple(map(lambda i: tensor_to_numpy(i), x)) 56 | 57 | 58 | def change_multiprocess_strategy(): 59 | torch.multiprocessing.set_sharing_strategy("file_system") 60 | 61 | 62 | def print_with_time(the_thing: str): 63 | print("[{}] {}".format(str(datetime.datetime.now()), the_thing)) 64 | 65 | 66 | class OverfitSampler(Sampler): 67 | # TODO: Mohsen: Better to make different dataset objects for overfitting 68 | # to have right value for len() in tqdm (used in train and evaluate functions) 69 | def __init__(self, main_source, indices, num_iter=0): 70 | """ 71 | 72 | :param main_source: 73 | :param indices: 74 | :param num_iter: 75 | 0: how_many=main_source_len/len(self.indices), 76 | otherwise: how_many=1 77 | """ 78 | super().__init__(main_source) 79 | self.main_source = main_source 80 | self.indices = indices 81 | 82 | if num_iter == 0: 83 | main_source_len = len(self.main_source) 84 | how_many = int(round(main_source_len / len(self.indices))) 85 | else: 86 | how_many = 1 87 | 88 | self.to_iter_from = [] 89 | for _ in range(how_many): 90 | self.to_iter_from.extend(self.indices) 91 | 92 | def __iter__(self): 93 | return iter(self.to_iter_from) 94 | 95 | def __len__(self): 96 | return len(self.main_source) 97 | 98 | 99 | def tensor_to_set(input_tensor: Tensor, thr: float) -> Set[int]: 100 | """ 101 | converts predicted tensor of sets to python built-in set 102 | :param input_tensor: [C] 103 | :param thr: 0.0 <= thr <= 1.0 104 | :return: \set\=M' 105 | """ 106 | s = input_tensor.detach().cpu().numpy() 107 | s = np.argwhere(s >= thr) 108 | return set(s.flatten()) 109 | 110 | 111 | def set_to_tensor(s: set) -> Tensor: 112 | """ 113 | 114 | :param s: 115 | :return: 116 | """ 117 | return torch.Tensor(sorted(list(s))).long() 118 | -------------------------------------------------------------------------------- /src/SCT/models/temporal_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from typing import List 6 | 7 | 8 | class WaveNetLayer(nn.Module): 9 | def __init__( 10 | self, num_channels: int, kernel_size: int, dilation: int, drop: float = 0.25 11 | ): 12 | super().__init__() 13 | self.num_channels = num_channels 14 | self.kernel_size = kernel_size 15 | self.dilation = dilation 16 | self.dilated_conv = nn.Conv1d( 17 | in_channels=self.num_channels, 18 | out_channels=self.num_channels, 19 | kernel_size=self.kernel_size, 20 | dilation=self.dilation, 21 | padding=self.dilation, 22 | ) 23 | self.conv_1x1 = nn.Conv1d( 24 | in_channels=self.num_channels, out_channels=self.num_channels, kernel_size=1 25 | ) 26 | self.drop = nn.Dropout(drop) 27 | 28 | @staticmethod 29 | def apply_non_lin(y: Tensor) -> Tensor: 30 | return F.relu(y) 31 | 32 | def forward(self, x: Tensor) -> Tensor: 33 | """ 34 | :param x: [B x num_channels x T] 35 | :return: [B x num_channels x T] 36 | """ 37 | y = self.dilated_conv.forward(x) 38 | y = self.apply_non_lin(y) # non-linearity 39 | y = self.conv_1x1.forward(y) 40 | y = self.drop.forward(y) # dropout 41 | y += x # residual connection 42 | return y 43 | 44 | 45 | class WaveNetBlock(nn.Module): 46 | def __init__( 47 | self, 48 | in_channels: int, 49 | stages: List[int] = (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024), 50 | pooling_levels: List[int] = [1, 2, 4, 8, 16], 51 | output_levels: List[int] = [], 52 | out_dims: int = 64, 53 | kernel_size: int = 3, 54 | pooling=True, 55 | ): 56 | super().__init__() 57 | self.in_channels = in_channels 58 | self.num_stages = len(stages) 59 | self.stages = stages 60 | self.out_dims = out_dims 61 | self.kernel_size = kernel_size 62 | self.layers = [] 63 | self.pooling = pooling 64 | self.pooling_levels = pooling_levels 65 | self.output_levels = output_levels 66 | 67 | self.first_conv = nn.Conv1d( 68 | in_channels=self.in_channels, out_channels=self.out_dims, kernel_size=1 69 | ) 70 | self.last_conv = nn.Conv1d( 71 | in_channels=self.out_dims, out_channels=self.out_dims, kernel_size=1 72 | ) 73 | 74 | for i in range(self.num_stages): 75 | stage = self.stages[i] 76 | layer = WaveNetLayer( 77 | self.out_dims, kernel_size=self.kernel_size, dilation=stage 78 | ) 79 | self.layers.append(layer) 80 | self.add_module("l_{}".format(i), layer) 81 | 82 | def forward(self, x: Tensor) -> List[Tensor]: 83 | """ 84 | :param x: [B x in_channels x T] 85 | :return: [B x out_dims x T] 86 | """ 87 | outputs = [] 88 | x = F.relu(self.first_conv.forward(x)) 89 | # fixme: clean the code 90 | pooling_levels = self.pooling_levels 91 | # if x.shape[2] > 10000: 92 | # pooling_levels = self.pooling_levels 93 | for i, l in enumerate(self.layers): 94 | x = l.forward(x) 95 | if i in pooling_levels and self.pooling: 96 | # fixme: clean the code 97 | # print("\n{}".format(x.shape)) 98 | x = F.max_pool1d(x, kernel_size=2) 99 | # fixme: clean the code 100 | # print(" {}".format(x.shape)) 101 | if i in self.output_levels: 102 | outputs.append(x) 103 | x = F.relu(x) 104 | x = self.last_conv.forward(x) 105 | outputs.append(x) 106 | 107 | return outputs 108 | -------------------------------------------------------------------------------- /src/SCT/models/fu_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | # noinspection PyPep8Naming 8 | def project_lengths_softmax(T: int, L: Tensor) -> Tensor: 9 | """ 10 | 11 | :param T: 1:int 12 | :param L: [1 x T']:float 13 | :return: [1 x T']:float 14 | """ 15 | return T * F.softmax(L, dim=0) 16 | 17 | 18 | class Fu(nn.Module): 19 | """ 20 | Abstract class for the fu(A,L): upsamples the action probabilities A w.r.t. 21 | estimated actions' temporal lengths L 22 | """ 23 | 24 | # noinspection PyPep8Naming 25 | def forward(self, A: Tensor, L: Tensor, T: int) -> Tensor: 26 | raise NotImplementedError 27 | 28 | 29 | class TemporalSamplingUpSampler(Fu): 30 | def __init__(self): 31 | super().__init__() 32 | self.temp_width = 100 33 | 34 | # noinspection PyPep8Naming 35 | @staticmethod 36 | def _normalize_location(T: int, pis: Tensor, sis: Tensor) -> Tensor: 37 | """ 38 | Normalizes the absolute value of z_where to the range that is appropriate for the network. 39 | :param T: 40 | :param pis: 41 | :param sis: unnormalized z_size 42 | :return: 43 | """ 44 | x = pis.clone() 45 | x += sis / 2 46 | x -= T / 2 47 | x /= -(sis / 2) 48 | 49 | return x 50 | 51 | @staticmethod 52 | def _create_params_matrix(sis: Tensor, pis: Tensor) -> Tensor: 53 | n = sis.size(0) 54 | theta = sis.new_zeros(torch.Size([n, 3])) 55 | 56 | s = sis.clone() 57 | x = pis.clone() 58 | # y = 0 59 | 60 | theta[:, 0] = s.view(-1) 61 | theta[:, 1] = x.view(-1) 62 | theta[:, 2] = 0 63 | return theta.float() 64 | 65 | @staticmethod 66 | def _create_theta(params: Tensor) -> Tensor: 67 | # Takes 3-dimensional vectors, and massages them into 2x3 matrices with elements like so: 68 | # [s,x,y] -> [[s,0,x], 69 | # [0,s,y]] 70 | n = params.size(0) 71 | expansion_indices = torch.LongTensor([1, 0, 2, 0, 1, 3]).to(params.device) 72 | out = torch.cat((params.new_zeros([1, 1]).expand(n, 1), params), 1) 73 | return torch.index_select(out, 1, expansion_indices).view(n, 2, 3) 74 | 75 | # noinspection PyPep8Naming 76 | @staticmethod 77 | def _normalize_scale(T: int, sis: Tensor) -> Tensor: 78 | return T / sis 79 | 80 | # noinspection PyPep8Naming 81 | def forward(self, A: Tensor, L: Tensor, T: int) -> Tensor: 82 | """ 83 | Given a set of predicted actions probabilities A_{i}s, upsamples them w.r.t. the given projected L_{i}s. 84 | :param L: [K] The projected lengths. 85 | :param A: [K x C] The predicted actions' probabilities. 86 | :return: [1 x C x ~T] Upsampled A_{k}s. 87 | """ 88 | 89 | A = A.squeeze().permute(1, 0) # [K x C] 90 | L = L.squeeze() # [K] 91 | L_prime = project_lengths_softmax(T=T, L=L) 92 | K = A.shape[0] 93 | C = A.shape[1] 94 | l_max = int(L_prime.max() + 0.5) # round to the nearest int 95 | pis = torch.zeros_like(L_prime) # [K] 96 | 97 | normalized_l = self._normalize_scale(l_max, L_prime) 98 | normalized_p = self._normalize_location(l_max, pis, L_prime) 99 | 100 | params_mat = self._create_params_matrix(normalized_l, normalized_p) # [K x 3] 101 | theta = self._create_theta(params_mat) # [K x 2 x 3] 102 | 103 | grid = F.affine_grid(theta, torch.Size((K, C, 1, l_max))) 104 | 105 | temp_A = A.view(K, C, 1, 1).expand(-1, -1, -1, self.temp_width) 106 | upsampled_probs = F.grid_sample(temp_A, grid, mode="bilinear") 107 | upsampled_probs = upsampled_probs.view(K, C, l_max) # [K x C x l_max] 108 | upsampled_cropped = [] 109 | for i, prob in enumerate(upsampled_probs): 110 | prob_cropped = prob[:, 0 : round(L_prime[i].item())] 111 | upsampled_cropped.append(prob_cropped) 112 | 113 | out = torch.cat(upsampled_cropped, dim=1).unsqueeze(dim=0) # [1 x C x ~T] 114 | out = F.interpolate(input=out, size=T) # [1 x C x T] 115 | return out # [1 x C x T] 116 | -------------------------------------------------------------------------------- /src/SCT/models/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | # noinspection PyPep8Naming 7 | from torch.nn.functional import binary_cross_entropy, pad, one_hot, relu, cross_entropy 8 | from torch.nn.functional import adaptive_max_pool1d 9 | 10 | 11 | # noinspection PyPep8Naming 12 | from yacs.config import CfgNode 13 | 14 | from SCT.datasets.general_dataset import BatchItem 15 | from SCT.models.data_classes import ForwardOut, LossOut 16 | 17 | 18 | # OldName: OverlapLoss 19 | # noinspection PyPep8Naming 20 | def RegionLoss(A: Tensor, A_hat: Tensor) -> Tensor: 21 | """ 22 | 23 | :param A: [1 x C x K] 24 | :param A_hat: 25 | :return: 26 | """ 27 | indices = A_hat 28 | A = A.index_select(dim=1, index=indices) # [1 x |A_hat| x K] 29 | A = A.permute([0, 2, 1]) # [1 x K x |A_hat|] 30 | max_classes = adaptive_max_pool1d(input=A, output_size=1).squeeze() # [K] 31 | # max_classes = max_classes.view(-1, 1) # [K x 1] 32 | target = torch.ones_like(max_classes) 33 | loss = binary_cross_entropy(max_classes, target) 34 | return loss 35 | 36 | 37 | # noinspection PyPep8Naming 38 | def SetLoss(A: Tensor, A_hat: Tensor, weight: Tensor) -> Tensor: 39 | """ 40 | 41 | :param A: 42 | :param A_hat: 43 | :param weight: 44 | :return: 45 | """ 46 | indices = A_hat 47 | A = A.index_select(dim=1, index=indices) # [1 x |A_hat| x K] 48 | classes = adaptive_max_pool1d(input=A, output_size=1) # [1 x |A_hat| x 1] 49 | classes = classes.view(-1, 1) # [|A_hat| x 1] 50 | target = classes.new_ones(classes.shape[0]) 51 | weight = weight.index_select(dim=0, index=indices) 52 | loss = binary_cross_entropy(classes.view(-1), target, weight=weight) 53 | return loss 54 | 55 | 56 | # noinspection PyPep8Naming 57 | def TemporalConsistencyLoss(A: Tensor, A_hat: Tensor) -> Tensor: 58 | """ 59 | 60 | :param A: 61 | :param A_hat: 62 | :return: 63 | """ 64 | A = A.index_select(dim=1, index=A_hat) # [1 x |S| x T'] 65 | shifted_right = pad(A, (1, 0), mode="replicate") 66 | shifted_left = pad(A, (0, 1), mode="replicate") 67 | loss = torch.abs(shifted_right - shifted_left).sum() 68 | return loss 69 | 70 | 71 | def kl_div(p: Tensor, q: Tensor) -> Tensor: 72 | """ 73 | args: 74 | :param p: Tensor same size as q 75 | :param q: Tensor same size as p 76 | :returns: kl divergence between the `p` and `q` 77 | """ 78 | 79 | s1 = torch.sum(p * torch.log(p / q)) 80 | s2 = torch.sum((1 - p) * torch.log((1 - p) / (1 - q))) 81 | return s1 + s2 82 | 83 | 84 | # noinspection PyPep8Naming 85 | def InverseSparsityLoss( 86 | A: Tensor, A_hat: Tensor, weight: Tensor, loss_type: str, activation: float = 1.0 87 | ) -> Tensor: 88 | """ 89 | 90 | :param A: [1 x C x K] 91 | :param A_hat: 92 | :param weight: 93 | :param loss_type: kl_div or L1 94 | :param activation: 95 | :return: 96 | """ 97 | target_indices = A_hat 98 | target_A = A.index_select(dim=1, index=target_indices) # [1 x |A_hat| x K] 99 | w = weight.index_select(dim=0, index=target_indices) 100 | 101 | if loss_type == "kl_div": 102 | a = ( 103 | torch.ones( 104 | size=[1, A_hat.shape[0]], 105 | dtype=target_A.dtype, 106 | device=target_A.device, 107 | ) 108 | * activation 109 | ) 110 | loss = (kl_div(target_A.mean(dim=2), a) * w).mean() 111 | elif loss_type == "L1": 112 | loss = (torch.abs(activation - target_A.mean(dim=2)) * w).mean() 113 | 114 | return loss 115 | 116 | 117 | # noinspection PyPep8Naming 118 | def un_normalized_length_regularizer(length_loss_width: float, L: Tensor) -> Tensor: 119 | """ 120 | relu(s - w) + relu(- w - s) 121 | """ 122 | y_right = relu(L - length_loss_width) 123 | y_left = relu(-length_loss_width - L) 124 | 125 | return (y_right + y_left).sum() 126 | 127 | 128 | # noinspection PyPep8Naming 129 | def loss_func( 130 | A: Tensor, 131 | L: Tensor, 132 | V: Tensor, 133 | A_hat: Tensor, 134 | cfg: CfgNode, 135 | weight: Tensor, 136 | ) -> LossOut: 137 | """ 138 | 139 | :param A: 140 | :param L: 141 | :param V: 142 | :param A_hat: 143 | :param cfg: 144 | :param weight: 145 | :return: 146 | """ 147 | 148 | # Regularizer over L -------- 149 | length_loss = ( 150 | un_normalized_length_regularizer(length_loss_width=cfg.length_loss_width, L=L) 151 | * cfg.length_loss_mul 152 | ) 153 | 154 | # Losses over A -------- 155 | set_loss = SetLoss(A=A, A_hat=A_hat, weight=weight) * cfg.set_loss_mul 156 | temporal_consistency_loss = ( 157 | TemporalConsistencyLoss(A=A, A_hat=A_hat) * cfg.temporal_consistency_loss_mul 158 | ) 159 | region_loss = RegionLoss(A=A, A_hat=A_hat) * cfg.region_loss_mul 160 | inv_sparsity_loss = ( 161 | InverseSparsityLoss( 162 | A=A, 163 | A_hat=A_hat, 164 | weight=weight, 165 | loss_type=cfg.inv_sparsity_loss_type, 166 | activation=cfg.inv_sparsity_loss_activation, 167 | ) 168 | * cfg.inv_sparsity_loss_mul 169 | ) 170 | 171 | sct_loss = cross_entropy(V.squeeze(dim=0), A_hat, weight=weight) * cfg.sct_loss_mul 172 | 173 | total_loss = ( 174 | length_loss 175 | + set_loss 176 | + temporal_consistency_loss 177 | + region_loss 178 | + inv_sparsity_loss 179 | + sct_loss 180 | ) 181 | return LossOut( 182 | total_loss=total_loss, 183 | length_loss=length_loss, 184 | set_loss=set_loss, 185 | temporal_consistency_loss=temporal_consistency_loss, 186 | region_loss=region_loss, 187 | inv_sparsity_loss=inv_sparsity_loss, 188 | sct_loss=sct_loss, 189 | ) 190 | -------------------------------------------------------------------------------- /src/SCT/evaluators/general_evaluator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from os import mkdir 3 | from os.path import exists 4 | from typing import List, Dict 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from yacs.config import CfgNode 11 | import edit_distance 12 | 13 | from SCT.datasets import GeneralDataset 14 | from SCT.datasets.general_dataset import BatchItem 15 | from SCT.datasets.utils import unsummarize_list, summarize_list 16 | from SCT.evaluators.metrics import ( 17 | make_same_size_interpolate, 18 | matching_score, 19 | MoF, 20 | ) 21 | from SCT.models import GeneralModel 22 | from SCT.utils import print_with_time, tensor_to_numpy, OverfitSampler 23 | 24 | 25 | @dataclass 26 | class EvalResult: 27 | mof: float 28 | 29 | 30 | @dataclass 31 | class FinalEvalResult(EvalResult): 32 | viterbi_mof: float 33 | viterbi_mof_rnn: float 34 | 35 | 36 | class GeneralEvaluator(object): 37 | def __init__( 38 | self, 39 | cfg: CfgNode, 40 | model: GeneralModel, 41 | dataset: GeneralDataset, 42 | teacher_forcing: bool = False, 43 | ): 44 | self.cfg = cfg 45 | self.model = model 46 | self.dataset = dataset 47 | self.dataloader = self.create_eval_dataloader( 48 | cfg=self.cfg, dataset=self.dataset 49 | ) 50 | self.device = torch.device(self.cfg.system.device) 51 | self.ignore_classes = cfg.training.evaluators.ignore_classes 52 | 53 | self.predicted_sets = [] 54 | self.target_sets = [] 55 | self.target_segmentations = [] 56 | self.predicted_segmentations = [] 57 | self.masks_prediction_segmentations = [] 58 | self.lengths_prediction_segmentations = [] 59 | 60 | self.target_sentences = [] 61 | self.predicted_sentences = [] 62 | 63 | self.IoU_values = [] 64 | 65 | def convert_seg_to_text(self, seg_pred: torch.Tensor) -> str: 66 | return "\n".join([self.dataset.mapping[x] for x in seg_pred]) 67 | 68 | def log_inference( 69 | self, video_name, predictions, target_segmentation, target_transcript 70 | ): 71 | pred_trn = summarize_list(predictions)[0] 72 | ms = matching_score( 73 | gt_transcript=target_transcript.tolist(), predicted_transcript=pred_trn 74 | ) 75 | if ( 76 | matching_score( 77 | gt_transcript=target_transcript.tolist(), 78 | predicted_transcript=[0, 24, 2, 25, 0], 79 | ) 80 | == 1 81 | ): 82 | print(video_name) 83 | ms *= 100 84 | ms = str(ms) # [:2] 85 | 86 | if not exists("./log/" + ms): 87 | mkdir("./log/" + ms) 88 | 89 | file = open("./log/" + ms + "/" + video_name + ".txt", "w") 90 | text = self.convert_seg_to_text(predictions) 91 | file.write(text) 92 | file.close() 93 | 94 | file = open("./log/" + ms + "/" + video_name + "_target.txt", "w") 95 | text = self.convert_seg_to_text(target_segmentation) 96 | file.write(text) 97 | file.close() 98 | 99 | @staticmethod 100 | def create_eval_dataloader(cfg: CfgNode, dataset: GeneralDataset) -> DataLoader: 101 | """ 102 | 103 | :param cfg: 104 | :param dataset: 105 | :return: 106 | """ 107 | if cfg.training.overfit: 108 | sampler = OverfitSampler( 109 | main_source=dataset, indices=cfg.training.overfit_indices, num_iter=1 110 | ) 111 | 112 | return DataLoader( 113 | dataset, 114 | batch_size=1, 115 | shuffle=False, 116 | sampler=sampler, 117 | num_workers=cfg.system.num_workers, 118 | collate_fn=dataset.collate_fn, 119 | pin_memory=True, 120 | ) 121 | else: 122 | return DataLoader( 123 | dataset, 124 | batch_size=1, 125 | num_workers=cfg.system.num_workers, 126 | collate_fn=dataset.collate_fn, 127 | shuffle=False, 128 | pin_memory=True, 129 | ) 130 | 131 | # noinspection PyPep8Naming 132 | def eval_1_batch(self, batch: BatchItem): 133 | batch.to(self.device) 134 | forward_out = self.model.forward(batch) 135 | 136 | _, Y_pred = forward_out.Y.topk(1, dim=1) 137 | _, A_pred = forward_out.A.topk(1, dim=1) 138 | target_segmentation = tensor_to_numpy(batch.gt_label) 139 | predicted_segmentation = make_same_size_interpolate( 140 | tensor_to_numpy(Y_pred), target=target_segmentation 141 | ).reshape(-1) 142 | masks_prediction_segmentation = make_same_size_interpolate( 143 | tensor_to_numpy(A_pred), target=target_segmentation 144 | ).reshape(-1) 145 | 146 | self.target_segmentations.append(target_segmentation) 147 | self.predicted_segmentations.append(predicted_segmentation) 148 | self.masks_prediction_segmentations.append(masks_prediction_segmentation) 149 | if self.cfg.experiment.log_inference_output: 150 | self.log_inference(batch.video_name, target_segmentation, batch.transcript) 151 | 152 | def reset(self): 153 | self.predicted_sets = [] 154 | self.target_sets = [] 155 | self.target_segmentations = [] 156 | self.predicted_segmentations = [] 157 | self.masks_prediction_segmentations = [] 158 | self.lengths_prediction_segmentations = [] 159 | 160 | self.target_sentences = [] 161 | self.predicted_sentences = [] 162 | 163 | self.IoU_values = [] 164 | 165 | def compute_metrics(self, ignore_classes: List[int] = None) -> EvalResult: 166 | mof = MoF( 167 | predictions=self.predicted_segmentations, 168 | targets=self.target_segmentations, 169 | ignore_ids=ignore_classes, 170 | ) 171 | return EvalResult( 172 | mof=mof, 173 | ) 174 | 175 | def evaluate(self) -> Dict[str, EvalResult]: 176 | print_with_time("Evaluating ...") 177 | self.reset() 178 | self.model.to(self.device) 179 | self.model.eval() 180 | 181 | with torch.no_grad(): 182 | for batch in tqdm(self.dataloader): 183 | self.eval_1_batch(batch) 184 | 185 | result = {"All": self.compute_metrics()} 186 | 187 | if not self.ignore_classes == []: 188 | result["W/O Ignored Classes"] = self.compute_metrics(self.ignore_classes) 189 | 190 | return result 191 | -------------------------------------------------------------------------------- /src/SCT/datasets/general_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from random import randint 4 | from typing import List, Set, Dict 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | from torch.utils.data import Dataset 10 | 11 | from SCT.utils import send_to_device, set_to_tensor 12 | from .utils import ( 13 | summarize_list, 14 | create_tf_input, 15 | create_tf_target, 16 | create_tf_set_target, 17 | ) 18 | 19 | 20 | @dataclass 21 | class BatchItem: 22 | """ 23 | T: the video length 24 | D: the feat dim 25 | M: number of actions in the video. 26 | """ 27 | 28 | video_name: str 29 | feats: Tensor # [1 x D x T] float 30 | gt_label: Tensor # [T] long 31 | transcript: Tensor # [M] long 32 | action_lengths: Tensor # [M] int 33 | tf_input: Tensor # [M + 1] long: equal to BOS + transcript 34 | tf_target: Tensor # [M + 1] long: equal to transcript + EOS 35 | tf_set_target: Tensor # [C] long 36 | gt_set: Set[int] # |gt_set|=M 37 | gt_set_list: List[Set[int]] # |set(gt_set)|=M 38 | A_hat: Tensor 39 | target_indices_list: List[Tensor] 40 | not_target_indices: Tensor 41 | not_target_indices_list: List[Tensor] 42 | not_target_set: Set[int] # |not_target_set|=|U| - M 43 | not_target_set_list: List[Set[int]] # |not_target_set|=|U| - M 44 | target_indices_mapping: Dict 45 | t_videos: List[int] # lengths of the input videos 46 | T: int # video length 47 | 48 | def to(self, device): 49 | self.feats = send_to_device(self.feats, device) 50 | self.gt_label = send_to_device(self.gt_label, device) 51 | self.transcript = send_to_device(self.transcript, device) 52 | self.action_lengths = send_to_device(self.action_lengths, device) 53 | self.tf_input = send_to_device(self.tf_input, device) 54 | self.tf_target = send_to_device(self.tf_target, device) 55 | self.tf_set_target = send_to_device(self.tf_set_target, device) 56 | self.A_hat = send_to_device(self.A_hat, device) 57 | self.target_indices_list = send_to_device(self.target_indices_list, device) 58 | self.not_target_indices = send_to_device(self.not_target_indices, device) 59 | self.not_target_indices_list = send_to_device( 60 | self.not_target_indices_list, device 61 | ) 62 | 63 | 64 | class GeneralDataset(Dataset): 65 | def __init__( 66 | self, 67 | root: str, 68 | feat_list: str = None, 69 | gt_list: str = None, 70 | mapping_file: str = None, 71 | feat_dim: int = -1, 72 | num_classes: int = -1, 73 | rnd_flip: bool = False, 74 | rnd_cat: bool = False, 75 | rnd_cat_n_vid: int = 0, 76 | ): 77 | """ 78 | feat_list: a file containing the relative path to numpy files with video features in them, separated by new line 79 | video features should have the shape: (feat_dim, n_frames) 80 | gt_list: a file containing the relative path to txt files with framewise labels in them, separated by new line. 81 | mapping_file: a file containing the mapping from integers to gt_labels. 82 | """ 83 | self.root = root 84 | self.feat_list = feat_list 85 | self.gt_list = gt_list 86 | self.mapping_file = mapping_file 87 | self.end_class_id = 0 88 | self.mof_eval_ignore_classes = [] 89 | self.n_classes = num_classes 90 | self.background_class_ids = [0] 91 | # following are defaults, should be set 92 | self.feat_dim = feat_dim 93 | self.convenient_name = None 94 | self.split = -1 95 | self.max_transcript_length = 100 96 | self.rnd_flip = rnd_flip 97 | self.rnd_cat = rnd_cat 98 | self.rnd_cat_n_vid = rnd_cat_n_vid 99 | 100 | if self.feat_list is not None: 101 | with open(self.feat_list) as f: 102 | self.feat_file_paths = [x.strip() for x in f] 103 | else: 104 | self.feat_file_paths = [] 105 | 106 | if self.gt_list is not None: 107 | with open(self.gt_list) as f: 108 | self.gt_file_paths = [x.strip() for x in f] 109 | else: 110 | self.gt_file_paths = [] 111 | 112 | self.mapping = {} 113 | self.inverse_mapping = {} 114 | if self.mapping_file is not None: 115 | with open(self.mapping_file) as f: 116 | the_mapping = [tuple(x.strip().split()) for x in f] 117 | 118 | for (i, l) in the_mapping: 119 | self.mapping[int(i)] = l 120 | self.inverse_mapping[l] = int(i) 121 | 122 | assert len(self.feat_file_paths) == len(self.gt_file_paths) 123 | 124 | @property 125 | def num_classes(self) -> int: 126 | return len(self.mapping) 127 | 128 | def __len__(self) -> int: 129 | return len(self.feat_file_paths) 130 | 131 | def __getitem__(self, item: int, no_features: bool = False) -> BatchItem: 132 | """ 133 | parameters: 134 | no_features: if set to True, this method will *not* actually load_model the features from disk. 135 | This is useful if we want to quickly run some code which doesn't need the features. 136 | """ 137 | feat_file_path = os.path.join(self.root, self.feat_file_paths[item]) 138 | gt_file_path = os.path.join(self.root, self.gt_file_paths[item]) 139 | 140 | if not no_features: 141 | vid_feats = torch.tensor(torch.load(feat_file_path)).float() 142 | else: 143 | vid_feats = torch.tensor([0]) 144 | 145 | # vid_feats.t_() 146 | 147 | with open(gt_file_path) as f: 148 | gt_label_names = [ 149 | x.strip() for x in f.read().split("\n") if len(x.strip()) > 0 150 | ] 151 | 152 | gt_label_ids = [self.inverse_mapping[x] for x in gt_label_names] 153 | weak_label_ids, weak_label_lens = summarize_list(gt_label_ids) 154 | 155 | gt_labels = torch.tensor(gt_label_ids).long() 156 | gt_action_lengths = torch.tensor(weak_label_lens).int() 157 | weak_labels = torch.tensor(weak_label_ids).long() 158 | weak_label_tf = torch.tensor(create_tf_input(weak_label_ids)).long() 159 | weak_label_target = torch.tensor(create_tf_target(weak_label_ids)).long() 160 | set_label_target = torch.tensor( 161 | create_tf_set_target(weak_label_ids, self.n_classes) 162 | ).float() 163 | gt_set_list = [set(weak_label_ids)] 164 | gt_set = set(weak_label_ids) 165 | target_indices_list = [set_to_tensor(gt_set)] 166 | not_target_set_list = [set(range(self.num_classes)).difference(set(gt_set))] 167 | not_target_indices_list = [set_to_tensor(not_target_set_list[0])] 168 | # fixme: make it object oriented 169 | if self.rnd_flip: 170 | if randint(0, 1) == 1: 171 | vid_feats = vid_feats.flip(1) 172 | t_vids = [vid_feats.shape[1]] 173 | 174 | # fixme: make it object oriented 175 | feats = [vid_feats] 176 | if self.rnd_cat: 177 | if randint(0, 1) == 1: 178 | for i in range(self.rnd_cat_n_vid): 179 | item = randint(0, self.__len__() - 1) 180 | feat_file_path = os.path.join(self.root, self.feat_file_paths[item]) 181 | gt_file_path = os.path.join(self.root, self.gt_file_paths[item]) 182 | f = torch.tensor(np.load(feat_file_path)).float().t_() 183 | # fixme: make it object oriented 184 | if self.rnd_flip: 185 | if randint(0, 1) == 1: 186 | f = f.flip(1) 187 | t_vids.append(f.shape[1]) 188 | feats.append(f) 189 | with open(gt_file_path) as f: 190 | gt_label_names = [ 191 | x.strip() 192 | for x in f.read().split("\n") 193 | if len(x.strip()) > 0 194 | ] 195 | gt_label_ids = [self.inverse_mapping[x] for x in gt_label_names] 196 | weak_label_ids, weak_label_lens = summarize_list(gt_label_ids) 197 | gt_set_list.append(set(weak_label_ids)) 198 | target_indices_list.append(set_to_tensor(gt_set_list[-1])) 199 | not_target_set_list.append( 200 | set(range(self.num_classes)).difference(set(weak_label_ids)) 201 | ) 202 | not_target_indices_list.append( 203 | set_to_tensor(not_target_set_list[-1]) 204 | ) 205 | gt_set = gt_set.union(set(weak_label_ids)) 206 | 207 | set_label_target = torch.zeros(self.n_classes) 208 | set_label_target[list(gt_set)] = 1.0 209 | vid_feats = torch.cat(feats, dim=1) 210 | 211 | vid_feats.unsqueeze_(0) 212 | not_target_set = set(range(self.num_classes)).difference(set(gt_set)) 213 | not_target_indices = set_to_tensor(not_target_set) 214 | target_indices = set_to_tensor(gt_set) 215 | target_indices_mapping = { 216 | target_indices[i].item(): i for i in range(0, target_indices.shape[0]) 217 | } 218 | T = vid_feats.shape[2] 219 | 220 | return BatchItem( 221 | video_name=feat_file_path.split("/")[-1][:-4], 222 | feats=vid_feats, 223 | gt_label=gt_labels, 224 | transcript=weak_labels, 225 | action_lengths=gt_action_lengths, 226 | tf_input=weak_label_tf, 227 | tf_target=weak_label_target, 228 | tf_set_target=set_label_target, 229 | gt_set=gt_set, 230 | gt_set_list=gt_set_list, 231 | A_hat=target_indices, 232 | target_indices_list=target_indices_list, 233 | not_target_indices=not_target_indices, 234 | not_target_indices_list=not_target_indices_list, 235 | not_target_set=not_target_set, 236 | not_target_set_list=not_target_set_list, 237 | target_indices_mapping=target_indices_mapping, 238 | t_videos=t_vids, 239 | T=T, 240 | ) 241 | 242 | @staticmethod 243 | def concat_videos(items: List[BatchItem]) -> BatchItem: 244 | output = BatchItem 245 | features = [] 246 | sets = set() 247 | for batch in items: 248 | features.append(batch.feats) 249 | sets = sets.union(batch.gt_set) 250 | num_classes = items[0].tf_set_target.shape[0] 251 | set_target = torch.zeros(num_classes) 252 | set_target[list(sets)] = 1.0 253 | output.feats = torch.cat(features, dim=2) 254 | output.gt_set = sets 255 | output.tf_set_target = set_target 256 | return output 257 | 258 | @staticmethod 259 | def collate_fn(items: List[BatchItem]) -> BatchItem: 260 | """ 261 | We assume batch_size = 1 262 | """ 263 | return items[0] 264 | -------------------------------------------------------------------------------- /src/SCT/experiment/general_experiment.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | from collections import defaultdict, OrderedDict 4 | from pathlib import Path 5 | from typing import Optional, Union, Dict, Iterable, List 6 | 7 | import torch 8 | import torch.optim as optim 9 | from torch.nn import Parameter 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR 12 | from torch.optim.optimizer import Optimizer 13 | from torch.utils.data import DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | from tqdm import tqdm 16 | from yacs.config import CfgNode 17 | 18 | from SCT.datasets import GeneralDataset 19 | from SCT.datasets.general_dataset import BatchItem 20 | from SCT.evaluators import GeneralEvaluator 21 | from SCT.evaluators.general_evaluator import EvalResult 22 | from SCT.models import GeneralModel 23 | from SCT.models.data_classes import LossOut 24 | from SCT.models.losses import loss_func 25 | from SCT.utils import get_git_commit_hash, print_with_time, OverfitSampler 26 | 27 | Scheduler = Union[ReduceLROnPlateau, MultiStepLR] 28 | 29 | 30 | RUN_INFO_FORMAT = """Time: {time} 31 | Command: {command} 32 | Git hash: {hash} 33 | ----------------------------------------- 34 | {config} 35 | """ 36 | 37 | 38 | class ScalarMetric: 39 | def __init__(self, writer: SummaryWriter, name: str, report_average: bool = True): 40 | self.writer = writer 41 | self.name = name 42 | self.report_average = report_average 43 | self.values = [] 44 | self.average_tag = "training_average/%s" % self.name 45 | 46 | def add_value(self, value: float, step: int, add_to_writer: bool = True): 47 | if add_to_writer: 48 | self.writer.add_scalar(tag=self.name, scalar_value=value, global_step=step) 49 | self.values.append(value) 50 | 51 | def epoch_finished(self, epoch_num): 52 | average_value = self.average_value() 53 | if self.report_average: 54 | self.writer.add_scalar( 55 | tag=self.average_tag, 56 | scalar_value=average_value, 57 | global_step=epoch_num + 1, 58 | ) 59 | print_with_time("%s: %f" % (self.average_tag, average_value)) 60 | self.reset_values() 61 | 62 | def reset_values(self): 63 | self.values.clear() 64 | 65 | def average_value(self) -> float: 66 | return sum(self.values) / len(self.values) 67 | 68 | 69 | def create_optimizer(cfg: CfgNode, parameters: Iterable[Parameter]) -> Optimizer: 70 | learning_rate = cfg.training.learning_rate 71 | momentum = cfg.training.momentum 72 | optimizer_name = cfg.training.optimizer 73 | weight_decay = cfg.training.weight_decay 74 | 75 | if optimizer_name == "SGD": 76 | return optim.SGD( 77 | params=parameters, 78 | lr=learning_rate, 79 | weight_decay=weight_decay, 80 | momentum=momentum, 81 | ) 82 | elif optimizer_name == "Adam": 83 | return optim.Adam( 84 | params=parameters, 85 | lr=learning_rate, 86 | weight_decay=weight_decay, 87 | ) 88 | else: 89 | raise Exception("Invalid optimizer name (%s)" % optimizer_name) 90 | 91 | 92 | def create_scheduler(cfg: CfgNode, optimizer: Optimizer) -> Optional[Scheduler]: 93 | scheduler_name = cfg.training.scheduler.name 94 | if scheduler_name == "none": 95 | return None 96 | elif scheduler_name == "plateau": 97 | return ReduceLROnPlateau( 98 | optimizer, 99 | mode=cfg.training.scheduler.plateau.mode, 100 | factor=cfg.training.scheduler.plateau.factor, 101 | verbose=cfg.training.scheduler.plateau.verbose, 102 | patience=cfg.training.scheduler.plateau.patience, 103 | ) 104 | elif scheduler_name == "step": 105 | steps = cfg.training.scheduler.multi_step.steps 106 | return MultiStepLR(optimizer=optimizer, milestones=steps) 107 | else: 108 | raise Exception("Invalid scheduler name (%s)" % scheduler_name) 109 | 110 | 111 | def create_metrics(cfg: CfgNode, writer: SummaryWriter) -> Dict[str, ScalarMetric]: 112 | metric_names_with_average = GeneralExperiment.metric_names_with_average 113 | metric_names_each_epoch = GeneralExperiment.metric_names_each_epoch_testing 114 | 115 | if cfg.training.evaluators.eval_train: 116 | metric_names_each_epoch.extend( 117 | GeneralExperiment.metric_names_each_epoch_training 118 | ) 119 | 120 | metrics = {} 121 | for mn in metric_names_with_average: 122 | metrics[mn] = ScalarMetric(writer, mn, report_average=True) 123 | 124 | for mn in metric_names_each_epoch: 125 | metrics[mn] = ScalarMetric(writer, mn, report_average=False) 126 | 127 | return metrics 128 | 129 | 130 | def create_train_dataloader(cfg: CfgNode, dataset: GeneralDataset) -> DataLoader: 131 | if cfg.training.overfit: 132 | sampler = OverfitSampler( 133 | main_source=dataset, indices=cfg.training.overfit_indices 134 | ) 135 | 136 | return DataLoader( 137 | dataset, 138 | batch_size=1, 139 | shuffle=False, 140 | sampler=sampler, 141 | num_workers=cfg.system.num_workers, 142 | collate_fn=dataset.collate_fn, 143 | pin_memory=True, 144 | ) 145 | else: 146 | return DataLoader( 147 | dataset, 148 | batch_size=1, 149 | shuffle=True, 150 | num_workers=cfg.system.num_workers, 151 | collate_fn=dataset.collate_fn, 152 | pin_memory=True, 153 | ) 154 | 155 | 156 | class GeneralExperiment(object): 157 | metric_names_with_average = [ 158 | "loss/Total", 159 | "loss/Set", 160 | "loss/Region", 161 | "loss/SCT", 162 | "loss/Temporal_Consistency", 163 | "loss/Length", 164 | "loss/Inverse_Sparsity", 165 | "optimization/learning_rate", 166 | ] 167 | metric_names_each_epoch_testing = [ 168 | "average_testing/mof", 169 | ] 170 | metric_names_each_epoch_training = [ 171 | "average_training/mof", 172 | ] 173 | 174 | model_filename = "model.pkl" 175 | optimizer_filename = "optimizer.pkl" 176 | scheduler_filename = "scheduler.pkl" 177 | 178 | def __init__( 179 | self, 180 | cfg: CfgNode, 181 | dataset: GeneralDataset, 182 | model: GeneralModel, 183 | loss_weights: torch.Tensor, 184 | val_evaluator: Optional[GeneralEvaluator], 185 | train_evaluator: Optional[GeneralEvaluator], 186 | ): 187 | self.cfg = cfg 188 | self.dataset = dataset 189 | self.model = model 190 | self.loss_weights = loss_weights 191 | self.val_evaluator = val_evaluator 192 | self.train_evaluator = train_evaluator 193 | 194 | self.device = torch.device(self.cfg.system.device) 195 | 196 | if not cfg.training.evaluators.ignore_classes == []: 197 | self.metric_names_each_epoch_training.extend( 198 | [ 199 | "average_training/mof_w/o_ignored_cls", 200 | ] 201 | ) 202 | self.metric_names_each_epoch_testing.extend( 203 | [ 204 | "average_testing/mof_w/o_ignored_cls", 205 | ] 206 | ) 207 | 208 | self.experiment_name = ( 209 | Path(self.cfg.experiment.name) 210 | / self.cfg.dataset.name 211 | / str(self.cfg.dataset.split) 212 | ) 213 | 214 | self.experiment_folder = Path(self.cfg.experiment.root) / self.experiment_name 215 | self.experiment_folder.mkdir(exist_ok=True, parents=True) 216 | self.iter_number = 0 217 | self.run_number = self.cfg.experiment.run_number 218 | if self.cfg.experiment.run_number == -1: 219 | self.run_number = self._figure_run_number() 220 | 221 | self.run_folder = self.experiment_folder / str(self.run_number) 222 | 223 | self.tb_folder = ( 224 | Path(self.cfg.experiment.tb_root) 225 | / self.experiment_name 226 | / Path(str(self.run_number)) 227 | ) 228 | self.tb_writer = SummaryWriter(self.tb_folder) 229 | 230 | self.metrics = create_metrics(self.cfg, self.tb_writer) 231 | 232 | self.clip_grad_norm = self.cfg.training.clip_grad_norm 233 | self.clip_grad_norm_value = self.cfg.training.clip_grad_norm_value 234 | 235 | self.epoch_number = 0 236 | if self.cfg.training.only_test and self.cfg.training.resume_from == -1: 237 | self.epoch_number = self._figure_epoch_number() 238 | elif self.cfg.training.only_test and not self.cfg.training.resume_from == -1: 239 | self.epoch_number = self.cfg.training.resume_from 240 | self.iter_num = 0 241 | self.epoch_losses = [] 242 | 243 | self.optimizer = create_optimizer(self.cfg, self.model.parameters()) 244 | self.scheduler = create_scheduler(self.cfg, self.optimizer) 245 | self.scheduler_type_plateau = ( 246 | True if self.cfg.training.scheduler.name == "plateau" else False 247 | ) 248 | self.loss = loss_func 249 | 250 | def _figure_run_number(self) -> int: 251 | # fixme: this is not thread safe! 252 | max_run = 0 253 | for f in self.experiment_folder.iterdir(): 254 | if f.is_dir(): 255 | try: 256 | f = int(str(f.name)) 257 | except ValueError: 258 | continue 259 | if f > max_run: 260 | max_run = f 261 | 262 | if self.cfg.training.only_test: 263 | return max_run 264 | return max_run + 1 265 | 266 | def _figure_epoch_number(self) -> int: 267 | max_epoch = 0 268 | for f in self.run_folder.iterdir(): 269 | if f.is_dir(): 270 | try: 271 | f = int(str(f.name)) 272 | except ValueError: 273 | continue 274 | if f > max_epoch: 275 | max_epoch = f 276 | 277 | if self.cfg.training.only_test: 278 | return max_epoch 279 | return max_epoch 280 | 281 | def generate_run_info(self) -> str: 282 | config_dump = self.cfg.dump() 283 | 284 | return RUN_INFO_FORMAT.format( 285 | time=str(datetime.datetime.now()), 286 | command=" ".join(sys.argv), 287 | hash=get_git_commit_hash(), 288 | config=config_dump, 289 | ) 290 | 291 | def _mark_the_run(self): 292 | self.run_folder.mkdir(exist_ok=True, parents=True) 293 | run_info = self.generate_run_info() 294 | with open(self.run_folder / "info.txt", "w") as f: 295 | f.write(run_info) 296 | 297 | # noinspection PyUnusedLocal 298 | def train_1_batch(self, iter_number: int, batch: BatchItem) -> LossOut: 299 | self.optimizer.zero_grad() 300 | batch.to(self.device) 301 | prediction = self.model.forward(batch) 302 | loss = self.loss( 303 | A=prediction.A, 304 | L=prediction.L, 305 | V=prediction.V, 306 | A_hat=batch.A_hat, 307 | cfg=self.cfg.loss, 308 | weight=self.loss_weights, 309 | ) 310 | loss.total_loss.backward() 311 | if self.clip_grad_norm: 312 | clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_value) 313 | # clip_grad_norm_(self.model.fl.parameters(), max_norm=self.clip_grad_norm_value) 314 | self.optimizer.step() 315 | return loss 316 | 317 | def train_1_epoch(self, epoch_number: int, dataloader: DataLoader): 318 | print_with_time("Training epoch %d ...." % (epoch_number + 1)) 319 | self.model.train() 320 | for batch in tqdm(dataloader): 321 | self.on_start_batch(self.iter_num) 322 | 323 | batch_loss = self.train_1_batch(self.iter_num, batch) 324 | self.track_training_metrics(batch, batch_loss, self.iter_num) 325 | 326 | self.on_end_batch(self.iter_num) 327 | self.iter_num += 1 328 | 329 | for n in self.metric_names_with_average: 330 | self.metrics[n].epoch_finished(epoch_number) 331 | 332 | def train(self): 333 | self._mark_the_run() 334 | num_epochs = self.cfg.training.num_epochs 335 | print_with_time(self.cfg.dump()) 336 | print_with_time("Training for run number: {:d}".format(self.run_number)) 337 | epoch_range = range(0, num_epochs) 338 | train_dataloader = create_train_dataloader(self.cfg, self.dataset) 339 | self.model.to(self.device) 340 | self.loss_weights = self.loss_weights.to(self.device) 341 | 342 | self.on_start_training() 343 | for epoch_num in epoch_range: 344 | self.epoch_number = epoch_num 345 | 346 | # resetting metrics 347 | for n, m in self.metrics.items(): 348 | m.reset_values() 349 | 350 | # callback 351 | self.on_start_epoch(epoch_num) 352 | 353 | # train for 1 epoch 354 | # with torch.autograd.set_detect_anomaly(True): 355 | self.train_1_epoch(epoch_num, train_dataloader) 356 | 357 | # save 358 | if (epoch_num + 1) % self.cfg.training.save_every == 0: 359 | self.save() 360 | 361 | # end of epoch evaluations 362 | if self.train_evaluator is not None: 363 | train_eval_result = self.train_evaluator.evaluate() 364 | print_with_time("Evaluation result on train set ...") 365 | print(train_eval_result) 366 | self.update_epoch_metrics_train_eval(train_eval_result, epoch_num) 367 | val_eval_result = self.val_evaluator.evaluate() 368 | print_with_time("Evaluation result on test set ...") 369 | print(val_eval_result) 370 | self.update_epoch_metrics_val_eval(val_eval_result, epoch_num) 371 | 372 | if self.scheduler is not None: 373 | # plateau scheduler 374 | if self.scheduler_type_plateau: 375 | self.scheduler.step( 376 | metrics=self._prepare_plateau_scheduler_input(val_eval_result), 377 | epoch=epoch_num, 378 | ) 379 | # step scheduler 380 | else: 381 | self.scheduler.step() 382 | 383 | # callback 384 | self.on_end_epoch(epoch_num) 385 | 386 | @staticmethod 387 | def _prepare_plateau_scheduler_input(eval_result: EvalResult) -> float: 388 | return eval_result["All"].mof 389 | 390 | def current_lr(self): 391 | if self.scheduler is not None: 392 | return list(self.scheduler.optimizer.param_groups)[0]["lr"] 393 | else: 394 | return self.optimizer.defaults["lr"] 395 | 396 | def on_start_epoch(self, epoch_num: int): 397 | 398 | print_with_time("Epoch {}, LR: {}".format(epoch_num + 1, self.current_lr())) 399 | 400 | def on_end_epoch(self, epoch_num: int): 401 | pass 402 | 403 | def on_start_batch(self, iter_num: int): 404 | pass 405 | 406 | def on_end_batch(self, iter_num: int): 407 | pass 408 | 409 | def on_start_training(self): 410 | pass 411 | 412 | def on_end_training(self): 413 | pass 414 | 415 | def init_from_pretrain(self): 416 | print_with_time("Initializing from pretrained weights...") 417 | model_file = self.cfg.training.pretrained_weight 418 | self.load_model(model_file, self.cfg.training.skip_modules) 419 | 420 | def resume(self): 421 | print_with_time("Resuming the experiment...") 422 | # TODO 423 | raise NotImplementedError("I am lazy!") 424 | 425 | def load_model_for_test(self): 426 | epoch_folder = self.run_folder / str(self.epoch_number) 427 | model_file = epoch_folder / self.model_filename 428 | self.load_model(model_file) 429 | 430 | def load_model(self, model_file: str, skip_modules: List[str] = []): 431 | print_with_time("Loading Model: {}".format(model_file)) 432 | input_model_dict = torch.load(model_file, map_location=torch.device("cpu")) 433 | filtered_model_dict = OrderedDict() 434 | for key, val in input_model_dict.items(): 435 | if key.split(".")[0] not in skip_modules: 436 | filtered_model_dict[key] = val 437 | else: 438 | print("Skipping: {}".format(key)) 439 | 440 | self.model.load_state_dict(filtered_model_dict, strict=False) 441 | 442 | def load_optimizer(self): 443 | epoch_folder = self.run_folder / str(self.epoch_number) 444 | optimizer_file = epoch_folder / self.model_filename 445 | print_with_time("Loading Optimizer: {}".format(optimizer_file)) 446 | self.model.load_state_dict(torch.load(optimizer_file)) 447 | 448 | def load_scheduler(self): 449 | epoch_folder = self.run_folder / str(self.epoch_number) 450 | scheduler_file = epoch_folder / self.model_filename 451 | print_with_time("Loading Scheduler: {}".format(scheduler_file)) 452 | self.model.load_state_dict(torch.load(scheduler_file)) 453 | 454 | def save(self): 455 | epoch_folder = self.run_folder / str(self.epoch_number + 1) 456 | epoch_folder.mkdir(exist_ok=True, parents=True) 457 | 458 | model_file = epoch_folder / self.model_filename 459 | optimizer_file = epoch_folder / self.optimizer_filename 460 | scheduler_file = epoch_folder / self.scheduler_filename 461 | 462 | print_with_time("Saving model ...") 463 | torch.save(self.model.cpu().state_dict(), model_file) 464 | print_with_time("Saving Optimizer ...") 465 | torch.save(self.optimizer.state_dict(), optimizer_file) 466 | if self.scheduler is not None: 467 | print_with_time("Saving Scheduler ...") 468 | torch.save(self.scheduler, scheduler_file) 469 | 470 | def update_epoch_metrics_train_eval( 471 | self, train_eval_result: EvalResult, epoch_num: int 472 | ): 473 | names = self.metric_names_each_epoch_training 474 | values = [ 475 | train_eval_result["All"].mof, 476 | ] 477 | if not self.cfg.training.evaluators.ignore_classes == []: 478 | values.extend( 479 | [ 480 | train_eval_result["W/O Ignored Classes"].mof, 481 | ] 482 | ) 483 | 484 | for n, v in zip(names, values): 485 | self.metrics[n].add_value(v, step=epoch_num + 1) 486 | 487 | def update_epoch_metrics_val_eval( 488 | self, val_eval_result: EvalResult, epoch_num: int 489 | ): 490 | names = self.metric_names_each_epoch_testing 491 | values = [ 492 | val_eval_result["All"].mof, 493 | ] 494 | if not self.cfg.training.evaluators.ignore_classes == []: 495 | values.extend( 496 | [ 497 | val_eval_result["W/O Ignored Classes"].mof, 498 | ] 499 | ) 500 | 501 | for n, v in zip(names, values): 502 | self.metrics[n].add_value(v, step=epoch_num + 1) 503 | 504 | # noinspection PyUnusedLocal 505 | def track_training_metrics( 506 | self, batch: BatchItem, batch_loss: LossOut, iter_num: int 507 | ): 508 | metric_names = self.metric_names_with_average 509 | values = [ 510 | batch_loss.total_loss.item(), 511 | batch_loss.set_loss.item(), 512 | batch_loss.region_loss.item(), 513 | batch_loss.sct_loss.item(), 514 | batch_loss.temporal_consistency_loss.item(), 515 | batch_loss.length_loss.item(), 516 | batch_loss.inv_sparsity_loss.item(), 517 | self.current_lr(), 518 | ] 519 | add_to_writer = self.cfg.experiment.track_training_metrics_per_iter 520 | for n, v in zip(metric_names, values): 521 | self.metrics[n].add_value(v, step=iter_num, add_to_writer=add_to_writer) 522 | --------------------------------------------------------------------------------