├── imgs
├── ARES_black.jpg
├── ARES_white.jpg
├── ARES_simple_ag.pdf
├── ARES_simple_ag.png
├── few-shot-metric.jpg
├── few-shot-metric.pdf
└── few-shot-metric.png
├── model
├── __init__.py
└── modeling.py
├── requirements.txt
├── preprocess
├── anserini_scripts
│ ├── do_bm25_search.sh
│ └── build_index.sh
├── convert_to_pred.py
├── README.md
├── convert_tokenize.py
└── Eval4.0.pl
├── example
└── rerank.py
├── finetune
├── modelsize_estimate.py
├── config.py
├── ms_marco_eval.py
├── dataloader.py
└── train.py
├── .gitignore
├── visualization
├── visualization.py
├── config.py
├── dataloader.py
├── visual.py
└── output_ARES_simple.html
├── pretrain
├── config.py
├── train.py
└── dataloader.py
├── README.md
└── LICENSE
/imgs/ARES_black.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_black.jpg
--------------------------------------------------------------------------------
/imgs/ARES_white.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_white.jpg
--------------------------------------------------------------------------------
/imgs/ARES_simple_ag.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_simple_ag.pdf
--------------------------------------------------------------------------------
/imgs/ARES_simple_ag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_simple_ag.png
--------------------------------------------------------------------------------
/imgs/few-shot-metric.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/few-shot-metric.jpg
--------------------------------------------------------------------------------
/imgs/few-shot-metric.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/few-shot-metric.pdf
--------------------------------------------------------------------------------
/imgs/few-shot-metric.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/few-shot-metric.png
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # PyTorch
2 | torch==1.9.0
3 | # Huggingface transformers
4 | transformers==4.9.2
5 | # progress bars in model download and training scripts
6 | tqdm
7 | # Accessing files from S3 directly.
8 | boto3
9 | # nltk
10 | nltk
11 | # numpy
12 | numpy
--------------------------------------------------------------------------------
/preprocess/anserini_scripts/do_bm25_search.sh:
--------------------------------------------------------------------------------
1 | python -m pyserini.search --index path_to_index \
2 | --topics path_to_queries \
3 | --output path_to_trec \
4 | --bm25 \
5 | --hits 200
--------------------------------------------------------------------------------
/preprocess/anserini_scripts/build_index.sh:
--------------------------------------------------------------------------------
1 | python -m pyserini.index -collection JsonCollection \
2 | -generator DefaultLuceneDocumentGenerator \
3 | -threads 8 \
4 | -input path_to_collection \
5 | -index path_to_index \
6 | -storePositions -storeDocvectors
--------------------------------------------------------------------------------
/preprocess/convert_to_pred.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from collections import defaultdict
3 | import argparse
4 |
5 | def trec_to_pred(args):
6 | trec = defaultdict(dict)
7 | with open(args.input_trec, 'r') as f:
8 | for line in f:
9 | qid, _, docid, rank, score, _ = line.strip().split(' ')
10 | trec[qid][docid] = score
11 |
12 | f = open(args.output, 'w')
13 | with open(args.qrels, 'r') as r:
14 | for line in r:
15 | line = line.strip().split()
16 | qid = line[1].split(':')[1]
17 | docid = line[-7]
18 | if docid in trec[qid]:
19 | f.write(trec[qid][docid] + '\n')
20 | else:
21 | f.write('0.0\n')
22 |
23 | f.close()
24 |
25 |
26 | if __name__ == "__main__":
27 | parser = argparse.ArgumentParser()
28 |
29 | parser.add_argument("--input_trec", default='', type=str, required=True)
30 | parser.add_argument("--output", default='', type=str, required=True)
31 | parser.add_argument("--qrels", default='', type=str, required=True)
32 | args = parser.parse_args()
33 |
34 | trec_to_pred(args)
--------------------------------------------------------------------------------
/example/rerank.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, '../')
4 |
5 | from tqdm import tqdm
6 | import json
7 | import torch
8 | import numpy as np
9 | import pandas as pd
10 | from datetime import timedelta
11 |
12 | from model.modeling import ARESReranker
13 |
14 |
15 | if __name__ == "__main__":
16 | model_path = "path/to/model"
17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18 | model = ARESReranker.from_pretrained(model_path).to(device)
19 |
20 | query1 = "What is the best way to get to the airport"
21 | query2 = "what do you like to eat?"
22 |
23 | doc1 = "The best way to get to the airport is to take the bus"
24 | doc2 = "I like to eat apples"
25 |
26 |
27 | ### Score a batch of q-d pairs
28 | qd_pairs = [
29 | (query1, doc1), (query1, doc2),
30 | (query2, doc1), (query2, doc2)
31 | ]
32 |
33 | score = model.score(qd_pairs)
34 | print("qd scores", score)
35 |
36 | ### Rerank a single query
37 | score = model.rerank_query(query1, [doc1, doc2])
38 | print("query1 scores", score)
39 |
40 | ### Rerank a batch of queries
41 | query1_topk = [ doc1, doc2 ]
42 | query2_topk = [ doc1, doc2 ]
43 |
44 | score = model.rerank([query1, query2], [query1_topk, query2_topk])
45 |
--------------------------------------------------------------------------------
/finetune/modelsize_estimate.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import torch.nn as nn
7 | import numpy as np
8 |
9 |
10 | def modelsize(model, input, type_size=4):
11 | para = sum([np.prod(list(p.size())) for p in model.parameters()])
12 | # print('Model {} : Number of params: {}'.format(model._get_name(), para))
13 | print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
14 |
15 | input_ = input.clone()
16 | input_.requires_grad_(requires_grad=False)
17 |
18 | mods = list(model.modules())
19 | out_sizes = []
20 |
21 | for i in range(1, len(mods)):
22 | m = mods[i]
23 | if isinstance(m, nn.ReLU):
24 | if m.inplace:
25 | continue
26 | out = m(input_)
27 | out_sizes.append(np.array(out.size()))
28 | input_ = out
29 |
30 | total_nums = 0
31 | for i in range(len(out_sizes)):
32 | s = out_sizes[i]
33 | nums = np.prod(np.array(s))
34 | total_nums += nums
35 |
36 | # print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums))
37 | # print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2))
38 | print('Model {} : intermedite variables: {:3f} M (without backward)'
39 | .format(model._get_name(), total_nums * type_size / 1000 / 1000))
40 | print('Model {} : intermedite variables: {:3f} M (with backward)'
41 | .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))
42 |
43 |
--------------------------------------------------------------------------------
/preprocess/README.md:
--------------------------------------------------------------------------------
1 | ## Data Preprocess
2 |
3 | Since different datasets require different pre-processing, we only provide some helper functions and scripts here.
4 |
5 | ### Anserini Scripts
6 |
7 | We use BM25 implemented by `anserini` to perform first-stage retrieval.
8 |
9 | Please make sure you have correctly installed `anserini` and `pyserini`.
10 |
11 | ### Tokenize
12 |
13 | You can pre-tokenize your dataset offline for faster training.
14 | ```bash
15 | python convert_tokenize.py \
16 | --vocab_dir {path_to_vocab} \
17 | --type {'query', 'doc', 'triples'} \
18 | --input {path_to_input} \
19 | --output {path_to_output}
20 | ```
21 | File format:
22 |
23 | * query: `qid \t query` for each line
24 | * doc: `{"id": docid, "contents": doc}` for each line
25 | * triples: `{"query": query_text, "doc_pos": positive_doc, "doc_neg": negative_doc}` for each line
26 |
27 | ### Small Datasets
28 |
29 | #### TREC-COVID
30 |
31 | We follow the same data preprocess as `OpenMatch`, please refer to [experiments-treccovid](https://github.com/thunlp/OpenMatch/blob/master/docs/experiments-treccovid.md)
32 |
33 | #### Robust04
34 |
35 | We use BM25 to generate Top-200 candidates for each query, and the fine-tuning procedure is similar to MS-MARCO
36 |
37 | #### MQ2007
38 |
39 | We use BM25 to generate Top-200 candidates for each query, and the fine-tuning procedure is similar to MS-MARCO
40 |
41 | Note that `trec_eval` cannot be used to compute metrics for MQ2007 directly. You should first convert the `trec` output file and use `Eval4.0.pl` for evaluation. `Eval4.0.pl` is from [LETOR4.0](https://www.microsoft.com/en-us/research/project/letor-learning-rank-information-retrieval/letor-4-0/)
42 | ```bash
43 | python convert_to_pred.py \
44 | --input_trec {path_to_trec_output} \
45 | --qrels {path_to_qrels} \
46 | --output {path_to_output}
47 |
48 | perl Eval4.0.pl {path_to_qrels} {path_to_output} ./eval_result 0
49 | ```
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/preprocess/convert_tokenize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 | from transformers import AutoTokenizer
7 |
8 |
9 | def tokenize_file(tokenizer, input_file, output_file, file_type):
10 | total_size = sum(1 for _ in open(input_file))
11 | with open(output_file, 'w') as outFile:
12 | for line in tqdm(open(input_file), total=total_size,
13 | desc=f"Tokenize: {os.path.basename(input_file)}"):
14 | if file_type == "query":
15 | seq_id, text = line.split("\t")
16 | else:
17 | line = json.loads(line.strip())
18 |
19 | tokens = tokenizer.tokenize(text)
20 | ids = tokenizer.convert_tokens_to_ids(tokens)[: 512]
21 | outFile.write(json.dumps(
22 | {"id": seq_id, "ids": ids}
23 | ))
24 | outFile.write("\n")
25 |
26 |
27 | def tokenize_queries(tokenizer, input_file, output_file):
28 | total_size = sum(1 for _ in open(input_file))
29 | f = open(output_file, 'w')
30 | with open(input_file, 'r') as r:
31 | for line in tqdm(r, total=total_size):
32 | query_id, text = line.strip().split('\t')
33 | tokens = tokenizer.tokenize(text)
34 | ids = tokenizer.convert_tokens_to_ids(tokens)[: 512]
35 | f.write(json.dumps({
36 | 'query_id': query_id,
37 | 'query': ids
38 | }) + '\n')
39 | f.close()
40 |
41 |
42 | def tokenize_docs(tokenizer, input_file, output_file):
43 | total_size = sum(1 for _ in open(input_file))
44 | f = open(output_file, 'w')
45 | with open(input_file, 'r') as r:
46 | for line in tqdm(r, total=total_size):
47 | line = json.loads(line.strip())
48 | tokens = tokenizer.tokenize(line['doc'])
49 | ids = tokenizer.convert_tokens_to_ids(tokens)[: 512]
50 | f.write(json.dumps({
51 | 'id': line['doc_id'],
52 | 'contents': ids
53 | }) + '\n')
54 | f.close()
55 |
56 |
57 | def tokenize_pairwise(tokenizer, input_file, output_file):
58 | total_size = sum(1 for _ in open(input_file))
59 | f = open(output_file, 'w')
60 | with open(input_file, 'r') as r:
61 | for line in tqdm(r, total=total_size):
62 | line = json.loads(line.strip())
63 | tokens = tokenizer.tokenize(line['query'])
64 | query_ids = tokenizer.convert_tokens_to_ids(tokens)[: 512]
65 |
66 | tokens = tokenizer.tokenize(line['doc_pos'])
67 | pos_ids = tokenizer.convert_tokens_to_ids(tokens)[: 512]
68 |
69 | tokens = tokenizer.tokenize(line['doc_neg'])
70 | neg_ids = tokenizer.convert_tokens_to_ids(tokens)[: 512]
71 | f.write(json.dumps({
72 | 'query': query_ids,
73 | 'doc_pos': pos_ids,
74 | 'doc_neg': neg_ids
75 | }) + '\n')
76 | f.close()
77 |
78 |
79 | if __name__ == "__main__":
80 |
81 | parser = argparse.ArgumentParser()
82 |
83 | parser.add_argument("--vocab_dir", default='bert-base-uncased', type=str)
84 | parser.add_argument("--type", default='query', type=str)
85 | parser.add_argument("--input", default='', type=str, required=True)
86 | parser.add_argument("--output", default='', type=str, required=True)
87 | args = parser.parse_args()
88 |
89 | tokenizer = AutoTokenizer.from_pretrained(args.vocab_dir)
90 |
91 | if args.type == "query":
92 | tokenize_queries(tokenizer, args.input, args.output)
93 | elif args.type == "doc":
94 | tokenize_docs(tokenizer, args.input, args.output)
95 | elif args.type == "triples":
96 | tokenize_pairwise(tokenizer, args.input, args.output)
--------------------------------------------------------------------------------
/visualization/visualization.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | from typing import Any, Iterable, List, Tuple, Union
6 | try:
7 | from IPython.core.display import HTML, display
8 |
9 | HAS_IPYTHON = True
10 | except ImportError:
11 | HAS_IPYTHON = False
12 |
13 |
14 | class VisualizationDataRecord:
15 | r"""
16 | A data record for storing attribution relevant information
17 | """
18 | __slots__ = [
19 | "word_attributions",
20 | "level",
21 | "rank",
22 | "v_q_id",
23 | "v_d_id",
24 | "doc_tokens",
25 | "convergence_score",
26 | ]
27 |
28 | def __init__(
29 | self,
30 | word_attributions,
31 | level,
32 | rank,
33 | v_q_id,
34 | v_d_id,
35 | doc_tokens,
36 | convergence_score,
37 | ):
38 | self.word_attributions = word_attributions
39 | self.level = level
40 | self.rank = rank
41 | self.v_q_id = v_q_id
42 | self.v_d_id = v_d_id
43 | self.doc_tokens = doc_tokens
44 | self.convergence_score = convergence_score
45 |
46 |
47 | def _get_color(attr):
48 | # clip values to prevent CSS errors (Values should be from [-1,1])
49 | # attr = max(-1, min(1, attr))
50 | if attr > 0:
51 | hue = 10
52 | sat = 75
53 | lig = 100-int(100 * attr)
54 | else:
55 | hue = 220
56 | sat = 75
57 | lig = 100 - int(-100 * attr)
58 | return "hsl({}, {}%, {}%)".format(hue, sat, lig)
59 |
60 |
61 | def format_classname(classname):
62 | return '
{} | '.format(classname)
63 |
64 |
65 | def format_special_tokens(token):
66 | if token.startswith("<") and token.endswith(">"):
67 | return "#" + token.strip("<>")
68 | return token
69 |
70 |
71 | def format_tooltip(item, text):
72 | return '{item}\
73 | {text}\
74 |
'.format(
75 | item=item, text=text
76 | )
77 |
78 |
79 | def format_word_importances(words, importances):
80 | if importances is None or len(importances) == 0:
81 | return " | "
82 | assert len(words) <= len(importances)
83 | tags = [""]
84 | for word, importance in zip(words, importances[: len(words)]):
85 | print(word, importance)
86 | word = format_special_tokens(word)
87 | color = _get_color(importance)
88 | # unwrapped_tag = ' {word}\
90 | # '.format(
91 | # color=color, word=word
92 | # )
93 | if word.startswith("##"):
94 | unwrapped_tag = '{word}'.format(
95 | color=color, word=word.replace("##","")
96 | )
97 | else:
98 | unwrapped_tag = ' {word}'.format(
99 | color=color, word=word
100 | )
101 | tags.append(unwrapped_tag)
102 | tags.append(" | ")
103 | return "".join(tags)
104 |
105 |
106 | def visualize_text(
107 | datarecords: Iterable[VisualizationDataRecord], legend: bool = False
108 | ) -> "HTML": # In quotes because this type doesn't exist in standalone mode
109 | assert HAS_IPYTHON, (
110 | "IPython must be available to visualize text. "
111 | "Please run 'pip install ipython'."
112 | )
113 | dom = []
114 | dom.append("")
115 | dom.append("")
116 | dom.append("")
117 | dom.append("")
118 | rows = [
119 | '| QID DID | '
120 | 'Relevance Level/Rank | '
121 | 'Word Importance | '
122 | ]
123 | for datarecord in datarecords:
124 | rows.append(
125 | "".join(
126 | [
127 | "
|---|
",
128 | format_classname("{}\n{}".format(datarecord.v_q_id,datarecord.v_d_id)),
129 | format_classname(
130 | "{}/{}".format(
131 | datarecord.level, datarecord.rank
132 | )
133 | ),
134 | format_word_importances(
135 | datarecord.doc_tokens, datarecord.word_attributions
136 | ),
137 | "
",
138 | ]
139 | )
140 | )
141 |
142 | dom.append("".join(rows))
143 | dom.append("
")
144 | dom.append("")
145 | dom.append("")
146 | html = HTML("".join(dom))
147 | display(html)
148 |
149 | return html
--------------------------------------------------------------------------------
/visualization/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | import argparse
6 | import pprint
7 | import yaml
8 |
9 |
10 | def str2bool(v):
11 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
12 | return True
13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
14 | return False
15 | else:
16 | raise argparse.ArgumentTypeError('Boolean value expected.')
17 |
18 |
19 | class Config(object):
20 | def __init__(self, **kwargs):
21 | """Configuration Class: set kwargs as class attributes with setattr"""
22 | for k, v in kwargs.items():
23 | setattr(self, k, v)
24 |
25 | @property
26 | def config_str(self):
27 | return pprint.pformat(self.__dict__)
28 |
29 | def __repr__(self):
30 | """Pretty-print configurations in alphabetical order"""
31 | config_str = 'Configurations\n'
32 | config_str += self.config_str
33 | return config_str
34 |
35 | def save(self, path):
36 | with open(path, 'w') as f:
37 | yaml.dump(self.__dict__, f, default_flow_style=False)
38 |
39 | @classmethod
40 | def load(cls, path):
41 | with open(path, 'r') as f:
42 | kwargs = yaml.load(f)
43 |
44 | return Config(**kwargs)
45 |
46 |
47 | def read_config(path):
48 | return Config.load(path)
49 |
50 |
51 | def get_config(parse=True, **optional_kwargs):
52 | """
53 | Get configurations as attributes of class
54 | 1. Parse configurations with argparse.
55 | 2. Create Config class initilized with parsed kwargs.
56 | 3. Return Config class.
57 | """
58 | parser = argparse.ArgumentParser()
59 | # Training
60 | parser.add_argument('--test', action="store_true")
61 | parser.add_argument('--epochs', type=int, default=10,
62 | help='num_epochs')
63 | parser.add_argument('--batch_size', type=int, default=1,
64 | help='batch size')
65 | parser.add_argument('--neg_docs_per_q', type=int, default=4,
66 | help='number of sampled docs per q-d pair')
67 | parser.add_argument("--adam_epsilon", default=1e-8, type=float)
68 | parser.add_argument("--weight_decay", default=0.01, type=float)
69 | parser.add_argument('--lr', type=float, default=3e-5,
70 | help='learning rate')
71 | parser.add_argument('--clip', type=float, default=1.0,
72 | help='gradient clip norm')
73 | parser.add_argument('--warm_up', type=float, default=0.1,
74 | help='warm up proportion')
75 | parser.add_argument('--gradient_checkpointing', action="store_true")
76 | parser.add_argument('--max_len', type=int, default=512,
77 | help='max length')
78 | parser.add_argument('--max_q_len', type=int, default=40,
79 | help='max query length')
80 | parser.add_argument('--visual_q_num', type=int, default=1,
81 | help='visual query number')
82 | parser.add_argument('--visual_d_num', type=int, default=5,
83 | help='visual document number')
84 | parser.add_argument('--model_name', type=str, default='ARES_simple',
85 | choices=['ARES_simple', 'ARES_hardest', 'BERT', 'PROP_msmarco'],
86 | help='the model name')
87 | parser.add_argument('--model_type', type=str, default='ARES',
88 | choices=['ARES', 'PROP', 'BERT'],
89 | help='the model type')
90 | parser.add_argument('--optim', type=str, default='adamw',
91 | choices=['adam', 'amsgrad', 'adagrad', 'adamw'],
92 | help='optimizer')
93 | parser.add_argument('--dropout', type=float, default=0.2)
94 | parser.add_argument('--distributed_train', action="store_true")
95 | parser.add_argument('--gpu_num', type=int, default=1)
96 | parser.add_argument('--seed', type=int, default=42,
97 | help='Random seed')
98 | parser.add_argument('--PRE_TRAINED_MODEL_NAME', default='/path/to/ARES-simple/',
99 | help='huggingface model name')
100 | parser.add_argument('--model_path', default='model_state_ARES', help='name of checkpoint to load')
101 | parser.add_argument('--print_every', default=200)
102 | parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training')
103 | parser.add_argument('--gradient_accumulation_steps', type=int, default=4,
104 | help="Number of updates steps to accumulate before performing a backward/update pass.")
105 |
106 | # human labels
107 | parser.add_argument('--dl2019_qd_dir', default='../preprocess/2019qrels-docs.txt')
108 |
109 | # queries
110 | parser.add_argument('--dl2019_qs_dir', default='../preprocess/queries.dl2019.json')
111 |
112 | # docs
113 | parser.add_argument('--memmap_doc_dir', default='../preprocess/doc_token_ids.memmap')
114 | parser.add_argument('--docid2id_dir', default='../preprocess/docid2idx.json')
115 |
116 | # STAR+ADORE Top100
117 | parser.add_argument('--dl100_dir', default='../preprocess/test.rank.tsv')
118 |
119 | if parse:
120 | kwargs = parser.parse_args()
121 | else:
122 | kwargs = parser.parse_known_args()[0]
123 |
124 | # Namespace => Dictionary
125 | kwargs = vars(kwargs)
126 | kwargs.update(optional_kwargs)
127 |
128 | return Config(**kwargs)
--------------------------------------------------------------------------------
/finetune/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import argparse
7 | import pprint
8 | import yaml
9 |
10 |
11 | def str2bool(v):
12 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
13 | return True
14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
15 | return False
16 | else:
17 | raise argparse.ArgumentTypeError('Boolean value expected.')
18 |
19 |
20 | class Config(object):
21 | def __init__(self, **kwargs):
22 | """Configuration Class: set kwargs as class attributes with setattr"""
23 | for k, v in kwargs.items():
24 | setattr(self, k, v)
25 |
26 | @property
27 | def config_str(self):
28 | return pprint.pformat(self.__dict__)
29 |
30 | def __repr__(self):
31 | """Pretty-print configurations in alphabetical order"""
32 | config_str = 'Configurations\n'
33 | config_str += self.config_str
34 | return config_str
35 |
36 | def save(self, path):
37 | with open(path, 'w') as f:
38 | yaml.dump(self.__dict__, f, default_flow_style=False)
39 |
40 | @classmethod
41 | def load(cls, path):
42 | with open(path, 'r') as f:
43 | kwargs = yaml.load(f)
44 |
45 | return Config(**kwargs)
46 |
47 |
48 | def read_config(path):
49 | return Config.load(path)
50 |
51 |
52 | def get_config(parse=True, **optional_kwargs):
53 | """
54 | Get configurations as attributes of class
55 | 1. Parse configurations with argparse.
56 | 2. Create Config class initilized with parsed kwargs.
57 | 3. Return Config class.
58 | """
59 | parser = argparse.ArgumentParser()
60 | # Training
61 | parser.add_argument('--test', action="store_true")
62 | parser.add_argument('--epochs', type=int, default=20,
63 | help='num_epochs')
64 | parser.add_argument('--batch_size', type=int, default=25,
65 | help='batch size')
66 | parser.add_argument('--neg_docs_per_q', type=int, default=4,
67 | help='number of sampled docs per q-d pair')
68 | parser.add_argument("--adam_epsilon", default=1e-8, type=float)
69 | parser.add_argument("--weight_decay", default=0.01, type=float)
70 | parser.add_argument('--lr', type=float, default=3e-5,
71 | help='learning rate')
72 | parser.add_argument('--clip', type=float, default=1.0,
73 | help='gradient clip norm')
74 | parser.add_argument('--warm_up', type=float, default=0.1,
75 | help='warm up proportion')
76 | parser.add_argument('--gradient_checkpointing', action="store_true")
77 | parser.add_argument('--max_len', type=int, default=512,
78 | help='max length')
79 | parser.add_argument('--max_q_len', type=int, default=15,
80 | help='max query length')
81 | parser.add_argument('--model_name', type=str, default='ARES_simple',
82 | help='the model name')
83 | parser.add_argument('--model_type', type=str, default='ARES',
84 | choices=['ARES', 'PROP', 'BERT', 'ICT'],
85 | help='the model type')
86 | parser.add_argument('--optim', type=str, default='adamw',
87 | choices=['adam', 'amsgrad', 'adagrad', 'adamw'],
88 | help='optimizer')
89 | parser.add_argument('--dropout', type=float, default=0.2)
90 | parser.add_argument('--embed_dim', type=int, default=100)
91 | parser.add_argument('--freeze', type=bool, default=False)
92 | parser.add_argument('--world_size', type=int, default=4)
93 | parser.add_argument('--distributed_train', action="store_true")
94 | parser.add_argument('--gpu_num', type=int, default=1)
95 | parser.add_argument('--seed', type=int, default=42,
96 | help='Random seed')
97 | parser.add_argument('--PRE_TRAINED_MODEL_NAME', default='/path/to/ares-simple/',
98 | help='huggingface model name')
99 | parser.add_argument('--gradient_accumulation_steps', type=int, default=4,
100 | help="Number of updates steps to accumulate before performing a backward/update pass.")
101 | parser.add_argument('--load_ckpt', action="store_true", help='whether to load a trained checkpoint')
102 | parser.add_argument('--model_path', default='model_state_ARES', help='name of checkpoint to load')
103 | parser.add_argument('--print_every', default=200)
104 | parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training')
105 |
106 | # human labels
107 | parser.add_argument('--train_qd_dir', default='../preprocess/msmarco-doctrain-qrels.tsv')
108 | parser.add_argument('--test_qd_dir', default='../preprocess/dev-qrels.txt')
109 | parser.add_argument('--dl2019_qd_dir', default='../preprocess/2019qrels-docs.txt')
110 |
111 | # queries
112 | parser.add_argument('--train_qs_dir', default='../preprocess/queries.doctrain.json')
113 | parser.add_argument('--test_qs_dir', default='../preprocess/queries.docdev.json')
114 | parser.add_argument('--dl2019_qs_dir', default='../preprocess/queries.dl2019.json')
115 |
116 | # docs
117 | parser.add_argument('--memmap_doc_dir', default='../preprocess/doc_token_ids.memmap')
118 | parser.add_argument('--docid2id_dir', default='../preprocess/docid2idx.json')
119 |
120 | # STAR+ADORE Top100
121 | parser.add_argument('--train100_dir', default='../preprocess/train.rank.tsv')
122 | parser.add_argument('--test100_dir', default='../preprocess/dev.rank.tsv')
123 | parser.add_argument('--dl100_dir', default='../preprocess/test.rank.tsv')
124 |
125 | if parse:
126 | kwargs = parser.parse_args()
127 | else:
128 | kwargs = parser.parse_known_args()[0]
129 |
130 | # Namespace => Dictionary
131 | kwargs = vars(kwargs)
132 | kwargs.update(optional_kwargs)
133 |
134 | return Config(**kwargs)
--------------------------------------------------------------------------------
/pretrain/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import argparse
7 | import pprint
8 | import yaml
9 |
10 |
11 | def str2bool(v):
12 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
13 | return True
14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
15 | return False
16 | else:
17 | raise argparse.ArgumentTypeError('Boolean value expected.')
18 |
19 |
20 | class Config(object):
21 | def __init__(self, **kwargs):
22 | """Configuration Class: set kwargs as class attributes with setattr"""
23 | for k, v in kwargs.items():
24 | setattr(self, k, v)
25 |
26 | @property
27 | def config_str(self):
28 | return pprint.pformat(self.__dict__)
29 |
30 | def __repr__(self):
31 | """Pretty-print configurations in alphabetical order"""
32 | config_str = 'Configurations\n'
33 | config_str += self.config_str
34 | return config_str
35 |
36 | def save(self, path):
37 | with open(path, 'w') as f:
38 | yaml.dump(self.__dict__, f, default_flow_style=False)
39 |
40 | @classmethod
41 | def load(cls, path):
42 | with open(path, 'r') as f:
43 | kwargs = yaml.load(f)
44 |
45 | return Config(**kwargs)
46 |
47 |
48 | def read_config(path):
49 | return Config.load(path)
50 |
51 |
52 | def get_config(parse=True, **optional_kwargs):
53 | """
54 | Get configurations as attributes of class
55 | 1. Parse configurations with argparse.
56 | 2. Create Config class initilized with parsed kwargs.
57 | 3. Return Config class.
58 | """
59 | parser = argparse.ArgumentParser()
60 |
61 | # Training
62 | parser.add_argument('--epochs', type=int, default=1,
63 | help='num_epochs')
64 | parser.add_argument('--batch_size', type=int, default=22,
65 | help='batch size')
66 | parser.add_argument('--neg_docs_per_q', type=int, default=4,
67 | help='number of sampled docs per q-d pair')
68 | parser.add_argument("--adam_epsilon", default=1e-8, type=float)
69 | parser.add_argument("--weight_decay", default=0.01, type=float)
70 | parser.add_argument('--lr', type=float, default=2e-5,
71 | help='learning rate')
72 | parser.add_argument('--clip', type=float, default=1.0,
73 | help='gradient clip norm')
74 | parser.add_argument('--warm_up', type=float, default=0.1,
75 | help='warm up proportion')
76 | parser.add_argument('--gradient_checkpointing', action="store_true")
77 | parser.add_argument('--max_len', type=int, default=512,
78 | help='max length')
79 | parser.add_argument('--max_q_len', type=int, default=40,
80 | help='max query length')
81 | parser.add_argument('--model_name', type=str, default='ARES_simple',
82 | help='the model name')
83 | parser.add_argument('--model_type', type=str, default='ARES',
84 | choices=['ARES', 'ICT'],
85 | help='the model type')
86 | parser.add_argument('--optim', type=str, default='adamw', choices=['adam', 'amsgrad', 'adagrad', 'adamw'],
87 | help='optimizer')
88 | parser.add_argument('--dropout', type=float, default=0.2)
89 | parser.add_argument('--embed_dim', type=int, default=100)
90 | parser.add_argument('--freeze', type=bool, default=False)
91 | parser.add_argument('--world_size', type=int, default=4)
92 | parser.add_argument('--distributed_train', action="store_true")
93 | parser.add_argument('--gpu_num', type=int, default=1)
94 | parser.add_argument('--seed', type=int, default=42,
95 | help='Random seed')
96 | parser.add_argument('--PRE_TRAINED_MODEL_NAME', default='/path/to/bert-base/',
97 | help='huggingface model name')
98 | parser.add_argument('--load_ckpt', action="store_true",
99 | help='whether to load a trained checkpoint')
100 | parser.add_argument('--model_path', default='model_state_ARES',
101 | help='name of checkpoint to load')
102 | parser.add_argument('--clf_model', default='/path/to/xgboost.model',
103 | help='the axiom classifier model path (xgboost)')
104 | parser.add_argument('--MLM', action="store_true", help='whether to add MLM loss while pre-training')
105 | parser.add_argument('--masked_lm_prob', default=0.15, help='only used when MLM is true')
106 | parser.add_argument('--max_predictions_per_seq', default=60,
107 | help='only used when MLM is true')
108 | parser.add_argument('--print_every', default=200)
109 | parser.add_argument('--local_rank', type=int, default=0,
110 | help='node rank for distributed training')
111 |
112 | # tricks
113 | parser.add_argument('--gradient_accumulation_steps', type=int, default=4,
114 | help="Number of updates steps to accumulate before performing a backward/update pass.")
115 |
116 | # human labels
117 | parser.add_argument('--train_qd_dir', default='../preprocess/msmarco-doctrain-qrels.tsv')
118 | parser.add_argument('--test_qd_dir', default='../preprocess/dev-qrels.txt')
119 | parser.add_argument('--dl2019_qd_dir', default='../preprocess/2019qrels-docs.txt')
120 |
121 | # queries
122 | parser.add_argument('--train_qs_dir', default='../preprocess/queries.doctrain.json')
123 | parser.add_argument('--test_qs_dir', default='../preprocess/queries.docdev.json')
124 | parser.add_argument('--dl2019_qs_dir', default='../preprocess/queries.dl2019.json')
125 |
126 | # docs
127 | parser.add_argument('--memmap_doc_dir', default='../preprocess/doc_token_ids.memmap')
128 | parser.add_argument('--docid2id_dir', default='../preprocess/docid2idx.json')
129 |
130 | # STAR+ADORE Top100 candidates
131 | parser.add_argument('--train100_dir', default='../preprocess/A+S_top100/train.rank.tsv')
132 | parser.add_argument('--test100_dir', default='../preprocess/A+S_top100/dev.rank.tsv')
133 | parser.add_argument('--dl100_dir', default='../preprocess/A+S_top100/test.rank.tsv')
134 |
135 | # candidate queries and axioms
136 | parser.add_argument('--doc2query_dir', default='../preprocess/doc2qs.json')
137 | parser.add_argument('--gen_qs_memmap_dir', default='../preprocess/sample_qs_token_ids.memmap')
138 | parser.add_argument('--gen_qid2id_dir', default='../preprocess/sample_qid2id.json') # qid idx
139 | parser.add_argument('--axiom', type=str, nargs='+',
140 | help="Basic axioms: [RANK, REP], Auxiliary axioms: [PROX, REG, STM], you should choose at least one basic axiom.", required=True)
141 | parser.add_argument('--axiom_feature_dir', default='../preprocess/axioms')
142 |
143 | if parse:
144 | kwargs = parser.parse_args()
145 | else:
146 | kwargs = parser.parse_known_args()[0]
147 |
148 | # Namespace => Dictionary
149 | kwargs = vars(kwargs)
150 | kwargs.update(optional_kwargs)
151 |
152 | return Config(**kwargs)
--------------------------------------------------------------------------------
/finetune/ms_marco_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH!
3 | Command line:
4 | python msmarco_eval_ranking.py
5 | Creation Date : 06/12/2018
6 | Last Modified : 4/09/2019
7 | Authors : Daniel Campos , Rutger van Haasteren
8 | """
9 |
10 | import sys
11 | import math
12 | import numpy as np
13 | from collections import Counter
14 |
15 | MaxMRRRank1 = 10
16 | MaxMRRRank2 = 100
17 |
18 |
19 | def load_reference_from_stream(f):
20 | """Load Reference reference relevant passages
21 | Args:f (stream): stream to load.
22 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
23 | """
24 | qids_to_relevant_passageids = {}
25 | for l in f:
26 | try:
27 | l = l.strip().split('\t')
28 | qid = int(l[0])
29 | if qid in qids_to_relevant_passageids:
30 | pass
31 | else:
32 | qids_to_relevant_passageids[qid] = []
33 | qids_to_relevant_passageids[qid].append(l[1])
34 | except:
35 | raise IOError('\"%s\" is not valid format' % l)
36 | return qids_to_relevant_passageids
37 |
38 |
39 | def load_reference(path_to_reference):
40 | """Load Reference reference relevant passages
41 | Args:path_to_reference (str): path to a file to load.
42 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
43 | """
44 | with open(path_to_reference, 'r') as f:
45 | qids_to_relevant_passageids = load_reference_from_stream(f)
46 | return qids_to_relevant_passageids
47 |
48 |
49 | def load_candidate_from_stream(f):
50 | """Load candidate data from a stream.
51 | Args:f (stream): stream to load.
52 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
53 | """
54 | qid_to_ranked_candidate_passages = {}
55 | for l in f:
56 | try:
57 | l = l.strip().split('\t')
58 | qid = int(l[0])
59 | pid = l[1]
60 | rank = int(l[2])
61 | if qid in qid_to_ranked_candidate_passages:
62 | pass
63 | else:
64 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given
65 | tmp = [0] * 1000
66 | qid_to_ranked_candidate_passages[qid] = tmp
67 | qid_to_ranked_candidate_passages[qid][rank - 1] = pid
68 | except:
69 | raise IOError('\"%s\" is not valid format' % l)
70 | return qid_to_ranked_candidate_passages
71 |
72 |
73 | def load_candidate(path_to_candidate):
74 | """Load candidate data from a file.
75 | Args:path_to_candidate (str): path to file to load.
76 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
77 | """
78 |
79 | with open(path_to_candidate, 'r') as f:
80 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f)
81 | return qid_to_ranked_candidate_passages
82 |
83 |
84 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
85 | """Perform quality checks on the dictionaries
86 | Args:
87 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
88 | Dict as read in with load_reference or load_reference_from_stream
89 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
90 | Returns:
91 | bool,str: Boolean whether allowed, message to be shown in case of a problem
92 | """
93 | message = ''
94 | allowed = True
95 |
96 | # Create sets of the QIDs for the submitted and reference queries
97 | candidate_set = set(qids_to_ranked_candidate_passages.keys())
98 | ref_set = set(qids_to_relevant_passageids.keys())
99 |
100 | # Check that we do not have multiple passages per query
101 | for qid in qids_to_ranked_candidate_passages:
102 | # Remove all zeros from the candidates
103 | duplicate_pids = set(
104 | [item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1])
105 |
106 | if len(duplicate_pids - set([0])) > 0:
107 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format(
108 | qid=qid, pid=list(duplicate_pids)[0])
109 | allowed = False
110 |
111 | return allowed, message
112 |
113 |
114 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
115 | """Compute MRR metric
116 | Args:
117 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
118 | Dict as read in with load_reference or load_reference_from_stream
119 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
120 | Returns:
121 | dict: dictionary of metrics {'MRR': }
122 | """
123 | all_scores = {}
124 | MRR_10, MRR_100 = 0, 0
125 | qids_with_relevant_passages = 0
126 | ranking = []
127 | for qid in qids_to_ranked_candidate_passages:
128 | if qid in qids_to_relevant_passageids:
129 | ranking.append(0)
130 | target_pid = qids_to_relevant_passageids[qid]
131 | candidate_pid = qids_to_ranked_candidate_passages[qid]
132 | for i in range(0, MaxMRRRank1):
133 | if candidate_pid[i] in target_pid:
134 | MRR_10 += 1 / (i + 1)
135 | ranking.pop()
136 | ranking.append(i + 1)
137 | break
138 | for i in range(0, MaxMRRRank2):
139 | if candidate_pid[i] in target_pid:
140 | MRR_100 += 1 / (i + 1)
141 | break
142 | if len(ranking) == 0:
143 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
144 |
145 | MRR_10 = MRR_10 / len(qids_to_relevant_passageids)
146 | MRR_100 = MRR_100 / len(qids_to_relevant_passageids)
147 | all_scores['MRR @10'] = MRR_10
148 | all_scores['MRR @100'] = MRR_100
149 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages)
150 | return all_scores
151 |
152 |
153 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True):
154 | """Compute MRR metric
155 | Args:
156 | p_path_to_reference_file (str): path to reference file.
157 | Reference file should contain lines in the following format:
158 | QUERYID\tPASSAGEID
159 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs
160 | p_path_to_candidate_file (str): path to candidate file.
161 | Candidate file sould contain lines in the following format:
162 | QUERYID\tPASSAGEID1\tRank
163 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is
164 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID
165 | Where the values are separated by tabs and ranked in order of relevance
166 | Returns:
167 | dict: dictionary of metrics {'MRR': }
168 | """
169 |
170 | qids_to_relevant_passageids = load_reference(path_to_reference)
171 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate)
172 | if perform_checks:
173 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
174 | if message != '': print(message)
175 |
176 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
177 |
178 |
179 | def main():
180 | """Command line:
181 | python msmarco_eval_ranking.py
182 | """
183 | path_to_candidate = sys.argv[2]
184 | path_to_reference = sys.argv[1]
185 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
186 | print('#####################')
187 | for metric in sorted(metrics):
188 | print('{}: {}'.format(metric, metrics[metric]))
189 | print('#####################')
190 |
191 |
192 | if __name__ == '__main__':
193 | main()
--------------------------------------------------------------------------------
/finetune/dataloader.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import random
7 | import numpy as np
8 | from tqdm import tqdm
9 | from torch.utils.data import Dataset, DataLoader
10 | from torch.utils.data.distributed import DistributedSampler
11 |
12 |
13 | class TrainQDDatasetPairwise(Dataset):
14 | def __init__(self, q_ids, d_ids, q_dict, d_dict, did2idx, config, labels, mode='train'):
15 | self.q_ids = q_ids
16 | self.d_ids = d_ids
17 | self.q_dict = q_dict
18 | self.d_dict = d_dict
19 | self.did2idx = did2idx
20 | self.labels = labels
21 | self.mode = mode
22 | self.config = config
23 |
24 | def __len__(self):
25 | return len(self.q_ids)
26 |
27 | def __getitem__(self, item):
28 | cls_id, sep_id = 101, 102
29 | q_id = self.q_ids[item]
30 | d_id = self.d_ids[item]
31 |
32 | q_id = q_id[0]
33 | pos_did, neg_did = d_id[0], d_id[1]
34 |
35 | query_input_ids, pos_doc_input_ids, neg_doc_input_ids = self.q_dict[str(q_id)], self.d_dict[self.did2idx[pos_did]].tolist(), \
36 | self.d_dict[self.did2idx[neg_did]].tolist()
37 | query_input_ids = query_input_ids[: self.config.max_q_len]
38 | max_passage_length = self.config.max_len - 3 - len(query_input_ids)
39 |
40 | pos_doc_input_ids = pos_doc_input_ids[:max_passage_length]
41 | neg_doc_input_ids = neg_doc_input_ids[:max_passage_length]
42 |
43 | pos_input_ids = [cls_id] + query_input_ids + [sep_id] + pos_doc_input_ids + [sep_id]
44 | neg_input_ids = [cls_id] + query_input_ids + [sep_id] + neg_doc_input_ids + [sep_id]
45 |
46 | pos_token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(pos_doc_input_ids))
47 | neg_token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(neg_doc_input_ids))
48 |
49 | pos_token_ids = np.array(pos_input_ids)
50 | neg_token_ids = np.array(neg_input_ids)
51 | token_ids = np.stack((pos_token_ids.flatten(), neg_token_ids.flatten()))
52 |
53 | pos_attention_mask = np.int64(pos_token_ids > 0)
54 | neg_attention_mask = np.int64(neg_token_ids > 0)
55 | attention_mask = np.stack((pos_attention_mask, neg_attention_mask))
56 |
57 | pos_token_type_ids = np.array(pos_token_type_ids)
58 | neg_token_type_ids = np.array(neg_token_type_ids)
59 | token_type_ids = np.stack((pos_token_type_ids, neg_token_type_ids))
60 |
61 | return {
62 | 'token_ids': token_ids,
63 | 'attention_mask': attention_mask,
64 | 'token_type_ids': token_type_ids,
65 | }
66 |
67 |
68 | class TestQDDataset(Dataset):
69 | def __init__(self, q_ids, d_ids, token_ids, attention_mask, token_type_ids, mode='test'):
70 | self.q_ids = q_ids
71 | self.d_ids = d_ids
72 | self.token_ids = token_ids
73 | self.attention_mask = attention_mask
74 | self.token_type_ids= token_type_ids
75 | self.mode = mode
76 |
77 | def __len__(self):
78 | return len(self.q_ids)
79 |
80 | def __getitem__(self, item):
81 | q_id = self.q_ids[item]
82 | d_id = self.d_ids[item]
83 | token_ids = np.array(self.token_ids[item])
84 | attention_mask = np.array(self.attention_mask[item])
85 | token_type_ids = np.array(self.token_type_ids[item])
86 |
87 | return {
88 | "q_id": q_id,
89 | "d_id": d_id,
90 | 'token_ids': token_ids.flatten(),
91 | 'attention_mask': attention_mask.flatten(),
92 | 'token_type_ids': token_type_ids.flatten(),
93 | }
94 |
95 |
96 | # [CLS] q [SEP] d [SEP]
97 | def get_train_qd_loader(df_qds, train_top100, q_dict, d_dict, did2idx, config, mode='train'):
98 | q_max_len, max_len, batch_size = config.max_q_len, config.max_len, config.batch_size
99 | q_ids = df_qds[0].values.tolist()
100 | d_ids = df_qds[2].values.tolist()
101 |
102 | qd_dict = {}
103 | for q_id, d_id in zip(q_ids, d_ids):
104 | if q_id not in qd_dict:
105 | qd_dict[q_id] = []
106 | qd_dict[q_id].append(d_id)
107 |
108 | top100_dict = {}
109 | top_qids = train_top100[0].values.tolist()
110 | top_dids = train_top100[1].values.tolist()
111 | for qid, did in zip(top_qids, top_dids):
112 | if qid not in top100_dict:
113 | top100_dict[qid] = []
114 | top100_dict[qid].append(did)
115 |
116 | new_q_ids, new_d_ids, labels = [], [], []
117 |
118 | q_num = len(q_ids)
119 | for idx in tqdm(range(q_num), desc=f"Loading train q-d progress"):
120 | this_qid = q_ids[idx]
121 | neg_cands = set(top100_dict[this_qid]) - set(qd_dict[this_qid])
122 | neg_cands = list(neg_cands)
123 | neg_dids = random.sample(neg_cands, config.neg_docs_per_q)
124 | for i in range(config.neg_docs_per_q):
125 | new_q_ids.append([this_qid])
126 | new_d_ids.append([d_ids[idx], neg_dids[i]])
127 | labels.append([1, 0])
128 |
129 | print('Loading tokens...')
130 | ds = TrainQDDatasetPairwise(
131 | q_ids=new_q_ids,
132 | d_ids=new_d_ids,
133 | q_dict=q_dict,
134 | d_dict=d_dict,
135 | did2idx=did2idx,
136 | config=config,
137 | labels=labels,
138 | mode='train'
139 | )
140 | batch_size = batch_size // 2
141 |
142 | if config.distributed_train:
143 | sampler = DistributedSampler(ds, num_replicas=config.world_size, rank=config.local_rank)
144 | return DataLoader(
145 | ds,
146 | batch_size=batch_size,
147 | num_workers=0,
148 | sampler=sampler
149 | )
150 | else:
151 | if mode == 'train':
152 | return DataLoader(
153 | ds,
154 | batch_size=batch_size,
155 | num_workers=0,
156 | shuffle=True,
157 | )
158 | else:
159 | return DataLoader(
160 | ds,
161 | batch_size=batch_size,
162 | num_workers=0,
163 | shuffle=False,
164 | )
165 |
166 |
167 | def get_test_qd_loader(top100qd, q_dict, d_dict, did2idx, config):
168 | cls_id, sep_id = 101, 102
169 | q_ids = top100qd[0].values.tolist()
170 | d_ids = top100qd[1].values.tolist()
171 |
172 | qd_dict = {}
173 | for q_id, d_id in zip(q_ids, d_ids):
174 | if q_id not in qd_dict:
175 | qd_dict[q_id] = []
176 | qd_dict[q_id].append(d_id)
177 |
178 | q_num = len(q_dict)
179 | qids = list(set(q_dict.keys()))
180 | tokens_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512
181 | token_type_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512
182 |
183 | new_q_ids, new_d_ids = [], []
184 | for idx in tqdm(range(len(qids)), desc=f"Loading test q-d pair progress"):
185 | this_qid = qids[idx]
186 |
187 | query_input_ids = q_dict[str(this_qid)]
188 | query_input_ids = query_input_ids[: config.max_q_len]
189 | max_passage_length = config.max_len - 3 - len(query_input_ids)
190 |
191 | dids = qd_dict[int(this_qid)]
192 | assert len(dids) == 100
193 | for rank in range(len(dids)):
194 | this_did = dids[rank]
195 | doc_input_ids = d_dict[did2idx[this_did]].tolist()
196 | doc_input_ids = doc_input_ids[:max_passage_length]
197 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id]
198 | token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(doc_input_ids))
199 | cat_len = min(len(input_ids), config.max_len)
200 |
201 | new_q_ids.append(this_qid)
202 | new_d_ids.append(this_did)
203 | tokens_np[idx * 100 + rank, :cat_len] = np.array(input_ids)
204 | token_type_np[idx * 100 + rank, :cat_len] = np.array(token_type_ids)
205 |
206 | attention_mask = np.int64(tokens_np > 0).tolist()
207 | tokens = tokens_np.tolist() # q_num x 512
208 | token_type = token_type_np.tolist() # q_num x 512
209 |
210 | ds = TestQDDataset(
211 | q_ids=new_q_ids,
212 | d_ids=new_d_ids,
213 | token_ids=tokens,
214 | token_type_ids=token_type,
215 | attention_mask=attention_mask,
216 | mode='test'
217 | )
218 |
219 | return DataLoader(
220 | ds,
221 | batch_size=100, # 100 docs per q
222 | num_workers=0,
223 | shuffle=False,
224 | )
225 |
226 |
--------------------------------------------------------------------------------
/visualization/dataloader.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | import numpy as np
6 | from tqdm import tqdm
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.utils.data.distributed import DistributedSampler
9 |
10 |
11 | class TestQDDataset(Dataset):
12 | def __init__(self, q_ids, d_ids, token_ids, attention_mask, token_type_ids, mode='test'):
13 | self.q_ids = q_ids
14 | self.d_ids = d_ids
15 | self.token_ids = token_ids
16 | self.attention_mask = attention_mask
17 | self.token_type_ids= token_type_ids
18 | self.mode = mode
19 |
20 | def __len__(self):
21 | return len(self.q_ids)
22 |
23 | def __getitem__(self, item):
24 | q_id = self.q_ids[item]
25 | d_id = self.d_ids[item]
26 | token_ids = np.array(self.token_ids[item])
27 | attention_mask = np.array(self.attention_mask[item])
28 | token_type_ids = np.array(self.token_type_ids[item])
29 |
30 | return {
31 | "q_id": q_id,
32 | "d_id": d_id,
33 | 'token_ids': token_ids.flatten(),
34 | 'attention_mask': attention_mask.flatten(),
35 | 'token_type_ids': token_type_ids.flatten(),
36 | }
37 |
38 |
39 | class VisualTestQDDataset(Dataset):
40 | def __init__(self, q_ids, d_ids, ranks,token_ids, ref_token_ids, token_type_ids, ref_token_type_ids, attention_mask, mode='test'):
41 | self.q_ids = q_ids
42 | self.d_ids = d_ids
43 | self.ranks = ranks
44 | self.token_ids = token_ids
45 | self.ref_token_ids=ref_token_ids
46 | self.token_type_ids = token_type_ids
47 | self.ref_token_type_ids = ref_token_type_ids
48 | self.attention_mask = attention_mask
49 | self.mode = mode
50 |
51 | def __len__(self):
52 | return len(self.q_ids)
53 |
54 | def __getitem__(self, item):
55 | q_id = self.q_ids[item]
56 | d_id = self.d_ids[item]
57 | rank = self.ranks[item]
58 | token_ids = np.array(self.token_ids[item])
59 | attention_mask = np.array(self.attention_mask[item])
60 | token_type_ids = np.array(self.token_type_ids[item])
61 | ref_token_ids=np.array(self.ref_token_ids[item])
62 | ref_token_type_ids=np.array(self.ref_token_type_ids[item])
63 | return {
64 | "q_id": q_id,
65 | "d_id": d_id,
66 | "rank": rank,
67 | 'token_ids': token_ids.flatten(),
68 | 'attention_mask': attention_mask.flatten(),
69 | 'token_type_ids': token_type_ids.flatten(),
70 | 'ref_token_ids': ref_token_ids.flatten(),
71 | 'ref_token_type_ids': ref_token_type_ids.flatten()
72 | }
73 |
74 |
75 | def get_test_qd_loader(top100qd, q_dict, d_dict, did2idx, config):
76 | cls_id, sep_id = 101, 102
77 | q_ids = top100qd[0].values.tolist()
78 | d_ids = top100qd[1].values.tolist()
79 |
80 | qd_dict = {}
81 | for q_id, d_id in zip(q_ids, d_ids):
82 | if q_id not in qd_dict:
83 | qd_dict[q_id] = []
84 | qd_dict[q_id].append(d_id)
85 |
86 | q_num = len(q_dict)
87 | qids = list(set(q_dict.keys()))
88 | tokens_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512
89 | token_type_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512
90 |
91 | new_q_ids, new_d_ids = [], []
92 | for idx in tqdm(range(len(qids)), desc=f"Loading test q-d pair progress"):
93 | this_qid = qids[idx]
94 |
95 | query_input_ids = q_dict[str(this_qid)]
96 | query_input_ids = query_input_ids[: config.max_q_len]
97 | max_passage_length = config.max_len - 3 - len(query_input_ids)
98 |
99 | dids = qd_dict[int(this_qid)]
100 | assert len(dids) == 100
101 | for rank in range(len(dids)):
102 | this_did = dids[rank]
103 | doc_input_ids = d_dict[did2idx[this_did]].tolist()
104 | doc_input_ids = doc_input_ids[:max_passage_length]
105 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id]
106 | token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(doc_input_ids))
107 | cat_len = min(len(input_ids), config.max_len)
108 |
109 | new_q_ids.append(this_qid)
110 | new_d_ids.append(this_did)
111 | tokens_np[idx * 100 + rank, :cat_len] = np.array(input_ids)
112 | token_type_np[idx * 100 + rank, :cat_len] = np.array(token_type_ids)
113 |
114 | attention_mask = np.int64(tokens_np > 0).tolist()
115 | tokens = tokens_np.tolist() # q_num x 512
116 | token_type = token_type_np.tolist() # q_num x 512
117 |
118 | ds = TestQDDataset(
119 | q_ids=new_q_ids,
120 | d_ids=new_d_ids,
121 | token_ids=tokens,
122 | token_type_ids=token_type,
123 | attention_mask=attention_mask,
124 | mode='test'
125 | )
126 |
127 | return DataLoader(
128 | ds,
129 | batch_size=100, # 100 docs per q
130 | num_workers=0,
131 | shuffle=False,
132 | )
133 |
134 |
135 | def get_visual_test_qd_loader(top100qd, q_dict, d_dict, did2idx, config):
136 | cls_id, sep_id, pad_id = 101, 102, 0
137 | q_ids = top100qd["q_id"].values.tolist()
138 | d_ids = top100qd["d_id"].values.tolist()
139 | ranks = top100qd["rank"].values.tolist()
140 | d_num = config.visual_d_num
141 | q_num = config.visual_q_num
142 | qd_dict = {}
143 | for q_id, d_id, rank in zip(q_ids, d_ids,ranks):
144 | if q_id not in qd_dict:
145 | qd_dict[q_id] = []
146 | qd_dict[q_id].append([d_id,rank])
147 |
148 | qids = list(q_dict.keys())[:q_num]
149 | tokens_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512
150 | token_type_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512
151 | ref_tokens_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512
152 | ref_token_type_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512
153 |
154 | new_q_ids, new_d_ids, new_ranks = [], [], []
155 | for idx in tqdm(range(len(qids)), desc=f"Loading test q-d pair progress"):
156 | this_qid = qids[idx]
157 | query_input_ids = q_dict[str(this_qid)]
158 | query_input_ids = query_input_ids[: config.max_q_len]
159 | max_passage_length = config.max_len - 3 - len(query_input_ids)
160 |
161 | did_ranks = qd_dict[str(this_qid)][:d_num]
162 | assert len(did_ranks) == d_num
163 | for rank in range(len(did_ranks)):
164 | this_did,this_rank = did_ranks[rank]
165 | doc_input_ids = d_dict[did2idx[this_did]].tolist()
166 | doc_input_ids = doc_input_ids[:max_passage_length]
167 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id]
168 | ref_input_ids = [cls_id] + [pad_id] * len(query_input_ids) + [sep_id] + [pad_id] * len(doc_input_ids) + [sep_id]
169 | token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(doc_input_ids))
170 | ref_token_type_ids = [0] * len(token_type_ids)
171 | cat_len = min(len(input_ids), config.max_len)
172 |
173 | new_q_ids.append(this_qid)
174 | new_d_ids.append(this_did)
175 | new_ranks.append(this_rank)
176 | tokens_np[idx * d_num + rank, :cat_len] = np.array(input_ids)
177 | token_type_np[idx * d_num + rank, :cat_len] = np.array(token_type_ids)
178 | ref_tokens_np[idx * d_num + rank, :cat_len] = np.array(ref_input_ids)
179 | ref_token_type_np[idx * d_num + rank, :cat_len] = np.array(ref_token_type_ids)
180 | attention_mask = np.int64(tokens_np > 0).tolist()
181 | tokens = tokens_np.tolist() # q_num x 512
182 | token_type = token_type_np.tolist() # q_num x 512
183 | ref_tokens = ref_tokens_np.tolist() # q_num x 512
184 | ref_token_type = ref_token_type_np.tolist() # q_num x 512
185 |
186 | ds = VisualTestQDDataset(
187 | q_ids=new_q_ids,
188 | d_ids=new_d_ids,
189 | ranks=new_ranks,
190 | token_ids=tokens,
191 | ref_token_ids=ref_tokens,
192 | token_type_ids=token_type,
193 | ref_token_type_ids=ref_token_type,
194 | attention_mask=attention_mask,
195 | mode='test'
196 | )
197 |
198 | return DataLoader(
199 | ds,
200 | batch_size=config.batch_size, # 100 docs per q
201 | num_workers=0,
202 | shuffle=False,
203 | )
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ## Introduction
21 | This codebase contains source-code of the Python-based implementation (ARES) of our SIGIR 2022 paper.
22 | - [Chen, Jia, et al. "Axiomatically Regularized Pre-training for Ad hoc Search." To Appear in the Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval. 2022.](https://xuanyuan14.github.io/files/SIGIR22Chen.pdf)
23 |
24 | ## Requirements
25 | * python 3.7
26 | * torch==1.9.0
27 | * transformers==4.9.2
28 | * tqdm, nltk, numpy, boto3
29 | * [trec_eval](https://github.com/usnistgov/trec_eval) for evaluation on TREC DL 2019
30 | * [anserini](https://github.com/castorini/anserini) for generating "RANK" axiom scores
31 |
32 | ## Why this repo?
33 | In this repo, you can pre-train ARESsimple and TransformerICT models, and fine-tune all pre-trained models with the same architecture as BERT. The papers are listed as follows:
34 | * BERT ([Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/pdf/1810.04805.pdf&usg=ALkJrhhzxlCL6yTht2BRmH9atgvKFxHsxQ))
35 | * TransformerICT ([Latent Retrieval for Weakly Supervised Open Domain Question Answering.](https://arxiv.org/pdf/1906.00300))
36 | * PROP ([PROP: Pre-training with representative words prediction for ad-hoc retrieval.](https://dl.acm.org/doi/pdf/10.1145/3437963.3441777))
37 | * ARES ([Axiomatically Regularized Pre-training for Ad hoc Search.](https://xuanyuan14.github.io/files/SIGIR22Chen.pdf))
38 |
39 | You can download the pre-trained ARES checkpoint [ARESsimple](https://drive.google.com/file/d/1QvJ-hs6VtK4nlrlFkzPZAXfTtY-QjTiU/view?usp=sharing) from Google drive and extract it.
40 |
41 | ## Pre-training Data
42 |
43 | ### Download data
44 | Download the **MS MARCO** corpus from the official [website](https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz).
45 | Download the **ADORE+STAR Top100 Candidates** files from this [repo](https://github.com/jingtaozhan/DRhard).
46 |
47 | ### Pre-process data
48 | To save memory, we store most files using the numpy `memmap` or `jsonl` format in the `./preprocess` directory.
49 |
50 | Document files:
51 | * `doc_token_ids.memmap`: each line is the token ids for a document
52 | * `docid2idx.json`: `{docid: memmap_line_id}`
53 |
54 | Query files:
55 | * `queries.doctrain.jsonl`: MS MARCO training queries `{"id" qid, "ids": token_ids}` for each line
56 | * `queries.docdev.jsonl`: MS MARCO validating queries `{"id" qid, "ids": token_ids}` for each line
57 | * `queries.dl2019.jsonl`: TREC DL 2019 queries `{"id" qid, "ids": token_ids}` for each line
58 |
59 | Human label files:
60 | * `msmarco-doctrain-qrels.tsv`: `qid 0 docid 1` for training set
61 | * `dev-qrels.txt`: `qid relevant_docid` for validating set
62 | * `2019qrels-docs.txt`: `qid relevant_docid` for TREC DL 2019 set
63 |
64 | Top 100 candidate files:
65 | * `train.rank.tsv`, `dev.rank.tsv`, `test.rank.tsv`: `qid docid rank` for each line
66 |
67 | Pseudo queries and axiomatic features:
68 | * `doc2qs.jsonl`: `{"docid": docid, "queries": [qids]}` for each line
69 | * `sample_qs_token_ids.memmap`: each line is the token ids for a pseudo query
70 | * `sample_qid2id.json`: `{qid: memmap_line_id}`
71 | * `axiom.memmap`: axiom can be one of the `['rank', 'prox-1', 'prox-2', 'rep-ql', 'rep-tfidf', 'reg', 'stm-1', 'stm-2', 'stm-3']`, each line is an axiomatic score for a query
72 |
73 |
74 | ## Quick Start
75 |
76 | ### Example Usage
77 | ```python
78 | from model.modeling import ARESReranker
79 |
80 | model = ARESReranker.from_pretrained(model_path).to(device)
81 |
82 | query1 = "What is the best way to get to the airport"
83 | query2 = "what do you like to eat?"
84 |
85 | doc1 = "The best way to get to the airport is to take the bus"
86 | doc2 = "I like to eat apples"
87 |
88 | qd_pairs = [
89 | (query1, doc1), (query1, doc2),
90 | (query2, doc1), (query2, doc2)
91 | ]
92 |
93 | score = model.score(qd_pairs)
94 | ```
95 |
96 | You will get
97 | ```bash
98 | scores: [ 41.60 -33.66
99 | -38.00 30.03 ]
100 | ```
101 |
102 | Note that to accelerate the training process, we adopt the parallel training technique. The scripts for pre-training and fine-tuning are as follow:
103 |
104 | ### Pre-training
105 |
106 | ```shell
107 | export BERT_DIR=/path/to/bert-base/
108 | export XGB_DIR=/path/to/xgboost.model
109 |
110 | cd pretrain
111 |
112 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 NCCL_BLOCKING_WAIT=1 \
113 | python -m torch.distributed.launch --nproc_per_node=6 --nnodes=1 train.py \
114 | --model_type ARES \
115 | --PRE_TRAINED_MODEL_NAME BERT_DIR \
116 | --gpu_num 6 --world_size 6 \
117 | --MLM --axiom REP RANK REG PROX STM \
118 | --clf_model XGB_DIR
119 | ```
120 | Here model type can be `ARES` or `ICT`.
121 |
122 | ### Zero-shot evaluation (based on AS top100)
123 | ```shell
124 | export MODEL_DIR=/path/to/ares-simple/
125 | export CKPT_NAME=ares.ckpt
126 |
127 | cd finetune
128 |
129 | CUDA_VISIBLE_DEVICES=0 python train.py \
130 | --test \
131 | --PRE_TRAINED_MODEL_NAME MODEL_DIR \
132 | --model_type ARES \
133 | --model_name ARES_simple \
134 | --load_ckpt \
135 | --model_path CKPT_NAME
136 | ```
137 | You can get:
138 | ```bash
139 | #####################
140 | <----- MS Dev ----->
141 | MRR @10: 0.2991
142 | MRR @100: 0.3130
143 | QueriesRanked: 5193
144 | #####################
145 | ```
146 | on MS MARCO dev set and:
147 | ```bash
148 | #############################
149 | <--------- DL 2019 --------->
150 | QueriesRanked: 43
151 | nDCG @10: 0.5955
152 | nDCG @100: 0.4863
153 | #############################
154 | ```
155 | on DL 2019 set.
156 |
157 | ### Fine-tuning
158 | ```shell
159 | export MODEL_DIR=/path/to/ares-simple/
160 |
161 | cd finetune
162 |
163 | CUDA_VISIBLE_DEVICES=0,1,2,3 NCCL_BLOCKING_WAIT=1 \
164 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 train.py \
165 | --model_type ARES \
166 | --distributed_train \
167 | --PRE_TRAINED_MODEL_NAME MODEL_DIR \
168 | --gpu_num 4 --world_size 4 \
169 | --model_name ARES_simple
170 | ```
171 |
172 | ### Visualization
173 | ```shell
174 | export MODEL_DIR=/path/to/ares-simple/
175 | export SAVE_DIR=/path/to/output/
176 | export CKPT_NAME=ares.ckpt
177 |
178 | cd visualization
179 |
180 | CUDA_VISIBLE_DEVICES=0 python visual.py \
181 | --PRE_TRAINED_MODEL_NAME MODEL_DIR \
182 | --model_name ARES_simple \
183 | --visual_q_num 1 \
184 | --visual_d_num 5 \
185 | --save_path SAVE_DIR \
186 | --model_path CKPT_NAME
187 | ```
188 |
189 | ## Results
190 | Zero-shot performance:
191 |
192 | | Model Name | MS MARCO MRR@10 | MS MARCO MRR@100 | DL NDCG@10 | DL NDCG@100 | COVID | EQ |
193 | | :--: | :--: | :--: | :--: | :--: | :--: | :--: |
194 | | BM25 | 0.2962 | 0.3107 | 0.5776 | 0.4795 | 0.4857 | 0.6690 |
195 | | BERT | 0.1820 | 0.2012 | 0.4059 | 0.4198 | 0.4314 | 0.6055 |
196 | | PROPwiki | 0.2429 | 0.2596 | 0.5088 | 0.4525 | 0.4857 | 0.5991 |
197 | | PROPmarco | 0.2763 | 0.2914 | 0.5317 | 0.4623 | 0.4829 | 0.6454 |
198 | | ARESstrict | 0.2630 | 0.2785 | 0.4942 | 0.4504 | 0.4786 | 0.6923 |
199 | | AREShard | 0.2627 | 0.2780 | 0.5189 | 0.4613 | 0.4943 | 0.6822 |
200 | | ARESsimple | 0.2991 | 0.3130 | 0.5955 | 0.4863 | 0.4957 | 0.6916 |
201 |
202 |
203 | Few-shot performance:
204 | 
205 |
206 | Visualization (attribution values have been normalized within a document):
207 | 
208 |
209 | ## Citation
210 | If you find our work useful, please do not save your star and cite our work:
211 | ```
212 | @inproceedings{chen2022axiomatically,
213 | title={Axiomatically Regularized Pre-training for Ad hoc Search},
214 | author={Chen, Jia and Liu, Yiqun and Fang, Yan and Mao, Jiaxin and Fang, Hui and Yang, Shenghao and Xie, Xiaohui and Zhang, Min and Ma, Shaoping},
215 | booktitle={Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval},
216 | year={2022}
217 | }
218 | ```
219 |
220 |
221 | ## Notice
222 | * Please make sure that all the pre-trained model parameters have been loaded correctly, or the zero-shot and the fine-tuning performance will be greatly impacted.
223 | * We welcome anyone who would like to contribute to this repo. 🤗
224 | * If you have any other questions, please feel free to contact me via [chenjia0831@gmail.com]() or open an issue.
225 | * Code for data preprocessing will come soon. Please stay tuned~
226 |
--------------------------------------------------------------------------------
/visualization/visual.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | import os
6 | import random
7 | from tqdm import tqdm
8 | import json
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | import pandas as pd
13 | from model.modeling import ARES
14 | from transformers import PretrainedConfig, BertConfig,BertTokenizer
15 | from dataloader import get_visual_test_qd_loader, get_test_qd_loader
16 | from config import get_config
17 | import warnings
18 | from captum.attr import LayerIntegratedGradients
19 | # from captum.attr import visualization as viz
20 | import visualization as viz
21 | from gensim.models import KeyedVectors
22 | warnings.filterwarnings("ignore")
23 |
24 |
25 | def eval_model(model, test_qd_loader, device, config):
26 | model.eval()
27 | qd_rank = pd.DataFrame(columns=['q_id', 'd_id', 'rank', 'score'])
28 | q_id_list, d_id_list, rank, score = [], [], [], []
29 | top5_q_id_list, top5_d_id_list, top5_rank_list,top5_score_list = [], [], [], []
30 | num_instances = len(test_qd_loader)
31 | with torch.no_grad():
32 | for i, batch_data in enumerate(tqdm(test_qd_loader, desc=f"Evaluating progress", total=num_instances)):
33 | input_ids, attention_mask, token_type_ids = batch_data["token_ids"], batch_data["attention_mask"], \
34 | batch_data["token_type_ids"]
35 |
36 | input_ids = input_ids.to(device) # bs x 512
37 | attention_mask = attention_mask.to(device) # bs x 512
38 | token_type_ids = token_type_ids.to(device)
39 |
40 | output = model(
41 | input_ids=input_ids,
42 | config=config,
43 | input_mask=attention_mask,
44 | token_type_ids=token_type_ids,
45 | ) # 100 x 1
46 |
47 | output = output.squeeze()
48 | q_ids = batch_data["q_id"]
49 | d_ids = batch_data["d_id"]
50 | scores = output.cpu().tolist()
51 | top5_q_id_list.extend(q_ids[:5])
52 | top5_d_id_list.extend(d_ids[:5])
53 | top5_score_list.extend(scores[:5])
54 | tuples = list(zip(q_ids, d_ids, scores))
55 | sorted_tuples = sorted(tuples, key=lambda x: x[2], reverse=True)
56 | for idx, this_tuple in enumerate(sorted_tuples):
57 | q_id_list.append(this_tuple[0])
58 | d_id_list.append(this_tuple[1])
59 | rank.append(idx + 1)
60 | score.append(this_tuple[2])
61 | qd_rank['q_id'] = q_id_list
62 | qd_rank['d_id'] = d_id_list
63 | qd_rank['rank'] = rank
64 | qd_rank['score'] = score
65 | df_rank = pd.DataFrame(columns=['q_id', 'Q0', 'd_id', 'rank', 'score', 'standard'])
66 | df_rank['q_id'] = qd_rank['q_id']
67 | df_rank['Q0'] = ['Q0'] * len(qd_rank['q_id'])
68 | df_rank['d_id'] = qd_rank['d_id']
69 | df_rank['rank'] = qd_rank['rank']
70 | df_rank['score'] = qd_rank['score']
71 | df_rank['standard'] = ['STANDARD'] * len(qd_rank['q_id'])
72 | df_rank.to_csv(f"{config.save_path}/dl2019_qd_rank_{config.model_name}.tsv", sep=' ', index=False, header=False)
73 | result_lines = os.popen(f'trec_eval -m ndcg_cut.10,100 {config.dl2019_qd_dir} {config.save_path}/dl2019_qd_rank_{config.model_name}.tsv').read().strip().split("\n")
74 | ndcg_10, ndcg_100 = float(result_lines[0].strip().split()[-1]), float(result_lines[1].strip().split()[-1])
75 | metrics = {'nDCG @10': ndcg_10, 'nDCG @100': ndcg_100, 'QueriesRanked': len(set(qd_rank['q_id']))}
76 | print('\n#############################')
77 | print(config.model_name)
78 | print('<--------- DL 2019 --------->')
79 | for metric in sorted(metrics):
80 | print('{}: {}'.format(metric, metrics[metric]))
81 | print('#############################\n')
82 | return df_rank
83 |
84 |
85 | def visual_model(lig, tokenizer, qd_loader, df_dl2019_qds,device, config):
86 | score_viz_list = []
87 | index = 0
88 | for i, batch_data in enumerate(tqdm(qd_loader, desc=f"IG progress", total=len(qd_loader))):
89 | q_ids, d_ids, ranks, input_ids, ref_input_ids, attention_mask, token_type_ids, ref_token_type_ids = \
90 | batch_data["q_id"], batch_data["d_id"], batch_data["rank"],\
91 | batch_data["token_ids"], batch_data["ref_token_ids"], batch_data["attention_mask"], batch_data["token_type_ids"], batch_data["ref_token_type_ids"]
92 | input_ids = input_ids.to(device) # bs x 512
93 | ref_input_ids = ref_input_ids.to(device) # bs x 512
94 | attention_mask = attention_mask.to(device) # bs x 512
95 | token_type_ids = token_type_ids.to(device)
96 | ref_token_type_ids = ref_token_type_ids.to(device)
97 | attributions, deltas = lig.attribute(
98 | inputs=(input_ids, token_type_ids),
99 | baselines=(ref_input_ids,ref_token_type_ids),
100 | return_convergence_delta=True,
101 | additional_forward_args=attention_mask,
102 | internal_batch_size=5
103 | )
104 | for j, attribution,delta in enumerate(zip(attributions, deltas)): # for 512*768 in bs*512*768
105 | attribution_sum = attribution.sum(dim=-1).squeeze(0) # 512
106 | tokens = [token.replace("Ġ", "") for token in tokenizer.convert_ids_to_tokens(input_ids[j])]
107 | sep_index = tokens.index('[SEP]')
108 | query_tokens = tokens[:sep_index]
109 | doc_tokens = tokens[sep_index:sep_index+250]
110 | tokens = query_tokens+doc_tokens
111 | query_attribution_sum = attribution_sum[:sep_index] / torch.norm(attribution_sum)
112 | doc_attribution_sum = attribution_sum[sep_index:sep_index+250] / torch.norm(attribution_sum[sep_index:sep_index+250])
113 | attribution_sum = torch.cat((query_attribution_sum,doc_attribution_sum), axis=-1)
114 | v_q_id,v_d_id,rank = q_ids[j],d_ids[j],ranks[j]
115 | try:
116 | level = df_dl2019_qds.set_index(0).loc[int(v_q_id)].set_index(2).loc[str(v_d_id)][3]
117 | except:
118 | level = -1
119 | score_viz = viz.VisualizationDataRecord(
120 | attribution_sum,
121 | level,
122 | rank,
123 | v_q_id,
124 | v_d_id,
125 | tokens,
126 | delta,
127 | )
128 | score_viz_list.append(score_viz)
129 | # index += 1
130 | html = viz.visualize_text(score_viz_list)
131 | html_filepath = f"output_{config.model_name}.html"
132 | with open(html_filepath, "w") as html_file:
133 | html_file.write(html.data)
134 |
135 |
136 | if __name__ == '__main__':
137 | config = get_config()
138 | random.seed(config.seed)
139 | np.random.seed(config.seed)
140 | torch.manual_seed(config.seed)
141 | config.local_rank = config.local_rank
142 | if torch.cuda.is_available():
143 | torch.cuda.manual_seed_all(config.seed)
144 | torch.cuda.set_device(config.local_rank)
145 | print('GPU is ON!')
146 | device = torch.device('cuda')
147 | else:
148 | device = torch.device("cpu")
149 | df_dl2019_qds = pd.read_csv(config.dl2019_qd_dir, sep=' ', header=None)
150 | dl2019_top100 = pd.read_csv(config.dl100_dir, sep='\t', header=None)
151 | dl2019_qs = {}
152 | with open(config.dl2019_qs_dir) as f_qs:
153 | for line in f_qs:
154 | es = json.loads(line)
155 | qid, ids = es["id"], es["ids"]
156 | if qid not in dl2019_qs:
157 | dl2019_qs[qid] = ids
158 | with open(config.docid2id_dir) as f_docid2id:
159 | docid2id = json.load(f_docid2id)
160 | collection_size = len(docid2id)
161 | doc_tokens = np.memmap(config.memmap_doc_dir, dtype='int32', shape=(collection_size, 512))
162 | print("\n========== Loading DL 2019 data ==========")
163 | dl2019_qd_loader = get_test_qd_loader(dl2019_top100, dl2019_qs, doc_tokens, docid2id, config)
164 | print(f"dl2019_q: {len(dl2019_qs)}, dl2019_q_batchs:{len(dl2019_qd_loader)}")
165 |
166 | print("Loading model...")
167 | if 'BERT' in config.model_name:
168 | model = ARES.from_pretrained(config.PRE_TRAINED_MODEL_NAME)
169 | tokenizer = BertTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME)
170 | elif 'ARES' in config.model_name or 'PROP' in config.model_name:
171 | cfg = PretrainedConfig.get_config_dict(config.PRE_TRAINED_MODEL_NAME)[0]
172 | if not config.gradient_checkpointing:
173 | del cfg["gradient_checkpointing"]
174 | del cfg["parameter_sharing"]
175 | cfg = BertConfig.from_dict(cfg)
176 | model = ARES(config=cfg)
177 | model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(f"{config.model_path}/{config.model_name}", map_location={'cuda:0':f'cuda:{config.local_rank}'}).items()},strict=False)
178 | tokenizer = BertTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME)
179 | model = model.to(device)
180 | print("Loading model finish")
181 | model_prefix = model.base_model_prefix
182 | model_base = getattr(model, model_prefix)
183 | if hasattr(model_base, "embeddings"):
184 | model_embeddings = getattr(model_base, "embeddings")
185 | lig = LayerIntegratedGradients(model, model_embeddings)
186 | qd_rank = eval_model(
187 | model,
188 | dl2019_qd_loader,
189 | device,
190 | config
191 | )
192 | print("\n========== Loading visual DL 2019 data ==========")
193 | visual_dl2019_qd_loader = get_visual_test_qd_loader(qd_rank, dl2019_qs, doc_tokens, docid2id,config)
194 | visual_model(
195 | lig,
196 | tokenizer,
197 | visual_dl2019_qd_loader,
198 | df_dl2019_qds,
199 | device,
200 | config
201 | )
--------------------------------------------------------------------------------
/preprocess/Eval4.0.pl:
--------------------------------------------------------------------------------
1 | #!
2 | # author: Jun Xu and Tie-Yan Liu
3 | # modified by Jun Xu, March 3, 2009 (for Letor 4.0)
4 | use strict;
5 |
6 | #hash table for NDCG,
7 | my %hsNdcgRelScore = ( "2", 3,
8 | "1", 1,
9 | "0", 0,
10 | );
11 |
12 | #hash table for Precision@N and MAP
13 | my %hsPrecisionRel = ("2", 1,
14 | "1", 1,
15 | "0", 0
16 | );
17 | #modified by Jun Xu, March 3, 2009
18 | # for Letor 4.0. only output top 10 precision and ndcg
19 | # my $iMaxPosition = 16;
20 | my $iMaxPosition = 10;
21 |
22 | my $argc = $#ARGV+1;
23 | if($argc != 4)
24 | {
25 | print "Invalid command line.\n";
26 | print "Usage: perl Eval.pl argv[1] argv[2] argv[3] argv[4]\n";
27 | print "argv[1]: feature file \n";
28 | print "argv[2]: prediction file\n";
29 | print "argv[3]: result (output) file\n";
30 | print "argv[4]: flag. If flag equals 1, output the evaluation results per query; if flag equals 0, simply output the average results.\n";
31 | exit -1;
32 | }
33 | my $fnFeature = $ARGV[0];
34 | my $fnPrediction = $ARGV[1];
35 | my $fnResult = $ARGV[2];
36 | my $flag = $ARGV[3];
37 | if($flag != 1 && $flag != 0)
38 | {
39 | print "Invalid command line.\n";
40 | print "Usage: perl Eval.pl argv[1] argv[2] argv[3] argv[4]\n";
41 | print "Flag should be 0 or 1\n";
42 | exit -1;
43 | }
44 |
45 | my %hsQueryDocLabelScore = ReadInputFiles($fnFeature, $fnPrediction);
46 | my %hsQueryEval = EvalQuery(\%hsQueryDocLabelScore);
47 | OuputResults($fnResult, %hsQueryEval);
48 |
49 |
50 | sub OuputResults
51 | {
52 | my ($fnOut, %hsResult) = @_;
53 | open(FOUT, ">$fnOut");
54 |
55 | my @qids = sort{$a <=> $b} keys(%hsResult);
56 | my $numQuery = @qids;
57 |
58 | #Precision@N and MAP
59 | # modified by Jun Xu, March 3, 2009
60 | # changing the output format
61 | print FOUT "qid\tP\@1\tP\@2\tP\@3\tP\@4\tP\@5\tP\@6\tP\@7\tP\@8\tP\@9\tP\@10\tMAP\n";
62 | #---------------------------------------------
63 | my @prec;
64 | my $map = 0;
65 | for(my $i = 0; $i < $#qids + 1; $i ++)
66 | {
67 | # modified by Jun Xu, March 3, 2009
68 | # output the real query id
69 | my $qid = $qids[$i];
70 | my @pN = @{$hsResult{$qid}{"PatN"}};
71 | my $map_q = $hsResult{$qid}{"MAP"};
72 | if ($flag == 1)
73 | {
74 | print FOUT "$qid\t";
75 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++)
76 | {
77 | print FOUT sprintf("%.4f\t", $pN[$iPos]);
78 | }
79 | print FOUT sprintf("%.4f\n", $map_q);
80 | }
81 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++)
82 | {
83 | $prec[$iPos] += $pN[$iPos];
84 | }
85 | $map += $map_q;
86 | }
87 | print FOUT "Average\t";
88 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++)
89 | {
90 | $prec[$iPos] /= ($#qids + 1);
91 | print FOUT sprintf("%.4f\t", $prec[$iPos]);
92 | }
93 | $map /= ($#qids + 1);
94 | print FOUT sprintf("%.4f\n\n", $map);
95 |
96 | #NDCG and MeanNDCG
97 | # modified by Jun Xu, March 3, 2009
98 | # changing the output format
99 | print FOUT "qid\tNDCG\@1\tNDCG\@2\tNDCG\@3\tNDCG\@4\tNDCG\@5\tNDCG\@6\tNDCG\@7\tNDCG\@8\tNDCG\@9\tNDCG\@10\tMeanNDCG\n";
100 | #---------------------------------------------
101 | my @ndcg;
102 | my $meanNdcg = 0;
103 | for(my $i = 0; $i < $#qids + 1; $i ++)
104 | {
105 | # modified by Jun Xu, March 3, 2009
106 | # output the real query id
107 | my $qid = $qids[$i];
108 | my @ndcg_q = @{$hsResult{$qid}{"NDCG"}};
109 | my $meanNdcg_q = $hsResult{$qid}{"MeanNDCG"};
110 | if ($flag == 1)
111 | {
112 | print FOUT "$qid\t";
113 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++)
114 | {
115 | print FOUT sprintf("%.4f\t", $ndcg_q[$iPos]);
116 | }
117 | print FOUT sprintf("%.4f\n", $meanNdcg_q);
118 | }
119 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++)
120 | {
121 | $ndcg[$iPos] += $ndcg_q[$iPos];
122 | }
123 | $meanNdcg += $meanNdcg_q;
124 | }
125 | print FOUT "Average\t";
126 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++)
127 | {
128 | $ndcg[$iPos] /= ($#qids + 1);
129 | print FOUT sprintf("%.4f\t", $ndcg[$iPos]);
130 | }
131 | $meanNdcg /= ($#qids + 1);
132 | print FOUT sprintf("%.4f\n\n", $meanNdcg);
133 |
134 | close(FOUT);
135 | }
136 |
137 | sub EvalQuery
138 | {
139 | my $pHash = $_[0];
140 | my %hsResults;
141 |
142 | my @qids = sort{$a <=> $b} keys(%$pHash);
143 | for(my $i = 0; $i < @qids; $i ++)
144 | {
145 | my $qid = $qids[$i];
146 | my @tmpDid = sort{$$pHash{$qid}{$a}{"lineNum"} <=> $$pHash{$qid}{$b}{"lineNum"}} keys(%{$$pHash{$qid}});
147 | my @docids = sort{$$pHash{$qid}{$b}{"pred"} <=> $$pHash{$qid}{$a}{"pred"}} @tmpDid;
148 | my @rates;
149 |
150 | for(my $iPos = 0; $iPos < $#docids + 1; $iPos ++)
151 | {
152 | $rates[$iPos] = $$pHash{$qid}{$docids[$iPos]}{"label"};
153 | }
154 |
155 | my $map = MAP(@rates);
156 | my @PAtN = PrecisionAtN($iMaxPosition, @rates);
157 | # modified by Jun Xu, calculate all possible positions' NDCG for MeanNDCG
158 | #my @Ndcg = NDCG($iMaxPosition, @rates);
159 |
160 | my @Ndcg = NDCG($#rates + 1, @rates);
161 | my $meanNdcg = 0;
162 | for(my $iPos = 0; $iPos < $#Ndcg + 1; $iPos ++)
163 | {
164 | $meanNdcg += $Ndcg[$iPos];
165 | }
166 | $meanNdcg /= ($#Ndcg + 1);
167 |
168 |
169 | @{$hsResults{$qid}{"PatN"}} = @PAtN;
170 | $hsResults{$qid}{"MAP"} = $map;
171 | @{$hsResults{$qid}{"NDCG"}} = @Ndcg;
172 | $hsResults{$qid}{"MeanNDCG"} = $meanNdcg;
173 |
174 | }
175 | return %hsResults;
176 | }
177 |
178 | sub ReadInputFiles
179 | {
180 | my ($fnFeature, $fnPred) = @_;
181 | my %hsQueryDocLabelScore;
182 |
183 | if(!open(FIN_Feature, $fnFeature))
184 | {
185 | print "Invalid command line.\n";
186 | print "Open \$fnFeature\" failed.\n";
187 | exit -2;
188 | }
189 | if(!open(FIN_Pred, $fnPred))
190 | {
191 | print "Invalid command line.\n";
192 | print "Open \"$fnPred\" failed.\n";
193 | exit -2;
194 | }
195 |
196 | my $lineNum = 0;
197 | while(defined(my $lnFea = ))
198 | {
199 | $lineNum ++;
200 | chomp($lnFea);
201 | my $predScore = ;
202 | if (!defined($predScore))
203 | {
204 | print "Error to read $fnPred at line $lineNum.\n";
205 | exit -2;
206 | }
207 | chomp($predScore);
208 | # modified by Jun Xu, 2008-9-9
209 | # Labels may have more than 3 levels
210 | # qid and docid may not be numeric
211 | # if ($lnFea =~ m/^([0-2]) qid\:(\d+).*?\#docid = (\d+)$/)
212 |
213 | # modified by Jun Xu, March 3, 2009
214 | # Letor 4.0's file format is different to Letor 3.0
215 | # if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+)$/)
216 | if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+) inc = ([^\s]+) prob = ([^\s]+).$/)
217 | {
218 | my $label = $1;
219 | my $qid = $2;
220 | my $did = $3;
221 | my $inc = $4;
222 | my $prob= $5;
223 | $hsQueryDocLabelScore{$qid}{$did}{"label"} = $label;
224 | $hsQueryDocLabelScore{$qid}{$did}{"inc"} = $inc;
225 | $hsQueryDocLabelScore{$qid}{$did}{"prob"} = $prob;
226 | $hsQueryDocLabelScore{$qid}{$did}{"pred"} = $predScore;
227 | $hsQueryDocLabelScore{$qid}{$did}{"lineNum"} = $lineNum;
228 | }
229 | else
230 | {
231 | print "Error to parse $fnFeature at line $lineNum:\n$lnFea\n";
232 | exit -2;
233 | }
234 | }
235 | close(FIN_Feature);
236 | close(FIN_Pred);
237 | return %hsQueryDocLabelScore;
238 | }
239 |
240 |
241 | sub PrecisionAtN
242 | {
243 | my ($topN, @rates) = @_;
244 | my @PrecN;
245 | my $numRelevant = 0;
246 | # modified by Jun Xu, 2009-4-24.
247 | # if # retrieved doc < $topN, the P@N will consider the hole as irrelevant
248 | # for(my $iPos = 0; $iPos < $topN && $iPos < $#rates + 1; $iPos ++)
249 | #
250 | for (my $iPos = 0; $iPos < $topN; $iPos ++)
251 | {
252 | my $r;
253 | if ($iPos < $#rates + 1)
254 | {
255 | $r = $rates[$iPos];
256 | }
257 | else
258 | {
259 | $r = 0;
260 | }
261 | $numRelevant ++ if ($hsPrecisionRel{$r} == 1);
262 | $PrecN[$iPos] = $numRelevant / ($iPos + 1);
263 | }
264 | return @PrecN;
265 | }
266 |
267 | sub MAP
268 | {
269 | my @rates = @_;
270 |
271 | my $numRelevant = 0;
272 | my $avgPrecision = 0.0;
273 | for(my $iPos = 0; $iPos < $#rates + 1; $iPos ++)
274 | {
275 | if ($hsPrecisionRel{$rates[$iPos]} == 1)
276 | {
277 | $numRelevant ++;
278 | $avgPrecision += ($numRelevant / ($iPos + 1));
279 | }
280 | }
281 | return 0.0 if ($numRelevant == 0);
282 | #return sprintf("%.4f", $avgPrecision / $numRelevant);
283 | return $avgPrecision / $numRelevant;
284 | }
285 |
286 | sub DCG
287 | {
288 | my ($topN, @rates) = @_;
289 | my @dcg;
290 |
291 | $dcg[0] = $hsNdcgRelScore{$rates[0]};
292 | # Modified by Jun Xu, 2009-4-24
293 | # if # retrieved doc < $topN, the NDCG@N will consider the hole as irrelevant
294 | # for(my $iPos = 1; $iPos < $topN && $iPos < $#rates + 1; $iPos ++)
295 | #
296 | for(my $iPos = 1; $iPos < $topN; $iPos ++)
297 | {
298 | my $r;
299 | if ($iPos < $#rates + 1)
300 | {
301 | $r = $rates[$iPos];
302 | }
303 | else
304 | {
305 | $r = 0;
306 | }
307 | if ($iPos < 2)
308 | {
309 | $dcg[$iPos] = $dcg[$iPos - 1] + $hsNdcgRelScore{$r};
310 | }
311 | else
312 | {
313 | $dcg[$iPos] = $dcg[$iPos - 1] + ($hsNdcgRelScore{$r} * log(2.0) / log($iPos + 1.0));
314 | }
315 | }
316 | return @dcg;
317 | }
318 | sub NDCG
319 | {
320 | my ($topN, @rates) = @_;
321 | my @ndcg;
322 | my @dcg = DCG($topN, @rates);
323 | my @stRates = sort {$hsNdcgRelScore{$b} <=> $hsNdcgRelScore{$a}} @rates;
324 | my @bestDcg = DCG($topN, @stRates);
325 |
326 | for(my $iPos =0; $iPos < $topN && $iPos < $#rates + 1; $iPos ++)
327 | {
328 | $ndcg[$iPos] = 0;
329 | $ndcg[$iPos] = $dcg[$iPos] / $bestDcg[$iPos] if ($bestDcg[$iPos] != 0);
330 | }
331 | return @ndcg;
332 | }
--------------------------------------------------------------------------------
/pretrain/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import os
7 | import sys
8 | sys.path.insert(0, '../')
9 |
10 | from tqdm import tqdm
11 | import json
12 | import torch
13 | import numpy as np
14 | from datetime import timedelta, datetime
15 | from model.modeling import ARES, ICT
16 |
17 | from transformers import AdamW, get_linear_schedule_with_warmup
18 | from transformers import PretrainedConfig, BertConfig
19 | from torch import nn
20 | from torch.cuda.amp import autocast, GradScaler
21 |
22 | from dataloader import get_train_qd_loader, get_ict_loader
23 | from config import get_config
24 | import warnings
25 |
26 | warnings.filterwarnings("ignore")
27 | torch.backends.cudnn.benchmark = True
28 |
29 |
30 | def train_epoch(model, scaler, qd_loader, optimizer, scheduler, device, config):
31 | model.train()
32 | losses = []
33 |
34 | num_instances = len(qd_loader)
35 | for step, batch_data in enumerate(tqdm(qd_loader, desc=f"Pretraining {config.model_type} progress", total=num_instances)):
36 | input_ids, attention_mask, masked_lm_ids = batch_data["token_ids"], batch_data["attention_mask"], batch_data["masked_lm_ids"]
37 | if config.model_type == 'ICT':
38 | token_type_ids = None
39 | input_ids, attention_mask, masked_lm_ids = input_ids.squeeze(), attention_mask.squeeze(), masked_lm_ids.squeeze()
40 | this_batch_size = input_ids.size()[0]
41 | if this_batch_size < 2:
42 | continue
43 | else:
44 | this_batch_size = input_ids.size()[0]
45 | token_type_ids = batch_data["token_type_ids"]
46 |
47 | input_ids = input_ids.reshape(this_batch_size * 2, -1)
48 | attention_mask = attention_mask.reshape(this_batch_size * 2, -1)
49 | masked_lm_ids = masked_lm_ids.reshape(this_batch_size * 2, -1)
50 | token_type_ids = token_type_ids.reshape(this_batch_size * 2, -1) if token_type_ids is not None else token_type_ids
51 |
52 | input_ids = input_ids.to(device) # bs x 512
53 | attention_mask = attention_mask.to(device) # bs x 512
54 | masked_lm_ids = masked_lm_ids.to(device)
55 |
56 | token_type_ids = token_type_ids.to(device) if token_type_ids is not None else token_type_ids
57 |
58 | with autocast():
59 | loss = model(
60 | input_ids=input_ids,
61 | config=config,
62 | input_mask=attention_mask,
63 | token_type_ids=token_type_ids,
64 | masked_lm_labels=masked_lm_ids,
65 | device=device
66 | )
67 |
68 | losses.append(loss.item())
69 | scaler.scale(loss).backward()
70 |
71 | # gradient accumulation
72 | if (step + 1) % config.gradient_accumulation_steps == 0:
73 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip)
74 | scaler.step(optimizer)
75 | scaler.update()
76 |
77 | scheduler.step()
78 | optimizer.zero_grad()
79 |
80 | if step % int(config.print_every) == 0:
81 | print(f"\n[Train] Loss at step {step} = {loss.item()}, lr = {optimizer.state_dict()['param_groups'][0]['lr']}")
82 |
83 | if step % 5000 == 0 and config.local_rank == 0:
84 | print('[SAVE] Saving model ... ')
85 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
86 | this_loss = round(float(np.mean(losses)), 4)
87 | torch.save(model_to_save.state_dict(), f"{save_dir}/{config.model_name}_{this_loss}_step{step}")
88 | return np.mean(losses)
89 |
90 |
91 | if __name__ == '__main__':
92 |
93 | # get configs
94 | config = get_config()
95 |
96 | # set save dir
97 | today = datetime.today().strftime('%Y-%m-%d')
98 | save_dir = f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/{today}"
99 | if not os.path.exists(save_dir):
100 | os.mkdir(save_dir)
101 |
102 | np.random.seed(config.seed)
103 | torch.manual_seed(config.seed)
104 | config.local_rank = config.local_rank
105 | if torch.cuda.is_available():
106 | torch.cuda.manual_seed_all(config.seed)
107 | torch.cuda.set_device(config.local_rank)
108 | print('GPU is ON!')
109 | device = torch.device(f'cuda:{config.local_rank}')
110 | else:
111 | device = torch.device("cpu")
112 |
113 | # set max timeout=50 hours
114 | if config.distributed_train:
115 | torch.distributed.init_process_group(backend="nccl", timeout=timedelta(180000000), rank=config.local_rank, world_size=config.world_size)
116 | local_rank = config.local_rank
117 | if local_rank != -1:
118 | print("Using Distributed")
119 |
120 | # json files
121 | doc2query = {}
122 | with open(config.doc2query_dir) as f_doc2query:
123 | for line in f_doc2query:
124 | es = json.loads(line)
125 | docid = es["docid"]
126 | queries = es["queries"]
127 | if docid not in doc2query:
128 | doc2query[docid] = queries
129 |
130 | with open(config.gen_qid2id_dir) as f_gen_qid2id:
131 | gen_qid2id = json.load(f_gen_qid2id)
132 |
133 | # save memory
134 | if config.model_type == 'ARES':
135 | q_num = len(gen_qid2id)
136 |
137 | axiom_rank = np.memmap(f"{config.axiom_feature_dir}/memmap/rank.memmap", dtype='float', shape=(q_num, 1))
138 | axiom_list = []
139 | print(config.axiom)
140 | if 'PROX' in config.axiom:
141 | prox_1 = np.memmap(f"{config.axiom_feature_dir}/memmap/prox-1.memmap", dtype='float', shape=(q_num, 1))
142 | prox_2 = np.memmap(f"{config.axiom_feature_dir}/memmap/prox-2.memmap", dtype='float', shape=(q_num, 1))
143 | axiom_list.append(['PROX-1', prox_1])
144 | axiom_list.append(['PROX-2', prox_2])
145 |
146 | if 'REP' in config.axiom:
147 | rep_ql = np.memmap(f"{config.axiom_feature_dir}/memmap/rep-ql.memmap", dtype='float', shape=(q_num, 1))
148 | rep_tfidf = np.memmap(f"{config.axiom_feature_dir}/memmap/rep-tfidf.memmap", dtype='float', shape=(q_num, 1))
149 | axiom_list.append(['REP-QL', rep_ql])
150 | axiom_list.append(['REP-TFIDF', rep_tfidf])
151 |
152 | if 'REG' in config.axiom:
153 | reg = np.memmap(f"{config.axiom_feature_dir}/memmap/reg.memmap", dtype='float', shape=(q_num, 1))
154 | axiom_list.append(['REG', reg])
155 |
156 | if 'STM' in config.axiom:
157 | stm_1 = np.memmap(f"{config.axiom_feature_dir}/memmap/stm-1.memmap", dtype='float', shape=(q_num, 1))
158 | stm_2 = np.memmap(f"{config.axiom_feature_dir}/memmap/stm-2.memmap", dtype='float', shape=(q_num, 1))
159 | stm_3 = np.memmap(f"{config.axiom_feature_dir}/memmap/stm-3.memmap", dtype='float', shape=(q_num, 1))
160 |
161 | axiom_list.append(['STM-1', stm_1])
162 | axiom_list.append(['STM-2', stm_2])
163 | axiom_list.append(['STM-3', stm_3])
164 |
165 | axiom_list.append(['RANK', axiom_rank])
166 | gen_qs_size = len(gen_qid2id)
167 | gen_qs_tokens = np.memmap(config.gen_qs_memmap_dir, dtype='int32', shape=(gen_qs_size, 15))
168 |
169 | with open(config.docid2id_dir) as f_docid2id:
170 | docid2id = json.load(f_docid2id)
171 | collection_size = len(docid2id)
172 | doc_tokens = np.memmap(config.memmap_doc_dir, dtype='int32', shape=(collection_size, 512))
173 |
174 | print("Load data done!")
175 |
176 | cfg = PretrainedConfig.get_config_dict(config.PRE_TRAINED_MODEL_NAME)[0]
177 | if not config.gradient_checkpointing:
178 | del cfg["gradient_checkpointing"] # gradient checkpointing conflicts with parallel training
179 | del cfg["parameter_sharing"]
180 | cfg = BertConfig.from_dict(cfg)
181 |
182 | # train
183 | if not config.load_ckpt:
184 | if config.model_type == 'ICT':
185 | model = ICT.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg)
186 | else:
187 | model = ARES.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg)
188 | else:
189 | if config.model_type == 'ICT':
190 | model = ICT(config=cfg)
191 | else:
192 | model = ARES(config=cfg)
193 | model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/{config.model_path}",
194 | map_location={'cuda:0': f'cuda:{config.local_rank}'}).items()})
195 | model = model.to(device)
196 | print("Loading model...")
197 | model = model.cuda()
198 |
199 | if config.optim == 'adam':
200 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
201 | elif config.optim == 'amsgrad':
202 | optimizer = torch.optim.Amsgrad(model.parameters(), lr=config.lr)
203 | elif config.optim == 'adagrad':
204 | optimizer = torch.optim.Adagrad(model.parameters(), lr=config.lr)
205 | else: # adamw, weight decay not depend on the lr
206 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
207 | optimizer_grouped_parameters = [
208 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay},
209 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
210 | ]
211 | optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=config.adam_epsilon)
212 |
213 | # train
214 | if config.distributed_train:
215 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=True)
216 | config.warm_up = config.warm_up / config.gpu_num
217 |
218 | for epoch in range(config.epochs):
219 | print(f'Epoch {epoch + 1}/{config.epochs}')
220 | print('-' * 10)
221 |
222 | print("========== Loading training data ==========")
223 | if config.model_type == 'ARES':
224 | train_qd_loader = get_train_qd_loader(doc_tokens, docid2id, config,
225 | doc2query=doc2query,
226 | gen_qs=gen_qs_tokens,
227 | gen_qid2id=gen_qid2id,
228 | axiom_feature=axiom_list) # b_sz * data samples
229 | else:
230 | train_qd_loader = get_ict_loader(doc_tokens, docid2id, config)
231 | print(f"train_batchs:{len(train_qd_loader)}, batch_size: {config.batch_size}")
232 |
233 | scaler = GradScaler(enabled=True)
234 | total_steps = len(train_qd_loader) * config.epochs
235 |
236 | scheduler = get_linear_schedule_with_warmup(
237 | optimizer,
238 | num_warmup_steps=int(total_steps * config.warm_up),
239 | num_training_steps=total_steps
240 | )
241 |
242 | train_loss = train_epoch(
243 | model,
244 | scaler,
245 | train_qd_loader,
246 | optimizer,
247 | scheduler,
248 | device,
249 | config,
250 | )
251 | scheduler.step()
252 | print(f'Train loss {train_loss}')
253 |
254 | if config.local_rank == 0:
255 | print('[SAVE] Saving model ... ')
256 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
257 | this_loss = round(float(train_loss), 4)
258 | torch.save(model_to_save.state_dict(), f"{save_dir}/{config.model_name}_{this_loss}")
259 |
260 |
261 |
262 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/model/modeling.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import sys
7 | import numpy as np
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | from torch import Tensor
12 | from torch.nn import CrossEntropyLoss, MarginRankingLoss
13 | from torch.nn import Softmax
14 | from torch.cuda.amp import autocast
15 | from transformers import BertModel, BertPreTrainedModel
16 | from transformers import AutoTokenizer, AutoConfig, AutoModel
17 |
18 |
19 |
20 | sys.path.insert(0, '../')
21 | PRETRAINED_MODEL_ARCHIVE_MAP = {
22 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
23 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
24 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
25 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
26 | 'spanbert-base-cased': "https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf_base.tar.gz",
27 | 'spanbert-large-cased': "https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf.tar.gz"
28 | }
29 |
30 | def batch_to_device(batch, target_device):
31 | """
32 | send a pytorch batch to a device (CPU/GPU)
33 | """
34 | for key in batch:
35 | if isinstance(batch[key], Tensor):
36 | batch[key] = batch[key].to(target_device)
37 | return batch
38 |
39 |
40 | class BertLayerNorm(nn.Module):
41 | def __init__(self, hidden_size, eps=1e-12):
42 | """Construct a layernorm module in the TF style (epsilon inside the square root).
43 | """
44 | super(BertLayerNorm, self).__init__()
45 | self.weight = nn.Parameter(torch.ones(hidden_size))
46 | self.bias = nn.Parameter(torch.zeros(hidden_size))
47 | self.variance_epsilon = eps
48 |
49 | def forward(self, x):
50 | u = x.mean(-1, keepdim=True)
51 | s = (x - u).pow(2).mean(-1, keepdim=True)
52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
53 | return self.weight * x + self.bias
54 |
55 |
56 | def gelu(x):
57 | """Implementation of the gelu activation function.
58 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
59 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
60 | Also see https://arxiv.org/abs/1606.08415
61 | """
62 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
63 |
64 |
65 | def swish(x):
66 | return x * torch.sigmoid(x)
67 |
68 |
69 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
70 |
71 |
72 | class BertPredictionHeadTransform(nn.Module):
73 | def __init__(self, config):
74 | super(BertPredictionHeadTransform, self).__init__()
75 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
76 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
77 | self.transform_act_fn = ACT2FN[config.hidden_act]
78 | else:
79 | self.transform_act_fn = config.hidden_act
80 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
81 |
82 | def forward(self, hidden_states):
83 | hidden_states = self.dense(hidden_states)
84 | hidden_states = self.transform_act_fn(hidden_states)
85 | hidden_states = self.LayerNorm(hidden_states)
86 | return hidden_states
87 |
88 |
89 | class BertLMPredictionHead(nn.Module):
90 | def __init__(self, config, bert_model_embedding_weights):
91 | super(BertLMPredictionHead, self).__init__()
92 | self.transform = BertPredictionHeadTransform(config)
93 |
94 | # The output weights are the same as the input embeddings, but there is
95 | # an output-only bias for each token.
96 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
97 | bert_model_embedding_weights.size(0),
98 | bias=False)
99 | self.decoder.weight = bert_model_embedding_weights
100 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
101 |
102 | def forward(self, hidden_states):
103 | hidden_states = self.transform(hidden_states)
104 | hidden_states = self.decoder(hidden_states) + self.bias
105 | return hidden_states
106 |
107 |
108 | # TransformerICT
109 | class ICT(BertPreTrainedModel):
110 | def __init__(self, config):
111 | super(ICT, self).__init__(config)
112 | self.bert = BertModel(config)
113 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
114 | self.cls = nn.Linear(config.hidden_size, 1)
115 | self.cls.predictions = BertLMPredictionHead(config, self.bert.embeddings.word_embeddings.weight)
116 | self.config = config
117 |
118 | self.init_weights()
119 |
120 | @autocast()
121 | def forward(self, input_ids, config, input_mask, token_type_ids=None, masked_lm_labels=None, device=None):
122 |
123 | batch_size = input_ids.size(0)
124 | outputs = self.bert(input_ids,
125 | attention_mask=input_mask,
126 | return_dict=False
127 | )
128 |
129 | sequence_output, pooled_output = outputs[0], outputs[1]
130 |
131 | if masked_lm_labels is not None:
132 | # MLM loss
133 | lm_prediction_scores = self.cls.predictions(sequence_output)
134 | loss_fct = CrossEntropyLoss(ignore_index=-1)
135 | mlm_loss = loss_fct(lm_prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) if config.MLM else 0.
136 |
137 | # ICT loss
138 | logits = pooled_output.reshape(batch_size//2, 2, self.config.hidden_size)
139 | s_encode = logits[:, 0, :] # bs/2, 1, h
140 | c_encode = logits[:, 1, :] # bs/2, 1, h
141 |
142 | logit = torch.matmul(s_encode, c_encode.transpose(-2, -1))
143 | target = torch.from_numpy(np.array([i for i in range(batch_size // 2)])).long().to(device)
144 | loss = nn.CrossEntropyLoss()
145 | ict_loss = loss(logit, target).mean()
146 |
147 | loss = mlm_loss + ict_loss
148 | return loss
149 |
150 | else:
151 | prediction_scores = self.cls(self.dropout(pooled_output))
152 | return prediction_scores
153 |
154 |
155 | class ARES(BertPreTrainedModel):
156 | def __init__(self, config):
157 | super(ARES, self).__init__(config)
158 | self.bert = BertModel(config)
159 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
160 | self.cls = nn.Linear(config.hidden_size, 1)
161 | self.sigmoid = nn.Sigmoid()
162 | self.cls.predictions = BertLMPredictionHead(config, self.bert.embeddings.word_embeddings.weight)
163 | self.config = config
164 |
165 | self.init_weights()
166 |
167 | @autocast()
168 | def forward(self, input_ids, config, input_mask, token_type_ids, masked_lm_labels=None, device=None):
169 |
170 | batch_size = input_ids.size(0)
171 | outputs = self.bert(input_ids,
172 | attention_mask=input_mask,
173 | token_type_ids=token_type_ids,
174 | return_dict=False
175 | )
176 |
177 | sequence_output, pooled_output = outputs[0], outputs[1]
178 | prediction_scores = self.cls(self.dropout(pooled_output))
179 |
180 | if masked_lm_labels is not None:
181 | # MLM loss
182 | lm_prediction_scores = self.cls.predictions(sequence_output)
183 | loss_fct = CrossEntropyLoss(ignore_index=-1)
184 | mlm_loss = loss_fct(lm_prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) if config.MLM else 0.
185 |
186 | # Pairwise loss
187 | logits = prediction_scores.reshape(batch_size // 2, 2)
188 | softmax = Softmax(dim=1)
189 | logits = softmax(logits)
190 | pos_logits = logits[:, 0]
191 | neg_logits = logits[:, 1]
192 | marginloss = MarginRankingLoss(margin=1.0, reduction='mean')
193 |
194 | rep_label = torch.ones_like(pos_logits)
195 | rep_loss = marginloss(pos_logits, neg_logits, rep_label)
196 |
197 | loss = mlm_loss + rep_loss
198 | return loss
199 | else:
200 | return prediction_scores
201 |
202 |
203 | class ARESReranker(ARES):
204 | def __init__(self, config, max_input_length=512):
205 | super().__init__(config)
206 | self.tokenizer = AutoTokenizer.from_pretrained(
207 | "bert-base-uncased", config=config, local_files_only=True)
208 | self.max_input_length = max_input_length
209 |
210 | def tokenize(self, qd_pairs):
211 | feature_input_ids = []
212 | feature_token_type_ids = []
213 | feature_attention_mask = []
214 | for query, doc in qd_pairs:
215 | cls_id, sep_id = 101, 102
216 | query_max_len = 32
217 | doc_max_len = 512 - 3 - query_max_len
218 | tokens = self.tokenizer.tokenize(query)
219 | query_input_ids = self.tokenizer.convert_tokens_to_ids(tokens)[: query_max_len]
220 |
221 | tokens = self.tokenizer.tokenize(doc)
222 | doc_input_ids = self.tokenizer.convert_tokens_to_ids(tokens)[: doc_max_len]
223 |
224 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id]
225 | token_type_ids = [0] * (len(query_input_ids) + 2) + [1] * (len(doc_input_ids) + 1)
226 | attention_mask = np.int64(np.array(input_ids) > 0)
227 |
228 | feature_input_ids.append(torch.tensor(input_ids))
229 | feature_token_type_ids.append(torch.tensor(token_type_ids))
230 | feature_attention_mask.append(torch.tensor(attention_mask))
231 |
232 |
233 | # padding to same length
234 | max_len = max([len(x) for x in feature_input_ids])
235 | for i in range(len(feature_input_ids)):
236 | pad_len = max_len - len(feature_input_ids[i])
237 | feature_input_ids[i] = torch.cat([feature_input_ids[i], torch.zeros(pad_len).long()])
238 | feature_token_type_ids[i] = torch.cat([feature_token_type_ids[i], torch.zeros(pad_len).long()])
239 | feature_attention_mask[i] = torch.cat([feature_attention_mask[i], torch.zeros(pad_len).long()])
240 |
241 | feature_input_ids = torch.vstack(feature_input_ids)
242 | feature_token_type_ids = torch.vstack(feature_token_type_ids)
243 | feature_attention_mask = torch.vstack(feature_attention_mask)
244 |
245 | return {
246 | "input_ids": feature_input_ids,
247 | "token_type_ids": feature_token_type_ids,
248 | "input_mask": feature_attention_mask
249 | }
250 |
251 | def score(self, qd_pairs):
252 | features = self.tokenize(qd_pairs)
253 | batch_to_device(features, self.device)
254 | with torch.cuda.amp.autocast():
255 | with torch.no_grad():
256 | scores = self.forward(config=None, **features)
257 | scores = scores.cpu().numpy().reshape(-1)
258 | return scores
259 |
260 |
261 | def rerank_query(self, query, docs):
262 | batch_size = 100
263 |
264 | qd_pairs = [(query, doc) for doc in docs]
265 | scores = []
266 | for i in range(0, len(qd_pairs), batch_size):
267 | scores.append(self.score(qd_pairs[i: i + batch_size]))
268 |
269 | scores = np.concatenate(scores, axis=0)
270 | scores = scores.reshape(-1)
271 | return scores.tolist()
272 |
273 | def rerank(self, queries, docs_topk):
274 | assert len(queries) == len(docs_topk)
275 | scores_for_queries = []
276 | for query, docs in zip(queries, docs_topk):
277 | scores_for_queries.append(self.rerank_query(query, docs))
278 | return scores_for_queries
279 |
--------------------------------------------------------------------------------
/finetune/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import os
7 | import sys
8 | sys.path.insert(0, '../')
9 |
10 | from tqdm import tqdm
11 | import json
12 | import torch
13 | import numpy as np
14 | import pandas as pd
15 | from datetime import timedelta
16 |
17 |
18 | from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
19 | from transformers import PretrainedConfig, BertConfig
20 | from torch import nn, optim
21 | from torch.cuda.amp import autocast, GradScaler
22 | from model.modeling import ARES, ICT
23 |
24 | from dataloader import get_train_qd_loader, get_test_qd_loader
25 | from config import get_config
26 | from ms_marco_eval import compute_metrics_from_files
27 | import warnings
28 |
29 | warnings.filterwarnings("ignore")
30 | torch.backends.cudnn.benchmark = True
31 |
32 |
33 | def train_epoch(model, scaler, qd_loader, optimizer, scheduler, device, config):
34 | model.train()
35 | losses = []
36 |
37 | num_instances = len(qd_loader)
38 | model_name = config.model_name
39 | for step, batch_data in enumerate(tqdm(qd_loader, desc=f"Fine-tuning {model_name} progress", total=num_instances)):
40 | input_ids, attention_mask, token_type_ids = batch_data["token_ids"], batch_data["attention_mask"], batch_data["token_type_ids"]
41 | this_batch_size = input_ids.size()[0]
42 |
43 | # b/2 x 2 x 512 ==> b x 512
44 | input_ids = input_ids.reshape(this_batch_size * 2, -1)
45 | attention_mask = attention_mask.reshape(this_batch_size * 2, -1)
46 | token_type_ids = token_type_ids.reshape(this_batch_size * 2, -1)
47 |
48 | input_ids = input_ids.to(device) # bs x 512
49 | attention_mask = attention_mask.to(device) # bs x 512
50 | token_type_ids = token_type_ids.to(device)
51 |
52 | with autocast():
53 | output = model(
54 | input_ids=input_ids,
55 | config=config,
56 | input_mask=attention_mask,
57 | token_type_ids=token_type_ids,
58 | ) # bs x 1
59 |
60 | softmax = nn.Softmax(dim=1)
61 | marginloss = nn.MarginRankingLoss(margin=1.0, reduction='mean')
62 | batch_size = output.size(0)
63 | logits = output.reshape(batch_size // 2, 2)
64 | logits = softmax(logits)
65 | pos_logits = logits[:, 0]
66 | neg_logits = logits[:, 1]
67 | rop_label = torch.ones_like(pos_logits)
68 | loss = marginloss(pos_logits, neg_logits, rop_label)
69 |
70 | loss = loss / config.gradient_accumulation_steps
71 | losses.append(loss.item())
72 | scaler.scale(loss).backward()
73 |
74 | # gradient accumulation
75 | if (step + 1) % config.gradient_accumulation_steps == 0:
76 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip)
77 | scaler.step(optimizer)
78 | scaler.update()
79 |
80 | scheduler.step()
81 | optimizer.zero_grad()
82 |
83 | if step % int(config.print_every) == 0:
84 | print(f"\n[Train] Loss at step {step} = {loss.item()}, lr = {optimizer.state_dict()['param_groups'][0]['lr']}")
85 | return np.mean(losses)
86 |
87 |
88 | def eval_model(model, qd_loader, device, config):
89 | model.eval()
90 | df_rank = pd.DataFrame(columns=['q_id', 'd_id', 'rank', 'score'])
91 | q_id_list, d_id_list, rank, score = [], [], [], []
92 |
93 | num_instances = len(qd_loader)
94 | with torch.no_grad():
95 | for i, batch_data in enumerate(tqdm(qd_loader, desc=f"Evaluating progress", total=num_instances)):
96 | input_ids, attention_mask, token_type_ids = batch_data["token_ids"], batch_data["attention_mask"], \
97 | batch_data["token_type_ids"]
98 |
99 | input_ids = input_ids.to(device) # bs x 512
100 | attention_mask = attention_mask.to(device) # bs x 512
101 | token_type_ids = token_type_ids.to(device)
102 |
103 | output = model(
104 | input_ids=input_ids,
105 | config=config,
106 | input_mask=attention_mask,
107 | token_type_ids=token_type_ids,
108 | ) # 100 x 1
109 |
110 | output = output.squeeze()
111 | q_ids = batch_data["q_id"]
112 | d_ids = batch_data["d_id"]
113 | scores = output.cpu().tolist()
114 | tuples = list(zip(q_ids, d_ids, scores))
115 | sorted_tuples = sorted(tuples, key=lambda x: x[2], reverse=True) # 看一下top100的分数分布
116 | for idx, this_tuple in enumerate(sorted_tuples):
117 | q_id_list.append(this_tuple[0])
118 | d_id_list.append(this_tuple[1])
119 | rank.append(idx + 1)
120 | score.append(this_tuple[2])
121 |
122 | df_rank['q_id'] = q_id_list
123 | df_rank['d_id'] = d_id_list
124 | df_rank['rank'] = rank
125 | df_rank['score'] = score
126 | return df_rank
127 |
128 |
129 | if __name__ == '__main__':
130 | config = get_config()
131 |
132 | # automatically create save dirs
133 | save_dir = f"{config.PRE_TRAINED_MODEL_NAME}/ckpt"
134 | if not os.path.exists(save_dir):
135 | os.mkdir(save_dir)
136 | save_model_path = f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/model_state"
137 |
138 | np.random.seed(config.seed)
139 | torch.manual_seed(config.seed)
140 | if torch.cuda.is_available():
141 | torch.cuda.manual_seed_all(config.seed)
142 | torch.cuda.set_device(config.local_rank)
143 | print('GPU is ON!')
144 | device = torch.device(f'cuda:{config.local_rank}')
145 | else:
146 | device = torch.device("cpu")
147 |
148 | # distributed training
149 | if config.distributed_train and not config.test:
150 | torch.distributed.init_process_group(backend="nccl", timeout=timedelta(180000000))
151 | local_rank = config.local_rank
152 | if local_rank != -1:
153 | print("Using Distributed")
154 |
155 | # Train Data Loader
156 | df_train_qds = pd.read_csv(config.train_qd_dir, sep=' ', header=None)
157 | if config.local_rank == 0:
158 | df_test_qds = pd.read_csv(config.test_qd_dir, sep=' ', header=None)
159 | df_dl2019_qds = pd.read_csv(config.dl2019_qd_dir, sep=' ', header=None)
160 |
161 | best_nDCG_dl2019, best_MRR_test = 0., 0.
162 | train_top100 = pd.read_csv(config.train100_dir, sep='\t', header=None)
163 | if config.local_rank == 0:
164 | test_top100 = pd.read_csv(config.test100_dir, sep='\t', header=None)
165 | dl2019_top100 = pd.read_csv(config.dl100_dir, sep='\t', header=None)
166 |
167 | # json files
168 | train_qs, test_qs, dl2019_qs, doc2query = {}, {}, {}, {}
169 | with open(config.train_qs_dir) as f_train_qs:
170 | for line in f_train_qs:
171 | es = json.loads(line)
172 | qid, ids = es["id"], es["ids"]
173 | if qid not in train_qs:
174 | train_qs[qid] = ids
175 |
176 | if config.local_rank == 0:
177 | with open(config.test_qs_dir) as f_test_qs:
178 | for line in f_test_qs:
179 | es = json.loads(line)
180 | qid, ids = es["id"], es["ids"]
181 | if qid not in test_qs:
182 | test_qs[qid] = ids
183 | with open(config.dl2019_qs_dir) as f_dl2019_qs:
184 | for line in f_dl2019_qs:
185 | es = json.loads(line)
186 | qid, ids = es["id"], es["ids"]
187 | if qid not in dl2019_qs:
188 | dl2019_qs[qid] = ids
189 |
190 | with open(config.docid2id_dir) as f_docid2id:
191 | docid2id = json.load(f_docid2id)
192 | print("Load dicts done!")
193 |
194 | collection_size = len(docid2id)
195 | doc_tokens = np.memmap(config.memmap_doc_dir, dtype='int32', shape=(collection_size, 512))
196 |
197 | cfg = PretrainedConfig.get_config_dict(config.PRE_TRAINED_MODEL_NAME)[0]
198 | if not config.gradient_checkpointing:
199 | del cfg["gradient_checkpointing"]
200 | cfg = BertConfig.from_dict(cfg)
201 |
202 | if not config.load_ckpt: # train
203 | if config.model_type == 'ICT':
204 | model = ICT.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg)
205 | else:
206 | model = ARES.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg)
207 | else: # test
208 | if config.model_type == 'ARES':
209 | model = ARES(config=cfg)
210 | elif config.model_type == 'PROP':
211 | model = PROP(config=cfg)
212 | else:
213 | model = ICT(config=cfg)
214 | model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/{config.model_path}",
215 | map_location={'cuda:0': f'cuda:{config.local_rank}'}).items()})
216 |
217 | model = model.to(device)
218 | print("Loading model...")
219 | model = model.cuda()
220 |
221 | scaler = GradScaler(enabled=True)
222 |
223 | if config.optim == 'adam':
224 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
225 | elif config.optim == 'amsgrad':
226 | optimizer = torch.optim.Amsgrad(model.parameters(), lr=config.lr)
227 | elif config.optim == 'adagrad':
228 | optimizer = torch.optim.Adagrad(model.parameters(), lr=config.lr)
229 | else: # adamw, weight decay not depend on the lr
230 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
231 | optimizer_grouped_parameters = [
232 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay},
233 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
234 | ]
235 | optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=config.adam_epsilon)
236 |
237 | if not config.test: # train
238 | if config.distributed_train:
239 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True, broadcast_buffers=False)
240 | config.warm_up = config.warm_up / config.gpu_num
241 |
242 | if config.local_rank == 0:
243 | print("\n========== Loading dev data ==========")
244 | test_qd_loader = get_test_qd_loader(test_top100, test_qs, doc_tokens, docid2id, config)
245 | print(f"test_q: {len(test_qs)}, test_q_batchs:{len(test_qd_loader)}")
246 |
247 | print("\n========== Loading DL 2019 data ==========")
248 | dl2019_qd_loader = get_test_qd_loader(dl2019_top100, dl2019_qs, doc_tokens, docid2id, config)
249 | print(f"dl2019_q: {len(dl2019_qs)}, dl2019_q_batchs:{len(dl2019_qd_loader)}")
250 |
251 | for epoch in range(config.epochs):
252 | print(f'Epoch {epoch + 1}/{config.epochs}')
253 | print('-' * 10)
254 |
255 | print("========== Loading training data ==========")
256 | train_qd_loader = get_train_qd_loader(df_train_qds, train_top100, train_qs, doc_tokens, docid2id, config, mode='train') # b_sz * data samples
257 | print(f"train_qd_pairs: {len(df_train_qds)}, train_batchs:{len(train_qd_loader)}, batch_size: {config.batch_size}")
258 |
259 | total_steps = len(train_qd_loader)
260 | scheduler = get_linear_schedule_with_warmup(
261 | optimizer,
262 | num_warmup_steps=int(total_steps * config.warm_up),
263 | num_training_steps=total_steps
264 | )
265 |
266 | train_loss = train_epoch(
267 | model,
268 | scaler,
269 | train_qd_loader,
270 | optimizer,
271 | scheduler,
272 | device,
273 | config,
274 | )
275 | scheduler.step()
276 | print(f'Train loss {train_loss}')
277 |
278 | if config.local_rank == 0:
279 | qd_rank = eval_model(
280 | model,
281 | dl2019_qd_loader,
282 | device,
283 | config,
284 | )
285 | df_rank = pd.DataFrame(columns=['q_id', 'Q0', 'd_id', 'rank', 'score', 'standard'])
286 | df_rank['q_id'] = qd_rank['q_id']
287 | df_rank['Q0'] = ['Q0'] * len(qd_rank['q_id'])
288 | df_rank['d_id'] = qd_rank['d_id']
289 | df_rank['rank'] = qd_rank['rank']
290 | df_rank['score'] = qd_rank['score']
291 | df_rank['standard'] = ['STANDARD'] * len(qd_rank['q_id'])
292 | df_rank.to_csv(f"{save_dir}/dl2019_qd_rank.tsv", sep=' ', index=False, header=False) # !
293 | result_lines = os.popen(f'trec_eval -m ndcg_cut.10,100 {config.dl2019_qd_dir} {save_dir}/dl2019_qd_rank.tsv').read().strip().split("\n")
294 | ndcg_10, ndcg_100 = float(result_lines[0].strip().split()[-1]), float(
295 | result_lines[1].strip().split()[-1])
296 | metrics = {'nDCG @10': ndcg_10, 'nDCG @100': ndcg_100, 'QueriesRanked': len(set(qd_rank['q_id']))}
297 |
298 | print('\n#############################')
299 | print('<--------- DL 2019 --------->')
300 | for metric in sorted(metrics):
301 | print('{}: {}'.format(metric, metrics[metric]))
302 | print('#############################\n')
303 | nDCG_dl2019 = round(metrics['nDCG @10'], 4)
304 | nDCG_dl2019_100 = round(metrics['nDCG @100'], 4)
305 | if nDCG_dl2019 > best_nDCG_dl2019:
306 | best_nDCG_dl2019 = nDCG_dl2019
307 | qd_rank.to_csv(f"{save_dir}/best_{config.model_type}_dl2019_qd_rank.tsv", sep='\t', index=False,
308 | header=False)
309 |
310 | # test msmarco dev
311 | qd_rank = eval_model(
312 | model,
313 | test_qd_loader,
314 | device,
315 | config,
316 | )
317 | qd_rank.to_csv(f"{save_dir}/test_qd_rank.tsv", sep='\t', index=False, header=False)
318 | metrics = compute_metrics_from_files(config.test_qd_dir, f"{save_dir}/test_qd_rank.tsv")
319 | print('\n#####################')
320 | print('<----- MS Dev ----->')
321 | for metric in sorted(metrics):
322 | print('{}: {}'.format(metric, metrics[metric]))
323 | print('#####################\n')
324 | MRR_test = round(metrics['MRR @10'], 4)
325 | MRR_test_100 = round(metrics['MRR @100'], 4)
326 | if MRR_test > best_MRR_test:
327 | best_MRR_test = MRR_test
328 | qd_rank.to_csv(f"{save_dir}/best_{config.model_type}_test_qd_rank.tsv", sep='\t', index=False, header=False)
329 |
330 | print('[SAVE] Saving model ... ')
331 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
332 | torch.save(model_to_save.state_dict(),f"{save_dir}/{config.model_type}_{MRR_test}_{MRR_test_100}_e{epoch + 1}")
333 |
334 | else: # test
335 | print("\n========== Loading dev data ==========")
336 | test_qd_loader = get_test_qd_loader(test_top100, test_qs, doc_tokens, docid2id, config)
337 | print(f"test_q: {len(test_qs)}, test_q_batchs:{len(test_qd_loader)}")
338 |
339 | print("\n========== Loading DL 2019 data ==========")
340 | dl2019_qd_loader = get_test_qd_loader(dl2019_top100, dl2019_qs, doc_tokens, docid2id, config)
341 | print(f"dl2019_q: {len(dl2019_qs)}, dl2019_q_batchs:{len(dl2019_qd_loader)}")
342 |
343 | qd_rank = eval_model(
344 | model,
345 | dl2019_qd_loader,
346 | device,
347 | config,
348 | )
349 | df_rank = pd.DataFrame(columns=['q_id', 'Q0', 'd_id', 'rank', 'score', 'standard'])
350 | df_rank['q_id'] = qd_rank['q_id']
351 | df_rank['Q0'] = ['Q0'] * len(qd_rank['q_id'])
352 | df_rank['d_id'] = qd_rank['d_id']
353 | df_rank['rank'] = qd_rank['rank']
354 | df_rank['score'] = qd_rank['score']
355 | df_rank['standard'] = ['STANDARD'] * len(qd_rank['q_id'])
356 | df_rank.to_csv(f"{save_dir}/dl2019_qd_rank_as100.tsv", sep=' ', index=False, header=False)
357 | result_lines = os.popen(f'trec_eval -m ndcg_cut.10,100 {config.dl2019_qd_dir} {save_dir}/dl2019_qd_rank_as100.tsv').read().strip().split("\n")
358 | ndcg_10, ndcg_100 = float(result_lines[0].strip().split()[-1]), float(result_lines[1].strip().split()[-1])
359 | metrics = {'nDCG @10': ndcg_10, 'nDCG @100': ndcg_100, 'QueriesRanked': len(set(qd_rank['q_id']))}
360 | print('\n#############################')
361 | print('<--------- DL 2019 --------->')
362 | for metric in sorted(metrics):
363 | print('{}: {}'.format(metric, metrics[metric]))
364 | print('#############################\n')
365 |
366 | # test msmarco dev
367 | qd_rank = eval_model(
368 | model,
369 | test_qd_loader,
370 | device,
371 | config,
372 | )
373 | qd_rank.to_csv(f"{save_dir}/test_qd_rank_as100.tsv", sep='\t', index=False, header=False)
374 | metrics = compute_metrics_from_files(config.test_qd_dir, f"{save_dir}/test_qd_rank_as100.tsv")
375 | print('\n#####################')
376 | print('<----- MS Dev ----->')
377 | for metric in sorted(metrics):
378 | print('{}: {}'.format(metric, metrics[metric]))
379 | print('#####################\n')
380 |
--------------------------------------------------------------------------------
/pretrain/dataloader.py:
--------------------------------------------------------------------------------
1 | '''
2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search
3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma.
4 | '''
5 | # encoding: utf-8
6 | import random
7 | import numpy as np
8 | import collections
9 | import pandas as pd
10 | from tqdm import tqdm
11 | import xgboost as xgb
12 | from scipy.special import *
13 | from torch.utils.data import Dataset, DataLoader
14 | from torch.utils.data.distributed import DistributedSampler
15 | from transformers import BertTokenizer
16 |
17 |
18 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
19 | # masked_lm_prob=0.15, max_predictions_per_seq=60, True, bert_vocab_list (id)
20 | def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list, id2token):
21 | """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
22 | with several refactors to clean it up and remove a lot of unnecessary variables."""
23 | cand_indices = []
24 |
25 | START_DOC = False
26 | for (i, token) in enumerate(tokens): # token_ids
27 | if token == 102: # SEP
28 | START_DOC = True
29 | continue
30 | if token == 101: # CLS
31 | continue
32 | if not START_DOC:
33 | continue
34 |
35 | if (whole_word_mask and len(cand_indices) >= 1 and id2token[token].startswith("##")):
36 | cand_indices[-1].append(i)
37 | else:
38 | cand_indices.append([i])
39 |
40 | num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(cand_indices) * masked_lm_prob))))
41 | random.shuffle(cand_indices)
42 | masked_lms = []
43 | covered_indexes = set()
44 | for index_set in cand_indices:
45 | if len(masked_lms) >= num_to_mask:
46 | break
47 | # If adding a whole-word mask would exceed the maximum number of
48 | # predictions, then just skip this candidate.
49 | if len(masked_lms) + len(index_set) > num_to_mask:
50 | continue
51 | is_any_index_covered = False
52 | for index in index_set:
53 | if index in covered_indexes:
54 | is_any_index_covered = True
55 | break
56 | if is_any_index_covered:
57 | continue
58 | for index in index_set:
59 | covered_indexes.add(index)
60 | # 80% of the time, replace with [MASK]
61 | if random.random() < 0.8:
62 | masked_token = 103
63 | else:
64 | # 10% of the time, keep original
65 | if random.random() < 0.5:
66 | masked_token = tokens[index]
67 | # 10% of the time, replace with random word
68 | else:
69 | masked_token = random.choice(vocab_list)
70 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
71 | tokens[index] = masked_token
72 |
73 | assert len(masked_lms) <= num_to_mask
74 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
75 | mask_indices = [p.index for p in masked_lms]
76 | masked_token_labels = [p.label for p in masked_lms]
77 |
78 | return tokens, masked_token_labels, mask_indices
79 |
80 |
81 | def create_masked_lm_predictions_ict(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list, id2token):
82 | """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
83 | with several refactors to clean it up and remove a lot of unnecessary variables."""
84 | cand_indices = []
85 | for (i, token) in enumerate(tokens): # token_ids
86 | if (whole_word_mask and len(cand_indices) >= 1 and id2token[token].startswith("##")): # startswith ##
87 | cand_indices[-1].append(i)
88 | else:
89 | cand_indices.append([i])
90 |
91 | num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(cand_indices) * masked_lm_prob))))
92 | random.shuffle(cand_indices)
93 | masked_lms = []
94 | covered_indexes = set()
95 | for index_set in cand_indices:
96 | if len(masked_lms) >= num_to_mask:
97 | break
98 | # If adding a whole-word mask would exceed the maximum number of
99 | # predictions, then just skip this candidate.
100 | if len(masked_lms) + len(index_set) > num_to_mask:
101 | continue
102 | is_any_index_covered = False
103 | for index in index_set:
104 | if index in covered_indexes:
105 | is_any_index_covered = True
106 | break
107 | if is_any_index_covered:
108 | continue
109 | for index in index_set:
110 | covered_indexes.add(index)
111 | # 80% of the time, replace with [MASK]
112 | if random.random() < 0.8:
113 | masked_token = 103
114 | else:
115 | # 10% of the time, keep original
116 | if random.random() < 0.5:
117 | masked_token = tokens[index]
118 | # 10% of the time, replace with random word
119 | else:
120 | masked_token = random.choice(vocab_list)
121 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
122 | tokens[index] = masked_token
123 |
124 | assert len(masked_lms) <= num_to_mask
125 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
126 | mask_indices = [p.index for p in masked_lms]
127 | masked_token_labels = [p.label for p in masked_lms]
128 |
129 | return tokens, masked_token_labels, mask_indices
130 |
131 |
132 | class TrainICTPairwise(Dataset):
133 | def __init__(self, dids, d_dict, did2idx, config):
134 | self.dids = dids
135 | self.d_dict = d_dict
136 | self.did2idx = did2idx
137 | self.config = config
138 |
139 | self.tokenizer = BertTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME)
140 | self.vocab_list = list(self.tokenizer.vocab[key] for key in self.tokenizer.vocab)
141 | self.id2token = {self.tokenizer.vocab[key]: key for key in self.tokenizer.vocab}
142 | self.sep_token_id = self.tokenizer.vocab["."]
143 | self.cls_id = 101
144 |
145 | def __len__(self):
146 | return len(self.dids)
147 |
148 | def __getitem__(self, item):
149 | this_did = self.dids[item]
150 |
151 | doc_ids = self.d_dict[self.did2idx[this_did]].tolist()
152 | sep_pos = [-1] + [i for i, id in enumerate(doc_ids) if id == self.sep_token_id] + [len(doc_ids) - 1]
153 | sentences = [doc_ids[sep_pos[i] + 1: sep_pos[i + 1] + 1] for i in range(len(sep_pos) - 1)]
154 | removes = [random.random() < 0.9 for _ in range(len(sentences))]
155 |
156 | s_ids, c_ids = [], []
157 | b_token_ids, b_attention_mask, b_masked_lm_ids = np.array([[]]), np.array([[]]), np.array([[]])
158 |
159 | for idx, remove in enumerate(removes):
160 | if remove == 1:
161 | sentence = [self.cls_id] + sentences[idx]
162 | context = sentences[: idx] + sentences[idx + 1:]
163 | context = [self.cls_id] + [w for s in context for w in s]
164 |
165 | sentence = sentence[: self.config.max_len]
166 | context = context[: self.config.max_len]
167 | s_ids.append(sentence)
168 | c_ids.append(context)
169 |
170 | s_input_ids = np.zeros(self.config.max_len, dtype=np.int)
171 | c_input_ids = np.zeros(self.config.max_len, dtype=np.int)
172 | s_input_ids[: len(sentence)] = sentence
173 | c_input_ids[: len(context)] = context
174 |
175 | s_attention_mask = np.int64(s_input_ids > 0)
176 | c_attention_mask = np.int64(c_input_ids > 0)
177 | attention_mask = np.stack((s_attention_mask, c_attention_mask))
178 |
179 | s_input_ids, s_masked_lm_ids, s_masked_lm_positions = create_masked_lm_predictions_ict(
180 | s_input_ids,
181 | masked_lm_prob=self.config.masked_lm_prob,
182 | max_predictions_per_seq=self.config.max_predictions_per_seq,
183 | whole_word_mask=True,
184 | vocab_list=self.vocab_list,
185 | id2token=self.id2token)
186 | c_input_ids, c_masked_lm_ids, c_masked_lm_positions = create_masked_lm_predictions_ict(
187 | c_input_ids,
188 | masked_lm_prob=self.config.masked_lm_prob,
189 | max_predictions_per_seq=self.config.max_predictions_per_seq,
190 | whole_word_mask=True,
191 | vocab_list=self.vocab_list,
192 | id2token=self.id2token)
193 | s_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1)
194 | c_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1)
195 | s_lm_label_array[s_masked_lm_positions] = s_masked_lm_ids
196 | c_lm_label_array[c_masked_lm_positions] = c_masked_lm_ids
197 | masked_lm_ids = np.stack((s_lm_label_array, c_lm_label_array))
198 |
199 | token_ids = np.stack((s_input_ids.flatten(), c_input_ids.flatten()))
200 | b_token_ids = token_ids if len(b_token_ids) == 1 else np.concatenate((b_token_ids, token_ids), axis=0)
201 | b_attention_mask = attention_mask if len(b_attention_mask) == 1 else np.concatenate((b_attention_mask, attention_mask), axis=0)
202 | b_masked_lm_ids = masked_lm_ids if len(b_masked_lm_ids) == 1 else np.concatenate((b_masked_lm_ids, masked_lm_ids), axis=0)
203 |
204 | # clip
205 | b_token_ids = b_token_ids[: self.config.batch_size, :]
206 | b_attention_mask = b_attention_mask[: self.config.batch_size, :]
207 | b_masked_lm_ids = b_masked_lm_ids[: self.config.batch_size, :] # no greater than max batch size
208 |
209 | return {
210 | 'token_ids': b_token_ids, # b x 2
211 | 'attention_mask': b_attention_mask,
212 | 'masked_lm_ids': b_masked_lm_ids,
213 | }
214 |
215 |
216 | def get_ict_loader(d_dict, did2idx, config):
217 |
218 | dids = list(did2idx.keys())
219 | print('Loading tokens...')
220 | ds = TrainICTPairwise(
221 | dids=dids,
222 | d_dict=d_dict,
223 | did2idx=did2idx,
224 | config=config
225 | )
226 | batch_size = 1
227 | if config.distributed_train:
228 | sampler = DistributedSampler(ds, num_replicas=config.world_size, rank=config.local_rank)
229 | return DataLoader(
230 | ds,
231 | batch_size=batch_size,
232 | num_workers=0,
233 | sampler=sampler
234 | )
235 | else:
236 | return DataLoader(
237 | ds,
238 | batch_size=batch_size,
239 | num_workers=0,
240 | shuffle=True,
241 | )
242 |
243 |
244 | class TrainQDDatasetPairwise(Dataset):
245 | def __init__(self, q_ids, d_ids, d_dict, did2idx, config, gen_qs, gen_qid2id):
246 | self.q_ids = q_ids
247 | self.d_ids = d_ids
248 | self.d_dict = d_dict
249 | self.did2idx = did2idx
250 | self.gen_qs = gen_qs
251 | self.gen_qid2id = gen_qid2id
252 | self.config = config
253 | self.tokenizer = BertTokenizer.from_pretrained(self.config.PRE_TRAINED_MODEL_NAME)
254 | self.vocab_list = list(self.tokenizer.vocab[key] for key in self.tokenizer.vocab)
255 | self.id2token = {self.tokenizer.vocab[key]: key for key in self.tokenizer.vocab}
256 |
257 | def __len__(self):
258 | return len(self.q_ids)
259 |
260 | def __getitem__(self, item):
261 | cls_id, sep_id = 101, 102
262 | q_id = self.q_ids[item]
263 | d_id = self.d_ids[item]
264 |
265 | pos_q_id, neg_q_id = q_id[0], q_id[1]
266 | did = d_id[0]
267 |
268 | pos_query_input_ids = self.gen_qs[self.gen_qid2id[pos_q_id]].tolist()
269 | neg_query_input_ids = self.gen_qs[self.gen_qid2id[neg_q_id]].tolist()
270 |
271 | doc_input_ids = self.d_dict[self.did2idx[did]].tolist()
272 | pos_query_input_ids = pos_query_input_ids[: self.config.max_q_len]
273 | neg_query_input_ids = neg_query_input_ids[: self.config.max_q_len]
274 |
275 | pos_max_passage_length = self.config.max_len - 3 - len(pos_query_input_ids)
276 | neg_max_passage_length = self.config.max_len - 3 - len(neg_query_input_ids)
277 |
278 | pos_doc_input_ids = doc_input_ids[:pos_max_passage_length]
279 | neg_doc_input_ids = doc_input_ids[:neg_max_passage_length]
280 |
281 | pos_input_ids = [cls_id] + pos_query_input_ids + [sep_id] + pos_doc_input_ids + [sep_id]
282 | neg_input_ids = [cls_id] + neg_query_input_ids + [sep_id] + neg_doc_input_ids + [sep_id]
283 |
284 | pos_token_type_ids = [0] * (2 + len(pos_query_input_ids)) + [1] * (1 + len(pos_doc_input_ids))
285 | neg_token_type_ids = [0] * (2 + len(neg_query_input_ids)) + [1] * (1 + len(neg_doc_input_ids))
286 |
287 | pos_token_ids = np.array(pos_input_ids)
288 | neg_token_ids = np.array(neg_input_ids)
289 |
290 | pos_attention_mask = np.int64(pos_token_ids > 0)
291 | neg_attention_mask = np.int64(neg_token_ids > 0)
292 | attention_mask = np.stack((pos_attention_mask, neg_attention_mask))
293 |
294 | pos_token_type_ids = np.array(pos_token_type_ids)
295 | neg_token_type_ids = np.array(neg_token_type_ids)
296 | token_type_ids = np.stack((pos_token_type_ids, neg_token_type_ids))
297 |
298 | pos_token_ids, pos_masked_lm_ids, pos_masked_lm_positions = create_masked_lm_predictions(
299 | pos_token_ids,
300 | masked_lm_prob=self.config.masked_lm_prob,
301 | max_predictions_per_seq=self.config.max_predictions_per_seq,
302 | whole_word_mask=True,
303 | vocab_list=self.vocab_list,
304 | id2token=self.id2token)
305 | neg_token_ids, neg_masked_lm_ids, neg_masked_lm_positions = create_masked_lm_predictions(
306 | neg_token_ids,
307 | masked_lm_prob=self.config.masked_lm_prob,
308 | max_predictions_per_seq=self.config.max_predictions_per_seq,
309 | whole_word_mask=True,
310 | vocab_list=self.vocab_list,
311 | id2token=self.id2token)
312 | token_ids = np.stack((pos_token_ids.flatten(), neg_token_ids.flatten()))
313 |
314 | pos_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1)
315 | neg_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1)
316 | pos_lm_label_array[pos_masked_lm_positions] = pos_masked_lm_ids
317 | neg_lm_label_array[neg_masked_lm_positions] = neg_masked_lm_ids
318 |
319 | masked_lm_ids = np.stack((pos_lm_label_array, neg_lm_label_array))
320 |
321 | return {
322 | 'token_ids': token_ids,
323 | 'attention_mask': attention_mask,
324 | 'token_type_ids': token_type_ids,
325 | 'masked_lm_ids': masked_lm_ids,
326 | }
327 |
328 |
329 | # [CLS] q [SEP] d [SEP]
330 | def get_train_qd_loader(d_dict, did2idx, config, doc2query=None, gen_qs=None, gen_qid2id=None, axiom_feature=None):
331 | q_max_len, max_len, batch_size = config.max_q_len, config.max_len, config.batch_size
332 |
333 | new_q_ids, new_d_ids = [], []
334 | doc_num = len(did2idx)
335 | dids = list(did2idx.keys())
336 |
337 | # loading xgboost model
338 | model = xgb.XGBRFClassifier()
339 | model.load_model(config.clf_model)
340 |
341 | all_case = []
342 | for idx in tqdm(range(doc_num), desc=f"Sampling Pre-train Query Pairs progress"):
343 | this_did = dids[idx]
344 | if this_did not in doc2query:
345 | continue
346 |
347 | qids = [[qid] for qid in doc2query[this_did]]
348 | q_num = len(qids)
349 | for i in range(q_num):
350 | q_id = qids[i][0]
351 | idx = gen_qid2id[q_id]
352 | for k in range(len(axiom_feature)):
353 | this_feature_name, this_feature = axiom_feature[k][0], axiom_feature[k][1]
354 | if this_feature_name == 'RANK':
355 | score = this_feature[idx][0] if this_feature[idx][0] != 0 else 1e12
356 | else:
357 | score = this_feature[idx][0]
358 | score = this_feature[idx][0] if this_feature_name not in ['PROX-1', 'PROX-2', 'RANK'] else (1 / (score + 1e-12))
359 | qids[i].append(score)
360 |
361 | all_pairs = []
362 | for i in range(q_num):
363 | for j in range(i+1, q_num):
364 | q1, q2 = qids[i], qids[j]
365 | all_pairs.append([q1, q2])
366 |
367 | k = min(2, len(all_pairs))
368 | sampled_pairs = random.sample(all_pairs, k=k)
369 |
370 | for pair in sampled_pairs:
371 | qid1, qid2 = pair[0][0], pair[1][0]
372 | case = []
373 | for i in range(len(axiom_feature)):
374 | axiom_1 = pair[0][i + 1]
375 | axiom_2 = pair[1][i + 1]
376 | if axiom_1 > axiom_2:
377 | case.append(1)
378 | elif axiom_1 == axiom_2:
379 | case.append(0)
380 | else:
381 | case.append(-1)
382 | all_case.append(case)
383 | new_q_ids.append([qid1, qid2])
384 | new_d_ids.append([this_did])
385 |
386 | all_case = pd.DataFrame(np.array(all_case))
387 | all_case.columns = ['PROX-1', 'PROX-2', 'REP-QL', 'REP-TFIDF', 'REG', 'STM-1', 'STM-2', 'STM-3', 'RANK']
388 | pred_prob = model.predict(all_case)
389 | for idx, pred in enumerate(pred_prob):
390 | result = 1 if pred > 0.5 else 0
391 | if result == 0: # swap
392 | qid1 = new_q_ids[idx][0]
393 | qid2 = new_q_ids[idx][1]
394 | new_q_ids[idx][0] = qid2
395 | new_q_ids[idx][1] = qid1
396 |
397 | print('Loading tokens...')
398 | ds = TrainQDDatasetPairwise(
399 | q_ids=new_q_ids,
400 | d_ids=new_d_ids,
401 | d_dict=d_dict,
402 | did2idx=did2idx,
403 | config=config,
404 | gen_qs=gen_qs,
405 | gen_qid2id=gen_qid2id,
406 | )
407 | batch_size = batch_size // 2
408 |
409 | if config.distributed_train:
410 | sampler = DistributedSampler(ds, num_replicas=config.world_size, rank=config.local_rank)
411 | return DataLoader(
412 | ds,
413 | batch_size=batch_size,
414 | num_workers=0,
415 | sampler=sampler
416 | )
417 | else:
418 | return DataLoader(
419 | ds,
420 | batch_size=batch_size,
421 | num_workers=0,
422 | shuffle=True,
423 | )
424 |
425 |
426 |
--------------------------------------------------------------------------------
/visualization/output_ARES_simple.html:
--------------------------------------------------------------------------------
1 | | QID DID | Relevance Level/Rank | Word Importance |
|---|
| 156493
2 | D3356945 | 3/1 | [CLS] do goldfish grow [SEP] https : / / answers . yahoo . com / question / index ? qid = 20100226170159aawholxhow to make goldfish grow faster ? " pets fish how to make goldfish grow faster ? just wondering ? update : what kind of foods could i use ? would warmer water help ? update 2 : gabe tech , retard they aren ' t in a bowl and if i did what you said , they ' d die ! follow 18 answers answers relevance rating newest oldest best answer : really people ? if you put a small child into a large house , will he grow faster ? no ! a tank that is too small will slow his growth down and even stop it but a bigger tank than needed won ' t have any effect . make sure his water is good and that he has adequate room and food , and he will grow at his own pace . really people fish are just like any other animal on the planet they aren ' t little aliens . t he only thing weird about how a fish grows is that they put out a hormone into the water that will slow down the growth of other fish and them selves . and dont put fill your bowl with juice thats an acid and it will kill your |
|
|
--------------------------------------------------------------------------------
|