├── extracted_feats ├── msrvtt │ └── .gitkeep └── msvd │ └── .gitkeep ├── feature_extractor ├── modules │ ├── .gitkeep │ ├── cross-base │ │ └── cross_config.json │ ├── until_config.py │ ├── optimization.py │ ├── file_utils.py │ ├── module_cross.py │ └── until_module.py ├── pretrained_clip4clip │ ├── msvd │ │ └── .gitkeep │ └── msrvtt │ │ └── .gitkeep ├── utility │ ├── dataset.py │ ├── vocabulary.py │ └── util.py ├── README.md ├── util.py └── clip_feature_extractor.py ├── .gitignore ├── dataset └── MSVD │ ├── raw-captions_mapped.pkl │ ├── val_list_mapping.txt │ ├── test_list_mapping.txt │ └── train_list_mapping.txt ├── modules ├── visual-base │ └── visual_config.json ├── decoder-base │ └── decoder_config.json ├── beam.py ├── until_config.py ├── optimization.py ├── file_utils.py ├── modeling.py ├── until_module.py ├── tokenization.py └── module_decoder.py ├── scripts ├── train_msvd.sh ├── train_msrvtt.sh └── eval_msrvtt.sh ├── README.md └── dataloaders ├── dataloader_msrvtt_raw.py ├── dataloader_msvd_raw.py ├── rawvideo_util.py ├── dataloader_msrvtt_feats.py └── dataloader_msvd_feats.py /extracted_feats/msrvtt/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extracted_feats/msvd/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /feature_extractor/modules/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /feature_extractor/pretrained_clip4clip/msvd/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /feature_extractor/pretrained_clip4clip/msrvtt/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ckpts/ 2 | dataset/MSRVTT/raw/ 3 | dataset/MSVD/raw/ 4 | *.ipynb_checkpoints* 5 | */__pycache__/* -------------------------------------------------------------------------------- /dataset/MSVD/raw-captions_mapped.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liupeng0606/clip4caption/HEAD/dataset/MSVD/raw-captions_mapped.pkl -------------------------------------------------------------------------------- /modules/visual-base/visual_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 1, 11 | "vocab_size": 512 12 | } 13 | -------------------------------------------------------------------------------- /feature_extractor/modules/cross-base/cross_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 512, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 2048, 8 | "max_position_embeddings": 128, 9 | "num_attention_heads": 8, 10 | "num_hidden_layers": 4, 11 | "vocab_size": 512 12 | } -------------------------------------------------------------------------------- /modules/decoder-base/decoder_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "num_attention_heads": 12, 9 | "num_hidden_layers": 12, 10 | "type_vocab_size": 2, 11 | "vocab_size": 30522, 12 | "num_decoder_layers": 1, 13 | "max_target_embeddings": 512 14 | } 15 | -------------------------------------------------------------------------------- /feature_extractor/utility/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | # we do not used this class 3 | class CustomDataset(Dataset): 4 | def __init__(self, fo_input, stgraph, target): 5 | 6 | self.fo_input = fo_input 7 | self.stgraph = stgraph 8 | self.target = target 9 | self.n_samples = len(fo_input) 10 | 11 | def __getitem__(self, index): 12 | return self.fo_input[index], self.stgraph[index], self.target[index] 13 | 14 | def __len__(self): 15 | return self.n_samples 16 | -------------------------------------------------------------------------------- /scripts/train_msvd.sh: -------------------------------------------------------------------------------- 1 | # Setup 2 | DATATYPE=msvd 3 | N_GPU=2 4 | N_THREAD=16 5 | 6 | DATA_PATH=../dataset/MSVD 7 | CKPT_ROOT=../ckpts 8 | INIT_MODEL_PATH=../weight/univl.pretrained.bin 9 | # Path to the features you extracted from CLIP4Clip 10 | FEATURES_PATH=../extracted_feats/msvd/MSVD_CLIP4Clip_features.pickle 11 | # Tuning Params 12 | LEARNING_RATE=(7e-6 1e-5 3e-5 7e-5 1e-4 3e-4 7e-4) 13 | 14 | for lr in "${LEARNING_RATE[@]}" 15 | do 16 | python -m torch.distributed.launch --nproc_per_node=${N_GPU} \ 17 | ../train.py --do_train --num_thread_reader=${N_THREAD} \ 18 | --epochs=50 --batch_size=128 --n_display=50 --gradient_accumulation_steps 1 \ 19 | --data_path ${DATA_PATH} --features_path ${FEATURES_PATH} \ 20 | --output_dir ${CKPT_ROOT}/${DATATYPE}_lr${lr} \ 21 | --bert_model bert-base-uncased --do_lower_case \ 22 | --lr ${lr} --max_words 48 --max_frames 20 --batch_size_val 16 \ 23 | --visual_num_hidden_layers 2 --decoder_num_hidden_layers 2 \ 24 | --datatype ${DATATYPE} --init_model ${INIT_MODEL_PATH} \ 25 | --d_model 512 --video_dim 512 26 | done 27 | -------------------------------------------------------------------------------- /scripts/train_msrvtt.sh: -------------------------------------------------------------------------------- 1 | # Setup 2 | DATATYPE=msrvtt 3 | N_GPU=2 4 | N_THREAD=16 5 | 6 | DATA_PATH=../dataset/MSRVTT/MSRVTT_data.json 7 | CKPT_ROOT=../ckpts 8 | INIT_MODEL_PATH=../weight/univl.pretrained.bin 9 | # Path to the features you extracted from CLIP4Clip 10 | FEATURES_PATH=../extracted_feats/msrvtt/MSRVTT_CLIP4Clip_features.pickle 11 | # Tuning Params 12 | LEARNING_RATE=(7e-6 1e-5 3e-5 7e-5 1e-4 3e-4 7e-4) 13 | 14 | for lr in "${LEARNING_RATE[@]}" 15 | do 16 | python -m torch.distributed.launch --nproc_per_node=${N_GPU} \ 17 | ../train.py --do_train --num_thread_reader=${N_THREAD} \ 18 | --epochs=50 --batch_size=1024 --n_display=50 --gradient_accumulation_steps 2 \ 19 | --data_path ${DATA_PATH} --features_path ${FEATURES_PATH} \ 20 | --output_dir ${CKPT_ROOT}/${DATATYPE}_lr${lr} \ 21 | --bert_model bert-base-uncased --do_lower_case \ 22 | --lr ${lr} --max_words 48 --max_frames 20 --batch_size_val 128 \ 23 | --visual_num_hidden_layers 2 --decoder_num_hidden_layers 2 \ 24 | --datatype ${DATATYPE} --init_model ${INIT_MODEL_PATH} \ 25 | --d_model 512 --video_dim 512 26 | done 27 | -------------------------------------------------------------------------------- /scripts/eval_msrvtt.sh: -------------------------------------------------------------------------------- 1 | # make sure to change the value of INIT_MODEL_PATH 2 | 3 | # Setup 4 | DATATYPE=msrvtt 5 | N_GPU=1 6 | N_THREAD=16 7 | 8 | DATA_PATH=../dataset/MSRVTT/MSRVTT_data.json 9 | CKPT_ROOT=../checkpoints 10 | # init with the desired model 11 | INIT_MODEL_PATH= 12 | # Path to the features you extracted from CLIP4Clip 13 | FEATURES_PATH=../extracted_feats/msrvtt/MSRVTT_CLIP4Clip_features.pickle 14 | # Tuning Params 15 | LEARNING_RATE=7e-5 16 | 17 | python -m torch.distributed.launch --nproc_per_node=${N_GPU} \ 18 | ../train.py --do_eval --num_thread_reader=${N_THREAD} \ 19 | --epochs=50 --batch_size=1024 --n_display=50 --gradient_accumulation_steps 2 \ 20 | --data_path ${DATA_PATH} --features_path ${FEATURES_PATH} \ 21 | --output_dir ${CKPT_ROOT}/${DATATYPE}_lr${LEARNING_RATE}_eval \ 22 | --bert_model bert-base-uncased --do_lower_case \ 23 | --lr ${LEARNING_RATE} --max_words 48 --max_frames 20 --batch_size_val 128 \ 24 | --visual_num_hidden_layers 2 --decoder_num_hidden_layers 2 \ 25 | --datatype ${DATATYPE} --init_model ${INIT_MODEL_PATH} \ 26 | --d_model 512 --video_dim 512 27 | -------------------------------------------------------------------------------- /dataset/MSVD/val_list_mapping.txt: -------------------------------------------------------------------------------- 1 | vid1201 2 | vid1202 3 | vid1203 4 | vid1204 5 | vid1205 6 | vid1206 7 | vid1207 8 | vid1208 9 | vid1209 10 | vid1210 11 | vid1211 12 | vid1212 13 | vid1213 14 | vid1214 15 | vid1215 16 | vid1216 17 | vid1217 18 | vid1218 19 | vid1219 20 | vid1220 21 | vid1221 22 | vid1222 23 | vid1223 24 | vid1224 25 | vid1225 26 | vid1226 27 | vid1227 28 | vid1228 29 | vid1229 30 | vid1230 31 | vid1231 32 | vid1232 33 | vid1233 34 | vid1234 35 | vid1235 36 | vid1236 37 | vid1237 38 | vid1238 39 | vid1239 40 | vid1240 41 | vid1241 42 | vid1242 43 | vid1243 44 | vid1244 45 | vid1245 46 | vid1246 47 | vid1247 48 | vid1248 49 | vid1249 50 | vid1250 51 | vid1251 52 | vid1252 53 | vid1253 54 | vid1254 55 | vid1255 56 | vid1256 57 | vid1257 58 | vid1258 59 | vid1259 60 | vid1260 61 | vid1261 62 | vid1262 63 | vid1263 64 | vid1264 65 | vid1265 66 | vid1266 67 | vid1267 68 | vid1268 69 | vid1269 70 | vid1270 71 | vid1271 72 | vid1272 73 | vid1273 74 | vid1274 75 | vid1275 76 | vid1276 77 | vid1277 78 | vid1278 79 | vid1279 80 | vid1280 81 | vid1281 82 | vid1282 83 | vid1283 84 | vid1284 85 | vid1285 86 | vid1286 87 | vid1287 88 | vid1288 89 | vid1289 90 | vid1290 91 | vid1291 92 | vid1292 93 | vid1293 94 | vid1294 95 | vid1295 96 | vid1296 97 | vid1297 98 | vid1298 99 | vid1299 100 | vid1300 -------------------------------------------------------------------------------- /feature_extractor/README.md: -------------------------------------------------------------------------------- 1 | # How to extract the video features from CLIP4Clip 2 | ## Prepare the Pretrained CLIP4Clip Model 3 | Train CLIP4Clip model using https://github.com/ArrowLuo/CLIP4Clip , and put the pretrained model with the name of `pytorch_model.bin` in the `pretrained_clip4clip/msvd` or `pretrained_clip4clip/msrvtt` folder. Or you can download the pretrained model of CLIP4Clip (ViT-B/32) from this [link](https://drive.google.com/drive/folders/16BlJLtGMrGmIty56wAO4POh-1Uw_AJ8o?usp=sharing). 4 | 5 | ## Prepare the Dataset 6 | Dowload the raw videos of MSVD ([link](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/)) and MSR-VTT ([link](https://github.com/VisionLearningGroup/caption-guided-saliency/issues/6)), and put the videos in `raw` folder as displayed below 7 | ```bash 8 | ├── feature_extractor 9 | ├── dataset 10 | │ ├── MSVD 11 | │ │ ├── raw # put the 1970 raw videos in here 12 | │ │ ├── captions 13 | │ │ ├── raw-captions_mapped.pkl # mapping between video id with captions 14 | │ │ ├── train_list_mapping.txt 15 | │ │ ├── val_list_mapping.txt 16 | │ │ ├── test_list_mapping.txt 17 | │ ├── MSRVTT 18 | │ │ ├── raw # put the 10000 raw videos in here 19 | │ │ ├── msrvtt.csv # list of video id in msrvtt dataset 20 | │ │ ├── MSRVTT_data.json # metadata of msrvtt dataset, which includes video url, video id, and caption 21 | ``` 22 | ## Extract the Features 23 | Execute the following command to extract the video features and to save the features at `../extracted_feats` folder: 24 | ### MSR-VTT 25 | ```bash 26 | python clip_feature_extractor.py --dataset_type msrvtt --save_dir ../extracted_feats --dataset_dir ../dataset 27 | ``` 28 | ### MSVD 29 | ```bash 30 | python clip_feature_extractor.py --dataset_type msvd --save_dir ../extracted_feats --dataset_dir ../dataset 31 | ``` 32 | 33 | Note that you may need to modify the arguments as per your needs. 34 | -------------------------------------------------------------------------------- /feature_extractor/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import threading 4 | from torch._utils import ExceptionWrapper 5 | import logging 6 | 7 | def get_a_var(obj): 8 | if isinstance(obj, torch.Tensor): 9 | return obj 10 | 11 | if isinstance(obj, list) or isinstance(obj, tuple): 12 | for result in map(get_a_var, obj): 13 | if isinstance(result, torch.Tensor): 14 | return result 15 | if isinstance(obj, dict): 16 | for result in map(get_a_var, obj.items()): 17 | if isinstance(result, torch.Tensor): 18 | return result 19 | return None 20 | 21 | def parallel_apply(fct, model, inputs, device_ids): 22 | modules = nn.parallel.replicate(model, device_ids) 23 | assert len(modules) == len(inputs) 24 | lock = threading.Lock() 25 | results = {} 26 | grad_enabled = torch.is_grad_enabled() 27 | 28 | def _worker(i, module, input): 29 | torch.set_grad_enabled(grad_enabled) 30 | device = get_a_var(input).get_device() 31 | try: 32 | with torch.cuda.device(device): 33 | # this also avoids accidental slicing of `input` if it is a Tensor 34 | if not isinstance(input, (list, tuple)): 35 | input = (input,) 36 | output = fct(module, *input) 37 | with lock: 38 | results[i] = output 39 | except Exception: 40 | with lock: 41 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) 42 | 43 | if len(modules) > 1: 44 | threads = [threading.Thread(target=_worker, args=(i, module, input)) 45 | for i, (module, input) in enumerate(zip(modules, inputs))] 46 | 47 | for thread in threads: 48 | thread.start() 49 | for thread in threads: 50 | thread.join() 51 | else: 52 | _worker(0, modules[0], inputs[0]) 53 | 54 | outputs = [] 55 | for i in range(len(inputs)): 56 | output = results[i] 57 | if isinstance(output, ExceptionWrapper): 58 | output.reraise() 59 | outputs.append(output) 60 | return outputs 61 | 62 | def get_logger(filename=None): 63 | logger = logging.getLogger('logger') 64 | logger.setLevel(logging.DEBUG) 65 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 66 | datefmt='%m/%d/%Y %H:%M:%S', 67 | level=logging.INFO) 68 | if filename is not None: 69 | handler = logging.FileHandler(filename) 70 | handler.setLevel(logging.DEBUG) 71 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 72 | logging.getLogger().addHandler(handler) 73 | return logger -------------------------------------------------------------------------------- /feature_extractor/utility/vocabulary.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | # we do not used this vocabulary class 3 | class Vocabulary: 4 | PAD_token = 0 # Used for padding short sentences 5 | BOS_token = 1 # Beginning-of-sentence token 6 | EOS_token = 2 # End-of-sentence token 7 | UNK_token = 3 # Unknown word token 8 | 9 | def __init__(self): 10 | self.word2index = {} 11 | self.word2count = {} 12 | self.index2word = {self.PAD_token: "", self.BOS_token: "", self.EOS_token: "", self.UNK_token: ""} 13 | self.num_words = 4 14 | self.num_sentences = 0 15 | self.longest_sentence = 0 16 | self.tokenizer = spacy.load('en_core_web_sm') 17 | 18 | def add_word(self, word): 19 | if word not in self.word2index: 20 | # First entry of word into vocabulary 21 | self.word2index[word] = self.num_words 22 | self.word2count[word] = 1 23 | self.index2word[self.num_words] = word 24 | self.num_words += 1 25 | else: 26 | # Word exists; increase word count 27 | self.word2count[word] += 1 28 | 29 | def add_sentence(self, sentence): 30 | sentence_len = 0 31 | for word in self.tokenizer(sentence): 32 | sentence_len += 1 33 | self.add_word(str(word)) 34 | if sentence_len > self.longest_sentence: 35 | # This is the longest sentence 36 | self.longest_sentence = sentence_len 37 | # Count the number of sentences 38 | self.num_sentences += 1 39 | 40 | def generate_vector(self, sentence="Hello", longest_sentence=None): 41 | # Validation data/test data may have longer sentence, so a parameter longest sentence provided 42 | if longest_sentence is None: 43 | longest_sentence = self.longest_sentence 44 | 45 | vector = [self.BOS_token] 46 | sentence_len = 0 47 | for word in self.tokenizer(sentence): 48 | vector.append(self.to_index(str(word))) 49 | sentence_len += 1 50 | vector.append(self.EOS_token) 51 | 52 | # Add token if needed 53 | if sentence_len < longest_sentence: 54 | for i in range(sentence_len, longest_sentence): 55 | vector.append(self.PAD_token) 56 | 57 | return vector 58 | 59 | def to_word(self, index): 60 | return self.index2word[index] 61 | 62 | def to_index(self, word): 63 | if word not in self.word2index: 64 | return self.UNK_token 65 | 66 | return self.word2index[word] 67 | 68 | def filter_vocab(self, min_word_count=0): 69 | word2count = self.word2count 70 | self.num_words = 4 71 | self.word2index = {} 72 | self.word2count = {} 73 | self.index2word = {self.PAD_token: "", self.BOS_token: "", self.EOS_token: "", self.UNK_token: ""} 74 | for word, count in word2count.items(): 75 | if count>=min_word_count: 76 | self.word2index[word] = self.num_words 77 | self.word2count[word] = count 78 | self.index2word[self.num_words] = word 79 | self.num_words += 1 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reproducing CLIP4Caption 2 | This is the first unofficial implementation of CLIP4Caption method (ACMMM 2021), which is the SOTA method in video captioning task at the time when this project was implemented. 3 | 4 | ![image](https://user-images.githubusercontent.com/5786636/210189414-aef876e0-38ab-4026-8ece-4a1f803a1005.png) 5 | 6 | **Note**: The provided extracted features and the reproduced results are not obtained using TSN sampling as in the CLIP4Caption paper. However, even without TSN sampling, i.e., only using original sampling method in CLIP4Clip, it is found that the similar (even slightly better) performance results can be achieved as in the CLIP4Caption paper. While reproducing the results, it was observed that using TSN sampling could not achieve the similar performance results as in the paper. 7 | 8 | **Paper**: Mingkang Tang, Zhanyu Wang, Zhenhua LIU, Fengyun Rao, Dian Li, and Xiu Li. 2021. CLIP4Caption: CLIP for Video Caption. In Proceedings of the 29th ACM International Conference on Multimedia (MM '21). Association for Computing Machinery, New York, NY, USA, 4858–4862. > https://dl.acm.org/doi/10.1145/3474085.3479207 9 | 10 | ## Reproduced Results 11 | ### MSRVTT 12 | | Method | BLEU@4 | METEOR | ROUGE-L | CIDEr | Checkpoint | 13 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | 14 | | CLIP4Caption* (ViT-B/32) | 46.1 | 30.7 | 63.7 | 57.7 | - | 15 | | CLIP4Caption (ViT-B/32) | 46.41 | 30.72 | 63.90 | 58.42 | [link](https://drive.google.com/file/d/17p476sL5_KZoQ2h4e1TU7-qWH4VqzReT/view?usp=sharing) | 16 | | CLIP4Caption (ViT-B/16) | 48.37 | 31.53 | 65.20 | 61.34 | [link](https://drive.google.com/file/d/1WyLktrWLrHMR_ymt_yQHWd4zyvnWSc71/view?usp=sharing) | 17 | 18 | (*) Original results from the paper 19 | 20 | ## Setup 21 | Execute below scripts in the main folder, to avoid download conflict when doing distributed pretrain. 22 | 23 | ```bash 24 | mkdir modules/bert-base-uncased 25 | cd modules/bert-base-uncased/ 26 | wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt 27 | mv bert-base-uncased-vocab.txt vocab.txt 28 | wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz 29 | tar -xvf bert-base-uncased.tar.gz 30 | rm bert-base-uncased.tar.gz 31 | cd ../../ 32 | ``` 33 | 34 | Prepare the conda environment: 35 | ```bash 36 | conda create -n clip4caption python=3.6.9 tqdm boto3 requests pandas 37 | conda activate clip4caption 38 | pip install torch==1.10.2 torchvision --extra-index-url https://download.pytorch.org/whl/cu113 39 | pip install git+https://github.com/Maluuba/nlg-eval.git@master 40 | pip install pycocoevalcap 41 | pip install pickle5 42 | pip install opencv-python==4.5.5.62 43 | ``` 44 | 45 | Download the pretrained weight of UniVL: 46 | ```bash 47 | mkdir -p ./weight 48 | wget -P ./weight https://github.com/microsoft/UniVL/releases/download/v0/univl.pretrained.bin 49 | ``` 50 | 51 | ## Extract the Video Features 52 | The extracted features (ViT-B/32) are provided [here](https://drive.google.com/drive/folders/1GHpAKDNU3qZxzIk6zqdatMjQoGc6Q3_9?usp=sharing). However if you want to extract the features by yourself, follow the instructions written [here](https://github.com/willyfh/clip4caption/tree/main/feature_extractor) 53 | 54 | ## Training & Evaluation 55 | The shell scripts to train and to evaluate the model is provided [here](https://github.com/willyfh/clip4caption/tree/main/scripts). You may need to modify the scripts as per your needs. 56 | 57 | ## References 58 | This repository is implemented based on [UniVL](https://github.com/microsoft/UniVL) and [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip) 59 | -------------------------------------------------------------------------------- /modules/beam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manage beam search info structure. 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link (maybe in oldest version): 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | 10 | class Constants(): 11 | """ Default constants for tokenizer """ 12 | def __init__(self): 13 | self.PAD = 0 14 | self.UNK = 1 15 | self.BOS = 2 16 | self.EOS = 3 17 | self.PAD_WORD = '[PAD]' 18 | self.UNK_WORD = '[UNK]' 19 | self.BOS_WORD = '[CLS]' 20 | self.EOS_WORD = '[SEP]' 21 | 22 | @classmethod 23 | def from_tokenizer(cls, tokenizer): 24 | instance = cls() 25 | instance.PAD = tokenizer.vocab[instance.PAD_WORD] 26 | instance.UNK = tokenizer.vocab[instance.UNK_WORD] 27 | instance.BOS = tokenizer.vocab[instance.BOS_WORD] 28 | instance.EOS = tokenizer.vocab[instance.EOS_WORD] 29 | return instance 30 | 31 | class Beam(): 32 | '''Implementation of the beam search from the `"Beam Search Strategies for Neural Machine Translation" 33 | ` paper. 34 | Params: 35 | size: beam search width. 36 | device: device for running the algorithm. 37 | tokenizer: whether to use default or predefined tokenizer. 38 | ''' 39 | 40 | def __init__(self, size, device=False, tokenizer=None): 41 | if tokenizer is None: 42 | self.constants = Constants() 43 | else: 44 | self.constants = Constants.from_tokenizer(tokenizer) 45 | 46 | self.size = size 47 | self._done = False 48 | # The score for each interface on the beam. 49 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 50 | self.all_scores = [] 51 | 52 | # The backpointers at each time-step. 53 | self.prev_ks = [] 54 | 55 | # The outputs at each time-step. 56 | self.next_ys = [torch.full((size,), self.constants.BOS, dtype=torch.long, device=device)] 57 | 58 | def get_current_state(self): 59 | "Get the outputs for the current timestep." 60 | return self.get_tentative_hypothesis() 61 | 62 | def get_current_origin(self): 63 | "Get the backpointers for the current timestep." 64 | return self.prev_ks[-1] 65 | 66 | @property 67 | def done(self): 68 | return self._done 69 | 70 | def advance(self, word_prob, word_length=None): 71 | 72 | "Update beam status and check if finished or not." 73 | num_words = word_prob.size(1) 74 | # Sum the previous scores. 75 | if len(self.prev_ks) > 0: 76 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 77 | else: 78 | beam_lk = word_prob[0] 79 | flat_beam_lk = beam_lk.view(-1) 80 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 81 | self.all_scores.append(self.scores) 82 | self.scores = best_scores 83 | # bestScoresId is flattened as a (beam x word) array, 84 | # so we need to calculate which word and beam each score came from 85 | prev_k = best_scores_id // num_words 86 | self.prev_ks.append(prev_k) 87 | self.next_ys.append(best_scores_id - prev_k * num_words) 88 | # End condition is when top-of-beam is EOS. 89 | if self.next_ys[-1][0].item() == self.constants.EOS: 90 | self._done = True 91 | 92 | return self._done 93 | 94 | def sort_scores(self): 95 | "Sort the scores." 96 | return torch.sort(self.scores, 0, True) 97 | 98 | def get_the_best_score_and_idx(self): 99 | "Get the score of the best in the beam." 100 | scores, ids = self.sort_scores() 101 | return scores[1], ids[1] 102 | 103 | def get_tentative_hypothesis(self): 104 | "Get the decoded sequence for the current timestep." 105 | 106 | if len(self.next_ys) == 1: 107 | dec_seq = self.next_ys[0].unsqueeze(1) 108 | else: 109 | _, keys = self.sort_scores() 110 | hyps = [self.get_hypothesis(k) for k in keys] 111 | hyps = [[self.constants.BOS] + h for h in hyps] 112 | dec_seq = torch.LongTensor(hyps) 113 | 114 | return dec_seq 115 | 116 | def get_hypothesis(self, k): 117 | """ Walk back to construct the full hypothesis. """ 118 | hyp = [] 119 | for j in range(len(self.prev_ks) - 1, -1, -1): 120 | hyp.append(self.next_ys[j+1][k]) 121 | k = self.prev_ks[j][k] 122 | 123 | return list(map(lambda x: x.item(), hyp[::-1])) 124 | -------------------------------------------------------------------------------- /feature_extractor/modules/until_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import logging 26 | import tarfile 27 | import tempfile 28 | import shutil 29 | import torch 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | class PretrainedConfig(object): 35 | 36 | pretrained_model_archive_map = {} 37 | config_name = "" 38 | weights_name = "" 39 | 40 | @classmethod 41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): 42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 43 | if os.path.exists(archive_file) is False: 44 | if pretrained_model_name in cls.pretrained_model_archive_map: 45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] 46 | else: 47 | archive_file = pretrained_model_name 48 | 49 | # redirect to the cache, if necessary 50 | try: 51 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 52 | except FileNotFoundError: 53 | if task_config is None or task_config.local_rank == 0: 54 | logger.error( 55 | "Model name '{}' was not found in model name list. " 56 | "We assumed '{}' was a path or url but couldn't find any file " 57 | "associated to this path or url.".format( 58 | pretrained_model_name, 59 | archive_file)) 60 | return None 61 | if resolved_archive_file == archive_file: 62 | if task_config is None or task_config.local_rank == 0: 63 | logger.info("loading archive file {}".format(archive_file)) 64 | else: 65 | if task_config is None or task_config.local_rank == 0: 66 | logger.info("loading archive file {} from cache at {}".format( 67 | archive_file, resolved_archive_file)) 68 | tempdir = None 69 | if os.path.isdir(resolved_archive_file): 70 | serialization_dir = resolved_archive_file 71 | else: 72 | # Extract archive to temp dir 73 | tempdir = tempfile.mkdtemp() 74 | if task_config is None or task_config.local_rank == 0: 75 | logger.info("extracting archive file {} to temp dir {}".format( 76 | resolved_archive_file, tempdir)) 77 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 78 | archive.extractall(tempdir) 79 | serialization_dir = tempdir 80 | # Load config 81 | config_file = os.path.join(serialization_dir, cls.config_name) 82 | config = cls.from_json_file(config_file) 83 | config.type_vocab_size = type_vocab_size 84 | if task_config is None or task_config.local_rank == 0: 85 | logger.info("Model config {}".format(config)) 86 | 87 | if state_dict is None: 88 | weights_path = os.path.join(serialization_dir, cls.weights_name) 89 | if os.path.exists(weights_path): 90 | state_dict = torch.load(weights_path, map_location='cpu') 91 | else: 92 | if task_config is None or task_config.local_rank == 0: 93 | logger.info("Weight doesn't exsits. {}".format(weights_path)) 94 | 95 | if tempdir: 96 | # Clean up temp dir 97 | shutil.rmtree(tempdir) 98 | 99 | return config, state_dict 100 | 101 | @classmethod 102 | def from_dict(cls, json_object): 103 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 104 | config = cls(vocab_size_or_config_json_file=-1) 105 | for key, value in json_object.items(): 106 | config.__dict__[key] = value 107 | return config 108 | 109 | @classmethod 110 | def from_json_file(cls, json_file): 111 | """Constructs a `BertConfig` from a json file of parameters.""" 112 | with open(json_file, "r", encoding='utf-8') as reader: 113 | text = reader.read() 114 | return cls.from_dict(json.loads(text)) 115 | 116 | def __repr__(self): 117 | return str(self.to_json_string()) 118 | 119 | def to_dict(self): 120 | """Serializes this instance to a Python dictionary.""" 121 | output = copy.deepcopy(self.__dict__) 122 | return output 123 | 124 | def to_json_string(self): 125 | """Serializes this instance to a JSON string.""" 126 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" -------------------------------------------------------------------------------- /modules/until_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import logging 26 | import tarfile 27 | import tempfile 28 | import shutil 29 | import torch 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | class PretrainedConfig(object): 35 | 36 | pretrained_model_archive_map = {} 37 | config_name = "" 38 | weights_name = "" 39 | 40 | @classmethod 41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): 42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 43 | if os.path.exists(archive_file) is False: 44 | if pretrained_model_name in cls.pretrained_model_archive_map: 45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] 46 | else: 47 | archive_file = pretrained_model_name 48 | 49 | # redirect to the cache, if necessary 50 | try: 51 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 52 | except FileNotFoundError: 53 | if task_config is None or task_config.local_rank == 0: 54 | logger.error( 55 | "Model name '{}' was not found in model name list. " 56 | "We assumed '{}' was a path or url but couldn't find any file " 57 | "associated to this path or url.".format( 58 | pretrained_model_name, 59 | archive_file)) 60 | return None 61 | if resolved_archive_file == archive_file: 62 | if task_config is None or task_config.local_rank == 0: 63 | logger.info("loading archive file {}".format(archive_file)) 64 | else: 65 | if task_config is None or task_config.local_rank == 0: 66 | logger.info("loading archive file {} from cache at {}".format( 67 | archive_file, resolved_archive_file)) 68 | tempdir = None 69 | if os.path.isdir(resolved_archive_file): 70 | serialization_dir = resolved_archive_file 71 | else: 72 | # Extract archive to temp dir 73 | tempdir = tempfile.mkdtemp() 74 | if task_config is None or task_config.local_rank == 0: 75 | logger.info("extracting archive file {} to temp dir {}".format( 76 | resolved_archive_file, tempdir)) 77 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 78 | archive.extractall(tempdir) 79 | serialization_dir = tempdir 80 | # Load config 81 | config_file = os.path.join(serialization_dir, cls.config_name) 82 | config = cls.from_json_file(config_file) 83 | config.type_vocab_size = type_vocab_size 84 | if task_config is None or task_config.local_rank == 0: 85 | logger.info("Model config {}".format(config)) 86 | 87 | if state_dict is None: 88 | weights_path = os.path.join(serialization_dir, cls.weights_name) 89 | if os.path.exists(weights_path): 90 | state_dict = torch.load(weights_path, map_location='cpu') 91 | else: 92 | if task_config is None or task_config.local_rank == 0: 93 | logger.info("Weight doesn't exsits. {}".format(weights_path)) 94 | 95 | if tempdir: 96 | # Clean up temp dir 97 | shutil.rmtree(tempdir) 98 | 99 | return config, state_dict 100 | 101 | @classmethod 102 | def from_dict(cls, json_object): 103 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 104 | config = cls(vocab_size_or_config_json_file=-1) 105 | for key, value in json_object.items(): 106 | config.__dict__[key] = value 107 | return config 108 | 109 | @classmethod 110 | def from_json_file(cls, json_file): 111 | """Constructs a `BertConfig` from a json file of parameters.""" 112 | with open(json_file, "r", encoding='utf-8') as reader: 113 | text = reader.read() 114 | return cls.from_dict(json.loads(text)) 115 | 116 | def __repr__(self): 117 | return str(self.to_json_string()) 118 | 119 | def to_dict(self): 120 | """Serializes this instance to a Python dictionary.""" 121 | output = copy.deepcopy(self.__dict__) 122 | return output 123 | 124 | def to_json_string(self): 125 | """Serializes this instance to a JSON string.""" 126 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" -------------------------------------------------------------------------------- /dataloaders/dataloader_msrvtt_raw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | import torch 8 | from torch.utils.data import Dataset 9 | import numpy as np 10 | import pandas as pd 11 | from collections import defaultdict 12 | import json 13 | import cv2 14 | import random 15 | from rawvideo_util import RawVideoExtractor 16 | from PIL import Image 17 | 18 | # Based on https://github.com/ArrowLuo/CLIP4Clip 19 | 20 | class MSRVTT_Raw_DataLoader(Dataset): 21 | """Implementation of the dataloader for MSRVTT. Mainly used in the feature extraction process. 22 | Params: 23 | csv_path: Path to the msrvtt.csv file. 24 | videos_path: Path to the video files. 25 | max_words: Max word length retained. Any more than the value will be truncated. Default: 73 26 | feature_framerate: sampling rate in second. Default: 1.0 27 | max_frames: Max frame sampled. Any more than the value will be ignored. Default: 20 28 | image_resolution: Processed image's width and height, in pixel. If param transform_type = 0 and 29 | the original image is greater than this value, it will be resized and center cropped. Default: 224 30 | frame_order: 0: ordinary order; 1: reverse order; 2: random order. Default: 0 31 | slice_framepos: 0: sample from the first frames; 1: sample from the last frames; 32 | 2: sample uniformly. Default: 0 33 | transform_type: 0: default transformation; 1: transformation for objects, iou, temporal, action; 34 | 2: transformation for i3d;. Default: 0 35 | """ 36 | def __init__( 37 | self, 38 | csv_path, 39 | videos_path, 40 | max_words=73, 41 | feature_framerate=1, 42 | max_frames=20, 43 | image_resolution=224, 44 | frame_order=0, 45 | slice_framepos=0, 46 | transform_type =0, 47 | ): 48 | self.data = pd.read_csv(csv_path) 49 | self.videos_path = videos_path 50 | self.feature_framerate = feature_framerate 51 | self.max_words = max_words 52 | self.max_frames = max_frames 53 | self.transform_type = transform_type 54 | 55 | # 0: ordinary order; 1: reverse order; 2: random order. 56 | self.frame_order = frame_order 57 | assert self.frame_order in [0, 1, 2] 58 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 59 | self.slice_framepos = slice_framepos 60 | assert self.slice_framepos in [0, 1, 2] 61 | 62 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution,type = self.transform_type) 63 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 64 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 65 | 66 | def __len__(self): 67 | return len(self.data) 68 | 69 | def _get_rawvideo(self, choice_video_ids): 70 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 71 | max_video_length = [0] * len(choice_video_ids) 72 | 73 | # Pair x L x T x 3 x H x W 74 | # video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 75 | # self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) 76 | 77 | for i, video_id in enumerate(choice_video_ids): 78 | # Individual for YoucokII dataset, due to it video format 79 | video_path = os.path.join(self.videos_path, "{}.mp4".format(video_id)) 80 | if os.path.exists(video_path) is False: 81 | video_path = video_path.replace(".mp4", ".webm") 82 | 83 | raw_video_data,shapes = self.rawVideoExtractor.get_video_data(video_path) 84 | raw_video_data = raw_video_data['video'] 85 | # Pair x L x T x 3 x H x W 86 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 87 | shapes[2], shapes[3]), dtype=np.float) 88 | if len(raw_video_data.shape) > 3: 89 | raw_video_data_clip = raw_video_data 90 | # L x T x 3 x H x W 91 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 92 | if self.max_frames < raw_video_slice.shape[0]: 93 | if self.slice_framepos == 0: 94 | video_slice = raw_video_slice[:self.max_frames, ...] 95 | elif self.slice_framepos == 1: 96 | video_slice = raw_video_slice[-self.max_frames:, ...] 97 | else: 98 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 99 | video_slice = raw_video_slice[sample_indx, ...] 100 | else: 101 | video_slice = raw_video_slice 102 | 103 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 104 | 105 | slice_len = video_slice.shape[0] 106 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 107 | if slice_len < 1: 108 | pass 109 | else: 110 | video[i][:slice_len, ...] = video_slice 111 | else: 112 | print("video path: {} error. video id: {}".format(video_path, video_id)) 113 | 114 | for i, v_length in enumerate(max_video_length): 115 | video_mask[i][:v_length] = [1] * v_length 116 | 117 | return video, video_mask 118 | 119 | def __getitem__(self, idx): 120 | video_id = self.data['video_id'].values[idx] 121 | choice_video_ids = [video_id] 122 | 123 | video, video_mask = self._get_rawvideo(choice_video_ids) 124 | return video_id,video, video_mask 125 | -------------------------------------------------------------------------------- /feature_extractor/clip_feature_extractor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import numpy as np 5 | from numpy import dot 6 | from numpy.linalg import norm 7 | import sys 8 | import glob 9 | import json 10 | import math 11 | from tqdm import tqdm 12 | import torch 13 | 14 | from modules.until_module import PreTrainedModel, AllGather, CrossEn 15 | from modules.module_cross import CrossModel, CrossConfig, Transformer as TransformerClip 16 | 17 | from modules.module_clip import CLIP, convert_weights 18 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 19 | import pickle 20 | import pathlib 21 | 22 | from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 23 | from modules.modeling import CLIP4Clip 24 | from modules.optimization import BertAdam 25 | from util import parallel_apply, get_logger 26 | 27 | sys.path.append("../dataloaders/") 28 | from dataloader_msvd_raw import MSVD_Raw_DataLoader 29 | from dataloader_msrvtt_raw import MSRVTT_Raw_DataLoader 30 | 31 | 32 | 33 | # Argument 34 | class args: 35 | msvd = True # or msvd = False for MSR-VTT 36 | max_frames = 20 37 | pretrined_clip4clip_dir='pretrained' 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser(description="CLIP Feature Extractor") 41 | parser.add_argument('--dataset_type', choices=['msvd', 'msrvtt'], default='msvd', type=str, help='msvd or msrvtt') 42 | parser.add_argument('--dataset_dir', type=str, default='../dataset', help='should be pointed to the location where the MSVD and MSRVTT dataset located') 43 | parser.add_argument('--save_dir', type=str, default='../extracted_feats', help='location of the extracted features') 44 | parser.add_argument('--slice_framepos', choices=[0,1,2], type=int, default=2, 45 | help='0: sample from the first frames; 1: sample from the last frames; 2: sample uniformly.') 46 | parser.add_argument('--max_frames', type=int, default=20, help='max sampled frames') 47 | parser.add_argument('--frame_order', type=int, choices=[0,1,2], default=0, help='0: normal order; 1: reverse order; 2: random order.') 48 | parser.add_argument('--pretrained_clip4clip_dir', type=str, default='pretrained_clip4clip/', help='path to the pretrained CLIP4Clip model') 49 | parser.add_argument('--device', choices=["cpu", "cuda"], type=str, default='cuda', help='cuda or cpu') 50 | parser.add_argument('--pretrained_clip_name', type=str, choices=["ViT-B/32", "ViT-B/16"], default="ViT-B/32") 51 | 52 | args = parser.parse_args() 53 | 54 | if args.device == "cuda": 55 | args.device = torch.device('cuda') 56 | 57 | if args.dataset_type=="msvd": 58 | dset_path = os.path.join(args.dataset_dir,'MSVD') 59 | args.videos_path = os.path.join(dset_path,'raw') # video .avi 60 | 61 | args.data_path =os.path.join(os.path.join(dset_path,'captions','youtube_mapping.txt')) 62 | args.max_words = 30 63 | 64 | save_dir = os.path.join(args.save_dir, "msvd") 65 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) 66 | args.save_file = os.path.join(save_dir,'MSVD_CLIP4Clip_features.pickle') 67 | 68 | args.pretrained_clip4clip_path = os.path.join(args.pretrained_clip4clip_dir, 'msvd','pytorch_model.bin') 69 | 70 | elif args.dataset_type=="msrvtt": 71 | dset_path = os.path.join(args.dataset_dir,'MSRVTT') 72 | args.videos_path = os.path.join(dset_path,'raw') 73 | 74 | args.data_path=os.path.join(dset_path,'MSRVTT_data.json') 75 | args.max_words = 73 76 | args.csv_path = os.path.join(dset_path,'msrvtt.csv') 77 | 78 | save_dir = os.path.join(args.save_dir, "msrvtt") 79 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) 80 | args.save_file = os.path.join(save_dir,'MSRVTT_CLIP4Clip_features.pickle') 81 | 82 | args.pretrained_clip4clip_path = os.path.join(args.pretrained_clip4clip_dir, 'msrvtt','pytorch_model.bin') 83 | 84 | return args 85 | 86 | def get_dataloader(args): 87 | 88 | dataloader = None 89 | if args.dataset_type=="msvd": 90 | dataloader = MSVD_Raw_DataLoader( 91 | data_path=args.data_path, 92 | videos_path=args.videos_path, 93 | max_frames=args.max_frames, 94 | frame_order=args.frame_order, 95 | slice_framepos=args.slice_framepos, 96 | transform_type = 0, 97 | ) 98 | elif args.dataset_type=="msrvtt": 99 | dataloader = MSRVTT_Raw_DataLoader( 100 | csv_path=args.csv_path, 101 | videos_path=args.videos_path, 102 | max_frames=args.max_frames, 103 | frame_order=args.frame_order, 104 | slice_framepos=args.slice_framepos, 105 | transform_type = 0, 106 | ) 107 | return dataloader 108 | 109 | def load_model(args): 110 | model_state_dict = torch.load(args.pretrained_clip4clip_path, map_location='cpu') 111 | cache_dir = os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') 112 | model = CLIP4Clip.from_pretrained('cross-base', cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) 113 | clip = model.clip.to(args.device) 114 | return clip 115 | 116 | 117 | def main(): 118 | args = get_args() 119 | dataloader = get_dataloader(args) 120 | model = load_model(args) 121 | model.eval() 122 | 123 | with torch.no_grad(): 124 | data ={} 125 | stop = False 126 | with open(args.save_file, 'wb') as handle: 127 | 128 | for i in tqdm(range(len(dataloader))): 129 | video_id,video,video_mask = dataloader[i] 130 | 131 | tensor = video[0] 132 | tensor = tensor[video_mask[0]==1,:] 133 | tensor = torch.as_tensor(tensor).float() 134 | video_frame,num,channel,h,w = tensor.shape 135 | tensor = tensor.view(video_frame*num, channel, h, w) 136 | 137 | video_frame,channel,h,w = tensor.shape 138 | 139 | output = model.encode_image(tensor.to(args.device), video_frame=video_frame).float().to(args.device) 140 | output = output.detach().cpu().numpy() 141 | data[video_id]=output 142 | 143 | del output 144 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 145 | 146 | if __name__ == "__main__": 147 | main() -------------------------------------------------------------------------------- /dataloaders/dataloader_msvd_raw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pickle 10 | from rawvideo_util import RawVideoExtractor 11 | 12 | # Based on https://github.com/ArrowLuo/CLIP4Clip 13 | 14 | class MSVD_Raw_DataLoader(Dataset): 15 | """This dataloader is mainly used in the feature extraction process. It will 16 | Params: 17 | data_path: Path to the MSVD folder. 18 | videos_path: Path to the video files. 19 | max_words: Max word length retained. Any more than the value will be truncated. Default: 30 20 | feature_framerate: sampling rate in second. Default: 1.0 21 | max_frames: Max frame sampled. Any more than the value will be ignored. Default: 100 22 | image_resolution: Processed image's width and height, in pixel. If param transform_type = 0 and 23 | the original image is greater than this value, it will be resized and center cropped. Default: 224 24 | frame_order: 0: ordinary order; 1: reverse order; 2: random order. Default: 0 25 | slice_framepos: 0: sample from the first frames; 1: sample from the last frames; 26 | 2: sample uniformly. Default: 0 27 | transform_type: 0: default transformation; 1: transformation for objects, iou, temporal, action; 28 | 2: transformation for i3d;. Default: 0 29 | """ 30 | 31 | def __init__( 32 | self, 33 | data_path, 34 | videos_path, 35 | max_words=30, 36 | feature_framerate=1, 37 | max_frames=100, 38 | image_resolution=224, 39 | frame_order=0, 40 | slice_framepos=0, 41 | transform_type =0, 42 | ): 43 | self.data_path = data_path 44 | self.videos_path = videos_path 45 | self.feature_framerate = feature_framerate 46 | self.max_words = max_words 47 | self.max_frames = max_frames 48 | self.id_dict = {} 49 | self.transform_type = transform_type 50 | 51 | 52 | # 0: ordinary order; 1: reverse order; 2: random order. 53 | self.frame_order = frame_order 54 | assert self.frame_order in [0, 1, 2] 55 | # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. 56 | self.slice_framepos = slice_framepos 57 | assert self.slice_framepos in [0, 1, 2] 58 | 59 | url2id = {} 60 | for line in open(data_path, 'r').readlines(): 61 | url2id[line.strip().split(' ')[0]] = line.strip().split(' ')[-1] 62 | 63 | video_dict = {} 64 | for root, dub_dir, video_files in os.walk(self.videos_path): 65 | for video_file in video_files: 66 | video_id_ = url2id[".".join(video_file.split(".")[:-1])] 67 | 68 | self.id_dict[len(self.id_dict)] = video_id_ 69 | 70 | file_path_ = os.path.join(root, video_file) 71 | video_dict[video_id_] = file_path_ 72 | self.video_dict = video_dict 73 | 74 | self.sample_len = 0 75 | self.sentences_dict = {} 76 | self.cut_off_points = [] 77 | 78 | 79 | print("Video number: {}".format(len(self.video_dict))) 80 | print("Id number: {}".format(len(self.id_dict))) 81 | 82 | 83 | 84 | self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution,type = self.transform_type) 85 | self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", 86 | "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} 87 | 88 | def __len__(self): 89 | return len(self.id_dict) 90 | 91 | def _get_rawvideo(self, choice_video_ids): 92 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 93 | max_video_length = [0] * len(choice_video_ids) 94 | 95 | for i, video_id in enumerate(choice_video_ids): 96 | 97 | video_path = self.video_dict[video_id] 98 | 99 | raw_video_data,shapes = self.rawVideoExtractor.get_video_data(video_path) 100 | raw_video_data = raw_video_data['video'] 101 | 102 | # Pair x L x T x 3 x H x W 103 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 104 | shapes[2], shapes[3]), dtype=np.float) 105 | 106 | # Pair x L x T x 3 x H x W 107 | video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, 108 | shapes[2], shapes[3]), dtype=np.float) 109 | 110 | if len(raw_video_data.shape) > 3: 111 | raw_video_data_clip = raw_video_data 112 | # L x T x 3 x H x W 113 | raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) 114 | if self.max_frames < raw_video_slice.shape[0]: 115 | if self.slice_framepos == 0: 116 | video_slice = raw_video_slice[:self.max_frames, ...] 117 | elif self.slice_framepos == 1: 118 | video_slice = raw_video_slice[-self.max_frames:, ...] 119 | else: 120 | sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) 121 | video_slice = raw_video_slice[sample_indx, ...] 122 | else: 123 | video_slice = raw_video_slice 124 | 125 | video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) 126 | 127 | slice_len = video_slice.shape[0] 128 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len 129 | if slice_len < 1: 130 | pass 131 | else: 132 | video[i][:slice_len, ...] = video_slice 133 | 134 | else: 135 | print("video path: {} error. video id: {}".format(video_path, video_id)) 136 | 137 | for i, v_length in enumerate(max_video_length): 138 | video_mask[i][:v_length] = [1] * v_length 139 | 140 | return video, video_mask 141 | 142 | def __getitem__(self, idx): 143 | #print(idx) 144 | if idx >=len(self):raise IndexError 145 | video_id = self.id_dict[idx] 146 | choice_video_ids = [video_id] 147 | video, video_mask = self._get_rawvideo(choice_video_ids) 148 | return video_id, video, video_mask 149 | -------------------------------------------------------------------------------- /dataloaders/rawvideo_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from PIL import Image 4 | # pytorch=1.7.1 5 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 6 | # pip install opencv-python 7 | import cv2 8 | import torchvision.transforms as transforms 9 | # Based on https://github.com/ArrowLuo/CLIP4Clip 10 | class RawVideoExtractorCV2(): 11 | """Implementation of the raw video preprocessing. 12 | Params: 13 | size: Processed image's width and height, in pixel. If param transform_type = 0 and 14 | the original image is greater than this value, it will be resized and center cropped. Default: 224 15 | framerate: sampling rate in second. Default: 1.0 16 | type: 0: default transformation; 1: transformation for objects, iou, temporal, action; 17 | 2: transformation for i3d;. Default: 0 18 | """ 19 | def __init__(self, size=224, framerate=-1, type=0): 20 | self.size = size 21 | self.framerate = framerate 22 | self.type = type 23 | self.transform = self._transform(self.size) 24 | 25 | 26 | def _transform(self, n_px): 27 | if self.type == 0: 28 | return Compose([ 29 | Resize(n_px, interpolation=Image.BICUBIC), 30 | CenterCrop(n_px), 31 | lambda image: image.convert("RGB"), 32 | ToTensor(), 33 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 34 | ]) 35 | #objects, iou, temporal, action 36 | elif self.type == 1: 37 | return Compose([transforms.ToTensor()]) 38 | # i3d 39 | elif self.type == 2: 40 | mean = [0.5, 0.5, 0.5] 41 | std = [0.5, 0.5, 0.5] 42 | 43 | return Compose([ 44 | transforms.ToPILImage(), 45 | transforms.Resize(256), 46 | transforms.CenterCrop(224), 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean, std)]) 49 | 50 | 51 | def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None, patch=0, overlapped=0): 52 | if start_time is not None or end_time is not None: 53 | assert isinstance(start_time, int) and isinstance(end_time, int) \ 54 | and start_time > -1 and end_time > start_time 55 | assert sample_fp > -1 56 | 57 | # Samples a frame sample_fp X frames. 58 | cap = cv2.VideoCapture(video_file) 59 | frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 60 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 61 | 62 | total_duration = (frameCount + fps - 1) // fps 63 | start_sec, end_sec = 0, total_duration 64 | 65 | if start_time is not None: 66 | start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration 67 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) 68 | 69 | interval = 1 70 | if sample_fp > 0: 71 | interval = fps // sample_fp 72 | else: 73 | sample_fp = fps 74 | if interval == 0: interval = 1 75 | 76 | inds = [ind for ind in np.arange(0, fps, interval)] 77 | assert len(inds) >= sample_fp 78 | inds = inds[:sample_fp] 79 | 80 | ret = True 81 | images, included = [], [] 82 | 83 | for sec in np.arange(start_sec, end_sec + 1): 84 | if not ret: break 85 | sec_base = int(sec * fps) 86 | for ind in inds: 87 | cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) 88 | ret, frame = cap.read() 89 | if not ret: break 90 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 91 | 92 | smaller_edge = min(frame_rgb.shape[0], frame_rgb.shape[1]) 93 | frame_rgb = np.array(CenterCrop(smaller_edge)(Image.fromarray(frame_rgb).convert("RGB"))) 94 | patches = [] 95 | 96 | 97 | if patch>1: 98 | for i in range(patch): 99 | x_crop_start = i * int(frame_rgb.shape[0]/patch) 100 | if overlapped>0 and i!=0: 101 | x_crop_start-= int((frame_rgb.shape[0]/(patch))*overlapped) 102 | 103 | x_crop_end = (i+1)*int(frame_rgb.shape[0]/patch) 104 | if overlapped>0: 105 | x_crop_end+= int((frame_rgb.shape[0]/(patch))*overlapped) 106 | if i==patch-1: 107 | x_crop_end = frame_rgb.shape[0] 108 | 109 | for j in range(patch): 110 | y_crop_start = j*int(frame_rgb.shape[1]/patch) 111 | if overlapped>0 and j!=0: 112 | y_crop_start -= int((frame_rgb.shape[1]/(patch))*overlapped) 113 | 114 | y_crop_end = (j+1)*int(frame_rgb.shape[1]/patch) 115 | if overlapped>0: 116 | y_crop_end += int((frame_rgb.shape[1]/(patch))*overlapped) 117 | if j==patch-1: 118 | y_crop_end = frame_rgb.shape[1] 119 | 120 | cropped_frame = frame_rgb[x_crop_start:x_crop_end, y_crop_start:y_crop_end, :] 121 | patches.append(preprocess(Image.fromarray(cropped_frame).convert("RGB"))) 122 | images.append(np.stack(patches)) 123 | else: 124 | images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) 125 | 126 | cap.release() 127 | 128 | if len(images) > 0: 129 | video_data = th.tensor(np.stack(images)) 130 | else: 131 | video_data = th.zeros(1) 132 | return {'video': video_data},video_data.shape 133 | 134 | def get_video_data(self, video_path, start_time=None, end_time=None, patch=0, overlapped=0): 135 | image_input,shapes = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time, patch=patch, overlapped=overlapped) 136 | return image_input,shapes 137 | 138 | def process_raw_data(self, raw_video_data): 139 | tensor_size = raw_video_data.size() 140 | tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1]) 141 | return tensor 142 | 143 | def process_frame_order(self, raw_video_data, frame_order=0): 144 | # 0: ordinary order; 1: reverse order; 2: random order. 145 | if frame_order == 0: 146 | pass 147 | elif frame_order == 1: 148 | reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) 149 | raw_video_data = raw_video_data[reverse_order, ...] 150 | elif frame_order == 2: 151 | random_order = np.arange(raw_video_data.size(0)) 152 | np.random.shuffle(random_order) 153 | raw_video_data = raw_video_data[random_order, ...] 154 | 155 | return raw_video_data 156 | 157 | # An ordinary video frame extractor based CV2 158 | RawVideoExtractor = RawVideoExtractorCV2 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /feature_extractor/modules/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /modules/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model. """ 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | grad = p.grad.data 118 | if grad.is_sparse: 119 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 120 | 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | state['step'] = 0 126 | # Exponential moving average of gradient values 127 | state['next_m'] = torch.zeros_like(p.data) 128 | # Exponential moving average of squared gradient values 129 | state['next_v'] = torch.zeros_like(p.data) 130 | 131 | next_m, next_v = state['next_m'], state['next_v'] 132 | beta1, beta2 = group['b1'], group['b2'] 133 | 134 | # Add grad clipping 135 | if group['max_grad_norm'] > 0: 136 | clip_grad_norm_(p, group['max_grad_norm']) 137 | 138 | # Decay the first and second moment running average coefficient 139 | # In-place operations to update the averages at the same time 140 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 141 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 142 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 143 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | else: 161 | lr_scheduled = group['lr'] 162 | 163 | update_with_lr = lr_scheduled * update 164 | p.data.add_(-update_with_lr) 165 | 166 | state['step'] += 1 167 | 168 | return loss -------------------------------------------------------------------------------- /dataloaders/dataloader_msrvtt_feats.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pickle5 as pickle 10 | import pandas as pd 11 | from collections import defaultdict 12 | import json 13 | import random 14 | from tqdm import tqdm 15 | from scipy import sparse 16 | import glob 17 | 18 | class MSRVTT_Feats_DataLoader(Dataset): 19 | """Implementation of the dataloader for MSRVTT. Mainly used in the model training and evaluation. 20 | Params: 21 | json_path: Path to the MSRVTT_data.json file. 22 | features_path: Path to the extracted feature file. 23 | tokenizer: Tokenizer used for tokenizing the caption. 24 | max_words: Max word length retained. Any more than the value will be truncated. Default: 30 25 | feature_framerate: sampling rate in second. Default: 1.0 26 | max_frames: Max frame sampled. Any more than the value will be ignored. Default: 100 27 | split_type: Either "train", "val", or "test". Default: "" 28 | """ 29 | def __init__( 30 | self, 31 | json_path, 32 | features_path, 33 | tokenizer, 34 | max_words=30, 35 | feature_framerate=1.0, 36 | max_frames=100, 37 | split_type="", 38 | ): 39 | self.data = json.load(open(json_path, 'r')) 40 | self.feature_dict = pickle.load(open(features_path, 'rb')) 41 | self.feature_framerate = feature_framerate 42 | self.max_words = max_words 43 | self.max_frames = max_frames 44 | self.tokenizer = tokenizer 45 | 46 | self.feature_size = self.feature_dict[next(iter(self.feature_dict))].shape[-1] 47 | 48 | assert split_type in ["train", "val", "test"] 49 | # Train: video0 : video6512 (6513) 50 | # Val: video6513 : video7009 (497) 51 | # Test: video7010 : video9999 (2990) 52 | video_ids = [self.data['videos'][idx]['video_id'] for idx in range(len(self.data['videos']))] 53 | split_dict = {"train": video_ids[:6513], "val": video_ids[6513:6513 + 497], "test": video_ids[6513 + 497:]} 54 | choiced_video_ids = split_dict[split_type] 55 | 56 | self.sample_len = 0 57 | self.sentences_dict = {} 58 | self.video_sentences_dict = defaultdict(list) 59 | if split_type == "train": # expand all sentence to train 60 | for itm in self.data['sentences']: 61 | if itm['video_id'] in choiced_video_ids: 62 | self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption']) 63 | self.video_sentences_dict[itm['video_id']].append(itm['caption']) 64 | elif split_type == "val" or split_type == "test": 65 | for itm in self.data['sentences']: 66 | if itm['video_id'] in choiced_video_ids: 67 | self.video_sentences_dict[itm['video_id']].append(itm['caption']) 68 | for vid in choiced_video_ids: 69 | self.sentences_dict[len(self.sentences_dict)] = (vid, self.video_sentences_dict[vid][0]) 70 | else: 71 | raise NotImplementedError 72 | 73 | self.sample_len = len(self.sentences_dict) 74 | 75 | 76 | 77 | def __len__(self): 78 | return self.sample_len 79 | 80 | def _get_text(self, video_id, caption=None): 81 | k = 1 82 | choice_video_ids = [video_id] 83 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 84 | 85 | pairs_input_caption_ids = np.zeros((k, self.max_words), dtype=np.long) 86 | pairs_output_caption_ids = np.zeros((k, self.max_words), dtype=np.long) 87 | pairs_decoder_mask = np.zeros((k, self.max_words), dtype=np.long) 88 | 89 | for i, video_id in enumerate(choice_video_ids): 90 | words = [] 91 | words = ["[CLS]"] + words 92 | total_length_with_CLS = self.max_words - 1 93 | if len(words) > total_length_with_CLS: 94 | words = words[:total_length_with_CLS] 95 | words = words + ["[SEP]"] 96 | 97 | 98 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 99 | while len(input_ids) < self.max_words: 100 | input_ids.append(0) 101 | assert len(input_ids) == self.max_words 102 | 103 | pairs_text[i] = np.array(input_ids) 104 | 105 | # For generate captions 106 | if caption is not None: 107 | caption_words = self.tokenizer.tokenize(caption) 108 | else: 109 | caption_words = self._get_single_text(video_id) 110 | if len(caption_words) > total_length_with_CLS: 111 | caption_words = caption_words[:total_length_with_CLS] 112 | input_caption_words = ["[CLS]"] + caption_words 113 | output_caption_words = caption_words + ["[SEP]"] 114 | 115 | # For generate captions 116 | input_caption_ids = self.tokenizer.convert_tokens_to_ids(input_caption_words) 117 | output_caption_ids = self.tokenizer.convert_tokens_to_ids(output_caption_words) 118 | decoder_mask = [1] * len(input_caption_ids) 119 | while len(input_caption_ids) < self.max_words: 120 | input_caption_ids.append(0) 121 | output_caption_ids.append(0) 122 | decoder_mask.append(0) 123 | assert len(input_caption_ids) == self.max_words 124 | assert len(output_caption_ids) == self.max_words 125 | assert len(decoder_mask) == self.max_words 126 | 127 | pairs_input_caption_ids[i] = np.array(input_caption_ids) 128 | pairs_output_caption_ids[i] = np.array(output_caption_ids) 129 | pairs_decoder_mask[i] = np.array(decoder_mask) 130 | 131 | return pairs_text, np.array([]), np.array([]), np.array([]), np.array([]), \ 132 | pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, choice_video_ids 133 | 134 | def _get_single_text(self, video_id): 135 | rind = random.randint(0, len(self.sentences[video_id]) - 1) 136 | caption = self.sentences[video_id][rind] 137 | words = self.tokenizer.tokenize(caption) 138 | return words 139 | 140 | def _get_video(self, choice_video_ids): 141 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 142 | 143 | max_video_length = [0] * len(choice_video_ids) 144 | 145 | video = np.zeros((len(choice_video_ids), self.max_frames, self.feature_size), dtype=np.float) 146 | for i, video_id in enumerate(choice_video_ids): 147 | video_slice = self.feature_dict[video_id] 148 | 149 | if self.max_frames < video_slice.shape[0]: 150 | video_slice = video_slice[:self.max_frames] 151 | 152 | slice_shape = video_slice.shape 153 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0] 154 | if len(video_slice) < 1: 155 | print("video_id: {}".format(video_id)) 156 | else: 157 | video[i][:slice_shape[0]] = video_slice 158 | 159 | return video, video_mask, np.array([]), np.array([]) 160 | 161 | def __getitem__(self, idx): 162 | video_id, caption = self.sentences_dict[idx] 163 | pairs_text, pairs_mask, pairs_segment, \ 164 | pairs_masked_text, pairs_token_labels, \ 165 | pairs_input_caption_ids, pairs_decoder_mask, \ 166 | pairs_output_caption_ids, choice_video_ids = self._get_text(video_id, caption) 167 | 168 | video, video_mask, masked_video, video_labels_index = self._get_video(choice_video_ids) 169 | 170 | 171 | pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = np.array([]),np.array([]),np.array([]),np.array([]),np.array([]),np.array([]) 172 | 173 | 174 | return pairs_text, pairs_mask, pairs_segment, video, video_mask, \ 175 | pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \ 176 | pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids -------------------------------------------------------------------------------- /dataloaders/dataloader_msvd_feats.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | import pickle5 as pickle 10 | import pandas as pd 11 | from collections import defaultdict 12 | import json 13 | import random 14 | from tqdm import tqdm 15 | import glob 16 | from scipy import sparse 17 | 18 | 19 | class MSVD_Feats_DataLoader(Dataset): 20 | """Implementation of the dataloader for MSRVTT. Mainly used in the model training and evaluation. 21 | Params: 22 | data_path: Path to the MSVD folder. 23 | features_path: Path to the extracted feature file. 24 | tokenizer: Tokenizer used for tokenizing the caption. 25 | max_words: Max word length retained. Any more than the value will be truncated. Default: 30 26 | feature_framerate: sampling rate in second. Default: 1.0 27 | max_frames: Max frame sampled. Any more than the value will be ignored. Default: 100 28 | split_type: Either "train", "val", or "test". Default: "" 29 | """ 30 | 31 | def __init__( 32 | self, 33 | data_path, 34 | features_path, 35 | tokenizer, 36 | max_words=30, 37 | feature_framerate=1.0, 38 | max_frames=100, 39 | split_type="" 40 | ): 41 | self.data_path = data_path 42 | self.features_path = features_path 43 | self.feature_dict = pickle.load(open(features_path, 'rb')) 44 | self.feature_framerate = feature_framerate 45 | self.max_words = max_words 46 | self.max_frames = max_frames 47 | self.tokenizer = tokenizer 48 | 49 | assert split_type in ["train", "val", "test"] 50 | 51 | split_dict = {} 52 | # video_ids = [self.data['videos'][idx]['video_id'] for idx in range(len(self.data['videos']))] 53 | split_dict["train"] = os.path.join(self.data_path, "train_list_mapping.txt") 54 | split_dict["val"] = os.path.join(self.data_path, "val_list_mapping.txt") 55 | split_dict["test"] = os.path.join(self.data_path, "test_list_mapping.txt") 56 | caption_file = os.path.join(self.data_path, "raw-captions_mapped.pkl") 57 | self.feature_size = self.feature_dict['vid1'].shape[-1] 58 | with open(caption_file, 'rb') as f: 59 | captions = pickle.load(f) 60 | 61 | with open(split_dict[split_type], 'r') as fp: 62 | choiced_video_ids = [itm.strip() for itm in fp.readlines()] 63 | # choiced_video_ids = split_dict[split_type] 64 | 65 | self.sample_len = 0 66 | self.sentences_dict = {} 67 | self.video_sentences_dict = defaultdict(list) 68 | if split_type == "train": # expand all sentence to train 69 | for video_id in captions: 70 | if video_id in choiced_video_ids: 71 | for cap in captions[video_id]: 72 | cap_txt = " ".join(cap) 73 | self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) 74 | self.video_sentences_dict[video_id].append(cap_txt) 75 | elif split_type == "val" or split_type == "test": 76 | for itm in captions: 77 | if itm in choiced_video_ids: 78 | for cap in captions[itm]: 79 | cap_txt = " ".join(cap) 80 | self.video_sentences_dict[itm].append(cap_txt) 81 | for vid in choiced_video_ids: 82 | self.sentences_dict[len(self.sentences_dict)] = (vid, self.video_sentences_dict[vid][0]) 83 | else: 84 | raise NotImplementedError 85 | 86 | self.sample_len = len(self.sentences_dict) 87 | 88 | 89 | def __len__(self): 90 | return self.sample_len 91 | 92 | def _get_text(self, video_id, caption=None): 93 | k = 1 94 | choice_video_ids = [video_id] 95 | pairs_text = np.zeros((k, self.max_words), dtype=np.long) 96 | 97 | pairs_input_caption_ids = np.zeros((k, self.max_words), dtype=np.long) 98 | pairs_output_caption_ids = np.zeros((k, self.max_words), dtype=np.long) 99 | pairs_decoder_mask = np.zeros((k, self.max_words), dtype=np.long) 100 | 101 | for i, video_id in enumerate(choice_video_ids): 102 | words = [] 103 | words = ["[CLS]"] + words 104 | total_length_with_CLS = self.max_words - 1 105 | if len(words) > total_length_with_CLS: 106 | words = words[:total_length_with_CLS] 107 | words = words + ["[SEP]"] 108 | 109 | input_ids = self.tokenizer.convert_tokens_to_ids(words) 110 | while len(input_ids) < self.max_words: 111 | input_ids.append(0) 112 | assert len(input_ids) == self.max_words 113 | 114 | pairs_text[i] = np.array(input_ids) 115 | 116 | # For generate captions 117 | if caption is not None: 118 | caption_words = self.tokenizer.tokenize(caption) 119 | else: 120 | caption_words = self._get_single_text(video_id) 121 | if len(caption_words) > total_length_with_CLS: 122 | caption_words = caption_words[:total_length_with_CLS] 123 | input_caption_words = ["[CLS]"] + caption_words 124 | output_caption_words = caption_words + ["[SEP]"] 125 | 126 | # For generate captions 127 | input_caption_ids = self.tokenizer.convert_tokens_to_ids(input_caption_words) 128 | output_caption_ids = self.tokenizer.convert_tokens_to_ids(output_caption_words) 129 | decoder_mask = [1] * len(input_caption_ids) 130 | while len(input_caption_ids) < self.max_words: 131 | input_caption_ids.append(0) 132 | output_caption_ids.append(0) 133 | decoder_mask.append(0) 134 | assert len(input_caption_ids) == self.max_words 135 | assert len(output_caption_ids) == self.max_words 136 | assert len(decoder_mask) == self.max_words 137 | 138 | pairs_input_caption_ids[i] = np.array(input_caption_ids) 139 | pairs_output_caption_ids[i] = np.array(output_caption_ids) 140 | pairs_decoder_mask[i] = np.array(decoder_mask) 141 | 142 | return pairs_text, np.array([]), np.array([]), np.array([]), np.array([]), \ 143 | pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, choice_video_ids 144 | 145 | def _get_single_text(self, video_id): 146 | rind = random.randint(0, len(self.sentences[video_id]) - 1) 147 | caption = self.sentences[video_id][rind] 148 | words = self.tokenizer.tokenize(caption) 149 | return words 150 | 151 | def _get_video(self, choice_video_ids): 152 | # print(choice_video_ids) 153 | video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) 154 | max_video_length = [0] * len(choice_video_ids) 155 | 156 | video = np.zeros((len(choice_video_ids), self.max_frames, self.feature_size), dtype=np.float) 157 | for i, video_id in enumerate(choice_video_ids): 158 | video_slice = self.feature_dict[video_id] 159 | 160 | if self.max_frames < video_slice.shape[0]: 161 | video_slice = video_slice[:self.max_frames] 162 | 163 | slice_shape = video_slice.shape 164 | max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0] 165 | if len(video_slice) < 1: 166 | print("video_id: {}".format(video_id)) 167 | else: 168 | video[i][:slice_shape[0]] = video_slice 169 | 170 | return video, video_mask, np.array([]), np.array([]) 171 | 172 | def __getitem__(self, idx): 173 | video_id, caption = self.sentences_dict[idx] 174 | pairs_text, pairs_mask, pairs_segment, \ 175 | pairs_masked_text, pairs_token_labels, \ 176 | pairs_input_caption_ids, pairs_decoder_mask, \ 177 | pairs_output_caption_ids, choice_video_ids = self._get_text(video_id, caption) 178 | 179 | video, video_mask, masked_video, video_labels_index = self._get_video(choice_video_ids) 180 | 181 | pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = np.array( 182 | []), np.array([]), np.array([]), np.array([]), np.array([]), np.array([]) 183 | 184 | return pairs_text, pairs_mask, pairs_segment, video, video_mask, \ 185 | pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \ 186 | pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids -------------------------------------------------------------------------------- /dataset/MSVD/test_list_mapping.txt: -------------------------------------------------------------------------------- 1 | vid1301 2 | vid1302 3 | vid1303 4 | vid1304 5 | vid1305 6 | vid1306 7 | vid1307 8 | vid1308 9 | vid1309 10 | vid1310 11 | vid1311 12 | vid1312 13 | vid1313 14 | vid1314 15 | vid1315 16 | vid1316 17 | vid1317 18 | vid1318 19 | vid1319 20 | vid1320 21 | vid1321 22 | vid1322 23 | vid1323 24 | vid1324 25 | vid1325 26 | vid1326 27 | vid1327 28 | vid1328 29 | vid1329 30 | vid1330 31 | vid1331 32 | vid1332 33 | vid1333 34 | vid1334 35 | vid1335 36 | vid1336 37 | vid1337 38 | vid1338 39 | vid1339 40 | vid1340 41 | vid1341 42 | vid1342 43 | vid1343 44 | vid1344 45 | vid1345 46 | vid1346 47 | vid1347 48 | vid1348 49 | vid1349 50 | vid1350 51 | vid1351 52 | vid1352 53 | vid1353 54 | vid1354 55 | vid1355 56 | vid1356 57 | vid1357 58 | vid1358 59 | vid1359 60 | vid1360 61 | vid1361 62 | vid1362 63 | vid1363 64 | vid1364 65 | vid1365 66 | vid1366 67 | vid1367 68 | vid1368 69 | vid1369 70 | vid1370 71 | vid1371 72 | vid1372 73 | vid1373 74 | vid1374 75 | vid1375 76 | vid1376 77 | vid1377 78 | vid1378 79 | vid1379 80 | vid1380 81 | vid1381 82 | vid1382 83 | vid1383 84 | vid1384 85 | vid1385 86 | vid1386 87 | vid1387 88 | vid1388 89 | vid1389 90 | vid1390 91 | vid1391 92 | vid1392 93 | vid1393 94 | vid1394 95 | vid1395 96 | vid1396 97 | vid1397 98 | vid1398 99 | vid1399 100 | vid1400 101 | vid1401 102 | vid1402 103 | vid1403 104 | vid1404 105 | vid1405 106 | vid1406 107 | vid1407 108 | vid1408 109 | vid1409 110 | vid1410 111 | vid1411 112 | vid1412 113 | vid1413 114 | vid1414 115 | vid1415 116 | vid1416 117 | vid1417 118 | vid1418 119 | vid1419 120 | vid1420 121 | vid1421 122 | vid1422 123 | vid1423 124 | vid1424 125 | vid1425 126 | vid1426 127 | vid1427 128 | vid1428 129 | vid1429 130 | vid1430 131 | vid1431 132 | vid1432 133 | vid1433 134 | vid1434 135 | vid1435 136 | vid1436 137 | vid1437 138 | vid1438 139 | vid1439 140 | vid1440 141 | vid1441 142 | vid1442 143 | vid1443 144 | vid1444 145 | vid1445 146 | vid1446 147 | vid1447 148 | vid1448 149 | vid1449 150 | vid1450 151 | vid1451 152 | vid1452 153 | vid1453 154 | vid1454 155 | vid1455 156 | vid1456 157 | vid1457 158 | vid1458 159 | vid1459 160 | vid1460 161 | vid1461 162 | vid1462 163 | vid1463 164 | vid1464 165 | vid1465 166 | vid1466 167 | vid1467 168 | vid1468 169 | vid1469 170 | vid1470 171 | vid1471 172 | vid1472 173 | vid1473 174 | vid1474 175 | vid1475 176 | vid1476 177 | vid1477 178 | vid1478 179 | vid1479 180 | vid1480 181 | vid1481 182 | vid1482 183 | vid1483 184 | vid1484 185 | vid1485 186 | vid1486 187 | vid1487 188 | vid1488 189 | vid1489 190 | vid1490 191 | vid1491 192 | vid1492 193 | vid1493 194 | vid1494 195 | vid1495 196 | vid1496 197 | vid1497 198 | vid1498 199 | vid1499 200 | vid1500 201 | vid1501 202 | vid1502 203 | vid1503 204 | vid1504 205 | vid1505 206 | vid1506 207 | vid1507 208 | vid1508 209 | vid1509 210 | vid1510 211 | vid1511 212 | vid1512 213 | vid1513 214 | vid1514 215 | vid1515 216 | vid1516 217 | vid1517 218 | vid1518 219 | vid1519 220 | vid1520 221 | vid1521 222 | vid1522 223 | vid1523 224 | vid1524 225 | vid1525 226 | vid1526 227 | vid1527 228 | vid1528 229 | vid1529 230 | vid1530 231 | vid1531 232 | vid1532 233 | vid1533 234 | vid1534 235 | vid1535 236 | vid1536 237 | vid1537 238 | vid1538 239 | vid1539 240 | vid1540 241 | vid1541 242 | vid1542 243 | vid1543 244 | vid1544 245 | vid1545 246 | vid1546 247 | vid1547 248 | vid1548 249 | vid1549 250 | vid1550 251 | vid1551 252 | vid1552 253 | vid1553 254 | vid1554 255 | vid1555 256 | vid1556 257 | vid1557 258 | vid1558 259 | vid1559 260 | vid1560 261 | vid1561 262 | vid1562 263 | vid1563 264 | vid1564 265 | vid1565 266 | vid1566 267 | vid1567 268 | vid1568 269 | vid1569 270 | vid1570 271 | vid1571 272 | vid1572 273 | vid1573 274 | vid1574 275 | vid1575 276 | vid1576 277 | vid1577 278 | vid1578 279 | vid1579 280 | vid1580 281 | vid1581 282 | vid1582 283 | vid1583 284 | vid1584 285 | vid1585 286 | vid1586 287 | vid1587 288 | vid1588 289 | vid1589 290 | vid1590 291 | vid1591 292 | vid1592 293 | vid1593 294 | vid1594 295 | vid1595 296 | vid1596 297 | vid1597 298 | vid1598 299 | vid1599 300 | vid1600 301 | vid1601 302 | vid1602 303 | vid1603 304 | vid1604 305 | vid1605 306 | vid1606 307 | vid1607 308 | vid1608 309 | vid1609 310 | vid1610 311 | vid1611 312 | vid1612 313 | vid1613 314 | vid1614 315 | vid1615 316 | vid1616 317 | vid1617 318 | vid1618 319 | vid1619 320 | vid1620 321 | vid1621 322 | vid1622 323 | vid1623 324 | vid1624 325 | vid1625 326 | vid1626 327 | vid1627 328 | vid1628 329 | vid1629 330 | vid1630 331 | vid1631 332 | vid1632 333 | vid1633 334 | vid1634 335 | vid1635 336 | vid1636 337 | vid1637 338 | vid1638 339 | vid1639 340 | vid1640 341 | vid1641 342 | vid1642 343 | vid1643 344 | vid1644 345 | vid1645 346 | vid1646 347 | vid1647 348 | vid1648 349 | vid1649 350 | vid1650 351 | vid1651 352 | vid1652 353 | vid1653 354 | vid1654 355 | vid1655 356 | vid1656 357 | vid1657 358 | vid1658 359 | vid1659 360 | vid1660 361 | vid1661 362 | vid1662 363 | vid1663 364 | vid1664 365 | vid1665 366 | vid1666 367 | vid1667 368 | vid1668 369 | vid1669 370 | vid1670 371 | vid1671 372 | vid1672 373 | vid1673 374 | vid1674 375 | vid1675 376 | vid1676 377 | vid1677 378 | vid1678 379 | vid1679 380 | vid1680 381 | vid1681 382 | vid1682 383 | vid1683 384 | vid1684 385 | vid1685 386 | vid1686 387 | vid1687 388 | vid1688 389 | vid1689 390 | vid1690 391 | vid1691 392 | vid1692 393 | vid1693 394 | vid1694 395 | vid1695 396 | vid1696 397 | vid1697 398 | vid1698 399 | vid1699 400 | vid1700 401 | vid1701 402 | vid1702 403 | vid1703 404 | vid1704 405 | vid1705 406 | vid1706 407 | vid1707 408 | vid1708 409 | vid1709 410 | vid1710 411 | vid1711 412 | vid1712 413 | vid1713 414 | vid1714 415 | vid1715 416 | vid1716 417 | vid1717 418 | vid1718 419 | vid1719 420 | vid1720 421 | vid1721 422 | vid1722 423 | vid1723 424 | vid1724 425 | vid1725 426 | vid1726 427 | vid1727 428 | vid1728 429 | vid1729 430 | vid1730 431 | vid1731 432 | vid1732 433 | vid1733 434 | vid1734 435 | vid1735 436 | vid1736 437 | vid1737 438 | vid1738 439 | vid1739 440 | vid1740 441 | vid1741 442 | vid1742 443 | vid1743 444 | vid1744 445 | vid1745 446 | vid1746 447 | vid1747 448 | vid1748 449 | vid1749 450 | vid1750 451 | vid1751 452 | vid1752 453 | vid1753 454 | vid1754 455 | vid1755 456 | vid1756 457 | vid1757 458 | vid1758 459 | vid1759 460 | vid1760 461 | vid1761 462 | vid1762 463 | vid1763 464 | vid1764 465 | vid1765 466 | vid1766 467 | vid1767 468 | vid1768 469 | vid1769 470 | vid1770 471 | vid1771 472 | vid1772 473 | vid1773 474 | vid1774 475 | vid1775 476 | vid1776 477 | vid1777 478 | vid1778 479 | vid1779 480 | vid1780 481 | vid1781 482 | vid1782 483 | vid1783 484 | vid1784 485 | vid1785 486 | vid1786 487 | vid1787 488 | vid1788 489 | vid1789 490 | vid1790 491 | vid1791 492 | vid1792 493 | vid1793 494 | vid1794 495 | vid1795 496 | vid1796 497 | vid1797 498 | vid1798 499 | vid1799 500 | vid1800 501 | vid1801 502 | vid1802 503 | vid1803 504 | vid1804 505 | vid1805 506 | vid1806 507 | vid1807 508 | vid1808 509 | vid1809 510 | vid1810 511 | vid1811 512 | vid1812 513 | vid1813 514 | vid1814 515 | vid1815 516 | vid1816 517 | vid1817 518 | vid1818 519 | vid1819 520 | vid1820 521 | vid1821 522 | vid1822 523 | vid1823 524 | vid1824 525 | vid1825 526 | vid1826 527 | vid1827 528 | vid1828 529 | vid1829 530 | vid1830 531 | vid1831 532 | vid1832 533 | vid1833 534 | vid1834 535 | vid1835 536 | vid1836 537 | vid1837 538 | vid1838 539 | vid1839 540 | vid1840 541 | vid1841 542 | vid1842 543 | vid1843 544 | vid1844 545 | vid1845 546 | vid1846 547 | vid1847 548 | vid1848 549 | vid1849 550 | vid1850 551 | vid1851 552 | vid1852 553 | vid1853 554 | vid1854 555 | vid1855 556 | vid1856 557 | vid1857 558 | vid1858 559 | vid1859 560 | vid1860 561 | vid1861 562 | vid1862 563 | vid1863 564 | vid1864 565 | vid1865 566 | vid1866 567 | vid1867 568 | vid1868 569 | vid1869 570 | vid1870 571 | vid1871 572 | vid1872 573 | vid1873 574 | vid1874 575 | vid1875 576 | vid1876 577 | vid1877 578 | vid1878 579 | vid1879 580 | vid1880 581 | vid1881 582 | vid1882 583 | vid1883 584 | vid1884 585 | vid1885 586 | vid1886 587 | vid1887 588 | vid1888 589 | vid1889 590 | vid1890 591 | vid1891 592 | vid1892 593 | vid1893 594 | vid1894 595 | vid1895 596 | vid1896 597 | vid1897 598 | vid1898 599 | vid1899 600 | vid1900 601 | vid1901 602 | vid1902 603 | vid1903 604 | vid1904 605 | vid1905 606 | vid1906 607 | vid1907 608 | vid1908 609 | vid1909 610 | vid1910 611 | vid1911 612 | vid1912 613 | vid1913 614 | vid1914 615 | vid1915 616 | vid1916 617 | vid1917 618 | vid1918 619 | vid1919 620 | vid1920 621 | vid1921 622 | vid1922 623 | vid1923 624 | vid1924 625 | vid1925 626 | vid1926 627 | vid1927 628 | vid1928 629 | vid1929 630 | vid1930 631 | vid1931 632 | vid1932 633 | vid1933 634 | vid1934 635 | vid1935 636 | vid1936 637 | vid1937 638 | vid1938 639 | vid1939 640 | vid1940 641 | vid1941 642 | vid1942 643 | vid1943 644 | vid1944 645 | vid1945 646 | vid1946 647 | vid1947 648 | vid1948 649 | vid1949 650 | vid1950 651 | vid1951 652 | vid1952 653 | vid1953 654 | vid1954 655 | vid1955 656 | vid1956 657 | vid1957 658 | vid1958 659 | vid1959 660 | vid1960 661 | vid1961 662 | vid1962 663 | vid1963 664 | vid1964 665 | vid1965 666 | vid1966 667 | vid1967 668 | vid1968 669 | vid1969 670 | vid1970 -------------------------------------------------------------------------------- /feature_extractor/modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | """Get method.""" 155 | req = requests.get(url, stream=True) 156 | content_length = req.headers.get('Content-Length') 157 | total = int(content_length) if content_length is not None else None 158 | progress = tqdm(unit="B", total=total) 159 | for chunk in req.iter_content(chunk_size=1024): 160 | if chunk: # filter out keep-alive new chunks 161 | progress.update(len(chunk)) 162 | temp_file.write(chunk) 163 | progress.close() 164 | 165 | 166 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 167 | """ 168 | Given a URL, look for the corresponding dataset in the local cache. 169 | If it's not there, download it. Then return the path to the cached file. 170 | """ 171 | if cache_dir is None: 172 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 173 | if isinstance(cache_dir, Path): 174 | cache_dir = str(cache_dir) 175 | 176 | os.makedirs(cache_dir, exist_ok=True) 177 | 178 | # Get eTag to add to filename, if it exists. 179 | if url.startswith("s3://"): 180 | etag = s3_etag(url) 181 | else: 182 | response = requests.head(url, allow_redirects=True) 183 | if response.status_code != 200: 184 | raise IOError("HEAD request failed for url {} with status code {}" 185 | .format(url, response.status_code)) 186 | etag = response.headers.get("ETag") 187 | 188 | filename = url_to_filename(url, etag) 189 | 190 | # get cache path to put the file 191 | cache_path = os.path.join(cache_dir, filename) 192 | 193 | if not os.path.exists(cache_path): 194 | # Download to temporary file, then copy to cache dir once finished. 195 | # Otherwise you get corrupt cache entries if the download gets interrupted. 196 | with tempfile.NamedTemporaryFile() as temp_file: 197 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 198 | 199 | # GET file object 200 | if url.startswith("s3://"): 201 | s3_get(url, temp_file) 202 | else: 203 | http_get(url, temp_file) 204 | 205 | # we are copying the file before closing it, so flush to avoid truncation 206 | temp_file.flush() 207 | # shutil.copyfileobj() starts at the current position, so go to the start 208 | temp_file.seek(0) 209 | 210 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 211 | with open(cache_path, 'wb') as cache_file: 212 | shutil.copyfileobj(temp_file, cache_file) 213 | 214 | logger.info("creating metadata file for %s", cache_path) 215 | meta = {'url': url, 'etag': etag} 216 | meta_path = cache_path + '.json' 217 | with open(meta_path, 'w') as meta_file: 218 | json.dump(meta, meta_file) 219 | 220 | logger.info("removing temp file %s", temp_file.name) 221 | 222 | return cache_path 223 | 224 | 225 | def read_set_from_file(filename: str) -> Set[str]: 226 | """ 227 | Extract a de-duped collection (set) of text from a file. 228 | Expected file format is one item per line. 229 | """ 230 | collection = set() 231 | with open(filename, 'r', encoding='utf-8') as file_: 232 | for line in file_: 233 | collection.add(line.rstrip()) 234 | return collection 235 | 236 | 237 | def get_file_extension(path: str, dot=True, lower: bool = True): 238 | """Return fie extension.""" 239 | ext = os.path.splitext(path)[1] 240 | ext = ext if dot else ext[1:] 241 | return ext.lower() if lower else ext 242 | -------------------------------------------------------------------------------- /feature_extractor/utility/util.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import h5py 3 | import os 4 | import sys 5 | import logging 6 | import re 7 | import spacy 8 | import torch 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from tqdm import tqdm 15 | from utility.vocabulary import * 16 | 17 | from pycocoevalcap.bleu.bleu import Bleu 18 | from pycocoevalcap.rouge.rouge import Rouge 19 | from pycocoevalcap.cider.cider import Cider 20 | from pycocoevalcap.meteor.meteor import Meteor 21 | from sklearn.metrics import precision_score, f1_score, recall_score 22 | 23 | tqdm.pandas() 24 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | # Set seed for reproducible result 27 | def init_seed(seed=1, use_cuda=False): 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | if use_cuda: 31 | torch.cuda.manual_seed(seed) 32 | 33 | 34 | # Initialize file handler object 35 | def init_log(save_dir='saved/log/', filename='log.txt', log_format='%(message)s'): 36 | logger = logging.getLogger(__name__) 37 | if not logger.hasHandlers(): 38 | create_folder(save_dir) 39 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format) 40 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 41 | fh.setFormatter(logging.Formatter(log_format)) 42 | logging.getLogger().addHandler(fh) 43 | 44 | 45 | def init_tensorboard(save_dir='saved/tensorboard/'): 46 | create_folder(save_dir) 47 | writer = SummaryWriter(save_dir) 48 | return writer 49 | 50 | 51 | def create_folder(save_dir): 52 | if not os.path.exists(save_dir): 53 | os.makedirs(save_dir) 54 | 55 | def convert_batched_graph_feat(features, adj): 56 | graphs = [] 57 | weights = [] 58 | for k in range(len(adj)): 59 | u = [] 60 | v = [] 61 | w = [] 62 | for i in range(100): 63 | for j in range(100): 64 | u.append(i) 65 | v.append(j) 66 | w.append(adj[k][i][j]) 67 | g = dgl.graph((u, v)) 68 | graphs.append(g) 69 | weights.append(torch.stack(w)) 70 | 71 | weights = torch.stack(weights) 72 | weights = torch.flatten(weights) 73 | features = torch.flatten(features, end_dim=1) 74 | graphs = dgl.batch(graphs) 75 | return graphs, features, weights 76 | 77 | def calculate_longest_sentence(series): 78 | tokenizer = spacy.load('en_core_web_sm') 79 | longest_sentence = 0 80 | for sentence in tqdm(series): 81 | sentence_len = 0 82 | for word in tokenizer(sentence): 83 | sentence_len += 1 84 | if sentence_len > longest_sentence: 85 | longest_sentence = sentence_len 86 | 87 | return longest_sentence 88 | 89 | def generate_caption_data(data='msvd_train', n_video=5, vocab=None, device="cuda",path="dataset/MSVD/captions/sents_%s_lc_nopunc.txt", 90 | min_word_count=0): 91 | 92 | if 'msvd' in data: 93 | used_captions = pd.read_csv(path % data.split("_")[1],\ 94 | sep='\t', header=None, names=["vid_id", "caption"]) 95 | 96 | # Start index for MSVD data 97 | start_index = {"msvd_train": 1, "msvd_val": 1201, "msvd_test": 1301} 98 | 99 | # Create a video_id query 100 | chosen_keys = ["vid%s" % x for x in range(start_index[data], start_index[data]+n_video)] 101 | used_captions = used_captions[used_captions['vid_id'].isin(chosen_keys)] 102 | 103 | if vocab is None: 104 | # Instantiate new vocabulary 105 | vocab = Vocabulary() 106 | 107 | # Populate vocabulary 108 | print("Populating vocab with %s..." % data) 109 | for caption in tqdm(used_captions['caption']): 110 | vocab.add_sentence(caption) 111 | 112 | print("Original number of words:",vocab.num_words) 113 | if min_word_count>0: 114 | vocab.filter_vocab(min_word_count) 115 | print("Filtered number of words:",vocab.num_words) 116 | 117 | # Create vector caption 118 | print("Converting sentences to indexes...") 119 | used_captions['vector'] = used_captions['caption'].progress_apply(lambda x: vocab.generate_vector(x)) 120 | longest_sentence = vocab.longest_sentence 121 | 122 | else: 123 | # If using val_data/test_data 124 | longest_sentence = calculate_longest_sentence(used_captions['caption']) 125 | used_captions['vector'] = used_captions['caption'].progress_apply(lambda x: vocab.generate_vector(x, longest_sentence)) 126 | 127 | flatten_captions = torch.tensor(used_captions['vector']).to(device=device) 128 | captions_vector = used_captions.groupby("vid_id", sort=False)['vector'].sum()\ 129 | .apply(lambda x: torch.tensor(x).reshape(-1, longest_sentence+2)\ 130 | .to(device=device)).to_dict() 131 | 132 | return captions_vector, flatten_captions, vocab, used_captions 133 | 134 | def generate_2d_3d_features(data='msvd_train', n_video=5, 135 | f2d_path="MSVD-2D.hdf5", f3d_path="MSVD-3D.hdf5", device="cuda"): 136 | scn_2d = h5py.File(f2d_path, "r") 137 | scn_3d = h5py.File(f3d_path, "r") 138 | 139 | # Start index for MSVD data 140 | start_index = {"msvd_train": 1, "msvd_val": 1201, "msvd_test": 1301} 141 | 142 | # Create a video_id query 143 | chosen_keys = ["vid%s" % x for x in range(start_index[data], start_index[data]+n_video)] 144 | 145 | scn_2d_src, scn_3d_src = [], [] 146 | for key in chosen_keys: 147 | scn_2d_src.append(scn_2d.get(key)) 148 | scn_3d_src.append(scn_3d.get(key)) 149 | 150 | return torch.tensor(scn_2d_src).to(device=device), torch.tensor(scn_3d_src).to(device=device) 151 | 152 | def generate_node_features(data="msvd_train", n_video=5, 153 | fo_path="MSVD_FO_FASTERRCNN_RESNET50.hdf5", 154 | stgraph_path="MSVD_IOU_STG_FASTERRCNN_RESNET50.hdf5", device="cuda", generate_fo=True): 155 | 156 | if generate_fo: 157 | fo_file = stack_node_features(fo_path) 158 | stgraph_file = h5py.File(stgraph_path, "r") 159 | 160 | # Start index for MSVD data 161 | start_index = {"msvd_train": 1, "msvd_val": 1201, "msvd_test": 1301} 162 | 163 | # Create a video_id query 164 | excluded_keys = [] 165 | for vid in fo_file.keys(): 166 | if len(fo_file[vid]) != 100: 167 | excluded_keys.append(vid) 168 | chosen_keys = ["vid%s" % x for x in range(start_index[data], start_index[data]+n_video) if "vid%s" % x not in excluded_keys] 169 | 170 | 171 | fo_input, stgraph = [], [] 172 | for key in chosen_keys: 173 | if generate_fo: 174 | fo_input.append(fo_file.get(key)) 175 | stgraph.append(stgraph_file.get(key)) 176 | 177 | if generate_fo: 178 | return torch.tensor(fo_input).to(device=device), torch.tensor(stgraph).to(device=device), excluded_keys 179 | return torch.tensor(stgraph).to(device=device) 180 | 181 | def stack_node_features(pathfile): 182 | 183 | fo_input = h5py.File(pathfile, "r") 184 | fo_list = {} 185 | for i,key in tqdm(enumerate(fo_input.keys()), total=len(fo_input.keys())): 186 | a = key.split('-') 187 | 188 | if a[0] not in fo_list: 189 | fo_list[a[0]] = {} 190 | fo_list[a[0]][int(a[1])] = fo_input[key][:] 191 | 192 | fo_stacked = {} 193 | for key in fo_list.keys(): 194 | stacked = [] 195 | for k_fr in sorted(fo_list[key].keys()): 196 | stacked.append(fo_list[key][k_fr]) 197 | fo_stacked[key] = np.vstack(stacked) 198 | 199 | return fo_stacked 200 | 201 | def score(ref, hypo, metrics=[]): 202 | """ 203 | ref, dictionary of reference sentences (id, sentence) 204 | hypo, dictionary of hypothesis sentences (id, sentence) 205 | score, dictionary of scores 206 | refers: https://github.com/zhegan27/SCN_for_video_captioning/blob/master/SCN_evaluation.py 207 | 208 | metrics, eg. ['bleu', 'meteor','rouge_l','cider'] 209 | """ 210 | scorers = { 211 | "bleu" : (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 212 | "meteor" : (Meteor(),"METEOR"), 213 | "rouge_l" : (Rouge(), "ROUGE_L"), 214 | "cider" : (Cider(), "CIDEr") 215 | } 216 | final_scores = {} 217 | for key in metrics: 218 | scorer, method = scorers[key] 219 | score, scores = scorer.compute_score(ref, hypo) 220 | if type(score) == list: 221 | for m, s in zip(method, score): 222 | final_scores[m] = s 223 | else: 224 | final_scores[method] = score 225 | return final_scores 226 | 227 | def calculate_metrics(pred, target, threshold=0.5): 228 | pred = np.array(pred > threshold, dtype=float) 229 | return { 230 | 'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'), 231 | 'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'), 232 | 'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'), 233 | 234 | 'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'), 235 | 'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'), 236 | 'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'), 237 | 238 | 'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples'), 239 | 'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples'), 240 | 'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples'), 241 | } -------------------------------------------------------------------------------- /modules/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import logging 23 | import numpy as np 24 | 25 | import torch 26 | from torch import nn 27 | import torch.nn.functional as F 28 | from torch.nn import CrossEntropyLoss, MSELoss 29 | 30 | from modules.until_module import PreTrainedModel, LayerNorm, CrossEn 31 | from modules.module_bert import BertModel, BertConfig 32 | from modules.module_visual import VisualModel, VisualConfig, VisualOnlyMLMHead 33 | from modules.module_decoder import DecoderModel, DecoderConfig 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class CaptionGeneratorPreTrainedModel(PreTrainedModel, nn.Module): 39 | """ An abstract class to handle weights initialization and 40 | a simple interface for dowloading and loading pretrained models. 41 | """ 42 | def __init__(self, bert_config, visual_config, decoder_config, *inputs, **kwargs): 43 | # utilize bert config as base config 44 | super(CaptionGeneratorPreTrainedModel, self).__init__(bert_config) 45 | self.bert_config = bert_config 46 | self.visual_config = visual_config 47 | self.decoder_config = decoder_config 48 | 49 | self.visual = None 50 | self.decoder = None 51 | self.lp = None 52 | 53 | @classmethod 54 | def from_pretrained(cls, pretrained_bert_name, visual_model_name, decoder_model_name, 55 | state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs): 56 | 57 | task_config = None 58 | if "task_config" in kwargs.keys(): 59 | task_config = kwargs["task_config"] 60 | if not hasattr(task_config, "local_rank"): 61 | task_config.__dict__["local_rank"] = 0 62 | elif task_config.local_rank == -1: 63 | task_config.local_rank = 0 64 | 65 | bert_config, state_dict = BertConfig.get_config(pretrained_bert_name, cache_dir, type_vocab_size, state_dict, task_config=task_config) 66 | visual_config, _ = VisualConfig.get_config(visual_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config) 67 | decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config) 68 | 69 | model = cls(bert_config, visual_config, decoder_config, *inputs, **kwargs) 70 | 71 | # assert model.bert is not None 72 | assert model.visual is not None 73 | 74 | if state_dict is not None: 75 | model = cls.init_preweight(model, state_dict, task_config=task_config) 76 | 77 | return model 78 | 79 | class NormalizeVideo(nn.Module): 80 | def __init__(self, task_config): 81 | super(NormalizeVideo, self).__init__() 82 | self.visual_norm2d = LayerNorm(task_config.video_dim) 83 | 84 | def forward(self, video): 85 | video = torch.as_tensor(video).float() 86 | video = video.view(-1, video.shape[-2], video.shape[-1]) 87 | video = self.visual_norm2d(video) 88 | return video 89 | 90 | def show_log(task_config, info): 91 | if task_config is None or task_config.local_rank == 0: 92 | logger.warning(info) 93 | 94 | def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None): 95 | if hasattr(source_config, source_attr_name): 96 | if default_value is None or getattr(source_config, source_attr_name) != default_value: 97 | setattr(target_config, target_attr_name, getattr(source_config, source_attr_name)) 98 | show_log(source_config, "Set {}.{}: {}.".format(target_name, 99 | target_attr_name, getattr(target_config, target_attr_name))) 100 | return target_config 101 | 102 | def check_attr(target_name, task_config): 103 | return hasattr(task_config, target_name) and task_config.__dict__[target_name] 104 | 105 | class CaptionGenerator(CaptionGeneratorPreTrainedModel): 106 | def __init__(self, bert_config, visual_config, decoder_config, task_config): 107 | super(CaptionGenerator, self).__init__(bert_config, visual_config, decoder_config) 108 | self.task_config = task_config 109 | self.ignore_video_index = -1 110 | 111 | assert self.task_config.max_words <= bert_config.max_position_embeddings 112 | assert self.task_config.max_words <= decoder_config.max_target_embeddings 113 | assert self.task_config.max_frames <= visual_config.max_position_embeddings 114 | 115 | # Text Encoder ===> 116 | bert_config = update_attr("bert_config", bert_config, "num_hidden_layers", 117 | self.task_config, "text_num_hidden_layers") 118 | bert = BertModel(bert_config) 119 | bert_word_embeddings_weight = bert.embeddings.word_embeddings.weight 120 | bert_position_embeddings_weight = bert.embeddings.position_embeddings.weight 121 | # <=== End of Text Encoder 122 | 123 | # Video Encoder ===> 124 | visual_config = update_attr("visual_config", visual_config, "num_hidden_layers", 125 | self.task_config, "visual_num_hidden_layers") 126 | self.visual = VisualModel(visual_config) 127 | visual_word_embeddings_weight = self.visual.embeddings.word_embeddings.weight 128 | # <=== End of Video Encoder 129 | 130 | 131 | # Decoder ===> 132 | decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers", 133 | self.task_config, "decoder_num_hidden_layers") 134 | self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight) 135 | # <=== End of Decoder 136 | 137 | self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1) 138 | 139 | self.normalize_video = NormalizeVideo(task_config) 140 | 141 | self.apply(self.init_weights) 142 | 143 | def forward(self, video, video_mask=None, 144 | input_caption_ids=None, decoder_mask=None): 145 | 146 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 147 | video = self.normalize_video(video) 148 | 149 | if input_caption_ids is not None: 150 | input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1]) 151 | decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1]) 152 | 153 | visual_output = self.get_visual_output(video, video_mask, shaped=True) 154 | 155 | if self.training: 156 | loss = 0. 157 | 158 | if (input_caption_ids is not None): 159 | decoder_scores, res_tuples = self._get_decoder_score(visual_output, video_mask, 160 | input_caption_ids, decoder_mask, shaped=True) 161 | # output_caption_ids = output_caption_ids.view(-1, output_caption_ids.shape[-1]) 162 | # decoder_loss = self.decoder_loss_fct(decoder_scores.view(-1, self.bert_config.vocab_size), output_caption_ids.view(-1)) 163 | # loss += decoder_loss 164 | 165 | return decoder_scores 166 | else: 167 | return None 168 | 169 | def get_visual_output(self, video, video_mask, shaped=False): 170 | if shaped is False: 171 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 172 | video = self.normalize_video(video) 173 | 174 | visual_layers, _ = self.visual(video, video_mask, output_all_encoded_layers=True) 175 | visual_output = visual_layers[-1] 176 | 177 | return visual_output 178 | 179 | 180 | def _get_decoder_score(self, visual_output, video_mask, input_caption_ids, decoder_mask, shaped=False): 181 | 182 | if shaped is False: 183 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 184 | 185 | input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1]) 186 | decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1]) 187 | 188 | res_tuples = () 189 | decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=video_mask) 190 | 191 | return decoder_scores, res_tuples 192 | 193 | def decoder_caption(self, visual_output, video_mask, input_caption_ids, decoder_mask, 194 | shaped=False, get_logits=False): 195 | if shaped is False: 196 | video_mask = video_mask.view(-1, video_mask.shape[-1]) 197 | 198 | input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1]) 199 | decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1]) 200 | 201 | decoder_scores, _ = self._get_decoder_score(visual_output, 202 | video_mask, 203 | input_caption_ids, decoder_mask, shaped=True) 204 | 205 | if get_logits: 206 | return decoder_scores 207 | 208 | _, decoder_scores_result = torch.max(decoder_scores, -1) 209 | 210 | return decoder_scores_result -------------------------------------------------------------------------------- /feature_extractor/modules/module_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | import json 8 | import math 9 | import logging 10 | import tarfile 11 | import tempfile 12 | import shutil 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from .file_utils import cached_path 18 | from .until_config import PretrainedConfig 19 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN 20 | from collections import OrderedDict 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 25 | CONFIG_NAME = 'cross_config.json' 26 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 27 | 28 | 29 | class CrossConfig(PretrainedConfig): 30 | """Configuration class to store the configuration of a `CrossModel`. 31 | """ 32 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP 33 | config_name = CONFIG_NAME 34 | weights_name = WEIGHTS_NAME 35 | def __init__(self, 36 | vocab_size_or_config_json_file, 37 | hidden_size=768, 38 | num_hidden_layers=12, 39 | num_attention_heads=12, 40 | intermediate_size=3072, 41 | hidden_act="gelu", 42 | hidden_dropout_prob=0.1, 43 | attention_probs_dropout_prob=0.1, 44 | max_position_embeddings=512, 45 | type_vocab_size=2, 46 | initializer_range=0.02): 47 | """Constructs CrossConfig. 48 | 49 | Args: 50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. 51 | hidden_size: Size of the encoder layers and the pooler layer. 52 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 53 | num_attention_heads: Number of attention heads for each attention layer in 54 | the Transformer encoder. 55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 56 | layer in the Transformer encoder. 57 | hidden_act: The non-linear activation function (function or string) in the 58 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 60 | layers in the embeddings, encoder, and pooler. 61 | attention_probs_dropout_prob: The dropout ratio for the attention 62 | probabilities. 63 | max_position_embeddings: The maximum sequence length that this model might 64 | ever be used with. Typically set this to something large just in case 65 | (e.g., 512 or 1024 or 2048). 66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 67 | `CrossModel`. 68 | initializer_range: The sttdev of the truncated_normal_initializer for 69 | initializing all weight matrices. 70 | """ 71 | if isinstance(vocab_size_or_config_json_file, str): 72 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 73 | json_config = json.loads(reader.read()) 74 | for key, value in json_config.items(): 75 | self.__dict__[key] = value 76 | elif isinstance(vocab_size_or_config_json_file, int): 77 | self.vocab_size = vocab_size_or_config_json_file 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | else: 89 | raise ValueError("First argument must be either a vocabulary size (int)" 90 | "or the path to a pretrained model config file (str)") 91 | 92 | class QuickGELU(nn.Module): 93 | def forward(self, x: torch.Tensor): 94 | return x * torch.sigmoid(1.702 * x) 95 | 96 | class ResidualAttentionBlock(nn.Module): 97 | def __init__(self, d_model: int, n_head: int): 98 | super().__init__() 99 | 100 | self.attn = nn.MultiheadAttention(d_model, n_head) 101 | self.ln_1 = LayerNorm(d_model) 102 | self.mlp = nn.Sequential(OrderedDict([ 103 | ("c_fc", nn.Linear(d_model, d_model * 4)), 104 | ("gelu", QuickGELU()), 105 | ("c_proj", nn.Linear(d_model * 4, d_model)) 106 | ])) 107 | self.ln_2 = LayerNorm(d_model) 108 | self.n_head = n_head 109 | 110 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): 111 | attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) 112 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 113 | 114 | def forward(self, para_tuple: tuple): 115 | # x: torch.Tensor, attn_mask: torch.Tensor 116 | # print(para_tuple) 117 | x, attn_mask = para_tuple 118 | x = x + self.attention(self.ln_1(x), attn_mask) 119 | x = x + self.mlp(self.ln_2(x)) 120 | return (x, attn_mask) 121 | 122 | class Transformer(nn.Module): 123 | def __init__(self, width: int, layers: int, heads: int): 124 | super().__init__() 125 | self.width = width 126 | self.layers = layers 127 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 128 | 129 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 130 | return self.resblocks((x, attn_mask))[0] 131 | 132 | class CrossEmbeddings(nn.Module): 133 | """Construct the embeddings from word, position and token_type embeddings. 134 | """ 135 | def __init__(self, config): 136 | super(CrossEmbeddings, self).__init__() 137 | 138 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 139 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 140 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, concat_embeddings, concat_type=None): 144 | 145 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) 146 | # if concat_type is None: 147 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) 148 | 149 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) 150 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) 151 | 152 | # token_type_embeddings = self.token_type_embeddings(concat_type) 153 | position_embeddings = self.position_embeddings(position_ids) 154 | 155 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings 156 | # embeddings = self.LayerNorm(embeddings) 157 | embeddings = self.dropout(embeddings) 158 | return embeddings 159 | 160 | class CrossPooler(nn.Module): 161 | def __init__(self, config): 162 | super(CrossPooler, self).__init__() 163 | self.ln_pool = LayerNorm(config.hidden_size) 164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 165 | self.activation = QuickGELU() 166 | 167 | def forward(self, hidden_states, hidden_mask): 168 | # We "pool" the model by simply taking the hidden state corresponding 169 | # to the first token. 170 | hidden_states = self.ln_pool(hidden_states) 171 | pooled_output = hidden_states[:, 0] 172 | pooled_output = self.dense(pooled_output) 173 | pooled_output = self.activation(pooled_output) 174 | return pooled_output 175 | 176 | class CrossModel(PreTrainedModel): 177 | 178 | def initialize_parameters(self): 179 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 180 | attn_std = self.transformer.width ** -0.5 181 | fc_std = (2 * self.transformer.width) ** -0.5 182 | for block in self.transformer.resblocks: 183 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 184 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 185 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 186 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 187 | 188 | def __init__(self, config): 189 | super(CrossModel, self).__init__(config) 190 | 191 | self.embeddings = CrossEmbeddings(config) 192 | 193 | transformer_width = config.hidden_size 194 | transformer_layers = config.num_hidden_layers 195 | transformer_heads = config.num_attention_heads 196 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads,) 197 | self.pooler = CrossPooler(config) 198 | self.apply(self.init_weights) 199 | 200 | def build_attention_mask(self, attention_mask): 201 | extended_attention_mask = attention_mask.unsqueeze(1) 202 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 203 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 204 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) 205 | return extended_attention_mask 206 | 207 | def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True): 208 | 209 | if attention_mask is None: 210 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) 211 | if concat_type is None: 212 | concat_type = torch.zeros_like(attention_mask) 213 | 214 | extended_attention_mask = self.build_attention_mask(attention_mask) 215 | 216 | embedding_output = self.embeddings(concat_input, concat_type) 217 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND 218 | embedding_output = self.transformer(embedding_output, extended_attention_mask) 219 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD 220 | 221 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) 222 | 223 | return embedding_output, pooled_output 224 | -------------------------------------------------------------------------------- /modules/until_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | 18 | import logging 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | import math 24 | from modules.until_config import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def gelu(x): 29 | """Implementation of the gelu activation function. 30 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 31 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 32 | """ 33 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 34 | 35 | def swish(x): 36 | return x * torch.sigmoid(x) 37 | 38 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 39 | 40 | class LayerNorm(nn.Module): 41 | def __init__(self, hidden_size, eps=1e-12): 42 | """Construct a layernorm module in the TF style (epsilon inside the square root). 43 | """ 44 | super(LayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(hidden_size)) 46 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 47 | self.variance_epsilon = eps 48 | 49 | def forward(self, x): 50 | u = x.mean(-1, keepdim=True) 51 | s = (x - u).pow(2).mean(-1, keepdim=True) 52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 53 | return self.weight * x + self.bias 54 | 55 | class PreTrainedModel(nn.Module): 56 | """ An abstract class to handle weights initialization and 57 | a simple interface for dowloading and loading pretrained models. 58 | """ 59 | def __init__(self, config, *inputs, **kwargs): 60 | super(PreTrainedModel, self).__init__() 61 | if not isinstance(config, PretrainedConfig): 62 | raise ValueError( 63 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 64 | "To create a model from a Google pretrained model use " 65 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 66 | self.__class__.__name__, self.__class__.__name__ 67 | )) 68 | self.config = config 69 | 70 | def init_weights(self, module): 71 | """ Initialize the weights. 72 | """ 73 | if isinstance(module, (nn.Linear, nn.Embedding)): 74 | # Slightly different from the TF version which uses truncated_normal for initialization 75 | # cf https://github.com/pytorch/pytorch/pull/5617 76 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 77 | elif isinstance(module, LayerNorm): 78 | if 'beta' in dir(module) and 'gamma' in dir(module): 79 | module.beta.data.zero_() 80 | module.gamma.data.fill_(1.0) 81 | else: 82 | module.bias.data.zero_() 83 | module.weight.data.fill_(1.0) 84 | if isinstance(module, nn.Linear) and module.bias is not None: 85 | module.bias.data.zero_() 86 | 87 | def resize_token_embeddings(self, new_num_tokens=None): 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | def init_preweight(cls, model, state_dict, prefix=None, task_config=None): 92 | old_keys = [] 93 | new_keys = [] 94 | for key in state_dict.keys(): 95 | new_key = None 96 | if 'gamma' in key: 97 | new_key = key.replace('gamma', 'weight') 98 | if 'beta' in key: 99 | new_key = key.replace('beta', 'bias') 100 | if new_key: 101 | old_keys.append(key) 102 | new_keys.append(new_key) 103 | for old_key, new_key in zip(old_keys, new_keys): 104 | state_dict[new_key] = state_dict.pop(old_key) 105 | 106 | if prefix is not None: 107 | old_keys = [] 108 | new_keys = [] 109 | for key in state_dict.keys(): 110 | old_keys.append(key) 111 | new_keys.append(prefix + key) 112 | for old_key, new_key in zip(old_keys, new_keys): 113 | state_dict[new_key] = state_dict.pop(old_key) 114 | 115 | missing_keys = [] 116 | unexpected_keys = [] 117 | error_msgs = [] 118 | # copy state_dict so _load_from_state_dict can modify it 119 | metadata = getattr(state_dict, '_metadata', None) 120 | state_dict = state_dict.copy() 121 | if metadata is not None: 122 | state_dict._metadata = metadata 123 | 124 | def load(module, prefix=''): 125 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 126 | module._load_from_state_dict( 127 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 128 | for name, child in module._modules.items(): 129 | if child is not None: 130 | load(child, prefix + name + '.') 131 | 132 | load(model, prefix='') 133 | 134 | if prefix is None and (task_config is None or task_config.local_rank == 0): 135 | logger.info("-" * 20) 136 | if len(missing_keys) > 0: 137 | logger.info("Weights of {} not initialized from pretrained model: {}" 138 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 139 | if len(unexpected_keys) > 0: 140 | logger.info("Weights from pretrained model not used in {}: {}" 141 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 142 | if len(error_msgs) > 0: 143 | logger.error("Weights from pretrained model cause errors in {}: {}" 144 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 145 | 146 | return model 147 | 148 | @property 149 | def dtype(self): 150 | """ 151 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 152 | """ 153 | try: 154 | return next(self.parameters()).dtype 155 | except StopIteration: 156 | # For nn.DataParallel compatibility in PyTorch 1.5 157 | def find_tensor_attributes(module: nn.Module): 158 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 159 | return tuples 160 | 161 | gen = self._named_members(get_members_fn=find_tensor_attributes) 162 | first_tuple = next(gen) 163 | return first_tuple[1].dtype 164 | 165 | @classmethod 166 | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): 167 | """ 168 | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. 169 | Download and cache the pre-trained model file if needed. 170 | """ 171 | # Instantiate model. 172 | model = cls(config, *inputs, **kwargs) 173 | if state_dict is None: 174 | return model 175 | model = cls.init_preweight(model, state_dict) 176 | 177 | return model 178 | 179 | ################################## 180 | ###### LOSS FUNCTION ############# 181 | ################################## 182 | class CrossEn(nn.Module): 183 | """ 184 | Implementation of cross entropy loss over similarity score matrix, used for calculating 185 | symmetric cross entropy loss . 186 | """ 187 | def __init__(self,): 188 | super(CrossEn, self).__init__() 189 | 190 | def forward(self, sim_matrix): 191 | logpt = F.log_softmax(sim_matrix, dim=-1) 192 | logpt = torch.diag(logpt) 193 | nce_loss = -logpt 194 | sim_loss = nce_loss.mean() 195 | return sim_loss 196 | 197 | class MILNCELoss(nn.Module): 198 | """ 199 | Implementation of MIL-NCE Loss 200 | """ 201 | def __init__(self, batch_size=1, n_pair=1,): 202 | super(MILNCELoss, self).__init__() 203 | self.batch_size = batch_size 204 | self.n_pair = n_pair 205 | torch_v = float(".".join(torch.__version__.split(".")[:2])) 206 | self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 207 | 208 | def forward(self, sim_matrix): 209 | mm_mask = np.eye(self.batch_size) 210 | mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) 211 | mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) 212 | 213 | from_text_matrix = sim_matrix + mm_mask * -1e12 214 | from_video_matrix = sim_matrix.transpose(1, 0) 215 | 216 | new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) 217 | logpt = F.log_softmax(new_sim_matrix, dim=-1) 218 | 219 | mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) 220 | masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 221 | 222 | new_logpt = -torch.logsumexp(masked_logpt, dim=-1) 223 | 224 | logpt_choice = torch.zeros_like(new_logpt) 225 | mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) 226 | logpt_choice[mark_ind] = 1 227 | sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() 228 | return sim_loss 229 | 230 | class MaxMarginRankingLoss(nn.Module): 231 | """ 232 | Implementation of max margin ranking loss 233 | """ 234 | def __init__(self, 235 | margin=1.0, 236 | negative_weighting=False, 237 | batch_size=1, 238 | n_pair=1, 239 | hard_negative_rate=0.5, 240 | ): 241 | super(MaxMarginRankingLoss, self).__init__() 242 | self.margin = margin 243 | self.n_pair = n_pair 244 | self.batch_size = batch_size 245 | easy_negative_rate = 1 - hard_negative_rate 246 | self.easy_negative_rate = easy_negative_rate 247 | self.negative_weighting = negative_weighting 248 | if n_pair > 1 and batch_size > 1: 249 | alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) 250 | mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha 251 | mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) 252 | mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) 253 | self.mm_mask = mm_mask.float() 254 | 255 | def forward(self, x): 256 | d = torch.diag(x) 257 | max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ 258 | F.relu(self.margin + x - d.view(1, -1)) 259 | if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: 260 | max_margin = max_margin * self.mm_mask.to(max_margin.device) 261 | return max_margin.mean() 262 | -------------------------------------------------------------------------------- /feature_extractor/modules/until_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | import logging 19 | import numpy as np 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | import math 24 | from modules.until_config import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def gelu(x): 29 | """Implementation of the gelu activation function. 30 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 31 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 32 | """ 33 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 34 | 35 | def swish(x): 36 | return x * torch.sigmoid(x) 37 | 38 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 39 | 40 | class LayerNorm(nn.Module): 41 | def __init__(self, hidden_size, eps=1e-12): 42 | """Construct a layernorm module in the TF style (epsilon inside the square root). 43 | """ 44 | super(LayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(hidden_size)) 46 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 47 | self.variance_epsilon = eps 48 | 49 | def forward(self, x): 50 | u = x.mean(-1, keepdim=True) 51 | s = (x - u).pow(2).mean(-1, keepdim=True) 52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 53 | return self.weight * x + self.bias 54 | 55 | class PreTrainedModel(nn.Module): 56 | """ An abstract class to handle weights initialization and 57 | a simple interface for dowloading and loading pretrained models. 58 | """ 59 | def __init__(self, config, *inputs, **kwargs): 60 | super(PreTrainedModel, self).__init__() 61 | if not isinstance(config, PretrainedConfig): 62 | raise ValueError( 63 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 64 | "To create a model from a Google pretrained model use " 65 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 66 | self.__class__.__name__, self.__class__.__name__ 67 | )) 68 | self.config = config 69 | 70 | def init_weights(self, module): 71 | """ Initialize the weights. 72 | """ 73 | if isinstance(module, (nn.Linear, nn.Embedding)): 74 | # Slightly different from the TF version which uses truncated_normal for initialization 75 | # cf https://github.com/pytorch/pytorch/pull/5617 76 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 77 | elif isinstance(module, LayerNorm): 78 | if 'beta' in dir(module) and 'gamma' in dir(module): 79 | module.beta.data.zero_() 80 | module.gamma.data.fill_(1.0) 81 | else: 82 | module.bias.data.zero_() 83 | module.weight.data.fill_(1.0) 84 | if isinstance(module, nn.Linear) and module.bias is not None: 85 | module.bias.data.zero_() 86 | 87 | def resize_token_embeddings(self, new_num_tokens=None): 88 | raise NotImplementedError 89 | 90 | @classmethod 91 | def init_preweight(cls, model, state_dict, prefix=None, task_config=None): 92 | old_keys = [] 93 | new_keys = [] 94 | for key in state_dict.keys(): 95 | new_key = None 96 | if 'gamma' in key: 97 | new_key = key.replace('gamma', 'weight') 98 | if 'beta' in key: 99 | new_key = key.replace('beta', 'bias') 100 | if new_key: 101 | old_keys.append(key) 102 | new_keys.append(new_key) 103 | for old_key, new_key in zip(old_keys, new_keys): 104 | state_dict[new_key] = state_dict.pop(old_key) 105 | 106 | if prefix is not None: 107 | old_keys = [] 108 | new_keys = [] 109 | for key in state_dict.keys(): 110 | old_keys.append(key) 111 | new_keys.append(prefix + key) 112 | for old_key, new_key in zip(old_keys, new_keys): 113 | state_dict[new_key] = state_dict.pop(old_key) 114 | 115 | missing_keys = [] 116 | unexpected_keys = [] 117 | error_msgs = [] 118 | # copy state_dict so _load_from_state_dict can modify it 119 | metadata = getattr(state_dict, '_metadata', None) 120 | state_dict = state_dict.copy() 121 | if metadata is not None: 122 | state_dict._metadata = metadata 123 | 124 | def load(module, prefix=''): 125 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 126 | module._load_from_state_dict( 127 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 128 | for name, child in module._modules.items(): 129 | if child is not None: 130 | load(child, prefix + name + '.') 131 | 132 | load(model, prefix='') 133 | 134 | if prefix is None and (task_config is None or task_config.local_rank == 0): 135 | logger.info("-" * 20) 136 | if len(missing_keys) > 0: 137 | logger.info("Weights of {} not initialized from pretrained model: {}" 138 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 139 | if len(unexpected_keys) > 0: 140 | logger.info("Weights from pretrained model not used in {}: {}" 141 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 142 | if len(error_msgs) > 0: 143 | logger.error("Weights from pretrained model cause errors in {}: {}" 144 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 145 | 146 | return model 147 | 148 | @property 149 | def dtype(self): 150 | """ 151 | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 152 | """ 153 | try: 154 | return next(self.parameters()).dtype 155 | except StopIteration: 156 | # For nn.DataParallel compatibility in PyTorch 1.5 157 | def find_tensor_attributes(module: nn.Module): 158 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 159 | return tuples 160 | 161 | gen = self._named_members(get_members_fn=find_tensor_attributes) 162 | first_tuple = next(gen) 163 | return first_tuple[1].dtype 164 | 165 | @classmethod 166 | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): 167 | """ 168 | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. 169 | Download and cache the pre-trained model file if needed. 170 | """ 171 | # Instantiate model. 172 | model = cls(config, *inputs, **kwargs) 173 | if state_dict is None: 174 | return model 175 | model = cls.init_preweight(model, state_dict) 176 | 177 | return model 178 | 179 | ################################## 180 | ###### LOSS FUNCTION ############# 181 | ################################## 182 | class CrossEn(nn.Module): 183 | def __init__(self,): 184 | super(CrossEn, self).__init__() 185 | 186 | def forward(self, sim_matrix): 187 | logpt = F.log_softmax(sim_matrix, dim=-1) 188 | logpt = torch.diag(logpt) 189 | nce_loss = -logpt 190 | sim_loss = nce_loss.mean() 191 | return sim_loss 192 | 193 | class MILNCELoss(nn.Module): 194 | def __init__(self, batch_size=1, n_pair=1,): 195 | super(MILNCELoss, self).__init__() 196 | self.batch_size = batch_size 197 | self.n_pair = n_pair 198 | torch_v = float(".".join(torch.__version__.split(".")[:2])) 199 | self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 200 | 201 | def forward(self, sim_matrix): 202 | mm_mask = np.eye(self.batch_size) 203 | mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) 204 | mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) 205 | 206 | from_text_matrix = sim_matrix + mm_mask * -1e12 207 | from_video_matrix = sim_matrix.transpose(1, 0) 208 | 209 | new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) 210 | logpt = F.log_softmax(new_sim_matrix, dim=-1) 211 | 212 | mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) 213 | masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 214 | 215 | new_logpt = -torch.logsumexp(masked_logpt, dim=-1) 216 | 217 | logpt_choice = torch.zeros_like(new_logpt) 218 | mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) 219 | logpt_choice[mark_ind] = 1 220 | sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() 221 | return sim_loss 222 | 223 | class MaxMarginRankingLoss(nn.Module): 224 | def __init__(self, 225 | margin=1.0, 226 | negative_weighting=False, 227 | batch_size=1, 228 | n_pair=1, 229 | hard_negative_rate=0.5, 230 | ): 231 | super(MaxMarginRankingLoss, self).__init__() 232 | self.margin = margin 233 | self.n_pair = n_pair 234 | self.batch_size = batch_size 235 | easy_negative_rate = 1 - hard_negative_rate 236 | self.easy_negative_rate = easy_negative_rate 237 | self.negative_weighting = negative_weighting 238 | if n_pair > 1 and batch_size > 1: 239 | alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) 240 | mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha 241 | mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) 242 | mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) 243 | self.mm_mask = mm_mask.float() 244 | 245 | def forward(self, x): 246 | d = torch.diag(x) 247 | max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ 248 | F.relu(self.margin + x - d.view(1, -1)) 249 | if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: 250 | max_margin = max_margin * self.mm_mask.to(max_margin.device) 251 | return max_margin.mean() 252 | 253 | class AllGather(torch.autograd.Function): 254 | """An autograd function that performs allgather on a tensor.""" 255 | 256 | @staticmethod 257 | def forward(ctx, tensor, args): 258 | output = [torch.empty_like(tensor) for _ in range(args.world_size)] 259 | torch.distributed.all_gather(output, tensor) 260 | ctx.rank = args.rank 261 | ctx.batch_size = tensor.shape[0] 262 | return torch.cat(output, dim=0) 263 | 264 | @staticmethod 265 | def backward(ctx, grad_output): 266 | return ( 267 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 268 | None, 269 | ) 270 | -------------------------------------------------------------------------------- /dataset/MSVD/train_list_mapping.txt: -------------------------------------------------------------------------------- 1 | vid1 2 | vid2 3 | vid3 4 | vid4 5 | vid5 6 | vid6 7 | vid7 8 | vid8 9 | vid9 10 | vid10 11 | vid11 12 | vid12 13 | vid13 14 | vid14 15 | vid15 16 | vid16 17 | vid17 18 | vid18 19 | vid19 20 | vid20 21 | vid21 22 | vid22 23 | vid23 24 | vid24 25 | vid25 26 | vid26 27 | vid27 28 | vid28 29 | vid29 30 | vid30 31 | vid31 32 | vid32 33 | vid33 34 | vid34 35 | vid35 36 | vid36 37 | vid37 38 | vid38 39 | vid39 40 | vid40 41 | vid41 42 | vid42 43 | vid43 44 | vid44 45 | vid45 46 | vid46 47 | vid47 48 | vid48 49 | vid49 50 | vid50 51 | vid51 52 | vid52 53 | vid53 54 | vid54 55 | vid55 56 | vid56 57 | vid57 58 | vid58 59 | vid59 60 | vid60 61 | vid61 62 | vid62 63 | vid63 64 | vid64 65 | vid65 66 | vid66 67 | vid67 68 | vid68 69 | vid69 70 | vid70 71 | vid71 72 | vid72 73 | vid73 74 | vid74 75 | vid75 76 | vid76 77 | vid77 78 | vid78 79 | vid79 80 | vid80 81 | vid81 82 | vid82 83 | vid83 84 | vid84 85 | vid85 86 | vid86 87 | vid87 88 | vid88 89 | vid89 90 | vid90 91 | vid91 92 | vid92 93 | vid93 94 | vid94 95 | vid95 96 | vid96 97 | vid97 98 | vid98 99 | vid99 100 | vid100 101 | vid101 102 | vid102 103 | vid103 104 | vid104 105 | vid105 106 | vid106 107 | vid107 108 | vid108 109 | vid109 110 | vid110 111 | vid111 112 | vid112 113 | vid113 114 | vid114 115 | vid115 116 | vid116 117 | vid117 118 | vid118 119 | vid119 120 | vid120 121 | vid121 122 | vid122 123 | vid123 124 | vid124 125 | vid125 126 | vid126 127 | vid127 128 | vid128 129 | vid129 130 | vid130 131 | vid131 132 | vid132 133 | vid133 134 | vid134 135 | vid135 136 | vid136 137 | vid137 138 | vid138 139 | vid139 140 | vid140 141 | vid141 142 | vid142 143 | vid143 144 | vid144 145 | vid145 146 | vid146 147 | vid147 148 | vid148 149 | vid149 150 | vid150 151 | vid151 152 | vid152 153 | vid153 154 | vid154 155 | vid155 156 | vid156 157 | vid157 158 | vid158 159 | vid159 160 | vid160 161 | vid161 162 | vid162 163 | vid163 164 | vid164 165 | vid165 166 | vid166 167 | vid167 168 | vid168 169 | vid169 170 | vid170 171 | vid171 172 | vid172 173 | vid173 174 | vid174 175 | vid175 176 | vid176 177 | vid177 178 | vid178 179 | vid179 180 | vid180 181 | vid181 182 | vid182 183 | vid183 184 | vid184 185 | vid185 186 | vid186 187 | vid187 188 | vid188 189 | vid189 190 | vid190 191 | vid191 192 | vid192 193 | vid193 194 | vid194 195 | vid195 196 | vid196 197 | vid197 198 | vid198 199 | vid199 200 | vid200 201 | vid201 202 | vid202 203 | vid203 204 | vid204 205 | vid205 206 | vid206 207 | vid207 208 | vid208 209 | vid209 210 | vid210 211 | vid211 212 | vid212 213 | vid213 214 | vid214 215 | vid215 216 | vid216 217 | vid217 218 | vid218 219 | vid219 220 | vid220 221 | vid221 222 | vid222 223 | vid223 224 | vid224 225 | vid225 226 | vid226 227 | vid227 228 | vid228 229 | vid229 230 | vid230 231 | vid231 232 | vid232 233 | vid233 234 | vid234 235 | vid235 236 | vid236 237 | vid237 238 | vid238 239 | vid239 240 | vid240 241 | vid241 242 | vid242 243 | vid243 244 | vid244 245 | vid245 246 | vid246 247 | vid247 248 | vid248 249 | vid249 250 | vid250 251 | vid251 252 | vid252 253 | vid253 254 | vid254 255 | vid255 256 | vid256 257 | vid257 258 | vid258 259 | vid259 260 | vid260 261 | vid261 262 | vid262 263 | vid263 264 | vid264 265 | vid265 266 | vid266 267 | vid267 268 | vid268 269 | vid269 270 | vid270 271 | vid271 272 | vid272 273 | vid273 274 | vid274 275 | vid275 276 | vid276 277 | vid277 278 | vid278 279 | vid279 280 | vid280 281 | vid281 282 | vid282 283 | vid283 284 | vid284 285 | vid285 286 | vid286 287 | vid287 288 | vid288 289 | vid289 290 | vid290 291 | vid291 292 | vid292 293 | vid293 294 | vid294 295 | vid295 296 | vid296 297 | vid297 298 | vid298 299 | vid299 300 | vid300 301 | vid301 302 | vid302 303 | vid303 304 | vid304 305 | vid305 306 | vid306 307 | vid307 308 | vid308 309 | vid309 310 | vid310 311 | vid311 312 | vid312 313 | vid313 314 | vid314 315 | vid315 316 | vid316 317 | vid317 318 | vid318 319 | vid319 320 | vid320 321 | vid321 322 | vid322 323 | vid323 324 | vid324 325 | vid325 326 | vid326 327 | vid327 328 | vid328 329 | vid329 330 | vid330 331 | vid331 332 | vid332 333 | vid333 334 | vid334 335 | vid335 336 | vid336 337 | vid337 338 | vid338 339 | vid339 340 | vid340 341 | vid341 342 | vid342 343 | vid343 344 | vid344 345 | vid345 346 | vid346 347 | vid347 348 | vid348 349 | vid349 350 | vid350 351 | vid351 352 | vid352 353 | vid353 354 | vid354 355 | vid355 356 | vid356 357 | vid357 358 | vid358 359 | vid359 360 | vid360 361 | vid361 362 | vid362 363 | vid363 364 | vid364 365 | vid365 366 | vid366 367 | vid367 368 | vid368 369 | vid369 370 | vid370 371 | vid371 372 | vid372 373 | vid373 374 | vid374 375 | vid375 376 | vid376 377 | vid377 378 | vid378 379 | vid379 380 | vid380 381 | vid381 382 | vid382 383 | vid383 384 | vid384 385 | vid385 386 | vid386 387 | vid387 388 | vid388 389 | vid389 390 | vid390 391 | vid391 392 | vid392 393 | vid393 394 | vid394 395 | vid395 396 | vid396 397 | vid397 398 | vid398 399 | vid399 400 | vid400 401 | vid401 402 | vid402 403 | vid403 404 | vid404 405 | vid405 406 | vid406 407 | vid407 408 | vid408 409 | vid409 410 | vid410 411 | vid411 412 | vid412 413 | vid413 414 | vid414 415 | vid415 416 | vid416 417 | vid417 418 | vid418 419 | vid419 420 | vid420 421 | vid421 422 | vid422 423 | vid423 424 | vid424 425 | vid425 426 | vid426 427 | vid427 428 | vid428 429 | vid429 430 | vid430 431 | vid431 432 | vid432 433 | vid433 434 | vid434 435 | vid435 436 | vid436 437 | vid437 438 | vid438 439 | vid439 440 | vid440 441 | vid441 442 | vid442 443 | vid443 444 | vid444 445 | vid445 446 | vid446 447 | vid447 448 | vid448 449 | vid449 450 | vid450 451 | vid451 452 | vid452 453 | vid453 454 | vid454 455 | vid455 456 | vid456 457 | vid457 458 | vid458 459 | vid459 460 | vid460 461 | vid461 462 | vid462 463 | vid463 464 | vid464 465 | vid465 466 | vid466 467 | vid467 468 | vid468 469 | vid469 470 | vid470 471 | vid471 472 | vid472 473 | vid473 474 | vid474 475 | vid475 476 | vid476 477 | vid477 478 | vid478 479 | vid479 480 | vid480 481 | vid481 482 | vid482 483 | vid483 484 | vid484 485 | vid485 486 | vid486 487 | vid487 488 | vid488 489 | vid489 490 | vid490 491 | vid491 492 | vid492 493 | vid493 494 | vid494 495 | vid495 496 | vid496 497 | vid497 498 | vid498 499 | vid499 500 | vid500 501 | vid501 502 | vid502 503 | vid503 504 | vid504 505 | vid505 506 | vid506 507 | vid507 508 | vid508 509 | vid509 510 | vid510 511 | vid511 512 | vid512 513 | vid513 514 | vid514 515 | vid515 516 | vid516 517 | vid517 518 | vid518 519 | vid519 520 | vid520 521 | vid521 522 | vid522 523 | vid523 524 | vid524 525 | vid525 526 | vid526 527 | vid527 528 | vid528 529 | vid529 530 | vid530 531 | vid531 532 | vid532 533 | vid533 534 | vid534 535 | vid535 536 | vid536 537 | vid537 538 | vid538 539 | vid539 540 | vid540 541 | vid541 542 | vid542 543 | vid543 544 | vid544 545 | vid545 546 | vid546 547 | vid547 548 | vid548 549 | vid549 550 | vid550 551 | vid551 552 | vid552 553 | vid553 554 | vid554 555 | vid555 556 | vid556 557 | vid557 558 | vid558 559 | vid559 560 | vid560 561 | vid561 562 | vid562 563 | vid563 564 | vid564 565 | vid565 566 | vid566 567 | vid567 568 | vid568 569 | vid569 570 | vid570 571 | vid571 572 | vid572 573 | vid573 574 | vid574 575 | vid575 576 | vid576 577 | vid577 578 | vid578 579 | vid579 580 | vid580 581 | vid581 582 | vid582 583 | vid583 584 | vid584 585 | vid585 586 | vid586 587 | vid587 588 | vid588 589 | vid589 590 | vid590 591 | vid591 592 | vid592 593 | vid593 594 | vid594 595 | vid595 596 | vid596 597 | vid597 598 | vid598 599 | vid599 600 | vid600 601 | vid601 602 | vid602 603 | vid603 604 | vid604 605 | vid605 606 | vid606 607 | vid607 608 | vid608 609 | vid609 610 | vid610 611 | vid611 612 | vid612 613 | vid613 614 | vid614 615 | vid615 616 | vid616 617 | vid617 618 | vid618 619 | vid619 620 | vid620 621 | vid621 622 | vid622 623 | vid623 624 | vid624 625 | vid625 626 | vid626 627 | vid627 628 | vid628 629 | vid629 630 | vid630 631 | vid631 632 | vid632 633 | vid633 634 | vid634 635 | vid635 636 | vid636 637 | vid637 638 | vid638 639 | vid639 640 | vid640 641 | vid641 642 | vid642 643 | vid643 644 | vid644 645 | vid645 646 | vid646 647 | vid647 648 | vid648 649 | vid649 650 | vid650 651 | vid651 652 | vid652 653 | vid653 654 | vid654 655 | vid655 656 | vid656 657 | vid657 658 | vid658 659 | vid659 660 | vid660 661 | vid661 662 | vid662 663 | vid663 664 | vid664 665 | vid665 666 | vid666 667 | vid667 668 | vid668 669 | vid669 670 | vid670 671 | vid671 672 | vid672 673 | vid673 674 | vid674 675 | vid675 676 | vid676 677 | vid677 678 | vid678 679 | vid679 680 | vid680 681 | vid681 682 | vid682 683 | vid683 684 | vid684 685 | vid685 686 | vid686 687 | vid687 688 | vid688 689 | vid689 690 | vid690 691 | vid691 692 | vid692 693 | vid693 694 | vid694 695 | vid695 696 | vid696 697 | vid697 698 | vid698 699 | vid699 700 | vid700 701 | vid701 702 | vid702 703 | vid703 704 | vid704 705 | vid705 706 | vid706 707 | vid707 708 | vid708 709 | vid709 710 | vid710 711 | vid711 712 | vid712 713 | vid713 714 | vid714 715 | vid715 716 | vid716 717 | vid717 718 | vid718 719 | vid719 720 | vid720 721 | vid721 722 | vid722 723 | vid723 724 | vid724 725 | vid725 726 | vid726 727 | vid727 728 | vid728 729 | vid729 730 | vid730 731 | vid731 732 | vid732 733 | vid733 734 | vid734 735 | vid735 736 | vid736 737 | vid737 738 | vid738 739 | vid739 740 | vid740 741 | vid741 742 | vid742 743 | vid743 744 | vid744 745 | vid745 746 | vid746 747 | vid747 748 | vid748 749 | vid749 750 | vid750 751 | vid751 752 | vid752 753 | vid753 754 | vid754 755 | vid755 756 | vid756 757 | vid757 758 | vid758 759 | vid759 760 | vid760 761 | vid761 762 | vid762 763 | vid763 764 | vid764 765 | vid765 766 | vid766 767 | vid767 768 | vid768 769 | vid769 770 | vid770 771 | vid771 772 | vid772 773 | vid773 774 | vid774 775 | vid775 776 | vid776 777 | vid777 778 | vid778 779 | vid779 780 | vid780 781 | vid781 782 | vid782 783 | vid783 784 | vid784 785 | vid785 786 | vid786 787 | vid787 788 | vid788 789 | vid789 790 | vid790 791 | vid791 792 | vid792 793 | vid793 794 | vid794 795 | vid795 796 | vid796 797 | vid797 798 | vid798 799 | vid799 800 | vid800 801 | vid801 802 | vid802 803 | vid803 804 | vid804 805 | vid805 806 | vid806 807 | vid807 808 | vid808 809 | vid809 810 | vid810 811 | vid811 812 | vid812 813 | vid813 814 | vid814 815 | vid815 816 | vid816 817 | vid817 818 | vid818 819 | vid819 820 | vid820 821 | vid821 822 | vid822 823 | vid823 824 | vid824 825 | vid825 826 | vid826 827 | vid827 828 | vid828 829 | vid829 830 | vid830 831 | vid831 832 | vid832 833 | vid833 834 | vid834 835 | vid835 836 | vid836 837 | vid837 838 | vid838 839 | vid839 840 | vid840 841 | vid841 842 | vid842 843 | vid843 844 | vid844 845 | vid845 846 | vid846 847 | vid847 848 | vid848 849 | vid849 850 | vid850 851 | vid851 852 | vid852 853 | vid853 854 | vid854 855 | vid855 856 | vid856 857 | vid857 858 | vid858 859 | vid859 860 | vid860 861 | vid861 862 | vid862 863 | vid863 864 | vid864 865 | vid865 866 | vid866 867 | vid867 868 | vid868 869 | vid869 870 | vid870 871 | vid871 872 | vid872 873 | vid873 874 | vid874 875 | vid875 876 | vid876 877 | vid877 878 | vid878 879 | vid879 880 | vid880 881 | vid881 882 | vid882 883 | vid883 884 | vid884 885 | vid885 886 | vid886 887 | vid887 888 | vid888 889 | vid889 890 | vid890 891 | vid891 892 | vid892 893 | vid893 894 | vid894 895 | vid895 896 | vid896 897 | vid897 898 | vid898 899 | vid899 900 | vid900 901 | vid901 902 | vid902 903 | vid903 904 | vid904 905 | vid905 906 | vid906 907 | vid907 908 | vid908 909 | vid909 910 | vid910 911 | vid911 912 | vid912 913 | vid913 914 | vid914 915 | vid915 916 | vid916 917 | vid917 918 | vid918 919 | vid919 920 | vid920 921 | vid921 922 | vid922 923 | vid923 924 | vid924 925 | vid925 926 | vid926 927 | vid927 928 | vid928 929 | vid929 930 | vid930 931 | vid931 932 | vid932 933 | vid933 934 | vid934 935 | vid935 936 | vid936 937 | vid937 938 | vid938 939 | vid939 940 | vid940 941 | vid941 942 | vid942 943 | vid943 944 | vid944 945 | vid945 946 | vid946 947 | vid947 948 | vid948 949 | vid949 950 | vid950 951 | vid951 952 | vid952 953 | vid953 954 | vid954 955 | vid955 956 | vid956 957 | vid957 958 | vid958 959 | vid959 960 | vid960 961 | vid961 962 | vid962 963 | vid963 964 | vid964 965 | vid965 966 | vid966 967 | vid967 968 | vid968 969 | vid969 970 | vid970 971 | vid971 972 | vid972 973 | vid973 974 | vid974 975 | vid975 976 | vid976 977 | vid977 978 | vid978 979 | vid979 980 | vid980 981 | vid981 982 | vid982 983 | vid983 984 | vid984 985 | vid985 986 | vid986 987 | vid987 988 | vid988 989 | vid989 990 | vid990 991 | vid991 992 | vid992 993 | vid993 994 | vid994 995 | vid995 996 | vid996 997 | vid997 998 | vid998 999 | vid999 1000 | vid1000 1001 | vid1001 1002 | vid1002 1003 | vid1003 1004 | vid1004 1005 | vid1005 1006 | vid1006 1007 | vid1007 1008 | vid1008 1009 | vid1009 1010 | vid1010 1011 | vid1011 1012 | vid1012 1013 | vid1013 1014 | vid1014 1015 | vid1015 1016 | vid1016 1017 | vid1017 1018 | vid1018 1019 | vid1019 1020 | vid1020 1021 | vid1021 1022 | vid1022 1023 | vid1023 1024 | vid1024 1025 | vid1025 1026 | vid1026 1027 | vid1027 1028 | vid1028 1029 | vid1029 1030 | vid1030 1031 | vid1031 1032 | vid1032 1033 | vid1033 1034 | vid1034 1035 | vid1035 1036 | vid1036 1037 | vid1037 1038 | vid1038 1039 | vid1039 1040 | vid1040 1041 | vid1041 1042 | vid1042 1043 | vid1043 1044 | vid1044 1045 | vid1045 1046 | vid1046 1047 | vid1047 1048 | vid1048 1049 | vid1049 1050 | vid1050 1051 | vid1051 1052 | vid1052 1053 | vid1053 1054 | vid1054 1055 | vid1055 1056 | vid1056 1057 | vid1057 1058 | vid1058 1059 | vid1059 1060 | vid1060 1061 | vid1061 1062 | vid1062 1063 | vid1063 1064 | vid1064 1065 | vid1065 1066 | vid1066 1067 | vid1067 1068 | vid1068 1069 | vid1069 1070 | vid1070 1071 | vid1071 1072 | vid1072 1073 | vid1073 1074 | vid1074 1075 | vid1075 1076 | vid1076 1077 | vid1077 1078 | vid1078 1079 | vid1079 1080 | vid1080 1081 | vid1081 1082 | vid1082 1083 | vid1083 1084 | vid1084 1085 | vid1085 1086 | vid1086 1087 | vid1087 1088 | vid1088 1089 | vid1089 1090 | vid1090 1091 | vid1091 1092 | vid1092 1093 | vid1093 1094 | vid1094 1095 | vid1095 1096 | vid1096 1097 | vid1097 1098 | vid1098 1099 | vid1099 1100 | vid1100 1101 | vid1101 1102 | vid1102 1103 | vid1103 1104 | vid1104 1105 | vid1105 1106 | vid1106 1107 | vid1107 1108 | vid1108 1109 | vid1109 1110 | vid1110 1111 | vid1111 1112 | vid1112 1113 | vid1113 1114 | vid1114 1115 | vid1115 1116 | vid1116 1117 | vid1117 1118 | vid1118 1119 | vid1119 1120 | vid1120 1121 | vid1121 1122 | vid1122 1123 | vid1123 1124 | vid1124 1125 | vid1125 1126 | vid1126 1127 | vid1127 1128 | vid1128 1129 | vid1129 1130 | vid1130 1131 | vid1131 1132 | vid1132 1133 | vid1133 1134 | vid1134 1135 | vid1135 1136 | vid1136 1137 | vid1137 1138 | vid1138 1139 | vid1139 1140 | vid1140 1141 | vid1141 1142 | vid1142 1143 | vid1143 1144 | vid1144 1145 | vid1145 1146 | vid1146 1147 | vid1147 1148 | vid1148 1149 | vid1149 1150 | vid1150 1151 | vid1151 1152 | vid1152 1153 | vid1153 1154 | vid1154 1155 | vid1155 1156 | vid1156 1157 | vid1157 1158 | vid1158 1159 | vid1159 1160 | vid1160 1161 | vid1161 1162 | vid1162 1163 | vid1163 1164 | vid1164 1165 | vid1165 1166 | vid1166 1167 | vid1167 1168 | vid1168 1169 | vid1169 1170 | vid1170 1171 | vid1171 1172 | vid1172 1173 | vid1173 1174 | vid1174 1175 | vid1175 1176 | vid1176 1177 | vid1177 1178 | vid1178 1179 | vid1179 1180 | vid1180 1181 | vid1181 1182 | vid1182 1183 | vid1183 1184 | vid1184 1185 | vid1185 1186 | vid1186 1187 | vid1187 1188 | vid1188 1189 | vid1189 1190 | vid1190 1191 | vid1191 1192 | vid1192 1193 | vid1193 1194 | vid1194 1195 | vid1195 1196 | vid1196 1197 | vid1197 1198 | vid1198 1199 | vid1199 1200 | vid1200 -------------------------------------------------------------------------------- /modules/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import sys 25 | import logging 26 | 27 | from .file_utils import cached_path 28 | 29 | logger = logging.getLogger(__name__) 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'base-uncased': 512, 41 | 'large-uncased': 512, 42 | 'base-cased': 512, 43 | 'large-cased': 512, 44 | 'base-multilingual-uncased': 512, 45 | 'base-multilingual-cased': 512, 46 | 'base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | vocab = collections.OrderedDict() 54 | index = 0 55 | with open(vocab_file, "r", encoding="utf-8") as reader: 56 | while True: 57 | token = reader.readline() 58 | if not token: 59 | break 60 | token = token.strip() 61 | vocab[token] = index 62 | index += 1 63 | return vocab 64 | 65 | 66 | def whitespace_tokenize(text): 67 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 68 | text = text.strip() 69 | if not text: 70 | return [] 71 | tokens = text.split() 72 | return tokens 73 | 74 | 75 | class BertTokenizer(object): 76 | """Runs end-to-end tokenization: punctuation splitting""" 77 | 78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[MASK]", "[CLS]")): 79 | if not os.path.isfile(vocab_file): 80 | raise ValueError( 81 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 82 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 83 | self.vocab = load_vocab(vocab_file) 84 | self.ids_to_tokens = collections.OrderedDict( 85 | [(ids, tok) for tok, ids in self.vocab.items()]) 86 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split) 87 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 88 | self.max_len = max_len if max_len is not None else int(1e12) 89 | 90 | def tokenize(self, text): 91 | split_tokens = [] 92 | for token in self.basic_tokenizer.tokenize(text): 93 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 94 | split_tokens.append(sub_token) 95 | return split_tokens 96 | 97 | def convert_tokens_to_ids(self, tokens): 98 | """Converts a sequence of tokens into ids using the vocab.""" 99 | ids = [] 100 | for token in tokens: 101 | if token not in self.vocab: 102 | ids.append(self.vocab["[UNK]"]) 103 | logger.error("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token)) 104 | else: 105 | ids.append(self.vocab[token]) 106 | if len(ids) > self.max_len: 107 | raise ValueError( 108 | "Token indices sequence length is longer than the specified maximum " 109 | " sequence length for this BERT model ({} > {}). Running this" 110 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 111 | ) 112 | return ids 113 | 114 | def convert_ids_to_tokens(self, ids): 115 | """Converts a sequence of ids in tokens using the vocab.""" 116 | tokens = [] 117 | for i in ids: 118 | tokens.append(self.ids_to_tokens[i]) 119 | return tokens 120 | 121 | @classmethod 122 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 123 | """ 124 | Instantiate a PreTrainedBertModel from a pre-trained model file. 125 | Download and cache the pre-trained model file if needed. 126 | """ 127 | vocab_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 128 | if os.path.exists(vocab_file) is False: 129 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 130 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 131 | else: 132 | vocab_file = pretrained_model_name 133 | if os.path.isdir(vocab_file): 134 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 135 | # redirect to the cache, if necessary 136 | print(vocab_file) 137 | try: 138 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 139 | except FileNotFoundError: 140 | logger.error( 141 | "Model name '{}' was not found. " 142 | "We assumed '{}' was a path or url but couldn't find any file " 143 | "associated to this path or url.".format( 144 | pretrained_model_name, 145 | vocab_file)) 146 | return None 147 | if resolved_vocab_file == vocab_file: 148 | logger.info("loading vocabulary file {}".format(vocab_file)) 149 | else: 150 | logger.info("loading vocabulary file {} from cache at {}".format( 151 | vocab_file, resolved_vocab_file)) 152 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 153 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 154 | # than the number of positional embeddings 155 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 156 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 157 | kwargs['never_split'] = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]") 158 | 159 | # Instantiate tokenizer. 160 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 161 | 162 | return tokenizer 163 | 164 | def add_tokens(self, new_tokens, model): 165 | """ 166 | Add a list of new tokens to the tokenizer class. If the new tokens are not in the 167 | vocabulary, they are added to it with indices starting from length of the current vocabulary. 168 | Args: 169 | new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). 170 | Returns: 171 | Number of tokens added to the vocabulary. 172 | Examples:: 173 | # Let's see how to increase the vocabulary of Bert model and tokenizer 174 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 175 | model = BertModel.from_pretrained('bert-base-uncased') 176 | num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) 177 | print('We have added', num_added_toks, 'tokens') 178 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. 179 | """ 180 | 181 | to_add_tokens = [] 182 | for token in new_tokens: 183 | assert isinstance(token, str) 184 | to_add_tokens.append(token) 185 | # logger.info("Adding %s to the vocabulary", token) 186 | 187 | vocab = collections.OrderedDict() 188 | for token in self.vocab.keys(): 189 | vocab[token] = self.vocab[token] 190 | for token in to_add_tokens: 191 | vocab[token] = len(vocab) 192 | self.vocab = self.wordpiece_tokenizer.vocab = vocab 193 | self.ids_to_tokens = collections.OrderedDict( 194 | [(ids, tok) for tok, ids in self.vocab.items()]) 195 | 196 | model.resize_token_embeddings(new_num_tokens=len(vocab)) 197 | 198 | class BasicTokenizer(object): 199 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 200 | 201 | def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 202 | """Constructs a BasicTokenizer. 203 | 204 | Args: 205 | do_lower_case: Whether to lower case the input. 206 | """ 207 | self.do_lower_case = do_lower_case 208 | self.never_split = never_split 209 | 210 | def tokenize(self, text): 211 | """Tokenizes a piece of text.""" 212 | text = self._clean_text(text) 213 | # This was added on November 1st, 2018 for the multilingual and Chinese 214 | # models. This is also applied to the English models now, but it doesn't 215 | # matter since the English models were not trained on any Chinese data 216 | # and generally don't have any Chinese data in them (there are Chinese 217 | # characters in the vocabulary because Wikipedia does have some Chinese 218 | # words in the English Wikipedia.). 219 | text = self._tokenize_chinese_chars(text) 220 | orig_tokens = whitespace_tokenize(text) 221 | split_tokens = [] 222 | for token in orig_tokens: 223 | if self.do_lower_case and token not in self.never_split: 224 | token = token.lower() 225 | token = self._run_strip_accents(token) 226 | split_tokens.extend(self._run_split_on_punc(token)) 227 | 228 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 229 | return output_tokens 230 | 231 | def _run_strip_accents(self, text): 232 | """Strips accents from a piece of text.""" 233 | text = unicodedata.normalize("NFD", text) 234 | output = [] 235 | for char in text: 236 | cat = unicodedata.category(char) 237 | if cat == "Mn": 238 | continue 239 | output.append(char) 240 | return "".join(output) 241 | 242 | def _run_split_on_punc(self, text): 243 | """Splits punctuation on a piece of text.""" 244 | if text in self.never_split: 245 | return [text] 246 | chars = list(text) 247 | i = 0 248 | start_new_word = True 249 | output = [] 250 | while i < len(chars): 251 | char = chars[i] 252 | if _is_punctuation(char): 253 | output.append([char]) 254 | start_new_word = True 255 | else: 256 | if start_new_word: 257 | output.append([]) 258 | start_new_word = False 259 | output[-1].append(char) 260 | i += 1 261 | 262 | return ["".join(x) for x in output] 263 | 264 | def _tokenize_chinese_chars(self, text): 265 | """Adds whitespace around any CJK character.""" 266 | output = [] 267 | for char in text: 268 | cp = ord(char) 269 | if self._is_chinese_char(cp): 270 | output.append(" ") 271 | output.append(char) 272 | output.append(" ") 273 | else: 274 | output.append(char) 275 | return "".join(output) 276 | 277 | def _is_chinese_char(self, cp): 278 | """Checks whether CP is the codepoint of a CJK character.""" 279 | # This defines a "chinese character" as anything in the CJK Unicode block: 280 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 281 | # 282 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 283 | # despite its name. The modern Korean Hangul alphabet is a different block, 284 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 285 | # space-separated words, so they are not treated specially and handled 286 | # like the all of the other languages. 287 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 288 | (cp >= 0x3400 and cp <= 0x4DBF) or # 289 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 290 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 291 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 292 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 293 | (cp >= 0xF900 and cp <= 0xFAFF) or # 294 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 295 | return True 296 | 297 | return False 298 | 299 | def _clean_text(self, text): 300 | """Performs invalid character removal and whitespace cleanup on text.""" 301 | output = [] 302 | for char in text: 303 | cp = ord(char) 304 | if cp == 0 or cp == 0xfffd or _is_control(char): 305 | continue 306 | if _is_whitespace(char): 307 | output.append(" ") 308 | else: 309 | output.append(char) 310 | return "".join(output) 311 | 312 | class WordpieceTokenizer(object): 313 | """Runs WordPiece tokenization.""" 314 | 315 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 316 | self.vocab = vocab 317 | self.unk_token = unk_token 318 | self.max_input_chars_per_word = max_input_chars_per_word 319 | 320 | def tokenize(self, text): 321 | """Tokenizes a piece of text into its word pieces. 322 | 323 | This uses a greedy longest-match-first algorithm to perform tokenization 324 | using the given vocabulary. 325 | 326 | For example: 327 | input = "unaffable" 328 | output = ["un", "##aff", "##able"] 329 | 330 | Args: 331 | text: A single token or whitespace separated tokens. This should have 332 | already been passed through `BasicTokenizer`. 333 | 334 | Returns: 335 | A list of wordpiece tokens. 336 | """ 337 | 338 | output_tokens = [] 339 | for token in whitespace_tokenize(text): 340 | chars = list(token) 341 | if len(chars) > self.max_input_chars_per_word: 342 | output_tokens.append(self.unk_token) 343 | continue 344 | 345 | is_bad = False 346 | start = 0 347 | sub_tokens = [] 348 | while start < len(chars): 349 | end = len(chars) 350 | cur_substr = None 351 | while start < end: 352 | substr = "".join(chars[start:end]) 353 | if start > 0: 354 | substr = "##" + substr 355 | if substr in self.vocab: 356 | cur_substr = substr 357 | break 358 | end -= 1 359 | if cur_substr is None: 360 | is_bad = True 361 | break 362 | sub_tokens.append(cur_substr) 363 | start = end 364 | 365 | if is_bad: 366 | output_tokens.append(self.unk_token) 367 | else: 368 | output_tokens.extend(sub_tokens) 369 | return output_tokens 370 | 371 | def _is_whitespace(char): 372 | """Checks whether `chars` is a whitespace character.""" 373 | # \t, \n, and \r are technically contorl characters but we treat them 374 | # as whitespace since they are generally considered as such. 375 | if char == " " or char == "\t" or char == "\n" or char == "\r": 376 | return True 377 | cat = unicodedata.category(char) 378 | if cat == "Zs": 379 | return True 380 | return False 381 | 382 | 383 | def _is_control(char): 384 | """Checks whether `chars` is a control character.""" 385 | # These are technically control characters but we count them as whitespace 386 | # characters. 387 | if char == "\t" or char == "\n" or char == "\r": 388 | return False 389 | cat = unicodedata.category(char) 390 | if cat.startswith("C"): 391 | return True 392 | return False 393 | 394 | 395 | def _is_punctuation(char): 396 | """Checks whether `chars` is a punctuation character.""" 397 | cp = ord(char) 398 | # We treat all non-letter/number ASCII as punctuation. 399 | # Characters such as "^", "$", and "`" are not in the Unicode 400 | # Punctuation class but we treat them as punctuation anyways, for 401 | # consistency. 402 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 403 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 404 | return True 405 | cat = unicodedata.category(char) 406 | if cat.startswith("P"): 407 | return True 408 | return False 409 | -------------------------------------------------------------------------------- /modules/module_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model. """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import math 26 | import logging 27 | import tarfile 28 | import tempfile 29 | import shutil 30 | import numpy as np 31 | 32 | import torch 33 | from torch import nn 34 | from .file_utils import cached_path 35 | from .until_config import PretrainedConfig 36 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 41 | CONFIG_NAME = 'decoder_config.json' 42 | WEIGHTS_NAME = 'decoder_pytorch_model.bin' 43 | 44 | 45 | class DecoderConfig(PretrainedConfig): 46 | """Configuration class to store the configuration of a `DecoderModel`. 47 | """ 48 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP 49 | config_name = CONFIG_NAME 50 | weights_name = WEIGHTS_NAME 51 | def __init__(self, 52 | vocab_size_or_config_json_file, 53 | hidden_size=768, 54 | num_hidden_layers=12, 55 | num_attention_heads=12, 56 | intermediate_size=3072, 57 | hidden_act="gelu", 58 | hidden_dropout_prob=0.1, 59 | attention_probs_dropout_prob=0.1, 60 | type_vocab_size=2, 61 | initializer_range=0.02, 62 | max_target_embeddings=128, 63 | num_decoder_layers=1): 64 | """Constructs DecoderConfig. 65 | 66 | Args: 67 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `DecoderModel`. 68 | hidden_size: Size of the encoder layers and the pooler layer. 69 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 70 | num_attention_heads: Number of attention heads for each attention layer in 71 | the Transformer encoder. 72 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 73 | layer in the Transformer encoder. 74 | hidden_act: The non-linear activation function (function or string) in the 75 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 76 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 77 | layers in the embeddings, encoder, and pooler. 78 | attention_probs_dropout_prob: The dropout ratio for the attention 79 | probabilities. 80 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 81 | `DecoderModel`. 82 | initializer_range: The sttdev of the truncated_normal_initializer for 83 | initializing all weight matrices. 84 | max_target_embeddings: The maximum sequence length that this model might 85 | ever be used with. Typically set this to something large just in case 86 | (e.g., 512 or 1024 or 2048). 87 | num_decoder_layers: 88 | """ 89 | if isinstance(vocab_size_or_config_json_file, str): 90 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 91 | json_config = json.loads(reader.read()) 92 | for key, value in json_config.items(): 93 | self.__dict__[key] = value 94 | elif isinstance(vocab_size_or_config_json_file, int): 95 | self.vocab_size = vocab_size_or_config_json_file 96 | self.hidden_size = hidden_size 97 | self.num_hidden_layers = num_hidden_layers 98 | self.num_attention_heads = num_attention_heads 99 | self.hidden_act = hidden_act 100 | self.intermediate_size = intermediate_size 101 | self.hidden_dropout_prob = hidden_dropout_prob 102 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 103 | self.type_vocab_size = type_vocab_size 104 | self.initializer_range = initializer_range 105 | self.max_target_embeddings = max_target_embeddings 106 | self.num_decoder_layers = num_decoder_layers 107 | else: 108 | raise ValueError("First argument must be either a vocabulary size (int)" 109 | "or the path to a pretrained model config file (str)") 110 | 111 | 112 | class BertSelfOutput(nn.Module): 113 | def __init__(self, config): 114 | super(BertSelfOutput, self).__init__() 115 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 116 | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 117 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 118 | 119 | def forward(self, hidden_states, input_tensor): 120 | hidden_states = self.dense(hidden_states) 121 | hidden_states = self.dropout(hidden_states) 122 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 123 | return hidden_states 124 | 125 | class BertIntermediate(nn.Module): 126 | def __init__(self, config): 127 | super(BertIntermediate, self).__init__() 128 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 129 | self.intermediate_act_fn = ACT2FN[config.hidden_act] \ 130 | if isinstance(config.hidden_act, str) else config.hidden_act 131 | 132 | def forward(self, hidden_states): 133 | hidden_states = self.dense(hidden_states) 134 | hidden_states = self.intermediate_act_fn(hidden_states) 135 | return hidden_states 136 | 137 | 138 | class BertOutput(nn.Module): 139 | def __init__(self, config): 140 | super(BertOutput, self).__init__() 141 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 142 | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 143 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 144 | 145 | def forward(self, hidden_states, input_tensor): 146 | hidden_states = self.dense(hidden_states) 147 | hidden_states = self.dropout(hidden_states) 148 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 149 | return hidden_states 150 | 151 | 152 | class BertPredictionHeadTransform(nn.Module): 153 | def __init__(self, config): 154 | super(BertPredictionHeadTransform, self).__init__() 155 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 156 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 157 | if isinstance(config.hidden_act, str) else config.hidden_act 158 | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 159 | 160 | def forward(self, hidden_states): 161 | hidden_states = self.dense(hidden_states) 162 | hidden_states = self.transform_act_fn(hidden_states) 163 | hidden_states = self.LayerNorm(hidden_states) 164 | return hidden_states 165 | 166 | 167 | class BertLMPredictionHead(nn.Module): 168 | def __init__(self, config, decoder_model_embedding_weights): 169 | super(BertLMPredictionHead, self).__init__() 170 | self.transform = BertPredictionHeadTransform(config) 171 | 172 | # The output weights are the same as the input embeddings, but there is 173 | # an output-only bias for each token. 174 | self.decoder = nn.Linear(decoder_model_embedding_weights.size(1), 175 | decoder_model_embedding_weights.size(0), 176 | bias=False) 177 | self.decoder.weight = decoder_model_embedding_weights 178 | self.bias = nn.Parameter(torch.zeros(decoder_model_embedding_weights.size(0))) 179 | 180 | def forward(self, hidden_states): 181 | hidden_states = self.transform(hidden_states) 182 | hidden_states = self.decoder(hidden_states) + self.bias 183 | return hidden_states 184 | 185 | 186 | class BertOnlyMLMHead(nn.Module): 187 | def __init__(self, config, decoder_model_embedding_weights): 188 | super(BertOnlyMLMHead, self).__init__() 189 | self.predictions = BertLMPredictionHead(config, decoder_model_embedding_weights) 190 | 191 | def forward(self, sequence_output): 192 | prediction_scores = self.predictions(sequence_output) 193 | return prediction_scores 194 | 195 | class MultiHeadAttention(nn.Module): 196 | ''' Multi-Head Attention module ''' 197 | 198 | def __init__(self, config): 199 | super(MultiHeadAttention, self).__init__() 200 | 201 | if config.hidden_size % config.num_attention_heads != 0: 202 | raise ValueError( 203 | "The hidden size (%d) is not a multiple of the number of attention " 204 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 205 | self.num_attention_heads = config.num_attention_heads 206 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 207 | self.all_head_size = self.num_attention_heads * self.attention_head_size 208 | 209 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 210 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 211 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 212 | 213 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 214 | 215 | def transpose_for_scores(self, x): 216 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 217 | x = x.view(*new_x_shape) 218 | return x.permute(0, 2, 1, 3) 219 | 220 | def forward(self, q, k, v, attention_mask): 221 | mixed_query_layer = self.query(q) 222 | mixed_key_layer = self.key(k) 223 | mixed_value_layer = self.value(v) 224 | 225 | query_layer = self.transpose_for_scores(mixed_query_layer) 226 | key_layer = self.transpose_for_scores(mixed_key_layer) 227 | value_layer = self.transpose_for_scores(mixed_value_layer) 228 | 229 | # Take the dot product between "query" and "key" to get the raw attention scores. 230 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 231 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 232 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 233 | attention_scores = attention_scores + attention_mask 234 | 235 | # Normalize the attention scores to probabilities. 236 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 237 | 238 | # This is actually dropping out entire tokens to attend to, which might 239 | # seem a bit unusual, but is taken from the original Transformer paper. 240 | attention_probs = self.dropout(attention_probs) 241 | 242 | context_layer = torch.matmul(attention_probs, value_layer) 243 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 244 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 245 | context_layer = context_layer.view(*new_context_layer_shape) 246 | 247 | return context_layer, attention_scores 248 | 249 | class PositionwiseFeedForward(nn.Module): 250 | ''' A two-feed-forward-layer module ''' 251 | 252 | def __init__(self, d_in, d_hid, dropout=0.1): 253 | super().__init__() 254 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 255 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 256 | self.layer_norm = nn.LayerNorm(d_in) 257 | self.dropout = nn.Dropout(dropout) 258 | 259 | def forward(self, x): 260 | residual = x 261 | output = x.transpose(1, 2) 262 | output = self.w_2(ACT2FN["gelu"](self.w_1(output))) 263 | output = output.transpose(1, 2) 264 | output = self.dropout(output) 265 | output = self.layer_norm(output + residual) 266 | return output 267 | 268 | class DecoderAttention(nn.Module): 269 | def __init__(self, config): 270 | super(DecoderAttention, self).__init__() 271 | self.att = MultiHeadAttention(config) 272 | self.output = BertSelfOutput(config) 273 | 274 | def forward(self, q, k, v, attention_mask): 275 | att_output, attention_probs = self.att(q, k, v, attention_mask) 276 | attention_output = self.output(att_output, q) 277 | return attention_output, attention_probs 278 | 279 | class DecoderLayer(nn.Module): 280 | def __init__(self, config): 281 | super(DecoderLayer, self).__init__() 282 | self.slf_attn = DecoderAttention(config) 283 | self.enc_attn = DecoderAttention(config) 284 | self.intermediate = BertIntermediate(config) 285 | self.output = BertOutput(config) 286 | 287 | def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None): 288 | slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask) 289 | dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask) 290 | intermediate_output = self.intermediate(dec_output) 291 | dec_output = self.output(intermediate_output, dec_output) 292 | return dec_output, dec_att_scores 293 | 294 | class DecoderEmbeddings(nn.Module): 295 | """Construct the embeddings from word, position and token_type embeddings. 296 | """ 297 | def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight): 298 | super(DecoderEmbeddings, self).__init__() 299 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 300 | self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size) 301 | self.word_embeddings.weight = decoder_word_embeddings_weight 302 | self.position_embeddings.weight = decoder_position_embeddings_weight 303 | 304 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 305 | # any TensorFlow checkpoint file 306 | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 307 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 308 | 309 | def forward(self, input_ids): 310 | seq_length = input_ids.size(1) 311 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 312 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 313 | 314 | words_embeddings = self.word_embeddings(input_ids) 315 | position_embeddings = self.position_embeddings(position_ids) 316 | 317 | embeddings = words_embeddings + position_embeddings 318 | embeddings = self.LayerNorm(embeddings) 319 | embeddings = self.dropout(embeddings) 320 | return embeddings 321 | 322 | class Decoder(nn.Module): 323 | def __init__(self, config): 324 | super(Decoder, self).__init__() 325 | layer = DecoderLayer(config) 326 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)]) 327 | 328 | def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False): 329 | dec_att_scores = None 330 | all_encoder_layers = [] 331 | all_dec_att_probs = [] 332 | for layer_module in self.layer: 333 | hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask) 334 | if output_all_encoded_layers: 335 | all_encoder_layers.append(hidden_states) 336 | all_dec_att_probs.append(dec_att_scores) 337 | if not output_all_encoded_layers: 338 | all_encoder_layers.append(hidden_states) 339 | all_dec_att_probs.append(dec_att_scores) 340 | return all_encoder_layers, all_dec_att_probs 341 | 342 | class DecoderClassifier(nn.Module): 343 | def __init__(self, config, embedding_weights): 344 | super(DecoderClassifier, self).__init__() 345 | self.cls = BertOnlyMLMHead(config, embedding_weights) 346 | 347 | def forward(self, hidden_states): 348 | cls_scores = self.cls(hidden_states) 349 | return cls_scores 350 | 351 | class DecoderModel(PreTrainedModel): 352 | 353 | """ 354 | Transformer decoder consisting of *args.decoder_layers* layers. Each layer 355 | is a :class:`TransformerDecoderLayer`. 356 | 357 | Args: 358 | args (argparse.Namespace): parsed command-line arguments 359 | final_norm (bool, optional): apply layer norm to the output of the 360 | final decoder layer (default: True). 361 | """ 362 | 363 | def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight): 364 | super(DecoderModel, self).__init__(config) 365 | self.config = config 366 | self.max_target_length = config.max_target_embeddings 367 | self.embeddings = DecoderEmbeddings(config, decoder_word_embeddings_weight, decoder_position_embeddings_weight) 368 | self.decoder = Decoder(config) 369 | self.classifier = DecoderClassifier(config, decoder_word_embeddings_weight) 370 | self.apply(self.init_weights) 371 | 372 | def forward(self, input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None): 373 | """ 374 | Args: 375 | input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing 376 | encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention 377 | 378 | Returns: 379 | tuple: 380 | - the last decoder layer's output of shape `(batch, tgt_len, vocab)` 381 | - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` 382 | """ 383 | embedding_output = self.embeddings(input_ids) 384 | 385 | extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls 386 | extended_encoder_mask = extended_encoder_mask.to(dtype=self.dtype) # fp16 compatibility 387 | extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0 388 | 389 | extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2) 390 | extended_answer_mask = extended_answer_mask.to(dtype=self.dtype) # fp16 compatibility 391 | 392 | sz_b, len_s, _ = embedding_output.size() 393 | subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1) 394 | self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls 395 | slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=self.dtype) 396 | self_attn_mask = slf_attn_mask * -10000.0 397 | 398 | decoded_layers, dec_att_scores = self.decoder(embedding_output, 399 | encoder_outs, 400 | self_attn_mask, 401 | extended_encoder_mask, 402 | ) 403 | sequence_output = decoded_layers[-1] 404 | cls_scores = self.classifier(sequence_output) 405 | 406 | return cls_scores 407 | --------------------------------------------------------------------------------