├── convinse ├── __init__.py ├── library │ ├── __init__.py │ ├── wikipedia_library.py │ ├── utils.py │ └── custom_trainer.py ├── distant_supervision │ ├── __init__.py │ ├── README.md │ ├── turn_relevance_annotator.py │ └── structured_representation_annotator.py ├── evidence_retrieval_scoring │ ├── wikipedia_retriever │ │ ├── __init__.py │ │ ├── text_parser.py │ │ ├── table_parser.py │ │ ├── infobox_parser.py │ │ └── wikipedia_retriever.py │ ├── bm25_es.py │ ├── clocq_bm25.py │ ├── evidence_retrieval_scoring.py │ ├── README.md │ └── clocq_er.py ├── heterogeneous_answering │ ├── heterogeneous_answering.py │ ├── README.md │ └── fid_module │ │ ├── fid_utils.py │ │ └── fid_module.py ├── question_understanding │ ├── question_understanding.py │ ├── README.md │ ├── question_rewriting │ │ ├── dataset_question_rewriting.py │ │ ├── question_rewriting_model.py │ │ └── question_rewriting_module.py │ ├── naive_concat │ │ └── naive_concat.py │ ├── structured_representation │ │ ├── dataset_structured_representation.py │ │ ├── structured_representation_module.py │ │ └── structured_representation_model.py │ └── question_resolution │ │ ├── question_resolution_module.py │ │ └── question_resolution_utils.py ├── evaluation.py └── pipeline.py ├── Makefile ├── requirements.txt ├── scripts ├── silver_annotation.sh ├── initialize.sh ├── pipeline.sh └── download.sh ├── .gitignore ├── setup.py ├── LICENSE └── config └── convmix ├── nc_all-clocq_bm25-fid.yml ├── nc_init-clocq_bm25-fid.yml ├── nc_prev-clocq_bm25-fid.yml ├── nc_init_prev-clocq_bm25-fid.yml ├── qres-clocq_bm25-fid.yml ├── convinse.yml └── qrew-clocq_bm25-fid.yml /convinse/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /convinse/library/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /convinse/distant_supervision/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | style: 2 | black --line-length 100 --target-version py38 . -------------------------------------------------------------------------------- /convinse/evidence_retrieval_scoring/wikipedia_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | bs4 3 | datasets 4 | matplotlib 5 | networkx 6 | numpy 7 | nltk 8 | pybind11 9 | python-Levenshtein 10 | torch_transformers 11 | pyyaml 12 | rank-bm25 13 | requests 14 | scikit-learn 15 | sentencepiece 16 | tensorboardX 17 | tqdm 18 | transformers 19 | wikitables -------------------------------------------------------------------------------- /scripts/silver_annotation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # read config parameter: if no present, stick to convinse.yaml 3 | CONFIG=${1:-config/convmix/convinse.yml} 4 | 5 | # adjust name to log 6 | IFS='/' read -ra NAME <<< "$CONFIG" 7 | DATA=${NAME[1]} 8 | IFS='.' read -ra NAME <<< "${NAME[2]}" 9 | NAME=${NAME[0]} 10 | OUT=out/${DATA}/silver_annotation_${NAME}.out 11 | mkdir -p out/${DATA} 12 | 13 | # start script 14 | nohup python -u convinse/distant_supervision/silver_annotation.py --inference $CONFIG > $OUT 2>&1 & -------------------------------------------------------------------------------- /scripts/initialize.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | # initialize root dir 4 | CONVINSE_ROOT=$(pwd) 5 | 6 | # create directories 7 | mkdir -p _benchmarks 8 | mkdir -p _data 9 | mkdir -p _intermediate_representations 10 | mkdir -p _results 11 | mkdir -p _results/convmix 12 | mkdir -p out 13 | mkdir -p out/convmix 14 | mkdir -p out/slurm 15 | 16 | # download 17 | bash scripts/download.sh convmix 18 | bash scripts/download.sh data 19 | bash scripts/download.sh wikipedia 20 | bash scripts/download.sh convinse 21 | bash scripts/download.sh annotated 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # meta data 2 | *__pycache__ 3 | convinse.egg-info/ 4 | .DS_Store 5 | 6 | # data folders 7 | _data/ 8 | _intermediate_representations/ 9 | _benchmarks/ 10 | 11 | # specific folders 12 | *results/ 13 | *logs/ 14 | *cache/ 15 | *examples/ 16 | *out/ 17 | *runs/ 18 | *tmp/ 19 | 20 | # data files 21 | *.pickle 22 | *.out 23 | *.json 24 | *.txt 25 | 26 | # specific paths 27 | clocq/ 28 | convmix/ 29 | test.py 30 | 31 | # other repos 32 | convinse/heterogeneous_answering/fid_module/FiD/ 33 | convinse/question_understanding/question_resolution/quretec/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("requirements.txt", "r") as f: 4 | requirements = list(f.read().splitlines()) 5 | 6 | setup( 7 | name="convinse", 8 | version="1.0", 9 | description="Code for the CONVINSE project (published in SIGIR 2022).", 10 | long_description=open("README.md").read(), 11 | long_description_content_type="text/markdown", 12 | author="Philipp Christmann", 13 | author_email="pchristm@mpi-inf.mpg.de", 14 | url="https://convinse.mpi-inf.mpg.de", 15 | packages=find_packages(), 16 | include_package_data=False, 17 | keywords=["qa", "question answering", "heterogeneous QA", "conversational", "ConvQA", "knowledge bases", "heterogeneous sources"], 18 | classifiers=[ 19 | "Programming Language :: Python :: 3.8" 20 | ], 21 | install_requires=requirements 22 | ) 23 | -------------------------------------------------------------------------------- /convinse/library/wikipedia_library.py: -------------------------------------------------------------------------------- 1 | """ 2 | Library for different string and path functions 3 | for the Wikipedia retriever. 4 | """ 5 | 6 | 7 | def format_wiki_path(value): 8 | """Reformat Wikipedia entity link.""" 9 | return value.replace("/wiki/", "") 10 | 11 | 12 | def is_wikipedia_path(value): 13 | """Check if the value is a Wikipedia entity.""" 14 | if not value: 15 | return False 16 | elif not value.startswith("/wiki"): 17 | return False 18 | elif value.startswith("/wiki/File:"): 19 | return False 20 | elif "Category:" in value: 21 | return False 22 | elif "Special:" in value: 23 | return False 24 | return True 25 | 26 | 27 | def _wiki_title_to_path(wiki_title): 28 | wiki_path = wiki_title.replace(" ", "_") 29 | wiki_path = wiki_path.replace("'", "%27") 30 | wiki_path = wiki_path.replace("-", "_") 31 | return wiki_path 32 | 33 | 34 | def _wiki_path_to_title(wiki_path): 35 | wiki_title = wiki_path.replace("_", " ") 36 | wiki_title = wiki_title.replace("%27", "'") 37 | return wiki_title 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Philipp Christmann, Rishiraj Saha Roy, Gerhard Weikum 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /scripts/pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | #SBATCH -o out/slurm/$OUT 3 | 4 | ## check argument length 5 | if [[ $# -lt 1 ]] 6 | then 7 | echo "Error: Invalid number of options: Please specify at least the pipeline-function." 8 | echo "Usage: bash scripts/pipeline.sh --train/--pred-answers/--gold-answers/--main-results/--example [] []" 9 | exit 0 10 | fi 11 | 12 | ## read config parameter: if no present, stick to default (default.yaml) 13 | FUNCTION=$1 14 | CONFIG=${2:-"config/convmix/convinse.yml"} 15 | SOURCES=${3:-"kb_text_table_info"} 16 | 17 | ## set path for output 18 | # get function name 19 | FUNCTION_NAME=${FUNCTION#"--"} 20 | # get data name 21 | IFS='/' read -ra NAME <<< "$CONFIG" 22 | DATA=${NAME[1]} 23 | # get config name 24 | CFG_NAME=${NAME[2]%".yml"} 25 | 26 | # set output path (include sources only if not default value) 27 | if [[ $# -lt 3 ]] 28 | then 29 | OUT="out/${DATA}/pipeline-${FUNCTION_NAME}-${CFG_NAME}.out" 30 | else 31 | OUT="out/${DATA}/pipeline-${FUNCTION_NAME}-${CFG_NAME}-${SOURCES}.out" 32 | fi 33 | 34 | 35 | ## fix global vars (required for FiD) 36 | export SLURM_NTASKS=1 37 | export TOKENIZERS_PARALLELISM=false 38 | 39 | ## start script 40 | if ! command -v sbatch &> /dev/null 41 | then 42 | # no slurm setup: run via nohup 43 | nohup python -u convinse/pipeline.py $FUNCTION $CONFIG $SOURCES > $OUT 2>&1 & 44 | else 45 | # run with sbatch 46 | sbatch <" 11 | exit 0 12 | fi 13 | 14 | case "$1" in 15 | "convinse") 16 | echo "Downloading CONVINSE data..." 17 | wget http://qa.mpi-inf.mpg.de/convinse/convmix_data/convinse.zip 18 | mkdir -p _data/convmix/ 19 | unzip convinse.zip -d _data/convmix/ 20 | rm convinse.zip 21 | echo "Successfully downloaded CONVINSE data!" 22 | ;; 23 | "convmix") 24 | echo "Downloading ConvMix dataset..." 25 | mkdir -p _benchmarks/convmix 26 | cd _benchmarks/convmix 27 | wget http://qa.mpi-inf.mpg.de/convinse/train_set.zip 28 | unzip train_set.zip 29 | rm train_set.zip 30 | wget http://qa.mpi-inf.mpg.de/convinse/dev_set.zip 31 | unzip dev_set.zip 32 | rm dev_set.zip 33 | wget http://qa.mpi-inf.mpg.de/convinse/test_set.zip 34 | unzip test_set.zip 35 | rm test_set.zip 36 | echo "Successfully downloaded ConvMix dataset!" 37 | ;; 38 | "wikipedia") 39 | echo "Downloading Wikipedia dump..." 40 | wget http://qa.mpi-inf.mpg.de/convinse/convmix_data/wikipedia.zip 41 | mkdir -p _data/convmix/ 42 | unzip wikipedia.zip -d _data/convmix/ 43 | rm wikipedia.zip 44 | echo "Successfully downloaded Wikipedia dump!" 45 | ;; 46 | "annotated") 47 | echo "Downloading annotated ConvMix data..." 48 | wget http://qa.mpi-inf.mpg.de/convinse/convmix_data/annotated.zip 49 | mkdir -p _intermediate_representations/convmix/ 50 | unzip annotated.zip -d _intermediate_representations/convmix/ 51 | rm annotated.zip 52 | echo "Successfully downloaded annotated ConvMix data!" 53 | ;; 54 | "data") 55 | echo "Downloading general repo data..." 56 | wget http://qa.mpi-inf.mpg.de/convinse/data.zip 57 | unzip data.zip -d _data 58 | rm data.zip 59 | echo "Successfully downloaded general repo data!" 60 | ;; 61 | *) 62 | echo "Error: Invalid specification of the data. Data $1 could not be found." 63 | exit 0 64 | ;; 65 | esac 66 | -------------------------------------------------------------------------------- /convinse/heterogeneous_answering/heterogeneous_answering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from convinse.library.utils import store_json_with_mkdir, get_logger 5 | 6 | 7 | class HeterogeneousAnswering: 8 | def __init__(self, config): 9 | """Initialize HA module.""" 10 | self.config = config 11 | self.logger = get_logger(__name__, config) 12 | 13 | def train(self, sources=["kb", "text", "table", "info"]): 14 | """ Method used in case no training required for HA phase. """ 15 | self.logger.info("Module used does not require training.") 16 | 17 | def inference(self): 18 | """Run HA on data and add answers for each source combination.""" 19 | input_dir = self.config["path_to_annotated"] 20 | output_dir = self.config["path_to_intermediate_results"] 21 | 22 | qu = self.config["qu"] 23 | ers = self.config["ers"] 24 | ha = self.config["ha"] 25 | 26 | source_combinations = self.config["source_combinations"] 27 | for sources in source_combinations: 28 | sources_string = "_".join(sources) 29 | 30 | input_path = os.path.join(input_dir, qu, ers, sources_string, "test_ers.jsonl") 31 | output_path = os.path.join(input_dir, qu, ers, sources_string, ha, "test_ha.json") 32 | self.inference_on_data_split(input_path, output_path, sources) 33 | 34 | def inference_on_data_split(self, input_path, output_path): 35 | """Run HA on given data split.""" 36 | # open data 37 | input_turns = list() 38 | data = list() 39 | with open(input_path, "r") as fp: 40 | line = fp.readline() 41 | while line: 42 | conversation = json.loads(line) 43 | input_turns += [turn for turn in conversation["questions"]] 44 | data.append(conversation) 45 | line = fp.readline() 46 | 47 | # inference 48 | self.inference_on_turns(input_turns) 49 | 50 | # store processed data 51 | store_json_with_mkdir(data, output_path) 52 | 53 | def inference_on_data(self, input_data): 54 | """Run HA on given data.""" 55 | input_turns = [turn for conv in input_data for turn in conv["questions"]] 56 | self.inference_on_turns(input_turns) 57 | return input_data 58 | 59 | def inference_on_turns(self, input_turns): 60 | """Run HA on a set of turns.""" 61 | for turn in turns: 62 | self.inference_on_turn(turn) 63 | 64 | def inference_on_turn(self, turn): 65 | raise Exception( 66 | "This is an abstract function which should be overwritten in a derived class!" 67 | ) 68 | -------------------------------------------------------------------------------- /convinse/question_understanding/question_understanding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from tqdm import tqdm 5 | 6 | from convinse.library.utils import store_json_with_mkdir, get_logger 7 | 8 | 9 | class QuestionUnderstanding: 10 | """Abstract class for QU phase.""" 11 | 12 | def __init__(self, config, use_gold_answers): 13 | """Initialize QU module.""" 14 | self.config = config 15 | self.logger = get_logger(__name__, config) 16 | self.use_gold_answers = use_gold_answers 17 | 18 | def train(self): 19 | """Method used in case no training required for QU phase.""" 20 | self.logger.info("QU - Module used does not require training.") 21 | 22 | def inference(self): 23 | """Run model on data and add predictions.""" 24 | # inference: add predictions to data 25 | qu = self.config["qu"] 26 | input_dir = self.config["path_to_annotated"] 27 | output_dir = self.config["path_to_intermediate_results"] 28 | 29 | input_path = os.path.join(input_dir, "annotated_train.json") 30 | output_path = os.path.join(output_dir, qu, "train_qu.json") 31 | self.inference_on_data_split(input_path, output_path) 32 | 33 | input_path = os.path.join(input_dir, "annotated_dev.json") 34 | output_path = os.path.join(output_dir, qu, "dev_qu.json") 35 | self.inference_on_data_split(input_path, output_path) 36 | 37 | input_path = os.path.join(input_dir, "annotated_test.json") 38 | output_path = os.path.join(output_dir, qu, "test_qu.json") 39 | self.inference_on_data_split(input_path, output_path) 40 | 41 | def inference_on_data_split(self, input_path, output_path): 42 | """Run model on data and add predictions.""" 43 | self.logger.info(f"QU - Starting inference on {input_path}.") 44 | 45 | # open data 46 | with open(input_path, "r") as fp: 47 | data = json.load(fp) 48 | 49 | # model inference on given data 50 | self.inference_on_data(data) 51 | 52 | # store data 53 | store_json_with_mkdir(data, output_path) 54 | 55 | # log 56 | self.logger.info(f"QU - Inference done on {input_path}.") 57 | 58 | def inference_on_data(self, input_data): 59 | """Run model on data and add predictions.""" 60 | # model inference on given data 61 | for conversation in tqdm(input_data): 62 | self.inference_on_conversation(conversation) 63 | return input_data 64 | 65 | def inference_on_conversation(self, conversation): 66 | raise Exception( 67 | "This is an abstract function which should be overwritten in a derived class!" 68 | ) 69 | 70 | def inference_on_turn(self, turn, history_turns): 71 | raise Exception( 72 | "This is an abstract function which should be overwritten in a derived class!" 73 | ) 74 | -------------------------------------------------------------------------------- /convinse/question_understanding/README.md: -------------------------------------------------------------------------------- 1 | # Question Understanding (QU) 2 | 3 | Module to create an intent-explicit form of the current question and the corresponding conversational history. 4 | 5 | - [Create your own QU module](#create-your-own-qu-module) 6 | - [`inference_on_turn` function](#inference_on_turn-function) 7 | - [`inference_on_conversation` function](#inference_on_conversation-function) 8 | - [`train` function](#optional-train-function) 9 | 10 | ## Create your own QU module 11 | You can inherit from the [QuestionUnderstanding](question_understanding.py) class and create your own QU module. Implementing the functions `inference_on_turn` and `inference_on_conversation` is sufficient for the pipeline to run properly. You might want to implement your own training procedure for your module via the `train` function though. 12 | 13 | Further, you need to instantiate a logger in the class, which will be used in the parent class. 14 | Alternatively, you can call the __init__ method of the parent class. 15 | Also, make sure to use the `use_gold_answers` parameter properly in your derived class. 16 | This parameter will be given as parameter when initializing the module. 17 | 18 | ## `inference_on_turn` function 19 | 20 | **Inputs**: 21 | - `turn`: the current turn for which the intent-explicit representation should be generated. 22 | - `history_turns`: the previous turns in the conversation. Can be used to generate the intent-explicit form. List of turn dictionaries. 23 | 24 | **Description**: 25 | This method is supposed to generate an intent-explicit form of the current question, given the conversational history. 26 | Please make sure that the class parameter `use_gold_answers` controls whether the gold answer(s) (in `turn["answers"]`) or the predicted answer(s) (in `turn["pred_answers"]`) are used. 27 | 28 | **Output**: 29 | Returns the turn. Make sure to store the intent-explicit representation of the information need in `turn["structured_representation"]`. 30 | 31 | ## `inference_on_conversation` function 32 | 33 | **Inputs**: 34 | - `conversation`: the conversation for which the intent-explicit representation should be generated. 35 | 36 | **Description**: 37 | This method is supposed to generate intent-explicit forms for all turns in the conversation. In the method, you can keep track of the conversational history (e.g. using a list). Please make sure that the class parameter `use_gold_answers` controls whether the gold answer(s) (in `turn["answers"]`) or the predicted answer(s) (in `turn["pred_answers"]`) are used. 38 | 39 | **Output**: 40 | Returns the conversation. Make sure to store the intent-explicit representation of the information need for every turn in the conversation, in `turn["structured_representation"]`. 41 | 42 | ## [Optional] `train` function 43 | 44 | **Inputs**: NONE 45 | 46 | **Description**: 47 | If required, you can train your QU module here. You can make use of whatever parameters are stored in your .yml file. 48 | 49 | **Output**: NONE 50 | -------------------------------------------------------------------------------- /convinse/heterogeneous_answering/README.md: -------------------------------------------------------------------------------- 1 | # Heterogeneous Answering (HA) 2 | 3 | Module to answer the intent-explicit representation of the question, using the top-*e* retrieved evidences. 4 | 5 | - [Create your own Ha module](#create-your-own-ha-module) 6 | - [`inference_on_turn` function](#inference_on_turn-function) 7 | - [`inference_on_conversation` function](#inference_on_conversation-function) 8 | - [`train` function](#optional-train-function) 9 | - [Answer format](#answer-format) 10 | 11 | ## Create your own HA module 12 | You can inherit from the [HeterogeneousAnswering](heterogeneous_answering.py) class and create your own QU module. Implementing the function `inference_on_turn` is sufficient for the pipeline to run properly. You might want to implement your own training procedure for your module via the `train` function though. 13 | 14 | Further, you need to instantiate a logger in the class, which will be used in the parent class. 15 | Alternatively, you can call the __init__ method of the parent class. 16 | 17 | ## `inference_on_turn` function 18 | 19 | **Inputs**: 20 | - `turn`: the turn, for which the answer should be predicted. You can access the intent-explicit representation of the information need via `turn["structured_representation"]`, and the top-*e* evidences via `turn["top_evidences"]`. 21 | 22 | **Description**: 23 | Run the HA module on the information need, and predict the answer(s). 24 | 25 | **Output**: 26 | Returns the turn. Make sure to add the predicted answers to `turn["pred_answers"]`. You can find additional information on the [expected answer format](#answer-format) below. 27 | 28 | ## [Optional] `train` function 29 | 30 | **Inputs**: 31 | - `sources`: list of sources for which the HA module should be trained. The default setting is to train a single model for all sources (and combinations of sources) for generalizability. 32 | 33 | **Description**: 34 | If required, you can train your HA module here. You can make use of whatever parameters are stored in your .yml file. 35 | 36 | **Output**: NONE 37 | 38 | ## Answer format 39 | The predicted answers are given as list of answer-dictionaries, and should be stored in `turn["pred_answers"]`. 40 | Note, that the answers should be normalized to Wikidata. This allows for fair comparison beyond plain string matching. 41 | Further, in a real use-case this has the advantage that knowledge cards can be shown for the given KB items. 42 | In case a date or year is returned, give the corresponding timestamp as the ID ("2011-04-17T00:00:00Z"; standard format in Wikidata, and the CLOCQ API), and a verbalized version as label ("17 April 2011"). You can make use of the timestamp-related functions in the [StringLibrary](../library/string_library.py). 43 | ``` json 44 | [{ 45 | "id": "", 46 | "label": "", 47 | "rank": "" 48 | }] 49 | ``` 50 | 51 | `rank` starts with 1, and has exactly one answer at every rank (for comparison on ConvMix). 52 | 53 | Example: 54 | ``` json 55 | [ 56 | { 57 | "id": "Q23633", 58 | "label": "HBO", 59 | "rank": "1" 60 | }, 61 | { 62 | "id": "2011-04-17T00:00:00Z", 63 | "label": "17 April 2011", 64 | "rank": "2" 65 | } 66 | ] 67 | ``` 68 | 69 | -------------------------------------------------------------------------------- /convinse/evidence_retrieval_scoring/clocq_bm25.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import logging 6 | 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | 10 | from convinse.library.utils import get_config, get_logger 11 | from convinse.evidence_retrieval_scoring.evidence_retrieval_scoring import EvidenceRetrievalScoring 12 | from convinse.evidence_retrieval_scoring.clocq_er import ClocqRetriever 13 | from convinse.evidence_retrieval_scoring.bm25_es import BM25Scoring 14 | 15 | 16 | class ClocqBM25(EvidenceRetrievalScoring): 17 | def __init__(self, config): 18 | self.config = config 19 | self.logger = get_logger(__name__, config) 20 | self.evr = ClocqRetriever(config) 21 | self.evs = BM25Scoring(config) 22 | 23 | def inference_on_turn(self, turn, sources=["kb", "text", "table", "info"]): 24 | """Retrieve best evidences for SR.""" 25 | structured_representation = turn["structured_representation"] 26 | evidences, _ = self.evr.retrieve_evidences(structured_representation, sources) 27 | top_evidences = self.evs.get_top_evidences(structured_representation, evidences) 28 | turn["top_evidences"] = top_evidences 29 | return top_evidences 30 | 31 | def store_cache(self): 32 | """Store cache of evidence retriever.""" 33 | self.evr.store_cache() 34 | 35 | 36 | ####################################################################################################################### 37 | ####################################################################################################################### 38 | if __name__ == "__main__": 39 | if len(sys.argv) != 2: 40 | raise Exception("python convinse/evidence_retrieval_scoring/clocq_bm25.py ") 41 | 42 | # load config 43 | config_path = sys.argv[1] 44 | config = get_config(config_path) 45 | ers = ClocqBM25(config) 46 | 47 | # inference: add predictions to data 48 | input_dir = config["path_to_annotated"] 49 | output_dir = config["path_to_intermediate_results"] 50 | 51 | qu = config["qu"] 52 | source_combinations = config["source_combinations"] 53 | 54 | for sources in source_combinations: 55 | sources_string = "_".join(sources) 56 | 57 | input_path = os.path.join(input_dir, qu, "train_qu.json") 58 | if os.path.exists(input_path): 59 | output_path = os.path.join( 60 | output_dir, qu, "clocq_bm25", sources_string, "train_ers.jsonl" 61 | ) 62 | ers.inference_on_data_split(input_path, output_path, sources) 63 | 64 | input_path = os.path.join(input_dir, qu, "dev_qu.json") 65 | if os.path.exists(input_path): 66 | output_path = os.path.join( 67 | output_dir, qu, "clocq_bm25", sources_string, "dev_ers.jsonl" 68 | ) 69 | ers.inference_on_data_split(input_path, output_path, sources) 70 | 71 | input_path = os.path.join(input_dir, qu, "test_qu.json") 72 | output_path = os.path.join(output_dir, qu, "clocq_bm25", sources_string, "test_ers.jsonl") 73 | ers.inference_on_data_split(input_path, output_path, sources) 74 | 75 | # store results in cache 76 | ers.store_cache() 77 | -------------------------------------------------------------------------------- /convinse/distant_supervision/README.md: -------------------------------------------------------------------------------- 1 | # Distant Supervision 2 | 3 | - [Usage](#usage) 4 | - [Input format](#input-format) 5 | - [Output format](#output-format) 6 | 7 | ## Usage 8 | For running the distant supervision on a given dataset, simply run: 9 | ``` 10 | bash scripts/silver_annotation.sh [] 11 | ``` 12 | from the ROOT directory of the project. 13 | The paths to the input files will be read from the given config values for `train_input_path`, `dev_input_path`, and `test_input_path`. 14 | This will create annotated versions of the benchmark in `_intermediate_representations//`. 15 | 16 | ## Input format 17 | The annotation script expects the benchmark in the following (minimal) format: 18 | ``` 19 | [ 20 | // first conversation 21 | { 22 | "conversation_id": "", 23 | "questions": [ 24 | // question 1 (complete) 25 | { 26 | "turn": 0, 27 | "question_id": "", 28 | "question": "", 29 | "answers": [ 30 | { 31 | "id": "", 32 | "label": " 33 | }, 34 | ] 35 | }, 36 | // question 2 (incomplete) 37 | { 38 | "turn": 1, 39 | "question_id": "", 40 | "question": "", 41 | "answers": [ 42 | { 43 | "id": "", 44 | "label": " 45 | }, 46 | ] 47 | ] 48 | }, 49 | // second conversation 50 | { 51 | ... 52 | }, 53 | // ... 54 | ] 55 | ``` 56 | Any other keys can be provided, and will be written to the output. 57 | You can see [here](../heterogeneous_answering#answer-format) for additional information of the expected format of the answer IDs and labels. 58 | 59 | ## Output format 60 | The result will be stored in a .json file: 61 | 62 | ``` 63 | [ 64 | // first conversation 65 | { 66 | "conversation_id": "", 67 | "questions": [ 68 | // question 1 (complete) 69 | { 70 | "turn": 0, 71 | "question_id": "", 72 | "question": "", 73 | "answers": [ 74 | { 75 | "id": "", 76 | "label": " 77 | }, 78 | // supervision signals from weak supervision 79 | "silver_SR": [ 80 | // SR 1 81 | ["", "",] 82 | ], 83 | "silver_relevant_turns": [ 84 | // list of integers referring to the relevant turns 85 | // -> this data is not used in current framework 86 | 0 87 | ] 88 | }, 89 | // question 2 (incomplete) 90 | { 91 | "turn": 1, 92 | "question_id": "", 93 | "question": "", 94 | "completed_question": "", 95 | "answers": [ 96 | { 97 | "id": "", 98 | "label": " 99 | }, 100 | // supervision signals from weak supervision 101 | "silver_SR": [ 102 | // SR 1 103 | ["", "",] 104 | ], 105 | "silver_relevant_turns": [ 106 | // list of integers referring to the relevant turns 107 | // -> this data is not used in current framework 108 | 0 109 | ] 110 | }, 111 | // ... 112 | ] 113 | // ... 114 | }, 115 | // second conversation 116 | { 117 | ... 118 | }, 119 | // ... 120 | ] 121 | ``` 122 | -------------------------------------------------------------------------------- /convinse/library/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import json 5 | import logging 6 | from pathlib import Path 7 | 8 | 9 | def get_config(path): 10 | """Load the config dict from the given .yml file.""" 11 | with open(path, "r") as fp: 12 | config = yaml.safe_load(fp) 13 | return config 14 | 15 | 16 | def store_json_with_mkdir(data, output_path, indent=True): 17 | """Store the JSON data in the given path.""" 18 | # create path if not exists 19 | output_dir = os.path.dirname(output_path) 20 | Path(output_dir).mkdir(parents=True, exist_ok=True) 21 | with open(output_path, "w") as fp: 22 | fp.write(json.dumps(data, indent=4)) 23 | 24 | 25 | def get_logger(mod_name, config): 26 | """Get a logger instance for the given module name.""" 27 | # create logger 28 | logger = logging.getLogger(mod_name) 29 | # add handler and format 30 | handler = logging.StreamHandler(sys.stdout) 31 | formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s') 32 | handler.setFormatter(formatter) 33 | logger.addHandler(handler) 34 | # set log level 35 | log_level = config["log_level"] 36 | logger.setLevel(getattr(logging, log_level)) 37 | return logger 38 | 39 | 40 | def get_result_logger(config): 41 | """Get a logger instance for the given module name.""" 42 | # create logger 43 | logger = logging.getLogger("result_logger") 44 | # add handler and format 45 | method_name = config["name"] 46 | benchmark = config["benchmark"] 47 | result_file = f"_results/{benchmark}/{method_name}.res" 48 | result_dir = os.path.dirname(result_file) 49 | Path(result_dir).mkdir(parents=True, exist_ok=True) 50 | handler = logging.FileHandler(result_file) 51 | formatter = logging.Formatter('%(asctime)s %(message)s') 52 | handler.setFormatter(formatter) 53 | logger.addHandler(handler) 54 | # set log level 55 | logger.setLevel("INFO") 56 | return logger 57 | 58 | 59 | def plot_flow_graph(graph): 60 | """ 61 | Predict turn relevances among the given conversation. 62 | The method will plot the resulting flow graph. 63 | """ 64 | nx.nx_agraph.write_dot(graph, "test.dot") 65 | # same layout using matplotlib with no labels 66 | pos = graphviz_layout(graph, prog="dot") 67 | pos = pos 68 | plt.figure(figsize=(18, 20)) 69 | nx.draw(graph, pos, with_labels=True, arrows=True, node_size=100) 70 | # nx.draw(G, pos, with_labels=True, arrows=True, node_size=100, figsize=(20, 20), dpi=150) 71 | plt.xlim([-1, 800]) 72 | plt.show() 73 | 74 | 75 | def print_dict(python_dict): 76 | """Print python dict as json-string.""" 77 | json_string = json.dumps(python_dict) 78 | print(json_string) 79 | 80 | 81 | def print_verbose(config, string): 82 | """Print the given string if verbose is set.""" 83 | if config["verbose"]: 84 | print(str(string)) 85 | 86 | 87 | def extract_mapping_incomplete_complete(data_paths): 88 | """ 89 | Extract mapping from incomplete questions to complete 90 | questions for all follow-up questions. 91 | """ 92 | mapping_incomplete_to_complete = dict() 93 | for data_path in data_paths: 94 | with open(data_path, "r") as fp: 95 | dataset = json.load(fp) 96 | 97 | for conversation in dataset: 98 | for turn in conversation["questions"]: 99 | if turn["turn"] == 0: 100 | continue 101 | question = turn["question"] 102 | completed = turn["completed"] 103 | mapping_incomplete_to_complete[question] = completed 104 | return mapping_incomplete_to_complete 105 | 106 | -------------------------------------------------------------------------------- /config/convmix/nc_all-clocq_bm25-fid.yml: -------------------------------------------------------------------------------- 1 | name: "nc_all-clocq_bm25-fid" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: nc_all 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | 46 | train_input_path: "train_set/train_set_ALL.json" 47 | dev_input_path: "dev_set/dev_set_ALL.json" 48 | test_input_path: "test_set/test_set_ALL.json" 49 | 50 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 51 | path_to_intermediate_results: "_intermediate_representations/convmix" 52 | 53 | ################################################################# 54 | # Parameters - CLOCQ 55 | ################################################################# 56 | clocq_params: 57 | h_match: 0.4 58 | h_rel: 0.2 59 | h_conn: 0.3 60 | h_coh: 0.1 61 | d: 20 62 | k: "AUTO" 63 | p_setting: 1000 # setting for search_space function 64 | bm25_limit: False 65 | clocq_p: 1000 # setting for neighborhood function(s) 66 | clocq_use_api: True # using CLOCQClientInterface 67 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 68 | clocq_port: "443" # port for client 69 | 70 | ################################################################# 71 | # Parameters - Silver annotation 72 | ################################################################# 73 | # annotation - SR 74 | sr_relation_shared_active: True 75 | sr_remove_stopwords: True 76 | 77 | # OPTIONAL: annotation - turn relevance 78 | tr_transitive_relevances: False 79 | tr_extract_dataset: True 80 | 81 | ################################################################# 82 | # Parameters - QU 83 | ################################################################# 84 | naive_concat: all 85 | 86 | ################################################################# 87 | # Parameters - ERS 88 | ################################################################# 89 | # cache path 90 | ers_use_cache: True 91 | ers_cache_path: "_data/convmix/nc_all/er_cache.pickle" 92 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 93 | ers_on_the_fly: True 94 | 95 | # evidence retrieval 96 | evr_min_evidence_length: 3 97 | evr_max_evidence_length: 200 98 | evr_max_entities: 10 # max entities per evidence 99 | evr_max_pos_evidences: 10 100 | 101 | # evidence scoring 102 | evs_max_evidences: 100 103 | 104 | ################################################################# 105 | # Parameters - HA 106 | ################################################################# 107 | # general 108 | ha_max_answers: 50 109 | 110 | fid_model_path: "_data/convmix/nc_all/fid/best_dev" 111 | fid_per_gpu_batch_size: 1 112 | fid_max_evidences: 100 113 | 114 | # train 115 | fid_lr: 0.00005 116 | fid_optim: adamw 117 | fid_scheduler: linear 118 | fid_weight_decay: 0.01 119 | fid_text_maxlength: 250 120 | fid_answer_maxlength: 10 121 | fid_total_step: 15000 122 | fid_warmup_step: 1000 123 | 124 | # inference 125 | fid_max_evidences: 100 126 | fid_num_beams: 20 127 | -------------------------------------------------------------------------------- /config/convmix/nc_init-clocq_bm25-fid.yml: -------------------------------------------------------------------------------- 1 | name: "nc_init-clocq_bm25-fid" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: nc_init 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | 46 | train_input_path: "train_set/train_set_ALL.json" 47 | dev_input_path: "dev_set/dev_set_ALL.json" 48 | test_input_path: "test_set/test_set_ALL.json" 49 | 50 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 51 | path_to_intermediate_results: "_intermediate_representations/convmix" 52 | 53 | ################################################################# 54 | # Parameters - CLOCQ 55 | ################################################################# 56 | clocq_params: 57 | h_match: 0.4 58 | h_rel: 0.2 59 | h_conn: 0.3 60 | h_coh: 0.1 61 | d: 20 62 | k: "AUTO" 63 | p_setting: 1000 # setting for search_space function 64 | bm25_limit: False 65 | clocq_p: 1000 # setting for neighborhood function(s) 66 | clocq_use_api: True # using CLOCQClientInterface 67 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 68 | clocq_port: "443" # port for client 69 | 70 | ################################################################# 71 | # Parameters - Silver annotation 72 | ################################################################# 73 | # annotation - SR 74 | sr_relation_shared_active: True 75 | sr_remove_stopwords: True 76 | 77 | # OPTIONAL: annotation - turn relevance 78 | tr_transitive_relevances: False 79 | tr_extract_dataset: True 80 | 81 | ################################################################# 82 | # Parameters - QU 83 | ################################################################# 84 | naive_concat: init 85 | 86 | ################################################################# 87 | # Parameters - ERS 88 | ################################################################# 89 | # cache path 90 | ers_use_cache: True 91 | ers_cache_path: "_data/convmix/nc_init/er_cache.pickle" 92 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 93 | ers_on_the_fly: True 94 | 95 | # evidence retrieval 96 | evr_min_evidence_length: 3 97 | evr_max_evidence_length: 200 98 | evr_max_entities: 10 # max entities per evidence 99 | evr_max_pos_evidences: 10 100 | 101 | # evidence scoring 102 | evs_max_evidences: 100 103 | 104 | ################################################################# 105 | # Parameters - HA 106 | ################################################################# 107 | # general 108 | ha_max_answers: 50 109 | 110 | fid_model_path: "_data/convmix/nc_init/fid/best_dev" 111 | fid_per_gpu_batch_size: 1 112 | fid_max_evidences: 100 113 | 114 | # train 115 | fid_lr: 0.00005 116 | fid_optim: adamw 117 | fid_scheduler: linear 118 | fid_weight_decay: 0.01 119 | fid_text_maxlength: 250 120 | fid_answer_maxlength: 10 121 | fid_total_step: 15000 122 | fid_warmup_step: 1000 123 | 124 | # inference 125 | fid_max_evidences: 100 126 | fid_num_beams: 20 -------------------------------------------------------------------------------- /config/convmix/nc_prev-clocq_bm25-fid.yml: -------------------------------------------------------------------------------- 1 | name: "nc_prev-clocq_bm25-fid" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: nc_prev 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | 46 | train_input_path: "train_set/train_set_ALL.json" 47 | dev_input_path: "dev_set/dev_set_ALL.json" 48 | test_input_path: "test_set/test_set_ALL.json" 49 | 50 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 51 | path_to_intermediate_results: "_intermediate_representations/convmix" 52 | 53 | ################################################################# 54 | # Parameters - CLOCQ 55 | ################################################################# 56 | clocq_params: 57 | h_match: 0.4 58 | h_rel: 0.2 59 | h_conn: 0.3 60 | h_coh: 0.1 61 | d: 20 62 | k: "AUTO" 63 | p_setting: 1000 # setting for search_space function 64 | bm25_limit: False 65 | clocq_p: 1000 # setting for neighborhood function(s) 66 | clocq_use_api: True # using CLOCQClientInterface 67 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 68 | clocq_port: "443" # port for client 69 | 70 | ################################################################# 71 | # Parameters - Silver annotation 72 | ################################################################# 73 | # annotation - SR 74 | sr_relation_shared_active: True 75 | sr_remove_stopwords: True 76 | 77 | # OPTIONAL: annotation - turn relevance 78 | tr_transitive_relevances: False 79 | tr_extract_dataset: True 80 | 81 | ################################################################# 82 | # Parameters - QU 83 | ################################################################# 84 | naive_concat: prev 85 | 86 | ################################################################# 87 | # Parameters - ERS 88 | ################################################################# 89 | # cache path 90 | ers_use_cache: True 91 | ers_cache_path: "_data/convmix/nc_prev/er_cache.pickle" 92 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 93 | ers_on_the_fly: True 94 | 95 | # evidence retrieval 96 | evr_min_evidence_length: 3 97 | evr_max_evidence_length: 200 98 | evr_max_entities: 10 # max entities per evidence 99 | evr_max_pos_evidences: 10 100 | 101 | # evidence scoring 102 | evs_max_evidences: 100 103 | 104 | ################################################################# 105 | # Parameters - HA 106 | ################################################################# 107 | # general 108 | ha_max_answers: 50 109 | 110 | fid_model_path: "_data/convmix/nc_prev/fid/best_dev" 111 | fid_per_gpu_batch_size: 1 112 | fid_max_evidences: 100 113 | 114 | # train 115 | fid_lr: 0.00005 116 | fid_optim: adamw 117 | fid_scheduler: linear 118 | fid_weight_decay: 0.01 119 | fid_text_maxlength: 250 120 | fid_answer_maxlength: 10 121 | fid_total_step: 15000 122 | fid_warmup_step: 1000 123 | 124 | # inference 125 | fid_max_evidences: 100 126 | fid_num_beams: 20 127 | -------------------------------------------------------------------------------- /config/convmix/nc_init_prev-clocq_bm25-fid.yml: -------------------------------------------------------------------------------- 1 | name: "nc_init_prev-clocq_bm25-fid" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: nc_init_prev 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | 46 | train_input_path: "train_set/train_set_ALL.json" 47 | dev_input_path: "dev_set/dev_set_ALL.json" 48 | test_input_path: "test_set/test_set_ALL.json" 49 | 50 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 51 | path_to_intermediate_results: "_intermediate_representations/convmix" 52 | 53 | ################################################################# 54 | # Parameters - CLOCQ 55 | ################################################################# 56 | clocq_params: 57 | h_match: 0.4 58 | h_rel: 0.2 59 | h_conn: 0.3 60 | h_coh: 0.1 61 | d: 20 62 | k: "AUTO" 63 | p_setting: 1000 # setting for search_space function 64 | bm25_limit: False 65 | clocq_p: 1000 # setting for neighborhood function(s) 66 | clocq_use_api: True # using CLOCQClientInterface 67 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 68 | clocq_port: "443" # port for client 69 | 70 | ################################################################# 71 | # Parameters - Silver annotation 72 | ################################################################# 73 | # annotation - SR 74 | sr_relation_shared_active: True 75 | sr_remove_stopwords: True 76 | 77 | # OPTIONAL: annotation - turn relevance 78 | tr_transitive_relevances: False 79 | tr_extract_dataset: True 80 | 81 | ################################################################# 82 | # Parameters - QU 83 | ################################################################# 84 | naive_concat: init_prev 85 | 86 | ################################################################# 87 | # Parameters - ERS 88 | ################################################################# 89 | # cache path 90 | ers_use_cache: True 91 | ers_cache_path: "_data/convmix/nc_init_prev/er_cache.pickle" 92 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 93 | ers_on_the_fly: True 94 | 95 | # evidence retrieval 96 | evr_min_evidence_length: 3 97 | evr_max_evidence_length: 200 98 | evr_max_entities: 10 # max entities per evidence 99 | evr_max_pos_evidences: 10 100 | 101 | # evidence scoring 102 | evs_max_evidences: 100 103 | 104 | ################################################################# 105 | # Parameters - HA 106 | ################################################################# 107 | # general 108 | ha_max_answers: 50 109 | 110 | fid_model_path: "_data/convmix/nc_init_prev/fid/best_dev" 111 | fid_per_gpu_batch_size: 1 112 | fid_max_evidences: 100 113 | 114 | # train 115 | fid_lr: 0.00005 116 | fid_optim: adamw 117 | fid_scheduler: linear 118 | fid_weight_decay: 0.01 119 | fid_text_maxlength: 250 120 | fid_answer_maxlength: 10 121 | fid_total_step: 15000 122 | fid_warmup_step: 1000 123 | 124 | # inference 125 | fid_max_evidences: 100 126 | fid_num_beams: 20 -------------------------------------------------------------------------------- /config/convmix/qres-clocq_bm25-fid.yml: -------------------------------------------------------------------------------- 1 | name: "qres-clocq_bm25-fid" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: qres 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | 46 | train_input_path: "train_set/train_set_ALL.json" 47 | dev_input_path: "dev_set/dev_set_ALL.json" 48 | test_input_path: "test_set/test_set_ALL.json" 49 | 50 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 51 | path_to_intermediate_results: "_intermediate_representations/convmix" 52 | 53 | ################################################################# 54 | # Parameters - CLOCQ 55 | ################################################################# 56 | clocq_params: 57 | h_match: 0.4 58 | h_rel: 0.2 59 | h_conn: 0.3 60 | h_coh: 0.1 61 | d: 20 62 | k: "AUTO" 63 | p_setting: 1000 # setting for search_space function 64 | bm25_limit: False 65 | clocq_p: 1000 # setting for neighborhood function(s) 66 | clocq_use_api: True # using CLOCQClientInterface 67 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 68 | clocq_port: "443" # port for client 69 | 70 | ################################################################# 71 | # Parameters - Silver annotation 72 | ################################################################# 73 | # annotation - SR 74 | sr_relation_shared_active: True 75 | sr_remove_stopwords: True 76 | 77 | # OPTIONAL: annotation - turn relevance 78 | tr_transitive_relevances: False 79 | tr_extract_dataset: True 80 | 81 | ################################################################# 82 | # Parameters - QU 83 | ################################################################# 84 | qres_input_separator: "[SEP]" 85 | qres_model_id: "convmix_qres" 86 | qres_model_dir: "_data/convmix/qres/" 87 | 88 | ################################################################# 89 | # Parameters - ERS 90 | ################################################################# 91 | # cache path 92 | ers_use_cache: True 93 | ers_cache_path: "_data/convmix/qres/er_cache.pickle" 94 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 95 | ers_on_the_fly: True 96 | 97 | # evidence retrieval 98 | evr_min_evidence_length: 3 99 | evr_max_evidence_length: 200 100 | evr_max_entities: 10 # max entities per evidence 101 | evr_max_pos_evidences: 10 102 | 103 | # evidence scoring 104 | evs_max_evidences: 100 105 | 106 | ################################################################# 107 | # Parameters - HA 108 | ################################################################# 109 | # general 110 | ha_max_answers: 50 111 | 112 | fid_model_path: "_data/convmix/qres/fid/best_dev" 113 | fid_per_gpu_batch_size: 1 114 | fid_max_evidences: 100 115 | 116 | # train 117 | fid_lr: 0.00005 118 | fid_optim: adamw 119 | fid_scheduler: linear 120 | fid_weight_decay: 0.01 121 | fid_text_maxlength: 250 122 | fid_answer_maxlength: 10 123 | fid_total_step: 15000 124 | fid_warmup_step: 1000 125 | 126 | # inference 127 | fid_max_evidences: 100 128 | fid_num_beams: 20 -------------------------------------------------------------------------------- /convinse/heterogeneous_answering/fid_module/fid_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | import logging 6 | 7 | from pathlib import Path 8 | 9 | from convinse.library.utils import get_config 10 | from convinse.evaluation import evidence_has_answer, question_is_existential 11 | 12 | 13 | def prepare_turn(config, input_turn, output_path, train=False): 14 | """ 15 | Prepare the given turn for input into FiD. 16 | Input will be top-100 evidences per question 17 | as predicted by ERS stage. 18 | Writes the result in the given output path. 19 | """ 20 | # create output dir 21 | output_dir = os.path.dirname(output_path) 22 | Path(output_dir).mkdir(parents=True, exist_ok=True) 23 | 24 | # prepare 25 | res = _prepare_turn(config, input_turn, train) 26 | if res is None: 27 | sr = input_turn["structured_representation"] 28 | raise Exception(f"No evidences found for this turn! SR: {sr}.") 29 | 30 | # store 31 | with open(output_path, "w") as fp: 32 | fp.write(json.dumps(res)) 33 | fp.write("\n") 34 | 35 | 36 | def prepare_data(config, input_turns, output_path, train=False): 37 | """ 38 | Prepare the given data for input into FiD. 39 | Input will be top-100 evidences per question 40 | as predicted by ERS stage. 41 | """ 42 | # create output dir 43 | output_dir = os.path.dirname(output_path) 44 | Path(output_dir).mkdir(parents=True, exist_ok=True) 45 | 46 | # process data 47 | with open(output_path, "w") as fp_out: 48 | # transform 49 | for turn in input_turns: 50 | # skip instances that are already processed 51 | if not turn.get("pred_answers") is None: 52 | continue 53 | 54 | res = _prepare_turn(config, turn, train) 55 | # skip turns for which no evidences were found 56 | if res is None: 57 | continue 58 | 59 | # write new 60 | fp_out.write(json.dumps(res)) 61 | fp_out.write("\n") 62 | 63 | 64 | def _prepare_turn(config, input_turn, train): 65 | """ 66 | Prepare the given turn for input into FiD. 67 | Input will be top-100 evidences per question 68 | as predicted by ERS stage. 69 | Returns the object. For internal usage! 70 | """ 71 | # construct set of answers that are present (from silver evidences) 72 | answer_ids = [answer["id"] for answer in input_turn["answers"]] 73 | 74 | # prepare target answers 75 | target_answers = set() 76 | # retrieve target answers from answering evidences -> preserve order! 77 | evidences = input_turn["top_evidences"] 78 | for evidence in evidences: 79 | if evidence_has_answer(evidence, input_turn["answers"]): 80 | for disambiguation in evidence["disambiguations"]: 81 | if disambiguation[1] in answer_ids: 82 | target_answers.add(disambiguation[0]) 83 | 84 | # if no answer can be found, skip instance during train/dev! 85 | if train and not target_answers: 86 | return None 87 | 88 | # if no answer in dataset, skip (fix for TimeQuestions dataset) 89 | if not input_turn["answers"]: 90 | return None 91 | 92 | evidences = input_turn["top_evidences"] 93 | evidences = evidences[: config["fid_max_evidences"]] 94 | 95 | # create data 96 | answers = list(target_answers) + [answer["label"] for answer in input_turn["answers"]] 97 | target_answer = answers[0] # always first element of target_answers 98 | evidences = [ 99 | {"title": evidence["retrieved_for_entity"]["label"], "text": evidence["evidence_text"]} 100 | for evidence in input_turn["top_evidences"] 101 | ] 102 | 103 | # if there are no evidences, return None (=skip instance) 104 | if evidences == []: 105 | return None 106 | 107 | # return transformed instance 108 | return { 109 | "id": input_turn["question_id"], 110 | "question": input_turn["structured_representation"], 111 | "target": target_answer, 112 | "answers": answers, 113 | "ctxs": evidences, 114 | } -------------------------------------------------------------------------------- /convinse/question_understanding/question_rewriting/dataset_question_rewriting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | 5 | from convinse.library.string_library import StringLibrary as string_lib 6 | from convinse.library.utils import extract_mapping_incomplete_complete 7 | 8 | def input_to_text(history_turns, current_turn, history_separator): 9 | """ 10 | Transform the relevant turns and current turn into the input text. 11 | """ 12 | history_text = history_separator.join( 13 | [_history_turn_to_text(history_turn, history_separator) for history_turn in history_turns] 14 | ) 15 | 16 | # create input 17 | current_question = current_turn["question"] 18 | input_text = f"{history_text}{history_separator}{current_question}" 19 | return input_text 20 | 21 | 22 | def _history_turn_to_text(history_turn, history_separator): 23 | """ 24 | Transform the given history turn to text. 25 | """ 26 | question = history_turn["question"] 27 | answers = history_turn["answers"] 28 | answers_text = " ".join([answer["label"] for answer in answers]) 29 | history_turn_text = f"{question}{history_separator}{answers_text}" 30 | return history_turn_text 31 | 32 | 33 | class DatasetQuestionRewriting(torch.utils.data.Dataset): 34 | def __init__(self, config, tokenizer, path): 35 | self.config = config 36 | self.tokenizer = tokenizer 37 | self.history_separator = config["history_separator"] 38 | 39 | benchmark_path = config["benchmark_path"] 40 | train_path = os.path.join(benchmark_path, config["train_input_path"]) 41 | dev_path = os.path.join(benchmark_path, config["dev_input_path"]) 42 | data_paths = [train_path, dev_path] 43 | self.mapping_incomplete_to_complete = extract_mapping_incomplete_complete(data_paths) 44 | 45 | input_encodings, output_encodings, dataset_length = self._load_data(path) 46 | self.input_encodings = input_encodings 47 | self.output_encodings = output_encodings 48 | self.dataset_length = dataset_length 49 | 50 | def __getitem__(self, idx): 51 | item = {key: torch.tensor(val[idx]) for key, val in self.input_encodings.items()} 52 | labels = self.output_encodings["input_ids"][idx] 53 | item = { 54 | "input_ids": item["input_ids"], 55 | "attention_mask": item["attention_mask"], 56 | "labels": labels, 57 | } 58 | return item 59 | 60 | def __len__(self): 61 | return self.dataset_length 62 | 63 | def _load_data(self, path): 64 | """ 65 | Opens the file, and loads the data into 66 | a format that can be put into the model. 67 | 68 | The whole history is given as input. 69 | The complete question, as annotated in the dataset, 70 | is the gold output. 71 | """ 72 | # open data 73 | with open(path, "r") as fp: 74 | dataset = json.load(fp) 75 | 76 | inputs = list() 77 | outputs = list() 78 | 79 | for conversation in dataset: 80 | history = list() 81 | for turn in conversation["questions"]: 82 | # skip initial turn: no rewrite required! 83 | if turn["turn"] == 0: 84 | continue 85 | 86 | # create input 87 | inputs.append(input_to_text(history, turn, self.history_separator)) 88 | 89 | # create output 90 | question = turn["question"] 91 | complete = self.mapping_incomplete_to_complete.get(question) 92 | outputs.append(complete) 93 | 94 | # append to history 95 | history.append(turn) 96 | 97 | input_encodings = self.tokenizer( 98 | inputs, padding=True, truncation=True, max_length=self.config["qrew_max_input_length"] 99 | ) 100 | output_encodings = self.tokenizer( 101 | outputs, padding=True, truncation=True, max_length=self.config["qrew_max_input_length"] 102 | ) 103 | dataset_length = len(inputs) 104 | 105 | return input_encodings, output_encodings, dataset_length 106 | -------------------------------------------------------------------------------- /convinse/question_understanding/question_rewriting/question_rewriting_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import transformers 4 | 5 | from pathlib import Path 6 | import convinse.question_understanding.question_rewriting.dataset_question_rewriting as dataset 7 | 8 | 9 | class QuestionRewritingModel(torch.nn.Module): 10 | def __init__(self, config): 11 | super(QuestionRewritingModel, self).__init__() 12 | self.config = config 13 | self.model = transformers.T5ForConditionalGeneration.from_pretrained( 14 | "castorini/t5-base-canard" 15 | ) 16 | self.tokenizer = transformers.T5TokenizerFast.from_pretrained("castorini/t5-base-canard") 17 | 18 | def set_eval_mode(self): 19 | """Set model to eval mode.""" 20 | self.model.eval() 21 | 22 | def save(self): 23 | """Save model.""" 24 | model_path = self.config["qrew_model_path"] 25 | # create dir if not exists 26 | model_dir = os.path.dirname(model_path) 27 | Path(model_dir).mkdir(parents=True, exist_ok=True) 28 | torch.save(self.model.state_dict(), model_path) 29 | 30 | def load(self): 31 | """Load model.""" 32 | state_dict = torch.load(self.config["qrew_model_path"]) 33 | self.model.load_state_dict(state_dict) 34 | # move to GPU (if possible) 35 | if torch.cuda.is_available(): 36 | self.model = self.model.cuda() 37 | 38 | def train(self, train_path, dev_path): 39 | """Train model.""" 40 | # load datasets 41 | train_dataset = dataset.DatasetQuestionRewriting(self.config, self.tokenizer, train_path) 42 | dev_dataset = dataset.DatasetQuestionRewriting(self.config, self.tokenizer, dev_path) 43 | # arguments for training 44 | training_args = transformers.Seq2SeqTrainingArguments( 45 | output_dir="convinse/question_understanding/question_rewriting/results", # output directory 46 | num_train_epochs=self.config[ 47 | "qrew_num_train_epochs" 48 | ], # total number of training epochs 49 | per_device_train_batch_size=self.config[ 50 | "qrew_per_device_train_batch_size" 51 | ], # batch size per device during training 52 | per_device_eval_batch_size=self.config[ 53 | "qrew_per_device_eval_batch_size" 54 | ], # batch size for evaluation 55 | warmup_steps=self.config[ 56 | "qrew_warmup_steps" 57 | ], # number of warmup steps for learning rate scheduler 58 | weight_decay=self.config["qrew_weight_decay"], # strength of weight decay 59 | logging_dir="convinse/question_understanding/question_rewriting/logs", # directory for storing logs 60 | logging_steps=1000, 61 | evaluation_strategy="epoch", 62 | save_strategy="epoch", 63 | load_best_model_at_end="True" 64 | # predict_with_generate=True 65 | ) 66 | # create the object for training 67 | trainer = transformers.Seq2SeqTrainer( 68 | model=self.model, 69 | args=training_args, 70 | train_dataset=train_dataset, 71 | eval_dataset=dev_dataset, 72 | ) 73 | # training progress 74 | trainer.train() 75 | # store model 76 | self.save() 77 | 78 | def inference(self, inputs): 79 | """ 80 | Run the model on the given input. 81 | Snippet taken from: https://github.com/gonced8/rachael-scai/blob/main/demo.py 82 | """ 83 | # encode 84 | rewrite_input_ids = self.tokenizer.encode( 85 | inputs, 86 | truncation=False, 87 | return_tensors="pt", 88 | ) 89 | if torch.cuda.is_available(): 90 | rewrite_input_ids = rewrite_input_ids.cuda() 91 | # generate 92 | output = self.model.generate( 93 | rewrite_input_ids, 94 | max_length=self.config["qrew_max_output_length"], 95 | do_sample=self.config["qrew_do_sample"], 96 | ) 97 | # decoding 98 | model_rewrite = self.tokenizer.batch_decode( 99 | output, 100 | skip_special_tokens=True, 101 | clean_up_tokenization_spaces=True, 102 | )[0] 103 | return model_rewrite 104 | -------------------------------------------------------------------------------- /config/convmix/convinse.yml: -------------------------------------------------------------------------------- 1 | name: "convinse" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: sr 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | 46 | train_input_path: "train_set/train_set_ALL.json" 47 | dev_input_path: "dev_set/dev_set_ALL.json" 48 | test_input_path: "test_set/test_set_ALL.json" 49 | 50 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 51 | path_to_intermediate_results: "_intermediate_representations/convmix" 52 | 53 | ################################################################# 54 | # Parameters - CLOCQ 55 | ################################################################# 56 | clocq_params: 57 | h_match: 0.4 58 | h_rel: 0.2 59 | h_conn: 0.3 60 | h_coh: 0.1 61 | d: 20 62 | k: "AUTO" 63 | p_setting: 1000 # setting for search_space function 64 | bm25_limit: False 65 | clocq_p: 1000 # setting for neighborhood function(s) 66 | clocq_use_api: True # using CLOCQClientInterface 67 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 68 | clocq_port: "443" # port for client 69 | 70 | ################################################################# 71 | # Parameters - Silver annotation 72 | ################################################################# 73 | # annotation - SR 74 | sr_relation_shared_active: True 75 | sr_remove_stopwords: True 76 | 77 | # OPTIONAL: annotation - turn relevance 78 | tr_transitive_relevances: False 79 | tr_extract_dataset: False 80 | 81 | ################################################################# 82 | # Parameters - QU 83 | ################################################################# 84 | sr_architecture: BART 85 | sr_model_path: "_data/convmix/convinse/sr_model.bin" 86 | sr_max_input_length: 512 87 | 88 | history_separator: " ||| " 89 | sr_separator: " || " 90 | 91 | # training parameters 92 | sr_num_train_epochs: 5 93 | sr_per_device_train_batch_size: 10 94 | sr_per_device_eval_batch_size: 10 95 | sr_warmup_steps: 500 96 | sr_weight_decay: 0.01 97 | 98 | # generation parameters 99 | sr_no_repeat_ngram_size: 2 100 | sr_num_beams: 20 101 | sr_early_stopping: True 102 | 103 | sr_delimiter: "||" 104 | 105 | ################################################################# 106 | # Parameters - ERS 107 | ################################################################# 108 | # cache path 109 | ers_use_cache: True 110 | ers_cache_path: "_data/convmix/convinse/er_cache.pickle" 111 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 112 | ers_on_the_fly: True 113 | 114 | # evidence retrieval 115 | evr_min_evidence_length: 3 116 | evr_max_evidence_length: 200 117 | evr_max_entities: 10 # max entities per evidence 118 | evr_max_pos_evidences: 10 119 | 120 | # evidence scoring 121 | evs_max_evidences: 100 122 | 123 | ################################################################# 124 | # Parameters - HA 125 | ################################################################# 126 | # general 127 | ha_max_answers: 50 128 | 129 | fid_model_path: "_data/convmix/convinse/fid/best_dev" 130 | fid_per_gpu_batch_size: 1 131 | fid_max_evidences: 100 132 | 133 | # train 134 | fid_lr: 0.00005 135 | fid_optim: adamw 136 | fid_scheduler: linear 137 | fid_weight_decay: 0.01 138 | fid_text_maxlength: 250 139 | fid_answer_maxlength: 10 140 | fid_total_step: 15000 141 | fid_warmup_step: 1000 142 | 143 | # inference 144 | fid_max_evidences: 100 145 | fid_num_beams: 20 146 | 147 | 148 | -------------------------------------------------------------------------------- /config/convmix/qrew-clocq_bm25-fid.yml: -------------------------------------------------------------------------------- 1 | name: "qrew-clocq_bm25-fid" 2 | log_level: "INFO" 3 | 4 | # Construct pipeline 5 | qu: qrew 6 | ers: clocq_bm25 7 | ha: fid 8 | 9 | # Define source combinations 10 | source_combinations: 11 | - - kb 12 | - - text 13 | - - table 14 | - - info 15 | - - kb 16 | - text 17 | - - kb 18 | - table 19 | - - kb 20 | - info 21 | - - text 22 | - table 23 | - - text 24 | - info 25 | - - table 26 | - info 27 | - - kb 28 | - text 29 | - table 30 | - info 31 | 32 | ################################################################# 33 | # General file paths 34 | ################################################################# 35 | path_to_stopwords: "_data/stopwords.txt" 36 | path_to_labels: "_data/labels.json" 37 | path_to_wikipedia_mappings: "_data/wikipedia_mappings.json" 38 | path_to_wikidata_mappings: "_data/wikidata_mappings.json" 39 | 40 | ################################################################# 41 | # Benchmark specific settings 42 | ################################################################# 43 | benchmark: "convmix" 44 | benchmark_path: "_benchmarks/convmix" 45 | seed_conversations_path: "_benchmarks/convmix/ConvMixSeed.json" 46 | 47 | train_input_path: "train_set/train_set_ALL.json" 48 | dev_input_path: "dev_set/dev_set_ALL.json" 49 | test_input_path: "test_set/test_set_ALL.json" 50 | 51 | path_to_annotated: "_intermediate_representations/convmix" # where annotated inputs come from 52 | path_to_intermediate_results: "_intermediate_representations/convmix" 53 | 54 | ################################################################# 55 | # Parameters - CLOCQ 56 | ################################################################# 57 | clocq_params: 58 | h_match: 0.4 59 | h_rel: 0.2 60 | h_conn: 0.3 61 | h_coh: 0.1 62 | d: 20 63 | k: "AUTO" 64 | p_setting: 1000 # setting for search_space function 65 | bm25_limit: False 66 | clocq_p: 1000 # setting for neighborhood function(s) 67 | clocq_use_api: True # using CLOCQClientInterface 68 | clocq_host: "https://clocq.mpi-inf.mpg.de/api" # host for client 69 | clocq_port: "443" # port for client 70 | 71 | ################################################################# 72 | # Parameters - Silver annotation 73 | ################################################################# 74 | # annotation - SR 75 | sr_relation_shared_active: True 76 | sr_remove_stopwords: True 77 | 78 | # OPTIONAL: annotation - turn relevance 79 | tr_transitive_relevances: False 80 | tr_extract_dataset: True 81 | 82 | ################################################################# 83 | # Parameters - QU 84 | ################################################################# 85 | qrew_model_path: "_data/convmix/qrew/qrew.bin" 86 | qrew_max_input_length: 512 87 | 88 | history_separator: " ||| " 89 | 90 | # training parameters 91 | qrew_num_train_epochs: 3 92 | qrew_per_device_train_batch_size: 10 93 | qrew_per_device_eval_batch_size: 10 94 | qrew_warmup_steps: 500 95 | qrew_weight_decay: 0.01 96 | 97 | # generation parameters 98 | qrew_no_repeat_ngram_size: 2 99 | qrew_max_output_length: 100 100 | qrew_do_sample: True 101 | 102 | ################################################################# 103 | # Parameters - ERS 104 | ################################################################# 105 | # cache path 106 | ers_use_cache: True 107 | ers_cache_path: "_data/convmix/qrew/er_cache.pickle" 108 | ers_wikipedia_dump: "_data/convmix/wikipedia_dump.pickle" 109 | ers_on_the_fly: True 110 | 111 | # evidence retrieval 112 | evr_min_evidence_length: 3 113 | evr_max_evidence_length: 200 114 | evr_max_entities: 10 # max entities per evidence 115 | evr_max_pos_evidences: 10 116 | 117 | # evidence scoring 118 | evs_max_evidences: 100 119 | 120 | ################################################################# 121 | # Parameters - HA 122 | ################################################################# 123 | # general 124 | ha_max_answers: 50 125 | 126 | fid_model_path: "_data/convmix/qrew/fid/best_dev" 127 | fid_per_gpu_batch_size: 1 128 | fid_max_evidences: 100 129 | 130 | # train 131 | fid_lr: 0.00005 132 | fid_optim: adamw 133 | fid_scheduler: linear 134 | fid_weight_decay: 0.01 135 | fid_text_maxlength: 250 136 | fid_answer_maxlength: 10 137 | fid_total_step: 15000 138 | fid_warmup_step: 1000 139 | 140 | # inference 141 | fid_max_evidences: 100 142 | fid_num_beams: 20 143 | -------------------------------------------------------------------------------- /convinse/question_understanding/naive_concat/naive_concat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import logging 4 | 5 | from tqdm import tqdm 6 | 7 | from pathlib import Path 8 | from convinse.question_understanding.question_understanding import QuestionUnderstanding 9 | from convinse.library.utils import get_config 10 | 11 | 12 | class NaiveConcat(QuestionUnderstanding): 13 | """ 14 | Prepend various parts of the ongoing conversation to the current turn. 15 | A turn refers to the question and the answers. 16 | Answers can be predicted answers or generated answers. 17 | - Option 1 (init): Prepend initial turn 18 | - Option 2 (prev): Prepend previous turn 19 | - Option 3 (init_prev): Prepend initial and previous turn 20 | - Option 4 (all): Prepend ALL previous turns 21 | The option can be set in the config file. 22 | """ 23 | 24 | def inference_on_turn(self, turn, history_turns): 25 | """Run model on single turn and add predictions.""" 26 | intent_explicit = self._preprend_history(history_turns, turn) 27 | turn["structured_representation"] = intent_explicit 28 | return turn 29 | 30 | def inference_on_conversation(self, conversation): 31 | """Run inference on a single conversation.""" 32 | history_turns = list() 33 | for turn in conversation["questions"]: 34 | # concat history to question 35 | question = self._preprend_history(history_turns, turn) 36 | turn["structured_representation"] = question 37 | 38 | # append to history 39 | history_turns.append(turn) 40 | return conversation 41 | 42 | def _preprend_history(self, history_turns, current_turn): 43 | """ 44 | Transform the relevant turns and current turn into the input text. 45 | """ 46 | ## consider last turn and first turn only 47 | if self.config["naive_concat"] == "init_prev": 48 | if len(history_turns) > 2: 49 | history_turns = [history_turns[0], history_turns[-1]] 50 | 51 | ## consider first turn only 52 | elif self.config["naive_concat"] == "init": 53 | if len(history_turns) > 1: 54 | history_turns = [history_turns[0]] 55 | 56 | ## consider last turn only 57 | elif self.config["naive_concat"] == "prev": 58 | if len(history_turns) > 1: 59 | history_turns = [history_turns[-1]] 60 | 61 | ## consider ALL turns 62 | elif self.config["naive_concat"] == "all": 63 | history_turns = history_turns 64 | 65 | ## consider only current turn 66 | elif self.config["naive_concat"] == "none": 67 | history_turns = [] 68 | 69 | else: 70 | raise Exception("Unknown value for naive_concat!") 71 | 72 | # create history text 73 | history_text = " ".join( 74 | [self._history_turn_to_text(history_turn) for history_turn in history_turns] 75 | ) 76 | 77 | # create input 78 | current_question = current_turn["question"] 79 | input_text = f"{history_text} {current_question}" 80 | return input_text 81 | 82 | def _history_turn_to_text(self, history_turn): 83 | """ 84 | Transform the given history turn to text. 85 | """ 86 | turn = history_turn["turn"] 87 | question = history_turn["question"] 88 | 89 | # use predicted answer in end-to-end evaluation 90 | if self.use_gold_answers: 91 | answers = history_turn["answers"] 92 | answers_text = ", ".join([answer["label"] for answer in answers]) 93 | else: 94 | answer = history_turn["pred_answers"][0] 95 | answers_text = answer["label"] 96 | 97 | history_turn_text = f"{question} {answers_text}" 98 | return history_turn_text 99 | 100 | 101 | ####################################################################################################################### 102 | ####################################################################################################################### 103 | if __name__ == "__main__": 104 | if len(sys.argv) != 2: 105 | raise Exception( 106 | "Invalid number of options provided.\nUsage: python convinse/question_understanding/naive_concat/naive_concat.py " 107 | ) 108 | 109 | # load config 110 | config_path = sys.argv[1] 111 | config = get_config(config_path) 112 | naive_concat = NaiveConcat(config, use_gold_answers=True) 113 | naive_concat.inference() 114 | -------------------------------------------------------------------------------- /convinse/question_understanding/structured_representation/dataset_structured_representation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | from convinse.library.string_library import StringLibrary as string_lib 5 | 6 | 7 | def input_to_text(history_turns, current_turn, history_separator): 8 | """ 9 | Transform the history turns and current turn into the input text. 10 | """ 11 | # create history text 12 | history_text = history_separator.join( 13 | [_history_turn_to_text(history_turn, history_separator) for history_turn in history_turns] 14 | ) 15 | 16 | # create input 17 | current_question = current_turn["question"] 18 | input_text = f"{history_text}{history_separator}{current_question}" 19 | return input_text 20 | 21 | 22 | def _history_turn_to_text(history_turn, history_separator): 23 | """ 24 | Transform the given history turn to text. 25 | """ 26 | question = history_turn["question"] 27 | answers = history_turn["answers"] 28 | answers_text = " ".join([answer["label"] for answer in answers]) 29 | history_turn_text = f"{question}{history_separator}{answers_text}" 30 | return history_turn_text 31 | 32 | 33 | def output_to_text(silver_SR, SR_delimiter): 34 | """ 35 | Transform the given silver abstract representation to text. 36 | The (recursive) list data structure is resolved and flattened. 37 | """ 38 | sep = ", " 39 | topic, entities, relation, ans_type = silver_SR[0] 40 | 41 | # create individual components 42 | topic = " ".join(topic).strip() 43 | entities = " ".join(entities).strip() 44 | relation = " ".join(relation).strip() 45 | ans_type = ans_type.strip() if ans_type else "" 46 | 47 | # create ar text 48 | sr_text = f"{topic}{SR_delimiter}{entities}{SR_delimiter}{relation}{SR_delimiter}{ans_type}" 49 | 50 | # remove whitespaces in AR 51 | while " " in sr_text: 52 | sr_text = sr_text.replace(" ", " ") 53 | sr_text.replace(" , ", ", ") 54 | sr_text = sr_text.strip() 55 | return sr_text 56 | 57 | 58 | class DatasetStructuredRepresentation(torch.utils.data.Dataset): 59 | def __init__(self, config, tokenizer, path): 60 | self.config = config 61 | self.tokenizer = tokenizer 62 | self.history_separator = config["history_separator"] 63 | self.sr_separator = config["sr_separator"] 64 | 65 | input_encodings, output_encodings, dataset_length = self._load_data(path) 66 | self.input_encodings = input_encodings 67 | self.output_encodings = output_encodings 68 | self.dataset_length = dataset_length 69 | 70 | def __getitem__(self, idx): 71 | item = {key: torch.tensor(val[idx]) for key, val in self.input_encodings.items()} 72 | labels = self.output_encodings["input_ids"][idx] 73 | item = { 74 | "input_ids": item["input_ids"], 75 | "attention_mask": item["attention_mask"], 76 | "labels": labels, 77 | } 78 | return item 79 | 80 | def __len__(self): 81 | return self.dataset_length 82 | 83 | def _load_data(self, path): 84 | """ 85 | Opens the file, and loads the data into 86 | a format that can be put into the model. 87 | 88 | The input dataset should be annotated using 89 | the silver_annotation.py class. 90 | 91 | The whole history is given as input. 92 | """ 93 | # open data 94 | with open(path, "r") as fp: 95 | dataset = json.load(fp) 96 | 97 | inputs = list() 98 | outputs = list() 99 | 100 | for conversation in dataset: 101 | history = list() 102 | for turn in conversation["questions"]: 103 | # skip examples for which no gold SR was found, or for first turn 104 | if not turn["silver_SR"]: 105 | continue 106 | 107 | inputs.append(input_to_text(history, turn, self.history_separator)) 108 | outputs.append(output_to_text(turn["silver_SR"], self.sr_separator)) 109 | 110 | # append to history 111 | history.append(turn) 112 | 113 | # encode 114 | input_encodings = self.tokenizer( 115 | inputs, padding=True, truncation=True, max_length=self.config["sr_max_input_length"] 116 | ) 117 | output_encodings = self.tokenizer( 118 | outputs, padding=True, truncation=True, max_length=self.config["sr_max_input_length"] 119 | ) 120 | dataset_length = len(inputs) 121 | 122 | return input_encodings, output_encodings, dataset_length 123 | -------------------------------------------------------------------------------- /convinse/question_understanding/question_rewriting/question_rewriting_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | 5 | from convinse.library.utils import get_config, get_logger 6 | from convinse.question_understanding.question_understanding import QuestionUnderstanding 7 | from convinse.question_understanding.question_rewriting.question_rewriting_model import ( 8 | QuestionRewritingModel, 9 | ) 10 | import convinse.question_understanding.question_rewriting.dataset_question_rewriting as dataset 11 | 12 | 13 | class QuestionRewritingModule(QuestionUnderstanding): 14 | def __init__(self, config, use_gold_answers): 15 | """Initialize QR module.""" 16 | self.config = config 17 | self.logger = get_logger(__name__, config) 18 | self.use_gold_answers = use_gold_answers 19 | 20 | # create model 21 | self.qr_model = QuestionRewritingModel(config) 22 | self.model_loaded = False 23 | 24 | self.history_separator = config["history_separator"] 25 | 26 | def train(self): 27 | """Train the model on silver AR data.""" 28 | # create paths 29 | self.logger.info(f"Starting training...") 30 | data_dir = self.config["path_to_annotated"] 31 | train_path = os.path.join(data_dir, "annotated_train.json") 32 | dev_path = os.path.join(data_dir, "annotated_dev.json") 33 | self.qr_model.train(train_path, dev_path) 34 | self.logger.info(f"Finished training.") 35 | 36 | def inference_on_conversation(self, conversation): 37 | """Run inference on a single conversation.""" 38 | # load QR model (if required) 39 | self._load() 40 | 41 | # QR model inference 42 | history_turns = list() 43 | for i, turn in enumerate(conversation["questions"]): 44 | # append to history 45 | question = turn["question"] 46 | history_turns.append(question) 47 | 48 | # prepare input (omitt gold answer(s)) 49 | rewrite_input = self.history_separator.join(history_turns) 50 | 51 | # run inference 52 | qrew = self.qr_model.inference(rewrite_input) 53 | turn["structured_representation"] = qrew 54 | 55 | # only append answer if there is a next question 56 | if i + 1 < len(conversation["questions"]): 57 | if self.use_gold_answers: 58 | answer_text = " ".join([answer["label"] for answer in turn["answers"]]) 59 | else: 60 | # answer_text = ", ".join([answer["label"] for answer in turn["pred_answers"]]) 61 | answer_text = turn["pred_answers"][0]["label"] 62 | history_turns.append(answer_text) 63 | return conversation 64 | 65 | def inference_on_turn(self, turn, history_turns): 66 | """Run inference on a single turn (and history).""" 67 | # load QR model (if required) 68 | self._load() 69 | 70 | # SR model inference 71 | question = turn["question"] 72 | history_turns.append(question) 73 | 74 | # prepare input (omitt gold answer(s)) 75 | rewrite_input = self.history_separator.join(history_turns) 76 | 77 | # run inference 78 | intent_explicit = self.qr_model.inference(rewrite_input) 79 | turn["structured_representation"] = intent_explicit 80 | return turn 81 | 82 | def _load(self): 83 | """Load the QRes model.""" 84 | # only load if not already done so 85 | if not self.model_loaded: 86 | self.qr_model.load() 87 | self.qr_model.set_eval_mode() 88 | self.model_loaded = True 89 | 90 | 91 | ####################################################################################################################### 92 | ####################################################################################################################### 93 | if __name__ == "__main__": 94 | if len(sys.argv) != 3: 95 | raise Exception( 96 | "Invalid number of options provided.\nUsage: python convinse/question_understanding/question_rewriting/question_rewriting_module.py " 97 | ) 98 | 99 | function = sys.argv[1] 100 | config_path = sys.argv[2] 101 | config = get_config(config_path) 102 | 103 | # train: train model 104 | if function == "--train": 105 | qrm = QuestionRewritingModule(config, use_gold_answers=True) 106 | qrm.train() 107 | 108 | # inference: add predictions to data 109 | elif function == "--inference": 110 | # load config 111 | qrm = QuestionRewritingModule(config, use_gold_answers=True) 112 | qrm.inference() 113 | -------------------------------------------------------------------------------- /convinse/question_understanding/question_resolution/question_resolution_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import logging 5 | 6 | from subprocess import Popen, PIPE 7 | 8 | from convinse.library.utils import get_config, get_logger, store_json_with_mkdir 9 | from convinse.question_understanding.question_understanding import QuestionUnderstanding 10 | import convinse.question_understanding.question_resolution.question_resolution_utils as qres_utils 11 | 12 | 13 | class QuestionResolutionModule(QuestionUnderstanding): 14 | def __init__(self, config, use_gold_answers): 15 | """Initialize SR module.""" 16 | self.config = config 17 | self.logger = get_logger(__name__, config) 18 | self.use_gold_answers = use_gold_answers 19 | self.path_to_quretec = "convinse/question_understanding/question_resolution/quretec" 20 | self.model_id = config["qres_model_id"] 21 | self.model_dir = config["qres_model_dir"] 22 | 23 | def train(self): 24 | """Train the model on silver SR data.""" 25 | # train model 26 | self.logger.info(f"Starting training...") 27 | input_dir = self.config["path_to_annotated"] 28 | train_path = os.path.join(input_dir, "annotated_train.json") 29 | dev_path = os.path.join(input_dir, "annotated_dev.json") 30 | qres_utils.prepare_data_for_training(self.config, train_path, dev_path) 31 | benchmark = self.config["benchmark"] 32 | data_dir = f"_intermediate_representations/{benchmark}/qres/data" 33 | 34 | # run training 35 | COMMAND = ["python", f"{self.path_to_quretec}/run_ner.py"] 36 | COMMAND += ["--task_name", "ner"] 37 | COMMAND += ["--bert_model", "bert-large-uncased"] 38 | COMMAND += ["--max_seq_length", "300"] 39 | COMMAND += ["--train_batch_size", "20"] 40 | COMMAND += ["--train_on", "train"] 41 | COMMAND += ["--hidden_dropout_prob", "0.4"] 42 | COMMAND += ["--dev_on", "dev"] 43 | COMMAND += ["--do_train"] 44 | COMMAND += ["--data_dir", data_dir] 45 | COMMAND += ["--base_dir", self.model_dir] 46 | COMMAND += ["--model_id", self.model_id] 47 | process = Popen(COMMAND, stdout=sys.stdout, stderr=sys.stderr) 48 | self.logger.info(f"Finished training.") 49 | 50 | def inference_on_data(self, input_data): 51 | """Run model on data and add predictions.""" 52 | benchmark = self.config["benchmark"] 53 | data_dir = f"_intermediate_representations/{benchmark}/qres/data" 54 | output_path = os.path.join(data_dir, "data_for_inference.json") 55 | 56 | # model inference on given data 57 | qres_utils.prepare_data_for_inference( 58 | self.config, input_data, output_path, use_gold_answers=self.use_gold_answers 59 | ) 60 | self._inference() 61 | 62 | # postprocess predictions 63 | quretec_pred_path = os.path.join( 64 | self.model_dir, self.model_id, "eval_results_data_for_inference_epoch0.json" 65 | ) 66 | qres_utils.postprocess_data(input_data, quretec_pred_path) 67 | return input_data 68 | 69 | def inference_on_turn(self, turn, history_turns): 70 | """Run inference on a single turn (and history).""" 71 | if not history_turns: 72 | turn["structured_representation"] = turn["question"] 73 | return turn 74 | 75 | benchmark = self.config["benchmark"] 76 | data_dir = f"_intermediate_representations/{benchmark}/qres/data" 77 | output_path = os.path.join(data_dir, "data_for_inference.json") 78 | 79 | # model inference on given data 80 | qres_utils.prepare_turn_for_inference(self.config, turn, history_turns, output_path, self.use_gold_answers) 81 | self._inference() 82 | 83 | # postprocess predictions 84 | quretec_pred_path = os.path.join( 85 | self.model_dir, self.model_id, "eval_results_data_for_inference_epoch0.json" 86 | ) 87 | qres_utils.postprocess_turn(turn, quretec_pred_path) 88 | return turn 89 | 90 | def _inference(self): 91 | """Run QuReTeC model on given input via separate script.""" 92 | benchmark = self.config["benchmark"] 93 | data_dir = f"_intermediate_representations/{benchmark}/qres/data" 94 | 95 | # run inference 96 | COMMAND = ["python", f"{self.path_to_quretec}/run_ner.py"] 97 | COMMAND += ["--task_name", "ner"] 98 | COMMAND += ["--do_eval"] 99 | COMMAND += ["--do_lower_case"] 100 | COMMAND += ["--data_dir", data_dir] 101 | COMMAND += ["--base_dir", self.model_dir] 102 | COMMAND += ["--dev_on", "data_for_inference"] 103 | COMMAND += ["--model_id", self.model_id] 104 | COMMAND += ["--no_cuda"] 105 | process = Popen(COMMAND, stdout=sys.stdout, stderr=sys.stderr) 106 | process.communicate() 107 | 108 | 109 | ####################################################################################################################### 110 | ####################################################################################################################### 111 | if __name__ == "__main__": 112 | if len(sys.argv) != 3: 113 | raise Exception( 114 | "Usage: python convinse/question_understanding/question_resolution/question_resolution_module.py -- " 115 | ) 116 | 117 | function = sys.argv[1] 118 | config_path = sys.argv[2] 119 | config = get_config(config_path) 120 | 121 | # train: train model 122 | if function == "--train": 123 | qrm = QuestionResolutionModule(config, use_gold_answers=True) 124 | qrm.train() 125 | 126 | # inference: add predictions to data 127 | elif function == "--inference": 128 | # load config 129 | qrm = QuestionResolutionModule(config, use_gold_answers=True) 130 | qrm.inference() 131 | -------------------------------------------------------------------------------- /convinse/library/custom_trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import inspect 3 | import math 4 | import os 5 | import random 6 | import re 7 | import shutil 8 | import sys 9 | import time 10 | import json 11 | import warnings 12 | from logging import StreamHandler 13 | from pathlib import Path 14 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 15 | 16 | from tqdm.auto import tqdm 17 | 18 | import numpy as np 19 | import torch 20 | from packaging import version 21 | from torch import nn 22 | from torch.utils.data.dataloader import DataLoader 23 | from torch.utils.data.dataset import Dataset, IterableDataset 24 | from torch.utils.data.distributed import DistributedSampler 25 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 26 | 27 | from transformers import Trainer 28 | from transformers.trainer_utils import speed_metrics 29 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 30 | 31 | 32 | class CustomTrainer(Trainer): 33 | def __init__( 34 | self, 35 | model, 36 | args=None, 37 | data_collator=None, 38 | train_dataset=None, 39 | eval_dataset=None, 40 | tokenizer=None, 41 | model_init=None, 42 | compute_metrics=None, 43 | callbacks=None, 44 | optimizers=(None, None), 45 | path_to_best_model="models/model", 46 | ): 47 | super().__init__( 48 | model, 49 | args, 50 | data_collator, 51 | train_dataset, 52 | eval_dataset, 53 | tokenizer, 54 | model_init, 55 | compute_metrics, 56 | callbacks, 57 | optimizers, 58 | ) 59 | self.path_to_best_model = path_to_best_model 60 | 61 | def evaluate( 62 | self, 63 | train_dataset=None, 64 | eval_dataset: Optional[Dataset] = None, 65 | ignore_keys: Optional[List[str]] = None, 66 | metric_key_prefix: str = "eval", 67 | ) -> Dict[str, float]: 68 | 69 | self._memory_tracker.start() 70 | 71 | train_dataloader = self.get_train_dataloader() 72 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 73 | start_time = time.time() 74 | 75 | eval_loop = ( 76 | self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 77 | ) 78 | 79 | train_output = eval_loop( 80 | train_dataloader, 81 | description="Evaluation on train", 82 | # No point gathering the predictions if there are no metrics, otherwise we defer to 83 | # self.args.prediction_loss_only 84 | prediction_loss_only=True if self.compute_metrics is None else None, 85 | ignore_keys=ignore_keys, 86 | metric_key_prefix="train", 87 | ) 88 | 89 | eval_output = eval_loop( 90 | eval_dataloader, 91 | description="Evaluation on dev", 92 | # No point gathering the predictions if there are no metrics, otherwise we defer to 93 | # self.args.prediction_loss_only 94 | prediction_loss_only=True if self.compute_metrics is None else None, 95 | ignore_keys=ignore_keys, 96 | metric_key_prefix="eval", 97 | ) 98 | 99 | total_batch_size = self.args.eval_batch_size * self.args.world_size 100 | 101 | train_output.metrics.update( 102 | speed_metrics(metric_key_prefix, start_time, train_output.num_samples) 103 | ) 104 | 105 | eval_output.metrics.update( 106 | speed_metrics(metric_key_prefix, start_time, eval_output.num_samples) 107 | ) 108 | 109 | self.log(train_output.metrics) 110 | self.log(eval_output.metrics) 111 | 112 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 113 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 114 | xm.master_print(met.metrics_report()) 115 | 116 | self.control = self.callback_handler.on_evaluate( 117 | self.args, self.state, self.control, train_output.metrics 118 | ) 119 | self.control = self.callback_handler.on_evaluate( 120 | self.args, self.state, self.control, eval_output.metrics 121 | ) 122 | 123 | self._memory_tracker.stop_and_update_metrics(train_output.metrics) 124 | self._memory_tracker.stop_and_update_metrics(eval_output.metrics) 125 | 126 | dic = { 127 | "Training metrics": train_output.metrics, 128 | "Validation metrics": eval_output.metrics, 129 | } 130 | print(eval_output.metrics.keys()) 131 | eval_accuracy = eval_output.metrics["eval_accuracy"] 132 | 133 | # store model if performance improved 134 | if self.state.best_model_checkpoint is None or eval_accuracy > self.state.best_metric: 135 | self.state.best_model_checkpoint = self.state.global_step 136 | self.state.best_metric = eval_accuracy 137 | self._save_model(self.path_to_best_model) 138 | self._store_metadata_best_model() 139 | return dic 140 | 141 | def _store_metadata_best_model(self): 142 | """ 143 | Store metadata of best model to .txt file. 144 | """ 145 | # change extension of path 146 | path, ext = os.path.splitext(self.path_to_best_model) 147 | path_to_metadata = f"{path}.txt" 148 | 149 | # create metadata string 150 | metadata = f"Best metric: {self.state.best_metric}, global_step: {self.state.best_model_checkpoint}" 151 | 152 | # store metadata 153 | with open(path_to_metadata, "w") as fp: 154 | fp.write(metadata) 155 | 156 | def _save_model(self, output_dir: Optional[str] = None): 157 | """ 158 | Stores the best model found so far. 159 | """ 160 | print("Storing best model") 161 | super().save_model(output_dir) 162 | -------------------------------------------------------------------------------- /convinse/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tqdm import tqdm 4 | 5 | from convinse.library.string_library import StringLibrary 6 | from Levenshtein import distance as levenshtein_distance 7 | 8 | 9 | def answer_presence(evidences, answers): 10 | """ 11 | Compute the answer presence for a set of evidences 12 | and a parsed answer dict, and return a list of 13 | answering evidences. 14 | Return format: (boolean, [evidence-dict, ...]) 15 | """ 16 | # initialize 17 | answer_present = False 18 | answering_evidences = list() 19 | 20 | # go through evidences 21 | for evidence in evidences: 22 | if evidence_has_answer(evidence, answers): 23 | # remember evidence 24 | answer_present = True 25 | answering_evidences.append(evidence) 26 | # return results 27 | return (answer_present, answering_evidences) 28 | 29 | 30 | def evidence_has_answer(evidence, gold_answers): 31 | """Check whether the given evidence has any of the answers.""" 32 | for answer_candidate in evidence["wikidata_entities"]: 33 | # check if answering candidate 34 | if candidate_in_answers(answer_candidate, gold_answers): 35 | return True 36 | return False 37 | 38 | 39 | def candidate_in_answers(answer_candidate, gold_answers): 40 | """Check if candidate is answer.""" 41 | # get ids 42 | answer_candidate_id = answer_candidate["id"] 43 | gold_answer_ids = [answer["id"] for answer in gold_answers] 44 | 45 | # normalize 46 | answer_candidate_id = answer_candidate_id.lower().strip().replace('"', "").replace("+", "") 47 | gold_answer_ids = [answer.lower().strip().replace('"', "") for answer in gold_answer_ids] 48 | 49 | # perform check 50 | if answer_candidate_id in gold_answer_ids: 51 | return True 52 | 53 | # no match found 54 | return False 55 | 56 | 57 | def mrr_score(answers, gold_answers): 58 | """Compute MRR score for given answers and gold answers.""" 59 | # check if any answer was given 60 | if not answers: 61 | return 0.0 62 | # go through answer candidates 63 | for answer in answers: 64 | if candidate_in_answers(answer["answer"], gold_answers): 65 | return 1.0 / float(answer["rank"]) 66 | return 0.0 67 | 68 | 69 | def precision_at_1(answers, gold_answers): 70 | """Compute P@1 score for given answers and gold answers.""" 71 | # check if any answer was given 72 | if not answers: 73 | return 0.0 74 | # go through answer candidates 75 | for answer in answers: 76 | if float(answer["rank"]) > float(1.0): 77 | break 78 | elif candidate_in_answers(answer["answer"], gold_answers): 79 | return 1.0 80 | return 0.0 81 | 82 | 83 | def hit_at_5(answers, gold_answers): 84 | """Compute Hit@5 score for given answers and gold answers.""" 85 | # check if any answer was given 86 | if not answers: 87 | return 0.0 88 | # go through answer candidates 89 | for answer in answers: 90 | if float(answer["rank"]) > float(5.0): 91 | break 92 | elif candidate_in_answers(answer["answer"], gold_answers): 93 | return 1.0 94 | return 0.0 95 | 96 | 97 | def get_ranked_answers(config, generated_answer, turn): 98 | """ 99 | Convert the predicted answer text to a Wikidata ID (or Yes/No), 100 | and return the ranked answers. 101 | Can be used for any method that predicts an answer string (instead of a KB item). 102 | """ 103 | # check if existential (special treatment) 104 | question = turn["question"] 105 | if question_is_existential(question): 106 | ranked_answers = [ 107 | {"answer": {"id": "yes", "label": "yes"}, "score": 1.0, "rank": 1}, 108 | {"answer": {"id": "no", "label": "no"}, "score": 0.5, "rank": 2}, 109 | ] 110 | # no existential 111 | else: 112 | # return dummy answer in case None was found (if no evidences found) 113 | if generated_answer is None: 114 | return [{"answer": {"id": "None", "label": "None"}, "rank": 1, "score": 0.0}] 115 | smallest_diff = 100000 116 | all_answers = list() 117 | mentions = set() 118 | for evidence in turn["top_evidences"]: 119 | for disambiguation in evidence["disambiguations"]: 120 | mention = disambiguation[0] 121 | id = disambiguation[1] 122 | if id is None or id == False: 123 | continue 124 | 125 | # skip duplicates 126 | ans = str(mention) + str(id) 127 | if ans in mentions: 128 | continue 129 | mentions.add(ans) 130 | # exact match 131 | if generated_answer == mention: 132 | diff = 0 133 | # otherwise compute edit distance 134 | else: 135 | diff = levenshtein_distance(generated_answer, mention) 136 | 137 | all_answers.append({"answer": {"id": id, "label": mention}, "score": diff}) 138 | 139 | sorted_answers = sorted(all_answers, key = lambda j: j['score']) 140 | ranked_answers = [ 141 | {"answer": answer["answer"], "score": answer["score"], "rank": i+1} 142 | for i, answer in enumerate(sorted_answers) 143 | ] 144 | 145 | # don't return all answers 146 | max_answers = config["ha_max_answers"] 147 | ranked_answers = ranked_answers[:max_answers] 148 | if not ranked_answers: 149 | ranked_answers = [{"answer": {"id": "None", "label": "None"}, "rank": 1, "score": 0.0}] 150 | return ranked_answers 151 | 152 | 153 | def question_is_existential(question): 154 | existential_keywords = [ 155 | "is", 156 | "are", 157 | "was", 158 | "were", 159 | "am", 160 | "be", 161 | "being", 162 | "been", 163 | "did", 164 | "do", 165 | "does", 166 | "done", 167 | "doing", 168 | "has", 169 | "have", 170 | "had", 171 | "having", 172 | ] 173 | lowercase_question = question.lower() 174 | lowercase_question = lowercase_question.strip() 175 | for keyword in existential_keywords: 176 | if lowercase_question.startswith(keyword): 177 | return True 178 | return False 179 | -------------------------------------------------------------------------------- /convinse/evidence_retrieval_scoring/evidence_retrieval_scoring.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | 7 | from convinse.library.utils import get_config, get_logger 8 | from convinse.evaluation import answer_presence 9 | 10 | 11 | class EvidenceRetrievalScoring: 12 | """Abstract class for ERs phase.""" 13 | 14 | def __init__(self, config): 15 | """Initialize ERS module.""" 16 | self.config = config 17 | self.logger = get_logger(__name__, config) 18 | 19 | def train(self, sources=None): 20 | """Method used in case no training required for ERS phase.""" 21 | self.logger.info("Module used does not require training.") 22 | 23 | def inference(self, sources=None): 24 | """Run ERS on data and add retrieve top-e evidences for each source combination.""" 25 | input_dir = self.config["path_to_annotated"] 26 | output_dir = self.config["path_to_intermediate_results"] 27 | 28 | qu = self.config["qu"] 29 | ers = self.config["ers"] 30 | 31 | # either use given option, or from config 32 | if not sources is None: 33 | source_combinations = [sources] 34 | else: 35 | source_combinations = self.config["source_combinations"] 36 | 37 | # go through all combinations 38 | for sources in source_combinations: 39 | sources_string = "_".join(sources) 40 | 41 | input_path = os.path.join(input_dir, qu, "train_qu.json") 42 | output_path = os.path.join(output_dir, qu, ers, sources_string, "train_ers.jsonl") 43 | self.inference_on_data_split(input_path, output_path, sources) 44 | 45 | input_path = os.path.join(input_dir, qu, "dev_qu.json") 46 | output_path = os.path.join(output_dir, qu, ers, sources_string, "dev_ers.jsonl") 47 | self.inference_on_data_split(input_path, output_path, sources) 48 | 49 | input_path = os.path.join(input_dir, qu, "test_qu.json") 50 | output_path = os.path.join(output_dir, qu, ers, sources_string, "test_ers.jsonl") 51 | self.inference_on_data_split(input_path, output_path, sources) 52 | 53 | # store results in cache (if applicable) 54 | self.store_cache() 55 | 56 | def inference_on_data_split(self, input_path, output_path, sources): 57 | """ 58 | Run ERS on the dataset to predict 59 | answering evidences for each SR in the dataset. 60 | """ 61 | # open data 62 | with open(input_path, "r") as fp: 63 | data = json.load(fp) 64 | self.logger.info(f"Input data loaded from: {input_path}.") 65 | 66 | # score 67 | answer_presences = list() 68 | source_to_ans_pres = {"kb": 0, "text": 0, "table": 0, "info": 0, "all": 0} 69 | 70 | # create folder if not exists 71 | output_dir = os.path.dirname(output_path) 72 | Path(output_dir).mkdir(parents=True, exist_ok=True) 73 | 74 | # process data 75 | with open(output_path, "w") as fp: 76 | for conversation in tqdm(data): 77 | for turn in conversation["questions"]: 78 | top_evidences = self.inference_on_turn(turn, sources) 79 | turn["top_evidences"] = top_evidences 80 | 81 | # answer presence 82 | hit, answering_evidences = answer_presence(top_evidences, turn["answers"]) 83 | turn["answer_presence"] = hit 84 | turn["answer_presence_per_src"] = { 85 | evidence["source"]: 1 for evidence in answering_evidences 86 | } 87 | 88 | # write conversation to file 89 | fp.write(json.dumps(conversation)) 90 | fp.write("\n") 91 | 92 | # accumulate results 93 | c_answer_presences = [turn["answer_presence"] for turn in conversation["questions"]] 94 | answer_presences += c_answer_presences 95 | for turn in conversation["questions"]: 96 | answer_presence_per_src = turn["answer_presence_per_src"] 97 | # add per source answer presence 98 | for src, ans_presence in answer_presence_per_src.items(): 99 | source_to_ans_pres[src] += ans_presence 100 | # aggregate overall answer presence for validation 101 | if len(answer_presence_per_src.items()): 102 | source_to_ans_pres["all"] += 1 103 | 104 | # print results 105 | res_path = output_path.replace(".jsonl", ".res") 106 | with open(res_path, "w") as fp: 107 | avg_answer_presence = sum(answer_presences) / len(answer_presences) 108 | fp.write(f"Avg. answer presence: {avg_answer_presence}\n") 109 | answer_presence_per_src = { 110 | src: (num / len(answer_presences)) for src, num in source_to_ans_pres.items() 111 | } 112 | fp.write(f"Answer presence per source: {answer_presence_per_src}") 113 | 114 | # log 115 | self.logger.info(f"Done with processing: {input_path}.") 116 | 117 | def inference_on_data(self, input_data, sources=["kb", "text", "table", "info"]): 118 | """Run ERS on given data.""" 119 | input_turns = [turn for conv in input_data for turn in conv["questions"]] 120 | self.inference_on_turns(input_turns, sources) 121 | return input_data 122 | 123 | def inference_on_turns(self, input_turns, sources=["kb", "text", "table", "info"]): 124 | """Run ERS on given turns.""" 125 | for turn in input_turns: 126 | top_evidences = self.inference_on_turn(turn, sources) 127 | turn["top_evidences"] = top_evidences 128 | 129 | # answer presence 130 | hit, answering_evidences = answer_presence(top_evidences, turn["answers"]) 131 | turn["answer_presence"] = hit 132 | turn["answer_presence_per_src"] = { 133 | evidence["source"]: 1 for evidence in answering_evidences 134 | } 135 | return input_turns 136 | 137 | def inference_on_turn(self): 138 | raise Exception( 139 | "This is an abstract function which should be overwritten in a derived class!" 140 | ) 141 | 142 | def store_cache(self): 143 | pass 144 | -------------------------------------------------------------------------------- /convinse/question_understanding/structured_representation/structured_representation_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | import random 6 | 7 | from convinse.library.utils import get_config, get_logger 8 | from convinse.question_understanding.question_understanding import QuestionUnderstanding 9 | from convinse.question_understanding.structured_representation.structured_representation_model import ( 10 | StructuredRepresentationModel, 11 | ) 12 | import convinse.question_understanding.structured_representation.dataset_structured_representation as dataset 13 | 14 | 15 | class StructuredRepresentationModule(QuestionUnderstanding): 16 | def __init__(self, config, use_gold_answers): 17 | """Initialize SR module.""" 18 | self.config = config 19 | self.logger = get_logger(__name__, config) 20 | self.use_gold_answers = use_gold_answers 21 | 22 | # create model 23 | self.sr_model = StructuredRepresentationModel(config) 24 | self.model_loaded = False 25 | 26 | self.history_separator = config["history_separator"] 27 | self.sr_delimiter = config["sr_delimiter"] 28 | 29 | def train(self): 30 | """Train the model on silver SR data.""" 31 | # train model 32 | self.logger.info(f"Starting training...") 33 | data_dir = self.config["path_to_annotated"] 34 | train_path = os.path.join(data_dir, "annotated_train.json") 35 | dev_path = os.path.join(data_dir, "annotated_dev.json") 36 | self.sr_model.train(train_path, dev_path) 37 | self.logger.info(f"Finished training.") 38 | 39 | def inference_on_conversation(self, conversation): 40 | """Run inference on a single conversation.""" 41 | # load SR model (if required) 42 | self._load() 43 | 44 | with torch.no_grad(): 45 | # SR model inference 46 | history_turns = list() 47 | for i, turn in enumerate(conversation["questions"]): 48 | self.inference_on_turn(turn, history_turns) 49 | 50 | # only append answer if there is a next question 51 | if i + 1 < len(conversation["questions"]): 52 | if self.use_gold_answers: 53 | answer_text = ", ".join([answer["label"] for answer in turn["answers"]]) 54 | else: 55 | # answer_text = ", ".join([answer["label"] for answer in turn["pred_answers"]]) 56 | answer_text = turn["pred_answers"][0]["label"] 57 | history_turns.append(answer_text) 58 | return conversation 59 | 60 | def inference_on_turn(self, turn, history_turns): 61 | """Run inference on a single turn.""" 62 | # load SR model (if required) 63 | self._load() 64 | 65 | with torch.no_grad(): 66 | # SR model inference 67 | question = turn["question"] 68 | history_turns.append(question) 69 | 70 | # prepare input (omitt gold answer(s)) 71 | rewrite_input = self.history_separator.join(history_turns) 72 | 73 | # run inference 74 | sr = self.sr_model.inference(rewrite_input) 75 | turn["structured_representation"] = sr 76 | return turn 77 | 78 | def _load(self): 79 | """Load the SR model.""" 80 | # only load if not already done so 81 | if not self.model_loaded: 82 | self.sr_model.load() 83 | self.sr_model.set_eval_mode() 84 | self.model_loaded = True 85 | 86 | def adjust_sr_for_ablation(self, sr, ablation_type): 87 | """ 88 | Adjust the given SR based on the specific ablation type. 89 | """ 90 | slots = sr.split(self.sr_delimiter, 3) 91 | if len(slots) < 4 and not slots[0]: 92 | # type missing 93 | slots = slots + [""] 94 | elif len(slots) < 4: 95 | # context missing 96 | slots = [""] + slots 97 | if len(slots) < 4: 98 | # fix other (strange) cases 99 | slots = slots + (4 - len(slots)) * [""] 100 | context, entity, pred, ans_type = slots 101 | if ablation_type == "nocontext": 102 | sr = f"{entity.strip()} {self.sr_delimiter} {pred.strip()} {self.sr_delimiter} {ans_type.strip()}" 103 | elif ablation_type == "noentity": 104 | sr = f"{context.strip()} {self.sr_delimiter} {pred.strip()} {self.sr_delimiter} {ans_type.strip()}" 105 | elif ablation_type == "nopred": 106 | sr = f"{context.strip()} {self.sr_delimiter} {entity.strip()} {self.sr_delimiter} {ans_type.strip()}" 107 | elif ablation_type == "notype": 108 | sr = f"{context.strip()} {self.sr_delimiter} {entity.strip()} {self.sr_delimiter} {pred.strip()}" 109 | elif ablation_type == "nostructure": 110 | slots = [context, entity, pred, ans_type] 111 | random.shuffle(slots) 112 | sr = f"{slots[0].strip()} {self.sr_delimiter} {slots[1].strip()} {self.sr_delimiter} {slots[2].strip()} {self.sr_delimiter} {slots[3].strip()}" 113 | elif ablation_type == "full": 114 | sr = f"{context.strip()} {self.sr_delimiter} {entity.strip()} {self.sr_delimiter} {pred.strip()} {self.sr_delimiter} {ans_type.strip()}" 115 | else: 116 | raise Exception(f"Unknown ablation type: {ablation_type}") 117 | return sr 118 | 119 | 120 | ####################################################################################################################### 121 | ####################################################################################################################### 122 | if __name__ == "__main__": 123 | if len(sys.argv) != 3: 124 | raise Exception( 125 | "Usage: python convinse/question_understanding/structured_representation/structured_representation_module.py -- " 126 | ) 127 | 128 | function = sys.argv[1] 129 | config_path = sys.argv[2] 130 | config = get_config(config_path) 131 | 132 | # train: train model 133 | if function == "--train": 134 | srm = StructuredRepresentationModule(config, use_gold_answers=True) 135 | srm.train() 136 | 137 | # inference: add predictions to data 138 | elif function == "--inference": 139 | # load config 140 | srm = StructuredRepresentationModule(config, use_gold_answers=True) 141 | srm.inference() 142 | -------------------------------------------------------------------------------- /convinse/evidence_retrieval_scoring/README.md: -------------------------------------------------------------------------------- 1 | # Evidence Retrieval and Scoring (ERS) 2 | 3 | Module to retrieve relevant evidences from (heterogeneous) information sources. 4 | 5 | - [Create your own ERS module](#create-your-own-ers-module) 6 | - [`inference_on_turn` function](#inference_on_turn-function) 7 | - [`store_cache` function](#optional-store_cache-function) 8 | - [`train` function](#optional-train-function) 9 | - [Available information sources](#available-information-sources) 10 | - [Wikidata access](#wikidata-access) 11 | - [Wikipedia access](#wikipedia-access) 12 | - [Evidences format](#evidences-format) 13 | 14 | 15 | ## Create your own ERS module 16 | You can inherit from the [`EvidenceRetrievalScoring`](evidence_retrieval_scoring.py) class and create your own ERS module. 17 | Implementing the function `inference_on_turn` is sufficient for the pipeline to run. 18 | In case you would like to store intermediate retrieval results, make sure to implement the `store_cache` function, which is called after the ERS module is run to store any data. 19 | Further, you need to instantiate a logger in the class, which will be used in the parent class. 20 | Alternatively, you can call the __init__ method of the parent class. 21 | Please find further details below. 22 | 23 | 24 | ## `inference_on_turn` function 25 | 26 | **Inputs**: 27 | - `turn`: turn that evidences are retrieved for. Will have the intent-explicit form of the current question in `turn["structured_representation"]`. 28 | - `sources`: list of input sources. 29 | 30 | **Description**: 31 | For the given intent-explicit representation of the question, retrieve relevant evidences from heterogeneous information sources. Depending on your individual module implemented, the initially retrieved evidences require to be scored to identify the top-*e* most relevant ones (*e* is defined by `evs_max_evidences` in the config). 32 | 33 | **Output**: 34 | Returns the top-*e* evidences. However, the current pipeline does not make use of the return value. 35 | Make sure to also store these evidences in `turn["top_evidences"]`. In your implementation, make sure that the config parameter `evs_max_evidences` controls the amount of evidences going into the HA part. 36 | 37 | 38 | ## [Optional] `store_cache` function 39 | 40 | **Inputs**: NONE 41 | 42 | **Description**: 43 | Whatever intermediate retrieval results you obtain in your implementation of the class, you can store these on disk to re-use them in a future run (e.g. for efficiency or reproducability). The default implementation (in [`EvidenceRetrievalScoring`](evidence_retrieval_scoring.py)) does not do anything. If you do not require storing any data, you can simply skip this function. 44 | 45 | **Output**: NONE 46 | 47 | ## [Optional] `train` function 48 | 49 | **Inputs**: NONE 50 | 51 | **Description**: 52 | If required, you can train your ERS module here. You can make use of whatever parameters are stored in your .yml file. 53 | 54 | **Output**: NONE 55 | 56 | ## Available information sources 57 | The following information sources are implemented in the native CONVINSE pipeline: 58 | - `"kb"`: KB-facts from Wikidata, 59 | - `"text"`: text-snippets (sentence-level) from Wikipedia, 60 | - `"table"`: table-records (row-level) from Wikipedia, 61 | - `"info"`: infobox-entries (attribute-value-pairs) from Wikipedia. 62 | 63 | For specifying specific combinations of information sources (e.g. for retrieval, training,...), you can either adjust the respective parameters in the config, or provide them as argument to the bash script. E.g. giving the option "kb_text_info" specifies that "kb", "text" and "info" should be set. 64 | 65 | ## Wikidata access 66 | For accessing Wikidata, you can make use of the [`ClocqRetriever`](clocq_er.py) class. 67 | You can: 68 | 1) retrieve relevant KB-facts for a given input snippet, using [CLOCQ](https://clocq.mpi-inf.mpg.de)'s search space retrieval functionality via the `retrieve_evidences` function, specifying the desired [input information sources](#available-information-sources) in a list, 69 | 2) retrieve relevant KB-facts for a given input snippet, using [CLOCQ](https://clocq.mpi-inf.mpg.de)'s search space retrieval functionality via the `retrieve_KB_facts` function, which will only return evidences from the KB, or 70 | 4) retrieve KB-facts for a given Wikidata item ID via the `retrieve_kb_facts_for_item` function. 71 | 72 | The CLOCQ parameters in the config will be used as input for the CLOCQ functions. 73 | For quickly getting started, you can make use of the publicly available [CLOCQ API](https://clocq.mpi-inf.mpg.de), which is the default setup. 74 | For more efficient access, you can run the CLOCQ algorithm on your local machine. Note, that this comes with quite some memory requirements of \~400 GB. 75 | 76 | ## Wikipedia access 77 | For accessing Wikipedia text, tables and infoboxes, you can use the [`ClocqRetriever`](clocq_er.py) class, or directly use the [`WikipediaRetriever`](wikipedia_retriever/wikipedia_retriever.py) package. 78 | You can: 79 | 1) retrieve facts from Wikipedia for a given input snippet, using [CLOCQ](https://clocq.mpi-inf.mpg.de)'s search space retrieval functionality via the `retrieve_evidences` function, specifying the desired [input information sources](#available-information-sources) in a list, 80 | 2) retrieve facts from Wikipedia for a given Wikidata item ID, using the `retrieve_wikipedia_evidences` function in the [`ClocqRetriever`](clocq_er.py) class, 81 | 3) retrieve facts from Wikipedia for a given Wikidata item ID, using the `retrieve_wp_evidences` function in the [`WikipediaRetriever`](wikipedia_retriever/wikipedia_retriever.py) package. You can adjust this function as required. Make sure to include the `retrieved_for_entity` key to the resulting evidences (not taken care of in this function). 82 | 83 | Either way, the pipeline would try to read evidences from the cache, or the Wikipedia dump (specified with the `ers_wikipedia_dump` keyword in the config). 84 | The parameter `ers_on_the_fly` controls, whether the Wikipedia API is called on-the-fly to retrieve evidences for entities that are not included in the specified Wikipedia dump. If `ers_on_the_fly=False`, an empty list of evidences will be returned in case an entity is not included. 85 | 86 | ## Evidences format 87 | Evidences are stored and processed in the following format. If you plan your own implementation of the ERS module, make sure that you match this format. 88 | 89 | ``` json 90 | { 91 | "evidence_text": "", 92 | "source": "kb|text|table|infobox", 93 | "disambiguations": [["", "ITEM_ID>"], ], 94 | "wikidata_entities": [{"id": "", "label": "