├── .gitignore
├── .vscode
└── launch.json
├── LICENSE
├── README.md
├── imgs
├── exp1.jpg
├── exp2.jpg
├── framework.png
├── intro.pdf
└── introduction.png
├── main_retrieval.py
├── preprocess
└── compress_video.py
├── requirements.txt
├── test_activitynet.sh
├── test_didemo.sh
├── test_lsmdc.sh
├── test_msrvtt.sh
├── test_msvd.sh
├── train_activitynet.sh
├── train_didemo.sh
├── train_lsmdc.sh
├── train_msrvtt.sh
├── train_msvd.sh
└── tvr
├── __init__.py
├── __pycache__
└── __init__.cpython-39.pyc
├── dataloaders
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── data_dataloaders.cpython-39.pyc
│ ├── dataloader_activitynet_retrieval.cpython-39.pyc
│ ├── dataloader_didemo_retrieval.cpython-39.pyc
│ ├── dataloader_lsmdc_retrieval.cpython-39.pyc
│ ├── dataloader_msrvtt_retrieval.cpython-39.pyc
│ ├── dataloader_msvd_retrieval.cpython-39.pyc
│ ├── dataloader_retrieval.cpython-39.pyc
│ ├── rand_augment.cpython-39.pyc
│ ├── random_erasing.cpython-39.pyc
│ ├── rawvideo_util.cpython-39.pyc
│ └── video_transforms.cpython-39.pyc
├── data_dataloaders.py
├── dataloader_activitynet_retrieval.py
├── dataloader_didemo_retrieval.py
├── dataloader_lsmdc_retrieval.py
├── dataloader_msrvtt_retrieval.py
├── dataloader_msvd_retrieval.py
├── dataloader_retrieval.py
├── rand_augment.py
├── random_erasing.py
├── rawvideo_util.py
└── video_transforms.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── modeling.cpython-39.pyc
│ ├── module_clip.cpython-39.pyc
│ ├── module_cross.cpython-39.pyc
│ ├── module_transformer.cpython-39.pyc
│ ├── optimization.cpython-39.pyc
│ ├── query_cross_att.cpython-39.pyc
│ ├── tokenization_clip.cpython-39.pyc
│ ├── transformer.cpython-39.pyc
│ ├── transformer_block.cpython-39.pyc
│ └── until_module.cpython-39.pyc
├── bpe_simple_vocab_16e6.txt.gz
├── modeling.py
├── modeling_old.py
├── module.py
├── module_clip.py
├── module_cross.py
├── module_transformer.py
├── optimization.py
├── tokenization_clip.py
├── transformer_block.py
└── until_module.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-39.pyc
├── comm.cpython-39.pyc
├── logger.cpython-39.pyc
├── metric_logger.cpython-39.pyc
└── metrics.cpython-39.pyc
├── comm.py
├── logger.py
├── metric_logger.py
└── metrics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data
2 | /ckpt
3 | /PromptSwitch
4 | /figs
5 | /tvr/models/ViT-B-*.pt
6 | a.py
7 | draw.ipynb
8 | query_vis.py
9 | query_vis.sh
10 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // 使用 IntelliSense 了解相关属性。
3 | // 悬停以查看现有属性的描述。
4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: 当前文件",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "/home/zhanghaonan/anaconda3/envs/ic/lib/python3.9/site-packages/torch/distributed/launch.py",
12 | "console": "integratedTerminal",
13 | "justMyCode": true,
14 | "env": {
15 | "CUDA_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7,8,9"
16 | },
17 | "args": [
18 | "--master_port=2505",
19 | "--nproc_per_node=4",
20 | "main_retrieval.py",
21 | "--do_train=1",
22 | "--workers=4",
23 | "--n_display=50",
24 | "--epochs=5",
25 | "--lr=1e-4",
26 | "--coef_lr=1e-3",
27 | "--batch_size=128",
28 | "--batch_size_val=64",
29 | "--anno_path=/mnt/nfs/CMG/zhanghaonan/datasets/MSVD/anns",
30 | "--video_path=/mnt/nfs/CMG/zhanghaonan/datasets/MSVD/MSVD_Videos",
31 | "--datatype=msvd",
32 | "--max_words=32",
33 | "--max_frames=12",
34 | "--video_framerate=1",
35 | "--output_dir=ckpt/msvd/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight",
36 | "--center=1",
37 | "--temp=3",
38 | "--alpha=0.0001",
39 | "--beta=0.005",
40 | ]
41 | }
42 | ]
43 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # _GLSCL_: Text-Video Retrieval with Global-Local Semantic Consistent Learning (TIP 2025)
3 |
4 | Haonan Zhang, Pengpeng Zeng, Lianli Gao, Jingkuan Song, Yihang Duan, Xinyu Lyu, Heng Tao Shen
5 |
6 |
7 |
8 | This is the official code implementation of the paper "Text-Video Retrieval with Global-Local Semantic Consistent Learning", the checkpoint and feature will be released soon.
9 |
10 | We are continuously refactoring our code, be patient and wait for the latest updates!
11 |
12 | ## 🔥 Updates
13 |
14 | - [ ] Release the pre-trained weight and datasets.
15 | - [x] Release the training and evaluation code.
16 |
17 | ## ✨Overview
18 | Adapting large-scale image-text pre-training models, _e.g._, CLIP, to the video domain represents the current state-of-the-art for text-video retrieval. The primary approaches involve transferring text-video pairs to a common embedding space and leveraging cross-modal interactions on specific entities for semantic alignment. Though effective, these paradigms entail prohibitive computational costs, leading to inefficient retrieval. To address this, we propose a simple yet effective method, Global-Local Semantic Consistent Learning (GLSCL), which capitalizes on latent shared semantics across modalities for text-video retrieval. Specifically, we introduce a parameter-free global interaction module to explore coarse-grained alignment. Then, we devise a shared local interaction module that employs several learnable queries to capture latent semantic concepts for learning fine-grained alignment. Moreover, we propose an inter-consistency loss and an intra-diversity loss to ensure the similarity and diversity of these concepts across and within modalities, respectively.
19 |
20 |
21 | 
22 | Figure 1. Performance comparison of the retrieval results (R@1) and computational costs (FLOPs) for text-to-video retrieval models.
23 |
24 |
25 |
26 | ## 🍀 Method
27 |
28 | 
29 | Figure 2. Overview of the proposed GLSCL for Text-Video retrieval.
30 |
31 |
32 | ## ⚙️ Usage
33 | ### Requirements
34 | The GLSCL framework depends on the following main requirements:
35 | - torch==1.8.1+cu114
36 | - Transformers 4.6.1
37 | - OpenCV 4.5.3
38 | - tqdm
39 |
40 | ### Datasets
41 | We train our model on ```MSR-VTT-9k```, ```MSVD```, ```DiDeMo```, ```LSMDC```, and ```ActivityNet``` datasets respectively. Please refer to this [repo](https://github.com/jpthu17/DiCoSA) for data preparation.
42 |
43 | ### How to Run (take *MSR-VTT* for example)
44 |
45 | For simple training on MSR-VTT-9k with default hyperparameters:
46 | ```
47 | bash train_msrvtt.sh
48 | ```
49 | or run in the terminal directly:
50 | ```
51 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
52 | python -m torch.distributed.launch \
53 | --master_port 2513 \
54 | --nproc_per_node=8 \
55 | main_retrieval.py \
56 | --do_train 1 \
57 | --workers 8 \
58 | --n_display 50 \
59 | --epochs 5 \
60 | --lr 1e-4 \
61 | --coef_lr 1e-3 \
62 | --batch_size 128 \
63 | --batch_size_val 64 \
64 | --anno_path ANNOTATION_PATH \
65 | --video_path YOUR_RAW_VIDEO_PATH \
66 | --datatype msrvtt \
67 | --max_words 32 \
68 | --max_frames 12 \
69 | --video_framerate 1 \
70 | --output_dir YOUR_SAVE_PATH \
71 | --center 1 \
72 | --temp 3 \
73 | --alpha 0.0001 \
74 | --beta 0.005 \
75 | --query_number 8 \
76 | --base_encoder ViT-B/32 \
77 | --cross_att_layer 3 \
78 | --query_share 1 \
79 | --cross_att_share 1 \
80 | --loss2_weight 0.5 \
81 | ```
82 | ### How to Evaluate (take *MSR-VTT* for example)
83 |
84 | For simple testing on MSR-VTT-9k with default hyperparameters:
85 | ```
86 | bash train_msrvtt.sh
87 | ```
88 | or
89 | ```
90 | CUDA_VISIBLE_DEVICES=0,1 \
91 | python -m torch.distributed.launch \
92 | --master_port 2503 \
93 | --nproc_per_node=2 \
94 | main_retrieval.py \
95 | --do_eval 1 \
96 | --workers 8 \
97 | --n_display 50 \
98 | --epochs 5 \
99 | --lr 1e-4 \
100 | --coef_lr 1e-3 \
101 | --batch_size 64 \
102 | --anno_path ANNOTATION_PATH \
103 | --video_path YOUR_RAW_VIDEO_PATH \
104 | --datatype msrvtt \
105 | --max_words 32 \
106 | --max_frames 12 \
107 | --video_framerate 1 \
108 | --output_dir YOUR_SAVE_PATH \
109 | --center 1 \
110 | --temp 3 \
111 | --alpha 0.0001 \
112 | --beta 0.005 \
113 | --query_number 8 \
114 | --base_encoder ViT-B/32 \
115 | --cross_att_layer 3 \
116 | --query_share 1 \
117 | --cross_att_share 1 \
118 | --loss2_weight 0.5 \
119 | --init_model YOUR_CKPT_FILE
120 | ```
121 |
122 | ## 🧪 Experiments
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 | ## 📚 Citation
132 |
133 | ```bibtex
134 | @inproceedings{GLSCL,
135 | author = {Haonan Zhang, and Pengpeng Zeng, and Lianli Gao, and Jingkuan Song, and Yihang Duan, and Xinyu Lyu, and Hengtao Sheng},
136 | title = {Text-Video Retrieval with Global-Local Semantic Consistent Learning},
137 | year = {2024}
138 | }
139 | ```
140 |
--------------------------------------------------------------------------------
/imgs/exp1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/exp1.jpg
--------------------------------------------------------------------------------
/imgs/exp2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/exp2.jpg
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/framework.png
--------------------------------------------------------------------------------
/imgs/intro.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/intro.pdf
--------------------------------------------------------------------------------
/imgs/introduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/imgs/introduction.png
--------------------------------------------------------------------------------
/preprocess/compress_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip
3 | Author: ArrowLuo
4 | """
5 | import os
6 | import argparse
7 | import ffmpeg
8 | import subprocess
9 | import time
10 | import multiprocessing
11 | from multiprocessing import Pool
12 | import shutil
13 | try:
14 | from psutil import cpu_count
15 | except:
16 | from multiprocessing import cpu_count
17 | # multiprocessing.freeze_support()
18 |
19 | def compress(paras):
20 | input_video_path, output_video_path = paras
21 | try:
22 | command = ['ffmpeg',
23 | '-y', # (optional) overwrite output file if it exists
24 | '-i', input_video_path,
25 | '-filter:v',
26 | 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'', # scale to 224
27 | '-map', '0:v',
28 | '-r', '3', # frames per second
29 | output_video_path,
30 | ]
31 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32 | out, err = ffmpeg.communicate()
33 | retcode = ffmpeg.poll()
34 | # print something above for debug
35 | except Exception as e:
36 | raise e
37 |
38 | def prepare_input_output_pairs(input_root, output_root):
39 | input_video_path_list = []
40 | output_video_path_list = []
41 | for root, dirs, files in os.walk(input_root):
42 | for file_name in files:
43 | input_video_path = os.path.join(root, file_name)
44 | output_video_path = os.path.join(output_root, file_name)
45 | if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0:
46 | pass
47 | else:
48 | input_video_path_list.append(input_video_path)
49 | output_video_path_list.append(output_video_path)
50 | return input_video_path_list, output_video_path_list
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser(description='Compress video for speed-up')
54 | parser.add_argument('--input_root', type=str, help='input root')
55 | parser.add_argument('--output_root', type=str, help='output root')
56 | args = parser.parse_args()
57 |
58 | input_root = args.input_root
59 | output_root = args.output_root
60 |
61 | assert input_root != output_root
62 |
63 | if not os.path.exists(output_root):
64 | os.makedirs(output_root, exist_ok=True)
65 |
66 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root)
67 |
68 | print("Total video need to process: {}".format(len(input_video_path_list)))
69 | num_works = cpu_count()
70 | print("Begin with {}-core logical processor.".format(num_works))
71 |
72 | pool = Pool(num_works)
73 | data_dict_list = pool.map(compress,
74 | [(input_video_path, output_video_path) for
75 | input_video_path, output_video_path in
76 | zip(input_video_path_list, output_video_path_list)])
77 | pool.close()
78 | pool.join()
79 |
80 | print("Compress finished, wait for checking files...")
81 | for input_video_path, output_video_path in zip(input_video_path_list, output_video_path_list):
82 | if os.path.exists(input_video_path):
83 | if os.path.exists(output_video_path) is False or os.path.getsize(output_video_path) < 1.:
84 | shutil.copyfile(input_video_path, output_video_path)
85 | print("Copy and replace file: {}".format(output_video_path))
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | decord
2 | pandas
3 | ftfy
4 | regex
5 | tqdm
6 | opencv-python
7 | functional
8 | timm
9 | # torch==1.8.1+cu102
10 | # torchvision==0.9.1+cu102
11 |
--------------------------------------------------------------------------------
/test_activitynet.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9 \
2 | python -m torch.distributed.launch \
3 | --master_port 2503 \
4 | --nproc_per_node=4 \
5 | main_retrieval.py \
6 | --do_eval 1 \
7 | --workers 8 \
8 | --n_display 50 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 128 \
13 | --batch_size_val 128 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/Videos/Activity_Videos \
16 | --datatype activity \
17 | --max_words 64 \
18 | --max_frames 64 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/activitynet/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \
21 | --center 1 \
22 | --query_number 8 \
23 | --cross_att_layer 3 \
24 | --query_share 1 \
25 | --cross_att_share 1 \
26 | --add_query_score_for_eval 0 \
27 | --base_encoder ViT-B/32 \
28 | --temp 3 \
29 | --alpha 0.0001 \
30 | --beta 0.005 \
31 | --t2v_beta 50 \
32 | --v2t_beta 50 \
33 | --init_model ckpt/activitynet/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight/pytorch_model.bin.step900.5
--------------------------------------------------------------------------------
/test_didemo.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=8,9 \
2 | python -m torch.distributed.launch \
3 | --master_port 2503 \
4 | --nproc_per_node=2 \
5 | main_retrieval.py \
6 | --do_eval 1 \
7 | --workers 8 \
8 | --n_display 50 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 32 \
13 | --batch_size_val 32 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/videos \
16 | --datatype didemo \
17 | --max_words 64 \
18 | --max_frames 64 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/didemo/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \
21 | --center 1 \
22 | --query_number 8 \
23 | --cross_att_layer 3 \
24 | --query_share 1 \
25 | --cross_att_share 1 \
26 | --add_query_score_for_eval 0 \
27 | --base_encoder ViT-B/32 \
28 | --temp 3 \
29 | --alpha 0.0001 \
30 | --beta 0.005 \
31 | --init_model ckpt/didemo/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight/pytorch_model.bin.step1200.4
--------------------------------------------------------------------------------
/test_lsmdc.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/test_lsmdc.sh
--------------------------------------------------------------------------------
/test_msrvtt.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=8,9 \
2 | python -m torch.distributed.launch \
3 | --master_port 2503 \
4 | --nproc_per_node=2 \
5 | main_retrieval.py \
6 | --do_eval 1 \
7 | --workers 8 \
8 | --n_display 50 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 64 \
13 | --batch_size_val 64 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/MSRVTT_Videos \
16 | --datatype msrvtt \
17 | --max_words 32 \
18 | --max_frames 12 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/msrvtt/ablation/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \
21 | --center 1 \
22 | --query_number 8 \
23 | --cross_att_layer 3 \
24 | --query_share 0 \
25 | --cross_att_share 1 \
26 | --add_query_score_for_eval 1 \
27 | --base_encoder ViT-B/32 \
28 | --temp 3 \
29 | --alpha 0.0001 \
30 | --beta 0.005 \
31 | --init_model ckpt/msrvtt/ablation/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight/pytorch_model.bin.step2850.4 \
32 | --loss2_weight 0.5 \
33 |
--------------------------------------------------------------------------------
/test_msvd.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/test_msvd.sh
--------------------------------------------------------------------------------
/train_activitynet.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
2 | python -m torch.distributed.launch \
3 | --master_port 2502 \
4 | --nproc_per_node=8 \
5 | main_retrieval.py \
6 | --do_train 1 \
7 | --workers 8 \
8 | --n_display 10 \
9 | --epochs 10 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 64 \
13 | --batch_size_val 64 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/activitynet/Videos/Activity_Videos \
16 | --datatype activity \
17 | --max_words 64 \
18 | --max_frames 64 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/activitynet/main_exp/v2 \
21 | --center 1 \
22 | --temp 3 \
23 | --alpha 0.0001 \
24 | --beta 0.005 \
25 | --t2v_beta 50 \
26 | --v2t_beta 50 \
27 | --query_number 12 \
28 | --base_encoder ViT-B/32 \
29 | --cross_att_layer 3 \
30 | --query_share 1 \
31 | --cross_att_share 1 \
32 | --loss2_weight 0.5 \
--------------------------------------------------------------------------------
/train_didemo.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
2 | python -m torch.distributed.launch \
3 | --master_port 2502 \
4 | --nproc_per_node=4 \
5 | main_retrieval.py \
6 | --do_train 1 \
7 | --workers 8 \
8 | --n_display 50 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 32 \
13 | --batch_size_val 32 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/Didemo/videos \
16 | --datatype didemo \
17 | --max_words 64 \
18 | --max_frames 64 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/didemo/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \
21 | --center 1 \
22 | --temp 3 \
23 | --alpha 0.0001 \
24 | --beta 0.005 \
25 | --query_number 8 \
26 | --base_encoder ViT-B/32 \
27 | --cross_att_layer 3 \
28 | --add_query_score_for_eval 0 \
29 | --query_share 1 \
30 | --cross_att_share 1 \
--------------------------------------------------------------------------------
/train_lsmdc.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
2 | python -m torch.distributed.launch \
3 | --master_port 2502 \
4 | --nproc_per_node=4 \
5 | main_retrieval.py \
6 | --do_train 1 \
7 | --workers 8 \
8 | --n_display 5 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 128 \
13 | --batch_size_val 64 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/LSMDC/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/LSMDC \
16 | --datatype lsmdc \
17 | --max_words 32 \
18 | --max_frames 12 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/lsmdc/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_4cross_add_query_sim_query_shared_cross_att_shared_without_weight \
21 | --center 1 \
22 | --query_number 8 \
23 | --base_encoder ViT-B/32 \
24 | --cross_att_layer 3 \
25 | --add_query_score_for_eval 0 \
26 | --query_share 1 \
27 | --cross_att_share 1 \
28 | --loss2_weight 0.5 \
--------------------------------------------------------------------------------
/train_msrvtt.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=4,5,6,7,8,9 \
2 | python -m torch.distributed.launch \
3 | --master_port 2513 \
4 | --nproc_per_node=4 \
5 | main_retrieval.py \
6 | --do_train 1 \
7 | --workers 8 \
8 | --n_display 50 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 128 \
13 | --batch_size_val 64 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/MSR-VTT/MSRVTT_Videos \
16 | --datatype msrvtt \
17 | --max_words 32 \
18 | --max_frames 12 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/msrvtt/no_qbnorm/alpha0.0001_beta_0.02 \
21 | --center 1 \
22 | --temp 3 \
23 | --alpha 0.0001 \
24 | --beta 0.02 \
25 | --query_number 8 \
26 | --base_encoder ViT-B/32 \
27 | --cross_att_layer 3 \
28 | --query_share 1 \
29 | --cross_att_share 1 \
30 | --loss2_weight 0.5 \
--------------------------------------------------------------------------------
/train_msvd.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9 \
2 | python -m torch.distributed.launch \
3 | --master_port 2512 \
4 | --nproc_per_node=4 \
5 | main_retrieval.py \
6 | --do_train 1 \
7 | --workers 8 \
8 | --n_display 50 \
9 | --epochs 5 \
10 | --lr 1e-4 \
11 | --coef_lr 1e-3 \
12 | --batch_size 128 \
13 | --batch_size_val 64 \
14 | --anno_path /mnt/nfs/CMG/zhanghaonan/datasets/MSVD/anns \
15 | --video_path /mnt/nfs/CMG/zhanghaonan/datasets/MSVD/MSVD_Videos \
16 | --datatype msvd \
17 | --max_words 32 \
18 | --max_frames 12 \
19 | --video_framerate 1 \
20 | --output_dir ckpt/msvd/main_exp/8query_intra_consistency_MSE_0.0001_inter_diversity_0.1margin_both_3cross_add_query_sim_query_shared_cross_att_shared_without_weight \
21 | --center 1 \
22 | --temp 3 \
23 | --alpha 0.0001 \
24 | --beta 0.005 \
25 | --query_number 8 \
26 | --base_encoder ViT-B/32 \
27 | --cross_att_layer 3 \
28 | --add_query_score_for_eval 0 \
29 | --query_share 1 \
30 | --cross_att_share 1 \
31 | --loss2_weight 0.5 \
--------------------------------------------------------------------------------
/tvr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/__init__.py
--------------------------------------------------------------------------------
/tvr/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__init__.py
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/data_dataloaders.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/data_dataloaders.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_activitynet_retrieval.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_didemo_retrieval.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_lsmdc_retrieval.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_msrvtt_retrieval.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_msvd_retrieval.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/dataloader_retrieval.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/dataloader_retrieval.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/rand_augment.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/rand_augment.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/random_erasing.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/random_erasing.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/rawvideo_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/rawvideo_util.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/__pycache__/video_transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/dataloaders/__pycache__/video_transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/dataloaders/data_dataloaders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from .dataloader_msrvtt_retrieval import MSRVTTDataset
4 | from .dataloader_activitynet_retrieval import ActivityNetDataset
5 | from .dataloader_didemo_retrieval import DiDeMoDataset
6 | from .dataloader_lsmdc_retrieval import LsmdcDataset
7 | from .dataloader_msvd_retrieval import MsvdDataset
8 |
9 |
10 | def dataloader_msrvtt_train(args, tokenizer):
11 | msrvtt_dataset = MSRVTTDataset(
12 | subset='train',
13 | anno_path=args.anno_path,
14 | video_path=args.video_path,
15 | max_words=args.max_words,
16 | tokenizer=tokenizer,
17 | max_frames=args.max_frames,
18 | video_framerate=args.video_framerate,
19 | config=args
20 | )
21 | try:
22 | train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset)
23 | except:
24 | train_sampler = None # cpu
25 | dataloader = DataLoader(
26 | msrvtt_dataset,
27 | batch_size=args.batch_size // args.world_size,
28 | num_workers=args.workers,
29 | pin_memory=False,
30 | shuffle=(train_sampler is None),
31 | sampler=train_sampler,
32 | drop_last=True,
33 | )
34 |
35 | return dataloader, len(msrvtt_dataset), train_sampler
36 |
37 |
38 | def dataloader_msrvtt_test(args, tokenizer, subset="test"):
39 | msrvtt_testset = MSRVTTDataset(
40 | subset=subset,
41 | anno_path=args.anno_path,
42 | video_path=args.video_path,
43 | max_words=args.max_words,
44 | tokenizer=tokenizer,
45 | max_frames=args.max_frames,
46 | video_framerate=args.video_framerate,
47 | config=args
48 | )
49 |
50 | try:
51 | test_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_testset)
52 | except:
53 | test_sampler = None # cpu
54 | dataloader_msrvtt = DataLoader(
55 | msrvtt_testset,
56 | batch_size=args.batch_size_val // args.world_size,
57 | num_workers=args.workers,
58 | shuffle=False,
59 | sampler=test_sampler,
60 | drop_last=False,
61 | )
62 | return dataloader_msrvtt, len(msrvtt_testset)
63 |
64 |
65 | def dataloader_msrvtt_train_test(args, tokenizer):
66 | msrvtt_dataset = MSRVTTDataset(
67 | subset='train_test',
68 | anno_path=args.anno_path,
69 | video_path=args.video_path,
70 | max_words=args.max_words,
71 | tokenizer=tokenizer,
72 | max_frames=args.max_frames,
73 | video_framerate=args.video_framerate,
74 | config=args
75 | )
76 | try:
77 | train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset)
78 | except:
79 | train_sampler = None # cpu
80 | dataloader = DataLoader(
81 | msrvtt_dataset,
82 | batch_size=args.batch_size // args.world_size,
83 | num_workers=args.workers,
84 | pin_memory=False,
85 | shuffle=(train_sampler is None),
86 | sampler=train_sampler,
87 | drop_last=True,
88 | )
89 |
90 | return dataloader, len(msrvtt_dataset), train_sampler
91 |
92 |
93 | def dataloader_lsmdc_train(args, tokenizer):
94 | lsmdc_dataset = LsmdcDataset(
95 | subset='train',
96 | anno_path=args.anno_path,
97 | video_path=args.video_path,
98 | max_words=args.max_words,
99 | tokenizer=tokenizer,
100 | max_frames=args.max_frames,
101 | video_framerate=args.video_framerate,
102 | config=args
103 | )
104 |
105 | train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset)
106 | dataloader = DataLoader(
107 | lsmdc_dataset,
108 | batch_size=args.batch_size // args.world_size,
109 | num_workers=args.workers,
110 | pin_memory=False,
111 | shuffle=(train_sampler is None),
112 | sampler=train_sampler,
113 | drop_last=True,
114 | )
115 |
116 | return dataloader, len(lsmdc_dataset), train_sampler
117 |
118 |
119 | def dataloader_lsmdc_train_test(args, tokenizer):
120 | lsmdc_dataset = LsmdcDataset(
121 | subset='train_test',
122 | anno_path=args.anno_path,
123 | video_path=args.video_path,
124 | max_words=args.max_words,
125 | tokenizer=tokenizer,
126 | max_frames=args.max_frames,
127 | video_framerate=args.video_framerate,
128 | config=args
129 | )
130 |
131 | train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset)
132 | dataloader = DataLoader(
133 | lsmdc_dataset,
134 | batch_size=args.batch_size // args.world_size,
135 | num_workers=args.workers,
136 | pin_memory=False,
137 | shuffle=(train_sampler is None),
138 | sampler=train_sampler,
139 | drop_last=True,
140 | )
141 |
142 | return dataloader, len(lsmdc_dataset), train_sampler
143 |
144 |
145 | def dataloader_lsmdc_test(args, tokenizer, subset="test"):
146 | lsmdc_testset = LsmdcDataset(
147 | subset=subset,
148 | anno_path=args.anno_path,
149 | video_path=args.video_path,
150 | max_words=args.max_words,
151 | tokenizer=tokenizer,
152 | max_frames=args.max_frames,
153 | video_framerate=args.video_framerate,
154 | config=args
155 | )
156 | try:
157 | test_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_testset)
158 | except:
159 | test_sampler = None # cpu
160 | dataloader_lsmdc = DataLoader(
161 | lsmdc_testset,
162 | batch_size=args.batch_size_val // args.world_size,
163 | num_workers=args.workers,
164 | shuffle=False,
165 | sampler=test_sampler,
166 | drop_last=False,
167 | )
168 | return dataloader_lsmdc, len(lsmdc_testset)
169 |
170 |
171 | def dataloader_activity_train(args, tokenizer):
172 | activity_dataset = ActivityNetDataset(
173 | subset="train",
174 | data_path=args.anno_path,
175 | features_path=args.video_path,
176 | max_words=args.max_words,
177 | feature_framerate=args.video_framerate,
178 | tokenizer=tokenizer,
179 | max_frames=args.max_frames
180 | )
181 |
182 | train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset)
183 | dataloader = DataLoader(
184 | activity_dataset,
185 | batch_size=args.batch_size // args.world_size,
186 | num_workers=args.workers,
187 | pin_memory=False,
188 | shuffle=(train_sampler is None),
189 | sampler=train_sampler,
190 | drop_last=True,
191 | )
192 |
193 | return dataloader, len(activity_dataset), train_sampler
194 |
195 |
196 | def dataloader_activity_train_test(args, tokenizer):
197 | activity_dataset = ActivityNetDataset(
198 | subset="train_test",
199 | data_path=args.anno_path,
200 | features_path=args.video_path,
201 | max_words=args.max_words,
202 | feature_framerate=args.video_framerate,
203 | tokenizer=tokenizer,
204 | max_frames=args.max_frames
205 | )
206 |
207 | train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset)
208 | dataloader = DataLoader(
209 | activity_dataset,
210 | batch_size=args.batch_size // args.world_size,
211 | num_workers=args.workers,
212 | pin_memory=False,
213 | shuffle=(train_sampler is None),
214 | sampler=train_sampler,
215 | drop_last=True,
216 | )
217 |
218 | return dataloader, len(activity_dataset), train_sampler
219 |
220 |
221 | def dataloader_activity_test(args, tokenizer, subset="test"):
222 | activity_testset = ActivityNetDataset(
223 | subset=subset,
224 | data_path=args.anno_path,
225 | features_path=args.video_path,
226 | max_words=args.max_words,
227 | feature_framerate=args.video_framerate,
228 | tokenizer=tokenizer,
229 | max_frames=args.max_frames
230 | )
231 | try:
232 | test_sampler = torch.utils.data.distributed.DistributedSampler(activity_testset)
233 | except:
234 | test_sampler = None # cpu
235 | dataloader_activity = DataLoader(
236 | activity_testset,
237 | batch_size=args.batch_size_val // args.world_size,
238 | num_workers=args.workers,
239 | shuffle=False,
240 | sampler=test_sampler,
241 | drop_last=False,
242 | )
243 | return dataloader_activity, len(activity_testset)
244 |
245 |
246 | def dataloader_msvd_train(args, tokenizer):
247 | msvd_dataset = MsvdDataset(
248 | subset="train",
249 | anno_path=args.anno_path,
250 | video_path=args.video_path,
251 | max_words=args.max_words,
252 | tokenizer=tokenizer,
253 | max_frames=args.max_frames,
254 | video_framerate=args.video_framerate,
255 | config=args
256 | )
257 |
258 | train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset)
259 | dataloader = DataLoader(
260 | msvd_dataset,
261 | batch_size=args.batch_size // args.world_size,
262 | num_workers=args.workers,
263 | pin_memory=False,
264 | shuffle=(train_sampler is None),
265 | sampler=train_sampler,
266 | drop_last=True,
267 | )
268 |
269 | return dataloader, len(msvd_dataset), train_sampler
270 |
271 |
272 | def dataloader_msvd_train_test(args, tokenizer):
273 | msvd_dataset = MsvdDataset(
274 | subset="train_test",
275 | anno_path=args.anno_path,
276 | video_path=args.video_path,
277 | max_words=args.max_words,
278 | tokenizer=tokenizer,
279 | max_frames=args.max_frames,
280 | video_framerate=args.video_framerate,
281 | config=args
282 | )
283 |
284 | train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset)
285 | dataloader = DataLoader(
286 | msvd_dataset,
287 | batch_size=args.batch_size // args.world_size,
288 | num_workers=args.workers,
289 | pin_memory=False,
290 | shuffle=(train_sampler is None),
291 | sampler=train_sampler,
292 | drop_last=True,
293 | )
294 |
295 | return dataloader, len(msvd_dataset), train_sampler
296 |
297 |
298 | def dataloader_msvd_test(args, tokenizer, subset="test"):
299 | msvd_testset = MsvdDataset(
300 | subset=subset,
301 | anno_path=args.anno_path,
302 | video_path=args.video_path,
303 | max_words=args.max_words,
304 | tokenizer=tokenizer,
305 | max_frames=args.max_frames,
306 | video_framerate=args.video_framerate,
307 | config=args
308 | )
309 | try:
310 | test_sampler = torch.utils.data.distributed.DistributedSampler(msvd_testset)
311 | except:
312 | test_sampler = None # cpu
313 | dataloader_msvd = DataLoader(
314 | msvd_testset,
315 | batch_size=args.batch_size_val // args.world_size,
316 | num_workers=args.workers,
317 | shuffle=False,
318 | sampler=test_sampler,
319 | drop_last=False,
320 | )
321 | return dataloader_msvd, len(msvd_testset)
322 |
323 |
324 | def dataloader_didemo_train(args, tokenizer):
325 | didemo_dataset = DiDeMoDataset(
326 | subset="train",
327 | data_path=args.anno_path,
328 | features_path=args.video_path,
329 | max_words=args.max_words,
330 | feature_framerate=args.video_framerate,
331 | tokenizer=tokenizer,
332 | max_frames=args.max_frames
333 | )
334 |
335 | train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset)
336 | dataloader = DataLoader(
337 | didemo_dataset,
338 | batch_size=args.batch_size // args.world_size,
339 | num_workers=args.workers,
340 | pin_memory=False,
341 | shuffle=(train_sampler is None),
342 | sampler=train_sampler,
343 | drop_last=True,
344 | )
345 |
346 | return dataloader, len(didemo_dataset), train_sampler
347 |
348 |
349 | def dataloader_didemo_train_test(args, tokenizer):
350 | didemo_dataset = DiDeMoDataset(
351 | subset="train_test",
352 | data_path=args.anno_path,
353 | features_path=args.video_path,
354 | max_words=args.max_words,
355 | feature_framerate=args.video_framerate,
356 | tokenizer=tokenizer,
357 | max_frames=args.max_frames
358 | )
359 |
360 | train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset)
361 | dataloader = DataLoader(
362 | didemo_dataset,
363 | batch_size=args.batch_size // args.world_size,
364 | num_workers=args.workers,
365 | pin_memory=False,
366 | shuffle=(train_sampler is None),
367 | sampler=train_sampler,
368 | drop_last=True,
369 | )
370 |
371 | return dataloader, len(didemo_dataset), train_sampler
372 |
373 |
374 | def dataloader_didemo_test(args, tokenizer, subset="test"):
375 | didemo_testset = DiDeMoDataset(
376 | subset=subset,
377 | data_path=args.anno_path,
378 | features_path=args.video_path,
379 | max_words=args.max_words,
380 | feature_framerate=args.video_framerate,
381 | tokenizer=tokenizer,
382 | max_frames=args.max_frames
383 | )
384 | try:
385 | test_sampler = torch.utils.data.distributed.DistributedSampler(didemo_testset)
386 | except:
387 | test_sampler = None # cpu
388 | dataloader_didemo = DataLoader(
389 | didemo_testset,
390 | batch_size=args.batch_size_val // args.world_size,
391 | num_workers=args.workers,
392 | shuffle=False,
393 | sampler=test_sampler,
394 | drop_last=False,
395 | )
396 | return dataloader_didemo, len(didemo_testset)
397 |
398 |
399 | DATALOADER_DICT = {}
400 | DATALOADER_DICT["msrvtt"] = {"train": dataloader_msrvtt_train,
401 | "val": dataloader_msrvtt_test,
402 | "test": None,
403 | "train_test": dataloader_msrvtt_train_test}
404 | DATALOADER_DICT["msvd"] = {"train": dataloader_msvd_train,
405 | "val": dataloader_msvd_test,
406 | "test": dataloader_msvd_test,
407 | "train_test": dataloader_msvd_train_test}
408 | DATALOADER_DICT["lsmdc"] = {"train": dataloader_lsmdc_train,
409 | "val": dataloader_lsmdc_test,
410 | "test": dataloader_lsmdc_test,
411 | "train_test": dataloader_lsmdc_train_test}
412 | DATALOADER_DICT["activity"] = {"train":dataloader_activity_train,
413 | "val":dataloader_activity_test,
414 | "test":None,
415 | "train_test": dataloader_activity_train_test}
416 | DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train,
417 | "val":None,
418 | "test":dataloader_didemo_test,
419 | "train_test":dataloader_didemo_train_test}
--------------------------------------------------------------------------------
/tvr/dataloaders/dataloader_activitynet_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import os
7 | from torch.utils.data import Dataset
8 | import numpy as np
9 | import json
10 | import math
11 | from .rawvideo_util import RawVideoExtractor
12 |
13 |
14 | class ActivityNetDataset(Dataset):
15 | def __init__(
16 | self,
17 | subset,
18 | data_path,
19 | features_path,
20 | tokenizer,
21 | max_words=30,
22 | feature_framerate=1.0,
23 | max_frames=100,
24 | image_resolution=224,
25 | frame_order=0,
26 | slice_framepos=2,
27 | ):
28 | self.data_path = data_path
29 | self.features_path = features_path
30 | self.feature_framerate = feature_framerate
31 | self.max_words = max_words
32 | self.max_frames = max_frames
33 | self.tokenizer = tokenizer
34 | # 0: ordinary order; 1: reverse order; 2: random order.
35 | self.frame_order = frame_order
36 | assert self.frame_order in [0, 1, 2]
37 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
38 | self.slice_framepos = slice_framepos
39 | assert self.slice_framepos in [0, 1, 2]
40 |
41 | self.subset = subset
42 | assert self.subset in ["train", "val", "train_test"]
43 |
44 | video_id_path_dict = {}
45 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json")
46 | video_id_path_dict["train_test"] = os.path.join(self.data_path, "train_ids.json")
47 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json")
48 |
49 | video_json_path_dict = {}
50 | video_json_path_dict["train"] = os.path.join(self.data_path, "train.json")
51 | video_json_path_dict["train_test"] = os.path.join(self.data_path, "train.json")
52 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json")
53 |
54 | pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset])
55 | pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset])
56 |
57 | print("video id list: {}".format(len(video_id_list)))
58 | print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys())))
59 |
60 | video_dict = {}
61 | for root, dub_dir, video_files in os.walk(self.features_path):
62 | for video_file in video_files:
63 | video_id_ = ".".join(video_file.split(".")[:-1])[2:]
64 | if video_id_ not in video_id_list:
65 | continue
66 | file_path_ = os.path.join(root, video_file)
67 | video_dict[video_id_] = file_path_
68 | self.video_dict = video_dict
69 | print("video dict: {}".format(len(video_dict)))
70 |
71 | self.pseudo_video_id_list = pseudo_video_id_list
72 | self.video_id_list = video_id_list
73 | self.pseudo_caption_dict = pseudo_caption_dict
74 |
75 | # Get iterator video ids
76 | self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)}
77 | # Get all captions
78 | self.iter2video_pairs_dict = {}
79 | for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list):
80 | if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict:
81 | continue
82 | caption = self.pseudo_caption_dict[pseudo_video_id]
83 | n_caption = len(caption['start'])
84 | for sub_id in range(n_caption):
85 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id)
86 |
87 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution)
88 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
89 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
90 |
91 | def __len__(self):
92 | return len(self.iter2video_pairs_dict)
93 |
94 | def _get_video_id_from_pseduo(self, pseudo_video_id):
95 | video_id = pseudo_video_id[2:]
96 | return video_id
97 |
98 | def _get_video_id_single(self, path):
99 | pseudo_video_id_list = []
100 | video_id_list = []
101 | print('Loading json: {}'.format(path))
102 | with open(path, 'r') as f:
103 | json_data = json.load(f)
104 |
105 | for pseudo_video_id in json_data:
106 | if pseudo_video_id in pseudo_video_id_list:
107 | print("reduplicate.")
108 | else:
109 | video_id = self._get_video_id_from_pseduo(pseudo_video_id)
110 | pseudo_video_id_list.append(pseudo_video_id)
111 | video_id_list.append(video_id)
112 | return pseudo_video_id_list, video_id_list
113 |
114 | def _get_captions_single(self, path):
115 | pseudo_caption_dict = {}
116 | with open(path, 'r') as f:
117 | json_data = json.load(f)
118 |
119 | for pseudo_video_id, v_ in json_data.items():
120 | pseudo_caption_dict[pseudo_video_id] = {}
121 | duration = v_["duration"]
122 | pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object)
123 | pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object)
124 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object)
125 | return pseudo_caption_dict
126 |
127 | def _get_text(self, pseudo_video_id, sub_id):
128 | caption = self.pseudo_caption_dict[pseudo_video_id]
129 | k = 1
130 | r_ind = [sub_id]
131 |
132 | starts = np.zeros(k, dtype=np.long)
133 | ends = np.zeros(k, dtype=np.long)
134 | pairs_text = np.zeros((k, self.max_words), dtype=np.long)
135 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
136 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
137 |
138 | for i in range(k):
139 | ind = r_ind[i]
140 | start_, end_ = caption['start'][ind], caption['end'][ind]
141 | words = self.tokenizer.tokenize(caption['text'][ind])
142 | starts[i], ends[i] = start_, end_
143 |
144 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
145 | total_length_with_CLS = self.max_words - 1
146 | if len(words) > total_length_with_CLS:
147 | words = words[:total_length_with_CLS]
148 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
149 |
150 | input_ids = self.tokenizer.convert_tokens_to_ids(words)
151 | input_mask = [1] * len(input_ids)
152 | segment_ids = [0] * len(input_ids)
153 | while len(input_ids) < self.max_words:
154 | input_ids.append(0)
155 | input_mask.append(0)
156 | segment_ids.append(0)
157 | assert len(input_ids) == self.max_words
158 | assert len(input_mask) == self.max_words
159 | assert len(segment_ids) == self.max_words
160 |
161 | pairs_text[i] = np.array(input_ids)
162 | pairs_mask[i] = np.array(input_mask)
163 | pairs_segment[i] = np.array(segment_ids)
164 |
165 | return pairs_text, pairs_mask, pairs_segment, starts, ends
166 |
167 | def _get_rawvideo(self, idx, s, e):
168 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
169 | max_video_length = [0] * len(s)
170 |
171 | # Pair x L x T x 3 x H x W
172 | video = np.zeros((len(s), self.max_frames, 1, 3,
173 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float)
174 | video_path = self.video_dict[idx]
175 | try:
176 | for i in range(len(s)):
177 | start_time = int(s[i])
178 | end_time = int(e[i])
179 | start_time = start_time if start_time >= 0. else 0.
180 | end_time = end_time if end_time >= 0. else 0.
181 | if start_time > end_time:
182 | start_time, end_time = end_time, start_time
183 | elif start_time == end_time:
184 | end_time = end_time + 1
185 |
186 | # Should be optimized by gathering all asking of this video
187 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time)
188 | raw_video_data = raw_video_data['video']
189 |
190 | if len(raw_video_data.shape) > 3:
191 | raw_video_data_clip = raw_video_data
192 | # L x T x 3 x H x W
193 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
194 | if self.max_frames < raw_video_slice.shape[0]:
195 | if self.slice_framepos == 0:
196 | video_slice = raw_video_slice[:self.max_frames, ...]
197 | elif self.slice_framepos == 1:
198 | video_slice = raw_video_slice[-self.max_frames:, ...]
199 | else:
200 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int)
201 | video_slice = raw_video_slice[sample_indx, ...]
202 | else:
203 | video_slice = raw_video_slice
204 |
205 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
206 |
207 | slice_len = video_slice.shape[0]
208 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
209 | if slice_len < 1:
210 | pass
211 | else:
212 | video[i][:slice_len, ...] = video_slice
213 | else:
214 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time))
215 | except Exception as excep:
216 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep))
217 | raise excep
218 |
219 | for i, v_length in enumerate(max_video_length):
220 | video_mask[i][:v_length] = [1] * v_length
221 |
222 | return video, video_mask
223 |
224 | def __getitem__(self, feature_idx):
225 | pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
226 | idx = self.video_id2idx_dict[pseudo_video_id]
227 |
228 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id)
229 | video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends)
230 | return pairs_text, pairs_mask, video, video_mask, feature_idx, hash(pseudo_video_id)
231 |
232 |
233 | def load_stopwords(path='data/english.txt'):
234 | with open(path, 'r', encoding='utf-8') as f:
235 | lines = f.readlines() # 取出所有行
236 | return [line.strip() for line in lines] #
237 |
238 |
239 | def remove_stopwords(documents, stopwords):
240 | cleaned_documents = []
241 | for token in documents.split():
242 | if token not in stopwords:
243 | cleaned_documents.append(token)
244 | return " ".join('%s' %a for a in cleaned_documents)
--------------------------------------------------------------------------------
/tvr/dataloaders/dataloader_didemo_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import os
7 | from torch.utils.data import Dataset
8 | import numpy as np
9 | import json
10 | from .rawvideo_util import RawVideoExtractor
11 |
12 |
13 | class DiDeMoDataset(Dataset):
14 | def __init__(
15 | self,
16 | subset,
17 | data_path,
18 | features_path,
19 | tokenizer,
20 | max_words=30,
21 | feature_framerate=1.0,
22 | max_frames=100,
23 | image_resolution=224,
24 | frame_order=0,
25 | slice_framepos=2,
26 | ):
27 | self.data_path = data_path
28 | self.features_path = features_path
29 | self.feature_framerate = feature_framerate
30 | self.max_words = max_words
31 | self.max_frames = max_frames
32 | self.tokenizer = tokenizer
33 | # 0: ordinary order; 1: reverse order; 2: random order.
34 | self.frame_order = frame_order
35 | assert self.frame_order in [0, 1, 2]
36 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
37 | self.slice_framepos = slice_framepos
38 | assert self.slice_framepos in [0, 1, 2]
39 |
40 | self.subset = subset
41 | assert self.subset in ["train", "val", "test", "train_test"]
42 |
43 | video_id_path_dict = {}
44 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt")
45 | video_id_path_dict["train_test"] = os.path.join(self.data_path, "train_list.txt")
46 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt")
47 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt")
48 |
49 | video_json_path_dict = {}
50 | video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json")
51 | video_json_path_dict["train_test"] = os.path.join(self.data_path, "train_data.json")
52 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json")
53 | video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json")
54 |
55 |
56 | with open(video_id_path_dict[self.subset], 'r') as fp:
57 | video_ids = [itm.strip() for itm in fp.readlines()]
58 |
59 | caption_dict = {}
60 | with open(video_json_path_dict[self.subset], 'r') as f:
61 | json_data = json.load(f)
62 | for itm in json_data:
63 | description = itm["description"]
64 | times = itm["times"]
65 | video = itm["video"]
66 | if video not in video_ids:
67 | continue
68 | # each video is split into 5-second temporal chunks
69 | # average the points from each annotator
70 | start_ = np.mean([t_[0] for t_ in times]) * 5
71 | end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5
72 | if video in caption_dict:
73 | caption_dict[video]["start"].append(start_)
74 | caption_dict[video]["end"].append(end_)
75 | caption_dict[video]["text"].append(description)
76 | else:
77 | caption_dict[video] = {}
78 | caption_dict[video]["start"] = [start_]
79 | caption_dict[video]["end"] = [end_]
80 | caption_dict[video]["text"] = [description]
81 |
82 | for k_ in caption_dict.keys():
83 | caption_dict[k_]["start"] = [0]
84 | # trick to save time on obtaining each video length
85 | # [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]:
86 | # Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation.
87 | caption_dict[k_]["end"] = [31]
88 | caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])]
89 |
90 | video_dict = {}
91 | for root, dub_dir, video_files in os.walk(self.features_path):
92 | for video_file in video_files:
93 | video_id_ = os.path.splitext(video_file)[0]
94 | if video_id_ not in video_ids:
95 | continue
96 | file_path_ = os.path.join(root, video_file)
97 | video_dict[video_id_] = file_path_
98 |
99 | self.caption_dict = caption_dict
100 | self.video_dict = video_dict
101 | video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys()))
102 |
103 | # Get all captions
104 | self.iter2video_pairs_dict = {}
105 | for video_id in self.caption_dict.keys():
106 | if video_id not in video_ids:
107 | continue
108 | caption = self.caption_dict[video_id]
109 | n_caption = len(caption['start'])
110 | for sub_id in range(n_caption):
111 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id)
112 |
113 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution)
114 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
115 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
116 |
117 | def __len__(self):
118 | return len(self.iter2video_pairs_dict)
119 |
120 | def _get_text(self, video_id, sub_id):
121 | caption = self.caption_dict[video_id]
122 | k = 1
123 | r_ind = [sub_id]
124 |
125 | starts = np.zeros(k, dtype=np.long)
126 | ends = np.zeros(k, dtype=np.long)
127 | pairs_text = np.zeros((k, self.max_words), dtype=np.long)
128 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
129 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
130 |
131 | for i in range(k):
132 | ind = r_ind[i]
133 | start_, end_ = caption['start'][ind], caption['end'][ind]
134 | words = self.tokenizer.tokenize(caption['text'][ind])
135 | starts[i], ends[i] = start_, end_
136 |
137 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
138 | total_length_with_CLS = self.max_words - 1
139 | if len(words) > total_length_with_CLS:
140 | words = words[:total_length_with_CLS]
141 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
142 |
143 | input_ids = self.tokenizer.convert_tokens_to_ids(words)
144 | input_mask = [1] * len(input_ids)
145 | segment_ids = [0] * len(input_ids)
146 | while len(input_ids) < self.max_words:
147 | input_ids.append(0)
148 | input_mask.append(0)
149 | segment_ids.append(0)
150 | assert len(input_ids) == self.max_words
151 | assert len(input_mask) == self.max_words
152 | assert len(segment_ids) == self.max_words
153 |
154 | pairs_text[i] = np.array(input_ids)
155 | pairs_mask[i] = np.array(input_mask)
156 | pairs_segment[i] = np.array(segment_ids)
157 |
158 | return pairs_text, pairs_mask, pairs_segment, starts, ends
159 |
160 | def _get_rawvideo(self, idx, s, e):
161 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
162 | max_video_length = [0] * len(s)
163 |
164 | # Pair x L x T x 3 x H x W
165 | video = np.zeros((len(s), self.max_frames, 1, 3,
166 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float)
167 | video_path = self.video_dict[idx]
168 |
169 | try:
170 | for i in range(len(s)):
171 | start_time = int(s[i])
172 | end_time = int(e[i])
173 | start_time = start_time if start_time >= 0. else 0.
174 | end_time = end_time if end_time >= 0. else 0.
175 | if start_time > end_time:
176 | start_time, end_time = end_time, start_time
177 | elif start_time == end_time:
178 | end_time = end_time + 1
179 |
180 | cache_id = "{}_{}_{}".format(video_path, start_time, end_time)
181 | # Should be optimized by gathering all asking of this video
182 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time)
183 | raw_video_data = raw_video_data['video']
184 |
185 | if len(raw_video_data.shape) > 3:
186 | raw_video_data_clip = raw_video_data
187 | # L x T x 3 x H x W
188 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip)
189 | if self.max_frames < raw_video_slice.shape[0]:
190 | if self.slice_framepos == 0:
191 | video_slice = raw_video_slice[:self.max_frames, ...]
192 | elif self.slice_framepos == 1:
193 | video_slice = raw_video_slice[-self.max_frames:, ...]
194 | else:
195 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int)
196 | video_slice = raw_video_slice[sample_indx, ...]
197 | else:
198 | video_slice = raw_video_slice
199 |
200 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order)
201 |
202 | slice_len = video_slice.shape[0]
203 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len
204 | if slice_len < 1:
205 | pass
206 | else:
207 | video[i][:slice_len, ...] = video_slice
208 | else:
209 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time))
210 | except Exception as excep:
211 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep))
212 | pass
213 | # raise e
214 |
215 | for i, v_length in enumerate(max_video_length):
216 | video_mask[i][:v_length] = [1] * v_length
217 |
218 | return video, video_mask
219 |
220 | def __getitem__(self, feature_idx):
221 | video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
222 |
223 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id)
224 | video, video_mask = self._get_rawvideo(video_id, starts, ends)
225 | return pairs_text, pairs_mask, video, video_mask, feature_idx, hash(video_id)
--------------------------------------------------------------------------------
/tvr/dataloaders/dataloader_lsmdc_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import json
7 | import tempfile
8 | import pandas as pd
9 | from os.path import join, splitext, exists
10 | from collections import OrderedDict
11 | from .dataloader_retrieval import RetrievalDataset
12 | import os
13 |
14 |
15 | class LsmdcDataset(RetrievalDataset):
16 | """LSMDC dataset."""
17 |
18 | def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32,
19 | max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None):
20 | super(LsmdcDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words,
21 | max_frames, video_framerate, image_resolution, mode, config=config)
22 | pass
23 |
24 | def _get_anns(self, subset='train'):
25 | """
26 | video_dict: dict: video_id -> video_path
27 | sentences_dict: list: [(video_id, caption)] , caption (list: [text:, start, end])
28 | """
29 | video_json_path_dict = {}
30 | video_json_path_dict["train"] = os.path.join(self.anno_path, "LSMDC16_annos_training.csv")
31 | #video_json_path_dict["train_test"] = os.path.join(self.anno_path, "LSMDC16_annos_training.csv")
32 | video_json_path_dict["train_test"] = os.path.join(self.anno_path, "LSMDC16_annos_val.csv")
33 | video_json_path_dict["val"] = os.path.join(self.anno_path, "LSMDC16_annos_val.csv")
34 | video_json_path_dict["test"] = os.path.join(self.anno_path, "LSMDC16_challenge_1000_publictect.csv")
35 |
36 | # \t\t\t\t\t
37 | # is not a unique identifier, i.e. the same can be associated with multiple sentences.
38 | # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances
39 | video_id_list = []
40 | caption_dict = {}
41 | with open(video_json_path_dict[self.subset], 'r') as fp:
42 | for line in fp:
43 | line = line.strip()
44 | line_split = line.split("\t")
45 | assert len(line_split) == 6
46 | clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split
47 | if clip_id not in ["0017_Pianist_00.23.28.872-00.23.34.843", "0017_Pianist_00.30.36.767-00.30.38.009",
48 | "3064_SPARKLE_2012_01.41.07.000-01.41.11.793"]:
49 | caption_dict[len(caption_dict)] = (clip_id, (sentence, None, None))
50 | if clip_id not in video_id_list: video_id_list.append(clip_id)
51 |
52 | video_dict = OrderedDict()
53 | sentences_dict = OrderedDict()
54 |
55 | for root, dub_dir, video_files in os.walk(self.video_path):
56 | for video_file in video_files:
57 | video_id_ = ".".join(video_file.split(".")[:-1])
58 | if video_id_ not in video_id_list:
59 | continue
60 | file_path_ = os.path.join(root, video_file)
61 | video_dict[video_id_] = file_path_
62 |
63 | # Get all captions
64 | for clip_id, sentence in caption_dict.values():
65 | if clip_id not in video_dict:
66 | continue
67 | sentences_dict[len(sentences_dict)] = (clip_id, sentence)
68 |
69 | unique_sentence = set([v[1][0] for v in sentences_dict.values()])
70 | print('[{}] Unique sentence is {} , all num is {}'.format(subset, len(unique_sentence), len(sentences_dict)))
71 |
72 | return video_dict, sentences_dict
--------------------------------------------------------------------------------
/tvr/dataloaders/dataloader_msrvtt_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import json
7 | import tempfile
8 | import pandas as pd
9 | from os.path import join, splitext, exists
10 | from collections import OrderedDict
11 | from .dataloader_retrieval import RetrievalDataset
12 |
13 |
14 | class MSRVTTDataset(RetrievalDataset):
15 | """MSRVTT dataset."""
16 |
17 | def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32,
18 | max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None):
19 | super(MSRVTTDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words,
20 | max_frames, video_framerate, image_resolution, mode, config=config)
21 | pass
22 |
23 | def _get_anns(self, subset='train'):
24 | """
25 | video_dict: dict: video_id -> video_path
26 | sentences_dict: list: [(video_id, caption)] , caption (list: [text:, start, end])
27 | """
28 | csv_path = {'train': join(self.anno_path, 'MSRVTT_train.9k.csv'),
29 | 'val': join(self.anno_path, 'MSRVTT_JSFUSION_test.csv'),
30 | 'test': join(self.anno_path, 'MSRVTT_JSFUSION_test.csv'),
31 | 'train_test': join(self.anno_path, 'MSRVTT_train.9k.csv')}[subset]
32 | if exists(csv_path):
33 | csv = pd.read_csv(csv_path)
34 | else:
35 | raise FileNotFoundError
36 |
37 | video_id_list = list(csv['video_id'].values)
38 |
39 | video_dict = OrderedDict()
40 | sentences_dict = OrderedDict()
41 | if subset == 'train':
42 | anno_path = join(self.anno_path, 'MSRVTT_data.json')
43 | data = json.load(open(anno_path, 'r'))
44 | for itm in data['sentences']:
45 | if itm['video_id'] in video_id_list:
46 | sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['caption'], None, None))
47 | video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id']))
48 | elif subset == 'train_test':
49 | anno_path = join(self.anno_path, 'MSRVTT_data.json')
50 | data = json.load(open(anno_path, 'r'))
51 | used = []
52 | for itm in data['sentences']:
53 | if itm['video_id'] in video_id_list and itm['video_id'] not in used:
54 | used.append(itm['video_id'])
55 | sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['caption'], None, None))
56 | video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id']))
57 | else:
58 | for _, itm in csv.iterrows():
59 | sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['sentence'], None, None))
60 | video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id']))
61 |
62 | unique_sentence = set([v[1][0] for v in sentences_dict.values()])
63 | print('[{}] Unique sentence is {} , all num is {}'.format(subset, len(unique_sentence), len(sentences_dict)))
64 |
65 | return video_dict, sentences_dict
66 |
--------------------------------------------------------------------------------
/tvr/dataloaders/dataloader_msvd_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import json
7 | import tempfile
8 | import os
9 | import pickle
10 | import pandas as pd
11 | from os.path import join, splitext, exists
12 | from collections import OrderedDict
13 | from .dataloader_retrieval import RetrievalDataset
14 |
15 |
16 | class MsvdDataset(RetrievalDataset):
17 | """MSVD dataset loader."""
18 | def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32,
19 | max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None):
20 | super(MsvdDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words,
21 | max_frames, video_framerate, image_resolution, mode, config=config)
22 | pass
23 |
24 | def _get_anns(self, subset='train'):
25 | self.sample_len = 0
26 | self.cut_off_points = []
27 | self.multi_sentence_per_video = True # !!! important tag for eval
28 |
29 | video_id_path_dict = {}
30 | video_id_path_dict["train"] = os.path.join(self.anno_path, "train_list.txt")
31 | video_id_path_dict["train_test"] = os.path.join(self.anno_path, "train_list.txt")
32 | video_id_path_dict["val"] = os.path.join(self.anno_path, "val_list.txt")
33 | video_id_path_dict["test"] = os.path.join(self.anno_path, "test_list.txt")
34 | caption_file = os.path.join(self.anno_path, "raw-captions.pkl")
35 |
36 | with open(video_id_path_dict[subset], 'r') as fp:
37 | video_ids = [itm.strip() for itm in fp.readlines()]
38 |
39 | with open(caption_file, 'rb') as f:
40 | captions = pickle.load(f)
41 |
42 | video_dict = OrderedDict()
43 | sentences_dict = OrderedDict()
44 |
45 | for root, dub_dir, video_files in os.walk(self.video_path):
46 | for video_file in video_files:
47 | video_id_ = ".".join(video_file.split(".")[:-1])
48 | if video_id_ not in video_ids:
49 | continue
50 | file_path_ = os.path.join(root, video_file)
51 | video_dict[video_id_] = file_path_
52 |
53 | for video_id in video_ids:
54 | assert video_id in captions
55 | for cap in captions[video_id]:
56 | cap_txt = " ".join(cap)
57 | sentences_dict[len(sentences_dict)] = (video_id, (cap_txt, None, None))
58 | self.cut_off_points.append(len(sentences_dict) - 1)
59 |
60 | if subset == "val" or subset == "test":
61 | self.sentence_num = len(sentences_dict)
62 | self.video_num = len(video_ids)
63 | assert len(self.cut_off_points) == self.video_num
64 | print("For {}, sentence number: {}".format(subset, self.sentence_num))
65 | print("For {}, video number: {}".format(subset, self.video_num))
66 |
67 | print("Video number: {}".format(len(video_dict)))
68 | print("Total Paire: {}".format(len(sentences_dict)))
69 |
70 | self.sample_len = len(sentences_dict)
71 |
72 | return video_dict, sentences_dict
--------------------------------------------------------------------------------
/tvr/dataloaders/dataloader_retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | from os.path import exists
7 |
8 | import random
9 | import numpy as np
10 | from torch.utils.data import Dataset
11 |
12 | import torch
13 | from PIL import Image
14 | from decord import VideoReader, cpu
15 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, ToPILImage, ColorJitter, RandomHorizontalFlip, RandomResizedCrop
16 | import tvr.dataloaders.video_transforms as video_transforms
17 | from .random_erasing import RandomErasing
18 |
19 |
20 | class RetrievalDataset(Dataset):
21 | """General dataset."""
22 |
23 | def __init__(
24 | self,
25 | subset,
26 | anno_path,
27 | video_path,
28 | tokenizer,
29 | max_words=30,
30 | max_frames=12,
31 | video_framerate=1,
32 | image_resolution=224,
33 | mode='all',
34 | config=None
35 | ):
36 | self.subset = subset
37 | self.anno_path = anno_path
38 | self.video_path = video_path
39 | self.tokenizer = tokenizer
40 | self.max_words = max_words
41 | self.max_frames = max_frames
42 | self.video_framerate = video_framerate
43 | self.image_resolution = image_resolution
44 | self.mode = mode # all/text/vision
45 | self.config = config
46 |
47 | self.video_dict, self.sentences_dict = self._get_anns(self.subset)
48 |
49 | self.video_list = list(self.video_dict.keys())
50 | self.sample_len = 0
51 |
52 | print("Video number: {}".format(len(self.video_dict)))
53 | print("Total Pairs: {}".format(len(self.sentences_dict)))
54 |
55 | from .rawvideo_util import RawVideoExtractor
56 | self.rawVideoExtractor = RawVideoExtractor(framerate=video_framerate, size=image_resolution)
57 | self.transform = Compose([
58 | Resize(image_resolution, interpolation=InterpolationMode.BICUBIC),
59 | CenterCrop(image_resolution),
60 | lambda image: image.convert("RGB"),
61 | ToTensor(),
62 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
63 | # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
64 | ])
65 | self.tsfm_dict = {
66 | 'clip_test': Compose([
67 | Resize(image_resolution, interpolation=InterpolationMode.BICUBIC),
68 | CenterCrop(image_resolution),
69 | lambda image: image.convert("RGB"),
70 | ToTensor(),
71 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
72 | ]),
73 | 'clip_train': Compose([
74 | RandomResizedCrop(image_resolution, scale=(0.5, 1.0)),
75 | RandomHorizontalFlip(),
76 | lambda image: image.convert("RGB"),
77 | ToTensor(),
78 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
79 | ])
80 | }
81 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>",
82 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"}
83 | self.image_resolution = image_resolution
84 | if self.mode in ['all', 'text']:
85 | self.sample_len = len(self.sentences_dict)
86 | else:
87 | self.sample_len = len(self.video_list)
88 | self.aug_transform = video_transforms.create_random_augment(
89 | input_size=(self.image_resolution, self.image_resolution),
90 | auto_augment='rand-m7-n4-mstd0.5-inc1',
91 | interpolation='bicubic',
92 | )
93 |
94 | def __len__(self):
95 | return self.sample_len
96 |
97 | def __aug_transform(self, buffer):
98 | _aug_transform = video_transforms.create_random_augment(
99 | input_size=(self.image_resolution, self.image_resolution),
100 | auto_augment='rand-m7-n4-mstd0.5-inc1',
101 | interpolation='bicubic',
102 | )
103 | buffer = _aug_transform(buffer)
104 | return buffer
105 | buffer = [ToTensor()(img) for img in buffer]
106 | buffer = torch.stack(buffer) # T C H W
107 | buffer = buffer.permute(1, 0, 2, 3) # T H W C -> C T H W.
108 | # Perform data augmentation.
109 | scl, asp = (
110 | [0.08, 1.0],
111 | [0.75, 1.3333],
112 | )
113 |
114 | buffer = spatial_sampling(
115 | buffer,
116 | spatial_idx=-1,
117 | min_scale=256,
118 | max_scale=320,
119 | crop_size=224,
120 | random_horizontal_flip=True,
121 | inverse_uniform_sampling=False,
122 | aspect_ratio=asp,
123 | scale=scl,
124 | motion_shift=False
125 | )
126 | buffer = buffer.permute(1, 0, 2, 3)
127 | buffer = [ToPILImage()(frame) for frame in buffer]
128 | return buffer
129 | erase_transform = RandomErasing(
130 | 0.25,
131 | mode='pixel',
132 | max_count=1,
133 | num_splits=1,
134 | device="cpu",
135 | )
136 | buffer = buffer.permute(1, 0, 2, 3)
137 | buffer = erase_transform(buffer)
138 | buffer = [ToPILImage()(frame) for frame in buffer]
139 | return buffer
140 |
141 | def _get_anns(self, subset='train'):
142 | raise NotImplementedError
143 |
144 | def _get_text(self, caption):
145 | if len(caption) == 3:
146 | _caption_text, s, e = caption
147 | else:
148 | raise NotImplementedError
149 |
150 | if isinstance(_caption_text, list):
151 | caption_text = random.choice(_caption_text)
152 | else:
153 | caption_text = _caption_text
154 |
155 | words = self.tokenizer.tokenize(caption_text)
156 |
157 | if self.subset == "train" and 0:
158 | if random.random() < 0.5:
159 | new_words = []
160 | for idx in range(len(words)):
161 | if random.random() < 0.8:
162 | new_words.append(words[idx])
163 | words = new_words
164 |
165 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
166 | total_length_with_CLS = self.max_words - 1
167 | if len(words) > total_length_with_CLS:
168 | words = words[:total_length_with_CLS]
169 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]
170 |
171 | input_ids = self.tokenizer.convert_tokens_to_ids(words)
172 | input_mask = [1] * len(input_ids)
173 |
174 | while len(input_ids) < self.max_words:
175 | input_ids.append(0)
176 | input_mask.append(0)
177 | assert len(input_ids) == self.max_words
178 | assert len(input_mask) == self.max_words
179 |
180 | input_ids = np.array(input_ids)
181 | input_mask = np.array(input_mask)
182 |
183 | return input_ids, input_mask, s, e
184 |
185 | def _get_rawvideo(self, video_id, s=None, e=None):
186 | video_mask = np.zeros(self.max_frames, dtype=np.long)
187 | max_video_length = 0
188 |
189 | # T x 3 x H x W
190 | video = np.zeros((self.max_frames, 3, self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float)
191 |
192 | if s is None:
193 | start_time, end_time = None, None
194 | else:
195 | start_time = int(s)
196 | end_time = int(e)
197 | start_time = start_time if start_time >= 0. else 0.
198 | end_time = end_time if end_time >= 0. else 0.
199 | if start_time > end_time:
200 | start_time, end_time = end_time, start_time
201 | elif start_time == end_time:
202 | end_time = end_time + 1
203 | video_path = self.video_dict[video_id]
204 |
205 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time)
206 | raw_video_data = raw_video_data['video']
207 |
208 | if len(raw_video_data.shape) > 3:
209 | # L x T x 3 x H x W
210 |
211 | if self.max_frames < raw_video_data.shape[0]:
212 | sample_indx = np.linspace(0, raw_video_data.shape[0] - 1, num=self.max_frames, dtype=int)
213 | video_slice = raw_video_data[sample_indx, ...]
214 | else:
215 | video_slice = raw_video_data
216 |
217 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=0)
218 |
219 | slice_len = video_slice.shape[0]
220 | max_video_length = max_video_length if max_video_length > slice_len else slice_len
221 | if slice_len < 1:
222 | pass
223 | else:
224 | video[:slice_len, ...] = video_slice
225 | else:
226 | print("video path: {} error. video id: {}".format(video_path, video_id))
227 |
228 | video_mask[:max_video_length] = [1] * max_video_length
229 |
230 | return video, video_mask
231 |
232 | def _get_rawvideo_dec(self, video_id, s=None, e=None):
233 | # speed up video decode via decord.
234 | video_mask = np.zeros(self.max_frames, dtype=np.long)
235 | max_video_length = 0
236 |
237 | # T x 3 x H x W
238 | video = np.zeros((self.max_frames, 3, self.image_resolution, self.image_resolution), dtype=np.float)
239 |
240 | if s is None:
241 | start_time, end_time = None, None
242 | else:
243 | start_time = int(s)
244 | end_time = int(e)
245 | start_time = start_time if start_time >= 0. else 0.
246 | end_time = end_time if end_time >= 0. else 0.
247 | if start_time > end_time:
248 | start_time, end_time = end_time, start_time
249 | elif start_time == end_time:
250 | end_time = start_time + 1
251 | video_path = self.video_dict[video_id]
252 |
253 | if exists(video_path):
254 | vreader = VideoReader(video_path, ctx=cpu(0))
255 | else:
256 | print(video_path)
257 | raise FileNotFoundError
258 |
259 | fps = vreader.get_avg_fps()
260 | f_start = 0 if start_time is None else int(start_time * fps)
261 | f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
262 | num_frames = f_end - f_start + 1
263 | if num_frames > 0:
264 | # T x 3 x H x W
265 | sample_fps = int(self.video_framerate)
266 | t_stride = int(round(float(fps) / sample_fps))
267 |
268 | all_pos = list(range(f_start, f_end + 1, t_stride))
269 | if len(all_pos) > self.max_frames:
270 | sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=self.max_frames, dtype=int)]
271 | else:
272 | sample_pos = all_pos
273 |
274 | patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
275 | if self.subset == "train":
276 | # for i in range(2):
277 | patch_images = self.aug_transform(patch_images)
278 |
279 | # if self.subset == "train":
280 | # patch_images = torch.stack([self.tsfm_dict["clip_train"](img) for img in patch_images])
281 | # else:
282 | # patch_images = torch.stack([self.tsfm_dict["clip_test"](img) for img in patch_images])
283 |
284 | patch_images = torch.stack([self.transform(img) for img in patch_images])
285 | slice_len = patch_images.shape[0]
286 | max_video_length = max_video_length if max_video_length > slice_len else slice_len
287 | if slice_len < 1:
288 | pass
289 | else:
290 | video[:slice_len, ...] = patch_images
291 | else:
292 | print("video path: {} error. video id: {}".format(video_path, video_id))
293 |
294 | video_mask[:max_video_length] = [1] * max_video_length
295 |
296 | return video, video_mask
297 |
298 | def __getitem__(self, idx):
299 |
300 | if self.mode == 'all':
301 | video_id, caption = self.sentences_dict[idx]
302 | text_ids, text_mask, s, e = self._get_text(caption)
303 | video, video_mask = self._get_rawvideo_dec(video_id, s, e)
304 | # video, video_mask = self._get_rawvideo(video_id, s, e)
305 | return text_ids, text_mask, video, video_mask, idx, hash(video_id.replace("video", ""))
306 | elif self.mode == 'text':
307 | video_id, caption = self.sentences_dict[idx]
308 | text_ids, text_mask, s, e = self._get_text(caption)
309 | return text_ids, text_mask, idx
310 | elif self.mode == 'video':
311 | video_id = self.video_list[idx]
312 | video, video_mask = self._get_rawvideo_dec(video_id)
313 | # video, video_mask = self._get_rawvideo(video_id)
314 | return video, video_mask, idx
315 |
316 | def get_text_len(self):
317 | return len(self.sentences_dict)
318 |
319 | def get_video_len(self):
320 | return len(self.video_list)
321 |
322 | def get_text_content(self, ind):
323 | return self.sentences_dict[ind][1]
324 |
325 | def get_data_name(self):
326 | return self.__class__.__name__ + "_" + self.subset
327 |
328 | def get_vis_info(self, idx):
329 | video_id, caption = self.sentences_dict[idx]
330 | video_path = self.video_dict[video_id]
331 | return caption, video_path
332 |
333 |
334 | def spatial_sampling(
335 | frames,
336 | spatial_idx=-1,
337 | min_scale=256,
338 | max_scale=320,
339 | crop_size=224,
340 | random_horizontal_flip=True,
341 | inverse_uniform_sampling=False,
342 | aspect_ratio=None,
343 | scale=None,
344 | motion_shift=False,
345 | ):
346 | """
347 | Perform spatial sampling on the given video frames. If spatial_idx is
348 | -1, perform random scale, random crop, and random flip on the given
349 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
350 | with the given spatial_idx.
351 | Args:
352 | frames (tensor): frames of images sampled from the video. The
353 | dimension is `num frames` x `height` x `width` x `channel`.
354 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
355 | or 2, perform left, center, right crop if width is larger than
356 | height, and perform top, center, buttom crop if height is larger
357 | than width.
358 | min_scale (int): the minimal size of scaling.
359 | max_scale (int): the maximal size of scaling.
360 | crop_size (int): the size of height and width used to crop the
361 | frames.
362 | inverse_uniform_sampling (bool): if True, sample uniformly in
363 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
364 | scale. If False, take a uniform sample from [min_scale,
365 | max_scale].
366 | aspect_ratio (list): Aspect ratio range for resizing.
367 | scale (list): Scale range for resizing.
368 | motion_shift (bool): Whether to apply motion shift for resizing.
369 | Returns:
370 | frames (tensor): spatially sampled frames.
371 | """
372 | assert spatial_idx in [-1, 0, 1, 2]
373 | if spatial_idx == -1:
374 | if aspect_ratio is None and scale is None:
375 | frames, _ = video_transforms.random_short_side_scale_jitter(
376 | images=frames,
377 | min_size=min_scale,
378 | max_size=max_scale,
379 | inverse_uniform_sampling=inverse_uniform_sampling,
380 | )
381 | frames, _ = video_transforms.random_crop(frames, crop_size)
382 | else:
383 | transform_func = (
384 | video_transforms.random_resized_crop_with_shift
385 | if motion_shift
386 | else video_transforms.random_resized_crop
387 | )
388 | frames = transform_func(
389 | images=frames,
390 | target_height=crop_size,
391 | target_width=crop_size,
392 | scale=scale,
393 | ratio=aspect_ratio,
394 | )
395 | if random_horizontal_flip:
396 | frames, _ = video_transforms.horizontal_flip(0.5, frames)
397 | else:
398 | # The testing is deterministic and no jitter should be performed.
399 | # min_scale, max_scale, and crop_size are expect to be the same.
400 | assert len({min_scale, max_scale, crop_size}) == 1
401 | frames, _ = video_transforms.random_short_side_scale_jitter(
402 | frames, min_scale, max_scale
403 | )
404 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
405 | return frames
--------------------------------------------------------------------------------
/tvr/dataloaders/rand_augment.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
4 | pulished under an Apache License 2.0.
5 |
6 | COMMENT FROM ORIGINAL:
7 | AutoAugment, RandAugment, and AugMix for PyTorch
8 | This code implements the searched ImageNet policies with various tweaks and
9 | improvements and does not include any of the search code. AA and RA
10 | Implementation adapted from:
11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12 | AugMix adapted from:
13 | https://github.com/google-research/augmix
14 | Papers:
15 | AutoAugment: Learning Augmentation Policies from Data
16 | https://arxiv.org/abs/1805.09501
17 | Learning Data Augmentation Strategies for Object Detection
18 | https://arxiv.org/abs/1906.11172
19 | RandAugment: Practical automated data augmentation...
20 | https://arxiv.org/abs/1909.13719
21 | AugMix: A Simple Data Processing Method to Improve Robustness and
22 | Uncertainty https://arxiv.org/abs/1912.02781
23 |
24 | Hacked together by / Copyright 2020 Ross Wightman
25 | """
26 |
27 | import math
28 | import numpy as np
29 | import random
30 | import re
31 | import PIL
32 | from PIL import Image, ImageEnhance, ImageOps
33 |
34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
35 |
36 | _FILL = (128, 128, 128)
37 |
38 | # This signifies the max integer that the controller RNN could predict for the
39 | # augmentation scheme.
40 | _MAX_LEVEL = 10.0
41 |
42 | _HPARAMS_DEFAULT = {
43 | "translate_const": 250,
44 | "img_mean": _FILL,
45 | }
46 |
47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
48 |
49 |
50 | def _interpolation(kwargs):
51 | interpolation = kwargs.pop("resample", Image.BILINEAR)
52 | if isinstance(interpolation, (list, tuple)):
53 | return random.choice(interpolation)
54 | else:
55 | return interpolation
56 |
57 |
58 | def _check_args_tf(kwargs):
59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
60 | kwargs.pop("fillcolor")
61 | kwargs["resample"] = _interpolation(kwargs)
62 |
63 |
64 | def shear_x(img, factor, **kwargs):
65 | _check_args_tf(kwargs)
66 | return img.transform(
67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
68 | )
69 |
70 |
71 | def shear_y(img, factor, **kwargs):
72 | _check_args_tf(kwargs)
73 | return img.transform(
74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
75 | )
76 |
77 |
78 | def translate_x_rel(img, pct, **kwargs):
79 | pixels = pct * img.size[0]
80 | _check_args_tf(kwargs)
81 | return img.transform(
82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
83 | )
84 |
85 |
86 | def translate_y_rel(img, pct, **kwargs):
87 | pixels = pct * img.size[1]
88 | _check_args_tf(kwargs)
89 | return img.transform(
90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
91 | )
92 |
93 |
94 | def translate_x_abs(img, pixels, **kwargs):
95 | _check_args_tf(kwargs)
96 | return img.transform(
97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
98 | )
99 |
100 |
101 | def translate_y_abs(img, pixels, **kwargs):
102 | _check_args_tf(kwargs)
103 | return img.transform(
104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
105 | )
106 |
107 |
108 | def rotate(img, degrees, **kwargs):
109 | _check_args_tf(kwargs)
110 | if _PIL_VER >= (5, 2):
111 | return img.rotate(degrees, **kwargs)
112 | elif _PIL_VER >= (5, 0):
113 | w, h = img.size
114 | post_trans = (0, 0)
115 | rotn_center = (w / 2.0, h / 2.0)
116 | angle = -math.radians(degrees)
117 | matrix = [
118 | round(math.cos(angle), 15),
119 | round(math.sin(angle), 15),
120 | 0.0,
121 | round(-math.sin(angle), 15),
122 | round(math.cos(angle), 15),
123 | 0.0,
124 | ]
125 |
126 | def transform(x, y, matrix):
127 | (a, b, c, d, e, f) = matrix
128 | return a * x + b * y + c, d * x + e * y + f
129 |
130 | matrix[2], matrix[5] = transform(
131 | -rotn_center[0] - post_trans[0],
132 | -rotn_center[1] - post_trans[1],
133 | matrix,
134 | )
135 | matrix[2] += rotn_center[0]
136 | matrix[5] += rotn_center[1]
137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
138 | else:
139 | return img.rotate(degrees, resample=kwargs["resample"])
140 |
141 |
142 | def auto_contrast(img, **__):
143 | return ImageOps.autocontrast(img)
144 |
145 |
146 | def invert(img, **__):
147 | return ImageOps.invert(img)
148 |
149 |
150 | def equalize(img, **__):
151 | return ImageOps.equalize(img)
152 |
153 |
154 | def solarize(img, thresh, **__):
155 | return ImageOps.solarize(img, thresh)
156 |
157 |
158 | def solarize_add(img, add, thresh=128, **__):
159 | lut = []
160 | for i in range(256):
161 | if i < thresh:
162 | lut.append(min(255, i + add))
163 | else:
164 | lut.append(i)
165 | if img.mode in ("L", "RGB"):
166 | if img.mode == "RGB" and len(lut) == 256:
167 | lut = lut + lut + lut
168 | return img.point(lut)
169 | else:
170 | return img
171 |
172 |
173 | def posterize(img, bits_to_keep, **__):
174 | if bits_to_keep >= 8:
175 | return img
176 | return ImageOps.posterize(img, bits_to_keep)
177 |
178 |
179 | def contrast(img, factor, **__):
180 | return ImageEnhance.Contrast(img).enhance(factor)
181 |
182 |
183 | def color(img, factor, **__):
184 | return ImageEnhance.Color(img).enhance(factor)
185 |
186 |
187 | def brightness(img, factor, **__):
188 | return ImageEnhance.Brightness(img).enhance(factor)
189 |
190 |
191 | def sharpness(img, factor, **__):
192 | return ImageEnhance.Sharpness(img).enhance(factor)
193 |
194 |
195 | def _randomly_negate(v):
196 | """With 50% prob, negate the value"""
197 | return -v if random.random() > 0.5 else v
198 |
199 |
200 | def _rotate_level_to_arg(level, _hparams):
201 | # range [-30, 30]
202 | level = (level / _MAX_LEVEL) * 30.0
203 | level = _randomly_negate(level)
204 | return (level,)
205 |
206 |
207 | def _enhance_level_to_arg(level, _hparams):
208 | # range [0.1, 1.9]
209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
210 |
211 |
212 | def _enhance_increasing_level_to_arg(level, _hparams):
213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
214 | # range [0.1, 1.9]
215 | level = (level / _MAX_LEVEL) * 0.9
216 | level = 1.0 + _randomly_negate(level)
217 | return (level,)
218 |
219 |
220 | def _shear_level_to_arg(level, _hparams):
221 | # range [-0.3, 0.3]
222 | level = (level / _MAX_LEVEL) * 0.3
223 | level = _randomly_negate(level)
224 | return (level,)
225 |
226 |
227 | def _translate_abs_level_to_arg(level, hparams):
228 | translate_const = hparams["translate_const"]
229 | level = (level / _MAX_LEVEL) * float(translate_const)
230 | level = _randomly_negate(level)
231 | return (level,)
232 |
233 |
234 | def _translate_rel_level_to_arg(level, hparams):
235 | # default range [-0.45, 0.45]
236 | translate_pct = hparams.get("translate_pct", 0.45)
237 | level = (level / _MAX_LEVEL) * translate_pct
238 | level = _randomly_negate(level)
239 | return (level,)
240 |
241 |
242 | def _posterize_level_to_arg(level, _hparams):
243 | # As per Tensorflow TPU EfficientNet impl
244 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
245 | # intensity/severity of augmentation decreases with level
246 | return (int((level / _MAX_LEVEL) * 4),)
247 |
248 |
249 | def _posterize_increasing_level_to_arg(level, hparams):
250 | # As per Tensorflow models research and UDA impl
251 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
252 | # intensity/severity of augmentation increases with level
253 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
254 |
255 |
256 | def _posterize_original_level_to_arg(level, _hparams):
257 | # As per original AutoAugment paper description
258 | # range [4, 8], 'keep 4 up to 8 MSB of image'
259 | # intensity/severity of augmentation decreases with level
260 | return (int((level / _MAX_LEVEL) * 4) + 4,)
261 |
262 |
263 | def _solarize_level_to_arg(level, _hparams):
264 | # range [0, 256]
265 | # intensity/severity of augmentation decreases with level
266 | return (int((level / _MAX_LEVEL) * 256),)
267 |
268 |
269 | def _solarize_increasing_level_to_arg(level, _hparams):
270 | # range [0, 256]
271 | # intensity/severity of augmentation increases with level
272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
273 |
274 |
275 | def _solarize_add_level_to_arg(level, _hparams):
276 | # range [0, 110]
277 | return (int((level / _MAX_LEVEL) * 110),)
278 |
279 |
280 | LEVEL_TO_ARG = {
281 | "AutoContrast": None,
282 | "Equalize": None,
283 | "Invert": None,
284 | "Rotate": _rotate_level_to_arg,
285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
286 | "Posterize": _posterize_level_to_arg,
287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
288 | "PosterizeOriginal": _posterize_original_level_to_arg,
289 | "Solarize": _solarize_level_to_arg,
290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
291 | "SolarizeAdd": _solarize_add_level_to_arg,
292 | "Color": _enhance_level_to_arg,
293 | "ColorIncreasing": _enhance_increasing_level_to_arg,
294 | "Contrast": _enhance_level_to_arg,
295 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
296 | "Brightness": _enhance_level_to_arg,
297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
298 | "Sharpness": _enhance_level_to_arg,
299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
300 | "ShearX": _shear_level_to_arg,
301 | "ShearY": _shear_level_to_arg,
302 | "TranslateX": _translate_abs_level_to_arg,
303 | "TranslateY": _translate_abs_level_to_arg,
304 | "TranslateXRel": _translate_rel_level_to_arg,
305 | "TranslateYRel": _translate_rel_level_to_arg,
306 | }
307 |
308 |
309 | NAME_TO_OP = {
310 | "AutoContrast": auto_contrast,
311 | "Equalize": equalize,
312 | "Invert": invert,
313 | "Rotate": rotate,
314 | "Posterize": posterize,
315 | "PosterizeIncreasing": posterize,
316 | "PosterizeOriginal": posterize,
317 | "Solarize": solarize,
318 | "SolarizeIncreasing": solarize,
319 | "SolarizeAdd": solarize_add,
320 | "Color": color,
321 | "ColorIncreasing": color,
322 | "Contrast": contrast,
323 | "ContrastIncreasing": contrast,
324 | "Brightness": brightness,
325 | "BrightnessIncreasing": brightness,
326 | "Sharpness": sharpness,
327 | "SharpnessIncreasing": sharpness,
328 | "ShearX": shear_x,
329 | "ShearY": shear_y,
330 | "TranslateX": translate_x_abs,
331 | "TranslateY": translate_y_abs,
332 | "TranslateXRel": translate_x_rel,
333 | "TranslateYRel": translate_y_rel,
334 | }
335 |
336 |
337 | class AugmentOp:
338 | """
339 | Apply for video.
340 | """
341 |
342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
343 | hparams = hparams or _HPARAMS_DEFAULT
344 | self.aug_fn = NAME_TO_OP[name]
345 | self.level_fn = LEVEL_TO_ARG[name]
346 | self.prob = prob
347 | self.magnitude = magnitude
348 | self.hparams = hparams.copy()
349 | self.kwargs = {
350 | "fillcolor": hparams["img_mean"]
351 | if "img_mean" in hparams
352 | else _FILL,
353 | "resample": hparams["interpolation"]
354 | if "interpolation" in hparams
355 | else _RANDOM_INTERPOLATION,
356 | }
357 |
358 | # If magnitude_std is > 0, we introduce some randomness
359 | # in the usually fixed policy and sample magnitude from a normal distribution
360 | # with mean `magnitude` and std-dev of `magnitude_std`.
361 | # NOTE This is my own hack, being tested, not in papers or reference impls.
362 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
363 |
364 | def __call__(self, img_list):
365 | if self.prob < 1.0 and random.random() > self.prob:
366 | return img_list
367 | magnitude = self.magnitude
368 | if self.magnitude_std and self.magnitude_std > 0:
369 | magnitude = random.gauss(magnitude, self.magnitude_std)
370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
371 | level_args = (
372 | self.level_fn(magnitude, self.hparams)
373 | if self.level_fn is not None
374 | else ()
375 | )
376 |
377 | if isinstance(img_list, list):
378 | return [
379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
380 | ]
381 | else:
382 | return self.aug_fn(img_list, *level_args, **self.kwargs)
383 |
384 |
385 | _RAND_TRANSFORMS = [
386 | "AutoContrast",
387 | "Equalize",
388 | "Invert",
389 | "Rotate",
390 | "Posterize",
391 | "Solarize",
392 | "SolarizeAdd",
393 | "Color",
394 | "Contrast",
395 | "Brightness",
396 | "Sharpness",
397 | "ShearX",
398 | "ShearY",
399 | "TranslateXRel",
400 | "TranslateYRel",
401 | ]
402 |
403 |
404 | _RAND_INCREASING_TRANSFORMS = [
405 | "AutoContrast",
406 | "Equalize",
407 | "Invert",
408 | "Rotate",
409 | "PosterizeIncreasing",
410 | "SolarizeIncreasing",
411 | "SolarizeAdd",
412 | "ColorIncreasing",
413 | "ContrastIncreasing",
414 | "BrightnessIncreasing",
415 | "SharpnessIncreasing",
416 | "ShearX",
417 | "ShearY",
418 | "TranslateXRel",
419 | "TranslateYRel",
420 | ]
421 |
422 |
423 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
424 | # They may not result in increased performance, but could likely be tuned to so.
425 | _RAND_CHOICE_WEIGHTS_0 = {
426 | "Rotate": 0.3,
427 | "ShearX": 0.2,
428 | "ShearY": 0.2,
429 | "TranslateXRel": 0.1,
430 | "TranslateYRel": 0.1,
431 | "Color": 0.025,
432 | "Sharpness": 0.025,
433 | "AutoContrast": 0.025,
434 | "Solarize": 0.005,
435 | "SolarizeAdd": 0.005,
436 | "Contrast": 0.005,
437 | "Brightness": 0.005,
438 | "Equalize": 0.005,
439 | "Posterize": 0,
440 | "Invert": 0,
441 | }
442 |
443 |
444 | def _select_rand_weights(weight_idx=0, transforms=None):
445 | transforms = transforms or _RAND_TRANSFORMS
446 | assert weight_idx == 0 # only one set of weights currently
447 | rand_weights = _RAND_CHOICE_WEIGHTS_0
448 | probs = [rand_weights[k] for k in transforms]
449 | probs /= np.sum(probs)
450 | return probs
451 |
452 |
453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
454 | hparams = hparams or _HPARAMS_DEFAULT
455 | transforms = transforms or _RAND_TRANSFORMS
456 | return [
457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
458 | for name in transforms
459 | ]
460 |
461 |
462 | class RandAugment:
463 | def __init__(self, ops, num_layers=2, choice_weights=None):
464 | self.ops = ops
465 | self.num_layers = num_layers
466 | self.choice_weights = choice_weights
467 |
468 | def __call__(self, img):
469 | # no replacement when using weighted choice
470 | ops = np.random.choice(
471 | self.ops,
472 | self.num_layers,
473 | replace=self.choice_weights is None,
474 | p=self.choice_weights,
475 | )
476 | for op in ops:
477 | img = op(img)
478 | return img
479 |
480 |
481 | def rand_augment_transform(config_str, hparams):
482 | """
483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
484 |
485 | Create a RandAugment transform
486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
488 | sections, not order sepecific determine
489 | 'm' - integer magnitude of rand augment
490 | 'n' - integer num layers (number of transform ops selected per image)
491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
492 | 'mstd' - float std deviation of magnitude noise applied
493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
497 | :return: A PyTorch compatible Transform
498 | """
499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
500 | num_layers = 2 # default to 2 ops per image
501 | weight_idx = None # default to no probability weights for op choice
502 | transforms = _RAND_TRANSFORMS
503 | config = config_str.split("-")
504 | assert config[0] == "rand"
505 | config = config[1:]
506 | for c in config:
507 | cs = re.split(r"(\d.*)", c)
508 | if len(cs) < 2:
509 | continue
510 | key, val = cs[:2]
511 | if key == "mstd":
512 | # noise param injected via hparams for now
513 | hparams.setdefault("magnitude_std", float(val))
514 | elif key == "inc":
515 | if bool(val):
516 | transforms = _RAND_INCREASING_TRANSFORMS
517 | elif key == "m":
518 | magnitude = int(val)
519 | elif key == "n":
520 | num_layers = int(val)
521 | elif key == "w":
522 | weight_idx = int(val)
523 | else:
524 | assert NotImplementedError
525 | ra_ops = rand_augment_ops(
526 | magnitude=magnitude, hparams=hparams, transforms=transforms
527 | )
528 | choice_weights = (
529 | None if weight_idx is None else _select_rand_weights(weight_idx)
530 | )
531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
532 |
--------------------------------------------------------------------------------
/tvr/dataloaders/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/tvr/dataloaders/rawvideo_util.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | from PIL import Image
4 | # pytorch=1.7.1
5 | # pip install opencv-python
6 | import cv2
7 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, ToPILImage, ColorJitter, RandomHorizontalFlip, RandomResizedCrop
8 | import tvr.dataloaders.video_transforms as video_transforms
9 | from .random_erasing import RandomErasing
10 |
11 |
12 | class RawVideoExtractorCV2():
13 | def __init__(self, centercrop=False, size=224, framerate=-1, subset="test"):
14 | self.centercrop = centercrop
15 | self.size = size
16 | self.framerate = framerate
17 | self.transform = self._transform(self.size)
18 | self.subset = subset
19 | self.tsfm_dict = {
20 | 'clip_test': Compose([
21 | Resize(size, interpolation=InterpolationMode.BICUBIC),
22 | CenterCrop(size),
23 | lambda image: image.convert("RGB"),
24 | ToTensor(),
25 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
26 | ]),
27 | 'clip_train': Compose([
28 | RandomResizedCrop(size, scale=(0.5, 1.0)),
29 | RandomHorizontalFlip(),
30 | lambda image: image.convert("RGB"),
31 | ToTensor(),
32 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
33 | ])
34 | }
35 | self.aug_transform = video_transforms.create_random_augment(
36 | input_size=(size, size),
37 | auto_augment='rand-m7-n4-mstd0.5-inc1',
38 | interpolation='bicubic',
39 | )
40 |
41 | def _transform(self, n_px):
42 | return Compose([
43 | Resize(n_px, interpolation=InterpolationMode.BICUBIC),
44 | CenterCrop(n_px),
45 | lambda image: image.convert("RGB"),
46 | ToTensor(),
47 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
48 | # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
49 | ])
50 |
51 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None, _no_process=False):
52 | if start_time is not None or end_time is not None:
53 | assert isinstance(start_time, int) and isinstance(end_time, int) \
54 | and start_time > -1 and end_time > start_time
55 | assert sample_fp > -1
56 |
57 | # Samples a frame sample_fp X frames.
58 | cap = cv2.VideoCapture(video_file)
59 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
60 | fps = int(cap.get(cv2.CAP_PROP_FPS))
61 |
62 | if fps == 0:
63 | print((video_file + '\n') * 10)
64 | total_duration = (frameCount + fps - 1) // fps
65 | start_sec, end_sec = 0, total_duration
66 |
67 | if start_time is not None:
68 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration
69 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps))
70 |
71 | interval = 1
72 | if sample_fp > 0:
73 | interval = fps // sample_fp
74 | else:
75 | sample_fp = fps
76 | if interval == 0: interval = 1
77 |
78 | inds = [ind for ind in np.arange(0, fps, interval)]
79 | assert len(inds) >= sample_fp
80 | inds = inds[:sample_fp]
81 |
82 | ret = True
83 | images, included = [], []
84 |
85 | for sec in np.arange(start_sec, end_sec + 1):
86 | if not ret: break
87 | sec_base = int(sec * fps)
88 | for ind in inds:
89 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind)
90 | ret, frame = cap.read()
91 | if not ret: break
92 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
93 | if _no_process:
94 | images.append(Image.fromarray(frame_rgb).convert("RGB"))
95 | else:
96 | # images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB")))
97 | images.append(Image.fromarray(frame_rgb))
98 |
99 | cap.release()
100 |
101 | if len(images) > 0:
102 | if _no_process:
103 | video_data = images
104 | else:
105 | if self.subset == "train":
106 | # for i in range(2):
107 | images = self.aug_transform(images)
108 |
109 | # if self.subset == "train":
110 | # patch_images = torch.stack([self.tsfm_dict["clip_train"](img) for img in patch_images])
111 | # else:
112 | # patch_images = torch.stack([self.tsfm_dict["clip_test"](img) for img in patch_images])
113 |
114 | video_data = th.stack([preprocess(img) for img in images])
115 | # video_data = th.tensor(np.stack(images))
116 | else:
117 | video_data = th.zeros(1)
118 | return {'video': video_data}
119 |
120 | def get_video_data(self, video_path, start_time=None, end_time=None, _no_process=False):
121 | image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time,
122 | end_time=end_time, _no_process=_no_process)
123 | return image_input
124 |
125 | def process_raw_data(self, raw_video_data):
126 | tensor_size = raw_video_data.size()
127 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1])
128 | return tensor
129 |
130 | def process_frame_order(self, raw_video_data, frame_order=0):
131 | # 0: ordinary order; 1: reverse order; 2: random order.
132 | if frame_order == 0:
133 | pass
134 | elif frame_order == 1:
135 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1)
136 | raw_video_data = raw_video_data[reverse_order, ...]
137 | elif frame_order == 2:
138 | random_order = np.arange(raw_video_data.size(0))
139 | np.random.shuffle(random_order)
140 | raw_video_data = raw_video_data[random_order, ...]
141 |
142 | return raw_video_data
143 |
144 |
145 | # An ordinary video frame extractor based CV2
146 | RawVideoExtractor = RawVideoExtractorCV2
147 |
--------------------------------------------------------------------------------
/tvr/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__init__.py
--------------------------------------------------------------------------------
/tvr/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/modeling.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/modeling.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/module_clip.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/module_clip.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/module_cross.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/module_cross.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/module_transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/module_transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/optimization.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/optimization.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/query_cross_att.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/query_cross_att.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/tokenization_clip.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/tokenization_clip.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/transformer_block.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/transformer_block.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/__pycache__/until_module.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/__pycache__/until_module.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/models/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/models/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/tvr/models/module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # ------------------------------------------------------
3 | # -------------------- LOCATE Module -------------------
4 | # ------------------------------------------------------
5 | import torch
6 | import torch.nn as nn
7 | from models.attention import SoftAttention, GumbelAttention
8 |
9 | class LOCATE(nn.Module):
10 | def __init__(self, opt):
11 | super(LOCATE, self).__init__()
12 | # spatial soft attention module
13 | self.spatial_attn = SoftAttention(opt.region_projected_size, opt.hidden_size, opt.hidden_size)
14 |
15 | # temporal soft attention module
16 | feat_size = opt.region_projected_size + opt.hidden_size * 2
17 | self.temp_attn = SoftAttention(feat_size, opt.hidden_size, opt.hidden_size)
18 |
19 | def forward(self, frame_feats, object_feats, hidden_state):
20 | """
21 | :param frame_feats: (batch_size, max_frames, 2*hidden_size)
22 | :param object_feats: (batch_size, max_frames, num_boxes, region_projected_size)
23 | :param hidden_state: (batch_size, hidden_size)
24 | :return: loc_feat: (batch_size, feat_size)
25 | """
26 | # spatial attention
27 | bsz, max_frames, num_boxes, fsize = object_feats.size()
28 | feats = object_feats.reshape(bsz * max_frames, num_boxes, fsize)
29 | hidden = hidden_state.repeat(1, max_frames).reshape(bsz * max_frames, -1)
30 | object_feats_att, _ = self.spatial_attn(feats, hidden)
31 | object_feats_att = object_feats_att.reshape(bsz, max_frames, fsize)
32 |
33 | # temporal attention
34 | feat = torch.cat([object_feats_att, frame_feats], dim=-1)
35 | loc_feat, _ = self.temp_attn(feat, hidden_state)
36 | return loc_feat
37 |
38 |
39 | # ------------------------------------------------------
40 | # -------------------- RELATE Module -------------------
41 | # ------------------------------------------------------
42 |
43 | class RELATE(nn.Module):
44 | def __init__(self, opt):
45 | super(RELATE, self).__init__()
46 |
47 | # spatial soft attention module
48 | region_feat_size = opt.region_projected_size
49 | self.spatial_attn = SoftAttention(region_feat_size, opt.hidden_size, opt.hidden_size)
50 |
51 | # temporal soft attention module
52 | feat_size = region_feat_size + opt.hidden_size * 2
53 | self.relation_attn = SoftAttention(2*feat_size, opt.hidden_size, opt.hidden_size)
54 |
55 | def forward(self, i3d_feats, object_feats, hidden_state):
56 | '''
57 | :param i3d_feats: (batch_size, max_frames, 2*hidden_size)
58 | :param object_feats: (batch_size, max_frames, num_boxes, region_projected_size)
59 | :param hidden_state: (batch_size, hidden_size)
60 | :return: rel_feat
61 | '''
62 | # spatial atttention
63 | bsz, max_frames, num_boxes, fsize = object_feats.size()
64 | feats = object_feats.reshape(bsz * max_frames, num_boxes, fsize)
65 | hidden = hidden_state.repeat(1, max_frames).reshape(bsz * max_frames, -1)
66 | object_feats_att, _ = self.spatial_attn(feats, hidden)
67 | object_feats_att = object_feats_att.reshape(bsz, max_frames, fsize)
68 |
69 | # generate pair-wise feature
70 | feat = torch.cat([object_feats_att, i3d_feats], dim=-1)
71 | feat1 = feat.repeat(1, max_frames, 1)
72 | feat2 = feat.repeat(1, 1, max_frames).reshape(bsz, max_frames*max_frames, -1)
73 | pairwise_feat = torch.cat([feat1, feat2], dim=-1)
74 |
75 | # temporal attention
76 | rel_feat, _ = self.relation_attn(pairwise_feat, hidden_state)
77 | return rel_feat
78 |
79 |
80 | # ------------------------------------------------------
81 | # -------------------- FUNC Module ---------------------
82 | # ------------------------------------------------------
83 |
84 | class FUNC(nn.Module):
85 | def __init__(self, opt):
86 | super(FUNC, self).__init__()
87 | self.cell_attn = SoftAttention(opt.hidden_size, opt.hidden_size, opt.hidden_size)
88 |
89 | def forward(self, cells, hidden_state):
90 | '''
91 | :param cells: previous memory states of decoder LSTM
92 | :param hidden_state: (batch_size, hidden_size)
93 | :return: func_feat
94 | '''
95 | func_feat, _ = self.cell_attn(cells, hidden_state)
96 | return func_feat
97 |
98 |
99 |
100 | # ------------------------------------------------------
101 | # ------------------- Module Selector ------------------
102 | # ------------------------------------------------------
103 |
104 | class ModuleSelection(nn.Module):
105 | def __init__(self, opt):
106 | super(ModuleSelection, self).__init__()
107 | self.use_loc = opt.use_loc
108 | self.use_rel = opt.use_rel
109 | self.use_func = opt.use_func
110 |
111 | if opt.use_loc:
112 | loc_feat_size = opt.region_projected_size + opt.hidden_size * 2
113 | self.loc_fc = nn.Linear(loc_feat_size, opt.hidden_size)
114 | nn.init.xavier_normal_(self.loc_fc.weight)
115 |
116 | if opt.use_rel:
117 | rel_feat_size = 2 * (opt.region_projected_size + 2 * opt.hidden_size)
118 | self.rel_fc = nn.Linear(rel_feat_size, opt.hidden_size)
119 | nn.init.xavier_normal_(self.rel_fc.weight)
120 |
121 | if opt.use_func:
122 | func_size = opt.hidden_size
123 | self.func_fc = nn.Linear(func_size, opt.hidden_size)
124 | nn.init.xavier_normal_(self.func_fc.weight)
125 |
126 | if opt.use_loc and opt.use_rel and opt.use_func:
127 | if opt.attention == 'soft':
128 | self.module_attn = SoftAttention(opt.hidden_size, opt.hidden_size, opt.hidden_size)
129 | elif opt.attention == 'gumbel':
130 | self.module_attn = GumbelAttention(opt.hidden_size, opt.hidden_size, opt.hidden_size)
131 |
132 | def forward(self, loc_feats, rel_feats, func_feats, hidden_state):
133 | '''
134 | soft attention: Weighted sum of three features
135 | gumbel attention: Choose one of three features
136 | '''
137 | loc_feats = self.loc_fc(loc_feats) if self.use_loc else None
138 | rel_feats = self.rel_fc(rel_feats) if self.use_rel else None
139 | func_feats = self.func_fc(func_feats) if self.use_func else None
140 |
141 | if self.use_loc and self.use_rel and self.use_func:
142 | feats = torch.stack([loc_feats, rel_feats, func_feats], dim=1)
143 | feats, module_weight = self.module_attn(feats, hidden_state)
144 |
145 | elif self.use_loc and not self.use_rel:
146 | feats = loc_feats
147 | module_weight = torch.tensor([0.3, 0.3, 0.4]).cuda()
148 | elif self.use_rel and not self.use_loc:
149 | feats = rel_feats
150 | module_weight = torch.tensor([0.3, 0.3, 0.4]).cuda()
151 |
152 | return feats, module_weight
--------------------------------------------------------------------------------
/tvr/models/module_clip.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py
3 | """
4 | from collections import OrderedDict
5 | from typing import Tuple, Union
6 |
7 | import hashlib
8 | import os
9 | import urllib
10 | import warnings
11 | from tqdm import tqdm
12 | from .module_transformer import Transformer as TransformerClip
13 | import torch
14 | import torch.nn.functional as F
15 | from torch import nn
16 |
17 | _MODELS = {
18 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
19 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
20 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
21 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
22 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
23 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
24 | }
25 | _PT_NAME = {
26 | "RN50": "RN50.pt",
27 | "RN101": "RN101.pt",
28 | "RN50x4": "RN50x4.pt",
29 | "RN50x16": "RN50x16.pt",
30 | "ViT-B/32": "ViT-B-32.pt",
31 | "ViT-B/16": "ViT-B-16.pt",
32 | "ViT-L/14": "ViT-L-14.pt",
33 | }
34 |
35 |
36 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
37 | os.makedirs(root, exist_ok=True)
38 | filename = os.path.basename(url)
39 |
40 | expected_sha256 = url.split("/")[-2]
41 | download_target = os.path.join(root, filename)
42 |
43 | if os.path.exists(download_target) and not os.path.isfile(download_target):
44 | raise RuntimeError(f"{download_target} exists and is not a regular file")
45 |
46 | if os.path.isfile(download_target):
47 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
48 | return download_target
49 | else:
50 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
51 |
52 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
53 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
54 | while True:
55 | buffer = source.read(8192)
56 | if not buffer:
57 | break
58 |
59 | output.write(buffer)
60 | loop.update(len(buffer))
61 |
62 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
63 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
64 |
65 | return download_target
66 |
67 |
68 | def available_models():
69 | """Returns the names of available CLIP models"""
70 | return list(_MODELS.keys())
71 |
72 |
73 | # =============================
74 |
75 | class Bottleneck(nn.Module):
76 | expansion = 4
77 |
78 | def __init__(self, inplanes, planes, stride=1):
79 | super(Bottleneck, self).__init__()
80 |
81 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
82 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
83 | self.bn1 = nn.BatchNorm2d(planes)
84 |
85 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
86 | self.bn2 = nn.BatchNorm2d(planes)
87 |
88 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
89 |
90 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
91 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
92 |
93 | self.relu = nn.ReLU(inplace=True)
94 | self.downsample = None
95 | self.stride = stride
96 |
97 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
98 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
99 | self.downsample = nn.Sequential(OrderedDict([
100 | ("-1", nn.AvgPool2d(stride)),
101 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
102 | ("1", nn.BatchNorm2d(planes * self.expansion))
103 | ]))
104 |
105 | def forward(self, x: torch.Tensor):
106 | identity = x
107 |
108 | out = self.relu(self.bn1(self.conv1(x)))
109 | out = self.relu(self.bn2(self.conv2(out)))
110 | out = self.avgpool(out)
111 | out = self.bn3(self.conv3(out))
112 |
113 | if self.downsample is not None:
114 | identity = self.downsample(x)
115 |
116 | out += identity
117 | out = self.relu(out)
118 | return out
119 |
120 |
121 | class AttentionPool2d(nn.Module):
122 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
123 | super(AttentionPool2d, self).__init__()
124 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
125 | self.k_proj = nn.Linear(embed_dim, embed_dim)
126 | self.q_proj = nn.Linear(embed_dim, embed_dim)
127 | self.v_proj = nn.Linear(embed_dim, embed_dim)
128 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
129 | self.num_heads = num_heads
130 |
131 | def forward(self, x):
132 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
133 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
134 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
135 | x, _ = F.multi_head_attention_forward(
136 | query=x, key=x, value=x,
137 | embed_dim_to_check=x.shape[-1],
138 | num_heads=self.num_heads,
139 | q_proj_weight=self.q_proj.weight,
140 | k_proj_weight=self.k_proj.weight,
141 | v_proj_weight=self.v_proj.weight,
142 | in_proj_weight=None,
143 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
144 | bias_k=None,
145 | bias_v=None,
146 | add_zero_attn=False,
147 | dropout_p=0,
148 | out_proj_weight=self.c_proj.weight,
149 | out_proj_bias=self.c_proj.bias,
150 | use_separate_proj_weight=True,
151 | training=self.training,
152 | need_weights=False
153 | )
154 |
155 | return x[0]
156 |
157 |
158 | class ModifiedResNet(nn.Module):
159 | """
160 | A ResNet class that is similar to torchvision's but contains the following changes:
161 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
162 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
163 | - The final pooling layer is a QKV attention instead of an average pool
164 | """
165 |
166 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
167 | super(ModifiedResNet, self).__init__()
168 | self.output_dim = output_dim
169 | self.input_resolution = input_resolution
170 |
171 | # the 3-layer stem
172 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
173 | self.bn1 = nn.BatchNorm2d(width // 2)
174 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
175 | self.bn2 = nn.BatchNorm2d(width // 2)
176 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
177 | self.bn3 = nn.BatchNorm2d(width)
178 | self.avgpool = nn.AvgPool2d(2)
179 | self.relu = nn.ReLU(inplace=True)
180 |
181 | # residual layers
182 | self._inplanes = width # this is a *mutable* variable used during construction
183 | self.layer1 = self._make_layer(width, layers[0])
184 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
185 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
186 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
187 |
188 | embed_dim = width * 32 # the ResNet feature dimension
189 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
190 |
191 | def _make_layer(self, planes, blocks, stride=1):
192 | layers = [Bottleneck(self._inplanes, planes, stride)]
193 |
194 | self._inplanes = planes * Bottleneck.expansion
195 | for _ in range(1, blocks):
196 | layers.append(Bottleneck(self._inplanes, planes))
197 |
198 | return nn.Sequential(*layers)
199 |
200 | def forward(self, x):
201 | def stem(x):
202 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
203 | x = self.relu(bn(conv(x)))
204 | x = self.avgpool(x)
205 | return x
206 |
207 | x = x.type(self.conv1.weight.dtype)
208 | x = stem(x)
209 | x = self.layer1(x)
210 | x = self.layer2(x)
211 | x = self.layer3(x)
212 | x = self.layer4(x)
213 | x = self.attnpool(x)
214 |
215 | return x
216 |
217 |
218 | class LayerNorm(nn.LayerNorm):
219 | """Subclass torch's LayerNorm to handle fp16."""
220 |
221 | def forward(self, x: torch.Tensor):
222 | orig_type = x.dtype
223 | ret = super().forward(x.type(torch.float32))
224 | return ret.type(orig_type)
225 |
226 |
227 | class QuickGELU(nn.Module):
228 | def forward(self, x: torch.Tensor):
229 | return x * torch.sigmoid(1.702 * x)
230 |
231 |
232 | class ResidualAttentionBlock(nn.Module):
233 | def __init__(self, d_model: int, n_head: int, attn_mask=None):
234 | super(ResidualAttentionBlock, self).__init__()
235 |
236 | self.attn = nn.MultiheadAttention(d_model, n_head)
237 | self.ln_1 = LayerNorm(d_model)
238 | self.mlp = nn.Sequential(OrderedDict([
239 | ("c_fc", nn.Linear(d_model, d_model * 4)),
240 | ("gelu", QuickGELU()),
241 | ("c_proj", nn.Linear(d_model * 4, d_model))
242 | ]))
243 | self.ln_2 = LayerNorm(d_model)
244 | self.attn_mask = attn_mask
245 |
246 | def attention(self, x: torch.Tensor):
247 | attn_mask_ = self.attn_mask
248 | if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
249 | attn_mask_ = self.attn_mask(x.size(0)) # LND
250 |
251 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
252 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
253 |
254 | def forward(self, x):
255 | x = x + self.attention(self.ln_1(x))
256 | x = x + self.mlp(self.ln_2(x))
257 | return x
258 |
259 |
260 | class Transformer(nn.Module):
261 | def __init__(self, width: int, layers: int, heads: int, attn_mask=None):
262 | super(Transformer, self).__init__()
263 | self.width = width
264 | self.layers = layers
265 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
266 |
267 | def forward(self, x: torch.Tensor):
268 | return self.resblocks(x)
269 |
270 |
271 | class VisualTransformer(nn.Module):
272 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
273 | super(VisualTransformer, self).__init__()
274 | self.input_resolution = input_resolution
275 | self.output_dim = output_dim
276 |
277 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
278 |
279 | scale = width ** -0.5
280 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
281 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
282 | self.ln_pre = LayerNorm(width)
283 |
284 | self.transformer = Transformer(width, layers, heads)
285 |
286 | self.ln_post = LayerNorm(width)
287 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
288 |
289 | for param in self.conv1.parameters():
290 | param.requires_grad = False # not update by gradient
291 |
292 | # self.class_embedding.requires_grad = False # not update by gradient
293 | # self.positional_embedding.requires_grad = False # not update by gradient
294 |
295 |
296 | def forward(self, x: torch.Tensor, mask=None):
297 |
298 | x = self.conv1(x) # shape = [*, width, grid, grid]
299 |
300 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
301 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
302 | x = torch.cat(
303 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
304 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
305 |
306 | x = x + self.positional_embedding.to(x.dtype)
307 | x = self.ln_pre(x)
308 |
309 | # zero = torch.zeros((x.size(1), x.size(1))).repeat(x.size(0), 1, 1).to(x.device)
310 | # inf = torch.zeros((x.size(1), x.size(1))).fill_(float("-inf")).repeat(x.size(0), 1, 1).to(mask.device)
311 | # _mask = mask.view(-1).unsqueeze(1).unsqueeze(1).expand(-1, x.size(1), x.size(1))
312 | # attn_mask = torch.where(_mask>0, zero, inf)
313 | # attn_mask[:, 0, 0] = 0
314 |
315 | x = x.permute(1, 0, 2) # NLD -> LND
316 | x = self.transformer(x)
317 | x = x.permute(1, 0, 2) # LND -> NLD
318 |
319 | # mask = mask.view(-1).unsqueeze(1).unsqueeze(1).expand(-1, x.size(1), x.size(2))
320 | # zero = torch.zeros((x.size(0), x.size(1), x.size(2))).to(x.device).type(x.dtype)
321 | # x = torch.where(mask>0, x, zero)
322 | # Move the three lines below to `encode_image` for entire hidden sequence
323 | # x = self.ln_post(x[:, 0, :])
324 | # if self.proj is not None:
325 | # x = x @ self.proj
326 |
327 | return x
328 |
329 |
330 | class CLIP(nn.Module):
331 | def __init__(self,
332 | embed_dim: int,
333 | # vision
334 | image_resolution: int,
335 | vision_layers: Union[Tuple[int, int, int, int], int],
336 | vision_width: int,
337 | vision_patch_size: int,
338 | # text
339 | context_length: int,
340 | vocab_size: int,
341 | transformer_width: int,
342 | transformer_heads: int,
343 | transformer_layers: int
344 | ):
345 | super(CLIP, self).__init__()
346 |
347 | self.context_length = context_length
348 |
349 | if isinstance(vision_layers, (tuple, list)):
350 | vision_heads = vision_width * 32 // 64
351 | self.visual = ModifiedResNet(
352 | layers=vision_layers,
353 | output_dim=embed_dim,
354 | heads=vision_heads,
355 | input_resolution=image_resolution,
356 | width=vision_width
357 | )
358 | else:
359 | vision_heads = vision_width // 64
360 | self.visual = VisualTransformer(
361 | input_resolution=image_resolution,
362 | patch_size=vision_patch_size,
363 | width=vision_width,
364 | layers=vision_layers,
365 | heads=vision_heads,
366 | output_dim=embed_dim,
367 | )
368 |
369 | self.transformer = TransformerClip(
370 | width=transformer_width,
371 | layers=transformer_layers,
372 | heads=transformer_heads,
373 | # attn_mask=self.build_attention_mask
374 | )
375 |
376 | self.vocab_size = vocab_size
377 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
378 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
379 | self.ln_final = LayerNorm(transformer_width)
380 |
381 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
382 | self.logit_scale = nn.Parameter(torch.ones([]))
383 |
384 | self.initialize_parameters()
385 |
386 | # for param in self.transformer.parameters():
387 | # param.requires_grad = False # not update by gradient
388 | self.token_embedding.requires_grad = False # not update by gradient
389 | # self.positional_embedding.requires_grad = False # not update by gradient
390 |
391 | def initialize_parameters(self):
392 | nn.init.normal_(self.token_embedding.weight, std=0.02)
393 | nn.init.normal_(self.positional_embedding, std=0.01)
394 |
395 | if isinstance(self.visual, ModifiedResNet):
396 | if self.visual.attnpool is not None:
397 | std = self.visual.attnpool.c_proj.in_features ** -0.5
398 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
399 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
400 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
401 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
402 |
403 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
404 | for name, param in resnet_block.named_parameters():
405 | if name.endswith("bn3.weight"):
406 | nn.init.zeros_(param)
407 |
408 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
409 | attn_std = self.transformer.width ** -0.5
410 | fc_std = (2 * self.transformer.width) ** -0.5
411 | for block in self.transformer.resblocks:
412 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
413 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
414 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
415 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
416 |
417 | if self.text_projection is not None:
418 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
419 |
420 | @staticmethod
421 | def get_config(pretrained_clip_name="ViT-B/32"):
422 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt")
423 | if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME:
424 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name])
425 |
426 | if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path):
427 | pass
428 | else:
429 | if pretrained_clip_name in _MODELS:
430 | model_path = _download(_MODELS[pretrained_clip_name])
431 | elif os.path.isfile(pretrained_clip_name):
432 | model_path = pretrained_clip_name
433 | else:
434 | raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}")
435 |
436 | try:
437 | # loading JIT archive
438 | model = torch.jit.load(model_path, map_location="cpu").eval()
439 | state_dict = model.state_dict()
440 | except RuntimeError:
441 | state_dict = torch.load(model_path, map_location="cpu")
442 |
443 | return state_dict
444 |
445 | def build_attention_mask(self, context_length):
446 | # lazily create causal attention mask, with full attention between the vision tokens
447 | # pytorch uses additive attention mask; fill with -inf
448 | mask = torch.zeros(context_length, context_length)
449 | mask.fill_(float("-inf"))
450 | mask.triu_(1) # zero out the lower diagonal
451 | return mask
452 |
453 | @property
454 | def dtype(self):
455 | return self.visual.conv1.weight.dtype
456 |
457 | def encode_image(self, image, return_hidden=False, mask=None):
458 | hidden = self.visual(image.type(self.dtype))
459 | hidden = self.visual.ln_post(hidden) @ self.visual.proj
460 |
461 | x = hidden[:, 0, :]
462 |
463 | if return_hidden:
464 | return x, hidden
465 |
466 | return x
467 |
468 | def encode_text(self, text, return_hidden=False, mask=None):
469 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
470 |
471 | pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
472 |
473 | attn_mask = self.build_attention_mask(x.size(1)).repeat(x.size(0), 1, 1).to(mask.device)
474 | inf = torch.zeros((x.size(1), x.size(1))).fill_(float("-inf")).repeat(x.size(0), 1, 1).to(mask.device)
475 | mask = mask.unsqueeze(1).expand(-1, mask.size(1), -1)
476 | attn_mask = torch.where(mask>0, attn_mask, inf)
477 |
478 | x = x + pos_emd
479 | x = x.permute(1, 0, 2) # NLD -> LND
480 | x = self.transformer(x, attn_mask)
481 | x = x.permute(1, 0, 2) # LND -> NLD
482 |
483 | hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
484 |
485 | # x.shape = [batch_size, n_ctx, transformer.width]
486 | # take features from the eot embedding (eot_token is the highest number in each sequence)
487 | x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
488 |
489 | if return_hidden:
490 | return x, hidden
491 |
492 | return x
493 |
494 | def forward(self, image, text):
495 | image_features = self.encode_image(image)
496 | text_features = self.encode_text(text)
497 |
498 | # normalized features
499 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
500 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
501 |
502 | # cosine similarity as logits
503 | logit_scale = self.logit_scale.exp()
504 | logits_per_image = logit_scale * image_features @ text_features.t()
505 | logits_per_text = logit_scale * text_features @ image_features.t()
506 |
507 | # shape = [global_batch_size, global_batch_size]
508 | return logits_per_image, logits_per_text
509 |
510 |
511 | def convert_weights(model: nn.Module):
512 | """Convert applicable model parameters to fp16"""
513 |
514 | def _convert_weights_to_fp16(l):
515 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
516 | l.weight.data = l.weight.data.half()
517 | if l.bias is not None:
518 | l.bias.data = l.bias.data.half()
519 |
520 | if isinstance(l, nn.MultiheadAttention):
521 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
522 | tensor = getattr(l, attr)
523 | if tensor is not None:
524 | tensor.data = tensor.data.half()
525 |
526 | for name in ["text_projection", "proj"]:
527 | if hasattr(l, name):
528 | attr = getattr(l, name)
529 | if attr is not None:
530 | attr.data = attr.data.half()
531 |
532 | model.apply(_convert_weights_to_fp16)
--------------------------------------------------------------------------------
/tvr/models/module_cross.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import logging
6 | from timm.models.layers import drop_path
7 | import torch
8 | from torch import nn
9 | from .until_module import LayerNorm, ACT2FN
10 | from collections import OrderedDict
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | PRETRAINED_MODEL_ARCHIVE_MAP = {}
15 | CONFIG_NAME = 'cross_config.json'
16 | WEIGHTS_NAME = 'cross_pytorch_model.bin'
17 |
18 |
19 | class QuickGELU(nn.Module):
20 | def forward(self, x: torch.Tensor):
21 | return x * torch.sigmoid(1.702 * x)
22 |
23 |
24 | class DropPath(nn.Module):
25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26 | """
27 |
28 | def __init__(self, drop_prob=None):
29 | super(DropPath, self).__init__()
30 | self.drop_prob = drop_prob
31 |
32 | def forward(self, x):
33 | return drop_path(x, self.drop_prob, self.training)
34 |
35 | def extra_repr(self) -> str:
36 | return 'p={}'.format(self.drop_prob)
37 |
38 |
39 | class ResidualAttentionBlock(nn.Module):
40 | def __init__(self, d_model: int, n_head: int, drop_path=0.0):
41 | super().__init__()
42 |
43 | self.attn = nn.MultiheadAttention(d_model, n_head)
44 | self.ln_1 = LayerNorm(d_model)
45 | self.mlp = nn.Sequential(OrderedDict([
46 | ("c_fc", nn.Linear(d_model, d_model * 4)),
47 | ("gelu", QuickGELU()),
48 | ("c_proj", nn.Linear(d_model * 4, d_model))
49 | ]))
50 | self.ln_2 = LayerNorm(d_model)
51 | self.n_head = n_head
52 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
53 |
54 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor):
55 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0)
56 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
57 |
58 | def forward(self, para_tuple: tuple):
59 | # x: torch.Tensor, attn_mask: torch.Tensor
60 | # print(para_tuple)
61 | x, attn_mask = para_tuple
62 | if self.training:
63 | x = x + self.drop_path(self.attention(self.ln_1(x), attn_mask))
64 | x = x + self.drop_path(self.mlp(self.ln_2(x)))
65 | else:
66 | x = x + self.attention(self.ln_1(x), attn_mask)
67 | x = x + self.mlp(self.ln_2(x))
68 | return (x, attn_mask)
69 |
70 |
71 | class Transformer(nn.Module):
72 | def __init__(self, width: int, layers: int, heads: int):
73 | super().__init__()
74 | self.width = width
75 | self.layers = layers
76 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
77 |
78 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
79 | return self.resblocks((x, attn_mask))[0]
80 |
81 |
82 | class CrossEmbeddings(nn.Module):
83 | """Construct the embeddings from word, position and token_type embeddings.
84 | """
85 |
86 | def __init__(self, config):
87 | super(CrossEmbeddings, self).__init__()
88 |
89 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
90 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
91 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
92 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
93 |
94 | def forward(self, concat_embeddings, concat_type=None):
95 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
96 | # if concat_type is None:
97 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
98 |
99 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device)
100 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1)
101 |
102 | # token_type_embeddings = self.token_type_embeddings(concat_type)
103 | position_embeddings = self.position_embeddings(position_ids)
104 |
105 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings
106 | # embeddings = self.LayerNorm(embeddings)
107 | embeddings = self.dropout(embeddings)
108 | return embeddings
109 |
110 |
111 | class CrossPooler(nn.Module):
112 | def __init__(self, config):
113 | super(CrossPooler, self).__init__()
114 | self.ln_pool = LayerNorm(config.hidden_size)
115 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
116 | self.activation = QuickGELU()
117 |
118 | def forward(self, hidden_states, hidden_mask):
119 | # We "pool" the model by simply taking the hidden state corresponding
120 | # to the first token.
121 | hidden_states = self.ln_pool(hidden_states)
122 | pooled_output = hidden_states[:, 0]
123 | pooled_output = self.dense(pooled_output)
124 | pooled_output = self.activation(pooled_output)
125 | return pooled_output
126 |
127 |
128 | class CrossModel(nn.Module):
129 |
130 | def initialize_parameters(self):
131 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
132 | attn_std = self.transformer.width ** -0.5
133 | fc_std = (2 * self.transformer.width) ** -0.5
134 | for block in self.transformer.resblocks:
135 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
136 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
137 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
138 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
139 |
140 | def __init__(self, config):
141 | super(CrossModel, self).__init__()
142 | self.config = config
143 |
144 | self.embeddings = CrossEmbeddings(config)
145 |
146 | transformer_width = config.hidden_size
147 | transformer_layers = config.num_hidden_layers
148 | transformer_heads = config.num_attention_heads
149 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads, )
150 | self.pooler = CrossPooler(config)
151 | self.apply(self.init_weights)
152 |
153 | def build_attention_mask(self, attention_mask):
154 | extended_attention_mask = attention_mask.unsqueeze(1)
155 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
156 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0
157 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1)
158 | return extended_attention_mask
159 |
160 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True):
161 |
162 | if attention_mask is None:
163 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
164 | if concat_type is None:
165 | concat_type = torch.zeros_like(attention_mask)
166 |
167 | extended_attention_mask = self.build_attention_mask(attention_mask)
168 |
169 | embedding_output = self.embeddings(concat_input, concat_type)
170 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND
171 | embedding_output = self.transformer(embedding_output, extended_attention_mask)
172 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD
173 |
174 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask)
175 |
176 | return embedding_output, pooled_output
177 |
178 | @property
179 | def dtype(self):
180 | """
181 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
182 | """
183 | try:
184 | return next(self.parameters()).dtype
185 | except StopIteration:
186 | # For nn.DataParallel compatibility in PyTorch 1.5
187 | def find_tensor_attributes(module: nn.Module):
188 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
189 | return tuples
190 |
191 | gen = self._named_members(get_members_fn=find_tensor_attributes)
192 | first_tuple = next(gen)
193 | return first_tuple[1].dtype
194 |
195 | def init_weights(self, module):
196 | """ Initialize the weights.
197 | """
198 | if isinstance(module, (nn.Linear, nn.Embedding)):
199 | # Slightly different from the TF version which uses truncated_normal for initialization
200 | # cf https://github.com/pytorch/pytorch/pull/5617
201 | module.weight.data.normal_(mean=0.0, std=0.02)
202 | elif isinstance(module, LayerNorm):
203 | if 'beta' in dir(module) and 'gamma' in dir(module):
204 | module.beta.data.zero_()
205 | module.gamma.data.fill_(1.0)
206 | else:
207 | module.bias.data.zero_()
208 | module.weight.data.fill_(1.0)
209 | if isinstance(module, nn.Linear) and module.bias is not None:
210 | module.bias.data.zero_()
211 |
--------------------------------------------------------------------------------
/tvr/models/module_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import logging
6 | from timm.models.layers import drop_path
7 | import torch
8 | from torch import nn
9 | from .until_module import LayerNorm, ACT2FN
10 | from collections import OrderedDict
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | PRETRAINED_MODEL_ARCHIVE_MAP = {}
15 | CONFIG_NAME = 'cross_config.json'
16 | WEIGHTS_NAME = 'cross_pytorch_model.bin'
17 |
18 |
19 | class LayerNorm(nn.LayerNorm):
20 | """Subclass torch's LayerNorm to handle fp16."""
21 |
22 | def forward(self, x: torch.Tensor):
23 | orig_type = x.dtype
24 | ret = super().forward(x.type(torch.float32))
25 | return ret.type(orig_type)
26 |
27 |
28 | class QuickGELU(nn.Module):
29 | def forward(self, x: torch.Tensor):
30 | return x * torch.sigmoid(1.702 * x)
31 |
32 |
33 | class ResidualAttentionBlock(nn.Module):
34 | def __init__(self, d_model: int, n_head: int, attn_mask=None):
35 | super(ResidualAttentionBlock, self).__init__()
36 |
37 | self.attn = nn.MultiheadAttention(d_model, n_head)
38 | self.ln_1 = LayerNorm(d_model)
39 | self.mlp = nn.Sequential(OrderedDict([
40 | ("c_fc", nn.Linear(d_model, d_model * 4)),
41 | ("gelu", QuickGELU()),
42 | ("c_proj", nn.Linear(d_model * 4, d_model))
43 | ]))
44 | self.ln_2 = LayerNorm(d_model)
45 | self.attn_mask = attn_mask
46 | self.n_head = n_head
47 |
48 | def attention(self, x: torch.Tensor, attn_mask_: torch.Tensor):
49 | attn_mask_ = attn_mask_.repeat_interleave(self.n_head, dim=0)
50 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
51 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
52 |
53 | def forward(self, para_tuple: tuple):
54 | x, attn_mask = para_tuple
55 | x = x + self.attention(self.ln_1(x), attn_mask)
56 | x = x + self.mlp(self.ln_2(x))
57 | return (x, attn_mask)
58 |
59 |
60 | class Transformer(nn.Module):
61 | def __init__(self, width: int, layers: int, heads: int, attn_mask=None):
62 | super(Transformer, self).__init__()
63 | self.width = width
64 | self.layers = layers
65 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
66 |
67 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
68 | return self.resblocks((x, attn_mask))[0]
--------------------------------------------------------------------------------
/tvr/models/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002):
27 | if x < warmup:
28 | return x/warmup
29 | return 0.5 * (1.0 + math.cos(math.pi * x))
30 |
31 | def warmup_constant(x, warmup=0.002):
32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
33 | Learning rate is 1. afterwards. """
34 | if x < warmup:
35 | return x/warmup
36 | return 1.0
37 |
38 | def warmup_linear(x, warmup=0.002):
39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
40 | After `t_total`-th training step, learning rate is zero. """
41 | if x < warmup:
42 | return x/warmup
43 | return max((x-1.)/(warmup-1.), 0)
44 |
45 | SCHEDULES = {
46 | 'warmup_cosine': warmup_cosine,
47 | 'warmup_constant': warmup_constant,
48 | 'warmup_linear': warmup_linear,
49 | }
50 |
51 |
52 | class BertAdam(Optimizer):
53 | """Implements BERT version of Adam algorithm with weight decay fix.
54 | Params:
55 | lr: learning rate
56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
57 | t_total: total number of training steps for the learning
58 | rate schedule, -1 means constant learning rate. Default: -1
59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
60 | b1: Adams b1. Default: 0.9
61 | b2: Adams b2. Default: 0.999
62 | e: Adams epsilon. Default: 1e-6
63 | weight_decay: Weight decay. Default: 0.01
64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
65 | """
66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
68 | max_grad_norm=1.0):
69 | if lr is not required and lr < 0.0:
70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
71 | if schedule not in SCHEDULES:
72 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
73 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
75 | if not 0.0 <= b1 < 1.0:
76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
77 | if not 0.0 <= b2 < 1.0:
78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
79 | if not e >= 0.0:
80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
83 | max_grad_norm=max_grad_norm)
84 | super(BertAdam, self).__init__(params, defaults)
85 |
86 | def get_lr(self):
87 | lr = []
88 | for group in self.param_groups:
89 | for p in group['params']:
90 | if p.grad is None:
91 | continue
92 | state = self.state[p]
93 | if len(state) == 0:
94 | return [0]
95 | if group['t_total'] != -1:
96 | schedule_fct = SCHEDULES[group['schedule']]
97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
98 | else:
99 | lr_scheduled = group['lr']
100 | lr.append(lr_scheduled)
101 | return lr
102 |
103 | def step(self, closure=None):
104 | """Performs a single optimization step.
105 | Arguments:
106 | closure (callable, optional): A closure that reevaluates the model
107 | and returns the loss.
108 | """
109 | loss = None
110 | if closure is not None:
111 | loss = closure()
112 |
113 | for group in self.param_groups:
114 | for p in group['params']:
115 | if p.grad is None:
116 | continue
117 | grad = p.grad.data
118 | if grad.is_sparse:
119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
120 |
121 | state = self.state[p]
122 |
123 | # State initialization
124 | if len(state) == 0:
125 | state['step'] = 0
126 | # Exponential moving average of gradient values
127 | state['next_m'] = torch.zeros_like(p.data)
128 | # Exponential moving average of squared gradient values
129 | state['next_v'] = torch.zeros_like(p.data)
130 |
131 | next_m, next_v = state['next_m'], state['next_v']
132 | beta1, beta2 = group['b1'], group['b2']
133 |
134 | # Add grad clipping
135 | if group['max_grad_norm'] > 0:
136 | clip_grad_norm_(p, group['max_grad_norm'])
137 |
138 | # Decay the first and second moment running average coefficient
139 | # In-place operations to update the averages at the same time
140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7
141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7
143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
144 | update = next_m / (next_v.sqrt() + group['e'])
145 |
146 | # Just adding the square of the weights to the loss function is *not*
147 | # the correct way of using L2 regularization/weight decay with Adam,
148 | # since that will interact with the m and v parameters in strange ways.
149 | #
150 | # Instead we want to decay the weights in a manner that doesn't interact
151 | # with the m/v parameters. This is equivalent to adding the square
152 | # of the weights to the loss with plain (non-momentum) SGD.
153 | if group['weight_decay'] > 0.0:
154 | update += group['weight_decay'] * p.data
155 |
156 | if group['t_total'] != -1:
157 | schedule_fct = SCHEDULES[group['schedule']]
158 | progress = state['step']/group['t_total']
159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
160 | else:
161 | lr_scheduled = group['lr']
162 |
163 | update_with_lr = lr_scheduled * update
164 | p.data.add_(-update_with_lr)
165 |
166 | state['step'] += 1
167 |
168 | return loss
--------------------------------------------------------------------------------
/tvr/models/tokenization_clip.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | self.vocab = self.encoder
81 |
82 | def bpe(self, token):
83 | if token in self.cache:
84 | return self.cache[token]
85 | word = tuple(token[:-1]) + ( token[-1] + '',)
86 | pairs = get_pairs(word)
87 |
88 | if not pairs:
89 | return token+''
90 |
91 | while True:
92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93 | if bigram not in self.bpe_ranks:
94 | break
95 | first, second = bigram
96 | new_word = []
97 | i = 0
98 | while i < len(word):
99 | try:
100 | j = word.index(first, i)
101 | new_word.extend(word[i:j])
102 | i = j
103 | except:
104 | new_word.extend(word[i:])
105 | break
106 |
107 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
108 | new_word.append(first+second)
109 | i += 2
110 | else:
111 | new_word.append(word[i])
112 | i += 1
113 | new_word = tuple(new_word)
114 | word = new_word
115 | if len(word) == 1:
116 | break
117 | else:
118 | pairs = get_pairs(word)
119 | word = ' '.join(word)
120 | self.cache[token] = word
121 | return word
122 |
123 | def encode(self, text):
124 | bpe_tokens = []
125 | text = whitespace_clean(basic_clean(text)).lower()
126 | for token in re.findall(self.pat, text):
127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129 | return bpe_tokens
130 |
131 | def decode(self, tokens):
132 | text = ''.join([self.decoder[token] for token in tokens])
133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
134 | return text
135 |
136 | def tokenize(self, text):
137 | tokens = []
138 | text = whitespace_clean(basic_clean(text)).lower()
139 | for token in re.findall(self.pat, text):
140 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
141 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
142 | return tokens
143 |
144 | def convert_tokens_to_ids(self, tokens):
145 | return [self.encoder[bpe_token] for bpe_token in tokens]
--------------------------------------------------------------------------------
/tvr/models/transformer_block.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on the implementation of https://github.com/jadore801120/attention-is-all-you-need-pytorch
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 |
9 | class ScaledDotProductAttention(nn.Module):
10 | ''' Scaled Dot-Product Attention '''
11 | def __init__(self, temperature, attn_dropout=0.1):
12 | super().__init__()
13 | self.temperature = temperature
14 | self.dropout = nn.Dropout(attn_dropout)
15 | self.softmax = nn.Softmax(dim=2)
16 |
17 | def forward(self, q, k, v, mask=None):
18 | """
19 | Args:
20 | q (bsz, len_q, dim_q)
21 | k (bsz, len_k, dim_k)
22 | v (bsz, len_v, dim_v)
23 | Note: len_k==len_v, and dim_q==dim_k
24 | Returns:
25 | output (bsz, len_q, dim_v)
26 | attn (bsz, len_q, len_k)
27 | """
28 | attn = torch.bmm(q, k.transpose(1, 2))
29 | attn = attn / self.temperature
30 |
31 | if mask is not None:
32 | attn = attn.masked_fill(mask, -np.inf)
33 |
34 | attn = self.softmax(attn)
35 | attn = self.dropout(attn)
36 | output = torch.bmm(attn, v)
37 |
38 | return output, attn
39 |
40 |
41 | class MultiHeadAttention(nn.Module):
42 | ''' Multi-Head Attention module '''
43 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
44 | super().__init__()
45 | self.n_head = n_head
46 | self.d_k = d_k
47 | self.d_v = d_v
48 |
49 | self.w_qs = nn.Linear(d_model, n_head * d_k)
50 | self.w_ks = nn.Linear(d_model, n_head * d_k)
51 | self.w_vs = nn.Linear(d_model, n_head * d_v)
52 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
53 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
54 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
55 |
56 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
57 | self.layer_norm = nn.LayerNorm(d_model)
58 |
59 | self.fc = nn.Linear(n_head * d_v, d_model)
60 | nn.init.xavier_normal_(self.fc.weight)
61 |
62 | self.dropout = nn.Dropout(dropout)
63 |
64 |
65 | def forward(self, q, k, v, mask=None):
66 | """
67 | Args:
68 | q (bsz, len_q, dim_q)
69 | k (bsz, len_k, dim_k)
70 | v (bsz, len_v, dim_v)
71 | Note: len_k==len_v, and dim_q==dim_k
72 | Returns:
73 | output (bsz, len_q, d_model)
74 | attn (bsz, len_q, len_k)
75 | """
76 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
77 |
78 | sz_b, len_q, _ = q.size()
79 | sz_b, len_k, _ = k.size()
80 | sz_b, len_v, _ = v.size() # len_k==len_v
81 |
82 | residual = q
83 |
84 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
85 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
86 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
87 |
88 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
89 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
90 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
91 |
92 | if mask is not None:
93 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
94 | output, attn = self.attention(q, k, v, mask=mask)
95 |
96 | output = output.view(n_head, sz_b, len_q, d_v)
97 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
98 |
99 | output = self.dropout(self.fc(output))
100 | output = self.layer_norm(output + residual)
101 |
102 | return output
103 |
104 |
105 | class PositionwiseFeedForward(nn.Module):
106 | ''' A two-feed-forward-layer module '''
107 | def __init__(self, d_in, d_hid, dropout=0.1):
108 | super().__init__()
109 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
110 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
111 | self.layer_norm = nn.LayerNorm(d_in)
112 | self.dropout = nn.Dropout(dropout)
113 |
114 | def forward(self, x):
115 | """
116 | Merge adjacent information. Equal to linear layer if kernel size is 1
117 | Args:
118 | x (bsz, len, dim)
119 | Returns:
120 | output (bsz, len, dim)
121 | """
122 | residual = x
123 | output = x.transpose(1, 2)
124 | output = self.w_2(F.relu(self.w_1(output)))
125 | output = output.transpose(1, 2)
126 | output = self.dropout(output)
127 | output = self.layer_norm(output + residual)
128 | return output
129 |
130 |
131 | class EncoderLayer(nn.Module):
132 | ''' Compose with two layers '''
133 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
134 | super(EncoderLayer, self).__init__()
135 | self.slf_attn = MultiHeadAttention(
136 | n_head, d_model, d_k, d_v, dropout=dropout)
137 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
138 |
139 | def forward(self, Q, K, V, non_pad_mask=None, slf_attn_mask=None):
140 | enc_output = self.slf_attn(
141 | Q, K, V, mask=slf_attn_mask)
142 | # enc_output *= non_pad_mask.float() if non_pad_mask is not None else 1.
143 |
144 | enc_output = self.pos_ffn(enc_output)
145 | # enc_output *= non_pad_mask.float()
146 |
147 | return enc_output
148 |
149 |
--------------------------------------------------------------------------------
/tvr/models/until_module.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model."""
17 |
18 | import logging
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 | import math
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def gelu(x):
28 | """Implementation of the gelu activation function.
29 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
30 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
31 | """
32 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
33 |
34 |
35 | def swish(x):
36 | return x * torch.sigmoid(x)
37 |
38 |
39 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
40 |
41 |
42 | class LayerNorm(nn.Module):
43 | def __init__(self, hidden_size, eps=1e-12):
44 | """Construct a layernorm module in the TF style (epsilon inside the square root).
45 | """
46 | super(LayerNorm, self).__init__()
47 | self.weight = nn.Parameter(torch.ones(hidden_size))
48 | self.bias = nn.Parameter(torch.zeros(hidden_size))
49 | self.variance_epsilon = eps
50 |
51 | def forward(self, x):
52 | u = x.mean(-1, keepdim=True)
53 | s = (x - u).pow(2).mean(-1, keepdim=True)
54 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
55 | return self.weight * x + self.bias
56 |
57 |
58 | ##################################
59 | ###### LOSS FUNCTION #############
60 | ##################################
61 | class CrossEn(nn.Module):
62 | def __init__(self, config=None):
63 | super(CrossEn, self).__init__()
64 |
65 | def forward(self, sim_matrix):
66 | logpt = F.log_softmax(sim_matrix, dim=-1)
67 | logpt = torch.diag(logpt)
68 | nce_loss = -logpt
69 | sim_loss = nce_loss.mean()
70 | return sim_loss
71 |
72 |
73 | class ArcCrossEn(nn.Module):
74 | def __init__(self, margin=10):
75 | super(ArcCrossEn, self).__init__()
76 | self.cos_m = math.cos(margin)
77 | self.sin_m = math.sin(margin)
78 |
79 | def forward(self, sim_matrix, scale):
80 | cos = torch.diag(sim_matrix)
81 | sin = torch.sqrt(1.0 - torch.pow(cos, 2))
82 | pin = cos * self.cos_m - sin * self.sin_m
83 | sim_matrix = sim_matrix - torch.diag_embed(cos) + torch.diag_embed(pin)
84 | logpt = F.log_softmax(sim_matrix / scale, dim=-1)
85 | logpt = torch.diag(logpt)
86 | nce_loss = -logpt
87 | sim_loss = nce_loss.mean()
88 | return sim_loss
89 |
90 |
91 | class CrossEn0(nn.Module):
92 | def __init__(self, config=None):
93 | super(CrossEn0, self).__init__()
94 |
95 | def forward(self, sim_matrix, b):
96 | logpt = F.log_softmax(sim_matrix[:b, :], dim=-1)
97 | logpt = torch.diag(logpt[:, :b])
98 | nce_loss = -logpt
99 | sim_loss = nce_loss.mean()
100 | return sim_loss
101 |
102 |
103 | class ema_CrossEn(nn.Module):
104 | def __init__(self, config=None):
105 | super(ema_CrossEn, self).__init__()
106 |
107 | def forward(self, sim_matrix0, sim_matrix1):
108 | m, n = sim_matrix0.size()
109 | diag1 = torch.diag(sim_matrix1)
110 | diag1 = torch.diag_embed(diag1)
111 | sim_matrix1 = sim_matrix1 - diag1
112 | logpt = F.log_softmax(torch.cat([sim_matrix0, sim_matrix1], dim=-1), dim=-1)
113 | logpt = torch.diag(logpt[:, :n])
114 | nce_loss = -logpt
115 | sim_loss = nce_loss.mean()
116 | return sim_loss
117 |
118 |
119 | class DC_CrossEn(nn.Module):
120 | def __init__(self, config=None):
121 | super(DC_CrossEn, self).__init__()
122 |
123 | def forward(self, sim_matrix0, sim_matrix1, seta=0.8):
124 | diag0 = torch.diag(sim_matrix0)
125 | diag1 = torch.diag(sim_matrix1)
126 | sim_matrix0 = sim_matrix0 - diag0
127 | sim_matrix1 = sim_matrix1 - diag1
128 | m, n = sim_matrix0.size()
129 |
130 | sim_matrix = torch.where(sim_matrix1 < seta, sim_matrix0, torch.tensor(0.0).to(sim_matrix0.device))
131 | sim_matrix = sim_matrix + diag0
132 |
133 | logpt = F.log_softmax(sim_matrix, dim=-1)
134 | logpt = torch.diag(logpt)
135 | nce_loss = -logpt
136 | sim_loss = nce_loss.mean()
137 | return sim_loss
138 |
139 |
140 | class ema_CrossEn1(nn.Module):
141 | def __init__(self, config=None):
142 | super(ema_CrossEn1, self).__init__()
143 |
144 | def forward(self, sim_matrix0, sim_matrix1):
145 | logpt0 = F.log_softmax(sim_matrix0, dim=-1)
146 | logpt1 = F.softmax(sim_matrix1, dim=-1)
147 | sim_loss = - logpt0 * logpt1
148 | # diag = torch.diag(sim_loss)
149 | # sim_loss = sim_loss - diag
150 | sim_loss = sim_loss.mean()
151 | return sim_loss
152 |
153 |
154 | class ema_CrossEn2(nn.Module):
155 | def __init__(self, config=None):
156 | super(ema_CrossEn2, self).__init__()
157 |
158 | def forward(self, sim_matrix0, sim_matrix1, lambd=0.5):
159 | m, n = sim_matrix1.size()
160 |
161 | logpt0 = F.log_softmax(sim_matrix0, dim=-1)
162 | logpt1 = F.softmax(sim_matrix1, dim=-1)
163 | logpt1 = lambd * torch.eye(m).to(logpt1.device) + (1 - lambd) * logpt1
164 |
165 | sim_loss = - logpt0 * logpt1
166 | sim_loss = sim_loss.sum() / m
167 | return sim_loss
168 |
169 |
170 | class KL(nn.Module):
171 | def __init__(self, config=None):
172 | super(KL, self).__init__()
173 |
174 | def forward(self, sim_matrix0, sim_matrix1):
175 | logpt0 = F.log_softmax(sim_matrix0, dim=-1)
176 | logpt1 = F.softmax(sim_matrix1, dim=-1)
177 | kl = F.kl_div(logpt0, logpt1, reduction='mean')
178 | # kl = F.kl_div(logpt0, logpt1, reduction='sum')
179 | return kl
180 |
181 |
182 | def _batch_hard(mat_distance, mat_similarity, indice=False):
183 | sorted_mat_distance, positive_indices = torch.sort(mat_distance + (9999999.) * (1 - mat_similarity), dim=1,
184 | descending=False)
185 | hard_p = sorted_mat_distance[:, 0]
186 | hard_p_indice = positive_indices[:, 0]
187 | sorted_mat_distance, negative_indices = torch.sort(mat_distance + (-9999999.) * (mat_similarity), dim=1,
188 | descending=True)
189 | hard_n = sorted_mat_distance[:, 0]
190 | hard_n_indice = negative_indices[:, 0]
191 | if (indice):
192 | return hard_p, hard_n, hard_p_indice, hard_n_indice
193 | return hard_p, hard_n
194 |
195 |
196 | class SoftTripletLoss(nn.Module):
197 | def __init__(self, config=None):
198 | super(SoftTripletLoss, self).__init__()
199 |
200 | def forward(self, sim_matrix0, sim_matrix1):
201 | N = sim_matrix0.size(0)
202 | mat_sim = torch.eye(N).float().to(sim_matrix0.device)
203 | dist_ap, dist_an, ap_idx, an_idx = _batch_hard(sim_matrix0, mat_sim, indice=True)
204 | triple_dist = torch.stack((dist_ap, dist_an), dim=1)
205 | triple_dist = F.log_softmax(triple_dist, dim=1)
206 | dist_ap_ref = torch.gather(sim_matrix1, 1, ap_idx.view(N, 1).expand(N, N))[:, 0]
207 | dist_an_ref = torch.gather(sim_matrix1, 1, an_idx.view(N, 1).expand(N, N))[:, 0]
208 | triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1)
209 | triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach()
210 | loss = (- triple_dist_ref * triple_dist).mean(0).sum()
211 | return loss
212 |
213 |
214 | class MSE(nn.Module):
215 | def __init__(self, config=None):
216 | super(MSE, self).__init__()
217 |
218 | def forward(self, sim_matrix0, sim_matrix1):
219 | logpt = (sim_matrix0 - sim_matrix1)
220 | loss = logpt * logpt
221 | return loss.mean()
222 |
223 |
224 | def euclidean_dist(x, y):
225 | m, n = x.size(0), y.size(0)
226 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
227 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
228 | dist = xx + yy
229 | dist.addmm_(1, -2, x, y.t())
230 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
231 | return dist
232 |
233 |
234 | def uniformity_loss(x, y):
235 | input = torch.cat((x, y), dim=0)
236 | m = input.size(0)
237 | dist = euclidean_dist(input, input)
238 | return torch.logsumexp(torch.logsumexp(dist, dim=-1), dim=-1) - torch.log(torch.tensor(m * m - m))
239 |
240 |
241 | class AllGather(torch.autograd.Function):
242 | """An autograd function that performs allgather on a tensor."""
243 |
244 | @staticmethod
245 | def forward(ctx, tensor, args):
246 | if args.world_size == 1:
247 | ctx.rank = args.local_rank
248 | ctx.batch_size = tensor.shape[0]
249 | return tensor
250 | else:
251 | output = [torch.empty_like(tensor) for _ in range(args.world_size)]
252 | torch.distributed.all_gather(output, tensor)
253 | ctx.rank = args.local_rank
254 | ctx.batch_size = tensor.shape[0]
255 | return torch.cat(output, dim=0)
256 |
257 | @staticmethod
258 | def backward(ctx, grad_output):
259 | return (
260 | grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
261 | None,
262 | )
263 |
264 |
265 | class AllGather2(torch.autograd.Function):
266 | """An autograd function that performs allgather on a tensor."""
267 |
268 | # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
269 | @staticmethod
270 | def forward(ctx, tensor, args):
271 | if args.world_size == 1:
272 | ctx.rank = args.local_rank
273 | ctx.batch_size = tensor.shape[0]
274 | return tensor
275 | else:
276 | output = [torch.empty_like(tensor) for _ in range(args.world_size)]
277 | torch.distributed.all_gather(output, tensor)
278 | ctx.rank = args.local_rank
279 | ctx.batch_size = tensor.shape[0]
280 | return torch.cat(output, dim=0)
281 |
282 | @staticmethod
283 | def backward(ctx, grad_output):
284 | grad_input = grad_output.clone()
285 | torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
286 | return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) * ctx.batch_size], None)
--------------------------------------------------------------------------------
/tvr/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__init__.py
--------------------------------------------------------------------------------
/tvr/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/utils/__pycache__/comm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/comm.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/utils/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/utils/__pycache__/metric_logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/metric_logger.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/utils/__pycache__/metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zchoi/GLSCL/06979f8ee1db6c9bf8ec10f6ca36531e0446f5a1/tvr/utils/__pycache__/metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/tvr/utils/comm.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains primitives for multi-gpu communication.
3 | This is useful when doing distributed training.
4 | """
5 |
6 | import time
7 | import pickle
8 |
9 | import torch
10 | import torch.distributed as dist
11 |
12 |
13 | def get_world_size():
14 | if not dist.is_available():
15 | return 1
16 | if not dist.is_initialized():
17 | return 1
18 | return dist.get_world_size()
19 |
20 |
21 | def get_rank():
22 | if not dist.is_available():
23 | return 0
24 | if not dist.is_initialized():
25 | return 0
26 | return dist.get_rank()
27 |
28 |
29 | def is_main_process():
30 | return get_rank() == 0
31 |
32 |
33 | def synchronize():
34 | """
35 | Helper function to synchronize (barrier) among all processes when
36 | using distributed training
37 | """
38 | if not dist.is_available():
39 | return
40 | if not dist.is_initialized():
41 | return
42 | world_size = dist.get_world_size()
43 | if world_size == 1:
44 | return
45 | dist.barrier()
46 |
47 |
48 | def all_gather(data):
49 | """
50 | Run all_gather on arbitrary picklable data (not necessarily tensors)
51 | Args:
52 | data: any picklable object
53 | Returns:
54 | list[data]: list of data gathered from each rank
55 | """
56 | world_size = get_world_size()
57 | if world_size == 1:
58 | return [data]
59 |
60 | buffer = pickle.dumps(data)
61 | storage = torch.ByteStorage.from_buffer(buffer)
62 | tensor = torch.ByteTensor(storage).to("cuda")
63 | # obtain Tensor size of each rank
64 | local_size = torch.LongTensor([tensor.numel()]).to("cuda")
65 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
66 | dist.all_gather(size_list, local_size)
67 | size_list = [int(size.item()) for size in size_list]
68 | max_size = max(size_list)
69 |
70 | # receiving Tensor from all ranks
71 | # we pad the tensor because torch all_gather does not support
72 | # gathering tensors of different shapes
73 | tensor_list = []
74 | for _ in size_list:
75 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
76 | if local_size != max_size:
77 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
78 | tensor = torch.cat((tensor, padding), dim=0)
79 | dist.all_gather(tensor_list, tensor)
80 |
81 | data_list = []
82 | for size, tensor in zip(size_list, tensor_list):
83 | # print(get_rank(), size)
84 | buffer = tensor.cpu().numpy().tobytes()[:size]
85 | data_list.append(pickle.loads(buffer))
86 |
87 | return data_list
88 |
89 |
90 | def reduce_dict(input_dict, average=True):
91 | """
92 | Args:
93 | input_dict (dict): all the values will be reduced
94 | average (bool): whether to do average or sum
95 | Reduce the values in the dictionary from all processes so that process with rank
96 | 0 has the averaged results. Returns a dict with the same fields as
97 | input_dict, after reduction.
98 | """
99 | world_size = get_world_size()
100 | if world_size < 2:
101 | return input_dict
102 | with torch.no_grad():
103 | names = []
104 | values = []
105 | # sort the keys so that they are consistent across processes
106 | for k in sorted(input_dict.keys()):
107 | names.append(k)
108 | values.append(input_dict[k])
109 | values = torch.stack(values, dim=0)
110 | dist.reduce(values, dst=0)
111 | if dist.get_rank() == 0 and average:
112 | # only main process gets accumulated, so only divide by
113 | # world_size in this case
114 | values /= world_size
115 | reduced_dict = {k: v for k, v in zip(names, values)}
116 | return reduced_dict
117 |
--------------------------------------------------------------------------------
/tvr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 |
6 | def setup_logger(name, save_dir, dist_rank, filename="log.txt"):
7 | logger = logging.getLogger(name)
8 | logger.setLevel(logging.ERROR)
9 | # don't log results for the non-master process
10 | if dist_rank > 0:
11 | return logger
12 | logger.setLevel(logging.DEBUG)
13 | ch = logging.StreamHandler(stream=sys.stdout)
14 | ch.setLevel(logging.DEBUG)
15 | # formatter = logging.Formatter(f"[{dist_rank}]"+"[%(asctime)s %(name)s %(lineno)s %(levelname)s]: %(message)s")
16 | formatter = logging.Formatter("[%(asctime)s %(name)s %(lineno)s %(levelname)s]: %(message)s")
17 | ch.setFormatter(formatter)
18 | logger.addHandler(ch)
19 | logger.propagate = False
20 |
21 | if save_dir:
22 | fh = logging.FileHandler(os.path.join(save_dir, filename))
23 | fh.setLevel(logging.DEBUG)
24 | fh.setFormatter(formatter)
25 | logger.addHandler(fh)
26 |
27 | return logger
28 |
--------------------------------------------------------------------------------
/tvr/utils/metric_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | from collections import defaultdict
3 | from collections import deque
4 |
5 | import torch
6 |
7 |
8 | class SmoothedValue(object):
9 | """Track a series of values and provide access to smoothed values over a
10 | window or the global series average.
11 | """
12 |
13 | def __init__(self, window_size=20):
14 | self.deque = deque(maxlen=window_size)
15 | self.series = []
16 | self.total = 0.0
17 | self.count = 0
18 |
19 | def update(self, value):
20 | self.deque.append(value)
21 | self.series.append(value)
22 | self.count += 1
23 | self.total += value
24 |
25 | @property
26 | def median(self):
27 | d = torch.tensor(list(self.deque))
28 | return d.median().item()
29 |
30 | @property
31 | def avg(self):
32 | d = torch.tensor(list(self.deque))
33 | return d.mean().item()
34 |
35 | @property
36 | def global_avg(self):
37 | return self.total / self.count
38 |
39 |
40 | class MetricLogger(object):
41 | def __init__(self, delimiter="\t"):
42 | self.meters = defaultdict(SmoothedValue)
43 | self.delimiter = delimiter
44 |
45 | def update(self, **kwargs):
46 | for k, v in kwargs.items():
47 | if isinstance(v, torch.Tensor):
48 | v = v.item()
49 | assert isinstance(v, (float, int))
50 | self.meters[k].update(v)
51 |
52 | def __getattr__(self, attr):
53 | if attr in self.meters:
54 | return self.meters[attr]
55 | if attr in self.__dict__:
56 | return self.__dict__[attr]
57 | raise AttributeError("'{}' object has no attribute '{}'".format(
58 | type(self).__name__, attr))
59 |
60 | def __str__(self):
61 | loss_str = []
62 | for name, meter in self.meters.items():
63 | loss_str.append(
64 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
65 | )
66 | return self.delimiter.join(loss_str)
67 |
--------------------------------------------------------------------------------
/tvr/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import unicode_literals
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def compute_metrics(x, re_ranking=False, sim=None):
11 |
12 | # 在score1+2中对每行进行排序
13 | # sx = np.sort(-x, axis=1) # 默认从小到达,加-1从大到小
14 | sx = torch.sort(torch.Tensor(x), dim=1, descending=True)[0].numpy()
15 | # if re_ranking:
16 | # topk = 10
17 | # # 从 score1 每行选择前 topk 个值
18 | # topk_values, topk_indices = torch.topk(torch.Tensor(sim), topk, dim=1)
19 |
20 | # # 在 score2 中对每行进行排序(降序)
21 | # sorted_score2, original_indice = torch.sort(torch.Tensor(x), dim=1, descending=True) # original_indice为排序后的索引, original_indice[i]: 第original_indice[i]下标对应的数值现在排在第i个位置
22 |
23 | # # 找出排序后 score_2 对应的 topk_indices 的数值目前都被排在哪些位置
24 | # ind = torch.zeros_like(topk_indices)
25 | # for i in range(sorted_score2.shape[0]):
26 | # for j in range(topk_indices.shape[1]):
27 | # ind[i][j] = (original_indice[i] == topk_indices[i][j]).nonzero().item() # video 3 和 video 8 排序后分别排在X1位和X2位
28 |
29 | # ind, _ = torch.sort(ind, dim=1, descending=False) # 对ind进行排序,得到原始的排序索引
30 |
31 | # new_indices = original_indice.clone() # 深拷贝
32 |
33 | # for i in range(topk_indices.shape[0]):
34 | # for j in range(topk_indices.shape[1]):
35 | # new_indices[i][ind[i][j]] = topk_indices[i][j]
36 |
37 | # sx = torch.gather(torch.Tensor(x), 1, new_indices).numpy() # re-ranking 后的score2
38 |
39 | d = np.diag(x)
40 | d = d[:, np.newaxis]
41 | ind = sx - d
42 | ind = np.where(ind == 0)
43 | ind = ind[1]
44 | metrics = {}
45 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind)
46 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind)
47 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind)
48 | metrics['R50'] = float(np.sum(ind < 50)) * 100 / len(ind)
49 | metrics['MR'] = np.median(ind) + 1
50 | metrics["MedianR"] = metrics['MR']
51 | metrics["MeanR"] = np.mean(ind) + 1
52 | metrics["cols"] = [int(i) for i in list(ind)]
53 | return metrics
54 |
55 | def print_computed_metrics(metrics):
56 | r1 = metrics['R1']
57 | r5 = metrics['R5']
58 | r10 = metrics['R10']
59 | r50 = metrics['R50']
60 | mr = metrics['MR']
61 | meanr = metrics["MeanR"]
62 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - R@50: {:.4f} - Median R: {} - MeanR: {}'.format(r1, r5, r10, r50,
63 | mr, meanr))
64 |
65 |
66 | # below two functions directly come from: https://github.com/Deferf/Experiments
67 | def tensor_text_to_video_metrics(sim_tensor, top_k=[1, 5, 10, 50]):
68 | if not torch.is_tensor(sim_tensor):
69 | sim_tensor = torch.tensor(sim_tensor)
70 |
71 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices.
72 | # Then obtain the double argsort to position the rank on the diagonal
73 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2)
74 | first_argsort = torch.argsort(stacked_sim_matrices, dim=-1, descending=True)
75 | second_argsort = torch.argsort(first_argsort, dim=-1, descending=False)
76 |
77 | # Extracts ranks i.e diagonals
78 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1=1, dim2=2))
79 |
80 | # Now we need to extract valid ranks, as some belong to inf padding values
81 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1=0, dim2=2))
82 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data))
83 | valid_ranks = ranks[mask]
84 | # A quick dimension check validates our results, there may be other correctness tests pending
85 | # Such as dot product localization, but that is for other time.
86 | # assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict])
87 | if not torch.is_tensor(valid_ranks):
88 | valid_ranks = torch.tensor(valid_ranks)
89 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k}
90 | results["MedianR"] = float(torch.median(valid_ranks + 1))
91 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1))
92 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1))
93 | results['MR'] = results["MedianR"]
94 | return results
95 |
96 |
97 | def tensor_video_to_text_sim(sim_tensor):
98 | if not torch.is_tensor(sim_tensor):
99 | sim_tensor = torch.tensor(sim_tensor)
100 | # Code to avoid nans
101 | sim_tensor[sim_tensor != sim_tensor] = float('-inf')
102 | # Forms a similarity matrix for use with rank at k
103 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True)
104 | return torch.squeeze(values).T
105 |
106 |
107 | if __name__ == '__main__':
108 | test_sim = np.random.rand(1000, 1000)
109 | metrics = compute_metrics(test_sim)
110 | print_computed_metrics(metrics)
111 |
--------------------------------------------------------------------------------