├── modules ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── cross-base │ └── cross_config.json ├── until_config.py ├── tokenization_clip.py ├── optimization.py ├── file_utils.py ├── module_cross.py ├── until_module.py ├── module_clip.py └── modeling.py ├── .gitignore ├── CLIP4Clip.png ├── LICENSE ├── util.py ├── metrics.py ├── preprocess └── compress_video.py ├── dataloaders ├── rawvideo_util.py ├── dataloader_msvd_retrieval.py ├── data_dataloaders.py ├── dataloader_lsmdc_retrieval.py ├── dataloader_didemo_retrieval.py ├── dataloader_activitynet_retrieval.py └── dataloader_msrvtt_retrieval.py ├── README.md └── main_task_retrieval.py /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .git 2 | .idea -------------------------------------------------------------------------------- /CLIP4Clip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArrowLuo/CLIP4Clip/HEAD/CLIP4Clip.png -------------------------------------------------------------------------------- /modules/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArrowLuo/CLIP4Clip/HEAD/modules/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /modules/cross-base/cross_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 512, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 2048, 8 | "max_position_embeddings": 128, 9 | "num_attention_heads": 8, 10 | "num_hidden_layers": 4, 11 | "vocab_size": 512 12 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ArrowLuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import threading 4 | from torch._utils import ExceptionWrapper 5 | import logging 6 | 7 | def get_a_var(obj): 8 | if isinstance(obj, torch.Tensor): 9 | return obj 10 | 11 | if isinstance(obj, list) or isinstance(obj, tuple): 12 | for result in map(get_a_var, obj): 13 | if isinstance(result, torch.Tensor): 14 | return result 15 | if isinstance(obj, dict): 16 | for result in map(get_a_var, obj.items()): 17 | if isinstance(result, torch.Tensor): 18 | return result 19 | return None 20 | 21 | def parallel_apply(fct, model, inputs, device_ids): 22 | modules = nn.parallel.replicate(model, device_ids) 23 | assert len(modules) == len(inputs) 24 | lock = threading.Lock() 25 | results = {} 26 | grad_enabled = torch.is_grad_enabled() 27 | 28 | def _worker(i, module, input): 29 | torch.set_grad_enabled(grad_enabled) 30 | device = get_a_var(input).get_device() 31 | try: 32 | with torch.cuda.device(device): 33 | # this also avoids accidental slicing of `input` if it is a Tensor 34 | if not isinstance(input, (list, tuple)): 35 | input = (input,) 36 | output = fct(module, *input) 37 | with lock: 38 | results[i] = output 39 | except Exception: 40 | with lock: 41 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) 42 | 43 | if len(modules) > 1: 44 | threads = [threading.Thread(target=_worker, args=(i, module, input)) 45 | for i, (module, input) in enumerate(zip(modules, inputs))] 46 | 47 | for thread in threads: 48 | thread.start() 49 | for thread in threads: 50 | thread.join() 51 | else: 52 | _worker(0, modules[0], inputs[0]) 53 | 54 | outputs = [] 55 | for i in range(len(inputs)): 56 | output = results[i] 57 | if isinstance(output, ExceptionWrapper): 58 | output.reraise() 59 | outputs.append(output) 60 | return outputs 61 | 62 | def get_logger(filename=None): 63 | logger = logging.getLogger('logger') 64 | logger.setLevel(logging.DEBUG) 65 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 66 | datefmt='%m/%d/%Y %H:%M:%S', 67 | level=logging.INFO) 68 | if filename is not None: 69 | handler = logging.FileHandler(filename) 70 | handler.setLevel(logging.DEBUG) 71 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 72 | logging.getLogger().addHandler(handler) 73 | return logger -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | 9 | def compute_metrics(x): 10 | sx = np.sort(-x, axis=1) 11 | d = np.diag(-x) 12 | d = d[:, np.newaxis] 13 | ind = sx - d 14 | ind = np.where(ind == 0) 15 | ind = ind[1] 16 | metrics = {} 17 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 18 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 19 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 20 | metrics['MR'] = np.median(ind) + 1 21 | metrics["MedianR"] = metrics['MR'] 22 | metrics["MeanR"] = np.mean(ind) + 1 23 | metrics["cols"] = [int(i) for i in list(ind)] 24 | return metrics 25 | 26 | def print_computed_metrics(metrics): 27 | r1 = metrics['R1'] 28 | r5 = metrics['R5'] 29 | r10 = metrics['R10'] 30 | mr = metrics['MR'] 31 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) 32 | 33 | # below two functions directly come from: https://github.com/Deferf/Experiments 34 | def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): 35 | if not torch.is_tensor(sim_tensor): 36 | sim_tensor = torch.tensor(sim_tensor) 37 | 38 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices. 39 | # Then obtain the double argsort to position the rank on the diagonal 40 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2) 41 | first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) 42 | second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) 43 | 44 | # Extracts ranks i.e diagonals 45 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) 46 | 47 | # Now we need to extract valid ranks, as some belong to inf padding values 48 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) 49 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) 50 | valid_ranks = ranks[mask] 51 | # A quick dimension check validates our results, there may be other correctness tests pending 52 | # Such as dot product localization, but that is for other time. 53 | #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) 54 | if not torch.is_tensor(valid_ranks): 55 | valid_ranks = torch.tensor(valid_ranks) 56 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} 57 | results["MedianR"] = float(torch.median(valid_ranks + 1)) 58 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) 59 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) 60 | results['MR'] = results["MedianR"] 61 | return results 62 | 63 | def tensor_video_to_text_sim(sim_tensor): 64 | if not torch.is_tensor(sim_tensor): 65 | sim_tensor = torch.tensor(sim_tensor) 66 | # Code to avoid nans 67 | sim_tensor[sim_tensor != sim_tensor] = float('-inf') 68 | # Forms a similarity matrix for use with rank at k 69 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True) 70 | return torch.squeeze(values).T 71 | -------------------------------------------------------------------------------- /preprocess/compress_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip 3 | Author: ArrowLuo 4 | """ 5 | import os 6 | import argparse 7 | import ffmpeg 8 | import subprocess 9 | import time 10 | import multiprocessing 11 | from multiprocessing import Pool 12 | import shutil 13 | try: 14 | from psutil import cpu_count 15 | except: 16 | from multiprocessing import cpu_count 17 | # multiprocessing.freeze_support() 18 | 19 | def compress(paras): 20 | input_video_path, output_video_path = paras 21 | try: 22 | command = ['ffmpeg', 23 | '-y', # (optional) overwrite output file if it exists 24 | '-i', input_video_path, 25 | '-filter:v', 26 | 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,224)\':\'if(gt(a,1),224,trunc(ow*a/2)*2)\'', # scale to 224 27 | '-map', '0:v', 28 | '-r', '3', # frames per second 29 | output_video_path, 30 | ] 31 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | out, err = ffmpeg.communicate() 33 | retcode = ffmpeg.poll() 34 | # print something above for debug 35 | except Exception as e: 36 | raise e 37 | 38 | def prepare_input_output_pairs(input_root, output_root): 39 | input_video_path_list = [] 40 | output_video_path_list = [] 41 | for root, dirs, files in os.walk(input_root): 42 | for file_name in files: 43 | input_video_path = os.path.join(root, file_name) 44 | output_video_path = os.path.join(output_root, file_name) 45 | if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: 46 | pass 47 | else: 48 | input_video_path_list.append(input_video_path) 49 | output_video_path_list.append(output_video_path) 50 | return input_video_path_list, output_video_path_list 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser(description='Compress video for speed-up') 54 | parser.add_argument('--input_root', type=str, help='input root') 55 | parser.add_argument('--output_root', type=str, help='output root') 56 | args = parser.parse_args() 57 | 58 | input_root = args.input_root 59 | output_root = args.output_root 60 | 61 | assert input_root != output_root 62 | 63 | if not os.path.exists(output_root): 64 | os.makedirs(output_root, exist_ok=True) 65 | 66 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root) 67 | 68 | print("Total video need to process: {}".format(len(input_video_path_list))) 69 | num_works = cpu_count() 70 | print("Begin with {}-core logical processor.".format(num_works)) 71 | 72 | pool = Pool(num_works) 73 | data_dict_list = pool.map(compress, 74 | [(input_video_path, output_video_path) for 75 | input_video_path, output_video_path in 76 | zip(input_video_path_list, output_video_path_list)]) 77 | pool.close() 78 | pool.join() 79 | 80 | print("Compress finished, wait for checking files...") 81 | for input_video_path, output_video_path in zip(input_video_path_list, output_video_path_list): 82 | if os.path.exists(input_video_path): 83 | if os.path.exists(output_video_path) is False or os.path.getsize(output_video_path) < 1.: 84 | shutil.copyfile(input_video_path, output_video_path) 85 | print("Copy and replace file: {}".format(output_video_path)) -------------------------------------------------------------------------------- /dataloaders/rawvideo_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from PIL import Image 4 | # pytorch=1.7.1 5 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 6 | # pip install opencv-python 7 | import cv2 8 | 9 | class RawVideoExtractorCV2(): 10 | def __init__(self, centercrop=False, size=224, framerate=-1, ): 11 | self.centercrop = centercrop 12 | self.size = size 13 | self.framerate = framerate 14 | self.transform = self._transform(self.size) 15 | 16 | def _transform(self, n_px): 17 | return Compose([ 18 | Resize(n_px, interpolation=Image.BICUBIC), 19 | CenterCrop(n_px), 20 | lambda image: image.convert("RGB"), 21 | ToTensor(), 22 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 23 | ]) 24 | 25 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): 26 | if start_time is not None or end_time is not None: 27 | assert isinstance(start_time, int) and isinstance(end_time, int) \ 28 | and start_time > -1 and end_time > start_time 29 | assert sample_fp > -1 30 | 31 | # Samples a frame sample_fp X frames. 32 | cap = cv2.VideoCapture(video_file) 33 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 34 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 35 | 36 | total_duration = (frameCount + fps - 1) // fps 37 | start_sec, end_sec = 0, total_duration 38 | 39 | if start_time is not None: 40 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration 41 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) 42 | 43 | interval = 1 44 | if sample_fp > 0: 45 | interval = fps // sample_fp 46 | else: 47 | sample_fp = fps 48 | if interval == 0: interval = 1 49 | 50 | inds = [ind for ind in np.arange(0, fps, interval)] 51 | assert len(inds) >= sample_fp 52 | inds = inds[:sample_fp] 53 | 54 | ret = True 55 | images, included = [], [] 56 | 57 | for sec in np.arange(start_sec, end_sec + 1): 58 | if not ret: break 59 | sec_base = int(sec * fps) 60 | for ind in inds: 61 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) 62 | ret, frame = cap.read() 63 | if not ret: break 64 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 65 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 66 | 67 | cap.release() 68 | 69 | if len(images) > 0: 70 | video_data = th.tensor(np.stack(images)) 71 | else: 72 | video_data = th.zeros(1) 73 | return {'video': video_data} 74 | 75 | def get_video_data(self, video_path, start_time=None, end_time=None): 76 | image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time) 77 | return image_input 78 | 79 | def process_raw_data(self, raw_video_data): 80 | tensor_size = raw_video_data.size() 81 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1]) 82 | return tensor 83 | 84 | def process_frame_order(self, raw_video_data, frame_order=0): 85 | # 0: ordinary order; 1: reverse order; 2: random order. 86 | if frame_order == 0: 87 | pass 88 | elif frame_order == 1: 89 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) 90 | raw_video_data = raw_video_data[reverse_order, ...] 91 | elif frame_order == 2: 92 | random_order = np.arange(raw_video_data.size(0)) 93 | np.random.shuffle(random_order) 94 | raw_video_data = raw_video_data[random_order, ...] 95 | 96 | return raw_video_data 97 | 98 | # An ordinary video frame extractor based CV2 99 | RawVideoExtractor = RawVideoExtractorCV2 -------------------------------------------------------------------------------- /modules/until_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import logging 26 | import tarfile 27 | import tempfile 28 | import shutil 29 | import torch 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | class PretrainedConfig(object): 35 | 36 | pretrained_model_archive_map = {} 37 | config_name = "" 38 | weights_name = "" 39 | 40 | @classmethod 41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): 42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 43 | if os.path.exists(archive_file) is False: 44 | if pretrained_model_name in cls.pretrained_model_archive_map: 45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] 46 | else: 47 | archive_file = pretrained_model_name 48 | 49 | # redirect to the cache, if necessary 50 | try: 51 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 52 | except FileNotFoundError: 53 | if task_config is None or task_config.local_rank == 0: 54 | logger.error( 55 | "Model name '{}' was not found in model name list. " 56 | "We assumed '{}' was a path or url but couldn't find any file " 57 | "associated to this path or url.".format( 58 | pretrained_model_name, 59 | archive_file)) 60 | return None 61 | if resolved_archive_file == archive_file: 62 | if task_config is None or task_config.local_rank == 0: 63 | logger.info("loading archive file {}".format(archive_file)) 64 | else: 65 | if task_config is None or task_config.local_rank == 0: 66 | logger.info("loading archive file {} from cache at {}".format( 67 | archive_file, resolved_archive_file)) 68 | tempdir = None 69 | if os.path.isdir(resolved_archive_file): 70 | serialization_dir = resolved_archive_file 71 | else: 72 | # Extract archive to temp dir 73 | tempdir = tempfile.mkdtemp() 74 | if task_config is None or task_config.local_rank == 0: 75 | logger.info("extracting archive file {} to temp dir {}".format( 76 | resolved_archive_file, tempdir)) 77 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 78 | archive.extractall(tempdir) 79 | serialization_dir = tempdir 80 | # Load config 81 | config_file = os.path.join(serialization_dir, cls.config_name) 82 | config = cls.from_json_file(config_file) 83 | config.type_vocab_size = type_vocab_size 84 | if task_config is None or task_config.local_rank == 0: 85 | logger.info("Model config {}".format(config)) 86 | 87 | if state_dict is None: 88 | weights_path = os.path.join(serialization_dir, cls.weights_name) 89 | if os.path.exists(weights_path): 90 | state_dict = torch.load(weights_path, map_location='cpu') 91 | else: 92 | if task_config is None or task_config.local_rank == 0: 93 | logger.info("Weight doesn't exsits. {}".format(weights_path)) 94 | 95 | if tempdir: 96 | # Clean up temp dir 97 | shutil.rmtree(tempdir) 98 | 99 | return config, state_dict 100 | 101 | @classmethod 102 | def from_dict(cls, json_object): 103 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 104 | config = cls(vocab_size_or_config_json_file=-1) 105 | for key, value in json_object.items(): 106 | config.__dict__[key] = value 107 | return config 108 | 109 | @classmethod 110 | def from_json_file(cls, json_file): 111 | """Constructs a `BertConfig` from a json file of parameters.""" 112 | with open(json_file, "r", encoding='utf-8') as reader: 113 | text = reader.read() 114 | return cls.from_dict(json.loads(text)) 115 | 116 | def __repr__(self): 117 | return str(self.to_json_string()) 118 | 119 | def to_dict(self): 120 | """Serializes this instance to a Python dictionary.""" 121 | output = copy.deepcopy(self.__dict__) 122 | return output 123 | 124 | def to_json_string(self): 125 | """Serializes this instance to a JSON string.""" 126 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" -------------------------------------------------------------------------------- /modules/tokenization_clip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | self.vocab = self.encoder 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | def tokenize(self, text): 137 | tokens = [] 138 | text = whitespace_clean(basic_clean(text)).lower() 139 | for token in re.findall(self.pat, text): 140 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 141 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 142 | return tokens 143 | 144 | def convert_tokens_to_ids(self, tokens): 145 | return [self.encoder[bpe_token] for bpe_token in tokens] -------------------------------------------------------------------------------- /modules/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /dataloaders/dataloader_msvd_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pickle 10 | from dataloaders.rawvideo_util import RawVideoExtractor 11 | 12 | class MSVD_DataLoader(Dataset): 13 | """MSVD dataset loader.""" 14 | def __init__( 15 | self, 16 | subset, 17 | data_path, 18 | features_path, 19 | tokenizer, 20 | max_words=30, 21 | feature_framerate=1.0, 22 | max_frames=100, 23 | image_resolution=224, 24 | frame_order=0, 25 | slice_framepos=0, 26 | ): 27 | self.data_path = data_path 28 | self.features_path = features_path 29 | self.feature_framerate = feature_framerate 30 | self.max_words = max_words 31 | self.max_frames = max_frames 32 | self.tokenizer = tokenizer 33 | # 0: ordinary order; 1: reverse order; 2: random order. 34 | self.frame_order = frame_order 35 | assert self.frame_order in [0, 1, 2] 36 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 37 | self.slice_framepos = slice_framepos 38 | assert self.slice_framepos in [0, 1, 2] 39 | 40 | self.subset = subset 41 | assert self.subset in ["train", "val", "test"] 42 | video_id_path_dict = {} 43 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") 44 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") 45 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") 46 | caption_file = os.path.join(self.data_path, "raw-captions.pkl") 47 | 48 | with open(video_id_path_dict[self.subset], 'r') as fp: 49 | video_ids = [itm.strip() for itm in fp.readlines()] 50 | 51 | with open(caption_file, 'rb') as f: 52 | captions = pickle.load(f) 53 | 54 | video_dict = {} 55 | for root, dub_dir, video_files in os.walk(self.features_path): 56 | for video_file in video_files: 57 | video_id_ = ".".join(video_file.split(".")[:-1]) 58 | if video_id_ not in video_ids: 59 | continue 60 | file_path_ = os.path.join(root, video_file) 61 | video_dict[video_id_] = file_path_ 62 | self.video_dict = video_dict 63 | 64 | self.sample_len = 0 65 | self.sentences_dict = {} 66 | self.cut_off_points = [] 67 | for video_id in video_ids: 68 | assert video_id in captions 69 | for cap in captions[video_id]: 70 | cap_txt = " ".join(cap) 71 | self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) 72 | self.cut_off_points.append(len(self.sentences_dict)) 73 | 74 | ## below variables are used to multi-sentences retrieval 75 | # self.cut_off_points: used to tag the label when calculate the metric 76 | # self.sentence_num: used to cut the sentence representation 77 | # self.video_num: used to cut the video representation 78 | self.multi_sentence_per_video = True # !!! important tag for eval 79 | if self.subset == "val" or self.subset == "test": 80 | self.sentence_num = len(self.sentences_dict) 81 | self.video_num = len(video_ids) 82 | assert len(self.cut_off_points) == self.video_num 83 | print("For {}, sentence number: {}".format(self.subset, self.sentence_num)) 84 | print("For {}, video number: {}".format(self.subset, self.video_num)) 85 | 86 | print("Video number: {}".format(len(self.video_dict))) 87 | print("Total Paire: {}".format(len(self.sentences_dict))) 88 | 89 | self.sample_len = len(self.sentences_dict) 90 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 91 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 92 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 93 | 94 | def __len__(self): 95 | return self.sample_len 96 | 97 | def _get_text(self, video_id, caption): 98 | k = 1 99 | choice_video_ids = [video_id] 100 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 101 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 102 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 103 | 104 | for i, video_id in enumerate(choice_video_ids): 105 | words = self.tokenizer.tokenize(caption) 106 | 107 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 108 | total_length_with_CLS = self.max_words - 1 109 | if len(words) > total_length_with_CLS: 110 | words = words[:total_length_with_CLS] 111 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 112 | 113 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 114 | input_mask = [1] * len(input_ids) 115 | segment_ids = [0] * len(input_ids) 116 | while len(input_ids) < self.max_words: 117 | input_ids.append(0) 118 | input_mask.append(0) 119 | segment_ids.append(0) 120 | assert len(input_ids) == self.max_words 121 | assert len(input_mask) == self.max_words 122 | assert len(segment_ids) == self.max_words 123 | 124 | pairs_text[i] = np.array(input_ids) 125 | pairs_mask[i] = np.array(input_mask) 126 | pairs_segment[i] = np.array(segment_ids) 127 | 128 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 129 | 130 | def _get_rawvideo(self, choice_video_ids): 131 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 132 | max_video_length = [0] * len(choice_video_ids) 133 | 134 | # Pair x L x T x 3 x H x W 135 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 136 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 137 | 138 | for i, video_id in enumerate(choice_video_ids): 139 | video_path = self.video_dict[video_id] 140 | 141 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 142 | raw_video_data = raw_video_data['video'] 143 | 144 | if len(raw_video_data.shape) > 3: 145 | raw_video_data_clip = raw_video_data 146 | # L x T x 3 x H x W 147 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 148 | if self.max_frames < raw_video_slice.shape[0]: 149 | if self.slice_framepos == 0: 150 | video_slice = raw_video_slice[:self.max_frames, ...] 151 | elif self.slice_framepos == 1: 152 | video_slice = raw_video_slice[-self.max_frames:, ...] 153 | else: 154 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 155 | video_slice = raw_video_slice[sample_indx, ...] 156 | else: 157 | video_slice = raw_video_slice 158 | 159 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 160 | 161 | slice_len = video_slice.shape[0] 162 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 163 | if slice_len < 1: 164 | pass 165 | else: 166 | video[i][:slice_len, ...] = video_slice 167 | else: 168 | print("video path: {} error. video id: {}".format(video_path, video_id)) 169 | 170 | for i, v_length in enumerate(max_video_length): 171 | video_mask[i][:v_length] = [1] * v_length 172 | 173 | return video, video_mask 174 | 175 | def __getitem__(self, idx): 176 | video_id, caption = self.sentences_dict[idx] 177 | 178 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) 179 | video, video_mask = self._get_rawvideo(choice_video_ids) 180 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval 2 | 3 | (**July 28, 2021**) Add ViT-B/16 with an extra `--pretrained_clip_name` 4 | 5 | (**Apr. 22, 2021**) First version 6 | 7 | The implementation of paper [**CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval**](https://arxiv.org/abs/2104.08860). 8 | 9 | CLIP4Clip is a video-text retrieval model based on [CLIP (ViT-B)](https://github.com/openai/CLIP). We investigate three similarity calculation approaches: parameter-free type, sequential type, and tight type, in this work. The model achieve SOTA results on MSR-VTT, MSVD, LSMDC, ActivityNet, and DiDeMo. 10 | 11 | ![CLIP4Clip](CLIP4Clip.png) 12 | 13 | ## Requirement 14 | ```sh 15 | # From CLIP 16 | conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 17 | pip install ftfy regex tqdm 18 | pip install opencv-python boto3 requests pandas 19 | ``` 20 | 21 | ## Data Preparing 22 | 23 | **For MSRVTT** 24 | 25 | The official data and video links can be found in [link](http://ms-multimedia-challenge.com/2017/dataset). 26 | 27 | For the convenience, you can also download the splits and captions by, 28 | ```sh 29 | wget https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msrvtt_data.zip 30 | ``` 31 | 32 | Besides, the raw videos can be found in [sharing](https://github.com/m-bain/frozen-in-time#-finetuning-benchmarks-msr-vtt) from *Frozen️ in Time*, i.e., 33 | ```sh 34 | wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip 35 | ``` 36 | 37 | **For MSVD** 38 | 39 | Raw videos can be download from [link](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/). 40 | 41 | The splits and `raw_captions` can be found in the wonderful job [collaborative-experts](https://github.com/albanie/collaborative-experts/blob/master/misc/datasets/msvd/README.md). For the convenience, you can also download them by, 42 | ```sh 43 | wget https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msvd_data.zip 44 | ``` 45 | 46 | **For LSMDC** 47 | 48 | You must obtain permission from MPII to download and use the data. The download link is [here](https://sites.google.com/site/describingmovies/download). 49 | The 1000 test clips data is [link](http://www.google.com/url?q=http%3A%2F%2Fdatasets.d2.mpi-inf.mpg.de%2FmovieDescription%2Fprotected%2Flsmdc2016%2FLSMDC16_challenge_1000_publictect.csv&sa=D&sntz=1&usg=AFQjCNGIaGVhCeb6zNfUs2UL1zNzoEtaSg). Read our paper and the [dataloader](./dataloaders/dataloader_lsmdc_retrieval.py) for more information. 50 | 51 | **For ActivityNet** 52 | 53 | The official websit has made the full dataset available on Google and Baidu drives, see more information at [here](http://activity-net.org/download.html) . The splits can be found in the job [collaborative-experts](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/activity-net). 54 | 55 | **For DiDeMo** 56 | 57 | Raw videos can be download from [LisaAnne/LocalizingMoments](https://github.com/LisaAnne/LocalizingMoments). The splits can be found in the job [collaborative-experts](https://github.com/albanie/collaborative-experts/tree/master/misc/datasets/didemo/README.md). 58 | 59 | 60 | ## Compress Video for Speed-up (optional) 61 | ```sh 62 | python preprocess/compress_video.py --input_root [raw_video_path] --output_root [compressed_video_path] 63 | ``` 64 | This script will compress the video to *3fps* with width *224* (or height *224*). Modify the variables for your customization. 65 | 66 | ## How to Run 67 | 68 | >`--features_path` is the video root path 69 | > 70 | >`--linear_patch` can be set with `2d` or `3d` 71 | > 72 | > `--sim_header` can be set with `meanP`, `seqLSTM`, `seqTransf`, or `tightTransf` 73 | > 74 | > `--pretrained_clip_name` can be set with `ViT-B/32` or `ViT-B/16` 75 | > 76 | > `--resume_model` can be used to reload the saved optimizer state to continuely train the model, **Note**: need to set the corresponding chechpoint via `--init_model` simultaneously. 77 | 78 | read our paper for more details on `--linear_patch` and `--sim_header`. Test more hyperparameters for better performance. 79 | 80 | Download CLIP (ViT-B/32) weight, 81 | ```sh 82 | wget -P ./modules https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt 83 | ``` 84 | or, download CLIP (ViT-B/16) weight, 85 | ```sh 86 | wget -P ./modules https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt 87 | ``` 88 | 89 | Then, run 90 | 91 | 92 | *The CLIP (ViT-B/32) is the default setting in the paper, replacing with the ViT-B/16 for better performance.* 93 | 94 | ### MSRVTT 95 | 96 | ```sh 97 | DATA_PATH=[Your MSRVTT data and videos path] 98 | python -m torch.distributed.launch --nproc_per_node=4 \ 99 | main_task_retrieval.py --do_train --num_thread_reader=0 \ 100 | --epochs=5 --batch_size=128 --n_display=50 \ 101 | --train_csv ${DATA_PATH}/MSRVTT_train.9k.csv \ 102 | --val_csv ${DATA_PATH}/MSRVTT_JSFUSION_test.csv \ 103 | --data_path ${DATA_PATH}/MSRVTT_data.json \ 104 | --features_path ${DATA_PATH}/MSRVTT_Videos \ 105 | --output_dir ckpts/ckpt_msrvtt_retrieval_looseType \ 106 | --lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \ 107 | --datatype msrvtt --expand_msrvtt_sentences \ 108 | --feature_framerate 1 --coef_lr 1e-3 \ 109 | --freeze_layer_num 0 --slice_framepos 2 \ 110 | --loose_type --linear_patch 2d --sim_header meanP \ 111 | --pretrained_clip_name ViT-B/32 112 | ``` 113 | 114 | ### MSVD 115 | ```sh 116 | DATA_PATH=[Your MSVD data and videos path] 117 | python -m torch.distributed.launch --nproc_per_node=4 \ 118 | main_task_retrieval.py --do_train --num_thread_reader=2 \ 119 | --epochs=5 --batch_size=128 --n_display=50 \ 120 | --data_path ${DATA_PATH} \ 121 | --features_path ${DATA_PATH}/MSVD_Videos \ 122 | --output_dir ckpts/ckpt_msvd_retrieval_looseType \ 123 | --lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \ 124 | --datatype msvd \ 125 | --feature_framerate 1 --coef_lr 1e-3 \ 126 | --freeze_layer_num 0 --slice_framepos 2 \ 127 | --loose_type --linear_patch 2d --sim_header meanP \ 128 | --pretrained_clip_name ViT-B/32 129 | ``` 130 | 131 | ### LSMDC 132 | ```sh 133 | DATA_PATH=[Your LSMDC data and videos path] 134 | python -m torch.distributed.launch --nproc_per_node=4 \ 135 | main_task_retrieval.py --do_train --num_thread_reader=2 \ 136 | --epochs=5 --batch_size=128 --n_display=50 \ 137 | --data_path ${DATA_PATH} \ 138 | --features_path ${DATA_PATH}/LSMDC_Videos \ 139 | --output_dir ckpts/ckpt_lsmdc_retrieval_looseType \ 140 | --lr 1e-4 --max_words 32 --max_frames 12 --batch_size_val 16 \ 141 | --datatype lsmdc --feature_framerate 1 --coef_lr 1e-3 \ 142 | --freeze_layer_num 0 --slice_framepos 2 \ 143 | --loose_type --linear_patch 2d --sim_header meanP \ 144 | --pretrained_clip_name ViT-B/32 145 | ``` 146 | 147 | ### ActivityNet 148 | ActivityNet is regarded as video-paragraph retrieval in our setting, thus, need more GPUs (or run with multi-node). 149 | ```sh 150 | DATA_PATH=[Your ActivityNet data and videos path] 151 | python -m torch.distributed.launch --nproc_per_node=8 \ 152 | main_task_retrieval.py --do_train --num_thread_reader=2 \ 153 | --epochs=5 --batch_size=128 --n_display=50 \ 154 | --data_path ${DATA_PATH} \ 155 | --features_path ${DATA_PATH}/Activity_Videos \ 156 | --output_dir ckpts/ckpt_activity_retrieval_looseType \ 157 | --lr 1e-4 --max_words 64 --max_frames 64 --batch_size_val 16 \ 158 | --datatype activity --feature_framerate 1 --coef_lr 1e-3 \ 159 | --freeze_layer_num 0 --slice_framepos 2 \ 160 | --loose_type --linear_patch 2d --sim_header meanP \ 161 | --pretrained_clip_name ViT-B/32 162 | ``` 163 | 164 | ### DiDeMo 165 | DiDeMo is regarded as video-paragraph retrieval in our setting, thus, need more GPUs (or run with multi-node). 166 | ```sh 167 | DATA_PATH=[Your DiDeMo data and videos path] 168 | python -m torch.distributed.launch --nproc_per_node=8 \ 169 | main_task_retrieval.py --do_train --num_thread_reader=2 \ 170 | --epochs=5 --batch_size=128 --n_display=50 \ 171 | --data_path ${DATA_PATH} \ 172 | --features_path ${DATA_PATH}/DiDeMo_Videos \ 173 | --output_dir ckpts/ckpt_didemo_retrieval_looseType \ 174 | --lr 1e-4 --max_words 64 --max_frames 64 --batch_size_val 16 \ 175 | --datatype didemo --feature_framerate 1 --coef_lr 1e-3 \ 176 | --freeze_layer_num 0 --slice_framepos 2 \ 177 | --loose_type --linear_patch 2d --sim_header meanP \ 178 | --pretrained_clip_name ViT-B/32 179 | ``` 180 | 181 | # Citation 182 | If you find CLIP4Clip useful in your work, you can cite the following paper: 183 | ```bibtex 184 | @Article{Luo2021CLIP4Clip, 185 | author = {Huaishao Luo and Lei Ji and Ming Zhong and Yang Chen and Wen Lei and Nan Duan and Tianrui Li}, 186 | title = {{CLIP4Clip}: An Empirical Study of CLIP for End to End Video Clip Retrieval}, 187 | journal = {arXiv preprint arXiv:2104.08860}, 188 | year = {2021}, 189 | } 190 | ``` 191 | 192 | # Acknowledgments 193 | Our code is based on [CLIP](https://github.com/openai/CLIP) and [UniVL](https://github.com/microsoft/UniVL). 194 | -------------------------------------------------------------------------------- /modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /dataloaders/data_dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader 4 | from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader 5 | from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader 6 | from dataloaders.dataloader_lsmdc_retrieval import LSMDC_DataLoader 7 | from dataloaders.dataloader_activitynet_retrieval import ActivityNet_DataLoader 8 | from dataloaders.dataloader_didemo_retrieval import DiDeMo_DataLoader 9 | 10 | def dataloader_msrvtt_train(args, tokenizer): 11 | msrvtt_dataset = MSRVTT_TrainDataLoader( 12 | csv_path=args.train_csv, 13 | json_path=args.data_path, 14 | features_path=args.features_path, 15 | max_words=args.max_words, 16 | feature_framerate=args.feature_framerate, 17 | tokenizer=tokenizer, 18 | max_frames=args.max_frames, 19 | unfold_sentences=args.expand_msrvtt_sentences, 20 | frame_order=args.train_frame_order, 21 | slice_framepos=args.slice_framepos, 22 | ) 23 | 24 | train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) 25 | dataloader = DataLoader( 26 | msrvtt_dataset, 27 | batch_size=args.batch_size // args.n_gpu, 28 | num_workers=args.num_thread_reader, 29 | pin_memory=False, 30 | shuffle=(train_sampler is None), 31 | sampler=train_sampler, 32 | drop_last=True, 33 | ) 34 | 35 | return dataloader, len(msrvtt_dataset), train_sampler 36 | 37 | def dataloader_msrvtt_test(args, tokenizer, subset="test"): 38 | msrvtt_testset = MSRVTT_DataLoader( 39 | csv_path=args.val_csv, 40 | features_path=args.features_path, 41 | max_words=args.max_words, 42 | feature_framerate=args.feature_framerate, 43 | tokenizer=tokenizer, 44 | max_frames=args.max_frames, 45 | frame_order=args.eval_frame_order, 46 | slice_framepos=args.slice_framepos, 47 | ) 48 | dataloader_msrvtt = DataLoader( 49 | msrvtt_testset, 50 | batch_size=args.batch_size_val, 51 | num_workers=args.num_thread_reader, 52 | shuffle=False, 53 | drop_last=False, 54 | ) 55 | return dataloader_msrvtt, len(msrvtt_testset) 56 | 57 | 58 | def dataloader_msvd_train(args, tokenizer): 59 | msvd_dataset = MSVD_DataLoader( 60 | subset="train", 61 | data_path=args.data_path, 62 | features_path=args.features_path, 63 | max_words=args.max_words, 64 | feature_framerate=args.feature_framerate, 65 | tokenizer=tokenizer, 66 | max_frames=args.max_frames, 67 | frame_order=args.train_frame_order, 68 | slice_framepos=args.slice_framepos, 69 | ) 70 | 71 | train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) 72 | dataloader = DataLoader( 73 | msvd_dataset, 74 | batch_size=args.batch_size // args.n_gpu, 75 | num_workers=args.num_thread_reader, 76 | pin_memory=False, 77 | shuffle=(train_sampler is None), 78 | sampler=train_sampler, 79 | drop_last=True, 80 | ) 81 | 82 | return dataloader, len(msvd_dataset), train_sampler 83 | 84 | def dataloader_msvd_test(args, tokenizer, subset="test"): 85 | msvd_testset = MSVD_DataLoader( 86 | subset=subset, 87 | data_path=args.data_path, 88 | features_path=args.features_path, 89 | max_words=args.max_words, 90 | feature_framerate=args.feature_framerate, 91 | tokenizer=tokenizer, 92 | max_frames=args.max_frames, 93 | frame_order=args.eval_frame_order, 94 | slice_framepos=args.slice_framepos, 95 | ) 96 | dataloader_msrvtt = DataLoader( 97 | msvd_testset, 98 | batch_size=args.batch_size_val, 99 | num_workers=args.num_thread_reader, 100 | shuffle=False, 101 | drop_last=False, 102 | ) 103 | return dataloader_msrvtt, len(msvd_testset) 104 | 105 | 106 | def dataloader_lsmdc_train(args, tokenizer): 107 | lsmdc_dataset = LSMDC_DataLoader( 108 | subset="train", 109 | data_path=args.data_path, 110 | features_path=args.features_path, 111 | max_words=args.max_words, 112 | feature_framerate=args.feature_framerate, 113 | tokenizer=tokenizer, 114 | max_frames=args.max_frames, 115 | frame_order=args.train_frame_order, 116 | slice_framepos=args.slice_framepos, 117 | ) 118 | 119 | train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) 120 | dataloader = DataLoader( 121 | lsmdc_dataset, 122 | batch_size=args.batch_size // args.n_gpu, 123 | num_workers=args.num_thread_reader, 124 | pin_memory=False, 125 | shuffle=(train_sampler is None), 126 | sampler=train_sampler, 127 | drop_last=True, 128 | ) 129 | 130 | return dataloader, len(lsmdc_dataset), train_sampler 131 | 132 | def dataloader_lsmdc_test(args, tokenizer, subset="test"): 133 | lsmdc_testset = LSMDC_DataLoader( 134 | subset=subset, 135 | data_path=args.data_path, 136 | features_path=args.features_path, 137 | max_words=args.max_words, 138 | feature_framerate=args.feature_framerate, 139 | tokenizer=tokenizer, 140 | max_frames=args.max_frames, 141 | frame_order=args.eval_frame_order, 142 | slice_framepos=args.slice_framepos, 143 | ) 144 | dataloader_msrvtt = DataLoader( 145 | lsmdc_testset, 146 | batch_size=args.batch_size_val, 147 | num_workers=args.num_thread_reader, 148 | shuffle=False, 149 | drop_last=False, 150 | ) 151 | return dataloader_msrvtt, len(lsmdc_testset) 152 | 153 | 154 | def dataloader_activity_train(args, tokenizer): 155 | activity_dataset = ActivityNet_DataLoader( 156 | subset="train", 157 | data_path=args.data_path, 158 | features_path=args.features_path, 159 | max_words=args.max_words, 160 | feature_framerate=args.feature_framerate, 161 | tokenizer=tokenizer, 162 | max_frames=args.max_frames, 163 | frame_order=args.train_frame_order, 164 | slice_framepos=args.slice_framepos, 165 | ) 166 | 167 | train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) 168 | dataloader = DataLoader( 169 | activity_dataset, 170 | batch_size=args.batch_size // args.n_gpu, 171 | num_workers=args.num_thread_reader, 172 | pin_memory=False, 173 | shuffle=(train_sampler is None), 174 | sampler=train_sampler, 175 | drop_last=True, 176 | ) 177 | 178 | return dataloader, len(activity_dataset), train_sampler 179 | 180 | def dataloader_activity_test(args, tokenizer, subset="test"): 181 | activity_testset = ActivityNet_DataLoader( 182 | subset=subset, 183 | data_path=args.data_path, 184 | features_path=args.features_path, 185 | max_words=args.max_words, 186 | feature_framerate=args.feature_framerate, 187 | tokenizer=tokenizer, 188 | max_frames=args.max_frames, 189 | frame_order=args.eval_frame_order, 190 | slice_framepos=args.slice_framepos, 191 | ) 192 | dataloader_msrvtt = DataLoader( 193 | activity_testset, 194 | batch_size=args.batch_size_val, 195 | num_workers=args.num_thread_reader, 196 | shuffle=False, 197 | drop_last=False, 198 | ) 199 | return dataloader_msrvtt, len(activity_testset) 200 | 201 | 202 | def dataloader_didemo_train(args, tokenizer): 203 | didemo_dataset = DiDeMo_DataLoader( 204 | subset="train", 205 | data_path=args.data_path, 206 | features_path=args.features_path, 207 | max_words=args.max_words, 208 | feature_framerate=args.feature_framerate, 209 | tokenizer=tokenizer, 210 | max_frames=args.max_frames, 211 | frame_order=args.train_frame_order, 212 | slice_framepos=args.slice_framepos, 213 | ) 214 | 215 | train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) 216 | dataloader = DataLoader( 217 | didemo_dataset, 218 | batch_size=args.batch_size // args.n_gpu, 219 | num_workers=args.num_thread_reader, 220 | pin_memory=False, 221 | shuffle=(train_sampler is None), 222 | sampler=train_sampler, 223 | drop_last=True, 224 | ) 225 | 226 | return dataloader, len(didemo_dataset), train_sampler 227 | 228 | def dataloader_didemo_test(args, tokenizer, subset="test"): 229 | didemo_testset = DiDeMo_DataLoader( 230 | subset=subset, 231 | data_path=args.data_path, 232 | features_path=args.features_path, 233 | max_words=args.max_words, 234 | feature_framerate=args.feature_framerate, 235 | tokenizer=tokenizer, 236 | max_frames=args.max_frames, 237 | frame_order=args.eval_frame_order, 238 | slice_framepos=args.slice_framepos, 239 | ) 240 | dataloader_didemo = DataLoader( 241 | didemo_testset, 242 | batch_size=args.batch_size_val, 243 | num_workers=args.num_thread_reader, 244 | shuffle=False, 245 | drop_last=False, 246 | ) 247 | return dataloader_didemo, len(didemo_testset) 248 | 249 | 250 | DATALOADER_DICT = {} 251 | DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test, "test":None} 252 | DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test} 253 | DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test} 254 | DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":None} 255 | DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test} 256 | -------------------------------------------------------------------------------- /dataloaders/dataloader_lsmdc_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | import math 11 | from dataloaders.rawvideo_util import RawVideoExtractor 12 | 13 | class LSMDC_DataLoader(Dataset): 14 | """LSMDC dataset loader.""" 15 | def __init__( 16 | self, 17 | subset, 18 | data_path, 19 | features_path, 20 | tokenizer, 21 | max_words=30, 22 | feature_framerate=1.0, 23 | max_frames=100, 24 | image_resolution=224, 25 | frame_order=0, 26 | slice_framepos=0, 27 | ): 28 | self.data_path = data_path 29 | self.features_path = features_path 30 | self.feature_framerate = feature_framerate 31 | self.max_words = max_words 32 | self.max_frames = max_frames 33 | self.tokenizer = tokenizer 34 | # 0: ordinary order; 1: reverse order; 2: random order. 35 | self.frame_order = frame_order 36 | assert self.frame_order in [0, 1, 2] 37 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 38 | self.slice_framepos = slice_framepos 39 | assert self.slice_framepos in [0, 1, 2] 40 | 41 | self.subset = subset 42 | assert self.subset in ["train", "val", "test"] 43 | 44 | video_json_path_dict = {} 45 | video_json_path_dict["train"] = os.path.join(self.data_path, "LSMDC16_annos_training.csv") 46 | video_json_path_dict["val"] = os.path.join(self.data_path, "LSMDC16_annos_val.csv") 47 | video_json_path_dict["test"] = os.path.join(self.data_path, "LSMDC16_challenge_1000_publictect.csv") 48 | 49 | # \t\t\t\t\t 50 | # is not a unique identifier, i.e. the same can be associated with multiple sentences. 51 | # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances 52 | video_id_list = [] 53 | caption_dict = {} 54 | with open(video_json_path_dict[self.subset], 'r') as fp: 55 | for line in fp: 56 | line = line.strip() 57 | line_split = line.split("\t") 58 | assert len(line_split) == 6 59 | clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split 60 | caption_dict[len(caption_dict)] = (clip_id, sentence) 61 | if clip_id not in video_id_list: video_id_list.append(clip_id) 62 | 63 | video_dict = {} 64 | for root, dub_dir, video_files in os.walk(self.features_path): 65 | for video_file in video_files: 66 | video_id_ = ".".join(video_file.split(".")[:-1]) 67 | if video_id_ not in video_id_list: 68 | continue 69 | file_path_ = os.path.join(root, video_file) 70 | video_dict[video_id_] = file_path_ 71 | 72 | self.video_dict = video_dict 73 | 74 | # Get all captions 75 | self.iter2video_pairs_dict = {} 76 | for clip_id, sentence in caption_dict.values(): 77 | if clip_id not in self.video_dict: 78 | continue 79 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (clip_id, sentence) 80 | 81 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 82 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 83 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 84 | 85 | def __len__(self): 86 | return len(self.iter2video_pairs_dict) 87 | 88 | def _get_video_id_from_pseduo(self, pseudo_video_id): 89 | video_id = pseudo_video_id[2:] 90 | return video_id 91 | 92 | def _get_video_id_single(self, path): 93 | pseudo_video_id_list = [] 94 | video_id_list = [] 95 | print('Loading json: {}'.format(path)) 96 | with open(path, 'r') as f: 97 | json_data = json.load(f) 98 | 99 | for pseudo_video_id in json_data: 100 | if pseudo_video_id in pseudo_video_id_list: 101 | print("reduplicate.") 102 | else: 103 | video_id = self._get_video_id_from_pseduo(pseudo_video_id) 104 | pseudo_video_id_list.append(pseudo_video_id) 105 | video_id_list.append(video_id) 106 | return pseudo_video_id_list, video_id_list 107 | 108 | def _get_captions_single(self, path): 109 | pseudo_caption_dict = {} 110 | with open(path, 'r') as f: 111 | json_data = json.load(f) 112 | 113 | for pseudo_video_id, v_ in json_data.items(): 114 | pseudo_caption_dict[pseudo_video_id] = {} 115 | timestamps = v_["timestamps"] 116 | pseudo_caption_dict[pseudo_video_id]["start"] = \ 117 | np.array([int(math.floor(float(itm[0]))) for itm in timestamps], dtype=object) 118 | pseudo_caption_dict[pseudo_video_id]["end"] = \ 119 | np.array([int(math.ceil(float(itm[1]))) for itm in timestamps], dtype=object) 120 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array(v_["sentences"], dtype=object) 121 | return pseudo_caption_dict 122 | 123 | def _get_text(self, video_id, caption): 124 | k = 1 125 | choice_video_ids = [video_id] 126 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 127 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 128 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 129 | 130 | for i, video_id in enumerate(choice_video_ids): 131 | words = self.tokenizer.tokenize(caption) 132 | 133 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 134 | total_length_with_CLS = self.max_words - 1 135 | if len(words) > total_length_with_CLS: 136 | words = words[:total_length_with_CLS] 137 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 138 | 139 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 140 | input_mask = [1] * len(input_ids) 141 | segment_ids = [0] * len(input_ids) 142 | while len(input_ids) < self.max_words: 143 | input_ids.append(0) 144 | input_mask.append(0) 145 | segment_ids.append(0) 146 | assert len(input_ids) == self.max_words 147 | assert len(input_mask) == self.max_words 148 | assert len(segment_ids) == self.max_words 149 | 150 | pairs_text[i] = np.array(input_ids) 151 | pairs_mask[i] = np.array(input_mask) 152 | pairs_segment[i] = np.array(segment_ids) 153 | 154 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 155 | 156 | def _get_rawvideo(self, choice_video_ids): 157 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 158 | max_video_length = [0] * len(choice_video_ids) 159 | 160 | # Pair x L x T x 3 x H x W 161 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 162 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 163 | 164 | try: 165 | for i, video_id in enumerate(choice_video_ids): 166 | video_path = self.video_dict[video_id] 167 | 168 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 169 | raw_video_data = raw_video_data['video'] 170 | 171 | if len(raw_video_data.shape) > 3: 172 | raw_video_data_clip = raw_video_data 173 | # L x T x 3 x H x W 174 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 175 | if self.max_frames < raw_video_slice.shape[0]: 176 | if self.slice_framepos == 0: 177 | video_slice = raw_video_slice[:self.max_frames, ...] 178 | elif self.slice_framepos == 1: 179 | video_slice = raw_video_slice[-self.max_frames:, ...] 180 | else: 181 | sample_indx = np.linspace(0, raw_video_slice.shape[0]-1, num=self.max_frames, dtype=int) 182 | video_slice = raw_video_slice[sample_indx, ...] 183 | else: 184 | video_slice = raw_video_slice 185 | 186 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 187 | 188 | slice_len = video_slice.shape[0] 189 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 190 | if slice_len < 1: 191 | pass 192 | else: 193 | video[i][:slice_len, ...] = video_slice 194 | else: 195 | print("video path: {} error. video id: {}".format(video_path, video_id)) 196 | except Exception as excep: 197 | print("Video ids: {}".format(choice_video_ids)) 198 | raise excep 199 | 200 | for i, v_length in enumerate(max_video_length): 201 | video_mask[i][:v_length] = [1] * v_length 202 | return video, video_mask 203 | 204 | def __getitem__(self, feature_idx): 205 | clip_id, sentence = self.iter2video_pairs_dict[feature_idx] 206 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(clip_id, sentence) 207 | video, video_mask = self._get_rawvideo(choice_video_ids) 208 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 209 | -------------------------------------------------------------------------------- /dataloaders/dataloader_didemo_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | from dataloaders.rawvideo_util import RawVideoExtractor 11 | 12 | class DiDeMo_DataLoader(Dataset): 13 | def __init__( 14 | self, 15 | subset, 16 | data_path, 17 | features_path, 18 | tokenizer, 19 | max_words=30, 20 | feature_framerate=1.0, 21 | max_frames=100, 22 | image_resolution=224, 23 | frame_order=0, 24 | slice_framepos=0, 25 | ): 26 | self.data_path = data_path 27 | self.features_path = features_path 28 | self.feature_framerate = feature_framerate 29 | self.max_words = max_words 30 | self.max_frames = max_frames 31 | self.tokenizer = tokenizer 32 | # 0: ordinary order; 1: reverse order; 2: random order. 33 | self.frame_order = frame_order 34 | assert self.frame_order in [0, 1, 2] 35 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 36 | self.slice_framepos = slice_framepos 37 | assert self.slice_framepos in [0, 1, 2] 38 | 39 | self.subset = subset 40 | assert self.subset in ["train", "val", "test"] 41 | 42 | video_id_path_dict = {} 43 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") 44 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") 45 | video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") 46 | 47 | video_json_path_dict = {} 48 | video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json") 49 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json") 50 | video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json") 51 | 52 | with open(video_id_path_dict[self.subset], 'r') as fp: 53 | video_ids = [itm.strip() for itm in fp.readlines()] 54 | 55 | caption_dict = {} 56 | with open(video_json_path_dict[self.subset], 'r') as f: 57 | json_data = json.load(f) 58 | for itm in json_data: 59 | description = itm["description"] 60 | times = itm["times"] 61 | video = itm["video"] 62 | if video not in video_ids: 63 | continue 64 | 65 | # each video is split into 5-second temporal chunks 66 | # average the points from each annotator 67 | start_ = np.mean([t_[0] for t_ in times]) * 5 68 | end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5 69 | if video in caption_dict: 70 | caption_dict[video]["start"].append(start_) 71 | caption_dict[video]["end"].append(end_) 72 | caption_dict[video]["text"].append(description) 73 | else: 74 | caption_dict[video] = {} 75 | caption_dict[video]["start"] = [start_] 76 | caption_dict[video]["end"] = [end_] 77 | caption_dict[video]["text"] = [description] 78 | 79 | for k_ in caption_dict.keys(): 80 | caption_dict[k_]["start"] = [0] 81 | # trick to save time on obtaining each video length 82 | # [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: 83 | # Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. 84 | caption_dict[k_]["end"] = [31] 85 | caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])] 86 | 87 | video_dict = {} 88 | for root, dub_dir, video_files in os.walk(self.features_path): 89 | for video_file in video_files: 90 | video_id_ = video_file 91 | if video_id_ not in video_ids: 92 | continue 93 | file_path_ = os.path.join(root, video_file) 94 | video_dict[video_id_] = file_path_ 95 | 96 | self.caption_dict = caption_dict 97 | self.video_dict = video_dict 98 | video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys())) 99 | 100 | # Get all captions 101 | self.iter2video_pairs_dict = {} 102 | for video_id in self.caption_dict.keys(): 103 | if video_id not in video_ids: 104 | continue 105 | caption = self.caption_dict[video_id] 106 | n_caption = len(caption['start']) 107 | for sub_id in range(n_caption): 108 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id) 109 | 110 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 111 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 112 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 113 | 114 | def __len__(self): 115 | return len(self.iter2video_pairs_dict) 116 | 117 | def _get_text(self, video_id, sub_id): 118 | caption = self.caption_dict[video_id] 119 | k = 1 120 | r_ind = [sub_id] 121 | 122 | starts = np.zeros(k, dtype=np.long) 123 | ends = np.zeros(k, dtype=np.long) 124 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 125 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 126 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 127 | 128 | for i in range(k): 129 | ind = r_ind[i] 130 | start_, end_ = caption['start'][ind], caption['end'][ind] 131 | words = self.tokenizer.tokenize(caption['text'][ind]) 132 | starts[i], ends[i] = start_, end_ 133 | 134 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 135 | total_length_with_CLS = self.max_words - 1 136 | if len(words) > total_length_with_CLS: 137 | words = words[:total_length_with_CLS] 138 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 139 | 140 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 141 | input_mask = [1] * len(input_ids) 142 | segment_ids = [0] * len(input_ids) 143 | while len(input_ids) < self.max_words: 144 | input_ids.append(0) 145 | input_mask.append(0) 146 | segment_ids.append(0) 147 | assert len(input_ids) == self.max_words 148 | assert len(input_mask) == self.max_words 149 | assert len(segment_ids) == self.max_words 150 | 151 | pairs_text[i] = np.array(input_ids) 152 | pairs_mask[i] = np.array(input_mask) 153 | pairs_segment[i] = np.array(segment_ids) 154 | 155 | return pairs_text, pairs_mask, pairs_segment, starts, ends 156 | 157 | def _get_rawvideo(self, idx, s, e): 158 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) 159 | max_video_length = [0] * len(s) 160 | 161 | # Pair x L x T x 3 x H x W 162 | video = np.zeros((len(s), self.max_frames, 1, 3, 163 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 164 | video_path = self.video_dict[idx] 165 | 166 | try: 167 | for i in range(len(s)): 168 | start_time = int(s[i]) 169 | end_time = int(e[i]) 170 | start_time = start_time if start_time >= 0. else 0. 171 | end_time = end_time if end_time >= 0. else 0. 172 | if start_time > end_time: 173 | start_time, end_time = end_time, start_time 174 | elif start_time == end_time: 175 | end_time = end_time + 1 176 | 177 | cache_id = "{}_{}_{}".format(video_path, start_time, end_time) 178 | # Should be optimized by gathering all asking of this video 179 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 180 | raw_video_data = raw_video_data['video'] 181 | 182 | if len(raw_video_data.shape) > 3: 183 | raw_video_data_clip = raw_video_data 184 | # L x T x 3 x H x W 185 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 186 | if self.max_frames < raw_video_slice.shape[0]: 187 | if self.slice_framepos == 0: 188 | video_slice = raw_video_slice[:self.max_frames, ...] 189 | elif self.slice_framepos == 1: 190 | video_slice = raw_video_slice[-self.max_frames:, ...] 191 | else: 192 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 193 | video_slice = raw_video_slice[sample_indx, ...] 194 | else: 195 | video_slice = raw_video_slice 196 | 197 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 198 | 199 | slice_len = video_slice.shape[0] 200 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 201 | if slice_len < 1: 202 | pass 203 | else: 204 | video[i][:slice_len, ...] = video_slice 205 | else: 206 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) 207 | except Exception as excep: 208 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) 209 | pass 210 | # raise e 211 | 212 | for i, v_length in enumerate(max_video_length): 213 | video_mask[i][:v_length] = [1] * v_length 214 | 215 | return video, video_mask 216 | 217 | def __getitem__(self, feature_idx): 218 | video_id, sub_id = self.iter2video_pairs_dict[feature_idx] 219 | 220 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id) 221 | video, video_mask = self._get_rawvideo(video_id, starts, ends) 222 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 223 | -------------------------------------------------------------------------------- /modules/module_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | import json 8 | import math 9 | import logging 10 | import tarfile 11 | import tempfile 12 | import shutil 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from .file_utils import cached_path 18 | from .until_config import PretrainedConfig 19 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN 20 | from collections import OrderedDict 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 25 | CONFIG_NAME = 'cross_config.json' 26 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 27 | 28 | 29 | class CrossConfig(PretrainedConfig): 30 | """Configuration class to store the configuration of a `CrossModel`. 31 | """ 32 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP 33 | config_name = CONFIG_NAME 34 | weights_name = WEIGHTS_NAME 35 | def __init__(self, 36 | vocab_size_or_config_json_file, 37 | hidden_size=768, 38 | num_hidden_layers=12, 39 | num_attention_heads=12, 40 | intermediate_size=3072, 41 | hidden_act="gelu", 42 | hidden_dropout_prob=0.1, 43 | attention_probs_dropout_prob=0.1, 44 | max_position_embeddings=512, 45 | type_vocab_size=2, 46 | initializer_range=0.02): 47 | """Constructs CrossConfig. 48 | 49 | Args: 50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. 51 | hidden_size: Size of the encoder layers and the pooler layer. 52 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 53 | num_attention_heads: Number of attention heads for each attention layer in 54 | the Transformer encoder. 55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 56 | layer in the Transformer encoder. 57 | hidden_act: The non-linear activation function (function or string) in the 58 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 60 | layers in the embeddings, encoder, and pooler. 61 | attention_probs_dropout_prob: The dropout ratio for the attention 62 | probabilities. 63 | max_position_embeddings: The maximum sequence length that this model might 64 | ever be used with. Typically set this to something large just in case 65 | (e.g., 512 or 1024 or 2048). 66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 67 | `CrossModel`. 68 | initializer_range: The sttdev of the truncated_normal_initializer for 69 | initializing all weight matrices. 70 | """ 71 | if isinstance(vocab_size_or_config_json_file, str): 72 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 73 | json_config = json.loads(reader.read()) 74 | for key, value in json_config.items(): 75 | self.__dict__[key] = value 76 | elif isinstance(vocab_size_or_config_json_file, int): 77 | self.vocab_size = vocab_size_or_config_json_file 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | else: 89 | raise ValueError("First argument must be either a vocabulary size (int)" 90 | "or the path to a pretrained model config file (str)") 91 | 92 | class QuickGELU(nn.Module): 93 | def forward(self, x: torch.Tensor): 94 | return x * torch.sigmoid(1.702 * x) 95 | 96 | class ResidualAttentionBlock(nn.Module): 97 | def __init__(self, d_model: int, n_head: int): 98 | super().__init__() 99 | 100 | self.attn = nn.MultiheadAttention(d_model, n_head) 101 | self.ln_1 = LayerNorm(d_model) 102 | self.mlp = nn.Sequential(OrderedDict([ 103 | ("c_fc", nn.Linear(d_model, d_model * 4)), 104 | ("gelu", QuickGELU()), 105 | ("c_proj", nn.Linear(d_model * 4, d_model)) 106 | ])) 107 | self.ln_2 = LayerNorm(d_model) 108 | self.n_head = n_head 109 | 110 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): 111 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) 112 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 113 | 114 | def forward(self, para_tuple: tuple): 115 | # x: torch.Tensor, attn_mask: torch.Tensor 116 | # print(para_tuple) 117 | x, attn_mask = para_tuple 118 | x = x + self.attention(self.ln_1(x), attn_mask) 119 | x = x + self.mlp(self.ln_2(x)) 120 | return (x, attn_mask) 121 | 122 | class Transformer(nn.Module): 123 | def __init__(self, width: int, layers: int, heads: int): 124 | super().__init__() 125 | self.width = width 126 | self.layers = layers 127 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 128 | 129 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 130 | return self.resblocks((x, attn_mask))[0] 131 | 132 | class CrossEmbeddings(nn.Module): 133 | """Construct the embeddings from word, position and token_type embeddings. 134 | """ 135 | def __init__(self, config): 136 | super(CrossEmbeddings, self).__init__() 137 | 138 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 139 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 140 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, concat_embeddings, concat_type=None): 144 | 145 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) 146 | # if concat_type is None: 147 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) 148 | 149 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) 150 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) 151 | 152 | # token_type_embeddings = self.token_type_embeddings(concat_type) 153 | position_embeddings = self.position_embeddings(position_ids) 154 | 155 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings 156 | # embeddings = self.LayerNorm(embeddings) 157 | embeddings = self.dropout(embeddings) 158 | return embeddings 159 | 160 | class CrossPooler(nn.Module): 161 | def __init__(self, config): 162 | super(CrossPooler, self).__init__() 163 | self.ln_pool = LayerNorm(config.hidden_size) 164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 165 | self.activation = QuickGELU() 166 | 167 | def forward(self, hidden_states, hidden_mask): 168 | # We "pool" the model by simply taking the hidden state corresponding 169 | # to the first token. 170 | hidden_states = self.ln_pool(hidden_states) 171 | pooled_output = hidden_states[:, 0] 172 | pooled_output = self.dense(pooled_output) 173 | pooled_output = self.activation(pooled_output) 174 | return pooled_output 175 | 176 | class CrossModel(PreTrainedModel): 177 | 178 | def initialize_parameters(self): 179 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 180 | attn_std = self.transformer.width ** -0.5 181 | fc_std = (2 * self.transformer.width) ** -0.5 182 | for block in self.transformer.resblocks: 183 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 184 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 185 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 186 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 187 | 188 | def __init__(self, config): 189 | super(CrossModel, self).__init__(config) 190 | 191 | self.embeddings = CrossEmbeddings(config) 192 | 193 | transformer_width = config.hidden_size 194 | transformer_layers = config.num_hidden_layers 195 | transformer_heads = config.num_attention_heads 196 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads,) 197 | self.pooler = CrossPooler(config) 198 | self.apply(self.init_weights) 199 | 200 | def build_attention_mask(self, attention_mask): 201 | extended_attention_mask = attention_mask.unsqueeze(1) 202 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 203 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 204 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) 205 | return extended_attention_mask 206 | 207 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True): 208 | 209 | if attention_mask is None: 210 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) 211 | if concat_type is None: 212 | concat_type = torch.zeros_like(attention_mask) 213 | 214 | extended_attention_mask = self.build_attention_mask(attention_mask) 215 | 216 | embedding_output = self.embeddings(concat_input, concat_type) 217 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND 218 | embedding_output = self.transformer(embedding_output, extended_attention_mask) 219 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD 220 | 221 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) 222 | 223 | return embedding_output, pooled_output 224 | -------------------------------------------------------------------------------- /dataloaders/dataloader_activitynet_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import json 10 | import math 11 | from dataloaders.rawvideo_util import RawVideoExtractor 12 | 13 | class ActivityNet_DataLoader(Dataset): 14 | def __init__( 15 | self, 16 | subset, 17 | data_path, 18 | features_path, 19 | tokenizer, 20 | max_words=30, 21 | feature_framerate=1.0, 22 | max_frames=100, 23 | image_resolution=224, 24 | frame_order=0, 25 | slice_framepos=0, 26 | ): 27 | self.data_path = data_path 28 | self.features_path = features_path 29 | self.feature_framerate = feature_framerate 30 | self.max_words = max_words 31 | self.max_frames = max_frames 32 | self.tokenizer = tokenizer 33 | # 0: ordinary order; 1: reverse order; 2: random order. 34 | self.frame_order = frame_order 35 | assert self.frame_order in [0, 1, 2] 36 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 37 | self.slice_framepos = slice_framepos 38 | assert self.slice_framepos in [0, 1, 2] 39 | 40 | self.subset = subset 41 | assert self.subset in ["train", "val"] 42 | 43 | video_id_path_dict = {} 44 | video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json") 45 | video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json") 46 | 47 | video_json_path_dict = {} 48 | video_json_path_dict["train"] = os.path.join(self.data_path, "train.json") 49 | video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json") 50 | 51 | pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset]) 52 | pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset]) 53 | 54 | print("video id list: {}".format(len(video_id_list))) 55 | print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys()))) 56 | 57 | video_dict = {} 58 | for root, dub_dir, video_files in os.walk(self.features_path): 59 | for video_file in video_files: 60 | video_id_ = ".".join(video_file.split(".")[:-1]) 61 | if video_id_ not in video_id_list: 62 | continue 63 | file_path_ = os.path.join(root, video_file) 64 | video_dict[video_id_] = file_path_ 65 | self.video_dict = video_dict 66 | print("video dict: {}".format(len(video_dict))) 67 | 68 | self.pseudo_video_id_list = pseudo_video_id_list 69 | self.video_id_list = video_id_list 70 | self.pseudo_caption_dict = pseudo_caption_dict 71 | 72 | # Get iterator video ids 73 | self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)} 74 | # Get all captions 75 | self.iter2video_pairs_dict = {} 76 | for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list): 77 | if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict: 78 | continue 79 | caption = self.pseudo_caption_dict[pseudo_video_id] 80 | n_caption = len(caption['start']) 81 | for sub_id in range(n_caption): 82 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id) 83 | 84 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 85 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 86 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 87 | 88 | def __len__(self): 89 | return len(self.iter2video_pairs_dict) 90 | 91 | def _get_video_id_from_pseduo(self, pseudo_video_id): 92 | video_id = pseudo_video_id[2:] 93 | return video_id 94 | 95 | def _get_video_id_single(self, path): 96 | pseudo_video_id_list = [] 97 | video_id_list = [] 98 | print('Loading json: {}'.format(path)) 99 | with open(path, 'r') as f: 100 | json_data = json.load(f) 101 | 102 | for pseudo_video_id in json_data: 103 | if pseudo_video_id in pseudo_video_id_list: 104 | print("reduplicate.") 105 | else: 106 | video_id = self._get_video_id_from_pseduo(pseudo_video_id) 107 | pseudo_video_id_list.append(pseudo_video_id) 108 | video_id_list.append(video_id) 109 | return pseudo_video_id_list, video_id_list 110 | 111 | def _get_captions_single(self, path): 112 | pseudo_caption_dict = {} 113 | with open(path, 'r') as f: 114 | json_data = json.load(f) 115 | 116 | for pseudo_video_id, v_ in json_data.items(): 117 | pseudo_caption_dict[pseudo_video_id] = {} 118 | duration = v_["duration"] 119 | pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object) 120 | pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object) 121 | pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object) 122 | return pseudo_caption_dict 123 | 124 | def _get_text(self, pseudo_video_id, sub_id): 125 | caption = self.pseudo_caption_dict[pseudo_video_id] 126 | k = 1 127 | r_ind = [sub_id] 128 | 129 | starts = np.zeros(k, dtype=np.long) 130 | ends = np.zeros(k, dtype=np.long) 131 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 132 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 133 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 134 | 135 | for i in range(k): 136 | ind = r_ind[i] 137 | start_, end_ = caption['start'][ind], caption['end'][ind] 138 | words = self.tokenizer.tokenize(caption['text'][ind]) 139 | starts[i], ends[i] = start_, end_ 140 | 141 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 142 | total_length_with_CLS = self.max_words - 1 143 | if len(words) > total_length_with_CLS: 144 | words = words[:total_length_with_CLS] 145 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 146 | 147 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 148 | input_mask = [1] * len(input_ids) 149 | segment_ids = [0] * len(input_ids) 150 | while len(input_ids) < self.max_words: 151 | input_ids.append(0) 152 | input_mask.append(0) 153 | segment_ids.append(0) 154 | assert len(input_ids) == self.max_words 155 | assert len(input_mask) == self.max_words 156 | assert len(segment_ids) == self.max_words 157 | 158 | pairs_text[i] = np.array(input_ids) 159 | pairs_mask[i] = np.array(input_mask) 160 | pairs_segment[i] = np.array(segment_ids) 161 | 162 | return pairs_text, pairs_mask, pairs_segment, starts, ends 163 | 164 | def _get_rawvideo(self, idx, s, e): 165 | video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) 166 | max_video_length = [0] * len(s) 167 | 168 | # Pair x L x T x 3 x H x W 169 | video = np.zeros((len(s), self.max_frames, 1, 3, 170 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 171 | video_path = self.video_dict[idx] 172 | try: 173 | for i in range(len(s)): 174 | start_time = int(s[i]) 175 | end_time = int(e[i]) 176 | start_time = start_time if start_time >= 0. else 0. 177 | end_time = end_time if end_time >= 0. else 0. 178 | if start_time > end_time: 179 | start_time, end_time = end_time, start_time 180 | elif start_time == end_time: 181 | end_time = end_time + 1 182 | 183 | # Should be optimized by gathering all asking of this video 184 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) 185 | raw_video_data = raw_video_data['video'] 186 | 187 | if len(raw_video_data.shape) > 3: 188 | raw_video_data_clip = raw_video_data 189 | # L x T x 3 x H x W 190 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 191 | if self.max_frames < raw_video_slice.shape[0]: 192 | if self.slice_framepos == 0: 193 | video_slice = raw_video_slice[:self.max_frames, ...] 194 | elif self.slice_framepos == 1: 195 | video_slice = raw_video_slice[-self.max_frames:, ...] 196 | else: 197 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 198 | video_slice = raw_video_slice[sample_indx, ...] 199 | else: 200 | video_slice = raw_video_slice 201 | 202 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 203 | 204 | slice_len = video_slice.shape[0] 205 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 206 | if slice_len < 1: 207 | pass 208 | else: 209 | video[i][:slice_len, ...] = video_slice 210 | else: 211 | print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) 212 | except Exception as excep: 213 | print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) 214 | raise excep 215 | 216 | for i, v_length in enumerate(max_video_length): 217 | video_mask[i][:v_length] = [1] * v_length 218 | 219 | return video, video_mask 220 | 221 | def __getitem__(self, feature_idx): 222 | pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx] 223 | idx = self.video_id2idx_dict[pseudo_video_id] 224 | 225 | pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id) 226 | video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends) 227 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 228 | -------------------------------------------------------------------------------- /modules/until_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | import logging 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | import math 24 | from modules.until_config import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def gelu(x): 29 | """Implementation of the gelu activation function. 30 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 31 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 32 | """ 33 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 34 | 35 | def swish(x): 36 | return x * torch.sigmoid(x) 37 | 38 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 39 | 40 | class LayerNorm(nn.Module): 41 | def __init__(self, hidden_size, eps=1e-12): 42 | """Construct a layernorm module in the TF style (epsilon inside the square root). 43 | """ 44 | super(LayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(hidden_size)) 46 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 47 | self.variance_epsilon = eps 48 | 49 | def forward(self, x): 50 | u = x.mean(-1, keepdim=True) 51 | s = (x - u).pow(2).mean(-1, keepdim=True) 52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 53 | return self.weight * x + self.bias 54 | 55 | class PreTrainedModel(nn.Module): 56 | """ An abstract class to handle weights initialization and 57 | a simple interface for dowloading and loading pretrained models. 58 | """ 59 | def __init__(self, config, *inputs, **kwargs): 60 | super(PreTrainedModel, self).__init__() 61 | if not isinstance(config, PretrainedConfig): 62 | raise ValueError( 63 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 64 | "To create a model from a Google pretrained model use " 65 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 66 | self.__class__.__name__, self.__class__.__name__ 67 | )) 68 | self.config = config 69 | 70 | def init_weights(self, module): 71 | """ Initialize the weights. 72 | """ 73 | if isinstance(module, (nn.Linear, nn.Embedding)): 74 | # Slightly different from the TF version which uses truncated_normal for initialization 75 | # cf https://github.com/pytorch/pytorch/pull/5617 76 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 77 | elif isinstance(module, LayerNorm): 78 | if 'beta' in dir(module) and 'gamma' in dir(module): 79 | module.beta.data.zero_() 80 | module.gamma.data.fill_(1.0) 81 | else: 82 | module.bias.data.zero_() 83 | module.weight.data.fill_(1.0) 84 | if isinstance(module, nn.Linear) and module.bias is not None: 85 | module.bias.data.zero_() 86 | 87 | def resize_token_embeddings(self, new_num_tokens=None): 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | def init_preweight(cls, model, state_dict, prefix=None, task_config=None): 92 | old_keys = [] 93 | new_keys = [] 94 | for key in state_dict.keys(): 95 | new_key = None 96 | if 'gamma' in key: 97 | new_key = key.replace('gamma', 'weight') 98 | if 'beta' in key: 99 | new_key = key.replace('beta', 'bias') 100 | if new_key: 101 | old_keys.append(key) 102 | new_keys.append(new_key) 103 | for old_key, new_key in zip(old_keys, new_keys): 104 | state_dict[new_key] = state_dict.pop(old_key) 105 | 106 | if prefix is not None: 107 | old_keys = [] 108 | new_keys = [] 109 | for key in state_dict.keys(): 110 | old_keys.append(key) 111 | new_keys.append(prefix + key) 112 | for old_key, new_key in zip(old_keys, new_keys): 113 | state_dict[new_key] = state_dict.pop(old_key) 114 | 115 | missing_keys = [] 116 | unexpected_keys = [] 117 | error_msgs = [] 118 | # copy state_dict so _load_from_state_dict can modify it 119 | metadata = getattr(state_dict, '_metadata', None) 120 | state_dict = state_dict.copy() 121 | if metadata is not None: 122 | state_dict._metadata = metadata 123 | 124 | def load(module, prefix=''): 125 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 126 | module._load_from_state_dict( 127 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 128 | for name, child in module._modules.items(): 129 | if child is not None: 130 | load(child, prefix + name + '.') 131 | 132 | load(model, prefix='') 133 | 134 | if prefix is None and (task_config is None or task_config.local_rank == 0): 135 | logger.info("-" * 20) 136 | if len(missing_keys) > 0: 137 | logger.info("Weights of {} not initialized from pretrained model: {}" 138 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 139 | if len(unexpected_keys) > 0: 140 | logger.info("Weights from pretrained model not used in {}: {}" 141 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 142 | if len(error_msgs) > 0: 143 | logger.error("Weights from pretrained model cause errors in {}: {}" 144 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 145 | 146 | return model 147 | 148 | @property 149 | def dtype(self): 150 | """ 151 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 152 | """ 153 | try: 154 | return next(self.parameters()).dtype 155 | except StopIteration: 156 | # For nn.DataParallel compatibility in PyTorch 1.5 157 | def find_tensor_attributes(module: nn.Module): 158 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 159 | return tuples 160 | 161 | gen = self._named_members(get_members_fn=find_tensor_attributes) 162 | first_tuple = next(gen) 163 | return first_tuple[1].dtype 164 | 165 | @classmethod 166 | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): 167 | """ 168 | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. 169 | Download and cache the pre-trained model file if needed. 170 | """ 171 | # Instantiate model. 172 | model = cls(config, *inputs, **kwargs) 173 | if state_dict is None: 174 | return model 175 | model = cls.init_preweight(model, state_dict) 176 | 177 | return model 178 | 179 | ################################## 180 | ###### LOSS FUNCTION ############# 181 | ################################## 182 | class CrossEn(nn.Module): 183 | def __init__(self,): 184 | super(CrossEn, self).__init__() 185 | 186 | def forward(self, sim_matrix): 187 | logpt = F.log_softmax(sim_matrix, dim=-1) 188 | logpt = torch.diag(logpt) 189 | nce_loss = -logpt 190 | sim_loss = nce_loss.mean() 191 | return sim_loss 192 | 193 | class MILNCELoss(nn.Module): 194 | def __init__(self, batch_size=1, n_pair=1,): 195 | super(MILNCELoss, self).__init__() 196 | self.batch_size = batch_size 197 | self.n_pair = n_pair 198 | torch_v = float(".".join(torch.__version__.split(".")[:2])) 199 | self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 200 | 201 | def forward(self, sim_matrix): 202 | mm_mask = np.eye(self.batch_size) 203 | mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) 204 | mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) 205 | 206 | from_text_matrix = sim_matrix + mm_mask * -1e12 207 | from_video_matrix = sim_matrix.transpose(1, 0) 208 | 209 | new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) 210 | logpt = F.log_softmax(new_sim_matrix, dim=-1) 211 | 212 | mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) 213 | masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 214 | 215 | new_logpt = -torch.logsumexp(masked_logpt, dim=-1) 216 | 217 | logpt_choice = torch.zeros_like(new_logpt) 218 | mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) 219 | logpt_choice[mark_ind] = 1 220 | sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() 221 | return sim_loss 222 | 223 | class MaxMarginRankingLoss(nn.Module): 224 | def __init__(self, 225 | margin=1.0, 226 | negative_weighting=False, 227 | batch_size=1, 228 | n_pair=1, 229 | hard_negative_rate=0.5, 230 | ): 231 | super(MaxMarginRankingLoss, self).__init__() 232 | self.margin = margin 233 | self.n_pair = n_pair 234 | self.batch_size = batch_size 235 | easy_negative_rate = 1 - hard_negative_rate 236 | self.easy_negative_rate = easy_negative_rate 237 | self.negative_weighting = negative_weighting 238 | if n_pair > 1 and batch_size > 1: 239 | alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) 240 | mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha 241 | mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) 242 | mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) 243 | self.mm_mask = mm_mask.float() 244 | 245 | def forward(self, x): 246 | d = torch.diag(x) 247 | max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ 248 | F.relu(self.margin + x - d.view(1, -1)) 249 | if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: 250 | max_margin = max_margin * self.mm_mask.to(max_margin.device) 251 | return max_margin.mean() 252 | 253 | class AllGather(torch.autograd.Function): 254 | """An autograd function that performs allgather on a tensor.""" 255 | 256 | @staticmethod 257 | def forward(ctx, tensor, args): 258 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 259 | torch.distributed.all_gather(output, tensor) 260 | ctx.rank = args.rank 261 | ctx.batch_size = tensor.shape[0] 262 | return torch.cat(output, dim=0) 263 | 264 | @staticmethod 265 | def backward(ctx, grad_output): 266 | return ( 267 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 268 | None, 269 | ) 270 | -------------------------------------------------------------------------------- /dataloaders/dataloader_msrvtt_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pandas as pd 10 | from collections import defaultdict 11 | import json 12 | import random 13 | from dataloaders.rawvideo_util import RawVideoExtractor 14 | 15 | class MSRVTT_DataLoader(Dataset): 16 | """MSRVTT dataset loader.""" 17 | def __init__( 18 | self, 19 | csv_path, 20 | features_path, 21 | tokenizer, 22 | max_words=30, 23 | feature_framerate=1.0, 24 | max_frames=100, 25 | image_resolution=224, 26 | frame_order=0, 27 | slice_framepos=0, 28 | ): 29 | self.data = pd.read_csv(csv_path) 30 | self.features_path = features_path 31 | self.feature_framerate = feature_framerate 32 | self.max_words = max_words 33 | self.max_frames = max_frames 34 | self.tokenizer = tokenizer 35 | # 0: ordinary order; 1: reverse order; 2: random order. 36 | self.frame_order = frame_order 37 | assert self.frame_order in [0, 1, 2] 38 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 39 | self.slice_framepos = slice_framepos 40 | assert self.slice_framepos in [0, 1, 2] 41 | 42 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 43 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 44 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 45 | 46 | def __len__(self): 47 | return len(self.data) 48 | 49 | def _get_text(self, video_id, sentence): 50 | choice_video_ids = [video_id] 51 | n_caption = len(choice_video_ids) 52 | 53 | k = n_caption 54 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 55 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 56 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 57 | 58 | for i, video_id in enumerate(choice_video_ids): 59 | words = self.tokenizer.tokenize(sentence) 60 | 61 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 62 | total_length_with_CLS = self.max_words - 1 63 | if len(words) > total_length_with_CLS: 64 | words = words[:total_length_with_CLS] 65 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 66 | 67 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 68 | input_mask = [1] * len(input_ids) 69 | segment_ids = [0] * len(input_ids) 70 | while len(input_ids) < self.max_words: 71 | input_ids.append(0) 72 | input_mask.append(0) 73 | segment_ids.append(0) 74 | assert len(input_ids) == self.max_words 75 | assert len(input_mask) == self.max_words 76 | assert len(segment_ids) == self.max_words 77 | 78 | pairs_text[i] = np.array(input_ids) 79 | pairs_mask[i] = np.array(input_mask) 80 | pairs_segment[i] = np.array(segment_ids) 81 | 82 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 83 | 84 | def _get_rawvideo(self, choice_video_ids): 85 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 86 | max_video_length = [0] * len(choice_video_ids) 87 | 88 | # Pair x L x T x 3 x H x W 89 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 90 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 91 | 92 | for i, video_id in enumerate(choice_video_ids): 93 | # Individual for YoucokII dataset, due to it video format 94 | video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) 95 | if os.path.exists(video_path) is False: 96 | video_path = video_path.replace(".mp4", ".webm") 97 | 98 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 99 | raw_video_data = raw_video_data['video'] 100 | if len(raw_video_data.shape) > 3: 101 | raw_video_data_clip = raw_video_data 102 | # L x T x 3 x H x W 103 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 104 | if self.max_frames < raw_video_slice.shape[0]: 105 | if self.slice_framepos == 0: 106 | video_slice = raw_video_slice[:self.max_frames, ...] 107 | elif self.slice_framepos == 1: 108 | video_slice = raw_video_slice[-self.max_frames:, ...] 109 | else: 110 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 111 | video_slice = raw_video_slice[sample_indx, ...] 112 | else: 113 | video_slice = raw_video_slice 114 | 115 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 116 | 117 | slice_len = video_slice.shape[0] 118 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 119 | if slice_len < 1: 120 | pass 121 | else: 122 | video[i][:slice_len, ...] = video_slice 123 | else: 124 | print("video path: {} error. video id: {}".format(video_path, video_id)) 125 | 126 | for i, v_length in enumerate(max_video_length): 127 | video_mask[i][:v_length] = [1] * v_length 128 | 129 | return video, video_mask 130 | 131 | def __getitem__(self, idx): 132 | video_id = self.data['video_id'].values[idx] 133 | sentence = self.data['sentence'].values[idx] 134 | 135 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, sentence) 136 | video, video_mask = self._get_rawvideo(choice_video_ids) 137 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 138 | 139 | class MSRVTT_TrainDataLoader(Dataset): 140 | """MSRVTT train dataset loader.""" 141 | def __init__( 142 | self, 143 | csv_path, 144 | json_path, 145 | features_path, 146 | tokenizer, 147 | max_words=30, 148 | feature_framerate=1.0, 149 | max_frames=100, 150 | unfold_sentences=False, 151 | image_resolution=224, 152 | frame_order=0, 153 | slice_framepos=0, 154 | ): 155 | self.csv = pd.read_csv(csv_path) 156 | self.data = json.load(open(json_path, 'r')) 157 | self.features_path = features_path 158 | self.feature_framerate = feature_framerate 159 | self.max_words = max_words 160 | self.max_frames = max_frames 161 | self.tokenizer = tokenizer 162 | # 0: ordinary order; 1: reverse order; 2: random order. 163 | self.frame_order = frame_order 164 | assert self.frame_order in [0, 1, 2] 165 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 166 | self.slice_framepos = slice_framepos 167 | assert self.slice_framepos in [0, 1, 2] 168 | 169 | self.unfold_sentences = unfold_sentences 170 | self.sample_len = 0 171 | if self.unfold_sentences: 172 | train_video_ids = list(self.csv['video_id'].values) 173 | self.sentences_dict = {} 174 | for itm in self.data['sentences']: 175 | if itm['video_id'] in train_video_ids: 176 | self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption']) 177 | self.sample_len = len(self.sentences_dict) 178 | else: 179 | num_sentences = 0 180 | self.sentences = defaultdict(list) 181 | s_video_id_set = set() 182 | for itm in self.data['sentences']: 183 | self.sentences[itm['video_id']].append(itm['caption']) 184 | num_sentences += 1 185 | s_video_id_set.add(itm['video_id']) 186 | 187 | # Use to find the clips in the same video 188 | self.parent_ids = {} 189 | self.children_video_ids = defaultdict(list) 190 | for itm in self.data['videos']: 191 | vid = itm["video_id"] 192 | url_posfix = itm["url"].split("?v=")[-1] 193 | self.parent_ids[vid] = url_posfix 194 | self.children_video_ids[url_posfix].append(vid) 195 | self.sample_len = len(self.csv) 196 | 197 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) 198 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 199 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 200 | 201 | def __len__(self): 202 | return self.sample_len 203 | 204 | def _get_text(self, video_id, caption=None): 205 | k = 1 206 | choice_video_ids = [video_id] 207 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 208 | pairs_mask = np.zeros((k, self.max_words), dtype=np.long) 209 | pairs_segment = np.zeros((k, self.max_words), dtype=np.long) 210 | 211 | for i, video_id in enumerate(choice_video_ids): 212 | if caption is not None: 213 | words = self.tokenizer.tokenize(caption) 214 | else: 215 | words = self._get_single_text(video_id) 216 | 217 | words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words 218 | total_length_with_CLS = self.max_words - 1 219 | if len(words) > total_length_with_CLS: 220 | words = words[:total_length_with_CLS] 221 | words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] 222 | 223 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 224 | input_mask = [1] * len(input_ids) 225 | segment_ids = [0] * len(input_ids) 226 | while len(input_ids) < self.max_words: 227 | input_ids.append(0) 228 | input_mask.append(0) 229 | segment_ids.append(0) 230 | assert len(input_ids) == self.max_words 231 | assert len(input_mask) == self.max_words 232 | assert len(segment_ids) == self.max_words 233 | 234 | pairs_text[i] = np.array(input_ids) 235 | pairs_mask[i] = np.array(input_mask) 236 | pairs_segment[i] = np.array(segment_ids) 237 | 238 | return pairs_text, pairs_mask, pairs_segment, choice_video_ids 239 | 240 | def _get_single_text(self, video_id): 241 | rind = random.randint(0, len(self.sentences[video_id]) - 1) 242 | caption = self.sentences[video_id][rind] 243 | words = self.tokenizer.tokenize(caption) 244 | return words 245 | 246 | def _get_rawvideo(self, choice_video_ids): 247 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 248 | max_video_length = [0] * len(choice_video_ids) 249 | 250 | # Pair x L x T x 3 x H x W 251 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 252 | self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 253 | 254 | for i, video_id in enumerate(choice_video_ids): 255 | # Individual for YoucokII dataset, due to it video format 256 | video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) 257 | if os.path.exists(video_path) is False: 258 | video_path = video_path.replace(".mp4", ".webm") 259 | 260 | raw_video_data = self.rawVideoExtractor.get_video_data(video_path) 261 | raw_video_data = raw_video_data['video'] 262 | if len(raw_video_data.shape) > 3: 263 | raw_video_data_clip = raw_video_data 264 | # L x T x 3 x H x W 265 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 266 | if self.max_frames < raw_video_slice.shape[0]: 267 | if self.slice_framepos == 0: 268 | video_slice = raw_video_slice[:self.max_frames, ...] 269 | elif self.slice_framepos == 1: 270 | video_slice = raw_video_slice[-self.max_frames:, ...] 271 | else: 272 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 273 | video_slice = raw_video_slice[sample_indx, ...] 274 | else: 275 | video_slice = raw_video_slice 276 | 277 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 278 | 279 | slice_len = video_slice.shape[0] 280 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 281 | if slice_len < 1: 282 | pass 283 | else: 284 | video[i][:slice_len, ...] = video_slice 285 | else: 286 | print("video path: {} error. video id: {}".format(video_path, video_id)) 287 | 288 | for i, v_length in enumerate(max_video_length): 289 | video_mask[i][:v_length] = [1] * v_length 290 | 291 | return video, video_mask 292 | 293 | def __getitem__(self, idx): 294 | if self.unfold_sentences: 295 | video_id, caption = self.sentences_dict[idx] 296 | else: 297 | video_id, caption = self.csv['video_id'].values[idx], None 298 | pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) 299 | video, video_mask = self._get_rawvideo(choice_video_ids) 300 | return pairs_text, pairs_mask, pairs_segment, video, video_mask 301 | -------------------------------------------------------------------------------- /modules/module_clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py 3 | """ 4 | from collections import OrderedDict 5 | from typing import Tuple, Union 6 | 7 | import hashlib 8 | import os 9 | import urllib 10 | import warnings 11 | from tqdm import tqdm 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 23 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 24 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 25 | } 26 | _PT_NAME = { 27 | "RN50": "RN50.pt", 28 | "RN101": "RN101.pt", 29 | "RN50x4": "RN50x4.pt", 30 | "RN50x16": "RN50x16.pt", 31 | "ViT-B/32": "ViT-B-32.pt", 32 | "ViT-B/16": "ViT-B-16.pt", 33 | } 34 | 35 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 36 | os.makedirs(root, exist_ok=True) 37 | filename = os.path.basename(url) 38 | 39 | expected_sha256 = url.split("/")[-2] 40 | download_target = os.path.join(root, filename) 41 | 42 | if os.path.exists(download_target) and not os.path.isfile(download_target): 43 | raise RuntimeError(f"{download_target} exists and is not a regular file") 44 | 45 | if os.path.isfile(download_target): 46 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 47 | return download_target 48 | else: 49 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 50 | 51 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 52 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 53 | while True: 54 | buffer = source.read(8192) 55 | if not buffer: 56 | break 57 | 58 | output.write(buffer) 59 | loop.update(len(buffer)) 60 | 61 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 62 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 63 | 64 | return download_target 65 | 66 | def available_models(): 67 | """Returns the names of available CLIP models""" 68 | return list(_MODELS.keys()) 69 | 70 | # ============================= 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1): 76 | super().__init__() 77 | 78 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 79 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | 82 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | 85 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 86 | 87 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 88 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 89 | 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = None 92 | self.stride = stride 93 | 94 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 95 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 96 | self.downsample = nn.Sequential(OrderedDict([ 97 | ("-1", nn.AvgPool2d(stride)), 98 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 99 | ("1", nn.BatchNorm2d(planes * self.expansion)) 100 | ])) 101 | 102 | def forward(self, x: torch.Tensor): 103 | identity = x 104 | 105 | out = self.relu(self.bn1(self.conv1(x))) 106 | out = self.relu(self.bn2(self.conv2(out))) 107 | out = self.avgpool(out) 108 | out = self.bn3(self.conv3(out)) 109 | 110 | if self.downsample is not None: 111 | identity = self.downsample(x) 112 | 113 | out += identity 114 | out = self.relu(out) 115 | return out 116 | 117 | 118 | class AttentionPool2d(nn.Module): 119 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 120 | super().__init__() 121 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 122 | self.k_proj = nn.Linear(embed_dim, embed_dim) 123 | self.q_proj = nn.Linear(embed_dim, embed_dim) 124 | self.v_proj = nn.Linear(embed_dim, embed_dim) 125 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 126 | self.num_heads = num_heads 127 | 128 | def forward(self, x): 129 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 130 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 131 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 132 | x, _ = F.multi_head_attention_forward( 133 | query=x, key=x, value=x, 134 | embed_dim_to_check=x.shape[-1], 135 | num_heads=self.num_heads, 136 | q_proj_weight=self.q_proj.weight, 137 | k_proj_weight=self.k_proj.weight, 138 | v_proj_weight=self.v_proj.weight, 139 | in_proj_weight=None, 140 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 141 | bias_k=None, 142 | bias_v=None, 143 | add_zero_attn=False, 144 | dropout_p=0, 145 | out_proj_weight=self.c_proj.weight, 146 | out_proj_bias=self.c_proj.bias, 147 | use_separate_proj_weight=True, 148 | training=self.training, 149 | need_weights=False 150 | ) 151 | 152 | return x[0] 153 | 154 | 155 | class ModifiedResNet(nn.Module): 156 | """ 157 | A ResNet class that is similar to torchvision's but contains the following changes: 158 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 159 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 160 | - The final pooling layer is a QKV attention instead of an average pool 161 | """ 162 | 163 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 164 | super().__init__() 165 | self.output_dim = output_dim 166 | self.input_resolution = input_resolution 167 | 168 | # the 3-layer stem 169 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 170 | self.bn1 = nn.BatchNorm2d(width // 2) 171 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 172 | self.bn2 = nn.BatchNorm2d(width // 2) 173 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 174 | self.bn3 = nn.BatchNorm2d(width) 175 | self.avgpool = nn.AvgPool2d(2) 176 | self.relu = nn.ReLU(inplace=True) 177 | 178 | # residual layers 179 | self._inplanes = width # this is a *mutable* variable used during construction 180 | self.layer1 = self._make_layer(width, layers[0]) 181 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 182 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 183 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 184 | 185 | embed_dim = width * 32 # the ResNet feature dimension 186 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 187 | 188 | def _make_layer(self, planes, blocks, stride=1): 189 | layers = [Bottleneck(self._inplanes, planes, stride)] 190 | 191 | self._inplanes = planes * Bottleneck.expansion 192 | for _ in range(1, blocks): 193 | layers.append(Bottleneck(self._inplanes, planes)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | def stem(x): 199 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 200 | x = self.relu(bn(conv(x))) 201 | x = self.avgpool(x) 202 | return x 203 | 204 | x = x.type(self.conv1.weight.dtype) 205 | x = stem(x) 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | x = self.attnpool(x) 211 | 212 | return x 213 | 214 | 215 | class LayerNorm(nn.LayerNorm): 216 | """Subclass torch's LayerNorm to handle fp16.""" 217 | 218 | def forward(self, x: torch.Tensor): 219 | orig_type = x.dtype 220 | ret = super().forward(x.type(torch.float32)) 221 | return ret.type(orig_type) 222 | 223 | 224 | class QuickGELU(nn.Module): 225 | def forward(self, x: torch.Tensor): 226 | return x * torch.sigmoid(1.702 * x) 227 | 228 | 229 | class ResidualAttentionBlock(nn.Module): 230 | def __init__(self, d_model: int, n_head: int, attn_mask=None): 231 | super().__init__() 232 | 233 | self.attn = nn.MultiheadAttention(d_model, n_head) 234 | self.ln_1 = LayerNorm(d_model) 235 | self.mlp = nn.Sequential(OrderedDict([ 236 | ("c_fc", nn.Linear(d_model, d_model * 4)), 237 | ("gelu", QuickGELU()), 238 | ("c_proj", nn.Linear(d_model * 4, d_model)) 239 | ])) 240 | self.ln_2 = LayerNorm(d_model) 241 | self.attn_mask = attn_mask 242 | 243 | def attention(self, x: torch.Tensor): 244 | attn_mask_ = self.attn_mask 245 | if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): 246 | attn_mask_ = self.attn_mask(x.size(0)) # LND 247 | 248 | attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None 249 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 250 | 251 | def forward(self, x_tuple:tuple): 252 | x, video_frame = x_tuple 253 | x = x + self.attention(self.ln_1(x)) 254 | x = x + self.mlp(self.ln_2(x)) 255 | return (x, video_frame) 256 | 257 | 258 | class Transformer(nn.Module): 259 | def __init__(self, width: int, layers: int, heads: int, attn_mask = None): 260 | super().__init__() 261 | self.width = width 262 | self.layers = layers 263 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 264 | 265 | def forward(self, x: torch.Tensor, video_frame=-1): 266 | return self.resblocks((x, video_frame))[0] 267 | 268 | 269 | class VisualTransformer(nn.Module): 270 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, 271 | linear_patch: str = '2d',): 272 | super().__init__() 273 | self.input_resolution = input_resolution 274 | self.output_dim = output_dim 275 | 276 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 277 | 278 | scale = width ** -0.5 279 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 280 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 281 | self.ln_pre = LayerNorm(width) 282 | 283 | self.transformer = Transformer(width, layers, heads) 284 | 285 | self.ln_post = LayerNorm(width) 286 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 287 | 288 | # For 3D 289 | assert linear_patch in ['2d', '3d'] 290 | self.linear_patch = linear_patch 291 | if self.linear_patch == '3d': 292 | self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size), 293 | stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False) 294 | 295 | def forward(self, x: torch.Tensor, video_frame=-1): 296 | 297 | if self.linear_patch == '3d': 298 | assert video_frame != -1 299 | x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1]) 300 | x_3d = x_3d.permute(0, 2, 1, 3, 4) 301 | x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] 302 | x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid] 303 | x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] 304 | else: 305 | x = self.conv1(x) # shape = [*, width, grid, grid] 306 | 307 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 308 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 309 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 310 | x = x + self.positional_embedding.to(x.dtype) 311 | x = self.ln_pre(x) 312 | 313 | x = x.permute(1, 0, 2) # NLD -> LND 314 | x = self.transformer(x, video_frame=video_frame) 315 | x = x.permute(1, 0, 2) # LND -> NLD 316 | 317 | # Move the three lines below to `encode_image` for entire hidden sequence 318 | # x = self.ln_post(x[:, 0, :]) 319 | # if self.proj is not None: 320 | # x = x @ self.proj 321 | 322 | return x 323 | 324 | 325 | class CLIP(nn.Module): 326 | def __init__(self, 327 | embed_dim: int, 328 | # vision 329 | image_resolution: int, 330 | vision_layers: Union[Tuple[int, int, int, int], int], 331 | vision_width: int, 332 | vision_patch_size: int, 333 | # text 334 | context_length: int, 335 | vocab_size: int, 336 | transformer_width: int, 337 | transformer_heads: int, 338 | transformer_layers: int, 339 | # vision linear of patch 340 | linear_patch: str = '2d', 341 | ): 342 | super().__init__() 343 | 344 | self.context_length = context_length 345 | 346 | if isinstance(vision_layers, (tuple, list)): 347 | vision_heads = vision_width * 32 // 64 348 | self.visual = ModifiedResNet( 349 | layers=vision_layers, 350 | output_dim=embed_dim, 351 | heads=vision_heads, 352 | input_resolution=image_resolution, 353 | width=vision_width 354 | ) 355 | else: 356 | vision_heads = vision_width // 64 357 | self.visual = VisualTransformer( 358 | input_resolution=image_resolution, 359 | patch_size=vision_patch_size, 360 | width=vision_width, 361 | layers=vision_layers, 362 | heads=vision_heads, 363 | output_dim=embed_dim, 364 | linear_patch=linear_patch 365 | ) 366 | 367 | self.transformer = Transformer( 368 | width=transformer_width, 369 | layers=transformer_layers, 370 | heads=transformer_heads, 371 | attn_mask=self.build_attention_mask 372 | ) 373 | 374 | self.vocab_size = vocab_size 375 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 376 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 377 | self.ln_final = LayerNorm(transformer_width) 378 | 379 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 380 | self.logit_scale = nn.Parameter(torch.ones([])) 381 | 382 | self.initialize_parameters() 383 | 384 | def initialize_parameters(self): 385 | nn.init.normal_(self.token_embedding.weight, std=0.02) 386 | nn.init.normal_(self.positional_embedding, std=0.01) 387 | 388 | if isinstance(self.visual, ModifiedResNet): 389 | if self.visual.attnpool is not None: 390 | std = self.visual.attnpool.c_proj.in_features ** -0.5 391 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 392 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 393 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 394 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 395 | 396 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 397 | for name, param in resnet_block.named_parameters(): 398 | if name.endswith("bn3.weight"): 399 | nn.init.zeros_(param) 400 | 401 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 402 | attn_std = self.transformer.width ** -0.5 403 | fc_std = (2 * self.transformer.width) ** -0.5 404 | for block in self.transformer.resblocks: 405 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 406 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 407 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 408 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 409 | 410 | if self.text_projection is not None: 411 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 412 | 413 | @staticmethod 414 | def get_config(pretrained_clip_name="ViT-B/32"): 415 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt") 416 | if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME: 417 | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name]) 418 | 419 | if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path): 420 | pass 421 | else: 422 | if pretrained_clip_name in _MODELS: 423 | model_path = _download(_MODELS[pretrained_clip_name]) 424 | elif os.path.isfile(pretrained_clip_name): 425 | model_path = pretrained_clip_name 426 | else: 427 | raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}") 428 | 429 | try: 430 | # loading JIT archive 431 | model = torch.jit.load(model_path, map_location="cpu").eval() 432 | state_dict = model.state_dict() 433 | except RuntimeError: 434 | state_dict = torch.load(model_path, map_location="cpu") 435 | 436 | return state_dict 437 | 438 | def build_attention_mask(self, context_length): 439 | # lazily create causal attention mask, with full attention between the vision tokens 440 | # pytorch uses additive attention mask; fill with -inf 441 | mask = torch.zeros(context_length, context_length) 442 | mask.fill_(float("-inf")) 443 | mask.triu_(1) # zero out the lower diagonal 444 | return mask 445 | 446 | @property 447 | def dtype(self): 448 | return self.visual.conv1.weight.dtype 449 | 450 | def encode_image(self, image, return_hidden=False, video_frame=-1): 451 | hidden = self.visual(image.type(self.dtype), video_frame=video_frame) 452 | hidden = self.visual.ln_post(hidden) @ self.visual.proj 453 | 454 | x = hidden[:, 0, :] 455 | 456 | if return_hidden: 457 | return x, hidden 458 | 459 | return x 460 | 461 | def encode_text(self, text, return_hidden=False): 462 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 463 | 464 | pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) 465 | x = x + pos_emd 466 | x = x.permute(1, 0, 2) # NLD -> LND 467 | x = self.transformer(x) 468 | x = x.permute(1, 0, 2) # LND -> NLD 469 | 470 | hidden = self.ln_final(x).type(self.dtype) @ self.text_projection 471 | 472 | # x.shape = [batch_size, n_ctx, transformer.width] 473 | # take features from the eot embedding (eot_token is the highest number in each sequence) 474 | x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] 475 | 476 | if return_hidden: 477 | return x, hidden 478 | 479 | return x 480 | 481 | def forward(self, image, text): 482 | image_features = self.encode_image(image) 483 | text_features = self.encode_text(text) 484 | 485 | # normalized features 486 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 487 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 488 | 489 | # cosine similarity as logits 490 | logit_scale = self.logit_scale.exp() 491 | logits_per_image = logit_scale * image_features @ text_features.t() 492 | logits_per_text = logit_scale * text_features @ image_features.t() 493 | 494 | # shape = [global_batch_size, global_batch_size] 495 | return logits_per_image, logits_per_text 496 | 497 | 498 | def convert_weights(model: nn.Module): 499 | """Convert applicable model parameters to fp16""" 500 | 501 | def _convert_weights_to_fp16(l): 502 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 503 | l.weight.data = l.weight.data.half() 504 | if l.bias is not None: 505 | l.bias.data = l.bias.data.half() 506 | 507 | if isinstance(l, nn.MultiheadAttention): 508 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 509 | tensor = getattr(l, attr) 510 | if tensor is not None: 511 | tensor.data = tensor.data.half() 512 | 513 | for name in ["text_projection", "proj"]: 514 | if hasattr(l, name): 515 | attr = getattr(l, name) 516 | if attr is not None: 517 | attr.data = attr.data.half() 518 | 519 | model.apply(_convert_weights_to_fp16) 520 | 521 | 522 | def build_model(state_dict: dict): 523 | vit = "visual.proj" in state_dict 524 | 525 | if vit: 526 | vision_width = state_dict["visual.conv1.weight"].shape[0] 527 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 528 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 529 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 530 | image_resolution = vision_patch_size * grid_size 531 | else: 532 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 533 | vision_layers = tuple(counts) 534 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 535 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 536 | vision_patch_size = None 537 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 538 | image_resolution = output_width * 32 539 | 540 | embed_dim = state_dict["text_projection"].shape[1] 541 | context_length = state_dict["positional_embedding"].shape[0] 542 | vocab_size = state_dict["token_embedding.weight"].shape[0] 543 | transformer_width = state_dict["ln_final.weight"].shape[0] 544 | transformer_heads = transformer_width // 64 545 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 546 | 547 | model = CLIP( 548 | embed_dim, 549 | image_resolution, vision_layers, vision_width, vision_patch_size, 550 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 551 | ) 552 | 553 | for key in ["input_resolution", "context_length", "vocab_size"]: 554 | if key in state_dict: 555 | del state_dict[key] 556 | 557 | convert_weights(model) 558 | model.load_state_dict(state_dict) 559 | return model.eval() 560 | -------------------------------------------------------------------------------- /modules/modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from modules.until_module import PreTrainedModel, AllGather, CrossEn 11 | from modules.module_cross import CrossModel, CrossConfig, Transformer as TransformerClip 12 | 13 | from modules.module_clip import CLIP, convert_weights 14 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 15 | 16 | logger = logging.getLogger(__name__) 17 | allgather = AllGather.apply 18 | 19 | class CLIP4ClipPreTrainedModel(PreTrainedModel, nn.Module): 20 | """ An abstract class to handle weights initialization and 21 | a simple interface for dowloading and loading pretrained models. 22 | """ 23 | def __init__(self, cross_config, *inputs, **kwargs): 24 | super(CLIP4ClipPreTrainedModel, self).__init__(cross_config) 25 | self.cross_config = cross_config 26 | self.clip = None 27 | self.cross = None 28 | 29 | @classmethod 30 | def from_pretrained(cls, cross_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs): 31 | 32 | task_config = None 33 | if "task_config" in kwargs.keys(): 34 | task_config = kwargs["task_config"] 35 | if not hasattr(task_config, "local_rank"): 36 | task_config.__dict__["local_rank"] = 0 37 | elif task_config.local_rank == -1: 38 | task_config.local_rank = 0 39 | 40 | if state_dict is None: state_dict = {} 41 | pretrained_clip_name = "ViT-B/32" 42 | if hasattr(task_config, 'pretrained_clip_name'): 43 | pretrained_clip_name = task_config.pretrained_clip_name 44 | clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) 45 | for key, val in clip_state_dict.items(): 46 | new_key = "clip." + key 47 | if new_key not in state_dict: 48 | state_dict[new_key] = val.clone() 49 | 50 | cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config) 51 | 52 | model = cls(cross_config, clip_state_dict, *inputs, **kwargs) 53 | 54 | ## ===> Initialization trick [HARD CODE] 55 | if model.linear_patch == "3d": 56 | contain_conv2 = False 57 | for key in state_dict.keys(): 58 | if key.find("visual.conv2.weight") > -1: 59 | contain_conv2 = True 60 | break 61 | if contain_conv2 is False and hasattr(model.clip.visual, "conv2"): 62 | cp_weight = state_dict["clip.visual.conv1.weight"].clone() 63 | kernel_size = model.clip.visual.conv2.weight.size(2) 64 | conv2_size = model.clip.visual.conv2.weight.size() 65 | conv2_size = list(conv2_size) 66 | 67 | left_conv2_size = conv2_size.copy() 68 | right_conv2_size = conv2_size.copy() 69 | left_conv2_size[2] = (kernel_size - 1) // 2 70 | right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2] 71 | 72 | left_zeros, right_zeros = None, None 73 | if left_conv2_size[2] > 0: 74 | left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device) 75 | if right_conv2_size[2] > 0: 76 | right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device) 77 | 78 | cat_list = [] 79 | if left_zeros != None: cat_list.append(left_zeros) 80 | cat_list.append(cp_weight.unsqueeze(2)) 81 | if right_zeros != None: cat_list.append(right_zeros) 82 | cp_weight = torch.cat(cat_list, dim=2) 83 | 84 | state_dict["clip.visual.conv2.weight"] = cp_weight 85 | 86 | if model.sim_header == 'tightTransf': 87 | contain_cross = False 88 | for key in state_dict.keys(): 89 | if key.find("cross.transformer") > -1: 90 | contain_cross = True 91 | break 92 | if contain_cross is False: 93 | for key, val in clip_state_dict.items(): 94 | if key == "positional_embedding": 95 | state_dict["cross.embeddings.position_embeddings.weight"] = val.clone() 96 | continue 97 | if key.find("transformer.resblocks") == 0: 98 | num_layer = int(key.split(".")[2]) 99 | 100 | # cut from beginning 101 | if num_layer < task_config.cross_num_hidden_layers: 102 | state_dict["cross."+key] = val.clone() 103 | continue 104 | 105 | if model.sim_header == "seqLSTM" or model.sim_header == "seqTransf": 106 | contain_frame_position = False 107 | for key in state_dict.keys(): 108 | if key.find("frame_position_embeddings") > -1: 109 | contain_frame_position = True 110 | break 111 | if contain_frame_position is False: 112 | for key, val in clip_state_dict.items(): 113 | if key == "positional_embedding": 114 | state_dict["frame_position_embeddings.weight"] = val.clone() 115 | continue 116 | if model.sim_header == "seqTransf" and key.find("transformer.resblocks") == 0: 117 | num_layer = int(key.split(".")[2]) 118 | # cut from beginning 119 | if num_layer < task_config.cross_num_hidden_layers: 120 | state_dict[key.replace("transformer.", "transformerClip.")] = val.clone() 121 | continue 122 | ## <=== End of initialization trick 123 | 124 | if state_dict is not None: 125 | model = cls.init_preweight(model, state_dict, task_config=task_config) 126 | 127 | return model 128 | 129 | def show_log(task_config, info): 130 | if task_config is None or task_config.local_rank == 0: 131 | logger.warning(info) 132 | 133 | def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None): 134 | if hasattr(source_config, source_attr_name): 135 | if default_value is None or getattr(source_config, source_attr_name) != default_value: 136 | setattr(target_config, target_attr_name, getattr(source_config, source_attr_name)) 137 | show_log(source_config, "Set {}.{}: {}.".format(target_name, 138 | target_attr_name, getattr(target_config, target_attr_name))) 139 | return target_config 140 | 141 | def check_attr(target_name, task_config): 142 | return hasattr(task_config, target_name) and task_config.__dict__[target_name] 143 | 144 | class CLIP4Clip(CLIP4ClipPreTrainedModel): 145 | def __init__(self, cross_config, clip_state_dict, task_config): 146 | super(CLIP4Clip, self).__init__(cross_config) 147 | self.task_config = task_config 148 | self.ignore_video_index = -1 149 | 150 | assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings 151 | 152 | self._stage_one = True 153 | self._stage_two = False 154 | 155 | show_log(task_config, "Stage-One:{}, Stage-Two:{}".format(self._stage_one, self._stage_two)) 156 | 157 | self.loose_type = False 158 | if self._stage_one and check_attr('loose_type', self.task_config): 159 | self.loose_type = True 160 | show_log(task_config, "Test retrieval by loose type.") 161 | 162 | # CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===> 163 | vit = "visual.proj" in clip_state_dict 164 | assert vit 165 | if vit: 166 | vision_width = clip_state_dict["visual.conv1.weight"].shape[0] 167 | vision_layers = len( 168 | [k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 169 | vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1] 170 | grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 171 | image_resolution = vision_patch_size * grid_size 172 | else: 173 | counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in 174 | [1, 2, 3, 4]] 175 | vision_layers = tuple(counts) 176 | vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0] 177 | output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 178 | vision_patch_size = None 179 | assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0] 180 | image_resolution = output_width * 32 181 | 182 | embed_dim = clip_state_dict["text_projection"].shape[1] 183 | context_length = clip_state_dict["positional_embedding"].shape[0] 184 | vocab_size = clip_state_dict["token_embedding.weight"].shape[0] 185 | transformer_width = clip_state_dict["ln_final.weight"].shape[0] 186 | transformer_heads = transformer_width // 64 187 | transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks"))) 188 | 189 | show_log(task_config, "\t embed_dim: {}".format(embed_dim)) 190 | show_log(task_config, "\t image_resolution: {}".format(image_resolution)) 191 | show_log(task_config, "\t vision_layers: {}".format(vision_layers)) 192 | show_log(task_config, "\t vision_width: {}".format(vision_width)) 193 | show_log(task_config, "\t vision_patch_size: {}".format(vision_patch_size)) 194 | show_log(task_config, "\t context_length: {}".format(context_length)) 195 | show_log(task_config, "\t vocab_size: {}".format(vocab_size)) 196 | show_log(task_config, "\t transformer_width: {}".format(transformer_width)) 197 | show_log(task_config, "\t transformer_heads: {}".format(transformer_heads)) 198 | show_log(task_config, "\t transformer_layers: {}".format(transformer_layers)) 199 | 200 | self.linear_patch = '2d' 201 | if hasattr(task_config, "linear_patch"): 202 | self.linear_patch = task_config.linear_patch 203 | show_log(task_config, "\t\t linear_patch: {}".format(self.linear_patch)) 204 | 205 | # use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40 206 | cut_top_layer = 0 207 | show_log(task_config, "\t cut_top_layer: {}".format(cut_top_layer)) 208 | self.clip = CLIP( 209 | embed_dim, 210 | image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size, 211 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer, 212 | linear_patch=self.linear_patch 213 | ).float() 214 | 215 | for key in ["input_resolution", "context_length", "vocab_size"]: 216 | if key in clip_state_dict: 217 | del clip_state_dict[key] 218 | 219 | convert_weights(self.clip) 220 | # <=== End of CLIP Encoders 221 | 222 | self.sim_header = 'meanP' 223 | if hasattr(task_config, "sim_header"): 224 | self.sim_header = task_config.sim_header 225 | show_log(task_config, "\t sim_header: {}".format(self.sim_header)) 226 | if self.sim_header == "tightTransf": assert self.loose_type is False 227 | 228 | cross_config.max_position_embeddings = context_length 229 | if self.loose_type is False: 230 | # Cross Encoder ===> 231 | cross_config = update_attr("cross_config", cross_config, "num_hidden_layers", self.task_config, "cross_num_hidden_layers") 232 | self.cross = CrossModel(cross_config) 233 | # <=== End of Cross Encoder 234 | self.similarity_dense = nn.Linear(cross_config.hidden_size, 1) 235 | 236 | if self.sim_header == "seqLSTM" or self.sim_header == "seqTransf": 237 | self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings, cross_config.hidden_size) 238 | if self.sim_header == "seqTransf": 239 | self.transformerClip = TransformerClip(width=transformer_width, layers=self.task_config.cross_num_hidden_layers, 240 | heads=transformer_heads, ) 241 | if self.sim_header == "seqLSTM": 242 | self.lstm_visual = nn.LSTM(input_size=cross_config.hidden_size, hidden_size=cross_config.hidden_size, 243 | batch_first=True, bidirectional=False, num_layers=1) 244 | 245 | self.loss_fct = CrossEn() 246 | 247 | self.apply(self.init_weights) 248 | 249 | def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=None): 250 | input_ids = input_ids.view(-1, input_ids.shape[-1]) 251 | token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) 252 | attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) 253 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 254 | 255 | # T x 3 x H x W 256 | video = torch.as_tensor(video).float() 257 | b, pair, bs, ts, channel, h, w = video.shape 258 | video = video.view(b * pair * bs * ts, channel, h, w) 259 | video_frame = bs * ts 260 | 261 | sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids, attention_mask, 262 | video, video_mask, shaped=True, video_frame=video_frame) 263 | 264 | if self.training: 265 | loss = 0. 266 | sim_matrix, *_tmp = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, 267 | shaped=True, loose_type=self.loose_type) 268 | sim_loss1 = self.loss_fct(sim_matrix) 269 | sim_loss2 = self.loss_fct(sim_matrix.T) 270 | sim_loss = (sim_loss1 + sim_loss2) / 2 271 | loss += sim_loss 272 | 273 | return loss 274 | else: 275 | return None 276 | 277 | def get_sequence_output(self, input_ids, token_type_ids, attention_mask, shaped=False): 278 | if shaped is False: 279 | input_ids = input_ids.view(-1, input_ids.shape[-1]) 280 | token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) 281 | attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) 282 | 283 | bs_pair = input_ids.size(0) 284 | sequence_hidden = self.clip.encode_text(input_ids).float() 285 | sequence_hidden = sequence_hidden.view(bs_pair, -1, sequence_hidden.size(-1)) 286 | 287 | return sequence_hidden 288 | 289 | def get_visual_output(self, video, video_mask, shaped=False, video_frame=-1): 290 | if shaped is False: 291 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 292 | video = torch.as_tensor(video).float() 293 | b, pair, bs, ts, channel, h, w = video.shape 294 | video = video.view(b * pair * bs * ts, channel, h, w) 295 | video_frame = bs * ts 296 | 297 | bs_pair = video_mask.size(0) 298 | visual_hidden = self.clip.encode_image(video, video_frame=video_frame).float() 299 | visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1)) 300 | 301 | return visual_hidden 302 | 303 | def get_sequence_visual_output(self, input_ids, token_type_ids, attention_mask, video, video_mask, shaped=False, video_frame=-1): 304 | if shaped is False: 305 | input_ids = input_ids.view(-1, input_ids.shape[-1]) 306 | token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) 307 | attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) 308 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 309 | 310 | video = torch.as_tensor(video).float() 311 | b, pair, bs, ts, channel, h, w = video.shape 312 | video = video.view(b * pair * bs * ts, channel, h, w) 313 | video_frame = bs * ts 314 | 315 | sequence_output = self.get_sequence_output(input_ids, token_type_ids, attention_mask, shaped=True) 316 | visual_output = self.get_visual_output(video, video_mask, shaped=True, video_frame=video_frame) 317 | 318 | return sequence_output, visual_output 319 | 320 | def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask): 321 | 322 | concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames 323 | concat_mask = torch.cat((attention_mask, video_mask), dim=1) 324 | text_type_ = torch.zeros_like(attention_mask) 325 | video_type_ = torch.ones_like(video_mask) 326 | concat_type = torch.cat((text_type_, video_type_), dim=1) 327 | 328 | cross_layers, pooled_output = self.cross(concat_features, concat_type, concat_mask, output_all_encoded_layers=True) 329 | cross_output = cross_layers[-1] 330 | 331 | return cross_output, pooled_output, concat_mask 332 | 333 | def _mean_pooling_for_similarity_sequence(self, sequence_output, attention_mask): 334 | attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1) 335 | attention_mask_un[:, 0, :] = 0. 336 | sequence_output = sequence_output * attention_mask_un 337 | text_out = torch.sum(sequence_output, dim=1) / torch.sum(attention_mask_un, dim=1, dtype=torch.float) 338 | return text_out 339 | 340 | def _mean_pooling_for_similarity_visual(self, visual_output, video_mask,): 341 | video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1) 342 | visual_output = visual_output * video_mask_un 343 | video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float) 344 | video_mask_un_sum[video_mask_un_sum == 0.] = 1. 345 | video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum 346 | return video_out 347 | 348 | def _mean_pooling_for_similarity(self, sequence_output, visual_output, attention_mask, video_mask,): 349 | text_out = self._mean_pooling_for_similarity_sequence(sequence_output, attention_mask) 350 | video_out = self._mean_pooling_for_similarity_visual(visual_output, video_mask) 351 | 352 | return text_out, video_out 353 | 354 | def _loose_similarity(self, sequence_output, visual_output, attention_mask, video_mask, sim_header="meanP"): 355 | sequence_output, visual_output = sequence_output.contiguous(), visual_output.contiguous() 356 | 357 | if sim_header == "meanP": 358 | # Default: Parameter-free type 359 | pass 360 | elif sim_header == "seqLSTM": 361 | # Sequential type: LSTM 362 | visual_output_original = visual_output 363 | visual_output = pack_padded_sequence(visual_output, torch.sum(video_mask, dim=-1).cpu(), 364 | batch_first=True, enforce_sorted=False) 365 | visual_output, _ = self.lstm_visual(visual_output) 366 | if self.training: self.lstm_visual.flatten_parameters() 367 | visual_output, _ = pad_packed_sequence(visual_output, batch_first=True) 368 | visual_output = torch.cat((visual_output, visual_output_original[:, visual_output.size(1):, ...].contiguous()), dim=1) 369 | visual_output = visual_output + visual_output_original 370 | elif sim_header == "seqTransf": 371 | # Sequential type: Transformer Encoder 372 | visual_output_original = visual_output 373 | seq_length = visual_output.size(1) 374 | position_ids = torch.arange(seq_length, dtype=torch.long, device=visual_output.device) 375 | position_ids = position_ids.unsqueeze(0).expand(visual_output.size(0), -1) 376 | frame_position_embeddings = self.frame_position_embeddings(position_ids) 377 | visual_output = visual_output + frame_position_embeddings 378 | 379 | extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 380 | extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1) 381 | visual_output = visual_output.permute(1, 0, 2) # NLD -> LND 382 | visual_output = self.transformerClip(visual_output, extended_video_mask) 383 | visual_output = visual_output.permute(1, 0, 2) # LND -> NLD 384 | visual_output = visual_output + visual_output_original 385 | 386 | if self.training: 387 | visual_output = allgather(visual_output, self.task_config) 388 | video_mask = allgather(video_mask, self.task_config) 389 | sequence_output = allgather(sequence_output, self.task_config) 390 | torch.distributed.barrier() 391 | 392 | visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) 393 | visual_output = self._mean_pooling_for_similarity_visual(visual_output, video_mask) 394 | visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) 395 | 396 | sequence_output = sequence_output.squeeze(1) 397 | sequence_output = sequence_output / sequence_output.norm(dim=-1, keepdim=True) 398 | 399 | logit_scale = self.clip.logit_scale.exp() 400 | retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_output.t()) 401 | return retrieve_logits 402 | 403 | def _cross_similarity(self, sequence_output, visual_output, attention_mask, video_mask): 404 | sequence_output, visual_output = sequence_output.contiguous(), visual_output.contiguous() 405 | 406 | b_text, s_text, h_text = sequence_output.size() 407 | b_visual, s_visual, h_visual = visual_output.size() 408 | 409 | retrieve_logits_list = [] 410 | 411 | step_size = b_text # set smaller to reduce memory cost 412 | split_size = [step_size] * (b_text // step_size) 413 | release_size = b_text - sum(split_size) 414 | if release_size > 0: 415 | split_size += [release_size] 416 | 417 | # due to clip text branch retrun the last hidden 418 | attention_mask = torch.ones(sequence_output.size(0), 1)\ 419 | .to(device=attention_mask.device, dtype=attention_mask.dtype) 420 | 421 | sequence_output_splits = torch.split(sequence_output, split_size, dim=0) 422 | attention_mask_splits = torch.split(attention_mask, split_size, dim=0) 423 | for i in range(len(split_size)): 424 | sequence_output_row = sequence_output_splits[i] 425 | attention_mask_row = attention_mask_splits[i] 426 | sequence_output_l = sequence_output_row.unsqueeze(1).repeat(1, b_visual, 1, 1) 427 | sequence_output_l = sequence_output_l.view(-1, s_text, h_text) 428 | attention_mask_l = attention_mask_row.unsqueeze(1).repeat(1, b_visual, 1) 429 | attention_mask_l = attention_mask_l.view(-1, s_text) 430 | 431 | step_truth = sequence_output_row.size(0) 432 | visual_output_r = visual_output.unsqueeze(0).repeat(step_truth, 1, 1, 1) 433 | visual_output_r = visual_output_r.view(-1, s_visual, h_visual) 434 | video_mask_r = video_mask.unsqueeze(0).repeat(step_truth, 1, 1) 435 | video_mask_r = video_mask_r.view(-1, s_visual) 436 | 437 | cross_output, pooled_output, concat_mask = \ 438 | self._get_cross_output(sequence_output_l, visual_output_r, attention_mask_l, video_mask_r) 439 | retrieve_logits_row = self.similarity_dense(pooled_output).squeeze(-1).view(step_truth, b_visual) 440 | 441 | retrieve_logits_list.append(retrieve_logits_row) 442 | 443 | retrieve_logits = torch.cat(retrieve_logits_list, dim=0) 444 | return retrieve_logits 445 | 446 | def get_similarity_logits(self, sequence_output, visual_output, attention_mask, video_mask, shaped=False, loose_type=False): 447 | if shaped is False: 448 | attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) 449 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 450 | 451 | contrastive_direction = () 452 | if loose_type: 453 | assert self.sim_header in ["meanP", "seqLSTM", "seqTransf"] 454 | retrieve_logits = self._loose_similarity(sequence_output, visual_output, attention_mask, video_mask, sim_header=self.sim_header) 455 | else: 456 | assert self.sim_header in ["tightTransf"] 457 | retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask, ) 458 | 459 | return retrieve_logits, contrastive_direction 460 | -------------------------------------------------------------------------------- /main_task_retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import torch 7 | import numpy as np 8 | import random 9 | import os 10 | from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim 11 | import time 12 | import argparse 13 | from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer 14 | from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 15 | from modules.modeling import CLIP4Clip 16 | from modules.optimization import BertAdam 17 | 18 | from util import parallel_apply, get_logger 19 | from dataloaders.data_dataloaders import DATALOADER_DICT 20 | 21 | torch.distributed.init_process_group(backend="nccl") 22 | 23 | global logger 24 | 25 | def get_args(description='CLIP4Clip on Retrieval Task'): 26 | parser = argparse.ArgumentParser(description=description) 27 | parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.") 28 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 29 | parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") 30 | 31 | parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='') 32 | parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='') 33 | parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path') 34 | parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path') 35 | 36 | parser.add_argument('--num_thread_reader', type=int, default=1, help='') 37 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate') 38 | parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit') 39 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 40 | parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval') 41 | parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay') 42 | parser.add_argument('--n_display', type=int, default=100, help='Information display frequence') 43 | parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension') 44 | parser.add_argument('--seed', type=int, default=42, help='random seed') 45 | parser.add_argument('--max_words', type=int, default=20, help='') 46 | parser.add_argument('--max_frames', type=int, default=100, help='') 47 | parser.add_argument('--feature_framerate', type=int, default=1, help='') 48 | parser.add_argument('--margin', type=float, default=0.1, help='margin for loss') 49 | parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample') 50 | parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative') 51 | parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader') 52 | 53 | parser.add_argument("--output_dir", default=None, type=str, required=True, 54 | help="The output directory where the model predictions and checkpoints will be written.") 55 | parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module") 56 | parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.") 57 | parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.") 58 | parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") 59 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 60 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") 61 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 62 | help="Number of updates steps to accumulate before performing a backward/update pass.") 63 | parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.") 64 | 65 | parser.add_argument("--cache_dir", default="", type=str, 66 | help="Where do you want to store the pre-trained models downloaded from s3") 67 | 68 | parser.add_argument('--fp16', action='store_true', 69 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 70 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 71 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 72 | "See details at https://nvidia.github.io/apex/amp.html") 73 | 74 | parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.") 75 | parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.") 76 | 77 | parser.add_argument("--world_size", default=0, type=int, help="distribted training") 78 | parser.add_argument("--local_rank", default=0, type=int, help="distribted training") 79 | parser.add_argument("--rank", default=0, type=int, help="distribted training") 80 | parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.') 81 | parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).") 82 | parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.") 83 | 84 | parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.") 85 | parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.") 86 | parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.") 87 | 88 | parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.") 89 | parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="") 90 | 91 | parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2], 92 | help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") 93 | parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2], 94 | help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") 95 | 96 | parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.") 97 | parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2], 98 | help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.") 99 | parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"], 100 | help="linear projection of flattened patches.") 101 | parser.add_argument('--sim_header', type=str, default="meanP", 102 | choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"], 103 | help="choice a similarity header.") 104 | 105 | parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version") 106 | 107 | args = parser.parse_args() 108 | 109 | if args.sim_header == "tightTransf": 110 | args.loose_type = False 111 | 112 | # Check paramenters 113 | if args.gradient_accumulation_steps < 1: 114 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 115 | args.gradient_accumulation_steps)) 116 | if not args.do_train and not args.do_eval: 117 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 118 | 119 | args.batch_size = int(args.batch_size / args.gradient_accumulation_steps) 120 | 121 | return args 122 | 123 | def set_seed_logger(args): 124 | global logger 125 | # predefining random initial seeds 126 | random.seed(args.seed) 127 | os.environ['PYTHONHASHSEED'] = str(args.seed) 128 | np.random.seed(args.seed) 129 | torch.manual_seed(args.seed) 130 | torch.cuda.manual_seed(args.seed) 131 | torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. 132 | torch.backends.cudnn.benchmark = False 133 | torch.backends.cudnn.deterministic = True 134 | 135 | world_size = torch.distributed.get_world_size() 136 | torch.cuda.set_device(args.local_rank) 137 | args.world_size = world_size 138 | rank = torch.distributed.get_rank() 139 | args.rank = rank 140 | 141 | if not os.path.exists(args.output_dir): 142 | os.makedirs(args.output_dir, exist_ok=True) 143 | 144 | logger = get_logger(os.path.join(args.output_dir, "log.txt")) 145 | 146 | if args.local_rank == 0: 147 | logger.info("Effective parameters:") 148 | for key in sorted(args.__dict__): 149 | logger.info(" <<< {}: {}".format(key, args.__dict__[key])) 150 | 151 | return args 152 | 153 | def init_device(args, local_rank): 154 | global logger 155 | 156 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank) 157 | 158 | n_gpu = torch.cuda.device_count() 159 | logger.info("device: {} n_gpu: {}".format(device, n_gpu)) 160 | args.n_gpu = n_gpu 161 | 162 | if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0: 163 | raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format( 164 | args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu)) 165 | 166 | return device, n_gpu 167 | 168 | def init_model(args, device, n_gpu, local_rank): 169 | 170 | if args.init_model: 171 | model_state_dict = torch.load(args.init_model, map_location='cpu') 172 | else: 173 | model_state_dict = None 174 | 175 | # Prepare model 176 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') 177 | model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) 178 | 179 | model.to(device) 180 | 181 | return model 182 | 183 | def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.): 184 | 185 | if hasattr(model, 'module'): 186 | model = model.module 187 | 188 | param_optimizer = list(model.named_parameters()) 189 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 190 | 191 | decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)] 192 | no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)] 193 | 194 | decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n] 195 | decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n] 196 | 197 | no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n] 198 | no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n] 199 | 200 | weight_decay = 0.2 201 | optimizer_grouped_parameters = [ 202 | {'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr}, 203 | {'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay}, 204 | {'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr}, 205 | {'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0} 206 | ] 207 | 208 | scheduler = None 209 | optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion, 210 | schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6, 211 | t_total=num_train_optimization_steps, weight_decay=weight_decay, 212 | max_grad_norm=1.0) 213 | 214 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], 215 | output_device=local_rank, find_unused_parameters=True) 216 | 217 | return optimizer, scheduler, model 218 | 219 | def save_model(epoch, args, model, optimizer, tr_loss, type_name=""): 220 | # Only save the model it-self 221 | model_to_save = model.module if hasattr(model, 'module') else model 222 | output_model_file = os.path.join( 223 | args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch)) 224 | optimizer_state_file = os.path.join( 225 | args.output_dir, "pytorch_opt.bin.{}{}".format("" if type_name=="" else type_name+".", epoch)) 226 | torch.save(model_to_save.state_dict(), output_model_file) 227 | torch.save({ 228 | 'epoch': epoch, 229 | 'optimizer_state_dict': optimizer.state_dict(), 230 | 'loss': tr_loss, 231 | }, optimizer_state_file) 232 | logger.info("Model saved to %s", output_model_file) 233 | logger.info("Optimizer saved to %s", optimizer_state_file) 234 | return output_model_file 235 | 236 | def load_model(epoch, args, n_gpu, device, model_file=None): 237 | if model_file is None or len(model_file) == 0: 238 | model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch)) 239 | if os.path.exists(model_file): 240 | model_state_dict = torch.load(model_file, map_location='cpu') 241 | if args.local_rank == 0: 242 | logger.info("Model loaded from %s", model_file) 243 | # Prepare model 244 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') 245 | model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) 246 | 247 | model.to(device) 248 | else: 249 | model = None 250 | return model 251 | 252 | def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0): 253 | global logger 254 | torch.cuda.empty_cache() 255 | model.train() 256 | log_step = args.n_display 257 | start_time = time.time() 258 | total_loss = 0 259 | 260 | for step, batch in enumerate(train_dataloader): 261 | if n_gpu == 1: 262 | # multi-gpu does scattering it-self 263 | batch = tuple(t.to(device=device, non_blocking=True) for t in batch) 264 | 265 | input_ids, input_mask, segment_ids, video, video_mask = batch 266 | loss = model(input_ids, segment_ids, input_mask, video, video_mask) 267 | 268 | if n_gpu > 1: 269 | loss = loss.mean() # mean() to average on multi-gpu. 270 | if args.gradient_accumulation_steps > 1: 271 | loss = loss / args.gradient_accumulation_steps 272 | 273 | loss.backward() 274 | 275 | total_loss += float(loss) 276 | if (step + 1) % args.gradient_accumulation_steps == 0: 277 | 278 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 279 | 280 | if scheduler is not None: 281 | scheduler.step() # Update learning rate schedule 282 | 283 | optimizer.step() 284 | optimizer.zero_grad() 285 | 286 | # https://github.com/openai/CLIP/issues/46 287 | if hasattr(model, 'module'): 288 | torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100)) 289 | else: 290 | torch.clamp_(model.clip.logit_scale.data, max=np.log(100)) 291 | 292 | global_step += 1 293 | if global_step % log_step == 0 and local_rank == 0: 294 | logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1, 295 | args.epochs, step + 1, 296 | len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), 297 | float(loss), 298 | (time.time() - start_time) / (log_step * args.gradient_accumulation_steps)) 299 | start_time = time.time() 300 | 301 | total_loss = total_loss / len(train_dataloader) 302 | return total_loss, global_step 303 | 304 | def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list): 305 | sim_matrix = [] 306 | for idx1, b1 in enumerate(batch_list_t): 307 | input_mask, segment_ids, *_tmp = b1 308 | sequence_output = batch_sequence_output_list[idx1] 309 | each_row = [] 310 | for idx2, b2 in enumerate(batch_list_v): 311 | video_mask, *_tmp = b2 312 | visual_output = batch_visual_output_list[idx2] 313 | b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask, 314 | loose_type=model.loose_type) 315 | b1b2_logits = b1b2_logits.cpu().detach().numpy() 316 | each_row.append(b1b2_logits) 317 | each_row = np.concatenate(tuple(each_row), axis=-1) 318 | sim_matrix.append(each_row) 319 | return sim_matrix 320 | 321 | def eval_epoch(args, model, test_dataloader, device, n_gpu): 322 | 323 | if hasattr(model, 'module'): 324 | model = model.module.to(device) 325 | else: 326 | model = model.to(device) 327 | 328 | # ################################################################# 329 | ## below variables are used to multi-sentences retrieval 330 | # multi_sentence_: important tag for eval 331 | # cut_off_points: used to tag the label when calculate the metric 332 | # sentence_num: used to cut the sentence representation 333 | # video_num: used to cut the video representation 334 | # ################################################################# 335 | multi_sentence_ = False 336 | cut_off_points_, sentence_num_, video_num_ = [], -1, -1 337 | if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') \ 338 | and test_dataloader.dataset.multi_sentence_per_video: 339 | multi_sentence_ = True 340 | cut_off_points_ = test_dataloader.dataset.cut_off_points 341 | sentence_num_ = test_dataloader.dataset.sentence_num 342 | video_num_ = test_dataloader.dataset.video_num 343 | cut_off_points_ = [itm - 1 for itm in cut_off_points_] 344 | 345 | if multi_sentence_: 346 | logger.warning("Eval under the multi-sentence per video clip setting.") 347 | logger.warning("sentence num: {}, video num: {}".format(sentence_num_, video_num_)) 348 | 349 | model.eval() 350 | with torch.no_grad(): 351 | batch_list_t = [] 352 | batch_list_v = [] 353 | batch_sequence_output_list, batch_visual_output_list = [], [] 354 | total_video_num = 0 355 | 356 | # ---------------------------- 357 | # 1. cache the features 358 | # ---------------------------- 359 | for bid, batch in enumerate(test_dataloader): 360 | batch = tuple(t.to(device) for t in batch) 361 | input_ids, input_mask, segment_ids, video, video_mask = batch 362 | 363 | if multi_sentence_: 364 | # multi-sentences retrieval means: one clip has two or more descriptions. 365 | b, *_t = video.shape 366 | sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask) 367 | batch_sequence_output_list.append(sequence_output) 368 | batch_list_t.append((input_mask, segment_ids,)) 369 | 370 | s_, e_ = total_video_num, total_video_num + b 371 | filter_inds = [itm - s_ for itm in cut_off_points_ if itm >= s_ and itm < e_] 372 | 373 | if len(filter_inds) > 0: 374 | video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...] 375 | visual_output = model.get_visual_output(video, video_mask) 376 | batch_visual_output_list.append(visual_output) 377 | batch_list_v.append((video_mask,)) 378 | total_video_num += b 379 | else: 380 | sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask) 381 | 382 | batch_sequence_output_list.append(sequence_output) 383 | batch_list_t.append((input_mask, segment_ids,)) 384 | 385 | batch_visual_output_list.append(visual_output) 386 | batch_list_v.append((video_mask,)) 387 | 388 | print("{}/{}\r".format(bid, len(test_dataloader)), end="") 389 | 390 | # ---------------------------------- 391 | # 2. calculate the similarity 392 | # ---------------------------------- 393 | if n_gpu > 1: 394 | device_ids = list(range(n_gpu)) 395 | batch_list_t_splits = [] 396 | batch_list_v_splits = [] 397 | batch_t_output_splits = [] 398 | batch_v_output_splits = [] 399 | bacth_len = len(batch_list_t) 400 | split_len = (bacth_len + n_gpu - 1) // n_gpu 401 | for dev_id in device_ids: 402 | s_, e_ = dev_id * split_len, (dev_id + 1) * split_len 403 | if dev_id == 0: 404 | batch_list_t_splits.append(batch_list_t[s_:e_]) 405 | batch_list_v_splits.append(batch_list_v) 406 | 407 | batch_t_output_splits.append(batch_sequence_output_list[s_:e_]) 408 | batch_v_output_splits.append(batch_visual_output_list) 409 | else: 410 | devc = torch.device('cuda:{}'.format(str(dev_id))) 411 | devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]] 412 | batch_list_t_splits.append(devc_batch_list) 413 | devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v] 414 | batch_list_v_splits.append(devc_batch_list) 415 | 416 | devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]] 417 | batch_t_output_splits.append(devc_batch_list) 418 | devc_batch_list = [b.to(devc) for b in batch_visual_output_list] 419 | batch_v_output_splits.append(devc_batch_list) 420 | 421 | parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id], 422 | batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids] 423 | parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids) 424 | sim_matrix = [] 425 | for idx in range(len(parallel_outputs)): 426 | sim_matrix += parallel_outputs[idx] 427 | sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) 428 | else: 429 | sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list) 430 | sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) 431 | 432 | if multi_sentence_: 433 | logger.info("before reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1])) 434 | cut_off_points2len_ = [itm + 1 for itm in cut_off_points_] 435 | max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)]) 436 | sim_matrix_new = [] 437 | for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_): 438 | sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_], 439 | np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0)) 440 | sim_matrix = np.stack(tuple(sim_matrix_new), axis=0) 441 | logger.info("after reshape, sim matrix size: {} x {} x {}". 442 | format(sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])) 443 | 444 | tv_metrics = tensor_text_to_video_metrics(sim_matrix) 445 | vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix)) 446 | else: 447 | logger.info("sim matrix size: {}, {}".format(sim_matrix.shape[0], sim_matrix.shape[1])) 448 | tv_metrics = compute_metrics(sim_matrix) 449 | vt_metrics = compute_metrics(sim_matrix.T) 450 | logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) 451 | 452 | logger.info("Text-to-Video:") 453 | logger.info('\t>>> R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'. 454 | format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])) 455 | logger.info("Video-to-Text:") 456 | logger.info('\t>>> V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'. 457 | format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])) 458 | 459 | R1 = tv_metrics['R1'] 460 | return R1 461 | 462 | def main(): 463 | global logger 464 | args = get_args() 465 | args = set_seed_logger(args) 466 | device, n_gpu = init_device(args, args.local_rank) 467 | 468 | tokenizer = ClipTokenizer() 469 | 470 | assert args.task_type == "retrieval" 471 | model = init_model(args, device, n_gpu, args.local_rank) 472 | 473 | ## #################################### 474 | # freeze testing 475 | ## #################################### 476 | assert args.freeze_layer_num <= 12 and args.freeze_layer_num >= -1 477 | if hasattr(model, "clip") and args.freeze_layer_num > -1: 478 | for name, param in model.clip.named_parameters(): 479 | # top layers always need to train 480 | if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \ 481 | or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0: 482 | continue # need to train 483 | elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0: 484 | layer_num = int(name.split(".resblocks.")[1].split(".")[0]) 485 | if layer_num >= args.freeze_layer_num: 486 | continue # need to train 487 | 488 | if args.linear_patch == "3d" and name.find("conv2."): 489 | continue 490 | else: 491 | # paramenters which < freeze_layer_num will be freezed 492 | param.requires_grad = False 493 | 494 | ## #################################### 495 | # dataloader loading 496 | ## #################################### 497 | assert args.datatype in DATALOADER_DICT 498 | 499 | assert DATALOADER_DICT[args.datatype]["test"] is not None \ 500 | or DATALOADER_DICT[args.datatype]["val"] is not None 501 | 502 | test_dataloader, test_length = None, 0 503 | if DATALOADER_DICT[args.datatype]["test"] is not None: 504 | test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer) 505 | 506 | if DATALOADER_DICT[args.datatype]["val"] is not None: 507 | val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val") 508 | else: 509 | val_dataloader, val_length = test_dataloader, test_length 510 | 511 | ## report validation results if the ["test"] is None 512 | if test_dataloader is None: 513 | test_dataloader, test_length = val_dataloader, val_length 514 | 515 | if args.local_rank == 0: 516 | logger.info("***** Running test *****") 517 | logger.info(" Num examples = %d", test_length) 518 | logger.info(" Batch size = %d", args.batch_size_val) 519 | logger.info(" Num steps = %d", len(test_dataloader)) 520 | logger.info("***** Running val *****") 521 | logger.info(" Num examples = %d", val_length) 522 | 523 | ## #################################### 524 | # train and eval 525 | ## #################################### 526 | if args.do_train: 527 | train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer) 528 | num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1) 529 | / args.gradient_accumulation_steps) * args.epochs 530 | 531 | coef_lr = args.coef_lr 532 | optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr) 533 | 534 | if args.local_rank == 0: 535 | logger.info("***** Running training *****") 536 | logger.info(" Num examples = %d", train_length) 537 | logger.info(" Batch size = %d", args.batch_size) 538 | logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps) 539 | 540 | best_score = 0.00001 541 | best_output_model_file = "None" 542 | ## ############################################################## 543 | # resume optimizer state besides loss to continue train 544 | ## ############################################################## 545 | resumed_epoch = 0 546 | if args.resume_model: 547 | checkpoint = torch.load(args.resume_model, map_location='cpu') 548 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 549 | resumed_epoch = checkpoint['epoch']+1 550 | resumed_loss = checkpoint['loss'] 551 | 552 | global_step = 0 553 | for epoch in range(resumed_epoch, args.epochs): 554 | train_sampler.set_epoch(epoch) 555 | tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, 556 | scheduler, global_step, local_rank=args.local_rank) 557 | if args.local_rank == 0: 558 | logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss) 559 | 560 | output_model_file = save_model(epoch, args, model, optimizer, tr_loss, type_name="") 561 | 562 | ## Run on val dataset, this process is *TIME-consuming*. 563 | # logger.info("Eval on val dataset") 564 | # R1 = eval_epoch(args, model, val_dataloader, device, n_gpu) 565 | 566 | R1 = eval_epoch(args, model, test_dataloader, device, n_gpu) 567 | if best_score <= R1: 568 | best_score = R1 569 | best_output_model_file = output_model_file 570 | logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score)) 571 | 572 | ## Uncomment if want to test on the best checkpoint 573 | # if args.local_rank == 0: 574 | # model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file) 575 | # eval_epoch(args, model, test_dataloader, device, n_gpu) 576 | 577 | elif args.do_eval: 578 | if args.local_rank == 0: 579 | eval_epoch(args, model, test_dataloader, device, n_gpu) 580 | 581 | if __name__ == "__main__": 582 | main() 583 | --------------------------------------------------------------------------------