├── run_scripts ├── eval_50salads_sample_script.sh └── train_50salads_sample_script.sh ├── LICENSE ├── .gitignore ├── README.md ├── postprocess.py ├── testtime_postprocess.py ├── testtime_dataset.py ├── model.py ├── utils.py ├── dataset.py ├── eval.py └── train.py /run_scripts/eval_50salads_sample_script.sh: -------------------------------------------------------------------------------- 1 | python eval.py --dataset_name 50salads --cudad 3 --base_dir base_data_dir -------------------------------------------------------------------------------- /run_scripts/train_50salads_sample_script.sh: -------------------------------------------------------------------------------- 1 | python train.py --dataset_name 50salads --split 1 --cudad 3 --base_dir base_data_dir 2 | python train.py --dataset_name 50salads --split 2 --cudad 3 --base_dir base_data_dir 3 | python train.py --dataset_name 50salads --split 3 --cudad 3 --base_dir base_data_dir 4 | python train.py --dataset_name 50salads --split 4 --cudad 3 --base_dir base_data_dir 5 | python train.py --dataset_name 50salads --split 5 --cudad 3 --base_dir base_data_dir 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 dipika singhania 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C2F-TCN: Coarse to Fine Multi-Resolution Temporal Convolutional Network for Temporal Action Segmentation 2 | 3 | Official implementation of Coarse to Fine Multi-Resolution Temporal Convolutional Network for Temporal Action Segmentation [link](https://arxiv.org/pdf/2105.10859.pdf) 4 | 5 | Code for full supervsion version of ‘C2F-TCN: A Framework for Semi- and Fully-Supervised Temporal Action Segmentation’ [link](https://ieeexplore.ieee.org/abstract/document/10147035) published in TPAMI-2023. 6 | 7 | Code for semi-supervised version of the same is available at [link](https://github.com/dipika-singhania/ICC-Semi-Supervised-TAS). 8 | 9 | 10 | 11 | ### Data download and directory structure: 12 | 13 | The I3D features, ground-truth and test split files are similar used to [MSTCN++](https://github.com/yabufarha/ms-tcn). 14 | In the mstcn_data, download additional files, checkpoints and semi-supervised splits can be downloaded from [drive](https://drive.google.com/drive/folders/1ArYPctLZZKfjicEf5nl4LJrY9xxFc6wU?usp=sharing) . 15 | Specifically, this drive link contains all necessary data in required directory structure except breakfast I3D feature files which can be downloaded from MSTCN++ data directory. 16 | It also contains the checkpoints files for supervised C2FTCN. 17 | 18 | The data directory is arranged in following structure 19 | 20 | - mstcn_data 21 | - mapping.csv 22 | - dataset_name 23 | - groundTruth 24 | - splits 25 | - results 26 | - supervised_C2FTCN 27 | - split1 28 | - check_pointfile 29 | - split2 30 | - 31 | 32 | ### Run Scripts 33 | The various scripts to run the supervised training, evaluation with test augmentation or with test augmentation is provided as example below. 34 | Change the dataset_name, to run on a different dataset. 35 | 36 | #### Training C2FTCN for a particular split of a dataset 37 | ##### python train.py --dataset_name --cudad --base_dir --split 38 | Example: 39 | python train.py --dataset_name 50salads --cudad 1 --base_dir ../mstcn_data/50salads/ --split 5 40 | 41 | 42 | #### Evaluate C2FTCN without test time augmentation, showing average results from all splits of dataset 43 | ##### python eval.py --dataset_name --cudad --base_dir --compile_result 44 | Example: 45 | python eval.py --dataset_name 50salads --cudad 2 --base_dir ../mstcn_data/50salads/ --compile_result 46 | 47 | #### Evaluate C2FTCN with test time augmentation, showing average results from all splits of dataset 48 | ##### python eval.py --dataset_name --cudad --base_dir 49 | Example: 50 | python eval.py --dataset_name 50salads --cudad 2 --base_dir ../mstcn_data/50salads/ 51 | 52 | 53 | 54 | ### Citation: 55 | 56 | If you use the code, please cite 57 | 58 | D. Singhania, R. Rahaman and A. Yao, "C2F-TCN: A Framework for Semi- and Fully-Supervised Temporal Action Segmentation," in IEEE Transactions on Pattern Analysis and Machine Intelligence, doi: 10.1109/TPAMI.2023.3284080. 59 | 60 | Singhania, D., Rahaman, R., & Yao, A. (2022, June). Iterative contrast-classify for semi-supervised temporal action segmentation. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 36, No. 2, pp. 2262-2270). 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pandas as pd 4 | import os 5 | import numpy as np 6 | 7 | class PostProcess(nn.Module): 8 | def __init__(self, args): 9 | super().__init__() 10 | df_labels = pd.read_csv(args.label_id_csv) 11 | 12 | self.labels_dict_id2name = {} 13 | self.labels_dict_name2id = {} 14 | for i, val in df_labels.iterrows(): 15 | self.labels_dict_id2name[val.label_id] = val.label_name 16 | self.labels_dict_name2id[val.label_name] = val.label_id 17 | 18 | self.ignore_label = args.num_class 19 | self.results_dict = dict() 20 | self.threshold = args.iou_threshold 21 | self.chunk_size = args.chunk_size 22 | self.gd_path = args.ground_truth_files_dir 23 | self.results_json = None 24 | self.count = 0 25 | 26 | def start(self): 27 | self.results_dict = dict() 28 | self.count = 0 29 | 30 | def dump_to_directory(self, path): 31 | print("Number of cats =", self.count) 32 | if len(self.results_dict.items()) == 0: 33 | return 34 | for video_id, video_value in self.results_dict.items(): 35 | pred_value = video_value[0].detach().cpu().numpy() 36 | label_count = video_value[1].detach().cpu().numpy() 37 | label_send = video_value[2].detach().cpu().numpy() 38 | 39 | video_path = os.path.join(self.gd_path, video_id + ".txt") 40 | with open(video_path, 'r') as f: 41 | recog_content = f.read().split('\n')[0:-1] # framelevel recognition is in 6-th line of file 42 | f.close() 43 | 44 | recog_content = np.array([self.labels_dict_name2id[e] for e in recog_content]) 45 | 46 | label_name_arr = [self.labels_dict_id2name[i.item()] for i in pred_value[:label_count.item()]] 47 | new_label_name_expanded = [] # np.empty(len(recog_content), dtype=np.object_) 48 | for i, ele in enumerate(label_name_arr): 49 | st = i * self.chunk_size 50 | end = st + self.chunk_size 51 | if end > len(recog_content): 52 | end = len(recog_content) 53 | for j in range(st, end): 54 | new_label_name_expanded.append(ele) 55 | if len(new_label_name_expanded) >= len(recog_content): 56 | break 57 | 58 | out_path = os.path.join(path, video_id + ".txt") 59 | with open(out_path, "w") as fp: 60 | fp.write("\n".join(new_label_name_expanded)) 61 | fp.write("\n") 62 | 63 | @torch.no_grad() 64 | def forward(self, outputs, video_names, framewise_labels, counts): 65 | """ Perform the computation 66 | Parameters: 67 | :param outputs: raw outputs of the model 68 | :param start_frame: 69 | :param video_names: 70 | :param clip_length: 71 | """ 72 | for output, vn, framewise_label, count in zip(outputs, video_names, framewise_labels, counts): 73 | output_video = torch.argmax(output, 0) 74 | if vn in self.results_dict: 75 | self.count += 1 76 | 77 | prev_tensor, prev_count, prev_gt_labels = self.results_dict[vn] 78 | output_video = torch.cat([prev_tensor, output_video]) 79 | framewise_label = torch.cat([prev_gt_labels, framewise_label]) 80 | count = count + prev_count 81 | 82 | self.results_dict[vn] = [output_video, count, framewise_label] 83 | -------------------------------------------------------------------------------- /testtime_postprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import random 6 | import torch.nn.functional as F 7 | import pandas as pd 8 | from torchvision import transforms 9 | import os 10 | import torch.nn as nn 11 | 12 | class PostProcess(nn.Module): 13 | def __init__(self, args, weights): 14 | super().__init__() 15 | df_labels = pd.read_csv(args.label_id_csv) 16 | 17 | self.labels_dict_id2name = {} 18 | self.labels_dict_name2id = {} 19 | for i, val in df_labels.iterrows(): 20 | self.labels_dict_id2name[val.label_id] = val.label_name 21 | self.labels_dict_name2id[val.label_name] = val.label_id 22 | 23 | self.results_dict = dict() 24 | self.gd_path = args.ground_truth_files_dir 25 | self.results_json = None 26 | self.count = 0 27 | self.acc_dict = dict() 28 | self.weights = weights 29 | 30 | 31 | def start(self): 32 | self.results_dict = dict() 33 | self.count = 0 34 | 35 | def get_acc_dict(self): 36 | return self.acc_dict 37 | 38 | def upsample_video_value(self, predictions, video_len, chunk_size): 39 | new_label_name_expanded = [] 40 | prediction_swap = predictions.permute(1, 0) 41 | for i, ele in enumerate(prediction_swap): 42 | st = i * chunk_size 43 | end = st + chunk_size 44 | for j in range(st, end): 45 | new_label_name_expanded.append(ele) 46 | out_p = torch.stack(new_label_name_expanded).permute(1, 0)[:, :video_len] 47 | return out_p 48 | 49 | def accumulate_result(self, all_pred_value): 50 | sum_ac = 0 51 | for wt, pred_v in zip(self.weights, all_pred_value): 52 | sum_ac = sum_ac + (wt * pred_v) 53 | 54 | return torch.argmax(sum_ac/ sum(self.weights) , dim=0) 55 | 56 | def dump_to_directory(self, path): 57 | 58 | print("Number of cats =", self.count) 59 | if len(self.results_dict.items()) == 0: 60 | return 61 | prev_vid_id = None 62 | all_pred_value = None 63 | ne_dict = {} 64 | video_id = None 65 | for video_chunk_id, video_value in self.results_dict.items(): 66 | video_id, chunk_id = video_chunk_id.split("@")[0], video_chunk_id.split("@")[1] 67 | upped_pred_logit = self.upsample_video_value(video_value[0][:, :video_value[1]], 68 | video_value[4], video_value[3]).unsqueeze(0) 69 | if video_id == prev_vid_id: 70 | all_pred_value = torch.cat([all_pred_value, upped_pred_logit], dim=0) 71 | else: 72 | if all_pred_value is not None: 73 | ne_dict[prev_vid_id] = self.accumulate_result(all_pred_value) 74 | all_pred_value = None 75 | prev_vid_id = video_id 76 | all_pred_value = upped_pred_logit # With refinement softmax has to be added 77 | 78 | if all_pred_value is not None: 79 | ne_dict[video_id] = self.accumulate_result(all_pred_value) 80 | 81 | for video_id, video_value in ne_dict.items(): 82 | pred_value = video_value.detach().cpu().numpy() 83 | label_name_arr = [self.labels_dict_id2name[i.item()] for i in pred_value] 84 | 85 | out_path = os.path.join(path, video_id + ".txt") 86 | with open(out_path, "w") as fp: 87 | fp.write("\n".join(label_name_arr)) 88 | fp.write("\n") 89 | 90 | @torch.no_grad() 91 | def forward(self, outputs, video_names, framewise_labels, counts, chunk_size_arr, chunk_id_arr, vid_len_arr): 92 | for output, vn, framewise_label, count, chunk_size, chunk_id, vid_len in zip(outputs, video_names, framewise_labels, 93 | counts, chunk_size_arr, chunk_id_arr, vid_len_arr): 94 | # output_video = torch.argmax(output, 0) 95 | 96 | key = '{}@{}'.format(vn, chunk_id) 97 | 98 | if key in self.results_dict: 99 | self.count += 1 100 | 101 | prev_tensor, prev_count, prev_gt_labels, chunk_size, vid_len = self.results_dict[key] 102 | output = torch.cat([prev_tensor, output], dim=1) 103 | framewise_label = torch.cat([prev_gt_labels, framewise_label]) 104 | count = count + prev_count 105 | 106 | self.results_dict[key] = [output, count, framewise_label, chunk_size, vid_len] 107 | 108 | -------------------------------------------------------------------------------- /testtime_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import ast 4 | import numpy as np 5 | import h5py 6 | from torchvision import transforms 7 | import os 8 | from PIL import Image 9 | from collections import defaultdict 10 | from itertools import chain as chain 11 | import random 12 | 13 | 14 | def collate_fn_override(data): 15 | """ 16 | data: 17 | """ 18 | data = list(filter(lambda x: x is not None, data)) 19 | data_arr, count, labels, video_len, start, video_id, labels_present_arr, chunk_size, chunk_id = zip(*data) 20 | return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(video_len),\ 21 | torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(chunk_size),\ 22 | torch.tensor(chunk_id) 23 | 24 | 25 | class AugmentDataset(torch.utils.data.Dataset): 26 | def __init__(self, args, fold, fold_file_name, chunk_size): 27 | 28 | self.fold = fold 29 | self.max_frames_per_video = args.max_frames_per_video 30 | self.feature_size = args.feature_size 31 | self.base_dir_name = args.features_file_name 32 | self.frames_format = "{}/{:06d}.jpg" 33 | self.ground_truth_files_dir = args.ground_truth_files_dir 34 | self.void_class = args.num_class 35 | self.num_class = args.num_class 36 | self.args = args 37 | self.chunk_size_arr = chunk_size 38 | self.data = self.make_data_set(fold_file_name) 39 | 40 | 41 | def make_data_set(self, fold_file_name): # Longer Videos 10 -- max_chunk_size # Shorter Videos = min(chunk size) - max 42 | df=pd.read_csv(self.args.label_id_csv) 43 | label_id_to_label_name = {} 44 | label_name_to_label_id_dict = {} 45 | for i, ele in df.iterrows(): 46 | label_id_to_label_name[ele.label_id] = ele.label_name 47 | label_name_to_label_id_dict[ele.label_name] = ele.label_id 48 | 49 | data = open(fold_file_name).read().split("\n")[:-1] 50 | data_arr = [] 51 | num_video_not_found = 0 52 | 53 | for i, video_id in enumerate(data): 54 | video_id = video_id.split(".txt")[0] 55 | filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt") 56 | 57 | with open(filename, 'r') as f: 58 | recog_content = f.read().split('\n')[0:-1] # framelevel recognition is in 6-th line of file 59 | f.close() 60 | 61 | recog_content = [label_name_to_label_id_dict[e] for e in recog_content] 62 | 63 | total_frames = len(recog_content) 64 | 65 | if not os.path.exists(os.path.join(self.base_dir_name, video_id + ".npy")): 66 | print("Not found video with id", os.path.join(self.base_dir_name, video_id + ".npy")) 67 | num_video_not_found += 1 68 | continue 69 | 70 | len_video = len(recog_content) 71 | 72 | chunk_size_arr = self.chunk_size_arr 73 | for i, chunk_size in enumerate(chunk_size_arr): 74 | for st_frame in range(0, len_video, self.max_frames_per_video * chunk_size): 75 | max_end = st_frame + (self.max_frames_per_video * chunk_size) 76 | end_frame = max_end if max_end < len_video else len_video 77 | ele_dict = {'st_frame': st_frame, 'end_frame': end_frame, 'chunk_id': i, 'chunk_size': chunk_size, 78 | 'video_id': video_id, 'tot_frames': (end_frame - st_frame) // chunk_size} 79 | ele_dict["labels"] = np.array(recog_content, dtype=int) 80 | data_arr.append(ele_dict) 81 | print("Number of datapoints logged in {} fold is {}".format(self.fold, len(data_arr))) 82 | return data_arr 83 | 84 | def getitem(self, index): # Try to use this for debugging purpose 85 | ele_dict = self.data[index] 86 | count = 0 87 | st_frame = ele_dict['st_frame'] 88 | end_frame = ele_dict['end_frame'] 89 | aug_chunk_size = ele_dict['chunk_size'] 90 | 91 | data_arr = torch.zeros((self.max_frames_per_video, self.feature_size)) 92 | label_arr = torch.ones(self.max_frames_per_video, dtype=torch.long) * -100 93 | 94 | image_path = os.path.join(self.base_dir_name, ele_dict['video_id'] + ".npy") 95 | elements = np.load(image_path) 96 | # elements = np.loadtxt(image_path) 97 | count = 0 98 | labels_present_arr = torch.zeros(self.num_class, dtype=torch.float32) 99 | 100 | for i in range(st_frame, end_frame, aug_chunk_size): 101 | end = min(end_frame, i + aug_chunk_size) 102 | key = elements[:, i: end] 103 | values, counts = np.unique(ele_dict["labels"][i: end], return_counts=True) 104 | label_arr[count] = values[np.argmax(counts)] 105 | labels_present_arr[label_arr[count]] = 1 106 | data_arr[count, :] = torch.tensor(np.max(key, axis=-1), dtype=torch.float32) 107 | count += 1 108 | 109 | return data_arr, count, label_arr, elements.shape[1], st_frame, ele_dict['video_id'], labels_present_arr, \ 110 | aug_chunk_size, ele_dict['chunk_id'] 111 | 112 | def __getitem__(self, index): 113 | return self.getitem(index) 114 | 115 | def __len__(self): 116 | return len(self.data) 117 | 118 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import math 3 | import torch.nn as nn 4 | import torch 5 | from functools import partial 6 | import torchvision.models as mdels 7 | 8 | nonlinearity = partial(F.relu, inplace=True) 9 | 10 | class double_conv(nn.Module): 11 | def __init__(self, in_ch, out_ch): 12 | super(double_conv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 15 | nn.BatchNorm1d(out_ch), 16 | nn.ReLU(inplace=True), 17 | nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 18 | nn.BatchNorm1d(out_ch), 19 | nn.ReLU(inplace=True), 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | 27 | class inconv(nn.Module): 28 | def __init__(self, in_ch, out_ch): 29 | super(inconv, self).__init__() 30 | self.conv = double_conv(in_ch, out_ch) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | class outconv(nn.Module): 37 | def __init__(self, in_ch, out_ch): 38 | super(outconv, self).__init__() 39 | self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1) 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | return x 44 | 45 | class down(nn.Module): 46 | def __init__(self, in_ch, out_ch): 47 | super(down, self).__init__() 48 | self.max_pool_conv = nn.Sequential( 49 | nn.MaxPool1d(2), double_conv(in_ch, out_ch)) 50 | 51 | def forward(self, x): 52 | x = self.max_pool_conv(x) 53 | return x 54 | 55 | class up(nn.Module): 56 | """Upscaling then double conv""" 57 | 58 | def __init__(self, in_channels, out_channels, bilinear=True): 59 | super().__init__() 60 | 61 | if bilinear: 62 | self.up = nn.Upsample( 63 | scale_factor=2, mode="linear", align_corners=True) 64 | else: 65 | self.up = nn.ConvTranspose1d( 66 | in_channels // 2, in_channels // 2, kernel_size=2, stride=2 67 | ) 68 | 69 | self.conv = double_conv(in_channels, out_channels) 70 | 71 | def forward(self, x1, x2): 72 | x1 = self.up(x1) 73 | # input is CHW 74 | diff = torch.tensor([x2.size()[2] - x1.size()[2]]) 75 | 76 | x1 = F.pad(x1, [diff // 2, diff - diff //2]) 77 | x = torch.cat([x2, x1], dim=1) 78 | return self.conv(x) 79 | 80 | 81 | class TPPblock(nn.Module): 82 | def __init__(self, in_channels): 83 | super(TPPblock, self).__init__() 84 | self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 85 | self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 86 | self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 87 | self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 88 | 89 | self.conv = nn.Conv1d( 90 | in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 91 | ) 92 | 93 | def forward(self, x): 94 | self.in_channels, t = x.size(1), x.size(2) 95 | self.layer1 = F.upsample( 96 | self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 97 | ) 98 | self.layer2 = F.upsample( 99 | self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 100 | ) 101 | self.layer3 = F.upsample( 102 | self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 103 | ) 104 | self.layer4 = F.upsample( 105 | self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 106 | ) 107 | 108 | out = torch.cat([self.layer1, self.layer2, 109 | self.layer3, self.layer4, x], 1) 110 | 111 | return out 112 | 113 | 114 | class C2F_TCN(nn.Module): 115 | ''' 116 | Features are extracted at the last layer of decoder. 117 | ''' 118 | def __init__(self, n_channels, n_classes): 119 | super(C2F_TCN, self).__init__() 120 | self.inc = inconv(n_channels, 256) 121 | self.down1 = down(256, 256) 122 | self.down2 = down(256, 256) 123 | self.down3 = down(256, 128) 124 | self.down4 = down(128, 128) 125 | self.down5 = down(128, 128) 126 | self.down6 = down(128, 128) 127 | self.up = up(260, 128) 128 | self.outcc0 = outconv(128, n_classes) 129 | self.up0 = up(256, 128) 130 | self.outcc1 = outconv(128, n_classes) 131 | self.up1 = up(256, 128) 132 | self.outcc2 = outconv(128, n_classes) 133 | self.up2 = up(384, 128) 134 | self.outcc3 = outconv(128, n_classes) 135 | self.up3 = up(384, 128) 136 | self.outcc4 = outconv(128, n_classes) 137 | self.up4 = up(384, 128) 138 | self.outcc = outconv(128, n_classes) 139 | self.tpp = TPPblock(128) 140 | self.weights = torch.nn.Parameter(torch.ones(6)) 141 | 142 | def forward(self, x): 143 | x1 = self.inc(x) 144 | x2 = self.down1(x1) 145 | x3 = self.down2(x2) 146 | x4 = self.down3(x3) 147 | x5 = self.down4(x4) 148 | x6 = self.down5(x5) 149 | x7 = self.down6(x6) 150 | # x7 = self.dac(x7) 151 | x7 = self.tpp(x7) 152 | x = self.up(x7, x6) 153 | y1 = self.outcc0(F.relu(x)) 154 | # print("y1.shape=", y1.shape) 155 | x = self.up0(x, x5) 156 | y2 = self.outcc1(F.relu(x)) 157 | # print("y2.shape=", y2.shape) 158 | x = self.up1(x, x4) 159 | y3 = self.outcc2(F.relu(x)) 160 | # print("y3.shape=", y3.shape) 161 | x = self.up2(x, x3) 162 | y4 = self.outcc3(F.relu(x)) 163 | # print("y4.shape=", y4.shape) 164 | x = self.up3(x, x2) 165 | y5 = self.outcc4(F.relu(x)) 166 | # print("y5.shape=", y5.shape) 167 | x = self.up4(x, x1) 168 | y = self.outcc(x) 169 | # print("y.shape=", y.shape) 170 | return y, [y5, y4, y3, y2, y1], x 171 | 172 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from torch import Tensor 5 | import numpy as np 6 | import torchvision 7 | import glob 8 | import re 9 | 10 | class dotdict(dict): 11 | """dot.notation access to dictionary attributes""" 12 | __getattr__ = dict.get 13 | __setattr__ = dict.__setitem__ 14 | __delattr__ = dict.__delitem__ 15 | 16 | 17 | def get_labels_start_end_time(frame_wise_labels, bg_class): 18 | labels = [] 19 | starts = [] 20 | ends = [] 21 | last_label = frame_wise_labels[0] 22 | if frame_wise_labels[0] not in bg_class: 23 | labels.append(frame_wise_labels[0]) 24 | starts.append(0) 25 | for i in range(len(frame_wise_labels)): 26 | if frame_wise_labels[i] != last_label: 27 | if frame_wise_labels[i] not in bg_class: 28 | labels.append(frame_wise_labels[i]) 29 | starts.append(i) 30 | if last_label not in bg_class: 31 | ends.append(i) 32 | last_label = frame_wise_labels[i] 33 | if last_label not in bg_class: 34 | ends.append(i + 1) 35 | return labels, starts, ends 36 | 37 | 38 | def levenstein(p, y, norm): 39 | m_row = len(p) 40 | n_col = len(y) 41 | D = np.zeros([m_row + 1, n_col + 1], np.float) 42 | for i in range(m_row + 1): 43 | D[i, 0] = i 44 | for i in range(n_col + 1): 45 | D[0, i] = i 46 | 47 | for j in range(1, n_col + 1): 48 | for i in range(1, m_row + 1): 49 | if y[j - 1] == p[i - 1]: 50 | D[i, j] = D[i - 1, j - 1] 51 | else: 52 | D[i, j] = min(D[i - 1, j] + 1, 53 | D[i, j - 1] + 1, 54 | D[i - 1, j - 1] + 1) 55 | 56 | if norm: 57 | score = (1 - D[-1, -1] / max(m_row, n_col)) * 100 58 | else: 59 | score = D[-1, -1] 60 | 61 | return score 62 | 63 | 64 | def edit_score(recognized, ground_truth, bg_class): 65 | norm = True 66 | P, _, _ = get_labels_start_end_time(recognized, bg_class) 67 | Y, _, _ = get_labels_start_end_time(ground_truth, bg_class) 68 | return levenstein(P, Y, norm) 69 | 70 | 71 | def f_score(recognized, ground_truth, overlap, bg_class): 72 | p_label, p_start, p_end = get_labels_start_end_time(recognized, bg_class) 73 | y_label, y_start, y_end = get_labels_start_end_time(ground_truth, bg_class) 74 | 75 | tp = 0 76 | fp = 0 77 | 78 | hits = np.zeros(len(y_label)) 79 | 80 | for j in range(len(p_label)): 81 | intersection = np.minimum(p_end[j], y_end) - np.maximum(p_start[j], y_start) 82 | union = np.maximum(p_end[j], y_end) - np.minimum(p_start[j], y_start) 83 | IoU = (1.0*intersection / union)*([p_label[j] == y_label[x] for x in range(len(y_label))]) 84 | # Get the best scoring segment 85 | idx = np.array(IoU).argmax() 86 | 87 | if IoU[idx] >= overlap and not hits[idx]: 88 | tp += 1 89 | hits[idx] = 1 90 | else: 91 | fp += 1 92 | fn = len(y_label) - sum(hits) 93 | return float(tp), float(fp), float(fn) 94 | 95 | 96 | def recog_file(filename, ground_truth_path, overlap, background_class_list): 97 | 98 | # read ground truth 99 | gt_file = ground_truth_path + re.sub('.*/', '/', filename) 100 | with open(gt_file, 'r') as f: 101 | gt_content = f.read().split('\n')[0:-1] 102 | f.close() 103 | # read recognized sequence 104 | with open(filename, 'r') as f: 105 | recog_content = f.read().split('\n')[0:-1] # framelevel recognition is in 6-th line of file 106 | f.close() 107 | 108 | n_frame_correct = 0 109 | for i in range(len(recog_content)): 110 | if recog_content[i] == gt_content[i]: 111 | n_frame_correct += 1 112 | 113 | edit_score_value = edit_score(recog_content, gt_content, background_class_list) 114 | 115 | tp_arr = [] 116 | fp_arr = [] 117 | fn_arr = [] 118 | for s in range(len(overlap)): 119 | tp1, fp1, fn1 = f_score(recog_content, gt_content, overlap[s], background_class_list) 120 | tp_arr.append(tp1) 121 | fp_arr.append(fp1) 122 | fn_arr.append(fn1) 123 | return n_frame_correct, len(recog_content), tp_arr, fp_arr, fn_arr, edit_score_value 124 | 125 | 126 | def calculate_mof(ground_truth_path_name, prediction_path, background_class): 127 | overlap = [.1, .25, .5] 128 | overlap_scores = np.zeros(3) 129 | tp, fp, fn = np.zeros(3), np.zeros(3), np.zeros(3) 130 | edit = 0 131 | n_frames = 0 132 | n_correct = 0 133 | 134 | filelist = glob.glob(prediction_path + '/*txt') 135 | 136 | print('Evaluate %d video files...' % len(filelist)) 137 | if len(filelist) == 0: 138 | return 0, 0, overlap_scores 139 | # loop over all recognition files and evaluate the frame error 140 | for filename in filelist: 141 | correct, frames, tp_arr, fp_arr, fn_arr, edit_score_value = recog_file(filename, ground_truth_path_name, 142 | overlap, background_class) 143 | n_correct += correct 144 | n_frames += frames 145 | edit += edit_score_value 146 | 147 | for i in range(len(overlap)): 148 | tp[i] += tp_arr[i] 149 | fp[i] += fp_arr[i] 150 | fn[i] += fn_arr[i] 151 | 152 | if n_correct == 0 or n_frames == 0: 153 | acc = 0 154 | else: 155 | acc = float(n_correct) * 100.0 / n_frames 156 | 157 | print('frame accuracy: %0.4f' % acc) 158 | final_edit_score = ((1.0 * edit) / len(filelist)) 159 | print('Edit score: %0.4f' % final_edit_score) 160 | 161 | for s in range(len(overlap)): 162 | precision = tp[s] / float(tp[s] + fp[s]) 163 | recall = tp[s] / float(tp[s] + fn[s]) 164 | 165 | f1 = 2.0 * (precision * recall) / (precision + recall) 166 | 167 | f1 = np.nan_to_num(f1) * 100 168 | print('F1@%0.2f: %.4f' % (overlap[s], f1)) 169 | overlap_scores[s] = f1 170 | 171 | return final_edit_score, acc, overlap_scores 172 | 173 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import ast 4 | import numpy as np 5 | import h5py 6 | from torchvision import transforms 7 | import os 8 | from PIL import Image 9 | from collections import defaultdict 10 | from itertools import chain as chain 11 | import random 12 | 13 | def collate_fn_override(data): 14 | """ 15 | data: 16 | """ 17 | data = list(filter(lambda x: x is not None, data)) 18 | data_arr, count, labels, clip_length, start, video_id, labels_present_arr, aug_chunk_size, targets = zip(*data) 19 | 20 | return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(clip_length),\ 21 | torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(aug_chunk_size, dtype=torch.int) 22 | 23 | 24 | class AugmentDataset(torch.utils.data.Dataset): 25 | def __init__(self, args, fold, fold_file_name, 26 | zoom_crop=(0.5, 2), smallest_cut=1.0): 27 | 28 | self.fold = fold 29 | self.max_frames_per_video = args.max_frames_per_video 30 | self.feature_size = args.feature_size 31 | self.base_dir_name = args.features_file_name 32 | self.frames_format = "{}/{:06d}.jpg" 33 | self.ground_truth_files_dir = args.ground_truth_files_dir 34 | self.chunk_size = args.chunk_size 35 | self.num_class = args.num_class 36 | self.zoom_crop = zoom_crop 37 | self.smallest_cut = smallest_cut 38 | self.validation = True if fold == 'val' else False 39 | self.args = args 40 | self.data = self.make_data_set(fold_file_name) 41 | 42 | 43 | def make_data_set(self, fold_file_name): 44 | df=pd.read_csv(self.args.label_id_csv) 45 | label_id_to_label_name = {} 46 | label_name_to_label_id_dict = {} 47 | for i, ele in df.iterrows(): 48 | label_id_to_label_name[ele.label_id] = ele.label_name 49 | label_name_to_label_id_dict[ele.label_name] = ele.label_id 50 | 51 | data = open(fold_file_name).read().split("\n")[:-1] 52 | data_arr = [] 53 | num_video_not_found = 0 54 | for i, video_id in enumerate(data): 55 | video_id = video_id.split(".txt")[0] 56 | filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt") 57 | 58 | with open(filename, 'r') as f: 59 | recog_content = f.read().split('\n')[0:-1] # framelevel recognition is in 6-th line of file 60 | f.close() 61 | 62 | recog_content = [label_name_to_label_id_dict[e] for e in recog_content] 63 | 64 | total_frames = len(recog_content) 65 | 66 | if not os.path.exists(os.path.join(self.base_dir_name, video_id + ".npy")): 67 | print("Not found video with id", os.path.join(self.base_dir_name, video_id + ".npy")) 68 | num_video_not_found += 1 69 | continue 70 | 71 | start_frame_arr = [] 72 | end_frame_arr = [] 73 | for st in range(0, total_frames, self.max_frames_per_video * self.chunk_size): 74 | start_frame_arr.append(st) 75 | max_end = st + (self.max_frames_per_video * self.chunk_size) 76 | end_frame = max_end if max_end < total_frames else total_frames 77 | end_frame_arr.append(end_frame) 78 | 79 | for st_frame, end_frame in zip(start_frame_arr, end_frame_arr): 80 | 81 | ele_dict = {'st_frame': st_frame, 'end_frame': end_frame, 'video_id': video_id, 82 | 'tot_frames': (end_frame - st_frame)} 83 | 84 | ele_dict["labels"] = np.array(recog_content, dtype=int) 85 | 86 | data_arr.append(ele_dict) 87 | 88 | print("Number of videos logged in {} fold is {}".format(self.fold, len(data_arr))) 89 | print("Number of videos not found in {} fold is {}".format(self.fold, num_video_not_found)) 90 | return data_arr 91 | 92 | def getitem(self, index): # Try to use this for debugging purpose 93 | ele_dict = self.data[index] 94 | count = 0 95 | st_frame = ele_dict['st_frame'] 96 | end_frame = ele_dict['end_frame'] 97 | 98 | data_arr = torch.zeros((self.max_frames_per_video, self.feature_size)) 99 | label_arr = torch.ones(self.max_frames_per_video, dtype=torch.long) * -100 100 | 101 | image_path = os.path.join(self.base_dir_name, ele_dict['video_id'] + ".npy") 102 | elements = np.load(image_path) 103 | 104 | if self.args.feature_size == 256: 105 | elements = elements.T 106 | count = 0 107 | end_frame = min(end_frame, st_frame + (self.max_frames_per_video * self.chunk_size)) 108 | len_video = end_frame - st_frame 109 | num_original_frames = np.ceil(len_video/self.chunk_size) 110 | 111 | if np.random.randint(low=0, high=2)==0 and (not self.validation): 112 | aug_start = np.random.uniform(low=0.0, high=1-self.smallest_cut) 113 | aug_len = np.random.uniform(low=self.smallest_cut, high=1-aug_start) 114 | aug_end = aug_start + aug_len 115 | min_possible_chunk_size = np.ceil(len_video/self.max_frames_per_video) 116 | max_chunk_size = int(1.0*self.chunk_size/self.zoom_crop[0]) 117 | min_chunk_size = max(int(1.0*self.chunk_size/self.zoom_crop[1]), min_possible_chunk_size) 118 | aug_chunk_size = int(np.exp(np.random.uniform(low=np.log(min_chunk_size), high=np.log(max_chunk_size)))) 119 | num_aug_frames = np.ceil(int(aug_len*len_video)/aug_chunk_size) 120 | if num_aug_frames > self.max_frames_per_video: 121 | num_aug_frames = self.max_frames_per_video 122 | aug_chunk_size = int(np.ceil(aug_len*len_video/num_aug_frames)) 123 | 124 | aug_translate = 0 125 | aug_start_frame = st_frame + int(len_video*aug_start) 126 | aug_end_frame = st_frame + int(len_video*aug_end) 127 | else: 128 | aug_translate, aug_start_frame, aug_end_frame, aug_chunk_size = 0, st_frame, end_frame, self.chunk_size 129 | 130 | labels_present_arr = torch.zeros(self.num_class, dtype=torch.float32) 131 | for i in range(aug_start_frame, aug_end_frame, aug_chunk_size): 132 | end = min(aug_end_frame, i + aug_chunk_size) 133 | key = elements[:, i: end] 134 | values, counts = np.unique(ele_dict["labels"][i: end], return_counts=True) 135 | label_arr[count] = values[np.argmax(counts)] 136 | labels_present_arr[label_arr[count]] = 1 137 | data_arr[aug_translate+count, :] = torch.tensor(np.max(key, axis=-1), dtype=torch.float32) 138 | count += 1 139 | 140 | return data_arr, count, label_arr, ele_dict['tot_frames'], st_frame, ele_dict['video_id'], \ 141 | labels_present_arr, aug_chunk_size, {"labels": label_arr} 142 | 143 | def __getitem__(self, index): 144 | return self.getitem(index) 145 | 146 | def __len__(self): 147 | return len(self.data) 148 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import warnings 4 | warnings.filterwarnings('ignore') 5 | import argparse 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | from utils import dotdict 13 | from utils import calculate_mof 14 | from testtime_postprocess import PostProcess 15 | import torch.nn.functional as F 16 | from testtime_dataset import AugmentDataset, collate_fn_override 17 | 18 | 19 | 20 | my_parser = argparse.ArgumentParser() 21 | my_parser.add_argument('--dataset_name', type=str, default="breakfast", choices=['breakfast', '50salads', 'gtea']) 22 | my_parser.add_argument('--split', type=int, required=False, help="Comma seperated split number to run evaluation," + \ 23 | "default = 1,2,3,4 for breakfast and gtea, 1,2,3,4,5 for 50salads") 24 | my_parser.add_argument('--cudad', type=str, default='0', help="Cuda device number to run evaluation program in") 25 | my_parser.add_argument('--base_dir', type=str, help="Base directory containing groundTruth, features, splits directory of dataset") 26 | my_parser.add_argument('--chunk_size', type=int, required=False, help="Provide chunk size which as used for training," + \ 27 | "by default it is set for datase") 28 | my_parser.add_argument('--ensem_weights', type=str, required=False, 29 | help='Default = \"1,1,1,1,0,0\", provide in similar format comma-seperated 6 weights values if required to be changed') 30 | my_parser.add_argument('--ft_file', type=str, required=False, help="Provide feature file dir path if default is not base_dir/features") 31 | my_parser.add_argument('--ft_size', type=int, required=False, help="Default = 2048 for the I3D features, change if feature size changes") 32 | my_parser.add_argument('--model_path', type=str, default='model') 33 | my_parser.add_argument('--err_bar', type=int, required=False) 34 | my_parser.add_argument('--compile_result', action='store_true', help="To get results without test time augmentation use --compile_result") 35 | my_parser.add_argument('--num_workers', type=int, default=7, help="Number of workers to be used for data loading") 36 | my_parser.add_argument('--out_dir', required=False, help="Directory where output(checkpoints, logs, results) is to be dumped") 37 | my_parser.add_argument('--model_checkpoint', required=False, help="Checkpoint to pick up the model") 38 | args = my_parser.parse_args() 39 | 40 | 41 | seed = 42 42 | 43 | if args.err_bar: 44 | seed = args.err_bar #np.random.randint(0, 999999) 45 | 46 | if args.model_checkpoint: 47 | split_dir = args.model_checkpoint.split("/") 48 | args.out_dir = "/".join(split_dir[:-2]) 49 | print(f"With model checkpoint {args.model_checkpoint}, output directory is {args.out_dir}") 50 | 51 | # Ensure deterministic behavior 52 | def set_seed(): 53 | torch.backends.cudnn.benchmark = False 54 | torch.backends.cudnn.deterministic = True 55 | random.seed(seed) 56 | np.random.seed(seed) 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | torch.cuda.manual_seed(seed) 60 | set_seed() 61 | 62 | # Device configuration 63 | os.environ['CUDA_VISIBLE_DEVICES']=args.cudad 64 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 65 | 66 | 67 | config = dotdict( 68 | epochs=500, 69 | dataset=args.dataset_name, 70 | feature_size=2048, 71 | gamma=0.5, 72 | step_size=500, 73 | model_path=args.model_path, 74 | base_dir =args.base_dir, 75 | aug=1, 76 | lps=0) 77 | 78 | 79 | config.ensem_weights = [1, 1, 1, 1, 0, 0] 80 | 81 | if args.dataset_name == "breakfast": 82 | config.chunk_size = 10 83 | config.max_frames_per_video = 1200 84 | config.learning_rate = 1e-4 85 | config.weight_decay = 3e-3 86 | config.batch_size = 100 87 | config.num_class = 48 88 | config.back_gd = ['SIL'] 89 | config.split = [1, 2, 3, 4] 90 | if not args.compile_result: 91 | config.chunk_size = list(range(7, 16)) 92 | config.weights = np.ones(len(config.chunk_size)) 93 | else: 94 | config.chunk_size = [10] 95 | config.weights = [1] 96 | config.eval_true = True 97 | 98 | elif args.dataset_name == "gtea": 99 | config.chunk_size = 4 100 | config.max_frames_per_video = 600 101 | config.learning_rate = 5e-4 102 | config.weight_decay = 3e-4 103 | config.batch_size = 11 104 | config.num_class = 11 105 | config.back_gd = ['background'] 106 | config.split = [1, 2, 3, 4] 107 | if not args.compile_result: 108 | config.chunk_size = [3, 4, 5] # list(range(20,40)) 109 | config.weights = [1, 3, 1] 110 | else: 111 | config.chunk_size = [4] 112 | config.weights = [1] 113 | 114 | else: # if args.dataset_name == "50salads": 115 | config.chunk_size = 20 116 | config.max_frames_per_video = 960 117 | config.learning_rate = 3e-4 118 | config.weight_decay = 1e-3 119 | config.batch_size = 20 120 | config.num_class = 19 121 | config.back_gd = ['action_start', 'action_end'] 122 | config.split = [1, 2, 3, 4, 5] 123 | if not args.compile_result: 124 | config.chunk_size = list(range(20,40)) 125 | config.weights = np.ones(len(config.chunk_size)) 126 | else: 127 | config.chunk_size = [20] 128 | config.weights = [1] 129 | config.eval_true = True 130 | 131 | 132 | if args.split is not None: 133 | try: 134 | args.split = int(args.split) 135 | config.split = [args.split] 136 | except: 137 | config.split = args.split.split(',') 138 | 139 | config.features_file_name = config.base_dir + "/features/" 140 | config.ground_truth_files_dir = config.base_dir + "/groundTruth/" 141 | config.label_id_csv = config.base_dir + "mapping.csv" 142 | 143 | 144 | def model_pipeline(config): 145 | acc_list = [] 146 | edit_list = [] 147 | f1_10_list = [] 148 | f1_25_list = [] 149 | f1_50_list = [] 150 | for ele in config.split: 151 | config.output_dir = config.base_dir + "results/supervised_C2FTCN/split{}".format(ele) #, onfig.model_path, ele, config.aug) 152 | # if args.wd is not None: 153 | # config.weight_decay = args.wd 154 | # config.output_dir=config.output_dir + "_wd{:.5f}".format(config.weight_decay) 155 | 156 | # if args.lr is not None: 157 | # config.learning_rate = args.lr 158 | # config.output_dir=config.output_dir + "_lr{:.6f}".format(config.learning_rate) 159 | 160 | if args.chunk_size is not None: 161 | config.chunk_size = args.chunk_size 162 | config.output_dir = config.output_dir + "_chunk{}".format(config.chunk_size) 163 | 164 | if args.ensem_weights is not None: 165 | config.output_dir = config.output_dir + "_wts{}".format(args.ensem_weights.replace(',', '-')) 166 | config.ensem_weights = list(map(int, args.ensem_weights.split(","))) 167 | print("Weights being used is ", config.ensem_weights) 168 | 169 | config.output_dir = config.output_dir + "/" 170 | if args.out_dir is not None: 171 | config.output_dir = args.out_dir + "/" 172 | 173 | print("printing getting the output from output dir = ", config.output_dir) 174 | config.project_name="{}-split{}".format(config.dataset, ele) 175 | config.test_split_file = config.base_dir + "splits/test.split{}.bundle".format(ele) 176 | # make the model, data, and optimization problem 177 | model, test_loader, postprocessor = make(config) 178 | model.load_state_dict(load_best_model(config)) 179 | prefix = '' 180 | 181 | # model.eval() 182 | 183 | 184 | correct, correct1, total = 0, 0, 0 185 | postprocessor.start() 186 | 187 | with torch.no_grad(): 188 | for i, item in enumerate(test_loader): 189 | samples = item[0].to(device).permute(0,2,1) 190 | count = item[1].to(device) 191 | labels = item[2].to(device) 192 | src_mask = torch.arange(labels.shape[1], device=labels.device)[None, :] < count[:, None] 193 | src_mask = src_mask.to(device) 194 | 195 | outplist = model(samples) 196 | ensembel_out = get_ensemble_out(outplist) 197 | 198 | pred = torch.argmax(ensembel_out, dim=1) 199 | correct += float(torch.sum((pred==labels)*src_mask).item()) 200 | total += float(torch.sum(src_mask).item()) 201 | 202 | # postprocessor(ensembel_out, item[5], labels, count) 203 | # 7 chunk size, 8 is chunk id 204 | postprocessor(ensembel_out, item[5], labels, count, item[7].to(device), item[8], item[3].to(device)) 205 | 206 | print(f'Accuracy: {100.0*correct/total: .2f}') 207 | # Add postprocessing and check the outcomes 208 | path = os.path.join(config.output_dir, prefix + "testtime_augmentation_split{}".format(ele)) 209 | if not os.path.exists(path): 210 | os.mkdir(path) 211 | print(f"Output files will be dumped in {path} directory") 212 | postprocessor.dump_to_directory(path) 213 | 214 | final_edit_score, map_v, overlap_scores = calculate_mof(config.ground_truth_files_dir, path, config.back_gd) 215 | acc_list.append(map_v) 216 | edit_list.append(final_edit_score) 217 | f1_10_list.append(overlap_scores[0]) 218 | f1_25_list.append(overlap_scores[1]) 219 | f1_50_list.append(overlap_scores[2]) 220 | 221 | print("Frame accuracy = ", np.mean(np.array(acc_list))) 222 | print("Edit Scores = ", np.mean(np.array(edit_list))) 223 | print("f1@10 Scores = ", np.mean(np.array(f1_10_list))) 224 | print("f1@25 Scores = ", np.mean(np.array(f1_25_list))) 225 | print("f1@50 Scores = ", np.mean(np.array(f1_50_list))) 226 | 227 | 228 | def load_best_model(config): 229 | if args.model_checkpoint is not None: 230 | print(f"Loading checkpoint from {args.model_checkpoint}") 231 | return torch.load(args.model_checkpoint) 232 | checkpoint_file = config.output_dir + '/best_' + config.dataset + '_unet.wt' 233 | print(f"Loading checkpoint from {checkpoint_file}") 234 | return torch.load(checkpoint_file) 235 | 236 | def load_avgbest_model(config): 237 | if args.model_checkpoint is not None: 238 | return torch.load(args.model_checkpoint) 239 | return torch.load(config.output_dir + '/avgbest_' + config.dataset + '_unet.wt') 240 | 241 | def make(config): 242 | # Make the data 243 | test = get_data(config, train=False) 244 | test_loader = make_loader(test, batch_size=config.batch_size, train=False) 245 | 246 | # Make the model 247 | model = get_model(config).to(device) 248 | 249 | num_params = sum([p.numel() for p in model.parameters()]) 250 | print("Number of parameters = ", num_params/1e6, " million") 251 | 252 | # postprocessor declaration 253 | postprocessor = PostProcess(config, config.weights) 254 | postprocessor = postprocessor.to(device) 255 | 256 | return model, test_loader, postprocessor 257 | 258 | 259 | def get_data(args, train=True): 260 | if train is True: 261 | fold='train' 262 | split_file_name = args.train_split_file 263 | else: 264 | fold='val' 265 | split_file_name = args.test_split_file 266 | dataset = AugmentDataset(args, fold=fold, fold_file_name=split_file_name, chunk_size=config.chunk_size) 267 | 268 | return dataset 269 | 270 | 271 | def make_loader(dataset, batch_size, train=True): 272 | def _init_fn(worker_id): 273 | np.random.seed(int(seed)) 274 | loader = torch.utils.data.DataLoader(dataset=dataset, 275 | batch_size=batch_size, 276 | shuffle=train, 277 | pin_memory=True, num_workers=args.num_workers, collate_fn=collate_fn_override, 278 | worker_init_fn=_init_fn) 279 | return loader 280 | 281 | 282 | def get_model(config): 283 | my_module = importlib.import_module(config.model_path) 284 | set_seed() 285 | return my_module.C2F_TCN(config.feature_size, config.num_class) 286 | 287 | 288 | def get_ensemble_out(outp): 289 | 290 | weights = config.ensem_weights 291 | ensemble_prob = F.softmax(outp[0], dim=1) * weights[0] / sum(weights) 292 | 293 | for i, outp_ele in enumerate(outp[1]): 294 | upped_logit = F.upsample(outp_ele, size=outp[0].shape[-1], mode='linear', align_corners=True) 295 | ensemble_prob = ensemble_prob + F.softmax(upped_logit, dim=1) * weights[i + 1] / sum(weights) 296 | 297 | return ensemble_prob 298 | 299 | model = model_pipeline(config) 300 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import warnings 4 | warnings.filterwarnings('ignore') 5 | import argparse 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | from utils import dotdict 13 | from utils import calculate_mof 14 | from postprocess import PostProcess 15 | import torch.nn.functional as F 16 | from dataset import AugmentDataset, collate_fn_override 17 | seed = 42 18 | 19 | my_parser = argparse.ArgumentParser() 20 | my_parser.add_argument('--dataset_name', type=str, default="breakfast", choices=['breakfast', '50salads', 'gtea']) 21 | my_parser.add_argument('--split', type=int, required=True, help="Split number of the dataset") 22 | my_parser.add_argument('--cudad', type=str, default='0', help="Cuda device number to run the program") 23 | my_parser.add_argument('--base_dir', type=str, help="Base directory containing groundTruth, features, splits, results directory of dataset") 24 | my_parser.add_argument('--model_path', type=str, default='model') 25 | my_parser.add_argument('--wd', type=float, required=False, help="Provide weigth decay if you want to change from default") 26 | my_parser.add_argument('--lr', type=float, required=False, help="Provide learning rate if you want to change from default") 27 | my_parser.add_argument('--chunk_size', type=int, required=False, help="Provide chunk size to be used if you want to change from default") 28 | my_parser.add_argument('--ensem_weights', type=str, required=False, 29 | help='Default = \"1,1,1,1,0,0\", provide in similar comma-seperated 6 weights values if required to be changed') 30 | my_parser.add_argument('--ft_file', type=str, required=False, help="Provide feature file dir path if default is not base_dir/features") 31 | my_parser.add_argument('--ft_size', type=int, required=False, help="Default=2048 for I3D features, change if feature size changes") 32 | my_parser.add_argument('--err_bar', type=int, required=False) 33 | my_parser.add_argument('--num_workers', type=int, default=7, help="Number of workers to be used for data loading") 34 | my_parser.add_argument('--out_dir', required=False, help="Directory where output(checkpoints, logs, results) is to be dumped") 35 | args = my_parser.parse_args() 36 | 37 | 38 | if args.err_bar: 39 | seed = args.err_bar #np.random.randint(0, 999999) 40 | 41 | # Ensure deterministic behavior 42 | def set_seed(): 43 | torch.backends.cudnn.benchmark = False 44 | torch.backends.cudnn.deterministic = True 45 | random.seed(seed) 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed_all(seed) 49 | torch.cuda.manual_seed(seed) 50 | set_seed() 51 | 52 | # Device configuration 53 | os.environ['CUDA_VISIBLE_DEVICES']=args.cudad 54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | 56 | config = dotdict( 57 | epochs = 500, 58 | dataset = args.dataset_name, 59 | feature_size = 2048, 60 | gamma = 0.5, 61 | step_size = 500, 62 | split_number = args.split, 63 | model_path = args.model_path, 64 | base_dir = args.base_dir, 65 | aug=1, 66 | lps=0) 67 | 68 | if args.dataset_name == "breakfast": 69 | config.chunk_size = 10 70 | config.max_frames_per_video = 1200 71 | config.learning_rate = 1e-4 72 | config.weight_decay = 3e-3 73 | config.batch_size = 100 74 | config.num_class = 48 75 | config.back_gd = ['SIL'] 76 | config.ensem_weights = [1, 1, 1, 1, 0, 0] 77 | elif args.dataset_name == "gtea": 78 | config.chunk_size = 4 79 | config.max_frames_per_video = 600 80 | config.learning_rate = 5e-4 81 | config.weight_decay = 3e-4 82 | config.batch_size = 11 83 | config.num_class = 11 84 | config.back_gd = ['background'] 85 | config.ensem_weights = [1, 1, 1, 1, 0, 0] 86 | else: # args.dataset_name == "50salads": 87 | config.chunk_size = 20 88 | config.max_frames_per_video = 960 89 | config.learning_rate = 3e-4 90 | config.weight_decay = 1e-3 91 | config.batch_size = 20 92 | config.num_class = 19 93 | config.back_gd = ['action_start', 'action_end'] 94 | config.ensem_weights = [1, 1, 1, 1, 0, 0] 95 | 96 | config.output_dir = config.base_dir + "results/supervised_C2FTCN/" 97 | if not os.path.exists(config.output_dir): 98 | os.mkdir(config.output_dir) 99 | 100 | config.output_dir = config.output_dir + "split{}".format(config.split_number) 101 | 102 | if args.wd is not None: 103 | config.weight_decay = args.wd 104 | config.output_dir=config.output_dir + "_wd{:.5f}".format(config.weight_decay) 105 | 106 | if args.lr is not None: 107 | config.learning_rate = args.lr 108 | config.output_dir=config.output_dir + "_lr{:.6f}".format(config.learning_rate) 109 | 110 | if args.chunk_size is not None: 111 | config.chunk_size = args.chunk_size 112 | config.output_dir=config.output_dir + "_chunk{}".format(config.chunk_size) 113 | 114 | if args.ensem_weights is not None: 115 | config.output_dir=config.output_dir + "_wts{}".format(args.ensem_weights.replace(',','-')) 116 | config.ensem_weights = list(map(int, args.ensem_weights.split(","))) 117 | print("C2F Ensemble Weights being used is ", config.ensem_weights) 118 | 119 | 120 | print("printing in output dir = ", config.output_dir) 121 | config.project_name="{}-split{}".format(config.dataset, config.split_number) 122 | config.train_split_file = config.base_dir + "splits/train.split{}.bundle".format(config.split_number) 123 | config.test_split_file = config.base_dir + "splits/test.split{}.bundle".format(config.split_number) 124 | config.features_file_name = config.base_dir + "/features/" 125 | 126 | if args.ft_file is not None: 127 | config.features_file_name = os.path.join(config.base_dir, args.ft_file) 128 | config.output_dir = config.output_dir + "_ft_file{}".format(args.ft_file) 129 | 130 | if args.ft_size is not None: 131 | config.feature_size = args.ft_size 132 | config.output_dir = config.output_dir + "_ft_size{}".format(args.ft_file) 133 | 134 | config.ground_truth_files_dir = config.base_dir + "/groundTruth/" 135 | config.label_id_csv = config.base_dir + "mapping.csv" 136 | 137 | config.output_dir = config.output_dir + "/" 138 | 139 | if args.out_dir is not None: 140 | config.output_dir = args.out_dir + "/" 141 | 142 | def model_pipeline(config): 143 | if not os.path.exists(config.output_dir): 144 | os.mkdir(config.output_dir) 145 | 146 | # make the model, data, and optimization problem 147 | model, train_loader, test_loader, criterion, optimizer, scheduler, postprocessor = make(config) 148 | 149 | # and use them to train the model 150 | train(model, train_loader, criterion, optimizer, scheduler, config, test_loader, postprocessor) 151 | 152 | # and test its final performance 153 | model.load_state_dict(load_avgbest_model(config)) 154 | acc = test(model, test_loader, criterion, postprocessor, config, config.epochs, 'avg') 155 | 156 | model.load_state_dict(load_best_model(config)) 157 | acc = test(model, test_loader, criterion, postprocessor, config, config.epochs, '') 158 | 159 | return model 160 | 161 | def load_best_model(config): 162 | return torch.load(config.output_dir + '/best_' + config.dataset + '_unet.wt') 163 | 164 | def load_avgbest_model(config): 165 | return torch.load(config.output_dir + '/avgbest_' + config.dataset + '_unet.wt') 166 | 167 | def make(config): 168 | # Make the data 169 | train, test = get_data(config, train=True), get_data(config, train=False) 170 | train_loader = make_loader(train, batch_size=config.batch_size, train=True) 171 | test_loader = make_loader(test, batch_size=config.batch_size, train=False) 172 | 173 | # Make the model 174 | model = get_model(config).to(device) 175 | 176 | num_params = sum([p.numel() for p in model.parameters()]) 177 | print("Number of parameters = ", num_params/1e6, " million") 178 | 179 | # Make the loss and optimizer 180 | criterion = get_criterion(config) 181 | optimizer = torch.optim.Adam( 182 | model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) 183 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma) 184 | 185 | # postprocessor declaration 186 | postprocessor = PostProcess(config) 187 | postprocessor = postprocessor.to(device) 188 | 189 | return model, train_loader, test_loader, criterion, optimizer, scheduler, postprocessor 190 | 191 | 192 | class CriterionClass(nn.Module): 193 | def __init__(self, config): 194 | super().__init__() 195 | self.ce = nn.CrossEntropyLoss(ignore_index=-100) # Frame wise cross entropy loss 196 | self.mse = nn.MSELoss(reduction='none') # Migitating transistion loss 197 | 198 | def forward(self, outp, labels, src_mask, labels_present): 199 | outp_wo_softmax = torch.log(outp + 1e-10) # log is necessary because ensemble gives softmax output 200 | ce_loss = self.ce(outp_wo_softmax, labels) 201 | 202 | mse_loss = 0.15 * torch.mean(torch.clamp(self.mse(outp_wo_softmax[:, :, 1:], 203 | outp_wo_softmax.detach()[:, :, :-1]), 204 | min=0, max=16) * src_mask[:, :, 1:]) 205 | 206 | loss = ce_loss + mse_loss 207 | return {'full_loss':loss, 'ce_loss':ce_loss, 'mse_loss': mse_loss} 208 | 209 | def get_criterion(config): 210 | return CriterionClass(config) 211 | 212 | def get_data(args, train=True): 213 | if train is True: 214 | fold='train' 215 | split_file_name = args.train_split_file 216 | else: 217 | fold='val' 218 | split_file_name = args.test_split_file 219 | 220 | dataset = AugmentDataset(args, fold=fold, fold_file_name=split_file_name, zoom_crop=(0.5, 2)) 221 | return dataset 222 | 223 | 224 | def make_loader(dataset, batch_size, train=True): 225 | def _init_fn(worker_id): 226 | np.random.seed(int(seed)) 227 | loader = torch.utils.data.DataLoader(dataset=dataset, 228 | batch_size=batch_size, 229 | shuffle=train, 230 | pin_memory=True, num_workers=args.num_workers, collate_fn=collate_fn_override, 231 | worker_init_fn=_init_fn) 232 | return loader 233 | 234 | 235 | def get_model(config): 236 | my_module = importlib.import_module(config.model_path) 237 | set_seed() 238 | return my_module.C2F_TCN(config.feature_size, config.num_class) 239 | 240 | 241 | def get_c2f_ensemble_output(outp, weights): 242 | 243 | ensemble_prob = F.softmax(outp[0], dim=1) * weights[0] / sum(weights) 244 | 245 | for i, outp_ele in enumerate(outp[1]): 246 | upped_logit = F.upsample(outp_ele, size=outp[0].shape[-1], mode='linear', align_corners=True) 247 | ensemble_prob = ensemble_prob + F.softmax(upped_logit, dim=1) * weights[i + 1] / sum(weights) 248 | 249 | return ensemble_prob 250 | 251 | def train(model, loader, criterion, optimizer, scheduler, config, test_loader, postprocessor): 252 | 253 | best_acc = 0 254 | avg_best_acc = 0 255 | accs = [] 256 | 257 | for epoch in range(config.epochs): 258 | model.train() 259 | for i, item in enumerate(loader): 260 | samples = item[0].to(device).permute(0, 2, 1) 261 | count = item[1].to(device) 262 | labels = item[2].to(device) 263 | src_mask = torch.arange(labels.shape[1], device=labels.device)[None, :] < count[:, None] 264 | src_mask = src_mask.to(device) 265 | 266 | src_msk_send = src_mask.to(torch.float32).to(device).unsqueeze(1) 267 | 268 | # Forward pass ➡ 269 | outputs_list = model(samples) 270 | outputs_ensemble = get_c2f_ensemble_output(outputs_list, config.ensem_weights) 271 | 272 | loss_dict = criterion(outputs_ensemble, labels, src_msk_send, item[6].to(device)) 273 | loss = loss_dict['full_loss'] 274 | 275 | # Backward pass ⬅ 276 | optimizer.zero_grad() 277 | loss.backward() 278 | 279 | # Step with optimizer 280 | optimizer.step() 281 | 282 | if i % 10 == 0: 283 | print(f"Train loss after {epoch} epochs, {i} iterations is {loss_dict['full_loss']:.3f}") 284 | 285 | acc, avg_score = test(model, test_loader, criterion, postprocessor, config, epoch, '') 286 | if avg_score > avg_best_acc: 287 | avg_best_acc = avg_score 288 | torch.save(model.state_dict(), config.output_dir + '/avgbest_' + config.dataset + '_unet.wt') 289 | if acc > best_acc: 290 | best_acc = acc 291 | torch.save(model.state_dict(), config.output_dir + '/best_' + config.dataset + '_unet.wt') 292 | 293 | torch.save(model.state_dict(), config.output_dir + '/last_' + config.dataset + '_unet.wt') 294 | accs.append(acc) 295 | accs.sort(reverse=True) 296 | scheduler.step() 297 | print(f'Validation best accuracies till now -> {" ".join(["%.2f"%item for item in accs[:3]])}') 298 | 299 | 300 | def test(model, test_loader, criterion, postprocessors, args, epoch, dump_prefix): 301 | model.eval() 302 | 303 | # Run the model on some test examples 304 | with torch.no_grad(): 305 | correct, total = 0, 0 306 | avg_loss = [] 307 | for i, item in enumerate(test_loader): 308 | samples = item[0].to(device).permute(0, 2, 1) 309 | count = item[1].to(device) 310 | labels = item[2].to(device) 311 | src_mask = torch.arange(labels.shape[1], device=labels.device)[None, :] < count[:, None] 312 | src_mask = src_mask.to(device) 313 | 314 | src_msk_send = src_mask.to(torch.float32).to(device).unsqueeze(1) 315 | 316 | # Forward pass ➡ 317 | outputs_list = model(samples) 318 | outputs_ensemble = get_c2f_ensemble_output(outputs_list, config.ensem_weights) 319 | 320 | loss_dict = criterion(outputs_ensemble, labels, src_msk_send, item[6].to(device)) 321 | loss = loss_dict['full_loss'] 322 | avg_loss.append(loss.item()) 323 | 324 | pred = torch.argmax(outputs_ensemble, dim=1) 325 | correct += float(torch.sum((pred == labels) * src_mask).item()) 326 | total += float(torch.sum(src_mask).item()) 327 | postprocessors(outputs_ensemble, item[5], labels, count) 328 | 329 | # Add postprocessing and check the outcomes 330 | path = os.path.join(args.output_dir, dump_prefix + "predict_" + args.dataset) 331 | if not os.path.exists(path): 332 | os.mkdir(path) 333 | postprocessors.dump_to_directory(path) 334 | final_edit_score, map_v, overlap_scores = calculate_mof(args.ground_truth_files_dir, path, config.back_gd) 335 | postprocessors.start() 336 | acc = 100.0 * correct / total 337 | 338 | print(f"Validation loss = {np.mean(np.array(avg_loss)): .3f}, accuracy of the model after epoch {epoch} = {acc: .3f}%") 339 | with open(config.output_dir + "/results_file.txt", "a+") as fp: 340 | fp.write("{:.1f}, {:.1f}, {:.1f}, {:.1f}, {:.1f}\n".format(overlap_scores[0], overlap_scores[1], 341 | overlap_scores[2], final_edit_score, map_v)) 342 | if epoch == config.epochs: 343 | with open(config.output_dir + "/" + dump_prefix + "final_results_file.txt", "a+") as fp: 344 | fp.write("{:.1f}, {:.1f}, {:.1f}, {:.1f}, {:.1f}\n".format(overlap_scores[0], overlap_scores[1], 345 | overlap_scores[2], final_edit_score, map_v)) 346 | 347 | 348 | avg_score = (map_v + final_edit_score) / 2 349 | return map_v, avg_score 350 | 351 | import time 352 | start_time = time.time() 353 | model = model_pipeline(config) 354 | end_time = time.time() 355 | 356 | duration = (end_time - start_time) / 60 357 | print(f"Total time taken = ", duration, "mins") 358 | --------------------------------------------------------------------------------