├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── README.md ├── imgs ├── exp1.jpg ├── exp2.jpg ├── framework.png ├── intro.pdf └── introduction.png ├── main_retrieval.py ├── preprocess └── compress_video.py ├── requirements.txt ├── test_activitynet.sh ├── test_didemo.sh ├── test_lsmdc.sh ├── test_msrvtt.sh ├── test_msvd.sh ├── train_activitynet.sh ├── train_didemo.sh ├── train_lsmdc.sh ├── train_msrvtt.sh ├── train_msvd.sh └── tvr ├── __init__.py ├── __pycache__ └── __init__.cpython-39.pyc ├── dataloaders ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── data_dataloaders.cpython-39.pyc │ ├── dataloader_activitynet_retrieval.cpython-39.pyc │ ├── dataloader_didemo_retrieval.cpython-39.pyc │ ├── dataloader_lsmdc_retrieval.cpython-39.pyc │ ├── dataloader_msrvtt_retrieval.cpython-39.pyc │ ├── dataloader_msvd_retrieval.cpython-39.pyc │ ├── dataloader_retrieval.cpython-39.pyc │ ├── rand_augment.cpython-39.pyc │ ├── random_erasing.cpython-39.pyc │ ├── rawvideo_util.cpython-39.pyc │ └── video_transforms.cpython-39.pyc ├── data_dataloaders.py ├── dataloader_activitynet_retrieval.py ├── dataloader_didemo_retrieval.py ├── dataloader_lsmdc_retrieval.py ├── dataloader_msrvtt_retrieval.py ├── dataloader_msvd_retrieval.py ├── dataloader_retrieval.py ├── rand_augment.py ├── random_erasing.py ├── rawvideo_util.py └── video_transforms.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── modeling.cpython-39.pyc │ ├── module_clip.cpython-39.pyc │ ├── module_cross.cpython-39.pyc │ ├── module_transformer.cpython-39.pyc │ ├── optimization.cpython-39.pyc │ ├── query_cross_att.cpython-39.pyc │ ├── tokenization_clip.cpython-39.pyc │ ├── transformer.cpython-39.pyc │ ├── transformer_block.cpython-39.pyc │ └── until_module.cpython-39.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── modeling.py ├── modeling_old.py ├── module.py ├── module_clip.py ├── module_cross.py ├── module_transformer.py ├── optimization.py ├── tokenization_clip.py ├── transformer_block.py └── until_module.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc ├── comm.cpython-39.pyc ├── logger.cpython-39.pyc ├── metric_logger.cpython-39.pyc └── metrics.cpython-39.pyc ├── comm.py ├── logger.py ├── metric_logger.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /ckpt 3 | /PromptSwitch 4 | /figs 5 | /tvr/models/ViT-B-*.pt 6 | a.py 7 | draw.ipynb 8 | query_vis.py 9 | query_vis.sh 10 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "/home/zhanghaonan/anaconda3/envs/ic/lib/python3.9/site-packages/torch/distributed/launch.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "env": { 15 | "CUDA_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7,8,9" 16 | }, 17 | "args": [ 18 | "--master_port=2505", 19 | "--nproc_per_node=4", 20 | "main_retrieval.py", 21 | "--do_train=1", 22 | "--workers=4", 23 | "--n_display=50", 24 | "--epochs=5", 25 | "--lr=1e-4", 26 | "--coef_lr=1e-3", 27 | "--batch_size=128", 28 | "--batch_size_val=64", 29 | "--anno_path=/mnt/nfs/CMG/zhanghaonan/datasets/MSVD/anns", 30 | "--video_path=/mnt/nfs/CMG/zhanghaonan/datasets/MSVD/MSVD_Videos", 31 | "--datatype=msvd", 32 | "--max_words=32", 33 | "--max_frames=12", 34 | "--video_framerate=1", 35 | "--output_dir=ckpt/msvd/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight", 36 | "--center=1", 37 | "--temp=3", 38 | "--alpha=0.0001", 39 | "--beta=0.005", 40 | ] 41 | } 42 | ] 43 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # _GLSCL_: Text-Video Retrieval with Global-Local Semantic Consistent Learning (TIP 2025) 3 | 4 | Haonan Zhang, Pengpeng Zeng, Lianli Gao, Jingkuan Song, Yihang Duan, Xinyu Lyu, Heng Tao Shen 5 | 6 | 7 | 8 | This is the official code implementation of the paper "Text-Video Retrieval with Global-Local Semantic Consistent Learning", the checkpoint and feature will be released soon. 9 | 10 | We are continuously refactoring our code, be patient and wait for the latest updates! 11 | 12 | ## 🔥 Updates 13 | 14 | - [ ] Release the pre-trained weight and datasets. 15 | - [x] Release the training and evaluation code. 16 | 17 | ## ✨Overview 18 | Adapting large-scale image-text pre-training models, _e.g._, CLIP, to the video domain represents the current state-of-the-art for text-video retrieval. The primary approaches involve transferring text-video pairs to a common embedding space and leveraging cross-modal interactions on specific entities for semantic alignment. Though effective, these paradigms entail prohibitive computational costs, leading to inefficient retrieval. To address this, we propose a simple yet effective method, Global-Local Semantic Consistent Learning (GLSCL), which capitalizes on latent shared semantics across modalities for text-video retrieval. Specifically, we introduce a parameter-free global interaction module to explore coarse-grained alignment. Then, we devise a shared local interaction module that employs several learnable queries to capture latent semantic concepts for learning fine-grained alignment. Moreover, we propose an inter-consistency loss and an intra-diversity loss to ensure the similarity and diversity of these concepts across and within modalities, respectively. 19 | 20 |

21 |
22 | Figure 1. Performance comparison of the retrieval results (R@1) and computational costs (FLOPs) for text-to-video retrieval models. 23 |

24 | 25 | 26 | ## 🍀 Method 27 |

28 |
29 | Figure 2. Overview of the proposed GLSCL for Text-Video retrieval. 30 |

31 | 32 | ## ⚙️ Usage 33 | ### Requirements 34 | The GLSCL framework depends on the following main requirements: 35 | - torch==1.8.1+cu114 36 | - Transformers 4.6.1 37 | - OpenCV 4.5.3 38 | - tqdm 39 | 40 | ### Datasets 41 | We train our model on ```MSR-VTT-9k```, ```MSVD```, ```DiDeMo```, ```LSMDC```, and ```ActivityNet``` datasets respectively. Please refer to this [repo](https://github.com/jpthu17/DiCoSA) for data preparation. 42 | 43 | ### How to Run (take *MSR-VTT* for example) 44 | 45 | For simple training on MSR-VTT-9k with default hyperparameters: 46 | ``` 47 | bash train_msrvtt.sh 48 | ``` 49 | or run in the terminal directly: 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 52 | python -m torch.distributed.launch \ 53 | --master_port 2513 \ 54 | --nproc_per_node=8 \ 55 | main_retrieval.py \ 56 | --do_train 1 \ 57 | --workers 8 \ 58 | --n_display 50 \ 59 | --epochs 5 \ 60 | --lr 1e-4 \ 61 | --coef_lr 1e-3 \ 62 | --batch_size 128 \ 63 | --batch_size_val 64 \ 64 | --anno_path ANNOTATION_PATH \ 65 | --video_path YOUR_RAW_VIDEO_PATH \ 66 | --datatype msrvtt \ 67 | --max_words 32 \ 68 | --max_frames 12 \ 69 | --video_framerate 1 \ 70 | --output_dir YOUR_SAVE_PATH \ 71 | --center 1 \ 72 | --temp 3 \ 73 | --alpha 0.0001 \ 74 | --beta 0.005 \ 75 | --query_number 8 \ 76 | --base_encoder ViT-B/32 \ 77 | --cross_att_layer 3 \ 78 | --query_share 1 \ 79 | --cross_att_share 1 \ 80 | --loss2_weight 0.5 \ 81 | ``` 82 | ### How to Evaluate (take *MSR-VTT* for example) 83 | 84 | For simple testing on MSR-VTT-9k with default hyperparameters: 85 | ``` 86 | bash train_msrvtt.sh 87 | ``` 88 | or 89 | ``` 90 | CUDA_VISIBLE_DEVICES=0,1 \ 91 | python -m torch.distributed.launch \ 92 | --master_port 2503 \ 93 | --nproc_per_node=2 \ 94 | main_retrieval.py \ 95 | --do_eval 1 \ 96 | --workers 8 \ 97 | --n_display 50 \ 98 | --epochs 5 \ 99 | --lr 1e-4 \ 100 | --coef_lr 1e-3 \ 101 | --batch_size 64 \ 102 | --anno_path ANNOTATION_PATH \ 103 | --video_path YOUR_RAW_VIDEO_PATH \ 104 | --datatype msrvtt \ 105 | --max_words 32 \ 106 | --max_frames 12 \ 107 | --video_framerate 1 \ 108 | --output_dir YOUR_SAVE_PATH \ 109 | --center 1 \ 110 | --temp 3 \ 111 | --alpha 0.0001 \ 112 | --beta 0.005 \ 113 | --query_number 8 \ 114 | --base_encoder ViT-B/32 \ 115 | --cross_att_layer 3 \ 116 | --query_share 1 \ 117 | --cross_att_share 1 \ 118 | --loss2_weight 0.5 \ 119 | --init_model YOUR_CKPT_FILE 120 | ``` 121 | 122 | ## 🧪 Experiments 123 |

124 | 125 |

126 | 127 |

128 | 129 |

130 | 131 | ## 📚 Citation 132 | 133 | ```bibtex 134 | @inproceedings{GLSCL, 135 | author = {Haonan Zhang, and Pengpeng Zeng, and Lianli Gao, and Jingkuan Song, and Yihang Duan, and Xinyu Lyu, and Hengtao Sheng}, 136 | title = {Text-Video Retrieval with Global-Local Semantic Consistent Learning}, 137 | year = {2024} 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /imgs/exp1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/exp1.jpg -------------------------------------------------------------------------------- /imgs/exp2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/exp2.jpg -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/framework.png -------------------------------------------------------------------------------- /imgs/intro.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/intro.pdf -------------------------------------------------------------------------------- /imgs/introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/introduction.png -------------------------------------------------------------------------------- /preprocess/compress_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip 3 | Author: ArrowLuo 4 | """ 5 | import os 6 | import argparse 7 | import ffmpeg 8 | import subprocess 9 | import time 10 | import multiprocessing 11 | from multiprocessing import Pool 12 | import shutil 13 | try: 14 | from psutil import cpu_count 15 | except: 16 | from multiprocessing import cpu_count 17 | # multiprocessing.freeze_support() 18 | 19 | def compress(paras): 20 | input_video_path, output_video_path = paras 21 | try: 22 | command = ['ffmpeg', 23 | '-y', # (optional) overwrite output file if it exists 24 | '-i', input_video_path, 25 | '-filter:v', 26 | 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'', # scale to 224 27 | '-map', '0:v', 28 | '-r', '3', # frames per second 29 | output_video_path, 30 | ] 31 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | out, err = ffmpeg.communicate() 33 | retcode = ffmpeg.poll() 34 | # print something above for debug 35 | except Exception as e: 36 | raise e 37 | 38 | def prepare_input_output_pairs(input_root, output_root): 39 | input_video_path_list = [] 40 | output_video_path_list = [] 41 | for root, dirs, files in os.walk(input_root): 42 | for file_name in files: 43 | input_video_path = os.path.join(root, file_name) 44 | output_video_path = os.path.join(output_root, file_name) 45 | if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: 46 | pass 47 | else: 48 | input_video_path_list.append(input_video_path) 49 | output_video_path_list.append(output_video_path) 50 | return input_video_path_list, output_video_path_list 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser(description='Compress video for speed-up') 54 | parser.add_argument('--input_root', type=str, help='input root') 55 | parser.add_argument('--output_root', type=str, help='output root') 56 | args = parser.parse_args() 57 | 58 | input_root = args.input_root 59 | output_root = args.output_root 60 | 61 | assert input_root != output_root 62 | 63 | if not os.path.exists(output_root): 64 | os.makedirs(output_root, exist_ok=True) 65 | 66 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root) 67 | 68 | print("Total video need to process: {}".format(len(input_video_path_list))) 69 | num_works = cpu_count() 70 | print("Begin with {}-core logical processor.".format(num_works)) 71 | 72 | pool = Pool(num_works) 73 | data_dict_list = pool.map(compress, 74 | [(input_video_path, output_video_path) for 75 | input_video_path, output_video_path in 76 | zip(input_video_path_list, output_video_path_list)]) 77 | pool.close() 78 | pool.join() 79 | 80 | print("Compress finished, wait for checking files...") 81 | for input_video_path, output_video_path in zip(input_video_path_list, output_video_path_list): 82 | if os.path.exists(input_video_path): 83 | if os.path.exists(output_video_path) is False or os.path.getsize(output_video_path) < 1.: 84 | shutil.copyfile(input_video_path, output_video_path) 85 | print("Copy and replace file: {}".format(output_video_path)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord 2 | pandas 3 | ftfy 4 | regex 5 | tqdm 6 | opencv-python 7 | functional 8 | timm 9 | # torch==1.8.1+cu102 10 | # torchvision==0.9.1+cu102 11 | -------------------------------------------------------------------------------- /test_activitynet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2503 \ 4 | --nproc_per_node=4 \ 5 | main_retrieval.py \ 6 | --do_eval 1 \ 7 | --workers 8 \ 8 | --n_display 50 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 128 \ 13 | --batch_size_val 128 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/Videos/Activity_Videos \ 16 | --datatype activity \ 17 | --max_words 64 \ 18 | --max_frames 64 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/activitynet/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \ 21 | --center 1 \ 22 | --query_number 8 \ 23 | --cross_att_layer 3 \ 24 | --query_share 1 \ 25 | --cross_att_share 1 \ 26 | --add_query_score_for_eval 0 \ 27 | --base_encoder ViT-B/32 \ 28 | --temp 3 \ 29 | --alpha 0.0001 \ 30 | --beta 0.005 \ 31 | --t2v_beta 50 \ 32 | --v2t_beta 50 \ 33 | --init_model ckpt/activitynet/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight/pytorch_model.bin.step900.5 -------------------------------------------------------------------------------- /test_didemo.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=8,9 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2503 \ 4 | --nproc_per_node=2 \ 5 | main_retrieval.py \ 6 | --do_eval 1 \ 7 | --workers 8 \ 8 | --n_display 50 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 32 \ 13 | --batch_size_val 32 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/videos \ 16 | --datatype didemo \ 17 | --max_words 64 \ 18 | --max_frames 64 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/didemo/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \ 21 | --center 1 \ 22 | --query_number 8 \ 23 | --cross_att_layer 3 \ 24 | --query_share 1 \ 25 | --cross_att_share 1 \ 26 | --add_query_score_for_eval 0 \ 27 | --base_encoder ViT-B/32 \ 28 | --temp 3 \ 29 | --alpha 0.0001 \ 30 | --beta 0.005 \ 31 | --init_model ckpt/didemo/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight/pytorch_model.bin.step1200.4 -------------------------------------------------------------------------------- /test_lsmdc.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/test_lsmdc.sh -------------------------------------------------------------------------------- /test_msrvtt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=8,9 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2503 \ 4 | --nproc_per_node=2 \ 5 | main_retrieval.py \ 6 | --do_eval 1 \ 7 | --workers 8 \ 8 | --n_display 50 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 64 \ 13 | --batch_size_val 64 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/MSRVTT_Videos \ 16 | --datatype msrvtt \ 17 | --max_words 32 \ 18 | --max_frames 12 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/msrvtt/ablation/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \ 21 | --center 1 \ 22 | --query_number 8 \ 23 | --cross_att_layer 3 \ 24 | --query_share 0 \ 25 | --cross_att_share 1 \ 26 | --add_query_score_for_eval 1 \ 27 | --base_encoder ViT-B/32 \ 28 | --temp 3 \ 29 | --alpha 0.0001 \ 30 | --beta 0.005 \ 31 | --init_model ckpt/msrvtt/ablation/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight/pytorch_model.bin.step2850.4 \ 32 | --loss2_weight 0.5 \ 33 | -------------------------------------------------------------------------------- /test_msvd.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/test_msvd.sh -------------------------------------------------------------------------------- /train_activitynet.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2502 \ 4 | --nproc_per_node=8 \ 5 | main_retrieval.py \ 6 | --do_train 1 \ 7 | --workers 8 \ 8 | --n_display 10 \ 9 | --epochs 10 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 64 \ 13 | --batch_size_val 64 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/Videos/Activity_Videos \ 16 | --datatype activity \ 17 | --max_words 64 \ 18 | --max_frames 64 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/activitynet/main_exp/v2 \ 21 | --center 1 \ 22 | --temp 3 \ 23 | --alpha 0.0001 \ 24 | --beta 0.005 \ 25 | --t2v_beta 50 \ 26 | --v2t_beta 50 \ 27 | --query_number 12 \ 28 | --base_encoder ViT-B/32 \ 29 | --cross_att_layer 3 \ 30 | --query_share 1 \ 31 | --cross_att_share 1 \ 32 | --loss2_weight 0.5 \ -------------------------------------------------------------------------------- /train_didemo.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2502 \ 4 | --nproc_per_node=4 \ 5 | main_retrieval.py \ 6 | --do_train 1 \ 7 | --workers 8 \ 8 | --n_display 50 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 32 \ 13 | --batch_size_val 32 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/videos \ 16 | --datatype didemo \ 17 | --max_words 64 \ 18 | --max_frames 64 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/didemo/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \ 21 | --center 1 \ 22 | --temp 3 \ 23 | --alpha 0.0001 \ 24 | --beta 0.005 \ 25 | --query_number 8 \ 26 | --base_encoder ViT-B/32 \ 27 | --cross_att_layer 3 \ 28 | --add_query_score_for_eval 0 \ 29 | --query_share 1 \ 30 | --cross_att_share 1 \ -------------------------------------------------------------------------------- /train_lsmdc.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2502 \ 4 | --nproc_per_node=4 \ 5 | main_retrieval.py \ 6 | --do_train 1 \ 7 | --workers 8 \ 8 | --n_display 5 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 128 \ 13 | --batch_size_val 64 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/LSMDC/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/LSMDC \ 16 | --datatype lsmdc \ 17 | --max_words 32 \ 18 | --max_frames 12 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/lsmdc/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_4cross_add_query_sim_query_shared_cross_att_shared_without_weight \ 21 | --center 1 \ 22 | --query_number 8 \ 23 | --base_encoder ViT-B/32 \ 24 | --cross_att_layer 3 \ 25 | --add_query_score_for_eval 0 \ 26 | --query_share 1 \ 27 | --cross_att_share 1 \ 28 | --loss2_weight 0.5 \ -------------------------------------------------------------------------------- /train_msrvtt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4,5,6,7,8,9 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2513 \ 4 | --nproc_per_node=4 \ 5 | main_retrieval.py \ 6 | --do_train 1 \ 7 | --workers 8 \ 8 | --n_display 50 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 128 \ 13 | --batch_size_val 64 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/MSRVTT_Videos \ 16 | --datatype msrvtt \ 17 | --max_words 32 \ 18 | --max_frames 12 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/msrvtt/no_qbnorm/alpha0.0001_beta_0.02 \ 21 | --center 1 \ 22 | --temp 3 \ 23 | --alpha 0.0001 \ 24 | --beta 0.02 \ 25 | --query_number 8 \ 26 | --base_encoder ViT-B/32 \ 27 | --cross_att_layer 3 \ 28 | --query_share 1 \ 29 | --cross_att_share 1 \ 30 | --loss2_weight 0.5 \ -------------------------------------------------------------------------------- /train_msvd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9 \ 2 | python -m torch.distributed.launch \ 3 | --master_port 2512 \ 4 | --nproc_per_node=4 \ 5 | main_retrieval.py \ 6 | --do_train 1 \ 7 | --workers 8 \ 8 | --n_display 50 \ 9 | --epochs 5 \ 10 | --lr 1e-4 \ 11 | --coef_lr 1e-3 \ 12 | --batch_size 128 \ 13 | --batch_size_val 64 \ 14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/MSVD/anns \ 15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/MSVD/MSVD_Videos \ 16 | --datatype msvd \ 17 | --max_words 32 \ 18 | --max_frames 12 \ 19 | --video_framerate 1 \ 20 | --output_dir ckpt/msvd/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \ 21 | --center 1 \ 22 | --temp 3 \ 23 | --alpha 0.0001 \ 24 | --beta 0.005 \ 25 | --query_number 8 \ 26 | --base_encoder ViT-B/32 \ 27 | --cross_att_layer 3 \ 28 | --add_query_score_for_eval 0 \ 29 | --query_share 1 \ 30 | --cross_att_share 1 \ 31 | --loss2_weight 0.5 \ -------------------------------------------------------------------------------- /tvr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/__init__.py -------------------------------------------------------------------------------- /tvr/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__init__.py -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/data_dataloaders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/data_dataloaders.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/dataloader_retrieval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_retrieval.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/rand_augment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/rand_augment.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/random_erasing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/random_erasing.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/rawvideo_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/rawvideo_util.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/__pycache__/video_transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/video_transforms.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/dataloaders/data_dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from .dataloader_msrvtt_retrieval import MSRVTTDataset 4 | from .dataloader_activitynet_retrieval import ActivityNetDataset 5 | from .dataloader_didemo_retrieval import DiDeMoDataset 6 | from .dataloader_lsmdc_retrieval import LsmdcDataset 7 | from .dataloader_msvd_retrieval import MsvdDataset 8 | 9 | 10 | def dataloader_msrvtt_train(args, tokenizer): 11 | msrvtt_dataset = MSRVTTDataset( 12 | subset='train', 13 | anno_path=args.anno_path, 14 | video_path=args.video_path, 15 | max_words=args.max_words, 16 | tokenizer=tokenizer, 17 | max_frames=args.max_frames, 18 | video_framerate=args.video_framerate, 19 | config=args 20 | ) 21 | try: 22 | train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) 23 | except: 24 | train_sampler = None # cpu 25 | dataloader = DataLoader( 26 | msrvtt_dataset, 27 | batch_size=args.batch_size // args.world_size, 28 | num_workers=args.workers, 29 | pin_memory=False, 30 | shuffle=(train_sampler is None), 31 | sampler=train_sampler, 32 | drop_last=True, 33 | ) 34 | 35 | return dataloader, len(msrvtt_dataset), train_sampler 36 | 37 | 38 | def dataloader_msrvtt_test(args, tokenizer, subset="test"): 39 | msrvtt_testset = MSRVTTDataset( 40 | subset=subset, 41 | anno_path=args.anno_path, 42 | video_path=args.video_path, 43 | max_words=args.max_words, 44 | tokenizer=tokenizer, 45 | max_frames=args.max_frames, 46 | video_framerate=args.video_framerate, 47 | config=args 48 | ) 49 | 50 | try: 51 | test_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_testset) 52 | except: 53 | test_sampler = None # cpu 54 | dataloader_msrvtt = DataLoader( 55 | msrvtt_testset, 56 | batch_size=args.batch_size_val // args.world_size, 57 | num_workers=args.workers, 58 | shuffle=False, 59 | sampler=test_sampler, 60 | drop_last=False, 61 | ) 62 | return dataloader_msrvtt, len(msrvtt_testset) 63 | 64 | 65 | def dataloader_msrvtt_train_test(args, tokenizer): 66 | msrvtt_dataset = MSRVTTDataset( 67 | subset='train_test', 68 | anno_path=args.anno_path, 69 | video_path=args.video_path, 70 | max_words=args.max_words, 71 | tokenizer=tokenizer, 72 | max_frames=args.max_frames, 73 | video_framerate=args.video_framerate, 74 | config=args 75 | ) 76 | try: 77 | train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) 78 | except: 79 | train_sampler = None # cpu 80 | dataloader = DataLoader( 81 | msrvtt_dataset, 82 | batch_size=args.batch_size // args.world_size, 83 | num_workers=args.workers, 84 | pin_memory=False, 85 | shuffle=(train_sampler is None), 86 | sampler=train_sampler, 87 | drop_last=True, 88 | ) 89 | 90 | return dataloader, len(msrvtt_dataset), train_sampler 91 | 92 | 93 | def dataloader_lsmdc_train(args, tokenizer): 94 | lsmdc_dataset = LsmdcDataset( 95 | subset='train', 96 | anno_path=args.anno_path, 97 | video_path=args.video_path, 98 | max_words=args.max_words, 99 | tokenizer=tokenizer, 100 | max_frames=args.max_frames, 101 | video_framerate=args.video_framerate, 102 | config=args 103 | ) 104 | 105 | train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) 106 | dataloader = DataLoader( 107 | lsmdc_dataset, 108 | batch_size=args.batch_size // args.world_size, 109 | num_workers=args.workers, 110 | pin_memory=False, 111 | shuffle=(train_sampler is None), 112 | sampler=train_sampler, 113 | drop_last=True, 114 | ) 115 | 116 | return dataloader, len(lsmdc_dataset), train_sampler 117 | 118 | 119 | def dataloader_lsmdc_train_test(args, tokenizer): 120 | lsmdc_dataset = LsmdcDataset( 121 | subset='train_test', 122 | anno_path=args.anno_path, 123 | video_path=args.video_path, 124 | max_words=args.max_words, 125 | tokenizer=tokenizer, 126 | max_frames=args.max_frames, 127 | video_framerate=args.video_framerate, 128 | config=args 129 | ) 130 | 131 | train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) 132 | dataloader = DataLoader( 133 | lsmdc_dataset, 134 | batch_size=args.batch_size // args.world_size, 135 | num_workers=args.workers, 136 | pin_memory=False, 137 | shuffle=(train_sampler is None), 138 | sampler=train_sampler, 139 | drop_last=True, 140 | ) 141 | 142 | return dataloader, len(lsmdc_dataset), train_sampler 143 | 144 | 145 | def dataloader_lsmdc_test(args, tokenizer, subset="test"): 146 | lsmdc_testset = LsmdcDataset( 147 | subset=subset, 148 | anno_path=args.anno_path, 149 | video_path=args.video_path, 150 | max_words=args.max_words, 151 | tokenizer=tokenizer, 152 | max_frames=args.max_frames, 153 | video_framerate=args.video_framerate, 154 | config=args 155 | ) 156 | try: 157 | test_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_testset) 158 | except: 159 | test_sampler = None # cpu 160 | dataloader_lsmdc = DataLoader( 161 | lsmdc_testset, 162 | batch_size=args.batch_size_val // args.world_size, 163 | num_workers=args.workers, 164 | shuffle=False, 165 | sampler=test_sampler, 166 | drop_last=False, 167 | ) 168 | return dataloader_lsmdc, len(lsmdc_testset) 169 | 170 | 171 | def dataloader_activity_train(args, tokenizer): 172 | activity_dataset = ActivityNetDataset( 173 | subset="train", 174 | data_path=args.anno_path, 175 | features_path=args.video_path, 176 | max_words=args.max_words, 177 | feature_framerate=args.video_framerate, 178 | tokenizer=tokenizer, 179 | max_frames=args.max_frames 180 | ) 181 | 182 | train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) 183 | dataloader = DataLoader( 184 | activity_dataset, 185 | batch_size=args.batch_size // args.world_size, 186 | num_workers=args.workers, 187 | pin_memory=False, 188 | shuffle=(train_sampler is None), 189 | sampler=train_sampler, 190 | drop_last=True, 191 | ) 192 | 193 | return dataloader, len(activity_dataset), train_sampler 194 | 195 | 196 | def dataloader_activity_train_test(args, tokenizer): 197 | activity_dataset = ActivityNetDataset( 198 | subset="train_test", 199 | data_path=args.anno_path, 200 | features_path=args.video_path, 201 | max_words=args.max_words, 202 | feature_framerate=args.video_framerate, 203 | tokenizer=tokenizer, 204 | max_frames=args.max_frames 205 | ) 206 | 207 | train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) 208 | dataloader = DataLoader( 209 | activity_dataset, 210 | batch_size=args.batch_size // args.world_size, 211 | num_workers=args.workers, 212 | pin_memory=False, 213 | shuffle=(train_sampler is None), 214 | sampler=train_sampler, 215 | drop_last=True, 216 | ) 217 | 218 | return dataloader, len(activity_dataset), train_sampler 219 | 220 | 221 | def dataloader_activity_test(args, tokenizer, subset="test"): 222 | activity_testset = ActivityNetDataset( 223 | subset=subset, 224 | data_path=args.anno_path, 225 | features_path=args.video_path, 226 | max_words=args.max_words, 227 | feature_framerate=args.video_framerate, 228 | tokenizer=tokenizer, 229 | max_frames=args.max_frames 230 | ) 231 | try: 232 | test_sampler = torch.utils.data.distributed.DistributedSampler(activity_testset) 233 | except: 234 | test_sampler = None # cpu 235 | dataloader_activity = DataLoader( 236 | activity_testset, 237 | batch_size=args.batch_size_val // args.world_size, 238 | num_workers=args.workers, 239 | shuffle=False, 240 | sampler=test_sampler, 241 | drop_last=False, 242 | ) 243 | return dataloader_activity, len(activity_testset) 244 | 245 | 246 | def dataloader_msvd_train(args, tokenizer): 247 | msvd_dataset = MsvdDataset( 248 | subset="train", 249 | anno_path=args.anno_path, 250 | video_path=args.video_path, 251 | max_words=args.max_words, 252 | tokenizer=tokenizer, 253 | max_frames=args.max_frames, 254 | video_framerate=args.video_framerate, 255 | config=args 256 | ) 257 | 258 | train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) 259 | dataloader = DataLoader( 260 | msvd_dataset, 261 | batch_size=args.batch_size // args.world_size, 262 | num_workers=args.workers, 263 | pin_memory=False, 264 | shuffle=(train_sampler is None), 265 | sampler=train_sampler, 266 | drop_last=True, 267 | ) 268 | 269 | return dataloader, len(msvd_dataset), train_sampler 270 | 271 | 272 | def dataloader_msvd_train_test(args, tokenizer): 273 | msvd_dataset = MsvdDataset( 274 | subset="train_test", 275 | anno_path=args.anno_path, 276 | video_path=args.video_path, 277 | max_words=args.max_words, 278 | tokenizer=tokenizer, 279 | max_frames=args.max_frames, 280 | video_framerate=args.video_framerate, 281 | config=args 282 | ) 283 | 284 | train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) 285 | dataloader = DataLoader( 286 | msvd_dataset, 287 | batch_size=args.batch_size // args.world_size, 288 | num_workers=args.workers, 289 | pin_memory=False, 290 | shuffle=(train_sampler is None), 291 | sampler=train_sampler, 292 | drop_last=True, 293 | ) 294 | 295 | return dataloader, len(msvd_dataset), train_sampler 296 | 297 | 298 | def dataloader_msvd_test(args, tokenizer, subset="test"): 299 | msvd_testset = MsvdDataset( 300 | subset=subset, 301 | anno_path=args.anno_path, 302 | video_path=args.video_path, 303 | max_words=args.max_words, 304 | tokenizer=tokenizer, 305 | max_frames=args.max_frames, 306 | video_framerate=args.video_framerate, 307 | config=args 308 | ) 309 | try: 310 | test_sampler = torch.utils.data.distributed.DistributedSampler(msvd_testset) 311 | except: 312 | test_sampler = None # cpu 313 | dataloader_msvd = DataLoader( 314 | msvd_testset, 315 | batch_size=args.batch_size_val // args.world_size, 316 | num_workers=args.workers, 317 | shuffle=False, 318 | sampler=test_sampler, 319 | drop_last=False, 320 | ) 321 | return dataloader_msvd, len(msvd_testset) 322 | 323 | 324 | def dataloader_didemo_train(args, tokenizer): 325 | didemo_dataset = DiDeMoDataset( 326 | subset="train", 327 | data_path=args.anno_path, 328 | features_path=args.video_path, 329 | max_words=args.max_words, 330 | feature_framerate=args.video_framerate, 331 | tokenizer=tokenizer, 332 | max_frames=args.max_frames 333 | ) 334 | 335 | train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) 336 | dataloader = DataLoader( 337 | didemo_dataset, 338 | batch_size=args.batch_size // args.world_size, 339 | num_workers=args.workers, 340 | pin_memory=False, 341 | shuffle=(train_sampler is None), 342 | sampler=train_sampler, 343 | drop_last=True, 344 | ) 345 | 346 | return dataloader, len(didemo_dataset), train_sampler 347 | 348 | 349 | def dataloader_didemo_train_test(args, tokenizer): 350 | didemo_dataset = DiDeMoDataset( 351 | subset="train_test", 352 | data_path=args.anno_path, 353 | features_path=args.video_path, 354 | max_words=args.max_words, 355 | feature_framerate=args.video_framerate, 356 | tokenizer=tokenizer, 357 | max_frames=args.max_frames 358 | ) 359 | 360 | train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) 361 | dataloader = DataLoader( 362 | didemo_dataset, 363 | batch_size=args.batch_size // args.world_size, 364 | num_workers=args.workers, 365 | pin_memory=False, 366 | shuffle=(train_sampler is None), 367 | sampler=train_sampler, 368 | drop_last=True, 369 | ) 370 | 371 | return dataloader, len(didemo_dataset), train_sampler 372 | 373 | 374 | def dataloader_didemo_test(args, tokenizer, subset="test"): 375 | didemo_testset = DiDeMoDataset( 376 | subset=subset, 377 | data_path=args.anno_path, 378 | features_path=args.video_path, 379 | max_words=args.max_words, 380 | feature_framerate=args.video_framerate, 381 | tokenizer=tokenizer, 382 | max_frames=args.max_frames 383 | ) 384 | try: 385 | test_sampler = torch.utils.data.distributed.DistributedSampler(didemo_testset) 386 | except: 387 | test_sampler = None # cpu 388 | dataloader_didemo = DataLoader( 389 | didemo_testset, 390 | batch_size=args.batch_size_val // args.world_size, 391 | num_workers=args.workers, 392 | shuffle=False, 393 | sampler=test_sampler, 394 | drop_last=False, 395 | ) 396 | return dataloader_didemo, len(didemo_testset) 397 | 398 | 399 | DATALOADER_DICT = {} 400 | DATALOADER_DICT["msrvtt"] = {"train": dataloader_msrvtt_train, 401 | "val": dataloader_msrvtt_test, 402 | "test": None, 403 | "train_test": dataloader_msrvtt_train_test} 404 | DATALOADER_DICT["msvd"] = {"train": dataloader_msvd_train, 405 | "val": dataloader_msvd_test, 406 | "test": dataloader_msvd_test, 407 | "train_test": dataloader_msvd_train_test} 408 | DATALOADER_DICT["lsmdc"] = {"train": dataloader_lsmdc_train, 409 | "val": dataloader_lsmdc_test, 410 | "test": dataloader_lsmdc_test, 411 | "train_test": dataloader_lsmdc_train_test} 412 | DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, 413 | "val":dataloader_activity_test, 414 | "test":None, 415 | "train_test": dataloader_activity_train_test} 416 | DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, 417 | "val":None, 418 | "test":dataloader_didemo_test, 419 | "train_test":dataloader_didemo_train_test} -------------------------------------------------------------------------------- /tvr/dataloaders/dataloader_activitynet_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | import math 11 | from .rawvideo_util import RawVideoExtractor 12 | 13 | 14 | class ActivityNetDataset(Dataset): 15 | def __init__( 16 | self, 17 | subset, 18 | data_path, 19 | features_path, 20 | tokenizer, 21 | max_words=30, 22 | feature_framerate=1.0, 23 | max_frames=100, 24 | image_resolution=224, 25 | frame_order=0, 26 | slice_framepos=2, 27 | ): 28 | self.data_path = data_path 29 | self.features_path = features_path 30 | self.feature_framerate = feature_framerate 31 | self.max_words = max_words 32 | self.max_frames = max_frames 33 | self.tokenizer = tokenizer 34 | # 0: ordinary order; 1: reverse order; 2: random order. 35 | self.frame_order = frame_order 36 | assert self.frame_order in [0, 1, 2] 37 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 38 | self.slice_framepos = slice_framepos 39 | assert self.slice_framepos in [0, 1, 2] 40 | 41 | self.subset = subset 42 | assert self.subset in ["train", "val", "train_test"] 43 | 44 | video_id_path_dict = {} 45 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json") 46 | video_id_path_dict["train_test"] = os.path.join(self.data_path, "train_ids.json") 47 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json") 48 | 49 | video_json_path_dict = {} 50 | video_json_path_dict["train"] = os.path.join(self.data_path, "train.json") 51 | video_json_path_dict["train_test"] = os.path.join(self.data_path, "train.json") 52 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json") 53 | 54 | pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset]) 55 | pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset]) 56 | 57 | print("video id list: {}".format(len(video_id_list))) 58 | print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys()))) 59 | 60 | video_dict = {} 61 | for root, dub_dir, video_files in os.walk(self.features_path): 62 | for video_file in video_files: 63 | video_id_ = ".".join(video_file.split(".")[:-1])[2:] 64 | if video_id_ not in video_id_list: 65 | continue 66 | file_path_ = os.path.join(root, video_file) 67 | video_dict[video_id_] = file_path_ 68 | self.video_dict = video_dict 69 | print("video dict: {}".format(len(video_dict))) 70 | 71 | self.pseudo_video_id_list = pseudo_video_id_list 72 | self.video_id_list = video_id_list 73 | self.pseudo_caption_dict = pseudo_caption_dict 74 | 75 | # Get iterator video ids 76 | self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)} 77 | # Get all captions 78 | self.iter2video_pairs_dict = {} 79 | for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list): 80 | if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict: 81 | continue 82 | caption = self.pseudo_caption_dict[pseudo_video_id] 83 | n_caption = len(caption['start']) 84 | for sub_id in range(n_caption): 85 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id) 86 | 87 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 88 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 89 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 90 | 91 | def __len__(self): 92 | return len(self.iter2video_pairs_dict) 93 | 94 | def _get_video_id_from_pseduo(self, pseudo_video_id): 95 | video_id = pseudo_video_id[2:] 96 | return video_id 97 | 98 | def _get_video_id_single(self, path): 99 | pseudo_video_id_list = [] 100 | video_id_list = [] 101 | print('Loading json: {}'.format(path)) 102 | with open(path, 'r') as f: 103 | json_data = json.load(f) 104 | 105 | for pseudo_video_id in json_data: 106 | if pseudo_video_id in pseudo_video_id_list: 107 | print("reduplicate.") 108 | else: 109 | video_id = self._get_video_id_from_pseduo(pseudo_video_id) 110 | pseudo_video_id_list.append(pseudo_video_id) 111 | video_id_list.append(video_id) 112 | return pseudo_video_id_list, video_id_list 113 | 114 | def _get_captions_single(self, path): 115 | pseudo_caption_dict = {} 116 | with open(path, 'r') as f: 117 | json_data = json.load(f) 118 | 119 | for pseudo_video_id, v_ in json_data.items(): 120 | pseudo_caption_dict[pseudo_video_id] = {} 121 | duration = v_["duration"] 122 | pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object) 123 | pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object) 124 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object) 125 | return pseudo_caption_dict 126 | 127 | def _get_text(self, pseudo_video_id, sub_id): 128 | caption = self.pseudo_caption_dict[pseudo_video_id] 129 | k = 1 130 | r_ind = [sub_id] 131 | 132 | starts = np.zeros(k, dtype=np.long) 133 | ends = np.zeros(k, dtype=np.long) 134 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 135 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 136 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 137 | 138 | for i in range(k): 139 | ind = r_ind[i] 140 | start_, end_ = caption['start'][ind], caption['end'][ind] 141 | words = self.tokenizer.tokenize(caption['text'][ind]) 142 | starts[i], ends[i] = start_, end_ 143 | 144 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 145 | total_length_with_CLS = self.max_words - 1 146 | if len(words) > total_length_with_CLS: 147 | words = words[:total_length_with_CLS] 148 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 149 | 150 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 151 | input_mask = [1] * len(input_ids) 152 | segment_ids = [0] * len(input_ids) 153 | while len(input_ids) < self.max_words: 154 | input_ids.append(0) 155 | input_mask.append(0) 156 | segment_ids.append(0) 157 | assert len(input_ids) == self.max_words 158 | assert len(input_mask) == self.max_words 159 | assert len(segment_ids) == self.max_words 160 | 161 | pairs_text[i] = np.array(input_ids) 162 | pairs_mask[i] = np.array(input_mask) 163 | pairs_segment[i] = np.array(segment_ids) 164 | 165 | return pairs_text, pairs_mask, pairs_segment, starts, ends 166 | 167 | def _get_rawvideo(self, idx, s, e): 168 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) 169 | max_video_length = [0] * len(s) 170 | 171 | # Pair x L x T x 3 x H x W 172 | video = np.zeros((len(s), self.max_frames, 1, 3, 173 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 174 | video_path = self.video_dict[idx] 175 | try: 176 | for i in range(len(s)): 177 | start_time = int(s[i]) 178 | end_time = int(e[i]) 179 | start_time = start_time if start_time >= 0. else 0. 180 | end_time = end_time if end_time >= 0. else 0. 181 | if start_time > end_time: 182 | start_time, end_time = end_time, start_time 183 | elif start_time == end_time: 184 | end_time = end_time + 1 185 | 186 | # Should be optimized by gathering all asking of this video 187 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 188 | raw_video_data = raw_video_data['video'] 189 | 190 | if len(raw_video_data.shape) > 3: 191 | raw_video_data_clip = raw_video_data 192 | # L x T x 3 x H x W 193 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 194 | if self.max_frames < raw_video_slice.shape[0]: 195 | if self.slice_framepos == 0: 196 | video_slice = raw_video_slice[:self.max_frames, ...] 197 | elif self.slice_framepos == 1: 198 | video_slice = raw_video_slice[-self.max_frames:, ...] 199 | else: 200 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 201 | video_slice = raw_video_slice[sample_indx, ...] 202 | else: 203 | video_slice = raw_video_slice 204 | 205 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 206 | 207 | slice_len = video_slice.shape[0] 208 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 209 | if slice_len < 1: 210 | pass 211 | else: 212 | video[i][:slice_len, ...] = video_slice 213 | else: 214 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) 215 | except Exception as excep: 216 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) 217 | raise excep 218 | 219 | for i, v_length in enumerate(max_video_length): 220 | video_mask[i][:v_length] = [1] * v_length 221 | 222 | return video, video_mask 223 | 224 | def __getitem__(self, feature_idx): 225 | pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx] 226 | idx = self.video_id2idx_dict[pseudo_video_id] 227 | 228 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id) 229 | video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends) 230 | return pairs_text, pairs_mask, video, video_mask, feature_idx, hash(pseudo_video_id) 231 | 232 | 233 | def load_stopwords(path='data/english.txt'): 234 | with open(path, 'r', encoding='utf-8') as f: 235 | lines = f.readlines() # 取出所有行 236 | return [line.strip() for line in lines] # 237 | 238 | 239 | def remove_stopwords(documents, stopwords): 240 | cleaned_documents = [] 241 | for token in documents.split(): 242 | if token not in stopwords: 243 | cleaned_documents.append(token) 244 | return " ".join('%s' %a for a in cleaned_documents) -------------------------------------------------------------------------------- /tvr/dataloaders/dataloader_didemo_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | from .rawvideo_util import RawVideoExtractor 11 | 12 | 13 | class DiDeMoDataset(Dataset): 14 | def __init__( 15 | self, 16 | subset, 17 | data_path, 18 | features_path, 19 | tokenizer, 20 | max_words=30, 21 | feature_framerate=1.0, 22 | max_frames=100, 23 | image_resolution=224, 24 | frame_order=0, 25 | slice_framepos=2, 26 | ): 27 | self.data_path = data_path 28 | self.features_path = features_path 29 | self.feature_framerate = feature_framerate 30 | self.max_words = max_words 31 | self.max_frames = max_frames 32 | self.tokenizer = tokenizer 33 | # 0: ordinary order; 1: reverse order; 2: random order. 34 | self.frame_order = frame_order 35 | assert self.frame_order in [0, 1, 2] 36 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 37 | self.slice_framepos = slice_framepos 38 | assert self.slice_framepos in [0, 1, 2] 39 | 40 | self.subset = subset 41 | assert self.subset in ["train", "val", "test", "train_test"] 42 | 43 | video_id_path_dict = {} 44 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") 45 | video_id_path_dict["train_test"] = os.path.join(self.data_path, "train_list.txt") 46 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") 47 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") 48 | 49 | video_json_path_dict = {} 50 | video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json") 51 | video_json_path_dict["train_test"] = os.path.join(self.data_path, "train_data.json") 52 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json") 53 | video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json") 54 | 55 | 56 | with open(video_id_path_dict[self.subset], 'r') as fp: 57 | video_ids = [itm.strip() for itm in fp.readlines()] 58 | 59 | caption_dict = {} 60 | with open(video_json_path_dict[self.subset], 'r') as f: 61 | json_data = json.load(f) 62 | for itm in json_data: 63 | description = itm["description"] 64 | times = itm["times"] 65 | video = itm["video"] 66 | if video not in video_ids: 67 | continue 68 | # each video is split into 5-second temporal chunks 69 | # average the points from each annotator 70 | start_ = np.mean([t_[0] for t_ in times]) * 5 71 | end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5 72 | if video in caption_dict: 73 | caption_dict[video]["start"].append(start_) 74 | caption_dict[video]["end"].append(end_) 75 | caption_dict[video]["text"].append(description) 76 | else: 77 | caption_dict[video] = {} 78 | caption_dict[video]["start"] = [start_] 79 | caption_dict[video]["end"] = [end_] 80 | caption_dict[video]["text"] = [description] 81 | 82 | for k_ in caption_dict.keys(): 83 | caption_dict[k_]["start"] = [0] 84 | # trick to save time on obtaining each video length 85 | # [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: 86 | # Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. 87 | caption_dict[k_]["end"] = [31] 88 | caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])] 89 | 90 | video_dict = {} 91 | for root, dub_dir, video_files in os.walk(self.features_path): 92 | for video_file in video_files: 93 | video_id_ = os.path.splitext(video_file)[0] 94 | if video_id_ not in video_ids: 95 | continue 96 | file_path_ = os.path.join(root, video_file) 97 | video_dict[video_id_] = file_path_ 98 | 99 | self.caption_dict = caption_dict 100 | self.video_dict = video_dict 101 | video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys())) 102 | 103 | # Get all captions 104 | self.iter2video_pairs_dict = {} 105 | for video_id in self.caption_dict.keys(): 106 | if video_id not in video_ids: 107 | continue 108 | caption = self.caption_dict[video_id] 109 | n_caption = len(caption['start']) 110 | for sub_id in range(n_caption): 111 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id) 112 | 113 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 114 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 115 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 116 | 117 | def __len__(self): 118 | return len(self.iter2video_pairs_dict) 119 | 120 | def _get_text(self, video_id, sub_id): 121 | caption = self.caption_dict[video_id] 122 | k = 1 123 | r_ind = [sub_id] 124 | 125 | starts = np.zeros(k, dtype=np.long) 126 | ends = np.zeros(k, dtype=np.long) 127 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 128 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 129 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 130 | 131 | for i in range(k): 132 | ind = r_ind[i] 133 | start_, end_ = caption['start'][ind], caption['end'][ind] 134 | words = self.tokenizer.tokenize(caption['text'][ind]) 135 | starts[i], ends[i] = start_, end_ 136 | 137 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 138 | total_length_with_CLS = self.max_words - 1 139 | if len(words) > total_length_with_CLS: 140 | words = words[:total_length_with_CLS] 141 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 142 | 143 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 144 | input_mask = [1] * len(input_ids) 145 | segment_ids = [0] * len(input_ids) 146 | while len(input_ids) < self.max_words: 147 | input_ids.append(0) 148 | input_mask.append(0) 149 | segment_ids.append(0) 150 | assert len(input_ids) == self.max_words 151 | assert len(input_mask) == self.max_words 152 | assert len(segment_ids) == self.max_words 153 | 154 | pairs_text[i] = np.array(input_ids) 155 | pairs_mask[i] = np.array(input_mask) 156 | pairs_segment[i] = np.array(segment_ids) 157 | 158 | return pairs_text, pairs_mask, pairs_segment, starts, ends 159 | 160 | def _get_rawvideo(self, idx, s, e): 161 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) 162 | max_video_length = [0] * len(s) 163 | 164 | # Pair x L x T x 3 x H x W 165 | video = np.zeros((len(s), self.max_frames, 1, 3, 166 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 167 | video_path = self.video_dict[idx] 168 | 169 | try: 170 | for i in range(len(s)): 171 | start_time = int(s[i]) 172 | end_time = int(e[i]) 173 | start_time = start_time if start_time >= 0. else 0. 174 | end_time = end_time if end_time >= 0. else 0. 175 | if start_time > end_time: 176 | start_time, end_time = end_time, start_time 177 | elif start_time == end_time: 178 | end_time = end_time + 1 179 | 180 | cache_id = "{}_{}_{}".format(video_path, start_time, end_time) 181 | # Should be optimized by gathering all asking of this video 182 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 183 | raw_video_data = raw_video_data['video'] 184 | 185 | if len(raw_video_data.shape) > 3: 186 | raw_video_data_clip = raw_video_data 187 | # L x T x 3 x H x W 188 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 189 | if self.max_frames < raw_video_slice.shape[0]: 190 | if self.slice_framepos == 0: 191 | video_slice = raw_video_slice[:self.max_frames, ...] 192 | elif self.slice_framepos == 1: 193 | video_slice = raw_video_slice[-self.max_frames:, ...] 194 | else: 195 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 196 | video_slice = raw_video_slice[sample_indx, ...] 197 | else: 198 | video_slice = raw_video_slice 199 | 200 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 201 | 202 | slice_len = video_slice.shape[0] 203 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 204 | if slice_len < 1: 205 | pass 206 | else: 207 | video[i][:slice_len, ...] = video_slice 208 | else: 209 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) 210 | except Exception as excep: 211 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) 212 | pass 213 | # raise e 214 | 215 | for i, v_length in enumerate(max_video_length): 216 | video_mask[i][:v_length] = [1] * v_length 217 | 218 | return video, video_mask 219 | 220 | def __getitem__(self, feature_idx): 221 | video_id, sub_id = self.iter2video_pairs_dict[feature_idx] 222 | 223 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id) 224 | video, video_mask = self._get_rawvideo(video_id, starts, ends) 225 | return pairs_text, pairs_mask, video, video_mask, feature_idx, hash(video_id) -------------------------------------------------------------------------------- /tvr/dataloaders/dataloader_lsmdc_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import json 7 | import tempfile 8 | import pandas as pd 9 | from os.path import join, splitext, exists 10 | from collections import OrderedDict 11 | from .dataloader_retrieval import RetrievalDataset 12 | import os 13 | 14 | 15 | class LsmdcDataset(RetrievalDataset): 16 | """LSMDC dataset.""" 17 | 18 | def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32, 19 | max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None): 20 | super(LsmdcDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words, 21 | max_frames, video_framerate, image_resolution, mode, config=config) 22 | pass 23 | 24 | def _get_anns(self, subset='train'): 25 | """ 26 | video_dict: dict: video_id -> video_path 27 | sentences_dict: list: [(video_id, caption)] , caption (list: [text:, start, end]) 28 | """ 29 | video_json_path_dict = {} 30 | video_json_path_dict["train"] = os.path.join(self.anno_path, "LSMDC16_annos_training.csv") 31 | #video_json_path_dict["train_test"] = os.path.join(self.anno_path, "LSMDC16_annos_training.csv") 32 | video_json_path_dict["train_test"] = os.path.join(self.anno_path, "LSMDC16_annos_val.csv") 33 | video_json_path_dict["val"] = os.path.join(self.anno_path, "LSMDC16_annos_val.csv") 34 | video_json_path_dict["test"] = os.path.join(self.anno_path, "LSMDC16_challenge_1000_publictect.csv") 35 | 36 | # \t\t\t\t\t 37 | # is not a unique identifier, i.e. the same can be associated with multiple sentences. 38 | # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances 39 | video_id_list = [] 40 | caption_dict = {} 41 | with open(video_json_path_dict[self.subset], 'r') as fp: 42 | for line in fp: 43 | line = line.strip() 44 | line_split = line.split("\t") 45 | assert len(line_split) == 6 46 | clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split 47 | if clip_id not in ["0017_Pianist_00.23.28.872-00.23.34.843", "0017_Pianist_00.30.36.767-00.30.38.009", 48 | "3064_SPARKLE_2012_01.41.07.000-01.41.11.793"]: 49 | caption_dict[len(caption_dict)] = (clip_id, (sentence, None, None)) 50 | if clip_id not in video_id_list: video_id_list.append(clip_id) 51 | 52 | video_dict = OrderedDict() 53 | sentences_dict = OrderedDict() 54 | 55 | for root, dub_dir, video_files in os.walk(self.video_path): 56 | for video_file in video_files: 57 | video_id_ = ".".join(video_file.split(".")[:-1]) 58 | if video_id_ not in video_id_list: 59 | continue 60 | file_path_ = os.path.join(root, video_file) 61 | video_dict[video_id_] = file_path_ 62 | 63 | # Get all captions 64 | for clip_id, sentence in caption_dict.values(): 65 | if clip_id not in video_dict: 66 | continue 67 | sentences_dict[len(sentences_dict)] = (clip_id, sentence) 68 | 69 | unique_sentence = set([v[1][0] for v in sentences_dict.values()]) 70 | print('[{}] Unique sentence is {} , all num is {}'.format(subset, len(unique_sentence), len(sentences_dict))) 71 | 72 | return video_dict, sentences_dict -------------------------------------------------------------------------------- /tvr/dataloaders/dataloader_msrvtt_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import json 7 | import tempfile 8 | import pandas as pd 9 | from os.path import join, splitext, exists 10 | from collections import OrderedDict 11 | from .dataloader_retrieval import RetrievalDataset 12 | 13 | 14 | class MSRVTTDataset(RetrievalDataset): 15 | """MSRVTT dataset.""" 16 | 17 | def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32, 18 | max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None): 19 | super(MSRVTTDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words, 20 | max_frames, video_framerate, image_resolution, mode, config=config) 21 | pass 22 | 23 | def _get_anns(self, subset='train'): 24 | """ 25 | video_dict: dict: video_id -> video_path 26 | sentences_dict: list: [(video_id, caption)] , caption (list: [text:, start, end]) 27 | """ 28 | csv_path = {'train': join(self.anno_path, 'MSRVTT_train.9k.csv'), 29 | 'val': join(self.anno_path, 'MSRVTT_JSFUSION_test.csv'), 30 | 'test': join(self.anno_path, 'MSRVTT_JSFUSION_test.csv'), 31 | 'train_test': join(self.anno_path, 'MSRVTT_train.9k.csv')}[subset] 32 | if exists(csv_path): 33 | csv = pd.read_csv(csv_path) 34 | else: 35 | raise FileNotFoundError 36 | 37 | video_id_list = list(csv['video_id'].values) 38 | 39 | video_dict = OrderedDict() 40 | sentences_dict = OrderedDict() 41 | if subset == 'train': 42 | anno_path = join(self.anno_path, 'MSRVTT_data.json') 43 | data = json.load(open(anno_path, 'r')) 44 | for itm in data['sentences']: 45 | if itm['video_id'] in video_id_list: 46 | sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['caption'], None, None)) 47 | video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id'])) 48 | elif subset == 'train_test': 49 | anno_path = join(self.anno_path, 'MSRVTT_data.json') 50 | data = json.load(open(anno_path, 'r')) 51 | used = [] 52 | for itm in data['sentences']: 53 | if itm['video_id'] in video_id_list and itm['video_id'] not in used: 54 | used.append(itm['video_id']) 55 | sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['caption'], None, None)) 56 | video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id'])) 57 | else: 58 | for _, itm in csv.iterrows(): 59 | sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['sentence'], None, None)) 60 | video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id'])) 61 | 62 | unique_sentence = set([v[1][0] for v in sentences_dict.values()]) 63 | print('[{}] Unique sentence is {} , all num is {}'.format(subset, len(unique_sentence), len(sentences_dict))) 64 | 65 | return video_dict, sentences_dict 66 | -------------------------------------------------------------------------------- /tvr/dataloaders/dataloader_msvd_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import json 7 | import tempfile 8 | import os 9 | import pickle 10 | import pandas as pd 11 | from os.path import join, splitext, exists 12 | from collections import OrderedDict 13 | from .dataloader_retrieval import RetrievalDataset 14 | 15 | 16 | class MsvdDataset(RetrievalDataset): 17 | """MSVD dataset loader.""" 18 | def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32, 19 | max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None): 20 | super(MsvdDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words, 21 | max_frames, video_framerate, image_resolution, mode, config=config) 22 | pass 23 | 24 | def _get_anns(self, subset='train'): 25 | self.sample_len = 0 26 | self.cut_off_points = [] 27 | self.multi_sentence_per_video = True # !!! important tag for eval 28 | 29 | video_id_path_dict = {} 30 | video_id_path_dict["train"] = os.path.join(self.anno_path, "train_list.txt") 31 | video_id_path_dict["train_test"] = os.path.join(self.anno_path, "train_list.txt") 32 | video_id_path_dict["val"] = os.path.join(self.anno_path, "val_list.txt") 33 | video_id_path_dict["test"] = os.path.join(self.anno_path, "test_list.txt") 34 | caption_file = os.path.join(self.anno_path, "raw-captions.pkl") 35 | 36 | with open(video_id_path_dict[subset], 'r') as fp: 37 | video_ids = [itm.strip() for itm in fp.readlines()] 38 | 39 | with open(caption_file, 'rb') as f: 40 | captions = pickle.load(f) 41 | 42 | video_dict = OrderedDict() 43 | sentences_dict = OrderedDict() 44 | 45 | for root, dub_dir, video_files in os.walk(self.video_path): 46 | for video_file in video_files: 47 | video_id_ = ".".join(video_file.split(".")[:-1]) 48 | if video_id_ not in video_ids: 49 | continue 50 | file_path_ = os.path.join(root, video_file) 51 | video_dict[video_id_] = file_path_ 52 | 53 | for video_id in video_ids: 54 | assert video_id in captions 55 | for cap in captions[video_id]: 56 | cap_txt = " ".join(cap) 57 | sentences_dict[len(sentences_dict)] = (video_id, (cap_txt, None, None)) 58 | self.cut_off_points.append(len(sentences_dict) - 1) 59 | 60 | if subset == "val" or subset == "test": 61 | self.sentence_num = len(sentences_dict) 62 | self.video_num = len(video_ids) 63 | assert len(self.cut_off_points) == self.video_num 64 | print("For {}, sentence number: {}".format(subset, self.sentence_num)) 65 | print("For {}, video number: {}".format(subset, self.video_num)) 66 | 67 | print("Video number: {}".format(len(video_dict))) 68 | print("Total Paire: {}".format(len(sentences_dict))) 69 | 70 | self.sample_len = len(sentences_dict) 71 | 72 | return video_dict, sentences_dict -------------------------------------------------------------------------------- /tvr/dataloaders/dataloader_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | from os.path import exists 7 | 8 | import random 9 | import numpy as np 10 | from torch.utils.data import Dataset 11 | 12 | import torch 13 | from PIL import Image 14 | from decord import VideoReader, cpu 15 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, ToPILImage, ColorJitter, RandomHorizontalFlip, RandomResizedCrop 16 | import tvr.dataloaders.video_transforms as video_transforms 17 | from .random_erasing import RandomErasing 18 | 19 | 20 | class RetrievalDataset(Dataset): 21 | """General dataset.""" 22 | 23 | def __init__( 24 | self, 25 | subset, 26 | anno_path, 27 | video_path, 28 | tokenizer, 29 | max_words=30, 30 | max_frames=12, 31 | video_framerate=1, 32 | image_resolution=224, 33 | mode='all', 34 | config=None 35 | ): 36 | self.subset = subset 37 | self.anno_path = anno_path 38 | self.video_path = video_path 39 | self.tokenizer = tokenizer 40 | self.max_words = max_words 41 | self.max_frames = max_frames 42 | self.video_framerate = video_framerate 43 | self.image_resolution = image_resolution 44 | self.mode = mode # all/text/vision 45 | self.config = config 46 | 47 | self.video_dict, self.sentences_dict = self._get_anns(self.subset) 48 | 49 | self.video_list = list(self.video_dict.keys()) 50 | self.sample_len = 0 51 | 52 | print("Video number: {}".format(len(self.video_dict))) 53 | print("Total Pairs: {}".format(len(self.sentences_dict))) 54 | 55 | from .rawvideo_util import RawVideoExtractor 56 | self.rawVideoExtractor = RawVideoExtractor(framerate=video_framerate, size=image_resolution) 57 | self.transform = Compose([ 58 | Resize(image_resolution, interpolation=InterpolationMode.BICUBIC), 59 | CenterCrop(image_resolution), 60 | lambda image: image.convert("RGB"), 61 | ToTensor(), 62 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 63 | # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 64 | ]) 65 | self.tsfm_dict = { 66 | 'clip_test': Compose([ 67 | Resize(image_resolution, interpolation=InterpolationMode.BICUBIC), 68 | CenterCrop(image_resolution), 69 | lambda image: image.convert("RGB"), 70 | ToTensor(), 71 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 72 | ]), 73 | 'clip_train': Compose([ 74 | RandomResizedCrop(image_resolution, scale=(0.5, 1.0)), 75 | RandomHorizontalFlip(), 76 | lambda image: image.convert("RGB"), 77 | ToTensor(), 78 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 79 | ]) 80 | } 81 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 82 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 83 | self.image_resolution = image_resolution 84 | if self.mode in ['all', 'text']: 85 | self.sample_len = len(self.sentences_dict) 86 | else: 87 | self.sample_len = len(self.video_list) 88 | self.aug_transform = video_transforms.create_random_augment( 89 | input_size=(self.image_resolution, self.image_resolution), 90 | auto_augment='rand-m7-n4-mstd0.5-inc1', 91 | interpolation='bicubic', 92 | ) 93 | 94 | def __len__(self): 95 | return self.sample_len 96 | 97 | def __aug_transform(self, buffer): 98 | _aug_transform = video_transforms.create_random_augment( 99 | input_size=(self.image_resolution, self.image_resolution), 100 | auto_augment='rand-m7-n4-mstd0.5-inc1', 101 | interpolation='bicubic', 102 | ) 103 | buffer = _aug_transform(buffer) 104 | return buffer 105 | buffer = [ToTensor()(img) for img in buffer] 106 | buffer = torch.stack(buffer) # T C H W 107 | buffer = buffer.permute(1, 0, 2, 3) # T H W C -> C T H W. 108 | # Perform data augmentation. 109 | scl, asp = ( 110 | [0.08, 1.0], 111 | [0.75, 1.3333], 112 | ) 113 | 114 | buffer = spatial_sampling( 115 | buffer, 116 | spatial_idx=-1, 117 | min_scale=256, 118 | max_scale=320, 119 | crop_size=224, 120 | random_horizontal_flip=True, 121 | inverse_uniform_sampling=False, 122 | aspect_ratio=asp, 123 | scale=scl, 124 | motion_shift=False 125 | ) 126 | buffer = buffer.permute(1, 0, 2, 3) 127 | buffer = [ToPILImage()(frame) for frame in buffer] 128 | return buffer 129 | erase_transform = RandomErasing( 130 | 0.25, 131 | mode='pixel', 132 | max_count=1, 133 | num_splits=1, 134 | device="cpu", 135 | ) 136 | buffer = buffer.permute(1, 0, 2, 3) 137 | buffer = erase_transform(buffer) 138 | buffer = [ToPILImage()(frame) for frame in buffer] 139 | return buffer 140 | 141 | def _get_anns(self, subset='train'): 142 | raise NotImplementedError 143 | 144 | def _get_text(self, caption): 145 | if len(caption) == 3: 146 | _caption_text, s, e = caption 147 | else: 148 | raise NotImplementedError 149 | 150 | if isinstance(_caption_text, list): 151 | caption_text = random.choice(_caption_text) 152 | else: 153 | caption_text = _caption_text 154 | 155 | words = self.tokenizer.tokenize(caption_text) 156 | 157 | if self.subset == "train" and 0: 158 | if random.random() < 0.5: 159 | new_words = [] 160 | for idx in range(len(words)): 161 | if random.random() < 0.8: 162 | new_words.append(words[idx]) 163 | words = new_words 164 | 165 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 166 | total_length_with_CLS = self.max_words - 1 167 | if len(words) > total_length_with_CLS: 168 | words = words[:total_length_with_CLS] 169 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 170 | 171 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 172 | input_mask = [1] * len(input_ids) 173 | 174 | while len(input_ids) < self.max_words: 175 | input_ids.append(0) 176 | input_mask.append(0) 177 | assert len(input_ids) == self.max_words 178 | assert len(input_mask) == self.max_words 179 | 180 | input_ids = np.array(input_ids) 181 | input_mask = np.array(input_mask) 182 | 183 | return input_ids, input_mask, s, e 184 | 185 | def _get_rawvideo(self, video_id, s=None, e=None): 186 | video_mask = np.zeros(self.max_frames, dtype=np.long) 187 | max_video_length = 0 188 | 189 | # T x 3 x H x W 190 | video = np.zeros((self.max_frames, 3, self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 191 | 192 | if s is None: 193 | start_time, end_time = None, None 194 | else: 195 | start_time = int(s) 196 | end_time = int(e) 197 | start_time = start_time if start_time >= 0. else 0. 198 | end_time = end_time if end_time >= 0. else 0. 199 | if start_time > end_time: 200 | start_time, end_time = end_time, start_time 201 | elif start_time == end_time: 202 | end_time = end_time + 1 203 | video_path = self.video_dict[video_id] 204 | 205 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 206 | raw_video_data = raw_video_data['video'] 207 | 208 | if len(raw_video_data.shape) > 3: 209 | # L x T x 3 x H x W 210 | 211 | if self.max_frames < raw_video_data.shape[0]: 212 | sample_indx = np.linspace(0, raw_video_data.shape[0] - 1, num=self.max_frames, dtype=int) 213 | video_slice = raw_video_data[sample_indx, ...] 214 | else: 215 | video_slice = raw_video_data 216 | 217 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=0) 218 | 219 | slice_len = video_slice.shape[0] 220 | max_video_length = max_video_length if max_video_length > slice_len else slice_len 221 | if slice_len < 1: 222 | pass 223 | else: 224 | video[:slice_len, ...] = video_slice 225 | else: 226 | print("video path: {} error. video id: {}".format(video_path, video_id)) 227 | 228 | video_mask[:max_video_length] = [1] * max_video_length 229 | 230 | return video, video_mask 231 | 232 | def _get_rawvideo_dec(self, video_id, s=None, e=None): 233 | # speed up video decode via decord. 234 | video_mask = np.zeros(self.max_frames, dtype=np.long) 235 | max_video_length = 0 236 | 237 | # T x 3 x H x W 238 | video = np.zeros((self.max_frames, 3, self.image_resolution, self.image_resolution), dtype=np.float) 239 | 240 | if s is None: 241 | start_time, end_time = None, None 242 | else: 243 | start_time = int(s) 244 | end_time = int(e) 245 | start_time = start_time if start_time >= 0. else 0. 246 | end_time = end_time if end_time >= 0. else 0. 247 | if start_time > end_time: 248 | start_time, end_time = end_time, start_time 249 | elif start_time == end_time: 250 | end_time = start_time + 1 251 | video_path = self.video_dict[video_id] 252 | 253 | if exists(video_path): 254 | vreader = VideoReader(video_path, ctx=cpu(0)) 255 | else: 256 | print(video_path) 257 | raise FileNotFoundError 258 | 259 | fps = vreader.get_avg_fps() 260 | f_start = 0 if start_time is None else int(start_time * fps) 261 | f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) 262 | num_frames = f_end - f_start + 1 263 | if num_frames > 0: 264 | # T x 3 x H x W 265 | sample_fps = int(self.video_framerate) 266 | t_stride = int(round(float(fps) / sample_fps)) 267 | 268 | all_pos = list(range(f_start, f_end + 1, t_stride)) 269 | if len(all_pos) > self.max_frames: 270 | sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=self.max_frames, dtype=int)] 271 | else: 272 | sample_pos = all_pos 273 | 274 | patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] 275 | if self.subset == "train": 276 | # for i in range(2): 277 | patch_images = self.aug_transform(patch_images) 278 | 279 | # if self.subset == "train": 280 | # patch_images = torch.stack([self.tsfm_dict["clip_train"](img) for img in patch_images]) 281 | # else: 282 | # patch_images = torch.stack([self.tsfm_dict["clip_test"](img) for img in patch_images]) 283 | 284 | patch_images = torch.stack([self.transform(img) for img in patch_images]) 285 | slice_len = patch_images.shape[0] 286 | max_video_length = max_video_length if max_video_length > slice_len else slice_len 287 | if slice_len < 1: 288 | pass 289 | else: 290 | video[:slice_len, ...] = patch_images 291 | else: 292 | print("video path: {} error. video id: {}".format(video_path, video_id)) 293 | 294 | video_mask[:max_video_length] = [1] * max_video_length 295 | 296 | return video, video_mask 297 | 298 | def __getitem__(self, idx): 299 | 300 | if self.mode == 'all': 301 | video_id, caption = self.sentences_dict[idx] 302 | text_ids, text_mask, s, e = self._get_text(caption) 303 | video, video_mask = self._get_rawvideo_dec(video_id, s, e) 304 | # video, video_mask = self._get_rawvideo(video_id, s, e) 305 | return text_ids, text_mask, video, video_mask, idx, hash(video_id.replace("video", "")) 306 | elif self.mode == 'text': 307 | video_id, caption = self.sentences_dict[idx] 308 | text_ids, text_mask, s, e = self._get_text(caption) 309 | return text_ids, text_mask, idx 310 | elif self.mode == 'video': 311 | video_id = self.video_list[idx] 312 | video, video_mask = self._get_rawvideo_dec(video_id) 313 | # video, video_mask = self._get_rawvideo(video_id) 314 | return video, video_mask, idx 315 | 316 | def get_text_len(self): 317 | return len(self.sentences_dict) 318 | 319 | def get_video_len(self): 320 | return len(self.video_list) 321 | 322 | def get_text_content(self, ind): 323 | return self.sentences_dict[ind][1] 324 | 325 | def get_data_name(self): 326 | return self.__class__.__name__ + "_" + self.subset 327 | 328 | def get_vis_info(self, idx): 329 | video_id, caption = self.sentences_dict[idx] 330 | video_path = self.video_dict[video_id] 331 | return caption, video_path 332 | 333 | 334 | def spatial_sampling( 335 | frames, 336 | spatial_idx=-1, 337 | min_scale=256, 338 | max_scale=320, 339 | crop_size=224, 340 | random_horizontal_flip=True, 341 | inverse_uniform_sampling=False, 342 | aspect_ratio=None, 343 | scale=None, 344 | motion_shift=False, 345 | ): 346 | """ 347 | Perform spatial sampling on the given video frames. If spatial_idx is 348 | -1, perform random scale, random crop, and random flip on the given 349 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 350 | with the given spatial_idx. 351 | Args: 352 | frames (tensor): frames of images sampled from the video. The 353 | dimension is `num frames` x `height` x `width` x `channel`. 354 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 355 | or 2, perform left, center, right crop if width is larger than 356 | height, and perform top, center, buttom crop if height is larger 357 | than width. 358 | min_scale (int): the minimal size of scaling. 359 | max_scale (int): the maximal size of scaling. 360 | crop_size (int): the size of height and width used to crop the 361 | frames. 362 | inverse_uniform_sampling (bool): if True, sample uniformly in 363 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 364 | scale. If False, take a uniform sample from [min_scale, 365 | max_scale]. 366 | aspect_ratio (list): Aspect ratio range for resizing. 367 | scale (list): Scale range for resizing. 368 | motion_shift (bool): Whether to apply motion shift for resizing. 369 | Returns: 370 | frames (tensor): spatially sampled frames. 371 | """ 372 | assert spatial_idx in [-1, 0, 1, 2] 373 | if spatial_idx == -1: 374 | if aspect_ratio is None and scale is None: 375 | frames, _ = video_transforms.random_short_side_scale_jitter( 376 | images=frames, 377 | min_size=min_scale, 378 | max_size=max_scale, 379 | inverse_uniform_sampling=inverse_uniform_sampling, 380 | ) 381 | frames, _ = video_transforms.random_crop(frames, crop_size) 382 | else: 383 | transform_func = ( 384 | video_transforms.random_resized_crop_with_shift 385 | if motion_shift 386 | else video_transforms.random_resized_crop 387 | ) 388 | frames = transform_func( 389 | images=frames, 390 | target_height=crop_size, 391 | target_width=crop_size, 392 | scale=scale, 393 | ratio=aspect_ratio, 394 | ) 395 | if random_horizontal_flip: 396 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 397 | else: 398 | # The testing is deterministic and no jitter should be performed. 399 | # min_scale, max_scale, and crop_size are expect to be the same. 400 | assert len({min_scale, max_scale, crop_size}) == 1 401 | frames, _ = video_transforms.random_short_side_scale_jitter( 402 | frames, min_scale, max_scale 403 | ) 404 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 405 | return frames -------------------------------------------------------------------------------- /tvr/dataloaders/rand_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 4 | pulished under an Apache License 2.0. 5 | 6 | COMMENT FROM ORIGINAL: 7 | AutoAugment, RandAugment, and AugMix for PyTorch 8 | This code implements the searched ImageNet policies with various tweaks and 9 | improvements and does not include any of the search code. AA and RA 10 | Implementation adapted from: 11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 12 | AugMix adapted from: 13 | https://github.com/google-research/augmix 14 | Papers: 15 | AutoAugment: Learning Augmentation Policies from Data 16 | https://arxiv.org/abs/1805.09501 17 | Learning Data Augmentation Strategies for Object Detection 18 | https://arxiv.org/abs/1906.11172 19 | RandAugment: Practical automated data augmentation... 20 | https://arxiv.org/abs/1909.13719 21 | AugMix: A Simple Data Processing Method to Improve Robustness and 22 | Uncertainty https://arxiv.org/abs/1912.02781 23 | 24 | Hacked together by / Copyright 2020 Ross Wightman 25 | """ 26 | 27 | import math 28 | import numpy as np 29 | import random 30 | import re 31 | import PIL 32 | from PIL import Image, ImageEnhance, ImageOps 33 | 34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 35 | 36 | _FILL = (128, 128, 128) 37 | 38 | # This signifies the max integer that the controller RNN could predict for the 39 | # augmentation scheme. 40 | _MAX_LEVEL = 10.0 41 | 42 | _HPARAMS_DEFAULT = { 43 | "translate_const": 250, 44 | "img_mean": _FILL, 45 | } 46 | 47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 48 | 49 | 50 | def _interpolation(kwargs): 51 | interpolation = kwargs.pop("resample", Image.BILINEAR) 52 | if isinstance(interpolation, (list, tuple)): 53 | return random.choice(interpolation) 54 | else: 55 | return interpolation 56 | 57 | 58 | def _check_args_tf(kwargs): 59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 60 | kwargs.pop("fillcolor") 61 | kwargs["resample"] = _interpolation(kwargs) 62 | 63 | 64 | def shear_x(img, factor, **kwargs): 65 | _check_args_tf(kwargs) 66 | return img.transform( 67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 68 | ) 69 | 70 | 71 | def shear_y(img, factor, **kwargs): 72 | _check_args_tf(kwargs) 73 | return img.transform( 74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 75 | ) 76 | 77 | 78 | def translate_x_rel(img, pct, **kwargs): 79 | pixels = pct * img.size[0] 80 | _check_args_tf(kwargs) 81 | return img.transform( 82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 83 | ) 84 | 85 | 86 | def translate_y_rel(img, pct, **kwargs): 87 | pixels = pct * img.size[1] 88 | _check_args_tf(kwargs) 89 | return img.transform( 90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 91 | ) 92 | 93 | 94 | def translate_x_abs(img, pixels, **kwargs): 95 | _check_args_tf(kwargs) 96 | return img.transform( 97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 98 | ) 99 | 100 | 101 | def translate_y_abs(img, pixels, **kwargs): 102 | _check_args_tf(kwargs) 103 | return img.transform( 104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 105 | ) 106 | 107 | 108 | def rotate(img, degrees, **kwargs): 109 | _check_args_tf(kwargs) 110 | if _PIL_VER >= (5, 2): 111 | return img.rotate(degrees, **kwargs) 112 | elif _PIL_VER >= (5, 0): 113 | w, h = img.size 114 | post_trans = (0, 0) 115 | rotn_center = (w / 2.0, h / 2.0) 116 | angle = -math.radians(degrees) 117 | matrix = [ 118 | round(math.cos(angle), 15), 119 | round(math.sin(angle), 15), 120 | 0.0, 121 | round(-math.sin(angle), 15), 122 | round(math.cos(angle), 15), 123 | 0.0, 124 | ] 125 | 126 | def transform(x, y, matrix): 127 | (a, b, c, d, e, f) = matrix 128 | return a * x + b * y + c, d * x + e * y + f 129 | 130 | matrix[2], matrix[5] = transform( 131 | -rotn_center[0] - post_trans[0], 132 | -rotn_center[1] - post_trans[1], 133 | matrix, 134 | ) 135 | matrix[2] += rotn_center[0] 136 | matrix[5] += rotn_center[1] 137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 138 | else: 139 | return img.rotate(degrees, resample=kwargs["resample"]) 140 | 141 | 142 | def auto_contrast(img, **__): 143 | return ImageOps.autocontrast(img) 144 | 145 | 146 | def invert(img, **__): 147 | return ImageOps.invert(img) 148 | 149 | 150 | def equalize(img, **__): 151 | return ImageOps.equalize(img) 152 | 153 | 154 | def solarize(img, thresh, **__): 155 | return ImageOps.solarize(img, thresh) 156 | 157 | 158 | def solarize_add(img, add, thresh=128, **__): 159 | lut = [] 160 | for i in range(256): 161 | if i < thresh: 162 | lut.append(min(255, i + add)) 163 | else: 164 | lut.append(i) 165 | if img.mode in ("L", "RGB"): 166 | if img.mode == "RGB" and len(lut) == 256: 167 | lut = lut + lut + lut 168 | return img.point(lut) 169 | else: 170 | return img 171 | 172 | 173 | def posterize(img, bits_to_keep, **__): 174 | if bits_to_keep >= 8: 175 | return img 176 | return ImageOps.posterize(img, bits_to_keep) 177 | 178 | 179 | def contrast(img, factor, **__): 180 | return ImageEnhance.Contrast(img).enhance(factor) 181 | 182 | 183 | def color(img, factor, **__): 184 | return ImageEnhance.Color(img).enhance(factor) 185 | 186 | 187 | def brightness(img, factor, **__): 188 | return ImageEnhance.Brightness(img).enhance(factor) 189 | 190 | 191 | def sharpness(img, factor, **__): 192 | return ImageEnhance.Sharpness(img).enhance(factor) 193 | 194 | 195 | def _randomly_negate(v): 196 | """With 50% prob, negate the value""" 197 | return -v if random.random() > 0.5 else v 198 | 199 | 200 | def _rotate_level_to_arg(level, _hparams): 201 | # range [-30, 30] 202 | level = (level / _MAX_LEVEL) * 30.0 203 | level = _randomly_negate(level) 204 | return (level,) 205 | 206 | 207 | def _enhance_level_to_arg(level, _hparams): 208 | # range [0.1, 1.9] 209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 210 | 211 | 212 | def _enhance_increasing_level_to_arg(level, _hparams): 213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 214 | # range [0.1, 1.9] 215 | level = (level / _MAX_LEVEL) * 0.9 216 | level = 1.0 + _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _shear_level_to_arg(level, _hparams): 221 | # range [-0.3, 0.3] 222 | level = (level / _MAX_LEVEL) * 0.3 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_abs_level_to_arg(level, hparams): 228 | translate_const = hparams["translate_const"] 229 | level = (level / _MAX_LEVEL) * float(translate_const) 230 | level = _randomly_negate(level) 231 | return (level,) 232 | 233 | 234 | def _translate_rel_level_to_arg(level, hparams): 235 | # default range [-0.45, 0.45] 236 | translate_pct = hparams.get("translate_pct", 0.45) 237 | level = (level / _MAX_LEVEL) * translate_pct 238 | level = _randomly_negate(level) 239 | return (level,) 240 | 241 | 242 | def _posterize_level_to_arg(level, _hparams): 243 | # As per Tensorflow TPU EfficientNet impl 244 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 245 | # intensity/severity of augmentation decreases with level 246 | return (int((level / _MAX_LEVEL) * 4),) 247 | 248 | 249 | def _posterize_increasing_level_to_arg(level, hparams): 250 | # As per Tensorflow models research and UDA impl 251 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 252 | # intensity/severity of augmentation increases with level 253 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 254 | 255 | 256 | def _posterize_original_level_to_arg(level, _hparams): 257 | # As per original AutoAugment paper description 258 | # range [4, 8], 'keep 4 up to 8 MSB of image' 259 | # intensity/severity of augmentation decreases with level 260 | return (int((level / _MAX_LEVEL) * 4) + 4,) 261 | 262 | 263 | def _solarize_level_to_arg(level, _hparams): 264 | # range [0, 256] 265 | # intensity/severity of augmentation decreases with level 266 | return (int((level / _MAX_LEVEL) * 256),) 267 | 268 | 269 | def _solarize_increasing_level_to_arg(level, _hparams): 270 | # range [0, 256] 271 | # intensity/severity of augmentation increases with level 272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 273 | 274 | 275 | def _solarize_add_level_to_arg(level, _hparams): 276 | # range [0, 110] 277 | return (int((level / _MAX_LEVEL) * 110),) 278 | 279 | 280 | LEVEL_TO_ARG = { 281 | "AutoContrast": None, 282 | "Equalize": None, 283 | "Invert": None, 284 | "Rotate": _rotate_level_to_arg, 285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 286 | "Posterize": _posterize_level_to_arg, 287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 288 | "PosterizeOriginal": _posterize_original_level_to_arg, 289 | "Solarize": _solarize_level_to_arg, 290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 291 | "SolarizeAdd": _solarize_add_level_to_arg, 292 | "Color": _enhance_level_to_arg, 293 | "ColorIncreasing": _enhance_increasing_level_to_arg, 294 | "Contrast": _enhance_level_to_arg, 295 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 296 | "Brightness": _enhance_level_to_arg, 297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 298 | "Sharpness": _enhance_level_to_arg, 299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 300 | "ShearX": _shear_level_to_arg, 301 | "ShearY": _shear_level_to_arg, 302 | "TranslateX": _translate_abs_level_to_arg, 303 | "TranslateY": _translate_abs_level_to_arg, 304 | "TranslateXRel": _translate_rel_level_to_arg, 305 | "TranslateYRel": _translate_rel_level_to_arg, 306 | } 307 | 308 | 309 | NAME_TO_OP = { 310 | "AutoContrast": auto_contrast, 311 | "Equalize": equalize, 312 | "Invert": invert, 313 | "Rotate": rotate, 314 | "Posterize": posterize, 315 | "PosterizeIncreasing": posterize, 316 | "PosterizeOriginal": posterize, 317 | "Solarize": solarize, 318 | "SolarizeIncreasing": solarize, 319 | "SolarizeAdd": solarize_add, 320 | "Color": color, 321 | "ColorIncreasing": color, 322 | "Contrast": contrast, 323 | "ContrastIncreasing": contrast, 324 | "Brightness": brightness, 325 | "BrightnessIncreasing": brightness, 326 | "Sharpness": sharpness, 327 | "SharpnessIncreasing": sharpness, 328 | "ShearX": shear_x, 329 | "ShearY": shear_y, 330 | "TranslateX": translate_x_abs, 331 | "TranslateY": translate_y_abs, 332 | "TranslateXRel": translate_x_rel, 333 | "TranslateYRel": translate_y_rel, 334 | } 335 | 336 | 337 | class AugmentOp: 338 | """ 339 | Apply for video. 340 | """ 341 | 342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 343 | hparams = hparams or _HPARAMS_DEFAULT 344 | self.aug_fn = NAME_TO_OP[name] 345 | self.level_fn = LEVEL_TO_ARG[name] 346 | self.prob = prob 347 | self.magnitude = magnitude 348 | self.hparams = hparams.copy() 349 | self.kwargs = { 350 | "fillcolor": hparams["img_mean"] 351 | if "img_mean" in hparams 352 | else _FILL, 353 | "resample": hparams["interpolation"] 354 | if "interpolation" in hparams 355 | else _RANDOM_INTERPOLATION, 356 | } 357 | 358 | # If magnitude_std is > 0, we introduce some randomness 359 | # in the usually fixed policy and sample magnitude from a normal distribution 360 | # with mean `magnitude` and std-dev of `magnitude_std`. 361 | # NOTE This is my own hack, being tested, not in papers or reference impls. 362 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 363 | 364 | def __call__(self, img_list): 365 | if self.prob < 1.0 and random.random() > self.prob: 366 | return img_list 367 | magnitude = self.magnitude 368 | if self.magnitude_std and self.magnitude_std > 0: 369 | magnitude = random.gauss(magnitude, self.magnitude_std) 370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 371 | level_args = ( 372 | self.level_fn(magnitude, self.hparams) 373 | if self.level_fn is not None 374 | else () 375 | ) 376 | 377 | if isinstance(img_list, list): 378 | return [ 379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 380 | ] 381 | else: 382 | return self.aug_fn(img_list, *level_args, **self.kwargs) 383 | 384 | 385 | _RAND_TRANSFORMS = [ 386 | "AutoContrast", 387 | "Equalize", 388 | "Invert", 389 | "Rotate", 390 | "Posterize", 391 | "Solarize", 392 | "SolarizeAdd", 393 | "Color", 394 | "Contrast", 395 | "Brightness", 396 | "Sharpness", 397 | "ShearX", 398 | "ShearY", 399 | "TranslateXRel", 400 | "TranslateYRel", 401 | ] 402 | 403 | 404 | _RAND_INCREASING_TRANSFORMS = [ 405 | "AutoContrast", 406 | "Equalize", 407 | "Invert", 408 | "Rotate", 409 | "PosterizeIncreasing", 410 | "SolarizeIncreasing", 411 | "SolarizeAdd", 412 | "ColorIncreasing", 413 | "ContrastIncreasing", 414 | "BrightnessIncreasing", 415 | "SharpnessIncreasing", 416 | "ShearX", 417 | "ShearY", 418 | "TranslateXRel", 419 | "TranslateYRel", 420 | ] 421 | 422 | 423 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 424 | # They may not result in increased performance, but could likely be tuned to so. 425 | _RAND_CHOICE_WEIGHTS_0 = { 426 | "Rotate": 0.3, 427 | "ShearX": 0.2, 428 | "ShearY": 0.2, 429 | "TranslateXRel": 0.1, 430 | "TranslateYRel": 0.1, 431 | "Color": 0.025, 432 | "Sharpness": 0.025, 433 | "AutoContrast": 0.025, 434 | "Solarize": 0.005, 435 | "SolarizeAdd": 0.005, 436 | "Contrast": 0.005, 437 | "Brightness": 0.005, 438 | "Equalize": 0.005, 439 | "Posterize": 0, 440 | "Invert": 0, 441 | } 442 | 443 | 444 | def _select_rand_weights(weight_idx=0, transforms=None): 445 | transforms = transforms or _RAND_TRANSFORMS 446 | assert weight_idx == 0 # only one set of weights currently 447 | rand_weights = _RAND_CHOICE_WEIGHTS_0 448 | probs = [rand_weights[k] for k in transforms] 449 | probs /= np.sum(probs) 450 | return probs 451 | 452 | 453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 454 | hparams = hparams or _HPARAMS_DEFAULT 455 | transforms = transforms or _RAND_TRANSFORMS 456 | return [ 457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 458 | for name in transforms 459 | ] 460 | 461 | 462 | class RandAugment: 463 | def __init__(self, ops, num_layers=2, choice_weights=None): 464 | self.ops = ops 465 | self.num_layers = num_layers 466 | self.choice_weights = choice_weights 467 | 468 | def __call__(self, img): 469 | # no replacement when using weighted choice 470 | ops = np.random.choice( 471 | self.ops, 472 | self.num_layers, 473 | replace=self.choice_weights is None, 474 | p=self.choice_weights, 475 | ) 476 | for op in ops: 477 | img = op(img) 478 | return img 479 | 480 | 481 | def rand_augment_transform(config_str, hparams): 482 | """ 483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 484 | 485 | Create a RandAugment transform 486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 488 | sections, not order sepecific determine 489 | 'm' - integer magnitude of rand augment 490 | 'n' - integer num layers (number of transform ops selected per image) 491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 492 | 'mstd' - float std deviation of magnitude noise applied 493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 497 | :return: A PyTorch compatible Transform 498 | """ 499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 500 | num_layers = 2 # default to 2 ops per image 501 | weight_idx = None # default to no probability weights for op choice 502 | transforms = _RAND_TRANSFORMS 503 | config = config_str.split("-") 504 | assert config[0] == "rand" 505 | config = config[1:] 506 | for c in config: 507 | cs = re.split(r"(\d.*)", c) 508 | if len(cs) < 2: 509 | continue 510 | key, val = cs[:2] 511 | if key == "mstd": 512 | # noise param injected via hparams for now 513 | hparams.setdefault("magnitude_std", float(val)) 514 | elif key == "inc": 515 | if bool(val): 516 | transforms = _RAND_INCREASING_TRANSFORMS 517 | elif key == "m": 518 | magnitude = int(val) 519 | elif key == "n": 520 | num_layers = int(val) 521 | elif key == "w": 522 | weight_idx = int(val) 523 | else: 524 | assert NotImplementedError 525 | ra_ops = rand_augment_ops( 526 | magnitude=magnitude, hparams=hparams, transforms=transforms 527 | ) 528 | choice_weights = ( 529 | None if weight_idx is None else _select_rand_weights(weight_idx) 530 | ) 531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 532 | -------------------------------------------------------------------------------- /tvr/dataloaders/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /tvr/dataloaders/rawvideo_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from PIL import Image 4 | # pytorch=1.7.1 5 | # pip install opencv-python 6 | import cv2 7 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, ToPILImage, ColorJitter, RandomHorizontalFlip, RandomResizedCrop 8 | import tvr.dataloaders.video_transforms as video_transforms 9 | from .random_erasing import RandomErasing 10 | 11 | 12 | class RawVideoExtractorCV2(): 13 | def __init__(self, centercrop=False, size=224, framerate=-1, subset="test"): 14 | self.centercrop = centercrop 15 | self.size = size 16 | self.framerate = framerate 17 | self.transform = self._transform(self.size) 18 | self.subset = subset 19 | self.tsfm_dict = { 20 | 'clip_test': Compose([ 21 | Resize(size, interpolation=InterpolationMode.BICUBIC), 22 | CenterCrop(size), 23 | lambda image: image.convert("RGB"), 24 | ToTensor(), 25 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 26 | ]), 27 | 'clip_train': Compose([ 28 | RandomResizedCrop(size, scale=(0.5, 1.0)), 29 | RandomHorizontalFlip(), 30 | lambda image: image.convert("RGB"), 31 | ToTensor(), 32 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 33 | ]) 34 | } 35 | self.aug_transform = video_transforms.create_random_augment( 36 | input_size=(size, size), 37 | auto_augment='rand-m7-n4-mstd0.5-inc1', 38 | interpolation='bicubic', 39 | ) 40 | 41 | def _transform(self, n_px): 42 | return Compose([ 43 | Resize(n_px, interpolation=InterpolationMode.BICUBIC), 44 | CenterCrop(n_px), 45 | lambda image: image.convert("RGB"), 46 | ToTensor(), 47 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 48 | # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 49 | ]) 50 | 51 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None, _no_process=False): 52 | if start_time is not None or end_time is not None: 53 | assert isinstance(start_time, int) and isinstance(end_time, int) \ 54 | and start_time > -1 and end_time > start_time 55 | assert sample_fp > -1 56 | 57 | # Samples a frame sample_fp X frames. 58 | cap = cv2.VideoCapture(video_file) 59 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 60 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 61 | 62 | if fps == 0: 63 | print((video_file + '\n') * 10) 64 | total_duration = (frameCount + fps - 1) // fps 65 | start_sec, end_sec = 0, total_duration 66 | 67 | if start_time is not None: 68 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration 69 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) 70 | 71 | interval = 1 72 | if sample_fp > 0: 73 | interval = fps // sample_fp 74 | else: 75 | sample_fp = fps 76 | if interval == 0: interval = 1 77 | 78 | inds = [ind for ind in np.arange(0, fps, interval)] 79 | assert len(inds) >= sample_fp 80 | inds = inds[:sample_fp] 81 | 82 | ret = True 83 | images, included = [], [] 84 | 85 | for sec in np.arange(start_sec, end_sec + 1): 86 | if not ret: break 87 | sec_base = int(sec * fps) 88 | for ind in inds: 89 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) 90 | ret, frame = cap.read() 91 | if not ret: break 92 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 93 | if _no_process: 94 | images.append(Image.fromarray(frame_rgb).convert("RGB")) 95 | else: 96 | # images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 97 | images.append(Image.fromarray(frame_rgb)) 98 | 99 | cap.release() 100 | 101 | if len(images) > 0: 102 | if _no_process: 103 | video_data = images 104 | else: 105 | if self.subset == "train": 106 | # for i in range(2): 107 | images = self.aug_transform(images) 108 | 109 | # if self.subset == "train": 110 | # patch_images = torch.stack([self.tsfm_dict["clip_train"](img) for img in patch_images]) 111 | # else: 112 | # patch_images = torch.stack([self.tsfm_dict["clip_test"](img) for img in patch_images]) 113 | 114 | video_data = th.stack([preprocess(img) for img in images]) 115 | # video_data = th.tensor(np.stack(images)) 116 | else: 117 | video_data = th.zeros(1) 118 | return {'video': video_data} 119 | 120 | def get_video_data(self, video_path, start_time=None, end_time=None, _no_process=False): 121 | image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, 122 | end_time=end_time, _no_process=_no_process) 123 | return image_input 124 | 125 | def process_raw_data(self, raw_video_data): 126 | tensor_size = raw_video_data.size() 127 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1]) 128 | return tensor 129 | 130 | def process_frame_order(self, raw_video_data, frame_order=0): 131 | # 0: ordinary order; 1: reverse order; 2: random order. 132 | if frame_order == 0: 133 | pass 134 | elif frame_order == 1: 135 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) 136 | raw_video_data = raw_video_data[reverse_order, ...] 137 | elif frame_order == 2: 138 | random_order = np.arange(raw_video_data.size(0)) 139 | np.random.shuffle(random_order) 140 | raw_video_data = raw_video_data[random_order, ...] 141 | 142 | return raw_video_data 143 | 144 | 145 | # An ordinary video frame extractor based CV2 146 | RawVideoExtractor = RawVideoExtractorCV2 147 | -------------------------------------------------------------------------------- /tvr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__init__.py -------------------------------------------------------------------------------- /tvr/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/modeling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/modeling.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/module_clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/module_clip.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/module_cross.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/module_cross.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/module_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/module_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/optimization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/optimization.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/query_cross_att.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/query_cross_att.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/tokenization_clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/tokenization_clip.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/transformer_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/transformer_block.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/__pycache__/until_module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/until_module.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/models/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /tvr/models/module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # ------------------------------------------------------ 3 | # -------------------- LOCATE Module ------------------- 4 | # ------------------------------------------------------ 5 | import torch 6 | import torch.nn as nn 7 | from models.attention import SoftAttention, GumbelAttention 8 | 9 | class LOCATE(nn.Module): 10 | def __init__(self, opt): 11 | super(LOCATE, self).__init__() 12 | # spatial soft attention module 13 | self.spatial_attn = SoftAttention(opt.region_projected_size, opt.hidden_size, opt.hidden_size) 14 | 15 | # temporal soft attention module 16 | feat_size = opt.region_projected_size + opt.hidden_size * 2 17 | self.temp_attn = SoftAttention(feat_size, opt.hidden_size, opt.hidden_size) 18 | 19 | def forward(self, frame_feats, object_feats, hidden_state): 20 | """ 21 | :param frame_feats: (batch_size, max_frames, 2*hidden_size) 22 | :param object_feats: (batch_size, max_frames, num_boxes, region_projected_size) 23 | :param hidden_state: (batch_size, hidden_size) 24 | :return: loc_feat: (batch_size, feat_size) 25 | """ 26 | # spatial attention 27 | bsz, max_frames, num_boxes, fsize = object_feats.size() 28 | feats = object_feats.reshape(bsz * max_frames, num_boxes, fsize) 29 | hidden = hidden_state.repeat(1, max_frames).reshape(bsz * max_frames, -1) 30 | object_feats_att, _ = self.spatial_attn(feats, hidden) 31 | object_feats_att = object_feats_att.reshape(bsz, max_frames, fsize) 32 | 33 | # temporal attention 34 | feat = torch.cat([object_feats_att, frame_feats], dim=-1) 35 | loc_feat, _ = self.temp_attn(feat, hidden_state) 36 | return loc_feat 37 | 38 | 39 | # ------------------------------------------------------ 40 | # -------------------- RELATE Module ------------------- 41 | # ------------------------------------------------------ 42 | 43 | class RELATE(nn.Module): 44 | def __init__(self, opt): 45 | super(RELATE, self).__init__() 46 | 47 | # spatial soft attention module 48 | region_feat_size = opt.region_projected_size 49 | self.spatial_attn = SoftAttention(region_feat_size, opt.hidden_size, opt.hidden_size) 50 | 51 | # temporal soft attention module 52 | feat_size = region_feat_size + opt.hidden_size * 2 53 | self.relation_attn = SoftAttention(2*feat_size, opt.hidden_size, opt.hidden_size) 54 | 55 | def forward(self, i3d_feats, object_feats, hidden_state): 56 | ''' 57 | :param i3d_feats: (batch_size, max_frames, 2*hidden_size) 58 | :param object_feats: (batch_size, max_frames, num_boxes, region_projected_size) 59 | :param hidden_state: (batch_size, hidden_size) 60 | :return: rel_feat 61 | ''' 62 | # spatial atttention 63 | bsz, max_frames, num_boxes, fsize = object_feats.size() 64 | feats = object_feats.reshape(bsz * max_frames, num_boxes, fsize) 65 | hidden = hidden_state.repeat(1, max_frames).reshape(bsz * max_frames, -1) 66 | object_feats_att, _ = self.spatial_attn(feats, hidden) 67 | object_feats_att = object_feats_att.reshape(bsz, max_frames, fsize) 68 | 69 | # generate pair-wise feature 70 | feat = torch.cat([object_feats_att, i3d_feats], dim=-1) 71 | feat1 = feat.repeat(1, max_frames, 1) 72 | feat2 = feat.repeat(1, 1, max_frames).reshape(bsz, max_frames*max_frames, -1) 73 | pairwise_feat = torch.cat([feat1, feat2], dim=-1) 74 | 75 | # temporal attention 76 | rel_feat, _ = self.relation_attn(pairwise_feat, hidden_state) 77 | return rel_feat 78 | 79 | 80 | # ------------------------------------------------------ 81 | # -------------------- FUNC Module --------------------- 82 | # ------------------------------------------------------ 83 | 84 | class FUNC(nn.Module): 85 | def __init__(self, opt): 86 | super(FUNC, self).__init__() 87 | self.cell_attn = SoftAttention(opt.hidden_size, opt.hidden_size, opt.hidden_size) 88 | 89 | def forward(self, cells, hidden_state): 90 | ''' 91 | :param cells: previous memory states of decoder LSTM 92 | :param hidden_state: (batch_size, hidden_size) 93 | :return: func_feat 94 | ''' 95 | func_feat, _ = self.cell_attn(cells, hidden_state) 96 | return func_feat 97 | 98 | 99 | 100 | # ------------------------------------------------------ 101 | # ------------------- Module Selector ------------------ 102 | # ------------------------------------------------------ 103 | 104 | class ModuleSelection(nn.Module): 105 | def __init__(self, opt): 106 | super(ModuleSelection, self).__init__() 107 | self.use_loc = opt.use_loc 108 | self.use_rel = opt.use_rel 109 | self.use_func = opt.use_func 110 | 111 | if opt.use_loc: 112 | loc_feat_size = opt.region_projected_size + opt.hidden_size * 2 113 | self.loc_fc = nn.Linear(loc_feat_size, opt.hidden_size) 114 | nn.init.xavier_normal_(self.loc_fc.weight) 115 | 116 | if opt.use_rel: 117 | rel_feat_size = 2 * (opt.region_projected_size + 2 * opt.hidden_size) 118 | self.rel_fc = nn.Linear(rel_feat_size, opt.hidden_size) 119 | nn.init.xavier_normal_(self.rel_fc.weight) 120 | 121 | if opt.use_func: 122 | func_size = opt.hidden_size 123 | self.func_fc = nn.Linear(func_size, opt.hidden_size) 124 | nn.init.xavier_normal_(self.func_fc.weight) 125 | 126 | if opt.use_loc and opt.use_rel and opt.use_func: 127 | if opt.attention == 'soft': 128 | self.module_attn = SoftAttention(opt.hidden_size, opt.hidden_size, opt.hidden_size) 129 | elif opt.attention == 'gumbel': 130 | self.module_attn = GumbelAttention(opt.hidden_size, opt.hidden_size, opt.hidden_size) 131 | 132 | def forward(self, loc_feats, rel_feats, func_feats, hidden_state): 133 | ''' 134 | soft attention: Weighted sum of three features 135 | gumbel attention: Choose one of three features 136 | ''' 137 | loc_feats = self.loc_fc(loc_feats) if self.use_loc else None 138 | rel_feats = self.rel_fc(rel_feats) if self.use_rel else None 139 | func_feats = self.func_fc(func_feats) if self.use_func else None 140 | 141 | if self.use_loc and self.use_rel and self.use_func: 142 | feats = torch.stack([loc_feats, rel_feats, func_feats], dim=1) 143 | feats, module_weight = self.module_attn(feats, hidden_state) 144 | 145 | elif self.use_loc and not self.use_rel: 146 | feats = loc_feats 147 | module_weight = torch.tensor([0.3, 0.3, 0.4]).cuda() 148 | elif self.use_rel and not self.use_loc: 149 | feats = rel_feats 150 | module_weight = torch.tensor([0.3, 0.3, 0.4]).cuda() 151 | 152 | return feats, module_weight -------------------------------------------------------------------------------- /tvr/models/module_clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | """ 4 | from collections import OrderedDict 5 | from typing import Tuple, Union 6 | 7 | import hashlib 8 | import os 9 | import urllib 10 | import warnings 11 | from tqdm import tqdm 12 | from .module_transformer import Transformer as TransformerClip 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | 17 | _MODELS = { 18 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 19 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 20 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 21 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 22 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 23 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 24 | } 25 | _PT_NAME = { 26 | "RN50": "RN50.pt", 27 | "RN101": "RN101.pt", 28 | "RN50x4": "RN50x4.pt", 29 | "RN50x16": "RN50x16.pt", 30 | "ViT-B/32": "ViT-B-32.pt", 31 | "ViT-B/16": "ViT-B-16.pt", 32 | "ViT-L/14": "ViT-L-14.pt", 33 | } 34 | 35 | 36 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 37 | os.makedirs(root, exist_ok=True) 38 | filename = os.path.basename(url) 39 | 40 | expected_sha256 = url.split("/")[-2] 41 | download_target = os.path.join(root, filename) 42 | 43 | if os.path.exists(download_target) and not os.path.isfile(download_target): 44 | raise RuntimeError(f"{download_target} exists and is not a regular file") 45 | 46 | if os.path.isfile(download_target): 47 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 48 | return download_target 49 | else: 50 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 51 | 52 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 53 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 54 | while True: 55 | buffer = source.read(8192) 56 | if not buffer: 57 | break 58 | 59 | output.write(buffer) 60 | loop.update(len(buffer)) 61 | 62 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 63 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 64 | 65 | return download_target 66 | 67 | 68 | def available_models(): 69 | """Returns the names of available CLIP models""" 70 | return list(_MODELS.keys()) 71 | 72 | 73 | # ============================= 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1): 79 | super(Bottleneck, self).__init__() 80 | 81 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 82 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 83 | self.bn1 = nn.BatchNorm2d(planes) 84 | 85 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 86 | self.bn2 = nn.BatchNorm2d(planes) 87 | 88 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 89 | 90 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 91 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 92 | 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = None 95 | self.stride = stride 96 | 97 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 98 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 99 | self.downsample = nn.Sequential(OrderedDict([ 100 | ("-1", nn.AvgPool2d(stride)), 101 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 102 | ("1", nn.BatchNorm2d(planes * self.expansion)) 103 | ])) 104 | 105 | def forward(self, x: torch.Tensor): 106 | identity = x 107 | 108 | out = self.relu(self.bn1(self.conv1(x))) 109 | out = self.relu(self.bn2(self.conv2(out))) 110 | out = self.avgpool(out) 111 | out = self.bn3(self.conv3(out)) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | return out 119 | 120 | 121 | class AttentionPool2d(nn.Module): 122 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 123 | super(AttentionPool2d, self).__init__() 124 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 125 | self.k_proj = nn.Linear(embed_dim, embed_dim) 126 | self.q_proj = nn.Linear(embed_dim, embed_dim) 127 | self.v_proj = nn.Linear(embed_dim, embed_dim) 128 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 129 | self.num_heads = num_heads 130 | 131 | def forward(self, x): 132 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 133 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 134 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 135 | x, _ = F.multi_head_attention_forward( 136 | query=x, key=x, value=x, 137 | embed_dim_to_check=x.shape[-1], 138 | num_heads=self.num_heads, 139 | q_proj_weight=self.q_proj.weight, 140 | k_proj_weight=self.k_proj.weight, 141 | v_proj_weight=self.v_proj.weight, 142 | in_proj_weight=None, 143 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 144 | bias_k=None, 145 | bias_v=None, 146 | add_zero_attn=False, 147 | dropout_p=0, 148 | out_proj_weight=self.c_proj.weight, 149 | out_proj_bias=self.c_proj.bias, 150 | use_separate_proj_weight=True, 151 | training=self.training, 152 | need_weights=False 153 | ) 154 | 155 | return x[0] 156 | 157 | 158 | class ModifiedResNet(nn.Module): 159 | """ 160 | A ResNet class that is similar to torchvision's but contains the following changes: 161 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 162 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 163 | - The final pooling layer is a QKV attention instead of an average pool 164 | """ 165 | 166 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 167 | super(ModifiedResNet, self).__init__() 168 | self.output_dim = output_dim 169 | self.input_resolution = input_resolution 170 | 171 | # the 3-layer stem 172 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 173 | self.bn1 = nn.BatchNorm2d(width // 2) 174 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 175 | self.bn2 = nn.BatchNorm2d(width // 2) 176 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 177 | self.bn3 = nn.BatchNorm2d(width) 178 | self.avgpool = nn.AvgPool2d(2) 179 | self.relu = nn.ReLU(inplace=True) 180 | 181 | # residual layers 182 | self._inplanes = width # this is a *mutable* variable used during construction 183 | self.layer1 = self._make_layer(width, layers[0]) 184 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 185 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 186 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 187 | 188 | embed_dim = width * 32 # the ResNet feature dimension 189 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 190 | 191 | def _make_layer(self, planes, blocks, stride=1): 192 | layers = [Bottleneck(self._inplanes, planes, stride)] 193 | 194 | self._inplanes = planes * Bottleneck.expansion 195 | for _ in range(1, blocks): 196 | layers.append(Bottleneck(self._inplanes, planes)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def forward(self, x): 201 | def stem(x): 202 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 203 | x = self.relu(bn(conv(x))) 204 | x = self.avgpool(x) 205 | return x 206 | 207 | x = x.type(self.conv1.weight.dtype) 208 | x = stem(x) 209 | x = self.layer1(x) 210 | x = self.layer2(x) 211 | x = self.layer3(x) 212 | x = self.layer4(x) 213 | x = self.attnpool(x) 214 | 215 | return x 216 | 217 | 218 | class LayerNorm(nn.LayerNorm): 219 | """Subclass torch's LayerNorm to handle fp16.""" 220 | 221 | def forward(self, x: torch.Tensor): 222 | orig_type = x.dtype 223 | ret = super().forward(x.type(torch.float32)) 224 | return ret.type(orig_type) 225 | 226 | 227 | class QuickGELU(nn.Module): 228 | def forward(self, x: torch.Tensor): 229 | return x * torch.sigmoid(1.702 * x) 230 | 231 | 232 | class ResidualAttentionBlock(nn.Module): 233 | def __init__(self, d_model: int, n_head: int, attn_mask=None): 234 | super(ResidualAttentionBlock, self).__init__() 235 | 236 | self.attn = nn.MultiheadAttention(d_model, n_head) 237 | self.ln_1 = LayerNorm(d_model) 238 | self.mlp = nn.Sequential(OrderedDict([ 239 | ("c_fc", nn.Linear(d_model, d_model * 4)), 240 | ("gelu", QuickGELU()), 241 | ("c_proj", nn.Linear(d_model * 4, d_model)) 242 | ])) 243 | self.ln_2 = LayerNorm(d_model) 244 | self.attn_mask = attn_mask 245 | 246 | def attention(self, x: torch.Tensor): 247 | attn_mask_ = self.attn_mask 248 | if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): 249 | attn_mask_ = self.attn_mask(x.size(0)) # LND 250 | 251 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None 252 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 253 | 254 | def forward(self, x): 255 | x = x + self.attention(self.ln_1(x)) 256 | x = x + self.mlp(self.ln_2(x)) 257 | return x 258 | 259 | 260 | class Transformer(nn.Module): 261 | def __init__(self, width: int, layers: int, heads: int, attn_mask=None): 262 | super(Transformer, self).__init__() 263 | self.width = width 264 | self.layers = layers 265 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 266 | 267 | def forward(self, x: torch.Tensor): 268 | return self.resblocks(x) 269 | 270 | 271 | class VisualTransformer(nn.Module): 272 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 273 | super(VisualTransformer, self).__init__() 274 | self.input_resolution = input_resolution 275 | self.output_dim = output_dim 276 | 277 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 278 | 279 | scale = width ** -0.5 280 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 281 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 282 | self.ln_pre = LayerNorm(width) 283 | 284 | self.transformer = Transformer(width, layers, heads) 285 | 286 | self.ln_post = LayerNorm(width) 287 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 288 | 289 | for param in self.conv1.parameters(): 290 | param.requires_grad = False # not update by gradient 291 | 292 | # self.class_embedding.requires_grad = False # not update by gradient 293 | # self.positional_embedding.requires_grad = False # not update by gradient 294 | 295 | 296 | def forward(self, x: torch.Tensor, mask=None): 297 | 298 | x = self.conv1(x) # shape = [*, width, grid, grid] 299 | 300 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 301 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 302 | x = torch.cat( 303 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 304 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 305 | 306 | x = x + self.positional_embedding.to(x.dtype) 307 | x = self.ln_pre(x) 308 | 309 | # zero = torch.zeros((x.size(1), x.size(1))).repeat(x.size(0), 1, 1).to(x.device) 310 | # inf = torch.zeros((x.size(1), x.size(1))).fill_(float("-inf")).repeat(x.size(0), 1, 1).to(mask.device) 311 | # _mask = mask.view(-1).unsqueeze(1).unsqueeze(1).expand(-1, x.size(1), x.size(1)) 312 | # attn_mask = torch.where(_mask>0, zero, inf) 313 | # attn_mask[:, 0, 0] = 0 314 | 315 | x = x.permute(1, 0, 2) # NLD -> LND 316 | x = self.transformer(x) 317 | x = x.permute(1, 0, 2) # LND -> NLD 318 | 319 | # mask = mask.view(-1).unsqueeze(1).unsqueeze(1).expand(-1, x.size(1), x.size(2)) 320 | # zero = torch.zeros((x.size(0), x.size(1), x.size(2))).to(x.device).type(x.dtype) 321 | # x = torch.where(mask>0, x, zero) 322 | # Move the three lines below to `encode_image` for entire hidden sequence 323 | # x = self.ln_post(x[:, 0, :]) 324 | # if self.proj is not None: 325 | # x = x @ self.proj 326 | 327 | return x 328 | 329 | 330 | class CLIP(nn.Module): 331 | def __init__(self, 332 | embed_dim: int, 333 | # vision 334 | image_resolution: int, 335 | vision_layers: Union[Tuple[int, int, int, int], int], 336 | vision_width: int, 337 | vision_patch_size: int, 338 | # text 339 | context_length: int, 340 | vocab_size: int, 341 | transformer_width: int, 342 | transformer_heads: int, 343 | transformer_layers: int 344 | ): 345 | super(CLIP, self).__init__() 346 | 347 | self.context_length = context_length 348 | 349 | if isinstance(vision_layers, (tuple, list)): 350 | vision_heads = vision_width * 32 // 64 351 | self.visual = ModifiedResNet( 352 | layers=vision_layers, 353 | output_dim=embed_dim, 354 | heads=vision_heads, 355 | input_resolution=image_resolution, 356 | width=vision_width 357 | ) 358 | else: 359 | vision_heads = vision_width // 64 360 | self.visual = VisualTransformer( 361 | input_resolution=image_resolution, 362 | patch_size=vision_patch_size, 363 | width=vision_width, 364 | layers=vision_layers, 365 | heads=vision_heads, 366 | output_dim=embed_dim, 367 | ) 368 | 369 | self.transformer = TransformerClip( 370 | width=transformer_width, 371 | layers=transformer_layers, 372 | heads=transformer_heads, 373 | # attn_mask=self.build_attention_mask 374 | ) 375 | 376 | self.vocab_size = vocab_size 377 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 378 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 379 | self.ln_final = LayerNorm(transformer_width) 380 | 381 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 382 | self.logit_scale = nn.Parameter(torch.ones([])) 383 | 384 | self.initialize_parameters() 385 | 386 | # for param in self.transformer.parameters(): 387 | # param.requires_grad = False # not update by gradient 388 | self.token_embedding.requires_grad = False # not update by gradient 389 | # self.positional_embedding.requires_grad = False # not update by gradient 390 | 391 | def initialize_parameters(self): 392 | nn.init.normal_(self.token_embedding.weight, std=0.02) 393 | nn.init.normal_(self.positional_embedding, std=0.01) 394 | 395 | if isinstance(self.visual, ModifiedResNet): 396 | if self.visual.attnpool is not None: 397 | std = self.visual.attnpool.c_proj.in_features ** -0.5 398 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 399 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 400 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 401 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 402 | 403 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 404 | for name, param in resnet_block.named_parameters(): 405 | if name.endswith("bn3.weight"): 406 | nn.init.zeros_(param) 407 | 408 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 409 | attn_std = self.transformer.width ** -0.5 410 | fc_std = (2 * self.transformer.width) ** -0.5 411 | for block in self.transformer.resblocks: 412 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 413 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 414 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 415 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 416 | 417 | if self.text_projection is not None: 418 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 419 | 420 | @staticmethod 421 | def get_config(pretrained_clip_name="ViT-B/32"): 422 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt") 423 | if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME: 424 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name]) 425 | 426 | if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path): 427 | pass 428 | else: 429 | if pretrained_clip_name in _MODELS: 430 | model_path = _download(_MODELS[pretrained_clip_name]) 431 | elif os.path.isfile(pretrained_clip_name): 432 | model_path = pretrained_clip_name 433 | else: 434 | raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}") 435 | 436 | try: 437 | # loading JIT archive 438 | model = torch.jit.load(model_path, map_location="cpu").eval() 439 | state_dict = model.state_dict() 440 | except RuntimeError: 441 | state_dict = torch.load(model_path, map_location="cpu") 442 | 443 | return state_dict 444 | 445 | def build_attention_mask(self, context_length): 446 | # lazily create causal attention mask, with full attention between the vision tokens 447 | # pytorch uses additive attention mask; fill with -inf 448 | mask = torch.zeros(context_length, context_length) 449 | mask.fill_(float("-inf")) 450 | mask.triu_(1) # zero out the lower diagonal 451 | return mask 452 | 453 | @property 454 | def dtype(self): 455 | return self.visual.conv1.weight.dtype 456 | 457 | def encode_image(self, image, return_hidden=False, mask=None): 458 | hidden = self.visual(image.type(self.dtype)) 459 | hidden = self.visual.ln_post(hidden) @ self.visual.proj 460 | 461 | x = hidden[:, 0, :] 462 | 463 | if return_hidden: 464 | return x, hidden 465 | 466 | return x 467 | 468 | def encode_text(self, text, return_hidden=False, mask=None): 469 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 470 | 471 | pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) 472 | 473 | attn_mask = self.build_attention_mask(x.size(1)).repeat(x.size(0), 1, 1).to(mask.device) 474 | inf = torch.zeros((x.size(1), x.size(1))).fill_(float("-inf")).repeat(x.size(0), 1, 1).to(mask.device) 475 | mask = mask.unsqueeze(1).expand(-1, mask.size(1), -1) 476 | attn_mask = torch.where(mask>0, attn_mask, inf) 477 | 478 | x = x + pos_emd 479 | x = x.permute(1, 0, 2) # NLD -> LND 480 | x = self.transformer(x, attn_mask) 481 | x = x.permute(1, 0, 2) # LND -> NLD 482 | 483 | hidden = self.ln_final(x).type(self.dtype) @ self.text_projection 484 | 485 | # x.shape = [batch_size, n_ctx, transformer.width] 486 | # take features from the eot embedding (eot_token is the highest number in each sequence) 487 | x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] 488 | 489 | if return_hidden: 490 | return x, hidden 491 | 492 | return x 493 | 494 | def forward(self, image, text): 495 | image_features = self.encode_image(image) 496 | text_features = self.encode_text(text) 497 | 498 | # normalized features 499 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 500 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 501 | 502 | # cosine similarity as logits 503 | logit_scale = self.logit_scale.exp() 504 | logits_per_image = logit_scale * image_features @ text_features.t() 505 | logits_per_text = logit_scale * text_features @ image_features.t() 506 | 507 | # shape = [global_batch_size, global_batch_size] 508 | return logits_per_image, logits_per_text 509 | 510 | 511 | def convert_weights(model: nn.Module): 512 | """Convert applicable model parameters to fp16""" 513 | 514 | def _convert_weights_to_fp16(l): 515 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 516 | l.weight.data = l.weight.data.half() 517 | if l.bias is not None: 518 | l.bias.data = l.bias.data.half() 519 | 520 | if isinstance(l, nn.MultiheadAttention): 521 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 522 | tensor = getattr(l, attr) 523 | if tensor is not None: 524 | tensor.data = tensor.data.half() 525 | 526 | for name in ["text_projection", "proj"]: 527 | if hasattr(l, name): 528 | attr = getattr(l, name) 529 | if attr is not None: 530 | attr.data = attr.data.half() 531 | 532 | model.apply(_convert_weights_to_fp16) -------------------------------------------------------------------------------- /tvr/models/module_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | from timm.models.layers import drop_path 7 | import torch 8 | from torch import nn 9 | from .until_module import LayerNorm, ACT2FN 10 | from collections import OrderedDict 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 15 | CONFIG_NAME = 'cross_config.json' 16 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 17 | 18 | 19 | class QuickGELU(nn.Module): 20 | def forward(self, x: torch.Tensor): 21 | return x * torch.sigmoid(1.702 * x) 22 | 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 26 | """ 27 | 28 | def __init__(self, drop_prob=None): 29 | super(DropPath, self).__init__() 30 | self.drop_prob = drop_prob 31 | 32 | def forward(self, x): 33 | return drop_path(x, self.drop_prob, self.training) 34 | 35 | def extra_repr(self) -> str: 36 | return 'p={}'.format(self.drop_prob) 37 | 38 | 39 | class ResidualAttentionBlock(nn.Module): 40 | def __init__(self, d_model: int, n_head: int, drop_path=0.0): 41 | super().__init__() 42 | 43 | self.attn = nn.MultiheadAttention(d_model, n_head) 44 | self.ln_1 = LayerNorm(d_model) 45 | self.mlp = nn.Sequential(OrderedDict([ 46 | ("c_fc", nn.Linear(d_model, d_model * 4)), 47 | ("gelu", QuickGELU()), 48 | ("c_proj", nn.Linear(d_model * 4, d_model)) 49 | ])) 50 | self.ln_2 = LayerNorm(d_model) 51 | self.n_head = n_head 52 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 53 | 54 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): 55 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) 56 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 57 | 58 | def forward(self, para_tuple: tuple): 59 | # x: torch.Tensor, attn_mask: torch.Tensor 60 | # print(para_tuple) 61 | x, attn_mask = para_tuple 62 | if self.training: 63 | x = x + self.drop_path(self.attention(self.ln_1(x), attn_mask)) 64 | x = x + self.drop_path(self.mlp(self.ln_2(x))) 65 | else: 66 | x = x + self.attention(self.ln_1(x), attn_mask) 67 | x = x + self.mlp(self.ln_2(x)) 68 | return (x, attn_mask) 69 | 70 | 71 | class Transformer(nn.Module): 72 | def __init__(self, width: int, layers: int, heads: int): 73 | super().__init__() 74 | self.width = width 75 | self.layers = layers 76 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 77 | 78 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 79 | return self.resblocks((x, attn_mask))[0] 80 | 81 | 82 | class CrossEmbeddings(nn.Module): 83 | """Construct the embeddings from word, position and token_type embeddings. 84 | """ 85 | 86 | def __init__(self, config): 87 | super(CrossEmbeddings, self).__init__() 88 | 89 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 90 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 91 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 92 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 93 | 94 | def forward(self, concat_embeddings, concat_type=None): 95 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) 96 | # if concat_type is None: 97 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) 98 | 99 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) 100 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) 101 | 102 | # token_type_embeddings = self.token_type_embeddings(concat_type) 103 | position_embeddings = self.position_embeddings(position_ids) 104 | 105 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings 106 | # embeddings = self.LayerNorm(embeddings) 107 | embeddings = self.dropout(embeddings) 108 | return embeddings 109 | 110 | 111 | class CrossPooler(nn.Module): 112 | def __init__(self, config): 113 | super(CrossPooler, self).__init__() 114 | self.ln_pool = LayerNorm(config.hidden_size) 115 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 116 | self.activation = QuickGELU() 117 | 118 | def forward(self, hidden_states, hidden_mask): 119 | # We "pool" the model by simply taking the hidden state corresponding 120 | # to the first token. 121 | hidden_states = self.ln_pool(hidden_states) 122 | pooled_output = hidden_states[:, 0] 123 | pooled_output = self.dense(pooled_output) 124 | pooled_output = self.activation(pooled_output) 125 | return pooled_output 126 | 127 | 128 | class CrossModel(nn.Module): 129 | 130 | def initialize_parameters(self): 131 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 132 | attn_std = self.transformer.width ** -0.5 133 | fc_std = (2 * self.transformer.width) ** -0.5 134 | for block in self.transformer.resblocks: 135 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 136 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 137 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 138 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 139 | 140 | def __init__(self, config): 141 | super(CrossModel, self).__init__() 142 | self.config = config 143 | 144 | self.embeddings = CrossEmbeddings(config) 145 | 146 | transformer_width = config.hidden_size 147 | transformer_layers = config.num_hidden_layers 148 | transformer_heads = config.num_attention_heads 149 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads, ) 150 | self.pooler = CrossPooler(config) 151 | self.apply(self.init_weights) 152 | 153 | def build_attention_mask(self, attention_mask): 154 | extended_attention_mask = attention_mask.unsqueeze(1) 155 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 156 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 157 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) 158 | return extended_attention_mask 159 | 160 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True): 161 | 162 | if attention_mask is None: 163 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) 164 | if concat_type is None: 165 | concat_type = torch.zeros_like(attention_mask) 166 | 167 | extended_attention_mask = self.build_attention_mask(attention_mask) 168 | 169 | embedding_output = self.embeddings(concat_input, concat_type) 170 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND 171 | embedding_output = self.transformer(embedding_output, extended_attention_mask) 172 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD 173 | 174 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) 175 | 176 | return embedding_output, pooled_output 177 | 178 | @property 179 | def dtype(self): 180 | """ 181 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 182 | """ 183 | try: 184 | return next(self.parameters()).dtype 185 | except StopIteration: 186 | # For nn.DataParallel compatibility in PyTorch 1.5 187 | def find_tensor_attributes(module: nn.Module): 188 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 189 | return tuples 190 | 191 | gen = self._named_members(get_members_fn=find_tensor_attributes) 192 | first_tuple = next(gen) 193 | return first_tuple[1].dtype 194 | 195 | def init_weights(self, module): 196 | """ Initialize the weights. 197 | """ 198 | if isinstance(module, (nn.Linear, nn.Embedding)): 199 | # Slightly different from the TF version which uses truncated_normal for initialization 200 | # cf https://github.com/pytorch/pytorch/pull/5617 201 | module.weight.data.normal_(mean=0.0, std=0.02) 202 | elif isinstance(module, LayerNorm): 203 | if 'beta' in dir(module) and 'gamma' in dir(module): 204 | module.beta.data.zero_() 205 | module.gamma.data.fill_(1.0) 206 | else: 207 | module.bias.data.zero_() 208 | module.weight.data.fill_(1.0) 209 | if isinstance(module, nn.Linear) and module.bias is not None: 210 | module.bias.data.zero_() 211 | -------------------------------------------------------------------------------- /tvr/models/module_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | from timm.models.layers import drop_path 7 | import torch 8 | from torch import nn 9 | from .until_module import LayerNorm, ACT2FN 10 | from collections import OrderedDict 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 15 | CONFIG_NAME = 'cross_config.json' 16 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 17 | 18 | 19 | class LayerNorm(nn.LayerNorm): 20 | """Subclass torch's LayerNorm to handle fp16.""" 21 | 22 | def forward(self, x: torch.Tensor): 23 | orig_type = x.dtype 24 | ret = super().forward(x.type(torch.float32)) 25 | return ret.type(orig_type) 26 | 27 | 28 | class QuickGELU(nn.Module): 29 | def forward(self, x: torch.Tensor): 30 | return x * torch.sigmoid(1.702 * x) 31 | 32 | 33 | class ResidualAttentionBlock(nn.Module): 34 | def __init__(self, d_model: int, n_head: int, attn_mask=None): 35 | super(ResidualAttentionBlock, self).__init__() 36 | 37 | self.attn = nn.MultiheadAttention(d_model, n_head) 38 | self.ln_1 = LayerNorm(d_model) 39 | self.mlp = nn.Sequential(OrderedDict([ 40 | ("c_fc", nn.Linear(d_model, d_model * 4)), 41 | ("gelu", QuickGELU()), 42 | ("c_proj", nn.Linear(d_model * 4, d_model)) 43 | ])) 44 | self.ln_2 = LayerNorm(d_model) 45 | self.attn_mask = attn_mask 46 | self.n_head = n_head 47 | 48 | def attention(self, x: torch.Tensor, attn_mask_: torch.Tensor): 49 | attn_mask_ = attn_mask_.repeat_interleave(self.n_head, dim=0) 50 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None 51 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 52 | 53 | def forward(self, para_tuple: tuple): 54 | x, attn_mask = para_tuple 55 | x = x + self.attention(self.ln_1(x), attn_mask) 56 | x = x + self.mlp(self.ln_2(x)) 57 | return (x, attn_mask) 58 | 59 | 60 | class Transformer(nn.Module): 61 | def __init__(self, width: int, layers: int, heads: int, attn_mask=None): 62 | super(Transformer, self).__init__() 63 | self.width = width 64 | self.layers = layers 65 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 66 | 67 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 68 | return self.resblocks((x, attn_mask))[0] -------------------------------------------------------------------------------- /tvr/models/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /tvr/models/tokenization_clip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | self.vocab = self.encoder 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | def tokenize(self, text): 137 | tokens = [] 138 | text = whitespace_clean(basic_clean(text)).lower() 139 | for token in re.findall(self.pat, text): 140 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 141 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 142 | return tokens 143 | 144 | def convert_tokens_to_ids(self, tokens): 145 | return [self.encoder[bpe_token] for bpe_token in tokens] -------------------------------------------------------------------------------- /tvr/models/transformer_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on the implementation of https://github.com/jadore801120/attention-is-all-you-need-pytorch 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | class ScaledDotProductAttention(nn.Module): 10 | ''' Scaled Dot-Product Attention ''' 11 | def __init__(self, temperature, attn_dropout=0.1): 12 | super().__init__() 13 | self.temperature = temperature 14 | self.dropout = nn.Dropout(attn_dropout) 15 | self.softmax = nn.Softmax(dim=2) 16 | 17 | def forward(self, q, k, v, mask=None): 18 | """ 19 | Args: 20 | q (bsz, len_q, dim_q) 21 | k (bsz, len_k, dim_k) 22 | v (bsz, len_v, dim_v) 23 | Note: len_k==len_v, and dim_q==dim_k 24 | Returns: 25 | output (bsz, len_q, dim_v) 26 | attn (bsz, len_q, len_k) 27 | """ 28 | attn = torch.bmm(q, k.transpose(1, 2)) 29 | attn = attn / self.temperature 30 | 31 | if mask is not None: 32 | attn = attn.masked_fill(mask, -np.inf) 33 | 34 | attn = self.softmax(attn) 35 | attn = self.dropout(attn) 36 | output = torch.bmm(attn, v) 37 | 38 | return output, attn 39 | 40 | 41 | class MultiHeadAttention(nn.Module): 42 | ''' Multi-Head Attention module ''' 43 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 44 | super().__init__() 45 | self.n_head = n_head 46 | self.d_k = d_k 47 | self.d_v = d_v 48 | 49 | self.w_qs = nn.Linear(d_model, n_head * d_k) 50 | self.w_ks = nn.Linear(d_model, n_head * d_k) 51 | self.w_vs = nn.Linear(d_model, n_head * d_v) 52 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 53 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 54 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 55 | 56 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 57 | self.layer_norm = nn.LayerNorm(d_model) 58 | 59 | self.fc = nn.Linear(n_head * d_v, d_model) 60 | nn.init.xavier_normal_(self.fc.weight) 61 | 62 | self.dropout = nn.Dropout(dropout) 63 | 64 | 65 | def forward(self, q, k, v, mask=None): 66 | """ 67 | Args: 68 | q (bsz, len_q, dim_q) 69 | k (bsz, len_k, dim_k) 70 | v (bsz, len_v, dim_v) 71 | Note: len_k==len_v, and dim_q==dim_k 72 | Returns: 73 | output (bsz, len_q, d_model) 74 | attn (bsz, len_q, len_k) 75 | """ 76 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 77 | 78 | sz_b, len_q, _ = q.size() 79 | sz_b, len_k, _ = k.size() 80 | sz_b, len_v, _ = v.size() # len_k==len_v 81 | 82 | residual = q 83 | 84 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 85 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 86 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 87 | 88 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 89 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 90 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 91 | 92 | if mask is not None: 93 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 94 | output, attn = self.attention(q, k, v, mask=mask) 95 | 96 | output = output.view(n_head, sz_b, len_q, d_v) 97 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 98 | 99 | output = self.dropout(self.fc(output)) 100 | output = self.layer_norm(output + residual) 101 | 102 | return output 103 | 104 | 105 | class PositionwiseFeedForward(nn.Module): 106 | ''' A two-feed-forward-layer module ''' 107 | def __init__(self, d_in, d_hid, dropout=0.1): 108 | super().__init__() 109 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 110 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 111 | self.layer_norm = nn.LayerNorm(d_in) 112 | self.dropout = nn.Dropout(dropout) 113 | 114 | def forward(self, x): 115 | """ 116 | Merge adjacent information. Equal to linear layer if kernel size is 1 117 | Args: 118 | x (bsz, len, dim) 119 | Returns: 120 | output (bsz, len, dim) 121 | """ 122 | residual = x 123 | output = x.transpose(1, 2) 124 | output = self.w_2(F.relu(self.w_1(output))) 125 | output = output.transpose(1, 2) 126 | output = self.dropout(output) 127 | output = self.layer_norm(output + residual) 128 | return output 129 | 130 | 131 | class EncoderLayer(nn.Module): 132 | ''' Compose with two layers ''' 133 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 134 | super(EncoderLayer, self).__init__() 135 | self.slf_attn = MultiHeadAttention( 136 | n_head, d_model, d_k, d_v, dropout=dropout) 137 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 138 | 139 | def forward(self, Q, K, V, non_pad_mask=None, slf_attn_mask=None): 140 | enc_output = self.slf_attn( 141 | Q, K, V, mask=slf_attn_mask) 142 | # enc_output *= non_pad_mask.float() if non_pad_mask is not None else 1. 143 | 144 | enc_output = self.pos_ffn(enc_output) 145 | # enc_output *= non_pad_mask.float() 146 | 147 | return enc_output 148 | 149 | -------------------------------------------------------------------------------- /tvr/models/until_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | import logging 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | import math 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def gelu(x): 28 | """Implementation of the gelu activation function. 29 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 30 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 31 | """ 32 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 33 | 34 | 35 | def swish(x): 36 | return x * torch.sigmoid(x) 37 | 38 | 39 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 40 | 41 | 42 | class LayerNorm(nn.Module): 43 | def __init__(self, hidden_size, eps=1e-12): 44 | """Construct a layernorm module in the TF style (epsilon inside the square root). 45 | """ 46 | super(LayerNorm, self).__init__() 47 | self.weight = nn.Parameter(torch.ones(hidden_size)) 48 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 49 | self.variance_epsilon = eps 50 | 51 | def forward(self, x): 52 | u = x.mean(-1, keepdim=True) 53 | s = (x - u).pow(2).mean(-1, keepdim=True) 54 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 55 | return self.weight * x + self.bias 56 | 57 | 58 | ################################## 59 | ###### LOSS FUNCTION ############# 60 | ################################## 61 | class CrossEn(nn.Module): 62 | def __init__(self, config=None): 63 | super(CrossEn, self).__init__() 64 | 65 | def forward(self, sim_matrix): 66 | logpt = F.log_softmax(sim_matrix, dim=-1) 67 | logpt = torch.diag(logpt) 68 | nce_loss = -logpt 69 | sim_loss = nce_loss.mean() 70 | return sim_loss 71 | 72 | 73 | class ArcCrossEn(nn.Module): 74 | def __init__(self, margin=10): 75 | super(ArcCrossEn, self).__init__() 76 | self.cos_m = math.cos(margin) 77 | self.sin_m = math.sin(margin) 78 | 79 | def forward(self, sim_matrix, scale): 80 | cos = torch.diag(sim_matrix) 81 | sin = torch.sqrt(1.0 - torch.pow(cos, 2)) 82 | pin = cos * self.cos_m - sin * self.sin_m 83 | sim_matrix = sim_matrix - torch.diag_embed(cos) + torch.diag_embed(pin) 84 | logpt = F.log_softmax(sim_matrix / scale, dim=-1) 85 | logpt = torch.diag(logpt) 86 | nce_loss = -logpt 87 | sim_loss = nce_loss.mean() 88 | return sim_loss 89 | 90 | 91 | class CrossEn0(nn.Module): 92 | def __init__(self, config=None): 93 | super(CrossEn0, self).__init__() 94 | 95 | def forward(self, sim_matrix, b): 96 | logpt = F.log_softmax(sim_matrix[:b, :], dim=-1) 97 | logpt = torch.diag(logpt[:, :b]) 98 | nce_loss = -logpt 99 | sim_loss = nce_loss.mean() 100 | return sim_loss 101 | 102 | 103 | class ema_CrossEn(nn.Module): 104 | def __init__(self, config=None): 105 | super(ema_CrossEn, self).__init__() 106 | 107 | def forward(self, sim_matrix0, sim_matrix1): 108 | m, n = sim_matrix0.size() 109 | diag1 = torch.diag(sim_matrix1) 110 | diag1 = torch.diag_embed(diag1) 111 | sim_matrix1 = sim_matrix1 - diag1 112 | logpt = F.log_softmax(torch.cat([sim_matrix0, sim_matrix1], dim=-1), dim=-1) 113 | logpt = torch.diag(logpt[:, :n]) 114 | nce_loss = -logpt 115 | sim_loss = nce_loss.mean() 116 | return sim_loss 117 | 118 | 119 | class DC_CrossEn(nn.Module): 120 | def __init__(self, config=None): 121 | super(DC_CrossEn, self).__init__() 122 | 123 | def forward(self, sim_matrix0, sim_matrix1, seta=0.8): 124 | diag0 = torch.diag(sim_matrix0) 125 | diag1 = torch.diag(sim_matrix1) 126 | sim_matrix0 = sim_matrix0 - diag0 127 | sim_matrix1 = sim_matrix1 - diag1 128 | m, n = sim_matrix0.size() 129 | 130 | sim_matrix = torch.where(sim_matrix1 < seta, sim_matrix0, torch.tensor(0.0).to(sim_matrix0.device)) 131 | sim_matrix = sim_matrix + diag0 132 | 133 | logpt = F.log_softmax(sim_matrix, dim=-1) 134 | logpt = torch.diag(logpt) 135 | nce_loss = -logpt 136 | sim_loss = nce_loss.mean() 137 | return sim_loss 138 | 139 | 140 | class ema_CrossEn1(nn.Module): 141 | def __init__(self, config=None): 142 | super(ema_CrossEn1, self).__init__() 143 | 144 | def forward(self, sim_matrix0, sim_matrix1): 145 | logpt0 = F.log_softmax(sim_matrix0, dim=-1) 146 | logpt1 = F.softmax(sim_matrix1, dim=-1) 147 | sim_loss = - logpt0 * logpt1 148 | # diag = torch.diag(sim_loss) 149 | # sim_loss = sim_loss - diag 150 | sim_loss = sim_loss.mean() 151 | return sim_loss 152 | 153 | 154 | class ema_CrossEn2(nn.Module): 155 | def __init__(self, config=None): 156 | super(ema_CrossEn2, self).__init__() 157 | 158 | def forward(self, sim_matrix0, sim_matrix1, lambd=0.5): 159 | m, n = sim_matrix1.size() 160 | 161 | logpt0 = F.log_softmax(sim_matrix0, dim=-1) 162 | logpt1 = F.softmax(sim_matrix1, dim=-1) 163 | logpt1 = lambd * torch.eye(m).to(logpt1.device) + (1 - lambd) * logpt1 164 | 165 | sim_loss = - logpt0 * logpt1 166 | sim_loss = sim_loss.sum() / m 167 | return sim_loss 168 | 169 | 170 | class KL(nn.Module): 171 | def __init__(self, config=None): 172 | super(KL, self).__init__() 173 | 174 | def forward(self, sim_matrix0, sim_matrix1): 175 | logpt0 = F.log_softmax(sim_matrix0, dim=-1) 176 | logpt1 = F.softmax(sim_matrix1, dim=-1) 177 | kl = F.kl_div(logpt0, logpt1, reduction='mean') 178 | # kl = F.kl_div(logpt0, logpt1, reduction='sum') 179 | return kl 180 | 181 | 182 | def _batch_hard(mat_distance, mat_similarity, indice=False): 183 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (9999999.) * (1 - mat_similarity), dim=1, 184 | descending=False) 185 | hard_p = sorted_mat_distance[:, 0] 186 | hard_p_indice = positive_indices[:, 0] 187 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (-9999999.) * (mat_similarity), dim=1, 188 | descending=True) 189 | hard_n = sorted_mat_distance[:, 0] 190 | hard_n_indice = negative_indices[:, 0] 191 | if (indice): 192 | return hard_p, hard_n, hard_p_indice, hard_n_indice 193 | return hard_p, hard_n 194 | 195 | 196 | class SoftTripletLoss(nn.Module): 197 | def __init__(self, config=None): 198 | super(SoftTripletLoss, self).__init__() 199 | 200 | def forward(self, sim_matrix0, sim_matrix1): 201 | N = sim_matrix0.size(0) 202 | mat_sim = torch.eye(N).float().to(sim_matrix0.device) 203 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(sim_matrix0, mat_sim, indice=True) 204 | triple_dist = torch.stack((dist_ap, dist_an), dim=1) 205 | triple_dist = F.log_softmax(triple_dist, dim=1) 206 | dist_ap_ref = torch.gather(sim_matrix1, 1, ap_idx.view(N, 1).expand(N, N))[:, 0] 207 | dist_an_ref = torch.gather(sim_matrix1, 1, an_idx.view(N, 1).expand(N, N))[:, 0] 208 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) 209 | triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach() 210 | loss = (- triple_dist_ref * triple_dist).mean(0).sum() 211 | return loss 212 | 213 | 214 | class MSE(nn.Module): 215 | def __init__(self, config=None): 216 | super(MSE, self).__init__() 217 | 218 | def forward(self, sim_matrix0, sim_matrix1): 219 | logpt = (sim_matrix0 - sim_matrix1) 220 | loss = logpt * logpt 221 | return loss.mean() 222 | 223 | 224 | def euclidean_dist(x, y): 225 | m, n = x.size(0), y.size(0) 226 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 227 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 228 | dist = xx + yy 229 | dist.addmm_(1, -2, x, y.t()) 230 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 231 | return dist 232 | 233 | 234 | def uniformity_loss(x, y): 235 | input = torch.cat((x, y), dim=0) 236 | m = input.size(0) 237 | dist = euclidean_dist(input, input) 238 | return torch.logsumexp(torch.logsumexp(dist, dim=-1), dim=-1) - torch.log(torch.tensor(m * m - m)) 239 | 240 | 241 | class AllGather(torch.autograd.Function): 242 | """An autograd function that performs allgather on a tensor.""" 243 | 244 | @staticmethod 245 | def forward(ctx, tensor, args): 246 | if args.world_size == 1: 247 | ctx.rank = args.local_rank 248 | ctx.batch_size = tensor.shape[0] 249 | return tensor 250 | else: 251 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 252 | torch.distributed.all_gather(output, tensor) 253 | ctx.rank = args.local_rank 254 | ctx.batch_size = tensor.shape[0] 255 | return torch.cat(output, dim=0) 256 | 257 | @staticmethod 258 | def backward(ctx, grad_output): 259 | return ( 260 | grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], 261 | None, 262 | ) 263 | 264 | 265 | class AllGather2(torch.autograd.Function): 266 | """An autograd function that performs allgather on a tensor.""" 267 | 268 | # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 269 | @staticmethod 270 | def forward(ctx, tensor, args): 271 | if args.world_size == 1: 272 | ctx.rank = args.local_rank 273 | ctx.batch_size = tensor.shape[0] 274 | return tensor 275 | else: 276 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 277 | torch.distributed.all_gather(output, tensor) 278 | ctx.rank = args.local_rank 279 | ctx.batch_size = tensor.shape[0] 280 | return torch.cat(output, dim=0) 281 | 282 | @staticmethod 283 | def backward(ctx, grad_output): 284 | grad_input = grad_output.clone() 285 | torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) 286 | return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) * ctx.batch_size], None) -------------------------------------------------------------------------------- /tvr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__init__.py -------------------------------------------------------------------------------- /tvr/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/utils/__pycache__/comm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/comm.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/utils/__pycache__/metric_logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/metric_logger.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/utils/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /tvr/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import time 7 | import pickle 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def get_world_size(): 14 | if not dist.is_available(): 15 | return 1 16 | if not dist.is_initialized(): 17 | return 1 18 | return dist.get_world_size() 19 | 20 | 21 | def get_rank(): 22 | if not dist.is_available(): 23 | return 0 24 | if not dist.is_initialized(): 25 | return 0 26 | return dist.get_rank() 27 | 28 | 29 | def is_main_process(): 30 | return get_rank() == 0 31 | 32 | 33 | def synchronize(): 34 | """ 35 | Helper function to synchronize (barrier) among all processes when 36 | using distributed training 37 | """ 38 | if not dist.is_available(): 39 | return 40 | if not dist.is_initialized(): 41 | return 42 | world_size = dist.get_world_size() 43 | if world_size == 1: 44 | return 45 | dist.barrier() 46 | 47 | 48 | def all_gather(data): 49 | """ 50 | Run all_gather on arbitrary picklable data (not necessarily tensors) 51 | Args: 52 | data: any picklable object 53 | Returns: 54 | list[data]: list of data gathered from each rank 55 | """ 56 | world_size = get_world_size() 57 | if world_size == 1: 58 | return [data] 59 | 60 | buffer = pickle.dumps(data) 61 | storage = torch.ByteStorage.from_buffer(buffer) 62 | tensor = torch.ByteTensor(storage).to("cuda") 63 | # obtain Tensor size of each rank 64 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 65 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 66 | dist.all_gather(size_list, local_size) 67 | size_list = [int(size.item()) for size in size_list] 68 | max_size = max(size_list) 69 | 70 | # receiving Tensor from all ranks 71 | # we pad the tensor because torch all_gather does not support 72 | # gathering tensors of different shapes 73 | tensor_list = [] 74 | for _ in size_list: 75 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 76 | if local_size != max_size: 77 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 78 | tensor = torch.cat((tensor, padding), dim=0) 79 | dist.all_gather(tensor_list, tensor) 80 | 81 | data_list = [] 82 | for size, tensor in zip(size_list, tensor_list): 83 | # print(get_rank(), size) 84 | buffer = tensor.cpu().numpy().tobytes()[:size] 85 | data_list.append(pickle.loads(buffer)) 86 | 87 | return data_list 88 | 89 | 90 | def reduce_dict(input_dict, average=True): 91 | """ 92 | Args: 93 | input_dict (dict): all the values will be reduced 94 | average (bool): whether to do average or sum 95 | Reduce the values in the dictionary from all processes so that process with rank 96 | 0 has the averaged results. Returns a dict with the same fields as 97 | input_dict, after reduction. 98 | """ 99 | world_size = get_world_size() 100 | if world_size < 2: 101 | return input_dict 102 | with torch.no_grad(): 103 | names = [] 104 | values = [] 105 | # sort the keys so that they are consistent across processes 106 | for k in sorted(input_dict.keys()): 107 | names.append(k) 108 | values.append(input_dict[k]) 109 | values = torch.stack(values, dim=0) 110 | dist.reduce(values, dst=0) 111 | if dist.get_rank() == 0 and average: 112 | # only main process gets accumulated, so only divide by 113 | # world_size in this case 114 | values /= world_size 115 | reduced_dict = {k: v for k, v in zip(names, values)} 116 | return reduced_dict 117 | -------------------------------------------------------------------------------- /tvr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | def setup_logger(name, save_dir, dist_rank, filename="log.txt"): 7 | logger = logging.getLogger(name) 8 | logger.setLevel(logging.ERROR) 9 | # don't log results for the non-master process 10 | if dist_rank > 0: 11 | return logger 12 | logger.setLevel(logging.DEBUG) 13 | ch = logging.StreamHandler(stream=sys.stdout) 14 | ch.setLevel(logging.DEBUG) 15 | # formatter = logging.Formatter(f"[{dist_rank}]"+"[%(asctime)s %(name)s %(lineno)s %(levelname)s]: %(message)s") 16 | formatter = logging.Formatter("[%(asctime)s %(name)s %(lineno)s %(levelname)s]: %(message)s") 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | logger.propagate = False 20 | 21 | if save_dir: 22 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 23 | fh.setLevel(logging.DEBUG) 24 | fh.setFormatter(formatter) 25 | logger.addHandler(fh) 26 | 27 | return logger 28 | -------------------------------------------------------------------------------- /tvr/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20): 14 | self.deque = deque(maxlen=window_size) 15 | self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | class MetricLogger(object): 41 | def __init__(self, delimiter="\t"): 42 | self.meters = defaultdict(SmoothedValue) 43 | self.delimiter = delimiter 44 | 45 | def update(self, **kwargs): 46 | for k, v in kwargs.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.item() 49 | assert isinstance(v, (float, int)) 50 | self.meters[k].update(v) 51 | 52 | def __getattr__(self, attr): 53 | if attr in self.meters: 54 | return self.meters[attr] 55 | if attr in self.__dict__: 56 | return self.__dict__[attr] 57 | raise AttributeError("'{}' object has no attribute '{}'".format( 58 | type(self).__name__, attr)) 59 | 60 | def __str__(self): 61 | loss_str = [] 62 | for name, meter in self.meters.items(): 63 | loss_str.append( 64 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 65 | ) 66 | return self.delimiter.join(loss_str) 67 | -------------------------------------------------------------------------------- /tvr/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def compute_metrics(x, re_ranking=False, sim=None): 11 | 12 | # 在score1+2中对每行进行排序 13 | # sx = np.sort(-x, axis=1) # 默认从小到达,加-1从大到小 14 | sx = torch.sort(torch.Tensor(x), dim=1, descending=True)[0].numpy() 15 | # if re_ranking: 16 | # topk = 10 17 | # # 从 score1 每行选择前 topk 个值 18 | # topk_values, topk_indices = torch.topk(torch.Tensor(sim), topk, dim=1) 19 | 20 | # # 在 score2 中对每行进行排序(降序) 21 | # sorted_score2, original_indice = torch.sort(torch.Tensor(x), dim=1, descending=True) # original_indice为排序后的索引, original_indice[i]: 第original_indice[i]下标对应的数值现在排在第i个位置 22 | 23 | # # 找出排序后 score_2 对应的 topk_indices 的数值目前都被排在哪些位置 24 | # ind = torch.zeros_like(topk_indices) 25 | # for i in range(sorted_score2.shape[0]): 26 | # for j in range(topk_indices.shape[1]): 27 | # ind[i][j] = (original_indice[i] == topk_indices[i][j]).nonzero().item() # video 3 和 video 8 排序后分别排在X1位和X2位 28 | 29 | # ind, _ = torch.sort(ind, dim=1, descending=False) # 对ind进行排序,得到原始的排序索引 30 | 31 | # new_indices = original_indice.clone() # 深拷贝 32 | 33 | # for i in range(topk_indices.shape[0]): 34 | # for j in range(topk_indices.shape[1]): 35 | # new_indices[i][ind[i][j]] = topk_indices[i][j] 36 | 37 | # sx = torch.gather(torch.Tensor(x), 1, new_indices).numpy() # re-ranking 后的score2 38 | 39 | d = np.diag(x) 40 | d = d[:, np.newaxis] 41 | ind = sx - d 42 | ind = np.where(ind == 0) 43 | ind = ind[1] 44 | metrics = {} 45 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 46 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 47 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 48 | metrics['R50'] = float(np.sum(ind < 50)) * 100 / len(ind) 49 | metrics['MR'] = np.median(ind) + 1 50 | metrics["MedianR"] = metrics['MR'] 51 | metrics["MeanR"] = np.mean(ind) + 1 52 | metrics["cols"] = [int(i) for i in list(ind)] 53 | return metrics 54 | 55 | def print_computed_metrics(metrics): 56 | r1 = metrics['R1'] 57 | r5 = metrics['R5'] 58 | r10 = metrics['R10'] 59 | r50 = metrics['R50'] 60 | mr = metrics['MR'] 61 | meanr = metrics["MeanR"] 62 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - R@50: {:.4f} - Median R: {} - MeanR: {}'.format(r1, r5, r10, r50, 63 | mr, meanr)) 64 | 65 | 66 | # below two functions directly come from: https://github.com/Deferf/Experiments 67 | def tensor_text_to_video_metrics(sim_tensor, top_k=[1, 5, 10, 50]): 68 | if not torch.is_tensor(sim_tensor): 69 | sim_tensor = torch.tensor(sim_tensor) 70 | 71 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices. 72 | # Then obtain the double argsort to position the rank on the diagonal 73 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2) 74 | first_argsort = torch.argsort(stacked_sim_matrices, dim=-1, descending=True) 75 | second_argsort = torch.argsort(first_argsort, dim=-1, descending=False) 76 | 77 | # Extracts ranks i.e diagonals 78 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1=1, dim2=2)) 79 | 80 | # Now we need to extract valid ranks, as some belong to inf padding values 81 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1=0, dim2=2)) 82 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) 83 | valid_ranks = ranks[mask] 84 | # A quick dimension check validates our results, there may be other correctness tests pending 85 | # Such as dot product localization, but that is for other time. 86 | # assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) 87 | if not torch.is_tensor(valid_ranks): 88 | valid_ranks = torch.tensor(valid_ranks) 89 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} 90 | results["MedianR"] = float(torch.median(valid_ranks + 1)) 91 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) 92 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) 93 | results['MR'] = results["MedianR"] 94 | return results 95 | 96 | 97 | def tensor_video_to_text_sim(sim_tensor): 98 | if not torch.is_tensor(sim_tensor): 99 | sim_tensor = torch.tensor(sim_tensor) 100 | # Code to avoid nans 101 | sim_tensor[sim_tensor != sim_tensor] = float('-inf') 102 | # Forms a similarity matrix for use with rank at k 103 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True) 104 | return torch.squeeze(values).T 105 | 106 | 107 | if __name__ == '__main__': 108 | test_sim = np.random.rand(1000, 1000) 109 | metrics = compute_metrics(test_sim) 110 | print_computed_metrics(metrics) 111 | --------------------------------------------------------------------------------