├── output └── .gitkeep ├── examples ├── __init__.py └── messenger │ ├── __init__.py │ ├── config.yml │ ├── README.md │ └── worlds.py ├── chatty_goose ├── agents │ ├── __init__.py │ └── chat.py ├── __init__.py ├── pipeline │ ├── __init__.py │ └── retrieval_pipeline.py ├── cqr │ ├── __init__.py │ ├── cqr.py │ ├── ntr.py │ ├── cqe.py │ └── hqe.py ├── logger.py ├── types.py ├── settings.py └── util.py ├── requirements.txt ├── pyproject.toml ├── .gitmodules ├── .gitignore ├── setup.py ├── data └── cast_preprocess.py ├── README.md ├── docs ├── cqr_experiments.md ├── cqr_experiment_2020.md ├── t5_finetuning.md └── conversation_dense_retrieval_experiments.md ├── experiments └── run_retrieval.py └── LICENSE /output/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chatty_goose/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/messenger/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chatty_goose/__init__.py: -------------------------------------------------------------------------------- 1 | from chatty_goose.logger import * 2 | -------------------------------------------------------------------------------- /chatty_goose/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .retrieval_pipeline import * 2 | -------------------------------------------------------------------------------- /chatty_goose/cqr/__init__.py: -------------------------------------------------------------------------------- 1 | from .hqe import * 2 | from .ntr import * 3 | from .cqr import * 4 | from .cqe import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coloredlogs 2 | parlai==1.1.0 3 | pydantic>=1.5 4 | pygaggle==0.0.3.1 5 | spacy>=2.2.4,<=2.3.5 6 | pyserini==0.14.0 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /chatty_goose/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import coloredlogs 4 | 5 | 6 | coloredlogs.install( 7 | level=os.environ.get("CHATTY_GOOSE_LOG_LEVEL", "WARN"), 8 | fmt="%(asctime)s [%(levelname)s]: %(message)s", 9 | ) 10 | -------------------------------------------------------------------------------- /chatty_goose/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class PosFilter(str, Enum): 5 | NO = "no" 6 | POS = "pos" 7 | STP = "stp" 8 | 9 | class CqrType(str, Enum): 10 | HQE = "hqe" 11 | CQE = "cqe" 12 | T5 = "t5" 13 | HQE_T5_FUSION = "hqe_t5_fusion" 14 | CQE_T5_FUSION = "cqe_t5_fusion" 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data/trec-cast-tools"] 2 | path = data/trec-cast-tools 3 | url = https://github.com/gla-ial/trec-cast-tools 4 | [submodule "data/treccastweb"] 5 | path = data/treccastweb 6 | url = https://github.com/daltonj/treccastweb 7 | [submodule "data/canard"] 8 | path = data/canard 9 | url = https://github.com/aagohary/canard 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/* 2 | !output/.gitkeep 3 | .vscode/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | -------------------------------------------------------------------------------- /examples/messenger/config.yml: -------------------------------------------------------------------------------- 1 | tasks: 2 | default: 3 | onboard_world: ChattyGooseMessengerOnboardWorld 4 | task_world: ChattyGooseMessengerTaskWorld 5 | timeout: 180 6 | agents_required: 1 7 | task_name: chatbot 8 | world_module: examples.messenger.worlds 9 | overworld: ChattyGooseMessengerOverworld 10 | max_workers: 30 11 | opt: 12 | debug: True 13 | password: ChattyGoose 14 | models: 15 | chatty_goose: 16 | model: chatty_goose.agents.chat:ChattyGooseAgent 17 | additional_args: 18 | page_id: ChattyGooseIR # Configure for custom page 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt") as f: 7 | requirements = f.read().splitlines() 8 | 9 | excluded = ["data*", "examples*", "experiments*"] 10 | 11 | 12 | setuptools.setup( 13 | name="chatty-goose", 14 | version="0.2.0", 15 | author="Anserini Gaggle", 16 | author_email="anserini.gaggle@gmail.com", 17 | description="A conversational passage retrieval toolkit", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | url="https://github.com/castorini/chatty-goose", 21 | install_requires=requirements, 22 | packages=setuptools.find_packages(exclude=excluded), 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Operating System :: OS Independent", 27 | ], 28 | python_requires=">=3.7", 29 | ) 30 | -------------------------------------------------------------------------------- /chatty_goose/cqr/cqr.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import abstractmethod 3 | import logging 4 | 5 | 6 | __all__ = ["ConversationalQueryRewriter"] 7 | 8 | 9 | class ConversationalQueryRewriter: 10 | """Base conversational query reformulation class""" 11 | 12 | def __init__(self, name: str, verbose: bool = False): 13 | self.name = name 14 | self.turn_id: int = -1 15 | self.total_latency: float = 0.0 16 | self.verbose: bool = verbose 17 | 18 | @abstractmethod 19 | def rewrite(self, query: str, context: Optional[str] = None) -> str: 20 | """Rewrite original query text""" 21 | raise NotImplementedError 22 | 23 | def reset_history(self): 24 | """Reset conversation history for model""" 25 | if self.verbose and self.turn_id > -1: 26 | turns = self.turn_id + 1 27 | logging.info( 28 | "Resetting {} after {} turns (average reformulation latency {:.4f}s)".format( 29 | self.name, turns, self.total_latency / (turns) 30 | ) 31 | ) 32 | self.turn_id = -1 33 | self.total_latency = 0 34 | -------------------------------------------------------------------------------- /data/cast_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | from pathlib import Path 5 | from spacy.lang.en import English 6 | 7 | 8 | def main(): 9 | """Prepare inference file by transforming CAsT queries into CANARD format""" 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("input_queries", help="Input CAsT query JSON") 12 | parser.add_argument("output_dir", help="Output inference file path") 13 | args = parser.parse_args() 14 | 15 | nlp = English() 16 | 17 | with open(Path(args.input_queries), "r") as fin, open( 18 | Path(args.output_dir), "w" 19 | ) as fout: 20 | js_list = json.load(fin) 21 | for topic in js_list: 22 | history = [] 23 | topic_id = topic["number"] 24 | turns = topic["turn"] 25 | print(f"========TOPIC {topic_id}=======") 26 | for turn in turns: 27 | turn_id = turn["number"] 28 | raw_query = turn["raw_utterance"] 29 | src_text = " ||| ".join(history + [raw_query]) 30 | src_text = " ".join([tok.text for tok in nlp(src_text)]) 31 | history += [raw_query] 32 | print(src_text) 33 | fout.write(src_text + "\n") 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /examples/messenger/README.md: -------------------------------------------------------------------------------- 1 | # Deploying a Facebook Messenger Chat Agent 2 | 3 | This guide is based on ParlAI's [chat service tutorial](https://parl.ai/docs/tutorial_chat_service.html), where we provide the base configuration and classes for deploying our example `ChattyGooseAgent` as a Facebook Messenger chatbot. The following steps deploy the webhook server to Heroku, however it is also possible to deploy it locally by setting the `local: True` parameter under `opt` in `config.yml`. 4 | 5 | 1. Create a new [Facebook Page](https://www.facebook.com/pages/create) and [Facebook App for Messenger](https://developers.facebook.com/docs/messenger-platform/getting-started/app-setup) to host the Chatty Goose agent. Add your Facebook Page under the Webhooks settings for Messenger, and check the "messages" subscription field. 6 | 7 | 2. If deploying to Heroku, create a free account and log into the [Heroku CLI](https://devcenter.heroku.com/articles/heroku-cli) on your machine. 8 | 9 | 3. Run the webhook server and Chatty Goose agent using our provided configuration. This assumes you have the ParlAI Python package installed and are inside the `chatty-goose` root repository folder. 10 | 11 | ``` 12 | python -m parlai.chat_service.services.messenger.run --config-path examples/messenger/config.yml 13 | ``` 14 | 15 | 4. Add the webhook URL outputted from the above command as a callback URL for the Messenger App settings, and set the verify token to `Messenger4ParlAI`. For Heroku, this URL should look like `https://firstname-parlai-messenger-chatbot.herokuapp.com/webhook`. 16 | 17 | 5. Visiting your page and sending a message should now trigger the agent to respond! 18 | -------------------------------------------------------------------------------- /chatty_goose/settings.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseSettings 2 | import typing 3 | from chatty_goose.types import PosFilter 4 | 5 | __all__ = ["SearcherSettings", "HqeSettings", "NtrSettings", "CqeSettings"] 6 | 7 | 8 | class SearcherSettings(BaseSettings): 9 | """Settings for Anserini searcher""" 10 | 11 | index_path: typing.Any # Pre-built index name or path to Lucene index 12 | k1: float = 0.82 # BM25 k parameter 13 | b: float = 0.68 # BM25 b parameter 14 | rm3: bool = False # use RM3 15 | fb_terms: int = 10 # RM3 number of expansion trees 16 | fb_docs: int = 10 # RM3 number of documents 17 | original_query_weight: float = 0.8 # RM3 weigh to assign initial query 18 | 19 | class DenseSearcherSettings(BaseSettings): 20 | """Settings for Pyserini dsearcher""" 21 | 22 | index_path: typing.Any # Pre-built index name or path to faiss index 23 | query_encoder: str # path to huggingface model or hub 24 | 25 | class CqrSettings(BaseSettings): 26 | verbose: bool = False 27 | 28 | 29 | class HqeSettings(CqrSettings): 30 | """Settings for HQE with defaults tuned on CAsT""" 31 | 32 | M: int = 5 # number of aggregate historical queries 33 | eta: float = 10.0 # QPP threshold for first stage retrieval 34 | R_topic: float = 4.5 # topic keyword threshold 35 | R_sub: float = 3.5 # subtopic keyword threshold 36 | filter: PosFilter = PosFilter.POS # 'no' or 'pos' or 'stp' 37 | 38 | class CqeSettings(CqrSettings): 39 | """Settings for CQE model for NTR""" 40 | l2_threshold: float = 10.5 41 | model_name: str = "castorini/cqe" 42 | max_context_length: int = 100 43 | max_query_length: int = 36 44 | 45 | 46 | class NtrSettings(CqrSettings): 47 | """Settings for T5 model for NTR""" 48 | 49 | model_name: str = "castorini/t5-base-canard" 50 | max_length: int = 64 51 | num_beams: int = 10 52 | early_stopping: bool = True 53 | -------------------------------------------------------------------------------- /chatty_goose/cqr/ntr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import torch 4 | from typing import Optional 5 | 6 | from chatty_goose.settings import NtrSettings 7 | from spacy.lang.en import English 8 | from transformers import T5ForConditionalGeneration, T5Tokenizer 9 | 10 | from .cqr import ConversationalQueryRewriter 11 | 12 | __all__ = ["Ntr"] 13 | 14 | 15 | class Ntr(ConversationalQueryRewriter): 16 | """Neural Transfer Reformulation using a trained T5 model""" 17 | 18 | def __init__(self, settings: NtrSettings = NtrSettings(), device: str = None): 19 | super().__init__("Ntr", verbose=settings.verbose) 20 | 21 | # Model settings 22 | self.max_length = settings.max_length 23 | self.num_beams = settings.num_beams 24 | self.early_stopping = settings.early_stopping 25 | 26 | device = device or ("cuda" if torch.cuda.is_available() else "cpu") 27 | self.device = torch.device(device) 28 | 29 | if self.verbose: 30 | logging.info(f"Initializing T5 using model {settings.model_name}...") 31 | self.model = ( 32 | T5ForConditionalGeneration.from_pretrained(settings.model_name) 33 | .to(device) 34 | .eval() 35 | ) 36 | self.tokenizer = T5Tokenizer.from_pretrained(settings.model_name) 37 | self.nlp = English() 38 | self.history_query = [] 39 | self.history = [] 40 | 41 | def rewrite(self, query: str, context: Optional[str] = None, response_num: Optional[int] = 0) -> str: 42 | start_time = time.time() 43 | 44 | # If the passage from canonical result (context) is provided, it is added to history. 45 | # Since canonical passage can be large and there is limit on length of tokens, 46 | # only one passage for the new query is used at a time. 47 | 48 | # if len(self.history) >= 2 and self.has_canonical_context: 49 | # self.history.pop(-2) 50 | # self.has_canonical_context = False 51 | self.history_query += [query] 52 | self.history += [query] 53 | 54 | 55 | # Build input sequence from query and history 56 | if response_num!=0: 57 | src_text = " ||| ".join(self.history_query[:-response_num] + self.history[-2*response_num:]) 58 | else: 59 | src_text = " ||| ".join(self.history_query) 60 | 61 | 62 | src_text = " ".join([tok.text for tok in self.nlp(src_text)]) 63 | input_ids = self.tokenizer( 64 | src_text, return_tensors="pt", add_special_tokens=True 65 | ).input_ids.to(self.device) 66 | 67 | # Generate new sequence 68 | output_ids = self.model.generate( 69 | input_ids, 70 | max_length=self.max_length, 71 | num_beams=self.num_beams, 72 | early_stopping=self.early_stopping, 73 | ) 74 | 75 | # Decode output 76 | rewrite_text = self.tokenizer.decode( 77 | output_ids[0, 0:], 78 | clean_up_tokenization_spaces=True, 79 | skip_special_tokens=True, 80 | ) 81 | if context: 82 | self.history += [context] 83 | self.total_latency += time.time() - start_time 84 | return rewrite_text 85 | 86 | def reset_history(self): 87 | super().reset_history() 88 | self.history = [] 89 | self.history_query = [] 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chatty Goose 2 | 3 | ## Multi-stage Conversational Passage Retrieval: An Approach to Fusing Term Importance Estimation and Neural Query Rewriting 4 | 5 | --- 6 | 7 | [![PyPI](https://img.shields.io/pypi/v/chatty-goose?color=brightgreen)](https://pypi.org/project/chatty-goose/) 8 | [![LICENSE](https://img.shields.io/badge/license-Apache-blue.svg?style=flat)](https://www.apache.org/licenses/LICENSE-2.0) 9 | 10 | ## Installation 11 | 12 | 1. Make sure Java 11+ and **Python 3.7+** are installed 13 | 14 | 2. Install the `chatty-goose` PyPI module 15 | 16 | ``` 17 | pip install chatty-goose 18 | ``` 19 | 20 | 3. If you are using T5 or BERT, make sure to install [PyTorch 1.4.0 - 1.7.1](https://pytorch.org/) using your specific platform instructions. Note that PyTorch 1.8 is currently incompatible due to the `transformers` version we currently use. Also make sure to install the corresponding [torchtext](https://pypi.org/project/torchtext/) version. 21 | 22 | 4. Download the English model for spaCy 23 | 24 | ``` 25 | python -m spacy download en_core_web_sm 26 | ``` 27 | 28 | ## Quickstart Guide 29 | 30 | The following example shows how to initialize a searcher and build a `ConversationalQueryRewriter` agent from scratch using HQE and T5 as first-stage retrievers, and a BERT reranker. To see a working example agent, see [chatty_goose/agents/chat.py](chatty_goose/agents/chat.py). 31 | 32 | First, load a searcher 33 | 34 | ``` 35 | from pyserini.search import SimpleSearcher 36 | 37 | # Option 1: load a prebuilt index 38 | searcher = SimpleSearcher.from_prebuilt_index("INDEX_NAME_HERE") 39 | # Option 2: load a local Lucene index 40 | searcher = SimpleSearcher("PATH_TO_INDEX") 41 | 42 | searcher.set_bm25(0.82, 0.68) 43 | ``` 44 | 45 | Next, initialize one or more first-stage CQR retrievers 46 | 47 | ``` 48 | from chatty_goose.cqr import Hqe, Ntr 49 | from chatty_goose.settings import HqeSettings, NtrSettings 50 | 51 | hqe = Hqe(searcher, HqeSettings()) 52 | ntr = Ntr(NtrSettings()) 53 | ``` 54 | 55 | Load a reranker 56 | 57 | ``` 58 | from chatty_goose.util import build_bert_reranker 59 | 60 | reranker = build_bert_reranker() 61 | ``` 62 | 63 | Create a new `RetrievalPipeline` 64 | 65 | ``` 66 | from chatty_goose.pipeline import RetrievalPipeline 67 | 68 | rp = RetrievalPipeline(searcher, [hqe, ntr], searcher_num_hits=50, reranker=reranker) 69 | ``` 70 | 71 | And we're done! Simply call `rp.retrieve(query)` to retrieve passages, or call `rp.reset_history()` to reset the conversational history of the retrievers. 72 | 73 | ## Running Experiments 74 | 75 | 1. Clone the repo and all submodules (`git submodule update --init --recursive`) 76 | 77 | 2. Clone and build [Anserini](https://github.com/castorini/anserini) for evaluation tools 78 | 79 | 3. Install dependencies 80 | 81 | ``` 82 | pip install -r requirements.txt 83 | ``` 84 | 85 | 4. Follow the instructions under [docs/cqr_experiments.md](docs/cqr_experiments.md) to run experiments using HQE, T5, or fusion. 86 | 87 | ## Example Agent 88 | 89 | To run an interactive conversational search agent with ParlAI, simply run [`chat.py`](chatty_goose/agents/chat.py). By default, we use the CAsT 2019 pre-built Pyserini index, but it is possible to specify other indexes using the `--from_prebuilt` flag. See the file for other possible arguments: 90 | 91 | ``` 92 | python -m chatty_goose.agents.chat 93 | ``` 94 | 95 | Alternatively, run the agent using ParlAI's command line interface: 96 | 97 | ``` 98 | python -m parlai interactive --model chatty_goose.agents.chat:ChattyGooseAgent 99 | ``` 100 | 101 | We also provide instructions to deploy the agent to Facebook Messenger using ParlAI under [`examples/messenger`](examples/messenger/README.md). 102 | -------------------------------------------------------------------------------- /docs/cqr_experiments.md: -------------------------------------------------------------------------------- 1 | # Experiments for Conversational Query Reformulation 2 | 3 | ## Data Preparation 4 | 5 | 1. Download either the [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_v1.0.json) and [evaluation](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json) input query JSON files. These files can be found under `data/treccastweb/2019/data` if you cloned the submodules for this repo. 6 | 7 | Pass your pathname to a variable 8 | default: 9 | ```shell=bash 10 | export input_query_json=data/treccastweb/2019/data 11 | ``` 12 | 13 | 2. Download the evaluation answer files for [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_mod.qrel) or [evaluation](https://trec.nist.gov/data/cast/2019qrels.txt). The training answer file is found under `data/treccastweb/2019/data`. 14 | 15 | ## Run CQR retrieval 16 | 17 | The following command is for HQE, but you can also run other CQR methods using `t5` or `fusion` instead of `hqe` as the input to the `--experiment` flag. Running the command for the first time will download the CAsT 2019 index (or whatever index is specified for the `--sparse_index` flag). It is also possible to supply a path to a local directory containing the index. 18 | 19 | ```shell=bash 20 | python -m experiments.run_retrieval \ 21 | --experiment hqe \ 22 | --hits 1000 \ 23 | --sparse_index cast2019 \ 24 | --qid_queries $input_query_json \ 25 | --output ./output/hqe_bm25 \ 26 | ``` 27 | 28 | The experiment will output the retrieval results at the specified location in TSV format. By default, this will perform retrieval using only BM25, but you can add the `--rerank` flag to further rerank these results using BERT. For other command line arguments, see [run_retrieval.py](../experiments/run_retrieval.py). 29 | 30 | ## Evaluate CQR results 31 | 32 | Convert the TSV file from above to TREC format and use the TREC tool to evaluate the resuls in terms of Recall@1000, mAP and NDCG@1,3. 33 | 34 | ```shell=bash 35 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/hqe_bm25.trec 36 | ``` 37 | 38 | ## Evaluation results 39 | 40 | Results for the CAsT 2019 evaluation dataset are provided below. The results may be slightly different from the numbers reported in the paper due to implementation differences between Huggingface and SpaCy versions. As of writing, we use `spacy==2.2.4` with the English model `en_core_web_sm==2.2.5`, and `transformers==4.0.0`. 41 | 42 | | | HQE BM25 | HQE BM25 + BERT | T5 BM25 | T5 BM25 + BERT | Fusion BM25 | Fusion BM25 + BERT | 43 | | ----------- | :------: | :-------------: | :-----: | :------------: | :---------: | :----------------: | 44 | | mAP | 0.2109 | 0.3058 | 0.2250 | 0.3555 | 0.2584 | 0.3739 | 45 | | Recall@1000 | 0.7322 | 0.7322 | 0.7392 | 0.7392 | 0.8028 | 0.8028 | 46 | | NDCG@1 | 0.2640 | 0.4745 | 0.2842 | 0.5751 | 0.3353 | 0.5838 | 47 | | NDCG@3 | 0.2606 | 0.4798 | 0.2954 | 0.5464 | 0.3247 | 0.5640 | 48 | 49 | ## Reproduction Log 50 | 51 | + Results reproduced by [@saileshnankani](https://github.com/saileshnankani) on 2021-05-07 (commit [`3847d15`](https://github.com/castorini/chatty-goose/commit/3847d15f3fb39a57ac061648c863798dd4510049)) (Fusion BM25) 52 | + Results reproduced by [@ArthurChen189](https://github.com/ArthurChen189) on 2021-05-07 (commit [`ef3a271`](https://github.com/castorini/chatty-goose/commit/ef3a27119d6825a96ae85d1453d6b4eac4ed22b7)) (Fusion BM25) 53 | + Results reproduced by [@andrewyguo](https://github.com/andrewyguo) on 2021-05-07 (commit [`79f89dc`](https://github.com/castorini/chatty-goose/commit/79f89dcaccc9b5d6b89b3d3012b98e98548bf6c7)) (Fusion BM25) 54 | -------------------------------------------------------------------------------- /chatty_goose/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # sys.path.append('./pyserini') 3 | import logging 4 | from os import path 5 | from typing import Dict, List, Tuple 6 | 7 | from pygaggle.rerank.transformer import MonoBERT 8 | from pyserini.search import JSimpleSearcherResult, SimpleSearcher 9 | from pyserini.dsearch import SimpleDenseSearcher 10 | from chatty_goose.settings import SearcherSettings 11 | 12 | 13 | def reciprocal_rank_fusion( 14 | hit_lists: List[List[JSimpleSearcherResult]], k: int = 60 15 | ) -> List[JSimpleSearcherResult]: 16 | """ 17 | Implements reciprocal rank fusion as defined in 18 | "Reciprocal Rank Fusion Outperforms Condorcet and Individual Rank Learning Methods" by Cormack, Clarke and Buettcher. 19 | 20 | Parameters: 21 | hit_lists: lists of hits to merge using reciprocal rank fusion 22 | k: term to avoid vanishing importance of lower-ranked documents (default 60 from original paper) 23 | """ 24 | if len(hit_lists) == 0: 25 | return [] 26 | 27 | if len(hit_lists) == 1: 28 | return hit_lists[0] 29 | 30 | doc_scores: Dict[str, Tuple[float, JSimpleSearcherResult]] = {} 31 | for hits in hit_lists: 32 | for pos, hit in enumerate(hits, start=1): 33 | cur_rank = doc_scores.get(hit.docid, (0.0, hit))[0] 34 | doc_scores[hit.docid] = (cur_rank + 1.0 / (k + pos), hit) 35 | 36 | # Sort by highest score 37 | result = [] 38 | for _, score_hit in sorted(iter(doc_scores.items()), key=lambda kv: -kv[1][0]): 39 | score_hit[1].score = score_hit[0] #update score with rrf fusion score 40 | result.append(score_hit[1]) 41 | return result 42 | 43 | 44 | def build_bert_reranker( 45 | name_or_path: str = "castorini/monobert-large-msmarco-finetune-only", 46 | device: str = None, 47 | ): 48 | """Returns a BERT reranker using the provided model name or path to load from""" 49 | model = MonoBERT.get_model(name_or_path, device=device) 50 | tokenizer = MonoBERT.get_tokenizer(name_or_path) 51 | return MonoBERT(model, tokenizer) 52 | 53 | 54 | def build_searcher(settings: SearcherSettings) -> SimpleSearcher: 55 | if settings.index_path==None: 56 | logging.info( 57 | "Cannot find bm25 index, skip bm25 search!") 58 | searcher=None 59 | else: 60 | if path.isdir(settings.index_path): 61 | searcher = SimpleSearcher(settings.index_path) 62 | else: 63 | searcher = SimpleSearcher.from_prebuilt_index(settings.index_path) 64 | searcher.set_bm25(float(settings.k1), float(settings.b)) 65 | logging.info( 66 | "Initializing BM25, setting k1={} and b={}".format(settings.k1, settings.b) 67 | ) 68 | if settings.rm3: 69 | searcher.set_rm3( 70 | settings.fb_terms, settings.fb_docs, settings.original_query_weight 71 | ) 72 | logging.info( 73 | "Initializing RM3, setting fbTerms={}, fbDocs={} and originalQueryWeight={}".format( 74 | settings.fb_terms, settings.fb_docs, settings.original_query_weight 75 | ) 76 | ) 77 | return searcher 78 | 79 | def build_dense_searcher(settings: SearcherSettings) -> SimpleDenseSearcher: 80 | if (settings.index_path==None) or (settings.query_encoder==None): 81 | logging.info( 82 | "Cannot find dense index or query encoder, skip dense search!") 83 | searcher = None 84 | else: 85 | logging.info( 86 | "Load dense index: {}".format(settings.index_path)) 87 | if path.isdir(settings.index_path): 88 | searcher = SimpleDenseSearcher(settings.index_path, settings.query_encoder) 89 | else: 90 | searcher = SimpleDenseSearcher.from_prebuilt_index(settings.index_path, settings.query_encoder) 91 | # logging.info( 92 | # "Load, setting k1={} and b={}".format(settings.k1, settings.b) 93 | # ) 94 | return searcher 95 | -------------------------------------------------------------------------------- /chatty_goose/agents/chat.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from chatty_goose.cqr import Hqe, Ntr 4 | from chatty_goose.pipeline import RetrievalPipeline 5 | from chatty_goose.settings import HqeSettings, NtrSettings 6 | from chatty_goose.types import CqrType, PosFilter 7 | from parlai.core.agents import Agent, register_agent 8 | from pyserini.search import SimpleSearcher 9 | 10 | 11 | @register_agent("ChattyGooseAgent") 12 | class ChattyGooseAgent(Agent): 13 | @classmethod 14 | def add_cmdline_args(cls, parser, partial_opt = None): 15 | parser.add_argument('--name', type=str, default='CQR', help="The agent's name.") 16 | parser.add_argument('--cqr_type', type=str, default='fusion', help="hqe, t5, or fusion") 17 | parser.add_argument('--episode_done', type=str, default='[END]', help="end signal for interactive mode") 18 | parser.add_argument('--hits', type=int, default=50, help="number of hits to retrieve from searcher") 19 | 20 | # Pyserini 21 | parser.add_argument('--k1', default=0.82, help='BM25 k1 parameter') 22 | parser.add_argument('--b', default=0.68, help='BM25 b parameter') 23 | parser.add_argument('--from_prebuilt', type=str, default='cast2019', help="Pyserini prebuilt index") 24 | 25 | # T5 26 | parser.add_argument('--from_pretrained', type=str, default='castorini/t5-base-canard', help="Huggingface T5 checkpoint") 27 | 28 | # HQE 29 | parser.add_argument('--M', default=5, type=int, help='aggregate historcial queries for first stage (BM25) retrieval') 30 | parser.add_argument('--eta', default=10, type=float, help='QPP threshold for first stage (BM25) retrieval') 31 | parser.add_argument('--R_topic', default=4.5, type=float, help='topic keyword threshold for first stage (BM25) retrieval') 32 | parser.add_argument('--R_sub', default=3.5, type=float, help='subtopic keyword threshold for first stage (BM25) retrieval') 33 | parser.add_argument('--filter', default='pos', help='filter word method: no, pos, stp') 34 | parser.add_argument('--verbose', action='store_true') 35 | return parser 36 | 37 | def __init__(self, opt, shared=None): 38 | super().__init__(opt, shared) 39 | self.name = opt["name"] 40 | self.episode_done = opt["episode_done"] 41 | self.cqr_type = CqrType(opt["cqr_type"]) 42 | 43 | # Initialize searcher 44 | searcher = SimpleSearcher.from_prebuilt_index(opt["from_prebuilt"]) 45 | searcher.set_bm25(float(opt["k1"]), float(opt["b"])) 46 | 47 | # Initialize retrievers 48 | retrievers = [] 49 | if self.cqr_type == CqrType.HQE or self.cqr_type == CqrType.FUSION: 50 | hqe_settings = HqeSettings( 51 | M=opt["M"], 52 | eta=opt["eta"], 53 | R_topic=opt["R_topic"], 54 | R_sub=opt["R_sub"], 55 | filter=PosFilter(opt["filter"]), 56 | verbose=opt["verbose"], 57 | ) 58 | hqe = Hqe(searcher, hqe_settings) 59 | retrievers.append(hqe) 60 | if self.cqr_type == CqrType.T5 or self.cqr_type == CqrType.FUSION: 61 | t5_settings = NtrSettings(model_name=opt["from_pretrained"], verbose=opt["verbose"]) 62 | t5 = Ntr(t5_settings) 63 | retrievers.append(t5) 64 | 65 | self.rp = RetrievalPipeline(searcher, retrievers, int(opt["hits"])) 66 | 67 | def observe(self, observation): 68 | # Gather the last word from the other user's input 69 | self.query = observation.get("text", "") 70 | if observation.get("episode_done") or self.query == self.episode_done: 71 | logging.info("Resetting agent history") 72 | self.rp.reset_history() 73 | 74 | def act(self): 75 | if self.query == self.episode_done: 76 | return {"id": self.id, "text": "Session finished"} 77 | 78 | # Retrieve hits 79 | hits = self.rp.retrieve(self.query) 80 | if len(hits) == 0: 81 | result = "Sorry, I couldn't find any results" 82 | else: 83 | result = hits[0].raw 84 | return { "id": self.id, "text": result } 85 | 86 | 87 | if __name__ == "__main__": 88 | from parlai.scripts.interactive import Interactive 89 | 90 | Interactive.main(model="ChattyGooseAgent", cqr_type="fusion") 91 | -------------------------------------------------------------------------------- /examples/messenger/worlds.py: -------------------------------------------------------------------------------- 1 | from parlai.core.worlds import World 2 | from parlai.chat_service.services.messenger.worlds import OnboardWorld 3 | from parlai.core.agents import create_agent_from_shared 4 | 5 | 6 | class ChattyGooseMessengerOnboardWorld(OnboardWorld): 7 | """ 8 | Example messenger onboarding world for Chatty Goose. 9 | """ 10 | 11 | @staticmethod 12 | def generate_world(opt, agents): 13 | return ChattyGooseMessengerOnboardWorld(opt=opt, agent=agents[0]) 14 | 15 | def parley(self): 16 | self.episodeDone = True 17 | 18 | 19 | class ChattyGooseMessengerTaskWorld(World): 20 | """ 21 | Example one person world that talks to a provided agent. 22 | """ 23 | MODEL_KEY = 'chatty_goose' 24 | 25 | def __init__(self, opt, agent, bot): 26 | self.agent = agent 27 | self.episodeDone = False 28 | self.model = bot 29 | self.first_time = True 30 | 31 | @staticmethod 32 | def generate_world(opt, agents): 33 | if opt['models'] is None: 34 | raise RuntimeError("Model must be specified") 35 | return ChattyGooseMessengerTaskWorld( 36 | opt, 37 | agents[0], 38 | create_agent_from_shared( 39 | opt['shared_bot_params'][ChattyGooseMessengerTaskWorld.MODEL_KEY] 40 | ), 41 | ) 42 | 43 | @staticmethod 44 | def assign_roles(agents): 45 | agents[0].disp_id = 'ChattyGooseAgent' 46 | 47 | def parley(self): 48 | if self.first_time: 49 | self.agent.observe( 50 | { 51 | 'id': '', 52 | 'text': 'Welcome to the Chatty Goose demo! ' 53 | 'Please type a query. ' 54 | 'Type [DONE] to finish the chat, or [RESET] to reset the dialogue history.', 55 | } 56 | ) 57 | self.first_time = False 58 | a = self.agent.act() 59 | if a is not None: 60 | if '[DONE]' in a['text']: 61 | self.episodeDone = True 62 | elif '[RESET]' in a['text']: 63 | self.model.reset() 64 | self.agent.observe( 65 | {"text": "[History Cleared]", "episode_done": False}) 66 | else: 67 | self.model.observe(a) 68 | response = self.model.act() 69 | # Make sure prefix from agent is not displayed 70 | response['id'] = '' 71 | self.agent.observe(response) 72 | 73 | def episode_done(self): 74 | return self.episodeDone 75 | 76 | def shutdown(self): 77 | self.agent.shutdown() 78 | 79 | 80 | class ChattyGooseMessengerOverworld(World): 81 | """ 82 | World to handle moving agents to their proper places. 83 | """ 84 | 85 | def __init__(self, opt, agent): 86 | self.agent = agent 87 | self.opt = opt 88 | self.first_time = True 89 | self.episodeDone = False 90 | 91 | @staticmethod 92 | def generate_world(opt, agents): 93 | return ChattyGooseMessengerOverworld(opt, agents[0]) 94 | 95 | @staticmethod 96 | def assign_roles(agents): 97 | for a in agents: 98 | a.disp_id = 'Agent' 99 | 100 | def episode_done(self): 101 | return self.episodeDone 102 | 103 | def parley(self): 104 | if self.first_time: 105 | self.agent.observe( 106 | { 107 | 'id': 'Overworld', 108 | 'text': 'Welcome to the Chatty Goose Messenger overworld! ' 109 | 'Please type "Start" to start, or "Exit" to exit. ', 110 | 'quick_replies': ['Start', 'Exit'], 111 | } 112 | ) 113 | self.first_time = False 114 | a = self.agent.act() 115 | if a is not None and a['text'].lower() == 'exit': 116 | self.episode_done = True 117 | return 'EXIT' 118 | if a is not None and a['text'].lower() == 'start': 119 | self.episodeDone = True 120 | return 'default' 121 | elif a is not None: 122 | self.agent.observe( 123 | { 124 | 'id': 'Overworld', 125 | 'text': 'Invalid option. Please type "Start".', 126 | 'quick_replies': ['Start'], 127 | } 128 | ) 129 | -------------------------------------------------------------------------------- /docs/cqr_experiment_2020.md: -------------------------------------------------------------------------------- 1 | # Experiments for CQR with CAsT 2020 2 | 3 | ## Data Preparation 4 | 5 | 1. Download [evaluation](https://github.com/daltonj/treccastweb/blob/master/2020/2020_manual_evaluation_topics_v1.0.json) input query JSON file. This file can be found under `data/treccastweb/2020` if you cloned the submodules for this repo. 6 | 7 | 2. Download the evaluation answer files for [evaluation](https://trec.nist.gov/data/cast/2020qrels.txt). 8 | 9 | ## Run and evaluate CQR retrieval 10 | 11 | The run is similar to the run for the [CQR experiment for CAsT2019](./cqr_experiments.md#run-cqr-retrieval). For canonical runs, you also need to specify an extra `--context_index` flag to define the index from which the canonical passage is retrieved from. `--add_response` controls how many previous response you want to add to the context; 0 represents using historical query only. 12 | 13 | 14 | The index `cast2019` can still be used to perform bm25 search since `cast2019` and `cast2020` share the same corpus. 15 | 16 | In the naive run, only the `raw_utterance` is used. In the canonical run, the passage corresponding to `manual_canonical_result_id` is also used in the context. 17 | 18 | Results for the CAsT 2020 evaluation dataset are provided below for both naive and canonical runs. As of writing, we use `spacy==2.2.4` with the English model `en_core_web_sm==2.2.5`, and `transformers==4.0.0`. 19 | 20 | ### Historical Query only 21 | 22 | ```shell=bash 23 | python -m experiments.run_retrieval \ 24 | --experiment hqe or t5 or hqe_t5_fusion \ 25 | --hits 1000 \ 26 | --sparse_index cast2019 \ 27 | --qid_queries $input_query_json \ 28 | --output ./output/result \ 29 | 30 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/result.trec 31 | ``` 32 | 33 | | | HQE BM25 | T5 BM25 | Fusion BM25 | 34 | | ----------- | :------: | :-------------: | :---------: | 35 | | mAP | 0.1155 | 0.1236 | 0.1386 | 36 | | Recall@1000 | 0.5316 | 0.5551 | 0.6063 | 37 | | NDCG@1 | 0.1635 | 0.1639 | 0.2015 | 38 | | NDCG@3 | 0.1640 | 0.1620 | 0.1879 | 39 | 40 | --------- 41 | 42 | ### One Canonical Response 43 | 44 | ```shell=bash 45 | python -m experiments.run_retrieval \ 46 | --experiment hqe or t5 or hqe_t5_fusion \ 47 | --hits 1000 \ 48 | --sparse_index cast2019 \ 49 | --qid_queries $input_query_json \ 50 | --output ./output/result \ 51 | --add_response 1 \ 52 | 53 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/result.trec 54 | ``` 55 | 56 | | | HQE BM25 | T5 BM25 | Fusion BM25 | 57 | | ----------- | :------: | :-------------: | :---------: | 58 | | mAP | 0.1061 | 0.1271 | 0.1478 | 59 | | Recall@1000 | 0.5887 | 0.5625 | 0.6594 | 60 | | NDCG@1 | 0.1306 | 0.1807 | 0.1959 | 61 | | NDCG@3 | 0.1244 | 0.1652 | 0.1925 | 62 | 63 | --------- 64 | 65 | ### Two Canonical Response 66 | 67 | ```shell=bash 68 | python -m experiments.run_retrieval \ 69 | --experiment hqe or t5 or hqe_t5_fusion \ 70 | --hits 1000 \ 71 | --sparse_index cast2019 \ 72 | --qid_queries $input_query_json \ 73 | --output ./output/result \ 74 | --add_response 2 \ 75 | 76 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/result.trec 77 | ``` 78 | 79 | | | HQE BM25 | T5 BM25 | Fusion BM25 | 80 | | ----------- | :------: | :------------: | :------------: | 81 | | mAP | 0.0948 | 0.1319 | 0.1446 | 82 | | Recall@1000 | 0.5755 | 0.5747 | 0.6618 | 83 | | NDCG@1 | 0.1306 | 0.1955 | 0.2011 | 84 | | NDCG@3 | 0.1185 | 0.1785 | 0.1967 | 85 | 86 | --------- 87 | 88 | ### BERR Reranking 89 | ```shell=bash 90 | python -m experiments.run_retrieval \ 91 | --experiment hqe_t5_fusion \ 92 | --hits 1000 \ 93 | --sparse_index cast2019 \ 94 | --qid_queries $input_query_json \ 95 | --output ./output/result \ 96 | --add_response 2 \ 97 | --rerank 98 | 99 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/result.trec 100 | ``` 101 | 102 | | | Fusion BM25 Rerank | 103 | | ----------- | :------------: | 104 | | mAP | 0.2861 | 105 | | Recall@1000 | 0.6618 | 106 | | NDCG@1 | 0.4623 | 107 | | NDCG@3 | 0.4202 | 108 | 109 | This [link](https://colab.research.google.com/drive/1KBm-BJAy9Yhb5b7NMuuW4v_gj8KJ0VAv?usp=sharing) is the Colab demo for HQE-T5 early fusion with BERT reranking. 110 | ## Reproduction Log 111 | 112 | 113 | -------------------------------------------------------------------------------- /docs/t5_finetuning.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Conversational Question Reformulation" 3 | author: Matt Yang 4 | --- 5 | 6 | # T5 for Conversational Query Reformulation (CQR) 7 | 8 | ## Overview 9 | 10 | This is the runbook of fine-tuning T5 for the conversational query reformulation task for TREC-CAsT 2019 dataset. 11 | 12 | ## Training data preparation 13 | 14 | We first use a conversational query reformulation dataset from [CANARD](https://sites.google.com/view/qanta/projects/canard). 15 | 16 | > [Dataset Paper (EMNLP'19)](http://users.umiacs.umd.edu/~jbg/docs/2019_emnlp_sequentialqa.pdf) 17 | 18 | First, we use the dataset and preprocessing script released by the authors. Note that you have to install [SpaCy](https://spacy.io/) by yourself. 19 | Make sure you have cloned the submodules for this repository, which will pull the latest changes from `canard` under `data/canard`. 20 | 21 | ```shell=bash 22 | # run preprocessing script 23 | cd data/canard 24 | # modify & run preprocessing script 25 | sed 's/seq2seq\/release/release/g' preprocess.sh > preprocess.mod.sh 26 | bash preprocess.mod.sh 27 | 28 | # prepare a tsv file for finetuning T5 29 | paste data/seq2seq/train-src.txt data/seq2seq/train-tgt.tsv > history_query_pair.tsv 30 | ``` 31 | 32 | And we have the trainig data for T5 `history_query_pair.train.tsv` 33 | 34 | ## Replicating T5 fine-tuning for CQR 35 | 36 | To begin, follow [T5 repo](https://github.com/google-research/text-to-text-transfer-transformer) to install packages. The following guide will show you how to train/predict the reformulated queries with the T5 model and the data from CANARD. 37 | Here we show how to use [Text-To-Text Transfer Transformer (T5)](https://github.com/google-research/text-to-text-transfer-transformer) model from its original github repo to finetune T5 as a CQR module. The following command will train a T5-base model for 4k iterations to predict queries from passages. We assume you put the tsv training file in `gs://your_bucket/data/history_query_pairs.train.tsv` (download from above). Also, change `your_tpu_name`, `your_tpu_zone`, `your_project_id`, and `your_bucket` accordingly. 38 | 39 | ```bash 40 | # please check the original T5 repo for full guide 41 | pip install t5[gcp] 42 | ``` 43 | 44 | ```bash 45 | t5_mesh_transformer \ 46 | --tpu="your_tpu_name" \ 47 | --gcp_project="your_project_id" \ 48 | --tpu_zone="your_tpu_zone" \ 49 | --model_dir="gs://your_bucket/models/" \ 50 | --gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/base/model.ckpt-999900'" \ 51 | --gin_file="dataset.gin" \ 52 | --gin_file="models/bi_v1.gin" \ 53 | --gin_file="gs://t5-data/pretrained_models/base/operative_config.gin" \ 54 | --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \ 55 | --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \ 56 | --gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \ 57 | --gin_param="tsv_dataset_fn.filename = 'gs://your_bucket/data/history_query_pairs.train.tsv'" \ 58 | --gin_file="learning_rate_schedules/constant_0_001.gin" \ 59 | --gin_param="run.train_steps = 1004000" \ 60 | --gin_param="tokens_per_batch = 131072" 61 | ``` 62 | 63 | ## Predicting Queries from History 64 | 65 | ```bash 66 | # prepare your data on the google cloud storage 67 | gsutil cp data/canard/data/seq2seq/test-src.txt gs://your_bucket/data/test-src.canard.txt 68 | ``` 69 | 70 | ```bash 71 | ### BEAM search decoding 72 | for BEAM in {1,5,10,15,20}; do 73 | t5_mesh_transformer \ 74 | --tpu="your_tpu_name" \ 75 | --gcp_project="your_project_id" \ 76 | --tpu_zone="your_tpu_zone" \ 77 | --model_dir="gs://your_buckets/models/" \ 78 | --gin_file="gs://t5-data/pretrained_models/base/operative_config.gin" \ 79 | --gin_file="infer.gin" \ 80 | --gin_file="beam_search.gin" \ 81 | --gin_param="Bitransformer.decode.beam_size = ${BEAM}" \ 82 | --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \ 83 | --gin_param="infer_checkpoint_step = 1004000" \ 84 | --gin_param="utils.run.sequence_length = {'inputs': 512, 'targets': 64}" \ 85 | --gin_param="Bitransformer.decode.max_decode_length = 64" \ 86 | --gin_param="input_filename = 'gs://your_buckets/data/test-src.canard.txt'" \ 87 | --gin_param="output_filename = 'gs://your_buckets/data/test_pred_beam_${BEAM}.txt'" \ 88 | --gin_param="tokens_per_batch = 131072" \ 89 | --gin_param="Bitransformer.decode.temperature = 0.0" \ 90 | --gin_param="Unitransformer.sample_autoregressive.sampling_keep_top_k = -1" 91 | done 92 | ``` 93 | 94 | ## Preparing Inference File for CAsT 95 | 96 | We here prepare a preprocessing script for you to transform CAsT queries into a CANARD compatible format: 97 | 98 | ```bash 99 | python -m data.cast_preprocess.py data/treccastweb/2019/data/evaluation/evaluation_topics_v1.0.json data/eval-queries.txt 100 | ``` 101 | 102 | After transforming the queries, just use the same procedure but replace the `'input_filename = ...'` with CAsT queries for CQR inferencing. 103 | -------------------------------------------------------------------------------- /chatty_goose/cqr/cqe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import torch 4 | from typing import Optional 5 | from pyserini.search import querybuilder 6 | from transformers import BertModel, BertTokenizer, BertTokenizerFast 7 | from .cqr import ConversationalQueryRewriter 8 | from chatty_goose.settings import CqeSettings 9 | import numpy as np 10 | from numpy import linalg as LA 11 | 12 | name = "Cqe" 13 | __all__ = [name] 14 | 15 | 16 | class Cqe(ConversationalQueryRewriter): 17 | """Neural Transfer Reformulation using a trained T5 model""" 18 | 19 | def __init__(self, settings: CqeSettings = CqeSettings(), device: str = None): 20 | super().__init__(name, verbose=settings.verbose) 21 | self.name=name 22 | # Model settings 23 | if settings.model_name: 24 | self.l2_threshold = settings.l2_threshold 25 | self.max_query_length = settings.max_query_length 26 | self.max_context_length = settings.max_context_length 27 | self.max_length = self.max_query_length + self.max_context_length 28 | device = device or ("cuda" if torch.cuda.is_available() else "cpu") 29 | self.device = torch.device(device) 30 | if self.verbose: 31 | logging.info(f"Initializing CQE using model {settings.model_name}...") 32 | logging.info(f"Initializing CQE using model {settings.model_name}...") 33 | self.model = BertModel.from_pretrained(settings.model_name) 34 | self.model.to(self.device) 35 | self.tokenizer = BertTokenizer.from_pretrained(settings.model_name) 36 | self.has_model = True 37 | self.history_query = [] 38 | self.history = [] 39 | if (not self.has_model): 40 | raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') 41 | 42 | def rewrite(self, query: str, context: Optional[str] = None, response_num: Optional[int] = 0) -> str: 43 | start_time = time.time() 44 | self.turn_id += 1 45 | 46 | # If the passage from canonical result (context) is provided, it is added to history. 47 | # Since canonical passage can be large and there is limit on length of tokens, 48 | # only one passage for the new query is used at a time. 49 | self.history_query += [query] 50 | self.history += [query] 51 | if response_num!=0: 52 | src_ctx = '[CLS] ' +"|".join(self.history_query[:-response_num] + self.history[-2*response_num:-1]) 53 | else: 54 | src_ctx = '[CLS] ' +"|".join(self.history_query[:-1]) 55 | src_q = ' [Q] ' + self.history[-1] + '[MASK]' * self.max_query_length 56 | if self.turn_id == 0: 57 | input_ids = self.tokenizer.encode('[CLS] ' +src_q, max_length=self.max_query_length, add_special_tokens=False) 58 | token_type_ids =[0] + [1] * (len(input_ids)-1) 59 | self.context_length = 3 60 | effective_token_weights = [0]*4 + [1]*(self.max_query_length-4) 61 | else: 62 | src_ctx_token_ids = self.tokenizer.encode(src_ctx, max_length=self.max_context_length, add_special_tokens=False) 63 | src_query_token_ids = self.tokenizer.encode(src_q, max_length=self.max_query_length, add_special_tokens=False) 64 | input_ids = src_ctx_token_ids + src_query_token_ids 65 | token_type_ids = [0]*len(src_ctx_token_ids) + [1]*len(src_query_token_ids) 66 | self.context_length = len(src_ctx_token_ids) - 1 + 3 67 | effective_token_weights = [0] + [1]*(len(src_ctx_token_ids)-1) + [0]*3 + [1]*(self.max_query_length-3) 68 | input_tokens = [] 69 | for input_id in input_ids[1:]: 70 | input_tokens.append(self.tokenizer.decode(input_id).replace(' ','') ) #transform whitepiece token into text by removing ' ' 71 | 72 | input_ids = torch.LongTensor([input_ids]).to(self.device) 73 | token_type_ids = torch.LongTensor([token_type_ids]).to(self.device) 74 | 75 | 76 | 77 | outputs = self.model(input_ids=input_ids, token_type_ids=token_type_ids) 78 | embeddings = outputs.last_hidden_state.detach().cpu().numpy() 79 | 80 | self.query_embs = np.average(embeddings, axis=1, weights=effective_token_weights).astype(np.float32) 81 | 82 | # Generate reformulated query with term weight 83 | query_token_weights = np.squeeze(LA.norm(embeddings[:,1:,:], axis=-1)) 84 | 85 | rewrite_text = self.build_query(input_tokens, query_token_weights, self.l2_threshold) 86 | if context: 87 | self.history += [context] 88 | 89 | self.total_latency += time.time() - start_time 90 | return rewrite_text 91 | 92 | def reset_history(self): 93 | super().reset_history() 94 | self.history_query = [] 95 | self.history = [] 96 | 97 | def build_query(self, query_tokens, query_token_weights, threshold): 98 | context_token_num = self.context_length 99 | mean_l2 = query_token_weights.mean() 100 | 101 | should = querybuilder.JBooleanClauseOccur['should'].value 102 | boolean_query_builder = querybuilder.get_boolean_query_builder() 103 | is_context = True 104 | term_weight = 0 105 | term ='' 106 | 107 | for i, token in enumerate(query_tokens): 108 | if token == '[MASK]': 109 | 110 | continue 111 | if '##' in token: 112 | term += token[2:] 113 | term_weight = max(term_weight, query_token_weights[i]) 114 | else: 115 | if ( (term_weight > threshold) or (i>=context_token_num)): #10.5, 12 116 | try: 117 | term = querybuilder.get_term_query(term) 118 | boost = querybuilder.get_boost_query(term, term_weight/mean_l2) 119 | boolean_query_builder.add(boost, should) 120 | except: 121 | pass 122 | 123 | term = token 124 | term_weight = query_token_weights[i] 125 | 126 | return boolean_query_builder.build() 127 | 128 | -------------------------------------------------------------------------------- /docs/conversation_dense_retrieval_experiments.md: -------------------------------------------------------------------------------- 1 | # Experiments for Conversational Dense Retrieval 2 | This is the Chatty-goose reproduction of our paper: *[Contextualized Query Embeddings for Conversational Search](https://arxiv.org/abs/2104.08707)* Sheng-Chieh Lin, Jheng-Hong Yang and Jimmy Lin. Note that, due to quey latency concern, different from our orginal paper and [CQE github](https://github.com/castorini/CQE), in this repo, we use HSW dense index; thus, the result may be slightly different. 3 | ## Data Preparation 4 | 5 | 1. Download the input query JSON files for [cast 2019](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json) and the [cast 2020](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json). These files can be found under `data/treccastweb/2019/data` and `data/treccastweb/2020`, respectively, if you cloned the submodules for this repo. 6 | 7 | Pass your pathname to a variable. Here we use cast2020 as an example. 8 | default: 9 | ```shell=bash 10 | export input_query_json=./treccastweb/2020/2020_manual_evaluation_topics_v1.0.json 11 | ``` 12 | 13 | 2. Download the evaluation answer files for [cast 2019](https://trec.nist.gov/data/cast/2019qrels.txt) and [cast 2020](https://trec.nist.gov/data/cast/2020qrels.txt). 14 | 15 | ## Run CQE retrieval 16 | 17 | The following command is for CQE, but you can also run other CQR methods using `t5` or `cqe_t5_fusion` as the input to the `--experiment` flag (Currently, dense retrieval does not support HQE since it requires longer query sequence). Running the command for the first time will download the CAsT 2019 index (or whatever index is specified for the `--sparse_index` flag). It is also possible to supply a path to a local directory containing the index. Note that `--cqe_l2_threshold` is to control how text is generated by CQE for sparse search. While conducting CQE BM25 retrieval, we use defaul setting, and for CQE hybrid search, we set `--cqe_l2_threshold=12`. 18 | ### CQE BM25 Retrieval 19 | 20 | ```shell=bash 21 | python -m experiments.run_retrieval \ 22 | --experiment cqe \ 23 | --sparse_index cast2019 \ 24 | --hits 1000 \ 25 | --qid_queries $input_query_json \ 26 | --output ./output/cqe_bm25 \ 27 | 28 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/cqe_bm25.trec 29 | ``` 30 | ### CQE Dense Retrieval 31 | ```shell=bash 32 | python -m experiments.run_retrieval \ 33 | --experiment cqe \ 34 | --dense_index cast2019-tct_colbert-v2-hnsw \ 35 | --hits 1000 \ 36 | --qid_queries $input_query_json \ 37 | --output ./output/cqe_dpr \ 38 | 39 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/cqe_dpr.trec 40 | ``` 41 | ### CQE Sparse-Dense Hybrid Retrieval 42 | ```shell=bash 43 | python -m experiments.run_retrieval \ 44 | --experiment cqe \ 45 | --cqe_l2_threshold 12 \ 46 | --dense_index cast2019-tct_colbert-v2-hnsw \ 47 | --sparse_index cast2019 \ 48 | --hits 1000 \ 49 | --qid_queries $input_query_json \ 50 | --output ./output/cqe_hybrid \ 51 | 52 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/cqe_hybrid.trec 53 | ``` 54 | ### CQE fuse T5 Sparse-Dense Hybrid Retrieval 55 | ```shell=bash 56 | python -m experiments.run_retrieval \ 57 | --experiment cqe_t5_fusion \ 58 | --cqe_l2_threshold 12 \ 59 | --dense_index cast2019-tct_colbert-v2-hnsw \ 60 | --sparse_index cast2019 \ 61 | --hits 1000 \ 62 | --qid_queries $input_query_json \ 63 | --output ./output/cqe_t5_hybrid \ 64 | 65 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/cqe_t5_hybrid.trec 66 | ``` 67 | 68 | ## Evaluation results 69 | 70 | Results for the CAsT 2019 and 2020 evaluation dataset are provided below. The results may be slightly different from the numbers reported in the paper due to implementation differences between Huggingface and SpaCy versions. As of writing, we use `spacy==2.2.4` with the English model `en_core_web_sm==2.2.5`, and `transformers==4.0.0`. Note that the Recall@1000 reported in [CQE paper]((https://arxiv.org/abs/2104.08707)) are using rel greater than 2 but in the repo, to be consistent with other previous experiments, we use rel greater than 1. 71 | 72 | | CAsT2019 | CQE BM25 | CQE Dense Retrieval | CQE Hybrid | T5 BM25 | T5 Dense Retrieval | T5 Hybrid | CQE+T5 Hybrid Fusion | 73 | | ----------- | :------: | :-------------: | :-------------: | :-----: | :------------: | :---------: | :----------------: | 74 | | mAP | 0.2059 | 0.2616 | 0.2997 | 0.2250 | 0.2512 | 0.3043 | 0.3391 | 75 | | Recall@1000 | 0.7705 | 0.7248 | 0.7984 | 0.7392 | 0.6734 | 0.7856 | 0.8376 | 76 | | NDCG@1 | 0.3030 | 0.5082 | 0.4971 | 0.2842 | 0.4841 | 0.5077 | 0.5318 | 77 | | NDCG@3 | 0.2740 | 0.4924 | 0.5032 | 0.2954 | 0.4688 | 0.5065 | 0.5226 | 78 | 79 | | CAsT2020 | CQE BM25 | CQE Dense Retrieval | CQE Hybrid | T5 BM25 | T5 Dense Retrieval | T5 Hybrid | CQE+T5 Hybrid Fusion | 80 | | ----------- | :------: | :-------------: | :-------------: | :-----: | :------------: | :---------: | :----------------: | 81 | | mAP | 0.1301 | 0.2072 | 0.2400 | 0.1236 | 0.1989 | 0.2309 | 0.2495 | 82 | | Recall@1000 | 0.6097 | 0.7008 | 0.7410 | 0.5551 | 0.6380 | 0.6983 | 0.7638 | 83 | | NDCG@1 | 0.1875 | 0.3429 | 0.3794 | 0.1639 | 0.3538 | 0.3742 | 0.3982 | 84 | | NDCG@3 | 0.1712 | 0.3122 | 0.3383 | 0.1620 | 0.3182 | 0.3323 | 0.3599 | 85 | 86 | ## Add canonical response 87 | You can also add canonical response for task of CAsT2020 with the option `--add_response` to specify how many previous response to be added in the context. 88 | ```shell=bash 89 | python -m experiments.run_retrieval \ 90 | --experiment t5 \ 91 | --cqe_l2_threshold 12 \ 92 | --add_response 1 \ 93 | --dense_index cast2019-tct_colbert-v2-hnsw \ 94 | --sparse_index cast2019 \ 95 | --hits 1000 \ 96 | --qid_queries $input_query_json \ 97 | --output ./output/t5_hybrid \ 98 | 99 | python -m pyserini.eval.trec_eval -c -mndcg_cut.3,1 -mrecall.1000 -mmap $qrel ./output/t5_hybrid.trec 100 | ``` 101 | | CAsT2020 | T5 Hybrid | 102 | | ----------- | :----------------: | 103 | | mAP | 0.2333 | 104 | | Recall@1000 | 0.6930 | 105 | | NDCG@1 | 0.3934 | 106 | | NDCG@3 | 0.3406 | 107 | 108 | ## Reproduction Log 109 | 110 | -------------------------------------------------------------------------------- /chatty_goose/cqr/hqe.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import re 3 | import time 4 | from typing import Optional 5 | 6 | import spacy 7 | from chatty_goose.settings import HqeSettings 8 | from .cqr import ConversationalQueryRewriter 9 | 10 | import sys 11 | sys.path.append('../pyserini') 12 | from pyserini.search import SimpleSearcher 13 | 14 | 15 | name = "Hqe" 16 | __all__ = [name] 17 | 18 | nlp = spacy.load("en_core_web_sm") 19 | STOP_WORDS = nlp.Defaults.stop_words 20 | 21 | 22 | class Hqe(ConversationalQueryRewriter): 23 | """Historical Query Expansion for conversational query reformulation""" 24 | 25 | def __init__(self, searcher: SimpleSearcher, settings: HqeSettings = HqeSettings()): 26 | super().__init__(name, verbose=settings.verbose) 27 | self.name = name 28 | # Model settings 29 | self.M = settings.M 30 | self.eta = settings.eta 31 | self.R_topic = settings.R_topic 32 | self.R_sub = settings.R_sub 33 | self.pos_filter = settings.filter 34 | self.searcher = searcher 35 | 36 | # History 37 | self.q_key_word_list = collections.defaultdict(list) 38 | self.q_subkey_word_list = collections.defaultdict(list) 39 | 40 | self.r_key_word_list = collections.defaultdict(list) 41 | self.r_subkey_word_list = collections.defaultdict(list) 42 | 43 | def rewrite(self, query: str, context: Optional[str] = None, response_num: Optional[int] = 0) -> str: 44 | start_time = time.time() 45 | self.turn_id += 1 46 | # self.history_query += [query] 47 | # if response_num!=0: 48 | # src_text = " ".join(self.history_query[:-response_num] + self.history[-2*response_num:]) 49 | # else: 50 | # src_text = " ".join(self.history_query) 51 | 52 | self.key_word_extraction(query, 'q') 53 | if (response_num > 0): 54 | self.key_word_extraction(context, 'r') 55 | if self.turn_id>0: 56 | hits = self.searcher.search(query, 1) 57 | if response_num>0: 58 | key_word = self.query_expansion(self.q_key_word_list, 0, self.turn_id) + " "+ self.query_expansion(self.r_key_word_list, self.turn_id-response_num-1, self.turn_id-1) 59 | else: 60 | key_word = self.query_expansion(self.q_key_word_list, 0, self.turn_id) 61 | subkey_word = "" 62 | if len(hits) == 0 or hits[0].score <= self.eta: 63 | end_turn = self.turn_id + 1 64 | start_turn = end_turn - self.M 65 | if start_turn < 0: 66 | start_turn = 0 67 | if response_num>0: 68 | subkey_word = self.query_expansion(self.q_subkey_word_list, start_turn, end_turn)+ " " + self.query_expansion(self.r_subkey_word_list, self.turn_id-response_num-1, self.turn_id-1) 69 | else: 70 | subkey_word = self.query_expansion(self.q_subkey_word_list, start_turn, end_turn) 71 | query = key_word + " " + subkey_word + " " + query 72 | self.total_latency += time.time() - start_time 73 | 74 | return query 75 | 76 | def reset_history(self): 77 | super().reset_history() 78 | self.q_key_word_list = collections.defaultdict(list) 79 | self.q_subkey_word_list = collections.defaultdict(list) 80 | self.r_key_word_list = collections.defaultdict(list) 81 | self.r_subkey_word_list = collections.defaultdict(list) 82 | 83 | 84 | def key_word_extraction(self, query, content): 85 | proc_query = self.calc_word_score(query) 86 | # Extract topic keyword 87 | if self.pos_filter == "no": 88 | for i, word in enumerate(proc_query["word"]): 89 | if proc_query["score"][i] >= self.R_topic: 90 | if content=='r': 91 | self.r_key_word_list[self.turn_id].append(word) 92 | if content=='q': 93 | self.q_key_word_list[self.turn_id].append(word) 94 | if (proc_query["score"][i] >= self.R_sub) & ( 95 | proc_query["score"][i] < self.R_topic 96 | ): 97 | self.subkey_word_list[self.turn_id].append(word) 98 | elif self.pos_filter == "pos": 99 | for i, word in enumerate(proc_query["word"]): 100 | if ("NN" in proc_query["pos"][i]) or ("JJ" in proc_query["pos"][i]): 101 | if proc_query["score"][i] >= self.R_topic: 102 | if content=='r': 103 | self.r_key_word_list[self.turn_id].append(word) 104 | if content=='q': 105 | self.q_key_word_list[self.turn_id].append(word) 106 | if (proc_query["score"][i] >= self.R_sub) & ( 107 | proc_query["score"][i] < self.R_topic 108 | ): 109 | if content=='r': 110 | self.r_subkey_word_list[self.turn_id].append(word) 111 | if content=='q': 112 | self.q_subkey_word_list[self.turn_id].append(word) 113 | elif self.pos_filter == "stp": 114 | for i, word in enumerate(proc_query["word"]): 115 | if word not in STOP_WORDS: 116 | if proc_query["score"][i] >= self.R_topic: 117 | if content=='r': 118 | self.r_key_word_list[self.turn_id].append(word) 119 | if content=='q': 120 | self.q_key_word_list[self.turn_id].append(word) 121 | if (proc_query["score"][i] >= self.R_sub) & ( 122 | proc_query["score"][i] < self.R_topic 123 | ): 124 | if content=='r': 125 | self.r_subkey_word_list[self.turn_id].append(word) 126 | if ontent=='q': 127 | self.q_key_word_list[self.turn_id].append(word) 128 | 129 | def calc_word_score(self, query): 130 | nlp_query = nlp(pre_process(query)) 131 | proc_query = process(nlp_query) 132 | query_words = proc_query["word"] 133 | proc_query["score"] = [] 134 | 135 | for query_word in query_words: 136 | hits = self.searcher.search(query_word, 1) 137 | try: 138 | score = hits[0].score 139 | proc_query["score"].append(score) 140 | except: 141 | proc_query["score"].append(-1) 142 | 143 | return proc_query 144 | 145 | @staticmethod 146 | def query_expansion(key_word_list, start_turn, end_turn): 147 | query_expansion = "" 148 | if start_turn<0: 149 | start_turn=0 150 | for turn in range(start_turn, end_turn + 1): 151 | for word in key_word_list[turn]: 152 | query_expansion = query_expansion + " " + word 153 | return query_expansion 154 | 155 | 156 | def pre_process(text): 157 | text = re.sub( 158 | u"-|\u2010|\u2011|\u2012|\u2013|\u2014|\u2015|%|\[|\]|:|\(|\)|/|\t", 159 | _space_extend, 160 | text, 161 | ) 162 | text = text.strip(" \n") 163 | text = re.sub("\s+", " ", text) 164 | return text 165 | 166 | 167 | def process(parsed_text): 168 | output = { 169 | "word": [], 170 | "lemma": [], 171 | "pos": [], 172 | "pos_id": [], 173 | "ent": [], 174 | "ent_id": [], 175 | "offsets": [], 176 | "sentences": [], 177 | } 178 | 179 | for token in parsed_text: 180 | output["word"].append(_str(token.text)) 181 | pos = token.tag_ 182 | output["pos"].append(pos) 183 | 184 | return output 185 | 186 | 187 | def _space_extend(matchobj): 188 | return " " + matchobj.group(0) + " " 189 | 190 | 191 | def _str(s): 192 | """ Convert PTB tokens to normal tokens """ 193 | if s.lower() == "-lrb-": 194 | s = "(" 195 | elif s.lower() == "-rrb-": 196 | s = ")" 197 | elif s.lower() == "-lsb-": 198 | s = "[" 199 | elif s.lower() == "-rsb-": 200 | s = "]" 201 | elif s.lower() == "-lcb-": 202 | s = "{" 203 | elif s.lower() == "-rcb-": 204 | s = "}" 205 | return s 206 | -------------------------------------------------------------------------------- /chatty_goose/pipeline/retrieval_pipeline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # sys.path.append('../pyserini') # Here we import pyserini from our local, since the current version pyserini did not support dense retrieval with embedding as input (See line 67) 3 | from os import path 4 | import logging 5 | import json 6 | from typing import List, Optional, Union 7 | from pyserini.hsearch import HybridSearcher 8 | from pyserini.dsearch import SimpleDenseSearcher, DenseSearchResult 9 | from chatty_goose.cqr import ConversationalQueryRewriter 10 | from chatty_goose.util import reciprocal_rank_fusion 11 | from pygaggle.rerank.base import Query, Reranker, hits_to_texts 12 | from pyserini.search import JSimpleSearcherResult, SimpleSearcher 13 | import spacy 14 | 15 | 16 | 17 | __all__ = ["RetrievalPipeline"] 18 | 19 | 20 | class RetrievalPipeline: 21 | """ 22 | End-to-end conversational passage retrieval pipeline 23 | 24 | Parameters: 25 | searcher (SimpleSearcher): Pyserini searcher for Lucene index 26 | reformulators (List[ConversationalQueryRewriter]): List of CQR methods to use for first-stage retrieval 27 | searcher_num_hits (int): number of hits returned by searcher - default 10 28 | early_fusion (bool): flag to perform fusion before second-stage retrieval - default True 29 | reranker (Reranker): optional reranker for second-stage retrieval 30 | reranker_query_index (int): retriever index to use for reranking query - defaults to last retriever 31 | reranker_query_reformulator (ConversationalQueryRewriter): CQR method for generating reranker query, 32 | overrides reranker_query_index if provided 33 | """ 34 | 35 | def __init__( 36 | self, 37 | searcher: SimpleSearcher, 38 | dense_searcher: Optional[SimpleSearcher], 39 | reformulators: List[ConversationalQueryRewriter], 40 | searcher_num_hits: int = 10, 41 | early_fusion: bool = True, 42 | reranker: Reranker = None, 43 | reranker_query_index: int = -1, 44 | reranker_query_reformulator: ConversationalQueryRewriter = None, 45 | add_response: int = 0, 46 | context_index_path: str = None 47 | ): 48 | self.searcher = searcher 49 | self.dense_searcher = dense_searcher 50 | self.reformulators = reformulators 51 | self.searcher_num_hits = int(searcher_num_hits) 52 | self.early_fusion = early_fusion 53 | self.reranker = reranker 54 | self.reranker_query_index = reranker_query_index 55 | self.reranker_query_reformulator = reranker_query_reformulator 56 | self.add_response = add_response 57 | if add_response > 0: 58 | self.nlp = spacy.load("en_core_web_sm") 59 | self.nlp.add_pipe(self.nlp.create_pipe("sentencizer")) 60 | if self.searcher==None: 61 | assert (context_index_path!=None), "No context index path" 62 | logging.info("We do not conduct for sparse search. Load another index: {}, for context search ...".format(context_index_path)) 63 | if path.isdir(context_index_path): 64 | self.context_searcher = SimpleSearcher(context_index_path) 65 | else: 66 | self.context_searcher = SimpleSearcher.from_prebuilt_index(context_index_path) 67 | else: 68 | self.context_searcher = self.searcher 69 | 70 | def retrieve(self, query, context: Optional[str] = None) -> List[JSimpleSearcherResult]: 71 | cqr_hits = [] 72 | cqr_queries = [] 73 | for cqr in self.reformulators: 74 | sparse_hits, dense_hits = None, None 75 | new_query = cqr.rewrite(query, context, self.add_response) 76 | # Sparse search 77 | if self.searcher!=None: 78 | sparse_hits = self.searcher.search(new_query, k=self.searcher_num_hits) 79 | # Dense search 80 | if self.dense_searcher!=None: 81 | if cqr.name=='Cqe': #CQE embedding is generated during query rewritting, so here we directly input CQE embeddings for dense retrieval 82 | dense_hits = self.dense_searcher.search(cqr.query_embs, k=self.searcher_num_hits) 83 | else: 84 | dense_hits = self.dense_searcher.search(new_query, k=self.searcher_num_hits) 85 | 86 | hits = self._hybrid_results(dense_hits, sparse_hits, 0.1, self.searcher_num_hits) 87 | cqr_hits.append(hits) 88 | cqr_queries.append(new_query) 89 | 90 | 91 | # Merge results from multiple CQR methods if required 92 | if self.early_fusion or self.reranker is None: 93 | 94 | cqr_hits = reciprocal_rank_fusion(cqr_hits) 95 | 96 | # Return results if no reranker 97 | if self.reranker is None: 98 | return cqr_hits[:self.searcher_num_hits] 99 | 100 | # Get query for reranker 101 | if self.reranker_query_reformulator is None: 102 | rerank_query = cqr_queries[self.reranker_query_index] 103 | else: 104 | rerank_query = self.reranker_query_reformulator.rewrite(query) 105 | 106 | # Rerank results 107 | if self.early_fusion: 108 | results = self.rerank( 109 | rerank_query, cqr_hits[:self.searcher_num_hits]) 110 | else: 111 | # Rerank all CQR results and fuse together 112 | results = [] 113 | for hits in cqr_hits: 114 | results = self.rerank(rerank_query, hits) 115 | results = reciprocal_rank_fusion(results) 116 | return results 117 | 118 | def rerank(self, query, hits): 119 | if self.reranker is None: 120 | logging.info("Reranker not available, skipping reranking") 121 | return hits 122 | 123 | reranked = self.reranker.rerank(Query(query), hits_to_texts(hits)) 124 | reranked_scores = [r.score for r in reranked] 125 | 126 | # Reorder hits with reranker scores 127 | reranked = list(zip(hits, reranked_scores)) 128 | reranked.sort(key=lambda x: x[1], reverse=True) 129 | reranked_hits = [] 130 | for hit, score in reranked: 131 | hit.score = score #update score with rrf fusion score 132 | reranked_hits.append(hit) 133 | # reranked_hits = [r[0] for r in reranked] 134 | return reranked_hits 135 | 136 | def reset_history(self): 137 | for cqr in self.reformulators: 138 | cqr.reset_history() 139 | 140 | if self.reranker_query_reformulator: 141 | self.reranker_query_reformulator.reset_history() 142 | 143 | def get_context(self, docid: Union[str, int], sent_num=1) -> Optional[str]: 144 | if self.add_response==0: 145 | return None 146 | doc = self.context_searcher.doc(docid).raw() 147 | if doc is not None: 148 | doc = self.nlp(doc) 149 | sentences = [sent.string.strip() for sent in doc.sents] 150 | response = ' '.join(sentences[:sent_num]) 151 | 152 | return response 153 | return None 154 | # Directly copy from pyserini (https://github.com/castorini/pyserini/blob/master/pyserini/hsearch/_hybrid.py) 155 | @staticmethod 156 | def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False, weight_on_dense=False): 157 | if (dense_results==None) or (sparse_results==None): 158 | if dense_results==None: 159 | return sparse_results[:k] 160 | else: 161 | return dense_results[:k] 162 | dense_hits = {hit.docid: hit.score for hit in dense_results} 163 | sparse_hits = {hit.docid: hit.score for hit in sparse_results} 164 | hybrid_result = [] 165 | min_dense_score = min(dense_hits.values()) if len(dense_hits) > 0 else 0 166 | max_dense_score = max(dense_hits.values()) if len(dense_hits) > 0 else 1 167 | min_sparse_score = min(sparse_hits.values()) if len(sparse_hits) > 0 else 0 168 | max_sparse_score = max(sparse_hits.values()) if len(sparse_hits) > 0 else 1 169 | for doc in set(dense_hits.keys()) | set(sparse_hits.keys()): 170 | if doc not in dense_hits: 171 | sparse_score = sparse_hits[doc] 172 | dense_score = min_dense_score 173 | elif doc not in sparse_hits: 174 | sparse_score = min_sparse_score 175 | dense_score = dense_hits[doc] 176 | else: 177 | sparse_score = sparse_hits[doc] 178 | dense_score = dense_hits[doc] 179 | if normalization: 180 | sparse_score = (sparse_score - (min_sparse_score + max_sparse_score) / 2) \ 181 | / (max_sparse_score - min_sparse_score) 182 | dense_score = (dense_score - (min_dense_score + max_dense_score) / 2) \ 183 | / (max_dense_score - min_dense_score) 184 | score = alpha * sparse_score + dense_score if not weight_on_dense else sparse_score + alpha * dense_score 185 | hybrid_result.append(DenseSearchResult(doc, score)) 186 | return sorted(hybrid_result, key=lambda x: x.score, reverse=True)[:k] 187 | -------------------------------------------------------------------------------- /experiments/run_retrieval.py: -------------------------------------------------------------------------------- 1 | # Now we import pyserini directly from folder since input embeddings to SimpleDenseSearcher.search are not upgraded 2 | import sys 3 | # sys.path.append('../pyserini') 4 | import argparse 5 | import json 6 | import time 7 | 8 | from chatty_goose.cqr import Hqe, Ntr, Cqe 9 | from chatty_goose.pipeline import RetrievalPipeline 10 | from chatty_goose.settings import SearcherSettings, DenseSearcherSettings, HqeSettings, NtrSettings, CqeSettings 11 | from chatty_goose.types import CqrType, PosFilter 12 | from chatty_goose.util import build_bert_reranker, build_searcher, build_dense_searcher 13 | 14 | from pyserini.search import SimpleSearcher 15 | from pyserini.dsearch import SimpleDenseSearcher 16 | 17 | def parse_experiment_args(): 18 | parser = argparse.ArgumentParser(description='CQR experiments for CAsT 2019.') 19 | parser.add_argument('--experiment', type=str, help='Type of experiment (cqe, hqe, t5, hqe_t5_fusion, cqe_t5_fusion)') 20 | parser.add_argument('--qid_queries', required=True, default='', help='query id - query mapping file') 21 | parser.add_argument('--output', required=True, default='', help='output file') 22 | parser.add_argument('--sparse_index', default=None, help='bm25 index path') 23 | parser.add_argument('--dense_index', default=None, help='dense index path') 24 | parser.add_argument('--context_index', default='cast2019', help='index for searching context text') 25 | parser.add_argument('--query_encoder', default='castorini/tct_colbert-v2-msmarco', help='query encoder model path') 26 | parser.add_argument('--hits', default=10, help='number of hits to retrieve') 27 | parser.add_argument('--rerank', action='store_true', help='rerank BM25 output using BERT') 28 | parser.add_argument('--reranker_device', default='cuda', help='reranker device to use') 29 | parser.add_argument('--late_fusion', action='store_true', help='perform late instead of early fusion') 30 | parser.add_argument('--verbose', action='store_true', help='verbose log output') 31 | parser.add_argument('--context_field', default='manual_canonical_result_id', help='doc id for additional context') 32 | parser.add_argument('--add_response', type=int, default=0, help='How many response to add in context') 33 | parser.add_argument('--run_name', type=str, default=None, help='run file name printed in trec file') 34 | 35 | 36 | # Parameters for BM25. See Anserini MS MARCO documentation to understand how these parameter values were tuned 37 | parser.add_argument('--k1', default=0.82, help='BM25 k1 parameter') 38 | parser.add_argument('--b', default=0.68, help='BM25 b parameter') 39 | parser.add_argument('--rm3', action='store_true', default=False, help='use RM3') 40 | parser.add_argument('--fb_terms', default=10, type=int, help='RM3 parameter: number of expansion terms') 41 | parser.add_argument('--fb_docs', default=10, type=int, help='RM3 parameter: number of documents') 42 | parser.add_argument('--original_query_weight', default=0.8, type=float, help='RM3 parameter: weight to assign to the original query') 43 | 44 | # Parameters for HQE. The default values are tuned on CAsT train data 45 | parser.add_argument('--M0', default=5, type=int, help='aggregate historcial queries for first stage (BM25) retrieval') 46 | parser.add_argument('--M1', default=1, type=int, help='aggregate historcial queries for second stage (BERT) retrieval') 47 | parser.add_argument('--eta0', default=10, type=float, help='QPP threshold for first stage (BM25) retrieval') 48 | parser.add_argument('--eta1', default=12, type=float, help='QPP threshold for second stage (BERT) retrieval') 49 | parser.add_argument('--R0_topic', default=4.5, type=float, help='Topic keyword threshold for first stage (BM25) retrieval') 50 | parser.add_argument('--R1_topic', default=4, type=float, help='Topic keyword threshold for second stage (BERT) retrieval') 51 | parser.add_argument('--R0_sub', default=3.5, type=float, help='Subtopic keyword threshold for first stage (BM25) retrieval') 52 | parser.add_argument('--R1_sub', default=3, type=float, help='Subtopic keyword threshold for second stage (BERT) retrieval') 53 | parser.add_argument('--filter', default='pos', help='filter word method (no, pos, stp') 54 | 55 | # Parameters for T5 56 | parser.add_argument('--t5_model_name', default='castorini/t5-base-canard', help='T5 model name') 57 | parser.add_argument('--max_length', default=64, help='T5 max sequence length') 58 | parser.add_argument('--num_beams', default=10, help='T5 number of beams') 59 | parser.add_argument('--no_early_stopping', action='store_false', help='T5 disable early stopping') 60 | parser.add_argument('--t5_device', default='cuda', help='T5 device to use') 61 | 62 | # Parameters for CQE 63 | parser.add_argument('--cqe_model_name', default='castorini/tct_colbert-v2-msmarco-cqe', help='CQE model name') 64 | parser.add_argument('--cqe_l2_threshold', default=10.5, help='Term weight threashold for select terms') 65 | parser.add_argument('--cqe_max_context_length', default=100, help='CQE max context length') 66 | parser.add_argument('--cqe_max_query_length', default=36, help='CQE max query length') 67 | parser.add_argument('--cqe_device', default='cpu', help='CQE device to use') 68 | 69 | # Return args 70 | args = parser.parse_args() 71 | return args 72 | 73 | 74 | def run_experiment(rp: RetrievalPipeline): 75 | with open(args.output + ".trec", "w") as fout: 76 | total_query_count = 0 77 | with open(args.qid_queries) as json_file: 78 | data = json.load(json_file) 79 | 80 | qr_total_time = 0 81 | initial_time = time.time() 82 | for session in data: 83 | session_num = str(session["number"]) 84 | start_time = time.time() 85 | manual_context_buffer = [None for i in range(len(session["turn"]))] 86 | 87 | for turn_id, conversations in enumerate(session["turn"]): 88 | query = conversations["raw_utterance"] 89 | total_query_count += 1 90 | 91 | conversation_num = str(conversations["number"]) 92 | qid = session_num + "_" + conversation_num 93 | 94 | # qr_start_time = time.time() 95 | # qr_total_time += time.time() - qr_start_time 96 | 97 | if args.add_response!=0: 98 | docid = conversations[args.context_field] 99 | manual_context_buffer[turn_id] = rp.get_context(docid) 100 | # We don't use the current context for retrieval but save the context for next turn 101 | hits = rp.retrieve(query, manual_context_buffer[turn_id]) 102 | 103 | for rank in range(len(hits)): 104 | docno = hits[rank].docid 105 | score = hits[rank].score 106 | fout.write("{} Q0 {} {} {} {}\n".format(qid, docno, rank + 1, score, args.run_name)) 107 | 108 | rp.reset_history() 109 | time_per_query = (time.time() - start_time) / (turn_id + 1) 110 | print( 111 | "Retrieving session {} with {} queries ({:0.3f} s/query)".format( 112 | session["number"], turn_id + 1, time_per_query 113 | ) 114 | ) 115 | 116 | time_per_query = (time.time() - initial_time) / (total_query_count) 117 | qr_total_time = 0 118 | for reformulator in rp.reformulators: 119 | qr_total_time+=reformulator.total_latency 120 | qr_time_per_query = qr_total_time / (total_query_count) 121 | print( 122 | "Retrieving {} queries ({:0.3f} s/query, QR {:0.3f} s/query)".format( 123 | total_query_count, time_per_query, qr_time_per_query 124 | ) 125 | ) 126 | 127 | print("total Query Counts %d" % (total_query_count)) 128 | print("Done!") 129 | 130 | 131 | if __name__ == "__main__": 132 | args = parse_experiment_args() 133 | assert (args.sparse_index!=None) or (args.dense_index!=None), "Must input at least one index for search" 134 | if args.sparse_index==None: 135 | assert (args.context_index!=None) or (args.add_response==0), "Must input argument context_index" 136 | else: 137 | args.context_index = args.sparse_index 138 | if args.run_name==None: 139 | args.run_name = 'chatty-goose_' + args.experiment 140 | experiment = CqrType(args.experiment) 141 | 142 | searcher_settings = SearcherSettings( 143 | index_path=args.sparse_index, 144 | k1=args.k1, 145 | b=args.b, 146 | rm3=args.rm3, 147 | fb_terms=args.fb_terms, 148 | fb_docs=args.fb_docs, 149 | original_query_weight=args.original_query_weight, 150 | ) 151 | 152 | if experiment == CqrType.HQE or experiment == CqrType.HQE_T5_FUSION: 153 | #Currently, dense retrieval does not support HQE since it requires longer query sequence 154 | assert (args.dense_index==None), "HQE does not support dense retrieval. Do not input dense index while using HQE." 155 | dense_searcher_settings = DenseSearcherSettings( 156 | index_path=args.dense_index, 157 | query_encoder=args.query_encoder, 158 | ) 159 | 160 | searcher = build_searcher(searcher_settings) 161 | dense_searcher = build_dense_searcher(dense_searcher_settings) 162 | 163 | # Initialize CQR and reranker 164 | reformulators = [] 165 | reranker_query_reformulator = None 166 | reranker = build_bert_reranker(device=args.reranker_device) if args.rerank else None 167 | 168 | if experiment == CqrType.HQE or experiment == CqrType.HQE_T5_FUSION: 169 | hqe_bm25_settings = HqeSettings( 170 | M=args.M0, 171 | eta=args.eta0, 172 | R_topic=args.R0_topic, 173 | R_sub=args.R0_sub, 174 | filter=PosFilter(args.filter), 175 | verbose=args.verbose, 176 | ) 177 | hqe_bm25 = Hqe(searcher, hqe_bm25_settings) 178 | reformulators.append(hqe_bm25) 179 | 180 | if experiment == CqrType.T5 or experiment == CqrType.HQE_T5_FUSION or experiment == CqrType.CQE_T5_FUSION: 181 | # Initialize T5 NTR 182 | t5_settings = NtrSettings( 183 | model_name=args.t5_model_name, 184 | max_length=args.max_length, 185 | num_beams=args.num_beams, 186 | early_stopping=not args.no_early_stopping, 187 | verbose=args.verbose, 188 | ) 189 | t5 = Ntr(t5_settings, device=args.t5_device) 190 | reformulators.append(t5) 191 | 192 | if experiment == CqrType.HQE: 193 | hqe_bert_settings = HqeSettings( 194 | M=args.M1, 195 | eta=args.eta1, 196 | R_topic=args.R1_topic, 197 | R_sub=args.R1_sub, 198 | filter=PosFilter(args.filter), 199 | ) 200 | reranker_query_reformulator = Hqe(searcher, hqe_bert_settings) 201 | 202 | if experiment == CqrType.CQE or experiment == CqrType.CQE_T5_FUSION: 203 | cqe_settings = CqeSettings( 204 | model_name=args.cqe_model_name, 205 | l2_threshold=args.cqe_l2_threshold, 206 | max_context_length=args.cqe_max_context_length, 207 | max_query_length=args.cqe_max_query_length, 208 | verbose=args.verbose, 209 | ) 210 | cqe = Cqe(cqe_settings, device=args.cqe_device) 211 | reformulators.append(cqe) 212 | 213 | rp = RetrievalPipeline( 214 | searcher, 215 | dense_searcher, 216 | reformulators, 217 | searcher_num_hits=args.hits, 218 | early_fusion=not args.late_fusion, 219 | reranker=reranker, 220 | reranker_query_reformulator=reranker_query_reformulator, 221 | add_response = args.add_response, 222 | context_index_path = args.context_index 223 | ) 224 | run_experiment(rp) 225 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2021 Anserini Gaggle 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | --------------------------------------------------------------------------------