├── requirements_hf.txt ├── step_recognition ├── model │ ├── rnn │ │ ├── __init__.py │ │ └── rnn.py │ ├── __init__.py │ ├── transformer_models │ │ ├── __init__.py │ │ ├── PositionalEncoding.py │ │ ├── decoder.py │ │ ├── Transformer.py │ │ ├── Attention.py │ │ ├── BiT.py │ │ ├── ViT.py │ │ ├── attn.py │ │ ├── HybridViT.py │ │ └── AxialNet.py │ ├── model_builder.py │ └── weights_init.py ├── criterions │ ├── __init__.py │ ├── loss_builder.py │ └── loss.py ├── datasets │ ├── __init__.py │ ├── dataset_builder.py │ └── dataset.py ├── trainer │ ├── __init__.py │ ├── eval_builder.py │ ├── train_builder.py │ ├── train.py │ └── eval.py ├── utils │ ├── __init__.py │ ├── logger.py │ ├── registry.py │ ├── util.py │ ├── postprocessing.py │ ├── lr_scheduler.py │ └── metrics.py ├── configs │ ├── miniroad_epic-tent-O.yaml │ └── miniroad_assembly101-O.yaml ├── data_info │ └── tvseries_length_24fps.json └── main.py ├── assets └── teaser.png ├── step_anticipation ├── data │ ├── idx2action.pkl │ ├── context_prompt │ │ ├── context_prompt.json │ │ └── epictent_context_prompt_train.json │ ├── utils │ │ ├── emoji.txt │ │ ├── toys.txt │ │ ├── toys.json │ │ └── toy2class.json │ ├── predictions │ │ └── output_miniROAD_Epic-tent-O.json │ └── idx2emoji.json ├── llama │ ├── __init__.py │ └── tokenizer.py ├── scripts │ └── anticipation.sh └── src │ ├── utils │ ├── metrics.py │ └── parser.py │ ├── data │ ├── assemblyLabelDataset.py │ ├── frequentist_baseline.py │ └── assembly_text.py │ └── models │ ├── llm_ollama.py │ ├── llm_hf.py │ └── llama_meta.py ├── requirements.txt ├── run.sh ├── LICENSE ├── utils └── aggregate.py ├── data └── output │ └── aggregated_data.json └── README.md /requirements_hf.txt: -------------------------------------------------------------------------------- 1 | ollama 2 | fire 3 | numpy 4 | transformers -------------------------------------------------------------------------------- /step_recognition/model/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .rnn import MROAD, MROADA -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleflabo/PREGO/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /step_recognition/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import OadLoss 2 | from .loss_builder import build_criterion -------------------------------------------------------------------------------- /step_anticipation/data/idx2action.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleflabo/PREGO/HEAD/step_anticipation/data/idx2action.pkl -------------------------------------------------------------------------------- /step_recognition/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_models import ViTEnc 2 | from .rnn import MROAD 3 | from .model_builder import build_model -------------------------------------------------------------------------------- /step_recognition/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_builder import build_dataset, build_data_loader 2 | from .dataset import THUMOSDataset, FINEACTIONDataset -------------------------------------------------------------------------------- /step_recognition/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_builder import build_trainer 2 | from .train import train_one_epoch 3 | from .eval_builder import build_eval 4 | from .eval import Evaluate -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ViT import ViTEnc 2 | from .HybridViT import ResNetHybridViT, AxialNetHybridViT 3 | 4 | __all__ = ['ResNetHybridViT', 'AxialNetHybridViT', 'VisionTransformer_v3'] 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # torch 2 | fairscale 3 | fire 4 | sentencepiece 5 | gdown 6 | numpy 7 | tensorboard 8 | PyYAML 9 | scikit-learn 10 | torch>=2.0.0 11 | torchvision 12 | tqdm 13 | git+https://github.com/meta-llama/llama.git -------------------------------------------------------------------------------- /step_recognition/trainer/eval_builder.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'build_eval' 3 | ] 4 | 5 | from utils import Registry 6 | 7 | EVAL = Registry() 8 | 9 | def build_eval(cfg): 10 | eval = EVAL[cfg["task"]](cfg) 11 | return eval 12 | -------------------------------------------------------------------------------- /step_recognition/trainer/train_builder.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'build_trainer' 3 | ] 4 | 5 | from utils import Registry 6 | 7 | TRAINER = Registry() 8 | 9 | def build_trainer(cfg): 10 | trainer = TRAINER[cfg["task"]] 11 | return trainer 12 | -------------------------------------------------------------------------------- /step_recognition/model/model_builder.py: -------------------------------------------------------------------------------- 1 | __all__ = ['build_model'] 2 | 3 | from utils import Registry 4 | 5 | META_ARCHITECTURES = Registry() 6 | 7 | def build_model(cfg, device=None): 8 | model = META_ARCHITECTURES[cfg["model"]](cfg) 9 | return model.to(device) -------------------------------------------------------------------------------- /step_recognition/criterions/loss_builder.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'build_criterion' 3 | ] 4 | 5 | from utils import Registry 6 | 7 | CRITERIONS = Registry() 8 | 9 | def build_criterion(cfg, device=None): 10 | criterion = CRITERIONS[cfg["loss"]](cfg) 11 | return criterion.to(device) -------------------------------------------------------------------------------- /step_anticipation/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .generation import Llama, Dialog 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Tokenizer 7 | -------------------------------------------------------------------------------- /step_recognition/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import perframe_average_precision, perstage_average_precision 2 | from .postprocessing import thumos_postprocessing 3 | from .util import * 4 | from .group_transforms import * 5 | from .lr_scheduler import build_lr_scheduler 6 | from .logger import get_logger 7 | from .registry import Registry -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHONPATH=. python ./step_anticipation/src/models/chimera/llm_hf.py --model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --max_seq_len=2048 --max_batch_size=6 --temperature=0.6 --max_gen_len=5 --dataset=epictent --type_prompt=num --num_samples=5 4 | 5 | # --use_gt 6 | # --model_name=unsloth/Llama-3.2-1B 7 | # --model_name=meta-llama/Llama-3.2-1B 8 | # --model_name=meta-llama/Llama-2-7B-hf 9 | -------------------------------------------------------------------------------- /step_recognition/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | def get_logger(output_path): 5 | logger = logging.getLogger(__name__) 6 | logger.setLevel(logging.DEBUG) 7 | # formatter = logging.Formatter('%(asctime)s:%(message)s ', '%Y-%m-%d %H:%M:%S') 8 | console_handler = logging.StreamHandler() 9 | console_handler.setLevel(logging.INFO) 10 | # console_handler.setFormatter(formatter) 11 | logger.addHandler(console_handler) 12 | log_file = os.path.join(output_path, 'log.txt') 13 | fh = logging.FileHandler(log_file) 14 | fh.setLevel(logging.INFO) 15 | logger.addHandler(fh) 16 | 17 | return logger -------------------------------------------------------------------------------- /step_recognition/utils/registry.py: -------------------------------------------------------------------------------- 1 | def _register_generic(module_dict, module_name, module): 2 | assert module_name not in module_dict 3 | module_dict[module_name] = module 4 | 5 | 6 | class Registry(dict): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super(Registry, self).__init__(*args, **kwargs) 10 | 11 | def register(self, module_name, module=None): 12 | if module is not None: 13 | _register_generic(self, module_name, module) 14 | return 15 | 16 | def register_fn(fn): 17 | _register_generic(self, module_name, fn) 18 | return fn 19 | 20 | return register_fn -------------------------------------------------------------------------------- /step_anticipation/data/context_prompt/context_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "default": { 3 | "init": "Sequence type:", 4 | "input": "Input Sequence:", 5 | "output": "Next Symbol:" 6 | }, 7 | "unreferenced": { 8 | "init": "Context:", 9 | "input": "Input:", 10 | "output": "Output:" 11 | }, 12 | "elaborate": { 13 | "init": "Given the sequences of the following:", 14 | "input": "Complete the following sequence:", 15 | "output": "Sequence is completed with:" 16 | }, 17 | "no-context": { 18 | "init": "Sequence type:", 19 | "input": "", 20 | "output": "" 21 | } 22 | } -------------------------------------------------------------------------------- /step_recognition/datasets/dataset_builder.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'build_dataset', 3 | 'build_data_loader', 4 | ] 5 | 6 | import torch.utils.data as data 7 | from utils import Registry 8 | 9 | DATA_LAYERS = Registry() 10 | 11 | def build_dataset(cfg): 12 | data_layer = DATA_LAYERS[f'{cfg["data_name"]}'] 13 | return data_layer 14 | 15 | def build_data_loader(cfg, mode): 16 | data_layer = build_dataset(cfg) 17 | data_loader = data.DataLoader( 18 | dataset=data_layer(cfg, mode), 19 | batch_size=cfg["batch_size"] if mode == 'train' else cfg["test_batch_size"], 20 | shuffle=True if mode == 'train' else False, 21 | num_workers=cfg["num_workers"], 22 | pin_memory=False, 23 | ) 24 | return data_loader -------------------------------------------------------------------------------- /step_recognition/configs/miniroad_epic-tent-O.yaml: -------------------------------------------------------------------------------- 1 | model: 'MiniROAD' 2 | data_name: 'EPIC-TENT-O' 3 | task: 'OAD' 4 | loss: 'NONUNIFORM' 5 | metric: 'AP' 6 | optimizer: 'AdamW' 7 | device: 'cuda:0' 8 | feature_pretrained: 'kinetics' 9 | root_path: 'Epic-tent-O' 10 | rgb_type: 'rgb_anet_resnet50' 11 | flow_type: 'flow_anet_resnet50' 12 | annotation_type: 'target_perframe' 13 | video_list_path: 'step_recognition/data_info/video_list.json' 14 | output_path: 'step_recognition/checkpoint_miniROAD/Epic-tent-O' 15 | window_size: 128 #128 16 | batch_size: 16 #! 16 #16 a train e 1 a test 17 | test_batch_size: 1 18 | num_epoch: 10 19 | lr: 0.0001 20 | weight_decay: 0.05 21 | num_workers: 4 22 | dropout: 0.20 23 | num_classes: 12 # including background 86 24 | embedding_dim: 2048 25 | hidden_dim: 1024 26 | num_layers: 1 27 | stride: 4 -------------------------------------------------------------------------------- /step_recognition/configs/miniroad_assembly101-O.yaml: -------------------------------------------------------------------------------- 1 | model: 'MiniROAD' 2 | data_name: 'ASSEMBLY101-O' 3 | task: 'OAD' 4 | loss: 'NONUNIFORM' 5 | metric: 'AP' 6 | optimizer: 'AdamW' 7 | device: 'cuda:0' 8 | feature_pretrained: 'kinetics' 9 | root_path: 'Assembly101-O' 10 | rgb_type: 'rgb_anet_resnet50' 11 | flow_type: 'flow_anet_resnet50' #'flow_nv_kinetics_bninception' #'flow_anet_resnet50' 12 | annotation_type: 'target_perframe' 13 | video_list_path: 'step_recognition/data_info/video_list.json' 14 | output_path: 'step_recognition/checkpoint_miniROAD/Assembly101-O' 15 | window_size: 128 #128 16 | batch_size: 16 #16 #16 a train e 1 a test 17 | test_batch_size: 1 18 | num_epoch: 10 19 | lr: 0.0001 20 | weight_decay: 0.05 21 | num_workers: 4 22 | dropout: 0.20 23 | num_classes: 86 # including background 24 | embedding_dim: 2048 25 | hidden_dim: 1024 26 | num_layers: 1 27 | stride: 4 -------------------------------------------------------------------------------- /step_anticipation/scripts/anticipation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # PYTHONPATH=. torchrun --nproc_per_node 2 ./src/models/llama_meta.py --ckpt_dir=/media/ssd/usr/edo/llama/llama-2-13b --tokenizer_path=/media/ssd/usr/edo/llama/tokenizer.model --max_seq_len=2048 --max_batch_size=6 --temperature=0 --num_samples=5 --max_gen_len=4 --use_gt --dataset=assembly --type_prompt=num 4 | PYTHONPATH=. torchrun --nproc_per_node 1 ./src/models/llama_meta.py --ckpt_dir=/media/ssd/usr/edo/llama/llama-2-7b --tokenizer_path=/media/ssd/usr/edo/llama/tokenizer.model --max_seq_len=2048 --max_batch_size=6 --temperature=0.6 --num_samples=5 --max_gen_len=8 --dataset=assembly --type_prompt=emoji 5 | # --use_gt 6 | PYTHONPATH=. torchrun --nproc_per_node 2 ./src/models/llama_meta.py --ckpt_dir=/media/ssd/usr/edo/llama/llama-2-14b --tokenizer_path=/media/ssd/usr/edo/llama/tokenizer.model --max_seq_len=2048 --max_batch_size=6 --temperature=0.6 --num_samples=5 --max_gen_len=8 --dataset=assembly --type_prompt=emoji -------------------------------------------------------------------------------- /step_anticipation/data/utils/emoji.txt: -------------------------------------------------------------------------------- 1 | 0 ❄️ 2 | 1 🏚️ 3 | 2 🛌🏻 4 | 3 🚓 5 | 4 🦘 6 | 5 🛕 7 | 6 📡 8 | 7 🪤 9 | 8 🌲 10 | 9 🤌🏽 11 | 10 🍸 12 | 11 🌚 13 | 12 🧎🏼 14 | 13 🍉 15 | 14 🤏🏿 16 | 15 👐🏻 17 | 16 👆🏼 18 | 17 😚 19 | 18 🚌 20 | 19 💇🏻 21 | 20 🛵 22 | 21 🌫️ 23 | 22 🕘 24 | 23 🍺 25 | 24 ⏯️ 26 | 25 💁🏽 27 | 26 🕺🏽 28 | 27 🏌🏼 29 | 28 🥷🏽 30 | 29 🍨 31 | 30 ♐ 32 | 31 ⤴️ 33 | 32 ♑ 34 | 33 🧝🏿 35 | 34 🗄️ 36 | 35 👉 37 | 36 🏌🏼 38 | 37 🕜 39 | 38 🕶️ 40 | 39 🦵🏿 41 | 40 🪂 42 | 41 🍘 43 | 42 🪖 44 | 43 🛬 45 | 44 🙆 46 | 45 💄 47 | 46 🔐 48 | 47 🕖 49 | 48 🎎 50 | 49 🥉 51 | 50 🕴🏾 52 | 51 💚 53 | 52 🗜️ 54 | 53 👦🏽 55 | 54 👜 56 | 55 ⚗️ 57 | 56 🪛 58 | 57 ⛳ 59 | 58 🫓 60 | 59 👩 61 | 60 🕠 62 | 61 🗿 63 | 62 👐🏾 64 | 63 🔤 65 | 64 🐖 66 | 65 ♏ 67 | 66 😒 68 | 67 🍁 69 | 68 🆔 70 | 69 🤚🏻 71 | 70 🧔🏼 72 | 71 💣 73 | 72 🐳 74 | 73 🆘 75 | 74 🃏 76 | 75 ℹ️ 77 | 76 🔼 78 | 77 🧰 79 | 78 🧛🏽 80 | 79 👡 81 | 80 🥯 82 | 81 🍽️ 83 | 82 ✋🏼 84 | 83 🔒 85 | 84 🏌🏿 86 | 85 🥒 87 | -------------------------------------------------------------------------------- /step_recognition/data_info/tvseries_length_24fps.json: -------------------------------------------------------------------------------- 1 | { 2 | "24_ep1": 58991, 3 | "24_ep2": 59447, 4 | "24_ep3": 60887, 5 | "24_ep4": 59831, 6 | "Breaking_Bad_ep1": 80255, 7 | "Breaking_Bad_ep2": 66599, 8 | "Breaking_Bad_ep3": 66575, 9 | "How_I_Met_Your_Mother_ep1": 30468, 10 | "How_I_Met_Your_Mother_ep2": 30396, 11 | "How_I_Met_Your_Mother_ep3": 29471, 12 | "How_I_Met_Your_Mother_ep4": 30407, 13 | "How_I_Met_Your_Mother_ep5": 30404, 14 | "How_I_Met_Your_Mother_ep6": 30527, 15 | "How_I_Met_Your_Mother_ep7": 30431, 16 | "How_I_Met_Your_Mother_ep8": 30432, 17 | "Mad_Men_ep1": 70033, 18 | "Mad_Men_ep2": 67705, 19 | "Mad_Men_ep3": 63649, 20 | "Modern_Family_ep1": 31751, 21 | "Modern_Family_ep2": 29807, 22 | "Modern_Family_ep3": 29327, 23 | "Modern_Family_ep4": 29927, 24 | "Modern_Family_ep5": 28991, 25 | "Modern_Family_ep6": 28751, 26 | "Sons_of_Anarchy_ep1": 78983, 27 | "Sons_of_Anarchy_ep2": 64079, 28 | "Sons_of_Anarchy_ep3": 65687 29 | } -------------------------------------------------------------------------------- /step_recognition/utils/util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import torch 7 | 8 | def dump_pickle(lst, file_path, file_name): 9 | with open(osp.join(file_path, file_name + '.pkl'), 'wb') as f: 10 | pickle.dump(lst, f) 11 | 12 | def create_dir(dir_path): 13 | if not osp.exists(dir_path): 14 | os.makedirs(dir_path) 15 | 16 | def create_outdir(result_path): 17 | i = 1 18 | new_result_path = result_path 19 | while osp.exists(new_result_path): 20 | new_result_path = f'{result_path}_{i}' 21 | i += 1 22 | create_dir(osp.join(new_result_path, 'ckpts')) 23 | create_dir(osp.join(new_result_path, 'runs')) 24 | return new_result_path 25 | 26 | def set_seed(seed): 27 | # os.environ['PYTHONHASHSEED'] = str(seed) 28 | random.seed(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | if torch.cuda.is_available(): 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | torch.backends.cudnn.benchmark = False 35 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 wangxiang1230 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /step_recognition/model/weights_init.py: -------------------------------------------------------------------------------- 1 | # FROM LSTR 2 | import math 3 | import torch.nn as nn 4 | 5 | def weights_init(m): 6 | if isinstance(m, nn.Linear): 7 | nn.init.kaiming_uniform_(m.weight.data, a=math.sqrt(5)) 8 | elif isinstance(m, nn.Conv1d): 9 | nn.init.normal_(m.weight.data) 10 | if m.bias is not None: 11 | nn.init.normal_(m.bias.data) 12 | elif isinstance(m, nn.Conv2d): 13 | nn.init.xavier_normal_(m.weight.data) 14 | if m.bias is not None: 15 | nn.init.normal_(m.bias.data) 16 | elif isinstance(m, nn.ConvTranspose1d): 17 | nn.init.normal_(m.weight.data) 18 | if m.bias is not None: 19 | nn.init.normal_(m.bias.data) 20 | elif isinstance(m, nn.ConvTranspose2d): 21 | nn.init.xavier_normal_(m.weight.data) 22 | if m.bias is not None: 23 | nn.init.normal_(m.bias.data) 24 | elif isinstance(m, nn.BatchNorm1d): 25 | nn.init.normal_(m.weight.data, mean=1, std=0.02) 26 | nn.init.constant_(m.bias.data, 0) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | nn.init.normal_(m.weight.data, mean=1, std=0.02) 29 | nn.init.constant_(m.bias.data, 0) 30 | elif isinstance(m, nn.GRUCell): 31 | for param in m.parameters(): 32 | if len(param.shape) >= 2: 33 | nn.init.orthogonal_(param.data) 34 | else: 35 | nn.init.normal_(param.data) -------------------------------------------------------------------------------- /step_anticipation/src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class MultiHotAccuracy(Metric): 8 | is_differentiable = True 9 | higher_is_better = True 10 | full_state_update = True 11 | 12 | def __init__(self, threshold: float = 0.5): 13 | super().__init__() 14 | self.add_state("correct", default=torch.tensor(0, dtype=torch.long)) 15 | self.add_state("total", default=torch.tensor(0, dtype=torch.long)) 16 | 17 | self.threshold = threshold 18 | 19 | @torch.inference_mode() 20 | def update(self, preds: torch.Tensor, target: torch.Tensor): 21 | assert preds.shape == target.shape 22 | 23 | # Convert to binary 24 | preds = (preds > self.threshold).int() 25 | target = target.int() 26 | 27 | # # ! we only consider the case where the target is 1 28 | # condition = ((target == 1) & (preds == target)).int() 29 | 30 | # ! we consider the case in which the whole vector is correct 31 | # ? maybe is too harsh 32 | preds = preds.view(-1, preds.shape[-1]) 33 | target = target.view(-1, target.shape[-1]) 34 | condition = preds == target 35 | 36 | self.correct = self.correct + torch.sum(torch.prod(condition, dim=-1)) 37 | self.total = self.total + target.shape[0] 38 | 39 | @torch.inference_mode() 40 | def compute(self): 41 | return self.correct / self.total 42 | -------------------------------------------------------------------------------- /step_recognition/utils/postprocessing.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def thumos_postprocessing(ground_truth, prediction, smooth=False, switch=False): 5 | """ 6 | We follow (Shou et al., 2017) and adopt their perframe postprocessing method on THUMOS'14 datset. 7 | Source: https://bitbucket.org/columbiadvmm/cdc/src/master/THUMOS14/eval/PreFrameLabeling/compute_framelevel_mAP.m 8 | """ 9 | 10 | # Simple temporal smoothing via NMS of 5-frames window 11 | if smooth: 12 | prob = np.copy(prediction) 13 | prob1 = prob.reshape(1, prob.shape[0], prob.shape[1]) 14 | prob2 = np.append(prob[0, :].reshape(1, -1), prob[0: -1, :], axis=0).reshape(1, prob.shape[0], prob.shape[1]) 15 | prob3 = np.append(prob[1:, :], prob[-1, :].reshape(1, -1), axis=0).reshape(1, prob.shape[0], prob.shape[1]) 16 | prob4 = np.append(prob[0: 2, :], prob[0: -2, :], axis=0).reshape(1, prob.shape[0], prob.shape[1]) 17 | prob5 = np.append(prob[2:, :], prob[-2:, :], axis=0).reshape(1, prob.shape[0], prob.shape[1]) 18 | probsmooth = np.squeeze(np.max(np.concatenate((prob1, prob2, prob3, prob4, prob5), axis=0), axis=0)) 19 | prediction = np.copy(probsmooth) 20 | 21 | # Assign cliff diving (5) as diving (8) 22 | if switch: 23 | switch_index = np.where(prediction[:, 5] > prediction[:, 8])[0] 24 | prediction[switch_index, 8] = prediction[switch_index, 5] 25 | 26 | # Remove ambiguous (21) 27 | valid_index = np.where(ground_truth[:, 21] != 1)[0] 28 | 29 | return ground_truth[valid_index], prediction[valid_index] -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/PositionalEncoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FixedPositionalEncoding(nn.Module): 6 | def __init__(self, embedding_dim, max_length=5000): 7 | super(FixedPositionalEncoding, self).__init__() 8 | 9 | pe = torch.zeros(max_length, embedding_dim) 10 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) 11 | div_term = torch.exp( 12 | torch.arange(0, embedding_dim, 2).float() 13 | * (-torch.log(torch.tensor(10000.0)) / embedding_dim) 14 | ) 15 | pe[:, 0::2] = torch.sin(position * div_term) 16 | pe[:, 1::2] = torch.cos(position * div_term) 17 | pe = pe.unsqueeze(0).transpose(0, 1) 18 | self.register_buffer('pe', pe) 19 | 20 | def forward(self, x): 21 | x = x + self.pe[: x.size(0), :] 22 | return x 23 | 24 | 25 | class LearnedPositionalEncoding(nn.Module): 26 | def __init__(self, max_position_embeddings, embedding_dim, seq_length): 27 | super(LearnedPositionalEncoding, self).__init__() 28 | self.pe = nn.Embedding(max_position_embeddings, embedding_dim) 29 | self.seq_length = seq_length 30 | 31 | self.register_buffer( 32 | "position_ids", 33 | torch.arange(max_position_embeddings).expand((1, -1)), 34 | ) 35 | 36 | def forward(self, x, position_ids=None): 37 | if position_ids is None: 38 | position_ids = self.position_ids[:, : self.seq_length] 39 | 40 | position_embeddings = self.pe(position_ids) 41 | return x + position_embeddings 42 | -------------------------------------------------------------------------------- /step_anticipation/src/data/assemblyLabelDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | 3 | from src.data.dataset_utils import get_OH_data 4 | 5 | 6 | class AssemblyLabelDataset(Dataset): 7 | def __init__(self, csv_path, split="correct"): 8 | assert split in [ 9 | "correct", 10 | "mistake", 11 | "all", 12 | ], "split must be 'correct', 'mistake' or 'all', not '{}'".format(split) 13 | ( 14 | self.oh_samplelist, 15 | self.oh_labellist, 16 | self.metadata, 17 | self.all_keysteps, 18 | ) = get_OH_data(csv_path, split) 19 | 20 | def __len__(self): 21 | return len(self.oh_samplelist) 22 | 23 | def __getitem__(self, idx): 24 | sample = { 25 | "oh_sample": self.oh_samplelist[idx], 26 | "oh_label": self.oh_labellist[idx], 27 | "metadata": self.metadata[idx], 28 | } 29 | 30 | return sample 31 | 32 | 33 | def collate_fn(batch): 34 | print("collate_fn") 35 | 36 | return batch 37 | 38 | 39 | if __name__ == "__main__": 40 | path_to_csv_A101 = "./mistake_labels/" 41 | 42 | # train_dataset = AssemblyLabelDataset(path_to_csv_A101, is_train=True) 43 | # test_dataset = AssemblyLabelDataset(path_to_csv_A101, is_train=False) 44 | 45 | train_dataset = AssemblyLabelDataset(path_to_csv_A101, split="all") 46 | test_dataset = AssemblyLabelDataset(path_to_csv_A101, split="mistake") 47 | batch_size = 32 # Adjust batch size as needed 48 | train_dataloader = DataLoader( 49 | train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn 50 | ) 51 | 52 | for batch in train_dataloader: 53 | print(batch) 54 | break 55 | 56 | print(len(train_dataset)) 57 | print(len(test_dataset)) 58 | -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DecoderLayer(nn.Module): 7 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 8 | dropout=0.1, activation="relu"): 9 | super(DecoderLayer, self).__init__() 10 | d_ff = d_ff or 4*d_model 11 | self.self_attention = self_attention 12 | self.cross_attention = cross_attention 13 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 14 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 15 | self.norm1 = nn.LayerNorm(d_model) 16 | self.norm2 = nn.LayerNorm(d_model) 17 | self.norm3 = nn.LayerNorm(d_model) 18 | self.dropout = nn.Dropout(dropout) 19 | self.activation = F.relu if activation == "relu" else F.gelu 20 | 21 | def forward(self, x, cross, x_mask=None, cross_mask=None): 22 | x = x + self.dropout(self.self_attention( 23 | x, x, x, 24 | attn_mask=x_mask 25 | )) 26 | x = self.norm1(x) 27 | 28 | x = x + self.dropout(self.cross_attention( 29 | x, cross, cross, 30 | attn_mask=cross_mask 31 | )) 32 | 33 | y = x = self.norm2(x) 34 | y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) 35 | y = self.dropout(self.conv2(y).transpose(-1,1)) 36 | 37 | return self.norm3(x+y) 38 | 39 | 40 | class Decoder(nn.Module): 41 | def __init__(self, layers, norm_layer=None): 42 | super(Decoder, self).__init__() 43 | self.layers = nn.ModuleList(layers) 44 | self.norm = norm_layer 45 | 46 | def forward(self, x, cross, x_mask=None, cross_mask=None): 47 | for layer in self.layers: 48 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 49 | 50 | if self.norm is not None: 51 | x = self.norm(x) 52 | 53 | return x -------------------------------------------------------------------------------- /step_anticipation/data/utils/toys.txt: -------------------------------------------------------------------------------- 1 | 0 a01 excavator 2 | 1 a02 bulldozer 3 | 2 a03 bulldozer 4 | 3 a06 clamp 5 | 4 a07 bulldozer 6 | 5 a08 crane 7 | 6 a09 clamp 8 | 7 a10 garbage_truck 9 | 8 a11 bulldozer 10 | 9 a12 dumper 11 | 10 a13 dumper 12 | 11 a14 transporter 13 | 12 a15 garbage_truck 14 | 13 a16 ladder_truck 15 | 14 a17 fire_truck 16 | 15 a18 fire_truck 17 | 16 a19 ladder_truck 18 | 17 a20 fire_truck 19 | 18 a21 dumper 20 | 19 a23 car 21 | 20 a24 car 22 | 21 a26 car 23 | 22 a27 suv 24 | 23 a28 suv 25 | 24 a29 suv 26 | 25 a30 suv 27 | 26 a31 car 28 | 27 b01a bulldozer 29 | 28 b01b roller 30 | 29 b02a dumper 31 | 30 b02b clamp 32 | 31 b03a ladder_truck 33 | 32 b03b fire_truck 34 | 33 b04a excavator 35 | 34 b04b bulldozer 36 | 35 b04c jackhammer 37 | 36 b04d roller 38 | 37 b05a excavator 39 | 38 b05b crane 40 | 39 b05c cement_mixer 41 | 40 b05d dumper 42 | 41 b06a dumper 43 | 42 b06b crane 44 | 43 b06c excavator 45 | 44 b06d cement_mixer 46 | 45 b08a clamp 47 | 46 b08b bulldozer 48 | 47 b08c bulldozer 49 | 48 b08d dumper 50 | 49 c01a excavator 51 | 50 c01b dumper 52 | 51 c01c crane 53 | 52 c01d cement_mixer 54 | 53 c02a excavator 55 | 54 c02b bulldozer 56 | 55 c02c roller 57 | 56 c03a excavator 58 | 57 c03b bulldozer 59 | 58 c03c cement_mixer 60 | 59 c03d dumper 61 | 60 c03e jackhammer 62 | 61 c03f roller 63 | 62 c04a excavator 64 | 63 c04b jackhammer 65 | 64 c04c roller 66 | 65 c04d excavator 67 | 66 c05a excavator 68 | 67 c05b bulldozer 69 | 68 c06a dumper 70 | 69 c06b cement_mixer 71 | 70 c06c crane 72 | 71 c06d water_tanker 73 | 72 c06e excavator 74 | 73 c06f ladder_truck 75 | 74 c07a excavator 76 | 75 c07b ladder_truck 77 | 76 c07c garbage_truck 78 | 77 c08a crane 79 | 78 c08b garbage_truck 80 | 79 c08c transporter 81 | 80 c09a dumper 82 | 81 c09b water_tanker 83 | 82 c09c transporter 84 | 83 c10a garbage_truck 85 | 84 c10b dumper 86 | 85 c10c water_tanker 87 | 86 c11a car 88 | 87 c11b suv 89 | 88 c12a crane 90 | 89 c12b water_tanker 91 | 90 c12c excavator 92 | 91 c12d ladder_truck 93 | 92 c12e dumper 94 | 93 c13a roller 95 | 94 c13b jackhammer 96 | 95 c13c excavator 97 | 96 c13d bulldozer 98 | 97 c13e dumper 99 | 98 c13f water_tanker 100 | 99 c14a excavator 101 | 100 c14b clamp 102 | -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/Transformer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .Attention import SelfAttention 3 | 4 | 5 | class Residual(nn.Module): 6 | def __init__(self, fn): 7 | super().__init__() 8 | self.fn = fn 9 | 10 | def forward(self, x): 11 | return self.fn(x) + x 12 | 13 | 14 | class PreNorm(nn.Module): 15 | def __init__(self, dim, fn): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim) 18 | self.fn = fn 19 | 20 | def forward(self, x): 21 | return self.fn(self.norm(x)) 22 | 23 | 24 | class PreNormDrop(nn.Module): 25 | def __init__(self, dim, dropout_rate, fn): 26 | super().__init__() 27 | self.norm = nn.LayerNorm(dim) 28 | self.dropout = nn.Dropout(p=dropout_rate) 29 | self.fn = fn 30 | 31 | def forward(self, x): 32 | return self.dropout(self.fn(self.norm(x))) 33 | 34 | 35 | class FeedForward(nn.Module): 36 | def __init__(self, dim, hidden_dim, dropout_rate): 37 | super().__init__() 38 | self.net = nn.Sequential( 39 | nn.Linear(dim, hidden_dim), 40 | nn.GELU(), 41 | nn.Dropout(p=dropout_rate), 42 | nn.Linear(hidden_dim, dim), 43 | nn.Dropout(p=dropout_rate), 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | class TransformerModel(nn.Module): 51 | def __init__( 52 | self, 53 | dim, 54 | depth, 55 | heads, 56 | mlp_dim, 57 | dropout_rate=0.1, 58 | attn_dropout_rate=0.1, 59 | ): 60 | super().__init__() 61 | layers = [] 62 | for _ in range(depth): 63 | layers.extend( 64 | [ 65 | Residual( 66 | PreNormDrop( 67 | dim, 68 | dropout_rate, 69 | SelfAttention( 70 | dim, heads=heads, dropout_rate=attn_dropout_rate 71 | ), 72 | ) 73 | ), 74 | Residual( 75 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) 76 | ), 77 | ] 78 | ) 79 | self.net = nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | return self.net(x) 83 | -------------------------------------------------------------------------------- /step_anticipation/llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | from logging import getLogger 6 | from typing import List 7 | 8 | from sentencepiece import SentencePieceProcessor 9 | 10 | logger = getLogger() 11 | 12 | 13 | class Tokenizer: 14 | """tokenizing and encoding/decoding text using SentencePiece.""" 15 | 16 | def __init__(self, model_path: str): 17 | """ 18 | Initializes the Tokenizer with a SentencePiece model. 19 | 20 | Args: 21 | model_path (str): The path to the SentencePiece model file. 22 | """ 23 | # reload tokenizer 24 | assert os.path.isfile(model_path), model_path 25 | self.sp_model = SentencePieceProcessor(model_file=model_path) 26 | logger.info(f"Reloaded SentencePiece model from {model_path}") 27 | 28 | # BOS / EOS token IDs 29 | self.n_words: int = self.sp_model.vocab_size() 30 | self.bos_id: int = self.sp_model.bos_id() 31 | self.eos_id: int = self.sp_model.eos_id() 32 | self.pad_id: int = self.sp_model.pad_id() 33 | logger.info( 34 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 35 | ) 36 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 37 | 38 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 39 | """ 40 | Encodes a string into a list of token IDs. 41 | 42 | Args: 43 | s (str): The input string to be encoded. 44 | bos (bool): Whether to prepend the beginning-of-sequence token. 45 | eos (bool): Whether to append the end-of-sequence token. 46 | 47 | Returns: 48 | List[int]: A list of token IDs. 49 | """ 50 | assert type(s) is str 51 | t = self.sp_model.encode(s) 52 | if bos: 53 | t = [self.bos_id] + t 54 | if eos: 55 | t = t + [self.eos_id] 56 | return t 57 | 58 | def decode(self, t: List[int]) -> str: 59 | """ 60 | Decodes a list of token IDs into a string. 61 | 62 | Args: 63 | t (List[int]): The list of token IDs to be decoded. 64 | 65 | Returns: 66 | str: The decoded string. 67 | """ 68 | return self.sp_model.decode(t) 69 | -------------------------------------------------------------------------------- /step_anticipation/data/utils/toys.json: -------------------------------------------------------------------------------- 1 | { 2 | "excavator": [ 3 | "a01", 4 | "b04a", 5 | "b05a", 6 | "b06c", 7 | "c01a", 8 | "c02a", 9 | "c03a", 10 | "c04a", 11 | "c04d", 12 | "c05a", 13 | "c06e", 14 | "c07a", 15 | "c12c", 16 | "c13c", 17 | "c14a" 18 | ], 19 | "bulldozer": [ 20 | "a02", 21 | "a03", 22 | "a07", 23 | "a11", 24 | "b01a", 25 | "b04b", 26 | "b08b", 27 | "b08c", 28 | "c02b", 29 | "c03b", 30 | "c05b", 31 | "c13d" 32 | ], 33 | "clamp": [ 34 | "a06", 35 | "a09", 36 | "b02b", 37 | "b08a", 38 | "c14b" 39 | ], 40 | "crane": [ 41 | "a08", 42 | "b05b", 43 | "b06b", 44 | "c01c", 45 | "c06c", 46 | "c08a", 47 | "c12a" 48 | ], 49 | "garbage_truck": [ 50 | "a10", 51 | "a15", 52 | "c07c", 53 | "c08b", 54 | "c10a" 55 | ], 56 | "dumper": [ 57 | "a12", 58 | "a13", 59 | "a21", 60 | "b02a", 61 | "b05d", 62 | "b06a", 63 | "b08d", 64 | "c01b", 65 | "c03d", 66 | "c06a", 67 | "c09a", 68 | "c10b", 69 | "c12e", 70 | "c13e" 71 | ], 72 | "transporter": [ 73 | "a14", 74 | "c08c", 75 | "c09c" 76 | ], 77 | "ladder_truck": [ 78 | "a16", 79 | "a19", 80 | "b03a", 81 | "c06f", 82 | "c07b", 83 | "c12d" 84 | ], 85 | "fire_truck": [ 86 | "a17", 87 | "a18", 88 | "a20", 89 | "b03b" 90 | ], 91 | "car": [ 92 | "a23", 93 | "a24", 94 | "a26", 95 | "a31", 96 | "c11a" 97 | ], 98 | "suv": [ 99 | "a27", 100 | "a28", 101 | "a29", 102 | "a30", 103 | "c11b" 104 | ], 105 | "roller": [ 106 | "b01b", 107 | "b04d", 108 | "c02c", 109 | "c03f", 110 | "c04c", 111 | "c13a" 112 | ], 113 | "jackhammer": [ 114 | "b04c", 115 | "c03e", 116 | "c04b", 117 | "c13b" 118 | ], 119 | "cement_mixer": [ 120 | "b05c", 121 | "b06d", 122 | "c01d", 123 | "c03c", 124 | "c06b" 125 | ], 126 | "water_tanker": [ 127 | "c06d", 128 | "c09b", 129 | "c10c", 130 | "c12b", 131 | "c13f" 132 | ] 133 | } -------------------------------------------------------------------------------- /step_recognition/trainer/train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from trainer.train_builder import TRAINER 4 | 5 | @TRAINER.register("OAD") 6 | def train_one_epoch(trainloader, model, criterion, optimizer, scaler, epoch, device, writer=None, scheduler=None): 7 | epoch_loss = 0 8 | for it, (rgb_input, flow_input, target, vid, start, end) in enumerate(tqdm(trainloader, desc=f'Epoch:{epoch} Training', postfix=f'lr: {optimizer.param_groups[0]["lr"]:.7f}')): 9 | rgb_input, flow_input, target = rgb_input.to(device), flow_input.to(device), target.to(device) #rgb_input.cuda(), flow_input.cuda(), target.cuda() 10 | model.train() 11 | if scaler != None: 12 | with torch.cuda.amp.autocast(): 13 | out_dict = model(rgb_input, flow_input) 14 | loss = criterion(out_dict, target) 15 | optimizer.zero_grad(set_to_none=True) 16 | scaler.scale(loss).backward() 17 | scaler.step(optimizer) 18 | scaler.update() 19 | else: 20 | out_dict = model(rgb_input, flow_input) 21 | loss = criterion(out_dict, target) 22 | optimizer.zero_grad(set_to_none=True) 23 | loss.backward() 24 | optimizer.step() 25 | 26 | epoch_loss += loss.item() 27 | if writer != None: 28 | writer.add_scalar("Train Loss", loss.item(), it+epoch*len(trainloader)) 29 | return epoch_loss 30 | 31 | @TRAINER.register("ANTICIPATION") 32 | def ant_train_one_epoch(trainloader, model, criterion, optimizer, scaler, epoch, writer=None, scheduler=None): 33 | epoch_loss = 0 34 | for it, (rgb_input, flow_input, target, ant_target) in enumerate(tqdm(trainloader, desc=f'Epoch:{epoch} Training', postfix=f'lr: {optimizer.param_groups[0]["lr"]:.7f}')): 35 | rgb_input, flow_input, target, ant_target = rgb_input.cuda(), flow_input.cuda(), target.cuda(), ant_target.cuda() 36 | model.train() 37 | if scaler != None: 38 | with torch.cuda.amp.autocast(): 39 | out_dict = model(rgb_input, flow_input) 40 | loss = criterion(out_dict, target, ant_target) 41 | optimizer.zero_grad(set_to_none=True) 42 | scaler.scale(loss).backward() 43 | scaler.step(optimizer) 44 | scaler.update() 45 | else: 46 | out_dict = model(rgb_input, flow_input) 47 | loss = criterion(out_dict, target, ant_target) 48 | optimizer.zero_grad(set_to_none=True) 49 | loss.backward() 50 | optimizer.step() 51 | epoch_loss += loss.item() 52 | if writer != None: 53 | writer.add_scalar("Train Loss", loss.item(), it+epoch*len(trainloader)) 54 | return epoch_loss -------------------------------------------------------------------------------- /step_anticipation/data/utils/toy2class.json: -------------------------------------------------------------------------------- 1 | { 2 | "a01": "excavator", 3 | "a02": "bulldozer", 4 | "a03": "bulldozer", 5 | "a06": "clamp", 6 | "a07": "bulldozer", 7 | "a08": "crane", 8 | "a09": "clamp", 9 | "a10": "garbage_truck", 10 | "a11": "bulldozer", 11 | "a12": "dumper", 12 | "a13": "dumper", 13 | "a14": "transporter", 14 | "a15": "garbage_truck", 15 | "a16": "ladder_truck", 16 | "a17": "fire_truck", 17 | "a18": "fire_truck", 18 | "a19": "ladder_truck", 19 | "a20": "fire_truck", 20 | "a21": "dumper", 21 | "a23": "car", 22 | "a24": "car", 23 | "a26": "car", 24 | "a27": "suv", 25 | "a28": "suv", 26 | "a29": "suv", 27 | "a30": "suv", 28 | "a31": "car", 29 | "b01a": "bulldozer", 30 | "b01b": "roller", 31 | "b02a": "dumper", 32 | "b02b": "clamp", 33 | "b03a": "ladder_truck", 34 | "b03b": "fire_truck", 35 | "b04a": "excavator", 36 | "b04b": "bulldozer", 37 | "b04c": "jackhammer", 38 | "b04d": "roller", 39 | "b05a": "excavator", 40 | "b05b": "crane", 41 | "b05c": "cement_mixer", 42 | "b05d": "dumper", 43 | "b06a": "dumper", 44 | "b06b": "crane", 45 | "b06c": "excavator", 46 | "b06d": "cement_mixer", 47 | "b08a": "clamp", 48 | "b08b": "bulldozer", 49 | "b08c": "bulldozer", 50 | "b08d": "dumper", 51 | "c01a": "excavator", 52 | "c01b": "dumper", 53 | "c01c": "crane", 54 | "c01d": "cement_mixer", 55 | "c02a": "excavator", 56 | "c02b": "bulldozer", 57 | "c02c": "roller", 58 | "c03a": "excavator", 59 | "c03b": "bulldozer", 60 | "c03c": "cement_mixer", 61 | "c03d": "dumper", 62 | "c03e": "jackhammer", 63 | "c03f": "roller", 64 | "c04a": "excavator", 65 | "c04b": "jackhammer", 66 | "c04c": "roller", 67 | "c04d": "excavator", 68 | "c05a": "excavator", 69 | "c05b": "bulldozer", 70 | "c06a": "dumper", 71 | "c06b": "cement_mixer", 72 | "c06c": "crane", 73 | "c06d": "water_tanker", 74 | "c06e": "excavator", 75 | "c06f": "ladder_truck", 76 | "c07a": "excavator", 77 | "c07b": "ladder_truck", 78 | "c07c": "garbage_truck", 79 | "c08a": "crane", 80 | "c08b": "garbage_truck", 81 | "c08c": "transporter", 82 | "c09a": "dumper", 83 | "c09b": "water_tanker", 84 | "c09c": "transporter", 85 | "c10a": "garbage_truck", 86 | "c10b": "dumper", 87 | "c10c": "water_tanker", 88 | "c11a": "car", 89 | "c11b": "suv", 90 | "c12a": "crane", 91 | "c12b": "water_tanker", 92 | "c12c": "excavator", 93 | "c12d": "ladder_truck", 94 | "c12e": "dumper", 95 | "c13a": "roller", 96 | "c13b": "jackhammer", 97 | "c13c": "excavator", 98 | "c13d": "bulldozer", 99 | "c13e": "dumper", 100 | "c13f": "water_tanker", 101 | "c14a": "excavator", 102 | "c14b": "clamp" 103 | } -------------------------------------------------------------------------------- /step_recognition/criterions/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from criterions.loss_builder import CRITERIONS 5 | 6 | @CRITERIONS.register('NONUNIFORM') 7 | class OadLoss(nn.Module): 8 | 9 | def __init__(self, cfg, reduction='mean'): 10 | super(OadLoss, self).__init__() 11 | self.reduction = reduction 12 | self.num_classes = cfg['num_classes'] 13 | self.loss = self.end_loss 14 | 15 | def end_loss(self, out_dict, target): 16 | # logits: (B, seq, K) target: (B, seq, K) 17 | logits = out_dict['logits'] 18 | logits = logits[:,-1,:].contiguous() 19 | target = target[:,-1,:].contiguous() 20 | ce_loss = self.mlce_loss(logits, target) 21 | return ce_loss 22 | 23 | def mlce_loss(self, logits, target): 24 | ''' 25 | multi label cross entropy loss. 26 | logits: (B, K) target: (B, K) 27 | ''' 28 | logsoftmax = nn.LogSoftmax(dim=-1).to(logits.device) 29 | output = torch.sum(-F.normalize(target) * logsoftmax(logits), dim=1) # B 30 | if self.reduction == 'mean': 31 | loss = torch.mean(output) 32 | elif self.reduction == 'sum': 33 | loss = torch.sum(output) 34 | return loss 35 | 36 | def forward(self, out_dict, target): 37 | return self.loss(out_dict, target) 38 | 39 | 40 | @CRITERIONS.register('ANTICIPATION') 41 | class OadAntLoss(nn.Module): 42 | 43 | def __init__(self, cfg, reduction='sum'): 44 | super(OadAntLoss, self).__init__() 45 | self.reduction = reduction 46 | self.loss = self.anticipation_loss 47 | self.num_classes = cfg['num_classes'] 48 | 49 | def anticipation_loss(self, out_dict, target, ant_target): 50 | anticipation_logits = out_dict['anticipation_logits'] 51 | pred_anticipation_logits = anticipation_logits[:,-1,:,:].contiguous().view(-1, self.num_classes) 52 | anticipation_logit_targets = ant_target.view(-1, self.num_classes) 53 | ant_loss = self.mlce_loss(pred_anticipation_logits, anticipation_logit_targets) 54 | return ant_loss 55 | 56 | def ce_loss(self, out_dict, target): 57 | # logits: (B, seq, K) target: (B, seq, K) 58 | logits = out_dict['logits'] 59 | logits = logits[:,-1,:].contiguous() 60 | target = target[:,-1,:].contiguous() 61 | ce_loss = self.mlce_loss(logits, target) 62 | return ce_loss 63 | 64 | def mlce_loss(self, logits, target): 65 | ''' 66 | multi label cross entropy loss. 67 | logits: (B, K) target: (B, K) 68 | ''' 69 | logsoftmax = nn.LogSoftmax(dim=-1).to(logits.device) 70 | output = torch.sum(-F.normalize(target) * logsoftmax(logits), dim=1) # B 71 | if self.reduction == 'mean': 72 | loss = torch.mean(output) 73 | elif self.reduction == 'sum': 74 | loss = torch.sum(output) 75 | 76 | return loss 77 | 78 | def forward(self, out_dict, target, ant_target): 79 | return self.loss(out_dict, target, ant_target) 80 | -------------------------------------------------------------------------------- /step_anticipation/src/utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | # TODO: Add your name here 4 | parser = argparse.ArgumentParser(description="PUT YOUR NAME HERE") 5 | 6 | parser.add_argument( 7 | "--cfg", 8 | type=str, 9 | default="configs/default.yaml", 10 | help="YAML configuration file", 11 | ) 12 | 13 | parser.add_argument("--debug", action="store_true", help="Debug mode") 14 | 15 | # * WandB 16 | parser.add_argument("--wandb-mode", type=str, default="disabled", help="WandB mode") 17 | parser.add_argument("--wandb-group", type=str, default=None, help="WandB group") 18 | parser.add_argument( 19 | "--wandb-name", type=str, default=None, required=True, help="WandB name" 20 | ) 21 | parser.add_argument("--wandb-tags", type=str, default=None, help="WandB tags") 22 | parser.add_argument("--wandb-notes", type=str, default=None, help="WandB notes") 23 | 24 | # * TaskGraph 25 | parser.add_argument("--hold-print", action="store_true", help="Hold print") 26 | parser.add_argument( 27 | "--clustering-th", type=float, default=1.0, help="Clustering distance threshold" 28 | ) 29 | parser.add_argument( 30 | "--match-th", type=float, default=0.46, help="Matching distance threshold" 31 | ) 32 | parser.add_argument( 33 | "--beam-search-th", type=float, default=0.30, help="Beam search distance threshold" 34 | ) 35 | parser.add_argument( 36 | "--dataset", 37 | type=str, 38 | choices=["coin", "crosstask", "assembly-label"], 39 | default="coin", 40 | help="Dataset to use", 41 | ) 42 | parser.add_argument( 43 | "--dataset-path", 44 | type=str, 45 | default="/media/hdd/data/assembly101/data/annotations/", 46 | help="Dataset path", 47 | ) 48 | 49 | 50 | parser.add_argument( 51 | "--eval-mode", type=str, choices=["text"], default="text", help="Evaluation mode" 52 | ) 53 | parser.add_argument( 54 | "--graph-type", type=str, choices=["overall"], default="overall", help="Graph type" 55 | ) 56 | parser.add_argument("--use-clusters", action="store_true", help="Use clusters") 57 | parser.add_argument( 58 | "--method", 59 | type=str, 60 | choices=["beam-search-with-cluster", "baseline-with-cluster"], 61 | default="beam-search-with-cluster", 62 | help="Method to use", 63 | ) 64 | parser.add_argument("--prune-keysteps", action="store_true", help="Prune keysteps") 65 | parser.add_argument("--keysteps-th", type=float, default=0.0, help="Keysteps threshold") 66 | 67 | # * BERT 68 | parser.add_argument("--lm", type=str, default="bert", help="Language model") 69 | parser.add_argument( 70 | "--mask-mode", 71 | type=str, 72 | default="none", 73 | choices=["none", "end", "prob"], 74 | help="Tokenize mode", 75 | ) 76 | parser.add_argument("--batch-size", type=int, default=1, help="Batch size") 77 | parser.add_argument("--tokenize-prob", type=float, default=0.15, help="Tokenize prob") 78 | parser.add_argument("--epochs", type=int, default=100, help="Epochs") 79 | parser.add_argument("--validate-every", type=int, default=10, help="Validate every") 80 | 81 | # * Misc 82 | parser.add_argument( 83 | "--device", 84 | type=str, 85 | default="cuda", 86 | help="The GPU or CPU to use, standard PyTorch rules apply", 87 | ) 88 | 89 | args = parser.parse_args() 90 | -------------------------------------------------------------------------------- /step_recognition/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # code from detectron2 3 | 4 | import math 5 | from typing import List 6 | import torch 7 | 8 | try: 9 | from torch.optim.lr_scheduler import LRScheduler 10 | except ImportError: 11 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler 12 | 13 | class WarmupCosineLR(LRScheduler): 14 | def __init__( 15 | self, 16 | optimizer: torch.optim.Optimizer, 17 | max_iters: int, 18 | warmup_factor: float = 0.001, 19 | warmup_iters: int = 1000, 20 | warmup_method: str = "linear", 21 | last_epoch: int = -1, 22 | ): 23 | self.max_iters = max_iters 24 | self.warmup_factor = warmup_factor 25 | self.warmup_iters = warmup_iters 26 | self.warmup_method = warmup_method 27 | super().__init__(optimizer, last_epoch) 28 | 29 | def get_lr(self) -> List[float]: 30 | warmup_factor = _get_warmup_factor_at_iter( 31 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 32 | ) 33 | # Different definitions of half-cosine with warmup are possible. For 34 | # simplicity we multiply the standard half-cosine schedule by the warmup 35 | # factor. An alternative is to start the period of the cosine at warmup_iters 36 | # instead of at 0. In the case that warmup_iters << max_iters the two are 37 | # very close to each other. 38 | return [ 39 | base_lr 40 | * warmup_factor 41 | * 0.5 42 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) 43 | for base_lr in self.base_lrs 44 | ] 45 | 46 | def _compute_values(self) -> List[float]: 47 | # The new interface 48 | return self.get_lr() 49 | 50 | 51 | def _get_warmup_factor_at_iter( 52 | method: str, iter: int, warmup_iters: int, warmup_factor: float 53 | ) -> float: 54 | """ 55 | Return the learning rate warmup factor at a specific iteration. 56 | See :paper:`ImageNet in 1h` for more details. 57 | 58 | Args: 59 | method (str): warmup method; either "constant" or "linear". 60 | iter (int): iteration at which to calculate the warmup factor. 61 | warmup_iters (int): the number of warmup iterations. 62 | warmup_factor (float): the base warmup factor (the meaning changes according 63 | to the method used). 64 | 65 | Returns: 66 | float: the effective warmup factor at the given iteration. 67 | """ 68 | if iter >= warmup_iters: 69 | return 1.0 70 | 71 | if method == "constant": 72 | return warmup_factor 73 | elif method == "linear": 74 | alpha = iter / warmup_iters 75 | return warmup_factor * (1 - alpha) + alpha 76 | else: 77 | raise ValueError("Unknown warmup method: {}".format(method)) 78 | 79 | def build_lr_scheduler(cfg, optimizer, iters_per_epoch): 80 | max_iters = cfg['num_epoch'] * iters_per_epoch 81 | warmup_iters = int(cfg['warmup_epoch'] * iters_per_epoch) 82 | # warmup_iters = int(max_iters * 1/6) 83 | scheduler = WarmupCosineLR( 84 | optimizer, 85 | max_iters=max_iters, 86 | warmup_factor=cfg['warmup_factor'], 87 | warmup_iters=warmup_iters, 88 | warmup_method=cfg['warmup_method'], 89 | last_epoch=-1) 90 | return scheduler -------------------------------------------------------------------------------- /step_anticipation/src/data/frequentist_baseline.py: -------------------------------------------------------------------------------- 1 | # First we import the necessary libraries 2 | import torch 3 | from src.data.assemblyLabelDataset import AssemblyLabelDataset 4 | 5 | correct_dataset = AssemblyLabelDataset("./mistake_labels/", split="correct") 6 | mistake_dataset = AssemblyLabelDataset("./mistake_labels/", split="mistake") 7 | 8 | sample_len = 67 9 | initial_padding = tuple([0]*sample_len) 10 | final_padding = tuple([1]*sample_len) 11 | 12 | all_samples = set() 13 | for sample in correct_dataset: 14 | for array in sample["oh_sample"]: 15 | if tuple(array.tolist()) == final_padding: 16 | continue 17 | all_samples.add(tuple(array.tolist())) 18 | for sample in mistake_dataset: 19 | for array in sample["oh_sample"]: 20 | if tuple(array.tolist()) == final_padding: 21 | continue 22 | all_samples.add(tuple(array.tolist())) 23 | 24 | 25 | all_samples = tuple(all_samples) 26 | all_samples = (initial_padding,) + all_samples 27 | 28 | A = torch.zeros((len(all_samples), len(all_samples))) 29 | for sample in correct_dataset: 30 | prev_step = initial_padding 31 | for n, array in enumerate(sample["oh_sample"]): 32 | if tuple(array.tolist()) == final_padding: 33 | continue 34 | A[all_samples.index(prev_step)][ 35 | all_samples.index(tuple(array.tolist())) 36 | ] += 1 37 | prev_step = tuple(array.tolist()) 38 | 39 | threshold = 1/len(all_samples) 40 | 41 | 42 | 43 | for n, line in enumerate(A): 44 | tot = sum(line) 45 | if tot > 0: 46 | A[n] = line / tot 47 | else: 48 | A[n] = torch.zeros(len(all_samples)) + threshold 49 | 50 | labels = [] 51 | gt_labels = [] 52 | 53 | for sample in mistake_dataset: 54 | prev_step = initial_padding 55 | for n, array in enumerate(sample["oh_sample"]): 56 | if tuple(array.tolist()) == final_padding: 57 | continue 58 | p = A[all_samples.index(prev_step)][all_samples.index(tuple(array.tolist()))] 59 | labels.append(0) if p < threshold else labels.append(1) 60 | if int(tuple(sample["oh_label"].tolist()[n])[0]) != 1: 61 | gt_labels.append(0) 62 | elif int(tuple(sample["oh_label"].tolist()[n])[1]) != 1: 63 | gt_labels.append(1) 64 | else: 65 | gt_labels.append(1) 66 | prev_step = tuple(array.tolist()) 67 | 68 | # Calculate accuracy, precision, recall, F1 69 | TP = 0 70 | FP = 0 71 | FN = 0 72 | TN = 0 73 | for n, label in enumerate(labels): 74 | if label == 1 and gt_labels[n] == 1: 75 | TP += 1 76 | elif label == 1 and gt_labels[n] == 0: 77 | FP += 1 78 | elif label == 0 and gt_labels[n] == 1: 79 | FN += 1 80 | elif label == 0 and gt_labels[n] == 0: 81 | TN += 1 82 | 83 | accuracy = (TP + TN) / (TP + FP + FN + TN) 84 | precision = TP / (TP + FP) 85 | recall = TP / (TP + FN) 86 | F1 = 2 * (precision * recall) / (precision + recall) 87 | 88 | # Print results 89 | print("Accuracy: {}".format(accuracy)) 90 | print("Precision: {}".format(precision)) 91 | print("Recall: {}".format(recall)) 92 | print("F1: {}".format(F1)) 93 | 94 | print("TP: {}".format(TP)) 95 | print("FP: {}".format(FP)) 96 | print("FN: {}".format(FN)) 97 | print("TN: {}".format(TN)) 98 | 99 | # Accuracy: 0.675739247311828 100 | # Precision: 0.7571277719112989 101 | # Recall: 0.739556472408458 102 | # F1: 0.7482389773023741 103 | # TP: 1434 104 | # FP: 460 105 | # FN: 505 106 | # TN: 577 -------------------------------------------------------------------------------- /utils/aggregate.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, List 3 | 4 | import numpy as np 5 | 6 | 7 | def eliminate_consecutive_duplicates(arr: np.ndarray) -> np.ndarray: 8 | """ 9 | Eliminate consecutive duplicate values in a numpy array. 10 | 11 | Args: 12 | arr (np.ndarray): Input array. 13 | 14 | Returns: 15 | np.ndarray: Array with consecutive duplicates removed. 16 | """ 17 | # Initialize result with the first element 18 | result = [arr[0]] 19 | # Iterate through the array and add elements that are not duplicates 20 | for i in range(1, len(arr)): 21 | if arr[i] != arr[i - 1]: 22 | result.append(arr[i]) 23 | return np.array(result) 24 | 25 | 26 | def find_changes(arr: np.ndarray) -> List[int]: 27 | """ 28 | Find indices where changes occur in the array. 29 | 30 | Args: 31 | arr (np.ndarray): Input array. 32 | 33 | Returns: 34 | List[int]: List of indices where changes occur. 35 | """ 36 | # Initialize result list 37 | result = [] 38 | # Iterate through the array and record indices where changes occur 39 | for i in range(1, len(arr)): 40 | if arr[i] != arr[i - 1]: 41 | result.append(i) 42 | result.append(len(arr)) 43 | return result 44 | 45 | 46 | def aggregate(data: Dict[str, Dict[str, Any]], output_path: str) -> None: 47 | """ 48 | Aggregate predictions and ground truth data, and save the results to a JSON file. 49 | 50 | Args: 51 | data (Dict[str, Dict[str, Any]]): Input data containing predictions and ground truth. 52 | output_path (str): Path to save the aggregated data as a JSON file. 53 | """ 54 | aggregated_data = {} 55 | window_size = 200 56 | # Iterate through each key-value pair in the data 57 | for key, value in data.items(): 58 | predictions = value["pred"] 59 | ground_truth = value["gt"] 60 | new_predictions = np.zeros_like(predictions) 61 | start_indices = [] 62 | end_indices = [] 63 | 64 | # Process predictions in windows 65 | for start in range(0, len(predictions), window_size): 66 | end = start + window_size 67 | if end > len(predictions): 68 | end = len(predictions) 69 | counts = np.bincount(predictions[start:end]) 70 | new_predictions[start:end] = np.argmax(counts) 71 | start_indices.append(start) 72 | end_indices.append(end) 73 | 74 | # Find changes and eliminate consecutive duplicates 75 | changes_predictions = find_changes(new_predictions) 76 | changes_ground_truth = find_changes(ground_truth) 77 | new_predictions = eliminate_consecutive_duplicates(new_predictions) 78 | ground_truth = eliminate_consecutive_duplicates(ground_truth) 79 | 80 | # Store the results in the aggregated data dictionary 81 | aggregated_data[key] = { 82 | "pred": new_predictions.tolist(), 83 | "gt": ground_truth.tolist(), 84 | "changes_pred": changes_predictions, 85 | "changes_gt": changes_ground_truth, 86 | } 87 | 88 | # Save the aggregated data as a JSON file 89 | with open(output_path, "w") as fp: 90 | json.dump(aggregated_data, fp) 91 | 92 | 93 | if __name__ == "__main__": 94 | import argparse 95 | 96 | parser = argparse.ArgumentParser( 97 | description="Aggregate predictions and ground truth data." 98 | ) 99 | parser.add_argument("input_path", type=str, help="Path to the input JSON file.") 100 | parser.add_argument( 101 | "output_path", type=str, help="Path to save the aggregated JSON file." 102 | ) 103 | args = parser.parse_args() 104 | 105 | # Load the input data from a JSON file 106 | with open(args.input_path, "r") as fp: 107 | data = json.load(fp) 108 | # Call the aggregate function with the loaded data and output path 109 | aggregate(data, args.output_path) 110 | -------------------------------------------------------------------------------- /data/output/aggregated_data.json: -------------------------------------------------------------------------------- 1 | {"annotations_1": {"pred": [4, 9, 10, 4, 10, 4, 10, 9, 10, 9, 2, 10, 9, 10, 9, 10, 2, 10, 2, 10, 2, 10, 2, 3, 10], "gt": [4, 10, 0, 4, 2, 3, 0, 2, 3], "changes_pred": [200, 400, 2000, 2800, 3200, 3400, 4200, 4600, 5800, 6200, 7200, 8400, 9600, 10800, 11000, 12200, 12400, 12800, 13000, 13200, 14800, 15600, 15800, 16200, 17280], "changes_gt": [48, 1407, 1984, 2781, 7209, 8407, 9402, 15820, 17280]}, "annotations_2": {"pred": [4, 9, 7, 10, 9, 10, 4, 10, 4, 0, 4, 0, 4, 2, 4, 0, 2, 0], "gt": [4, 7, 10, 6, 10, 6, 0, 10, 2, 3], "changes_pred": [1400, 2000, 2200, 2400, 2600, 3800, 4200, 4400, 4600, 5000, 5200, 5800, 6000, 6800, 7200, 7600, 7800, 9015], "changes_gt": [1466, 2161, 2345, 2369, 3697, 4591, 5677, 5761, 8744, 9015]}, "annotations_3": {"pred": [4, 7, 4, 0, 2, 4, 2, 4, 6, 4, 0, 4, 2, 4, 8, 2, 0, 2, 4, 2, 3, 4, 2, 3, 4], "gt": [4, 7, 10, 4, 7, 6, 4, 0, 4, 10, 2, 0, 2, 4, 3, 4, 3, 5], "changes_pred": [2400, 2600, 2800, 4000, 4200, 4400, 4600, 6400, 6800, 7600, 8400, 9800, 10000, 10600, 11000, 20600, 20800, 24600, 26000, 27400, 27600, 28400, 29400, 30800, 31114], "changes_gt": [2187, 2700, 4523, 5765, 5869, 6769, 7215, 8321, 9571, 9998, 19386, 20846, 24554, 25970, 27539, 28358, 30263, 31114]}, "annotations_4": {"pred": [4, 9, 4, 9, 10, 2, 4, 2], "gt": [4, 7, 10, 6, 4, 0, 2, 3], "changes_pred": [1200, 1400, 1600, 1800, 2600, 4200, 4400, 10046], "changes_gt": [433, 862, 2612, 3925, 4422, 6558, 9602, 10046]}, "annotations_5": {"pred": [4, 10, 2, 4, 2], "gt": [4, 7, 4, 10, 4, 6, 5, 4, 0, 4, 0, 4, 2, 3], "changes_pred": [1400, 2600, 2800, 7000, 9874], "changes_gt": [710, 732, 1167, 2629, 3090, 3352, 3520, 3721, 4698, 4810, 5941, 6610, 9544, 9874]}, "annotations_6": {"pred": [10, 0, 2, 3, 2, 3, 8], "gt": [10, 5, 10, 6, 0, 2, 0, 2, 0, 3, 8], "changes_pred": [5200, 5600, 7200, 7400, 8200, 10800, 11744], "changes_gt": [102, 178, 1216, 1857, 2340, 4567, 5791, 7131, 8230, 9800, 11744]}, "annotations_7": {"pred": [4, 10, 4, 10, 4], "gt": [4, 7, 6, 0], "changes_pred": [2400, 2600, 3800, 4400, 5157], "changes_gt": [1170, 3035, 4274, 5157]}, "annotations_10": {"pred": [4, 10, 4, 0, 7, 4, 0, 4, 10, 2, 10, 2, 10, 7, 3, 4, 2, 10, 4, 2, 7, 0, 7], "gt": [4, 7, 4, 10, 4, 6, 4, 5, 0, 4, 2, 4, 5, 1], "changes_pred": [1200, 3000, 3800, 4400, 4800, 5000, 5200, 6000, 7000, 7200, 7600, 8800, 9400, 9600, 10000, 10200, 11000, 11400, 12000, 12200, 12400, 12800, 12807], "changes_gt": [314, 368, 896, 2932, 3038, 3536, 3714, 3788, 5200, 5490, 11377, 11679, 12273, 12807]}, "annotations_16": {"pred": [4, 7, 4, 10, 4, 6, 4, 2, 4, 0, 4, 10], "gt": [4, 7, 4, 10, 4, 6, 0, 4, 0, 4, 2, 10, 2, 4, 10, 3], "changes_pred": [1400, 1800, 2200, 2800, 4000, 4600, 4800, 5000, 5200, 5400, 12200, 13001], "changes_gt": [1273, 1739, 2103, 2758, 3753, 4884, 4898, 5106, 6475, 9077, 9320, 10092, 10385, 12104, 12181, 13001]}, "annotations_17": {"pred": [4, 8, 1, 4, 10, 4, 10, 6, 0, 10, 6, 2], "gt": [4, 7, 10, 6, 8, 6, 10, 0, 2], "changes_pred": [2600, 3200, 3400, 3600, 5600, 5800, 6000, 6200, 7600, 8400, 9600, 11734], "changes_gt": [2478, 3243, 3800, 3836, 3845, 7639, 8326, 9669, 11734]}, "annotations_18": {"pred": [4, 7, 10, 6, 0, 2, 0, 2], "gt": [4, 7, 10, 6, 0, 10, 2, 3, 5], "changes_pred": [2600, 3000, 5200, 6200, 8400, 8800, 9000, 21803], "changes_gt": [2677, 3687, 4995, 6393, 8951, 9145, 21085, 21353, 21803]}, "annotations_20": {"pred": [4, 10, 4, 10, 6, 4, 6, 4, 6, 4, 2, 4], "gt": [4, 7, 10, 4, 6, 4, 0, 4, 2, 4, 2, 4, 11], "changes_pred": [3000, 3200, 4600, 4800, 6000, 6800, 7000, 7600, 7800, 12200, 12400, 12976], "changes_gt": [646, 1489, 3290, 3879, 4133, 4773, 8039, 8500, 8702, 8847, 9086, 11338, 12976]}, "annotations_23": {"pred": [4, 7, 8, 10], "gt": [4, 7, 10], "changes_pred": [2600, 2800, 3200, 3702], "changes_gt": [2375, 2703, 3702]}, "annotations_25": {"pred": [4, 10], "gt": [4, 7, 4, 7, 4, 10, 6], "changes_pred": [6000, 6971], "changes_gt": [2526, 2670, 2954, 5293, 5802, 6623, 6971]}, "annotations_28": {"pred": [2, 4, 10, 4, 2, 4, 2, 10, 2, 4, 2, 4, 2], "gt": [4, 7, 4, 10, 4, 10, 8, 4, 6, 5, 4, 10, 6, 10, 9, 10, 9, 4, 6, 4, 0, 4, 0], "changes_pred": [1000, 2800, 3800, 4000, 4400, 5800, 6000, 6400, 7400, 8600, 9400, 10200, 10735], "changes_gt": [14, 804, 2626, 3165, 3217, 3757, 3876, 4060, 4284, 4452, 5567, 5758, 5858, 6468, 6746, 7157, 7259, 8428, 8796, 8834, 9437, 10207, 10735]}} -------------------------------------------------------------------------------- /step_recognition/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import argparse 4 | import yaml 5 | import os 6 | import os.path as osp 7 | from utils import get_logger 8 | from model import build_model 9 | from datasets import build_data_loader 10 | from criterions import build_criterion 11 | from trainer import build_trainer, build_eval 12 | from utils import * 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--config", type=str, default="./configs/miniroad_thumos_kinetics.yaml" 18 | ) 19 | parser.add_argument("--eval", type=str, default=None) 20 | parser.add_argument("--amp", action="store_true") 21 | parser.add_argument("--tensorboard", action="store_true") 22 | parser.add_argument("--lr_scheduler", action="store_true") 23 | parser.add_argument("--no_rgb", action="store_true") 24 | parser.add_argument("--no_flow", action="store_true") 25 | args = parser.parse_args() 26 | 27 | # combine argparse and yaml 28 | opt = yaml.load(open(args.config), Loader=yaml.FullLoader) 29 | opt.update(vars(args)) 30 | cfg = opt 31 | 32 | set_seed(20) 33 | device = "cuda:1" if torch.cuda.is_available() else "cpu" 34 | # print gpu type 35 | print(torch.cuda.get_device_name(0)) 36 | 37 | identifier = f'{cfg["model"]}_{cfg["data_name"]}_{cfg["feature_pretrained"]}_flow{not cfg["no_flow"]}' 38 | result_path = create_outdir(osp.join(cfg["output_path"], identifier)) 39 | logger = get_logger(result_path) 40 | logger.info(cfg) 41 | 42 | testloader = build_data_loader(cfg, mode="test") 43 | trainloader = build_data_loader(cfg, mode="train") 44 | model = build_model(cfg, device) 45 | evaluate = build_eval(cfg) 46 | 47 | if args.eval != None: 48 | model.load_state_dict(torch.load(args.eval)) 49 | mAP = evaluate( 50 | model, 51 | testloader, #! trainloader per mettere il loader di train 52 | # testloader 53 | logger, 54 | device, 55 | ) 56 | logger.info(f'{cfg["task"]} result: {mAP*100:.2f} m{cfg["metric"]}') 57 | exit() 58 | 59 | trainloader = build_data_loader(cfg, mode="train") 60 | criterion = build_criterion(cfg, device) 61 | train_one_epoch = build_trainer(cfg) 62 | optim = torch.optim.AdamW if cfg["optimizer"] == "AdamW" else torch.optim.Adam 63 | optimizer = optim( 64 | [{"params": model.parameters(), "initial_lr": cfg["lr"]}], 65 | lr=cfg["lr"], 66 | weight_decay=cfg["weight_decay"], 67 | ) 68 | 69 | scheduler = ( 70 | build_lr_scheduler(cfg, optimizer, len(trainloader)) 71 | if args.lr_scheduler 72 | else None 73 | ) 74 | writer = SummaryWriter(osp.join(result_path, "runs")) if args.tensorboard else None 75 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 76 | total_params = sum(p.numel() for p in model.parameters()) 77 | 78 | logger.info(f'Dataset: {cfg["data_name"]}, Model: {cfg["model"]}') 79 | logger.info( 80 | f'lr:{cfg["lr"]} | Weight Decay:{cfg["weight_decay"]} | Window Size:{cfg["window_size"]} | Batch Size:{cfg["batch_size"]}' 81 | ) 82 | logger.info( 83 | f'Total epoch:{cfg["num_epoch"]} | Total Params:{total_params/1e6:.1f} M | Optimizer: {cfg["optimizer"]}' 84 | ) 85 | logger.info(f"Output Path:{result_path}") 86 | 87 | best_mAP, best_epoch = 0, 0 88 | for epoch in range(1, cfg["num_epoch"] + 1): 89 | epoch_loss = train_one_epoch( 90 | trainloader, 91 | model, 92 | criterion, 93 | optimizer, 94 | scaler, 95 | epoch, 96 | device , 97 | writer, 98 | scheduler=scheduler, 99 | ) 100 | trainloader.dataset._init_features() 101 | mAP = evaluate(model, testloader, logger, device) 102 | print("Current mAP:", mAP) 103 | if mAP > best_mAP: 104 | best_mAP = mAP 105 | best_epoch = epoch 106 | print("Checkpoint Saved at ", osp.join(result_path, "ckpts", "best.pth")) 107 | torch.save(model.state_dict(), osp.join(result_path, "ckpts", "best.pth")) 108 | logger.info( 109 | f'Epoch {epoch} mAP: {mAP*100:.2f} | Best mAP: {best_mAP*100:.2f} at epoch {best_epoch}, iter {epoch*cfg["batch_size"]*len(trainloader)} | train_loss: {epoch_loss/len(trainloader):.4f}, lr: {optimizer.param_groups[0]["lr"]:.7f}' 110 | ) 111 | 112 | os.rename( 113 | osp.join(result_path, "ckpts", "best.pth"), 114 | osp.join(result_path, "ckpts", f"best_{best_mAP*100:.2f}.pth"), 115 | ) 116 | -------------------------------------------------------------------------------- /step_recognition/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Evaluation code from LSTR 2 | 3 | from multiprocessing import Pool 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | from sklearn.metrics import average_precision_score 8 | 9 | 10 | def calibrated_average_precision_score(y_true, y_score): 11 | """Compute calibrated average precision (cAP), which is particularly 12 | proposed for the TVSeries dataset. 13 | """ 14 | y_true_sorted = y_true[np.argsort(-y_score)] 15 | tp = y_true_sorted.astype(float) 16 | fp = np.abs(y_true_sorted.astype(float) - 1) 17 | tps = np.cumsum(tp) 18 | fps = np.cumsum(fp) 19 | ratio = np.sum(tp == 0) / np.sum(tp) 20 | cprec = tps / (tps + fps / (ratio + np.finfo(float).eps) + np.finfo(float).eps) 21 | cap = np.sum(cprec[tp == 1]) / np.sum(tp) 22 | return cap 23 | 24 | 25 | def perframe_average_precision(prediction, ground_truth, class_names, 26 | postprocessing=None, metrics='AP'): 27 | """Compute (frame-level) average precision between ground truth and 28 | predictions data frames. 29 | """ 30 | result = OrderedDict() 31 | ground_truth = np.array(ground_truth) 32 | prediction = np.array(prediction) 33 | 34 | # Postprocessing 35 | if postprocessing is not None: 36 | ground_truth, prediction = postprocessing(ground_truth, prediction) 37 | 38 | # Build metrics 39 | if metrics == 'AP': 40 | compute_score = average_precision_score 41 | elif metrics == 'cAP': 42 | # print('cAP') 43 | compute_score = calibrated_average_precision_score 44 | else: 45 | raise RuntimeError('Unknown metrics: {}'.format(metrics)) 46 | 47 | # Ignore backgroud class 48 | ignore_index = set([0]) 49 | 50 | # Compute average precision 51 | result['per_class_AP'] = OrderedDict() 52 | result['num'] = OrderedDict() 53 | print(f"NUM FRAMES: {np.sum(ground_truth[:, 1:])}") 54 | for idx, class_name in enumerate(class_names): 55 | if idx not in ignore_index: 56 | if np.any(ground_truth[:, idx]): 57 | ap_score = compute_score(ground_truth[:, idx], prediction[:, idx]) 58 | result['per_class_AP'][class_name] = ap_score 59 | result['num'][class_name] = f'[true: {int(np.sum(ground_truth[:, idx]))}, pred:{int(np.sum(prediction[:,idx]))}, AP:{ap_score*100:.1f}]' 60 | result['mean_AP'] = np.mean(list(result['per_class_AP'].values())) 61 | 62 | return result 63 | 64 | def get_stage_pred_scores(gt_targets, pred_scores, perc_s, perc_e): 65 | starts = [] 66 | ends = [] 67 | stage_gt_targets = [] 68 | stage_pred_scores = [] 69 | for i in range(len(gt_targets)): 70 | if gt_targets[i] == 0: 71 | stage_gt_targets.append(gt_targets[i]) 72 | stage_pred_scores.append(pred_scores[i]) 73 | else: 74 | if i == 0 or gt_targets[i - 1] == 0: 75 | starts.append(i) 76 | if i == len(gt_targets) - 1 or gt_targets[i + 1] == 0: 77 | ends.append(i) 78 | if len(starts) != len(ends): 79 | raise ValueError('starts and ends cannot pair!') 80 | 81 | action_lens = [ends[i] - starts[i] for i in range(len(starts))] 82 | stage_starts = [starts[i] + int(action_lens[i] * perc_s) for i in range(len(starts))] 83 | stage_ends = [max(stage_starts[i] + 1, starts[i] + int(action_lens[i] * perc_e)) for i in range(len(starts))] 84 | for i in range(len(starts)): 85 | stage_gt_targets.extend(gt_targets[stage_starts[i]: stage_ends[i]]) 86 | stage_pred_scores.extend(pred_scores[stage_starts[i]: stage_ends[i]]) 87 | return np.array(stage_gt_targets), np.array(stage_pred_scores) 88 | 89 | 90 | def perstage_average_precision(prediction, ground_truth, 91 | class_names, postprocessing, 92 | metrics='cAP'): 93 | result = OrderedDict() 94 | ground_truth = np.array(ground_truth) 95 | prediction = np.array(prediction) 96 | 97 | # Postprocessing 98 | if postprocessing is not None: 99 | ground_truth, prediction = postprocessing(ground_truth, prediction) 100 | 101 | # Build metrics 102 | if metrics == 'AP': 103 | compute_score = average_precision_score 104 | elif metrics == 'cAP': 105 | compute_score = calibrated_average_precision_score 106 | else: 107 | raise RuntimeError('Unknown metrics: {}'.format(metrics)) 108 | 109 | # Ignore backgroud class 110 | ignore_index = set([0]) 111 | 112 | # Compute average precision 113 | for perc_s in range(10): 114 | perc_e = perc_s + 1 115 | stage_name = '{:2}%_{:3}%'.format(perc_s * 10, perc_e * 10) 116 | result[stage_name] = OrderedDict({'per_class_AP': OrderedDict()}) 117 | for idx, class_name in enumerate(class_names): 118 | if idx not in ignore_index: 119 | stage_gt_targets, stage_pred_scores = get_stage_pred_scores( 120 | (ground_truth[:, idx] == 1).astype(int), 121 | prediction[:, idx], 122 | perc_s / 10, 123 | perc_e / 10, 124 | ) 125 | result[stage_name]['per_class_AP'][class_name] = \ 126 | compute_score(stage_gt_targets, stage_pred_scores) 127 | result[stage_name]['mean_AP'] = \ 128 | np.mean(list(result[stage_name]['per_class_AP'].values())) 129 | 130 | return result 131 | -------------------------------------------------------------------------------- /step_recognition/model/rnn/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.model_builder import META_ARCHITECTURES 5 | 6 | FEATURE_SIZES = { 7 | 'rgb_anet_resnet50': 2048, 8 | 'flow_anet_resnet50': 2048, 9 | 'rgb_kinetics_bninception': 1024, 10 | 'flow_kinetics_bninception': 1024, 11 | 'rgb_kinetics_resnet50': 2048, 12 | 'flow_kinetics_resnet50': 2048, 13 | 'flow_nv_kinetics_bninception': 1024, 14 | 'rgb_kinetics_i3d': 2048, 15 | 'flow_kinetics_i3d': 2048 16 | } 17 | 18 | @META_ARCHITECTURES.register("MiniROAD") 19 | class MROAD(nn.Module): 20 | 21 | def __init__(self, cfg): 22 | super(MROAD, self).__init__() 23 | self.use_flow = not cfg['no_flow'] 24 | self.use_rgb = not cfg['no_rgb'] 25 | self.input_dim = 0 26 | if self.use_rgb: 27 | self.input_dim += FEATURE_SIZES[cfg['rgb_type']] 28 | if self.use_flow: 29 | self.input_dim += FEATURE_SIZES[cfg['flow_type']] 30 | 31 | self.hidden_dim = cfg['hidden_dim'] 32 | self.num_layers = cfg['num_layers'] 33 | self.out_dim = cfg['num_classes'] 34 | self.window_size = cfg['window_size'] 35 | 36 | self.relu = nn.ReLU() 37 | self.embedding_dim = cfg['embedding_dim'] 38 | self.gru = nn.GRU(self.embedding_dim, self.hidden_dim, self.num_layers, batch_first=True) 39 | self.layer1 = nn.Sequential( 40 | nn.Linear(self.input_dim, self.embedding_dim), 41 | nn.LayerNorm(self.embedding_dim), 42 | nn.ReLU(), 43 | nn.Dropout(p=cfg['dropout']), 44 | ) 45 | self.f_classification = nn.Sequential( 46 | nn.Linear(self.hidden_dim, self.out_dim) 47 | ) 48 | # self.h0 = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) 49 | self.h0 = torch.zeros(self.num_layers, 1, self.hidden_dim) 50 | 51 | def forward(self, rgb_input, flow_input): 52 | if self.use_rgb and self.use_flow: 53 | x = torch.cat((rgb_input, flow_input), 2) 54 | elif self.use_rgb: 55 | x = rgb_input 56 | elif self.use_flow: 57 | x = flow_input 58 | x = self.layer1(x) 59 | B, _, _ = x.shape 60 | h0 = self.h0.expand(-1, B, -1).to(x.device) 61 | ht, _ = self.gru(x, h0) 62 | ht = self.relu(ht) 63 | # ht = self.relu(ht + x) 64 | logits = self.f_classification(ht) 65 | out_dict = {} 66 | if self.training: 67 | out_dict['logits'] = logits 68 | else: 69 | pred_scores = F.softmax(logits, dim=-1) 70 | out_dict['logits'] = pred_scores 71 | return out_dict 72 | 73 | @META_ARCHITECTURES.register("MiniROADA") 74 | class MROADA(nn.Module): 75 | 76 | def __init__(self, cfg): 77 | super(MROADA, self).__init__() 78 | self.use_flow = not cfg['no_flow'] 79 | self.use_rgb = not cfg['no_rgb'] 80 | self.input_dim = 0 81 | if self.use_rgb: 82 | self.input_dim += FEATURE_SIZES[cfg['rgb_type']] 83 | if self.use_flow: 84 | self.input_dim += FEATURE_SIZES[cfg['flow_type']] 85 | 86 | self.embedding_dim = cfg['embedding_dim'] 87 | self.hidden_dim = cfg['hidden_dim'] 88 | self.num_layers = cfg['num_layers'] 89 | self.anticipation_length = cfg["anticipation_length"] 90 | self.out_dim = cfg['num_classes'] 91 | 92 | self.layer1 = nn.Sequential( 93 | nn.Linear(self.input_dim, self.embedding_dim), 94 | nn.LayerNorm(self.embedding_dim), 95 | nn.ReLU(), 96 | nn.Dropout(p=cfg['dropout']) 97 | ) 98 | self.actionness = cfg['actionness'] 99 | if self.actionness: 100 | self.f_actionness = nn.Sequential( 101 | nn.Linear(self.hidden_dim, 1), 102 | ) 103 | self.relu = nn.ReLU() 104 | self.gru = nn.GRU(self.embedding_dim, self.hidden_dim, self.num_layers, batch_first=True) 105 | self.f_classification = nn.Sequential( 106 | nn.Linear(self.hidden_dim, self.out_dim) 107 | ) 108 | self.anticipation_layer = nn.Sequential( 109 | nn.Linear(self.hidden_dim, self.anticipation_length*self.hidden_dim), 110 | ) 111 | 112 | 113 | def forward(self, rgb_input, flow_input): 114 | if self.use_rgb and self.use_flow: 115 | x = torch.cat((rgb_input, flow_input), 2) 116 | elif self.use_rgb: 117 | x = rgb_input 118 | elif self.use_flow: 119 | x = flow_input 120 | B, S, _ = x.shape 121 | x = self.layer1(x) 122 | h0 = torch.zeros(1, B, self.hidden_dim).to(x.device) 123 | ht, _ = self.gru(x, h0) 124 | logits = self.f_classification(self.relu(ht)) 125 | anticipation_ht = self.anticipation_layer(self.relu(ht)).view(B, S, self.anticipation_length, self.hidden_dim) 126 | anticipation_logits = self.f_classification(self.relu(anticipation_ht)) 127 | out_dict = {} 128 | if self.training: 129 | out_dict['logits'] = logits 130 | out_dict['anticipation_logits'] = anticipation_logits 131 | else: 132 | pred_scores = F.softmax(logits, dim=-1) 133 | pred_anticipation_scores = F.softmax(anticipation_logits, dim=-1) 134 | out_dict['logits'] = pred_scores 135 | out_dict['anticipation_logits'] = pred_anticipation_scores 136 | 137 | return out_dict -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/Attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class SelfAttention(nn.Module): 8 | def __init__( 9 | self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 10 | ): 11 | super().__init__() 12 | self.num_heads = heads 13 | head_dim = dim // heads 14 | self.scale = qk_scale or head_dim ** -0.5 15 | 16 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 17 | self.attn_drop = nn.Dropout(dropout_rate) 18 | self.proj = nn.Linear(dim, dim) 19 | self.proj_drop = nn.Dropout(dropout_rate) 20 | 21 | def forward(self, x): 22 | B, N, C = x.shape 23 | qkv = ( 24 | self.qkv(x) 25 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 26 | .permute(2, 0, 3, 1, 4) 27 | ) 28 | q, k, v = ( 29 | qkv[0], 30 | qkv[1], 31 | qkv[2], 32 | ) # make torchscript happy (cannot use tensor as tuple) 33 | 34 | attn = (q @ k.transpose(-2, -1)) * self.scale 35 | attn = attn.softmax(dim=-1) 36 | attn = self.attn_drop(attn) 37 | 38 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 39 | x = self.proj(x) 40 | x = self.proj_drop(x) 41 | return x 42 | 43 | 44 | class AxialAttention(nn.Module): 45 | def __init__( 46 | self, 47 | in_planes, 48 | out_planes, 49 | groups=8, 50 | kernel_size=56, 51 | stride=1, 52 | bias=False, 53 | width=False, 54 | ): 55 | assert (in_planes % groups == 0) and (out_planes % groups == 0) 56 | super(AxialAttention, self).__init__() 57 | self.in_planes = in_planes 58 | self.out_planes = out_planes 59 | self.groups = groups 60 | self.group_planes = out_planes // groups 61 | self.kernel_size = kernel_size 62 | self.stride = stride 63 | self.bias = bias 64 | self.width = width 65 | 66 | # Multi-head self attention 67 | self.qkv_transform = nn.Conv1d( 68 | in_planes, 69 | out_planes * 2, 70 | kernel_size=1, 71 | stride=1, 72 | padding=0, 73 | bias=False, 74 | ) 75 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2) 76 | self.bn_similarity = nn.BatchNorm2d(groups * 3) 77 | self.bn_output = nn.BatchNorm1d(out_planes * 2) 78 | 79 | # Position embedding 80 | self.relative = nn.Parameter( 81 | torch.randn(self.group_planes * 2, kernel_size * 2 - 1), 82 | requires_grad=True, 83 | ) 84 | query_index = torch.arange(kernel_size).unsqueeze(0) 85 | key_index = torch.arange(kernel_size).unsqueeze(1) 86 | relative_index = key_index - query_index + kernel_size - 1 87 | self.register_buffer('flatten_index', relative_index.view(-1)) 88 | if stride > 1: 89 | self.pooling = nn.AvgPool2d(stride, stride=stride) 90 | 91 | self.reset_parameters() 92 | 93 | def forward(self, x): 94 | if self.width: 95 | x = x.permute(0, 2, 1, 3) 96 | else: 97 | x = x.permute(0, 3, 1, 2) # N, W, C, H 98 | N, W, C, H = x.shape 99 | x = x.contiguous().view(N * W, C, H) 100 | 101 | # Transformations 102 | qkv = self.bn_qkv(self.qkv_transform(x)) 103 | q, k, v = torch.split( 104 | qkv.reshape(N * W, self.groups, self.group_planes * 2, H), 105 | [self.group_planes // 2, self.group_planes // 2, self.group_planes], 106 | dim=2, 107 | ) 108 | 109 | # Calculate position embedding 110 | all_embeddings = torch.index_select( 111 | self.relative, 1, self.flatten_index 112 | ).view(self.group_planes * 2, self.kernel_size, self.kernel_size) 113 | q_embedding, k_embedding, v_embedding = torch.split( 114 | all_embeddings, 115 | [self.group_planes // 2, self.group_planes // 2, self.group_planes], 116 | dim=0, 117 | ) 118 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding) 119 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) 120 | qk = torch.einsum('bgci, bgcj->bgij', q, k) 121 | stacked_similarity = torch.cat([qk, qr, kr], dim=1) 122 | stacked_similarity = ( 123 | self.bn_similarity(stacked_similarity) 124 | .view(N * W, 3, self.groups, H, H) 125 | .sum(dim=1) 126 | ) 127 | 128 | # (N, groups, H, H, W) 129 | similarity = F.softmax(stacked_similarity, dim=3) 130 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v) 131 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) 132 | stacked_output = torch.cat([sv, sve], dim=-1).view( 133 | N * W, self.out_planes * 2, H 134 | ) 135 | output = ( 136 | self.bn_output(stacked_output) 137 | .view(N, W, self.out_planes, 2, H) 138 | .sum(dim=-2) 139 | ) 140 | 141 | if self.width: 142 | output = output.permute(0, 2, 1, 3) 143 | else: 144 | output = output.permute(0, 2, 3, 1) 145 | 146 | if self.stride > 1: 147 | output = self.pooling(output) 148 | 149 | return output 150 | 151 | def reset_parameters(self): 152 | self.qkv_transform.weight.data.normal_( 153 | 0, math.sqrt(1.0 / self.in_planes) 154 | ) 155 | nn.init.normal_(self.relative, 0.0, math.sqrt(1.0 / self.group_planes)) 156 | -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/BiT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | class StdConv2d(nn.Conv2d): 8 | def forward(self, x): 9 | w = self.weight 10 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 11 | w = (w - m) / torch.sqrt(v + 1e-10) 12 | return F.conv2d( 13 | x, 14 | w, 15 | self.bias, 16 | self.stride, 17 | self.padding, 18 | self.dilation, 19 | self.groups, 20 | ) 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return StdConv2d( 26 | in_planes, 27 | out_planes, 28 | kernel_size=3, 29 | stride=stride, 30 | padding=dilation, 31 | groups=groups, 32 | bias=False, 33 | dilation=dilation, 34 | ) 35 | 36 | 37 | def conv1x1(in_planes, out_planes, stride=1): 38 | """1x1 convolution""" 39 | return StdConv2d( 40 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 41 | ) 42 | 43 | 44 | class PreActBottleneck(nn.Module): 45 | """Pre-activation (v2) bottleneck block. 46 | Follows the implementation of "Identity Mappings in Deep Residual Networks": 47 | https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua 48 | Except it puts the stride on 3x3 conv when available. 49 | """ 50 | 51 | def __init__(self, in_planes, out_planes=None, mid_planes=None, stride=1): 52 | super(PreActBottleneck, self).__init__() 53 | out_planes = out_planes or in_planes 54 | mid_planes = mid_planes or out_planes // 4 55 | 56 | self.gn1 = nn.GroupNorm(32, in_planes) 57 | self.conv1 = conv1x1(in_planes, mid_planes) 58 | self.gn2 = nn.GroupNorm(32, mid_planes) 59 | self.conv2 = conv3x3(mid_planes, mid_planes, stride) 60 | self.gn3 = nn.GroupNorm(32, mid_planes) 61 | self.conv3 = conv1x1(mid_planes, out_planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | if stride != 1 or in_planes != out_planes: 65 | # Projection also with pre-activation according to paper. 66 | self.downsample = conv1x1(in_planes, out_planes, stride) 67 | 68 | def forward(self, x): 69 | out = self.relu(self.gn1(x)) 70 | 71 | # Residual branch 72 | residual = x 73 | if hasattr(self, 'downsample'): 74 | residual = self.downsample(out) 75 | 76 | # Unit's branch 77 | out = self.conv1(out) 78 | out = self.conv2(self.relu(self.gn2(out))) 79 | out = self.conv3(self.relu(self.gn3(out))) 80 | 81 | return out + residual 82 | 83 | 84 | class ResNetV2Model(nn.Module): 85 | """Implementation of Pre-activation (v2) ResNet mode.""" 86 | 87 | def __init__(self, block_units, width_factor, head_size=21843): 88 | super(ResNetV2Model, self).__init__() 89 | wf = width_factor # shortcut 'cause we'll use it a lot. 90 | 91 | # The following will be unreadable if we split lines. 92 | # pylint: disable=line-too-long 93 | self.conv1 = nn.Sequential(OrderedDict([ 94 | ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)), 95 | ('pad', nn.ConstantPad2d(1, 0)), 96 | ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), 97 | ])) 98 | 99 | self.conv2 = nn.Sequential(OrderedDict( 100 | [('unit01', PreActBottleneck(in_planes=64*wf, out_planes=256*wf, mid_planes=64*wf))] + 101 | [(f'unit{i:02d}', PreActBottleneck(in_planes=256*wf, out_planes=256*wf, mid_planes=64*wf)) for i in range(2, block_units[0] + 1)], 102 | )) 103 | self.conv3 = nn.Sequential(OrderedDict( 104 | [('unit01', PreActBottleneck(in_planes=256*wf, out_planes=512*wf, mid_planes=128*wf, stride=2))] + 105 | [(f'unit{i:02d}', PreActBottleneck(in_planes=512*wf, out_planes=512*wf, mid_planes=128*wf)) for i in range(2, block_units[1] + 1)], 106 | )) 107 | self.conv4 = nn.Sequential(OrderedDict( 108 | [('unit01', PreActBottleneck(in_planes=512*wf, out_planes=1024*wf, mid_planes=256*wf, stride=2))] + 109 | [(f'unit{i:02d}', PreActBottleneck(in_planes=1024*wf, out_planes=1024*wf, mid_planes=256*wf)) for i in range(2, block_units[2] + 1)], 110 | )) 111 | self.conv5 = nn.Sequential(OrderedDict( 112 | [('unit01', PreActBottleneck(in_planes=1024*wf, out_planes=2048*wf, mid_planes=512*wf, stride=2))] + 113 | [(f'unit{i:02d}', PreActBottleneck(in_planes=2048*wf, out_planes=2048*wf, mid_planes=512*wf)) for i in range(2, block_units[3] + 1)], 114 | )) 115 | # pylint: enable=line-too-long 116 | 117 | self.head = nn.Sequential(OrderedDict([ 118 | ('gn', nn.GroupNorm(32, 2048*wf)), 119 | ('relu', nn.ReLU(inplace=True)), 120 | ('avg', nn.AdaptiveAvgPool2d(output_size=1)), 121 | ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)), 122 | ])) 123 | 124 | def forward(self, x, include_conv5=False, include_top=False): 125 | x = self.conv1(x) 126 | x = self.conv2(x) 127 | x = self.conv3(x) 128 | x = self.conv4(x) 129 | if include_conv5: 130 | x = self.conv5(x) 131 | if include_top: 132 | x = self.head(x) 133 | 134 | if include_top and include_conv5: 135 | assert x.shape[-2:] == (1,1,) 136 | return x[..., 0, 0] 137 | 138 | return x 139 | -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/ViT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .decoder import Decoder, DecoderLayer 5 | from .attn import FullAttention, ProbAttention, AttentionLayer 6 | from .Transformer import TransformerModel 7 | from .PositionalEncoding import ( 8 | FixedPositionalEncoding, 9 | LearnedPositionalEncoding, 10 | ) 11 | from model.model_builder import META_ARCHITECTURES as registry 12 | 13 | FEATURE_SIZES = { 14 | 'rgb_anet_resnet50': 2048, 15 | 'flow_anet_resnet50': 2048, 16 | 'rgb_kinetics_bninception': 1024, 17 | 'flow_kinetics_bninception': 1024, 18 | 'rgb_kinetics_resnet50': 2048, 19 | 'flow_kinetics_resnet50': 2048, 20 | 'flow_nv_kinetics_bninception': 1024, 21 | 'rgb_kinetics_i3d': 2048, 22 | 'flow_kinetics_i3d': 2048 23 | } 24 | 25 | @registry.register('Transformer') 26 | class ViTEnc(nn.Module): 27 | def __init__(self, cfg, use_representation=True, 28 | conv_patch_representation=False, 29 | positional_encoding_type="learned" 30 | ): 31 | super(ViTEnc, self).__init__() 32 | self.img_dim = cfg["window_size"] 33 | self.out_dim = cfg["num_classes"] 34 | self.embedding_dim = cfg["embedding_dim"] 35 | self.patch_dim = cfg["patch_dim"] 36 | self.num_heads = cfg["num_heads"] 37 | self.num_layers = cfg["num_layers"] 38 | self.hidden_dim = cfg["hidden_dim"] 39 | self.dropout_rate = cfg["dropout"] 40 | self.use_flow = not cfg['no_flow'] 41 | self.use_rgb = not cfg['no_rgb'] 42 | self.num_channels= 0 43 | if self.use_rgb: 44 | self.num_channels += FEATURE_SIZES[cfg['rgb_type']] 45 | if self.use_flow: 46 | self.num_channels += FEATURE_SIZES[cfg['flow_type']] 47 | self.attn_dropout_rate = cfg["attn_dropout_rate"] 48 | assert self.embedding_dim % self.num_heads == 0 49 | assert self.img_dim % self.patch_dim == 0 50 | self.conv_patch_representation = conv_patch_representation 51 | self.num_patches = int(self.img_dim // self.patch_dim) 52 | self.seq_length = self.num_patches + 1 53 | # self.seq_length = self.num_patches + self.img_dim 54 | self.flatten_dim = self.patch_dim * self.patch_dim * self.num_channels 55 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embedding_dim)) 56 | # self.cls_token = nn.Parameter(torch.zeros(1, self.img_dim, self.embedding_dim)) 57 | 58 | self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim) 59 | if positional_encoding_type == "learned": 60 | self.position_encoding = LearnedPositionalEncoding( 61 | self.seq_length, self.embedding_dim, self.seq_length 62 | ) 63 | elif positional_encoding_type == "fixed": 64 | self.position_encoding = FixedPositionalEncoding( 65 | self.embedding_dim, 66 | ) 67 | print('position encoding :', positional_encoding_type) 68 | 69 | self.pe_dropout = nn.Dropout(p=self.dropout_rate) 70 | 71 | self.encoder = TransformerModel( 72 | self.embedding_dim, 73 | self.num_layers, 74 | self.num_heads, 75 | self.hidden_dim, 76 | self.dropout_rate, 77 | self.attn_dropout_rate, 78 | ) 79 | self.pre_head_ln = nn.LayerNorm(self.embedding_dim) 80 | 81 | use_representation = False # False 82 | if use_representation: 83 | self.mlp_head = nn.Sequential( 84 | nn.Linear(self.embedding_dim , self.hidden_dim//2), 85 | # nn.Tanh(), 86 | nn.ReLU(), 87 | nn.Linear(self.hidden_dim//2, self.out_dim), 88 | ) 89 | else: 90 | self.mlp_head = nn.Linear(self.embedding_dim, self.out_dim) 91 | 92 | if self.conv_patch_representation: 93 | # self.conv_x = nn.Conv2d( 94 | # self.num_channels, 95 | # self.embedding_dim, 96 | # kernel_size=(self.patch_dim, self.patch_dim), 97 | # stride=(self.patch_dim, self.patch_dim), 98 | # padding=self._get_padding( 99 | # 'VALID', (self.patch_dim, self.patch_dim), 100 | # ), 101 | # ) 102 | self.conv_x = nn.Conv1d( 103 | self.num_channels, 104 | self.embedding_dim, 105 | kernel_size=self.patch_dim, 106 | stride=self.patch_dim, 107 | padding=self._get_padding( 108 | 'VALID', (self.patch_dim), 109 | ), 110 | ) 111 | else: 112 | self.conv_x = None 113 | 114 | self.to_cls_token = nn.Identity() 115 | 116 | 117 | def forward(self, sequence_input_rgb, sequence_input_flow): 118 | if self.use_rgb and self.use_flow: 119 | x = torch.cat((sequence_input_rgb, sequence_input_flow), 2) 120 | elif self.use_rgb: 121 | x = sequence_input_rgb 122 | elif self.use_flow: 123 | x = sequence_input_flow 124 | 125 | x = self.linear_encoding(x) 126 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # B, 1, 1024 127 | # x = torch.cat((cls_tokens, x), dim=1) 128 | x = torch.cat((x, cls_tokens), dim=1) # B, seq+1, 1024 129 | x = self.position_encoding(x) # B, seq+1, 1024 130 | x = self.pe_dropout(x) # not delete 131 | 132 | # apply transformer 133 | x = self.encoder(x) 134 | x = self.pre_head_ln(x) # B, seq+1, 1024 135 | 136 | x = self.to_cls_token(x[:, 0]) # B, 1024 137 | # x = self.to_cls_token(x[:,0:self.img_dim]) # B, 1024 138 | x = self.mlp_head(x) 139 | # x = F.log_softmax(x, dim=-1) 140 | out_dict = {} 141 | out_dict['logits'] = x.unsqueeze(1) 142 | # out_dict['logits'] = x #x.unsqueeze(1) 143 | return out_dict 144 | 145 | def _get_padding(self, padding_type, kernel_size): 146 | assert padding_type in ['SAME', 'VALID'] 147 | if padding_type == 'SAME': 148 | _list = [(k - 1) // 2 for k in kernel_size] 149 | return tuple(_list) 150 | return tuple(0 for _ in kernel_size) 151 | 152 | -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | from math import sqrt 8 | 9 | 10 | class TriangularCausalMask(): 11 | def __init__(self, B, L, device="cpu"): 12 | mask_shape = [B, 1, L, L] 13 | with torch.no_grad(): 14 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 15 | 16 | @property 17 | def mask(self): 18 | return self._mask 19 | 20 | 21 | class ProbMask(): 22 | def __init__(self, B, H, L, index, scores, device="cpu"): 23 | _mask = torch.ones(L, scores.shape[-1], dytpe=torch.bool).to(device).triu(1) 24 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 25 | indicator = _mask_ex[torch.arange(B)[:, None, None], 26 | torch.arange(H)[None, :, None], 27 | index, :].to(device) 28 | self._mask = indicator.view(scores.shape).to(device) 29 | 30 | @property 31 | def mask(self): 32 | return self._mask 33 | 34 | 35 | class FullAttention(nn.Module): 36 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1): 37 | super(FullAttention, self).__init__() 38 | self.scale = scale 39 | self.mask_flag = mask_flag 40 | self.dropout = nn.Dropout(attention_dropout) 41 | 42 | def forward(self, queries, keys, values, attn_mask): 43 | B, L, H, E = queries.shape 44 | _, S, _, D = values.shape 45 | scale = self.scale or 1. / sqrt(E) 46 | 47 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 48 | if self.mask_flag: 49 | if attn_mask is None: 50 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 51 | 52 | scores.masked_fill_(attn_mask.mask, -np.inf) 53 | 54 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 55 | V = torch.einsum("bhls,bshd->blhd", A, values) 56 | 57 | return V.contiguous() 58 | 59 | 60 | class ProbAttention(nn.Module): 61 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1): 62 | super(ProbAttention, self).__init__() 63 | self.factor = factor 64 | self.scale = scale 65 | self.mask_flag = mask_flag 66 | self.dropout = nn.Dropout(attention_dropout) 67 | 68 | def _prob_QK(self, Q, K, sample_k, n_top): 69 | # Q [B, H, L, D] 70 | B, H, L, E = K.shape 71 | _, _, S, _ = Q.shape 72 | 73 | # calculate the sampled Q_K 74 | K_expand = K.unsqueeze(-3).expand(B, H, S, L, E) 75 | indx_sample = torch.randint(L, (S, sample_k)) 76 | K_sample = K_expand[:, :, torch.arange(S).unsqueeze(1), indx_sample, :] 77 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 78 | 79 | # find the Top_k query with sparisty measurement 80 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L) 81 | M_top = M.topk(n_top, sorted=False)[1] 82 | 83 | # use the reduced Q to calculate Q_K 84 | Q_reduce = Q[torch.arange(B)[:, None, None], 85 | torch.arange(H)[None, :, None], 86 | M_top, :] 87 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) 88 | 89 | return Q_K, M_top 90 | 91 | def _get_initial_context(self, V, L_Q): 92 | B, H, L_V, D = V.shape 93 | if not self.mask_flag: 94 | V_sum = V.sum(dim=-2) 95 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 96 | else: # use mask 97 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only 98 | contex = V.cumsum(dim=-1) 99 | return contex 100 | 101 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 102 | B, H, L_V, D = V.shape 103 | 104 | if self.mask_flag: 105 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 106 | scores.masked_fill_(attn_mask.mask, -np.inf) 107 | 108 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 109 | 110 | context_in[torch.arange(B)[:, None, None], 111 | torch.arange(H)[None, :, None], 112 | index, :] = torch.matmul(attn, V) 113 | return context_in 114 | 115 | def forward(self, queries, keys, values, attn_mask): 116 | B, L, H, D = queries.shape 117 | _, S, _, _ = keys.shape 118 | 119 | queries = queries.view(B, H, L, -1) 120 | keys = keys.view(B, H, S, -1) 121 | values = values.view(B, H, S, -1) 122 | 123 | U = self.factor * np.ceil(np.log(S)).astype('int').item() 124 | u = self.factor * np.ceil(np.log(L)).astype('int').item() 125 | 126 | scores_top, index = self._prob_QK(queries, keys, u, U) 127 | # add scale factor 128 | scale = self.scale or 1. / sqrt(D) 129 | if scale is not None: 130 | scores_top = scores_top * scale 131 | # get the context 132 | context = self._get_initial_context(values, L) 133 | # update the context with selected top_k queries 134 | context = self._update_context(context, values, scores_top, index, L, attn_mask) 135 | 136 | return context.contiguous() 137 | 138 | 139 | class AttentionLayer(nn.Module): 140 | def __init__(self, attention, d_model, n_heads, d_keys=None, 141 | d_values=None): 142 | super(AttentionLayer, self).__init__() 143 | 144 | d_keys = d_keys or (d_model // n_heads) 145 | d_values = d_values or (d_model // n_heads) 146 | 147 | self.inner_attention = attention 148 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 149 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 150 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 151 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 152 | self.n_heads = n_heads 153 | 154 | def forward(self, queries, keys, values, attn_mask): 155 | B, L, _ = queries.shape 156 | _, S, _ = keys.shape 157 | H = self.n_heads 158 | 159 | queries = self.query_projection(queries).view(B, L, H, -1) 160 | keys = self.key_projection(keys).view(B, S, H, -1) 161 | values = self.value_projection(values).view(B, S, H, -1) 162 | 163 | out = self.inner_attention( 164 | queries, 165 | keys, 166 | values, 167 | attn_mask 168 | ).view(B, L, -1) 169 | 170 | return self.out_projection(out) -------------------------------------------------------------------------------- /step_recognition/trainer/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from tqdm import tqdm 4 | import time 5 | from utils import thumos_postprocessing 6 | from utils import * 7 | import json 8 | from trainer.eval_builder import EVAL 9 | from utils import thumos_postprocessing, perframe_average_precision 10 | import pickle 11 | import numpy as np 12 | import os 13 | 14 | 15 | @EVAL.register("OAD") 16 | class Evaluate(nn.Module): 17 | 18 | def __init__(self, cfg): 19 | super(Evaluate, self).__init__() 20 | self.data_processing = ( 21 | thumos_postprocessing if "THUMOS" in cfg["data_name"] else None 22 | ) 23 | self.metric = cfg["metric"] 24 | self.eval_method = perframe_average_precision 25 | self.cfg = cfg 26 | self.all_class_names = json.load(open(cfg["video_list_path"]))[ 27 | cfg["data_name"].split("_")[0] 28 | ]["class_index"] 29 | 30 | def eval(self, model, dataloader, logger, device): 31 | model.eval() 32 | output = {} 33 | with torch.no_grad(): 34 | pred_scores, gt_targets = [], [] 35 | start = time.time() 36 | for rgb_input, flow_input, target, vid, start, end in tqdm( 37 | dataloader, desc="Evaluation:", leave=False 38 | ): 39 | rgb_input, flow_input, target = ( 40 | rgb_input.to(device), 41 | flow_input.to(device), 42 | target.to(device), 43 | ) 44 | out_dict = model(rgb_input, flow_input) 45 | pred_logit = out_dict["logits"] 46 | prob_val = pred_logit.squeeze().cpu().numpy() 47 | target_batch = target.squeeze().cpu().numpy() 48 | pred_scores += list(prob_val) 49 | gt_targets += list(target_batch) 50 | #! To save the prediction scores and ground truth 51 | if self.cfg['eval'] != None: 52 | video_name = vid[0] 53 | pred = np.argmax(prob_val, axis=1) 54 | gt = np.argmax(target_batch, axis=1) 55 | sample = {"pred": pred, "gt": gt} 56 | output[video_name] = sample 57 | 58 | # save output as json file 59 | if self.cfg['eval'] != None: 60 | os.makedirs("output_miniRoad", exist_ok=True) 61 | # save as json file 62 | for k, v in output.items(): 63 | output[k] = {"pred": v["pred"].tolist(), "gt": v["gt"].tolist()} 64 | with open("output_miniRoad/output_miniROAD.json", "w") as file: 65 | json.dump(output, file) 66 | 67 | 68 | end = time.time() 69 | num_frames = len(gt_targets) 70 | result = self.eval_method( 71 | pred_scores, 72 | gt_targets, 73 | self.all_class_names, 74 | self.data_processing, 75 | self.metric, 76 | ) 77 | time_taken = (end - start).item() 78 | logger.info( 79 | f"Processed {num_frames} frames in {time_taken:.1f} seconds ({num_frames / time_taken :.1f} FPS)" 80 | ) 81 | return result["mean_AP"] 82 | 83 | def forward(self, model, dataloader, logger, device): 84 | return self.eval(model, dataloader, logger, device) 85 | 86 | 87 | @EVAL.register("ANTICIPATION") 88 | class ANT_Evaluate(nn.Module): 89 | 90 | def __init__(self, cfg): 91 | super(ANT_Evaluate, self).__init__() 92 | data_name = cfg["data_name"].split("_")[0] 93 | self.data_processing = thumos_postprocessing if data_name == "THUMOS" else None 94 | self.metric = cfg["metric"] 95 | self.eval_method = perframe_average_precision 96 | self.all_class_names = json.load(open(cfg["video_list_path"]))[data_name][ 97 | "class_index" 98 | ] 99 | 100 | def eval(self, model, dataloader, logger): 101 | device = "cuda:0" 102 | model.eval() 103 | with torch.no_grad(): 104 | pred_scores, gt_targets, ant_pred_scores, ant_gt_targets = [], [], [], [] 105 | start = time.time() 106 | anticipation_mAPs = [] 107 | for rgb_input, flow_input, target, ant_target in tqdm( 108 | dataloader, desc="Evaluation:", leave=False 109 | ): 110 | rgb_input, flow_input, target, ant_target = ( 111 | rgb_input.to(device), 112 | flow_input.to(device), 113 | target.to(device), 114 | ant_target.to(device), 115 | ) 116 | out_dict = model(rgb_input, flow_input) 117 | pred_logit = out_dict["logits"] 118 | ant_pred_logit = out_dict["anticipation_logits"] 119 | prob_val = pred_logit.squeeze().cpu().numpy() 120 | target_batch = target.squeeze().cpu().numpy() 121 | ant_prob_val = ant_pred_logit.squeeze().cpu().numpy() 122 | ant_target_batch = ant_target.squeeze().cpu().numpy() 123 | pred_scores += list(prob_val) 124 | gt_targets += list(target_batch) 125 | ant_pred_scores += list(ant_prob_val) 126 | ant_gt_targets += list(ant_target_batch) 127 | end = time.time() 128 | num_frames = len(gt_targets) 129 | result = self.eval_method( 130 | pred_scores, 131 | gt_targets, 132 | self.all_class_names, 133 | self.data_processing, 134 | self.metric, 135 | ) 136 | ant_pred_scores = np.array(ant_pred_scores) 137 | ant_gt_targets = np.array(ant_gt_targets) 138 | logger.info(f'OAD mAP: {result["mean_AP"]*100:.2f}') 139 | for step in range(ant_gt_targets.shape[1]): 140 | result[f"anticipation_{step+1}"] = self.eval_method( 141 | ant_pred_scores[:, step, :], 142 | ant_gt_targets[:, step, :], 143 | self.all_class_names, 144 | self.data_processing, 145 | self.metric, 146 | ) 147 | anticipation_mAPs.append(result[f"anticipation_{step+1}"]["mean_AP"]) 148 | logger.info( 149 | f"Anticipation at step {step+1}: {result[f'anticipation_{step+1}']['mean_AP']*100:.2f}" 150 | ) 151 | logger.info(f"Mean Anticipation mAP: {np.mean(anticipation_mAPs)*100:.2f}") 152 | 153 | time_taken = end - start 154 | logger.info( 155 | f"Processed {num_frames} frames in {time_taken:.1f} seconds ({num_frames / time_taken :.1f} FPS)" 156 | ) 157 | 158 | return np.mean(anticipation_mAPs) 159 | 160 | def forward(self, model, dataloader, logger): 161 | return self.eval(model, dataloader, logger) 162 | -------------------------------------------------------------------------------- /step_anticipation/src/data/assembly_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | from transformers import ( 10 | AutoModelForMaskedLM, 11 | AutoTokenizer, 12 | GPT2LMHeadModel, 13 | GPT2Tokenizer, 14 | OpenAIGPTModel, 15 | OpenAIGPTTokenizer, 16 | pipeline, 17 | set_seed, 18 | ) 19 | 20 | from src.utils.variables import CORRECT, WRONG 21 | 22 | 23 | class AssemblyTextDataset(Dataset): 24 | def __init__(self, path: str, split: str = "train") -> None: 25 | self.dir = Path(path) 26 | self.word2idx = {} 27 | self.idx2word = [ 28 | "[CLS]", 29 | "[SEP]", 30 | "[MASK]", 31 | "[PAD]", 32 | "[UNK]", 33 | ] 34 | 35 | if split == "train": 36 | files = CORRECT 37 | elif split == "test": 38 | files = WRONG 39 | else: 40 | raise ValueError("split should be either train or test") 41 | 42 | self.files = [self.dir / f for f in files] 43 | self.__init_data__() 44 | 45 | def __init_data__(self) -> None: 46 | self.data = {} 47 | for f in self.files: 48 | df = pd.read_csv(f) 49 | self.data[f.stem] = ( 50 | df[["verb", "this", "that"]] 51 | .apply( 52 | # ! Changed " " to "-" 53 | lambda x: "-".join(map(lambda y: y.replace(" ", ""), x)).strip(), 54 | axis=1, 55 | ) 56 | .tolist() 57 | ) 58 | 59 | # # * Hexadecimal encoding 60 | # # iterate over the rows of the dataframe i.e. actions, each one is a triple of words 61 | # val = [] 62 | # for action in df[["verb", "this", "that"]].values: 63 | # count = 0 64 | # # iterate over the words of the action 65 | # for word in action: 66 | # # 1. each word is a string, iterate over the characters 67 | # # 2. get the sum of their ascii values 68 | # # 3. convert them to hex 69 | # # 4. append them to a list 70 | # count += sum([int(ord(c)) for c in list(word)]) 71 | # val.append(hex(count).split("x")[-1]) 72 | # self.data[f.stem] = val 73 | 74 | # def get_vocab(self) -> list: 75 | # for f in self.files: 76 | # df = pd.read_csv(f) 77 | # words = df[["verb", "this", "that"]].values.flatten() 78 | # for word in words: 79 | # word = word.replace(" ", "") 80 | # if word not in self.idx2word: 81 | # self.idx2word.append(word) 82 | 83 | # for i, word in enumerate(self.idx2word): 84 | # self.word2idx[word] = i 85 | 86 | # def get_vocab_action(self): 87 | # vocab = [] 88 | # for f in self.files: 89 | # df = pd.read_csv(f) 90 | # actions = df[["verb", "this", "that"]] 91 | # actions = actions.apply(lambda x: " ".join(x).strip(), axis=1).tolist() 92 | # vocab.extend(actions) 93 | # return list(set(vocab)) 94 | 95 | def __len__(self): 96 | return len(self.data) 97 | 98 | def __getitem__(self, index) -> list: 99 | key = self.files[index].stem 100 | val = self.data[key] 101 | return val 102 | 103 | def collate_fn(self, batch) -> dict: 104 | min_n = min([len(x) for x in batch]) - 1 105 | # get a random number between 0 and min_len - 1 106 | n = np.random.randint(1, min_n) 107 | hist, gt = [], [] 108 | for x in batch: 109 | hist.append(x[:n]) 110 | gt.append(x[n]) 111 | 112 | out = {"hist": hist, "gt": gt} 113 | return out 114 | 115 | # ! Done for Assembly101 and not our version 116 | # def __getitem__(self, index) -> dict: 117 | # id = self.files[index].split("_", 1)[-1].split(".")[0] 118 | # return self.data[id] 119 | 120 | 121 | if __name__ == "__main__": 122 | # * Dataset 123 | dataset = AssemblyTextDataset("data/mistake_labels", split="test") 124 | vocab = dataset.get_vocab() 125 | vocab = dataset.get_vocab_action() 126 | 127 | # * Model 128 | model_checkpoint = "distilbert-base-uncased" 129 | model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) 130 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 131 | 132 | # * Dataloader 133 | # TODO don't know why but batch size higher than 2 gives problem to tokenization 134 | dl = DataLoader(dataset, batch_size=2, collate_fn=dataset.collate_fn) 135 | for batch in dl: 136 | B = len(batch) 137 | hist = batch["hist"] 138 | gt = batch["gt"] 139 | 140 | # * BERT 141 | texts = [] 142 | for entry in hist: 143 | text = " ".join(entry) 144 | text = text + 3 * " [MASK]" 145 | texts.append(text) 146 | 147 | # TODO maybe we should make it autoregressive to avoid predicting always the same word 148 | inputs = tokenizer(texts, return_tensors="pt") 149 | token_logits = model(**inputs).logits 150 | # Find the location of [MASK] and extract its logits 151 | mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id) 152 | 153 | mask_token_logits = token_logits[mask_token_index[0], mask_token_index[1], :] 154 | N, D = mask_token_logits.shape 155 | mask_token_logits = mask_token_logits.reshape(B, N // B, D) 156 | # Pick the [MASK] candidates with the highest logits 157 | top_k_tokens = torch.topk(mask_token_logits, 1, dim=-1).indices.transpose(1, 2) 158 | 159 | for i, (text, k_tokens) in enumerate(zip(texts, top_k_tokens)): 160 | print(f"{i}: text") 161 | for tokens in k_tokens: 162 | print( 163 | f'{i}: {text.replace("[MASK] [MASK] [MASK]", tokenizer.decode(tokens))}' 164 | ) 165 | 166 | # TODO get only the predicted tokens and compare them to the gt 167 | # make_autoregressive(texts, n=3, model=model, tokenizer=tokenizer) 168 | 169 | # * GPT2 170 | # texts = [" ".join(entry) for entry in hist] 171 | 172 | # tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 173 | # model = GPT2LMHeadModel.from_pretrained("gpt2") 174 | 175 | # prompt = texts[0] 176 | # print(prompt) 177 | # input_ids = tokenizer.encode(prompt, return_tensors="pt") 178 | 179 | # attention_mask = torch.ones(input_ids.shape, dtype=torch.long) 180 | # pad_token_id = tokenizer.eos_token_id 181 | 182 | # sample_output = model.generate( 183 | # input_ids, 184 | # attention_mask=attention_mask, 185 | # pad_token_id=pad_token_id, 186 | # do_sample=True, 187 | # max_length=len(prompt.split()) + 3, 188 | # top_k=len(prompt.split()), 189 | # top_p=0.95, 190 | # num_return_sequences=10, 191 | # ) 192 | 193 | # for i, sample in enumerate(sample_output): 194 | # generated_text = tokenizer.decode(sample, skip_special_tokens=True) 195 | # generated_words = generated_text.split()[-3:] 196 | # print(f"Variant {i+1}: {' '.join(generated_words)}") 197 | 198 | break 199 | -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/HybridViT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .BiT import ResNetV2Model 5 | from .AxialNet import AxialAttentionNet 6 | from .Transformer import TransformerModel 7 | from .PositionalEncoding import ( 8 | FixedPositionalEncoding, 9 | LearnedPositionalEncoding, 10 | ) 11 | 12 | 13 | class HybridVisionTransformer(nn.Module): 14 | def __init__( 15 | self, 16 | img_dim, 17 | out_dim, 18 | num_channels, 19 | embedding_dim, 20 | num_heads, 21 | num_layers, 22 | hidden_dim, 23 | include_conv5, 24 | dropout_rate, 25 | positional_encoding_type, 26 | backbone=None, 27 | ): 28 | super(HybridVisionTransformer, self).__init__() 29 | 30 | assert embedding_dim % num_heads == 0 31 | 32 | self.embedding_dim = embedding_dim 33 | self.num_heads = num_heads 34 | self.out_dim = out_dim 35 | self.num_channels = num_channels 36 | self.include_conv5 = include_conv5 37 | self.backbone = backbone 38 | 39 | self.backbone_model, self.flatten_dim = self.configure_backbone() 40 | 41 | self.projection_encoding = nn.Linear(self.flatten_dim, embedding_dim) 42 | self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) 43 | 44 | self.decoder_dim = int(img_dim / 16.0) ** 2 45 | if self.include_conv5: 46 | self.decoder_dim = int(img_dim / 32.0) ** 2 47 | 48 | self.decoder_dim += 1 # for the cls token 49 | 50 | if positional_encoding_type == "learned": 51 | self.position_encoding = LearnedPositionalEncoding( 52 | self.decoder_dim, self.embedding_dim, self.decoder_dim 53 | ) 54 | elif positional_encoding_type == "fixed": 55 | self.position_encoding = FixedPositionalEncoding( 56 | self.embedding_dim, 57 | ) 58 | 59 | self.transformer = TransformerModel( 60 | embedding_dim, num_layers, num_heads, hidden_dim 61 | ) 62 | self.mlp_head = nn.Sequential( 63 | nn.Linear(embedding_dim, hidden_dim), 64 | nn.GELU(), 65 | nn.Linear(hidden_dim, out_dim), 66 | ) 67 | self.to_cls_token = nn.Identity() 68 | 69 | def forward(self, x): 70 | # apply bit backbone 71 | x = self.backbone_model(x, include_conv5=self.include_conv5) 72 | x = x.view(x.size(0), -1, self.flatten_dim) 73 | 74 | x = self.projection_encoding(x) 75 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 76 | x = torch.cat((cls_tokens, x), dim=1) 77 | x = self.position_encoding(x) 78 | 79 | # apply transformer 80 | x = self.transformer(x) 81 | x = self.to_cls_token(x[:, 0]) 82 | x = self.mlp_head(x) 83 | x = F.log_softmax(x, dim=-1) 84 | 85 | return x 86 | 87 | def configure_backbone(self): 88 | raise NotImplementedError("Method to be called in child class!!") 89 | 90 | 91 | class ResNetHybridViT(HybridVisionTransformer): 92 | def __init__( 93 | self, 94 | img_dim, 95 | out_dim, 96 | num_channels, 97 | embedding_dim, 98 | num_heads, 99 | num_layers, 100 | hidden_dim, 101 | include_conv5=False, 102 | dropout_rate=0.1, 103 | positional_encoding_type="learned", 104 | backbone='r50x1', 105 | ): 106 | super(ResNetHybridViT, self).__init__( 107 | img_dim=img_dim, 108 | out_dim=out_dim, 109 | num_channels=num_channels, 110 | embedding_dim=embedding_dim, 111 | num_heads=num_heads, 112 | num_layers=num_layers, 113 | hidden_dim=hidden_dim, 114 | include_conv5=include_conv5, 115 | dropout_rate=dropout_rate, 116 | positional_encoding_type=positional_encoding_type, 117 | backbone=backbone, 118 | ) 119 | 120 | def configure_backbone(self): 121 | """ 122 | Current support offered for all BiT models 123 | KNOWN_MODELS in https://github.com/google-research/big_transfer/blob/master/bit_pytorch/models.py 124 | 125 | expects model name of style 'r{depth}x{width}' 126 | where depth in [50, 101, 152] 127 | where width in [1,2,3,4] 128 | """ 129 | backbone = self.backbone 130 | out_dim = self.out_dim 131 | 132 | splits = backbone.split('x') 133 | model_name = splits[0] 134 | width_factor = int(splits[1]) 135 | 136 | if model_name in ['r50', 'r101'] and width_factor in [2, 4]: 137 | return ValueError( 138 | "Invalid Configuration of models -- expect 50x1, 50x3, 101x1, 101x3" 139 | ) 140 | elif model_name == 'r152' and width_factor in [1, 3]: 141 | return ValueError( 142 | "Invalid Configuration of models -- expect 152x2, 152x4" 143 | ) 144 | 145 | block_units_dict = { 146 | 'r50': [3, 4, 6, 3], 147 | 'r101': [3, 4, 23, 3], 148 | 'r152': [3, 8, 36, 3], 149 | } 150 | block_units = block_units_dict.get(model_name, [3, 4, 6, 3]) 151 | model = ResNetV2Model(block_units, width_factor, head_size=out_dim) 152 | 153 | if self.num_channels == 3: 154 | flatten_dim = 1024 * width_factor 155 | if self.include_conv5: 156 | flatten_dim *= 2 157 | 158 | return model, flatten_dim 159 | 160 | 161 | class AxialNetHybridViT(HybridVisionTransformer): 162 | def __init__( 163 | self, 164 | img_dim, 165 | out_dim, 166 | num_channels, 167 | embedding_dim, 168 | num_heads, 169 | num_layers, 170 | hidden_dim, 171 | include_conv5=False, 172 | dropout_rate=0.1, 173 | positional_encoding_type="learned", 174 | backbone='a50m', 175 | ): 176 | super(AxialNetHybridViT, self).__init__( 177 | img_dim=img_dim, 178 | out_dim=out_dim, 179 | num_channels=num_channels, 180 | embedding_dim=embedding_dim, 181 | num_heads=num_heads, 182 | num_layers=num_layers, 183 | hidden_dim=hidden_dim, 184 | include_conv5=include_conv5, 185 | dropout_rate=dropout_rate, 186 | positional_encoding_type=positional_encoding_type, 187 | backbone=backbone, 188 | ) 189 | 190 | def configure_backbone(self): 191 | """ 192 | Current support offered for all BiT models 193 | models from https://github.com/csrhddlam/axial-deeplab/blob/master/lib/models/axialnet.py 194 | 195 | expects model name of style 'a{depth}{width}' 196 | where depth in [26, 50, 101] 197 | where width in [s, m, l] 198 | """ 199 | backbone = self.backbone 200 | out_dim = self.out_dim 201 | 202 | model_name = backbone[:3] 203 | width = backbone[-1] 204 | 205 | block_units_dict = { 206 | 'a26': [1, 2, 4, 1], 207 | 'a50': [3, 4, 6, 3], 208 | 'a101': [3, 4, 23, 3], 209 | } 210 | block_units = block_units_dict.get(model_name, [3, 4, 6, 3]) 211 | 212 | scale_factor_dict = {'s': 0.5, 'm': 0.75, 'l': 1.0} 213 | scale_factor = scale_factor_dict.get(width, 0.75) 214 | model = AxialAttentionNet( 215 | block_units, s=scale_factor, num_classes=out_dim 216 | ) 217 | 218 | if self.num_channels == 3: 219 | flatten_dim = int(512 * float(scale_factor / 0.5)) 220 | if self.include_conv5: 221 | flatten_dim *= 2 222 | 223 | return model, flatten_dim 224 | -------------------------------------------------------------------------------- /step_anticipation/data/predictions/output_miniROAD_Epic-tent-O.json: -------------------------------------------------------------------------------- 1 | { 2 | "annotations_1": { 3 | "pred": [ 4 | 4, 5 | 9, 6 | 4, 7 | 9, 8 | 4, 9 | 9, 10 | 4, 11 | 9, 12 | 4, 13 | 9, 14 | 4, 15 | 9, 16 | 4, 17 | 9, 18 | 4, 19 | 9, 20 | 4, 21 | 9, 22 | 4 23 | ], 24 | "gt": [ 25 | 4, 26 | 10, 27 | 0, 28 | 4, 29 | 2, 30 | 3, 31 | 0, 32 | 2, 33 | 3 34 | ] 35 | }, 36 | "annotations_2": { 37 | "pred": [ 38 | 4, 39 | 9, 40 | 4, 41 | 9, 42 | 4, 43 | 9 44 | ], 45 | "gt": [ 46 | 4, 47 | 7, 48 | 10, 49 | 6, 50 | 10, 51 | 6, 52 | 0, 53 | 10, 54 | 2, 55 | 3 56 | ] 57 | }, 58 | "annotations_3": { 59 | "pred": [ 60 | 9, 61 | 4, 62 | 9, 63 | 4, 64 | 9, 65 | 4, 66 | 9, 67 | 4, 68 | 9, 69 | 4, 70 | 9, 71 | 4, 72 | 9, 73 | 4, 74 | 9, 75 | 8, 76 | 9, 77 | 8, 78 | 9, 79 | 4, 80 | 9, 81 | 4, 82 | 9, 83 | 4, 84 | 9 85 | ], 86 | "gt": [ 87 | 4, 88 | 7, 89 | 10, 90 | 4, 91 | 7, 92 | 6, 93 | 4, 94 | 0, 95 | 4, 96 | 10, 97 | 2, 98 | 0, 99 | 2, 100 | 4, 101 | 3, 102 | 4, 103 | 3, 104 | 5 105 | ] 106 | }, 107 | "annotations_4": { 108 | "pred": [ 109 | 4, 110 | 9, 111 | 4, 112 | 9, 113 | 4, 114 | 9, 115 | 4, 116 | 9, 117 | 4, 118 | 9, 119 | 4, 120 | 9 121 | ], 122 | "gt": [ 123 | 4, 124 | 7, 125 | 10, 126 | 6, 127 | 4, 128 | 0, 129 | 2, 130 | 3 131 | ] 132 | }, 133 | "annotations_5": { 134 | "pred": [ 135 | 9, 136 | 4, 137 | 9, 138 | 4, 139 | 9, 140 | 4, 141 | 9 142 | ], 143 | "gt": [ 144 | 4, 145 | 7, 146 | 4, 147 | 10, 148 | 4, 149 | 6, 150 | 5, 151 | 4, 152 | 0, 153 | 4, 154 | 0, 155 | 4, 156 | 2, 157 | 3 158 | ] 159 | }, 160 | "annotations_6": { 161 | "pred": [ 162 | 9, 163 | 6, 164 | 9, 165 | 8, 166 | 9, 167 | 8, 168 | 4, 169 | 9, 170 | 4 171 | ], 172 | "gt": [ 173 | 10, 174 | 5, 175 | 10, 176 | 6, 177 | 0, 178 | 2, 179 | 0, 180 | 2, 181 | 0, 182 | 3, 183 | 8 184 | ] 185 | }, 186 | "annotations_7": { 187 | "pred": [ 188 | 9, 189 | 4, 190 | 9, 191 | 4, 192 | 9 193 | ], 194 | "gt": [ 195 | 4, 196 | 7, 197 | 6, 198 | 0 199 | ] 200 | }, 201 | "annotations_10": { 202 | "pred": [ 203 | 4, 204 | 9, 205 | 4, 206 | 9, 207 | 4, 208 | 9, 209 | 4, 210 | 9, 211 | 8, 212 | 9, 213 | 4, 214 | 9, 215 | 8, 216 | 9, 217 | 4, 218 | 9, 219 | 4, 220 | 9, 221 | 6, 222 | 4 223 | ], 224 | "gt": [ 225 | 4, 226 | 7, 227 | 4, 228 | 10, 229 | 4, 230 | 6, 231 | 4, 232 | 5, 233 | 0, 234 | 4, 235 | 2, 236 | 4, 237 | 5, 238 | 1 239 | ] 240 | }, 241 | "annotations_16": { 242 | "pred": [ 243 | 9, 244 | 4, 245 | 9, 246 | 4, 247 | 9, 248 | 4, 249 | 9, 250 | 4, 251 | 9, 252 | 4, 253 | 9 254 | ], 255 | "gt": [ 256 | 4, 257 | 7, 258 | 4, 259 | 10, 260 | 4, 261 | 6, 262 | 0, 263 | 4, 264 | 0, 265 | 4, 266 | 2, 267 | 10, 268 | 2, 269 | 4, 270 | 10, 271 | 3 272 | ] 273 | }, 274 | "annotations_17": { 275 | "pred": [ 276 | 9, 277 | 4, 278 | 9, 279 | 8, 280 | 9 281 | ], 282 | "gt": [ 283 | 4, 284 | 7, 285 | 10, 286 | 6, 287 | 8, 288 | 6, 289 | 10, 290 | 0, 291 | 2 292 | ] 293 | }, 294 | "annotations_18": { 295 | "pred": [ 296 | 9, 297 | 4, 298 | 9, 299 | 8, 300 | 9 301 | ], 302 | "gt": [ 303 | 4, 304 | 7, 305 | 10, 306 | 6, 307 | 0, 308 | 10, 309 | 2, 310 | 3, 311 | 5 312 | ] 313 | }, 314 | "annotations_20": { 315 | "pred": [ 316 | 9, 317 | 4, 318 | 9, 319 | 8, 320 | 4, 321 | 8, 322 | 9, 323 | 4, 324 | 9, 325 | 4, 326 | 8 327 | ], 328 | "gt": [ 329 | 4, 330 | 7, 331 | 10, 332 | 4, 333 | 6, 334 | 4, 335 | 0, 336 | 4, 337 | 2, 338 | 4, 339 | 2, 340 | 4, 341 | 11 342 | ] 343 | }, 344 | "annotations_23": { 345 | "pred": [ 346 | 9, 347 | 4, 348 | 9 349 | ], 350 | "gt": [ 351 | 4, 352 | 7, 353 | 10 354 | ] 355 | }, 356 | "annotations_25": { 357 | "pred": [ 358 | 9, 359 | 4, 360 | 9, 361 | 4, 362 | 9 363 | ], 364 | "gt": [ 365 | 4, 366 | 7, 367 | 4, 368 | 7, 369 | 4, 370 | 10, 371 | 6 372 | ] 373 | }, 374 | "annotations_28": { 375 | "pred": [ 376 | 9, 377 | 4, 378 | 8, 379 | 4, 380 | 9, 381 | 4, 382 | 9, 383 | 4, 384 | 8, 385 | 4, 386 | 8, 387 | 4, 388 | 8, 389 | 9, 390 | 8, 391 | 4, 392 | 9, 393 | 4, 394 | 9 395 | ], 396 | "gt": [ 397 | 4, 398 | 7, 399 | 4, 400 | 10, 401 | 4, 402 | 10, 403 | 8, 404 | 4, 405 | 6, 406 | 5, 407 | 4, 408 | 10, 409 | 6, 410 | 10, 411 | 9, 412 | 10, 413 | 9, 414 | 4, 415 | 6, 416 | 4, 417 | 0, 418 | 4, 419 | 0 420 | ] 421 | } 422 | } -------------------------------------------------------------------------------- /step_recognition/model/transformer_models/AxialNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .Attention import AxialAttention 4 | 5 | 6 | def conv1x1(in_planes, out_planes, stride=1): 7 | """1x1 convolution""" 8 | return nn.Conv2d( 9 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 10 | ) 11 | 12 | 13 | class AxialBlock(nn.Module): 14 | expansion = 2 15 | 16 | def __init__( 17 | self, 18 | inplanes, 19 | planes, 20 | stride=1, 21 | downsample=None, 22 | groups=1, 23 | base_width=64, 24 | dilation=1, 25 | norm_layer=None, 26 | kernel_size=56, 27 | ): 28 | super(AxialBlock, self).__init__() 29 | if norm_layer is None: 30 | norm_layer = nn.BatchNorm2d 31 | width = int(planes * (base_width / 64.0)) 32 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 33 | self.conv_down = conv1x1(inplanes, width) 34 | self.bn1 = norm_layer(width) 35 | self.hight_block = AxialAttention( 36 | width, width, groups=groups, kernel_size=kernel_size 37 | ) 38 | self.width_block = AxialAttention( 39 | width, 40 | width, 41 | groups=groups, 42 | kernel_size=kernel_size, 43 | stride=stride, 44 | width=True, 45 | ) 46 | self.conv_up = conv1x1(width, planes * self.expansion) 47 | self.bn2 = norm_layer(planes * self.expansion) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | identity = x 54 | 55 | out = self.conv_down(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.hight_block(out) 60 | out = self.width_block(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv_up(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class AxialAttentionNet(nn.Module): 76 | def __init__( 77 | self, 78 | layers, 79 | num_classes=1000, 80 | zero_init_residual=True, 81 | groups=8, 82 | width_per_group=64, 83 | replace_stride_with_dilation=None, 84 | norm_layer=None, 85 | s=0.5, 86 | ): 87 | super(AxialAttentionNet, self).__init__() 88 | block = AxialBlock 89 | if norm_layer is None: 90 | norm_layer = nn.BatchNorm2d 91 | self._norm_layer = norm_layer 92 | 93 | self.inplanes = int(64 * s) 94 | self.dilation = 1 95 | if replace_stride_with_dilation is None: 96 | # each element in the tuple indicates if we should replace 97 | # the 2x2 stride with a dilated convolution instead 98 | replace_stride_with_dilation = [False, False, False] 99 | if len(replace_stride_with_dilation) != 3: 100 | raise ValueError( 101 | "replace_stride_with_dilation should be None " 102 | "or a 3-element tuple, got {}".format( 103 | replace_stride_with_dilation 104 | ) 105 | ) 106 | self.groups = groups 107 | self.base_width = width_per_group 108 | self.conv1 = nn.Conv2d( 109 | 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 110 | ) 111 | self.bn1 = norm_layer(self.inplanes) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer( 115 | block, int(128 * s), layers[0], kernel_size=56 116 | ) 117 | self.layer2 = self._make_layer( 118 | block, 119 | int(256 * s), 120 | layers[1], 121 | stride=2, 122 | kernel_size=56, 123 | dilate=replace_stride_with_dilation[0], 124 | ) 125 | self.layer3 = self._make_layer( 126 | block, 127 | int(512 * s), 128 | layers[2], 129 | stride=2, 130 | kernel_size=28, 131 | dilate=replace_stride_with_dilation[1], 132 | ) 133 | self.layer4 = self._make_layer( 134 | block, 135 | int(1024 * s), 136 | layers[3], 137 | stride=2, 138 | kernel_size=14, 139 | dilate=replace_stride_with_dilation[2], 140 | ) 141 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 142 | self.fc = nn.Linear(int(1024 * block.expansion * s), num_classes) 143 | 144 | for m in self.modules(): 145 | if isinstance(m, (nn.Conv2d, nn.Conv1d)): 146 | if isinstance(m, nn.Conv1d): 147 | pass 148 | else: 149 | nn.init.kaiming_normal_( 150 | m.weight, mode='fan_out', nonlinearity='relu' 151 | ) 152 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm)): 153 | nn.init.constant_(m.weight, 1) 154 | nn.init.constant_(m.bias, 0) 155 | 156 | # Zero-initialize the last BN in each residual branch, 157 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 158 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 159 | if zero_init_residual: 160 | for m in self.modules(): 161 | if isinstance(m, AxialBlock): 162 | nn.init.constant_(m.bn2.weight, 0) 163 | 164 | def _make_layer( 165 | self, block, planes, blocks, kernel_size=56, stride=1, dilate=False 166 | ): 167 | norm_layer = self._norm_layer 168 | downsample = None 169 | previous_dilation = self.dilation 170 | if dilate: 171 | self.dilation *= stride 172 | stride = 1 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | conv1x1(self.inplanes, planes * block.expansion, stride), 176 | norm_layer(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append( 181 | block( 182 | self.inplanes, 183 | planes, 184 | stride, 185 | downsample, 186 | groups=self.groups, 187 | base_width=self.base_width, 188 | dilation=previous_dilation, 189 | norm_layer=norm_layer, 190 | kernel_size=kernel_size, 191 | ) 192 | ) 193 | self.inplanes = planes * block.expansion 194 | if stride != 1: 195 | kernel_size = kernel_size // 2 196 | 197 | for _ in range(1, blocks): 198 | layers.append( 199 | block( 200 | self.inplanes, 201 | planes, 202 | groups=self.groups, 203 | base_width=self.base_width, 204 | dilation=self.dilation, 205 | norm_layer=norm_layer, 206 | kernel_size=kernel_size, 207 | ) 208 | ) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def _forward_impl(self, x, include_conv5=False, include_top=False): 213 | # See note [TorchScript super()] 214 | x = self.conv1(x) 215 | x = self.bn1(x) 216 | x = self.relu(x) 217 | x = self.maxpool(x) 218 | 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | x = self.layer3(x) 222 | if include_conv5: 223 | x = self.layer4(x) 224 | 225 | if include_top: 226 | x = self.avgpool(x) 227 | x = torch.flatten(x, 1) 228 | x = self.fc(x) 229 | 230 | return x 231 | 232 | def forward(self, x, include_conv5=False, include_top=False): 233 | return self._forward_impl(x, include_conv5, include_top) 234 | 235 | 236 | if __name__ == "__main__": 237 | model = AxialAttentionNet([1, 2, 4, 1], s=0.5) 238 | # model = AxialAttentionNet([3, 4, 6, 3], s=0.25) 239 | # print(model) 240 | 241 | x = torch.randn([8, 3, 224, 224]) 242 | print(model(x).shape) 243 | print(model(x, include_conv5=True).shape) 244 | print(model(x, include_conv5=True, include_top=True).shape) 245 | -------------------------------------------------------------------------------- /step_anticipation/data/idx2emoji.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": { 3 | "utf-8": "\\u2744\\ufe0f", 4 | "escape": "\u2744\ufe0f" 5 | }, 6 | "1": { 7 | "utf-8": "\\U0001f3da\\ufe0f", 8 | "escape": "\ud83c\udfda\ufe0f" 9 | }, 10 | "2": { 11 | "utf-8": "\\U0001f6cc\\U0001f3fb", 12 | "escape": "\ud83d\udecc\ud83c\udffb" 13 | }, 14 | "3": { 15 | "utf-8": "\\U0001f693", 16 | "escape": "\ud83d\ude93" 17 | }, 18 | "4": { 19 | "utf-8": "\\U0001f998", 20 | "escape": "\ud83e\udd98" 21 | }, 22 | "5": { 23 | "utf-8": "\\U0001f6d5", 24 | "escape": "\ud83d\uded5" 25 | }, 26 | "6": { 27 | "utf-8": "\\U0001f4e1", 28 | "escape": "\ud83d\udce1" 29 | }, 30 | "7": { 31 | "utf-8": "\\U0001faa4", 32 | "escape": "\ud83e\udea4" 33 | }, 34 | "8": { 35 | "utf-8": "\\U0001f332", 36 | "escape": "\ud83c\udf32" 37 | }, 38 | "9": { 39 | "utf-8": "\\U0001f90c\\U0001f3fd", 40 | "escape": "\ud83e\udd0c\ud83c\udffd" 41 | }, 42 | "10": { 43 | "utf-8": "\\U0001f378", 44 | "escape": "\ud83c\udf78" 45 | }, 46 | "11": { 47 | "utf-8": "\\U0001f31a", 48 | "escape": "\ud83c\udf1a" 49 | }, 50 | "12": { 51 | "utf-8": "\\U0001f9ce\\U0001f3fc", 52 | "escape": "\ud83e\uddce\ud83c\udffc" 53 | }, 54 | "13": { 55 | "utf-8": "\\U0001f349", 56 | "escape": "\ud83c\udf49" 57 | }, 58 | "14": { 59 | "utf-8": "\\U0001f90f\\U0001f3ff", 60 | "escape": "\ud83e\udd0f\ud83c\udfff" 61 | }, 62 | "15": { 63 | "utf-8": "\\U0001f450\\U0001f3fb", 64 | "escape": "\ud83d\udc50\ud83c\udffb" 65 | }, 66 | "16": { 67 | "utf-8": "\\U0001f446\\U0001f3fc", 68 | "escape": "\ud83d\udc46\ud83c\udffc" 69 | }, 70 | "17": { 71 | "utf-8": "\\U0001f61a", 72 | "escape": "\ud83d\ude1a" 73 | }, 74 | "18": { 75 | "utf-8": "\\U0001f68c", 76 | "escape": "\ud83d\ude8c" 77 | }, 78 | "19": { 79 | "utf-8": "\\U0001f487\\U0001f3fb", 80 | "escape": "\ud83d\udc87\ud83c\udffb" 81 | }, 82 | "20": { 83 | "utf-8": "\\U0001f6f5", 84 | "escape": "\ud83d\udef5" 85 | }, 86 | "21": { 87 | "utf-8": "\\U0001f32b\\ufe0f", 88 | "escape": "\ud83c\udf2b\ufe0f" 89 | }, 90 | "22": { 91 | "utf-8": "\\U0001f558", 92 | "escape": "\ud83d\udd58" 93 | }, 94 | "23": { 95 | "utf-8": "\\U0001f37a", 96 | "escape": "\ud83c\udf7a" 97 | }, 98 | "24": { 99 | "utf-8": "\\u23ef\\ufe0f", 100 | "escape": "\u23ef\ufe0f" 101 | }, 102 | "25": { 103 | "utf-8": "\\U0001f481\\U0001f3fd", 104 | "escape": "\ud83d\udc81\ud83c\udffd" 105 | }, 106 | "26": { 107 | "utf-8": "\\U0001f57a\\U0001f3fd", 108 | "escape": "\ud83d\udd7a\ud83c\udffd" 109 | }, 110 | "27": { 111 | "utf-8": "\\U0001f3cc\\U0001f3fc", 112 | "escape": "\ud83c\udfcc\ud83c\udffc" 113 | }, 114 | "28": { 115 | "utf-8": "\\U0001f977\\U0001f3fd", 116 | "escape": "\ud83e\udd77\ud83c\udffd" 117 | }, 118 | "29": { 119 | "utf-8": "\\U0001f368", 120 | "escape": "\ud83c\udf68" 121 | }, 122 | "30": { 123 | "utf-8": "\\u2650", 124 | "escape": "\u2650" 125 | }, 126 | "31": { 127 | "utf-8": "\\u2934\\ufe0f", 128 | "escape": "\u2934\ufe0f" 129 | }, 130 | "32": { 131 | "utf-8": "\\u2651", 132 | "escape": "\u2651" 133 | }, 134 | "33": { 135 | "utf-8": "\\U0001f9dd\\U0001f3ff", 136 | "escape": "\ud83e\udddd\ud83c\udfff" 137 | }, 138 | "34": { 139 | "utf-8": "\\U0001f5c4\\ufe0f", 140 | "escape": "\ud83d\uddc4\ufe0f" 141 | }, 142 | "35": { 143 | "utf-8": "\\U0001f449", 144 | "escape": "\ud83d\udc49" 145 | }, 146 | "36": { 147 | "utf-8": "\\U0001f3cc\\U0001f3fc", 148 | "escape": "\ud83c\udfcc\ud83c\udffc" 149 | }, 150 | "37": { 151 | "utf-8": "\\U0001f55c", 152 | "escape": "\ud83d\udd5c" 153 | }, 154 | "38": { 155 | "utf-8": "\\U0001f576\\ufe0f", 156 | "escape": "\ud83d\udd76\ufe0f" 157 | }, 158 | "39": { 159 | "utf-8": "\\U0001f9b5\\U0001f3ff", 160 | "escape": "\ud83e\uddb5\ud83c\udfff" 161 | }, 162 | "40": { 163 | "utf-8": "\\U0001fa82", 164 | "escape": "\ud83e\ude82" 165 | }, 166 | "41": { 167 | "utf-8": "\\U0001f358", 168 | "escape": "\ud83c\udf58" 169 | }, 170 | "42": { 171 | "utf-8": "\\U0001fa96", 172 | "escape": "\ud83e\ude96" 173 | }, 174 | "43": { 175 | "utf-8": "\\U0001f6ec", 176 | "escape": "\ud83d\udeec" 177 | }, 178 | "44": { 179 | "utf-8": "\\U0001f646", 180 | "escape": "\ud83d\ude46" 181 | }, 182 | "45": { 183 | "utf-8": "\\U0001f484", 184 | "escape": "\ud83d\udc84" 185 | }, 186 | "46": { 187 | "utf-8": "\\U0001f510", 188 | "escape": "\ud83d\udd10" 189 | }, 190 | "47": { 191 | "utf-8": "\\U0001f556", 192 | "escape": "\ud83d\udd56" 193 | }, 194 | "48": { 195 | "utf-8": "\\U0001f38e", 196 | "escape": "\ud83c\udf8e" 197 | }, 198 | "49": { 199 | "utf-8": "\\U0001f949", 200 | "escape": "\ud83e\udd49" 201 | }, 202 | "50": { 203 | "utf-8": "\\U0001f574\\U0001f3fe", 204 | "escape": "\ud83d\udd74\ud83c\udffe" 205 | }, 206 | "51": { 207 | "utf-8": "\\U0001f49a", 208 | "escape": "\ud83d\udc9a" 209 | }, 210 | "52": { 211 | "utf-8": "\\U0001f5dc\\ufe0f", 212 | "escape": "\ud83d\udddc\ufe0f" 213 | }, 214 | "53": { 215 | "utf-8": "\\U0001f466\\U0001f3fd", 216 | "escape": "\ud83d\udc66\ud83c\udffd" 217 | }, 218 | "54": { 219 | "utf-8": "\\U0001f45c", 220 | "escape": "\ud83d\udc5c" 221 | }, 222 | "55": { 223 | "utf-8": "\\u2697\\ufe0f", 224 | "escape": "\u2697\ufe0f" 225 | }, 226 | "56": { 227 | "utf-8": "\\U0001fa9b", 228 | "escape": "\ud83e\ude9b" 229 | }, 230 | "57": { 231 | "utf-8": "\\u26f3", 232 | "escape": "\u26f3" 233 | }, 234 | "58": { 235 | "utf-8": "\\U0001fad3", 236 | "escape": "\ud83e\uded3" 237 | }, 238 | "59": { 239 | "utf-8": "\\U0001f469", 240 | "escape": "\ud83d\udc69" 241 | }, 242 | "60": { 243 | "utf-8": "\\U0001f560", 244 | "escape": "\ud83d\udd60" 245 | }, 246 | "61": { 247 | "utf-8": "\\U0001f5ff", 248 | "escape": "\ud83d\uddff" 249 | }, 250 | "62": { 251 | "utf-8": "\\U0001f450\\U0001f3fe", 252 | "escape": "\ud83d\udc50\ud83c\udffe" 253 | }, 254 | "63": { 255 | "utf-8": "\\U0001f524", 256 | "escape": "\ud83d\udd24" 257 | }, 258 | "64": { 259 | "utf-8": "\\U0001f416", 260 | "escape": "\ud83d\udc16" 261 | }, 262 | "65": { 263 | "utf-8": "\\u264f", 264 | "escape": "\u264f" 265 | }, 266 | "66": { 267 | "utf-8": "\\U0001f612", 268 | "escape": "\ud83d\ude12" 269 | }, 270 | "67": { 271 | "utf-8": "\\U0001f341", 272 | "escape": "\ud83c\udf41" 273 | }, 274 | "68": { 275 | "utf-8": "\\U0001f194", 276 | "escape": "\ud83c\udd94" 277 | }, 278 | "69": { 279 | "utf-8": "\\U0001f91a\\U0001f3fb", 280 | "escape": "\ud83e\udd1a\ud83c\udffb" 281 | }, 282 | "70": { 283 | "utf-8": "\\U0001f9d4\\U0001f3fc", 284 | "escape": "\ud83e\uddd4\ud83c\udffc" 285 | }, 286 | "71": { 287 | "utf-8": "\\U0001f4a3", 288 | "escape": "\ud83d\udca3" 289 | }, 290 | "72": { 291 | "utf-8": "\\U0001f433", 292 | "escape": "\ud83d\udc33" 293 | }, 294 | "73": { 295 | "utf-8": "\\U0001f198", 296 | "escape": "\ud83c\udd98" 297 | }, 298 | "74": { 299 | "utf-8": "\\U0001f0cf", 300 | "escape": "\ud83c\udccf" 301 | }, 302 | "75": { 303 | "utf-8": "\\u2139\\ufe0f", 304 | "escape": "\u2139\ufe0f" 305 | }, 306 | "76": { 307 | "utf-8": "\\U0001f53c", 308 | "escape": "\ud83d\udd3c" 309 | }, 310 | "77": { 311 | "utf-8": "\\U0001f9f0", 312 | "escape": "\ud83e\uddf0" 313 | }, 314 | "78": { 315 | "utf-8": "\\U0001f9db\\U0001f3fd", 316 | "escape": "\ud83e\udddb\ud83c\udffd" 317 | }, 318 | "79": { 319 | "utf-8": "\\U0001f461", 320 | "escape": "\ud83d\udc61" 321 | }, 322 | "80": { 323 | "utf-8": "\\U0001f96f", 324 | "escape": "\ud83e\udd6f" 325 | }, 326 | "81": { 327 | "utf-8": "\\U0001f37d\\ufe0f", 328 | "escape": "\ud83c\udf7d\ufe0f" 329 | }, 330 | "82": { 331 | "utf-8": "\\u270b\\U0001f3fc", 332 | "escape": "\u270b\ud83c\udffc" 333 | }, 334 | "83": { 335 | "utf-8": "\\U0001f512", 336 | "escape": "\ud83d\udd12" 337 | }, 338 | "84": { 339 | "utf-8": "\\U0001f3cc\\U0001f3ff", 340 | "escape": "\ud83c\udfcc\ud83c\udfff" 341 | }, 342 | "85": { 343 | "utf-8": "\\U0001f952", 344 | "escape": "\ud83e\udd52" 345 | } 346 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PREGO: online mistake detection in PRocedural EGOcentric videos (CVPR 2024) 2 | | 3 | **[PREGO paper [CVPR 2024]](https://openaccess.thecvf.com/content/CVPR2024/html/Flaborea_PREGO_Online_Mistake_Detection_in_PRocedural_EGOcentric_Videos_CVPR_2024_paper.html)** 4 | | 5 | **[TI-PREGO paper [arXiv]](https://arxiv.org/abs/2411.02570)** 6 | 7 | 8 | ## Index 9 | 10 | 1. [Introduction](#introduction) 11 | 2. [News](#news) 12 | 3. [Preparation](#preparation) 13 | - [Data](#data) 14 | - [LLAMA](#llama) 15 | - [Environment](#environment) 16 | 4. [Usage](#usage) 17 | - [Step Recognition](#step-recognition) 18 | - [Data Aggregation](#data-aggregation) 19 | - [Step Anticipation](#step-anticipation) 20 | - [Data Preparation](#data-preparation) 21 | - [Parameters](#parameters) 22 | - [Run](#run) 23 | 5. [Reference](#reference) 24 | 25 | 26 | ## Introduction 27 | This repo hosts the official PyTorch implementations of the *IEEE/CVF Computer Vision and Pattern Recognition (CVPR) '24* paper **PREGO: online mistake detection in PRocedural EGOcentric videos** and of the follow-up paper **TI-PREGO: Chain of Thought and In-Context Learning for Online Mistake Detection in PRocedural EGOcentric Videos**. 28 | 29 | PREGO is the first online one-class classification model for mistake detection in procedural egocentric videos. It uses an online action recognition component to model current actions and a symbolic reasoning module to predict next actions, detecting mistakes by comparing the recognized current action with the expected future one. We evaluate this on two adapted datasets, *Assembly101-O* and *Epic-tent-O*, for online benchmarking of procedural mistake detection. 30 | 31 | ![teaser_image](assets/teaser.png) 32 | 33 | ## News 34 | **[2024-12-01]** Uploaded the recognition branch. 35 | 36 | **[2024-11-12]** Uploaded the script for the prediction aggregation strategy described in [[TI-PREGO]](https://arxiv.org/abs/2411.02570). 37 | 38 | **[2024-11-12]** Uploaded the TSN features for Assembly101-O and Epic-tent-O [[GDrive]](https://drive.google.com/drive/u/1/folders/1gcOIEXhwysCE2o8-5C4vQnTShJ7p3CKH). 39 | 40 | **[2024-11-04]** Published the follow-up paper [[TI-PREGO]](https://arxiv.org/abs/2411.02570). 41 | 42 | **[2024-06-20]** Presented PREGO at #CVPR2024. 43 | 44 | **[2024-06-16]** Uploaded the anticipation branch. 45 | 46 | ## Preparation 47 | ### Data 48 | The TSN features of the Assembly101-O and Epic-tent-O datasets can be downloaded here: [[GDrive]](https://drive.google.com/drive/u/1/folders/1gcOIEXhwysCE2o8-5C4vQnTShJ7p3CKH). 49 | 50 | To download the data using `gdown`, run the following command: 51 | ```bash 52 | gdown --folder --remaining-ok https://drive.google.com/drive/u/1/folders/1gcOIEXhwysCE2o8-5C4vQnTShJ7p3CKH 53 | ``` 54 | 55 | The folder follows the structure described in [MiniROAD](https://github.com/jbistanbul/MiniROAD): 56 | ``` 57 | PREGO 58 | | 59 | |__________ Assembly101-O 60 | | | 61 | | |__________ rgb_anet_resnet50 62 | | | | 63 | | | |_________nusar-2021_action_both_9011-b06b_9011_user_id_2021-02-01_154253.npy 64 | | | |_________... 65 | | |__________ rgb_as_flow 66 | | | | 67 | | | |_________nusar-2021_action_both_9011-b06b_9011_user_id_2021-02-01_154253.npy 68 | | | |_________... 69 | | |__________ target_perframe 70 | | | 71 | | |_________nusar-2021_action_both_9011-b06b_9011_user_id_2021-02-01_154253.npy 72 | | |_________... 73 | |__________ Epic-tent-O 74 | | 75 | |__________ rgb_anet_resnet50 76 | | | 77 | | |_________annotations_1.npy 78 | | |_________... 79 | |__________ rgb_as_flow 80 | | | 81 | | |_________annotations_1.npy 82 | | |_________... 83 | |__________ target_perframe 84 | | 85 | |_________annotations_1.npy 86 | |_________... 87 | ``` 88 | 89 | 90 | ### LLAMA 91 | To run our anticipation step with LLAMA, you must be granted access to the models by Meta [here](https://www.llama.com/llama-downloads/). 92 | Place them wherever you like, and recall to update the paths whenever necessary, as in `step_anticipation/scripts/anticipation.sh`. 93 | 94 | ### Environment 95 | You can choose between creating a `conda` or `virtualenv` environment, as you prefer 96 | ```bash 97 | # conda 98 | conda create -n prego python=3.10 99 | conda activate prego 100 | 101 | # virtualenv 102 | python3.10 -m venv .venv 103 | source .venv/bin/activate 104 | ``` 105 | Then, install the requirements 106 | ```bash 107 | pip install -r requirements.txt 108 | ``` 109 | Install `unsloth` following the instructions [here](https://docs.unsloth.ai/get-started/installation/pip-install). 110 | 111 | ## Usage 112 | 113 | ### Step Recognition 114 | For more detaila regarding the Step Recognition branch, you can refer to the official implementation of MiniROAD [here](https://github.com/jbistanbul/MiniROAD). 115 | 116 | To run the training on Assembly101-O for example, use the command 117 | ```bash 118 | python step_recognition/main.py --config step_recognition/configs/miniroad_assembly101-O.yaml 119 | ``` 120 | that will save the checkpoints in the folder `step_recognition/checkpoint/miniROAD/Assembly101-O`. 121 | 122 | At this point, you can use the checkpoint for evaluation and it will save predictions frame by frame as a JSON file in the folder `output_miniROAD` using the command 123 | 124 | ```bash 125 | python step_recognition/main.py --config step_recognition/configs/miniroad_assembly101-O.yaml --eval 126 | ``` 127 | 128 | ### Data Aggregation 129 | The `utils/aggregate.py` script handles the data aggregation process. 130 | This script is responsible for aggregating predictions and ground truth data and saving the results to a JSON file. 131 | 132 | To run the data aggregation script, use the following command using as input the JSON that was created in the section Step Recognition: 133 | 134 | ```bash 135 | python utils/aggregate.py 136 | ``` 137 | 138 | - ``: Path to the input JSON file containing the data. 139 | - ``: Path to save the aggregated JSON file. 140 | 141 | ### Example 142 | 143 | ```bash 144 | python utils/aggregate.py data/input.json data/output/aggregated_data.json 145 | ``` 146 | 147 | ### Step Anticipation 148 | 149 | #### Data Preparation 150 | Description of the steps needed to prepare the data for the Step Anticipation branch. 151 | 152 | Step Recognition predictions: 153 | - place the predictions (after aggregation) of the Step Recognizer in the `step_anticipation/data/predictions` 154 | - the file should have the following structure: 155 | ```json 156 | { 157 | "nusar-2021_action_both_9044-a08_9044_user_id_2021-02-05_154403": { 158 | "pred": [ 159 | 39, 160 | 37, 161 | 74, 162 | 39, 163 | 37 164 | ], 165 | "gt": [ 166 | 37, 167 | 80, 168 | 39, 169 | 29, 170 | 85 171 | ] 172 | }, 173 | ... 174 | } 175 | ``` 176 | Context prompt: 177 | - `step_anticipation/data/context_prompt/assembly_context_prompt_train.json` and `step_anticipation/data/context_prompt/epictents_context_prompt_train.json` contain the context to be used for the In-context learning prompt. 178 | - `step_anticipation/data/context_prompt/context_prompt.json` contains the strings to fill the context prompt. 179 | 180 | #### Parameters 181 | Description of the parameters that can be added to the `step_anticipation/scripts/anticipation.sh` script. 182 | 183 | - `ckpt_dir=/path/to//llama/llama-2-7b` 184 | - `tokenizer_path=/path/to/tokenizer/llama/tokenizer.model` 185 | - `max_seq_len=2048` Maximum sequence length for input text 186 | - `max_batch_size` Maximum batch size for generating sequences 187 | - `temperature` Temperature value for controlling randomness in generation 188 | - `max_gen_len` Maximum length of the generated text sequence. 189 | - `num_samples` How many generations per each input context 190 | - `use_gt` Select if gt or predictions from Step Recognizer are used as input context 191 | - `dataset` Select the dataset to use. ['assembly', 'epictent'] 192 | - `type_prompt` Select which type of context to be passed. ['num', 'alpha', 'emoji'] 193 | - `toy_class_context` For the assembly dataset only. If True, the input context has all the examples from the same class of toys 194 | - `recognition_model` If not use_gt, select which Step Recognizer predictions to use. ['miniROAD', 'OadTR'] 195 | - `prompt_context` Select how the prompt context is structured. ['default', 'unreferenced','elaborate','no-context'] 196 | 197 | 198 | #### Run 199 | ```bash 200 | cd step_anticipation 201 | ./scripts/anticipation.sh 202 | ``` 203 | 204 | ## Reference 205 | If you find our code or paper to be helpful, please consider citing: 206 | ``` 207 | @InProceedings{Flaborea_2024_CVPR, 208 | author = {Flaborea, Alessandro and di Melendugno, Guido Maria D'Amely and Plini, Leonardo and Scofano, Luca and De Matteis, Edoardo and Furnari, Antonino and Farinella, Giovanni Maria and Galasso, Fabio}, 209 | title = {PREGO: Online Mistake Detection in PRocedural EGOcentric Videos}, 210 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 211 | month = {June}, 212 | year = {2024}, 213 | pages = {18483-18492} 214 | } 215 | ``` 216 | ``` 217 | @misc{plini2024tipregochainthoughtincontext, 218 | title={TI-PREGO: Chain of Thought and In-Context Learning for Online Mistake Detection in PRocedural EGOcentric Videos}, 219 | author={Leonardo Plini and Luca Scofano and Edoardo De Matteis and Guido Maria D'Amely di Melendugno and Alessandro Flaborea and Andrea Sanchietti and Giovanni Maria Farinella and Fabio Galasso and Antonino Furnari}, 220 | year={2024}, 221 | eprint={2411.02570}, 222 | archivePrefix={arXiv}, 223 | primaryClass={cs.CV}, 224 | url={https://arxiv.org/abs/2411.02570}, 225 | } 226 | ``` 227 | -------------------------------------------------------------------------------- /step_anticipation/src/models/llm_ollama.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import re 5 | from typing import Optional 6 | 7 | import fire 8 | import numpy as np 9 | from ollama import ChatResponse, chat 10 | 11 | BASE_PATH = "step_anticipation/data" 12 | CONTEXT_PROMPT_PATH = f"{BASE_PATH}/context_prompt" 13 | PREDICTIONS_PATH = f"{BASE_PATH}/predictions" 14 | 15 | 16 | def get_metrics(preds, gts): 17 | tp, fp, fn, tn = 0, 0, 0, 0 18 | count, samples = 0, 0 19 | for k in gts.keys(): 20 | gt = gts[k] 21 | pred = preds[k] 22 | matches = np.array([g in p for g, p in zip(gt, pred)]) 23 | 24 | count += np.sum(matches) 25 | samples += len(matches) 26 | correct = matches[:-1] 27 | mistake = matches[-1] 28 | 29 | tn += np.sum(correct) 30 | fp += np.sum(~correct) 31 | tp += int(not mistake) 32 | fn += int(mistake) 33 | 34 | acc = (tp + tn) / (tp + tn + fp + fn) 35 | precision = tp / (tp + fp) 36 | recall = tp / (tp + fn) 37 | f1 = 2 * precision * recall / (precision + recall) 38 | ratio = count / samples 39 | 40 | return { 41 | "tp": tp, 42 | "fp": fp, 43 | "fn": fn, 44 | "tn": tn, 45 | "accuracy": acc, 46 | "precision": precision, 47 | "recall": recall, 48 | "f1": f1, 49 | "ratio": ratio, 50 | "count": count, 51 | "samples": samples, 52 | } 53 | 54 | 55 | def get_toy(name: str) -> str: 56 | toy = name.split("-")[2].split("_")[0] 57 | return toy 58 | 59 | 60 | def load_data(path: str) -> dict: 61 | data = json.load(open(path, "r")) 62 | return data 63 | 64 | 65 | def remove_sequenceInput(prompt, toy_class): 66 | new_prompt = "" 67 | start = 0 68 | for m in re.finditer(r"Sequence type: [a-zA-Z0-9]{3,}\n", prompt): 69 | new_prompt += prompt[start : m.start()] 70 | new_prompt += f"Sequence type: {toy_class}\n" 71 | start = m.end() 72 | new_prompt += prompt[start:] 73 | return new_prompt.replace("Symbol", "Sequence") 74 | 75 | 76 | def anticipation( 77 | ollama_model, 78 | seq: list, 79 | prompt: str, 80 | toy: Optional[str], 81 | toy_class: Optional[str], 82 | num_samples: int, 83 | clean_prediction: bool, 84 | type_prompt="num", 85 | prompt_context="default", 86 | ): 87 | preds, gts = [], [] 88 | 89 | if type_prompt == "emoji": 90 | prompt = prompt.replace("-1", "👉") 91 | 92 | if toy_class: 93 | remove_toySequence = True 94 | prompt = remove_sequenceInput(prompt, toy_class) 95 | else: 96 | remove_toySequence = False 97 | 98 | for i in range(len(seq)): 99 | prompt_builder = load_data(f"{CONTEXT_PROMPT_PATH}/context_prompt.json") 100 | init = prompt_builder[prompt_context]["init"] 101 | 102 | if remove_toySequence: 103 | prompt_ = f"{prompt}{init} {toy_class}\n" 104 | else: 105 | prompt_ = f"{prompt}{init} {toy}\n" 106 | 107 | if type_prompt == "emoji": 108 | hist, action = ["👉"] + seq[:i], seq[i] 109 | else: 110 | hist, action = [-1] + seq[:i], seq[i] 111 | 112 | input_builder = prompt_builder[prompt_context]["input"] 113 | prompt_ += f"{input_builder}\n {', '.join(map(str, hist))}\n" 114 | 115 | output_builder = prompt_builder[prompt_context]["output"] 116 | prompt_ += f"{output_builder}\n" 117 | 118 | pred = set() 119 | 120 | messages = [ 121 | { 122 | "role": "system", 123 | "content": "Always provide only the final output, consisting in one and only one number. Never output anything different from a single number.", 124 | }, 125 | { 126 | "role": "user", 127 | "content": prompt_, 128 | }, 129 | ] 130 | for _ in range(num_samples): 131 | response: ChatResponse = chat(model=ollama_model, messages=messages) 132 | # result = ollama_model.generate( 133 | # prompt=prompt_, 134 | # max_tokens=max_gen_len, 135 | # temperature=temperature, 136 | # top_p=top_p, 137 | # ) 138 | 139 | result = response.message.content 140 | print(f"> {result}") 141 | v = result.strip().strip("_") 142 | # v = result.text.strip().strip("_") 143 | 144 | if type_prompt == "num": 145 | v = re.sub(r"^[^0-9]*|[^0-9]*$", "", v) 146 | try: 147 | v = int(v) 148 | except: 149 | pass 150 | 151 | pred.add(v if type_prompt != "emoji" else v[0] if v else "") 152 | 153 | gts.append(action) 154 | preds.append(pred) 155 | 156 | return preds, gts 157 | 158 | 159 | # def anticipation( 160 | # seq: list, 161 | # prompt: str, 162 | # toy: Optional[str], 163 | # toy_class: Optional[str], 164 | # ollama_model, 165 | # max_gen_len: Optional[int], 166 | # temperature: float, 167 | # top_p: float, 168 | # num_samples: int, 169 | # clean_prediction: bool, 170 | # type_prompt="num", 171 | # prompt_context="default", 172 | # ): 173 | # preds, gts = [], [] 174 | 175 | # if type_prompt == "emoji": 176 | # prompt = prompt.replace("-1", "👉") 177 | 178 | # if toy_class: 179 | # remove_toySequence = True 180 | # prompt = remove_sequenceInput(prompt, toy_class) 181 | # else: 182 | # remove_toySequence = False 183 | 184 | # for i in range(len(seq)): 185 | # prompt_builder = load_data(f"{CONTEXT_PROMPT_PATH}/context_prompt.json") 186 | # init = prompt_builder[prompt_context]["init"] 187 | 188 | # if remove_toySequence: 189 | # prompt_ = f"{prompt}{init} {toy_class}\n" 190 | # else: 191 | # prompt_ = f"{prompt}{init} {toy}\n" 192 | 193 | # if type_prompt == "emoji": 194 | # hist, action = ["👉"] + seq[:i], seq[i] 195 | # else: 196 | # hist, action = [-1] + seq[:i], seq[i] 197 | 198 | # input_builder = prompt_builder[prompt_context]["input"] 199 | # prompt_ += f"{input_builder}\n {', '.join(map(str, hist))}\n" 200 | 201 | # output_builder = prompt_builder[prompt_context]["output"] 202 | # prompt_ += f"{output_builder}\n" 203 | 204 | # pred = set() 205 | # for _ in range(num_samples): 206 | # result = ollama_model.generate( 207 | # prompt=prompt_, 208 | # max_tokens=max_gen_len, 209 | # temperature=temperature, 210 | # top_p=top_p, 211 | # ) 212 | 213 | # v = result.text.strip().strip("_") 214 | 215 | # if type_prompt == "num": 216 | # v = re.sub(r"^[^0-9]*|[^0-9]*$", "", v) 217 | # try: 218 | # v = int(v) 219 | # except: 220 | # pass 221 | 222 | # pred.add(v if type_prompt != "emoji" else v[0] if v else "") 223 | 224 | # gts.append(action) 225 | # preds.append(pred) 226 | 227 | # return preds, gts 228 | 229 | 230 | def main( 231 | ollama_model_name: str, 232 | # max_seq_len: int = 512, 233 | # max_gen_len: Optional[int] = None, 234 | # temperature: float = 0.6, 235 | # top_p: float = 0.9, 236 | ollama_model="llama3.2:3b", 237 | num_samples: int = 1, 238 | use_gt: bool = False, 239 | type_prompt: str = "num", 240 | clean_prediction: bool = False, 241 | eval_metrics: bool = True, 242 | dataset: str = "assembly", 243 | toy_class_context: bool = False, 244 | recognition_model: str = "miniROAD", 245 | prompt_context: str = "default", 246 | ): 247 | if dataset == "assembly": 248 | if toy_class_context: 249 | toy2class = json.load(open("assets/toy2class.json", "r")) 250 | contexts = load_data( 251 | f"{CONTEXT_PROMPT_PATH}/assembly_context_prompt_train.json" 252 | ) 253 | else: 254 | contexts = load_data( 255 | f"{CONTEXT_PROMPT_PATH}/supplementary/assembly_context_prompt_train_onlyToy.json" 256 | ) 257 | 258 | if recognition_model == "OadTR": 259 | seqs = load_data(f"{PREDICTIONS_PATH}/output_OadTR_Assembly101-O.json") 260 | elif recognition_model == "miniROAD": 261 | seqs = load_data(f"{PREDICTIONS_PATH}/output_miniROAD_Assembly101-O.json") 262 | 263 | elif dataset == "epictent": 264 | contexts = load_data( 265 | f"{CONTEXT_PROMPT_PATH}/epictent_context_prompt_train.json" 266 | ) 267 | 268 | if recognition_model == "OadTR": 269 | seqs = load_data(f"{PREDICTIONS_PATH}/output_OadTR_Epic-Tent-O.json") 270 | elif recognition_model == "miniROAD": 271 | seqs = load_data(f"{PREDICTIONS_PATH}/output_miniROAD_Epic-Tent-O_edo.json") 272 | else: 273 | raise ValueError(f"Dataset {dataset} not supported") 274 | 275 | preds, gts = {}, {} 276 | 277 | for i, (k, v) in enumerate(seqs.items()): 278 | if dataset == "assembly": 279 | toy = get_toy(k) 280 | 281 | if toy_class_context: 282 | toy_class = toy2class[toy] 283 | prompt = contexts[toy_class][type_prompt] 284 | else: 285 | toy_class = None 286 | prompt = contexts.get(toy, {}).get(type_prompt, "") 287 | 288 | elif dataset == "epictent": 289 | toy = None 290 | toy_class = None 291 | prompt = contexts[type_prompt] 292 | else: 293 | raise ValueError(f"Dataset {dataset} not supported") 294 | 295 | seq = v["gt"] if use_gt else v["pred"] 296 | 297 | pred, gt = anticipation( 298 | seq=seq, 299 | prompt=prompt, 300 | toy=toy, 301 | toy_class=toy_class, 302 | ollama_model=ollama_model, 303 | # max_gen_len=max_gen_len, 304 | # temperature=temperature, 305 | # top_p=top_p, 306 | num_samples=num_samples, 307 | clean_prediction=clean_prediction, 308 | type_prompt=type_prompt, 309 | prompt_context=prompt_context, 310 | ) 311 | 312 | preds[k] = pred 313 | gts[k] = gt 314 | 315 | if eval_metrics: 316 | metrics = get_metrics(preds=preds, gts=gts) 317 | print(f"[INFO] {metrics}") 318 | 319 | model_name = ollama_model_name.replace("/", "_") 320 | save_folder = f"results/{model_name}_{dataset}_{prompt_context}" 321 | 322 | os.makedirs(save_folder, exist_ok=True) 323 | pickle.dump(gts, open(f"{save_folder}/gts.pkl", "wb")) 324 | pickle.dump(preds, open(f"{save_folder}/preds.pkl", "wb")) 325 | 326 | 327 | if __name__ == "__main__": 328 | fire.Fire(main) 329 | -------------------------------------------------------------------------------- /step_anticipation/data/context_prompt/epictent_context_prompt_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "num": "Input Sequence:\n -1, 4, 7, 4, 6, 4, 10, 4, 4, 0, 0, 4, 2, 2, 4, 4, 4, 3, 4, 3, 3, 3, 9, 4, 5, 4, 4, 1, 1, 1, 1, 4, 9, 9, 9, 9, 11, 4\nNext Symbol:\n 8\n---\nInput Sequence:\n -1, 4, 7, 10, 6, 0, 0, 2, 2, 3, 3, 3, 3, 3, 8, 11, 8, 5, 1, 1, 1, 1, 5, 9, 9, 9\nNext Symbol:\n 9\n---\nInput Sequence:\n -1, 4, 7, 10, 6, 5, 4, 0, 0, 2, 2, 4, 3, 3, 3, 3, 4, 3, 3, 3, 3, 4, 8, 4, 4, 1, 1, 1, 1, 4, 9, 9, 9, 9, 4, 8, 4, 4, 4, 11\nNext Symbol:\n 4\n---\nInput Sequence:\n -1, 4, 7, 4, 6, 10, 6, 0, 0, 4, 2, 2, 4, 3, 3, 3, 3, 3, 3, 3, 4, 5, 1, 4, 1, 1, 1, 4, 9, 4, 9, 9, 9, 4, 4, 4, 8, 11\nNext Symbol:\n 4\n---\nInput Sequence:\n -1, 4, 7, 10, 4, 6, 0, 2, 3, 3, 0, 2, 3, 0, 3, 4, 5, 1, 1, 1, 1, 9, 9, 9, 9, 4\nNext Symbol:\n 8\n---\nInput Sequence:\n -1, 4, 7, 10, 7, 10, 7, 6, 0, 2, 0, 2, 3, 2, 2, 3, 2, 3, 2, 3, 9, 7, 5, 1, 10, 1, 1, 1, 5, 9, 9, 9, 9, 8, 11, 8\nNext Symbol:\n 6\n---\nInput Sequence:\n -1, 7, 10, 4, 6, 0, 2, 0, 0, 2, 5, 1, 1, 1, 1, 3, 0, 3, 3, 1, 3, 1, 1, 4, 9, 4, 8, 11, 8, 9, 9, 8, 9, 1\nNext Symbol:\n 9\n---\nInput Sequence:\n -1, 4, 7, 4, 4, 4, 10, 4, 6, 0, 0, 4, 4, 5, 4, 0, 2, 3, 3, 2, 3, 3, 3, 3, 4, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 4, 8\nNext Symbol:\n 4\n---\nInput Sequence:\n -1, 7, 4, 10, 4, 6, 0, 0, 4, 4, 2, 4, 4, 2, 4, 4, 3, 3, 3, 3, 3, 4, 5, 4, 1, 1, 5, 1, 1, 4, 9, 9, 9, 9, 9, 4\nNext Symbol:\n 8\n---\nInput Sequence:\n -1, 4, 7, 4, 10, 6, 0, 2, 0, 2, 3, 3, 3, 3, 5, 1, 9, 1, 1, 1, 9, 9, 9\nNext Symbol:\n 9\n---\nInput Sequence:\n -1, 4, 4, 4, 7, 5, 4, 4, 6, 10, 4, 8, 10, 4, 4, 4, 6, 0, 4, 0, 4, 2, 4, 3, 4, 3, 4, 3, 2, 3, 3, 3, 4, 5, 4, 1, 1, 1, 1, 4, 9, 4, 4, 8, 4, 9, 4, 4, 4, 8, 9, 4, 4, 9, 9, 9\nNext Symbol:\n 9\n---\nInput Sequence:\n -1, 4, 7, 4, 10, 4, 4, 4, 6, 0, 4, 4, 4, 4, 0, 4, 4, 2, 4, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 3, 3, 4, 5, 4, 4, 4, 1, 4, 1, 1, 1, 4, 9, 4, 9, 9, 9, 4, 8, 4, 4, 8, 4, 6, 5, 4, 11, 8\nNext Symbol:\n 8\n---\nInput Sequence:\n -1, 4, 7, 10, 6, 5, 8, 4, 10, 4, 0, 4, 2, 4, 2, 4, 2, 0, 2, 3, 2, 4, 3, 3, 2, 3, 2, 3, 1, 10, 1, 10, 1, 10, 1, 9, 11, 8, 4, 4, 8, 4, 9, 4, 9, 4, 9, 4, 9, 4, 9, 9\nNext Symbol:\n 9\n---\nInput Sequence:\n -1, 4, 7, 6, 10, 6, 0, 0, 2, 0, 3, 3, 2, 0, 0, 3, 3, 11, 5, 1, 10, 1, 1, 10, 1, 9, 9, 9, 9\nNext Symbol:\n 8\n---\n", 3 | "alpha": "Input Sequence:\n -1, instruction, pickup/open_tentbag, instruction, pickup/open_supportbag, instruction, spread_tent, instruction, instruction, assemble_support, assemble_support, instruction, insert_support, insert_support, instruction, instruction, instruction, insert_support_tab, instruction, insert_support_tab, insert_support_tab, insert_support_tab, place_guyline, instruction, pickup/open_stakebag, instruction, instruction, insert_stake, insert_stake, insert_stake, insert_stake, instruction, place_guyline, place_guyline, place_guyline, place_guyline, tie_top, instruction\nNext Symbol:\n pickup/place_ventcover\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, spread_tent, pickup/open_supportbag, assemble_support, assemble_support, insert_support, insert_support, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, pickup/place_ventcover, tie_top, pickup/place_ventcover, pickup/open_stakebag, insert_stake, insert_stake, insert_stake, insert_stake, pickup/open_stakebag, place_guyline, place_guyline, place_guyline\nNext Symbol:\n place_guyline\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, spread_tent, pickup/open_supportbag, pickup/open_stakebag, instruction, assemble_support, assemble_support, insert_support, insert_support, instruction, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, instruction, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, instruction, pickup/place_ventcover, instruction, instruction, insert_stake, insert_stake, insert_stake, insert_stake, instruction, place_guyline, place_guyline, place_guyline, place_guyline, instruction, pickup/place_ventcover, instruction, instruction, instruction, tie_top\nNext Symbol:\n instruction\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, instruction, pickup/open_supportbag, spread_tent, pickup/open_supportbag, assemble_support, assemble_support, instruction, insert_support, insert_support, instruction, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, instruction, pickup/open_stakebag, insert_stake, instruction, insert_stake, insert_stake, insert_stake, instruction, place_guyline, instruction, place_guyline, place_guyline, place_guyline, instruction, instruction, instruction, pickup/place_ventcover, tie_top\nNext Symbol:\n instruction\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, spread_tent, instruction, pickup/open_supportbag, assemble_support, insert_support, insert_support_tab, insert_support_tab, assemble_support, insert_support, insert_support_tab, assemble_support, insert_support_tab, instruction, pickup/open_stakebag, insert_stake, insert_stake, insert_stake, insert_stake, place_guyline, place_guyline, place_guyline, place_guyline, instruction\nNext Symbol:\n pickup/place_ventcover\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, spread_tent, pickup/open_tentbag, spread_tent, pickup/open_tentbag, pickup/open_supportbag, assemble_support, insert_support, assemble_support, insert_support, insert_support_tab, insert_support, insert_support, insert_support_tab, insert_support, insert_support_tab, insert_support, insert_support_tab, place_guyline, pickup/open_tentbag, pickup/open_stakebag, insert_stake, spread_tent, insert_stake, insert_stake, insert_stake, pickup/open_stakebag, place_guyline, place_guyline, place_guyline, place_guyline, pickup/place_ventcover, tie_top, pickup/place_ventcover\nNext Symbol:\n pickup/open_supportbag\n---\nInput Sequence:\n -1, pickup/open_tentbag, spread_tent, instruction, pickup/open_supportbag, assemble_support, insert_support, assemble_support, assemble_support, insert_support, pickup/open_stakebag, insert_stake, insert_stake, insert_stake, insert_stake, insert_support_tab, assemble_support, insert_support_tab, insert_support_tab, insert_stake, insert_support_tab, insert_stake, insert_stake, instruction, place_guyline, instruction, pickup/place_ventcover, tie_top, pickup/place_ventcover, place_guyline, place_guyline, pickup/place_ventcover, place_guyline, insert_stake\nNext Symbol:\n place_guyline\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, instruction, instruction, instruction, spread_tent, instruction, pickup/open_supportbag, assemble_support, assemble_support, instruction, instruction, pickup/open_stakebag, instruction, assemble_support, insert_support, insert_support_tab, insert_support_tab, insert_support, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, instruction, insert_stake, insert_stake, insert_stake, insert_stake, insert_stake, insert_stake, place_guyline, place_guyline, place_guyline, place_guyline, instruction, pickup/place_ventcover\nNext Symbol:\n instruction\n---\nInput Sequence:\n -1, pickup/open_tentbag, instruction, spread_tent, instruction, pickup/open_supportbag, assemble_support, assemble_support, instruction, instruction, insert_support, instruction, instruction, insert_support, instruction, instruction, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, instruction, pickup/open_stakebag, instruction, insert_stake, insert_stake, pickup/open_stakebag, insert_stake, insert_stake, instruction, place_guyline, place_guyline, place_guyline, place_guyline, place_guyline, instruction\nNext Symbol:\n pickup/place_ventcover\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, instruction, spread_tent, pickup/open_supportbag, assemble_support, insert_support, assemble_support, insert_support, insert_support_tab, insert_support_tab, insert_support_tab, insert_support_tab, pickup/open_stakebag, insert_stake, place_guyline, insert_stake, insert_stake, insert_stake, place_guyline, place_guyline, place_guyline\nNext Symbol:\n place_guyline\n---\nInput Sequence:\n -1, instruction, instruction, instruction, pickup/open_tentbag, pickup/open_stakebag, instruction, instruction, pickup/open_supportbag, spread_tent, instruction, pickup/place_ventcover, spread_tent, instruction, instruction, instruction, pickup/open_supportbag, assemble_support, instruction, assemble_support, instruction, insert_support, instruction, insert_support_tab, instruction, insert_support_tab, instruction, insert_support_tab, insert_support, insert_support_tab, insert_support_tab, insert_support_tab, instruction, pickup/open_stakebag, instruction, insert_stake, insert_stake, insert_stake, insert_stake, instruction, place_guyline, instruction, instruction, pickup/place_ventcover, instruction, place_guyline, instruction, instruction, instruction, pickup/place_ventcover, place_guyline, instruction, instruction, place_guyline, place_guyline, place_guyline\nNext Symbol:\n place_guyline\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, instruction, spread_tent, instruction, instruction, instruction, pickup/open_supportbag, assemble_support, instruction, instruction, instruction, instruction, assemble_support, instruction, instruction, insert_support, instruction, insert_support, insert_support_tab, insert_support, insert_support_tab, insert_support_tab, insert_support, insert_support_tab, insert_support, insert_support_tab, insert_support, insert_support_tab, insert_support_tab, instruction, pickup/open_stakebag, instruction, instruction, instruction, insert_stake, instruction, insert_stake, insert_stake, insert_stake, instruction, place_guyline, instruction, place_guyline, place_guyline, place_guyline, instruction, pickup/place_ventcover, instruction, instruction, pickup/place_ventcover, instruction, pickup/open_supportbag, pickup/open_stakebag, instruction, tie_top, pickup/place_ventcover\nNext Symbol:\n pickup/place_ventcover\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, spread_tent, pickup/open_supportbag, pickup/open_stakebag, pickup/place_ventcover, instruction, spread_tent, instruction, assemble_support, instruction, insert_support, instruction, insert_support, instruction, insert_support, assemble_support, insert_support, insert_support_tab, insert_support, instruction, insert_support_tab, insert_support_tab, insert_support, insert_support_tab, insert_support, insert_support_tab, insert_stake, spread_tent, insert_stake, spread_tent, insert_stake, spread_tent, insert_stake, place_guyline, tie_top, pickup/place_ventcover, instruction, instruction, pickup/place_ventcover, instruction, place_guyline, instruction, place_guyline, instruction, place_guyline, instruction, place_guyline, instruction, place_guyline, place_guyline\nNext Symbol:\n place_guyline\n---\nInput Sequence:\n -1, instruction, pickup/open_tentbag, pickup/open_supportbag, spread_tent, pickup/open_supportbag, assemble_support, assemble_support, insert_support, assemble_support, insert_support_tab, insert_support_tab, insert_support, assemble_support, assemble_support, insert_support_tab, insert_support_tab, tie_top, pickup/open_stakebag, insert_stake, spread_tent, insert_stake, insert_stake, spread_tent, insert_stake, place_guyline, place_guyline, place_guyline, place_guyline\nNext Symbol:\n pickup/place_ventcover\n---\n" 4 | } -------------------------------------------------------------------------------- /step_recognition/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os.path as osp 4 | 5 | import ipdb 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | from datasets.dataset_builder import DATA_LAYERS 10 | 11 | FEATURE_SIZES = { 12 | "rgb_anet_resnet50": 2048, 13 | "flow_anet_resnet50": 2048, 14 | "rgb_kinetics_bninception": 1024, 15 | "flow_kinetics_bninception": 1024, 16 | "rgb_kinetics_resnet50": 2048, 17 | "flow_kinetics_resnet50": 2048, 18 | "flow_nv_kinetics_bninception": 1024, 19 | "rgb_kinetics_i3d": 2048, 20 | "flow_kinetics_i3d": 2048, 21 | } 22 | 23 | 24 | @DATA_LAYERS.register("THUMOS") 25 | @DATA_LAYERS.register("TVSERIES") 26 | @DATA_LAYERS.register("ASSEMBLY101-O") 27 | @DATA_LAYERS.register("EPIC-TENT-O") 28 | class THUMOSDataset(data.Dataset): 29 | 30 | def __init__(self, cfg, mode="train"): 31 | self.root_path = cfg["root_path"] 32 | self.mode = mode 33 | self.training = mode == "train" 34 | self.window_size = cfg["window_size"] 35 | self.stride = cfg["stride"] 36 | data_name = cfg["data_name"] 37 | self.vids = json.load(open(cfg["video_list_path"]))[data_name][ 38 | mode + "_session_set" 39 | ] # [:5] # list of video names 40 | self.num_classes = cfg["num_classes"] 41 | self.inputs = [] 42 | self._load_features(cfg) 43 | self._init_features() 44 | 45 | def _load_features(self, cfg): 46 | self.removed = 0 47 | self.annotation_type = cfg["annotation_type"] 48 | self.rgb_type = cfg["rgb_type"] 49 | self.flow_type = cfg["flow_type"] 50 | self.target_all = {} 51 | self.rgb_inputs = {} 52 | self.flow_inputs = {} 53 | dummy_target = np.zeros((self.window_size - 1, self.num_classes)) 54 | dummy_rgb = np.zeros((self.window_size - 1, FEATURE_SIZES[cfg["rgb_type"]])) 55 | dummy_flow = np.zeros((self.window_size - 1, FEATURE_SIZES[cfg["flow_type"]])) 56 | for vid in self.vids: 57 | try: 58 | target = np.load( 59 | osp.join(self.root_path, self.annotation_type, vid + ".npy") 60 | ) 61 | rgb = np.load(osp.join(self.root_path, self.rgb_type, vid + ".npy")) 62 | # checkpoint con rgb al posto del flow 63 | if cfg["flow_type"] == "flow_anet_resnet50": 64 | flow = np.load( 65 | osp.join( 66 | self.root_path + "/rgb_as_flow", self.rgb_type, vid + ".npy" 67 | ) 68 | ) 69 | flow = np.zeros(flow.shape) 70 | else: # checkpoint con flow 71 | my_string = ( 72 | "assembly_optical_flow_BNInception/" + vid + "/assembling.npy" 73 | ) 74 | flow = np.load(osp.join(self.root_path, self.flow_type, my_string)) 75 | # # concatting dummy target at the front 76 | # ! LEO to save train data leave these lines (within the if block) with the comment 77 | if self.training: 78 | self.target_all[vid] = np.concatenate( 79 | (dummy_target, target), axis=0 80 | ) 81 | self.rgb_inputs[vid] = np.concatenate((dummy_rgb, rgb), axis=0) 82 | self.flow_inputs[vid] = np.concatenate((dummy_flow, flow), axis=0) 83 | else: 84 | self.target_all[vid] = target 85 | self.rgb_inputs[vid] = rgb 86 | self.flow_inputs[vid] = flow 87 | except Exception as e: 88 | # print excpetion 89 | print("---- Exception in loading video ", e) 90 | # remove the video from the list if it does not have the required features 91 | self.vids.remove(vid) 92 | self.removed += 1 93 | print("---- Removed video ", vid) 94 | print("---- Removed videos ", self.removed) 95 | 96 | def _init_features(self): 97 | # del self.inputs 98 | # gc.collect() 99 | self.inputs = [] 100 | # remove 'nusar-2021_action_both_9056-b08a_9056_user_id_2021-02-22_141934' 101 | if ( 102 | "nusar-2021_action_both_9056-b08a_9056_user_id_2021-02-22_141934" 103 | in self.vids 104 | ): 105 | self.vids.remove( 106 | "nusar-2021_action_both_9056-b08a_9056_user_id_2021-02-22_141934" 107 | ) 108 | ipdb.set_trace() 109 | for vid in self.vids: 110 | target = self.target_all[vid] 111 | #!Leo to save train data leave these lines (within the if block) with the comment 112 | ipdb.set_trace() 113 | if self.training: 114 | seed = np.random.randint(self.stride) 115 | for start, end in zip( 116 | range(seed, target.shape[0], self.stride), 117 | range(seed + self.window_size, target.shape[0] + 1, self.stride), 118 | ): 119 | self.inputs.append([vid, start, end, target[start:end]]) 120 | else: 121 | start = 0 122 | end = target.shape[0] 123 | self.inputs.append([vid, start, end, target[start:end]]) 124 | 125 | def __getitem__(self, index): 126 | vid, start, end, target = self.inputs[index] 127 | rgb_input = self.rgb_inputs[vid][start:end] 128 | flow_input = self.flow_inputs[vid][start:end] 129 | rgb_input = torch.tensor(rgb_input.astype(np.float32)) 130 | flow_input = torch.tensor(flow_input.astype(np.float32)) 131 | target = torch.tensor(target.astype(np.float32)) 132 | return rgb_input, flow_input, target, vid, start, end #! vid added by Leo 133 | 134 | def __len__(self): 135 | return len(self.inputs) 136 | 137 | 138 | @DATA_LAYERS.register("THUMOS_ANTICIPATION") 139 | @DATA_LAYERS.register("TVSERIES_ANTICIPATION") 140 | class THUMOSDataset(data.Dataset): 141 | 142 | def __init__(self, cfg, mode="train"): 143 | self.root_path = cfg["root_path"] 144 | self.mode = mode 145 | self.training = mode == "train" 146 | self.window_size = cfg["window_size"] 147 | self.stride = cfg["stride"] 148 | self.anticipation_length = cfg["anticipation_length"] 149 | data_name = cfg["data_name"].split("_")[0] 150 | self.vids = json.load(open(cfg["video_list_path"]))[data_name][ 151 | mode + "_session_set" 152 | ] # list of video names 153 | self.num_classes = cfg["num_classes"] 154 | self.inputs = [] 155 | self._load_features(cfg) 156 | self._init_features() 157 | 158 | def _load_features(self, cfg): 159 | self.annotation_type = cfg["annotation_type"] 160 | self.rgb_type = cfg["rgb_type"] 161 | self.flow_type = cfg["flow_type"] 162 | self.target_all = {} 163 | self.rgb_inputs = {} 164 | self.flow_inputs = {} 165 | dummy_target = np.zeros((self.window_size - 1, self.num_classes)) 166 | dummy_rgb = np.zeros((self.window_size - 1, FEATURE_SIZES[cfg["rgb_type"]])) 167 | dummy_flow = np.zeros((self.window_size - 1, FEATURE_SIZES[cfg["flow_type"]])) 168 | for vid in self.vids: 169 | target = np.load( 170 | osp.join(self.root_path, self.annotation_type, vid + ".npy") 171 | ) 172 | rgb = np.load(osp.join(self.root_path, self.rgb_type, vid + ".npy")) 173 | flow = np.load(osp.join(self.root_path, self.flow_type, vid + ".npy")) 174 | if self.training: 175 | self.target_all[vid] = np.concatenate((dummy_target, target), axis=0) 176 | self.rgb_inputs[vid] = np.concatenate((dummy_rgb, rgb), axis=0) 177 | self.flow_inputs[vid] = np.concatenate((dummy_flow, flow), axis=0) 178 | else: 179 | self.target_all[vid] = target 180 | self.rgb_inputs[vid] = rgb 181 | self.flow_inputs[vid] = flow 182 | 183 | def _init_features(self): 184 | del self.inputs 185 | gc.collect() 186 | self.inputs = [] 187 | 188 | for vid in self.vids: 189 | target = self.target_all[vid] 190 | if self.training: 191 | seed = np.random.randint(self.stride) 192 | for start, end in zip( 193 | range(seed, target.shape[0], self.stride), 194 | range( 195 | seed + self.window_size, 196 | target.shape[0] - self.anticipation_length, 197 | self.stride, 198 | ), 199 | ): 200 | self.inputs.append( 201 | [ 202 | vid, 203 | start, 204 | end, 205 | target[start:end], 206 | target[end : end + self.anticipation_length], 207 | ] 208 | ) 209 | else: 210 | start = 0 211 | end = target.shape[0] - self.anticipation_length 212 | ant_target = [] 213 | for s in range(0, target.shape[0] - self.anticipation_length): 214 | ant_target.append(target[s : s + self.anticipation_length]) 215 | 216 | self.inputs.append( 217 | [vid, start, end, target[start:end], np.array(ant_target)] 218 | ) 219 | 220 | def __getitem__(self, index): 221 | vid, start, end, target, ant_target = self.inputs[index] 222 | rgb_input = self.rgb_inputs[vid][start:end] 223 | flow_input = self.flow_inputs[vid][start:end] 224 | rgb_input = torch.tensor(rgb_input.astype(np.float32)) 225 | flow_input = torch.tensor(flow_input.astype(np.float32)) 226 | target = torch.tensor(target.astype(np.float32)) 227 | ant_target = torch.tensor(ant_target.astype(np.float32)) 228 | return rgb_input, flow_input, target, ant_target 229 | 230 | def __len__(self): 231 | return len(self.inputs) 232 | 233 | 234 | @DATA_LAYERS.register("FINEACTION") 235 | class FINEACTIONDataset(data.Dataset): 236 | 237 | def __init__(self, cfg, mode="train"): 238 | self.root_path = cfg["root_path"] 239 | self.mode = mode 240 | self.training = mode == "train" 241 | self.window_size = cfg["window_size"] 242 | self.stride = cfg["stride"] 243 | data_name = cfg["data_name"] 244 | self.vids = json.load(open(cfg["video_list_path"]))[data_name][ 245 | mode + "_session_set" 246 | ] # list of video names 247 | self.num_classes = cfg["num_classes"] 248 | self.inputs = [] 249 | self._load_features(cfg) 250 | self._init_features() 251 | 252 | def _load_features(self, cfg): 253 | self.annotation_type = cfg["annotation_type"] 254 | self.rgb_type = cfg["rgb_type"] 255 | self.flow_type = cfg["flow_type"] 256 | 257 | def _init_features(self, seed=0): 258 | # self.inputs = [] 259 | del self.inputs 260 | gc.collect() 261 | self.inputs = [] 262 | for vid in self.vids: 263 | target = np.load( 264 | osp.join(self.root_path, self.annotation_type, vid + ".npy") 265 | ) 266 | if self.training: 267 | seed = np.random.randint(self.stride) 268 | for start, end in zip( 269 | range(seed, target.shape[0], self.stride), 270 | range(seed + self.window_size, target.shape[0] + 1, self.stride), 271 | ): 272 | self.inputs.append([vid, start, end]) 273 | else: 274 | start = 0 275 | end = target.shape[0] 276 | self.inputs.append([vid, start, end]) 277 | 278 | def __getitem__(self, index): 279 | vid, start, end = self.inputs[index] 280 | rgb_input = np.load( 281 | osp.join(self.root_path, self.rgb_type, vid + ".npy"), mmap_mode="r" 282 | )[start:end] 283 | flow_input = np.load( 284 | osp.join(self.root_path, self.flow_type, vid + ".npy"), mmap_mode="r" 285 | )[start:end] 286 | target = np.load( 287 | osp.join(self.root_path, self.annotation_type, vid + ".npy"), mmap_mode="r" 288 | )[start:end] 289 | rgb_input = torch.tensor(rgb_input.astype(np.float32)) 290 | flow_input = torch.tensor(flow_input.astype(np.float32)) 291 | target = torch.tensor(target.astype(np.float32)) 292 | return rgb_input, flow_input, target 293 | 294 | def __len__(self): 295 | return len(self.inputs) 296 | -------------------------------------------------------------------------------- /step_anticipation/src/models/llm_hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import re 5 | import time 6 | from typing import Optional 7 | 8 | import transformers 9 | 10 | transformers.logging.set_verbosity_error() 11 | import fire 12 | import ipdb 13 | import numpy as np 14 | import torch 15 | from transformers import pipeline 16 | 17 | BASE_PATH = "step_anticipation/data" 18 | CONTEXT_PROMPT_PATH = f"{BASE_PATH}/context_prompt" 19 | PREDICTIONS_PATH = f"{BASE_PATH}/predictions" 20 | 21 | TIME_CNT = [] 22 | 23 | 24 | class HFModel: 25 | def __init__( 26 | self, model_name: str, max_seq_len: int = 512, max_batch_size: int = 8 27 | ): 28 | self.pipe = pipeline( 29 | "text-generation", 30 | model=model_name, 31 | tokenizer=model_name, 32 | # max_batch_size=max_batch_size, 33 | # max_seq_len=max_seq_len, 34 | device_map="auto", 35 | ) 36 | 37 | def text_completion( 38 | self, prompts, max_gen_len: Optional[int], temperature: float, top_p: float 39 | ): 40 | generate_kwargs = { 41 | "max_new_tokens": max_gen_len, 42 | # "max_gen_len": max_gen_len, 43 | "temperature": temperature, 44 | "top_p": top_p, 45 | } 46 | 47 | start_time = time.time() 48 | outputs = self.pipe(prompts, **generate_kwargs) 49 | TIME_CNT.append(time.time() - start_time) 50 | 51 | # Flatten the outputs (each prompt returns a list of outputs) 52 | flattened = [] 53 | for res in outputs: 54 | if isinstance(res, list): 55 | flattened.append(res[0]) 56 | else: 57 | flattened.append(res) 58 | return flattened 59 | 60 | 61 | def get_metrics(preds, gts): 62 | tp, fp, fn, tn = 0, 0, 0, 0 63 | count, samples = 0, 0 64 | for k in gts.keys(): 65 | gt = gts[k] 66 | pred = preds[k] 67 | matches = np.array([g in p for g, p in zip(gt, pred)]) 68 | 69 | count += np.sum(matches) 70 | samples += len(matches) 71 | # the last one is a mistake, a mismatch is expected 72 | correct = matches[:-1] 73 | mistake = matches[-1] 74 | 75 | tn += np.sum(correct) 76 | fp += np.sum(~correct) 77 | tp += int(not mistake) 78 | fn += int(mistake) 79 | 80 | acc = (tp + tn) / (tp + tn + fp + fn) 81 | precision = tp / (tp + fp) 82 | recall = tp / (tp + fn) 83 | f1 = 2 * precision * recall / (precision + recall) 84 | ratio = count / samples 85 | 86 | return { 87 | "tp": tp, 88 | "fp": fp, 89 | "fn": fn, 90 | "tn": tn, 91 | "accuracy": acc, 92 | "precision": precision, 93 | "recall": recall, 94 | "f1": f1, 95 | "ratio": ratio, 96 | "count": count, 97 | "samples": samples, 98 | } 99 | 100 | 101 | def get_toy(name: str) -> str: 102 | toy = name.split("-")[2].split("_")[0] 103 | return toy 104 | 105 | 106 | def load_data(path: str) -> dict: 107 | data = json.load(open(path, "r")) 108 | return data 109 | 110 | 111 | def remove_sequenceInput(prompt, toy_class): 112 | new_prompt = "" 113 | start = 0 114 | for m in re.finditer(r"Sequence type: [a-zA-Z0-9]{3,}\n", prompt): 115 | new_prompt += prompt[start : m.start()] 116 | new_prompt += f"Sequence type: {toy_class}\n" 117 | start = m.end() 118 | new_prompt += prompt[start:] 119 | return new_prompt.replace("Symbol", "Sequence") 120 | 121 | 122 | def anticipation( 123 | seq: list, 124 | prompt: str, 125 | toy: Optional[str], 126 | toy_class: Optional[str], 127 | llm, 128 | max_gen_len: Optional[int], 129 | temperature: float, 130 | top_p: float, 131 | num_samples: int, 132 | clean_prediction: bool, 133 | type_prompt="num", 134 | prompt_context="default", 135 | ): 136 | preds, gts = [], [] 137 | 138 | if type_prompt == "emoji": 139 | prompt = prompt.replace("-1", "👉") 140 | 141 | if toy_class: 142 | remove_toySequence = True 143 | prompt = remove_sequenceInput(prompt, toy_class) 144 | else: 145 | remove_toySequence = False 146 | 147 | # ! Don't know why it was inside the loop 148 | # >>> 149 | prompt_builder = load_data(f"{CONTEXT_PROMPT_PATH}/context_prompt.json") 150 | init = prompt_builder[prompt_context]["init"] 151 | 152 | if remove_toySequence: 153 | prompt_ = f"{prompt}{init} {toy_class}\n" 154 | else: 155 | prompt_ = f"{prompt}{init} {toy}\n" 156 | # <<< 157 | 158 | for i in range(len(seq)): 159 | if type_prompt == "emoji": 160 | hist, action = ["👉"] + seq[:i], seq[i] 161 | else: 162 | hist, action = [-1] + seq[:i], seq[i] 163 | 164 | # initialize history if empty 165 | if type_prompt == "emoji": 166 | hist = ["👉"] if len(hist) == 0 else hist 167 | else: 168 | hist = [-1] if len(hist) == 0 else hist 169 | 170 | print(f"[INFO] >>> {hist} -> {action}") 171 | 172 | input_builder = prompt_builder[prompt_context]["input"] 173 | prompt_ += f"{input_builder}\n {', '.join(map(str, hist))}\n" 174 | 175 | output_builder = prompt_builder[prompt_context]["output"] 176 | prompt_ += f"{output_builder}\n" 177 | 178 | pred = set() 179 | for _ in range(num_samples): 180 | prompts = [prompt_] * num_samples 181 | 182 | # NEW: Using our HFModel text_completion method (which calls the Hugging Face pipeline) 183 | results = llm.text_completion( 184 | prompts, 185 | max_gen_len=max_gen_len, 186 | temperature=temperature, 187 | top_p=top_p, 188 | ) 189 | for res in results: 190 | res = res["generated_text"].replace(prompt_, "") # Remove the prompt 191 | # v = re.sub(r"^[ \n\.,;:]+|[ \n\.,;:]+$", "", res) 192 | v = re.sub(r"[ \n\.,;:]+", "", res) 193 | v = v.strip("_") 194 | 195 | if type_prompt == "num": 196 | v = re.sub(r"^[^0-9]*|[^0-9]*$", "", v) 197 | try: 198 | v = int(v) 199 | except: 200 | pass 201 | 202 | if len(hist) in out_plot: 203 | out_plot[len(hist)]["sum"] += len(pred) 204 | out_plot[len(hist)]["count"] += 1 205 | else: 206 | out_plot[len(hist)] = {"sum": len(pred), "count": 1} 207 | 208 | if type_prompt == "num": 209 | pred.add(v) 210 | elif type_prompt == "emoji": 211 | try: 212 | pred.add(v[0]) 213 | except: 214 | pred.add("") 215 | else: 216 | pred.add(v[: v.find("\n")]) 217 | gts.append(action) 218 | preds.append(pred) 219 | print(f"[INFO] >>>> {action} in {pred} ---> {action in pred}") 220 | 221 | return preds, gts 222 | 223 | 224 | def main( 225 | model_name: str, 226 | max_seq_len: int = 512, 227 | max_batch_size: int = 8, 228 | max_gen_len: Optional[int] = 20, 229 | temperature: float = 0.6, 230 | top_p: float = 0.9, 231 | num_samples: int = 1, 232 | use_gt: bool = False, 233 | type_prompt: str = "num", 234 | clean_prediction: bool = False, 235 | eval_metrics: bool = True, 236 | dataset: str = "assembly", 237 | toy_class_context: bool = False, 238 | recognition_model: str = "miniROAD", 239 | prompt_context: str = "default", 240 | ): 241 | if dataset == "assembly": 242 | if toy_class_context: 243 | toy2class = json.load(open("assets/toy2class.json", "r")) 244 | contexts = load_data( 245 | f"{CONTEXT_PROMPT_PATH}/assembly_context_prompt_train.json" 246 | ) 247 | else: 248 | contexts = load_data( 249 | f"{CONTEXT_PROMPT_PATH}/supplementary/assembly_context_prompt_train_onlyToy.json" 250 | ) 251 | 252 | if recognition_model == "OadTR": 253 | seqs = load_data(f"{PREDICTIONS_PATH}/output_OadTR_Assembly101-O.json") 254 | elif recognition_model == "miniROAD": 255 | seqs = load_data(f"{PREDICTIONS_PATH}/output_miniROAD_Assembly101-O.json") 256 | 257 | if type_prompt == "alpha": 258 | # load the idx2action mapping 259 | idx2action = pickle.load(open(f"{BASE_PATH}/idx2action.pkl", "rb")) 260 | elif type_prompt == "emoji": 261 | # load the idx2action mapping 262 | idx2emoji = json.load(open(f"{BASE_PATH}/idx2emoji.json", "r")) 263 | 264 | elif dataset == "epictent": 265 | contexts = load_data( 266 | f"{CONTEXT_PROMPT_PATH}/epictent_context_prompt_train.json" 267 | ) 268 | 269 | if recognition_model == "OadTR": 270 | seqs = load_data(f"{PREDICTIONS_PATH}/output_OadTR_Epic-Tent-O.json") 271 | elif recognition_model == "miniROAD": 272 | seqs = load_data(f"{PREDICTIONS_PATH}/output_miniROAD_Epic-tent-O.json") 273 | else: 274 | raise ValueError(f"Dataset {dataset} not supported") 275 | 276 | preds, gts = {}, {} 277 | global out_plot 278 | out_plot = {} 279 | 280 | llm = HFModel( 281 | model_name=model_name, max_seq_len=max_seq_len, max_batch_size=max_batch_size 282 | ) 283 | 284 | for i, (k, v) in enumerate(seqs.items()): 285 | if dataset == "assembly": 286 | toy = get_toy(k) 287 | print(f"[INFO] > {i}/{len(seqs)}: {toy}") 288 | 289 | if toy_class_context: 290 | toy_class = toy2class[toy] 291 | prompt = contexts[toy_class][type_prompt] 292 | else: 293 | toy_class = None 294 | try: 295 | prompt = contexts[toy][type_prompt] 296 | except: 297 | prompt = "" 298 | elif dataset == "epictent": 299 | toy = None 300 | toy_class = None 301 | prompt = contexts[type_prompt] 302 | print(f"[INFO] > {i}/{len(seqs)}") 303 | else: 304 | raise ValueError(f"Dataset {dataset} not supported") 305 | 306 | seq = v["gt"] if use_gt else v["pred"] 307 | 308 | print(f"[INFO] >> {seq}") 309 | 310 | if type_prompt == "alpha" and dataset == "assembly": 311 | seq = [idx2action[s] for s in seq] 312 | elif type_prompt == "emoji": 313 | seq = [idx2emoji[str(s)]["escape"] for s in seq] 314 | 315 | pred, gt = anticipation( 316 | seq=seq, 317 | prompt=prompt, 318 | toy=toy, 319 | toy_class=toy_class, 320 | llm=llm, 321 | max_gen_len=max_gen_len, 322 | temperature=temperature, 323 | top_p=top_p, 324 | num_samples=num_samples, 325 | clean_prediction=clean_prediction, 326 | type_prompt=type_prompt, 327 | prompt_context=prompt_context, 328 | ) 329 | 330 | preds[k] = pred 331 | gts[k] = gt 332 | 333 | matches = [int(g in p) for p, g in zip(pred, gt)] 334 | 335 | model_identifier = model_name.split("/")[-1] # Use model name for folder naming 336 | save_folder = "{}_{:d}_{}_{:d}_{:d}_{:.2f}_{}_{}".format( 337 | model_identifier, 338 | use_gt, 339 | type_prompt, 340 | int(clean_prediction), 341 | num_samples, 342 | temperature, 343 | dataset, 344 | prompt_context, 345 | ) 346 | 347 | if not os.path.exists(f"results/{save_folder}"): 348 | os.makedirs(f"results/{save_folder}") 349 | 350 | if eval_metrics: 351 | metrics = get_metrics(preds=preds, gts=gts) 352 | print(f"[INFO] {metrics}") 353 | print( 354 | "Ratio: {:.3f}\t({:d}/{:d})".format( 355 | metrics["ratio"], metrics["count"], metrics["samples"] 356 | ) 357 | ) 358 | print( 359 | "TP: {:d}, FP: {:d}, FN: {:d}, TN: {:d}".format( 360 | metrics["tp"], metrics["fp"], metrics["fn"], metrics["tn"] 361 | ) 362 | ) 363 | print( 364 | "Accuracy: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, F1: {:.3f}".format( 365 | metrics["accuracy"], 366 | metrics["precision"], 367 | metrics["recall"], 368 | metrics["f1"], 369 | ) 370 | ) 371 | print(f"Average time: {sum(TIME_CNT) / len(TIME_CNT)}") 372 | 373 | pickle.dump(gts, open(f"results/{save_folder}/hf_gts.pkl", "wb")) 374 | pickle.dump(preds, open(f"results/{save_folder}/hf_preds.pkl", "wb")) 375 | pickle.dump(out_plot, open(f"results/{save_folder}/plot.pkl", "wb")) 376 | 377 | 378 | if __name__ == "__main__": 379 | fire.Fire(main) 380 | -------------------------------------------------------------------------------- /step_anticipation/src/models/llama_meta.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import re 5 | from tempfile import tempdir 6 | from typing import Optional 7 | 8 | import fire 9 | import numpy as np 10 | from llama import Dialog, Llama 11 | from llama.generation import Llama 12 | 13 | 14 | def get_metrics(preds, gts): 15 | tp, fp, fn, tn = 0, 0, 0, 0 16 | count, samples = 0, 0 17 | for k in gts.keys(): 18 | gt = gts[k] 19 | pred = preds[k] 20 | matches = np.array([g in p for g, p in zip(gt, pred)]) 21 | 22 | count += np.sum(matches) 23 | samples += len(matches) 24 | # the last one is a mistake, a mismatch is expected 25 | # all the actions all correct procedures except the last one 26 | correct = matches[:-1] 27 | mistake = matches[-1] 28 | 29 | # tn: correct seen as correct 30 | tn += np.sum(correct) 31 | # fp: correct seen as mistake 32 | fp += np.sum(~correct) 33 | # tp: mistake seen as mistake 34 | tp += int(not mistake) 35 | # fn: mistake seen as correct 36 | fn += int(mistake) 37 | 38 | # * metrics 39 | # accuracy 40 | acc = (tp + tn) / (tp + tn + fp + fn) 41 | precision = tp / (tp + fp) 42 | recall = tp / (tp + fn) 43 | f1 = 2 * precision * recall / (precision + recall) 44 | ratio = count / samples 45 | 46 | return { 47 | "tp": tp, 48 | "fp": fp, 49 | "fn": fn, 50 | "tn": tn, 51 | "accuracy": acc, 52 | "precision": precision, 53 | "recall": recall, 54 | "f1": f1, 55 | "ratio": ratio, 56 | "count": count, 57 | "samples": samples, 58 | } 59 | 60 | 61 | def get_toy(name: str) -> str: 62 | """Get the toy name from the file name 63 | 64 | Args: 65 | name (str): file name 66 | 67 | Returns: 68 | str: toy name 69 | """ 70 | toy = name.split("-")[2].split("_")[0] 71 | return toy 72 | 73 | 74 | def load_data(path: str) -> dict: 75 | """ 76 | Load the data from the json file 77 | 78 | Args: 79 | path (str): path to the json file 80 | 81 | Returns: 82 | dict: dictionary with the data 83 | """ 84 | data = json.load(open(path, "r")) 85 | return data 86 | 87 | 88 | def remove_sequenceInput(prompt, toy_class): 89 | # if toy_class is not None, remove the reference to the single toy class, i.e. "a21" and replace with the superclass, i.e. "dumper" 90 | new_prompt = "" 91 | start = 0 92 | count = 1 93 | for m in re.finditer(r"Sequence type: [a-zA-Z0-9]{3,}\n", prompt): 94 | new_prompt += prompt[start : m.start()] 95 | new_prompt += f"Sequence type: {toy_class}\n" 96 | count += 1 97 | start = m.end() 98 | new_prompt += prompt[start:] 99 | return new_prompt.replace("Symbol", "Sequence") 100 | 101 | 102 | def anticipation( 103 | seq: list, 104 | prompt: str, 105 | toy: Optional[str], 106 | toy_class: Optional[str], 107 | llm, 108 | max_gen_len: Optional[int], 109 | temperature: float, 110 | top_p: float, 111 | num_samples: int, 112 | clean_prediction: bool, 113 | type_prompt="num", 114 | prompt_context="default", 115 | ): 116 | preds, gts = [], [] 117 | 118 | if type_prompt == "emoji": 119 | # replacing the start of the sequence "-1" with an emoji 120 | prompt = prompt.replace("-1", "👉") 121 | 122 | if toy_class: 123 | remove_toySequence = True 124 | prompt = remove_sequenceInput(prompt, toy_class) 125 | else: 126 | remove_toySequence = False 127 | 128 | # iterate over the sequence 129 | for i in range(len(seq)): 130 | prompt_builder = load_data("data/context_prompt/context_prompt.json") 131 | init = prompt_builder[prompt_context]["init"] 132 | 133 | if remove_toySequence: 134 | prompt_ = f"{prompt}{init} {toy_class}\n" 135 | else: 136 | prompt_ = f"{prompt}{init} {toy}\n" 137 | 138 | if type_prompt == "emoji": 139 | hist, action = ["👉"] + seq[:i], seq[i] 140 | else: 141 | hist, action = [-1] + seq[:i], seq[i] 142 | 143 | # hist, action = seq[-2:-1], seq[-1] # !LS 144 | 145 | # initialize the history with a starting num or emoji 146 | if type_prompt == "emoji": 147 | hist = ["👉"] if len(hist) == 0 else hist 148 | else: 149 | hist = [-1] if len(hist) == 0 else hist 150 | 151 | print(f"[INFO] >>> {hist} -> {action}") 152 | 153 | # Add the history 154 | input_builder = prompt_builder[prompt_context]["input"] 155 | prompt_ += f"{input_builder}\n {', '.join(map(str,hist))}\n" 156 | 157 | # Add the action 158 | output_builder = prompt_builder[prompt_context]["output"] 159 | prompt_ += f"{output_builder}\n" 160 | 161 | # LLM 162 | pred = set() 163 | for sample in range(num_samples): 164 | 165 | # if needed multiple predictions for the same prompt 166 | prompts = [prompt_] * num_samples 167 | 168 | # predict the next symbol with LLAMA2 169 | results = llm.text_completion( 170 | prompts, 171 | max_gen_len=max_gen_len, 172 | temperature=temperature, 173 | top_p=top_p, 174 | ) 175 | # pdb.set_trace() 176 | for res in results: 177 | # replace the following patterns at the beginning and end of the string: 178 | # - whitespaces 179 | # - newlines 180 | # - puntuaction 181 | v = re.sub(r"^[ \n\.,;:]+|[ \n\.,;:]+$", "", res["generation"]) 182 | v = res["generation"].strip("_") 183 | 184 | if type_prompt == "num": 185 | # remove non numeric characters from left and right 186 | v = re.sub(r"^[^0-9]*|[^0-9]*$", "", v) 187 | try: 188 | v = int(v) 189 | except: 190 | pass 191 | 192 | if len(hist) in out_plot: 193 | out_plot[len(hist)]["sum"] += len(pred) 194 | out_plot[len(hist)]["count"] += 1 195 | else: 196 | out_plot[len(hist)] = {"sum": len(pred), "count": 1} 197 | 198 | if type_prompt == "num": 199 | pred.add(v) 200 | elif type_prompt == "emoji": 201 | try: 202 | pred.add(v[0]) 203 | except: 204 | pred.add("") 205 | else: 206 | pred.add(v[: v.find("\n")]) 207 | # for p in re.findall(r'\d+', v): 208 | # pred.add(int(p)) 209 | 210 | # for p in re.findall(r"[^\w\s,]", v): 211 | # pred.add(p) 212 | 213 | gts.append(action) 214 | preds.append(pred) 215 | print(f"[INFO] >>>> {action} in {pred} ---> {action in pred}") 216 | 217 | return preds, gts 218 | 219 | 220 | def main( 221 | ckpt_dir: str, 222 | tokenizer_path: str, 223 | max_seq_len: int = 512, 224 | max_batch_size: int = 8, 225 | max_gen_len: Optional[int] = None, 226 | temperature: float = 0.6, 227 | top_p: float = 0.9, 228 | num_samples: int = 1, 229 | use_gt: bool = False, 230 | type_prompt: str = "num", 231 | clean_prediction: bool = False, 232 | eval_metrics: bool = True, 233 | dataset: str = "assembly", 234 | toy_class_context: bool = False, 235 | recognition_model: str = "miniROAD", # select which recognition model to use. ["OadTR", "miniROAD"] 236 | prompt_context: str = "default", # select which context prompt to use. ["default", "unreferenced", "elaborate", "no-context"] 237 | ): 238 | 239 | if dataset == "assembly": 240 | 241 | if toy_class_context: 242 | # load the same toy_class, i.e. "a01": "excavator", as context 243 | toy2class = json.load(open("assets/toy2class.json", "r")) 244 | contexts = load_data( 245 | "data/context_prompt/assembly_context_prompt_train.json" 246 | ) 247 | else: 248 | # load only the same toy examples as context 249 | contexts = load_data( 250 | "data/context_prompt/supplementary/assembly_context_prompt_train_onlyToy.json" 251 | ) 252 | 253 | if recognition_model == "OadTR": 254 | # using the predictions from the OadTR recognition model 255 | seqs = load_data("data/predictions/output_OadTR_Assembly101-O.json") 256 | elif recognition_model == "miniROAD": 257 | # using the predictions from the miniROAD recognition model 258 | seqs = load_data("data/predictions/output_miniROAD_Assembly101-O.json") 259 | 260 | if type_prompt == "alpha": 261 | # load the idx2action mapping 262 | idx2action = pickle.load(open("data/idx2action.pkl", "rb")) 263 | elif type_prompt == "emoji": 264 | # load the idx2action mapping 265 | idx2emoji = json.load(open("data/idx2emoji.json", "r")) 266 | 267 | elif dataset == "epictent": 268 | 269 | contexts = load_data("data/context_prompt/epictent_context_prompt_train.json") 270 | 271 | if recognition_model == "OadTR": 272 | # using the predictions from the OadTR recognition model 273 | seqs = load_data("data/predictions/output_OadTR_Epic-Tent-O.json") 274 | elif recognition_model == "miniROAD": 275 | # using the predictions from the miniROAD recognition model 276 | seqs = load_data("data/output_miniROAD_Epic-Tent-O_edo.json") 277 | 278 | if type_prompt == "emoji": 279 | idx2emoji = json.load(open("data/idx2emoji.json", "r")) 280 | 281 | else: 282 | raise ValueError(f"Dataset {dataset} not supported") 283 | 284 | # * Global variables 285 | preds, gts = {}, {} 286 | global out_plot 287 | out_plot = {} 288 | 289 | llm = Llama.build( 290 | ckpt_dir=ckpt_dir, 291 | tokenizer_path=tokenizer_path, 292 | max_seq_len=max_seq_len, 293 | max_batch_size=max_batch_size, 294 | ) 295 | 296 | # * Split size 297 | # Due to memory constraints we cannot have a global context, 298 | # so we split it in `context_splits` parts. 299 | for i, (k, v) in enumerate(seqs.items()): 300 | if dataset == "assembly": 301 | toy = get_toy(k) 302 | print(f"[INFO] > {i}/{len(seqs)}: {toy}") 303 | 304 | if toy_class_context: 305 | toy_class = toy2class[toy] 306 | prompt = contexts[toy_class][type_prompt] 307 | else: 308 | toy_class = None 309 | try: 310 | prompt = contexts[toy][type_prompt] 311 | except: 312 | prompt = "" 313 | 314 | elif dataset == "epictent": 315 | toy = None 316 | toy_class = None 317 | prompt = contexts[type_prompt] 318 | print(f"[INFO] > {i}/{len(seqs)}") 319 | else: 320 | raise ValueError(f"Dataset {dataset} not supported") 321 | 322 | seq = v["gt"] if use_gt else v["pred"] 323 | 324 | print(f"[INFO] >> {seq}") 325 | 326 | # convert action numbers to string or emoji if requested 327 | if type_prompt == "alpha" and dataset == "assembly": 328 | seq = [idx2action[s] for s in seq] 329 | elif type_prompt == "emoji": 330 | seq = [idx2emoji[str(s)]["escape"] for s in seq] 331 | 332 | pred, gt = anticipation( 333 | seq=seq, 334 | prompt=prompt, 335 | toy=toy, 336 | toy_class=toy_class, 337 | llm=llm, 338 | max_gen_len=max_gen_len, 339 | temperature=temperature, 340 | top_p=top_p, 341 | num_samples=num_samples, 342 | clean_prediction=clean_prediction, 343 | type_prompt=type_prompt, 344 | prompt_context=prompt_context, 345 | ) 346 | 347 | preds[k] = pred 348 | gts[k] = gt 349 | 350 | matches = [int(g in p) for p, g in zip(pred, gt)] 351 | 352 | # save preds and gts in pickle 353 | model = os.path.basename(ckpt_dir).split("-")[-1] 354 | save_folder = "{}_{:d}_{}_{:d}_{:d}_{:.2f}_{}_{}".format( 355 | model, 356 | use_gt, 357 | type_prompt, 358 | clean_prediction, 359 | num_samples, 360 | temperature, 361 | dataset, 362 | prompt_context, 363 | ) 364 | 365 | if not os.path.exists(f"results/{save_folder}"): 366 | os.makedirs(f"results/{save_folder}", exist_ok=True) 367 | 368 | if eval_metrics: 369 | metrics = get_metrics(preds=preds, gts=gts) 370 | print(f"[INFO] {metrics}") 371 | print( 372 | "Ratio: {:.3f}\t({:d}/{:d})".format( 373 | metrics["ratio"], metrics["count"], metrics["samples"] 374 | ) 375 | ) 376 | print( 377 | "TP: {:d}, FP: {:d}, FN: {:d}, TN: {:d}".format( 378 | metrics["tp"], metrics["fp"], metrics["fn"], metrics["tn"] 379 | ) 380 | ) 381 | print( 382 | "Accuracy: {:.3f}, Precision: {:.3f}, Recall: {:.3f}, F1: {:.3f}".format( 383 | metrics["accuracy"], 384 | metrics["precision"], 385 | metrics["recall"], 386 | metrics["f1"], 387 | ) 388 | ) 389 | pickle.dump(gts, open(f"results/{save_folder}/llama_gts.pkl", "wb")) 390 | pickle.dump(preds, open(f"results/{save_folder}/llama_preds.pkl", "wb")) 391 | pickle.dump(out_plot, open(f"results/{save_folder}/plot.pkl", "wb")) 392 | 393 | 394 | if __name__ == "__main__": 395 | fire.Fire(main) 396 | --------------------------------------------------------------------------------