├── .gitignore
├── ALLIES
├── README.md
├── assets
│ └── model.jpg
├── dataset
│ ├── nq-test.jsonl
│ ├── tqa-test.jsonl
│ └── webq-test.jsonl
├── embedding_generation.ipynb
├── main.py
├── modeling_bert.py
├── retrieval_utils.py
├── tools.py
└── utils.py
├── CAPSTONE
├── README.md
├── generate_query.sh
├── get_doc2query_marco.sh
├── get_msmarco.sh
├── get_trec.sh
├── merge_beir_result.py
├── models
│ ├── co_training_generate_new_train_wiki.py
│ ├── generate_query.py
│ ├── merge_query.py
│ ├── modules.py
│ └── run_de_model_ernie.py
├── preprocess
│ ├── merge_query.py
│ ├── preprocess_msmarco.py
│ └── preprocess_trec.py
├── run_de_model_expand_corpus_cocondenser.sh
├── run_de_model_expand_corpus_cocondenser_step2.sh
└── utils
│ ├── __init__.py
│ ├── data_utils.py
│ ├── dpr_utils.py
│ ├── evaluate_trec.py
│ ├── lamb.py
│ ├── metric_utils.py
│ └── util.py
├── CODE_OF_CONDUCT.md
├── HOW_TO_DOWNLOAD.md
├── LEAD
├── README.md
├── assets
│ ├── main_result.jpg
│ └── model.jpg
├── data_preprocess
│ ├── .DS_Store
│ ├── data_preprocess.py
│ ├── get_data.sh
│ └── hard_negative_generation
│ │ ├── retrieve_hard_negatives_academic.sh
│ │ └── train_de.sh
├── dataset.py
├── distillation
│ ├── distill_from_12cb_to_6de.sh
│ ├── distill_from_12ce_to_6de.sh
│ ├── distill_from_12de_to_6de.sh
│ └── distill_from_24cb_to_12de.sh
├── inference
│ ├── evaluate_12_layer_de.sh
│ └── evaluate_6_layer_de.sh
├── inference_de.py
├── modeling_bert.py
├── modeling_distilbert.py
├── models.py
├── retrieve_hard_negative.py
├── run_LEAD.py
├── run_single_model.py
├── util.py
└── warm_up
│ ├── train_12_layer_ce.sh
│ ├── train_12_layer_col.sh
│ ├── train_12_layer_de.sh
│ ├── train_24_layer_col.sh
│ └── train_6_layer_de.sh
├── LICENSE
├── MASTER
├── README.md
├── figs
│ ├── master_main.jpg
│ └── master_main_result.jpg
├── finetune
│ ├── MS
│ │ ├── co_training_model.py
│ │ ├── co_training_model_ele.py
│ │ ├── co_training_model_ele_continue.py
│ │ ├── inference_de.py
│ │ ├── inference_de_prob.py
│ │ ├── run_ce_model.py
│ │ ├── run_ce_model_ele.py
│ │ └── run_de_model.py
│ ├── ft_MS_MASTER.sh
│ ├── ft_wiki_NQ.sh
│ ├── ft_wiki_TQ.sh
│ ├── model
│ │ ├── models.py
│ │ └── models_ele.py
│ ├── utils
│ │ ├── MARCO_until.py
│ │ ├── dpr_utils.py
│ │ ├── lamb.py
│ │ └── util.py
│ └── wiki
│ │ ├── co_training_model.py
│ │ ├── inference_de.py
│ │ ├── run_ce_model.py
│ │ └── run_de_model.py
└── pretrain
│ ├── arguments.py
│ ├── data.py
│ ├── modeling.py
│ ├── run_pre_training.py
│ ├── run_pretrain.sh
│ └── trainer.py
├── PROD
├── MarcoDoc_Data.sh
├── MarcoPas_Data.sh
├── ProD_KD
│ ├── model
│ │ ├── __init__.py
│ │ └── models.py
│ ├── run_progressive_distill_marco.py
│ ├── run_progressive_distill_marcodoc.py
│ ├── run_progressive_distill_nq.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── build_marco_train.py
│ │ ├── dataset_division_marco.py
│ │ ├── dataset_division_marcodoc.py
│ │ ├── dataset_division_nq.py
│ │ ├── dpr_utils.py
│ │ ├── lamb.py
│ │ ├── marco_until.py
│ │ ├── prepare_ce_data_nq.py
│ │ ├── preprae_ce_marco_train.py
│ │ ├── preprae_ce_marcodoc_train.py
│ │ └── util.py
├── ProD_base
│ ├── inference_DE_marco.py
│ ├── inference_DE_marcodoc.py
│ ├── inference_DE_nq.py
│ ├── inference_DE_trec.py
│ ├── model
│ │ ├── __init__.py
│ │ └── models.py
│ ├── rerank_eval_marco.py
│ ├── rerank_eval_marcodoc.py
│ ├── rerank_eval_nq.py
│ ├── rerank_train_eval_marco.py
│ ├── rerank_train_eval_marcodoc.py
│ ├── rerank_train_eval_nq.py
│ ├── train_CE_model_marco.py
│ ├── train_CE_model_marcodoc.py
│ ├── train_CE_model_nq.py
│ ├── train_DE_model_marco.py
│ ├── train_DE_model_marcodoc.py
│ ├── train_DE_model_nq.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── dpr_utils.py
│ │ ├── lamb.py
│ │ ├── marco_until.py
│ │ └── util.py
├── README.md
└── image
│ └── framework.jpg
├── README.md
├── SECURITY.md
├── SUPPORT.md
└── SimANS
├── Doc_training
├── co_training_doc_generate.py
├── co_training_doc_train.py
├── co_training_generate_new_train.py
└── star_tokenizer.py
├── README.md
├── best_simans_ckpt
├── MS-Doc
│ └── .gitkeep
├── MS-Pas
│ └── .gitkeep
├── NQ
│ └── .gitkeep
└── TQ
│ └── .gitkeep
├── ckpt
├── MS-Doc
│ └── .gitkeep
├── MS-Pas
│ └── .gitkeep
├── NQ
│ └── .gitkeep
└── TQ
│ └── .gitkeep
├── co_training
├── co_training_generate.py
├── co_training_marco_generate.py
└── co_training_marco_train.py
├── data
├── MS-Doc
│ └── .gitkeep
├── MS-Pas
│ └── .gitkeep
├── NQ
│ └── .gitkeep
└── TQ
│ └── .gitkeep
├── figs
├── simans_industry_result.jpg
├── simans_main.jpg
└── simans_main_result.jpg
├── model
└── models.py
├── train_MS_Doc_AR2.sh
├── train_MS_Pas_AR2.sh
├── train_NQ_AR2.sh
├── train_TQ_AR2.sh
├── utils
├── MARCO_until_Doc.py
├── MARCO_until_new.py
├── dpr_utils.py
├── util.py
└── util_wiki.py
└── wiki
├── co_training_generate_new_train_wiki.py
├── co_training_wiki_generate.py
└── co_training_wiki_train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Aa][Rr][Mm]/
27 | [Aa][Rr][Mm]64/
28 | bld/
29 | [Bb]in/
30 | [Oo]bj/
31 | [Ll]og/
32 | [Ll]ogs/
33 |
34 | # Visual Studio 2015/2017 cache/options directory
35 | .vs/
36 | # Uncomment if you have tasks that create the project's static files in wwwroot
37 | #wwwroot/
38 |
39 | # Visual Studio 2017 auto generated files
40 | Generated\ Files/
41 |
42 | # MSTest test Results
43 | [Tt]est[Rr]esult*/
44 | [Bb]uild[Ll]og.*
45 |
46 | # NUnit
47 | *.VisualState.xml
48 | TestResult.xml
49 | nunit-*.xml
50 |
51 | # Build Results of an ATL Project
52 | [Dd]ebugPS/
53 | [Rr]eleasePS/
54 | dlldata.c
55 |
56 | # Benchmark Results
57 | BenchmarkDotNet.Artifacts/
58 |
59 | # .NET Core
60 | project.lock.json
61 | project.fragment.lock.json
62 | artifacts/
63 |
64 | # StyleCop
65 | StyleCopReport.xml
66 |
67 | # Files built by Visual Studio
68 | *_i.c
69 | *_p.c
70 | *_h.h
71 | *.ilk
72 | *.meta
73 | *.obj
74 | *.iobj
75 | *.pch
76 | *.pdb
77 | *.ipdb
78 | *.pgc
79 | *.pgd
80 | *.rsp
81 | *.sbr
82 | *.tlb
83 | *.tli
84 | *.tlh
85 | *.tmp
86 | *.tmp_proj
87 | *_wpftmp.csproj
88 | *.log
89 | *.vspscc
90 | *.vssscc
91 | .builds
92 | *.pidb
93 | *.svclog
94 | *.scc
95 |
96 | # Chutzpah Test files
97 | _Chutzpah*
98 |
99 | # Visual C++ cache files
100 | ipch/
101 | *.aps
102 | *.ncb
103 | *.opendb
104 | *.opensdf
105 | *.sdf
106 | *.cachefile
107 | *.VC.db
108 | *.VC.VC.opendb
109 |
110 | # Visual Studio profiler
111 | *.psess
112 | *.vsp
113 | *.vspx
114 | *.sap
115 |
116 | # Visual Studio Trace Files
117 | *.e2e
118 |
119 | # TFS 2012 Local Workspace
120 | $tf/
121 |
122 | # Guidance Automation Toolkit
123 | *.gpState
124 |
125 | # ReSharper is a .NET coding add-in
126 | _ReSharper*/
127 | *.[Rr]e[Ss]harper
128 | *.DotSettings.user
129 |
130 | # TeamCity is a build add-in
131 | _TeamCity*
132 |
133 | # DotCover is a Code Coverage Tool
134 | *.dotCover
135 |
136 | # AxoCover is a Code Coverage Tool
137 | .axoCover/*
138 | !.axoCover/settings.json
139 |
140 | # Visual Studio code coverage results
141 | *.coverage
142 | *.coveragexml
143 |
144 | # NCrunch
145 | _NCrunch_*
146 | .*crunch*.local.xml
147 | nCrunchTemp_*
148 |
149 | # MightyMoose
150 | *.mm.*
151 | AutoTest.Net/
152 |
153 | # Web workbench (sass)
154 | .sass-cache/
155 |
156 | # Installshield output folder
157 | [Ee]xpress/
158 |
159 | # DocProject is a documentation generator add-in
160 | DocProject/buildhelp/
161 | DocProject/Help/*.HxT
162 | DocProject/Help/*.HxC
163 | DocProject/Help/*.hhc
164 | DocProject/Help/*.hhk
165 | DocProject/Help/*.hhp
166 | DocProject/Help/Html2
167 | DocProject/Help/html
168 |
169 | # Click-Once directory
170 | publish/
171 |
172 | # Publish Web Output
173 | *.[Pp]ublish.xml
174 | *.azurePubxml
175 | # Note: Comment the next line if you want to checkin your web deploy settings,
176 | # but database connection strings (with potential passwords) will be unencrypted
177 | *.pubxml
178 | *.publishproj
179 |
180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
181 | # checkin your Azure Web App publish settings, but sensitive information contained
182 | # in these scripts will be unencrypted
183 | PublishScripts/
184 |
185 | # NuGet Packages
186 | *.nupkg
187 | # NuGet Symbol Packages
188 | *.snupkg
189 | # The packages folder can be ignored because of Package Restore
190 | **/[Pp]ackages/*
191 | # except build/, which is used as an MSBuild target.
192 | !**/[Pp]ackages/build/
193 | # Uncomment if necessary however generally it will be regenerated when needed
194 | #!**/[Pp]ackages/repositories.config
195 | # NuGet v3's project.json files produces more ignorable files
196 | *.nuget.props
197 | *.nuget.targets
198 |
199 | # Microsoft Azure Build Output
200 | csx/
201 | *.build.csdef
202 |
203 | # Microsoft Azure Emulator
204 | ecf/
205 | rcf/
206 |
207 | # Windows Store app package directories and files
208 | AppPackages/
209 | BundleArtifacts/
210 | Package.StoreAssociation.xml
211 | _pkginfo.txt
212 | *.appx
213 | *.appxbundle
214 | *.appxupload
215 |
216 | # Visual Studio cache files
217 | # files ending in .cache can be ignored
218 | *.[Cc]ache
219 | # but keep track of directories ending in .cache
220 | !?*.[Cc]ache/
221 |
222 | # Others
223 | ClientBin/
224 | ~$*
225 | *~
226 | *.dbmdl
227 | *.dbproj.schemaview
228 | *.jfm
229 | *.pfx
230 | *.publishsettings
231 | orleans.codegen.cs
232 |
233 | # Including strong name files can present a security risk
234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
235 | #*.snk
236 |
237 | # Since there are multiple workflows, uncomment next line to ignore bower_components
238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
239 | #bower_components/
240 |
241 | # RIA/Silverlight projects
242 | Generated_Code/
243 |
244 | # Backup & report files from converting an old project file
245 | # to a newer Visual Studio version. Backup files are not needed,
246 | # because we have git ;-)
247 | _UpgradeReport_Files/
248 | Backup*/
249 | UpgradeLog*.XML
250 | UpgradeLog*.htm
251 | ServiceFabricBackup/
252 | *.rptproj.bak
253 |
254 | # SQL Server files
255 | *.mdf
256 | *.ldf
257 | *.ndf
258 |
259 | # Business Intelligence projects
260 | *.rdl.data
261 | *.bim.layout
262 | *.bim_*.settings
263 | *.rptproj.rsuser
264 | *- [Bb]ackup.rdl
265 | *- [Bb]ackup ([0-9]).rdl
266 | *- [Bb]ackup ([0-9][0-9]).rdl
267 |
268 | # Microsoft Fakes
269 | FakesAssemblies/
270 |
271 | # GhostDoc plugin setting file
272 | *.GhostDoc.xml
273 |
274 | # Node.js Tools for Visual Studio
275 | .ntvs_analysis.dat
276 | node_modules/
277 |
278 | # Visual Studio 6 build log
279 | *.plg
280 |
281 | # Visual Studio 6 workspace options file
282 | *.opt
283 |
284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
285 | *.vbw
286 |
287 | # Visual Studio LightSwitch build output
288 | **/*.HTMLClient/GeneratedArtifacts
289 | **/*.DesktopClient/GeneratedArtifacts
290 | **/*.DesktopClient/ModelManifest.xml
291 | **/*.Server/GeneratedArtifacts
292 | **/*.Server/ModelManifest.xml
293 | _Pvt_Extensions
294 |
295 | # Paket dependency manager
296 | .paket/paket.exe
297 | paket-files/
298 |
299 | # FAKE - F# Make
300 | .fake/
301 |
302 | # CodeRush personal settings
303 | .cr/personal
304 |
305 | # Python Tools for Visual Studio (PTVS)
306 | __pycache__/
307 | *.pyc
308 |
309 | # Cake - Uncomment if you are using it
310 | # tools/**
311 | # !tools/packages.config
312 |
313 | # Tabs Studio
314 | *.tss
315 |
316 | # Telerik's JustMock configuration file
317 | *.jmconfig
318 |
319 | # BizTalk build output
320 | *.btp.cs
321 | *.btm.cs
322 | *.odx.cs
323 | *.xsd.cs
324 |
325 | # OpenCover UI analysis results
326 | OpenCover/
327 |
328 | # Azure Stream Analytics local run output
329 | ASALocalRun/
330 |
331 | # MSBuild Binary and Structured Log
332 | *.binlog
333 |
334 | # NVidia Nsight GPU debugger configuration file
335 | *.nvuser
336 |
337 | # MFractors (Xamarin productivity tool) working folder
338 | .mfractor/
339 |
340 | # Local History for Visual Studio
341 | .localhistory/
342 |
343 | # BeatPulse healthcheck temp database
344 | healthchecksdb
345 |
346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
347 | MigrationBackup/
348 |
349 | # Ionide (cross platform F# VS Code tools) working folder
350 | .ionide/
351 |
--------------------------------------------------------------------------------
/ALLIES/README.md:
--------------------------------------------------------------------------------
1 | # ALLIES
2 |
3 | The code for our paper [ALLIES: Prompting Large Language Model with Beam Search](https://arxiv.org/abs/2305.14766).
4 |
5 | 
6 |
7 | ## Dataset
8 |
9 | ### NQ
10 |
11 | dataset/nq-test.jsonl
12 |
13 | ### TriviaQA
14 |
15 | dataset/tqa-test.jsonl
16 |
17 | ### WebQ
18 |
19 | dataset/webq-test.jsonl
20 |
21 |
22 | ## Released Resources
23 |
24 | We release the preprocessed data and trained ckpts in [Azure Blob](https://msranlciropen.blob.core.windows.net/simxns/ALLIES/).
25 | Here we also provide the file list under this URL:
26 |
27 | Click here to see the file list.
28 | INFO: nq/de-checkpoint-10000/passage_embedding.pb; Content Length: 60.13 GiB
29 | INFO: nq/de-checkpoint-10000/passage_embedding2id.pb; Content Length: 160.33 MiB
30 | INFO: webq/de-checkpoint-400/passage_embedding.pb; Content Length: 60.13 GiB
31 | INFO: webq/de-checkpoint-400/passage_embedding2id.pb; Content Length: 160.33 MiB
32 | INFO: tq/de-checkpoint-10000/passage_embedding.pb; Content Length: 60.13 GiB
33 | INFO: tq/de-checkpoint-10000/passage_embedding2id.pb; Content Length: 160.33 MiB
34 |
35 |
36 | To download the files, please refer to [HOW_TO_DOWNLOAD](https://github.com/microsoft/SimXNS/tree/main/HOW_TO_DOWNLOAD.md).
37 |
38 |
39 |
40 | ## Run
41 |
42 | ### Directly Answer
43 |
44 | ```
45 | python main.py --dataset $dataset --task answer_without_retrieval --apikey $ID
46 | ```
47 |
48 | ### Answer with retrieval
49 |
50 | ```
51 | python main.py --dataset $dataset --task answer_with_retrieval --topK $retrieval_num --apikey $ID
52 | ```
53 |
54 | ### GenRead
55 |
56 | ```
57 | python main.py --dataset $dataset --task genread --apikey $ID
58 | ```
59 |
60 | ### Allies
61 |
62 | ```
63 | ##GENREAD
64 | python main.py --dataset $dataset --task ALLIES --retrieval_type generate --beam_size $beam_size --beam_Depth $beam_depth --ask_question_num $ask_question_num --apikey $ID
65 |
66 | ##Retrieval
67 | python main.py --dataset $dataset --task ALLIES --topK $retrieval_num --retrieval_type retrieve --beam_size $beam_size --beam_Depth $beam_depth --ask_question_num $ask_question_num --apikey $ID
68 | ```
69 |
70 | ## Parameters
71 |
72 | - $dataset: Dataset for testing
73 | - $ID: The key for API
74 | - $beam_size: Beam size
75 | - $beam_depth: Beam depth
76 | - $ask_question_num: Ask question number
77 | - $retrieval_num: Retrieval doc num
78 |
--------------------------------------------------------------------------------
/ALLIES/assets/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/ALLIES/assets/model.jpg
--------------------------------------------------------------------------------
/ALLIES/embedding_generation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "92874557",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from tqdm import trange\n",
11 | "import os\n",
12 | "import pickle"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "id": "542ce71c",
18 | "metadata": {},
19 | "source": [
20 | "## Step 1: Prepare the training dataset for QA"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "f92b9312",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "train_data_list = []\n",
31 | "\n",
32 | "for qid in trange(len(dataset)):\n",
33 | " example = {}\n",
34 | " \n",
35 | " positive_ctxs = []\n",
36 | " \n",
37 | " example['question'] = dataset[qid]['question']\n",
38 | " \n",
39 | " example['answers'] = dataset[qid]['answer']\n",
40 | " \n",
41 | " train_data_list.append(example)\n",
42 | "\n",
43 | "with open(f'/{data_path}/biencoder-dataset-train.json', 'w', encoding='utf-8') as json_file:\n",
44 | " json_file.write(json.dumps(train_data_list, indent=4))"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "id": "e8a951ff",
50 | "metadata": {},
51 | "source": [
52 | "## Step 2: Retrieve the initial documents and generate the training data for dense retrieval"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "id": "37c2f93b",
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "## Please follow the instruction in LEAD (https://github.com/microsoft/SimXNS/tree/main/LEAD)\n",
63 | "\n",
64 | "## bash retrieve_hard_negatives_academic.sh $DATASET $MASTER_PORT $MODEL_PATH $CKPT_NAME $MAX_DOC_LENGTH $MAX_QUERY_LENGTH"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "id": "9c017de6",
70 | "metadata": {},
71 | "source": [
72 | "## Step 3: Train the retriever using the generated dataset"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "id": "866ad0b7",
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "## Please follow the instruction in LEAD (https://github.com/microsoft/SimXNS/tree/main/LEAD)\n",
83 | "\n",
84 | "## bash train_12_layer_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "id": "4d370f43",
90 | "metadata": {},
91 | "source": [
92 | "## Step 4: Output the embedding file using the trained checkpoint"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "id": "f47e70ec",
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "## Please follow the instruction in LEAD (https://github.com/microsoft/SimXNS/tree/main/LEAD)\n",
103 | "\n",
104 | "## bash evaluate_12_layer_de.sh $DATASET $MASTER_PORT $MODEL_PATH $FIRST_STEPS $EVAL_STEPS $MAX_STEPS $MAX_DOC_LENGTH $MAX_QUERY_LENGTH"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "id": "ab021dbd",
110 | "metadata": {},
111 | "source": [
112 | "## Step 5: Merge the inference the embedding file as follow:"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "id": "23acb765",
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "passage_embedding_list = []\n",
123 | "passage_embedding_id_list = []\n",
124 | "for i in trange(N): \n",
125 | " pickle_path = os.path.join(f'/{data_path}/', \"{1}_data_obj_{0}.pb\".format(str(i), 'passage_embedding'))\n",
126 | " with open(pickle_path, 'rb') as handle:\n",
127 | " b = pickle.load(handle)\n",
128 | " passage_embedding_list.append(b)\n",
129 | "for i in trange(N): \n",
130 | " pickle_path = os.path.join(f'/{data_path}/', \"{1}_data_obj_{0}.pb\".format(str(i), 'passage_embedding_id'))\n",
131 | " with open(pickle_path, 'rb') as handle:\n",
132 | " b = pickle.load(handle)\n",
133 | " passage_embedding_id_list.append(b)\n",
134 | "passage_embedding = np.concatenate(passage_embedding_list, axis=0)\n",
135 | "passage_embedding_id = np.concatenate(passage_embedding_id_list, axis=0)\n",
136 | "\n",
137 | "with open(f'/{data_path}/passage_embedding.pb', 'wb') as f:\n",
138 | " pickle.dump(passage_embedding, f)\n",
139 | "f.close()\n",
140 | "\n",
141 | "with open(f'/{data_path}/passage_embedding2id.pb', 'wb') as f:\n",
142 | " pickle.dump(passage_embedding_id, f)\n",
143 | "f.close()"
144 | ]
145 | }
146 | ],
147 | "metadata": {
148 | "kernelspec": {
149 | "display_name": "py37",
150 | "language": "python",
151 | "name": "py37"
152 | },
153 | "language_info": {
154 | "codemirror_mode": {
155 | "name": "ipython",
156 | "version": 3
157 | },
158 | "file_extension": ".py",
159 | "mimetype": "text/x-python",
160 | "name": "python",
161 | "nbconvert_exporter": "python",
162 | "pygments_lexer": "ipython3",
163 | "version": "3.10.9"
164 | }
165 | },
166 | "nbformat": 4,
167 | "nbformat_minor": 5
168 | }
169 |
--------------------------------------------------------------------------------
/ALLIES/utils.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import json
3 | import time
4 | from tqdm import trange
5 | import regex
6 | import json
7 | import string
8 | import unicodedata
9 | from typing import List
10 | import numpy as np
11 | from collections import Counter
12 | from rouge import Rouge
13 | import json
14 | import openai
15 | import backoff
16 | import os
17 | from multiprocessing.pool import ThreadPool
18 | import threading
19 | import time
20 | import datetime
21 | from msal import PublicClientApplication, SerializableTokenCache
22 | import json
23 | import os
24 | import atexit
25 | import requests
26 | from utils import *
27 | from retrieval_utils import *
28 | import _thread
29 | from contextlib import contextmanager
30 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
31 |
32 |
33 | import signal
34 |
35 | cache = {}
36 | default_engine = None
37 |
38 | from transformers import AutoTokenizer, AutoModel
39 |
40 | class SimpleTokenizer(object):
41 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
42 | NON_WS = r'[^\p{Z}\p{C}]'
43 |
44 | def __init__(self):
45 | """
46 | Args:
47 | annotators: None or empty set (only tokenizes).
48 | """
49 | self._regexp = regex.compile(
50 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
51 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
52 | )
53 |
54 | def tokenize(self, text, uncased=False):
55 | matches = [m for m in self._regexp.finditer(text)]
56 | if uncased:
57 | tokens = [m.group().lower() for m in matches]
58 | else:
59 | tokens = [m.group() for m in matches]
60 | return tokens
61 |
62 | def readfiles(infile):
63 |
64 | if infile.endswith('json'):
65 | lines = json.load(open(infile, 'r', encoding='utf8'))
66 | elif infile.endswith('jsonl'):
67 | lines = open(infile, 'r', encoding='utf8').readlines()
68 | lines = [json.loads(l) for l in lines]
69 | else:
70 | raise NotImplementedError
71 |
72 | if len(lines[0]) == 1 and lines[0].get('prompt'):
73 | lines = lines[1:] ## skip prompt line
74 |
75 | return lines
76 |
77 | def write(log_file, text):
78 | log_file.write(text + '\n')
79 | log_file.flush()
80 |
81 | def create_directory(path):
82 | try:
83 | os.makedirs(path)
84 | print(f"Created directory: {path}")
85 | except FileExistsError:
86 | print(f"Directory already exists: {path}")
87 |
88 | def sparse_retrieval(searcher_sparse, query, top_K):
89 | hits_sparse = searcher_sparse.search(query)
90 | result = []
91 | for i in range(top_K):
92 | result.append([field.stringValue() for field in hits_sparse[i].lucene_document.getFields()][1])
93 | return result
94 |
95 | def retrieval(questions, top_k, tokenizer, model, cpu_index, passage_embedding2id, passages, args):
96 | question_embedding = get_question_embeddings(questions, tokenizer, model, args)
97 | _, dev_I = cpu_index.search(question_embedding.astype(np.float32), top_k) # I: [number of queries, topk]
98 | topk_document = [passage_embedding2id[dev_I[0][index]] for index in range(len(dev_I[0]))]
99 |
100 | return [passages[doc][1] for doc in topk_document]
101 |
102 | def check_answer(example, tokenizer) -> List[bool]:
103 | """Search through all the top docs to see if they have any of the answers."""
104 | answers = example['answers']
105 | ctxs = example['ctxs']
106 |
107 | hits = []
108 |
109 | for _, doc in enumerate(ctxs):
110 | text = doc['text']
111 |
112 | if text is None: # cannot find the document for some reason
113 | hits.append(False)
114 | continue
115 |
116 | hits.append(has_answer(answers, text, tokenizer))
117 |
118 | return hits
119 |
120 | def has_answer(answers, text, tokenizer=SimpleTokenizer()) -> bool:
121 | """Check if a document contains an answer string."""
122 | text = _normalize(text)
123 | text = tokenizer.tokenize(text, uncased=True)
124 |
125 | for answer in answers:
126 | answer = _normalize(answer)
127 | answer = tokenizer.tokenize(answer, uncased=True)
128 | for i in range(0, len(text) - len(answer) + 1):
129 | if answer == text[i: i + len(answer)]:
130 | return True
131 | return False
132 |
133 | def _normalize(text):
134 | return unicodedata.normalize('NFD', text)
135 |
136 | def normalize_answer(s):
137 | def remove_articles(text):
138 | return regex.sub(r'\b(a|an|the)\b', ' ', text)
139 |
140 | def white_space_fix(text):
141 | return ' '.join(text.split())
142 |
143 | def remove_punc(text):
144 | exclude = set(string.punctuation)
145 | return ''.join(ch for ch in text if ch not in exclude)
146 |
147 | def lower(text):
148 | return text.lower()
149 |
150 | return white_space_fix(remove_articles(remove_punc(lower(s))))
151 |
152 | def exact_match_score(prediction, ground_truth):
153 | return normalize_answer(prediction) == normalize_answer(ground_truth)
154 |
155 | def ems(prediction, ground_truths):
156 | return max([exact_match_score(prediction, gt) for gt in ground_truths])
157 |
158 | def f1_score(prediction, ground_truth):
159 | prediction_tokens = normalize_answer(prediction).split()
160 | ground_truth_tokens = normalize_answer(ground_truth).split()
161 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
162 | num_same = sum(common.values())
163 | if num_same == 0:
164 | return 0
165 | precision = 1.0 * num_same / len(prediction_tokens)
166 | recall = 1.0 * num_same / len(ground_truth_tokens)
167 | f1 = (2 * precision * recall) / (precision + recall)
168 | return f1
169 |
170 | def f1(prediction, ground_truths):
171 | return max([f1_score(prediction, gt) for gt in ground_truths])
172 |
173 | def rougel_score(prediction, ground_truth):
174 | rouge = Rouge()
175 | # no normalization
176 | try:
177 | scores = rouge.get_scores(prediction, ground_truth, avg=True)
178 | except ValueError: # "Hypothesis is empty."
179 | return 0.0
180 | return scores["rouge-l"]["f"]
181 |
182 | def rl(prediction, ground_truths):
183 | return max([rougel_score(prediction, gt) for gt in ground_truths])
184 |
185 | def run_inference_openai(inputs_with_prompts):
186 | for _ in range(200):
187 | try:
188 | completions = openai.ChatCompletion.create(
189 | model="gpt-3.5-turbo",
190 | messages=[
191 | {"role": "system", "content": "You are a helpful AI assistant."},
192 | {"role": "user", "content": inputs_with_prompts},
193 | ],
194 | temperature=0,
195 | )
196 | except:
197 | continue
198 |
199 | outputs = [c["message"]['content'] for c in completions["choices"]]
200 | return outputs, completions['usage']['total_tokens']
201 |
202 | def print_args(args):
203 | print("=================================================================")
204 | print("======================General Setting=========================")
205 | print(f"Dataset:".rjust(30) + " " + str(args.dataset))
206 | print(f"Data Path:".rjust(30) + " " + str(args.data_path))
207 | print(f"Task:".rjust(30) + " " + str(args.task))
208 | print(f"Retrieval TopK:".rjust(30) + " " + str(args.topK))
209 | print(f"Beam Size:".rjust(30) + " " + str(args.beam_size))
210 | print(f"Beam Depth:".rjust(30) + " " + str(args.beam_Depth))
211 | print(f"Ask Question Num:".rjust(30) + " " + str(args.ask_question_num))
212 | print(f"Threshold:".rjust(30) + " " + str(args.threshold))
213 | print(f"Device:".rjust(30) + " " + str(args.device))
214 | print(f"Summary:".rjust(30) + " " + str(args.summary))
215 | print(f"Retrieval Type:".rjust(30) + " " + str(args.retrieval_type))
216 | print(f"Identifier:".rjust(30) + " " + str(args.unique_identifier))
217 | print(f"Save File:".rjust(30) + " " + str(args.save_file))
218 | print("=================================================================")
219 | print("=================================================================")
220 |
221 | def set_openai(args):
222 | openai.api_key = args.apikey
223 |
--------------------------------------------------------------------------------
/CAPSTONE/README.md:
--------------------------------------------------------------------------------
1 | # CAPSTONE
2 | This repo provides the code and models in [*CAPSTONE*](https://arxiv.org/abs/2212.09114).
3 | In the paper, we propose CAPSTONE, a curriculum sampling for dense retrieval with document expansion, to bridge the gap between training and inference for dual-cross-encoder.
4 |
5 |
6 | # Requirements
7 | pip install datasets==2.4.0
8 | pip install rouge_score==0.1.2
9 | pip install nltk
10 | pip install transformers==4.21.1
11 |
12 | conda install -y -c pytorch faiss-gpu==1.7.1
13 | pip install tensorboard
14 | pip install pytrec-eval
15 |
16 | # Get Data
17 | Download the cleaned corpus hosted by RocketQA team, generate BM25 negatives for MS-MARCO. Then, download [TREC-Deep-Learning-2019](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019)(TREC DL 19) and [TREC-Deep-Learning-2020](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2020)(TREC DL 20).
18 |
19 | ```
20 | bash get_msmarco.sh
21 | cd ..
22 | bash get_trec.sh
23 | python preprocess/preprocess_msmarco.py --data_dir ./marco --output_path ./reformed_marco
24 | python preprocess/preprocess_trec.py
25 | ```
26 |
27 | Download the generated queries for MS-MARCO and merge duplicated queries. Note the query file is around 19GB.
28 | ```
29 | bash get_doc2query_marco.sh
30 | python preprocess/merge_query.py
31 | ```
32 |
33 | # Train CAPSTONE on MS-MARCO
34 |
35 | Train CAPSTONE on MS-MARCO for two stages and initialize the retriever with coCondenser at each stage.
36 | At the first training stage, the hard negatives are sampled from the official BM25 hard negatives, but at the second training stage, the hard negatives are sampled from the mined hard negatives.
37 |
38 | Evaluate CAPSTONE on the MS-MARCO development set, TREC DL 19 and 20 test sets.
39 | ```bash
40 | bash run_de_model_expand_corpus_cocondenser.sh
41 | bash run_de_model_expand_corpus_cocondenser_step2.sh
42 | ```
43 | If you want to evaluate CAPSTONE on BEIR benchmark, you should download [BEIR](https://github.com/beir-cellar/beir) datasets and generate queries for BEIR.
44 | ```bash
45 | bash generate_query.sh
46 | ```
47 |
48 | # Well-trained Checkpoints
49 | | Model | Download link
50 | |----------------------|--------|
51 | | The checkpoint for CAPSTONE trained on MS-MARCO at the first step| [\[link\]](https://drive.google.com/file/d/1QTsHQV8BQDJmxGD--fr4hmYdjpizE0zq/view?usp=sharing) |
52 | | The checkpoint for CAPSTONE trained on MS-MARCO at the second step| [\[link\]](https://drive.google.com/file/d/1tssOGIRwwXpn2yg4StiRC3B3u0_UqzLG/view?usp=sharing) |
53 |
54 |
55 |
56 |
57 |
58 | # Citation
59 | If you want to use this code in your research, please cite our [paper](https://arxiv.org/abs/2212.09114):
60 | ```bibtex
61 | @inproceedings{he-CAPSTONE,
62 | title = "CAPSTONE: Curriculum Sampling for Dense Retrieval with Document Expansion",
63 | author = "Xingwei He and Yeyun Gong and A-Long Jin and Hang Zhang and Anlei Dong and Jian Jiao and Siu Ming Yiu and Nan Duan",
64 | booktitle = "Proceedings of EMNLP",
65 | year = "2023",
66 | }
67 | ```
--------------------------------------------------------------------------------
/CAPSTONE/generate_query.sh:
--------------------------------------------------------------------------------
1 | # generate queries for beir data
2 | for dataset in cqadupstack/android cqadupstack/english cqadupstack/gaming cqadupstack/gis cqadupstack/mathematica cqadupstack/physics cqadupstack/programmers cqadupstack/stats cqadupstack/tex cqadupstack/unix cqadupstack/webmasters cqadupstack/wordpress \
3 | quora robust04 trec-news nq signal1m dbpedia-entity \
4 | nfcorpus scifact arguana scidocs fiqa trec-covid webis-touche2020 \
5 | bioasq hotpotqa fever climate-fever
6 | do
7 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9589 \
8 | ./models/generate_query.py --dataset $dataset
9 | done
10 |
--------------------------------------------------------------------------------
/CAPSTONE/get_doc2query_marco.sh:
--------------------------------------------------------------------------------
1 |
2 | # get the generated query from https://github.com/castorini/docTTTTTquery
3 | mkdir tmp
4 | mkdir docTTTTTquery_full
5 | cd tmp
6 | wget https://git.uwaterloo.ca/jimmylin/doc2query-data/raw/master/T5-passage/predicted_queries_topk_sampling.zip
7 |
8 | unzip predicted_queries_topk_sampling.zip
9 |
10 | for i in $(seq -f "%03g" 0 17); do
11 | echo "Processing chunk $i"
12 | paste predicted_queries_topk_sample???.txt${i}-1004000 \
13 | > predicted_queries_topk.txt${i}-1004000
14 | done
15 |
16 | cat predicted_queries_topk.txt???-1004000 > doc2query.tsv
17 | mv doc2query.tsv ../docTTTTTquery_full
18 | cd ..
19 | rm -rf tmp
20 |
--------------------------------------------------------------------------------
/CAPSTONE/get_msmarco.sh:
--------------------------------------------------------------------------------
1 | # this shell is from https://github.com/NLPCode/tevatron/blob/main/examples/coCondenser-marco/get_data.sh
2 | SCRIPT_DIR=$PWD
3 |
4 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz
5 | tar -zxf marco.tar.gz
6 | rm -rf marco.tar.gz
7 | cd marco
8 |
9 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz
10 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv
11 | gunzip qidpidtriples.train.full.2.tsv.gz
12 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv
13 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv
14 |
15 |
16 |
--------------------------------------------------------------------------------
/CAPSTONE/get_trec.sh:
--------------------------------------------------------------------------------
1 | # this shell is from https://github.com/NLPCode/tevatron/blob/main/examples/coCondenser-marco/get_data.sh
2 | SCRIPT_DIR=$PWD
3 |
4 | #https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019
5 | mkdir trec_19
6 | cd trec_19
7 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz
8 | gunzip msmarco-test2019-queries.tsv.gz
9 | rm msmarco-test2019-queries.tsv.gz
10 | wget --no-check-certificate https://trec.nist.gov/data/deep/2019qrels-pass.txt
11 |
12 | cd ..
13 |
14 | #https://microsoft.github.io/msmarco/TREC-Deep-Learning-2020
15 | mkdir trec_20
16 | cd trec_20
17 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz
18 | gunzip msmarco-test2020-queries.tsv.gz
19 | rm msmarco-test2020-queries.tsv.gz
20 | wget --no-check-certificate https://trec.nist.gov/data/deep/2020qrels-pass.txt
--------------------------------------------------------------------------------
/CAPSTONE/merge_beir_result.py:
--------------------------------------------------------------------------------
1 | """
2 | this file is used to merge the beir results of each dataset.
3 | """
4 | import argparse
5 | import os
6 | import json
7 | corpus_list = ['trec-covid', 'bioasq', 'nfcorpus', 'nq', 'hotpotqa', 'fiqa', 'signal1m', 'trec-news', 'robust04', 'arguana',
8 | 'webis-touche2020', 'cqadupstack', 'quora', 'dbpedia-entity', 'scidocs', 'fever', 'climate-fever', 'scifact' ]
9 | corpus_list2 = ["cqadupstack/android", "cqadupstack/english", "cqadupstack/gaming", "cqadupstack/gis", "cqadupstack/mathematica", "cqadupstack/physics",
10 | "cqadupstack/programmers", "cqadupstack/stats", "cqadupstack/tex", "cqadupstack/unix", "cqadupstack/webmasters", "cqadupstack/wordpress"]
11 |
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--beir_data_path", default=None, type=str, help="path of a dataset in beir")
16 | parser.add_argument("--checkpoint_num", default=20000, type=int)
17 | args = parser.parse_args()
18 | print(os.path.join(args.beir_data_path, f'test_eval_result{args.checkpoint_num}.json'))
19 | with open(os.path.join(args.beir_data_path, f'test_eval_result{args.checkpoint_num}.txt'), 'w') as fw:
20 | fw.write("NDCG@10\n")
21 | # fw.write('K=10\n')
22 | # total_value = 0
23 | # for corpus in corpus_list:
24 | # if corpus=='cqadupstack':
25 | # total = 0
26 | # for subcorpus in corpus_list2:
27 | # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}_0_query.json')
28 | # with open(filename, 'r') as fr:
29 | # results = json.load(fr)
30 | # total += float(results['NDCG@10'])
31 | # value = total/len(corpus_list2)
32 | # total_value += value
33 | # value = str(value)
34 |
35 | # else:
36 | # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}_0_query.json')
37 | # if os.path.exists(filename):
38 | # with open(filename, 'r') as fr:
39 | # results = json.load(fr)
40 | # value = results['NDCG@10']
41 | # else:
42 | # value ='0'
43 | # print(f'{corpus} no results.')
44 | # total_value += float(value)
45 | # fw.write(f'{corpus}: {value:.3}\n')
46 | # fw.write(f'Average: {total_value/len(corpus_list):.3}\n')
47 |
48 | # fw.write('K=5\n')
49 | # total_value = 0
50 | # for corpus in corpus_list:
51 | # if corpus=='cqadupstack':
52 | # total = 0
53 | # for subcorpus in corpus_list2:
54 | # # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}.json')
55 | # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}_0_5_query.json')
56 | # with open(filename, 'r') as fr:
57 | # results = json.load(fr)
58 | # total += float(results['NDCG@10'])
59 | # value = total/len(corpus_list2)
60 | # total_value += value
61 | # value = str(value)
62 |
63 | # else:
64 | # # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}.json')
65 | # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}_0_5_query.json')
66 | # if os.path.exists(filename):
67 | # with open(filename, 'r') as fr:
68 | # results = json.load(fr)
69 | # value = results['NDCG@10']
70 | # else:
71 | # value ='0'
72 | # print(f'{corpus} no results.')
73 | # total_value += float(value)
74 | # fw.write(f'{corpus}: {value:.3}\n')
75 | # fw.write(f'Average: {total_value/len(corpus_list):.3}\n')
76 |
77 |
78 | fw.write('base\n')
79 | total_value = 0
80 | for corpus in corpus_list:
81 | if corpus=='cqadupstack':
82 | total = 0
83 | for subcorpus in corpus_list2:
84 | # filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}.json')
85 | filename = os.path.join(args.beir_data_path, f'{subcorpus}', f'test_eval_result{args.checkpoint_num}.json')
86 | with open(filename, 'r') as fr:
87 | results = json.load(fr)
88 | total += float(results['NDCG@10'])
89 | value = total/len(corpus_list2)
90 | total_value += value
91 | value = str(value)
92 |
93 | else:
94 | # filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}.json')
95 | filename = os.path.join(args.beir_data_path, f'{corpus}', f'test_eval_result{args.checkpoint_num}.json')
96 | if os.path.exists(filename):
97 | with open(filename, 'r') as fr:
98 | results = json.load(fr)
99 | value = results['NDCG@10']
100 | else:
101 | value ='0'
102 | print(f'{corpus} no results.')
103 | total_value += float(value)
104 | fw.write(f'{corpus}: {value:.3}\n')
105 | fw.write(f'Average: {total_value/len(corpus_list):.3}\n')
106 |
--------------------------------------------------------------------------------
/CAPSTONE/models/merge_query.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | This file aims to generate queries for a given document, which users may ask based on the given document.
5 | Input: a document
6 | Output: top-k queries,
7 | """
8 | from dataclasses import dataclass, field
9 | import csv
10 | import pickle
11 | from turtle import title
12 | from transformers import (
13 | HfArgumentParser,
14 | )
15 |
16 | import os
17 | import sys
18 | import json
19 | import random
20 | from tabulate import tabulate
21 | from tqdm import tqdm
22 | import logging
23 |
24 | pyfile_path = os.path.abspath(__file__)
25 | pyfile_dir = os.path.dirname(os.path.dirname(pyfile_path)) # equals to the path '../'
26 | sys.path.append(pyfile_dir)
27 | from utils.util import normalize_question, set_seed, get_optimizer, sum_main
28 | from utils.data_utils import load_passage
29 |
30 | # logger = logging.getLogger(__name__)
31 | logger = logging.getLogger("__main__")
32 |
33 | @dataclass
34 | class Arguments:
35 | dir_name: str = field(
36 | default="",
37 | metadata={
38 | "help": (
39 | "The path containing the generated queries."
40 | )
41 | },
42 | )
43 |
44 | seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
45 |
46 | # for cross train
47 | num_split: int = field(default=1,
48 | metadata={"help": "Split the train set into num_split parts. "
49 | "Note There is no overlap between different parts in terms of the positive doc."
50 | "Suppose num_split=3, we need to train three different generators. "
51 | "The first generator is tained on the data_2 and data_3, "
52 | "the second is trained on the data_1 and data_3, the third is trained on data_1 and data_2."})
53 |
54 |
55 |
56 |
57 | def main():
58 | parser = HfArgumentParser((Arguments))
59 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
60 | # If we pass only one argument to the script and it's the path to a json file,
61 | # let's parse it to get our arguments.
62 | args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
63 | else:
64 | args, = parser.parse_args_into_dataclasses()
65 | print(args)
66 | removed_passage_ids_list = []
67 | # removed_passage_id_path = os.path.join('checkpoints', args.dir_name, f'0_th_generator/removed_passge_id_for_0_th_generator.json')
68 | # removed_passage_id_set = set(json.load(open(removed_passage_id_path, 'r', encoding='utf-8')))
69 |
70 | # removed_passage_id_path = os.path.join('checkpoints', args.dir_name, f'1_th_generator/removed_passge_id_for_1_th_generator.json')
71 | # removed_passage_id_set2 = set(json.load(open(removed_passage_id_path, 'r', encoding='utf-8')))
72 | # print(len(removed_passage_id_set), len(removed_passage_id_set2), len(removed_passage_id_set|removed_passage_id_set2), len(removed_passage_id_set&removed_passage_id_set2))
73 | # exit()
74 | union_set = set()
75 | for i in range(args.num_split):
76 | # query_path = os.path.join(args.dir_name, f'{i}_th_generator/corpus/top_k_10_rs_10.json')
77 | removed_passage_id_path = os.path.join('checkpoints', args.dir_name, f'{i}_th_generator/removed_passge_id_for_{i}_th_generator.json')
78 | print(os.path.exists(removed_passage_id_path))
79 | removed_passage_ids = json.load(open(removed_passage_id_path, 'r', encoding='utf-8'))
80 | removed_passage_ids_list.append(removed_passage_ids)
81 | union_set = union_set | set(removed_passage_ids)
82 |
83 |
84 | final_queries_dict = {}
85 | for i in range(args.num_split):
86 | query_path = os.path.join('checkpoints', args.dir_name, f'{i}_th_generator/corpus/top_k_10_rs_10.json')
87 | removed_passage_ids = set(removed_passage_ids_list[i])
88 | queries_list = json.load(open(query_path, 'r', encoding='utf-8'))
89 | for e in queries_list:
90 | passage_id = e['passage_id']
91 | generated_queries = e['generated_queries']
92 | if passage_id in removed_passage_ids or passage_id not in union_set:
93 | final_queries_dict[passage_id] = final_queries_dict.get(passage_id, []) + generated_queries
94 | merge_output_data = []
95 | for k, v in final_queries_dict.items():
96 | merge_output_data.append({'passage_id': k, 'generated_queries':v})
97 | sorted_merge_output_data = sorted(merge_output_data, key=lambda x: int(x['passage_id']))
98 | with open(os.path.join('checkpoints', args.dir_name, f'top_k_10_rs_10.json'), 'w', encoding='utf-8') as fw:
99 | json.dump(sorted_merge_output_data, fw, indent=2)
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
--------------------------------------------------------------------------------
/CAPSTONE/preprocess/merge_query.py:
--------------------------------------------------------------------------------
1 | import json
2 | query_filename='./docTTTTTquery_full/doc2query.tsv'
3 |
4 | # reform
5 | min_q=80
6 | with open(query_filename, 'r', encoding='utf-8') as fr, open('doc2query_merge.tsv', 'w', encoding='utf-8') as fw:
7 | for psg_id, line in enumerate(fr):
8 | example = line.strip().split('\t')
9 | example = list(set([e.strip() for e in example]))
10 | min_q=min(min_q, len(example))
11 | output_data = '\t'.join([str(psg_id)]+example)
12 | #output_data = {'passage_id': str(psg_id), 'generated_queries': example}
13 | fw.write(output_data+'\n')
14 | print(min_q)
15 |
16 |
--------------------------------------------------------------------------------
/CAPSTONE/preprocess/preprocess_msmarco.py:
--------------------------------------------------------------------------------
1 | """
2 | This script aims to convert the data format into the
3 | """
4 | from tqdm import tqdm
5 | import json
6 | import csv
7 | import datasets
8 | from dataclasses import dataclass
9 | from argparse import ArgumentParser
10 | import os
11 | import random
12 | from tqdm import tqdm
13 | from datetime import datetime
14 | from collections import OrderedDict
15 | import logging
16 | # logger = logging.getLogger("__main__")
17 | logger = logging.getLogger(__name__)
18 | @dataclass
19 | class MSMARCOPreProcessor:
20 | data_dir: str
21 | split: str
22 |
23 | columns = ['text_id', 'title', 'text']
24 | title_field = 'title'
25 | text_field = 'text'
26 |
27 | def __post_init__(self):
28 | assert self.split in ['train', 'dev']
29 | self.query_file = os.path.join(self.data_dir, f'{self.split}.query.txt')
30 | self.queries = self.read_queries(self.query_file)
31 |
32 |
33 | self.collection_file = os.path.join(self.data_dir, f'corpus.tsv')
34 | print(f'Loading passages from: {self.collection_file}')
35 | self.collection = datasets.load_dataset(
36 | 'csv',
37 | data_files=self.collection_file,
38 | column_names=self.columns,
39 | delimiter='\t',
40 | )['train']
41 | print(f'the copus has {len(self.collection)} passages.')
42 |
43 | self.relevance_file = os.path.join(self.data_dir, f'qrels.{self.split}.tsv')
44 | self.query_2_pos_doc_dict = self.read_positive(self.relevance_file)
45 | if self.split =='train':
46 | self.negative_file = os.path.join(self.data_dir, f'train.negatives.tsv')
47 | self.query_2_neg_doc_dict = self.read_negative(self.negative_file)
48 | else:
49 | self.negative_file = None
50 | self.query_2_neg_doc_dict = None
51 |
52 | @staticmethod
53 | def read_negative(negative_file):
54 | print(f'Load the negative docs for queries from {negative_file}.')
55 | query_2_neg_doc_dict = {}
56 | with open(negative_file, 'r', encoding='utf8') as fr:
57 | for line in fr:
58 | query_id, negative_doc_id_list = line.strip().split('\t')
59 | negative_doc_id_list = negative_doc_id_list.split(',')
60 | random.shuffle(negative_doc_id_list)
61 | query_2_neg_doc_dict[query_id] = negative_doc_id_list
62 | print(f'{len(query_2_neg_doc_dict)} queries have negative docs.')
63 | return query_2_neg_doc_dict
64 |
65 | @staticmethod
66 | def read_queries(queries):
67 | print(f'Load the query from {queries}.')
68 | qmap = OrderedDict()
69 | with open(queries, 'r', encoding='utf8') as f:
70 | for l in f:
71 | qid, qry = l.strip().split('\t')
72 | qmap[qid] = qry
73 | print(f'Train set has {len(qmap)} queries.')
74 | return qmap
75 |
76 |
77 | @staticmethod
78 | def read_positive(relevance_file):
79 | print(f'Load the positive docs for queries from {relevance_file}.')
80 | query_2_pos_doc_dict = {}
81 | with open(relevance_file, 'r', encoding='utf8') as f:
82 | tsvreader = csv.reader(f, delimiter="\t")
83 | if 'train' in relevance_file:
84 | print('train')
85 | for [query_id, _, doc_id, rel] in tsvreader:
86 | if query_id in query_2_pos_doc_dict:
87 | assert rel == "1"
88 | query_2_pos_doc_dict[query_id].append(doc_id)
89 | else:
90 | query_2_pos_doc_dict[query_id] = [doc_id]
91 | else: # dev
92 | print('dev')
93 | for [query_id, doc_id] in tsvreader:
94 | if query_id in query_2_pos_doc_dict:
95 | query_2_pos_doc_dict[query_id].append(doc_id)
96 | else:
97 | query_2_pos_doc_dict[query_id] = [doc_id]
98 | print(f'{len(query_2_pos_doc_dict)} queries have positive docs.')
99 | return query_2_pos_doc_dict
100 |
101 |
102 | def reform_data(self, n_sample, output_path):
103 | output_list = []
104 | for query_id in self.queries:
105 | positive_doc_id_list = self.query_2_pos_doc_dict[query_id]
106 | if self.query_2_neg_doc_dict is not None and query_id in self.query_2_neg_doc_dict:
107 | negative_doc_id_list = self.query_2_neg_doc_dict[query_id][:n_sample]
108 | else:
109 | negative_doc_id_list = []
110 |
111 | q_str = self.queries[query_id]
112 | q_answer = 'nil'
113 | positive_ctxs = []
114 | negative_ctxs = []
115 | for passage_id in negative_doc_id_list:
116 | entry = self.collection[int(passage_id)]
117 | title = entry[self.title_field]
118 | title = "" if title is None else title
119 | body = entry[self.text_field]
120 |
121 | negative_ctxs.append(
122 | {'title': title, 'text': body, 'passage_id': passage_id, 'score': 'nil'}
123 | )
124 |
125 | for passage_id in positive_doc_id_list:
126 | entry = self.collection[int(passage_id)]
127 | title = entry[self.title_field]
128 | title = "" if title is None else title
129 | body = entry[self.text_field]
130 |
131 | positive_ctxs.append(
132 | {'title': title, 'text': body, 'passage_id': passage_id, 'score': 'nil'}
133 | )
134 |
135 | output_list.append(
136 | {
137 | "q_id": query_id, "question": q_str, "answers": q_answer, "positive_ctxs": positive_ctxs,
138 | "hard_negative_ctxs": negative_ctxs, "negative_ctxs": []
139 | }
140 | )
141 |
142 |
143 | with open(output_path, 'w', encoding='utf8') as f:
144 | json.dump(output_list, f, indent=2)
145 |
146 |
147 | def generate_qa_files(self, output_path):
148 | with open(output_path, 'w', encoding='utf8') as fw:
149 | for q_id, query in self.queries.items():
150 | positive_doc_id_list = self.query_2_pos_doc_dict[q_id]
151 | fw.write(f"{query}\t{json.dumps(positive_doc_id_list)}\t{q_id}\n")
152 |
153 | def reform_corpus(self, output_path):
154 | # original format: psg_id, title, text
155 | # reformed format: psg_id, text, title (to consistent with nq and tq)
156 | with open(self.collection_file, 'r', encoding='utf8') as fr, open(output_path, 'w', encoding='utf8') as fw:
157 | for line in fr:
158 | l = line.strip().split('\t')
159 | assert len(l) == 3
160 | fw.write(f'{l[0]}\t{l[2]}\t{l[1]}\n')
161 |
162 |
163 | if __name__ == '__main__':
164 | random.seed(0)
165 | parser = ArgumentParser()
166 | parser.add_argument('--data_dir', type=str, default='./marco')
167 | parser.add_argument('--output_path', type=str, default='./reformed_marco')
168 | parser.add_argument('--n_sample', type=int, default=30, help='number of selected negative examples.')
169 |
170 |
171 | args = parser.parse_args()
172 | os.makedirs(args.output_path, exist_ok=True)
173 |
174 | processor = MSMARCOPreProcessor(
175 | data_dir=args.data_dir,
176 | split='train'
177 | )
178 | processor.reform_data(args.n_sample, os.path.join(args.output_path, 'biencoder-marco-train.json'))
179 | processor.generate_qa_files(os.path.join(args.output_path, 'marco-train.qa.csv'))
180 |
181 | processor = MSMARCOPreProcessor(
182 | data_dir=args.data_dir,
183 | split='dev'
184 | )
185 | processor.reform_corpus( os.path.join(args.output_path, 'corpus.tsv'))
186 | processor.reform_data(args.n_sample, os.path.join(args.output_path, 'biencoder-marco-dev.json'))
187 | processor.generate_qa_files(os.path.join(args.output_path, 'marco-dev.qa.csv'))
--------------------------------------------------------------------------------
/CAPSTONE/preprocess/preprocess_trec.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import json
3 | import csv
4 | import datasets
5 | from dataclasses import dataclass
6 | from argparse import ArgumentParser
7 | import os
8 | import random
9 | from tqdm import tqdm
10 | from datetime import datetime
11 | from collections import OrderedDict
12 | import logging
13 |
14 | def read_queries(queries):
15 | print(f'Load the query from {queries}.')
16 | qmap = OrderedDict()
17 | with open(queries, 'r', encoding='utf8') as f:
18 | for l in f:
19 | qid, qry = l.strip().split('\t')
20 | qmap[qid] = qry
21 | print(f'Train set has {len(qmap)} queries.')
22 | return qmap
23 |
24 |
25 | def read_positive(relevance_file):
26 | print(f'Load the positive docs for queries from {relevance_file}.')
27 | query_2_pos_doc_dict = {}
28 | with open(relevance_file, 'r', encoding='utf8') as fr:
29 | for line in fr:
30 | query_id, _, doc_id, _ = line.split()
31 | if query_id in query_2_pos_doc_dict:
32 | query_2_pos_doc_dict[query_id].append(doc_id)
33 | else:
34 | query_2_pos_doc_dict[query_id] = [doc_id]
35 | print(f'{len(query_2_pos_doc_dict)} queries have positive docs.')
36 | return query_2_pos_doc_dict
37 |
38 | def generate_qa_files(output_path, queries, query_2_pos_doc_dict):
39 | with open(output_path, 'w', encoding='utf8') as fw:
40 | for q_id, query in queries.items():
41 | if q_id not in query_2_pos_doc_dict:
42 | continue
43 | positive_doc_id_list = query_2_pos_doc_dict[q_id]
44 | fw.write(f"{query}\t{json.dumps(positive_doc_id_list)}\t{q_id}\n")
45 |
46 | if __name__ == '__main__':
47 | random.seed(0)
48 | parser = ArgumentParser()
49 | args = parser.parse_args()
50 |
51 | # trec_19
52 | data_dir = './trec_19'
53 | queries = read_queries(os.path.join(data_dir, 'msmarco-test2019-queries.tsv'))
54 | query_2_pos_doc_dict = read_positive(os.path.join(data_dir, '2019qrels-pass.txt'))
55 | generate_qa_files(os.path.join(data_dir, 'test2019.qa.csv'), queries, query_2_pos_doc_dict)
56 |
57 | # trec_20
58 | data_dir = './trec_20'
59 | queries = read_queries(os.path.join(data_dir, 'msmarco-test2020-queries.tsv'))
60 | query_2_pos_doc_dict = read_positive(os.path.join(data_dir, '2020qrels-pass.txt'))
61 | generate_qa_files(os.path.join(data_dir, 'test2020.qa.csv'), queries, query_2_pos_doc_dict)
--------------------------------------------------------------------------------
/CAPSTONE/run_de_model_expand_corpus_cocondenser.sh:
--------------------------------------------------------------------------------
1 | echo "start de model warmup"
2 |
3 | MAX_SEQ_LEN=144
4 | MAX_Q_LEN=32
5 | DATA_DIR=./reformed_marco
6 |
7 |
8 | QUERY_PATH=./docTTTTTquery_full/doc2query_merge.tsv # each doc is paired with 80 generated quries.
9 | CORPUS_NAME=corpus
10 | CKPT_NUM=20000
11 | MODEL_TYPE=Luyu/co-condenser-marco
12 | MODEL=cocondenser
13 |
14 | number_neg=31
15 | total_part=3
16 | select_generated_query=gradual
17 | delimiter=sep
18 | gold_query_prob=0
19 |
20 | EXP_NAME=run_de_marco_${MAX_SEQ_LEN}_${MAX_Q_LEN}_t5_append_${delimiter}_gqp${gold_query_prob}_${MODEL}_neg${number_neg}_${select_generated_query}_${total_part}_part
21 | OUT_DIR=checkpoints_full_query/$EXP_NAME
22 | TB_DIR=tensorboard_log_full_query/$EXP_NAME # tensorboard log path
23 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9589 \
24 | ./models/run_de_model_ernie.py \
25 | --model_type $MODEL_TYPE \
26 | --train_file $DATA_DIR/biencoder-marco-train.json \
27 | --validation_file $DATA_DIR/biencoder-marco-dev.json \
28 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN \
29 | --per_device_train_batch_size 8 --gradient_accumulation_steps 1 \
30 | --number_neg $number_neg --learning_rate 5e-6 \
31 | --output_dir $OUT_DIR \
32 | --warmup_steps 2000 --logging_steps 100 --save_steps 1000 --max_steps 20000 \
33 | --log_dir $TB_DIR \
34 | --shuffle_positives \
35 | --fp16 --do_train \
36 | --expand_corpus --query_path $QUERY_PATH \
37 | --top_k_query 1 --append --delimiter $delimiter \
38 | --gold_query_prob $gold_query_prob --select_generated_query $select_generated_query --total_part $total_part
39 |
40 |
41 | # 2. Evaluate retriever and generate hard topk
42 | echo "start de model inference"
43 | for k in 10
44 | do
45 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9536 \
46 | ./models/run_de_model_ernie.py \
47 | --model_type $MODEL_TYPE \
48 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \
49 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \
50 | --train_file $DATA_DIR/biencoder-marco-train.json \
51 | --validation_file $DATA_DIR/biencoder-marco-dev.json \
52 | --train_qa_path $DATA_DIR/marco-train.qa.csv \
53 | --dev_qa_path $DATA_DIR/marco-dev.qa.csv \
54 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 1024 \
55 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \
56 | --fp16 --do_predict \
57 | --expand_corpus --query_path $QUERY_PATH \
58 | --top_k_query $k --append --delimiter $delimiter
59 |
60 | # TREC 19 and 20
61 | for year in 19 20
62 | do
63 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9535 \
64 | ./models/run_de_model_ernie.py \
65 | --model_type $MODEL_TYPE \
66 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \
67 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \
68 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 512 \
69 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \
70 | --evaluate_trec --prefix trec${year}_test \
71 | --test_qa_path ./trec_${year}/test20${year}.qa.csv \
72 | --query_positive_id_path ./trec_${year}/20${year}qrels-pass.txt \
73 | --fp16 --do_predict \
74 | --expand_corpus --query_path $QUERY_PATH \
75 | --top_k_query $k --append --delimiter $delimiter \
76 | --load_cache
77 | done
78 | done
79 |
--------------------------------------------------------------------------------
/CAPSTONE/run_de_model_expand_corpus_cocondenser_step2.sh:
--------------------------------------------------------------------------------
1 | echo "start de model warmup"
2 |
3 | MAX_SEQ_LEN=144
4 | MAX_Q_LEN=32
5 | DATA_DIR=./reformed_marco
6 | QUERY_PATH=./docTTTTTquery_full/doc2query_merge.tsv # each doc is paired with 80 generated quries.
7 | CORPUS_NAME=corpus
8 | CKPT_NUM=25000
9 |
10 | MODEL_TYPE=Luyu/co-condenser-marco
11 | MODEL=cocondenser
12 |
13 | number_neg=31
14 | total_part=4
15 | select_generated_query=gradual
16 | delimiter=sep
17 | gold_query_prob=0
18 |
19 | # path to the mined hard negatives
20 | train_file=checkpoints_full_query/run_de_marco_144_32_t5_append_sep_gqp0_cocondenser_neg31_gradual_3_part/20000_corpus_top_k_query_10/train_20000_0_query.json
21 | validation_file=checkpoints_full_query/run_de_marco_144_32_t5_append_sep_gqp0_cocondenser_neg31_gradual_3_part/20000_corpus_top_k_query_10/dev_20000_0_query.json
22 |
23 | EXP_NAME=run_de_marco_${MAX_SEQ_LEN}_${MAX_Q_LEN}_t5_append_${delimiter}_gqp${gold_query_prob}_${MODEL}_neg${number_neg}_${select_generated_query}_${total_part}_part_step2_selfdata_shuffle_positives_F
24 | OUT_DIR=checkpoints_full_query/$EXP_NAME
25 | TB_DIR=tensorboard_log_full_query/$EXP_NAME # tensorboard log path
26 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9589 \
27 | ./models/run_de_model_ernie.py \
28 | --model_type $MODEL_TYPE \
29 | --train_file $train_file \
30 | --validation_file $validation_file \
31 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN \
32 | --per_device_train_batch_size=8 --gradient_accumulation_steps=1 \
33 | --number_neg $number_neg --learning_rate 5e-6 \
34 | --output_dir $OUT_DIR \
35 | --warmup_steps 2500 --logging_steps 100 --save_steps 1000 --max_steps 25000 \
36 | --log_dir $TB_DIR \
37 | --passage_path=$DATA_DIR/${CORPUS_NAME}.tsv \
38 | --fp16 --do_train \
39 | --expand_corpus --query_path $QUERY_PATH \
40 | --top_k_query 1 --append --delimiter $delimiter \
41 | --gold_query_prob $gold_query_prob --select_generated_query $select_generated_query --total_part $total_part
42 |
43 | # 2. Evaluate retriever and generate hard topk
44 | echo "start de model inference"
45 | for k in 10
46 | do
47 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9536 \
48 | ./models/run_de_model_ernie.py \
49 | --model_type $MODEL_TYPE \
50 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \
51 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \
52 | --train_file $DATA_DIR/biencoder-marco-train.json \
53 | --validation_file $DATA_DIR/biencoder-marco-dev.json \
54 | --train_qa_path $DATA_DIR/marco-train.qa.csv \
55 | --dev_qa_path $DATA_DIR/marco-dev.qa.csv \
56 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 1024 \
57 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \
58 | --fp16 --do_predict \
59 | --expand_corpus --query_path $QUERY_PATH \
60 | --top_k_query $k --append --delimiter $delimiter
61 |
62 | # TREC 19 and 20
63 | for year in 19 20
64 | do
65 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9535 \
66 | ./models/run_de_model_ernie.py \
67 | --model_type $MODEL_TYPE \
68 | --model_name_or_path $OUT_DIR/checkpoint-$CKPT_NUM \
69 | --output_dir ${OUT_DIR}/${CKPT_NUM}_${CORPUS_NAME}_top_k_query_${k} \
70 | --max_seq_length $MAX_SEQ_LEN --max_query_length $MAX_Q_LEN --per_device_eval_batch_size 512 \
71 | --passage_path $DATA_DIR/${CORPUS_NAME}.tsv \
72 | --evaluate_trec --prefix trec${year}_test \
73 | --test_qa_path ./trec_${year}/test20${year}.qa.csv \
74 | --query_positive_id_path ./trec_${year}/20${year}qrels-pass.txt \
75 | --fp16 --do_predict \
76 | --expand_corpus --query_path $QUERY_PATH \
77 | --top_k_query $k --append --delimiter $delimiter \
78 | --load_cache
79 | done
80 | done
81 |
82 |
--------------------------------------------------------------------------------
/CAPSTONE/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/CAPSTONE/utils/__init__.py
--------------------------------------------------------------------------------
/CAPSTONE/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import csv
5 | import json
6 | import datasets
7 | from dataclasses import dataclass
8 | from collections import OrderedDict
9 | csv.field_size_limit(sys.maxsize)
10 | logger = logging.getLogger("__main__")
11 | # logger = logging.getLogger(__name__)
12 |
13 | def load_passage(passage_path:str):
14 | """
15 | for nq, tq, and msmarco
16 | """
17 | if not os.path.exists(passage_path):
18 | logger.info(f'{passage_path} does not exist')
19 | return
20 | logger.info(f'Loading passages from: {passage_path}')
21 | passages = OrderedDict()
22 | with open(passage_path, 'r', encoding='utf8') as fin:
23 | reader = csv.reader(fin, delimiter='\t')
24 | for row in reader:
25 | if row[0] != 'id':
26 | try:
27 | if len(row) == 3:
28 | # psg_id, text, title
29 | passages[row[0]] = (row[0], row[1].strip(), row[2].strip())
30 | else:
31 | # psg_id, text, title, original psg_id in the psgs_w100.tsv file
32 | passages[row[3]] = (row[3], row[1].strip(), row[2].strip())
33 | except Exception:
34 | logger.warning(f'The following input line has not been correctly loaded: {row}')
35 | logger.info(f'{passage_path} has {len(passages)} passages.')
36 | return passages
37 |
--------------------------------------------------------------------------------
/CAPSTONE/utils/evaluate_trec.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is used to evaluate the trec-19, 20 test set.
3 | """
4 | import pytrec_eval
5 |
6 | def load_qrel_data(query_positive_id_path):
7 | dev_query_positive_id = {}
8 | with open(query_positive_id_path, 'r', encoding='utf8') as f:
9 | for line in f:
10 | query_id, _, doc_id, rel = line.split()
11 | query_id = int(query_id)
12 | doc_id = int(doc_id)
13 | if query_id not in dev_query_positive_id:
14 | dev_query_positive_id[query_id] = {}
15 | dev_query_positive_id[query_id][doc_id] = int(rel)
16 | return dev_query_positive_id
17 |
18 | def convert_to_string_id(result_dict):
19 | string_id_dict = {}
20 |
21 | # format [string, dict[string, val]]
22 | for k, v in result_dict.items():
23 | _temp_v = {}
24 | for inner_k, inner_v in v.items():
25 | _temp_v[str(inner_k)] = inner_v
26 |
27 | string_id_dict[str(k)] = _temp_v
28 |
29 | return string_id_dict
30 | def EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, I_nearest_neighbor,topN):
31 | prediction = {} #[qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2)
32 |
33 | total = 0
34 | labeled = 0
35 | Atotal = 0
36 | Alabeled = 0
37 | qids_to_ranked_candidate_passages = {}
38 | for query_idx in range(len(I_nearest_neighbor)):
39 | seen_pid = set()
40 | query_id = query_embedding2id[query_idx]
41 | prediction[query_id] = {}
42 |
43 | top_ann_pid = I_nearest_neighbor[query_idx].copy()
44 | selected_ann_idx = top_ann_pid[:topN]
45 | rank = 0
46 |
47 | if query_id in qids_to_ranked_candidate_passages:
48 | pass
49 | else:
50 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given
51 | tmp = [0] * 1000
52 | qids_to_ranked_candidate_passages[query_id] = tmp
53 |
54 | for idx in selected_ann_idx:
55 | pred_pid = passage_embedding2id[idx]
56 |
57 | if not pred_pid in seen_pid:
58 | # this check handles multiple vector per document
59 | qids_to_ranked_candidate_passages[query_id][rank]=pred_pid
60 | Atotal += 1
61 | if pred_pid not in dev_query_positive_id[query_id]:
62 | Alabeled += 1
63 | if rank < 10:
64 | total += 1
65 | if pred_pid not in dev_query_positive_id[query_id]:
66 | labeled += 1
67 | rank += 1
68 | prediction[query_id][pred_pid] = -rank
69 | seen_pid.add(pred_pid)
70 |
71 | # use out of the box evaluation script
72 | evaluator = pytrec_eval.RelevanceEvaluator(
73 | convert_to_string_id(dev_query_positive_id), {'map_cut', 'ndcg_cut', 'recip_rank','recall'})
74 |
75 | eval_query_cnt = 0
76 | result = evaluator.evaluate(convert_to_string_id(prediction))
77 |
78 | qids_to_relevant_passageids = {}
79 | for qid in dev_query_positive_id:
80 | qid = int(qid)
81 | if qid in qids_to_relevant_passageids:
82 | pass
83 | else:
84 | qids_to_relevant_passageids[qid] = []
85 | for pid in dev_query_positive_id[qid]:
86 | if pid>0:
87 | qids_to_relevant_passageids[qid].append(pid)
88 |
89 | # ms_mrr = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
90 |
91 | ndcg = 0
92 | Map = 0
93 | mrr = 0
94 | recall = 0
95 | recall_1000 = 0
96 |
97 | for k in result.keys():
98 | eval_query_cnt += 1
99 | ndcg += result[k]["ndcg_cut_10"]
100 | Map += result[k]["map_cut_10"]
101 | mrr += result[k]["recip_rank"]
102 | recall += result[k]["recall_"+str(topN)]
103 |
104 | final_ndcg = ndcg / eval_query_cnt
105 | final_Map = Map / eval_query_cnt
106 | final_mrr = mrr / eval_query_cnt
107 | final_recall = recall / eval_query_cnt
108 | hole_rate = labeled/total
109 | Ahole_rate = Alabeled/Atotal
110 |
111 | return final_ndcg, eval_query_cnt, final_Map, final_mrr, final_recall, hole_rate, Ahole_rate, result, prediction
--------------------------------------------------------------------------------
/CAPSTONE/utils/lamb.py:
--------------------------------------------------------------------------------
1 | """Lamb optimizer."""
2 |
3 | import collections
4 | import math
5 |
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 | from torch.optim import Optimizer
9 |
10 |
11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
12 | """Log a histogram of trust ratio scalars in across layers."""
13 | results = collections.defaultdict(list)
14 | for group in optimizer.param_groups:
15 | for p in group['params']:
16 | state = optimizer.state[p]
17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
18 | if i in state:
19 | results[i].append(state[i])
20 |
21 | for k, v in results.items():
22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
23 |
24 | class Lamb(Optimizer):
25 | r"""Implements Lamb algorithm.
26 |
27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
28 |
29 | Arguments:
30 | params (iterable): iterable of parameters to optimize or dicts defining
31 | parameter groups
32 | lr (float, optional): learning rate (default: 1e-3)
33 | betas (Tuple[float, float], optional): coefficients used for computing
34 | running averages of gradient and its square (default: (0.9, 0.999))
35 | eps (float, optional): term added to the denominator to improve
36 | numerical stability (default: 1e-8)
37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
38 | adam (bool, optional): always use trust ratio = 1, which turns this into
39 | Adam. Useful for comparison purposes.
40 |
41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
42 | https://arxiv.org/abs/1904.00962
43 | """
44 |
45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
46 | weight_decay=0, adam=False):
47 | if not 0.0 <= lr:
48 | raise ValueError("Invalid learning rate: {}".format(lr))
49 | if not 0.0 <= eps:
50 | raise ValueError("Invalid epsilon value: {}".format(eps))
51 | if not 0.0 <= betas[0] < 1.0:
52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
55 | defaults = dict(lr=lr, betas=betas, eps=eps,
56 | weight_decay=weight_decay)
57 | self.adam = adam
58 | super(Lamb, self).__init__(params, defaults)
59 |
60 | def step(self, closure=None):
61 | """Performs a single optimization step.
62 |
63 | Arguments:
64 | closure (callable, optional): A closure that reevaluates the model
65 | and returns the loss.
66 | """
67 | loss = None
68 | if closure is not None:
69 | loss = closure()
70 |
71 | for group in self.param_groups:
72 | for p in group['params']:
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
78 |
79 | state = self.state[p]
80 |
81 | # State initialization
82 | if len(state) == 0:
83 | state['step'] = 0
84 | # Exponential moving average of gradient values
85 | state['exp_avg'] = torch.zeros_like(p.data)
86 | # Exponential moving average of squared gradient values
87 | state['exp_avg_sq'] = torch.zeros_like(p.data)
88 |
89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | # Decay the first and second moment running average coefficient
95 | # m_t
96 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
97 | # v_t
98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
99 |
100 | # Paper v3 does not use debiasing.
101 | # Apply bias to lr to avoid broadcast.
102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
103 |
104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
105 |
106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
107 | if group['weight_decay'] != 0:
108 | adam_step.add_(group['weight_decay'], p.data)
109 |
110 | adam_norm = adam_step.pow(2).sum().sqrt()
111 | if weight_norm == 0 or adam_norm == 0:
112 | trust_ratio = 1
113 | else:
114 | trust_ratio = weight_norm / adam_norm
115 | state['weight_norm'] = weight_norm
116 | state['adam_norm'] = adam_norm
117 | state['trust_ratio'] = trust_ratio
118 | if self.adam:
119 | trust_ratio = 1
120 |
121 | p.data.add_(-step_size * trust_ratio, adam_step)
122 |
123 | return loss
124 |
--------------------------------------------------------------------------------
/CAPSTONE/utils/metric_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 | import numpy as np
4 | from collections import Counter
5 | from dataclasses import dataclass
6 | from datasets import load_metric
7 | from rouge_score import rouge_scorer
8 |
9 | import nltk
10 | import random
11 | from typing import Optional, List
12 |
13 | class compute_metric:
14 | def __init__(self):
15 | self.rouge_metric = load_metric('rouge')
16 | self.rouge_scorer = rouge_scorer.RougeScorer(rouge_types = ["rougeL"], use_stemmer=True)
17 | self.bleu_scorer = load_metric('bleu')
18 | self.meteor_scorer = load_metric('meteor')
19 |
20 | def postprocess_text_bleu(self, preds, labels):
21 | preds = [nltk.word_tokenize(pred) for pred in preds]
22 | labels = [nltk.word_tokenize(label) for label in labels]
23 |
24 | return preds, labels
25 |
26 | def __call__(self, preds, labels):
27 | # preds, labels = eval_preds
28 | result = {}
29 | # Some simple post-processing
30 | preds_bleu, labels_bleu = self.postprocess_text_bleu(preds, labels)
31 | result['rougeL'] = self.rouge_metric.compute(predictions=preds, references=labels, use_stemmer=True)['rougeL'].mid.fmeasure * 100
32 |
33 | result['bleu_4'] = self.bleu_scorer.compute(predictions=preds_bleu, references=[[l] for l in labels_bleu], max_order=4)['bleu'] * 100
34 |
35 | result['meteor'] = self.meteor_scorer.compute(predictions=preds_bleu, references=labels_bleu)['meteor'] * 100
36 | result = {k: round(v, 4) for k, v in result.items()}
37 | return result
38 |
39 | def compute_rouge(self, preds, labels):
40 | result = {}
41 | # Some simple post-processing
42 | preds_bleu, labels_bleu = self.postprocess_text_bleu(preds, labels)
43 | result['rougeL'] = self.rouge_metric.compute(predictions=preds, references=labels, use_stemmer=True)['rougeL'].mid.fmeasure * 100
44 | result = {k: round(v, 4) for k, v in result.items()}
45 | return result
46 |
47 | def get_candidates(self, targets: List[str], preds: List[str], num_cand:int, num_cand_picked:int, strategy:str, gold_as_positive=False):
48 | """
49 | args:
50 | targets: list of targets for each sample
51 | preds: list of predictions, length == len(targets) * num_cand
52 | num_cand: number of candidates
53 | num_cand_picked: number of returned indices per sample
54 | strategy: how to select num_cand_picked negatives from the preds
55 | gold_as_positive: use the gold or use the one of preds with the highest reward as the positive.
56 | returns:
57 | indices: Torch tensor, (B * (C-1), ), since the positive candidate is not generated from the generator
58 | candiates: candidates, with the length of len(targets) * num_cand_picked
59 | NOTE: We should always keep the positive sequences in the first candidate for each sample
60 | """
61 | preds_bleu, targets_bleu = self.postprocess_text_bleu(preds, targets)
62 | # preds_meteor = [' '.join(pred) for pred in preds_bleu]
63 | # targets_meteor = [' '.join(label) for label in targets_bleu]
64 | # print(targets_meteor)
65 |
66 | indices = []
67 | candidates = []
68 | rewards = []
69 | for i,t in enumerate(targets):
70 | scores = []
71 | ps = preds[i * num_cand: (i+1)*num_cand]
72 | ps_bleu = preds_bleu[i * num_cand: (i+1)*num_cand]
73 | for j,p in enumerate(ps):
74 | if len(ps_bleu[j]) == 0:
75 | rouge_score = 0
76 | bleu_score = 0
77 | meteor_score = 0
78 | else:
79 | # rouge_score = self.rouge_metric.compute(predictions=[p], references=[t], use_stemmer=True)['rougeL'].mid.fmeasure
80 | bleu_score = self.bleu_scorer.compute(predictions = [ps_bleu[j]], references = [[targets_bleu[i]]], max_order = 4)['bleu']
81 | meteor_score = self.meteor_scorer.compute(predictions = [ps_bleu[j]], references = [targets_bleu[i]])['meteor']
82 | # reward = rouge_score + bleu_score + meteor_score
83 | reward = bleu_score + meteor_score
84 | scores.append((j + i * num_cand, reward , p))
85 |
86 | scores = sorted(scores, key = lambda x: x[1], reverse=True)
87 |
88 | if gold_as_positive:
89 | # rouge_score = self.rouge_metric.compute(predictions = [t], references= [t], use_stemmer=True)['rougeL'].mid.fmeasure
90 | bleu_score = self.bleu_scorer.compute(predictions = [targets_bleu[i]], references = [[targets_bleu[i]]], max_order = 4)['bleu']
91 | meteor_score = self.meteor_scorer.compute(predictions = [targets_bleu[i]], references = [targets_bleu[i]])['meteor']
92 | # reward = rouge_score + bleu_score + meteor_score
93 | reward = bleu_score + meteor_score
94 | idx_this = []
95 | cand_this = [t]
96 | rewards_this = [reward]
97 | max_num = num_cand_picked - 1
98 | else:
99 | idx_this = [scores[0][0]] # the first as pos
100 | cand_this = [scores[0][2]]
101 | rewards_this = [scores[0][1]]
102 | scores = scores[1:]
103 | max_num = num_cand_picked - 1
104 |
105 | if strategy == 'random':
106 | s_for_pick = random.sample(scores, max_num)
107 | idx_this += [s[0] for s in s_for_pick]
108 | cand_this += [s[2] for s in s_for_pick]
109 | rewards_this += [s[1] for s in s_for_pick]
110 | else:
111 | if strategy == 'top':
112 | idx_this += [s[0] for s in scores[:max_num]]
113 | cand_this += [s[2] for s in scores[:max_num]]
114 | rewards_this += [s[1] for s in scores[:max_num]]
115 | elif strategy == 'bottom':
116 | idx_this += [s[0] for s in scores[-max_num:]]
117 | cand_this += [s[2] for s in scores[-max_num:]]
118 | rewards_this += [s[1] for s in scores[-max_num:]]
119 | elif strategy == 'top-bottom':
120 | n_top = max_num // 2
121 | n_bottom = max_num - n_top
122 | idx_this += [s[0] for s in scores[:n_top]]
123 | cand_this += [s[2] for s in scores[:n_top]]
124 | idx_this += [s[0] for s in scores[-n_bottom:]]
125 | cand_this += [s[2] for s in scores[-n_bottom:]]
126 | rewards_this += [s[1] for s in scores[:n_top]]
127 | rewards_this += [s[1] for s in scores[-n_bottom:]]
128 |
129 | indices += idx_this
130 | candidates += cand_this
131 | rewards.append(rewards_this)
132 | return candidates, torch.FloatTensor(rewards)
133 | # return torch.LongTensor(indices), candidates, torch.FloatTensor(rewards)
134 |
135 | def compute_metric_score(self, target:str, preds:List[str], metric:str):
136 | score_list = []
137 | preds_bleu, targets_bleu = self.postprocess_text_bleu(preds, [target])
138 |
139 | for i, pred in enumerate(preds_bleu):
140 | if metric=='rouge-l':
141 | score = self.rouge_scorer.score(target=target, prediction=preds[i])['rougeL'].fmeasure
142 | # score = self.rouge_metric.compute(predictions=[preds[i]], references=[target], use_stemmer=True)['rougeL'].mid.fmeasure
143 | elif metric =='bleu':
144 | score = self.bleu_scorer.compute(predictions = [preds_bleu[i]], references = [[targets_bleu[0]]], max_order = 4)['bleu']
145 | elif metric=='meteor':
146 | score = self.meteor_scorer.compute(predictions = [preds_bleu[i]], references = [targets_bleu[0]])['meteor']
147 | else:
148 | raise ValueError()
149 | score_list.append(score)
150 | return score_list
151 |
152 |
153 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/HOW_TO_DOWNLOAD.md:
--------------------------------------------------------------------------------
1 | # How to Download Desired Resources
2 |
3 | Since most of the resources are large, we release the preprocessed data and trained ckpts in an Azure Blob container with public access for the most of our open-source projects.
4 | We here provide two methods to download the resources.
5 |
6 |
7 | ## You Must Know
8 |
9 | In almost each projects, we will provide the **basic URL** of the project in the `README.md` file.
10 | For example, for [SimANS](https://github.com/microsoft/SimXNS/tree/main/SimANS), the basic URL is `https://msranlciropen.blob.core.windows.net/simxns/SimANS/`.
11 |
12 | If you want to view the content of the blob, you can use [Microsoft's AzCopy CLI tool](https://learn.microsoft.com/en-us/azure/storage/common/storage-ref-azcopy):
13 | ```bash
14 | azcopy list https://msranlciropen.blob.core.windows.net/simxns/SimANS/
15 | ```
16 | We also provide the list in the `README.md` file of projects.
17 |
18 |
19 | ## Method 1: Directly Download using the URLs
20 |
21 | You may choose to download only a part of the resources by appending the relative path to the blob url and directly downloading it.
22 |
23 |
24 | ## Method 2: Faster Transmission using AzCopy
25 |
26 | You can also use [the copy function of Microsoft's AzCopy CLI tool](https://learn.microsoft.com/en-us/azure/storage/common/storage-ref-azcopy-copy).
27 | Here we provide the command of copying the entire folder like this:
28 | ```bash
29 | azcopy copy --recursive "https://msranlciropen.blob.core.windows.net/simxns/SimANS/" .
30 | ```
31 | You may also use the tool to copy only a single file as you like:
32 | ```bash
33 | azcopy copy "https://msranlciropen.blob.core.windows.net/simxns/SimANS/best_simans_ckpt/TQ/checkpoint-10000" .
34 | ```
35 |
--------------------------------------------------------------------------------
/LEAD/README.md:
--------------------------------------------------------------------------------
1 | # LEAD
2 |
3 | The code for our paper [LEAD: Liberal Feature-based Distillation for Dense Retrieval](https://arxiv.org/abs/2212.05225).
4 |
5 | 
6 |
7 | ## Overview
8 |
9 | Our proposed method LEAD aligns the layer features of student and teacher, emphasizing more on the informative layers by re-weighting.
10 |
11 | Here we show the main results on [MS MARCO](https://microsoft.github.io/msmarco/). This method outperformes all the baselines.
12 |
13 | 
14 |
15 |
16 | ## Released Resources
17 |
18 | We release the preprocessed data and trained ckpts in [Azure Blob](https://msranlciropen.blob.core.windows.net/simxns/LEAD/).
19 | Here we also provide the file list under this URL:
20 |
21 | Click here to see the file list.
22 | INFO: ckpt/24_to_12_msdoc.ckpt; Content Length: 1.22 GiB
23 | INFO: ckpt/24_to_12_mspas.ckpt; Content Length: 1.22 GiB
24 | INFO: ckpt/24_to_12_trec_doc_19.ckpt; Content Length: 1.22 GiB
25 | INFO: ckpt/24_to_12_trec_doc_20.ckpt; Content Length: 1.22 GiB
26 | INFO: ckpt/24_to_12_trec_pas_19.ckpt; Content Length: 1.22 GiB
27 | INFO: ckpt/24_to_12_trec_pas_20.ckpt; Content Length: 1.22 GiB
28 | INFO: dataset/mspas/biencoder-mspas-train-hard.json; Content Length: 1.56 GiB
29 | INFO: dataset/mspas/biencoder-mspas-train.json; Content Length: 5.35 GiB
30 | INFO: dataset/mspas/mspas-test.qa.csv; Content Length: 319.98 KiB
31 | INFO: dataset/mspas/psgs_w100.tsv; Content Length: 3.06 GiB
32 | INFO: dataset/mspas/trec2019-test.qa.csv; Content Length: 99.72 KiB
33 | INFO: dataset/mspas/trec2019-test.rating.csv; Content Length: 46.67 KiB
34 | INFO: dataset/mspas/trec2020-test.qa.csv; Content Length: 122.86 KiB
35 | INFO: dataset/mspas/trec2020-test.rating.csv; Content Length: 57.48 KiB
36 | INFO: dataset/msdoc/biencoder-msdoc-train-hard.json; Content Length: 800.13 MiB
37 | INFO: dataset/msdoc/biencoder-msdoc-train.json; Content Length: 294.04 MiB
38 | INFO: dataset/msdoc/msdoc-test.qa.csv; Content Length: 232.16 KiB
39 | INFO: dataset/msdoc/psgs_w100.tsv; Content Length: 21.11 GiB
40 | INFO: dataset/msdoc/trec2019-test.qa.csv; Content Length: 170.55 KiB
41 | INFO: dataset/msdoc/trec2019-test.rating.csv; Content Length: 80.84 KiB
42 | INFO: dataset/msdoc/trec2020-test.qa.csv; Content Length: 96.26 KiB
43 | INFO: dataset/msdoc/trec2020-test.rating.csv; Content Length: 46.04 KiB
44 |
45 |
46 | To download the files, please refer to [HOW_TO_DOWNLOAD](https://github.com/microsoft/SimXNS/tree/main/HOW_TO_DOWNLOAD.md).
47 |
48 |
49 | ## Environments Setting
50 |
51 | We implement our approach based on Pytorch and Huggingface Transformers. We list our command to prepare the experimental environment as follows:
52 |
53 | ```
54 | yes | conda create -n lead
55 | conda init bash
56 | source ~/.bashrc
57 | conda activate lead
58 | conda install -y pytorch==1.11.0 cudatoolkit=11.5 faiss-gpu -c pytorch
59 | conda install -y pandas
60 | pip install transformers
61 | pip install tqdm
62 | pip install wandb
63 | pip install sklearn
64 | pip install pytrec-eval
65 | ```
66 |
67 | ## Dataset Preprocess
68 |
69 | We conduct experiments on MS-PAS and MS-DOC datasets. You can download and preprocess data by using our code:
70 |
71 | ```
72 | cd data_preprocess
73 | bash get_data.sh
74 | ```
75 |
76 | ## Hard Negative Generation
77 |
78 | The negatives in the original dataset are mainly random or BM25 negatives. Before training, we first train a 12-layer DE and get top-100 hard negatives.
79 |
80 | ```
81 | cd hard_negative_generation
82 | bash train_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES
83 | bash retrieve_hard_negatives_academic.sh $DATASET $MASTER_PORT $MODEL_PATH $CKPT_NAME $MAX_DOC_LENGTH $MAX_QUERY_LENGTH
84 | ```
85 |
86 | ## Single Model Training
87 |
88 | Before distillation, we train the teacher and student with the mined hard negatives as warming up. You can go to the warm_up directory and run the following command:
89 |
90 | ### Train a 6-layer Dual Encoder
91 |
92 | ```
93 | bash train_6_layer_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES
94 | ```
95 |
96 | ### Train a 12-layer Dual Encoder
97 |
98 | ```
99 | bash train_12_layer_de.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES
100 | ```
101 |
102 | ### Train a 12-layer ColBERT
103 |
104 | ```
105 | bash train_12_layer_col.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES
106 | ```
107 |
108 | ### Train a 12-layer Cross Encoder
109 |
110 | ```
111 | bash train_12_layer_ce.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES
112 | ```
113 |
114 | ### Train a 24-layer ColBERT
115 |
116 | ```
117 | bash train_24_layer_col.sh $DATASET $MASTER_PORT $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES
118 | ```
119 |
120 | ## LEAD distillation
121 |
122 | ### Distill from 12-layer DE to 6-layer DE
123 |
124 | ```
125 | bash distill_from_12de_to_6de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $DE_MODEL_PATH $DB_MODEL_PATH
126 | ```
127 |
128 | ### Distill from 12-layer CB to 6-layer DE
129 |
130 | ```
131 | bash distill_from_12cb_to_6de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $CB_MODEL_PATH $DB_MODEL_PATH
132 | ```
133 |
134 | ### Distill from 12-layer CE to 6-layer DE
135 |
136 | ```
137 | bash distill_from_12ce_to_6de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $CE_MODEL_PATH $DB_MODEL_PATH
138 | ```
139 |
140 | ### Distill from 24-layer CB to 12-layer DE
141 |
142 | ```
143 | bash distill_from_24cb_to_12de.sh $DATASET $MASTER_PORT $DISTILL_LAYER_NUM $MAX_DOC_LENGTH $MAX_QUERY_LENGTH $TRAIN_BATCH_SIZE $NUM_NEGATIVES $WARM_RATIO $CB_MODEL_PATH $DB_MODEL_PATH
144 | ```
145 |
146 | ## Inference
147 |
148 | ### Inference 6-layer DE
149 |
150 | ```
151 | bash evaluate_6_layer_de.sh $DATASET $MASTER_PORT $MODEL_PATH $FIRST_STEPS $EVAL_STEPS $MAX_STEPS $MAX_DOC_LENGTH $MAX_QUERY_LENGTH
152 | ```
153 |
154 | ### Inference 12-layer DE
155 |
156 | ```
157 | bash evaluate_12_layer_de.sh $DATASET $MASTER_PORT $MODEL_PATH $FIRST_STEPS $EVAL_STEPS $MAX_STEPS $MAX_DOC_LENGTH $MAX_QUERY_LENGTH
158 | ```
159 |
160 | ## Parameters
161 | The meaning of each parameter is defined as follows. For detailed initialization of these parameters, please refer to Section ''Hyper-parameter and Initialization''' in Appendix B of our paper.
162 | - `--DATASET`: The dataset used for training, e.g, mspas or msdoc.
163 | - `--MASTER_PORT`: The master port for parallel training, e.g., 3000.
164 | - `--FIRST_STEPS`: The first evaluation step.
165 | - `--EVAL_STEPS`: The evaluation step gap.
166 | - `--MAX_STEPS`: The maximum evaluation step.
167 | - `--WARM_RATIO`: The warm up ratio of the training.
168 | - `--MAX_DOC_LENGTH`: The maximum input length for each document.
169 | - `--MAX_QUERY_LENGTH`: The maximum input length for each query.
170 | - `--TRAIN_BATCH_SIZE`: Batch size per GPU.
171 | - `--NUM_NEGATIVES`: The number of negatives for each query.
172 | - `--DISTILL_LAYER_NUM`: The number of layers for distillation.
173 | - `--MODEL_PATH`: The directory name of saved checkpoint.
174 | - `--DE_MODEL_PATH`: The path of DE teacher checkpoint.
175 | - `--CB_MODEL_PATH`: The path of CB teacher checkpoint.
176 | - `--CE_MODEL_PATH`: The path of CE teacher checkpoint.
177 | - `--DB_MODEL_PATH`: The path of DE student checkpoint.
178 | - `--CKPT_NAME`: The training step of the checkpoint
179 |
180 | ## 📜 Citation
181 |
182 | Please cite our paper if you use [LEAD](https://arxiv.org/abs/2212.05225) in your work:
183 | ```bibtex
184 | @article{sun2022lead,
185 | title={LEAD: Liberal Feature-based Distillation for Dense Retrieval},
186 | author={Sun, Hao and Liu, Xiao and Gong, Yeyun and Dong, Anlei and Jiao, Jian and Lu, Jingwen and Zhang, Yan and Jiang, Daxin and Yang, Linjun and Majumder, Rangan and others},
187 | journal={arXiv preprint arXiv:2212.05225},
188 | year={2022}
189 | }
190 | ```
191 |
--------------------------------------------------------------------------------
/LEAD/assets/main_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/LEAD/assets/main_result.jpg
--------------------------------------------------------------------------------
/LEAD/assets/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/LEAD/assets/model.jpg
--------------------------------------------------------------------------------
/LEAD/data_preprocess/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/LEAD/data_preprocess/.DS_Store
--------------------------------------------------------------------------------
/LEAD/data_preprocess/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm, trange
3 | from random import sample
4 |
5 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1):
6 | def gen():
7 | for i, line in tqdm(enumerate(fd)):
8 | if i % trainer_num == trainer_id:
9 | slots = line.rstrip('\n').split(delimiter)
10 | if len(slots) == 1:
11 | yield slots,
12 | else:
13 | yield slots
14 | return gen()
15 |
16 | def read_qrel_train(relevance_file):
17 | qrel = {}
18 | with open(relevance_file, encoding='utf8') as f:
19 | tsvreader = csv_reader(f, delimiter="\t")
20 | for [topicid, _, docid, rel] in tsvreader:
21 | assert rel == "1"
22 | if topicid in qrel:
23 | qrel[topicid].append(docid)
24 | else:
25 | qrel[topicid] = [docid]
26 | return qrel
27 |
28 | def read_qrel_train_2(relevance_file):
29 | qrel = {}
30 | with open(relevance_file, encoding='utf8') as f:
31 | tsvreader = csv_reader(f, delimiter=" ")
32 | for [topicid, _, docid, rel] in tsvreader:
33 | assert rel == "1"
34 | if topicid in qrel:
35 | qrel[topicid].append(docid)
36 | else:
37 | qrel[topicid] = [docid]
38 | return qrel
39 |
40 |
41 | def read_qrel_dev(relevance_file):
42 | qrel = {}
43 | with open(relevance_file, encoding='utf8') as f:
44 | tsvreader = csv_reader(f, delimiter="\t")
45 | for [topicid, docid] in tsvreader:
46 | if topicid in qrel:
47 | qrel[topicid].append(docid)
48 | else:
49 | qrel[topicid] = [docid]
50 | return qrel
51 |
52 | def read_qstring(query_file):
53 | q_string = {}
54 | with open(query_file, 'r', encoding='utf-8') as file:
55 | for num, line in enumerate(file):
56 | line = line.strip('\n') # 删除换行符
57 | line = line.split('\t')
58 | q_string[line[0]] = line[1]
59 | return q_string
60 |
61 | def read_dstring(corpus_file):
62 | d_string = {}
63 | with open(corpus_file, 'r', encoding='utf-8') as file:
64 | for num, line in enumerate(file):
65 | line = line.strip('\n') # 删除换行符
66 | line = line.split('\t')
67 | d_string[line[0]] = [line[1], line[2]]
68 | return d_string
69 |
70 | def read_dstring_2(corpus_file):
71 | d_string = {}
72 | with open(corpus_file, 'r', encoding='utf-8') as file:
73 | for num, line in enumerate(file):
74 | line = line.strip('\n') # 删除换行符
75 | line = line.split('\t')
76 | d_string[line[0]] = [line[2], line[3]]
77 | return d_string
78 |
79 |
80 | def construct_mspas():
81 | train_relevance_file = "mspas/qrels.train.tsv"
82 | train_query_file = "mspas/train.query.txt"
83 | train_negative_file = "mspas/train.negatives.tsv"
84 | corpus_file = "mspas/corpus.tsv"
85 | dev_relevance_file = "mspas/qrels.dev.tsv"
86 | dev_query_file = "mspas/dev.query.txt"
87 | n_sample = 100
88 |
89 | ## qid: docid
90 | train_qrel = read_qrel_train(train_relevance_file)
91 |
92 | ## qid: query
93 | train_q_string = read_qstring(train_query_file)
94 |
95 | ## qid: docid
96 | dev_qrel = read_qrel_dev(dev_relevance_file)
97 |
98 | ## qid: query
99 | dev_q_string = read_qstring(dev_query_file)
100 |
101 | # docid: doc
102 | d_string = read_dstring(corpus_file)
103 |
104 | ## qid: [docid1, ...]
105 | negative = {}
106 | with open(train_negative_file, 'r', encoding='utf8') as nf:
107 | reader = csv_reader(nf)
108 | for cnt, line in enumerate(reader):
109 | q = line[0]
110 | nn = line[1]
111 | nn = nn.split(',')
112 | negative[q] = nn
113 |
114 | train_file = open("mspas/psgs_w100.tsv", 'w')
115 | for did in d_string:
116 | train_file.write('\t'.join([str(int(did) + 1), d_string[did][1], d_string[did][0]]) + '\n')
117 | train_file.flush()
118 |
119 | train_data_list = []
120 | for qid in tqdm(negative):
121 | example = {}
122 | example['question'] = train_q_string[qid]
123 | example['answers'] = train_qrel[qid]
124 | example['positive_ctxs'] = [did for did in train_qrel[qid]]
125 | example['hard_negative_ctxs'] = [did for did in negative[qid][:n_sample]]
126 | example['negative_ctxs'] = []
127 | train_data_list.append(example)
128 |
129 | with open('mspas/biencoder-mspas-train.json', 'w', encoding='utf-8') as json_file:
130 | json_file.write(json.dumps(train_data_list, indent=4))
131 |
132 |
133 | train_data_list = []
134 | for qid in tqdm(train_q_string):
135 | example = {}
136 | example['question'] = train_q_string[qid]
137 | example['answers'] = train_qrel[qid]
138 | example['positive_ctxs'] = [did for did in train_qrel[qid]]
139 | example['negative_ctxs'] = []
140 | train_data_list.append(example)
141 |
142 | with open('mspas/biencoder-mspas-train-full.json', 'w', encoding='utf-8') as json_file:
143 | json_file.write(json.dumps(train_data_list, indent=4))
144 |
145 | train_file = open("mspas/mspas-test.qa.csv", 'w')
146 | for qid in dev_qrel:
147 | train_file.write('\t'.join([dev_q_string[qid], str(dev_qrel[qid])]) + '\n')
148 | train_file.flush()
149 |
150 |
151 | def construct_msdoc():
152 | train_relevance_file = "msdoc/msmarco-doctrain-qrels.tsv"
153 | train_query_file = "msdoc/msmarco-doctrain-queries.tsv"
154 | corpus_file = "msdoc/msmarco-docs.tsv"
155 | dev_relevance_file = "msdoc/msmarco-docdev-qrels.tsv"
156 | dev_query_file = "msdoc/msmarco-docdev-queries.tsv"
157 | n_sample = 100
158 |
159 | ## qid: docid
160 | train_qrel = read_qrel_train_2(train_relevance_file)
161 |
162 | ## qid: query
163 | train_q_string = read_qstring(train_query_file)
164 |
165 | ## qid: docid
166 | dev_qrel = read_qrel_train_2(dev_relevance_file)
167 |
168 | ## qid: query
169 | dev_q_string = read_qstring(dev_query_file)
170 |
171 | # docid: doc
172 | d_string = read_dstring_2(corpus_file)
173 |
174 | docid_int = {}
175 | int_docid = {}
176 | idx = 0
177 | for docid in d_string:
178 | docid_int[docid] = idx
179 | int_docid[idx] = docid
180 | idx += 1
181 |
182 | ## qid: [docid1, ...]
183 | negative = {}
184 | all_doc_list = [elem for elem in docid_int]
185 |
186 | for docid in train_qrel:
187 | negative[docid] = sample(all_doc_list, n_sample)
188 |
189 | train_data_list = []
190 | for qid in tqdm(train_qrel):
191 | example = {}
192 | example['question'] = train_q_string[qid]
193 | example['answers'] = [docid_int[did] for did in train_qrel[qid]]
194 | example['positive_ctxs'] = [docid_int[did] for did in train_qrel[qid]]
195 | example['hard_negative_ctxs'] = [docid_int[did] for did in negative[qid][:n_sample]]
196 | example['negative_ctxs'] = []
197 | train_data_list.append(example)
198 |
199 | with open('msdoc/biencoder-msdoc-train.json', 'w', encoding='utf-8') as json_file:
200 | json_file.write(json.dumps(train_data_list, indent=4))
201 |
202 | train_file = open("msdoc/psgs_w100.tsv", 'w')
203 | for did in int_docid:
204 | train_file.write('\t'.join([str(did + 1), d_string[int_docid[did]][1], d_string[int_docid[did]][0]]) + '\n')
205 | train_file.flush()
206 |
207 | train_file = open("msdoc/msdoc-test.qa.csv", 'w')
208 | for qid in dev_qrel:
209 | train_file.write('\t'.join([dev_q_string[qid], str([str(docid_int[elem]) for elem in dev_qrel[qid]])]) + '\n')
210 | train_file.flush()
211 |
212 |
213 | print('Start Processing')
214 | construct_mspas()
215 | construct_msdoc()
216 | print('Finish Processing')
--------------------------------------------------------------------------------
/LEAD/data_preprocess/get_data.sh:
--------------------------------------------------------------------------------
1 | # Download MSMARCO-PAS data
2 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz
3 | tar -zxf marco.tar.gz
4 | rm -rf marco.tar.gz
5 | mv marco mspas
6 | cd mspas
7 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz
8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv
9 | gunzip qidpidtriples.train.full.2.tsv.gz
10 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv
11 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv
12 |
13 |
14 |
15 | # Download MSMARCO-DOC data
16 | cd ..
17 | mkdir msdoc
18 | cd msdoc
19 |
20 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz
21 | gunzip msmarco-docs.tsv.gz
22 |
23 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz
24 | gunzip msmarco-doctrain-queries.tsv.gz
25 |
26 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz
27 | gunzip msmarco-doctrain-qrels.tsv.gz
28 |
29 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz
30 | gunzip msmarco-test2019-queries.tsv.gz
31 |
32 | wget https://trec.nist.gov/data/deep/2019qrels-docs.txt
33 |
34 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz
35 | gunzip msmarco-docdev-queries.tsv.gz
36 |
37 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz
38 | gunzip msmarco-docdev-qrels.tsv.gz
39 |
40 | cd ..
41 | python data_preprocess.py
--------------------------------------------------------------------------------
/LEAD/data_preprocess/hard_negative_generation/retrieve_hard_negatives_academic.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET=$1
4 | MASTER_PORT=$2
5 | MODEL_PATH=$3
6 | CKPT_NAME=$4
7 | MAX_DOC_LENGTH=$5
8 | MAX_QUERY_LENGTH=$6
9 |
10 | MODEL_TYPE=dual_encoder
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 |
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | if [ ${DATASET} == 'mspas' ];
17 | then
18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-mspas-train-full.json
19 | else
20 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train.json
21 | fi
22 |
23 | TOKENIZER_NAME=bert-base-uncased
24 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
25 |
26 | EVAL_BATCHSIZE=256
27 |
28 | ####################################### Multi Evaluation ########################################
29 | echo "****************begin Retrieve****************"
30 | cd ..
31 |
32 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} retrieve_hard_negative.py \
33 | --model_type=${MODEL_TYPE} \
34 | --dataset ${DATASET} \
35 | --tokenizer_name=${TOKENIZER_NAME} \
36 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
37 | --eval_model_dir=${BASE_OUTPUT_DIR}/models/${MODEL_PATH}/de-checkpoint-${CKPT_NAME} \
38 | --output_dir=${BASE_DATA_DIR} \
39 | --train_file=${TRAIN_FILE} \
40 | --passage_path=${PASSAGE_PATH} \
41 | --max_query_length=${MAX_QUERY_LENGTH} \
42 | --max_doc_length=${MAX_DOC_LENGTH} \
43 | --per_gpu_eval_batch_size $EVAL_BATCHSIZE
44 |
45 | echo "****************End Retrieve****************"
--------------------------------------------------------------------------------
/LEAD/data_preprocess/hard_negative_generation/train_de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=dual_encoder
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MAX_DOC_LENGTH=$3
7 | MAX_QUERY_LENGTH=$4
8 | TRAIN_BATCH_SIZE=$5
9 | NUM_NEGATIVES=$6
10 |
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train.json
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | TOKENIZER_NAME=bert-base-uncased
17 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
18 | MAX_STEPS=20000
19 | LOGGING_STEPS=10
20 | SAVE_STEPS=1000
21 | LR=2e-5
22 | GRADIENT_ACCUMULATION_STEPS=1
23 | ######################################## Training ########################################
24 | echo "****************begin Train****************"
25 | cd ../
26 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
27 | ./run_single_model.py \
28 | --model_type=${MODEL_TYPE} \
29 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
30 | --tokenizer_name=${TOKENIZER_NAME} \
31 | --dataset=${DATASET} \
32 | --train_file=${TRAIN_FILE} \
33 | --passage_path=${PASSAGE_PATH} \
34 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
35 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
36 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
37 | --share_weight \
38 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
39 | --output_dir=${BASE_OUTPUT_DIR}/models
40 |
41 | echo "****************End Train****************"
42 |
--------------------------------------------------------------------------------
/LEAD/distillation/distill_from_12cb_to_6de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=distilbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | DISTILL_LAYER_NUM=$3
7 | MAX_DOC_LENGTH=$4
8 | MAX_QUERY_LENGTH=$5
9 | TRAIN_BATCH_SIZE=$6
10 | NUM_NEGATIVES=$7
11 | WARM_RATIO=$8
12 | COL_MODEL_PATH=$9
13 | DB_MODEL_PATH=${10}
14 |
15 |
16 | BASE_DATA_DIR=data_preprocess/${DATASET}
17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
20 | DISTILL_COL_PATH=model_output/colbert/models/${COL_MODEL_PATH} # The initialized parameter of ce
21 | DISTILL_DB_PATH=model_output/distilbert/models/${DB_MODEL_PATH} # The initialized parameter of db
22 |
23 |
24 | TOKENIZER_NAME=bert-base-uncased
25 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db
27 | DISTILL_PARA_COL=1 # The distillation parameter of loss_col
28 | DISTILL_PARA_COL_DB_DIS=1
29 | DISTILL_PARA_COL_DB_LAYER_SCORE=1
30 | GRADIENT_ACCUMULATION_STEPS=1
31 | SAVE_STEPS=10
32 | TEMPERATURE=1
33 | LAYER_TEMPERATURE=10
34 | MAX_STEPS=50000
35 | LOGGING_STEPS=10
36 | LR=2e-5
37 |
38 |
39 | cd ..
40 | ######################################## Training ########################################
41 | echo "****************begin Train****************"
42 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
43 | ./run_LEAD.py \
44 | --model_type=${MODEL_TYPE} \
45 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
46 | --tokenizer_name=${TOKENIZER_NAME} \
47 | --train_file=${TRAIN_FILE} \
48 | --passage_path=${PASSAGE_PATH} \
49 | --dataset=${DATASET} \
50 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
51 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
52 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
53 | --share_weight \
54 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
55 | --temperature ${TEMPERATURE} \
56 | --layer_temperature ${LAYER_TEMPERATURE} \
57 | --output_dir=${BASE_OUTPUT_DIR}/models \
58 | --distill_para_db=${DISTILL_PARA_DB} \
59 | --distill_para_col=${DISTILL_PARA_COL} \
60 | --distill_para_col_db_dis=${DISTILL_PARA_COL_DB_DIS} \
61 | --distill_para_col_db_layer_score=${DISTILL_PARA_COL_DB_LAYER_SCORE} \
62 | --distill_col --train_col\
63 | --distill_col_path=${DISTILL_COL_PATH} \
64 | --distill_db --train_db\
65 | --distill_db_path=${DISTILL_DB_PATH} \
66 | --distill_col_db_layer_score \
67 | --disitll_layer_num=${DISTILL_LAYER_NUM} \
68 | --layer_selection_random \
69 | --layer_score_reweight \
70 | --warm_up_ratio=${WARM_RATIO}
71 |
72 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/distillation/distill_from_12ce_to_6de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=distilbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | DISTILL_LAYER_NUM=$3
7 | MAX_DOC_LENGTH=$4
8 | MAX_QUERY_LENGTH=$5
9 | TRAIN_BATCH_SIZE=$6
10 | NUM_NEGATIVES=$7
11 | WARM_RATIO=$8
12 | CE_MODEL_PATH=$9
13 | DB_MODEL_PATH=${10}
14 |
15 |
16 | BASE_DATA_DIR=data_preprocess/${DATASET}
17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
20 | DISTILL_CE_PATH=model_output/cross_encoder/models/${CE_MODEL_PATH} # The initialized parameter of ce
21 | DISTILL_DB_PATH=model_output/distilbert/models/${DB_MODEL_PATH} # The initialized parameter of db
22 |
23 |
24 | TOKENIZER_NAME=bert-base-uncased
25 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db
27 | DISTILL_PARA_CE=1 # The distillation parameter of loss_ce
28 | DISTILL_PARA_CE_DB_DIS=1
29 | DISTILL_PARA_CE_DB_LAYER_SCORE=1
30 | GRADIENT_ACCUMULATION_STEPS=10
31 | SAVE_STEPS=10
32 | TEMPERATURE=1
33 | LAYER_TEMPERATURE=10
34 | MAX_STEPS=100000
35 | LOGGING_STEPS=10
36 | LR=5e-5
37 |
38 | cd ..
39 | ######################################## Training ########################################
40 | echo "****************begin Train****************"
41 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
42 | ./run_LEAD.py \
43 | --model_type=${MODEL_TYPE} \
44 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
45 | --tokenizer_name=${TOKENIZER_NAME} \
46 | --train_file=${TRAIN_FILE} \
47 | --passage_path=${PASSAGE_PATH} \
48 | --dataset=${DATASET} \
49 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
50 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
51 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
52 | --share_weight \
53 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
54 | --temperature ${TEMPERATURE} \
55 | --layer_temperature ${LAYER_TEMPERATURE} \
56 | --output_dir=${BASE_OUTPUT_DIR}/models \
57 | --distill_para_db=${DISTILL_PARA_DB} \
58 | --distill_para_ce=${DISTILL_PARA_CE} \
59 | --distill_para_ce_db_dis=${DISTILL_PARA_CE_DB_DIS} \
60 | --distill_para_ce_db_layer_score=${DISTILL_PARA_CE_DB_LAYER_SCORE} \
61 | --distill_ce --train_ce\
62 | --distill_ce_path=${DISTILL_CE_PATH} \
63 | --distill_db --train_db\
64 | --distill_db_path=${DISTILL_DB_PATH} \
65 | --distill_ce_db_layer_score \
66 | --disitll_layer_num=${DISTILL_LAYER_NUM} \
67 | --layer_selection_random \
68 | --layer_score_reweight \
69 | --warm_up_ratio=${WARM_RATIO}
70 |
71 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/distillation/distill_from_12de_to_6de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=distilbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | DISTILL_LAYER_NUM=$3
7 | MAX_DOC_LENGTH=$4
8 | MAX_QUERY_LENGTH=$5
9 | TRAIN_BATCH_SIZE=$6
10 | NUM_NEGATIVES=$7
11 | WARM_RATIO=$8
12 | DE_MODEL_PATH=$9
13 | DB_MODEL_PATH=${10}
14 |
15 |
16 | BASE_DATA_DIR=data_preprocess/${DATASET}
17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
20 | DISTILL_DE_PATH=model_output/dual_encoder/models/${DE_MODEL_PATH} # The initialized parameter of ce
21 | DISTILL_DB_PATH=model_output/distilbert/models/${DB_MODEL_PATH} # The initialized parameter of db
22 |
23 |
24 | TOKENIZER_NAME=bert-base-uncased
25 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db
27 | DISTILL_PARA_DE=1 # The distillation parameter of loss_de
28 | DISTILL_PARA_DE_DB_DIS=1
29 | DISTILL_PARA_DE_DB_LAYER_SCORE=1
30 | GRADIENT_ACCUMULATION_STEPS=1
31 | SAVE_STEPS=10
32 | TEMPERATURE=1
33 | LAYER_TEMPERATURE=10
34 | MAX_STEPS=50000
35 | LOGGING_STEPS=10
36 | LR=5e-5
37 |
38 | cd ..
39 | ######################################## Training ########################################
40 | echo "****************begin Train****************"
41 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
42 | ./run_LEAD.py \
43 | --model_type=${MODEL_TYPE} \
44 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
45 | --tokenizer_name=${TOKENIZER_NAME} \
46 | --train_file=${TRAIN_FILE} \
47 | --passage_path=${PASSAGE_PATH} \
48 | --dataset=${DATASET} \
49 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
50 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
51 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
52 | --share_weight \
53 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
54 | --temperature ${TEMPERATURE} \
55 | --layer_temperature ${LAYER_TEMPERATURE} \
56 | --output_dir=${BASE_OUTPUT_DIR}/models \
57 | --distill_para_de=${DISTILL_PARA_DE} \
58 | --distill_para_db=${DISTILL_PARA_DB} \
59 | --distill_para_de_db_dis=${DISTILL_PARA_DE_DB_DIS} \
60 | --distill_para_de_db_layer_score=${DISTILL_PARA_DE_DB_LAYER_SCORE} \
61 | --distill_de --train_de \
62 | --distill_de_path=${DISTILL_DE_PATH} \
63 | --distill_db --train_db \
64 | --distill_db_path=${DISTILL_DB_PATH} \
65 | --distill_de_db_layer_score \
66 | --disitll_layer_num=${DISTILL_LAYER_NUM} \
67 | --layer_selection_random \
68 | --layer_score_reweight \
69 | --warm_up_ratio=${WARM_RATIO}
70 |
71 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/distillation/distill_from_24cb_to_12de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=dual_encoder
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | DISTILL_LAYER_NUM=$3
7 | MAX_DOC_LENGTH=$4
8 | MAX_QUERY_LENGTH=$5
9 | TRAIN_BATCH_SIZE=$6
10 | NUM_NEGATIVES=$7
11 | WARM_RATIO=$8
12 | COL_MODEL_PATH=$9
13 | DB_MODEL_PATH=${10}
14 |
15 |
16 | BASE_DATA_DIR=data_preprocess/${DATASET}
17 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
18 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
19 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
20 | DISTILL_COL_PATH=model_output/colbert/models/${COL_MODEL_PATH} # The initialized parameter of ce
21 | DISTILL_DB_PATH=model_output/dual_encoder/models/${DB_MODEL_PATH} # The initialized parameter of db
22 |
23 |
24 | TOKENIZER_NAME=bert-base-uncased
25 | PRETRAINED_MODEL_NAME=nghuyong/ernie-2.0-large-en
26 | DISTILL_PARA_DB=1 # The distillation parameter of loss_db
27 | DISTILL_PARA_COL=1 # The distillation parameter of loss_col
28 | DISTILL_PARA_COL_DB_DIS=1
29 | DISTILL_PARA_COL_DB_LAYER_SCORE=1
30 | GRADIENT_ACCUMULATION_STEPS=1
31 | SAVE_STEPS=10
32 | TEMPERATURE=1
33 | LAYER_TEMPERATURE=10
34 | MAX_STEPS=50000
35 | LOGGING_STEPS=10
36 | LR=2e-5
37 |
38 |
39 | cd ..
40 | ######################################## Training ########################################
41 | echo "****************begin Train****************"
42 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
43 | ./run_LEAD.py \
44 | --model_type=${MODEL_TYPE} \
45 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
46 | --tokenizer_name=${TOKENIZER_NAME} \
47 | --train_file=${TRAIN_FILE} \
48 | --passage_path=${PASSAGE_PATH} \
49 | --dataset=${DATASET} \
50 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
51 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
52 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
53 | --share_weight \
54 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
55 | --temperature ${TEMPERATURE} \
56 | --layer_temperature ${LAYER_TEMPERATURE} \
57 | --output_dir=${BASE_OUTPUT_DIR}/models \
58 | --distill_para_db=${DISTILL_PARA_DB} \
59 | --distill_para_col=${DISTILL_PARA_COL} \
60 | --distill_para_col_db_dis=${DISTILL_PARA_COL_DB_DIS} \
61 | --distill_para_col_db_layer_score=${DISTILL_PARA_COL_DB_LAYER_SCORE} \
62 | --distill_col --train_col\
63 | --distill_col_path=${DISTILL_COL_PATH} \
64 | --distill_db --train_db\
65 | --distill_db_path=${DISTILL_DB_PATH} \
66 | --distill_col_db_layer_score \
67 | --disitll_layer_num=${DISTILL_LAYER_NUM} \
68 | --layer_selection_random \
69 | --layer_score_reweight \
70 | --warm_up_ratio=${WARM_RATIO}
71 |
72 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/inference/evaluate_12_layer_de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=dual_encoder
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MODEL_PATH=$3
7 | FIRST_STEPS=$4
8 | EVAL_STEPS=$5
9 | MAX_STEPS=$6
10 | MAX_DOC_LENGTH=$7
11 | MAX_QUERY_LENGTH=$8
12 |
13 | BASE_DATA_DIR=data_preprocess/${DATASET}
14 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
15 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
16 | TEST_FILE=${BASE_DATA_DIR}/${DATASET}-test.qa.csv
17 |
18 | TOKENIZER_NAME=bert-base-uncased
19 | PRETRAINED_MODEL_NAME=master
20 | #PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
21 | EVAL_BATCHSIZE=256
22 |
23 | ##remember to change 'de-checkpoint-${CKPT_NAME}' to 'db-checkpoint-${CKPT_NAME}' when evaluating distillation result
24 | ####################################### Multi Evaluation ########################################
25 | echo "****************begin Evaluate****************"
26 | cd ..
27 | for ITER in $(seq 0 $((($MAX_STEPS - $FIRST_STEPS)/ $EVAL_STEPS)))
28 | do
29 | CKPT_NAME=$(($FIRST_STEPS + $ITER * $EVAL_STEPS))
30 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} inference_de.py \
31 | --model_type=${MODEL_TYPE} \
32 | --tokenizer_name=${TOKENIZER_NAME} \
33 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
34 | --eval_model_dir=${BASE_OUTPUT_DIR}/models/${MODEL_PATH}/de-checkpoint-${CKPT_NAME} \
35 | --output_dir=${BASE_OUTPUT_DIR}/eval/${MODEL_PATH}/de-checkpoint-${CKPT_NAME} \
36 | --test_file=${TEST_FILE} \
37 | --passage_path=${PASSAGE_PATH} \
38 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
39 | --per_gpu_eval_batch_size $EVAL_BATCHSIZE
40 | done
41 |
42 | echo "****************End Evaluate****************"
43 |
--------------------------------------------------------------------------------
/LEAD/inference/evaluate_6_layer_de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=distilbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MODEL_PATH=$3
7 | FIRST_STEPS=$4
8 | EVAL_STEPS=$5
9 | MAX_STEPS=$6
10 | MAX_DOC_LENGTH=$7
11 | MAX_QUERY_LENGTH=$8
12 |
13 | BASE_DATA_DIR=data_preprocess/${DATASET}
14 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
15 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
16 | TEST_FILE=${BASE_DATA_DIR}/${DATASET}-test.qa.csv
17 |
18 | TOKENIZER_NAME=bert-base-uncased
19 | PRETRAINED_MODEL_NAME=distilbert-base-uncased
20 | EVAL_BATCHSIZE=256
21 |
22 | ####################################### Multi Evaluation ########################################
23 | echo "****************begin Evaluate****************"
24 | cd ..
25 | for ITER in $(seq 0 $((($MAX_STEPS - $FIRST_STEPS)/ $EVAL_STEPS)))
26 | do
27 | CKPT_NAME=$(($FIRST_STEPS + $ITER * $EVAL_STEPS))
28 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} inference_de.py \
29 | --model_type=${MODEL_TYPE} \
30 | --tokenizer_name=${TOKENIZER_NAME} \
31 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
32 | --eval_model_dir=${BASE_OUTPUT_DIR}/models/${MODEL_PATH}/db-checkpoint-${CKPT_NAME} \
33 | --output_dir=${BASE_OUTPUT_DIR}/eval/${MODEL_PATH}/db-checkpoint-${CKPT_NAME} \
34 | --test_file=${TEST_FILE} \
35 | --passage_path=${PASSAGE_PATH} \
36 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
37 | --per_gpu_eval_batch_size $EVAL_BATCHSIZE
38 | done
39 |
40 | echo "****************End Evaluate****************"
41 |
--------------------------------------------------------------------------------
/LEAD/warm_up/train_12_layer_ce.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=cross_encoder
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MAX_DOC_LENGTH=$3
7 | MAX_QUERY_LENGTH=$4
8 | TRAIN_BATCH_SIZE=$5
9 | NUM_NEGATIVES=$6
10 |
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | TOKENIZER_NAME=bert-base-uncased
17 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
18 |
19 | MAX_STEPS=20000
20 | GRADIENT_ACCUMULATION_STEPS=10
21 | LOGGING_STEPS=10
22 | SAVE_STEPS=10
23 | LR=1e-5
24 | ######################################## Training ########################################
25 | echo "****************begin Train****************"
26 | cd ../
27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
28 | ./run_single_model.py \
29 | --model_type=${MODEL_TYPE} \
30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
31 | --tokenizer_name=${TOKENIZER_NAME} \
32 | --train_file=${TRAIN_FILE} \
33 | --dataset=${DATASET} \
34 | --passage_path=${PASSAGE_PATH} \
35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
38 | --share_weight \
39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
40 | --output_dir=${BASE_OUTPUT_DIR}/models
41 |
42 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/warm_up/train_12_layer_col.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=colbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MAX_DOC_LENGTH=$3
7 | MAX_QUERY_LENGTH=$4
8 | TRAIN_BATCH_SIZE=$5
9 | NUM_NEGATIVES=$6
10 |
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | TOKENIZER_NAME=bert-base-uncased
17 | PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
18 |
19 | MAX_STEPS=20000
20 | LOGGING_STEPS=10
21 | SAVE_STEPS=10
22 | LR=2e-5
23 | GRADIENT_ACCUMULATION_STEPS=1
24 | ######################################## Training ########################################
25 | echo "****************begin Train****************"
26 | cd ../
27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
28 | ./run_single_model.py \
29 | --model_type=${MODEL_TYPE} \
30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
31 | --tokenizer_name=${TOKENIZER_NAME} \
32 | --train_file=${TRAIN_FILE} \
33 | --dataset=${DATASET} \
34 | --passage_path=${PASSAGE_PATH} \
35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
38 | --share_weight \
39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
40 | --output_dir=${BASE_OUTPUT_DIR}/models
41 |
42 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/warm_up/train_12_layer_de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=dual_encoder
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MAX_DOC_LENGTH=$3
7 | MAX_QUERY_LENGTH=$4
8 | TRAIN_BATCH_SIZE=$5
9 | NUM_NEGATIVES=$6
10 |
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | TOKENIZER_NAME=bert-base-uncased
17 | PRETRAINED_MODEL_NAME=master
18 | #PRETRAINED_MODEL_NAME=Luyu/co-condenser-marco
19 |
20 | MAX_STEPS=20000
21 | LOGGING_STEPS=10
22 | SAVE_STEPS=10
23 | LR=2e-5
24 | GRADIENT_ACCUMULATION_STEPS=1
25 | ######################################## Training ########################################
26 | echo "****************begin Train****************"
27 | cd ../
28 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
29 | ./run_single_model.py \
30 | --model_type=${MODEL_TYPE} \
31 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
32 | --tokenizer_name=${TOKENIZER_NAME} \
33 | --train_file=${TRAIN_FILE} \
34 | --dataset=${DATASET} \
35 | --passage_path=${PASSAGE_PATH} \
36 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
37 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
38 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
39 | --share_weight \
40 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
41 | --output_dir=${BASE_OUTPUT_DIR}/models
42 |
43 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/warm_up/train_24_layer_col.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=colbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MAX_DOC_LENGTH=$3
7 | MAX_QUERY_LENGTH=$4
8 | TRAIN_BATCH_SIZE=$5
9 | NUM_NEGATIVES=$6
10 |
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | TOKENIZER_NAME=bert-base-uncased
17 | PRETRAINED_MODEL_NAME=nghuyong/ernie-2.0-large-en
18 |
19 | MAX_STEPS=20000
20 | LOGGING_STEPS=10
21 | SAVE_STEPS=10
22 | LR=1e-5
23 | GRADIENT_ACCUMULATION_STEPS=1
24 | ######################################## Training ########################################
25 | echo "****************begin Train****************"
26 | cd ../
27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
28 | ./run_single_model.py \
29 | --model_type=${MODEL_TYPE} \
30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
31 | --tokenizer_name=${TOKENIZER_NAME} \
32 | --train_file=${TRAIN_FILE} \
33 | --dataset=${DATASET} \
34 | --passage_path=${PASSAGE_PATH} \
35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
38 | --share_weight \
39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
40 | --output_dir=${BASE_OUTPUT_DIR}/models
41 |
42 | echo "****************End Train****************"
--------------------------------------------------------------------------------
/LEAD/warm_up/train_6_layer_de.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_TYPE=distilbert
4 | DATASET=$1
5 | MASTER_PORT=$2
6 | MAX_DOC_LENGTH=$3
7 | MAX_QUERY_LENGTH=$4
8 | TRAIN_BATCH_SIZE=$5
9 | NUM_NEGATIVES=$6
10 |
11 | BASE_DATA_DIR=data_preprocess/${DATASET}
12 | BASE_OUTPUT_DIR=model_output/${MODEL_TYPE}
13 | TRAIN_FILE=${BASE_DATA_DIR}/biencoder-${DATASET}-train-hard.json
14 | PASSAGE_PATH=${BASE_DATA_DIR}/psgs_w100.tsv
15 |
16 | TOKENIZER_NAME=bert-base-uncased
17 | PRETRAINED_MODEL_NAME=distilbert-base-uncased
18 |
19 | MAX_STEPS=20000
20 | LOGGING_STEPS=10
21 | SAVE_STEPS=10
22 | LR=2e-5
23 | GRADIENT_ACCUMULATION_STEPS=1
24 | ######################################## Training ########################################
25 | echo "****************begin Train****************"
26 | cd ../
27 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
28 | ./run_single_model.py \
29 | --model_type=${MODEL_TYPE} \
30 | --pretrained_model_name=${PRETRAINED_MODEL_NAME} \
31 | --tokenizer_name=${TOKENIZER_NAME} \
32 | --train_file=${TRAIN_FILE} \
33 | --dataset=${DATASET} \
34 | --passage_path=${PASSAGE_PATH} \
35 | --max_doc_length=${MAX_DOC_LENGTH} --max_query_length=${MAX_QUERY_LENGTH} \
36 | --per_gpu_train_batch_size=${TRAIN_BATCH_SIZE} --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
37 | --learning_rate=${LR} --num_hard_negatives ${NUM_NEGATIVES} \
38 | --share_weight \
39 | --logging_steps ${LOGGING_STEPS} --save_steps ${SAVE_STEPS} --max_steps ${MAX_STEPS} --seed 888 \
40 | --output_dir=${BASE_OUTPUT_DIR}/models
41 |
42 | echo "****************End Train****************"
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/MASTER/figs/master_main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/MASTER/figs/master_main.jpg
--------------------------------------------------------------------------------
/MASTER/figs/master_main_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/MASTER/figs/master_main_result.jpg
--------------------------------------------------------------------------------
/MASTER/finetune/ft_MS_MASTER.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=run_de_ms_MASTER_all_1 # de means dual encoder.
2 | DATA_DIR=/kun_data/marco/
3 | OUT_DIR=output/$EXP_NAME
4 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path
5 |
6 | for epoch in 80000
7 | do
8 |
9 | # Fine-tune with BM25 negatives using CKPT from epoch as initialization
10 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \
11 | ./MS/run_de_model.py \
12 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/SIMLM/checkpoint-$epoch \
13 | --origin_data_dir=$DATA_DIR/train_stage1.tsv \
14 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \
15 | --passage_path=/kun_data/marco \
16 | --dataset=MS-MARCO \
17 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \
18 | --learning_rate=5e-6 --output_dir $OUT_DIR \
19 | --warmup_steps 1000 --logging_steps 200 --save_steps 5000 --max_steps 30000 \
20 | --log_dir $TB_DIR \
21 | --number_neg 31 --fp16
22 |
23 | # Evaluation 3h totally
24 | for CKPT_NUM in 20000 25000 30000
25 | do
26 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
27 | ./MS/inference_de.py \
28 | --model_type=bert-base-uncased \
29 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
30 | --output_dir=$OUT_DIR/$CKPT_NUM \
31 | --train_qa_path=/kun_data/marco/train.query.txt \
32 | --test_qa_path=/kun_data/marco/dev.query.txt \
33 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
34 | --passage_path=/kun_data/marco \
35 | --dataset=MS-MARCO \
36 | --fp16
37 | done
38 |
39 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
40 | ./MS/inference_de.py \
41 | --model_type=bert-base-uncased \
42 | --eval_model_dir=$OUT_DIR/checkpoint-25000 \
43 | --output_dir=$OUT_DIR/25000 \
44 | --train_qa_path=/kun_data/marco/train.query.txt \
45 | --test_qa_path=/kun_data/marco/dev.query.txt \
46 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
47 | --passage_path=/kun_data/marco \
48 | --dataset=MS-MARCO \
49 | --fp16 --write_hardneg=True
50 |
51 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \
52 | ./MS/run_de_model.py \
53 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/SIMLM/checkpoint-$epoch \
54 | --origin_data_dir=$OUT_DIR/25000/train_ce_hardneg.tsv \
55 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \
56 | --passage_path=/kun_data/marco \
57 | --dataset=MS-MARCO \
58 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \
59 | --learning_rate=5e-6 --output_dir $OUT_DIR \
60 | --warmup_steps 1000 --logging_steps 200 --save_steps 5000 --max_steps 40000 \
61 | --log_dir $TB_DIR \
62 | --number_neg 31 --fp16
63 |
64 | # Evaluation 3h totally
65 | for CKPT_NUM in 25000 30000 35000 40000
66 | do
67 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
68 | ./MS/inference_de.py \
69 | --model_type=bert-base-uncased \
70 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
71 | --output_dir=$OUT_DIR/$CKPT_NUM \
72 | --train_qa_path=/kun_data/marco/train.query.txt \
73 | --test_qa_path=/kun_data/marco/dev.query.txt \
74 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
75 | --passage_path=/kun_data/marco \
76 | --dataset=MS-MARCO \
77 | --fp16
78 | done
79 |
80 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
81 | ./MS/inference_de.py \
82 | --model_type=bert-base-uncased \
83 | --eval_model_dir=$OUT_DIR/checkpoint-40000 \
84 | --output_dir=$OUT_DIR/40000 \
85 | --train_qa_path=/kun_data/marco/train.query.txt \
86 | --test_qa_path=/kun_data/marco/dev.query.txt \
87 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
88 | --passage_path=/kun_data/marco \
89 | --dataset=MS-MARCO \
90 | --fp16 --write_hardneg=True
91 |
92 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
93 | ./MS/run_ce_model_ele.py \
94 | --model_type=google/electra-base-discriminator --max_seq_length=192 \
95 | --per_gpu_train_batch_size=1 --gradient_accumulation_steps=8 \
96 | --number_neg=63 --learning_rate=1e-5 \
97 | --output_dir=$OUT_DIR \
98 | --origin_data_dir=$OUT_DIR/32000/train_ce_hardneg.tsv \
99 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \
100 | --passage_path=/kun_data/marco \
101 | --dataset=MS-MARCO \
102 | --warmup_steps=2000 --logging_steps=500 --save_steps=8000 \
103 | --max_steps=33000 --log_dir=$TB_DIR
104 |
105 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 MS/co_training_model_ele.py \
106 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/SIMLM/checkpoint-$epoch \
107 | --max_seq_length=128 --per_gpu_train_batch_size=16 --gradient_accumulation_steps=4 \
108 | --number_neg=41 --learning_rate=5e-6 \
109 | --reranker_model_type=google/electra-base-discriminator \
110 | --reranker_model_path=output/run_de_ms_MASTER_all_1/checkpoint-ce-24000 \
111 | --output_dir=$OUT_DIR \
112 | --log_dir=$TB_DIR \
113 | --origin_data_dir=output/run_de_ms_MASTER_all_1/32000/train_ce_hardneg.tsv \
114 | --origin_data_dir_dev=$DATA_DIR/dev.query.txt \
115 | --passage_path=/kun_data/marco \
116 | --dataset=MS-MARCO \
117 | --warmup_steps 5000 --logging_steps 500 --save_steps 5000 --max_steps 25000 \
118 | --gradient_checkpointing --normal_loss \
119 | --temperature_normal=1
120 |
121 | for CKPT_NUM in 10000 15000 20000 25000
122 | do
123 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
124 | ./MS/inference_de.py \
125 | --model_type=bert-base-uncased \
126 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
127 | --output_dir=$OUT_DIR/$CKPT_NUM \
128 | --train_qa_path=/kun_data/marco/train.query.txt \
129 | --test_qa_path=/kun_data/marco/dev.query.txt \
130 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
131 | --passage_path=/kun_data/marco \
132 | --dataset=MS-MARCO \
133 | --fp16
134 | done
135 |
136 | done
137 |
--------------------------------------------------------------------------------
/MASTER/finetune/ft_wiki_NQ.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=run_de_nq_MT5_Gen_hardneg_KD # de means dual encoder.
2 | DATA_DIR=/kun_data/DR/ARR/data/
3 | OUT_DIR=output/$EXP_NAME
4 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path
5 |
6 | for epoch in 140000
7 | do
8 |
9 | # Fine-tune with BM25 negatives using CKPT from epoch as initialization
10 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \
11 | ./wiki/run_de_model.py \
12 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki_1/checkpoint-$epoch \
13 | --origin_data_dir=$DATA_DIR/biencoder-nq-train.json \
14 | --origin_data_dir_dev=$DATA_DIR/biencoder-nq-dev.json \
15 | --max_seq_length=128 --per_gpu_train_batch_size=32 --gradient_accumulation_steps=1 \
16 | --learning_rate=1e-5 --output_dir $OUT_DIR \
17 | --warmup_steps 8000 --logging_steps 100 --save_steps 10000 --max_steps 80000 \
18 | --log_dir $TB_DIR \
19 | --number_neg 1 --fp16
20 |
21 | # Evaluation 3h totally
22 | for CKPT_NUM in 40000 50000 60000 70000 80000
23 | do
24 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
25 | ./wiki/inference_de.py \
26 | --model_type=bert-base-uncased \
27 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
28 | --output_dir=$OUT_DIR/$CKPT_NUM \
29 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \
30 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \
31 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \
32 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
33 | --passage_path=$DATA_DIR/psgs_w100.tsv \
34 | --fp16
35 | done
36 |
37 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
38 | ./wiki/inference_de.py \
39 | --model_type=bert-base-uncased \
40 | --eval_model_dir=$OUT_DIR/checkpoint-50000 \
41 | --output_dir=$OUT_DIR/$CKPT_NUM \
42 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \
43 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \
44 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \
45 | --golden_train_qa_path=$DATA_DIR/biencoder-nq-train.json \
46 | --golden_dev_qa_path=$DATA_DIR/biencoder-nq-dev.json \
47 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
48 | --passage_path=$DATA_DIR/psgs_w100.tsv \
49 | --fp16 --write_hardneg=True
50 |
51 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \
52 | ./wiki/run_de_model.py \
53 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki_1/checkpoint-$epoch \
54 | --origin_data_dir=$OUT_DIR/80000/train_ce_hardneg.json \
55 | --origin_data_dir_dev=$OUT_DIR/80000/dev_ce_hardneg.json \
56 | --max_seq_length=128 --per_gpu_train_batch_size=16 --gradient_accumulation_steps=1 \
57 | --learning_rate=5e-6 --output_dir $OUT_DIR \
58 | --warmup_steps 8000 --logging_steps 100 --save_steps 5000 --max_steps 50000 \
59 | --log_dir $TB_DIR \
60 | --number_neg 1 --fp16
61 |
62 | # Evaluation 3h totally
63 | for CKPT_NUM in 30000 35000 40000 45000 50000
64 | do
65 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
66 | ./wiki/inference_de.py \
67 | --model_type=bert-base-uncased \
68 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
69 | --output_dir=$OUT_DIR/$CKPT_NUM \
70 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \
71 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \
72 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \
73 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
74 | --passage_path=$DATA_DIR/psgs_w100.tsv \
75 | --fp16
76 | done
77 | '''
78 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
79 | ./wiki/inference_de.py \
80 | --model_type=bert-base-uncased \
81 | --eval_model_dir=$OUT_DIR/checkpoint-50000 \
82 | --output_dir=$OUT_DIR/$CKPT_NUM \
83 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \
84 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \
85 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \
86 | --golden_train_qa_path=$DATA_DIR/biencoder-nq-train.json \
87 | --golden_dev_qa_path=$DATA_DIR/biencoder-nq-dev.json \
88 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
89 | --passage_path=$DATA_DIR/psgs_w100.tsv \
90 | --fp16 --write_hardneg=True
91 |
92 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
93 | ./wiki/run_ce_model.py \
94 | --model_type=bert-large-uncased --max_seq_length=256 \
95 | --per_gpu_train_batch_size=1 --gradient_accumulation_steps=8 \
96 | --number_neg=15 --learning_rate=1e-5 \
97 | --output_dir=$OUT_DIR \
98 | --origin_data_dir=$OUT_DIR/$CKPT_NUM/train_ce_hardneg.json \
99 | --origin_data_dir_dev=$OUT_DIR/$CKPT_NUM/dev_ce_hardneg.json \
100 | --warmup_steps=1000 --logging_steps=100 --save_steps=1000 \
101 | --max_steps=5000 --log_dir=$TB_DIR
102 |
103 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_model.py \
104 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki/checkpoint-$epoch \
105 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \
106 | --number_neg=15 --learning_rate=1e-5 \
107 | --reranker_model_type=nghuyong/ernie-2.0-large-en \
108 | --reranker_model_path=$OUT_DIR/checkpoint-ce-4000 \
109 | --output_dir=$OUT_DIR \
110 | --log_dir=$TB_DIR \
111 | --origin_data_dir=$OUT_DIR/$CKPT_NUM/train_ce_hardneg.json \
112 | --origin_data_dir_dev=$OUT_DIR/$CKPT_NUM/dev_ce_hardneg.json \
113 | --warmup_steps=8000 --logging_steps=10 --save_steps=10000 --max_steps=80000 \
114 | --gradient_checkpointing --normal_loss \
115 | --temperature_normal=1
116 |
117 | for CKPT_NUM in 40000 50000 60000 70000 80000
118 | do
119 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
120 | ./wiki/inference_de.py \
121 | --model_type=bert-base-uncased \
122 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
123 | --output_dir=$OUT_DIR/$CKPT_NUM \
124 | --test_qa_path=$DATA_DIR/nq-test.qa.csv \
125 | --train_qa_path=$DATA_DIR/nq-train.qa.csv \
126 | --dev_qa_path=$DATA_DIR/nq-dev.qa.csv \
127 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
128 | --passage_path=$DATA_DIR/psgs_w100.tsv \
129 | --fp16
130 | done
131 | '''
132 | done
133 |
--------------------------------------------------------------------------------
/MASTER/finetune/ft_wiki_TQ.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=run_de_tq_MT5_Gen_hardneg_KD # de means dual encoder.
2 | DATA_DIR=/kun_data/DR/ARR/data/
3 | OUT_DIR=output/$EXP_NAME
4 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path
5 |
6 | for epoch in 80000
7 | do
8 | # Fine-tune with BM25 negatives using CKPT from epoch as initialization
9 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \
10 | ./wiki/run_de_model.py \
11 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki/checkpoint-$epoch \
12 | --origin_data_dir=$DATA_DIR/biencoder-trivia-train.json \
13 | --origin_data_dir_dev=$DATA_DIR/biencoder-trivia-dev.json \
14 | --max_seq_length=128 --per_gpu_train_batch_size=32 --gradient_accumulation_steps=1 \
15 | --learning_rate=1e-5 --output_dir $OUT_DIR \
16 | --warmup_steps 8000 --logging_steps 100 --save_steps 10000 --max_steps 80000 \
17 | --log_dir $TB_DIR \
18 | --number_neg 1 --fp16
19 |
20 | # Evaluation 3h totally
21 | for CKPT_NUM in 40000 50000 60000 70000 80000
22 | do
23 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
24 | ./wiki/inference_de.py \
25 | --model_type=bert-base-uncased \
26 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
27 | --output_dir=$OUT_DIR/$CKPT_NUM \
28 | --test_qa_path=$DATA_DIR/trivia-test.qa.csv \
29 | --train_qa_path=$DATA_DIR/trivia-train.qa.csv \
30 | --dev_qa_path=$DATA_DIR/trivia-dev.qa.csv \
31 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
32 | --passage_path=$DATA_DIR/psgs_w100.tsv \
33 | --fp16
34 | done
35 |
36 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
37 | ./wiki/inference_de.py \
38 | --model_type=bert-base-uncased \
39 | --eval_model_dir=$OUT_DIR/checkpoint-50000 \
40 | --output_dir=$OUT_DIR/$CKPT_NUM \
41 | --test_qa_path=$DATA_DIR/trivia-test.qa.csv \
42 | --train_qa_path=$DATA_DIR/trivia-train.qa.csv \
43 | --dev_qa_path=$DATA_DIR/trivia-dev.qa.csv \
44 | --golden_train_qa_path=$DATA_DIR/biencoder-trivia-train.json \
45 | --golden_dev_qa_path=$DATA_DIR/biencoder-trivia-dev.json \
46 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
47 | --passage_path=$DATA_DIR/psgs_w100.tsv \
48 | --fp16 --write_hardneg=True
49 |
50 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=69539 \
51 | ./wiki/run_de_model.py \
52 | --model_type=/kun_data/Austerlitz/MT5_Gen/ckpt/Wiki/checkpoint-$epoch \
53 | --origin_data_dir=$OUT_DIR/$CKPT_NUM/train_ce_hardneg.json \
54 | --origin_data_dir_dev=$OUT_DIR/$CKPT_NUM/dev_ce_hardneg.json \
55 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \
56 | --learning_rate=5e-6 --output_dir $OUT_DIR \
57 | --warmup_steps 8000 --logging_steps 100 --save_steps 10000 --max_steps 80000 \
58 | --log_dir $TB_DIR \
59 | --number_neg 1 --fp16
60 |
61 | # Evaluation 3h totally
62 | for CKPT_NUM in 40000 50000 60000 70000 80000
63 | do
64 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=19539 \
65 | ./wiki/inference_de.py \
66 | --model_type=bert-base-uncased \
67 | --eval_model_dir=$OUT_DIR/checkpoint-$CKPT_NUM \
68 | --output_dir=$OUT_DIR/$CKPT_NUM \
69 | --test_qa_path=$DATA_DIR/trivia-test.qa.csv \
70 | --train_qa_path=$DATA_DIR/trivia-train.qa.csv \
71 | --dev_qa_path=$DATA_DIR/trivia-dev.qa.csv \
72 | --max_seq_length=128 --per_gpu_eval_batch_size=1024 \
73 | --passage_path=$DATA_DIR/psgs_w100.tsv \
74 | --fp16
75 | done
76 |
77 | done
--------------------------------------------------------------------------------
/MASTER/finetune/utils/lamb.py:
--------------------------------------------------------------------------------
1 | """Lamb optimizer."""
2 |
3 | import collections
4 | import math
5 |
6 | import torch
7 | from tensorboardX import SummaryWriter
8 | from torch.optim import Optimizer
9 |
10 |
11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
12 | """Log a histogram of trust ratio scalars in across layers."""
13 | results = collections.defaultdict(list)
14 | for group in optimizer.param_groups:
15 | for p in group['params']:
16 | state = optimizer.state[p]
17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
18 | if i in state:
19 | results[i].append(state[i])
20 |
21 | for k, v in results.items():
22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
23 |
24 | class Lamb(Optimizer):
25 | r"""Implements Lamb algorithm.
26 |
27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
28 |
29 | Arguments:
30 | params (iterable): iterable of parameters to optimize or dicts defining
31 | parameter groups
32 | lr (float, optional): learning rate (default: 1e-3)
33 | betas (Tuple[float, float], optional): coefficients used for computing
34 | running averages of gradient and its square (default: (0.9, 0.999))
35 | eps (float, optional): term added to the denominator to improve
36 | numerical stability (default: 1e-8)
37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
38 | adam (bool, optional): always use trust ratio = 1, which turns this into
39 | Adam. Useful for comparison purposes.
40 |
41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
42 | https://arxiv.org/abs/1904.00962
43 | """
44 |
45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
46 | weight_decay=0, adam=False):
47 | if not 0.0 <= lr:
48 | raise ValueError("Invalid learning rate: {}".format(lr))
49 | if not 0.0 <= eps:
50 | raise ValueError("Invalid epsilon value: {}".format(eps))
51 | if not 0.0 <= betas[0] < 1.0:
52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
55 | defaults = dict(lr=lr, betas=betas, eps=eps,
56 | weight_decay=weight_decay)
57 | self.adam = adam
58 | super(Lamb, self).__init__(params, defaults)
59 |
60 | def step(self, closure=None):
61 | """Performs a single optimization step.
62 |
63 | Arguments:
64 | closure (callable, optional): A closure that reevaluates the model
65 | and returns the loss.
66 | """
67 | loss = None
68 | if closure is not None:
69 | loss = closure()
70 |
71 | for group in self.param_groups:
72 | for p in group['params']:
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
78 |
79 | state = self.state[p]
80 |
81 | # State initialization
82 | if len(state) == 0:
83 | state['step'] = 0
84 | # Exponential moving average of gradient values
85 | state['exp_avg'] = torch.zeros_like(p.data)
86 | # Exponential moving average of squared gradient values
87 | state['exp_avg_sq'] = torch.zeros_like(p.data)
88 |
89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | # Decay the first and second moment running average coefficient
95 | # m_t
96 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
97 | # v_t
98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
99 |
100 | # Paper v3 does not use debiasing.
101 | # Apply bias to lr to avoid broadcast.
102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
103 |
104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
105 |
106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
107 | if group['weight_decay'] != 0:
108 | adam_step.add_(group['weight_decay'], p.data)
109 |
110 | adam_norm = adam_step.pow(2).sum().sqrt()
111 | if weight_norm == 0 or adam_norm == 0:
112 | trust_ratio = 1
113 | else:
114 | trust_ratio = weight_norm / adam_norm
115 | state['weight_norm'] = weight_norm
116 | state['adam_norm'] = adam_norm
117 | state['trust_ratio'] = trust_ratio
118 | if self.adam:
119 | trust_ratio = 1
120 |
121 | p.data.add_(-step_size * trust_ratio, adam_step)
122 |
123 | return loss
124 |
--------------------------------------------------------------------------------
/MASTER/pretrain/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional, Union
3 | import os
4 | from transformers import TrainingArguments
5 |
6 | @dataclass
7 | class DataTrainingArguments:
8 | """
9 | Arguments pertaining to what data we are going to input our model for training and eval.
10 | """
11 |
12 | dataset_name: Optional[str] = field(
13 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
14 | )
15 | dataset_config_name: Optional[str] = field(
16 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
17 | )
18 | train_dir: str = field(
19 | default=None, metadata={"help": "Path to train directory"}
20 | )
21 | train_path: Union[str] = field(
22 | default=None, metadata={"help": "Path to train data"}
23 | )
24 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
25 | validation_file: Optional[str] = field(
26 | default=None,
27 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
28 | )
29 | train_ref_file: Optional[str] = field(
30 | default=None,
31 | metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
32 | )
33 | validation_ref_file: Optional[str] = field(
34 | default=None,
35 | metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
36 | )
37 | frequency_dict: Union[str] = field(
38 | default=None, metadata={"help": "Path to frequency dict"}
39 | )
40 | overwrite_cache: bool = field(
41 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
42 | )
43 | max_seq_length: Optional[int] = field(
44 | default=None,
45 | metadata={
46 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
47 | "than this will be truncated. Default to the max input length of the model."
48 | },
49 | )
50 | min_seq_length: int = field(default=16)
51 | preprocessing_num_workers: Optional[int] = field(
52 | default=None,
53 | metadata={"help": "The number of processes to use for the preprocessing."},
54 | )
55 | mlm_probability: float = field(
56 | default=0.3, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
57 | )
58 | decoder_mlm_probability: float = field(
59 | default=0.5, metadata={"help": "Ratio of tokens to mask for decoder masked language modeling loss"}
60 | )
61 | pad_to_max_length: bool = field(
62 | default=False,
63 | metadata={
64 | "help": "Whether to pad all samples to `max_seq_length`. "
65 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
66 | },
67 | )
68 |
69 | def __post_init__(self):
70 | if self.train_dir is not None:
71 | files = os.listdir(self.train_dir)
72 | self.train_path = [
73 | os.path.join(self.train_dir, f)
74 | for f in files
75 | if f.endswith('tsv') or f.endswith('json')
76 | ]
77 |
78 | @dataclass
79 | class ModelArguments:
80 | """
81 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
82 | """
83 |
84 | model_name_or_path: Optional[str] = field(
85 | default=None,
86 | metadata={
87 | "help": "The model checkpoint for weights initialization."
88 | "Don't set if you want to train a model from scratch."
89 | },
90 | )
91 | model_type: Optional[str] = field(
92 | default='bert',
93 | )
94 | config_name: Optional[str] = field(
95 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
96 | )
97 | tokenizer_name: Optional[str] = field(
98 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99 | )
100 | cache_dir: Optional[str] = field(
101 | default=None,
102 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
103 | )
104 | use_fast_tokenizer: bool = field(
105 | default=True,
106 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
107 | )
108 |
109 | n_head_layers: int = field(default=2)
110 | skip_from: int = field(default=2)
111 | late_mlm: bool = field(default=False)
112 | temp: float = field(default=0.05)
113 |
114 |
115 | @dataclass
116 | class CondenserPreTrainingArguments(TrainingArguments):
117 | warmup_ratio: float = field(default=0.1)
118 |
119 |
120 | @dataclass
121 | class CoCondenserPreTrainingArguments(CondenserPreTrainingArguments):
122 | cache_chunk_size: int = field(default=-1)
123 |
--------------------------------------------------------------------------------
/MASTER/pretrain/run_pre_training.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import logging
3 | import math
4 | import os
5 | import sys
6 | from datasets import load_dataset
7 | import json
8 | from arguments import DataTrainingArguments, ModelArguments, \
9 | CondenserPreTrainingArguments as TrainingArguments
10 | from data import CondenserCollator
11 | from modeling import CondenserForPretraining, RobertaCondenserForPretraining, ELECTRACondenserForPretraining
12 | from trainer import CondenserPreTrainer as Trainer
13 | import transformers
14 | from transformers import (
15 | CONFIG_MAPPING,
16 | AutoConfig,
17 | AutoTokenizer,
18 | HfArgumentParser,
19 | set_seed, )
20 | from transformers.trainer_utils import is_main_process
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | CONDENSER_TYPE_MAP = {
25 | 'bert': CondenserForPretraining,
26 | 'roberta': RobertaCondenserForPretraining,
27 | }
28 |
29 |
30 | def main():
31 | # See all possible arguments in src/transformers/training_args.py
32 | # or by passing the --help flag to this script.
33 | # We now keep distinct sets of args, for a cleaner separation of concerns.
34 |
35 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
36 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
37 | # If we pass only one argument to the script and it's the path to a json file,
38 | # let's parse it to get our arguments.
39 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
40 | else:
41 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
42 |
43 | if (
44 | os.path.exists(training_args.output_dir)
45 | and os.listdir(training_args.output_dir)
46 | and training_args.do_train
47 | and not training_args.overwrite_output_dir
48 | ):
49 | raise ValueError(
50 | f"Output directory ({training_args.output_dir}) already exists and is not empty."
51 | "Use --overwrite_output_dir to overcome."
52 | )
53 |
54 | model_args: ModelArguments
55 | data_args: DataTrainingArguments
56 | training_args: TrainingArguments
57 |
58 | # Setup logging
59 | logging.basicConfig(
60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
61 | datefmt="%m/%d/%Y %H:%M:%S",
62 | level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
63 | )
64 |
65 | # Log on each process the small summary:
66 | logger.warning(
67 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
68 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
69 | )
70 | # Set the verbosity to info of the Transformers logger (on main process only):
71 | if is_main_process(training_args.local_rank):
72 | transformers.utils.logging.set_verbosity_info()
73 | transformers.utils.logging.enable_default_handler()
74 | transformers.utils.logging.enable_explicit_format()
75 | logger.info("Training/evaluation parameters %s", training_args)
76 |
77 | # Set seed before initializing model.
78 | set_seed(training_args.seed)
79 |
80 | train_set = load_dataset(
81 | 'json',
82 | data_files=data_args.train_path,
83 | block_size=2**25,
84 | )['train']
85 | dev_set = load_dataset(
86 | 'json',
87 | data_files=data_args.validation_file,
88 | block_size=2**25
89 | )['train'] \
90 | if data_args.validation_file is not None else None
91 |
92 | if model_args.config_name:
93 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
94 | elif model_args.model_name_or_path:
95 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
96 | else:
97 | config = CONFIG_MAPPING[model_args.model_type]()
98 | logger.warning("You are instantiating a new config instance from scratch.")
99 |
100 | if model_args.tokenizer_name:
101 | tokenizer = AutoTokenizer.from_pretrained(
102 | model_args.tokenizer_name,
103 | cache_dir=model_args.cache_dir, use_fast=False
104 | )
105 | elif model_args.model_name_or_path:
106 | tokenizer = AutoTokenizer.from_pretrained(
107 | model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=False
108 | )
109 | else:
110 | raise ValueError(
111 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
112 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
113 | )
114 |
115 | # initialize the Condenser Pre-training LMX
116 | if model_args.model_type not in CONDENSER_TYPE_MAP:
117 | raise NotImplementedError(f'Condenser for {model_args.model_type} LM is not implemented')
118 | _condenser_cls = CONDENSER_TYPE_MAP[model_args.model_type]
119 | if model_args.model_name_or_path:
120 | model = _condenser_cls.from_pretrained(
121 | model_args, data_args, training_args,
122 | model_args.model_name_or_path,
123 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
124 | config=config,
125 | cache_dir=model_args.cache_dir,
126 | )
127 | else:
128 | logger.warning('Training from scratch.')
129 | model = _condenser_cls.from_config(
130 | config, model_args, data_args, training_args)
131 |
132 | # Data collator
133 | # This one will take care of randomly masking the tokens.
134 | if data_args.frequency_dict:
135 | logger.info("Load Frequency Dictionary from %s", data_args.frequency_dict)
136 | frequency_dict = json.load(open(data_args.frequency_dict))
137 | else:
138 | frequency_dict = None
139 | data_collator = CondenserCollator(
140 | tokenizer=tokenizer,
141 | mlm_probability=data_args.mlm_probability,
142 | decoder_mlm_probability=data_args.decoder_mlm_probability,
143 | max_seq_length=data_args.max_seq_length,
144 | frequency_dict=frequency_dict
145 | )
146 | # Initialize our Trainer
147 | trainer = Trainer(
148 | model=model,
149 | args=training_args,
150 | train_dataset=train_set,
151 | eval_dataset=dev_set,
152 | tokenizer=tokenizer,
153 | data_collator=data_collator,
154 | )
155 |
156 | # Training
157 | if training_args.do_train:
158 | model_path = (
159 | model_args.model_name_or_path
160 | if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
161 | else None
162 | )
163 | trainer.train(model_path=model_path)
164 | trainer.save_model() # Saves the tokenizer too for easy upload
165 |
166 | # Evaluation
167 | results = {}
168 | if training_args.do_eval:
169 | logger.info("*** Evaluate ***")
170 |
171 | eval_output = trainer.evaluate()
172 |
173 | perplexity = math.exp(eval_output["eval_loss"])
174 | results["perplexity"] = perplexity
175 |
176 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_mlm_wwm.txt")
177 | if trainer.is_world_process_zero():
178 | with open(output_eval_file, "w") as writer:
179 | logger.info("***** Eval results *****")
180 | for key, value in results.items():
181 | logger.info(f" {key} = {value}")
182 | writer.write(f"{key} = {value}\n")
183 |
184 | return results
185 |
186 |
187 | def _mp_fn(index):
188 | # For xla_spawn (TPUs)
189 | main()
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
194 |
--------------------------------------------------------------------------------
/MASTER/pretrain/run_pretrain.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 50000 run_pre_training.py \
2 | --output_dir ckpt/MASTER \
3 | --model_name_or_path bert-base-uncased \
4 | --do_train \
5 | --save_steps 40000 \
6 | --per_device_train_batch_size 128 \
7 | --gradient_accumulation_steps 2 \
8 | --warmup_ratio 0.1 \
9 | --learning_rate 3e-4 \
10 | --num_train_epochs 40 \
11 | --overwrite_output_dir \
12 | --dataloader_num_workers 32 \
13 | --n_head_layers 2 \
14 | --skip_from 6 \
15 | --max_seq_length 128 \
16 | --train_dir process_data \
17 | --frequency_dict frequency_dict_MS_doc.json \
18 | --weight_decay 0.01 \
19 | --late_mlm \
20 | --fp16
21 |
--------------------------------------------------------------------------------
/PROD/MarcoDoc_Data.sh:
--------------------------------------------------------------------------------
1 | # download MSMARCO doc data
2 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz
3 | gunzip msmarco-docs.tsv.gz
4 |
5 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz
6 | gunzip msmarco-doctrain-queries.tsv.gz
7 |
8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz
9 | gunzip msmarco-doctrain-qrels.tsv.gz
10 |
11 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz
12 | gunzip msmarco-test2019-queries.tsv.gz
13 |
14 | wget https://trec.nist.gov/data/deep/2019qrels-docs.txt
15 |
16 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz
17 | gunzip msmarco-docdev-queries.tsv.gz
18 |
19 | wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz
20 | gunzip msmarco-docdev-qrels.tsv.gz
21 |
22 | wget https://msmarco.blob.core.windows.net/msmarcoranking/docleaderboard-queries.tsv.gz
--------------------------------------------------------------------------------
/PROD/MarcoPas_Data.sh:
--------------------------------------------------------------------------------
1 | SCRIPT_DIR=$PWD
2 |
3 | wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz
4 | tar -zxf marco.tar.gz
5 | rm -rf marco.tar.gz
6 | cd marco
7 |
8 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz
9 | wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv -O qrels.train.tsv
10 | gunzip qidpidtriples.train.full.2.tsv.gz
11 | join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv
12 | awk -v RS='\r\n' '$1==last {printf ",%s",$3; next} NR>1 {print "";} {last=$1; printf "%s\t%s",$1,$3;} END{print "";}' qidpidtriples.train.full.2.tsv > train.negatives.tsv
--------------------------------------------------------------------------------
/PROD/ProD_KD/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_KD/model/__init__.py
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_KD/utils/__init__.py
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/build_marco_train.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import random
4 |
5 | relevance_file = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/qrels.train.tsv"
6 | query_file = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/train.query.txt"
7 | negative_file = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/train.negatives.tsv"
8 | outfile = "/colab_space/fanshuai/KDmarco/coCondenser-marco/marco/marco_train.json"
9 | n_sample = 30
10 |
11 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1):
12 | def gen():
13 | for i, line in tqdm(enumerate(fd)):
14 | if i % trainer_num == trainer_id:
15 | slots = line.rstrip('\n').split(delimiter)
16 | if len(slots) == 1:
17 | yield slots,
18 | else:
19 | yield slots
20 | return gen()
21 |
22 | def read_qrel(relevance_file):
23 | qrel = {}
24 | with open(relevance_file, encoding='utf8') as f:
25 | tsvreader = csv_reader(f, delimiter="\t")
26 | for [topicid, _, docid, rel] in tsvreader:
27 | assert rel == "1"
28 | if topicid in qrel:
29 | qrel[topicid].append(docid)
30 | else:
31 | qrel[topicid] = [docid]
32 | return qrel
33 |
34 | def read_qstring(query_file):
35 | q_string = {}
36 | with open(query_file, 'r', encoding='utf-8') as file:
37 | for num, line in enumerate(file):
38 | line = line.strip('\n') # 删除换行符
39 | line = line.split('\t')
40 | q_string[line[0]] = line[1]
41 | return q_string
42 |
43 | datalist = []
44 | qrel = read_qrel(relevance_file)
45 | q_string = read_qstring(query_file)
46 | with open(negative_file, 'r', encoding='utf8') as nf:
47 | reader = csv_reader(nf)
48 | for cnt, line in enumerate(reader):
49 | examples = {}
50 | q = line[0]
51 | nn = line[1]
52 | nn = nn.split(',')
53 | random.shuffle(nn)
54 | nn = nn[:n_sample]
55 | examples['query_id'] = q
56 | examples['query_string'] = q_string[q]
57 | examples['pos_id'] = qrel[q]
58 | examples['neg_id'] = nn
59 | if len(examples['query_id'])!=0 and len(examples['pos_id'])!=0 and len(examples['neg_id'])!=0:
60 | datalist.append(examples)
61 |
62 | print("data len:", len(datalist))
63 | print("data keys:", datalist[0].keys())
64 | print("data info:", datalist[0])
65 |
66 | with open(outfile, 'w',encoding='utf-8') as f:
67 | json.dump(datalist, f,indent=2)
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/dataset_division_nq.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import json
3 | from tqdm import tqdm
4 | import os
5 | import sys
6 |
7 |
8 | # data file
9 | # qid \t qstring \t pos_id \t neg_id
10 | # type : str
11 | # result_file_path1="/colab_space/fanshuai/KDnq/result/24CEt6DE_LwF_5e-5/20000/train_result_dict_list.json"
12 | # result_file_path2="/colab_space/fanshuai/KDnq/result/Ranker_24layer/10000trainrank/train_rerank_dict_list.json"
13 | # data_file_path="/colab_space/fanshuai/KDnq/result/24CEt6DE_LwF_5e-5/20000/train_nq_flash4.json"
14 | # output_dir = "/colab_space/fanshuai/KDnq/result/Ranker_24layer/10000div/"
15 |
16 | def read_result(result_file_path):
17 | result_data = []
18 | with open(result_file_path, 'r', encoding="utf-8") as f:
19 | data = json.load(f)
20 | for i, example in tqdm(enumerate(data)):
21 | result_data.append(example)
22 |
23 | return result_data
24 |
25 | def load_data(data_file_path):
26 | with open(data_file_path, 'r', encoding="utf-8") as f:
27 | data = json.load(f)
28 | print('Aggregated data size: {}'.format(len(data)))
29 | # filter those without positive ctx
30 | return data
31 |
32 | # def divide_data(result_data_student):
33 | # ranking = []
34 | # recall_q_top1 = set()
35 | # recall_q_2ti = set()
36 | # recall_q_2t5 = set()
37 | # recall_q_2t10 = set()
38 | # recall_q_2t15 = set()
39 | # recall_q_6t20 = set()
40 | # recall_q_21t50 = set()
41 | # recall_q_51t100 = set()
42 | # recall_q_101tall = set()
43 | #
44 | # for data in result_data_student:
45 | # qid = data['id']
46 | # for i, ctx in enumerate(data['ctxs']):
47 | # if ctx['hit'] == 'True':
48 | # if i < 100 and i >= 50:
49 | # recall_q_51t100.add(qid)
50 | # elif i < 50 and i >= 20:
51 | # recall_q_21t50.add(qid)
52 | # elif i < 20 and i >= 5:
53 | # recall_q_6t20.add(qid)
54 | # elif i < 5 and i >= 1:
55 | # recall_q_2t5.add(qid)
56 | # elif i == 0:
57 | # recall_q_top1.add(qid)
58 | # break
59 | #
60 | # print("top1 data num :", len(recall_q_top1))
61 | # print("top2 to topi data num :", len(recall_q_2ti))
62 | # print("top2 to top5 data num :", len(recall_q_2t5))
63 | # print("top2 to top10 data num :", len(recall_q_2t10))
64 | # print("top2 to top15 data num :", len(recall_q_2t15))
65 | # print("top6 to top20 data num :", len(recall_q_6t20))
66 | # print("top21 to top50 data num :", len(recall_q_21t50))
67 | # print("top51 to top100 data num :", len(recall_q_51t100))
68 | # print("top101 to top1000 data num :", len(recall_q_101tall))
69 | # print()
70 | #
71 | #
72 | # qid_divide_dic = {}
73 | # qid_divide_dic['top1'] = recall_q_top1
74 | # qid_divide_dic['2ti'] = recall_q_2ti
75 | # qid_divide_dic['2t5'] = recall_q_2t5
76 | # qid_divide_dic['2t10'] = recall_q_2t10
77 | # qid_divide_dic['2t15'] = recall_q_2t15
78 | # qid_divide_dic['6t20'] = recall_q_6t20
79 | # qid_divide_dic['21t50'] = recall_q_21t50
80 | # qid_divide_dic['51t100'] = recall_q_51t100
81 | # qid_divide_dic['101tall'] = recall_q_101tall
82 | #
83 | # return qid_divide_dic
84 |
85 | def divide_data(result_data_student, result_data_teacher):
86 | ranking = []
87 | t2_all_better = set()
88 | t2_15_better = set()
89 |
90 | qid_data_dic={}
91 | for data in result_data_teacher:
92 | qid_data_dic[data['id']] = data
93 |
94 | for data in result_data_student:
95 | qid = data['id']
96 |
97 | s_rank = 99999
98 | t_rank = 99999
99 | for i,ctx in enumerate(data['ctxs']):
100 | if ctx['hit'] == 'True':
101 | s_rank = i
102 | break
103 |
104 | t_data = qid_data_dic[qid]
105 | for i,ctx in enumerate(t_data['ctxs']):
106 | if ctx['hit'] == 'True':
107 | t_rank = i
108 | break
109 |
110 | if t_rank < s_rank and t_rank < 15:
111 | t2_15_better.add(qid)
112 |
113 | if t_rank < s_rank and t_rank < 100:
114 | t2_all_better.add(qid)
115 |
116 | print("t2_15_better len ", len(t2_15_better))
117 | print("t2_31_better len ", len(t2_all_better))
118 | print()
119 |
120 | qid_divide_dic = {}
121 | qid_divide_dic['t2_15_better'] = t2_15_better
122 | qid_divide_dic['t2_all_better'] = t2_all_better
123 |
124 | return qid_divide_dic
125 |
126 | def main(result_file_path1, result_file_path2, data_file_path, output_dir):
127 | t1_qidset = set()
128 | t2_qidset = set()
129 |
130 | t2_better_set = set()
131 |
132 | result_data_student = read_result(result_file_path1)
133 | result_data_teacher = read_result(result_file_path2)
134 |
135 | # train_data = load_data(data_file_path)
136 |
137 | # top1 / top2-top5 / top6-top20 / top30-top50 / top50-top100 / top100-topall
138 | #qid_divide_dic1 = divide_data(result_data_student)
139 | #qid_divide_dic2 = divide_data(result_data_teacher)
140 | qid_divide_dic = divide_data(result_data_student, result_data_teacher)
141 |
142 | t2_better_set = qid_divide_dic['t2_15_better']
143 |
144 | print("t2_better_set len:", len(t2_better_set))
145 |
146 | with open(data_file_path, 'r', encoding="utf-8") as f:
147 | dataset = json.load(f)
148 | print('Aggregated data size: {}'.format(len(dataset)))
149 |
150 | t2_better_data = []
151 |
152 | for data in dataset:
153 | if data['q_id'] in t2_better_set:
154 | t2_better_data.append(data)
155 |
156 | print(len(t2_better_data))
157 |
158 | print("check if data right...")
159 | if len(t2_better_data) != len(t2_better_set):
160 | print("data error")
161 | exit(0)
162 | print("data set success!")
163 |
164 | t2_better_output_dir = os.path.join(output_dir, "flash4_top15_better.json")
165 | with open(t2_better_output_dir, 'w', encoding='utf-8') as f:
166 | json.dump(t2_better_data, f, indent=2)
167 |
168 |
169 |
170 |
171 |
172 | if __name__ == "__main__":
173 | result_file_path_s = sys.argv[1]
174 | result_file_path_t = sys.argv[2]
175 | data_file_path = sys.argv[3]
176 | output_dir = sys.argv[4]
177 | main(result_file_path_s, result_file_path_t, data_file_path, output_dir)
178 | print("data division done")
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/lamb.py:
--------------------------------------------------------------------------------
1 | """Lamb optimizer."""
2 |
3 | import collections
4 | import math
5 |
6 | import torch
7 | from tensorboardX import SummaryWriter
8 | from torch.optim import Optimizer
9 |
10 |
11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
12 | """Log a histogram of trust ratio scalars in across layers."""
13 | results = collections.defaultdict(list)
14 | for group in optimizer.param_groups:
15 | for p in group['params']:
16 | state = optimizer.state[p]
17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
18 | if i in state:
19 | results[i].append(state[i])
20 |
21 | for k, v in results.items():
22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
23 |
24 | class Lamb(Optimizer):
25 | r"""Implements Lamb algorithm.
26 |
27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
28 |
29 | Arguments:
30 | params (iterable): iterable of parameters to optimize or dicts defining
31 | parameter groups
32 | lr (float, optional): learning rate (default: 1e-3)
33 | betas (Tuple[float, float], optional): coefficients used for computing
34 | running averages of gradient and its square (default: (0.9, 0.999))
35 | eps (float, optional): term added to the denominator to improve
36 | numerical stability (default: 1e-8)
37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
38 | adam (bool, optional): always use trust ratio = 1, which turns this into
39 | Adam. Useful for comparison purposes.
40 |
41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
42 | https://arxiv.org/abs/1904.00962
43 | """
44 |
45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
46 | weight_decay=0, adam=False):
47 | if not 0.0 <= lr:
48 | raise ValueError("Invalid learning rate: {}".format(lr))
49 | if not 0.0 <= eps:
50 | raise ValueError("Invalid epsilon value: {}".format(eps))
51 | if not 0.0 <= betas[0] < 1.0:
52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
55 | defaults = dict(lr=lr, betas=betas, eps=eps,
56 | weight_decay=weight_decay)
57 | self.adam = adam
58 | super(Lamb, self).__init__(params, defaults)
59 |
60 | def step(self, closure=None):
61 | """Performs a single optimization step.
62 |
63 | Arguments:
64 | closure (callable, optional): A closure that reevaluates the model
65 | and returns the loss.
66 | """
67 | loss = None
68 | if closure is not None:
69 | loss = closure()
70 |
71 | for group in self.param_groups:
72 | for p in group['params']:
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
78 |
79 | state = self.state[p]
80 |
81 | # State initialization
82 | if len(state) == 0:
83 | state['step'] = 0
84 | # Exponential moving average of gradient values
85 | state['exp_avg'] = torch.zeros_like(p.data)
86 | # Exponential moving average of squared gradient values
87 | state['exp_avg_sq'] = torch.zeros_like(p.data)
88 |
89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | # Decay the first and second moment running average coefficient
95 | # m_t
96 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
97 | # v_t
98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
99 |
100 | # Paper v3 does not use debiasing.
101 | # Apply bias to lr to avoid broadcast.
102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
103 |
104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
105 |
106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
107 | if group['weight_decay'] != 0:
108 | adam_step.add_(group['weight_decay'], p.data)
109 |
110 | adam_norm = adam_step.pow(2).sum().sqrt()
111 | if weight_norm == 0 or adam_norm == 0:
112 | trust_ratio = 1
113 | else:
114 | trust_ratio = weight_norm / adam_norm
115 | state['weight_norm'] = weight_norm
116 | state['adam_norm'] = adam_norm
117 | state['trust_ratio'] = trust_ratio
118 | if self.adam:
119 | trust_ratio = 1
120 |
121 | p.data.add_(-step_size * trust_ratio, adam_step)
122 |
123 | return loss
124 |
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/prepare_ce_data_nq.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from tqdm import tqdm
4 | def load_data(data_path=None):
5 | assert data_path
6 | with open(data_path, 'r',encoding='utf-8') as fin:
7 | data = json.load(fin)
8 | examples = []
9 | for k, example in enumerate(data):
10 | if not 'id' in example:
11 | example['id'] = k
12 | for c in example['ctxs']:
13 | if not 'score' in c:
14 | c['score'] = 1.0 / (k + 1)
15 | examples.append(example)
16 | return examples
17 |
18 | def read_train_pos(ground_truth_path):
19 | origin_train_path = ground_truth_path
20 | with open(origin_train_path, "r", encoding="utf-8") as ifile:
21 | # file format: question, answers
22 | train_list = json.load(ifile)
23 | train_q_pos_dict = {}
24 | for example in train_list:
25 | if len(example['positive_ctxs'])==0 or "positive_ctxs" not in example.keys():
26 | continue
27 | train_q_pos_dict[example['question']]=example['positive_ctxs'][0]
28 | return train_q_pos_dict
29 |
30 | def reform_out(examples,outfile, ground_truth_path):
31 | train_q_pos_dict = read_train_pos(ground_truth_path)
32 | transfer_list = []
33 | easy_ctxs = []
34 | for infer_result in tqdm(examples):
35 | if 'passage_id' not in infer_result.keys():
36 | q_id = infer_result["id"]
37 | else:
38 | q_id = infer_result["passage_id"]
39 | q_str = infer_result["question"]
40 | q_answer = infer_result["answers"]
41 | positive_ctxs = []
42 | negative_ctxs = []
43 | if q_str in train_q_pos_dict.keys():
44 | # ground true
45 | real_true_dic = train_q_pos_dict[q_str]
46 | # real_true_doc_id = real_true_dic['passage_id'] if 'passage_id' in real_true_dic.keys() else real_true_dic['id']
47 | if 'passage_id' not in real_true_dic.keys() and 'id' in real_true_dic.keys():
48 | real_true_dic['passage_id'] = real_true_dic['id']
49 | elif 'psg_id' in real_true_dic.keys():
50 | real_true_dic['passage_id'] = real_true_dic['psg_id']
51 | positive_ctxs.append(real_true_dic)
52 |
53 | for doc in infer_result['ctxs']:
54 | doc_text = doc['text']
55 | doc_title = doc['title']
56 | if doc['hit']=="True":
57 | positive_ctxs.append({'title':doc_title,'text':doc_text,'passage_id':doc['d_id'],'score':str(doc['score'])})
58 | else:
59 | negative_ctxs.append({'title':doc_title,'text':doc_text,'passage_id':doc['d_id'],'score':str(doc['score'])})
60 | # easy_ctxs.append({'title':doc_title,'text':doc_text,'passage_id':doc['d_id'],'score':str(doc['score'])})
61 |
62 | transfer_list.append(
63 | {
64 | "q_id":str(q_id), "question":q_str, "answers" :q_answer ,"positive_ctxs":positive_ctxs,"hard_negative_ctxs":negative_ctxs,"negative_ctxs":[]
65 | }
66 | )
67 |
68 | print("total data to ce train: ", len(transfer_list))
69 | # print("easy data to ce train: ", len(easy_list))
70 | print("hardneg num:", len(transfer_list[0]["hard_negative_ctxs"]))
71 | # print("easyneg num:", len(transfer_list[0]["easy_negative_ctxs"]))
72 |
73 | with open(outfile, 'w',encoding='utf-8') as f:
74 | json.dump(transfer_list, f,indent=2)
75 | return
76 | if __name__ == "__main__":
77 | inference_results = sys.argv[1]
78 | outfile = sys.argv[2]
79 | ground_truth_path = sys.argv[3]
80 | examples = load_data(inference_results)
81 | reform_out(examples,outfile,ground_truth_path)
82 |
83 |
84 |
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/preprae_ce_marco_train.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import random
4 | import pickle
5 | import sys
6 |
7 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1):
8 | def gen():
9 | for i, line in tqdm(enumerate(fd)):
10 | if i % trainer_num == trainer_id:
11 | slots = line.rstrip('\n').split(delimiter)
12 | if len(slots) == 1:
13 | yield slots,
14 | else:
15 | yield slots
16 | return gen()
17 |
18 | def read_qrel(relevance_file):
19 | qrel = {}
20 | with open(relevance_file, encoding='utf8') as f:
21 | tsvreader = csv_reader(f, delimiter="\t")
22 | for [topicid, _, docid, rel] in tsvreader:
23 | assert rel == "1"
24 | if topicid in qrel:
25 | qrel[topicid].append(docid)
26 | else:
27 | qrel[topicid] = [docid]
28 | return qrel
29 |
30 | def read_qstring(query_file):
31 | q_string = {}
32 | with open(query_file, 'r', encoding='utf-8') as file:
33 | for num, line in enumerate(file):
34 | line = line.strip('\n')
35 | line = line.split('\t')
36 | q_string[line[0]] = line[1]
37 | return q_string
38 |
39 | def load_train_reference_from_stream(input_file, trainer_id=0, trainer_num=1):
40 | """Reads a tab separated value file."""
41 | with open(input_file, 'r', encoding='utf8') as f:
42 | reader = csv_reader(f, trainer_id=trainer_id, trainer_num=trainer_num)
43 | #headers = 'query_id\tpos_id\tneg_id'.split('\t')
44 |
45 | #Example = namedtuple('Example', headers)
46 | qrel = {}
47 | for [topicid, _, docid, rel] in reader:
48 | topicid = int(topicid)
49 | assert rel == "1"
50 | if topicid in qrel:
51 | qrel[topicid].append(int(docid))
52 | else:
53 | qrel[topicid] = [int(docid)]
54 | return qrel
55 |
56 | def main(relevance_file, query_file, result_file, outfile, neg_num):
57 | with open(result_file, 'rb') as f:
58 | qids_to_ranked_candidate_passages, qids_to_ranked_candidate_scores = pickle.load(f)
59 | qids_to_relevant_passageids = load_train_reference_from_stream(relevance_file)
60 |
61 | datalist = []
62 | q_string = read_qstring(query_file)
63 | for qid in qids_to_ranked_candidate_passages:
64 | examples = {}
65 | if qid in qids_to_relevant_passageids:
66 | target_pid = qids_to_relevant_passageids[qid]
67 | examples['query_id'] = str(qid)
68 | examples['pos_id'] = [str(id) for id in target_pid]
69 | examples['query_string'] = q_string[examples['query_id']]
70 | candidate_pid = qids_to_ranked_candidate_passages[qid]
71 | examples['neg_id'] = []
72 |
73 | for i, pid in enumerate(candidate_pid):
74 | if pid in target_pid:
75 | # print(target_pid, pid)
76 | continue
77 | else:
78 | if len(examples['neg_id']) < neg_num:
79 | examples['neg_id'].append(str(pid))
80 | else:
81 | break
82 | if len(examples['query_id'])!=0 and len(examples['pos_id'])!=0 and len(examples['neg_id'])!=0:
83 | datalist.append(examples)
84 |
85 |
86 | # print(num,len(qids_to_ranked_candidate_passages),num/len(qids_to_ranked_candidate_passages))
87 | print("data len:", len(datalist))
88 | print("data keys:", datalist[0].keys())
89 | print("data info:", datalist[0])
90 | print("data info:", datalist[100])
91 | print("data info:", datalist[1000])
92 |
93 | for data in datalist:
94 | for id in data['neg_id']:
95 | if int(id) in data['pos_id']:
96 | print("data error")
97 | print(data)
98 | exit(0)
99 | if str(id) in data['pos_id']:
100 | print("data error")
101 | print(data)
102 | exit(0)
103 | print("Correctness check completed")
104 |
105 |
106 | with open(outfile, 'w',encoding='utf-8') as f:
107 | json.dump(datalist, f,indent=2)
108 |
109 | if __name__ == "__main__":
110 | relevance_file = sys.argv[1]
111 | query_file = sys.argv[2]
112 | result_file = sys.argv[3]
113 | outfile = sys.argv[4]
114 | neg_num = int(sys.argv[5])
115 | main(relevance_file, query_file, result_file, outfile, neg_num)
--------------------------------------------------------------------------------
/PROD/ProD_KD/utils/preprae_ce_marcodoc_train.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import random
4 | import pickle
5 | import sys
6 | def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1):
7 | def gen():
8 | for i, line in tqdm(enumerate(fd)):
9 | if i % trainer_num == trainer_id:
10 | slots = line.rstrip('\n').split(delimiter)
11 | if len(slots) == 1:
12 | yield slots,
13 | else:
14 | yield slots
15 | return gen()
16 |
17 | def load_marcodoc_reference_from_stream(path_to_reference):
18 | """Load Reference reference relevant passages
19 | Args:f (stream): stream to load.
20 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
21 | """
22 | qids_to_relevant_passageids = {}
23 | with open(path_to_reference, 'r') as f:
24 | for l in f:
25 | l = l.strip().split(' ')
26 | qid = int(l[0])
27 | if qid in qids_to_relevant_passageids:
28 | pass
29 | else:
30 | qids_to_relevant_passageids[qid] = []
31 | qids_to_relevant_passageids[qid].append(int(l[2][1:]))
32 | return qids_to_relevant_passageids
33 |
34 | def read_qstring(query_file):
35 | q_string = {}
36 | with open(query_file, 'r', encoding='utf-8') as file:
37 | for num, line in enumerate(file):
38 | line = line.strip('\n')
39 | line = line.split('\t')
40 | q_string[int(line[0])] = line[1]
41 | return q_string
42 |
43 | def main(relevance_file, query_file, result_file, outfile, neg_num):
44 | with open(result_file, 'rb') as f:
45 | qids_to_ranked_candidate_passages, qids_to_ranked_candidate_scores = pickle.load(f)
46 | qids_to_relevant_passageids = load_marcodoc_reference_from_stream(relevance_file)
47 |
48 | datalist = []
49 | q_string = read_qstring(query_file)
50 | for qid in qids_to_ranked_candidate_passages:
51 | examples = {}
52 | if qid in qids_to_relevant_passageids:
53 | target_pid = qids_to_relevant_passageids[qid]
54 | examples['query_id'] = int(qid)
55 | examples['pos_id'] = [int(id) for id in target_pid]
56 | examples['query_string'] = q_string[examples['query_id']]
57 | # neg id,pos id
58 | candidate_pid = qids_to_ranked_candidate_passages[qid]
59 | examples['neg_id'] = []
60 |
61 | for i, pid in enumerate(candidate_pid):
62 | if pid in examples['pos_id']:
63 | # print(target_pid, pid)
64 | continue
65 | else:
66 | if len(examples['neg_id']) < neg_num:
67 | examples['neg_id'].append(int(pid))
68 | else:
69 | break
70 | if len(examples['pos_id'])!=0 and len(examples['neg_id'])!=0:
71 | datalist.append(examples)
72 |
73 |
74 | # print(num,len(qids_to_ranked_candidate_passages),num/len(qids_to_ranked_candidate_passages))
75 | print("data len:", len(datalist))
76 | print("data keys:", datalist[0].keys())
77 | print("data info:", datalist[0])
78 | print("data info:", datalist[100])
79 | print("data info:", datalist[1000])
80 |
81 | for data in datalist:
82 | for id in data['neg_id']:
83 | if int(id) in data['pos_id']:
84 | print("data error")
85 | print(data)
86 | exit(0)
87 | if str(id) in data['pos_id']:
88 | print("data error")
89 | print(data)
90 | exit(0)
91 | print("Correctness check completed")
92 |
93 |
94 | with open(outfile, 'w',encoding='utf-8') as f:
95 | json.dump(datalist, f,indent=2)
96 |
97 | if __name__ == "__main__":
98 | relevance_file = sys.argv[1]
99 | query_file = sys.argv[2]
100 | result_file = sys.argv[3]
101 | outfile = sys.argv[4]
102 | neg_num = int(sys.argv[5])
103 | main(relevance_file, query_file, result_file, outfile, neg_num)
--------------------------------------------------------------------------------
/PROD/ProD_base/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_base/model/__init__.py
--------------------------------------------------------------------------------
/PROD/ProD_base/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/ProD_base/utils/__init__.py
--------------------------------------------------------------------------------
/PROD/ProD_base/utils/lamb.py:
--------------------------------------------------------------------------------
1 | """Lamb optimizer."""
2 |
3 | import collections
4 | import math
5 |
6 | import torch
7 | from tensorboardX import SummaryWriter
8 | from torch.optim import Optimizer
9 |
10 |
11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
12 | """Log a histogram of trust ratio scalars in across layers."""
13 | results = collections.defaultdict(list)
14 | for group in optimizer.param_groups:
15 | for p in group['params']:
16 | state = optimizer.state[p]
17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
18 | if i in state:
19 | results[i].append(state[i])
20 |
21 | for k, v in results.items():
22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
23 |
24 | class Lamb(Optimizer):
25 | r"""Implements Lamb algorithm.
26 |
27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
28 |
29 | Arguments:
30 | params (iterable): iterable of parameters to optimize or dicts defining
31 | parameter groups
32 | lr (float, optional): learning rate (default: 1e-3)
33 | betas (Tuple[float, float], optional): coefficients used for computing
34 | running averages of gradient and its square (default: (0.9, 0.999))
35 | eps (float, optional): term added to the denominator to improve
36 | numerical stability (default: 1e-8)
37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
38 | adam (bool, optional): always use trust ratio = 1, which turns this into
39 | Adam. Useful for comparison purposes.
40 |
41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
42 | https://arxiv.org/abs/1904.00962
43 | """
44 |
45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
46 | weight_decay=0, adam=False):
47 | if not 0.0 <= lr:
48 | raise ValueError("Invalid learning rate: {}".format(lr))
49 | if not 0.0 <= eps:
50 | raise ValueError("Invalid epsilon value: {}".format(eps))
51 | if not 0.0 <= betas[0] < 1.0:
52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
55 | defaults = dict(lr=lr, betas=betas, eps=eps,
56 | weight_decay=weight_decay)
57 | self.adam = adam
58 | super(Lamb, self).__init__(params, defaults)
59 |
60 | def step(self, closure=None):
61 | """Performs a single optimization step.
62 |
63 | Arguments:
64 | closure (callable, optional): A closure that reevaluates the model
65 | and returns the loss.
66 | """
67 | loss = None
68 | if closure is not None:
69 | loss = closure()
70 |
71 | for group in self.param_groups:
72 | for p in group['params']:
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
78 |
79 | state = self.state[p]
80 |
81 | # State initialization
82 | if len(state) == 0:
83 | state['step'] = 0
84 | # Exponential moving average of gradient values
85 | state['exp_avg'] = torch.zeros_like(p.data)
86 | # Exponential moving average of squared gradient values
87 | state['exp_avg_sq'] = torch.zeros_like(p.data)
88 |
89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | # Decay the first and second moment running average coefficient
95 | # m_t
96 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
97 | # v_t
98 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
99 |
100 | # Paper v3 does not use debiasing.
101 | # Apply bias to lr to avoid broadcast.
102 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
103 |
104 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
105 |
106 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
107 | if group['weight_decay'] != 0:
108 | adam_step.add_(group['weight_decay'], p.data)
109 |
110 | adam_norm = adam_step.pow(2).sum().sqrt()
111 | if weight_norm == 0 or adam_norm == 0:
112 | trust_ratio = 1
113 | else:
114 | trust_ratio = weight_norm / adam_norm
115 | state['weight_norm'] = weight_norm
116 | state['adam_norm'] = adam_norm
117 | state['trust_ratio'] = trust_ratio
118 | if self.adam:
119 | trust_ratio = 1
120 |
121 | p.data.add_(-step_size * trust_ratio, adam_step)
122 |
123 | return loss
124 |
--------------------------------------------------------------------------------
/PROD/image/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/PROD/image/framework.jpg
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/SimANS/README.md:
--------------------------------------------------------------------------------
1 | # SimANS: Simple Ambiguous Negatives Sampling for Dense Text Retrieval
2 |
3 | This repository contains the code for our EMNLP2022 paper [***SimANS: Simple Ambiguous Negatives Sampling for Dense Text Retrieval***](https://arxiv.org/abs/2210.11773).
4 |
5 |
6 | ## 🚀 Overview
7 |
8 | We propose ***SimANS***, a simple, general and flexible ambiguous negatives sampling method for dense text retrieval. It can be easily applied to various dense retrieval methods.
9 |
10 | The key code of our SimANS is the implementation of the sampling distribution as:
11 |
12 | $$p_{i} \propto \exp{(-a\cdot(s(q,d_{i})-s(q,\tilde{d}^{+})-b)^{2})}, \forall~d_{i} \in \hat{\mathcal{D}}^{-}$$
13 |
14 | ```python
15 | def SimANS(pos_pair, neg_pairs_list):
16 | pos_id, pos_score = int(pos_pair[0]), float(pos_pair[1])
17 | neg_candidates, neg_scores = [], []
18 | for pair in neg_pairs_list:
19 | neg_id, neg_score = pair
20 | neg_score = math.exp(-(neg_score - pos_score) ** 2 * self.tau)
21 | neg_candidates.append(neg_id)
22 | neg_scores.append(neg_score)
23 | return random.choices(neg_candidates, weights=neg_scores, k=num_hard_negatives)
24 | ```
25 |
26 | Here we show the main results on [MS MARCO](https://microsoft.github.io/msmarco/), [Natural Questions](https://ai.google.com/research/NaturalQuestions/) and [TriviaQA](http://nlp.cs.washington.edu/triviaqa/). This method outperformes the state-of-the-art methods.
27 | 
28 |
29 | This method has been applied in [Microsoft Bing](https://www.bing.com/), and we also show the results on the industry dataset.
30 |
31 |
32 |
33 | Please find more details in the paper.
34 |
35 |
36 | ## Released Resources
37 |
38 | We release the preprocessed data and trained ckpts in [Azure Blob](https://msranlciropen.blob.core.windows.net/simxns/SimANS/).
39 | Here we also provide the file list under this URL:
40 |
41 | Click here to see the file list.
42 | INFO: best_simans_ckpt.zip; Content Length: 7.74 GiB
43 | INFO: best_simans_ckpt/MS-Doc/checkpoint-25000; Content Length: 1.39 GiB
44 | INFO: best_simans_ckpt/MS-Doc/log.txt; Content Length: 78.32 KiB
45 | INFO: best_simans_ckpt/MS-Pas/checkpoint-20000; Content Length: 2.45 GiB
46 | INFO: best_simans_ckpt/MS-Pas/log.txt; Content Length: 82.74 KiB
47 | INFO: best_simans_ckpt/NQ/checkpoint-30000; Content Length: 2.45 GiB
48 | INFO: best_simans_ckpt/NQ/log.txt; Content Length: 298.44 KiB
49 | INFO: best_simans_ckpt/TQ/checkpoint-10000; Content Length: 2.45 GiB
50 | INFO: best_simans_ckpt/TQ/log.txt; Content Length: 99.44 KiB
51 | INFO: ckpt.zip; Content Length: 19.63 GiB
52 | INFO: ckpt/MS-Doc/adore-star/config.json; Content Length: 1.37 KiB
53 | INFO: ckpt/MS-Doc/adore-star/pytorch_model.bin; Content Length: 480.09 MiB
54 | INFO: ckpt/MS-Doc/checkpoint-20000; Content Length: 1.39 GiB
55 | INFO: ckpt/MS-Doc/checkpoint-reranker20000; Content Length: 1.39 GiB
56 | INFO: ckpt/MS-Pas/checkpoint-20000; Content Length: 2.45 GiB
57 | INFO: ckpt/MS-Pas/checkpoint-reranker20000; Content Length: 3.75 GiB
58 | INFO: ckpt/NQ/checkpoint-reranker26000; Content Length: 3.75 GiB
59 | INFO: ckpt/NQ/nq_fintinue.pkl; Content Length: 2.45 GiB
60 | INFO: ckpt/TQ/checkpoint-reranker34000; Content Length: 3.75 GiB
61 | INFO: ckpt/TQ/triviaqa_fintinue.pkl; Content Length: 2.45 GiB
62 | INFO: data.zip; Content Length: 18.43 GiB
63 | INFO: data/MS-Doc/dev_ce_0.tsv; Content Length: 15.97 MiB
64 | INFO: data/MS-Doc/msmarco-docdev-qrels.tsv; Content Length: 105.74 KiB
65 | INFO: data/MS-Doc/msmarco-docdev-queries.tsv; Content Length: 215.14 KiB
66 | INFO: data/MS-Doc/msmarco-docs.tsv; Content Length: 21.32 GiB
67 | INFO: data/MS-Doc/msmarco-doctrain-qrels.tsv; Content Length: 7.19 MiB
68 | INFO: data/MS-Doc/msmarco-doctrain-queries.tsv; Content Length: 14.76 MiB
69 | INFO: data/MS-Doc/train_ce_0.tsv; Content Length: 1.13 GiB
70 | INFO: data/MS-Pas/dev.query.txt; Content Length: 283.39 KiB
71 | INFO: data/MS-Pas/para.title.txt; Content Length: 280.76 MiB
72 | INFO: data/MS-Pas/para.txt; Content Length: 2.85 GiB
73 | INFO: data/MS-Pas/qrels.dev.tsv; Content Length: 110.89 KiB
74 | INFO: data/MS-Pas/qrels.train.addition.tsv; Content Length: 5.19 MiB
75 | INFO: data/MS-Pas/qrels.train.tsv; Content Length: 7.56 MiB
76 | INFO: data/MS-Pas/train.query.txt; Content Length: 19.79 MiB
77 | INFO: data/MS-Pas/train_ce_0.tsv; Content Length: 1.68 GiB
78 | INFO: data/NQ/dev_ce_0.json; Content Length: 632.98 MiB
79 | INFO: data/NQ/nq-dev.qa.csv; Content Length: 605.48 KiB
80 | INFO: data/NQ/nq-test.qa.csv; Content Length: 289.99 KiB
81 | INFO: data/NQ/nq-train.qa.csv; Content Length: 5.36 MiB
82 | INFO: data/NQ/train_ce_0.json; Content Length: 5.59 GiB
83 | INFO: data/TQ/dev_ce_0.json; Content Length: 646.60 MiB
84 | INFO: data/TQ/train_ce_0.json; Content Length: 5.62 GiB
85 | INFO: data/TQ/trivia-dev.qa.csv; Content Length: 3.03 MiB
86 | INFO: data/TQ/trivia-test.qa.csv; Content Length: 3.91 MiB
87 | INFO: data/TQ/trivia-train.qa.csv; Content Length: 26.67 MiB
88 | INFO: data/psgs_w100.tsv; Content Length: 12.76 GiB
89 |
90 |
91 | To download the files, please refer to [HOW_TO_DOWNLOAD](https://github.com/microsoft/SimXNS/tree/main/HOW_TO_DOWNLOAD.md).
92 |
93 |
94 | ## 🙋 How to Use
95 |
96 | **⚙️ Environment Setting**
97 |
98 | To faithfully reproduce our results, please use the correct `1.7.1` pytorch version corresponding to your platforms/CUDA versions according to [Released Packages by pytorch](https://anaconda.org/pytorch/pytorch), and install faiss successfully for evaluation.
99 |
100 | We list our command to prepare the experimental environment as follows:
101 | ```bash
102 | conda install pytorch==1.7.1 cudatoolkit=11.0 -c pytorch
103 | conda install faiss-gpu cudatoolkit=11.0 -c pytorch
104 | conda install transformers
105 | pip install tqdm
106 | pip install tensorboardX
107 | pip install lmdb
108 | pip install datasets
109 | pip install wandb
110 | pip install sklearn
111 | pip install boto3
112 | ```
113 |
114 | **💾 Data and Initial Checkpoint**
115 |
116 | We list the necessary data for training on MS-Pas/MS-Doc/NQ/TQ [here](https://msranlciropen.blob.core.windows.net/simxns/SimANS/data.zip). You can download the compressed file directly or with [Microsoft's AzCopy CLI tool](https://learn.microsoft.com/en-us/azure/storage/common/storage-ref-azcopy) and put the content in `./data`.
117 | If you are working on MS-Pas, you will need this [additional file](https://msranlciropen.blob.core.windows.net/simxns/SimANS/data/MS-Pas/qrels.train.addition.tsv) for training.
118 |
119 | In our approach, we require to use the checkpoint from AR2 for initialization. We release them [here](https://msranlciropen.blob.core.windows.net/simxns/SimANS/ckpt.zip). You can download the all-in-one compressed file and put the content in `./ckpt`.
120 |
121 |
122 | **📋 Training Scripts**
123 |
124 | We provide the training scripts using SimANS on SOTA AR2 model for MS-MARCO-Passage/Document Retrieval, NQ and TQ datasets, and have set up the best hyperparameters for training. You can run it to automatically finish the training and evaluation.
125 | ```bash
126 | bash train_MS_Pas_AR2.sh
127 | bash train_MS_Doc_AR2.sh
128 | bash train_NQ_AR2.sh
129 | bash train_TQ_AR2.sh
130 | ```
131 |
132 | For results in the paper, we use 8 * A100 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to different performance.
133 |
134 | **⚽ Best SimANS Checkpoint**
135 |
136 | For better reproducing our experimental results, we also release all the checkpoint of our approach [here](https://msranlciropen.blob.core.windows.net/simxns/SimANS/best_simans_ckpt.zip). You can download the compressed file and reuse the content for evaluation.
137 |
138 |
139 | ## 📜 Citation
140 |
141 | Please cite our paper if you use [SimANS](https://arxiv.org/abs/2210.11773) in your work:
142 | ```bibtex
143 | @article{zhou2022simans,
144 | title={SimANS: Simple Ambiguous Negatives Sampling for Dense Text Retrieval},
145 | author={Kun Zhou, Yeyun Gong, Xiao Liu, Wayne Xin Zhao, Yelong Shen, Anlei Dong, Jingwen Lu, Rangan Majumder, Ji-Rong Wen, Nan Duan and Weizhu Chen},
146 | booktitle = {{EMNLP}},
147 | year={2022}
148 | }
149 | ```
150 |
--------------------------------------------------------------------------------
/SimANS/best_simans_ckpt/MS-Doc/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/MS-Doc/.gitkeep
--------------------------------------------------------------------------------
/SimANS/best_simans_ckpt/MS-Pas/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/MS-Pas/.gitkeep
--------------------------------------------------------------------------------
/SimANS/best_simans_ckpt/NQ/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/NQ/.gitkeep
--------------------------------------------------------------------------------
/SimANS/best_simans_ckpt/TQ/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/best_simans_ckpt/TQ/.gitkeep
--------------------------------------------------------------------------------
/SimANS/ckpt/MS-Doc/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/MS-Doc/.gitkeep
--------------------------------------------------------------------------------
/SimANS/ckpt/MS-Pas/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/MS-Pas/.gitkeep
--------------------------------------------------------------------------------
/SimANS/ckpt/NQ/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/NQ/.gitkeep
--------------------------------------------------------------------------------
/SimANS/ckpt/TQ/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/ckpt/TQ/.gitkeep
--------------------------------------------------------------------------------
/SimANS/data/MS-Doc/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/MS-Doc/.gitkeep
--------------------------------------------------------------------------------
/SimANS/data/MS-Pas/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/MS-Pas/.gitkeep
--------------------------------------------------------------------------------
/SimANS/data/NQ/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/NQ/.gitkeep
--------------------------------------------------------------------------------
/SimANS/data/TQ/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/data/TQ/.gitkeep
--------------------------------------------------------------------------------
/SimANS/figs/simans_industry_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/figs/simans_industry_result.jpg
--------------------------------------------------------------------------------
/SimANS/figs/simans_main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/figs/simans_main.jpg
--------------------------------------------------------------------------------
/SimANS/figs/simans_main_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimXNS/0163380c69533be0112142ee6ba25f43bc9246bc/SimANS/figs/simans_main_result.jpg
--------------------------------------------------------------------------------
/SimANS/train_MS_Doc_AR2.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=co_training_MS_MARCO_Doc_SimANS
2 | Iteration_step=5000
3 | Iteration_reranker_step=1000
4 | MAX_STEPS=40000
5 |
6 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done;
7 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`;
8 | do
9 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 Doc_training/co_training_doc_train.py \
10 | --model_type=ckpt/MS-Doc/adore-star \
11 | --model_name_or_path=ckpt/MS-Doc/checkpoint-20000 \
12 | --max_seq_length=512 --per_gpu_train_batch_size=32 --gradient_accumulation_steps=1 \
13 | --number_neg=15 --learning_rate=5e-6 \
14 | --teacher_model_type=roberta-base \
15 | --teacher_model_path=ckpt/MS-Doc/checkpoint-reranker20000 \
16 | --teacher_learning_rate=1e-6 \
17 | --output_dir=ckpt/$EXP_NAME \
18 | --log_dir=tensorboard/logs/$EXP_NAME \
19 | --origin_data_dir=data/MS-Doc/train_ce_0.tsv \
20 | --train_qa_path=data/MS-Doc/msmarco-doctrain-queries.tsv \
21 | --passage_path=data/MS-Doc \
22 | --logging_steps=100 --save_steps=5000 --max_steps=$MAX_STEPS \
23 | --gradient_checkpointing --distill_loss \
24 | --iteration_step=$Iteration_step \
25 | --iteration_reranker_step=$Iteration_reranker_step \
26 | --temperature_distill=1 --ann_dir=ckpt/$EXP_NAME/temp --adv_lambda 1 --global_step=$global_step
27 |
28 | g_global_step=`expr $global_step + $Iteration_step`
29 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 Doc_training/co_training_doc_generate.py \
30 | --model_type=ckpt/MS-Doc/adore-star \
31 | --max_seq_length=512 \
32 | --output_dir=ckpt/$EXP_NAME \
33 | --log_dir=tensorboard/logs/$EXP_NAME \
34 | --train_qa_path=data/MS-Doc/msmarco-doctrain-queries.tsv \
35 | --dev_qa_path=data/MS-Doc/msmarco-docdev-queries.tsv \
36 | --passage_path=data/MS-Doc \
37 | --max_steps=$MAX_STEPS \
38 | --gradient_checkpointing \
39 | --ann_dir=ckpt/$EXP_NAME/temp --global_step=$g_global_step
40 | done
41 |
--------------------------------------------------------------------------------
/SimANS/train_MS_Pas_AR2.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=co_training_MS_MARCO_Pas_SimANS
2 | Iteration_step=5000
3 | Iteration_reranker_step=500
4 | MAX_STEPS=60000
5 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done;
6 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`;
7 | do
8 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 co_training/co_training_marco_train.py \
9 | --model_type=Luyu/co-condenser-marco \
10 | --model_name_or_path=ckpt/MS-Pas/checkpoint-20000 \
11 | --max_seq_length=128 --per_gpu_train_batch_size=16 --gradient_accumulation_steps=2 \
12 | --number_neg=15 --learning_rate=5e-6 \
13 | --teacher_model_type=nghuyong/ernie-2.0-large-en \
14 | --teacher_model_path=ckpt/MS-Pas/checkpoint-reranker20000 \
15 | --teacher_learning_rate=5e-7 \
16 | --output_dir=ckpt/$EXP_NAME \
17 | --log_dir=tensorboard/logs/$EXP_NAME \
18 | --origin_data_dir=data/MS-Pas/train_ce_0.tsv \
19 | --train_qa_path=data/MS-Pas/train.query.txt \
20 | --dev_qa_path=data/MS-Pas/dev.query.txt \
21 | --passage_path=data/MS-Pas \
22 | --logging_steps=10 --save_steps=5000 --max_steps=$MAX_STEPS \
23 | --gradient_checkpointing --distill_loss \
24 | --iteration_step=$Iteration_step \
25 | --iteration_reranker_step=$Iteration_reranker_step \
26 | --temperature_distill=1 --ann_dir=ckpt/$EXP_NAME/temp --adv_lambda 1 --global_step=$global_step
27 |
28 | g_global_step=`expr $global_step + $Iteration_step`
29 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 co_training/co_training_marco_generate.py \
30 | --model_type=Luyu/co-condenser-marco \
31 | --max_seq_length=128 \
32 | --output_dir=ckpt/$EXP_NAME \
33 | --log_dir=tensorboard/logs/$EXP_NAME \
34 | --train_qa_path=data/MS-Pas/train.query.txt \
35 | --dev_qa_path=data/MS-Pas/dev.query.txt \
36 | --passage_path=data/MS-Pas \
37 | --max_steps=$MAX_STEPS \
38 | --gradient_checkpointing --adv_step=0 \
39 | --iteration_step=$Iteration_step \
40 | --iteration_reranker_step=$Iteration_reranker_step \
41 | --ann_dir=ckpt/$EXP_NAME/temp --global_step=$g_global_step
42 | done
43 |
--------------------------------------------------------------------------------
/SimANS/train_NQ_AR2.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=co_training_nq_SimANS_test
2 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path
3 | OUT_DIR=output/$EXP_NAME
4 |
5 | DE_CKPT_PATH=ckpt/NQ/nq_fintinue.pkl
6 | CE_CKPT_PATH=ckpt/NQ/checkpoint-reranker26000
7 | Origin_Data_Dir=data/NQ/train_ce_0.json
8 | Origin_Data_Dir_Dev=data/NQ/dev_ce_0.json
9 |
10 | Iteration_step=2000
11 | Iteration_reranker_step=500
12 | MAX_STEPS=30000
13 |
14 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done;
15 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`;
16 | do
17 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_train.py \
18 | --model_type=nghuyong/ernie-2.0-base-en \
19 | --model_name_or_path=$DE_CKPT_PATH \
20 | --max_seq_length=128 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \
21 | --number_neg=15 --learning_rate=1e-5 \
22 | --reranker_model_type=nghuyong/ernie-2.0-large-en \
23 | --reranker_model_path=$CE_CKPT_PATH \
24 | --reranker_learning_rate=1e-6 \
25 | --output_dir=$OUT_DIR \
26 | --log_dir=$TB_DIR \
27 | --origin_data_dir=$Origin_Data_Dir \
28 | --warmup_steps=2000 --logging_steps=100 --save_steps=2000 --max_steps=$MAX_STEPS \
29 | --gradient_checkpointing --normal_loss \
30 | --iteration_step=$Iteration_step \
31 | --iteration_reranker_step=$Iteration_reranker_step \
32 | --temperature_normal=1 --ann_dir=$OUT_DIR/temp --adv_lambda 0 --global_step=$global_step --b 1.0
33 |
34 | g_global_step=`expr $global_step + $Iteration_step`
35 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_generate.py \
36 | --model_type=nghuyong/ernie-2.0-base-en \
37 | --model_name_or_path=$DE_CKPT_PATH \
38 | --max_seq_length=128 --per_gpu_train_batch_size=8 \
39 | --output_dir=output/$EXP_NAME \
40 | --log_dir=tensorboard/logs/$EXP_NAME \
41 | --origin_data_dir=$Origin_Data_Dir \
42 | --origin_data_dir_dev=$Origin_Data_Dir_Dev \
43 | --train_qa_path=data/NQ/nq-train.qa.csv \
44 | --test_qa_path=data/NQ/nq-test.qa.csv \
45 | --dev_qa_path=data/NQ/nq-dev.qa.csv \
46 | --passage_path=data/psgs_w100.tsv \
47 | --max_steps=$MAX_STEPS \
48 | --gradient_checkpointing \
49 | --ann_dir=output/$EXP_NAME/temp --global_step=$g_global_step
50 | done
51 |
--------------------------------------------------------------------------------
/SimANS/train_TQ_AR2.sh:
--------------------------------------------------------------------------------
1 | EXP_NAME=co_training_tq_SimANS_test
2 | TB_DIR=tensorboard_log/$EXP_NAME # tensorboard log path
3 | OUT_DIR=output/$EXP_NAME
4 |
5 | DE_CKPT_PATH=ckpt/TQ/triviaqa_fintinue.pkl
6 | CE_CKPT_PATH=ckpt/TQ/checkpoint-reranker34000
7 | Origin_Data_Dir=data/TQ/train_ce_0.json
8 | Origin_Data_Dir_Dev=data/TQ/dev_ce_0.json
9 |
10 | Iteration_step=2000
11 | Iteration_reranker_step=500
12 | MAX_STEPS=10000
13 |
14 | # for global_step in `seq 0 2000 $MAX_STEPS`; do echo $global_step; done;
15 | for global_step in `seq 0 $Iteration_step $MAX_STEPS`;
16 | do
17 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_train.py \
18 | --model_type=nghuyong/ernie-2.0-base-en \
19 | --model_name_or_path=$DE_CKPT_PATH \
20 | --max_seq_length=256 --per_gpu_train_batch_size=8 --gradient_accumulation_steps=1 \
21 | --number_neg=15 --learning_rate=5e-6 \
22 | --reranker_model_type=nghuyong/ernie-2.0-large-en \
23 | --reranker_model_path=$CE_CKPT_PATH \
24 | --reranker_learning_rate=1e-6 \
25 | --output_dir=$OUT_DIR \
26 | --log_dir=$TB_DIR \
27 | --origin_data_dir=$Origin_Data_Dir \
28 | --warmup_steps=1000 --logging_steps=100 --save_steps=2000 --max_steps=$MAX_STEPS \
29 | --gradient_checkpointing --normal_loss \
30 | --iteration_step=$Iteration_step \
31 | --iteration_reranker_step=$Iteration_reranker_step \
32 | --temperature_normal=1 --ann_dir=$OUT_DIR/temp --adv_lambda 0.0 --global_step=$global_step --a 0.5 --b 0
33 |
34 | g_global_step=`expr $global_step + $Iteration_step`
35 | python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9539 wiki/co_training_wiki_generate.py \
36 | --model_type=nghuyong/ernie-2.0-base-en \
37 | --model_name_or_path=$DE_CKPT_PATH \
38 | --max_seq_length=256 --per_gpu_train_batch_size=8 \
39 | --output_dir=output/$EXP_NAME \
40 | --log_dir=tensorboard/logs/$EXP_NAME \
41 | --origin_data_dir=$Origin_Data_Dir \
42 | --origin_data_dir_dev=$Origin_Data_Dir_Dev \
43 | --train_qa_path=data/TQ/trivia-train.qa.csv \
44 | --test_qa_path=data/TQ/trivia-test.qa.csv \
45 | --dev_qa_path=data/TQ/trivia-dev.qa.csv \
46 | --passage_path=data/psgs_w100.tsv \
47 | --max_steps=$MAX_STEPS \
48 | --gradient_checkpointing \
49 | --ann_dir=output/$EXP_NAME/temp --global_step=$g_global_step
50 | done
51 |
52 |
--------------------------------------------------------------------------------