├── src ├── data │ └── config │ │ ├── index │ │ ├── none.yaml │ │ ├── _default.yaml │ │ ├── trie.yaml │ │ ├── fm.yaml │ │ ├── faiss.yaml │ │ ├── impact-tok.yaml │ │ ├── impact-word.yaml │ │ ├── invhit.yaml │ │ ├── invvec.yaml │ │ ├── impact.yaml │ │ ├── bm25.yaml │ │ └── wordset.yaml │ │ ├── mode │ │ ├── script.yaml │ │ ├── encode-query.yaml │ │ ├── encode-text.yaml │ │ ├── code.yaml │ │ ├── deploy.yaml │ │ ├── migrate.yaml │ │ ├── eval.yaml │ │ ├── index.yaml │ │ ├── encode.yaml │ │ ├── cluster.yaml │ │ ├── _eval.yaml │ │ └── train.yaml │ │ ├── script │ │ ├── _default.yaml │ │ ├── download.yaml │ │ ├── eval.yaml │ │ ├── negative.yaml │ │ ├── ttest.yaml │ │ ├── preprocess.yaml │ │ └── doct5.yaml │ │ ├── model │ │ ├── ranker.yaml │ │ ├── dense.yaml │ │ ├── sparse.yaml │ │ ├── generative.yaml │ │ └── _default.yaml │ │ ├── extra │ │ └── code.yaml │ │ ├── base │ │ ├── NQ-open.yaml │ │ ├── MSMARCO-passage.yaml │ │ ├── MS300k.yaml │ │ ├── MS600k.yaml │ │ ├── MS300k-unseen.yaml │ │ ├── MSMARCO-doc.yaml │ │ ├── Top300k-filter.yaml │ │ ├── Rand300k-filter.yaml │ │ ├── NQ320k.yaml │ │ ├── NQ-50k-seen.yaml │ │ ├── NQ320k-seen.yaml │ │ ├── NQ-50k-unseen.yaml │ │ ├── NQ320k-unseen.yaml │ │ ├── LECARD.yaml │ │ └── _default.yaml │ │ ├── _default.yaml │ │ ├── _example.yaml │ │ ├── bm25.yaml │ │ ├── crossenc.yaml │ │ ├── colbert.yaml │ │ ├── coil.yaml │ │ ├── dpr.yaml │ │ ├── ar2.yaml │ │ ├── contriever.yaml │ │ ├── ance.yaml │ │ ├── deepimpact.yaml │ │ ├── gtr.yaml │ │ ├── rankt5.yaml │ │ ├── unicoil.yaml │ │ ├── sparta.yaml │ │ ├── retromae.yaml │ │ ├── ivf.yaml │ │ ├── genre.yaml │ │ ├── dsiqg.yaml │ │ ├── seal.yaml │ │ ├── spladev2.yaml │ │ ├── tsgen.yaml │ │ ├── dsi.yaml │ │ ├── bow.yaml │ │ ├── deepspeed.json │ │ ├── tokivf.yaml │ │ ├── topivf.yaml │ │ ├── bivfpq-nq.yaml │ │ ├── uniretriever.yaml │ │ ├── bivfpq.yaml │ │ ├── sequer.yaml │ │ └── distillvq.yaml ├── models │ ├── AR2.py │ ├── KeyRank.py │ ├── DeepImpact.py │ ├── RankT5.py │ ├── AutoModel.py │ ├── ColBERT.py │ ├── CrossEnc.py │ ├── SPARTA.py │ ├── UniRetriever.py │ ├── SEAL.py │ ├── SPLADE.py │ ├── COIL.py │ ├── DPR.py │ ├── Sequer.py │ ├── DSI.py │ ├── IVF.py │ ├── BOW.py │ ├── TSGen.py │ ├── UniCOIL.py │ ├── BM25.py │ └── VQ.py ├── utils │ ├── __init__.py │ └── static.py ├── scripts │ ├── download.py │ ├── eval.py │ ├── select_phrases.py │ ├── negative.py │ ├── ttest.py │ ├── doct5.py │ ├── evalnq.py │ └── preprocess.py ├── run.py └── notebooks │ └── tsgen.ipynb ├── .gitattributes ├── LICENSE ├── README.md └── .gitignore /src/data/config/index/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/config/mode/script.yaml: -------------------------------------------------------------------------------- 1 | mode: script -------------------------------------------------------------------------------- /src/data/config/mode/encode-query.yaml: -------------------------------------------------------------------------------- 1 | mode: encode-query -------------------------------------------------------------------------------- /src/data/config/mode/encode-text.yaml: -------------------------------------------------------------------------------- 1 | mode: encode-text -------------------------------------------------------------------------------- /src/data/config/mode/code.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | mode: code -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | anserini-*.jar filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /src/data/config/mode/deploy.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | mode: deploy 3 | -------------------------------------------------------------------------------- /src/data/config/mode/migrate.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | mode: migrate -------------------------------------------------------------------------------- /src/data/config/mode/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_.eval 2 | defaults: 3 | - _eval 4 | 5 | mode: eval 6 | -------------------------------------------------------------------------------- /src/data/config/mode/index.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _eval 4 | 5 | mode: index 6 | -------------------------------------------------------------------------------- /src/data/config/mode/encode.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_.infer 2 | 3 | mode: encode 4 | do_text: true 5 | do_query: true -------------------------------------------------------------------------------- /src/data/config/script/_default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | # essential to relocate the package 3 | - /_default 4 | 5 | mode: script -------------------------------------------------------------------------------- /src/data/config/index/_default.yaml: -------------------------------------------------------------------------------- 1 | index_type: none 2 | 3 | # load the existing index 4 | load_index: false 5 | save_index: false 6 | -------------------------------------------------------------------------------- /src/data/config/model/ranker.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | # concate query and text as sentence pair 5 | return_pair: true 6 | -------------------------------------------------------------------------------- /src/data/config/index/trie.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: trie 6 | load_index: true 7 | save_index: true 8 | -------------------------------------------------------------------------------- /src/data/config/index/fm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: fm 6 | load_index: true 7 | load_collection: true 8 | index_thread: 32 9 | -------------------------------------------------------------------------------- /src/data/config/index/faiss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: flat 6 | 7 | nprobe: 1 8 | by_residual: true 9 | hnswef: 1000 10 | -------------------------------------------------------------------------------- /src/data/config/model/dense.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | # metric used for dense retrieval 5 | dense_metric: ip 6 | # separate query and document encoder? 7 | untie_encoder: false 8 | -------------------------------------------------------------------------------- /src/data/config/index/impact-tok.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: impact-tok 6 | 7 | load_collection: false 8 | quantize_bit: 8 9 | language: eng 10 | 11 | index_thread: 32 12 | -------------------------------------------------------------------------------- /src/data/config/extra/code.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_.code 2 | 3 | code_type: none 4 | code_length: 0 5 | code_tokenizer: t5 6 | code_sep: " " 7 | code_src: none 8 | return_code: true 9 | return_query_code: false 10 | -------------------------------------------------------------------------------- /src/data/config/index/impact-word.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: impact-word 6 | 7 | load_collection: false 8 | quantize_bit: 8 9 | language: eng 10 | 11 | reduce: max 12 | 13 | index_thread: 32 14 | -------------------------------------------------------------------------------- /src/data/config/base/NQ-open.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: NQ-open 5 | text_length: 128 6 | query_length: 32 7 | 8 | max_text_length: 256 9 | max_query_length: 64 10 | 11 | plm: bert 12 | 13 | text_col: [1, 2] 14 | text_col_sep: sep 15 | -------------------------------------------------------------------------------- /src/data/config/mode/cluster.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | mode: cluster 3 | # the dense embedding can be used for clusterring 4 | cluster_type: hier-l2 5 | # the number of clusters 6 | ncluster: 10 7 | # the number of leaf node in hierarchical clusterring 8 | nleaf: 100 9 | -------------------------------------------------------------------------------- /src/data/config/base/MSMARCO-passage.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: MSMARCO-passage 5 | text_length: 128 6 | query_length: 32 7 | 8 | max_text_length: 256 9 | max_query_length: 64 10 | 11 | plm: bert 12 | 13 | text_col: [2] 14 | text_col_sep: " " 15 | -------------------------------------------------------------------------------- /src/data/config/index/invhit.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: invhit 6 | 7 | # the fraction or number of items per posting 8 | posting_prune: 0. 9 | index_shard: 32 10 | index_thread: 10 11 | load_index: true 12 | save_index: true 13 | -------------------------------------------------------------------------------- /src/data/config/index/invvec.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: invvec 6 | 7 | # the fraction or number of items per posting 8 | posting_prune: 0. 9 | index_shard: 32 10 | index_thread: 10 11 | load_index: true 12 | save_index: true 13 | -------------------------------------------------------------------------------- /src/data/config/index/impact.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: impact 6 | 7 | load_collection: false 8 | load_index: false 9 | 10 | index_thread: 32 11 | language: eng 12 | 13 | quantize_bit: 8 14 | granularity: token 15 | reduce: max 16 | -------------------------------------------------------------------------------- /src/data/config/base/MS300k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: MS300k 5 | text_length: 512 6 | query_length: 64 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract 14 | text_col: [1, 2, 3] 15 | text_col_sep: " " 16 | -------------------------------------------------------------------------------- /src/data/config/base/MS600k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: MS600k 5 | text_length: 512 6 | query_length: 64 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract 14 | text_col: [1, 2, 3] 15 | text_col_sep: " " 16 | -------------------------------------------------------------------------------- /src/data/config/base/MS300k-unseen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: MS300k-unseen 5 | text_length: 512 6 | query_length: 64 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract 14 | text_col: [1, 2, 3] 15 | text_col_sep: " " 16 | -------------------------------------------------------------------------------- /src/data/config/base/MSMARCO-doc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: MSMARCO-doc 5 | text_length: 512 6 | query_length: 64 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + url + text 14 | text_col: [2, 1, 3] 15 | text_col_sep: " " 16 | -------------------------------------------------------------------------------- /src/data/config/base/Top300k-filter.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: Top300k-filter 5 | text_length: 512 6 | query_length: 64 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract 14 | text_col: [1, 2, 3] 15 | text_col_sep: " " 16 | -------------------------------------------------------------------------------- /src/data/config/base/Rand300k-filter.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: Rand300k-filter 5 | text_length: 512 6 | query_length: 64 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract 14 | text_col: [1, 2, 3] 15 | text_col_sep: " " 16 | -------------------------------------------------------------------------------- /src/models/AR2.py: -------------------------------------------------------------------------------- 1 | from .DPR import DPR 2 | from utils.util import load_pickle 3 | 4 | 5 | class AR2(DPR): 6 | def __init__(self, config): 7 | super().__init__(config) 8 | 9 | 10 | def forward(self, x): 11 | raise NotImplementedError("AR2 training not implemented!") 12 | -------------------------------------------------------------------------------- /src/data/config/_default.yaml: -------------------------------------------------------------------------------- 1 | # disable the hydra outputs 2 | defaults: 3 | - _self_ 4 | # override package to be imported from the folder 5 | - override /hydra/hydra_logging@_group_: none 6 | - override /hydra/job_logging@_group_: none 7 | 8 | hydra: 9 | output_subdir: null 10 | run: 11 | dir: . -------------------------------------------------------------------------------- /src/data/config/index/bm25.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: bm25 6 | 7 | load_collection: false 8 | load_index: false 9 | 10 | index_thread: 32 11 | language: eng 12 | 13 | # k1 and b used in bm25 14 | k1: 0.82 15 | b: 0.68 16 | pretokenize: false 17 | granularity: word 18 | 19 | -------------------------------------------------------------------------------- /src/data/config/base/NQ320k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: NQ320k 5 | text_length: 512 6 | query_length: 32 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract + content 14 | text_col: [1, 2, 3] 15 | # t5 by default has no sep token 16 | text_col_sep: " " 17 | -------------------------------------------------------------------------------- /src/data/config/base/NQ-50k-seen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: NQ-50k-seen 5 | text_length: 512 6 | query_length: 32 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract + content 14 | text_col: [1, 2, 3] 15 | # t5 by default has no sep token 16 | text_col_sep: " " 17 | -------------------------------------------------------------------------------- /src/data/config/base/NQ320k-seen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: NQ320k-seen 5 | text_length: 512 6 | query_length: 32 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract + content 14 | text_col: [1, 2, 3] 15 | # t5 by default has no sep token 16 | text_col_sep: " " 17 | -------------------------------------------------------------------------------- /src/data/config/base/NQ-50k-unseen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: NQ-50k-unseen 5 | text_length: 512 6 | query_length: 32 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract + content 14 | text_col: [1, 2, 3] 15 | # t5 by default has no sep token 16 | text_col_sep: " " 17 | -------------------------------------------------------------------------------- /src/data/config/base/NQ320k-unseen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: NQ320k-unseen 5 | text_length: 512 6 | query_length: 32 7 | 8 | max_text_length: 512 9 | max_query_length: 64 10 | 11 | plm: t5 12 | 13 | # title + abstract + content 14 | text_col: [1, 2, 3] 15 | # t5 by default has no sep token 16 | text_col_sep: " " 17 | -------------------------------------------------------------------------------- /src/data/config/base/LECARD.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | dataset: LECARD 5 | text_length: 512 6 | query_length: 512 7 | 8 | max_text_length: 512 9 | max_query_length: 512 10 | 11 | plm: bert-chinese 12 | 13 | text_col: [1] 14 | text_col_sep: "sep" 15 | 16 | eval_metric: [mrr, map, precision, ndcg, recall] 17 | eval_metric_cutoff: [5, 10, 20, 30] -------------------------------------------------------------------------------- /src/data/config/index/wordset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | - _self_ 4 | 5 | index_type: wordset 6 | index_thread: 10 7 | index_shard: 32 8 | 9 | load_index: true 10 | save_index: true 11 | 12 | # early stop when decoding? (specifically designed for wordset index) 13 | wordset_early_stop: true 14 | # at which step to enable early stop 15 | early_stop_start_len: 0 16 | -------------------------------------------------------------------------------- /src/models/KeyRank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .BaseModel import BaseModel 3 | 4 | class KeyRank(BaseModel): 5 | """ 6 | Select keywords from the document for ranking, using REINFORCE policy gradient. 7 | """ 8 | def __init__(self, config): 9 | super().__init__(config) 10 | 11 | 12 | def forward(self, x): 13 | pass 14 | 15 | 16 | def rerank_step(self, x): 17 | pass 18 | -------------------------------------------------------------------------------- /src/data/config/script/download.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # very essential to put the package directive so that the following config parameters are situated at the root layer 3 | 4 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 5 | defaults: 6 | - _default 7 | # add group package so the default list can be overriden from cli by name 8 | - /base@_group_: MSMARCO-passage 9 | -------------------------------------------------------------------------------- /src/data/config/model/sparse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | # the number of tokens to keep in text 5 | text_gate_k: 0 6 | # the number of tokens to keep in query 7 | query_gate_k: 0 8 | 9 | # return attention mask for the eos/sep token 10 | return_special_mask: false 11 | # return the attention mask for the first occurance of a token in a piece of text 12 | return_first_mask: false 13 | # separate query and document encoder? 14 | untie_encoder: false 15 | -------------------------------------------------------------------------------- /src/data/config/_example.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - mode: train 6 | - model: sparse 7 | - index: invvec 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | eval_batch_size: 2 13 | num_worker: 0 14 | 15 | train: 16 | batch_size: 2 17 | neg_type: random 18 | -------------------------------------------------------------------------------- /src/data/config/bm25.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: bm25 7 | - mode: eval 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | device: cpu 13 | plm: bert 14 | eval_batch_size: 1 15 | 16 | model: 17 | model_type: bm25 18 | 19 | index: 20 | pretokenize: false 21 | -------------------------------------------------------------------------------- /src/data/config/crossenc.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: ranker 6 | - train: neg 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: crossenc 16 | 17 | train: 18 | batch_size: 16 19 | learning_rate: 3e-5 20 | nneg: 7 21 | -------------------------------------------------------------------------------- /src/data/config/colbert.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: ranker 6 | - mode: train 7 | # add _self_ here so that the following arguments can be rewritten 8 | - _self_ 9 | 10 | base: 11 | plm: bert 12 | 13 | model: 14 | model_type: colbert 15 | token_dim: 128 16 | 17 | train: 18 | learning_rate: 3e-5 19 | 20 | eval: 21 | eval_mode: rerank 22 | -------------------------------------------------------------------------------- /src/data/config/script/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # very essential to put the package directive so that the following config parameters are situated at the root layer 3 | 4 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 5 | defaults: 6 | - _default 7 | # add group package so the default list can be overriden from cli by name 8 | - /base@_group_: MSMARCO-passage 9 | - /mode@_here_: 10 | - _eval 11 | - _self_ 12 | 13 | src: ??? 14 | -------------------------------------------------------------------------------- /src/data/config/coil.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: invvec 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | model: 12 | model_type: coil 13 | return_special_mask: true 14 | token_dim: 32 15 | 16 | train: 17 | nneg: 7 18 | 19 | eval: 20 | eval_posting_length: true 21 | -------------------------------------------------------------------------------- /src/data/config/script/negative.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # very essential to put the package directive so that the following config parameters are situated at the root layer 3 | 4 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 5 | defaults: 6 | - _default 7 | # add group package so the default list can be overriden from cli by name 8 | - /base@_group_: MSMARCO-passage 9 | - _self_ 10 | 11 | query_set: [train] 12 | neg_type: BM25 13 | save_name: default 14 | hits: 200 -------------------------------------------------------------------------------- /src/data/config/dpr.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: dpr 16 | 17 | index: 18 | index_type: Flat 19 | 20 | train: 21 | nneg: 7 22 | learning_rate: 3e-5 23 | scheduler: linear 24 | -------------------------------------------------------------------------------- /src/models/DeepImpact.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .UniCOIL import UniCOIL 3 | from .BaseModel import BaseSparseModel 4 | 5 | 6 | 7 | class DeepImpact(UniCOIL): 8 | def __init__(self, config): 9 | """ 10 | `DeepImpact model `_ 11 | """ 12 | super().__init__(config) 13 | 14 | 15 | def encode_query_step(self, x): 16 | """ 17 | not contextualized 18 | """ 19 | return BaseSparseModel.encode_query_step(self, x) 20 | -------------------------------------------------------------------------------- /src/data/config/ar2.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: eval 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | text_col: [1, 2] 13 | 14 | model: 15 | model_type: ar2 16 | untie_encoder: true 17 | 18 | index: 19 | index_type: Flat 20 | 21 | eval: 22 | eval_posting_length: true 23 | 24 | -------------------------------------------------------------------------------- /src/data/config/contriever.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: contriever 16 | 17 | index: 18 | index_type: Flat 19 | 20 | train: 21 | nneg: 7 22 | learning_rate: 3e-5 23 | scheduler: linear 24 | -------------------------------------------------------------------------------- /src/data/config/model/generative.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _default 3 | 4 | # the beams when decoding 5 | nbeam: 10 6 | # how to measure relevance, generative probability or eos hidden states 7 | rank_type: prob 8 | # stop if the threshold has been reduced to 9 | beam_trsd: 0 10 | # how many steps to examine threshold 11 | trsd_start_len: 0 12 | # sample instead of topk? 13 | decode_do_sample: false 14 | decode_do_greedy: false 15 | decode_renorm_logit: false 16 | 17 | sample_topk: null 18 | sample_topp: null 19 | sample_typicalp: null 20 | sample_tau: null 21 | -------------------------------------------------------------------------------- /src/data/config/script/ttest.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # very essential to put the package directive so that the following config parameters are situated at the root layer 3 | 4 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 5 | defaults: 6 | - _default 7 | # add group package so the default list can be overriden from cli by name 8 | - /base@_group_: MSMARCO-passage 9 | - /mode@_group_: script 10 | - _self_ 11 | 12 | x_model: ??? 13 | y_model: ??? 14 | ttest_metric: [MRR@10, Recall@10] 15 | -------------------------------------------------------------------------------- /src/data/config/ance.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: dpr 16 | 17 | index: 18 | index_type: Flat 19 | 20 | train: 21 | nneg: 7 22 | neg_type: DPR 23 | learning_rate: 3e-5 24 | scheduler: linear 25 | 26 | -------------------------------------------------------------------------------- /src/data/config/deepimpact.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: impact 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: deepimpact 16 | return_special_mask: true 17 | 18 | index: 19 | granularity: word 20 | 21 | train: 22 | nneg: 7 23 | max_grad_norm: 2.0 24 | -------------------------------------------------------------------------------- /src/data/config/gtr.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: gtr 16 | dense_metric: cos 17 | 18 | index: 19 | index_type: Flat 20 | 21 | train: 22 | nneg: 7 23 | learning_rate: 3e-5 24 | scheduler: linear 25 | -------------------------------------------------------------------------------- /src/data/config/script/preprocess.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # very essential to put the package directive so that the following config parameters are situated at the root layer 3 | 4 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 5 | defaults: 6 | - _default 7 | # add group package so the default list can be overriden from cli by name 8 | - /base@_group_: MSMARCO-passage 9 | - _self_ 10 | 11 | do_text: true 12 | do_query: true 13 | query_set: [train, dev] 14 | pretokenize: true 15 | tokenize_thread: 32 16 | -------------------------------------------------------------------------------- /src/data/config/rankt5.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: ranker 6 | - mode: train 7 | # add _self_ here so that the following arguments can be rewritten 8 | - _self_ 9 | 10 | base: 11 | plm: t5 12 | 13 | model: 14 | model_type: rankt5 15 | ranking_token: 32089 # 16 | 17 | query_prefix: "Query:" 18 | text_prefix: "Text:" 19 | 20 | train: 21 | batch_size: 32 22 | learning_rate: 1e-4 23 | nneg: 7 24 | -------------------------------------------------------------------------------- /src/data/config/unicoil.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: invvec 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: bert 13 | 14 | model: 15 | model_type: unicoil 16 | return_special_mask: true 17 | return_first_mask: true 18 | 19 | train: 20 | nneg: 7 21 | 22 | eval: 23 | eval_posting_length: true 24 | eval_flops: true 25 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # very essential that we import faiss before torch on zhiyuan machine 2 | import faiss 3 | import torch 4 | 5 | import logging 6 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s (%(name)s) %(message)s") 7 | logging.getLogger("faiss.loader").setLevel(logging.ERROR) 8 | logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING) 9 | 10 | import transformers 11 | # prevent warning of transformers 12 | transformers.logging.set_verbosity_error() 13 | 14 | import os 15 | os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' 16 | -------------------------------------------------------------------------------- /src/data/config/mode/_eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_.eval 2 | 3 | # evaluation metrics, seperated by colon 4 | eval_metric: [mrr, recall] 5 | # the cutoff for each evaluation metric 6 | eval_metric_cutoff: [1,5,10,100,1000] 7 | # the cutoff for retrieval result 8 | hits: 1000 9 | # evaluate flops? 10 | eval_flops: false 11 | # evaluate posting length in inverted indexes? 12 | eval_posting_length: false 13 | 14 | # the post verifier 15 | verifier_type: none 16 | # the source of verifier 17 | verifier_src: none 18 | # the (pq) index used 19 | verifier_index: none 20 | # the final hits 21 | verifier_hits: 1000 22 | -------------------------------------------------------------------------------- /src/scripts/download.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from utils.util import Config 3 | 4 | import hydra 5 | from pathlib import Path 6 | from omegaconf import OmegaConf 7 | @hydra.main(version_base=None, config_path="../data/config/", config_name=f"script/{Path(__file__).stem}") 8 | def get_config(hydra_config: OmegaConf): 9 | config._from_hydra(hydra_config) 10 | 11 | 12 | if __name__ == "__main__": 13 | # manually action="store_true" because hydra doesn't support it 14 | for i, arg in enumerate(sys.argv): 15 | if "=" not in arg: 16 | sys.argv[i] += "=true" 17 | 18 | config = Config() 19 | get_config() 20 | -------------------------------------------------------------------------------- /src/data/config/sparta.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: invvec 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | model: 12 | model_type: sparta 13 | text_decode_k: 200 14 | 15 | index: 16 | return_first_mask: false 17 | load_index: false 18 | save_index: false 19 | 20 | train: 21 | nneg: 7 22 | learning_rate: 3e-5 23 | 24 | eval: 25 | eval_posting_length: true 26 | eval_flops: true 27 | -------------------------------------------------------------------------------- /src/data/config/retromae.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: eval 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: retromae_distill 13 | text_col: [1, 2] 14 | text_length: 140 15 | 16 | model: 17 | model_type: dpr 18 | 19 | index: 20 | index_type: Flat 21 | 22 | train: 23 | nneg: 15 24 | learning_rate: 2e-5 25 | scheduler: linear 26 | 27 | eval: 28 | eval_posting_length: true -------------------------------------------------------------------------------- /src/data/config/model/_default.yaml: -------------------------------------------------------------------------------- 1 | # model class 2 | model_type: null 3 | 4 | # the checkpoint to load 5 | load_ckpt: null 6 | # the checkpoint path to save 7 | save_ckpt: best 8 | 9 | # load the encoded cache 10 | load_encode: false 11 | # save the encoded result 12 | save_encode: false 13 | load_text_encode: false 14 | load_query_encode: false 15 | 16 | # load the existing retrieval result 17 | load_result: false 18 | # save the model after main function 19 | save_model: false 20 | # save the retrieval result together with the score of each retrieved document 21 | save_score: false 22 | # file name for the retrieval results 23 | save_res: retrieval_result 24 | 25 | -------------------------------------------------------------------------------- /src/data/config/ivf.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: invhit 7 | - mode: eval 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | model: 12 | model_type: ivf 13 | 14 | query_gate_k: 20 15 | 16 | vq_src: RetroMAE 17 | vq_index: IVF10000,PQ64x8 18 | embedding_src: RetroMAE 19 | return_embedding: true 20 | 21 | load_ckpt: none 22 | 23 | eval: 24 | hits: 0 25 | eval_posting_length: true 26 | 27 | verifier_type: pq 28 | verifier_src: DistillVQ_d-RetroMAE 29 | verifier_index: OPQ96,PQ96x8 30 | -------------------------------------------------------------------------------- /src/data/config/base/_default.yaml: -------------------------------------------------------------------------------- 1 | # the root directory of the raw data 2 | data_root: /data/TSGen 3 | plm_root: /data/TSGen/PLMs 4 | seed: 42 5 | 6 | # the device to run the model or script 7 | device: 0 8 | 9 | text_type: default 10 | data_format: memmap 11 | num_worker: 2 12 | 13 | # the batch size fed to the loader_eval 14 | eval_batch_size: 100 15 | # the dataset to evaluate the model or run commands 16 | eval_set: dev 17 | # the mode to evaluate the model: retrieve or rerank 18 | eval_mode: retrieve 19 | # use the debug mode (will train 2 steps and encode 10 steps) 20 | debug: false 21 | 22 | # when using distributed training/evaluating, we can choose to split text 23 | # or query across processes 24 | parallel: text 25 | -------------------------------------------------------------------------------- /src/data/config/genre.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ320k 5 | - model: generative 6 | - index: trie 7 | - mode: train 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: t5 14 | parallel: query 15 | 16 | model: 17 | model_type: genre 18 | nbeam: 10 19 | 20 | train: 21 | bf16: true 22 | # only query-pos pair 23 | neg_type: none 24 | 25 | epoch: 80 26 | eval_delay: 10e 27 | 28 | learning_rate: 1e-3 29 | scheduler: linear 30 | batch_size: 400 31 | main_metric: MRR@10 32 | 33 | 34 | code: 35 | code_type: title 36 | code_length: 26 37 | 38 | -------------------------------------------------------------------------------- /src/data/config/dsiqg.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ320k 5 | - model: generative 6 | - index: trie 7 | - mode: train 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: t5 14 | parallel: query 15 | 16 | model: 17 | model_type: dsiqg 18 | nbeam: 10 19 | 20 | train: 21 | bf16: true 22 | # only query-pos pair 23 | neg_type: none 24 | epoch: 80 25 | 26 | learning_rate: 1e-3 27 | scheduler: linear 28 | batch_size: 400 29 | eval_delay: 20e 30 | early_stop_patience: 10 31 | 32 | main_metric: MRR@10 33 | 34 | train_set: [train, doct5] 35 | 36 | code: 37 | code_type: id 38 | code_length: 8 39 | -------------------------------------------------------------------------------- /src/data/config/script/doct5.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # very essential to put the package directive so that the following config parameters are situated at the root layer 3 | 4 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 5 | defaults: 6 | - _default 7 | # add group package so the default list can be overriden from cli by name 8 | - /base@_group_: NQ320k 9 | - _self_ 10 | 11 | base: 12 | plm: doct5 13 | eval_batch_size: 50 14 | 15 | # sometimes we want to tokenize the generated queries with another plm and save the results 16 | dest_plm: t5 17 | # how many queries to generate for each document 18 | query_per_doc: 10 19 | # load previously stored memmap file? 20 | load_encode: false 21 | # how many threads to use? 22 | tokenize_thread: 32 23 | -------------------------------------------------------------------------------- /src/data/config/seal.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ320k 5 | - model: generative 6 | - index: fm 7 | - mode: eval 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: bart 14 | parallel: query 15 | eval_batch_size: 1 16 | 17 | model: 18 | model_type: seal 19 | nbeam: 10 20 | 21 | train: 22 | # only query-pos pair 23 | neg_type: none 24 | epoch: 80 25 | max_grad_norm: 0.1 26 | 27 | learning_rate: 1e-3 28 | scheduler: linear 29 | batch_size: 400 30 | eval_delay: 40e 31 | 32 | main_metric: MRR@10 33 | 34 | code: 35 | code_type: seal 36 | code_length: 10 37 | return_code: false 38 | # return_query_code: true 39 | -------------------------------------------------------------------------------- /src/data/config/spladev2.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: impact 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | model: 12 | model_type: spladev2 13 | save_encode: true 14 | 15 | text_decode_k: 128 16 | query_decode_k: 128 17 | text_lambda: 1e-2 18 | query_lambda: 1e-2 19 | 20 | 21 | index: 22 | return_first_mask: false 23 | load_index: false 24 | save_index: false 25 | 26 | train: 27 | 28 | learning_rate: 2e-5 29 | scheduler: linear 30 | batch_size: 64 31 | nneg: 7 32 | lambda_warmup_step: 0 33 | eval_step: 5e 34 | 35 | eval: 36 | eval_posting_length: true 37 | eval_flops: true 38 | -------------------------------------------------------------------------------- /src/data/config/tsgen.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ320k 5 | - model: generative 6 | - index: wordset 7 | - mode: train 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: t5 14 | parallel: query 15 | 16 | model: 17 | model_type: tsgen 18 | nbeam: 10 19 | 20 | train: 21 | # only query-pos pair 22 | neg_type: none 23 | epoch: 50 24 | batch_size: 512 25 | bf16: true 26 | 27 | learning_rate: 2e-3 28 | scheduler: linear 29 | eval_delay: 20e 30 | early_stop_patience: 5 31 | main_metric: MRR@10 32 | 33 | code: 34 | code_type: term 35 | code_length: 26 36 | code_sep: "," 37 | reduce_code: min 38 | permute_code: 0 39 | -------------------------------------------------------------------------------- /src/data/config/dsi.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ320k 5 | - model: generative 6 | - index: trie 7 | - mode: train 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: t5 14 | parallel: query 15 | 16 | model: 17 | model_type: dsi 18 | nbeam: 10 19 | 20 | train: 21 | train_set: [train, doc] 22 | # only query-pos pair 23 | neg_type: none 24 | epoch: 80 25 | bf16: true 26 | 27 | learning_rate: 1e-3 28 | scheduler: linear 29 | batch_size: 400 30 | eval_delay: 40e 31 | early_stop_patience: 0 32 | 33 | main_metric: MRR@10 34 | 35 | code: 36 | code_type: ANCE_hier 37 | code_length: 10 38 | -------------------------------------------------------------------------------- /src/data/config/bow.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ 5 | - model: generative 6 | - index: wordset 7 | - mode: train 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: t5 14 | parallel: query 15 | 16 | model: 17 | model_type: bow 18 | nbeam: 10 19 | 20 | train: 21 | # only query-pos pair 22 | neg_type: none 23 | epoch: 50 24 | batch_size: 400 25 | bf16: true 26 | 27 | learning_rate: 1e-3 28 | scheduler: linear 29 | eval_delay: 20e 30 | early_stop_patience: 5 31 | main_metric: MRR@10 32 | 33 | code: 34 | code_type: words_comma_plus_stem 35 | code_length: 26 36 | code_sep: "," 37 | reduce_code: min 38 | permute_code: 0 39 | -------------------------------------------------------------------------------- /src/data/config/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | }, 22 | "gradient_accumulation_steps": "auto", 23 | "gradient_clipping": "auto", 24 | "steps_per_print": 2000, 25 | "train_batch_size": "auto", 26 | "train_micro_batch_size_per_gpu": "auto", 27 | "wall_clock_breakdown": false 28 | } 29 | -------------------------------------------------------------------------------- /src/data/config/tokivf.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: invhit 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | plm: retromae_msmarco 13 | text_col: [1, 2] 14 | 15 | model: 16 | model_type: tokivf 17 | return_special_mask: true 18 | # how many token postings to scan 19 | text_gate_k: 3 20 | 21 | index: 22 | # what percentile of the inverted lists are kept 23 | posting_prune: 0.996 24 | 25 | train: 26 | nneg: 7 27 | enable_distill: bi 28 | distill_src: RetroMAE 29 | 30 | eval: 31 | eval_posting_length: true 32 | verifier_type: pq 33 | verifier_src: DistillVQ_d-RetroMAE 34 | verifier_index: OPQ96,PQ96x8 35 | -------------------------------------------------------------------------------- /src/data/config/topivf.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: sparse 6 | - index: invhit 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | base: 12 | eval_batch_size: 500 13 | 14 | model: 15 | model_type: topivf 16 | 17 | query_gate_k: 20 18 | 19 | vq_src: RetroMAE 20 | vq_index: IVF10000,PQ64x8 21 | 22 | embedding_src: RetroMAE 23 | return_embedding: true 24 | 25 | load_ckpt: none 26 | 27 | enable_commit_loss: true 28 | 29 | train: 30 | epoch: 50 31 | 32 | learning_rate: 1e-4 33 | scheduler: linear 34 | 35 | eval: 36 | hits: 0 37 | eval_posting_length: true 38 | verifier_type: pq 39 | verifier_src: DistillVQ_d-RetroMAE 40 | verifier_index: OPQ96,PQ96x8 41 | -------------------------------------------------------------------------------- /src/data/config/bivfpq-nq.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ-open 5 | - model: _default 6 | - mode: eval 7 | 8 | model: 9 | model_type: uniretriever 10 | 11 | return_embedding: true 12 | embedding_src: AR2 13 | 14 | x_model: TokIVF 15 | x_index_type: invhit 16 | x_hits: 0 17 | x_load_encode: true 18 | x_text_gate_k: 3 19 | x_load_ckpt: best 20 | x_posting_prune: 0.996 21 | 22 | y_model: TopIVF 23 | y_index_type: invhit 24 | y_hits: 0 25 | y_load_encode: false 26 | y_query_gate_k: 20 27 | y_load_ckpt: best 28 | 29 | x_eval_flops: false 30 | y_eval_flops: false 31 | x_eval_posting_length: true 32 | y_eval_posting_length: true 33 | 34 | eval: 35 | verifier_type: pq 36 | verifier_src: DistillVQ_d-AR2 37 | verifier_index: OPQ96,PQ96x8 38 | -------------------------------------------------------------------------------- /src/data/config/uniretriever.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: _default 6 | - mode: eval 7 | # add _self_ here so that the following arguments can be rewritten 8 | - _self_ 9 | 10 | mode: 11 | verifier_type: flat 12 | verifier_src: AR2 13 | # verifier_index: OPQ96,PQ96x8 14 | 15 | eval: 16 | model_type: uniretriever 17 | load_index: true 18 | 19 | x_model: BM25 20 | x_index_type: bm25 21 | x_hits: 1000 22 | 23 | y_model: AR2 24 | y_index_type: IVF10000,PQ64x8 25 | y_hits: 1000 26 | 27 | x_load_encode: true 28 | y_load_encode: true 29 | x_load_index: true 30 | y_load_index: true 31 | 32 | x_load_ckpt: best 33 | y_load_ckpt: best 34 | 35 | x_verifier_type: none 36 | y_verifier_type: none 37 | 38 | x_eval_posting_length: true 39 | y_eval_posting_length: true 40 | -------------------------------------------------------------------------------- /src/data/config/bivfpq.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: _default 6 | - mode: eval 7 | 8 | base: 9 | text_col: [1, 2] 10 | 11 | model: 12 | return_embedding: true 13 | embedding_src: RetroMAE 14 | 15 | model_type: uniretriever 16 | 17 | x_model: TokIVF 18 | x_index_type: invhit 19 | x_hits: 0 20 | x_load_encode: true 21 | x_text_gate_k: 3 22 | x_load_ckpt: best 23 | x_posting_prune: 0.996 24 | 25 | y_model: TopIVF 26 | y_index_type: invhit 27 | y_hits: 0 28 | y_load_encode: false 29 | y_query_gate_k: 20 30 | y_load_ckpt: best 31 | 32 | x_eval_flops: false 33 | y_eval_flops: false 34 | x_eval_posting_length: true 35 | y_eval_posting_length: true 36 | 37 | eval: 38 | verifier_type: pq 39 | verifier_src: DistillVQ_d-RetroMAE 40 | verifier_index: OPQ96,PQ96x8 41 | -------------------------------------------------------------------------------- /src/data/config/sequer.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: NQ 5 | - model: ranker 6 | - index: trie 7 | - mode: train 8 | - extra: code 9 | # add _self_ here so that the following arguments can be rewritten 10 | - _self_ 11 | 12 | base: 13 | plm: t5 14 | parallel: query 15 | 16 | model: 17 | model_type: sequer 18 | train_scheme: contra 19 | # how to rank when evaluating 20 | rank_type: eos 21 | # beam size 22 | nbeam: 10 23 | 24 | index: 25 | # threshold for relaxed beam search 26 | beam_trsd: 0 27 | trsd_start_len: 3 28 | 29 | train: 30 | epoch: 50 31 | batch_size: 64 32 | learning_rate: 3e-5 33 | neg_type: BM25 34 | nneg: 23 35 | main_metric: MRR@10 36 | return_prefix_mask: true 37 | bf16: true 38 | 39 | eval: 40 | eval_mode: rerank 41 | cand_type: BM25 42 | ncand: 100 43 | 44 | code: 45 | code_type: words_comma_plus_stem 46 | code_length: 26 47 | code_sep: "," 48 | -------------------------------------------------------------------------------- /src/data/config/distillvq.yaml: -------------------------------------------------------------------------------- 1 | # load the default lists, whose parameters can be changed by referencing its namespace in the following 2 | defaults: 3 | - _default 4 | - base: MSMARCO-passage 5 | - model: dense 6 | - index: faiss 7 | - mode: train 8 | # add _self_ here so that the following arguments can be rewritten 9 | - _self_ 10 | 11 | model: 12 | model_type: distillvq 13 | 14 | return_embedding: true 15 | embedding_src: RetroMAE 16 | 17 | # dynamically update ivf assignments 18 | train_ivf_assign: false 19 | # dynamically update pq assignments 20 | train_pq_assign: false 21 | # train query encoder together with the index 22 | train_encoder: false 23 | # freeze pq centroids, only update IVF centroids 24 | freeze_pq: false 25 | 26 | index: 27 | index_type: OPQ96,PQ96x8 28 | 29 | train: 30 | distill_src: RetroMAE 31 | enable_distill: bi 32 | 33 | epoch: 50 34 | batch_size: 128 35 | nneg: 31 36 | 37 | learning_rate: 1e-5 38 | learning_rate_pq: 1e-4 39 | learning_rate_ivf: 1e-4 40 | scheduler: linear 41 | 42 | eval: 43 | eval_posting_length: true 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/scripts/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from utils.util import compute_metrics, load_pickle, Config 4 | 5 | import hydra 6 | from pathlib import Path 7 | from omegaconf import OmegaConf 8 | @hydra.main(version_base=None, config_path="../data/config/", config_name=f"script/{Path(__file__).stem}") 9 | def get_config(hydra_config: OmegaConf): 10 | config._from_hydra(hydra_config) 11 | 12 | 13 | if __name__ == "__main__": 14 | # manually action="store_true" because hydra doesn't support it 15 | for i, arg in enumerate(sys.argv): 16 | if "=" not in arg: 17 | sys.argv[i] += "=true" 18 | 19 | config = Config() 20 | get_config() 21 | 22 | if os.path.exists(config.src): 23 | path = config.src 24 | elif os.path.exists(os.path.join(config.cache_root, config.eval_mode, config.src, config.eval_set, "retrieval_result.pkl")): 25 | path = os.path.join(config.cache_root, config.eval_mode, config.src, config.eval_set, "retrieval_result.pkl") 26 | else: 27 | raise FileNotFoundError 28 | 29 | retrieval_result = load_pickle(path) 30 | 31 | ground_truth = load_pickle(os.path.join(config.cache_root, "dataset", "query", config.eval_set, "positives.pkl")) 32 | metrics = compute_metrics(retrieval_result, ground_truth, metrics=config.eval_metric, cutoffs=config.eval_metric_cutoff) 33 | print() 34 | print(metrics) -------------------------------------------------------------------------------- /src/data/config/mode/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_.train 2 | defaults: 3 | - _eval # load the configs 4 | 5 | mode: train 6 | 7 | # default to use negative 8 | loader_train: neg 9 | # query set for training 10 | train_set: [train] 11 | 12 | epoch: 20 13 | # the total batch size 14 | batch_size: 128 15 | # mixed precision 16 | fp16: false 17 | bf16: false 18 | # gradient accumulation 19 | grad_accum_step: 1 20 | # Stop training when the evaluation results is inferior to the best one for ? times. 21 | early_stop_patience: 5 22 | # clip grad 23 | max_grad_norm: 0 24 | # maximum steps for training 25 | max_step: 0 26 | # wandb 27 | report_to: none 28 | # deepspeed configuration file path 29 | deepspeed: null 30 | 31 | learning_rate: 3e-6 32 | adam_beta1: 0.9 33 | adam_beta2: 0.999 34 | adam_epsilon: 1e-8 35 | weight_decay: 0.01 36 | scheduler: constant 37 | warmup_ratio: 0.1 38 | warmup_step: 0 39 | 40 | main_metric: Recall@10 41 | # interval of testing the model performance 42 | eval_step: 1e 43 | # donot test the model performance before eval_delay steps 44 | eval_delay: 0 45 | # if true, save the model after validation 46 | # otherwise, only store the ever-best performance model 47 | save_at_eval: false 48 | 49 | 50 | # how many hard negatives to use? 51 | nneg: 1 52 | # what kind of hard negatives to use? 53 | neg_type: BM25 54 | # use inbatch negative? 55 | enable_inbatch_negative: true 56 | # gather all the embeddings across processes in distributed training? 57 | enable_all_gather: true 58 | # distillation 59 | enable_distill: false 60 | # distill from which model? 61 | distill_src: none 62 | -------------------------------------------------------------------------------- /src/scripts/select_phrases.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from tqdm import tqdm 4 | from sentence_transformers import SentenceTransformer 5 | from keybert import KeyBERT 6 | from accelerate import Accelerator 7 | 8 | 9 | if __name__ == "__main__": 10 | accelerator = Accelerator(cpu=True) 11 | 12 | model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder="/share/LMs/", device="cpu") 13 | model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True).eval() 14 | 15 | kw_model = KeyBERT(model) 16 | 17 | ndoc = 109739 18 | ndoc_per_node = ndoc / accelerator.num_processes 19 | start_idx = round(ndoc_per_node * accelerator.process_index) 20 | end_idx = round(ndoc_per_node * (accelerator.process_index + 1)) 21 | 22 | with open("/share/peitian/share/Datasets/Adon/NQ320k/collection.tsv") as f, open(f"/share/peitian/share/Datasets/Adon/NQ320k/phrases/2grams.{accelerator.process_index}.json", "w") as g: 23 | for i, line in enumerate(tqdm(f, total=109739)): 24 | if i < start_idx: 25 | continue 26 | if i >= end_idx: 27 | break 28 | 29 | text = " ".join(line.split("\t")[1:]).strip() 30 | text = " ".join(text.split()[:1024]) 31 | 32 | with torch.no_grad(): 33 | phrases = kw_model.extract_keywords(text, keyphrase_ngram_range=(2, 2), stop_words="english", top_n=200, use_mmr=True, diversity=0.5) 34 | # phrases = kw_model.extract_keywords(text, keyphrase_ngram_range=(2, 2), stop_words="english", top_n=200) 35 | phrases = [x[0] for x in phrases] 36 | g.write(json.dumps(phrases, ensure_ascii=False) + "\n") 37 | -------------------------------------------------------------------------------- /src/models/RankT5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers import T5ForConditionalGeneration 4 | from .BaseModel import BaseModel 5 | 6 | 7 | 8 | class RankT5(BaseModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self.plm = T5ForConditionalGeneration.from_pretrained(config.plm_dir) 12 | 13 | 14 | def _compute_score(self, **kwargs): 15 | """ concate the query and the input text; 16 | Args: 17 | query_token_id: B, LQ 18 | text_token_id: B, LS 19 | Returns: 20 | tensor of [B] 21 | """ 22 | for k, v in kwargs.items(): 23 | # B, 1+N, L -> B * (1+N), L 24 | if v.dim() == 3: 25 | kwargs[k] = v.view(-1, v.shape[-1]) 26 | 27 | batch_size = kwargs["input_ids"].shape[0] 28 | score = self.plm(**kwargs, decoder_input_ids=torch.zeros((batch_size, 1), dtype=torch.long, device=self.config.device)).logits[:, 0, self.config.ranking_token] 29 | return score 30 | 31 | 32 | def rerank_step(self, x): 33 | x = self._move_to_device(x) 34 | pair = x["pair"] 35 | score = self._compute_score(**pair) 36 | return score 37 | 38 | 39 | def forward(self, x): 40 | pair = x["pair"] 41 | score = self.rerank_step(x) # B*(1+N) 42 | 43 | if pair["input_ids"].dim() == 3: 44 | # use cross entropy loss 45 | score = score.view(x["pair"]["input_ids"].shape[0], -1) 46 | label = torch.zeros(score.shape[0], dtype=torch.long, device=self.config.device) 47 | loss = F.cross_entropy(score, label) 48 | 49 | elif pair["input_ids"].dim() == 2: 50 | label = x["label"] 51 | loss = F.binary_cross_entropy(torch.sigmoid(score), label) 52 | 53 | return loss 54 | 55 | 56 | -------------------------------------------------------------------------------- /src/scripts/negative.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate negatives from the ``retrieval_result`` returned by :func:`models.BaseModel.BaseModel.retrieve` over ``train`` set. 3 | """ 4 | import sys 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | from utils.util import load_pickle, save_pickle, Config 9 | 10 | import hydra 11 | from pathlib import Path 12 | from omegaconf import OmegaConf 13 | @hydra.main(version_base=None, config_path="../data/config/", config_name=f"script/{Path(__file__).stem}") 14 | def get_config(hydra_config: OmegaConf): 15 | config._from_hydra(hydra_config) 16 | 17 | 18 | if __name__ == "__main__": 19 | # manually action="store_true" because hydra doesn't support it 20 | for i, arg in enumerate(sys.argv): 21 | if "=" not in arg: 22 | sys.argv[i] += "=true" 23 | 24 | config = Config() 25 | get_config() 26 | 27 | for query_set in config.query_set: 28 | positives = load_pickle(f"{config.cache_root}/dataset/query/{query_set}/positives.pkl") 29 | 30 | retrieval_result = load_pickle(f"{config.cache_root}/retrieve/{config.neg_type}/{query_set}/retrieval_result.pkl") 31 | hard_negatives = defaultdict(list) 32 | for k,v in tqdm(retrieval_result.items(), desc="Collecting Negatives", ncols=100): 33 | for i, x in enumerate(v[:config.hits]): 34 | if x in positives[k]: 35 | continue 36 | hard_negatives[k].append(x) 37 | 38 | nnegs = np.array([len(x) for x in hard_negatives.values()]) 39 | print(f"the collected query number is {len(hard_negatives)}, whose negative number is MEAN: {np.round(nnegs.mean(), 1)}, MAX: {nnegs.max()}, MIN: {nnegs.min()}") 40 | 41 | if config.save_name != "default": 42 | save_name = config.save_name 43 | else: 44 | save_name = config.neg_type 45 | save_path = f"{config.cache_root}/dataset/query/{query_set}/negatives_{save_name}.pkl" 46 | save_pickle(dict(hard_negatives), save_path) 47 | print(f"saved negatives at {save_path}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Retrieval via Term Set Generation (TSGen) 2 | 3 | This repository contains the implementation of the SIGIR'24 paper Generative Retrieval via Term Set Generation (TSGen). 4 | 5 | ## Quick Maps 6 | - For training configurations of TSGen, check [train.yaml](src/data/config/mode/train.yaml) and [tsgen.yaml](src/data/config/tsgen.yaml) 7 | - For the implementation of our proposed *permutation-invariant decoding* algorithm, check [index.py](src/utils/index.py), from line 1459 to line 2465 (we modify the `.generate()` method in huggingface transformers and implement it with a new class named `BeamDecoder`). 8 | 9 | ## Reproduction 10 | - Environment 11 | - `python==3.9.12` 12 | - `torch==1.10.1` 13 | - `transformers==4.21.3` 14 | - `faiss==1.7.2` 15 | 16 | - Data 17 | ```bash 18 | # suppose you want to save the dataset at /data/TSGen 19 | # NOTE: if you prefer another location, remember to set 'data_root' and 'plm_root' in src/data/config/base/_default.yaml accordingly 20 | mkdir /data/TSGen 21 | cd /data/TSGen 22 | # download NQ320k dataset 23 | wget https://huggingface.co/datasets/namespace-Pt/adon/resolve/main/NQ320k.tar.gz?download=true -O NQ320k.tar.gz 24 | # untar the file, which results in the folder /data/TSGen/NQ320k 25 | tar -xzvf NQ320k.tar.gz 26 | 27 | # move to the code folder 28 | cd TSGen/src 29 | # preprocess the dataset, which results in the folder TSGen/src/data/cache/NQ320k/dataset 30 | python -m scripts.preprocess base=NQ320k ++query_set=test 31 | 32 | # move to the cache folder 33 | cd TSGen/src/data/cache/NQ320k 34 | # download the checkpoint and the term-set docid 35 | wget https://huggingface.co/datasets/namespace-Pt/adon/resolve/main/tsgen.tar.gz?download=true -O TSGen.tar.gz 36 | # untar the file, which results in the folder TSGen/src/data/cache/NQ320k/ckpts and TSGen/src/data/cache/NQ320k/codes 37 | tar -xzvf NQ320k.tar.gz 38 | ``` 39 | 40 | - Evaluation 41 | ```bash 42 | # evaluate with 100 beams 43 | torchrun --nproc_per_node 8 run.py TSGen base=NQ320k mode=eval ++nbeam=100 ++eval_batch_size=20 44 | ``` 45 | The results should be similar to: 46 | |MRR@10|MRR@100|Recall@1|Recall@10|Recall@100| 47 | |0.771|0.774|0.708|0.889|0.948| 48 | -------------------------------------------------------------------------------- /src/models/AutoModel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from .BaseModel import BaseModel 4 | 5 | from .AR2 import AR2 6 | from .TSGen import TSGen 7 | from .BM25 import BM25 8 | from .COIL import COIL 9 | from .ColBERT import ColBERT 10 | from .DSI import DSI, GENRE, DSIQG 11 | from .DeepImpact import DeepImpact 12 | from .DPR import DPR, Contriever, GTR 13 | from .IVF import IVF, TopIVF, TokIVF 14 | from .SPARTA import SPARTA 15 | from .SPLADE import SPLADEv2 16 | from .SEAL import SEAL 17 | from .RankT5 import RankT5 18 | from .CrossEnc import CrossEncoder 19 | from .UniCOIL import UniCOIL 20 | from .UniRetriever import UniRetriever 21 | from .VQ import DistillVQ 22 | 23 | MODEL_MAP = { 24 | "ar2": AR2, 25 | "bm25": BM25, 26 | "tsgen": TSGen, 27 | "coil": COIL, 28 | "colbert": ColBERT, 29 | "crossenc": CrossEncoder, 30 | "contriever": Contriever, 31 | "deepimpact": DeepImpact, 32 | "distillvq": DistillVQ, 33 | "dpr": DPR, 34 | "dsi": DSI, 35 | "dsiqg": DSIQG, 36 | "genre": GENRE, 37 | "gtr": GTR, 38 | "ivf": IVF, 39 | "rankt5": RankT5, 40 | "sparta": SPARTA, 41 | "spladev2": SPLADEv2, 42 | "seal": SEAL, 43 | "topivf": TopIVF, 44 | "tokivf": TokIVF, 45 | "unicoil": UniCOIL, 46 | "uniretriever": UniRetriever 47 | } 48 | 49 | 50 | class AutoModel(BaseModel): 51 | @classmethod 52 | def from_pretrained(cls, ckpt_path, **kwargs): 53 | state_dict = torch.load(ckpt_path, map_location="cpu") 54 | 55 | config = state_dict["config"] 56 | model_name_current = os.path.abspath(ckpt_path).split(os.sep)[-2] 57 | model_name_ckpt = config.name 58 | model_type = model_name_current.split("_")[0].lower() 59 | 60 | # override model name 61 | config.update(**kwargs, name=model_name_current) 62 | # re-initialize the config so the distributed information is properly set 63 | config.__post_init__() 64 | 65 | try: 66 | model = MODEL_MAP[model_type](config).to(config.device) 67 | except KeyError: 68 | raise NotImplementedError(f"Model {model_type} not implemented!") 69 | if model_name_ckpt != model_name_current: 70 | model.logger.warning(f"model name in the checkpoint is {model_name_ckpt}, while it's {model_name_current} now!") 71 | 72 | model.logger.info(f"loading model from {ckpt_path} with checkpoint config...") 73 | model.load_state_dict(state_dict["model"]) 74 | model.metrics = state_dict["metrics"] 75 | 76 | model.eval() 77 | return model 78 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import hydra 3 | from omegaconf import OmegaConf 4 | from utils.util import Config 5 | from utils.data import prepare_data 6 | from models.AutoModel import MODEL_MAP 7 | 8 | name = None 9 | 10 | @hydra.main(version_base=None, config_path="data/config/") 11 | def get_config(hydra_config: OmegaConf): 12 | config._from_hydra(hydra_config) 13 | config.name = name 14 | 15 | 16 | def main(config:Config): 17 | """ train/dev/test the model (in distributed) 18 | 19 | Args: 20 | rank: current process id 21 | world_size: total gpus 22 | """ 23 | loaders = prepare_data(config) 24 | model = MODEL_MAP[config.model_type](config).to(config.device) 25 | 26 | if config.mode == "train": 27 | from utils.trainer import train 28 | model.load() 29 | train(model, loaders) 30 | 31 | elif config.mode == "eval": 32 | model.load() 33 | model.evaluate(loaders) 34 | 35 | elif config.mode == "encode": 36 | model.load() 37 | model.encode(loaders) 38 | 39 | elif config.mode == "cluster": 40 | model.load() 41 | model.cluster(loaders) 42 | 43 | elif config.mode == "code": 44 | model.load() 45 | model.generate_code(loaders) 46 | 47 | elif config.mode == "migrate": 48 | from utils.util import load_from_previous 49 | if config.is_main_proc: 50 | path = f"{config.cache_root}/ckpts/{model.name}/{config.load_ckpt}" 51 | load_from_previous(model, path) 52 | model.save() 53 | 54 | elif config.mode == "deploy": 55 | model.load() 56 | model.deploy() 57 | 58 | elif config.mode == "index": 59 | model.load() 60 | model.index(loaders) 61 | 62 | else: 63 | raise ValueError(f"Invalid mode {config.mode}!") 64 | 65 | if config.save_model: 66 | model.save() 67 | 68 | 69 | 70 | if __name__ == "__main__": 71 | # get the model full name 72 | name = sys.argv.pop(1) 73 | # parse the config_name, which is the first part in the list split by _ 74 | config_name = name.split("_")[0].lower() 75 | # add the parsed config_name back to the sys.argv so that hydra can use it 76 | sys.argv.insert(1, config_name) 77 | sys.argv.insert(1, "--config-name") 78 | 79 | # manually action="store_true" because hydra doesn't support it 80 | for i, arg in enumerate(sys.argv): 81 | if i > 2 and "=" not in arg: 82 | sys.argv[i] += "=true" 83 | 84 | config = Config() 85 | get_config() 86 | 87 | main(config) 88 | -------------------------------------------------------------------------------- /src/scripts/ttest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import scipy.stats as stats 4 | from utils.util import load_pickle, compute_metrics, Config 5 | 6 | import hydra 7 | from pathlib import Path 8 | from omegaconf import OmegaConf 9 | @hydra.main(version_base=None, config_path="../data/config/", config_name=f"script/{Path(__file__).stem}") 10 | def get_config(hydra_config: OmegaConf): 11 | config._from_hydra(hydra_config) 12 | 13 | 14 | if __name__ == "__main__": 15 | # manually action="store_true" because hydra doesn't support it 16 | for i, arg in enumerate(sys.argv): 17 | if "=" not in arg: 18 | sys.argv[i] += "=true" 19 | 20 | config = Config() 21 | get_config() 22 | 23 | if os.path.exists(config.x_model): 24 | x_path = config.x_model 25 | elif os.path.exists(os.path.join(config.cache_root, config.eval_mode, config.x_model, config.eval_set, "retrieval_result.pkl")): 26 | x_path = os.path.join(config.cache_root, config.eval_mode, config.x_model, config.eval_set, "retrieval_result.pkl") 27 | else: 28 | raise FileNotFoundError 29 | 30 | if os.path.exists(config.y_model): 31 | y_path = config.y_model 32 | elif os.path.exists(os.path.join(config.cache_root, config.eval_mode, config.y_model, config.eval_set, "retrieval_result.pkl")): 33 | y_path = os.path.join(config.cache_root, config.eval_mode, config.y_model, config.eval_set, "retrieval_result.pkl") 34 | else: 35 | raise FileNotFoundError 36 | 37 | print(x_path, y_path) 38 | 39 | x_retrieval_result = load_pickle(x_path) 40 | y_retrieval_result = load_pickle(y_path) 41 | 42 | ground_truth = load_pickle(os.path.join(config.cache_root, "dataset", config.eval_set, "positives.pkl")) 43 | 44 | all_metrics = set() 45 | cutoffs = set() 46 | for metric in config.ttest_metric: 47 | if "@" in metric: 48 | metric_body, cutoff = metric.split("@") 49 | all_metrics.add(metric_body.lower()) 50 | cutoffs.add(int(cutoff)) 51 | else: 52 | all_metrics.add(metric.lower()) 53 | 54 | all_metrics = list(all_metrics) 55 | cutoffs = list(cutoffs) 56 | 57 | x_metrics_per_query = compute_metrics(x_retrieval_result, ground_truth, metrics=all_metrics, cutoffs=cutoffs, return_each_query=True) 58 | y_metrics_per_query = compute_metrics(y_retrieval_result, ground_truth, metrics=all_metrics, cutoffs=cutoffs, return_each_query=True) 59 | 60 | print("*" * 10 + f" {config.x_model} (X) v.s. {config.y_model} (Y) " + "*" * 10) 61 | for metric in config.ttest_metric: 62 | print(f"the p of {metric}: {' '*(20 - len(metric))}{stats.ttest_rel(x_metrics_per_query[metric], y_metrics_per_query[metric]).pvalue}") 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | anserini 2 | cache 3 | *.pyc 4 | *.log 5 | __pycache__ 6 | test*.py 7 | tmp*.py 8 | tmp*.sh 9 | discussions 10 | reviews 11 | backup 12 | ppts 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # auto-generated Sphinx api docs 23 | /docs/generated 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | pip-wheel-metadata/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | db.sqlite3-journal 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | -------------------------------------------------------------------------------- /src/models/ColBERT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoModel 4 | from .BaseModel import BaseModel 5 | 6 | 7 | 8 | class ColBERT(BaseModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | 12 | self._set_encoder() 13 | 14 | self.tokenProject = nn.Linear(self.textEncoder.config.hidden_size, config.token_dim) 15 | 16 | 17 | def _encode_text(self, **kwargs): 18 | for k, v in kwargs.items(): 19 | # B, 1+N, L -> B * (1+N), L 20 | if v.dim() == 3: 21 | kwargs[k] = v.view(-1, v.shape[-1]) 22 | 23 | token_all_embedding = self.textEncoder(**kwargs)[0] 24 | token_embedding = self.tokenProject(token_all_embedding) 25 | return token_embedding 26 | 27 | 28 | def _encode_query(self, **kwargs): 29 | token_all_embedding = self.queryEncoder(**kwargs)[0] 30 | token_embedding = self.tokenProject(token_all_embedding) 31 | return token_embedding 32 | 33 | 34 | def forward(self, x): 35 | x = self._move_to_device(x) 36 | 37 | query_token_embedding = self._encode_query(**x["query"]) # B, LQ, D 38 | text_token_embedding = self._encode_text(**x["text"]) # B*(1+N), LS, D 39 | 40 | if self.config.is_distributed and self.config.enable_all_gather: 41 | query_token_embedding = self._gather_tensors(query_token_embedding) 42 | text_token_embedding = self._gather_tensors(text_token_embedding) 43 | 44 | query_text_score = torch.einsum('qin,tjn->qitj', query_token_embedding, text_token_embedding) 45 | query_text_score = query_text_score.max(dim=-1)[0] # B, LQ, B*(1+N) 46 | score = query_text_score.sum(dim=1) # B, B*(1+N) 47 | 48 | B = score.shape[0] 49 | if self.config.enable_inbatch_negative: 50 | label = torch.arange(B, device=self.config.device) 51 | label = label * (text_token_embedding.shape[0] // query_token_embedding.shape[0]) 52 | else: 53 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 54 | score = score.view(B, B, -1)[range(B), range(B)] # B, 1+N 55 | 56 | loss = self._compute_loss(score, label, self._compute_teacher_score(x)) 57 | return loss 58 | 59 | 60 | def rerank_step(self, x): 61 | """ 62 | given a query and a sequence, output the sequence's score 63 | """ 64 | x = self._move_to_device(x) 65 | query_token_embedding = self._encode_query(**x["query"]) # B, LQ, D 66 | text_token_embedding = self._encode_text(**x["text"]) # B, LS, D 67 | 68 | query_text_score = query_token_embedding.matmul(text_token_embedding.transpose(-1,-2)) 69 | score = query_text_score.max(dim=-1)[0].sum(dim=-1) # B 70 | return score 71 | 72 | 73 | def retrieve(self, manager, loaders): 74 | self.logger.error("currently we do not support retrieval with ColBERT, instead we evaluate it by reranking task") 75 | raise 76 | 77 | -------------------------------------------------------------------------------- /src/utils/static.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from typing import * 5 | 6 | DEVICE = Union[int,Literal["cpu"]] 7 | 8 | RETRIEVAL_MAPPING = Union[dict[int, list[int]], dict[int, list[tuple[int,float]]]] 9 | ID_MAPPING = dict[str, int] 10 | 11 | DENSE_METRIC = Literal["ip", "cos", "l2"] 12 | DATA_FORMAT = Literal["memmap", "raw"] 13 | 14 | TENSOR = torch.Tensor 15 | LOADERS = dict[str,DataLoader] 16 | NN_MODULE = torch.nn.Module 17 | INDICES = Union[np.ndarray,list,torch.Tensor] 18 | 19 | PLM_MAP = { 20 | "bert": { 21 | # different model may share the same tokenizer, so we can load the same tokenized data for them 22 | "tokenizer": "bert", 23 | "load_name": "bert-base-uncased" 24 | }, 25 | "distilbert": { 26 | "tokenizer": "bert", 27 | "load_name": "distilbert-base-uncased", 28 | }, 29 | "ernie": { 30 | "tokenizer": "bert", 31 | "load_name": "nghuyong/ernie-2.0-en" 32 | }, 33 | "contriever": { 34 | "tokenizer": "contriever", 35 | "load_name": "null" 36 | }, 37 | "gtr": { 38 | "tokenizer": "gtr", 39 | "load_name": "null" 40 | }, 41 | "bert-chinese": { 42 | "tokenizer": "bert-chinese", 43 | "load_name": "bert-base-chinese" 44 | }, 45 | "bert-xingshi": { 46 | "tokenizer": "bert-xingshi", 47 | "load_name": "null" 48 | }, 49 | "t5-small": { 50 | "tokenizer": "t5", 51 | "load_name": "t5-small" 52 | }, 53 | "t5": { 54 | "tokenizer": "t5", 55 | "load_name": "t5-base" 56 | }, 57 | "t5-large": { 58 | "tokenizer": "t5", 59 | "load_name": "t5-large" 60 | }, 61 | "doct5": { 62 | "tokenizer": "t5", 63 | "load_name": "castorini/doc2query-t5-base-msmarco" 64 | }, 65 | "distilsplade": { 66 | "tokenizer": "bert", 67 | "load_name": "null" 68 | }, 69 | "splade": { 70 | "tokenizer": "bert", 71 | "load_name": "null" 72 | }, 73 | "bart": { 74 | "tokenizer": "bart", 75 | "load_name": "facebook/bart-base" 76 | }, 77 | "bart-large": { 78 | "tokenizer": "bart", 79 | "load_name": "facebook/bart-large" 80 | }, 81 | "retromae": { 82 | "tokenizer": "bert", 83 | "load_name": "Shitao/RetroMAE" 84 | }, 85 | "retromae_msmarco": { 86 | "tokenizer": "bert", 87 | "load_name": "Shitao/RetroMAE_MSMARCO" 88 | }, 89 | "retromae_distill": { 90 | "tokenizer": "bert", 91 | "load_name": "Shitao/RetroMAE_MSMARCO_distill" 92 | }, 93 | "deberta": { 94 | "tokenizer": "deberta", 95 | "load_name": "microsoft/deberta-base" 96 | }, 97 | "keyt5": { 98 | "tokenizer": "t5", 99 | "load_name": "snrspeaks/KeyPhraseTransformer" 100 | }, 101 | "seal": { 102 | "tokenizer": "bart", 103 | "load_name": "tuner007/pegasus_paraphrase" 104 | }, 105 | "doct5-nq": { 106 | "tokenizer": "t5", 107 | "load_name": "namespace-Pt/doct5-nq320k" 108 | } 109 | } -------------------------------------------------------------------------------- /src/models/CrossEnc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers import AutoModelForSequenceClassification 4 | from .BaseModel import BaseModel 5 | 6 | 7 | 8 | class CrossEncoder(BaseModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | 12 | self.plm = AutoModelForSequenceClassification.from_pretrained(config.plm_dir, num_labels=1) 13 | # self.plm.pooler = None 14 | if config.code_size > 0: 15 | self.plm.resize_token_embeddings(config.vocab_size + config.code_size) 16 | 17 | 18 | def _compute_score(self, **kwargs): 19 | """ concate the query and the input text; 20 | Args: 21 | query_token_id: B, LQ 22 | text_token_id: B, LS 23 | Returns: 24 | tensor of [B] 25 | """ 26 | for k, v in kwargs.items(): 27 | # B, 1+N, L -> B * (1+N), L 28 | if v.dim() == 3: 29 | kwargs[k] = v.view(-1, v.shape[-1]) 30 | 31 | score = self.plm(**kwargs).logits.squeeze(-1) 32 | return score 33 | 34 | 35 | def rerank_step(self, x): 36 | x = self._move_to_device(x) 37 | if "text_code" in x: 38 | # concate query and text code as inputs 39 | query_token_id = x["query"]["input_ids"] # B, L 40 | query_attn_mask = x["query"]["attention_mask"] 41 | 42 | text_code = x["text_code"] # B, 1+N, LC or B, LC 43 | if text_code.dim() == 3: 44 | text_code = text_code.flatten(0, 1) # B*(1+N) or B, LC 45 | 46 | M, L = text_code.shape[0] // query_token_id.shape[0], query_token_id.shape[-1] 47 | 48 | pair_token_id = torch.zeros((text_code.shape[0], text_code.shape[-1] + query_token_id.shape[-1] - 1), device=text_code.device) 49 | pair_token_id[:, :L] = query_token_id.repeat_interleave(M, 0) 50 | # remove the leading 0 51 | pair_token_id[:, L:] = text_code[:, 1:] 52 | 53 | pair_attn_mask = torch.zeros_like(pair_token_id) 54 | pair_attn_mask[:, :L] = query_attn_mask.repeat_interleave(M, 0) 55 | pair_attn_mask[:, L:] = (text_code != -1).float() 56 | 57 | pair = { 58 | "input_ids": pair_token_id, 59 | "attention_mask": pair_attn_mask 60 | } 61 | if "token_type_ids" in x["query"]: 62 | pair_type_id = torch.zeros_like(pair_attn_mask) 63 | pair_type_id[:, L:] = 1 64 | pair["token_type_ids"] = pair_type_id 65 | else: 66 | pair = x["pair"] 67 | 68 | score = self._compute_score(**pair) # B or B*(1+N) 69 | return score 70 | 71 | 72 | def forward(self, x): 73 | pair = x["pair"] 74 | score = self.rerank_step(x) # B*(1+N) 75 | 76 | if pair["input_ids"].dim() == 3: 77 | # use cross entropy loss 78 | score = score.view(x["pair"]["input_ids"].shape[0], -1) 79 | label = torch.zeros(score.shape[0], dtype=torch.long, device=self.config.device) 80 | loss = F.cross_entropy(score, label) 81 | 82 | elif pair["input_ids"].dim() == 2: 83 | label = x["label"] 84 | loss = F.binary_cross_entropy(torch.sigmoid(score), label) 85 | 86 | return loss 87 | 88 | 89 | -------------------------------------------------------------------------------- /src/models/SPARTA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | from tqdm import tqdm 7 | from transformers import AutoModel, AutoTokenizer 8 | from .BaseModel import BaseSparseModel 9 | from utils.util import BaseOutput 10 | 11 | 12 | class SPARTA(BaseSparseModel): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | 16 | self.plm = AutoModel.from_pretrained(config.plm_dir) 17 | self.plm.pooler = None 18 | 19 | self._skip_special_tokens = True 20 | self._text_length = self.config.text_decode_k 21 | 22 | 23 | def _encode_text(self, **kwargs): 24 | for k, v in kwargs.items(): 25 | # B, 1+N, L -> B * (1+N), L 26 | if v.dim() == 3: 27 | kwargs[k] = v.view(-1, v.shape[-1]) 28 | 29 | token_embedding = self.plm(**kwargs)[0] # B, L, D 30 | return token_embedding 31 | 32 | 33 | def _encode_query(self, token_id): 34 | return self.plm.embeddings.word_embeddings(token_id) 35 | 36 | 37 | def forward(self, x): 38 | x = self._move_to_device(x) 39 | 40 | query_token_embedding = self._encode_query(x["query"]["input_ids"]) # B, L, D 41 | text_token_embedding = self._encode_text(**x["text"]) # B*(1+N), L, D 42 | 43 | if self.config.is_distributed and self.config.enable_all_gather: 44 | query_token_embedding = self._gather_tensors(query_token_embedding) 45 | text_token_embedding = self._gather_tensors(text_token_embedding) 46 | 47 | query_text_score = torch.einsum('qin,tjn->qitj', query_token_embedding, text_token_embedding) 48 | query_text_score = query_text_score.max(dim=-1)[0] # B, LQ, B*(1+N) 49 | query_text_score = torch.log(torch.relu(query_text_score) + 1) 50 | score = query_text_score.sum(dim=1) # B, B*(1+N) 51 | 52 | B = score.shape[0] 53 | if self.config.enable_inbatch_negative: 54 | label = torch.arange(B, device=self.config.device) 55 | label = label * (text_token_embedding.shape[0] // query_token_embedding.shape[0]) 56 | else: 57 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 58 | score = score.view(B, B, -1)[range(B), range(B)] # B, 1+N 59 | 60 | loss = self._compute_loss(score, label, self._compute_teacher_score(x)) 61 | return loss 62 | 63 | 64 | def encode_text_step(self, x): 65 | """ 66 | Pre-compute interactions of all possible tokens with each text token, keep the most matching text token; then only index the topk decoded tokens (top k important tokens in the sense that they will contribute most to the final text score) 67 | """ 68 | text = self._move_to_device(x["text"]) 69 | text_token_embedding = self._encode_text(**text) # B, L, D 70 | vocab_embedding = self.plm.embeddings.word_embeddings.weight # V, D 71 | text_token_embedding = torch.einsum("vd,...ld->...lv", vocab_embedding, text_token_embedding) # B, L, V 72 | text_embedding = torch.log(torch.relu(text_token_embedding.max(1)[0]) + 1) # B, V 73 | 74 | text_token_id, text_token_weight = text_embedding.topk(k=self._text_length) 75 | return text_token_id.cpu().numpy(), text_token_weight.unsqueeze(-1).cpu().numpy() 76 | 77 | -------------------------------------------------------------------------------- /src/models/UniRetriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from .BaseModel import BaseModel 4 | from utils.util import load_pickle 5 | 6 | 7 | class UniRetriever(BaseModel): 8 | def __init__(self, config): 9 | from .AutoModel import AutoModel as AM 10 | super().__init__(config) 11 | 12 | if config.x_model != "none": 13 | additional_kwargs = { 14 | "data_root": config.data_root, 15 | "plm_root": config.plm_root, 16 | "text_col": config.text_col, 17 | "device": config.get("x_device", config.device), 18 | "verifier_type": config.verifier_type, 19 | "verifier_src": config.verifier_src, 20 | "verifier_index": config.verifier_index 21 | } 22 | for k,v in config.items(): 23 | if k.startswith("x_") and k != "x_model": 24 | additional_kwargs[k[2:]] = v 25 | 26 | XModel = AM.from_pretrained(os.path.join(config.cache_root, "ckpts", config.x_model, config.x_load_ckpt), **additional_kwargs) 27 | else: 28 | XModel = None 29 | 30 | if config.y_model != "none": 31 | additional_kwargs = { 32 | "data_root": config.data_root, 33 | "plm_root": config.plm_root, 34 | "text_col": config.text_col, 35 | "device": config.get("y_device", config.device), 36 | "verifier_type": config.verifier_type, 37 | "verifier_src": config.verifier_src, 38 | "verifier_index": config.verifier_index 39 | } 40 | for k,v in config.items(): 41 | if k.startswith("y_") and k != "y_model": 42 | additional_kwargs[k[2:]] = v 43 | 44 | YModel = AM.from_pretrained(os.path.join(config.cache_root, "ckpts", config.y_model, config.y_load_ckpt), **additional_kwargs) 45 | 46 | else: 47 | YModel = None 48 | 49 | self.XModel = XModel 50 | self.YModel = YModel 51 | 52 | 53 | def retrieve(self, loaders): 54 | """ retrieve by index 55 | 56 | Args: 57 | encode_query: if true, compute query embedding before retrieving 58 | """ 59 | if self.XModel is not None: 60 | x_retrieval_result = self.XModel.retrieve(loaders) 61 | self.metrics.update({f"X {k}": v for k, v in self.XModel.metrics.items() if k in ["Posting_List_Length"]}) 62 | else: 63 | x_retrieval_result = {} 64 | 65 | if self.YModel is not None: 66 | y_retrieval_result = self.YModel.retrieve(loaders) 67 | self.metrics.update({f"Y {k}": v for k, v in self.YModel.metrics.items() if k in ["Posting_List_Length"]}) 68 | else: 69 | y_retrieval_result = {} 70 | 71 | try: 72 | posting_length = self.metrics["X Posting_List_Length"] + self.metrics["Y Posting_List_Length"] 73 | flops = round((posting_length) * 48 / len(loaders["text"].dataset), 2) 74 | self.metrics.update({"Posting_List_length": posting_length, "FLOPs": flops}) 75 | except: 76 | pass 77 | 78 | if self.config.get("save_intm_result"): 79 | self.XModel._gather_retrieval_result( 80 | x_retrieval_result, 81 | retrieval_result_path=os.path.join(self.retrieve_dir, "x_retrieval_result.pkl") 82 | ) 83 | self.YModel._gather_retrieval_result( 84 | y_retrieval_result, 85 | retrieval_result_path=os.path.join(self.retrieve_dir, "y_retrieval_result.pkl") 86 | ) 87 | 88 | loader_query = loaders["query"] 89 | retrieval_result = {} 90 | for qidx in range(loader_query.sampler.start, loader_query.sampler.end): 91 | res = dict(x_retrieval_result.get(qidx, [])) 92 | res.update(dict(y_retrieval_result.get(qidx, []))) 93 | sorted_res = sorted(res.items(), key=lambda x: x[1], reverse=True)[:self.config.hits] 94 | retrieval_result[qidx] = sorted_res 95 | 96 | return retrieval_result 97 | -------------------------------------------------------------------------------- /src/models/SEAL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import subprocess 4 | from tqdm import tqdm 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 6 | from utils.util import BaseOutput, synchronize, makedirs 7 | from .BaseModel import BaseGenerativeModel 8 | 9 | 10 | class SEAL(BaseGenerativeModel): 11 | def __init__(self, config): 12 | super().__init__(config) 13 | assert "bart" in self.config.plm 14 | self.plm = AutoModelForSeq2SeqLM.from_pretrained(self.config.plm_dir) 15 | 16 | 17 | def index(self, loaders): 18 | """ 19 | Build FM index. 20 | """ 21 | import seal 22 | assert self.config.index_type == "fm", "Must use fm index!" 23 | 24 | index_dir = os.path.join(self.index_dir, "index") 25 | index_path = os.path.join(index_dir, "fm_index") 26 | if self.config.is_main_proc: 27 | collection_dir = os.path.join(self.index_dir, "collection") 28 | collection_path = os.path.join(collection_dir, "collection.tsv") 29 | 30 | makedirs(collection_path) 31 | makedirs(index_path) 32 | 33 | tokenizer = AutoTokenizer.from_pretrained(self.config.plm_dir) 34 | 35 | if (self.config.load_index and os.path.exists(index_path + ".oth")) or (self.config.load_collection and os.path.exists(collection_path)): 36 | pass 37 | else: 38 | assert self.config.get("title_col") is not None, "Must specify title column index!" 39 | loader_text = loaders["text"] 40 | with open(f"{self.config.data_root}/{self.config.dataset}/collection.tsv") as f, \ 41 | open(collection_path, "w") as g: 42 | for line in tqdm(f, total=loader_text.dataset.text_num, ncols=100, desc="Building Collection"): 43 | fields = line.split("\t") 44 | fields = [field.strip() for field in fields] 45 | 46 | tid = fields[0] 47 | title = fields[self.config.title_col] 48 | text = " ".join(fields[self.config.title_col + 1:]) 49 | 50 | # for fair comparison 51 | text = tokenizer.decode(tokenizer.encode(text, add_special_tokens=False)[:self.config.text_length], skip_special_tokens=True) 52 | 53 | g.write("\t".join([tid, title, text]) + "\n") 54 | 55 | if self.config.load_index and os.path.exists(index_path + ".oth"): 56 | pass 57 | else: 58 | subprocess.run( 59 | f"python -m seal.build_fm_index {collection_path} {index_path} --hf_model {self.config.plm_dir} --jobs {self.config.index_thread} --include_title --lowercase", shell=True) 60 | 61 | synchronize() 62 | # fm_index = seal.SEALSearcher.load_fm_index(index_path) 63 | return BaseOutput() 64 | 65 | 66 | @synchronize 67 | @torch.no_grad() 68 | def retrieve(self, loaders): 69 | import seal 70 | index = self.index(loaders).index 71 | loader_query = loaders["query"] 72 | 73 | self.logger.info("searching...") 74 | tokenizer = AutoTokenizer.from_pretrained(self.config.plm_dir) 75 | 76 | searcher = seal.SEALSearcher.load(os.path.join(self.index_dir, "index", "fm_index"), bart_model_path=f"../../SEAL/{self.config.dataset}/checkpoints/checkpoint_best.pt", backbone=self.config.plm_dir, device=self.config.device) 77 | # searcher = seal.SEALSearcher.load(os.path.join(self.index_dir, "index", "fm_index"), None, backbone=self.config.plm_dir, device=self.config.device) 78 | searcher.include_keys = True 79 | 80 | retrieval_result = {} 81 | for i, x in enumerate(tqdm(loader_query, leave=False, ncols=100)): 82 | query = x["query"]["input_ids"] 83 | query_idx = x["query_idx"].tolist() 84 | 85 | query = tokenizer.batch_decode(query, skip_special_tokens=True) 86 | for j, q in zip(query_idx, query): 87 | res = searcher.search(q, k=self.config.hits) 88 | retrieval_result[j] = [doc.id() for doc in res] 89 | 90 | if self.config.get("debug") and i > 1: 91 | break 92 | 93 | return retrieval_result 94 | 95 | 96 | def generate_code(self, loaders): 97 | pass 98 | 99 | 100 | -------------------------------------------------------------------------------- /src/models/SPLADE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers import AutoModelForMaskedLM 4 | from .BaseModel import BaseSparseModel 5 | 6 | 7 | 8 | class SPLADEv2(BaseSparseModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | 12 | self._set_encoder(AutoModelForMaskedLM) 13 | 14 | if self.config.mode == "train": 15 | self._step = 0 16 | self._lambda_warmup_step = config.lambda_warmup_step // config.world_size 17 | 18 | self._skip_special_tokens = False 19 | self._text_length = self.config.text_decode_k 20 | self._query_length = self.config.query_decode_k 21 | 22 | def _encode_text(self, **kwargs): 23 | for k, v in kwargs.items(): 24 | # B, 1+N, L -> B * (1+N), L 25 | if v.dim() == 3: 26 | kwargs[k] = v.view(-1, v.shape[-1]) 27 | 28 | token_all_embedding = self.textEncoder(**kwargs, return_dict=True).logits 29 | token_all_embedding = torch.log(F.relu(token_all_embedding) + 1) * kwargs["attention_mask"].unsqueeze(-1) # B, L, V 30 | token_embedding = token_all_embedding.max(dim=1)[0] # B, V 31 | return token_embedding 32 | 33 | def _encode_query(self, **kwargs): 34 | token_all_embedding = self.queryEncoder(**kwargs, return_dict=True).logits 35 | token_all_embedding = torch.log(F.relu(token_all_embedding) + 1) * kwargs["attention_mask"].unsqueeze(-1) # B, L, V 36 | token_embedding = token_all_embedding.max(dim=1)[0] # B, V 37 | return token_embedding 38 | 39 | 40 | def _compute_flops(self, embedding): 41 | return torch.sum(torch.mean(torch.abs(embedding), dim=0) ** 2) 42 | 43 | 44 | def _refresh_lambda(self): 45 | if self._step <= self._lambda_warmup_step and self._lambda_warmup_step > 0: 46 | self._text_lambda = self.config.text_lambda * (self._step / self._lambda_warmup_step) ** 2 47 | self._query_lambda = self.config.query_lambda * (self._step / self._lambda_warmup_step) ** 2 48 | else: 49 | self._text_lambda = self.config.text_lambda 50 | self._query_lambda = self.config.query_lambda 51 | self._step += 1 52 | 53 | 54 | def forward(self, x): 55 | if self.training: 56 | self._refresh_lambda() 57 | 58 | x = self._move_to_device(x) 59 | 60 | query_embedding = self._encode_query(**x["query"]) # B, V 61 | text_embedding = self._encode_text(**x["text"]) # B*(1+N), V 62 | 63 | if self.config.is_distributed and self.config.enable_all_gather: 64 | query_embedding = self._gather_tensors(query_embedding) 65 | text_embedding = self._gather_tensors(text_embedding) 66 | 67 | score = query_embedding.matmul(text_embedding.transpose(-1,-2)) # B, B*(1+N) 68 | 69 | B = query_embedding.size(0) 70 | # in batch negative 71 | if self.config.enable_inbatch_negative: 72 | label = torch.arange(B, device=self.config.device) 73 | label = label * (text_embedding.shape[0] // query_embedding.shape[0]) 74 | else: 75 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 76 | score = score.view(B, B, -1)[range(B), range(B)] # B, 1+N 77 | 78 | query_flops_loss = self._compute_flops(query_embedding) * self._query_lambda 79 | text_flops_loss = self._compute_flops(text_embedding) * self._text_lambda 80 | flops_loss = query_flops_loss + text_flops_loss 81 | 82 | loss = self._compute_loss(score, label, self._compute_teacher_score(x)) + flops_loss 83 | 84 | return loss 85 | 86 | 87 | def encode_text_step(self, x): 88 | text = self._move_to_device(x["text"]) 89 | text_embedding = self._encode_text(**text) 90 | 91 | text_token_weight, text_token_id = text_embedding.topk(k=self._text_length, dim=1) # B, K 92 | 93 | # unsqueeze to map it to the _output_dim (1) 94 | return text_token_id.cpu().numpy(), text_token_weight.unsqueeze(-1).cpu().numpy() 95 | 96 | 97 | def encode_query_step(self, x): 98 | query = self._move_to_device(x["query"]) 99 | query_embedding = self._encode_query(**query) 100 | 101 | query_token_weight, query_token_id = query_embedding.topk(k=self._query_length, dim=1) # B, K 102 | 103 | # unsqueeze to map it to the _output_dim (1) 104 | return query_token_id.cpu().numpy(), query_token_weight.unsqueeze(-1).cpu().numpy() 105 | -------------------------------------------------------------------------------- /src/scripts/doct5.py: -------------------------------------------------------------------------------- 1 | # on zhiyuan machine, must import utils first to load faiss before torch 2 | from utils.util import synchronize, save_pickle, makedirs, Config 3 | from utils.data import prepare_data 4 | 5 | import os 6 | import sys 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | from multiprocessing import Pool 11 | from transformers import T5ForConditionalGeneration, AutoTokenizer 12 | 13 | import hydra 14 | from pathlib import Path 15 | from omegaconf import OmegaConf 16 | @hydra.main(version_base=None, config_path="../data/config/", config_name=f"script/{Path(__file__).stem}") 17 | def get_config(hydra_config: OmegaConf): 18 | config._from_hydra(hydra_config) 19 | 20 | 21 | def main(config:Config): 22 | loaders = prepare_data(config) 23 | loader_text = loaders["text"] 24 | 25 | max_length = config.query_length 26 | query_per_doc = config.query_per_doc 27 | 28 | doct5_path = os.path.join(config.data_root, config.dataset, "queries.doct5.tsv") 29 | doct5_qrel_path = os.path.join(config.data_root, config.dataset, "qrels.doct5.tsv") 30 | cache_dir = os.path.join(config.cache_root, "dataset", "query", "doct5") 31 | mmp_path = os.path.join(cache_dir, config.plm_tokenizer, "token_ids.mmp") 32 | makedirs(mmp_path) 33 | 34 | model = T5ForConditionalGeneration.from_pretrained(config.plm_dir).to(config.device) 35 | tokenizer = AutoTokenizer.from_pretrained(config.plm_dir) 36 | 37 | # generate psudo queries 38 | if not config.load_encode: 39 | # -1 is the pad_token_id 40 | query_token_ids = np.zeros((len(loader_text.sampler), query_per_doc, max_length), dtype=np.int32) - 1 41 | start_idx = end_idx = 0 42 | with torch.no_grad(): 43 | for i, x in enumerate(tqdm(loader_text, ncols=100, desc="Generating Queries")): 44 | text = x["text"] 45 | for k, v in text.items(): 46 | text[k] = v.to(config.device, non_blocking=True) 47 | 48 | B = text["input_ids"].shape[0] 49 | 50 | sequences = model.generate( 51 | **text, 52 | do_sample=True, 53 | max_length=max_length, 54 | temperature=3.0, 55 | top_k=10, 56 | num_return_sequences=query_per_doc 57 | ).view(B, query_per_doc, -1).cpu().numpy() # B, N, L 58 | 59 | end_idx += B 60 | query_token_ids[start_idx: end_idx, :, :sequences.shape[-1]] = sequences 61 | start_idx = end_idx 62 | 63 | # mask eos tokens 64 | query_token_ids[query_token_ids == config.special_token_ids["eos"][1]] = -1 65 | 66 | # use memmap to temperarily save the generated token ids 67 | if config.is_main_proc: 68 | query_token_ids_mmp = np.memmap( 69 | mmp_path, 70 | shape=(len(loader_text.dataset) * query_per_doc, max_length), 71 | dtype=np.int32, 72 | mode="w+" 73 | ) 74 | synchronize() 75 | 76 | query_token_ids_mmp = np.memmap( 77 | mmp_path, 78 | dtype=np.int32, 79 | mode="r+" 80 | ).reshape(len(loader_text.dataset), query_per_doc, max_length) 81 | query_token_ids_mmp[loader_text.sampler.start: loader_text.sampler.end] = query_token_ids 82 | synchronize() 83 | 84 | if config.is_main_proc: 85 | # load all saved token ids 86 | query_token_ids = np.memmap( 87 | mmp_path, 88 | dtype=np.int32, 89 | mode="r+" 90 | ).reshape(len(loader_text.dataset) * query_per_doc, max_length) 91 | 92 | # decode to strings and write to the query file 93 | idx = 0 94 | with open(doct5_path, "w") as f, open(doct5_qrel_path, "w") as g: 95 | for i, queries in enumerate(tqdm(query_token_ids.reshape(len(loader_text.dataset), query_per_doc, max_length), ncols=100, desc="Decoding")): 96 | for j, query in enumerate(queries): 97 | seq = tokenizer.decode(query[query != -1], skip_special_tokens=True) # N 98 | f.write("\t".join([str(idx), seq]) + "\n") 99 | g.write("\t".join([str(idx), "0", str(i), "1"]) + "\n") 100 | idx += 1 101 | 102 | 103 | if __name__ == "__main__": 104 | # manually action="store_true" because hydra doesn't support it 105 | for i, arg in enumerate(sys.argv): 106 | if "=" not in arg: 107 | sys.argv[i] += "=true" 108 | 109 | config = Config() 110 | get_config() 111 | main(config) 112 | -------------------------------------------------------------------------------- /src/scripts/evalnq.py: -------------------------------------------------------------------------------- 1 | import regex 2 | import unicodedata 3 | import argparse 4 | import pickle 5 | import sys 6 | import numpy as np 7 | from torch.utils.data.dataloader import DataLoader 8 | from torch.utils.data.dataset import Dataset 9 | from tqdm import tqdm 10 | 11 | 12 | 13 | def load_test_data(query_andwer_path, collection_path): 14 | answers = [] 15 | for line in open(query_andwer_path, encoding='utf-8'): 16 | line = line.strip().split('\t') 17 | answers.append(eval(line[1])) 18 | 19 | collection = [] 20 | for line in tqdm(open(collection_path, encoding='utf-8'), ncols=100, desc="Collecting Passages", leave=False): 21 | line = line.strip().split('\t') 22 | collection.append(line[2]) 23 | return answers, collection 24 | 25 | 26 | class SimpleTokenizer: 27 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 28 | NON_WS = r'[^\p{Z}\p{C}]' 29 | 30 | def __init__(self, **kwargs): 31 | """ 32 | Args: 33 | annotators: None or empty set (only tokenizes). 34 | """ 35 | self._regexp = regex.compile( 36 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 37 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 38 | ) 39 | 40 | def tokenize(self, text, uncase=False): 41 | tokens = [] 42 | matches = [m for m in self._regexp.finditer(text)] 43 | for i in range(len(matches)): 44 | # Get text 45 | token = matches[i].group() 46 | # Format data 47 | if uncase: 48 | tokens.append(token.lower()) 49 | else: 50 | tokens.append(token) 51 | return tokens 52 | 53 | 54 | def _normalize(text): 55 | return unicodedata.normalize('NFD', text) 56 | 57 | 58 | def has_answer(answers, text, tokenizer) -> bool: 59 | """Check if a document contains an answer string. 60 | """ 61 | text = _normalize(text) 62 | 63 | # Answer is a list of possible strings 64 | text = tokenizer.tokenize(text, uncase=True) 65 | 66 | for answer in answers: 67 | answer = _normalize(answer) 68 | answer = tokenizer.tokenize(answer, uncase=True) 69 | 70 | for i in range(0, len(text) - len(answer) + 1): 71 | if answer == text[i: i + len(answer)]: 72 | return True 73 | return False 74 | 75 | 76 | def collate_fn(batch_hits): 77 | return batch_hits 78 | 79 | 80 | class EvalDataset(Dataset): 81 | def __init__(self, retrieval_result, answers, collection): 82 | self.collection = collection 83 | self.answers = answers 84 | self.retrieval_result = retrieval_result 85 | self.tokenizer = SimpleTokenizer() 86 | 87 | def __getitem__(self, qidx): 88 | res = self.retrieval_result[qidx] 89 | hits = [] 90 | for i, tidx in enumerate(res): 91 | if tidx == -1: 92 | hits.append(False) 93 | else: 94 | hits.append(has_answer(self.answers[qidx], self.collection[tidx], self.tokenizer)) 95 | return hits 96 | 97 | def __len__(self): 98 | return len(self.retrieval_result) 99 | 100 | 101 | def validate(retrieval_result, answers, collection, num_workers=16, batch_size=16): 102 | dataset = EvalDataset(retrieval_result, answers, collection) 103 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) 104 | 105 | final_scores = [] 106 | for scores in tqdm(dataloader, total=len(dataloader), ncols=100, desc="Computing Metrics"): 107 | final_scores.extend(scores) 108 | 109 | relaxed_hits = np.zeros(max([len(x) for x in retrieval_result.values()])) 110 | for question_hits in final_scores: 111 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 112 | if best_hit is not None: 113 | relaxed_hits[best_hit:] += 1 114 | 115 | relaxed_recall = relaxed_hits / len(retrieval_result) 116 | 117 | return { 118 | "Recall@1": round(relaxed_recall[0], 4), 119 | "Recall@5": round(relaxed_recall[4], 4), 120 | "Recall@10": round(relaxed_recall[9], 4), 121 | "Recall@20": round(relaxed_recall[19], 4), 122 | "Recall@100": round(relaxed_recall[99], 4) 123 | } 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--retrieval_result_path") 129 | parser.add_argument("--query_answer_path", default="../../../Data/NQ/nq-test.qa.csv") 130 | parser.add_argument("--collection_path", default="../../../Data/NQ/collection.tsv") 131 | args = parser.parse_args() 132 | 133 | with open(args.retrieval_result_path, "rb") as f: 134 | retrieval_result = pickle.load(f) 135 | 136 | metric = validate(retrieval_result, *load_test_data(args.query_answer_path, args.collection_path)) 137 | sys.stdout.write(str(metric)) -------------------------------------------------------------------------------- /src/models/COIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .BaseModel import BaseSparseModel 4 | from transformers import AutoModel 5 | 6 | 7 | 8 | class COIL(BaseSparseModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | 12 | self._set_encoder() 13 | 14 | self.tokenProject = nn.Linear(self.textEncoder.config.hidden_size, config.token_dim) 15 | 16 | self._output_dim = config.token_dim 17 | 18 | 19 | def _encode_text(self, **kwargs): 20 | for k, v in kwargs.items(): 21 | # B, 1+N, L -> B * (1+N), L 22 | if v.dim() == 3: 23 | kwargs[k] = v.view(-1, v.shape[-1]) 24 | 25 | token_all_embedding = self.textEncoder(**kwargs)[0] 26 | token_embedding = self.tokenProject(token_all_embedding) 27 | return token_embedding 28 | 29 | 30 | def _encode_query(self, **kwargs): 31 | token_all_embedding = self.queryEncoder(**kwargs)[0] 32 | token_embedding = self.tokenProject(token_all_embedding) 33 | return token_embedding 34 | 35 | 36 | def forward(self, x): 37 | x = self._move_to_device(x) 38 | 39 | query_token_embedding = self._encode_query(**x["query"]) # B, LQ, D 40 | text_token_embedding = self._encode_text(**x["text"]) # B * (1+N), LS, D 41 | 42 | query_token_id = x["query"]["input_ids"] 43 | query_special_mask = x["query_special_mask"] 44 | text_token_id = x["text"]["input_ids"].view(text_token_embedding.shape[:-1]) 45 | if self.config.is_distributed and self.config.enable_all_gather: 46 | query_token_id = self._gather_tensors(query_token_id) 47 | query_special_mask = self._gather_tensors(query_special_mask) 48 | text_token_id = self._gather_tensors(text_token_id) 49 | query_token_embedding = self._gather_tensors(query_token_embedding) 50 | text_token_embedding = self._gather_tensors(text_token_embedding) 51 | 52 | B, LQ, D = query_token_embedding.shape 53 | LS = text_token_id.shape[1] 54 | 55 | query_text_overlap = self._compute_overlap(query_token_id, text_token_id) # B, LQ, B * (1+N), LS 56 | query_text_score = query_token_embedding.view(-1, D).matmul(text_token_embedding.view(-1, D).transpose(0, 1)).view(B, LQ, -1, LS) 57 | # only keep the overlapping tokens 58 | query_text_score = query_text_score * query_text_overlap 59 | # max pooling 60 | query_text_score = query_text_score.max(dim=-1)[0] # B, LQ, B * (1+N) 61 | # mask [CLS] and [SEP] and [PAD] 62 | query_text_score = query_text_score * query_special_mask.unsqueeze(-1) 63 | score = query_text_score.sum(dim=1) # B, B * (1+N) 64 | 65 | if self.config.enable_inbatch_negative: 66 | label = torch.arange(B, device=self.config.device) 67 | label = label * (text_token_embedding.shape[0] // query_token_embedding.shape[0]) 68 | 69 | else: 70 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 71 | score = score.view(B, B, -1)[range(B), range(B)] # B, 1+N 72 | 73 | loss = self._compute_loss(score, label, self._compute_teacher_score(x)) 74 | return loss 75 | 76 | 77 | def encode_text_step(self, x): 78 | # only move text because others are not needed 79 | text = self._move_to_device(x["text"]) 80 | text_token_id = text["input_ids"] 81 | text_token_embedding = self._encode_text(**text) 82 | text_token_embedding *= text["attention_mask"].unsqueeze(-1) 83 | return text_token_id.cpu().numpy(), text_token_embedding.cpu().numpy() 84 | 85 | 86 | def encode_query_step(self, x): 87 | # only move query because others are not needed 88 | query = self._move_to_device(x["query"]) 89 | query_token_id = query["input_ids"] 90 | query_token_embedding = self._encode_query(**query) 91 | query_token_embedding *= query["attention_mask"].unsqueeze(-1) 92 | return query_token_id.cpu().numpy(), query_token_embedding.cpu().numpy() 93 | 94 | 95 | def rerank_step(self, x): 96 | """ 97 | Given a query and a sequence, output the sequence's score 98 | """ 99 | x = self._move_to_device(x) 100 | 101 | query_token_embedding = self._encode_query(**x["query"]) # B, LQ, D 102 | text_token_embedding = self._encode_text(**x["text"]) # B, LS, D 103 | 104 | overlap = self._compute_overlap(x["query"]["input_ids"], x["text"]["input_ids"], cross_batch=False) # B, LQ, LS 105 | query_text_score = query_token_embedding.matmul(text_token_embedding.transpose(-1,-2)) * overlap # B, LQ, LS 106 | # mask the [SEP] token and [CLS] token 107 | query_text_score = query_text_score.max(dim=-1)[0] * x["query_special_mask"] # B, LQ 108 | score = query_text_score.sum(dim=1) 109 | return score 110 | 111 | -------------------------------------------------------------------------------- /src/models/DPR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn.functional as F 5 | from transformers import AutoModel, AutoTokenizer 6 | from .BaseModel import BaseDenseModel 7 | 8 | 9 | 10 | class DPR(BaseDenseModel): 11 | """ 12 | The basic dense retriever. `Paper `_. 13 | """ 14 | def __init__(self, config): 15 | super().__init__(config) 16 | 17 | self._set_encoder() 18 | self._output_dim = self.textEncoder.config.hidden_size 19 | 20 | 21 | def _encode_text(self, **kwargs): 22 | """ 23 | encode tokens with bert 24 | """ 25 | for k, v in kwargs.items(): 26 | # B, 1+N, L -> B * (1+N), L 27 | if v.dim() == 3: 28 | kwargs[k] = v.view(-1, v.shape[-1]) 29 | 30 | embedding = self.textEncoder(**kwargs)[0][:, 0] 31 | if self.config.dense_metric == "cos": 32 | embedding = F.normalize(embedding, dim=-1) 33 | return embedding 34 | 35 | 36 | def _encode_query(self, **kwargs): 37 | embedding = self.queryEncoder(**kwargs)[0][:, 0] 38 | if self.config.dense_metric == "cos": 39 | embedding = F.normalize(embedding, dim=-1) 40 | return embedding 41 | 42 | 43 | def forward(self, x): 44 | x = self._move_to_device(x) 45 | query_embedding = self._encode_query(**x["query"]) # B, D 46 | text_embedding = self._encode_text(**x["text"]) # *, D 47 | 48 | if self.config.is_distributed and self.config.enable_all_gather: 49 | query_embedding = self._gather_tensors(query_embedding) 50 | text_embedding = self._gather_tensors(text_embedding) 51 | 52 | if self.config.dense_metric == "ip": 53 | score = query_embedding.matmul(text_embedding.transpose(-1,-2)) # B, B*(1+N) 54 | elif self.config.dense_metric == "cos": 55 | score = self._cos_sim(query_embedding, text_embedding) 56 | elif self.config.dense_metric == "l2": 57 | score = self._l2_sim(query_embedding, text_embedding) 58 | else: 59 | raise NotImplementedError 60 | 61 | B = query_embedding.size(0) 62 | # in batch negative 63 | if self.config.enable_inbatch_negative: 64 | label = torch.arange(B, device=self.config.device) 65 | label = label * (text_embedding.shape[0] // query_embedding.shape[0]) 66 | else: 67 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 68 | score = score.view(B, B, -1)[range(B), range(B)] # B, 1+N 69 | 70 | loss = self._compute_loss(score, label, self._compute_teacher_score(x)) 71 | return loss 72 | 73 | 74 | def rerank_step(self, x): 75 | """ 76 | given a query and a sequence, output the sequence's score 77 | """ 78 | query_embedding = self._encode_query(**x["query"]) # B, D 79 | text_embedding = self._encode_text(**x["text"]) # B, D 80 | B = query_embedding.size(0) 81 | score = query_embedding.matmul(text_embedding.transpose(-1, -2))[range(B), range(B)] 82 | return score 83 | 84 | 85 | def deploy(self): 86 | deploy_dir = os.path.join(self.config.cache_root, "deploy", self.name) 87 | os.makedirs(deploy_dir, exist_ok=True) 88 | 89 | AutoTokenizer.from_pretrained(self.config.plm_dir).save_pretrained(deploy_dir) 90 | if self.config.untie_encoder: 91 | self.queryEncoder.save_pretrained(os.path.join(deploy_dir, "query")) 92 | self.textEncoder.save_pretrained(os.path.join(deploy_dir, "text")) 93 | else: 94 | self.logger.info(f"saving plm model and tokenizer at {deploy_dir}...") 95 | self.plm.save_pretrained(deploy_dir) 96 | 97 | 98 | class Contriever(DPR): 99 | def encode_text_step(self, x): 100 | text = self._move_to_device(x["text"]) 101 | token_embeddings = self.textEncoder(**text)[0] 102 | mask = text["attention_mask"] 103 | token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) 104 | embedding = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] 105 | if self.config.dense_metric == "cos": 106 | embedding = F.normalize(embedding, dim=-1) 107 | return embedding.cpu().numpy() 108 | 109 | def encode_query_step(self, x): 110 | query = self._move_to_device(x["query"]) 111 | token_embeddings = self.textEncoder(**query)[0] 112 | mask = query["attention_mask"] 113 | token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) 114 | embedding = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] 115 | if self.config.dense_metric == "cos": 116 | embedding = F.normalize(embedding, dim=-1) 117 | return embedding.cpu().numpy() 118 | 119 | 120 | class GTR(BaseDenseModel): 121 | def __init__(self, config): 122 | super().__init__(config) 123 | from sentence_transformers import SentenceTransformer 124 | 125 | self.encoder = SentenceTransformer('sentence-transformers/gtr-t5-base') 126 | self.tokenizer = self.encoder.tokenizer 127 | self._output_dim = 768 128 | 129 | def encode_text_step(self, x): 130 | text = self.tokenizer.batch_decode(x["text"]["input_ids"], skip_special_tokens=True) 131 | embedding = self.encoder.encode(text, batch_size=self.config.eval_batch_size, convert_to_tensor=True, device=self.config.device) 132 | if self.config.dense_metric == "cos": 133 | embedding = F.normalize(embedding, dim=-1) 134 | return embedding.cpu().numpy() 135 | 136 | def encode_query_step(self, x): 137 | query = self.tokenizer.batch_decode(x["query"]["input_ids"], skip_special_tokens=True) 138 | embedding = self.encoder.encode(query, batch_size=self.config.eval_batch_size, convert_to_tensor=True, device=self.config.device) 139 | if self.config.dense_metric == "cos": 140 | embedding = F.normalize(embedding, dim=-1) 141 | return embedding.cpu().numpy() 142 | -------------------------------------------------------------------------------- /src/models/Sequer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from transformers import T5ForConditionalGeneration 6 | from .BaseModel import BaseGenerativeModel 7 | 8 | 9 | 10 | class Sequer(BaseGenerativeModel): 11 | def __init__(self, config): 12 | super().__init__(config) 13 | 14 | plm = T5ForConditionalGeneration.from_pretrained(config.plm_dir) 15 | 16 | if config.get("code_size", 0) > 0: 17 | plm.resize_token_embeddings(config.vocab_size + config.code_size) 18 | 19 | self.plm = plm 20 | 21 | if config.get("rank_type") == "eos": 22 | self.scorer = torch.nn.Linear(plm.config.d_model, 1) 23 | 24 | 25 | def _prepare_decoder_inputs(self, text=None, text_code=None): 26 | """ 27 | Prepare for _compute_logits. For regular text, shift right by 1 position; for code, keep it as is and generate attention mask. 28 | 29 | Returns: 30 | text_token_id: tensor of [B, L]; starting with 0 31 | text_attn_mask: tensor of [B, L] 32 | """ 33 | if text_code is not None: 34 | text_token_id = text_code 35 | text_attn_mask = (text_token_id != -1).float() 36 | # remove -1 because it can not be recognized by the model 37 | text_token_id = text_token_id.masked_fill(text_token_id == -1, 0) 38 | 39 | elif text is not None: 40 | # the original text_token_id does not start with 0, we append it 41 | text_token_id = text.input_ids 42 | text_attn_mask = text.attention_mask 43 | pad_token_id = torch.zeros((*text_token_id.shape[:-1], 1), dtype=text_token_id.dtype, device=text_token_id.device) 44 | text_token_id = torch.cat([pad_token_id, text_token_id], dim=-1) 45 | text_attn_mask = torch.cat([torch.ones_like(pad_token_id), text_attn_mask], dim=-1) 46 | 47 | else: 48 | raise ValueError(f"Must provide either text or text_codes!") 49 | 50 | return text_token_id, text_attn_mask 51 | 52 | 53 | def _compute_logits(self, text_token_id, **kwargs): 54 | """ 55 | Wrapped method to compute each token's relevance score. 56 | 57 | Returns: 58 | token_score: tensor of [B, L] 59 | logits: tensor of [B, L, V] 60 | """ 61 | outputs = self.plm(decoder_input_ids=text_token_id, output_hidden_states=True, **kwargs) # *, L, V 62 | token_embedding = outputs.decoder_hidden_states[-1] # *, L, D 63 | logits = torch.log_softmax(outputs.logits, dim=-1) 64 | # target_token_id = text_token_id[:, 1:] 65 | # token_score = logits.gather(index=target_token_id.unsqueeze(-1), dim=-1).squeeze(-1) # *, L - 1 66 | return logits, token_embedding 67 | 68 | 69 | def forward(self, x): 70 | x = self._move_to_device(x) 71 | encoder_outputs = self.plm.encoder(**x["query"]) 72 | 73 | # start with 0 74 | text_token_id, text_attn_mask = self._prepare_decoder_inputs(x["text"], x.get("text_code")) 75 | 76 | if text_token_id.dim() == 3: 77 | text_token_id = text_token_id.flatten(0, 1) # B*N, L 78 | text_attn_mask = text_attn_mask.flatten(0, 1) 79 | 80 | B = x["query"]["input_ids"].shape[0] 81 | M = text_token_id.shape[0] // B 82 | 83 | query_attn_mask = x["query"]["attention_mask"] 84 | if M > 1: 85 | # repeat query encode outputs to batchify 86 | for k, v in encoder_outputs.items(): 87 | encoder_outputs[k] = v.repeat_interleave(M, 0) 88 | query_attn_mask = query_attn_mask.repeat_interleave(M, 0) 89 | 90 | # important to add attention mask to properly read from encoder_outputs 91 | logits, token_embedding = self._compute_logits(text_token_id, encoder_outputs=encoder_outputs, attention_mask=query_attn_mask) 92 | 93 | loss = 0 94 | 95 | if "gen" in self.config.train_scheme: 96 | gen_logits = logits.unflatten(0, (B, M))[:, 0] # B, L, V 97 | pos_text_token_id = text_token_id.unflatten(0, (B, M))[:, 0] # B, L 98 | pos_text_attn_mask = text_attn_mask.unflatten(0, (B, M))[:, 0] # B, L 99 | 100 | labels = torch.zeros_like(pos_text_token_id) # B, L 101 | labels_mask = torch.zeros_like(pos_text_attn_mask) 102 | # shift left 103 | labels[:, :-1] = pos_text_token_id[:, 1:] # B, L 104 | labels_mask[:, :-1] = pos_text_attn_mask[:, 1:] # B, L 105 | # the pad token will be ignored in computing loss 106 | labels = labels.masked_fill(~labels_mask.bool(), -100) 107 | loss += F.nll_loss(gen_logits.flatten(0, 1), labels.view(-1), ignore_index=-100) 108 | 109 | if "contra" in self.config.train_scheme: 110 | if self.config.rank_type == "eos": 111 | valid_token_length = text_attn_mask.sum(dim=-1).long() - 1 112 | eos_embedding = token_embedding[range(valid_token_length.shape[0]), valid_token_length] 113 | score = self.scorer(eos_embedding).squeeze(-1) 114 | 115 | elif self.config.rank_type == "prob": 116 | score = logits.gather(dim=-1, index=text_token_id[:, 1:, None]).squeeze(-1) # B*N, L-1 117 | score = score.sum(-1) 118 | 119 | # cross entropy 120 | label = torch.zeros(B, device=self.config.device, dtype=torch.long) 121 | score = score.view(B, M) 122 | loss += F.cross_entropy(score, label) 123 | 124 | return loss 125 | 126 | 127 | def rerank_step(self, x): 128 | x = self._move_to_device(x) 129 | 130 | text_token_id, text_attn_mask = self._prepare_decoder_inputs(x.get("text"), x.get("text_code")) 131 | 132 | logits, token_embedding = self._compute_logits(text_token_id, **x["query"], encoder_outputs=x.get("encoder_outputs")) 133 | 134 | # always use eos token to rank 135 | if self.config.rank_type == "eos": 136 | valid_token_length = text_attn_mask.sum(dim=-1).long() - 1 # B 137 | eos_embedding = token_embedding[range(valid_token_length.shape[0]), valid_token_length] 138 | score = self.scorer(eos_embedding).squeeze(-1) 139 | elif self.config.rank_type == "prob": 140 | logits = logits.log_softmax(-1) # B*N, L 141 | score = logits.gather(dim=-1, index=text_token_id[:, 1:, None]).squeeze(-1).sum(-1) # B 142 | return score 143 | 144 | -------------------------------------------------------------------------------- /src/models/DSI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from transformers import AutoModelForSeq2SeqLM 6 | from .BaseModel import BaseGenerativeModel 7 | 8 | 9 | class DSI(BaseGenerativeModel): 10 | def __init__(self, config): 11 | super().__init__(config) 12 | 13 | self.plm = AutoModelForSeq2SeqLM.from_pretrained(config.plm_dir) 14 | if config.get("code_size") and self.config.code_size > 0: 15 | self.plm.resize_token_embeddings(config.vocab_size + config.code_size) 16 | 17 | # NOTE: set hidden size to be involked when using deepspeed 18 | self.config.hidden_size = self.plm.config.hidden_size 19 | 20 | def forward(self, x): 21 | x = self._move_to_device(x) 22 | query = x["query"] 23 | # squeeze the auxillary dimension 24 | text_code = x["text_code"].squeeze(1) 25 | # the code has a leading 0, shift left one position so t5 can shift it back 26 | labels = torch.zeros_like(text_code) 27 | labels[:, :-1] = text_code[:, 1:] 28 | # replace the padding code with the -100 (ignored when computing loss) 29 | labels = labels.masked_fill(labels == -1, -100) 30 | 31 | loss = self.plm(**query, labels=labels).loss 32 | return loss 33 | 34 | def rerank_step(self, x): 35 | """ 36 | Rerank using the log sum of the generation probabilities. 37 | """ 38 | x = self._move_to_device(x) 39 | query = x["query"] 40 | # starts with 0 41 | text_code = x["text_code"] 42 | text_code[text_code == -1] = 0 43 | logits = self.plm(**query, decoder_input_ids=text_code).logits 44 | logits = logits.log_softmax(-1) 45 | score = logits.gather(dim=-1, index=text_code[:, 1:, None]).squeeze(-1).sum(-1) 46 | return score 47 | 48 | 49 | class GENRE(DSI): 50 | def __init__(self, config): 51 | super().__init__(config) 52 | 53 | def generate_code(self, loaders): 54 | import multiprocessing as mp 55 | from transformers import AutoTokenizer, AutoModel 56 | 57 | assert self.config.code_type == "title" 58 | if self.config.is_main_proc: 59 | from utils.util import _get_title_code, makedirs 60 | # the code is bind to the plm_tokenizer 61 | code_path = os.path.join(self.config.cache_root, "codes", self.config.code_type, self.config.code_tokenizer, str(self.config.code_length), "codes.mmp") 62 | # all codes are led by 0 and padded by -1 63 | self.logger.info(f"generating codes from {self.config.code_type} with code_length: {self.config.code_length}, saving at {code_path}...") 64 | 65 | loader_text = loaders["text"] 66 | text_num = len(loader_text.dataset) 67 | makedirs(code_path) 68 | 69 | # load all saved token ids 70 | text_codes = np.memmap( 71 | code_path, 72 | dtype=np.int32, 73 | mode="w+", 74 | shape=(text_num, self.config.code_length) 75 | ) 76 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 77 | model = AutoModel.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 78 | try: 79 | start_token_id = model._get_decoder_start_token_id() 80 | except ValueError: 81 | start_token_id = model.config.pad_token_id 82 | self.logger.warning(f"Decoder start token id not found, use pad token id ({start_token_id}) instead!") 83 | 84 | # the codes are always led by start_token_id and padded by -1 85 | text_codes[:, 0] = start_token_id 86 | text_codes[:, 1:] = -1 87 | 88 | collection_path = os.path.join(self.config.data_root, self.config.dataset, "collection.tsv") 89 | preprocess_threads = 10 90 | 91 | arguments = [] 92 | for i in range(preprocess_threads): 93 | start_idx = round(text_num * i / preprocess_threads) 94 | end_idx = round(text_num * (i+1) / preprocess_threads) 95 | arguments.append(( 96 | collection_path, 97 | code_path, 98 | text_num, 99 | start_idx, 100 | end_idx, 101 | tokenizer, 102 | self.config.code_length, 103 | self.config.text_col, 104 | # FIXME: add stopwords 105 | None, 106 | self.config.get("code_sep", " "), 107 | self.config.get("dedup_code"), 108 | self.config.get("stem_code"), 109 | self.config.get("filter_num"), 110 | self.config.get("filter_unit") 111 | )) 112 | with mp.Pool(preprocess_threads) as p: 113 | p.starmap(_get_title_code, arguments) 114 | 115 | 116 | class DSIQG(DSI): 117 | def __init__(self, config): 118 | super().__init__(config) 119 | 120 | def generate_code(self, loaders): 121 | from transformers import AutoTokenizer, AutoModel 122 | from utils.util import makedirs 123 | if self.config.is_main_proc: 124 | code_path = os.path.join(self.config.cache_root, "codes", self.config.code_type, self.config.code_tokenizer, str(self.config.code_length), "codes.mmp") 125 | # all codes are led by 0 and padded by -1 126 | self.logger.info(f"generating codes from {self.config.code_type} with code_length: {self.config.code_length}, saving at {code_path}...") 127 | makedirs(code_path) 128 | 129 | loader_text = loaders["text"] 130 | text_num = len(loader_text.dataset) 131 | 132 | # load all saved token ids 133 | text_codes = np.memmap( 134 | code_path, 135 | dtype=np.int32, 136 | mode="w+", 137 | shape=(text_num, self.config.code_length) 138 | ) 139 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 140 | model = AutoModel.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 141 | try: 142 | start_token_id = model._get_decoder_start_token_id() 143 | except ValueError: 144 | start_token_id = model.config.pad_token_id 145 | self.logger.warning(f"Decoder start token id not found, use pad token id ({start_token_id}) instead!") 146 | 147 | # the codes are always led by start_token_id and padded by -1 148 | text_codes[:, 0] = start_token_id 149 | text_codes[:, 1:] = -1 150 | 151 | eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id 152 | for i in tqdm(range(text_num)): 153 | code = i 154 | if self.config.get("code_bias"): 155 | code += self.config.code_bias 156 | code = tokenizer.encode(str(code), add_special_tokens=False) 157 | code.append(eos_token_id) 158 | text_codes[i, 1: len(code) + 1] = code 159 | 160 | 161 | -------------------------------------------------------------------------------- /src/models/IVF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from .BaseModel import BaseSparseModel 7 | from .UniCOIL import UniCOIL 8 | from utils.util import BaseOutput, synchronize 9 | from utils.index import FaissIndex 10 | from utils.static import * 11 | 12 | 13 | 14 | class IVF(BaseSparseModel): 15 | def __init__(self, config): 16 | super().__init__(config) 17 | 18 | path = os.path.join(config.cache_root, "index", config.vq_src, "faiss", config.vq_index) 19 | self.logger.info(f"loading index from {path}...") 20 | index = faiss.read_index(path) 21 | if isinstance(index, faiss.IndexPreTransform): 22 | ivf_index = faiss.downcast_index(index.index) 23 | vt = faiss.downcast_VectorTransform(index.chain.at(0)) 24 | opq = faiss.vector_to_array(vt.A).reshape(vt.d_out, vt.d_in).T 25 | self.register_buffer("vt", torch.tensor(opq)) 26 | else: 27 | ivf_index = index 28 | 29 | ivf_centroids = FaissIndex.get_xb(index.quantizer) 30 | self.register_buffer("ivfCentroids", torch.tensor(ivf_centroids)) 31 | 32 | invlists = ivf_index.invlists 33 | ivf_codes = np.zeros(index.ntotal, dtype=np.int32) 34 | for i in range(index.nlist): 35 | ls = invlists.list_size(i) 36 | list_ids = faiss.rev_swig_ptr(invlists.get_ids(i), ls) 37 | for j, docid in enumerate(list_ids): 38 | ivf_codes[docid] = i 39 | 40 | self._ivf_codes = ivf_codes 41 | 42 | self._posting_entry_num = ivf_index.nlist 43 | self._skip_special_tokens = False 44 | self._text_length = 1 45 | self._query_length = config.query_gate_k 46 | 47 | 48 | def _encode_query(self, embedding:Optional[TENSOR]) -> TENSOR: 49 | ivf_quantization = embedding.matmul(self.ivfCentroids.transpose(-1, -2)) # B, ncluster 50 | ivf_weight, ivf_id = ivf_quantization.topk(dim=-1, k=self._query_length) 51 | return ivf_id, ivf_weight 52 | 53 | 54 | def encode_query_step(self, x): 55 | query_embedding = x["query_embedding"].to(self.config.device) 56 | if hasattr(self, "vt"): 57 | query_embedding = query_embedding.matmul(self.vt) 58 | query_ivf_id, query_ivf_weight = self._encode_query(query_embedding) 59 | return query_ivf_id.cpu().numpy(), query_ivf_weight.unsqueeze(-1).cpu().numpy() 60 | 61 | 62 | @synchronize 63 | @torch.no_grad() 64 | def encode_text(self, loader_text, load_all_encode=False): 65 | text_token_id_path = os.path.join(self.text_dir, "text_token_ids.mmp") 66 | text_embedding_path = os.path.join(self.text_dir, "text_embeddings.mmp") 67 | 68 | if load_all_encode: 69 | text_embeddings = np.memmap( 70 | text_embedding_path, 71 | mode="r", 72 | dtype=np.float32 73 | ).reshape(len(loader_text.dataset), self._text_length, self._output_dim).copy() 74 | text_token_ids = np.memmap( 75 | text_token_id_path, 76 | mode="r", 77 | dtype=np.int32 78 | ).reshape(len(loader_text.dataset), self._text_length).copy() 79 | 80 | elif self.config.load_encode or self.config.load_text_encode: 81 | text_embeddings = np.memmap( 82 | text_embedding_path, 83 | mode="r", 84 | dtype=np.float32 85 | ).reshape(len(loader_text.dataset), self._text_length, self._output_dim)[loader_text.sampler.start: loader_text.sampler.end].copy() 86 | text_token_ids = np.memmap( 87 | text_token_id_path, 88 | mode="r", 89 | dtype=np.int32 90 | ).reshape(len(loader_text.dataset), self._text_length)[loader_text.sampler.start: loader_text.sampler.end].copy() 91 | 92 | else: 93 | self.logger.info(f"encoding {self.config.dataset} text...") 94 | text_token_ids = np.expand_dims(self._ivf_codes[loader_text.sampler.start: loader_text.sampler.end], -1) 95 | text_embeddings = np.ones((*text_token_ids.shape, 1), dtype=np.float32) 96 | 97 | if self.config.save_encode: 98 | self.save_to_mmp( 99 | path=text_embedding_path, 100 | shape=(len(loader_text.dataset), self._text_length, self._output_dim), 101 | dtype=np.float32, 102 | loader=loader_text, 103 | obj=text_embeddings 104 | ) 105 | self.save_to_mmp( 106 | path=text_token_id_path, 107 | shape=(len(loader_text.dataset), self._text_length), 108 | dtype=np.int32, 109 | loader=loader_text, 110 | obj=text_token_ids 111 | ) 112 | 113 | return BaseOutput(embeddings=text_embeddings, token_ids=text_token_ids) 114 | 115 | 116 | class TopIVF(IVF): 117 | """ 118 | Fix IVF assignments, and optimize centroid embeddings. 119 | """ 120 | def __init__(self, config): 121 | super().__init__(config) 122 | self.ivfCentroids = nn.parameter.Parameter(self.ivfCentroids) 123 | 124 | def _encode_text(self, embedding, text_idx=None): 125 | ivf_assign = embedding.matmul(self.ivfCentroids.transpose(-1, -2)) # B, ncluster 126 | if text_idx is None: 127 | ivf_weight, ivf_id = ivf_assign.max(dim=-1) # B 128 | ivf_assign_soft = torch.softmax(ivf_assign, dim=-1) # B, nlist 129 | ivf_assign_hard = torch.zeros_like(ivf_assign_soft).scatter_(-1, ivf_id.unsqueeze(-1), 1.0) 130 | # straight-through trick 131 | ivf_assign_st = ivf_assign_hard.detach() - ivf_assign_soft.detach() + ivf_assign_soft # B, nlist 132 | quantized_embedding = ivf_assign_st.matmul(self.ivfCentroids) 133 | 134 | else: 135 | ivf_id = torch.as_tensor(self._ivf_codes[text_idx], device=self.config.device, dtype=torch.long) 136 | ivf_weight = torch.ones(ivf_id.shape, device=ivf_id.device) 137 | quantized_embedding = self.ivfCentroids[ivf_id] 138 | 139 | returns = (ivf_id, ivf_weight) 140 | if self.training: 141 | returns = (quantized_embedding, ivf_assign) + returns 142 | return returns 143 | 144 | def forward(self, x): 145 | x = self._move_to_device(x) 146 | 147 | text_idx = x["text_idx"].view(-1).cpu().numpy() 148 | text_embedding = x["text_embedding"].flatten(0, 1) # B*(1+N), D 149 | query_embedding = x["query_embedding"] # B, D 150 | 151 | quantized_text_embedding, ivf_assign, ivf_id, _ = self._encode_text(text_embedding, text_idx) # B*(1+N), D; B*(1+N), nlist; B*(1+N) 152 | 153 | if self.config.is_distributed and self.config.enable_all_gather: 154 | query_embedding = self._gather_tensors(query_embedding) 155 | text_embedding = self._gather_tensors(text_embedding) 156 | quantized_text_embedding = self._gather_tensors(quantized_text_embedding) 157 | ivf_assign = self._gather_tensors(ivf_assign) 158 | ivf_id = self._gather_tensors(ivf_id) 159 | 160 | score_ivf = query_embedding.matmul(quantized_text_embedding.transpose(-1,-2)) # B, B*(1+N) 161 | 162 | B = query_embedding.size(0) 163 | if self.config.enable_inbatch_negative: 164 | label = torch.arange(B, device=self.config.device) 165 | label = label * (text_embedding.shape[0] // query_embedding.shape[0]) 166 | # mask_ind = ivf_id.view(B, -1) # B, 1+N 167 | # # if the negative's label equals to the positive's 168 | # mask_ind = (mask_ind.T == mask_ind[:, 0]).T # B, 1+N 169 | # mask_ind[:, 0] = False 170 | # # mask the conflicting labels 171 | # label[mask_ind.view(-1)] = -100 172 | else: 173 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 174 | score_ivf = score_ivf.view(B, B, -1)[range(B), range(B)] # B, 1+N 175 | 176 | loss = self._compute_loss(score_ivf, label) 177 | 178 | if self.config.enable_commit_loss: 179 | score_commit = ivf_assign # B*(1+N), nlist 180 | label_commit = ivf_id # B*(1+N) 181 | loss += self._compute_loss(score_commit, label_commit) 182 | 183 | return loss 184 | 185 | 186 | class TokIVF(UniCOIL): 187 | """ 188 | Uses explicit tokens as IVF entries. 189 | """ 190 | def __init__(self, config): 191 | super().__init__(config) 192 | self.queryEncoder = None 193 | 194 | 195 | def _encode_query(self, **kwargs): 196 | token_embedding = torch.ones_like(kwargs["input_ids"], dtype=torch.float).unsqueeze(-1) 197 | return token_embedding 198 | 199 | 200 | def encode_query_step(self, x): 201 | return BaseSparseModel.encode_query_step(self, x) 202 | 203 | -------------------------------------------------------------------------------- /src/models/BOW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .DSI import DSI 4 | from transformers import AutoTokenizer 5 | 6 | 7 | class BOW(DSI): 8 | def __init__(self, config): 9 | super().__init__(config) 10 | 11 | def forward(self, x): 12 | x = self._move_to_device(x) 13 | # squeeze the auxillary dimension 14 | if "query_code" in x: 15 | text_code = x["query_code"] 16 | else: 17 | text_code = x["text_code"] 18 | 19 | # in case there are multiple codes for one text (config.permute_code > 0) 20 | encoder_outputs = self.plm.encoder(**x["query"]) 21 | query_attn_mask = x["query"]["attention_mask"] 22 | B, M = text_code.shape[:2] 23 | for k, v in encoder_outputs.items(): 24 | encoder_outputs[k] = v.repeat_interleave(M, 0) 25 | query_attn_mask = query_attn_mask.repeat_interleave(M, 0) # B*M, L 26 | 27 | text_code = text_code.flatten(0,1) # B*M, CL 28 | # the code has a leading 0, shift left one position so t5 can shift it back 29 | # default to -1 30 | labels = torch.zeros_like(text_code) - 1 31 | labels[:, :-1] = text_code[:, 1:] 32 | # replace the padding code with the -100 (ignored when computing loss) 33 | labels = labels.masked_fill(labels == -1, -100) 34 | logits = self.plm(attention_mask=query_attn_mask, encoder_outputs=encoder_outputs, labels=labels).logits # B*M, CL, V 35 | 36 | # ignore_index defaults to -100 37 | loss = nn.functional.cross_entropy(logits.flatten(0,1), labels.view(-1), reduction="none").view(B, M, -1).mean(-1) # B, M 38 | # sum 39 | if self.config.reduce_code == "mean": 40 | loss = loss.mean() 41 | elif self.config.reduce_code == "min": 42 | min_loss, min_index = loss.min(-1) 43 | loss = min_loss.mean() 44 | else: 45 | raise NotImplementedError(f"Reduction type {self.config.reduce_code} is not implemented!") 46 | return loss 47 | 48 | def generate_code(self, loaders): 49 | """ 50 | Greedily decode the keywords. 51 | """ 52 | import os 53 | import numpy as np 54 | from utils.util import synchronize, makedirs 55 | assert self.config.index_type == "wordset", "Only wordset index can be used when sorting keywords!" 56 | 57 | if self.config.get("sort_code"): 58 | from tqdm import tqdm 59 | from utils.index import BeamDecoder, GreedyCodeSorter 60 | from utils.data import prepare_train_data 61 | index = self.index(loaders).index 62 | 63 | text_dataset = loaders["text"].dataset 64 | # set necessary attributes to enable loader_train 65 | self.config.loader_train = "neg" 66 | self.config.train_set = [self.config.eval_set] 67 | self.config.neg_type = "none" 68 | self.config.batch_size = self.config.eval_batch_size 69 | 70 | loader_train = prepare_train_data(self.config, text_dataset, return_dataloader=True) 71 | query_dataset = loader_train.dataset.query_datasets[0] 72 | 73 | code_path = os.path.join(self.code_dir, self.config.code_src, query_dataset.query_set, "codes.mmp") 74 | makedirs(code_path) 75 | 76 | self.logger.info(f"sorting keywords from {self.config.code_type} and saving at {code_path}...") 77 | 78 | if self.config.get("nbeam", 1) > 1: 79 | sorter = BeamDecoder() 80 | nseq = self.config.nbeam 81 | else: 82 | sorter = GreedyCodeSorter() 83 | nseq = self.config.get("decode_nseq", 1) 84 | 85 | if self.config.is_main_proc: 86 | query_codes = np.memmap( 87 | code_path, 88 | dtype=np.int32, 89 | shape=(len(loader_train.dataset), nseq, self.config.code_length), 90 | mode="w+" 91 | ) 92 | # default to -1 to be used as padding 93 | query_codes[:] = -1 94 | synchronize() 95 | query_codes = np.memmap( 96 | code_path, 97 | dtype=np.int32, 98 | mode="r+" 99 | ).reshape(len(loader_train.dataset), nseq, self.config.code_length) 100 | 101 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 102 | 103 | start_idx = 0 104 | for i, x in enumerate(tqdm(loader_train, leave=False, ncols=100)): 105 | # if i < 69198: 106 | # continue 107 | qrel_idx = x["qrel_idx"] 108 | query = self._move_to_device(x["query"]) 109 | 110 | encoder_outputs = self.plm.encoder(**query) 111 | B = query["input_ids"].shape[0] 112 | end_idx = start_idx + B 113 | 114 | sorter.search( 115 | model=self.plm, 116 | query={**query, "encoder_outputs": encoder_outputs}, 117 | nbeam=self.config.nbeam, 118 | max_new_tokens=self.config.code_length - 1, 119 | constrain_index=index, 120 | text_indices=x["text_idx"].squeeze(1).numpy(), 121 | # forbid early stop as we must generate the entire sequence 122 | do_early_stop=False, 123 | do_sample=self.config.get("decode_do_sample", False), 124 | num_return_sequences=self.config.get("decode_nseq", 1), 125 | temperature=self.config.get("decode_tau", 1), 126 | ) 127 | # print(tokenizer.decode(x["query"]["input_ids"][0], skip_special_tokens=True), tokenizer.batch_decode(sorter.beams[0])) 128 | # input() 129 | 130 | res = sorter.beams # batch_size, nseq, code_length 131 | 132 | # assign query_codes 133 | for j, y in enumerate(res): 134 | length = len(y[0]) 135 | query_codes[qrel_idx[j], :, :length] = y 136 | 137 | start_idx = end_idx 138 | if self.config.debug: 139 | if i > 2: 140 | break 141 | 142 | if self.config.is_main_proc: 143 | same_count = 0 144 | text_codes = text_dataset.text_codes 145 | for qrel in loader_train.dataset.qrels: 146 | qrel_idx, query_set_idx, query_idx, text_idx = qrel 147 | query_code = query_codes[qrel_idx] 148 | text_code = text_codes[text_idx] 149 | for qc in query_code: 150 | if (qc == text_code).all(): 151 | same_count += 1 152 | 153 | self.logger.info(f"{same_count}/{query_codes.shape[0] * query_codes.shape[1]} query codes are identical to the text codes!") 154 | 155 | else: 156 | if self.config.code_type == "chat": 157 | import multiprocessing as mp 158 | from utils.util import _get_chatgpt_code 159 | 160 | # the code is bind to the code_tokenizer 161 | code_path = os.path.join(self.config.cache_root, "codes", self.config.code_type, self.config.code_tokenizer, str(self.config.code_length), "codes.mmp") 162 | self.logger.info(f"generating codes from {self.config.code_type} with code_length: {self.config.code_length}, saving at {code_path}...") 163 | 164 | loader_text = loaders["text"] 165 | text_num = len(loader_text.dataset) 166 | makedirs(code_path) 167 | 168 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 169 | 170 | # load all saved token ids 171 | # all codes are led by 0 and padded by -1 172 | text_codes = np.memmap( 173 | code_path, 174 | dtype=np.int32, 175 | mode="w+", 176 | shape=(text_num, self.config.code_length) 177 | ) 178 | # the codes are always led by 0 and padded by -1 179 | text_codes[:, 0] = tokenizer.pad_token_id 180 | text_codes[:, 1:] = -1 181 | 182 | thread_num = 10 183 | # each thread creates one jsonl file 184 | text_num_per_thread = text_num / thread_num 185 | 186 | arguments = [] 187 | # re-tokenize words in the collection folder 188 | for i in range(thread_num): 189 | input_path = os.path.join(self.config.data_root, self.config.dataset, "keywords.tsv") 190 | start_idx = round(text_num_per_thread * i) 191 | end_idx = round(text_num_per_thread * (i+1)) 192 | 193 | arguments.append(( 194 | input_path, 195 | code_path, 196 | text_num, 197 | start_idx, 198 | end_idx, 199 | tokenizer, 200 | self.config.code_length, 201 | )) 202 | 203 | # the collection has no special_tokens so we don't need to filter them out 204 | with mp.Pool(thread_num) as p: 205 | p.starmap(_get_chatgpt_code, arguments) 206 | 207 | -------------------------------------------------------------------------------- /src/models/TSGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .DSI import DSI 4 | from transformers import AutoTokenizer 5 | 6 | 7 | class TSGen(DSI): 8 | def __init__(self, config): 9 | super().__init__(config) 10 | 11 | def forward(self, x): 12 | x = self._move_to_device(x) 13 | # squeeze the auxillary dimension 14 | if "query_code" in x: 15 | text_code = x["query_code"] 16 | else: 17 | text_code = x["text_code"] 18 | invalid_code_pos = (text_code == -1).all(-1) 19 | 20 | # in case there are multiple codes for one text (config.permute_code > 0) 21 | encoder_outputs = self.plm.encoder(**x["query"]) 22 | query_attn_mask = x["query"]["attention_mask"] 23 | B, M = text_code.shape[:2] 24 | for k, v in encoder_outputs.items(): 25 | encoder_outputs[k] = v.repeat_interleave(M, 0) 26 | query_attn_mask = query_attn_mask.repeat_interleave(M, 0) # B*M, L 27 | 28 | text_code = text_code.flatten(0,1) # B*M, CL 29 | # the code has a leading 0, shift left one position so t5 can shift it back 30 | # default to -1 31 | labels = torch.zeros_like(text_code) - 1 32 | labels[:, :-1] = text_code[:, 1:] 33 | # replace the padding code with the -100 (ignored when computing loss) 34 | labels = labels.masked_fill(labels == -1, -100) 35 | logits = self.plm(attention_mask=query_attn_mask, encoder_outputs=encoder_outputs, labels=labels).logits # B*M, CL, V 36 | 37 | # ignore_index defaults to -100 38 | loss = nn.functional.cross_entropy(logits.flatten(0,1), labels.view(-1), reduction="none").view(B, M, -1).mean(-1) # B, M 39 | # sum 40 | if self.config.reduce_code == "mean": 41 | loss[invalid_code_pos] = 0 42 | loss = loss.mean() 43 | elif self.config.reduce_code == "min": 44 | loss[invalid_code_pos] = 1e6 45 | min_loss, min_index = loss.min(-1) 46 | loss = min_loss.mean() 47 | else: 48 | raise NotImplementedError(f"Reduction type {self.config.reduce_code} is not implemented!") 49 | return loss 50 | 51 | def generate_code(self, loaders): 52 | """ 53 | Greedily decode the keywords. 54 | """ 55 | import os 56 | import numpy as np 57 | from utils.util import synchronize, makedirs 58 | assert self.config.index_type == "wordset", "Only wordset index can be used when sorting keywords!" 59 | 60 | if self.config.get("sort_code"): 61 | from tqdm import tqdm 62 | from utils.index import BeamDecoder 63 | from utils.data import prepare_train_data 64 | index = self.index(loaders).index 65 | 66 | text_dataset = loaders["text"].dataset 67 | # set necessary attributes to enable loader_train 68 | self.config.loader_train = "neg" 69 | self.config.train_set = [self.config.eval_set] 70 | self.config.neg_type = "none" 71 | self.config.batch_size = self.config.eval_batch_size 72 | 73 | loader_train = prepare_train_data(self.config, text_dataset, return_dataloader=True) 74 | query_dataset = loader_train.dataset.query_datasets[0] 75 | 76 | code_path = os.path.join(self.code_dir, self.config.code_src, query_dataset.query_set, "codes.mmp") 77 | makedirs(code_path) 78 | 79 | self.logger.info(f"sorting keywords from {self.config.code_type} and saving at {code_path}...") 80 | 81 | sorter = BeamDecoder() 82 | nseq = self.config.nbeam 83 | 84 | if self.config.is_main_proc: 85 | query_codes = np.memmap( 86 | code_path, 87 | dtype=np.int32, 88 | shape=(len(loader_train.dataset), nseq, self.config.code_length), 89 | mode="w+" 90 | ) 91 | # default to -1 to be used as padding 92 | query_codes[:] = -1 93 | synchronize() 94 | query_codes = np.memmap( 95 | code_path, 96 | dtype=np.int32, 97 | mode="r+" 98 | ).reshape(len(loader_train.dataset), nseq, self.config.code_length) 99 | 100 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 101 | 102 | start_idx = 0 103 | for i, x in enumerate(tqdm(loader_train, leave=False, ncols=100)): 104 | # if i < 328: 105 | # continue 106 | qrel_idx = x["qrel_idx"] 107 | query = self._move_to_device(x["query"]) 108 | 109 | encoder_outputs = self.plm.encoder(**query) 110 | B = query["input_ids"].shape[0] 111 | end_idx = start_idx + B 112 | 113 | sorter.search( 114 | model=self.plm, 115 | query={**query, "encoder_outputs": encoder_outputs}, 116 | nbeam=self.config.nbeam, 117 | max_new_tokens=self.config.code_length - 1, 118 | constrain_index=index, 119 | text_indices=x["text_idx"].squeeze(1).numpy(), 120 | tokenizer=tokenizer, 121 | do_sample=self.config.decode_do_sample, 122 | do_greedy=self.config.decode_do_greedy, 123 | topk=self.config.sample_topk, 124 | topp=float(self.config.sample_topp) if self.config.sample_topp is not None else None, 125 | typical_p=float(self.config.sample_typicalp) if self.config.sample_typicalp is not None else None, 126 | temperature=float(self.config.sample_tau) if self.config.sample_tau is not None else None, 127 | renormalize_logits=self.config.decode_renorm_logit, 128 | # forbid early stop as we must generate the entire sequence 129 | do_early_stop=False, 130 | ) 131 | res = sorter.beams # batch_size, nseq, code_length 132 | 133 | # assign query_codes 134 | for j, y in enumerate(res): 135 | length = len(y[0]) 136 | try: 137 | query_codes[qrel_idx[j], :len(y), :length] = y 138 | except: 139 | print(i, j, self.config.rank) 140 | raise 141 | 142 | start_idx = end_idx 143 | if self.config.debug: 144 | if i > 2: 145 | break 146 | 147 | if self.config.is_main_proc: 148 | same_count = 0 149 | text_codes = text_dataset.text_codes 150 | for qrel in loader_train.dataset.qrels: 151 | qrel_idx, query_set_idx, query_idx, text_idx = qrel 152 | query_code = query_codes[qrel_idx] 153 | text_code = text_codes[text_idx] 154 | for qc in query_code: 155 | if (qc == text_code).all(): 156 | same_count += 1 157 | 158 | self.logger.info(f"{same_count}/{query_codes.shape[0] * query_codes.shape[1]} query codes are identical to the text codes!") 159 | 160 | else: 161 | if self.config.code_type == "chat": 162 | import multiprocessing as mp 163 | from utils.util import _get_chatgpt_code 164 | 165 | # the code is bind to the code_tokenizer 166 | code_path = os.path.join(self.config.cache_root, "codes", self.config.code_type, self.config.code_tokenizer, str(self.config.code_length), "codes.mmp") 167 | self.logger.info(f"generating codes from {self.config.code_type} with code_length: {self.config.code_length}, saving at {code_path}...") 168 | 169 | loader_text = loaders["text"] 170 | text_num = len(loader_text.dataset) 171 | makedirs(code_path) 172 | 173 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 174 | 175 | # load all saved token ids 176 | # all codes are led by 0 and padded by -1 177 | text_codes = np.memmap( 178 | code_path, 179 | dtype=np.int32, 180 | mode="w+", 181 | shape=(text_num, self.config.code_length) 182 | ) 183 | # the codes are always led by 0 and padded by -1 184 | text_codes[:, 0] = tokenizer.pad_token_id 185 | text_codes[:, 1:] = -1 186 | 187 | thread_num = 10 188 | # each thread creates one jsonl file 189 | text_num_per_thread = text_num / thread_num 190 | 191 | arguments = [] 192 | # re-tokenize words in the collection folder 193 | for i in range(thread_num): 194 | input_path = os.path.join(self.config.data_root, self.config.dataset, "keywords.tsv") 195 | start_idx = round(text_num_per_thread * i) 196 | end_idx = round(text_num_per_thread * (i+1)) 197 | 198 | arguments.append(( 199 | input_path, 200 | code_path, 201 | text_num, 202 | start_idx, 203 | end_idx, 204 | tokenizer, 205 | self.config.code_length, 206 | )) 207 | 208 | # the collection has no special_tokens so we don't need to filter them out 209 | with mp.Pool(thread_num) as p: 210 | p.starmap(_get_chatgpt_code, arguments) 211 | 212 | -------------------------------------------------------------------------------- /src/scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import numpy as np 5 | from tqdm import tqdm 6 | from multiprocessing import Pool 7 | from collections import defaultdict 8 | from transformers import AutoTokenizer 9 | from utils.util import save_pickle, load_pickle, Config 10 | from utils.static import * 11 | 12 | import hydra 13 | from pathlib import Path 14 | from omegaconf import OmegaConf 15 | @hydra.main(version_base=None, config_path="../data/config/", config_name=f"script/{Path(__file__).stem}") 16 | def get_config(hydra_config: OmegaConf): 17 | config._from_hydra(hydra_config) 18 | 19 | 20 | def init_text(collection_path:str, cache_dir:str) -> ID_MAPPING: 21 | """ 22 | convert document ids to offsets 23 | """ 24 | tid2index = {} 25 | os.makedirs(cache_dir, exist_ok=True) 26 | with open(collection_path, "r", encoding="utf-8") as f: 27 | for i, line in enumerate(tqdm(f, desc="Collecting Text IDs", ncols=100, leave=False)): 28 | tid = line.split('\t')[0].strip() 29 | tid2index[tid] = len(tid2index) 30 | save_pickle(tid2index, os.path.join(cache_dir, "id2index.pkl")) 31 | return tid2index 32 | 33 | 34 | def init_query_and_qrel(query_path:str, qrel_path:str, cache_dir:str, tid2index:ID_MAPPING): 35 | """ 36 | Tokenize query file and transfer passage/document/query id in qrel file to its index in the saved token-ids matrix. 37 | 38 | Args: 39 | query_path: query file path 40 | qrel_path: qrel file path 41 | cache_dir: the directory to save files 42 | tid2index: mapping from text ids to text indices 43 | """ 44 | os.makedirs(cache_dir, exist_ok=True) 45 | 46 | valid_queries = set() 47 | for line in open(qrel_path, 'r', encoding='utf-8'): 48 | try: 49 | query_id, _, positive_text_id, _ = line.strip().split() 50 | except: 51 | raise ValueError(f"Invalid format: {line}") 52 | if query_id in valid_queries: 53 | pass 54 | else: 55 | valid_queries.add(query_id) 56 | 57 | print("valid query number: {}".format(len(valid_queries))) 58 | 59 | qid2index = {} 60 | invalid_count = 0 61 | tmp_query_path = ".".join([*query_path.split(".")[:-1], "tmp", "tsv"]) 62 | 63 | with open(query_path, "r", encoding="utf-8") as f, \ 64 | open(tmp_query_path, "w", encoding="utf-8") as g: 65 | for i, line in enumerate(tqdm(f, desc="Removing Missing Queries", ncols=100, leave=False)): 66 | query_id = line.split('\t')[0] 67 | if query_id not in valid_queries: 68 | invalid_count += 1 69 | continue 70 | qid2index[query_id] = len(qid2index) 71 | g.write(line) 72 | 73 | if invalid_count > 0: 74 | # backup queries that appear in the query file but not in the qrel file 75 | backup_query_path = ".".join([*query_path.split(".")[:-1], "backup", "tsv"]) 76 | print(f"There are {invalid_count} queries that appear in the query file but not in the qrel file! The original query file is saved at {backup_query_path}.") 77 | os.rename(query_path, backup_query_path) 78 | else: 79 | os.remove(query_path) 80 | os.rename(tmp_query_path, query_path) 81 | 82 | qrels = [] 83 | positives = defaultdict(list) 84 | with open(qrel_path, "r") as g: 85 | for line in tqdm(g, ncols=100, leave=False, desc="Processing Qrels"): 86 | fields = line.split() 87 | if len(fields) == 4: 88 | query_id, _, text_id, _ = fields 89 | elif len(fields) == 2: 90 | query_id, text_id = fields 91 | else: 92 | raise NotImplementedError 93 | query_index = qid2index[query_id] 94 | text_index = tid2index[text_id] 95 | qrels.append((query_index, text_index)) 96 | # there may be multiple positive samples correpsonding to one query 97 | positives[query_index].append(text_index) 98 | 99 | save_pickle(qid2index, os.path.join(cache_dir, "id2index.pkl")) 100 | save_pickle(qrels, os.path.join(cache_dir, "qrels.pkl")) 101 | save_pickle(dict(positives), os.path.join(cache_dir, "positives.pkl")) 102 | return qid2index 103 | 104 | 105 | def tokenize_to_memmap(input_path:str, cache_dir:str, num_rec:int, max_length:int, tokenizer:Any, tokenizer_type:str, tokenize_thread:int, text_col:list[int]=None, text_col_sep:str=None, is_query:bool=False) -> ID_MAPPING: 106 | """ 107 | tokenize the passage/document text in multiple threads 108 | 109 | Args: 110 | input_path: query/passage/document file path 111 | cache_dir: save the output token ids etc 112 | num_rec: the number of records 113 | max_length: max length of tokens 114 | tokenizer(transformers.Tokenizer) 115 | tokenizer_type: the actual tokenizer vocabulary used 116 | tokenize_thread 117 | text_col 118 | text_col_sep 119 | is_query: if the input is a query 120 | 121 | Returns: 122 | mapping from the id to the index in the saved token-id matrix 123 | """ 124 | cache_dir_with_plm = os.path.join(cache_dir, tokenizer_type) 125 | os.makedirs(cache_dir_with_plm, exist_ok=True) 126 | 127 | print(f"tokenizing {input_path} in {tokenize_thread} threads, output file will be saved at {cache_dir_with_plm}") 128 | 129 | arguments = [] 130 | 131 | memmap_path = os.path.join(cache_dir_with_plm, "token_ids.mmp") 132 | # remove old memmap file 133 | if os.path.exists(memmap_path): 134 | os.remove(memmap_path) 135 | 136 | # create memmap first 137 | token_ids = np.memmap( 138 | memmap_path, 139 | shape=(num_rec, max_length), 140 | mode="w+", 141 | dtype=np.int32 142 | ) 143 | 144 | for i in range(tokenize_thread): 145 | start_idx = round(num_rec * i / tokenize_thread) 146 | end_idx = round(num_rec * (i+1) / tokenize_thread) 147 | arguments.append((input_path, cache_dir_with_plm, num_rec, start_idx, end_idx, tokenizer, max_length, text_col, text_col_sep, is_query)) 148 | 149 | with Pool(tokenize_thread) as p: 150 | p.starmap(_tokenize_to_memmap, arguments) 151 | 152 | 153 | 154 | def _tokenize_to_memmap(input_path:str, output_dir:str, num_rec:int, start_idx:int, end_idx:int, tokenizer:Any, max_length:int, text_col:list[int]=None, text_col_sep:str=None, is_query:bool=False): 155 | """ 156 | #. Tokenize the input text; 157 | 158 | #. do padding and truncation; 159 | 160 | #. then save the token ids, token_lengths, text ids 161 | 162 | Args: 163 | input_path: input text file path 164 | output_dir: directory of output numpy arrays 165 | start_idx: the begining index to read 166 | end_idx: the ending index 167 | tokenizer: transformer tokenizer 168 | max_length: max length of tokens 169 | text_col: 170 | text_col_sep: 171 | is_query 172 | """ 173 | # some models such as t5 doesn't have sep token, we use space instead 174 | separator = text_col_sep 175 | pad_token_id = -1 176 | 177 | token_ids = np.memmap( 178 | os.path.join(output_dir, "token_ids.mmp"), 179 | shape=(num_rec, max_length), 180 | mode="r+", 181 | dtype=np.int32 182 | ) 183 | 184 | with open(input_path, 'r') as f: 185 | pbar = tqdm(total=end_idx-start_idx, desc="Tokenizing", ncols=100, leave=False) 186 | for idx, line in enumerate(f): 187 | if idx < start_idx: 188 | continue 189 | if idx >= end_idx: 190 | break 191 | 192 | columns = line.split('\t') 193 | 194 | if is_query: 195 | # query has only one textual column 196 | text = columns[-1].strip() 197 | else: 198 | text = [] 199 | for col_idx in text_col: 200 | text.append(columns[col_idx].strip()) 201 | text = separator.join(text) 202 | 203 | # only encode text 204 | token_id = tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=max_length) 205 | token_length = len(token_id) 206 | 207 | token_ids[idx, :] = token_id + [pad_token_id] * (max_length - token_length) 208 | pbar.update(1) 209 | pbar.close() 210 | 211 | 212 | if __name__ == "__main__": 213 | # manually action="store_true" because hydra doesn't support it 214 | for i, arg in enumerate(sys.argv): 215 | if "=" not in arg: 216 | sys.argv[i] += "=true" 217 | 218 | config = Config() 219 | get_config() 220 | 221 | cache_dir = os.path.join(config.cache_root, "dataset") 222 | data_dir = os.path.join(config.data_root, config.dataset) 223 | text_dir = os.path.join(cache_dir, "text") 224 | 225 | tokenizer = AutoTokenizer.from_pretrained(config.plm_dir) 226 | tokenizer_type = config.plm_tokenizer 227 | 228 | collection_path = os.path.join(data_dir, "collection.tsv") 229 | 230 | if config.do_text: 231 | tid2index = init_text(collection_path, text_dir) 232 | if config.pretokenize: 233 | tokenize_to_memmap(collection_path, os.path.join(text_dir, ','.join([str(x) for x in config.text_col])), len(tid2index), config.max_text_length, tokenizer, tokenizer_type, config.tokenize_thread, text_col=config.text_col, text_col_sep=config.text_col_sep) 234 | 235 | if config.do_query: 236 | tid2index = load_pickle(os.path.join(text_dir, "id2index.pkl")) 237 | 238 | for query_set in config.query_set: 239 | query_path = os.path.join(data_dir, f"queries.{query_set}.tsv") 240 | qrel_path = os.path.join(data_dir, f"qrels.{query_set}.tsv") 241 | qid2index = init_query_and_qrel(query_path, qrel_path, os.path.join(cache_dir, "query", query_set), tid2index) 242 | if config.pretokenize: 243 | tokenize_to_memmap(os.path.join(data_dir, f"queries.{query_set}.tsv"), os.path.join(cache_dir, "query", query_set), len(qid2index), config.max_query_length, tokenizer, tokenizer_type, config.tokenize_thread, is_query=True) 244 | -------------------------------------------------------------------------------- /src/models/UniCOIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .COIL import COIL 4 | 5 | 6 | 7 | class UniCOIL(COIL): 8 | def __init__(self, config): 9 | config.token_dim = 1 10 | super().__init__(config) 11 | 12 | plm_dim = self.textEncoder.config.hidden_size 13 | self.tokenProject = nn.Sequential( 14 | nn.Linear(plm_dim, plm_dim), 15 | nn.ReLU(), 16 | nn.Linear(plm_dim, 1), 17 | nn.ReLU() 18 | ) 19 | 20 | self.special_token_ids = [x[1] for x in config.special_token_ids.values() if x[0] is not None] 21 | 22 | 23 | def _to_bow(self, token_ids, token_weights): 24 | """ 25 | Convert the token sequence (maybe repetitive tokens) into BOW (no repetitive tokens except pad token) 26 | 27 | Args: 28 | token_ids: tensor of B, L 29 | token_weights: tensor of B, L, 1 30 | 31 | Returns: 32 | bow representation of B, V 33 | """ 34 | # create the src 35 | dest = torch.zeros((*token_ids.shape, self.config.vocab_size), device=token_ids.device) - 1 # B, L, V 36 | bow = torch.scatter(dest, dim=-1, index=token_ids.unsqueeze(-1), src=token_weights) 37 | bow = bow.max(dim=1)[0] # B, V 38 | # only pad token and the tokens with positive weights are valid 39 | bow[:, self.special_token_ids] = 0 40 | return bow 41 | 42 | 43 | def encode_text_step(self, x): 44 | text = self._move_to_device(x["text"]) 45 | text_token_id = text["input_ids"] 46 | text_token_weight = self._encode_text(**text) 47 | 48 | if "text_first_mask" in x: 49 | text_bow = self._to_bow(text_token_id, text_token_weight) 50 | text_token_weight = text_bow.gather(index=text_token_id, dim=-1) 51 | 52 | # mask the duplicated tokens' weight 53 | text_first_mask = self._move_to_device(x["text_first_mask"]) 54 | # mask duplicated tokens' id 55 | text_token_id = text_token_id.masked_fill(~text_first_mask, 0) 56 | text_token_weight = text_token_weight.masked_fill(~text_first_mask, 0).unsqueeze(-1) 57 | 58 | # unsqueeze to map it to the _output_dim (1) 59 | return text_token_id.cpu().numpy(), text_token_weight.cpu().numpy() 60 | 61 | 62 | def encode_query_step(self, x): 63 | query = self._move_to_device(x["query"]) 64 | query_token_id = query["input_ids"] 65 | 66 | query_token_weight = self._encode_query(**query) 67 | query_token_weight *= query["attention_mask"].unsqueeze(-1) 68 | 69 | # unsqueeze to map it to the _output_dim (1) 70 | return query_token_id.cpu().numpy(), query_token_weight.cpu().numpy() 71 | 72 | 73 | # FIXME: refactor 74 | # @torch.no_grad() 75 | # def generate_code(self, loaders): 76 | # if self.config.get("sort_code"): 77 | # import os 78 | # import numpy as np 79 | # from tqdm import tqdm 80 | # from transformers import AutoTokenizer, AutoModel 81 | # from utils.index import convert_tokens_to_words, subword_to_word_bert 82 | # from utils.data import prepare_train_data 83 | # from utils.util import synchronize, makedirs 84 | 85 | # text_dataset = loaders["text"].dataset 86 | # # set necessary attributes to enable loader_train 87 | # self.config.loader_train = "neg" 88 | # self.config.train_set = [self.config.eval_set] 89 | # self.config.neg_type = "none" 90 | # self.config.batch_size = self.config.eval_batch_size 91 | 92 | # loader_train = prepare_train_data(self.config, text_dataset, return_dataloader=True) 93 | # query_dataset = loader_train.dataset.query_datasets[0] 94 | 95 | # code_path = os.path.join(self.config.cache_root, "codes", self.config.code_type, self.config.code_tokenizer, str(self.config.code_length), self.config.code_src, query_dataset.query_set, "codes.mmp") 96 | # makedirs(code_path) 97 | 98 | # self.logger.info(f"sorting keywords from {self.config.code_type} and saving at {code_path}...") 99 | 100 | # text_tokenizer = AutoTokenizer.from_pretrained(self.config.plm_dir) 101 | # code_tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 102 | 103 | # special_tokens = set([x[0] for x in self.config.special_token_ids.values()]) 104 | 105 | # # load all encoded results 106 | # # text_outputs = self.encode_text(loaders["text"], load_all_cache=True) 107 | # # query_outputs = self.encode_query(loaders["query"], load_all_cache=True) 108 | # # text_token_ids = text_outputs.token_ids 109 | # # text_token_weights = text_outputs.embeddings.squeeze(-1) 110 | # # query_token_ids = query_outputs.token_ids 111 | # # query_token_weights = query_outputs.embeddings.squeeze(-1) 112 | 113 | # if self.config.is_main_proc: 114 | # query_codes = np.memmap( 115 | # code_path, 116 | # dtype=np.int32, 117 | # shape=(len(loader_train.dataset), self.config.code_length), 118 | # mode="w+" 119 | # ) 120 | # model = AutoModel.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 121 | # try: 122 | # start_token_id = model._get_decoder_start_token_id() 123 | # except ValueError: 124 | # start_token_id = model.config.pad_token_id 125 | # self.logger.warning(f"Decoder start token id not found, use pad token id ({start_token_id}) instead!") 126 | # # the codes are always led by start_token_id and padded by -1 127 | # query_codes[:, 0] = start_token_id 128 | 129 | # synchronize() 130 | # query_codes = np.memmap( 131 | # code_path, 132 | # dtype=np.int32, 133 | # mode="r+" 134 | # ).reshape(len(loader_train.dataset), self.config.code_length) 135 | 136 | # for i, x in enumerate(tqdm(loader_train, leave=False, ncols=100)): 137 | # qrel_idx = x["qrel_idx"].numpy() 138 | 139 | # # squeeze the second dimension (text_num) 140 | # for k, v in x.items(): 141 | # if k == "text": 142 | # for sk, sv in v.items(): 143 | # x[k][sk] = sv.squeeze(1) 144 | # elif "text" in k: 145 | # x[k] = v.squeeze(1) 146 | 147 | # text_code = x["text_code"] 148 | 149 | # x = self._move_to_device(x) 150 | # text_token_id = x["text"]["input_ids"] # B, LS 151 | # query_token_id = x["query"]["input_ids"] # B, LQ 152 | 153 | # text_token_weight = self._encode_text(**x["text"]) # B, LS, 1 154 | # query_token_weight = self._encode_query(**x["query"]) # B, LQ, 1 155 | 156 | # overlap = query_token_id.unsqueeze(-1) == text_token_id.unsqueeze(1) # B, LS, LQ 157 | # # accumulate query_token_weight when there is overlap 158 | # text_token_weight = text_token_weight.squeeze(-1) + (query_token_weight * overlap).sum(1) # B, LS 159 | # text_token_weight = text_token_weight.cpu().numpy() 160 | 161 | # for j, token_id in enumerate(text_token_id): 162 | # token_weight = text_token_weight[j] # LS 163 | # tokens = text_tokenizer.convert_ids_to_tokens(token_id) 164 | # words, weights = convert_tokens_to_words(tokens, subword_to_word_bert, scores=token_weight, reduce="max") 165 | 166 | # # a dict mapping the compressed phrase tokens to the phrase weight 167 | # word_weight_dict = {} 168 | # phrase_weight = 0 169 | # phrase_weights = [] 170 | # # collect the accumulated weights for each phrase (comma separated) 171 | # for word, weight in zip(words, weights): 172 | # if word in special_tokens: 173 | # continue 174 | # # comma will be a standalone token 175 | # elif word == self.config.code_sep: 176 | # # compress all tokens to overcome the space issues from different tokenizers 177 | # phrase_weights.append(phrase_weight) 178 | # phrase_weight = 0 179 | # else: 180 | # phrase_weight += weight 181 | 182 | # phrase_weights = np.array(phrase_weights) 183 | # # sort the words by their weights descendingly 184 | # sorted_indices = np.argsort(phrase_weights)[::-1] 185 | 186 | # src_code = text_code[j] 187 | # src_phrases = code_tokenizer.decode(src_code[src_code != -1], skip_special_tokens=True) 188 | # src_phrases = [prs.strip() for prs in src_phrases.split(self.config.code_sep) if len(prs.strip())] 189 | 190 | # dest_phrases = [] 191 | # try: 192 | # for idx in sorted_indices: 193 | # dest_phrases.append(src_phrases[idx]) 194 | # except: 195 | # print(x["text_idx"][j], tokens, words) 196 | # raise 197 | 198 | # # add separator at the end of each word 199 | # words = [prs + self.config.code_sep + " " for prs in dest_phrases] 200 | 201 | # query_code = "".join(words) 202 | # # the query code must be less than code_length - 1 203 | # query_code = code_tokenizer.encode(query_code) 204 | # assert len(query_code) < self.config.code_length 205 | 206 | # query_codes[qrel_idx[j], 1: len(query_code) + 1] = query_code 207 | 208 | # else: 209 | # return super().generate_code(loaders) 210 | -------------------------------------------------------------------------------- /src/models/BM25.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | from .BaseModel import BaseSparseModel 8 | from transformers import AutoTokenizer 9 | from utils.static import * 10 | from utils.util import BaseOutput, synchronize, save_pickle 11 | 12 | 13 | 14 | class BM25(BaseSparseModel): 15 | def __init__(self, config): 16 | super().__init__(config) 17 | self.tokenizer = AutoTokenizer.from_pretrained(config.plm_dir) 18 | 19 | @synchronize 20 | @torch.no_grad() 21 | def encode_text(self, *args, **kwargs): 22 | """ 23 | One step in encode_text. 24 | 25 | Args: 26 | x: a data record. 27 | 28 | Returns: 29 | the text token id for indexing, array of [B, L] 30 | the text token embedding for indexing, array of [B, L, D] 31 | """ 32 | if self.config.pretokenize: 33 | return BaseSparseModel.encode_text(self, *args, **kwargs) 34 | else: 35 | return BaseOutput() 36 | 37 | @synchronize 38 | @torch.no_grad() 39 | def encode_query(self, *args, **kwargs): 40 | """ 41 | One step in encode_text. 42 | 43 | Args: 44 | x: a data record. 45 | 46 | Returns: 47 | the query token id for indexing, array of [B, L] 48 | the query token embedding for indexing, array of [B, L, D] 49 | """ 50 | if self.config.pretokenize: 51 | return BaseSparseModel.encode_query(self, *args, **kwargs) 52 | else: 53 | return BaseOutput() 54 | 55 | 56 | @synchronize 57 | def generate_code(self, loaders: LOADERS): 58 | """ 59 | Generate code by BM25 term weights. 60 | """ 61 | import json 62 | import shutil 63 | from transformers import AutoModel 64 | from pyserini.index.lucene import IndexReader 65 | from utils.util import _get_token_code, makedirs, isempty 66 | 67 | assert self.config.pretokenize, f"Enable pretokenize!" 68 | 69 | # the code is bind to the code_tokenizer 70 | code_path = os.path.join(self.config.cache_root, "codes", self.config.code_type, self.config.code_tokenizer, str(self.config.code_length), "codes.mmp") 71 | self.logger.info(f"generating codes from {self.config.code_type} with code_length: {self.config.code_length}, saving at {code_path}...") 72 | makedirs(code_path) 73 | 74 | tokenizer = AutoTokenizer.from_pretrained(self.config.plm_dir) 75 | code_tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 76 | 77 | loader_text = loaders["text"] 78 | text_num = len(loader_text.dataset) 79 | start_idx = loader_text.sampler.start 80 | end_idx = loader_text.sampler.end 81 | 82 | if self.config.is_main_proc: 83 | # load all saved token ids 84 | # all codes are led by 0 and padded by -1 85 | text_codes = np.memmap( 86 | code_path, 87 | dtype=np.int32, 88 | mode="w+", 89 | shape=(text_num, self.config.code_length) 90 | ) 91 | model = AutoModel.from_pretrained(os.path.join(self.config.plm_root, self.config.code_tokenizer)) 92 | try: 93 | start_token_id = model._get_decoder_start_token_id() 94 | except ValueError: 95 | start_token_id = model.config.pad_token_id 96 | self.logger.warning(f"Decoder start token id not found, use pad token id ({start_token_id}) instead!") 97 | # the codes are always led by start_token_id and padded by -1 98 | text_codes[:, 0] = start_token_id 99 | text_codes[:, 1:] = -1 100 | 101 | synchronize() 102 | 103 | stop_words = set() 104 | punctuations = set([x for x in ";:'\\\"`~[]<>()\{\}/|?!@$#%^&*…-_=+,."]) 105 | nltk_stop_words = set(["a", "about", "also", "am", "to", "an", "and", "another", "any", "anyone", "are", "aren't", "as", "at", "be", "been", "being", "but", "by", "despite", "did", "didn't", "do", "does", "doesn't", "doing", "done", "don't", "each", "etc", "every", "everyone", "for", "from", "further", "had", "hadn't", "has", "hasn't", "have", "haven't", "having", "he", "he'd", "he'll", "her", "here", "here's", "hers", "herself", "he's", "him", "himself", "his", "however", "i", "i'd", "if", "i'll", "i'm", "in", "into", "is", "isn't", "it", "its", "it's", "itself", "i've", "just", "let's", "like", "lot", "may", "me", "might", "mightn't", "my", "myself", "no", "nor", "not", "of", "on", "onto", "or", "other", "ought", "oughtn't", "our", "ours", "ourselves", "out", "over", "shall", "shan't", "she", "she'd", "she'll", "she's", "since", "so", "some", "something", "such", "than", "that", "that's", "the", "their", "theirs", "them", "themselves", "then", "there", "there's", "these", "they", "they'd", "they'll", "they're", "they've", "this", "those", "through", "tht", "to", "too", "usually", "very", "via", "was", "wasn't", "we", "we'd", "well", "we'll", "were", "we're", "weren't", "we've", "will", "with", "without", "won't", "would", "wouldn't", "yes", "yet", "you", "you'd", "you'll", "your", "you're", "yours", "yourself", "yourselves", "you've"]) 106 | # include punctuations 107 | stop_words = stop_words | punctuations 108 | # include nltk stop words 109 | stop_words = stop_words | nltk_stop_words 110 | # include numbers in stopwords 111 | stop_words.add(r"\d") 112 | 113 | collection_dir = os.path.join(os.path.join(self.index_dir, "collection"), "weighted") 114 | 115 | input_path = f"{collection_dir}/{self.config.rank:02d}.jsonl" 116 | 117 | if self.config.get("load_collection"): 118 | pass 119 | else: 120 | if self.config.is_main_proc: 121 | if not isempty(collection_dir): 122 | shutil.rmtree(collection_dir) 123 | makedirs(input_path) 124 | synchronize() 125 | 126 | bm25_index = IndexReader(os.path.join(self.index_dir, "index")) 127 | if self.config.get("use_tfidf"): 128 | doc_count = bm25_index.stats()["documents"] 129 | 130 | with open(input_path, "w") as f: 131 | for i in tqdm(range(start_idx, end_idx), leave=False, ncols=100, desc="Collecting DFs"): 132 | x = loader_text.dataset[i] 133 | text_idx = str(x["text_idx"]) 134 | 135 | doc = bm25_index.get_document_vector(text_idx) 136 | if self.config.get("use_tfidf"): 137 | doc_len = sum(doc.values()) 138 | 139 | word_weight_pairs = {} 140 | for word in doc: 141 | if word[-1] in punctuations: 142 | continue 143 | if word not in word_weight_pairs: 144 | # NOTE: set analyzer to None because this is a pretokenized index 145 | if self.config.get("use_tfidf"): 146 | df = bm25_index.get_term_counts(word, analyzer=None)[0] 147 | word_weight_pairs[word] = round(doc[word] / doc_len * math.log(doc_count / df), 3) 148 | else: 149 | word_weight_pairs[word] = round(bm25_index.compute_bm25_term_weight(text_idx, word, analyzer=None), 3) 150 | 151 | doc_vec = {"id": int(text_idx), "vector": word_weight_pairs} 152 | f.write(json.dumps(doc_vec) + "\n") 153 | 154 | code_fields = self.config.code_type.split("-") 155 | defaults = ["weight", None] 156 | code_fields.extend(defaults[-(3 - len(code_fields)):]) 157 | code_name, code_init_order, code_post_order = code_fields[:3] 158 | 159 | # force to stem 160 | _get_token_code( 161 | input_path, 162 | code_path, 163 | text_num, 164 | start_idx, 165 | end_idx, 166 | code_tokenizer, 167 | self.config.code_length, 168 | code_init_order, 169 | code_post_order, 170 | stop_words, 171 | self.config.get("code_sep", " "), 172 | self.config.get("stem_code"), 173 | self.config.get("filter_num"), 174 | self.config.get("filter_unit"), 175 | ) 176 | 177 | 178 | # FIXME: refactor 179 | def rerank(self, loaders: dict): 180 | from pyserini.index.lucene import IndexReader 181 | from utils.util import load_pickle 182 | 183 | index_reader = IndexReader(self.index_dir) 184 | 185 | tid2index = load_pickle(os.path.join(os.path.join(self.config.cache_root, "dataset", "text", "id2index.pkl"))) 186 | tindex2id = {v: k for k, v in tid2index.items()} 187 | 188 | loader_rerank = loaders["rerank"] 189 | retrieval_result = defaultdict(list) 190 | 191 | self.logger.info("reranking...") 192 | for i, x in enumerate(tqdm(loader_rerank.dataset, ncols=100, leave=False)): 193 | query_idx = x["query_idx"] 194 | seq_idx = x["seq_idx"] 195 | 196 | seq_id = tindex2id[seq_idx] 197 | query = self.tokenizer.decode(x["query_token_id"], skip_special_tokens=True) 198 | score = index_reader.compute_query_document_score(seq_id, query) 199 | 200 | retrieval_result[query_idx].append((seq_idx, score)) 201 | 202 | retrieval_result = self._gather_retrieval_result(retrieval_result) 203 | return retrieval_result 204 | 205 | 206 | @torch.no_grad() 207 | def compute_flops(self, loaders, log=True): 208 | """ compute flops as stated in SPLADE 209 | """ 210 | from pyserini.index.lucene import IndexReader 211 | 212 | # document side 213 | loader_text = loaders["text"] 214 | doc_num = len(loader_text.dataset) 215 | query_num = len(loaders["query"].dataset) 216 | 217 | index = IndexReader(self.index_dir) 218 | terms = {} 219 | for k in tqdm(index.terms(), ncols=100, desc="Collecting Vocabulary"): 220 | terms[k.term] = 0 221 | 222 | D = terms 223 | Q = terms.copy() 224 | 225 | for i in tqdm(range(doc_num), ncols=100, desc="Collecting Text Terms"): 226 | tid = str(i) 227 | base_doc = index.get_document_vector(tid) 228 | base_key = base_doc.keys() 229 | for k in base_key: 230 | D[k] += 1 231 | 232 | with open(f"{self.config.data_root}/{self.config.dataset}/queries.dev.tsv") as f: 233 | for line in tqdm(f, ncols=100, desc="Collecting Query Terms"): 234 | qid, text = line.strip().split("\t") 235 | analysed = index.analyze(text) 236 | for k in analysed: 237 | try: 238 | Q[k] += 1 239 | except: 240 | pass 241 | 242 | flops = 0 243 | for k in D.keys(): 244 | flops += (D[k] / doc_num * Q[k] / query_num) 245 | 246 | flops = round(flops, 2) 247 | self.metrics.update({"FLOPs": flops}) 248 | if log: 249 | self.log_result() 250 | self.logger.info(f"FLOPs: {flops}") 251 | 252 | -------------------------------------------------------------------------------- /src/notebooks/tsgen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "[2024-04-03 07:43:15,463] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" 13 | ] 14 | }, 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "[2024-04-03 07:43:16,245] INFO (Config) setting seed to 42...\n", 20 | "[2024-04-03 07:43:16,250] INFO (Config) setting PLM to t5...\n" 21 | ] 22 | }, 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "downloading t5-base\n" 28 | ] 29 | }, 30 | { 31 | "name": "stderr", 32 | "output_type": "stream", 33 | "text": [ 34 | "/share/peitian/Envs/adon/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5_fast.py:156: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", 35 | "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", 36 | "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", 37 | "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", 38 | "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", 39 | " warnings.warn(\n", 40 | "[2024-04-03 07:43:26,151] INFO (Config) Config: {'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'batch_size': 2, 'bf16': False, 'cache_root': 'data/cache/NQ320k', 'data_format': 'memmap', 'data_root': '/data/TSGen', 'dataset': 'NQ320k', 'debug': False, 'deepspeed': None, 'device': 0, 'distill_src': 'none', 'early_stop_patience': 5, 'enable_all_gather': True, 'enable_distill': False, 'enable_inbatch_negative': True, 'epoch': 20, 'eval_batch_size': 2, 'eval_delay': 0, 'eval_flops': False, 'eval_metric': ['mrr', 'recall'], 'eval_metric_cutoff': [1, 5, 10, 100, 1000], 'eval_mode': 'retrieve', 'eval_posting_length': False, 'eval_set': 'dev', 'eval_step': '1e', 'fp16': False, 'grad_accum_step': 1, 'hits': 1000, 'index_shard': 32, 'index_thread': 10, 'index_type': 'invvec', 'learning_rate': 3e-06, 'load_ckpt': None, 'load_encode': False, 'load_index': True, 'load_query_encode': False, 'load_result': False, 'load_text_encode': False, 'loader_train': 'neg', 'main_metric': 'Recall@10', 'max_grad_norm': 0, 'max_query_length': 64, 'max_step': 0, 'max_text_length': 512, 'mode': 'train', 'model_type': None, 'neg_type': 'random', 'nneg': 1, 'num_worker': 0, 'parallel': 'text', 'plm': 't5', 'plm_dir': '/data/TSGen/PLMs/t5', 'plm_root': '/data/TSGen/PLMs', 'plm_tokenizer': 't5', 'posting_prune': 0.0, 'query_gate_k': 0, 'query_length': 32, 'report_to': 'none', 'return_first_mask': False, 'return_special_mask': False, 'save_at_eval': False, 'save_ckpt': 'best', 'save_encode': False, 'save_index': True, 'save_model': False, 'save_res': 'retrieval_result', 'save_score': False, 'scheduler': 'constant', 'seed': 42, 'special_token_ids': {'cls': (None, None), 'pad': ('', 0), 'unk': ('', 2), 'sep': (None, None), 'eos': ('', 1)}, 'text_col': [1, 2, 3], 'text_col_sep': ' ', 'text_gate_k': 0, 'text_length': 512, 'text_type': 'default', 'train_set': ['train'], 'untie_encoder': False, 'verifier_hits': 1000, 'verifier_index': 'none', 'verifier_src': 'none', 'verifier_type': 'none', 'vocab_size': 32100, 'warmup_ratio': 0.1, 'warmup_step': 0, 'weight_decay': 0.01}\n", 41 | "[2024-04-03 07:43:26,242] INFO (Dataset) initializing NQ320k memmap Text dataset...\n", 42 | "[2024-04-03 07:43:26,293] INFO (Dataset) initializing NQ320k memmap Query dev dataset...\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "import os\n", 48 | "import sys\n", 49 | "if sys.path[-1] != \"../\":\n", 50 | " sys.path.append(\"../\")\n", 51 | " os.chdir(\"../\")\n", 52 | "\n", 53 | "os.environ['https_proxy'] = \"http://127.0.0.1:15777\"\n", 54 | "os.environ['http_proxy'] = \"http://127.0.0.1:15777\"\n", 55 | "\n", 56 | "import numpy as np\n", 57 | "import pandas as pd\n", 58 | "from IPython.display import display\n", 59 | "from random import sample\n", 60 | "from transformers import AutoModel, AutoTokenizer\n", 61 | "\n", 62 | "import torch\n", 63 | "from utils.util import *\n", 64 | "from utils.index import *\n", 65 | "from utils.data import *\n", 66 | "\n", 67 | "from hydra import initialize, compose\n", 68 | "\n", 69 | "config = Config()\n", 70 | "with initialize(version_base=None, config_path=\"../data/config/\"):\n", 71 | " overrides = [\n", 72 | " \"base=NQ320k\",\n", 73 | " # \"base=MS300k\",\n", 74 | " # \"++plm=t5\",\n", 75 | " ]\n", 76 | " hydra_config = compose(config_name=\"_example\", overrides=overrides)\n", 77 | " config._from_hydra(hydra_config)\n", 78 | "\n", 79 | "loaders = prepare_data(config)\n", 80 | "\n", 81 | "loader_text = loaders[\"text\"]\n", 82 | "loader_query = loaders[\"query\"]\n", 83 | "text_dataset = loader_text.dataset\n", 84 | "query_dataset = loader_query.dataset\n", 85 | "\n", 86 | "# tokenizer = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, config.plm_tokenizer))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 12, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# load terms\n", 96 | "code_type = \"term\"\n", 97 | "code_tokenizer = \"t5\"\n", 98 | "# for NQ320k\n", 99 | "code_length = 26\n", 100 | "\n", 101 | "tokenizer = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, code_tokenizer))\n", 102 | "\n", 103 | "text_codes = np.memmap(\n", 104 | " f\"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp\",\n", 105 | " mode=\"r\",\n", 106 | " dtype=np.int32\n", 107 | ").reshape(len(text_dataset), -1).copy()" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 13, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "[' email, mail, marketing, sent, messages, customer, sending, hide, purpose, opt, online, merchant,',\n", 119 | " ' mother, ted, umbrella, met, meet, tracy, mcconnell, mom,',\n", 120 | " ' sperm, fertilization, spermatozoon, vitro, spermatozoa, egg,',\n", 121 | " ' quarterback, nfl, wins, career, brady, football, peyton, manning,',\n", 122 | " ' roanoke, colony, lost, disappeared, dare, raleigh, established,',\n", 123 | " ' africa, african, regions, five, subregions, west, north, six, south, continent,',\n", 124 | " ' mantis, guardians, actress, french, galaxy, spring, infinity, hacker, sleepless,',\n", 125 | " ' frosty, december, hat, christmas, melt, snowman, life, 1969, special,',\n", 126 | " ' acadians, acadia, french, colonial, acadiens,',\n", 127 | " ' banks, outer, graveyard, carolina, roanoke, islands, wright,']" 128 | ] 129 | }, 130 | "metadata": {}, 131 | "output_type": "display_data" 132 | }, 133 | { 134 | "data": { 135 | "text/plain": [ 136 | "['Email marketing Email marketing is the act of sending a commercial message, typically to a group of people, using email. In its broadest sense, every email sent to a potential or current customer could be considered email marketing. It usually involves using email to send advertisements, request business, or solicit sales or donations, and is meant to build loyalty, trust, or brand awareness. Marketing emails can be sent to a purchased lead list',\n", 137 | " \"The Mother ( How I Met Your Mother ) Tracy McConnell, better known as The Mother '', is the title character from the CBS television sitcom How I Met Your Mother. The show, narrated by Future Ted, tells the story of how Ted Mosby met The Mother. Tracy McConnell appears in 8 episodes from Lucky Penny '' to The Time Travelers '' as\",\n", 138 | " 'Human fertilization Human fertilization is the union of a human egg and sperm, usually occurring in the ampulla of the fallopian tube. The result of this union is the production of a zygote cell, or fertilized egg, initiating prenatal development. Scientists discovered the dynamics of human fertilization in the nineteenth century. The process of fertilization involves a sperm fusing with an ov',\n", 139 | " 'List of National Football League career quarterback wins leaders The following is a list of the top National Football League ( NFL ) quarterbacks in wins. In the NFL, the quarterback is the only position that is credited with records of wins and losses. Active quarterback Tom Brady holds the records for most wins with 220, most regular season wins with 195, and most postseason wins with 25, as of Week 16 of the 2017 NFL season. Having',\n", 140 | " \"Roanoke Colony The Roanoke Colony ( / ronok / ), also known as the Lost Colony, was established in 1585 on Roanoke Island in what is today's Dare County, North Carolina. It was a late 16th - century attempt by Queen Elizabeth I to establish a permanent English settlement in North America. The colony\",\n", 141 | " 'List of regions of Africa The continent of Africa is commonly divided into five regions or subregions, four of which are in Sub-Saharan Africa, though some definitions may contain four ( removing Central Africa ) or six regions ( separating the horn of Africa into its own region ). Contents ( hide ) 1 List of subregions in Africa 2 Directional approach 3 Physiographic approach 4 Linguistic approach 4.1 By official',\n", 142 | " \"Pom Klementieff Pom Klementieff ( born 3 May 1986 ) is a French actress. She was trained at the Cours Florent drama school in Paris and has appeared in such films as Loup ( 2009 ), Sleepless Night ( 2011 ) and Hacker's Game ( 2015 ). She plays the role of Mantis in the film Guardians of the Galaxy Vol. 2 ( 2017 ) and will appear\",\n", 143 | " \"Frosty the Snowman ( film ) Frosty the Snowman is a 1969 animated Christmas television special based on the song Frosty the Snowman ''. The program, which first aired on December 7, 1969 on CBS ( where it continues to air annually ), was produced for television by Rankin / Bass Productions and featured the voices of comedians Jimmy Durante as the film's narr\",\n", 144 | " 'History of the Acadians The Acadians ( French : Acadiens ) are the descendants of the French settlers, and sometimes the Indigenous peoples, of parts of Acadia ( French : Acadie ) in the northeastern region of North America comprising what is now the Canadian Maritime Provinces of New Brunswick, Nova Scotia, and Prince Edward Island, Gaspé, in Quebec, and to the',\n", 145 | " 'Outer Banks The Outer Banks ( OBX ) is a 200 - mile - long ( 320 km ) string of barrier islands and spits off the coast of North Carolina and southeastern Virginia, on the east coast of the United States. They cover most of the North Carolina coastline, separating the Currituck Sound, Albemarle Sound, and Pamlico Sound from the Atlantic Ocean ']" 146 | ] 147 | }, 148 | "metadata": {}, 149 | "output_type": "display_data" 150 | } 151 | ], 152 | "source": [ 153 | "indices = range(10)\n", 154 | "text_code = text_codes[indices]\n", 155 | "text_code[text_code == -1] = 0\n", 156 | "display(tokenizer.batch_decode(text_code))\n", 157 | "display(tokenizer.batch_decode(np.array(text_dataset[indices][\"text\"][\"input_ids\"])[:, :100]))" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 14, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "(4, 8)" 169 | ] 170 | }, 171 | "execution_count": 14, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | } 175 | ], 176 | "source": [ 177 | "# trie = TrieIndex(save_dir=f\"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}\")\n", 178 | "# trie.load()\n", 179 | "\n", 180 | "# wordset = WordSetIndex(save_dir=f\"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}\", sep_token_id=6)\n", 181 | "# wordset.fit(None)\n", 182 | "\n", 183 | "# text_codes = np.sort(text_codes, axis=-1)\n", 184 | "df = pd.DataFrame(text_codes)\n", 185 | "duplicates = df.groupby(df.columns.tolist(),as_index=False).size()\n", 186 | "duplicates = duplicates.sort_values(\"size\", ascending=False)\n", 187 | "duplicates.reset_index(drop=True, inplace=True)\n", 188 | "\n", 189 | "dup = df.duplicated(keep=\"first\").to_numpy()\n", 190 | "dup_indices = np.argwhere(dup)[:, 0]\n", 191 | "len(dup_indices), duplicates[\"size\"][duplicates[\"size\"] > 1].sum()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "Python 3.9.12", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.9.12" 219 | }, 220 | "orig_nbformat": 4, 221 | "vscode": { 222 | "interpreter": { 223 | "hash": "778a5a6b0df35a46498564cf16af2e5ec016022ef7dc9d5934de67fcb1f6bfb9" 224 | } 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 2 229 | } 230 | -------------------------------------------------------------------------------- /src/models/VQ.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | import torch 4 | import subprocess 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | from transformers import AutoModel 10 | from .BaseModel import BaseDenseModel 11 | from utils.util import BaseOutput, readlink 12 | from utils.index import FaissIndex 13 | from utils.static import * 14 | 15 | 16 | 17 | class DistillVQ(BaseDenseModel): 18 | """ 19 | The model is proposed in `this paper `_. The implementation here follows its own `git repository `_. 20 | """ 21 | def __init__(self, config): 22 | assert "PQ" in config.index_type, "DistillVQ is intended for PQ based methods!" 23 | 24 | super().__init__(config) 25 | 26 | index = faiss.read_index(os.path.join(config.cache_root, "index", config.embedding_src, "faiss", config.index_type)) 27 | if isinstance(index, faiss.IndexPreTransform): 28 | vt = faiss.downcast_VectorTransform(index.chain.at(0)) 29 | opq = faiss.vector_to_array(vt.A).reshape(vt.d_out, vt.d_in).T 30 | self.register_buffer("vt", torch.tensor(opq)) 31 | 32 | if "IVF" in config.index_type: 33 | if isinstance(index, faiss.IndexPreTransform): 34 | ivf_index = faiss.downcast_index(index.index) 35 | else: 36 | ivf_index = index 37 | 38 | quantizer = faiss.downcast_index(ivf_index.quantizer) 39 | ivf_centroids = FaissIndex.get_xb(quantizer) 40 | self.ivfCentroids = nn.parameter.Parameter(torch.tensor(ivf_centroids)) 41 | 42 | pq = ivf_index.pq 43 | pq_centroids = FaissIndex.get_pq_codebook(pq) 44 | self.pqCentroids = nn.parameter.Parameter(torch.tensor(pq_centroids)) 45 | 46 | invlists = ivf_index.invlists 47 | cs = invlists.code_size 48 | pq_codes = np.zeros((ivf_index.ntotal, pq.M), dtype=np.float32) 49 | ivf_codes = np.zeros(ivf_index.ntotal, dtype=np.float32) 50 | for i in tqdm(range(ivf_index.nlist), ncols=100, desc="Collecting IVF Codes"): 51 | ls = invlists.list_size(i) 52 | list_ids = faiss.rev_swig_ptr(invlists.get_ids(i), ls) 53 | list_codes = faiss.rev_swig_ptr(invlists.get_codes(i), ls * cs).reshape(ls, cs) 54 | for j, docid in enumerate(list_ids): 55 | pq_codes[docid] = list_codes[j] 56 | ivf_codes[docid] = i 57 | 58 | self._ivf_codes = ivf_codes 59 | self._pq_codes = pq_codes 60 | 61 | # load pq centroids 62 | elif "PQ" in config.index_type: 63 | if isinstance(index, faiss.IndexPreTransform): 64 | pq_index = faiss.downcast_index(index.index) 65 | else: 66 | pq_index = index 67 | # for both ivfpq and pq index, the product quantizer can be accessed by index.pq 68 | pq = pq_index.pq 69 | pq_centroids = faiss.vector_to_array(pq.centroids).reshape(pq.M, pq.ksub, pq.dsub) 70 | self.pqCentroids = nn.parameter.Parameter(torch.tensor(pq_centroids)) 71 | 72 | pq_codes = faiss.vector_to_array(pq_index.codes).reshape(-1, pq.M) 73 | self._pq_codes = pq_codes 74 | 75 | else: 76 | raise NotImplementedError(f"{config.index_type} not implemented!") 77 | 78 | if config.train_encoder: 79 | self.queryEncoder = AutoModel.from_pretrained(f"{config.plm_root}/retromae_distill") 80 | self.queryEncoder.pooler = None 81 | 82 | self._output_dim = index.d 83 | 84 | 85 | def create_optimizer(self) -> torch.optim.Optimizer: 86 | ivf_parameter_names = ["ivfCentroids"] 87 | pq_parameter_names = ["pqCentroids"] 88 | 89 | ivf_parameters = [] 90 | pq_parameters = [] 91 | encoder_parameters = [] 92 | 93 | for name, param in self.named_parameters(): 94 | if any(x in name for x in ivf_parameter_names): 95 | ivf_parameters.append(param) 96 | elif any(x in name for x in pq_parameter_names): 97 | pq_parameters.append(param) 98 | else: 99 | encoder_parameters.append(param) 100 | 101 | optimizer_grouped_parameters = [ 102 | { 103 | "params": ivf_parameters, 104 | "lr": self.config.learning_rate_ivf 105 | }, 106 | { 107 | "params": pq_parameters, 108 | "lr": self.config.learning_rate_pq 109 | }, 110 | { 111 | "params": encoder_parameters, 112 | "lr": self.config.learning_rate 113 | } 114 | ] 115 | 116 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters) 117 | return optimizer 118 | 119 | 120 | def _quantize_ivf(self, text_idx:np.ndarray, text_embedding:TENSOR) -> TENSOR: 121 | """ 122 | Args: 123 | text_idx: the indices of the input documents, used to look up ``self._ivf_codes``; tensor of [B] 124 | embedding: the embedding of the documents, used to dynamically compute ivf assignments when ``self.config.train_ivf_assign==True`` 125 | 126 | Returns: 127 | quantized ivf embedding (the closest ivf centroid) 128 | """ 129 | if self.config.train_ivf_assign: 130 | ivf_sim = text_embedding.matmul(self.ivfCentroids.T) 131 | ivf_assign_soft = torch.softmax(ivf_sim, dim=-1) # B, nlist 132 | _, max_index = ivf_assign_soft.max(dim=-1, keepdim=True) # B, 1 133 | ivf_assign_hard = torch.zeros_like(ivf_assign_soft, device=ivf_assign_soft.device, dtype=ivf_assign_soft.dtype).scatter_(-1, max_index, 1.0) 134 | # straight-through trick 135 | ivf_assign = ivf_assign_hard.detach() - ivf_assign_soft.detach() + ivf_assign_soft # B, C 136 | quantized_embedding = ivf_assign.matmul(self.ivfCentroids) 137 | 138 | else: 139 | ivf_id = torch.as_tensor(self._ivf_codes[text_idx], device=self.config.device, dtype=torch.long) # B 140 | quantized_embedding = self.ivfCentroids[ivf_id] 141 | 142 | return quantized_embedding 143 | 144 | 145 | def _quantize_pq(self, text_idx:np.ndarray, embedding:TENSOR) -> TENSOR: 146 | """ 147 | Args: 148 | text_idx: the indices of the input documents, used to look up ``self._pq_codes``; tensor of [B] 149 | embedding: the embedding of the documents, used to dynamically compute pq assignments when ``self.config.train_pq_assign==True`` 150 | 151 | Returns: 152 | quantized pq embedding (the closest pq centroid) 153 | """ 154 | if self.config.train_pq_assign: 155 | M, ksub, dsub = self.pqCentroids.shape 156 | B = embedding.shape[0] 157 | embedding = embedding.view(B, M, dsub) 158 | codebook = self.pqCentroids 159 | distance = - torch.sum((embedding.unsqueeze(-2) - codebook) ** 2, -1) 160 | pq_assign_soft = torch.softmax(distance, dim=-1) # B, M, ksub 161 | _, max_index = pq_assign_soft.max(dim=-1, keepdim=True) 162 | pq_assign_hard = torch.zeros_like(pq_assign_soft, device=pq_assign_soft.device, dtype=pq_assign_soft.dtype).scatter_(-1, max_index, 1.0) 163 | # straight-through trick 164 | pq_assign = pq_assign_hard.detach() - pq_assign_soft.detach() + pq_assign_soft 165 | pq_assign = pq_assign.unsqueeze(-2) # B, M, 1, ksub 166 | 167 | codebook = codebook.unsqueeze(0).expand(B, -1, -1, -1) # B, M, ksub, dsub 168 | quantized_embedding = torch.matmul(pq_assign, codebook).view(B, -1) # B, D 169 | 170 | else: 171 | pq_id = torch.as_tensor(self._pq_codes[text_idx], device=self.config.device, dtype=torch.long) 172 | quantized_embedding = FaissIndex.pq_quantize(pq_id, self.pqCentroids) 173 | 174 | return quantized_embedding 175 | 176 | 177 | def _encode_query(self, **kwargs): 178 | """ 179 | encode tokens with bert 180 | """ 181 | embedding = self.queryEncoder(**kwargs)[0][:, 0] 182 | return embedding 183 | 184 | 185 | def forward(self, x:dict) -> TENSOR: 186 | x = self._move_to_device(x) 187 | text_idx = x["text_idx"].view(-1).numpy() 188 | 189 | text_embedding = x["text_embedding"].flatten(0, 1) # B*(1+N), D 190 | 191 | if self.config.train_encoder: 192 | query_embedding = self._encode_query(**x["query"]) 193 | else: 194 | query_embedding = x["query_embedding"] 195 | 196 | if hasattr(self, "vt"): 197 | rotate_query_embedding = query_embedding.matmul(self.vt) 198 | rotate_text_embedding = text_embedding.matmul(self.vt) 199 | else: 200 | rotate_query_embedding = query_embedding 201 | rotate_text_embedding = text_embedding 202 | 203 | if hasattr(self, "ivfCentroids"): 204 | text_ivf_quantization = self._quantize_ivf(text_idx, rotate_text_embedding) 205 | quantize_text_embedding = rotate_text_embedding - text_ivf_quantization 206 | text_pq_quantization = self._quantize_pq(text_idx, quantize_text_embedding) + text_ivf_quantization 207 | else: 208 | text_ivf_quantization = None 209 | text_pq_quantization = self._quantize_pq(text_idx, rotate_text_embedding) 210 | 211 | if self.config.is_distributed and self.config.enable_all_gather: 212 | rotate_query_embedding = self._gather_tensors(rotate_query_embedding) 213 | rotate_text_embedding = self._gather_tensors(rotate_text_embedding) 214 | text_ivf_quantization = self._gather_tensors(text_ivf_quantization) 215 | text_pq_quantization = self._gather_tensors(text_pq_quantization) 216 | 217 | score_pq = rotate_query_embedding.matmul(text_pq_quantization.transpose(-1, -2)) # B, B*(1+N) 218 | if text_ivf_quantization is not None: 219 | score_ivf = rotate_query_embedding.matmul(text_ivf_quantization.transpose(-1,-2)) # B, B*(1+N) 220 | if self.config.train_encoder: 221 | score_dense = rotate_query_embedding.matmul(rotate_text_embedding.transpose(-1, -2)) # B, B*(1+N) 222 | 223 | B = rotate_query_embedding.size(0) 224 | if self.config.enable_inbatch_negative: 225 | label = torch.arange(B, device=self.config.device) 226 | label = label * (rotate_text_embedding.size(0) // rotate_query_embedding.size(0)) 227 | else: 228 | label = torch.zeros(B, dtype=torch.long, device=self.config.device) 229 | score_pq = score_pq.view(B, B, -1)[range(B), range(B)] # B, 1+N 230 | if text_ivf_quantization is not None: 231 | score_ivf = score_ivf.view(B, B, -1)[range(B), range(B)] # B, 1+N 232 | if self.config.train_encoder: 233 | score_dense = score_dense.view(B, B, -1)[range(B), range(B)] # B, 1+N 234 | 235 | teacher_score = self._compute_teacher_score(x) 236 | loss_pq = self._compute_loss(score_pq, label, teacher_score) 237 | if text_ivf_quantization is not None: 238 | loss_ivf = self._compute_loss(score_ivf, label, teacher_score) 239 | else: 240 | loss_ivf = 0 241 | if self.config.train_encoder: 242 | loss_dense = self._compute_loss(score_dense, label, teacher_score) 243 | else: 244 | loss_dense = 0 245 | 246 | # scale the ivf loss, it's usually way bigger than PQ's loss 247 | # if text_ivf_quantization is not None: 248 | # # rescale the ivf loss 249 | # loss_ivf = loss_ivf * float(float(loss_pq) / (loss_ivf + 1e-6)) 250 | 251 | return loss_pq + loss_ivf + loss_dense 252 | 253 | 254 | @torch.no_grad() 255 | def encode_text(self, loader_text:DataLoader, load_all_encode:bool=False): 256 | if load_all_encode: 257 | text_embeddings = loader_text.dataset.text_embeddings 258 | else: 259 | # create soft link to the embedding_src 260 | if self.config.is_main_proc and self.config.save_encode: 261 | os.makedirs(self.text_dir, exist_ok=True) 262 | subprocess.run( 263 | f"ln -sf {os.path.join(self.config.cache_root, 'encode', self.config.embedding_src, 'text', self.config.text_type, 'text_embeddings.mmp')} {os.path.join(self.text_dir, 'text_embeddings.mmp')}", 264 | shell=True 265 | ) 266 | 267 | text_embeddings = loader_text.dataset.text_embeddings[loader_text.sampler.start: loader_text.sampler.end] 268 | return BaseOutput(embeddings=text_embeddings) 269 | 270 | 271 | @torch.no_grad() 272 | def encode_query(self, loader_query:DataLoader, load_all_encode:bool=False): 273 | query_embedding_path = os.path.join(self.query_dir, "query_embeddings.mmp") 274 | 275 | if load_all_encode: 276 | query_embeddings = np.memmap( 277 | readlink(query_embedding_path), 278 | mode="r+", 279 | dtype=np.float32 280 | ).reshape(len(loader_query.dataset), self._output_dim) 281 | 282 | elif self.config.load_encode or self.config.load_query_encode: 283 | query_embeddings = np.memmap( 284 | readlink(query_embedding_path), 285 | mode="r+", 286 | dtype=np.float32 287 | ).reshape(len(loader_query.dataset), self._output_dim)[loader_query.sampler.start: loader_query.sampler.end] 288 | 289 | else: 290 | if hasattr(self, "queryEncoder"): 291 | query_embeddings = np.zeros((len(loader_query.sampler), self._output_dim), dtype=np.float32) 292 | start_idx = end_idx = 0 293 | self.logger.info(f"encoding {self.config.dataset} {self.config.eval_set} query...") 294 | for i, x in enumerate(tqdm(loader_query, leave=False, ncols=100)): 295 | query = self._move_to_device(x["query"]) 296 | query_embedding = self._encode_query(**query).cpu().numpy() # B, LS, D 297 | end_idx += query_embedding.shape[0] 298 | query_embeddings[start_idx: end_idx] = query_embedding 299 | start_idx = end_idx 300 | if self.config.debug: 301 | if i > 10: 302 | break 303 | 304 | if self.config.save_encode: 305 | self.save_to_mmp( 306 | path=query_embedding_path, 307 | shape=(len(loader_query.dataset), self._output_dim), 308 | dtype=np.float32, 309 | loader=loader_query, 310 | obj=query_embeddings 311 | ) 312 | else: 313 | # create soft link to the embedding_src 314 | if self.config.is_main_proc and self.config.save_encode: 315 | os.makedirs(self.query_dir, exist_ok=True) 316 | subprocess.run( 317 | f"ln -sf {os.path.join(self.config.cache_root, 'encode', self.config.embedding_src, 'query', self.config.eval_set, 'query_embeddings.mmp')} {os.path.join(self.query_dir, 'query_embeddings.mmp')}", 318 | shell=True 319 | ) 320 | 321 | query_embeddings = loader_query.dataset.query_embeddings[loader_query.sampler.start: loader_query.sampler.end] 322 | 323 | return BaseOutput(embeddings=query_embeddings) 324 | 325 | 326 | def index(self, loaders:LOADERS): 327 | if self.config.load_index: 328 | return super().index(loaders) 329 | 330 | loader_text = loaders["text"] 331 | text_embeddings = self.encode_text(loader_text).embeddings 332 | 333 | if self.config.index_type != "Flat" and not self.config.is_main_proc: 334 | index = None 335 | else: 336 | index = FaissIndex( 337 | index_type=self.config.index_type, 338 | d=text_embeddings.shape[1], 339 | metric=self.config.dense_metric, 340 | start_text_idx=loader_text.sampler.start, 341 | device=self.config.device, 342 | save_dir=self.index_dir, 343 | ) 344 | index.load(os.path.join(self.config.cache_root, "index", self.config.embedding_src, "faiss", self.config.index_type)) 345 | 346 | # load opq transformation 347 | if isinstance(index.index, faiss.IndexPreTransform): 348 | vt = faiss.downcast_VectorTransform(index.index.chain.at(0)) 349 | faiss.copy_array_to_vector(self.vt.T.cpu().numpy().ravel(), vt.A) 350 | vt.is_trained = True 351 | 352 | if "IVF" in self.config.index_type: 353 | # Important! Don't move the index to gpu; otherwise may trigger wierd performance issue of Faiss 354 | index.device = "cpu" 355 | 356 | if isinstance(index.index, faiss.IndexPreTransform): 357 | ivfpq = faiss.downcast_index(index.index.index) 358 | else: 359 | ivfpq = index.index 360 | 361 | quantizer = faiss.downcast_index(ivfpq.quantizer) 362 | # remove all embeddings from the quantizer 363 | quantizer.reset() 364 | quantizer.add(self.ivfCentroids.cpu().numpy()) 365 | 366 | if self.config.train_ivf_assign and not self.config.train_pq_assign: 367 | ivfpq.reset() 368 | index.index.ntotal = 0 369 | index.fit(text_embeddings) 370 | # copy pq centroids after adding embeddings 371 | pq = ivfpq.pq 372 | faiss.copy_array_to_vector(self.pqCentroids.cpu().numpy().ravel(), pq.centroids) 373 | 374 | elif self.config.train_pq_assign: 375 | # copy pq centroids before adding embeddings 376 | pq = ivfpq.pq 377 | faiss.copy_array_to_vector(self.pqCentroids.cpu().numpy().ravel(), pq.centroids) 378 | ivfpq.reset() 379 | index.index.ntotal = 0 380 | index.fit(text_embeddings) 381 | 382 | else: 383 | pq = ivfpq.pq 384 | faiss.copy_array_to_vector(self.pqCentroids.cpu().numpy().ravel(), pq.centroids) 385 | # do nothing 386 | index.fit(text_embeddings) 387 | 388 | elif "PQ" in self.config.index_type: 389 | if isinstance(index.index, faiss.IndexPreTransform): 390 | pq_index = faiss.downcast_index(index.index.index) 391 | else: 392 | pq_index = index.index 393 | pq = pq_index.pq 394 | faiss.copy_array_to_vector(self.pqCentroids.cpu().numpy().ravel(), pq.centroids) 395 | pq_index.is_trained = True 396 | if self.config.train_pq_assign: 397 | pq_index.reset() 398 | index.index.ntotal = 0 399 | 400 | index.fit(text_embeddings) 401 | else: 402 | raise NotImplementedError 403 | 404 | if self.config.save_index: 405 | index.save() 406 | 407 | return BaseOutput(index=index) 408 | 409 | --------------------------------------------------------------------------------