├── code ├── .gitignore ├── exp.py ├── model │ ├── decorator.py │ ├── __init__.py │ ├── config.py │ ├── pretrained.py │ ├── attention.py │ ├── pooler.py │ ├── input.py │ └── lavit.py ├── data │ ├── merge.py │ ├── tokenizer.py │ ├── dataloader.py │ ├── load.py │ ├── README.md │ ├── qa.py │ ├── video.py │ └── dataset.py ├── utils.py ├── cli.py ├── optimizer │ ├── schedulers.py │ ├── bert_adam.py │ └── __init__.py ├── ckpt.py ├── metrics │ └── logger.py ├── common.py ├── evaluate.py ├── config.py ├── args.py └── train.py ├── data └── README.md ├── assets └── data.png ├── LICENSE ├── .gitignore ├── README.md └── environment.yml /code/.gitignore: -------------------------------------------------------------------------------- 1 | demo 2 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | -------------------------------------------------------------------------------- /assets/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HS-YN/PanoAVQA/HEAD/assets/data.png -------------------------------------------------------------------------------- /code/exp.py: -------------------------------------------------------------------------------- 1 | import sacred 2 | 3 | exp_name = 'pano_avqa' 4 | ex = sacred.Experiment(exp_name) 5 | 6 | use_mongodb = False -------------------------------------------------------------------------------- /code/model/decorator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def full_model(target): 4 | target.is_full_model = True 5 | target.run_pretrain = True 6 | 7 | return target -------------------------------------------------------------------------------- /code/data/merge.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import numpy as np 4 | 5 | from exp import ex 6 | 7 | 8 | def merge_data(qas, tokenizer): 9 | # Not needed 10 | return qas -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def merge_dict(list_of_dict): 6 | res = {} 7 | if len(list_of_dict) > 0: 8 | keys = list_of_dict[0].keys() 9 | for key in keys: 10 | res[key] = [d[key] for d in list_of_dict] 11 | return res 12 | 13 | 14 | def one_hot_vectorize(target, element_list): 15 | vector = [1 if ele == target else 0 for ele in element_list] 16 | 17 | if sum(vector) == 0: ## target is not in the list 18 | vector += [1] 19 | else: 20 | vector += [0] 21 | 22 | return np.array(vector) 23 | -------------------------------------------------------------------------------- /code/cli.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from exp import ex 4 | from args import get_args 5 | 6 | from train import _train 7 | from evaluate import _eval 8 | 9 | @ex.command 10 | def train(_config): 11 | res = _train() 12 | print("Training complete") 13 | 14 | return 0 15 | 16 | 17 | @ex.command 18 | def eval(_config): 19 | res = _eval() 20 | print("Evaluation complete") 21 | 22 | return 0 23 | 24 | 25 | @ex.option_hook 26 | def update_args(options): 27 | args = get_args(options) 28 | print(json.dumps({k: str(v) for k, v in sorted(args.items())}, indent=4)) 29 | ex.add_config(args) 30 | return options 31 | 32 | 33 | @ex.automain 34 | def run(): 35 | train() -------------------------------------------------------------------------------- /code/optimizer/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | from transformers import get_linear_schedule_with_warmup 3 | 4 | from exp import ex 5 | 6 | 7 | def get_no_scheduler(optimizer, num_warmup_steps, num_training_steps): 8 | def lr_lambda(current_step): 9 | return 1 10 | 11 | return LambdaLR(optimizer, lr_lambda) 12 | 13 | 14 | sched_dict = { 15 | 'linear': get_linear_schedule_with_warmup, 16 | 'none': get_no_scheduler 17 | } 18 | 19 | 20 | @ex.capture() 21 | def get_scheduler(optimizer, t_total, warmup, scheduler_name, grad_acc_steps): 22 | warmup_steps = int(t_total * warmup) 23 | scheduler = sched_dict[scheduler_name](optimizer, warmup_steps, t_total) 24 | scheduler.accumulated = 0 25 | scheduler.grad_acc_steps = grad_acc_steps 26 | return scheduler -------------------------------------------------------------------------------- /code/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | 4 | from exp import ex 5 | 6 | 7 | token_dict = { 8 | 'bert': 'bert-base-uncased' 9 | } 10 | 11 | 12 | @ex.capture() 13 | def get_tokenizer(cache_path, transformer_name, rebuild_cache): 14 | # Provide tokenizer with caching 15 | tokenizer_path = cache_path / 'tokenizer' 16 | tokenizer_path.mkdir(parents=True, exist_ok=True) 17 | tokenizer_file = f"{transformer_name}.pkl" 18 | path = tokenizer_path / tokenizer_file 19 | 20 | if rebuild_cache and path.is_file(): 21 | path.unlink() 22 | if path.is_file(): 23 | tokenizer = torch.load(path) 24 | else: 25 | tokenizer = AutoTokenizer.from_pretrained(token_dict[transformer_name], 26 | do_lower_case=True) 27 | torch.save(tokenizer, path) 28 | return tokenizer -------------------------------------------------------------------------------- /code/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from munch import Munch 5 | from torch.utils.data import DataLoader 6 | 7 | from exp import ex 8 | from utils import merge_dict 9 | from .dataset import get_dataset 10 | 11 | @ex.capture() 12 | def get_dataloaders(batch_size, grad_acc_steps, num_workers, max_epochs, modes=['train', 'val',' test']): 13 | dataset, video, tokenizer = get_dataset(modes=modes) 14 | outputs = {} 15 | 16 | for mode, ds in dataset.items(): 17 | dataloader = DataLoader(ds, 18 | batch_size=batch_size, 19 | collate_fn=ds.collate_fn, 20 | shuffle=(mode == 'train' or mode == 'pretrain'), 21 | num_workers=num_workers) 22 | dataloader.dataset.t_total = math.ceil(len(ds) * max_epochs / (batch_size * grad_acc_steps)) 23 | outputs[mode] = dataloader 24 | return outputs, video, tokenizer 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Heeseung Yun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /code/data/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from exp import ex 4 | from .qa import get_qa 5 | from .video import get_video 6 | from .tokenizer import get_tokenizer 7 | 8 | 9 | @ex.capture() 10 | def load(modes, cache_path, transformer_name, rebuild_cache): 11 | if isinstance(modes, str): 12 | modes = modes 13 | modes = sorted(list(modes)) 14 | cache_path.mkdir(parents=True, exist_ok=True) 15 | cache_files = [f"{mode}_{transformer_name}.pkl" for mode in modes] 16 | 17 | data = {} 18 | tokenizer = get_tokenizer() 19 | print(f'[LOG] Loading cached QA from {cache_path}...', end='', flush=True) 20 | for mode, cache_file in zip(modes, cache_files): 21 | path = cache_path / cache_file 22 | if rebuild_cache and path.is_file(): 23 | path.unlink() 24 | if path.is_file(): 25 | data[mode] = torch.load(path) 26 | else: 27 | qa = get_qa(tokenizer, data=None, mode=mode) 28 | torch.save(qa, path) 29 | data[mode] = qa 30 | print('Complete!') 31 | video = get_video(mode=mode) 32 | 33 | return data, video, tokenizer -------------------------------------------------------------------------------- /code/optimizer/bert_adam.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam 2 | from transformers import AdamW 3 | 4 | from exp import ex 5 | 6 | 7 | class BertAdam(AdamW): 8 | @ex.capture() 9 | def __init__(self, model, learning_rate, transformer_learning_rate, weight_decay): 10 | options = {} 11 | options['lr'] = learning_rate 12 | options['weight_decay'] = weight_decay 13 | 14 | params = [] 15 | for name, child in model.named_children(): 16 | if name == 'transformer': 17 | lr = transformer_learning_rate if transformer_learning_rate is not None else learning_rate 18 | params.append({'params': child.parameters(), 'lr': lr}) 19 | else: 20 | params.append({'params': child.parameters(), 'lr': learning_rate}) 21 | 22 | super().__init__(params, **options) 23 | 24 | 25 | class Adam(Adam): 26 | @ex.capture() 27 | def __init__(self, learning_rate, weight_decay): 28 | options = { 29 | 'lr': learning_rate, 30 | 'weight_decay': weight_decay 31 | } 32 | super().__init__(**options) -------------------------------------------------------------------------------- /code/data/README.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | 3 | ## Annotation Format 4 | 5 | We follow the json schema of VQA 2.0, while adding up some annotations relevant to our domain. 6 | 7 | * "info": {"year", "version", "description", "contributor", "url", "date_created"} 8 | * "license": {"name", "url"} 9 | * "data_type": str(train, val or test) 10 | * "questions": dict 11 | * "question_id": {"question", "video_id", "question_type"} 12 | * "question_type": str(a or s) 13 | * "annotations": dict 14 | * "question_id": {"video_id", "answer", "ground_l", "ground_c"} 15 | * "ground_l": [keyword1(, keyword2)] 16 | * "ground_c": [coordinate1(, coordinate2)] 17 | * coordinate: [x1, y1, x2, y2, pov] 18 | * "videos": dict 19 | * "video_id": {"begin", "end", "width", "height", "keyword"} 20 | * "answers": list 21 | 22 | 23 | ## Data structure 24 | 25 | All features should be stored under `/data/feat` directory. 26 | 27 | * `/data/feat/video/(feature_type)/(video_id).pkl` 28 | * `data.keys()` would return [feat, cls, (four different coordinates)] 29 | * `/data/feat/audio/(feature_type)/(video_id).pkl` 30 | * `data.keys()` would reaturn [feat, cls, harmonics, (time)] -------------------------------------------------------------------------------- /code/ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from munch import Munch 3 | 4 | from exp import ex 5 | 6 | 7 | @ex.capture 8 | def save_ckpt(epoch, loss, model, log_path, config_dir, _config): 9 | print(f'[LOG] Saving epoch {epoch:02d}') 10 | ckpt = { 11 | 'args': _config, 12 | 'epoch': epoch, 13 | 'loss': loss, 14 | 'model': model.state_dict() 15 | } 16 | 17 | ckpt_path = log_path / config_dir 18 | ckpt_path.mkdir(parents=True, exist_ok=True) 19 | torch.save(ckpt, ckpt_path / f"{epoch:02d}_ckpt.pth") 20 | 21 | 22 | @ex.capture 23 | def load_ckpt(model, ckpt_name, ckpt_path, model_config, cache_path): 24 | if ckpt_name is not None: 25 | name = f'{ckpt_name}*' if not ckpt_name.endswith('*') else f'{ckpt_name}' 26 | ckpt_path = sorted(ckpt_path.glob(name), reverse=False) 27 | assert len(ckpt_path) > 0, \ 28 | "[ERROR] No checkpoint candidate for {}.".format(ckpt_name) 29 | ckpt_path = ckpt_path[0] 30 | print(f'[LOG] Loading checkpoint {ckpt_path}') 31 | data = torch.load(ckpt_path) 32 | if isinstance(data, dict) and 'model' in data: 33 | model.load_state_dict(data['model']) 34 | else: 35 | model.load_state_dict(data) 36 | return model -------------------------------------------------------------------------------- /code/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from torch import optim 5 | from inflection import underscore 6 | 7 | from exp import ex 8 | from .schedulers import get_scheduler 9 | 10 | 11 | optim_dict = {} 12 | 13 | 14 | def add_optims(): 15 | path = Path(os.path.dirname(__file__)) 16 | 17 | for p in path.glob('*.py'): 18 | name = p.stem 19 | parent = p.parent.stem 20 | 21 | if name != "__init__": 22 | __import__("{}.{}".format(parent, name)) 23 | module = eval(name) 24 | for member in dir(module): 25 | member = getattr(module, member) 26 | if hasattr(member, '__bases__') and \ 27 | (optim.Optimizer in member.__bases__ or \ 28 | optim.lr_scheduler._LRScheduler in member.__bases__ or \ 29 | optim.Optimizer in member.__bases__[0].__bases__ or \ 30 | optim.lr_scheduler._LRScheduler in member.__bases[0].__bases__): 31 | optim_dict[underscore(str(member.__name__))] = member 32 | 33 | 34 | @ex.capture() 35 | def get_optimizer(model, t_total, optimizer_name, learning_rate): 36 | optim = optim_dict[optimizer_name](model, learning_rate) 37 | optim.zero_grad() 38 | scheduler = get_scheduler(optim, t_total) 39 | return optim, scheduler 40 | 41 | 42 | add_optims() -------------------------------------------------------------------------------- /code/metrics/logger.py: -------------------------------------------------------------------------------- 1 | # Simple tb logger 2 | import torch 3 | 4 | from exp import ex 5 | 6 | ''' 7 | geometry_normalizer = { 8 | 'cartesian': 4, # [0,1]x[0,1]x[0,1]x[0,1] 9 | 'angular': 98.696, # [-pi,pi]x[-.5pi,.5pi]x[0,2pi]x[0,pi] 10 | 'spherical': 61.348, # [-1,1]x[-1,1]x[-1,1]x[0,2pi]x[0,pi] 11 | 'quaternion': 17 # [0,1]x[-1,1]x[-1,1]x[0,2]x[0,2] 12 | } 13 | ''' 14 | 15 | def write_logs(logger, timestamp, lr, stat, meta, mode="train"): 16 | if mode == "train": 17 | logger.add_scalar('Train/lr', lr, timestamp) 18 | 19 | for k, v in stat.items(): 20 | if type(v) == torch.Tensor and v.dim() == 0: 21 | logger.add_scalar(f'Train/{k}', v.item(), timestamp) 22 | elif type(v) == str: 23 | logger.add_text(f'Train/{k}', v, timestamp) 24 | 25 | else: 26 | for k, v in stat.items(): 27 | if type(v) in [int, float]: 28 | logger.add_scalar(f'{mode.capitalize()}/{k}', v, timestamp) 29 | elif type(v) == torch.Tensor and v.dim() == 0: 30 | logger.add_scalar(f'{mode.capitalize()}/{k}', v.item(), timestamp) 31 | elif type(v) == str: 32 | logger.add_text(f'{mode.capitalize()}/{k}', v, timestamp) 33 | #logger.add_image('Eval/image', img, timestamp) 34 | 35 | 36 | @ex.capture() 37 | def adjust_grounding_error(error, geometry): 38 | return error * geometry_normalizer[geometry] -------------------------------------------------------------------------------- /code/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from torch import nn 5 | from munch import Munch 6 | from inspect import getmro 7 | from inflection import underscore 8 | 9 | from exp import ex 10 | 11 | 12 | model_dict = {} 13 | 14 | 15 | def add_models(): 16 | path = Path(os.path.dirname(__file__)) 17 | 18 | for p in path.glob('*.py'): 19 | name = p.stem 20 | parent = p.parent.stem 21 | if name != '__init__': 22 | __import__(f"{parent}.{name}") 23 | module = eval(name) 24 | for member in dir(module): 25 | member = getattr(module, member) 26 | if hasattr(member, '__mro__') and \ 27 | nn.Module in getmro(member) and \ 28 | hasattr(member, 'is_full_model'): 29 | model_dict[underscore(str(member.__name__))] = member 30 | 31 | 32 | def get_model_class(model): 33 | if not model_dict: 34 | add_models() 35 | 36 | assert model in model_dict.keys(), "[ERROR] Provided model \'{}\' does not exist. Possible candidates are: \n{}".format(model, str(model_dict.keys())) 37 | model = model_dict[model] 38 | return model 39 | 40 | 41 | @ex.capture() 42 | def get_model(tokenizer, model_name, model_config, cache_path): 43 | print("[LOG] Using model {}".format(model_name)) 44 | model = get_model_class(model_name) 45 | return model(Munch(model_config), cache_path) -------------------------------------------------------------------------------- /code/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | 5 | from exp import ex 6 | from ckpt import load_ckpt 7 | from model import get_model 8 | from data.dataloader import get_dataloaders 9 | 10 | 11 | @ex.capture() 12 | def prepare_batch(batch, device): 13 | 14 | data, label, meta = batch #zip(*batch) 15 | 16 | for key, value in data.items(): 17 | if isinstance(value, list): 18 | data[key] = [convert(v, device) for v in value] 19 | elif isinstance(value, dict): 20 | data[key] = {k: convert(v, device) for k, v in value.items()} 21 | else: 22 | data[key] = convert(value, device) 23 | 24 | for key, value in label.items(): 25 | if isinstance(value, list): 26 | label[key] = [convert(v, device) for v in value] 27 | elif isinstance(value, dict): 28 | label[key] = {k: convert(v, device) for k, v in value.items()} 29 | else: 30 | label[key] = convert(value, device) 31 | 32 | return data, label, meta 33 | 34 | 35 | def convert(value, device): 36 | if isinstance(value, np.ndarray): 37 | value = torch.from_numpy(value) 38 | if torch.is_tensor(value): 39 | value = value.to(device) 40 | return value 41 | 42 | 43 | @ex.capture() 44 | def get_all(data_modes, device): 45 | dataloaders, video, tokenizer = get_dataloaders(modes=data_modes) 46 | model = get_model(tokenizer) 47 | model = load_ckpt(model).to(device) 48 | criterion = get_criterion() 49 | 50 | return dataloaders, video, tokenizer, model, criterion 51 | 52 | 53 | def get_criterion(): 54 | criterion = nn.CrossEntropyLoss() 55 | return criterion 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pano-AVQA 2 | 3 | Official repository of PanoAVQA: Grounded Audio-Visual Question Answering in 360° Videos (ICCV 2021) 4 | 5 | ![Data_fig](https://hs-yn.github.io/assets/img/panoavqa_data.png) 6 | 7 | ### [[Paper]](https://openaccess.thecvf.com/content/ICCV2021/html/Yun_Pano-AVQA_Grounded_Audio-Visual_Question_Answering_on_360deg_Videos_ICCV_2021_paper.html) [[Poster]](https://hs-yn.github.io/assets/pdf/2021iccv_panoavqa_poster.pdf) [Video] 8 | 9 | 10 | ## Getting Started 11 | 12 | This code is based on following libraries: 13 | 14 | * `python=3.8` 15 | * `pytorch=1.7.0` (with cuda 10.2) 16 | 17 | To create virtual environment with all necessary libraries: 18 | 19 | ```bash 20 | conda env create -f environment.yml 21 | ``` 22 | 23 | By default data should be saved under `data/feat/{audio,label,visual}` directory and logs (w/ cache, checkpoint) are saved under `data/{cache,ckpt,log}` directory. Using symbolic link is recommended: 24 | 25 | ```bash 26 | ln -s {path_to_your_data_directory} data 27 | ``` 28 | 29 | We use single TITAN RTX for training, but GPUs with less memory are still doable with smaller batch size (provided precomputed features). 30 | 31 | 32 | ## Dataset 33 | 34 | We plan to release the Pano-AVQA dataset public within this year, including Q&A annotation, precomputed features, etc. Please stay tuned! 35 | 36 | 37 | ## Model 38 | 39 | ### Training 40 | 41 | Default configuration is provided in `code/config.py`. To run with this configuration: 42 | 43 | ```bash 44 | python cli.py 45 | ``` 46 | 47 | To run with custom configuration, either modify `code/config.py` or execute: 48 | 49 | ```bash 50 | python cli.py with {{flags_at_your_disposal}} 51 | ``` 52 | 53 | ### Inference 54 | 55 | Model weight is saved under `./data/log` directory. To run inference only: 56 | 57 | ```bash 58 | python cli.py eval with ckpt_file=../data/log/{experiment}/{ckpt}.pth 59 | ``` 60 | 61 | 62 | ## Citation 63 | 64 | If you find our work useful in your research, please consider citing: 65 | 66 | ```tex 67 | @InProceedings{Yun2021PanoAVQA, 68 | author = {Yun, Heeseung and Yu, Youngjae and Yang, Wonsuk and Lee, Kangil and Kim, Gunhee}, 69 | title = {Pano-AVQA: Grounded Audio-Visual Question Answering on 360$^\circ$ Videos}, 70 | booktitle = {ICCV}, 71 | year = {2021} 72 | } 73 | ``` 74 | 75 | 76 | ## Contact 77 | 78 | If you have any inquiries, please don't hesitate to contact us via heeseung.yun at vision.snu.ac.kr. 79 | -------------------------------------------------------------------------------- /code/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from exp import ex 9 | from ckpt import load_ckpt 10 | from metrics.logger import write_logs 11 | from optimizer import get_optimizer 12 | from metrics.logger import write_logs 13 | from common import prepare_batch, get_all 14 | 15 | 16 | grounding_error = torch.nn.MSELoss(reduction='none') 17 | 18 | 19 | def get_accs(gt, prop, qtype): 20 | retval = {} 21 | correct = (gt == prop).float() 22 | retval['acc_total'] = [k.item() for k in correct] 23 | is_av = torch.Tensor(np.array(qtype) == 'a').float() 24 | retval['acc_av'] = [k.item() for i, k in enumerate(correct) if is_av[i] == 1] 25 | retval['acc_sp'] = [k.item() for i, k in enumerate(correct) if is_av[i] == 0] 26 | 27 | return retval 28 | 29 | 30 | def get_errors(gt, prop, qtype): 31 | retval = {} 32 | errors = grounding_error(prop, gt).sum(1) 33 | retval['mse_total'] = [k.item() for k in errors] 34 | is_av = torch.Tensor(np.array(qtype) == 'a').float() 35 | retval['mse_av'] = [k.item() for i, k in enumerate(errors) if is_av[i] == 1] 36 | retval['mse_sp'] = [k.item() for i, k in enumerate(errors) if is_av[i] == 0] 37 | 38 | return retval 39 | 40 | 41 | 42 | @ex.capture() 43 | def _eval(log_path, ckpt_path, config_dir, max_epochs, pretrain_epochs, answer_path, learning_rate, ckpt_file, 44 | pretrain_learning_rate, pretrain_types, split_train, model_config, _config): 45 | dataloaders, _, tokenizer, model, criterion = get_all(data_modes=['test']) 46 | 47 | answer_dict = json.load(open(answer_path, 'r')) 48 | 49 | # print(model) 50 | # PRETRAIN 51 | model.load_state_dict(torch.load(ckpt_file)['model']) 52 | model.eval() 53 | qas = {} 54 | 55 | for _batch in tqdm(dataloaders['test'], total=len(dataloaders['test']), desc="Test"): 56 | batch, label, meta = prepare_batch(_batch) 57 | 58 | with torch.no_grad(): 59 | stats = model(batch, label, ['qa', 'ground']) 60 | label['qa'] = label['qa'].cpu() 61 | label['ground'] = label['ground'].cpu() 62 | if 'ground_pred' in stats.keys(): 63 | for i in range(len(label['ground'])): 64 | qas[meta['question_id'][i]] = { 65 | "video_id": meta["video_id"][i], 66 | "question": meta["question"][i], 67 | "ans_gt": answer_dict[label['qa'][i].item()], 68 | "ans_pr": answer_dict[stats['answer_pred'][i].item()], 69 | "grnd_gt": label['ground'][i].numpy().tolist(), 70 | "grnd_pr": stats['ground_pred'][i].numpy().tolist(), 71 | "grnd_err": grounding_error(label['ground'][i], stats['ground_pred'][i]).sum().item() 72 | } 73 | else: 74 | assert False, "No grounding available" 75 | 76 | json.dump(qas, open('./{}.json'.format(ckpt_file.split('/')[-2]),'w'), indent=2) 77 | -------------------------------------------------------------------------------- /code/model/config.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers.configuration_utils import PretrainedConfig 3 | 4 | from exp import ex 5 | 6 | 7 | ''' 8 | Refrain from directly modifying this configuration! 9 | Please use /code/config.py for such purpose 10 | ''' 11 | class ModelConfig(PretrainedConfig): 12 | def __init__(self, **kwargs): 13 | super().__init__() 14 | self.__dict__.update(kwargs) 15 | 16 | self.pretrain_task_list = ['mask_lm', 'visual_feat', 'visual_coord', 'visual_label', 17 | 'audio_feat', 'audio_coord', 'audio_label', 'vl_match', 18 | 'al_match', 'qa', 'ground', 19 | 'visual_vilbert', 'audio_vilbert'] 20 | 21 | self.pretrain_types = [] 22 | for task in self.pretrain_task_list: 23 | if hasattr(self, task) and getattr(self, task): 24 | self.pretrain_types.append(task) 25 | 26 | self.finetune_types = ['qa'] 27 | if hasattr(self, 'use_grounding') and getattr(self, 'use_grounding'): 28 | self.finetune_types.append('grounding') 29 | 30 | self.audio_encoder = 'stereo' if self.use_stereo_audio else 'mono' 31 | if self.audio_coord_dim == 3: 32 | self.audio_encoder += '_st' 33 | elif self.audio_coord_dim == 2: 34 | self.audio_encoder += '_t' 35 | elif self.audio_coord_dim == 1: 36 | self.audio_encoder += '_s' 37 | 38 | if self.geometry in ['quaternion', 'spherical']: 39 | self.visual_coord_dim = 6 40 | elif self.geometry in ['angular', 'cartesian']: 41 | self.visual_coord_dim = 5 42 | else: 43 | self.visual_coord_dim = 0 44 | 45 | 46 | self.pretrain_loss_config = { 47 | 'mask_lm': ((-1, self.vocab_size), 'ce', (-1,), 1), 48 | 'visual_feat': ((-1, self.visual_feat_dim), 'l2', (-1, self.visual_feat_dim), self.loss_normalizer), 49 | 'visual_coord': ((-1, self.visual_coord_dim), 'l2', (-1, self.visual_coord_dim), self.loss_normalizer), 50 | 'visual_label': ((-1, self.visual_label_dim), 'ce_no_reduction', (-1,), self.loss_normalizer), 51 | 'audio_feat': ((-1, 2, self.audio_feat_dim) if self.use_stereo_audio else (-1, self.audio_feat_dim), 52 | 'l2', 53 | (-1, 2, self.audio_feat_dim) if self.use_stereo_audio else (-1, self.audio_feat_dim), 54 | self.loss_normalizer), 55 | 'audio_harmonics': ((-1, 1), 'l2', (-1, 1), self.loss_normalizer), 56 | 'audio_label': ((-1, self.audio_label_dim), 'ce_no_reduction', (-1,), self.loss_normalizer), 57 | 'audio_coord': ((-1, self.audio_coord_dim), 'l2', (-1, self.audio_coord_dim), self.loss_normalizer), 58 | 'vl_match': ((-1, 2), 'ce', (-1,), 1), 59 | 'al_match': ((-1, 2), 'ce', (-1,), 1), 60 | 'qa': ((-1, self.num_answers), 'ce_no_reduction', (-1,), 1), 61 | 'ground': ((-1, max(0, self.visual_coord_dim-1)), 'l2_reduction', (-1, max(0, self.visual_coord_dim-1)), self.lambda_ground), 62 | 'visual_vilbert': ((-1, self.visual_label_dim), 'kl', (-1,), self.loss_normalizer), 63 | 'audio_vilbert': ((-1, self.audio_label_dim), 'kl', (-1,), self.loss_normalizer) 64 | } 65 | -------------------------------------------------------------------------------- /code/data/qa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from exp import ex 6 | from misc.geometry import conversion 7 | 8 | ''' 9 | Data structure 10 | 11 | - annotation (json) 12 | - info: info(year,version,description,contributor,url,date_created) 13 | - license: license(name,url) 14 | - data_type: str(train,val,test) 15 | - questions: {'question_id': {'question', 'video_id','question_type'}} 16 | - annotations: {'question_id': {'video_id','answer','ground_l{1,2}','ground_c{1,2}'} 17 | - videos: {'video_id':{'begin','end','width','height','keyword'} 18 | ''' 19 | 20 | 21 | @ex.capture() 22 | def get_qa(tokenizer, data_path, max_length_qa, data=None, mode='train'): 23 | if data is None: 24 | data = json.load(open(data_path[mode], 'r')) 25 | outputs = {} 26 | 27 | minval = [10000 for _ in range(5)] 28 | maxval = [-10000 for _ in range(5)] 29 | 30 | for qid in tqdm(data['questions'].keys(), desc=f"Load {mode} split"): 31 | question = data['questions'][qid] 32 | annotation = data['annotations'][qid] 33 | #video = data['videos'][question['video_id']] 34 | 35 | encoded_question = tokenizer.encode(question['question'], 36 | truncation='longest_first', 37 | max_length=max_length_qa) 38 | 39 | xtl, ytl, xbr, ybr, pov = annotation['ground_c'][0] 40 | if type(pov) != int: 41 | # For some reason, few questions contain invalid grounding pov info 42 | # As a quickfix, we enforce pov to be 18 (ER) in such cases 43 | pov = 18 44 | geo_out = {} 45 | for geo in ['cartesian', 'angular', 'spherical', 'quaternion']: 46 | grounding = conversion(xtl, ytl, xbr-xtl, ybr-ytl, pov, geo) 47 | 48 | if geo == "cartesian": 49 | g = grounding 50 | grounding = [g[0]/1920., 51 | g[1]/1080., 52 | g[2]/1920., 53 | g[3]/1080.] 54 | # Tackling discontinuity 55 | if grounding[0] > 1: 56 | grounding[0] -= 1 57 | elif grounding[0] < 0: 58 | grounding[0] += 1 59 | else: 60 | #debug = annotation['ground_c'][0] 61 | grounding = list(grounding) 62 | 63 | if geo == "angular": 64 | if grounding[0] > np.pi: 65 | grounding[0] -= 2 * np.pi 66 | elif grounding[0] < -np.pi: 67 | grounding[0] += 2 * np.pi 68 | #if grounding[0] > np.pi: 69 | # print(debug, [f"{x:.4f}" for x in grounding]) 70 | geo_out[geo] = [2.] + grounding 71 | 72 | ''' Checking representation bound 73 | for i, x in enumerate(grounding): 74 | if x < minval[i]: 75 | minval[i] = x 76 | if x > maxval[i]: 77 | maxval[i] = x 78 | ''' # timestamp is fixed at 2 79 | output = { 80 | 'question_id': qid, 81 | 'question_type': question['question_type'], 82 | 'video_id': question['video_id'], 83 | 'question': encoded_question, 84 | 'answer': annotation['answer'] 85 | } 86 | output.update(geo_out) 87 | outputs[qid] = output 88 | 89 | ''' Checking representation bound 90 | print(geometry) 91 | print([f"{x:.4f}" for x in minval]) 92 | print([f"{x:.4f}" for x in maxval]) 93 | assert False 94 | ''' 95 | 96 | return outputs 97 | -------------------------------------------------------------------------------- /code/model/pretrained.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import subprocess as sp 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from exp import ex 9 | 10 | 11 | class BertPreTrainedModel(nn.Module): 12 | def __init__(self, model_config, cache_path, *inputs, **kwargs): 13 | super().__init__() 14 | self.config = model_config 15 | self.cache_path = cache_path 16 | 17 | def init_bert_weights(self, module): 18 | if isinstance(module, (nn.Linear, nn.Embedding)): 19 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 20 | elif isinstance(module, nn.LayerNorm): 21 | module.bias.data.zero_() 22 | module.weight.data.fill_(1.0) 23 | if isinstance(module, nn.Linear) and module.bias is not None: 24 | module.bias.data.zero_() 25 | 26 | @classmethod 27 | def from_pretrained(cls, model_config, cache_path, pretrained_model='bert-base-uncased', 28 | state_dict=None, from_tf=False, *inputs, **kwargs): 29 | # Assume bert-base-uncased 30 | assert pretrained_model == 'bert-base-uncased', f"[ERROR] {pretrained_model} is not supported." 31 | 32 | pretrained_model_path = cache_path / pretrained_model 33 | if not pretrained_model_path.is_dir(): 34 | pretrained_model_path.mkdir(parents=True, exist_ok=True) 35 | 36 | pretrained_model = pretrained_model_path / 'pytorch_model.bin' 37 | if not Path(pretrained_model).exists(): 38 | print("[LOG] Downloading BERT pretrained weight. It will take somewhere around 10 minutes, depending on your internet connection.") 39 | sp.call(["wget", "-P", f"{pretrained_model_path}", 40 | "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"]) 41 | sp.call(["tar", "-xvf", "{}".format(pretrained_model_path / 'bert-base-uncased.tar.gz')]) 42 | (pretrained_model_path / 'bert-base-uncased.tar.gz').unlink() 43 | 44 | model = cls(model_config, cache_path, *inputs, **kwargs) 45 | 46 | state_dict = torch.load(pretrained_model, map_location='cpu' if not torch.cuda.is_available() else None) 47 | 48 | old_keys = [] 49 | new_keys = [] 50 | for key in state_dict.keys(): 51 | new_key = None 52 | if "gamma" in key: 53 | new_key = key.replace('gamma', 'weight') 54 | if "beta" in key: 55 | new_key = key.replace('beta', 'bias') 56 | if new_key: 57 | old_keys.append(key) 58 | new_keys.append(new_key) 59 | for old_key, new_key in zip(old_keys, new_keys): 60 | state_dict[new_key] = state_dict.pop(old_key) 61 | 62 | missing_keys = [] 63 | unexpected_keys = [] 64 | error_msgs = [] 65 | metadata = getattr(state_dict, '_metadata', None) 66 | state_dict = state_dict.copy() 67 | if metadata is not None: 68 | state_dict._metadata = metadata 69 | 70 | def load(module, prefix=''): 71 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 72 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, 73 | missing_keys, unexpected_keys, error_msgs) 74 | for name, child in module._modules.items(): 75 | if child is not None: 76 | load(child, prefix + name + '.') 77 | start_prefix = '' 78 | if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()): 79 | start_prefix = 'bert.' 80 | load(model, prefix=start_prefix) 81 | 82 | # print('\n'.join(['{} {}'.format(k, v.size()) for k,v in state_dict.items()])) 83 | # print("Missing Keys: ", '\n'.join(missing_keys)) 84 | # print("Unexpected Keys: ", '\n'.join(unexpected_keys)) 85 | 86 | if len(error_msgs) > 0: 87 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 88 | model.__class__.__name__, "\n\t".join(error_msgs))) 89 | return model -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lavit 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - bzip2=1.0.8=h7b6447c_0 9 | - ca-certificates=2021.1.19=h06a4308_0 10 | - certifi=2020.12.5=py38h06a4308_0 11 | - cudatoolkit=10.2.89=hfd86e86_1 12 | - ffmpeg=4.2.2=h20bf706_0 13 | - freetype=2.10.4=h5ab3b9f_0 14 | - gmp=6.1.2=h6c8ec71_1 15 | - gnutls=3.6.5=h71b1129_1002 16 | - ignite=0.4.3=py_0 17 | - intel-openmp=2020.2=254 18 | - jpeg=9b=h024ee3a_2 19 | - lame=3.100=h7b6447c_0 20 | - lcms2=2.11=h396b838_0 21 | - ld_impl_linux-64=2.33.1=h53a641e_7 22 | - libedit=3.1.20191231=h14c3975_1 23 | - libffi=3.3=he6710b0_2 24 | - libgcc-ng=9.1.0=hdf63c60_0 25 | - libopus=1.3.1=h7b6447c_0 26 | - libpng=1.6.37=hbc83047_0 27 | - libstdcxx-ng=9.1.0=hdf63c60_0 28 | - libtiff=4.1.0=h2733197_1 29 | - libuv=1.40.0=h7b6447c_0 30 | - libvpx=1.7.0=h439df22_0 31 | - lz4-c=1.9.2=heb0550a_3 32 | - mkl=2020.2=256 33 | - mkl-service=2.3.0=py38he904b0f_0 34 | - mkl_fft=1.2.0=py38h23d657b_0 35 | - mkl_random=1.1.1=py38h0573a6f_0 36 | - ncurses=6.2=he6710b0_1 37 | - nettle=3.4.1=hbb512f6_0 38 | - ninja=1.10.1=py38hfd86e86_0 39 | - numpy=1.19.2=py38h54aff64_0 40 | - numpy-base=1.19.2=py38hfa32c7d_0 41 | - olefile=0.46=py_0 42 | - openh264=2.1.0=hd408876_0 43 | - openssl=1.1.1i=h27cfd23_0 44 | - pillow=8.0.1=py38he98fc37_0 45 | - pip=20.2.4=py38h06a4308_0 46 | - python=3.8.5=h7579374_1 47 | - pytorch=1.7.0=py3.8_cuda10.2.89_cudnn7.6.5_0 48 | - readline=8.0=h7b6447c_0 49 | - setuptools=50.3.0=py38h06a4308_1 50 | - six=1.15.0=py_0 51 | - sqlite=3.33.0=h62c20be_0 52 | - tk=8.6.10=hbc83047_0 53 | - torchaudio=0.7.0=py38 54 | - torchvision=0.8.1=py38_cu102 55 | - typing_extensions=3.7.4.3=py_0 56 | - wheel=0.35.1=py_0 57 | - x264=1!157.20191217=h7b6447c_0 58 | - xz=5.2.5=h7b6447c_0 59 | - zlib=1.2.11=h7b6447c_3 60 | - zstd=1.4.5=h9ceee32_0 61 | - pip: 62 | - absl-py==0.11.0 63 | - audioread==2.1.9 64 | - backcall==0.2.0 65 | - cachetools==4.2.1 66 | - cffi==1.14.3 67 | - chardet==3.0.4 68 | - click==7.1.2 69 | - colorama==0.4.4 70 | - cycler==0.10.0 71 | - decorator==4.4.2 72 | - docopt==0.6.2 73 | - filelock==3.0.12 74 | - gitdb==4.0.5 75 | - gitpython==3.1.13 76 | - google-auth==1.27.1 77 | - google-auth-oauthlib==0.4.3 78 | - grpcio==1.36.1 79 | - h5py==3.1.0 80 | - idna==2.10 81 | - inflection==0.5.1 82 | - ipdb==0.13.4 83 | - ipython==7.20.0 84 | - ipython-genutils==0.2.0 85 | - jedi==0.18.0 86 | - joblib==0.17.0 87 | - jsonpickle==1.5.2 88 | - kiwisolver==1.3.1 89 | - librosa==0.6.3 90 | - llvmlite==0.31.0 91 | - markdown==3.3.4 92 | - matplotlib==3.3.2 93 | - munch==2.5.0 94 | - nms==0.1.6 95 | - numba==0.48.0 96 | - oauthlib==3.1.0 97 | - opencv-python==4.4.0.46 98 | - packaging==20.4 99 | - pandas==1.1.4 100 | - parso==0.8.1 101 | - pexpect==4.8.0 102 | - pickleshare==0.7.5 103 | - prompt-toolkit==3.0.16 104 | - protobuf==3.13.0 105 | - ptyprocess==0.7.0 106 | - py-cpuinfo==7.0.0 107 | - pyasn1==0.4.8 108 | - pyasn1-modules==0.2.8 109 | - pycparser==2.20 110 | - pygments==2.8.0 111 | - pylab-sdk==1.3.2 112 | - pymongo==3.11.3 113 | - pyparsing==2.4.7 114 | - python-dateutil==2.8.1 115 | - pytz==2020.4 116 | - regex==2020.10.28 117 | - requests==2.24.0 118 | - requests-oauthlib==1.3.0 119 | - resampy==0.2.2 120 | - rsa==4.7.2 121 | - sacred==0.8.2 122 | - sacremoses==0.0.43 123 | - scikit-learn==0.23.2 124 | - scipy==1.5.4 125 | - sentencepiece==0.1.94 126 | - smmap==3.0.5 127 | - soundfile==0.10.3.post1 128 | - tensorboard==2.4.1 129 | - tensorboard-plugin-wit==1.8.0 130 | - tensorboardx==2.1 131 | - threadpoolctl==2.1.0 132 | - tokenizers==0.9.2 133 | - torchlibrosa==0.0.4 134 | - tqdm==4.51.0 135 | - traitlets==5.0.5 136 | - transformers==3.4.0 137 | - urllib3==1.25.11 138 | - wcwidth==0.2.5 139 | - werkzeug==1.0.1 140 | - wrapt==1.12.1 141 | prefix: /home2/heeseung.yun/anaconda3/envs/lavit 142 | 143 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | default_args = { 4 | # Logging and general configuration 5 | 'debug': True, 6 | 'num_workers': 40, 7 | 'random_seed': 1234, 8 | 'log_path': 'data/log', 9 | 'ckpt_path': 'data/ckpt', 10 | 'ckpt_name': None, 11 | 'log_tag': '', # brief tagging for readibility 12 | 'log_keys': [ 13 | 'log_tag', 14 | 'model_name' 15 | ], 16 | 'feat_path': 'data/feat', 17 | 'pretrain_path': 'data/feat/label/trainval.json', 18 | 'train_path': 'data/feat/label/trainval.json', 19 | 'preval_path': 'data/feat/label/test.json', 20 | 'val_path': 'data/feat/label/test.json', 21 | 'test_path': 'data/feat/label/test.json', 22 | 'answer_path': 'data/feat/label/answer_2020.json', 23 | 'cache_path': 'data/cache', 24 | 'output_path': 'data/output', 25 | 'rebuild_cache': False, 26 | 27 | 'model_name': 'lavit', 28 | 'transformer_name': 'bert', 29 | 'num_objects': 36, 30 | 31 | # Learning configuration 32 | 'pretrain_epochs': 3, 33 | 'max_epochs': 10, 34 | 35 | 'batch_size': 32, 36 | 'grad_acc_steps': 4, 37 | 38 | 'split_train': False, 39 | 'optimizer_name': 'bert_adam', 40 | 'scheduler_name': 'linear', 41 | 'weight_decay': 1e-2, 42 | 'learning_rate': 5e-5, 43 | 'pretrain_learning_rate': 1e-4, 44 | 'transformer_learning_rate': 1e-5, 45 | 'warmup': 0.1, 46 | 'feature_mask_rate': 0.15, 47 | 'device': 'cuda', 48 | 49 | 'max_length_qa': 60, 50 | 51 | 'model_config': { 52 | # Input/output features 53 | 'audio_feature': 'top_3', # None,orig,pool_{2,4,all},top_{3,5} 54 | 'visual_feature': 'rcnn_all', # None,rcnn_all,rcnn_center,rcnn_cube,rcnn_er,rcnn_nfov,i3d_center,i3d_er 55 | 'geometry': 'quaternion', # cartesian,angular,spherical,quaternion 56 | 57 | # Input/output dimension 58 | 'vocab_size': 30522, 59 | 'num_answers': 2020, 60 | 'visual_feat_dim': 2048, 61 | 'visual_label_dim': 200, 62 | 'visual_coord_dim': 6, 63 | 'audio_feat_dim': 2048, 64 | 'audio_label_dim': 527, 65 | 'audio_coord_dim': 3, 66 | 67 | # Model embedding dimension 68 | 'hidden_size': 768, 69 | 'num_attention_heads': 12, 70 | 'intermediate_size': 3072, 71 | 'max_position_embeddings': 512, 72 | 'type_vocab_size': 2, 73 | 'pad_token_id': 0, 74 | 75 | # Probability / activation 76 | 'hidden_act': 'gelu', 77 | 'hidden_dropout_prob': 0.1, 78 | 'attention_probs_dropout_prob': 0.1, 79 | 'initializer_range': 0.02, 80 | 'layer_norm_eps': 1e-12, 81 | 'loss_normalizer': 6.67, # 1 / 0.15 82 | 83 | # Model structure 84 | 'use_concat_encoder': False, 85 | 'use_concat_decoder': True, 86 | 'use_stereo_audio': True, 87 | 'use_cls_token': True, 88 | 'use_grounding': True, 89 | 'l_layers': 9, 90 | 'v_layers': 5, 91 | 'a_layers': 5, 92 | 'x_layers': 5, 93 | 94 | # Pretrain task toggle 95 | 'mask_lm': True, 96 | 'visual_feat': True, 97 | 'visual_coord': True, 98 | 'visual_label': True, 99 | 'audio_feat': True, 100 | 'audio_coord': True, 101 | 'audio_label': True, 102 | 'vl_match': True, 103 | 'al_match': True, 104 | 'qa': True, 105 | 'ground': True, 106 | 107 | # Hyperparameter 108 | 'lambda_ground': 0.2 109 | }, 110 | } 111 | 112 | lxmert_args = deepcopy(default_args) 113 | lxmert_args['model_config'].update({ 114 | # Pretrain task toggle 115 | 'mask_lm': True, 116 | 'visual_feat': True, 117 | 'visual_coord': False, 118 | 'visual_label': True, 119 | 'audio_feat': False, 120 | 'audio_coord': False, 121 | 'audio_label': False, 122 | 'vl_match': True, 123 | 'al_match': False, 124 | 'qa': True, 125 | 'ground': False, 126 | }) 127 | lxmert_args.update({ 128 | 'use_cls_token': False, 129 | 'use_concat_decoder': False, 130 | 'use_grounding': False, 131 | }) 132 | 133 | 134 | args_dict = { 135 | 'lxmert': lxmert_args, 136 | 'lavit': default_args, 137 | 'bert': default_args 138 | } -------------------------------------------------------------------------------- /code/args.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import random 4 | from datetime import datetime 5 | from pathlib import Path, PosixPath 6 | 7 | import torch 8 | import numpy as np 9 | from munch import Munch 10 | from sacred.arg_parser import get_config_updates 11 | 12 | from config import args_dict 13 | from model.config import * 14 | 15 | 16 | def get_args(options, fixed_args={}): 17 | '''Processes arguments''' 18 | updated_args = {} 19 | updated_args.update(get_new_args(options)) 20 | updated_args.update(fixed_args) 21 | 22 | default_args = get_default_config(options, fixed_args) 23 | #import json; print(json.dumps(current_args, indent=2)) 24 | #assert False 25 | 26 | args = Munch(default_args) 27 | args.update(Munch(updated_args)) 28 | if args.ckpt_name is not None: 29 | args.update(Munch(load_args(args))) 30 | args.update(Munch(updated_args)) 31 | 32 | args.update(fix_seed(args)) 33 | args.update(resolve_paths(args)) 34 | args = update_data_path(args) 35 | 36 | args.config_dir = get_config_dir(args) 37 | args.model_config = get_model_config(args.model_name, args.model_config) 38 | args = args.toDict() 39 | 40 | # Primary assertions 41 | if args['device'] == 'cuda': 42 | assert torch.cuda.is_available(), "GPU device is not available" 43 | 44 | return args 45 | 46 | 47 | def get_default_config(options, fixed_args): 48 | updated_args = get_config_updates(options['UPDATE'])[0] 49 | if 'model_name' in updated_args.keys() and updated_args['model_name'] in args_dict.keys(): 50 | return args_dict[updated_args['model_name']] 51 | else: 52 | return default_args 53 | 54 | 55 | def get_model_config(model_name, args): 56 | model_config = ModelConfig(**args) 57 | # Sacred is not capable of capturing classmethods as variable 58 | return {k:v for k,v in vars(model_config).items()} 59 | 60 | 61 | def get_new_args(options): 62 | '''Fetch updated arguments that deviate from default settings''' 63 | if 'UPDATE' in options: 64 | new_args, _ = get_config_updates(options['UPDATE']) 65 | else: 66 | new_args = options 67 | return new_args 68 | 69 | 70 | def load_args(args): 71 | '''Load arguments of previous experiment''' 72 | root = Path('../').resolve() 73 | 74 | if str(root) not in str(args.ckpt_path): 75 | args.ckpt_path = root / args.ckpt_path 76 | args_path = sorted(args.ckpt_path.glob(f'{args.ckpt_name}*')) 77 | if args.ckpt_name is None or len(args_path) <= 0: 78 | return {} 79 | args_path = args_path[0] / 'args.json' 80 | ckpt_args = {} 81 | if args_path.is_file(): 82 | ckpt_args = json.load(open(args_path, 'r'))['args'] 83 | # update non-string arguments (and data_path) 84 | eval_keys = [k for k, v in default_args.items() if not isinstance(v, str)] 85 | eval_keys.append('data_path') 86 | # ckpt_args = {k: eval(v) if k in eval_keys else v for k, v in ckpt_args.items()} 87 | ckpt_args = {k: v for k, v in ckpt_args.items() if not k.endswith('_path')} 88 | return ckpt_args 89 | 90 | 91 | def resolve_paths(args): 92 | '''Convert strings into paths if applicable''' 93 | path_list = [k for k in args.keys() if k.endswith('_path') and k != 'data_path'] 94 | res_args = {} 95 | res_args['root'] = Path('../').resolve() 96 | for path in path_list: 97 | if args[path] is not None: 98 | if isinstance(args[path], list): 99 | res_args[path] = [res_args['root'] / Path(v) for v in args[path]] 100 | elif isinstance(args[path], dict): 101 | res_args[path] = {k: res_args['root'] / Path(v) for k, v in args[path].items()} 102 | else: 103 | res_args[path] = res_args['root'] / Path(args[path]) 104 | return res_args 105 | 106 | 107 | def update_data_path(args): 108 | '''Update dataset path''' 109 | if 'data_path' not in args: 110 | args['data_path'] = {} 111 | for k in ['pretrain', 'train', 'preval', 'val', 'test']: 112 | if f"{k}_path" in args: 113 | args['data_path'][k] = args[f"{k}_path"] 114 | del args[f"{k}_path"] 115 | for k, path in args.data_path.items(): 116 | path = Path(path).resolve() if str(path).startswith('/') else args.root / path 117 | args.data_path[k] = path 118 | return args 119 | 120 | 121 | 122 | def fix_seed(args): 123 | '''Fix random seeds at once''' 124 | if 'random_seed' not in args or not isinstance(args['random_seed'], int): 125 | args['random_seed'] = args['seed'] if 'seed' in args else 0 126 | args['seed'] = args['random_seed'] # for sacred 127 | 128 | random.seed(args['random_seed']) 129 | np.random.seed(args['random_seed']) 130 | torch.manual_seed(args['random_seed']) 131 | torch.cuda.manual_seed_all(args['random_seed']) 132 | # torch.backends.cudnn.benchmark = False 133 | # torch.set_deterministic(True) 134 | torch.multiprocessing.set_sharing_strategy('file_system') 135 | 136 | return args 137 | 138 | 139 | def get_config_dir(args): 140 | '''Generate directory name for logging''' 141 | now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 142 | tags = [re.sub('[,\W-]+', '_', str(args[key])) for key in args['log_keys']] 143 | dirname = '_'.join(tags)[:100] # Avoid too long paths 144 | return f"{dirname}_{now}" 145 | -------------------------------------------------------------------------------- /code/model/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | import math 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | def gelu(x): 24 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 25 | 26 | def swish(x): 27 | return x * torch.sigmoid(x) 28 | 29 | ACT2FN = {"gelu": gelu, "relu": nn.functional.relu, "swish": swish} 30 | 31 | 32 | class BertSelfAttention(nn.Module): 33 | def __init__(self, model_config): 34 | super().__init__() 35 | if model_config.hidden_size % model_config.num_attention_heads != 0: 36 | raise ValueError( 37 | "The hidden size (%d) is not a multiple of the number of attention " 38 | "heads (%d)" % (model_config.hidden_size, model_config.num_attention_heads) 39 | ) 40 | self.num_attention_heads = model_config.num_attention_heads 41 | self.attention_head_size = int(model_config.hidden_size / model_config.num_attention_heads) 42 | self.all_head_size = self.num_attention_heads * self.attention_head_size 43 | 44 | self.query = nn.Linear(model_config.hidden_size, self.all_head_size) 45 | self.key = nn.Linear(model_config.hidden_size, self.all_head_size) 46 | self.value = nn.Linear(model_config.hidden_size, self.all_head_size) 47 | 48 | self.dropout = nn.Dropout(model_config.attention_probs_dropout_prob) 49 | 50 | def transpose_for_scores(self, x): 51 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 52 | x = x.view(*new_x_shape) 53 | return x.permute(0, 2, 1, 3) 54 | 55 | def forward( 56 | self, 57 | hidden_states, 58 | context, 59 | attention_mask=None, 60 | ): 61 | query_layer = self.transpose_for_scores(self.query(hidden_states)) 62 | key_layer = self.transpose_for_scores(self.key(context)) 63 | value_layer = self.transpose_for_scores(self.value(context)) 64 | 65 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 66 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 67 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 68 | if attention_mask is not None: 69 | attention_scores = attention_scores + attention_mask 70 | 71 | # Normalize the attention scores to probabilities. 72 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 73 | 74 | # This is actually dropping out entire tokens to attend to, which might 75 | # seem a bit unusual, but is taken from the original Transformer paper. 76 | attention_probs = self.dropout(attention_probs) 77 | 78 | context_layer = torch.matmul(attention_probs, value_layer) 79 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 80 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 81 | context_layer = context_layer.view(*new_context_layer_shape) 82 | 83 | return context_layer 84 | 85 | 86 | class BertSelfOutput(nn.Module): 87 | def __init__(self, model_config): 88 | super().__init__() 89 | self.dense = nn.Linear(model_config.hidden_size, model_config.hidden_size) 90 | self.LayerNorm = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 91 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 92 | 93 | def forward(self, hidden_states, input_tensor): 94 | hidden_states = self.dense(hidden_states) 95 | hidden_states = self.dropout(hidden_states) 96 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 97 | return hidden_states 98 | 99 | 100 | class BertAttention(nn.Module): 101 | def __init__(self, model_config): 102 | super().__init__() 103 | self.self = BertSelfAttention(model_config) 104 | self.output = BertSelfOutput(model_config) 105 | 106 | def forward(self, hidden_states, attention_mask=None): 107 | self_output = self.self(hidden_states, hidden_states, attention_mask) 108 | attention_output = self.output(self_output, hidden_states) 109 | return attention_output 110 | 111 | 112 | class BertCrossAttention(nn.Module): 113 | def __init__(self, model_config): 114 | super().__init__() 115 | self.self = BertSelfAttention(model_config) 116 | self.output = BertSelfOutput(model_config) 117 | 118 | def forward(self, hidden_states, context_states, context_mask=None): 119 | self_outputs = self.self(hidden_states, context_states, context_mask) 120 | attention_output = self.output(self_outputs, hidden_states) 121 | return attention_output 122 | 123 | 124 | class BertIntermediate(nn.Module): 125 | def __init__(self, model_config): 126 | super().__init__() 127 | self.dense = nn.Linear(model_config.hidden_size, model_config.intermediate_size) 128 | if isinstance(model_config.hidden_act, str): 129 | self.intermediate_act_fn = ACT2FN[model_config.hidden_act] 130 | else: 131 | self.intermediate_act_fn = model_config.hidden_act 132 | 133 | def forward(self, hidden_states): 134 | hidden_states = self.dense(hidden_states) 135 | hidden_states = self.intermediate_act_fn(hidden_states) 136 | return hidden_states 137 | 138 | 139 | class BertOutput(nn.Module): 140 | def __init__(self, model_config): 141 | super().__init__() 142 | self.dense = nn.Linear(model_config.intermediate_size, model_config.hidden_size) 143 | self.LayerNorm = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 144 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 145 | 146 | def forward(self, hidden_states, input_tensor): 147 | hidden_states = self.dense(hidden_states) 148 | hidden_states = self.dropout(hidden_states) 149 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 150 | return hidden_states 151 | 152 | 153 | class BertLayer(nn.Module): 154 | def __init__(self, model_config): 155 | super(BertLayer, self).__init__() 156 | self.attention = BertAttention(model_config) 157 | self.intermediate = BertIntermediate(model_config) 158 | self.output = BertOutput(model_config) 159 | 160 | def forward(self, hidden_states, attention_mask): 161 | attention_output = self.attention(hidden_states, attention_mask) 162 | intermediate_output = self.intermediate(attention_output) 163 | layer_output = self.output(intermediate_output, attention_output) 164 | return layer_output 165 | -------------------------------------------------------------------------------- /code/data/video.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle as pkl 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from exp import ex 9 | 10 | 11 | # List of available (implemented) features 12 | geometry_list = [None, 'cartesian', 'angular', 'spherical', 'quaternion'] 13 | visual_list = [None, 'rcnn_all', 'rcnn_center', 'rcnn_cube', 'rcnn_er', 'rcnn_nfov', 'i3d_center', 'i3d_er'] 14 | audio_list = [None, 'orig', 'pool_2', 'pool_4', 'pool_all', 'top_3', 'top_5'] 15 | stereo_list = ['bin', 'reg', 'raw'] # binary, regression, raw 16 | 17 | audio_unit = 0.32 18 | multiplier = { 19 | "orig": 1, 20 | "pool_2": 2, 21 | "pool_4": 4 22 | } 23 | 24 | @ex.capture() 25 | def get_video(cache_path, rebuild_cache, feat_path, mode='train'): 26 | cache_path.mkdir(parents=True, exist_ok=True) 27 | 28 | visual, coordinate = get_visual(mode=mode) 29 | audio = get_audio() 30 | video = { 31 | "visual": visual, 32 | "coordinate": coordinate, 33 | "audio": audio 34 | } 35 | 36 | return video 37 | 38 | 39 | @ex.capture() 40 | def get_visual(feat_path, cache_path, rebuild_cache, data_path, model_config, mode='train'): 41 | visual_feature = model_config['visual_feature'] 42 | geometry = model_config['geometry'] 43 | 44 | visual_cache_file = f"{visual_feature}_{geometry}.pkl" 45 | #coordinate_cache_file = f"{geometry}.pkl" 46 | visual_cache_path = cache_path / visual_cache_file 47 | #coordinate_cache_path = cache_path / coordinate_cache_file 48 | 49 | assert visual_feature is None or (feat_path / 'visual' / visual_feature) in (feat_path / 'visual').glob('*'), \ 50 | "[ERROR] video feature {} does not exist.".format(visual_feature) 51 | assert geometry in geometry_list, \ 52 | "[ERROR] Geometry {} does not exist.".format(geometry) 53 | 54 | if rebuild_cache: 55 | if visual_cache_path.is_file(): 56 | visual_cache_path.unlink() 57 | #if coordinate_cache_path.is_file(): 58 | # coordinate_cache_path.unlink() 59 | if visual_cache_path.is_file(): #and coordinate_cache_path.is_file(): 60 | print(f'[LOG] Loading cached visual feautre from {visual_cache_path}...', end='', flush=True) 61 | visual, coordinate = torch.load(visual_cache_path) 62 | print('Complete!') 63 | #coordinate = torch.load(coordinate_cache_path) 64 | else: 65 | '''Reloading is inevitable unless both features exist''' 66 | visual, coordinate = _get_visual(visual_feature, geometry, feat_path, data_path, mode) 67 | torch.save((visual, coordinate), visual_cache_path) 68 | #torch.save(coordinate, coordinate_cache_path) 69 | 70 | return visual, coordinate 71 | 72 | 73 | @ex.capture() 74 | def get_audio(feat_path, cache_path, rebuild_cache, model_config): 75 | audio_feature = model_config['audio_feature'] 76 | 77 | cache_file = f"{audio_feature}.pkl" 78 | path = cache_path / cache_file 79 | 80 | assert audio_feature is None or (feat_path / 'audio' / f'{audio_feature}_feat') in (feat_path / 'audio').glob('*'), \ 81 | "[ERROR] audio feature {} does not exist.".format(audio_feature) 82 | 83 | if rebuild_cache and path.is_file(): 84 | path.unlink() 85 | if path.is_file(): 86 | print(f'[LOG] Loading cached audio feautre from {path}...', end='', flush=True) 87 | audio = torch.load(path) 88 | print('Complete!') 89 | else: 90 | audio = _get_audio(audio_feature, feat_path) 91 | torch.save(audio, path) 92 | 93 | return audio 94 | 95 | 96 | def _get_visual(visual_feature, geometry, feat_path, data_path, mode): 97 | '''Return requested visual feature''' 98 | visual_path = feat_path / 'visual' 99 | 100 | visual = {} 101 | coordinate = {} if geometry is not None else None 102 | 103 | if visual_feature is None: 104 | '''No visual feature''' 105 | return None, None 106 | 107 | metadata = json.load(open(data_path[mode], 'r'))['videos'] 108 | 109 | if 'rcnn' in visual_feature: 110 | '''Visual features from RCNN family''' 111 | visual_path = visual_path / visual_feature 112 | 113 | for vid in tqdm(visual_path.glob('*'), 114 | desc="Loading visual feature", 115 | total=len(list(visual_path.glob('*')))): 116 | feats = pkl.load(open(vid, 'rb')) 117 | video_id = vid.stem.split('.')[0] 118 | 119 | embedding = feats['feat'] 120 | score = feats['score'] 121 | classes = feats['cls'] 122 | if geometry == "cartesian": 123 | # Scale down cartesian coordinate w.r.t. width and height 124 | # RCNN utilized fixed width and height, thus we use 125 | w = 480. if visual_feature == "rcnn_center" else 1920.# float(metadata[video_id]['width']) 126 | h = 320. if visual_feature == "rcnn_center" else 1080.# float(metadata[video_id]['height']) 127 | feats[geometry] = [[g[0], g[1]/w, g[2]/h, g[3]/w, g[4]/h] for g in feats[geometry]] 128 | if len(feats[geometry]) == 0: 129 | feats[geometry] = np.zeros((0, 5)) 130 | elif geometry == 'spherical': 131 | feats[geometry][:,-2:] = feats['angular'][:,-2:] 132 | 133 | visual[video_id] = { 134 | "embedding": embedding, 135 | "score": np.array(score), 136 | "classes": np.array(classes), 137 | "coordinate": np.array(feats[geometry]) 138 | } 139 | 140 | if geometry is not None: 141 | coordinate[video_id] = np.array(feats[geometry]) 142 | elif 'i3d' in visual_feature: 143 | '''Visual features from I3D family''' 144 | visual_path = visual_path / visual_feature 145 | coordinate = None 146 | 147 | for vid in tqdm(visual_path.glob('*')): 148 | feats = pkl.load(open(vid, 'rb')) 149 | visual[video_id] ={ 150 | "embedding": feats 151 | } 152 | else: 153 | assert False, "[ERROR] Unimplemented feature request in visual modality." 154 | 155 | return visual, coordinate 156 | 157 | 158 | def _get_audio(audio_feature, feat_path): 159 | audio_path = feat_path / 'audio' 160 | 161 | audio = {} 162 | 163 | if audio_feature is None: 164 | return None 165 | 166 | stereo_path = audio_path / 'harmonics' / f'{audio_feature}_reg.json' 167 | stereo_feat = json.load(open(stereo_path, 'r')) 168 | 169 | if 'top' in audio_feature: 170 | time_feat = json.load(open(audio_path / f'{audio_feature}_time.json', 'r')) 171 | else: 172 | time_feat = {} 173 | 174 | for aud in tqdm((audio_path / f'{audio_feature}_feat').glob('*'), 175 | desc="Loading audio feature", 176 | total=len(list((audio_path / f'{audio_feature}_feat').glob('*')))): 177 | feats = pkl.load(open(aud, 'rb')) 178 | video_id = aud.stem.split('.')[0] 179 | label = pkl.load(open(audio_path / f'{audio_feature}_label' / f'{video_id}.pkl', 'rb')) 180 | harmonics = stereo_feat[video_id][:len(label)] 181 | audio_len = pkl.load(open(audio_path / f'orig_label' / f'{video_id}.pkl', 'rb')).shape[0] * audio_unit 182 | 183 | score = torch.sigmoid(torch.from_numpy(label)).numpy() 184 | label = np.argmax(score, axis=1) 185 | score = np.max(score, axis=1) 186 | 187 | if 'top' not in audio_feature: 188 | a_start = np.array([audio_unit * multiplier[audio_feature] for _ in range(label.shape[0]+1)]) 189 | else: 190 | a_start = np.array([0.] + time_feat[video_id]) 191 | a_duration = a_start[1:] - a_start[:-1] 192 | a_start = a_start[:-1] 193 | a_coord = np.array([[a_start[i], a_duration[i], harmonics[i]] for i in range(a_start.shape[0])]) 194 | 195 | audio[video_id] = { 196 | "embedding": feats, 197 | "score": score, 198 | "classes": label, 199 | "harmonics": np.array(harmonics), 200 | "coordinate": a_coord 201 | } 202 | 203 | return audio 204 | -------------------------------------------------------------------------------- /code/model/pooler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | poolerLoss = { 8 | 'ce': nn.CrossEntropyLoss(ignore_index=-1), 9 | 'l2': nn.SmoothL1Loss(reduction='none'), 10 | 'l2_reduction': nn.SmoothL1Loss(), 11 | 'ce_no_reduction': nn.CrossEntropyLoss(ignore_index=-1, reduction='none'), 12 | 'kl': lambda x,y: nn.KLDivLoss(reduction='none')(F.log_softmax(x, dim=2), y) 13 | } 14 | 15 | def gelu(x): 16 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 17 | 18 | ACT2FN = {"gelu": gelu, "relu": nn.functional.relu} 19 | 20 | 21 | class GeLU(nn.Module): 22 | def __init__(self): 23 | # torch.nn.functional.gelu is not a Module subclass 24 | super().__init__() 25 | 26 | def forward(self, x): 27 | return gelu(x) 28 | 29 | 30 | class BertPooler(nn.Module): 31 | def __init__(self, model_config): 32 | super().__init__() 33 | self.dense = nn.Linear(model_config.hidden_size, model_config.hidden_size) 34 | self.activation = nn.Tanh() 35 | 36 | def forward(self, hidden_states): 37 | cls_tensor = hidden_states[:, 0] 38 | output = self.activation(self.dense(cls_tensor)) 39 | return output 40 | 41 | 42 | class BertPredictionHeadTransform(nn.Module): 43 | def __init__(self, model_config): 44 | super().__init__() 45 | self.dense = nn.Linear(model_config.hidden_size, model_config.hidden_size) 46 | if isinstance(model_config.hidden_act, str): 47 | self.transform_act_fn = ACT2FN[model_config.hidden_act] 48 | else: 49 | self.transform_act_fn = model_config.hidden_act 50 | self.LayerNorm = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 51 | 52 | def forward(self, hidden_states): 53 | hidden_states = self.dense(hidden_states) 54 | hidden_states = self.transform_act_fn(hidden_states) 55 | hidden_states = self.LayerNorm(hidden_states) 56 | return hidden_states 57 | 58 | 59 | class BertLMPredictionHead(nn.Module): 60 | def __init__(self, model_config, bert_weights=None): 61 | super().__init__() 62 | self.transform = BertPredictionHeadTransform(model_config) 63 | 64 | if bert_weights is None: 65 | self.decoder = nn.Linear(model_config.hidden_size, model_config.vocab_size, bias=False) 66 | self.bias = nn.Parameter(torch.zeros(model_config.vocab_size)) 67 | else: 68 | self.decoder = nn.Linear(bert_weights.size(1), bert_weights.size(0), bias=False) 69 | self.decoder.weight = bert_weights 70 | self.bias = nn.Parameter(torch.zeros(bert_weights.size(0))) 71 | 72 | def forward(self, hidden_states): 73 | hidden_states = self.transform(hidden_states) 74 | output = self.decoder(hidden_states) + self.bias 75 | return output 76 | 77 | 78 | class LanguageHead(nn.Module): 79 | def __init__(self, model_config, bert_weights=None): 80 | super().__init__() 81 | self.predictions = BertLMPredictionHead(model_config, bert_weights) 82 | 83 | def forward(self, hidden_states): 84 | return self.predictions(hidden_states) 85 | 86 | 87 | class VisualHead(nn.Module): 88 | def __init__(self, model_config): 89 | super().__init__() 90 | self.transform = BertPredictionHeadTransform(model_config) 91 | self.tasks = {} 92 | 93 | if 'visual_feat' in model_config.pretrain_types: 94 | self.feat_decoder = nn.Linear(model_config.hidden_size, model_config.visual_feat_dim) 95 | self.tasks["visual_feat"] = self.feat_decoder 96 | if 'visual_coord' in model_config.pretrain_types: 97 | self.coord_decoder = nn.Linear(model_config.hidden_size, model_config.visual_coord_dim) 98 | self.tasks["visual_coord"] = self.coord_decoder 99 | if 'visual_label' in model_config.pretrain_types: 100 | self.label_decoder = nn.Linear(model_config.hidden_size, model_config.visual_label_dim) 101 | self.tasks["visual_label"] = self.label_decoder 102 | 103 | def forward(self, hidden_states): 104 | hidden_states = self.transform(hidden_states) 105 | output = {} 106 | for task, decoder in self.tasks.items(): 107 | output[task] = decoder(hidden_states) 108 | return output 109 | 110 | 111 | class AudioHead(nn.Module): 112 | def __init__(self, model_config): 113 | super().__init__() 114 | self.transform = BertPredictionHeadTransform(model_config) 115 | self.tasks = {} 116 | self.feat_dim = model_config.audio_feat_dim * 2 if model_config.use_stereo_audio else model_config.audio_feat_dim 117 | 118 | if 'audio_feat' in model_config.pretrain_types: 119 | self.feat_decoder = nn.Linear(model_config.hidden_size, self.feat_dim) 120 | self.tasks["audio_feat"] = self.feat_decoder 121 | if 'audio_harmonics' in model_config.pretrain_types: 122 | self.harm_decoder = nn.Linear(model_config.hidden_size, 1) 123 | self.tasks["audio_harmonics"] = self.harm_decoder # regression 124 | if 'audio_harmonics_reg' in model_config.pretrain_types: 125 | self.harm_decoder = nn.Linear(model_config.hidden_size, 1) 126 | self.tasks["audio_harmonics_reg"] = self.harm_decoder 127 | if 'audio_harmonics_bin' in model_config.pretrain_types: 128 | self.harm_decoder = nn.Linear(model_config.hidden_size, 3) # -1, 0, 1 129 | self.tasks["audio_harmonics_reg"] = self.harm_decoder 130 | if 'audio_label' in model_config.pretrain_types: 131 | self.label_decoder = nn.Linear(model_config.hidden_size, model_config.audio_label_dim) 132 | self.tasks["audio_label"] = self.label_decoder 133 | if 'audio_coord' in model_config.pretrain_types: 134 | self.coord_decoder = nn.Linear(model_config.hidden_size, model_config.audio_coord_dim) 135 | self.tasks["audio_coord"] = self.coord_decoder 136 | 137 | def forward(self, hidden_states): 138 | hidden_states = self.transform(hidden_states) 139 | output = {} 140 | for task, decoder in self.tasks.items(): 141 | output[task] = decoder(hidden_states).squeeze() 142 | return output 143 | 144 | 145 | class AnswerHead(nn.Module): 146 | def __init__(self, model_config, num_modality=1): 147 | super().__init__() 148 | in_dim = model_config.hidden_size 149 | hid_dim = 2 * in_dim 150 | if model_config.use_concat_decoder and num_modality > 1: 151 | in_dim = num_modality * in_dim 152 | hid_dim = 2 * in_dim 153 | num_answers = model_config.num_answers 154 | 155 | self.logit_fc = nn.Sequential( 156 | nn.Linear(in_dim, hid_dim), 157 | GeLU(), 158 | nn.LayerNorm(hid_dim, eps=model_config.layer_norm_eps), 159 | nn.Linear(hid_dim, num_answers) 160 | ) 161 | 162 | def forward(self, x): 163 | if type(x) == list: 164 | return self.logit_fc(torch.cat(x, 1)) 165 | else: 166 | return self.logit_fc(x) 167 | 168 | 169 | class GroundHead(nn.Module): 170 | def __init__(self, model_config, num_modality=1): 171 | super().__init__() 172 | in_dim = model_config.hidden_size 173 | hid_dim = 2 * in_dim 174 | if model_config.use_concat_decoder and num_modality > 1: 175 | in_dim = num_modality * in_dim 176 | hid_dim = 2 * in_dim 177 | ground_dim = model_config.visual_coord_dim - 1 178 | 179 | self.logit_fc = nn.Sequential( 180 | nn.Linear(in_dim, hid_dim), 181 | GeLU(), 182 | nn.LayerNorm(hid_dim, eps=model_config.layer_norm_eps), 183 | nn.Linear(hid_dim, ground_dim) 184 | ) 185 | 186 | def forward(self, x): 187 | if type(x) == list: 188 | return self.logit_fc(torch.cat(x, 1)) 189 | else: 190 | return self.logit_fc(x) 191 | 192 | 193 | class MatchHead(nn.Module): 194 | def __init__(self, model_config, use_concat_decoder=None): 195 | super().__init__() 196 | self.use_concat_decoder = use_concat_decoder if use_concat_decoder is not None else model_config.use_concat_decoder 197 | if self.use_concat_decoder: 198 | self.match_head = nn.Linear(2 * model_config.hidden_size, 2) 199 | else: 200 | self.match_head = nn.Linear(model_config.hidden_size, 2) 201 | 202 | def forward(self, x): 203 | if type(x) == list: 204 | return self.match_head(torch.cat(x, 1)) 205 | else: 206 | return self.match_head(x) 207 | 208 | 209 | class DummyHead(nn.Module): 210 | def __init__(self): 211 | super().__init__() 212 | 213 | def forward(self, x): 214 | return None -------------------------------------------------------------------------------- /code/model/input.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BertEmbeddings(nn.Module): 6 | """Construct the embeddings from word, position and token_type embeddings. 7 | """ 8 | def __init__(self, model_config): 9 | super(BertEmbeddings, self).__init__() 10 | self.word_embeddings = nn.Embedding(model_config.vocab_size, model_config.hidden_size, padding_idx=0) 11 | self.position_embeddings = nn.Embedding(model_config.max_position_embeddings, model_config.hidden_size, padding_idx=0) 12 | self.token_type_embeddings = nn.Embedding(model_config.type_vocab_size, model_config.hidden_size, padding_idx=0) 13 | 14 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 15 | # any TensorFlow checkpoint file 16 | self.LayerNorm = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 17 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 18 | 19 | def forward(self, input_ids, token_type_ids=None): 20 | seq_length = input_ids.size(1) 21 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 22 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 23 | if token_type_ids is None: 24 | token_type_ids = torch.zeros_like(input_ids) 25 | 26 | words_embeddings = self.word_embeddings(input_ids) 27 | position_embeddings = self.position_embeddings(position_ids) 28 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 29 | 30 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 31 | embeddings = self.LayerNorm(embeddings) 32 | embeddings = self.dropout(embeddings) 33 | return embeddings 34 | 35 | 36 | class VisualFeatEncoder(nn.Module): 37 | def __init__(self, model_config): 38 | super().__init__() 39 | 40 | self.feature_fc = nn.Linear(model_config.visual_feat_dim, model_config.hidden_size) 41 | self.feature_ln = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 42 | 43 | self.coord_fc = nn.Linear(model_config.visual_coord_dim, model_config.hidden_size) 44 | self.coord_in = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 45 | 46 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 47 | 48 | def forward(self, v_input): 49 | feats, boxes = v_input 50 | 51 | feature_out = self.feature_ln(self.feature_fc(feats)) 52 | coord_out = self.coord_in(self.coord_fc(boxes)) 53 | output = self.dropout((feature_out + coord_out) / 2) 54 | 55 | return output 56 | 57 | 58 | class VisualFeatNoCoordEncoder(nn.Module): 59 | def __init__(self, model_config): 60 | super().__init__() 61 | 62 | self.feat_fc = nn.Linear(model_config.visual_feat_dim, model_config.hidden_size) 63 | self.feat_ln = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 64 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 65 | 66 | def forward(self, v_input): 67 | feats, _ = v_input 68 | 69 | return self.dropout(self.feat_ln(self.feat_fc(feats))) 70 | 71 | 72 | class VisualFeatConcatEncoder(nn.Module): 73 | def __init__(self, model_config): 74 | super().__init__() 75 | # AFAIK it is identical to VisualFeatEncoder, since it is mere decoupling of two fc 76 | # (Wx+a) + (Vy+b) = [W:V][x:y] + (a+b) 77 | # -> In fact, they are slightly different from VisualFeatEncoder due to nonlinearities 78 | self.feature = nn.Linear(model_config.visual_feat_dim + model_config.visual_coord_dim, model_config.hidden_size) 79 | self.layernorm = nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps) 80 | self.dropout = nn.Dropout(model_config.hidden_dropout_prob) 81 | 82 | def forward(self, v_input): 83 | v_input = torch.cat(v_input, -1) 84 | 85 | output = self.dropout(self.layernorm(self.feature(v_input))) 86 | return output 87 | 88 | 89 | class AudioMonoEncoder(nn.Module): 90 | def __init__(self, model_config): 91 | super().__init__() 92 | f_dim = model_config.audio_feat_dim 93 | h_dim = model_config.hidden_size 94 | dropout_rate = model_config.hidden_dropout_prob 95 | eps = model_config.layer_norm_eps 96 | 97 | self.feat_fc = nn.Linear(f_dim, h_dim) 98 | self.feat_ln = nn.LayerNorm(h_dim, eps=eps) 99 | self.dropout = nn.Dropout(dropout_rate) 100 | 101 | 102 | def forward(self, a_input): 103 | a_feat, _ = a_input 104 | output = self.dropout(self.feat_ln(self.feat_fc(a_feat))) 105 | return output 106 | 107 | 108 | class AudioMonoTEncoder(nn.Module): 109 | def __init__(self, model_config): 110 | super().__init__() 111 | f_dim = model_config.audio_feat_dim 112 | h_dim = model_config.hidden_size 113 | dropout_rate = model_config.hidden_dropout_prob 114 | eps = model_config.layer_norm_eps 115 | 116 | self.feat_fc = nn.Linear(f_dim, h_dim) 117 | self.cord_fc = nn.Linear(2, h_dim) 118 | self.feat_ln = nn.LayerNorm(h_dim, eps=eps) 119 | self.cord_ln = nn.LayerNorm(h_dim, eps=eps) 120 | self.dropout = nn.Dropout(dropout_rate) 121 | 122 | 123 | def forward(self, a_input): 124 | a_feat, a_cord = a_input 125 | feat_out = self.feat_ln(self.feat_fc(a_feat)) 126 | cord_out = self.cord_ln(self.cord_fc(a_cord[:,:2])) 127 | 128 | return self.dropout((feat_out + cord_out) / 2) 129 | 130 | 131 | class AudioMonoSEncoder(nn.Module): 132 | def __init__(self, model_config): 133 | super().__init__() 134 | f_dim = model_config.audio_feat_dim 135 | h_dim = model_config.hidden_size 136 | dropout_rate = model_config.hidden_dropout_prob 137 | eps = model_config.layer_norm_eps 138 | 139 | self.feat_fc = nn.Linear(f_dim, h_dim) 140 | self.cord_fc = nn.Linear(1, h_dim) 141 | self.feat_ln = nn.LayerNorm(h_dim, eps=eps) 142 | self.cord_ln = nn.LayerNorm(h_dim, eps=eps) 143 | self.dropout = nn.Dropout(dropout_rate) 144 | 145 | 146 | def forward(self, a_input): 147 | a_feat, a_cord = a_input 148 | feat_out = self.feat_ln(self.feat_fc(a_feat)) 149 | cord_out = self.cord_ln(self.cord_fc(a_cord[:,2])) 150 | 151 | return self.dropout((feat_out + cord_out) / 2) 152 | 153 | 154 | class AudioMonoSTEncoder(nn.Module): 155 | def __init__(self, model_config): 156 | super().__init__() 157 | f_dim = model_config.audio_feat_dim 158 | h_dim = model_config.hidden_size 159 | dropout_rate = model_config.hidden_dropout_prob 160 | eps = model_config.layer_norm_eps 161 | 162 | self.feat_fc = nn.Linear(f_dim, h_dim) 163 | self.cord_fc = nn.Linear(3, h_dim) 164 | self.feat_ln = nn.LayerNorm(h_dim, eps=eps) 165 | self.cord_ln = nn.LayerNorm(h_dim, eps=eps) 166 | self.dropout = nn.Dropout(dropout_rate) 167 | 168 | 169 | def forward(self, a_input): 170 | a_feat, a_cord = a_input 171 | feat_out = self.feat_ln(self.feat_fc(a_feat)) 172 | cord_out = self.cord_ln(self.cord_fc(a_cord)) 173 | 174 | return self.dropout((feat_out + cord_out) / 2) 175 | 176 | 177 | class AudioStereoEncoder(nn.Module): 178 | def __init__(self, model_config): 179 | super().__init__() 180 | f_dim = model_config.audio_feat_dim 181 | h_dim = model_config.hidden_size 182 | dropout_rate = model_config.hidden_dropout_prob 183 | eps = model_config.layer_norm_eps 184 | 185 | self.left_fc = nn.Linear(f_dim, h_dim) 186 | self.righ_fc = nn.Linear(f_dim, h_dim) 187 | self.left_ln = nn.LayerNorm(h_dim, eps=eps) 188 | self.righ_ln = nn.LayerNorm(h_dim, eps=eps) 189 | self.dropout = nn.Dropout(dropout_rate) 190 | 191 | 192 | def forward(self, a_input): 193 | a_feat, _ = a_input 194 | a_left = a_feat[:,0,:,:] 195 | a_righ = a_feat[:,1,:,:] 196 | 197 | left_out = self.left_ln(self.left_fc(a_left)) 198 | righ_out = self.righ_ln(self.righ_fc(a_righ)) 199 | return self.dropout((left_out + righ_out) / 2) 200 | 201 | 202 | class AudioStereoSEncoder(nn.Module): 203 | def __init__(self, model_config): 204 | super().__init__() 205 | f_dim = model_config.audio_feat_dim 206 | h_dim = model_config.hidden_size 207 | dropout_rate = model_config.hidden_dropout_prob 208 | eps = model_config.layer_norm_eps 209 | 210 | self.left_fc = nn.Linear(f_dim, h_dim) 211 | self.righ_fc = nn.Linear(f_dim, h_dim) 212 | self.cord_fc = nn.Linear(1, h_dim) 213 | self.left_ln = nn.LayerNorm(h_dim, eps=eps) 214 | self.righ_ln = nn.LayerNorm(h_dim, eps=eps) 215 | self.cord_ln = nn.LayerNorm(h_dim, eps=eps) 216 | self.dropout = nn.Dropout(dropout_rate) 217 | 218 | def forward(self, a_input): 219 | a_feat, a_cord = a_input 220 | a_left = a_feat[:,0,:,:] 221 | a_righ = a_feat[:,1,:,:] 222 | 223 | left_out = self.left_ln(self.left_fc(a_left)) 224 | righ_out = self.righ_ln(self.righ_fc(a_righ)) 225 | cord_out = self.cord_ln(self.cord_fc(a_cord[:,-1])) 226 | return self.dropout((left_out + righ_out + cord_out) / 3) 227 | 228 | 229 | class AudioStereoTEncoder(nn.Module): 230 | def __init__(self, model_config): 231 | super().__init__() 232 | f_dim = model_config.audio_feat_dim 233 | h_dim = model_config.hidden_size 234 | dropout_rate = model_config.hidden_dropout_prob 235 | eps = model_config.layer_norm_eps 236 | 237 | self.left_fc = nn.Linear(f_dim, h_dim) 238 | self.righ_fc = nn.Linear(f_dim, h_dim) 239 | self.cord_fc = nn.Linear(2, h_dim) 240 | self.left_ln = nn.LayerNorm(h_dim, eps=eps) 241 | self.righ_ln = nn.LayerNorm(h_dim, eps=eps) 242 | self.cord_ln = nn.LayerNorm(h_dim, eps=eps) 243 | self.dropout = nn.Dropout(dropout_rate) 244 | 245 | def forward(self, a_input): 246 | a_feat, a_cord = a_input 247 | a_left = a_feat[:,0,:,:] 248 | a_righ = a_feat[:,1,:,:] 249 | 250 | left_out = self.left_ln(self.left_fc(a_left)) 251 | righ_out = self.righ_ln(self.righ_fc(a_righ)) 252 | cord_out = self.cord_ln(self.cord_fc(a_cord[:,:-1])) 253 | return self.dropout((left_out + righ_out + cord_out) / 3) 254 | 255 | 256 | class AudioStereoSTEncoder(nn.Module): 257 | def __init__(self, model_config): 258 | super().__init__() 259 | f_dim = model_config.audio_feat_dim 260 | h_dim = model_config.hidden_size 261 | dropout_rate = model_config.hidden_dropout_prob 262 | eps = model_config.layer_norm_eps 263 | 264 | self.left_fc = nn.Linear(f_dim, h_dim) 265 | self.righ_fc = nn.Linear(f_dim, h_dim) 266 | self.cord_fc = nn.Linear(3, h_dim) 267 | self.left_ln = nn.LayerNorm(h_dim, eps=eps) 268 | self.righ_ln = nn.LayerNorm(h_dim, eps=eps) 269 | self.cord_ln = nn.LayerNorm(h_dim, eps=eps) 270 | self.dropout = nn.Dropout(dropout_rate) 271 | 272 | def forward(self, a_input): 273 | a_feat, a_cord = a_input 274 | a_left = a_feat[:,0,:,:] 275 | a_righ = a_feat[:,1,:,:] 276 | 277 | left_out = self.left_ln(self.left_fc(a_left)) 278 | righ_out = self.righ_ln(self.righ_fc(a_righ)) 279 | cord_out = self.cord_ln(self.cord_fc(a_cord)) 280 | return self.dropout((left_out + righ_out + cord_out) / 3) 281 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from exp import ex 9 | from ckpt import save_ckpt 10 | from metrics.logger import write_logs 11 | from optimizer import get_optimizer 12 | from metrics.logger import write_logs 13 | from common import prepare_batch, get_all 14 | 15 | 16 | grounding_error = torch.nn.MSELoss(reduction='none') 17 | 18 | 19 | @ex.capture() 20 | def get_pretrain_task(split, epoch, pretrain_epochs, model_name, pretrain_types, model_config): 21 | if split != 'pretrain': 22 | return model_config.finetune_types 23 | else: 24 | if model_name in ["bert", "bert_scratch"]: 25 | return ["mask_lm", "ground"] 26 | 27 | elif model_name == 'lxmert': 28 | if epoch < (pretrain_epochs/2): 29 | return ["mask_lm", "vl_match", "visual_feat", "visual_label"] 30 | else: 31 | return ["mask_lm", "vl_match", "visual_feat", "visual_label", "qa"] 32 | else: #lavit 33 | return pretrain_types 34 | 35 | 36 | def get_accs(gt, prop, qtype): 37 | retval = {} 38 | correct = (gt == prop).float() 39 | retval['acc_total'] = [k.item() for k in correct] 40 | is_av = torch.Tensor(np.array(qtype) == 'a').float() 41 | retval['acc_av'] = [k.item() for i, k in enumerate(correct) if is_av[i] == 1] 42 | retval['acc_sp'] = [k.item() for i, k in enumerate(correct) if is_av[i] == 0] 43 | 44 | return retval 45 | 46 | 47 | def get_errors(gt, prop, qtype): 48 | retval = {} 49 | errors = grounding_error(prop, gt).sum(1) 50 | retval['mse_total'] = [k.item() for k in errors] 51 | is_av = torch.Tensor(np.array(qtype) == 'a').float() 52 | retval['mse_av'] = [k.item() for i, k in enumerate(errors) if is_av[i] == 1] 53 | retval['mse_sp'] = [k.item() for i, k in enumerate(errors) if is_av[i] == 0] 54 | 55 | return retval 56 | 57 | 58 | 59 | @ex.capture() 60 | def _train(log_path, ckpt_path, config_dir, max_epochs, pretrain_epochs, answer_path, learning_rate, 61 | pretrain_learning_rate, pretrain_types, split_train, model_config, _config): 62 | dataloaders, _, tokenizer, model, criterion = get_all(data_modes=['pretrain', 'train', 'preval', 'val', 'test']) 63 | print("[LOG] Logging to {}".format(log_path / config_dir)) 64 | logger = torch.utils.tensorboard.SummaryWriter(log_path / config_dir) 65 | 66 | answer_dict = json.load(open(answer_path, 'r')) 67 | 68 | # print(model) 69 | # PRETRAIN 70 | it = 0 71 | if pretrain_epochs > 0 and hasattr(model, 'run_pretrain'): 72 | if split_train: 73 | model.fix_transformer(True) 74 | 75 | optimizer, scheduler = get_optimizer(model, 76 | dataloaders['pretrain'].dataset.t_total, 77 | learning_rate=pretrain_learning_rate) 78 | optimizer.zero_grad() 79 | 80 | 81 | for epoch in range(pretrain_epochs): 82 | 83 | model.train() 84 | 85 | pretrain_tasks = get_pretrain_task('pretrain', epoch) 86 | print(pretrain_tasks) 87 | 88 | for _batch in tqdm(dataloaders['pretrain'], total=len(dataloaders['pretrain']), desc=f"Pretrain e{epoch}"): 89 | batch, label, meta = prepare_batch(_batch) 90 | 91 | stats = model(batch, label, pretrain_tasks) 92 | stats['total_loss'].backward() 93 | # if debug: 94 | # with torch.autograd.set_detect_anomaly(True): 95 | # stats = model(batch, label) 96 | 97 | scheduler.accumulated += 1 98 | if scheduler.accumulated >= scheduler.grad_acc_steps: 99 | optimizer.step() 100 | scheduler.step() 101 | scheduler.accumulated = 0 102 | optimizer.zero_grad() 103 | 104 | if it % 20 == 0: 105 | write_logs(logger, it, optimizer.param_groups[0]['lr'], stats, meta, 'train') 106 | 107 | it += 1 108 | 109 | # Evaluate 110 | model.eval() 111 | scheduler.accumulated = 0 112 | eval_stats = {} 113 | acc = {'acc_total': [], 'acc_sp': [], 'acc_av': []} 114 | mse = {'mse_total': [], 'mse_sp': [], 'mse_av': []} 115 | qas = [] 116 | for _batch in tqdm(dataloaders['preval'], total=len(dataloaders['preval']), desc="Valid"): 117 | batch, label, meta = prepare_batch(_batch) 118 | 119 | with torch.no_grad(): 120 | stats = model(batch, label, pretrain_tasks) 121 | for k, v in get_accs(label['qa'].cpu(), stats['answer_pred'], meta['question_type']).items(): 122 | acc[k].extend(v) 123 | if 'ground_pred' in stats.keys(): 124 | for k, v in get_errors(label['ground'].cpu(), stats['ground_pred'], meta['question_type']).items(): 125 | mse[k].extend(v) 126 | qas.append("[{}] {} / (GT) {} (PROP) {} / (GT) {} (PROP) {}".format( 127 | meta['video_id'][0], meta['question'][0], answer_dict[label['qa'][0].item()], answer_dict[stats['answer_pred'][0].item()], 128 | str([f"{i:.3f}" for i in label['ground'][0]]), str([f"{i:.3f}" for i in stats['ground_pred'][0]]) 129 | )) 130 | else: 131 | qas.append("[{}] {} / (GT) {} (PROP) {}".format( 132 | meta['video_id'][0], meta['question'][0], answer_dict[label['qa'][0].item()], answer_dict[stats['answer_pred'][0].item()] 133 | )) 134 | 135 | for k, v in stats.items(): 136 | if type(v) == torch.Tensor and v.dim() == 0: 137 | eval_stats[k] = v.item() if k not in eval_stats.keys() else eval_stats[k] + v.item() 138 | for k, v in acc.items(): 139 | eval_stats[k] = sum(v) / len(v) 140 | if 'ground_pred' in stats.keys(): 141 | for k, v in mse.items(): 142 | eval_stats[k] = sum(v) / len(v) 143 | 144 | print([f"{k}: {v:.4f}" for k,v in eval_stats.items()]) 145 | eval_stats['example'] = '\n\n'.join(qas) 146 | write_logs(logger, epoch, None, eval_stats, meta, 'eval') 147 | save_ckpt(epoch, stats['total_loss'].item(), model) 148 | 149 | # TRAIN 150 | if split_train: 151 | model.fix_transformer(False) 152 | optimizer, scheduler = get_optimizer(model, dataloaders['train'].dataset.t_total, learning_rate=learning_rate) 153 | optimizer.zero_grad() 154 | 155 | for epoch in range(pretrain_epochs, max_epochs): 156 | 157 | train_tasks = get_pretrain_task('train', epoch) 158 | print(train_tasks) 159 | 160 | model.train() 161 | for _batch in tqdm(dataloaders['train'], total=len(dataloaders['train']), desc=f"Train e{epoch}"): 162 | batch, label, meta = prepare_batch(_batch) 163 | 164 | stats = model(batch, label, train_tasks) 165 | stats['total_loss'].backward() 166 | 167 | scheduler.accumulated += 1 168 | if scheduler.accumulated >= scheduler.grad_acc_steps: 169 | optimizer.step() 170 | scheduler.step() 171 | scheduler.accumulated = 0 172 | optimizer.zero_grad() 173 | 174 | if it % 20 == 0: 175 | write_logs(logger, it, optimizer.param_groups[0]['lr'], stats, meta, 'train') 176 | it += 1 177 | 178 | # Evaluate 179 | model.eval() 180 | scheduler.accumulated = 0 181 | eval_stats = {} 182 | acc = {'acc_total': [], 'acc_sp': [], 'acc_av': []} 183 | mse = {'mse_total': [], 'mse_sp': [], 'mse_av': []} 184 | qas = [] 185 | for _batch in tqdm(dataloaders['val'], total=len(dataloaders['val']), desc="Valid"): 186 | batch, label, meta = prepare_batch(_batch) 187 | 188 | with torch.no_grad(): 189 | stats = model(batch, label, train_tasks) 190 | for k, v in get_accs(label['qa'].cpu(), stats['answer_pred'], meta['question_type']).items(): 191 | acc[k].extend(v) 192 | if 'ground_pred' in stats.keys(): 193 | for k, v in get_errors(label['ground'].cpu(), stats['ground_pred'], meta['question_type']).items(): 194 | mse[k].extend(v) 195 | qas.append("[{}] {} / (GT) {} (PROP) {} / (GT) {} (PROP) {}".format( 196 | meta['video_id'][0], meta['question'][0], answer_dict[label['qa'][0].item()], answer_dict[stats['answer_pred'][0].item()], 197 | str([f"{i:.3f}" for i in label['ground'][0]]), str([f"{i:.3f}" for i in stats['ground_pred'][0]]) 198 | )) 199 | else: 200 | qas.append("[{}] {} / (GT) {} (PROP) {}".format( 201 | meta['video_id'][0], meta['question'][0], answer_dict[label['qa'][0].item()], answer_dict[stats['answer_pred'][0].item()] 202 | )) 203 | 204 | for k, v in stats.items(): 205 | if type(v) == torch.Tensor and v.dim() == 0: 206 | eval_stats[k] = v.item() if k not in eval_stats.keys() else eval_stats[k] + v.item() 207 | 208 | for k, v in acc.items(): 209 | eval_stats[k] = sum(v) / len(v) 210 | if 'ground_pred' in stats.keys(): 211 | for k, v in mse.items(): 212 | eval_stats[k] = sum(v) / len(v) 213 | 214 | print([f"{k}: {v:.4f}" for k,v in eval_stats.items()]) 215 | eval_stats['example'] = '\n\n'.join(qas) 216 | write_logs(logger, epoch, None, eval_stats, meta, 'eval') 217 | save_ckpt(epoch, stats['total_loss'].item(), model) 218 | 219 | # Test 220 | model.eval() 221 | scheduler.accumulated = 0 222 | eval_stats = {} 223 | acc = {'acc_total': [], 'acc_sp': [], 'acc_av': []} 224 | mse = {'mse_total': [], 'mse_sp': [], 'mse_av': []} 225 | qas = [] 226 | for _batch in tqdm(dataloaders['test'], total=len(dataloaders['test']), desc="Test"): 227 | batch, label, meta = prepare_batch(_batch) 228 | 229 | with torch.no_grad(): 230 | stats = model(batch, label, train_tasks) 231 | for k, v in get_accs(label['qa'].cpu(), stats['answer_pred'], meta['question_type']).items(): 232 | acc[k].extend(v) 233 | if 'ground_pred' in stats.keys(): 234 | for k, v in get_errors(label['ground'].cpu(), stats['ground_pred'], meta['question_type']).items(): 235 | mse[k].extend(v) 236 | qas.extend(["[{}] {} / (GT) {} (PROP) {} / (GT) {} (PROP) {}".format( 237 | meta['video_id'][i], meta['question'][i], answer_dict[label['qa'][i].item()], answer_dict[stats['answer_pred'][i].item()], 238 | str([f"{i:.3f}" for i in label['ground'][i]]), str([f"{i:.3f}" for i in stats['ground_pred'][i]]) 239 | ) for i in range(len(_batch))]) 240 | else: 241 | qas.extend(["[{}] {} / (GT) {} (PROP) {}".format( 242 | meta['video_id'][i], meta['question'][i], answer_dict[label['qa'][i].item()], answer_dict[stats['answer_pred'][i].item()] 243 | ) for i in range(len(_batch))]) 244 | 245 | for k, v in stats.items(): 246 | if type(v) == torch.Tensor and v.dim() == 0: 247 | eval_stats[k] = v.item() if k not in eval_stats.keys() else eval_stats[k] + v.item() 248 | 249 | for k, v in acc.items(): 250 | eval_stats[k] = sum(v) / len(v) 251 | if 'ground_pred' in stats.keys(): 252 | for k, v in mse.items(): 253 | eval_stats[k] = sum(v) / len(v) 254 | 255 | print([f"{k}: {v:.4f}" for k,v in eval_stats.items()]) 256 | eval_stats['example'] = '\n\n'.join(qas) 257 | write_logs(logger, 0, None, eval_stats, meta, 'test') 258 | print("GPU Allocation: ", torch.cuda.memory_stats()["allocation.all.peak"], "MB") 259 | -------------------------------------------------------------------------------- /code/model/lavit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from transformers.activations import ACT2FN, gelu 5 | 6 | from .decorator import full_model 7 | from .pretrained import BertPreTrainedModel 8 | from .input import * 9 | from .attention import BertAttention, BertCrossAttention, \ 10 | BertIntermediate, BertOutput, BertLayer 11 | from .pooler import LanguageHead, MatchHead, VisualHead, GroundHead, \ 12 | AudioHead, AnswerHead, BertPooler, poolerLoss, DummyHead 13 | 14 | from exp import ex 15 | 16 | ''' 17 | As we follow the naming convention of the official huggingface transformers repo, 18 | there are a few nn modules with different naming convention. 19 | THis dictionary tracks the class naming difference between lxmert and ours. 20 | ''' 21 | pretrain_state_dict_mapper = { 22 | # 'huggingface_convention': 'lxmert_convention' 23 | 'BertSelfAttention': 'BertAttention', 24 | 'BertSelfOutput': 'BertAttOutput', 25 | 'BertAttention': 'BertSelfattLayer', 26 | 'BertCrossAttention': 'BertCrossattLayer', 27 | 'LxmertCrossLayer': 'LXRTXLayer' 28 | } 29 | 30 | 31 | class LavitCrossLayer(nn.Module): 32 | def __init__(self, model_config): 33 | super().__init__() 34 | 35 | self.la_attention = BertCrossAttention(model_config) 36 | self.av_attention = BertCrossAttention(model_config) 37 | self.vl_attention = BertCrossAttention(model_config) 38 | 39 | self.l_self_att = BertAttention(model_config) 40 | self.v_self_att = BertAttention(model_config) 41 | self.a_self_att = BertAttention(model_config) 42 | 43 | self.l_inter = BertIntermediate(model_config) 44 | self.l_output = BertOutput(model_config) 45 | self.v_inter = BertIntermediate(model_config) 46 | self.v_output = BertOutput(model_config) 47 | self.a_inter = BertIntermediate(model_config) 48 | self.a_output = BertOutput(model_config) 49 | 50 | def cross_att(self, l_feat, l_mask, a_feat, a_mask, v_feat, v_mask): 51 | l_out = self.la_attention(l_feat, a_feat, a_mask) 52 | a_out = self.av_attention(a_feat, v_feat, v_mask) 53 | v_out = self.vl_attention(v_feat, l_feat, l_mask) 54 | return l_out, a_out, v_out 55 | 56 | def self_att(self, l_feat, l_mask, a_feat, a_mask, v_feat, v_mask): 57 | l_feat = self.l_self_att(l_feat, l_mask) 58 | a_feat = self.a_self_att(a_feat, a_mask) 59 | v_feat = self.v_self_att(v_feat, v_mask) 60 | 61 | l_out = self.l_output(self.l_inter(l_feat), l_feat) 62 | a_out = self.a_output(self.a_inter(a_feat), a_feat) 63 | v_out = self.v_output(self.v_inter(v_feat), v_feat) 64 | return l_out, a_out, v_out 65 | 66 | def forward(self, l_feat, l_mask, a_feat, a_mask, v_feat, v_mask): 67 | l_out, a_out, v_out = self.cross_att(l_feat, l_mask, a_feat, a_mask, v_feat, v_mask) 68 | l_out, a_out, v_out = self.self_att(l_out, l_mask, a_out, a_mask, v_out, v_mask) 69 | return l_out, a_out, v_out 70 | 71 | 72 | class LavitEncoder(nn.Module): 73 | def __init__(self, model_config): 74 | super().__init__() 75 | 76 | if not model_config.no_coord: 77 | self.v_fc = VisualFeatEncoder(model_config) 78 | else: 79 | self.v_fc = VisualFeatNoCoordEncoder(model_config) 80 | 81 | if model_config.audio_encoder == 'mono': 82 | self.a_fc = AudioMonoEncoder(model_config) 83 | elif model_config.audio_encoder == 'mono_s': 84 | self.a_fc = AudioMonoSEncoder(model_config) 85 | elif model_config.audio_encoder == 'mono_t': 86 | self.a_fc = AudioMonoTEncoder(model_config) 87 | elif model_config.audio_encoder == 'mono_st': 88 | self.a_fc = AudioMonoSTEncoder(model_config) 89 | elif model_config.audio_encoder == 'stereo': 90 | self.a_fc = AudioStereoEncoder(model_config) 91 | elif model_config.audio_encoder == 'stereo_s': 92 | self.a_fc = AudioStereoSEncoder(model_config) 93 | elif model_config.audio_encoder == 'stereo_t': 94 | self.a_fc = AudioStereoTEncoder(model_config) 95 | elif model_config.audio_encoder == 'stereo_st': 96 | self.a_fc = AudioStereoSTEncoder(model_config) 97 | 98 | self.layer = nn.ModuleList( 99 | [BertLayer(model_config) for _ in range(model_config.l_layers)] 100 | ) 101 | self.x_layers = nn.ModuleList( 102 | [LavitCrossLayer(model_config) for _ in range(model_config.x_layers)] 103 | ) 104 | self.v_layers = nn.ModuleList( 105 | [BertLayer(model_config) for _ in range(model_config.v_layers)] 106 | ) 107 | self.a_layers = nn.ModuleList( 108 | [BertLayer(model_config) for _ in range(model_config.a_layers)] 109 | ) 110 | 111 | def forward(self, l_feat, l_mask, a_feat, a_mask=None, v_feat=None, v_mask=None): 112 | v_feat = self.v_fc(v_feat) 113 | a_feat = self.a_fc(a_feat) 114 | 115 | for layer_module in self.layer: 116 | l_feat = layer_module(l_feat, l_mask) 117 | for layer_module in self.v_layers: 118 | v_feat = layer_module(v_feat, v_mask) 119 | for layer_module in self.a_layers: 120 | a_feat = layer_module(a_feat, a_mask) 121 | for layer_module in self.x_layers: 122 | l_feat, a_feat, v_feat = layer_module(l_feat, l_mask, a_feat, a_mask, v_feat, v_mask) 123 | 124 | return l_feat, a_feat, v_feat 125 | 126 | 127 | class LavitModel(BertPreTrainedModel): 128 | def __init__(self, model_config, cache_path): 129 | super().__init__(model_config, cache_path) 130 | self.config = model_config 131 | 132 | self.embeddings = BertEmbeddings(model_config) 133 | self.encoder = LavitEncoder(model_config) 134 | self.pooler = BertPooler(model_config) 135 | self.a_pooler = BertPooler(model_config) 136 | self.v_pooler = BertPooler(model_config) 137 | self.apply(self.init_bert_weights) 138 | 139 | def forward(self, input_ids, token_type_ids=None, l_mask=None, a_feat=None, a_mask=None, 140 | v_feat=None, v_mask=None): 141 | if l_mask is None: 142 | l_mask = torch.ones_like(input_ids) 143 | if token_type_ids is None: 144 | token_type_ids = torch.zeros_like(input_ids) 145 | 146 | e_l_mask = l_mask.unsqueeze(1).unsqueeze(2) 147 | e_l_mask = e_l_mask.to(dtype=next(self.parameters()).dtype) 148 | e_l_mask = (1.0 - e_l_mask) * -10000.0 149 | 150 | if v_mask is not None: 151 | e_v_mask = v_mask.unsqueeze(1).unsqueeze(2) 152 | e_v_mask = e_v_mask.to(dtype=next(self.parameters()).dtype) 153 | e_v_mask = (1.0 - e_v_mask) * -10000.0 154 | else: 155 | e_v_mask = None 156 | 157 | if a_mask is not None: 158 | e_a_mask = a_mask.unsqueeze(1).unsqueeze(2) 159 | e_a_mask = e_a_mask.to(dtype=next(self.parameters()).dtype) 160 | e_a_mask = (1.0 - e_a_mask) * -10000.0 161 | else: 162 | e_a_mask = None 163 | 164 | l_feat = self.embeddings(input_ids, token_type_ids) 165 | l_feat, a_feat, v_feat = self.encoder(l_feat, e_l_mask, a_feat, e_a_mask, v_feat, e_v_mask) 166 | l_pool = self.pooler(l_feat) 167 | a_pool = self.a_pooler(a_feat) 168 | v_pool = self.v_pooler(v_feat) 169 | 170 | return (l_feat, a_feat, v_feat), (l_pool, a_pool, v_pool) 171 | 172 | 173 | class LavitPretraining(BertPreTrainedModel): 174 | def __init__(self, model_config, cache_path): 175 | super().__init__(model_config, cache_path) 176 | self.config = model_config 177 | self.num_modality = 1 178 | 179 | self.bert = LavitModel(model_config, cache_path) 180 | 181 | # mask_lm 182 | self.cls = LanguageHead(model_config, bert_weights=self.bert.embeddings.word_embeddings.weight) 183 | 184 | if 'visual' in [x[:6] for x in model_config.pretrain_types]: 185 | self.v_head = VisualHead(model_config) 186 | self.num_modality += 1 187 | else: 188 | self.v_head = DummyHead() 189 | 190 | if 'audio' in [x[:5] for x in model_config.pretrain_types]: 191 | self.a_head = AudioHead(model_config) 192 | self.num_modality += 1 193 | else: 194 | self.a_head = DummyHead() 195 | 196 | if 'vl_match' in model_config.pretrain_types: 197 | self.vl_head = MatchHead(model_config) 198 | else: 199 | self.vl_head = DummyHead() 200 | 201 | if 'al_match' in model_config.pretrain_types: 202 | self.al_head = MatchHead(model_config) 203 | else: 204 | self.al_head = DummyHead() 205 | 206 | self.answer_head = AnswerHead(model_config, num_modality=self.num_modality) 207 | 208 | if 'ground' in model_config.pretrain_types: 209 | self.ground = GroundHead(model_config, num_modality=self.num_modality) 210 | else: 211 | self.ground = DummyHead() 212 | 213 | self.apply(self.init_bert_weights) 214 | 215 | def forward(self, input_ids, token_type_ids=None, l_mask=None, a_feat=None, a_mask=None, 216 | v_feat=None, v_mask=None): 217 | (l_out, a_out, v_out), (l_head, a_head, v_head) = self.bert( 218 | input_ids, token_type_ids, l_mask, 219 | a_feat=a_feat, a_mask=None, v_feat=v_feat, v_mask=None 220 | ) 221 | 222 | pred = {} 223 | pred['mask_lm'] = self.cls(l_out) 224 | pred.update(self.a_head(a_out)) 225 | pred.update(self.v_head(v_out)) 226 | pred['vl_match'] = self.vl_head([v_head, l_head]) 227 | pred['al_match'] = self.al_head([a_head, l_head]) 228 | pred['qa'] = self.answer_head([l_head, a_head, v_head]) 229 | pred['ground'] = self.ground([l_head, a_head, v_head]) 230 | 231 | return pred 232 | 233 | 234 | @full_model 235 | class Lavit(nn.Module): 236 | def __init__(self, model_config, cache_path): 237 | super().__init__() 238 | self.config = model_config 239 | self.pretrain_loss_config = model_config.pretrain_loss_config 240 | self.audio_encoder = model_config.audio_encoder 241 | self.model = LavitPretraining.from_pretrained(model_config, cache_path) 242 | 243 | def forward(self, batch, label, tasks): 244 | input_ids = batch.get('l_feat') 245 | token_type_ids = None 246 | l_mask = batch.get('l_mask') 247 | v_feat = (batch.get('v_feat'), batch.get('v_coord')) 248 | v_mask = batch.get('v_mask') 249 | a_feat = (batch.get('a_feat'), batch.get('a_coord')) 250 | a_mask = batch.get('a_mask') 251 | 252 | pred = self.model(input_ids, token_type_ids, l_mask, 253 | a_feat=a_feat, a_mask=None, v_feat=v_feat, v_mask=None) 254 | 255 | mask = { 256 | "visual_feat": batch.get('v_mask'), 257 | "visual_label": label.get('visual_score'), 258 | "visual_coord": batch.get('v_mask'), 259 | "audio_feat": batch.get('a_mask'), 260 | "audio_label": label.get("audio_score"), 261 | "audio_coord": batch.get("a_mask"), 262 | "qa": label.get('qa_valid') 263 | } 264 | if mask['visual_label'] is not None: 265 | mask['visual_label'] *= mask['visual_feat'] 266 | if mask["audio_label"] is not None: 267 | mask['audio_label'] *= mask['audio_feat'] 268 | 269 | if 'audio_coord' in label.keys() and self.config.audio_coord_dim != label['audio_coord'].shape[-1]: 270 | if self.config.audio_coord_dim == 1: 271 | label['audio_coord'] = label['audio_coord'][:, :, -1] 272 | elif self.config.audio_coord_dim == 2: 273 | label['audio_coord'] = label['audio_coord'][:, :, :-1] 274 | 275 | total_loss = 0. 276 | loss = {} 277 | for task in tasks: 278 | output_shape, loss_type, label_shape, weight = self.pretrain_loss_config[task] 279 | loss_func = poolerLoss[loss_type] 280 | 281 | if task in label.keys() and label[task] is not None: 282 | task_loss = loss_func( 283 | pred[task].view(*output_shape), 284 | label[task].view(*label_shape) 285 | ) 286 | 287 | if task_loss.dim() > 1: 288 | task_loss = task_loss.mean(1) 289 | if task_loss.dim() > 1: 290 | task_loss = task_loss.mean(1) 291 | task_loss = (task_loss * mask[task].view(-1)).mean() 292 | elif task_loss.dim() == 1: 293 | task_loss = (task_loss * mask[task].view(-1)).mean() 294 | 295 | task_loss = task_loss * weight 296 | total_loss += task_loss 297 | loss[f'loss_{task}'] = task_loss.detach() 298 | 299 | 300 | answer_pred = np.argmax(pred['qa'].detach().cpu(), 1) if 'qa' in pred.keys() else None 301 | ground_pred = pred['ground'].detach().cpu() if 'ground' in pred.keys() and pred['ground'] is not None else None 302 | 303 | return {'total_loss': total_loss, 304 | 'answer_pred': answer_pred, 305 | 'ground_pred': ground_pred, 306 | **loss} 307 | -------------------------------------------------------------------------------- /code/data/dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from itertools import chain 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | from munch import Munch 9 | 10 | from exp import ex 11 | from .load import load 12 | from utils import merge_dict, one_hot_vectorize 13 | #from .utils import pad 14 | 15 | 16 | audio_pad_dict = { 17 | 'orig': 18, 18 | 'pool_2': 9, 19 | 'pool_4': 6, 20 | 'pool_all': 1, 21 | 'top_3': 18, 22 | 'top_5': 18 23 | } 24 | 25 | def get_dataset(modes=[]): 26 | data, video, tokenizer = load(modes) 27 | outputs = {} 28 | for mode in sorted(list(data.keys())): 29 | print(f"[LOG] Loading {mode} split... ", end='') 30 | mode_feat = {} 31 | mode_ids = set([x['video_id'] for x in data[mode].values()]) 32 | for modality, feature in video.items(): 33 | mode_feat[modality] = {k: v for k, v in feature.items() if k in mode_ids} 34 | print("({} video features)".format(len(mode_ids))) 35 | outputs[mode] = Dataset(data=data[mode], mode=mode, video=mode_feat, tokenizer=tokenizer) 36 | return outputs, video, tokenizer 37 | 38 | 39 | class Dataset(torch.utils.data.Dataset): 40 | @ex.capture() 41 | def __init__(self, data, mode, video, tokenizer, model_config, device, feature_mask_rate, 42 | answer_path, num_objects): 43 | self.data = data 44 | self.video = video 45 | self.ids = list(self.data.keys()) 46 | self.tokenizer = tokenizer 47 | self.device = device 48 | self.mode = mode # pretrain, train, val, test 49 | self.feature_mask_rate = feature_mask_rate 50 | self.model_config = Munch(model_config) 51 | self.pretrain_types = self.model_config.pretrain_types if (self.mode == "pretrain" or self.mode == "preval") else ['qa'] 52 | 53 | self.answer_label = json.load(open(answer_path, 'r')) 54 | 55 | self.use_cls_token = self.model_config.use_cls_token 56 | self.num_objects = num_objects 57 | self.audio_objects = audio_pad_dict[self.model_config.audio_feature] 58 | self.geometry = self.model_config.geometry 59 | 60 | self.feature_names = [] 61 | if self.model_config.visual_feature is not None: 62 | self.feature_names.append('visual') 63 | if self.model_config.audio_feature is not None: 64 | self.feature_names.append('audio') 65 | 66 | def __len__(self): 67 | return len(self.ids) 68 | 69 | def __getitem__(self, idx): 70 | qid = self.ids[idx] 71 | datum = self.data[qid].copy() 72 | id = datum['video_id'] 73 | 74 | # apply masking 75 | # check the number of features with norm > 0 76 | # probabilistically mask and generate masking 77 | if 'visual' in self.feature_names: 78 | video_feat = self.video['visual'][id] 79 | 80 | datum['v_feat'] = video_feat['embedding'] 81 | datum['v_coord'] = self.video['coordinate'][id] 82 | 83 | if 'audio' in self.feature_names: 84 | audio_feat = self.video['audio'][id] 85 | ''' 86 | if self.model_config.use_stereo_audio: 87 | audio_feat_embedding = audio_feat['embedding'][1:, :, :] 88 | else: 89 | audio_feat_embedding = audio_feat['embedding'][0, :, :] 90 | ''' 91 | datum['a_feat'] = audio_feat['embedding'] 92 | datum['a_coord'] = audio_feat['coordinate'] 93 | 94 | return datum 95 | 96 | def prepare_pretrain(self, datum, batch_question_tokens): 97 | 98 | label = {} 99 | metadata = {} 100 | 101 | id = datum['video_id'] 102 | video_feat = datum['v_feat'] 103 | video_coord = datum['v_coord'] 104 | video_class = self.video['visual'][id]['classes'] 105 | video_score = self.video['visual'][id]['score'] 106 | 107 | if 'visual' in self.feature_names: 108 | # Select random visual feature in batch 109 | rand_id = self.data[random.choice(self.ids)]['video_id'] 110 | while rand_id == id: 111 | rand_id = self.data[random.choice(self.ids)]['video_id'] 112 | rand_feat = self.video['visual'][rand_id]['embedding'] 113 | if rand_feat.shape[0] == 0: 114 | rand_feat = np.zeros_like(video_feat) 115 | 116 | if 'vl_match' in self.pretrain_types: 117 | if random.random() > 0.5: 118 | video_feat, rand_feat = rand_feat, video_feat 119 | video_coord = self.video['visual'][rand_id]['coordinate'] 120 | video_class = self.video['visual'][rand_id]['classes'] 121 | video_score = self.video['visual'][rand_id]['score'] 122 | label['vl_match'] = 1 123 | else: 124 | label['vl_match'] = 0 125 | 126 | # CLS_v token 127 | if self.use_cls_token: 128 | video_cls = np.expand_dims(np.mean(video_feat, axis=0), axis=0) 129 | coord_cls = np.expand_dims(np.mean(video_coord, axis=0), axis=0) 130 | mask_loop_range = (1, min(len(video_feat) + 1, self.num_objects)) 131 | assert video_cls.shape == (1, 2048) 132 | 133 | video_feat = np.concatenate([video_cls, video_feat], axis=0) 134 | video_coord = np.concatenate([coord_cls, video_coord], axis=0) 135 | video_class = np.concatenate([[0], video_class]) 136 | video_score = np.concatenate([[0], video_score]) 137 | else: 138 | mask_loop_range = (0, min(len(video_feat), self.num_objects)) 139 | 140 | # zero-pad 141 | video_feat = video_feat[:self.num_objects] 142 | video_coord = video_coord[:self.num_objects] 143 | video_class = video_class[:self.num_objects] 144 | video_score = video_score[:self.num_objects] 145 | if video_feat.shape[0] < self.num_objects: 146 | surplus = self.num_objects - video_feat.shape[0] 147 | video_feat = np.concatenate([video_feat, np.zeros((surplus, video_feat.shape[1]))], axis=0) 148 | video_coord = np.concatenate([video_coord, np.zeros((surplus, video_coord.shape[1]))], axis=0) 149 | video_class = np.concatenate([video_class, np.zeros(surplus)]) 150 | video_score = np.concatenate([video_score, np.zeros(surplus)]) 151 | 152 | # Apply masking for visual feature 153 | if 'visual' in [x[:6] for x in self.pretrain_types]: 154 | mask_feat = video_feat.copy() 155 | visual_mask = [0. for _ in range(len(mask_feat))] 156 | 157 | for i in range(*mask_loop_range): 158 | prob = random.random() 159 | 160 | if prob < self.feature_mask_rate: 161 | prob /= self.feature_mask_rate 162 | 163 | if prob < 0.8: 164 | mask_feat[i, :] = 0. 165 | 166 | elif prob < 0.9: 167 | rand_idx = random.choice(range(rand_feat.shape[0])) 168 | mask_feat[i, :] = rand_feat[rand_idx, :] 169 | visual_mask[i] = 1. 170 | 171 | datum['v_feat'] = mask_feat 172 | datum['v_coord'] = video_coord 173 | datum['v_mask'] = visual_mask 174 | 175 | label['visual_feat'] = video_feat 176 | label['visual_coord'] = video_coord 177 | label['visual_label'] = video_class.astype(np.long) 178 | label['visual_score'] = video_score 179 | 180 | else: 181 | datum['v_feat'] = video_feat 182 | datum['v_coord'] = video_coord 183 | 184 | 185 | audio_feat = datum['a_feat'] # embedding 186 | audio_coord = datum['a_coord'] 187 | audio_score = self.video['audio'][id]['score'] 188 | audio_class = self.video['audio'][id]['classes'] 189 | audio_harmo = self.video['audio'][id]['harmonics'] 190 | 191 | if 'audio' in self.feature_names: 192 | rand_id = self.data[random.choice(self.ids)]['video_id'] 193 | while rand_id == id: 194 | rand_id = self.data[random.choice(self.ids)]['video_id'] 195 | 196 | if self.model_config.use_stereo_audio: 197 | audio_feat = audio_feat[1:, :, :] 198 | rand_feat = self.video['audio'][rand_id]['embedding'][1:, :, :] 199 | assert audio_feat.shape[0] == rand_feat.shape[0] 200 | else: 201 | audio_feat = np.expand_dims(audio_feat[0, :, :], axis=0) 202 | rand_feat = np.expand_dims(self.video['audio'][rand_id]['embedding'][0, :, :], axis=0) 203 | 204 | if 'al_match' in self.pretrain_types: 205 | if random.random() > 0.5: 206 | audio_feat, rand_feat = rand_feat, audio_feat 207 | audio_coord = self.video['audio'][rand_id]['coordinate'] 208 | audio_score = self.video['audio'][rand_id]['score'] 209 | audio_class = self.video['audio'][rand_id]['classes'] 210 | audio_harmo = self.video['audio'][rand_id]['harmonics'] 211 | label['al_match'] = 1 212 | else: 213 | label['al_match'] = 0 214 | 215 | # CLS_v token 216 | if self.use_cls_token: 217 | #rint(">>>" , audio_feat.shape) 218 | audio_cls = np.expand_dims(np.mean(audio_feat, axis=1), axis=1) 219 | coord_cls = np.expand_dims(np.mean(audio_coord, axis=0), axis=0) 220 | mask_loop_range = (1, min(len(audio_feat[0]) + 1, self.audio_objects)) 221 | 222 | audio_feat = np.concatenate([audio_cls, audio_feat], axis=1) 223 | audio_coord = np.concatenate([coord_cls, audio_coord], axis=0) 224 | audio_score = np.concatenate([[0], audio_score]) 225 | audio_class = np.concatenate([[0], audio_class]) 226 | audio_harmo = np.concatenate([[0], audio_harmo]) 227 | else: 228 | mask_loop_range = (0, min(len(audio_feat[0]), self.audio_objects)) 229 | 230 | # zero-pad 231 | audio_feat = audio_feat[:self.audio_objects] 232 | audio_coord = audio_coord[:self.audio_objects] 233 | audio_class = audio_class[:self.audio_objects] 234 | audio_score = audio_score[:self.audio_objects] 235 | audio_harmo = audio_harmo[:self.audio_objects] 236 | if audio_feat.shape[1] < self.audio_objects: 237 | surplus = self.audio_objects - audio_feat.shape[1] 238 | audio_feat = np.concatenate([audio_feat, np.zeros((audio_feat.shape[0], surplus, audio_feat.shape[2]))], axis=1) 239 | audio_coord = np.concatenate([audio_coord, np.zeros((surplus, audio_coord.shape[1]))], axis=0) 240 | audio_class = np.concatenate([audio_class, np.zeros(surplus)]) 241 | audio_score = np.concatenate([audio_score, np.zeros(surplus)]) 242 | audio_harmo = np.concatenate([audio_harmo, np.zeros(surplus)]) 243 | 244 | 245 | # Do masking 246 | if 'audio' in [x[:5] for x in self.pretrain_types]: 247 | mask_feat = audio_feat.copy() 248 | audio_mask = [0. for _ in range(len(mask_feat[0]))] 249 | 250 | for i in range(*mask_loop_range): 251 | prob = random.random() 252 | 253 | if prob < self.feature_mask_rate: 254 | prob /= self.feature_mask_rate 255 | 256 | if prob < 0.8: 257 | mask_feat[:, i, :] = 0. 258 | 259 | elif prob < 0.9: 260 | rand_idx = random.choice(range(rand_feat.shape[1])) 261 | mask_feat[:, i, :] = rand_feat[:, rand_idx, :] 262 | 263 | audio_mask[i] = 1. 264 | 265 | datum['a_feat'] = mask_feat.squeeze() 266 | datum['a_coord'] = audio_coord 267 | datum['a_mask'] = audio_mask 268 | 269 | label['audio_feat'] = audio_feat 270 | label['audio_label'] = audio_class.astype(np.long) 271 | label['audio_score'] = audio_score 272 | label['audio_coord'] = audio_coord 273 | label['audio_harmonics'] = audio_harmo 274 | else: 275 | datum['a_feat'] = audio_feat.squeeze() 276 | datum['a_coord'] = audio_coord 277 | 278 | 279 | if 'mask_lm' in self.pretrain_types: 280 | question = datum['question'] 281 | mask_feat = question.copy() 282 | ques_mask = [0. for _ in range(len(mask_feat))] 283 | 284 | for i, token in enumerate(question): 285 | prob = random.random() 286 | 287 | if token in (self.tokenizer.cls_token_id, self.tokenizer.sep_token_id): 288 | continue 289 | 290 | if prob < self.feature_mask_rate: 291 | prob /= self.feature_mask_rate 292 | 293 | if prob < 0.8: 294 | mask_feat[i] = self.tokenizer.mask_token_id 295 | elif prob < 0.9: 296 | random_token = random.choice(batch_question_tokens) 297 | 298 | # random token should not be CLS, SEP or MASK 299 | while random_token in (self.tokenizer.cls_token_id, 300 | self.tokenizer.sep_token_id, 301 | self.tokenizer.mask_token_id, 302 | question[i]): 303 | random_token = random.choice(batch_question_tokens) 304 | mask_feat[i] = random_token 305 | 306 | ques_mask[i] = 1. 307 | 308 | datum['l_feat'] = mask_feat 309 | datum['l_mask'] = ques_mask 310 | 311 | label['mask_lm'] = question 312 | else: 313 | datum['l_feat'] = datum['question'] 314 | 315 | if 'qa' in self.pretrain_types: 316 | metadata['answer_str'] = datum['answer'] 317 | # label['qa'] = one_hot_vectorize(datum['answer'], self.answer_label) 318 | if datum['answer'] in self.answer_label: 319 | label['qa'] = self.answer_label.index(datum['answer']) 320 | label['qa_valid'] = 1. 321 | else: 322 | label['qa'] = 0 323 | label['qa_valid'] = 0. 324 | 325 | datum.pop('answer') 326 | 327 | # Move metadata to label dict 328 | metadata['question'] = ' '.join(self.tokenizer.convert_ids_to_tokens(datum.pop('question'))) 329 | metadata['question_id'] = datum.pop('question_id') 330 | metadata['question_type'] = datum.pop('question_type') 331 | metadata['video_id'] = datum.pop('video_id') 332 | for geo in ['cartesian', 'angular', 'spherical', 'quaternion']: 333 | if geo == self.geometry: 334 | label['ground'] = datum.pop(geo)[1:] 335 | else: 336 | datum.pop(geo) 337 | 338 | #print([f"{x:.3f}" for x in label['ground']], np.array(label['ground']).shape, np.array(datum['v_coord']).shape) 339 | 340 | return datum, label, metadata 341 | 342 | def collate_fn(self, batch): 343 | labels = [] 344 | meta = [] 345 | 346 | max_question_length = 0 347 | 348 | # follow the token distribution of curent batch 349 | batch_question_tokens = list(chain(*[datum['question'] for datum in batch])) 350 | 351 | for i, datum in enumerate(batch): 352 | max_question_length = max(max_question_length, len(datum['question'])) 353 | batch[i], label, metadata = self.prepare_pretrain(datum, batch_question_tokens) 354 | labels.append(label) 355 | meta.append(metadata) 356 | 357 | _batch = merge_dict(batch) 358 | _labels = merge_dict(labels) 359 | _meta = merge_dict(meta) 360 | 361 | # Debugging tokens 362 | # print('\n'.join(str(self.tokenizer.convert_ids_to_tokens(x)) for x in _batch['question'])) 363 | for i, question in enumerate(_batch['l_feat']): 364 | surplus = max_question_length - len(question) 365 | 366 | _batch['l_feat'][i] += [self.tokenizer.pad_token_id for _ in range(surplus)] 367 | if "mask_lm" in self.pretrain_types: 368 | _batch['l_mask'][i] += [self.tokenizer.pad_token_id for _ in range(surplus)] 369 | 370 | _labels['mask_lm'][i] += [self.tokenizer.pad_token_id for _ in range(surplus)] 371 | # mlm = _labels['mask_lm'][i] + [self.tokenizer.pad_token_id for _ in range(surplus)] 372 | # _labels['mask_lm'][i] = np.zeros((len(mlm), self.model_config.vocab_size)) 373 | # _labels['mask_lm'][i][np.arange(len(mlm)), mlm] = 1 374 | 375 | for k in _batch.keys(): 376 | _batch[k] = np.array(_batch[k]) 377 | if _batch[k].dtype == np.float64: 378 | _batch[k] = np.array(_batch[k], dtype=np.float32) 379 | 380 | for k in _labels.keys(): 381 | _labels[k] = np.array(_labels[k]) 382 | if _labels[k].dtype == np.float64: 383 | _labels[k] = np.array(_labels[k], dtype=np.float32) 384 | 385 | return _batch, _labels, _meta 386 | --------------------------------------------------------------------------------