├── LICENSE ├── README.md ├── config ├── bert-base-uncased_weight_name.json ├── bert-large-uncased_weight_name.json └── bert_base_6layer_6conect_capture_itp3va.json ├── dataloaders ├── __init__.py ├── classification_dataset_ITP3VA.py ├── data_utils.py ├── pretrain_dataset_ITP3VA.py └── test.py ├── evaluate_unit_v2.py ├── examples ├── README.md └── SCALE │ ├── eval │ ├── eval_gallery1_1.sh │ ├── eval_gallery1_cls.sh │ ├── eval_gallery1_cls_fg.sh │ ├── extract_features.py │ └── extract_features_cls.py │ ├── pretrain_task.py │ ├── run_pretrain_task.sh │ ├── run_train_cls.sh │ ├── train_cls.py │ └── utils_args.py ├── m5.yaml ├── model ├── SCALE.py ├── __init__.py ├── cross-base │ ├── cross_config.json │ └── cross_config_temp.json ├── file_utils.py ├── module_cross.py ├── optimization.py ├── until_config.py ├── until_module.py └── utils.py ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── optimization.py ├── optimization_openai.py ├── tokenization.py ├── tokenization_gpt2.py ├── tokenization_openai.py └── tokenization_transfo_xl.py ├── retrieval_unit_id_list_v2.py └── tools ├── VideoFeatureExtractor ├── .gitignore ├── LICENSE ├── README.md ├── convert_video_feature_to_pickle.py ├── demo │ └── run_extract_subset.sh ├── extract.py ├── input.csv ├── input.pickle ├── model.py ├── preprocess_generate_csv.py ├── preprocess_generate_csv2.py ├── preprocess_generate_csv3.py ├── preprocessing.py ├── random_sequence_shuffler.py ├── run_gen_csv2.sh ├── video_loader.py └── videocnn │ └── models │ ├── resnext.py │ └── s3dg.py ├── audio_process ├── extract_audio_feature.py ├── get_audio_from_video_v4.py └── save_audio_feature.py └── bp_feature ├── convert ├── convert_gallery_c.py ├── convert_gallery_fg.py ├── convert_query_c.py ├── convert_query_fg.py ├── convert_subset_test.py ├── convert_subset_train.py ├── convert_subset_v4_test.py ├── convert_subset_v4_train.py ├── convert_train_all.py └── get_all_tsv_filename.py └── extract ├── generate_subset_tsv.py ├── run_subset_v3_1.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 XiaoDong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCALE_code 2 | M5Product: Self-harmonized Contrastive Learning for E-commercial Multi-modal Pretraining CVPR 2022 3 | -------------------------------------------------------------------------------- /config/bert-base-uncased_weight_name.json: -------------------------------------------------------------------------------- 1 | ["embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.token_type_embeddings.weight", "embeddings.LayerNorm.weight", "embeddings.LayerNorm.bias", "encoder.layer.0.attention.self.query.weight", "encoder.layer.0.attention.self.query.bias", "encoder.layer.0.attention.self.key.weight", "encoder.layer.0.attention.self.key.bias", "encoder.layer.0.attention.self.value.weight", "encoder.layer.0.attention.self.value.bias", "encoder.layer.0.attention.output.dense.weight", "encoder.layer.0.attention.output.dense.bias", "encoder.layer.0.attention.output.LayerNorm.weight", "encoder.layer.0.attention.output.LayerNorm.bias", "encoder.layer.0.intermediate.dense.weight", "encoder.layer.0.intermediate.dense.bias", "encoder.layer.0.output.dense.weight", "encoder.layer.0.output.dense.bias", "encoder.layer.0.output.LayerNorm.weight", "encoder.layer.0.output.LayerNorm.bias", "encoder.layer.1.attention.self.query.weight", "encoder.layer.1.attention.self.query.bias", "encoder.layer.1.attention.self.key.weight", "encoder.layer.1.attention.self.key.bias", "encoder.layer.1.attention.self.value.weight", "encoder.layer.1.attention.self.value.bias", "encoder.layer.1.attention.output.dense.weight", "encoder.layer.1.attention.output.dense.bias", "encoder.layer.1.attention.output.LayerNorm.weight", "encoder.layer.1.attention.output.LayerNorm.bias", "encoder.layer.1.intermediate.dense.weight", "encoder.layer.1.intermediate.dense.bias", "encoder.layer.1.output.dense.weight", "encoder.layer.1.output.dense.bias", "encoder.layer.1.output.LayerNorm.weight", "encoder.layer.1.output.LayerNorm.bias", "encoder.layer.2.attention.self.query.weight", "encoder.layer.2.attention.self.query.bias", "encoder.layer.2.attention.self.key.weight", "encoder.layer.2.attention.self.key.bias", "encoder.layer.2.attention.self.value.weight", "encoder.layer.2.attention.self.value.bias", "encoder.layer.2.attention.output.dense.weight", "encoder.layer.2.attention.output.dense.bias", "encoder.layer.2.attention.output.LayerNorm.weight", "encoder.layer.2.attention.output.LayerNorm.bias", "encoder.layer.2.intermediate.dense.weight", "encoder.layer.2.intermediate.dense.bias", "encoder.layer.2.output.dense.weight", "encoder.layer.2.output.dense.bias", "encoder.layer.2.output.LayerNorm.weight", "encoder.layer.2.output.LayerNorm.bias", "encoder.layer.3.attention.self.query.weight", "encoder.layer.3.attention.self.query.bias", "encoder.layer.3.attention.self.key.weight", "encoder.layer.3.attention.self.key.bias", "encoder.layer.3.attention.self.value.weight", "encoder.layer.3.attention.self.value.bias", "encoder.layer.3.attention.output.dense.weight", "encoder.layer.3.attention.output.dense.bias", "encoder.layer.3.attention.output.LayerNorm.weight", "encoder.layer.3.attention.output.LayerNorm.bias", "encoder.layer.3.intermediate.dense.weight", "encoder.layer.3.intermediate.dense.bias", "encoder.layer.3.output.dense.weight", "encoder.layer.3.output.dense.bias", "encoder.layer.3.output.LayerNorm.weight", "encoder.layer.3.output.LayerNorm.bias", "encoder.layer.4.attention.self.query.weight", "encoder.layer.4.attention.self.query.bias", "encoder.layer.4.attention.self.key.weight", "encoder.layer.4.attention.self.key.bias", "encoder.layer.4.attention.self.value.weight", "encoder.layer.4.attention.self.value.bias", "encoder.layer.4.attention.output.dense.weight", "encoder.layer.4.attention.output.dense.bias", "encoder.layer.4.attention.output.LayerNorm.weight", "encoder.layer.4.attention.output.LayerNorm.bias", "encoder.layer.4.intermediate.dense.weight", "encoder.layer.4.intermediate.dense.bias", "encoder.layer.4.output.dense.weight", "encoder.layer.4.output.dense.bias", "encoder.layer.4.output.LayerNorm.weight", "encoder.layer.4.output.LayerNorm.bias", "encoder.layer.5.attention.self.query.weight", "encoder.layer.5.attention.self.query.bias", "encoder.layer.5.attention.self.key.weight", "encoder.layer.5.attention.self.key.bias", "encoder.layer.5.attention.self.value.weight", "encoder.layer.5.attention.self.value.bias", "encoder.layer.5.attention.output.dense.weight", "encoder.layer.5.attention.output.dense.bias", "encoder.layer.5.attention.output.LayerNorm.weight", "encoder.layer.5.attention.output.LayerNorm.bias", "encoder.layer.5.intermediate.dense.weight", "encoder.layer.5.intermediate.dense.bias", "encoder.layer.5.output.dense.weight", "encoder.layer.5.output.dense.bias", "encoder.layer.5.output.LayerNorm.weight", "encoder.layer.5.output.LayerNorm.bias", "encoder.layer.6.attention.self.query.weight", "encoder.layer.6.attention.self.query.bias", "encoder.layer.6.attention.self.key.weight", "encoder.layer.6.attention.self.key.bias", "encoder.layer.6.attention.self.value.weight", "encoder.layer.6.attention.self.value.bias", "encoder.layer.6.attention.output.dense.weight", "encoder.layer.6.attention.output.dense.bias", "encoder.layer.6.attention.output.LayerNorm.weight", "encoder.layer.6.attention.output.LayerNorm.bias", "encoder.layer.6.intermediate.dense.weight", "encoder.layer.6.intermediate.dense.bias", "encoder.layer.6.output.dense.weight", "encoder.layer.6.output.dense.bias", "encoder.layer.6.output.LayerNorm.weight", "encoder.layer.6.output.LayerNorm.bias", "encoder.layer.7.attention.self.query.weight", "encoder.layer.7.attention.self.query.bias", "encoder.layer.7.attention.self.key.weight", "encoder.layer.7.attention.self.key.bias", "encoder.layer.7.attention.self.value.weight", "encoder.layer.7.attention.self.value.bias", "encoder.layer.7.attention.output.dense.weight", "encoder.layer.7.attention.output.dense.bias", "encoder.layer.7.attention.output.LayerNorm.weight", "encoder.layer.7.attention.output.LayerNorm.bias", "encoder.layer.7.intermediate.dense.weight", "encoder.layer.7.intermediate.dense.bias", "encoder.layer.7.output.dense.weight", "encoder.layer.7.output.dense.bias", "encoder.layer.7.output.LayerNorm.weight", "encoder.layer.7.output.LayerNorm.bias", "encoder.layer.8.attention.self.query.weight", "encoder.layer.8.attention.self.query.bias", "encoder.layer.8.attention.self.key.weight", "encoder.layer.8.attention.self.key.bias", "encoder.layer.8.attention.self.value.weight", "encoder.layer.8.attention.self.value.bias", "encoder.layer.8.attention.output.dense.weight", "encoder.layer.8.attention.output.dense.bias", "encoder.layer.8.attention.output.LayerNorm.weight", "encoder.layer.8.attention.output.LayerNorm.bias", "encoder.layer.8.intermediate.dense.weight", "encoder.layer.8.intermediate.dense.bias", "encoder.layer.8.output.dense.weight", "encoder.layer.8.output.dense.bias", "encoder.layer.8.output.LayerNorm.weight", "encoder.layer.8.output.LayerNorm.bias", "encoder.layer.9.attention.self.query.weight", "encoder.layer.9.attention.self.query.bias", "encoder.layer.9.attention.self.key.weight", "encoder.layer.9.attention.self.key.bias", "encoder.layer.9.attention.self.value.weight", "encoder.layer.9.attention.self.value.bias", "encoder.layer.9.attention.output.dense.weight", "encoder.layer.9.attention.output.dense.bias", "encoder.layer.9.attention.output.LayerNorm.weight", "encoder.layer.9.attention.output.LayerNorm.bias", "encoder.layer.9.intermediate.dense.weight", "encoder.layer.9.intermediate.dense.bias", "encoder.layer.9.output.dense.weight", "encoder.layer.9.output.dense.bias", "encoder.layer.9.output.LayerNorm.weight", "encoder.layer.9.output.LayerNorm.bias", "encoder.layer.10.attention.self.query.weight", "encoder.layer.10.attention.self.query.bias", "encoder.layer.10.attention.self.key.weight", "encoder.layer.10.attention.self.key.bias", "encoder.layer.10.attention.self.value.weight", "encoder.layer.10.attention.self.value.bias", "encoder.layer.10.attention.output.dense.weight", "encoder.layer.10.attention.output.dense.bias", "encoder.layer.10.attention.output.LayerNorm.weight", "encoder.layer.10.attention.output.LayerNorm.bias", "encoder.layer.10.intermediate.dense.weight", "encoder.layer.10.intermediate.dense.bias", "encoder.layer.10.output.dense.weight", "encoder.layer.10.output.dense.bias", "encoder.layer.10.output.LayerNorm.weight", "encoder.layer.10.output.LayerNorm.bias", "encoder.layer.11.attention.self.query.weight", "encoder.layer.11.attention.self.query.bias", "encoder.layer.11.attention.self.key.weight", "encoder.layer.11.attention.self.key.bias", "encoder.layer.11.attention.self.value.weight", "encoder.layer.11.attention.self.value.bias", "encoder.layer.11.attention.output.dense.weight", "encoder.layer.11.attention.output.dense.bias", "encoder.layer.11.attention.output.LayerNorm.weight", "encoder.layer.11.attention.output.LayerNorm.bias", "encoder.layer.11.intermediate.dense.weight", "encoder.layer.11.intermediate.dense.bias", "encoder.layer.11.output.dense.weight", "encoder.layer.11.output.dense.bias", "encoder.layer.11.output.LayerNorm.weight", "encoder.layer.11.output.LayerNorm.bias"] -------------------------------------------------------------------------------- /config/bert_base_6layer_6conect_capture_itp3va.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "max_video_len": 50, 10 | "max_audio_len": 100, 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 6, 13 | "type_vocab_size": 2, 14 | "vocab_size": 21128, 15 | "v_feature_size": 2048, 16 | "video_feature_size": 1024, 17 | "audio_feature_size": 751, 18 | "v_target_size": 1601, 19 | "video_target_size": 1024, 20 | "audio_target_size": 751, 21 | "v_hidden_size": 768, 22 | "v_num_hidden_layers":6, 23 | "v_num_attention_heads":8, 24 | "v_intermediate_size":1024, 25 | "bi_hidden_size":1024, 26 | "bi_num_attention_heads":8, 27 | "bi_intermediate_size": 1024, 28 | "bi_attention_type":1, 29 | "co_num_layers": 6, 30 | "v_attention_probs_dropout_prob":0.1, 31 | "v_hidden_act":"gelu", 32 | "v_hidden_dropout_prob":0.1, 33 | "v_initializer_range":0.02, 34 | "v_biattention_id":[0, 1, 2, 3, 4, 5], 35 | "t_biattention_id":[6, 7, 8, 9, 10, 11], 36 | "pooling_method": "mul", 37 | "num_classes": 1805 38 | } 39 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaodongsuper/SCALE_code/227ae0886a0bda598495e60c4624ca894a7418bf/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonlines 3 | import pickle 4 | import csv 5 | import re 6 | 7 | class IOProcessor(): 8 | def read_jsonline(self,file): 9 | file=open(file,"r",encoding="utf-8") 10 | data=[json.loads(line) for line in file.readlines()] 11 | return data 12 | 13 | def write_jsonline(self,file,data): 14 | f=jsonlines.open(file,"w") 15 | for each in data: 16 | jsonlines.Writer.write(f,each) 17 | return 18 | 19 | def read_json(self,file): 20 | f=open(file,"r",encoding="utf-8").read() 21 | return json.loads(f) 22 | 23 | def write_json(self,file,data): 24 | f=open(file,"w",encoding="utf-8") 25 | json.dump(data,f,indent=2,ensure_ascii=False) 26 | return 27 | 28 | def read_pickle(self,filename): 29 | return pickle.loads(open(filename,"rb").read()) 30 | 31 | def write_pickle(self,filename,data): 32 | open(filename,"wb").write(pickle.dumps(data)) 33 | return 34 | 35 | 36 | def read_csv(self,filename): 37 | csv_data = csv.reader(open(filename, "r", encoding="utf-8")) 38 | csv_data=[each for each in csv_data] 39 | return csv_data 40 | 41 | 42 | if __name__ == '__main__': 43 | pass 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /dataloaders/test.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | 6 | a=[1,2,3,4,5] 7 | b=np.array(a) 8 | print(b) 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /evaluate_unit_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import division 3 | import sys 4 | import os 5 | import io 6 | import os.path 7 | import numpy as np 8 | import json 9 | import shutil 10 | import argparse 11 | 12 | 13 | def parse_args(): 14 | parser=argparse.ArgumentParser() 15 | 16 | parser.add_argument("--output_metric_dir",type=str) 17 | parser.add_argument("--retrieval_result_dir",type=str) 18 | 19 | parser.add_argument("--GT_file",type=str) 20 | 21 | parser.add_argument("--t",action="store_true") 22 | parser.add_argument("--p",action="store_true") 23 | parser.add_argument("--i",action="store_true") 24 | parser.add_argument("--v",action="store_true") 25 | parser.add_argument("--a",action="store_true") 26 | 27 | parser.add_argument("--tp",action="store_true") 28 | parser.add_argument("--ti", action="store_true") 29 | parser.add_argument("--tv", action="store_true") 30 | parser.add_argument("--pi", action="store_true") 31 | parser.add_argument("--pv", action="store_true") 32 | parser.add_argument("--iv", action="store_true") 33 | parser.add_argument("--ta", action="store_true") 34 | parser.add_argument("--pa", action="store_true") 35 | parser.add_argument("--ia", action="store_true") 36 | parser.add_argument("--va", action="store_true") 37 | 38 | 39 | parser.add_argument("--tpi", action="store_true") 40 | parser.add_argument("--tpv", action="store_true") 41 | parser.add_argument("--tiv", action="store_true") 42 | parser.add_argument("--piv", action="store_true") 43 | 44 | parser.add_argument("--tpiv", action="store_true") 45 | parser.add_argument("--tpiva", action="store_true") 46 | parser.add_argument("--dense", action="store_true") 47 | 48 | return parser.parse_args() 49 | 50 | 51 | # 52 | # if not os.path.isdir(submit_dir): 53 | # print ("%s doesn't exist" % submit_dir) 54 | # 55 | # if os.path.isdir(submit_dir) and os.path.isdir(truth_dir): 56 | # if not os.path.exists(output_dir): 57 | # os.makedirs(output_dir) 58 | 59 | def read_json(file): 60 | f=io.open(file,"r",encoding="utf-8").read() 61 | f=f.encode("utf-8") 62 | return json.loads(f) 63 | 64 | def write_json(file,data): 65 | f=open(file,"w") 66 | json.dump(data,f,indent=2,ensure_ascii=False) 67 | return 68 | 69 | def compute_p(rank_list,pos_set,topk): 70 | intersect_size = 0 71 | for i in range(topk): 72 | if rank_list[i] in pos_set: 73 | intersect_size+=1 74 | 75 | p=float(intersect_size/topk) 76 | 77 | return p 78 | 79 | def compute_ap(rank_list,pos_set,topk): 80 | ''' 81 | rank_list: 82 | pos_list: 83 | rank_list=["a","d","b","c"] 84 | pos_set=["b","c"] 85 | ap=compute_ap(rank_list,pos_set) 86 | print("ap: ",ap) 87 | ''' 88 | intersect_size=0 89 | ap=0 90 | 91 | for i in range(topk): 92 | if rank_list[i] in pos_set: 93 | intersect_size += 1 94 | precision = intersect_size / (i+1) 95 | ap+=precision 96 | if intersect_size==0: 97 | return 0 98 | ap/=intersect_size 99 | 100 | return ap 101 | 102 | def compute_HitRate(rank_label_set,query_label_set): 103 | return len(rank_label_set.intersection(query_label_set))/len(query_label_set) 104 | 105 | 106 | 107 | def main(): 108 | args = parse_args() 109 | 110 | if not os.path.exists(args.output_metric_dir): 111 | os.makedirs(args.output_metric_dir) 112 | 113 | feature_type = [] 114 | if args.t: feature_type.append("t") 115 | if args.p: feature_type.append("p") 116 | if args.i: feature_type.append("i") 117 | if args.v: feature_type.append("v") 118 | if args.a: feature_type.append("a") 119 | 120 | if args.tp: feature_type.append("tp") 121 | if args.ti: feature_type.append("ti") 122 | if args.tv: feature_type.append("tv") 123 | if args.pi: feature_type.append("pi") 124 | if args.pv: feature_type.append("pv") 125 | if args.iv: feature_type.append("iv") 126 | if args.ta: feature_type.append("ta") 127 | if args.pa: feature_type.append("pa") 128 | if args.ia: feature_type.append("ia") 129 | if args.va: feature_type.append("va") 130 | 131 | if args.tpi: feature_type.append("tpi") 132 | if args.tpv: feature_type.append("tpv") 133 | if args.tiv: feature_type.append("tiv") 134 | if args.piv: feature_type.append("piv") 135 | 136 | if args.tpiv: feature_type.append("tpiv") 137 | if args.tpiva: feature_type.append("tpiva") 138 | if args.dense: feature_type.append("dense") 139 | 140 | # gallery_unit_id_label_txt=open("{}/gallery_unit_id_label.txt".format(args.GT_dir)).readlines() 141 | # test_query_suit_id_label_txt = open("{}/test_query_suit_id_label.txt".format(args.GT_dir)).readlines() 142 | 143 | all_id_label_temp=open("{}".format(args.GT_file),"r",encoding='utf-8').read() 144 | all_id_label_temp=json.loads(all_id_label_temp) 145 | all_id_label={} 146 | 147 | 148 | for id,info in all_id_label_temp.items(): 149 | all_id_label[id]={ 150 | "label":[info["label"]] 151 | } 152 | 153 | # print("all_id_label: ",all_id_label) 154 | 155 | # gallery_unit_id_label={} 156 | # for line in gallery_unit_id_label_txt: 157 | # line=line.strip() 158 | # line_split=line.split("#####") 159 | # item_id=line_split[0] 160 | # label_list=line_split[1].split("#;#") 161 | # gallery_unit_id_label[item_id]={ 162 | # "label":label_list 163 | # } 164 | # 165 | # test_query_suit_id_label={} 166 | # for line in test_query_suit_id_label_txt: 167 | # line=line.strip() 168 | # line_split=line.split("#####") 169 | # item_id=line_split[0] 170 | # label_list=line_split[1].split("#;#") 171 | # test_query_suit_id_label[item_id]={ 172 | # "label":label_list 173 | # } 174 | 175 | 176 | all_label_id={} 177 | for item_id,info in all_id_label.items(): 178 | label=info["label"][0] 179 | if label not in all_label_id: 180 | all_label_id[label]=[item_id] 181 | else: 182 | all_label_id[label]+=[item_id] 183 | 184 | 185 | results={} 186 | for each_feature_type in feature_type: 187 | results[each_feature_type]={} 188 | 189 | retrieval_results=open("{}/{}_feature_retrieval_id_list.txt" 190 | .format(args.retrieval_result_dir,each_feature_type),"r").readlines() 191 | 192 | topk_list=[1,5,10] 193 | for topk in topk_list: 194 | topk_temp=topk 195 | mAP=0 196 | mHitRate=0 197 | mP=0 198 | cnt=0 199 | for index,each in enumerate(retrieval_results): 200 | each=each.strip() 201 | each_split=each.split(",") 202 | query_id=each_split[0] 203 | rank_id_list=each_split[1:] 204 | pos_set=[] 205 | 206 | # try: 207 | cnt+=1 208 | query_suit_labels=all_id_label[query_id]["label"] 209 | for label in query_suit_labels: 210 | pos_set+=all_label_id[label] 211 | 212 | topk = min(topk_temp, len(pos_set),len(rank_id_list)) 213 | 214 | # if topk<10: 215 | # print() 216 | # print("query_suit_labels: ",query_suit_labels) 217 | # print("topk in: ",topk) 218 | # print("pos set: ",len(pos_set)) 219 | # print("rank id list: ",len(rank_id_list)) 220 | # print() 221 | 222 | ap=compute_ap(rank_id_list,pos_set,topk) 223 | p = compute_p(rank_id_list, pos_set, topk) 224 | 225 | # print("ap: ",ap) 226 | # print() 227 | 228 | 229 | mAP+=ap 230 | mP += p 231 | 232 | # # hit rate 233 | # query_suit_label_set = set(gallery_unit_id_label[query_id]["label"]) 234 | # rank_label_set = set([gallery_unit_id_label[item_id]["label"][0] for item_id in rank_id_list[:topk]]) 235 | 236 | # hit_rate=compute_HitRate(rank_label_set,query_suit_label_set) 237 | # mHitRate+=hit_rate 238 | 239 | # except Exception as e: 240 | # print(e) 241 | # print(query_id) 242 | 243 | # continue 244 | # if index==100: 245 | # break 246 | 247 | 248 | mAP/=cnt 249 | mHitRate/=cnt 250 | mP /= cnt 251 | 252 | 253 | # print("topk: ",topk) 254 | # print("topk_temp: ",topk_temp) 255 | print("topk: {} mAP: {} ".format(topk_temp,mAP)) 256 | 257 | results[each_feature_type]["top{}".format(topk_temp)]={ 258 | "mAP": mAP*100, 259 | "mHitRate": mHitRate*100, 260 | "Prec": mP*100, 261 | "average": 100*(mAP + mHitRate + mP) / 3 262 | } 263 | 264 | write_json("{}/metric_results.json".format(args.output_metric_dir),results) 265 | 266 | return 267 | 268 | if __name__ == '__main__': 269 | main() 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ITP2V是pv和title 拼接的版本 2 | 3 | ITP3V_capture是PV和title分开encode的版本,但是这个版本的问题是对比学习loss的 4 | 代码有问题 5 | 6 | ITP3V_capture_v2是改了正确的对比学习的版本,但是效果好像有点低 7 | 8 | 9 | ITP3V_capture_v2_dymask是在上面的版本基础上,增加了mask任务的动态权重 10 | 11 | 12 | ITP3V_capture_v3是调整了ITP3V_capture的对比学习loss代码 13 | 不用n_views的方式,两两pair,还是之前的代码 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /examples/SCALE/eval/eval_gallery1_1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TRAIN_TYPE=ITP3VA_capture_v3_dyctr_dymask_per_sample 3 | MODEL_TYPE=capture_subset_v2_MLM_MRM_MEM_MFM_MAM_CLR 4 | QUERY_FEATURE=subset_v2_feature 5 | GALLERY_FEATURE=subset_v2_feature 6 | QUERY_FEATURE_DIR=examples/${TRAIN_TYPE}/eval/feature_data/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 7 | GALLERY_FEATURE_DIR=examples/${TRAIN_TYPE}/eval/feature_data/${MODEL_TYPE}/${GALLERY_FEATURE}/return_hidden 8 | RETRIEVAL_RESULTS_DIR=examples/${TRAIN_TYPE}/eval/retrieval_id_list/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 9 | 10 | 11 | GALLERY_FILE=/multi_modal/data/train.txt 12 | QUERY_FILE=/multi_modal/data/query.txt 13 | IMAGE_PATH=/multi_modal/data/images 14 | VIDEO_PATH=/multi_modal/data/videos 15 | 16 | 17 | # remenber to change the dir and filename 18 | # gallery 19 | CUDA_VISIBLE_DEVICES=0 python extract_features.py \ 20 | --bert_model bert-base-chinese \ 21 | --from_pretrained ../save/${MODEL_TYPE}/pytorch_model_9.bin \ 22 | --config_file ../../../config/bert_base_6layer_6conect_capture_itp3va.json \ 23 | --predict_feature \ 24 | --video_feature_dir /multi_modal/data/video_feature \ 25 | --audio_file_dir /multi_modal/data/audios \ 26 | --split test \ 27 | --train_batch_size 32 \ 28 | --max_seq_length 36 \ 29 | --zero_shot \ 30 | --video_len 12 \ 31 | --pv_seq_len 64 \ 32 | --audio_len 12 \ 33 | --lmdb_file /multi_modal/data/lmdb_features/${GALLERY_FEATURE}.lmdb \ 34 | --caption_path /multi_modal/data/product5m_v2/subset_v2_id_label.json \ 35 | --feature_dir ./feature_data/${MODEL_TYPE}/${GALLERY_FEATURE}/return_hidden \ 36 | --return_hidden 37 | 38 | 39 | 40 | # query 41 | #CUDA_VISIBLE_DEVICES=7 python extract_features.py \ 42 | # --bert_model bert-base-chinese \ 43 | # --from_pretrained ../save/${MODEL_TYPE}/pytorch_model_9.bin \ 44 | # --config_file ../../../config/bert_base_6layer_6conect_capture_itpv.json \ 45 | # --predict_feature \ 46 | # --split test \ 47 | # --train_batch_size 32 \ 48 | # --max_seq_length 100 \ 49 | # --zero_shot \ 50 | # --lmdb_file /multi_modal/data/lmdb_features/${QUERY_FEATURE}.lmdb \ 51 | # --caption_path /multi_modal/data/product5m_v2/product1m_product5m_test_id_label.json \ 52 | # --feature_dir ./feature_data/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden \ 53 | # --return_hidden 54 | 55 | 56 | cd ../../.. 57 | 58 | 59 | #python retrieval_unit_id_list_v2.py \ 60 | # --query_feature_path ${QUERY_FEATURE_DIR} \ 61 | # --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 62 | # --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} \ 63 | # --max_topk 10 \ 64 | # --t \ 65 | # --p \ 66 | # --i \ 67 | # --v \ 68 | # --a \ 69 | # --tp \ 70 | # --ti \ 71 | # --tv \ 72 | # --pi \ 73 | # --pv \ 74 | # --iv \ 75 | # --ta \ 76 | # --pa \ 77 | # --ia \ 78 | # --va \ 79 | # --tpi \ 80 | # --tpv \ 81 | # --tiv \ 82 | # --piv \ 83 | # --tpiv \ 84 | # --tpiva 85 | 86 | 87 | #GT_file=/multi_modal/data/product5m_v2/subset_v2_id_label.json 88 | OUTPUT_METRIC_DIR=examples/${TRAIN_TYPE}/eval/retrieval_metric/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 89 | #python evaluate_unit_v2.py \ 90 | # --retrieval_result_dir ${RETRIEVAL_RESULTS_DIR} \ 91 | # --GT_file ${GT_file} \ 92 | # --output_metric_dir ${OUTPUT_METRIC_DIR} \ 93 | # --t \ 94 | # --p \ 95 | # --i \ 96 | # --v \ 97 | # --a \ 98 | # --tp \ 99 | # --ti \ 100 | # --tv \ 101 | # --pi \ 102 | # --pv \ 103 | # --iv \ 104 | # --ta \ 105 | # --pa \ 106 | # --ia \ 107 | # --va \ 108 | # --tpi \ 109 | # --tpv \ 110 | # --tiv \ 111 | # --piv \ 112 | # --tpiv \ 113 | # --tpiva 114 | 115 | 116 | python eval_spearman.py \ 117 | --query_feature_path ${QUERY_FEATURE_DIR} \ 118 | --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 119 | --retrieval_results_path ${OUTPUT_METRIC_DIR} 120 | 121 | 122 | ## cross-modal 123 | #python retrieval_unit_id_list_cross_modal.py \ 124 | # --query_feature_path ${QUERY_FEATURE_DIR} \ 125 | # --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 126 | # --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} \ 127 | # --max_topk 10 \ 128 | # --cross_modal vt 129 | # 130 | # 131 | #GT_file=/multi_modal/data/product5m_v2/product1m_product5m_id_label.json 132 | #OUTPUT_METRIC_DIR=examples/${TRAIN_TYPE}/eval/retrieval_metric/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 133 | #python evaluate_unit_cross_modal.py \ 134 | # --retrieval_result_dir ${RETRIEVAL_RESULTS_DIR} \ 135 | # --GT_file ${GT_file} \ 136 | # --output_metric_dir ${OUTPUT_METRIC_DIR} \ 137 | # --cross_modal vt 138 | 139 | 140 | #RETRIEVAL_IMAGES_DIR=examples/${TRAIN_TYPE}/eval/retrieval_images/${MODEL_TYPE}/${QUERY_FEATURE} 141 | #python retrieval_unit_images.py \ 142 | # --retrieval_ids_path ${RETRIEVAL_RESULTS_DIR} \ 143 | # --retrieval_images_path ${RETRIEVAL_IMAGES_DIR} \ 144 | # --query_image_prefix /multi_modal/data/images \ 145 | # --gallery_image_prefix /multi_modal/data/images \ 146 | # --t \ 147 | # --v 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /examples/SCALE/eval/eval_gallery1_cls.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TRAIN_TYPE=ITP3VA_capture_v3_dyctr_dymask_per_sample 3 | MODEL_TYPE=CLS 4 | QUERY_FEATURE=subset_v2_test_feature 5 | GALLERY_FEATURE=subset_v2_test_feature 6 | QUERY_FEATURE_DIR=examples/${TRAIN_TYPE}/eval/feature_data/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 7 | GALLERY_FEATURE_DIR=examples/${TRAIN_TYPE}/eval/feature_data/${MODEL_TYPE}/${GALLERY_FEATURE}/return_hidden 8 | RETRIEVAL_RESULTS_DIR=examples/${TRAIN_TYPE}/eval/retrieval_id_list/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 9 | 10 | 11 | GALLERY_FILE=/multi_modal/data/train.txt 12 | QUERY_FILE=/multi_modal/data/query.txt 13 | IMAGE_PATH=/multi_modal/data/images 14 | VIDEO_PATH=/multi_modal/data/videos 15 | 16 | 17 | # remenber to change the dir and filename 18 | # gallery 19 | CUDA_VISIBLE_DEVICES=7 python extract_features_cls.py \ 20 | --bert_model bert-base-chinese \ 21 | --from_pretrained ../save/${MODEL_TYPE}/pytorch_model_9.bin \ 22 | --config_file ../../../config/bert_base_6layer_6conect_capture_itp3va.json \ 23 | --predict_feature \ 24 | --video_feature_dir /multi_modal/data/video_feature \ 25 | --split test \ 26 | --train_batch_size 32 \ 27 | --max_seq_length 36 \ 28 | --zero_shot \ 29 | --video_len 12 \ 30 | --pv_seq_len 64 \ 31 | --lmdb_file /multi_modal/data/lmdb_features/${GALLERY_FEATURE}.lmdb \ 32 | --caption_path /multi_modal/data/product5m_v2/subset_v2_id_label.json \ 33 | --feature_dir ./feature_data/${MODEL_TYPE}/${GALLERY_FEATURE}/return_hidden \ 34 | --return_hidden 35 | 36 | 37 | cd ../../.. 38 | 39 | 40 | python retrieval_unit_id_list_v2.py \ 41 | --query_feature_path ${QUERY_FEATURE_DIR} \ 42 | --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 43 | --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} \ 44 | --max_topk 10 \ 45 | --dense 46 | 47 | 48 | GT_file=/multi_modal/data/product5m_v2/subset_v2_id_label.json 49 | OUTPUT_METRIC_DIR=examples/${TRAIN_TYPE}/eval/retrieval_metric/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 50 | python evaluate_unit_v2.py \ 51 | --retrieval_result_dir ${RETRIEVAL_RESULTS_DIR} \ 52 | --GT_file ${GT_file} \ 53 | --output_metric_dir ${OUTPUT_METRIC_DIR} \ 54 | --dense 55 | 56 | # 57 | #python eval_spearman.py \ 58 | # --query_feature_path ${QUERY_FEATURE_DIR} \ 59 | # --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 60 | # --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} 61 | 62 | 63 | ## cross-modal 64 | #python retrieval_unit_id_list_cross_modal.py \ 65 | # --query_feature_path ${QUERY_FEATURE_DIR} \ 66 | # --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 67 | # --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} \ 68 | # --max_topk 10 \ 69 | # --cross_modal vt 70 | # 71 | # 72 | #GT_file=/multi_modal/data/product5m_v2/product1m_product5m_id_label.json 73 | #OUTPUT_METRIC_DIR=examples/${TRAIN_TYPE}/eval/retrieval_metric/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 74 | #python evaluate_unit_cross_modal.py \ 75 | # --retrieval_result_dir ${RETRIEVAL_RESULTS_DIR} \ 76 | # --GT_file ${GT_file} \ 77 | # --output_metric_dir ${OUTPUT_METRIC_DIR} \ 78 | # --cross_modal vt 79 | 80 | 81 | #RETRIEVAL_IMAGES_DIR=examples/${TRAIN_TYPE}/eval/retrieval_images/${MODEL_TYPE}/${QUERY_FEATURE} 82 | #python retrieval_unit_images.py \ 83 | # --retrieval_ids_path ${RETRIEVAL_RESULTS_DIR} \ 84 | # --retrieval_images_path ${RETRIEVAL_IMAGES_DIR} \ 85 | # --query_image_prefix /multi_modal/data/images \ 86 | # --gallery_image_prefix /multi_modal/data/images \ 87 | # --t \ 88 | # --v 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /examples/SCALE/eval/eval_gallery1_cls_fg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TRAIN_TYPE=ITP3VA_capture_v3_dyctr_dymask_per_sample 3 | MODEL_TYPE=CLS 4 | QUERY_FEATURE=query_fg_feature 5 | GALLERY_FEATURE=query_fg_feature 6 | QUERY_FEATURE_DIR=examples/${TRAIN_TYPE}/eval/feature_data/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 7 | GALLERY_FEATURE_DIR=examples/${TRAIN_TYPE}/eval/feature_data/${MODEL_TYPE}/${GALLERY_FEATURE}/return_hidden 8 | RETRIEVAL_RESULTS_DIR=examples/${TRAIN_TYPE}/eval/retrieval_id_list/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 9 | 10 | 11 | GALLERY_FILE=/multi_modal/data/train.txt 12 | QUERY_FILE=/multi_modal/data/query.txt 13 | IMAGE_PATH=/multi_modal/data/images 14 | VIDEO_PATH=/multi_modal/data/videos 15 | 16 | 17 | # remenber to change the dir and filename 18 | # gallery 19 | CUDA_VISIBLE_DEVICES=7 python extract_features_cls.py \ 20 | --bert_model bert-base-chinese \ 21 | --from_pretrained ../save/${MODEL_TYPE}/pytorch_model_9.bin \ 22 | --config_file ../../../config/bert_base_6layer_6conect_capture_itp3va.json \ 23 | --predict_feature \ 24 | --video_feature_dir /multi_modal/data/video_feature \ 25 | --split test \ 26 | --train_batch_size 32 \ 27 | --max_seq_length 36 \ 28 | --zero_shot \ 29 | --video_len 12 \ 30 | --pv_seq_len 64 \ 31 | --lmdb_file /multi_modal/data/lmdb_features/${GALLERY_FEATURE}.lmdb \ 32 | --caption_path /multi_modal/data/product5m_v2/product1m_product5m_test_id_label.json \ 33 | --feature_dir ./feature_data/${MODEL_TYPE}/${GALLERY_FEATURE}/return_hidden \ 34 | --return_hidden 35 | 36 | 37 | cd ../../.. 38 | 39 | 40 | python retrieval_unit_id_list_v2.py \ 41 | --query_feature_path ${QUERY_FEATURE_DIR} \ 42 | --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 43 | --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} \ 44 | --max_topk 10 \ 45 | --dense 46 | 47 | 48 | GT_file=/multi_modal/data/product5m_v2/product1m_product5m_test_id_label.json 49 | OUTPUT_METRIC_DIR=examples/${TRAIN_TYPE}/eval/retrieval_metric/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 50 | python evaluate_unit_v2.py \ 51 | --retrieval_result_dir ${RETRIEVAL_RESULTS_DIR} \ 52 | --GT_file ${GT_file} \ 53 | --output_metric_dir ${OUTPUT_METRIC_DIR} \ 54 | --dense 55 | 56 | # 57 | #python eval_spearman.py \ 58 | # --query_feature_path ${QUERY_FEATURE_DIR} \ 59 | # --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 60 | # --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} 61 | 62 | 63 | ## cross-modal 64 | #python retrieval_unit_id_list_cross_modal.py \ 65 | # --query_feature_path ${QUERY_FEATURE_DIR} \ 66 | # --gallery_feature_path ${GALLERY_FEATURE_DIR} \ 67 | # --retrieval_results_path ${RETRIEVAL_RESULTS_DIR} \ 68 | # --max_topk 10 \ 69 | # --cross_modal vt 70 | # 71 | # 72 | #GT_file=/multi_modal/data/product5m_v2/product1m_product5m_id_label.json 73 | #OUTPUT_METRIC_DIR=examples/${TRAIN_TYPE}/eval/retrieval_metric/${MODEL_TYPE}/${QUERY_FEATURE}/return_hidden 74 | #python evaluate_unit_cross_modal.py \ 75 | # --retrieval_result_dir ${RETRIEVAL_RESULTS_DIR} \ 76 | # --GT_file ${GT_file} \ 77 | # --output_metric_dir ${OUTPUT_METRIC_DIR} \ 78 | # --cross_modal vt 79 | 80 | 81 | #RETRIEVAL_IMAGES_DIR=examples/${TRAIN_TYPE}/eval/retrieval_images/${MODEL_TYPE}/${QUERY_FEATURE} 82 | #python retrieval_unit_images.py \ 83 | # --retrieval_ids_path ${RETRIEVAL_RESULTS_DIR} \ 84 | # --retrieval_images_path ${RETRIEVAL_IMAGES_DIR} \ 85 | # --query_image_prefix /multi_modal/data/images \ 86 | # --gallery_image_prefix /multi_modal/data/images \ 87 | # --t \ 88 | # --v 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /examples/SCALE/eval/extract_features_cls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | from io import open 7 | import numpy as np 8 | 9 | import sys 10 | sys.path.append("../../../") 11 | 12 | from tensorboardX import SummaryWriter 13 | from tqdm import tqdm 14 | from bisect import bisect 15 | import yaml 16 | from easydict import EasyDict as edict 17 | import sys 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.nn as nn 22 | import pickle 23 | from torch.utils.data import DataLoader, Dataset, RandomSampler 24 | 25 | 26 | from pytorch_pretrained_bert.tokenization import BertTokenizer 27 | 28 | from dataloaders.pretrain_dataset_ITP3VA import Pretrain_DataSet_Train 29 | # from model.capture_ITP3V import BertForMultiModalPreTraining, BertConfig 30 | from model.capture_ITP3VA_v3_dyctr_dymask_per_sample import Capture_ITPV_ForClassification, BertConfig 31 | 32 | 33 | from utils_args import get_args 34 | 35 | import torch.distributed as dist 36 | 37 | logging.basicConfig( 38 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 39 | datefmt="%m/%d/%Y %H:%M:%S", 40 | level=logging.INFO, 41 | ) 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | def read_pickle(filename): 46 | return pickle.loads(open(filename,"rb").read()) 47 | 48 | def write_pickle(filename,data): 49 | open(filename,"wb").write(pickle.dumps(data)) 50 | return 51 | 52 | 53 | def main(): 54 | args = get_args() 55 | 56 | random.seed(args.seed) 57 | np.random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | 60 | 61 | # timeStamp = '-'.join(task_names) + '_' + args.config_file.split('/')[1].split('.')[0] 62 | if '/' in args.from_pretrained: 63 | timeStamp = args.from_pretrained.split('/')[1] 64 | else: 65 | timeStamp = args.from_pretrained 66 | 67 | savePath = os.path.join(args.output_dir, timeStamp) 68 | 69 | config = BertConfig.from_json_file(args.config_file) 70 | # bert_weight_name = json.load(open("config/" + args.bert_model + "_weight_name.json", "r")) 71 | 72 | if args.local_rank == -1 or args.no_cuda: 73 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 74 | n_gpu = torch.cuda.device_count() 75 | else: 76 | torch.cuda.set_device(args.local_rank) 77 | device = torch.device("cuda", args.local_rank) 78 | n_gpu = 1 79 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 80 | torch.distributed.init_process_group(backend="nccl") 81 | 82 | logger.info( 83 | "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 84 | device, n_gpu, bool(args.local_rank != -1), args.fp16 85 | ) 86 | ) 87 | 88 | default_gpu = False 89 | if dist.is_available() and args.local_rank != -1: 90 | rank = dist.get_rank() 91 | if rank == 0: 92 | default_gpu = True 93 | else: 94 | default_gpu = True 95 | 96 | tokenizer = BertTokenizer.from_pretrained( 97 | args.bert_model, do_lower_case=args.do_lower_case 98 | ) 99 | 100 | train_dataset = Pretrain_DataSet_Train( 101 | tokenizer, 102 | seq_len=args.max_seq_length, 103 | batch_size=args.train_batch_size, 104 | predict_feature=args.predict_feature, 105 | num_workers=args.num_workers, 106 | lmdb_file=args.lmdb_file, # '/train/training_feat_all_v2.lmdb' 107 | caption_path=args.caption_path ,# "/id_info_dict.json" 108 | video_feature_dir=args.video_feature_dir, 109 | video_len=args.video_len, 110 | pv_seq_len=args.pv_seq_len, 111 | audio_file_dir=args.audio_file_dir, 112 | audio_len=args.audio_len, 113 | MLM=args.MLM, 114 | MRM=args.MRM, 115 | MEM=args.MEM, 116 | ITM=args.ITM, 117 | MFM=args.MFM, 118 | MAM=args.MAM 119 | ) 120 | 121 | print("all image batch num: ", len(train_dataset)) 122 | 123 | config.fast_mode = True 124 | if args.predict_feature: 125 | print("predict_feature") 126 | config.v_target_size = 2048 127 | config.predict_feature = True 128 | else: 129 | print("no predict_feature") 130 | config.v_target_size = 1601 131 | config.predict_feature = False 132 | 133 | model = Capture_ITPV_ForClassification.from_pretrained(args.from_pretrained, config) 134 | 135 | model.to(device) 136 | if args.local_rank != -1: 137 | try: 138 | from apex import DistributedDataParallel as DDP 139 | except ImportError: 140 | raise ImportError( 141 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." 142 | ) 143 | model = DDP(model, deay_allreduce=True) 144 | 145 | elif n_gpu > 1: 146 | model = nn.DataParallel(model) 147 | 148 | no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 149 | 150 | print("Prepare to generate feature! ready!") 151 | model.eval() 152 | 153 | # lib 154 | lib_vil_id=[] 155 | dense_feature_list=[] 156 | 157 | for step, batch in enumerate(tqdm(train_dataset)): 158 | image_id = batch[-1] 159 | batch = batch[:-1] 160 | batch = tuple(t.cuda(device=device, non_blocking=True) for t in batch) 161 | 162 | input_ids, input_mask, segment_ids, lm_label_ids, is_next, \ 163 | pv_input_ids, pv_input_mask, pv_segment_ids, em_label_ids, \ 164 | image_feat, image_loc, image_target, image_label, image_mask, \ 165 | video_feat, video_target, video_label, video_mask, \ 166 | audio_feat, audio_target, audio_label, audio_mask= ( 167 | batch 168 | ) 169 | 170 | # lib_vil_id.append(image_ids) 171 | lib_vil_id+=list(image_id) 172 | 173 | with torch.no_grad(): 174 | _,_,pooled_output_dense=model( 175 | input_ids, 176 | pv_input_ids, 177 | image_feat, 178 | video_feat, 179 | audio_feat, 180 | image_loc, 181 | segment_ids, 182 | pv_segment_ids, 183 | input_mask, 184 | pv_input_mask, 185 | image_mask, 186 | video_mask, 187 | audio_mask, 188 | ) 189 | 190 | # ########################## 191 | pooled_output_dense = pooled_output_dense.detach().cpu().numpy() 192 | ##############################33 193 | dense_feature_list.append(pooled_output_dense) 194 | 195 | dense_feature_np=np.vstack(dense_feature_list) 196 | print("dense_feature_np: ",dense_feature_np.shape) 197 | if not os.path.exists(args.feature_dir): 198 | os.makedirs(args.feature_dir) 199 | np.save("{}/dense_feature_np.npy".format(args.feature_dir),dense_feature_np) 200 | np.save("{}/id.npy".format(args.feature_dir),lib_vil_id) 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /examples/SCALE/run_pretrain_task.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=7 python pretrain_task.py \ 3 | --from_pretrained /bert_model/bert_base_chinese \ 4 | --bert_model bert-base-chinese \ 5 | --config_file ../../config/bert_base_6layer_6conect_capture_itp3va.json\ 6 | --predict_feature \ 7 | --learning_rate 1e-4 \ 8 | --video_feature_dir video_path \ 9 | --audio_file_dir audio_path \ 10 | --train_batch_size 64 \ 11 | --max_seq_length 36 \ 12 | --video_len 12 \ 13 | --pv_seq_len 64 \ 14 | --audio_len 12 \ 15 | --lmdb_file image_path \ 16 | --caption_path caption_path \ 17 | --save_name capture_subset_v2_MLM_MRM_MEM_MFM_MAM_CLR \ 18 | --MLM \ 19 | --MRM \ 20 | --MFM \ 21 | --MEM \ 22 | --MAM \ 23 | --CLR -------------------------------------------------------------------------------- /examples/SCALE/run_train_cls.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TRAIN_TYPE=ITP3VA_capture_v3_dyctr_dymask_per_sample 3 | MODEL_TYPE=capture_subset_v2_MLM_MRM_MEM_MFM_MAM_CLR 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_cls.py \ 6 | --from_pretrained /multi_modal/CAPTURE/examples/${TRAIN_TYPE}/save/${MODEL_TYPE}/pytorch_model_9.bin \ 7 | --bert_model bert-base-chinese \ 8 | --config_file ../../config/bert_base_6layer_6conect_capture_itp3va.json\ 9 | --predict_feature \ 10 | --learning_rate 1e-4 \ 11 | --video_feature_dir video_path \ 12 | --audio_file_dir audio_path \ 13 | --train_batch_size 64 \ 14 | --max_seq_length 36 \ 15 | --video_len 12 \ 16 | --pv_seq_len 64 \ 17 | --audio_len 12 \ 18 | --train_lmdb_file /data/lmdb_features/subset_v2_train_feature.lmdb \ 19 | --test_lmdb_file /data/lmdb_features/subset_v2_test_feature.lmdb \ 20 | --caption_path /data/product5m_v2/subset_v2_id_label.json \ 21 | --label_list_file /data/product5m_v2/subset_v2_label_list.json \ 22 | --save_name CLS -------------------------------------------------------------------------------- /examples/SCALE/utils_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | def get_args(): 3 | parser = argparse.ArgumentParser() 4 | 5 | # Required parameters 6 | parser.add_argument( 7 | "--train_file", 8 | default="data/conceptual_caption/training", 9 | type=str, 10 | # required=True, 11 | help="The input train corpus.", 12 | ) 13 | parser.add_argument( 14 | "--validation_file", 15 | default="data/conceptual_caption/validation", 16 | type=str, 17 | # required=True, 18 | help="The input train corpus.", 19 | ) 20 | parser.add_argument( 21 | "--from_pretrained", 22 | default="", 23 | type=str, 24 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 25 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", 26 | ) 27 | parser.add_argument( 28 | "--bert_model", 29 | default="bert-base-uncased", 30 | type=str, 31 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 32 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", 33 | ) 34 | parser.add_argument( 35 | "--output_dir", 36 | default="save", 37 | type=str, 38 | # required=True, 39 | help="The output directory where the model checkpoints will be written.", 40 | ) 41 | 42 | parser.add_argument( 43 | "--config_file", 44 | default="config/bert_config.json", 45 | type=str, 46 | # required=True, 47 | help="The config file which specified the model details.", 48 | ) 49 | ## Other parameters 50 | parser.add_argument( 51 | "--max_seq_length", 52 | default=36, 53 | type=int, 54 | help="The maximum total input sequence length after WordPiece tokenization. \n" 55 | "Sequences longer than this will be truncated, and sequences shorter \n" 56 | "than this will be padded.", 57 | ) 58 | parser.add_argument("--predict_feature", action="store_true", help="visual target.") 59 | 60 | parser.add_argument( 61 | "--train_batch_size", 62 | default=512, 63 | type=int, 64 | help="Total batch size for training.", 65 | ) 66 | parser.add_argument( 67 | "--learning_rate", 68 | default=1e-4, 69 | type=float, 70 | help="The initial learning rate for Adam.", 71 | ) 72 | parser.add_argument( 73 | "--num_train_epochs", 74 | default=20.0, 75 | type=float, 76 | help="Total number of training epochs to perform.", 77 | ) 78 | parser.add_argument( 79 | "--start_epoch", 80 | default=0, 81 | type=float, 82 | help="Total number of training epochs to perform.", 83 | ) 84 | parser.add_argument( 85 | "--warmup_proportion", 86 | default=0.1, 87 | type=float, 88 | help="Proportion of training to perform linear learning rate warmup for. " 89 | "E.g., 0.1 = 10%% of training.", 90 | ) 91 | parser.add_argument("--caption_path", 92 | type=str) 93 | parser.add_argument("--lmdb_file", type=str) 94 | parser.add_argument("--train_lmdb_file",type=str) 95 | parser.add_argument("--test_lmdb_file",type=str) 96 | parser.add_argument( 97 | "--img_weight", default=1, type=float, help="weight for image loss" 98 | ) 99 | parser.add_argument( 100 | "--no_cuda", action="store_true", help="Whether not to use CUDA when available" 101 | ) 102 | parser.add_argument( 103 | "--on_memory", 104 | action="store_true", 105 | help="Whether to load train samples into memory or use disk", 106 | ) 107 | parser.add_argument( 108 | "--do_lower_case", 109 | type=bool, 110 | default=True, 111 | help="Whether to lower case the input text. True for uncased models, False for cased models.", 112 | ) 113 | parser.add_argument( 114 | "--local_rank", 115 | type=int, 116 | default=-1, 117 | help="local_rank for distributed training on gpus", 118 | ) 119 | parser.add_argument( 120 | "--seed", type=int, default=42, help="random seed for initialization" 121 | ) 122 | parser.add_argument( 123 | "--gradient_accumulation_steps", 124 | type=int, 125 | default=1, 126 | help="Number of updates steps to accumualte before performing a backward/update pass.", 127 | ) 128 | parser.add_argument( 129 | "--fp16", 130 | action="store_true", 131 | help="Whether to use 16-bit float precision instead of 32-bit", 132 | ) 133 | parser.add_argument( 134 | "--loss_scale", 135 | type=float, 136 | default=0, 137 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 138 | "0 (default value): dynamic loss scaling.\n" 139 | "Positive power of 2: static loss scaling value.\n", 140 | ) 141 | parser.add_argument( 142 | "--num_workers", 143 | type=int, 144 | default=3, 145 | help="Number of workers in the dataloader.", 146 | ) 147 | 148 | parser.add_argument( 149 | "--save_name", 150 | default='', 151 | type=str, 152 | help="save name for training.", 153 | ) 154 | parser.add_argument( 155 | "--baseline", action="store_true", help="Wheter to use the baseline model (single bert)." 156 | ) 157 | parser.add_argument( 158 | "--freeze", default=-1, type=int, 159 | help="till which layer of textual stream of vilbert need to fixed." 160 | ) 161 | parser.add_argument( 162 | "--use_chuncks", default=0, type=float, help="whether use chunck for parallel training." 163 | ) 164 | parser.add_argument( 165 | "--distributed", action="store_true", help="whether use chunck for parallel training." 166 | ) 167 | parser.add_argument( 168 | "--without_coattention", action="store_true", help="whether pair loss." 169 | ) 170 | 171 | parser.add_argument( 172 | "--video_feature_dir",type=str 173 | ) 174 | parser.add_argument( 175 | "--video_len",type=int 176 | ) 177 | 178 | parser.add_argument( 179 | "--label_list_file",type=str 180 | ) 181 | 182 | parser.add_argument( 183 | "--pv_seq_len",type=int,default=64 184 | ) 185 | parser.add_argument( 186 | "--audio_file_dir",type=str 187 | ) 188 | parser.add_argument( 189 | "--audio_len",type=int 190 | ) 191 | 192 | parser.add_argument( 193 | "--MLM",action="store_true" 194 | ) 195 | parser.add_argument( 196 | "--MRM",action="store_true" 197 | ) 198 | parser.add_argument( 199 | "--MEM",action="store_true" 200 | ) 201 | parser.add_argument( 202 | "--ITM",action="store_true" 203 | ) 204 | parser.add_argument( 205 | "--CLR",action="store_true" 206 | ) 207 | parser.add_argument( 208 | "--MFM",action="store_true" 209 | ) 210 | parser.add_argument( 211 | "--MAM",action="store_true" 212 | ) 213 | 214 | 215 | 216 | 217 | args = parser.parse_args() 218 | 219 | return args 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /m5.yaml: -------------------------------------------------------------------------------- 1 | name: m5 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - https://repo.anaconda.com/pkgs/main 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - blas=1.0=mkl 11 | - bzip2=1.0.8=h7f98852_4 12 | - ca-certificates=2021.10.8=ha878542_0 13 | - certifi=2021.10.8=py37h89c1867_1 14 | - cudatoolkit=11.1.1=h6406543_8 15 | - ffmpeg=4.3=hf484d3e_0 16 | - freetype=2.10.4=h5ab3b9f_0 17 | - giflib=5.2.1=h7b6447c_0 18 | - gmp=6.2.1=h58526e2_0 19 | - gnutls=3.6.13=h85f3911_1 20 | - intel-openmp=2021.3.0=h06a4308_3350 21 | - jpeg=9b=h024ee3a_2 22 | - lame=3.100=h7f98852_1001 23 | - lcms2=2.12=h3be6417_0 24 | - ld_impl_linux-64=2.35.1=h7274673_9 25 | - libffi=3.3=he6710b0_2 26 | - libgcc-ng=9.3.0=h5101ec6_17 27 | - libgomp=9.3.0=h5101ec6_17 28 | - libiconv=1.16=h516909a_0 29 | - libpng=1.6.37=hbc83047_0 30 | - libstdcxx-ng=9.3.0=hd4cf53a_17 31 | - libtiff=4.2.0=h85742a9_0 32 | - libuv=1.40.0=h7b6447c_0 33 | - libwebp=1.2.0=h89dd481_0 34 | - libwebp-base=1.2.0=h27cfd23_0 35 | - lz4-c=1.9.3=h295c915_1 36 | - mkl=2021.3.0=h06a4308_520 37 | - mkl-service=2.4.0=py37h7f8727e_0 38 | - mkl_fft=1.3.0=py37h42c9631_2 39 | - mkl_random=1.2.2=py37h51133e4_0 40 | - ncurses=6.2=he6710b0_1 41 | - nettle=3.6=he412f7d_0 42 | - ninja=1.10.2=hff7bd54_1 43 | - olefile=0.46=py37_0 44 | - openh264=2.1.1=h780b84a_0 45 | - openjpeg=2.4.0=h3ad879b_0 46 | - openssl=1.1.1l=h7f8727e_0 47 | - pip=21.0.1=py37h06a4308_0 48 | - python=3.7.11=h12debd9_0 49 | - python_abi=3.7=2_cp37m 50 | - pytorch=1.8.0=py3.7_cuda11.1_cudnn8.0.5_0 51 | - readline=8.1=h27cfd23_0 52 | - setuptools=58.0.4=py37h06a4308_0 53 | - sqlite=3.36.0=hc218d9a_0 54 | - tk=8.6.10=hbc83047_0 55 | - torchaudio=0.8.0=py37 56 | - torchvision=0.9.0=py37_cu111 57 | - typing_extensions=3.10.0.2=pyh06a4308_0 58 | - wheel=0.37.0=pyhd3eb1b0_1 59 | - xz=5.2.5=h7b6447c_0 60 | - zlib=1.2.11=h7b6447c_3 61 | - zstd=1.4.9=haebb681_0 62 | - pip: 63 | - attrs==21.4.0 64 | - blessings==1.7 65 | - boto3==1.18.63 66 | - botocore==1.21.63 67 | - charset-normalizer==2.0.7 68 | - click==8.0.3 69 | - cycler==0.11.0 70 | - decorator==4.4.2 71 | - easydict==1.9 72 | - ffmpeg-python==0.2.0 73 | - filelock==3.4.0 74 | - ftfy==6.0.3 75 | - future==0.18.2 76 | - gpustat==0.6.0 77 | - huggingface-hub==0.1.2 78 | - idna==3.3 79 | - imageio==2.10.3 80 | - imageio-ffmpeg==0.4.5 81 | - importlib-metadata==4.8.2 82 | - iniconfig==1.1.1 83 | - jmespath==0.10.0 84 | - joblib==1.1.0 85 | - jsonlines==2.0.0 86 | - kiwisolver==1.3.2 87 | - lmdb==0.94 88 | - matplotlib==3.4.3 89 | - moviepy==1.0.3 90 | - msgpack==1.0.0 91 | - msgpack-numpy==0.4.4.3 92 | - munkres==1.1.4 93 | - nltk==3.6.5 94 | - numpy==1.21.4 95 | - nvidia-ml-py3==7.352.0 96 | - opencv-python==3.4.4.19 97 | - packaging==21.3 98 | - pandas==1.3.4 99 | - pillow==8.4.0 100 | - pluggy==1.0.0 101 | - proglog==0.1.9 102 | - protobuf==3.19.1 103 | - psutil==5.7.3 104 | - py==1.11.0 105 | - pyparsing==3.0.6 106 | - pytest==6.2.5 107 | - python-dateutil==2.8.2 108 | - pytz==2021.3 109 | - pyyaml==6.0 110 | - pyzmq==19.0.0 111 | - regex==2021.11.2 112 | - requests==2.26.0 113 | - s3transfer==0.5.0 114 | - sacremoses==0.0.46 115 | - scikit-learn==1.0.1 116 | - scipy==1.7.1 117 | - seaborn==0.11.2 118 | - six==1.14.0 119 | - sklearn==0.0 120 | - tabulate==0.8.6 121 | - tensorboard-logger==0.1.0 122 | - tensorboardx==2.4 123 | - tensorpack==0.9.4 124 | - termcolor==1.1.0 125 | - threadpoolctl==3.0.0 126 | - timm==0.4.12 127 | - tokenizers==0.10.3 128 | - toml==0.10.2 129 | - tqdm==4.62.3 130 | - transformers==4.12.5 131 | - urllib3==1.26.7 132 | - wcwidth==0.2.5 133 | - zipp==3.6.0 134 | prefix: /home/server/miniconda3/envs/m5 135 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaodongsuper/SCALE_code/227ae0886a0bda598495e60c4624ca894a7418bf/model/__init__.py -------------------------------------------------------------------------------- /model/cross-base/cross_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 2048, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 8, 10 | "num_hidden_layers": 6, 11 | "vocab_size": 512 12 | } -------------------------------------------------------------------------------- /model/cross-base/cross_config_temp.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 2048, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 8, 10 | "num_hidden_layers": 4 11 | } -------------------------------------------------------------------------------- /model/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /model/module_cross.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | import json 8 | import math 9 | import logging 10 | import tarfile 11 | import tempfile 12 | import shutil 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from .file_utils import cached_path 18 | from .until_config import PretrainedConfig 19 | from .until_module import PreTrainedModel, LayerNorm, ACT2FN 20 | from collections import OrderedDict 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | PRETRAINED_MODEL_ARCHIVE_MAP = {} 25 | CONFIG_NAME = 'cross_config.json' 26 | WEIGHTS_NAME = 'cross_pytorch_model.bin' 27 | 28 | 29 | class CrossConfig(PretrainedConfig): 30 | """Configuration class to store the configuration of a `CrossModel`. 31 | """ 32 | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP 33 | config_name = CONFIG_NAME 34 | weights_name = WEIGHTS_NAME 35 | def __init__(self, 36 | vocab_size_or_config_json_file, 37 | hidden_size=768, 38 | num_hidden_layers=12, 39 | num_attention_heads=12, 40 | intermediate_size=3072, 41 | hidden_act="gelu", 42 | hidden_dropout_prob=0.1, 43 | attention_probs_dropout_prob=0.1, 44 | max_position_embeddings=512, 45 | type_vocab_size=2, 46 | initializer_range=0.02): 47 | """Constructs CrossConfig. 48 | 49 | Args: 50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. 51 | hidden_size: Size of the encoder layers and the pooler layer. 52 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 53 | num_attention_heads: Number of attention heads for each attention layer in 54 | the Transformer encoder. 55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 56 | layer in the Transformer encoder. 57 | hidden_act: The non-linear activation function (function or string) in the 58 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 60 | layers in the embeddings, encoder, and pooler. 61 | attention_probs_dropout_prob: The dropout ratio for the attention 62 | probabilities. 63 | max_position_embeddings: The maximum sequence length that this model might 64 | ever be used with. Typically set this to something large just in case 65 | (e.g., 512 or 1024 or 2048). 66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 67 | `CrossModel`. 68 | initializer_range: The sttdev of the truncated_normal_initializer for 69 | initializing all weight matrices. 70 | """ 71 | if isinstance(vocab_size_or_config_json_file, str): 72 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 73 | json_config = json.loads(reader.read()) 74 | for key, value in json_config.items(): 75 | self.__dict__[key] = value 76 | elif isinstance(vocab_size_or_config_json_file, int): 77 | self.vocab_size = vocab_size_or_config_json_file 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | else: 89 | raise ValueError("First argument must be either a vocabulary size (int)" 90 | "or the path to a pretrained model config file (str)") 91 | 92 | class QuickGELU(nn.Module): 93 | def forward(self, x: torch.Tensor): 94 | return x * torch.sigmoid(1.702 * x) 95 | 96 | class ResidualAttentionBlock(nn.Module): 97 | def __init__(self, d_model: int, n_head: int): 98 | super().__init__() 99 | 100 | self.attn = nn.MultiheadAttention(d_model, n_head) 101 | self.ln_1 = LayerNorm(d_model) 102 | self.mlp = nn.Sequential(OrderedDict([ 103 | ("c_fc", nn.Linear(d_model, d_model * 4)), 104 | ("gelu", QuickGELU()), 105 | ("c_proj", nn.Linear(d_model * 4, d_model)) 106 | ])) 107 | self.ln_2 = LayerNorm(d_model) 108 | self.n_head = n_head 109 | 110 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): 111 | attn_mask_ = attn_mask.repeat(self.n_head, 1, 1) 112 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] 113 | 114 | def forward(self, para_tuple: tuple): 115 | # x: torch.Tensor, attn_mask: torch.Tensor 116 | # print(para_tuple) 117 | x, attn_mask = para_tuple 118 | x = x + self.attention(self.ln_1(x), attn_mask) 119 | x = x + self.mlp(self.ln_2(x)) 120 | return (x, attn_mask) 121 | 122 | class Transformer(nn.Module): 123 | def __init__(self, width: int, layers: int, heads: int): 124 | super().__init__() 125 | self.width = width 126 | self.layers = layers 127 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) 128 | 129 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): 130 | return self.resblocks((x, attn_mask))[0] 131 | 132 | class CrossEmbeddings(nn.Module): 133 | """Construct the embeddings from word, position and token_type embeddings. 134 | """ 135 | def __init__(self, config): 136 | super(CrossEmbeddings, self).__init__() 137 | 138 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 139 | # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 140 | # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, concat_embeddings, concat_type=None): 144 | 145 | batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) 146 | # if concat_type is None: 147 | # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) 148 | 149 | position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) 150 | position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) 151 | 152 | # token_type_embeddings = self.token_type_embeddings(concat_type) 153 | position_embeddings = self.position_embeddings(position_ids) 154 | 155 | embeddings = concat_embeddings + position_embeddings # + token_type_embeddings 156 | # embeddings = self.LayerNorm(embeddings) 157 | embeddings = self.dropout(embeddings) 158 | return embeddings 159 | 160 | class CrossPooler(nn.Module): 161 | def __init__(self, config): 162 | super(CrossPooler, self).__init__() 163 | self.ln_pool = LayerNorm(config.hidden_size) 164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 165 | self.activation = QuickGELU() 166 | 167 | def forward(self, hidden_states, hidden_mask): 168 | # We "pool" the model by simply taking the hidden state corresponding 169 | # to the first token. 170 | hidden_states = self.ln_pool(hidden_states) 171 | pooled_output = hidden_states[:, 0] 172 | pooled_output = self.dense(pooled_output) 173 | pooled_output = self.activation(pooled_output) 174 | return pooled_output 175 | 176 | class CrossModel(PreTrainedModel): 177 | 178 | def initialize_parameters(self): 179 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 180 | attn_std = self.transformer.width ** -0.5 181 | fc_std = (2 * self.transformer.width) ** -0.5 182 | for block in self.transformer.resblocks: 183 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 184 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 185 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 186 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 187 | 188 | def __init__(self, config): 189 | super(CrossModel, self).__init__(config) 190 | 191 | self.embeddings = CrossEmbeddings(config) 192 | 193 | transformer_width = config.hidden_size 194 | transformer_layers = config.num_hidden_layers 195 | transformer_heads = config.num_attention_heads 196 | self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads,) 197 | self.pooler = CrossPooler(config) 198 | self.apply(self.init_weights) 199 | 200 | def build_attention_mask(self, attention_mask): 201 | extended_attention_mask = attention_mask.unsqueeze(1) 202 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 203 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 204 | extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) 205 | return extended_attention_mask 206 | 207 | def forward(self, concat_input, concat_type=None, attention_mask=None): 208 | 209 | if attention_mask is None: 210 | attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) 211 | if concat_type is None: 212 | concat_type = torch.zeros_like(attention_mask) 213 | 214 | extended_attention_mask = self.build_attention_mask(attention_mask) 215 | 216 | embedding_output = self.embeddings(concat_input, concat_type) 217 | embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND 218 | embedding_output = self.transformer(embedding_output, extended_attention_mask) 219 | embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD 220 | 221 | pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) 222 | 223 | return embedding_output, pooled_output 224 | -------------------------------------------------------------------------------- /model/until_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import copy 24 | import json 25 | import logging 26 | import tarfile 27 | import tempfile 28 | import shutil 29 | import torch 30 | from .file_utils import cached_path 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | class PretrainedConfig(object): 35 | 36 | pretrained_model_archive_map = {} 37 | config_name = "" 38 | weights_name = "" 39 | 40 | @classmethod 41 | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): 42 | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) 43 | if os.path.exists(archive_file) is False: 44 | if pretrained_model_name in cls.pretrained_model_archive_map: 45 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] 46 | else: 47 | archive_file = pretrained_model_name 48 | 49 | print("archive_file: ",archive_file) 50 | # redirect to the cache, if necessary 51 | try: 52 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 53 | except FileNotFoundError: 54 | if task_config is None or task_config.local_rank == 0: 55 | logger.error( 56 | "Model name '{}' was not found in model name list. " 57 | "We assumed '{}' was a path or url but couldn't find any file " 58 | "associated to this path or url.".format( 59 | pretrained_model_name, 60 | archive_file)) 61 | return None 62 | if resolved_archive_file == archive_file: 63 | if task_config is None or task_config.local_rank == 0: 64 | logger.info("loading archive file {}".format(archive_file)) 65 | else: 66 | if task_config is None or task_config.local_rank == 0: 67 | logger.info("loading archive file {} from cache at {}".format( 68 | archive_file, resolved_archive_file)) 69 | tempdir = None 70 | if os.path.isdir(resolved_archive_file): 71 | serialization_dir = resolved_archive_file 72 | else: 73 | # Extract archive to temp dir 74 | tempdir = tempfile.mkdtemp() 75 | if task_config is None or task_config.local_rank == 0: 76 | logger.info("extracting archive file {} to temp dir {}".format( 77 | resolved_archive_file, tempdir)) 78 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 79 | archive.extractall(tempdir) 80 | serialization_dir = tempdir 81 | # Load config 82 | config_file = os.path.join(serialization_dir, cls.config_name) 83 | config = cls.from_json_file(config_file) 84 | config.type_vocab_size = type_vocab_size 85 | if task_config is None or task_config.local_rank == 0: 86 | logger.info("Model config {}".format(config)) 87 | 88 | if state_dict is None: 89 | weights_path = os.path.join(serialization_dir, cls.weights_name) 90 | if os.path.exists(weights_path): 91 | state_dict = torch.load(weights_path, map_location='cpu') 92 | else: 93 | if task_config is None or task_config.local_rank == 0: 94 | logger.info("Weight doesn't exsits. {}".format(weights_path)) 95 | 96 | if tempdir: 97 | # Clean up temp dir 98 | shutil.rmtree(tempdir) 99 | 100 | return config, state_dict 101 | 102 | @classmethod 103 | def from_dict(cls, json_object): 104 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 105 | # config = cls(vocab_size_or_config_json_file=-1) 106 | config = cls() 107 | for key, value in json_object.items(): 108 | config.__dict__[key] = value 109 | return config 110 | 111 | @classmethod 112 | def from_json_file(cls, json_file): 113 | """Constructs a `BertConfig` from a json file of parameters.""" 114 | with open(json_file, "r", encoding='utf-8') as reader: 115 | text = reader.read() 116 | return cls.from_dict(json.loads(text)) 117 | 118 | def __repr__(self): 119 | return str(self.to_json_string()) 120 | 121 | def to_dict(self): 122 | """Serializes this instance to a Python dictionary.""" 123 | output = copy.deepcopy(self.__dict__) 124 | return output 125 | 126 | def to_json_string(self): 127 | """Serializes this instance to a JSON string.""" 128 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['b1'], group['b2'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /retrieval_unit_id_list_v2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ITP3V可以使用这个 3 | ''' 4 | import numpy as np 5 | import os 6 | import heapq 7 | from tqdm import tqdm 8 | import argparse 9 | import pickle 10 | import json 11 | 12 | def read_json(file): 13 | f = open(file, "r", encoding="utf-8").read() 14 | return json.loads(f) 15 | 16 | 17 | def write_json(file, data): 18 | f = open(file, "w", encoding="utf-8") 19 | json.dump(data, f, indent=2, ensure_ascii=False) 20 | return 21 | 22 | def read_pickle(filename): 23 | return pickle.loads(open(filename,"rb").read()) 24 | 25 | def write_pickle(filename,data): 26 | open(filename,"wb").write(pickle.dumps(data)) 27 | return 28 | 29 | 30 | def parse_args(): 31 | 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument( 35 | "--query_feature_path",type=str 36 | ) 37 | parser.add_argument( 38 | "--gallery_feature_path",type=str 39 | ) 40 | parser.add_argument("--retrieval_results_path",type=str) 41 | parser.add_argument("--t",action="store_true") 42 | parser.add_argument("--p",action="store_true") 43 | parser.add_argument("--i",action="store_true") 44 | parser.add_argument("--v",action="store_true") 45 | parser.add_argument("--a",action="store_true") 46 | 47 | parser.add_argument("--tp",action="store_true") 48 | parser.add_argument("--ti", action="store_true") 49 | parser.add_argument("--tv", action="store_true") 50 | parser.add_argument("--pi", action="store_true") 51 | parser.add_argument("--pv", action="store_true") 52 | parser.add_argument("--iv", action="store_true") 53 | parser.add_argument("--ta", action="store_true") 54 | parser.add_argument("--pa", action="store_true") 55 | parser.add_argument("--ia", action="store_true") 56 | parser.add_argument("--va", action="store_true") 57 | 58 | parser.add_argument("--tpi", action="store_true") 59 | parser.add_argument("--tpv", action="store_true") 60 | parser.add_argument("--tiv", action="store_true") 61 | parser.add_argument("--piv", action="store_true") 62 | 63 | parser.add_argument("--tpiv", action="store_true") 64 | parser.add_argument("--tpiva", action="store_true") 65 | parser.add_argument("--dense",action="store_true") 66 | 67 | parser.add_argument( 68 | "--max_topk",type=int,default=110 69 | ) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def read_feature(query_feature_txt): 75 | query_id = [] 76 | query_feature=[] 77 | for each in tqdm(query_feature_txt): 78 | each_split = each.split(",") 79 | item_id = each_split[0] 80 | each_feature = [float(i) for i in each_split[1:]] 81 | 82 | query_id.append(item_id) 83 | query_feature.append(each_feature) 84 | return query_id,query_feature 85 | 86 | 87 | if __name__ == '__main__': 88 | print() 89 | 90 | args=parse_args() 91 | 92 | save_path=args.retrieval_results_path 93 | if not os.path.exists(save_path): 94 | os.makedirs(save_path) 95 | 96 | feature_type = [] 97 | if args.t: feature_type.append("t") 98 | if args.p: feature_type.append("p") 99 | if args.i: feature_type.append("i") 100 | if args.v: feature_type.append("v") 101 | if args.a: feature_type.append("a") 102 | 103 | if args.tp: feature_type.append("tp") 104 | if args.ti: feature_type.append("ti") 105 | if args.tv: feature_type.append("tv") 106 | if args.pi: feature_type.append("pi") 107 | if args.pv: feature_type.append("pv") 108 | if args.iv: feature_type.append("iv") 109 | if args.ta: feature_type.append("ta") 110 | if args.pa: feature_type.append("pa") 111 | if args.ia: feature_type.append("ia") 112 | if args.va: feature_type.append("va") 113 | 114 | if args.tpi: feature_type.append("tpi") 115 | if args.tpv: feature_type.append("tpv") 116 | if args.tiv: feature_type.append("tiv") 117 | if args.piv: feature_type.append("piv") 118 | 119 | if args.tpiv: feature_type.append("tpiv") 120 | if args.tpiva: feature_type.append("tpiva") 121 | if args.dense: feature_type.append("dense") 122 | 123 | 124 | for each_feature_type in feature_type: 125 | 126 | query_dir = args.query_feature_path 127 | gallery_dir = args.gallery_feature_path 128 | 129 | save_file=open("{}/{}_feature_retrieval_id_list.txt".format(save_path,each_feature_type),"w") 130 | 131 | # new 132 | gallery_ids = np.load("{}/id.npy".format(gallery_dir)) 133 | gallery_ids = np.hstack(gallery_ids) 134 | 135 | query_ids = np.load("{}/id.npy".format(query_dir)) 136 | query_ids = np.hstack(query_ids) 137 | 138 | gallery_feature_np = np.load("{}/{}_feature_np.npy".format(gallery_dir, each_feature_type)) 139 | print(gallery_feature_np.shape) 140 | 141 | query_feature_np = np.load("{}/{}_feature_np.npy".format(query_dir, each_feature_type)) 142 | print(query_feature_np.shape) 143 | 144 | # query_id=query_id[:100] 145 | # query_feature_np=query_feature_np[:100] 146 | 147 | score_matrix = query_feature_np.dot(gallery_feature_np.T) 148 | max_topk = args.max_topk 149 | for q,each_score in tqdm(zip(query_ids,score_matrix)): 150 | max_index = heapq.nlargest(max_topk, range(len(each_score)), each_score.take) 151 | topk_item_id = gallery_ids[max_index] 152 | topk_item_id=[each_item_id for each_item_id in topk_item_id if each_item_id!=q] 153 | 154 | topk_item_id_str = ",".join(topk_item_id) 155 | save_file.write("{},{}\n".format(q, topk_item_id_str)) 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/.gitignore: -------------------------------------------------------------------------------- 1 | /model/* 2 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/README.md: -------------------------------------------------------------------------------- 1 | 2 | This repo is forked from [video_feature_extractor](https://github.com/antoine77340/video_feature_extractor) to extract S3D feature ([S3D_HowTo100M](https://github.com/antoine77340/S3D_HowTo100M)) pretraied on HowTo100M. Read more details in [video_feature_extractor](https://github.com/antoine77340/video_feature_extractor). 3 | 4 | This repo is also as a preprocess in video-language pretrain model [UniVL](https://github.com/microsoft/UniVL). 5 | 6 | ## Requirements 7 | 8 | IMPORTANT: The video decode process depends on the FFmpeg (https://www.ffmpeg.org/download.html), install it first and run `ffmpeg` and `ffprobe` command derectly to make them work well. 9 | 10 | - Python 3 11 | - PyTorch (>= 1.0) 12 | - python-ffmpeg (https://github.com/kkroening/ffmpeg-python) 13 | 14 | ## Downloading pretrained models 15 | This will download the pretrained S3D model: 16 | 17 | ```sh 18 | mkdir -p model 19 | cd model 20 | wget https://www.rocq.inria.fr/cluster-willow/amiech/howto100m/s3d_howto100m.pth 21 | cd .. 22 | ``` 23 | 24 | ## Extract S3D Feature 25 | 26 | First of all you need to generate a csv containing the list of videos you 27 | want to process. For instance, if you have absolute_path_video1.mp4 and absolute_path_video2.webm to process, 28 | you will need to generate a csv of this form: 29 | 30 | ```sh 31 | video_path,feature_path 32 | absolute_path_video1.mp4,absolute_path_of_video1_features.npy 33 | absolute_path_video2.webm,absolute_path_of_video2_features.npy 34 | ``` 35 | 36 | Refer to below command to generate such a csv file: 37 | ```sh 38 | python preprocess_generate_csv.py --csv=input.csv --video_root_path [VIDEO_PATH] --feature_root_path [FEATURE_PATH] --csv_save_path . 39 | ``` 40 | *Note: the video file should have a suffix, modify the code for your customization* 41 | 42 | 43 | And then just simply run: 44 | 45 | ```sh 46 | python extract.py --csv=./input.csv --type=s3dg --batch_size=64 --num_decoding_thread=4 47 | ``` 48 | This command will extract s3d-g video feature in a form of a numpy array. 49 | 50 | If you want to pickle all generated npy files: 51 | ```sh 52 | python convert_video_feature_to_pickle.py --feature_root_path [FEATURE_PATH] --pickle_root_path . --pickle_name input.pickle 53 | ``` 54 | *The key is set as the video name in the pickle file* 55 | 56 | ## Acknowledgements 57 | The code re-used code from https://github.com/kenshohara/3D-ResNets-PyTorch 58 | for 3D CNN. And modified from https://github.com/antoine77340/video_feature_extractor. 59 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/convert_video_feature_to_pickle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pickle 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser(description='Package video feature') 9 | 10 | parser.add_argument('--feature_root_path', type=str, help='feature path') 11 | parser.add_argument('--pickle_root_path', type=str, help='pickle path', default='.') 12 | parser.add_argument('--pickle_name', type=str, help='pickle name') 13 | args = parser.parse_args() 14 | 15 | save_file = os.path.join(args.pickle_root_path, args.pickle_name) 16 | 17 | feature_dict = {} 18 | for root, dirs, files in os.walk(args.feature_root_path): 19 | for file_name in files: 20 | if file_name.find(".npy") > 0: 21 | file_name_split = file_name.split(".") 22 | if len(file_name_split) == 2: 23 | key = file_name_split[0] 24 | feature_file = os.path.join(root, file_name) 25 | features = np.load(feature_file) 26 | print("features: ",features.shape) 27 | feature_dict[key] = features 28 | else: 29 | print("{} is error.".format(file_name)) 30 | print("Total num: {}".format(len(feature_dict))) 31 | pickle.dump(feature_dict, open(save_file, 'wb')) 32 | print("pickle is saved in: {}".format(save_file)) 33 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/demo/run_extract_subset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA_FILE=/multi_modal/data/product5m_v2/product1m_product5m_train_id_label.json 4 | IMAGE_PATH=/multi_modal/data/images 5 | VIDEO_PATH=/multi_modal/data/videos 6 | VIDEO_FEATURE_PATH=/multi_modal/data/video_feature 7 | CSV_SAVE_PATH=/multi_modal/data/product5m_v2 8 | 9 | cd ../ 10 | 11 | python preprocess_generate_csv3.py \ 12 | --ids_file /multi_modal/data/product5m_v2/subset_train_id_label.json \ 13 | --csv=input.csv \ 14 | --video_root_path ${VIDEO_PATH} \ 15 | --feature_root_path ${VIDEO_FEATURE_PATH} \ 16 | --csv_save_path ${CSV_SAVE_PATH} 17 | 18 | 19 | CUDA_VISIBLE_DEVICES=6 python extract.py \ 20 | --csv=${CSV_SAVE_PATH}/input.csv \ 21 | --type=s3dg \ 22 | --batch_size=16 \ 23 | --num_decoding_thread=4 24 | 25 | 26 | #python convert_video_feature_to_pickle.py \ 27 | #--feature_root_path ${VIDEO_FEATURE_PATH} \ 28 | #--pickle_root_path . \ 29 | #--pickle_name input.pickle 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | current_dir_path = os.path.dirname(os.path.realpath(__file__)) 3 | print(current_dir_path) 4 | 5 | import pdb 6 | import torch as th 7 | import math 8 | import numpy as np 9 | from video_loader import VideoLoader 10 | from torch.utils.data import DataLoader 11 | import argparse 12 | from model import get_model 13 | from preprocessing import Preprocessing 14 | from random_sequence_shuffler import RandomSequenceSampler 15 | import torch.nn.functional as F 16 | from tqdm import tqdm 17 | 18 | FRAMERATE_DICT = {'2d':1, '3d':24, 's3dg':16, 'raw_data':16} 19 | SIZE_DICT = {'2d':224, '3d':112, 's3dg':224, 'raw_data':224} 20 | CENTERCROP_DICT = {'2d':False, '3d':True, 's3dg':True, 'raw_data':True} 21 | FEATURE_LENGTH = {'2d':2048, '3d':2048, 's3dg':1024, 'raw_data':1024} 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description='Easy video feature extractor') 25 | 26 | parser.add_argument('--csv', type=str, help='input csv with video input path') 27 | parser.add_argument('--batch_size', type=int, default=64, help='batch size') 28 | parser.add_argument('--type', type=str, default='2d', help='CNN type', choices=['2d','3d','s3dg','raw_data']) 29 | parser.add_argument('--half_precision', type=int, default=1, help='output half precision float') 30 | parser.add_argument('--num_decoding_thread', type=int, default=4, help='Num parallel thread for video decoding') 31 | parser.add_argument('--l2_normalize', type=int, default=1, help='l2 normalize feature') 32 | parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth', help='Resnext model path') 33 | parser.add_argument('--s3d_model_path', type=str, default='model/s3d_howto100m.pth', help='S3GD model path') 34 | parser.add_argument('--datastore_base', type=str, default='.') 35 | args = parser.parse_args() 36 | 37 | 38 | dataset = VideoLoader( 39 | args.csv, # args.datastore_base, 40 | framerate=FRAMERATE_DICT[args.type], 41 | size=SIZE_DICT[args.type], 42 | centercrop=CENTERCROP_DICT[args.type] 43 | ) 44 | n_dataset = len(dataset) 45 | sampler = RandomSequenceSampler(n_dataset, 10) 46 | loader = DataLoader( 47 | dataset, 48 | batch_size=1, 49 | shuffle=False, 50 | num_workers=args.num_decoding_thread, 51 | sampler=sampler if n_dataset > 10 else None, 52 | ) 53 | preprocess = Preprocessing(args.type, FRAMERATE_DICT) 54 | 55 | if args.type == "raw_data": 56 | model = None 57 | else: 58 | model = get_model(args) 59 | 60 | with th.no_grad(): 61 | k = 0 62 | for data in tqdm(loader): 63 | k += 1 64 | input_file = data['input'][0] 65 | output_file = data['output'][0] 66 | 67 | if os.path.exists(output_file): 68 | continue 69 | # 70 | # pdb.set_trace() 71 | if len(data['video'].shape) > 3: 72 | print('Computing features of video {}/{}: {}'.format(k + 1, n_dataset, input_file)) 73 | video = data['video'].squeeze() 74 | if len(video.shape) == 4: 75 | video = preprocess(video) 76 | # Batch x 3 x T x H x W 77 | if args.type == "raw_data": 78 | features = video 79 | else: 80 | n_chunk = len(video) 81 | features = th.cuda.FloatTensor(n_chunk, FEATURE_LENGTH[args.type]).fill_(0) 82 | n_iter = int(math.ceil(n_chunk / float(args.batch_size))) 83 | for i in range(n_iter): 84 | min_ind = i * args.batch_size 85 | max_ind = (i + 1) * args.batch_size 86 | video_batch = video[min_ind:max_ind].cuda() 87 | batch_features = model(video_batch) 88 | if args.l2_normalize: 89 | batch_features = F.normalize(batch_features, dim=1) 90 | features[min_ind:max_ind] = batch_features 91 | features = features.cpu().numpy() 92 | if args.half_precision: 93 | features = features.astype('float16') 94 | os.makedirs('/'.join(output_file.split('/')[:-1]), exist_ok=True) 95 | np.save(output_file, features) 96 | else: 97 | print("data['video'].shape: ",data['video'].shape) 98 | print('Video {} already processed.'.format(input_file)) -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/input.csv: -------------------------------------------------------------------------------- 1 | video_path,feature_path 2 | /multi_modal/data/videos_temp/622651207543.mp4,/multi_modal/data/video_feature/622651207543.npy 3 | /multi_modal/data/videos_temp/568529764107.mp4,/multi_modal/data/video_feature/568529764107.npy 4 | /multi_modal/data/videos_temp/599827533459.mp4,/multi_modal/data/video_feature/599827533459.npy 5 | /multi_modal/data/videos_temp/558391226899.mp4,/multi_modal/data/video_feature/558391226899.npy 6 | /multi_modal/data/videos_temp/594273981721.mp4,/multi_modal/data/video_feature/594273981721.npy 7 | /multi_modal/data/videos_temp/567773355453.mp4,/multi_modal/data/video_feature/567773355453.npy 8 | /multi_modal/data/videos_temp/565582596512.mp4,/multi_modal/data/video_feature/565582596512.npy 9 | /multi_modal/data/videos_temp/540606452330.mp4,/multi_modal/data/video_feature/540606452330.npy 10 | /multi_modal/data/videos_temp/616895968589.mp4,/multi_modal/data/video_feature/616895968589.npy 11 | /multi_modal/data/videos_temp/561833188203.mp4,/multi_modal/data/video_feature/561833188203.npy 12 | /multi_modal/data/videos_temp/521158940564.mp4,/multi_modal/data/video_feature/521158940564.npy 13 | /multi_modal/data/videos_temp/624481586289.mp4,/multi_modal/data/video_feature/624481586289.npy 14 | /multi_modal/data/videos_temp/545861362742.mp4,/multi_modal/data/video_feature/545861362742.npy 15 | /multi_modal/data/videos_temp/608680499881.mp4,/multi_modal/data/video_feature/608680499881.npy 16 | /multi_modal/data/videos_temp/620111609197.mp4,/multi_modal/data/video_feature/620111609197.npy 17 | /multi_modal/data/videos_temp/45603295971.mp4,/multi_modal/data/video_feature/45603295971.npy 18 | /multi_modal/data/videos_temp/535540938556.mp4,/multi_modal/data/video_feature/535540938556.npy 19 | /multi_modal/data/videos_temp/611001046660.mp4,/multi_modal/data/video_feature/611001046660.npy 20 | /multi_modal/data/videos_temp/613193045962.mp4,/multi_modal/data/video_feature/613193045962.npy 21 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/input.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaodongsuper/SCALE_code/227ae0886a0bda598495e60c4624ca894a7418bf/tools/VideoFeatureExtractor/input.pickle -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch as th 3 | import torchvision.models as models 4 | from videocnn.models import resnext, s3dg 5 | from torch import nn 6 | 7 | 8 | class GlobalAvgPool(nn.Module): 9 | def __init__(self): 10 | super(GlobalAvgPool, self).__init__() 11 | 12 | def forward(self, x): 13 | return th.mean(x, dim=[-2, -1]) 14 | 15 | def init_weight(model, state_dict, should_omit="s3dg."): 16 | old_keys = [] 17 | new_keys = [] 18 | for key in state_dict.keys(): 19 | new_key = None 20 | if 'gamma' in key: 21 | new_key = key.replace('gamma', 'weight') 22 | if 'beta' in key: 23 | new_key = key.replace('beta', 'bias') 24 | if new_key: 25 | old_keys.append(key) 26 | new_keys.append(new_key) 27 | for old_key, new_key in zip(old_keys, new_keys): 28 | state_dict[new_key] = state_dict.pop(old_key) 29 | 30 | missing_keys = [] 31 | unexpected_keys = [] 32 | error_msgs = [] 33 | # copy state_dict so _load_from_state_dict can modify it 34 | metadata = getattr(state_dict, '_metadata', None) 35 | state_dict = state_dict.copy() 36 | if metadata is not None: 37 | state_dict._metadata = metadata 38 | 39 | if should_omit is not None: 40 | old_keys = [] 41 | new_keys = [] 42 | for key in state_dict.keys(): 43 | if key.find(should_omit) == 0: 44 | old_keys.append(key) 45 | new_key = key[len(should_omit):] 46 | new_keys.append(new_key) 47 | for old_key, new_key in zip(old_keys, new_keys): 48 | state_dict[new_key] = state_dict.pop(old_key) 49 | 50 | def load(module, prefix=''): 51 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 52 | module._load_from_state_dict( 53 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 54 | for name, child in module._modules.items(): 55 | if child is not None: 56 | load(child, prefix + name + '.') 57 | 58 | load(model, prefix='') 59 | 60 | if len(missing_keys) > 0: 61 | print("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) 62 | if len(unexpected_keys) > 0: 63 | print("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) 64 | if len(error_msgs) > 0: 65 | print("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) 66 | 67 | return model 68 | 69 | def get_model(args): 70 | assert args.type in ['2d', '3d', 's3dg'] 71 | if args.type == '2d': 72 | print('Loading 2D-ResNet-152 ...') 73 | model = models.resnet152(pretrained=True) 74 | model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool()) 75 | model = model.cuda() 76 | elif args.type == '3d': 77 | print('Loading 3D-ResneXt-101 ...') 78 | model = resnext.resnet101( 79 | num_classes=400, 80 | shortcut_type='B', 81 | cardinality=32, 82 | sample_size=112, 83 | sample_duration=16, 84 | last_fc=False) 85 | model = model.cuda() 86 | model_data = th.load(args.resnext101_model_path) 87 | model.load_state_dict(model_data) 88 | elif args.type == 's3dg': 89 | print('Loading S3DG ...') 90 | model = s3dg.S3D(last_fc=False) 91 | model = model.cuda() 92 | model_data = th.load(args.s3d_model_path) 93 | model = init_weight(model, model_data) 94 | 95 | 96 | model.eval() 97 | print('loaded') 98 | return model 99 | 100 | if __name__ == "__main__": 101 | model = resnext.resnet101( 102 | num_classes=400, 103 | shortcut_type='B', 104 | cardinality=32, 105 | sample_size=112, 106 | sample_duration=16, 107 | last_fc=False) 108 | print(model) 109 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/preprocess_generate_csv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | First of all you need to generate a csv containing the list of videos you want to process. For instance, if you have video1.mp4 and video2.webm to process, you will need to generate a csv of this form: 3 | 4 | video_path, feature_path 5 | video1.mp4, path_of_video1_features.npy 6 | video2.webm, path_of_video2_features.npy 7 | 8 | And then just simply run: 9 | 10 | python extract.py --csv=input.csv --type=s3dg --batch_size=64 --num_decoding_thread=4 11 | ''' 12 | 13 | import os 14 | import argparse 15 | 16 | if __name__ == "__main__": 17 | 18 | parser = argparse.ArgumentParser(description='Generate CSV') 19 | 20 | parser.add_argument('--csv', type=str, help='input csv with video input path') 21 | parser.add_argument('--video_root_path', type=str, help='video path') 22 | parser.add_argument('--feature_root_path', type=str, help='feature path') 23 | parser.add_argument('--csv_save_path', type=str, help='csv path', default='.') 24 | args = parser.parse_args() 25 | 26 | video_root_path = args.video_root_path 27 | feature_root_path = args.feature_root_path 28 | 29 | csv_save_path = os.path.join(args.csv_save_path, args.csv) 30 | fp_wt = open(csv_save_path, 'w') 31 | line = "video_path,feature_path" 32 | fp_wt.write(line + "\n") 33 | 34 | all_files = os.walk(video_root_path) 35 | for path, d, filelist in all_files: 36 | for file_name in filelist: 37 | video_path = os.path.join(path, file_name) 38 | video_id = video_path.replace("\\","/").split("/")[-1].split(".")[0] 39 | feature_path = os.path.join(feature_root_path, "{}.npy".format(video_id)) 40 | line = ",".join([video_path, feature_path]) 41 | fp_wt.write(line + "\n") 42 | 43 | fp_wt.close() 44 | print("csv is saved in: {}".format(csv_save_path)) -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/preprocess_generate_csv2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | First of all you need to generate a csv containing the list of videos you want to process. For instance, if you have video1.mp4 and video2.webm to process, you will need to generate a csv of this form: 3 | 4 | video_path, feature_path 5 | video1.mp4, path_of_video1_features.npy 6 | video2.webm, path_of_video2_features.npy 7 | 8 | And then just simply run: 9 | 10 | python extract.py --csv=input.csv --type=s3dg --batch_size=64 --num_decoding_thread=4 11 | ''' 12 | 13 | import os 14 | import argparse 15 | 16 | 17 | import json 18 | import jsonlines 19 | import pickle 20 | import csv 21 | import re 22 | import os.path 23 | import random 24 | from tqdm import tqdm 25 | 26 | 27 | class IOProcessor(): 28 | def read_jsonline(self,file): 29 | file=open(file,"r",encoding="utf-8") 30 | data=[json.loads(line) for line in file.readlines()] 31 | return data 32 | 33 | def write_jsonline(self,file,data): 34 | f=jsonlines.open(file,"w") 35 | for each in data: 36 | jsonlines.Writer.write(f,each) 37 | return 38 | 39 | def read_json(self,file): 40 | f=open(file,"r",encoding="utf-8").read() 41 | return json.loads(f) 42 | 43 | def write_json(self,file,data): 44 | f=open(file,"w",encoding="utf-8") 45 | json.dump(data,f,indent=2,ensure_ascii=False) 46 | return 47 | 48 | def read_pickle(self,filename): 49 | return pickle.loads(open(filename,"rb").read()) 50 | 51 | def write_pickle(self,filename,data): 52 | open(filename,"wb").write(pickle.dumps(data)) 53 | return 54 | 55 | 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | io_process=IOProcessor() 61 | 62 | parser = argparse.ArgumentParser(description='Generate CSV') 63 | 64 | parser.add_argument("--ids_file",type=str) 65 | parser.add_argument('--csv', type=str, help='input csv with video input path') 66 | parser.add_argument('--video_root_path', type=str, help='video path') 67 | parser.add_argument('--feature_root_path', type=str, help='feature path') 68 | parser.add_argument('--csv_save_path', type=str, help='csv path', default='.') 69 | args = parser.parse_args() 70 | 71 | video_root_path = args.video_root_path 72 | feature_root_path = args.feature_root_path 73 | 74 | csv_save_path = os.path.join(args.csv_save_path, args.csv) 75 | fp_wt = open(csv_save_path, 'w') 76 | line = "video_path,feature_path" 77 | fp_wt.write(line + "\n") 78 | 79 | ids=io_process.read_json("{}".format(args.ids_file)) 80 | 81 | # all_files = os.walk(video_root_path) 82 | # for path, d, filelist in all_files: 83 | # for file_name in filelist: 84 | # video_path = os.path.join(path, file_name) 85 | # video_id = video_path.replace("\\","/").split("/")[-1].split(".")[0] 86 | # feature_path = os.path.join(feature_root_path, "{}.npy".format(video_id)) 87 | # line = ",".join([video_path, feature_path]) 88 | # fp_wt.write(line + "\n") 89 | 90 | for id in ids: 91 | video_path = os.path.join(video_root_path, "{}.mp4".format(id)) 92 | video_id =id 93 | feature_path = os.path.join(feature_root_path, "{}.npy".format(video_id)) 94 | line = ",".join([video_path, feature_path]) 95 | fp_wt.write(line + "\n") 96 | 97 | 98 | fp_wt.close() 99 | print("csv is saved in: {}".format(csv_save_path)) -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/preprocess_generate_csv3.py: -------------------------------------------------------------------------------- 1 | ''' 2 | First of all you need to generate a csv containing the list of videos you want to process. For instance, if you have video1.mp4 and video2.webm to process, you will need to generate a csv of this form: 3 | 4 | video_path, feature_path 5 | video1.mp4, path_of_video1_features.npy 6 | video2.webm, path_of_video2_features.npy 7 | 8 | And then just simply run: 9 | 10 | python extract.py --csv=input.csv --type=s3dg --batch_size=64 --num_decoding_thread=4 11 | ''' 12 | 13 | import os 14 | import argparse 15 | 16 | 17 | import json 18 | import jsonlines 19 | import pickle 20 | import csv 21 | import re 22 | import os.path 23 | import random 24 | from tqdm import tqdm 25 | 26 | 27 | class IOProcessor(): 28 | def read_jsonline(self,file): 29 | file=open(file,"r",encoding="utf-8") 30 | data=[json.loads(line) for line in file.readlines()] 31 | return data 32 | 33 | def write_jsonline(self,file,data): 34 | f=jsonlines.open(file,"w") 35 | for each in data: 36 | jsonlines.Writer.write(f,each) 37 | return 38 | 39 | def read_json(self,file): 40 | f=open(file,"r",encoding="utf-8").read() 41 | return json.loads(f) 42 | 43 | def write_json(self,file,data): 44 | f=open(file,"w",encoding="utf-8") 45 | json.dump(data,f,indent=2,ensure_ascii=False) 46 | return 47 | 48 | def read_pickle(self,filename): 49 | return pickle.loads(open(filename,"rb").read()) 50 | 51 | def write_pickle(self,filename,data): 52 | open(filename,"wb").write(pickle.dumps(data)) 53 | return 54 | 55 | 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | io_process=IOProcessor() 61 | 62 | parser = argparse.ArgumentParser(description='Generate CSV') 63 | 64 | parser.add_argument("--ids_file",type=str) 65 | parser.add_argument('--csv', type=str, help='input csv with video input path') 66 | parser.add_argument('--video_root_path', type=str, help='video path') 67 | parser.add_argument('--feature_root_path', type=str, help='feature path') 68 | parser.add_argument('--csv_save_path', type=str, help='csv path', default='.') 69 | args = parser.parse_args() 70 | 71 | video_root_path = args.video_root_path 72 | feature_root_path = args.feature_root_path 73 | 74 | csv_save_path = os.path.join(args.csv_save_path, args.csv) 75 | fp_wt = open(csv_save_path, 'w') 76 | line = "video_path,feature_path" 77 | fp_wt.write(line + "\n") 78 | 79 | ids=set(list(io_process.read_json("{}".format(args.ids_file)).keys())) # TODO: the difference 80 | ids=[each_id.split("#")[0] for each_id in ids] 81 | # all_files = os.walk(video_root_path) 82 | # for path, d, filelist in all_files: 83 | # for file_name in filelist: 84 | # video_path = os.path.join(path, file_name) 85 | # video_id = video_path.replace("\\","/").split("/")[-1].split(".")[0] 86 | # feature_path = os.path.join(feature_root_path, "{}.npy".format(video_id)) 87 | # line = ",".join([video_path, feature_path]) 88 | # fp_wt.write(line + "\n") 89 | 90 | for id in ids: 91 | video_path = os.path.join(video_root_path, "{}.avi".format(id)) 92 | video_id =id 93 | feature_path = os.path.join(feature_root_path, "{}.npy".format(video_id)) 94 | line = ",".join([video_path, feature_path]) 95 | fp_wt.write(line + "\n") 96 | 97 | 98 | fp_wt.close() 99 | print("csv is saved in: {}".format(csv_save_path)) -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class Normalize(object): 4 | 5 | def __init__(self, mean, std): 6 | self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) 7 | self.std = th.FloatTensor(std).view(1, 3, 1, 1) 8 | 9 | def __call__(self, tensor): 10 | tensor = (tensor - self.mean) / (self.std + 1e-8) 11 | return tensor 12 | 13 | class Preprocessing(object): 14 | 15 | def __init__(self, type, FRAMERATE_DICT): 16 | self.type = type 17 | self.FRAMERATE_DICT = FRAMERATE_DICT 18 | if type == '2d': 19 | self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 20 | elif type == '3d': 21 | self.norm = Normalize(mean=[110.6, 103.2, 96.3], std=[1.0, 1.0, 1.0]) 22 | elif type == 's3dg': 23 | pass 24 | elif type == 'raw_data': 25 | pass 26 | 27 | def _zero_pad(self, tensor, size): 28 | n = size - len(tensor) % size 29 | if n == size: 30 | return tensor 31 | else: 32 | z = th.zeros(n, tensor.shape[1], tensor.shape[2], tensor.shape[3]) 33 | return th.cat((tensor, z), 0) 34 | 35 | def __call__(self, tensor): 36 | if self.type == '2d': 37 | tensor = tensor / 255.0 38 | tensor = self.norm(tensor) 39 | elif self.type == '3d': 40 | tensor = self._zero_pad(tensor, 16) 41 | tensor = self.norm(tensor) 42 | tensor = tensor.view(-1, 16, 3, 112, 112) 43 | tensor = tensor.transpose(1, 2) 44 | elif self.type == 's3dg': 45 | tensor = tensor / 255.0 46 | tensor = self._zero_pad(tensor, self.FRAMERATE_DICT[self.type]) 47 | # To Batch= T x 3 x H x W 48 | tensor_size = tensor.size() 49 | tensor = tensor.view(-1, self.FRAMERATE_DICT[self.type], 3, tensor_size[-2], tensor_size[-1]) 50 | # To Batch x 3 x T x H x W 51 | tensor = tensor.transpose(1, 2) 52 | elif self.type == 'raw_data': 53 | tensor = tensor / 255.0 54 | tensor = self._zero_pad(tensor, self.FRAMERATE_DICT[self.type]) 55 | # To Batch= T x 3 x H x W 56 | tensor_size = tensor.size() 57 | tensor = tensor.view(-1, self.FRAMERATE_DICT[self.type], 3, tensor_size[-2], tensor_size[-1]) 58 | # To Batch x 3 x T x H x W 59 | tensor = tensor.transpose(1, 2) 60 | 61 | return tensor 62 | 63 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/random_sequence_shuffler.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.utils.data.sampler import Sampler 3 | import numpy as np 4 | 5 | class RandomSequenceSampler(Sampler): 6 | 7 | def __init__(self, n_sample, seq_len): 8 | self.n_sample = n_sample 9 | self.seq_len = seq_len 10 | 11 | def _pad_ind(self, ind): 12 | zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len) 13 | ind = np.concatenate((ind, zeros)) 14 | return ind 15 | 16 | def __iter__(self): 17 | idx = np.arange(self.n_sample) 18 | if self.n_sample % self.seq_len != 0: 19 | idx = self._pad_ind(idx) 20 | idx = np.reshape(idx, (-1, self.seq_len)) 21 | np.random.shuffle(idx) 22 | idx = np.reshape(idx, (-1)) 23 | return iter(idx.astype(int)) 24 | 25 | def __len__(self): 26 | return self.n_sample + (self.seq_len - self.n_sample % self.seq_len) 27 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/run_gen_csv2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA_FILE=/multi_modal/data/product5m_v2/product1m_product5m_train_id_label.json 4 | IMAGE_PATH=/multi_modal/data/images 5 | VIDEO_PATH=/multi_modal/data/videos 6 | VIDEO_FEATURE_PATH=/multi_modal/data/video_feature 7 | CSV_SAVE_PATH=/multi_modal/data/product5m_v2 8 | 9 | 10 | python preprocess_generate_csv2.py \ 11 | --ids_file /multi_modal/data/product5m_v2/ava_video_ids.json \ 12 | --csv=input.csv \ 13 | --video_root_path ${VIDEO_PATH} \ 14 | --feature_root_path ${VIDEO_FEATURE_PATH} \ 15 | --csv_save_path ${CSV_SAVE_PATH} 16 | 17 | 18 | CUDA_VISIBLE_DEVICES=6 python extract.py \ 19 | --csv=${CSV_SAVE_PATH}/input.csv \ 20 | --type=s3dg \ 21 | --batch_size=32 \ 22 | --num_decoding_thread=16 23 | 24 | 25 | #python convert_video_feature_to_pickle.py \ 26 | #--feature_root_path ${VIDEO_FEATURE_PATH} \ 27 | #--pickle_root_path . \ 28 | #--pickle_name input.pickle 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/video_loader.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import os 5 | import numpy as np 6 | import ffmpeg 7 | 8 | 9 | class VideoLoader(Dataset): 10 | """Pytorch video loader.""" 11 | 12 | def __init__( 13 | self, 14 | csv, 15 | framerate=1, 16 | size=112, 17 | centercrop=False, 18 | ): 19 | """ 20 | Args: 21 | """ 22 | self.csv = pd.read_csv(csv) 23 | self.centercrop = centercrop 24 | self.size = size 25 | self.framerate = framerate 26 | 27 | def __len__(self): 28 | return len(self.csv) 29 | 30 | def _get_video_dim(self, video_path): 31 | probe = ffmpeg.probe(video_path) 32 | video_stream = next((stream for stream in probe['streams'] 33 | if stream['codec_type'] == 'video'), None) 34 | width = int(video_stream['width']) 35 | height = int(video_stream['height']) 36 | return height, width 37 | 38 | def _get_output_dim(self, h, w): 39 | if isinstance(self.size, tuple) and len(self.size) == 2: 40 | return self.size 41 | elif h >= w: 42 | return int(h * self.size / w), self.size 43 | else: 44 | return self.size, int(w * self.size / h) 45 | 46 | def __getitem__(self, idx): 47 | video_path = self.csv['video_path'].values[idx] 48 | output_file = self.csv['feature_path'].values[idx] 49 | video = th.zeros(1) 50 | 51 | if not (os.path.isfile(output_file)) and os.path.isfile(video_path): 52 | print('Decoding video: {}'.format(video_path)) 53 | try: 54 | h, w = self._get_video_dim(video_path) 55 | 56 | height, width = self._get_output_dim(h, w) 57 | cmd = ( 58 | ffmpeg 59 | .input(video_path) 60 | .filter('fps', fps=self.framerate) 61 | .filter('scale', width, height) 62 | ) 63 | if self.centercrop: 64 | x = int((width - self.size) / 2.0) 65 | y = int((height - self.size) / 2.0) 66 | cmd = cmd.crop(x, y, self.size, self.size) 67 | out, _ = ( 68 | cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') 69 | .run(capture_stdout=True, quiet=True) 70 | ) 71 | if self.centercrop and isinstance(self.size, int): 72 | height, width = self.size, self.size 73 | video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) 74 | video = th.from_numpy(video.astype('float32')) 75 | video = video.permute(0, 3, 1, 2) 76 | except: 77 | print('ffprobe failed at: {}'.format(video_path)) 78 | 79 | return {'video': video, 'input': video_path, 'output': output_file} 80 | 81 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/videocnn/models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['ResNeXt', 'resnet50', 'resnet101'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class ResNeXtBottleneck(nn.Module): 31 | expansion = 2 32 | 33 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None): 34 | super(ResNeXtBottleneck, self).__init__() 35 | mid_planes = cardinality * int(planes / 32) 36 | self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm3d(mid_planes) 38 | self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride, 39 | padding=1, groups=cardinality, bias=False) 40 | self.bn2 = nn.BatchNorm3d(mid_planes) 41 | self.conv3 = nn.Conv3d(mid_planes, planes * self.expansion, kernel_size=1, bias=False) 42 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | 72 | def __init__(self, block, layers, sample_size, sample_duration, shortcut_type='B', cardinality=32, num_classes=400, last_fc=True): 73 | self.last_fc = last_fc 74 | 75 | self.inplanes = 64 76 | super(ResNeXt, self).__init__() 77 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 78 | padding=(3, 3, 3), bias=False) 79 | self.bn1 = nn.BatchNorm3d(64) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 82 | self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type, cardinality) 83 | self.layer2 = self._make_layer(block, 256, layers[1], shortcut_type, cardinality, stride=2) 84 | self.layer3 = self._make_layer(block, 512, layers[2], shortcut_type, cardinality, stride=2) 85 | self.layer4 = self._make_layer(block, 1024, layers[3], shortcut_type, cardinality, stride=2) 86 | last_duration = math.ceil(sample_duration / 16) 87 | last_size = math.ceil(sample_size / 32) 88 | self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 89 | self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv3d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, nn.BatchNorm3d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, shortcut_type, cardinality, stride=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | if shortcut_type == 'A': 103 | downsample = partial(downsample_basic_block, 104 | planes=planes * block.expansion, 105 | stride=stride) 106 | else: 107 | downsample = nn.Sequential( 108 | nn.Conv3d(self.inplanes, planes * block.expansion, 109 | kernel_size=1, stride=stride, bias=False), 110 | nn.BatchNorm3d(planes * block.expansion) 111 | ) 112 | 113 | layers = [] 114 | layers.append(block(self.inplanes, planes, cardinality, stride, downsample)) 115 | self.inplanes = planes * block.expansion 116 | for i in range(1, blocks): 117 | layers.append(block(self.inplanes, planes, cardinality)) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.maxpool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | x = self.avgpool(x) 132 | 133 | x = x.view(x.size(0), -1) 134 | if self.last_fc: 135 | x = self.fc(x) 136 | 137 | return x 138 | 139 | def get_fine_tuning_parameters(model, ft_begin_index): 140 | if ft_begin_index == 0: 141 | return model.parameters() 142 | 143 | ft_module_names = [] 144 | for i in range(ft_begin_index, 5): 145 | ft_module_names.append('layer{}'.format(ft_begin_index)) 146 | ft_module_names.append('fc') 147 | 148 | parameters = [] 149 | for k, v in model.named_parameters(): 150 | for ft_module in ft_module_names: 151 | if ft_module in k: 152 | parameters.append({'params': v}) 153 | break 154 | else: 155 | parameters.append({'params': v, 'lr': 0.0}) 156 | 157 | return parameters 158 | 159 | def resnet50(**kwargs): 160 | """Constructs a ResNet-50 model. 161 | """ 162 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs) 163 | return model 164 | 165 | def resnet101(**kwargs): 166 | """Constructs a ResNet-101 model. 167 | """ 168 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs) 169 | return model 170 | 171 | def resnet152(**kwargs): 172 | """Constructs a ResNet-101 model. 173 | """ 174 | model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs) 175 | return model 176 | -------------------------------------------------------------------------------- /tools/VideoFeatureExtractor/videocnn/models/s3dg.py: -------------------------------------------------------------------------------- 1 | """Contains a PyTorch definition for Gated Separable 3D network (S3D-G) 2 | with a text module for computing joint text-video embedding from raw text 3 | and video input. The following code will enable you to load the HowTo100M 4 | pretrained S3D Text-Video model from: 5 | A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman, 6 | End-to-End Learning of Visual Representations from Uncurated Instructional Videos. 7 | https://arxiv.org/abs/1912.06430. 8 | 9 | S3D-G was proposed by: 10 | S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy, 11 | Rethinking Spatiotemporal Feature Learning For Video Understanding. 12 | https://arxiv.org/abs/1712.04851. 13 | Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py 14 | 15 | The S3D architecture was slightly modified with a space to depth trick for TPU 16 | optimization. 17 | """ 18 | 19 | import torch as th 20 | import torch.nn.functional as F 21 | import torch.nn as nn 22 | import os 23 | import numpy as np 24 | import re 25 | 26 | 27 | class InceptionBlock(nn.Module): 28 | def __init__( 29 | self, 30 | input_dim, 31 | num_outputs_0_0a, 32 | num_outputs_1_0a, 33 | num_outputs_1_0b, 34 | num_outputs_2_0a, 35 | num_outputs_2_0b, 36 | num_outputs_3_0b, 37 | gating=True, 38 | ): 39 | super(InceptionBlock, self).__init__() 40 | self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1]) 41 | self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1]) 42 | self.conv_b1_b = STConv3D( 43 | num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True 44 | ) 45 | self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1]) 46 | self.conv_b2_b = STConv3D( 47 | num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True 48 | ) 49 | self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1) 50 | self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1]) 51 | self.gating = gating 52 | self.output_dim = ( 53 | num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b 54 | ) 55 | if gating: 56 | self.gating_b0 = SelfGating(num_outputs_0_0a) 57 | self.gating_b1 = SelfGating(num_outputs_1_0b) 58 | self.gating_b2 = SelfGating(num_outputs_2_0b) 59 | self.gating_b3 = SelfGating(num_outputs_3_0b) 60 | 61 | def forward(self, input): 62 | """Inception block 63 | """ 64 | b0 = self.conv_b0(input) 65 | b1 = self.conv_b1_a(input) 66 | b1 = self.conv_b1_b(b1) 67 | b2 = self.conv_b2_a(input) 68 | b2 = self.conv_b2_b(b2) 69 | b3 = self.maxpool_b3(input) 70 | b3 = self.conv_b3_b(b3) 71 | if self.gating: 72 | b0 = self.gating_b0(b0) 73 | b1 = self.gating_b1(b1) 74 | b2 = self.gating_b2(b2) 75 | b3 = self.gating_b3(b3) 76 | return th.cat((b0, b1, b2, b3), dim=1) 77 | 78 | 79 | class SelfGating(nn.Module): 80 | def __init__(self, input_dim): 81 | super(SelfGating, self).__init__() 82 | self.fc = nn.Linear(input_dim, input_dim) 83 | 84 | def forward(self, input_tensor): 85 | """Feature gating as used in S3D-G. 86 | """ 87 | spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4]) 88 | weights = self.fc(spatiotemporal_average) 89 | weights = th.sigmoid(weights) 90 | return weights[:, :, None, None, None] * input_tensor 91 | 92 | 93 | class STConv3D(nn.Module): 94 | def __init__( 95 | self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False 96 | ): 97 | super(STConv3D, self).__init__() 98 | self.separable = separable 99 | self.relu = nn.ReLU(inplace=True) 100 | assert len(kernel_size) == 3 101 | if separable and kernel_size[0] != 1: 102 | spatial_kernel_size = [1, kernel_size[1], kernel_size[2]] 103 | temporal_kernel_size = [kernel_size[0], 1, 1] 104 | if isinstance(stride, list) and len(stride) == 3: 105 | spatial_stride = [1, stride[1], stride[2]] 106 | temporal_stride = [stride[0], 1, 1] 107 | else: 108 | spatial_stride = [1, stride, stride] 109 | temporal_stride = [stride, 1, 1] 110 | if isinstance(padding, list) and len(padding) == 3: 111 | spatial_padding = [0, padding[1], padding[2]] 112 | temporal_padding = [padding[0], 0, 0] 113 | else: 114 | spatial_padding = [0, padding, padding] 115 | temporal_padding = [padding, 0, 0] 116 | if separable: 117 | self.conv1 = nn.Conv3d( 118 | input_dim, 119 | output_dim, 120 | kernel_size=spatial_kernel_size, 121 | stride=spatial_stride, 122 | padding=spatial_padding, 123 | bias=False, 124 | ) 125 | self.bn1 = nn.BatchNorm3d(output_dim) 126 | self.conv2 = nn.Conv3d( 127 | output_dim, 128 | output_dim, 129 | kernel_size=temporal_kernel_size, 130 | stride=temporal_stride, 131 | padding=temporal_padding, 132 | bias=False, 133 | ) 134 | self.bn2 = nn.BatchNorm3d(output_dim) 135 | else: 136 | self.conv1 = nn.Conv3d( 137 | input_dim, 138 | output_dim, 139 | kernel_size=kernel_size, 140 | stride=stride, 141 | padding=padding, 142 | bias=False, 143 | ) 144 | self.bn1 = nn.BatchNorm3d(output_dim) 145 | 146 | def forward(self, input): 147 | out = self.relu(self.bn1(self.conv1(input))) 148 | if self.separable: 149 | out = self.relu(self.bn2(self.conv2(out))) 150 | return out 151 | 152 | 153 | class MaxPool3dTFPadding(th.nn.Module): 154 | def __init__(self, kernel_size, stride=None, padding="SAME"): 155 | super(MaxPool3dTFPadding, self).__init__() 156 | if padding == "SAME": 157 | padding_shape = self._get_padding_shape(kernel_size, stride) 158 | self.padding_shape = padding_shape 159 | self.pad = th.nn.ConstantPad3d(padding_shape, 0) 160 | self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True) 161 | 162 | def _get_padding_shape(self, filter_shape, stride): 163 | def _pad_top_bottom(filter_dim, stride_val): 164 | pad_along = max(filter_dim - stride_val, 0) 165 | pad_top = pad_along // 2 166 | pad_bottom = pad_along - pad_top 167 | return pad_top, pad_bottom 168 | 169 | padding_shape = [] 170 | for filter_dim, stride_val in zip(filter_shape, stride): 171 | pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val) 172 | padding_shape.append(pad_top) 173 | padding_shape.append(pad_bottom) 174 | depth_top = padding_shape.pop(0) 175 | depth_bottom = padding_shape.pop(0) 176 | padding_shape.append(depth_top) 177 | padding_shape.append(depth_bottom) 178 | return tuple(padding_shape) 179 | 180 | def forward(self, inp): 181 | inp = self.pad(inp) 182 | out = self.pool(inp) 183 | return out 184 | 185 | class S3D(nn.Module): 186 | def __init__(self, num_classes=512, gating=True, space_to_depth=True, last_fc=False): 187 | super(S3D, self).__init__() 188 | self.last_fc = last_fc 189 | self.num_classes = num_classes 190 | self.gating = gating 191 | self.space_to_depth = space_to_depth 192 | if space_to_depth: 193 | self.conv1 = STConv3D( 194 | 24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False 195 | ) 196 | else: 197 | self.conv1 = STConv3D( 198 | 3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False 199 | ) 200 | self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False) 201 | self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True) 202 | self.gating = SelfGating(192) 203 | self.maxpool_2a = MaxPool3dTFPadding( 204 | kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME" 205 | ) 206 | self.maxpool_3a = MaxPool3dTFPadding( 207 | kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME" 208 | ) 209 | self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32) 210 | self.mixed_3c = InceptionBlock( 211 | self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64 212 | ) 213 | self.maxpool_4a = MaxPool3dTFPadding( 214 | kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME" 215 | ) 216 | self.mixed_4b = InceptionBlock( 217 | self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64 218 | ) 219 | self.mixed_4c = InceptionBlock( 220 | self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64 221 | ) 222 | self.mixed_4d = InceptionBlock( 223 | self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64 224 | ) 225 | self.mixed_4e = InceptionBlock( 226 | self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64 227 | ) 228 | self.mixed_4f = InceptionBlock( 229 | self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128 230 | ) 231 | self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding( 232 | kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME" 233 | ) 234 | self.mixed_5b = InceptionBlock( 235 | self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128 236 | ) 237 | self.mixed_5c = InceptionBlock( 238 | self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128 239 | ) 240 | self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes) 241 | 242 | def _space_to_depth(self, input): 243 | """3D space to depth trick for TPU optimization. 244 | """ 245 | B, C, T, H, W = input.shape 246 | input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2) 247 | input = input.permute(0, 3, 5, 7, 1, 2, 4, 6) 248 | input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2) 249 | return input 250 | 251 | def forward(self, inputs): 252 | """Defines the S3DG base architecture. 253 | """ 254 | if self.space_to_depth: 255 | inputs = self._space_to_depth(inputs) 256 | net = self.conv1(inputs) 257 | if self.space_to_depth: 258 | # we need to replicate 'SAME' tensorflow padding 259 | net = net[:, :, 1:, 1:, 1:] 260 | net = self.maxpool_2a(net) 261 | net = self.conv_2b(net) 262 | net = self.conv_2c(net) 263 | if self.gating: 264 | net = self.gating(net) 265 | net = self.maxpool_3a(net) 266 | net = self.mixed_3b(net) 267 | net = self.mixed_3c(net) 268 | net = self.maxpool_4a(net) 269 | net = self.mixed_4b(net) 270 | net = self.mixed_4c(net) 271 | net = self.mixed_4d(net) 272 | net = self.mixed_4e(net) 273 | net = self.mixed_4f(net) 274 | net = self.maxpool_5a(net) 275 | net = self.mixed_5b(net) 276 | net = self.mixed_5c(net) 277 | net = th.mean(net, dim=[2, 3, 4]) 278 | 279 | if self.last_fc: 280 | net = self.fc(net) 281 | 282 | return net 283 | -------------------------------------------------------------------------------- /tools/audio_process/extract_audio_feature.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset,DataLoader 2 | import torch 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /tools/audio_process/get_audio_from_video_v4.py: -------------------------------------------------------------------------------- 1 | from moviepy.editor import * 2 | import os 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | import requests 6 | import os 7 | from _md5 import md5 8 | from tqdm import tqdm 9 | from multiprocessing import Pool 10 | import random 11 | import json 12 | import json 13 | import jsonlines 14 | import pickle 15 | import csv 16 | import re 17 | import os.path 18 | import random 19 | from tqdm import tqdm 20 | 21 | 22 | class IOProcessor(): 23 | def read_jsonline(self,file): 24 | file=open(file,"r",encoding="utf-8") 25 | data=[json.loads(line) for line in file.readlines()] 26 | return data 27 | 28 | def write_jsonline(self,file,data): 29 | f=jsonlines.open(file,"w") 30 | for each in data: 31 | jsonlines.Writer.write(f,each) 32 | return 33 | 34 | def read_json(self,file): 35 | f=open(file,"r",encoding="utf-8").read() 36 | return json.loads(f) 37 | 38 | def write_json(self,file,data): 39 | f=open(file,"w",encoding="utf-8") 40 | json.dump(data,f,indent=2,ensure_ascii=False) 41 | return 42 | 43 | def read_pickle(self,filename): 44 | return pickle.loads(open(filename,"rb").read()) 45 | 46 | def write_pickle(self,filename,data): 47 | open(filename,"wb").write(pickle.dumps(data)) 48 | return 49 | 50 | def convert(file_list): 51 | for each in tqdm(file_list): 52 | try: 53 | id = each 54 | if id == "544053141067": 55 | print(id) 56 | if os.path.exists('/multi_modal/data/audios/{}.mp3'.format(id)): 57 | continue 58 | video = VideoFileClip("/multi_modal/data/videos/{}.mp4".format(each)) 59 | audio = video.audio 60 | audio.write_audiofile('/multi_modal/data/audios/{}.mp3'.format(id)) 61 | except Exception as e: 62 | print("error: ", e) 63 | 64 | return 65 | 66 | 67 | import random 68 | if __name__ == '__main__': 69 | print() 70 | io_pro=IOProcessor() 71 | 72 | ava_video_ids=list(io_pro.read_json("/multi_modal/data/product5m_v2/subset_v2_id_label.json").keys()) 73 | 74 | filter_data=ava_video_ids 75 | random.shuffle(filter_data) 76 | 77 | thread_num = 10 78 | chunk_size = int(len(filter_data) / thread_num) 79 | chunk_data = [filter_data[i:i + chunk_size] for i in range(0, len(filter_data), chunk_size)] 80 | 81 | pool = Pool(thread_num) 82 | # download_all_data_images(data) 83 | multi_data = pool.map(convert, chunk_data) 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /tools/audio_process/save_audio_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import requests 4 | from io import BytesIO 5 | from moviepy.editor import * 6 | import time 7 | from multiprocessing import Pool 8 | from tqdm import tqdm 9 | import json 10 | import random 11 | import pdb 12 | import numpy as np 13 | 14 | 15 | def read_json(file): 16 | f=open(file,"r",encoding="utf-8").read() 17 | return json.loads(f) 18 | 19 | def write_json(file,data): 20 | f=open(file,"w",encoding="utf-8") 21 | json.dump(data,f,indent=2,ensure_ascii=False) 22 | return 23 | 24 | 25 | 26 | audio_len=12 27 | devices=torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | audio_dir = "/multi_modal/data/audios" 29 | audio_feature_dir= "/multi_modal/data/audio_feature" 30 | def extract_audio_feature(item_id): 31 | # try: 32 | if not os.path.exists("{}/{}.mp3".format(audio_dir,item_id)): 33 | return None 34 | 35 | if os.path.exists("{}/{}.npy".format(audio_feature_dir,item_id)): 36 | return None 37 | 38 | try: 39 | audios,sample_rate = torchaudio.load("{}/{}.mp3".format(audio_dir,item_id)) 40 | 41 | audios=audios.to(devices) 42 | resample = torchaudio.transforms.Resample(sample_rate, 16000).to(devices) 43 | audios=resample(audios) 44 | audio_data = torch.sum(torch.as_tensor(audios), dim=0) / 2 45 | 46 | 47 | if (len(audio_data) / 16000 < audio_len): 48 | new_audio_data = torch.zeros([audio_len * 16000]).to(devices) 49 | new_audio_data[0:len(audio_data)] = audio_data 50 | audio_data = new_audio_data.to(devices) 51 | else: 52 | # audio_data = torch.as_tensor(audio_data.cpu().numpy())[:16000 * audio_len].to(devices) 53 | audio_data=audio_data[:16000 * audio_len] 54 | 55 | transform=torchaudio.transforms.MelSpectrogram(n_mels=80, 56 | n_fft=1024, 57 | win_length=1024, 58 | hop_length=256).to(devices) 59 | 60 | audio_feature = transform(audio_data) 61 | # (channel, n_mels, time) 62 | print("audio_feature1: ",audio_feature.size()) # 80 x 751 63 | cur_mean, cur_std = audio_feature.mean(dim=0), audio_feature.std(dim=0) 64 | audio_feature = (audio_feature - cur_mean) / (cur_std + 1e-9) 65 | # audio_feature = audio_feature.permute(1, 0) 66 | # num_audio = audio_feature.shape[0] 67 | # print("audio_feature2: ",audio_feature.size()) 68 | audio_feature_np=audio_feature.cpu().numpy() 69 | print("audio_feature_np: ",audio_feature_np.shape) 70 | 71 | np.save("{}/{}.npy".format(audio_feature_dir,item_id),audio_feature_np) 72 | 73 | except Exception as e: 74 | print(e) 75 | 76 | 77 | def extract_audios_feature(item_id_list): 78 | for item_id in tqdm(item_id_list): 79 | extract_audio_feature(item_id) 80 | # break 81 | 82 | 83 | if __name__ == '__main__': 84 | print() 85 | 86 | # item_id="123456" 87 | # video_url="http://cloud.video.taobao.com/play/u/1745634433/p/1/e/6/t/1/50152114431.mp4" 88 | # audio_processor.get_audio_feature(item_id,video_url) 89 | 90 | # one thread 91 | subset_v2_id_list=list(read_json("/multi_modal/data/product5m_v2/subset_v2_id_label.json").keys()) 92 | extract_audios_feature(subset_v2_id_list) 93 | 94 | # multi thread 95 | # subset_v2_id_list=list(read_json("/multi_modal/data/product5m_v2/subset_v2_id_label.json").keys()) 96 | # filter_data=subset_v2_id_list 97 | # random.shuffle(filter_data) 98 | # 99 | # thread_num = 10 100 | # chunk_size = int(len(filter_data) / thread_num) 101 | # chunk_data = [filter_data[i:i + chunk_size] for i in range(0, len(filter_data), chunk_size)] 102 | # 103 | # pool = Pool(thread_num) 104 | # # download_all_data_images(data) 105 | # multi_data = pool.map(extract_audios_feature, chunk_data) 106 | # 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_gallery_c.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "gallery_c_0_15.tsv", 52 | "gallery_c_15_30.tsv", 53 | "gallery_c_30_45.tsv", 54 | "gallery_c_45_60.tsv", 55 | "gallery_c_60_75.tsv", 56 | "gallery_c_75_90.tsv", 57 | "gallery_c_90_105.tsv", 58 | "gallery_c_105_120.tsv" 59 | ] 60 | 61 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features/gallery_c",i) for i in file_name_list] 62 | self.counts = [] 63 | # self.num_caps = 3136161 64 | # self.num_caps = 15027 + 113970 # TODO: modify each time 65 | self.num_caps = 1179613 66 | 67 | 68 | 69 | 70 | 71 | 72 | def __len__(self): 73 | return self.num_caps 74 | 75 | def __iter__(self): 76 | cnt = 0 77 | for infile in self.infiles: 78 | count = 0 79 | with open(infile) as tsv_in_file: 80 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 81 | for item in reader: 82 | cnt+=1 83 | 84 | # if cnt>521820 and cnt<521830: 85 | # continue 86 | 87 | try: 88 | image_id = item['image_id'] 89 | # if image_id not in self.id_list: continue 90 | 91 | image_h = int(item['image_h']) 92 | image_w = int(item['image_w']) 93 | num_boxes = item['num_boxes'] 94 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 95 | int(num_boxes), 4) 96 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 97 | int(num_boxes), 2048) 98 | # print("image_id: ",image_id) 99 | caption = item["title"] 100 | except: 101 | print("error: ",cnt,image_id) 102 | continue 103 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 104 | 105 | 106 | 107 | 108 | import time 109 | if __name__ == '__main__': 110 | import time 111 | # time.sleep(3600*3) 112 | corpus_path = '' 113 | 114 | # time.sleep(25200) 115 | ds = Conceptual_Caption(corpus_path) 116 | 117 | # for each in ds: 118 | # print(each) 119 | # break 120 | 121 | ds1 = PrefetchDataZMQ(ds) 122 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/gallery_c_feature.lmdb') 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_gallery_fg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "gallery_fg_0_15.tsv", 52 | ] 53 | 54 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 55 | self.counts = [] 56 | # self.num_caps = 3136161 57 | # self.num_caps = 15027 + 113970 # TODO: modify each time 58 | self.num_caps = 117791 59 | 60 | 61 | 62 | 63 | 64 | 65 | def __len__(self): 66 | return self.num_caps 67 | 68 | def __iter__(self): 69 | cnt = 0 70 | for infile in self.infiles: 71 | count = 0 72 | with open(infile) as tsv_in_file: 73 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 74 | for item in reader: 75 | cnt+=1 76 | 77 | # if cnt>521820 and cnt<521830: 78 | # continue 79 | 80 | try: 81 | image_id = item['image_id'] 82 | # if image_id not in self.id_list: continue 83 | 84 | image_h = int(item['image_h']) 85 | image_w = int(item['image_w']) 86 | num_boxes = item['num_boxes'] 87 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 88 | int(num_boxes), 4) 89 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 90 | int(num_boxes), 2048) 91 | # print("image_id: ",image_id) 92 | caption = item["title"] 93 | except: 94 | print("error: ",cnt,image_id) 95 | continue 96 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 97 | 98 | 99 | 100 | 101 | import time 102 | if __name__ == '__main__': 103 | import time 104 | # time.sleep(3600*3) 105 | corpus_path = '' 106 | 107 | # time.sleep(25200) 108 | ds = Conceptual_Caption(corpus_path) 109 | 110 | # for each in ds: 111 | # print(each) 112 | # break 113 | 114 | ds1 = PrefetchDataZMQ(ds) 115 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/gallery_fg_feature.lmdb') 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_query_c.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "query_c_0_15.tsv", 52 | ] 53 | 54 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 55 | self.counts = [] 56 | # self.num_caps = 3136161 57 | # self.num_caps = 15027 + 113970 # TODO: modify each time 58 | self.num_caps = 24054 59 | 60 | 61 | 62 | 63 | 64 | 65 | def __len__(self): 66 | return self.num_caps 67 | 68 | def __iter__(self): 69 | cnt = 0 70 | for infile in self.infiles: 71 | count = 0 72 | with open(infile) as tsv_in_file: 73 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 74 | for item in reader: 75 | cnt+=1 76 | 77 | # if cnt>521820 and cnt<521830: 78 | # continue 79 | 80 | try: 81 | image_id = item['image_id'] 82 | # if image_id not in self.id_list: continue 83 | 84 | image_h = int(item['image_h']) 85 | image_w = int(item['image_w']) 86 | num_boxes = item['num_boxes'] 87 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 88 | int(num_boxes), 4) 89 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 90 | int(num_boxes), 2048) 91 | # print("image_id: ",image_id) 92 | caption = item["title"] 93 | except: 94 | print("error: ",cnt,image_id) 95 | continue 96 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 97 | 98 | 99 | 100 | 101 | import time 102 | if __name__ == '__main__': 103 | import time 104 | # time.sleep(3600*3) 105 | corpus_path = '' 106 | 107 | # time.sleep(25200) 108 | ds = Conceptual_Caption(corpus_path) 109 | 110 | # for each in ds: 111 | # print(each) 112 | # break 113 | 114 | ds1 = PrefetchDataZMQ(ds) 115 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/query_c_feature.lmdb') 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_query_fg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "query_fg_0_15.tsv", 52 | ] 53 | 54 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 55 | self.counts = [] 56 | # self.num_caps = 3136161 57 | # self.num_caps = 15027 + 113970 # TODO: modify each time 58 | self.num_caps = 1991 59 | 60 | 61 | 62 | 63 | 64 | 65 | def __len__(self): 66 | return self.num_caps 67 | 68 | def __iter__(self): 69 | cnt = 0 70 | for infile in self.infiles: 71 | count = 0 72 | with open(infile) as tsv_in_file: 73 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 74 | for item in reader: 75 | cnt+=1 76 | 77 | # if cnt>521820 and cnt<521830: 78 | # continue 79 | 80 | try: 81 | image_id = item['image_id'] 82 | # if image_id not in self.id_list: continue 83 | 84 | image_h = int(item['image_h']) 85 | image_w = int(item['image_w']) 86 | num_boxes = item['num_boxes'] 87 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 88 | int(num_boxes), 4) 89 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 90 | int(num_boxes), 2048) 91 | # print("image_id: ",image_id) 92 | caption = item["title"] 93 | except: 94 | print("error: ",cnt,image_id) 95 | continue 96 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 97 | 98 | 99 | 100 | 101 | import time 102 | if __name__ == '__main__': 103 | import time 104 | # time.sleep(3600*3) 105 | corpus_path = '' 106 | 107 | # time.sleep(25200) 108 | ds = Conceptual_Caption(corpus_path) 109 | 110 | # for each in ds: 111 | # print(each) 112 | # break 113 | 114 | ds1 = PrefetchDataZMQ(ds) 115 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/query_fg_feature.lmdb') 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_subset_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "query_c_0_15.tsv", 52 | ] 53 | 54 | self.ava_ids=read_json("/multi_modal/data/product5m_v2/subset_test_ids.json") 55 | 56 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 57 | self.counts = [] 58 | # self.num_caps = 3136161 59 | # self.num_caps = 15027 + 113970 # TODO: modify each time 60 | self.num_caps = len(self.ava_ids) 61 | 62 | 63 | 64 | def __len__(self): 65 | return self.num_caps 66 | 67 | def __iter__(self): 68 | cnt = 0 69 | for infile in self.infiles: 70 | count = 0 71 | with open(infile) as tsv_in_file: 72 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 73 | for item in reader: 74 | cnt+=1 75 | 76 | 77 | # if cnt>521820 and cnt<521830: 78 | # continue 79 | 80 | try: 81 | image_id = item['image_id'] 82 | if image_id not in self.ava_ids: continue 83 | 84 | image_h = int(item['image_h']) 85 | image_w = int(item['image_w']) 86 | num_boxes = item['num_boxes'] 87 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 88 | int(num_boxes), 4) 89 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 90 | int(num_boxes), 2048) 91 | # print("image_id: ",image_id) 92 | caption = item["title"] 93 | except: 94 | print("error: ",cnt,image_id) 95 | continue 96 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 97 | 98 | 99 | import time 100 | if __name__ == '__main__': 101 | import time 102 | # time.sleep(3600*3) 103 | corpus_path = '' 104 | 105 | # time.sleep(25200) 106 | ds = Conceptual_Caption(corpus_path) 107 | 108 | # for each in ds: 109 | # print(each) 110 | # break 111 | 112 | ds1 = PrefetchDataZMQ(ds) 113 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/subset_test_feature.lmdb') 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_subset_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "query_c_0_15.tsv", 52 | ] 53 | 54 | self.ava_ids=read_json("/multi_modal/data/product5m_v2/subset_train_ids.json") 55 | 56 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 57 | self.counts = [] 58 | # self.num_caps = 3136161 59 | # self.num_caps = 15027 + 113970 # TODO: modify each time 60 | self.num_caps = len(self.ava_ids) 61 | 62 | 63 | 64 | def __len__(self): 65 | return self.num_caps 66 | 67 | def __iter__(self): 68 | cnt = 0 69 | for infile in self.infiles: 70 | count = 0 71 | with open(infile) as tsv_in_file: 72 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 73 | for item in reader: 74 | cnt+=1 75 | 76 | 77 | # if cnt>521820 and cnt<521830: 78 | # continue 79 | 80 | try: 81 | image_id = item['image_id'] 82 | if image_id not in self.ava_ids: continue 83 | 84 | image_h = int(item['image_h']) 85 | image_w = int(item['image_w']) 86 | num_boxes = item['num_boxes'] 87 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 88 | int(num_boxes), 4) 89 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 90 | int(num_boxes), 2048) 91 | # print("image_id: ",image_id) 92 | caption = item["title"] 93 | except: 94 | print("error: ",cnt,image_id) 95 | continue 96 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 97 | 98 | 99 | import time 100 | if __name__ == '__main__': 101 | import time 102 | # time.sleep(3600*3) 103 | corpus_path = '' 104 | 105 | # time.sleep(25200) 106 | ds = Conceptual_Caption(corpus_path) 107 | 108 | # for each in ds: 109 | # print(each) 110 | # break 111 | 112 | ds1 = PrefetchDataZMQ(ds) 113 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/subset_train_feature.lmdb') 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_subset_v4_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "subset_v2_test.tsv", 52 | ] 53 | 54 | self.ava_ids=list(read_json("/multi_modal/data/product5m_v2/subset_v4_test_id_label.json").keys()) 55 | 56 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 57 | self.counts = [] 58 | # self.num_caps = 3136161 59 | # self.num_caps = 15027 + 113970 # TODO: modify each time 60 | self.num_caps = len(self.ava_ids) 61 | 62 | 63 | 64 | def __len__(self): 65 | return self.num_caps 66 | 67 | def __iter__(self): 68 | cnt = 0 69 | for infile in self.infiles: 70 | count = 0 71 | with open(infile) as tsv_in_file: 72 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 73 | for item in reader: 74 | cnt+=1 75 | try: 76 | image_id = item['image_id'] 77 | if image_id not in self.ava_ids: continue 78 | 79 | image_h = int(item['image_h']) 80 | image_w = int(item['image_w']) 81 | num_boxes = item['num_boxes'] 82 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 83 | int(num_boxes), 4) 84 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 85 | int(num_boxes), 2048) 86 | # print("image_id: ",image_id) 87 | caption = item["title"] 88 | except: 89 | print("error: ",cnt,image_id) 90 | continue 91 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 92 | 93 | 94 | import time 95 | if __name__ == '__main__': 96 | import time 97 | # time.sleep(3600*3) 98 | corpus_path = '' 99 | 100 | # time.sleep(25200) 101 | ds = Conceptual_Caption(corpus_path) 102 | 103 | # for each in ds: 104 | # print(each) 105 | # break 106 | 107 | ds1 = PrefetchDataZMQ(ds) 108 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/subset_v4_test_feature.lmdb') 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_subset_v4_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "subset_v2_train.tsv", 52 | ] 53 | 54 | self.ava_ids=list(read_json("/multi_modal/data/product5m_v2/subset_v4_train_id_label.json").keys()) 55 | 56 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 57 | self.counts = [] 58 | # self.num_caps = 3136161 59 | # self.num_caps = 15027 + 113970 # TODO: modify each time 60 | self.num_caps = len(self.ava_ids) 61 | 62 | 63 | 64 | def __len__(self): 65 | return self.num_caps 66 | 67 | def __iter__(self): 68 | cnt = 0 69 | for infile in self.infiles: 70 | count = 0 71 | with open(infile) as tsv_in_file: 72 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 73 | for item in reader: 74 | cnt+=1 75 | try: 76 | image_id = item['image_id'] 77 | if image_id not in self.ava_ids: continue 78 | 79 | image_h = int(item['image_h']) 80 | image_w = int(item['image_w']) 81 | num_boxes = item['num_boxes'] 82 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 83 | int(num_boxes), 4) 84 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 85 | int(num_boxes), 2048) 86 | # print("image_id: ",image_id) 87 | caption = item["title"] 88 | except: 89 | print("error: ",cnt,image_id) 90 | continue 91 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 92 | 93 | 94 | import time 95 | if __name__ == '__main__': 96 | import time 97 | # time.sleep(3600*3) 98 | corpus_path = '' 99 | 100 | # time.sleep(25200) 101 | ds = Conceptual_Caption(corpus_path) 102 | 103 | # for each in ds: 104 | # print(each) 105 | # break 106 | 107 | ds1 = PrefetchDataZMQ(ds) 108 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/subset_v4_train_feature.lmdb') 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/convert_train_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # from tensorpack.dataflow import RNGDataFlow, PrefetchDataZMQ 4 | from tensorpack.dataflow import * 5 | import lmdb 6 | import json 7 | import pdb 8 | import csv 9 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features',"title"] 10 | import sys 11 | import pandas as pd 12 | import zlib 13 | import base64 14 | 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | 18 | def read_json(file): 19 | f=open(file,"r",encoding="utf-8").read() 20 | return json.loads(f) 21 | 22 | def write_json(file,data): 23 | f=open(file,"w",encoding="utf-8") 24 | json.dump(data,f,indent=2,ensure_ascii=False) 25 | return 26 | 27 | 28 | def open_tsv(fname, folder): 29 | print("Opening %s Data File..." % fname) 30 | df = pd.read_csv(fname, sep=',', names=["caption", "url"], usecols=range(1, 3)) 31 | df['folder'] = folder 32 | print("Processing", len(df), " Images:") 33 | return df 34 | 35 | 36 | def _file_name(row): 37 | return "%s/%s" % (row['folder'], (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff)) 38 | 39 | class Conceptual_Caption(RNGDataFlow): 40 | """ 41 | """ 42 | def __init__(self, corpus_path, shuffle=False): 43 | """ 44 | Same as in :class:`ILSVRC12`. 45 | """ 46 | self.shuffle = shuffle 47 | self.num_file = 30 48 | # self.name = os.path.join(corpus_path, 'conceptual_caption_trainsubset_resnet101_faster_rcnn_genome.tsv.%d') 49 | 50 | file_name_list=[ 51 | "train_2.5m_3.5m_50_62.tsv", 52 | "train_1m_2m_25_37.tsv", 53 | "train_3.5m_5m_75_87.tsv", 54 | "train_3.5m_5m_12_25.tsv", 55 | "train_1m_75_87.tsv", 56 | "train_2.5m_3.5m_62_75.tsv", 57 | "train_2m_2.5m_37_50.tsv", 58 | "train_3.5m_5m_37_50.tsv", 59 | "train_2.5m_3.5m_87_100.tsv", 60 | "train_1m_2m_50_62.tsv", 61 | "train_3.5m_5m_50_62.tsv", 62 | "train_1m_25_37.tsv", 63 | "train_1m_37_50.tsv", 64 | "train_1m_2m_75_87.tsv", 65 | "train_2.5m_3.5m_0_12.tsv", 66 | "train_1m_2m_12_25.tsv", 67 | "train_1m_12_25.tsv", 68 | "train_2m_2.5m_12_25.tsv", 69 | "train_1m_50_62.tsv", 70 | "train_2.5m_3.5m_75_87.tsv", 71 | "train_3.5m_5m_0_12.tsv", 72 | "train_1m_87_100.tsv", 73 | "train_1m_0_12.tsv", 74 | "train_3.5m_5m_87_100.tsv", 75 | "train_3.5m_5m_62_75.tsv", 76 | "train_1m_2m_87_100.tsv", 77 | "train_1m_2m_62_75.tsv", 78 | "train_2m_2.5m_25_37.tsv", 79 | "train_2m_2.5m_0_12.tsv", 80 | "train_2.5m_3.5m_12_25.tsv", 81 | "train_1m_62_75.tsv", 82 | "train_2.5m_3.5m_37_50.tsv", 83 | "train_1m_2m_37_50.tsv", 84 | "train_2.5m_3.5m_25_37.tsv", 85 | "train_3.5m_5m_25_37.tsv", 86 | "train_1m_2m_0_12.tsv" 87 | ] 88 | 89 | self.infiles = ["{}/{}".format("/multi_modal/data/tsv_features",i) for i in file_name_list] 90 | self.counts = [] 91 | # self.num_caps = 3136161 92 | # self.num_caps = 15027 + 113970 # TODO: modify each time 93 | self.num_caps = 4363122 94 | 95 | 96 | 97 | 98 | 99 | 100 | def __len__(self): 101 | return self.num_caps 102 | 103 | def __iter__(self): 104 | cnt = 0 105 | for infile in self.infiles: 106 | count = 0 107 | with open(infile) as tsv_in_file: 108 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 109 | for item in reader: 110 | cnt+=1 111 | 112 | # if cnt>521820 and cnt<521830: 113 | # continue 114 | 115 | try: 116 | image_id = item['image_id'] 117 | # if image_id not in self.id_list: continue 118 | 119 | image_h = int(item['image_h']) 120 | image_w = int(item['image_w']) 121 | num_boxes = item['num_boxes'] 122 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape( 123 | int(num_boxes), 4) 124 | features = np.frombuffer(base64.b64decode(item['features']), dtype=np.float32).reshape( 125 | int(num_boxes), 2048) 126 | # print("image_id: ",image_id) 127 | caption = item["title"] 128 | except: 129 | print("error: ",cnt,image_id) 130 | continue 131 | yield [features, boxes, num_boxes, image_h, image_w, image_id, caption] 132 | 133 | 134 | 135 | 136 | import time 137 | if __name__ == '__main__': 138 | import time 139 | # time.sleep(3600*3) 140 | corpus_path = '' 141 | 142 | # time.sleep(25200) 143 | ds = Conceptual_Caption(corpus_path) 144 | 145 | # for each in ds: 146 | # print(each) 147 | # break 148 | 149 | ds1 = PrefetchDataZMQ(ds) 150 | LMDBSerializer.save(ds1, '/multi_modal/data/lmdb_features/train_feature.lmdb') 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /tools/bp_feature/convert/get_all_tsv_filename.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | 5 | def read_json(file): 6 | f=open(file,"r",encoding="utf-8").read() 7 | return json.loads(f) 8 | 9 | def write_json(file,data): 10 | f=open(file,"w",encoding="utf-8") 11 | json.dump(data,f,indent=2,ensure_ascii=False) 12 | return 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | print() 18 | 19 | 20 | tsv_file_list=list(os.listdir("/multi_modal/data/tsv_features")) 21 | 22 | write_json("tsv_filename_list.json", tsv_file_list) 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /tools/bp_feature/extract/run_subset_v3_1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python generate_subset_tsv.py \ 5 | --start 0 \ 6 | --end 5 7 | 8 | -------------------------------------------------------------------------------- /tools/bp_feature/extract/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonlines 3 | import pickle 4 | import csv 5 | import re 6 | 7 | class IOProcessor(): 8 | def read_jsonline(self,file): 9 | file=open(file,"r",encoding="utf-8") 10 | data=[json.loads(line) for line in file.readlines()] 11 | return data 12 | 13 | def write_jsonline(self,file,data): 14 | f=jsonlines.open(file,"w") 15 | for each in data: 16 | jsonlines.Writer.write(f,each) 17 | return 18 | 19 | def read_json(self,file): 20 | f=open(file,"r",encoding="utf-8").read() 21 | return json.loads(f) 22 | 23 | def write_json(self,file,data): 24 | f=open(file,"w",encoding="utf-8") 25 | json.dump(data,f,indent=2,ensure_ascii=False) 26 | return 27 | 28 | def read_pickle(self,filename): 29 | return pickle.loads(open(filename,"rb").read()) 30 | 31 | def write_pickle(self,filename,data): 32 | open(filename,"wb").write(pickle.dumps(data)) 33 | return 34 | 35 | 36 | def read_csv(self,filename): 37 | csv_data = csv.reader(open(filename, "r", encoding="utf-8")) 38 | csv_data=[each for each in csv_data] 39 | return csv_data 40 | 41 | 42 | if __name__ == '__main__': 43 | pass 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | --------------------------------------------------------------------------------