├── .gitignore ├── LICENSE ├── README.md ├── annotation └── egocvr │ ├── egocvr_annotations.csv │ └── egocvr_data.csv ├── assets ├── benchmarks.png └── concept.png ├── egocvr_retrieval.py └── model ├── blip ├── __init__.py ├── base.py ├── med.py ├── med_config.json ├── model.py ├── transforms.py └── vit.py ├── egovlpv2 ├── EgoNCE_MLM_ITM_Config.yaml ├── base │ ├── __init__.py │ └── base_model.py ├── heads.py ├── model.py ├── parse_config.py ├── roberta.py ├── util.py ├── video_transformer.py └── video_utils.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | cache_dir/ 165 | annotation/egocvr/egocvr_annotations_gallery.csv 166 | checkpoints/ 167 | data/ 168 | embeddings/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EML Tübingen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EgoCVR: An Egocentric Benchmark for Fine-Grained Composed Video Retrieval [ECCV 2024] 2 | 3 | __Authors__: Thomas Hummel*, Shyamgopal Karthik*, Mariana-Iuliana Georgescu, Zeynep Akata 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2407.16658) 6 | 7 | 8 | ### Abstract 9 |
10 | 11 | > In Composed Video Retrieval, a video and a textual description which modifies the video content are provided as inputs to the model. The aim is to retrieve the relevant video with the modified content from a database of videos. 12 | In this challenging task, the first step is to acquire large-scale training datasets and collect high-quality benchmarks for evaluation. In this work, we introduce EgoCVR, a new evaluation benchmark for fine-grained Composed Video Retrieval using large-scale egocentric video datasets. EgoCVR consists of 2,295 queries that specifically focus on high-quality temporal video understanding. We find that existing Composed Video Retrieval frameworks may not achieve the necessary high-quality temporal video understanding for this task. 13 | To address this shortcoming, we adapt a simple training-free method, propose a generic re-ranking framework for Composed Video Retrieval, and demonstrate that this achieves strong results on EgoCVR. 14 |
15 | 16 | ![](assets/concept.png "Composed Video Retrieval with EgoCVR") 17 | 18 | 19 | ## The EgoCVR benchmark 20 | We propose EgoCVR, a benchmark with 2,295 queries, to evaluate vision-language models for the task of Composed Video Retrieval. The videos and corresponding annotations were collected from the [Ego4D FHO task](https://ego4d-data.org/docs/tutorials/FHO_Overview/). 21 | 22 | 23 | ### Comparison to existing benchmarks 24 | ![](assets/benchmarks.png "Benchmark Comparison WebVid-CoVR-Test vs. EgoCVR") 25 | EgoCVR focuses to a significantly greater extent on temporal and action-related modifications (blue) as opposed to object-centred modifications (orange) when compared to the previously existing WebVid-CoVR-Test benchmark. 26 | 27 | ## Dataset structure 28 | 29 | ### Annotations 30 | The annotations for the EgoCVR benchmark are stored in ```annotation/egocvr/egocvr_annotations.csv```. Each row in the CSV file corresponds to a single query. The columns are as follows: 31 | 32 | - ```video_clip_id```: The unique identifier of the query video clip. 33 | - ```target_clip_ids```: The unique identifiers of the target video clips. 34 | - ```video_clip_narration```: The narration of the query video clip. 35 | - ```target_clip_narration```: The narration of the target video clips. 36 | - ```instruction```: The textual video modification instruction for the query. 37 | - ```modified_captions```: Our TF-CVR modified captions for retrieving the target video clips (see Paper). 38 | 39 | ### Evaluation 40 | We consider two possible evaluation settings for EgoCVR: 41 | - ```global```: The standard composed image/video retrieval setting, where the gallery comprises a long list of videos. In the global setting, the query is searched in the pool of videos, which contains all the other video queries, along with their video distractors. 42 | - ```local```: The local search is obtained by restricting the gallery to have only clips from the same video sequence. This strategy simulates the scenario when searching in a long video for a specific moment. 43 | 44 | The gallery information for the EgoCVR benchmark is stored in ```annotation/egocvr/egocvr_annotations_gallery.csv```. In addition to the columns from the query annotations, the gallery annotations contain the following columns: 45 | - ```global_idx```: Indices of videos of the gallery for the global evaluation. 46 | - ```local_idx```: Indices of videos of the gallery for the local evaluation. 47 | 48 | Please follow the instructions below for downloading the gallery information. 49 | 50 | ## Setup 51 | 1. Install the required packages 52 | 2. Download the EgoCVR gallery information 53 | 3. Download either the EgoCVR videos or the pre-computed model embeddings 54 | 4. Download the model weights 55 | 56 | ### 1. Installing Required Packages 57 | > Instructions coming soon. 58 | 59 | ### 2. Downloading the EgoCVR Gallery Information 60 | The gallery information for the EgoCVR benchmark is stored in ```annotation/egocvr/egocvr_annotations_gallery.csv```. The gallery information can be downloaded from the following link: 61 | - EgoCVR Gallery Information: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_annotations_gallery.csv) 62 | 63 | 64 | ### 3. Downloading the EgoCVR Videos or Pre-computed Model Embeddings 65 | 66 | #### EgoCVR Videos 67 | 68 | We provide the video clips from the EgoCVR benchmark to download. We provide the clips in original full scale and downscaled to with the short side 256px. For all models we use the full scale video clips as input except for EgoVLPv2, for which we follow the model recommendations of downscaling first ([more information](https://github.com/facebookresearch/EgoVLPv2/blob/main/EgoVLPv2/README.md)). 69 | 70 | - Full scale video clips: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_clips.zip) 71 | - Downscaled video clips: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_clips_256.zip) 72 | 73 | After downloading, please extract the zip file and place the video clips to the ```data/``` directory. 74 | ```bash 75 | unzip egocvr_clips.zip -d data/ 76 | unzip egocvr_clips_256.zip -d data/ 77 | ``` 78 | 79 | #### Pre-computed Model Embeddings 80 | We provide also the pre-computed model embeddings for the EgoCVR benchmark to download. 81 | - EgoVLPv2 Embeddings: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_embeddings_egovlpv2.zip) 82 | - LanguageBind Embeddings: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_embeddings_languagebind.zip) 83 | - BLIP Embeddings: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_embeddings_blip.zip) 84 | - CLIP Embeddings: [Download](https://mlcloud.uni-tuebingen.de:7443/eccvdatasets/egocvr/egocvr_embeddings_clip.zip) 85 | 86 | After downloading, please extract the zip file and place the model embeddings to the ```embeddings/``` directory. 87 | ```bash 88 | unzip egocvr_embeddings_egovlpv2.zip -d embeddings/ 89 | ``` 90 | 91 | ### 4. Downloading the Model Weights 92 | 93 | - EgoVLPv2 Model Weights: [Download](http://www.cis.jhu.edu/~shraman/EgoVLPv2/ckpts/Pre-trained/EgoVLPv2.pth) (from the [official repository](https://github.com/facebookresearch/EgoVLPv2/blob/main/EgoVLPv2/README.md)) 94 | - BLIPCoVR Model Weights: [Download](https://huggingface.co/lucas-ventura/CoVR/resolve/main/webvid-covr.ckpt) (from the [official repository](https://github.com/lucas-ventura/CoVR/)) 95 | 96 | The model weights should be placed in the ```checkpoints/``` directory. 97 | 98 | 99 | ## Evaluation 100 | To evaluate different methods on the EgoCVR benchmark, please run the following command: 101 | ```bash 102 | # Evaluation in the global setting 103 | python egocvr_retrieval.py --evaluation global 104 | # Evaluation in the local setting 105 | python egocvr_retrieval.py --evaluation local 106 | ``` 107 | You can specify the model and the modalities to evaluate by using the following arguments: 108 | - ```--model```: The model to evaluate. Possible values are ```egovlpv2```, ```languagebind```, ```blip```, and ```clip```. 109 | - For 2-Stage retrieval, you can use up to two models separated by a space. For example, ```--model languagebind egovlpv2```. 110 | - ```--modalities```: The query modalities to use. Possible values are ```visual```, ```text```, and ```visual-text```. 111 | - For 2-Stage retrieval, you can use up to two modalities separated by a space. For example, ```--modalities visual text```. 112 | - ```--text```: The source of the query text to use. Possible values are 113 | - ```instruction```: The original instruction from the query annotation. 114 | - ```tfcvr```: The TF-CVR modified captions from the query annotation. These captions are created by video captioning the query video and then modifying the captions according to the instruction using an LLM (see Paper). 115 | - ```gt```: The ground truth narration from the target video clip. 116 | 117 | ### TFR-CVR: Training-Free Re-ranking for Composed Video Retrieval 118 | In the Paper, we propose a simple training-free re-ranking method for Composed Video Retrieval. To evaluate the TFR-CVR and TF-CVR method, please run the following command: 119 | ```bash 120 | # 2-Stage retrieval in the global setting 121 | python egocvr_retrieval.py --evaluation global --model languagebind egovlpv2 --modalities visual text --text tfcvr 122 | 123 | # 1-Stage TF-CVR in the local setting 124 | python egocvr_retrieval.py --evaluation local --model egovlpv2 --modalities text --text tfcvr 125 | ``` 126 | 127 | ### Additional Examples 128 | ```bash 129 | # CLIP 130 | python egocvr_retrieval.py --evaluation global --model clip --modalities visual-text --text instruction 131 | # BLIP 132 | python egocvr_retrieval.py --evaluation global --model blip --modalities visual-text --text instruction 133 | # LanguageBind 134 | python egocvr_retrieval.py --evaluation global --model languagebind --modalities visual-text --text instruction 135 | # BLIP_CoVR 136 | python egocvr_retrieval.py --evaluation global --model blip --modalities visual-text --text instruction --fusion crossattn --finetuned 137 | ``` 138 | 139 | 140 | ### Citation 141 | ```bibtex 142 | @article{hummel2024egocvr, 143 | title={EgoCVR: An Egocentric Benchmark for Fine-Grained Composed Video Retrieval}, 144 | author={Thomas Hummel and Shyamgopal Karthik and Mariana-Iuliana Georgescu and Zeynep Akata}, 145 | journal={European Conference on Computer Vision (ECCV)}, 146 | year={2024} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /assets/benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/EgoCVR/fc08f95aac692b75ab4e2282bb90acd3d8658075/assets/benchmarks.png -------------------------------------------------------------------------------- /assets/concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/EgoCVR/fc08f95aac692b75ab4e2282bb90acd3d8658075/assets/concept.png -------------------------------------------------------------------------------- /egocvr_retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import sys 4 | from ast import literal_eval 5 | from collections import defaultdict 6 | from functools import partial 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import torch.nn.functional as F 13 | from tqdm import tqdm 14 | 15 | ROOT_DIR = Path(__file__).parent 16 | sys.path.insert(0, ROOT_DIR.as_posix()) 17 | 18 | from model.models import ( 19 | forward_blip, 20 | forward_blip_text, 21 | forward_clip, 22 | forward_clip_text, 23 | forward_egovlpv2, 24 | forward_egovlpv2_text, 25 | forward_languagebind, 26 | forward_languagebind_text, 27 | init_BLIP, 28 | init_CLIP, 29 | init_EgoVLPv2, 30 | init_languagebind, 31 | ) 32 | 33 | CUDA_DEVICE = "cuda:0" 34 | EMBEDDING_DIR = "./embeddings" 35 | VIDEO_DIR = "./data" 36 | 37 | parser = argparse.ArgumentParser( 38 | "Script to perform Composed Video Retrieval on EgoCVR dataset" 39 | ) 40 | 41 | parser.add_argument( 42 | "--models", 43 | nargs="*", 44 | default=["languagebind", "egovlpv2"], 45 | type=str, 46 | help="Which models to use for retrieval.", 47 | ) 48 | parser.add_argument( 49 | "--modalities", 50 | default=["visual", "text"], 51 | nargs="*", 52 | type=str, 53 | help="Query modalities to use for retrieval.", 54 | ) 55 | parser.add_argument( 56 | "--evaluation", 57 | default="global", 58 | choices=[ 59 | "local", 60 | "global", 61 | ], 62 | type=str, 63 | help="Type of evaluation. Local: within the same video, Global: across all videos", 64 | ) 65 | parser.add_argument( 66 | "--finetuned", 67 | action="store_true", 68 | help="Use finetuned CVR model if available (only BLIP).", 69 | ) 70 | parser.add_argument( 71 | "--query_frames", default=15, type=int, help="Number of video query frames." 72 | ) 73 | parser.add_argument( 74 | "--target_frames", default=15, type=int, help="Number of video target frames." 75 | ) 76 | parser.add_argument( 77 | "--text", 78 | default="tfcvr", 79 | choices=["instruction", "tfcvr", "gt"], 80 | type=str, 81 | help="Type of query text to use for retrieval. instruction: instruction text, tfcvr: modified captions, gt: target clip narration", 82 | ) 83 | parser.add_argument( 84 | "--fusion", 85 | default="avg", 86 | choices=["crossattn", "avg"], 87 | type=str, 88 | help="Query fusion strategy when using visual-text modality.", 89 | ) 90 | parser.add_argument( 91 | "--min_gallery_size", default=2, type=int, help="Minimum gallery size. default=2" 92 | ) 93 | parser.add_argument( 94 | "--no_precomputed", action="store_true", help="Do not use precomputed embeddings." 95 | ) 96 | parser.add_argument( 97 | "--neighbors", 98 | default=15, 99 | type=int, 100 | help="Number of neighbors to use for the first stage of 2-stage retrieval.", 101 | ) 102 | 103 | args = parser.parse_args() 104 | 105 | ##################### 106 | ###### CONFIG ####### 107 | ##################### 108 | config = { 109 | "blip": { 110 | "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv", 111 | "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_blip-large.csv", 112 | "ckpt_path_finetuned": "./checkpoints/webvid-covr.ckpt", 113 | "ckpt_path_notfinetuned": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth", 114 | "video_folder": f"{VIDEO_DIR}/egocvr_clips", 115 | }, 116 | "egovlpv2": { 117 | "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv", 118 | "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_EgoVLPv2.csv", 119 | "ckpt_path": "./checkpoints/EgoVLPv2.pth", 120 | "video_folder": f"{VIDEO_DIR}/egocvr_clips_256", 121 | }, 122 | "clip": { 123 | "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv", 124 | "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_ViT-L-14_datacomp_xl_s13b_b90k.csv", 125 | "video_folder": f"{VIDEO_DIR}/egocvr_clips", 126 | }, 127 | "languagebind": { 128 | "annotations": f"{ROOT_DIR}/annotation/egocvr/egocvr_annotations_gallery.csv", 129 | "embedding_path": f"{EMBEDDING_DIR}/EgoCVR_LanguageBind.csv", 130 | "video_folder": f"{VIDEO_DIR}/egocvr_clips", 131 | }, 132 | } 133 | 134 | modalities = args.modalities 135 | assert len(modalities) <= 2, "We implemented only 2 stages" 136 | evaluation = args.evaluation 137 | finetuned = args.finetuned 138 | num_query_frames = args.query_frames 139 | num_target_frames = args.target_frames 140 | fusion = args.fusion 141 | text_variant = args.text 142 | min_gallery_size = args.min_gallery_size 143 | no_precomputed = args.no_precomputed 144 | num_neighbors = args.neighbors 145 | 146 | # Recalls 147 | recalls = [1, 5, 10] if not evaluation == "local" else [1, 2, 3] 148 | 149 | if "blip" in args.models: 150 | config["blip"]["ckpt_path"] = ( 151 | config["blip"]["ckpt_path_finetuned"] 152 | if finetuned 153 | else config["blip"]["ckpt_path_notfinetuned"] 154 | ) 155 | 156 | query_frame_method = "middle" if num_query_frames == 1 else "sample" 157 | if text_variant == "tfcvr": 158 | text_variant = "modified_captions" 159 | elif text_variant == "gt": 160 | text_variant = "target_clip_narration" 161 | else: 162 | text_variant = "instruction" 163 | 164 | for _, config_ in config.items(): 165 | config_["embedding_path_raw"] = ( 166 | config_["embedding_path"].replace(".csv", ".pt") 167 | if Path(config_["embedding_path"].replace(".csv", ".pt")).exists() 168 | else None 169 | ) 170 | 171 | assert len(args.models) == len(args.modalities) 172 | 173 | 174 | def seed_everything(seed=42): 175 | # Set Python seed 176 | random.seed(seed) 177 | 178 | # Set NumPy seed 179 | np.random.seed(seed) 180 | 181 | # Set PyTorch seed for CPU 182 | torch.manual_seed(seed) 183 | 184 | # Set PyTorch seed for GPU, if available 185 | if torch.cuda.is_available(): 186 | torch.cuda.manual_seed(seed) 187 | torch.cuda.manual_seed_all(seed) 188 | 189 | 190 | def load_embeddings(path, emb_path=None): 191 | df = pd.read_csv(path) 192 | if emb_path: 193 | embeddings = torch.load(emb_path) 194 | embeddings = embeddings.to(CUDA_DEVICE) 195 | else: 196 | embeddings = df["clip_embeddings"].apply( 197 | lambda emb: np.array(literal_eval(emb)) 198 | ) 199 | embeddings = np.stack(embeddings) 200 | embeddings = torch.tensor(embeddings, device=CUDA_DEVICE, dtype=torch.float32) 201 | return df, embeddings 202 | 203 | 204 | def dump_embeddings(path, emb_path=None): 205 | # dump embeddings to .pt file if not already present 206 | if not emb_path: 207 | dump_path = path.replace(".csv", ".pt") 208 | print( 209 | f"Dumping embeddings to {dump_path}. Improves loading time for future runs." 210 | ) 211 | df = pd.read_csv(path) 212 | embeddings = df["clip_embeddings"].apply( 213 | lambda emb: np.array(literal_eval(emb)) 214 | ) 215 | embeddings = np.stack(embeddings) 216 | embeddings = torch.tensor(embeddings, device=CUDA_DEVICE, dtype=torch.float32) 217 | torch.save(embeddings, dump_path) 218 | return dump_path 219 | return emb_path 220 | 221 | 222 | def nearest_neighbors( 223 | candidate_embeddings, 224 | query, 225 | k, 226 | normalize=True, 227 | return_distances=True, 228 | ): 229 | if query.ndim == 1: 230 | query = query.unsqueeze(0) 231 | if normalize: 232 | candidate_embeddings = F.normalize(candidate_embeddings, dim=-1) 233 | query = F.normalize(query, dim=-1) 234 | 235 | similarities = torch.matmul(query, candidate_embeddings.T) 236 | topk_values, topk_indices = torch.topk(similarities, k, largest=True) 237 | 238 | if return_distances: 239 | return topk_values, topk_indices 240 | else: 241 | return topk_indices 242 | 243 | 244 | def compute_recall_at_k( 245 | query_embeddings, 246 | candidate_embeddings, 247 | ground_truth, 248 | k, 249 | gallery, 250 | min_gallery_size, 251 | modalities, 252 | num_neighbors=None, 253 | ): 254 | 255 | _, indices = nearest_neighbors( 256 | candidate_embeddings[args.models[0]], 257 | torch.stack(query_embeddings[args.models[0]][modalities[0]]), 258 | (len(candidate_embeddings[args.models[0]])), 259 | ) 260 | 261 | total_relevant = 0 262 | total_retrieved_relevant = 0 263 | 264 | num_queries = len(ground_truth) 265 | for i in range(num_queries): 266 | 267 | relevant_items = set(ground_truth[i]) 268 | filtered_indices = torch.tensor(list(set(gallery[i])), device=CUDA_DEVICE) 269 | 270 | filter_mask = torch.isin(indices[i], filtered_indices) 271 | filtered_indices = indices[i][filter_mask] 272 | 273 | if len(filtered_indices) < min_gallery_size: 274 | # skip this query 275 | continue 276 | 277 | if len(modalities) > 1: 278 | new_query = query_embeddings[args.models[1]][modalities[1]][i] 279 | if new_query.ndim < 2: 280 | new_query = new_query.unsqueeze(0) 281 | 282 | new_candidates_indices = filtered_indices[:num_neighbors] 283 | new_candidates = candidate_embeddings[args.models[1]][ 284 | new_candidates_indices 285 | ] 286 | _, new_indices = nearest_neighbors( 287 | new_candidates, 288 | new_query, 289 | k=(len(new_candidates)), 290 | ) 291 | 292 | filtered_indices = new_candidates_indices.cpu().numpy()[ 293 | new_indices.cpu().numpy()[0] 294 | ] 295 | 296 | retrieved_items = set(filtered_indices[:k].tolist()) 297 | relevant_retrieved = relevant_items.intersection(retrieved_items) 298 | 299 | total_relevant += 1 300 | total_retrieved_relevant += min(len(relevant_retrieved), 1) 301 | 302 | recall_at_k = total_retrieved_relevant / total_relevant if total_relevant > 0 else 0 303 | return recall_at_k 304 | 305 | 306 | def main(): 307 | print( 308 | f"Running {args.models} retrieval with {modalities} using {evaluation} evaluation." 309 | ) 310 | seed_everything(123) 311 | 312 | tqdm.pandas() 313 | models = {} 314 | frame_loaders = {} 315 | tokenizers = {} 316 | model_forwards = {} 317 | text_forwards = {} 318 | if "blip" in args.models: 319 | model_blip, frame_loader_blip, tokenizer_blip = init_BLIP( 320 | checkpoint_path=config["blip"]["ckpt_path"], 321 | query_frame_method=query_frame_method, 322 | num_query_frames=num_query_frames, 323 | device=CUDA_DEVICE, 324 | ) 325 | models["blip"] = model_blip 326 | frame_loaders["blip"] = frame_loader_blip 327 | tokenizers["blip"] = tokenizer_blip 328 | model_forwards["blip"] = forward_blip 329 | text_forwards["blip"] = forward_blip_text 330 | 331 | if "egovlpv2" in args.models: 332 | model_egovlpv2, frame_loader_egovlpv2, tokenizer_egovlpv2 = init_EgoVLPv2( 333 | checkpoint_path=config["egovlpv2"]["ckpt_path"], device=CUDA_DEVICE 334 | ) 335 | models["egovlpv2"] = model_egovlpv2 336 | frame_loaders["egovlpv2"] = frame_loader_egovlpv2 337 | tokenizers["egovlpv2"] = tokenizer_egovlpv2 338 | model_forwards["egovlpv2"] = forward_egovlpv2 339 | text_forwards["egovlpv2"] = forward_egovlpv2_text 340 | 341 | if "clip" in args.models: 342 | model_clip, frame_loader_clip, tokenizer_clip = init_CLIP( 343 | query_frame_method=query_frame_method, 344 | num_query_frames=num_query_frames, 345 | device=CUDA_DEVICE, 346 | ) 347 | models["clip"] = model_clip 348 | frame_loaders["clip"] = frame_loader_clip 349 | tokenizers["clip"] = tokenizer_clip 350 | model_forwards["clip"] = forward_clip 351 | text_forwards["clip"] = partial(forward_clip_text, tokenizer=tokenizer_clip) 352 | 353 | if "languagebind" in args.models: 354 | model_languagebind, frame_loader_languagebind, tokenizer_languagebind = ( 355 | init_languagebind(device=CUDA_DEVICE) 356 | ) 357 | models["languagebind"] = model_languagebind 358 | frame_loaders["languagebind"] = frame_loader_languagebind 359 | tokenizers["languagebind"] = tokenizer_languagebind 360 | model_forwards["languagebind"] = forward_languagebind 361 | text_forwards["languagebind"] = forward_languagebind_text 362 | 363 | df_dict = {} 364 | model_embeddings_dict = {} 365 | for model in set(args.models): 366 | dump_embeddings( 367 | config[model]["embedding_path"], config[model]["embedding_path_raw"] 368 | ) 369 | df_dict[model], model_embeddings_dict[model] = load_embeddings( 370 | config[model]["embedding_path"], config[model]["embedding_path_raw"] 371 | ) 372 | 373 | # Fix for LanguageBind embeddings due to possible unnecessary extra dimension 374 | if "languagebind" in args.models: 375 | model_embeddings_dict["languagebind"] = model_embeddings_dict[ 376 | "languagebind" 377 | ].squeeze(1) 378 | 379 | annotation_df = pd.read_csv(config[args.models[0]]["annotations"]) 380 | 381 | all_targets = annotation_df["target_clip_ids"].apply(literal_eval) 382 | 383 | gallery = ( 384 | annotation_df[f"{args.evaluation}_idx"].progress_apply(literal_eval).tolist() 385 | ) 386 | 387 | query_embeddings = {} 388 | for model in set(args.models): 389 | query_embeddings[model] = defaultdict(list) 390 | candidate_embeddings = [] 391 | ground_truth = [] 392 | 393 | index_mapping = {} 394 | for model in set(args.models): 395 | for i in range(len(model_embeddings_dict[model])): 396 | clip_id = df_dict[model].iloc[i]["clip_name"] 397 | index_mapping[clip_id] = i 398 | 399 | print(f"Generating {args.models} {modalities} embeddings") 400 | for i in tqdm(range(len(annotation_df))): 401 | 402 | modifier_text = annotation_df.iloc[i][text_variant] 403 | 404 | video_uid = annotation_df.iloc[i]["video_clip_id"].split("_")[0] 405 | clip_name = annotation_df.iloc[i]["video_clip_id"] 406 | 407 | with torch.no_grad(): 408 | for modality, model in zip(modalities, args.models): 409 | video_path = ( 410 | Path(config[model]["video_folder"]) / video_uid / f"{clip_name}.mp4" 411 | ) 412 | query_video = model_embeddings_dict[model][index_mapping[clip_name]] 413 | query_caption = modifier_text 414 | 415 | query_embedding = model_forwards[model]( 416 | modality, 417 | models[model], 418 | tokenizers[model], 419 | query_video, 420 | query_caption, 421 | video_path, 422 | frame_loaders[model], 423 | fusion, 424 | num_query_frames, 425 | query_frame_method, 426 | use_precomputed=(not no_precomputed), 427 | ) 428 | 429 | query_embeddings[model][modality].append(query_embedding) 430 | 431 | all_gts = [] 432 | for entry in all_targets[i]: 433 | all_gts.append(index_mapping[entry]) 434 | ground_truth.append(all_gts) 435 | 436 | candidate_embeddings = model_embeddings_dict 437 | if num_target_frames == 1: 438 | # use only the middle frame for target clips 439 | for model in set(args.models): 440 | if "languagebind" not in model and "egovlpv2" not in model: 441 | temporal_mid = candidate_embeddings[model].shape[1] // 2 442 | candidate_embeddings[model] = candidate_embeddings[model][ 443 | :, temporal_mid, : 444 | ] 445 | 446 | for model in set(args.models): 447 | if candidate_embeddings[model].ndim > 2: 448 | candidate_embeddings[model] = candidate_embeddings[model].mean(1) 449 | 450 | recall_results = [] 451 | for k in recalls: 452 | recall = compute_recall_at_k( 453 | query_embeddings, 454 | candidate_embeddings, 455 | ground_truth, 456 | k, 457 | gallery, 458 | min_gallery_size, 459 | modalities, 460 | num_neighbors, 461 | ) 462 | recall_results.append(recall) 463 | print( 464 | f"Recall@{','.join([str(r) for r in recalls])}: {' & '.join([str('{0:.3f}'.format(res)) for res in recall_results])} \\\\" 465 | ) 466 | 467 | 468 | if __name__ == "__main__": 469 | main() 470 | -------------------------------------------------------------------------------- /model/blip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/EgoCVR/fc08f95aac692b75ab4e2282bb90acd3d8658075/model/blip/__init__.py -------------------------------------------------------------------------------- /model/blip/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | """ 8 | import warnings 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | import os 13 | from urllib.parse import urlparse 14 | 15 | import torch 16 | from hydra.utils import get_original_cwd 17 | from timm.models.hub import download_cached_file 18 | from torch import nn 19 | from transformers import BertTokenizer 20 | 21 | from model.blip.med import BertConfig, BertLMHeadModel, BertModel 22 | from model.blip.vit import VisionTransformer, interpolate_pos_embed 23 | 24 | 25 | class BLIP_Base(nn.Module): 26 | def __init__( 27 | self, 28 | med_config="configs/med_config.json", 29 | image_size=224, 30 | vit="base", 31 | vit_grad_ckpt=False, 32 | vit_ckpt_layer=0, 33 | ): 34 | """ 35 | Args: 36 | med_config (str): path for the mixture of encoder-decoder model's configuration file 37 | image_size (int): input image size 38 | vit (str): model size of vision transformer 39 | """ 40 | super().__init__() 41 | 42 | self.visual_encoder, vision_width = create_vit( 43 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 44 | ) 45 | self.tokenizer = init_tokenizer() 46 | med_config = BertConfig.from_json_file(med_config) 47 | med_config.encoder_width = vision_width 48 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 49 | 50 | def forward(self, image, caption, mode): 51 | assert mode in [ 52 | "image", 53 | "text", 54 | "multimodal", 55 | ], "mode parameter must be image, text, or multimodal" 56 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 57 | 58 | if mode == "image": 59 | # return image features 60 | image_embeds = self.visual_encoder(image) 61 | return image_embeds 62 | 63 | elif mode == "text": 64 | # return text features 65 | text_output = self.text_encoder( 66 | text.input_ids, 67 | attention_mask=text.attention_mask, 68 | return_dict=True, 69 | mode="text", 70 | ) 71 | return text_output.last_hidden_state 72 | 73 | elif mode == "multimodal": 74 | # return multimodel features 75 | image_embeds = self.visual_encoder(image) 76 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 77 | image.device 78 | ) 79 | 80 | text.input_ids[:, 0] = self.tokenizer.enc_token_id 81 | output = self.text_encoder( 82 | text.input_ids, 83 | attention_mask=text.attention_mask, 84 | encoder_hidden_states=image_embeds, 85 | encoder_attention_mask=image_atts, 86 | return_dict=True, 87 | ) 88 | return output.last_hidden_state 89 | 90 | 91 | class BLIP_Decoder(nn.Module): 92 | def __init__( 93 | self, 94 | med_config="configs/med_config.json", 95 | image_size=384, 96 | vit="base", 97 | vit_grad_ckpt=False, 98 | vit_ckpt_layer=0, 99 | prompt="a picture of ", 100 | ): 101 | """ 102 | Args: 103 | med_config (str): path for the mixture of encoder-decoder model's configuration file 104 | image_size (int): input image size 105 | vit (str): model size of vision transformer 106 | """ 107 | super().__init__() 108 | 109 | self.visual_encoder, vision_width = create_vit( 110 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 111 | ) 112 | self.tokenizer = init_tokenizer() 113 | med_config = BertConfig.from_json_file(med_config) 114 | med_config.encoder_width = vision_width 115 | self.text_decoder = BertLMHeadModel(config=med_config) 116 | 117 | self.prompt = prompt 118 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 119 | 120 | def forward(self, image, caption): 121 | image_embeds = self.visual_encoder(image) 122 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 123 | image.device 124 | ) 125 | 126 | text = self.tokenizer( 127 | caption, 128 | padding="longest", 129 | truncation=True, 130 | max_length=40, 131 | return_tensors="pt", 132 | ).to(image.device) 133 | 134 | text.input_ids[:, 0] = self.tokenizer.bos_token_id 135 | 136 | decoder_targets = text.input_ids.masked_fill( 137 | text.input_ids == self.tokenizer.pad_token_id, -100 138 | ) 139 | decoder_targets[:, : self.prompt_length] = -100 140 | 141 | decoder_output = self.text_decoder( 142 | text.input_ids, 143 | attention_mask=text.attention_mask, 144 | encoder_hidden_states=image_embeds, 145 | encoder_attention_mask=image_atts, 146 | labels=decoder_targets, 147 | return_dict=True, 148 | ) 149 | loss_lm = decoder_output.loss 150 | 151 | return loss_lm 152 | 153 | def generate( 154 | self, 155 | image, 156 | sample=False, 157 | num_beams=3, 158 | max_length=30, 159 | min_length=10, 160 | top_p=0.9, 161 | repetition_penalty=1.0, 162 | ): 163 | image_embeds = self.visual_encoder(image) 164 | 165 | if not sample: 166 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) 167 | 168 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 169 | image.device 170 | ) 171 | model_kwargs = { 172 | "encoder_hidden_states": image_embeds, 173 | "encoder_attention_mask": image_atts, 174 | } 175 | 176 | prompt = [self.prompt] * image.size(0) 177 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 178 | image.device 179 | ) 180 | input_ids[:, 0] = self.tokenizer.bos_token_id 181 | input_ids = input_ids[:, :-1] 182 | 183 | if sample: 184 | # nucleus sampling 185 | outputs = self.text_decoder.generate( 186 | input_ids=input_ids, 187 | max_length=max_length, 188 | min_length=min_length, 189 | do_sample=True, 190 | top_p=top_p, 191 | num_return_sequences=1, 192 | eos_token_id=self.tokenizer.sep_token_id, 193 | pad_token_id=self.tokenizer.pad_token_id, 194 | repetition_penalty=1.1, 195 | **model_kwargs, 196 | ) 197 | else: 198 | # beam search 199 | outputs = self.text_decoder.generate( 200 | input_ids=input_ids, 201 | max_length=max_length, 202 | min_length=min_length, 203 | num_beams=num_beams, 204 | eos_token_id=self.tokenizer.sep_token_id, 205 | pad_token_id=self.tokenizer.pad_token_id, 206 | repetition_penalty=repetition_penalty, 207 | **model_kwargs, 208 | ) 209 | 210 | captions = [] 211 | for output in outputs: 212 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 213 | captions.append(caption[len(self.prompt) :]) 214 | return captions 215 | 216 | 217 | def blip_decoder(pretrained="", **kwargs): 218 | model = BLIP_Decoder(**kwargs) 219 | if pretrained: 220 | model, msg = load_checkpoint(model, pretrained) 221 | assert len(msg.missing_keys) == 0 222 | return model 223 | 224 | 225 | def blip_feature_extractor(pretrained="", **kwargs): 226 | model = BLIP_Base(**kwargs) 227 | if pretrained: 228 | model, msg = load_checkpoint(model, pretrained) 229 | assert len(msg.missing_keys) == 0 230 | return model 231 | 232 | 233 | def init_tokenizer(): 234 | try: 235 | bert_pth = os.path.join(get_original_cwd(), "bert-base-uncased") 236 | tokenizer = BertTokenizer.from_pretrained(bert_pth) 237 | except: 238 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 239 | 240 | tokenizer.add_special_tokens({"bos_token": "[DEC]"}) 241 | tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) 242 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 243 | return tokenizer 244 | 245 | 246 | def create_vit( 247 | vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0 248 | ): 249 | assert vit in ["base", "large"], "vit parameter must be base or large" 250 | if vit == "base": 251 | vision_width = 768 252 | visual_encoder = VisionTransformer( 253 | img_size=image_size, 254 | patch_size=16, 255 | embed_dim=vision_width, 256 | depth=12, 257 | num_heads=12, 258 | use_grad_checkpointing=use_grad_checkpointing, 259 | ckpt_layer=ckpt_layer, 260 | drop_path_rate=0 or drop_path_rate, 261 | ) 262 | elif vit == "large": 263 | vision_width = 1024 264 | visual_encoder = VisionTransformer( 265 | img_size=image_size, 266 | patch_size=16, 267 | embed_dim=vision_width, 268 | depth=24, 269 | num_heads=16, 270 | use_grad_checkpointing=use_grad_checkpointing, 271 | ckpt_layer=ckpt_layer, 272 | drop_path_rate=0.1 or drop_path_rate, 273 | ) 274 | else: 275 | raise NotImplementedError 276 | return visual_encoder, vision_width 277 | 278 | 279 | def is_url(url_or_filename): 280 | parsed = urlparse(url_or_filename) 281 | return parsed.scheme in ("http", "https") 282 | 283 | 284 | def load_checkpoint(model, url_or_filename): 285 | if is_url(url_or_filename): 286 | cached_file = download_cached_file( 287 | url_or_filename, check_hash=False, progress=True 288 | ) 289 | checkpoint = torch.load(cached_file, map_location="cpu") 290 | elif os.path.isfile(url_or_filename): 291 | checkpoint = torch.load(url_or_filename, map_location="cpu") 292 | else: 293 | raise RuntimeError(f"checkpoint {url_or_filename} is invalid") 294 | 295 | state_dict = checkpoint["model"] 296 | state_dict = remove_module(state_dict) 297 | 298 | state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( 299 | state_dict["visual_encoder.pos_embed"], model.visual_encoder 300 | ) 301 | if "visual_encoder_m.pos_embed" in model.state_dict().keys(): 302 | state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( 303 | state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m 304 | ) 305 | for key in model.state_dict().keys(): 306 | if key in state_dict.keys(): 307 | if state_dict[key].shape != model.state_dict()[key].shape: 308 | del state_dict[key] 309 | 310 | msg = model.load_state_dict(state_dict, strict=False) 311 | print("load checkpoint from %s" % url_or_filename) 312 | return model, msg 313 | 314 | 315 | def remove_module(state_dict): 316 | new_state_dict = {} 317 | for key in state_dict.keys(): 318 | if key.startswith("module."): 319 | new_state_dict[key[7:]] = state_dict[key] 320 | else: 321 | new_state_dict[key] = state_dict[key] 322 | return new_state_dict 323 | -------------------------------------------------------------------------------- /model/blip/med.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on huggingface code base 8 | * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert 9 | """ 10 | 11 | import math 12 | from typing import Tuple 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.utils.checkpoint 17 | from torch import Tensor, device, dtype, nn 18 | from torch.nn import CrossEntropyLoss 19 | from transformers.activations import ACT2FN 20 | from transformers.file_utils import ModelOutput 21 | from transformers.modeling_outputs import ( 22 | BaseModelOutputWithPastAndCrossAttentions, 23 | BaseModelOutputWithPoolingAndCrossAttentions, 24 | CausalLMOutputWithCrossAttentions, 25 | ) 26 | from transformers.modeling_utils import ( 27 | PreTrainedModel, 28 | apply_chunking_to_forward, 29 | find_pruneable_heads_and_indices, 30 | prune_linear_layer, 31 | ) 32 | from transformers.models.bert.configuration_bert import BertConfig 33 | from transformers.utils import logging 34 | 35 | logger = logging.get_logger(__name__) 36 | 37 | 38 | class BertEmbeddings(nn.Module): 39 | """Construct the embeddings from word and position embeddings.""" 40 | 41 | def __init__(self, config): 42 | super().__init__() 43 | self.word_embeddings = nn.Embedding( 44 | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id 45 | ) 46 | self.position_embeddings = nn.Embedding( 47 | config.max_position_embeddings, config.hidden_size 48 | ) 49 | 50 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 51 | # any TensorFlow checkpoint file 52 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 53 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 54 | 55 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 56 | self.register_buffer( 57 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) 58 | ) 59 | self.position_embedding_type = getattr( 60 | config, "position_embedding_type", "absolute" 61 | ) 62 | 63 | self.config = config 64 | 65 | def forward( 66 | self, 67 | input_ids=None, 68 | position_ids=None, 69 | inputs_embeds=None, 70 | past_key_values_length=0, 71 | ): 72 | if input_ids is not None: 73 | input_shape = input_ids.size() 74 | else: 75 | input_shape = inputs_embeds.size()[:-1] 76 | 77 | seq_length = input_shape[1] 78 | 79 | if position_ids is None: 80 | position_ids = self.position_ids[ 81 | :, past_key_values_length : seq_length + past_key_values_length 82 | ] 83 | 84 | if inputs_embeds is None: 85 | inputs_embeds = self.word_embeddings(input_ids) 86 | 87 | embeddings = inputs_embeds 88 | 89 | if self.position_embedding_type == "absolute": 90 | position_embeddings = self.position_embeddings(position_ids) 91 | embeddings += position_embeddings 92 | embeddings = self.LayerNorm(embeddings) 93 | embeddings = self.dropout(embeddings) 94 | return embeddings 95 | 96 | 97 | class BertSelfAttention(nn.Module): 98 | def __init__(self, config, is_cross_attention): 99 | super().__init__() 100 | self.config = config 101 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr( 102 | config, "embedding_size" 103 | ): 104 | raise ValueError( 105 | "The hidden size (%d) is not a multiple of the number of attention " 106 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 107 | ) 108 | 109 | self.num_attention_heads = config.num_attention_heads 110 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 111 | self.all_head_size = self.num_attention_heads * self.attention_head_size 112 | 113 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 114 | if is_cross_attention: 115 | self.key = nn.Linear(config.encoder_width, self.all_head_size) 116 | self.value = nn.Linear(config.encoder_width, self.all_head_size) 117 | else: 118 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 119 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 120 | 121 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 122 | self.position_embedding_type = getattr( 123 | config, "position_embedding_type", "absolute" 124 | ) 125 | if ( 126 | self.position_embedding_type == "relative_key" 127 | or self.position_embedding_type == "relative_key_query" 128 | ): 129 | self.max_position_embeddings = config.max_position_embeddings 130 | self.distance_embedding = nn.Embedding( 131 | 2 * config.max_position_embeddings - 1, self.attention_head_size 132 | ) 133 | self.save_attention = False 134 | 135 | def save_attn_gradients(self, attn_gradients): 136 | self.attn_gradients = attn_gradients 137 | 138 | def get_attn_gradients(self): 139 | return self.attn_gradients 140 | 141 | def save_attention_map(self, attention_map): 142 | self.attention_map = attention_map 143 | 144 | def get_attention_map(self): 145 | return self.attention_map 146 | 147 | def transpose_for_scores(self, x): 148 | new_x_shape = x.size()[:-1] + ( 149 | self.num_attention_heads, 150 | self.attention_head_size, 151 | ) 152 | x = x.view(*new_x_shape) 153 | return x.permute(0, 2, 1, 3) 154 | 155 | def forward( 156 | self, 157 | hidden_states, 158 | attention_mask=None, 159 | head_mask=None, 160 | encoder_hidden_states=None, 161 | encoder_attention_mask=None, 162 | past_key_value=None, 163 | output_attentions=False, 164 | ): 165 | mixed_query_layer = self.query(hidden_states) 166 | 167 | # If this is instantiated as a cross-attention module, the keys 168 | # and values come from an encoder; the attention mask needs to be 169 | # such that the encoder's padding tokens are not attended to. 170 | is_cross_attention = encoder_hidden_states is not None 171 | 172 | if is_cross_attention: 173 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 174 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 175 | attention_mask = encoder_attention_mask 176 | elif past_key_value is not None: 177 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 178 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 179 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 180 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 181 | else: 182 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 183 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 184 | 185 | query_layer = self.transpose_for_scores(mixed_query_layer) 186 | 187 | past_key_value = (key_layer, value_layer) 188 | 189 | # Take the dot product between "query" and "key" to get the raw attention scores. 190 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 191 | 192 | if ( 193 | self.position_embedding_type == "relative_key" 194 | or self.position_embedding_type == "relative_key_query" 195 | ): 196 | seq_length = hidden_states.size()[1] 197 | position_ids_l = torch.arange( 198 | seq_length, dtype=torch.long, device=hidden_states.device 199 | ).view(-1, 1) 200 | position_ids_r = torch.arange( 201 | seq_length, dtype=torch.long, device=hidden_states.device 202 | ).view(1, -1) 203 | distance = position_ids_l - position_ids_r 204 | positional_embedding = self.distance_embedding( 205 | distance + self.max_position_embeddings - 1 206 | ) 207 | positional_embedding = positional_embedding.to( 208 | dtype=query_layer.dtype 209 | ) # fp16 compatibility 210 | 211 | if self.position_embedding_type == "relative_key": 212 | relative_position_scores = torch.einsum( 213 | "bhld,lrd->bhlr", query_layer, positional_embedding 214 | ) 215 | attention_scores = attention_scores + relative_position_scores 216 | elif self.position_embedding_type == "relative_key_query": 217 | relative_position_scores_query = torch.einsum( 218 | "bhld,lrd->bhlr", query_layer, positional_embedding 219 | ) 220 | relative_position_scores_key = torch.einsum( 221 | "bhrd,lrd->bhlr", key_layer, positional_embedding 222 | ) 223 | attention_scores = ( 224 | attention_scores 225 | + relative_position_scores_query 226 | + relative_position_scores_key 227 | ) 228 | 229 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 230 | if attention_mask is not None: 231 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 232 | attention_scores = attention_scores + attention_mask 233 | 234 | # Normalize the attention scores to probabilities. 235 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 236 | 237 | if is_cross_attention and self.save_attention: 238 | self.save_attention_map(attention_probs) 239 | attention_probs.register_hook(self.save_attn_gradients) 240 | 241 | # This is actually dropping out entire tokens to attend to, which might 242 | # seem a bit unusual, but is taken from the original Transformer paper. 243 | attention_probs_dropped = self.dropout(attention_probs) 244 | 245 | # Mask heads if we want to 246 | if head_mask is not None: 247 | attention_probs_dropped = attention_probs_dropped * head_mask 248 | 249 | context_layer = torch.matmul(attention_probs_dropped, value_layer) 250 | 251 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 252 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 253 | context_layer = context_layer.view(*new_context_layer_shape) 254 | 255 | outputs = ( 256 | (context_layer, attention_probs) if output_attentions else (context_layer,) 257 | ) 258 | 259 | outputs = outputs + (past_key_value,) 260 | return outputs 261 | 262 | 263 | class BertSelfOutput(nn.Module): 264 | def __init__(self, config): 265 | super().__init__() 266 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 267 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 268 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 269 | 270 | def forward(self, hidden_states, input_tensor): 271 | hidden_states = self.dense(hidden_states) 272 | hidden_states = self.dropout(hidden_states) 273 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 274 | return hidden_states 275 | 276 | 277 | class BertAttention(nn.Module): 278 | def __init__(self, config, is_cross_attention=False): 279 | super().__init__() 280 | self.self = BertSelfAttention(config, is_cross_attention) 281 | self.output = BertSelfOutput(config) 282 | self.pruned_heads = set() 283 | 284 | def prune_heads(self, heads): 285 | if len(heads) == 0: 286 | return 287 | heads, index = find_pruneable_heads_and_indices( 288 | heads, 289 | self.self.num_attention_heads, 290 | self.self.attention_head_size, 291 | self.pruned_heads, 292 | ) 293 | 294 | # Prune linear layers 295 | self.self.query = prune_linear_layer(self.self.query, index) 296 | self.self.key = prune_linear_layer(self.self.key, index) 297 | self.self.value = prune_linear_layer(self.self.value, index) 298 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 299 | 300 | # Update hyper params and store pruned heads 301 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 302 | self.self.all_head_size = ( 303 | self.self.attention_head_size * self.self.num_attention_heads 304 | ) 305 | self.pruned_heads = self.pruned_heads.union(heads) 306 | 307 | def forward( 308 | self, 309 | hidden_states, 310 | attention_mask=None, 311 | head_mask=None, 312 | encoder_hidden_states=None, 313 | encoder_attention_mask=None, 314 | past_key_value=None, 315 | output_attentions=False, 316 | ): 317 | self_outputs = self.self( 318 | hidden_states, 319 | attention_mask, 320 | head_mask, 321 | encoder_hidden_states, 322 | encoder_attention_mask, 323 | past_key_value, 324 | output_attentions, 325 | ) 326 | attention_output = self.output(self_outputs[0], hidden_states) 327 | outputs = (attention_output,) + self_outputs[ 328 | 1: 329 | ] # add attentions if we output them 330 | return outputs 331 | 332 | 333 | class BertIntermediate(nn.Module): 334 | def __init__(self, config): 335 | super().__init__() 336 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 337 | if isinstance(config.hidden_act, str): 338 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 339 | else: 340 | self.intermediate_act_fn = config.hidden_act 341 | 342 | def forward(self, hidden_states): 343 | hidden_states = self.dense(hidden_states) 344 | hidden_states = self.intermediate_act_fn(hidden_states) 345 | return hidden_states 346 | 347 | 348 | class BertOutput(nn.Module): 349 | def __init__(self, config): 350 | super().__init__() 351 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 352 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 353 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 354 | 355 | def forward(self, hidden_states, input_tensor): 356 | hidden_states = self.dense(hidden_states) 357 | hidden_states = self.dropout(hidden_states) 358 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 359 | return hidden_states 360 | 361 | 362 | class BertLayer(nn.Module): 363 | def __init__(self, config, layer_num): 364 | super().__init__() 365 | self.config = config 366 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 367 | self.seq_len_dim = 1 368 | self.attention = BertAttention(config) 369 | self.layer_num = layer_num 370 | if self.config.add_cross_attention: 371 | self.crossattention = BertAttention( 372 | config, is_cross_attention=self.config.add_cross_attention 373 | ) 374 | self.intermediate = BertIntermediate(config) 375 | self.output = BertOutput(config) 376 | 377 | def forward( 378 | self, 379 | hidden_states, 380 | attention_mask=None, 381 | head_mask=None, 382 | encoder_hidden_states=None, 383 | encoder_attention_mask=None, 384 | past_key_value=None, 385 | output_attentions=False, 386 | mode=None, 387 | ): 388 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 389 | self_attn_past_key_value = ( 390 | past_key_value[:2] if past_key_value is not None else None 391 | ) 392 | self_attention_outputs = self.attention( 393 | hidden_states, 394 | attention_mask, 395 | head_mask, 396 | output_attentions=output_attentions, 397 | past_key_value=self_attn_past_key_value, 398 | ) 399 | attention_output = self_attention_outputs[0] 400 | 401 | outputs = self_attention_outputs[1:-1] 402 | present_key_value = self_attention_outputs[-1] 403 | 404 | if mode == "multimodal": 405 | assert ( 406 | encoder_hidden_states is not None 407 | ), "encoder_hidden_states must be given for cross-attention layers" 408 | 409 | cross_attention_outputs = self.crossattention( 410 | attention_output, 411 | attention_mask, 412 | head_mask, 413 | encoder_hidden_states, 414 | encoder_attention_mask, 415 | output_attentions=output_attentions, 416 | ) 417 | attention_output = cross_attention_outputs[0] 418 | outputs = ( 419 | outputs + cross_attention_outputs[1:-1] 420 | ) # add cross attentions if we output attention weights 421 | layer_output = apply_chunking_to_forward( 422 | self.feed_forward_chunk, 423 | self.chunk_size_feed_forward, 424 | self.seq_len_dim, 425 | attention_output, 426 | ) 427 | outputs = (layer_output,) + outputs 428 | 429 | outputs = outputs + (present_key_value,) 430 | 431 | return outputs 432 | 433 | def feed_forward_chunk(self, attention_output): 434 | intermediate_output = self.intermediate(attention_output) 435 | layer_output = self.output(intermediate_output, attention_output) 436 | return layer_output 437 | 438 | 439 | class BertEncoder(nn.Module): 440 | def __init__(self, config): 441 | super().__init__() 442 | self.config = config 443 | self.layer = nn.ModuleList( 444 | [BertLayer(config, i) for i in range(config.num_hidden_layers)] 445 | ) 446 | self.gradient_checkpointing = False 447 | 448 | def forward( 449 | self, 450 | hidden_states, 451 | attention_mask=None, 452 | head_mask=None, 453 | encoder_hidden_states=None, 454 | encoder_attention_mask=None, 455 | past_key_values=None, 456 | use_cache=None, 457 | output_attentions=False, 458 | output_hidden_states=False, 459 | return_dict=True, 460 | mode="multimodal", 461 | ): 462 | all_hidden_states = () if output_hidden_states else None 463 | all_self_attentions = () if output_attentions else None 464 | all_cross_attentions = ( 465 | () if output_attentions and self.config.add_cross_attention else None 466 | ) 467 | 468 | next_decoder_cache = () if use_cache else None 469 | 470 | for i in range(self.config.num_hidden_layers): 471 | layer_module = self.layer[i] 472 | if output_hidden_states: 473 | all_hidden_states = all_hidden_states + (hidden_states,) 474 | 475 | layer_head_mask = head_mask[i] if head_mask is not None else None 476 | past_key_value = past_key_values[i] if past_key_values is not None else None 477 | 478 | if self.gradient_checkpointing and self.training: 479 | if use_cache: 480 | logger.warn( 481 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 482 | ) 483 | use_cache = False 484 | 485 | def create_custom_forward(module): 486 | def custom_forward(*inputs): 487 | return module(*inputs, past_key_value, output_attentions) 488 | 489 | return custom_forward 490 | 491 | layer_outputs = torch.utils.checkpoint.checkpoint( 492 | create_custom_forward(layer_module), 493 | hidden_states, 494 | attention_mask, 495 | layer_head_mask, 496 | encoder_hidden_states, 497 | encoder_attention_mask, 498 | mode=mode, 499 | ) 500 | else: 501 | layer_outputs = layer_module( 502 | hidden_states, 503 | attention_mask, 504 | layer_head_mask, 505 | encoder_hidden_states, 506 | encoder_attention_mask, 507 | past_key_value, 508 | output_attentions, 509 | mode=mode, 510 | ) 511 | 512 | hidden_states = layer_outputs[0] 513 | if use_cache: 514 | next_decoder_cache += (layer_outputs[-1],) 515 | if output_attentions: 516 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 517 | 518 | if output_hidden_states: 519 | all_hidden_states = all_hidden_states + (hidden_states,) 520 | 521 | if not return_dict: 522 | return tuple( 523 | v 524 | for v in [ 525 | hidden_states, 526 | next_decoder_cache, 527 | all_hidden_states, 528 | all_self_attentions, 529 | all_cross_attentions, 530 | ] 531 | if v is not None 532 | ) 533 | return BaseModelOutputWithPastAndCrossAttentions( 534 | last_hidden_state=hidden_states, 535 | past_key_values=next_decoder_cache, 536 | hidden_states=all_hidden_states, 537 | attentions=all_self_attentions, 538 | cross_attentions=all_cross_attentions, 539 | ) 540 | 541 | 542 | class BertPooler(nn.Module): 543 | def __init__(self, config): 544 | super().__init__() 545 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 546 | self.activation = nn.Tanh() 547 | 548 | def forward(self, hidden_states): 549 | # We "pool" the model by simply taking the hidden state corresponding 550 | # to the first token. 551 | first_token_tensor = hidden_states[:, 0] 552 | pooled_output = self.dense(first_token_tensor) 553 | pooled_output = self.activation(pooled_output) 554 | return pooled_output 555 | 556 | 557 | class BertPredictionHeadTransform(nn.Module): 558 | def __init__(self, config): 559 | super().__init__() 560 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 561 | if isinstance(config.hidden_act, str): 562 | self.transform_act_fn = ACT2FN[config.hidden_act] 563 | else: 564 | self.transform_act_fn = config.hidden_act 565 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 566 | 567 | def forward(self, hidden_states): 568 | hidden_states = self.dense(hidden_states) 569 | hidden_states = self.transform_act_fn(hidden_states) 570 | hidden_states = self.LayerNorm(hidden_states) 571 | return hidden_states 572 | 573 | 574 | class BertLMPredictionHead(nn.Module): 575 | def __init__(self, config): 576 | super().__init__() 577 | self.transform = BertPredictionHeadTransform(config) 578 | 579 | # The output weights are the same as the input embeddings, but there is 580 | # an output-only bias for each token. 581 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 582 | 583 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 584 | 585 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 586 | self.decoder.bias = self.bias 587 | 588 | def forward(self, hidden_states): 589 | hidden_states = self.transform(hidden_states) 590 | hidden_states = self.decoder(hidden_states) 591 | return hidden_states 592 | 593 | 594 | class BertOnlyMLMHead(nn.Module): 595 | def __init__(self, config): 596 | super().__init__() 597 | self.predictions = BertLMPredictionHead(config) 598 | 599 | def forward(self, sequence_output): 600 | prediction_scores = self.predictions(sequence_output) 601 | return prediction_scores 602 | 603 | 604 | class BertPreTrainedModel(PreTrainedModel): 605 | """ 606 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 607 | models. 608 | """ 609 | 610 | config_class = BertConfig 611 | base_model_prefix = "bert" 612 | _keys_to_ignore_on_load_missing = [r"position_ids"] 613 | 614 | def _init_weights(self, module): 615 | """Initialize the weights""" 616 | if isinstance(module, (nn.Linear, nn.Embedding)): 617 | # Slightly different from the TF version which uses truncated_normal for initialization 618 | # cf https://github.com/pytorch/pytorch/pull/5617 619 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 620 | elif isinstance(module, nn.LayerNorm): 621 | module.bias.data.zero_() 622 | module.weight.data.fill_(1.0) 623 | if isinstance(module, nn.Linear) and module.bias is not None: 624 | module.bias.data.zero_() 625 | 626 | 627 | class BertModel(BertPreTrainedModel): 628 | """ 629 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 630 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is 631 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 632 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 633 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an 634 | input to the forward pass. 635 | """ 636 | 637 | def __init__(self, config, add_pooling_layer=True): 638 | super().__init__(config) 639 | self.config = config 640 | 641 | self.embeddings = BertEmbeddings(config) 642 | 643 | self.encoder = BertEncoder(config) 644 | 645 | self.pooler = BertPooler(config) if add_pooling_layer else None 646 | 647 | self.init_weights() 648 | 649 | def get_input_embeddings(self): 650 | return self.embeddings.word_embeddings 651 | 652 | def set_input_embeddings(self, value): 653 | self.embeddings.word_embeddings = value 654 | 655 | def _prune_heads(self, heads_to_prune): 656 | """ 657 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 658 | class PreTrainedModel 659 | """ 660 | for layer, heads in heads_to_prune.items(): 661 | self.encoder.layer[layer].attention.prune_heads(heads) 662 | 663 | def get_extended_attention_mask( 664 | self, 665 | attention_mask: Tensor, 666 | input_shape: Tuple[int], 667 | device: device, 668 | is_decoder: bool, 669 | ) -> Tensor: 670 | """ 671 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. 672 | 673 | Arguments: 674 | attention_mask (:obj:`torch.Tensor`): 675 | Mask with ones indicating tokens to attend to, zeros for tokens to ignore. 676 | input_shape (:obj:`Tuple[int]`): 677 | The shape of the input to the model. 678 | device: (:obj:`torch.device`): 679 | The device of the input to the model. 680 | 681 | Returns: 682 | :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. 683 | """ 684 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 685 | # ourselves in which case we just need to make it broadcastable to all heads. 686 | if attention_mask.dim() == 3: 687 | extended_attention_mask = attention_mask[:, None, :, :] 688 | elif attention_mask.dim() == 2: 689 | # Provided a padding mask of dimensions [batch_size, seq_length] 690 | # - if the model is a decoder, apply a causal mask in addition to the padding mask 691 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] 692 | if is_decoder: 693 | batch_size, seq_length = input_shape 694 | 695 | seq_ids = torch.arange(seq_length, device=device) 696 | causal_mask = ( 697 | seq_ids[None, None, :].repeat(batch_size, seq_length, 1) 698 | <= seq_ids[None, :, None] 699 | ) 700 | # in case past_key_values are used we need to add a prefix ones mask to the causal mask 701 | # causal and attention masks must have same type with pytorch version < 1.3 702 | causal_mask = causal_mask.to(attention_mask.dtype) 703 | 704 | if causal_mask.shape[1] < attention_mask.shape[1]: 705 | prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] 706 | causal_mask = torch.cat( 707 | [ 708 | torch.ones( 709 | (batch_size, seq_length, prefix_seq_len), 710 | device=device, 711 | dtype=causal_mask.dtype, 712 | ), 713 | causal_mask, 714 | ], 715 | axis=-1, 716 | ) 717 | 718 | extended_attention_mask = ( 719 | causal_mask[:, None, :, :] * attention_mask[:, None, None, :] 720 | ) 721 | else: 722 | extended_attention_mask = attention_mask[:, None, None, :] 723 | else: 724 | raise ValueError( 725 | "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( 726 | input_shape, attention_mask.shape 727 | ) 728 | ) 729 | 730 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 731 | # masked positions, this operation will create a tensor which is 0.0 for 732 | # positions we want to attend and -10000.0 for masked positions. 733 | # Since we are adding it to the raw scores before the softmax, this is 734 | # effectively the same as removing these entirely. 735 | extended_attention_mask = extended_attention_mask.to( 736 | dtype=self.dtype 737 | ) # fp16 compatibility 738 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 739 | return extended_attention_mask 740 | 741 | def forward( 742 | self, 743 | input_ids=None, 744 | attention_mask=None, 745 | position_ids=None, 746 | head_mask=None, 747 | inputs_embeds=None, 748 | encoder_embeds=None, 749 | encoder_hidden_states=None, 750 | encoder_attention_mask=None, 751 | past_key_values=None, 752 | use_cache=None, 753 | output_attentions=None, 754 | output_hidden_states=None, 755 | return_dict=None, 756 | is_decoder=False, 757 | mode="multimodal", 758 | ): 759 | r""" 760 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 761 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 762 | the model is configured as a decoder. 763 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 764 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 765 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 766 | - 1 for tokens that are **not masked**, 767 | - 0 for tokens that are **masked**. 768 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 769 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 770 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 771 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 772 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 773 | use_cache (:obj:`bool`, `optional`): 774 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 775 | decoding (see :obj:`past_key_values`). 776 | """ 777 | output_attentions = ( 778 | output_attentions 779 | if output_attentions is not None 780 | else self.config.output_attentions 781 | ) 782 | output_hidden_states = ( 783 | output_hidden_states 784 | if output_hidden_states is not None 785 | else self.config.output_hidden_states 786 | ) 787 | return_dict = ( 788 | return_dict if return_dict is not None else self.config.use_return_dict 789 | ) 790 | 791 | if is_decoder: 792 | use_cache = use_cache if use_cache is not None else self.config.use_cache 793 | else: 794 | use_cache = False 795 | 796 | if input_ids is not None and inputs_embeds is not None: 797 | raise ValueError( 798 | "You cannot specify both input_ids and inputs_embeds at the same time" 799 | ) 800 | elif input_ids is not None: 801 | input_shape = input_ids.size() 802 | batch_size, seq_length = input_shape 803 | device = input_ids.device 804 | elif inputs_embeds is not None: 805 | input_shape = inputs_embeds.size()[:-1] 806 | batch_size, seq_length = input_shape 807 | device = inputs_embeds.device 808 | elif encoder_embeds is not None: 809 | input_shape = encoder_embeds.size()[:-1] 810 | batch_size, seq_length = input_shape 811 | device = encoder_embeds.device 812 | else: 813 | raise ValueError( 814 | "You have to specify either input_ids or inputs_embeds or encoder_embeds" 815 | ) 816 | 817 | # past_key_values_length 818 | past_key_values_length = ( 819 | past_key_values[0][0].shape[2] if past_key_values is not None else 0 820 | ) 821 | 822 | if attention_mask is None: 823 | attention_mask = torch.ones( 824 | ((batch_size, seq_length + past_key_values_length)), device=device 825 | ) 826 | 827 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 828 | # ourselves in which case we just need to make it broadcastable to all heads. 829 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( 830 | attention_mask, input_shape, device, is_decoder 831 | ) 832 | 833 | # If a 2D or 3D attention mask is provided for the cross-attention 834 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 835 | if encoder_hidden_states is not None: 836 | if type(encoder_hidden_states) == list: 837 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ 838 | 0 839 | ].size() 840 | else: 841 | ( 842 | encoder_batch_size, 843 | encoder_sequence_length, 844 | _, 845 | ) = encoder_hidden_states.size() 846 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 847 | 848 | if type(encoder_attention_mask) == list: 849 | encoder_extended_attention_mask = [ 850 | self.invert_attention_mask(mask) for mask in encoder_attention_mask 851 | ] 852 | elif encoder_attention_mask is None: 853 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 854 | encoder_extended_attention_mask = self.invert_attention_mask( 855 | encoder_attention_mask 856 | ) 857 | else: 858 | encoder_extended_attention_mask = self.invert_attention_mask( 859 | encoder_attention_mask 860 | ) 861 | else: 862 | encoder_extended_attention_mask = None 863 | 864 | # Prepare head mask if needed 865 | # 1.0 in head_mask indicate we keep the head 866 | # attention_probs has shape bsz x n_heads x N x N 867 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 868 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 869 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 870 | 871 | if encoder_embeds is None: 872 | embedding_output = self.embeddings( 873 | input_ids=input_ids, 874 | position_ids=position_ids, 875 | inputs_embeds=inputs_embeds, 876 | past_key_values_length=past_key_values_length, 877 | ) 878 | else: 879 | embedding_output = encoder_embeds 880 | 881 | encoder_outputs = self.encoder( 882 | embedding_output, 883 | attention_mask=extended_attention_mask, 884 | head_mask=head_mask, 885 | encoder_hidden_states=encoder_hidden_states, 886 | encoder_attention_mask=encoder_extended_attention_mask, 887 | past_key_values=past_key_values, 888 | use_cache=use_cache, 889 | output_attentions=output_attentions, 890 | output_hidden_states=output_hidden_states, 891 | return_dict=return_dict, 892 | mode=mode, 893 | ) 894 | sequence_output = encoder_outputs[0] 895 | pooled_output = ( 896 | self.pooler(sequence_output) if self.pooler is not None else None 897 | ) 898 | 899 | if not return_dict: 900 | return (sequence_output, pooled_output) + encoder_outputs[1:] 901 | 902 | return BaseModelOutputWithPoolingAndCrossAttentions( 903 | last_hidden_state=sequence_output, 904 | pooler_output=pooled_output, 905 | past_key_values=encoder_outputs.past_key_values, 906 | hidden_states=encoder_outputs.hidden_states, 907 | attentions=encoder_outputs.attentions, 908 | cross_attentions=encoder_outputs.cross_attentions, 909 | ) 910 | 911 | 912 | class BertLMHeadModel(BertPreTrainedModel): 913 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 914 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 915 | 916 | def __init__(self, config): 917 | super().__init__(config) 918 | 919 | self.bert = BertModel(config, add_pooling_layer=False) 920 | self.cls = BertOnlyMLMHead(config) 921 | 922 | self.init_weights() 923 | 924 | def get_output_embeddings(self): 925 | return self.cls.predictions.decoder 926 | 927 | def set_output_embeddings(self, new_embeddings): 928 | self.cls.predictions.decoder = new_embeddings 929 | 930 | def forward( 931 | self, 932 | input_ids=None, 933 | attention_mask=None, 934 | position_ids=None, 935 | head_mask=None, 936 | inputs_embeds=None, 937 | encoder_hidden_states=None, 938 | encoder_attention_mask=None, 939 | labels=None, 940 | past_key_values=None, 941 | use_cache=None, 942 | output_attentions=None, 943 | output_hidden_states=None, 944 | return_dict=None, 945 | return_logits=False, 946 | is_decoder=True, 947 | reduction="mean", 948 | mode="multimodal", 949 | ): 950 | r""" 951 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 952 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 953 | the model is configured as a decoder. 954 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 955 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 956 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 957 | - 1 for tokens that are **not masked**, 958 | - 0 for tokens that are **masked**. 959 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 960 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 961 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are 962 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` 963 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 964 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 965 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 966 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 967 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 968 | use_cache (:obj:`bool`, `optional`): 969 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 970 | decoding (see :obj:`past_key_values`). 971 | Returns: 972 | Example:: 973 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig 974 | >>> import torch 975 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 976 | >>> config = BertConfig.from_pretrained("bert-base-cased") 977 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) 978 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 979 | >>> outputs = model(**inputs) 980 | >>> prediction_logits = outputs.logits 981 | """ 982 | return_dict = ( 983 | return_dict if return_dict is not None else self.config.use_return_dict 984 | ) 985 | if labels is not None: 986 | use_cache = False 987 | 988 | outputs = self.bert( 989 | input_ids, 990 | attention_mask=attention_mask, 991 | position_ids=position_ids, 992 | head_mask=head_mask, 993 | inputs_embeds=inputs_embeds, 994 | encoder_hidden_states=encoder_hidden_states, 995 | encoder_attention_mask=encoder_attention_mask, 996 | past_key_values=past_key_values, 997 | use_cache=use_cache, 998 | output_attentions=output_attentions, 999 | output_hidden_states=output_hidden_states, 1000 | return_dict=return_dict, 1001 | is_decoder=is_decoder, 1002 | mode=mode, 1003 | ) 1004 | 1005 | sequence_output = outputs[0] 1006 | prediction_scores = self.cls(sequence_output) 1007 | 1008 | if return_logits: 1009 | return prediction_scores[:, :-1, :].contiguous() 1010 | 1011 | lm_loss = None 1012 | if labels is not None: 1013 | # we are doing next-token prediction; shift prediction scores and input ids by one 1014 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 1015 | labels = labels[:, 1:].contiguous() 1016 | loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) 1017 | lm_loss = loss_fct( 1018 | shifted_prediction_scores.view(-1, self.config.vocab_size), 1019 | labels.view(-1), 1020 | ) 1021 | if reduction == "none": 1022 | lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) 1023 | 1024 | if not return_dict: 1025 | output = (prediction_scores,) + outputs[2:] 1026 | return ((lm_loss,) + output) if lm_loss is not None else output 1027 | 1028 | return CausalLMOutputWithCrossAttentions( 1029 | loss=lm_loss, 1030 | logits=prediction_scores, 1031 | past_key_values=outputs.past_key_values, 1032 | hidden_states=outputs.hidden_states, 1033 | attentions=outputs.attentions, 1034 | cross_attentions=outputs.cross_attentions, 1035 | ) 1036 | 1037 | def prepare_inputs_for_generation( 1038 | self, input_ids, past=None, attention_mask=None, **model_kwargs 1039 | ): 1040 | input_shape = input_ids.shape 1041 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1042 | if attention_mask is None: 1043 | attention_mask = input_ids.new_ones(input_shape) 1044 | 1045 | # cut decoder_input_ids if past is used 1046 | if past is not None: 1047 | input_ids = input_ids[:, -1:] 1048 | 1049 | return { 1050 | "input_ids": input_ids, 1051 | "attention_mask": attention_mask, 1052 | "past_key_values": past, 1053 | "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), 1054 | "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), 1055 | "is_decoder": True, 1056 | } 1057 | 1058 | def _reorder_cache(self, past, beam_idx): 1059 | reordered_past = () 1060 | for layer_past in past: 1061 | reordered_past += ( 1062 | tuple( 1063 | past_state.index_select(0, beam_idx) for past_state in layer_past 1064 | ), 1065 | ) 1066 | return reordered_past 1067 | -------------------------------------------------------------------------------- /model/blip/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /model/blip/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2023 Lucas Ventura 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from typing import Any 26 | 27 | import einops 28 | import torch 29 | import torch.nn.functional as F 30 | from torch import nn 31 | from transformers.models.bert.configuration_bert import BertConfig 32 | 33 | from model.blip.base import create_vit, init_tokenizer, load_checkpoint 34 | from model.blip.med import BertModel 35 | 36 | 37 | class BLIPCir(nn.Module): 38 | def __init__( 39 | self, 40 | loss: Any, 41 | med_config="configs/med_config.json", 42 | image_size=384, 43 | vit="large", 44 | vit_grad_ckpt=True, 45 | vit_ckpt_layer=12, 46 | embed_dim=256, 47 | train_vit=False, 48 | ): 49 | """ 50 | Args: 51 | med_config (str): path for the mixture of encoder-decoder model's configuration file 52 | image_size (int): input image size 53 | vit (str): model size of vision transformer 54 | """ 55 | super().__init__() 56 | 57 | self.loss = loss 58 | 59 | self.visual_encoder, vision_width = create_vit( 60 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer 61 | ) 62 | self.tokenizer = init_tokenizer() 63 | med_config = BertConfig.from_json_file(med_config) 64 | med_config.encoder_width = vision_width 65 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 66 | 67 | text_width = self.text_encoder.config.hidden_size 68 | 69 | self.vision_proj = nn.Linear(vision_width, embed_dim) 70 | self.text_proj = nn.Linear(text_width, embed_dim) 71 | 72 | self.train_vit = train_vit 73 | if not self.train_vit: 74 | # Do not train visual encoder 75 | for p in self.visual_encoder.parameters(): 76 | p.requires_grad = False 77 | 78 | for p in self.vision_proj.parameters(): 79 | p.requires_grad = False 80 | 81 | self.temp = 0.07 82 | 83 | def forward(self, batch, fabric): 84 | ref_img, tar_feat, caption, _ = batch 85 | 86 | device = ref_img.device 87 | 88 | if self.train_vit: 89 | ref_img_embs = self.visual_encoder(ref_img) 90 | else: 91 | with torch.no_grad(): 92 | ref_img_embs = self.visual_encoder(ref_img) 93 | 94 | # Encode the target image 95 | tar_feat = tar_feat.to(device) 96 | tar_img_feat = F.normalize(tar_feat, dim=-1) 97 | 98 | # Encode the reference image 99 | ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device) 100 | 101 | text = self.tokenizer( 102 | caption, 103 | padding="max_length", 104 | truncation=True, 105 | max_length=35, 106 | return_tensors="pt", 107 | ).to(device) 108 | 109 | # Shift encoder 110 | encoder_input_ids = text.input_ids.clone() 111 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 112 | query_embs = self.text_encoder( 113 | encoder_input_ids, 114 | attention_mask=text.attention_mask, 115 | encoder_hidden_states=ref_img_embs, 116 | encoder_attention_mask=ref_img_atts, 117 | return_dict=True, 118 | ) 119 | query_feat = query_embs.last_hidden_state[:, 0, :] 120 | query_feat = F.normalize(self.text_proj(query_feat), dim=-1) 121 | 122 | if fabric.world_size > 1: 123 | # d: devices, b: batch size, e: embedding dim 124 | query_feat = fabric.all_gather(query_feat, sync_grads=True) 125 | query_feat = einops.rearrange(query_feat, "d b e -> (d b) e") 126 | 127 | tar_img_feat = fabric.all_gather(tar_img_feat, sync_grads=True) 128 | tar_img_feat = einops.rearrange(tar_img_feat, "d b e -> (d b) e") 129 | 130 | return self.loss(query_feat, tar_img_feat, self.temp) 131 | 132 | 133 | def blip_cir(model, ckpt_path, **kwargs): 134 | if ckpt_path: 135 | model, msg = load_checkpoint(model, ckpt_path) 136 | print("missing keys:") 137 | print(msg.missing_keys) 138 | return model 139 | -------------------------------------------------------------------------------- /model/blip/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2023 Lucas Ventura 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from torchvision import transforms 26 | from torchvision.transforms.functional import InterpolationMode 27 | 28 | normalize = transforms.Normalize( 29 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 30 | ) 31 | 32 | 33 | class transform_test(transforms.Compose): 34 | def __init__(self, image_size=384): 35 | self.transform = transforms.Compose( 36 | [ 37 | transforms.Resize( 38 | (image_size, image_size), 39 | interpolation=InterpolationMode.BICUBIC, 40 | ), 41 | transforms.ToTensor(), 42 | normalize, 43 | ] 44 | ) 45 | 46 | def __call__(self, img): 47 | return self.transform(img) 48 | -------------------------------------------------------------------------------- /model/blip/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | """ 10 | 11 | from functools import partial 12 | 13 | import torch 14 | import torch.nn as nn 15 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 16 | from timm.models.helpers import adapt_input_conv 17 | from timm.models.layers import DropPath, trunc_normal_ 18 | from timm.models.vision_transformer import PatchEmbed 19 | 20 | 21 | class Mlp(nn.Module): 22 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 23 | 24 | def __init__( 25 | self, 26 | in_features, 27 | hidden_features=None, 28 | out_features=None, 29 | act_layer=nn.GELU, 30 | drop=0.0, 31 | ): 32 | super().__init__() 33 | out_features = out_features or in_features 34 | hidden_features = hidden_features or in_features 35 | self.fc1 = nn.Linear(in_features, hidden_features) 36 | self.act = act_layer() 37 | self.fc2 = nn.Linear(hidden_features, out_features) 38 | self.drop = nn.Dropout(drop) 39 | 40 | def forward(self, x): 41 | x = self.fc1(x) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | 49 | class Attention(nn.Module): 50 | def __init__( 51 | self, 52 | dim, 53 | num_heads=8, 54 | qkv_bias=False, 55 | qk_scale=None, 56 | attn_drop=0.0, 57 | proj_drop=0.0, 58 | ): 59 | super().__init__() 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 63 | self.scale = qk_scale or head_dim**-0.5 64 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 65 | self.attn_drop = nn.Dropout(attn_drop) 66 | self.proj = nn.Linear(dim, dim) 67 | self.proj_drop = nn.Dropout(proj_drop) 68 | self.attn_gradients = None 69 | self.attention_map = None 70 | 71 | def save_attn_gradients(self, attn_gradients): 72 | self.attn_gradients = attn_gradients 73 | 74 | def get_attn_gradients(self): 75 | return self.attn_gradients 76 | 77 | def save_attention_map(self, attention_map): 78 | self.attention_map = attention_map 79 | 80 | def get_attention_map(self): 81 | return self.attention_map 82 | 83 | def forward(self, x, register_hook=False): 84 | B, N, C = x.shape 85 | qkv = ( 86 | self.qkv(x) 87 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 88 | .permute(2, 0, 3, 1, 4) 89 | ) 90 | q, k, v = ( 91 | qkv[0], 92 | qkv[1], 93 | qkv[2], 94 | ) # make torchscript happy (cannot use tensor as tuple) 95 | 96 | attn = (q @ k.transpose(-2, -1)) * self.scale 97 | attn = attn.softmax(dim=-1) 98 | attn = self.attn_drop(attn) 99 | 100 | if register_hook: 101 | self.save_attention_map(attn) 102 | attn.register_hook(self.save_attn_gradients) 103 | 104 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 105 | x = self.proj(x) 106 | x = self.proj_drop(x) 107 | return x 108 | 109 | 110 | class Block(nn.Module): 111 | def __init__( 112 | self, 113 | dim, 114 | num_heads, 115 | mlp_ratio=4.0, 116 | qkv_bias=False, 117 | qk_scale=None, 118 | drop=0.0, 119 | attn_drop=0.0, 120 | drop_path=0.0, 121 | act_layer=nn.GELU, 122 | norm_layer=nn.LayerNorm, 123 | use_grad_checkpointing=False, 124 | ): 125 | super().__init__() 126 | self.norm1 = norm_layer(dim) 127 | self.attn = Attention( 128 | dim, 129 | num_heads=num_heads, 130 | qkv_bias=qkv_bias, 131 | qk_scale=qk_scale, 132 | attn_drop=attn_drop, 133 | proj_drop=drop, 134 | ) 135 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 136 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 137 | self.norm2 = norm_layer(dim) 138 | mlp_hidden_dim = int(dim * mlp_ratio) 139 | self.mlp = Mlp( 140 | in_features=dim, 141 | hidden_features=mlp_hidden_dim, 142 | act_layer=act_layer, 143 | drop=drop, 144 | ) 145 | 146 | if use_grad_checkpointing: 147 | self.attn = checkpoint_wrapper(self.attn) 148 | self.mlp = checkpoint_wrapper(self.mlp) 149 | 150 | def forward(self, x, register_hook=False): 151 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 152 | x = x + self.drop_path(self.mlp(self.norm2(x))) 153 | return x 154 | 155 | 156 | class VisionTransformer(nn.Module): 157 | """Vision Transformer 158 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 159 | https://arxiv.org/abs/2010.11929 160 | """ 161 | 162 | def __init__( 163 | self, 164 | img_size=224, 165 | patch_size=16, 166 | in_chans=3, 167 | num_classes=1000, 168 | embed_dim=768, 169 | depth=12, 170 | num_heads=12, 171 | mlp_ratio=4.0, 172 | qkv_bias=True, 173 | qk_scale=None, 174 | representation_size=None, 175 | drop_rate=0.0, 176 | attn_drop_rate=0.0, 177 | drop_path_rate=0.0, 178 | norm_layer=None, 179 | use_grad_checkpointing=False, 180 | ckpt_layer=0, 181 | ): 182 | """ 183 | Args: 184 | img_size (int, tuple): input image size 185 | patch_size (int, tuple): patch size 186 | in_chans (int): number of input channels 187 | num_classes (int): number of classes for classification head 188 | embed_dim (int): embedding dimension 189 | depth (int): depth of transformer 190 | num_heads (int): number of attention heads 191 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 192 | qkv_bias (bool): enable bias for qkv if True 193 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 194 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 195 | drop_rate (float): dropout rate 196 | attn_drop_rate (float): attention dropout rate 197 | drop_path_rate (float): stochastic depth rate 198 | norm_layer: (nn.Module): normalization layer 199 | """ 200 | super().__init__() 201 | self.num_features = ( 202 | self.embed_dim 203 | ) = embed_dim # num_features for consistency with other models 204 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 205 | 206 | self.patch_embed = PatchEmbed( 207 | img_size=img_size, 208 | patch_size=patch_size, 209 | in_chans=in_chans, 210 | embed_dim=embed_dim, 211 | ) 212 | 213 | num_patches = self.patch_embed.num_patches 214 | 215 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 216 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 217 | self.pos_drop = nn.Dropout(p=drop_rate) 218 | 219 | dpr = [ 220 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 221 | ] # stochastic depth decay rule 222 | self.blocks = nn.ModuleList( 223 | [ 224 | Block( 225 | dim=embed_dim, 226 | num_heads=num_heads, 227 | mlp_ratio=mlp_ratio, 228 | qkv_bias=qkv_bias, 229 | qk_scale=qk_scale, 230 | drop=drop_rate, 231 | attn_drop=attn_drop_rate, 232 | drop_path=dpr[i], 233 | norm_layer=norm_layer, 234 | use_grad_checkpointing=( 235 | use_grad_checkpointing and i >= depth - ckpt_layer 236 | ), 237 | ) 238 | for i in range(depth) 239 | ] 240 | ) 241 | self.norm = norm_layer(embed_dim) 242 | 243 | trunc_normal_(self.pos_embed, std=0.02) 244 | trunc_normal_(self.cls_token, std=0.02) 245 | self.apply(self._init_weights) 246 | 247 | def _init_weights(self, m): 248 | if isinstance(m, nn.Linear): 249 | trunc_normal_(m.weight, std=0.02) 250 | if isinstance(m, nn.Linear) and m.bias is not None: 251 | nn.init.constant_(m.bias, 0) 252 | elif isinstance(m, nn.LayerNorm): 253 | nn.init.constant_(m.bias, 0) 254 | nn.init.constant_(m.weight, 1.0) 255 | 256 | @torch.jit.ignore 257 | def no_weight_decay(self): 258 | return {"pos_embed", "cls_token"} 259 | 260 | def forward(self, x, register_blk=-1): 261 | B = x.shape[0] 262 | x = self.patch_embed(x) 263 | 264 | cls_tokens = self.cls_token.expand( 265 | B, -1, -1 266 | ) # stole cls_tokens impl from Phil Wang, thanks 267 | x = torch.cat((cls_tokens, x), dim=1) 268 | 269 | x = x + self.pos_embed[:, : x.size(1), :] 270 | x = self.pos_drop(x) 271 | 272 | for i, blk in enumerate(self.blocks): 273 | x = blk(x, register_blk == i) 274 | x = self.norm(x) 275 | 276 | return x 277 | 278 | @torch.jit.ignore() 279 | def load_pretrained(self, checkpoint_path, prefix=""): 280 | _load_weights(self, checkpoint_path, prefix) 281 | 282 | 283 | @torch.no_grad() 284 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""): 285 | """Load weights from .npz checkpoints for official Google Brain Flax implementation""" 286 | import numpy as np 287 | 288 | def _n2p(w, t=True): 289 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 290 | w = w.flatten() 291 | if t: 292 | if w.ndim == 4: 293 | w = w.transpose([3, 2, 0, 1]) 294 | elif w.ndim == 3: 295 | w = w.transpose([2, 0, 1]) 296 | elif w.ndim == 2: 297 | w = w.transpose([1, 0]) 298 | return torch.from_numpy(w) 299 | 300 | w = np.load(checkpoint_path) 301 | if not prefix and "opt/target/embedding/kernel" in w: 302 | prefix = "opt/target/" 303 | 304 | if hasattr(model.patch_embed, "backbone"): 305 | # hybrid 306 | backbone = model.patch_embed.backbone 307 | stem_only = not hasattr(backbone, "stem") 308 | stem = backbone if stem_only else backbone.stem 309 | stem.conv.weight.copy_( 310 | adapt_input_conv( 311 | stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"]) 312 | ) 313 | ) 314 | stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"])) 315 | stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"])) 316 | if not stem_only: 317 | for i, stage in enumerate(backbone.stages): 318 | for j, block in enumerate(stage.blocks): 319 | bp = f"{prefix}block{i + 1}/unit{j + 1}/" 320 | for r in range(3): 321 | getattr(block, f"conv{r + 1}").weight.copy_( 322 | _n2p(w[f"{bp}conv{r + 1}/kernel"]) 323 | ) 324 | getattr(block, f"norm{r + 1}").weight.copy_( 325 | _n2p(w[f"{bp}gn{r + 1}/scale"]) 326 | ) 327 | getattr(block, f"norm{r + 1}").bias.copy_( 328 | _n2p(w[f"{bp}gn{r + 1}/bias"]) 329 | ) 330 | if block.downsample is not None: 331 | block.downsample.conv.weight.copy_( 332 | _n2p(w[f"{bp}conv_proj/kernel"]) 333 | ) 334 | block.downsample.norm.weight.copy_( 335 | _n2p(w[f"{bp}gn_proj/scale"]) 336 | ) 337 | block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"])) 338 | embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"]) 339 | else: 340 | embed_conv_w = adapt_input_conv( 341 | model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]) 342 | ) 343 | model.patch_embed.proj.weight.copy_(embed_conv_w) 344 | model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"])) 345 | model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False)) 346 | pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False) 347 | if pos_embed_w.shape != model.pos_embed.shape: 348 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 349 | pos_embed_w, 350 | model.pos_embed, 351 | getattr(model, "num_tokens", 1), 352 | model.patch_embed.grid_size, 353 | ) 354 | model.pos_embed.copy_(pos_embed_w) 355 | model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"])) 356 | model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"])) 357 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 358 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 359 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 360 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 361 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 362 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 363 | for i, block in enumerate(model.blocks.children()): 364 | block_prefix = f"{prefix}Transformer/encoderblock_{i}/" 365 | mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/" 366 | block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"])) 367 | block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"])) 368 | block.attn.qkv.weight.copy_( 369 | torch.cat( 370 | [ 371 | _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T 372 | for n in ("query", "key", "value") 373 | ] 374 | ) 375 | ) 376 | block.attn.qkv.bias.copy_( 377 | torch.cat( 378 | [ 379 | _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) 380 | for n in ("query", "key", "value") 381 | ] 382 | ) 383 | ) 384 | block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1)) 385 | block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"])) 386 | for r in range(2): 387 | getattr(block.mlp, f"fc{r + 1}").weight.copy_( 388 | _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"]) 389 | ) 390 | getattr(block.mlp, f"fc{r + 1}").bias.copy_( 391 | _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"]) 392 | ) 393 | block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"])) 394 | block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"])) 395 | 396 | 397 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 398 | # interpolate position embedding 399 | embedding_size = pos_embed_checkpoint.shape[-1] 400 | num_patches = visual_encoder.patch_embed.num_patches 401 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 402 | # height (== width) for the checkpoint position embedding 403 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 404 | # height (== width) for the new position embedding 405 | new_size = int(num_patches**0.5) 406 | 407 | if orig_size != new_size: 408 | # class_token and dist_token are kept unchanged 409 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 410 | # only the position tokens are interpolated 411 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 412 | pos_tokens = pos_tokens.reshape( 413 | -1, orig_size, orig_size, embedding_size 414 | ).permute(0, 3, 1, 2) 415 | pos_tokens = torch.nn.functional.interpolate( 416 | pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False 417 | ) 418 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 419 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 420 | print( 421 | "reshape position embedding from %d to %d" % (orig_size**2, new_size**2) 422 | ) 423 | 424 | return new_pos_embed 425 | else: 426 | return pos_embed_checkpoint 427 | -------------------------------------------------------------------------------- /model/egovlpv2/EgoNCE_MLM_ITM_Config.yaml: -------------------------------------------------------------------------------- 1 | # Image setting 2 | input_image_embed_size: 768 3 | 4 | # Text Setting 5 | #tokenizer = "roberta-base" 6 | vocab_size: 50265 7 | mlm_prob: 0.15 8 | input_text_embed_size: 768 9 | 10 | # Transformer Setting 11 | hidden_size: 768 12 | num_heads: 12 13 | num_layers: 12 14 | mlp_ratio: 4 15 | drop_rate: 0.1 16 | num_fuse_block: 6 17 | 18 | 19 | # Gradient Checkpoint 20 | use_checkpoint: True 21 | 22 | # lr_scheduler 23 | decay_power: "cosine" 24 | end_lr: 0.0000001 25 | warmup_steps: 0.1 # This is a floating point indicating % of max_steps 26 | -------------------------------------------------------------------------------- /model/egovlpv2/base/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from .base_model import * 26 | -------------------------------------------------------------------------------- /model/egovlpv2/base/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import torch.nn as nn 26 | import numpy as np 27 | from abc import abstractmethod 28 | 29 | 30 | class BaseModel(nn.Module): 31 | """ 32 | Base class for all models 33 | """ 34 | 35 | @abstractmethod 36 | def forward(self, *inputs): 37 | """ 38 | Forward pass logic 39 | 40 | :return: Model output 41 | """ 42 | raise NotImplementedError 43 | 44 | def __str__(self): 45 | """ 46 | Model prints with number of trainable parameters 47 | """ 48 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 49 | params = sum(np.prod(p.size()) for p in model_parameters) 50 | return super().__str__() + "\nTrainable parameters: {}".format(params) 51 | -------------------------------------------------------------------------------- /model/egovlpv2/heads.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform 30 | 31 | 32 | class Pooler(nn.Module): 33 | def __init__(self, hidden_size): 34 | super().__init__() 35 | self.dense = nn.Linear(hidden_size, hidden_size) 36 | self.activation = nn.Tanh() 37 | 38 | def forward(self, hidden_states): 39 | first_token_tensor = hidden_states # [:, 0] 40 | pooled_output = self.dense(first_token_tensor) 41 | pooled_output = self.activation(pooled_output) 42 | return pooled_output 43 | 44 | 45 | class ITMHead(nn.Module): 46 | def __init__(self, hidden_size): 47 | super().__init__() 48 | self.fc = nn.Linear(hidden_size, 2) 49 | 50 | def forward(self, x): 51 | x = self.fc(x) 52 | return x 53 | 54 | 55 | class MLMHead(nn.Module): 56 | def __init__(self, config, weight=None): 57 | super().__init__() 58 | self.transform = BertPredictionHeadTransform(config) 59 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 60 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 61 | if weight is not None: 62 | self.decoder.weight = weight 63 | 64 | def forward(self, x): 65 | x = self.transform(x) 66 | x = self.decoder(x) + self.bias 67 | return x 68 | -------------------------------------------------------------------------------- /model/egovlpv2/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import os 26 | import sys 27 | import torch 28 | import yaml 29 | import torch.nn as nn 30 | import numpy as np 31 | import torch.nn.functional as F 32 | 33 | from transformers import RobertaConfig 34 | from functools import partial 35 | import copy 36 | import torch.distributed as dist 37 | 38 | from model.egovlpv2.base import BaseModel 39 | from model.egovlpv2.video_transformer import SpaceTimeTransformer 40 | from model.egovlpv2.util import state_dict_data_parallel_fix 41 | 42 | from model.egovlpv2 import roberta 43 | from model.egovlpv2.roberta import RobertaModel 44 | from model.egovlpv2 import heads 45 | 46 | with open(os.path.join(os.path.dirname(__file__), "EgoNCE_MLM_ITM_Config.yaml")) as f: 47 | config = yaml.load(f, Loader=yaml.FullLoader) 48 | 49 | 50 | def init_weights(module): 51 | if isinstance(module, (nn.Linear, nn.Embedding)): 52 | module.weight.data.normal_(mean=0.0, std=0.02) 53 | elif isinstance(module, nn.LayerNorm): 54 | module.bias.data.zero_() 55 | module.weight.data.fill_(1.0) 56 | 57 | if isinstance(module, nn.Linear) and module.bias is not None: 58 | module.bias.data.zero_() 59 | 60 | 61 | class FrozenInTime(BaseModel): 62 | def __init__( 63 | self, 64 | video_params, 65 | text_params, 66 | projection_dim=4096, 67 | load_checkpoint=None, 68 | projection="minimal", 69 | load_temporal_fix="bilinear", 70 | config=config, 71 | task_names="EgoNCE_ITM_MLM", 72 | norm_layer=None, 73 | embed_dim=768, 74 | ): 75 | super().__init__() 76 | 77 | self.video_params = video_params 78 | self.text_params = text_params 79 | self.load_temporal_fix = load_temporal_fix 80 | self.config = config 81 | self.task_names = task_names 82 | if not text_params["pretrained"]: 83 | raise NotImplementedError( 84 | "Huggingface text models require pretrained init." 85 | ) 86 | 87 | if self.text_params["model"].startswith("roberta"): 88 | self.text_model = RobertaModel.from_pretrained("roberta-base") 89 | self.text_model.train() 90 | 91 | pretrained = video_params["pretrained"] 92 | if video_params["model"] == "SpaceTimeTransformer": 93 | self.num_frames = video_params["num_frames"] 94 | time_init = "zeros" 95 | attention_style = "frozen-in-time" 96 | arch_config = "base_patch16_224" 97 | vit_init = "imagenet-21k" 98 | if arch_config == "base_patch16_224": 99 | model = SpaceTimeTransformer( 100 | num_frames=self.num_frames, 101 | time_init=time_init, 102 | attention_style=attention_style, 103 | ) 104 | else: 105 | raise NotImplementedError 106 | 107 | model.head = nn.Identity() 108 | model.pre_logits = nn.Identity() 109 | ftr_dim = model.embed_dim 110 | 111 | self.video_model = model 112 | else: 113 | raise NotImplementedError(f"{video_params['model']} not implemented") 114 | 115 | # for backwards compatibility (old models) 116 | self.video_model.fc = nn.Identity() 117 | 118 | # Project to a common embedding 119 | if projection == "small": 120 | 121 | txt_proj = nn.Sequential( 122 | nn.ReLU(), 123 | nn.Linear(self.text_model.config.hidden_size, 256), 124 | ) 125 | vid_proj = nn.Sequential(nn.Linear(ftr_dim, 256)) 126 | elif projection == "default" or projection == "minimal": 127 | txt_proj = nn.Sequential( 128 | nn.Linear( 129 | self.text_model.config.hidden_size, projection_dim, bias=False 130 | ), 131 | nn.ReLU(inplace=True), 132 | nn.Linear(projection_dim, projection_dim, bias=True), 133 | nn.ReLU(inplace=True), 134 | nn.Linear(projection_dim, projection_dim, bias=True), 135 | ) 136 | 137 | vid_proj = nn.Sequential( 138 | nn.Linear(ftr_dim, projection_dim, bias=False), 139 | nn.ReLU(inplace=True), 140 | nn.Linear(projection_dim, projection_dim, bias=True), 141 | nn.ReLU(inplace=True), 142 | nn.Linear(projection_dim, projection_dim, bias=True), 143 | ) 144 | 145 | elif projection == "": 146 | txt_proj = nn.Identity() 147 | vid_proj = nn.Identity() 148 | else: 149 | raise NotImplementedError 150 | self.txt_proj = txt_proj 151 | self.vid_proj = vid_proj 152 | 153 | if "MLM" in self.task_names or "ITM" in self.task_names: 154 | # for FIBER-like cross-attention 155 | 156 | bert_config = RobertaConfig( 157 | vocab_size=self.config["vocab_size"], 158 | hidden_size=self.config["hidden_size"], 159 | num_hidden_layers=self.config["num_layers"], 160 | num_attention_heads=self.config["num_heads"], 161 | intermediate_size=self.config["hidden_size"] * self.config["mlp_ratio"], 162 | # max_position_embeddings=maxlen, [was used in BTGOT script] 163 | hidden_dropout_prob=self.config["drop_rate"], 164 | attention_probs_dropout_prob=self.config["drop_rate"], 165 | ) 166 | 167 | self.num_fuse_block = self.config["num_fuse_block"] 168 | self.num_text_layer = self.config["num_layers"] 169 | roberta.NUM_FUSE_BLOCK = self.video_model.NUM_FUSE_BLOCK = ( 170 | self.num_fuse_block 171 | ) 172 | roberta.DIM_IMG = self.config["input_image_embed_size"] 173 | self.video_model.DIM_TXT = self.config["input_text_embed_size"] 174 | 175 | self.cross_modal_text_transform = nn.Linear( 176 | self.config["input_text_embed_size"], self.config["hidden_size"] 177 | ) 178 | self.cross_modal_text_transform.apply(init_weights) 179 | self.cross_modal_video_transform = nn.Linear( 180 | self.config["input_image_embed_size"], self.config["hidden_size"] 181 | ) 182 | self.cross_modal_video_transform.apply(init_weights) 183 | 184 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 185 | 186 | self.num_patches = self.video_model.patch_embed.num_patches 187 | self.patches_per_frame = self.num_patches // self.num_frames 188 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 189 | self.norm = norm_layer(embed_dim) 190 | self.pre_logits = nn.Identity() 191 | 192 | self.avgpool = nn.AdaptiveAvgPool1d(1) 193 | self.cross_modal_video_pooler = heads.Pooler(self.config["hidden_size"]) 194 | self.cross_modal_video_pooler.apply(init_weights) 195 | self.cross_modal_text_pooler = heads.Pooler(self.config["hidden_size"]) 196 | self.cross_modal_text_pooler.apply(init_weights) 197 | 198 | ## einops transformations 199 | self.einops_from_space = "b (f n) d" 200 | self.einops_to_space = "(b f) n d" 201 | self.einops_from_time = "b (f n) d" 202 | self.einops_to_time = "(b n) f d" 203 | 204 | if "MLM" in self.task_names: 205 | self.mlm_score = heads.MLMHead(bert_config) 206 | self.mlm_score.apply(init_weights) 207 | 208 | if "ITM" in self.task_names: 209 | self.itm_score = heads.ITMHead(self.config["hidden_size"] * 2) 210 | self.itm_score.apply(init_weights) 211 | 212 | if load_checkpoint not in ["", None]: 213 | print("loading checkpoint from ", load_checkpoint) 214 | # Fix for loading parse_config from non-root, required to load checkpoint 215 | src_dir = os.path.dirname(__file__) 216 | sys.path.insert(0, src_dir) # Modify path temporarily 217 | 218 | checkpoint = torch.load(load_checkpoint, map_location="cpu") 219 | state_dict = checkpoint["state_dict"] 220 | new_state_dict = state_dict_data_parallel_fix(state_dict, self.state_dict()) 221 | new_state_dict = self._inflate_positional_embeds(new_state_dict) 222 | self.load_state_dict(new_state_dict, strict=False) 223 | 224 | def set_device(self, device): 225 | self.device = device 226 | 227 | def infer( 228 | self, data, video_only=False, return_embeds=True, task_names=None, ret={} 229 | ): 230 | 231 | text_data = data["text"] 232 | video_data = data["video"] 233 | 234 | if task_names is not None: 235 | self.task_names = task_names 236 | 237 | if "EgoNCE" in self.task_names: 238 | 239 | text_embeddings = self.compute_text(text_data) 240 | video_embeddings = self.compute_video(video_data) 241 | 242 | if return_embeds: 243 | ret.update( 244 | {"text_embeds": text_embeddings, "video_embeds": video_embeddings} 245 | ) 246 | 247 | if "ITM" in self.task_names: 248 | 249 | b, curr_frames, channels, _, _ = video_data.shape 250 | video_data_itm = self.video_model.patch_embed(video_data) 251 | video_data_itm = video_data_itm.flatten(2).transpose(2, 1) 252 | video_data_itm = video_data_itm.reshape( 253 | b, -1, self.video_model.patch_embed.embed_dim 254 | ) 255 | 256 | BF = video_data_itm.shape[0] 257 | cls_tokens = self.cls_token.expand( 258 | BF, -1, -1 259 | ) # stole cls_tokens impl from Phil Wang, thanks 260 | video_data_itm = torch.cat((cls_tokens, video_data_itm), dim=1) 261 | # positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...) 262 | cls_embed = self.video_model.pos_embed[:, 0, :].unsqueeze(1) 263 | tile_pos_embed = self.video_model.pos_embed[:, 1:, :].repeat( 264 | 1, self.num_frames, 1 265 | ) 266 | # temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...) 267 | tile_temporal_embed = self.video_model.temporal_embed.repeat_interleave( 268 | self.patches_per_frame, 1 269 | ) 270 | total_pos_embed = tile_pos_embed + tile_temporal_embed 271 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 272 | 273 | n = self.patches_per_frame 274 | f = curr_frames 275 | 276 | curr_patches = video_data_itm.shape[1] 277 | video_data_itm = video_data_itm + total_pos_embed[:, :curr_patches] 278 | video_data_itm = self.video_model.pos_drop(video_data_itm) 279 | 280 | unfused_blocks = self.num_text_layer - self.num_fuse_block 281 | 282 | for blk_i, blk in enumerate(self.video_model.blocks[:unfused_blocks]): 283 | if self.config["use_checkpoint"]: 284 | video_data_itm = torch.utils.checkpoint.checkpoint( 285 | blk, 286 | video_data_itm, 287 | self.einops_from_space, 288 | self.einops_to_space, 289 | self.einops_from_time, 290 | self.einops_to_time, 291 | n, 292 | f, 293 | ) 294 | else: 295 | video_data_itm = blk( 296 | video_data_itm, 297 | self.einops_from_space, 298 | self.einops_to_space, 299 | self.einops_from_time, 300 | self.einops_to_time, 301 | time_n=n, 302 | space_f=f, 303 | ) 304 | 305 | text_embeds = self.text_model.embeddings( 306 | input_ids=text_data["input_ids"] 307 | ) # before it was input_ids=text_ids 308 | device = text_embeds.device 309 | text_masks = text_data["attention_mask"] 310 | input_shape = text_masks.size() 311 | extend_text_masks = self.text_model.get_extended_attention_mask( 312 | text_masks, input_shape, device 313 | ) 314 | for layer_i, layer in enumerate( 315 | self.text_model.encoder.layer[:unfused_blocks] 316 | ): 317 | if self.config["use_checkpoint"]: 318 | text_embeds = torch.utils.checkpoint.checkpoint( 319 | layer, text_embeds, extend_text_masks 320 | )[0] 321 | else: 322 | text_embeds = layer(text_embeds, extend_text_masks)[0] 323 | 324 | for blk_i, blk in enumerate( 325 | self.video_model.blocks[unfused_blocks : self.num_text_layer] 326 | ): 327 | if self.config["use_checkpoint"]: 328 | 329 | fuse_video_data = torch.utils.checkpoint.checkpoint( 330 | blk, 331 | video_data_itm, 332 | self.einops_from_space, 333 | self.einops_to_space, 334 | self.einops_from_time, 335 | self.einops_to_time, 336 | n, 337 | f, 338 | text_embeds, 339 | extend_text_masks, 340 | ) 341 | text_embeds = torch.utils.checkpoint.checkpoint( 342 | self.text_model.encoder.layer[blk_i + unfused_blocks], 343 | text_embeds, 344 | extend_text_masks, 345 | None, 346 | (video_data_itm), 347 | None, 348 | None, 349 | False, 350 | True, 351 | )[0] 352 | else: 353 | fuse_video_data = blk( 354 | video_data_itm, 355 | self.einops_from_space, 356 | self.einops_to_space, 357 | self.einops_from_time, 358 | self.einops_to_time, 359 | y=text_embeds, 360 | y_mask=extend_text_masks, 361 | time_n=n, 362 | space_f=f, 363 | ) 364 | text_embeds = self.text_model.encoder.layer[blk_i + unfused_blocks]( 365 | text_embeds, 366 | extend_text_masks, 367 | encoder_hidden_states=(video_data_itm), 368 | last_norm=True, 369 | )[0] 370 | video_data_itm = fuse_video_data 371 | 372 | # print("Shape of model output", video_data.shape) 373 | video_data_itm = self.norm(video_data_itm)[:, 0] 374 | video_data_itm = self.pre_logits(video_data_itm) 375 | 376 | text_embeds = text_embeds[:, 0] 377 | text_embeds = self.cross_modal_text_transform(text_embeds) 378 | video_embeds = self.cross_modal_video_transform(video_data_itm) 379 | 380 | cls_feats_text = self.cross_modal_text_pooler(text_embeds) 381 | 382 | cls_feats_video = self.cross_modal_video_pooler(video_embeds) 383 | 384 | cls_feats = torch.cat([cls_feats_text, cls_feats_video], dim=-1) 385 | 386 | ret.update({"cross_attn_itm_logits": self.itm_score(cls_feats)}) 387 | 388 | if "MLM" in self.task_names: 389 | 390 | b, curr_frames, channels, _, _ = video_data.shape 391 | video_data_mlm = self.video_model.patch_embed(video_data) 392 | video_data_mlm = video_data_mlm.flatten(2).transpose(2, 1) 393 | video_data_mlm = video_data_mlm.reshape( 394 | b, -1, self.video_model.patch_embed.embed_dim 395 | ) 396 | 397 | BF = video_data_mlm.shape[0] 398 | cls_tokens = self.cls_token.expand( 399 | BF, -1, -1 400 | ) # stole cls_tokens impl from Phil Wang, thanks 401 | video_data_mlm = torch.cat((cls_tokens, video_data_mlm), dim=1) 402 | # positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...) 403 | cls_embed = self.video_model.pos_embed[:, 0, :].unsqueeze(1) 404 | tile_pos_embed = self.video_model.pos_embed[:, 1:, :].repeat( 405 | 1, self.num_frames, 1 406 | ) 407 | # temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...) 408 | tile_temporal_embed = self.video_model.temporal_embed.repeat_interleave( 409 | self.patches_per_frame, 1 410 | ) 411 | total_pos_embed = tile_pos_embed + tile_temporal_embed 412 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 413 | 414 | # print("total_pos_embed shape: ", total_pos_embed.shape) 415 | 416 | n = self.patches_per_frame 417 | f = curr_frames 418 | 419 | curr_patches = video_data_mlm.shape[1] 420 | video_data_mlm = video_data_mlm + total_pos_embed[:, :curr_patches] 421 | video_data_mlm = self.video_model.pos_drop(video_data_mlm) 422 | 423 | # print("video_data_mlm shape: ", video_data_mlm.shape) 424 | 425 | unfused_blocks = self.num_text_layer - self.num_fuse_block 426 | 427 | for blk_i, blk in enumerate(self.video_model.blocks[:unfused_blocks]): 428 | if self.config["use_checkpoint"]: 429 | video_data_mlm = torch.utils.checkpoint.checkpoint( 430 | blk, 431 | video_data_mlm, 432 | self.einops_from_space, 433 | self.einops_to_space, 434 | self.einops_from_time, 435 | self.einops_to_time, 436 | n, 437 | f, 438 | ) 439 | else: 440 | video_data_mlm = blk( 441 | video_data_mlm, 442 | self.einops_from_space, 443 | self.einops_to_space, 444 | self.einops_from_time, 445 | self.einops_to_time, 446 | time_n=n, 447 | space_f=f, 448 | ) 449 | 450 | text_embeds = self.text_model.embeddings( 451 | input_ids=data["text_mlm_ids"] 452 | ) # before it was input_ids=text_ids 453 | device = text_embeds.device 454 | text_masks = text_data["attention_mask"] 455 | input_shape = text_masks.size() 456 | extend_text_masks = self.text_model.get_extended_attention_mask( 457 | text_masks, input_shape, device 458 | ) 459 | 460 | for layer_i, layer in enumerate( 461 | self.text_model.encoder.layer[:unfused_blocks] 462 | ): 463 | if self.config["use_checkpoint"]: 464 | text_embeds = torch.utils.checkpoint.checkpoint( 465 | layer, text_embeds, extend_text_masks 466 | )[0] 467 | else: 468 | text_embeds = layer(text_embeds, extend_text_masks)[0] 469 | 470 | for blk_i, blk in enumerate( 471 | self.video_model.blocks[unfused_blocks : self.num_text_layer] 472 | ): 473 | if self.config["use_checkpoint"]: 474 | 475 | fuse_video_data = torch.utils.checkpoint.checkpoint( 476 | blk, 477 | video_data_mlm, 478 | self.einops_from_space, 479 | self.einops_to_space, 480 | self.einops_from_time, 481 | self.einops_to_time, 482 | n, 483 | f, 484 | text_embeds, 485 | extend_text_masks, 486 | ) 487 | text_embeds = torch.utils.checkpoint.checkpoint( 488 | self.text_model.encoder.layer[blk_i + unfused_blocks], 489 | text_embeds, 490 | extend_text_masks, 491 | None, 492 | (video_data_mlm), 493 | None, 494 | None, 495 | False, 496 | True, 497 | )[0] 498 | else: 499 | fuse_video_data = blk( 500 | video_data_mlm, 501 | self.einops_from_space, 502 | self.einops_to_space, 503 | self.einops_from_time, 504 | self.einops_to_time, 505 | y=text_embeds, 506 | y_mask=extend_text_masks, 507 | time_n=n, 508 | space_f=f, 509 | ) 510 | text_embeds = self.text_model.encoder.layer[blk_i + unfused_blocks]( 511 | text_embeds, 512 | extend_text_masks, 513 | encoder_hidden_states=(video_data_mlm), 514 | last_norm=True, 515 | )[0] 516 | video_data_mlm = fuse_video_data 517 | 518 | text_embeds = text_embeds # [:, 0] 519 | text_embeds = self.cross_modal_text_transform(text_embeds) 520 | 521 | ret.update({"cross_attn_mlm_logits": self.mlm_score(text_embeds)}) 522 | 523 | return ret 524 | 525 | def forward( 526 | self, 527 | data, 528 | n_embeds, 529 | v_embeds, 530 | allgather, 531 | n_gpu, 532 | args, 533 | config, 534 | loss_egonce, 535 | gpu, 536 | return_embeds=True, 537 | task_names="EgoNCE_ITM_MLM", 538 | ): 539 | 540 | ret = {} 541 | loss_dict = {} 542 | 543 | if "Feature_Extraction" in task_names: 544 | video_embeddings = self.compute_video(data["video"]) 545 | return video_embeddings 546 | 547 | if "EgoNCE" in task_names: 548 | 549 | ret = self.infer(data, task_names="EgoNCE") 550 | video_embeds = ret["video_embeds"] 551 | text_embeds = ret["text_embeds"] 552 | video_embeds = allgather(video_embeds, n_gpu, args) 553 | text_embeds = allgather(text_embeds, n_gpu, args) 554 | n_embeds = allgather(n_embeds, n_gpu, args) 555 | v_embeds = allgather(v_embeds, n_gpu, args) 556 | output = sim_matrix(text_embeds, video_embeds) 557 | 558 | if config["loss"]["type"] == "EgoNCE": 559 | sim_v = sim_matrix(v_embeds, v_embeds) 560 | sim_n = sim_matrix(n_embeds, n_embeds) 561 | loss, mask_bool, temp = loss_egonce(output, sim_v, sim_n) 562 | else: 563 | loss, mask_bool, temp = loss_egonce(output) 564 | 565 | ret.update( 566 | { 567 | "sim_v2t": output, 568 | "sim_t2v": output.t(), 569 | } 570 | ) 571 | 572 | loss_dict.update({"EgoNCE": loss}) 573 | 574 | # MLM 575 | if "MLM" in task_names: 576 | 577 | ret = self.infer(data, task_names="MLM", ret=ret) 578 | 579 | mlm_logits = ret["cross_attn_mlm_logits"].view(-1, 50265) 580 | mlm_labels = data["text_mlm_labels"].view(-1) 581 | 582 | mlm_logits = allgather(mlm_logits, n_gpu, args) 583 | mlm_labels = allgather(mlm_labels, n_gpu, args) 584 | 585 | loss_mlm = torch.nn.functional.cross_entropy( 586 | mlm_logits, 587 | mlm_labels, 588 | ignore_index=-100, 589 | ).mean() 590 | 591 | loss = loss + loss_mlm 592 | 593 | loss_dict.update({"loss_mlm": loss_mlm}) 594 | 595 | # ITM 596 | if "ITM" in task_names: 597 | 598 | rank = dist.get_rank() 599 | 600 | all_video = allgather(data["video"], n_gpu, args) 601 | all_text_ids = allgather(data["text"]["input_ids"], n_gpu, args) 602 | all_text_masks = allgather(data["text"]["attention_mask"], n_gpu, args) 603 | 604 | pos_len = data["video"].size(0) // 2 605 | neg_len = data["video"].size(0) - pos_len 606 | itm_labels = torch.cat([torch.ones(pos_len), torch.zeros(neg_len)]).cuda( 607 | gpu, non_blocking=True 608 | ) 609 | 610 | itm_labels = itm_labels[torch.randperm(itm_labels.size(0))] 611 | 612 | batch_size = len(itm_labels) 613 | 614 | with torch.no_grad(): 615 | weights_v2t = F.softmax( 616 | ret["sim_v2t"][batch_size * rank : batch_size * (rank + 1), :] 617 | / temp, 618 | dim=1, 619 | ) 620 | weights_t2v = F.softmax( 621 | ret["sim_t2v"][batch_size * rank : batch_size * (rank + 1), :] 622 | / temp, 623 | dim=1, 624 | ) 625 | 626 | weights_v2t.masked_fill_( 627 | mask_bool[batch_size * rank : batch_size * (rank + 1), :], 0 628 | ) 629 | weights_t2v.masked_fill_( 630 | mask_bool[batch_size * rank : batch_size * (rank + 1), :], 0 631 | ) 632 | 633 | data_itm = copy.deepcopy(data) 634 | 635 | for idx in range(len(itm_labels)): 636 | if itm_labels[idx] == 1: 637 | data_itm["video"][idx, :] = all_video[rank * batch_size + idx, :] 638 | data_itm["text"]["input_ids"][idx, :] = all_text_ids[ 639 | rank * batch_size + idx, : 640 | ] 641 | data_itm["text"]["attention_mask"][idx, :] = all_text_masks[ 642 | rank * batch_size + idx, : 643 | ] 644 | 645 | else: 646 | if np.random.rand() > 0.5: 647 | neg_idx = torch.multinomial(weights_t2v[idx] + 1e-9, 1).item() 648 | data_itm["video"][idx, :] = all_video[neg_idx, :] 649 | data_itm["text"]["input_ids"][idx, :] = all_text_ids[ 650 | rank * batch_size + idx, : 651 | ] 652 | data_itm["text"]["attention_mask"][idx, :] = all_text_masks[ 653 | rank * batch_size + idx, : 654 | ] 655 | else: 656 | neg_idx = torch.multinomial(weights_v2t[idx] + 1e-9, 1).item() 657 | data_itm["video"][idx, :] = all_video[ 658 | rank * batch_size + idx, : 659 | ] 660 | data_itm["text"]["input_ids"][idx, :] = all_text_ids[neg_idx, :] 661 | data_itm["text"]["attention_mask"][idx, :] = all_text_masks[ 662 | neg_idx, : 663 | ] 664 | 665 | ret = self.infer(data_itm, task_names="ITM", ret=ret) 666 | 667 | itm_logits = ret["cross_attn_itm_logits"] 668 | 669 | itm_logits = allgather(itm_logits, n_gpu, args) 670 | itm_labels = allgather(itm_labels, n_gpu, args) 671 | 672 | loss_itm = torch.nn.functional.cross_entropy( 673 | itm_logits, itm_labels.long() 674 | ).mean() 675 | 676 | loss = loss + 2 * loss_itm 677 | 678 | # print("ITM loss: ", loss_itm) 679 | loss_dict.update({"loss_itm": loss_itm}) 680 | 681 | loss_dict.update({"loss_total": loss}) 682 | 683 | return loss, loss_dict, ret 684 | 685 | def compute_text(self, text_data): 686 | if self.text_params["model"].startswith("bert"): 687 | text_embeddings = self.text_model( 688 | text_data["input_ids"], attention_mask=text_data["attention_mask"] 689 | )["pooler_output"] 690 | elif self.text_params["model"].startswith("distilbert"): 691 | text_embeddings = self.text_model(**text_data).last_hidden_state[:, 0, :] 692 | elif self.text_params["model"].startswith("roberta"): 693 | text_embeddings = self.text_model(**text_data).last_hidden_state[:, 0, :] 694 | else: 695 | raise NotImplementedError 696 | if self.config["use_checkpoint"]: 697 | text_embeddings = torch.utils.checkpoint.checkpoint( 698 | self.txt_proj, text_embeddings 699 | ) 700 | else: 701 | text_embeddings = self.txt_proj(text_embeddings) 702 | return text_embeddings 703 | 704 | def compute_text_tokens(self, text_data): 705 | if self.text_params["model"].startswith("bert"): 706 | text_embeddings = self.text_model( 707 | text_data["input_ids"], attention_mask=text_data["attention_mask"] 708 | )[ 709 | "pooler_output" 710 | ] # not implement for bert 711 | elif self.text_params["model"].startswith("distilbert"): 712 | text_embeddings = self.text_model(**text_data).last_hidden_state 713 | elif self.text_params["model"].startswith("roberta"): 714 | text_embeddings = self.text_model(**text_data).last_hidden_state 715 | else: 716 | raise NotImplementedError 717 | 718 | if self.config["use_checkpoint"]: 719 | text_embeddings = torch.utils.checkpoint.checkpoint( 720 | self.txt_proj, text_embeddings 721 | ) 722 | else: 723 | text_embeddings = self.txt_proj(text_embeddings) 724 | return text_embeddings 725 | 726 | def compute_video(self, video_data): 727 | video_embeddings = self.video_model(video_data) 728 | if self.config["use_checkpoint"]: 729 | video_embeddings = torch.utils.checkpoint.checkpoint( 730 | self.vid_proj, video_embeddings 731 | ) 732 | else: 733 | video_embeddings = self.vid_proj(video_embeddings) 734 | return video_embeddings 735 | 736 | def _inflate_positional_embeds(self, new_state_dict): 737 | # allow loading of timesformer with fewer num_frames 738 | curr_keys = list(self.state_dict().keys()) 739 | if ( 740 | "video_model.temporal_embed" in new_state_dict 741 | and "video_model.temporal_embed" in curr_keys 742 | ): 743 | load_temporal_embed = new_state_dict["video_model.temporal_embed"] 744 | load_num_frames = load_temporal_embed.shape[1] 745 | curr_num_frames = self.video_params["num_frames"] 746 | embed_dim = load_temporal_embed.shape[2] 747 | 748 | if load_num_frames != curr_num_frames: 749 | if load_num_frames > curr_num_frames: 750 | print( 751 | f'### loaded {self.video_params["model"]} model has MORE frames than current...' 752 | f"### loading weights, filling in the extras via {self.load_temporal_fix}" 753 | ) 754 | new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :] 755 | else: 756 | print( 757 | f'### loaded {self.video_params["model"]} model has FEWER frames than current...' 758 | f"### loading weights, filling in the extras via {self.load_temporal_fix}" 759 | ) 760 | if self.load_temporal_fix == "zeros": 761 | new_temporal_embed = torch.zeros( 762 | [load_temporal_embed.shape[0], curr_num_frames, embed_dim] 763 | ) 764 | new_temporal_embed[:, :load_num_frames] = load_temporal_embed 765 | elif self.load_temporal_fix in ["interp", "bilinear"]: 766 | # interpolate 767 | # unsqueeze so pytorch thinks its an image 768 | mode = "nearest" 769 | if self.load_temporal_fix == "bilinear": 770 | mode = "bilinear" 771 | load_temporal_embed = load_temporal_embed.unsqueeze(0) 772 | new_temporal_embed = F.interpolate( 773 | load_temporal_embed, 774 | (curr_num_frames, embed_dim), 775 | mode=mode, 776 | align_corners=True, 777 | ).squeeze(0) 778 | else: 779 | raise NotImplementedError 780 | new_state_dict["video_model.temporal_embed"] = new_temporal_embed 781 | # allow loading with smaller spatial patches. assumes custom border crop, to append the 782 | # border patches to the input sequence 783 | if ( 784 | "video_model.pos_embed" in new_state_dict 785 | and "video_model.pos_embed" in curr_keys 786 | ): 787 | load_pos_embed = new_state_dict["video_model.pos_embed"] 788 | load_num_patches = load_pos_embed.shape[1] 789 | curr_pos_embed = self.state_dict()["video_model.pos_embed"] 790 | if load_num_patches != curr_pos_embed.shape[1]: 791 | raise NotImplementedError( 792 | "Loading models with different spatial resolution / patch number not yet implemented, sorry." 793 | ) 794 | 795 | return new_state_dict 796 | 797 | 798 | def sim_matrix(a, b, eps=1e-8): 799 | """ 800 | added eps for numerical stability 801 | """ 802 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 803 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 804 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 805 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) 806 | return sim_mt 807 | 808 | 809 | def sim_matrix_batch_val(a, b, eps=1e-8): 810 | """ 811 | added eps for numerical stability 812 | """ 813 | a_n, b_n = a.norm(dim=-1).unsqueeze(-1), b.norm(dim=-1).unsqueeze(-1) 814 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 815 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 816 | sim_mt = torch.bmm(a_norm, b_norm.transpose(1, 2)) 817 | return sim_mt 818 | 819 | 820 | if __name__ == "__main__": 821 | pass 822 | -------------------------------------------------------------------------------- /model/egovlpv2/parse_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import os 26 | import logging 27 | import pdb 28 | from pathlib import Path 29 | from functools import reduce 30 | from operator import getitem 31 | from datetime import datetime 32 | 33 | # import pdb; pdb.set_trace() 34 | # from logger import setup_logging 35 | from model.egovlpv2.util import read_json, write_json 36 | import time 37 | import inspect 38 | 39 | 40 | class ConfigParser: 41 | def __init__(self, args, options="", timestamp=True, test=False, eval_mode=None): 42 | # parse default and custom cli options 43 | for opt in options: 44 | args.add_argument(*opt.flags, default=None, type=opt.type) 45 | args = args.parse_args() 46 | self.args = args 47 | if args.device: 48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 49 | if args.resume is None: 50 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 51 | assert args.config is not None, msg_no_cfg 52 | self.cfg_fname = Path(args.config) 53 | config = read_json(self.cfg_fname) 54 | self.resume = None 55 | else: 56 | self.resume = Path(args.resume) 57 | resume_cfg_fname = Path(args.config) 58 | if eval_mode == "epic": 59 | resume_cfg_fname = Path("configs/eval/epic.json") 60 | if eval_mode == "charades": 61 | resume_cfg_fname = Path("configs/eval/charades.json") 62 | if eval_mode == "nlq": 63 | resume_cfg_fname = Path("configs/eval/nlq.json") 64 | if eval_mode == "mq": 65 | resume_cfg_fname = Path("configs/eval/mq.json") 66 | 67 | config = read_json(resume_cfg_fname) 68 | if args.config is not None: 69 | config.update(read_json(Path(args.config))) 70 | 71 | # load config file and apply custom cli options 72 | self._config = _update_config(config, options, args) 73 | 74 | # set save_dir where trained model and log will be saved. 75 | # save_dir = Path(self.config['trainer']['save_dir']) 76 | save_dir = Path(args.save_dir) 77 | # timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 78 | timestamp = datetime.now().strftime(r"%m%d_%H") if timestamp else "" 79 | 80 | exper_name = self.config["name"] 81 | 82 | self._save_dir = save_dir / "models" / timestamp 83 | self._web_log_dir = save_dir / "web" / timestamp 84 | self._log_dir = save_dir / "log" / timestamp 85 | self._tf_dir = save_dir / "tf" / timestamp 86 | 87 | if not test: 88 | self.save_dir.mkdir(parents=True, exist_ok=True) 89 | self.log_dir.mkdir(parents=True, exist_ok=True) 90 | self._tf_dir.mkdir(parents=True, exist_ok=True) 91 | 92 | # if set, remove all previous experiments with the current config 93 | if vars(args).get("purge_exp_dir", False): 94 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir): 95 | config_dir = dirpath.parent 96 | existing = list(config_dir.glob("*")) 97 | print(f"purging {len(existing)} directories from config_dir...") 98 | tic = time.time() 99 | os.system(f"rm -rf {config_dir}") 100 | print(f"Finished purge in {time.time() - tic:.3f}s") 101 | 102 | # save updated config file to the checkpoint dir 103 | if not test: 104 | write_json(self.config, self.save_dir / "config.json") 105 | 106 | # configure logging module 107 | # setup_logging(self.log_dir) 108 | self.log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} 109 | 110 | def initialize(self, name, module, *args, index=None, **kwargs): 111 | """ 112 | finds a function handle with the name given as 'type' in config, and returns the 113 | instance initialized with corresponding keyword args given as 'args'. 114 | """ 115 | if index is None: 116 | module_name = self[name]["type"] 117 | module_args = dict(self[name]["args"]) 118 | assert all( 119 | [k not in module_args for k in kwargs] 120 | ), "Overwriting kwargs given in config file is not allowed" 121 | module_args.update(kwargs) 122 | else: 123 | module_name = self[name][index]["type"] 124 | module_args = dict(self[name][index]["args"]) 125 | 126 | # if parameter not in config subdict, then check if it's in global config. 127 | signature = inspect.signature(getattr(module, module_name).__init__) 128 | print(module_name) 129 | for param in signature.parameters.keys(): 130 | if param not in module_args and param in self.config: 131 | module_args[param] = self[param] 132 | if module_name == "FrozenInTime" and param == "args": 133 | module_args[param] = self.args 134 | if module_name == "MultiDistTextVideoDataLoader" and param == "args": 135 | module_args[param] = self.args 136 | 137 | return getattr(module, module_name)(*args, **module_args) 138 | 139 | def __getitem__(self, name): 140 | return self.config[name] 141 | 142 | def get_logger(self, name, verbosity=2): 143 | msg_verbosity = "verbosity option {} is invalid. Valid options are {}.".format( 144 | verbosity, self.log_levels.keys() 145 | ) 146 | assert verbosity in self.log_levels, msg_verbosity 147 | logger = logging.getLogger(name) 148 | logger.setLevel(self.log_levels[verbosity]) 149 | return logger 150 | 151 | # setting read-only attributes 152 | @property 153 | def config(self): 154 | return self._config 155 | 156 | @property 157 | def save_dir(self): 158 | return self._save_dir 159 | 160 | @property 161 | def log_dir(self): 162 | return self._log_dir 163 | 164 | @property 165 | def tf_dir(self): 166 | return self._tf_dir 167 | 168 | 169 | # helper functions used to update config dict with custom cli options 170 | def _update_config(config, options, args): 171 | for opt in options: 172 | value = getattr(args, _get_opt_name(opt.flags)) 173 | if value is not None: 174 | _set_by_path(config, opt.target, value) 175 | return config 176 | 177 | 178 | def _get_opt_name(flags): 179 | for flg in flags: 180 | if flg.startswith("--"): 181 | return flg.replace("--", "") 182 | return flags[0].replace("--", "") 183 | 184 | 185 | def _set_by_path(tree, keys, value): 186 | """Set a value in a nested object in tree by sequence of keys.""" 187 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 188 | 189 | 190 | def _get_by_path(tree, keys): 191 | """Access a nested object in tree by sequence of keys.""" 192 | return reduce(getitem, keys, tree) 193 | -------------------------------------------------------------------------------- /model/egovlpv2/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import json 26 | from collections import OrderedDict 27 | 28 | 29 | def state_dict_data_parallel_fix(load_state_dict, curr_state_dict): 30 | load_keys = list(load_state_dict.keys()) 31 | curr_keys = list(curr_state_dict.keys()) 32 | 33 | redo_dp = False 34 | undo_dp = False 35 | if not curr_keys[0].startswith("module.") and load_keys[0].startswith( 36 | "module." 37 | ): # this 38 | undo_dp = True 39 | elif curr_keys[0].startswith("module.") and not load_keys[0].startswith("module."): 40 | redo_dp = True 41 | 42 | if undo_dp: # this 43 | from collections import OrderedDict 44 | 45 | new_state_dict = OrderedDict() 46 | for k, v in load_state_dict.items(): 47 | name = k[7:] # remove `module.` 48 | new_state_dict[name] = v 49 | # load params 50 | elif redo_dp: 51 | from collections import OrderedDict 52 | 53 | new_state_dict = OrderedDict() 54 | for k, v in load_state_dict.items(): 55 | name = "module." + k # remove `module.` 56 | new_state_dict[name] = v 57 | else: 58 | new_state_dict = load_state_dict 59 | return new_state_dict 60 | 61 | 62 | def read_json(fname): 63 | with fname.open("rt") as handle: 64 | return json.load(handle, object_hook=OrderedDict) 65 | 66 | 67 | def write_json(content, fname): 68 | with fname.open("wt") as handle: 69 | json.dump(content, handle, indent=4, sort_keys=False) 70 | -------------------------------------------------------------------------------- /model/egovlpv2/video_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementations of Video Transformers in PyTorch 3 | 4 | A PyTorch implementation of space-time transformer as described in 5 | 'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650 6 | 7 | A PyTorch implementation of timesformer as described in 8 | 'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095 9 | 10 | Acknowledgments: 11 | - This code builds on Ross Wightman's vision_transformer code in pytorch-image-models: 12 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 13 | 14 | - It is also inspired by lucidrains timesformer implementation: 15 | https://github.com/lucidrains/TimeSformer-pytorch 16 | 17 | Hacked together by Max Bain 18 | """ 19 | 20 | import os 21 | from collections import OrderedDict 22 | from functools import partial 23 | import yaml 24 | 25 | import torch 26 | from einops import rearrange, repeat 27 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 28 | from torch import einsum, nn 29 | 30 | with open(os.path.join(os.path.dirname(__file__), "EgoNCE_MLM_ITM_Config.yaml")) as f: 31 | config_yaml = yaml.load(f, Loader=yaml.FullLoader) 32 | 33 | NUM_FUSE_BLOCK = config_yaml["num_fuse_block"] 34 | DIM_TEXT = 768 35 | 36 | 37 | def attn(q, k, v): 38 | sim = einsum("b i d, b j d -> b i j", q, k) 39 | attn = sim.softmax(dim=-1) 40 | out = einsum("b i j, b j d -> b i d", attn, v) 41 | return out 42 | 43 | 44 | class Mlp(nn.Module): 45 | def __init__( 46 | self, 47 | in_features, 48 | hidden_features=None, 49 | out_features=None, 50 | act_layer=nn.GELU, 51 | drop=0.0, 52 | ): 53 | super().__init__() 54 | out_features = out_features or in_features 55 | hidden_features = hidden_features or in_features 56 | self.fc1 = nn.Linear(in_features, hidden_features) 57 | self.act = act_layer() 58 | self.fc2 = nn.Linear(hidden_features, out_features) 59 | self.drop = nn.Dropout(drop) 60 | 61 | def forward(self, x): 62 | x = self.fc1(x) 63 | x = self.act(x) 64 | x = self.drop(x) 65 | x = self.fc2(x) 66 | x = self.drop(x) 67 | return x 68 | 69 | 70 | class VideoPatchEmbed(nn.Module): 71 | """Video to Patch Embedding""" 72 | 73 | def __init__( 74 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=8 75 | ): 76 | super().__init__() 77 | img_size = to_2tuple(img_size) 78 | patch_size = to_2tuple(patch_size) 79 | num_patches = ( 80 | (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames 81 | ) 82 | self.img_size = img_size 83 | self.patch_size = patch_size 84 | self.num_patches = num_patches 85 | self.num_frames = num_frames 86 | self.embed_dim = embed_dim 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 89 | ) 90 | 91 | def forward(self, x): 92 | B, F, C, H, W = x.shape 93 | assert F == self.num_frames, print(F, self.num_frames) 94 | x = x.view(-1, C, H, W) 95 | x = self.proj(x) 96 | return x 97 | 98 | 99 | class VarAttention(nn.Module): 100 | def __init__( 101 | self, 102 | dim, 103 | num_heads=8, 104 | qkv_bias=False, 105 | qk_scale=None, 106 | attn_drop=0.0, 107 | proj_drop=0.0, 108 | initialize="random", 109 | dim_text=None, 110 | norm_layer=nn.LayerNorm, 111 | space_attn=True, 112 | ): 113 | super().__init__() 114 | self.num_heads = num_heads 115 | head_dim = dim // num_heads 116 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 117 | self.scale = qk_scale or head_dim**-0.5 118 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 119 | self.proj = nn.Linear(dim, dim) 120 | if initialize == "zeros": 121 | self.qkv.weight.data.fill_(0) 122 | self.qkv.bias.data.fill_(0) 123 | # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs 124 | # are multiplied by 0*0, which is hard for the model to move out of. 125 | self.proj.weight.data.fill_(1) 126 | self.proj.bias.data.fill_(0) 127 | self.attn_drop = nn.Dropout(attn_drop) 128 | self.proj_drop = nn.Dropout(proj_drop) 129 | 130 | self.softmax = nn.Softmax(dim=-1) 131 | 132 | if dim_text is not None and space_attn: 133 | self.qkv_text_i2t = nn.Linear(dim_text, dim * 2, bias=qkv_bias) 134 | self.qkv_i2t = nn.Linear(dim, dim, bias=qkv_bias) 135 | self.attn_drop_i2t = nn.Dropout(attn_drop) 136 | self.proj_i2t = nn.Linear(dim, dim) 137 | self.proj_drop_i2t = nn.Dropout(proj_drop) 138 | self.alpha_i2t = nn.Parameter(torch.Tensor([0])) 139 | self.norm_i2t_i = norm_layer(dim) 140 | 141 | def forward(self, x, einops_from, einops_to, y=None, y_mask=None, **einops_dims): 142 | h = self.num_heads 143 | # project x to q, k, v vaalues 144 | q, k, v = self.qkv(x).chunk(3, dim=-1) 145 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 146 | 147 | q = q * self.scale 148 | 149 | # splice out CLS token at index 1 150 | (cls_q, q_), (cls_k, k_), (cls_v, v_) = map( 151 | lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v) 152 | ) 153 | 154 | # let CLS token attend to key / values of all patches across time and space 155 | cls_out = attn(cls_q, k, v) 156 | # rearrange across time or space 157 | q_, k_, v_ = map( 158 | lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), 159 | (q_, k_, v_), 160 | ) 161 | 162 | # expand cls token keys and values across time or space and concat 163 | r = q_.shape[0] // cls_k.shape[0] 164 | cls_k, cls_v = map( 165 | lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v) 166 | ) 167 | 168 | k_ = torch.cat((cls_k, k_), dim=1) 169 | v_ = torch.cat((cls_v, v_), dim=1) 170 | 171 | # attention 172 | out = attn(q_, k_, v_) 173 | 174 | # merge back time or space 175 | out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) 176 | 177 | # concat back the cls token 178 | out = torch.cat((cls_out, out), dim=1) 179 | 180 | # merge back the heads 181 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 182 | ## to out 183 | x = self.proj(out) 184 | x = self.proj_drop(x) 185 | 186 | if y is not None: 187 | B_, N, C = x.shape 188 | B_text, N_text, C_text = y.shape 189 | 190 | kv_text = ( 191 | self.qkv_text_i2t(y) 192 | .reshape(B_text, N_text, 2, self.num_heads, C // self.num_heads) 193 | .permute(2, 0, 3, 1, 4) 194 | ) 195 | k_text, v_text = kv_text[0], kv_text[1] 196 | 197 | q_i2t = self.qkv_i2t(self.norm_i2t_i(x)) 198 | q_i2t = q_i2t.reshape( 199 | B_, N, 1, self.num_heads, C // self.num_heads 200 | ).permute(2, 0, 3, 1, 4) 201 | q_i2t = q_i2t[0] 202 | 203 | # image to text attention 204 | text_scale = k_text.size(-1) ** -0.5 205 | q_i2t = q_i2t * text_scale 206 | attn_i2t = q_i2t @ k_text.transpose(-2, -1) # B_, nH, N, N_text 207 | 208 | # add image to text bias and text_mask 209 | if y_mask is not None: 210 | mask_and_i2t_bias = y_mask.view(B_text, 1, 1, N_text) 211 | attn_i2t = attn_i2t + mask_and_i2t_bias 212 | 213 | attn_i2t = self.softmax(attn_i2t) 214 | attn_i2t = self.attn_drop_i2t(attn_i2t) 215 | y = (attn_i2t @ v_text).transpose(1, 2).reshape(B_, N, C) 216 | y = self.proj_i2t(y) 217 | y = self.proj_drop_i2t(y) 218 | x = x + self.alpha_i2t * y 219 | 220 | return x 221 | 222 | 223 | class SpaceTimeBlock(nn.Module): 224 | 225 | def __init__( 226 | self, 227 | dim, 228 | num_heads, 229 | mlp_ratio=4.0, 230 | qkv_bias=False, 231 | qk_scale=None, 232 | drop=0.0, 233 | attn_drop=0.0, 234 | drop_path=0.0, 235 | act_layer=nn.GELU, 236 | norm_layer=nn.LayerNorm, 237 | time_init="zeros", 238 | attention_style="frozen-in-time", 239 | dim_text=None, 240 | ): 241 | super().__init__() 242 | self.norm1 = norm_layer(dim) 243 | self.attn = VarAttention( 244 | dim, 245 | num_heads=num_heads, 246 | qkv_bias=qkv_bias, 247 | qk_scale=qk_scale, 248 | attn_drop=attn_drop, 249 | proj_drop=drop, 250 | dim_text=dim_text, 251 | norm_layer=norm_layer, 252 | space_attn=True, 253 | ) 254 | 255 | self.timeattn = VarAttention( 256 | dim, 257 | num_heads=num_heads, 258 | qkv_bias=qkv_bias, 259 | qk_scale=qk_scale, 260 | attn_drop=attn_drop, 261 | proj_drop=drop, 262 | initialize=time_init, 263 | dim_text=dim_text, 264 | norm_layer=norm_layer, 265 | space_attn=False, 266 | ) 267 | 268 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 269 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 270 | self.norm2 = norm_layer(dim) 271 | mlp_hidden_dim = int(dim * mlp_ratio) 272 | self.mlp = Mlp( 273 | in_features=dim, 274 | hidden_features=mlp_hidden_dim, 275 | act_layer=act_layer, 276 | drop=drop, 277 | ) 278 | self.norm3 = norm_layer(dim) 279 | 280 | self.attention_style = attention_style 281 | 282 | def forward( 283 | self, 284 | x, 285 | einops_from_space, 286 | einops_to_space, 287 | einops_from_time, 288 | einops_to_time, 289 | time_n, 290 | space_f, 291 | y=None, 292 | y_mask=None, 293 | ): 294 | 295 | time_output = self.timeattn( 296 | self.norm3(x), 297 | einops_from_time, 298 | einops_to_time, 299 | n=time_n, 300 | y=None, 301 | y_mask=None, 302 | ) 303 | time_residual = x + time_output 304 | space_output = self.attn( 305 | self.norm1(time_residual), 306 | einops_from_space, 307 | einops_to_space, 308 | f=space_f, 309 | y=y, 310 | y_mask=y_mask, 311 | ) 312 | if self.attention_style == "frozen-in-time": 313 | space_residual = x + self.drop_path(space_output) 314 | else: 315 | raise NotImplementedError 316 | 317 | x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual))) 318 | 319 | return x 320 | 321 | 322 | class SpaceTimeTransformer(nn.Module): 323 | """Vision Transformer 324 | 325 | A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain. 326 | https://arxiv.org/abs/2104.00650 327 | 328 | Based off: 329 | - ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py] 330 | lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch]. 331 | 332 | Notable differences: 333 | - allows for variable length input frames (<= num_frames) 334 | - allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED] 335 | - different attention block mechanism 336 | """ 337 | 338 | def __init__( 339 | self, 340 | img_size=224, 341 | patch_size=16, 342 | in_chans=3, 343 | num_classes=1000, 344 | embed_dim=768, 345 | depth=12, 346 | num_heads=12, 347 | mlp_ratio=4.0, 348 | qkv_bias=True, 349 | qk_scale=None, 350 | representation_size=None, 351 | drop_rate=0.0, 352 | attn_drop_rate=0.0, 353 | drop_path_rate=0.0, 354 | hybrid_backbone=None, 355 | num_frames=8, 356 | time_init="rand", 357 | attention_style="frozen-in-time", 358 | norm_layer=nn.LayerNorm, 359 | dim_text=None, 360 | ): 361 | """ 362 | Args: 363 | img_size (int, tuple): input image size 364 | patch_size (int, tuple): patch size 365 | in_chans (int): number of input channels 366 | num_classes (int): number of classes for classification head 367 | embed_dim (int): embedding dimension 368 | depth (int): depth of transformer 369 | num_heads (int): number of attention heads 370 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 371 | qkv_bias (bool): enable bias for qkv if True 372 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 373 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 374 | drop_rate (float): dropout rate 375 | attn_drop_rate (float): attention dropout rate 376 | drop_path_rate (float): stochastic depth rate 377 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module 378 | norm_layer: (nn.Module): normalization layer 379 | num_frames: (int) maximum number of frames expected as input 380 | time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off 381 | as ViT. 382 | attention_style: (str) how to attend to space and time. 383 | """ 384 | super().__init__() 385 | self.num_classes = num_classes 386 | self.num_features = self.embed_dim = ( 387 | embed_dim # num_features for consistency with other models 388 | ) 389 | self.num_frames = num_frames 390 | self.embed_dim = embed_dim 391 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 392 | print("######USING ATTENTION STYLE: ", attention_style) 393 | if hybrid_backbone is not None: 394 | raise NotImplementedError("hybrid backbone not implemented") 395 | else: 396 | self.patch_embed = VideoPatchEmbed( 397 | img_size=img_size, 398 | patch_size=patch_size, 399 | in_chans=in_chans, 400 | embed_dim=embed_dim, 401 | num_frames=num_frames, 402 | ) 403 | num_patches = self.patch_embed.num_patches 404 | self.patches_per_frame = num_patches // num_frames 405 | 406 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 407 | self.pos_embed = nn.Parameter( 408 | torch.zeros(1, self.patches_per_frame + 1, embed_dim) 409 | ) # remember to take pos_embed[1:] for tiling over time 410 | self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) 411 | 412 | self.pos_drop = nn.Dropout(p=drop_rate) 413 | 414 | dpr = [ 415 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 416 | ] # stochastic depth decay rule 417 | self.blocks = nn.ModuleList( 418 | [ 419 | SpaceTimeBlock( 420 | dim=embed_dim, 421 | num_heads=num_heads, 422 | mlp_ratio=mlp_ratio, 423 | qkv_bias=qkv_bias, 424 | qk_scale=qk_scale, 425 | drop=drop_rate, 426 | attn_drop=attn_drop_rate, 427 | drop_path=dpr[i], 428 | norm_layer=norm_layer, 429 | time_init=time_init, 430 | attention_style=attention_style, 431 | dim_text=None if i < 6 else DIM_TEXT, 432 | ) 433 | for i in range(depth) 434 | ] 435 | ) 436 | self.norm = norm_layer(embed_dim) 437 | 438 | # Representation layer 439 | if representation_size: 440 | self.num_features = representation_size 441 | self.pre_logits = nn.Sequential( 442 | OrderedDict( 443 | [ 444 | ("fc", nn.Linear(embed_dim, representation_size)), 445 | ("act", nn.Tanh()), 446 | ] 447 | ) 448 | ) 449 | else: 450 | self.pre_logits = nn.Identity() 451 | 452 | # Classifier head 453 | self.head = ( 454 | nn.Linear(self.num_features, num_classes) 455 | if num_classes > 0 456 | else nn.Identity() 457 | ) 458 | 459 | trunc_normal_(self.pos_embed, std=0.02) 460 | trunc_normal_(self.cls_token, std=0.02) 461 | 462 | # if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary. 463 | if num_frames == 1: 464 | self.apply(self._init_weights) 465 | 466 | ## einops transformations 467 | self.einops_from_space = "b (f n) d" 468 | self.einops_to_space = "(b f) n d" 469 | self.einops_from_time = "b (f n) d" 470 | self.einops_to_time = "(b n) f d" 471 | 472 | def _init_weights(self, m): 473 | if isinstance(m, nn.Linear): 474 | trunc_normal_(m.weight, std=0.02) 475 | if isinstance(m, nn.Linear) and m.bias is not None: 476 | nn.init.constant_(m.bias, 0) 477 | elif isinstance(m, nn.LayerNorm): 478 | nn.init.constant_(m.bias, 0) 479 | nn.init.constant_(m.weight, 1.0) 480 | 481 | @torch.jit.ignore 482 | def no_weight_decay(self): 483 | return {"pos_embed", "cls_token"} 484 | 485 | def get_classifier(self): 486 | return self.head 487 | 488 | def reset_classifier(self, num_classes, global_pool=""): 489 | self.num_classes = num_classes 490 | self.head = ( 491 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 492 | ) 493 | 494 | def forward_features(self, x): 495 | b, curr_frames, channels, _, _ = x.shape 496 | x = self.patch_embed(x) 497 | x = x.flatten(2).transpose(2, 1) 498 | x = x.reshape(b, -1, self.patch_embed.embed_dim) 499 | 500 | BF = x.shape[0] 501 | cls_tokens = self.cls_token.expand( 502 | BF, -1, -1 503 | ) # stole cls_tokens impl from Phil Wang, thanks 504 | x = torch.cat((cls_tokens, x), dim=1) 505 | # positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...) 506 | cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) 507 | tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1) 508 | # temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...) 509 | tile_temporal_embed = self.temporal_embed.repeat_interleave( 510 | self.patches_per_frame, 1 511 | ) 512 | total_pos_embed = tile_pos_embed + tile_temporal_embed 513 | total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) 514 | 515 | curr_patches = x.shape[1] 516 | x = x + total_pos_embed[:, :curr_patches] 517 | x = self.pos_drop(x) 518 | n = self.patches_per_frame 519 | f = curr_frames 520 | 521 | for blk in self.blocks: 522 | if config_yaml["use_checkpoint"]: 523 | 524 | def create_custom_forward(module): 525 | def custom_forward(*inputs): 526 | return module(*inputs, time_n=n, space_f=f) 527 | 528 | return custom_forward 529 | 530 | x = torch.utils.checkpoint.checkpoint( 531 | create_custom_forward(blk), 532 | x, 533 | self.einops_from_space, 534 | self.einops_to_space, 535 | self.einops_from_time, 536 | self.einops_to_time, 537 | ) 538 | else: 539 | x = blk( 540 | x, 541 | self.einops_from_space, 542 | self.einops_to_space, 543 | self.einops_from_time, 544 | self.einops_to_time, 545 | time_n=n, 546 | space_f=f, 547 | ) 548 | 549 | x = self.norm(x)[:, 0] 550 | x = self.pre_logits(x) 551 | 552 | return x 553 | 554 | def forward(self, x): 555 | x = self.forward_features(x) 556 | x = self.head(x) 557 | return x 558 | 559 | 560 | if __name__ == "__main__": 561 | network = SpaceTimeTransformer(num_frames=4) 562 | data = torch.rand((3, 4, 3, 224, 224)) 563 | network(data) 564 | -------------------------------------------------------------------------------- /model/egovlpv2/video_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import torch 5 | 6 | from torchvision import transforms 7 | from torchvision.transforms._transforms_video import NormalizeVideo 8 | 9 | 10 | def sample_frames(num_frames, vlen, sample="rand", fix_start=None): 11 | acc_samples = min(num_frames, vlen) 12 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 13 | ranges = [] 14 | for idx, interv in enumerate(intervals[:-1]): 15 | ranges.append((interv, intervals[idx + 1] - 1)) 16 | if sample == "rand": 17 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] 18 | elif fix_start is not None: 19 | frame_idxs = [x[0] + fix_start for x in ranges] 20 | elif sample == "uniform": 21 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 22 | elif sample == "middle_repeat": 23 | frame_idxs = [vlen // 2] * num_frames 24 | else: 25 | raise NotImplementedError 26 | 27 | return frame_idxs 28 | 29 | 30 | def read_frames_cv2(video_path, num_frames, sample="rand", fix_start=None): 31 | cap = cv2.VideoCapture(video_path) 32 | assert cap.isOpened() 33 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 34 | # get indexes of sampled frames 35 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start) 36 | frames = [] 37 | success_idxs = [] 38 | for index in frame_idxs: 39 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) 40 | ret, frame = cap.read() 41 | if ret: 42 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 43 | frame = torch.from_numpy(frame) 44 | # (H x W x C) to (C x H x W) 45 | frame = frame.permute(2, 0, 1) 46 | frames.append(frame) 47 | success_idxs.append(index) 48 | else: 49 | pass 50 | print(frame_idxs, " fail ", index, f" (vlen {vlen})") 51 | 52 | frames = torch.stack(frames).float() / 255 53 | cap.release() 54 | return frames, success_idxs 55 | 56 | 57 | class FrameLoader: 58 | def __init__(self, num_frames, method="rand", fix_start=None): 59 | self.num_frames = num_frames 60 | self.method = method 61 | self.fix_start = fix_start 62 | 63 | normalize = NormalizeVideo( 64 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 65 | ) 66 | self.transforms = transforms.Compose( 67 | [ 68 | transforms.Resize(256), 69 | transforms.CenterCrop(256), 70 | transforms.Resize(224), 71 | normalize, 72 | ] 73 | ) 74 | 75 | def __call__(self, video_path): 76 | frames, success_idxs = read_frames_cv2( 77 | video_path, self.num_frames, sample=self.method, fix_start=self.fix_start 78 | ) 79 | 80 | if self.num_frames > 1: 81 | frames = frames.transpose(0, 1) # [T, C, H, W] ---> [C, T, H, W] 82 | frames = self.transforms(frames) 83 | frames = frames.transpose(0, 1) # recover 84 | else: 85 | frames = self.transforms(frames) 86 | 87 | return frames 88 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import importlib 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from model.utils import FrameLoader, pre_caption 9 | 10 | TEXT_MAX_WORDS = 30 11 | 12 | 13 | def temporal_sample(embeddings, num_frames): 14 | if embeddings.ndim == 1: 15 | return embeddings 16 | len_emb = len(embeddings) 17 | if num_frames == 1: 18 | # take the middle frame 19 | return embeddings[len_emb // 2] 20 | else: 21 | # for now just averaging, could be improved to sample num_frames frames 22 | return embeddings.mean(0) 23 | 24 | 25 | def forward_blip( 26 | modality, 27 | model, 28 | tokenizer, 29 | ref_img, 30 | caption, 31 | video_path, 32 | frame_loader, 33 | modality_fusion, 34 | num_query_frames, 35 | query_frame_method, 36 | use_precomputed=True, 37 | device="cuda", 38 | ): 39 | if not use_precomputed or ( 40 | modality == "visual-text" and modality_fusion == "crossattn" 41 | ): 42 | ref_img = frame_loader(video_path.as_posix()).to(device) 43 | if query_frame_method == "sample": 44 | cross_embed = [] 45 | for i in range(len(ref_img)): 46 | embed = forward_blip_crossattn(model, ref_img[i], caption) 47 | cross_embed.append(embed) 48 | cross_embed = torch.stack(cross_embed, dim=0) 49 | cross_embed = cross_embed.mean(0) 50 | return cross_embed 51 | 52 | if modality == "visual": 53 | return ( 54 | temporal_sample(ref_img, num_query_frames) 55 | if use_precomputed 56 | else forward_blip_visual(model, ref_img) 57 | ) 58 | elif modality == "visual-text": 59 | if modality_fusion == "crossattn": 60 | # can't use pre-extracted embeddings for crossattn fusion 61 | return forward_blip_crossattn(model, ref_img, caption) 62 | elif modality_fusion == "avg": 63 | if use_precomputed: 64 | if isinstance(caption, torch.Tensor): 65 | text_embed = caption 66 | else: 67 | text_embed = forward_blip_text(model, caption) 68 | ref_img = temporal_sample(ref_img, num_query_frames) 69 | return (ref_img + text_embed) / 2 70 | else: 71 | return forward_blip_visual_text_avg(model, ref_img, caption) 72 | else: 73 | raise NotImplementedError(f"Fusion {modality_fusion} not implemented") 74 | elif modality == "text": 75 | if use_precomputed and isinstance(caption, torch.Tensor): 76 | return caption 77 | return forward_blip_text(model, caption) 78 | else: 79 | raise NotImplementedError(f"Modality {modality} not implemented") 80 | 81 | 82 | def forward_blip_visual(model, ref_img): 83 | ref_img = ref_img.unsqueeze(0) 84 | 85 | model.eval() 86 | ref_img_embs = model.visual_encoder(ref_img) 87 | query_feat = F.normalize(model.vision_proj(ref_img_embs[:, 0, :]), dim=-1) 88 | 89 | query_feat = query_feat.squeeze(0) 90 | return query_feat 91 | 92 | 93 | def forward_blip_text(model, caption, device="cuda"): 94 | caption = pre_caption(caption, TEXT_MAX_WORDS) 95 | text = model.tokenizer( 96 | caption, 97 | padding="longest", 98 | truncation=True, 99 | max_length=64, 100 | return_tensors="pt", 101 | ).to(device) 102 | 103 | # Shift encoder 104 | query_embs = model.text_encoder( 105 | text.input_ids, 106 | attention_mask=text.attention_mask, 107 | return_dict=True, 108 | mode="text", 109 | ) 110 | query_feat = query_embs.last_hidden_state[:, 0, :] 111 | query_feat = F.normalize(model.text_proj(query_feat), dim=-1) 112 | 113 | query_feat = query_feat.squeeze(0) 114 | return query_feat 115 | 116 | 117 | def forward_blip_crossattn(model, ref_img, caption): 118 | ref_img = ref_img.unsqueeze(0) 119 | 120 | model.eval() 121 | device = ref_img.device 122 | 123 | ref_img_embs = model.visual_encoder(ref_img) 124 | ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device) 125 | 126 | caption = pre_caption(caption, TEXT_MAX_WORDS) 127 | text = model.tokenizer( 128 | caption, 129 | padding="longest", 130 | truncation=True, 131 | max_length=64, 132 | return_tensors="pt", 133 | ).to(device) 134 | 135 | # Shift encoder 136 | encoder_input_ids = text.input_ids.clone() 137 | encoder_input_ids[:, 0] = model.tokenizer.enc_token_id 138 | query_embs = model.text_encoder( 139 | encoder_input_ids, 140 | attention_mask=text.attention_mask, 141 | encoder_hidden_states=ref_img_embs, 142 | encoder_attention_mask=ref_img_atts, 143 | return_dict=True, 144 | ) 145 | query_feat = query_embs.last_hidden_state[:, 0, :] 146 | query_feat = F.normalize(model.text_proj(query_feat), dim=-1) 147 | 148 | query_feat = query_feat.squeeze(0) 149 | return query_feat 150 | 151 | 152 | def forward_blip_visual_text_avg(model, ref_img, caption): 153 | ref_img = ref_img.unsqueeze(0) 154 | 155 | model.eval() 156 | device = ref_img.device 157 | 158 | # visual forward 159 | ref_img_embs = model.visual_encoder(ref_img) 160 | query_feat_vis = F.normalize(model.vision_proj(ref_img_embs[:, 0, :]), dim=-1) 161 | 162 | caption = pre_caption(caption, TEXT_MAX_WORDS) 163 | text = model.tokenizer( 164 | caption, 165 | padding="longest", 166 | truncation=True, 167 | max_length=64, 168 | return_tensors="pt", 169 | ).to(device) 170 | 171 | # Shift encoder 172 | query_text_embs = model.text_encoder( 173 | text.input_ids, 174 | attention_mask=text.attention_mask, 175 | return_dict=True, 176 | mode="text", 177 | ) 178 | query_feat_txt = query_text_embs.last_hidden_state[:, 0, :] 179 | query_feat_txt = F.normalize(model.text_proj(query_feat_txt), dim=-1) 180 | 181 | query_feat = (query_feat_vis + query_feat_txt) / 2 182 | query_feat = query_feat.squeeze(0) 183 | return query_feat 184 | 185 | 186 | def forward_egovlpv2( 187 | modality, 188 | model, 189 | tokenizer, 190 | ref_img, 191 | caption, 192 | video_path, 193 | frame_loader, 194 | modality_fusion, 195 | num_query_frames, 196 | query_frame_method, 197 | use_precomputed=True, 198 | device="cuda", 199 | ): 200 | if not use_precomputed or ( 201 | modality == "visual-text" and modality_fusion == "crossattn" 202 | ): 203 | ref_img = frame_loader(video_path.as_posix()).to(device) 204 | 205 | if modality == "visual": 206 | return ref_img if use_precomputed else forward_egovlpv2_visual(model, ref_img) 207 | 208 | elif modality == "visual-text": 209 | if modality_fusion == "crossattn": 210 | # can't use pre-extracted embeddings for crossattn fusion 211 | raise NotImplementedError("Crossattn fusion not implemented for EgoVLPv2") 212 | elif modality_fusion == "avg": 213 | if use_precomputed: 214 | if isinstance(caption, torch.Tensor): 215 | text_embed = caption 216 | else: 217 | text_embed = forward_egovlpv2_text(model, tokenizer, caption) 218 | return (ref_img + text_embed) / 2 219 | else: 220 | text_embed = forward_egovlpv2_text(model, tokenizer, caption) 221 | video_embed = forward_egovlpv2_visual(model, ref_img) 222 | return (video_embed + text_embed) / 2 223 | else: 224 | raise NotImplementedError(f"Fusion {modality_fusion} not implemented") 225 | elif modality == "text": 226 | if use_precomputed and isinstance(caption, torch.Tensor): 227 | return caption 228 | return forward_egovlpv2_text(model, tokenizer, caption) 229 | else: 230 | raise NotImplementedError(f"Modality {modality} not implemented") 231 | 232 | 233 | def forward_egovlpv2_text(model, tokenizer, caption, device="cuda"): 234 | text = tokenizer(caption, return_tensors="pt", padding=True, truncation=True) 235 | text = {key: val.cuda(device) for key, val in text.items()} 236 | text_embed = model.compute_text(text) 237 | text_embed /= text_embed.norm(dim=-1, keepdim=True) 238 | return text_embed 239 | 240 | 241 | def forward_egovlpv2_visual(model, ref_img): 242 | ref_img = ref_img.unsqueeze(0) 243 | video_embed = model.compute_video(ref_img) 244 | video_embed /= video_embed.norm(dim=-1, keepdim=True) 245 | video_embed = video_embed.squeeze(0) 246 | return video_embed 247 | 248 | 249 | def forward_clip( 250 | modality, 251 | model, 252 | tokenizer, 253 | ref_img, 254 | caption, 255 | video_path, 256 | frame_loader, 257 | modality_fusion, 258 | num_query_frames, 259 | query_frame_method, 260 | use_precomputed=True, 261 | device="cuda", 262 | ): 263 | if not use_precomputed: 264 | ref_img = frame_loader(video_path.as_posix()).to(device) 265 | if query_frame_method == "sample": 266 | raise NotImplementedError 267 | 268 | if modality == "visual": 269 | return ( 270 | temporal_sample(ref_img, num_query_frames) 271 | if use_precomputed 272 | else forward_clip_visual(model, ref_img) 273 | ) 274 | elif modality == "visual-text": 275 | if modality_fusion == "crossattn": 276 | # can't use pre-extracted embeddings for crossattn fusion 277 | raise NotImplementedError("Crossattn fusion not implemented for CLIP") 278 | elif modality_fusion == "avg": 279 | if use_precomputed: 280 | if isinstance(caption, torch.Tensor): 281 | text_embed = caption 282 | else: 283 | text_embed = forward_clip_text(model, tokenizer, caption, device) 284 | ref_img = temporal_sample(ref_img, num_query_frames) 285 | return (ref_img + text_embed) / 2 286 | else: 287 | text_embed = forward_clip_text(model, tokenizer, caption, device) 288 | video_embed = forward_clip_visual(model, ref_img) 289 | return (video_embed + text_embed) / 2 290 | else: 291 | raise NotImplementedError(f"Fusion {modality_fusion} not implemented") 292 | 293 | elif modality == "text": 294 | if use_precomputed and isinstance(caption, torch.Tensor): 295 | return caption 296 | return forward_clip_text(model, tokenizer, caption, device) 297 | else: 298 | raise NotImplementedError(f"Modality {modality} not implemented") 299 | 300 | 301 | def forward_clip_text(model, tokenizer, caption, device="cuda"): 302 | text = tokenizer(caption).to(device) 303 | with torch.cuda.amp.autocast(): 304 | text_embed = model.encode_text(text) 305 | text_embed /= text_embed.norm(dim=-1, keepdim=True) 306 | 307 | text_embed = text_embed.squeeze(0) 308 | text_embed = text_embed.float() 309 | return text_embed 310 | 311 | 312 | def forward_clip_visual(model, ref_img): 313 | with torch.cuda.amp.autocast(): 314 | video_embed = model.encode_image(ref_img) 315 | video_embed /= video_embed.norm(dim=-1, keepdim=True) 316 | video_embed = video_embed.float() 317 | 318 | return video_embed 319 | 320 | 321 | def init_EgoVLPv2(checkpoint_path, device="cuda", no_temporal=False, small_proj=False): 322 | from model.egovlpv2.model import FrozenInTime 323 | from model.egovlpv2.video_utils import FrameLoader as EgoFrameLoader 324 | import transformers 325 | 326 | video_params = { 327 | "model": "SpaceTimeTransformer", 328 | "arch_config": "base_patch16_224", 329 | "num_frames": 16, 330 | "pretrained": True, 331 | "time_init": "zeros", 332 | } 333 | text_params = {"model": "roberta-base", "pretrained": True, "input": "text"} 334 | 335 | tokenizer = transformers.AutoTokenizer.from_pretrained( 336 | "roberta-base", TOKENIZERS_PARALLELISM=False 337 | ) 338 | projection = "small" if small_proj else "default" 339 | 340 | model = FrozenInTime( 341 | video_params, 342 | text_params, 343 | projection_dim=4096, 344 | load_checkpoint=checkpoint_path, 345 | projection=projection, 346 | load_temporal_fix="bilinear", 347 | task_names="EgoNCE_ITM_MLM", 348 | norm_layer=None, 349 | embed_dim=768, 350 | ) 351 | model = model.to(device) 352 | 353 | if no_temporal: 354 | frame_method = "middle_repeat" 355 | else: 356 | frame_method = "uniform" 357 | frame_loader = EgoFrameLoader(16, method=frame_method) 358 | model.eval() 359 | return model, frame_loader, tokenizer 360 | 361 | 362 | class SimpleEgoVLPDataset(torch.utils.data.Dataset): 363 | def __init__(self, video_paths, frame_loader, transform=None): 364 | self.video_paths = video_paths 365 | self.frame_loader = frame_loader 366 | self.transform = transform 367 | # remove extension and only filename 368 | self.video_ids = [Path(p).stem for p in video_paths] 369 | 370 | def __len__(self): 371 | return len(self.video_paths) 372 | 373 | def __getitem__(self, idx): 374 | video_path = self.video_paths[idx] 375 | video_id = self.video_ids[idx] 376 | frames = self.frame_loader(video_path) 377 | if self.transform: 378 | frames = self.transform(frames) 379 | return video_id, idx, frames 380 | 381 | 382 | def init_BLIP(checkpoint_path, query_frame_method, num_query_frames, device="cuda"): 383 | from model.blip.model import blip_cir, BLIPCir 384 | from model.blip.transforms import transform_test 385 | 386 | config_path = os.path.join(os.path.dirname(__file__), "blip", "med_config.json") 387 | 388 | model = BLIPCir( 389 | med_config=config_path, 390 | image_size=384, 391 | vit="large", 392 | vit_grad_ckpt=True, 393 | vit_ckpt_layer=12, 394 | embed_dim=256, 395 | train_vit=False, 396 | loss=None, 397 | ) 398 | model = blip_cir(model, checkpoint_path) 399 | model = model.to(device) 400 | 401 | transform = transform_test(384) 402 | # frame loader for query videos. "middle" or "sample" 403 | frame_loader = FrameLoader( 404 | transform=transform, method=query_frame_method, frames_video=num_query_frames 405 | ) 406 | model.eval() 407 | return model, frame_loader, None 408 | 409 | 410 | def init_CLIP(query_frame_method, num_query_frames, device="cuda"): 411 | import open_clip 412 | 413 | model, _, preprocess = open_clip.create_model_and_transforms( 414 | "ViT-L-14", pretrained="datacomp_xl_s13b_b90k" 415 | ) 416 | tokenizer = open_clip.get_tokenizer("ViT-L-14") 417 | 418 | frame_loader = FrameLoader( 419 | transform=preprocess, method=query_frame_method, frames_video=num_query_frames 420 | ) 421 | 422 | model = model.to(device) 423 | model.eval() 424 | return model, frame_loader, tokenizer 425 | 426 | 427 | class LanguageBindFrameLoader: 428 | def __init__(self, transform): 429 | self.transform = transform 430 | 431 | def __call__(self, video_path): 432 | if isinstance(video_path, Path): 433 | video_path = video_path.as_posix() 434 | return self.transform([video_path]) 435 | 436 | 437 | class LanguageBindTensorDivider(torch.nn.Module): 438 | def __init__(self, value=255.0): 439 | super().__init__() 440 | self.value = value 441 | 442 | def forward(self, x: torch.Tensor) -> torch.Tensor: 443 | return x / self.value 444 | 445 | 446 | def languagebind_middle_frame_processor( 447 | video_path, 448 | transform, 449 | video_decode_backend="opencv", 450 | clip_start_sec=0.0, 451 | clip_end_sec=None, 452 | num_frames=8, 453 | ): 454 | # Use repeated middle frame 455 | cv2 = importlib.import_module("cv2") 456 | cv2_vr = cv2.VideoCapture(video_path) 457 | duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) 458 | frame_idx = duration // 2 459 | 460 | video_data = [] 461 | cv2_vr.set(1, frame_idx) 462 | _, frame = cv2_vr.read() 463 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 464 | cv2_vr.release() 465 | 466 | frame = torch.from_numpy(frame).permute(2, 0, 1) 467 | video_data = [frame] * num_frames 468 | video_data = torch.stack(video_data, dim=1) 469 | video_outputs = transform(video_data) 470 | 471 | return video_outputs 472 | 473 | 474 | def init_languagebind( 475 | device="cuda", variant="LanguageBind_Video_FT", no_temporal=False 476 | ): 477 | from languagebind import ( 478 | LanguageBind, 479 | transform_dict, 480 | LanguageBindVideoTokenizer, 481 | ) 482 | 483 | tokenizer = LanguageBindVideoTokenizer.from_pretrained(f"LanguageBind/{variant}") 484 | model = LanguageBind( 485 | clip_type={"video": variant}, 486 | ).to(device) 487 | model.eval() 488 | video_transform = transform_dict["video"](model.modality_config["video"]) 489 | # patch the video transform to not use random horizontal flipping 490 | video_transform.transform.transforms = video_transform.transform.transforms[:4] 491 | video_transform.transform.transforms[0] = LanguageBindTensorDivider(255.0) 492 | 493 | if no_temporal: 494 | video_transform.image_processor = languagebind_middle_frame_processor 495 | frame_loader = LanguageBindFrameLoader(video_transform) 496 | 497 | return model, frame_loader, tokenizer 498 | 499 | 500 | def forward_languagebind_text(model, tokenizer, caption, device="cuda"): 501 | to_device = getattr(importlib.import_module("languagebind"), "to_device") 502 | 503 | data = { 504 | "language": to_device( 505 | tokenizer( 506 | [caption], 507 | max_length=77, 508 | padding="max_length", 509 | truncation=True, 510 | return_tensors="pt", 511 | ), 512 | device, 513 | ) 514 | } 515 | embeddings = model(data) 516 | return embeddings["language"][0] 517 | 518 | 519 | def forward_languagebind_visual(model, ref_img, device="cuda"): 520 | to_device = getattr(importlib.import_module("languagebind"), "to_device") 521 | 522 | if ref_img["pixel_values"].ndim == 6: 523 | ref_img["pixel_values"] = ref_img["pixel_values"].squeeze(0) 524 | 525 | data = {"video": to_device(ref_img, device)} 526 | embeddings = model(data) 527 | return embeddings["video"][0] 528 | 529 | 530 | def forward_languagebind( 531 | modality, 532 | model, 533 | tokenizer, 534 | ref_img, 535 | caption, 536 | video_path, 537 | frame_loader, 538 | modality_fusion, 539 | num_query_frames, 540 | query_frame_method, 541 | use_precomputed=True, 542 | device="cuda", 543 | ): 544 | 545 | if not use_precomputed: 546 | ref_img = frame_loader(video_path.as_posix()) 547 | 548 | if modality == "visual": 549 | return ( 550 | ref_img 551 | if use_precomputed 552 | else forward_languagebind_visual(model, ref_img, device=device) 553 | ) 554 | 555 | elif modality == "visual-text": 556 | if modality_fusion == "crossattn": 557 | raise NotImplementedError( 558 | "Crossattn fusion not implemented for LanguageBind" 559 | ) 560 | elif modality_fusion == "avg": 561 | if use_precomputed: 562 | if isinstance(caption, torch.Tensor): 563 | text_embed = caption 564 | else: 565 | text_embed = forward_languagebind_text( 566 | model, tokenizer, caption, device=device 567 | ) 568 | if ref_img.ndim == 2: 569 | ref_img = ref_img.squeeze(0) 570 | return (ref_img + text_embed) / 2 571 | else: 572 | text_embed = forward_languagebind_text( 573 | model, tokenizer, caption, device=device 574 | ) 575 | video_embed = forward_languagebind_visual(model, ref_img, device=device) 576 | return (video_embed + text_embed) / 2 577 | else: 578 | raise NotImplementedError(f"Fusion {modality_fusion} not implemented") 579 | elif modality == "text": 580 | if use_precomputed and isinstance(caption, torch.Tensor): 581 | return caption 582 | return forward_languagebind_text(model, tokenizer, caption, device=device) 583 | else: 584 | raise NotImplementedError(f"Modality {modality} not implemented") 585 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from: https://github.com/lucas-ventura/CoVR 3 | 4 | MIT License 5 | 6 | Copyright (c) 2023 Lucas Ventura 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | import re 28 | from typing import Union 29 | 30 | import torch 31 | 32 | 33 | def pre_caption(caption, max_words=50): 34 | caption = re.sub( 35 | r"([.!\"()*#:;~])", 36 | " ", 37 | caption.lower(), 38 | ) 39 | caption = re.sub( 40 | r"\s{2,}", 41 | " ", 42 | caption, 43 | ) 44 | caption = caption.rstrip("\n") 45 | caption = caption.strip(" ") 46 | 47 | # truncate caption 48 | caption_words = caption.split(" ") 49 | if len(caption_words) > max_words: 50 | caption = " ".join(caption_words[:max_words]) 51 | 52 | return caption 53 | 54 | 55 | def remove_non_digits(string, sub: str = ""): 56 | return int(re.sub(r"\D", sub, string)) 57 | 58 | 59 | def get_middle_frame(reference_vid_pth): 60 | from pathlib import Path 61 | 62 | import cv2 63 | import numpy as np 64 | from PIL import Image 65 | 66 | reference_vid_pth = str(reference_vid_pth) 67 | 68 | if not Path(reference_vid_pth).exists(): 69 | print(f"Video {reference_vid_pth} does not exist") 70 | return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) 71 | 72 | # use OpenCV to read the video 73 | cap = cv2.VideoCapture(reference_vid_pth) 74 | 75 | # get the total number of frames in the video 76 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 77 | 78 | # calculate the index of the middle frame 79 | middle_frame_index = total_frames // 2 80 | 81 | # set the current frame index to the middle frame index 82 | cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_index) 83 | 84 | # read the middle frame 85 | ret, frame = cap.read() 86 | 87 | if not ret or frame is None: 88 | print(f"Video {reference_vid_pth} is corrupted") 89 | return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) 90 | 91 | # convert the frame from BGR to RGB using OpenCV 92 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 93 | 94 | # create a PIL Image object from the middle frame 95 | pil_image = Image.fromarray(frame) 96 | 97 | return pil_image 98 | 99 | 100 | def get_random_frame(reference_vid_pth): 101 | from pathlib import Path 102 | 103 | import cv2 104 | import numpy as np 105 | from PIL import Image 106 | 107 | reference_vid_pth = str(reference_vid_pth) 108 | 109 | if not Path(reference_vid_pth).exists(): 110 | print(f"Video {reference_vid_pth} does not exist") 111 | return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) 112 | 113 | # use OpenCV to read the video 114 | cap = cv2.VideoCapture(reference_vid_pth) 115 | 116 | # get the total number of frames in the video 117 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 118 | 119 | # calculate the index of random frame 120 | random_frame_index = np.random.randint(0, total_frames) 121 | 122 | # set the current frame index to the random frame index 123 | cap.set(cv2.CAP_PROP_POS_FRAMES, random_frame_index) 124 | 125 | # read the frame 126 | ret, frame = cap.read() 127 | 128 | if not ret or frame is None: 129 | print(f"Video {reference_vid_pth} is corrupted") 130 | return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) 131 | 132 | # convert the frame from BGR to RGB using OpenCV 133 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 134 | 135 | # create a PIL Image object from the middle frame 136 | pil_image = Image.fromarray(frame) 137 | 138 | return pil_image 139 | 140 | 141 | def sample_frames(frames_videos, vlen): 142 | import numpy as np 143 | 144 | acc_samples = min(frames_videos, vlen) 145 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 146 | ranges = [] 147 | for idx, interv in enumerate(intervals[:-1]): 148 | ranges.append((interv, intervals[idx + 1] - 1)) 149 | 150 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 151 | 152 | return frame_idxs 153 | 154 | 155 | class FrameLoader: 156 | def __init__(self, transform, frames_video=1, method="middle"): 157 | self.transform = transform 158 | self.method = method 159 | 160 | if method == "middle": 161 | self.get_frame = get_middle_frame 162 | assert frames_video == 1, "frames_video must be 1 for middle frame method" 163 | elif method == "random": 164 | self.get_frame = get_random_frame 165 | assert frames_video == 1, "frames_video must be 1 for random frame method" 166 | elif method == "sample": 167 | assert frames_video > 1, "frames_video must be > 1 for sample frame method" 168 | self.frames_video = frames_video 169 | else: 170 | raise ValueError(f"Invalid method: {method}") 171 | 172 | def __call__(self, video_pth: str): 173 | if self.method == "sample": 174 | frames = self.get_video_frames(video_pth, 0.0, None) 175 | return torch.stack(frames) 176 | else: 177 | return self.transform(self.get_frame(video_pth)) 178 | 179 | def get_video_frames( 180 | self, 181 | video_pth: str, 182 | start_time: float = 0.0, 183 | end_time: Union[float, None] = None, 184 | ) -> list: 185 | import cv2 186 | from PIL import Image 187 | 188 | cap = cv2.VideoCapture(video_pth) 189 | 190 | fps = cap.get(cv2.CAP_PROP_FPS) 191 | if end_time is not None: 192 | start_frame = int(fps * start_time) 193 | end_frame = int(fps * end_time) 194 | vlen = end_frame - start_frame 195 | else: 196 | start_frame = 0 197 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 198 | 199 | frame_idxs = sample_frames(self.frames_video, vlen) 200 | frame_idxs = [frame_idx + start_frame for frame_idx in frame_idxs] 201 | if self.frames_video != len(frame_idxs): 202 | frame_idxs = (frame_idxs * self.frames_video)[: self.frames_video] 203 | print(f"Video {video_pth} has less than {self.frames_video} frames") 204 | 205 | frames = [] 206 | for index in frame_idxs: 207 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) 208 | ret, frame = cap.read() 209 | if not ret: 210 | break 211 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 212 | frames.append(Image.fromarray(frame_rgb).convert("RGB")) 213 | 214 | cap.release() 215 | 216 | if len(frames) > 0: 217 | video_data = [self.transform(frame) for frame in frames] 218 | return video_data 219 | else: 220 | raise ValueError(f"video path: {video_pth} error.") 221 | --------------------------------------------------------------------------------