├── .gitignore ├── LICENSE ├── README.md ├── attention.py ├── configs ├── aln_uni_config.yml └── aln_uni_y8_config.yml ├── custom_callbacks.py ├── datasets.py ├── environment.yml ├── evaluation.py ├── knapsack.py ├── lit_models ├── lit_aln_uni.py └── lit_aln_uni_y8.py ├── main_ablations.py ├── main_y8.py ├── modules.py ├── run_ablation.sh ├── run_y8.sh ├── utils.py └── vsum_tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | logs 3 | wandb 4 | __pycache__ 5 | lightning_logs 6 | .ipynb_checkpoints 7 | wandb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 pangzss 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-CTVSUM 2 | 3 | This repository contains the official Pytorch implementation of the paper 4 | > [**Contrastive Losses Are Natural Criteria for Unsupervised Video Summarization**](https://arxiv.org/abs/2211.10056) 5 | > 6 | > Zongshang Pang, Yuta Nakashima, Mayu Otani, Hajime Nagahara 7 | > 8 | > In WACV2023 9 | 10 | ## Installation 11 | ```shell 12 | git clone https://github.com/pangzss/pytorch-CTVSUM.git 13 | cd pytorch-CTVSUM 14 | conda env create -f environment.yml 15 | conda activate ctvsum 16 | ``` 17 | 18 | ## Dataset preparation 19 | We use three datasets in our paper: [**TVSum**](https://github.com/yalesong/tvsum), [**SumMe**](https://gyglim.github.io/me/vsum/index.html), and a random subset of [**Youtube8M**](https://research.google.com/youtube8m/). 20 | 21 | TVSum and SumMe are used for training and evaluation, and Youtube8M is only used for training. 22 | 23 | To prepare the datasets, 24 | 25 | 1. Download raw videos from TVSum and SumMe and put them in ./data/raw 26 | 2. Download the extracted features from [**GoogleDrive**](https://drive.google.com/drive/folders/1ruIbB8LoJ1sbF_q_yihLuolE4JpEgK8G?usp=sharing) (GoogLeNet features for TVSum and SumMe, kindly provided by the authors of [**DRDSN**](https://github.com/KaiyangZhou/pytorch-vsumm-reinforce), and quantized Inception features for Youtube8M). 27 | 3. Put eccv_* files in ./data/interim, and unzip selected_features.zip in ./data/interim/youtube8M/ 28 | 29 | ## Training and Evaluation 30 | ### Evaluation with only pretrained features 31 | 1. In ./configs/aln_unif_config.yml, modify the dataset and evaluation settings, e.g. 32 | ```yaml 33 | data: 34 | name: tvsum # summe 35 | setting: Canonical # Augmented/Transfer 36 | ``` 37 | 2. Set 38 | ```yaml 39 | is_raw: True 40 | ``` 41 | 3. Set use Global Consistency or not 42 | ```yaml 43 | use_unif: True # False 44 | ``` 45 | 4. For Youtube8M features (quantized Inception), in ./configs/aln_unif_y8_config.yml, set 46 | ```yaml 47 | is_raw: True 48 | hparams: 49 | use_unif: True # False 50 | ``` 51 | 5. Run 52 | ```shell 53 | ./run_ablation.sh 54 | ./run_y8.sh 55 | ``` 56 | ### Contrastive refinement and evaluation 57 | 1. For TVSum and SumMe, set training/evaluation setting in ./configs/aln_unif_config.yml, and decide whether to use global consistency or uniqueness filter 58 | ```yaml 59 | is_raw: False 60 | use_unif: True # False 61 | use_unq: True # False 62 | ``` 63 | The code will run 5-fold cross validation by default. 64 | 65 | 2. For Youtube8M, similarly in ./configs/aln_unif_y8_config.yml, 66 | ```yaml 67 | is_raw: False 68 | hparams: 69 | use_unif: True # False 70 | use_unq: True # False 71 | ``` 72 | 3. Run 73 | ```bash 74 | ./run_ablation.sh 75 | ./run_y8.sh 76 | ``` 77 | ## Acknowledgement 78 | We would like to thank [**DRDSN**](https://github.com/KaiyangZhou/pytorch-vsumm-reinforce), which provides the extracted features and the evaluation code for TVSum and SumMe. Moreover, we are thankful to the insightful work [**Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere**](https://arxiv.org/abs/2005.10242), which inspired our work. 79 | 80 | ## Citation 81 | ```bibtex 82 | @inproceedings{pang2023contrastive, 83 | title={Contrastive Losses Are Natural Criteria for Unsupervised Video Summarization}, 84 | author={Pang, Zongshang and Nakashima, Yuta and Otani, Mayu and Nagahara, Hajime}, 85 | booktitle={WACV}, 86 | year={2023} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from modules import * 7 | 8 | class TransformerEncoder(nn.Module): 9 | ''' A encoder model with self attention mechanism. ''' 10 | 11 | def __init__( 12 | self, d_inp=1024, n_layers=4, n_head=1, d_k=64, d_v=64, 13 | d_model=256, d_inner=512, dropout=0., num_patches=300): 14 | 15 | super().__init__() 16 | self.n_layers = n_layers 17 | 18 | self.proj = nn.Linear(d_inp, d_model) 19 | self.layer_stack = nn.ModuleList([ 20 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 21 | for _ in range(n_layers)]) 22 | 23 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 24 | 25 | self.unq_est = EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 26 | self.score = nn.Sequential( 27 | nn.Linear(d_model, d_model), 28 | nn.ReLU(), 29 | nn.Linear(d_model, 1), 30 | nn.Sigmoid() 31 | ) 32 | def forward(self, src_seq): 33 | # -- Forward 34 | enc_output = self.proj(src_seq) 35 | enc_output = self.layer_norm(enc_output) 36 | for i, enc_layer in enumerate(self.layer_stack): 37 | enc_output, _ = enc_layer(enc_output) 38 | scores = self.score(self.unq_est(enc_output.detach())[0]) 39 | return enc_output, scores.squeeze(-1) 40 | 41 | if __name__ == '__main__': 42 | model = TransformerEncoder() 43 | inp = torch.rand(1,300,1024) 44 | enc_output = model(inp) 45 | print(enc_output.shape) 46 | -------------------------------------------------------------------------------- /configs/aln_uni_config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | d_inp: 1024 3 | d_model: 128 4 | data: 5 | num_workers: 2 6 | name: tvsum 7 | setting: Canonical 8 | split: 0 9 | num_frames: 200 10 | paths: 11 | interim: 'data/interim/' 12 | youtube8M: 'data/interim/youtube8M/selected_features' 13 | 14 | hparams: 15 | lr: 0.0001 16 | weight_decay: 0.0001 17 | ratio_s: 0 18 | ratio_k1: 0.1 19 | alpha: 0.5 20 | num_frames: 200 21 | batch_size: 32 22 | n_layer: 4 23 | n_head: 1 24 | setup: 25 | is_logger: False 26 | wandb_name: video-sum 27 | 28 | lightning: 29 | trainer: 30 | max_epochs: 40 31 | log_every_n_steps: 100 32 | 33 | is_raw: True 34 | use_unq: True 35 | use_unif: True 36 | -------------------------------------------------------------------------------- /configs/aln_uni_y8_config.yml: -------------------------------------------------------------------------------- 1 | model: 2 | d_inp: 1024 3 | d_model: 128 4 | data: 5 | num_workers: 2 6 | paths: 7 | interim: 'data/interim/' 8 | youtube8M: 'data/interim/youtube8M/selected_features' 9 | 10 | hparams: 11 | lr: 0.0001 12 | weight_decay: 0.0005 13 | ratio_s: 0 14 | ratio_k1: 0.1 15 | alpha: 0.5 16 | num_frames: 200 17 | batch_size: 128 18 | n_layer: 4 19 | n_head: 8 20 | use_unq: True 21 | use_unif: False 22 | setup: 23 | wandb_name: video-sum 24 | 25 | lightning: 26 | trainer: 27 | max_epochs: 40 28 | log_every_n_steps: 100 29 | 30 | is_raw: False 31 | -------------------------------------------------------------------------------- /custom_callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import wandb 7 | #import matplotlib.pyplot as plt 8 | 9 | from omegaconf import OmegaConf 10 | 11 | from pytorch_lightning.callbacks import Callback 12 | from pytorch_lightning.utilities.distributed import rank_zero_only 13 | 14 | class SetupCallback(Callback): 15 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config): 16 | super().__init__() 17 | self.resume = resume 18 | self.now = now 19 | self.logdir = logdir 20 | self.ckptdir = ckptdir 21 | self.cfgdir = cfgdir 22 | self.config = config 23 | 24 | def on_pretrain_routine_start(self, trainer, pl_module): 25 | if trainer.global_rank == 0: 26 | # Create logdirs and save configs 27 | os.makedirs(self.logdir, exist_ok=True) 28 | os.makedirs(self.ckptdir, exist_ok=True) 29 | os.makedirs(self.cfgdir, exist_ok=True) 30 | 31 | OmegaConf.save(self.config, 32 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 33 | else: 34 | # ModelCheckpoint callback created log directory --- remove it 35 | if not self.resume and os.path.exists(self.logdir): 36 | dst, name = os.path.split(self.logdir) 37 | dst = os.path.join(dst, "child_runs", name) 38 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 39 | try: 40 | os.rename(self.logdir, dst) 41 | except FileNotFoundError: 42 | pass 43 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch.nn.functional as F 6 | 7 | import h5py 8 | import numpy as np 9 | import json 10 | from utils import * 11 | from random import shuffle 12 | 13 | class DPLDatasetRand(Dataset): 14 | def __init__(self, cfg, mode): 15 | assert mode in ['train','test'] 16 | assert cfg.setting in ['Augmented','Transfer','Canonical'] 17 | 18 | self.mode = mode 19 | self.num_frames = cfg.num_frames 20 | self.train_data = [] 21 | self.test_data = [] 22 | dataset = cfg.name 23 | 24 | data_folder = cfg.paths.interim 25 | 26 | self.data = {} 27 | names = ['tvsum','summe','youtube','ovp'] 28 | for n in names: 29 | self.data[n] = h5py.File(data_folder+'/eccv16_dataset_{}_google_pool5.h5'.format(n), 'r') 30 | 31 | with open(data_folder+'/splits/{}_splits.json'.format(dataset), 'r') as fp: 32 | splits = json.load(fp) 33 | 34 | test_videos = splits[cfg.split]['test_keys'] 35 | train_videos = splits[cfg.split]['train_keys'] 36 | 37 | if cfg.setting == 'Augmented': 38 | for video in test_videos: 39 | self.test_data.append(self.data[dataset][video]) 40 | 41 | for n in names: 42 | if n == dataset: 43 | for video in train_videos: 44 | self.train_data.append(self.data[n][video]) 45 | else: 46 | for video in list(self.data[n].keys()): 47 | self.train_data.append(self.data[n][video]) 48 | elif cfg.setting == 'Transfer': 49 | for video in self.data[dataset].keys(): 50 | self.test_data.append(self.data[dataset][video]) 51 | 52 | for n in names: 53 | if n != dataset: 54 | for video in list(self.data[n].keys()): 55 | self.train_data.append(self.data[n][video]) 56 | else: 57 | for video in test_videos: 58 | self.test_data.append(self.data[dataset][video]) 59 | for video in train_videos: 60 | self.train_data.append(self.data[dataset][video]) 61 | 62 | def __len__(self): 63 | if self.mode == 'train': 64 | self.len = len(self.train_data) 65 | else: 66 | self.len = len(self.test_data) 67 | return self.len 68 | 69 | def __getitem__(self, index): 70 | if self.mode == 'train': 71 | feats = torch.Tensor(self.train_data[index]['features'][...]) 72 | length = len(feats) 73 | if length >= self.num_frames: 74 | ids = torch.randperm(length)[:self.num_frames] 75 | ids = torch.sort(ids)[0] 76 | else: 77 | ids = torch.arange(length).view(1,1,-1).float() 78 | ids = F.interpolate(ids,size=self.num_frames, mode='nearest').long().flatten() 79 | return feats[ids] 80 | else: 81 | video = self.test_data[index] 82 | return video 83 | 84 | class Youtube8M(Dataset): 85 | def __init__(self, cfg, num_frames): 86 | self.num_frames = num_frames 87 | 88 | self.dirpath = cfg.paths.youtube8M 89 | self.fname = os.listdir(self.dirpath) 90 | def dequantize(self,feat_vector, max_quantized_value=2, min_quantized_value=-2): 91 | ''' Dequantize the feature from the byte format to the float format. ''' 92 | assert max_quantized_value > min_quantized_value 93 | quantized_range = max_quantized_value - min_quantized_value 94 | scalar = quantized_range / 255.0 95 | bias = (quantized_range / 512.0) + min_quantized_value 96 | return feat_vector * scalar + bias 97 | 98 | def __len__(self): 99 | return len(self.fname) 100 | def __getitem__(self, index): 101 | fp = os.path.join(self.dirpath, self.fname[index]) 102 | feature = np.load(fp) 103 | deq_feature = torch.tensor(self.dequantize(feature)).float() 104 | length = len(deq_feature) 105 | if length >= self.num_frames: 106 | ids = torch.randperm(length)[:self.num_frames] 107 | ids = torch.sort(ids)[0] 108 | else: 109 | ids = torch.arange(length).view(1,1,-1).float() 110 | ids = F.interpolate(ids,size=self.num_frames, mode='nearest').long().flatten() 111 | ret_features = deq_feature[ids] 112 | ret_features = F.normalize(ret_features,p=2, dim=1) 113 | return ret_features -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ctvsum 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - _tflow_select=2.3.0=mkl 10 | - anyio=3.4.0=py38h578d9bd_0 11 | - astor=0.8.1=pyh9f0ad1d_0 12 | - astunparse=1.6.3=pyhd8ed1ab_0 13 | - async_generator=1.10=py_0 14 | - attrs=21.4.0=pyhd8ed1ab_0 15 | - babel=2.9.1=pyh44b312d_0 16 | - backcall=0.2.0=pyh9f0ad1d_0 17 | - backports=1.0=py_2 18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 19 | - blas=1.0=mkl 20 | - bleach=4.1.0=pyhd8ed1ab_0 21 | - brotli=1.0.9=h7f98852_5 22 | - brotli-bin=1.0.9=h7f98852_5 23 | - brotlipy=0.7.0=py38h497a2fe_1001 24 | - bzip2=1.0.8=h7b6447c_0 25 | - c-ares=1.18.1=h7f8727e_0 26 | - ca-certificates=2022.5.18.1=ha878542_0 27 | - cached-property=1.5.2=hd8ed1ab_1 28 | - cached_property=1.5.2=pyha770c72_1 29 | - certifi=2022.5.18.1=py38h578d9bd_0 30 | - chardet=4.0.0=py38h578d9bd_3 31 | - charset-normalizer=2.0.9=pyhd8ed1ab_0 32 | - colorama=0.4.4=pyh9f0ad1d_0 33 | - cryptography=35.0.0=py38ha5dfef3_0 34 | - cudatoolkit=11.3.1=h2bc3f7f_2 35 | - cycler=0.11.0=pyhd8ed1ab_0 36 | - dbus=1.13.18=hb2f20db_0 37 | - decorator=5.1.0=pyhd8ed1ab_0 38 | - defusedxml=0.7.1=pyhd8ed1ab_0 39 | - entrypoints=0.3=pyhd8ed1ab_1003 40 | - expat=2.2.10=h9c3ff4c_0 41 | - faiss=1.6.5=py38hf4212ac_1_cuda 42 | - faiss-gpu=1.6.5=hf05f184_1 43 | - ffmpeg=4.3=hf484d3e_0 44 | - fontconfig=2.13.1=hba837de_1005 45 | - fonttools=4.25.0=pyhd3eb1b0_0 46 | - freetype=2.11.0=h70c0345_0 47 | - future=0.18.2=py38h578d9bd_5 48 | - gast=0.4.0=pyh9f0ad1d_0 49 | - gflags=2.2.2=he1b5a44_1004 50 | - giflib=5.2.1=h7b6447c_0 51 | - glib=2.69.1=h5202010_0 52 | - gmp=6.2.1=h2531618_2 53 | - gnutls=3.6.15=he1e5248_0 54 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 55 | - google-pasta=0.2.0=pyh8c360ce_0 56 | - gst-plugins-base=1.14.0=hbbd80ab_1 57 | - gstreamer=1.14.0=h28cd5cc_2 58 | - hdf5=1.10.6=nompi_h3c11f04_101 59 | - icu=58.2=hf484d3e_1000 60 | - importlib-metadata=4.10.0=py38h578d9bd_0 61 | - importlib_resources=5.4.0=pyhd8ed1ab_0 62 | - intel-openmp=2021.4.0=h06a4308_3561 63 | - ipython=7.30.1=py38h578d9bd_0 64 | - ipython_genutils=0.2.0=py_1 65 | - jedi=0.18.1=py38h578d9bd_0 66 | - jinja2=3.0.3=pyhd8ed1ab_0 67 | - jpeg=9d=h7f8727e_0 68 | - json5=0.9.5=pyh9f0ad1d_0 69 | - jupyter_client=7.1.0=pyhd8ed1ab_0 70 | - jupyter_core=4.9.1=py38h578d9bd_1 71 | - jupyter_server=1.13.1=pyhd8ed1ab_0 72 | - jupyterlab=3.2.5=pyhd8ed1ab_0 73 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 74 | - jupyterlab_server=2.10.2=pyhd8ed1ab_0 75 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 76 | - kiwisolver=1.3.1=py38h2531618_0 77 | - lame=3.100=h7b6447c_0 78 | - lcms2=2.12=h3be6417_0 79 | - ld_impl_linux-64=2.35.1=h7274673_9 80 | - libblas=3.9.0=1_h6e990d7_netlib 81 | - libbrotlicommon=1.0.9=h7f98852_5 82 | - libbrotlidec=1.0.9=h7f98852_5 83 | - libbrotlienc=1.0.9=h7f98852_5 84 | - libfaiss=1.6.5=h5bea7ad_1_cuda 85 | - libffi=3.3=he6710b0_2 86 | - libgcc-ng=9.3.0=h5101ec6_17 87 | - libgfortran-ng=7.5.0=h14aa051_19 88 | - libgfortran4=7.5.0=h14aa051_19 89 | - libgomp=9.3.0=h5101ec6_17 90 | - libiconv=1.15=h63c8f33_5 91 | - libidn2=2.3.2=h7f8727e_0 92 | - liblapack=3.9.0=3_h893e4fe_netlib 93 | - libpng=1.6.37=hbc83047_0 94 | - libprotobuf=3.15.8=h780b84a_0 95 | - libsodium=1.0.18=h36c2ea0_1 96 | - libstdcxx-ng=9.3.0=hd4cf53a_17 97 | - libtasn1=4.16.0=h27cfd23_0 98 | - libtiff=4.2.0=h85742a9_0 99 | - libunistring=0.9.10=h27cfd23_0 100 | - libuuid=2.32.1=h7f98852_1000 101 | - libuv=1.40.0=h7b6447c_0 102 | - libwebp=1.2.0=h89dd481_0 103 | - libwebp-base=1.2.0=h27cfd23_0 104 | - libxcb=1.13=h7f98852_1003 105 | - libxml2=2.9.12=h03d6c58_0 106 | - lz4-c=1.9.3=h295c915_1 107 | - markupsafe=2.0.1=py38h497a2fe_0 108 | - matplotlib=3.4.3=py38h578d9bd_1 109 | - matplotlib-base=3.4.3=py38hbbc1b5f_0 110 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 111 | - mistune=0.8.4=py38h497a2fe_1004 112 | - mkl=2020.2=256 113 | - mkl-service=2.3.0=py38he904b0f_0 114 | - mkl_fft=1.3.0=py38h54f3939_0 115 | - mkl_random=1.1.1=py38h0573a6f_0 116 | - multidict=5.2.0=py38h7f8727e_2 117 | - munkres=1.1.4=pyh9f0ad1d_0 118 | - nbclassic=0.3.4=pyhd8ed1ab_0 119 | - nbclient=0.5.9=pyhd8ed1ab_0 120 | - nbconvert=6.3.0=py38h578d9bd_1 121 | - nbformat=5.1.3=pyhd8ed1ab_0 122 | - ncurses=6.3=h7f8727e_2 123 | - nest-asyncio=1.5.4=pyhd8ed1ab_0 124 | - nettle=3.7.3=hbbd107a_1 125 | - ninja=1.10.2=py38hd09550d_3 126 | - notebook=6.4.6=pyha770c72_0 127 | - numpy=1.19.2=py38h54aff64_0 128 | - numpy-base=1.19.2=py38hfa32c7d_0 129 | - olefile=0.46=pyhd3eb1b0_0 130 | - openh264=2.1.1=h4ff587b_0 131 | - openssl=1.1.1o=h7f8727e_0 132 | - opt_einsum=3.3.0=pyhd8ed1ab_1 133 | - packaging=21.3=pyhd8ed1ab_0 134 | - pandoc=2.16.2=h7f98852_0 135 | - pandocfilters=1.5.0=pyhd8ed1ab_0 136 | - parso=0.8.3=pyhd8ed1ab_0 137 | - pcre=8.45=h9c3ff4c_0 138 | - pexpect=4.8.0=pyh9f0ad1d_2 139 | - pickleshare=0.7.5=py_1003 140 | - pillow=8.4.0=py38h5aabda8_0 141 | - pip=20.3.3=py38h06a4308_0 142 | - prometheus_client=0.12.0=pyhd8ed1ab_0 143 | - prompt-toolkit=3.0.24=pyha770c72_0 144 | - pthread-stubs=0.4=h36c2ea0_1001 145 | - ptyprocess=0.7.0=pyhd3deb0d_0 146 | - pyasn1=0.4.8=py_0 147 | - pycparser=2.21=pyhd8ed1ab_0 148 | - pyjwt=2.4.0=pyhd8ed1ab_0 149 | - pyopenssl=21.0.0=pyhd8ed1ab_0 150 | - pyparsing=3.0.6=pyhd8ed1ab_0 151 | - pyqt=5.9.2=py38h05f1152_4 152 | - pyrsistent=0.18.0=py38heee7806_0 153 | - pysocks=1.7.1=py38h578d9bd_4 154 | - python=3.8.5=h7579374_1 155 | - python-dateutil=2.8.2=pyhd8ed1ab_0 156 | - python-flatbuffers=2.0=pyhd8ed1ab_0 157 | - python_abi=3.8=2_cp38 158 | - pytorch=1.10.1=py3.8_cuda11.3_cudnn8.2.0_0 159 | - pytorch-mutex=1.0=cuda 160 | - pytz=2021.3=pyhd8ed1ab_0 161 | - pyu2f=0.1.5=pyhd8ed1ab_0 162 | - qt=5.9.7=h5867ecd_1 163 | - readline=8.1=h27cfd23_0 164 | - requests=2.26.0=pyhd8ed1ab_1 165 | - rsa=4.8=pyhd8ed1ab_0 166 | - send2trash=1.8.0=pyhd8ed1ab_0 167 | - sip=4.19.13=py38he6710b0_0 168 | - six=1.16.0=pyhd3eb1b0_0 169 | - sniffio=1.2.0=py38h578d9bd_2 170 | - sqlite=3.37.0=hc218d9a_0 171 | - tensorflow=2.4.1=mkl_py38hb2083e0_0 172 | - tensorflow-base=2.4.1=mkl_py38h43e0292_0 173 | - tensorflow-estimator=2.5.0=pyh81a9013_1 174 | - terminado=0.12.1=py38h578d9bd_1 175 | - testpath=0.5.0=pyhd8ed1ab_0 176 | - timm=0.4.12=pyhd8ed1ab_0 177 | - tk=8.6.11=h1ccaba5_0 178 | - torchaudio=0.10.1=py38_cu113 179 | - torchvision=0.11.2=py38_cu113 180 | - tornado=6.1=py38h497a2fe_1 181 | - traitlets=5.1.1=pyhd8ed1ab_0 182 | - tsnecuda=3.0.0=cuda112py38hdfc3e5f_2 183 | - typing_extensions=3.10.0.2=pyh06a4308_0 184 | - urllib3=1.26.7=pyhd8ed1ab_0 185 | - wcwidth=0.2.5=pyh9f0ad1d_2 186 | - websocket-client=1.2.3=pyhd8ed1ab_0 187 | - wheel=0.37.0=pyhd3eb1b0_1 188 | - wrapt=1.12.1=py38h497a2fe_3 189 | - xorg-libxau=1.0.9=h7f98852_0 190 | - xorg-libxdmcp=1.1.3=h7f98852_0 191 | - xz=5.2.5=h7b6447c_0 192 | - yaml=0.2.5=h516909a_0 193 | - zeromq=4.3.4=h9c3ff4c_0 194 | - zlib=1.2.11=h7f8727e_4 195 | - zstd=1.4.9=haebb681_0 196 | - pip: 197 | - absl-py==1.0.0 198 | - aiohttp==3.8.1 199 | - aiosignal==1.2.0 200 | - albumentations==0.4.3 201 | - altair==4.2.0 202 | - antlr4-python3-runtime==4.8 203 | - argon2-cffi==21.3.0 204 | - argon2-cffi-bindings==21.2.0 205 | - async-timeout==4.0.2 206 | - backports-zoneinfo==0.2.1 207 | - base58==2.1.1 208 | - blinker==1.4 209 | - build==0.7.0 210 | - cachetools==4.2.4 211 | - cffi==1.15.0 212 | - click==7.1.2 213 | - cmake==3.18.4 214 | - colorspacious==1.1.2 215 | - configparser==5.2.0 216 | - cox==0.1.post3 217 | - dataclasses==0.6 218 | - debugpy==1.5.1 219 | - dill==0.3.4 220 | - docker-pycreds==0.4.0 221 | - einops==0.3.0 222 | - ffmpeg-python==0.2.0 223 | - filelock==3.4.2 224 | - frozenlist==1.2.0 225 | - fsspec==2021.11.1 226 | - geomloss==0.2.4 227 | - gitdb==4.0.9 228 | - gitpython==3.1.24 229 | - google-auth==2.3.3 230 | - gputil==1.4.0 231 | - grpcio==1.43.0 232 | - h5py==3.6.0 233 | - huggingface-hub==0.8.1 234 | - hydra-core==1.1.1 235 | - idna==3.3 236 | - imageio==2.9.0 237 | - imageio-ffmpeg==0.4.2 238 | - imgaug==0.2.6 239 | - ipykernel==6.6.0 240 | - ipython-genutils==0.2.0 241 | - ipywidgets==7.6.5 242 | - joblib==1.1.0 243 | - jsonschema==4.3.2 244 | - jupyterlab-widgets==1.0.2 245 | - kaggle==1.5.12 246 | - kmeans-pytorch==0.3 247 | - lightning-bolts==0.5.0 248 | - llvmlite==0.38.0 249 | - lmdb==1.3.0 250 | - markdown==3.3.6 251 | - more-itertools==8.12.0 252 | - multicoretsne==0.1 253 | - networkx==2.6.3 254 | - numba==0.55.1 255 | - numexpr==2.8.1 256 | - oauthlib==3.1.1 257 | - omegaconf==2.1.1 258 | - opencv-python==4.1.2.30 259 | - opencv-python-headless==4.5.5.62 260 | - ortools==9.2.9972 261 | - pandas==1.3.5 262 | - pathtools==0.1.2 263 | - pep517==0.12.0 264 | - promise==2.3 265 | - protobuf==3.19.1 266 | - psutil==5.9.0 267 | - pudb==2019.2 268 | - py3nvml==0.2.7 269 | - pyarrow==6.0.1 270 | - pyasn1-modules==0.2.8 271 | - pydeck==0.7.1 272 | - pydeprecate==0.3.1 273 | - pydot==1.4.2 274 | - pygments==2.11.0 275 | - pygraphviz==1.9 276 | - pympler==1.0.1 277 | - python-graphviz==0.20.1 278 | - python-slugify==5.0.2 279 | - pytorch-lightning==1.6.3 280 | - pytz-deprecation-shim==0.1.0.post0 281 | - pywavelets==1.2.0 282 | - pyyaml==6.0 283 | - pyzmq==22.3.0 284 | - regex==2021.11.10 285 | - requests-oauthlib==1.3.0 286 | - robustness==1.2.1.post2 287 | - sacremoses==0.0.46 288 | - scikit-image==0.19.1 289 | - scikit-learn==0.22.2 290 | - scipy==1.7.3 291 | - seaborn==0.11.2 292 | - sentry-sdk==1.5.1 293 | - setuptools==59.5.0 294 | - shortuuid==1.0.8 295 | - smmap==5.0.0 296 | - streamlit==1.3.1 297 | - subprocess32==3.5.4 298 | - tables==3.7.0 299 | - tensorboard==2.7.0 300 | - tensorboard-data-server==0.6.1 301 | - tensorboard-plugin-wit==1.8.0 302 | - tensorboardx==2.4.1 303 | - termcolor==1.1.0 304 | - test-tube==0.7.5 305 | - text-unidecode==1.3 306 | - threadpoolctl==3.0.0 307 | - tifffile==2021.11.2 308 | - tokenizers==0.12.1 309 | - toml==0.10.2 310 | - tomli==2.0.0 311 | - toolz==0.11.2 312 | - torchmetrics==0.6.2 313 | - tqdm==4.62.3 314 | - transformers==4.21.1 315 | - typing-extensions==4.2.0 316 | - tzdata==2021.5 317 | - tzlocal==4.1 318 | - urwid==2.1.2 319 | - validators==0.18.2 320 | - wandb==0.12.10 321 | - watchdog==2.1.6 322 | - webencodings==0.5.1 323 | - werkzeug==2.0.2 324 | - wget==3.2 325 | - widgetsnbextension==3.5.2 326 | - xmltodict==0.12.0 327 | - yarl==1.7.2 328 | - yaspin==2.1.0 329 | - zipp==3.7.0 330 | prefix: /opt/conda/envs/ctvsum 331 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | 5 | import h5py 6 | import json 7 | from scipy.stats import kendalltau, spearmanr 8 | from scipy.stats import rankdata 9 | 10 | import vsum_tools 11 | 12 | def get_rc_func(metric): 13 | if metric == 'kendalltau': 14 | f = lambda x, y: kendalltau(rankdata(-x), rankdata(-y)) 15 | elif metric == 'spearmanr': 16 | f = lambda x, y: spearmanr(x, y) 17 | else: 18 | raise RuntimeError 19 | return f 20 | 21 | kendall = get_rc_func('kendalltau') 22 | spearman = get_rc_func('spearmanr') 23 | 24 | def evaluate(model, test_loader, dataset, split, log_dir, save_h5=False): 25 | print("==> Test") 26 | with torch.no_grad(): 27 | model.eval() 28 | fms = [] 29 | eval_metric = 'avg' if dataset == 'tvsum' else 'max' 30 | if save_h5: 31 | experiment = '{}_split_{}.h5'.format(dataset, split) 32 | h5_res = h5py.File(osp.join(log_dir,experiment), 'w') 33 | 34 | tau_list = [] 35 | r_list = [] 36 | for data in test_loader: 37 | video = data 38 | video_name = str(data['video_name'][...].astype('U8')) 39 | cps = video['change_points'][...] 40 | video_feats = torch.Tensor(video['features'][...]) 41 | video_feats = video_feats.to(model.device) 42 | 43 | _, sum_attns = model(video_feats.unsqueeze(0)) 44 | scores = torch.mean(sum_attns[0],dim=0).cpu().numpy() 45 | 46 | if vars['dataset'] == 'tvsum': 47 | user_scores = video['user_scores'][...] 48 | num_users = user_scores.shape[0] 49 | tau = 0. 50 | r = 0. 51 | for us in user_scores: 52 | tau += kendall(us,scores)[0] 53 | r += spearman(us,scores)[0] 54 | tau_list.append(tau/num_users) 55 | r_list.append(r/num_users) 56 | else: 57 | user_scores = video['user_summary'][...] 58 | num_users = user_scores.shape[0] 59 | tau = 0. 60 | r = 0. 61 | for us in user_scores: 62 | us = us[::15] 63 | tau += kendall(us,scores)[0] 64 | r += spearman(us,scores)[0] 65 | tau_list.append(tau/num_users) 66 | r_list.append(r/num_users) 67 | 68 | num_frames = video['n_frames'][()] 69 | 70 | nfps = video['n_frame_per_seg'][...].tolist() 71 | positions = video['picks'][...] 72 | user_summary = video['user_summary'][...] 73 | 74 | machine_summary = vsum_tools.generate_summary(scores, cps, num_frames, nfps, positions) 75 | fm, _, _,machine_summary,user_summary = vsum_tools.evaluate_summary(machine_summary, user_summary, eval_metric) 76 | fms.append(fm) 77 | 78 | if save_h5: 79 | h5_res.create_dataset(video_name + '/score', data=scores) 80 | h5_res.create_dataset(video_name + '/machine_summary', data=machine_summary) 81 | h5_res.create_dataset(video_name + '/gtscore', data=video['gtscore'][...]) 82 | h5_res.create_dataset(video_name + '/fm', data=fm) 83 | h5_res.create_dataset(video_name + '/tau', data=tau_list[-1]) 84 | h5_res.create_dataset(video_name + '/r', data=r_list[-1]) 85 | h5_res.create_dataset(video_name + '/user_summary', data=user_summary) 86 | h5_res.create_dataset(video_name + '/change_points', data=cps) 87 | 88 | mean_tau = np.mean(tau_list) 89 | mean_r = np.mean(r_list) 90 | if save_h5: 91 | h5_res.close() 92 | mean_fm = np.mean(fms) 93 | print("Average F-score {:.1%}. Tau {:.5f}, R {:.5f}".format(mean_fm,mean_tau,mean_r)) 94 | 95 | return mean_fm,mean_tau, mean_r 96 | -------------------------------------------------------------------------------- /knapsack.py: -------------------------------------------------------------------------------- 1 | # A Dynamic Programming based Python Program for 0-1 Knapsack problem 2 | # Returns the maximum value that can be put in a knapsack of capacity W 3 | import numpy as np 4 | from ortools.algorithms import pywrapknapsack_solver 5 | 6 | 7 | def knapsack(W, wt, val, n): 8 | K = [[0 for x in range(W+1)] for x in range(n+1)] 9 | 10 | # Build table K[][] in bottom up manner 11 | for i in range(n+1): 12 | for w in range(W+1): 13 | if i==0 or w==0: 14 | K[i][w] = 0 15 | elif wt[i-1] <= w: 16 | K[i][w] = max(val[i-1] + K[i-1][w-wt[i-1]], K[i-1][w]) 17 | else: 18 | K[i][w] = K[i-1][w] 19 | 20 | 21 | best = K[n][W] 22 | 23 | amount = np.zeros(n) 24 | a = best 25 | j = n 26 | Y = W 27 | 28 | # j = j + 1; 29 | # 30 | # amount(j) = 1; 31 | # Y = Y - weights(j); 32 | # j = j - 1; 33 | # a = A(j + 1, Y + 1); 34 | 35 | while a > 0: 36 | while K[j][Y] == a: 37 | j = j - 1 38 | 39 | j = j + 1 40 | amount[j-1] = 1 41 | Y = Y - wt[j-1] 42 | j = j - 1 43 | a = K[j][Y] 44 | 45 | return amount 46 | 47 | 48 | def test_knapsack(): 49 | weights = [1 ,1 ,1, 1 ,2 ,2 ,3] 50 | values = [1 ,1 ,2 ,3, 1, 3 ,5] 51 | best = 13 52 | print(knapsack(7, weights, values, 7)) 53 | 54 | #=========================================== 55 | ''' 56 | ------------------------------------------------ 57 | Use dynamic programming (DP) to solve 0/1 knapsack problem 58 | Time complexity: O(nW), where n is number of items and W is capacity 59 | Author: Kaiyang Zhou 60 | Website: https://kaiyangzhou.github.io/ 61 | ------------------------------------------------ 62 | knapsack_dp(values,weights,n_items,capacity,return_all=False) 63 | Input arguments: 64 | 1. values: a list of numbers in either int or float, specifying the values of items 65 | 2. weights: a list of int numbers specifying weights of items 66 | 3. n_items: an int number indicating number of items 67 | 4. capacity: an int number indicating the knapsack capacity 68 | 5. return_all: whether return all info, defaulty is False (optional) 69 | Return: 70 | 1. picks: a list of numbers storing the positions of selected items 71 | 2. max_val: maximum value (optional) 72 | ------------------------------------------------ 73 | ''' 74 | def knapsack_dp(values,weights,n_items,capacity,return_all=False): 75 | check_inputs(values,weights,n_items,capacity) 76 | 77 | table = np.zeros((n_items+1,capacity+1),dtype=np.float32) 78 | keep = np.zeros((n_items+1,capacity+1),dtype=np.float32) 79 | 80 | for i in range(1,n_items+1): 81 | for w in range(0,capacity+1): 82 | wi = weights[i-1] # weight of current item 83 | vi = values[i-1] # value of current item 84 | if (wi <= w) and (vi + table[i-1,w-wi] > table[i-1,w]): 85 | table[i,w] = vi + table[i-1,w-wi] 86 | keep[i,w] = 1 87 | else: 88 | table[i,w] = table[i-1,w] 89 | 90 | picks = [] 91 | K = capacity 92 | 93 | for i in range(n_items,0,-1): 94 | if keep[i,K] == 1: 95 | picks.append(i) 96 | K -= weights[i-1] 97 | 98 | picks.sort() 99 | picks = [x-1 for x in picks] # change to 0-index 100 | 101 | if return_all: 102 | max_val = table[n_items,capacity] 103 | return picks,max_val 104 | return picks 105 | 106 | def check_inputs(values,weights,n_items,capacity): 107 | # check variable type 108 | assert(isinstance(values,list)) 109 | assert(isinstance(weights,list)) 110 | assert(isinstance(n_items,int)) 111 | assert(isinstance(capacity,int)) 112 | # check value type 113 | assert(all(isinstance(val,int) or isinstance(val,float) for val in values)) 114 | assert(all(isinstance(val,int) for val in weights)) 115 | # check validity of value 116 | assert(all(val >= 0 for val in weights)) 117 | assert(n_items > 0) 118 | assert(capacity > 0) 119 | 120 | def test_knapsack_dp(): 121 | values = [2,3,4] 122 | weights = [1,2,3] 123 | n_items = 3 124 | capacity = 3 125 | picks = knapsack_dp(values,weights,n_items,capacity) 126 | print (picks) 127 | 128 | 129 | 130 | osolver = pywrapknapsack_solver.KnapsackSolver( 131 | # pywrapknapsack_solver.KnapsackSolver.KNAPSACK_MULTIDIMENSION_BRANCH_AND_BOUND_SOLVER, 132 | pywrapknapsack_solver.KnapsackSolver.KNAPSACK_DYNAMIC_PROGRAMMING_SOLVER, 133 | 'test') 134 | 135 | def knapsack_ortools(values, weights, items, capacity ): 136 | scale = 1000 137 | values = np.array(values) 138 | weights = np.array(weights) 139 | values = (values * scale).astype(np.int) 140 | weights = (weights).astype(np.int) 141 | capacity = capacity 142 | 143 | osolver.Init(values.tolist(), [weights.tolist()], [capacity]) 144 | computed_value = osolver.Solve() 145 | packed_items = [x for x in range(0, len(weights)) 146 | if osolver.BestSolutionContains(x)] 147 | 148 | return packed_items 149 | 150 | 151 | if __name__ == "__main__": 152 | test_knapsack_dp() 153 | test_knapsack() -------------------------------------------------------------------------------- /lit_models/lit_aln_uni.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import pytorch_lightning as pl 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | import vsum_tools 11 | 12 | from torch.utils.data import DataLoader 13 | from scipy.stats import kendalltau, spearmanr 14 | from scipy.stats import rankdata 15 | from pytorch_lightning.utilities.distributed import rank_zero_only 16 | from scipy.ndimage import gaussian_filter1d 17 | from einops import rearrange 18 | 19 | from datasets import DPLDatasetRand 20 | from utils import * 21 | from attention import TransformerEncoder 22 | def get_rc_func(metric): 23 | if metric == 'kendalltau': 24 | f = lambda x, y: kendalltau(rankdata(-x), rankdata(-y)) 25 | elif metric == 'spearmanr': 26 | f = lambda x, y: spearmanr(x, y) 27 | else: 28 | raise RuntimeError 29 | return f 30 | 31 | kendall = get_rc_func('kendalltau') 32 | spearman = get_rc_func('spearmanr') 33 | 34 | class LitModel(pl.LightningModule): 35 | def __init__(self, cfg, hpms): 36 | super().__init__() 37 | self.is_raw = cfg.is_raw 38 | self.use_unq = cfg.use_unq 39 | self.use_unif = cfg.use_unif 40 | 41 | self.model_cfg = cfg.model 42 | self.data_cfg = cfg.data 43 | self.hpms = hpms 44 | self.setup_cfg = cfg.setup 45 | self.lit_cfg = cfg.lightning 46 | self.model = TransformerEncoder(n_layers=self.hpms.n_layer, n_head=self.hpms.n_head, num_patches=self.hpms.num_frames) 47 | self.bce = nn.BCELoss() 48 | 49 | def forward(self,x): 50 | out, scores = self.model(x) 51 | return out, scores 52 | 53 | def get_values(self, feats, proj, scores, train=True): 54 | assert (len(feats.shape) == 3) and (len(proj.shape) == 3) 55 | with torch.no_grad(): 56 | norm_raw = F.normalize(feats, p=2, dim=-1) 57 | xy_raw = torch.einsum('bmc, bnc -> bmn', norm_raw, norm_raw) 58 | norm_proj = F.normalize(proj, p=2, dim=-1) 59 | xy = torch.einsum('bmc, bnc -> bmn', norm_proj, norm_proj) 60 | sort_ids = torch.argsort(xy_raw, -1, descending=True) 61 | 62 | diff_mat = 2 - 2 * xy 63 | L = feats.shape[1] 64 | S = int(L * self.hpms.ratio_s) 65 | K1 = int(L * self.hpms.ratio_k1) 66 | 67 | pos = torch.gather(diff_mat, -1, sort_ids[:,:,S:S+K1]) 68 | 69 | laln = pos.mean(dim=-1) 70 | 71 | lunif = diff_mat.mul(-2).exp().mean(dim=-1).log() 72 | 73 | if train: 74 | s = 20 75 | seg = rearrange(proj, 'b (s l) c -> (b s) l c', s=s) 76 | seg_feats = F.normalize(seg.mean(dim=1), dim=1) # (b s) c 77 | fv_xy = torch.einsum("bmc, nc -> bmn", norm_proj, seg_feats) # b m (b s) 78 | 79 | mask = torch.ones_like(fv_xy) 80 | # for i in range(len(mask)): 81 | # mask[i,:,i * s : (i+1) * s] = 0 82 | lunif_fv = (fv_xy.mul(4).exp() * mask).sum(dim=-1) / mask.sum(dim=-1) 83 | lunif_fv = lunif_fv.log() 84 | unq_target = (lunif_fv - lunif_fv.min(dim=-1, keepdim=True)[0]) / (lunif_fv.max(dim=-1, keepdim=True)[0] - lunif_fv.min(dim=-1, keepdim=True)[0]).add(1e-9) 85 | unq_target = unq_target * 0.5 + 0.25 86 | lunq = self.bce(scores, 1 - unq_target.detach()) 87 | return laln, lunif, lunif_fv, lunq 88 | else: 89 | return laln, lunif 90 | 91 | def training_step(self, batch, batch_idx): 92 | out, scores = self(batch) 93 | 94 | laln, lunif, lunif_fv, lunq = self.get_values(batch, out, scores) 95 | 96 | if self.use_unq: 97 | loss = laln.mean() + self.hpms.alpha * lunif.mean() + 0.1 * lunif_fv.mean() + 0.1 * lunq 98 | else: 99 | loss = laln.mean() + self.hpms.alpha * lunif.mean() 100 | 101 | return loss 102 | 103 | def on_test_start(self): 104 | self.fms = [] 105 | self.pres = [] 106 | self.recs = [] 107 | self.tau_list = [] 108 | self.r_list = [] 109 | self.eval_metric = 'avg' if self.data_cfg.name == 'tvsum' else 'max' 110 | # if self.setup_cfg.save_h5: 111 | # experiment = '{}_split_{}.h5'.format(self.data_cfg.name, self.data_cfg.split) 112 | # self.h5_res = h5py.File(osp.join(self.setup_cfg.logdir,experiment), 'w') 113 | def test_step(self,batch, batch_ids, eps=0.01): 114 | video = batch 115 | # video_name = str(video['video_name'][...].astype('U8')) 116 | cps = video['change_points'][...] 117 | video_feats = torch.Tensor(video['features'][...]).to(self.device) 118 | 119 | feats = video_feats.unsqueeze(0) 120 | # proj = self(feats.transpose(2,1)) 121 | proj, unq_scores = self(feats) 122 | 123 | laln_raw, lunif_raw = self.get_values(feats, feats, unq_scores, train=False) 124 | laln_raw = laln_raw.flatten().cpu() 125 | lunif_raw = lunif_raw.flatten().cpu() 126 | 127 | laln_raw = (laln_raw - laln_raw.min()) / (laln_raw.max() - laln_raw.min()) 128 | lunif_raw = (lunif_raw - lunif_raw.min()) / (lunif_raw.max() - lunif_raw.min()) 129 | 130 | laln, lunif = self.get_values(feats, proj, unq_scores, train=False) 131 | laln = laln.flatten().cpu() 132 | lunif = lunif.flatten().cpu() 133 | 134 | laln = (laln - laln.min()) / (laln.max() - laln.min()) 135 | lunif = (lunif - lunif.min()) / (lunif.max() - lunif.min()) 136 | 137 | unq_scores = unq_scores.cpu().flatten() 138 | unq_scores = (unq_scores - unq_scores.min()) / (unq_scores.max() - unq_scores.min()) 139 | if not self.is_raw: 140 | if self.use_unif: 141 | scores = laln * lunif 142 | else: 143 | scores = laln 144 | 145 | if self.use_unq: 146 | scores = scores * unq_scores 147 | else: 148 | if self.use_unif: 149 | scores = laln_raw * lunif_raw 150 | else: 151 | scores = laln_raw 152 | if self.data_cfg.name == 'tvsum': 153 | scores = np.exp(scores - 1) 154 | elif self.data_cfg.name == 'summe': 155 | scores = scores + 0.05 156 | else: 157 | raise NotImplementedError 158 | 159 | scores = gaussian_filter1d(scores, 1) 160 | 161 | if self.data_cfg.name == 'tvsum': 162 | user_scores = video['user_scores'][...] 163 | num_users = user_scores.shape[0] 164 | tau = 0. 165 | r = 0. 166 | for us in user_scores: 167 | tau += kendall(us,scores)[0] 168 | r += spearman(us,scores)[0] 169 | self.tau_list.append(tau/num_users) 170 | self.r_list.append(r/num_users) 171 | else: 172 | user_scores = video['user_summary'][...] 173 | num_users = user_scores.shape[0] 174 | tau = 0. 175 | r = 0. 176 | for us in user_scores: 177 | us = us[::15] 178 | tau += kendall(us,scores)[0] 179 | r += spearman(us,scores)[0] 180 | self.tau_list.append(tau/num_users) 181 | self.r_list.append(r/num_users) 182 | 183 | num_frames = video['n_frames'][()] 184 | 185 | nfps = video['n_frame_per_seg'][...].tolist() 186 | positions = video['picks'][...] 187 | user_summary = video['user_summary'][...] 188 | 189 | machine_summary = vsum_tools.generate_summary(scores, cps, num_frames, nfps, positions) 190 | fm, pre, rec,machine_summary,user_summary = vsum_tools.evaluate_summary(machine_summary, user_summary, self.eval_metric) 191 | self.fms.append(fm) 192 | self.pres.append(pre) 193 | self.recs.append(rec) 194 | 195 | @rank_zero_only 196 | def test_epoch_end(self, outputs): 197 | mean_tau = torch.mean(torch.tensor(self.all_gather(self.tau_list))) 198 | mean_r = torch.mean(torch.tensor(self.all_gather(self.r_list))) 199 | # if self.setup_cfg.save_h5: 200 | # self.h5_res.close() 201 | mean_fm = torch.mean(torch.tensor(self.all_gather(self.fms))) 202 | mean_pre = torch.mean(torch.tensor(self.all_gather(self.pres))) 203 | mean_rec = torch.mean(torch.tensor(self.all_gather(self.recs))) 204 | 205 | self.log_dict({"F1": mean_fm, 206 | "tau": mean_tau, 207 | "rho": mean_r}) 208 | 209 | def train_dataloader(self): 210 | dataset = DPLDatasetRand(self.data_cfg,mode='train') 211 | print(len(dataset)) 212 | dataloader = DataLoader(dataset, 213 | num_workers=self.data_cfg.num_workers, 214 | batch_size=self.hpms.batch_size, 215 | shuffle=True, 216 | pin_memory=True) 217 | self.dataloader_len = len(dataloader) 218 | print("Number of training videos: {}".format(len(dataset))) 219 | return dataloader 220 | 221 | def test_dataloader(self): 222 | dataset = DPLDatasetRand(self.data_cfg, mode='test') 223 | dataloader = DataLoader(dataset, 224 | num_workers=0, 225 | batch_size=1, 226 | shuffle=False, 227 | collate_fn=test_collate, 228 | pin_memory=True) 229 | print("Number of test videos: {}".format(len(dataset))) 230 | return dataloader 231 | 232 | def configure_optimizers(self): 233 | optimizer = torch.optim.Adam(list(self.model.parameters()), 234 | lr=self.hpms.lr, 235 | weight_decay=self.hpms.weight_decay) 236 | return optimizer 237 | 238 | @torch.no_grad() 239 | def log_hist(self,name,inp_): 240 | inp = inp_.detach().cpu().flatten() 241 | fig = plt.figure() 242 | ax = fig.gca() 243 | sns.histplot(inp,bins=50) 244 | plt.tight_layout() 245 | 246 | add_plot(self.logger.experiment, name, self.global_step) -------------------------------------------------------------------------------- /lit_models/lit_aln_uni_y8.py: -------------------------------------------------------------------------------- 1 | from re import L 2 | import h5py 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import pytorch_lightning as pl 9 | import vsum_tools 10 | 11 | from torch.utils.data import DataLoader 12 | 13 | from scipy.stats import kendalltau, spearmanr 14 | from scipy.stats import rankdata 15 | from pytorch_lightning.utilities.distributed import rank_zero_only 16 | from einops import rearrange 17 | from scipy.ndimage import gaussian_filter1d 18 | 19 | from datasets import Youtube8M 20 | from utils import * 21 | from attention import TransformerEncoder 22 | 23 | def get_rc_func(metric): 24 | if metric == 'kendalltau': 25 | f = lambda x, y: kendalltau(rankdata(-x), rankdata(-y)) 26 | elif metric == 'spearmanr': 27 | f = lambda x, y: spearmanr(x, y) 28 | else: 29 | raise RuntimeError 30 | return f 31 | 32 | kendall = get_rc_func('kendalltau') 33 | spearman = get_rc_func('spearmanr') 34 | 35 | class LitModel(pl.LightningModule): 36 | def __init__(self, cfg, hpms): 37 | super().__init__() 38 | # self.save_hyperparameters() 39 | self.is_raw = cfg.is_raw 40 | self.use_unq = hpms.use_unq 41 | self.use_unif = hpms.use_unif 42 | 43 | self.model_cfg = cfg.model 44 | self.data_cfg = cfg.data 45 | self.hpms = hpms 46 | self.setup_cfg = cfg.setup 47 | self.lit_cfg = cfg.lightning 48 | self.model = TransformerEncoder(n_layers=self.hpms.n_layer, n_head=self.hpms.n_head, num_patches=self.hpms.num_frames) 49 | self.bce = nn.BCELoss() 50 | 51 | def forward(self,x): 52 | out, scores = self.model(x) 53 | return out, scores 54 | 55 | def get_values(self, feats, proj, scores, train=True): 56 | 57 | assert (len(feats.shape) == 3) and (len(proj.shape) == 3) 58 | with torch.no_grad(): 59 | norm_raw = F.normalize(feats, p=2, dim=-1) 60 | xy_raw = torch.einsum('bmc, bnc -> bmn', norm_raw, norm_raw) 61 | norm_proj = F.normalize(proj, p=2, dim=-1) 62 | xy = torch.einsum('bmc, bnc -> bmn', norm_proj, norm_proj) 63 | sort_ids = torch.argsort(xy_raw, -1, descending=True) 64 | 65 | diff_mat = 2 - 2 * xy 66 | L = feats.shape[1] 67 | S = int(L * self.hpms.ratio_s) 68 | K1 = int(L * self.hpms.ratio_k1) 69 | pos = torch.gather(diff_mat, -1, sort_ids[:,:,S:S+K1]) 70 | 71 | laln = pos.mean(dim=-1) 72 | if train: 73 | lunif = diff_mat.mul(-2).exp().mean(dim=-1).log() 74 | else: 75 | lunif = diff_mat.mul(-2).exp().mean(dim=-1).log() 76 | if train: 77 | s = 20 78 | seg = rearrange(proj, 'b (s l) c -> (b s) l c', s=s) 79 | seg_feats = F.normalize(seg.mean(dim=1), dim=1) # (b s) c 80 | fv_xy = torch.einsum("bmc, nc -> bmn", norm_proj, seg_feats) # b m (b s) 81 | 82 | mask = torch.ones_like(fv_xy) 83 | # for i in range(len(mask)): 84 | # mask[i,:,i * s : (i+1) * s] = 0 85 | lunif_fv = (fv_xy.mul(4).exp() * mask).sum(dim=-1) / mask.sum(dim=-1) 86 | lunif_fv = lunif_fv.log() 87 | unq_target = (lunif_fv - lunif_fv.min(dim=-1, keepdim=True)[0]) / (lunif_fv.max(dim=-1, keepdim=True)[0] - lunif_fv.min(dim=-1, keepdim=True)[0]).add(1e-9) 88 | unq_target = unq_target * 0.5 + 0.25 89 | lunq = self.bce(scores, 1 - unq_target.detach()) 90 | return laln, lunif, lunif_fv, lunq 91 | else: 92 | return laln, lunif 93 | 94 | def training_step(self, batch, batch_idx): 95 | out, scores = self(batch) 96 | 97 | laln, lunif, lunif_fv, lunq = self.get_values(batch, out, scores) 98 | 99 | if self.use_unq: 100 | loss = laln.mean() + self.hpms.alpha * lunif.mean() + 0.1 * lunif_fv.mean() + 0.1 * lunq 101 | else: 102 | loss = laln.mean() + self.hpms.alpha * lunif.mean() 103 | 104 | return loss 105 | 106 | def dequantize(self,feat_vector, max_quantized_value=2, min_quantized_value=-2): 107 | 108 | ''' Dequantize the feature from the byte format to the float format. ''' 109 | assert max_quantized_value > min_quantized_value 110 | quantized_range = max_quantized_value - min_quantized_value 111 | scalar = quantized_range / 255.0 112 | bias = (quantized_range / 512.0) + min_quantized_value 113 | return feat_vector * scalar + bias 114 | 115 | @rank_zero_only 116 | def on_train_epoch_end(self): 117 | with torch.no_grad(): 118 | for name in ['tvsum', 'summe']: 119 | dataset = h5py.File(self.data_cfg.paths.interim+'/{}_inception_v3.h5'.format(name), 'r') 120 | fms = [] 121 | pres = [] 122 | recs = [] 123 | tau_list = [] 124 | r_list = [] 125 | eval_metric = 'avg' if name == 'tvsum' else 'max' 126 | 127 | for it, vid in enumerate(dataset): 128 | video = dataset[vid] 129 | cps = video['change_points'][...] 130 | video_feats = self.dequantize(video['features'][...]) 131 | video_feats = torch.Tensor(video_feats).float().to(self.device) 132 | 133 | feats = video_feats.unsqueeze(0) 134 | proj, unq_scores = self(feats) 135 | 136 | laln_raw, lunif_raw = self.get_values(feats, feats, unq_scores, train=False) 137 | laln_raw = laln_raw.flatten().cpu() 138 | lunif_raw = lunif_raw.flatten().cpu() 139 | 140 | laln_raw = (laln_raw - laln_raw.min()) / (laln_raw.max() - laln_raw.min()) 141 | lunif_raw = (lunif_raw - lunif_raw.min()) / (lunif_raw.max() - lunif_raw.min()) 142 | 143 | laln, lunif = self.get_values(feats, proj, unq_scores, train=False) 144 | laln = laln.flatten().cpu() 145 | lunif = lunif.flatten().cpu() 146 | 147 | laln = (laln - laln.min()) / (laln.max() - laln.min()) 148 | lunif = (lunif - lunif.min()) / (lunif.max() - lunif.min()) 149 | 150 | unq_scores = unq_scores.cpu().flatten() 151 | unq_scores = (unq_scores - unq_scores.min()) / (unq_scores.max() - unq_scores.min()) 152 | 153 | if not self.is_raw: 154 | if self.use_unif: 155 | scores = laln * lunif 156 | else: 157 | scores = laln 158 | 159 | if self.use_unq: 160 | scores = scores * unq_scores 161 | else: 162 | if self.use_unif: 163 | scores = laln_raw * lunif_raw 164 | else: 165 | scores = laln_raw 166 | 167 | if name == 'tvsum': 168 | scores = np.exp(scores - 1) 169 | elif name == 'summe': 170 | scores = scores + 0.05 171 | else: 172 | raise NotImplementedError 173 | scores = gaussian_filter1d(scores, 1) 174 | 175 | if name == 'tvsum': 176 | user_scores = video['user_scores'][...] 177 | num_users = user_scores.shape[0] 178 | tau = 0. 179 | r = 0. 180 | for us in user_scores: 181 | tau += kendall(us,scores)[0] 182 | r += spearman(us,scores)[0] 183 | tau_list.append(tau/num_users) 184 | r_list.append(r/num_users) 185 | else: 186 | user_scores = video['user_summary'][...] 187 | num_users = user_scores.shape[0] 188 | tau = 0. 189 | r = 0. 190 | for us in user_scores: 191 | us = us[:len(us)-(len(us) % 15)] 192 | us = us.reshape(-1, 15) 193 | us = us.mean(-1)#us[::15] 194 | if len(us) > len(scores): 195 | us = us[:len(scores)] 196 | elif len(us) < len(scores): 197 | scores = scores[:len(us)] 198 | tau += kendall(us,scores)[0] 199 | r += spearman(us,scores)[0] 200 | tau_list.append(tau/num_users) 201 | r_list.append(r/num_users) 202 | 203 | num_frames = video['n_frames'][()] 204 | 205 | nfps = video['n_frame_per_seg'][...].tolist() 206 | positions = video['picks'][...] 207 | user_summary = video['user_summary'][...] 208 | 209 | machine_summary = vsum_tools.generate_summary(scores, cps, num_frames, nfps, positions) 210 | fm, pre, rec,machine_summary,user_summary = vsum_tools.evaluate_summary(machine_summary, user_summary, eval_metric) 211 | fms.append(fm) 212 | pres.append(pre) 213 | recs.append(rec) 214 | 215 | 216 | mean_tau = torch.mean(torch.tensor(self.all_gather(tau_list))) 217 | mean_r = torch.mean(torch.tensor(self.all_gather(r_list))) 218 | 219 | mean_fm = torch.mean(torch.tensor(self.all_gather(fms))) 220 | print("{}: Average F-score {:.2%}, Tau {:.4f}, R {:.4f}".format(name, mean_fm, mean_tau, mean_r)) 221 | 222 | self.log(f"{name}/fm", mean_fm) 223 | self.log(f"{name}/tau", mean_tau) 224 | self.log(f"{name}/r", mean_r) 225 | 226 | def train_dataloader(self): 227 | dataset = Youtube8M(self.data_cfg, self.hpms.num_frames) 228 | dataloader = DataLoader(dataset, 229 | num_workers=self.data_cfg.num_workers, 230 | batch_size=self.hpms.batch_size, 231 | shuffle=True, 232 | drop_last=True, 233 | pin_memory=True) 234 | self.dataloader_len = len(dataloader) 235 | print("Number of training videos: {}".format(len(dataset))) 236 | return dataloader 237 | def configure_optimizers(self): 238 | optimizer = torch.optim.Adam(list(self.model.parameters()), 239 | lr=self.hpms.lr, 240 | weight_decay=self.hpms.weight_decay) 241 | return optimizer -------------------------------------------------------------------------------- /main_ablations.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob 2 | import numpy as np 3 | 4 | from omegaconf import OmegaConf 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.plugins import DDPPlugin 8 | from pytorch_lightning import seed_everything 9 | from pytorch_lightning.trainer import Trainer 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from pytorch_lightning.loggers import WandbLogger, CSVLogger 12 | 13 | 14 | from lit_models.lit_aln_uni import LitModel 15 | from custom_callbacks import SetupCallback 16 | 17 | def get_parser(**parser_kwargs): 18 | def str2bool(v): 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ("yes", "true", "t", "y", "1"): 22 | return True 23 | elif v.lower() in ("no", "false", "f", "n", "0"): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError("Boolean value expected.") 27 | 28 | parser = argparse.ArgumentParser(**parser_kwargs) 29 | parser.add_argument( 30 | "-n", 31 | "--name", 32 | type=str, 33 | const=True, 34 | nargs="?", 35 | help="postfix for logdir", 36 | ) 37 | 38 | parser.add_argument( 39 | "-b", 40 | "--base", 41 | nargs="*", 42 | metavar="base_config.yaml", 43 | help="paths to base configs. Loaded from left-to-right. " 44 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 45 | default=list(), 46 | ) 47 | 48 | return parser 49 | 50 | def nondefault_trainer_args(opt): 51 | 52 | parser = argparse.ArgumentParser() 53 | parser = Trainer.add_argparse_args(parser) 54 | args = parser.parse_args([]) 55 | 56 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 57 | 58 | if __name__ == "__main__": 59 | sys.path.append(os.getcwd()) 60 | 61 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 62 | 63 | parser = get_parser() 64 | parser = Trainer.add_argparse_args(parser) 65 | 66 | opt, unknown = parser.parse_known_args() 67 | if opt.name and opt.resume: 68 | raise ValueError( 69 | "-n/--name and -r/--resume cannot be specified both." 70 | "If you want to resume training in a new log folder, " 71 | "use -n/--name in combination with --resume_from_checkpoint" 72 | ) 73 | if opt.resume: 74 | if not os.path.exists(opt.resume): 75 | raise ValueError("Cannot find {}".format(opt.resume)) 76 | if os.path.isfile(opt.resume): 77 | paths = opt.resume.split("/") 78 | idx = len(paths)-paths[::-1].index("logs")+1 79 | logdir = "/".join(paths[:idx]) 80 | ckpt = opt.resume 81 | else: 82 | assert os.path.isdir(opt.resume), opt.resume 83 | logdir = opt.resume.rstrip("/") 84 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 85 | 86 | opt.resume_from_checkpoint = ckpt 87 | 88 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) 89 | opt.base = base_configs+opt.base 90 | _tmp = logdir.split("/") 91 | name = _tmp[_tmp.index("logs")+1] 92 | nowname = _tmp[_tmp.index("logs")+2] 93 | 94 | else: 95 | if opt.name: 96 | name = opt.name 97 | elif opt.base: 98 | cfg_fname = os.path.split(opt.base[0])[-1] 99 | cfg_name = os.path.splitext(cfg_fname)[0] 100 | name = cfg_name.split('_')[0] 101 | else: 102 | name = "no_name" 103 | nowname = now 104 | 105 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 106 | seed_everything(0) 107 | # merge configs 108 | cli = OmegaConf.from_dotlist(unknown) 109 | configs = OmegaConf.merge(*configs, cli) 110 | 111 | # define log directories 112 | logdir = os.path.join("logs",name, nowname) 113 | ckptdir = os.path.join(logdir, "checkpoints") 114 | cfgdir = os.path.join(logdir, "configs") 115 | configs.setup.logdir = logdir 116 | configs.setup.ckptdir = ckptdir 117 | configs.setup.cfgdir = cfgdir 118 | #### set up important trainer flags 119 | # merge trainer cli with config 120 | trainer_cfg = configs.lightning.get("trainer", OmegaConf.create()) 121 | for k in nondefault_trainer_args(opt): # this incorporates cli into trainer configs 122 | trainer_cfg[k] = getattr(opt, k) 123 | 124 | if not "gpus" in trainer_cfg: 125 | cpu = True 126 | else: 127 | gpuinfo = trainer_cfg["gpus"] 128 | print(f"Running on GPUs {gpuinfo}") 129 | cpu = False 130 | 131 | trainer_opt = argparse.Namespace(**trainer_cfg) 132 | configs.lightning.trainer = trainer_cfg 133 | 134 | #### configure learning rate 135 | lr = configs.hparams.lr 136 | if not cpu: 137 | ngpu = len(configs.lightning.trainer.gpus.strip(",").split(',')) 138 | else: 139 | ngpu = 0 140 | 141 | 142 | 143 | trainer_kwargs = dict() 144 | ### configure callbacks 145 | checkpoint_callback = ModelCheckpoint(dirpath=ckptdir, 146 | filename="{epoch:06}", 147 | verbose=True, 148 | save_last=True) 149 | setup_callback = SetupCallback(resume=opt.resume, 150 | now=now, 151 | logdir=logdir, 152 | ckptdir=ckptdir, 153 | cfgdir=cfgdir, 154 | config=configs) 155 | 156 | trainer_kwargs["callbacks"] = [checkpoint_callback, 157 | setup_callback] 158 | # ### configure logger 159 | # if configs.is_raw: 160 | # logname = name + '_raw' 161 | # logname = name + '_align' 162 | # if configs.use_unif: 163 | # logname += 'align_unif' 164 | # if (not confgs_is_raw) and (configs.use_unq): 165 | # logname += '_unq' 166 | 167 | # logger = CSVLogger(logdir, name=logname) 168 | 169 | # trainer_kwargs["logger"] = logger 170 | 171 | if configs.is_raw: 172 | configs.data.setting = 'Transfer' 173 | if configs.data.name == 'summe': 174 | configs.hparams.batch_size = 8 175 | ### initialize trainer 176 | if configs.data.setting != 'Transfer': 177 | results = {'F1':[], 'tau':[], 'rho':[]} 178 | for split in range(5): 179 | configs.data.split = split 180 | 181 | trainer = Trainer.from_argparse_args(trainer_opt, 182 | **trainer_kwargs, 183 | plugins=DDPPlugin(find_unused_parameters=True)) 184 | 185 | model = LitModel(configs, configs.hparams) 186 | if not configs.is_raw: 187 | trainer.fit(model) 188 | results_split = trainer.test(model)[0] 189 | for key in results_split: 190 | results[key].append(results_split[key]) 191 | print("Average F-score {:.2%}, Tau {:.4f}, R {:.4f}".format(np.mean(results['F1']), 192 | np.mean(results['tau']), 193 | np.mean(results['rho']))) 194 | print(f'{configs.data.name}_{configs.data.setting}_trained_lunif_{configs.use_unif}_unq_{configs.use_unq}') 195 | else: 196 | trainer = Trainer.from_argparse_args(trainer_opt, 197 | **trainer_kwargs, 198 | plugins=DDPPlugin(find_unused_parameters=True)) 199 | model = LitModel(configs, configs.hparams) 200 | if not configs.is_raw: 201 | trainer.fit(model) 202 | results = trainer.test(model)[0] 203 | print("Average F-score {:.2%}, Tau {:.4f}, R {:.4f}".format(results['F1'], results['tau'], results['rho'])) 204 | if configs.is_raw: 205 | print(f'{configs.data.name}_raw_lunif_{configs.use_unif}') 206 | else: 207 | print(f'{configs.data.name}_lunif_{configs.use_unif}_unq_{configs.use_unq}') -------------------------------------------------------------------------------- /main_y8.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob 2 | 3 | from omegaconf import OmegaConf 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.plugins import DDPPlugin 7 | from pytorch_lightning import seed_everything 8 | from pytorch_lightning.trainer import Trainer 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from pytorch_lightning.loggers import WandbLogger 11 | 12 | 13 | from lit_models.lit_aln_uni_y8 import LitModel 14 | from custom_callbacks import SetupCallback 15 | 16 | def get_parser(**parser_kwargs): 17 | def str2bool(v): 18 | if isinstance(v, bool): 19 | return v 20 | if v.lower() in ("yes", "true", "t", "y", "1"): 21 | return True 22 | elif v.lower() in ("no", "false", "f", "n", "0"): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError("Boolean value expected.") 26 | 27 | parser = argparse.ArgumentParser(**parser_kwargs) 28 | parser.add_argument( 29 | "-n", 30 | "--name", 31 | type=str, 32 | const=True, 33 | nargs="?", 34 | help="postfix for logdir", 35 | ) 36 | 37 | parser.add_argument( 38 | "-b", 39 | "--base", 40 | nargs="*", 41 | metavar="base_config.yaml", 42 | help="paths to base configs. Loaded from left-to-right. " 43 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 44 | default=list(), 45 | ) 46 | 47 | parser.add_argument( 48 | "--seed", 49 | type=int, 50 | const=True, 51 | default=0, 52 | nargs="?", 53 | help="random seed", 54 | ) 55 | return parser 56 | 57 | def nondefault_trainer_args(opt): 58 | 59 | parser = argparse.ArgumentParser() 60 | parser = Trainer.add_argparse_args(parser) 61 | args = parser.parse_args([]) 62 | 63 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 64 | 65 | if __name__ == "__main__": 66 | sys.path.append(os.getcwd()) 67 | 68 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 69 | 70 | parser = get_parser() 71 | parser = Trainer.add_argparse_args(parser) 72 | 73 | opt, unknown = parser.parse_known_args() 74 | if opt.name and opt.resume: 75 | raise ValueError( 76 | "-n/--name and -r/--resume cannot be specified both." 77 | "If you want to resume training in a new log folder, " 78 | "use -n/--name in combination with --resume_from_checkpoint" 79 | ) 80 | if opt.resume: 81 | if not os.path.exists(opt.resume): 82 | raise ValueError("Cannot find {}".format(opt.resume)) 83 | if os.path.isfile(opt.resume): 84 | paths = opt.resume.split("/") 85 | idx = len(paths)-paths[::-1].index("logs")+1 86 | logdir = "/".join(paths[:idx]) 87 | ckpt = opt.resume 88 | else: 89 | assert os.path.isdir(opt.resume), opt.resume 90 | logdir = opt.resume.rstrip("/") 91 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 92 | 93 | opt.resume_from_checkpoint = ckpt 94 | 95 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) 96 | opt.base = base_configs+opt.base 97 | _tmp = logdir.split("/") 98 | name = _tmp[_tmp.index("logs")+1] 99 | nowname = _tmp[_tmp.index("logs")+2] 100 | 101 | else: 102 | if opt.name: 103 | name = opt.name 104 | elif opt.base: 105 | cfg_fname = os.path.split(opt.base[0])[-1] 106 | cfg_name = os.path.splitext(cfg_fname)[0] 107 | name = cfg_name.split('_')[0] 108 | else: 109 | name = "no_name" 110 | nowname = now 111 | 112 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 113 | 114 | seed_everything(0) 115 | # merge configs 116 | cli = OmegaConf.from_dotlist(unknown) 117 | configs = OmegaConf.merge(*configs, cli) 118 | # define log directories 119 | logdir = os.path.join("logs",name, nowname) 120 | ckptdir = os.path.join(logdir, "checkpoints") 121 | cfgdir = os.path.join(logdir, "configs") 122 | configs.setup.logdir = logdir 123 | configs.setup.ckptdir = ckptdir 124 | configs.setup.cfgdir = cfgdir 125 | #### set up important trainer flags 126 | if configs.is_raw: 127 | configs.lightning.trainer.max_epochs = 1 128 | # merge trainer cli with config 129 | trainer_cfg = configs.lightning.get("trainer", OmegaConf.create()) 130 | for k in nondefault_trainer_args(opt): # this incorporates cli into trainer configs 131 | trainer_cfg[k] = getattr(opt, k) 132 | 133 | if not "gpus" in trainer_cfg: 134 | cpu = True 135 | else: 136 | gpuinfo = trainer_cfg["gpus"] 137 | print(f"Running on GPUs {gpuinfo}") 138 | cpu = False 139 | 140 | trainer_opt = argparse.Namespace(**trainer_cfg) 141 | configs.lightning.trainer = trainer_cfg 142 | 143 | #### configure learning rate 144 | lr = configs.hparams.lr 145 | if not cpu: 146 | ngpu = len(configs.lightning.trainer.gpus.strip(",").split(',')) 147 | else: 148 | ngpu = 0 149 | 150 | ### initialize pl model 151 | # model = LitModel(configs, configs.hparams) 152 | 153 | trainer_kwargs = dict() 154 | ### configure callbacks 155 | checkpoint_callback = ModelCheckpoint(dirpath=ckptdir, 156 | filename="{epoch:06}", 157 | verbose=True, 158 | save_last=True) 159 | setup_callback = SetupCallback(resume=opt.resume, 160 | now=now, 161 | logdir=logdir, 162 | ckptdir=ckptdir, 163 | cfgdir=cfgdir, 164 | config=configs) 165 | 166 | trainer_kwargs["callbacks"] = [checkpoint_callback, 167 | setup_callback] 168 | ### configure logger 169 | # logger = WandbLogger(project=configs.setup.wandb_name, 170 | # name=name+"_"+nowname, 171 | # id=nowname) 172 | 173 | # trainer_kwargs["logger"] = logger 174 | ### initialize trainer 175 | # trainer = Trainer.from_argparse_args(trainer_opt, 176 | # **trainer_kwargs, 177 | # accelerator="gpu", 178 | # devices=-1, 179 | # plugins=DDPPlugin(find_unused_parameters=True)) 180 | 181 | # trainer.fit(model) 182 | 183 | 184 | trainer = Trainer.from_argparse_args(trainer_opt, 185 | **trainer_kwargs, 186 | accelerator="gpu", 187 | devices=-1, 188 | plugins=DDPPlugin(find_unused_parameters=True)) 189 | model = LitModel(configs, configs.hparams) 190 | 191 | trainer.fit(model) -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class EncoderLayer(nn.Module): 7 | ''' Compose with two layers ''' 8 | 9 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 10 | super(EncoderLayer, self).__init__() 11 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 12 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 13 | 14 | def forward(self, enc_input, slf_attn_mask=None): 15 | enc_output, enc_slf_attn = self.slf_attn( 16 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 17 | enc_output = self.pos_ffn(enc_output) 18 | return enc_output, enc_slf_attn 19 | 20 | class ScaledDotProductAttention(nn.Module): 21 | ''' Scaled Dot-Product Attention ''' 22 | 23 | def __init__(self, temperature, attn_dropout=0.1): 24 | super().__init__() 25 | self.temperature = temperature 26 | self.dropout = nn.Dropout(attn_dropout) 27 | 28 | def forward(self, q, k, v, mask=None): 29 | 30 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 31 | 32 | if mask is not None: 33 | attn = attn.masked_fill(mask == 0, -1e9) 34 | 35 | attn = self.dropout(F.softmax(attn, dim=-1)) 36 | output = torch.matmul(attn, v) 37 | 38 | return output, attn 39 | 40 | 41 | class MultiHeadAttention(nn.Module): 42 | ''' Multi-Head Attention module ''' 43 | 44 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 45 | super().__init__() 46 | 47 | self.n_head = n_head 48 | self.d_k = d_k 49 | self.d_v = d_v 50 | 51 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 52 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 53 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 54 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 55 | 56 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 57 | 58 | self.dropout = nn.Dropout(dropout) 59 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 60 | 61 | 62 | def forward(self, q, k, v, mask=None): 63 | 64 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 65 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 66 | 67 | residual = q 68 | 69 | # Pass through the pre-attention projection: b x lq x (n*dv) 70 | # Separate different heads: b x lq x n x dv 71 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 72 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 73 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 74 | 75 | # Transpose for attention dot product: b x n x lq x dv 76 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 77 | 78 | if mask is not None: 79 | mask = mask.unsqueeze(1) # For head axis broadcasting. 80 | 81 | q, attn = self.attention(q, k, v, mask=mask) 82 | 83 | # Transpose to move the head dimension back: b x lq x n x dv 84 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 85 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 86 | q = self.dropout(self.fc(q)) 87 | q += residual 88 | 89 | q = self.layer_norm(q) 90 | 91 | return q, attn 92 | 93 | 94 | class PositionwiseFeedForward(nn.Module): 95 | ''' A two-feed-forward-layer module ''' 96 | 97 | def __init__(self, d_in, d_hid, dropout=0.1): 98 | super().__init__() 99 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 100 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 101 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 102 | self.dropout = nn.Dropout(dropout) 103 | 104 | def forward(self, x): 105 | 106 | residual = x 107 | 108 | x = self.w_2(F.relu(self.w_1(x))) 109 | x = self.dropout(x) 110 | x += residual 111 | 112 | x = self.layer_norm(x) 113 | 114 | return x -------------------------------------------------------------------------------- /run_ablation.sh: -------------------------------------------------------------------------------- 1 | python main_ablations.py \ 2 | -t True \ 3 | --base configs/aln_uni_config.yml \ 4 | --gpus 0, 5 | -------------------------------------------------------------------------------- /run_y8.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_y8.py \ 2 | --base configs/aln_uni_y8_config.yml \ 3 | --seed 0 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | 4 | import numpy as np 5 | import torch 6 | import math 7 | import PIL.Image 8 | import matplotlib.pyplot as plt 9 | import wandb 10 | 11 | from torchvision.transforms import ToTensor 12 | 13 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 14 | warmup_schedule = np.array([]) 15 | warmup_iters = warmup_epochs * niter_per_ep 16 | if warmup_epochs > 0: 17 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 18 | 19 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 20 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 21 | 22 | schedule = np.concatenate((warmup_schedule, schedule)) 23 | assert len(schedule) == epochs * niter_per_ep 24 | return schedule 25 | 26 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 27 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 28 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 29 | def norm_cdf(x): 30 | # Computes standard normal cumulative distribution function 31 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 32 | 33 | if (mean < a - 2 * std) or (mean > b + 2 * std): 34 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 35 | "The distribution of values may be incorrect.", 36 | stacklevel=2) 37 | 38 | with torch.no_grad(): 39 | # Values are generated by using a truncated uniform distribution and 40 | # then using the inverse CDF for the normal distribution. 41 | # Get upper and lower cdf values 42 | l = norm_cdf((a - mean) / std) 43 | u = norm_cdf((b - mean) / std) 44 | 45 | # Uniformly fill tensor with values from [l, u], then translate to 46 | # [2l-1, 2u-1]. 47 | tensor.uniform_(2 * l - 1, 2 * u - 1) 48 | 49 | # Use inverse cdf transform for normal distribution to get truncated 50 | # standard normal 51 | tensor.erfinv_() 52 | 53 | # Transform to proper mean, std 54 | tensor.mul_(std * math.sqrt(2.)) 55 | tensor.add_(mean) 56 | 57 | # Clamp to ensure it's in the proper range 58 | tensor.clamp_(min=a, max=b) 59 | return tensor 60 | 61 | 62 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 63 | # type: (Tensor, float, float, float, float) -> Tensor 64 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 65 | 66 | def test_collate(input): 67 | return input[0] 68 | 69 | def train_collate(batch): 70 | return batch 71 | 72 | def read_json(fpath): 73 | with open(fpath, 'r') as f: 74 | obj = json.load(f) 75 | return obj 76 | 77 | def add_plot(logger, name, step): 78 | buf = io.BytesIO() 79 | plt.savefig(buf, format='jpeg') 80 | buf.seek(0) 81 | image = PIL.Image.open(buf) 82 | image = ToTensor()(image) 83 | logger.log({name:wandb.Image(image)}) 84 | plt.clf() 85 | plt.close() 86 | -------------------------------------------------------------------------------- /vsum_tools.py: -------------------------------------------------------------------------------- 1 | '''' 2 | Courtesy of KaiyangZhou 3 | https://github.com/KaiyangZhou/pytorch-vsumm-reinforce 4 | @article{zhou2017reinforcevsumm, 5 | title={Deep Reinforcement Learning for Unsupervised Video Summarization with Diversity-Representativeness Reward}, 6 | author={Zhou, Kaiyang and Qiao, Yu and Xiang, Tao}, 7 | journal={arXiv:1801.00054}, 8 | year={2017} 9 | } 10 | Modifications by Jiri Fajtl 11 | - knapsack replaced with knapsack_ortools 12 | - added evaluate_user_summaries() for user summaries ground truth evaluation 13 | ''' 14 | 15 | import numpy as np 16 | from knapsack import knapsack_dp 17 | from knapsack import knapsack_ortools 18 | import math 19 | 20 | 21 | def generate_summary(ypred, cps, n_frames, nfps, positions, proportion=0.15, method='knapsack'): 22 | """Generate keyshot-based video summary i.e. a binary vector. 23 | Args: 24 | --------------------------------------------- 25 | - ypred: predicted importance scores. 26 | - cps: change points, 2D matrix, each row contains a segment. 27 | - n_frames: original number of frames. 28 | - nfps: number of frames per segment. 29 | - positions: positions of subsampled frames in the original video. 30 | - proportion: length of video summary (compared to original video length). 31 | - method: defines how shots are selected, ['knapsack', 'rank']. 32 | """ 33 | n_segs = cps.shape[0] 34 | frame_scores = np.zeros((n_frames), dtype=np.float32) 35 | if positions.dtype != int: 36 | positions = positions.astype(np.int32) 37 | if positions[-1] != n_frames: 38 | positions = np.concatenate([positions, [n_frames]]) 39 | for i in range(len(positions) - 1): 40 | pos_left, pos_right = positions[i], positions[i+1] 41 | if i == len(ypred): 42 | frame_scores[pos_left:pos_right] = 0 43 | else: 44 | frame_scores[pos_left:pos_right] = ypred[i] 45 | 46 | seg_score = [] 47 | for seg_idx in range(n_segs): 48 | start, end = int(cps[seg_idx,0]), int(cps[seg_idx,1]+1) 49 | scores = frame_scores[start:end] 50 | seg_score.append(float(scores.mean())) 51 | 52 | limits = int(math.floor(n_frames * proportion)) 53 | 54 | if method == 'knapsack': 55 | # picks = knapsack_dp(seg_score, nfps, n_segs, limits) 56 | picks = knapsack_ortools(seg_score, nfps, n_segs, limits) 57 | elif method == 'rank': 58 | order = np.argsort(seg_score)[::-1].tolist() 59 | picks = [] 60 | total_len = 0 61 | for i in order: 62 | if total_len + nfps[i] < limits: 63 | picks.append(i) 64 | total_len += nfps[i] 65 | else: 66 | raise KeyError("Unknown method {}".format(method)) 67 | 68 | summary = np.zeros((1), dtype=np.float32) # this element should be deleted 69 | for seg_idx in range(n_segs): 70 | nf = nfps[seg_idx] 71 | if seg_idx in picks: 72 | tmp = np.ones((nf), dtype=np.float32) 73 | else: 74 | tmp = np.zeros((nf), dtype=np.float32) 75 | summary = np.concatenate((summary, tmp)) 76 | 77 | summary = np.delete(summary, 0) # delete the first element 78 | return summary 79 | 80 | 81 | def evaluate_summary(machine_summary, user_summary, eval_metric='avg'): 82 | """Compare machine summary with user summary (keyshot-based). 83 | Args: 84 | -------------------------------- 85 | machine_summary and user_summary should be binary vectors of ndarray type. 86 | eval_metric = {'avg', 'max'} 87 | 'avg' averages results of comparing multiple human summaries. 88 | 'max' takes the maximum (best) out of multiple comparisons. 89 | """ 90 | machine_summary = machine_summary.astype(np.float32) 91 | user_summary = user_summary.astype(np.float32) 92 | n_users,n_frames = user_summary.shape 93 | 94 | # binarization 95 | machine_summary[machine_summary > 0] = 1 96 | user_summary[user_summary > 0] = 1 97 | 98 | if len(machine_summary) > n_frames: 99 | machine_summary = machine_summary[:n_frames] 100 | elif len(machine_summary) < n_frames: 101 | zero_padding = np.zeros((n_frames - len(machine_summary))) 102 | machine_summary = np.concatenate([machine_summary, zero_padding]) 103 | 104 | f_scores = [] 105 | prec_arr = [] 106 | rec_arr = [] 107 | 108 | for user_idx in range(n_users): 109 | gt_summary = user_summary[user_idx,:] 110 | overlap_duration = (machine_summary * gt_summary).sum() 111 | precision = overlap_duration / (machine_summary.sum() + 1e-8) 112 | recall = overlap_duration / (gt_summary.sum() + 1e-8) 113 | if precision == 0 and recall == 0: 114 | f_score = 0. 115 | else: 116 | f_score = (2 * precision * recall) / (precision + recall) 117 | f_scores.append(f_score) 118 | prec_arr.append(precision) 119 | rec_arr.append(recall) 120 | 121 | if eval_metric == 'avg': 122 | final_f_score = np.mean(f_scores) 123 | final_prec = np.mean(prec_arr) 124 | final_rec = np.mean(rec_arr) 125 | elif eval_metric == 'max': 126 | final_f_score = np.max(f_scores) 127 | max_idx = np.argmax(f_scores) 128 | final_prec = prec_arr[max_idx] 129 | final_rec = rec_arr[max_idx] 130 | 131 | return final_f_score, final_prec, final_rec,machine_summary,user_summary 132 | 133 | 134 | def evaluate_user_summaries(user_summary, eval_metric='avg'): 135 | """Compare machine summary with user summary (keyshot-based). 136 | Args: 137 | -------------------------------- 138 | machine_summary and user_summary should be binary vectors of ndarray type. 139 | eval_metric = {'avg', 'max'} 140 | 'avg' averages results of comparing multiple human summaries. 141 | 'max' takes the maximum (best) out of multiple comparisons. 142 | """ 143 | user_summary = user_summary.astype(np.float32) 144 | n_users, n_frames = user_summary.shape 145 | 146 | # binarization 147 | user_summary[user_summary > 0] = 1 148 | 149 | f_scores = [] 150 | prec_arr = [] 151 | rec_arr = [] 152 | 153 | for user_idx in range(n_users): 154 | gt_summary = user_summary[user_idx, :] 155 | for other_user_idx in range(user_idx+1, n_users): 156 | other_gt_summary = user_summary[other_user_idx, :] 157 | overlap_duration = (other_gt_summary * gt_summary).sum() 158 | precision = overlap_duration / (other_gt_summary.sum() + 1e-8) 159 | recall = overlap_duration / (gt_summary.sum() + 1e-8) 160 | if precision == 0 and recall == 0: 161 | f_score = 0. 162 | else: 163 | f_score = (2 * precision * recall) / (precision + recall) 164 | f_scores.append(f_score) 165 | prec_arr.append(precision) 166 | rec_arr.append(recall) 167 | 168 | 169 | if eval_metric == 'avg': 170 | final_f_score = np.mean(f_scores) 171 | final_prec = np.mean(prec_arr) 172 | final_rec = np.mean(rec_arr) 173 | elif eval_metric == 'max': 174 | final_f_score = np.max(f_scores) 175 | max_idx = np.argmax(f_scores) 176 | final_prec = prec_arr[max_idx] 177 | final_rec = rec_arr[max_idx] 178 | 179 | return final_f_score, final_prec, final_rec --------------------------------------------------------------------------------