├── LICENSE ├── README.md ├── conda_env.yml ├── dialdoc ├── models │ └── rag │ │ ├── configuration_rag_dialdoc.py │ │ ├── distributed_pytorch_retriever.py │ │ ├── modeling_rag_dialdoc.py │ │ └── retrieval_rag_dialdoc.py └── utils │ └── utils_rag.py ├── scripts ├── convert_dpr_original_checkpoint_to_pytorch.py ├── data_preprocessor.py ├── hf_datasets │ └── doc2dial │ │ └── doc2dial_pub.py ├── model_converter.py ├── rag │ ├── callbacks_rag.py │ ├── eval_rag.py │ ├── finetune_rag_dialdoc.py │ ├── lightning_base.py │ ├── use_own_knowledge_dataset.py │ └── utils_rag.py ├── run_converter.sh ├── run_converter_modelcard.sh ├── run_data_preprocessing.sh ├── run_data_preprocessing_domain.sh ├── run_data_preprocessing_dpr.sh ├── run_data_preprocessing_dpr_domain.sh ├── run_download.sh ├── run_eval_rag_e2e.sh ├── run_eval_rag_re.sh ├── run_finetune_rag_dialdoc.sh ├── run_kb_index.sh ├── run_kb_index_domain.sh ├── run_sharedtask_eval.sh └── sharedtask_eval.py └── sharedtask ├── README.md └── sample_files ├── sample_task_grounding_predictions.json ├── sample_task_references.json └── sample_task_utterance_predictions.json /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiDoc2Dial: Modeling Dialogues Grounded in Multiple Documents 2 | This repository provides data and code for the corresponding [paper](https://arxiv.org/abs/2109.12595) "MultiDoc2Dial: Modeling Dialogues Grounded in Multiple Documents" (EMNLP 2021) by Song Feng *, Siva Sankalp Patel*, Wan Hui and Sachindra Joshi. 3 | Please cite the paper and star the repository if you find the paper, data and code useful for your work. 4 | 5 | ```bibtex 6 | @inproceedings{feng2021multidoc2dial, 7 | title={MultiDoc2Dial: Modeling Dialogues Grounded in Multiple Documents}, 8 | author={Feng, Song and Patel, Siva Sankalp and Wan, Hui and Joshi, Sachindra}, 9 | booktitle={EMNLP}, 10 | year={2021} 11 | } 12 | ``` 13 | 14 | ## Installation 15 | 16 | Please refer to `conda_env.yml` for creating a virtual environment. 17 | 18 | ```bash 19 | conda env create -f conda_env.yml 20 | ``` 21 | 22 | Our scripts require to set the following environment variables, 23 | - `HF_HOME`- for caching downloads from [Huggingface](https://huggingface.co/transformers/v4.0.1/installation.html#caching-models) locally. 24 | - `CHECKPOINTS` for saving the checkpoints. 25 | 26 | ## Data 27 | 28 | Please run the commands to download data. It will download the document and dialogue data into folder `data/multidoc2dial`. 29 | 30 | ```bash 31 | cd scripts 32 | ./run_download.sh 33 | ``` 34 | 35 | ### Document preprocessing 36 | 37 | To segment the document into passages, please refer to 38 | > [`run_data_preprocessing.sh`](scripts/run_data_preprocessing.sh) 39 | 40 | ### Data preprocessing for fine-tuning DPR 41 | 42 | If you are finetuning DPR on MultiDoc2Dial, please refer to [`run_data_preprocessing_dpr.sh`](scripts/run_data_preprocessing_dpr.sh) create positive and negative examples in the format of [DPR](https://github.com/facebookresearch/DPR). 43 | 44 | ## Run Baselines 45 | 46 | ### Finetuning DPR 47 | 48 | To finetune DPR, we use Facebook [DPR](https://github.com/facebookresearch/DPR) (March 2021 release) with an effective batch size 128. You can finetune DPR on MultiDoc2Dial data yourself ; or use our finetuned version. 49 | 50 | *If you would like to finetune DPR yourself*, please refer to Facebook [DPR](https://github.com/facebookresearch/DPR) for detailed instructions. 51 | 52 | Or 53 | 54 | *If you would like to use our finetuned DPR encoders*, please use the the following paths as the model path to ctx or question encoder (for instance, [`run_converter_modelcard.sh`](scripts/run_converter_modelcard.sh)), 55 | - `sivasankalpp/dpr-multidoc2dial-token-question-encoder` for fine-tuned DPR question encoder based on token-segmented document passages ([link](https://huggingface.co/sivasankalpp/dpr-multidoc2dial-token-question-encoder)) 56 | - `sivasankalpp/dpr-multidoc2dial-token-ctx-encoder` for fine-tuned DPR ctx encoder based on token-segmented document passages ([link](https://huggingface.co/sivasankalpp/dpr-multidoc2dial-token-ctx-encoder)) 57 | - `sivasankalpp/dpr-multidoc2dial-structure-question-encoder` fine-tuned DPR question encoder based on structure-segmented document passages ([link](https://huggingface.co/sivasankalpp/dpr-multidoc2dial-structure-question-encoder)) 58 | - `sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder` for fine-tuned DPR ctx encoder based on structure-segmented document passages ([link](https://huggingface.co/sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder)) 59 | 60 | ### Using finetuned DPR encoders in RAG 61 | 62 | *If you obtain your own finetuned DPR checkpoints,* 63 | 1. Download the following files from RAG model cards to "../data" folder 64 | - 65 | - 66 | 67 | 2. Convert your fine-tuned DPR checkpoint and add it to RAG model. Please refer to [`run_converter.sh`](scripts/run_converter.sh). 68 | 69 | OR 70 | 71 | *If you use our finetuned DPR encoders*, please refer to [`run_converter_modelcard.sh`](scripts/run_converter_modelcard.sh). 72 | 73 | 74 | ### Finetuning RAG 75 | 76 | Our implementation is based on [Huggingface RAG](https://huggingface.co/docs/transformers/master/model_doc/rag). Please refer to their [README](https://github.com/huggingface/transformers/tree/master/examples/research_projects/rag#readme) for more detailed explanations on document retrieval and finetuning RAG. 77 | 78 | To create FAISS index, please refer to 79 | > [`run_kb_index.sh`](scripts/run_kb_index.sh) 80 | 81 | To finetune RAG on MultiDoc2Dial data, please refer to 82 | > [`run_finetune_rag.sh`](scripts/run_finetune_rag.sh) 83 | 84 | ## Evaluations 85 | 86 | To evaluate the retrieval results (recall@n for passage and document level), please refer to 87 | > [`run_eval_rag_re.sh`](scripts/run_eval_rag_re.sh) 88 | 89 | To evaluate the generation results, please refer to 90 | > [`run_eval_rag_e2e.sh`](scripts/run_eval_rag_e2e.sh) 91 | 92 | ## Results 93 | 94 | The evaluation results on the validation set of agent response generation task Please refer to the `scripts` for corresponding hyperparameters. 95 | 96 | | Model |F1 | EM| BLEU| r@1 | r@5 | r@10 | 97 | | ----------- | ---- | ---- | ---- | ---- | ---- | ---- | 98 | | D-token-nq | 30.9 | 2.8 | 15.7 | 25.8 | 48.2 | 57.7 | 99 | | D-struct-nq | 31.5 | 3.2 | 16.6 | 27.4 | 51.1 | 60.2 | 100 | | D-token-ft | 33.2 | 3.4 | 18.8 | 35.2 | 63.4 | 72.9 | 101 | | D-struct-ft | 33.7 | 3.5 | 19.5 | 37.5 | 67.0 | 75.8 | 102 | 103 | ## Leaderboard 104 | 105 | Please check out our [**leaderboard**](https://eval.ai/web/challenges/challenge-page/1437/overview) and [**Shared Task**](http://doc2dial.github.io/workshop2022/#shared). 106 | 107 | ## Acknowledgement 108 | 109 | Our code is based on [Huggingface Transformers](https://github.com/huggingface/transformers). Our dataset is based on [Doc2Dial](https://arxiv.org/abs/2011.06623). We thank the authors for sharing their great work. 110 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: multidoc2dial 2 | channels: 3 | - pytorch 4 | - defaults 5 | - nvidia 6 | dependencies: 7 | - python=3.9 8 | - pip 9 | - pytorch=1.9.1 10 | - cudatoolkit=11.1 11 | - pip: 12 | - tensorboard==2.5.0 13 | - tensorboardX==2.1 14 | - transformers==4.12.1 15 | - pytorch_lightning==1.1.8 16 | - datasets==1.16.1 17 | - GitPython 18 | - psutil 19 | - sentencepiece 20 | - rouge-score 21 | - sacrebleu 22 | - rank-bm25 23 | - faiss-cpu -------------------------------------------------------------------------------- /dialdoc/models/rag/configuration_rag_dialdoc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020, The RAG Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ RAG model configuration """ 16 | 17 | import copy 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.file_utils import add_start_docstrings 21 | from transformers.models.rag.configuration_rag import RagConfig 22 | 23 | 24 | RAG_CONFIG_DOC = r""" 25 | :class:`~transformers.RagConfig` stores the configuration of a `RagModel`. Configuration objects inherit from 26 | :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from 27 | :class:`~transformers.PretrainedConfig` for more information. 28 | 29 | Args: 30 | title_sep (:obj:`str`, `optional`, defaults to ``" / "``): 31 | Separator inserted between the title and the text of the retrieved document when calling 32 | :class:`~transformers.RagRetriever`. 33 | doc_sep (:obj:`str`, `optional`, defaults to ``" // "``): 34 | Separator inserted between the the text of the retrieved document and the original input when calling 35 | :class:`~transformers.RagRetriever`. 36 | n_docs (:obj:`int`, `optional`, defaults to 5): 37 | Number of documents to retrieve. 38 | max_combined_length (:obj:`int`, `optional`, defaults to 300): 39 | Max length of contextualized input returned by :meth:`~transformers.RagRetriever.__call__`. 40 | retrieval_vector_size (:obj:`int`, `optional`, defaults to 768): 41 | Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`. 42 | retrieval_batch_size (:obj:`int`, `optional`, defaults to 8): 43 | Retrieval batch size, defined as the number of queries issues concurrently to the faiss index encapsulated 44 | :class:`~transformers.RagRetriever`. 45 | dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`): 46 | A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and ids 47 | using :obj:`datasets.list_datasets()`). 48 | dataset_split (:obj:`str`, `optional`, defaults to :obj:`"train"`) 49 | Which split of the :obj:`dataset` to load. 50 | index_name (:obj:`str`, `optional`, defaults to :obj:`"compressed"`) 51 | The index name of the index associated with the :obj:`dataset`. One can choose between :obj:`"legacy"`, 52 | :obj:`"exact"` and :obj:`"compressed"`. 53 | index_path (:obj:`str`, `optional`) 54 | The path to the serialized faiss index on disk. 55 | passages_path: (:obj:`str`, `optional`): 56 | A path to text passages compatible with the faiss index. Required if using 57 | :class:`~transformers.models.rag.retrieval_rag.LegacyIndex` 58 | use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``) 59 | Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`. 60 | label_smoothing (:obj:`float`, `optional`, defaults to 0.0): 61 | Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label 62 | smoothing in the loss calculation. If set to 0, no label smoothing is performed. 63 | do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`): 64 | If :obj:`True`, the logits are marginalized over all documents by making use of 65 | ``torch.nn.functional.log_softmax``. 66 | reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`): 67 | Whether or not to reduce the NLL loss using the ``torch.Tensor.sum`` operation. 68 | do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`): 69 | Whether or not to deduplicate the generations from different context documents for a given input. Has to be 70 | set to :obj:`False` if used while training with distributed backend. 71 | exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`): 72 | Whether or not to disregard the BOS token when computing the loss. 73 | output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`): 74 | If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and 75 | :obj:`context_attention_mask` are returned. See returned tensors for more detail. 76 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 77 | Whether or not the model should return the last key/values attentions (not used by all models). 78 | forced_eos_token_id (:obj:`int`, `optional`): 79 | The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to 80 | :obj:`eos_token_id`. 81 | """ 82 | 83 | 84 | @add_start_docstrings(RAG_CONFIG_DOC) 85 | class DialDocRagConfig(RagConfig): 86 | model_type = "rag" 87 | is_composition = True 88 | 89 | def __init__(self, mapping_file=None, segmentation=None, scoring_func="reranking", dataset="multidoc2dial", *args, **kwargs): 90 | self.mapping_file = mapping_file 91 | self.segmentation = segmentation 92 | self.dataset = dataset 93 | self.scoring_func = scoring_func 94 | super(DialDocRagConfig, self).__init__(*args, **kwargs) 95 | -------------------------------------------------------------------------------- /dialdoc/models/rag/distributed_pytorch_retriever.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | import psutil 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from dialdoc.models.rag.retrieval_rag_dialdoc import DialDocRagRetriever 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class RagPyTorchDistributedRetriever(DialDocRagRetriever): 17 | """ 18 | A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers 19 | initialize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored 20 | in cpu memory. The index will also work well in a non-distributed setup. 21 | 22 | Args: 23 | config (:class:`~transformers.RagConfig`): 24 | The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. 25 | question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`): 26 | The tokenizer that was used to tokenize the question. 27 | It is used to decode the question and then use the generator_tokenizer. 28 | generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`): 29 | The tokenizer used for the generator part of the RagModel. 30 | index (:class:`~transformers.models.rag.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): 31 | If specified, use this index instead of the one built using the configuration 32 | """ 33 | 34 | def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None): 35 | super().__init__( 36 | config, 37 | question_encoder_tokenizer=question_encoder_tokenizer, 38 | generator_tokenizer=generator_tokenizer, 39 | index=index, 40 | init_retrieval=False, 41 | ) 42 | self.process_group = None 43 | 44 | def init_retrieval(self, distributed_port: int): 45 | """ 46 | Retriever initialization function, needs to be called from the training process. The function sets some common parameters 47 | and environment variables. On top of that, (only) the main process in the process group loads the index into memory. 48 | 49 | Args: 50 | distributed_port (:obj:`int`): 51 | The port on which the main communication of the training run is carried out. We set the port for retrieval-related 52 | communication as ``distributed_port + 1``. 53 | """ 54 | 55 | logger.info("initializing retrieval") 56 | 57 | # initializing a separate process group for retrieval as the default 58 | # nccl backend doesn't support gather/scatter operations while gloo 59 | # is too slow to replace nccl for the core gpu communication 60 | if dist.is_initialized(): 61 | logger.info("dist initialized") 62 | # needs to be set manually 63 | os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname() 64 | # avoid clash with the NCCL port 65 | os.environ["MASTER_PORT"] = str(distributed_port + 1) 66 | self.process_group = dist.new_group(ranks=None, backend="gloo") 67 | 68 | # initialize retriever only on the main worker 69 | if not dist.is_initialized() or self._is_main(): 70 | logger.info("dist not initialized / main") 71 | self.index.init_index() 72 | 73 | # all processes wait untill the retriever is initialized by the main process 74 | if dist.is_initialized(): 75 | torch.distributed.barrier(group=self.process_group) 76 | 77 | def _is_main(self): 78 | return dist.get_rank(group=self.process_group) == 0 79 | 80 | def _scattered(self, scatter_list, target_shape, target_type=torch.float32): 81 | target_tensor = torch.empty(target_shape, dtype=target_type) 82 | dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group) 83 | return target_tensor 84 | 85 | def _infer_socket_ifname(self): 86 | addrs = psutil.net_if_addrs() 87 | # a hacky way to deal with varying network interface names 88 | ifname = next((addr for addr in addrs if addr.startswith("e")), None) 89 | return ifname 90 | 91 | def retrieve( 92 | self, 93 | combined_hidden_states: np.ndarray, 94 | current_hidden_states: np.ndarray, 95 | history_hidden_states: np.ndarray, 96 | n_docs: int, 97 | dialog_lengths: List[Tuple] = None, 98 | domain: List[str] = None, 99 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]: 100 | """ 101 | Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries 102 | from all the processes in the main training process group, performs the retrieval and scatters back the results. 103 | 104 | Args: 105 | question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): 106 | A batch of query vectors to retrieve with. 107 | n_docs (:obj:`int`): 108 | The number of docs retrieved per query. 109 | 110 | Output: 111 | retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` 112 | The retrieval embeddings of the retrieved docs per query. 113 | doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) 114 | The ids of the documents in the index 115 | doc_dicts (:obj:`List[dict]`): 116 | The retrieved_doc_embeds examples per query. 117 | """ 118 | 119 | # single GPU training 120 | if not dist.is_initialized(): 121 | doc_ids, retrieved_doc_embeds, doc_scores = self._main_retrieve( 122 | combined_hidden_states, current_hidden_states, history_hidden_states, n_docs, dialog_lengths, domain 123 | ) 124 | return retrieved_doc_embeds, doc_ids, doc_scores, self.index.get_doc_dicts(doc_ids) 125 | 126 | # distributed training 127 | world_size = dist.get_world_size(group=self.process_group) 128 | 129 | # gather logic 130 | gather_list_1 = None 131 | gather_list_2 = None 132 | gather_list_3 = None 133 | if self._is_main(): 134 | gather_list_1 = [torch.empty(combined_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] 135 | gather_list_2 = [torch.empty(current_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] 136 | gather_list_3 = [torch.empty(history_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] 137 | dist.gather(torch.tensor(combined_hidden_states), dst=0, gather_list=gather_list_1, group=self.process_group) 138 | dist.gather(torch.tensor(current_hidden_states), dst=0, gather_list=gather_list_2, group=self.process_group) 139 | dist.gather(torch.tensor(history_hidden_states), dst=0, gather_list=gather_list_3, group=self.process_group) 140 | 141 | # scatter logic 142 | n_queries = combined_hidden_states.shape[0] 143 | scatter_ids = [] 144 | scatter_vectors = [] 145 | scatter_scores = [] 146 | if self._is_main(): 147 | assert len(gather_list_1) == len(gather_list_2) == len(gather_list_3) == world_size 148 | comb_h_s = torch.cat(gather_list_1).numpy() 149 | curr_h_s = torch.cat(gather_list_2).numpy() 150 | hist_h_s = torch.cat(gather_list_3).numpy() 151 | ids, vectors, scores = self._main_retrieve(comb_h_s, curr_h_s, hist_h_s, n_docs, dialog_lengths, domain) 152 | ids, vectors, scores = torch.tensor(ids), torch.tensor(vectors), torch.tensor(scores) 153 | scatter_ids = self._chunk_tensor(ids, n_queries) 154 | scatter_vectors = self._chunk_tensor(vectors, n_queries) 155 | scatter_scores = self._chunk_tensor(scores, n_queries) 156 | 157 | doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64) 158 | retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, combined_hidden_states.shape[1]]) 159 | doc_scores = self._scattered(scatter_scores, [n_queries, n_docs], torch.float64) 160 | 161 | return retrieved_doc_embeds.numpy(), doc_ids.numpy(), doc_scores.numpy(), self.index.get_doc_dicts(doc_ids) 162 | -------------------------------------------------------------------------------- /dialdoc/models/rag/retrieval_rag_dialdoc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import os 4 | import time 5 | import torch 6 | import numpy as np 7 | import json 8 | 9 | from transformers.models.rag.retrieval_rag import ( 10 | HFIndexBase, 11 | RagRetriever, 12 | LegacyIndex, 13 | CustomHFIndex, 14 | CanonicalHFIndex, 15 | LEGACY_INDEX_PATH, 16 | ) 17 | from transformers.models.rag.tokenization_rag import RagTokenizer 18 | from transformers.file_utils import requires_backends 19 | from transformers.tokenization_utils_base import BatchEncoding 20 | 21 | from transformers.utils import logging 22 | 23 | from dialdoc.models.rag.configuration_rag_dialdoc import DialDocRagConfig 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class DialDocIndex(CustomHFIndex): 29 | def load_pid_domain_mapping(self, mapping_file): 30 | with open(mapping_file, "r") as f_in: 31 | map = json.load(f_in) 32 | 33 | new_map = {} 34 | for k, v in map.items(): 35 | new_map[int(k)] = v 36 | del map 37 | self.mapping = new_map 38 | 39 | def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 40 | scores, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs) 41 | docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids] 42 | vectors = [doc["embeddings"] for doc in docs] 43 | for i in range(len(vectors)): 44 | if len(vectors[i]) < n_docs: 45 | vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))]) 46 | return ( 47 | np.array(ids), 48 | np.array(vectors), 49 | np.array(scores), 50 | ) # shapes (batch_size, n_docs), (batch_size, n_docs, d) and (batch_size, n_docs) 51 | 52 | def search_batch_domain(self, embeddings, domain, n_docs=5): 53 | scores, ids = self.dataset.search_batch("embeddings", embeddings, 1200) 54 | filtered_scores, filtered_ids = [], [] 55 | for i in range(len(ids)): 56 | dom = domain[i] 57 | f_s, f_id = [], [] 58 | for score, id in zip(scores[i], ids[i]): 59 | if id != -1 and self.mapping[id] == dom: 60 | f_s.append(score) 61 | f_id.append(id) 62 | if len(f_id) == n_docs: 63 | filtered_scores.append(f_s) 64 | filtered_ids.append(f_id) 65 | break 66 | if 0 < len(f_id) < n_docs: ## bandage for cases where the retriever finds less than n_docs 67 | while len(f_id) < n_docs: 68 | f_id.append(f_id[0]) 69 | f_s.append(f_s[0]) 70 | filtered_scores.append(f_s) 71 | filtered_ids.append(f_id) 72 | ## TODO: what happens if none of the retrieved docs are not in GT domain 73 | 74 | return filtered_scores, filtered_ids 75 | 76 | def get_top_docs_domain( 77 | self, question_hidden_states: np.ndarray, domain, n_docs=5 78 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 79 | scores, ids = self.search_batch_domain(question_hidden_states, domain, n_docs) 80 | docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids] 81 | vectors = [doc["embeddings"] for doc in docs] 82 | for i in range(len(vectors)): 83 | if len(vectors[i]) < n_docs: 84 | vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))]) 85 | 86 | return ( 87 | np.array(ids), 88 | np.array(vectors), 89 | np.array(scores), 90 | ) # shapes (batch_size, n_docs), (batch_size, n_docs, d) and (batch_size, n_docs) 91 | 92 | def get_top_docs_rerank_domain( 93 | self, 94 | combined_hidden_states: np.ndarray, 95 | current_hidden_states: np.ndarray, 96 | n_docs=5, 97 | dialog_lengths=None, 98 | domain=None, 99 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 100 | scores1, ids1 = self.search_batch_domain(combined_hidden_states, domain, n_docs) 101 | scores2, ids2 = self.search_batch_domain(current_hidden_states, domain, n_docs) 102 | ids3 = [[None] * (n_docs * 2)] * len(ids1) 103 | scores3 = [[0] * (n_docs * 2)] * len(ids1) 104 | scores = [] 105 | ids = [] 106 | for r in range(len(ids1)): 107 | if dialog_lengths: 108 | if dialog_lengths[r][0] < 10: 109 | ids.append(ids1[r]) 110 | scores.append(scores1[r]) 111 | continue 112 | n1, n2 = len(ids1[r]), len(ids2[r]) 113 | i = j = k = 0 114 | while i < n1 and j < n2: 115 | if scores1[r][i] >= scores2[r][j]: 116 | ids3[r][k] = ids1[r][i] 117 | scores3[r][k] = scores1[r][i] 118 | k, i = k + 1, i + 1 119 | else: 120 | ids3[r][k] = ids2[r][j] 121 | scores3[r][k] = scores2[r][i] 122 | k, j = k + 1, j + 1 123 | while i < n1: 124 | ids3[r][k] = ids1[r][i] 125 | scores3[r][k] = scores1[r][i] 126 | k, i = k + 1, i + 1 127 | while j < n2: 128 | ids3[r][k] = ids2[r][j] 129 | scores3[r][k] = scores2[r][j] 130 | k, j = k + 1, j + 1 131 | ids_new = [] 132 | scores_new = [] 133 | for ii, ele in enumerate(ids3[r]): 134 | if ele not in ids_new: 135 | ids_new.append(ele) 136 | scores_new.append(scores3[r][ii]) 137 | ids.append(ids_new[:n_docs]) 138 | scores.append(scores_new[:n_docs]) 139 | docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids] 140 | vectors = [doc["embeddings"] for doc in docs] 141 | for i in range(len(vectors)): 142 | if len(vectors[i]) < n_docs: 143 | vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))]) 144 | return np.array(ids), np.array(vectors), np.array(scores) 145 | 146 | def get_top_docs_multihandle( 147 | self, 148 | current_hidden_states: np.ndarray, 149 | history_hidden_states: np.ndarray, 150 | scoring_func, 151 | n_docs=5, 152 | dialog_lengths=None, 153 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 154 | total_docs = len(self.dataset) 155 | scores_current, ids_current = self.dataset.search_batch("embeddings", current_hidden_states, 500) 156 | scores_history, ids_history = self.dataset.search_batch("embeddings", history_hidden_states, 500) 157 | 158 | final_scores = [] 159 | final_ids = [] 160 | for i in range(len(ids_current)): 161 | ids_current_i, scores_current_i = ids_current[i], scores_current[i] 162 | ids_history_i, scores_history_i = ids_history[i], scores_history[i] 163 | 164 | scaling_factor = None 165 | if dialog_lengths: 166 | curr_length, history_length = dialog_lengths[i] 167 | scaling_factor = 1.2 if curr_length > 10 else 1.0 168 | 169 | ## common ids between question and history 170 | common_ids = set(ids_current_i).intersection(set(ids_history_i)) 171 | common_ids = {i for i in common_ids if i >= 0} 172 | if len(common_ids) < n_docs: 173 | logger.info("Only {} common ids found".format(len(common_ids))) 174 | logger.info( 175 | "Picking the best ids from top matches with current turn and adding them to common_ids until we reach n_docs={}".format( 176 | n_docs 177 | ) 178 | ) 179 | new_ids = [] 180 | for id in ids_current_i: 181 | if len(common_ids) == n_docs: 182 | break 183 | if id not in common_ids: 184 | new_ids.append(id) 185 | common_ids.add(id) 186 | 187 | ids_current_i_common, scores_current_i_common = self.filter_ids( 188 | common_ids, ids_current_i, scores_current_i 189 | ) 190 | ids_history_i_common, scores_history_i_common = self.filter_ids( 191 | common_ids, ids_history_i, scores_history_i 192 | ) 193 | 194 | doc_dicts = self.get_doc_dicts(np.array(new_ids)) 195 | for j, id in enumerate(new_ids): 196 | ids_history_i_common.append(id) 197 | score = np.inner(history_hidden_states[i], doc_dicts[j]["embeddings"]) 198 | scores_history_i_common.append(score) 199 | 200 | assert len(ids_current_i_common) == len(ids_history_i_common) 201 | 202 | else: 203 | ## only keep ids and scores that are common between question and history 204 | ids_current_i_common, scores_current_i_common = self.filter_ids( 205 | common_ids, ids_current_i, scores_current_i 206 | ) 207 | ids_history_i_common, scores_history_i_common = self.filter_ids( 208 | common_ids, ids_history_i, scores_history_i 209 | ) 210 | 211 | assert len(ids_current_i_common) == len(ids_history_i_common) 212 | 213 | ## sort by ids 214 | q_doc_ids, q_doc_scores = zip(*sorted(zip(ids_current_i_common, scores_current_i_common))) 215 | h_doc_ids, h_doc_scores = zip(*sorted(zip(ids_history_i_common, scores_history_i_common))) 216 | 217 | q_doc_ids, q_doc_scores = list(q_doc_ids), list(q_doc_scores) 218 | h_doc_ids, h_doc_scores = list(h_doc_ids), list(h_doc_scores) 219 | 220 | assert q_doc_ids == h_doc_ids 221 | 222 | ## Combine scores using scoring function 223 | rescored_ids = [] 224 | rescored_scores = [] 225 | for id, q_score, h_score in zip(q_doc_ids, q_doc_scores, h_doc_scores): 226 | rescored_ids.append(id) 227 | inp = torch.Tensor([q_score, h_score]) 228 | if scaling_factor: 229 | rescored_scores.append(scoring_func(inp, scaling_factor).tolist()) 230 | else: 231 | rescored_scores.append(scoring_func(inp).tolist()) 232 | 233 | rescored_scores, rescored_ids = zip(*sorted(zip(rescored_scores, rescored_ids), reverse=True)) 234 | rescored_scores, rescored_ids = list(rescored_scores), list(rescored_ids) 235 | 236 | final_ids.append(rescored_ids[:n_docs]) 237 | final_scores.append(rescored_scores[:n_docs]) 238 | 239 | docs = [self.dataset[[i for i in indices if i >= 0]] for indices in final_ids] 240 | vectors = [doc["embeddings"] for doc in docs] 241 | for i in range(len(vectors)): 242 | if len(vectors[i]) < n_docs: 243 | vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))]) 244 | return ( 245 | np.array(final_ids), 246 | np.array(vectors), 247 | np.array(final_scores), 248 | ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) 249 | 250 | def get_top_docs_rerank( 251 | self, 252 | combined_hidden_states: np.ndarray, 253 | current_hidden_states: np.ndarray, 254 | n_docs=5, 255 | dialog_lengths=None, 256 | domain=None, 257 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 258 | scores1, ids1 = self.dataset.search_batch("embeddings", combined_hidden_states, n_docs) 259 | scores2, ids2 = self.dataset.search_batch("embeddings", current_hidden_states, n_docs) 260 | ids3 = [[None] * (n_docs * 2)] * len(ids1) 261 | scores3 = [[0] * (n_docs * 2)] * len(ids1) 262 | scores = [] 263 | ids = [] 264 | for r in range(len(ids1)): 265 | if dialog_lengths: 266 | if dialog_lengths[r][0] < 10: 267 | ids.append(ids1[r]) 268 | scores.append(scores1[r]) 269 | continue 270 | n1, n2 = len(ids1[r]), len(ids2[r]) 271 | i = j = k = 0 272 | while i < n1 and j < n2: 273 | if scores1[r][i] >= scores2[r][j]: 274 | ids3[r][k] = ids1[r][i] 275 | scores3[r][k] = scores1[r][i] 276 | k, i = k + 1, i + 1 277 | else: 278 | ids3[r][k] = ids2[r][j] 279 | scores3[r][k] = scores2[r][i] 280 | k, j = k + 1, j + 1 281 | while i < n1: 282 | ids3[r][k] = ids1[r][i] 283 | scores3[r][k] = scores1[r][i] 284 | k, i = k + 1, i + 1 285 | while j < n2: 286 | ids3[r][k] = ids2[r][j] 287 | scores3[r][k] = scores2[r][j] 288 | k, j = k + 1, j + 1 289 | ids_new = [] 290 | scores_new = [] 291 | for ii, ele in enumerate(ids3[r]): 292 | if ele not in ids_new: 293 | ids_new.append(ele) 294 | scores_new.append(scores3[r][ii]) 295 | ids.append(ids_new[:n_docs]) 296 | scores.append(scores_new[:n_docs]) 297 | docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids] 298 | vectors = [doc["embeddings"] for doc in docs] 299 | for i in range(len(vectors)): 300 | if len(vectors[i]) < n_docs: 301 | vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))]) 302 | return np.array(ids), np.array(vectors), np.array(scores) 303 | 304 | 305 | def get_top_n_indices(bm25, query, n=5): 306 | query = query.lower().split() 307 | scores = bm25.get_scores(query) 308 | scores_i = [(i, score) for i, score in enumerate(scores)] 309 | sorted_indices = sorted(scores_i, key=lambda score: score[1], reverse=True) 310 | return sorted_indices[:n] 311 | 312 | 313 | class DialDocRagRetriever(RagRetriever): 314 | def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True): 315 | super().__init__( 316 | config, question_encoder_tokenizer, generator_tokenizer, index=index, init_retrieval=init_retrieval 317 | ) 318 | if config.scoring_func in ["domain", "reranking_domain"]: 319 | self.index.load_pid_domain_mapping(config.mapping_file) 320 | 321 | if config.scoring_func == "nonlinear": 322 | logger.info("Using nonlinear scorer in RagRetriever") 323 | self.nn_scorer = torch.nn.Sequential( 324 | torch.nn.Linear(2, 2), torch.nn.ReLU(), torch.nn.Linear(2, 1), torch.nn.ReLU() 325 | ) 326 | 327 | @staticmethod 328 | def _build_index(config): 329 | if config.index_name == "legacy": 330 | return LegacyIndex( 331 | config.retrieval_vector_size, 332 | config.index_path or LEGACY_INDEX_PATH, 333 | ) 334 | elif config.index_name == "custom": 335 | return CustomHFIndex.load_from_disk( 336 | vector_size=config.retrieval_vector_size, 337 | dataset_path=config.passages_path, 338 | index_path=config.index_path, 339 | ) 340 | elif config.index_name == "dialdoc": 341 | return DialDocIndex.load_from_disk( 342 | vector_size=config.retrieval_vector_size, 343 | dataset_path=config.passages_path, 344 | index_path=config.index_path, 345 | ) 346 | else: 347 | return CanonicalHFIndex( 348 | vector_size=config.retrieval_vector_size, 349 | dataset_name=config.dataset, 350 | dataset_split=config.dataset_split, 351 | index_name=config.index_name, 352 | index_path=config.index_path, 353 | use_dummy_dataset=config.use_dummy_dataset, 354 | ) 355 | 356 | @classmethod 357 | def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs): 358 | requires_backends(cls, ["datasets", "faiss"]) 359 | config = kwargs.pop("config", None) or DialDocRagConfig.from_pretrained(retriever_name_or_path, **kwargs) 360 | rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) 361 | question_encoder_tokenizer = rag_tokenizer.question_encoder 362 | generator_tokenizer = rag_tokenizer.generator 363 | if indexed_dataset is not None: 364 | config.index_name = "custom" 365 | index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset) 366 | else: 367 | index = cls._build_index(config) 368 | return cls( 369 | config, 370 | question_encoder_tokenizer=question_encoder_tokenizer, 371 | generator_tokenizer=generator_tokenizer, 372 | index=index, 373 | ) 374 | 375 | def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None): 376 | r""" 377 | Postprocessing retrieved ``docs`` and combining them with ``input_strings``. 378 | 379 | Args: 380 | docs (:obj:`dict`): 381 | Retrieved documents. 382 | input_strings (:obj:`str`): 383 | Input strings decoded by ``preprocess_query``. 384 | prefix (:obj:`str`): 385 | Prefix added at the beginning of each input, typically used with T5-based models. 386 | 387 | Return: 388 | :obj:`tuple(tensors)`: a tuple consisting of two elements: contextualized ``input_ids`` and a compatible 389 | ``attention_mask``. 390 | """ 391 | 392 | def cat_input_and_doc(doc_title, doc_text, input_string, prefix): 393 | if doc_title.startswith('"'): 394 | doc_title = doc_title[1:] 395 | if doc_title.endswith('"'): 396 | doc_title = doc_title[:-1] 397 | if prefix is None: 398 | prefix = "" 399 | out = (prefix + input_string + self.config.doc_sep + doc_text).replace(" ", " ") 400 | 401 | return out 402 | 403 | rag_input_strings = [ 404 | cat_input_and_doc( 405 | docs[i]["title"][j], 406 | docs[i]["text"][j], 407 | input_strings[i], 408 | prefix, 409 | ) 410 | for i in range(len(docs)) 411 | for j in range(n_docs) 412 | ] 413 | 414 | contextualized_inputs = self.generator_tokenizer.batch_encode_plus( 415 | rag_input_strings, 416 | max_length=self.config.max_combined_length, 417 | return_tensors=return_tensors, 418 | padding="max_length", 419 | truncation=True, 420 | ) 421 | 422 | return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"] 423 | 424 | def _main_retrieve( 425 | self, 426 | combined_hidden_states: np.ndarray, 427 | current_hidden_states: np.ndarray, 428 | history_hidden_states: np.ndarray, 429 | n_docs: int, 430 | dialog_lengths: List[Tuple] = None, 431 | domain: List[str] = None, 432 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 433 | def linear(a: List[int]): 434 | return sum(a) 435 | 436 | def linear2(a: List[int]): 437 | return a[0] + 0.5 * a[1] 438 | 439 | def linear3(a: List[int], scaling_factor=1): 440 | return scaling_factor * a[0] + 0.5 * a[1] 441 | 442 | def nonlinear(a: List[int]): 443 | with torch.no_grad(): 444 | return self.nn_scorer(a) 445 | 446 | combined_hidden_states_batched = self._chunk_tensor(combined_hidden_states, self.batch_size) 447 | current_hidden_states_batched = self._chunk_tensor(current_hidden_states, self.batch_size) 448 | history_hidden_states_batched = self._chunk_tensor(history_hidden_states, self.batch_size) 449 | if (domain is None or len(domain) == 0) and self.config.scoring_func != "domain": 450 | domain_batched = [[""]] * len(combined_hidden_states_batched) 451 | else: 452 | domain_batched = self._chunk_tensor(domain, self.batch_size) 453 | ids_batched = [] 454 | vectors_batched = [] 455 | scores_batched = [] 456 | for comb_h_s, curr_h_s, hist_h_s, dom_batch in zip( 457 | combined_hidden_states_batched, 458 | current_hidden_states_batched, 459 | history_hidden_states_batched, 460 | domain_batched, 461 | ): 462 | start_time = time.time() 463 | if self.config.scoring_func in ["linear", "linear2", "linear3", "nonlinear"]: 464 | if self.config.scoring_func == "linear": 465 | dialog_lengths = None 466 | scoring_func = linear 467 | elif self.config.scoring_func == "linear2": 468 | dialog_lengths = None 469 | scoring_func = linear2 470 | elif self.config.scoring_func == "linear3": 471 | scoring_func = linear3 472 | else: 473 | dialog_lengths = None 474 | scoring_func = nonlinear 475 | ids, vectors, scores = self.index.get_top_docs_multihandle( 476 | curr_h_s, hist_h_s, scoring_func, n_docs, dialog_lengths=dialog_lengths 477 | ) 478 | elif self.config.scoring_func in ["reranking_original", "reranking"]: 479 | ids, vectors, scores = self.index.get_top_docs_rerank(comb_h_s, curr_h_s, n_docs, None, dom_batch) 480 | elif self.config.scoring_func == "reranking2": 481 | ids, vectors, scores = self.index.get_top_docs_rerank( 482 | comb_h_s, curr_h_s, n_docs, dialog_lengths=dialog_lengths 483 | ) 484 | elif self.config.scoring_func in ["current_original", "current_pooled"]: 485 | ids, vectors, scores = self.index.get_top_docs(curr_h_s, n_docs) 486 | elif self.config.scoring_func in ["domain"]: 487 | ids, vectors, scores = self.index.get_top_docs_domain(comb_h_s, dom_batch, n_docs) 488 | elif self.config.scoring_func in ["reranking_domain"]: 489 | ids, vectors, scores = self.index.get_top_docs_rerank_domain( 490 | comb_h_s, curr_h_s, n_docs, None, dom_batch 491 | ) 492 | else: 493 | ids, vectors, scores = self.index.get_top_docs(comb_h_s, n_docs) 494 | logger.debug(f"index search time: {time.time() - start_time} sec, batch size {comb_h_s.shape}") 495 | ids_batched.extend(ids) 496 | vectors_batched.extend(vectors) 497 | scores_batched.extend(scores) 498 | return ( 499 | np.array(ids_batched), 500 | np.array(vectors_batched), 501 | np.array(scores_batched), 502 | ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) 503 | 504 | def retrieve( 505 | self, 506 | combined_hidden_states: np.ndarray, 507 | current_hidden_states: np.ndarray, 508 | history_hidden_states: np.ndarray, 509 | n_docs: int, 510 | dialog_lengths: List[Tuple] = None, 511 | domain: List[str] = None, 512 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]: 513 | """ 514 | Retrieves documents for specified ``question_hidden_states``. 515 | 516 | Args: 517 | question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): 518 | A batch of query vectors to retrieve with. 519 | n_docs (:obj:`int`): 520 | The number of docs retrieved per query. 521 | 522 | Return: 523 | :obj:`Tuple[np.ndarray, np.ndarray, List[dict]]`: A tuple with the following objects: 524 | 525 | - **retrieved_doc_embeds** (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`) -- The retrieval 526 | embeddings of the retrieved docs per query. 527 | - **doc_ids** (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs)`) -- The ids of the documents in the 528 | index 529 | - **doc_dicts** (:obj:`List[dict]`): The :obj:`retrieved_doc_embeds` examples per query. 530 | """ 531 | 532 | doc_ids, retrieved_doc_embeds, doc_scores = self._main_retrieve( 533 | combined_hidden_states, current_hidden_states, history_hidden_states, n_docs, dialog_lengths, domain 534 | ) 535 | return retrieved_doc_embeds, doc_ids, doc_scores, self.index.get_doc_dicts(doc_ids) 536 | 537 | def __call__( 538 | self, 539 | question_input_ids: List[List[int]], 540 | combined_hidden_states: np.ndarray, 541 | current_hidden_states: np.ndarray, 542 | history_hidden_states: np.ndarray, 543 | dialog_lengths: List[Tuple], 544 | domain: List[str], 545 | prefix=None, 546 | n_docs=None, 547 | return_tensors=None, 548 | bm25=None, 549 | ) -> BatchEncoding: 550 | """ 551 | Retrieves documents for specified :obj:`question_hidden_states`. 552 | 553 | Args: 554 | question_input_ids: (:obj:`List[List[int]]`) batch of input ids 555 | question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`: 556 | A batch of query vectors to retrieve with. 557 | prefix: (:obj:`str`, `optional`): 558 | The prefix used by the generator's tokenizer. 559 | n_docs (:obj:`int`, `optional`): 560 | The number of docs retrieved per query. 561 | return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`, defaults to "pt"): 562 | If set, will return tensors instead of list of python integers. Acceptable values are: 563 | 564 | * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. 565 | * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. 566 | * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. 567 | 568 | Returns: :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following 569 | fields: 570 | 571 | - **context_input_ids** -- List of token ids to be fed to a model. 572 | 573 | `What are input IDs? <../glossary.html#input-ids>`__ 574 | 575 | - **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model 576 | (when :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`). 577 | 578 | `What are attention masks? <../glossary.html#attention-mask>`__ 579 | 580 | - **retrieved_doc_embeds** -- List of embeddings of the retrieved documents 581 | - **doc_ids** -- List of ids of the retrieved documents 582 | """ 583 | 584 | n_docs = n_docs if n_docs is not None else self.n_docs 585 | prefix = prefix if prefix is not None else self.config.generator.prefix 586 | 587 | input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True) 588 | if self.config.bm25: 589 | doc_ids = [] 590 | doc_scores = [] 591 | for input_string in input_strings: 592 | # doc_ids.append(self.config.bm25.get(input_string, [])[:self.config.n_docs]) 593 | # doc_scores = ??? 594 | sorted_indices = get_top_n_indices(bm25, input_string, self.config.n_docs) 595 | doc_ids.append([x[0] for x in sorted_indices]) 596 | doc_scores.append([x[-1] for x in sorted_indices]) 597 | docs = self.index.get_doc_dicts(np.array(doc_ids)) 598 | 599 | retrieved_doc_embeds = [docs[i]["embeddings"] for i in range(len(doc_ids))] 600 | else: 601 | retrieved_doc_embeds, doc_ids, doc_scores, docs = self.retrieve( 602 | combined_hidden_states=combined_hidden_states, 603 | current_hidden_states=current_hidden_states, 604 | history_hidden_states=history_hidden_states, 605 | n_docs=n_docs, 606 | dialog_lengths=dialog_lengths, 607 | domain=domain, 608 | ) 609 | context_input_ids, context_attention_mask = self.postprocess_docs( 610 | docs, input_strings, prefix, n_docs, return_tensors=return_tensors 611 | ) 612 | 613 | return BatchEncoding( 614 | { 615 | "context_input_ids": context_input_ids, 616 | "context_attention_mask": context_attention_mask, 617 | "retrieved_doc_embeds": retrieved_doc_embeds, 618 | "doc_ids": doc_ids, 619 | "doc_scores": doc_scores, 620 | }, 621 | tensor_type=return_tensors, 622 | ) -------------------------------------------------------------------------------- /dialdoc/utils/utils_rag.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | from pathlib import Path 3 | from typing import Dict 4 | from torch.utils.data import Dataset 5 | 6 | import torch 7 | 8 | from transformers import BartTokenizer, RagTokenizer, T5Tokenizer 9 | 10 | 11 | def load_bm25_results(in_path): 12 | d_query_results = {} 13 | return d_query_results 14 | 15 | 16 | def load_bm25(in_path): 17 | from rank_bm25 import BM25Okapi 18 | 19 | dataset = load_dataset("csv", data_files=[in_path], split="train", delimiter="\t", column_names=["title", "text"]) 20 | passages = [] 21 | for ex in dataset: 22 | for ele in ex["text"].split("####"): 23 | passages.append(ele) 24 | passages_tokenized = [passage.strip().lower().split() for passage in passages] 25 | bm25 = BM25Okapi(passages_tokenized) 26 | return bm25 27 | 28 | 29 | def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 30 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} 31 | tokenizer.padding_side = padding_side 32 | return tokenizer( 33 | [line], 34 | max_length=max_length, 35 | padding="max_length" if pad_to_max_length else None, 36 | truncation=True, 37 | return_tensors=return_tensors, 38 | add_special_tokens=True, 39 | **extra_kw, 40 | ) 41 | 42 | 43 | def encode_line2(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 44 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} 45 | tokenizer.padding_side = padding_side 46 | line = tuple(line.split("[SEP]")) 47 | return tokenizer( 48 | [line], 49 | max_length=max_length, 50 | padding="max_length" if pad_to_max_length else None, 51 | truncation=True, 52 | return_tensors=return_tensors, 53 | add_special_tokens=True, 54 | **extra_kw, 55 | ) 56 | 57 | 58 | def trim_batch( 59 | input_ids, 60 | pad_token_id, 61 | attention_mask=None, 62 | ): 63 | """Remove columns that are populated exclusively by pad_token_id""" 64 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 65 | if attention_mask is None: 66 | return input_ids[:, keep_column_mask] 67 | else: 68 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 69 | 70 | 71 | class Seq2SeqDataset(Dataset): 72 | def __init__( 73 | self, 74 | tokenizer, 75 | data_dir, 76 | max_source_length, 77 | max_target_length, 78 | type_path="train", 79 | n_obs=None, 80 | src_lang=None, 81 | tgt_lang=None, 82 | prefix="", 83 | ): 84 | super().__init__() 85 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 86 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 87 | self.src_lens = self.get_char_lens(self.src_file) 88 | self.max_source_length = max_source_length 89 | self.max_target_length = max_target_length 90 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 91 | self.tokenizer = tokenizer 92 | self.prefix = prefix 93 | if n_obs is not None: 94 | self.src_lens = self.src_lens[:n_obs] 95 | self.src_lang = src_lang 96 | self.tgt_lang = tgt_lang 97 | 98 | def __len__(self): 99 | return len(self.src_lens) 100 | 101 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 102 | index = index + 1 # linecache starts at 1 103 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 104 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 105 | assert source_line, f"empty source line for index {index}" 106 | assert tgt_line, f"empty tgt line for index {index}" 107 | 108 | # Need to add eos token manually for T5 109 | if isinstance(self.tokenizer, T5Tokenizer): 110 | source_line += self.tokenizer.eos_token 111 | tgt_line += self.tokenizer.eos_token 112 | 113 | # Pad source and target to the right 114 | source_tokenizer = ( 115 | self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer 116 | ) 117 | target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer 118 | 119 | source_inputs = encode_line2(source_tokenizer, source_line, self.max_source_length, "right") 120 | target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right") 121 | 122 | source_ids = source_inputs["input_ids"].squeeze() 123 | target_ids = target_inputs["input_ids"].squeeze() 124 | src_mask = source_inputs["attention_mask"].squeeze() 125 | src_token_type_ids = source_inputs["token_type_ids"].squeeze() 126 | return { 127 | "input_ids": source_ids, 128 | "attention_mask": src_mask, 129 | "token_type_ids": src_token_type_ids, 130 | "decoder_input_ids": target_ids, 131 | } 132 | 133 | @staticmethod 134 | def get_char_lens(data_file): 135 | return [len(x) for x in Path(data_file).open().readlines()] 136 | 137 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 138 | input_ids = torch.stack([x["input_ids"] for x in batch]) 139 | masks = torch.stack([x["attention_mask"] for x in batch]) 140 | token_type_ids = torch.stack([x["token_type_ids"] for x in batch]) 141 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 142 | tgt_pad_token_id = ( 143 | self.tokenizer.generator.pad_token_id 144 | if isinstance(self.tokenizer, RagTokenizer) 145 | else self.tokenizer.pad_token_id 146 | ) 147 | src_pad_token_id = ( 148 | self.tokenizer.question_encoder.pad_token_id 149 | if isinstance(self.tokenizer, RagTokenizer) 150 | else self.tokenizer.pad_token_id 151 | ) 152 | y = trim_batch(target_ids, tgt_pad_token_id) 153 | source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks) 154 | keep_col_mask = input_ids.ne(src_pad_token_id).any(dim=0) 155 | token_type_ids = token_type_ids[:, keep_col_mask] 156 | batch = { 157 | "input_ids": source_ids, 158 | "attention_mask": source_mask, 159 | "token_type_ids": token_type_ids, 160 | "decoder_input_ids": y, 161 | } 162 | return batch 163 | -------------------------------------------------------------------------------- /scripts/convert_dpr_original_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import collections 17 | from pathlib import Path 18 | 19 | import torch 20 | from torch.serialization import default_restore_location 21 | 22 | from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader 23 | 24 | 25 | CheckpointState = collections.namedtuple( 26 | "CheckpointState", ["model_dict", "optimizer_dict", "scheduler_dict", "offset", "epoch", "encoder_params"] 27 | ) 28 | 29 | 30 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 31 | print(f"Reading saved model from {model_file}") 32 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, "cpu")) 33 | return CheckpointState(**state_dict) 34 | 35 | 36 | class DPRState: 37 | def __init__(self, src_file: Path): 38 | self.src_file = src_file 39 | 40 | def load_dpr_model(self): 41 | raise NotImplementedError 42 | 43 | @staticmethod 44 | def from_type(comp_type: str, *args, **kwargs) -> "DPRState": 45 | if comp_type.startswith("c"): 46 | return DPRContextEncoderState(*args, **kwargs) 47 | if comp_type.startswith("q"): 48 | return DPRQuestionEncoderState(*args, **kwargs) 49 | if comp_type.startswith("r"): 50 | return DPRReaderState(*args, **kwargs) 51 | else: 52 | raise ValueError("Component type must be either 'ctx_encoder', 'question_encoder' or 'reader'.") 53 | 54 | 55 | class DPRContextEncoderState(DPRState): 56 | def load_dpr_model(self): 57 | model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) 58 | print(f"Loading DPR biencoder from {self.src_file}") 59 | saved_state = load_states_from_checkpoint(self.src_file) 60 | encoder, prefix = model.ctx_encoder, "ctx_model." 61 | # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 62 | state_dict = {"bert_model.embeddings.position_ids": model.ctx_encoder.bert_model.embeddings.position_ids} 63 | for key, value in saved_state.model_dict.items(): 64 | if key.startswith(prefix): 65 | key = key[len(prefix) :] 66 | if not key.startswith("encode_proj."): 67 | key = "bert_model." + key 68 | state_dict[key] = value 69 | encoder.load_state_dict(state_dict) 70 | return model 71 | 72 | 73 | class DPRQuestionEncoderState(DPRState): 74 | def load_dpr_model(self): 75 | model = DPRQuestionEncoder(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) 76 | print(f"Loading DPR biencoder from {self.src_file}") 77 | saved_state = load_states_from_checkpoint(self.src_file) 78 | encoder, prefix = model.question_encoder, "question_model." 79 | # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 80 | state_dict = {"bert_model.embeddings.position_ids": model.question_encoder.bert_model.embeddings.position_ids} 81 | for key, value in saved_state.model_dict.items(): 82 | if key.startswith(prefix): 83 | key = key[len(prefix) :] 84 | if not key.startswith("encode_proj."): 85 | key = "bert_model." + key 86 | state_dict[key] = value 87 | encoder.load_state_dict(state_dict) 88 | return model 89 | 90 | 91 | class DPRReaderState(DPRState): 92 | def load_dpr_model(self): 93 | model = DPRReader(DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) 94 | print(f"Loading DPR reader from {self.src_file}") 95 | saved_state = load_states_from_checkpoint(self.src_file) 96 | # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 97 | state_dict = { 98 | "encoder.bert_model.embeddings.position_ids": model.span_predictor.encoder.bert_model.embeddings.position_ids 99 | } 100 | for key, value in saved_state.model_dict.items(): 101 | if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"): 102 | key = "encoder.bert_model." + key[len("encoder.") :] 103 | state_dict[key] = value 104 | model.span_predictor.load_state_dict(state_dict) 105 | return model 106 | 107 | 108 | def convert(comp_type: str, src_file: Path, dest_dir: Path): 109 | dest_dir = Path(dest_dir) 110 | dest_dir.mkdir(exist_ok=True) 111 | 112 | dpr_state = DPRState.from_type(comp_type, src_file=src_file) 113 | model = dpr_state.load_dpr_model() 114 | model.save_pretrained(dest_dir) 115 | model.from_pretrained(dest_dir) # sanity check 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser() 120 | # Required parameters 121 | parser.add_argument( 122 | "--type", type=str, help="Type of the component to convert: 'ctx_encoder', 'question_encoder' or 'reader'." 123 | ) 124 | parser.add_argument( 125 | "--src", 126 | type=str, 127 | help="Path to the dpr checkpoint file. They can be downloaded from the official DPR repo https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the 'retriever' checkpoints.", 128 | ) 129 | parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model directory.") 130 | args = parser.parse_args() 131 | 132 | src_file = Path(args.src) 133 | dest_dir = f"converted-{src_file.name}" if args.dest is None else args.dest 134 | dest_dir = Path(dest_dir) 135 | assert src_file.exists() 136 | assert ( 137 | args.type is not None 138 | ), "Please specify the component type of the DPR model to convert: 'ctx_encoder', 'question_encoder' or 'reader'." 139 | convert(args.type, src_file, dest_dir) 140 | -------------------------------------------------------------------------------- /scripts/data_preprocessor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | import csv 5 | import sys 6 | from pathlib import Path 7 | from collections import defaultdict 8 | from tqdm import tqdm 9 | from rank_bm25 import BM25Okapi 10 | from datasets import load_dataset 11 | 12 | 13 | DOMAINS = ["va", "ssa", "dmv", "studentaid"] 14 | SEP = "####" # separator for passages 15 | 16 | sys.path.insert(2, str(Path(__file__).resolve().parents[1])) 17 | 18 | 19 | def rm_blank(text, is_shortern=False): 20 | text = text.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "") 21 | if is_shortern: 22 | text = text[3:-3] 23 | return text 24 | 25 | 26 | def text2line(text): 27 | return text.replace("\n", " ").replace("\r", " ").replace("\t", " ").strip() 28 | 29 | 30 | def split_text_section(spans, title, args): 31 | def get_text(buff, title, span): 32 | text = " ".join(buff).replace("\n", " ") 33 | parent_titles = [title.replace("/", "-").rsplit("#")[0]] 34 | if len(span["parent_titles"]["text"]) > 1: 35 | parent_titles = [ele.replace("/", "-").rsplit("#")[0] for ele in span["parent_titles"]["text"]] 36 | text = " / ".join(parent_titles) + " // " + text 37 | return text2line(text) 38 | 39 | buff = [] 40 | pre_sec, pre_title, pre_span = None, None, None 41 | passages = [] 42 | subtitles = [] 43 | for span in spans: 44 | parent_titles = title 45 | if len(span["parent_titles"]["text"]) > 1: 46 | parent_titles = [ele.replace("/", "-").rsplit("#")[0] for ele in span["parent_titles"]["text"]] 47 | parent_titles = " / ".join(parent_titles) 48 | if pre_sec == span["id_sec"] or pre_title == span["title"].strip(): 49 | buff.append(span["text_sp"]) 50 | elif buff: 51 | text = get_text(buff, title, pre_span) 52 | passages.append(text) 53 | subtitles.append(parent_titles) 54 | buff = [span["text_sp"]] 55 | else: 56 | buff.append(span["text_sp"]) 57 | pre_sec = span["id_sec"] 58 | pre_span = span 59 | pre_title = span["title"].strip() 60 | if buff: 61 | text = get_text(buff, title, span) 62 | passages.append(text) 63 | subtitles.append(parent_titles) 64 | return passages, subtitles 65 | 66 | 67 | def split_text(text: str, n=100, character=" "): 68 | """Split the text every ``n``-th occurrence of ``character``""" 69 | text = text.split(character) 70 | passages = [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)] 71 | return [passage for passage in passages if len(passage) > 0] 72 | 73 | 74 | def get_bm25(passages): 75 | passages_tokenized = [passage.strip().lower().split() for passage in passages] 76 | bm25 = BM25Okapi(passages_tokenized) 77 | return bm25 78 | 79 | 80 | def get_top_n_indices(bm25, query, n=5): 81 | query = query.lower().split() 82 | scores = bm25.get_scores(query) 83 | scores_i = [(i, score) for i, score in enumerate(scores)] 84 | sorted_indices = sorted(scores_i, key=lambda score: score[1], reverse=True) 85 | return [x[0] for x in sorted_indices[:n]] 86 | 87 | 88 | def get_positive_passages(positive_pids, doc_scores, passage_map): 89 | """ 90 | Get positive passages for a given grounding using BM25 scores from the positive passage pool 91 | Parameters: 92 | positive_pids: list 93 | Positive passage indices 94 | doc_scores: list 95 | BM25 scores against the query's grounding for all passages 96 | passage_map: dict 97 | All passages mapped with their ids 98 | Returns: 99 | positive_passage_pool 100 | """ 101 | scores = [(i, score) for (i, score) in doc_scores if i in positive_pids] 102 | top_scores = sorted(scores, key=lambda x: x[1], reverse=True) 103 | 104 | top_n_passages = [ 105 | {"psg_id": ix, "score": score, "title": passage_map[ix]["title"], "text": passage_map[ix]["text"]} 106 | for ix, score in top_scores 107 | ] 108 | 109 | return top_n_passages 110 | 111 | 112 | def get_negative_passages(positive_pids, doc_scores, passage_map, begin=5, n=10): 113 | """ 114 | Get hard negative passages for a given grounding using BM25 scores across all passages. 115 | Filter out all passages from the query's positive passage pool 116 | """ 117 | scores = [(i, score) for (i, score) in doc_scores if i not in positive_pids] 118 | top_scores = sorted(scores, key=lambda x: x[1], reverse=True) 119 | negative_passages = [ 120 | {"psg_id": ix, "score": score, "title": passage_map[ix]["title"], "text": passage_map[ix]["text"]} 121 | for ix, score in top_scores[begin : begin + n] 122 | ] 123 | assert len(negative_passages) == n 124 | return negative_passages 125 | 126 | 127 | def create_dpr_data(args): 128 | dd = DD_Loader(args) 129 | args.split = "train" if not args.split else args.split 130 | dd.get_doc_passages(args) 131 | doc_passages = dd.d_doc_psg 132 | all_passages = dd.doc_psg_all 133 | all_domains = dd.doc_domain_all 134 | 135 | d_in = dd.get_dial(args) 136 | source = d_in["source"] 137 | target = d_in["target"] 138 | qids = d_in["qid"] 139 | titles = d_in["title"] 140 | pids = d_in["pid"] 141 | domains = d_in["domain"] 142 | 143 | passage_map = {} 144 | for title in doc_passages: 145 | psg_start_ix = doc_passages[title][0] 146 | n_psgs = doc_passages[title][1] 147 | for i in range(n_psgs): 148 | passage_map[psg_start_ix + i] = {"text": all_passages[psg_start_ix + i], "title": title} 149 | 150 | # Create passage index using BM25 151 | print("Creating passage index ...") 152 | bm25 = get_bm25(all_passages) 153 | 154 | dataset = [] 155 | for qid, query, grounding, title, pid_pos, domain in tqdm( 156 | zip(qids, source, target, titles, pids, domains), total=len(source), desc="Creating dataset ..." 157 | ): 158 | if args.last_turn_only: 159 | query = query.split("[SEP]")[0].strip() 160 | scores_g = bm25.get_scores(grounding.strip().lower().split()) 161 | if args.in_domain_only: 162 | doc_scores_g = [] 163 | for idx, score in enumerate(scores_g): 164 | if dd.doc_domain_all[idx] == domain: 165 | doc_scores_g.append((idx, score)) 166 | else: 167 | doc_scores_g = [(i, score) for i, score in enumerate(scores_g)] 168 | positive_passages = get_positive_passages( 169 | positive_pids=pid_pos, doc_scores=doc_scores_g, passage_map=passage_map 170 | ) 171 | hard_negative_passages = get_negative_passages( 172 | positive_pids=pid_pos, doc_scores=doc_scores_g, passage_map=passage_map 173 | ) 174 | scores_q = bm25.get_scores(query.strip().lower().split()) 175 | if args.in_domain_only: 176 | doc_scores_q = [] 177 | for idx, score in enumerate(scores_q): 178 | if all_domains[idx] == domain: 179 | doc_scores_q.append((idx, score)) 180 | else: 181 | doc_scores_q = [(i, score) for i, score in enumerate(scores_q)] 182 | negative_passages = get_negative_passages( 183 | positive_pids=pid_pos, doc_scores=doc_scores_q, passage_map=passage_map 184 | ) 185 | sample = { 186 | "dataset": args.dataset_config_name, 187 | "qid": qid, 188 | "question": query, 189 | "answers": [grounding], 190 | "positive_ctxs": positive_passages, 191 | "negative_ctxs": negative_passages, 192 | "hard_negative_ctxs": hard_negative_passages, 193 | } 194 | dataset.append(sample) 195 | os.makedirs(args.output_dir, exist_ok=True) 196 | if args.target_domain: 197 | config = f"{args.dataset_config_name}.{args.segmentation}" 198 | else: 199 | config = f"{args.dataset_config_name}_all.{args.segmentation}" 200 | outfile = os.path.join(args.output_dir, f"dpr.{config}.{args.split}.json") 201 | print("Writing dataset to {}".format(outfile)) 202 | with open(outfile, "w") as f: 203 | json.dump(dataset, f, indent=4) 204 | passage_file = os.path.join(args.output_dir, f"dpr.psg.{config}.json") 205 | passages = [] 206 | for k, v in sorted(passage_map.items()): 207 | v.update({"id": k}) 208 | passages.append(v) 209 | with open(passage_file, "w") as f: 210 | json.dump(passages, f, indent=4) 211 | 212 | 213 | def map_passages(grounding, all_psgs, start_idx, num_psg): 214 | mapping = [] 215 | for start in range(start_idx, start_idx + num_psg): 216 | current_mapping = [] 217 | for end in range(start + 1, start_idx + num_psg + 1): 218 | content = "".join(all_psgs[start:end]) 219 | if grounding in content or rm_blank(grounding.lower(), True) in rm_blank(content.lower()): 220 | current_mapping = list(range(start, end)) 221 | if len(current_mapping) == 1: 222 | return current_mapping 223 | elif len(current_mapping) > 1: 224 | break 225 | if current_mapping: 226 | mapping = current_mapping 227 | return mapping 228 | 229 | 230 | def load_doc_dataset(args): 231 | doc_data = load_dataset(args.dataset_name, "document_domain", split="train", ignore_verifications=True) 232 | return doc_data 233 | 234 | 235 | class DD_Loader: 236 | def __init__(self, args) -> None: 237 | self.doc_dataset = load_doc_dataset(args) 238 | self.dial_dataset = load_dataset( 239 | args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, ignore_verifications=True 240 | ) 241 | self.d_doc_data = defaultdict(dict) # doc -> "doc_text", "spans" 242 | self.d_doc_psg = {} 243 | self.doc_psg_all = [] 244 | self.doc_domain_all = [] 245 | self.d_pid_domain = {} 246 | 247 | def reset(self): 248 | self.d_doc_data = defaultdict(dict) 249 | self.d_doc_psg = {} 250 | self.doc_psg_all = [] 251 | self.doc_domain_all = [] 252 | self.d_pid_domain = {} 253 | 254 | def get_doc_passages(self, args): 255 | # self.doc_dataset = load_doc_dataset(args) 256 | start_idx = 0 257 | for ex in self.doc_dataset: 258 | if args.target_domain and ex["domain"] not in args.included_domains: 259 | continue 260 | if args.segmentation == "token": 261 | passages = split_text(ex["doc_text"]) 262 | else: 263 | passages, subtitles = split_text_section(ex["spans"], ex["title"], args) 264 | self.doc_psg_all.extend(passages) 265 | self.doc_domain_all.extend([ex["domain"]] * len(passages)) 266 | self.d_doc_psg[ex["doc_id"]] = (start_idx, len(passages)) 267 | for i in range(start_idx, start_idx + len(passages)): 268 | self.d_pid_domain[i] = ex["domain"] 269 | start_idx += len(passages) 270 | self.d_doc_data[ex["doc_id"]]["doc_text"] = ex["doc_text"] 271 | self.d_doc_data[ex["doc_id"]]["spans"] = {} 272 | self.d_doc_data[ex["doc_id"]]["domain"] = ex["domain"] 273 | for d_span in ex["spans"]: 274 | self.d_doc_data[ex["doc_id"]]["spans"][d_span["id_sp"]] = d_span 275 | 276 | def get_dial(self, args): 277 | source, target, qids, titles, pids, domains, das = [], [], [], [], [], [], [] 278 | # self.dial_dataset = load_dataset( 279 | # args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, ignore_verifications=True 280 | # ) 281 | 282 | for ex in self.dial_dataset[args.split]: 283 | qid = ex["id"] 284 | doc_id = ex["title"] 285 | query = ex["question"] 286 | domain = ex["domain"] 287 | da = ex["da"] 288 | if args.num_token > 0: 289 | query = " ".join(query.split()[: args.num_token]) 290 | grounding = ex["answers"]["text"][0] 291 | utterance = ex.get("utterance", "") 292 | source_txt = text2line(query) 293 | target_txt = text2line(utterance) if args.task == "generation" else text2line(grounding) 294 | if not source_txt or not target_txt: 295 | continue 296 | start_idx, num_psg = self.d_doc_psg[doc_id] 297 | pids_pos = map_passages(grounding, self.doc_psg_all, start_idx, num_psg) 298 | source.append(source_txt) 299 | target.append(target_txt) 300 | qids.append(qid) 301 | pids.append(pids_pos) 302 | titles.append(doc_id) 303 | domains.append(domain) 304 | das.append(da) 305 | d_out = { 306 | "source": source, 307 | "target": target, 308 | "qid": qids, 309 | "title": titles, 310 | "pid": pids, 311 | "domain": domains, 312 | "da": das, 313 | } 314 | return d_out 315 | 316 | def save_kb_files(self, args): 317 | os.makedirs(args.kb_dir, exist_ok=True) 318 | if args.target_domain and len(args.included_domains) > 1: 319 | config = f"{args.segmentation}-wo-{args.target_domain}" 320 | elif args.target_domain and len(args.included_domains) == 1: 321 | config = f"{args.segmentation}-{args.target_domain}" 322 | else: 323 | config = f"{args.segmentation}-all" 324 | with open( 325 | os.path.join(args.kb_dir, f"mdd-{config}.csv"), 326 | "w", 327 | encoding="utf8", 328 | ) as fp: 329 | csv_writer = csv.writer(fp, delimiter="\t") 330 | for k, (start_id, num_psg) in self.d_doc_psg.items(): 331 | psgs = [text2line(e) for e in self.doc_psg_all[start_id : start_id + num_psg]] 332 | csv_writer.writerow([k, SEP.join(psgs)]) 333 | with open(os.path.join(args.kb_dir, f"pid_domain-{config}.json"), "w", encoding="utf8") as fp: 334 | json.dump(self.d_pid_domain, fp, indent=4) 335 | 336 | def save_dial_files(self, args, d_in): 337 | sp = "val" if args.split == "validation" else args.split 338 | if not args.output_dir: 339 | args.output_dir = f"data_mdd_wo_{args.target_domain}" 340 | od = f"{args.output_dir}/dd-{args.task}-{args.segmentation}" 341 | os.makedirs(od, exist_ok=True) 342 | source = d_in["source"] 343 | target = d_in["target"] 344 | qids = d_in["qid"] 345 | titles = d_in["title"] 346 | pids = d_in["pid"] 347 | domains = d_in["domain"] 348 | das = d_in["da"] 349 | 350 | with open(os.path.join(od, f"{sp}.domain"), "w", encoding="utf8") as fp: 351 | fp.write("\n".join(domains)) 352 | with open(os.path.join(od, f"{sp}.da"), "w", encoding="utf8") as fp: 353 | fp.write("\n".join(das)) 354 | with open(os.path.join(od, f"{sp}.source"), "w", encoding="utf8") as fp: 355 | fp.write("\n".join(source)) 356 | with open(os.path.join(od, f"{sp}.target"), "w", encoding="utf8") as fp: 357 | fp.write("\n".join(target)) 358 | with open(os.path.join(od, f"{sp}.qids"), "w", encoding="utf8") as fp: 359 | fp.write("\n".join(qids)) 360 | with open(os.path.join(od, f"{sp}.titles"), "w", encoding="utf8") as fp: 361 | fp.write("\n".join(titles)) 362 | with open(os.path.join(od, f"{sp}.pids"), "w", encoding="utf8") as fp: 363 | lines_pid = [] 364 | for ids in pids: 365 | lines_pid.append("\t".join([str(e) for e in ids])) 366 | fp.write("\n".join(lines_pid)) 367 | 368 | 369 | def main(): 370 | parser = argparse.ArgumentParser() 371 | parser.add_argument( 372 | "--dataset_name", 373 | type=str, 374 | default="hf_datasets/doc2dial/doc2dial_pub.py", 375 | help="dataset name or path for data loader", 376 | ) 377 | parser.add_argument( 378 | "--dataset_config_name", 379 | type=str, 380 | default="multidoc2dial", 381 | help="hugging dataset config name", 382 | ) 383 | parser.add_argument( 384 | "--target_domain", 385 | type=str, 386 | default="", # default is empty, which indicates that all domains are included. 387 | help="target or test domain in domain adaptation setup, one domain from ssa, va, dmv, studentaid", 388 | ) 389 | parser.add_argument( 390 | "--output_dir", 391 | type=str, 392 | required=True, 393 | help="path to output the data files", 394 | ) 395 | parser.add_argument( 396 | "--kb_dir", 397 | type=str, 398 | default="YOUR_DIR/data_mdd_kb", 399 | help="path to output kb data files", 400 | ) 401 | parser.add_argument( 402 | "--cache_dir", 403 | type=str, 404 | default=os.environ["HF_HOME"], 405 | help="Path for caching the downloaded data by HuggingFace Datasets", 406 | ) 407 | parser.add_argument( 408 | "--split", 409 | type=str, 410 | default="", 411 | help="Data split is 'train', 'validation' or 'test'", 412 | ) 413 | parser.add_argument( 414 | "--last_turn_only", 415 | type=bool, 416 | help="Only include the latest turn in dialogue", 417 | ) 418 | parser.add_argument( 419 | "--segmentation", 420 | type=str, 421 | default="structure", 422 | help="`token` or `structure`", 423 | ) 424 | parser.add_argument( 425 | "--num_token", 426 | type=int, 427 | default=-1, 428 | help="number of tokens of a query; -1 indicates all tokens", 429 | ) 430 | parser.add_argument( 431 | "--task", 432 | default="grounding", 433 | help="task: grounding, generation", 434 | ) 435 | parser.add_argument( 436 | "--dpr", 437 | action="store_true", 438 | help="generate DPR data", 439 | ) 440 | parser.add_argument( 441 | "--in_domain_only", 442 | action="store_true", 443 | help="bm25 retrievals within domain", 444 | ) 445 | 446 | args = parser.parse_args() 447 | if not args.dataset_config_name: 448 | args.dataset_config_name = "multidoc2dial" 449 | if args.target_domain: 450 | args.dataset_config_name = f"multidoc2dial_{args.target_domain}" 451 | if not args.dpr: 452 | dd = DD_Loader(args) 453 | splits = [args.split] if args.split else ["train", "validation", "test"] # test split at last 454 | if not args.target_domain: 455 | dd.get_doc_passages(args) 456 | dd.save_kb_files(args) 457 | for split in splits: 458 | args.split = split 459 | d_out = dd.get_dial(args) 460 | dd.save_dial_files(args, d_out) 461 | else: 462 | for split in splits: 463 | args.split = split 464 | if split == "test": 465 | args.included_domains = [args.target_domain] 466 | dd.reset() 467 | else: 468 | args.included_domains = [ele for ele in DOMAINS if ele != args.target_domain] 469 | if not dd.doc_psg_all: 470 | dd.get_doc_passages(args) 471 | d_out = dd.get_dial(args) 472 | dd.save_kb_files(args) 473 | dd.save_dial_files(args, d_out) 474 | else: 475 | if args.target_domain: 476 | args.included_domains = [ele for ele in DOMAINS if ele != args.target_domain] 477 | create_dpr_data(args) 478 | 479 | 480 | if __name__ == "__main__": 481 | main() -------------------------------------------------------------------------------- /scripts/hf_datasets/doc2dial/doc2dial_pub.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Doc2dial: A Goal-Oriented Document-Grounded Dialogue Dataset v1.0.1""" 18 | 19 | 20 | import json 21 | import os 22 | from types import CodeType 23 | 24 | import datasets 25 | 26 | MAX_Q_LEN = 128 27 | DATA_DIR = "../data" 28 | 29 | logger = datasets.logging.get_logger(__name__) 30 | 31 | _CITATION = """\ 32 | @inproceedings{feng-etal-2020-doc2dial, 33 | title = "doc2dial: A Goal-Oriented Document-Grounded Dialogue Dataset", 34 | author = "Feng, Song and Wan, Hui and Gunasekara, Chulaka and Patel, Siva and Joshi, Sachindra and Lastras, Luis", 35 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 36 | month = nov, 37 | year = "2020", 38 | publisher = "Association for Computational Linguistics", 39 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.652", 40 | } 41 | """ 42 | 43 | _DESCRIPTION = """\ 44 | Doc2dial is dataset of goal-oriented dialogues that are grounded in the associated documents. \ 45 | It includes over 4500 annotated conversations with an average of 14 turns that are grounded \ 46 | in over 450 documents from four domains. Compared to the prior document-grounded dialogue datasets \ 47 | this dataset covers a variety of dialogue scenes in information-seeking conversations. 48 | """ 49 | 50 | _HOMEPAGE = "http://doc2dial.github.io/multidoc2dial/" 51 | 52 | 53 | _URL = "https://doc2dial.github.io/multidoc2dial/file/" 54 | 55 | _URLs = { 56 | "default": _URL + "multidoc2dial.zip", 57 | "domain": _URL + "multidoc2dial_domain.zip", 58 | } 59 | 60 | 61 | class Doc2dial(datasets.GeneratorBasedBuilder): 62 | "MultiDoc2Dial v1.0" 63 | 64 | VERSION = datasets.Version("1.0.0") 65 | 66 | BUILDER_CONFIGS = [ 67 | datasets.BuilderConfig( 68 | name="dialogue_domain", 69 | version=VERSION, 70 | description="This part of the dataset covers the dialgoue domain that has questions, answers and the associated doc ids", 71 | ), 72 | datasets.BuilderConfig( 73 | name="document_domain", 74 | version=VERSION, 75 | description="This part of the dataset covers the document domain which details all the documents in the various domains", 76 | ), 77 | datasets.BuilderConfig( 78 | name="multidoc2dial", 79 | version=VERSION, 80 | description="Load MultiDoc2Dial dataset for machine reading comprehension tasks by domain", 81 | ), 82 | datasets.BuilderConfig( 83 | name="multidoc2dial_dmv", 84 | version=VERSION, 85 | description="Load MultiDoc2Dial dataset for machine reading comprehension tasks by domain", 86 | ), 87 | datasets.BuilderConfig( 88 | name="multidoc2dial_ssa", 89 | version=VERSION, 90 | description="Load MultiDoc2Dial dataset for machine reading comprehension tasks by domain", 91 | ), 92 | datasets.BuilderConfig( 93 | name="multidoc2dial_va", 94 | version=VERSION, 95 | description="Load MultiDoc2Dial dataset for machine reading comprehension tasks by domain", 96 | ), 97 | datasets.BuilderConfig( 98 | name="multidoc2dial_studentaid", 99 | version=VERSION, 100 | description="Load MultiDoc2Dial dataset for machine reading comprehension tasks by domain", 101 | ), 102 | ] 103 | 104 | DEFAULT_CONFIG_NAME = "multidoc2dial" 105 | 106 | def _info(self): 107 | 108 | if self.config.name == "dialogue_domain": 109 | features = datasets.Features( 110 | { 111 | "dial_id": datasets.Value("string"), 112 | "doc_id": datasets.Value("string"), 113 | "domain": datasets.Value("string"), 114 | "turns": [ 115 | { 116 | "turn_id": datasets.Value("int32"), 117 | "role": datasets.Value("string"), 118 | "da": datasets.Value("string"), 119 | "references": [ 120 | { 121 | "id_sp": datasets.Value("string"), 122 | "label": datasets.Value("string"), 123 | } 124 | ], 125 | "utterance": datasets.Value("string"), 126 | } 127 | ], 128 | } 129 | ) 130 | 131 | elif "document_domain" in self.config.name: 132 | features = datasets.Features( 133 | { 134 | "domain": datasets.Value("string"), 135 | "doc_id": datasets.Value("string"), 136 | "title": datasets.Value("string"), 137 | "doc_text": datasets.Value("string"), 138 | "spans": [ 139 | { 140 | "id_sp": datasets.Value("string"), 141 | "tag": datasets.Value("string"), 142 | "start_sp": datasets.Value("int32"), 143 | "end_sp": datasets.Value("int32"), 144 | "text_sp": datasets.Value("string"), 145 | "title": datasets.Value("string"), 146 | "parent_titles": datasets.features.Sequence( 147 | { 148 | "id_sp": datasets.Value("string"), 149 | "text": datasets.Value("string"), 150 | "level": datasets.Value("string"), 151 | } 152 | ), 153 | "id_sec": datasets.Value("string"), 154 | "start_sec": datasets.Value("int32"), 155 | "text_sec": datasets.Value("string"), 156 | "end_sec": datasets.Value("int32"), 157 | } 158 | ], 159 | "doc_html_ts": datasets.Value("string"), 160 | "doc_html_raw": datasets.Value("string"), 161 | } 162 | ) 163 | 164 | else: 165 | features = datasets.Features( 166 | { 167 | "id": datasets.Value("string"), 168 | "title": datasets.Value("string"), 169 | "context": datasets.Value("string"), 170 | "question": datasets.Value("string"), 171 | "da": datasets.Value("string"), 172 | "answers": datasets.features.Sequence( 173 | { 174 | "text": datasets.Value("string"), 175 | "answer_start": datasets.Value("int32"), 176 | } 177 | ), 178 | "utterance": datasets.Value("string"), 179 | "domain": datasets.Value("string"), 180 | } 181 | ) 182 | 183 | return datasets.DatasetInfo( 184 | description=_DESCRIPTION, 185 | features=features, 186 | supervised_keys=None, 187 | homepage=_HOMEPAGE, 188 | citation=_CITATION, 189 | ) 190 | 191 | def _split_generators(self, dl_manager): 192 | 193 | my_urls = _URLs 194 | 195 | # data_dir = dl_manager.download_and_extract(my_urls) 196 | data_dir = DATA_DIR 197 | 198 | if self.config.name == "dialogue_domain": 199 | return [ 200 | datasets.SplitGenerator( 201 | name=datasets.Split.TRAIN, 202 | gen_kwargs={ 203 | "filepath": os.path.join(data_dir, "multidoc2dial/multidoc2dial_dial_train.json"), 204 | }, 205 | ), 206 | datasets.SplitGenerator( 207 | name=datasets.Split.VALIDATION, 208 | gen_kwargs={ 209 | "filepath": os.path.join(data_dir, "multidoc2dial/multidoc2dial_dial_validation.json"), 210 | }, 211 | ), 212 | ] 213 | elif self.config.name == "document_domain": 214 | return [ 215 | datasets.SplitGenerator( 216 | name=datasets.Split.TRAIN, 217 | gen_kwargs={ 218 | "filepath": os.path.join(data_dir, "multidoc2dial/multidoc2dial_doc.json"), 219 | }, 220 | ) 221 | ] 222 | elif "multidoc2dial_" in self.config.name: 223 | domain = self.config.name.split("_")[-1] 224 | return [ 225 | datasets.SplitGenerator( 226 | name=datasets.Split.VALIDATION, 227 | gen_kwargs={ 228 | "filepath": os.path.join( 229 | data_dir, "multidoc2dial_domain", domain, "multidoc2dial_dial_validation.json" 230 | ), 231 | }, 232 | ), 233 | datasets.SplitGenerator( 234 | name=datasets.Split.TRAIN, 235 | gen_kwargs={ 236 | "filepath": os.path.join( 237 | data_dir, "multidoc2dial_domain", domain, "multidoc2dial_dial_train.json" 238 | ), 239 | }, 240 | ), 241 | datasets.SplitGenerator( 242 | name=datasets.Split.TEST, 243 | gen_kwargs={ 244 | "filepath": os.path.join( 245 | data_dir, "multidoc2dial_domain", domain, "multidoc2dial_dial_test.json" 246 | ), 247 | }, 248 | ), 249 | ] 250 | elif self.config.name == "multidoc2dial": 251 | return [ 252 | datasets.SplitGenerator( 253 | name=datasets.Split.VALIDATION, 254 | gen_kwargs={ 255 | "filepath": os.path.join(data_dir, "multidoc2dial/multidoc2dial_dial_validation.json"), 256 | }, 257 | ), 258 | datasets.SplitGenerator( 259 | name=datasets.Split.TRAIN, 260 | gen_kwargs={ 261 | "filepath": os.path.join(data_dir, "multidoc2dial/multidoc2dial_dial_train.json"), 262 | }, 263 | ), 264 | datasets.SplitGenerator( 265 | name=datasets.Split.TEST, 266 | gen_kwargs={ 267 | "filepath": os.path.join(data_dir, "multidoc2dial/multidoc2dial_dial_test.json"), 268 | }, 269 | ), 270 | ] 271 | 272 | def _load_doc_data_rc(self, filepath): 273 | # doc_filepath = os.path.join(os.path.dirname(filepath), "multidoc2dial_doc.json") 274 | doc_filepath = os.path.join(DATA_DIR, "multidoc2dial/multidoc2dial_doc.json") 275 | with open(doc_filepath, encoding="utf-8") as f: 276 | data = json.load(f)["doc_data"] 277 | return data 278 | 279 | def _get_answers_rc(self, references, spans, doc_text): 280 | """Obtain the grounding annotation for a given dialogue turn""" 281 | if not references: 282 | return [] 283 | start, end = -1, -1 284 | ls_sp = [] 285 | for ele in references: 286 | id_sp = ele["id_sp"] 287 | start_sp, end_sp = spans[id_sp]["start_sp"], spans[id_sp]["end_sp"] 288 | if start == -1 or start > start_sp: 289 | start = start_sp 290 | if end < end_sp: 291 | end = end_sp 292 | ls_sp.append(doc_text[start_sp:end_sp]) 293 | answer = {"text": doc_text[start:end], "answer_start": start} 294 | return [answer] 295 | 296 | def _generate_examples(self, filepath): 297 | """This function returns the examples in the raw (text) form.""" 298 | if self.config.name == "dialogue_domain": 299 | logger.info("generating examples from = %s", filepath) 300 | with open(filepath, encoding="utf-8") as f: 301 | data = json.load(f) 302 | for domain in data["dial_data"]: 303 | for doc_id in data["dial_data"][domain]: 304 | for dialogue in data["dial_data"][domain][doc_id]: 305 | 306 | x = { 307 | "dial_id": dialogue["dial_id"], 308 | "domain": domain, 309 | "doc_id": doc_id, 310 | "turns": dialogue["turns"], 311 | } 312 | 313 | yield dialogue["dial_id"], x 314 | 315 | elif self.config.name == "document_domain": 316 | 317 | logger.info("generating examples from = %s", filepath) 318 | with open(filepath, encoding="utf-8") as f: 319 | data = json.load(f) 320 | for domain in data["doc_data"]: 321 | for doc_id in data["doc_data"][domain]: 322 | 323 | yield doc_id, { 324 | "domain": domain, 325 | "doc_id": doc_id, 326 | "title": data["doc_data"][domain][doc_id]["title"], 327 | "doc_text": data["doc_data"][domain][doc_id]["doc_text"], 328 | "spans": [ 329 | { 330 | "id_sp": data["doc_data"][domain][doc_id]["spans"][i]["id_sp"], 331 | "tag": data["doc_data"][domain][doc_id]["spans"][i]["tag"], 332 | "start_sp": data["doc_data"][domain][doc_id]["spans"][i]["start_sp"], 333 | "end_sp": data["doc_data"][domain][doc_id]["spans"][i]["end_sp"], 334 | "text_sp": data["doc_data"][domain][doc_id]["spans"][i]["text_sp"], 335 | "title": data["doc_data"][domain][doc_id]["spans"][i]["title"], 336 | "parent_titles": data["doc_data"][domain][doc_id]["spans"][i]["parent_titles"], 337 | "id_sec": data["doc_data"][domain][doc_id]["spans"][i]["id_sec"], 338 | "start_sec": data["doc_data"][domain][doc_id]["spans"][i]["start_sec"], 339 | "text_sec": data["doc_data"][domain][doc_id]["spans"][i]["text_sec"], 340 | "end_sec": data["doc_data"][domain][doc_id]["spans"][i]["end_sec"], 341 | } 342 | for i in data["doc_data"][domain][doc_id]["spans"] 343 | ], 344 | "doc_html_ts": data["doc_data"][domain][doc_id]["doc_html_ts"], 345 | "doc_html_raw": data["doc_data"][domain][doc_id]["doc_html_raw"], 346 | } 347 | 348 | elif "multidoc2dial" in self.config.name: 349 | logger.info("generating examples from = %s", filepath) 350 | doc_data = self._load_doc_data_rc(filepath) 351 | d_doc_data = {} 352 | for domain, d_doc in doc_data.items(): 353 | for doc_id, data in d_doc.items(): 354 | d_doc_data[doc_id] = data 355 | with open(filepath, encoding="utf-8") as f: 356 | dial_data = json.load(f)["dial_data"] 357 | for domain, dialogues in dial_data.items(): 358 | for dial in dialogues: 359 | all_prev_utterances = [] 360 | for idx, turn in enumerate(dial["turns"]): 361 | doc_id = turn["references"][0]["doc_id"] 362 | doc = d_doc_data[doc_id] 363 | utterance_line = turn["utterance"].replace("\n", " ").replace("\t", " ") 364 | all_prev_utterances.append("{}: {}".format(turn["role"], utterance_line)) 365 | if turn["role"] == "agent": 366 | continue 367 | if idx + 1 < len(dial["turns"]): 368 | if ( 369 | dial["turns"][idx + 1]["role"] == "agent" 370 | and dial["turns"][idx + 1]["da"] != "respond_no_solution" 371 | ): 372 | turn_to_predict = dial["turns"][idx + 1] 373 | else: 374 | continue 375 | else: 376 | continue 377 | question_str = utterance_line + "[SEP]" + "||".join(reversed(all_prev_utterances[:-1])) 378 | id_ = "{}_{}".format(dial["dial_id"], turn["turn_id"]) 379 | qa = { 380 | "id": id_, 381 | "title": doc_id, 382 | "context": doc["doc_text"], 383 | "question": question_str, 384 | "da": turn["da"], 385 | "answers": self._get_answers_rc( 386 | turn_to_predict["references"], doc["spans"], doc["doc_text"] 387 | ), 388 | "utterance": turn_to_predict["utterance"], 389 | "domain": domain, 390 | } 391 | yield id_, qa -------------------------------------------------------------------------------- /scripts/model_converter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer 3 | 4 | # https://huggingface.co/facebook/rag-token-base 5 | 6 | 7 | def main(): 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument( 12 | "--model_path", 13 | type=str, 14 | ) 15 | 16 | parser.add_argument( 17 | "--out_path", 18 | type=str, 19 | ) 20 | 21 | parser.add_argument( 22 | "--index_name", 23 | type=str, 24 | default="exact", 25 | ) 26 | 27 | args = parser.parse_args() 28 | 29 | model = RagTokenForGeneration.from_pretrained_question_encoder_generator(args.model_path, "facebook/bart-large") 30 | model.config.use_dummy_dataset = True 31 | model.config.index_name = args.index_name 32 | 33 | question_encoder_tokenizer = AutoTokenizer.from_pretrained(args.model_path) 34 | generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") 35 | 36 | tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) 37 | retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) 38 | 39 | model.save_pretrained(args.out_path) 40 | tokenizer.save_pretrained(args.out_path) 41 | retriever.save_pretrained(args.out_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() -------------------------------------------------------------------------------- /scripts/rag/callbacks_rag.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | from utils_rag import save_json 12 | 13 | 14 | def count_trainable_parameters(model): 15 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 16 | params = sum([np.prod(p.size()) for p in model_parameters]) 17 | return params 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def get_checkpoint_callback(output_dir, metric): 24 | """Saves the best model by validation EM score.""" 25 | if metric == "rouge2": 26 | exp = "{val_avg_rouge2:.4f}-{step_count}" 27 | elif metric == "bleu": 28 | exp = "{val_avg_bleu:.4f}-{step_count}" 29 | elif metric == "em": 30 | exp = "{val_avg_em:.4f}-{step_count}" 31 | else: 32 | raise NotImplementedError( 33 | f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." 34 | ) 35 | 36 | checkpoint_callback = ModelCheckpoint( 37 | filepath=os.path.join(output_dir, exp), 38 | monitor=f"val_{metric}", 39 | mode="max", 40 | save_top_k=1, 41 | period=1, # maybe save a checkpoint every time val is run, not just end of epoch. 42 | ) 43 | return checkpoint_callback 44 | 45 | 46 | def get_early_stopping_callback(metric, patience): 47 | return EarlyStopping( 48 | monitor=f"val_{metric}", # does this need avg? 49 | mode="min" if "loss" in metric else "max", 50 | patience=patience, 51 | verbose=True, 52 | ) 53 | 54 | 55 | class Seq2SeqLoggingCallback(pl.Callback): 56 | def on_batch_end(self, trainer, pl_module): 57 | lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} 58 | pl_module.logger.log_metrics(lrs) 59 | 60 | @rank_zero_only 61 | def _write_logs( 62 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True 63 | ) -> None: 64 | logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") 65 | metrics = trainer.callback_metrics 66 | trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) 67 | # Log results 68 | od = Path(pl_module.hparams.output_dir) 69 | if type_path == "test": 70 | results_file = od / "test_results.txt" 71 | generations_file = od / "test_generations.txt" 72 | else: 73 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json 74 | # If people want this it will be easy enough to add back. 75 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" 76 | generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" 77 | results_file.parent.mkdir(exist_ok=True) 78 | generations_file.parent.mkdir(exist_ok=True) 79 | with open(results_file, "a+") as writer: 80 | for key in sorted(metrics): 81 | if key in ["log", "progress_bar", "preds"]: 82 | continue 83 | val = metrics[key] 84 | if isinstance(val, torch.Tensor): 85 | val = val.item() 86 | msg = f"{key}: {val:.6f}\n" 87 | writer.write(msg) 88 | 89 | if not save_generations: 90 | return 91 | 92 | if "preds" in metrics: 93 | content = "\n".join(metrics["preds"]) 94 | generations_file.open("w+").write(content) 95 | 96 | @rank_zero_only 97 | def on_train_start(self, trainer, pl_module): 98 | try: 99 | npars = pl_module.model.model.num_parameters() 100 | except AttributeError: 101 | npars = pl_module.model.num_parameters() 102 | 103 | n_trainable_pars = count_trainable_parameters(pl_module) 104 | # mp stands for million parameters 105 | trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) 106 | 107 | @rank_zero_only 108 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 109 | save_json(pl_module.metrics, pl_module.metrics_save_path) 110 | return self._write_logs(trainer, pl_module, "test") 111 | 112 | @rank_zero_only 113 | def on_validation_end(self, trainer: pl.Trainer, pl_module): 114 | save_json(pl_module.metrics, pl_module.metrics_save_path) 115 | # Uncommenting this will save val generations 116 | # return self._write_logs(trainer, pl_module, "valid") 117 | -------------------------------------------------------------------------------- /scripts/rag/eval_rag.py: -------------------------------------------------------------------------------- 1 | """ Evaluation script for RAG models.""" 2 | 3 | import argparse 4 | import ast 5 | import logging 6 | import os 7 | import sys 8 | 9 | import pandas as pd 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | from datasets import load_metric 14 | 15 | from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration 16 | from transformers import logging as transformers_logging 17 | 18 | 19 | sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip 20 | from utils_rag import exact_match_score, f1_score, load_bm25 # noqa: E402 # isort:skip 21 | from dialdoc.models.rag.modeling_rag_dialdoc import DialDocRagTokenForGeneration 22 | from dialdoc.models.rag.retrieval_rag_dialdoc import DialDocRagRetriever 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | transformers_logging.set_verbosity_info() 29 | 30 | # os.environ['KMP_DUPLICATE_LIB_OK']='True' 31 | 32 | 33 | def get_top_n_indices(bm25, query, n=5): 34 | query = query.lower().split() 35 | scores = bm25.get_scores(query) 36 | scores_i = [(i, score) for i, score in enumerate(scores)] 37 | sorted_indices = sorted(scores_i, key=lambda score: score[1], reverse=True) 38 | return sorted_indices[:n] 39 | 40 | 41 | def infer_model_type(model_name_or_path): 42 | if "token_dialdoc" in model_name_or_path: 43 | return "rag_token_dialdoc" 44 | if "token" in model_name_or_path: 45 | return "rag_token" 46 | if "sequence" in model_name_or_path: 47 | return "rag_sequence" 48 | if "bart" in model_name_or_path: 49 | return "bart" 50 | return None 51 | 52 | 53 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 54 | return max(metric_fn(prediction, gt) for gt in ground_truths) 55 | 56 | 57 | def get_scores(args, preds_path, gold_data_path): 58 | hypos = [line.strip() for line in open(preds_path, "r").readlines()] 59 | answers = [] 60 | 61 | if args.gold_data_mode == "qa": 62 | data = pd.read_csv(gold_data_path, sep="\t", header=None) 63 | for answer_list in data[1]: 64 | ground_truths = ast.literal_eval(answer_list) 65 | answers.append(ground_truths) 66 | else: 67 | references = [line.strip() for line in open(gold_data_path, "r").readlines()] 68 | answers = [[reference] for reference in references] 69 | 70 | f1 = em = total = 0 71 | for prediction, ground_truths in zip(hypos, answers): 72 | total += 1 73 | em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) 74 | f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) 75 | 76 | em = 100.0 * em / total 77 | f1 = 100.0 * f1 / total 78 | 79 | metric = load_metric("sacrebleu") 80 | metric.add_batch(predictions=hypos, references=answers) 81 | sacrebleu = metric.compute()["score"] 82 | 83 | logger.info(f"F1: {f1: .2f}") 84 | logger.info(f"EM: {em: .2f}") 85 | logger.info(f"sacrebleu: {sacrebleu: .2f}") 86 | logger.info(f"all: {f1: .2f} & {em: .2f} & {sacrebleu: .2f} ") 87 | 88 | 89 | def get_precision_at_k(args, preds_path, gold_data_path): 90 | k = args.k 91 | hypos = [line.strip().split("####")[0] for line in open(preds_path, "r").readlines()] 92 | hypos_pid = [line.strip().split("####")[-1] for line in open(preds_path, "r").readlines()] 93 | references = [line.strip() for line in open(gold_data_path, "r").readlines()] 94 | pids = [line.strip().split("\t") for line in open(args.gold_pid_path, "r").readlines()] 95 | 96 | r_1 = r_5 = r_10 = em = total = 0 97 | for hypo, reference in zip(hypos, references): 98 | hypo_provenance = set(hypo.split("\t")[:k]) 99 | ref_provenance = set(reference.split("\t")) 100 | total += 1 101 | em += len(hypo_provenance & ref_provenance) / k 102 | r_1 += int(bool(set(hypo.split("\t")[:1]) & ref_provenance)) 103 | r_5 += int(bool(set(hypo.split("\t")[:5]) & ref_provenance)) 104 | r_10 += int(bool(set(hypo.split("\t")[:10]) & ref_provenance)) 105 | r_1 = 100.0 * r_1 / total 106 | r_5 = 100.0 * r_5 / total 107 | r_10 = 100.0 * r_10 / total 108 | # logger.info(f"Doc_Prec@{k}: {em: .2f}") 109 | # logger.info(f"Doc_Prec@{1}: {r_1: .2f}") 110 | logger.info(f"Doc_Prec@1: {r_1: .2f}") 111 | logger.info(f"Doc_Prec@5: {r_5: .2f}") 112 | logger.info(f"Doc_Prec@10: {r_10: .2f}") 113 | 114 | r_1_p = r_5_p = r_10_p = total = 0 115 | for hypo, reference in zip(hypos_pid, pids): 116 | hypo = hypo.split("\t") 117 | # hypo_provenance = set(hypo) 118 | ref_provenance = set(reference) 119 | total += 1 120 | # em += len([r for r in reference if r in hypo_provenance]) == len(reference) 121 | r_1_p += int(bool(set(hypo[:1]) & ref_provenance)) 122 | r_5_p += int(bool(set(hypo[:5]) & ref_provenance)) 123 | r_10_p += int(bool(set(hypo[:10]) & ref_provenance)) 124 | r_1_p = 100.0 * r_1_p / total 125 | r_5_p = 100.0 * r_5_p / total 126 | r_10_p = 100.0 * r_10_p / total 127 | logger.info(f"Pid_Prec@1: {r_1_p: .2f}") 128 | logger.info(f"Pid_Prec@5: {r_5_p: .2f}") 129 | logger.info(f"Pid_Prec@10: {r_10_p: .2f}") 130 | logger.info(f"all: {r_1: .2f} & {r_5: .2f} & {r_10: .2f} & {r_1_p: .2f} & {r_5_p: .2f} & {r_10_p: .2f} & ") 131 | 132 | 133 | def mean_pool(vector: torch.LongTensor): 134 | return vector.sum(axis=0) / vector.shape[0] 135 | 136 | 137 | def get_attn_mask(tokens_tensor: torch.LongTensor) -> torch.tensor: 138 | return tokens_tensor != 0 139 | 140 | 141 | def evaluate_batch_retrieval(args, rag_model, questions, domains=None): # old_q 142 | def strip_title(title): 143 | if title.startswith('"'): 144 | title = title[1:] 145 | if title.endswith('"'): 146 | title = title[:-1] 147 | return title 148 | 149 | # retriever_input_ids_0 = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( 150 | # old_q, 151 | # return_tensors="pt", 152 | # padding=True, 153 | # truncation=True, 154 | # )["input_ids"].to(args.device) 155 | # question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids_0) 156 | # question_enc_pool_output = question_enc_outputs[0] 157 | 158 | inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( 159 | questions, 160 | return_tensors="pt", 161 | padding=True, 162 | truncation=True, 163 | add_special_tokens=True, 164 | return_token_type_ids=True, 165 | ) 166 | 167 | retriever_input_ids = inputs_dict.input_ids.to(args.device) 168 | token_type_ids = inputs_dict.token_type_ids.to(args.device) 169 | attention_mask = inputs_dict.attention_mask.to(args.device) 170 | 171 | dpr_out = rag_model.rag.question_encoder( 172 | retriever_input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True 173 | ) 174 | combined_out = dpr_out.pooler_output 175 | 176 | ## Get mask for current turn input ids 177 | curr_turn_mask = torch.logical_xor(attention_mask, token_type_ids) 178 | current_turn_input_ids = retriever_input_ids * curr_turn_mask 179 | current_turn_only_out = rag_model.rag.question_encoder( 180 | current_turn_input_ids, attention_mask=curr_turn_mask.long(), return_dict=True 181 | ) 182 | current_turn_output = current_turn_only_out.pooler_output 183 | 184 | ## Split the dpr sequence output 185 | sequence_output = dpr_out.hidden_states[-1] 186 | attn_mask = get_attn_mask(retriever_input_ids) 187 | ## Split sequence output, and pool each sequence 188 | seq_out_0 = [] # last turn, if query; doc structure if passage 189 | seq_out_1 = [] # dial history, if query; passage text if passage 190 | dialog_lengths = [] 191 | for i in range(sequence_output.shape[0]): 192 | seq_out_masked = sequence_output[i, attn_mask[i], :] 193 | segment_masked = token_type_ids[i, attn_mask[i]] 194 | seq_out_masked_0 = seq_out_masked[segment_masked == 0, :] 195 | seq_out_masked_1 = seq_out_masked[segment_masked == 1, :] 196 | dialog_lengths.append((len(seq_out_masked_0), len(seq_out_masked_1))) 197 | ### perform pooling 198 | seq_out_0.append(mean_pool(seq_out_masked_0)) 199 | seq_out_1.append(mean_pool(seq_out_masked_1)) 200 | 201 | pooled_output_0 = torch.cat([seq.view(1, -1) for seq in seq_out_0], dim=0) 202 | pooled_output_1 = torch.cat([seq.view(1, -1) for seq in seq_out_1], dim=0) 203 | 204 | if args.scoring_func in ["reranking_original", "current_original"]: 205 | current_out = current_turn_output 206 | else: 207 | current_out = pooled_output_0 208 | 209 | if args.bm25: 210 | logger.info("Using BM25 for retrieval") 211 | doc_ids = [] 212 | doc_scores = [] 213 | for input_string in questions: 214 | input_string = " [SEP] ".join(input_string) 215 | sorted_indices = get_top_n_indices(rag_model.bm25, input_string, rag_model.config.n_docs) 216 | doc_ids.append([x[0] for x in sorted_indices]) 217 | doc_scores.append([x[-1] for x in sorted_indices]) 218 | all_docs = rag_model.retriever.index.get_doc_dicts(np.array(doc_ids)) 219 | else: 220 | if args.scoring_func != "original": 221 | current_input = current_out.cpu().detach().to(torch.float32).numpy() 222 | history_input = pooled_output_1.cpu().detach().to(torch.float32).numpy() 223 | else: 224 | current_input = combined_out.cpu().detach().to(torch.float32).numpy() 225 | history_input = combined_out.cpu().detach().to(torch.float32).numpy() 226 | result = rag_model.retriever( 227 | retriever_input_ids, 228 | combined_out.cpu().detach().to(torch.float32).numpy(), 229 | current_input, 230 | history_input, 231 | dialog_lengths=dialog_lengths, 232 | domain=domains, 233 | prefix=rag_model.rag.generator.config.prefix, 234 | n_docs=rag_model.config.n_docs, 235 | return_tensors="pt", 236 | ) 237 | all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids) 238 | doc_ids = result.doc_ids 239 | provenance_strings = [] 240 | 241 | for i, docs in enumerate(all_docs): 242 | provenance = [strip_title(title) for title in docs["title"]] 243 | # provenance_strings.append("\t".join(provenance)) 244 | pids = "\t".join([str(int(e)) for e in doc_ids[i]]) 245 | provenance_strings.append("\t".join(provenance) + "####" + pids) 246 | return provenance_strings 247 | 248 | 249 | def evaluate_batch_e2e(args, rag_model, questions, domains=None): 250 | with torch.no_grad(): 251 | inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( 252 | questions, 253 | return_tensors="pt", 254 | padding=True, 255 | truncation=True, 256 | add_special_tokens=True, 257 | return_token_type_ids=True, 258 | ) 259 | 260 | input_ids = inputs_dict.input_ids.to(args.device) 261 | token_type_ids = inputs_dict.token_type_ids.to(args.device) 262 | attention_mask = inputs_dict.attention_mask.to(args.device) 263 | outputs = rag_model.generate( # rag_model overwrites generate 264 | input_ids, 265 | domain=domains, 266 | attention_mask=attention_mask, 267 | token_type_ids=token_type_ids, 268 | num_beams=args.num_beams, 269 | min_length=args.min_length, 270 | max_length=args.max_length, 271 | early_stopping=False, 272 | num_return_sequences=1, 273 | bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one 274 | ) 275 | answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True) 276 | 277 | if args.print_predictions: 278 | for q, a in zip(questions, answers): 279 | logger.info("Q: {} - A: {}".format(q, a)) 280 | 281 | return answers 282 | 283 | 284 | def get_args(): 285 | parser = argparse.ArgumentParser() 286 | parser.add_argument( 287 | "--scoring_func", 288 | default="original", 289 | type=str, 290 | help="different scoring function, `original`, `linear`, `nonlinear`, `reranking`, `reranking_original`", 291 | ) 292 | parser.add_argument( 293 | "--bm25", 294 | type=str, 295 | default=None, 296 | help="file folder", 297 | ) 298 | parser.add_argument( 299 | "--mapping_file", 300 | type=str, 301 | default=None, 302 | help="file folder", 303 | ) 304 | parser.add_argument( 305 | "--model_type", 306 | choices=["rag_sequence", "rag_token", "rag_token_dialdoc", "bart"], 307 | type=str, 308 | help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path", 309 | ) 310 | parser.add_argument( 311 | "--index_name", 312 | default=None, 313 | choices=["dialdoc", "custom", "exact", "compressed", "legacy"], 314 | type=str, 315 | help="RAG model retriever type", 316 | ) 317 | parser.add_argument( 318 | "--index_path", 319 | default=None, 320 | type=str, 321 | help="Path to the retrieval index", 322 | ) 323 | parser.add_argument( 324 | "--passages_path", 325 | default=None, 326 | type=str, 327 | help="Path to the knowledge data", 328 | ) 329 | parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs") 330 | parser.add_argument( 331 | "--model_name_or_path", 332 | default=None, 333 | type=str, 334 | required=True, 335 | help="Path to pretrained checkpoints or model identifier from huggingface.co/models", 336 | ) 337 | parser.add_argument( 338 | "--eval_mode", 339 | choices=["e2e", "retrieval"], 340 | default="retrieval", 341 | type=str, 342 | help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.", 343 | ) 344 | parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation") 345 | parser.add_argument( 346 | "--evaluation_set", 347 | default=None, 348 | type=str, 349 | required=True, 350 | help="Path to a file containing evaluation samples", 351 | ) 352 | parser.add_argument( 353 | "--gold_data_path", 354 | default=None, 355 | type=str, 356 | required=True, 357 | help="Path to a tab-separated file with gold samples", 358 | ) 359 | parser.add_argument( 360 | "--gold_domain_path", 361 | default=None, 362 | type=str, 363 | required=False, 364 | help="Path to a tab-separated file with gold domains", 365 | ) 366 | parser.add_argument( 367 | "--gold_pid_path", 368 | default=None, 369 | type=str, 370 | required=True, 371 | help="Path to a tab-separated file with gold samples", 372 | ) 373 | parser.add_argument( 374 | "--gold_data_mode", 375 | default="qa", 376 | type=str, 377 | choices=["qa", "ans"], 378 | help="Format of the gold data file" 379 | "qa - a single line in the following format: question [tab] answer_list" 380 | "ans - a single line of the gold file contains the expected answer string", 381 | ) 382 | parser.add_argument( 383 | "--predictions_path", 384 | type=str, 385 | default="predictions.txt", 386 | help="Name of the predictions file, to be stored in the checkpoints directory", 387 | ) 388 | parser.add_argument( 389 | "--eval_all_checkpoints", 390 | action="store_true", 391 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 392 | ) 393 | parser.add_argument( 394 | "--eval_batch_size", 395 | default=8, 396 | type=int, 397 | help="Batch size per GPU/CPU for evaluation.", 398 | ) 399 | parser.add_argument( 400 | "--recalculate", 401 | help="Recalculate predictions even if the prediction file exists", 402 | action="store_true", 403 | ) 404 | parser.add_argument( 405 | "--num_beams", 406 | default=4, 407 | type=int, 408 | help="Number of beams to be used when generating answers", 409 | ) 410 | parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers") 411 | parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers") 412 | 413 | parser.add_argument( 414 | "--print_predictions", 415 | action="store_true", 416 | help="If True, prints predictions while evaluating.", 417 | ) 418 | parser.add_argument( 419 | "--print_docs", 420 | action="store_true", 421 | help="If True, prints docs retried while generating.", 422 | ) 423 | args = parser.parse_args() 424 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 425 | return args 426 | 427 | 428 | def main(args): 429 | model_kwargs = {} 430 | if args.model_type is None: 431 | args.model_type = infer_model_type(args.model_name_or_path) 432 | assert args.model_type is not None 433 | if args.model_type.startswith("rag"): 434 | if args.model_type == "rag_token": 435 | model_class = RagTokenForGeneration 436 | elif args.model_type == "rag_token_dialdoc": 437 | model_class = DialDocRagTokenForGeneration 438 | else: 439 | model_class = RagSequenceForGeneration 440 | model_kwargs["n_docs"] = args.n_docs 441 | if args.index_name is not None: 442 | model_kwargs["index_name"] = args.index_name 443 | if args.index_path is not None: 444 | model_kwargs["index_path"] = args.index_path 445 | if args.passages_path is not None: 446 | model_kwargs["passages_path"] = args.passages_path 447 | if args.mapping_file is not None: 448 | model_kwargs["mapping_file"] = args.mapping_file 449 | else: 450 | model_class = BartForConditionalGeneration 451 | 452 | bm25 = None 453 | if args.bm25: 454 | bm25 = load_bm25(args.bm25) 455 | 456 | checkpoints = ( 457 | [f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()] 458 | if args.eval_all_checkpoints 459 | else [args.model_name_or_path] 460 | ) 461 | 462 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 463 | 464 | score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k 465 | evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval 466 | 467 | for checkpoint in checkpoints: 468 | if os.path.exists(args.predictions_path) and (not args.recalculate): 469 | logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path)) 470 | score_fn(args, args.predictions_path, args.gold_data_path) 471 | continue 472 | 473 | logger.info("***** Running evaluation for {} *****".format(checkpoint)) 474 | logger.info(" Batch size = %d", args.eval_batch_size) 475 | logger.info(" Predictions will be stored under {}".format(args.predictions_path)) 476 | logger.info(" Using scoring function {}".format(args.scoring_func)) 477 | 478 | if args.model_type.startswith("rag"): 479 | if "dialdoc" in args.model_type: 480 | retriever = DialDocRagRetriever.from_pretrained(checkpoint, **model_kwargs) 481 | retriever.config.scoring_func = args.scoring_func 482 | retriever.config.n_docs = args.n_docs 483 | retriever.config.bm25 = args.bm25 484 | retriever.config.mapping_file = args.mapping_file 485 | model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs) 486 | if bm25: 487 | model.bm25 = bm25 488 | model.config.scoring_func = args.scoring_func 489 | model.config.n_docs = args.n_docs 490 | model.config.bm25 = args.bm25 491 | model.config.mapping_file = args.mapping_file 492 | 493 | else: 494 | retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs) 495 | model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs) 496 | model.retriever.init_retrieval() 497 | else: 498 | model = model_class.from_pretrained(checkpoint, **model_kwargs) 499 | model.to(args.device) 500 | 501 | with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file: 502 | questions = [] 503 | if args.gold_domain_path: 504 | dom_file = open(args.gold_domain_path, "r") 505 | domains = [] 506 | for line1, line2 in tqdm(zip(eval_file, dom_file)): 507 | question = line1.strip() 508 | questions.append(question) 509 | domain = line2.strip() 510 | domains.append(domain) 511 | if len(questions) == args.eval_batch_size: 512 | new_questions = list(tuple(question.split("[SEP]")) for question in questions) 513 | answers = evaluate_batch_fn(args, model, new_questions, domains) 514 | preds_file.write("\n".join(answers) + "\n") 515 | preds_file.flush() 516 | questions = [] 517 | if len(questions) > 0: 518 | new_questions = list(tuple(question.split("[SEP]")) for question in questions) 519 | answers = evaluate_batch_fn(args, model, new_questions, domains) 520 | preds_file.write("\n".join(answers)) 521 | preds_file.flush() 522 | else: 523 | for line in tqdm(eval_file): 524 | question = line.strip() 525 | questions.append(question) 526 | if len(questions) == args.eval_batch_size: 527 | new_questions = list(tuple(question.split("[SEP]")) for question in questions) 528 | answers = evaluate_batch_fn(args, model, new_questions) 529 | preds_file.write("\n".join(answers) + "\n") 530 | preds_file.flush() 531 | questions = [] 532 | if len(questions) > 0: 533 | new_questions = list(tuple(question.split("[SEP]")) for question in questions) 534 | answers = evaluate_batch_fn(args, model, new_questions) 535 | preds_file.write("\n".join(answers)) 536 | preds_file.flush() 537 | 538 | score_fn(args, args.predictions_path, args.gold_data_path) 539 | 540 | 541 | if __name__ == "__main__": 542 | args = get_args() 543 | main(args) -------------------------------------------------------------------------------- /scripts/rag/lightning_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from typing import Any, Dict 6 | 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.utilities import rank_zero_info 9 | 10 | from transformers import ( 11 | AdamW, 12 | AutoConfig, 13 | AutoModel, 14 | AutoModelForPreTraining, 15 | AutoModelForQuestionAnswering, 16 | AutoModelForSeq2SeqLM, 17 | AutoModelForSequenceClassification, 18 | AutoModelForTokenClassification, 19 | AutoModelWithLMHead, 20 | AutoTokenizer, 21 | PretrainedConfig, 22 | PreTrainedTokenizer, 23 | ) 24 | from transformers.optimization import ( 25 | Adafactor, 26 | get_cosine_schedule_with_warmup, 27 | get_cosine_with_hard_restarts_schedule_with_warmup, 28 | get_linear_schedule_with_warmup, 29 | get_polynomial_decay_schedule_with_warmup, 30 | ) 31 | # from transformers.utils.versions import require_version_examples 32 | 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | # require_version_examples("pytorch_lightning>=1.0.4") 37 | 38 | MODEL_MODES = { 39 | "base": AutoModel, 40 | "sequence-classification": AutoModelForSequenceClassification, 41 | "question-answering": AutoModelForQuestionAnswering, 42 | "pretraining": AutoModelForPreTraining, 43 | "token-classification": AutoModelForTokenClassification, 44 | "language-modeling": AutoModelWithLMHead, 45 | "summarization": AutoModelForSeq2SeqLM, 46 | "translation": AutoModelForSeq2SeqLM, 47 | } 48 | 49 | 50 | # update this and the import above to support new schedulers from transformers.optimization 51 | arg_to_scheduler = { 52 | "linear": get_linear_schedule_with_warmup, 53 | "cosine": get_cosine_schedule_with_warmup, 54 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 55 | "polynomial": get_polynomial_decay_schedule_with_warmup, 56 | # '': get_constant_schedule, # not supported for now 57 | # '': get_constant_schedule_with_warmup, # not supported for now 58 | } 59 | arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) 60 | arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" 61 | 62 | 63 | class BaseTransformer(pl.LightningModule): 64 | def __init__( 65 | self, 66 | hparams: argparse.Namespace, 67 | num_labels=None, 68 | mode="base", 69 | config=None, 70 | tokenizer=None, 71 | model=None, 72 | **config_kwargs 73 | ): 74 | """Initialize a model, tokenizer and config.""" 75 | super().__init__() 76 | # TODO: move to self.save_hyperparameters() 77 | # self.save_hyperparameters() 78 | # can also expand arguments into trainer signature for easier reading 79 | 80 | self.save_hyperparameters(hparams) 81 | self.step_count = 0 82 | self.output_dir = Path(self.hparams.output_dir) 83 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None 84 | if config is None: 85 | self.config = AutoConfig.from_pretrained( 86 | self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, 87 | **({"num_labels": num_labels} if num_labels is not None else {}), 88 | cache_dir=cache_dir, 89 | **config_kwargs, 90 | ) 91 | else: 92 | self.config: PretrainedConfig = config 93 | 94 | extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") 95 | for p in extra_model_params: 96 | if getattr(self.hparams, p, None): 97 | assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" 98 | setattr(self.config, p, getattr(self.hparams, p)) 99 | 100 | if tokenizer is None: 101 | self.tokenizer = AutoTokenizer.from_pretrained( 102 | self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, 103 | cache_dir=cache_dir, 104 | ) 105 | else: 106 | self.tokenizer: PreTrainedTokenizer = tokenizer 107 | self.model_type = MODEL_MODES[mode] 108 | if model is None: 109 | self.model = self.model_type.from_pretrained( 110 | self.hparams.model_name_or_path, 111 | from_tf=bool(".ckpt" in self.hparams.model_name_or_path), 112 | config=self.config, 113 | cache_dir=cache_dir, 114 | ) 115 | else: 116 | self.model = model 117 | 118 | def load_hf_checkpoint(self, *args, **kwargs): 119 | self.model = self.model_type.from_pretrained(*args, **kwargs) 120 | 121 | def get_lr_scheduler(self): 122 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] 123 | scheduler = get_schedule_func( 124 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps() 125 | ) 126 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 127 | return scheduler 128 | 129 | def configure_optimizers(self): 130 | """Prepare optimizer and schedule (linear warmup and decay)""" 131 | model = self.model 132 | no_decay = ["bias", "LayerNorm.weight"] 133 | optimizer_grouped_parameters = [ 134 | { 135 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 136 | "weight_decay": self.hparams.weight_decay, 137 | }, 138 | { 139 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 140 | "weight_decay": 0.0, 141 | }, 142 | ] 143 | if self.hparams.adafactor: 144 | optimizer = Adafactor( 145 | optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False 146 | ) 147 | 148 | else: 149 | optimizer = AdamW( 150 | optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon 151 | ) 152 | self.opt = optimizer 153 | 154 | scheduler = self.get_lr_scheduler() 155 | 156 | return [optimizer], [scheduler] 157 | 158 | def test_step(self, batch, batch_nb): 159 | return self.validation_step(batch, batch_nb) 160 | 161 | def test_epoch_end(self, outputs): 162 | return self.validation_end(outputs) 163 | 164 | def total_steps(self) -> int: 165 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" 166 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores 167 | effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices 168 | return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs 169 | 170 | def setup(self, mode): 171 | if mode == "test": 172 | self.dataset_size = len(self.test_dataloader().dataset) 173 | else: 174 | self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) 175 | self.dataset_size = len(self.train_dataloader().dataset) 176 | 177 | def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False): 178 | raise NotImplementedError("You must implement this for your task") 179 | 180 | def train_dataloader(self): 181 | return self.train_loader 182 | 183 | def val_dataloader(self): 184 | return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False) 185 | 186 | def test_dataloader(self): 187 | return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False) 188 | 189 | def _feature_file(self, mode): 190 | return os.path.join( 191 | self.hparams.data_dir, 192 | "cached_{}_{}_{}".format( 193 | mode, 194 | list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), 195 | str(self.hparams.max_seq_length), 196 | ), 197 | ) 198 | 199 | @pl.utilities.rank_zero_only 200 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 201 | save_path = self.output_dir.joinpath("best_tfmr") 202 | self.model.config.save_step = self.step_count 203 | self.model.save_pretrained(save_path) 204 | self.tokenizer.save_pretrained(save_path) 205 | 206 | @staticmethod 207 | def add_model_specific_args(parser, root_dir): 208 | parser.add_argument( 209 | "--model_name_or_path", 210 | default=None, 211 | type=str, 212 | required=True, 213 | help="Path to pretrained model or model identifier from huggingface.co/models", 214 | ) 215 | parser.add_argument( 216 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 217 | ) 218 | parser.add_argument( 219 | "--tokenizer_name", 220 | default=None, 221 | type=str, 222 | help="Pretrained tokenizer name or path if not the same as model_name", 223 | ) 224 | parser.add_argument( 225 | "--cache_dir", 226 | default="", 227 | type=str, 228 | help="Where do you want to store the pre-trained models downloaded from huggingface.co", 229 | ) 230 | parser.add_argument( 231 | "--encoder_layerdrop", 232 | type=float, 233 | help="Encoder layer dropout probability (Optional). Goes into model.config", 234 | ) 235 | parser.add_argument( 236 | "--decoder_layerdrop", 237 | type=float, 238 | help="Decoder layer dropout probability (Optional). Goes into model.config", 239 | ) 240 | parser.add_argument( 241 | "--dropout", 242 | type=float, 243 | help="Dropout probability (Optional). Goes into model.config", 244 | ) 245 | parser.add_argument( 246 | "--attention_dropout", 247 | type=float, 248 | help="Attention dropout probability (Optional). Goes into model.config", 249 | ) 250 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 251 | parser.add_argument( 252 | "--lr_scheduler", 253 | default="linear", 254 | choices=arg_to_scheduler_choices, 255 | metavar=arg_to_scheduler_metavar, 256 | type=str, 257 | help="Learning rate scheduler", 258 | ) 259 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 260 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 261 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 262 | parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") 263 | parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) 264 | parser.add_argument("--train_batch_size", default=32, type=int) 265 | parser.add_argument("--eval_batch_size", default=32, type=int) 266 | parser.add_argument("--adafactor", action="store_true") 267 | 268 | 269 | class LoggingCallback(pl.Callback): 270 | def on_batch_end(self, trainer, pl_module): 271 | lr_scheduler = trainer.lr_schedulers[0]["scheduler"] 272 | lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} 273 | pl_module.logger.log_metrics(lrs) 274 | 275 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 276 | rank_zero_info("***** Validation results *****") 277 | metrics = trainer.callback_metrics 278 | # Log results 279 | for key in sorted(metrics): 280 | if key not in ["log", "progress_bar"]: 281 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 282 | 283 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 284 | rank_zero_info("***** Test results *****") 285 | metrics = trainer.callback_metrics 286 | # Log and save results to file 287 | output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 288 | with open(output_test_results_file, "w") as writer: 289 | for key in sorted(metrics): 290 | if key not in ["log", "progress_bar"]: 291 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 292 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 293 | 294 | 295 | def add_generic_args(parser, root_dir) -> None: 296 | # To allow all pl args uncomment the following line 297 | # parser = pl.Trainer.add_argparse_args(parser) 298 | parser.add_argument( 299 | "--output_dir", 300 | default=None, 301 | type=str, 302 | required=True, 303 | help="The output directory where the model predictions and checkpoints will be written.", 304 | ) 305 | parser.add_argument( 306 | "--fp16", 307 | action="store_true", 308 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 309 | ) 310 | 311 | parser.add_argument( 312 | "--fp16_opt_level", 313 | type=str, 314 | default="O2", 315 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 316 | "See details at https://nvidia.github.io/apex/amp.html", 317 | ) 318 | parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) 319 | parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") 320 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 321 | parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") 322 | parser.add_argument( 323 | "--gradient_accumulation_steps", 324 | dest="accumulate_grad_batches", 325 | type=int, 326 | default=1, 327 | help="Number of updates steps to accumulate before performing a backward/update pass.", 328 | ) 329 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 330 | parser.add_argument( 331 | "--data_dir", 332 | default=None, 333 | type=str, 334 | required=True, 335 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", 336 | ) 337 | 338 | 339 | def generic_train( 340 | model: BaseTransformer, 341 | args: argparse.Namespace, 342 | early_stopping_callback=None, 343 | logger=True, # can pass WandbLogger() here 344 | extra_callbacks=[], 345 | checkpoint_callback=None, 346 | logging_callback=None, 347 | **extra_train_kwargs 348 | ): 349 | pl.seed_everything(args.seed) 350 | 351 | # init model 352 | odir = Path(model.hparams.output_dir) 353 | odir.mkdir(exist_ok=True) 354 | 355 | # add custom checkpoints 356 | if checkpoint_callback is None: 357 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 358 | filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 359 | ) 360 | if early_stopping_callback: 361 | extra_callbacks.append(early_stopping_callback) 362 | if logging_callback is None: 363 | logging_callback = LoggingCallback() 364 | 365 | train_params = {} 366 | 367 | # TODO: remove with PyTorch 1.6 since pl uses native amp 368 | if args.fp16: 369 | train_params["precision"] = 16 370 | train_params["amp_level"] = args.fp16_opt_level 371 | 372 | if args.gpus > 1: 373 | train_params["distributed_backend"] = "ddp" 374 | 375 | train_params["accumulate_grad_batches"] = args.accumulate_grad_batches 376 | train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) 377 | train_params["profiler"] = extra_train_kwargs.get("profiler", None) 378 | 379 | trainer = pl.Trainer.from_argparse_args( 380 | args, 381 | weights_summary=None, 382 | callbacks=[logging_callback] + extra_callbacks, 383 | logger=logger, 384 | checkpoint_callback=checkpoint_callback, 385 | **train_params, 386 | ) 387 | 388 | if args.do_train: 389 | trainer.fit(model) 390 | 391 | return trainer 392 | -------------------------------------------------------------------------------- /scripts/rag/use_own_knowledge_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass, field 4 | from functools import partial 5 | from pathlib import Path 6 | from tempfile import TemporaryDirectory 7 | from typing import List, Optional 8 | 9 | import torch 10 | from datasets import Features, Sequence, Value, load_dataset 11 | 12 | import faiss 13 | from transformers import ( 14 | DPRContextEncoder, 15 | DPRContextEncoderTokenizerFast, 16 | HfArgumentParser, 17 | RagRetriever, 18 | RagSequenceForGeneration, 19 | RagTokenizer, 20 | ) 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | torch.set_grad_enabled(False) 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | 27 | 28 | def split_text(text: str, n=100, character=" ") -> List[str]: 29 | """Split the text every ``n``-th occurrence of ``character``""" 30 | text = text.split(character) 31 | return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)] 32 | 33 | 34 | def split_text_dd(text: str, n=100, character=" ") -> List[str]: 35 | """Split the text every ``n``-th occurrence of ``character``""" 36 | passages = [] 37 | for passage in split_text(text, 1, "####"): 38 | passages.append(passage) 39 | return passages 40 | 41 | 42 | def split_documents(documents: dict) -> dict: 43 | """Split documents into passages""" 44 | titles, texts = [], [] 45 | for title, text in zip(documents["title"], documents["text"]): 46 | if text is not None: 47 | for passage in split_text_dd(text): 48 | titles.append(title if title is not None else "") 49 | texts.append(passage) 50 | print("###### num passage", len(texts)) 51 | return {"title": titles, "text": texts} 52 | 53 | 54 | def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict: 55 | """Compute the DPR embeddings of document passages""" 56 | input_ids = ctx_tokenizer( 57 | documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt" 58 | )["input_ids"] 59 | embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output 60 | return {"embeddings": embeddings.detach().cpu().numpy()} 61 | 62 | 63 | def main( 64 | rag_example_args: "RagExampleArguments", 65 | processing_args: "ProcessingArguments", 66 | index_args: "IndexHnswArguments", 67 | ): 68 | 69 | ###################################### 70 | logger.info("Step 1 - Create the dataset") 71 | ###################################### 72 | 73 | # The dataset needed for RAG must have three columns: 74 | # - title (string): title of the document 75 | # - text (string): text of a passage of the document 76 | # - embeddings (array of dimension d): DPR representation of the passage 77 | 78 | # Let's say you have documents in tab-separated csv files with columns "title" and "text" 79 | assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file" 80 | 81 | # You can load a Dataset object this way 82 | dataset = load_dataset( 83 | "csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"] 84 | ) 85 | 86 | # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files 87 | 88 | # Then split the documents into passages of 100 words 89 | dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) 90 | 91 | # And compute the embeddings 92 | ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device) 93 | ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name) 94 | new_features = Features( 95 | {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))} 96 | ) # optional, save as float32 instead of float64 to save space 97 | dataset = dataset.map( 98 | partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), 99 | batched=True, 100 | batch_size=processing_args.batch_size, 101 | features=new_features, 102 | ) 103 | 104 | # And finally save your dataset 105 | Path(rag_example_args.output_dir).mkdir(exist_ok=True) 106 | passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") 107 | Path(passages_path).mkdir(exist_ok=True) 108 | dataset.save_to_disk(passages_path) 109 | # from datasets import load_from_disk 110 | # dataset = load_from_disk(passages_path) # to reload the dataset 111 | 112 | ###################################### 113 | logger.info("Step 2 - Index the dataset") 114 | ###################################### 115 | 116 | # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search 117 | # index = faiss.IndexHNSWFlat(index_args.d, index_args.m, faiss.METRIC_INNER_PRODUCT) 118 | index = faiss.IndexFlatIP(index_args.d) 119 | dataset.add_faiss_index("embeddings", custom_index=index) 120 | 121 | # And save the index 122 | index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_index.faiss") 123 | dataset.get_index("embeddings").save(index_path) 124 | # dataset.load_faiss_index("embeddings", index_path) # to reload the index 125 | 126 | ###################################### 127 | logger.info("Step 3 - Load RAG") 128 | ###################################### 129 | 130 | # Easy way to load the model 131 | retriever = RagRetriever.from_pretrained( 132 | rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset 133 | ) 134 | model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever) 135 | tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name) 136 | 137 | # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately. 138 | # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path) 139 | 140 | ###################################### 141 | logger.info("Step 4 - Have fun") 142 | ###################################### 143 | 144 | question = rag_example_args.question or "What does Moses' rod turn into ?" 145 | input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"] 146 | generated = model.generate(input_ids) 147 | generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] 148 | logger.info("Q: " + question) 149 | logger.info("A: " + generated_string) 150 | 151 | 152 | @dataclass 153 | class RagExampleArguments: 154 | csv_path: str = field( 155 | default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"), 156 | metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"}, 157 | ) 158 | question: Optional[str] = field( 159 | default=None, 160 | metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."}, 161 | ) 162 | rag_model_name: str = field( 163 | default="facebook/rag-token-nq", 164 | metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"}, 165 | ) 166 | dpr_ctx_encoder_model_name: str = field( 167 | default="facebook/dpr-ctx_encoder-multiset-base", 168 | metadata={ 169 | "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'" 170 | }, 171 | ) 172 | output_dir: Optional[str] = field( 173 | default=None, 174 | metadata={"help": "Path to a directory where the dataset passages and the index will be saved"}, 175 | ) 176 | 177 | 178 | @dataclass 179 | class ProcessingArguments: 180 | num_proc: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": "The number of processes to use to split the documents into passages. Default is single process." 184 | }, 185 | ) 186 | batch_size: int = field( 187 | default=16, 188 | metadata={ 189 | "help": "The batch size to use when computing the passages embeddings using the DPR context encoder." 190 | }, 191 | ) 192 | 193 | 194 | @dataclass 195 | class IndexHnswArguments: 196 | d: int = field( 197 | default=768, 198 | metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."}, 199 | ) 200 | m: int = field( 201 | default=128, 202 | metadata={ 203 | "help": "The number of bi-directional links created for every new element during the HNSW index construction." 204 | }, 205 | ) 206 | 207 | 208 | if __name__ == "__main__": 209 | logging.basicConfig(level=logging.WARNING) 210 | logger.setLevel(logging.INFO) 211 | 212 | parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments)) 213 | rag_example_args, processing_args, index_args = parser.parse_args_into_dataclasses() 214 | with TemporaryDirectory() as tmp_dir: 215 | rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir 216 | main(rag_example_args, processing_args, index_args) -------------------------------------------------------------------------------- /scripts/rag/utils_rag.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import os 5 | import pickle 6 | import re 7 | import socket 8 | import string 9 | from collections import Counter 10 | from logging import getLogger 11 | from pathlib import Path 12 | from typing import Callable, Dict, Iterable, List 13 | from sacrebleu import corpus_bleu 14 | 15 | import git 16 | import torch 17 | from torch.utils.data import Dataset 18 | 19 | from rank_bm25 import BM25Okapi 20 | from datasets import load_dataset 21 | 22 | from transformers import BartTokenizer, RagTokenizer, T5Tokenizer 23 | 24 | 25 | def load_bm25(in_path): 26 | dataset = load_dataset("csv", data_files=[in_path], split="train", delimiter="\t", column_names=["title", "text"]) 27 | passages = [] 28 | for ex in dataset: 29 | passages.extend(ex["text"].split("####")) 30 | passages_tokenized = [passage.strip().lower().split() for passage in passages] 31 | bm25 = BM25Okapi(passages_tokenized) 32 | return bm25 33 | 34 | 35 | def get_top_n_indices(bm25, query, n=5): 36 | query = query.lower().split() 37 | scores = bm25.get_scores(query) 38 | scores_i = [(i, score) for i, score in enumerate(scores)] 39 | sorted_indices = sorted(scores_i, key=lambda score: score[1], reverse=True) 40 | return [x[0] for x in sorted_indices[:n]] 41 | 42 | 43 | def load_bm25_results(in_path): 44 | d_query_pid = {} 45 | total = 0 46 | for split in ["train", "val", "test"]: 47 | queries, bm_rslt = [], [] 48 | with open(os.path.join(in_path, f"{split}.source")) as f: 49 | for line in f: 50 | queries.append(line.strip()) 51 | with open(os.path.join(in_path, f"{split}.bm25")) as f: 52 | for line in f: 53 | bm_rslt.append([int(ele) for ele in line.strip().split("\t")]) 54 | total += len(queries) 55 | d_query_pid.update(dict(zip(queries, bm_rslt))) 56 | return d_query_pid 57 | 58 | 59 | def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 60 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} 61 | tokenizer.padding_side = padding_side 62 | return tokenizer( 63 | [line], 64 | max_length=max_length, 65 | padding="max_length" if pad_to_max_length else None, 66 | truncation=True, 67 | return_tensors=return_tensors, 68 | add_special_tokens=True, 69 | **extra_kw, 70 | ) 71 | 72 | 73 | def encode_line2(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"): 74 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {} 75 | tokenizer.padding_side = padding_side 76 | line = tuple(line.split("[SEP]")) 77 | return tokenizer( 78 | [line], 79 | max_length=max_length, 80 | padding="max_length" if pad_to_max_length else None, 81 | truncation=True, 82 | return_tensors=return_tensors, 83 | add_special_tokens=True, 84 | **extra_kw, 85 | ) 86 | 87 | 88 | def trim_batch( 89 | input_ids, 90 | pad_token_id, 91 | attention_mask=None, 92 | ): 93 | """Remove columns that are populated exclusively by pad_token_id""" 94 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 95 | if attention_mask is None: 96 | return input_ids[:, keep_column_mask] 97 | else: 98 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 99 | 100 | 101 | class Seq2SeqDataset(Dataset): 102 | def __init__( 103 | self, 104 | tokenizer, 105 | data_dir, 106 | max_source_length, 107 | max_target_length, 108 | type_path="train", 109 | n_obs=None, 110 | src_lang=None, 111 | tgt_lang=None, 112 | prefix="", 113 | ): 114 | super().__init__() 115 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 116 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 117 | self.domain_file = Path(data_dir).joinpath(type_path + ".domain") 118 | if not os.path.exists(self.domain_file): 119 | self.domain_file = None 120 | self.src_lens = self.get_char_lens(self.src_file) 121 | self.max_source_length = max_source_length 122 | self.max_target_length = max_target_length 123 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 124 | self.tokenizer = tokenizer 125 | self.prefix = prefix 126 | if n_obs is not None: 127 | self.src_lens = self.src_lens[:n_obs] 128 | self.src_lang = src_lang 129 | self.tgt_lang = tgt_lang 130 | 131 | def __len__(self): 132 | return len(self.src_lens) 133 | 134 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 135 | index = index + 1 # linecache starts at 1 136 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 137 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 138 | assert source_line, f"empty source line for index {index}" 139 | assert tgt_line, f"empty tgt line for index {index}" 140 | domain_line = None 141 | if self.domain_file is not None: 142 | domain_line = linecache.getline(str(self.domain_file), index).rstrip("\n") 143 | assert domain_line, f"empty domain line for index {index}" 144 | 145 | # Need to add eos token manually for T5 146 | if isinstance(self.tokenizer, T5Tokenizer): 147 | source_line += self.tokenizer.eos_token 148 | tgt_line += self.tokenizer.eos_token 149 | 150 | # Pad source and target to the right 151 | source_tokenizer = ( 152 | self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer 153 | ) 154 | target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer 155 | source_inputs = encode_line2(source_tokenizer, source_line, self.max_source_length, "right") 156 | target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right") 157 | 158 | source_ids = source_inputs["input_ids"].squeeze() 159 | target_ids = target_inputs["input_ids"].squeeze() 160 | src_mask = source_inputs["attention_mask"].squeeze() 161 | src_token_type_ids = source_inputs["token_type_ids"].squeeze() 162 | return { 163 | "input_ids": source_ids, 164 | "attention_mask": src_mask, 165 | "token_type_ids": src_token_type_ids, 166 | "decoder_input_ids": target_ids, 167 | "domain": domain_line, 168 | } 169 | 170 | @staticmethod 171 | def get_char_lens(data_file): 172 | return [len(x) for x in Path(data_file).open().readlines()] 173 | 174 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 175 | input_ids = torch.stack([x["input_ids"] for x in batch]) 176 | masks = torch.stack([x["attention_mask"] for x in batch]) 177 | token_type_ids = torch.stack([x["token_type_ids"] for x in batch]) 178 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 179 | domain = [x["domain"] for x in batch] 180 | tgt_pad_token_id = ( 181 | self.tokenizer.generator.pad_token_id 182 | if isinstance(self.tokenizer, RagTokenizer) 183 | else self.tokenizer.pad_token_id 184 | ) 185 | src_pad_token_id = ( 186 | self.tokenizer.question_encoder.pad_token_id 187 | if isinstance(self.tokenizer, RagTokenizer) 188 | else self.tokenizer.pad_token_id 189 | ) 190 | y = trim_batch(target_ids, tgt_pad_token_id) 191 | source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks) 192 | keep_col_mask = input_ids.ne(src_pad_token_id).any(dim=0) 193 | token_type_ids = token_type_ids[:, keep_col_mask] 194 | batch = { 195 | "input_ids": source_ids, 196 | "attention_mask": source_mask, 197 | "token_type_ids": token_type_ids, 198 | "decoder_input_ids": y, 199 | "domain": domain, 200 | } 201 | return batch 202 | 203 | 204 | logger = getLogger(__name__) 205 | 206 | 207 | def flatten_list(summary_ids: List[List]): 208 | return [x for x in itertools.chain.from_iterable(summary_ids)] 209 | 210 | 211 | def save_git_info(folder_path: str) -> None: 212 | """Save git information to output_dir/git_log.json""" 213 | repo_infos = get_git_info() 214 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 215 | 216 | 217 | def save_json(content, path, indent=4, **json_dump_kwargs): 218 | with open(path, "w") as f: 219 | json.dump(content, f, indent=indent, **json_dump_kwargs) 220 | 221 | 222 | def load_json(path): 223 | with open(path) as f: 224 | return json.load(f) 225 | 226 | 227 | def get_git_info(): 228 | repo = git.Repo(search_parent_directories=True) 229 | repo_infos = { 230 | "repo_id": str(repo), 231 | "repo_sha": str(repo.head.object.hexsha), 232 | "repo_branch": str(repo.active_branch), 233 | "hostname": str(socket.gethostname()), 234 | } 235 | return repo_infos 236 | 237 | 238 | def lmap(f: Callable, x: Iterable) -> List: 239 | """list(map(f, x))""" 240 | return list(map(f, x)) 241 | 242 | 243 | def pickle_save(obj, path): 244 | """pickle.dump(obj, path)""" 245 | with open(path, "wb") as f: 246 | return pickle.dump(obj, f) 247 | 248 | 249 | def normalize_answer(s): 250 | """Lower text and remove punctuation, articles and extra whitespace.""" 251 | 252 | def remove_articles(text): 253 | return re.sub(r"\b(a|an|the)\b", " ", text) 254 | 255 | def white_space_fix(text): 256 | return " ".join(text.split()) 257 | 258 | def remove_punc(text): 259 | exclude = set(string.punctuation) 260 | return "".join(ch for ch in text if ch not in exclude) 261 | 262 | def lower(text): 263 | return text.lower() 264 | 265 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 266 | 267 | 268 | def f1_score(prediction, ground_truth): 269 | prediction_tokens = normalize_answer(prediction).split() 270 | ground_truth_tokens = normalize_answer(ground_truth).split() 271 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 272 | num_same = sum(common.values()) 273 | if num_same == 0: 274 | return 0 275 | precision = 1.0 * num_same / len(prediction_tokens) 276 | recall = 1.0 * num_same / len(ground_truth_tokens) 277 | f1 = (2 * precision * recall) / (precision + recall) 278 | return f1 279 | 280 | 281 | def exact_match_score(prediction, ground_truth): 282 | return normalize_answer(prediction) == normalize_answer(ground_truth) 283 | 284 | 285 | def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict: 286 | assert len(output_lns) == len(reference_lns) 287 | em = 0 288 | for hypo, pred in zip(output_lns, reference_lns): 289 | em += exact_match_score(hypo, pred) 290 | if len(output_lns) > 0: 291 | em /= len(output_lns) 292 | return {"em": em} 293 | 294 | 295 | def calculate_bleu(output_lns, refs_lns) -> dict: 296 | """Uses sacrebleu's corpus_bleu implementation.""" 297 | return {"bleu": round(corpus_bleu(output_lns, [refs_lns]).score, 4)} 298 | 299 | 300 | def is_rag_model(model_prefix): 301 | return model_prefix.startswith("rag") 302 | 303 | 304 | def set_extra_model_params(extra_params, hparams, config): 305 | equivalent_param = {p: p for p in extra_params} 306 | # T5 models don't have `dropout` param, they have `dropout_rate` instead 307 | equivalent_param["dropout"] = "dropout_rate" 308 | for p in extra_params: 309 | if getattr(hparams, p, None): 310 | if not hasattr(config, p) and not hasattr(config, equivalent_param[p]): 311 | logger.info("config doesn't have a `{}` attribute".format(p)) 312 | delattr(hparams, p) 313 | continue 314 | set_p = p if hasattr(config, p) else equivalent_param[p] 315 | setattr(config, set_p, getattr(hparams, p)) 316 | delattr(hparams, p) 317 | return hparams, config 318 | -------------------------------------------------------------------------------- /scripts/run_converter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | domain=$1 4 | seg=$2 5 | 6 | config=dpr-$domain-$seg 7 | 8 | dpr=dpr_mdd-$domain-$seg 9 | src=YOUR_DPR_CHECKPOINT 10 | 11 | mkdir $CHECKPOINTS/$config 12 | 13 | python convert_dpr_original_checkpoint_to_pytorch.py \ 14 | --type question_encoder \ 15 | --src $src \ 16 | --dest $CHECKPOINTS/dpr-$domain-$seg/question_encoder 17 | 18 | python convert_dpr_original_checkpoint_to_pytorch.py \ 19 | --type ctx_encoder \ 20 | --src $src \ 21 | --dest $CHECKPOINTS/dpr-$domain-$seg/ctx_encoder 22 | 23 | 24 | # generate rag model 25 | cp ../data/tokenizer_config.json $CHECKPOINTS/$config/question_encoder/ 26 | cp ../data/vocab.txt $CHECKPOINTS/$config/question_encoder/ 27 | cp ../data/tokenizer_config.json $CHECKPOINTS/$config/ctx_encoder/ 28 | cp ../data/vocab.txt $CHECKPOINTS/$config/ctx_encoder/ 29 | 30 | # config "model_path" for question encoder to your local path to DPR encoder; 31 | # or use our uploaded model, such as "sivasankalpp/dpr-multidoc2dial-token-question-encoder" or "sivasankalpp/dpr-multidoc2dial-structure-question-encoder" 32 | python model_converter.py \ 33 | --model_path $CHECKPOINTS/$config/question_encoder \ 34 | --out_path $CHECKPOINTS/rag-$config -------------------------------------------------------------------------------- /scripts/run_converter_modelcard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | domain=$1 4 | seg=$2 5 | config=dpr-$domain-$seg 6 | 7 | # config "model_path" for question encoder to your local path to DPR encoder; 8 | # or use our uploaded model, such as "sivasankalpp/dpr-multidoc2dial-token-question-encoder" or "sivasankalpp/dpr-multidoc2dial-structure-question-encoder" 9 | python model_converter.py \ 10 | --model_path sivasankalpp/dpr-multidoc2dial-$seg-question-encoder \ 11 | --out_path $CHECKPOINTS/rag-$config -------------------------------------------------------------------------------- /scripts/run_data_preprocessing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | seg=$1 # token or structure 4 | task=$2 # grounding or generation 5 | YOUR_DIR=../data # change it to your own local dir 6 | 7 | 8 | python data_preprocessor.py \ 9 | --dataset_config_name multidoc2dial \ 10 | --output_dir $YOUR_DIR/mdd_all \ 11 | --kb_dir $YOUR_DIR/mdd_kb \ 12 | --segmentation $seg \ 13 | --task $task -------------------------------------------------------------------------------- /scripts/run_data_preprocessing_domain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | domain=$1 # dmv va ssa or studentaid 4 | seg=$2 # token or structure 5 | task=$3 # grounding or generation 6 | YOUR_DIR=../data # change it to your own local dir 7 | 8 | python data_preprocessor.py \ 9 | --dataset_config_name multidoc2dial_$domain \ 10 | --output_dir $YOUR_DIR/mdd_$domain \ 11 | --target_domain $domain \ 12 | --kb_dir $YOUR_DIR/mdd_kb \ 13 | --segmentation $seg \ 14 | --task $task -------------------------------------------------------------------------------- /scripts/run_data_preprocessing_dpr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | seg=$1 # token or structure 4 | domain=$2 # dmv va ssa or studentaid 5 | YOUR_DIR=../data # change it to your own local dir 6 | 7 | python data_preprocessor.py \ 8 | --dataset_config_name multidoc2dial \ 9 | --output_dir $YOUR_DIR/mdd_dpr \ 10 | --segmentation $seg \ 11 | --dpr -------------------------------------------------------------------------------- /scripts/run_data_preprocessing_dpr_domain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | seg=$1 # token or structure 4 | domain=$2 # dmv va ssa or studentaid 5 | YOUR_DIR=../data # change it to your own local dir 6 | 7 | python data_preprocessor.py \ 8 | --dataset_config_name multidoc2dial_$domain \ 9 | --output_dir $YOUR_DIR/mdd_dpr \ 10 | --segmentation $seg \ # structure or token 11 | --dpr \ 12 | --in_domain_only -------------------------------------------------------------------------------- /scripts/run_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir ../data && \ 4 | cd ../data && \ 5 | wget http://doc2dial.github.io/multidoc2dial/file/multidoc2dial.zip && \ 6 | wget http://doc2dial.github.io/multidoc2dial/file/multidoc2dial_domain.zip && \ 7 | unzip multidoc2dial.zip && \ 8 | unzip multidoc2dial_domain.zip && \ 9 | rm *.zip && \ 10 | wget https://huggingface.co/facebook/rag-token-nq/raw/main/question_encoder_tokenizer/tokenizer_config.json && \ 11 | wget https://huggingface.co/facebook/rag-token-nq/raw/main/question_encoder_tokenizer/vocab.txt -------------------------------------------------------------------------------- /scripts/run_eval_rag_e2e.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH="../":"${PYTHONPATH}" 4 | domain=$1 # all dmv va ssa studentaid 5 | seg=$2 # token structure 6 | score=$3 # original reranking reranking_original 7 | task=$4 # grounding generation 8 | split=$5 # val test 9 | 10 | dpr=dpr-$domain-$seg 11 | DATA_DIR=../data/mdd_$domain/dd-$task-$seg 12 | KB_FOLDER=../data/mdd_kb/knowledge_dataset-$dpr 13 | MODEL_PATH=$CHECKPOINTS/mdd-$task-$dpr-$score/ 14 | 15 | 16 | python rag/eval_rag.py \ 17 | --model_type rag_token_dialdoc \ 18 | --scoring_func $score \ 19 | --gold_pid_path $DATA_DIR/$split.pids \ 20 | --passages_path $KB_FOLDER/my_knowledge_dataset \ 21 | --index_path $KB_FOLDER/my_knowledge_dataset_index.faiss \ 22 | --index_name dialdoc \ 23 | --n_docs 10 \ 24 | --model_name_or_path $MODEL_PATH \ 25 | --eval_mode e2e \ 26 | --evaluation_set $DATA_DIR/$split.source \ 27 | --gold_data_path $DATA_DIR/$split.target \ 28 | --gold_data_mode ans \ 29 | --recalculate \ 30 | --eval_all_checkpoints \ 31 | --predictions_path results.txt -------------------------------------------------------------------------------- /scripts/run_eval_rag_re.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH="../":"${PYTHONPATH}" 4 | domain=$1 # all dmv va ssa studentaid 5 | seg=$2 # token structure 6 | score=$3 # original reranking reranking_original 7 | task=$4 # grounding generation 8 | split=$5 # val test 9 | 10 | 11 | dpr=dpr-$domain-$seg 12 | DATA_DIR=../data/mdd_$domain/dd-$task-$seg 13 | 14 | if [[ $split != "test" && $domain != "all" ]]; then 15 | KB_FOLDER=../data/mdd_kb/knowledge_dataset-$dpr-wo 16 | else 17 | KB_FOLDER=../data/mdd_kb/knowledge_dataset-$dpr 18 | fi 19 | 20 | MODEL_PATH=$CHECKPOINTS/mdd-$task-$dpr-$score/ 21 | 22 | python rag/eval_rag.py \ 23 | --model_type rag_token_dialdoc \ 24 | --scoring_func $score \ 25 | --gold_pid_path $DATA_DIR/$split.pids \ 26 | --passages_path $KB_FOLDER/my_knowledge_dataset \ 27 | --index_name dialdoc \ 28 | --index_path $KB_FOLDER/my_knowledge_dataset_index.faiss \ 29 | --n_docs 10 \ 30 | --model_name_or_path $MODEL_PATH \ 31 | --eval_mode retrieval \ 32 | --evaluation_set $DATA_DIR/$split.source \ 33 | --gold_data_path $DATA_DIR/$split.titles \ 34 | --gold_data_mode ans \ 35 | --recalculate \ 36 | --eval_all_checkpoints \ 37 | --predictions_path results.txt 38 | -------------------------------------------------------------------------------- /scripts/run_finetune_rag_dialdoc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH="../":"${PYTHONPATH}" 4 | export TOKENIZERS_PARALLELISM=false 5 | domain=$1 # all dmv ssa va studentaid 6 | seg=$2 # token structure 7 | score=$3 # original reranking reranking_original 8 | task=$4 # grounding generation 9 | seed=$RANDOM 10 | 11 | dpr=dpr-$domain-$seg 12 | MODEL_NAME_OR_PATH=$CHECKPOINTS/rag-$dpr 13 | KB_FOLDER=../data/mdd_kb/knowledge_dataset-$dpr 14 | DATA_DIR=../data/mdd_$domain/dd-$task-$seg 15 | 16 | python rag/finetune_rag_dialdoc.py \ 17 | --seed $seed \ 18 | --segmentation $seg \ 19 | --do_marginalize 1 \ 20 | --data_dir $DATA_DIR \ 21 | --scoring_func $score \ 22 | --output_dir $CHECKPOINTS/mdd-$task-$dpr-$score \ 23 | --model_name_or_path $MODEL_NAME_OR_PATH \ 24 | --model_type rag_token_dialdoc \ 25 | --index_name dialdoc \ 26 | --passages_path $KB_FOLDER/my_knowledge_dataset \ 27 | --index_path $KB_FOLDER/my_knowledge_dataset_index.faiss \ 28 | --fp16 \ 29 | --profile \ 30 | --do_train \ 31 | --gpus 1 \ 32 | --n_train -1 \ 33 | --n_val -1 \ 34 | --n_test -1 \ 35 | --n_docs 5 \ 36 | --train_batch_size 8 \ 37 | --eval_batch_size 2 \ 38 | --max_combined_length 300 \ 39 | --max_source_length 128 \ 40 | --max_target_length 50 \ 41 | --val_max_target_length 50 \ 42 | --test_max_target_length 50 \ 43 | --label_smoothing 0.1 \ 44 | --dropout 0.1 \ 45 | --attention_dropout 0.1 \ 46 | --weight_decay 0.001 \ 47 | --adam_epsilon 1e-08 \ 48 | --max_grad_norm 0.1 \ 49 | --lr_scheduler polynomial \ 50 | --learning_rate 3e-05 \ 51 | --num_train_epochs 2 \ 52 | --warmup_steps 500 \ 53 | --gradient_accumulation_steps 1 54 | -------------------------------------------------------------------------------- /scripts/run_kb_index.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | domain=$1 # "all" or "dmv", "ssa", "studentaid", "va" for domain adaptation setup 4 | seg=$2 # token or structure 5 | 6 | dpr=dpr-$domain-$seg 7 | rag_model_name=$CHECKPOINTS/rag-$dpr 8 | # config "ctx_model_name" for ctx encoder to your local path to DPR encoder; 9 | # ctx_model_name=$CHECKPOINTS/$dpr/ctx_encoder 10 | # or use our fine-tuned DPR encoders, such as "sivasankalpp/dpr-multidoc2dial-token-ctx-encoder" or "sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder" 11 | ctx_model_name=sivasankalpp/dpr-multidoc2dial-$seg-ctx-encoder 12 | KB_FOLDER=../data/mdd_kb/ 13 | 14 | python rag/use_own_knowledge_dataset.py \ 15 | --rag_model_name $rag_model_name \ 16 | --dpr_ctx_encoder_model_name $ctx_model_name \ 17 | --csv_path $KB_FOLDER/mdd-$seg-$domain.csv \ 18 | --output_dir $KB_FOLDER/knowledge_dataset-$dpr 19 | 20 | -------------------------------------------------------------------------------- /scripts/run_kb_index_domain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | domain=$1 # dmv ssa studentaid va 4 | seg=$2 # token structure 5 | 6 | dpr=dpr-$domain-$seg 7 | rag_model_name=$CHECKPOINTS/rag-$dpr 8 | ctx_model_name=$CHECKPOINTS/$dpr/ctx_encoder 9 | KB_FOLDER=../data/mdd_kb/ 10 | 11 | # for train and validation split (with $domain as target domain) 12 | python rag/use_own_knowledge_dataset.py \ 13 | --rag_model_name $rag_model_name \ 14 | --dpr_ctx_encoder_model_name $ctx_model_name \ 15 | --csv_path $KB_FOLDER/mdd-$seg-wo-$domain.csv \ 16 | --output_dir $KB_FOLDER/knowledge_dataset-$dpr-wo 17 | 18 | # for test split (with $domain as target domain) 19 | python rag/use_own_knowledge_dataset.py \ 20 | --rag_model_name $rag_model_name \ 21 | --dpr_ctx_encoder_model_name $ctx_model_name \ 22 | --csv_path $KB_FOLDER/mdd-$seg-$domain.csv \ 23 | --output_dir $KB_FOLDER/knowledge_dataset-$dpr 24 | 25 | -------------------------------------------------------------------------------- /scripts/run_sharedtask_eval.sh: -------------------------------------------------------------------------------- 1 | python sharedtask_eval.py \ 2 | --task grounding \ 3 | --prediction_json ../sharedtask/sample_files/sample_task_grounding_predictions.json \ 4 | --reference_json ../sharedtask/sample_files/sample_task_references.json 5 | 6 | 7 | python sharedtask_eval.py \ 8 | --task utterance \ 9 | --prediction_json ../sharedtask/sample_files/sample_task_utterance_predictions.json \ 10 | --reference_json ../sharedtask/sample_files/sample_task_references.json -------------------------------------------------------------------------------- /scripts/sharedtask_eval.py: -------------------------------------------------------------------------------- 1 | """ F1_score is from Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | import json 3 | import string 4 | import re 5 | import argparse 6 | from collections import Counter 7 | from datasets import load_metric 8 | from rag.utils_rag import f1_score, exact_match_score 9 | 10 | 11 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 12 | scores_for_ground_truths = [] 13 | for ground_truth in ground_truths: 14 | score = metric_fn(prediction, ground_truth) 15 | scores_for_ground_truths.append(score) 16 | return max(scores_for_ground_truths) 17 | 18 | 19 | def matching_evaluate(references, predictions): 20 | f1 = em = total = 0 21 | for id_, ref_text in references.items(): 22 | total += 1 23 | ground_truths = [ref_text] 24 | prediction = predictions.get(id_, "") 25 | f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) 26 | em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) 27 | f1 = 100.0 * f1 / total 28 | em = 100.0 * em / total 29 | 30 | return f1, em 31 | 32 | 33 | def matching_metrics(task, reference_json, prediction_json): 34 | d_id_reference = {} 35 | references_text = [] 36 | references_list = [] 37 | with open(reference_json) as fp_ref: 38 | data = json.load(fp_ref) 39 | for d_ref in data: 40 | d_id_reference[d_ref["id"]] = d_ref[task] 41 | references_list.append([d_ref[task]]) 42 | references_text.append(d_ref[task]) 43 | predictions = [] 44 | d_id_prediction = {} 45 | with open(prediction_json) as fp_pred: 46 | data = json.load(fp_pred) 47 | for d_pred in data: 48 | d_id_prediction[d_pred["id"]] = d_pred[task] 49 | predictions.append(d_pred[task]) 50 | assert ( 51 | len(predictions) == len(references_list) == len(references_text) 52 | ), "Ensure the matching count of instances of references and predictioins" 53 | 54 | output = {} 55 | f1_score, em_score = matching_evaluate(references=d_id_reference, predictions=d_id_prediction) 56 | if task == "utterance": 57 | metric_sacrebleu = load_metric("sacrebleu") 58 | results = metric_sacrebleu.compute(predictions=predictions, references=references_list) 59 | sacrebleu_score = results["score"] 60 | 61 | metric_meteor = load_metric("meteor") 62 | results = metric_meteor.compute(predictions=predictions, references=references_text) 63 | meteor_score = round(results["meteor"] * 100, 4) 64 | 65 | metric_rouge = load_metric("rouge") 66 | results = metric_rouge.compute(predictions=predictions, references=references_text) 67 | rouge_score = round(results["rougeL"].mid.fmeasure * 100, 4) 68 | output = {"F1_U": f1_score, "SACREBLEU_U": sacrebleu_score, "METEOR_U": meteor_score, "ROUGE-L_U": rouge_score} 69 | else: 70 | output = {"EM_G": em_score, "F1_G": f1_score} 71 | return output 72 | 73 | 74 | def main(): 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument( 77 | "--task", 78 | type=str, 79 | required=True, 80 | help="Select metrics for task that is either 'grounding' or 'utterance' generation", 81 | ) 82 | parser.add_argument( 83 | "--prediction_json", 84 | type=str, 85 | required=True, 86 | help="Path to predictions", 87 | ) 88 | parser.add_argument( 89 | "--reference_json", 90 | type=str, 91 | required=True, 92 | help="Path to references", 93 | ) 94 | 95 | args = parser.parse_args() 96 | output = matching_metrics(args.task, args.reference_json, args.prediction_json) 97 | print("task:", args.task) 98 | print("output:", output) 99 | 100 | 101 | if __name__ == "__main__": 102 | """ 103 | task: grounding 104 | output: {'EM': 20.0, 'F1': 28.047519076264688} 105 | 106 | task: utterance 107 | output: {'F1': 13.25862068965517, 'SACREBLEU': 1.5941520509774114, 'METEOR': 8.3403, 'ROUGE-L': 9.0637} 108 | """ 109 | main() 110 | -------------------------------------------------------------------------------- /sharedtask/README.md: -------------------------------------------------------------------------------- 1 | # Shared Task of DialDoc Workshop at ACL 2022 2 | 3 | This shared task of [2nd DialDoc Workshop](https://doc2dial.github.io/workshop2022/) at [ACL 2022](https://www.2022.aclweb.org) focuses on modeling goal-oriented dialogues that are grounded in multiple domain documents. This repository provides the code for the baselines using the train and validation split of [MultiDoc2Dial](http://doc2dial.github.io/multidoc2dial/) dataset. 4 | 5 | Please cite the paper and star the repository if you find the paper, data and code useful for your work. 6 | 7 | ```bibtex 8 | @inproceedings{feng2021multidoc2dial, 9 | title={MultiDoc2Dial: Modeling Dialogues Grounded in Multiple Documents}, 10 | author={Feng, Song and Patel, Siva Sankalp and Wan, Hui and Joshi, Sachindra}, 11 | booktitle={EMNLP}, 12 | year={2021} 13 | } 14 | ``` 15 | 16 | Please refer to the main [README](../README.md) for details about running baselines on MultiDoc2Dial data. 17 | 18 | 19 | ## Evaluations 20 | For sample prediction/reference files (`sharedtask/sample_files`) and evaluation script for leaderboard submission, please refer to, 21 | 22 | > [`run_sharedtask_eval.sh`](../scripts/run_sharedtask_eval.sh) 23 | 24 | 25 | ## Participation 26 | 27 | Check out our [leaderboard](https://eval.ai/web/challenges/challenge-page/1437/overview)! Please refer to the [workshop website](https://doc2dial.github.io/workshop2022/#shared) for more details about participating the Shared Task. 28 | -------------------------------------------------------------------------------- /sharedtask/sample_files/sample_task_grounding_predictions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "1409501a35697e0ce68561e29577b90a_1", 4 | "grounding": "Because we all pay indirectly for crashes involving uninsured motorists , New York State requires every motorist to maintain auto insurance every single day a vehicle is registered. DMV works with insurance companies to electronically monitor your insurance coverage ," 5 | }, 6 | { 7 | "id": "1409501a35697e0ce68561e29577b90a_3", 8 | "grounding": "If your driver license or your vehicle registration is suspended because of a lapse in automobile liability insurance coverage," 9 | }, 10 | { 11 | "id": "1409501a35697e0ce68561e29577b90a_5", 12 | "grounding": "use the Ask DMV a Question [8] service to request a correction. We may be able to correct it without you coming a DMV office. There is no fee for a correction." 13 | }, 14 | { 15 | "id": "1409501a35697e0ce68561e29577b90a_7", 16 | "grounding": "You must provide proof of identity and date of birth. You must be at least 16 years old except for ATV registrations. You can use a New York State" 17 | }, 18 | { 19 | "id": "1409501a35697e0ce68561e29577b90a_9", 20 | "grounding": "they can be charged based on training time no matter how much money you re paid back. Half - time training rates reduce your driver license by a half - month for each month you re enrolled." 21 | } 22 | ] -------------------------------------------------------------------------------- /sharedtask/sample_files/sample_task_references.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "1409501a35697e0ce68561e29577b90a_1", 4 | "utterance": "You will need to get insurance or we will suspend your registration and license", 5 | "grounding": "Because we all pay indirectly for crashes involving uninsured motorists , New York State requires every motorist to maintain auto insurance every single day a vehicle is registered. DMV works with insurance companies to electronically monitor your insurance coverage ," 6 | }, 7 | { 8 | "id": "1409501a35697e0ce68561e29577b90a_3", 9 | "utterance": "Okay, have you received a letter from the DMV letting you know how to clear things up?", 10 | "grounding": "we mail you an insurance inquiry letter to allow you to clear up the problem." 11 | }, 12 | { 13 | "id": "1409501a35697e0ce68561e29577b90a_5", 14 | "utterance": "Okay, we can take care of that.", 15 | "grounding": "Learn more about how to change the address on your license and registrations [1 ]" 16 | }, 17 | { 18 | "id": "1409501a35697e0ce68561e29577b90a_7", 19 | "utterance": "Sure, it is. You can contact your college and get a certified copy.", 20 | "grounding": "you may have received a General Educational Development GED certificate. You can contact the college for a certified copy. U.S. college transcripts." 21 | }, 22 | { 23 | "id": "1409501a35697e0ce68561e29577b90a_9", 24 | "utterance": "Yes, as long as they be a phone, electricity, gas, water or cable bill.", 25 | "grounding": "a phone bill, electricity / gas bill, water bill or cable bill. Divorce papers. You can contact your lawyer to find out where to go to obtain the papers, or check with the state where you got divorced." 26 | } 27 | ] -------------------------------------------------------------------------------- /sharedtask/sample_files/sample_task_utterance_predictions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "1409501a35697e0ce68561e29577b90a_1", 4 | "utterance": "Was your insurance up to date?" 5 | }, 6 | { 7 | "id": "1409501a35697e0ce68561e29577b90a_3", 8 | "utterance": "You will need to get insurance or we will suspend your registration and license" 9 | }, 10 | { 11 | "id": "1409501a35697e0ce68561e29577b90a_5", 12 | "utterance": "Do you want to change the address on your license and registration?" 13 | }, 14 | { 15 | "id": "1409501a35697e0ce68561e29577b90a_7", 16 | "utterance": "Yes, that is correct." 17 | }, 18 | { 19 | "id": "1409501a35697e0ce68561e29577b90a_9", 20 | "utterance": "Yes, sure. Just go online and check them out. Be aware that you must pay any fees, penalties and room and board that are due to you." 21 | } 22 | ] --------------------------------------------------------------------------------