├── src ├── models │ ├── __init__.py │ ├── semimarkov │ │ ├── __init__.py │ │ ├── semimarkov_utils.py │ │ └── semimarkov.py │ ├── model.py │ ├── flow.py │ ├── framewise.py │ ├── test_semimarkov.py │ └── sequential.py ├── utils │ ├── __init__.py │ ├── utils.py │ └── logger.py ├── evaluation │ ├── __init__.py │ ├── f1.py │ └── accuracy.py ├── data │ ├── features.py │ ├── breakfast.py │ └── crosstask.py └── main.py ├── .gitignore ├── decode.sh ├── run_crosstask_no-bkg.sh ├── run_crosstask_i3d-resnet.sh ├── run_crosstask_i3d-resnet-audio.sh ├── decode_oracle.sh ├── run_crosstask_i3d-resnet-audio-narration.sh ├── run_crosstask_i3d-resnet_no-bkg.sh ├── run_crosstask_i3d-resnet-audio_no-bkg.sh ├── run_crosstask_i3d-resnet-audio-narration_no-bkg.sh ├── decode_constrained.sh ├── data └── breakfast │ └── mapping.txt ├── env.yml └── README.md /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/semimarkov/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /.idea/ 3 | /act-recog.iml 4 | /data/ 5 | -------------------------------------------------------------------------------- /decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | expt_folder=$1 4 | 5 | line=$(grep "src/main.py" ${expt_folder}/log.txt | head -n1) 6 | 7 | if [[ -z $line ]] 8 | then 9 | echo "command not found in ${expt_folder}/log.txt" 10 | fi 11 | 12 | decode_line=${line/model_output_path/model_input_path} 13 | 14 | python -u $decode_line | tee ${expt_folder}/decode.out 15 | -------------------------------------------------------------------------------- /run_crosstask_no-bkg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_no-bkg/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --model_output_path $output_path \ 14 | --remove_background \ 15 | $@ \ 16 | | tee ${output_path}/log.txt 17 | -------------------------------------------------------------------------------- /run_crosstask_i3d-resnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_i3d-resnet/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --crosstask_feature_groups i3d resnet \ 14 | --model_output_path $output_path \ 15 | $@ \ 16 | | tee ${output_path}/log.txt 17 | -------------------------------------------------------------------------------- /run_crosstask_i3d-resnet-audio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_i3d-resnet-audio/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --crosstask_feature_groups i3d resnet audio \ 14 | --model_output_path $output_path \ 15 | $@ \ 16 | | tee ${output_path}/log.txt 17 | -------------------------------------------------------------------------------- /decode_oracle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | expt_folder=$1 4 | 5 | line=$(grep "src/main.py" ${expt_folder}/log.txt | head -n1) 6 | 7 | if [[ -z $line ]] 8 | then 9 | echo "command not found in ${expt_folder}/log.txt" 10 | exit 1; 11 | fi 12 | 13 | decode_line=${line/model_output_path/model_input_path} 14 | 15 | decode_line="$decode_line --force_optimal_assignment" 16 | 17 | python -u $decode_line | tee ${expt_folder}/decode-optimal-assignment.out 18 | -------------------------------------------------------------------------------- /run_crosstask_i3d-resnet-audio-narration.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_i3d-resnet-audio-narration/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --crosstask_feature_groups i3d resnet audio narration \ 14 | --model_output_path $output_path \ 15 | $@ \ 16 | | tee ${output_path}/log.txt 17 | -------------------------------------------------------------------------------- /run_crosstask_i3d-resnet_no-bkg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_i3d-resnet_no-bkg/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --crosstask_feature_groups i3d resnet \ 14 | --model_output_path $output_path \ 15 | --remove_background \ 16 | $@ \ 17 | | tee ${output_path}/log.txt 18 | -------------------------------------------------------------------------------- /run_crosstask_i3d-resnet-audio_no-bkg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_i3d-resnet-audio_no-bkg/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --crosstask_feature_groups i3d resnet audio \ 14 | --model_output_path $output_path \ 15 | --remove_background \ 16 | $@ \ 17 | | tee ${output_path}/log.txt 18 | -------------------------------------------------------------------------------- /run_crosstask_i3d-resnet-audio-narration_no-bkg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_name=$1 4 | shift 5 | output_path="expts/crosstask_i3d-resnet-audio-narration_no-bkg/${output_name}" 6 | 7 | mkdir -p $output_path 8 | 9 | export PYTHONPATH="src/":$PYTHONPATH 10 | 11 | python -u src/main.py \ 12 | --dataset crosstask \ 13 | --crosstask_feature_groups i3d resnet audio narration \ 14 | --model_output_path $output_path \ 15 | --remove_background \ 16 | $@ \ 17 | | tee ${output_path}/log.txt 18 | -------------------------------------------------------------------------------- /decode_constrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | expt_folder=$1 4 | 5 | line=$(grep "src/main.py" ${expt_folder}/log.txt | head -n1) 6 | 7 | if [[ -z $line ]] 8 | then 9 | echo "command not found in ${expt_folder}/log.txt" 10 | exit 1; 11 | fi 12 | 13 | decode_line=${line/model_output_path/model_input_path} 14 | 15 | decode_line=${decode_line/--sm_constrain_with_narration train/} 16 | decode_line="$decode_line --sm_constrain_with_narration test" 17 | 18 | python -u $decode_line | tee ${expt_folder}/decode-constrain-test.out 19 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def all_equal(xs): 5 | xs = list(xs) 6 | return all(x == xs[0] for x in xs[1:]) 7 | 8 | 9 | def nested_dict_map(nested_dict, value_map): 10 | """ 11 | Apply function value_map to each value inside a two-level nested dictionary 12 | 13 | :param nested_dict: {k1: {k2: v}} 14 | :param value_map: k1, k2, v -> v' 15 | :return: {k1: {k2: v'}} 16 | """ 17 | return { 18 | outer_key: { 19 | inner_key: value_map(outer_key, inner_key, value) 20 | for inner_key, value in inner_dict.items() 21 | } 22 | for outer_key, inner_dict in nested_dict.items() 23 | } 24 | 25 | 26 | def load_pickle(fname): 27 | with open(fname, 'rb') as f: 28 | return pickle.load(f) 29 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Logger parameters for the entire process. 4 | """ 5 | 6 | __author__ = 'Anna Kukleva' 7 | __date__ = 'November 2018' 8 | 9 | import sys 10 | 11 | 12 | import logging 13 | 14 | logger = logging.getLogger('basic') 15 | logger.setLevel(logging.DEBUG) 16 | 17 | ch = logging.StreamHandler(sys.stdout) 18 | ch.setLevel(logging.DEBUG) 19 | 20 | # formatter = logging.Formatter('%(asctime)s - %(levelno)s - %(filename)s - ' 21 | # '%(funcName)s - %(message)s') 22 | formatter = logging.Formatter('%(message)s') 23 | ch.setFormatter(formatter) 24 | logger.addHandler(ch) 25 | 26 | def path_logger(filename): 27 | global logger 28 | fh = logging.FileHandler(filename, mode='w') 29 | fh.setLevel(logging.DEBUG) 30 | 31 | fh.setFormatter(formatter) 32 | logger.addHandler(fh) 33 | 34 | return logger 35 | -------------------------------------------------------------------------------- /data/breakfast/mapping.txt: -------------------------------------------------------------------------------- 1 | 0 SIL 2 | 1 pour_cereals 3 | 2 pour_milk 4 | 3 stir_cereals 5 | 4 take_bowl 6 | 5 pour_coffee 7 | 6 take_cup 8 | 7 spoon_sugar 9 | 8 stir_coffee 10 | 9 pour_sugar 11 | 10 pour_oil 12 | 11 crack_egg 13 | 12 add_saltnpepper 14 | 13 fry_egg 15 | 14 take_plate 16 | 15 put_egg2plate 17 | 16 take_eggs 18 | 17 butter_pan 19 | 18 take_knife 20 | 19 cut_orange 21 | 20 squeeze_orange 22 | 21 pour_juice 23 | 22 take_glass 24 | 23 take_squeezer 25 | 24 spoon_powder 26 | 25 stir_milk 27 | 26 spoon_flour 28 | 27 stir_dough 29 | 28 pour_dough2pan 30 | 29 fry_pancake 31 | 30 put_pancake2plate 32 | 31 pour_flour 33 | 32 cut_fruit 34 | 33 put_fruit2bowl 35 | 34 peel_fruit 36 | 35 stir_fruit 37 | 36 cut_bun 38 | 37 smear_butter 39 | 38 take_topping 40 | 39 put_toppingOnTop 41 | 40 put_bunTogether 42 | 41 take_butter 43 | 42 stir_egg 44 | 43 pour_egg2pan 45 | 44 stirfry_egg 46 | 45 add_teabag 47 | 46 pour_water 48 | 47 stir_tea 49 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: act-recog 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=openblas 9 | - ca-certificates=2021.10.26=h06a4308_2 10 | - certifi=2021.5.30=py36h06a4308_0 11 | - cffi=1.14.6=py36h400218f_0 12 | - cudatoolkit=10.1.243=h6bb024c_0 13 | - editdistance=0.5.3=py36h2531618_0 14 | - intel-openmp=2022.0.1=h06a4308_3633 15 | - joblib=1.0.1=pyhd3eb1b0_0 16 | - ld_impl_linux-64=2.35.1=h7274673_9 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.3.0=h5101ec6_17 19 | - libgfortran-ng=7.5.0=ha8ba4b0_17 20 | - libgfortran4=7.5.0=ha8ba4b0_17 21 | - libgomp=9.3.0=h5101ec6_17 22 | - libopenblas=0.3.17=hf726d26_1 23 | - libstdcxx-ng=9.3.0=hd4cf53a_17 24 | - mkl=2022.0.1=h06a4308_117 25 | - ncurses=6.3=h7f8727e_2 26 | - ninja=1.10.2=h5e70eb0_2 27 | - numpy=1.17.0=py36h99e49ec_0 28 | - numpy-base=1.17.0=py36h2f8d375_0 29 | - openssl=1.1.1m=h7f8727e_0 30 | - pip=21.2.2=py36h06a4308_0 31 | - pycparser=2.21=pyhd3eb1b0_0 32 | - python=3.6.13=h12debd9_1 33 | - pytorch=1.3.1=py3.6_cuda10.1.243_cudnn7.6.3_0 34 | - readline=8.1.2=h7f8727e_1 35 | - scikit-learn=0.24.2=py36ha9443f7_0 36 | - scipy=1.5.2=py36habc2bb6_0 37 | - setuptools=58.0.4=py36h06a4308_0 38 | - sqlite=3.37.2=hc218d9a_0 39 | - threadpoolctl=2.2.0=pyh0d69192_0 40 | - tk=8.6.11=h1ccaba5_0 41 | - tqdm=4.62.3=pyhd3eb1b0_1 42 | - wheel=0.37.1=pyhd3eb1b0_0 43 | - xz=5.2.5=h7b6447c_0 44 | - zlib=1.2.11=h7f8727e_4 45 | - pip: 46 | - genbmm==0.1 47 | - torch-struct==0.2 48 | prefix: /private/home/dpf/.conda/envs/crosstask-pytorch1.3 49 | -------------------------------------------------------------------------------- /src/data/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import PCA 3 | 4 | from utils.logger import logger 5 | from utils.utils import all_equal 6 | 7 | def merge_grouped(grouped_features): 8 | # grouped_features: Dict[group_name: str, Dict[vid_name: str, np array]] 9 | merged = {} 10 | # should have the same vid_name s for each group_name 11 | assert all_equal(group_dict.keys() for group_dict in grouped_features.values()) 12 | for vid_name in next(iter(grouped_features.values())): 13 | values = [t[1][vid_name] for t in sorted(grouped_features.items(), key=lambda t: t[0])] 14 | merged[vid_name] = np.hstack(values) 15 | return merged 16 | 17 | 18 | def grouped_pca(grouped_features, n_components: int, pca_models_by_group=None): 19 | # grouped_features: Dict[group_name: str, Dict[vid_name: str, np array]] 20 | if pca_models_by_group is not None: 21 | assert set(grouped_features.keys()) == set(pca_models_by_group.keys()) 22 | else: 23 | pca_models_by_group = {} 24 | for group_name, vid_dict in grouped_features.items(): 25 | # rows should be data points, so all groups should have the same number of cols 26 | assert all_equal(v.shape[1] for v in vid_dict.values()) 27 | X_l = [] 28 | for vid, features in vid_dict.items(): 29 | X_l.append(features) 30 | X = np.vstack(X_l) 31 | pca = PCA(n_components=min(n_components, X.shape[1])) 32 | pca.fit(X) 33 | logger.debug("group {}: {} instances".format(group_name, len(X_l))) 34 | logger.debug("group {}: pca explained {} of the variance".format(group_name, pca.explained_variance_ratio_.sum())) 35 | pca_models_by_group[group_name] = pca 36 | transformed = { 37 | group_name: { 38 | vid_name: pca_models_by_group[group_name].transform(x) 39 | for vid_name, x in vid_dict.items() 40 | } 41 | for group_name, vid_dict in grouped_features.items() 42 | } 43 | return transformed, pca_models_by_group 44 | -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | from torch.utils.data import DataLoader 3 | 4 | from data.corpus import Datasplit 5 | 6 | 7 | def add_training_args(parser): 8 | parser.add_argument('--epochs', type=int, default=60) 9 | parser.add_argument('--batch_accumulation', type=int, default=1) 10 | parser.add_argument('--lr', type=float, default=5e-3) 11 | parser.add_argument('--workers', type=int, default=0) 12 | 13 | parser.add_argument('--max_grad_norm', type=float, default=10) 14 | 15 | parser.add_argument('--print_every', type=int, default=100) 16 | 17 | parser.add_argument('--no_reduce_plateau', action='store_true') 18 | parser.add_argument('--reduce_plateau_factor', type=float, default=0.2) 19 | parser.add_argument('--reduce_plateau_patience', type=float, default=1) 20 | parser.add_argument('--reduce_plateau_min_lr', type=float, default=1e-4) 21 | 22 | parser.add_argument('--train_limit', type=int) 23 | 24 | parser.add_argument('--dev_decode_frequency', type=int, default=1) 25 | 26 | 27 | def make_optimizer(args, parameters): 28 | opt = torch.optim.Adam(parameters, lr=args.lr) 29 | if not args.no_reduce_plateau: 30 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 31 | opt, factor=args.reduce_plateau_factor, 32 | verbose=True, 33 | patience=args.reduce_plateau_patience, 34 | min_lr=1e-4, 35 | threshold=1e-5, 36 | ) 37 | else: 38 | scheduler = None 39 | return opt, scheduler 40 | 41 | 42 | def padding_colate(data_samples): 43 | data_samples = [samp for samp in data_samples if samp is not None] 44 | unpacked = { 45 | key: [samp[key] for samp in data_samples] 46 | for key in next(iter(data_samples)).keys() 47 | } 48 | 49 | lengths = [feats.size(0) for feats in unpacked['features']] 50 | # batch_size = len(lengths) 51 | # max_length = max(lengths) 52 | # lengths_t = torch.LongTensor(lengths) 53 | 54 | pad_keys = ['gt_single', 'features', 'constraints'] 55 | nopad_keys = ['task_name', 'video_name', 'task_indices', 'gt', 'gt_with_background'] 56 | data = {k: v for k, v in unpacked.items() if k in nopad_keys} 57 | data['lengths'] = torch.LongTensor(lengths) 58 | 59 | for key in pad_keys: 60 | if key in unpacked: 61 | data[key] = torch.nn.utils.rnn.pad_sequence(unpacked[key], batch_first=True, padding_value=0) 62 | 63 | return data 64 | 65 | 66 | def make_data_loader(args, datasplit: Datasplit, shuffle, batch_by_task, batch_size=1): 67 | # assert batch_size == 1, "other sizes not implemented" 68 | return DataLoader( 69 | datasplit, 70 | # batch_size=batch_size, 71 | num_workers=args.workers, 72 | # shuffle=shuffle, 73 | # drop_last=False, 74 | # collate_fn=lambda batch: batch, 75 | collate_fn=padding_colate, 76 | batch_sampler=datasplit.batch_sampler(batch_size, batch_by_task, shuffle) 77 | ) 78 | 79 | 80 | class Model(object): 81 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 82 | raise NotImplementedError() 83 | 84 | def predict(self, test_data): 85 | raise NotImplementedError() 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Segment Actions from Observation and Narration 2 | 3 | Code for the paper: 4 | [Learning to Segment Actions from Observation and Narration](https://arxiv.org/abs/2005.03684) 5 | Daniel Fried, Jean-Baptiste Alayrac, Phil Blunsom, Chris Dyer, Stephen Clark, and Aida Nematzadeh 6 | ACL, 2020 7 | 8 | ## Summary 9 | 10 | This repository provides a system for segmenting and labeling actions in a video, using a simple generative segmental (hidden semi-Markov) model of the video. This model can be used as a strong baseline for action segmentation on instructional video datasets such as [CrossTask](https://github.com/DmZhukov/CrossTask) ([Zhukov et al., CVPR 2019](https://arxiv.org/abs/1903.08225)), and can be trained fully supervised (with action labels for each frame in each video) or with weak supervision from narrative descriptions and "canonical" step orderings. Please see our paper for more details. 11 | 12 | ## Requirements 13 | 14 | * python 3.6 15 | * pytorch 1.3 16 | * sklearn 17 | * editdistance 18 | * tqdm 19 | * Particular commits of [genbmm](https://github.com/harvardnlp/genbmm) and [pytorch-struct](https://github.com/harvardnlp/pytorch-struct/). Newer versions may run out of memory on the long videos in the CrossTask dataset, due to changes to pytorch-struct that improve runtime complexity but increase memory usage. They can be installed via 20 | 21 | ```bash 22 | pip install -U git+https://github.com/harvardnlp/genbmm@bd42837ae0037a66803218d374c78fda72a9c9f4 23 | pip install -U git+https://github.com/harvardnlp/pytorch-struct@1c9b038a1bbece32fe8d2d46d9e3d7c09f4c08e7 24 | ``` 25 | 26 | See `env.yml` for a full list of other dependencies, which can be installed with conda. 27 | 28 | ## Setup 29 | 30 | 1. Download and unpack the CrossTask dataset of Zhukov et al.: 31 | 32 | ```bash 33 | cd data 34 | mkdir crosstask 35 | cd crosstask 36 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_release.zip 37 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_features.zip 38 | wget https://www.di.ens.fr/~dzhukov/crosstask/crosstask_constraints.zip 39 | unzip '*.zip' 40 | ``` 41 | 42 | 2. Preprocess the features with PCA. In the repository's root folder, run 43 | 44 | ```bash 45 | PYTHONPATH="src/":$PYTHONPATH python src/data/crosstask.py 46 | ``` 47 | 48 | This should generate the folder `data/crosstask/crosstask_processed/crosstask_primary_pca-200_with-bkg_by-task` 49 | 50 | ## Experiments 51 | 52 | Here are the commands to replicate key results from Table 2 in our [paper](https://arxiv.org/abs/2005.03684). Please contact Daniel Fried for others, or for any help or questions about the code. 53 | 54 | | Number | Name | Command | 55 | | ------ | ---- | ------- | 56 | | S6 | Supervised: SMM, generative | `./run_crosstask_i3d-resnet-audio.sh pca_semimarkov_sup --classifier semimarkov --training supervised --cuda` | 57 | | U7 | HSMM + Narr + Ord | `./run_crosstask_i3d-resnet-audio.sh pca_semimarkov_unsup_narration_ordering --classifier semimarkov --training unsupervised --mix_tasks --task_specific_steps --sm_constrain_transitions --annotate_background_with_previous --sm_constrain_with_narration train --sm_constrain_narration_weight=-1e4 --cuda` | 58 | 59 | ## Credits 60 | 61 | - Parts of the data loading and evaluation code are based on [this repo](https://github.com/Annusha/slim_mallow) from Anna Kukleva. 62 | - Code for invertible emission distributions are based on Junxian He's [structured flow code](https://github.com/jxhe/struct-learning-with-flow). (These didn't make it into the paper -- I wasn't able to get them to work consistently better than Gaussian emissions over the PCA features.) 63 | - Compound HSMM / VAE models are based on Yoon Kim's [Compound PCFG code](https://github.com/harvardnlp/compound-pcfg). (These also didn't make it into the paper, for the same reasons.) 64 | -------------------------------------------------------------------------------- /src/evaluation/f1.py: -------------------------------------------------------------------------------- 1 | # modified from slim_mallow by Anna Kukleva, https://github.com/Annusha/slim_mallow 2 | 3 | import numpy as np 4 | 5 | from utils.logger import logger 6 | 7 | 8 | class F1Score: 9 | def __init__(self, K, n_videos, verbose=True): 10 | self.sampling_ratio = 15 # number of frames per segment to sample 11 | self.n_experiments = 50 12 | self._K = K # number of predicted segments per video 13 | self._n_videos = n_videos 14 | self._eps = 1e-8 15 | self._verbose = verbose 16 | 17 | # TODO: update this to allow multiple predicted segments per video 18 | 19 | self.gt = None 20 | self.gt_sampled = None 21 | self.pr = None 22 | self.pr_sampled = None 23 | self.gt2pr = None 24 | self.mask = None 25 | self.exclude = [] 26 | 27 | self.bound_masks = [] # list of masks for each segment 28 | 29 | self.f1_scores = [] 30 | self._return = {} 31 | self._n_true_seg_all = 0 32 | 33 | def set_gt(self, gt): 34 | assert isinstance(gt, list) and isinstance(gt[0], list) 35 | gt = [gt_t[0] for gt_t in gt] 36 | self.gt = np.asarray(gt) 37 | self.mask = np.zeros(self.gt.shape, dtype=bool) 38 | 39 | def set_pr(self, pr): 40 | self.pr = np.asarray(pr) 41 | 42 | def set_gt2pr(self, gt2pr): 43 | self.gt2pr = gt2pr 44 | 45 | def set_exclude(self, label): 46 | self.bound_masks = [] 47 | self.exclude.append(label) 48 | mask_exclude = self.gt != label 49 | self.gt = self.gt[mask_exclude] 50 | self.pr = self.pr[mask_exclude] 51 | self.mask = np.zeros(self.gt.shape, dtype=bool) 52 | 53 | def _finish_init(self): 54 | if self.gt is not None and \ 55 | self.pr is not None and \ 56 | self.gt2pr is not None: 57 | self._pr2gt_convert() 58 | self._set_boundaries() 59 | 60 | def _pr2gt_convert(self): 61 | new_pr = np.asarray(self.pr).copy() 62 | for gt_label, pr_label in self.gt2pr.items(): 63 | if len(pr_label) == 0: 64 | continue 65 | m = np.sum(self.pr == pr_label[0]) 66 | new_pr[self.pr == pr_label[0]] = gt_label 67 | self.pr = np.asarray(new_pr).copy() 68 | 69 | def _set_boundaries(self): 70 | """Define boundaries for each segment from where sample.""" 71 | cur_label = self.gt[0] 72 | mask = np.zeros(self.gt.shape, dtype=bool) 73 | for label_idx, label in enumerate(self.gt): 74 | if label == cur_label: 75 | mask[label_idx] = True 76 | else: 77 | self.bound_masks.append(mask) 78 | mask = np.zeros(self.gt.shape, dtype=bool) 79 | mask[label_idx] = True 80 | cur_label = label 81 | 82 | def _sampling(self): 83 | """Define mask for frames for which measure a score. And label if the segment defined correctly.""" 84 | n_correct_segments = 0 85 | for mask in self.bound_masks: 86 | where = np.where(mask)[0] 87 | low = np.min(where) 88 | high = np.max(where) 89 | sampled_idxs = np.random.random_integers(low, high, self.sampling_ratio) 90 | n_corr_frames = np.sum(self.gt[sampled_idxs] == self.pr[sampled_idxs]) 91 | n_correct_segments += n_corr_frames / self.sampling_ratio 92 | # if n_corr_frames > self.sampling_ratio / 2: 93 | # n_correct_segments += 1 94 | 95 | precision = n_correct_segments / (self._K * self._n_videos) 96 | recall = n_correct_segments / len(self.bound_masks) 97 | f1 = 2 * (precision * recall) / (precision + recall + self._eps) 98 | self.f1_scores.append(f1) 99 | 100 | self._n_true_seg_all += n_correct_segments 101 | 102 | self._return['precision'] = [n_correct_segments, (self._K * self._n_videos)] 103 | self._return['recall'] = [n_correct_segments, len(self.bound_masks)] 104 | 105 | def f1(self): 106 | self._finish_init() 107 | for iteration in range(self.n_experiments): 108 | self._sampling() 109 | f1_mean = np.mean(self.f1_scores) 110 | # TODO: fix f1 computation and output it 111 | # if self._verbose: 112 | # logger.debug('f1 score: %f' % f1_mean) 113 | self._n_true_seg_all /= self.n_experiments 114 | self._return['precision'] = [self._n_true_seg_all, (self._K * self._n_videos)] 115 | self._return['recall'] = [self._n_true_seg_all, len(self.bound_masks)] 116 | self._return['mean_f1'] = [f1_mean, 1] 117 | 118 | 119 | def stat(self): 120 | return self._return 121 | -------------------------------------------------------------------------------- /src/models/flow.py: -------------------------------------------------------------------------------- 1 | # code from Junxian He, https://github.com/jxhe/struct-learning-with-flow/blob/master/modules/projection.py 2 | from __future__ import print_function 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ReLUNet(nn.Module): 11 | @classmethod 12 | def add_args(cls, parser): 13 | parser.add_argument('--flow_hidden_layers', type=int, default=1) 14 | parser.add_argument('--flow_hidden_units', type=int, default=100) 15 | 16 | def __init__(self, args, in_features, out_features): 17 | super(ReLUNet, self).__init__() 18 | 19 | self.args = args 20 | 21 | self.in_layer = nn.Linear(in_features, self.args.flow_hidden_units, bias=True) 22 | self.out_layer = nn.Linear(self.args.flow_hidden_units, out_features, bias=True) 23 | for i in range(self.args.flow_hidden_layers): 24 | name = 'cell{}'.format(i) 25 | cell = nn.Linear(self.args.flow_hidden_units, self.args.flow_hidden_units, bias=True) 26 | setattr(self, name, cell) 27 | 28 | def reset_parameters(self): 29 | self.in_layer.reset_parameters() 30 | self.out_layer.reset_parameters() 31 | for i in range(self.args.flow_hidden_layers): 32 | name = 'cell{}'.format(i) 33 | getattr(self, name).reset_parameters() 34 | 35 | def init_identity(self): 36 | self.in_layer.weight.data.zero_() 37 | self.in_layer.bias.data.zero_() 38 | self.out_layer.weight.data.zero_() 39 | self.out_layer.bias.data.zero_() 40 | for i in range(self.args.flow_hidden_layers): 41 | name = 'cell{}'.format(i) 42 | getattr(self, name).weight.data.zero_() 43 | getattr(self, name).bias.data.zero_() 44 | 45 | def forward(self, input): 46 | """ 47 | input: (batch_size, seq_length, in_features) 48 | output: (batch_size, seq_length, out_features) 49 | """ 50 | h = self.in_layer(input) 51 | h = F.relu(h) 52 | for i in range(self.args.flow_hidden_layers): 53 | name = 'cell{}'.format(i) 54 | h = getattr(self, name)(h) 55 | h = F.relu(h) 56 | return self.out_layer(h) 57 | 58 | 59 | class NICETrans(nn.Module): 60 | @classmethod 61 | def add_args(cls, parser): 62 | ReLUNet.add_args(parser) 63 | parser.add_argument('--flow_couple_layers', type=int, default=4) 64 | parser.add_argument('--flow_scale', action='store_true') 65 | parser.add_argument('--flow_scale_no_zero', action='store_true') 66 | 67 | def __init__(self, 68 | args, 69 | features): 70 | super(NICETrans, self).__init__() 71 | 72 | self.args = args 73 | 74 | for i in range(self.args.flow_couple_layers): 75 | name = 'cell{}'.format(i) 76 | cell = ReLUNet(args, features//2, features//2) 77 | setattr(self, name, cell) 78 | if args.flow_scale: 79 | name = 'scale_cell{}'.format(i) 80 | cell = ReLUNet(args, features//2, features//2) 81 | if not args.flow_scale_no_zero: 82 | cell.init_identity() 83 | setattr(self, name, cell) 84 | 85 | def reset_parameters(self): 86 | for i in range(self.args.flow_couple_layers): 87 | name = 'cell{}'.format(i) 88 | getattr(self, name).reset_parameters() 89 | if self.args.flow_scale: 90 | name = 'scale_cell{}'.format(i) 91 | getattr(self, name).reset_parameters() 92 | 93 | 94 | def forward(self, input): 95 | """ 96 | input: (batch_size, seq_length, features) 97 | h: (batch_size, seq_length, features) 98 | """ 99 | 100 | # For NICE it is a constant 101 | jacobian_loss = torch.zeros(input.size(0), device=input.device, requires_grad=False) 102 | 103 | ep_size = input.size() 104 | features = ep_size[-1] 105 | # h = odd_input 106 | h = input 107 | for i in range(self.args.flow_couple_layers): 108 | name = 'cell{}'.format(i) 109 | h1, h2 = torch.split(h, features//2, dim=-1) 110 | if i%2 == 1: 111 | h1, h2 = h2, h1 112 | t = getattr(self, name)(h1) 113 | if self.args.flow_scale: 114 | s = getattr(self, 'scale_cell{}'.format(i))(h1) 115 | jacobian_loss += s.sum(dim=-1).sum(dim=-1) 116 | h2_p = torch.exp(s) * h2 + t 117 | else: 118 | h2_p = h2 + t 119 | if i%2 == 1: 120 | h1, h2_p = h2_p, h1 121 | h = torch.cat((h1, h2_p), dim=-1) 122 | # if i%2 == 0: 123 | # h = torch.cat((h1, h2 + getattr(self, name)(h1)), dim=-1) 124 | # else: 125 | # h = torch.cat((h1 + getattr(self, name)(h2), h2), dim=-1) 126 | return h, jacobian_loss 127 | -------------------------------------------------------------------------------- /src/models/semimarkov/semimarkov_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.mixture import GaussianMixture 4 | 5 | 6 | def labels_to_spans(position_labels, max_k): 7 | # position_labels: b x N, LongTensor 8 | assert not (position_labels == -1).any(), "position_labels already appear span encoded (have -1)" 9 | b, N = position_labels.size() 10 | last = position_labels[:, 0] 11 | values = [last.unsqueeze(1)] 12 | lengths = torch.ones_like(last) 13 | for n in range(1, N): 14 | this = position_labels[:, n] 15 | same_symbol = (last == this) 16 | if max_k is not None: 17 | same_symbol = same_symbol & (lengths < max_k - 1) 18 | encoded = torch.where(same_symbol, torch.full([1], -1, device=same_symbol.device, dtype=torch.long), this) 19 | lengths = torch.where(same_symbol, lengths, torch.full([1], 0, device=same_symbol.device, dtype=torch.long)) 20 | lengths += 1 21 | values.append(encoded.unsqueeze(1)) 22 | last = this 23 | return torch.cat(values, dim=1) 24 | 25 | 26 | def rle_spans(spans, lengths): 27 | b, T = spans.size() 28 | all_rle = [] 29 | for i in range(b): 30 | this_rle = [] 31 | this_spans = spans[i, :lengths[i]] 32 | current_symbol = None 33 | count = 0 34 | for symbol in this_spans: 35 | symbol = symbol.item() 36 | if current_symbol is None or symbol != -1: 37 | if current_symbol is not None: 38 | assert count > 0 39 | this_rle.append((current_symbol, count)) 40 | count = 0 41 | current_symbol = symbol 42 | count += 1 43 | if current_symbol is not None: 44 | assert count > 0 45 | this_rle.append((current_symbol, count)) 46 | assert sum(count for sym, count in this_rle) == lengths[i] 47 | all_rle.append(this_rle) 48 | return all_rle 49 | 50 | 51 | def spans_to_labels(spans): 52 | # spans: b x N, LongTensor 53 | # contains 0.. for the start of a span (B-*), and -1 for its continuation (I-*) 54 | b, N = spans.size() 55 | current_labels = spans[:, 0] 56 | assert (current_labels != -1).all() 57 | values = [current_labels.unsqueeze(1)] 58 | for n in range(1, N): 59 | this = spans[:, n] 60 | this_labels = torch.where(this == -1, current_labels, this) 61 | values.append(this_labels.unsqueeze(1)) 62 | current_labels = this_labels 63 | return torch.cat(values, dim=1) 64 | 65 | 66 | def get_diagonal_covariances(data): 67 | # data: num_points x feat_dim 68 | model = GaussianMixture(n_components=1, covariance_type='diag') 69 | responsibilities = np.ones((data.shape[0], 1)) 70 | model._initialize(data, responsibilities) 71 | return model.covariances_, model.precisions_cholesky_ 72 | 73 | 74 | def semimarkov_sufficient_stats(feature_list, label_list, covariance_type, n_classes, max_k=None): 75 | assert len(feature_list) == len(label_list) 76 | tied_diag = covariance_type == 'tied_diag' 77 | if tied_diag: 78 | emissions = GaussianMixture(n_classes, covariance_type='diag') 79 | else: 80 | emissions = GaussianMixture(n_classes, covariance_type=covariance_type) 81 | X_l = [] 82 | r_l = [] 83 | 84 | span_counts = np.zeros(n_classes, dtype=np.float32) 85 | span_lengths = np.zeros(n_classes, dtype=np.float32) 86 | span_start_counts = np.zeros(n_classes, dtype=np.float32) 87 | # to, from 88 | span_transition_counts = np.zeros((n_classes, n_classes), dtype=np.float32) 89 | 90 | instance_count = 0 91 | 92 | # for i in tqdm.tqdm(list(range(len(train_data))), ncols=80): 93 | for X, labels in zip(feature_list, label_list): 94 | X_l.append(X) 95 | r = np.zeros((X.shape[0], n_classes)) 96 | r[np.arange(X.shape[0]), labels] = 1 97 | assert r.sum() == X.shape[0] 98 | r_l.append(r) 99 | spans = labels_to_spans(labels.unsqueeze(0), max_k) 100 | # symbol, length 101 | spans = rle_spans(spans, torch.LongTensor([spans.size(1)]))[0] 102 | last_symbol = None 103 | for index, (symbol, length) in enumerate(spans): 104 | if index == 0: 105 | span_start_counts[symbol] += 1 106 | span_counts[symbol] += 1 107 | span_lengths[symbol] += length 108 | if last_symbol is not None: 109 | span_transition_counts[symbol, last_symbol] += 1 110 | last_symbol = symbol 111 | instance_count += 1 112 | 113 | X_arr = np.vstack(X_l) 114 | r_arr = np.vstack(r_l) 115 | emissions._initialize(X_arr, r_arr) 116 | if tied_diag: 117 | cov, prec_chol = get_diagonal_covariances(X_arr) 118 | emissions.covariances_[:] = np.copy(cov) 119 | emissions.precisions_cholesky_[:] = np.copy(prec_chol) 120 | return emissions, { 121 | 'span_counts': span_counts, 122 | 'span_lengths': span_lengths, 123 | 'span_start_counts': span_start_counts, 124 | 'span_transition_counts': span_transition_counts, 125 | 'instance_count': instance_count, 126 | } 127 | -------------------------------------------------------------------------------- /src/models/framewise.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from models.model import Model, make_optimizer, make_data_loader 6 | from utils.utils import all_equal 7 | 8 | from models.semimarkov.semimarkov_utils import semimarkov_sufficient_stats 9 | 10 | from collections import Counter 11 | 12 | from data.corpus import Datasplit 13 | 14 | 15 | class FeedForward(nn.Module): 16 | @classmethod 17 | def add_args(cls, parser): 18 | parser.add_argument('--ff_dropout_p', type=float, default=0.1) 19 | parser.add_argument('--ff_hidden_layers', type=int, default=0) 20 | parser.add_argument('--ff_hidden_dim', type=int, default=200) 21 | 22 | def __init__(self, args, input_dim, output_dim): 23 | super(FeedForward, self).__init__() 24 | self.args = args 25 | layers = [nn.Dropout(p=args.ff_dropout_p)] 26 | layers.append(nn.Linear(input_dim,output_dim if args.ff_hidden_layers == 0 else args.ff_hidden_dim)) 27 | if args.ff_hidden_layers > 0: 28 | for l_ix in range(args.ff_hidden_layers): 29 | # TODO: consider adding dropout in here 30 | layers.append(nn.ReLU()) 31 | layers.append(nn.Linear(args.ff_hidden_dim, args.ff_hidden_dim if l_ix < args.ff_hidden_layers - 1 else output_dim)) 32 | self.layers = nn.Sequential(*layers) 33 | 34 | def forward(self, x, valid_classes_per_instance=None): 35 | batch_size = x.size(0) 36 | logits = self.layers(x) 37 | if valid_classes_per_instance is not None: 38 | assert all_equal(set(vc.detach().cpu().numpy()) for vc in 39 | valid_classes_per_instance), "must have same valid_classes for all instances in the batch" 40 | valid_classes = valid_classes_per_instance[0] 41 | mask = torch.full_like(logits, -float("inf")) 42 | mask[:,valid_classes] = 0 43 | logits = logits + mask 44 | return logits 45 | 46 | class FramewiseBaseline(Model): 47 | @classmethod 48 | def add_args(cls, parser): 49 | parser.add_argument("--framewise_baseline_type", choices=['majority_class', 'sample_class_distribution']) 50 | 51 | @classmethod 52 | def from_args(cls, args, train_data: Datasplit): 53 | return FramewiseBaseline(args, train_data) 54 | 55 | def __init__(self, args, train_data: Datasplit): 56 | self.args = args 57 | self.n_classes = train_data._corpus.n_classes 58 | self.class_histograms_by_task = {} 59 | 60 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 61 | assert use_labels 62 | loader = make_data_loader(self.args, train_data, batch_by_task=False, shuffle=True, batch_size=1) 63 | 64 | for batch in tqdm.tqdm(loader, ncols=80): 65 | tasks = batch['task_name'] 66 | assert len(tasks) == 1 67 | task = next(iter(tasks)) 68 | task_indices = next(iter(batch['task_indices'])) 69 | gt_single = batch['gt_single'].squeeze(0) 70 | assert all(ix in task_indices for ix in set(gt_single)) 71 | if task not in self.class_histograms_by_task: 72 | self.class_histograms_by_task[task] = Counter() 73 | 74 | self.class_histograms_by_task[task].update(gt_single.numpy()) 75 | 76 | def predict(self, test_data: Datasplit): 77 | predictions = {} 78 | loader = make_data_loader(self.args, test_data, batch_by_task=False, shuffle=False, batch_size=1) 79 | 80 | probs_by_task = {} 81 | classes_by_task = {} 82 | for task, task_distr in self.class_histograms_by_task.items(): 83 | classes, counts = zip(*task_distr.most_common()) 84 | classes_by_task[task] = classes 85 | probs_by_task[task] = np.array(counts, dtype=np.float) / sum(counts) 86 | 87 | for batch in tqdm.tqdm(loader, ncols=80): 88 | features = batch['features'].squeeze(0) 89 | num_timesteps = features.size(0) 90 | 91 | tasks = batch['task_name'] 92 | assert len(tasks) == 1 93 | task = next(iter(tasks)) 94 | task_indices = next(iter(batch['task_indices'])) 95 | videos = batch['video_name'] 96 | assert len(videos) == 1 97 | video = next(iter(videos)) 98 | 99 | task_distr = self.class_histograms_by_task[task] 100 | 101 | if self.args.framewise_baseline_type == 'majority_class': 102 | class_pred, _ = task_distr.most_common()[0] 103 | preds = np.full(num_timesteps, class_pred, dtype=np.long) 104 | else: 105 | assert self.args.framewise_baseline_type == 'sample_class_distribution' 106 | probs = probs_by_task[task] 107 | classes = classes_by_task[task] 108 | pred_indices = np.random.multinomial(1, probs, size=num_timesteps).argmax(axis=1) 109 | preds = np.array([classes[ix] for ix in pred_indices]) 110 | assert all(ix in task_indices for ix in set(preds)) 111 | predictions[video] = preds 112 | return predictions 113 | 114 | class FramewiseDiscriminative(Model): 115 | @classmethod 116 | def add_args(cls, parser): 117 | FeedForward.add_args(parser) 118 | 119 | @classmethod 120 | def from_args(cls, args, train_data: Datasplit): 121 | return FramewiseDiscriminative(args, train_data) 122 | 123 | def __init__(self, args, train_data: Datasplit): 124 | self.args = args 125 | #self.n_classes = sum(len(indices) for indices in train_data.groundtruth.indices_by_task.values()) 126 | self.n_classes = train_data._corpus.n_classes 127 | self.model = FeedForward(args, 128 | input_dim=train_data.feature_dim, 129 | output_dim=self.n_classes) 130 | if args.cuda: 131 | self.model.cuda() 132 | 133 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 134 | assert use_labels 135 | loss = nn.CrossEntropyLoss() 136 | optimizer, scheduler = make_optimizer(self.args, self.model.parameters()) 137 | loader = make_data_loader(self.args, train_data, batch_by_task=False, shuffle=True, batch_size=1) 138 | 139 | for epoch in range(self.args.epochs): 140 | # call here since we may set eval in callback_fn 141 | self.model.train() 142 | losses = [] 143 | for batch in tqdm.tqdm(loader, ncols=80): 144 | # for batch in loader: 145 | tasks = batch['task_name'] 146 | videos = batch['video_name'] 147 | features = batch['features'].squeeze(0) 148 | gt_single = batch['gt_single'].squeeze(0) 149 | task_indices = batch['task_indices'] 150 | if self.args.cuda: 151 | features = features.cuda() 152 | task_indices = [indx.cuda() for indx in task_indices] 153 | gt_single = gt_single.cuda() 154 | logits = self.model.forward(features, valid_classes_per_instance=task_indices) 155 | 156 | this_loss = loss(logits, gt_single) 157 | losses.append(this_loss.item()) 158 | this_loss.backward() 159 | 160 | optimizer.step() 161 | self.model.zero_grad() 162 | train_loss = np.mean(losses) 163 | callback_fn(epoch, {'train_loss': train_loss}) 164 | if scheduler is not None: 165 | scheduler.step(train_loss) 166 | # if evaluate_on_data_fn is not None: 167 | # train_mof = evaluate_on_data_fn(self, train_data, 'train') 168 | # dev_mof = evaluate_on_data_fn(self, dev_data, 'dev') 169 | # dev_mof_by_epoch[epoch] = dev_mof 170 | # log_str += ("\ttrain mof: {:.4f}".format(train_mof)) 171 | # log_str += ("\tdev mof: {:.4f}".format(dev_mof)) 172 | 173 | def predict(self, test_data: Datasplit): 174 | self.model.eval() 175 | predictions = {} 176 | loader = make_data_loader(self.args, test_data, batch_by_task=False, shuffle=False, batch_size=1) 177 | for batch in loader: 178 | features = batch['features'].squeeze(0) 179 | task_indices = batch['task_indices'] 180 | if self.args.cuda: 181 | features = features.cuda() 182 | task_indices = [indx.cuda() for indx in task_indices] 183 | videos = batch['video_name'] 184 | assert all_equal(videos) 185 | video = next(iter(videos)) 186 | logits = self.model.forward(features, valid_classes_per_instance=task_indices) 187 | preds = logits.max(dim=1)[1] 188 | # handle the edge case where there's only a single instance, in which case preds.size() <= 1 189 | if len(preds.size()) > 1: 190 | preds = preds.squeeze(-1) 191 | predictions[video] = preds.detach().cpu().numpy() 192 | return predictions 193 | 194 | 195 | class FramewiseGaussianMixture(Model): 196 | @classmethod 197 | def add_args(cls, parser): 198 | parser.add_argument('--gm_covariance', choices=['full', 'diag', 'tied', 'tied_diag'], default='tied_diag') 199 | 200 | @classmethod 201 | def from_args(cls, args, train_data): 202 | n_classes = train_data._corpus.n_classes 203 | feature_dim = train_data.feature_dim 204 | return FramewiseGaussianMixture(args, n_classes, feature_dim) 205 | 206 | def __init__(self, args, n_classes, feature_dim): 207 | self.args = args 208 | self.n_classes = n_classes 209 | self.feature_dim = feature_dim 210 | self.model = None 211 | 212 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 213 | loader = make_data_loader(self.args, train_data, batch_by_task=False, shuffle=False, batch_size=1) 214 | feature_list, label_list = [], [] 215 | for batch in loader: 216 | feature_list.append(batch['features'].squeeze(0)) 217 | label_list.append(batch['gt_single'].squeeze(0)) 218 | gmm, stats = semimarkov_sufficient_stats(feature_list, label_list, 219 | self.args.gm_covariance, 220 | self.n_classes, max_k=100) 221 | self.model = gmm 222 | 223 | def predict(self, test_data): 224 | assert self.model is not None 225 | predictions = {} 226 | # for i in tqdm.tqdm(list(range(len(test_data))), ncols=80): 227 | for i in range(len(test_data)): 228 | sample = test_data._get_by_index(i, wrap_torch=False) 229 | X = sample['features'] 230 | mask_indices = list(set(range(self.n_classes)) - set(sample['task_indices'])) 231 | if mask_indices: 232 | probs = self.model.predict_proba(X) 233 | probs[:, mask_indices] = 0 234 | probs /= probs.sum(axis=1)[:,None] 235 | preds = probs.argmax(axis=1) 236 | else: 237 | preds = self.model.predict(X) 238 | predictions[sample['video_name']] = preds 239 | return predictions 240 | -------------------------------------------------------------------------------- /src/models/test_semimarkov.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.optimize import linear_sum_assignment 6 | from torch.utils.data import Dataset, DataLoader 7 | from torch_struct import SemiMarkov, MaxSemiring 8 | 9 | from models.semimarkov.semimarkov_modules import SemiMarkovModule 10 | 11 | # device = torch.device("cuda") 12 | device = torch.device("cpu") 13 | 14 | sm_max = SemiMarkov(MaxSemiring) 15 | 16 | BIG_NEG = -1e9 17 | 18 | 19 | class ToyDataset(Dataset): 20 | def __init__(self, labels, features, lengths, valid_classes, max_k): 21 | self.labels = labels 22 | self.features = features 23 | self.lengths = lengths 24 | self.valid_classes = valid_classes 25 | self.max_k = max_k 26 | 27 | def __len__(self): 28 | return self.labels.size(0) 29 | 30 | def __getitem__(self, index): 31 | labels = self.labels[index] 32 | spans = SemiMarkovModule.labels_to_spans(labels.unsqueeze(0), max_k=self.max_k).squeeze(0) 33 | return { 34 | 'labels': self.labels[index], 35 | 'features': self.features[index], 36 | 'lengths': self.lengths[index], 37 | 'valid_classes': self.valid_classes[index], 38 | 'spans': spans, 39 | } 40 | 41 | 42 | def synthetic_data(num_data_points=200, C=3, N=100, K=5, num_classes_per_instance=None): 43 | def make_synthetic_features(class_labels, shift_constant=1.0): 44 | _batch_size, _N = class_labels.size() 45 | f = torch.randn((_batch_size, _N, C)) 46 | shift = torch.zeros_like(f) 47 | shift.scatter_(2, class_labels.unsqueeze(2), shift_constant) 48 | return shift + f 49 | 50 | labels_l = [] 51 | lengths = [] 52 | valid_classes = [] 53 | for i in range(num_data_points): 54 | if i == 0: 55 | length = N 56 | else: 57 | length = random.randint(K, N) 58 | lengths.append(length) 59 | lab = [] 60 | current_step = 0 61 | if num_classes_per_instance is not None: 62 | assert num_classes_per_instance <= C 63 | this_valid_classes = np.random.choice(list(range(C)), size=num_classes_per_instance, replace=False) 64 | else: 65 | this_valid_classes = list(range(C)) 66 | valid_classes.append(this_valid_classes) 67 | while len(lab) < N: 68 | step_length = random.randint(1, K - 1) 69 | this_label = this_valid_classes[current_step % len(this_valid_classes)] 70 | lab.extend([this_label] * step_length) 71 | current_step += 1 72 | lab = lab[:N] 73 | labels_l.append(lab) 74 | labels = torch.LongTensor(labels_l) 75 | features = make_synthetic_features(labels) 76 | lengths = torch.LongTensor(lengths) 77 | valid_classes = [torch.LongTensor(tvc) for tvc in valid_classes] 78 | 79 | return labels, features, lengths, valid_classes 80 | 81 | 82 | def partition_rows(arr, N): 83 | if isinstance(arr, list): 84 | assert N < len(list) 85 | else: 86 | assert N < arr.size(0) 87 | return arr[:N], arr[N:] 88 | 89 | 90 | def test_learn_synthetic(): 91 | C = 3 92 | MAX_K = 20 93 | K = 5 94 | N = 20 95 | N_train = 150 96 | N_test = 50 97 | 98 | closed_form_supervised = True 99 | 100 | supervised = True 101 | 102 | allow_self_transitions = True 103 | 104 | num_classes_per_instance = None 105 | 106 | epochs = 20 107 | 108 | batch_size = 10 109 | 110 | train_data = ToyDataset( 111 | *synthetic_data(num_data_points=N_train, C=C, N=N, K=K, num_classes_per_instance=num_classes_per_instance), 112 | max_k=MAX_K 113 | ) 114 | train_loader = DataLoader(train_data, batch_size=batch_size) 115 | test_data = ToyDataset( 116 | *synthetic_data(num_data_points=N_test, C=C, N=N, K=K, num_classes_per_instance=num_classes_per_instance), 117 | max_k=MAX_K 118 | ) 119 | test_loader = DataLoader(test_data, batch_size=batch_size) 120 | 121 | model = SemiMarkovModule(C, C, max_k=MAX_K, allow_self_transitions=allow_self_transitions) 122 | model.initialize_gaussian(train_data.features, train_data.lengths) 123 | 124 | model.train() 125 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-1) 126 | 127 | if supervised and closed_form_supervised: 128 | train_features = [] 129 | train_labels = [] 130 | for i in range(len(train_data)): 131 | sample = train_data[i] 132 | train_features.append(sample['features']) 133 | train_labels.append(sample['labels']) 134 | model.fit_supervised(train_features, train_labels) 135 | else: 136 | for epoch in range(epochs): 137 | losses = [] 138 | for batch in train_loader: 139 | # if self.args.cuda: 140 | # features = features.cuda() 141 | # task_indices = task_indices.cuda() 142 | # gt_single = gt_single.cuda() 143 | features = batch['features'] 144 | lengths = batch['lengths'] 145 | spans = batch['spans'] 146 | valid_classes = batch['valid_classes'] 147 | this_N = lengths.max().item() 148 | features = features[:, :this_N, :] 149 | spans = spans[:, :this_N] 150 | 151 | if supervised: 152 | spans_sup = spans 153 | else: 154 | spans_sup = None 155 | 156 | this_loss = -model.log_likelihood(features, lengths, valid_classes_per_instance=valid_classes, 157 | spans=spans_sup) 158 | this_loss.backward() 159 | 160 | losses.append(this_loss.item()) 161 | 162 | optimizer.step() 163 | model.zero_grad() 164 | train_acc, train_remap_acc, _ = predict_synthetic(model, train_loader) 165 | test_acc, test_remap_acc, _ = predict_synthetic(model, test_loader) 166 | # print(train_acc) 167 | # print(train_remap_acc) 168 | # print(test_acc) 169 | # print(test_remap_acc) 170 | print("epoch {} avg loss: {:.4f}\ttrain acc: {:.2f}\ttest acc: {:.2f}".format( 171 | epoch, 172 | np.mean(losses), 173 | train_acc if supervised else train_remap_acc, 174 | test_acc if supervised else test_remap_acc, 175 | )) 176 | 177 | return model, train_loader, test_loader 178 | 179 | 180 | def optimal_map(predicted_labels, gold_labels, possible_labels): 181 | assert all(lab in possible_labels for lab in predicted_labels) 182 | assert all(lab in possible_labels for lab in gold_labels) 183 | voting_table = np.zeros((len(possible_labels), len(possible_labels))) 184 | labs_numpy = possible_labels.detach().cpu().numpy() 185 | for idx_gt, label_gt in enumerate(labs_numpy): 186 | gold_mask = gold_labels == label_gt 187 | for idx_pr, label_pr in enumerate(labs_numpy): 188 | voting_table[idx_gt, idx_pr] = (predicted_labels[gold_mask] == label_pr).sum() 189 | 190 | best_gt, best_pr = linear_sum_assignment(-voting_table) 191 | mapping = { 192 | labs_numpy[pr]: labs_numpy[gt] 193 | for pr, gt in zip(best_pr, best_gt) 194 | } 195 | remapped = predicted_labels.clone() 196 | remapped.apply_(lambda lab: mapping[lab]) 197 | return remapped, mapping 198 | 199 | 200 | def predict_synthetic(model, dataloader): 201 | items = [] 202 | token_match = 0 203 | token_total = 0 204 | token_remap_match = 0 205 | for batch in dataloader: 206 | features = batch['features'] 207 | lengths = batch['lengths'] 208 | gold_spans = batch['spans'] 209 | valid_classes = batch['valid_classes'] 210 | 211 | batch_size = features.size(0) 212 | 213 | this_N = lengths.max().item() 214 | features = features[:, :this_N, :] 215 | gold_spans = gold_spans[:, :this_N] 216 | 217 | pred_spans = model.viterbi(features, lengths, valid_classes_per_instance=valid_classes, add_eos=True) 218 | gold_labels = model.spans_to_labels(gold_spans) 219 | pred_labels = model.spans_to_labels(pred_spans) 220 | 221 | gold_labels_trim = model.trim(gold_labels, lengths, check_eos=False) 222 | pred_labels_trim = model.trim(pred_labels, lengths, check_eos=True) 223 | 224 | assert len(gold_labels_trim) == batch_size 225 | assert len(pred_labels_trim) == batch_size 226 | 227 | for i in range(batch_size): 228 | this_valid_classes = valid_classes[i] 229 | pred_remapped, mapping = optimal_map(pred_labels_trim[i], gold_labels_trim[i], this_valid_classes) 230 | item = { 231 | 'length': lengths[i].item(), 232 | 'gold_spans': gold_spans[i], 233 | 'pred_spans': pred_spans[i], 234 | 'gold_labels': gold_labels[i], 235 | 'pred_labels': pred_labels[i], 236 | 'gold_labels_trim': gold_labels_trim[i], 237 | 'pred_labels_trim': pred_labels_trim[i], 238 | 'pred_labels_remap_trim': pred_remapped, 239 | 'mapping': mapping 240 | } 241 | items.append(item) 242 | token_match += (gold_labels_trim[i] == pred_labels_trim[i]).sum().item() 243 | token_remap_match += (gold_labels_trim[i] == pred_remapped).sum().item() 244 | token_total += pred_labels_trim[i].size(0) 245 | accuracy = 100.0 * token_match / token_total 246 | remapped_accuracy = 100.0 * token_remap_match / token_total 247 | return accuracy, remapped_accuracy, items 248 | 249 | 250 | def test_labels_and_spans(): 251 | position_labels = torch.LongTensor([[0, 1, 1, 2, 2, 2], [0, 1, 2, 3, 3, 4]]) 252 | spans = torch.LongTensor([[0, 1, -1, 2, -1, -1], [0, 1, 2, 3, -1, 4]]) 253 | rle = [[(0, 1), (1, 2), (2, 3)], [(0, 1), (1, 1), (2, 1), (3, 2), (4, 1)]] 254 | assert (SemiMarkovModule.labels_to_spans(position_labels, max_k=10) == spans).all() 255 | assert (SemiMarkovModule.spans_to_labels(spans) == position_labels).all() 256 | assert SemiMarkovModule.rle_spans(spans, lengths=torch.LongTensor([6, 6])) == rle 257 | trunc_lengths = torch.LongTensor([5, 6]) 258 | trunc_rle = [[(0, 1), (1, 2), (2, 2)], [(0, 1), (1, 1), (2, 1), (3, 2), (4, 1)]] 259 | assert SemiMarkovModule.rle_spans(spans, lengths=trunc_lengths) == trunc_rle 260 | 261 | rand_labels = torch.randint(low=0, high=3, size=(5, 20)) 262 | assert (SemiMarkovModule.spans_to_labels( 263 | SemiMarkovModule.labels_to_spans(rand_labels, max_k=5)) == rand_labels).all() 264 | 265 | 266 | def test_log_hsmm(): 267 | # b = 100 268 | # C = 7 269 | # N = 300 270 | # K = 50 271 | # step_length = 20 272 | 273 | # b = 10 274 | # C = 3 275 | # N = 10 276 | # K = 20 # K > N 277 | # step_length = 2 278 | 279 | b = 10 280 | C = 4 281 | N = 100 282 | K = 5 283 | step_length = 4 284 | 285 | add_eos = True 286 | 287 | padded_length = N + step_length * 2 288 | 289 | lengths_unpadded = torch.full((b,), N).long() 290 | lengths_unpadded[0] = padded_length 291 | lengths = lengths_unpadded + 1 292 | 293 | num_steps = N // step_length 294 | assert N % step_length == 0 # since we're fixing lengths, need to end perfectly 295 | 296 | # trans_scores = torch.from_numpy(np.array([[0,1,0],[0,0,1],[1,0,0]]).T).float().log() 297 | trans_scores = torch.zeros(C, C, device=device) 298 | init_scores = torch.full((C,), BIG_NEG, device=device) 299 | init_scores[0] = 0 300 | 301 | emission_scores = torch.full((b, padded_length, C), BIG_NEG, device=device) 302 | 303 | for n in range(padded_length): 304 | c = (n // step_length) % C 305 | emission_scores[:, n, c] = 1 306 | 307 | length_scores = torch.full((K, C), BIG_NEG, device=device) 308 | length_scores[step_length, :] = 0 309 | 310 | scores = SemiMarkovModule.log_hsmm(trans_scores, emission_scores, init_scores, length_scores, lengths_unpadded, 311 | add_eos=add_eos) 312 | marginals = sm_max.marginals(scores, lengths=lengths) 313 | 314 | sequence, extra = sm_max.from_parts(marginals) 315 | 316 | for step in range(num_steps): 317 | c = step % C 318 | assert torch.allclose(sequence[:, step_length * step], torch.full((1,), c).long()) 319 | 320 | # C == EOS 321 | if add_eos: 322 | batch_indices = torch.arange(0, b) 323 | assert torch.allclose(sequence[batch_indices, lengths - 1], torch.full((1,), C).long()) 324 | 325 | 326 | test_labels_and_spans() 327 | print("test_labels_and_spans passed") 328 | 329 | test_log_hsmm() 330 | print("test_log_hsmm passed") 331 | 332 | model, trainloader, testloader = test_learn_synthetic() 333 | train_acc, train_remap_accuracy, train_preds = predict_synthetic(model, trainloader) 334 | test_acc, test_remap_accuracy, test_preds = predict_synthetic(model, testloader) 335 | print("train acc: {:.2f}".format(train_acc)) 336 | print("train remap acc: {:.2f}".format(train_remap_accuracy)) 337 | print("test acc: {:.2f}".format(test_acc)) 338 | print("test remap acc: {:.2f}".format(test_remap_accuracy)) 339 | -------------------------------------------------------------------------------- /src/models/sequential.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from models.model import Model, make_optimizer, make_data_loader 6 | from utils.utils import all_equal 7 | 8 | from data.corpus import Datasplit 9 | 10 | 11 | class Encoder(nn.Module): 12 | @classmethod 13 | def add_args(cls, parser): 14 | parser.add_argument('--seq_num_layers', type=int, default=2) 15 | 16 | def __init__(self, args, input_dim, output_dim): 17 | super(Encoder, self).__init__() 18 | self.args = args 19 | assert output_dim % 2 == 0 20 | # TODO: dropout? 21 | self.encoder = nn.LSTM(input_dim, output_dim // 2, bidirectional=True, num_layers=args.seq_num_layers, batch_first=True) 22 | 23 | def flatten_parameters(self): 24 | self.encoder.flatten_parameters() 25 | 26 | def forward(self, features, lengths, output_padding_value=0): 27 | packed = nn.utils.rnn.pack_padded_sequence(features, lengths, batch_first=True, enforce_sorted=False) 28 | encoded_packed, _ = self.encoder(packed) 29 | encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded_packed, batch_first=True, padding_value=output_padding_value) 30 | return encoded 31 | 32 | class SequentialPredictConstraints(Model): 33 | @classmethod 34 | def add_args(cls, parser): 35 | pass 36 | 37 | @classmethod 38 | def from_args(cls, args, train_data: Datasplit): 39 | return cls(args, train_data) 40 | 41 | 42 | def __init__(self, args, train_data: Datasplit): 43 | from data.crosstask import CrosstaskDatasplit 44 | assert isinstance(train_data, CrosstaskDatasplit) 45 | 46 | self.args = args 47 | self.n_classes = train_data._corpus.n_classes 48 | self.remove_background = train_data.remove_background 49 | 50 | self.ordered_nonbackground_indices_by_task = { 51 | task_id: [train_data.corpus._index(step) for step in task.steps] 52 | for task_id, task in train_data._tasks_by_id.items() 53 | } 54 | 55 | self.background_indices_by_task = { 56 | task_id: list(sorted(ix for ix in train_data.corpus.indices_by_task(task_id) 57 | if ix in set(train_data.corpus._background_indices))) 58 | for task_id in train_data._tasks_by_id.keys() 59 | } 60 | assert all(len(v) == 1 for v in self.background_indices_by_task.values()), self.background_indices_by_task 61 | 62 | if train_data.remove_background: 63 | self.canonical = SequentialCanonicalBaseline(args, train_data) 64 | else: 65 | self.canonical = None 66 | 67 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 68 | pass 69 | 70 | def predict(self, test_data: Datasplit): 71 | predictions = {} 72 | loader = make_data_loader(self.args, test_data, batch_by_task=False, shuffle=False, batch_size=1) 73 | 74 | 75 | for batch in loader: 76 | features = batch['features'].squeeze(0) 77 | num_timesteps = features.size(0) 78 | 79 | tasks = batch['task_name'] 80 | assert len(tasks) == 1 81 | task = next(iter(tasks)) 82 | videos = batch['video_name'] 83 | assert len(videos) == 1 84 | video = next(iter(videos)) 85 | 86 | # constraints: T x K 87 | constraints = batch['constraints'].squeeze(0) 88 | assert constraints.size(0) == num_timesteps 89 | 90 | step_indices = self.ordered_nonbackground_indices_by_task[task] 91 | background_indices = self.background_indices_by_task[task] 92 | 93 | active_step = constraints.argmax(dim=1) 94 | active_step.apply_(lambda ix: step_indices[ix]) 95 | if not test_data.remove_background: 96 | active_step[constraints.sum(dim=1) == 0] = background_indices[0] 97 | predictions[video] = active_step.cpu().numpy() 98 | else: 99 | preds = active_step.cpu().numpy() 100 | zero_indices = (constraints.sum(dim=1) == 0).nonzero().flatten() 101 | baseline_preds = self.canonical.predict_single(task, num_timesteps) 102 | for ix in zero_indices: 103 | preds[ix] = baseline_preds[ix] 104 | predictions[video] = preds 105 | # just arbitrarily choose a background index, they will get canonicalized anyway 106 | return predictions 107 | 108 | class SequentialGroundTruth(Model): 109 | @classmethod 110 | def add_args(cls, parser): 111 | pass 112 | 113 | @classmethod 114 | def from_args(cls, args, train_data: Datasplit): 115 | return cls(args, train_data) 116 | 117 | def __init__(self, args, train_data: Datasplit): 118 | from data.crosstask import CrosstaskDatasplit 119 | assert isinstance(train_data, CrosstaskDatasplit) 120 | self.args = args 121 | self.n_classes = train_data._corpus.n_classes 122 | self.remove_background = train_data.remove_background 123 | 124 | pass 125 | 126 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 127 | pass 128 | 129 | def predict(self, test_data: Datasplit): 130 | predictions = {} 131 | loader = make_data_loader(self.args, test_data, batch_by_task=False, shuffle=False, batch_size=1) 132 | 133 | for batch in loader: 134 | features = batch['features'].squeeze(0) 135 | # num_timesteps = features.size(0) 136 | 137 | tasks = batch['task_name'] 138 | assert len(tasks) == 1 139 | # task = next(iter(tasks)) 140 | videos = batch['video_name'] 141 | assert len(videos) == 1 142 | video = next(iter(videos)) 143 | 144 | predictions[video] = batch['gt_single'].squeeze(0).numpy().tolist() 145 | return predictions 146 | 147 | class SequentialCanonicalBaseline(Model): 148 | @classmethod 149 | def add_args(cls, parser): 150 | parser.add_argument('--canonical_baseline_background_fraction', type=float, default=0.0) 151 | 152 | @classmethod 153 | def from_args(cls, args, train_data: Datasplit): 154 | return cls(args, train_data) 155 | 156 | def __init__(self, args, train_data: Datasplit): 157 | from data.crosstask import CrosstaskDatasplit 158 | assert isinstance(train_data, CrosstaskDatasplit) 159 | self.args = args 160 | self.n_classes = train_data._corpus.n_classes 161 | self.remove_background = train_data.remove_background 162 | 163 | self.ordered_nonbackground_indices_by_task = { 164 | task_id: [train_data.corpus._index(step) for step in task.steps] 165 | for task_id, task in train_data._tasks_by_id.items() 166 | } 167 | 168 | self.background_indices_by_task = { 169 | task_id: list(sorted(ix for ix in train_data.corpus.indices_by_task(task_id) 170 | if ix in set(train_data.corpus._background_indices))) 171 | for task_id in train_data._tasks_by_id.keys() 172 | } 173 | assert all(len(v) == 1 for v in self.background_indices_by_task.values()), self.background_indices_by_task 174 | 175 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 176 | pass 177 | 178 | def predict_single(self, task_id, num_timesteps): 179 | if self.remove_background: 180 | num_background_frames = 0 181 | else: 182 | num_background_frames = int(num_timesteps * self.args.canonical_baseline_background_fraction) 183 | background_index = next(iter(self.background_indices_by_task[task_id])) 184 | 185 | nonbackground_indices = self.ordered_nonbackground_indices_by_task[task_id] 186 | 187 | # this fails if we've removed background b/c some videos are too short 188 | if not self.remove_background: 189 | assert num_timesteps >= len(nonbackground_indices) 190 | 191 | # expand the total nonbackground duration to fit all background frames 192 | num_nonbackground_frames = max(num_timesteps - num_background_frames, len(nonbackground_indices)) 193 | 194 | step_duration = num_nonbackground_frames // len(nonbackground_indices) 195 | assert step_duration >= 1 196 | 197 | if self.remove_background or num_background_frames == 0: 198 | background_duration = 0 199 | pad = nonbackground_indices[-1] 200 | else: 201 | background_duration = (num_timesteps - step_duration * len(nonbackground_indices)) // (len(nonbackground_indices) + 1) 202 | assert background_duration >= 0 203 | pad = background_index 204 | 205 | indices = [] 206 | for step_ix in nonbackground_indices: 207 | if not self.remove_background: 208 | indices.extend([background_index] * background_duration) 209 | indices.extend([step_ix] * step_duration) 210 | 211 | if not self.remove_background: 212 | assert len(indices) <= num_timesteps 213 | assert num_timesteps - len(indices) - background_duration <= len(nonbackground_indices) + 1 214 | indices.extend([pad] * (num_timesteps - len(indices))) 215 | # hack for remove_background case: some videos have e.g. only 6 frames for 8 steps 216 | indices = indices[:num_timesteps] 217 | return indices 218 | 219 | def predict(self, test_data: Datasplit): 220 | predictions = {} 221 | loader = make_data_loader(self.args, test_data, batch_by_task=False, shuffle=False, batch_size=1) 222 | 223 | for batch in loader: 224 | features = batch['features'].squeeze(0) 225 | num_timesteps = features.size(0) 226 | 227 | tasks = batch['task_name'] 228 | assert len(tasks) == 1 229 | task = next(iter(tasks)) 230 | videos = batch['video_name'] 231 | assert len(videos) == 1 232 | video = next(iter(videos)) 233 | 234 | predictions[video] = self.predict_single(task, num_timesteps) 235 | return predictions 236 | 237 | class SequentialPredictFrames(nn.Module): 238 | @classmethod 239 | def add_args(cls, parser): 240 | Encoder.add_args(parser) 241 | parser.add_argument('--seq_hidden_size', type=int, default=200) 242 | 243 | def __init__(self, args, input_dim, num_classes): 244 | super(SequentialPredictFrames, self).__init__() 245 | self.args = args 246 | self.input_dim = input_dim 247 | self.num_classes = num_classes 248 | self.encoder = Encoder(self.args, input_dim, args.seq_hidden_size) 249 | self.proj = nn.Linear(args.seq_hidden_size, num_classes) 250 | 251 | def forward(self, features, lengths, valid_classes_per_instance=None): 252 | # batch_size x max_len x seq_hidden_size 253 | encoded = self.encoder(features, lengths, output_padding_value=0) 254 | # batch_size x max_len x num_classes 255 | logits = self.proj(encoded) 256 | if valid_classes_per_instance is not None: 257 | assert all_equal(set(vc.detach().cpu().numpy()) for vc in 258 | valid_classes_per_instance), "must have same valid_classes for all instances in the batch" 259 | valid_classes = valid_classes_per_instance[0] 260 | mask = torch.full_like(logits, -float("inf")) 261 | mask[:,:,valid_classes] = 0 262 | logits = logits + mask 263 | return logits 264 | 265 | class SequentialDiscriminative(Model): 266 | @classmethod 267 | def add_args(cls, parser): 268 | SequentialPredictFrames.add_args(parser) 269 | 270 | @classmethod 271 | def from_args(cls, args, train_data: Datasplit): 272 | return cls(args, train_data) 273 | 274 | def __init__(self, args, train_data: Datasplit): 275 | self.args = args 276 | #self.n_classes = sum(len(indices) for indices in train_data.groundtruth.indices_by_task.values()) 277 | self.n_classes = train_data._corpus.n_classes 278 | self.model = SequentialPredictFrames(args, input_dim=train_data.feature_dim, num_classes=self.n_classes) 279 | if args.cuda: 280 | self.model.cuda() 281 | 282 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 283 | assert use_labels 284 | IGNORE = -100 285 | loss = nn.CrossEntropyLoss(ignore_index=IGNORE) 286 | optimizer, scheduler = make_optimizer(self.args, self.model.parameters()) 287 | loader = make_data_loader(self.args, train_data, batch_by_task=False, shuffle=True, batch_size=self.args.batch_size) 288 | 289 | for epoch in range(self.args.epochs): 290 | # call here since we may set eval in callback_fn 291 | self.model.train() 292 | losses = [] 293 | assert self.args.batch_accumulation <= 1 294 | for batch in tqdm.tqdm(loader, ncols=80): 295 | # for batch in loader: 296 | tasks = batch['task_name'] 297 | videos = batch['video_name'] 298 | features = batch['features'] 299 | gt_single = batch['gt_single'] 300 | task_indices = batch['task_indices'] 301 | max_len = features.size(1) 302 | lengths = batch['lengths'] 303 | invalid_mask = torch.arange(max_len).expand(len(lengths), max_len) >= lengths.unsqueeze(1) 304 | if self.args.cuda: 305 | features = features.cuda() 306 | lengths = lengths.cuda() 307 | task_indices = [indx.cuda() for indx in task_indices] 308 | gt_single = gt_single.cuda() 309 | invalid_mask = invalid_mask.cuda() 310 | gt_single.masked_fill_(invalid_mask, IGNORE) 311 | # batch_size x max_len x num_classes 312 | logits = self.model(features, lengths, valid_classes_per_instance=task_indices) 313 | 314 | this_loss = loss(logits.view(-1, logits.size(-1)), gt_single.flatten()) 315 | losses.append(this_loss.item()) 316 | this_loss.backward() 317 | 318 | optimizer.step() 319 | self.model.zero_grad() 320 | train_loss = np.mean(losses) 321 | if scheduler is not None: 322 | scheduler.step(train_loss) 323 | callback_fn(epoch, {'train_loss': train_loss}) 324 | 325 | # if evaluate_on_data_fn is not None: 326 | # train_mof = evaluate_on_data_fn(self, train_data, 'train') 327 | # dev_mof = evaluate_on_data_fn(self, dev_data, 'dev') 328 | # dev_mof_by_epoch[epoch] = dev_mof 329 | # log_str += ("\ttrain mof: {:.4f}".format(train_mof)) 330 | # log_str += ("\tdev mof: {:.4f}".format(dev_mof)) 331 | 332 | def predict(self, test_data: Datasplit): 333 | self.model.eval() 334 | predictions = {} 335 | loader = make_data_loader(self.args, test_data, batch_by_task=False, shuffle=False, batch_size=1) 336 | for batch in loader: 337 | features = batch['features'] 338 | lengths = batch['lengths'] 339 | task_indices = batch['task_indices'] 340 | if self.args.cuda: 341 | features = features.cuda() 342 | lengths = lengths.cuda() 343 | task_indices = [indx.cuda() for indx in task_indices] 344 | videos = batch['video_name'] 345 | assert all_equal(videos) 346 | video = next(iter(videos)) 347 | # batch_size x length x num_classes 348 | with torch.no_grad(): 349 | logits = self.model(features, lengths, valid_classes_per_instance=task_indices) 350 | preds = logits.max(dim=-1)[1] 351 | preds = preds.squeeze(0) 352 | assert preds.ndim == 1 353 | predictions[video] = preds.detach().cpu().numpy() 354 | return predictions 355 | 356 | -------------------------------------------------------------------------------- /src/data/breakfast.py: -------------------------------------------------------------------------------- 1 | # modified from slim_mallow by Anna Kukleva, https://github.com/Annusha/slim_mallow 2 | 3 | import os 4 | import re 5 | from collections import Counter, defaultdict 6 | 7 | import numpy as np 8 | 9 | from data.features import grouped_pca 10 | from data.corpus import Corpus, GroundTruth, Video, Datasplit 11 | from utils.logger import logger 12 | from utils.utils import all_equal 13 | 14 | 15 | 16 | class BreakfastDatasplit(Datasplit): 17 | def __init__(self, corpus, remove_background, task_filter=None, splits=None, full=True, subsample=1, feature_downscale=1.0, 18 | feature_permutation_seed=None): 19 | if splits is None: 20 | splits = list(sorted(BreakfastCorpus.DATASPLITS.keys())) 21 | self._splits = splits 22 | self._tasks = BreakfastCorpus.TASKS[:] if task_filter is None else task_filter 23 | self._p_files = [] 24 | # split 25 | assert all(split in BreakfastCorpus.DATASPLITS for split in splits) 26 | 27 | for split, p_files in sorted(BreakfastCorpus.DATASPLITS.items()): 28 | if split in splits: 29 | assert len(set(p_files) & set(self._p_files)) == 0, "{} : {}".format(set(p_files), set(self._p_files)) 30 | self._p_files.extend(p_files) 31 | 32 | super(BreakfastDatasplit, self).__init__( 33 | corpus, 34 | remove_background=remove_background, 35 | full=full, 36 | subsample=subsample, 37 | feature_downscale=feature_downscale, 38 | feature_permutation_seed=feature_permutation_seed 39 | ) 40 | 41 | def _load_ground_truth_and_videos(self, remove_background): 42 | self.groundtruth = BreakfastGroundTruth( 43 | self._corpus, 44 | task_names=self._tasks, 45 | p_files=self._p_files, 46 | remove_background=remove_background, 47 | ) 48 | 49 | k_by_task = {} 50 | for task, gts in self.groundtruth.gt_by_task.items(): 51 | uniq_labels = set() 52 | for filename, labels in gts.items(): 53 | uniq_labels = uniq_labels.union(labels_t[0] for labels_t in labels) 54 | assert -1 not in uniq_labels 55 | # if -1 in uniq_labels: 56 | # k_by_task[task] = len(uniq_labels) - 1 57 | # else: 58 | # k_by_task[task] = len(uniq_labels) 59 | k_by_task[task] = len(uniq_labels) 60 | self._K_by_task = k_by_task 61 | self._init_videos() 62 | 63 | def _init_videos(self): 64 | # TODO: move to super class? 65 | gt_stat = Counter() 66 | video_names = set() 67 | for root, dirs, files in os.walk(self._corpus._feature_root): 68 | if files: 69 | for filename in files: 70 | if not filename.endswith(".npy"): 71 | continue 72 | matching_tasks = [ 73 | task for task in self._tasks if task in filename 74 | ] 75 | assert len(matching_tasks) <= 1, "{} matched by {}".format(filename, matching_tasks) 76 | if not matching_tasks: 77 | continue 78 | task = matching_tasks[0] 79 | match = re.match(r'(\w*)\.\w*', filename) 80 | gt_name = match.group(1) 81 | p_name = gt_name.split('_')[0] 82 | if p_name not in self._p_files: 83 | continue 84 | if gt_name not in self.groundtruth.gt_by_task[task]: 85 | print("skipping video {} for which no ground truth found!".format(gt_name)) 86 | continue 87 | if not self._full and len(self._videos_by_task[task]) > 10: 88 | continue 89 | # use extracted features from pretrained on gt embedding 90 | # path = os.path.join(root, filename) 91 | if self._remove_background: 92 | nonbackground_timesteps = self.groundtruth.nonbackground_timesteps_by_task[task][gt_name] 93 | else: 94 | nonbackground_timesteps = None 95 | video = BreakfastVideo( 96 | # path, 97 | root, 98 | remove_background=self._remove_background, 99 | nonbackground_timesteps=nonbackground_timesteps, 100 | K=self._K_by_task[task], 101 | gt=self.groundtruth.gt_by_task[task][gt_name], 102 | gt_with_background=self.groundtruth.gt_with_background_by_task[task][gt_name], 103 | name=gt_name, 104 | cache_features=self._corpus._cache_features, 105 | feature_permutation_seed=self._feature_permutation_seed, 106 | ) 107 | # self._features = join_data(self._features, video.features(), 108 | # np.vstack) 109 | 110 | # video.reset() # to not store second time loaded features 111 | if task not in self._videos_by_task: 112 | self._videos_by_task[task] = {} 113 | assert video.name not in self._videos_by_task[task] 114 | self._videos_by_task[task][video.name] = video 115 | video_names.add(video.name) 116 | # accumulate statistic for inverse counts vector for each video 117 | gt_stat.update(labels_t[0] for labels_t in self.groundtruth.gt_by_task[task][gt_name]) 118 | 119 | logger.debug( 120 | "{} tasks found with tasks {}, p_files {}".format(len(self._videos_by_task), self._tasks, self._p_files)) 121 | logger.debug("{} videos found with tasks {}, p_files {}".format(len(video_names), self._tasks, self._p_files)) 122 | 123 | # # update global range within the current collection for each video 124 | # for video in self._videos: 125 | # video.update_indexes(len(self._features)) 126 | logger.debug('gt statistic: ' + str(gt_stat)) 127 | # FG_MASK 128 | # self._update_fg_mask() 129 | 130 | # FG_MASK 131 | # def _update_fg_mask(self): 132 | # logger.debug('.') 133 | # if self._with_bg: 134 | # self._total_fg_mask = np.zeros(len(self._features), dtype=bool) 135 | # for video in self._videos: 136 | # self._total_fg_mask[np.nonzero(video.global_range)[0][video.fg_mask]] = True 137 | # else: 138 | # self._total_fg_mask = np.ones(len(self._features), dtype=bool) 139 | 140 | 141 | 142 | class BreakfastCorpus(Corpus): 143 | BACKGROUND_LABELS = ["SIL"] 144 | 145 | TASKS = [ 146 | 'coffee', 'cereals', 'tea', 'milk', 'juice', 'sandwich', 'scrambledegg', 'friedegg', 'salat', 'pancake' 147 | ] 148 | 149 | DATASPLITS = { 150 | 's1': ["P{:02d}".format(d) for d in range(3, 16)], 151 | 's2': ["P{:02d}".format(d) for d in range(16, 29)], 152 | 's3': ["P{:02d}".format(d) for d in range(29, 42)], 153 | 's4': ["P{:02d}".format(d) for d in range(42, 55)], 154 | } 155 | assert all_equal(len(v) for v in DATASPLITS.values()) 156 | 157 | def __init__(self, mapping_file, feature_root, label_root, task_specific_steps=False): 158 | self._mapping_file = mapping_file 159 | self._feature_root = feature_root 160 | self._label_root = label_root 161 | self._task_specific_steps = task_specific_steps 162 | assert not task_specific_steps 163 | self.annotate_background_with_previous = False 164 | 165 | super(BreakfastCorpus, self).__init__(background_labels=self.BACKGROUND_LABELS) 166 | 167 | def _get_components_for_label(self, label): 168 | return label.split('_') 169 | 170 | def _load_mapping(self): 171 | with open(self._mapping_file, 'r') as f: 172 | for line in f: 173 | index, label = line.strip().split() 174 | index = int(index) 175 | _index = self._index(label) 176 | if label in self._background_labels: 177 | assert index in self._background_indices 178 | if index in self._background_indices: 179 | assert label in self._background_labels 180 | assert _index == index 181 | 182 | def get_datasplit(self, remove_background, task_filter=None, splits=None, full=True, subsample=1, feature_downscale=1.0, 183 | feature_permutation_seed=None): 184 | return BreakfastDatasplit(self, remove_background, task_filter=task_filter, splits=splits, 185 | full=full, subsample=subsample, feature_downscale=feature_downscale, 186 | feature_permutation_seed=feature_permutation_seed) 187 | 188 | def datasets_by_task(mapping_file, feature_root, label_root, remove_background, 189 | task_ids=None, splits=BreakfastCorpus.DATASPLITS.keys(), full=True): 190 | if task_ids is None: 191 | task_ids = BreakfastCorpus.TASKS 192 | corpus = BreakfastCorpus(mapping_file, feature_root, label_root) 193 | return { 194 | task_id: corpus.get_datasplit(remove_background, [task_id], splits, full) 195 | for task_id in task_ids 196 | } 197 | 198 | class BreakfastGroundTruth(GroundTruth): 199 | 200 | def __init__(self, corpus, task_names, p_files, remove_background): 201 | self._p_files = set(p_files) 202 | super(BreakfastGroundTruth, self).__init__(corpus, task_names, remove_background) 203 | 204 | def _load_gt(self): 205 | annotation_count = 0 206 | for root, dirs, files in os.walk(self._corpus._label_root): 207 | for filename in files: 208 | if not filename.endswith(".txt"): 209 | continue 210 | p_file = filename.split('_')[0] 211 | if p_file not in self._p_files: 212 | continue 213 | matching_tasks = [ 214 | task for task in self._task_names if task in filename 215 | ] 216 | assert len(matching_tasks) <= 1, "{} matched by {}".format(filename, matching_tasks) 217 | if not matching_tasks: 218 | continue 219 | task = matching_tasks[0] 220 | 221 | # ** load labels ** 222 | gt = [] 223 | order = [] 224 | with open(os.path.join(root, filename), 'r') as f: 225 | for line in f: 226 | match = re.match(r'(\d*)-(\d*)\s*(\w*)', line) 227 | start = int(match.group(1)) 228 | end = int(match.group(2)) 229 | if end < start: 230 | assert match.group(3) == self._corpus.BACKGROUND_LABELS[0] 231 | continue 232 | assert start > len(gt) - 1 233 | label = match.group(3) 234 | label_idx = self._corpus._index(label) 235 | # gt should be a list of lists, since other corpora can have multiple labels per timestep 236 | gt += [[label_idx]] * (end - start + 1) 237 | order.append((label_idx, start, end)) 238 | 239 | annotation_count += 1 240 | 241 | # ** get vid_name to match feature names ** 242 | up_to_cam, cam_name = os.path.split(root) 243 | if cam_name == 'stereo': 244 | cam_name = 'stereo01' 245 | _, p_name = os.path.split(up_to_cam) 246 | 247 | match = re.match(r'(\w*)_ch(\d+)\.\w*', filename) 248 | if match: 249 | gt_name = match.group(1) 250 | index = int(match.group(2)) 251 | else: 252 | match = re.match(r'(\w*)\.\w*', filename) 253 | gt_name = match.group(1) 254 | index = 0 255 | 256 | # skip videos for which the length of the features and the labels differ by more than 50 257 | # TODO: get the processed version of the data that fixes this! 258 | if (gt_name, cam_name) in [ 259 | ("P51_coffee", "webcam01"), 260 | ("P34_coffee", "cam01"), 261 | ("P34_juice", "cam01"), 262 | ("P52_sandwich", "stereo01"), 263 | ("P54_scrambledegg", "webcam01"), 264 | ("P34_scrambledegg", "cam01"), 265 | ("P34_friedegg", "cam01"), 266 | ("P54_pancake", "cam01"), 267 | ("P52_pancake", "webcam01"), 268 | ]: 269 | continue 270 | 271 | vid_name = "{}_{}_{}".format(p_name, cam_name, gt_name) 272 | 273 | if task not in self.order_by_task: 274 | self.order_by_task[task] = {} 275 | if task not in self.gt_by_task: 276 | self.gt_by_task[task] = {} 277 | 278 | self.gt_by_task[task][vid_name] = gt 279 | self.order_by_task[task][vid_name] = order 280 | logger.debug("{} annotation files found".format(annotation_count)) 281 | 282 | # def _load_gt(self): 283 | # self.gt, self.order = {}, {} 284 | # for filename in os.listdir(self.label_root): 285 | # if os.path.isdir(os.path.join(self.label_root, filename)): 286 | # continue 287 | # with open(os.path.join(self.label_root, filename), 'r') as f: 288 | # labels = [] 289 | # local_order = [] 290 | # curr_lab = -1 291 | # start, end = 0, 0 292 | # for line in f: 293 | # line = line.split()[0] 294 | # try: 295 | # labels.append(self.label2index[line]) 296 | # if curr_lab != labels[-1]: 297 | # if curr_lab != -1: 298 | # local_order.append([curr_lab, start, end]) 299 | # curr_lab = labels[-1] 300 | # start = end 301 | # end += 1 302 | # except KeyError: 303 | # break 304 | # else: 305 | # # executes every times when "for" wasn't interrupted by break 306 | # self.gt[filename] = np.array(labels) 307 | # # add last labels 308 | # 309 | # local_order.append([curr_lab, start, end]) 310 | # self.order[filename] = local_order 311 | 312 | 313 | class BreakfastVideo(Video): 314 | 315 | def load_features(self): 316 | # feats = _features = np.loadtxt(os.path.join(self._feature_root, "{}.txt".format(self.name))) 317 | feats = np.load(os.path.join(self._feature_root, "{}.npy".format(self.name))) 318 | feats = feats[1:, 1:] 319 | return feats 320 | 321 | def extract_feature_groups(corpus): 322 | group_indices = { 323 | 'reduced_64': (0, 64), 324 | } 325 | n_instances = len(corpus) 326 | grouped = defaultdict(dict) 327 | for idx in range(n_instances): 328 | instance = corpus._get_by_index(idx) 329 | video_name = instance['video_name'] 330 | features = instance['features'] 331 | for group, (start, end) in group_indices.items(): 332 | grouped[group][video_name] = features[:, start:end] 333 | return grouped 334 | 335 | def pca_and_serialize_features(mapping_file, feature_root, label_root, output_feature_root, remove_background, 336 | pca_components_per_group=300, by_task=True, task_ids=None): 337 | all_splits = BreakfastCorpus.DATASPLITS.keys() 338 | if by_task: 339 | grouped_datasets = datasets_by_task(mapping_file, feature_root, label_root, remove_background, 340 | task_ids=task_ids, splits=all_splits, full=True) 341 | else: 342 | corpus = BreakfastCorpus(mapping_file, feature_root, label_root) 343 | grouped_datasets = { 344 | 'all': corpus.get_datasplit(remove_background, splits=all_splits) 345 | } 346 | 347 | os.makedirs(output_feature_root, exist_ok=True) 348 | 349 | for corpora_group, dataset in grouped_datasets.items(): 350 | logger.debug("saving features for task: {}".format(corpora_group)) 351 | grouped_features = extract_feature_groups(dataset) 352 | transformed, pca_models = grouped_pca(grouped_features, pca_components_per_group, pca_models_by_group=None) 353 | for feature_group, vid_dict in transformed.items(): 354 | logger.debug("\tsaving features for feature group: {}".format(feature_group)) 355 | feature_group_dir = os.path.join(output_feature_root, feature_group) 356 | os.makedirs(feature_group_dir, exist_ok=True) 357 | for vid, features in vid_dict.items(): 358 | fname = os.path.join(feature_group_dir, '{}.npy'.format(vid)) 359 | np.save(fname, features) 360 | 361 | 362 | if __name__ == "__main__": 363 | _mapping_file = 'data/breakfast/mapping.txt' 364 | _feature_root = 'data/breakfast/reduced_fv_64' 365 | _label_root = 'data/breakfast/BreakfastII_15fps_qvga_sync' 366 | 367 | _components = 64 368 | for _remove_background in [False, True]: 369 | for _by_task in [True]: 370 | _output_feature_root = 'data/breakfast/breakfast_processed/breakfast_pca-{}_{}_{}'.format( 371 | _components, 372 | 'no-bkg' if _remove_background else 'with-bkg', 373 | 'by-task' if _by_task else 'all-tasks', 374 | ) 375 | 376 | pca_and_serialize_features(_mapping_file, _feature_root, _label_root, _output_feature_root, _remove_background, 377 | pca_components_per_group=_components, by_task=_by_task) 378 | -------------------------------------------------------------------------------- /src/models/semimarkov/semimarkov.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import tqdm 4 | import time 5 | 6 | import torch 7 | from data.corpus import Datasplit 8 | from models.model import Model, make_optimizer, make_data_loader 9 | from models.semimarkov.semimarkov_modules import SemiMarkovModule, ComponentSemiMarkovModule 10 | from models.semimarkov import semimarkov_utils 11 | 12 | from utils.utils import all_equal 13 | 14 | 15 | class SemiMarkovModel(Model): 16 | @classmethod 17 | def add_args(cls, parser): 18 | SemiMarkovModule.add_args(parser) 19 | ComponentSemiMarkovModule.add_args(parser) 20 | parser.add_argument('--sm_component_model', action='store_true') 21 | 22 | parser.add_argument('--sm_constrain_transitions', action='store_true') 23 | 24 | parser.add_argument('--sm_constrain_with_narration', choices=['train', 'test'], nargs='*', default=[]) 25 | parser.add_argument('--sm_constrain_narration_weight', type=float, default=-1e4) 26 | 27 | parser.add_argument('--sm_train_discriminatively', action='store_true') 28 | 29 | parser.add_argument('--sm_hidden_markov', action='store_true', help='train as hidden markov model (fix K=1) and length distribution') 30 | 31 | parser.add_argument('--sm_predict_single', action='store_true') 32 | 33 | @classmethod 34 | def from_args(cls, args, train_data): 35 | n_classes = train_data.corpus.n_classes 36 | feature_dim = train_data.feature_dim 37 | 38 | allow_self_transitions = True 39 | 40 | assert args.sm_max_span_length is not None 41 | if args.sm_constrain_transitions: 42 | # assert args.task_specific_steps, "will get bad results with --sm_constrain_transitions if you don't also pass --task_specific_steps, because of multiple exits" 43 | # TODO: figure out what I meant by multiple exits; this seems fine at least if you're using the ComponentSemiMarkovModule. maybe add a check for this? 44 | # if not args.remove_background: 45 | # raise NotImplementedError("--sm_constrain_transitions without --remove_background ") 46 | 47 | ( 48 | allowed_starts, allowed_transitions, allowed_ends, ordered_indices_by_task 49 | ) = train_data.get_allowed_starts_and_transitions() 50 | if allow_self_transitions: 51 | for src in range(n_classes): 52 | if src not in allowed_transitions: 53 | allowed_transitions[src] = set() 54 | allowed_transitions[src].add(src) 55 | else: 56 | allowed_starts, allowed_transitions, allowed_ends, ordered_indices_by_task = None, None, None, None 57 | 58 | if args.annotate_background_with_previous and not args.no_merge_classes: 59 | merge_classes = {} 60 | for task, indices in train_data.corpus._indices_by_task.items(): 61 | background_indices = [ix for ix in indices if ix in train_data.corpus._background_indices] 62 | nonbackground_indices = [ix for ix in indices if ix not in train_data.corpus._background_indices] 63 | canon_bkg_ix = background_indices[0] 64 | for ix in background_indices: 65 | if ix in merge_classes: 66 | assert merge_classes[ix] == canon_bkg_ix 67 | else: 68 | merge_classes[ix] = canon_bkg_ix 69 | # assert ix not in merge_classes 70 | for ix in nonbackground_indices: 71 | if ix in merge_classes: 72 | assert merge_classes[ix] == ix 73 | else: 74 | merge_classes[ix] = ix 75 | # assert ix not in merge_classes 76 | # merge_classes[ix] = ix 77 | else: 78 | merge_classes = None 79 | 80 | if args.sm_component_model: 81 | if args.sm_component_decompose_steps: 82 | # assert not args.task_specific_steps, "can't decompose steps unless steps are across tasks; you should remove --task_specific_steps" 83 | n_components = train_data.corpus.n_components 84 | class_to_components = copy.copy(train_data.corpus.label_indices2component_indices) 85 | else: 86 | n_components = n_classes 87 | class_to_components = { 88 | cls: {cls} 89 | for cls in range(n_classes) 90 | } 91 | model = ComponentSemiMarkovModule( 92 | args, 93 | n_classes, 94 | n_components=n_components, 95 | class_to_components=class_to_components, 96 | feature_dim=feature_dim, 97 | allow_self_transitions=allow_self_transitions, 98 | allowed_starts=allowed_starts, 99 | allowed_transitions=allowed_transitions, 100 | allowed_ends=allowed_ends, 101 | merge_classes=merge_classes, 102 | ) 103 | else: 104 | model = SemiMarkovModule( 105 | args, 106 | n_classes, 107 | feature_dim, 108 | allow_self_transitions=allow_self_transitions, 109 | allowed_starts=allowed_starts, 110 | allowed_transitions=allowed_transitions, 111 | allowed_ends=allowed_ends, 112 | merge_classes=merge_classes, 113 | ) 114 | return SemiMarkovModel(args, n_classes, feature_dim, model, ordered_indices_by_task) 115 | 116 | def __init__(self, args, n_classes, feature_dim, model, ordered_indices_by_task=None): 117 | self.args = args 118 | self.n_classes = n_classes 119 | self.feature_dim = feature_dim 120 | self.model = model 121 | self.ordered_indices_by_task = ordered_indices_by_task 122 | if args.cuda: 123 | self.model.cuda() 124 | 125 | def fit_supervised(self, train_data: Datasplit): 126 | assert not self.args.sm_component_model 127 | assert not self.args.sm_constrain_transitions 128 | loader = make_data_loader(self.args, train_data, batch_by_task=False, shuffle=False, batch_size=1) 129 | features, labels = [], [] 130 | for batch in loader: 131 | features.append(batch['features'].squeeze(0)) 132 | labels.append(batch['gt_single'].squeeze(0)) 133 | self.model.fit_supervised(features, labels) 134 | 135 | def make_additional_allowed_ends(self, tasks, lengths): 136 | if self.ordered_indices_by_task is not None: 137 | addl_allowed_ends = [] 138 | for task, length in zip(tasks, lengths): 139 | ord_indices = self.ordered_indices_by_task[task] 140 | if length.item() < len(ord_indices): 141 | this_allowed_ends = [ord_indices[length.item()-1]] 142 | else: 143 | this_allowed_ends = [] 144 | addl_allowed_ends.append(this_allowed_ends) 145 | else: 146 | addl_allowed_ends = None 147 | return addl_allowed_ends 148 | 149 | def expand_constraints(self, datasplit, task, task_indices, constraints): 150 | task_indices = list(task_indices.cpu().numpy()) 151 | step_indices = datasplit.get_ordered_indices_no_background()[task] 152 | # constraints: batch_dim x T x K 153 | assert constraints.size(2) == len(step_indices) 154 | constraints_expanded = torch.zeros((constraints.size(0), constraints.size(1), len(task_indices))) 155 | for index, label in enumerate(step_indices): 156 | constraints_expanded[:,:,task_indices.index(label)] = constraints[:,:,index] 157 | return constraints_expanded 158 | 159 | def fit(self, train_data: Datasplit, use_labels: bool, callback_fn=None): 160 | self.model.train() 161 | self.model.flatten_parameters() 162 | if use_labels: 163 | assert not self.args.sm_constrain_transitions 164 | initialize = True 165 | if use_labels and self.args.sm_supervised_method in ['closed-form', 'closed-then-gradient']: 166 | self.fit_supervised(train_data) 167 | if self.args.sm_supervised_method == 'closed-then-gradient': 168 | initialize = False 169 | callback_fn(-1, {}) 170 | else: 171 | return 172 | if self.args.sm_init_non_projection_parameters_from: 173 | initialize = False 174 | if callback_fn: 175 | callback_fn(-1, {}) 176 | optimizer, scheduler = make_optimizer(self.args, self.model.parameters()) 177 | big_loader = make_data_loader(self.args, train_data, batch_by_task=False, shuffle=True, batch_size=100) 178 | samp = next(iter(big_loader)) 179 | big_features = samp['features'] 180 | big_lengths = samp['lengths'] 181 | if self.args.cuda: 182 | big_features = big_features.cuda() 183 | big_lengths = big_lengths.cuda() 184 | 185 | if initialize: 186 | self.model.initialize_gaussian(big_features, big_lengths) 187 | 188 | loader = make_data_loader(self.args, train_data, batch_by_task=True, shuffle=True, batch_size=self.args.batch_size) 189 | 190 | # print('{} videos in training data'.format(len(loader.dataset))) 191 | 192 | # all_features = [sample['features'] for batch in loader for sample in batch] 193 | # if self.args.cuda: 194 | # all_features = [feats.cuda() for feats in all_features] 195 | 196 | C = self.n_classes 197 | K = self.args.sm_max_span_length 198 | 199 | for epoch in range(self.args.epochs): 200 | start_time = time.time() 201 | # call here since we may set eval in callback_fn 202 | self.model.train() 203 | losses = [] 204 | multi_batch_losses = [] 205 | nlls = [] 206 | kls = [] 207 | log_dets = [] 208 | num_frames = 0 209 | num_videos = 0 210 | train_nll = 0 211 | train_kl = 0 212 | train_log_det = 0 213 | # for batch_ix, batch in enumerate(tqdm.tqdm(loader, ncols=80)): 214 | for batch_ix, batch in enumerate(loader): 215 | if self.args.train_limit and batch_ix >= self.args.train_limit: 216 | break 217 | # if self.args.cuda: 218 | # features = features.cuda() 219 | # task_indices = task_indices.cuda() 220 | # gt_single = gt_single.cuda() 221 | tasks = batch['task_name'] 222 | videos = batch['video_name'] 223 | features = batch['features'] 224 | task_indices = batch['task_indices'] 225 | lengths = batch['lengths'] 226 | 227 | if 'train' in self.args.sm_constrain_with_narration: 228 | assert all_equal(tasks) 229 | constraints_expanded = self.expand_constraints( 230 | train_data, tasks[0], task_indices[0], 1 - batch['constraints'] 231 | ) 232 | constraints_expanded *= self.args.sm_constrain_narration_weight 233 | else: 234 | constraints_expanded = None 235 | 236 | num_frames += lengths.sum().item() 237 | num_videos += len(lengths) 238 | 239 | # assert len( task_indices) == self.n_classes, "remove_background and multi-task fit() not implemented" 240 | 241 | if self.args.cuda: 242 | features = features.cuda() 243 | lengths = lengths.cuda() 244 | if constraints_expanded is not None: 245 | constraints_expanded = constraints_expanded.cuda() 246 | 247 | if use_labels: 248 | labels = batch['gt_single'] 249 | if self.args.cuda: 250 | labels = labels.cuda() 251 | spans = semimarkov_utils.labels_to_spans(labels, max_k=K) 252 | use_mean_z = True 253 | else: 254 | spans = None 255 | use_mean_z = False 256 | 257 | addl_allowed_ends = self.make_additional_allowed_ends(tasks, lengths) 258 | 259 | ll, log_det = self.model.log_likelihood(features, 260 | lengths, 261 | valid_classes_per_instance=task_indices, 262 | spans=spans, 263 | add_eos=True, 264 | use_mean_z=use_mean_z, 265 | additional_allowed_ends_per_instance=addl_allowed_ends, 266 | constraints=constraints_expanded) 267 | nll = -ll 268 | kl = self.model.kl.mean() 269 | if use_labels: 270 | this_loss = nll - log_det 271 | else: 272 | this_loss = nll - log_det + kl 273 | multi_batch_losses.append(this_loss) 274 | nlls.append(nll.item()) 275 | kls.append(kl.item()) 276 | log_dets.append(log_det.item()) 277 | 278 | train_nll += (nll.item() * len(videos)) 279 | train_kl += (kl.item() * len(videos)) 280 | train_log_det += (log_det.item() * len(videos)) 281 | 282 | losses.append(this_loss.item()) 283 | 284 | if len(multi_batch_losses) >= self.args.batch_accumulation: 285 | loss = sum(multi_batch_losses) / len(multi_batch_losses) 286 | loss.backward() 287 | multi_batch_losses = [] 288 | 289 | if self.args.print_every and (batch_ix % self.args.print_every == 0): 290 | param_norm = sum([p.norm()**2 for p in self.model.parameters() 291 | if p.requires_grad]).item()**0.5 292 | gparam_norm = sum([p.grad.norm()**2 for p in self.model.parameters() 293 | if p.requires_grad and p.grad is not None]).item()**0.5 294 | log_str = 'Epoch: %02d, Batch: %03d/%03d, |Param|: %.6f, |GParam|: %.2f, lr: %.2E, ' + \ 295 | 'loss: %.4f, recon: %.4f, kl: %.4f, log_det: %.4f, recon_bound: %.2f, Throughput: %.2f vid / sec' 296 | print(log_str % 297 | (epoch, batch_ix, len(loader), param_norm, gparam_norm, 298 | optimizer.param_groups[0]["lr"], 299 | (train_nll + train_kl + train_log_det) / num_videos, # loss 300 | train_nll / num_frames, # recon 301 | train_kl / num_frames, # kl 302 | train_log_det / num_videos, # log_det 303 | (train_nll + train_kl) / num_frames, # recon_bound 304 | num_videos / (time.time() - start_time))) # Throughput 305 | if self.args.max_grad_norm is not None: 306 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 307 | 308 | optimizer.step() 309 | self.model.zero_grad() 310 | train_loss = np.mean(losses) 311 | if scheduler is not None: 312 | scheduler.step(train_loss) 313 | callback_fn(epoch, {'train_loss': train_loss, 314 | 'train_nll_frame_avg': train_nll / num_frames, 315 | 'train_kl_vid_avg': train_kl / num_videos, 316 | 'train_recon_bound': (train_nll + train_kl) / num_frames}) 317 | 318 | def predict(self, test_data): 319 | self.model.eval() 320 | self.model.flatten_parameters() 321 | predictions = {} 322 | loader = make_data_loader(self.args, test_data, shuffle=False, batch_by_task=True, batch_size=self.args.batch_size) 323 | # print('{} videos in prediction data'.format(len(loader.dataset))) 324 | # for batch in tqdm.tqdm(loader, ncols=80): 325 | for batch in loader: 326 | features = batch['features'] 327 | task_indices = batch['task_indices'] 328 | lengths = batch['lengths'] 329 | 330 | # add a batch dimension 331 | # lengths = torch.LongTensor([features.size(0)]).unsqueeze(0) 332 | # features = features.unsqueeze(0) 333 | # task_indices = task_indices.unsqueeze(0) 334 | 335 | videos = batch['video_name'] 336 | tasks = batch['task_name'] 337 | assert len(set(tasks)) == 1 338 | task = next(iter(tasks)) 339 | 340 | if 'test' in self.args.sm_constrain_with_narration: 341 | assert all_equal(tasks) 342 | constraints_expanded = self.expand_constraints( 343 | test_data, task, task_indices[0], 1 - batch['constraints'] 344 | ) 345 | constraints_expanded *= self.args.sm_constrain_narration_weight 346 | else: 347 | constraints_expanded = None 348 | 349 | if self.args.cuda: 350 | features = features.cuda() 351 | task_indices = [ti.cuda() for ti in task_indices] 352 | lengths = lengths.cuda() 353 | if constraints_expanded is not None: 354 | constraints_expanded = constraints_expanded.cuda() 355 | 356 | addl_allowed_ends = self.make_additional_allowed_ends(tasks, lengths) 357 | 358 | def predict(constraints): 359 | # TODO: figure out under which eval conditions use_mean_z should be False 360 | pred_spans, elp = self.model.viterbi(features, lengths, task_indices, add_eos=True, use_mean_z=True, 361 | additional_allowed_ends_per_instance=addl_allowed_ends, 362 | constraints=constraints, return_elp=True) 363 | pred_labels = semimarkov_utils.spans_to_labels(pred_spans) 364 | # if self.args.sm_predict_single: 365 | # # pred_spans: batch_size x T 366 | # pred_labels_single = torch.zeros_like(pred_labels) 367 | # for i in pred_labels.size(0): 368 | # for lab in torch.unique(pred_labels[i,:lengths[i]]): 369 | # #emission_scores: b x N x C 370 | # pred_labels 371 | # pass 372 | 373 | # if self.args.sm_constrain_transitions: 374 | # all_pred_span_indices = [ 375 | # [ix for ix, count in this_rle_spans] 376 | # for this_rle_spans in semimarkov_utils.rle_spans(pred_spans, lengths) 377 | # ] 378 | # for i, indices in enumerate(all_pred_span_indices): 379 | # remove_cons_dups = [ix for ix, group in itertools.groupby(indices) 380 | # if not ix in test_data.corpus._background_indices] 381 | # non_bg_indices = [ 382 | # ix for ix in test_data.corpus.indices_by_task(task) 383 | # if ix not in test_data.corpus._background_indices 384 | # ] 385 | # if len(remove_cons_dups) != len(non_bg_indices) and lengths[i].item() != len(remove_cons_dups): 386 | # print("deduped: {}, indices: {}, length {}".format( 387 | # remove_cons_dups, non_bg_indices, lengths[i].item() 388 | # )) 389 | # # assert lengths[i].item() < len(non_bg_indices) 390 | 391 | pred_labels_trim_s = self.model.trim(pred_labels, lengths, check_eos=True) 392 | return pred_labels_trim_s 393 | 394 | pred_labels_trim_s = predict(constraints_expanded) 395 | 396 | # assert len(pred_labels_trim_s) == 1, "batch size should be 1" 397 | for ix, (video, pred_labels_trim) in enumerate(zip(videos, pred_labels_trim_s)): 398 | preds = pred_labels_trim.numpy() 399 | predictions[video] = preds 400 | # if constraints_expanded is not None: 401 | # this_cons = batch['constraints'][ix] 402 | # if this_cons.sum() > 0: 403 | # step_indices = test_data.get_ordered_indices_no_background()[task] 404 | # for t, label in enumerate(preds): 405 | # if label in step_indices: 406 | # label_ix = step_indices.index(label) 407 | # assert batch['constraints'][ix,t,label_ix] == 1 408 | assert self.model.n_classes not in predictions[video], "predictions should not contain EOS: {}".format( 409 | predictions[video]) 410 | return predictions 411 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | import pprint 5 | import sys 6 | from collections import OrderedDict 7 | import json 8 | 9 | import numpy as np 10 | 11 | from data.breakfast import BreakfastCorpus 12 | from data.corpus import Datasplit 13 | from data.crosstask import CrosstaskCorpus 14 | from models.framewise import FramewiseGaussianMixture, FramewiseDiscriminative, FramewiseBaseline 15 | from models.sequential import SequentialDiscriminative, SequentialCanonicalBaseline, SequentialPredictConstraints, SequentialGroundTruth 16 | from models.model import Model, add_training_args 17 | from models.semimarkov.semimarkov import SemiMarkovModel 18 | from utils.logger import logger 19 | 20 | STAT_KEYS = [ 21 | 'mof', 'mof_non_bg', 'step_recall_non_bg', 'mean_normed_levenshtein', 22 | 'center_step_recall_non_bg', 'f1', 'f1_non_bg', 'pred_background', 'iou_multi_non_bg', 23 | 'predicted_label_types_per_video', 'predicted_label_types_non_bg_per_video', 24 | 'predicted_segments_per_video', 'predicted_segments_non_bg_per_video', 25 | 'multiple_gt_labels', 26 | ] 27 | DISPLAY_STAT_KEYS = [ 28 | 'f1', 'f1_non_bg', 'center_step_recall_non_bg', 'mean_normed_levenshtein', 29 | 'pred_background', 'iou_multi_non_bg', 30 | 'predicted_label_types_per_video', 'predicted_label_types_non_bg_per_video', 31 | 'predicted_segments_per_video', 'predicted_segments_non_bg_per_video', 32 | 'mof', 'mof_non_bg', 33 | 'multiple_gt_labels', 34 | ] 35 | 36 | CLASSIFIERS = { 37 | 'framewise_discriminative': FramewiseDiscriminative, 38 | 'framewise_gaussian_mixture': FramewiseGaussianMixture, 39 | 'framewise_baseline': FramewiseBaseline, 40 | 'semimarkov': SemiMarkovModel, 41 | 'sequential_discriminative': SequentialDiscriminative, 42 | 'sequential_canonical_baseline': SequentialCanonicalBaseline, 43 | 'sequential_predict_constraints': SequentialPredictConstraints, 44 | 'sequential_ground_truth': SequentialGroundTruth, 45 | } 46 | 47 | 48 | def add_serialization_args(parser): 49 | group = parser.add_argument_group('serialization') 50 | group.add_argument('--model_output_path') 51 | group.add_argument('--model_input_path') 52 | group.add_argument('--prediction_output_path') 53 | 54 | 55 | def add_misc_args(parser): 56 | group = parser.add_argument_group('miscellaneous') 57 | group.add_argument('--compare_to_prediction_folder', help='root folder containing *_pred.npy and *_true.npy prediction files (for comparison)') 58 | group.add_argument('--compare_only', 59 | action='store_true', 60 | help="skip everything to do with models and just evaluate these serialized predictions") 61 | group.add_argument('--compare_load_splits_from_predictions', action='store_true') 62 | 63 | 64 | def add_data_args(parser): 65 | group = parser.add_argument_group('data') 66 | group.add_argument('--dataset', choices=['crosstask', 'breakfast'], default='crosstask') 67 | group.add_argument('--features', choices=['raw', 'pca'], default='pca') 68 | group.add_argument('--feature_downscale', type=float, default=1.0) 69 | group.add_argument('--feature_permutation_seed', type=int) 70 | group.add_argument('--batch_size', type=int, default=5) 71 | group.add_argument('--remove_background', action='store_true') 72 | group.add_argument('--pca_components_per_group', type=int, default=100) 73 | group.add_argument('--pca_no_background', action='store_true') 74 | 75 | group.add_argument('--mix_tasks', action='store_true', help='train on all tasks simultaneously') 76 | 77 | group.add_argument('--frame_subsample', type=int, default=1, help="interval to subsample frames at (e.g. 10 takes every 10th frame)") 78 | 79 | group.add_argument('--task_specific_steps', action='store_true', help="") 80 | group.add_argument('--annotate_background_with_previous', action='store_true', help="") 81 | 82 | group.add_argument('--no_merge_classes', action='store_true', help="") 83 | 84 | group.add_argument('--force_optimal_assignment', action='store_true', help="force optimal assignment to maximize MoF") 85 | 86 | group.add_argument('--no_cache_features', action='store_true', help="") 87 | 88 | group.add_argument('--crosstask_feature_groups', 89 | choices=['i3d', 'resnet', 'audio', 'narration'], 90 | nargs='+', default=['i3d', 'resnet', 'audio']) 91 | group.add_argument('--crosstask_training_data', choices=['primary', 'related'], nargs='+', default=['primary']) 92 | 93 | group.add_argument('--crosstask_cross_validation', action='store_true') 94 | # group.add_argument('--crosstask_cross_validation_n_train', type=int, default=30) 95 | group.add_argument('--crosstask_cross_validation_seed', type=int) 96 | 97 | 98 | def add_classifier_args(parser): 99 | group = parser.add_argument_group('classifier') 100 | group.add_argument('--classifier', choices=CLASSIFIERS.keys(), required=True) 101 | group.add_argument('--training', choices=['supervised', 'unsupervised'], default='supervised') 102 | group.add_argument('--cuda', action='store_true') 103 | for name, cls in CLASSIFIERS.items(): 104 | cls.add_args(parser) 105 | 106 | def write_predictions(test_data, predictions_by_video, output_path): 107 | # TODO: unuglify this 108 | for video, pred in predictions_by_video.items(): 109 | labels = [] 110 | task = test_data._tasks_by_video[video] 111 | for index in pred: 112 | if index in test_data._corpus._background_indices: 113 | label = "" 114 | else: 115 | label = test_data._corpus.index2label[index].replace(' ', '_') 116 | labels.append('{}:{}'.format(task, label)) 117 | with open(os.path.join(output_path, video), 'w') as f: 118 | f.write('### Recognized sequence: ###\n') 119 | f.write('\n') # TODO 120 | f.write('### Score: ###\n') 121 | f.write('\n') # TODO 122 | f.write('### Frame level recognition: ###\n') 123 | f.write(' '.join(labels)) 124 | 125 | def test(args, model: Model, test_data: Datasplit, test_data_name: str, verbose=True, prediction_output_path=None): 126 | if args.training == 'supervised': 127 | optimal_assignment = False 128 | else: 129 | assert args.training == 'unsupervised' 130 | # if we're constraining the transitions to be the canonical order in the semimarkov, we don't need oracle reassignment 131 | optimal_assignment = not (args.classifier == 'semimarkov' and args.sm_constrain_transitions) 132 | if 'train' in args.sm_constrain_with_narration or 'test' in args.sm_constrain_with_narration: 133 | optimal_assignment = False 134 | if args.force_optimal_assignment: 135 | optimal_assignment = True 136 | if model is not None: 137 | predictions_by_video = model.predict(test_data) 138 | prediction_function = lambda video: predictions_by_video[video.name] 139 | else: 140 | prediction_function = None 141 | # print('prediction_output_path: {}'.format(prediction_output_path)) 142 | if prediction_output_path is not None: 143 | assert model is not None 144 | write_predictions(test_data, predictions_by_video, prediction_output_path) 145 | stats = test_data.accuracy_corpus( 146 | optimal_assignment, 147 | prediction_function, 148 | prefix=test_data_name, 149 | verbose=verbose, 150 | compare_to_folder=args.compare_to_prediction_folder if not test_data_name.startswith('train') else None 151 | ) 152 | return stats 153 | 154 | 155 | def make_model_path(path, split_name): 156 | if path.endswith('.pkl'): 157 | return path 158 | else: 159 | # is directory 160 | return os.path.join(path, '{}.pkl'.format(split_name)) 161 | 162 | 163 | def train(args, train_data: Datasplit, dev_data: Datasplit, split_name, verbose=False, train_sub_data=None): 164 | model = CLASSIFIERS[args.classifier].from_args(args, train_data) 165 | 166 | if args.training == 'supervised': 167 | use_labels = True 168 | early_stopping_on_dev = True 169 | else: 170 | assert args.training == 'unsupervised' 171 | use_labels = False 172 | early_stopping_on_dev = False 173 | 174 | def evaluate_on_data(data, name): 175 | stats_by_name = test(args, model, data, name, verbose=verbose) 176 | 177 | # all_mof = np.array([stats['mof'] for stats in stats_by_name.values()]) 178 | # sum_mof = all_mof.sum(axis=0) 179 | # 180 | # all_mof_non_bg = np.array([stats['mof_non_bg'] for stats in stats_by_name.values()]) 181 | # sum_mof_non_bg = all_mof_non_bg.sum(axis=0) 182 | # 183 | # all_step_recall_non_bg = np.array([stats['step_recall_non_bg'] for stats in stats_by_name.values()]) 184 | # sum_step_recall_non_bg = all_step_recall_non_bg.sum(axis=0) 185 | # 186 | # all_leven = np.array([stats['mean_normed_levenshtein'] for stats in stats_by_name.values()]) 187 | # sum_leven = all_leven.sum(axis=0) 188 | 189 | d = {} 190 | for key in STAT_KEYS: 191 | all_stats = np.array([stats[key] for stats in stats_by_name.values()]) 192 | sum_stats = all_stats.sum(axis=0) 193 | d['{}_{}'.format(name, key)] = float(sum_stats[0]) / sum_stats[1] 194 | return d 195 | 196 | # return { 197 | # '{}_mof'.format(name): float(sum_mof[0]) / sum_mof[1], 198 | # '{}_mof_non_bg'.format(name): float(sum_mof_non_bg[0]) / sum_mof_non_bg[1], 199 | # '{}_step_recall_non_bg'.format(name): float(sum_step_recall_non_bg[0]) / sum_step_recall_non_bg[1], 200 | # '{}_mean_normed_levenshtein'.format(name): float(sum_leven[0]) / sum_leven[1], 201 | # } 202 | 203 | models_by_epoch = {} 204 | dev_mof_by_epoch = {} 205 | stats_by_epoch = {} 206 | 207 | def callback_fn(epoch, stats): 208 | stats_by_epoch[epoch] = stats 209 | if train_sub_data is not None: 210 | train_name = 'train_subset' 211 | train_stats = evaluate_on_data(train_sub_data, train_name) 212 | else: 213 | train_name = 'train' 214 | train_stats = evaluate_on_data(train_data, train_name) 215 | split_stats = [train_stats] 216 | if epoch == -1 or epoch % args.dev_decode_frequency == 0: 217 | dev_stats = evaluate_on_data(dev_data, 'dev') 218 | split_stats.append(dev_stats) 219 | else: 220 | dev_stats = None 221 | log_str = '{}\tepoch {:2d}'.format(split_name, epoch) 222 | for stat, value in stats.items(): 223 | if isinstance(value, float): 224 | log_str += '\t{} {:.4f}'.format(stat, value) 225 | else: 226 | log_str += '\t{} {}'.format(stat, value) 227 | # log_str += '\t{} '.format(train_name) 228 | for stats in split_stats: 229 | log_str += '\n' 230 | for name, val in sorted(stats.items()): 231 | log_str += ' {} {:.4f}'.format(name, val) 232 | # log_str += '\t{} mof {:.4f}\tdev mof {:.4f}'.format(train_name, train_mof, dev_mof) 233 | logger.debug(log_str) 234 | models_by_epoch[epoch] = pickle.dumps(model) 235 | 236 | if dev_stats is not None: 237 | dev_mof_by_epoch[epoch] = dev_stats['dev_mof'] 238 | 239 | if args.model_output_path and epoch % 5 == 0: 240 | os.makedirs(args.model_output_path, exist_ok=True) 241 | model_fname = os.path.join(args.model_output_path, '{}_epoch-{}.pkl'.format(split_name, epoch)) 242 | print("writing model to {}".format(model_fname)) 243 | with open(model_fname, 'wb') as f: 244 | pickle.dump(model, f) 245 | 246 | model.fit(train_data, use_labels=use_labels, callback_fn=callback_fn) 247 | 248 | if early_stopping_on_dev and dev_mof_by_epoch: 249 | best_dev_epoch, best_dev_mof = max(dev_mof_by_epoch.items(), key=lambda t: t[1]) 250 | logger.debug("best dev mov {:.4f} in epoch {}".format(best_dev_mof, best_dev_epoch)) 251 | best_model = pickle.loads(models_by_epoch[best_dev_epoch]) 252 | elif stats_by_epoch and 'train_loss' in next(iter(stats_by_epoch.values())): 253 | best_epoch, best_train_stats = min(stats_by_epoch.items(), key=lambda t: t[1]['train_loss']) 254 | logger.debug("best train loss {:.4f} in epoch {}".format(best_train_stats['train_loss'], best_epoch)) 255 | best_model = pickle.loads(models_by_epoch[best_epoch]) 256 | else: 257 | best_model = model 258 | 259 | if args.model_output_path: 260 | os.makedirs(args.model_output_path, exist_ok=True) 261 | model_fname = make_model_path(args.model_output_path, split_name) 262 | print("writing model to {}".format(model_fname)) 263 | with open(model_fname, 'wb') as f: 264 | pickle.dump(best_model, f) 265 | 266 | return best_model 267 | 268 | 269 | def make_data_splits(args): 270 | # split_name -> (train_data, test_data) 271 | splits = OrderedDict() 272 | 273 | if args.dataset == 'crosstask': 274 | features_contain_background = True 275 | if args.features == 'pca': 276 | max_components = 200 277 | assert args.pca_components_per_group <= max_components 278 | features_contain_background = not args.pca_no_background 279 | feature_root = 'data/crosstask/crosstask_processed/crosstask_primary_pca-{}_{}-bkg_by-task'.format( 280 | max_components, 281 | "no" if args.pca_no_background else "with", 282 | ) 283 | dimensions_per_feature_group = { 284 | feature_group: args.pca_components_per_group 285 | for feature_group in args.crosstask_feature_groups 286 | } 287 | else: 288 | feature_root = 'data/crosstask/crosstask_features' 289 | dimensions_per_feature_group = None 290 | 291 | corpus = CrosstaskCorpus( 292 | release_root="data/crosstask/crosstask_release", 293 | feature_root=feature_root, 294 | dimensions_per_feature_group=dimensions_per_feature_group, 295 | features_contain_background=features_contain_background, 296 | task_specific_steps=args.task_specific_steps, 297 | annotate_background_with_previous=args.annotate_background_with_previous, 298 | use_secondary='related' in args.crosstask_training_data, 299 | constraints_root='data/crosstask/crosstask_constraints', 300 | load_constraints=True, 301 | ) 302 | corpus._cache_features = True 303 | if args.no_cache_features: 304 | corpus._cache_features = False 305 | train_task_sets = args.crosstask_training_data 306 | 307 | 308 | test_task_sets = ['primary'] 309 | task_ids = sorted([task_id for task_set in sorted(set(train_task_sets) | set(test_task_sets)) 310 | for task_id in CrosstaskCorpus.TASK_IDS_BY_SET[task_set]]) 311 | if args.crosstask_cross_validation: 312 | if train_task_sets != ['primary']: 313 | raise NotImplementedError("cross validation with related tasks") 314 | split_names_and_full = [ 315 | ('cv_train_{}'.format(args.crosstask_cross_validation_seed), True, train_task_sets), 316 | ('cv_train_{}'.format(args.crosstask_cross_validation_seed), False, train_task_sets), 317 | ('cv_test_{}'.format(args.crosstask_cross_validation_seed), True, train_task_sets), 318 | ] 319 | else: 320 | split_names_and_full = [ 321 | ('train', True, train_task_sets), 322 | ('train', False, test_task_sets), 323 | ('val', True, test_task_sets) 324 | ] 325 | if args.compare_load_splits_from_predictions: 326 | assert args.compare_to_prediction_folder 327 | assert args.compare_only 328 | assert not args.crosstask_cross_validation, "just pass --compare_to_prediction_folder, --compare_only, and --compare_load_splits_from_predictions" 329 | with open(os.path.join(args.compare_to_prediction_folder, 'y_pred.json'), 'rb') as f: 330 | preds_by_task_and_video = json.load(f) 331 | val_videos_override = [] 332 | for task, data in preds_by_task_and_video.items(): 333 | val_videos_override.extend(data.keys()) 334 | print("loaded predictions for {} videos; using as the validation set".format(len(val_videos_override))) 335 | else: 336 | val_videos_override = None 337 | 338 | # TODO: here 339 | if args.mix_tasks: 340 | splits['all'] = tuple( 341 | corpus.get_datasplit(remove_background=args.remove_background, 342 | task_sets=task_sets, 343 | task_ids=task_ids, 344 | split=split, 345 | full=full, 346 | subsample=args.frame_subsample, 347 | feature_downscale=args.feature_downscale, 348 | val_videos_override=val_videos_override, 349 | feature_permutation_seed=args.feature_permutation_seed, 350 | ) 351 | for split, full, task_sets in split_names_and_full 352 | ) 353 | train_videos = set(p[1] for p in splits['all'][0]._tasks_and_video_names) 354 | test_videos = set(p[1] for p in splits['all'][2]._tasks_and_video_names) 355 | assert not(train_videos & test_videos),\ 356 | "overlap in train and test videos: {}".format(train_videos & test_videos) 357 | else: 358 | for task_id in task_ids: 359 | splits['{}_val'.format(task_id)] = tuple( 360 | corpus.get_datasplit(remove_background=args.remove_background, 361 | task_sets=task_sets, 362 | task_ids=[task_id], 363 | split=split, 364 | full=full, 365 | subsample=args.frame_subsample, 366 | feature_downscale=args.feature_downscale, 367 | val_videos_override=val_videos_override, 368 | feature_permutation_seed=args.feature_permutation_seed, 369 | ) 370 | for split, full, task_sets in split_names_and_full 371 | ) 372 | elif args.dataset == 'breakfast': 373 | assert not args.annotate_background_with_previous 374 | if args.features == 'pca': 375 | max_components = 64 376 | assert args.pca_components_per_group == max_components 377 | features_contain_background = not args.pca_no_background 378 | assert features_contain_background # not implemented! 379 | feature_root = 'data/breakfast/breakfast_processed/breakfast_pca-{}_{}-bkg_by-task'.format( 380 | max_components, 381 | "no" if args.pca_no_background else "with", 382 | ) 383 | else: 384 | feature_root = 'data/breakfast/reduced_fv_64' 385 | corpus = BreakfastCorpus(mapping_file='data/breakfast/mapping.txt', 386 | feature_root=feature_root, 387 | label_root='data/breakfast/BreakfastII_15fps_qvga_sync', 388 | task_specific_steps=args.task_specific_steps) 389 | corpus._cache_features = True 390 | 391 | all_splits = list(sorted(BreakfastCorpus.DATASPLITS.keys())) 392 | for heldout_split in all_splits: 393 | splits[heldout_split] = ( 394 | corpus.get_datasplit(remove_background=args.remove_background, 395 | splits=[sp for sp in all_splits if sp != heldout_split], 396 | full=True, 397 | subsample=args.frame_subsample, 398 | feature_downscale=args.feature_downscale, 399 | feature_permutation_seed=args.feature_permutation_seed, 400 | ), 401 | corpus.get_datasplit(remove_background=args.remove_background, 402 | splits=[sp for sp in all_splits if sp != heldout_split], 403 | full=True, 404 | subsample=args.frame_subsample, 405 | feature_downscale=args.feature_downscale, 406 | feature_permutation_seed=args.feature_permutation_seed, 407 | ), # has issue with some tasks being dropped if we pass full=False 408 | corpus.get_datasplit(remove_background=args.remove_background, 409 | splits=[heldout_split], 410 | full=True, 411 | subsample=args.frame_subsample, 412 | feature_downscale=args.feature_downscale, 413 | feature_permutation_seed=args.feature_permutation_seed, 414 | ), 415 | ) 416 | else: 417 | raise NotImplementedError("invalid dataset {}".format(args.dataset)) 418 | 419 | return splits 420 | 421 | 422 | if __name__ == "__main__": 423 | parser = argparse.ArgumentParser(fromfile_prefix_chars='@') 424 | add_serialization_args(parser) 425 | add_data_args(parser) 426 | add_classifier_args(parser) 427 | add_training_args(parser) 428 | add_misc_args(parser) 429 | args = parser.parse_args() 430 | 431 | print(' '.join(sys.argv)) 432 | 433 | pprint.pprint(vars(args)) 434 | 435 | stats_by_split_and_task = {} 436 | 437 | stats_by_split_by_task = {} 438 | 439 | for split_name, (train_data, train_sub_data, test_data) in make_data_splits(args).items(): 440 | print(split_name) 441 | if args.compare_only: 442 | assert args.compare_to_prediction_folder 443 | model = None 444 | else: 445 | if args.model_input_path: 446 | model_path = make_model_path(args.model_input_path, split_name) 447 | print("loading model from {}".format(model_path)) 448 | with open(model_path, 'rb') as f: 449 | model = pickle.load(f) 450 | if vars(args) != vars(model.args): 451 | print("warning: command line args and serialized model args differ:") 452 | cmd_d = vars(args) 453 | ser_d = vars(model.args) 454 | for key in set(cmd_d) | set(ser_d): 455 | if key == 'model_input_path' or key == 'model_output_path': 456 | continue 457 | if key not in ser_d or key not in cmd_d or ser_d[key] != cmd_d[key]: 458 | print("{}: {} != {}".format(key, cmd_d.get(key, ""), ser_d.get(key, ""))) 459 | 460 | print("setting model args to serialized args") 461 | model.args = args 462 | try: 463 | model.model.eval() 464 | if args.cuda: 465 | model.model.cuda() 466 | else: 467 | model.model.cpu() 468 | except Exception as e: 469 | print(e) 470 | 471 | else: 472 | model = train(args, train_data, test_data, split_name, train_sub_data=train_sub_data) 473 | 474 | print('split_name: {}'.format(split_name)) 475 | # prediction_output_path = args.prediction_output_path if 'val' in split_name else None 476 | prediction_output_path = args.prediction_output_path 477 | 478 | stats_by_task = test(args, model, test_data, split_name, prediction_output_path=prediction_output_path) 479 | stats_by_split_by_task[split_name] = {} 480 | for task, stats in stats_by_task.items(): 481 | stats_by_split_and_task["{}_{}".format(split_name, task)] = stats 482 | stats_by_split_by_task[split_name][task] = stats 483 | print() 484 | 485 | 486 | def divide(d): 487 | divided = {} 488 | for key, vals in d.items(): 489 | assert len(vals) == 2 490 | divided[key] = float(vals[0]) / vals[1] 491 | return divided 492 | 493 | 494 | print() 495 | pprint.pprint(stats_by_split_and_task) 496 | 497 | print() 498 | pprint.pprint({k: divide(d) for k, d in stats_by_split_and_task.items()}) 499 | 500 | summed_across_tasks = {} 501 | divided_averaged_across_tasks = {} 502 | 503 | sum_within_split_averaged_across_splits = {} 504 | 505 | for key in next(iter(stats_by_split_and_task.values())): 506 | arrs = np.array([d[key] for d in stats_by_split_and_task.values()]) 507 | summed_across_tasks[key] = np.sum(arrs, axis=0) 508 | 509 | divided_averaged_across_tasks[key] = np.mean([ 510 | divide(d)[key] for d in stats_by_split_and_task.values() 511 | ]) 512 | 513 | print() 514 | 515 | summed_across_tasks_divided = divide(summed_across_tasks) 516 | 517 | print("summed across tasks:") 518 | pprint.pprint(summed_across_tasks_divided) 519 | print() 520 | print("averaged across tasks:") 521 | pprint.pprint(divided_averaged_across_tasks) 522 | print() 523 | # print("averaged across splits:") 524 | # pprint.pprint(sum_within_split_averaged_across_splits) 525 | 526 | stat_dict = divided_averaged_across_tasks 527 | 528 | print(', '.join(STAT_KEYS)) 529 | print(', '.join('{:.4f}'.format(stat_dict[key]) for key in STAT_KEYS)) 530 | 531 | print(', '.join(DISPLAY_STAT_KEYS)) 532 | print(', '.join('{:.4f}'.format(stat_dict[key]) for key in DISPLAY_STAT_KEYS)) 533 | 534 | if any(stat.startswith('compare_') for stat in stat_dict): 535 | compare_keys = ['comparison_{}'.format(key) for key in DISPLAY_STAT_KEYS] 536 | print(', '.join(compare_keys)) 537 | print(', '.join('{:.4f}'.format(stat_dict[key]) for key in compare_keys)) 538 | -------------------------------------------------------------------------------- /src/evaluation/accuracy.py: -------------------------------------------------------------------------------- 1 | # modified from slim_mallow by Anna Kukleva, https://github.com/Annusha/slim_mallow 2 | 3 | import pprint 4 | from collections import defaultdict, Counter 5 | import editdistance 6 | 7 | 8 | import numpy as np 9 | from scipy.optimize import linear_sum_assignment 10 | 11 | from utils.logger import logger 12 | 13 | 14 | def singleton_lookup(dictionary, label): 15 | assert label in dictionary, "{} not in {}".format(label, dictionary) 16 | # this should be a singleton unless 'max' was used for optimization 17 | values = dictionary[label] 18 | assert len(values) == 1 19 | return next(iter(values)) 20 | 21 | def run_length_encode(labels): 22 | rle = [] 23 | current_label = None 24 | count = 0 25 | for label in labels: 26 | if current_label is None or label != current_label: 27 | if current_label is not None: 28 | assert count > 0 29 | rle.append((current_label, count)) 30 | count = 0 31 | current_label = label 32 | count += 1 33 | if current_label is not None: 34 | assert count > 0 35 | rle.append((current_label, count)) 36 | assert sum(count for sym, count in rle) == len(labels) 37 | return rle 38 | 39 | class Accuracy(object): 40 | """ Implementation of evaluation metrics for unsupervised learning. 41 | 42 | Since it's unsupervised learning relations between ground truth labels 43 | and output segmentation should be found. 44 | Hence the Hungarian method was used and labeling which gives us 45 | the best score is used as a result. 46 | """ 47 | def __init__(self, n_frames=1, verbose=True, corpus=None): 48 | """ 49 | Args: 50 | n_frames: frequency of sampling, 51 | in case of it's equal to 1 => dense sampling 52 | """ 53 | self._n_frames = n_frames 54 | self._reset() 55 | 56 | self._corpus = corpus 57 | 58 | 59 | self._predicted_rle_per_video = [] # : List[List[Tuple[label, length]]], one entry per video 60 | self._gt_rle_per_video = [] # : List[List[Tuple[label, length]]], one entry per video 61 | 62 | self._predicted_labels_per_video = [] 63 | self._gt_labels_per_video = [] 64 | self._gt_labels_multi_per_video = [] 65 | 66 | self._predicted_labels = None 67 | self._gt_labels_subset = None 68 | self._gt_labels = None 69 | self._gt_labels_multi = None 70 | self._boundaries = None 71 | # all frames used for alg without any subsampling technique 72 | self._indices = None 73 | 74 | self._frames_overall = 0 75 | 76 | self._true_background_frames = None 77 | self._pred_background_frames = None 78 | 79 | self._frames_true_pr = 0 80 | self._average_score = 0 81 | self._processed_number = 0 82 | # self._classes_precision = {} 83 | self._precision = None 84 | self._recall = None 85 | self._precision_without_bg = None 86 | self._recall_without_bg = None 87 | self._precision = None 88 | self._recall = None 89 | 90 | self._multiple_labels = None 91 | 92 | self._classes_recall = {} 93 | self._classes_MoF = {} 94 | self._classes_IoU = {} 95 | self._non_bg_IoU_multi = None 96 | # keys - gt, values - pr 97 | self.exclude = {} 98 | 99 | self._classes_levenshtein = {} 100 | self._classes_step_recall = {} 101 | 102 | self._logger = logger 103 | self._return = {} 104 | 105 | self._verbose = verbose 106 | 107 | def _reset(self): 108 | self._n_clusters = 0 109 | 110 | self._gt_label2index = {} 111 | self._gt_index2label = {} 112 | self._pr_label2index = {} 113 | self._pr_index2label = {} 114 | 115 | self._voting_table = [] 116 | self._gt2cluster = defaultdict(list) 117 | self._acc_per_gt_class = {} 118 | 119 | self.exclude = {} 120 | 121 | def _single_timestep_gt_labels(self, labels): 122 | # get a single label per timestep 123 | # should be nested list 124 | assert isinstance(labels, list) and isinstance(labels[0], list) 125 | # can have multiple GT labels per timestep, so we need to take just one per timestep 126 | return [lab_t[0] for lab_t in labels] 127 | 128 | def _add_labels(self, labels, is_predicted: bool): 129 | if is_predicted: 130 | rle = run_length_encode(labels) 131 | self._predicted_labels = None 132 | self._predicted_labels_per_video.append(labels) 133 | self._predicted_rle_per_video.append(rle) 134 | else: 135 | # ground truth can have multiple labels per timestep; deduplicate 136 | labels_single = self._single_timestep_gt_labels(labels) 137 | rle_single = run_length_encode(labels_single) 138 | self._gt_labels = None 139 | self._gt_labels_multi = None 140 | self._gt_labels_subset = None 141 | self._indices = None 142 | self._gt_labels_per_video.append(labels_single) 143 | self._gt_labels_multi_per_video.append(labels) 144 | self._gt_rle_per_video.append(rle_single) 145 | 146 | def add_gt_labels(self, labels): 147 | self._add_labels(labels, is_predicted=False) 148 | 149 | def add_predicted_labels(self, labels): 150 | self._add_labels(labels, is_predicted=True) 151 | 152 | # @property 153 | # def predicted_labels(self): 154 | # return self._predicted_labels 155 | # 156 | # @predicted_labels.setter 157 | # def predicted_labels(self, labels): 158 | # self._predicted_labels = np.array(labels) 159 | # self._reset() 160 | # 161 | # @property 162 | # def gt_labels(self): 163 | # return self._gt_labels_subset 164 | # 165 | # @gt_labels.setter 166 | # def gt_labels(self, labels): 167 | # # should be nested list 168 | # assert isinstance(labels, list) and isinstance(labels[0], list) 169 | # labels = [lab_t[0] for lab_t in labels] 170 | # self._gt_labels = np.array(labels) 171 | # self._gt_labels_subset = self._gt_labels[:] 172 | # self._indices = list(range(len(self._gt_labels))) 173 | 174 | def _set_gt_labels(self): 175 | labels = [x for xs in self._gt_labels_per_video for x in xs] 176 | labels_multi = [x for xs in self._gt_labels_multi_per_video for x in xs] 177 | self._gt_labels = np.array(labels) 178 | self._gt_labels_subset = self._gt_labels[:] 179 | self._gt_labels_multi = labels_multi 180 | assert len(labels) == len(labels_multi) 181 | self._indices = list(range(len(self._gt_labels))) 182 | 183 | def _set_predicted_labels(self): 184 | labels = [x for xs in self._predicted_labels_per_video for x in xs] 185 | self._predicted_labels = np.array(labels) 186 | 187 | @property 188 | def gt_labels(self): 189 | if self._gt_labels is None: 190 | self._set_gt_labels() 191 | return self._gt_labels 192 | 193 | @property 194 | def gt_labels_multi(self): 195 | if self._gt_labels_multi is None: 196 | self._set_gt_labels() 197 | return self._gt_labels_multi 198 | 199 | @property 200 | def gt_labels_subset(self): 201 | if self._gt_labels_subset is None: 202 | self._set_gt_labels() 203 | return self._gt_labels_subset 204 | 205 | @property 206 | def indices(self): 207 | if self._indices is None: 208 | self._set_gt_labels() 209 | return self._indices 210 | 211 | @property 212 | def predicted_labels(self): 213 | if self._predicted_labels is None: 214 | self._set_predicted_labels() 215 | return self._predicted_labels 216 | 217 | # @property 218 | # def params(self): 219 | # """ 220 | # boundaries: if frames samples from segments we need to know boundaries 221 | # of these segments to fulfill them after 222 | # indices: frames extracted for whatever and indeed evaluation 223 | # """ 224 | # return self._boundaries, self._indices 225 | # 226 | # @params.setter 227 | # def params(self, params): 228 | # self._boundaries = params[0] 229 | # self._indices = params[1] 230 | # self._gt_labels_subset = self._gt_labels[self._indices] 231 | 232 | def _create_voting_table(self): 233 | """Filling table with assignment scores. 234 | 235 | Create table which represents paired label assignments, i.e. each 236 | cell comprises score for corresponding label assignment""" 237 | size = max(len(np.unique(self.gt_labels_subset)), 238 | len(np.unique(self.predicted_labels))) 239 | self._voting_table = np.zeros((size, size)) 240 | 241 | for idx_gt, gt_label in enumerate(np.unique(self.gt_labels_subset)): 242 | self._gt_label2index[gt_label] = idx_gt 243 | self._gt_index2label[idx_gt] = gt_label 244 | 245 | if len(self._gt_label2index) < size: 246 | for idx_gt in range(len(np.unique(self.gt_labels_subset)), size): 247 | gt_label = idx_gt 248 | while gt_label in self._gt_label2index: 249 | gt_label += 1 250 | self._gt_label2index[gt_label] = idx_gt 251 | self._gt_index2label[idx_gt] = gt_label 252 | 253 | for idx_pr, pr_label in enumerate(np.unique(self.predicted_labels)): 254 | self._pr_label2index[pr_label] = idx_pr 255 | self._pr_index2label[idx_pr] = pr_label 256 | 257 | if len(self._pr_label2index) < size: 258 | for idx_pr in range(len(np.unique(self.predicted_labels)), size): 259 | pr_label = idx_pr 260 | while pr_label in self._pr_label2index: 261 | pr_label += 1 262 | self._pr_label2index[pr_label] = idx_pr 263 | self._pr_index2label[idx_pr] = pr_label 264 | 265 | for idx_gt, gt_label in enumerate(np.unique(self.gt_labels_subset)): 266 | if gt_label in list(self.exclude.keys()): 267 | continue 268 | gt_mask = self.gt_labels_subset == gt_label 269 | for idx_pr, pr_label in enumerate(np.unique(self.predicted_labels)): 270 | if pr_label in list(self.exclude.values()): 271 | continue 272 | self._voting_table[idx_gt, idx_pr] = \ 273 | np.sum(self.predicted_labels[gt_mask] == pr_label, dtype=float) 274 | for key, val in self.exclude.items(): 275 | # works only if one pair in exclude 276 | assert len(self.exclude) == 1 277 | try: 278 | self._voting_table[self._gt_label2index[key], self._pr_label2index[val[0]]] = size * np.max(self._voting_table) 279 | except KeyError: 280 | logger.debug('No background!') 281 | self._voting_table[self._gt_label2index[key], -1] = size * np.max(self._voting_table) 282 | self._pr_index2label[size - 1] = val[0] 283 | self._pr_label2index[val[0]] = size - 1 284 | 285 | def _create_correspondences(self, method='hungarian', optimization='max'): 286 | """ Find output labels which correspond to ground truth labels. 287 | 288 | Hungarian method finds one-to-one mapping: if there is squared matrix 289 | given, then for each output label -> gt label. If not, some labels will 290 | be without correspondences. 291 | Args: 292 | method: hungarian or max 293 | optimization: for hungarian method usually min problem but here 294 | is max, hence convert to min 295 | where: if some actions are not in the video collection anymore 296 | """ 297 | if method == 'hungarian': 298 | try: 299 | assert self._voting_table.shape[0] == self._voting_table.shape[1] 300 | except AssertionError: 301 | self._logger.debug('voting table non squared') 302 | raise AssertionError('bum tss') 303 | if optimization == 'max': 304 | # convert max problem to minimization problem 305 | self._voting_table *= -1 306 | x, y = linear_sum_assignment(self._voting_table) 307 | for idx_gt, idx_pr in zip(x, y): 308 | self._gt2cluster[self._gt_index2label[idx_gt]] = [self._pr_index2label[idx_pr]] 309 | elif method == 'max': 310 | # maximum voting, won't create exactly one-to-one mapping 311 | max_responses = np.argmax(self._voting_table, axis=0) 312 | for idx, c in enumerate(max_responses): 313 | # c is index of gt label 314 | # idx is predicted cluster label 315 | self._gt2cluster[self._gt_index2label[c]].append(idx) 316 | elif method == 'identity': 317 | for label in np.unique(self.gt_labels_subset): 318 | self._gt2cluster[label] = [label] 319 | 320 | def _fulfill_segments_nondes(self, boundaries, predicted_labels, n_frames): 321 | full_predicted_labels = [] 322 | for idx, slice in enumerate(range(0, len(predicted_labels), n_frames)): 323 | start, end = boundaries[idx] 324 | label_counter = Counter(predicted_labels[slice: slice + n_frames]) 325 | win_label = label_counter.most_common(1)[0][0] 326 | full_predicted_labels += [win_label] * (end - start + 1) 327 | return np.asarray(full_predicted_labels) 328 | 329 | def _fulfill_segments(self): 330 | """If was used frame sampling then anyway we need to get assignment 331 | for each frame""" 332 | self._full_predicted_labels = self._fulfill_segments_nondes(self._boundaries, self.predicted_labels, self._n_frames) 333 | 334 | def compute_assignment(self, optimal_assignment: bool, optimization='max', possible_gt_labels=None): 335 | self._n_clusters = len(np.unique(self.predicted_labels)) 336 | if optimal_assignment: 337 | self._create_voting_table() 338 | self._create_correspondences(method='hungarian', optimization=optimization) 339 | else: 340 | self._create_correspondences(method='identity') 341 | 342 | if possible_gt_labels is None: 343 | possible_gt_labels = np.unique(self.gt_labels_subset) 344 | 345 | num_gt_labels = len(possible_gt_labels) 346 | num_pr_labels = len(np.unique(self.predicted_labels)) 347 | 348 | assert num_pr_labels <= num_gt_labels, "gt_labels: {}, pred_labels: {}".format( 349 | possible_gt_labels, 350 | np.unique(self.predicted_labels), 351 | ) 352 | 353 | if self._verbose: 354 | self._logger.debug('# gt_labels: %d # pr_labels: %d' % 355 | (num_gt_labels, 356 | num_pr_labels)) 357 | self._logger.debug('Correspondences: segmentation to gt : ' 358 | + str([('%d: %d' % (value[0], key)) for (key, value) in 359 | sorted(self._gt2cluster.items(), key=lambda x: x[-1]) 360 | if len(value) > 0 361 | ])) 362 | return 363 | 364 | def levenshtein(self, gt2cluster=None): 365 | if gt2cluster is None: 366 | gt2cluster = self._gt2cluster 367 | levenshteins = [] 368 | max_num_segments = [] 369 | 370 | predicted_segments = 0.0 371 | predicted_segments_non_bg = 0.0 372 | 373 | num_videos = 0 374 | 375 | assert len(self._predicted_labels_per_video) == len(self._gt_labels_per_video) 376 | background_remapped_labels = set(singleton_lookup(gt2cluster, label) 377 | for label in self._corpus._background_indices 378 | if len(gt2cluster[label]) > 0) 379 | for ix, (gt_rle, pred_rle) in enumerate(zip(self._gt_rle_per_video, self._predicted_rle_per_video)): 380 | num_videos += 1 381 | assert sum(length for _, length in gt_rle) == sum(length for _, length in pred_rle) 382 | gt_remapped_segments = [singleton_lookup(gt2cluster, label) for (label, length) in gt_rle] 383 | pred_segments = [label for (label, length) in pred_rle] 384 | # self._logger.debug("{}: \n\tpred: {}\n\tgold:{}".format(ix, pred_segments, gt_remapped_segments)) 385 | predicted_segments += len(pred_segments) 386 | predicted_segments_non_bg += len([seg_label for seg_label in pred_segments if seg_label not in background_remapped_labels]) 387 | levenshteins.append(editdistance.eval(gt_remapped_segments, pred_segments)) 388 | max_num_segments.append(max(len(gt_remapped_segments), len(pred_segments))) 389 | 390 | levenshteins = np.array(levenshteins) 391 | max_num_segments = np.array(max_num_segments) 392 | 393 | assert np.all(max_num_segments > 0) 394 | 395 | results = { 396 | 'mean_levenshtein': np.array([np.mean(levenshteins), 1.0]), 397 | 'mean_max_segments': np.array([np.mean(max_num_segments), 1.0]), 398 | 'total_levenshtein': np.array([np.sum(levenshteins), 1.0]), 399 | 'num_videos': np.array([len(levenshteins), 1.0]), 400 | 'mean_normed_levenshtein': np.array([np.mean(levenshteins / max_num_segments), 1.0]), 401 | 'predicted_segments_per_video': np.array([predicted_segments, num_videos]), 402 | 'predicted_segments_non_bg_per_video': np.array([predicted_segments_non_bg, num_videos]), 403 | } 404 | if self._verbose: 405 | logger.debug("Levenshtein stats") 406 | for k, v in results.items(): 407 | logger.debug("{}: {}".format(k, v)) 408 | self._return.update(results) 409 | 410 | def single_step_recall(self, gt2cluster=None): 411 | if gt2cluster is None: 412 | gt2cluster = self._gt2cluster 413 | 414 | step_match = 0.0 415 | step_total = 0.0 416 | non_background_step_match = 0.0 417 | non_background_step_total = 0.0 418 | 419 | center_step_match = 0.0 420 | non_background_center_step_match = 0.0 421 | 422 | predicted_label_types = 0.0 423 | predicted_label_types_non_bg = 0.0 424 | num_videos = 0.0 425 | 426 | assert len(self._predicted_labels_per_video) == len(self._gt_labels_per_video) 427 | background_remapped_labels = set(singleton_lookup(gt2cluster, label) 428 | for label in self._corpus._background_indices 429 | if len(gt2cluster[label]) > 0) 430 | 431 | for gt_labels, pred_labels in zip(self._gt_labels_per_video, self._predicted_labels_per_video): 432 | num_videos += 1 433 | pred_labels = np.asarray(pred_labels) 434 | background_timesteps = [lab in self._corpus._background_indices for lab in gt_labels] 435 | gt_labels_remapped = np.asarray([gt2cluster[gt_label] for gt_label in gt_labels]) 436 | 437 | for label in np.unique(pred_labels): 438 | predicted_label_types += 1 439 | if label not in background_remapped_labels: 440 | predicted_label_types_non_bg += 1 441 | 442 | for label in np.unique(gt_labels_remapped): 443 | step_total += 1 444 | if label not in background_remapped_labels: 445 | non_background_step_total += 1 446 | pred_indices = (pred_labels == label).nonzero()[0] 447 | if len(pred_indices) == 0: 448 | continue 449 | pred_index = np.random.choice(pred_indices) 450 | # center_index = pred_indices[len(pred_indices) // 2] 451 | center_index = min(pred_indices, key=lambda x:abs(x-(pred_indices[0]+pred_indices[-1])/2)) 452 | if gt_labels_remapped[pred_index] == label: 453 | step_match += 1 454 | if label not in background_remapped_labels: 455 | non_background_step_match += 1 456 | if gt_labels_remapped[center_index] == label: 457 | center_step_match += 1 458 | if label not in background_remapped_labels: 459 | non_background_center_step_match += 1 460 | results = ({ 461 | 'single_step_recall': np.array([step_match, step_total]), 462 | 'step_recall_non_bg': np.array([non_background_step_match, non_background_step_total]), 463 | 'center_step_recall': np.array([center_step_match, step_total]), 464 | 'center_step_recall_non_bg': np.array([non_background_center_step_match, non_background_step_total]), 465 | 'predicted_label_types_per_video': np.array([predicted_label_types, num_videos]), 466 | 'predicted_label_types_non_bg_per_video': np.array([predicted_label_types_non_bg, num_videos]), 467 | }) 468 | if self._verbose: 469 | logger.debug("Single step recall stats") 470 | for k, v in results.items(): 471 | logger.debug("{}: {}".format(k, v)) 472 | self._return.update(results) 473 | 474 | 475 | def mof(self, optimal_assignment: bool, with_segments=False, optimization='max', possible_gt_labels=None): 476 | """ Compute mean over frames (MoF) for current labeling. 477 | 478 | Args: 479 | optimal_assignment: use hungarian to maximize MoF? 480 | with_segments: if frame sampling was used 481 | optimization: inside hungarian method 482 | where: see _create_correspondences method 483 | 484 | Returns: 485 | 486 | """ 487 | self.compute_assignment(optimal_assignment, optimization=optimization, possible_gt_labels=possible_gt_labels) 488 | if with_segments: 489 | self._fulfill_segments() 490 | else: 491 | self._full_predicted_labels = self.predicted_labels 492 | 493 | background_clusters = [self._gt2cluster[label] for label in self._corpus._background_indices] 494 | 495 | self._classes_MoF = {} 496 | self._classes_IoU = {} 497 | excluded_total = 0 498 | if self._verbose: 499 | logger.debug("exclude: {}".format(self.exclude)) 500 | for gt_label in np.unique(self.gt_labels): 501 | true_defined_frame_n = 0. 502 | union = 0 503 | gt_mask = self.gt_labels == gt_label 504 | # no need the loop since only one label should be here 505 | # i.e. one-to-one mapping, but i'm lazy 506 | predicted = 0 507 | for cluster in self._gt2cluster[gt_label]: 508 | true_defined_frame_n += np.sum(self._full_predicted_labels[gt_mask] == cluster, 509 | dtype=float) 510 | pr_mask = self._full_predicted_labels == cluster 511 | union += np.sum(gt_mask | pr_mask) 512 | predicted += np.sum(pr_mask) 513 | 514 | self._classes_MoF[gt_label] = [true_defined_frame_n, np.sum(gt_mask)] 515 | self._classes_IoU[gt_label] = [true_defined_frame_n, union] 516 | # self._classes_precision[gt_label] = [true_defined_frame_n, predicted] 517 | 518 | if gt_label in self.exclude: 519 | excluded_total += np.sum(gt_mask) 520 | else: 521 | self._frames_true_pr += true_defined_frame_n 522 | 523 | assert len(self.gt_labels_multi) == len(self._full_predicted_labels) 524 | 525 | self._precision = np.zeros(2) 526 | self._recall = np.zeros(2) 527 | 528 | self._precision_without_bg = np.zeros(2) 529 | self._recall_without_bg = np.zeros(2) 530 | 531 | self._true_background_frames = np.zeros(2) 532 | self._pred_background_frames = np.zeros(2) 533 | 534 | self._non_bg_IoU_multi = np.zeros(2) 535 | 536 | self._multiple_labels = np.zeros(2) 537 | 538 | for gt_labels_t, pred_label_t in zip(self.gt_labels_multi, self._full_predicted_labels): 539 | self._multiple_labels[1] += 1 540 | if len(gt_labels_t) > 1: 541 | self._multiple_labels[0] += 1 542 | gt_clusters_t = [self._gt2cluster[gt_label] for gt_label in gt_labels_t] 543 | self._recall[1] += len(gt_labels_t) 544 | self._precision[1] += 1 545 | 546 | true_positive = pred_label_t in gt_clusters_t 547 | if true_positive: 548 | self._recall[0] += 1 549 | self._precision[0] += 1 550 | 551 | self._true_background_frames[1] += 1 552 | self._pred_background_frames[1] += 1 553 | 554 | pred_background = False 555 | if pred_label_t in background_clusters: 556 | self._pred_background_frames[0] += 1 557 | pred_background = True 558 | 559 | is_background = False 560 | if any(gt_label in self._corpus._background_indices for gt_label in gt_labels_t): 561 | is_background = True 562 | assert all(gt_label in self._corpus._background_indices for gt_label in gt_labels_t) 563 | 564 | if (not is_background) or (not pred_background): 565 | self._non_bg_IoU_multi[1] += 1 566 | if true_positive: 567 | self._non_bg_IoU_multi[0] += 1 568 | 569 | if is_background: 570 | self._true_background_frames[0] += 1 571 | else: 572 | self._recall_without_bg[1] += len(gt_labels_t) 573 | self._precision_without_bg[1] += 1 574 | if pred_label_t in gt_clusters_t: 575 | self._recall_without_bg[0] += 1 576 | self._precision_without_bg[0] += 1 577 | 578 | self._frames_overall = len(self.gt_labels) - excluded_total 579 | return self._frames_overall 580 | 581 | def mof_classes(self): 582 | average_class_mof = 0 583 | total_true = 0 584 | total = 0 585 | 586 | average_class_mof_non_bkg = 0 587 | total_true_non_bkg = 0 588 | total_non_bkg = 0 589 | non_bkg_classes = 0 590 | for key, val in self._classes_MoF.items(): 591 | true_frames, all_frames = val 592 | # tf_2, pred_frames = self._classes_precision[key] 593 | # assert tf_2 == true_frames 594 | if self._verbose: 595 | log_str = 'mof label %d: %f %d / %d' % (key, true_frames / all_frames, 596 | true_frames, all_frames) 597 | if self._corpus is not None: 598 | log_str += '\t[{}]'.format(self._corpus.index2label[key]) 599 | logger.debug(log_str) 600 | # log_str = 'prec label %d: %f %d / %d' % (key, tf_2 / pred_frames, tf_2, pred_frames) 601 | # if self._corpus is not None: 602 | # log_str += '\t[{}]'.format(self._corpus.index2label[key]) 603 | # logger.debug(log_str) 604 | average_class_mof += true_frames / all_frames 605 | total_true += true_frames 606 | total += all_frames 607 | if key not in self._corpus._background_indices: 608 | non_bkg_classes += 1 609 | average_class_mof_non_bkg += true_frames / all_frames 610 | total_true_non_bkg += true_frames 611 | total_non_bkg += all_frames 612 | average_class_mof /= len(self._classes_MoF) 613 | average_class_mof_non_bkg /= non_bkg_classes 614 | self._return['mof'] = [self._frames_true_pr, self._frames_overall] 615 | self._return['mof_bg'] = [total_true, total] 616 | self._return['mof_non_bg'] = [total_true_non_bkg, total_non_bkg] 617 | self._return['precision'] = self._precision 618 | self._return['recall'] = self._recall 619 | if self._precision[1] == 0: 620 | precision = 0. 621 | else: 622 | precision = float(self._precision[0]) / self._precision[1] 623 | if self._recall[1] == 0: 624 | recall = 0. 625 | else: 626 | recall = float(self._recall[0]) / self._recall[1] 627 | self._return['f1'] = np.array([(2 * precision * recall) / (precision + recall), 1.0]) 628 | 629 | self._return['precision_non_bg'] = self._precision_without_bg 630 | self._return['recall_non_bg'] = self._recall_without_bg 631 | if self._precision_without_bg[1] == 0: 632 | precision_no_bg = 0. 633 | else: 634 | precision_no_bg = float(self._precision_without_bg[0]) / self._precision_without_bg[1] 635 | if self._recall_without_bg[1] == 0: 636 | recall_no_bg = 0. 637 | else: 638 | recall_no_bg = float(self._recall_without_bg[0]) / self._recall_without_bg[1] 639 | 640 | if precision_no_bg == 0 and recall_no_bg == 0: 641 | f1_no_bg = 0 642 | else: 643 | f1_no_bg = (2 * precision_no_bg * recall_no_bg) / (precision_no_bg + recall_no_bg) 644 | self._return['f1_non_bg'] = np.array([f1_no_bg, 1.0]) 645 | 646 | self._return['true_background'] = self._true_background_frames 647 | self._return['pred_background'] = self._pred_background_frames 648 | 649 | self._return['iou_multi_non_bg'] = self._non_bg_IoU_multi 650 | 651 | self._return['multiple_gt_labels'] = self._multiple_labels 652 | 653 | if self._verbose: 654 | logger.debug('average class mof: %f' % average_class_mof) 655 | logger.debug('mof with bg: %f' % (total_true / total)) 656 | logger.debug('average class mof without bg: %f' % average_class_mof_non_bkg) 657 | logger.debug('mof without bg: %f' % (total_true_non_bkg / total_non_bkg)) 658 | logger.debug('\n') 659 | logger.debug('f1 with bg: %f' % self._return['f1'][0]) 660 | logger.debug('f1 without bg: %f' % self._return['f1_non_bg'][0]) 661 | 662 | def iou_classes(self): 663 | average_class_iou = 0 664 | excluded_iou = 0 665 | non_bg_iou = 0 666 | for key, val in self._classes_IoU.items(): 667 | true_frames, union = val 668 | if self._verbose: 669 | log_str = 'iou label %d: %f %d / %d' % (key, true_frames / union, true_frames, union) 670 | if self._corpus is not None: 671 | log_str += ' [{}]'.format(self._corpus.index2label[key]) 672 | logger.debug(log_str) 673 | if key not in self.exclude: 674 | average_class_iou += true_frames / union 675 | else: 676 | excluded_iou += true_frames / union 677 | if key not in self._corpus._background_indices: 678 | non_bg_iou += true_frames / union 679 | average_iou_without_exc = average_class_iou / \ 680 | (len(self._classes_IoU) - len(self.exclude)) 681 | average_iou_with_exc = (average_class_iou + excluded_iou) / \ 682 | len(self._classes_IoU) 683 | self._return['iou'] = [average_class_iou, 684 | len(self._classes_IoU) - len(self.exclude)] 685 | self._return['iou_bg'] = [average_class_iou + excluded_iou, 686 | len(self._classes_IoU) - len(self.exclude)] 687 | # TODO: non-bg IOU 688 | # figure out class averaging 689 | if self._verbose: 690 | logger.debug('average IoU: %f' % average_iou_without_exc) 691 | logger.debug('average IoU with bg: %f' % average_iou_with_exc) 692 | 693 | 694 | def mof_val(self): 695 | if self._verbose: 696 | self._logger.debug('frames true: %d\tframes overall : %d' % 697 | (self._frames_true_pr, self._frames_overall)) 698 | return float(self._frames_true_pr) / self._frames_overall 699 | 700 | def frames(self): 701 | return self._frames_true_pr 702 | 703 | def stat(self): 704 | return self._return 705 | -------------------------------------------------------------------------------- /src/data/crosstask.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import os 4 | import pickle 5 | from collections import namedtuple, defaultdict 6 | 7 | import numpy as np 8 | 9 | from data.corpus import Corpus, GroundTruth, Video, Datasplit 10 | from data.features import grouped_pca 11 | from utils.logger import logger 12 | from utils.utils import load_pickle 13 | import random 14 | 15 | CrosstaskTask = namedtuple("CrosstaskTask", ["index", "title", "url", "n_steps", "steps"]) 16 | 17 | 18 | def read_task_info(path): 19 | tasks = [] 20 | with open(path, 'r') as f: 21 | index = f.readline() 22 | while index is not '': 23 | index = int(index.strip()) 24 | title = f.readline().strip() 25 | url = f.readline().strip() 26 | n_steps = int(f.readline().strip()) 27 | steps = f.readline().strip().split(',') 28 | next(f) 29 | assert n_steps == len(steps) 30 | tasks.append(CrosstaskTask(index, title, url, n_steps, steps)) 31 | index = f.readline() 32 | return tasks 33 | 34 | 35 | def get_vids(path): 36 | task_vids = {} 37 | with open(path, 'r') as f: 38 | for line in f: 39 | task, vid, url = line.strip().split(',') 40 | task = int(task) 41 | if task not in task_vids: 42 | task_vids[task] = [] 43 | task_vids[task].append(vid) 44 | return task_vids 45 | 46 | 47 | def read_assignment(T, num_steps, path, include_background=False): 48 | if include_background: 49 | cols = num_steps + 1 50 | else: 51 | cols = num_steps 52 | Y = np.zeros([T, cols], dtype=np.uint8) 53 | with open(path, 'r') as f: 54 | for line in f: 55 | step, start, end = line.strip().split(',') 56 | step = int(step) 57 | start = int(math.floor(float(start))) 58 | end = int(math.ceil(float(end))) 59 | if not include_background: 60 | step = step - 1 61 | Y[start:end, step] = 1 62 | if include_background: 63 | # turn on the background class (col 0) for any row that has no entries 64 | Y[Y.sum(axis=1) == 0, 0] = 1 65 | return Y 66 | 67 | 68 | def read_assignment_list(T, num_steps, path): 69 | # T x (K + 1) 70 | Y = read_assignment(T, num_steps, path, include_background=True) 71 | indices = [list(row.nonzero()[0]) for row in Y] 72 | assert len(indices) == T 73 | assert max(max(indices_t) for indices_t in indices) <= num_steps 74 | return indices 75 | 76 | 77 | def random_split(task_vids, test_tasks, n_train): 78 | train_vids = {} 79 | test_vids = {} 80 | for task, vids in task_vids.items(): 81 | if task in test_tasks and len(vids) > n_train: 82 | train_vids[task] = np.random.choice(vids, n_train, replace=False).tolist() 83 | test_vids[task] = [vid for vid in vids if vid not in train_vids[task]] 84 | else: 85 | train_vids[task] = vids 86 | return train_vids, test_vids 87 | 88 | 89 | class CrosstaskVideo(Video): 90 | 91 | def __init__(self, *args, dimensions_per_feature_group=None, **kwargs): 92 | self._dimensions_per_feature_group = dimensions_per_feature_group 93 | super(CrosstaskVideo, self).__init__(*args, **kwargs) 94 | 95 | @classmethod 96 | def load_grouped_features(cls, feature_root, dimensions_per_feature_group, video_name): 97 | if dimensions_per_feature_group is None: 98 | try: 99 | return np.load(os.path.join(feature_root, "{}.npy".format(video_name))) 100 | except Exception as e: 101 | print(e) 102 | print("video_name: {}".format(video_name)) 103 | print("feature path: {}".format(os.path.join(feature_root, "{}.npy".format(video_name)))) 104 | raise e 105 | else: 106 | all_feats = [] 107 | for feature_group, dimensions in sorted(dimensions_per_feature_group.items()): 108 | feat_path = os.path.join(feature_root, feature_group, "{}.npy".format(video_name)) 109 | feats = np.load(feat_path) 110 | feats = feats[:, :dimensions] 111 | all_feats.append(feats) 112 | return np.hstack(all_feats) 113 | 114 | def load_features(self): 115 | return CrosstaskVideo.load_grouped_features(self._feature_root, self._dimensions_per_feature_group, self.name) 116 | 117 | 118 | DATA_SPLITS = ['train', 'val', 'all'] 119 | 120 | def load_videos_by_task(release_root, split='train', cv_n_train=30): 121 | assert split in DATA_SPLITS or split.startswith('cv') 122 | 123 | all_videos_by_task = get_vids(os.path.join(release_root, "videos.csv")) 124 | if split == 'all': 125 | return all_videos_by_task 126 | val_videos_by_task = get_vids(os.path.join(release_root, "videos_val.csv")) 127 | if split == 'val': 128 | return val_videos_by_task 129 | 130 | val_videos = set(v for vids in val_videos_by_task.values() for v in vids) 131 | train_videos_by_task = { 132 | task_index: [v for v in vids if v not in val_videos] 133 | for task_index, vids in all_videos_by_task.items() 134 | } 135 | 136 | if split.startswith('cv'): 137 | # cv_{train|test}_{split_seed} 138 | cv, cv_split, split_seed = split.split('_') 139 | assert cv == 'cv' 140 | assert cv_split in ['train', 'test'] 141 | 142 | vids_by_task = {} 143 | for task in train_videos_by_task: 144 | state = random.Random(int(split_seed)) 145 | vids = sorted(train_videos_by_task[task]) 146 | state.shuffle(vids) 147 | if cv_split == 'train': 148 | # take first n 149 | vids_by_task[task] = vids[:cv_n_train] 150 | else: 151 | # take remainder 152 | vids_by_task[task] = vids[cv_n_train:] 153 | return vids_by_task 154 | 155 | assert split == 'train' 156 | return train_videos_by_task 157 | 158 | 159 | def datasets_by_task(release_root, feature_root, constraints_root, remove_background, task_sets=None, split='train', task_ids=None, full=True): 160 | if task_sets is None: 161 | task_sets = list(CrosstaskCorpus.TASK_SET_PATHS.keys()) 162 | if task_ids is None: 163 | task_ids = [ 164 | task_id for task_set in task_sets 165 | for task_id in CrosstaskCorpus.TASK_IDS_BY_SET[task_set] 166 | ] 167 | corpus = CrosstaskCorpus(release_root, feature_root, use_secondary='related' in task_sets, 168 | load_constraints=True, constraints_root=constraints_root) 169 | if not os.path.exists(os.path.join(corpus._release_root, "frame_counts.pkl")): 170 | # get_datasplit generates frame counts but only for the passed task_ids, so we need to call this for its side effect of writing frame_counts.pkl 171 | corpus.get_datasplit(remove_background, task_sets=CrosstaskCorpus.TASK_SET_PATHS.keys(), split='all', task_ids=None, full=full) 172 | return { 173 | task_id: corpus.get_datasplit(remove_background, task_sets=task_sets, split=split, task_ids=[task_id], 174 | full=full) 175 | for task_id in task_ids 176 | } 177 | 178 | 179 | class CrosstaskDatasplit(Datasplit): 180 | def __init__(self, corpus, remove_background, task_sets=None, 181 | split='train', task_ids=None, full=True, subsample=1, feature_downscale=1.0, 182 | val_videos_override=None, 183 | feature_permutation_seed=None, 184 | ): 185 | self.full = full 186 | self._tasks_to_load = [] 187 | 188 | if task_sets is None: 189 | task_sets = list(sorted(CrosstaskCorpus.TASK_SET_PATHS.keys())) 190 | 191 | assert all(ts in CrosstaskCorpus.TASK_SET_PATHS.keys() for ts in task_sets) 192 | 193 | for ts in task_sets: 194 | tasks = read_task_info(os.path.join(corpus._release_root, CrosstaskCorpus.TASK_SET_PATHS[ts])) 195 | for task in tasks: 196 | if task_ids is None or task.index in task_ids: 197 | self._tasks_to_load.append(task) 198 | 199 | task_indices_to_load = list(sorted(set([task.index for task in self._tasks_to_load]))) 200 | 201 | self._tasks_by_id = { 202 | task.index: task for task in self._tasks_to_load 203 | } 204 | 205 | if val_videos_override is not None: 206 | def use_video(video): 207 | if split == 'train': 208 | return video not in val_videos_override 209 | else: 210 | assert split == 'val' 211 | return video in val_videos_override 212 | self._video_names_by_task = { 213 | task_index: [video for video in videos if use_video(video)] 214 | for task_index, videos in load_videos_by_task(corpus._release_root, split='all').items() 215 | if task_index in task_indices_to_load 216 | } 217 | else: 218 | self._video_names_by_task = { 219 | task_index: videos 220 | for task_index, videos in load_videos_by_task(corpus._release_root, split=split).items() 221 | if task_index in task_indices_to_load 222 | } 223 | 224 | if not self.full: 225 | self._video_names_by_task = { 226 | task_index: videos[:10] 227 | for task_index, videos in self._video_names_by_task.items() 228 | } 229 | 230 | self._tasks_by_video = { 231 | video: task 232 | for task, videos in self._video_names_by_task.items() 233 | for video in videos 234 | } 235 | 236 | assert len( 237 | self._video_names_by_task) != 0, "no tasks found with task_sets {}, task_ids {}, split {}, and release_directory {}".format( 238 | task_sets, task_ids, split, corpus._release_root 239 | ) 240 | 241 | video_names = list(sorted(set(video for videos in self._video_names_by_task.values() for video in videos))) 242 | assert len(video_names) != 0, "no videos found with task_sets {}, task_ids {}, split {}, and release_directory {}".format( 243 | task_sets, task_ids, split, corpus._release_root 244 | ) 245 | 246 | # logger.debug( 247 | # "{} tasks found with task_sets {}, task_ids {}, split {}".format(len(self._video_names_by_task), task_sets, task_ids, split)) 248 | # logger.debug("{} videos found with task_sets {}, task_ids {}, split {}".format(len(video_names), task_sets, task_ids, split)) 249 | 250 | self._save_frame_counts = (split == 'all' and set(corpus.TASK_SET_PATHS.keys()) == set(task_sets)) 251 | 252 | super(CrosstaskDatasplit, self).__init__( 253 | corpus, remove_background, subsample=subsample, feature_downscale=feature_downscale, 254 | feature_permutation_seed=feature_permutation_seed 255 | ) 256 | 257 | def _load_ground_truth_and_videos(self, remove_background): 258 | # features_by_task_and_video = {} 259 | 260 | t_by_video_path = os.path.join(self._corpus._release_root, "frame_counts.pkl") 261 | 262 | if os.path.exists(t_by_video_path): 263 | with open(t_by_video_path, 'rb') as f: 264 | t_by_video = pickle.load(f) 265 | else: 266 | logger.debug("creating frame counts") 267 | t_by_video = {} 268 | 269 | for task_name in self._tasks_by_id: 270 | logger.debug(task_name) 271 | for video in self._video_names_by_task[task_name]: 272 | feats = CrosstaskVideo.load_grouped_features( 273 | self._corpus._feature_root, self._corpus._dimensions_per_feature_group, video 274 | ) 275 | 276 | # features_by_task_and_video[(task_name, video)] = feats 277 | 278 | T = feats.shape[0] 279 | if video in t_by_video: 280 | assert t_by_video[ 281 | video] == T, "mismatch in timesteps from features for video {}. stored: {}; new {}".format( 282 | video, t_by_video[video], T) 283 | t_by_video[video] = T 284 | if self._save_frame_counts: 285 | logger.debug("saving to {}".format(t_by_video_path)) 286 | with open(t_by_video_path, 'wb') as f: 287 | pickle.dump(t_by_video, f) 288 | 289 | self.groundtruth = CrosstaskGroundTruth(self._corpus, self._tasks_by_id, t_by_video, self._remove_background) 290 | self._K_by_task = self.groundtruth._K_by_task 291 | 292 | for task_name in self._tasks_by_id: 293 | if task_name not in self._videos_by_task: 294 | self._videos_by_task[task_name] = {} 295 | for video in self._video_names_by_task[task_name]: 296 | assert video not in self._videos_by_task[task_name] 297 | has_label = task_name in self.groundtruth.gt_by_task 298 | 299 | nonbackground_timesteps = self.groundtruth.nonbackground_timesteps_by_task[task_name][video] if ( 300 | has_label and self._remove_background) else None 301 | self._videos_by_task[task_name][video] = CrosstaskVideo( 302 | feature_root=self._corpus._feature_root, 303 | dimensions_per_feature_group=self._corpus._dimensions_per_feature_group, 304 | remove_background=self._remove_background, 305 | nonbackground_timesteps=nonbackground_timesteps, 306 | K=self._K_by_task[task_name], 307 | gt=self.groundtruth.gt_by_task[task_name][video] if has_label else None, 308 | gt_with_background=self.groundtruth.gt_with_background_by_task[task_name][ 309 | video] if has_label else None, 310 | name=video, 311 | has_label=has_label, 312 | cache_features=self._corpus._cache_features, 313 | features_contain_background=self._corpus._features_contain_background, 314 | constraints=self.groundtruth.constraints_by_task[task_name][video], 315 | feature_permutation_seed=self._feature_permutation_seed, 316 | ) 317 | 318 | def get_ordered_indices_no_background(self): 319 | ordered_indices_by_task = {} 320 | for task in self._corpus._all_tasks: 321 | indices = [ 322 | self._corpus._index(self._corpus.get_label(task.index, step)) 323 | for step in task.steps 324 | ] 325 | ordered_indices_by_task[task.index] = indices 326 | return ordered_indices_by_task 327 | 328 | def get_allowed_starts_and_transitions(self): 329 | 330 | allowed_starts = set() 331 | allowed_transitions = {} 332 | allowed_ends = set() 333 | 334 | ordered_indices_by_task = {} 335 | 336 | for task in self._corpus._all_tasks: 337 | if self.remove_background: 338 | indices = self.get_ordered_indices_no_background()[task.index] 339 | ordered_indices_by_task[task.index] = indices 340 | 341 | for src, tgt in zip(indices, indices[1:]): 342 | if src not in allowed_transitions: 343 | allowed_transitions[src] = set() 344 | allowed_transitions[src].add(tgt) 345 | 346 | allowed_starts.add(indices[0]) 347 | allowed_ends.add(indices[-1]) 348 | else: 349 | step_indices = [ 350 | self._corpus._index(self._corpus.get_label(task.index, step)) 351 | for step in task.steps 352 | ] 353 | background_indices = [ 354 | self._corpus._index(lbl) 355 | for lbl in self._corpus.BACKGROUND_LABELS_BY_TASK[task.index] 356 | ] 357 | assert len(background_indices) == len(step_indices) + 1 358 | indices = [] 359 | for ix in range(len(step_indices)): 360 | indices.append(background_indices[ix]) 361 | indices.append(step_indices[ix]) 362 | indices.append(background_indices[-1]) 363 | assert len(indices) == 2 * len(step_indices) + 1 364 | 365 | STEP_TO_STEP = False 366 | 367 | if STEP_TO_STEP: 368 | for i, step_ix in enumerate(step_indices): 369 | s = {background_indices[i+1]} 370 | if i < len(step_indices) - 1: 371 | s.add(step_indices[i+1]) 372 | allowed_transitions[step_ix] = s 373 | 374 | for i, bg_ix in enumerate(background_indices[:-1]): 375 | s = {step_indices[i]} 376 | s.add(background_indices[i+1]) 377 | allowed_transitions[bg_ix] = s 378 | else: 379 | for src, tgt in zip(indices, indices[1:]): 380 | if src not in allowed_transitions: 381 | allowed_transitions[src] = set() 382 | allowed_transitions[src].add(tgt) 383 | 384 | allowed_starts.add(indices[0]) 385 | allowed_ends.add(indices[-1]) 386 | ordered_indices_by_task[task.index] = indices 387 | 388 | return allowed_starts, allowed_transitions, allowed_ends, ordered_indices_by_task 389 | 390 | 391 | class CrosstaskCorpus(Corpus): 392 | TASK_SET_PATHS = { 393 | 'primary': 'tasks_primary.txt', 394 | 'related': 'tasks_related.txt', 395 | } 396 | 397 | TASK_IDS_BY_SET = { 398 | 'primary': [16815, 23521, 40567, 44047, 44789, 53193, 59684, 71781, 76400, 77721, 87706, 91515, 94276, 95603, 399 | 105222, 105253, 109972, 113766], 400 | 'related': [1373, 11138, 14133, 16136, 16323, 20880, 20898, 23524, 26618, 29477, 30744, 31438, 34938, 34967, 401 | 40566, 40570, 40596, 40610, 41718, 41773, 41950, 42901, 44043, 50348, 51659, 53195, 53204, 57396, 402 | 67160, 68268, 72954, 75501, 76407, 76412, 77194, 81790, 83956, 85159, 89899, 91518, 91537, 91586, 403 | 93376, 93400, 96127, 96366, 97633, 100901, 101028, 103832, 105209, 105259, 105762, 106568, 106686, 404 | 108098, 109761, 110266, 113764, 114508, 118421, 118779, 118780, 118819, 118831], 405 | } 406 | 407 | def __init__(self, release_root, feature_root, dimensions_per_feature_group=None, 408 | features_contain_background=True, task_specific_steps=True, use_secondary=False, 409 | annotate_background_with_previous=False, load_constraints=False, constraints_root=None): 410 | print("feature root: {}".format(feature_root)) 411 | 412 | self._release_root = release_root 413 | self._feature_root = feature_root 414 | self._dimensions_per_feature_group = dimensions_per_feature_group 415 | self._features_contain_background = features_contain_background 416 | 417 | self.use_secondary = use_secondary 418 | if use_secondary: 419 | all_task_sets = list(sorted(CrosstaskCorpus.TASK_SET_PATHS.keys())) 420 | else: 421 | all_task_sets = ['primary'] 422 | 423 | self._all_tasks = [ 424 | task for ts in all_task_sets 425 | for task in read_task_info(os.path.join(release_root, CrosstaskCorpus.TASK_SET_PATHS[ts])) 426 | ] 427 | 428 | self.task_specific_steps = task_specific_steps 429 | 430 | self.annotate_background_with_previous = annotate_background_with_previous 431 | 432 | if load_constraints: 433 | assert constraints_root is not None 434 | self._constraints_root = constraints_root 435 | self.load_constraints = load_constraints 436 | 437 | if annotate_background_with_previous: 438 | # assert task_specific_steps 439 | self.BACKGROUND_LABELS_BY_TASK = { 440 | task.index: [self.get_label(task.index, "BKG_{}".format(step)) 441 | for step in ["FIRST"] + task.steps] 442 | for task in self._all_tasks 443 | } 444 | else: 445 | self.BACKGROUND_LABELS_BY_TASK = { 446 | task.index: [self.get_label(task.index, "BKG")] 447 | for task in self._all_tasks 448 | } 449 | 450 | # if not self.task_specific_steps, BACKGROUND_LABELS_BY_TASK will map everything to BKG, so we need to deduplicate 451 | 452 | # self.BACKGROUND_LABELS = list(sorted(set(self.BACKGROUND_LABELS_BY_TASK.values()))) 453 | self.BACKGROUND_LABELS = list(sorted(set( 454 | lbl for task_labels in self.BACKGROUND_LABELS_BY_TASK.values() 455 | for lbl in task_labels 456 | ))) 457 | 458 | super(CrosstaskCorpus, self).__init__(background_labels=self.BACKGROUND_LABELS) 459 | 460 | def get_label(self, task, step): 461 | if self.task_specific_steps: 462 | return "{} {}".format(task, step) 463 | else: 464 | return step 465 | 466 | def _get_components_for_label(self, label): 467 | # if self.task_specific_steps: 468 | # step_words = label.split(':')[1] 469 | # else: 470 | # step_words = label 471 | # return step_words.split() 472 | return label.split() 473 | 474 | def _load_mapping(self): 475 | for task in self._all_tasks: 476 | indices = [self._index(lbl) for lbl in self.BACKGROUND_LABELS_BY_TASK[task.index]] 477 | indices += [self._index(self.get_label(task.index, step)) for step in task.steps] 478 | self.update_indices_by_task(task.index, indices) 479 | 480 | 481 | def get_datasplit(self, remove_background, task_sets=None, split='train', task_ids=None, full=True, subsample=1, 482 | feature_downscale=1.0, val_videos_override=None, feature_permutation_seed=None): 483 | return CrosstaskDatasplit( 484 | self, remove_background, task_sets=task_sets, split=split, task_ids=task_ids, 485 | full=full, subsample=subsample, feature_downscale=feature_downscale, 486 | val_videos_override=val_videos_override, feature_permutation_seed=feature_permutation_seed, 487 | ) 488 | 489 | 490 | class CrosstaskGroundTruth(GroundTruth): 491 | 492 | def __init__(self, corpus: CrosstaskCorpus, tasks_by_id, t_by_video, remove_background): 493 | self._tasks_by_id = tasks_by_id 494 | self._K_by_task = { 495 | task_id: len(task.steps) + (0 if remove_background else 1) 496 | for task_id, task in tasks_by_id.items() 497 | } 498 | self._t_by_video = t_by_video 499 | task_names = list(sorted(set(self._tasks_by_id))) 500 | self._task_names = task_names 501 | 502 | self.constraints_by_task = {} 503 | 504 | super(CrosstaskGroundTruth, self).__init__(corpus, task_names, remove_background) 505 | 506 | def _load_gt_single(self, task, T, num_steps, filename): 507 | gt = read_assignment_list(T, num_steps, filename) 508 | # turn these step indices into global indices 509 | global_gt = [] 510 | 511 | previous_step_ix = 0 512 | for gt_t in gt: 513 | new_gt_t = [] 514 | for ix in gt_t: 515 | if ix == 0: 516 | if self._corpus.annotate_background_with_previous: 517 | label_idx = self._corpus.label2index[self._corpus.BACKGROUND_LABELS_BY_TASK[task][previous_step_ix]] 518 | else: 519 | assert len(self._corpus.BACKGROUND_LABELS_BY_TASK[task]) == 1 520 | label_idx = self._corpus.label2index[self._corpus.BACKGROUND_LABELS_BY_TASK[task][0]] 521 | else: 522 | label_idx = self._corpus._index(self._corpus.get_label(task, self._tasks_by_id[task].steps[ix - 1])) 523 | previous_step_ix = ix 524 | new_gt_t.append(label_idx) 525 | global_gt.append(new_gt_t) 526 | return global_gt 527 | 528 | def _load_gt(self): 529 | glob_path = os.path.join(self._corpus._release_root, "annotations", "*.csv") 530 | filenames = glob.glob(glob_path) 531 | assert filenames, "no filenames found for glob path {}".format(glob_path) 532 | # logger.debug("{} annotation files found".format(len(filenames))) 533 | 534 | def get_T(filename): 535 | file = os.path.split(filename)[1] 536 | file_no_ext = os.path.splitext(file)[0] 537 | splits = file_no_ext.split('_') 538 | task = int(splits[0]) 539 | video = '_'.join(splits[1:]) 540 | T = self._t_by_video[video] 541 | num_steps = self._K_by_task.get(task, None) 542 | return task, video, T, num_steps 543 | 544 | for filename in filenames: 545 | task, video, T, num_steps = get_T(filename) 546 | if task not in self._task_names: 547 | continue 548 | global_gt = self._load_gt_single(task, T, num_steps, filename) 549 | if task not in self.gt_by_task: 550 | self.gt_by_task[task] = {} 551 | self.gt_by_task[task][video] = global_gt 552 | 553 | if self._corpus.load_constraints: 554 | glob_path = os.path.join(self._corpus._constraints_root, "*.csv") 555 | filenames = glob.glob(glob_path) 556 | assert filenames, "no filenames found for glob path {}".format(glob_path) 557 | for constraint_fname in filenames: 558 | # constraint_fname = os.path.join(self._corpus._constraints_root, os.path.split(filename)[1]) 559 | task, video, T, num_steps = get_T(constraint_fname) 560 | if task not in self._task_names: 561 | continue 562 | constraint_mat = read_assignment( 563 | T, (num_steps if self._remove_background else num_steps-1), constraint_fname, include_background=False 564 | ) 565 | if task not in self.constraints_by_task: 566 | self.constraints_by_task[task] = {} 567 | self.constraints_by_task[task][video] = constraint_mat 568 | 569 | # def _load_mapping(self): 570 | # super(CrosstaskGroundTruth, self)._load_mapping() 571 | # 572 | # # augment indices_by_task with indices for related tasks, if they're present 573 | # for task_id, task in self._tasks_by_id.items(): 574 | # if task_id not in self._indices_by_task: 575 | # this_indices = set() 576 | # if not self._remove_background: 577 | # label_index = self._corpus.label2index[self._corpus.BACKGROUND_LABELS_BY_TASK[task]] 578 | # this_indices.add(label_index) 579 | # for step in task.steps: 580 | # label_index = self._corpus._index(self._corpus.get_label(task, step)) 581 | # this_indices.add(label_index) 582 | # self._indices_by_task[task_id] = list(sorted(this_indices)) 583 | 584 | 585 | 586 | def extract_feature_groups(corpus, narration_feature_dirs=None): 587 | group_indices = { 588 | 'i3d': (0, 1024), 589 | 'resnet': (1024, 3072), 590 | 'audio': (3072, 3200), 591 | } 592 | n_instances = len(corpus) 593 | grouped = defaultdict(dict) 594 | last_task = None 595 | task_feats = None 596 | for idx in range(n_instances): 597 | instance = corpus._get_by_index(idx) 598 | video_name = instance['video_name'] 599 | features = instance['features'] 600 | for group, (start, end) in group_indices.items(): 601 | grouped[group][video_name] = features[:, start:end] 602 | if narration_feature_dirs is not None: 603 | task = instance['task_name'] 604 | if last_task != task: 605 | task_data = [ 606 | load_pickle(os.path.join(dir, 'crosstask_narr_{}.pkl'.format(task))) 607 | for dir in narration_feature_dirs 608 | ] 609 | task_feats = { 610 | datum['video']: datum['narration'] 611 | for data in task_data 612 | for datum in data 613 | } 614 | grouped['narration'][video_name] = task_feats[video_name] 615 | last_task = task 616 | return grouped 617 | 618 | 619 | def pca_and_serialize_features(release_root, raw_feature_root, output_feature_root, constraints_root, remove_background, 620 | pca_components_per_group=300, by_task=True, task_sets=None, narration_feature_dirs=None): 621 | if by_task: 622 | grouped_datasets = datasets_by_task(release_root, raw_feature_root, constraints_root, 623 | remove_background, split='all', 624 | task_sets=task_sets, full=True) 625 | else: 626 | corpus = CrosstaskCorpus( 627 | release_root, 628 | raw_feature_root, 629 | use_secondary='related' in task_sets, 630 | load_constraints=True, 631 | constraints_root=constraints_root, 632 | ) 633 | grouped_datasets = { 634 | 'all': corpus.get_datasplit(remove_background, split='all', task_sets=task_sets) 635 | } 636 | 637 | os.makedirs(output_feature_root, exist_ok=True) 638 | 639 | for corpora_group, dataset in grouped_datasets.items(): 640 | logger.debug("saving features for task: {}".format(corpora_group)) 641 | grouped_features = extract_feature_groups(dataset, narration_feature_dirs) 642 | transformed, pca_models = grouped_pca(grouped_features, pca_components_per_group, pca_models_by_group=None) 643 | for feature_group, vid_dict in transformed.items(): 644 | logger.debug("\tsaving features for feature group: {}".format(feature_group)) 645 | feature_group_dir = os.path.join(output_feature_root, feature_group) 646 | os.makedirs(feature_group_dir, exist_ok=True) 647 | for vid, features in vid_dict.items(): 648 | fname = os.path.join(feature_group_dir, '{}.npy'.format(vid)) 649 | np.save(fname, features) 650 | 651 | 652 | if __name__ == "__main__": 653 | _release_root = 'data/crosstask/crosstask_release' 654 | _raw_feature_root = 'data/crosstask/crosstask_features' 655 | _constraints_root = 'data/crosstask/crosstask_constraints' 656 | _components = 200 657 | 658 | # _narration_dirs = ['data/crosstask/narration', 'data/crosstask/narration_test'] 659 | _narration_dirs = None 660 | 661 | _task_sets = ['primary'] 662 | for _remove_background in [False]: 663 | for _by_task in [True]: 664 | _output_feature_root = 'data/crosstask/crosstask_processed/crosstask_{}_pca-{}_{}_{}'.format( 665 | '+'.join(_task_sets), 666 | _components, 667 | 'no-bkg' if _remove_background else 'with-bkg', 668 | 'by-task' if _by_task else 'all-tasks', 669 | ) 670 | 671 | pca_and_serialize_features( 672 | _release_root, _raw_feature_root, _output_feature_root, _constraints_root, _remove_background, 673 | pca_components_per_group=_components, by_task=_by_task, task_sets=_task_sets, 674 | narration_feature_dirs=_narration_dirs 675 | ) 676 | 677 | # _task_sets = ['related'] 678 | # for _remove_background in [False]: 679 | # for _by_task in [True]: 680 | # #_output_feature_root = 'data/crosstask/crosstask_processed/crosstask_{}_pca-{}_{}_{}'.format( 681 | # # put related feats in the same directory as primary so we can load them all simultaneously 682 | # _output_feature_root = 'data/crosstask/crosstask_processed/crosstask_primary_pca-{}_{}_{}'.format( 683 | # # '+'.join(_task_sets), 684 | # _components, 685 | # 'no-bkg' if _remove_background else 'with-bkg', 686 | # 'by-task' if _by_task else 'all-tasks', 687 | # ) 688 | 689 | # pca_and_serialize_features( 690 | # _release_root, _raw_feature_root, _output_feature_root, _constraints_root, _remove_background, 691 | # pca_components_per_group=_components, by_task=_by_task, task_sets=_task_sets 692 | # ) 693 | --------------------------------------------------------------------------------