├── README.md ├── LICENSE ├── main.py ├── batch_gen.py ├── eval.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # MS-TCN: Multi-Stage Temporal Convolutional Network for Action Segmentation 2 | This repository provides a PyTorch implementation of the paper [MS-TCN: Multi-Stage Temporal Convolutional Network for Action Segmentation](https://arxiv.org/pdf/1903.01945.pdf). 3 | 4 | An extended version has been published in TPAMI [Link](https://github.com/sj-li/MS-TCN2). 5 | 6 | Tested with: 7 | - PyTorch 0.4.1 8 | - Python 2.7.12 9 | 10 | 11 | ### Qualitative Results: 12 | 13 |
16 | 17 | ### Training: 18 | 19 | * Download the [data](https://mega.nz/#!O6wXlSTS!wcEoDT4Ctq5HRq_hV-aWeVF1_JB3cacQBQqOLjCIbc8) folder, which contains the features and the ground truth labels. (~30GB) (If you cannot download the data from the previous link, try to download it from [here](https://zenodo.org/record/3625992#.Xiv9jGhKhPY)) 20 | * Extract it so that you have the `data` folder in the same directory as `main.py`. 21 | * To train the model run `python main.py --action=train --dataset=DS --split=SP` where `DS` is `breakfast`, `50salads` or `gtea`, and `SP` is the split number (1-5) for 50salads and (1-4) for the other datasets. 22 | 23 | ### Prediction: 24 | 25 | Run `python main.py --action=predict --dataset=DS --split=SP`. 26 | 27 | ### Evaluation: 28 | 29 | Run `python eval.py --dataset=DS --split=SP`. 30 | 31 | ### Citation: 32 | 33 | If you use the code, please cite 34 | 35 | Y. Abu Farha and J. Gall. 36 | MS-TCN: Multi-Stage Temporal Convolutional Network for Action Segmentation. 37 | In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019 38 | 39 | S. Li, Y. Abu Farha, Y. Liu, MM. Cheng, and J. Gall. 40 | MS-TCN++: Multi-Stage Temporal Convolutional Network for Action Segmentation. 41 | In IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), 2020 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT+CC License 2 | 3 | Copyright (c) 2019 yabufarha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | Commons Clause" License Condition v1.0 24 | 25 | The Software is provided to you by the Licensor under the License, 26 | as defined below, subject to the following condition. 27 | 28 | Without limiting other conditions in the License, the grant of 29 | rights under the License will not include, and the License does not 30 | grant to you, the right to Sell the Software. 31 | 32 | For purposes of the foregoing, "Sell" means practicing any or all 33 | of the rights granted to you under the License to provide to third 34 | parties, for a fee or other consideration (including without 35 | limitation fees for hosting or consulting/ support services related 36 | to the Software), a product or service whose value derives, entirely 37 | or substantially, from the functionality of the Software. Any license 38 | notice or attribution required by the License must also include 39 | this Commons Clause License Condition notice. 40 | 41 | Software: All ms-tcn associated files. 42 | License: MIT 43 | Licensor: yabufarha 44 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2.7 2 | 3 | import torch 4 | from model import Trainer 5 | from batch_gen import BatchGenerator 6 | import os 7 | import argparse 8 | import random 9 | 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | seed = 1538574472 13 | random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.deterministic = True 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--action', default='train') 20 | parser.add_argument('--dataset', default="gtea") 21 | parser.add_argument('--split', default='1') 22 | 23 | args = parser.parse_args() 24 | 25 | num_stages = 4 26 | num_layers = 10 27 | num_f_maps = 64 28 | features_dim = 2048 29 | bz = 1 30 | lr = 0.0005 31 | num_epochs = 50 32 | 33 | # use the full temporal resolution @ 15fps 34 | sample_rate = 1 35 | # sample input features @ 15fps instead of 30 fps 36 | # for 50salads, and up-sample the output to 30 fps 37 | if args.dataset == "50salads": 38 | sample_rate = 2 39 | 40 | vid_list_file = "./data/"+args.dataset+"/splits/train.split"+args.split+".bundle" 41 | vid_list_file_tst = "./data/"+args.dataset+"/splits/test.split"+args.split+".bundle" 42 | features_path = "./data/"+args.dataset+"/features/" 43 | gt_path = "./data/"+args.dataset+"/groundTruth/" 44 | 45 | mapping_file = "./data/"+args.dataset+"/mapping.txt" 46 | 47 | model_dir = "./models/"+args.dataset+"/split_"+args.split 48 | results_dir = "./results/"+args.dataset+"/split_"+args.split 49 | 50 | if not os.path.exists(model_dir): 51 | os.makedirs(model_dir) 52 | if not os.path.exists(results_dir): 53 | os.makedirs(results_dir) 54 | 55 | file_ptr = open(mapping_file, 'r') 56 | actions = file_ptr.read().split('\n')[:-1] 57 | file_ptr.close() 58 | actions_dict = dict() 59 | for a in actions: 60 | actions_dict[a.split()[1]] = int(a.split()[0]) 61 | 62 | num_classes = len(actions_dict) 63 | 64 | trainer = Trainer(num_stages, num_layers, num_f_maps, features_dim, num_classes) 65 | if args.action == "train": 66 | batch_gen = BatchGenerator(num_classes, actions_dict, gt_path, features_path, sample_rate) 67 | batch_gen.read_data(vid_list_file) 68 | trainer.train(model_dir, batch_gen, num_epochs=num_epochs, batch_size=bz, learning_rate=lr, device=device) 69 | 70 | if args.action == "predict": 71 | trainer.predict(model_dir, results_dir, features_path, vid_list_file_tst, num_epochs, actions_dict, device, sample_rate) 72 | -------------------------------------------------------------------------------- /batch_gen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2.7 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | 7 | 8 | class BatchGenerator(object): 9 | def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rate): 10 | self.list_of_examples = list() 11 | self.index = 0 12 | self.num_classes = num_classes 13 | self.actions_dict = actions_dict 14 | self.gt_path = gt_path 15 | self.features_path = features_path 16 | self.sample_rate = sample_rate 17 | 18 | def reset(self): 19 | self.index = 0 20 | random.shuffle(self.list_of_examples) 21 | 22 | def has_next(self): 23 | if self.index < len(self.list_of_examples): 24 | return True 25 | return False 26 | 27 | def read_data(self, vid_list_file): 28 | file_ptr = open(vid_list_file, 'r') 29 | self.list_of_examples = file_ptr.read().split('\n')[:-1] 30 | file_ptr.close() 31 | random.shuffle(self.list_of_examples) 32 | 33 | def next_batch(self, batch_size): 34 | batch = self.list_of_examples[self.index:self.index + batch_size] 35 | self.index += batch_size 36 | 37 | batch_input = [] 38 | batch_target = [] 39 | for vid in batch: 40 | features = np.load(self.features_path + vid.split('.')[0] + '.npy') 41 | file_ptr = open(self.gt_path + vid, 'r') 42 | content = file_ptr.read().split('\n')[:-1] 43 | classes = np.zeros(min(np.shape(features)[1], len(content))) 44 | for i in range(len(classes)): 45 | classes[i] = self.actions_dict[content[i]] 46 | batch_input .append(features[:, ::self.sample_rate]) 47 | batch_target.append(classes[::self.sample_rate]) 48 | 49 | length_of_sequences = map(len, batch_target) 50 | batch_input_tensor = torch.zeros(len(batch_input), np.shape(batch_input[0])[0], max(length_of_sequences), dtype=torch.float) 51 | batch_target_tensor = torch.ones(len(batch_input), max(length_of_sequences), dtype=torch.long)*(-100) 52 | mask = torch.zeros(len(batch_input), self.num_classes, max(length_of_sequences), dtype=torch.float) 53 | for i in range(len(batch_input)): 54 | batch_input_tensor[i, :, :np.shape(batch_input[i])[1]] = torch.from_numpy(batch_input[i]) 55 | batch_target_tensor[i, :np.shape(batch_target[i])[0]] = torch.from_numpy(batch_target[i]) 56 | mask[i, :, :np.shape(batch_target[i])[0]] = torch.ones(self.num_classes, np.shape(batch_target[i])[0]) 57 | 58 | return batch_input_tensor, batch_target_tensor, mask 59 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2.7 2 | # adapted from: https://github.com/colincsl/TemporalConvolutionalNetworks/blob/master/code/metrics.py 3 | 4 | import numpy as np 5 | import argparse 6 | 7 | 8 | def read_file(path): 9 | with open(path, 'r') as f: 10 | content = f.read() 11 | f.close() 12 | return content 13 | 14 | 15 | def get_labels_start_end_time(frame_wise_labels, bg_class=["background"]): 16 | labels = [] 17 | starts = [] 18 | ends = [] 19 | last_label = frame_wise_labels[0] 20 | if frame_wise_labels[0] not in bg_class: 21 | labels.append(frame_wise_labels[0]) 22 | starts.append(0) 23 | for i in range(len(frame_wise_labels)): 24 | if frame_wise_labels[i] != last_label: 25 | if frame_wise_labels[i] not in bg_class: 26 | labels.append(frame_wise_labels[i]) 27 | starts.append(i) 28 | if last_label not in bg_class: 29 | ends.append(i) 30 | last_label = frame_wise_labels[i] 31 | if last_label not in bg_class: 32 | ends.append(i + 1) 33 | return labels, starts, ends 34 | 35 | 36 | def levenstein(p, y, norm=False): 37 | m_row = len(p) 38 | n_col = len(y) 39 | D = np.zeros([m_row+1, n_col+1], np.float) 40 | for i in range(m_row+1): 41 | D[i, 0] = i 42 | for i in range(n_col+1): 43 | D[0, i] = i 44 | 45 | for j in range(1, n_col+1): 46 | for i in range(1, m_row+1): 47 | if y[j-1] == p[i-1]: 48 | D[i, j] = D[i-1, j-1] 49 | else: 50 | D[i, j] = min(D[i-1, j] + 1, 51 | D[i, j-1] + 1, 52 | D[i-1, j-1] + 1) 53 | 54 | if norm: 55 | score = (1 - D[-1, -1]/max(m_row, n_col)) * 100 56 | else: 57 | score = D[-1, -1] 58 | 59 | return score 60 | 61 | 62 | def edit_score(recognized, ground_truth, norm=True, bg_class=["background"]): 63 | P, _, _ = get_labels_start_end_time(recognized, bg_class) 64 | Y, _, _ = get_labels_start_end_time(ground_truth, bg_class) 65 | return levenstein(P, Y, norm) 66 | 67 | 68 | def f_score(recognized, ground_truth, overlap, bg_class=["background"]): 69 | p_label, p_start, p_end = get_labels_start_end_time(recognized, bg_class) 70 | y_label, y_start, y_end = get_labels_start_end_time(ground_truth, bg_class) 71 | 72 | tp = 0 73 | fp = 0 74 | 75 | hits = np.zeros(len(y_label)) 76 | 77 | for j in range(len(p_label)): 78 | intersection = np.minimum(p_end[j], y_end) - np.maximum(p_start[j], y_start) 79 | union = np.maximum(p_end[j], y_end) - np.minimum(p_start[j], y_start) 80 | IoU = (1.0*intersection / union)*([p_label[j] == y_label[x] for x in range(len(y_label))]) 81 | # Get the best scoring segment 82 | idx = np.array(IoU).argmax() 83 | 84 | if IoU[idx] >= overlap and not hits[idx]: 85 | tp += 1 86 | hits[idx] = 1 87 | else: 88 | fp += 1 89 | fn = len(y_label) - sum(hits) 90 | return float(tp), float(fp), float(fn) 91 | 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser() 95 | 96 | parser.add_argument('--dataset', default="gtea") 97 | parser.add_argument('--split', default='1') 98 | 99 | args = parser.parse_args() 100 | 101 | ground_truth_path = "./data/"+args.dataset+"/groundTruth/" 102 | recog_path = "./results/"+args.dataset+"/split_"+args.split+"/" 103 | file_list = "./data/"+args.dataset+"/splits/test.split"+args.split+".bundle" 104 | 105 | list_of_videos = read_file(file_list).split('\n')[:-1] 106 | 107 | overlap = [.1, .25, .5] 108 | tp, fp, fn = np.zeros(3), np.zeros(3), np.zeros(3) 109 | 110 | correct = 0 111 | total = 0 112 | edit = 0 113 | 114 | for vid in list_of_videos: 115 | gt_file = ground_truth_path + vid 116 | gt_content = read_file(gt_file).split('\n')[0:-1] 117 | 118 | recog_file = recog_path + vid.split('.')[0] 119 | recog_content = read_file(recog_file).split('\n')[1].split() 120 | 121 | for i in range(len(gt_content)): 122 | total += 1 123 | if gt_content[i] == recog_content[i]: 124 | correct += 1 125 | 126 | edit += edit_score(recog_content, gt_content) 127 | 128 | for s in range(len(overlap)): 129 | tp1, fp1, fn1 = f_score(recog_content, gt_content, overlap[s]) 130 | tp[s] += tp1 131 | fp[s] += fp1 132 | fn[s] += fn1 133 | 134 | print "Acc: %.4f" % (100*float(correct)/total) 135 | print 'Edit: %.4f' % ((1.0*edit)/len(list_of_videos)) 136 | for s in range(len(overlap)): 137 | precision = tp[s] / float(tp[s]+fp[s]) 138 | recall = tp[s] / float(tp[s]+fn[s]) 139 | 140 | f1 = 2.0 * (precision*recall) / (precision+recall) 141 | 142 | f1 = np.nan_to_num(f1)*100 143 | print 'F1@%0.2f: %.4f' % (overlap[s], f1) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2.7 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import optim 7 | import copy 8 | import numpy as np 9 | 10 | 11 | class MultiStageModel(nn.Module): 12 | def __init__(self, num_stages, num_layers, num_f_maps, dim, num_classes): 13 | super(MultiStageModel, self).__init__() 14 | self.stage1 = SingleStageModel(num_layers, num_f_maps, dim, num_classes) 15 | self.stages = nn.ModuleList([copy.deepcopy(SingleStageModel(num_layers, num_f_maps, num_classes, num_classes)) for s in range(num_stages-1)]) 16 | 17 | def forward(self, x, mask): 18 | out = self.stage1(x, mask) 19 | outputs = out.unsqueeze(0) 20 | for s in self.stages: 21 | out = s(F.softmax(out, dim=1) * mask[:, 0:1, :], mask) 22 | outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 23 | return outputs 24 | 25 | 26 | class SingleStageModel(nn.Module): 27 | def __init__(self, num_layers, num_f_maps, dim, num_classes): 28 | super(SingleStageModel, self).__init__() 29 | self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 30 | self.layers = nn.ModuleList([copy.deepcopy(DilatedResidualLayer(2 ** i, num_f_maps, num_f_maps)) for i in range(num_layers)]) 31 | self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 32 | 33 | def forward(self, x, mask): 34 | out = self.conv_1x1(x) 35 | for layer in self.layers: 36 | out = layer(out, mask) 37 | out = self.conv_out(out) * mask[:, 0:1, :] 38 | return out 39 | 40 | 41 | class DilatedResidualLayer(nn.Module): 42 | def __init__(self, dilation, in_channels, out_channels): 43 | super(DilatedResidualLayer, self).__init__() 44 | self.conv_dilated = nn.Conv1d(in_channels, out_channels, 3, padding=dilation, dilation=dilation) 45 | self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 46 | self.dropout = nn.Dropout() 47 | 48 | def forward(self, x, mask): 49 | out = F.relu(self.conv_dilated(x)) 50 | out = self.conv_1x1(out) 51 | out = self.dropout(out) 52 | return (x + out) * mask[:, 0:1, :] 53 | 54 | 55 | class Trainer: 56 | def __init__(self, num_blocks, num_layers, num_f_maps, dim, num_classes): 57 | self.model = MultiStageModel(num_blocks, num_layers, num_f_maps, dim, num_classes) 58 | self.ce = nn.CrossEntropyLoss(ignore_index=-100) 59 | self.mse = nn.MSELoss(reduction='none') 60 | self.num_classes = num_classes 61 | 62 | def train(self, save_dir, batch_gen, num_epochs, batch_size, learning_rate, device): 63 | self.model.train() 64 | self.model.to(device) 65 | optimizer = optim.Adam(self.model.parameters(), lr=learning_rate) 66 | for epoch in range(num_epochs): 67 | epoch_loss = 0 68 | correct = 0 69 | total = 0 70 | while batch_gen.has_next(): 71 | batch_input, batch_target, mask = batch_gen.next_batch(batch_size) 72 | batch_input, batch_target, mask = batch_input.to(device), batch_target.to(device), mask.to(device) 73 | optimizer.zero_grad() 74 | predictions = self.model(batch_input, mask) 75 | 76 | loss = 0 77 | for p in predictions: 78 | loss += self.ce(p.transpose(2, 1).contiguous().view(-1, self.num_classes), batch_target.view(-1)) 79 | loss += 0.15*torch.mean(torch.clamp(self.mse(F.log_softmax(p[:, :, 1:], dim=1), F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0, max=16)*mask[:, :, 1:]) 80 | 81 | epoch_loss += loss.item() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | _, predicted = torch.max(predictions[-1].data, 1) 86 | correct += ((predicted == batch_target).float()*mask[:, 0, :].squeeze(1)).sum().item() 87 | total += torch.sum(mask[:, 0, :]).item() 88 | 89 | batch_gen.reset() 90 | torch.save(self.model.state_dict(), save_dir + "/epoch-" + str(epoch + 1) + ".model") 91 | torch.save(optimizer.state_dict(), save_dir + "/epoch-" + str(epoch + 1) + ".opt") 92 | print("[epoch %d]: epoch loss = %f, acc = %f" % (epoch + 1, epoch_loss / len(batch_gen.list_of_examples), 93 | float(correct)/total)) 94 | 95 | def predict(self, model_dir, results_dir, features_path, vid_list_file, epoch, actions_dict, device, sample_rate): 96 | self.model.eval() 97 | with torch.no_grad(): 98 | self.model.to(device) 99 | self.model.load_state_dict(torch.load(model_dir + "/epoch-" + str(epoch) + ".model")) 100 | file_ptr = open(vid_list_file, 'r') 101 | list_of_vids = file_ptr.read().split('\n')[:-1] 102 | file_ptr.close() 103 | for vid in list_of_vids: 104 | print vid 105 | features = np.load(features_path + vid.split('.')[0] + '.npy') 106 | features = features[:, ::sample_rate] 107 | input_x = torch.tensor(features, dtype=torch.float) 108 | input_x.unsqueeze_(0) 109 | input_x = input_x.to(device) 110 | predictions = self.model(input_x, torch.ones(input_x.size(), device=device)) 111 | _, predicted = torch.max(predictions[-1].data, 1) 112 | predicted = predicted.squeeze() 113 | recognition = [] 114 | for i in range(len(predicted)): 115 | recognition = np.concatenate((recognition, [actions_dict.keys()[actions_dict.values().index(predicted[i].item())]]*sample_rate)) 116 | f_name = vid.split('/')[-1].split('.')[0] 117 | f_ptr = open(results_dir + "/" + f_name, "w") 118 | f_ptr.write("### Frame level recognition: ###\n") 119 | f_ptr.write(' '.join(recognition)) 120 | f_ptr.close() 121 | --------------------------------------------------------------------------------