├── .gitignore ├── ALLIES ├── README.md ├── assets │ └── model.jpg ├── dataset │ ├── nq-test.jsonl │ ├── tqa-test.jsonl │ └── webq-test.jsonl ├── embedding_generation.ipynb ├── main.py ├── modeling_bert.py ├── retrieval_utils.py ├── tools.py └── utils.py ├── CAPSTONE ├── README.md ├── generate_query.sh ├── get_doc2query_marco.sh ├── get_msmarco.sh ├── get_trec.sh ├── merge_beir_result.py ├── models │ ├── co_training_generate_new_train_wiki.py │ ├── generate_query.py │ ├── merge_query.py │ ├── modules.py │ └── run_de_model_ernie.py ├── preprocess │ ├── merge_query.py │ ├── preprocess_msmarco.py │ └── preprocess_trec.py ├── run_de_model_expand_corpus_cocondenser.sh ├── run_de_model_expand_corpus_cocondenser_step2.sh └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── dpr_utils.py │ ├── evaluate_trec.py │ ├── lamb.py │ ├── metric_utils.py │ └── util.py ├── CODE_OF_CONDUCT.md ├── HOW_TO_DOWNLOAD.md ├── LEAD ├── README.md ├── assets │ ├── main_result.jpg │ └── model.jpg ├── data_preprocess │ ├── .DS_Store │ ├── data_preprocess.py │ ├── get_data.sh │ └── hard_negative_generation │ │ ├── retrieve_hard_negatives_academic.sh │ │ └── train_de.sh ├── dataset.py ├── distillation │ ├── distill_from_12cb_to_6de.sh │ ├── distill_from_12ce_to_6de.sh │ ├── distill_from_12de_to_6de.sh │ └── distill_from_24cb_to_12de.sh ├── inference │ ├── evaluate_12_layer_de.sh │ └── evaluate_6_layer_de.sh ├── inference_de.py ├── modeling_bert.py ├── modeling_distilbert.py ├── models.py ├── retrieve_hard_negative.py ├── run_LEAD.py ├── run_single_model.py ├── util.py └── warm_up │ ├── train_12_layer_ce.sh │ ├── train_12_layer_col.sh │ ├── train_12_layer_de.sh │ ├── train_24_layer_col.sh │ └── train_6_layer_de.sh ├── LICENSE ├── MASTER ├── README.md ├── figs │ ├── master_main.jpg │ └── master_main_result.jpg ├── finetune │ ├── MS │ │ ├── co_training_model.py │ │ ├── co_training_model_ele.py │ │ ├── co_training_model_ele_continue.py │ │ ├── inference_de.py │ │ ├── inference_de_prob.py │ │ ├── run_ce_model.py │ │ ├── run_ce_model_ele.py │ │ └── run_de_model.py │ ├── ft_MS_MASTER.sh │ ├── ft_wiki_NQ.sh │ ├── ft_wiki_TQ.sh │ ├── model │ │ ├── models.py │ │ └── models_ele.py │ ├── utils │ │ ├── MARCO_until.py │ │ ├── dpr_utils.py │ │ ├── lamb.py │ │ └── util.py │ └── wiki │ │ ├── co_training_model.py │ │ ├── inference_de.py │ │ ├── run_ce_model.py │ │ └── run_de_model.py └── pretrain │ ├── arguments.py │ ├── data.py │ ├── modeling.py │ ├── run_pre_training.py │ ├── run_pretrain.sh │ └── trainer.py ├── PROD ├── MarcoDoc_Data.sh ├── MarcoPas_Data.sh ├── ProD_KD │ ├── model │ │ ├── __init__.py │ │ └── models.py │ ├── run_progressive_distill_marco.py │ ├── run_progressive_distill_marcodoc.py │ ├── run_progressive_distill_nq.py │ └── utils │ │ ├── __init__.py │ │ ├── build_marco_train.py │ │ ├── dataset_division_marco.py │ │ ├── dataset_division_marcodoc.py │ │ ├── dataset_division_nq.py │ │ ├── dpr_utils.py │ │ ├── lamb.py │ │ ├── marco_until.py │ │ ├── prepare_ce_data_nq.py │ │ ├── preprae_ce_marco_train.py │ │ ├── preprae_ce_marcodoc_train.py │ │ └── util.py ├── ProD_base │ ├── inference_DE_marco.py │ ├── inference_DE_marcodoc.py │ ├── inference_DE_nq.py │ ├── inference_DE_trec.py │ ├── model │ │ ├── __init__.py │ │ └── models.py │ ├── rerank_eval_marco.py │ ├── rerank_eval_marcodoc.py │ ├── rerank_eval_nq.py │ ├── rerank_train_eval_marco.py │ ├── rerank_train_eval_marcodoc.py │ ├── rerank_train_eval_nq.py │ ├── train_CE_model_marco.py │ ├── train_CE_model_marcodoc.py │ ├── train_CE_model_nq.py │ ├── train_DE_model_marco.py │ ├── train_DE_model_marcodoc.py │ ├── train_DE_model_nq.py │ └── utils │ │ ├── __init__.py │ │ ├── dpr_utils.py │ │ ├── lamb.py │ │ ├── marco_until.py │ │ └── util.py ├── README.md └── image │ └── framework.jpg ├── README.md ├── SECURITY.md ├── SUPPORT.md └── SimANS ├── Doc_training ├── co_training_doc_generate.py ├── co_training_doc_train.py ├── co_training_generate_new_train.py └── star_tokenizer.py ├── README.md ├── best_simans_ckpt ├── MS-Doc │ └── .gitkeep ├── MS-Pas │ └── .gitkeep ├── NQ │ └── .gitkeep └── TQ │ └── .gitkeep ├── ckpt ├── MS-Doc │ └── .gitkeep ├── MS-Pas │ └── .gitkeep ├── NQ │ └── .gitkeep └── TQ │ └── .gitkeep ├── co_training ├── co_training_generate.py ├── co_training_marco_generate.py └── co_training_marco_train.py ├── data ├── MS-Doc │ └── .gitkeep ├── MS-Pas │ └── .gitkeep ├── NQ │ └── .gitkeep └── TQ │ └── .gitkeep ├── figs ├── simans_industry_result.jpg ├── simans_main.jpg └── simans_main_result.jpg ├── model └── models.py ├── train_MS_Doc_AR2.sh ├── train_MS_Pas_AR2.sh ├── train_NQ_AR2.sh ├── train_TQ_AR2.sh ├── utils ├── MARCO_until_Doc.py ├── MARCO_until_new.py ├── dpr_utils.py ├── util.py └── util_wiki.py └── wiki ├── co_training_generate_new_train_wiki.py ├── co_training_wiki_generate.py └── co_training_wiki_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | -------------------------------------------------------------------------------- /ALLIES/README.md: -------------------------------------------------------------------------------- 1 | # ALLIES 2 | 3 | The code for our paper [ALLIES: Prompting Large Language Model with Beam Search](https://arxiv.org/abs/2305.14766). 4 | 5 | ![model](assets/model.jpg) 6 | 7 | ## Dataset 8 | 9 | ### NQ 10 | 11 | dataset/nq-test.jsonl 12 | 13 | ### TriviaQA 14 | 15 | dataset/tqa-test.jsonl 16 | 17 | ### WebQ 18 | 19 | dataset/webq-test.jsonl 20 | 21 | 22 | ## Released Resources 23 | 24 | We release the preprocessed data and trained ckpts in [Azure Blob](https://msranlciropen.blob.core.windows.net/simxns/ALLIES/). 25 | Here we also provide the file list under this URL: 26 |
27 | Click here to see the file list. 28 |
INFO: nq/de-checkpoint-10000/passage_embedding.pb;  Content Length: 60.13 GiB
29 | INFO: nq/de-checkpoint-10000/passage_embedding2id.pb;  Content Length: 160.33 MiB
30 | INFO: webq/de-checkpoint-400/passage_embedding.pb;  Content Length: 60.13 GiB
31 | INFO: webq/de-checkpoint-400/passage_embedding2id.pb;  Content Length: 160.33 MiB
32 | INFO: tq/de-checkpoint-10000/passage_embedding.pb;  Content Length: 60.13 GiB
33 | INFO: tq/de-checkpoint-10000/passage_embedding2id.pb;  Content Length: 160.33 MiB
34 |
35 | 36 | To download the files, please refer to [HOW_TO_DOWNLOAD](https://github.com/microsoft/SimXNS/tree/main/HOW_TO_DOWNLOAD.md). 37 | 38 | 39 | 40 | ## Run 41 | 42 | ### Directly Answer 43 | 44 | ``` 45 | python main.py --dataset $dataset --task answer_without_retrieval --apikey $ID 46 | ``` 47 | 48 | ### Answer with retrieval 49 | 50 | ``` 51 | python main.py --dataset $dataset --task answer_with_retrieval --topK $retrieval_num --apikey $ID 52 | ``` 53 | 54 | ### GenRead 55 | 56 | ``` 57 | python main.py --dataset $dataset --task genread --apikey $ID 58 | ``` 59 | 60 | ### Allies 61 | 62 | ``` 63 | ##GENREAD 64 | python main.py --dataset $dataset --task ALLIES --retrieval_type generate --beam_size $beam_size --beam_Depth $beam_depth --ask_question_num $ask_question_num --apikey $ID 65 | 66 | ##Retrieval 67 | python main.py --dataset $dataset --task ALLIES --topK $retrieval_num --retrieval_type retrieve --beam_size $beam_size --beam_Depth $beam_depth --ask_question_num $ask_question_num --apikey $ID 68 | ``` 69 | 70 | ## Parameters 71 | 72 | - $dataset: Dataset for testing 73 | - $ID: The key for API 74 | - $beam_size: Beam size 75 | - $beam_depth: Beam depth 76 | - $ask_question_num: Ask question number 77 | - $retrieval_num: Retrieval doc num 78 | -------------------------------------------------------------------------------- /ALLIES/assets/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/ALLIES/assets/model.jpg -------------------------------------------------------------------------------- /ALLIES/embedding_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "92874557", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from tqdm import trange\n", 11 | "import os\n", 12 | "import pickle" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "542ce71c", 18 | "metadata": {}, 19 | "source": [ 20 | "## Step 1: Prepare the training dataset for QA" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "f92b9312", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "train_data_list = []\n", 31 | "\n", 32 | "for qid in trange(len(dataset)):\n", 33 | " example = {}\n", 34 | " \n", 35 | " positive_ctxs = []\n", 36 | " \n", 37 | " example['question'] = dataset[qid]['question']\n", 38 | " \n", 39 | " example['answers'] = dataset[qid]['answer']\n", 40 | " \n", 41 | " train_data_list.append(example)\n", 42 | "\n", 43 | "with open(f'/{data_path}/biencoder-dataset-train.json', 'w', encoding='utf-8') as json_file:\n", 44 | " json_file.write(json.dumps(train_data_list, indent=4))" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "e8a951ff", 50 | "metadata": {}, 51 | "source": [ 52 | "## Step 2: Retrieve the initial documents and generate the training data for dense retrieval" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "37c2f93b", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "## Please follow the instruction in LEAD (https://github.com/microsoft/SimXNS/tree/main/LEAD)\n", 63 | "\n", 64 | "## bash retrieve_hard_negatives_academic.sh $DATASET $MASTER_PORT $MODEL_PATH $CKPT_NAME $MAX_DOC_LENGTH $MAX_QUERY_LENGTH" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "id": "9c017de6", 70 | "metadata": {}, 71 | "source": [ 72 | "## Step 3: Train the retriever using the generated dataset" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "866ad0b7", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "## Please follow the instruction in LEAD (https://github.com/microsoft/SimXNS/tree/main/LEAD)\n", 83 | "\n", 84 | "## bash train_12_layer_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "4d370f43", 90 | "metadata": {}, 91 | "source": [ 92 | "## Step 4: Output the embedding file using the trained checkpoint" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "f47e70ec", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "## Please follow the instruction in LEAD (https://github.com/microsoft/SimXNS/tree/main/LEAD)\n", 103 | "\n", 104 | "## bash evaluate_12_layer_de.sh $DATASET $MASTER_PORT $MODEL_PATH $FIRST_STEPS $EVAL_STEPS $MAX_STEPS $MAX_DOC_LENGTH $MAX_QUERY_LENGTH" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "ab021dbd", 110 | "metadata": {}, 111 | "source": [ 112 | "## Step 5: Merge the inference the embedding file as follow:" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "23acb765", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "passage_embedding_list = []\n", 123 | "passage_embedding_id_list = []\n", 124 | "for i in trange(N): \n", 125 | " pickle_path = os.path.join(f'/{data_path}/', \"{1}_data_obj_{0}.pb\".format(str(i), 'passage_embedding'))\n", 126 | " with open(pickle_path, 'rb') as handle:\n", 127 | " b = pickle.load(handle)\n", 128 | " passage_embedding_list.append(b)\n", 129 | "for i in trange(N): \n", 130 | " pickle_path = os.path.join(f'/{data_path}/', \"{1}_data_obj_{0}.pb\".format(str(i), 'passage_embedding_id'))\n", 131 | " with open(pickle_path, 'rb') as handle:\n", 132 | " b = pickle.load(handle)\n", 133 | " passage_embedding_id_list.append(b)\n", 134 | "passage_embedding = np.concatenate(passage_embedding_list, axis=0)\n", 135 | "passage_embedding_id = np.concatenate(passage_embedding_id_list, axis=0)\n", 136 | "\n", 137 | "with open(f'/{data_path}/passage_embedding.pb', 'wb') as f:\n", 138 | " pickle.dump(passage_embedding, f)\n", 139 | "f.close()\n", 140 | "\n", 141 | "with open(f'/{data_path}/passage_embedding2id.pb', 'wb') as f:\n", 142 | " pickle.dump(passage_embedding_id, f)\n", 143 | "f.close()" 144 | ] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "py37", 150 | "language": "python", 151 | "name": "py37" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.10.9" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 5 168 | } 169 | -------------------------------------------------------------------------------- /ALLIES/utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import json 3 | import time 4 | from tqdm import trange 5 | import regex 6 | import json 7 | import string 8 | import unicodedata 9 | from typing import List 10 | import numpy as np 11 | from collections import Counter 12 | from rouge import Rouge 13 | import json 14 | import openai 15 | import backoff 16 | import os 17 | from multiprocessing.pool import ThreadPool 18 | import threading 19 | import time 20 | import datetime 21 | from msal import PublicClientApplication, SerializableTokenCache 22 | import json 23 | import os 24 | import atexit 25 | import requests 26 | from utils import * 27 | from retrieval_utils import * 28 | import _thread 29 | from contextlib import contextmanager 30 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 31 | 32 | 33 | import signal 34 | 35 | cache = {} 36 | default_engine = None 37 | 38 | from transformers import AutoTokenizer, AutoModel 39 | 40 | class SimpleTokenizer(object): 41 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 42 | NON_WS = r'[^\p{Z}\p{C}]' 43 | 44 | def __init__(self): 45 | """ 46 | Args: 47 | annotators: None or empty set (only tokenizes). 48 | """ 49 | self._regexp = regex.compile( 50 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 51 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 52 | ) 53 | 54 | def tokenize(self, text, uncased=False): 55 | matches = [m for m in self._regexp.finditer(text)] 56 | if uncased: 57 | tokens = [m.group().lower() for m in matches] 58 | else: 59 | tokens = [m.group() for m in matches] 60 | return tokens 61 | 62 | def readfiles(infile): 63 | 64 | if infile.endswith('json'): 65 | lines = json.load(open(infile, 'r', encoding='utf8')) 66 | elif infile.endswith('jsonl'): 67 | lines = open(infile, 'r', encoding='utf8').readlines() 68 | lines = [json.loads(l) for l in lines] 69 | else: 70 | raise NotImplementedError 71 | 72 | if len(lines[0]) == 1 and lines[0].get('prompt'): 73 | lines = lines[1:] ## skip prompt line 74 | 75 | return lines 76 | 77 | def write(log_file, text): 78 | log_file.write(text + '\n') 79 | log_file.flush() 80 | 81 | def create_directory(path): 82 | try: 83 | os.makedirs(path) 84 | print(f"Created directory: {path}") 85 | except FileExistsError: 86 | print(f"Directory already exists: {path}") 87 | 88 | def sparse_retrieval(searcher_sparse, query, top_K): 89 | hits_sparse = searcher_sparse.search(query) 90 | result = [] 91 | for i in range(top_K): 92 | result.append([field.stringValue() for field in hits_sparse[i].lucene_document.getFields()][1]) 93 | return result 94 | 95 | def retrieval(questions, top_k, tokenizer, model, cpu_index, passage_embedding2id, passages, args): 96 | question_embedding = get_question_embeddings(questions, tokenizer, model, args) 97 | _, dev_I = cpu_index.search(question_embedding.astype(np.float32), top_k) # I: [number of queries, topk] 98 | topk_document = [passage_embedding2id[dev_I[0][index]] for index in range(len(dev_I[0]))] 99 | 100 | return [passages[doc][1] for doc in topk_document] 101 | 102 | def check_answer(example, tokenizer) -> List[bool]: 103 | """Search through all the top docs to see if they have any of the answers.""" 104 | answers = example['answers'] 105 | ctxs = example['ctxs'] 106 | 107 | hits = [] 108 | 109 | for _, doc in enumerate(ctxs): 110 | text = doc['text'] 111 | 112 | if text is None: # cannot find the document for some reason 113 | hits.append(False) 114 | continue 115 | 116 | hits.append(has_answer(answers, text, tokenizer)) 117 | 118 | return hits 119 | 120 | def has_answer(answers, text, tokenizer=SimpleTokenizer()) -> bool: 121 | """Check if a document contains an answer string.""" 122 | text = _normalize(text) 123 | text = tokenizer.tokenize(text, uncased=True) 124 | 125 | for answer in answers: 126 | answer = _normalize(answer) 127 | answer = tokenizer.tokenize(answer, uncased=True) 128 | for i in range(0, len(text) - len(answer) + 1): 129 | if answer == text[i: i + len(answer)]: 130 | return True 131 | return False 132 | 133 | def _normalize(text): 134 | return unicodedata.normalize('NFD', text) 135 | 136 | def normalize_answer(s): 137 | def remove_articles(text): 138 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 139 | 140 | def white_space_fix(text): 141 | return ' '.join(text.split()) 142 | 143 | def remove_punc(text): 144 | exclude = set(string.punctuation) 145 | return ''.join(ch for ch in text if ch not in exclude) 146 | 147 | def lower(text): 148 | return text.lower() 149 | 150 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 151 | 152 | def exact_match_score(prediction, ground_truth): 153 | return normalize_answer(prediction) == normalize_answer(ground_truth) 154 | 155 | def ems(prediction, ground_truths): 156 | return max([exact_match_score(prediction, gt) for gt in ground_truths]) 157 | 158 | def f1_score(prediction, ground_truth): 159 | prediction_tokens = normalize_answer(prediction).split() 160 | ground_truth_tokens = normalize_answer(ground_truth).split() 161 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 162 | num_same = sum(common.values()) 163 | if num_same == 0: 164 | return 0 165 | precision = 1.0 * num_same / len(prediction_tokens) 166 | recall = 1.0 * num_same / len(ground_truth_tokens) 167 | f1 = (2 * precision * recall) / (precision + recall) 168 | return f1 169 | 170 | def f1(prediction, ground_truths): 171 | return max([f1_score(prediction, gt) for gt in ground_truths]) 172 | 173 | def rougel_score(prediction, ground_truth): 174 | rouge = Rouge() 175 | # no normalization 176 | try: 177 | scores = rouge.get_scores(prediction, ground_truth, avg=True) 178 | except ValueError: # "Hypothesis is empty." 179 | return 0.0 180 | return scores["rouge-l"]["f"] 181 | 182 | def rl(prediction, ground_truths): 183 | return max([rougel_score(prediction, gt) for gt in ground_truths]) 184 | 185 | def run_inference_openai(inputs_with_prompts): 186 | for _ in range(200): 187 | try: 188 | completions = openai.ChatCompletion.create( 189 | model="gpt-3.5-turbo", 190 | messages=[ 191 | {"role": "system", "content": "You are a helpful AI assistant."}, 192 | {"role": "user", "content": inputs_with_prompts}, 193 | ], 194 | temperature=0, 195 | ) 196 | except: 197 | continue 198 | 199 | outputs = [c["message"]['content'] for c in completions["choices"]] 200 | return outputs, completions['usage']['total_tokens'] 201 | 202 | def print_args(args): 203 | print("=================================================================") 204 | print("======================General Setting=========================") 205 | print(f"Dataset:".rjust(30) + " " + str(args.dataset)) 206 | print(f"Data Path:".rjust(30) + " " + str(args.data_path)) 207 | print(f"Task:".rjust(30) + " " + str(args.task)) 208 | print(f"Retrieval TopK:".rjust(30) + " " + str(args.topK)) 209 | print(f"Beam Size:".rjust(30) + " " + str(args.beam_size)) 210 | print(f"Beam Depth:".rjust(30) + " " + str(args.beam_Depth)) 211 | print(f"Ask Question Num:".rjust(30) + " " + str(args.ask_question_num)) 212 | print(f"Threshold:".rjust(30) + " " + str(args.threshold)) 213 | print(f"Device:".rjust(30) + " " + str(args.device)) 214 | print(f"Summary:".rjust(30) + " " + str(args.summary)) 215 | print(f"Retrieval Type:".rjust(30) + " " + str(args.retrieval_type)) 216 | print(f"Identifier:".rjust(30) + " " + str(args.unique_identifier)) 217 | print(f"Save File:".rjust(30) + " " + str(args.save_file)) 218 | print("=================================================================") 219 | print("=================================================================") 220 | 221 | def set_openai(args): 222 | openai.api_key = args.apikey 223 | -------------------------------------------------------------------------------- /CAPSTONE/README.md: -------------------------------------------------------------------------------- 1 | # CAPSTONE 2 | This repo provides the code and models in [*CAPSTONE*](https://arxiv.org/abs/2212.09114). 3 | In the paper, we propose CAPSTONE, a curriculum sampling for dense retrieval with document expansion, to bridge the gap between training and inference for dual-cross-encoder. 4 | 5 | 6 | # Requirements 7 | pip install datasets==2.4.0 8 | pip install rouge_score==0.1.2 9 | pip install nltk 10 | pip install transformers==4.21.1 11 | 12 | conda install -y -c pytorch faiss-gpu==1.7.1 13 | pip install tensorboard 14 | pip install pytrec-eval 15 | 16 | # Get Data 17 | Download the cleaned corpus hosted by RocketQA team, generate BM25 negatives for MS-MARCO. Then, download [TREC-Deep-Learning-2019](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019)(TREC DL 19) and [TREC-Deep-Learning-2020](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2020)(TREC DL 20). 18 | 19 | ``` 20 | bash get_msmarco.sh 21 | cd .. 22 | bash get_trec.sh 23 | python preprocess/preprocess_msmarco.py --data_dir ./marco --output_path ./reformed_marco 24 | python preprocess/preprocess_trec.py 25 | ``` 26 | 27 | Download the generated queries for MS-MARCO and merge duplicated queries. Note the query file is around 19GB. 28 | ``` 29 | bash get_doc2query_marco.sh 30 | python preprocess/merge_query.py 31 | ``` 32 | 33 | # Train CAPSTONE on MS-MARCO 34 | 35 | Train CAPSTONE on MS-MARCO for two stages and initialize the retriever with coCondenser at each stage. 36 | At the first training stage, the hard negatives are sampled from the official BM25 hard negatives, but at the second training stage, the hard negatives are sampled from the mined hard negatives. 37 | 38 | Evaluate CAPSTONE on the MS-MARCO development set, TREC DL 19 and 20 test sets. 39 | ```bash 40 | bash run_de_model_expand_corpus_cocondenser.sh 41 | bash run_de_model_expand_corpus_cocondenser_step2.sh 42 | ``` 43 | If you want to evaluate CAPSTONE on BEIR benchmark, you should download [BEIR](https://github.com/beir-cellar/beir) datasets and generate queries for BEIR. 44 | ```bash 45 | bash generate_query.sh 46 | ``` 47 | 48 | # Well-trained Checkpoints 49 | | Model | Download link 50 | |----------------------|--------| 51 | | The checkpoint for CAPSTONE trained on MS-MARCO at the first step| [\[link\]](https://drive.google.com/file/d/1QTsHQV8BQDJmxGD--fr4hmYdjpizE0zq/view?usp=sharing) | 52 | | The checkpoint for CAPSTONE trained on MS-MARCO at the second step| [\[link\]](https://drive.google.com/file/d/1tssOGIRwwXpn2yg4StiRC3B3u0_UqzLG/view?usp=sharing) | 53 | 54 | 55 | 56 | 57 | 58 | # Citation 59 | If you want to use this code in your research, please cite our [paper](https://arxiv.org/abs/2212.09114): 60 | ```bibtex 61 | @inproceedings{he-CAPSTONE, 62 | title = "CAPSTONE: Curriculum Sampling for Dense Retrieval with Document Expansion", 63 | author = "Xingwei He and Yeyun Gong and A-Long Jin and Hang Zhang and Anlei Dong and Jian Jiao and Siu Ming Yiu and Nan Duan", 64 | booktitle = "Proceedings of EMNLP", 65 | year = "2023", 66 | } 67 | ``` -------------------------------------------------------------------------------- /CAPSTONE/generate_query.sh: -------------------------------------------------------------------------------- 1 | # generate queries for beir data 2 | for dataset in cqadupstack/android cqadupstack/english cqadupstack/gaming cqadupstack/gis cqadupstack/mathematica cqadupstack/physics cqadupstack/programmers cqadupstack/stats cqadupstack/tex cqadupstack/unix cqadupstack/webmasters cqadupstack/wordpress \ 3 | quora robust04 trec-news nq signal1m dbpedia-entity \ 4 | nfcorpus scifact arguana scidocs fiqa trec-covid webis-touche2020 \ 5 | bioasq hotpotqa fever climate-fever 6 | do 7 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9589 \ 8 | ./models/generate_query.py --dataset $dataset 9 | done 10 | -------------------------------------------------------------------------------- /CAPSTONE/get_doc2query_marco.sh: -------------------------------------------------------------------------------- 1 | 2 | # get the generated query from https://github.com/castorini/docTTTTTquery 3 | mkdir tmp 4 | mkdir docTTTTTquery_full 5 | cd tmp 6 | wget https://git.uwaterloo.ca/jimmylin/doc2query-data/raw/master/T5-passage/predicted_queries_topk_sampling.zip 7 | 8 | unzip predicted_queries_topk_sampling.zip 9 | 10 | for i in $(seq -f "%03g" 0 17); do 11 | echo "Processing chunk $i" 12 | paste predicted_queries_topk_sample???.txt${i}-1004000 \ 13 | > predicted_queries_topk.txt${i}-1004000 14 | done 15 | 16 | cat predicted_queries_topk.txt???-1004000 > doc2query.tsv 17 | mv doc2query.tsv ../docTTTTTquery_full 18 | cd .. 19 | rm -rf tmp 20 | -------------------------------------------------------------------------------- /CAPSTONE/get_msmarco.sh: -------------------------------------------------------------------------------- 1 | # this shell is from https://github.com/NLPCode/tevatron/blob/main/examples/coCondenser-marco/get_data.sh 2 | SCRIPT_DIR=$PWD 3 | 4 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz 5 | tar -zxf marco.tar.gz 6 | rm -rf marco.tar.gz 7 | cd marco 8 | 9 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz 10 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv 11 | gunzip qidpidtriples.train.full.2.tsv.gz 12 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv 13 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv 14 | 15 | 16 | -------------------------------------------------------------------------------- /CAPSTONE/get_trec.sh: -------------------------------------------------------------------------------- 1 | # this shell is from https://github.com/NLPCode/tevatron/blob/main/examples/coCondenser-marco/get_data.sh 2 | SCRIPT_DIR=$PWD 3 | 4 | #https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019 5 | mkdir trec_19 6 | cd trec_19 7 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz 8 | gunzip msmarco-test2019-queries.tsv.gz 9 | rm msmarco-test2019-queries.tsv.gz 10 | wget --no-check-certificate https://trec.nist.gov/data/deep/2019qrels-pass.txt 11 | 12 | cd .. 13 | 14 | #https://microsoft.github.io/msmarco/TREC-Deep-Learning-2020 15 | mkdir trec_20 16 | cd trec_20 17 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz 18 | gunzip msmarco-test2020-queries.tsv.gz 19 | rm msmarco-test2020-queries.tsv.gz 20 | wget --no-check-certificate https://trec.nist.gov/data/deep/2020qrels-pass.txt -------------------------------------------------------------------------------- /CAPSTONE/merge_beir_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | this file is used to merge the beir results of each dataset. 3 | """ 4 | import argparse 5 | import os 6 | import json 7 | corpus_list = ['trec-covid', 'bioasq', 'nfcorpus', 'nq', 'hotpotqa', 'fiqa', 'signal1m', 'trec-news', 'robust04', 'arguana', 8 | 'webis-touche2020', 'cqadupstack', 'quora', 'dbpedia-entity', 'scidocs', 'fever', 'climate-fever', 'scifact' ] 9 | corpus_list2 = ["cqadupstack/android", "cqadupstack/english", "cqadupstack/gaming", "cqadupstack/gis", "cqadupstack/mathematica", "cqadupstack/physics", 10 | "cqadupstack/programmers", "cqadupstack/stats", "cqadupstack/tex", "cqadupstack/unix", "cqadupstack/webmasters", "cqadupstack/wordpress"] 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--beir_data_path", default=None, type=str, help="path of a dataset in beir") 16 | parser.add_argument("--checkpoint_num", default=20000, type=int) 17 | args = parser.parse_args() 18 | print(os.path.join(args.beir_data_path, f'test_eval_result{args.checkpoint_num}.json')) 19 | with open(os.path.join(args.beir_data_path, f'test_eval_result{args.checkpoint_num}.txt'), 'w') as fw: 20 | fw.write("NDCG@10\n") 21 | # fw.write('K=10\n') 22 | # total_value = 0 23 | # for corpus in corpus_list: 24 | # if corpus=='cqadupstack': 25 | # total = 0 26 | # for subcorpus in corpus_list2: 27 | # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}_0_query.json') 28 | # with open(filename, 'r') as fr: 29 | # results = json.load(fr) 30 | # total += float(results['NDCG@10']) 31 | # value = total/len(corpus_list2) 32 | # total_value += value 33 | # value = str(value) 34 | 35 | # else: 36 | # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}_0_query.json') 37 | # if os.path.exists(filename): 38 | # with open(filename, 'r') as fr: 39 | # results = json.load(fr) 40 | # value = results['NDCG@10'] 41 | # else: 42 | # value ='0' 43 | # print(f'{corpus} no results.') 44 | # total_value += float(value) 45 | # fw.write(f'{corpus}: {value:.3}\n') 46 | # fw.write(f'Average: {total_value/len(corpus_list):.3}\n') 47 | 48 | # fw.write('K=5\n') 49 | # total_value = 0 50 | # for corpus in corpus_list: 51 | # if corpus=='cqadupstack': 52 | # total = 0 53 | # for subcorpus in corpus_list2: 54 | # # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}.json') 55 | # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}_0_5_query.json') 56 | # with open(filename, 'r') as fr: 57 | # results = json.load(fr) 58 | # total += float(results['NDCG@10']) 59 | # value = total/len(corpus_list2) 60 | # total_value += value 61 | # value = str(value) 62 | 63 | # else: 64 | # # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}.json') 65 | # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}_0_5_query.json') 66 | # if os.path.exists(filename): 67 | # with open(filename, 'r') as fr: 68 | # results = json.load(fr) 69 | # value = results['NDCG@10'] 70 | # else: 71 | # value ='0' 72 | # print(f'{corpus} no results.') 73 | # total_value += float(value) 74 | # fw.write(f'{corpus}: {value:.3}\n') 75 | # fw.write(f'Average: {total_value/len(corpus_list):.3}\n') 76 | 77 | 78 | fw.write('base\n') 79 | total_value = 0 80 | for corpus in corpus_list: 81 | if corpus=='cqadupstack': 82 | total = 0 83 | for subcorpus in corpus_list2: 84 | # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}.json') 85 | filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}.json') 86 | with open(filename, 'r') as fr: 87 | results = json.load(fr) 88 | total += float(results['NDCG@10']) 89 | value = total/len(corpus_list2) 90 | total_value += value 91 | value = str(value) 92 | 93 | else: 94 | # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}.json') 95 | filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}.json') 96 | if os.path.exists(filename): 97 | with open(filename, 'r') as fr: 98 | results = json.load(fr) 99 | value = results['NDCG@10'] 100 | else: 101 | value ='0' 102 | print(f'{corpus} no results.') 103 | total_value += float(value) 104 | fw.write(f'{corpus}: {value:.3}\n') 105 | fw.write(f'Average: {total_value/len(corpus_list):.3}\n') 106 | -------------------------------------------------------------------------------- /CAPSTONE/models/merge_query.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | This file aims to generate queries for a given document, which users may ask based on the given document. 5 | Input: a document 6 | Output: top-k queries, 7 | """ 8 | from dataclasses import dataclass, field 9 | import csv 10 | import pickle 11 | from turtle import title 12 | from transformers import ( 13 | HfArgumentParser, 14 | ) 15 | 16 | import os 17 | import sys 18 | import json 19 | import random 20 | from tabulate import tabulate 21 | from tqdm import tqdm 22 | import logging 23 | 24 | pyfile_path = os.path.abspath(__file__) 25 | pyfile_dir = os.path.dirname(os.path.dirname(pyfile_path)) # equals to the path '../' 26 | sys.path.append(pyfile_dir) 27 | from utils.util import normalize_question, set_seed, get_optimizer, sum_main 28 | from utils.data_utils import load_passage 29 | 30 | # logger = logging.getLogger(__name__) 31 | logger = logging.getLogger("__main__") 32 | 33 | @dataclass 34 | class Arguments: 35 | dir_name: str = field( 36 | default="", 37 | metadata={ 38 | "help": ( 39 | "The path containing the generated queries." 40 | ) 41 | }, 42 | ) 43 | 44 | seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) 45 | 46 | # for cross train 47 | num_split: int = field(default=1, 48 | metadata={"help": "Split the train set into num_split parts. " 49 | "Note There is no overlap between different parts in terms of the positive doc." 50 | "Suppose num_split=3, we need to train three different generators. " 51 | "The first generator is tained on the data_2 and data_3, " 52 | "the second is trained on the data_1 and data_3, the third is trained on data_1 and data_2."}) 53 | 54 | 55 | 56 | 57 | def main(): 58 | parser = HfArgumentParser((Arguments)) 59 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 60 | # If we pass only one argument to the script and it's the path to a json file, 61 | # let's parse it to get our arguments. 62 | args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 63 | else: 64 | args, = parser.parse_args_into_dataclasses() 65 | print(args) 66 | removed_passage_ids_list = [] 67 | # removed_passage_id_path = os.path.join('checkpoints', args.dir_name, f'0_th_generator/removed_passge_id_for_0_th_generator.json') 68 | # removed_passage_id_set = set(json.load(open(removed_passage_id_path, 'r', encoding='utf-8'))) 69 | 70 | # removed_passage_id_path = os.path.join('checkpoints', args.dir_name, f'1_th_generator/removed_passge_id_for_1_th_generator.json') 71 | # removed_passage_id_set2 = set(json.load(open(removed_passage_id_path, 'r', encoding='utf-8'))) 72 | # print(len(removed_passage_id_set), len(removed_passage_id_set2), len(removed_passage_id_set|removed_passage_id_set2), len(removed_passage_id_set&removed_passage_id_set2)) 73 | # exit() 74 | union_set = set() 75 | for i in range(args.num_split): 76 | # query_path = os.path.join(args.dir_name, f'{i}_th_generator/corpus/top_k_10_rs_10.json') 77 | removed_passage_id_path = os.path.join('checkpoints', args.dir_name, f'{i}_th_generator/removed_passge_id_for_{i}_th_generator.json') 78 | print(os.path.exists(removed_passage_id_path)) 79 | removed_passage_ids = json.load(open(removed_passage_id_path, 'r', encoding='utf-8')) 80 | removed_passage_ids_list.append(removed_passage_ids) 81 | union_set = union_set | set(removed_passage_ids) 82 | 83 | 84 | final_queries_dict = {} 85 | for i in range(args.num_split): 86 | query_path = os.path.join('checkpoints', args.dir_name, f'{i}_th_generator/corpus/top_k_10_rs_10.json') 87 | removed_passage_ids = set(removed_passage_ids_list[i]) 88 | queries_list = json.load(open(query_path, 'r', encoding='utf-8')) 89 | for e in queries_list: 90 | passage_id = e['passage_id'] 91 | generated_queries = e['generated_queries'] 92 | if passage_id in removed_passage_ids or passage_id not in union_set: 93 | final_queries_dict[passage_id] = final_queries_dict.get(passage_id, []) + generated_queries 94 | merge_output_data = [] 95 | for k, v in final_queries_dict.items(): 96 | merge_output_data.append({'passage_id': k, 'generated_queries':v}) 97 | sorted_merge_output_data = sorted(merge_output_data, key=lambda x: int(x['passage_id'])) 98 | with open(os.path.join('checkpoints', args.dir_name, f'top_k_10_rs_10.json'), 'w', encoding='utf-8') as fw: 99 | json.dump(sorted_merge_output_data, fw, indent=2) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() -------------------------------------------------------------------------------- /CAPSTONE/preprocess/merge_query.py: -------------------------------------------------------------------------------- 1 | import json 2 | query_filename='./docTTTTTquery_full/doc2query.tsv' 3 | 4 | # reform 5 | min_q=80 6 | with open(query_filename, 'r', encoding='utf-8') as fr, open('doc2query_merge.tsv', 'w', encoding='utf-8') as fw: 7 | for psg_id, line in enumerate(fr): 8 | example = line.strip().split('\t') 9 | example = list(set([e.strip() for e in example])) 10 | min_q=min(min_q, len(example)) 11 | output_data = '\t'.join([str(psg_id)]+example) 12 | #output_data = {'passage_id': str(psg_id), 'generated_queries': example} 13 | fw.write(output_data+'\n') 14 | print(min_q) 15 | 16 | -------------------------------------------------------------------------------- /CAPSTONE/preprocess/preprocess_msmarco.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script aims to convert the data format into the 3 | """ 4 | from tqdm import tqdm 5 | import json 6 | import csv 7 | import datasets 8 | from dataclasses import dataclass 9 | from argparse import ArgumentParser 10 | import os 11 | import random 12 | from tqdm import tqdm 13 | from datetime import datetime 14 | from collections import OrderedDict 15 | import logging 16 | # logger = logging.getLogger("__main__") 17 | logger = logging.getLogger(__name__) 18 | @dataclass 19 | class MSMARCOPreProcessor: 20 | data_dir: str 21 | split: str 22 | 23 | columns = ['text_id', 'title', 'text'] 24 | title_field = 'title' 25 | text_field = 'text' 26 | 27 | def __post_init__(self): 28 | assert self.split in ['train', 'dev'] 29 | self.query_file = os.path.join(self.data_dir, f'{self.split}.query.txt') 30 | self.queries = self.read_queries(self.query_file) 31 | 32 | 33 | self.collection_file = os.path.join(self.data_dir, f'corpus.tsv') 34 | print(f'Loading passages from: {self.collection_file}') 35 | self.collection = datasets.load_dataset( 36 | 'csv', 37 | data_files=self.collection_file, 38 | column_names=self.columns, 39 | delimiter='\t', 40 | )['train'] 41 | print(f'the copus has {len(self.collection)} passages.') 42 | 43 | self.relevance_file = os.path.join(self.data_dir, f'qrels.{self.split}.tsv') 44 | self.query_2_pos_doc_dict = self.read_positive(self.relevance_file) 45 | if self.split =='train': 46 | self.negative_file = os.path.join(self.data_dir, f'train.negatives.tsv') 47 | self.query_2_neg_doc_dict = self.read_negative(self.negative_file) 48 | else: 49 | self.negative_file = None 50 | self.query_2_neg_doc_dict = None 51 | 52 | @staticmethod 53 | def read_negative(negative_file): 54 | print(f'Load the negative docs for queries from {negative_file}.') 55 | query_2_neg_doc_dict = {} 56 | with open(negative_file, 'r', encoding='utf8') as fr: 57 | for line in fr: 58 | query_id, negative_doc_id_list = line.strip().split('\t') 59 | negative_doc_id_list = negative_doc_id_list.split(',') 60 | random.shuffle(negative_doc_id_list) 61 | query_2_neg_doc_dict[query_id] = negative_doc_id_list 62 | print(f'{len(query_2_neg_doc_dict)} queries have negative docs.') 63 | return query_2_neg_doc_dict 64 | 65 | @staticmethod 66 | def read_queries(queries): 67 | print(f'Load the query from {queries}.') 68 | qmap = OrderedDict() 69 | with open(queries, 'r', encoding='utf8') as f: 70 | for l in f: 71 | qid, qry = l.strip().split('\t') 72 | qmap[qid] = qry 73 | print(f'Train set has {len(qmap)} queries.') 74 | return qmap 75 | 76 | 77 | @staticmethod 78 | def read_positive(relevance_file): 79 | print(f'Load the positive docs for queries from {relevance_file}.') 80 | query_2_pos_doc_dict = {} 81 | with open(relevance_file, 'r', encoding='utf8') as f: 82 | tsvreader = csv.reader(f, delimiter="\t") 83 | if 'train' in relevance_file: 84 | print('train') 85 | for [query_id, _, doc_id, rel] in tsvreader: 86 | if query_id in query_2_pos_doc_dict: 87 | assert rel == "1" 88 | query_2_pos_doc_dict[query_id].append(doc_id) 89 | else: 90 | query_2_pos_doc_dict[query_id] = [doc_id] 91 | else: # dev 92 | print('dev') 93 | for [query_id, doc_id] in tsvreader: 94 | if query_id in query_2_pos_doc_dict: 95 | query_2_pos_doc_dict[query_id].append(doc_id) 96 | else: 97 | query_2_pos_doc_dict[query_id] = [doc_id] 98 | print(f'{len(query_2_pos_doc_dict)} queries have positive docs.') 99 | return query_2_pos_doc_dict 100 | 101 | 102 | def reform_data(self, n_sample, output_path): 103 | output_list = [] 104 | for query_id in self.queries: 105 | positive_doc_id_list = self.query_2_pos_doc_dict[query_id] 106 | if self.query_2_neg_doc_dict is not None and query_id in self.query_2_neg_doc_dict: 107 | negative_doc_id_list = self.query_2_neg_doc_dict[query_id][:n_sample] 108 | else: 109 | negative_doc_id_list = [] 110 | 111 | q_str = self.queries[query_id] 112 | q_answer = 'nil' 113 | positive_ctxs = [] 114 | negative_ctxs = [] 115 | for passage_id in negative_doc_id_list: 116 | entry = self.collection[int(passage_id)] 117 | title = entry[self.title_field] 118 | title = "" if title is None else title 119 | body = entry[self.text_field] 120 | 121 | negative_ctxs.append( 122 | {'title': title, 'text': body, 'passage_id': passage_id, 'score': 'nil'} 123 | ) 124 | 125 | for passage_id in positive_doc_id_list: 126 | entry = self.collection[int(passage_id)] 127 | title = entry[self.title_field] 128 | title = "" if title is None else title 129 | body = entry[self.text_field] 130 | 131 | positive_ctxs.append( 132 | {'title': title, 'text': body, 'passage_id': passage_id, 'score': 'nil'} 133 | ) 134 | 135 | output_list.append( 136 | { 137 | "q_id": query_id, "question": q_str, "answers": q_answer, "positive_ctxs": positive_ctxs, 138 | "hard_negative_ctxs": negative_ctxs, "negative_ctxs": [] 139 | } 140 | ) 141 | 142 | 143 | with open(output_path, 'w', encoding='utf8') as f: 144 | json.dump(output_list, f, indent=2) 145 | 146 | 147 | def generate_qa_files(self, output_path): 148 | with open(output_path, 'w', encoding='utf8') as fw: 149 | for q_id, query in self.queries.items(): 150 | positive_doc_id_list = self.query_2_pos_doc_dict[q_id] 151 | fw.write(f"{query}\t{json.dumps(positive_doc_id_list)}\t{q_id}\n") 152 | 153 | def reform_corpus(self, output_path): 154 | # original format: psg_id, title, text 155 | # reformed format: psg_id, text, title (to consistent with nq and tq) 156 | with open(self.collection_file, 'r', encoding='utf8') as fr, open(output_path, 'w', encoding='utf8') as fw: 157 | for line in fr: 158 | l = line.strip().split('\t') 159 | assert len(l) == 3 160 | fw.write(f'{l[0]}\t{l[2]}\t{l[1]}\n') 161 | 162 | 163 | if __name__ == '__main__': 164 | random.seed(0) 165 | parser = ArgumentParser() 166 | parser.add_argument('--data_dir', type=str, default='./marco') 167 | parser.add_argument('--output_path', type=str, default='./reformed_marco') 168 | parser.add_argument('--n_sample', type=int, default=30, help='number of selected negative examples.') 169 | 170 | 171 | args = parser.parse_args() 172 | os.makedirs(args.output_path, exist_ok=True) 173 | 174 | processor = MSMARCOPreProcessor( 175 | data_dir=args.data_dir, 176 | split='train' 177 | ) 178 | processor.reform_data(args.n_sample, os.path.join(args.output_path, 'biencoder-marco-train.json')) 179 | processor.generate_qa_files(os.path.join(args.output_path, 'marco-train.qa.csv')) 180 | 181 | processor = MSMARCOPreProcessor( 182 | data_dir=args.data_dir, 183 | split='dev' 184 | ) 185 | processor.reform_corpus( os.path.join(args.output_path, 'corpus.tsv')) 186 | processor.reform_data(args.n_sample, os.path.join(args.output_path, 'biencoder-marco-dev.json')) 187 | processor.generate_qa_files(os.path.join(args.output_path, 'marco-dev.qa.csv')) -------------------------------------------------------------------------------- /CAPSTONE/preprocess/preprocess_trec.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import json 3 | import csv 4 | import datasets 5 | from dataclasses import dataclass 6 | from argparse import ArgumentParser 7 | import os 8 | import random 9 | from tqdm import tqdm 10 | from datetime import datetime 11 | from collections import OrderedDict 12 | import logging 13 | 14 | def read_queries(queries): 15 | print(f'Load the query from {queries}.') 16 | qmap = OrderedDict() 17 | with open(queries, 'r', encoding='utf8') as f: 18 | for l in f: 19 | qid, qry = l.strip().split('\t') 20 | qmap[qid] = qry 21 | print(f'Train set has {len(qmap)} queries.') 22 | return qmap 23 | 24 | 25 | def read_positive(relevance_file): 26 | print(f'Load the positive docs for queries from {relevance_file}.') 27 | query_2_pos_doc_dict = {} 28 | with open(relevance_file, 'r', encoding='utf8') as fr: 29 | for line in fr: 30 | query_id, _, doc_id, _ = line.split() 31 | if query_id in query_2_pos_doc_dict: 32 | query_2_pos_doc_dict[query_id].append(doc_id) 33 | else: 34 | query_2_pos_doc_dict[query_id] = [doc_id] 35 | print(f'{len(query_2_pos_doc_dict)} queries have positive docs.') 36 | return query_2_pos_doc_dict 37 | 38 | def generate_qa_files(output_path, queries, query_2_pos_doc_dict): 39 | with open(output_path, 'w', encoding='utf8') as fw: 40 | for q_id, query in queries.items(): 41 | if q_id not in query_2_pos_doc_dict: 42 | continue 43 | positive_doc_id_list = query_2_pos_doc_dict[q_id] 44 | fw.write(f"{query}\t{json.dumps(positive_doc_id_list)}\t{q_id}\n") 45 | 46 | if __name__ == '__main__': 47 | random.seed(0) 48 | parser = ArgumentParser() 49 | args = parser.parse_args() 50 | 51 | # trec_19 52 | data_dir = './trec_19' 53 | queries = read_queries(os.path.join(data_dir, 'msmarco-test2019-queries.tsv')) 54 | query_2_pos_doc_dict = read_positive(os.path.join(data_dir, '2019qrels-pass.txt')) 55 | generate_qa_files(os.path.join(data_dir, 'test2019.qa.csv'), queries, query_2_pos_doc_dict) 56 | 57 | # trec_20 58 | data_dir = './trec_20' 59 | queries = read_queries(os.path.join(data_dir, 'msmarco-test2020-queries.tsv')) 60 | query_2_pos_doc_dict = read_positive(os.path.join(data_dir, '2020qrels-pass.txt')) 61 | generate_qa_files(os.path.join(data_dir, 'test2020.qa.csv'), queries, query_2_pos_doc_dict) -------------------------------------------------------------------------------- /CAPSTONE/run_de_model_expand_corpus_cocondenser.sh: -------------------------------------------------------------------------------- 1 | echo "start de model warmup" 2 | 3 | MAX_SEQ_LEN=144 4 | MAX_Q_LEN=32 5 | DATA_DIR=./reformed_marco 6 | 7 | 8 | QUERY_PATH=./docTTTTTquery_full/doc2query_merge.tsv # each doc is paired with 80 generated quries. 9 | CORPUS_NAME=corpus 10 | CKPT_NUM=20000 11 | MODEL_TYPE=Luyu/co-condenser-marco 12 | MODEL=cocondenser 13 | 14 | number_neg=31 15 | total_part=3 16 | select_generated_query=gradual 17 | delimiter=sep 18 | gold_query_prob=0 19 | 20 | EXP_NAME=run_de_marco_${MAX_SEQ_LEN}_${MAX_Q_LEN}_t5_append_${delimiter}_gqp${gold_query_prob}_${MODEL}_neg${number_neg}_${select_generated_query}_${total_part}_part 21 | OUT_DIR=checkpoints_full_query/$EXP_NAME 22 | TB_DIR=tensorboard_log_full_query/$EXP_NAME # tensorboard log path 23 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9589 \ 24 | ./models/run_de_model_ernie.py \ 25 | --model_type $MODEL_TYPE \ 26 | --train_file $DATA_DIR/biencoder-marco-train.json \ 27 | --validation_file $DATA_DIR/biencoder-marco-dev.json \ 28 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN \ 29 | --per_device_train_batch_size 8 --gradient_accumulation_steps 1 \ 30 | --number_neg $number_neg --learning_rate 5e-6 \ 31 | --output_dir $OUT_DIR \ 32 | --warmup_steps 2000 --logging_steps 100 --save_steps 1000 --max_steps 20000 \ 33 | --log_dir $TB_DIR \ 34 | --shuffle_positives \ 35 | --fp16 --do_train \ 36 | --expand_corpus --query_path $QUERY_PATH \ 37 | --top_k_query 1 --append --delimiter $delimiter \ 38 | --gold_query_prob $gold_query_prob --select_generated_query $select_generated_query --total_part $total_part 39 | 40 | 41 | # 2. Evaluate retriever and generate hard topk 42 | echo "start de model inference" 43 | for k in 10 44 | do 45 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9536 \ 46 | ./models/run_de_model_ernie.py \ 47 | --model_type $MODEL_TYPE \ 48 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \ 49 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \ 50 | --train_file $DATA_DIR/biencoder-marco-train.json \ 51 | --validation_file $DATA_DIR/biencoder-marco-dev.json \ 52 | --train_qa_path $DATA_DIR/marco-train.qa.csv \ 53 | --dev_qa_path $DATA_DIR/marco-dev.qa.csv \ 54 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 1024 \ 55 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \ 56 | --fp16 --do_predict \ 57 | --expand_corpus --query_path $QUERY_PATH \ 58 | --top_k_query $k --append --delimiter $delimiter 59 | 60 | # TREC 19 and 20 61 | for year in 19 20 62 | do 63 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9535 \ 64 | ./models/run_de_model_ernie.py \ 65 | --model_type $MODEL_TYPE \ 66 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \ 67 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \ 68 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 512 \ 69 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \ 70 | --evaluate_trec --prefix trec${year}_test \ 71 | --test_qa_path ./trec_${year}/test20${year}.qa.csv \ 72 | --query_positive_id_path ./trec_${year}/20${year}qrels-pass.txt \ 73 | --fp16 --do_predict \ 74 | --expand_corpus --query_path $QUERY_PATH \ 75 | --top_k_query $k --append --delimiter $delimiter \ 76 | --load_cache 77 | done 78 | done 79 | -------------------------------------------------------------------------------- /CAPSTONE/run_de_model_expand_corpus_cocondenser_step2.sh: -------------------------------------------------------------------------------- 1 | echo "start de model warmup" 2 | 3 | MAX_SEQ_LEN=144 4 | MAX_Q_LEN=32 5 | DATA_DIR=./reformed_marco 6 | QUERY_PATH=./docTTTTTquery_full/doc2query_merge.tsv # each doc is paired with 80 generated quries. 7 | CORPUS_NAME=corpus 8 | CKPT_NUM=25000 9 | 10 | MODEL_TYPE=Luyu/co-condenser-marco 11 | MODEL=cocondenser 12 | 13 | number_neg=31 14 | total_part=4 15 | select_generated_query=gradual 16 | delimiter=sep 17 | gold_query_prob=0 18 | 19 | # path to the mined hard negatives 20 | train_file=checkpoints_full_query/run_de_marco_144_32_t5_append_sep_gqp0_cocondenser_neg31_gradual_3_part/20000_corpus_top_k_query_10/train_20000_0_query.json 21 | validation_file=checkpoints_full_query/run_de_marco_144_32_t5_append_sep_gqp0_cocondenser_neg31_gradual_3_part/20000_corpus_top_k_query_10/dev_20000_0_query.json 22 | 23 | EXP_NAME=run_de_marco_${MAX_SEQ_LEN}_${MAX_Q_LEN}_t5_append_${delimiter}_gqp${gold_query_prob}_${MODEL}_neg${number_neg}_${select_generated_query}_${total_part}_part_step2_selfdata_shuffle_positives_F 24 | OUT_DIR=checkpoints_full_query/$EXP_NAME 25 | TB_DIR=tensorboard_log_full_query/$EXP_NAME # tensorboard log path 26 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9589 \ 27 | ./models/run_de_model_ernie.py \ 28 | --model_type $MODEL_TYPE \ 29 | --train_file $train_file \ 30 | --validation_file $validation_file \ 31 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN \ 32 | --per_device_train_batch_size=8 --gradient_accumulation_steps=1 \ 33 | --number_neg $number_neg --learning_rate 5e-6 \ 34 | --output_dir $OUT_DIR \ 35 | --warmup_steps 2500 --logging_steps 100 --save_steps 1000 --max_steps 25000 \ 36 | --log_dir $TB_DIR \ 37 | --passage_path=$DATA_DIR/${CORPUS_NAME}.tsv \ 38 | --fp16 --do_train \ 39 | --expand_corpus --query_path $QUERY_PATH \ 40 | --top_k_query 1 --append --delimiter $delimiter \ 41 | --gold_query_prob $gold_query_prob --select_generated_query $select_generated_query --total_part $total_part 42 | 43 | # 2. Evaluate retriever and generate hard topk 44 | echo "start de model inference" 45 | for k in 10 46 | do 47 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9536 \ 48 | ./models/run_de_model_ernie.py \ 49 | --model_type $MODEL_TYPE \ 50 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \ 51 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \ 52 | --train_file $DATA_DIR/biencoder-marco-train.json \ 53 | --validation_file $DATA_DIR/biencoder-marco-dev.json \ 54 | --train_qa_path $DATA_DIR/marco-train.qa.csv \ 55 | --dev_qa_path $DATA_DIR/marco-dev.qa.csv \ 56 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 1024 \ 57 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \ 58 | --fp16 --do_predict \ 59 | --expand_corpus --query_path $QUERY_PATH \ 60 | --top_k_query $k --append --delimiter $delimiter 61 | 62 | # TREC 19 and 20 63 | for year in 19 20 64 | do 65 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9535 \ 66 | ./models/run_de_model_ernie.py \ 67 | --model_type $MODEL_TYPE \ 68 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \ 69 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \ 70 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 512 \ 71 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \ 72 | --evaluate_trec --prefix trec${year}_test \ 73 | --test_qa_path ./trec_${year}/test20${year}.qa.csv \ 74 | --query_positive_id_path ./trec_${year}/20${year}qrels-pass.txt \ 75 | --fp16 --do_predict \ 76 | --expand_corpus --query_path $QUERY_PATH \ 77 | --top_k_query $k --append --delimiter $delimiter \ 78 | --load_cache 79 | done 80 | done 81 | 82 | -------------------------------------------------------------------------------- /CAPSTONE/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/CAPSTONE/utils/__init__.py -------------------------------------------------------------------------------- /CAPSTONE/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import csv 5 | import json 6 | import datasets 7 | from dataclasses import dataclass 8 | from collections import OrderedDict 9 | csv.field_size_limit(sys.maxsize) 10 | logger = logging.getLogger("__main__") 11 | # logger = logging.getLogger(__name__) 12 | 13 | def load_passage(passage_path:str): 14 | """ 15 | for nq, tq, and msmarco 16 | """ 17 | if not os.path.exists(passage_path): 18 | logger.info(f'{passage_path} does not exist') 19 | return 20 | logger.info(f'Loading passages from: {passage_path}') 21 | passages = OrderedDict() 22 | with open(passage_path, 'r', encoding='utf8') as fin: 23 | reader = csv.reader(fin, delimiter='\t') 24 | for row in reader: 25 | if row[0] != 'id': 26 | try: 27 | if len(row) == 3: 28 | # psg_id, text, title 29 | passages[row[0]] = (row[0], row[1].strip(), row[2].strip()) 30 | else: 31 | # psg_id, text, title, original psg_id in the psgs_w100.tsv file 32 | passages[row[3]] = (row[3], row[1].strip(), row[2].strip()) 33 | except Exception: 34 | logger.warning(f'The following input line has not been correctly loaded: {row}') 35 | logger.info(f'{passage_path} has {len(passages)} passages.') 36 | return passages 37 | -------------------------------------------------------------------------------- /CAPSTONE/utils/evaluate_trec.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to evaluate the trec-19, 20 test set. 3 | """ 4 | import pytrec_eval 5 | 6 | def load_qrel_data(query_positive_id_path): 7 | dev_query_positive_id = {} 8 | with open(query_positive_id_path, 'r', encoding='utf8') as f: 9 | for line in f: 10 | query_id, _, doc_id, rel = line.split() 11 | query_id = int(query_id) 12 | doc_id = int(doc_id) 13 | if query_id not in dev_query_positive_id: 14 | dev_query_positive_id[query_id] = {} 15 | dev_query_positive_id[query_id][doc_id] = int(rel) 16 | return dev_query_positive_id 17 | 18 | def convert_to_string_id(result_dict): 19 | string_id_dict = {} 20 | 21 | # format [string, dict[string, val]] 22 | for k, v in result_dict.items(): 23 | _temp_v = {} 24 | for inner_k, inner_v in v.items(): 25 | _temp_v[str(inner_k)] = inner_v 26 | 27 | string_id_dict[str(k)] = _temp_v 28 | 29 | return string_id_dict 30 | def EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, I_nearest_neighbor,topN): 31 | prediction = {} #[qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2) 32 | 33 | total = 0 34 | labeled = 0 35 | Atotal = 0 36 | Alabeled = 0 37 | qids_to_ranked_candidate_passages = {} 38 | for query_idx in range(len(I_nearest_neighbor)): 39 | seen_pid = set() 40 | query_id = query_embedding2id[query_idx] 41 | prediction[query_id] = {} 42 | 43 | top_ann_pid = I_nearest_neighbor[query_idx].copy() 44 | selected_ann_idx = top_ann_pid[:topN] 45 | rank = 0 46 | 47 | if query_id in qids_to_ranked_candidate_passages: 48 | pass 49 | else: 50 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 51 | tmp = [0] * 1000 52 | qids_to_ranked_candidate_passages[query_id] = tmp 53 | 54 | for idx in selected_ann_idx: 55 | pred_pid = passage_embedding2id[idx] 56 | 57 | if not pred_pid in seen_pid: 58 | # this check handles multiple vector per document 59 | qids_to_ranked_candidate_passages[query_id][rank]=pred_pid 60 | Atotal += 1 61 | if pred_pid not in dev_query_positive_id[query_id]: 62 | Alabeled += 1 63 | if rank < 10: 64 | total += 1 65 | if pred_pid not in dev_query_positive_id[query_id]: 66 | labeled += 1 67 | rank += 1 68 | prediction[query_id][pred_pid] = -rank 69 | seen_pid.add(pred_pid) 70 | 71 | # use out of the box evaluation script 72 | evaluator = pytrec_eval.RelevanceEvaluator( 73 | convert_to_string_id(dev_query_positive_id), {'map_cut', 'ndcg_cut', 'recip_rank','recall'}) 74 | 75 | eval_query_cnt = 0 76 | result = evaluator.evaluate(convert_to_string_id(prediction)) 77 | 78 | qids_to_relevant_passageids = {} 79 | for qid in dev_query_positive_id: 80 | qid = int(qid) 81 | if qid in qids_to_relevant_passageids: 82 | pass 83 | else: 84 | qids_to_relevant_passageids[qid] = [] 85 | for pid in dev_query_positive_id[qid]: 86 | if pid>0: 87 | qids_to_relevant_passageids[qid].append(pid) 88 | 89 | # ms_mrr = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 90 | 91 | ndcg = 0 92 | Map = 0 93 | mrr = 0 94 | recall = 0 95 | recall_1000 = 0 96 | 97 | for k in result.keys(): 98 | eval_query_cnt += 1 99 | ndcg += result[k]["ndcg_cut_10"] 100 | Map += result[k]["map_cut_10"] 101 | mrr += result[k]["recip_rank"] 102 | recall += result[k]["recall_"+str(topN)] 103 | 104 | final_ndcg = ndcg / eval_query_cnt 105 | final_Map = Map / eval_query_cnt 106 | final_mrr = mrr / eval_query_cnt 107 | final_recall = recall / eval_query_cnt 108 | hole_rate = labeled/total 109 | Ahole_rate = Alabeled/Atotal 110 | 111 | return final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, Ahole_rate, result, prediction -------------------------------------------------------------------------------- /CAPSTONE/utils/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torch.optim import Optimizer 9 | 10 | 11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 12 | """Log a histogram of trust ratio scalars in across layers.""" 13 | results = collections.defaultdict(list) 14 | for group in optimizer.param_groups: 15 | for p in group['params']: 16 | state = optimizer.state[p] 17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 18 | if i in state: 19 | results[i].append(state[i]) 20 | 21 | for k, v in results.items(): 22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 23 | 24 | class Lamb(Optimizer): 25 | r"""Implements Lamb algorithm. 26 | 27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay) 57 | self.adam = adam 58 | super(Lamb, self).__init__(params, defaults) 59 | 60 | def step(self, closure=None): 61 | """Performs a single optimization step. 62 | 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state['step'] = 0 84 | # Exponential moving average of gradient values 85 | state['exp_avg'] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state['exp_avg_sq'] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 99 | 100 | # Paper v3 does not use debiasing. 101 | # Apply bias to lr to avoid broadcast. 102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 103 | 104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 105 | 106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 107 | if group['weight_decay'] != 0: 108 | adam_step.add_(group['weight_decay'], p.data) 109 | 110 | adam_norm = adam_step.pow(2).sum().sqrt() 111 | if weight_norm == 0 or adam_norm == 0: 112 | trust_ratio = 1 113 | else: 114 | trust_ratio = weight_norm / adam_norm 115 | state['weight_norm'] = weight_norm 116 | state['adam_norm'] = adam_norm 117 | state['trust_ratio'] = trust_ratio 118 | if self.adam: 119 | trust_ratio = 1 120 | 121 | p.data.add_(-step_size * trust_ratio, adam_step) 122 | 123 | return loss 124 | -------------------------------------------------------------------------------- /CAPSTONE/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import numpy as np 4 | from collections import Counter 5 | from dataclasses import dataclass 6 | from datasets import load_metric 7 | from rouge_score import rouge_scorer 8 | 9 | import nltk 10 | import random 11 | from typing import Optional, List 12 | 13 | class compute_metric: 14 | def __init__(self): 15 | self.rouge_metric = load_metric('rouge') 16 | self.rouge_scorer = rouge_scorer.RougeScorer(rouge_types = ["rougeL"], use_stemmer=True) 17 | self.bleu_scorer = load_metric('bleu') 18 | self.meteor_scorer = load_metric('meteor') 19 | 20 | def postprocess_text_bleu(self, preds, labels): 21 | preds = [nltk.word_tokenize(pred) for pred in preds] 22 | labels = [nltk.word_tokenize(label) for label in labels] 23 | 24 | return preds, labels 25 | 26 | def __call__(self, preds, labels): 27 | # preds, labels = eval_preds 28 | result = {} 29 | # Some simple post-processing 30 | preds_bleu, labels_bleu = self.postprocess_text_bleu(preds, labels) 31 | result['rougeL'] = self.rouge_metric.compute(predictions=preds, references=labels, use_stemmer=True)['rougeL'].mid.fmeasure * 100 32 | 33 | result['bleu_4'] = self.bleu_scorer.compute(predictions=preds_bleu, references=[[l] for l in labels_bleu], max_order=4)['bleu'] * 100 34 | 35 | result['meteor'] = self.meteor_scorer.compute(predictions=preds_bleu, references=labels_bleu)['meteor'] * 100 36 | result = {k: round(v, 4) for k, v in result.items()} 37 | return result 38 | 39 | def compute_rouge(self, preds, labels): 40 | result = {} 41 | # Some simple post-processing 42 | preds_bleu, labels_bleu = self.postprocess_text_bleu(preds, labels) 43 | result['rougeL'] = self.rouge_metric.compute(predictions=preds, references=labels, use_stemmer=True)['rougeL'].mid.fmeasure * 100 44 | result = {k: round(v, 4) for k, v in result.items()} 45 | return result 46 | 47 | def get_candidates(self, targets: List[str], preds: List[str], num_cand:int, num_cand_picked:int, strategy:str, gold_as_positive=False): 48 | """ 49 | args: 50 | targets: list of targets for each sample 51 | preds: list of predictions, length == len(targets) * num_cand 52 | num_cand: number of candidates 53 | num_cand_picked: number of returned indices per sample 54 | strategy: how to select num_cand_picked negatives from the preds 55 | gold_as_positive: use the gold or use the one of preds with the highest reward as the positive. 56 | returns: 57 | indices: Torch tensor, (B * (C-1), ), since the positive candidate is not generated from the generator 58 | candiates: candidates, with the length of len(targets) * num_cand_picked 59 | NOTE: We should always keep the positive sequences in the first candidate for each sample 60 | """ 61 | preds_bleu, targets_bleu = self.postprocess_text_bleu(preds, targets) 62 | # preds_meteor = [' '.join(pred) for pred in preds_bleu] 63 | # targets_meteor = [' '.join(label) for label in targets_bleu] 64 | # print(targets_meteor) 65 | 66 | indices = [] 67 | candidates = [] 68 | rewards = [] 69 | for i,t in enumerate(targets): 70 | scores = [] 71 | ps = preds[i * num_cand: (i+1)*num_cand] 72 | ps_bleu = preds_bleu[i * num_cand: (i+1)*num_cand] 73 | for j,p in enumerate(ps): 74 | if len(ps_bleu[j]) == 0: 75 | rouge_score = 0 76 | bleu_score = 0 77 | meteor_score = 0 78 | else: 79 | # rouge_score = self.rouge_metric.compute(predictions=[p], references=[t], use_stemmer=True)['rougeL'].mid.fmeasure 80 | bleu_score = self.bleu_scorer.compute(predictions = [ps_bleu[j]], references = [[targets_bleu[i]]], max_order = 4)['bleu'] 81 | meteor_score = self.meteor_scorer.compute(predictions = [ps_bleu[j]], references = [targets_bleu[i]])['meteor'] 82 | # reward = rouge_score + bleu_score + meteor_score 83 | reward = bleu_score + meteor_score 84 | scores.append((j + i * num_cand, reward , p)) 85 | 86 | scores = sorted(scores, key = lambda x: x[1], reverse=True) 87 | 88 | if gold_as_positive: 89 | # rouge_score = self.rouge_metric.compute(predictions = [t], references= [t], use_stemmer=True)['rougeL'].mid.fmeasure 90 | bleu_score = self.bleu_scorer.compute(predictions = [targets_bleu[i]], references = [[targets_bleu[i]]], max_order = 4)['bleu'] 91 | meteor_score = self.meteor_scorer.compute(predictions = [targets_bleu[i]], references = [targets_bleu[i]])['meteor'] 92 | # reward = rouge_score + bleu_score + meteor_score 93 | reward = bleu_score + meteor_score 94 | idx_this = [] 95 | cand_this = [t] 96 | rewards_this = [reward] 97 | max_num = num_cand_picked - 1 98 | else: 99 | idx_this = [scores[0][0]] # the first as pos 100 | cand_this = [scores[0][2]] 101 | rewards_this = [scores[0][1]] 102 | scores = scores[1:] 103 | max_num = num_cand_picked - 1 104 | 105 | if strategy == 'random': 106 | s_for_pick = random.sample(scores, max_num) 107 | idx_this += [s[0] for s in s_for_pick] 108 | cand_this += [s[2] for s in s_for_pick] 109 | rewards_this += [s[1] for s in s_for_pick] 110 | else: 111 | if strategy == 'top': 112 | idx_this += [s[0] for s in scores[:max_num]] 113 | cand_this += [s[2] for s in scores[:max_num]] 114 | rewards_this += [s[1] for s in scores[:max_num]] 115 | elif strategy == 'bottom': 116 | idx_this += [s[0] for s in scores[-max_num:]] 117 | cand_this += [s[2] for s in scores[-max_num:]] 118 | rewards_this += [s[1] for s in scores[-max_num:]] 119 | elif strategy == 'top-bottom': 120 | n_top = max_num // 2 121 | n_bottom = max_num - n_top 122 | idx_this += [s[0] for s in scores[:n_top]] 123 | cand_this += [s[2] for s in scores[:n_top]] 124 | idx_this += [s[0] for s in scores[-n_bottom:]] 125 | cand_this += [s[2] for s in scores[-n_bottom:]] 126 | rewards_this += [s[1] for s in scores[:n_top]] 127 | rewards_this += [s[1] for s in scores[-n_bottom:]] 128 | 129 | indices += idx_this 130 | candidates += cand_this 131 | rewards.append(rewards_this) 132 | return candidates, torch.FloatTensor(rewards) 133 | # return torch.LongTensor(indices), candidates, torch.FloatTensor(rewards) 134 | 135 | def compute_metric_score(self, target:str, preds:List[str], metric:str): 136 | score_list = [] 137 | preds_bleu, targets_bleu = self.postprocess_text_bleu(preds, [target]) 138 | 139 | for i, pred in enumerate(preds_bleu): 140 | if metric=='rouge-l': 141 | score = self.rouge_scorer.score(target=target, prediction=preds[i])['rougeL'].fmeasure 142 | # score = self.rouge_metric.compute(predictions=[preds[i]], references=[target], use_stemmer=True)['rougeL'].mid.fmeasure 143 | elif metric =='bleu': 144 | score = self.bleu_scorer.compute(predictions = [preds_bleu[i]], references = [[targets_bleu[0]]], max_order = 4)['bleu'] 145 | elif metric=='meteor': 146 | score = self.meteor_scorer.compute(predictions = [preds_bleu[i]], references = [targets_bleu[0]])['meteor'] 147 | else: 148 | raise ValueError() 149 | score_list.append(score) 150 | return score_list 151 | 152 | 153 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /HOW_TO_DOWNLOAD.md: -------------------------------------------------------------------------------- 1 | # How to Download Desired Resources 2 | 3 | Since most of the resources are large, we release the preprocessed data and trained ckpts in an Azure Blob container with public access for the most of our open-source projects. 4 | We here provide two methods to download the resources. 5 | 6 | 7 | ## You Must Know 8 | 9 | In almost each projects, we will provide the **basic URL** of the project in the `README.md` file. 10 | For example, for [SimANS](https://github.com/microsoft/SimXNS/tree/main/SimANS), the basic URL is `https://msranlciropen.blob.core.windows.net/simxns/SimANS/`. 11 | 12 | If you want to view the content of the blob, you can use [Microsoft's AzCopy CLI tool](https://learn.microsoft.com/en-us/azure/storage/common/storage-ref-azcopy): 13 | ```bash 14 | azcopy list https://msranlciropen.blob.core.windows.net/simxns/SimANS/ 15 | ``` 16 | We also provide the list in the `README.md` file of projects. 17 | 18 | 19 | ## Method 1: Directly Download using the URLs 20 | 21 | You may choose to download only a part of the resources by appending the relative path to the blob url and directly downloading it. 22 | 23 | 24 | ## Method 2: Faster Transmission using AzCopy 25 | 26 | You can also use [the copy function of Microsoft's AzCopy CLI tool](https://learn.microsoft.com/en-us/azure/storage/common/storage-ref-azcopy-copy). 27 | Here we provide the command of copying the entire folder like this: 28 | ```bash 29 | azcopy copy --recursive "https://msranlciropen.blob.core.windows.net/simxns/SimANS/" . 30 | ``` 31 | You may also use the tool to copy only a single file as you like: 32 | ```bash 33 | azcopy copy "https://msranlciropen.blob.core.windows.net/simxns/SimANS/best_simans_ckpt/TQ/checkpoint-10000" . 34 | ``` 35 | -------------------------------------------------------------------------------- /LEAD/README.md: -------------------------------------------------------------------------------- 1 | # LEAD 2 | 3 | The code for our paper [LEAD: Liberal Feature-based Distillation for Dense Retrieval](https://arxiv.org/abs/2212.05225). 4 | 5 | ![model](assets/model.jpg) 6 | 7 | ## Overview 8 | 9 | Our proposed method LEAD aligns the layer features of student and teacher, emphasizing more on the informative layers by re-weighting. 10 | 11 | Here we show the main results on [MS MARCO](https://microsoft.github.io/msmarco/). This method outperformes all the baselines. 12 | 13 | ![main_result](assets/main_result.jpg) 14 | 15 | 16 | ## Released Resources 17 | 18 | We release the preprocessed data and trained ckpts in [Azure Blob](https://msranlciropen.blob.core.windows.net/simxns/LEAD/). 19 | Here we also provide the file list under this URL: 20 |
21 | Click here to see the file list. 22 |
INFO: ckpt/24_to_12_msdoc.ckpt;  Content Length: 1.22 GiB
 23 | INFO: ckpt/24_to_12_mspas.ckpt;  Content Length: 1.22 GiB
 24 | INFO: ckpt/24_to_12_trec_doc_19.ckpt;  Content Length: 1.22 GiB
 25 | INFO: ckpt/24_to_12_trec_doc_20.ckpt;  Content Length: 1.22 GiB
 26 | INFO: ckpt/24_to_12_trec_pas_19.ckpt;  Content Length: 1.22 GiB
 27 | INFO: ckpt/24_to_12_trec_pas_20.ckpt;  Content Length: 1.22 GiB
 28 | INFO: dataset/mspas/biencoder-mspas-train-hard.json;  Content Length: 1.56 GiB
 29 | INFO: dataset/mspas/biencoder-mspas-train.json;  Content Length: 5.35 GiB
 30 | INFO: dataset/mspas/mspas-test.qa.csv;  Content Length: 319.98 KiB
 31 | INFO: dataset/mspas/psgs_w100.tsv;  Content Length: 3.06 GiB
 32 | INFO: dataset/mspas/trec2019-test.qa.csv;  Content Length: 99.72 KiB
 33 | INFO: dataset/mspas/trec2019-test.rating.csv;  Content Length: 46.67 KiB
 34 | INFO: dataset/mspas/trec2020-test.qa.csv;  Content Length: 122.86 KiB
 35 | INFO: dataset/mspas/trec2020-test.rating.csv;  Content Length: 57.48 KiB
 36 | INFO: dataset/msdoc/biencoder-msdoc-train-hard.json;  Content Length: 800.13 MiB
 37 | INFO: dataset/msdoc/biencoder-msdoc-train.json;  Content Length: 294.04 MiB
 38 | INFO: dataset/msdoc/msdoc-test.qa.csv;  Content Length: 232.16 KiB
 39 | INFO: dataset/msdoc/psgs_w100.tsv;  Content Length: 21.11 GiB
 40 | INFO: dataset/msdoc/trec2019-test.qa.csv;  Content Length: 170.55 KiB
 41 | INFO: dataset/msdoc/trec2019-test.rating.csv;  Content Length: 80.84 KiB
 42 | INFO: dataset/msdoc/trec2020-test.qa.csv;  Content Length: 96.26 KiB
 43 | INFO: dataset/msdoc/trec2020-test.rating.csv;  Content Length: 46.04 KiB
44 |
45 | 46 | To download the files, please refer to [HOW_TO_DOWNLOAD](https://github.com/microsoft/SimXNS/tree/main/HOW_TO_DOWNLOAD.md). 47 | 48 | 49 | ## Environments Setting 50 | 51 | We implement our approach based on Pytorch and Huggingface Transformers. We list our command to prepare the experimental environment as follows: 52 | 53 | ``` 54 | yes | conda create -n lead 55 | conda init bash 56 | source ~/.bashrc 57 | conda activate lead 58 | conda install -y pytorch==1.11.0 cudatoolkit=11.5 faiss-gpu -c pytorch 59 | conda install -y pandas 60 | pip install transformers 61 | pip install tqdm 62 | pip install wandb 63 | pip install sklearn 64 | pip install pytrec-eval 65 | ``` 66 | 67 | ## Dataset Preprocess 68 | 69 | We conduct experiments on MS-PAS and MS-DOC datasets. You can download and preprocess data by using our code: 70 | 71 | ``` 72 | cd data_preprocess 73 | bash get_data.sh 74 | ``` 75 | 76 | ## Hard Negative Generation 77 | 78 | The negatives in the original dataset are mainly random or BM25 negatives. Before training, we first train a 12-layer DE and get top-100 hard negatives. 79 | 80 | ``` 81 | cd hard_negative_generation 82 | bash train_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES 83 | bash retrieve_hard_negatives_academic.sh $DATASET $MASTER_PORT $MODEL_PATH $CKPT_NAME $MAX_DOC_LENGTH $MAX_QUERY_LENGTH 84 | ``` 85 | 86 | ## Single Model Training 87 | 88 | Before distillation, we train the teacher and student with the mined hard negatives as warming up. You can go to the warm_up directory and run the following command: 89 | 90 | ### Train a 6-layer Dual Encoder 91 | 92 | ``` 93 | bash train_6_layer_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES 94 | ``` 95 | 96 | ### Train a 12-layer Dual Encoder 97 | 98 | ``` 99 | bash train_12_layer_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES 100 | ``` 101 | 102 | ### Train a 12-layer ColBERT 103 | 104 | ``` 105 | bash train_12_layer_col.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES 106 | ``` 107 | 108 | ### Train a 12-layer Cross Encoder 109 | 110 | ``` 111 | bash train_12_layer_ce.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES 112 | ``` 113 | 114 | ### Train a 24-layer ColBERT 115 | 116 | ``` 117 | bash train_24_layer_col.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES 118 | ``` 119 | 120 | ## LEAD distillation 121 | 122 | ### Distill from 12-layer DE to 6-layer DE 123 | 124 | ``` 125 | bash distill_from_12de_to_6de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $DE_MODEL_PATH $DB_MODEL_PATH 126 | ``` 127 | 128 | ### Distill from 12-layer CB to 6-layer DE 129 | 130 | ``` 131 | bash distill_from_12cb_to_6de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $CB_MODEL_PATH $DB_MODEL_PATH 132 | ``` 133 | 134 | ### Distill from 12-layer CE to 6-layer DE 135 | 136 | ``` 137 | bash distill_from_12ce_to_6de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $CE_MODEL_PATH $DB_MODEL_PATH 138 | ``` 139 | 140 | ### Distill from 24-layer CB to 12-layer DE 141 | 142 | ``` 143 | bash distill_from_24cb_to_12de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $CB_MODEL_PATH $DB_MODEL_PATH 144 | ``` 145 | 146 | ## Inference 147 | 148 | ### Inference 6-layer DE 149 | 150 | ``` 151 | bash evaluate_6_layer_de.sh $DATASET $MASTER_PORT $MODEL_PATH $FIRST_STEPS $EVAL_STEPS $MAX_STEPS $MAX_DOC_LENGTH $MAX_QUERY_LENGTH 152 | ``` 153 | 154 | ### Inference 12-layer DE 155 | 156 | ``` 157 | bash evaluate_12_layer_de.sh $DATASET $MASTER_PORT $MODEL_PATH $FIRST_STEPS $EVAL_STEPS $MAX_STEPS $MAX_DOC_LENGTH $MAX_QUERY_LENGTH 158 | ``` 159 | 160 | ## Parameters 161 | The meaning of each parameter is defined as follows. For detailed initialization of these parameters, please refer to Section ''Hyper-parameter and Initialization''' in Appendix B of our paper. 162 | - `--DATASET`: The dataset used for training, e.g, mspas or msdoc. 163 | - `--MASTER_PORT`: The master port for parallel training, e.g., 3000. 164 | - `--FIRST_STEPS`: The first evaluation step. 165 | - `--EVAL_STEPS`: The evaluation step gap. 166 | - `--MAX_STEPS`: The maximum evaluation step. 167 | - `--WARM_RATIO`: The warm up ratio of the training. 168 | - `--MAX_DOC_LENGTH`: The maximum input length for each document. 169 | - `--MAX_QUERY_LENGTH`: The maximum input length for each query. 170 | - `--TRAIN_BATCH_SIZE`: Batch size per GPU. 171 | - `--NUM_NEGATIVES`: The number of negatives for each query. 172 | - `--DISTILL_LAYER_NUM`: The number of layers for distillation. 173 | - `--MODEL_PATH`: The directory name of saved checkpoint. 174 | - `--DE_MODEL_PATH`: The path of DE teacher checkpoint. 175 | - `--CB_MODEL_PATH`: The path of CB teacher checkpoint. 176 | - `--CE_MODEL_PATH`: The path of CE teacher checkpoint. 177 | - `--DB_MODEL_PATH`: The path of DE student checkpoint. 178 | - `--CKPT_NAME`: The training step of the checkpoint 179 | 180 | ## 📜 Citation 181 | 182 | Please cite our paper if you use [LEAD](https://arxiv.org/abs/2212.05225) in your work: 183 | ```bibtex 184 | @article{sun2022lead, 185 | title={LEAD: Liberal Feature-based Distillation for Dense Retrieval}, 186 | author={Sun, Hao and Liu, Xiao and Gong, Yeyun and Dong, Anlei and Jiao, Jian and Lu, Jingwen and Zhang, Yan and Jiang, Daxin and Yang, Linjun and Majumder, Rangan and others}, 187 | journal={arXiv preprint arXiv:2212.05225}, 188 | year={2022} 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /LEAD/assets/main_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/LEAD/assets/main_result.jpg -------------------------------------------------------------------------------- /LEAD/assets/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/LEAD/assets/model.jpg -------------------------------------------------------------------------------- /LEAD/data_preprocess/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/LEAD/data_preprocess/.DS_Store -------------------------------------------------------------------------------- /LEAD/data_preprocess/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm, trange 3 | from random import sample 4 | 5 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1): 6 | def gen(): 7 | for i, line in tqdm(enumerate(fd)): 8 | if i % trainer_num == trainer_id: 9 | slots = line.rstrip('\n').split(delimiter) 10 | if len(slots) == 1: 11 | yield slots, 12 | else: 13 | yield slots 14 | return gen() 15 | 16 | def read_qrel_train(relevance_file): 17 | qrel = {} 18 | with open(relevance_file, encoding='utf8') as f: 19 | tsvreader = csv_reader(f, delimiter="\t") 20 | for [topicid, _, docid, rel] in tsvreader: 21 | assert rel == "1" 22 | if topicid in qrel: 23 | qrel[topicid].append(docid) 24 | else: 25 | qrel[topicid] = [docid] 26 | return qrel 27 | 28 | def read_qrel_train_2(relevance_file): 29 | qrel = {} 30 | with open(relevance_file, encoding='utf8') as f: 31 | tsvreader = csv_reader(f, delimiter=" ") 32 | for [topicid, _, docid, rel] in tsvreader: 33 | assert rel == "1" 34 | if topicid in qrel: 35 | qrel[topicid].append(docid) 36 | else: 37 | qrel[topicid] = [docid] 38 | return qrel 39 | 40 | 41 | def read_qrel_dev(relevance_file): 42 | qrel = {} 43 | with open(relevance_file, encoding='utf8') as f: 44 | tsvreader = csv_reader(f, delimiter="\t") 45 | for [topicid, docid] in tsvreader: 46 | if topicid in qrel: 47 | qrel[topicid].append(docid) 48 | else: 49 | qrel[topicid] = [docid] 50 | return qrel 51 | 52 | def read_qstring(query_file): 53 | q_string = {} 54 | with open(query_file, 'r', encoding='utf-8') as file: 55 | for num, line in enumerate(file): 56 | line = line.strip('\n') # 删除换行符 57 | line = line.split('\t') 58 | q_string[line[0]] = line[1] 59 | return q_string 60 | 61 | def read_dstring(corpus_file): 62 | d_string = {} 63 | with open(corpus_file, 'r', encoding='utf-8') as file: 64 | for num, line in enumerate(file): 65 | line = line.strip('\n') # 删除换行符 66 | line = line.split('\t') 67 | d_string[line[0]] = [line[1], line[2]] 68 | return d_string 69 | 70 | def read_dstring_2(corpus_file): 71 | d_string = {} 72 | with open(corpus_file, 'r', encoding='utf-8') as file: 73 | for num, line in enumerate(file): 74 | line = line.strip('\n') # 删除换行符 75 | line = line.split('\t') 76 | d_string[line[0]] = [line[2], line[3]] 77 | return d_string 78 | 79 | 80 | def construct_mspas(): 81 | train_relevance_file = "mspas/qrels.train.tsv" 82 | train_query_file = "mspas/train.query.txt" 83 | train_negative_file = "mspas/train.negatives.tsv" 84 | corpus_file = "mspas/corpus.tsv" 85 | dev_relevance_file = "mspas/qrels.dev.tsv" 86 | dev_query_file = "mspas/dev.query.txt" 87 | n_sample = 100 88 | 89 | ## qid: docid 90 | train_qrel = read_qrel_train(train_relevance_file) 91 | 92 | ## qid: query 93 | train_q_string = read_qstring(train_query_file) 94 | 95 | ## qid: docid 96 | dev_qrel = read_qrel_dev(dev_relevance_file) 97 | 98 | ## qid: query 99 | dev_q_string = read_qstring(dev_query_file) 100 | 101 | # docid: doc 102 | d_string = read_dstring(corpus_file) 103 | 104 | ## qid: [docid1, ...] 105 | negative = {} 106 | with open(train_negative_file, 'r', encoding='utf8') as nf: 107 | reader = csv_reader(nf) 108 | for cnt, line in enumerate(reader): 109 | q = line[0] 110 | nn = line[1] 111 | nn = nn.split(',') 112 | negative[q] = nn 113 | 114 | train_file = open("mspas/psgs_w100.tsv", 'w') 115 | for did in d_string: 116 | train_file.write('\t'.join([str(int(did) + 1), d_string[did][1], d_string[did][0]]) + '\n') 117 | train_file.flush() 118 | 119 | train_data_list = [] 120 | for qid in tqdm(negative): 121 | example = {} 122 | example['question'] = train_q_string[qid] 123 | example['answers'] = train_qrel[qid] 124 | example['positive_ctxs'] = [did for did in train_qrel[qid]] 125 | example['hard_negative_ctxs'] = [did for did in negative[qid][:n_sample]] 126 | example['negative_ctxs'] = [] 127 | train_data_list.append(example) 128 | 129 | with open('mspas/biencoder-mspas-train.json', 'w', encoding='utf-8') as json_file: 130 | json_file.write(json.dumps(train_data_list, indent=4)) 131 | 132 | 133 | train_data_list = [] 134 | for qid in tqdm(train_q_string): 135 | example = {} 136 | example['question'] = train_q_string[qid] 137 | example['answers'] = train_qrel[qid] 138 | example['positive_ctxs'] = [did for did in train_qrel[qid]] 139 | example['negative_ctxs'] = [] 140 | train_data_list.append(example) 141 | 142 | with open('mspas/biencoder-mspas-train-full.json', 'w', encoding='utf-8') as json_file: 143 | json_file.write(json.dumps(train_data_list, indent=4)) 144 | 145 | train_file = open("mspas/mspas-test.qa.csv", 'w') 146 | for qid in dev_qrel: 147 | train_file.write('\t'.join([dev_q_string[qid], str(dev_qrel[qid])]) + '\n') 148 | train_file.flush() 149 | 150 | 151 | def construct_msdoc(): 152 | train_relevance_file = "msdoc/msmarco-doctrain-qrels.tsv" 153 | train_query_file = "msdoc/msmarco-doctrain-queries.tsv" 154 | corpus_file = "msdoc/msmarco-docs.tsv" 155 | dev_relevance_file = "msdoc/msmarco-docdev-qrels.tsv" 156 | dev_query_file = "msdoc/msmarco-docdev-queries.tsv" 157 | n_sample = 100 158 | 159 | ## qid: docid 160 | train_qrel = read_qrel_train_2(train_relevance_file) 161 | 162 | ## qid: query 163 | train_q_string = read_qstring(train_query_file) 164 | 165 | ## qid: docid 166 | dev_qrel = read_qrel_train_2(dev_relevance_file) 167 | 168 | ## qid: query 169 | dev_q_string = read_qstring(dev_query_file) 170 | 171 | # docid: doc 172 | d_string = read_dstring_2(corpus_file) 173 | 174 | docid_int = {} 175 | int_docid = {} 176 | idx = 0 177 | for docid in d_string: 178 | docid_int[docid] = idx 179 | int_docid[idx] = docid 180 | idx += 1 181 | 182 | ## qid: [docid1, ...] 183 | negative = {} 184 | all_doc_list = [elem for elem in docid_int] 185 | 186 | for docid in train_qrel: 187 | negative[docid] = sample(all_doc_list, n_sample) 188 | 189 | train_data_list = [] 190 | for qid in tqdm(train_qrel): 191 | example = {} 192 | example['question'] = train_q_string[qid] 193 | example['answers'] = [docid_int[did] for did in train_qrel[qid]] 194 | example['positive_ctxs'] = [docid_int[did] for did in train_qrel[qid]] 195 | example['hard_negative_ctxs'] = [docid_int[did] for did in negative[qid][:n_sample]] 196 | example['negative_ctxs'] = [] 197 | train_data_list.append(example) 198 | 199 | with open('msdoc/biencoder-msdoc-train.json', 'w', encoding='utf-8') as json_file: 200 | json_file.write(json.dumps(train_data_list, indent=4)) 201 | 202 | train_file = open("msdoc/psgs_w100.tsv", 'w') 203 | for did in int_docid: 204 | train_file.write('\t'.join([str(did + 1), d_string[int_docid[did]][1], d_string[int_docid[did]][0]]) + '\n') 205 | train_file.flush() 206 | 207 | train_file = open("msdoc/msdoc-test.qa.csv", 'w') 208 | for qid in dev_qrel: 209 | train_file.write('\t'.join([dev_q_string[qid], str([str(docid_int[elem]) for elem in dev_qrel[qid]])]) + '\n') 210 | train_file.flush() 211 | 212 | 213 | print('Start Processing') 214 | construct_mspas() 215 | construct_msdoc() 216 | print('Finish Processing') -------------------------------------------------------------------------------- /LEAD/data_preprocess/get_data.sh: -------------------------------------------------------------------------------- 1 | # Download MSMARCO-PAS data 2 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz 3 | tar -zxf marco.tar.gz 4 | rm -rf marco.tar.gz 5 | mv marco mspas 6 | cd mspas 7 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz 8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv 9 | gunzip qidpidtriples.train.full.2.tsv.gz 10 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv 11 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv 12 | 13 | 14 | 15 | # Download MSMARCO-DOC data 16 | cd .. 17 | mkdir msdoc 18 | cd msdoc 19 | 20 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz 21 | gunzip msmarco-docs.tsv.gz 22 | 23 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz 24 | gunzip msmarco-doctrain-queries.tsv.gz 25 | 26 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz 27 | gunzip msmarco-doctrain-qrels.tsv.gz 28 | 29 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz 30 | gunzip msmarco-test2019-queries.tsv.gz 31 | 32 | wget https://trec.nist.gov/data/deep/2019qrels-docs.txt 33 | 34 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz 35 | gunzip msmarco-docdev-queries.tsv.gz 36 | 37 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz 38 | gunzip msmarco-docdev-qrels.tsv.gz 39 | 40 | cd .. 41 | python data_preprocess.py -------------------------------------------------------------------------------- /LEAD/data_preprocess/hard_negative_generation/retrieve_hard_negatives_academic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=$1 4 | MASTER_PORT=$2 5 | MODEL_PATH=$3 6 | CKPT_NAME=$4 7 | MAX_DOC_LENGTH=$5 8 | MAX_QUERY_LENGTH=$6 9 | 10 | MODEL_TYPE=dual_encoder 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | if [ ${DATASET} == 'mspas' ]; 17 | then 18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-mspas-train-full.json 19 | else 20 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train.json 21 | fi 22 | 23 | TOKENIZER_NAME=bert-base-uncased 24 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 25 | 26 | EVAL_BATCHSIZE=256 27 | 28 | ####################################### Multi Evaluation ######################################## 29 | echo "****************begin Retrieve****************" 30 | cd .. 31 | 32 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} retrieve_hard_negative.py \ 33 | --model_type=${MODEL_TYPE} \ 34 | --dataset ${DATASET} \ 35 | --tokenizer_name=${TOKENIZER_NAME} \ 36 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 37 | --eval_model_dir=${BASE_OUTPUT_DIR}/models/${MODEL_PATH}/de-checkpoint-${CKPT_NAME} \ 38 | --output_dir=${BASE_DATA_DIR} \ 39 | --train_file=${TRAIN_FILE} \ 40 | --passage_path=${PASSAGE_PATH} \ 41 | --max_query_length=${MAX_QUERY_LENGTH} \ 42 | --max_doc_length=${MAX_DOC_LENGTH} \ 43 | --per_gpu_eval_batch_size $EVAL_BATCHSIZE 44 | 45 | echo "****************End Retrieve****************" -------------------------------------------------------------------------------- /LEAD/data_preprocess/hard_negative_generation/train_de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=dual_encoder 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MAX_DOC_LENGTH=$3 7 | MAX_QUERY_LENGTH=$4 8 | TRAIN_BATCH_SIZE=$5 9 | NUM_NEGATIVES=$6 10 | 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train.json 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | TOKENIZER_NAME=bert-base-uncased 17 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 18 | MAX_STEPS=20000 19 | LOGGING_STEPS=10 20 | SAVE_STEPS=1000 21 | LR=2e-5 22 | GRADIENT_ACCUMULATION_STEPS=1 23 | ######################################## Training ######################################## 24 | echo "****************begin Train****************" 25 | cd ../ 26 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 27 | ./run_single_model.py \ 28 | --model_type=${MODEL_TYPE} \ 29 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 30 | --tokenizer_name=${TOKENIZER_NAME} \ 31 | --dataset=${DATASET} \ 32 | --train_file=${TRAIN_FILE} \ 33 | --passage_path=${PASSAGE_PATH} \ 34 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 35 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 36 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 37 | --share_weight \ 38 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 39 | --output_dir=${BASE_OUTPUT_DIR}/models 40 | 41 | echo "****************End Train****************" 42 | -------------------------------------------------------------------------------- /LEAD/distillation/distill_from_12cb_to_6de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=distilbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | DISTILL_LAYER_NUM=$3 7 | MAX_DOC_LENGTH=$4 8 | MAX_QUERY_LENGTH=$5 9 | TRAIN_BATCH_SIZE=$6 10 | NUM_NEGATIVES=$7 11 | WARM_RATIO=$8 12 | COL_MODEL_PATH=$9 13 | DB_MODEL_PATH=${10} 14 | 15 | 16 | BASE_DATA_DIR=data_preprocess/${DATASET} 17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 20 | DISTILL_COL_PATH=model_output/colbert/models/${COL_MODEL_PATH} # The initialized parameter of ce 21 | DISTILL_DB_PATH=model_output/distilbert/models/${DB_MODEL_PATH} # The initialized parameter of db 22 | 23 | 24 | TOKENIZER_NAME=bert-base-uncased 25 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db 27 | DISTILL_PARA_COL=1 # The distillation parameter of loss_col 28 | DISTILL_PARA_COL_DB_DIS=1 29 | DISTILL_PARA_COL_DB_LAYER_SCORE=1 30 | GRADIENT_ACCUMULATION_STEPS=1 31 | SAVE_STEPS=10 32 | TEMPERATURE=1 33 | LAYER_TEMPERATURE=10 34 | MAX_STEPS=50000 35 | LOGGING_STEPS=10 36 | LR=2e-5 37 | 38 | 39 | cd .. 40 | ######################################## Training ######################################## 41 | echo "****************begin Train****************" 42 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 43 | ./run_LEAD.py \ 44 | --model_type=${MODEL_TYPE} \ 45 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 46 | --tokenizer_name=${TOKENIZER_NAME} \ 47 | --train_file=${TRAIN_FILE} \ 48 | --passage_path=${PASSAGE_PATH} \ 49 | --dataset=${DATASET} \ 50 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 51 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 52 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 53 | --share_weight \ 54 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 55 | --temperature ${TEMPERATURE} \ 56 | --layer_temperature ${LAYER_TEMPERATURE} \ 57 | --output_dir=${BASE_OUTPUT_DIR}/models \ 58 | --distill_para_db=${DISTILL_PARA_DB} \ 59 | --distill_para_col=${DISTILL_PARA_COL} \ 60 | --distill_para_col_db_dis=${DISTILL_PARA_COL_DB_DIS} \ 61 | --distill_para_col_db_layer_score=${DISTILL_PARA_COL_DB_LAYER_SCORE} \ 62 | --distill_col --train_col\ 63 | --distill_col_path=${DISTILL_COL_PATH} \ 64 | --distill_db --train_db\ 65 | --distill_db_path=${DISTILL_DB_PATH} \ 66 | --distill_col_db_layer_score \ 67 | --disitll_layer_num=${DISTILL_LAYER_NUM} \ 68 | --layer_selection_random \ 69 | --layer_score_reweight \ 70 | --warm_up_ratio=${WARM_RATIO} 71 | 72 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/distillation/distill_from_12ce_to_6de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=distilbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | DISTILL_LAYER_NUM=$3 7 | MAX_DOC_LENGTH=$4 8 | MAX_QUERY_LENGTH=$5 9 | TRAIN_BATCH_SIZE=$6 10 | NUM_NEGATIVES=$7 11 | WARM_RATIO=$8 12 | CE_MODEL_PATH=$9 13 | DB_MODEL_PATH=${10} 14 | 15 | 16 | BASE_DATA_DIR=data_preprocess/${DATASET} 17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 20 | DISTILL_CE_PATH=model_output/cross_encoder/models/${CE_MODEL_PATH} # The initialized parameter of ce 21 | DISTILL_DB_PATH=model_output/distilbert/models/${DB_MODEL_PATH} # The initialized parameter of db 22 | 23 | 24 | TOKENIZER_NAME=bert-base-uncased 25 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db 27 | DISTILL_PARA_CE=1 # The distillation parameter of loss_ce 28 | DISTILL_PARA_CE_DB_DIS=1 29 | DISTILL_PARA_CE_DB_LAYER_SCORE=1 30 | GRADIENT_ACCUMULATION_STEPS=10 31 | SAVE_STEPS=10 32 | TEMPERATURE=1 33 | LAYER_TEMPERATURE=10 34 | MAX_STEPS=100000 35 | LOGGING_STEPS=10 36 | LR=5e-5 37 | 38 | cd .. 39 | ######################################## Training ######################################## 40 | echo "****************begin Train****************" 41 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 42 | ./run_LEAD.py \ 43 | --model_type=${MODEL_TYPE} \ 44 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 45 | --tokenizer_name=${TOKENIZER_NAME} \ 46 | --train_file=${TRAIN_FILE} \ 47 | --passage_path=${PASSAGE_PATH} \ 48 | --dataset=${DATASET} \ 49 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 50 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 51 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 52 | --share_weight \ 53 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 54 | --temperature ${TEMPERATURE} \ 55 | --layer_temperature ${LAYER_TEMPERATURE} \ 56 | --output_dir=${BASE_OUTPUT_DIR}/models \ 57 | --distill_para_db=${DISTILL_PARA_DB} \ 58 | --distill_para_ce=${DISTILL_PARA_CE} \ 59 | --distill_para_ce_db_dis=${DISTILL_PARA_CE_DB_DIS} \ 60 | --distill_para_ce_db_layer_score=${DISTILL_PARA_CE_DB_LAYER_SCORE} \ 61 | --distill_ce --train_ce\ 62 | --distill_ce_path=${DISTILL_CE_PATH} \ 63 | --distill_db --train_db\ 64 | --distill_db_path=${DISTILL_DB_PATH} \ 65 | --distill_ce_db_layer_score \ 66 | --disitll_layer_num=${DISTILL_LAYER_NUM} \ 67 | --layer_selection_random \ 68 | --layer_score_reweight \ 69 | --warm_up_ratio=${WARM_RATIO} 70 | 71 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/distillation/distill_from_12de_to_6de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=distilbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | DISTILL_LAYER_NUM=$3 7 | MAX_DOC_LENGTH=$4 8 | MAX_QUERY_LENGTH=$5 9 | TRAIN_BATCH_SIZE=$6 10 | NUM_NEGATIVES=$7 11 | WARM_RATIO=$8 12 | DE_MODEL_PATH=$9 13 | DB_MODEL_PATH=${10} 14 | 15 | 16 | BASE_DATA_DIR=data_preprocess/${DATASET} 17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 20 | DISTILL_DE_PATH=model_output/dual_encoder/models/${DE_MODEL_PATH} # The initialized parameter of ce 21 | DISTILL_DB_PATH=model_output/distilbert/models/${DB_MODEL_PATH} # The initialized parameter of db 22 | 23 | 24 | TOKENIZER_NAME=bert-base-uncased 25 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db 27 | DISTILL_PARA_DE=1 # The distillation parameter of loss_de 28 | DISTILL_PARA_DE_DB_DIS=1 29 | DISTILL_PARA_DE_DB_LAYER_SCORE=1 30 | GRADIENT_ACCUMULATION_STEPS=1 31 | SAVE_STEPS=10 32 | TEMPERATURE=1 33 | LAYER_TEMPERATURE=10 34 | MAX_STEPS=50000 35 | LOGGING_STEPS=10 36 | LR=5e-5 37 | 38 | cd .. 39 | ######################################## Training ######################################## 40 | echo "****************begin Train****************" 41 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 42 | ./run_LEAD.py \ 43 | --model_type=${MODEL_TYPE} \ 44 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 45 | --tokenizer_name=${TOKENIZER_NAME} \ 46 | --train_file=${TRAIN_FILE} \ 47 | --passage_path=${PASSAGE_PATH} \ 48 | --dataset=${DATASET} \ 49 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 50 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 51 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 52 | --share_weight \ 53 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 54 | --temperature ${TEMPERATURE} \ 55 | --layer_temperature ${LAYER_TEMPERATURE} \ 56 | --output_dir=${BASE_OUTPUT_DIR}/models \ 57 | --distill_para_de=${DISTILL_PARA_DE} \ 58 | --distill_para_db=${DISTILL_PARA_DB} \ 59 | --distill_para_de_db_dis=${DISTILL_PARA_DE_DB_DIS} \ 60 | --distill_para_de_db_layer_score=${DISTILL_PARA_DE_DB_LAYER_SCORE} \ 61 | --distill_de --train_de \ 62 | --distill_de_path=${DISTILL_DE_PATH} \ 63 | --distill_db --train_db \ 64 | --distill_db_path=${DISTILL_DB_PATH} \ 65 | --distill_de_db_layer_score \ 66 | --disitll_layer_num=${DISTILL_LAYER_NUM} \ 67 | --layer_selection_random \ 68 | --layer_score_reweight \ 69 | --warm_up_ratio=${WARM_RATIO} 70 | 71 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/distillation/distill_from_24cb_to_12de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=dual_encoder 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | DISTILL_LAYER_NUM=$3 7 | MAX_DOC_LENGTH=$4 8 | MAX_QUERY_LENGTH=$5 9 | TRAIN_BATCH_SIZE=$6 10 | NUM_NEGATIVES=$7 11 | WARM_RATIO=$8 12 | COL_MODEL_PATH=$9 13 | DB_MODEL_PATH=${10} 14 | 15 | 16 | BASE_DATA_DIR=data_preprocess/${DATASET} 17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 20 | DISTILL_COL_PATH=model_output/colbert/models/${COL_MODEL_PATH} # The initialized parameter of ce 21 | DISTILL_DB_PATH=model_output/dual_encoder/models/${DB_MODEL_PATH} # The initialized parameter of db 22 | 23 | 24 | TOKENIZER_NAME=bert-base-uncased 25 | PRETRAINED_MODEL_NAME=nghuyong/ernie-2.0-large-en 26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db 27 | DISTILL_PARA_COL=1 # The distillation parameter of loss_col 28 | DISTILL_PARA_COL_DB_DIS=1 29 | DISTILL_PARA_COL_DB_LAYER_SCORE=1 30 | GRADIENT_ACCUMULATION_STEPS=1 31 | SAVE_STEPS=10 32 | TEMPERATURE=1 33 | LAYER_TEMPERATURE=10 34 | MAX_STEPS=50000 35 | LOGGING_STEPS=10 36 | LR=2e-5 37 | 38 | 39 | cd .. 40 | ######################################## Training ######################################## 41 | echo "****************begin Train****************" 42 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 43 | ./run_LEAD.py \ 44 | --model_type=${MODEL_TYPE} \ 45 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 46 | --tokenizer_name=${TOKENIZER_NAME} \ 47 | --train_file=${TRAIN_FILE} \ 48 | --passage_path=${PASSAGE_PATH} \ 49 | --dataset=${DATASET} \ 50 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 51 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 52 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 53 | --share_weight \ 54 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 55 | --temperature ${TEMPERATURE} \ 56 | --layer_temperature ${LAYER_TEMPERATURE} \ 57 | --output_dir=${BASE_OUTPUT_DIR}/models \ 58 | --distill_para_db=${DISTILL_PARA_DB} \ 59 | --distill_para_col=${DISTILL_PARA_COL} \ 60 | --distill_para_col_db_dis=${DISTILL_PARA_COL_DB_DIS} \ 61 | --distill_para_col_db_layer_score=${DISTILL_PARA_COL_DB_LAYER_SCORE} \ 62 | --distill_col --train_col\ 63 | --distill_col_path=${DISTILL_COL_PATH} \ 64 | --distill_db --train_db\ 65 | --distill_db_path=${DISTILL_DB_PATH} \ 66 | --distill_col_db_layer_score \ 67 | --disitll_layer_num=${DISTILL_LAYER_NUM} \ 68 | --layer_selection_random \ 69 | --layer_score_reweight \ 70 | --warm_up_ratio=${WARM_RATIO} 71 | 72 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/inference/evaluate_12_layer_de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=dual_encoder 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MODEL_PATH=$3 7 | FIRST_STEPS=$4 8 | EVAL_STEPS=$5 9 | MAX_STEPS=$6 10 | MAX_DOC_LENGTH=$7 11 | MAX_QUERY_LENGTH=$8 12 | 13 | BASE_DATA_DIR=data_preprocess/${DATASET} 14 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 15 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 16 | TEST_FILE=${BASE_DATA_DIR}/${DATASET}-test.qa.csv 17 | 18 | TOKENIZER_NAME=bert-base-uncased 19 | PRETRAINED_MODEL_NAME=master 20 | #PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 21 | EVAL_BATCHSIZE=256 22 | 23 | ##remember to change 'de-checkpoint-${CKPT_NAME}' to 'db-checkpoint-${CKPT_NAME}' when evaluating distillation result 24 | ####################################### Multi Evaluation ######################################## 25 | echo "****************begin Evaluate****************" 26 | cd .. 27 | for ITER in $(seq 0 $((($MAX_STEPS - $FIRST_STEPS)/ $EVAL_STEPS))) 28 | do 29 | CKPT_NAME=$(($FIRST_STEPS + $ITER * $EVAL_STEPS)) 30 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} inference_de.py \ 31 | --model_type=${MODEL_TYPE} \ 32 | --tokenizer_name=${TOKENIZER_NAME} \ 33 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 34 | --eval_model_dir=${BASE_OUTPUT_DIR}/models/${MODEL_PATH}/de-checkpoint-${CKPT_NAME} \ 35 | --output_dir=${BASE_OUTPUT_DIR}/eval/${MODEL_PATH}/de-checkpoint-${CKPT_NAME} \ 36 | --test_file=${TEST_FILE} \ 37 | --passage_path=${PASSAGE_PATH} \ 38 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 39 | --per_gpu_eval_batch_size $EVAL_BATCHSIZE 40 | done 41 | 42 | echo "****************End Evaluate****************" 43 | -------------------------------------------------------------------------------- /LEAD/inference/evaluate_6_layer_de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=distilbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MODEL_PATH=$3 7 | FIRST_STEPS=$4 8 | EVAL_STEPS=$5 9 | MAX_STEPS=$6 10 | MAX_DOC_LENGTH=$7 11 | MAX_QUERY_LENGTH=$8 12 | 13 | BASE_DATA_DIR=data_preprocess/${DATASET} 14 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 15 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 16 | TEST_FILE=${BASE_DATA_DIR}/${DATASET}-test.qa.csv 17 | 18 | TOKENIZER_NAME=bert-base-uncased 19 | PRETRAINED_MODEL_NAME=distilbert-base-uncased 20 | EVAL_BATCHSIZE=256 21 | 22 | ####################################### Multi Evaluation ######################################## 23 | echo "****************begin Evaluate****************" 24 | cd .. 25 | for ITER in $(seq 0 $((($MAX_STEPS - $FIRST_STEPS)/ $EVAL_STEPS))) 26 | do 27 | CKPT_NAME=$(($FIRST_STEPS + $ITER * $EVAL_STEPS)) 28 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} inference_de.py \ 29 | --model_type=${MODEL_TYPE} \ 30 | --tokenizer_name=${TOKENIZER_NAME} \ 31 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 32 | --eval_model_dir=${BASE_OUTPUT_DIR}/models/${MODEL_PATH}/db-checkpoint-${CKPT_NAME} \ 33 | --output_dir=${BASE_OUTPUT_DIR}/eval/${MODEL_PATH}/db-checkpoint-${CKPT_NAME} \ 34 | --test_file=${TEST_FILE} \ 35 | --passage_path=${PASSAGE_PATH} \ 36 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 37 | --per_gpu_eval_batch_size $EVAL_BATCHSIZE 38 | done 39 | 40 | echo "****************End Evaluate****************" 41 | -------------------------------------------------------------------------------- /LEAD/warm_up/train_12_layer_ce.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=cross_encoder 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MAX_DOC_LENGTH=$3 7 | MAX_QUERY_LENGTH=$4 8 | TRAIN_BATCH_SIZE=$5 9 | NUM_NEGATIVES=$6 10 | 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | TOKENIZER_NAME=bert-base-uncased 17 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 18 | 19 | MAX_STEPS=20000 20 | GRADIENT_ACCUMULATION_STEPS=10 21 | LOGGING_STEPS=10 22 | SAVE_STEPS=10 23 | LR=1e-5 24 | ######################################## Training ######################################## 25 | echo "****************begin Train****************" 26 | cd ../ 27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 28 | ./run_single_model.py \ 29 | --model_type=${MODEL_TYPE} \ 30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 31 | --tokenizer_name=${TOKENIZER_NAME} \ 32 | --train_file=${TRAIN_FILE} \ 33 | --dataset=${DATASET} \ 34 | --passage_path=${PASSAGE_PATH} \ 35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 38 | --share_weight \ 39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 40 | --output_dir=${BASE_OUTPUT_DIR}/models 41 | 42 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/warm_up/train_12_layer_col.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=colbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MAX_DOC_LENGTH=$3 7 | MAX_QUERY_LENGTH=$4 8 | TRAIN_BATCH_SIZE=$5 9 | NUM_NEGATIVES=$6 10 | 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | TOKENIZER_NAME=bert-base-uncased 17 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 18 | 19 | MAX_STEPS=20000 20 | LOGGING_STEPS=10 21 | SAVE_STEPS=10 22 | LR=2e-5 23 | GRADIENT_ACCUMULATION_STEPS=1 24 | ######################################## Training ######################################## 25 | echo "****************begin Train****************" 26 | cd ../ 27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 28 | ./run_single_model.py \ 29 | --model_type=${MODEL_TYPE} \ 30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 31 | --tokenizer_name=${TOKENIZER_NAME} \ 32 | --train_file=${TRAIN_FILE} \ 33 | --dataset=${DATASET} \ 34 | --passage_path=${PASSAGE_PATH} \ 35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 38 | --share_weight \ 39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 40 | --output_dir=${BASE_OUTPUT_DIR}/models 41 | 42 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/warm_up/train_12_layer_de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=dual_encoder 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MAX_DOC_LENGTH=$3 7 | MAX_QUERY_LENGTH=$4 8 | TRAIN_BATCH_SIZE=$5 9 | NUM_NEGATIVES=$6 10 | 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | TOKENIZER_NAME=bert-base-uncased 17 | PRETRAINED_MODEL_NAME=master 18 | #PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco 19 | 20 | MAX_STEPS=20000 21 | LOGGING_STEPS=10 22 | SAVE_STEPS=10 23 | LR=2e-5 24 | GRADIENT_ACCUMULATION_STEPS=1 25 | ######################################## Training ######################################## 26 | echo "****************begin Train****************" 27 | cd ../ 28 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 29 | ./run_single_model.py \ 30 | --model_type=${MODEL_TYPE} \ 31 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 32 | --tokenizer_name=${TOKENIZER_NAME} \ 33 | --train_file=${TRAIN_FILE} \ 34 | --dataset=${DATASET} \ 35 | --passage_path=${PASSAGE_PATH} \ 36 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 37 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 38 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 39 | --share_weight \ 40 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 41 | --output_dir=${BASE_OUTPUT_DIR}/models 42 | 43 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/warm_up/train_24_layer_col.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=colbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MAX_DOC_LENGTH=$3 7 | MAX_QUERY_LENGTH=$4 8 | TRAIN_BATCH_SIZE=$5 9 | NUM_NEGATIVES=$6 10 | 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | TOKENIZER_NAME=bert-base-uncased 17 | PRETRAINED_MODEL_NAME=nghuyong/ernie-2.0-large-en 18 | 19 | MAX_STEPS=20000 20 | LOGGING_STEPS=10 21 | SAVE_STEPS=10 22 | LR=1e-5 23 | GRADIENT_ACCUMULATION_STEPS=1 24 | ######################################## Training ######################################## 25 | echo "****************begin Train****************" 26 | cd ../ 27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 28 | ./run_single_model.py \ 29 | --model_type=${MODEL_TYPE} \ 30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 31 | --tokenizer_name=${TOKENIZER_NAME} \ 32 | --train_file=${TRAIN_FILE} \ 33 | --dataset=${DATASET} \ 34 | --passage_path=${PASSAGE_PATH} \ 35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 38 | --share_weight \ 39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 40 | --output_dir=${BASE_OUTPUT_DIR}/models 41 | 42 | echo "****************End Train****************" -------------------------------------------------------------------------------- /LEAD/warm_up/train_6_layer_de.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=distilbert 4 | DATASET=$1 5 | MASTER_PORT=$2 6 | MAX_DOC_LENGTH=$3 7 | MAX_QUERY_LENGTH=$4 8 | TRAIN_BATCH_SIZE=$5 9 | NUM_NEGATIVES=$6 10 | 11 | BASE_DATA_DIR=data_preprocess/${DATASET} 12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE} 13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json 14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv 15 | 16 | TOKENIZER_NAME=bert-base-uncased 17 | PRETRAINED_MODEL_NAME=distilbert-base-uncased 18 | 19 | MAX_STEPS=20000 20 | LOGGING_STEPS=10 21 | SAVE_STEPS=10 22 | LR=2e-5 23 | GRADIENT_ACCUMULATION_STEPS=1 24 | ######################################## Training ######################################## 25 | echo "****************begin Train****************" 26 | cd ../ 27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \ 28 | ./run_single_model.py \ 29 | --model_type=${MODEL_TYPE} \ 30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \ 31 | --tokenizer_name=${TOKENIZER_NAME} \ 32 | --train_file=${TRAIN_FILE} \ 33 | --dataset=${DATASET} \ 34 | --passage_path=${PASSAGE_PATH} \ 35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \ 36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ 37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \ 38 | --share_weight \ 39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \ 40 | --output_dir=${BASE_OUTPUT_DIR}/models 41 | 42 | echo "****************End Train****************" 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /MASTER/figs/master_main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/MASTER/figs/master_main.jpg -------------------------------------------------------------------------------- /MASTER/figs/master_main_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/MASTER/figs/master_main_result.jpg -------------------------------------------------------------------------------- /MASTER/finetune/ft_MS_MASTER.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=run_de_ms_MASTER_all_1 # de means dual encoder. 2 | DATA_DIR=/kun_data/marco/ 3 | OUT_DIR=output/$EXP_NAME 4 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path 5 | 6 | for epoch in 80000 7 | do 8 | 9 | # Fine-tune with BM25 negatives using CKPT from epoch as initialization 10 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \ 11 | ./MS/run_de_model.py \ 12 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/SIMLM/checkpoint-$epoch \ 13 | --origin_data_dir=$DATA_DIR/train_stage1.tsv \ 14 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \ 15 | --passage_path=/kun_data/marco \ 16 | --dataset=MS-MARCO \ 17 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \ 18 | --learning_rate=5e-6 --output_dir $OUT_DIR \ 19 | --warmup_steps 1000 --logging_steps 200 --save_steps 5000 --max_steps 30000 \ 20 | --log_dir $TB_DIR \ 21 | --number_neg 31 --fp16 22 | 23 | # Evaluation 3h totally 24 | for CKPT_NUM in 20000 25000 30000 25 | do 26 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 27 | ./MS/inference_de.py \ 28 | --model_type=bert-base-uncased \ 29 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 30 | --output_dir=$OUT_DIR/$CKPT_NUM \ 31 | --train_qa_path=/kun_data/marco/train.query.txt \ 32 | --test_qa_path=/kun_data/marco/dev.query.txt \ 33 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 34 | --passage_path=/kun_data/marco \ 35 | --dataset=MS-MARCO \ 36 | --fp16 37 | done 38 | 39 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 40 | ./MS/inference_de.py \ 41 | --model_type=bert-base-uncased \ 42 | --eval_model_dir=$OUT_DIR/checkpoint-25000 \ 43 | --output_dir=$OUT_DIR/25000 \ 44 | --train_qa_path=/kun_data/marco/train.query.txt \ 45 | --test_qa_path=/kun_data/marco/dev.query.txt \ 46 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 47 | --passage_path=/kun_data/marco \ 48 | --dataset=MS-MARCO \ 49 | --fp16 --write_hardneg=True 50 | 51 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \ 52 | ./MS/run_de_model.py \ 53 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/SIMLM/checkpoint-$epoch \ 54 | --origin_data_dir=$OUT_DIR/25000/train_ce_hardneg.tsv \ 55 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \ 56 | --passage_path=/kun_data/marco \ 57 | --dataset=MS-MARCO \ 58 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \ 59 | --learning_rate=5e-6 --output_dir $OUT_DIR \ 60 | --warmup_steps 1000 --logging_steps 200 --save_steps 5000 --max_steps 40000 \ 61 | --log_dir $TB_DIR \ 62 | --number_neg 31 --fp16 63 | 64 | # Evaluation 3h totally 65 | for CKPT_NUM in 25000 30000 35000 40000 66 | do 67 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 68 | ./MS/inference_de.py \ 69 | --model_type=bert-base-uncased \ 70 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 71 | --output_dir=$OUT_DIR/$CKPT_NUM \ 72 | --train_qa_path=/kun_data/marco/train.query.txt \ 73 | --test_qa_path=/kun_data/marco/dev.query.txt \ 74 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 75 | --passage_path=/kun_data/marco \ 76 | --dataset=MS-MARCO \ 77 | --fp16 78 | done 79 | 80 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 81 | ./MS/inference_de.py \ 82 | --model_type=bert-base-uncased \ 83 | --eval_model_dir=$OUT_DIR/checkpoint-40000 \ 84 | --output_dir=$OUT_DIR/40000 \ 85 | --train_qa_path=/kun_data/marco/train.query.txt \ 86 | --test_qa_path=/kun_data/marco/dev.query.txt \ 87 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 88 | --passage_path=/kun_data/marco \ 89 | --dataset=MS-MARCO \ 90 | --fp16 --write_hardneg=True 91 | 92 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 93 | ./MS/run_ce_model_ele.py \ 94 | --model_type=google/electra-base-discriminator --max_seq_length=192 \ 95 | --per_gpu_train_batch_size=1 --gradient_accumulation_steps=8 \ 96 | --number_neg=63 --learning_rate=1e-5 \ 97 | --output_dir=$OUT_DIR \ 98 | --origin_data_dir=$OUT_DIR/32000/train_ce_hardneg.tsv \ 99 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \ 100 | --passage_path=/kun_data/marco \ 101 | --dataset=MS-MARCO \ 102 | --warmup_steps=2000 --logging_steps=500 --save_steps=8000 \ 103 | --max_steps=33000 --log_dir=$TB_DIR 104 | 105 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 MS/co_training_model_ele.py \ 106 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/SIMLM/checkpoint-$epoch \ 107 | --max_seq_length=128 --per_gpu_train_batch_size=16 --gradient_accumulation_steps=4 \ 108 | --number_neg=41 --learning_rate=5e-6 \ 109 | --reranker_model_type=google/electra-base-discriminator \ 110 | --reranker_model_path=output/run_de_ms_MASTER_all_1/checkpoint-ce-24000 \ 111 | --output_dir=$OUT_DIR \ 112 | --log_dir=$TB_DIR \ 113 | --origin_data_dir=output/run_de_ms_MASTER_all_1/32000/train_ce_hardneg.tsv \ 114 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \ 115 | --passage_path=/kun_data/marco \ 116 | --dataset=MS-MARCO \ 117 | --warmup_steps 5000 --logging_steps 500 --save_steps 5000 --max_steps 25000 \ 118 | --gradient_checkpointing --normal_loss \ 119 | --temperature_normal=1 120 | 121 | for CKPT_NUM in 10000 15000 20000 25000 122 | do 123 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 124 | ./MS/inference_de.py \ 125 | --model_type=bert-base-uncased \ 126 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 127 | --output_dir=$OUT_DIR/$CKPT_NUM \ 128 | --train_qa_path=/kun_data/marco/train.query.txt \ 129 | --test_qa_path=/kun_data/marco/dev.query.txt \ 130 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 131 | --passage_path=/kun_data/marco \ 132 | --dataset=MS-MARCO \ 133 | --fp16 134 | done 135 | 136 | done 137 | -------------------------------------------------------------------------------- /MASTER/finetune/ft_wiki_NQ.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=run_de_nq_MT5_Gen_hardneg_KD # de means dual encoder. 2 | DATA_DIR=/kun_data/DR/ARR/data/ 3 | OUT_DIR=output/$EXP_NAME 4 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path 5 | 6 | for epoch in 140000 7 | do 8 | 9 | # Fine-tune with BM25 negatives using CKPT from epoch as initialization 10 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \ 11 | ./wiki/run_de_model.py \ 12 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki_1/checkpoint-$epoch \ 13 | --origin_data_dir=$DATA_DIR/biencoder-nq-train.json \ 14 | --origin_data_dir_dev=$DATA_DIR/biencoder-nq-dev.json \ 15 | --max_seq_length=128 --per_gpu_train_batch_size=32 --gradient_accumulation_steps=1 \ 16 | --learning_rate=1e-5 --output_dir $OUT_DIR \ 17 | --warmup_steps 8000 --logging_steps 100 --save_steps 10000 --max_steps 80000 \ 18 | --log_dir $TB_DIR \ 19 | --number_neg 1 --fp16 20 | 21 | # Evaluation 3h totally 22 | for CKPT_NUM in 40000 50000 60000 70000 80000 23 | do 24 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 25 | ./wiki/inference_de.py \ 26 | --model_type=bert-base-uncased \ 27 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 28 | --output_dir=$OUT_DIR/$CKPT_NUM \ 29 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \ 30 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \ 31 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \ 32 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 33 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 34 | --fp16 35 | done 36 | 37 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 38 | ./wiki/inference_de.py \ 39 | --model_type=bert-base-uncased \ 40 | --eval_model_dir=$OUT_DIR/checkpoint-50000 \ 41 | --output_dir=$OUT_DIR/$CKPT_NUM \ 42 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \ 43 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \ 44 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \ 45 | --golden_train_qa_path=$DATA_DIR/biencoder-nq-train.json \ 46 | --golden_dev_qa_path=$DATA_DIR/biencoder-nq-dev.json \ 47 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 48 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 49 | --fp16 --write_hardneg=True 50 | 51 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \ 52 | ./wiki/run_de_model.py \ 53 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki_1/checkpoint-$epoch \ 54 | --origin_data_dir=$OUT_DIR/80000/train_ce_hardneg.json \ 55 | --origin_data_dir_dev=$OUT_DIR/80000/dev_ce_hardneg.json \ 56 | --max_seq_length=128 --per_gpu_train_batch_size=16 --gradient_accumulation_steps=1 \ 57 | --learning_rate=5e-6 --output_dir $OUT_DIR \ 58 | --warmup_steps 8000 --logging_steps 100 --save_steps 5000 --max_steps 50000 \ 59 | --log_dir $TB_DIR \ 60 | --number_neg 1 --fp16 61 | 62 | # Evaluation 3h totally 63 | for CKPT_NUM in 30000 35000 40000 45000 50000 64 | do 65 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 66 | ./wiki/inference_de.py \ 67 | --model_type=bert-base-uncased \ 68 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 69 | --output_dir=$OUT_DIR/$CKPT_NUM \ 70 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \ 71 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \ 72 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \ 73 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 74 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 75 | --fp16 76 | done 77 | ''' 78 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 79 | ./wiki/inference_de.py \ 80 | --model_type=bert-base-uncased \ 81 | --eval_model_dir=$OUT_DIR/checkpoint-50000 \ 82 | --output_dir=$OUT_DIR/$CKPT_NUM \ 83 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \ 84 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \ 85 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \ 86 | --golden_train_qa_path=$DATA_DIR/biencoder-nq-train.json \ 87 | --golden_dev_qa_path=$DATA_DIR/biencoder-nq-dev.json \ 88 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 89 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 90 | --fp16 --write_hardneg=True 91 | 92 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 93 | ./wiki/run_ce_model.py \ 94 | --model_type=bert-large-uncased --max_seq_length=256 \ 95 | --per_gpu_train_batch_size=1 --gradient_accumulation_steps=8 \ 96 | --number_neg=15 --learning_rate=1e-5 \ 97 | --output_dir=$OUT_DIR \ 98 | --origin_data_dir=$OUT_DIR/$CKPT_NUM/train_ce_hardneg.json \ 99 | --origin_data_dir_dev=$OUT_DIR/$CKPT_NUM/dev_ce_hardneg.json \ 100 | --warmup_steps=1000 --logging_steps=100 --save_steps=1000 \ 101 | --max_steps=5000 --log_dir=$TB_DIR 102 | 103 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_model.py \ 104 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki/checkpoint-$epoch \ 105 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \ 106 | --number_neg=15 --learning_rate=1e-5 \ 107 | --reranker_model_type=nghuyong/ernie-2.0-large-en \ 108 | --reranker_model_path=$OUT_DIR/checkpoint-ce-4000 \ 109 | --output_dir=$OUT_DIR \ 110 | --log_dir=$TB_DIR \ 111 | --origin_data_dir=$OUT_DIR/$CKPT_NUM/train_ce_hardneg.json \ 112 | --origin_data_dir_dev=$OUT_DIR/$CKPT_NUM/dev_ce_hardneg.json \ 113 | --warmup_steps=8000 --logging_steps=10 --save_steps=10000 --max_steps=80000 \ 114 | --gradient_checkpointing --normal_loss \ 115 | --temperature_normal=1 116 | 117 | for CKPT_NUM in 40000 50000 60000 70000 80000 118 | do 119 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 120 | ./wiki/inference_de.py \ 121 | --model_type=bert-base-uncased \ 122 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 123 | --output_dir=$OUT_DIR/$CKPT_NUM \ 124 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \ 125 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \ 126 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \ 127 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 128 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 129 | --fp16 130 | done 131 | ''' 132 | done 133 | -------------------------------------------------------------------------------- /MASTER/finetune/ft_wiki_TQ.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=run_de_tq_MT5_Gen_hardneg_KD # de means dual encoder. 2 | DATA_DIR=/kun_data/DR/ARR/data/ 3 | OUT_DIR=output/$EXP_NAME 4 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path 5 | 6 | for epoch in 80000 7 | do 8 | # Fine-tune with BM25 negatives using CKPT from epoch as initialization 9 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \ 10 | ./wiki/run_de_model.py \ 11 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki/checkpoint-$epoch \ 12 | --origin_data_dir=$DATA_DIR/biencoder-trivia-train.json \ 13 | --origin_data_dir_dev=$DATA_DIR/biencoder-trivia-dev.json \ 14 | --max_seq_length=128 --per_gpu_train_batch_size=32 --gradient_accumulation_steps=1 \ 15 | --learning_rate=1e-5 --output_dir $OUT_DIR \ 16 | --warmup_steps 8000 --logging_steps 100 --save_steps 10000 --max_steps 80000 \ 17 | --log_dir $TB_DIR \ 18 | --number_neg 1 --fp16 19 | 20 | # Evaluation 3h totally 21 | for CKPT_NUM in 40000 50000 60000 70000 80000 22 | do 23 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 24 | ./wiki/inference_de.py \ 25 | --model_type=bert-base-uncased \ 26 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 27 | --output_dir=$OUT_DIR/$CKPT_NUM \ 28 | --test_qa_path=$DATA_DIR/trivia-test.qa.csv \ 29 | --train_qa_path=$DATA_DIR/trivia-train.qa.csv \ 30 | --dev_qa_path=$DATA_DIR/trivia-dev.qa.csv \ 31 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 32 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 33 | --fp16 34 | done 35 | 36 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 37 | ./wiki/inference_de.py \ 38 | --model_type=bert-base-uncased \ 39 | --eval_model_dir=$OUT_DIR/checkpoint-50000 \ 40 | --output_dir=$OUT_DIR/$CKPT_NUM \ 41 | --test_qa_path=$DATA_DIR/trivia-test.qa.csv \ 42 | --train_qa_path=$DATA_DIR/trivia-train.qa.csv \ 43 | --dev_qa_path=$DATA_DIR/trivia-dev.qa.csv \ 44 | --golden_train_qa_path=$DATA_DIR/biencoder-trivia-train.json \ 45 | --golden_dev_qa_path=$DATA_DIR/biencoder-trivia-dev.json \ 46 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 47 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 48 | --fp16 --write_hardneg=True 49 | 50 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \ 51 | ./wiki/run_de_model.py \ 52 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki/checkpoint-$epoch \ 53 | --origin_data_dir=$OUT_DIR/$CKPT_NUM/train_ce_hardneg.json \ 54 | --origin_data_dir_dev=$OUT_DIR/$CKPT_NUM/dev_ce_hardneg.json \ 55 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \ 56 | --learning_rate=5e-6 --output_dir $OUT_DIR \ 57 | --warmup_steps 8000 --logging_steps 100 --save_steps 10000 --max_steps 80000 \ 58 | --log_dir $TB_DIR \ 59 | --number_neg 1 --fp16 60 | 61 | # Evaluation 3h totally 62 | for CKPT_NUM in 40000 50000 60000 70000 80000 63 | do 64 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \ 65 | ./wiki/inference_de.py \ 66 | --model_type=bert-base-uncased \ 67 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \ 68 | --output_dir=$OUT_DIR/$CKPT_NUM \ 69 | --test_qa_path=$DATA_DIR/trivia-test.qa.csv \ 70 | --train_qa_path=$DATA_DIR/trivia-train.qa.csv \ 71 | --dev_qa_path=$DATA_DIR/trivia-dev.qa.csv \ 72 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \ 73 | --passage_path=$DATA_DIR/psgs_w100.tsv \ 74 | --fp16 75 | done 76 | 77 | done -------------------------------------------------------------------------------- /MASTER/finetune/utils/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | from torch.optim import Optimizer 9 | 10 | 11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 12 | """Log a histogram of trust ratio scalars in across layers.""" 13 | results = collections.defaultdict(list) 14 | for group in optimizer.param_groups: 15 | for p in group['params']: 16 | state = optimizer.state[p] 17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 18 | if i in state: 19 | results[i].append(state[i]) 20 | 21 | for k, v in results.items(): 22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 23 | 24 | class Lamb(Optimizer): 25 | r"""Implements Lamb algorithm. 26 | 27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay) 57 | self.adam = adam 58 | super(Lamb, self).__init__(params, defaults) 59 | 60 | def step(self, closure=None): 61 | """Performs a single optimization step. 62 | 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state['step'] = 0 84 | # Exponential moving average of gradient values 85 | state['exp_avg'] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state['exp_avg_sq'] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 99 | 100 | # Paper v3 does not use debiasing. 101 | # Apply bias to lr to avoid broadcast. 102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 103 | 104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 105 | 106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 107 | if group['weight_decay'] != 0: 108 | adam_step.add_(group['weight_decay'], p.data) 109 | 110 | adam_norm = adam_step.pow(2).sum().sqrt() 111 | if weight_norm == 0 or adam_norm == 0: 112 | trust_ratio = 1 113 | else: 114 | trust_ratio = weight_norm / adam_norm 115 | state['weight_norm'] = weight_norm 116 | state['adam_norm'] = adam_norm 117 | state['trust_ratio'] = trust_ratio 118 | if self.adam: 119 | trust_ratio = 1 120 | 121 | p.data.add_(-step_size * trust_ratio, adam_step) 122 | 123 | return loss 124 | -------------------------------------------------------------------------------- /MASTER/pretrain/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, Union 3 | import os 4 | from transformers import TrainingArguments 5 | 6 | @dataclass 7 | class DataTrainingArguments: 8 | """ 9 | Arguments pertaining to what data we are going to input our model for training and eval. 10 | """ 11 | 12 | dataset_name: Optional[str] = field( 13 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 14 | ) 15 | dataset_config_name: Optional[str] = field( 16 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 17 | ) 18 | train_dir: str = field( 19 | default=None, metadata={"help": "Path to train directory"} 20 | ) 21 | train_path: Union[str] = field( 22 | default=None, metadata={"help": "Path to train data"} 23 | ) 24 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 25 | validation_file: Optional[str] = field( 26 | default=None, 27 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 28 | ) 29 | train_ref_file: Optional[str] = field( 30 | default=None, 31 | metadata={"help": "An optional input train ref data file for whole word masking in Chinese."}, 32 | ) 33 | validation_ref_file: Optional[str] = field( 34 | default=None, 35 | metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."}, 36 | ) 37 | frequency_dict: Union[str] = field( 38 | default=None, metadata={"help": "Path to frequency dict"} 39 | ) 40 | overwrite_cache: bool = field( 41 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 42 | ) 43 | max_seq_length: Optional[int] = field( 44 | default=None, 45 | metadata={ 46 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 47 | "than this will be truncated. Default to the max input length of the model." 48 | }, 49 | ) 50 | min_seq_length: int = field(default=16) 51 | preprocessing_num_workers: Optional[int] = field( 52 | default=None, 53 | metadata={"help": "The number of processes to use for the preprocessing."}, 54 | ) 55 | mlm_probability: float = field( 56 | default=0.3, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 57 | ) 58 | decoder_mlm_probability: float = field( 59 | default=0.5, metadata={"help": "Ratio of tokens to mask for decoder masked language modeling loss"} 60 | ) 61 | pad_to_max_length: bool = field( 62 | default=False, 63 | metadata={ 64 | "help": "Whether to pad all samples to `max_seq_length`. " 65 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 66 | }, 67 | ) 68 | 69 | def __post_init__(self): 70 | if self.train_dir is not None: 71 | files = os.listdir(self.train_dir) 72 | self.train_path = [ 73 | os.path.join(self.train_dir, f) 74 | for f in files 75 | if f.endswith('tsv') or f.endswith('json') 76 | ] 77 | 78 | @dataclass 79 | class ModelArguments: 80 | """ 81 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 82 | """ 83 | 84 | model_name_or_path: Optional[str] = field( 85 | default=None, 86 | metadata={ 87 | "help": "The model checkpoint for weights initialization." 88 | "Don't set if you want to train a model from scratch." 89 | }, 90 | ) 91 | model_type: Optional[str] = field( 92 | default='bert', 93 | ) 94 | config_name: Optional[str] = field( 95 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 96 | ) 97 | tokenizer_name: Optional[str] = field( 98 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 99 | ) 100 | cache_dir: Optional[str] = field( 101 | default=None, 102 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 103 | ) 104 | use_fast_tokenizer: bool = field( 105 | default=True, 106 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 107 | ) 108 | 109 | n_head_layers: int = field(default=2) 110 | skip_from: int = field(default=2) 111 | late_mlm: bool = field(default=False) 112 | temp: float = field(default=0.05) 113 | 114 | 115 | @dataclass 116 | class CondenserPreTrainingArguments(TrainingArguments): 117 | warmup_ratio: float = field(default=0.1) 118 | 119 | 120 | @dataclass 121 | class CoCondenserPreTrainingArguments(CondenserPreTrainingArguments): 122 | cache_chunk_size: int = field(default=-1) 123 | -------------------------------------------------------------------------------- /MASTER/pretrain/run_pre_training.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import logging 3 | import math 4 | import os 5 | import sys 6 | from datasets import load_dataset 7 | import json 8 | from arguments import DataTrainingArguments, ModelArguments, \ 9 | CondenserPreTrainingArguments as TrainingArguments 10 | from data import CondenserCollator 11 | from modeling import CondenserForPretraining, RobertaCondenserForPretraining, ELECTRACondenserForPretraining 12 | from trainer import CondenserPreTrainer as Trainer 13 | import transformers 14 | from transformers import ( 15 | CONFIG_MAPPING, 16 | AutoConfig, 17 | AutoTokenizer, 18 | HfArgumentParser, 19 | set_seed, ) 20 | from transformers.trainer_utils import is_main_process 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | CONDENSER_TYPE_MAP = { 25 | 'bert': CondenserForPretraining, 26 | 'roberta': RobertaCondenserForPretraining, 27 | } 28 | 29 | 30 | def main(): 31 | # See all possible arguments in src/transformers/training_args.py 32 | # or by passing the --help flag to this script. 33 | # We now keep distinct sets of args, for a cleaner separation of concerns. 34 | 35 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 36 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 37 | # If we pass only one argument to the script and it's the path to a json file, 38 | # let's parse it to get our arguments. 39 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 40 | else: 41 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 42 | 43 | if ( 44 | os.path.exists(training_args.output_dir) 45 | and os.listdir(training_args.output_dir) 46 | and training_args.do_train 47 | and not training_args.overwrite_output_dir 48 | ): 49 | raise ValueError( 50 | f"Output directory ({training_args.output_dir}) already exists and is not empty." 51 | "Use --overwrite_output_dir to overcome." 52 | ) 53 | 54 | model_args: ModelArguments 55 | data_args: DataTrainingArguments 56 | training_args: TrainingArguments 57 | 58 | # Setup logging 59 | logging.basicConfig( 60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 61 | datefmt="%m/%d/%Y %H:%M:%S", 62 | level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, 63 | ) 64 | 65 | # Log on each process the small summary: 66 | logger.warning( 67 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 68 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 69 | ) 70 | # Set the verbosity to info of the Transformers logger (on main process only): 71 | if is_main_process(training_args.local_rank): 72 | transformers.utils.logging.set_verbosity_info() 73 | transformers.utils.logging.enable_default_handler() 74 | transformers.utils.logging.enable_explicit_format() 75 | logger.info("Training/evaluation parameters %s", training_args) 76 | 77 | # Set seed before initializing model. 78 | set_seed(training_args.seed) 79 | 80 | train_set = load_dataset( 81 | 'json', 82 | data_files=data_args.train_path, 83 | block_size=2**25, 84 | )['train'] 85 | dev_set = load_dataset( 86 | 'json', 87 | data_files=data_args.validation_file, 88 | block_size=2**25 89 | )['train'] \ 90 | if data_args.validation_file is not None else None 91 | 92 | if model_args.config_name: 93 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) 94 | elif model_args.model_name_or_path: 95 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 96 | else: 97 | config = CONFIG_MAPPING[model_args.model_type]() 98 | logger.warning("You are instantiating a new config instance from scratch.") 99 | 100 | if model_args.tokenizer_name: 101 | tokenizer = AutoTokenizer.from_pretrained( 102 | model_args.tokenizer_name, 103 | cache_dir=model_args.cache_dir, use_fast=False 104 | ) 105 | elif model_args.model_name_or_path: 106 | tokenizer = AutoTokenizer.from_pretrained( 107 | model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=False 108 | ) 109 | else: 110 | raise ValueError( 111 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 112 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 113 | ) 114 | 115 | # initialize the Condenser Pre-training LMX 116 | if model_args.model_type not in CONDENSER_TYPE_MAP: 117 | raise NotImplementedError(f'Condenser for {model_args.model_type} LM is not implemented') 118 | _condenser_cls = CONDENSER_TYPE_MAP[model_args.model_type] 119 | if model_args.model_name_or_path: 120 | model = _condenser_cls.from_pretrained( 121 | model_args, data_args, training_args, 122 | model_args.model_name_or_path, 123 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 124 | config=config, 125 | cache_dir=model_args.cache_dir, 126 | ) 127 | else: 128 | logger.warning('Training from scratch.') 129 | model = _condenser_cls.from_config( 130 | config, model_args, data_args, training_args) 131 | 132 | # Data collator 133 | # This one will take care of randomly masking the tokens. 134 | if data_args.frequency_dict: 135 | logger.info("Load Frequency Dictionary from %s", data_args.frequency_dict) 136 | frequency_dict = json.load(open(data_args.frequency_dict)) 137 | else: 138 | frequency_dict = None 139 | data_collator = CondenserCollator( 140 | tokenizer=tokenizer, 141 | mlm_probability=data_args.mlm_probability, 142 | decoder_mlm_probability=data_args.decoder_mlm_probability, 143 | max_seq_length=data_args.max_seq_length, 144 | frequency_dict=frequency_dict 145 | ) 146 | # Initialize our Trainer 147 | trainer = Trainer( 148 | model=model, 149 | args=training_args, 150 | train_dataset=train_set, 151 | eval_dataset=dev_set, 152 | tokenizer=tokenizer, 153 | data_collator=data_collator, 154 | ) 155 | 156 | # Training 157 | if training_args.do_train: 158 | model_path = ( 159 | model_args.model_name_or_path 160 | if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) 161 | else None 162 | ) 163 | trainer.train(model_path=model_path) 164 | trainer.save_model() # Saves the tokenizer too for easy upload 165 | 166 | # Evaluation 167 | results = {} 168 | if training_args.do_eval: 169 | logger.info("*** Evaluate ***") 170 | 171 | eval_output = trainer.evaluate() 172 | 173 | perplexity = math.exp(eval_output["eval_loss"]) 174 | results["perplexity"] = perplexity 175 | 176 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_mlm_wwm.txt") 177 | if trainer.is_world_process_zero(): 178 | with open(output_eval_file, "w") as writer: 179 | logger.info("***** Eval results *****") 180 | for key, value in results.items(): 181 | logger.info(f" {key} = {value}") 182 | writer.write(f"{key} = {value}\n") 183 | 184 | return results 185 | 186 | 187 | def _mp_fn(index): 188 | # For xla_spawn (TPUs) 189 | main() 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /MASTER/pretrain/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 50000 run_pre_training.py \ 2 | --output_dir ckpt/MASTER \ 3 | --model_name_or_path bert-base-uncased \ 4 | --do_train \ 5 | --save_steps 40000 \ 6 | --per_device_train_batch_size 128 \ 7 | --gradient_accumulation_steps 2 \ 8 | --warmup_ratio 0.1 \ 9 | --learning_rate 3e-4 \ 10 | --num_train_epochs 40 \ 11 | --overwrite_output_dir \ 12 | --dataloader_num_workers 32 \ 13 | --n_head_layers 2 \ 14 | --skip_from 6 \ 15 | --max_seq_length 128 \ 16 | --train_dir process_data \ 17 | --frequency_dict frequency_dict_MS_doc.json \ 18 | --weight_decay 0.01 \ 19 | --late_mlm \ 20 | --fp16 21 | -------------------------------------------------------------------------------- /PROD/MarcoDoc_Data.sh: -------------------------------------------------------------------------------- 1 | # download MSMARCO doc data 2 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz 3 | gunzip msmarco-docs.tsv.gz 4 | 5 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz 6 | gunzip msmarco-doctrain-queries.tsv.gz 7 | 8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz 9 | gunzip msmarco-doctrain-qrels.tsv.gz 10 | 11 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz 12 | gunzip msmarco-test2019-queries.tsv.gz 13 | 14 | wget https://trec.nist.gov/data/deep/2019qrels-docs.txt 15 | 16 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz 17 | gunzip msmarco-docdev-queries.tsv.gz 18 | 19 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz 20 | gunzip msmarco-docdev-qrels.tsv.gz 21 | 22 | wget https://msmarco.blob.core.windows.net/msmarcoranking/docleaderboard-queries.tsv.gz -------------------------------------------------------------------------------- /PROD/MarcoPas_Data.sh: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR=$PWD 2 | 3 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz 4 | tar -zxf marco.tar.gz 5 | rm -rf marco.tar.gz 6 | cd marco 7 | 8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz 9 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv 10 | gunzip qidpidtriples.train.full.2.tsv.gz 11 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv 12 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv -------------------------------------------------------------------------------- /PROD/ProD_KD/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_KD/model/__init__.py -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_KD/utils/__init__.py -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/build_marco_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import random 4 | 5 | relevance_file = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/qrels.train.tsv" 6 | query_file = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/train.query.txt" 7 | negative_file = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/train.negatives.tsv" 8 | outfile = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/marco_train.json" 9 | n_sample = 30 10 | 11 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1): 12 | def gen(): 13 | for i, line in tqdm(enumerate(fd)): 14 | if i % trainer_num == trainer_id: 15 | slots = line.rstrip('\n').split(delimiter) 16 | if len(slots) == 1: 17 | yield slots, 18 | else: 19 | yield slots 20 | return gen() 21 | 22 | def read_qrel(relevance_file): 23 | qrel = {} 24 | with open(relevance_file, encoding='utf8') as f: 25 | tsvreader = csv_reader(f, delimiter="\t") 26 | for [topicid, _, docid, rel] in tsvreader: 27 | assert rel == "1" 28 | if topicid in qrel: 29 | qrel[topicid].append(docid) 30 | else: 31 | qrel[topicid] = [docid] 32 | return qrel 33 | 34 | def read_qstring(query_file): 35 | q_string = {} 36 | with open(query_file, 'r', encoding='utf-8') as file: 37 | for num, line in enumerate(file): 38 | line = line.strip('\n') # 删除换行符 39 | line = line.split('\t') 40 | q_string[line[0]] = line[1] 41 | return q_string 42 | 43 | datalist = [] 44 | qrel = read_qrel(relevance_file) 45 | q_string = read_qstring(query_file) 46 | with open(negative_file, 'r', encoding='utf8') as nf: 47 | reader = csv_reader(nf) 48 | for cnt, line in enumerate(reader): 49 | examples = {} 50 | q = line[0] 51 | nn = line[1] 52 | nn = nn.split(',') 53 | random.shuffle(nn) 54 | nn = nn[:n_sample] 55 | examples['query_id'] = q 56 | examples['query_string'] = q_string[q] 57 | examples['pos_id'] = qrel[q] 58 | examples['neg_id'] = nn 59 | if len(examples['query_id'])!=0 and len(examples['pos_id'])!=0 and len(examples['neg_id'])!=0: 60 | datalist.append(examples) 61 | 62 | print("data len:", len(datalist)) 63 | print("data keys:", datalist[0].keys()) 64 | print("data info:", datalist[0]) 65 | 66 | with open(outfile, 'w',encoding='utf-8') as f: 67 | json.dump(datalist, f,indent=2) -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/dataset_division_nq.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | from tqdm import tqdm 4 | import os 5 | import sys 6 | 7 | 8 | # data file 9 | # qid \t qstring \t pos_id \t neg_id 10 | # type : str 11 | # result_file_path1="/colab_space/fanshuai/KDnq/result/24CEt6DE_LwF_5e-5/20000/train_result_dict_list.json" 12 | # result_file_path2="/colab_space/fanshuai/KDnq/result/Ranker_24layer/10000trainrank/train_rerank_dict_list.json" 13 | # data_file_path="/colab_space/fanshuai/KDnq/result/24CEt6DE_LwF_5e-5/20000/train_nq_flash4.json" 14 | # output_dir = "/colab_space/fanshuai/KDnq/result/Ranker_24layer/10000div/" 15 | 16 | def read_result(result_file_path): 17 | result_data = [] 18 | with open(result_file_path, 'r', encoding="utf-8") as f: 19 | data = json.load(f) 20 | for i, example in tqdm(enumerate(data)): 21 | result_data.append(example) 22 | 23 | return result_data 24 | 25 | def load_data(data_file_path): 26 | with open(data_file_path, 'r', encoding="utf-8") as f: 27 | data = json.load(f) 28 | print('Aggregated data size: {}'.format(len(data))) 29 | # filter those without positive ctx 30 | return data 31 | 32 | # def divide_data(result_data_student): 33 | # ranking = [] 34 | # recall_q_top1 = set() 35 | # recall_q_2ti = set() 36 | # recall_q_2t5 = set() 37 | # recall_q_2t10 = set() 38 | # recall_q_2t15 = set() 39 | # recall_q_6t20 = set() 40 | # recall_q_21t50 = set() 41 | # recall_q_51t100 = set() 42 | # recall_q_101tall = set() 43 | # 44 | # for data in result_data_student: 45 | # qid = data['id'] 46 | # for i, ctx in enumerate(data['ctxs']): 47 | # if ctx['hit'] == 'True': 48 | # if i < 100 and i >= 50: 49 | # recall_q_51t100.add(qid) 50 | # elif i < 50 and i >= 20: 51 | # recall_q_21t50.add(qid) 52 | # elif i < 20 and i >= 5: 53 | # recall_q_6t20.add(qid) 54 | # elif i < 5 and i >= 1: 55 | # recall_q_2t5.add(qid) 56 | # elif i == 0: 57 | # recall_q_top1.add(qid) 58 | # break 59 | # 60 | # print("top1 data num :", len(recall_q_top1)) 61 | # print("top2 to topi data num :", len(recall_q_2ti)) 62 | # print("top2 to top5 data num :", len(recall_q_2t5)) 63 | # print("top2 to top10 data num :", len(recall_q_2t10)) 64 | # print("top2 to top15 data num :", len(recall_q_2t15)) 65 | # print("top6 to top20 data num :", len(recall_q_6t20)) 66 | # print("top21 to top50 data num :", len(recall_q_21t50)) 67 | # print("top51 to top100 data num :", len(recall_q_51t100)) 68 | # print("top101 to top1000 data num :", len(recall_q_101tall)) 69 | # print() 70 | # 71 | # 72 | # qid_divide_dic = {} 73 | # qid_divide_dic['top1'] = recall_q_top1 74 | # qid_divide_dic['2ti'] = recall_q_2ti 75 | # qid_divide_dic['2t5'] = recall_q_2t5 76 | # qid_divide_dic['2t10'] = recall_q_2t10 77 | # qid_divide_dic['2t15'] = recall_q_2t15 78 | # qid_divide_dic['6t20'] = recall_q_6t20 79 | # qid_divide_dic['21t50'] = recall_q_21t50 80 | # qid_divide_dic['51t100'] = recall_q_51t100 81 | # qid_divide_dic['101tall'] = recall_q_101tall 82 | # 83 | # return qid_divide_dic 84 | 85 | def divide_data(result_data_student, result_data_teacher): 86 | ranking = [] 87 | t2_all_better = set() 88 | t2_15_better = set() 89 | 90 | qid_data_dic={} 91 | for data in result_data_teacher: 92 | qid_data_dic[data['id']] = data 93 | 94 | for data in result_data_student: 95 | qid = data['id'] 96 | 97 | s_rank = 99999 98 | t_rank = 99999 99 | for i,ctx in enumerate(data['ctxs']): 100 | if ctx['hit'] == 'True': 101 | s_rank = i 102 | break 103 | 104 | t_data = qid_data_dic[qid] 105 | for i,ctx in enumerate(t_data['ctxs']): 106 | if ctx['hit'] == 'True': 107 | t_rank = i 108 | break 109 | 110 | if t_rank < s_rank and t_rank < 15: 111 | t2_15_better.add(qid) 112 | 113 | if t_rank < s_rank and t_rank < 100: 114 | t2_all_better.add(qid) 115 | 116 | print("t2_15_better len ", len(t2_15_better)) 117 | print("t2_31_better len ", len(t2_all_better)) 118 | print() 119 | 120 | qid_divide_dic = {} 121 | qid_divide_dic['t2_15_better'] = t2_15_better 122 | qid_divide_dic['t2_all_better'] = t2_all_better 123 | 124 | return qid_divide_dic 125 | 126 | def main(result_file_path1, result_file_path2, data_file_path, output_dir): 127 | t1_qidset = set() 128 | t2_qidset = set() 129 | 130 | t2_better_set = set() 131 | 132 | result_data_student = read_result(result_file_path1) 133 | result_data_teacher = read_result(result_file_path2) 134 | 135 | # train_data = load_data(data_file_path) 136 | 137 | # top1 / top2-top5 / top6-top20 / top30-top50 / top50-top100 / top100-topall 138 | #qid_divide_dic1 = divide_data(result_data_student) 139 | #qid_divide_dic2 = divide_data(result_data_teacher) 140 | qid_divide_dic = divide_data(result_data_student, result_data_teacher) 141 | 142 | t2_better_set = qid_divide_dic['t2_15_better'] 143 | 144 | print("t2_better_set len:", len(t2_better_set)) 145 | 146 | with open(data_file_path, 'r', encoding="utf-8") as f: 147 | dataset = json.load(f) 148 | print('Aggregated data size: {}'.format(len(dataset))) 149 | 150 | t2_better_data = [] 151 | 152 | for data in dataset: 153 | if data['q_id'] in t2_better_set: 154 | t2_better_data.append(data) 155 | 156 | print(len(t2_better_data)) 157 | 158 | print("check if data right...") 159 | if len(t2_better_data) != len(t2_better_set): 160 | print("data error") 161 | exit(0) 162 | print("data set success!") 163 | 164 | t2_better_output_dir = os.path.join(output_dir, "flash4_top15_better.json") 165 | with open(t2_better_output_dir, 'w', encoding='utf-8') as f: 166 | json.dump(t2_better_data, f, indent=2) 167 | 168 | 169 | 170 | 171 | 172 | if __name__ == "__main__": 173 | result_file_path_s = sys.argv[1] 174 | result_file_path_t = sys.argv[2] 175 | data_file_path = sys.argv[3] 176 | output_dir = sys.argv[4] 177 | main(result_file_path_s, result_file_path_t, data_file_path, output_dir) 178 | print("data division done") -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | from torch.optim import Optimizer 9 | 10 | 11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 12 | """Log a histogram of trust ratio scalars in across layers.""" 13 | results = collections.defaultdict(list) 14 | for group in optimizer.param_groups: 15 | for p in group['params']: 16 | state = optimizer.state[p] 17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 18 | if i in state: 19 | results[i].append(state[i]) 20 | 21 | for k, v in results.items(): 22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 23 | 24 | class Lamb(Optimizer): 25 | r"""Implements Lamb algorithm. 26 | 27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay) 57 | self.adam = adam 58 | super(Lamb, self).__init__(params, defaults) 59 | 60 | def step(self, closure=None): 61 | """Performs a single optimization step. 62 | 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state['step'] = 0 84 | # Exponential moving average of gradient values 85 | state['exp_avg'] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state['exp_avg_sq'] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 99 | 100 | # Paper v3 does not use debiasing. 101 | # Apply bias to lr to avoid broadcast. 102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 103 | 104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 105 | 106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 107 | if group['weight_decay'] != 0: 108 | adam_step.add_(group['weight_decay'], p.data) 109 | 110 | adam_norm = adam_step.pow(2).sum().sqrt() 111 | if weight_norm == 0 or adam_norm == 0: 112 | trust_ratio = 1 113 | else: 114 | trust_ratio = weight_norm / adam_norm 115 | state['weight_norm'] = weight_norm 116 | state['adam_norm'] = adam_norm 117 | state['trust_ratio'] = trust_ratio 118 | if self.adam: 119 | trust_ratio = 1 120 | 121 | p.data.add_(-step_size * trust_ratio, adam_step) 122 | 123 | return loss 124 | -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/prepare_ce_data_nq.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from tqdm import tqdm 4 | def load_data(data_path=None): 5 | assert data_path 6 | with open(data_path, 'r',encoding='utf-8') as fin: 7 | data = json.load(fin) 8 | examples = [] 9 | for k, example in enumerate(data): 10 | if not 'id' in example: 11 | example['id'] = k 12 | for c in example['ctxs']: 13 | if not 'score' in c: 14 | c['score'] = 1.0 / (k + 1) 15 | examples.append(example) 16 | return examples 17 | 18 | def read_train_pos(ground_truth_path): 19 | origin_train_path = ground_truth_path 20 | with open(origin_train_path, "r", encoding="utf-8") as ifile: 21 | # file format: question, answers 22 | train_list = json.load(ifile) 23 | train_q_pos_dict = {} 24 | for example in train_list: 25 | if len(example['positive_ctxs'])==0 or "positive_ctxs" not in example.keys(): 26 | continue 27 | train_q_pos_dict[example['question']]=example['positive_ctxs'][0] 28 | return train_q_pos_dict 29 | 30 | def reform_out(examples,outfile, ground_truth_path): 31 | train_q_pos_dict = read_train_pos(ground_truth_path) 32 | transfer_list = [] 33 | easy_ctxs = [] 34 | for infer_result in tqdm(examples): 35 | if 'passage_id' not in infer_result.keys(): 36 | q_id = infer_result["id"] 37 | else: 38 | q_id = infer_result["passage_id"] 39 | q_str = infer_result["question"] 40 | q_answer = infer_result["answers"] 41 | positive_ctxs = [] 42 | negative_ctxs = [] 43 | if q_str in train_q_pos_dict.keys(): 44 | # ground true 45 | real_true_dic = train_q_pos_dict[q_str] 46 | # real_true_doc_id = real_true_dic['passage_id'] if 'passage_id' in real_true_dic.keys() else real_true_dic['id'] 47 | if 'passage_id' not in real_true_dic.keys() and 'id' in real_true_dic.keys(): 48 | real_true_dic['passage_id'] = real_true_dic['id'] 49 | elif 'psg_id' in real_true_dic.keys(): 50 | real_true_dic['passage_id'] = real_true_dic['psg_id'] 51 | positive_ctxs.append(real_true_dic) 52 | 53 | for doc in infer_result['ctxs']: 54 | doc_text = doc['text'] 55 | doc_title = doc['title'] 56 | if doc['hit']=="True": 57 | positive_ctxs.append({'title':doc_title,'text':doc_text,'passage_id':doc['d_id'],'score':str(doc['score'])}) 58 | else: 59 | negative_ctxs.append({'title':doc_title,'text':doc_text,'passage_id':doc['d_id'],'score':str(doc['score'])}) 60 | # easy_ctxs.append({'title':doc_title,'text':doc_text,'passage_id':doc['d_id'],'score':str(doc['score'])}) 61 | 62 | transfer_list.append( 63 | { 64 | "q_id":str(q_id), "question":q_str, "answers" :q_answer ,"positive_ctxs":positive_ctxs,"hard_negative_ctxs":negative_ctxs,"negative_ctxs":[] 65 | } 66 | ) 67 | 68 | print("total data to ce train: ", len(transfer_list)) 69 | # print("easy data to ce train: ", len(easy_list)) 70 | print("hardneg num:", len(transfer_list[0]["hard_negative_ctxs"])) 71 | # print("easyneg num:", len(transfer_list[0]["easy_negative_ctxs"])) 72 | 73 | with open(outfile, 'w',encoding='utf-8') as f: 74 | json.dump(transfer_list, f,indent=2) 75 | return 76 | if __name__ == "__main__": 77 | inference_results = sys.argv[1] 78 | outfile = sys.argv[2] 79 | ground_truth_path = sys.argv[3] 80 | examples = load_data(inference_results) 81 | reform_out(examples,outfile,ground_truth_path) 82 | 83 | 84 | -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/preprae_ce_marco_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import random 4 | import pickle 5 | import sys 6 | 7 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1): 8 | def gen(): 9 | for i, line in tqdm(enumerate(fd)): 10 | if i % trainer_num == trainer_id: 11 | slots = line.rstrip('\n').split(delimiter) 12 | if len(slots) == 1: 13 | yield slots, 14 | else: 15 | yield slots 16 | return gen() 17 | 18 | def read_qrel(relevance_file): 19 | qrel = {} 20 | with open(relevance_file, encoding='utf8') as f: 21 | tsvreader = csv_reader(f, delimiter="\t") 22 | for [topicid, _, docid, rel] in tsvreader: 23 | assert rel == "1" 24 | if topicid in qrel: 25 | qrel[topicid].append(docid) 26 | else: 27 | qrel[topicid] = [docid] 28 | return qrel 29 | 30 | def read_qstring(query_file): 31 | q_string = {} 32 | with open(query_file, 'r', encoding='utf-8') as file: 33 | for num, line in enumerate(file): 34 | line = line.strip('\n') 35 | line = line.split('\t') 36 | q_string[line[0]] = line[1] 37 | return q_string 38 | 39 | def load_train_reference_from_stream(input_file, trainer_id=0, trainer_num=1): 40 | """Reads a tab separated value file.""" 41 | with open(input_file, 'r', encoding='utf8') as f: 42 | reader = csv_reader(f, trainer_id=trainer_id, trainer_num=trainer_num) 43 | #headers = 'query_id\tpos_id\tneg_id'.split('\t') 44 | 45 | #Example = namedtuple('Example', headers) 46 | qrel = {} 47 | for [topicid, _, docid, rel] in reader: 48 | topicid = int(topicid) 49 | assert rel == "1" 50 | if topicid in qrel: 51 | qrel[topicid].append(int(docid)) 52 | else: 53 | qrel[topicid] = [int(docid)] 54 | return qrel 55 | 56 | def main(relevance_file, query_file, result_file, outfile, neg_num): 57 | with open(result_file, 'rb') as f: 58 | qids_to_ranked_candidate_passages, qids_to_ranked_candidate_scores = pickle.load(f) 59 | qids_to_relevant_passageids = load_train_reference_from_stream(relevance_file) 60 | 61 | datalist = [] 62 | q_string = read_qstring(query_file) 63 | for qid in qids_to_ranked_candidate_passages: 64 | examples = {} 65 | if qid in qids_to_relevant_passageids: 66 | target_pid = qids_to_relevant_passageids[qid] 67 | examples['query_id'] = str(qid) 68 | examples['pos_id'] = [str(id) for id in target_pid] 69 | examples['query_string'] = q_string[examples['query_id']] 70 | candidate_pid = qids_to_ranked_candidate_passages[qid] 71 | examples['neg_id'] = [] 72 | 73 | for i, pid in enumerate(candidate_pid): 74 | if pid in target_pid: 75 | # print(target_pid, pid) 76 | continue 77 | else: 78 | if len(examples['neg_id']) < neg_num: 79 | examples['neg_id'].append(str(pid)) 80 | else: 81 | break 82 | if len(examples['query_id'])!=0 and len(examples['pos_id'])!=0 and len(examples['neg_id'])!=0: 83 | datalist.append(examples) 84 | 85 | 86 | # print(num,len(qids_to_ranked_candidate_passages),num/len(qids_to_ranked_candidate_passages)) 87 | print("data len:", len(datalist)) 88 | print("data keys:", datalist[0].keys()) 89 | print("data info:", datalist[0]) 90 | print("data info:", datalist[100]) 91 | print("data info:", datalist[1000]) 92 | 93 | for data in datalist: 94 | for id in data['neg_id']: 95 | if int(id) in data['pos_id']: 96 | print("data error") 97 | print(data) 98 | exit(0) 99 | if str(id) in data['pos_id']: 100 | print("data error") 101 | print(data) 102 | exit(0) 103 | print("Correctness check completed") 104 | 105 | 106 | with open(outfile, 'w',encoding='utf-8') as f: 107 | json.dump(datalist, f,indent=2) 108 | 109 | if __name__ == "__main__": 110 | relevance_file = sys.argv[1] 111 | query_file = sys.argv[2] 112 | result_file = sys.argv[3] 113 | outfile = sys.argv[4] 114 | neg_num = int(sys.argv[5]) 115 | main(relevance_file, query_file, result_file, outfile, neg_num) -------------------------------------------------------------------------------- /PROD/ProD_KD/utils/preprae_ce_marcodoc_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import random 4 | import pickle 5 | import sys 6 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1): 7 | def gen(): 8 | for i, line in tqdm(enumerate(fd)): 9 | if i % trainer_num == trainer_id: 10 | slots = line.rstrip('\n').split(delimiter) 11 | if len(slots) == 1: 12 | yield slots, 13 | else: 14 | yield slots 15 | return gen() 16 | 17 | def load_marcodoc_reference_from_stream(path_to_reference): 18 | """Load Reference reference relevant passages 19 | Args:f (stream): stream to load. 20 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 21 | """ 22 | qids_to_relevant_passageids = {} 23 | with open(path_to_reference, 'r') as f: 24 | for l in f: 25 | l = l.strip().split(' ') 26 | qid = int(l[0]) 27 | if qid in qids_to_relevant_passageids: 28 | pass 29 | else: 30 | qids_to_relevant_passageids[qid] = [] 31 | qids_to_relevant_passageids[qid].append(int(l[2][1:])) 32 | return qids_to_relevant_passageids 33 | 34 | def read_qstring(query_file): 35 | q_string = {} 36 | with open(query_file, 'r', encoding='utf-8') as file: 37 | for num, line in enumerate(file): 38 | line = line.strip('\n') 39 | line = line.split('\t') 40 | q_string[int(line[0])] = line[1] 41 | return q_string 42 | 43 | def main(relevance_file, query_file, result_file, outfile, neg_num): 44 | with open(result_file, 'rb') as f: 45 | qids_to_ranked_candidate_passages, qids_to_ranked_candidate_scores = pickle.load(f) 46 | qids_to_relevant_passageids = load_marcodoc_reference_from_stream(relevance_file) 47 | 48 | datalist = [] 49 | q_string = read_qstring(query_file) 50 | for qid in qids_to_ranked_candidate_passages: 51 | examples = {} 52 | if qid in qids_to_relevant_passageids: 53 | target_pid = qids_to_relevant_passageids[qid] 54 | examples['query_id'] = int(qid) 55 | examples['pos_id'] = [int(id) for id in target_pid] 56 | examples['query_string'] = q_string[examples['query_id']] 57 | # neg id,pos id 58 | candidate_pid = qids_to_ranked_candidate_passages[qid] 59 | examples['neg_id'] = [] 60 | 61 | for i, pid in enumerate(candidate_pid): 62 | if pid in examples['pos_id']: 63 | # print(target_pid, pid) 64 | continue 65 | else: 66 | if len(examples['neg_id']) < neg_num: 67 | examples['neg_id'].append(int(pid)) 68 | else: 69 | break 70 | if len(examples['pos_id'])!=0 and len(examples['neg_id'])!=0: 71 | datalist.append(examples) 72 | 73 | 74 | # print(num,len(qids_to_ranked_candidate_passages),num/len(qids_to_ranked_candidate_passages)) 75 | print("data len:", len(datalist)) 76 | print("data keys:", datalist[0].keys()) 77 | print("data info:", datalist[0]) 78 | print("data info:", datalist[100]) 79 | print("data info:", datalist[1000]) 80 | 81 | for data in datalist: 82 | for id in data['neg_id']: 83 | if int(id) in data['pos_id']: 84 | print("data error") 85 | print(data) 86 | exit(0) 87 | if str(id) in data['pos_id']: 88 | print("data error") 89 | print(data) 90 | exit(0) 91 | print("Correctness check completed") 92 | 93 | 94 | with open(outfile, 'w',encoding='utf-8') as f: 95 | json.dump(datalist, f,indent=2) 96 | 97 | if __name__ == "__main__": 98 | relevance_file = sys.argv[1] 99 | query_file = sys.argv[2] 100 | result_file = sys.argv[3] 101 | outfile = sys.argv[4] 102 | neg_num = int(sys.argv[5]) 103 | main(relevance_file, query_file, result_file, outfile, neg_num) -------------------------------------------------------------------------------- /PROD/ProD_base/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_base/model/__init__.py -------------------------------------------------------------------------------- /PROD/ProD_base/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_base/utils/__init__.py -------------------------------------------------------------------------------- /PROD/ProD_base/utils/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | from torch.optim import Optimizer 9 | 10 | 11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 12 | """Log a histogram of trust ratio scalars in across layers.""" 13 | results = collections.defaultdict(list) 14 | for group in optimizer.param_groups: 15 | for p in group['params']: 16 | state = optimizer.state[p] 17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 18 | if i in state: 19 | results[i].append(state[i]) 20 | 21 | for k, v in results.items(): 22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 23 | 24 | class Lamb(Optimizer): 25 | r"""Implements Lamb algorithm. 26 | 27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay) 57 | self.adam = adam 58 | super(Lamb, self).__init__(params, defaults) 59 | 60 | def step(self, closure=None): 61 | """Performs a single optimization step. 62 | 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state['step'] = 0 84 | # Exponential moving average of gradient values 85 | state['exp_avg'] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state['exp_avg_sq'] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 99 | 100 | # Paper v3 does not use debiasing. 101 | # Apply bias to lr to avoid broadcast. 102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 103 | 104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 105 | 106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 107 | if group['weight_decay'] != 0: 108 | adam_step.add_(group['weight_decay'], p.data) 109 | 110 | adam_norm = adam_step.pow(2).sum().sqrt() 111 | if weight_norm == 0 or adam_norm == 0: 112 | trust_ratio = 1 113 | else: 114 | trust_ratio = weight_norm / adam_norm 115 | state['weight_norm'] = weight_norm 116 | state['adam_norm'] = adam_norm 117 | state['trust_ratio'] = trust_ratio 118 | if self.adam: 119 | trust_ratio = 1 120 | 121 | p.data.add_(-step_size * trust_ratio, adam_step) 122 | 123 | return loss 124 | -------------------------------------------------------------------------------- /PROD/image/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/image/framework.jpg -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /SimANS/README.md: -------------------------------------------------------------------------------- 1 | # SimANS: Simple Ambiguous Negatives Sampling for Dense Text Retrieval 2 | 3 | This repository contains the code for our EMNLP2022 paper [***SimANS: Simple Ambiguous Negatives Sampling for Dense Text Retrieval***](https://arxiv.org/abs/2210.11773). 4 | 5 | 6 | ## 🚀 Overview 7 | 8 | We propose ***SimANS***, a simple, general and flexible ambiguous negatives sampling method for dense text retrieval. It can be easily applied to various dense retrieval methods. 9 | 10 | The key code of our SimANS is the implementation of the sampling distribution as: 11 | 12 | $$p_{i} \propto \exp{(-a\cdot(s(q,d_{i})-s(q,\tilde{d}^{+})-b)^{2})}, \forall~d_{i} \in \hat{\mathcal{D}}^{-}$$ 13 | 14 | ```python 15 | def SimANS(pos_pair, neg_pairs_list): 16 | pos_id, pos_score = int(pos_pair[0]), float(pos_pair[1]) 17 | neg_candidates, neg_scores = [], [] 18 | for pair in neg_pairs_list: 19 | neg_id, neg_score = pair 20 | neg_score = math.exp(-(neg_score - pos_score) ** 2 * self.tau) 21 | neg_candidates.append(neg_id) 22 | neg_scores.append(neg_score) 23 | return random.choices(neg_candidates, weights=neg_scores, k=num_hard_negatives) 24 | ``` 25 | 26 | Here we show the main results on [MS MARCO](https://microsoft.github.io/msmarco/), [Natural Questions](https://ai.google.com/research/NaturalQuestions/) and [TriviaQA](http://nlp.cs.washington.edu/triviaqa/). This method outperformes the state-of-the-art methods. 27 | ![SimANS Main Result](figs/simans_main_result.jpg) 28 | 29 | This method has been applied in [Microsoft Bing](https://www.bing.com/), and we also show the results on the industry dataset. 30 | 31 |
32 | 33 | Please find more details in the paper. 34 | 35 | 36 | ## Released Resources 37 | 38 | We release the preprocessed data and trained ckpts in [Azure Blob](https://msranlciropen.blob.core.windows.net/simxns/SimANS/). 39 | Here we also provide the file list under this URL: 40 |
41 | Click here to see the file list. 42 |
INFO: best_simans_ckpt.zip;  Content Length: 7.74 GiB
 43 | INFO: best_simans_ckpt/MS-Doc/checkpoint-25000;  Content Length: 1.39 GiB
 44 | INFO: best_simans_ckpt/MS-Doc/log.txt;  Content Length: 78.32 KiB
 45 | INFO: best_simans_ckpt/MS-Pas/checkpoint-20000;  Content Length: 2.45 GiB
 46 | INFO: best_simans_ckpt/MS-Pas/log.txt;  Content Length: 82.74 KiB
 47 | INFO: best_simans_ckpt/NQ/checkpoint-30000;  Content Length: 2.45 GiB
 48 | INFO: best_simans_ckpt/NQ/log.txt;  Content Length: 298.44 KiB
 49 | INFO: best_simans_ckpt/TQ/checkpoint-10000;  Content Length: 2.45 GiB
 50 | INFO: best_simans_ckpt/TQ/log.txt;  Content Length: 99.44 KiB
 51 | INFO: ckpt.zip;  Content Length: 19.63 GiB
 52 | INFO: ckpt/MS-Doc/adore-star/config.json;  Content Length: 1.37 KiB
 53 | INFO: ckpt/MS-Doc/adore-star/pytorch_model.bin;  Content Length: 480.09 MiB
 54 | INFO: ckpt/MS-Doc/checkpoint-20000;  Content Length: 1.39 GiB
 55 | INFO: ckpt/MS-Doc/checkpoint-reranker20000;  Content Length: 1.39 GiB
 56 | INFO: ckpt/MS-Pas/checkpoint-20000;  Content Length: 2.45 GiB
 57 | INFO: ckpt/MS-Pas/checkpoint-reranker20000;  Content Length: 3.75 GiB
 58 | INFO: ckpt/NQ/checkpoint-reranker26000;  Content Length: 3.75 GiB
 59 | INFO: ckpt/NQ/nq_fintinue.pkl;  Content Length: 2.45 GiB
 60 | INFO: ckpt/TQ/checkpoint-reranker34000;  Content Length: 3.75 GiB
 61 | INFO: ckpt/TQ/triviaqa_fintinue.pkl;  Content Length: 2.45 GiB
 62 | INFO: data.zip;  Content Length: 18.43 GiB
 63 | INFO: data/MS-Doc/dev_ce_0.tsv;  Content Length: 15.97 MiB
 64 | INFO: data/MS-Doc/msmarco-docdev-qrels.tsv;  Content Length: 105.74 KiB
 65 | INFO: data/MS-Doc/msmarco-docdev-queries.tsv;  Content Length: 215.14 KiB
 66 | INFO: data/MS-Doc/msmarco-docs.tsv;  Content Length: 21.32 GiB
 67 | INFO: data/MS-Doc/msmarco-doctrain-qrels.tsv;  Content Length: 7.19 MiB
 68 | INFO: data/MS-Doc/msmarco-doctrain-queries.tsv;  Content Length: 14.76 MiB
 69 | INFO: data/MS-Doc/train_ce_0.tsv;  Content Length: 1.13 GiB
 70 | INFO: data/MS-Pas/dev.query.txt;  Content Length: 283.39 KiB
 71 | INFO: data/MS-Pas/para.title.txt;  Content Length: 280.76 MiB
 72 | INFO: data/MS-Pas/para.txt;  Content Length: 2.85 GiB
 73 | INFO: data/MS-Pas/qrels.dev.tsv;  Content Length: 110.89 KiB
 74 | INFO: data/MS-Pas/qrels.train.addition.tsv;  Content Length: 5.19 MiB
 75 | INFO: data/MS-Pas/qrels.train.tsv;  Content Length: 7.56 MiB
 76 | INFO: data/MS-Pas/train.query.txt;  Content Length: 19.79 MiB
 77 | INFO: data/MS-Pas/train_ce_0.tsv;  Content Length: 1.68 GiB
 78 | INFO: data/NQ/dev_ce_0.json;  Content Length: 632.98 MiB
 79 | INFO: data/NQ/nq-dev.qa.csv;  Content Length: 605.48 KiB
 80 | INFO: data/NQ/nq-test.qa.csv;  Content Length: 289.99 KiB
 81 | INFO: data/NQ/nq-train.qa.csv;  Content Length: 5.36 MiB
 82 | INFO: data/NQ/train_ce_0.json;  Content Length: 5.59 GiB
 83 | INFO: data/TQ/dev_ce_0.json;  Content Length: 646.60 MiB
 84 | INFO: data/TQ/train_ce_0.json;  Content Length: 5.62 GiB
 85 | INFO: data/TQ/trivia-dev.qa.csv;  Content Length: 3.03 MiB
 86 | INFO: data/TQ/trivia-test.qa.csv;  Content Length: 3.91 MiB
 87 | INFO: data/TQ/trivia-train.qa.csv;  Content Length: 26.67 MiB
 88 | INFO: data/psgs_w100.tsv;  Content Length: 12.76 GiB
89 |
90 | 91 | To download the files, please refer to [HOW_TO_DOWNLOAD](https://github.com/microsoft/SimXNS/tree/main/HOW_TO_DOWNLOAD.md). 92 | 93 | 94 | ## 🙋 How to Use 95 | 96 | **⚙️ Environment Setting** 97 | 98 | To faithfully reproduce our results, please use the correct `1.7.1` pytorch version corresponding to your platforms/CUDA versions according to [Released Packages by pytorch](https://anaconda.org/pytorch/pytorch), and install faiss successfully for evaluation. 99 | 100 | We list our command to prepare the experimental environment as follows: 101 | ```bash 102 | conda install pytorch==1.7.1 cudatoolkit=11.0 -c pytorch 103 | conda install faiss-gpu cudatoolkit=11.0 -c pytorch 104 | conda install transformers 105 | pip install tqdm 106 | pip install tensorboardX 107 | pip install lmdb 108 | pip install datasets 109 | pip install wandb 110 | pip install sklearn 111 | pip install boto3 112 | ``` 113 | 114 | **💾 Data and Initial Checkpoint** 115 | 116 | We list the necessary data for training on MS-Pas/MS-Doc/NQ/TQ [here](https://msranlciropen.blob.core.windows.net/simxns/SimANS/data.zip). You can download the compressed file directly or with [Microsoft's AzCopy CLI tool](https://learn.microsoft.com/en-us/azure/storage/common/storage-ref-azcopy) and put the content in `./data`. 117 | If you are working on MS-Pas, you will need this [additional file](https://msranlciropen.blob.core.windows.net/simxns/SimANS/data/MS-Pas/qrels.train.addition.tsv) for training. 118 | 119 | In our approach, we require to use the checkpoint from AR2 for initialization. We release them [here](https://msranlciropen.blob.core.windows.net/simxns/SimANS/ckpt.zip). You can download the all-in-one compressed file and put the content in `./ckpt`. 120 | 121 | 122 | **📋 Training Scripts** 123 | 124 | We provide the training scripts using SimANS on SOTA AR2 model for MS-MARCO-Passage/Document Retrieval, NQ and TQ datasets, and have set up the best hyperparameters for training. You can run it to automatically finish the training and evaluation. 125 | ```bash 126 | bash train_MS_Pas_AR2.sh 127 | bash train_MS_Doc_AR2.sh 128 | bash train_NQ_AR2.sh 129 | bash train_TQ_AR2.sh 130 | ``` 131 | 132 | For results in the paper, we use 8 * A100 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to different performance. 133 | 134 | **⚽ Best SimANS Checkpoint** 135 | 136 | For better reproducing our experimental results, we also release all the checkpoint of our approach [here](https://msranlciropen.blob.core.windows.net/simxns/SimANS/best_simans_ckpt.zip). You can download the compressed file and reuse the content for evaluation. 137 | 138 | 139 | ## 📜 Citation 140 | 141 | Please cite our paper if you use [SimANS](https://arxiv.org/abs/2210.11773) in your work: 142 | ```bibtex 143 | @article{zhou2022simans, 144 | title={SimANS: Simple Ambiguous Negatives Sampling for Dense Text Retrieval}, 145 | author={Kun Zhou, Yeyun Gong, Xiao Liu, Wayne Xin Zhao, Yelong Shen, Anlei Dong, Jingwen Lu, Rangan Majumder, Ji-Rong Wen, Nan Duan and Weizhu Chen}, 146 | booktitle = {{EMNLP}}, 147 | year={2022} 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /SimANS/best_simans_ckpt/MS-Doc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/MS-Doc/.gitkeep -------------------------------------------------------------------------------- /SimANS/best_simans_ckpt/MS-Pas/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/MS-Pas/.gitkeep -------------------------------------------------------------------------------- /SimANS/best_simans_ckpt/NQ/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/NQ/.gitkeep -------------------------------------------------------------------------------- /SimANS/best_simans_ckpt/TQ/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/TQ/.gitkeep -------------------------------------------------------------------------------- /SimANS/ckpt/MS-Doc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/MS-Doc/.gitkeep -------------------------------------------------------------------------------- /SimANS/ckpt/MS-Pas/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/MS-Pas/.gitkeep -------------------------------------------------------------------------------- /SimANS/ckpt/NQ/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/NQ/.gitkeep -------------------------------------------------------------------------------- /SimANS/ckpt/TQ/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/TQ/.gitkeep -------------------------------------------------------------------------------- /SimANS/data/MS-Doc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/MS-Doc/.gitkeep -------------------------------------------------------------------------------- /SimANS/data/MS-Pas/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/MS-Pas/.gitkeep -------------------------------------------------------------------------------- /SimANS/data/NQ/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/NQ/.gitkeep -------------------------------------------------------------------------------- /SimANS/data/TQ/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/TQ/.gitkeep -------------------------------------------------------------------------------- /SimANS/figs/simans_industry_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/figs/simans_industry_result.jpg -------------------------------------------------------------------------------- /SimANS/figs/simans_main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/figs/simans_main.jpg -------------------------------------------------------------------------------- /SimANS/figs/simans_main_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/figs/simans_main_result.jpg -------------------------------------------------------------------------------- /SimANS/train_MS_Doc_AR2.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=co_training_MS_MARCO_Doc_SimANS 2 | Iteration_step=5000 3 | Iteration_reranker_step=1000 4 | MAX_STEPS=40000 5 | 6 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done; 7 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`; 8 | do 9 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 Doc_training/co_training_doc_train.py \ 10 | --model_type=ckpt/MS-Doc/adore-star \ 11 | --model_name_or_path=ckpt/MS-Doc/checkpoint-20000 \ 12 | --max_seq_length=512 --per_gpu_train_batch_size=32 --gradient_accumulation_steps=1 \ 13 | --number_neg=15 --learning_rate=5e-6 \ 14 | --teacher_model_type=roberta-base \ 15 | --teacher_model_path=ckpt/MS-Doc/checkpoint-reranker20000 \ 16 | --teacher_learning_rate=1e-6 \ 17 | --output_dir=ckpt/$EXP_NAME \ 18 | --log_dir=tensorboard/logs/$EXP_NAME \ 19 | --origin_data_dir=data/MS-Doc/train_ce_0.tsv \ 20 | --train_qa_path=data/MS-Doc/msmarco-doctrain-queries.tsv \ 21 | --passage_path=data/MS-Doc \ 22 | --logging_steps=100 --save_steps=5000 --max_steps=$MAX_STEPS \ 23 | --gradient_checkpointing --distill_loss \ 24 | --iteration_step=$Iteration_step \ 25 | --iteration_reranker_step=$Iteration_reranker_step \ 26 | --temperature_distill=1 --ann_dir=ckpt/$EXP_NAME/temp --adv_lambda 1 --global_step=$global_step 27 | 28 | g_global_step=`expr $global_step + $Iteration_step` 29 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 Doc_training/co_training_doc_generate.py \ 30 | --model_type=ckpt/MS-Doc/adore-star \ 31 | --max_seq_length=512 \ 32 | --output_dir=ckpt/$EXP_NAME \ 33 | --log_dir=tensorboard/logs/$EXP_NAME \ 34 | --train_qa_path=data/MS-Doc/msmarco-doctrain-queries.tsv \ 35 | --dev_qa_path=data/MS-Doc/msmarco-docdev-queries.tsv \ 36 | --passage_path=data/MS-Doc \ 37 | --max_steps=$MAX_STEPS \ 38 | --gradient_checkpointing \ 39 | --ann_dir=ckpt/$EXP_NAME/temp --global_step=$g_global_step 40 | done 41 | -------------------------------------------------------------------------------- /SimANS/train_MS_Pas_AR2.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=co_training_MS_MARCO_Pas_SimANS 2 | Iteration_step=5000 3 | Iteration_reranker_step=500 4 | MAX_STEPS=60000 5 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done; 6 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`; 7 | do 8 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 co_training/co_training_marco_train.py \ 9 | --model_type=Luyu/co-condenser-marco \ 10 | --model_name_or_path=ckpt/MS-Pas/checkpoint-20000 \ 11 | --max_seq_length=128 --per_gpu_train_batch_size=16 --gradient_accumulation_steps=2 \ 12 | --number_neg=15 --learning_rate=5e-6 \ 13 | --teacher_model_type=nghuyong/ernie-2.0-large-en \ 14 | --teacher_model_path=ckpt/MS-Pas/checkpoint-reranker20000 \ 15 | --teacher_learning_rate=5e-7 \ 16 | --output_dir=ckpt/$EXP_NAME \ 17 | --log_dir=tensorboard/logs/$EXP_NAME \ 18 | --origin_data_dir=data/MS-Pas/train_ce_0.tsv \ 19 | --train_qa_path=data/MS-Pas/train.query.txt \ 20 | --dev_qa_path=data/MS-Pas/dev.query.txt \ 21 | --passage_path=data/MS-Pas \ 22 | --logging_steps=10 --save_steps=5000 --max_steps=$MAX_STEPS \ 23 | --gradient_checkpointing --distill_loss \ 24 | --iteration_step=$Iteration_step \ 25 | --iteration_reranker_step=$Iteration_reranker_step \ 26 | --temperature_distill=1 --ann_dir=ckpt/$EXP_NAME/temp --adv_lambda 1 --global_step=$global_step 27 | 28 | g_global_step=`expr $global_step + $Iteration_step` 29 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 co_training/co_training_marco_generate.py \ 30 | --model_type=Luyu/co-condenser-marco \ 31 | --max_seq_length=128 \ 32 | --output_dir=ckpt/$EXP_NAME \ 33 | --log_dir=tensorboard/logs/$EXP_NAME \ 34 | --train_qa_path=data/MS-Pas/train.query.txt \ 35 | --dev_qa_path=data/MS-Pas/dev.query.txt \ 36 | --passage_path=data/MS-Pas \ 37 | --max_steps=$MAX_STEPS \ 38 | --gradient_checkpointing --adv_step=0 \ 39 | --iteration_step=$Iteration_step \ 40 | --iteration_reranker_step=$Iteration_reranker_step \ 41 | --ann_dir=ckpt/$EXP_NAME/temp --global_step=$g_global_step 42 | done 43 | -------------------------------------------------------------------------------- /SimANS/train_NQ_AR2.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=co_training_nq_SimANS_test 2 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path 3 | OUT_DIR=output/$EXP_NAME 4 | 5 | DE_CKPT_PATH=ckpt/NQ/nq_fintinue.pkl 6 | CE_CKPT_PATH=ckpt/NQ/checkpoint-reranker26000 7 | Origin_Data_Dir=data/NQ/train_ce_0.json 8 | Origin_Data_Dir_Dev=data/NQ/dev_ce_0.json 9 | 10 | Iteration_step=2000 11 | Iteration_reranker_step=500 12 | MAX_STEPS=30000 13 | 14 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done; 15 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`; 16 | do 17 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_train.py \ 18 | --model_type=nghuyong/ernie-2.0-base-en \ 19 | --model_name_or_path=$DE_CKPT_PATH \ 20 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \ 21 | --number_neg=15 --learning_rate=1e-5 \ 22 | --reranker_model_type=nghuyong/ernie-2.0-large-en \ 23 | --reranker_model_path=$CE_CKPT_PATH \ 24 | --reranker_learning_rate=1e-6 \ 25 | --output_dir=$OUT_DIR \ 26 | --log_dir=$TB_DIR \ 27 | --origin_data_dir=$Origin_Data_Dir \ 28 | --warmup_steps=2000 --logging_steps=100 --save_steps=2000 --max_steps=$MAX_STEPS \ 29 | --gradient_checkpointing --normal_loss \ 30 | --iteration_step=$Iteration_step \ 31 | --iteration_reranker_step=$Iteration_reranker_step \ 32 | --temperature_normal=1 --ann_dir=$OUT_DIR/temp --adv_lambda 0 --global_step=$global_step --b 1.0 33 | 34 | g_global_step=`expr $global_step + $Iteration_step` 35 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_generate.py \ 36 | --model_type=nghuyong/ernie-2.0-base-en \ 37 | --model_name_or_path=$DE_CKPT_PATH \ 38 | --max_seq_length=128 --per_gpu_train_batch_size=8 \ 39 | --output_dir=output/$EXP_NAME \ 40 | --log_dir=tensorboard/logs/$EXP_NAME \ 41 | --origin_data_dir=$Origin_Data_Dir \ 42 | --origin_data_dir_dev=$Origin_Data_Dir_Dev \ 43 | --train_qa_path=data/NQ/nq-train.qa.csv \ 44 | --test_qa_path=data/NQ/nq-test.qa.csv \ 45 | --dev_qa_path=data/NQ/nq-dev.qa.csv \ 46 | --passage_path=data/psgs_w100.tsv \ 47 | --max_steps=$MAX_STEPS \ 48 | --gradient_checkpointing \ 49 | --ann_dir=output/$EXP_NAME/temp --global_step=$g_global_step 50 | done 51 | -------------------------------------------------------------------------------- /SimANS/train_TQ_AR2.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=co_training_tq_SimANS_test 2 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path 3 | OUT_DIR=output/$EXP_NAME 4 | 5 | DE_CKPT_PATH=ckpt/TQ/triviaqa_fintinue.pkl 6 | CE_CKPT_PATH=ckpt/TQ/checkpoint-reranker34000 7 | Origin_Data_Dir=data/TQ/train_ce_0.json 8 | Origin_Data_Dir_Dev=data/TQ/dev_ce_0.json 9 | 10 | Iteration_step=2000 11 | Iteration_reranker_step=500 12 | MAX_STEPS=10000 13 | 14 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done; 15 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`; 16 | do 17 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_train.py \ 18 | --model_type=nghuyong/ernie-2.0-base-en \ 19 | --model_name_or_path=$DE_CKPT_PATH \ 20 | --max_seq_length=256 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \ 21 | --number_neg=15 --learning_rate=5e-6 \ 22 | --reranker_model_type=nghuyong/ernie-2.0-large-en \ 23 | --reranker_model_path=$CE_CKPT_PATH \ 24 | --reranker_learning_rate=1e-6 \ 25 | --output_dir=$OUT_DIR \ 26 | --log_dir=$TB_DIR \ 27 | --origin_data_dir=$Origin_Data_Dir \ 28 | --warmup_steps=1000 --logging_steps=100 --save_steps=2000 --max_steps=$MAX_STEPS \ 29 | --gradient_checkpointing --normal_loss \ 30 | --iteration_step=$Iteration_step \ 31 | --iteration_reranker_step=$Iteration_reranker_step \ 32 | --temperature_normal=1 --ann_dir=$OUT_DIR/temp --adv_lambda 0.0 --global_step=$global_step --a 0.5 --b 0 33 | 34 | g_global_step=`expr $global_step + $Iteration_step` 35 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_generate.py \ 36 | --model_type=nghuyong/ernie-2.0-base-en \ 37 | --model_name_or_path=$DE_CKPT_PATH \ 38 | --max_seq_length=256 --per_gpu_train_batch_size=8 \ 39 | --output_dir=output/$EXP_NAME \ 40 | --log_dir=tensorboard/logs/$EXP_NAME \ 41 | --origin_data_dir=$Origin_Data_Dir \ 42 | --origin_data_dir_dev=$Origin_Data_Dir_Dev \ 43 | --train_qa_path=data/TQ/trivia-train.qa.csv \ 44 | --test_qa_path=data/TQ/trivia-test.qa.csv \ 45 | --dev_qa_path=data/TQ/trivia-dev.qa.csv \ 46 | --passage_path=data/psgs_w100.tsv \ 47 | --max_steps=$MAX_STEPS \ 48 | --gradient_checkpointing \ 49 | --ann_dir=output/$EXP_NAME/temp --global_step=$g_global_step 50 | done 51 | 52 | --------------------------------------------------------------------------------