├── imgs ├── ARES_black.jpg ├── ARES_white.jpg ├── ARES_simple_ag.pdf ├── ARES_simple_ag.png ├── few-shot-metric.jpg ├── few-shot-metric.pdf └── few-shot-metric.png ├── model ├── __init__.py └── modeling.py ├── requirements.txt ├── preprocess ├── anserini_scripts │ ├── do_bm25_search.sh │ └── build_index.sh ├── convert_to_pred.py ├── README.md ├── convert_tokenize.py └── Eval4.0.pl ├── example └── rerank.py ├── finetune ├── modelsize_estimate.py ├── config.py ├── ms_marco_eval.py ├── dataloader.py └── train.py ├── .gitignore ├── visualization ├── visualization.py ├── config.py ├── dataloader.py ├── visual.py └── output_ARES_simple.html ├── pretrain ├── config.py ├── train.py └── dataloader.py ├── README.md └── LICENSE /imgs/ARES_black.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_black.jpg -------------------------------------------------------------------------------- /imgs/ARES_white.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_white.jpg -------------------------------------------------------------------------------- /imgs/ARES_simple_ag.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_simple_ag.pdf -------------------------------------------------------------------------------- /imgs/ARES_simple_ag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/ARES_simple_ag.png -------------------------------------------------------------------------------- /imgs/few-shot-metric.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/few-shot-metric.jpg -------------------------------------------------------------------------------- /imgs/few-shot-metric.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/few-shot-metric.pdf -------------------------------------------------------------------------------- /imgs/few-shot-metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanyuan14/ARES/HEAD/imgs/few-shot-metric.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch==1.9.0 3 | # Huggingface transformers 4 | transformers==4.9.2 5 | # progress bars in model download and training scripts 6 | tqdm 7 | # Accessing files from S3 directly. 8 | boto3 9 | # nltk 10 | nltk 11 | # numpy 12 | numpy -------------------------------------------------------------------------------- /preprocess/anserini_scripts/do_bm25_search.sh: -------------------------------------------------------------------------------- 1 | python -m pyserini.search --index path_to_index \ 2 | --topics path_to_queries \ 3 | --output path_to_trec \ 4 | --bm25 \ 5 | --hits 200 -------------------------------------------------------------------------------- /preprocess/anserini_scripts/build_index.sh: -------------------------------------------------------------------------------- 1 | python -m pyserini.index -collection JsonCollection \ 2 | -generator DefaultLuceneDocumentGenerator \ 3 | -threads 8 \ 4 | -input path_to_collection \ 5 | -index path_to_index \ 6 | -storePositions -storeDocvectors -------------------------------------------------------------------------------- /preprocess/convert_to_pred.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from collections import defaultdict 3 | import argparse 4 | 5 | def trec_to_pred(args): 6 | trec = defaultdict(dict) 7 | with open(args.input_trec, 'r') as f: 8 | for line in f: 9 | qid, _, docid, rank, score, _ = line.strip().split(' ') 10 | trec[qid][docid] = score 11 | 12 | f = open(args.output, 'w') 13 | with open(args.qrels, 'r') as r: 14 | for line in r: 15 | line = line.strip().split() 16 | qid = line[1].split(':')[1] 17 | docid = line[-7] 18 | if docid in trec[qid]: 19 | f.write(trec[qid][docid] + '\n') 20 | else: 21 | f.write('0.0\n') 22 | 23 | f.close() 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument("--input_trec", default='', type=str, required=True) 30 | parser.add_argument("--output", default='', type=str, required=True) 31 | parser.add_argument("--qrels", default='', type=str, required=True) 32 | args = parser.parse_args() 33 | 34 | trec_to_pred(args) -------------------------------------------------------------------------------- /example/rerank.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, '../') 4 | 5 | from tqdm import tqdm 6 | import json 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | from datetime import timedelta 11 | 12 | from model.modeling import ARESReranker 13 | 14 | 15 | if __name__ == "__main__": 16 | model_path = "path/to/model" 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | model = ARESReranker.from_pretrained(model_path).to(device) 19 | 20 | query1 = "What is the best way to get to the airport" 21 | query2 = "what do you like to eat?" 22 | 23 | doc1 = "The best way to get to the airport is to take the bus" 24 | doc2 = "I like to eat apples" 25 | 26 | 27 | ### Score a batch of q-d pairs 28 | qd_pairs = [ 29 | (query1, doc1), (query1, doc2), 30 | (query2, doc1), (query2, doc2) 31 | ] 32 | 33 | score = model.score(qd_pairs) 34 | print("qd scores", score) 35 | 36 | ### Rerank a single query 37 | score = model.rerank_query(query1, [doc1, doc2]) 38 | print("query1 scores", score) 39 | 40 | ### Rerank a batch of queries 41 | query1_topk = [ doc1, doc2 ] 42 | query2_topk = [ doc1, doc2 ] 43 | 44 | score = model.rerank([query1, query2], [query1_topk, query2_topk]) 45 | -------------------------------------------------------------------------------- /finetune/modelsize_estimate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | 10 | def modelsize(model, input, type_size=4): 11 | para = sum([np.prod(list(p.size())) for p in model.parameters()]) 12 | # print('Model {} : Number of params: {}'.format(model._get_name(), para)) 13 | print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000)) 14 | 15 | input_ = input.clone() 16 | input_.requires_grad_(requires_grad=False) 17 | 18 | mods = list(model.modules()) 19 | out_sizes = [] 20 | 21 | for i in range(1, len(mods)): 22 | m = mods[i] 23 | if isinstance(m, nn.ReLU): 24 | if m.inplace: 25 | continue 26 | out = m(input_) 27 | out_sizes.append(np.array(out.size())) 28 | input_ = out 29 | 30 | total_nums = 0 31 | for i in range(len(out_sizes)): 32 | s = out_sizes[i] 33 | nums = np.prod(np.array(s)) 34 | total_nums += nums 35 | 36 | # print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums)) 37 | # print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2)) 38 | print('Model {} : intermedite variables: {:3f} M (without backward)' 39 | .format(model._get_name(), total_nums * type_size / 1000 / 1000)) 40 | print('Model {} : intermedite variables: {:3f} M (with backward)' 41 | .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000)) 42 | 43 | -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | ## Data Preprocess 2 | 3 | Since different datasets require different pre-processing, we only provide some helper functions and scripts here. 4 | 5 | ### Anserini Scripts 6 | 7 | We use BM25 implemented by `anserini` to perform first-stage retrieval. 8 | 9 | Please make sure you have correctly installed `anserini` and `pyserini`. 10 | 11 | ### Tokenize 12 | 13 | You can pre-tokenize your dataset offline for faster training. 14 | ```bash 15 | python convert_tokenize.py \ 16 | --vocab_dir {path_to_vocab} \ 17 | --type {'query', 'doc', 'triples'} \ 18 | --input {path_to_input} \ 19 | --output {path_to_output} 20 | ``` 21 | File format: 22 | 23 | * query: `qid \t query` for each line 24 | * doc: `{"id": docid, "contents": doc}` for each line 25 | * triples: `{"query": query_text, "doc_pos": positive_doc, "doc_neg": negative_doc}` for each line 26 | 27 | ### Small Datasets 28 | 29 | #### TREC-COVID 30 | 31 | We follow the same data preprocess as `OpenMatch`, please refer to [experiments-treccovid](https://github.com/thunlp/OpenMatch/blob/master/docs/experiments-treccovid.md) 32 | 33 | #### Robust04 34 | 35 | We use BM25 to generate Top-200 candidates for each query, and the fine-tuning procedure is similar to MS-MARCO 36 | 37 | #### MQ2007 38 | 39 | We use BM25 to generate Top-200 candidates for each query, and the fine-tuning procedure is similar to MS-MARCO 40 | 41 | Note that `trec_eval` cannot be used to compute metrics for MQ2007 directly. You should first convert the `trec` output file and use `Eval4.0.pl` for evaluation. `Eval4.0.pl` is from [LETOR4.0](https://www.microsoft.com/en-us/research/project/letor-learning-rank-information-retrieval/letor-4-0/) 42 | ```bash 43 | python convert_to_pred.py \ 44 | --input_trec {path_to_trec_output} \ 45 | --qrels {path_to_qrels} \ 46 | --output {path_to_output} 47 | 48 | perl Eval4.0.pl {path_to_qrels} {path_to_output} ./eval_result 0 49 | ``` 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /preprocess/convert_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def tokenize_file(tokenizer, input_file, output_file, file_type): 10 | total_size = sum(1 for _ in open(input_file)) 11 | with open(output_file, 'w') as outFile: 12 | for line in tqdm(open(input_file), total=total_size, 13 | desc=f"Tokenize: {os.path.basename(input_file)}"): 14 | if file_type == "query": 15 | seq_id, text = line.split("\t") 16 | else: 17 | line = json.loads(line.strip()) 18 | 19 | tokens = tokenizer.tokenize(text) 20 | ids = tokenizer.convert_tokens_to_ids(tokens)[: 512] 21 | outFile.write(json.dumps( 22 | {"id": seq_id, "ids": ids} 23 | )) 24 | outFile.write("\n") 25 | 26 | 27 | def tokenize_queries(tokenizer, input_file, output_file): 28 | total_size = sum(1 for _ in open(input_file)) 29 | f = open(output_file, 'w') 30 | with open(input_file, 'r') as r: 31 | for line in tqdm(r, total=total_size): 32 | query_id, text = line.strip().split('\t') 33 | tokens = tokenizer.tokenize(text) 34 | ids = tokenizer.convert_tokens_to_ids(tokens)[: 512] 35 | f.write(json.dumps({ 36 | 'query_id': query_id, 37 | 'query': ids 38 | }) + '\n') 39 | f.close() 40 | 41 | 42 | def tokenize_docs(tokenizer, input_file, output_file): 43 | total_size = sum(1 for _ in open(input_file)) 44 | f = open(output_file, 'w') 45 | with open(input_file, 'r') as r: 46 | for line in tqdm(r, total=total_size): 47 | line = json.loads(line.strip()) 48 | tokens = tokenizer.tokenize(line['doc']) 49 | ids = tokenizer.convert_tokens_to_ids(tokens)[: 512] 50 | f.write(json.dumps({ 51 | 'id': line['doc_id'], 52 | 'contents': ids 53 | }) + '\n') 54 | f.close() 55 | 56 | 57 | def tokenize_pairwise(tokenizer, input_file, output_file): 58 | total_size = sum(1 for _ in open(input_file)) 59 | f = open(output_file, 'w') 60 | with open(input_file, 'r') as r: 61 | for line in tqdm(r, total=total_size): 62 | line = json.loads(line.strip()) 63 | tokens = tokenizer.tokenize(line['query']) 64 | query_ids = tokenizer.convert_tokens_to_ids(tokens)[: 512] 65 | 66 | tokens = tokenizer.tokenize(line['doc_pos']) 67 | pos_ids = tokenizer.convert_tokens_to_ids(tokens)[: 512] 68 | 69 | tokens = tokenizer.tokenize(line['doc_neg']) 70 | neg_ids = tokenizer.convert_tokens_to_ids(tokens)[: 512] 71 | f.write(json.dumps({ 72 | 'query': query_ids, 73 | 'doc_pos': pos_ids, 74 | 'doc_neg': neg_ids 75 | }) + '\n') 76 | f.close() 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | parser = argparse.ArgumentParser() 82 | 83 | parser.add_argument("--vocab_dir", default='bert-base-uncased', type=str) 84 | parser.add_argument("--type", default='query', type=str) 85 | parser.add_argument("--input", default='', type=str, required=True) 86 | parser.add_argument("--output", default='', type=str, required=True) 87 | args = parser.parse_args() 88 | 89 | tokenizer = AutoTokenizer.from_pretrained(args.vocab_dir) 90 | 91 | if args.type == "query": 92 | tokenize_queries(tokenizer, args.input, args.output) 93 | elif args.type == "doc": 94 | tokenize_docs(tokenizer, args.input, args.output) 95 | elif args.type == "triples": 96 | tokenize_pairwise(tokenizer, args.input, args.output) -------------------------------------------------------------------------------- /visualization/visualization.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | from typing import Any, Iterable, List, Tuple, Union 6 | try: 7 | from IPython.core.display import HTML, display 8 | 9 | HAS_IPYTHON = True 10 | except ImportError: 11 | HAS_IPYTHON = False 12 | 13 | 14 | class VisualizationDataRecord: 15 | r""" 16 | A data record for storing attribution relevant information 17 | """ 18 | __slots__ = [ 19 | "word_attributions", 20 | "level", 21 | "rank", 22 | "v_q_id", 23 | "v_d_id", 24 | "doc_tokens", 25 | "convergence_score", 26 | ] 27 | 28 | def __init__( 29 | self, 30 | word_attributions, 31 | level, 32 | rank, 33 | v_q_id, 34 | v_d_id, 35 | doc_tokens, 36 | convergence_score, 37 | ): 38 | self.word_attributions = word_attributions 39 | self.level = level 40 | self.rank = rank 41 | self.v_q_id = v_q_id 42 | self.v_d_id = v_d_id 43 | self.doc_tokens = doc_tokens 44 | self.convergence_score = convergence_score 45 | 46 | 47 | def _get_color(attr): 48 | # clip values to prevent CSS errors (Values should be from [-1,1]) 49 | # attr = max(-1, min(1, attr)) 50 | if attr > 0: 51 | hue = 10 52 | sat = 75 53 | lig = 100-int(100 * attr) 54 | else: 55 | hue = 220 56 | sat = 75 57 | lig = 100 - int(-100 * attr) 58 | return "hsl({}, {}%, {}%)".format(hue, sat, lig) 59 | 60 | 61 | def format_classname(classname): 62 | return '{}'.format(classname) 63 | 64 | 65 | def format_special_tokens(token): 66 | if token.startswith("<") and token.endswith(">"): 67 | return "#" + token.strip("<>") 68 | return token 69 | 70 | 71 | def format_tooltip(item, text): 72 | return '
{item}\ 73 | {text}\ 74 |
'.format( 75 | item=item, text=text 76 | ) 77 | 78 | 79 | def format_word_importances(words, importances): 80 | if importances is None or len(importances) == 0: 81 | return "" 82 | assert len(words) <= len(importances) 83 | tags = [""] 84 | for word, importance in zip(words, importances[: len(words)]): 85 | print(word, importance) 86 | word = format_special_tokens(word) 87 | color = _get_color(importance) 88 | # unwrapped_tag = ' {word}\ 90 | # '.format( 91 | # color=color, word=word 92 | # ) 93 | if word.startswith("##"): 94 | unwrapped_tag = '{word}'.format( 95 | color=color, word=word.replace("##","") 96 | ) 97 | else: 98 | unwrapped_tag = ' {word}'.format( 99 | color=color, word=word 100 | ) 101 | tags.append(unwrapped_tag) 102 | tags.append("") 103 | return "".join(tags) 104 | 105 | 106 | def visualize_text( 107 | datarecords: Iterable[VisualizationDataRecord], legend: bool = False 108 | ) -> "HTML": # In quotes because this type doesn't exist in standalone mode 109 | assert HAS_IPYTHON, ( 110 | "IPython must be available to visualize text. " 111 | "Please run 'pip install ipython'." 112 | ) 113 | dom = [] 114 | dom.append("") 115 | dom.append("") 116 | dom.append("") 117 | dom.append("") 118 | rows = [ 119 | '' 120 | '' 121 | '' 122 | ] 123 | for datarecord in datarecords: 124 | rows.append( 125 | "".join( 126 | [ 127 | "", 128 | format_classname("{}\n{}".format(datarecord.v_q_id,datarecord.v_d_id)), 129 | format_classname( 130 | "{}/{}".format( 131 | datarecord.level, datarecord.rank 132 | ) 133 | ), 134 | format_word_importances( 135 | datarecord.doc_tokens, datarecord.word_attributions 136 | ), 137 | "", 138 | ] 139 | ) 140 | ) 141 | 142 | dom.append("".join(rows)) 143 | dom.append("
QID DIDRelevance Level/RankWord Importance
") 144 | dom.append("") 145 | dom.append("") 146 | html = HTML("".join(dom)) 147 | display(html) 148 | 149 | return html -------------------------------------------------------------------------------- /visualization/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | import argparse 6 | import pprint 7 | import yaml 8 | 9 | 10 | def str2bool(v): 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | 19 | class Config(object): 20 | def __init__(self, **kwargs): 21 | """Configuration Class: set kwargs as class attributes with setattr""" 22 | for k, v in kwargs.items(): 23 | setattr(self, k, v) 24 | 25 | @property 26 | def config_str(self): 27 | return pprint.pformat(self.__dict__) 28 | 29 | def __repr__(self): 30 | """Pretty-print configurations in alphabetical order""" 31 | config_str = 'Configurations\n' 32 | config_str += self.config_str 33 | return config_str 34 | 35 | def save(self, path): 36 | with open(path, 'w') as f: 37 | yaml.dump(self.__dict__, f, default_flow_style=False) 38 | 39 | @classmethod 40 | def load(cls, path): 41 | with open(path, 'r') as f: 42 | kwargs = yaml.load(f) 43 | 44 | return Config(**kwargs) 45 | 46 | 47 | def read_config(path): 48 | return Config.load(path) 49 | 50 | 51 | def get_config(parse=True, **optional_kwargs): 52 | """ 53 | Get configurations as attributes of class 54 | 1. Parse configurations with argparse. 55 | 2. Create Config class initilized with parsed kwargs. 56 | 3. Return Config class. 57 | """ 58 | parser = argparse.ArgumentParser() 59 | # Training 60 | parser.add_argument('--test', action="store_true") 61 | parser.add_argument('--epochs', type=int, default=10, 62 | help='num_epochs') 63 | parser.add_argument('--batch_size', type=int, default=1, 64 | help='batch size') 65 | parser.add_argument('--neg_docs_per_q', type=int, default=4, 66 | help='number of sampled docs per q-d pair') 67 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 68 | parser.add_argument("--weight_decay", default=0.01, type=float) 69 | parser.add_argument('--lr', type=float, default=3e-5, 70 | help='learning rate') 71 | parser.add_argument('--clip', type=float, default=1.0, 72 | help='gradient clip norm') 73 | parser.add_argument('--warm_up', type=float, default=0.1, 74 | help='warm up proportion') 75 | parser.add_argument('--gradient_checkpointing', action="store_true") 76 | parser.add_argument('--max_len', type=int, default=512, 77 | help='max length') 78 | parser.add_argument('--max_q_len', type=int, default=40, 79 | help='max query length') 80 | parser.add_argument('--visual_q_num', type=int, default=1, 81 | help='visual query number') 82 | parser.add_argument('--visual_d_num', type=int, default=5, 83 | help='visual document number') 84 | parser.add_argument('--model_name', type=str, default='ARES_simple', 85 | choices=['ARES_simple', 'ARES_hardest', 'BERT', 'PROP_msmarco'], 86 | help='the model name') 87 | parser.add_argument('--model_type', type=str, default='ARES', 88 | choices=['ARES', 'PROP', 'BERT'], 89 | help='the model type') 90 | parser.add_argument('--optim', type=str, default='adamw', 91 | choices=['adam', 'amsgrad', 'adagrad', 'adamw'], 92 | help='optimizer') 93 | parser.add_argument('--dropout', type=float, default=0.2) 94 | parser.add_argument('--distributed_train', action="store_true") 95 | parser.add_argument('--gpu_num', type=int, default=1) 96 | parser.add_argument('--seed', type=int, default=42, 97 | help='Random seed') 98 | parser.add_argument('--PRE_TRAINED_MODEL_NAME', default='/path/to/ARES-simple/', 99 | help='huggingface model name') 100 | parser.add_argument('--model_path', default='model_state_ARES', help='name of checkpoint to load') 101 | parser.add_argument('--print_every', default=200) 102 | parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training') 103 | parser.add_argument('--gradient_accumulation_steps', type=int, default=4, 104 | help="Number of updates steps to accumulate before performing a backward/update pass.") 105 | 106 | # human labels 107 | parser.add_argument('--dl2019_qd_dir', default='../preprocess/2019qrels-docs.txt') 108 | 109 | # queries 110 | parser.add_argument('--dl2019_qs_dir', default='../preprocess/queries.dl2019.json') 111 | 112 | # docs 113 | parser.add_argument('--memmap_doc_dir', default='../preprocess/doc_token_ids.memmap') 114 | parser.add_argument('--docid2id_dir', default='../preprocess/docid2idx.json') 115 | 116 | # STAR+ADORE Top100 117 | parser.add_argument('--dl100_dir', default='../preprocess/test.rank.tsv') 118 | 119 | if parse: 120 | kwargs = parser.parse_args() 121 | else: 122 | kwargs = parser.parse_known_args()[0] 123 | 124 | # Namespace => Dictionary 125 | kwargs = vars(kwargs) 126 | kwargs.update(optional_kwargs) 127 | 128 | return Config(**kwargs) -------------------------------------------------------------------------------- /finetune/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import argparse 7 | import pprint 8 | import yaml 9 | 10 | 11 | def str2bool(v): 12 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 13 | return True 14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError('Boolean value expected.') 18 | 19 | 20 | class Config(object): 21 | def __init__(self, **kwargs): 22 | """Configuration Class: set kwargs as class attributes with setattr""" 23 | for k, v in kwargs.items(): 24 | setattr(self, k, v) 25 | 26 | @property 27 | def config_str(self): 28 | return pprint.pformat(self.__dict__) 29 | 30 | def __repr__(self): 31 | """Pretty-print configurations in alphabetical order""" 32 | config_str = 'Configurations\n' 33 | config_str += self.config_str 34 | return config_str 35 | 36 | def save(self, path): 37 | with open(path, 'w') as f: 38 | yaml.dump(self.__dict__, f, default_flow_style=False) 39 | 40 | @classmethod 41 | def load(cls, path): 42 | with open(path, 'r') as f: 43 | kwargs = yaml.load(f) 44 | 45 | return Config(**kwargs) 46 | 47 | 48 | def read_config(path): 49 | return Config.load(path) 50 | 51 | 52 | def get_config(parse=True, **optional_kwargs): 53 | """ 54 | Get configurations as attributes of class 55 | 1. Parse configurations with argparse. 56 | 2. Create Config class initilized with parsed kwargs. 57 | 3. Return Config class. 58 | """ 59 | parser = argparse.ArgumentParser() 60 | # Training 61 | parser.add_argument('--test', action="store_true") 62 | parser.add_argument('--epochs', type=int, default=20, 63 | help='num_epochs') 64 | parser.add_argument('--batch_size', type=int, default=25, 65 | help='batch size') 66 | parser.add_argument('--neg_docs_per_q', type=int, default=4, 67 | help='number of sampled docs per q-d pair') 68 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 69 | parser.add_argument("--weight_decay", default=0.01, type=float) 70 | parser.add_argument('--lr', type=float, default=3e-5, 71 | help='learning rate') 72 | parser.add_argument('--clip', type=float, default=1.0, 73 | help='gradient clip norm') 74 | parser.add_argument('--warm_up', type=float, default=0.1, 75 | help='warm up proportion') 76 | parser.add_argument('--gradient_checkpointing', action="store_true") 77 | parser.add_argument('--max_len', type=int, default=512, 78 | help='max length') 79 | parser.add_argument('--max_q_len', type=int, default=15, 80 | help='max query length') 81 | parser.add_argument('--model_name', type=str, default='ARES_simple', 82 | help='the model name') 83 | parser.add_argument('--model_type', type=str, default='ARES', 84 | choices=['ARES', 'PROP', 'BERT', 'ICT'], 85 | help='the model type') 86 | parser.add_argument('--optim', type=str, default='adamw', 87 | choices=['adam', 'amsgrad', 'adagrad', 'adamw'], 88 | help='optimizer') 89 | parser.add_argument('--dropout', type=float, default=0.2) 90 | parser.add_argument('--embed_dim', type=int, default=100) 91 | parser.add_argument('--freeze', type=bool, default=False) 92 | parser.add_argument('--world_size', type=int, default=4) 93 | parser.add_argument('--distributed_train', action="store_true") 94 | parser.add_argument('--gpu_num', type=int, default=1) 95 | parser.add_argument('--seed', type=int, default=42, 96 | help='Random seed') 97 | parser.add_argument('--PRE_TRAINED_MODEL_NAME', default='/path/to/ares-simple/', 98 | help='huggingface model name') 99 | parser.add_argument('--gradient_accumulation_steps', type=int, default=4, 100 | help="Number of updates steps to accumulate before performing a backward/update pass.") 101 | parser.add_argument('--load_ckpt', action="store_true", help='whether to load a trained checkpoint') 102 | parser.add_argument('--model_path', default='model_state_ARES', help='name of checkpoint to load') 103 | parser.add_argument('--print_every', default=200) 104 | parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training') 105 | 106 | # human labels 107 | parser.add_argument('--train_qd_dir', default='../preprocess/msmarco-doctrain-qrels.tsv') 108 | parser.add_argument('--test_qd_dir', default='../preprocess/dev-qrels.txt') 109 | parser.add_argument('--dl2019_qd_dir', default='../preprocess/2019qrels-docs.txt') 110 | 111 | # queries 112 | parser.add_argument('--train_qs_dir', default='../preprocess/queries.doctrain.json') 113 | parser.add_argument('--test_qs_dir', default='../preprocess/queries.docdev.json') 114 | parser.add_argument('--dl2019_qs_dir', default='../preprocess/queries.dl2019.json') 115 | 116 | # docs 117 | parser.add_argument('--memmap_doc_dir', default='../preprocess/doc_token_ids.memmap') 118 | parser.add_argument('--docid2id_dir', default='../preprocess/docid2idx.json') 119 | 120 | # STAR+ADORE Top100 121 | parser.add_argument('--train100_dir', default='../preprocess/train.rank.tsv') 122 | parser.add_argument('--test100_dir', default='../preprocess/dev.rank.tsv') 123 | parser.add_argument('--dl100_dir', default='../preprocess/test.rank.tsv') 124 | 125 | if parse: 126 | kwargs = parser.parse_args() 127 | else: 128 | kwargs = parser.parse_known_args()[0] 129 | 130 | # Namespace => Dictionary 131 | kwargs = vars(kwargs) 132 | kwargs.update(optional_kwargs) 133 | 134 | return Config(**kwargs) -------------------------------------------------------------------------------- /pretrain/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import argparse 7 | import pprint 8 | import yaml 9 | 10 | 11 | def str2bool(v): 12 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 13 | return True 14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError('Boolean value expected.') 18 | 19 | 20 | class Config(object): 21 | def __init__(self, **kwargs): 22 | """Configuration Class: set kwargs as class attributes with setattr""" 23 | for k, v in kwargs.items(): 24 | setattr(self, k, v) 25 | 26 | @property 27 | def config_str(self): 28 | return pprint.pformat(self.__dict__) 29 | 30 | def __repr__(self): 31 | """Pretty-print configurations in alphabetical order""" 32 | config_str = 'Configurations\n' 33 | config_str += self.config_str 34 | return config_str 35 | 36 | def save(self, path): 37 | with open(path, 'w') as f: 38 | yaml.dump(self.__dict__, f, default_flow_style=False) 39 | 40 | @classmethod 41 | def load(cls, path): 42 | with open(path, 'r') as f: 43 | kwargs = yaml.load(f) 44 | 45 | return Config(**kwargs) 46 | 47 | 48 | def read_config(path): 49 | return Config.load(path) 50 | 51 | 52 | def get_config(parse=True, **optional_kwargs): 53 | """ 54 | Get configurations as attributes of class 55 | 1. Parse configurations with argparse. 56 | 2. Create Config class initilized with parsed kwargs. 57 | 3. Return Config class. 58 | """ 59 | parser = argparse.ArgumentParser() 60 | 61 | # Training 62 | parser.add_argument('--epochs', type=int, default=1, 63 | help='num_epochs') 64 | parser.add_argument('--batch_size', type=int, default=22, 65 | help='batch size') 66 | parser.add_argument('--neg_docs_per_q', type=int, default=4, 67 | help='number of sampled docs per q-d pair') 68 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 69 | parser.add_argument("--weight_decay", default=0.01, type=float) 70 | parser.add_argument('--lr', type=float, default=2e-5, 71 | help='learning rate') 72 | parser.add_argument('--clip', type=float, default=1.0, 73 | help='gradient clip norm') 74 | parser.add_argument('--warm_up', type=float, default=0.1, 75 | help='warm up proportion') 76 | parser.add_argument('--gradient_checkpointing', action="store_true") 77 | parser.add_argument('--max_len', type=int, default=512, 78 | help='max length') 79 | parser.add_argument('--max_q_len', type=int, default=40, 80 | help='max query length') 81 | parser.add_argument('--model_name', type=str, default='ARES_simple', 82 | help='the model name') 83 | parser.add_argument('--model_type', type=str, default='ARES', 84 | choices=['ARES', 'ICT'], 85 | help='the model type') 86 | parser.add_argument('--optim', type=str, default='adamw', choices=['adam', 'amsgrad', 'adagrad', 'adamw'], 87 | help='optimizer') 88 | parser.add_argument('--dropout', type=float, default=0.2) 89 | parser.add_argument('--embed_dim', type=int, default=100) 90 | parser.add_argument('--freeze', type=bool, default=False) 91 | parser.add_argument('--world_size', type=int, default=4) 92 | parser.add_argument('--distributed_train', action="store_true") 93 | parser.add_argument('--gpu_num', type=int, default=1) 94 | parser.add_argument('--seed', type=int, default=42, 95 | help='Random seed') 96 | parser.add_argument('--PRE_TRAINED_MODEL_NAME', default='/path/to/bert-base/', 97 | help='huggingface model name') 98 | parser.add_argument('--load_ckpt', action="store_true", 99 | help='whether to load a trained checkpoint') 100 | parser.add_argument('--model_path', default='model_state_ARES', 101 | help='name of checkpoint to load') 102 | parser.add_argument('--clf_model', default='/path/to/xgboost.model', 103 | help='the axiom classifier model path (xgboost)') 104 | parser.add_argument('--MLM', action="store_true", help='whether to add MLM loss while pre-training') 105 | parser.add_argument('--masked_lm_prob', default=0.15, help='only used when MLM is true') 106 | parser.add_argument('--max_predictions_per_seq', default=60, 107 | help='only used when MLM is true') 108 | parser.add_argument('--print_every', default=200) 109 | parser.add_argument('--local_rank', type=int, default=0, 110 | help='node rank for distributed training') 111 | 112 | # tricks 113 | parser.add_argument('--gradient_accumulation_steps', type=int, default=4, 114 | help="Number of updates steps to accumulate before performing a backward/update pass.") 115 | 116 | # human labels 117 | parser.add_argument('--train_qd_dir', default='../preprocess/msmarco-doctrain-qrels.tsv') 118 | parser.add_argument('--test_qd_dir', default='../preprocess/dev-qrels.txt') 119 | parser.add_argument('--dl2019_qd_dir', default='../preprocess/2019qrels-docs.txt') 120 | 121 | # queries 122 | parser.add_argument('--train_qs_dir', default='../preprocess/queries.doctrain.json') 123 | parser.add_argument('--test_qs_dir', default='../preprocess/queries.docdev.json') 124 | parser.add_argument('--dl2019_qs_dir', default='../preprocess/queries.dl2019.json') 125 | 126 | # docs 127 | parser.add_argument('--memmap_doc_dir', default='../preprocess/doc_token_ids.memmap') 128 | parser.add_argument('--docid2id_dir', default='../preprocess/docid2idx.json') 129 | 130 | # STAR+ADORE Top100 candidates 131 | parser.add_argument('--train100_dir', default='../preprocess/A+S_top100/train.rank.tsv') 132 | parser.add_argument('--test100_dir', default='../preprocess/A+S_top100/dev.rank.tsv') 133 | parser.add_argument('--dl100_dir', default='../preprocess/A+S_top100/test.rank.tsv') 134 | 135 | # candidate queries and axioms 136 | parser.add_argument('--doc2query_dir', default='../preprocess/doc2qs.json') 137 | parser.add_argument('--gen_qs_memmap_dir', default='../preprocess/sample_qs_token_ids.memmap') 138 | parser.add_argument('--gen_qid2id_dir', default='../preprocess/sample_qid2id.json') # qid idx 139 | parser.add_argument('--axiom', type=str, nargs='+', 140 | help="Basic axioms: [RANK, REP], Auxiliary axioms: [PROX, REG, STM], you should choose at least one basic axiom.", required=True) 141 | parser.add_argument('--axiom_feature_dir', default='../preprocess/axioms') 142 | 143 | if parse: 144 | kwargs = parser.parse_args() 145 | else: 146 | kwargs = parser.parse_known_args()[0] 147 | 148 | # Namespace => Dictionary 149 | kwargs = vars(kwargs) 150 | kwargs.update(optional_kwargs) 151 | 152 | return Config(**kwargs) -------------------------------------------------------------------------------- /finetune/ms_marco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH! 3 | Command line: 4 | python msmarco_eval_ranking.py 5 | Creation Date : 06/12/2018 6 | Last Modified : 4/09/2019 7 | Authors : Daniel Campos , Rutger van Haasteren 8 | """ 9 | 10 | import sys 11 | import math 12 | import numpy as np 13 | from collections import Counter 14 | 15 | MaxMRRRank1 = 10 16 | MaxMRRRank2 = 100 17 | 18 | 19 | def load_reference_from_stream(f): 20 | """Load Reference reference relevant passages 21 | Args:f (stream): stream to load. 22 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 23 | """ 24 | qids_to_relevant_passageids = {} 25 | for l in f: 26 | try: 27 | l = l.strip().split('\t') 28 | qid = int(l[0]) 29 | if qid in qids_to_relevant_passageids: 30 | pass 31 | else: 32 | qids_to_relevant_passageids[qid] = [] 33 | qids_to_relevant_passageids[qid].append(l[1]) 34 | except: 35 | raise IOError('\"%s\" is not valid format' % l) 36 | return qids_to_relevant_passageids 37 | 38 | 39 | def load_reference(path_to_reference): 40 | """Load Reference reference relevant passages 41 | Args:path_to_reference (str): path to a file to load. 42 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 43 | """ 44 | with open(path_to_reference, 'r') as f: 45 | qids_to_relevant_passageids = load_reference_from_stream(f) 46 | return qids_to_relevant_passageids 47 | 48 | 49 | def load_candidate_from_stream(f): 50 | """Load candidate data from a stream. 51 | Args:f (stream): stream to load. 52 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 53 | """ 54 | qid_to_ranked_candidate_passages = {} 55 | for l in f: 56 | try: 57 | l = l.strip().split('\t') 58 | qid = int(l[0]) 59 | pid = l[1] 60 | rank = int(l[2]) 61 | if qid in qid_to_ranked_candidate_passages: 62 | pass 63 | else: 64 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 65 | tmp = [0] * 1000 66 | qid_to_ranked_candidate_passages[qid] = tmp 67 | qid_to_ranked_candidate_passages[qid][rank - 1] = pid 68 | except: 69 | raise IOError('\"%s\" is not valid format' % l) 70 | return qid_to_ranked_candidate_passages 71 | 72 | 73 | def load_candidate(path_to_candidate): 74 | """Load candidate data from a file. 75 | Args:path_to_candidate (str): path to file to load. 76 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 77 | """ 78 | 79 | with open(path_to_candidate, 'r') as f: 80 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 81 | return qid_to_ranked_candidate_passages 82 | 83 | 84 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 85 | """Perform quality checks on the dictionaries 86 | Args: 87 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 88 | Dict as read in with load_reference or load_reference_from_stream 89 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 90 | Returns: 91 | bool,str: Boolean whether allowed, message to be shown in case of a problem 92 | """ 93 | message = '' 94 | allowed = True 95 | 96 | # Create sets of the QIDs for the submitted and reference queries 97 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 98 | ref_set = set(qids_to_relevant_passageids.keys()) 99 | 100 | # Check that we do not have multiple passages per query 101 | for qid in qids_to_ranked_candidate_passages: 102 | # Remove all zeros from the candidates 103 | duplicate_pids = set( 104 | [item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 105 | 106 | if len(duplicate_pids - set([0])) > 0: 107 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 108 | qid=qid, pid=list(duplicate_pids)[0]) 109 | allowed = False 110 | 111 | return allowed, message 112 | 113 | 114 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 115 | """Compute MRR metric 116 | Args: 117 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 118 | Dict as read in with load_reference or load_reference_from_stream 119 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 120 | Returns: 121 | dict: dictionary of metrics {'MRR': } 122 | """ 123 | all_scores = {} 124 | MRR_10, MRR_100 = 0, 0 125 | qids_with_relevant_passages = 0 126 | ranking = [] 127 | for qid in qids_to_ranked_candidate_passages: 128 | if qid in qids_to_relevant_passageids: 129 | ranking.append(0) 130 | target_pid = qids_to_relevant_passageids[qid] 131 | candidate_pid = qids_to_ranked_candidate_passages[qid] 132 | for i in range(0, MaxMRRRank1): 133 | if candidate_pid[i] in target_pid: 134 | MRR_10 += 1 / (i + 1) 135 | ranking.pop() 136 | ranking.append(i + 1) 137 | break 138 | for i in range(0, MaxMRRRank2): 139 | if candidate_pid[i] in target_pid: 140 | MRR_100 += 1 / (i + 1) 141 | break 142 | if len(ranking) == 0: 143 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 144 | 145 | MRR_10 = MRR_10 / len(qids_to_relevant_passageids) 146 | MRR_100 = MRR_100 / len(qids_to_relevant_passageids) 147 | all_scores['MRR @10'] = MRR_10 148 | all_scores['MRR @100'] = MRR_100 149 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 150 | return all_scores 151 | 152 | 153 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 154 | """Compute MRR metric 155 | Args: 156 | p_path_to_reference_file (str): path to reference file. 157 | Reference file should contain lines in the following format: 158 | QUERYID\tPASSAGEID 159 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 160 | p_path_to_candidate_file (str): path to candidate file. 161 | Candidate file sould contain lines in the following format: 162 | QUERYID\tPASSAGEID1\tRank 163 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 164 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 165 | Where the values are separated by tabs and ranked in order of relevance 166 | Returns: 167 | dict: dictionary of metrics {'MRR': } 168 | """ 169 | 170 | qids_to_relevant_passageids = load_reference(path_to_reference) 171 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 172 | if perform_checks: 173 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 174 | if message != '': print(message) 175 | 176 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 177 | 178 | 179 | def main(): 180 | """Command line: 181 | python msmarco_eval_ranking.py 182 | """ 183 | path_to_candidate = sys.argv[2] 184 | path_to_reference = sys.argv[1] 185 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 186 | print('#####################') 187 | for metric in sorted(metrics): 188 | print('{}: {}'.format(metric, metrics[metric])) 189 | print('#####################') 190 | 191 | 192 | if __name__ == '__main__': 193 | main() -------------------------------------------------------------------------------- /finetune/dataloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import random 7 | import numpy as np 8 | from tqdm import tqdm 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | 12 | 13 | class TrainQDDatasetPairwise(Dataset): 14 | def __init__(self, q_ids, d_ids, q_dict, d_dict, did2idx, config, labels, mode='train'): 15 | self.q_ids = q_ids 16 | self.d_ids = d_ids 17 | self.q_dict = q_dict 18 | self.d_dict = d_dict 19 | self.did2idx = did2idx 20 | self.labels = labels 21 | self.mode = mode 22 | self.config = config 23 | 24 | def __len__(self): 25 | return len(self.q_ids) 26 | 27 | def __getitem__(self, item): 28 | cls_id, sep_id = 101, 102 29 | q_id = self.q_ids[item] 30 | d_id = self.d_ids[item] 31 | 32 | q_id = q_id[0] 33 | pos_did, neg_did = d_id[0], d_id[1] 34 | 35 | query_input_ids, pos_doc_input_ids, neg_doc_input_ids = self.q_dict[str(q_id)], self.d_dict[self.did2idx[pos_did]].tolist(), \ 36 | self.d_dict[self.did2idx[neg_did]].tolist() 37 | query_input_ids = query_input_ids[: self.config.max_q_len] 38 | max_passage_length = self.config.max_len - 3 - len(query_input_ids) 39 | 40 | pos_doc_input_ids = pos_doc_input_ids[:max_passage_length] 41 | neg_doc_input_ids = neg_doc_input_ids[:max_passage_length] 42 | 43 | pos_input_ids = [cls_id] + query_input_ids + [sep_id] + pos_doc_input_ids + [sep_id] 44 | neg_input_ids = [cls_id] + query_input_ids + [sep_id] + neg_doc_input_ids + [sep_id] 45 | 46 | pos_token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(pos_doc_input_ids)) 47 | neg_token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(neg_doc_input_ids)) 48 | 49 | pos_token_ids = np.array(pos_input_ids) 50 | neg_token_ids = np.array(neg_input_ids) 51 | token_ids = np.stack((pos_token_ids.flatten(), neg_token_ids.flatten())) 52 | 53 | pos_attention_mask = np.int64(pos_token_ids > 0) 54 | neg_attention_mask = np.int64(neg_token_ids > 0) 55 | attention_mask = np.stack((pos_attention_mask, neg_attention_mask)) 56 | 57 | pos_token_type_ids = np.array(pos_token_type_ids) 58 | neg_token_type_ids = np.array(neg_token_type_ids) 59 | token_type_ids = np.stack((pos_token_type_ids, neg_token_type_ids)) 60 | 61 | return { 62 | 'token_ids': token_ids, 63 | 'attention_mask': attention_mask, 64 | 'token_type_ids': token_type_ids, 65 | } 66 | 67 | 68 | class TestQDDataset(Dataset): 69 | def __init__(self, q_ids, d_ids, token_ids, attention_mask, token_type_ids, mode='test'): 70 | self.q_ids = q_ids 71 | self.d_ids = d_ids 72 | self.token_ids = token_ids 73 | self.attention_mask = attention_mask 74 | self.token_type_ids= token_type_ids 75 | self.mode = mode 76 | 77 | def __len__(self): 78 | return len(self.q_ids) 79 | 80 | def __getitem__(self, item): 81 | q_id = self.q_ids[item] 82 | d_id = self.d_ids[item] 83 | token_ids = np.array(self.token_ids[item]) 84 | attention_mask = np.array(self.attention_mask[item]) 85 | token_type_ids = np.array(self.token_type_ids[item]) 86 | 87 | return { 88 | "q_id": q_id, 89 | "d_id": d_id, 90 | 'token_ids': token_ids.flatten(), 91 | 'attention_mask': attention_mask.flatten(), 92 | 'token_type_ids': token_type_ids.flatten(), 93 | } 94 | 95 | 96 | # [CLS] q [SEP] d [SEP] 97 | def get_train_qd_loader(df_qds, train_top100, q_dict, d_dict, did2idx, config, mode='train'): 98 | q_max_len, max_len, batch_size = config.max_q_len, config.max_len, config.batch_size 99 | q_ids = df_qds[0].values.tolist() 100 | d_ids = df_qds[2].values.tolist() 101 | 102 | qd_dict = {} 103 | for q_id, d_id in zip(q_ids, d_ids): 104 | if q_id not in qd_dict: 105 | qd_dict[q_id] = [] 106 | qd_dict[q_id].append(d_id) 107 | 108 | top100_dict = {} 109 | top_qids = train_top100[0].values.tolist() 110 | top_dids = train_top100[1].values.tolist() 111 | for qid, did in zip(top_qids, top_dids): 112 | if qid not in top100_dict: 113 | top100_dict[qid] = [] 114 | top100_dict[qid].append(did) 115 | 116 | new_q_ids, new_d_ids, labels = [], [], [] 117 | 118 | q_num = len(q_ids) 119 | for idx in tqdm(range(q_num), desc=f"Loading train q-d progress"): 120 | this_qid = q_ids[idx] 121 | neg_cands = set(top100_dict[this_qid]) - set(qd_dict[this_qid]) 122 | neg_cands = list(neg_cands) 123 | neg_dids = random.sample(neg_cands, config.neg_docs_per_q) 124 | for i in range(config.neg_docs_per_q): 125 | new_q_ids.append([this_qid]) 126 | new_d_ids.append([d_ids[idx], neg_dids[i]]) 127 | labels.append([1, 0]) 128 | 129 | print('Loading tokens...') 130 | ds = TrainQDDatasetPairwise( 131 | q_ids=new_q_ids, 132 | d_ids=new_d_ids, 133 | q_dict=q_dict, 134 | d_dict=d_dict, 135 | did2idx=did2idx, 136 | config=config, 137 | labels=labels, 138 | mode='train' 139 | ) 140 | batch_size = batch_size // 2 141 | 142 | if config.distributed_train: 143 | sampler = DistributedSampler(ds, num_replicas=config.world_size, rank=config.local_rank) 144 | return DataLoader( 145 | ds, 146 | batch_size=batch_size, 147 | num_workers=0, 148 | sampler=sampler 149 | ) 150 | else: 151 | if mode == 'train': 152 | return DataLoader( 153 | ds, 154 | batch_size=batch_size, 155 | num_workers=0, 156 | shuffle=True, 157 | ) 158 | else: 159 | return DataLoader( 160 | ds, 161 | batch_size=batch_size, 162 | num_workers=0, 163 | shuffle=False, 164 | ) 165 | 166 | 167 | def get_test_qd_loader(top100qd, q_dict, d_dict, did2idx, config): 168 | cls_id, sep_id = 101, 102 169 | q_ids = top100qd[0].values.tolist() 170 | d_ids = top100qd[1].values.tolist() 171 | 172 | qd_dict = {} 173 | for q_id, d_id in zip(q_ids, d_ids): 174 | if q_id not in qd_dict: 175 | qd_dict[q_id] = [] 176 | qd_dict[q_id].append(d_id) 177 | 178 | q_num = len(q_dict) 179 | qids = list(set(q_dict.keys())) 180 | tokens_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512 181 | token_type_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512 182 | 183 | new_q_ids, new_d_ids = [], [] 184 | for idx in tqdm(range(len(qids)), desc=f"Loading test q-d pair progress"): 185 | this_qid = qids[idx] 186 | 187 | query_input_ids = q_dict[str(this_qid)] 188 | query_input_ids = query_input_ids[: config.max_q_len] 189 | max_passage_length = config.max_len - 3 - len(query_input_ids) 190 | 191 | dids = qd_dict[int(this_qid)] 192 | assert len(dids) == 100 193 | for rank in range(len(dids)): 194 | this_did = dids[rank] 195 | doc_input_ids = d_dict[did2idx[this_did]].tolist() 196 | doc_input_ids = doc_input_ids[:max_passage_length] 197 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id] 198 | token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(doc_input_ids)) 199 | cat_len = min(len(input_ids), config.max_len) 200 | 201 | new_q_ids.append(this_qid) 202 | new_d_ids.append(this_did) 203 | tokens_np[idx * 100 + rank, :cat_len] = np.array(input_ids) 204 | token_type_np[idx * 100 + rank, :cat_len] = np.array(token_type_ids) 205 | 206 | attention_mask = np.int64(tokens_np > 0).tolist() 207 | tokens = tokens_np.tolist() # q_num x 512 208 | token_type = token_type_np.tolist() # q_num x 512 209 | 210 | ds = TestQDDataset( 211 | q_ids=new_q_ids, 212 | d_ids=new_d_ids, 213 | token_ids=tokens, 214 | token_type_ids=token_type, 215 | attention_mask=attention_mask, 216 | mode='test' 217 | ) 218 | 219 | return DataLoader( 220 | ds, 221 | batch_size=100, # 100 docs per q 222 | num_workers=0, 223 | shuffle=False, 224 | ) 225 | 226 | -------------------------------------------------------------------------------- /visualization/dataloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.utils.data.distributed import DistributedSampler 9 | 10 | 11 | class TestQDDataset(Dataset): 12 | def __init__(self, q_ids, d_ids, token_ids, attention_mask, token_type_ids, mode='test'): 13 | self.q_ids = q_ids 14 | self.d_ids = d_ids 15 | self.token_ids = token_ids 16 | self.attention_mask = attention_mask 17 | self.token_type_ids= token_type_ids 18 | self.mode = mode 19 | 20 | def __len__(self): 21 | return len(self.q_ids) 22 | 23 | def __getitem__(self, item): 24 | q_id = self.q_ids[item] 25 | d_id = self.d_ids[item] 26 | token_ids = np.array(self.token_ids[item]) 27 | attention_mask = np.array(self.attention_mask[item]) 28 | token_type_ids = np.array(self.token_type_ids[item]) 29 | 30 | return { 31 | "q_id": q_id, 32 | "d_id": d_id, 33 | 'token_ids': token_ids.flatten(), 34 | 'attention_mask': attention_mask.flatten(), 35 | 'token_type_ids': token_type_ids.flatten(), 36 | } 37 | 38 | 39 | class VisualTestQDDataset(Dataset): 40 | def __init__(self, q_ids, d_ids, ranks,token_ids, ref_token_ids, token_type_ids, ref_token_type_ids, attention_mask, mode='test'): 41 | self.q_ids = q_ids 42 | self.d_ids = d_ids 43 | self.ranks = ranks 44 | self.token_ids = token_ids 45 | self.ref_token_ids=ref_token_ids 46 | self.token_type_ids = token_type_ids 47 | self.ref_token_type_ids = ref_token_type_ids 48 | self.attention_mask = attention_mask 49 | self.mode = mode 50 | 51 | def __len__(self): 52 | return len(self.q_ids) 53 | 54 | def __getitem__(self, item): 55 | q_id = self.q_ids[item] 56 | d_id = self.d_ids[item] 57 | rank = self.ranks[item] 58 | token_ids = np.array(self.token_ids[item]) 59 | attention_mask = np.array(self.attention_mask[item]) 60 | token_type_ids = np.array(self.token_type_ids[item]) 61 | ref_token_ids=np.array(self.ref_token_ids[item]) 62 | ref_token_type_ids=np.array(self.ref_token_type_ids[item]) 63 | return { 64 | "q_id": q_id, 65 | "d_id": d_id, 66 | "rank": rank, 67 | 'token_ids': token_ids.flatten(), 68 | 'attention_mask': attention_mask.flatten(), 69 | 'token_type_ids': token_type_ids.flatten(), 70 | 'ref_token_ids': ref_token_ids.flatten(), 71 | 'ref_token_type_ids': ref_token_type_ids.flatten() 72 | } 73 | 74 | 75 | def get_test_qd_loader(top100qd, q_dict, d_dict, did2idx, config): 76 | cls_id, sep_id = 101, 102 77 | q_ids = top100qd[0].values.tolist() 78 | d_ids = top100qd[1].values.tolist() 79 | 80 | qd_dict = {} 81 | for q_id, d_id in zip(q_ids, d_ids): 82 | if q_id not in qd_dict: 83 | qd_dict[q_id] = [] 84 | qd_dict[q_id].append(d_id) 85 | 86 | q_num = len(q_dict) 87 | qids = list(set(q_dict.keys())) 88 | tokens_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512 89 | token_type_np = np.zeros((q_num * 100, config.max_len), dtype='int32') # (q_num * 100) x 512 90 | 91 | new_q_ids, new_d_ids = [], [] 92 | for idx in tqdm(range(len(qids)), desc=f"Loading test q-d pair progress"): 93 | this_qid = qids[idx] 94 | 95 | query_input_ids = q_dict[str(this_qid)] 96 | query_input_ids = query_input_ids[: config.max_q_len] 97 | max_passage_length = config.max_len - 3 - len(query_input_ids) 98 | 99 | dids = qd_dict[int(this_qid)] 100 | assert len(dids) == 100 101 | for rank in range(len(dids)): 102 | this_did = dids[rank] 103 | doc_input_ids = d_dict[did2idx[this_did]].tolist() 104 | doc_input_ids = doc_input_ids[:max_passage_length] 105 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id] 106 | token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(doc_input_ids)) 107 | cat_len = min(len(input_ids), config.max_len) 108 | 109 | new_q_ids.append(this_qid) 110 | new_d_ids.append(this_did) 111 | tokens_np[idx * 100 + rank, :cat_len] = np.array(input_ids) 112 | token_type_np[idx * 100 + rank, :cat_len] = np.array(token_type_ids) 113 | 114 | attention_mask = np.int64(tokens_np > 0).tolist() 115 | tokens = tokens_np.tolist() # q_num x 512 116 | token_type = token_type_np.tolist() # q_num x 512 117 | 118 | ds = TestQDDataset( 119 | q_ids=new_q_ids, 120 | d_ids=new_d_ids, 121 | token_ids=tokens, 122 | token_type_ids=token_type, 123 | attention_mask=attention_mask, 124 | mode='test' 125 | ) 126 | 127 | return DataLoader( 128 | ds, 129 | batch_size=100, # 100 docs per q 130 | num_workers=0, 131 | shuffle=False, 132 | ) 133 | 134 | 135 | def get_visual_test_qd_loader(top100qd, q_dict, d_dict, did2idx, config): 136 | cls_id, sep_id, pad_id = 101, 102, 0 137 | q_ids = top100qd["q_id"].values.tolist() 138 | d_ids = top100qd["d_id"].values.tolist() 139 | ranks = top100qd["rank"].values.tolist() 140 | d_num = config.visual_d_num 141 | q_num = config.visual_q_num 142 | qd_dict = {} 143 | for q_id, d_id, rank in zip(q_ids, d_ids,ranks): 144 | if q_id not in qd_dict: 145 | qd_dict[q_id] = [] 146 | qd_dict[q_id].append([d_id,rank]) 147 | 148 | qids = list(q_dict.keys())[:q_num] 149 | tokens_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512 150 | token_type_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512 151 | ref_tokens_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512 152 | ref_token_type_np = np.zeros((q_num * d_num, config.max_len), dtype='int32') # (q_num * d_num) x 512 153 | 154 | new_q_ids, new_d_ids, new_ranks = [], [], [] 155 | for idx in tqdm(range(len(qids)), desc=f"Loading test q-d pair progress"): 156 | this_qid = qids[idx] 157 | query_input_ids = q_dict[str(this_qid)] 158 | query_input_ids = query_input_ids[: config.max_q_len] 159 | max_passage_length = config.max_len - 3 - len(query_input_ids) 160 | 161 | did_ranks = qd_dict[str(this_qid)][:d_num] 162 | assert len(did_ranks) == d_num 163 | for rank in range(len(did_ranks)): 164 | this_did,this_rank = did_ranks[rank] 165 | doc_input_ids = d_dict[did2idx[this_did]].tolist() 166 | doc_input_ids = doc_input_ids[:max_passage_length] 167 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id] 168 | ref_input_ids = [cls_id] + [pad_id] * len(query_input_ids) + [sep_id] + [pad_id] * len(doc_input_ids) + [sep_id] 169 | token_type_ids = [0] * (2 + len(query_input_ids)) + [1] * (1 + len(doc_input_ids)) 170 | ref_token_type_ids = [0] * len(token_type_ids) 171 | cat_len = min(len(input_ids), config.max_len) 172 | 173 | new_q_ids.append(this_qid) 174 | new_d_ids.append(this_did) 175 | new_ranks.append(this_rank) 176 | tokens_np[idx * d_num + rank, :cat_len] = np.array(input_ids) 177 | token_type_np[idx * d_num + rank, :cat_len] = np.array(token_type_ids) 178 | ref_tokens_np[idx * d_num + rank, :cat_len] = np.array(ref_input_ids) 179 | ref_token_type_np[idx * d_num + rank, :cat_len] = np.array(ref_token_type_ids) 180 | attention_mask = np.int64(tokens_np > 0).tolist() 181 | tokens = tokens_np.tolist() # q_num x 512 182 | token_type = token_type_np.tolist() # q_num x 512 183 | ref_tokens = ref_tokens_np.tolist() # q_num x 512 184 | ref_token_type = ref_token_type_np.tolist() # q_num x 512 185 | 186 | ds = VisualTestQDDataset( 187 | q_ids=new_q_ids, 188 | d_ids=new_d_ids, 189 | ranks=new_ranks, 190 | token_ids=tokens, 191 | ref_token_ids=ref_tokens, 192 | token_type_ids=token_type, 193 | ref_token_type_ids=ref_token_type, 194 | attention_mask=attention_mask, 195 | mode='test' 196 | ) 197 | 198 | return DataLoader( 199 | ds, 200 | batch_size=config.batch_size, # 100 docs per q 201 | num_workers=0, 202 | shuffle=False, 203 | ) 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![img](imgs/ARES_black.jpg) 2 | 3 |

4 | 5 | 6 | THUIR 7 | 8 | 9 | License 10 | 11 | 12 | made-with-python 13 | 14 | 15 | code-size 16 | 17 | 18 |

19 | 20 | ## Introduction 21 | This codebase contains source-code of the Python-based implementation (ARES) of our SIGIR 2022 paper. 22 | - [Chen, Jia, et al. "Axiomatically Regularized Pre-training for Ad hoc Search." To Appear in the Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval. 2022.](https://xuanyuan14.github.io/files/SIGIR22Chen.pdf) 23 | 24 | ## Requirements 25 | * python 3.7 26 | * torch==1.9.0 27 | * transformers==4.9.2 28 | * tqdm, nltk, numpy, boto3 29 | * [trec_eval](https://github.com/usnistgov/trec_eval) for evaluation on TREC DL 2019 30 | * [anserini](https://github.com/castorini/anserini) for generating "RANK" axiom scores 31 | 32 | ## Why this repo? 33 | In this repo, you can pre-train ARESsimple and TransformerICT models, and fine-tune all pre-trained models with the same architecture as BERT. The papers are listed as follows: 34 | * BERT ([Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/pdf/1810.04805.pdf&usg=ALkJrhhzxlCL6yTht2BRmH9atgvKFxHsxQ)) 35 | * TransformerICT ([Latent Retrieval for Weakly Supervised Open Domain Question Answering.](https://arxiv.org/pdf/1906.00300)) 36 | * PROP ([PROP: Pre-training with representative words prediction for ad-hoc retrieval.](https://dl.acm.org/doi/pdf/10.1145/3437963.3441777)) 37 | * ARES ([Axiomatically Regularized Pre-training for Ad hoc Search.](https://xuanyuan14.github.io/files/SIGIR22Chen.pdf)) 38 | 39 | You can download the pre-trained ARES checkpoint [ARESsimple](https://drive.google.com/file/d/1QvJ-hs6VtK4nlrlFkzPZAXfTtY-QjTiU/view?usp=sharing) from Google drive and extract it. 40 | 41 | ## Pre-training Data 42 | 43 | ### Download data 44 | Download the **MS MARCO** corpus from the official [website](https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz). 45 | Download the **ADORE+STAR Top100 Candidates** files from this [repo](https://github.com/jingtaozhan/DRhard). 46 | 47 | ### Pre-process data 48 | To save memory, we store most files using the numpy `memmap` or `jsonl` format in the `./preprocess` directory. 49 | 50 | Document files: 51 | * `doc_token_ids.memmap`: each line is the token ids for a document 52 | * `docid2idx.json`: `{docid: memmap_line_id}` 53 | 54 | Query files: 55 | * `queries.doctrain.jsonl`: MS MARCO training queries `{"id" qid, "ids": token_ids}` for each line 56 | * `queries.docdev.jsonl`: MS MARCO validating queries `{"id" qid, "ids": token_ids}` for each line 57 | * `queries.dl2019.jsonl`: TREC DL 2019 queries `{"id" qid, "ids": token_ids}` for each line 58 | 59 | Human label files: 60 | * `msmarco-doctrain-qrels.tsv`: `qid 0 docid 1` for training set 61 | * `dev-qrels.txt`: `qid relevant_docid` for validating set 62 | * `2019qrels-docs.txt`: `qid relevant_docid` for TREC DL 2019 set 63 | 64 | Top 100 candidate files: 65 | * `train.rank.tsv`, `dev.rank.tsv`, `test.rank.tsv`: `qid docid rank` for each line 66 | 67 | Pseudo queries and axiomatic features: 68 | * `doc2qs.jsonl`: `{"docid": docid, "queries": [qids]}` for each line 69 | * `sample_qs_token_ids.memmap`: each line is the token ids for a pseudo query 70 | * `sample_qid2id.json`: `{qid: memmap_line_id}` 71 | * `axiom.memmap`: axiom can be one of the `['rank', 'prox-1', 'prox-2', 'rep-ql', 'rep-tfidf', 'reg', 'stm-1', 'stm-2', 'stm-3']`, each line is an axiomatic score for a query 72 | 73 | 74 | ## Quick Start 75 | 76 | ### Example Usage 77 | ```python 78 | from model.modeling import ARESReranker 79 | 80 | model = ARESReranker.from_pretrained(model_path).to(device) 81 | 82 | query1 = "What is the best way to get to the airport" 83 | query2 = "what do you like to eat?" 84 | 85 | doc1 = "The best way to get to the airport is to take the bus" 86 | doc2 = "I like to eat apples" 87 | 88 | qd_pairs = [ 89 | (query1, doc1), (query1, doc2), 90 | (query2, doc1), (query2, doc2) 91 | ] 92 | 93 | score = model.score(qd_pairs) 94 | ``` 95 | 96 | You will get 97 | ```bash 98 | scores: [ 41.60 -33.66 99 | -38.00 30.03 ] 100 | ``` 101 | 102 | Note that to accelerate the training process, we adopt the parallel training technique. The scripts for pre-training and fine-tuning are as follow: 103 | 104 | ### Pre-training 105 | 106 | ```shell 107 | export BERT_DIR=/path/to/bert-base/ 108 | export XGB_DIR=/path/to/xgboost.model 109 | 110 | cd pretrain 111 | 112 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 NCCL_BLOCKING_WAIT=1 \ 113 | python -m torch.distributed.launch --nproc_per_node=6 --nnodes=1 train.py \ 114 | --model_type ARES \ 115 | --PRE_TRAINED_MODEL_NAME BERT_DIR \ 116 | --gpu_num 6 --world_size 6 \ 117 | --MLM --axiom REP RANK REG PROX STM \ 118 | --clf_model XGB_DIR 119 | ``` 120 | Here model type can be `ARES` or `ICT`. 121 | 122 | ### Zero-shot evaluation (based on AS top100) 123 | ```shell 124 | export MODEL_DIR=/path/to/ares-simple/ 125 | export CKPT_NAME=ares.ckpt 126 | 127 | cd finetune 128 | 129 | CUDA_VISIBLE_DEVICES=0 python train.py \ 130 | --test \ 131 | --PRE_TRAINED_MODEL_NAME MODEL_DIR \ 132 | --model_type ARES \ 133 | --model_name ARES_simple \ 134 | --load_ckpt \ 135 | --model_path CKPT_NAME 136 | ``` 137 | You can get: 138 | ```bash 139 | ##################### 140 | <----- MS Dev -----> 141 | MRR @10: 0.2991 142 | MRR @100: 0.3130 143 | QueriesRanked: 5193 144 | ##################### 145 | ``` 146 | on MS MARCO dev set and: 147 | ```bash 148 | ############################# 149 | <--------- DL 2019 ---------> 150 | QueriesRanked: 43 151 | nDCG @10: 0.5955 152 | nDCG @100: 0.4863 153 | ############################# 154 | ``` 155 | on DL 2019 set. 156 | 157 | ### Fine-tuning 158 | ```shell 159 | export MODEL_DIR=/path/to/ares-simple/ 160 | 161 | cd finetune 162 | 163 | CUDA_VISIBLE_DEVICES=0,1,2,3 NCCL_BLOCKING_WAIT=1 \ 164 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 train.py \ 165 | --model_type ARES \ 166 | --distributed_train \ 167 | --PRE_TRAINED_MODEL_NAME MODEL_DIR \ 168 | --gpu_num 4 --world_size 4 \ 169 | --model_name ARES_simple 170 | ``` 171 | 172 | ### Visualization 173 | ```shell 174 | export MODEL_DIR=/path/to/ares-simple/ 175 | export SAVE_DIR=/path/to/output/ 176 | export CKPT_NAME=ares.ckpt 177 | 178 | cd visualization 179 | 180 | CUDA_VISIBLE_DEVICES=0 python visual.py \ 181 | --PRE_TRAINED_MODEL_NAME MODEL_DIR \ 182 | --model_name ARES_simple \ 183 | --visual_q_num 1 \ 184 | --visual_d_num 5 \ 185 | --save_path SAVE_DIR \ 186 | --model_path CKPT_NAME 187 | ``` 188 | 189 | ## Results 190 | Zero-shot performance: 191 | 192 | | Model Name | MS MARCO MRR@10 | MS MARCO MRR@100 | DL NDCG@10 | DL NDCG@100 | COVID | EQ | 193 | | :--: | :--: | :--: | :--: | :--: | :--: | :--: | 194 | | BM25 | 0.2962 | 0.3107 | 0.5776 | 0.4795 | 0.4857 | 0.6690 | 195 | | BERT | 0.1820 | 0.2012 | 0.4059 | 0.4198 | 0.4314 | 0.6055 | 196 | | PROPwiki | 0.2429 | 0.2596 | 0.5088 | 0.4525 | 0.4857 | 0.5991 | 197 | | PROPmarco | 0.2763 | 0.2914 | 0.5317 | 0.4623 | 0.4829 | 0.6454 | 198 | | ARESstrict | 0.2630 | 0.2785 | 0.4942 | 0.4504 | 0.4786 | 0.6923 | 199 | | AREShard | 0.2627 | 0.2780 | 0.5189 | 0.4613 | 0.4943 | 0.6822 | 200 | | ARESsimple | 0.2991 | 0.3130 | 0.5955 | 0.4863 | 0.4957 | 0.6916 | 201 | 202 | 203 | Few-shot performance: 204 | ![img](imgs/few-shot-metric.png) 205 | 206 | Visualization (attribution values have been normalized within a document): 207 | ![img](imgs/ARES_simple_ag.png) 208 | 209 | ## Citation 210 | If you find our work useful, please do not save your star and cite our work: 211 | ``` 212 | @inproceedings{chen2022axiomatically, 213 | title={Axiomatically Regularized Pre-training for Ad hoc Search}, 214 | author={Chen, Jia and Liu, Yiqun and Fang, Yan and Mao, Jiaxin and Fang, Hui and Yang, Shenghao and Xie, Xiaohui and Zhang, Min and Ma, Shaoping}, 215 | booktitle={Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval}, 216 | year={2022} 217 | } 218 | ``` 219 | 220 | 221 | ## Notice 222 | * Please make sure that all the pre-trained model parameters have been loaded correctly, or the zero-shot and the fine-tuning performance will be greatly impacted. 223 | * We welcome anyone who would like to contribute to this repo. 🤗 224 | * If you have any other questions, please feel free to contact me via [chenjia0831@gmail.com]() or open an issue. 225 | * Code for data preprocessing will come soon. Please stay tuned~ 226 | -------------------------------------------------------------------------------- /visualization/visual.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | import os 6 | import random 7 | from tqdm import tqdm 8 | import json 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import pandas as pd 13 | from model.modeling import ARES 14 | from transformers import PretrainedConfig, BertConfig,BertTokenizer 15 | from dataloader import get_visual_test_qd_loader, get_test_qd_loader 16 | from config import get_config 17 | import warnings 18 | from captum.attr import LayerIntegratedGradients 19 | # from captum.attr import visualization as viz 20 | import visualization as viz 21 | from gensim.models import KeyedVectors 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | def eval_model(model, test_qd_loader, device, config): 26 | model.eval() 27 | qd_rank = pd.DataFrame(columns=['q_id', 'd_id', 'rank', 'score']) 28 | q_id_list, d_id_list, rank, score = [], [], [], [] 29 | top5_q_id_list, top5_d_id_list, top5_rank_list,top5_score_list = [], [], [], [] 30 | num_instances = len(test_qd_loader) 31 | with torch.no_grad(): 32 | for i, batch_data in enumerate(tqdm(test_qd_loader, desc=f"Evaluating progress", total=num_instances)): 33 | input_ids, attention_mask, token_type_ids = batch_data["token_ids"], batch_data["attention_mask"], \ 34 | batch_data["token_type_ids"] 35 | 36 | input_ids = input_ids.to(device) # bs x 512 37 | attention_mask = attention_mask.to(device) # bs x 512 38 | token_type_ids = token_type_ids.to(device) 39 | 40 | output = model( 41 | input_ids=input_ids, 42 | config=config, 43 | input_mask=attention_mask, 44 | token_type_ids=token_type_ids, 45 | ) # 100 x 1 46 | 47 | output = output.squeeze() 48 | q_ids = batch_data["q_id"] 49 | d_ids = batch_data["d_id"] 50 | scores = output.cpu().tolist() 51 | top5_q_id_list.extend(q_ids[:5]) 52 | top5_d_id_list.extend(d_ids[:5]) 53 | top5_score_list.extend(scores[:5]) 54 | tuples = list(zip(q_ids, d_ids, scores)) 55 | sorted_tuples = sorted(tuples, key=lambda x: x[2], reverse=True) 56 | for idx, this_tuple in enumerate(sorted_tuples): 57 | q_id_list.append(this_tuple[0]) 58 | d_id_list.append(this_tuple[1]) 59 | rank.append(idx + 1) 60 | score.append(this_tuple[2]) 61 | qd_rank['q_id'] = q_id_list 62 | qd_rank['d_id'] = d_id_list 63 | qd_rank['rank'] = rank 64 | qd_rank['score'] = score 65 | df_rank = pd.DataFrame(columns=['q_id', 'Q0', 'd_id', 'rank', 'score', 'standard']) 66 | df_rank['q_id'] = qd_rank['q_id'] 67 | df_rank['Q0'] = ['Q0'] * len(qd_rank['q_id']) 68 | df_rank['d_id'] = qd_rank['d_id'] 69 | df_rank['rank'] = qd_rank['rank'] 70 | df_rank['score'] = qd_rank['score'] 71 | df_rank['standard'] = ['STANDARD'] * len(qd_rank['q_id']) 72 | df_rank.to_csv(f"{config.save_path}/dl2019_qd_rank_{config.model_name}.tsv", sep=' ', index=False, header=False) 73 | result_lines = os.popen(f'trec_eval -m ndcg_cut.10,100 {config.dl2019_qd_dir} {config.save_path}/dl2019_qd_rank_{config.model_name}.tsv').read().strip().split("\n") 74 | ndcg_10, ndcg_100 = float(result_lines[0].strip().split()[-1]), float(result_lines[1].strip().split()[-1]) 75 | metrics = {'nDCG @10': ndcg_10, 'nDCG @100': ndcg_100, 'QueriesRanked': len(set(qd_rank['q_id']))} 76 | print('\n#############################') 77 | print(config.model_name) 78 | print('<--------- DL 2019 --------->') 79 | for metric in sorted(metrics): 80 | print('{}: {}'.format(metric, metrics[metric])) 81 | print('#############################\n') 82 | return df_rank 83 | 84 | 85 | def visual_model(lig, tokenizer, qd_loader, df_dl2019_qds,device, config): 86 | score_viz_list = [] 87 | index = 0 88 | for i, batch_data in enumerate(tqdm(qd_loader, desc=f"IG progress", total=len(qd_loader))): 89 | q_ids, d_ids, ranks, input_ids, ref_input_ids, attention_mask, token_type_ids, ref_token_type_ids = \ 90 | batch_data["q_id"], batch_data["d_id"], batch_data["rank"],\ 91 | batch_data["token_ids"], batch_data["ref_token_ids"], batch_data["attention_mask"], batch_data["token_type_ids"], batch_data["ref_token_type_ids"] 92 | input_ids = input_ids.to(device) # bs x 512 93 | ref_input_ids = ref_input_ids.to(device) # bs x 512 94 | attention_mask = attention_mask.to(device) # bs x 512 95 | token_type_ids = token_type_ids.to(device) 96 | ref_token_type_ids = ref_token_type_ids.to(device) 97 | attributions, deltas = lig.attribute( 98 | inputs=(input_ids, token_type_ids), 99 | baselines=(ref_input_ids,ref_token_type_ids), 100 | return_convergence_delta=True, 101 | additional_forward_args=attention_mask, 102 | internal_batch_size=5 103 | ) 104 | for j, attribution,delta in enumerate(zip(attributions, deltas)): # for 512*768 in bs*512*768 105 | attribution_sum = attribution.sum(dim=-1).squeeze(0) # 512 106 | tokens = [token.replace("Ġ", "") for token in tokenizer.convert_ids_to_tokens(input_ids[j])] 107 | sep_index = tokens.index('[SEP]') 108 | query_tokens = tokens[:sep_index] 109 | doc_tokens = tokens[sep_index:sep_index+250] 110 | tokens = query_tokens+doc_tokens 111 | query_attribution_sum = attribution_sum[:sep_index] / torch.norm(attribution_sum) 112 | doc_attribution_sum = attribution_sum[sep_index:sep_index+250] / torch.norm(attribution_sum[sep_index:sep_index+250]) 113 | attribution_sum = torch.cat((query_attribution_sum,doc_attribution_sum), axis=-1) 114 | v_q_id,v_d_id,rank = q_ids[j],d_ids[j],ranks[j] 115 | try: 116 | level = df_dl2019_qds.set_index(0).loc[int(v_q_id)].set_index(2).loc[str(v_d_id)][3] 117 | except: 118 | level = -1 119 | score_viz = viz.VisualizationDataRecord( 120 | attribution_sum, 121 | level, 122 | rank, 123 | v_q_id, 124 | v_d_id, 125 | tokens, 126 | delta, 127 | ) 128 | score_viz_list.append(score_viz) 129 | # index += 1 130 | html = viz.visualize_text(score_viz_list) 131 | html_filepath = f"output_{config.model_name}.html" 132 | with open(html_filepath, "w") as html_file: 133 | html_file.write(html.data) 134 | 135 | 136 | if __name__ == '__main__': 137 | config = get_config() 138 | random.seed(config.seed) 139 | np.random.seed(config.seed) 140 | torch.manual_seed(config.seed) 141 | config.local_rank = config.local_rank 142 | if torch.cuda.is_available(): 143 | torch.cuda.manual_seed_all(config.seed) 144 | torch.cuda.set_device(config.local_rank) 145 | print('GPU is ON!') 146 | device = torch.device('cuda') 147 | else: 148 | device = torch.device("cpu") 149 | df_dl2019_qds = pd.read_csv(config.dl2019_qd_dir, sep=' ', header=None) 150 | dl2019_top100 = pd.read_csv(config.dl100_dir, sep='\t', header=None) 151 | dl2019_qs = {} 152 | with open(config.dl2019_qs_dir) as f_qs: 153 | for line in f_qs: 154 | es = json.loads(line) 155 | qid, ids = es["id"], es["ids"] 156 | if qid not in dl2019_qs: 157 | dl2019_qs[qid] = ids 158 | with open(config.docid2id_dir) as f_docid2id: 159 | docid2id = json.load(f_docid2id) 160 | collection_size = len(docid2id) 161 | doc_tokens = np.memmap(config.memmap_doc_dir, dtype='int32', shape=(collection_size, 512)) 162 | print("\n========== Loading DL 2019 data ==========") 163 | dl2019_qd_loader = get_test_qd_loader(dl2019_top100, dl2019_qs, doc_tokens, docid2id, config) 164 | print(f"dl2019_q: {len(dl2019_qs)}, dl2019_q_batchs:{len(dl2019_qd_loader)}") 165 | 166 | print("Loading model...") 167 | if 'BERT' in config.model_name: 168 | model = ARES.from_pretrained(config.PRE_TRAINED_MODEL_NAME) 169 | tokenizer = BertTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME) 170 | elif 'ARES' in config.model_name or 'PROP' in config.model_name: 171 | cfg = PretrainedConfig.get_config_dict(config.PRE_TRAINED_MODEL_NAME)[0] 172 | if not config.gradient_checkpointing: 173 | del cfg["gradient_checkpointing"] 174 | del cfg["parameter_sharing"] 175 | cfg = BertConfig.from_dict(cfg) 176 | model = ARES(config=cfg) 177 | model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(f"{config.model_path}/{config.model_name}", map_location={'cuda:0':f'cuda:{config.local_rank}'}).items()},strict=False) 178 | tokenizer = BertTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME) 179 | model = model.to(device) 180 | print("Loading model finish") 181 | model_prefix = model.base_model_prefix 182 | model_base = getattr(model, model_prefix) 183 | if hasattr(model_base, "embeddings"): 184 | model_embeddings = getattr(model_base, "embeddings") 185 | lig = LayerIntegratedGradients(model, model_embeddings) 186 | qd_rank = eval_model( 187 | model, 188 | dl2019_qd_loader, 189 | device, 190 | config 191 | ) 192 | print("\n========== Loading visual DL 2019 data ==========") 193 | visual_dl2019_qd_loader = get_visual_test_qd_loader(qd_rank, dl2019_qs, doc_tokens, docid2id,config) 194 | visual_model( 195 | lig, 196 | tokenizer, 197 | visual_dl2019_qd_loader, 198 | df_dl2019_qds, 199 | device, 200 | config 201 | ) -------------------------------------------------------------------------------- /preprocess/Eval4.0.pl: -------------------------------------------------------------------------------- 1 | #! 2 | # author: Jun Xu and Tie-Yan Liu 3 | # modified by Jun Xu, March 3, 2009 (for Letor 4.0) 4 | use strict; 5 | 6 | #hash table for NDCG, 7 | my %hsNdcgRelScore = ( "2", 3, 8 | "1", 1, 9 | "0", 0, 10 | ); 11 | 12 | #hash table for Precision@N and MAP 13 | my %hsPrecisionRel = ("2", 1, 14 | "1", 1, 15 | "0", 0 16 | ); 17 | #modified by Jun Xu, March 3, 2009 18 | # for Letor 4.0. only output top 10 precision and ndcg 19 | # my $iMaxPosition = 16; 20 | my $iMaxPosition = 10; 21 | 22 | my $argc = $#ARGV+1; 23 | if($argc != 4) 24 | { 25 | print "Invalid command line.\n"; 26 | print "Usage: perl Eval.pl argv[1] argv[2] argv[3] argv[4]\n"; 27 | print "argv[1]: feature file \n"; 28 | print "argv[2]: prediction file\n"; 29 | print "argv[3]: result (output) file\n"; 30 | print "argv[4]: flag. If flag equals 1, output the evaluation results per query; if flag equals 0, simply output the average results.\n"; 31 | exit -1; 32 | } 33 | my $fnFeature = $ARGV[0]; 34 | my $fnPrediction = $ARGV[1]; 35 | my $fnResult = $ARGV[2]; 36 | my $flag = $ARGV[3]; 37 | if($flag != 1 && $flag != 0) 38 | { 39 | print "Invalid command line.\n"; 40 | print "Usage: perl Eval.pl argv[1] argv[2] argv[3] argv[4]\n"; 41 | print "Flag should be 0 or 1\n"; 42 | exit -1; 43 | } 44 | 45 | my %hsQueryDocLabelScore = ReadInputFiles($fnFeature, $fnPrediction); 46 | my %hsQueryEval = EvalQuery(\%hsQueryDocLabelScore); 47 | OuputResults($fnResult, %hsQueryEval); 48 | 49 | 50 | sub OuputResults 51 | { 52 | my ($fnOut, %hsResult) = @_; 53 | open(FOUT, ">$fnOut"); 54 | 55 | my @qids = sort{$a <=> $b} keys(%hsResult); 56 | my $numQuery = @qids; 57 | 58 | #Precision@N and MAP 59 | # modified by Jun Xu, March 3, 2009 60 | # changing the output format 61 | print FOUT "qid\tP\@1\tP\@2\tP\@3\tP\@4\tP\@5\tP\@6\tP\@7\tP\@8\tP\@9\tP\@10\tMAP\n"; 62 | #--------------------------------------------- 63 | my @prec; 64 | my $map = 0; 65 | for(my $i = 0; $i < $#qids + 1; $i ++) 66 | { 67 | # modified by Jun Xu, March 3, 2009 68 | # output the real query id 69 | my $qid = $qids[$i]; 70 | my @pN = @{$hsResult{$qid}{"PatN"}}; 71 | my $map_q = $hsResult{$qid}{"MAP"}; 72 | if ($flag == 1) 73 | { 74 | print FOUT "$qid\t"; 75 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++) 76 | { 77 | print FOUT sprintf("%.4f\t", $pN[$iPos]); 78 | } 79 | print FOUT sprintf("%.4f\n", $map_q); 80 | } 81 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++) 82 | { 83 | $prec[$iPos] += $pN[$iPos]; 84 | } 85 | $map += $map_q; 86 | } 87 | print FOUT "Average\t"; 88 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++) 89 | { 90 | $prec[$iPos] /= ($#qids + 1); 91 | print FOUT sprintf("%.4f\t", $prec[$iPos]); 92 | } 93 | $map /= ($#qids + 1); 94 | print FOUT sprintf("%.4f\n\n", $map); 95 | 96 | #NDCG and MeanNDCG 97 | # modified by Jun Xu, March 3, 2009 98 | # changing the output format 99 | print FOUT "qid\tNDCG\@1\tNDCG\@2\tNDCG\@3\tNDCG\@4\tNDCG\@5\tNDCG\@6\tNDCG\@7\tNDCG\@8\tNDCG\@9\tNDCG\@10\tMeanNDCG\n"; 100 | #--------------------------------------------- 101 | my @ndcg; 102 | my $meanNdcg = 0; 103 | for(my $i = 0; $i < $#qids + 1; $i ++) 104 | { 105 | # modified by Jun Xu, March 3, 2009 106 | # output the real query id 107 | my $qid = $qids[$i]; 108 | my @ndcg_q = @{$hsResult{$qid}{"NDCG"}}; 109 | my $meanNdcg_q = $hsResult{$qid}{"MeanNDCG"}; 110 | if ($flag == 1) 111 | { 112 | print FOUT "$qid\t"; 113 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++) 114 | { 115 | print FOUT sprintf("%.4f\t", $ndcg_q[$iPos]); 116 | } 117 | print FOUT sprintf("%.4f\n", $meanNdcg_q); 118 | } 119 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++) 120 | { 121 | $ndcg[$iPos] += $ndcg_q[$iPos]; 122 | } 123 | $meanNdcg += $meanNdcg_q; 124 | } 125 | print FOUT "Average\t"; 126 | for(my $iPos = 0; $iPos < $iMaxPosition; $iPos ++) 127 | { 128 | $ndcg[$iPos] /= ($#qids + 1); 129 | print FOUT sprintf("%.4f\t", $ndcg[$iPos]); 130 | } 131 | $meanNdcg /= ($#qids + 1); 132 | print FOUT sprintf("%.4f\n\n", $meanNdcg); 133 | 134 | close(FOUT); 135 | } 136 | 137 | sub EvalQuery 138 | { 139 | my $pHash = $_[0]; 140 | my %hsResults; 141 | 142 | my @qids = sort{$a <=> $b} keys(%$pHash); 143 | for(my $i = 0; $i < @qids; $i ++) 144 | { 145 | my $qid = $qids[$i]; 146 | my @tmpDid = sort{$$pHash{$qid}{$a}{"lineNum"} <=> $$pHash{$qid}{$b}{"lineNum"}} keys(%{$$pHash{$qid}}); 147 | my @docids = sort{$$pHash{$qid}{$b}{"pred"} <=> $$pHash{$qid}{$a}{"pred"}} @tmpDid; 148 | my @rates; 149 | 150 | for(my $iPos = 0; $iPos < $#docids + 1; $iPos ++) 151 | { 152 | $rates[$iPos] = $$pHash{$qid}{$docids[$iPos]}{"label"}; 153 | } 154 | 155 | my $map = MAP(@rates); 156 | my @PAtN = PrecisionAtN($iMaxPosition, @rates); 157 | # modified by Jun Xu, calculate all possible positions' NDCG for MeanNDCG 158 | #my @Ndcg = NDCG($iMaxPosition, @rates); 159 | 160 | my @Ndcg = NDCG($#rates + 1, @rates); 161 | my $meanNdcg = 0; 162 | for(my $iPos = 0; $iPos < $#Ndcg + 1; $iPos ++) 163 | { 164 | $meanNdcg += $Ndcg[$iPos]; 165 | } 166 | $meanNdcg /= ($#Ndcg + 1); 167 | 168 | 169 | @{$hsResults{$qid}{"PatN"}} = @PAtN; 170 | $hsResults{$qid}{"MAP"} = $map; 171 | @{$hsResults{$qid}{"NDCG"}} = @Ndcg; 172 | $hsResults{$qid}{"MeanNDCG"} = $meanNdcg; 173 | 174 | } 175 | return %hsResults; 176 | } 177 | 178 | sub ReadInputFiles 179 | { 180 | my ($fnFeature, $fnPred) = @_; 181 | my %hsQueryDocLabelScore; 182 | 183 | if(!open(FIN_Feature, $fnFeature)) 184 | { 185 | print "Invalid command line.\n"; 186 | print "Open \$fnFeature\" failed.\n"; 187 | exit -2; 188 | } 189 | if(!open(FIN_Pred, $fnPred)) 190 | { 191 | print "Invalid command line.\n"; 192 | print "Open \"$fnPred\" failed.\n"; 193 | exit -2; 194 | } 195 | 196 | my $lineNum = 0; 197 | while(defined(my $lnFea = )) 198 | { 199 | $lineNum ++; 200 | chomp($lnFea); 201 | my $predScore = ; 202 | if (!defined($predScore)) 203 | { 204 | print "Error to read $fnPred at line $lineNum.\n"; 205 | exit -2; 206 | } 207 | chomp($predScore); 208 | # modified by Jun Xu, 2008-9-9 209 | # Labels may have more than 3 levels 210 | # qid and docid may not be numeric 211 | # if ($lnFea =~ m/^([0-2]) qid\:(\d+).*?\#docid = (\d+)$/) 212 | 213 | # modified by Jun Xu, March 3, 2009 214 | # Letor 4.0's file format is different to Letor 3.0 215 | # if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+)$/) 216 | if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+) inc = ([^\s]+) prob = ([^\s]+).$/) 217 | { 218 | my $label = $1; 219 | my $qid = $2; 220 | my $did = $3; 221 | my $inc = $4; 222 | my $prob= $5; 223 | $hsQueryDocLabelScore{$qid}{$did}{"label"} = $label; 224 | $hsQueryDocLabelScore{$qid}{$did}{"inc"} = $inc; 225 | $hsQueryDocLabelScore{$qid}{$did}{"prob"} = $prob; 226 | $hsQueryDocLabelScore{$qid}{$did}{"pred"} = $predScore; 227 | $hsQueryDocLabelScore{$qid}{$did}{"lineNum"} = $lineNum; 228 | } 229 | else 230 | { 231 | print "Error to parse $fnFeature at line $lineNum:\n$lnFea\n"; 232 | exit -2; 233 | } 234 | } 235 | close(FIN_Feature); 236 | close(FIN_Pred); 237 | return %hsQueryDocLabelScore; 238 | } 239 | 240 | 241 | sub PrecisionAtN 242 | { 243 | my ($topN, @rates) = @_; 244 | my @PrecN; 245 | my $numRelevant = 0; 246 | # modified by Jun Xu, 2009-4-24. 247 | # if # retrieved doc < $topN, the P@N will consider the hole as irrelevant 248 | # for(my $iPos = 0; $iPos < $topN && $iPos < $#rates + 1; $iPos ++) 249 | # 250 | for (my $iPos = 0; $iPos < $topN; $iPos ++) 251 | { 252 | my $r; 253 | if ($iPos < $#rates + 1) 254 | { 255 | $r = $rates[$iPos]; 256 | } 257 | else 258 | { 259 | $r = 0; 260 | } 261 | $numRelevant ++ if ($hsPrecisionRel{$r} == 1); 262 | $PrecN[$iPos] = $numRelevant / ($iPos + 1); 263 | } 264 | return @PrecN; 265 | } 266 | 267 | sub MAP 268 | { 269 | my @rates = @_; 270 | 271 | my $numRelevant = 0; 272 | my $avgPrecision = 0.0; 273 | for(my $iPos = 0; $iPos < $#rates + 1; $iPos ++) 274 | { 275 | if ($hsPrecisionRel{$rates[$iPos]} == 1) 276 | { 277 | $numRelevant ++; 278 | $avgPrecision += ($numRelevant / ($iPos + 1)); 279 | } 280 | } 281 | return 0.0 if ($numRelevant == 0); 282 | #return sprintf("%.4f", $avgPrecision / $numRelevant); 283 | return $avgPrecision / $numRelevant; 284 | } 285 | 286 | sub DCG 287 | { 288 | my ($topN, @rates) = @_; 289 | my @dcg; 290 | 291 | $dcg[0] = $hsNdcgRelScore{$rates[0]}; 292 | # Modified by Jun Xu, 2009-4-24 293 | # if # retrieved doc < $topN, the NDCG@N will consider the hole as irrelevant 294 | # for(my $iPos = 1; $iPos < $topN && $iPos < $#rates + 1; $iPos ++) 295 | # 296 | for(my $iPos = 1; $iPos < $topN; $iPos ++) 297 | { 298 | my $r; 299 | if ($iPos < $#rates + 1) 300 | { 301 | $r = $rates[$iPos]; 302 | } 303 | else 304 | { 305 | $r = 0; 306 | } 307 | if ($iPos < 2) 308 | { 309 | $dcg[$iPos] = $dcg[$iPos - 1] + $hsNdcgRelScore{$r}; 310 | } 311 | else 312 | { 313 | $dcg[$iPos] = $dcg[$iPos - 1] + ($hsNdcgRelScore{$r} * log(2.0) / log($iPos + 1.0)); 314 | } 315 | } 316 | return @dcg; 317 | } 318 | sub NDCG 319 | { 320 | my ($topN, @rates) = @_; 321 | my @ndcg; 322 | my @dcg = DCG($topN, @rates); 323 | my @stRates = sort {$hsNdcgRelScore{$b} <=> $hsNdcgRelScore{$a}} @rates; 324 | my @bestDcg = DCG($topN, @stRates); 325 | 326 | for(my $iPos =0; $iPos < $topN && $iPos < $#rates + 1; $iPos ++) 327 | { 328 | $ndcg[$iPos] = 0; 329 | $ndcg[$iPos] = $dcg[$iPos] / $bestDcg[$iPos] if ($bestDcg[$iPos] != 0); 330 | } 331 | return @ndcg; 332 | } -------------------------------------------------------------------------------- /pretrain/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import os 7 | import sys 8 | sys.path.insert(0, '../') 9 | 10 | from tqdm import tqdm 11 | import json 12 | import torch 13 | import numpy as np 14 | from datetime import timedelta, datetime 15 | from model.modeling import ARES, ICT 16 | 17 | from transformers import AdamW, get_linear_schedule_with_warmup 18 | from transformers import PretrainedConfig, BertConfig 19 | from torch import nn 20 | from torch.cuda.amp import autocast, GradScaler 21 | 22 | from dataloader import get_train_qd_loader, get_ict_loader 23 | from config import get_config 24 | import warnings 25 | 26 | warnings.filterwarnings("ignore") 27 | torch.backends.cudnn.benchmark = True 28 | 29 | 30 | def train_epoch(model, scaler, qd_loader, optimizer, scheduler, device, config): 31 | model.train() 32 | losses = [] 33 | 34 | num_instances = len(qd_loader) 35 | for step, batch_data in enumerate(tqdm(qd_loader, desc=f"Pretraining {config.model_type} progress", total=num_instances)): 36 | input_ids, attention_mask, masked_lm_ids = batch_data["token_ids"], batch_data["attention_mask"], batch_data["masked_lm_ids"] 37 | if config.model_type == 'ICT': 38 | token_type_ids = None 39 | input_ids, attention_mask, masked_lm_ids = input_ids.squeeze(), attention_mask.squeeze(), masked_lm_ids.squeeze() 40 | this_batch_size = input_ids.size()[0] 41 | if this_batch_size < 2: 42 | continue 43 | else: 44 | this_batch_size = input_ids.size()[0] 45 | token_type_ids = batch_data["token_type_ids"] 46 | 47 | input_ids = input_ids.reshape(this_batch_size * 2, -1) 48 | attention_mask = attention_mask.reshape(this_batch_size * 2, -1) 49 | masked_lm_ids = masked_lm_ids.reshape(this_batch_size * 2, -1) 50 | token_type_ids = token_type_ids.reshape(this_batch_size * 2, -1) if token_type_ids is not None else token_type_ids 51 | 52 | input_ids = input_ids.to(device) # bs x 512 53 | attention_mask = attention_mask.to(device) # bs x 512 54 | masked_lm_ids = masked_lm_ids.to(device) 55 | 56 | token_type_ids = token_type_ids.to(device) if token_type_ids is not None else token_type_ids 57 | 58 | with autocast(): 59 | loss = model( 60 | input_ids=input_ids, 61 | config=config, 62 | input_mask=attention_mask, 63 | token_type_ids=token_type_ids, 64 | masked_lm_labels=masked_lm_ids, 65 | device=device 66 | ) 67 | 68 | losses.append(loss.item()) 69 | scaler.scale(loss).backward() 70 | 71 | # gradient accumulation 72 | if (step + 1) % config.gradient_accumulation_steps == 0: 73 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip) 74 | scaler.step(optimizer) 75 | scaler.update() 76 | 77 | scheduler.step() 78 | optimizer.zero_grad() 79 | 80 | if step % int(config.print_every) == 0: 81 | print(f"\n[Train] Loss at step {step} = {loss.item()}, lr = {optimizer.state_dict()['param_groups'][0]['lr']}") 82 | 83 | if step % 5000 == 0 and config.local_rank == 0: 84 | print('[SAVE] Saving model ... ') 85 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 86 | this_loss = round(float(np.mean(losses)), 4) 87 | torch.save(model_to_save.state_dict(), f"{save_dir}/{config.model_name}_{this_loss}_step{step}") 88 | return np.mean(losses) 89 | 90 | 91 | if __name__ == '__main__': 92 | 93 | # get configs 94 | config = get_config() 95 | 96 | # set save dir 97 | today = datetime.today().strftime('%Y-%m-%d') 98 | save_dir = f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/{today}" 99 | if not os.path.exists(save_dir): 100 | os.mkdir(save_dir) 101 | 102 | np.random.seed(config.seed) 103 | torch.manual_seed(config.seed) 104 | config.local_rank = config.local_rank 105 | if torch.cuda.is_available(): 106 | torch.cuda.manual_seed_all(config.seed) 107 | torch.cuda.set_device(config.local_rank) 108 | print('GPU is ON!') 109 | device = torch.device(f'cuda:{config.local_rank}') 110 | else: 111 | device = torch.device("cpu") 112 | 113 | # set max timeout=50 hours 114 | if config.distributed_train: 115 | torch.distributed.init_process_group(backend="nccl", timeout=timedelta(180000000), rank=config.local_rank, world_size=config.world_size) 116 | local_rank = config.local_rank 117 | if local_rank != -1: 118 | print("Using Distributed") 119 | 120 | # json files 121 | doc2query = {} 122 | with open(config.doc2query_dir) as f_doc2query: 123 | for line in f_doc2query: 124 | es = json.loads(line) 125 | docid = es["docid"] 126 | queries = es["queries"] 127 | if docid not in doc2query: 128 | doc2query[docid] = queries 129 | 130 | with open(config.gen_qid2id_dir) as f_gen_qid2id: 131 | gen_qid2id = json.load(f_gen_qid2id) 132 | 133 | # save memory 134 | if config.model_type == 'ARES': 135 | q_num = len(gen_qid2id) 136 | 137 | axiom_rank = np.memmap(f"{config.axiom_feature_dir}/memmap/rank.memmap", dtype='float', shape=(q_num, 1)) 138 | axiom_list = [] 139 | print(config.axiom) 140 | if 'PROX' in config.axiom: 141 | prox_1 = np.memmap(f"{config.axiom_feature_dir}/memmap/prox-1.memmap", dtype='float', shape=(q_num, 1)) 142 | prox_2 = np.memmap(f"{config.axiom_feature_dir}/memmap/prox-2.memmap", dtype='float', shape=(q_num, 1)) 143 | axiom_list.append(['PROX-1', prox_1]) 144 | axiom_list.append(['PROX-2', prox_2]) 145 | 146 | if 'REP' in config.axiom: 147 | rep_ql = np.memmap(f"{config.axiom_feature_dir}/memmap/rep-ql.memmap", dtype='float', shape=(q_num, 1)) 148 | rep_tfidf = np.memmap(f"{config.axiom_feature_dir}/memmap/rep-tfidf.memmap", dtype='float', shape=(q_num, 1)) 149 | axiom_list.append(['REP-QL', rep_ql]) 150 | axiom_list.append(['REP-TFIDF', rep_tfidf]) 151 | 152 | if 'REG' in config.axiom: 153 | reg = np.memmap(f"{config.axiom_feature_dir}/memmap/reg.memmap", dtype='float', shape=(q_num, 1)) 154 | axiom_list.append(['REG', reg]) 155 | 156 | if 'STM' in config.axiom: 157 | stm_1 = np.memmap(f"{config.axiom_feature_dir}/memmap/stm-1.memmap", dtype='float', shape=(q_num, 1)) 158 | stm_2 = np.memmap(f"{config.axiom_feature_dir}/memmap/stm-2.memmap", dtype='float', shape=(q_num, 1)) 159 | stm_3 = np.memmap(f"{config.axiom_feature_dir}/memmap/stm-3.memmap", dtype='float', shape=(q_num, 1)) 160 | 161 | axiom_list.append(['STM-1', stm_1]) 162 | axiom_list.append(['STM-2', stm_2]) 163 | axiom_list.append(['STM-3', stm_3]) 164 | 165 | axiom_list.append(['RANK', axiom_rank]) 166 | gen_qs_size = len(gen_qid2id) 167 | gen_qs_tokens = np.memmap(config.gen_qs_memmap_dir, dtype='int32', shape=(gen_qs_size, 15)) 168 | 169 | with open(config.docid2id_dir) as f_docid2id: 170 | docid2id = json.load(f_docid2id) 171 | collection_size = len(docid2id) 172 | doc_tokens = np.memmap(config.memmap_doc_dir, dtype='int32', shape=(collection_size, 512)) 173 | 174 | print("Load data done!") 175 | 176 | cfg = PretrainedConfig.get_config_dict(config.PRE_TRAINED_MODEL_NAME)[0] 177 | if not config.gradient_checkpointing: 178 | del cfg["gradient_checkpointing"] # gradient checkpointing conflicts with parallel training 179 | del cfg["parameter_sharing"] 180 | cfg = BertConfig.from_dict(cfg) 181 | 182 | # train 183 | if not config.load_ckpt: 184 | if config.model_type == 'ICT': 185 | model = ICT.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg) 186 | else: 187 | model = ARES.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg) 188 | else: 189 | if config.model_type == 'ICT': 190 | model = ICT(config=cfg) 191 | else: 192 | model = ARES(config=cfg) 193 | model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/{config.model_path}", 194 | map_location={'cuda:0': f'cuda:{config.local_rank}'}).items()}) 195 | model = model.to(device) 196 | print("Loading model...") 197 | model = model.cuda() 198 | 199 | if config.optim == 'adam': 200 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 201 | elif config.optim == 'amsgrad': 202 | optimizer = torch.optim.Amsgrad(model.parameters(), lr=config.lr) 203 | elif config.optim == 'adagrad': 204 | optimizer = torch.optim.Adagrad(model.parameters(), lr=config.lr) 205 | else: # adamw, weight decay not depend on the lr 206 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 207 | optimizer_grouped_parameters = [ 208 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, 209 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 210 | ] 211 | optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=config.adam_epsilon) 212 | 213 | # train 214 | if config.distributed_train: 215 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=True) 216 | config.warm_up = config.warm_up / config.gpu_num 217 | 218 | for epoch in range(config.epochs): 219 | print(f'Epoch {epoch + 1}/{config.epochs}') 220 | print('-' * 10) 221 | 222 | print("========== Loading training data ==========") 223 | if config.model_type == 'ARES': 224 | train_qd_loader = get_train_qd_loader(doc_tokens, docid2id, config, 225 | doc2query=doc2query, 226 | gen_qs=gen_qs_tokens, 227 | gen_qid2id=gen_qid2id, 228 | axiom_feature=axiom_list) # b_sz * data samples 229 | else: 230 | train_qd_loader = get_ict_loader(doc_tokens, docid2id, config) 231 | print(f"train_batchs:{len(train_qd_loader)}, batch_size: {config.batch_size}") 232 | 233 | scaler = GradScaler(enabled=True) 234 | total_steps = len(train_qd_loader) * config.epochs 235 | 236 | scheduler = get_linear_schedule_with_warmup( 237 | optimizer, 238 | num_warmup_steps=int(total_steps * config.warm_up), 239 | num_training_steps=total_steps 240 | ) 241 | 242 | train_loss = train_epoch( 243 | model, 244 | scaler, 245 | train_qd_loader, 246 | optimizer, 247 | scheduler, 248 | device, 249 | config, 250 | ) 251 | scheduler.step() 252 | print(f'Train loss {train_loss}') 253 | 254 | if config.local_rank == 0: 255 | print('[SAVE] Saving model ... ') 256 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 257 | this_loss = round(float(train_loss), 4) 258 | torch.save(model_to_save.state_dict(), f"{save_dir}/{config.model_name}_{this_loss}") 259 | 260 | 261 | 262 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model/modeling.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import sys 7 | import numpy as np 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from torch import Tensor 12 | from torch.nn import CrossEntropyLoss, MarginRankingLoss 13 | from torch.nn import Softmax 14 | from torch.cuda.amp import autocast 15 | from transformers import BertModel, BertPreTrainedModel 16 | from transformers import AutoTokenizer, AutoConfig, AutoModel 17 | 18 | 19 | 20 | sys.path.insert(0, '../') 21 | PRETRAINED_MODEL_ARCHIVE_MAP = { 22 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 23 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 24 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 25 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 26 | 'spanbert-base-cased': "https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf_base.tar.gz", 27 | 'spanbert-large-cased': "https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf.tar.gz" 28 | } 29 | 30 | def batch_to_device(batch, target_device): 31 | """ 32 | send a pytorch batch to a device (CPU/GPU) 33 | """ 34 | for key in batch: 35 | if isinstance(batch[key], Tensor): 36 | batch[key] = batch[key].to(target_device) 37 | return batch 38 | 39 | 40 | class BertLayerNorm(nn.Module): 41 | def __init__(self, hidden_size, eps=1e-12): 42 | """Construct a layernorm module in the TF style (epsilon inside the square root). 43 | """ 44 | super(BertLayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(hidden_size)) 46 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 47 | self.variance_epsilon = eps 48 | 49 | def forward(self, x): 50 | u = x.mean(-1, keepdim=True) 51 | s = (x - u).pow(2).mean(-1, keepdim=True) 52 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 53 | return self.weight * x + self.bias 54 | 55 | 56 | def gelu(x): 57 | """Implementation of the gelu activation function. 58 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 59 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 60 | Also see https://arxiv.org/abs/1606.08415 61 | """ 62 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 63 | 64 | 65 | def swish(x): 66 | return x * torch.sigmoid(x) 67 | 68 | 69 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 70 | 71 | 72 | class BertPredictionHeadTransform(nn.Module): 73 | def __init__(self, config): 74 | super(BertPredictionHeadTransform, self).__init__() 75 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 76 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 77 | self.transform_act_fn = ACT2FN[config.hidden_act] 78 | else: 79 | self.transform_act_fn = config.hidden_act 80 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 81 | 82 | def forward(self, hidden_states): 83 | hidden_states = self.dense(hidden_states) 84 | hidden_states = self.transform_act_fn(hidden_states) 85 | hidden_states = self.LayerNorm(hidden_states) 86 | return hidden_states 87 | 88 | 89 | class BertLMPredictionHead(nn.Module): 90 | def __init__(self, config, bert_model_embedding_weights): 91 | super(BertLMPredictionHead, self).__init__() 92 | self.transform = BertPredictionHeadTransform(config) 93 | 94 | # The output weights are the same as the input embeddings, but there is 95 | # an output-only bias for each token. 96 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 97 | bert_model_embedding_weights.size(0), 98 | bias=False) 99 | self.decoder.weight = bert_model_embedding_weights 100 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 101 | 102 | def forward(self, hidden_states): 103 | hidden_states = self.transform(hidden_states) 104 | hidden_states = self.decoder(hidden_states) + self.bias 105 | return hidden_states 106 | 107 | 108 | # TransformerICT 109 | class ICT(BertPreTrainedModel): 110 | def __init__(self, config): 111 | super(ICT, self).__init__(config) 112 | self.bert = BertModel(config) 113 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 114 | self.cls = nn.Linear(config.hidden_size, 1) 115 | self.cls.predictions = BertLMPredictionHead(config, self.bert.embeddings.word_embeddings.weight) 116 | self.config = config 117 | 118 | self.init_weights() 119 | 120 | @autocast() 121 | def forward(self, input_ids, config, input_mask, token_type_ids=None, masked_lm_labels=None, device=None): 122 | 123 | batch_size = input_ids.size(0) 124 | outputs = self.bert(input_ids, 125 | attention_mask=input_mask, 126 | return_dict=False 127 | ) 128 | 129 | sequence_output, pooled_output = outputs[0], outputs[1] 130 | 131 | if masked_lm_labels is not None: 132 | # MLM loss 133 | lm_prediction_scores = self.cls.predictions(sequence_output) 134 | loss_fct = CrossEntropyLoss(ignore_index=-1) 135 | mlm_loss = loss_fct(lm_prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) if config.MLM else 0. 136 | 137 | # ICT loss 138 | logits = pooled_output.reshape(batch_size//2, 2, self.config.hidden_size) 139 | s_encode = logits[:, 0, :] # bs/2, 1, h 140 | c_encode = logits[:, 1, :] # bs/2, 1, h 141 | 142 | logit = torch.matmul(s_encode, c_encode.transpose(-2, -1)) 143 | target = torch.from_numpy(np.array([i for i in range(batch_size // 2)])).long().to(device) 144 | loss = nn.CrossEntropyLoss() 145 | ict_loss = loss(logit, target).mean() 146 | 147 | loss = mlm_loss + ict_loss 148 | return loss 149 | 150 | else: 151 | prediction_scores = self.cls(self.dropout(pooled_output)) 152 | return prediction_scores 153 | 154 | 155 | class ARES(BertPreTrainedModel): 156 | def __init__(self, config): 157 | super(ARES, self).__init__(config) 158 | self.bert = BertModel(config) 159 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 160 | self.cls = nn.Linear(config.hidden_size, 1) 161 | self.sigmoid = nn.Sigmoid() 162 | self.cls.predictions = BertLMPredictionHead(config, self.bert.embeddings.word_embeddings.weight) 163 | self.config = config 164 | 165 | self.init_weights() 166 | 167 | @autocast() 168 | def forward(self, input_ids, config, input_mask, token_type_ids, masked_lm_labels=None, device=None): 169 | 170 | batch_size = input_ids.size(0) 171 | outputs = self.bert(input_ids, 172 | attention_mask=input_mask, 173 | token_type_ids=token_type_ids, 174 | return_dict=False 175 | ) 176 | 177 | sequence_output, pooled_output = outputs[0], outputs[1] 178 | prediction_scores = self.cls(self.dropout(pooled_output)) 179 | 180 | if masked_lm_labels is not None: 181 | # MLM loss 182 | lm_prediction_scores = self.cls.predictions(sequence_output) 183 | loss_fct = CrossEntropyLoss(ignore_index=-1) 184 | mlm_loss = loss_fct(lm_prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) if config.MLM else 0. 185 | 186 | # Pairwise loss 187 | logits = prediction_scores.reshape(batch_size // 2, 2) 188 | softmax = Softmax(dim=1) 189 | logits = softmax(logits) 190 | pos_logits = logits[:, 0] 191 | neg_logits = logits[:, 1] 192 | marginloss = MarginRankingLoss(margin=1.0, reduction='mean') 193 | 194 | rep_label = torch.ones_like(pos_logits) 195 | rep_loss = marginloss(pos_logits, neg_logits, rep_label) 196 | 197 | loss = mlm_loss + rep_loss 198 | return loss 199 | else: 200 | return prediction_scores 201 | 202 | 203 | class ARESReranker(ARES): 204 | def __init__(self, config, max_input_length=512): 205 | super().__init__(config) 206 | self.tokenizer = AutoTokenizer.from_pretrained( 207 | "bert-base-uncased", config=config, local_files_only=True) 208 | self.max_input_length = max_input_length 209 | 210 | def tokenize(self, qd_pairs): 211 | feature_input_ids = [] 212 | feature_token_type_ids = [] 213 | feature_attention_mask = [] 214 | for query, doc in qd_pairs: 215 | cls_id, sep_id = 101, 102 216 | query_max_len = 32 217 | doc_max_len = 512 - 3 - query_max_len 218 | tokens = self.tokenizer.tokenize(query) 219 | query_input_ids = self.tokenizer.convert_tokens_to_ids(tokens)[: query_max_len] 220 | 221 | tokens = self.tokenizer.tokenize(doc) 222 | doc_input_ids = self.tokenizer.convert_tokens_to_ids(tokens)[: doc_max_len] 223 | 224 | input_ids = [cls_id] + query_input_ids + [sep_id] + doc_input_ids + [sep_id] 225 | token_type_ids = [0] * (len(query_input_ids) + 2) + [1] * (len(doc_input_ids) + 1) 226 | attention_mask = np.int64(np.array(input_ids) > 0) 227 | 228 | feature_input_ids.append(torch.tensor(input_ids)) 229 | feature_token_type_ids.append(torch.tensor(token_type_ids)) 230 | feature_attention_mask.append(torch.tensor(attention_mask)) 231 | 232 | 233 | # padding to same length 234 | max_len = max([len(x) for x in feature_input_ids]) 235 | for i in range(len(feature_input_ids)): 236 | pad_len = max_len - len(feature_input_ids[i]) 237 | feature_input_ids[i] = torch.cat([feature_input_ids[i], torch.zeros(pad_len).long()]) 238 | feature_token_type_ids[i] = torch.cat([feature_token_type_ids[i], torch.zeros(pad_len).long()]) 239 | feature_attention_mask[i] = torch.cat([feature_attention_mask[i], torch.zeros(pad_len).long()]) 240 | 241 | feature_input_ids = torch.vstack(feature_input_ids) 242 | feature_token_type_ids = torch.vstack(feature_token_type_ids) 243 | feature_attention_mask = torch.vstack(feature_attention_mask) 244 | 245 | return { 246 | "input_ids": feature_input_ids, 247 | "token_type_ids": feature_token_type_ids, 248 | "input_mask": feature_attention_mask 249 | } 250 | 251 | def score(self, qd_pairs): 252 | features = self.tokenize(qd_pairs) 253 | batch_to_device(features, self.device) 254 | with torch.cuda.amp.autocast(): 255 | with torch.no_grad(): 256 | scores = self.forward(config=None, **features) 257 | scores = scores.cpu().numpy().reshape(-1) 258 | return scores 259 | 260 | 261 | def rerank_query(self, query, docs): 262 | batch_size = 100 263 | 264 | qd_pairs = [(query, doc) for doc in docs] 265 | scores = [] 266 | for i in range(0, len(qd_pairs), batch_size): 267 | scores.append(self.score(qd_pairs[i: i + batch_size])) 268 | 269 | scores = np.concatenate(scores, axis=0) 270 | scores = scores.reshape(-1) 271 | return scores.tolist() 272 | 273 | def rerank(self, queries, docs_topk): 274 | assert len(queries) == len(docs_topk) 275 | scores_for_queries = [] 276 | for query, docs in zip(queries, docs_topk): 277 | scores_for_queries.append(self.rerank_query(query, docs)) 278 | return scores_for_queries 279 | -------------------------------------------------------------------------------- /finetune/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import os 7 | import sys 8 | sys.path.insert(0, '../') 9 | 10 | from tqdm import tqdm 11 | import json 12 | import torch 13 | import numpy as np 14 | import pandas as pd 15 | from datetime import timedelta 16 | 17 | 18 | from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup 19 | from transformers import PretrainedConfig, BertConfig 20 | from torch import nn, optim 21 | from torch.cuda.amp import autocast, GradScaler 22 | from model.modeling import ARES, ICT 23 | 24 | from dataloader import get_train_qd_loader, get_test_qd_loader 25 | from config import get_config 26 | from ms_marco_eval import compute_metrics_from_files 27 | import warnings 28 | 29 | warnings.filterwarnings("ignore") 30 | torch.backends.cudnn.benchmark = True 31 | 32 | 33 | def train_epoch(model, scaler, qd_loader, optimizer, scheduler, device, config): 34 | model.train() 35 | losses = [] 36 | 37 | num_instances = len(qd_loader) 38 | model_name = config.model_name 39 | for step, batch_data in enumerate(tqdm(qd_loader, desc=f"Fine-tuning {model_name} progress", total=num_instances)): 40 | input_ids, attention_mask, token_type_ids = batch_data["token_ids"], batch_data["attention_mask"], batch_data["token_type_ids"] 41 | this_batch_size = input_ids.size()[0] 42 | 43 | # b/2 x 2 x 512 ==> b x 512 44 | input_ids = input_ids.reshape(this_batch_size * 2, -1) 45 | attention_mask = attention_mask.reshape(this_batch_size * 2, -1) 46 | token_type_ids = token_type_ids.reshape(this_batch_size * 2, -1) 47 | 48 | input_ids = input_ids.to(device) # bs x 512 49 | attention_mask = attention_mask.to(device) # bs x 512 50 | token_type_ids = token_type_ids.to(device) 51 | 52 | with autocast(): 53 | output = model( 54 | input_ids=input_ids, 55 | config=config, 56 | input_mask=attention_mask, 57 | token_type_ids=token_type_ids, 58 | ) # bs x 1 59 | 60 | softmax = nn.Softmax(dim=1) 61 | marginloss = nn.MarginRankingLoss(margin=1.0, reduction='mean') 62 | batch_size = output.size(0) 63 | logits = output.reshape(batch_size // 2, 2) 64 | logits = softmax(logits) 65 | pos_logits = logits[:, 0] 66 | neg_logits = logits[:, 1] 67 | rop_label = torch.ones_like(pos_logits) 68 | loss = marginloss(pos_logits, neg_logits, rop_label) 69 | 70 | loss = loss / config.gradient_accumulation_steps 71 | losses.append(loss.item()) 72 | scaler.scale(loss).backward() 73 | 74 | # gradient accumulation 75 | if (step + 1) % config.gradient_accumulation_steps == 0: 76 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip) 77 | scaler.step(optimizer) 78 | scaler.update() 79 | 80 | scheduler.step() 81 | optimizer.zero_grad() 82 | 83 | if step % int(config.print_every) == 0: 84 | print(f"\n[Train] Loss at step {step} = {loss.item()}, lr = {optimizer.state_dict()['param_groups'][0]['lr']}") 85 | return np.mean(losses) 86 | 87 | 88 | def eval_model(model, qd_loader, device, config): 89 | model.eval() 90 | df_rank = pd.DataFrame(columns=['q_id', 'd_id', 'rank', 'score']) 91 | q_id_list, d_id_list, rank, score = [], [], [], [] 92 | 93 | num_instances = len(qd_loader) 94 | with torch.no_grad(): 95 | for i, batch_data in enumerate(tqdm(qd_loader, desc=f"Evaluating progress", total=num_instances)): 96 | input_ids, attention_mask, token_type_ids = batch_data["token_ids"], batch_data["attention_mask"], \ 97 | batch_data["token_type_ids"] 98 | 99 | input_ids = input_ids.to(device) # bs x 512 100 | attention_mask = attention_mask.to(device) # bs x 512 101 | token_type_ids = token_type_ids.to(device) 102 | 103 | output = model( 104 | input_ids=input_ids, 105 | config=config, 106 | input_mask=attention_mask, 107 | token_type_ids=token_type_ids, 108 | ) # 100 x 1 109 | 110 | output = output.squeeze() 111 | q_ids = batch_data["q_id"] 112 | d_ids = batch_data["d_id"] 113 | scores = output.cpu().tolist() 114 | tuples = list(zip(q_ids, d_ids, scores)) 115 | sorted_tuples = sorted(tuples, key=lambda x: x[2], reverse=True) # 看一下top100的分数分布 116 | for idx, this_tuple in enumerate(sorted_tuples): 117 | q_id_list.append(this_tuple[0]) 118 | d_id_list.append(this_tuple[1]) 119 | rank.append(idx + 1) 120 | score.append(this_tuple[2]) 121 | 122 | df_rank['q_id'] = q_id_list 123 | df_rank['d_id'] = d_id_list 124 | df_rank['rank'] = rank 125 | df_rank['score'] = score 126 | return df_rank 127 | 128 | 129 | if __name__ == '__main__': 130 | config = get_config() 131 | 132 | # automatically create save dirs 133 | save_dir = f"{config.PRE_TRAINED_MODEL_NAME}/ckpt" 134 | if not os.path.exists(save_dir): 135 | os.mkdir(save_dir) 136 | save_model_path = f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/model_state" 137 | 138 | np.random.seed(config.seed) 139 | torch.manual_seed(config.seed) 140 | if torch.cuda.is_available(): 141 | torch.cuda.manual_seed_all(config.seed) 142 | torch.cuda.set_device(config.local_rank) 143 | print('GPU is ON!') 144 | device = torch.device(f'cuda:{config.local_rank}') 145 | else: 146 | device = torch.device("cpu") 147 | 148 | # distributed training 149 | if config.distributed_train and not config.test: 150 | torch.distributed.init_process_group(backend="nccl", timeout=timedelta(180000000)) 151 | local_rank = config.local_rank 152 | if local_rank != -1: 153 | print("Using Distributed") 154 | 155 | # Train Data Loader 156 | df_train_qds = pd.read_csv(config.train_qd_dir, sep=' ', header=None) 157 | if config.local_rank == 0: 158 | df_test_qds = pd.read_csv(config.test_qd_dir, sep=' ', header=None) 159 | df_dl2019_qds = pd.read_csv(config.dl2019_qd_dir, sep=' ', header=None) 160 | 161 | best_nDCG_dl2019, best_MRR_test = 0., 0. 162 | train_top100 = pd.read_csv(config.train100_dir, sep='\t', header=None) 163 | if config.local_rank == 0: 164 | test_top100 = pd.read_csv(config.test100_dir, sep='\t', header=None) 165 | dl2019_top100 = pd.read_csv(config.dl100_dir, sep='\t', header=None) 166 | 167 | # json files 168 | train_qs, test_qs, dl2019_qs, doc2query = {}, {}, {}, {} 169 | with open(config.train_qs_dir) as f_train_qs: 170 | for line in f_train_qs: 171 | es = json.loads(line) 172 | qid, ids = es["id"], es["ids"] 173 | if qid not in train_qs: 174 | train_qs[qid] = ids 175 | 176 | if config.local_rank == 0: 177 | with open(config.test_qs_dir) as f_test_qs: 178 | for line in f_test_qs: 179 | es = json.loads(line) 180 | qid, ids = es["id"], es["ids"] 181 | if qid not in test_qs: 182 | test_qs[qid] = ids 183 | with open(config.dl2019_qs_dir) as f_dl2019_qs: 184 | for line in f_dl2019_qs: 185 | es = json.loads(line) 186 | qid, ids = es["id"], es["ids"] 187 | if qid not in dl2019_qs: 188 | dl2019_qs[qid] = ids 189 | 190 | with open(config.docid2id_dir) as f_docid2id: 191 | docid2id = json.load(f_docid2id) 192 | print("Load dicts done!") 193 | 194 | collection_size = len(docid2id) 195 | doc_tokens = np.memmap(config.memmap_doc_dir, dtype='int32', shape=(collection_size, 512)) 196 | 197 | cfg = PretrainedConfig.get_config_dict(config.PRE_TRAINED_MODEL_NAME)[0] 198 | if not config.gradient_checkpointing: 199 | del cfg["gradient_checkpointing"] 200 | cfg = BertConfig.from_dict(cfg) 201 | 202 | if not config.load_ckpt: # train 203 | if config.model_type == 'ICT': 204 | model = ICT.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg) 205 | else: 206 | model = ARES.from_pretrained(config.PRE_TRAINED_MODEL_NAME, config=cfg) 207 | else: # test 208 | if config.model_type == 'ARES': 209 | model = ARES(config=cfg) 210 | elif config.model_type == 'PROP': 211 | model = PROP(config=cfg) 212 | else: 213 | model = ICT(config=cfg) 214 | model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(f"{config.PRE_TRAINED_MODEL_NAME}/ckpt/{config.model_path}", 215 | map_location={'cuda:0': f'cuda:{config.local_rank}'}).items()}) 216 | 217 | model = model.to(device) 218 | print("Loading model...") 219 | model = model.cuda() 220 | 221 | scaler = GradScaler(enabled=True) 222 | 223 | if config.optim == 'adam': 224 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 225 | elif config.optim == 'amsgrad': 226 | optimizer = torch.optim.Amsgrad(model.parameters(), lr=config.lr) 227 | elif config.optim == 'adagrad': 228 | optimizer = torch.optim.Adagrad(model.parameters(), lr=config.lr) 229 | else: # adamw, weight decay not depend on the lr 230 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 231 | optimizer_grouped_parameters = [ 232 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, 233 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 234 | ] 235 | optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=config.adam_epsilon) 236 | 237 | if not config.test: # train 238 | if config.distributed_train: 239 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True, broadcast_buffers=False) 240 | config.warm_up = config.warm_up / config.gpu_num 241 | 242 | if config.local_rank == 0: 243 | print("\n========== Loading dev data ==========") 244 | test_qd_loader = get_test_qd_loader(test_top100, test_qs, doc_tokens, docid2id, config) 245 | print(f"test_q: {len(test_qs)}, test_q_batchs:{len(test_qd_loader)}") 246 | 247 | print("\n========== Loading DL 2019 data ==========") 248 | dl2019_qd_loader = get_test_qd_loader(dl2019_top100, dl2019_qs, doc_tokens, docid2id, config) 249 | print(f"dl2019_q: {len(dl2019_qs)}, dl2019_q_batchs:{len(dl2019_qd_loader)}") 250 | 251 | for epoch in range(config.epochs): 252 | print(f'Epoch {epoch + 1}/{config.epochs}') 253 | print('-' * 10) 254 | 255 | print("========== Loading training data ==========") 256 | train_qd_loader = get_train_qd_loader(df_train_qds, train_top100, train_qs, doc_tokens, docid2id, config, mode='train') # b_sz * data samples 257 | print(f"train_qd_pairs: {len(df_train_qds)}, train_batchs:{len(train_qd_loader)}, batch_size: {config.batch_size}") 258 | 259 | total_steps = len(train_qd_loader) 260 | scheduler = get_linear_schedule_with_warmup( 261 | optimizer, 262 | num_warmup_steps=int(total_steps * config.warm_up), 263 | num_training_steps=total_steps 264 | ) 265 | 266 | train_loss = train_epoch( 267 | model, 268 | scaler, 269 | train_qd_loader, 270 | optimizer, 271 | scheduler, 272 | device, 273 | config, 274 | ) 275 | scheduler.step() 276 | print(f'Train loss {train_loss}') 277 | 278 | if config.local_rank == 0: 279 | qd_rank = eval_model( 280 | model, 281 | dl2019_qd_loader, 282 | device, 283 | config, 284 | ) 285 | df_rank = pd.DataFrame(columns=['q_id', 'Q0', 'd_id', 'rank', 'score', 'standard']) 286 | df_rank['q_id'] = qd_rank['q_id'] 287 | df_rank['Q0'] = ['Q0'] * len(qd_rank['q_id']) 288 | df_rank['d_id'] = qd_rank['d_id'] 289 | df_rank['rank'] = qd_rank['rank'] 290 | df_rank['score'] = qd_rank['score'] 291 | df_rank['standard'] = ['STANDARD'] * len(qd_rank['q_id']) 292 | df_rank.to_csv(f"{save_dir}/dl2019_qd_rank.tsv", sep=' ', index=False, header=False) # ! 293 | result_lines = os.popen(f'trec_eval -m ndcg_cut.10,100 {config.dl2019_qd_dir} {save_dir}/dl2019_qd_rank.tsv').read().strip().split("\n") 294 | ndcg_10, ndcg_100 = float(result_lines[0].strip().split()[-1]), float( 295 | result_lines[1].strip().split()[-1]) 296 | metrics = {'nDCG @10': ndcg_10, 'nDCG @100': ndcg_100, 'QueriesRanked': len(set(qd_rank['q_id']))} 297 | 298 | print('\n#############################') 299 | print('<--------- DL 2019 --------->') 300 | for metric in sorted(metrics): 301 | print('{}: {}'.format(metric, metrics[metric])) 302 | print('#############################\n') 303 | nDCG_dl2019 = round(metrics['nDCG @10'], 4) 304 | nDCG_dl2019_100 = round(metrics['nDCG @100'], 4) 305 | if nDCG_dl2019 > best_nDCG_dl2019: 306 | best_nDCG_dl2019 = nDCG_dl2019 307 | qd_rank.to_csv(f"{save_dir}/best_{config.model_type}_dl2019_qd_rank.tsv", sep='\t', index=False, 308 | header=False) 309 | 310 | # test msmarco dev 311 | qd_rank = eval_model( 312 | model, 313 | test_qd_loader, 314 | device, 315 | config, 316 | ) 317 | qd_rank.to_csv(f"{save_dir}/test_qd_rank.tsv", sep='\t', index=False, header=False) 318 | metrics = compute_metrics_from_files(config.test_qd_dir, f"{save_dir}/test_qd_rank.tsv") 319 | print('\n#####################') 320 | print('<----- MS Dev ----->') 321 | for metric in sorted(metrics): 322 | print('{}: {}'.format(metric, metrics[metric])) 323 | print('#####################\n') 324 | MRR_test = round(metrics['MRR @10'], 4) 325 | MRR_test_100 = round(metrics['MRR @100'], 4) 326 | if MRR_test > best_MRR_test: 327 | best_MRR_test = MRR_test 328 | qd_rank.to_csv(f"{save_dir}/best_{config.model_type}_test_qd_rank.tsv", sep='\t', index=False, header=False) 329 | 330 | print('[SAVE] Saving model ... ') 331 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 332 | torch.save(model_to_save.state_dict(),f"{save_dir}/{config.model_type}_{MRR_test}_{MRR_test_100}_e{epoch + 1}") 333 | 334 | else: # test 335 | print("\n========== Loading dev data ==========") 336 | test_qd_loader = get_test_qd_loader(test_top100, test_qs, doc_tokens, docid2id, config) 337 | print(f"test_q: {len(test_qs)}, test_q_batchs:{len(test_qd_loader)}") 338 | 339 | print("\n========== Loading DL 2019 data ==========") 340 | dl2019_qd_loader = get_test_qd_loader(dl2019_top100, dl2019_qs, doc_tokens, docid2id, config) 341 | print(f"dl2019_q: {len(dl2019_qs)}, dl2019_q_batchs:{len(dl2019_qd_loader)}") 342 | 343 | qd_rank = eval_model( 344 | model, 345 | dl2019_qd_loader, 346 | device, 347 | config, 348 | ) 349 | df_rank = pd.DataFrame(columns=['q_id', 'Q0', 'd_id', 'rank', 'score', 'standard']) 350 | df_rank['q_id'] = qd_rank['q_id'] 351 | df_rank['Q0'] = ['Q0'] * len(qd_rank['q_id']) 352 | df_rank['d_id'] = qd_rank['d_id'] 353 | df_rank['rank'] = qd_rank['rank'] 354 | df_rank['score'] = qd_rank['score'] 355 | df_rank['standard'] = ['STANDARD'] * len(qd_rank['q_id']) 356 | df_rank.to_csv(f"{save_dir}/dl2019_qd_rank_as100.tsv", sep=' ', index=False, header=False) 357 | result_lines = os.popen(f'trec_eval -m ndcg_cut.10,100 {config.dl2019_qd_dir} {save_dir}/dl2019_qd_rank_as100.tsv').read().strip().split("\n") 358 | ndcg_10, ndcg_100 = float(result_lines[0].strip().split()[-1]), float(result_lines[1].strip().split()[-1]) 359 | metrics = {'nDCG @10': ndcg_10, 'nDCG @100': ndcg_100, 'QueriesRanked': len(set(qd_rank['q_id']))} 360 | print('\n#############################') 361 | print('<--------- DL 2019 --------->') 362 | for metric in sorted(metrics): 363 | print('{}: {}'.format(metric, metrics[metric])) 364 | print('#############################\n') 365 | 366 | # test msmarco dev 367 | qd_rank = eval_model( 368 | model, 369 | test_qd_loader, 370 | device, 371 | config, 372 | ) 373 | qd_rank.to_csv(f"{save_dir}/test_qd_rank_as100.tsv", sep='\t', index=False, header=False) 374 | metrics = compute_metrics_from_files(config.test_qd_dir, f"{save_dir}/test_qd_rank_as100.tsv") 375 | print('\n#####################') 376 | print('<----- MS Dev ----->') 377 | for metric in sorted(metrics): 378 | print('{}: {}'.format(metric, metrics[metric])) 379 | print('#####################\n') 380 | -------------------------------------------------------------------------------- /pretrain/dataloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @ref: Axiomatically Regularized Pre-training for Ad hoc Search 3 | @author: Jia Chen, Yiqun Liu, Yan Fang, Jiaxin Mao, Hui Fang, Shenghao Yang, Xiaohui Xie, Min Zhang, Shaoping Ma. 4 | ''' 5 | # encoding: utf-8 6 | import random 7 | import numpy as np 8 | import collections 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import xgboost as xgb 12 | from scipy.special import * 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | from transformers import BertTokenizer 16 | 17 | 18 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) 19 | # masked_lm_prob=0.15, max_predictions_per_seq=60, True, bert_vocab_list (id) 20 | def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list, id2token): 21 | """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but 22 | with several refactors to clean it up and remove a lot of unnecessary variables.""" 23 | cand_indices = [] 24 | 25 | START_DOC = False 26 | for (i, token) in enumerate(tokens): # token_ids 27 | if token == 102: # SEP 28 | START_DOC = True 29 | continue 30 | if token == 101: # CLS 31 | continue 32 | if not START_DOC: 33 | continue 34 | 35 | if (whole_word_mask and len(cand_indices) >= 1 and id2token[token].startswith("##")): 36 | cand_indices[-1].append(i) 37 | else: 38 | cand_indices.append([i]) 39 | 40 | num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(cand_indices) * masked_lm_prob)))) 41 | random.shuffle(cand_indices) 42 | masked_lms = [] 43 | covered_indexes = set() 44 | for index_set in cand_indices: 45 | if len(masked_lms) >= num_to_mask: 46 | break 47 | # If adding a whole-word mask would exceed the maximum number of 48 | # predictions, then just skip this candidate. 49 | if len(masked_lms) + len(index_set) > num_to_mask: 50 | continue 51 | is_any_index_covered = False 52 | for index in index_set: 53 | if index in covered_indexes: 54 | is_any_index_covered = True 55 | break 56 | if is_any_index_covered: 57 | continue 58 | for index in index_set: 59 | covered_indexes.add(index) 60 | # 80% of the time, replace with [MASK] 61 | if random.random() < 0.8: 62 | masked_token = 103 63 | else: 64 | # 10% of the time, keep original 65 | if random.random() < 0.5: 66 | masked_token = tokens[index] 67 | # 10% of the time, replace with random word 68 | else: 69 | masked_token = random.choice(vocab_list) 70 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 71 | tokens[index] = masked_token 72 | 73 | assert len(masked_lms) <= num_to_mask 74 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 75 | mask_indices = [p.index for p in masked_lms] 76 | masked_token_labels = [p.label for p in masked_lms] 77 | 78 | return tokens, masked_token_labels, mask_indices 79 | 80 | 81 | def create_masked_lm_predictions_ict(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list, id2token): 82 | """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but 83 | with several refactors to clean it up and remove a lot of unnecessary variables.""" 84 | cand_indices = [] 85 | for (i, token) in enumerate(tokens): # token_ids 86 | if (whole_word_mask and len(cand_indices) >= 1 and id2token[token].startswith("##")): # startswith ## 87 | cand_indices[-1].append(i) 88 | else: 89 | cand_indices.append([i]) 90 | 91 | num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(cand_indices) * masked_lm_prob)))) 92 | random.shuffle(cand_indices) 93 | masked_lms = [] 94 | covered_indexes = set() 95 | for index_set in cand_indices: 96 | if len(masked_lms) >= num_to_mask: 97 | break 98 | # If adding a whole-word mask would exceed the maximum number of 99 | # predictions, then just skip this candidate. 100 | if len(masked_lms) + len(index_set) > num_to_mask: 101 | continue 102 | is_any_index_covered = False 103 | for index in index_set: 104 | if index in covered_indexes: 105 | is_any_index_covered = True 106 | break 107 | if is_any_index_covered: 108 | continue 109 | for index in index_set: 110 | covered_indexes.add(index) 111 | # 80% of the time, replace with [MASK] 112 | if random.random() < 0.8: 113 | masked_token = 103 114 | else: 115 | # 10% of the time, keep original 116 | if random.random() < 0.5: 117 | masked_token = tokens[index] 118 | # 10% of the time, replace with random word 119 | else: 120 | masked_token = random.choice(vocab_list) 121 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 122 | tokens[index] = masked_token 123 | 124 | assert len(masked_lms) <= num_to_mask 125 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 126 | mask_indices = [p.index for p in masked_lms] 127 | masked_token_labels = [p.label for p in masked_lms] 128 | 129 | return tokens, masked_token_labels, mask_indices 130 | 131 | 132 | class TrainICTPairwise(Dataset): 133 | def __init__(self, dids, d_dict, did2idx, config): 134 | self.dids = dids 135 | self.d_dict = d_dict 136 | self.did2idx = did2idx 137 | self.config = config 138 | 139 | self.tokenizer = BertTokenizer.from_pretrained(config.PRE_TRAINED_MODEL_NAME) 140 | self.vocab_list = list(self.tokenizer.vocab[key] for key in self.tokenizer.vocab) 141 | self.id2token = {self.tokenizer.vocab[key]: key for key in self.tokenizer.vocab} 142 | self.sep_token_id = self.tokenizer.vocab["."] 143 | self.cls_id = 101 144 | 145 | def __len__(self): 146 | return len(self.dids) 147 | 148 | def __getitem__(self, item): 149 | this_did = self.dids[item] 150 | 151 | doc_ids = self.d_dict[self.did2idx[this_did]].tolist() 152 | sep_pos = [-1] + [i for i, id in enumerate(doc_ids) if id == self.sep_token_id] + [len(doc_ids) - 1] 153 | sentences = [doc_ids[sep_pos[i] + 1: sep_pos[i + 1] + 1] for i in range(len(sep_pos) - 1)] 154 | removes = [random.random() < 0.9 for _ in range(len(sentences))] 155 | 156 | s_ids, c_ids = [], [] 157 | b_token_ids, b_attention_mask, b_masked_lm_ids = np.array([[]]), np.array([[]]), np.array([[]]) 158 | 159 | for idx, remove in enumerate(removes): 160 | if remove == 1: 161 | sentence = [self.cls_id] + sentences[idx] 162 | context = sentences[: idx] + sentences[idx + 1:] 163 | context = [self.cls_id] + [w for s in context for w in s] 164 | 165 | sentence = sentence[: self.config.max_len] 166 | context = context[: self.config.max_len] 167 | s_ids.append(sentence) 168 | c_ids.append(context) 169 | 170 | s_input_ids = np.zeros(self.config.max_len, dtype=np.int) 171 | c_input_ids = np.zeros(self.config.max_len, dtype=np.int) 172 | s_input_ids[: len(sentence)] = sentence 173 | c_input_ids[: len(context)] = context 174 | 175 | s_attention_mask = np.int64(s_input_ids > 0) 176 | c_attention_mask = np.int64(c_input_ids > 0) 177 | attention_mask = np.stack((s_attention_mask, c_attention_mask)) 178 | 179 | s_input_ids, s_masked_lm_ids, s_masked_lm_positions = create_masked_lm_predictions_ict( 180 | s_input_ids, 181 | masked_lm_prob=self.config.masked_lm_prob, 182 | max_predictions_per_seq=self.config.max_predictions_per_seq, 183 | whole_word_mask=True, 184 | vocab_list=self.vocab_list, 185 | id2token=self.id2token) 186 | c_input_ids, c_masked_lm_ids, c_masked_lm_positions = create_masked_lm_predictions_ict( 187 | c_input_ids, 188 | masked_lm_prob=self.config.masked_lm_prob, 189 | max_predictions_per_seq=self.config.max_predictions_per_seq, 190 | whole_word_mask=True, 191 | vocab_list=self.vocab_list, 192 | id2token=self.id2token) 193 | s_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1) 194 | c_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1) 195 | s_lm_label_array[s_masked_lm_positions] = s_masked_lm_ids 196 | c_lm_label_array[c_masked_lm_positions] = c_masked_lm_ids 197 | masked_lm_ids = np.stack((s_lm_label_array, c_lm_label_array)) 198 | 199 | token_ids = np.stack((s_input_ids.flatten(), c_input_ids.flatten())) 200 | b_token_ids = token_ids if len(b_token_ids) == 1 else np.concatenate((b_token_ids, token_ids), axis=0) 201 | b_attention_mask = attention_mask if len(b_attention_mask) == 1 else np.concatenate((b_attention_mask, attention_mask), axis=0) 202 | b_masked_lm_ids = masked_lm_ids if len(b_masked_lm_ids) == 1 else np.concatenate((b_masked_lm_ids, masked_lm_ids), axis=0) 203 | 204 | # clip 205 | b_token_ids = b_token_ids[: self.config.batch_size, :] 206 | b_attention_mask = b_attention_mask[: self.config.batch_size, :] 207 | b_masked_lm_ids = b_masked_lm_ids[: self.config.batch_size, :] # no greater than max batch size 208 | 209 | return { 210 | 'token_ids': b_token_ids, # b x 2 211 | 'attention_mask': b_attention_mask, 212 | 'masked_lm_ids': b_masked_lm_ids, 213 | } 214 | 215 | 216 | def get_ict_loader(d_dict, did2idx, config): 217 | 218 | dids = list(did2idx.keys()) 219 | print('Loading tokens...') 220 | ds = TrainICTPairwise( 221 | dids=dids, 222 | d_dict=d_dict, 223 | did2idx=did2idx, 224 | config=config 225 | ) 226 | batch_size = 1 227 | if config.distributed_train: 228 | sampler = DistributedSampler(ds, num_replicas=config.world_size, rank=config.local_rank) 229 | return DataLoader( 230 | ds, 231 | batch_size=batch_size, 232 | num_workers=0, 233 | sampler=sampler 234 | ) 235 | else: 236 | return DataLoader( 237 | ds, 238 | batch_size=batch_size, 239 | num_workers=0, 240 | shuffle=True, 241 | ) 242 | 243 | 244 | class TrainQDDatasetPairwise(Dataset): 245 | def __init__(self, q_ids, d_ids, d_dict, did2idx, config, gen_qs, gen_qid2id): 246 | self.q_ids = q_ids 247 | self.d_ids = d_ids 248 | self.d_dict = d_dict 249 | self.did2idx = did2idx 250 | self.gen_qs = gen_qs 251 | self.gen_qid2id = gen_qid2id 252 | self.config = config 253 | self.tokenizer = BertTokenizer.from_pretrained(self.config.PRE_TRAINED_MODEL_NAME) 254 | self.vocab_list = list(self.tokenizer.vocab[key] for key in self.tokenizer.vocab) 255 | self.id2token = {self.tokenizer.vocab[key]: key for key in self.tokenizer.vocab} 256 | 257 | def __len__(self): 258 | return len(self.q_ids) 259 | 260 | def __getitem__(self, item): 261 | cls_id, sep_id = 101, 102 262 | q_id = self.q_ids[item] 263 | d_id = self.d_ids[item] 264 | 265 | pos_q_id, neg_q_id = q_id[0], q_id[1] 266 | did = d_id[0] 267 | 268 | pos_query_input_ids = self.gen_qs[self.gen_qid2id[pos_q_id]].tolist() 269 | neg_query_input_ids = self.gen_qs[self.gen_qid2id[neg_q_id]].tolist() 270 | 271 | doc_input_ids = self.d_dict[self.did2idx[did]].tolist() 272 | pos_query_input_ids = pos_query_input_ids[: self.config.max_q_len] 273 | neg_query_input_ids = neg_query_input_ids[: self.config.max_q_len] 274 | 275 | pos_max_passage_length = self.config.max_len - 3 - len(pos_query_input_ids) 276 | neg_max_passage_length = self.config.max_len - 3 - len(neg_query_input_ids) 277 | 278 | pos_doc_input_ids = doc_input_ids[:pos_max_passage_length] 279 | neg_doc_input_ids = doc_input_ids[:neg_max_passage_length] 280 | 281 | pos_input_ids = [cls_id] + pos_query_input_ids + [sep_id] + pos_doc_input_ids + [sep_id] 282 | neg_input_ids = [cls_id] + neg_query_input_ids + [sep_id] + neg_doc_input_ids + [sep_id] 283 | 284 | pos_token_type_ids = [0] * (2 + len(pos_query_input_ids)) + [1] * (1 + len(pos_doc_input_ids)) 285 | neg_token_type_ids = [0] * (2 + len(neg_query_input_ids)) + [1] * (1 + len(neg_doc_input_ids)) 286 | 287 | pos_token_ids = np.array(pos_input_ids) 288 | neg_token_ids = np.array(neg_input_ids) 289 | 290 | pos_attention_mask = np.int64(pos_token_ids > 0) 291 | neg_attention_mask = np.int64(neg_token_ids > 0) 292 | attention_mask = np.stack((pos_attention_mask, neg_attention_mask)) 293 | 294 | pos_token_type_ids = np.array(pos_token_type_ids) 295 | neg_token_type_ids = np.array(neg_token_type_ids) 296 | token_type_ids = np.stack((pos_token_type_ids, neg_token_type_ids)) 297 | 298 | pos_token_ids, pos_masked_lm_ids, pos_masked_lm_positions = create_masked_lm_predictions( 299 | pos_token_ids, 300 | masked_lm_prob=self.config.masked_lm_prob, 301 | max_predictions_per_seq=self.config.max_predictions_per_seq, 302 | whole_word_mask=True, 303 | vocab_list=self.vocab_list, 304 | id2token=self.id2token) 305 | neg_token_ids, neg_masked_lm_ids, neg_masked_lm_positions = create_masked_lm_predictions( 306 | neg_token_ids, 307 | masked_lm_prob=self.config.masked_lm_prob, 308 | max_predictions_per_seq=self.config.max_predictions_per_seq, 309 | whole_word_mask=True, 310 | vocab_list=self.vocab_list, 311 | id2token=self.id2token) 312 | token_ids = np.stack((pos_token_ids.flatten(), neg_token_ids.flatten())) 313 | 314 | pos_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1) 315 | neg_lm_label_array = np.full(self.config.max_len, dtype=np.int, fill_value=-1) 316 | pos_lm_label_array[pos_masked_lm_positions] = pos_masked_lm_ids 317 | neg_lm_label_array[neg_masked_lm_positions] = neg_masked_lm_ids 318 | 319 | masked_lm_ids = np.stack((pos_lm_label_array, neg_lm_label_array)) 320 | 321 | return { 322 | 'token_ids': token_ids, 323 | 'attention_mask': attention_mask, 324 | 'token_type_ids': token_type_ids, 325 | 'masked_lm_ids': masked_lm_ids, 326 | } 327 | 328 | 329 | # [CLS] q [SEP] d [SEP] 330 | def get_train_qd_loader(d_dict, did2idx, config, doc2query=None, gen_qs=None, gen_qid2id=None, axiom_feature=None): 331 | q_max_len, max_len, batch_size = config.max_q_len, config.max_len, config.batch_size 332 | 333 | new_q_ids, new_d_ids = [], [] 334 | doc_num = len(did2idx) 335 | dids = list(did2idx.keys()) 336 | 337 | # loading xgboost model 338 | model = xgb.XGBRFClassifier() 339 | model.load_model(config.clf_model) 340 | 341 | all_case = [] 342 | for idx in tqdm(range(doc_num), desc=f"Sampling Pre-train Query Pairs progress"): 343 | this_did = dids[idx] 344 | if this_did not in doc2query: 345 | continue 346 | 347 | qids = [[qid] for qid in doc2query[this_did]] 348 | q_num = len(qids) 349 | for i in range(q_num): 350 | q_id = qids[i][0] 351 | idx = gen_qid2id[q_id] 352 | for k in range(len(axiom_feature)): 353 | this_feature_name, this_feature = axiom_feature[k][0], axiom_feature[k][1] 354 | if this_feature_name == 'RANK': 355 | score = this_feature[idx][0] if this_feature[idx][0] != 0 else 1e12 356 | else: 357 | score = this_feature[idx][0] 358 | score = this_feature[idx][0] if this_feature_name not in ['PROX-1', 'PROX-2', 'RANK'] else (1 / (score + 1e-12)) 359 | qids[i].append(score) 360 | 361 | all_pairs = [] 362 | for i in range(q_num): 363 | for j in range(i+1, q_num): 364 | q1, q2 = qids[i], qids[j] 365 | all_pairs.append([q1, q2]) 366 | 367 | k = min(2, len(all_pairs)) 368 | sampled_pairs = random.sample(all_pairs, k=k) 369 | 370 | for pair in sampled_pairs: 371 | qid1, qid2 = pair[0][0], pair[1][0] 372 | case = [] 373 | for i in range(len(axiom_feature)): 374 | axiom_1 = pair[0][i + 1] 375 | axiom_2 = pair[1][i + 1] 376 | if axiom_1 > axiom_2: 377 | case.append(1) 378 | elif axiom_1 == axiom_2: 379 | case.append(0) 380 | else: 381 | case.append(-1) 382 | all_case.append(case) 383 | new_q_ids.append([qid1, qid2]) 384 | new_d_ids.append([this_did]) 385 | 386 | all_case = pd.DataFrame(np.array(all_case)) 387 | all_case.columns = ['PROX-1', 'PROX-2', 'REP-QL', 'REP-TFIDF', 'REG', 'STM-1', 'STM-2', 'STM-3', 'RANK'] 388 | pred_prob = model.predict(all_case) 389 | for idx, pred in enumerate(pred_prob): 390 | result = 1 if pred > 0.5 else 0 391 | if result == 0: # swap 392 | qid1 = new_q_ids[idx][0] 393 | qid2 = new_q_ids[idx][1] 394 | new_q_ids[idx][0] = qid2 395 | new_q_ids[idx][1] = qid1 396 | 397 | print('Loading tokens...') 398 | ds = TrainQDDatasetPairwise( 399 | q_ids=new_q_ids, 400 | d_ids=new_d_ids, 401 | d_dict=d_dict, 402 | did2idx=did2idx, 403 | config=config, 404 | gen_qs=gen_qs, 405 | gen_qid2id=gen_qid2id, 406 | ) 407 | batch_size = batch_size // 2 408 | 409 | if config.distributed_train: 410 | sampler = DistributedSampler(ds, num_replicas=config.world_size, rank=config.local_rank) 411 | return DataLoader( 412 | ds, 413 | batch_size=batch_size, 414 | num_workers=0, 415 | sampler=sampler 416 | ) 417 | else: 418 | return DataLoader( 419 | ds, 420 | batch_size=batch_size, 421 | num_workers=0, 422 | shuffle=True, 423 | ) 424 | 425 | 426 | -------------------------------------------------------------------------------- /visualization/output_ARES_simple.html: -------------------------------------------------------------------------------- 1 |
QID DIDRelevance Level/RankWord Importance
156493 2 | D33569453/1 [CLS] do goldfish grow [SEP] https : / / answers . yahoo . com / question / index ? qid = 20100226170159aawholxhow to make goldfish grow faster ? " pets fish how to make goldfish grow faster ? just wondering ? update : what kind of foods could i use ? would warmer water help ? update 2 : gabe tech , retard they aren ' t in a bowl and if i did what you said , they ' d die ! follow 18 answers answers relevance rating newest oldest best answer : really people ? if you put a small child into a large house , will he grow faster ? no ! a tank that is too small will slow his growth down and even stop it but a bigger tank than needed won ' t have any effect . make sure his water is good and that he has adequate room and food , and he will grow at his own pace . really people fish are just like any other animal on the planet they aren ' t little aliens . t he only thing weird about how a fish grows is that they put out a hormone into the water that will slow down the growth of other fish and them selves . and dont put fill your bowl with juice thats an acid and it will kill your
--------------------------------------------------------------------------------