├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── args ├── api.tsv └── kb_sites.txt ├── fig ├── dialog_web_demo.PNG └── doc_web_demo.PNG ├── setup.sh ├── setup_linux.sh ├── setup_win.sh └── src ├── cmr ├── __init__.py ├── batcher.py ├── common.py ├── config.py ├── dreader.py ├── dreader_seq2seq.py ├── dropout_wrapper.py ├── encoder.py ├── fetch_realtime_grounding.py ├── model.py ├── my_optim.py ├── my_utils │ ├── __init__.py │ ├── eval_bleu.py │ ├── eval_nist.py │ ├── log_wrapper.py │ ├── squad_eval.py │ ├── tokenizer.py │ ├── utils.py │ └── word2vec_utils.py ├── process_raw_data.py ├── recurrent.py ├── san_decoder.py ├── similarity.py └── sub_layers.py ├── common.txt ├── demo_dialog.py ├── demo_doc_gen.py ├── grounded.py ├── knowledge.py ├── lm.py ├── mrc.py ├── onmt ├── Beam.py ├── Constants.py ├── Dataset.py ├── Dict.py ├── Models.py ├── Optim.py ├── Translator.py ├── __init__.py └── modules │ ├── GlobalAttention.py │ └── __init__.py ├── open_dialog.py ├── ranker.py ├── shared.py ├── templates ├── dialog.html └── doc_gen.html ├── todo.py └── tts.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | src/transformers/ 4 | models/ 5 | voice/ 6 | temp/ 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixingBoard: a Knowledgeable Stylized Integrated Text Generation Platform 2 | 3 | We present [MixingBoard](https://arxiv.org/abs/2005.08365), a platform for quickly building demos with a focus on knowledge grounded stylized text generation. We unify existing text generation algorithms in a shared codebase and further adapt earlier algorithms for constrained generation. To borrow advantages from different models, we implement strategies for cross-model integration, from the token probability level to the latent space level. An interface to external knowledge is provided via a module that retrieves on-the-fly relevant knowledge from passages on the web or any document collection. A user interface for local development, remote webpage access, and a RESTful API are provided to make it simple for users to build their own demos. 4 | 5 | # News 6 | * July 6, 2020: MixingBoard repo is released on GitHub. 7 | * Apr 3, 2020: MixingBoard [paper](https://arxiv.org/abs/2005.08365) is accepted to appear on [ACL 2020](https://acl2020.org/) Demo track. 8 | 9 | # Setup 10 | 11 | We recommend using [Anaconda](https://www.anaconda.com/) to setup 12 | Firstly, create an environment with Python 3.6 13 | ``` 14 | conda create -n mixingboard python=3.6 15 | conda activate mixingboard 16 | ``` 17 | Then, install Python packages with 18 | ``` 19 | sh setup.sh 20 | ``` 21 | Then, depending on your operating system, download pretrained models with 22 | ``` 23 | # if using Windows 24 | sh setup_win.sh 25 | # if using Linux 26 | sh setup_linux.sh 27 | ``` 28 | 29 | If you prefer to use the web search and text-to-speech functions, please apply the following accounts. 30 | * **Bing Search API**: open an account and/or try for free on [Azure Cognitive Services](https://azure.microsoft.com/en-us/services/cognitive-services/bing-web-search-api/). Once you obtained the key, please put it in `args/api.tsv`. You can also try other search engine, however we currently only support Bing Search v7.0 in `src/knowledge.py`. 31 | * **Text-to-Speech**: open an account and/or try for free on [Azure Cognitive Services](https://azure.microsoft.com/en-us/services/cognitive-services/text-to-speech/). Once you obtained the key, please put it in `args/api.tsv`. 32 | 33 | Finally, Please implement your own `pick_tokens` function in `src/todo.py` (see [Disclaimer](#Disclaimer)). This function is used to pick tokens for a generation time step given predicted token probability distribution. Many choices are available, e.g. greedy, top-k, top-p, or sampling. 34 | 35 | 36 | # Modules 37 | 38 | ## Knowledge passage retrieval 39 | We use the following unstructured free-text sources to retrieve relevant knowledge passage: search engine, specialized websites (e.g. wikipedia), and user provided document. 40 | ``` 41 | python src/knowledge.py 42 | ``` 43 | The above command calls Bing search API and the following shows results of an example query. 44 | ``` 45 | QUERY: what is deep learning? 46 | 47 | URL: https://en.wikipedia.org/wiki/Deep_learning 48 | TXT: Deep learning is a class of machine learning algorithms that (pp199–200) uses multiple layers to progressively extract higher level features from the raw input. For example, in image processing, lower layers may identify edges, while higher layers may identify the concepts relevant to a human such as digits or letters or faces.. Overview. Most modern deep learning models are based on ... 49 | 50 | URL: https://machinelearningmastery.com/what-is-deep-learning/ 51 | TXT: Deep Learning is Large Neural Networks. Andrew Ng from Coursera and Chief Scientist at Baidu Research formally founded Google Brain that eventually resulted in the productization of deep learning technologies across a large number of Google services.. He has spoken and written a lot about what deep learning is and is a good place to start. In early talks on deep learning, Andrew described deep ... 52 | 53 | URL: https://www.forbes.com/sites/bernardmarr/2018/10/01/what-is-deep-learning-ai-a-simple-guide-with-8-practical-examples/ 54 | TXT: Since deep-learning algorithms require a ton of data to learn from, this increase in data creation is one reason that deep learning capabilities have grown in recent years. 55 | ``` 56 | 57 | ## Open-ended dialogue generation 58 | We use [DialoGPT](https://github.com/microsoft/DialoGPT) as an example. 59 | ``` 60 | python src/open_dialog.py 61 | ``` 62 | The following shows DialoGPT (`DPT`) predictions of an example query using one implementation of the `pick_tokens` function. 63 | ``` 64 | CONTEXT: What's your dream? 65 | DPT 0.198 First one is to be a professional footballer. Second one is to be a dad. Third one is to be a father of a second son. 66 | DPT 0.198 First one is to be a professional footballer. Second one is to be a dad. Third one is to be a father of two. 67 | ... 68 | ``` 69 | 70 | ## Generation with language model 71 | We use [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) as an example. 72 | ``` 73 | python src/lm.py 74 | ``` 75 | The following shows GPT-2 predictions of an example query using one implementation of the `pick_tokens` function. 76 | ``` 77 | CONTEXT: Deep learning and Natural Language Processing are 78 | GPT2 0.128 not to be relied on in everyday life. The good news is, with a little practice, you'll be able to master them. 79 | GPT2 0.101 not to be relied on in everyday life. The good news is, with a little practice, you'll be able to master them quickly 80 | GPT2 0.096 not to be relied on in everyday life. The good news is, with a little practice, you will be able to solve complex problem 81 | ... 82 | ``` 83 | 84 | ## Machine reading comprehension 85 | ``` 86 | python src/mrc.py 87 | ``` 88 | The above command calls [BiDAF](https://allenai.github.io/bi-att-flow/) model. Given a passage from a [Wikipedia page](https://en.wikipedia.org/wiki/Geoffrey_Hinton) and an example query, it returns the following results 89 | ``` 90 | QUERY: Who is Jeffrey Hinton? 91 | PASSAGE: Geoffrey Everest Hinton CC FRS FRSC is an English Canadian cognitive psychologist and computer scientist, most noted for his work on artificial neural networks. Since 2013 he divides his time working for Google and the University of Toronto. In 2017, he cofounded and became the Chief Scientific Advisor of the Vector Institute in Toronto. 92 | Bidaf 0.352 an English Canadian cognitive psychologist and computer scientist 93 | ``` 94 | 95 | ## Grounded generation 96 | We consider two document-grounded text generation algorithms: 97 | * [Conversing-by-Reading](https://github.com/qkaren/converse_reading_cmr), which aims to generate proper dialog response grounded on relevant document or text knowledge passage. It can be called with the command below 98 | ``` 99 | python src/grounded.py cmr 100 | ``` 101 | * [Content-Transfer](https://github.com/shrimai/Towards-Content-Transfer-through-Grounded-Text-Generation), which aims to generate proper sentences in a given document context given another relevant document or text knowledge passage. It can be called with the command below 102 | ``` 103 | python src/grounded.py ct 104 | ``` 105 | 106 | ## Text-to-speech 107 | ``` 108 | python src/tts.py 109 | ``` 110 | The above command calls [Microsoft Azure Text-to-Speech API](https://azure.microsoft.com/en-us/services/cognitive-services/text-to-speech/), saves and plays the audio. The following is one example. 111 | ``` 112 | TXT: Hello there, welcome to the Mixing Board repo! 113 | audio saved to voice/hellotherewelcometothemixingboardrepo_en-US-JessaNeural.wav 114 | ``` 115 | 116 | ## Ranking 117 | We consider multiple metrics to rank the hypotheses, including 1) forward and reverse generation likelihood, 2) repetition penalty, 3) informativeness, and 4) style intensity. 118 | ``` 119 | python src/ranker.py 120 | ``` 121 | Following are some examples of the the above command. 122 | ``` 123 | TXT: This is a normal sentence. 124 | rep -0.0000 info 0.1619 score 0.1619 125 | 126 | TXT: This is a repetive and repetive sentence. 127 | rep -0.1429 info 0.2518 score 0.1089 128 | 129 | TXT: This is a informative sentence from the MixingBoard GitHub repo. 130 | rep -0.0000 info 0.4416 score 0.4416 131 | ``` 132 | 133 | ## Coming soon 134 | the modules for stylization, constrained generation and cross-model integration will be available soon in this repo. 135 | 136 | # Dialog Demo 137 | 138 | ## Comand-line interface 139 | The comand-line interface can be started with the following command. 140 | ``` 141 | python src/demo_dialog.py cmd 142 | ``` 143 | ## Webpage interface 144 | ``` 145 | python src/demo_dialog.py web 146 | ``` 147 | The comand above creates a webpage demo that can be visited by typing `localhost:5000` in your browser. You can interact with the models, and the following screenshot is an example 148 | ![](https://github.com/microsoft/MixingBoard/blob/master/fig/dialog_web_demo.PNG) 149 | 150 | ## RESTful API 151 | ``` 152 | python src/demo_dialog.py api 153 | ``` 154 | Runing the command above on your machine `A` (say its IP address is `IP_address_A`) starts to host the models on machine `A` with a RESTful API. Then, you can call this API on another machine, say machine `B`, with the following command, using "what is machine learning?" as an example context 155 | ``` 156 | curl IP_address_A:5000 -d "context=what is machine learning?" -X GET 157 | ``` 158 | which will returns a json object, in the following format 159 | ```json 160 | { 161 | "context": "what is machine learning?", 162 | "passages": [[ 163 | "https://en.wikipedia.org/wiki/Machine_learning", 164 | "Machine learning (ML) is the study of computer algorithms that improve automatically through experience. It is seen as a subset of artificial intelligence.Machine learning algorithms build a mathematical model based on sample data, known as \"training data\", in order to make predictions or decisions without being explicitly programmed to do so. Machine learning algorithms are used in a wide ..." 165 | ]], 166 | "responses": [ 167 | { 168 | "rep": -0.0, "info": 0.4280192169639406, "fwd": 0.014708111993968487, "rvs": 0.10698941218944846, "score": 0.5497167508995263, "way": "Bidaf", 169 | "hyp": "computer algorithms that improve automatically through experience"}, 170 | { 171 | "rep": -0.0, "info": 0.24637171873352778, "fwd": 0.16426260769367218, "rvs": 0.05065313921885011, "score": 0.46128747495542344, "way": "DPT", 172 | "hyp": "I believe that is a fancy way to say artificial intelligence."}, 173 | { 174 | "rep": -0.1428571428571429, "info": 0.22310269295193919, "fwd": 0.1599835902452469, "rvs": 0.21712445686414383, "score": 0.4573535985050974, "way": "DPT", 175 | "hyp": "I believe that is a fancy way to put it. Machine learning is a set of algorithms and algorithms are machines."}, 176 | ]} 177 | ``` 178 | Besides calling API by `curl`, you can also lanch a webpage demo on machine `B`, but using the backend running on machine `A` with the API, using the following command 179 | ``` 180 | python src/demo_dialog.py web --remote=IP_address_A:5000 --port=5001 181 | ``` 182 | 183 | # Document generation Demo 184 | 185 | ## Comand-line interface 186 | The comand-line interface can be started with the following command. 187 | ``` 188 | python src/demo_dialog.py cmd 189 | ``` 190 | ## Webpage interface 191 | ``` 192 | python src/demo_doc_gen.py web 193 | ``` 194 | The comand above creates a webpage demo that can be visited by typing `localhost:5000` in your browser. You can interact with the models, and the following screenshot is an example 195 | ![](https://github.com/microsoft/MixingBoard/blob/master/fig/doc_web_demo.PNG) 196 | 197 | ## RESTful API 198 | ``` 199 | python src/demo_doc_gen.py api 200 | ``` 201 | Runing the command above on your machine `A` (say its IP address is `IP_address_A`) starts to host the models on machine `A` with a RESTful API. Similar to the Dialog Demo, you can use `curl` to call this backend from another machine 202 | ``` 203 | curl IP_address_A:5000 -d "context=Deep learning is" -X GET 204 | ``` 205 | which will returns a json object, in the following format 206 | ```json 207 | { 208 | "context": "Deep learning is", 209 | "passages": [[ 210 | "https://en.wikipedia.org/wiki/Deep_learning", 211 | "Deep learning is a class of machine learning algorithms that (pp199\u2013200) uses multiple layers to progressively extract higher level features from the raw input. For example, in image processing, lower layers may identify edges, while higher layers may identify the concepts relevant to a human such as digits or letters or faces.. Overview. Most modern deep learning models are based on ..." 212 | ]], 213 | "responses": [ 214 | { 215 | "rep": -0.07407407407407407, "info": 0.36715887127947433, "fwd": 0.06162497028708458, "score": 0.35470977305382584, "way": "GPT2", 216 | "hyp": "particularly exciting at work because it allows anyone with a background in machine learning or machine learning algorithms to solve real-world problems using artificial neural networks," 217 | } 218 | ]} 219 | ``` 220 | or using the model hosted on machine `A` as backend of a webpage demo hosted on machine `B` using the following command 221 | ``` 222 | python src/demo_doc_gen.py web --remote=IP_address_A:5000 --port=5001 223 | ``` 224 | 225 | # Contributing 226 | 227 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 228 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 229 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 230 | 231 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 232 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 233 | provided by the bot. You will only need to do this once across all repos using our CLA. 234 | 235 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 236 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 237 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 238 | 239 | # Disclaimer 240 | MixingBoard is mainly released as a platform that helps developers build demos with a focus on knowledge grounded stylized text generation, and is not meant as an end-to-end system on its own. The responsibility of decoder implementation resides with the developer, and the developer needs to implement the method `pick_token` in `MixingBoard/src/todo.py` to have a workable system. Despite our efforts to minimize the amount of overtly offensive data in our processing pipelines, models made available with MixingBoard retain the potential of generating output that may trigger offense. Output may reflect gender and other biases implicit in the data. Responses created using these models may exhibit a tendency to agree with propositions that are unethical, biased, or offensive (or conversely, disagreeing with otherwise ethical statements). These are known issues in current state-of-the-art end-to-end conversation models trained on large, naturally occurring datasets. In no case should inappropriate content generated as a result of using MixingBoard be interpreted as reflecting the views or values of either the authors or Microsoft Corp. 241 | 242 | 243 | # Citation 244 | 245 | If you use this code in your work, you can cite our [arxiv](https://arxiv.org/abs/2005.08365) paper: 246 | 247 | ``` 248 | @article{gao2020mixingboard, 249 | title={MixingBoard: a Knowledgeable Stylized Integrated Text Generation Platform}, 250 | author={Gao, Xiang and Galley, Michel and Dolan, Bill}, 251 | journal={Proc. of ACL}, 252 | year={2020} 253 | } 254 | ``` -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /args/api.tsv: -------------------------------------------------------------------------------- 1 | speech __your_API_region__(e.g.`westeurope`) __your_key___ 2 | bing_v7 __your_key___ -------------------------------------------------------------------------------- /args/kb_sites.txt: -------------------------------------------------------------------------------- 1 | wikipedia.org -------------------------------------------------------------------------------- /fig/dialog_web_demo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MixingBoard/4383b382ffd9493007ac484c54d2aea2b64762a8/fig/dialog_web_demo.PNG -------------------------------------------------------------------------------- /fig/doc_web_demo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MixingBoard/4383b382ffd9493007ac484c54d2aea2b64762a8/fig/doc_web_demo.PNG -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # install python packages ==== 2 | 3 | conda install pytorch -c pytorch 4 | pip install azure-cognitiveservices-search-websearch==1.0.0 5 | pip install git+https://github.com/boudinfl/pke.git 6 | pip install allennlp allennlp-models 7 | pip install flask flask_restful 8 | pip install spacy regex nltk pyaudio 9 | pip install sentencepiece sacremoses 10 | 11 | python -m nltk.downloader stopwords 12 | python -m nltk.downloader wordnet 13 | python -m nltk.downloader universal_tagset 14 | python -m spacy download en 15 | python -m spacy download en_core_web_sm 16 | 17 | # download an older version of Transformers that are compatible to the current MixingBoard 18 | mkdir temp 19 | wget https://github.com/huggingface/transformers/archive/4d45654.zip -O temp/transformers.zip 20 | tar -xf temp/transformers.zip -C temp 21 | move temp/transformers-4d456542e9d381090f9a00b2bcc5a4cb07f6f3f7/transformers src/transformers -------------------------------------------------------------------------------- /setup_linux.sh: -------------------------------------------------------------------------------- 1 | # DialoGPT forward and reverse model 2 | mkdir models/DialoGPT 3 | wget https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl -O models/DialoGPT/medium_ft.pkl 4 | wget https://convaisharables.blob.core.windows.net/lsp/multiref/small_reverse.pkl -O models/DialoGPT/small_reverse.pkl 5 | 6 | # BiDAF model 7 | mkdir models/BiDAF 8 | wget https://storage.googleapis.com/allennlp-public-models/bidaf-model-2020.03.19.tar.gz -O models/BiDAF/bidaf-model-2020.03.19.tar.gz 9 | 10 | # Content Transfer 11 | mkdir models/crg 12 | wget http://tts.speech.cs.cmu.edu/content_transfer/crg_model.zip -O temp/crg_model.zip 13 | unzip temp/crg_model.zip -d temp 14 | mv temp/crg_model/crg_model/crg_model.pt models/crg/crg_model.pt 15 | 16 | wget http://tts.speech.cs.cmu.edu/content_transfer/sentencepieceModel.zip -O temp/bpe.zip 17 | unzip temp/bpe.zip -d temp 18 | mv temp/bpe/sentencepieceModel/bpeM.vocab models/crg/bpeM.vocab 19 | mv temp/bpe/sentencepieceModel/bpeM.model models/crg/bpeM.model -------------------------------------------------------------------------------- /setup_win.sh: -------------------------------------------------------------------------------- 1 | # DialoGPT forward and reverse model 2 | mkdir models/DialoGPT 3 | wget https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl -O models/DialoGPT/medium_ft.pkl 4 | wget https://convaisharables.blob.core.windows.net/lsp/multiref/small_reverse.pkl -O models/DialoGPT/small_reverse.pkl 5 | 6 | # BiDAF model 7 | mkdir models/BiDAF 8 | wget https://storage.googleapis.com/allennlp-public-models/bidaf-model-2020.03.19.tar.gz -O models/BiDAF/bidaf-model-2020.03.19.tar.gz 9 | 10 | # Content Transfer 11 | mkdir models/crg 12 | wget http://tts.speech.cs.cmu.edu/content_transfer/crg_model.zip -O temp/crg_model.zip 13 | tar -xf temp/crg_model.zip -C temp 14 | move temp/crg_model/crg_model/crg_model.pt models/crg/crg_model.pt 15 | 16 | wget http://tts.speech.cs.cmu.edu/content_transfer/sentencepieceModel.zip -O temp/bpe.zip 17 | tar -xf temp/bpe.zip -C temp/bpe 18 | move temp/bpe/sentencepieceModel/bpeM.vocab models/crg/bpeM.vocab 19 | move temp/bpe/sentencepieceModel/bpeM.model models/crg/bpeM.model -------------------------------------------------------------------------------- /src/cmr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MixingBoard/4383b382ffd9493007ac484c54d2aea2b64762a8/src/cmr/__init__.py -------------------------------------------------------------------------------- /src/cmr/batcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import random 6 | import string 7 | import logging 8 | import numpy as np 9 | import pickle as pkl 10 | from shutil import copyfile 11 | from .my_utils.tokenizer import Vocabulary 12 | 13 | def load_meta(opt, meta_path): 14 | with open(meta_path, 'rb') as f: 15 | meta = pkl.load(f) 16 | embedding = torch.Tensor(meta['embedding']) 17 | opt['pos_vocab_size'] = len(meta['vocab_tag.tok2ind']) 18 | opt['ner_vocab_size'] = len(meta['vocab_ner.tok2ind']) 19 | opt['vocab_size'] = len(meta['vocab.tok2ind']) 20 | vocab = Vocabulary(meta['vocab.neat']) 21 | vocab.ind2tok = meta['vocab.ind2tok'] 22 | vocab.tok2ind = meta['vocab.tok2ind'] 23 | return embedding, opt, vocab 24 | 25 | 26 | def prepare_batch_data(batch, ground_truth=True): 27 | batch_size = len(batch) 28 | batch_dict = {} 29 | 30 | doc_len = max(len(x['doc_tok']) for x in batch) 31 | if ground_truth: 32 | ans_len = max(len(x['answer_tok']) for x in batch) 33 | # feature vector 34 | feature_len = len(eval(batch[0]['doc_fea'])[0]) if len(batch[0].get('doc_fea', [])) > 0 else 1 35 | doc_id = torch.LongTensor(batch_size, doc_len).fill_(0) 36 | doc_tag = torch.LongTensor(batch_size, doc_len).fill_(0) 37 | doc_ent = torch.LongTensor(batch_size, doc_len).fill_(0) 38 | doc_feature = torch.Tensor(batch_size, doc_len, feature_len).fill_(0) 39 | if ground_truth: 40 | doc_ans = torch.LongTensor(batch_size, ans_len + 2).fill_(0) 41 | 42 | for i, sample in enumerate(batch): 43 | select_len = min(len(sample['doc_tok']), doc_len) 44 | if select_len ==0: 45 | continue 46 | doc_id[i, :select_len] = torch.LongTensor(sample['doc_tok'][:select_len]) 47 | if ground_truth: 48 | answer_tok_ori = sample['answer_tok'] 49 | answer_tok = [2] + answer_tok_ori + [3] 50 | doc_ans[i, :len(answer_tok)] = torch.LongTensor(answer_tok) 51 | 52 | query_len = max(len(x['query_tok']) for x in batch) 53 | query_id = torch.LongTensor(batch_size, query_len).fill_(0) 54 | 55 | for i, sample in enumerate(batch): 56 | select_len = min(len(sample['query_tok']), query_len) 57 | if select_len == 0: 58 | continue 59 | query_id[i, :len(sample['query_tok'])] = torch.LongTensor(sample['query_tok'][:select_len]) 60 | 61 | doc_mask = torch.eq(doc_id, 0) 62 | query_mask = torch.eq(query_id, 0) 63 | if ground_truth: 64 | ans_mask = torch.eq(doc_ans, 0) 65 | 66 | batch_dict['doc_tok'] = doc_id 67 | batch_dict['doc_pos'] = doc_tag 68 | batch_dict['doc_ner'] = doc_ent 69 | batch_dict['doc_fea'] = doc_feature 70 | batch_dict['query_tok'] = query_id 71 | batch_dict['doc_mask'] = doc_mask 72 | batch_dict['query_mask'] = query_mask 73 | if ground_truth: 74 | batch_dict['answer_token'] = doc_ans 75 | batch_dict['ans_mask'] = ans_mask 76 | 77 | return batch_dict 78 | 79 | 80 | class BatchGen: 81 | def __init__(self, data_path, batch_size, gpu, is_train=True, doc_maxlen=100): 82 | self.batch_size = batch_size 83 | self.doc_maxlen = doc_maxlen 84 | self.is_train = is_train 85 | self.gpu = gpu 86 | self.data_path = data_path 87 | self.data = self.load(self.data_path, is_train, doc_maxlen) 88 | if is_train: 89 | indices = list(range(len(self.data))) 90 | random.shuffle(indices) 91 | data = [self.data[i] for i in indices] 92 | data = [self.data[i:i + batch_size] for i in range(0, len(self.data), batch_size)] 93 | self.data = data 94 | self.offset = 0 95 | 96 | def load(self, path, is_train, doc_maxlen=100): 97 | with open(path, 'r', encoding='utf-8') as reader: 98 | data = [] 99 | cnt = 0 100 | for line in reader: 101 | sample = json.loads(line) 102 | cnt += 1 103 | try: 104 | if len(sample['doc_tok']) > doc_maxlen: 105 | sample['doc_tok'] = sample['doc_tok'][:doc_maxlen] 106 | except TypeError: 107 | print(sample['doc_tok']) 108 | print(sample) 109 | raise 110 | 111 | data.append(sample) 112 | print('Loaded {} samples out of {}'.format(len(data), cnt)) 113 | return data 114 | 115 | def reset(self): 116 | if self.is_train: 117 | indices = list(range(len(self.data))) 118 | random.shuffle(indices) 119 | self.data = [self.data[i] for i in indices] 120 | self.offset = 0 121 | 122 | def __len__(self): 123 | return len(self.data) 124 | 125 | def __iter__(self): 126 | while self.offset < len(self): 127 | batch = self.data[self.offset] 128 | 129 | # Convert data into model-ready format 130 | batch_dict = prepare_batch_data(batch) 131 | 132 | if self.gpu: 133 | for k, v in batch_dict.items(): 134 | batch_dict[k] = v.pin_memory() 135 | self.offset += 1 136 | 137 | yield batch_dict 138 | -------------------------------------------------------------------------------- /src/cmr/common.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import tanh, relu, prelu, leaky_relu, sigmoid, elu, selu 2 | from torch.nn.init import uniform, normal, eye, xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, orthogonal 3 | 4 | def linear(x): 5 | return x 6 | 7 | def activation(func_a): 8 | """Activation function wrapper 9 | """ 10 | return eval(func_a) 11 | 12 | def init_wrapper(init='xavier_uniform'): 13 | return eval(init) 14 | -------------------------------------------------------------------------------- /src/cmr/config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # /usr/bin/env python3 4 | import argparse 5 | import multiprocessing 6 | import torch 7 | import json 8 | 9 | """ 10 | Configuration file 11 | """ 12 | 13 | 14 | def model_config(parser): 15 | parser.add_argument('--vocab_size', type=int, default=0) 16 | parser.add_argument('--wemb_dim', type=int, default=300) 17 | parser.add_argument('--covec_on', action='store_false') 18 | parser.add_argument('--embedding_dim', type=int, default=300) 19 | 20 | # pos 21 | parser.add_argument('--no_pos', dest='pos_on', action='store_false') 22 | parser.add_argument('--pos_vocab_size', type=int, default=56) 23 | parser.add_argument('--pos_dim', type=int, default=12) 24 | parser.add_argument('--no_ner', dest='ner_on', action='store_false') 25 | parser.add_argument('--ner_vocab_size', type=int, default=19) 26 | parser.add_argument('--ner_dim', type=int, default=8) 27 | parser.add_argument('--no_feat', dest='feat_on', action='store_false') 28 | parser.add_argument('--num_features', type=int, default=4) 29 | # q->p 30 | parser.add_argument('--prealign_on', action='store_false') 31 | parser.add_argument('--prealign_head', type=int, default=1) 32 | parser.add_argument('--prealign_att_dropout', type=float, default=0) 33 | parser.add_argument('--prealign_norm_on', action='store_true') 34 | parser.add_argument('--prealign_proj_on', action='store_true') 35 | parser.add_argument('--prealign_bidi', action='store_true') 36 | parser.add_argument('--prealign_hidden_size', type=int, default=64) 37 | parser.add_argument('--prealign_share', action='store_false') 38 | parser.add_argument('--prealign_residual_on', action='store_true') 39 | parser.add_argument('--prealign_scale_on', action='store_false') 40 | parser.add_argument('--prealign_sim_func', type=str, default='dotproductproject') 41 | parser.add_argument('--prealign_activation', type=str, default='relu') 42 | parser.add_argument('--pwnn_on', action='store_false') 43 | parser.add_argument('--pwnn_hidden_size', type=int, default=64) 44 | 45 | ##contextual encoding 46 | parser.add_argument('--contextual_hidden_size', type=int, default=64) 47 | parser.add_argument('--contextual_cell_type', type=str, default='lstm') 48 | parser.add_argument('--contextual_weight_norm_on', action='store_true') 49 | parser.add_argument('--contextual_maxout_on', action='store_true') 50 | parser.add_argument('--contextual_residual_on', action='store_true') 51 | parser.add_argument('--contextual_encoder_share', action='store_true') 52 | parser.add_argument('--contextual_num_layers', type=int, default=2) 53 | 54 | ## mem setting 55 | parser.add_argument('--msum_hidden_size', type=int, default=64) 56 | parser.add_argument('--msum_cell_type', type=str, default='lstm') 57 | parser.add_argument('--msum_weight_norm_on', action='store_true') 58 | parser.add_argument('--msum_maxout_on', action='store_true') 59 | parser.add_argument('--msum_residual_on', action='store_true') 60 | parser.add_argument('--msum_lexicon_input_on', action='store_true') 61 | parser.add_argument('--msum_num_layers', type=int, default=1) 62 | 63 | # attention 64 | parser.add_argument('--deep_att_lexicon_input_on', action='store_false') 65 | parser.add_argument('--deep_att_hidden_size', type=int, default=64) 66 | parser.add_argument('--deep_att_sim_func', type=str, default='dotproductproject') 67 | parser.add_argument('--deep_att_activation', type=str, default='relu') 68 | parser.add_argument('--deep_att_norm_on', action='store_false') 69 | parser.add_argument('--deep_att_proj_on', action='store_true') 70 | parser.add_argument('--deep_att_residual_on', action='store_true') 71 | parser.add_argument('--deep_att_share', action='store_false') 72 | parser.add_argument('--deep_att_opt', type=int, default=0) 73 | 74 | # self attn 75 | parser.add_argument('--self_attention_on', action='store_false') 76 | parser.add_argument('--self_att_hidden_size', type=int, default=64) 77 | parser.add_argument('--self_att_sim_func', type=str, default='dotproductproject') 78 | parser.add_argument('--self_att_activation', type=str, default='relu') 79 | parser.add_argument('--self_att_norm_on', action='store_true') 80 | parser.add_argument('--self_att_proj_on', action='store_true') 81 | parser.add_argument('--self_att_residual_on', action='store_true') 82 | parser.add_argument('--self_att_dropout', type=float, default=0.1) 83 | parser.add_argument('--self_att_drop_diagonal', action='store_false') 84 | parser.add_argument('--self_att_share', action='store_false') 85 | 86 | # query summary 87 | parser.add_argument('--query_sum_att_type', type=str, default='linear', 88 | help='linear/mlp') 89 | parser.add_argument('--query_sum_norm_on', action='store_true') 90 | parser.add_argument('--san_on', action='store_true') 91 | parser.add_argument('--max_len', type=int, default=30) 92 | parser.add_argument('--decoder_hidden_size', type=int, default=512) 93 | parser.add_argument('--decoder_ptr_update_on', action='store_true') 94 | parser.add_argument('--decoder_num_turn', type=int, default=5) 95 | parser.add_argument('--decoder_mem_type', type=int, default=3) 96 | parser.add_argument('--decoder_mem_drop_p', type=float, default=0.2) 97 | parser.add_argument('--decoder_opt', type=int, default=0) 98 | parser.add_argument('--decoder_att_type', type=str, default='bilinear', 99 | help='bilinear/simple/default') 100 | parser.add_argument('--decoder_rnn_type', type=str, default='gru', 101 | help='rnn/gru/lstm') 102 | parser.add_argument('--decoder_sum_att_type', type=str, default='bilinear', 103 | help='bilinear/simple/default') 104 | parser.add_argument('--decoder_weight_norm_on', action='store_true') 105 | return parser 106 | 107 | 108 | def data_config(parser): 109 | parser.add_argument('--log_file', default='./log/reddit.log', help='path for log file.') 110 | parser.add_argument('--data_dir', default='data') 111 | parser.add_argument('--raw_data_dir', default='./raw_data') 112 | parser.add_argument('--meta', default='models/cmr/new_meta.pick', help='path to preprocessed meta file.') 113 | parser.add_argument('--train_data', default='train_100k.json', 114 | help='path to preprocessed training data file.') 115 | parser.add_argument('--dev_data', default='dev_100k.json', 116 | help='path to preprocessed validation data file.') 117 | parser.add_argument('--dev_gold', default='dev_seq_answer', 118 | help='path to preprocessed validation data file.') 119 | parser.add_argument('--covec_path', default='models/cmr/MT-LSTM.pt') 120 | parser.add_argument('--glove', default='data_processing/glove.840B.300d.txt', 121 | help='path to word vector file.') 122 | parser.add_argument('--glove_dim', type=int, default=300, 123 | help='word vector dimension.') 124 | parser.add_argument('--sort_all', action='store_true', 125 | help='sort the vocabulary by frequencies of all words.' 126 | 'Otherwise consider question words first.') 127 | parser.add_argument('--threads', type=int, default=multiprocessing.cpu_count(), 128 | help='number of threads for preprocessing.') 129 | parser.add_argument('--dev_full', default='dev.full') 130 | parser.add_argument('--test_full', default='test.full') 131 | parser.add_argument('--test_data', default='test.json') 132 | 133 | parser.add_argument('--test_output', default='test_output') 134 | return parser 135 | 136 | 137 | def train_config(parser): 138 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(), 139 | help='whether to use GPU acceleration.') 140 | parser.add_argument('--log_per_updates', type=int, default=150) 141 | parser.add_argument('--epoches', type=int, default=400) 142 | parser.add_argument('--eval_step', type=int, default=3000) 143 | parser.add_argument('--batch_size', type=int, default=32) 144 | parser.add_argument('--resume', type=str, default='') 145 | parser.add_argument('--optimizer', default='adam', 146 | help='supported optimizer: adamax, sgd, adadelta, adam') 147 | parser.add_argument('--grad_clipping', type=float, default=5) 148 | parser.add_argument('--weight_decay', type=float, default=0) 149 | parser.add_argument('--learning_rate', type=float, default=0.002) 150 | parser.add_argument('--momentum', type=float, default=0.9) 151 | parser.add_argument('--vb_dropout', action='store_false') 152 | parser.add_argument('--dropout_p', type=float, default=0.4) 153 | parser.add_argument('--dropout_emb', type=float, default=0.4) 154 | parser.add_argument('--dropout_w', type=float, default=0.05) 155 | parser.add_argument('--unk_id', type=int, default=1) 156 | parser.add_argument('--decoding', type=str, default='greedy', help='greedy/sample') 157 | parser.add_argument('--temperature', type=float, default=1.0) 158 | parser.add_argument('--top_k', type=int, default=1) 159 | parser.add_argument('--if_train', type=int, default=1) 160 | parser.add_argument('--curve_file', type=str, default='dev_curve.csv') 161 | parser.add_argument('--smooth', type=int, default=-1) 162 | parser.add_argument('--max_doc', type=int, default=100) 163 | parser.add_argument('--is_rep', type=float, default=0) 164 | parser.add_argument('--decoding_topk', type=int, default=8) 165 | parser.add_argument('--decoding_bleu_lambda', type=float, default=0.5) 166 | parser.add_argument('--decoding_bleu_normalize', action='store_true') 167 | parser.add_argument('--model_type', type=str, default='san', help='[san|seq2seq|memnet]') 168 | parser.add_argument('--weight_type', type=str, default='bleu', help='[bleu|nist]') 169 | parser.add_argument('--no_lr_scheduler', dest='have_lr_scheduler', action='store_true') 170 | parser.add_argument('--multi_step_lr', type=str, default='10,20,30') 171 | parser.add_argument('--lr_gamma', type=float, default=0.5) 172 | parser.add_argument('--scheduler_type', type=str, default='ms', help='ms/rop/exp') 173 | parser.add_argument('--fix_embeddings', action='store_true', help='if true, `tune_partial` will be ignored.') 174 | parser.add_argument('--tune_partial', type=int, default=1000, 175 | help='finetune top-x embeddings (including , ).') 176 | parser.add_argument('--model_dir', default='models/cmr/san_checkpoint.pt') 177 | parser.add_argument('--seed', type=int, default=2018, 178 | help='random seed for data shuffling, embedding init, etc.') 179 | return parser 180 | 181 | 182 | def decoding_config(parser): 183 | parser.add_argument('--skip_tokens_file', type=str, default="") 184 | parser.add_argument('--skip_tokens_first_file', type=str, default="") 185 | return parser 186 | 187 | 188 | def set_args(): 189 | parser = argparse.ArgumentParser() 190 | parser = data_config(parser) 191 | parser = model_config(parser) 192 | parser = train_config(parser) 193 | parser = decoding_config(parser) 194 | args = parser.parse_args() 195 | return args 196 | 197 | 198 | def args2json(args, path): 199 | d = args.__dict__ 200 | s = json.dumps(d) 201 | with open(path, 'w', encoding='utf-8') as f: 202 | f.write(s) 203 | print('args saved to: '+path) 204 | 205 | 206 | 207 | 208 | if __name__ == "__main__": 209 | # python src/cmr_config.py --pwnn_on --no_pos --no_ner --no_feat 210 | args = set_args() 211 | print(args) 212 | path = 'models/cmr/args.json' 213 | args2json(args, path) -------------------------------------------------------------------------------- /src/cmr/dreader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from .recurrent import OneLayerBRNN, ContextualEmbed 8 | from .dropout_wrapper import DropoutWrapper 9 | from .encoder import LexiconEncoder 10 | from .similarity import DeepAttentionWrapper, FlatSimilarityWrapper, SelfAttnWrapper 11 | from .similarity import AttentionWrapper 12 | from .san_decoder import SANDecoder 13 | 14 | class DNetwork(nn.Module): 15 | """Network for SAN doc reader.""" 16 | def __init__(self, opt, embedding=None, padding_idx=0): 17 | super(DNetwork, self).__init__() 18 | my_dropout = DropoutWrapper(opt['dropout_p'], opt['vb_dropout']) 19 | self.dropout = my_dropout 20 | 21 | self.lexicon_encoder = LexiconEncoder(opt, embedding=embedding, dropout=my_dropout) 22 | query_input_size = self.lexicon_encoder.query_input_size 23 | doc_input_size = self.lexicon_encoder.doc_input_size 24 | 25 | covec_size = self.lexicon_encoder.covec_size 26 | embedding_size = self.lexicon_encoder.embedding_dim 27 | 28 | # share net 29 | contextual_share = opt.get('contextual_encoder_share', False) 30 | prefix = 'contextual' 31 | prefix = 'contextual' 32 | 33 | # doc_hidden_size 34 | self.doc_encoder_low = OneLayerBRNN(doc_input_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 35 | self.doc_encoder_high = OneLayerBRNN(self.doc_encoder_low.output_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 36 | if contextual_share: 37 | self.query_encoder_low = self.doc_encoder_low 38 | self.query_encoder_high = self.doc_encoder_high 39 | else: 40 | self.query_encoder_low = OneLayerBRNN(query_input_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 41 | self.query_encoder_high = OneLayerBRNN(self.query_encoder_low.output_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 42 | 43 | doc_hidden_size = self.doc_encoder_low.output_size + self.doc_encoder_high.output_size 44 | query_hidden_size = self.query_encoder_low.output_size + self.query_encoder_high.output_size 45 | 46 | self.query_understand = OneLayerBRNN(query_hidden_size, opt['msum_hidden_size'], prefix='msum', opt=opt, dropout=my_dropout) 47 | doc_attn_size = doc_hidden_size + covec_size + embedding_size 48 | query_attn_size = query_hidden_size + covec_size + embedding_size 49 | num_layers = 3 50 | 51 | prefix = 'deep_att' 52 | self.deep_attn = DeepAttentionWrapper(doc_attn_size, query_attn_size, num_layers, prefix, opt, my_dropout) 53 | 54 | doc_und_size = doc_hidden_size + query_hidden_size + self.query_understand.output_size 55 | self.doc_understand = OneLayerBRNN(doc_und_size, opt['msum_hidden_size'], prefix='msum', opt=opt, dropout=my_dropout) 56 | query_mem_hidden_size = self.query_understand.output_size 57 | doc_mem_hidden_size = self.doc_understand.output_size 58 | 59 | if opt['self_attention_on']: 60 | att_size = embedding_size + covec_size + doc_hidden_size + query_hidden_size + self.query_understand.output_size + self.doc_understand.output_size 61 | self.doc_self_attn = AttentionWrapper(att_size, att_size, prefix='self_att', opt=opt, dropout=my_dropout) 62 | doc_mem_hidden_size = doc_mem_hidden_size * 2 63 | self.doc_mem_gen = OneLayerBRNN(doc_mem_hidden_size, opt['msum_hidden_size'], 'msum', opt, my_dropout) 64 | doc_mem_hidden_size = self.doc_mem_gen.output_size 65 | 66 | # Question merging 67 | self.query_sum_attn = SelfAttnWrapper(query_mem_hidden_size, prefix='query_sum', opt=opt, dropout=my_dropout) 68 | self.decoder = SANDecoder(doc_mem_hidden_size, query_mem_hidden_size, opt, prefix='decoder', dropout=my_dropout) 69 | self.opt = opt 70 | 71 | self.hidden_size = self.query_understand.output_size 72 | 73 | def forward(self, batch): 74 | doc_input, query_input,\ 75 | doc_emb, query_emb,\ 76 | doc_cove_low, doc_cove_high,\ 77 | query_cove_low, query_cove_high,\ 78 | doc_mask, query_mask = self.lexicon_encoder(batch) 79 | 80 | query_list, doc_list = [], [] 81 | query_list.append(query_input) 82 | doc_list.append(doc_input) 83 | 84 | # doc encode 85 | doc_low = self.doc_encoder_low(torch.cat([doc_input, doc_cove_low], 2), doc_mask) 86 | doc_low = self.dropout(doc_low) 87 | doc_high = self.doc_encoder_high(torch.cat([doc_low, doc_cove_high], 2), doc_mask) 88 | doc_high = self.dropout(doc_high) 89 | # query 90 | query_low = self.query_encoder_low(torch.cat([query_input, query_cove_low], 2), query_mask) 91 | query_low = self.dropout(query_low) 92 | query_high = self.query_encoder_high(torch.cat([query_low, query_cove_high], 2), query_mask) 93 | query_high = self.dropout(query_high) 94 | 95 | query_mem_hiddens = self.query_understand(torch.cat([query_low, query_high], 2), query_mask) 96 | query_mem_hiddens = self.dropout(query_mem_hiddens) 97 | query_list = [query_low, query_high, query_mem_hiddens] 98 | doc_list = [doc_low, doc_high] 99 | 100 | query_att_input = torch.cat([query_emb, query_cove_high, query_low, query_high], 2) 101 | doc_att_input = torch.cat([doc_emb, doc_cove_high] + doc_list, 2) 102 | doc_attn_hiddens = self.deep_attn(doc_att_input, query_att_input, query_list, query_mask) 103 | doc_attn_hiddens = self.dropout(doc_attn_hiddens) 104 | doc_mem_hiddens = self.doc_understand(torch.cat([doc_attn_hiddens] + doc_list, 2), doc_mask) 105 | doc_mem_hiddens = self.dropout(doc_mem_hiddens) 106 | doc_mem_inputs = torch.cat([doc_attn_hiddens] + doc_list, 2) 107 | if self.opt['self_attention_on']: 108 | doc_att = torch.cat([doc_mem_inputs, doc_mem_hiddens, doc_cove_high, doc_emb], 2) 109 | doc_self_hiddens = self.doc_self_attn(doc_att, doc_att, doc_mask, x3=doc_mem_hiddens) 110 | doc_mem = self.doc_mem_gen(torch.cat([doc_mem_hiddens, doc_self_hiddens], 2), doc_mask) 111 | else: 112 | doc_mem = doc_mem_hiddens 113 | query_mem = self.query_sum_attn(query_mem_hiddens, query_mask) 114 | return doc_mem, query_mem, doc_mask 115 | -------------------------------------------------------------------------------- /src/cmr/dreader_seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from .recurrent import OneLayerBRNN, ContextualEmbed 8 | from .dropout_wrapper import DropoutWrapper 9 | from .encoder import LexiconEncoder 10 | from .similarity import DeepAttentionWrapper, FlatSimilarityWrapper, SelfAttnWrapper 11 | from .similarity import AttentionWrapper 12 | from .san_decoder import SANDecoder 13 | 14 | class DNetwork_Seq2seq(nn.Module): 15 | """Network for Seq2seq/Memnet doc reader.""" 16 | def __init__(self, opt, embedding=None, padding_idx=0): 17 | super(DNetwork_Seq2seq, self).__init__() 18 | my_dropout = DropoutWrapper(opt['dropout_p'], opt['vb_dropout']) 19 | self.dropout = my_dropout 20 | 21 | self.lexicon_encoder = LexiconEncoder(opt, embedding=embedding, dropout=my_dropout) 22 | query_input_size = self.lexicon_encoder.query_input_size 23 | doc_input_size = self.lexicon_encoder.doc_input_size 24 | 25 | print('Lexicon encoding size for query and doc are:{}', doc_input_size, query_input_size) 26 | covec_size = self.lexicon_encoder.covec_size 27 | embedding_size = self.lexicon_encoder.embedding_dim 28 | # share net 29 | contextual_share = opt.get('contextual_encoder_share', False) 30 | prefix = 'contextual' 31 | prefix = 'contextual' 32 | # doc_hidden_size 33 | self.doc_encoder_low = OneLayerBRNN(doc_input_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 34 | self.doc_encoder_high = OneLayerBRNN(self.doc_encoder_low.output_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 35 | if contextual_share: 36 | self.query_encoder_low = self.doc_encoder_low 37 | self.query_encoder_high = self.doc_encoder_high 38 | else: 39 | self.query_encoder_low = OneLayerBRNN(query_input_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 40 | self.query_encoder_high = OneLayerBRNN(self.query_encoder_low.output_size + covec_size, opt['contextual_hidden_size'], prefix=prefix, opt=opt, dropout=my_dropout) 41 | 42 | doc_hidden_size = self.doc_encoder_low.output_size + self.doc_encoder_high.output_size 43 | query_hidden_size = self.query_encoder_low.output_size + self.query_encoder_high.output_size 44 | 45 | self.query_understand = OneLayerBRNN(query_hidden_size, opt['msum_hidden_size'], prefix='msum', opt=opt, dropout=my_dropout) 46 | doc_attn_size = doc_hidden_size + covec_size + embedding_size 47 | query_attn_size = query_hidden_size + covec_size + embedding_size 48 | num_layers = 3 49 | 50 | prefix = 'deep_att' 51 | self.deep_attn = DeepAttentionWrapper(doc_attn_size, query_attn_size, num_layers, prefix, opt, my_dropout) 52 | 53 | doc_und_size = doc_hidden_size + query_hidden_size + self.query_understand.output_size 54 | self.doc_understand = OneLayerBRNN(doc_und_size, opt['msum_hidden_size'], prefix='msum', opt=opt, dropout=my_dropout) 55 | query_mem_hidden_size = self.query_understand.output_size 56 | doc_mem_hidden_size = self.doc_understand.output_size 57 | 58 | if opt['self_attention_on']: 59 | att_size = embedding_size + covec_size + doc_hidden_size + query_hidden_size + self.query_understand.output_size + self.doc_understand.output_size 60 | self.doc_self_attn = AttentionWrapper(att_size, att_size, prefix='self_att', opt=opt, dropout=my_dropout) 61 | doc_mem_hidden_size = doc_mem_hidden_size * 2 62 | self.doc_mem_gen = OneLayerBRNN(doc_mem_hidden_size, opt['msum_hidden_size'], 'msum', opt, my_dropout) 63 | doc_mem_hidden_size = self.doc_mem_gen.output_size 64 | # Question merging 65 | self.query_sum_attn = SelfAttnWrapper(query_mem_hidden_size, prefix='query_sum', opt=opt, dropout=my_dropout) 66 | self.decoder = SANDecoder(doc_mem_hidden_size, query_mem_hidden_size, opt, prefix='decoder', dropout=my_dropout) 67 | self.opt = opt 68 | self.hidden_size = self.query_understand.output_size 69 | self.gru = nn.GRUCell(embedding_size, self.hidden_size) 70 | self.embedding = nn.Embedding(opt['vocab_size'], embedding_size, padding_idx=0) 71 | self.memA = nn.Linear(embedding_size, self.hidden_size, bias=False) 72 | self.memC = nn.Linear(embedding_size, self.hidden_size, bias=False) 73 | 74 | def forward(self, input, hidden): 75 | hidden = self.gru(input, hidden) 76 | return hidden 77 | 78 | def initHidden(self, batch_size): 79 | return torch.zeros(batch_size, self.hidden_size) 80 | 81 | def _get_doc_sentence_embeddings(self, batch): 82 | sentence_end_tok_ids = [' 5 ', ' 157 ', ' 80 ', ' 180 '] 83 | 84 | batch_size = len(batch['doc_tok']) 85 | 86 | doc_sents = [] 87 | max_num_sents = -1 88 | max_sent_len = -1 89 | max_doc_len = -1 90 | 91 | doc = Variable(torch.cat([batch['doc_tok'], torch.LongTensor([[2, 5, 1, 2]] * batch_size)], dim=1)).data.numpy() 92 | for doc_i in doc: # i-th example, e.g., [9 9 5 8 8] 93 | 94 | max_doc_len = max(max_doc_len, len(doc_i)) 95 | 96 | doc_i_str = ' '.join([str(_) for _ in doc_i]) # '9 9 5 8 8' 97 | for se in sentence_end_tok_ids: 98 | doc_i_str = doc_i_str.replace(se, '') # ['9 9 8 8'] 99 | doc_i_str_split = doc_i_str.split('') # ['9 9', '8 8'] 100 | doc_i_str_toks = [_.strip().split() for _ in doc_i_str_split] # [['9', '9'], ['8', '8']] 101 | 102 | num_sent = len(doc_i_str_toks) 103 | max_num_sents = max(num_sent, max_num_sents) 104 | 105 | max_sent_len_i = max(len(_) for _ in doc_i_str_toks) 106 | max_sent_len = max(max_sent_len, max_sent_len_i) 107 | 108 | def _doc_ij_str_to_idx(doc_ij_str): 109 | return [int(_) for _ in doc_ij_str] 110 | 111 | doc_sents.append([_doc_ij_str_to_idx(_) for _ in doc_i_str_toks]) 112 | # [..., [[9, 9], [8, 8]], ...] 113 | 114 | doc_sents_tensor = torch.LongTensor(batch_size, max_num_sents, max_sent_len).fill_(0) 115 | for i, doc_i in enumerate(doc_sents): 116 | for j, doc_ij in enumerate(doc_i): 117 | sent_len_ij = len(doc_ij) 118 | doc_sents_tensor[i, j, :sent_len_ij] = torch.LongTensor(doc_ij) 119 | 120 | if self.opt['cuda']: 121 | doc_sents_tensor = doc_sents_tensor.cuda() 122 | doc_sents_tensor = Variable(doc_sents_tensor) 123 | doc_sents_emb = self.embedding(doc_sents_tensor.view(batch_size, -1)) 124 | doc_sents_emb = doc_sents_emb.view(batch_size, max_num_sents, max_sent_len, -1) 125 | doc_sents_emb = torch.mean(doc_sents_emb, dim=2) 126 | return doc_sents_emb 127 | 128 | def add_fact_memory(self, query_final_hidden, batch): 129 | doc_emb = self._get_doc_sentence_embeddings(batch) 130 | 131 | m = self.memA(doc_emb) 132 | c = self.memC(doc_emb) 133 | 134 | u = query_final_hidden.unsqueeze(1) 135 | attn = m * u 136 | attn = torch.sum(attn, dim=-1).squeeze() 137 | attn = F.softmax(attn) 138 | mem_hidden = torch.sum(attn.unsqueeze(2) * c, dim=1).squeeze() 139 | return mem_hidden 140 | -------------------------------------------------------------------------------- /src/cmr/dropout_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class DropoutWrapper(nn.Module): 7 | """ 8 | This is a dropout wrapper which supports the fix mask dropout 9 | by: xiaodl 10 | """ 11 | def __init__(self, dropout_p=0, enable_vbp=True): 12 | super(DropoutWrapper, self).__init__() 13 | """variational dropout means fix dropout mask 14 | ref: https://discuss.pytorch.org/t/dropout-for-rnns/633/11 15 | """ 16 | self.enable_variational_dropout = enable_vbp 17 | self.dropout_p = dropout_p 18 | 19 | def forward(self, x): 20 | """ 21 | :param x: batch * len * input_size 22 | """ 23 | if self.training == False or self.dropout_p == 0: 24 | return x 25 | 26 | if len(x.size()) == 3: 27 | mask = Variable(1.0 / (1-self.dropout_p) * torch.bernoulli((1-self.dropout_p) * (x.data.new(x.size(0), x.size(2)).zero_() + 1)), requires_grad=False) 28 | return mask.unsqueeze(1).expand_as(x) * x 29 | else: 30 | return F.dropout(x, p=self.dropout_p, training=self.training) 31 | -------------------------------------------------------------------------------- /src/cmr/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn.utils import weight_norm 7 | from .recurrent import BRNNEncoder, ContextualEmbed 8 | from .dropout_wrapper import DropoutWrapper 9 | from .common import activation 10 | from .similarity import AttentionWrapper 11 | from .sub_layers import PositionwiseNN 12 | 13 | class LexiconEncoder(nn.Module): 14 | def create_embed(self, vocab_size, embed_dim, padding_idx=0): 15 | return nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) 16 | 17 | def create_word_embed(self, embedding=None, opt={}, prefix='wemb'): 18 | vocab_size = opt.get('vocab_size', 1) 19 | embed_dim = opt.get('{}_dim'.format(prefix), 300) 20 | self.embedding = self.create_embed(vocab_size, embed_dim) 21 | if embedding is not None: 22 | self.embedding.weight.data = embedding 23 | if opt['fix_embeddings'] or opt['tune_partial'] == 0: 24 | opt['fix_embeddings'] = True 25 | opt['tune_partial'] = 0 26 | for p in self.embedding.parameters(): 27 | p.requires_grad = False 28 | else: 29 | fixed_embedding = embedding[embedding.size(0)-1:] 30 | 31 | self.register_buffer('fixed_embedding', fixed_embedding) 32 | self.fixed_embedding = fixed_embedding 33 | return embed_dim 34 | 35 | def create_pos_embed(self, opt={}, prefix='pos'): 36 | vocab_size = opt.get('{}_vocab_size'.format(prefix), 56) 37 | embed_dim = opt.get('{}_dim'.format(prefix), 12) 38 | self.pos_embedding = self.create_embed(vocab_size, embed_dim) 39 | return embed_dim 40 | 41 | def create_ner_embed(self, opt={}, prefix='ner'): 42 | vocab_size = opt.get('{}_vocab_size'.format(prefix), 19) 43 | embed_dim = opt.get('{}_dim'.format(prefix), 8) 44 | self.ner_embedding = self.create_embed(vocab_size, embed_dim) 45 | return embed_dim 46 | 47 | def create_cove(self, vocab_size, embedding=None, embed_dim=300, padding_idx=0, opt=None): 48 | self.ContextualEmbed= ContextualEmbed(opt['covec_path'], opt['vocab_size'], embedding=embedding, padding_idx=padding_idx) 49 | return self.ContextualEmbed.output_size 50 | 51 | def create_prealign(self, x1_dim, x2_dim, opt={}, prefix='prealign'): 52 | self.prealign = AttentionWrapper(x1_dim, x2_dim, prefix, opt, self.dropout) 53 | 54 | def __init__(self, opt, pwnn_on=True, embedding=None, padding_idx=0, dropout=None): 55 | super(LexiconEncoder, self).__init__() 56 | doc_input_size = 0 57 | que_input_size = 0 58 | self.dropout = DropoutWrapper(opt['dropout_p']) if dropout == None else dropout 59 | self.dropout_emb = DropoutWrapper(opt['dropout_emb']) 60 | # word embedding 61 | embedding_dim = self.create_word_embed(embedding, opt) 62 | self.embedding_dim = embedding_dim 63 | doc_input_size += embedding_dim 64 | que_input_size += embedding_dim 65 | 66 | # pre-trained contextual vector 67 | covec_size = self.create_cove(opt['vocab_size'], embedding, opt=opt) if opt['covec_on'] else 0 68 | self.covec_size = covec_size 69 | 70 | prealign_size = 0 71 | if opt['prealign_on'] and embedding_dim > 0: 72 | prealign_size = embedding_dim 73 | self.create_prealign(embedding_dim, embedding_dim, opt) 74 | self.prealign_size = prealign_size 75 | pos_size = self.create_pos_embed(opt) if opt['pos_on'] else 0 76 | ner_size = self.create_ner_embed(opt) if opt['ner_on'] else 0 77 | feat_size = opt['num_features'] if opt['feat_on'] else 0 78 | print(feat_size) 79 | doc_hidden_size = embedding_dim + covec_size + prealign_size + pos_size + ner_size + feat_size 80 | que_hidden_size = embedding_dim + covec_size 81 | if opt['prealign_bidi']: 82 | que_hidden_size += prealign_size 83 | self.pwnn_on = pwnn_on 84 | self.opt = opt 85 | if self.pwnn_on: 86 | #print('here: doc_pwnn') 87 | #import pdb; pdb.set_trace() 88 | self.doc_pwnn = PositionwiseNN(doc_hidden_size, opt['pwnn_hidden_size'], dropout) 89 | if doc_hidden_size == que_hidden_size: 90 | self.que_pwnn = self.doc_pwnn 91 | else: 92 | self.que_pwnn = PositionwiseNN(que_hidden_size, opt['pwnn_hidden_size'], dropout) 93 | doc_input_size, que_input_size = opt['pwnn_hidden_size'], opt['pwnn_hidden_size'] 94 | self.doc_input_size = doc_input_size 95 | self.query_input_size = que_input_size 96 | 97 | def patch(self, v): 98 | if self.opt['cuda']: 99 | v = Variable(v.cuda()) 100 | else: 101 | v = Variable(v) 102 | return v 103 | 104 | def forward(self, batch): 105 | drnn_input_list = [] 106 | qrnn_input_list = [] 107 | emb = self.embedding if self.training else self.eval_embed 108 | doc_tok = self.patch(batch['doc_tok']) 109 | doc_mask = self.patch(batch['doc_mask']) 110 | query_tok = self.patch(batch['query_tok']) 111 | query_mask = self.patch(batch['query_mask']) 112 | 113 | doc_emb, query_emb = emb(doc_tok), emb(query_tok) 114 | # Dropout on embeddings 115 | if self.opt['dropout_emb'] > 0: 116 | doc_emb = self.dropout_emb(doc_emb) 117 | query_emb = self.dropout_emb(query_emb) 118 | drnn_input_list.append(doc_emb) 119 | qrnn_input_list.append(query_emb) 120 | 121 | doc_cove_low, doc_cove_high = None, None 122 | query_cove_low, query_cove_high = None, None 123 | if self.opt['covec_on']: 124 | doc_cove_low, doc_cove_high = self.ContextualEmbed(doc_tok, doc_mask) 125 | query_cove_low, query_cove_high = self.ContextualEmbed(query_tok, query_mask) 126 | doc_cove_low = self.dropout(doc_cove_low) 127 | doc_cove_high = self.dropout(doc_cove_high) 128 | query_cove_low = self.dropout(query_cove_low) 129 | query_cove_high = self.dropout(query_cove_high) 130 | drnn_input_list.append(doc_cove_low) 131 | qrnn_input_list.append(query_cove_low) 132 | 133 | if self.opt['prealign_on']: 134 | q2d_atten = self.prealign(doc_emb, query_emb, query_mask) 135 | d2q_atten = self.prealign(query_emb, doc_emb, doc_mask) 136 | drnn_input_list.append(q2d_atten) 137 | if self.opt['prealign_bidi']: 138 | qrnn_input_list.append(d2q_atten) 139 | 140 | if self.opt['pos_on']: 141 | doc_pos = self.patch(batch['doc_pos']) 142 | doc_pos_emb = self.pos_embedding(doc_pos) 143 | drnn_input_list.append(doc_pos_emb) 144 | 145 | if self.opt['ner_on']: 146 | doc_ner = self.patch(batch['doc_ner']) 147 | doc_ner_emb = self.ner_embedding(doc_ner) 148 | drnn_input_list.append(doc_ner_emb) 149 | 150 | if self.opt['feat_on']: 151 | doc_fea = self.patch(batch['doc_fea']) 152 | drnn_input_list.append(doc_fea) 153 | 154 | doc_input = torch.cat(drnn_input_list, 2) 155 | query_input = torch.cat(qrnn_input_list, 2) 156 | if self.pwnn_on: 157 | doc_input = self.doc_pwnn(doc_input) 158 | query_input = self.que_pwnn(query_input) 159 | doc_input = self.dropout(doc_input) 160 | query_input = self.dropout(query_input) 161 | return doc_input, query_input, doc_emb, query_emb, doc_cove_low, \ 162 | doc_cove_high, query_cove_low, query_cove_high, doc_mask, query_mask 163 | -------------------------------------------------------------------------------- /src/cmr/fetch_realtime_grounding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # PAge normalization/data extraction logic taken from https://github.com/qkaren/converse_reading_cmr (for consistency) 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from bs4.element import NavigableString 7 | import re 8 | from nltk.tokenize import TweetTokenizer 9 | from itertools import chain 10 | import pke 11 | 12 | 13 | class GroudingGenerator: 14 | def __init__(self, max_fact_len=12, max_facts_count=500, min_fact_len=8): 15 | self.tokenizer = TweetTokenizer(preserve_case=False) 16 | self.extractor = pke.unsupervised.TopicRank() 17 | self.max_fact_len = max_fact_len 18 | self.max_facts_count = max_facts_count 19 | self.min_fact_len = min_fact_len 20 | 21 | def insert_escaped_tags(self, tags): 22 | """For each tag in "tags", insert contextual tags (e.g.,

) as escaped text 23 | so that these tags are still there when html markup is stripped out.""" 24 | found = False 25 | for tag in tags: 26 | strs = list(tag.strings) 27 | if len(strs) > 0: 28 | l = tag.name 29 | strs[0].parent.insert(0, NavigableString("<"+l+">")) 30 | strs[-1].parent.append(NavigableString("")) 31 | found = True 32 | return found 33 | 34 | def norm_fact(self, t, tokenize=True): 35 | # Minimalistic processing: remove extra space characters 36 | t = re.sub("[ \n\r\t]+", " ", t) 37 | t = t.strip() 38 | if tokenize: 39 | t = " ".join(self.tokenizer.tokenize(t)) 40 | t = t.replace('[ deleted ]','[deleted]'); 41 | # Preprocessing specific to fact 42 | t = self.filter_text(t) 43 | t = re.sub('- wikipedia ', '', t, 1) 44 | t = re.sub(' \[ edit \]', '', t, 1) 45 | t = re.sub('

navigation menu

', '', t) 46 | return t 47 | 48 | def norm_article(self, t): 49 | """Minimalistic processing with linebreaking.""" 50 | t = re.sub("\s*\n+\s*","\n", t) 51 | t = re.sub(r'()',r'\1\n', t) 52 | t = re.sub("[ \t]+"," ", t) 53 | t = t.strip() 54 | return t 55 | 56 | def get_wiki_page_url(self, title): 57 | """Search for wiki URL for given topic""" 58 | r = requests.get("https://en.wikipedia.org/w/api.php?action=opensearch&search=%s&limit=1&format=json" % "%20".join(title.split(" "))).json()[3] 59 | if len(r) == 0: 60 | return None 61 | wanted_url = r[0] 62 | main_content = requests.get(wanted_url) 63 | return main_content 64 | 65 | def get_desired_content(self, page_content): 66 | """Return facts extracted from website""" 67 | notext_tags = ['script', 'style'] 68 | important_tags = ['title', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'hr', 'p'] 69 | b = BeautifulSoup(page_content,'html.parser') 70 | # Remove tags whose text we don't care about (javascript, etc.): 71 | for el in b(notext_tags): 72 | el.decompose() 73 | # Delete other unimportant tags, but keep the text: 74 | for tag in b.findAll(True): 75 | if tag.name not in important_tags: 76 | tag.append(' ') 77 | tag.replaceWithChildren() 78 | # All tags left are important (e.g.,

) so add them to the text: 79 | self.insert_escaped_tags(b.find_all(True)) 80 | # Extract facts from html: 81 | t = b.get_text(" ") 82 | t = self.norm_article(t) 83 | facts = [] 84 | for sent in filter(None, t.split("\n")): 85 | if len(sent.split(" ")) >= self.min_fact_len: 86 | facts.append(self.process_fact(sent)) 87 | return self.combine_facts(facts) 88 | 89 | def filter_text(self, text): 90 | #https://stackoverflow.com/questions/4703390/how-to-extract-a-floating-number-from-a-string 91 | text = re.sub(r'[-+]?(\d+([.,]\d*)?|[.,]\d+)([eE][-+]?\d+)?', '', text) 92 | text = re.sub("[\(].*?[\)]", "", text) # [\[\(].*?[\]\)] 93 | text = text.split() 94 | new_text = [] 95 | for x in text: 96 | if 'www' in x or 'http' in x: 97 | continue 98 | new_text.append(x) 99 | return ' '.join(new_text) 100 | 101 | def process_fact(self, fact): 102 | fact = self.filter_text(self.norm_fact(fact)).strip().split() 103 | if len(fact) > 100: 104 | fact = fact[:100] + [''] 105 | return fact 106 | 107 | def combine_facts(self, facts): 108 | facts = facts[:self.max_fact_len] 109 | facts = ' '.join(list(chain(*facts))) 110 | facts = facts.split() 111 | if len(facts) == 0: 112 | facts = ['UNK'] 113 | if len(facts) > self.max_facts_count: 114 | facts = facts[:self.max_facts_count] + [''] 115 | return facts 116 | 117 | def topic_extraction(self, text): 118 | self.extractor.load_document(input=text, language='en') 119 | self.extractor.candidate_selection() 120 | self.extractor.candidate_weighting() 121 | keyphrases = self.extractor.get_n_best(n=10) 122 | return keyphrases 123 | 124 | def get_appropriate_grounding(self, topics): 125 | for topic in topics: 126 | x = self.get_wiki_page_url(topic[0]) 127 | if x: 128 | return x 129 | return None 130 | 131 | def get_grounding_data(self, text): 132 | topics = self.topic_extraction(text) 133 | url = self.get_appropriate_grounding(topics) 134 | grounding = self.get_desired_content(url.content) 135 | return grounding 136 | 137 | 138 | if __name__ == "__main__": 139 | print("Generating grounding data. This may take a while...") 140 | conversation = "hey thhere, what is up? I love Nokia phones." 141 | g = GroudingGenerator() 142 | grounding = g.get_grounding_data(conversation) 143 | print(grounding) 144 | -------------------------------------------------------------------------------- /src/cmr/my_optim.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | from torch.nn import Parameter 4 | from functools import wraps 5 | 6 | 7 | class EMA: 8 | """EMA 9 | by: xiaodl 10 | """ 11 | 12 | def __init__(self, model, gamma=0.99): 13 | self.gamma = gamma 14 | self.shadow = deepcopy(list(p.data.new(p.size()).zero_() for p in model.parameters())) 15 | 16 | def update(self, parameters): 17 | gamma = self.gamma 18 | for p, avg in zip(parameters, self.shadow): 19 | avg.mul_(gamma).add_(p.mul(1 - gamma)) 20 | 21 | def copy_out(self): 22 | return self.shadow 23 | 24 | def dump(model): 25 | parameters = deepcopy(list(p.data for p in model.parameters())) 26 | return parameters 27 | 28 | def reset(model, parameters): 29 | model_p = list(p.data for p in model.parameters()) 30 | for m, p in zip(model_p, parameters): 31 | m = deepcopy(p) 32 | 33 | 34 | """ 35 | Adapted from 36 | https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 37 | and https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py 38 | """ 39 | 40 | 41 | def _norm(p, dim): 42 | """Computes the norm over all dimensions except dim""" 43 | if dim is None: 44 | return p.norm() 45 | elif dim == 0: 46 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 47 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) 48 | elif dim == p.dim() - 1: 49 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 50 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) 51 | else: 52 | return _norm(p.transpose(0, dim), 0).transpose(0, dim) 53 | 54 | 55 | def _dummy(*args, **kwargs): 56 | """ 57 | We need to replace flatten_parameters with a nothing function 58 | """ 59 | return 60 | 61 | 62 | class WeightNorm(torch.nn.Module): 63 | 64 | def __init__(self, weights, dim): 65 | super(WeightNorm, self).__init__() 66 | self.weights = weights 67 | self.dim = dim 68 | 69 | def compute_weight(self, module, name): 70 | g = getattr(module, name + '_g') 71 | v = getattr(module, name + '_v') 72 | return v * (g / _norm(v, self.dim)) 73 | 74 | @staticmethod 75 | def apply(module, weights, dim): 76 | if issubclass(type(module), torch.nn.RNNBase): 77 | module.flatten_parameters = _dummy 78 | if weights is None: # do for all weight params 79 | weights = [w for w in module._parameters.keys() if 'weight' in w] 80 | fn = WeightNorm(weights, dim) 81 | for name in weights: 82 | if hasattr(module, name): 83 | print('Applying weight norm to {} - {}'.format(str(module), name)) 84 | weight = getattr(module, name) 85 | # remove w from parameter list 86 | del module._parameters[name] 87 | # add g and v as new parameters and express w as g/||v|| * v 88 | module.register_parameter( 89 | name + '_g', Parameter(_norm(weight, dim).data)) 90 | module.register_parameter(name + '_v', Parameter(weight.data)) 91 | setattr(module, name, fn.compute_weight(module, name)) 92 | 93 | # recompute weight before every forward() 94 | module.register_forward_pre_hook(fn) 95 | 96 | return fn 97 | 98 | def remove(self, module): 99 | for name in self.weights: 100 | weight = self.compute_weight(module) 101 | delattr(module, name) 102 | del module._parameters[name + '_g'] 103 | del module._parameters[name + '_v'] 104 | module.register_parameter(name, Parameter(weight.data)) 105 | 106 | def __call__(self, module, inputs): 107 | for name in self.weights: 108 | setattr(module, name, self.compute_weight(module, name)) 109 | 110 | 111 | def weight_norm(module, weights=None, dim=0): 112 | WeightNorm.apply(module, weights, dim) 113 | return module 114 | -------------------------------------------------------------------------------- /src/cmr/my_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MixingBoard/4383b382ffd9493007ac484c54d2aea2b64762a8/src/cmr/my_utils/__init__.py -------------------------------------------------------------------------------- /src/cmr/my_utils/eval_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modifications copyright (C) 2018 Texar 16 | # ============================================================================== 17 | """ 18 | Python implementation of BLEU and smoothed BLEU adapted from: 19 | `https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py` 20 | 21 | This module provides a Python implementation of BLEU and smoothed BLEU. 22 | Smooth BLEU is computed following the method outlined in the paper: 23 | 24 | (Lin et al. 2004) ORANGE: a method for evaluating automatic evaluation 25 | metrics for maching translation. 26 | Chin-Yew Lin, Franz Josef Och. COLING 2004. 27 | """ 28 | 29 | from __future__ import absolute_import 30 | from __future__ import print_function 31 | from __future__ import division 32 | from __future__ import unicode_literals 33 | 34 | import collections 35 | import math 36 | import six 37 | 38 | # pylint: disable=invalid-name, too-many-branches, too-many-locals 39 | # pylint: disable=too-many-arguments 40 | 41 | __all__ = [ 42 | "sentence_bleu", 43 | "corpus_bleu" 44 | ] 45 | 46 | def is_str(x): 47 | """Returns `True` if :attr:`x` is either a str or unicode. Returns `False` 48 | otherwise. 49 | """ 50 | return isinstance(x, six.string_types) 51 | 52 | def _get_ngrams(segment, max_order): 53 | """Extracts all n-grams up to a given maximum order from an input segment. 54 | 55 | Args: 56 | segment: text segment from which n-grams will be extracted. 57 | max_order: maximum length in tokens of the n-grams returned by this 58 | methods. 59 | 60 | Returns: 61 | The Counter containing all n-grams upto max_order in segment 62 | with a count of how many times each n-gram occurred. 63 | """ 64 | ngram_counts = collections.Counter() 65 | for order in range(1, max_order + 1): 66 | for i in range(0, len(segment) - order + 1): 67 | ngram = tuple(segment[i:i+order]) 68 | ngram_counts[ngram] += 1 69 | return ngram_counts 70 | 71 | def _maybe_str_to_list(list_or_str): 72 | if is_str(list_or_str): 73 | return list_or_str.split() 74 | return list_or_str 75 | 76 | def _lowercase(str_list): 77 | return [str_.lower() for str_ in str_list] 78 | 79 | def sentence_bleu(references, hypothesis, max_order=4, lowercase=False, 80 | smooth=False, return_all=False): 81 | """Calculates BLEU score of a hypothesis sentence. 82 | 83 | Args: 84 | references: A list of reference for the hypothesis. 85 | Each reference can be either a list of string tokens, or a string 86 | containing tokenized tokens separated with whitespaces. 87 | List can also be numpy array. 88 | hypotheses: A hypothesis sentence. 89 | Each hypothesis can be either a list of string tokens, or a 90 | string containing tokenized tokens separated with whitespaces. 91 | List can also be numpy array. 92 | lowercase (bool): If `True`, pass the "-lc" flag to the multi-bleu 93 | script. 94 | max_order (int): Maximum n-gram order to use when computing BLEU score. 95 | smooth (bool): Whether or not to apply (Lin et al. 2004) smoothing. 96 | return_all (bool): If `True`, returns BLEU and all n-gram precisions. 97 | 98 | Returns: 99 | If :attr:`return_all` is `False` (default), returns a float32 100 | BLEU score. 101 | 102 | If :attr:`return_all` is `True`, returns a list of float32 scores: 103 | `[BLEU] + n-gram precisions`, which is of length :attr:`max_order`+1. 104 | """ 105 | return corpus_bleu( 106 | [references], [hypothesis], max_order=max_order, lowercase=lowercase, 107 | smooth=smooth, return_all=return_all) 108 | 109 | def corpus_bleu(list_of_references, hypotheses, max_order=4, lowercase=False, 110 | smooth=False, return_all=True): 111 | """Computes corpus-level BLEU score. 112 | 113 | Args: 114 | list_of_references: A list of lists of references for each hypothesis. 115 | Each reference can be either a list of string tokens, or a string 116 | containing tokenized tokens separated with whitespaces. 117 | List can also be numpy array. 118 | hypotheses: A list of hypothesis sentences. 119 | Each hypothesis can be either a list of string tokens, or a 120 | string containing tokenized tokens separated with whitespaces. 121 | List can also be numpy array. 122 | lowercase (bool): If `True`, lowercase reference and hypothesis tokens. 123 | max_order (int): Maximum n-gram order to use when computing BLEU score. 124 | smooth (bool): Whether or not to apply (Lin et al. 2004) smoothing. 125 | return_all (bool): If `True`, returns BLEU and all n-gram precisions. 126 | 127 | Returns: 128 | If :attr:`return_all` is `False` (default), returns a float32 129 | BLEU score. 130 | 131 | If :attr:`return_all` is `True`, returns a list of float32 scores: 132 | `[BLEU] + n-gram precisions`, which is of length :attr:`max_order`+1. 133 | """ 134 | list_of_references = list_of_references 135 | hypotheses = hypotheses 136 | 137 | matches_by_order = [0] * max_order 138 | possible_matches_by_order = [0] * max_order 139 | reference_length = 0 140 | hyperthsis_length = 0 141 | for (references, hyperthsis) in zip(list_of_references, hypotheses): 142 | reference_length += min(len(r) for r in references) 143 | hyperthsis_length += len(hyperthsis) 144 | 145 | merged_ref_ngram_counts = collections.Counter() 146 | for reference in references: 147 | reference = _maybe_str_to_list(reference) 148 | if lowercase: 149 | reference = _lowercase(reference) 150 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 151 | 152 | hyperthsis = _maybe_str_to_list(hyperthsis) 153 | if lowercase: 154 | hyperthsis = _lowercase(hyperthsis) 155 | hyperthsis_ngram_counts = _get_ngrams(hyperthsis, max_order) 156 | 157 | overlap = hyperthsis_ngram_counts & merged_ref_ngram_counts 158 | for ngram in overlap: 159 | matches_by_order[len(ngram)-1] += overlap[ngram] 160 | for order in range(1, max_order+1): 161 | possible_matches = len(hyperthsis) - order + 1 162 | if possible_matches > 0: 163 | possible_matches_by_order[order-1] += possible_matches 164 | 165 | precisions = [0] * max_order 166 | for i in range(0, max_order): 167 | if smooth: 168 | precisions[i] = ((matches_by_order[i] + 1.) / 169 | (possible_matches_by_order[i] + 1.)) 170 | else: 171 | if possible_matches_by_order[i] > 0: 172 | precisions[i] = (float(matches_by_order[i]) / 173 | possible_matches_by_order[i]) 174 | else: 175 | precisions[i] = 0.0 176 | 177 | if min(precisions) > 0: 178 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 179 | geo_mean = math.exp(p_log_sum) 180 | else: 181 | geo_mean = 0 182 | 183 | ratio = float(hyperthsis_length) / reference_length 184 | 185 | if ratio > 1.0: 186 | bp = 1. 187 | else: 188 | try: 189 | bp = math.exp(1 - 1. / ratio) 190 | except ZeroDivisionError: 191 | bp = math.exp(1 - 1. / (ratio + 1e-8)) 192 | 193 | bleu = geo_mean 194 | 195 | if return_all: 196 | return [bleu * 100] + [p * 100 for p in precisions] 197 | else: 198 | return bleu * 100 -------------------------------------------------------------------------------- /src/cmr/my_utils/eval_nist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modifications copyright (C) 2018 Texar 16 | # ============================================================================== 17 | """ 18 | Approximate Python implementation of NIST and smoothed NIST adapted from 19 | `eval_bleu.py`. The adaptation is based on the descrition here: 20 | https://www.nltk.org/api/nltk.translate.html 21 | 22 | which says: 23 | 24 | The main differences between NIST and BLEU are: 25 | 26 | BLEU uses geometric mean of the ngram overlaps, NIST uses arithmetic mean. 27 | NIST has a different brevity penalty 28 | NIST score from mteval-14.pl has a self-contained tokenizer 29 | """ 30 | 31 | from __future__ import absolute_import 32 | from __future__ import print_function 33 | from __future__ import division 34 | from __future__ import unicode_literals 35 | 36 | import collections 37 | import math 38 | import six 39 | import numpy as np 40 | 41 | # pylint: disable=invalid-name, too-many-branches, too-many-locals 42 | # pylint: disable=too-many-arguments 43 | 44 | __all__ = [ 45 | "sentence_nist", 46 | "corpus_nist" 47 | ] 48 | 49 | def is_str(x): 50 | """Returns `True` if :attr:`x` is either a str or unicode. Returns `False` 51 | otherwise. 52 | """ 53 | return isinstance(x, six.string_types) 54 | 55 | def _get_ngrams(segment, max_order): 56 | """Extracts all n-grams up to a given maximum order from an input segment. 57 | 58 | Args: 59 | segment: text segment from which n-grams will be extracted. 60 | max_order: maximum length in tokens of the n-grams returned by this 61 | methods. 62 | 63 | Returns: 64 | The Counter containing all n-grams upto max_order in segment 65 | with a count of how many times each n-gram occurred. 66 | """ 67 | ngram_counts = collections.Counter() 68 | for order in range(1, max_order + 1): 69 | for i in range(0, len(segment) - order + 1): 70 | ngram = tuple(segment[i:i+order]) 71 | ngram_counts[ngram] += 1 72 | return ngram_counts 73 | 74 | def _maybe_str_to_list(list_or_str): 75 | if is_str(list_or_str): 76 | return list_or_str.split() 77 | return list_or_str 78 | 79 | def _lowercase(str_list): 80 | return [str_.lower() for str_ in str_list] 81 | 82 | def sentence_nist(references, hypothesis, max_order=4, lowercase=False, 83 | smooth=False, return_all=False): 84 | """Calculates nist score of a hypothesis sentence. 85 | 86 | Args: 87 | references: A list of reference for the hypothesis. 88 | Each reference can be either a list of string tokens, or a string 89 | containing tokenized tokens separated with whitespaces. 90 | List can also be numpy array. 91 | hypotheses: A hypothesis sentence. 92 | Each hypothesis can be either a list of string tokens, or a 93 | string containing tokenized tokens separated with whitespaces. 94 | List can also be numpy array. 95 | lowercase (bool): If `True`, pass the "-lc" flag to the multi-nist 96 | script. 97 | max_order (int): Maximum n-gram order to use when computing nist score. 98 | smooth (bool): Whether or not to apply (Lin et al. 2004) smoothing. 99 | return_all (bool): If `True`, returns nist and all n-gram precisions. 100 | 101 | Returns: 102 | If :attr:`return_all` is `False` (default), returns a float32 103 | nist score. 104 | 105 | If :attr:`return_all` is `True`, returns a list of float32 scores: 106 | `[nist] + n-gram precisions`, which is of length :attr:`max_order`+1. 107 | """ 108 | return corpus_nist( 109 | [references], [hypothesis], max_order=max_order, lowercase=lowercase, 110 | smooth=smooth, return_all=return_all) 111 | 112 | def corpus_nist(list_of_references, hypotheses, max_order=4, lowercase=False, 113 | smooth=False, return_all=True): 114 | """Computes corpus-level nist score. 115 | 116 | Args: 117 | list_of_references: A list of lists of references for each hypothesis. 118 | Each reference can be either a list of string tokens, or a string 119 | containing tokenized tokens separated with whitespaces. 120 | List can also be numpy array. 121 | hypotheses: A list of hypothesis sentences. 122 | Each hypothesis can be either a list of string tokens, or a 123 | string containing tokenized tokens separated with whitespaces. 124 | List can also be numpy array. 125 | lowercase (bool): If `True`, lowercase reference and hypothesis tokens. 126 | max_order (int): Maximum n-gram order to use when computing nist score. 127 | smooth (bool): Whether or not to apply (Lin et al. 2004) smoothing. 128 | return_all (bool): If `True`, returns nist and all n-gram precisions. 129 | 130 | Returns: 131 | If :attr:`return_all` is `False` (default), returns a float32 132 | nist score. 133 | 134 | If :attr:`return_all` is `True`, returns a list of float32 scores: 135 | `[nist] + n-gram precisions`, which is of length :attr:`max_order`+1. 136 | """ 137 | list_of_references = list_of_references 138 | hypotheses = hypotheses 139 | 140 | matches_by_order = [0] * max_order 141 | possible_matches_by_order = [0] * max_order 142 | reference_length = 0 143 | hyperthsis_length = 0 144 | for (references, hyperthsis) in zip(list_of_references, hypotheses): 145 | reference_length += min(len(r) for r in references) 146 | hyperthsis_length += len(hyperthsis) 147 | 148 | merged_ref_ngram_counts = collections.Counter() 149 | for reference in references: 150 | reference = _maybe_str_to_list(reference) 151 | if lowercase: 152 | reference = _lowercase(reference) 153 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 154 | 155 | hyperthsis = _maybe_str_to_list(hyperthsis) 156 | if lowercase: 157 | hyperthsis = _lowercase(hyperthsis) 158 | hyperthsis_ngram_counts = _get_ngrams(hyperthsis, max_order) 159 | 160 | overlap = hyperthsis_ngram_counts & merged_ref_ngram_counts 161 | for ngram in overlap: 162 | matches_by_order[len(ngram)-1] += overlap[ngram] 163 | for order in range(1, max_order+1): 164 | possible_matches = len(hyperthsis) - order + 1 165 | if possible_matches > 0: 166 | possible_matches_by_order[order-1] += possible_matches 167 | 168 | precisions = [0] * max_order 169 | for i in range(0, max_order): 170 | if smooth: 171 | precisions[i] = ((matches_by_order[i] + 1.) / 172 | (possible_matches_by_order[i] + 1.)) 173 | else: 174 | if possible_matches_by_order[i] > 0: 175 | precisions[i] = (float(matches_by_order[i]) / 176 | possible_matches_by_order[i]) 177 | else: 178 | precisions[i] = 0.0 179 | 180 | if min(precisions) > 0: 181 | ari_mean = np.mean(precisions) 182 | else: 183 | ari_mean = 0 184 | 185 | ratio = float(hyperthsis_length) / reference_length 186 | 187 | if ratio > 1.0: 188 | bp = 1. 189 | else: 190 | try: 191 | bp = math.exp(1 - 1. / ratio) 192 | except ZeroDivisionError: 193 | bp = math.exp(1 - 1. / (ratio + 1e-8)) 194 | 195 | nist = ari_mean #geo_mean 196 | 197 | if return_all: 198 | return [nist * 100] + [p * 100 for p in precisions] 199 | else: 200 | return nist * 100 201 | 202 | 203 | if __name__ == "__main__": 204 | hypothesis1 = [ 205 | 'It', 'is', 'a', 'guide', 'to', 'action', 'which', 206 | 'ensures', 'that', 'the', 'military', 'always', 207 | 'obeys', 'the', 'commands', 'of', 'the', 'party', 208 | 'It', 'is', 'a', 'guide', 'to', 'action', 'which', 209 | 'ensures', 'that', 'the', 'military', 'always', 210 | 'obeys', 'the', 'commands', 'of', 'the', 'party' 211 | ] 212 | hypothesis2 = [ 213 | 'It', 'is', 'to', 'insure', 'the', 'troops', 214 | 'forever', 'hearing', 'the', 'activity', 'guidebook', 215 | 'that', 'party', 'direct', 216 | 'It', 'is', 'to', 'insure', 'the', 'troops', 217 | 'forever', 'hearing', 'the', 'activity', 'guidebook', 218 | 'that', 'party', 'direct', 219 | ] 220 | reference1 = [ 221 | 'It', 'is', 'a', 'guide', 'to', 'action', 'that', 222 | 'ensures', 'that', 'the', 'military', 'will', 'forever', 223 | 'heed', 'Party', 'commands', 224 | ] 225 | reference2 = [ 226 | 'It', 'is', 'the', 'guiding', 'principle', 'which', 227 | 'guarantees', 'the', 'military', 'forces', 'always', 228 | 'being', 'under', 'the', 'command', 'of', 'the', 229 | 'Party', 230 | ] 231 | reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 232 | 'army', 'always', 'to', 'heed', 'the', 'directions', 233 | 'of', 'the', 'party'] 234 | 235 | refs = [reference1] 236 | 237 | nist_1 = sentence_nist(refs, hypothesis1, smooth=True) 238 | print(nist_1) 239 | 240 | nist_2 = sentence_nist(refs, hypothesis2, smooth=True) 241 | print(nist_2) 242 | 243 | import nltk 244 | nist_1 = nltk.translate.nist_score.sentence_nist(refs, hypothesis1) 245 | print(nist_1) 246 | nist_2 = nltk.translate.nist_score.sentence_nist(refs, hypothesis2) 247 | print(nist_2) 248 | 249 | -------------------------------------------------------------------------------- /src/cmr/my_utils/log_wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import gmtime, strftime 3 | import sys 4 | 5 | def create_logger(name, silent=False, to_disk=False, log_file=None, prefix=None): 6 | """Logger wrapper 7 | by xiaodong liu, xiaodl@microsoft.com 8 | """ 9 | # setup logger 10 | log = logging.getLogger(name) 11 | log.setLevel(logging.DEBUG) 12 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 13 | if not silent: 14 | ch = logging.StreamHandler(sys.stdout) 15 | ch.setLevel(logging.INFO) 16 | ch.setFormatter(formatter) 17 | log.addHandler(ch) 18 | if to_disk: 19 | prefix = prefix if prefix is not None else 'my_log' 20 | log_file = log_file if log_file is not None else strftime('/log/{}-%Y-%m-%d-%H-%M-%S.log'.format(prefix), gmtime()) 21 | fh = logging.FileHandler(log_file) 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | log.addHandler(fh) 25 | return log 26 | -------------------------------------------------------------------------------- /src/cmr/my_utils/squad_eval.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | import subprocess 10 | import numpy as np 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | def remove_articles(text): 15 | return re.sub(r'\b(a|an|the)\b', ' ', text) 16 | 17 | def white_space_fix(text): 18 | return ' '.join(text.split()) 19 | 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return ''.join(ch for ch in text if ch not in exclude) 23 | 24 | def lower(text): 25 | return text.lower() 26 | 27 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 28 | 29 | 30 | def f1_score(prediction, ground_truth): 31 | prediction_tokens = normalize_answer(prediction).split() 32 | ground_truth_tokens = normalize_answer(ground_truth).split() 33 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 34 | num_same = sum(common.values()) 35 | if num_same == 0: 36 | return 0 37 | precision = 1.0 * num_same / len(prediction_tokens) 38 | recall = 1.0 * num_same / len(ground_truth_tokens) 39 | f1 = (2 * precision * recall) / (precision + recall) 40 | return f1 41 | 42 | 43 | def exact_match_score(prediction, ground_truth): 44 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 45 | 46 | 47 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 48 | scores_for_ground_truths = [] 49 | for ground_truth in ground_truths: 50 | score = metric_fn(prediction, ground_truth) 51 | scores_for_ground_truths.append(score) 52 | return max(scores_for_ground_truths) 53 | 54 | def evaluate_file(data_path, predictions): 55 | with open(data_path) as dataset_file: 56 | dataset_json = json.load(dataset_file) 57 | if (dataset_json['version'] != expected_version): 58 | print('Evaluation expects v-' + expected_version + 59 | ', but got dataset with v-' + dataset_json['version'], 60 | file=sys.stderr) 61 | ground_truths = [] 62 | for line in dataset_json: 63 | ground_truth = dataset_json['answer_tok'] 64 | ground_truths.append(ground_truth) 65 | return get_bleu(ground_truth, predictions) 66 | 67 | def evaluate(dataset, predictions): 68 | f1 = exact_match = total = 0 69 | for article in dataset: 70 | for paragraph in article['paragraphs']: 71 | for qa in paragraph['qas']: 72 | total += 1 73 | if qa['id'] not in predictions: 74 | message = 'Unanswered question ' + qa['id'] + \ 75 | ' will receive score 0.' 76 | print(message, file=sys.stderr) 77 | continue 78 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 79 | prediction = predictions[qa['id']] 80 | exact_match += metric_max_over_ground_truths( 81 | exact_match_score, prediction, ground_truths) 82 | f1 += metric_max_over_ground_truths( 83 | f1_score, prediction, ground_truths) 84 | 85 | exact_match = 100.0 * exact_match / total 86 | f1 = 100.0 * f1 / total 87 | return {'exact_match': exact_match, 'f1': f1} 88 | 89 | def bleu(stats): 90 | """Compute BLEU given n-gram statistics.""" 91 | if len(filter(lambda x: x == 0, stats)) > 0: 92 | return 0 93 | (c, r) = stats[:2] 94 | log_bleu_prec = sum( 95 | [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])] 96 | ) / 4. 97 | return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec) 98 | 99 | 100 | def get_bleu(hypotheses, reference): 101 | """Get validation BLEU score for dev set.""" 102 | stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) 103 | for hyp, ref in zip(hypotheses, reference): 104 | stats += np.array(bleu_stats(hyp, ref)) 105 | return 100 * bleu(stats) 106 | 107 | 108 | def get_bleu_moses(hypotheses, reference): 109 | """Get BLEU score with moses bleu score.""" 110 | with open('tmp_hypotheses.txt', 'w') as f: 111 | for hypothesis in hypotheses: 112 | f.write(' '.join(hypothesis) + '\n') 113 | 114 | with open('tmp_reference.txt', 'w') as f: 115 | for ref in reference: 116 | f.write(' '.join(ref) + '\n') 117 | 118 | hypothesis_pipe = '\n'.join([' '.join(hyp) for hyp in hypotheses]) 119 | pipe = subprocess.Popen( 120 | ["perl", './bleu_eval/multi-bleu.perl', '-lc', 'tmp_reference.txt'], 121 | stdin=subprocess.PIPE, 122 | stdout=subprocess.PIPE 123 | ) 124 | pipe.stdin.write(hypothesis_pipe.encode()) 125 | pipe.stdin.close() 126 | return pipe.stdout.read() 127 | 128 | def get_bleu_moses_score(hypotheses, reference): 129 | """Get BLEU score with moses bleu score.""" 130 | with open('tmp_hypotheses.txt', 'w') as f: 131 | for hypothesis in hypotheses: 132 | f.write(' '.join(hypothesis) + '\n') 133 | 134 | with open('tmp_reference.txt', 'w') as f: 135 | for ref in reference: 136 | f.write(' '.join(ref) + '\n') 137 | 138 | hypothesis_pipe = '\n'.join([' '.join(hyp) for hyp in hypotheses]) 139 | print(hypothesis_pipe) 140 | print('\n'.join([' '.join(ref) for ref in reference])) 141 | pipe = subprocess.Popen( 142 | ["perl", './bleu_eval/multi-bleu.perl', '-lc', 'tmp_reference.txt'], 143 | stdin=subprocess.PIPE, 144 | stdout=subprocess.PIPE 145 | ) 146 | pipe.stdin.write(hypothesis_pipe.encode()) 147 | pipe.stdin.close() 148 | bleu_str = pipe.stdout.read() 149 | print(bleu_str) 150 | 151 | bleu_str = str(bleu_str) 152 | bleu_score = re.search(r"BLEU = (.+?),", bleu_str).group(1) 153 | bleu_score = np.float32(bleu_score) 154 | return bleu_score 155 | 156 | 157 | if __name__ == '__main__': 158 | expected_version = '1.1' 159 | parser = argparse.ArgumentParser( 160 | description='Evaluation for SQuAD ' + expected_version) 161 | parser.add_argument('dataset_file', help='Dataset file') 162 | parser.add_argument('prediction_file', help='Prediction File') 163 | args = parser.parse_args() 164 | with open(args.dataset_file) as dataset_file: 165 | dataset_json = json.load(dataset_file) 166 | if (dataset_json['version'] != expected_version): 167 | print('Evaluation expects v-' + expected_version + 168 | ', but got dataset with v-' + dataset_json['version'], 169 | file=sys.stderr) 170 | dataset = dataset_json['data'] 171 | with open(args.prediction_file) as prediction_file: 172 | predictions = json.load(prediction_file) 173 | print(json.dumps(evaluate(dataset, predictions))) 174 | -------------------------------------------------------------------------------- /src/cmr/my_utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import warnings 3 | import spacy 4 | import tqdm 5 | import logging 6 | import unicodedata 7 | from collections import Counter 8 | from functools import partial 9 | from multiprocessing import Pool as ThreadPool 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | PAD = '' 14 | UNK = '' 15 | STA= '' 16 | END = '' 17 | 18 | PAD_ID = 0 19 | UNK_ID = 1 20 | STA_ID = 2 21 | END_ID = 3 22 | 23 | def normalize_text(text): 24 | return unicodedata.normalize('NFD', text) 25 | 26 | def space_extend(matchobj): 27 | return ' ' + matchobj.group(0) + ' ' 28 | 29 | def reform_text(text): 30 | text = re.sub(u'-|¢|¥|€|£|\u2010|\u2011|\u2012|\u2013|\u2014|\u2015|%|\[|\]|:|\(|\)|/', space_extend, text) 31 | text = text.strip(' \n') 32 | text = re.sub('\s+', ' ', text) 33 | return text 34 | 35 | class Vocabulary(object): 36 | INIT_LEN = 4 37 | def __init__(self, neat=False): 38 | self.neat = neat 39 | if not neat: 40 | self.tok2ind = {PAD: PAD_ID, UNK: UNK_ID, STA: STA_ID, END: END_ID} 41 | self.ind2tok = {PAD_ID: PAD, UNK_ID: UNK, STA_ID: STA, END_ID:END} 42 | else: 43 | self.tok2ind = {} 44 | self.ind2tok = {} 45 | 46 | def __len__(self): 47 | return len(self.tok2ind) 48 | 49 | def __iter__(self): 50 | return iter(self.tok2ind) 51 | 52 | def __contains__(self, key): 53 | if type(key) == int: 54 | return key in self.ind2tok 55 | elif type(key) == str: 56 | return key in self.tok2ind 57 | 58 | def __getitem__(self, key): 59 | if type(key) == int: 60 | return self.ind2tok.get(key, -1) if self.neat else self.ind2tok.get(key, UNK) 61 | if type(key) == str: 62 | return self.tok2ind.get(key, None) if self.neat else self.tok2ind.get(key,self.tok2ind.get(UNK)) 63 | 64 | def __setitem__(self, key, item): 65 | if type(key) == int and type(item) == str: 66 | self.ind2tok[key] = item 67 | elif type(key) == str and type(item) == int: 68 | self.tok2ind[key] = item 69 | else: 70 | raise RuntimeError('Invalid (key, item) types.') 71 | 72 | def add(self, token): 73 | if token not in self.tok2ind: 74 | index = len(self.tok2ind) 75 | self.tok2ind[token] = index 76 | self.ind2tok[index] = token 77 | 78 | def get_vocab_list(self, with_order=True): 79 | if with_order: 80 | words = [self[k] for k in range(0, len(self))] 81 | else: 82 | words = [k for k in self.tok2ind.keys() 83 | if k not in {PAD, UNK, STA, END}] 84 | return words 85 | 86 | def toidx(self, tokens): 87 | return [self[tok] for tok in tokens] 88 | 89 | def copy(self): 90 | """Deep copy 91 | """ 92 | new_vocab = Vocabulary(self.neat) 93 | for w in self: 94 | new_vocab.add(w) 95 | return new_vocab 96 | 97 | def build(words, neat=False): 98 | vocab = Vocabulary(neat) 99 | for w in words: vocab.add(w) 100 | return vocab 101 | -------------------------------------------------------------------------------- /src/cmr/my_utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value.""" 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | def set_environment(seed, set_cuda=False): 23 | random.seed(seed) 24 | numpy.random.seed(seed) 25 | torch.manual_seed(seed) 26 | if torch.cuda.is_available() and set_cuda: 27 | torch.cuda.manual_seed_all(seed) 28 | -------------------------------------------------------------------------------- /src/cmr/my_utils/word2vec_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .tokenizer import normalize_text 3 | 4 | def load_glove_vocab(path, glove_dim, wv_dim=300): 5 | vocab = set() 6 | with open(path, encoding="utf8") as f: 7 | for line in f: 8 | elems = line.split() 9 | token = normalize_text(' '.join(elems[0:-wv_dim])) 10 | vocab.add(token) 11 | return vocab 12 | 13 | def build_embedding(path, targ_vocab, wv_dim): 14 | vocab_size = len(targ_vocab) 15 | emb = np.zeros((vocab_size, wv_dim)) 16 | emb[0] = 0 17 | count = 0 18 | 19 | w2id = {w: i for i, w in enumerate(targ_vocab)} 20 | 21 | with open(path, encoding="utf8") as f: 22 | for line in f: 23 | elems = line.split() 24 | token = normalize_text(' '.join(elems[0:-wv_dim])) 25 | if token in w2id: 26 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]] 27 | count += 1 28 | print ('loading glove done!') 29 | print (count) 30 | return emb 31 | -------------------------------------------------------------------------------- /src/cmr/process_raw_data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from itertools import chain 4 | 5 | 6 | def filter_text(text): 7 | #https://stackoverflow.com/questions/4703390/how-to-extract-a-floating-number-from-a-string 8 | text = re.sub(r'[-+]?(\d+([.,]\d*)?|[.,]\d+)([eE][-+]?\d+)?', '', text) 9 | text = re.sub("[\(].*?[\)]", "", text) # [\[\(].*?[\]\)] 10 | text = text.split() 11 | new_text = [] 12 | for x in text: 13 | if 'www' in x or 'http' in x: 14 | continue 15 | new_text.append(x) 16 | return ' '.join(new_text) 17 | 18 | 19 | def filter_query(text, max_len=100): 20 | 'removing the history of query' 21 | if 'EOS' in text: 22 | text = text[text.rindex('EOS')+4:] 23 | text = filter_text(text) 24 | if text.startswith('til'): 25 | text = re.sub('til ', '', text, 1) 26 | if text.startswith('...'): 27 | text = re.sub('...', '', text, 1) 28 | if len(text.split()) > 50 and text.endswith('...'): 29 | text = text[:-3] + '' 30 | text = text.strip().split() 31 | if len(text) > max_len: 32 | text = [''] + text[-max_len:] 33 | return ' '.join(text) 34 | 35 | 36 | def filter_fact(text): 37 | text = filter_text(text) 38 | text = re.sub('- wikipedia ', '', text, 1) 39 | text = re.sub(' \[ edit \]', '', text, 1) 40 | text = re.sub('

navigation menu

', '', text) 41 | return text 42 | 43 | 44 | def filter_resp(text): 45 | text = re.sub('\[|\]', '', text) 46 | return filter_text(text) 47 | # anything else? 48 | 49 | 50 | def no_label(fact): 51 | include_labels = [ 52 | '', '<anchor>', '<p>', '<h1>', '<h2>', '<h3>', '<h4>', 53 | '', '', '

', '', '', '', ''] 54 | for il in include_labels: 55 | if il in fact: 56 | return False 57 | return True 58 | 59 | 60 | def write_files(output_path, data): 61 | with open(output_path + '.full', 'w', encoding='utf8') as fw_full: 62 | with open(output_path + '.query', 'w', encoding='utf8') as fw_query: 63 | with open(output_path + '.response', 'w', encoding='utf8') as fw_response: 64 | with open(output_path + '.fact', 'w', encoding='utf8') as fw_fact: 65 | for line in data: 66 | fw_query.write(' '.join(line['query']) + '\n') 67 | fw_response.write(' '.join(line['response']) + '\n') 68 | fw_fact.write(' '.join(line['fact']) + '\n') 69 | fw_full.write(line['raw']) 70 | 71 | 72 | def load_facts(fact_path): 73 | fact_dict = {} 74 | anchor_idx = 0 75 | anc_flag = False 76 | with open(fact_path, 'r', encoding='utf8') as fin: 77 | for i, line in enumerate(fin): 78 | parts = line.strip().split('\t') 79 | if len(parts) != 5: 80 | print('[Warning] loss fact #parts !=5, line %d in %s' 81 | % (i, fact_path)) 82 | continue 83 | fact_id = parts[2].strip() 84 | fact = filter_fact(parts[-1].strip()).split() 85 | 86 | if no_label(fact): 87 | continue 88 | 89 | if len(fact) > 100: 90 | fact = fact[:100] + [''] 91 | 92 | if fact_id in fact_dict: 93 | anchor_idx += 1 94 | fact_dict[fact_id]['facts'].append(fact) 95 | if '' in fact: 96 | anc_flag = True 97 | fact_dict[fact_id]['anchor_idx'] = anchor_idx 98 | fact_dict[fact_id]['anchor_label'] = fact[0] 99 | else: 100 | anchor_idx = 0 101 | anc_flag = False 102 | fact_dict[fact_id] = {} 103 | fact_dict[fact_id]['anchor_idx'] = anchor_idx 104 | fact_dict[fact_id]['anchor_label'] = '' 105 | fact_dict[fact_id]['facts'] = [fact] 106 | fact_dict[fact_id]['anchor_status'] = anc_flag 107 | return fact_dict 108 | 109 | 110 | def combine_fact(fact_dict, anc_type='section', fact_len=12, just_anc=False): 111 | anc_end_idx = 0 112 | ret_fact_dict = {} 113 | for fact_id in fact_dict.keys(): 114 | facts = fact_dict[fact_id]['facts'] 115 | anc_idx = fact_dict[fact_id]['anchor_idx'] 116 | anc_label = fact_dict[fact_id]['anchor_label'] 117 | anc_flag = fact_dict[fact_id]['anchor_status'] 118 | 119 | if just_anc and not anc_flag: 120 | continue 121 | 122 | if not anc_flag: 123 | facts = facts[:fact_len] 124 | else: 125 | if anc_type == 'sentence': 126 | anc_end_idx = anc_idx + 2 127 | elif anc_type == 'section': 128 | for anc_i, anc_text in enumerate(facts[anc_idx + 1:]): 129 | if anc_label == anc_text[0]: #

130 | anc_end_idx = anc_idx + anc_i + 1 131 | break 132 | else: 133 | print('anchor type error') 134 | exit() 135 | facts = facts[:2] + facts[anc_idx:anc_end_idx] 136 | 137 | ret_fact_dict[fact_id] = fact_dict[fact_id] 138 | facts = ' '.join(list(chain(*facts))) 139 | ret_fact_dict[fact_id]['facts'] = facts 140 | return ret_fact_dict 141 | 142 | 143 | def load_conv(conv_path, fact_dict, is_train=True, min_que=5): 144 | conv_fact_list = [] 145 | count_min = 0 146 | count_dup = 0 147 | count_short_fact = 0 148 | hash_set = set() 149 | with open(conv_path, 'r', encoding='utf8') as fin: 150 | for i, line in enumerate(fin): 151 | parts = line.strip().split('\t') 152 | if len(parts) != 7: 153 | print('[Warning] loss convos #parts != 7, line %d in %s' 154 | % (i, conv_path)) 155 | continue 156 | 157 | hash_id = parts[0].strip() 158 | if hash_id in hash_set: 159 | count_dup += 1 160 | continue 161 | else: 162 | hash_set.add(hash_id) 163 | 164 | conv_id = parts[2].strip() 165 | query = filter_query(parts[-2].strip()).split() 166 | response = filter_resp(parts[-1].strip()).split() 167 | 168 | if conv_id in fact_dict: 169 | facts = fact_dict[conv_id]['facts'] 170 | facts = filter_text(facts) 171 | else: 172 | continue 173 | 174 | facts = facts.split() 175 | 176 | if is_train: 177 | if len(response) < 5: 178 | count_min += 1 179 | continue 180 | 181 | if len(query) <= 1 or \ 182 | len(facts) <= 1: 183 | continue 184 | else: 185 | if len(query) == 0: 186 | query = ['UNK'] 187 | if len(facts) == 0: 188 | facts = ['UNK'] 189 | 190 | if len(facts) < 10: 191 | count_short_fact += 1 192 | 193 | if len(facts) > 500: 194 | facts = facts[:500] + [''] 195 | 196 | if len(response) > 30: 197 | response = response[:30] 198 | 199 | conv_fact_dict = { 200 | 'query': query, 201 | 'response': response, 202 | 'fact': facts, 203 | 'conv_id': conv_id, 204 | 'raw': line, 205 | 'hash_id': hash_id 206 | } 207 | conv_fact_list.append(conv_fact_dict) 208 | return conv_fact_list, count_min, count_dup, count_short_fact 209 | 210 | 211 | def combine_files(files_path, anc_type='section', fact_len=12, just_anc=False, is_train=False): 212 | file_name = set() 213 | data_list = [] 214 | count_mins = 0 215 | count_dups = 0 216 | count_short_facts = 0 217 | for file in os.listdir(files_path): 218 | file_name.add(file[:file.find('.')]) 219 | print(file_name) 220 | 221 | for file in file_name: 222 | init_path = os.path.join(files_path, file) 223 | fact_dict = load_facts(init_path + '.facts.txt') 224 | comb_fact = combine_fact(fact_dict, anc_type, fact_len, just_anc) 225 | conv_fact, count_min, count_dup,count_short_fact = load_conv(init_path + '.convos.txt', comb_fact, is_train) 226 | data_list += conv_fact 227 | count_mins += count_min 228 | count_dups += count_dup 229 | count_short_facts += count_short_fact 230 | print(len(data_list)) 231 | print('have discard the short response {} times'.format(count_mins)) 232 | print('the percentage of discarding is {0:.2f}%'.format(count_mins / len(data_list) * 100)) 233 | print('there are {} hash value are the same'.format(count_dups)) 234 | print('the percentage of duplicates is {0:.2f}%'.format(count_dups / len(data_list) * 100)) 235 | print('there are {} count_short_facts'.format(count_short_facts)) 236 | print('the percentage of count_short_facts is {0:.2f}%'.format(count_short_facts / len(data_list) * 100)) 237 | return data_list -------------------------------------------------------------------------------- /src/cmr/recurrent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from .my_optim import weight_norm as WN 7 | 8 | # TODO: use system func to bind ~ 9 | RNN_MAP = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} 10 | 11 | class OneLayerBRNN(nn.Module): 12 | def __init__(self, input_size, hidden_size, prefix='stack_rnn', opt={}, dropout=None): 13 | super(OneLayerBRNN, self).__init__() 14 | self.opt = opt 15 | self.prefix = prefix 16 | self.cell_type = self.opt.get('{}_cell'.format(self.prefix), 'lstm') 17 | self.emb_dim = self.opt.get('{}_embd_dim'.format(self.prefix), 0) 18 | self.maxout_on = self.opt.get('{}_maxout_on'.format(self.prefix), False) 19 | self.weight_norm_on = self.opt.get('{}_weight_norm_on'.format(self.prefix), False) 20 | self.dropout = dropout 21 | self.output_size = hidden_size if self.maxout_on else hidden_size * 2 22 | self.hidden_size = hidden_size 23 | self.rnn = RNN_MAP[self.cell_type](input_size, hidden_size, num_layers=1, bidirectional=True) 24 | 25 | def forward(self, x, x_mask): 26 | x = x.transpose(0, 1) 27 | size = list(x.size()) 28 | rnn_output, h = self.rnn(x) 29 | if self.maxout_on: 30 | rnn_output = rnn_output.view(size[0], size[1], self.hidden_size, 2).max(-1)[0] 31 | # Transpose back 32 | hiddens = rnn_output.transpose(0, 1) 33 | return hiddens 34 | 35 | class BRNNEncoder(nn.Module): 36 | def __init__(self, input_size, hidden_size, prefix='rnn', opt={}, dropout=None): 37 | super(BRNNEncoder, self).__init__() 38 | self.opt = opt 39 | self.dropout = dropout 40 | self.cell_type = opt.get('{}_cell'.format(self.prefix), 'gru') 41 | self.weight_norm_on = opt.get('{}_weight_norm_on'.format(self.prefix), False) 42 | self.top_layer_only = opt.get('{}_top_layer_only'.format(self.prefix), False) 43 | self.num_layers = opt.get('{}_num_layers'.format(self.prefix), 1) 44 | self.rnn = RNN_MAP[self.cell_type](input_size, hidden_size, self.num_layers, bidirectional=True) 45 | if self.weight_norm_on: 46 | self.rnn = WN(self.rnn) 47 | if self.top_layer_only: 48 | self.output_size = hidden_size * 2 49 | else: 50 | self.output_size = self.num_layers * hidden_size * 2 51 | 52 | def forward(self, x, x_mask): 53 | x = self.dropout(x) 54 | _, h = self.rnn(x.transpose(0, 1).contiguous()) 55 | if self.cell_type == 'lstm': 56 | h = h[0] 57 | shape = h.size() 58 | h = h.view(self.num_layers, 2, shape[1], shape[3]).transpose(1,2).contiguous() 59 | h = h.view(self.num_layers, shape[1], 2 * shape[3]) 60 | if self.top_layer_only: 61 | return h[-1] 62 | else: 63 | return h.transose(0, 1).contiguous().view(x.size(0), -1) 64 | 65 | 66 | #------------------------------ 67 | # Contextual embedding 68 | # TODO: remove packing to speed up 69 | # Credit from: https://github.com/salesforce/cove 70 | #------------------------------ 71 | class ContextualEmbedV2(nn.Module): 72 | def __init__(self, model_path, padding_idx=0): 73 | super(ContextualEmbedV2, self).__init__() 74 | state_dict = torch.load(model_path) 75 | self.rnn1 = nn.LSTM(300, 300, num_layers=1, bidirectional=True) 76 | self.rnn2 = nn.LSTM(600, 300, num_layers=1, bidirectional=True) 77 | state_dict1 = dict([(name, param.data) if isinstance(param, Parameter) else (name, param) 78 | for name, param in state_dict.items() if '0' in name]) 79 | state_dict2 = dict([(name.replace('1', '0'), param.data) if isinstance(param, Parameter) else (name.replace('1', '0'), param) 80 | for name, param in state_dict.items() if '1' in name]) 81 | self.rnn1.load_state_dict(state_dict1) 82 | self.rnn2.load_state_dict(state_dict2) 83 | for p in self.parameters(): p.requires_grad = False 84 | self.output_size = 600 85 | self.output_size = 600 86 | 87 | def setup_eval_embed(self, eval_embed, padding_idx=0): 88 | pass 89 | 90 | def forward(self, x, x_mask): 91 | """A pretrained MT-LSTM (McCann et. al. 2017). 92 | """ 93 | lengths = x_mask.data.eq(0).long().sum(1).squeeze() 94 | lens, indices = torch.sort(lengths, 0, True) 95 | output1, _ = self.rnn1(pack(x[indices], lens.tolist(), batch_first=True)) 96 | output2, _ = self.rnn2(output1) 97 | 98 | output1 = unpack(output1, batch_first=True)[0] 99 | output2 = unpack(output2, batch_first=True)[0] 100 | _, _indices = torch.sort(indices, 0) 101 | output1 = output1[_indices] 102 | output2 = output2[_indices] 103 | 104 | return output1, output2 105 | 106 | 107 | class ContextualEmbed(nn.Module): 108 | def __init__(self, path, vocab_size, emb_dim=300, embedding=None, padding_idx=0): 109 | super(ContextualEmbed, self).__init__() 110 | 111 | self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=padding_idx) 112 | if embedding is not None: 113 | self.embedding.weight.data = embedding 114 | 115 | state_dict = torch.load(path) 116 | self.rnn1 = nn.LSTM(300, 300, num_layers=1, bidirectional=True) 117 | self.rnn2 = nn.LSTM(600, 300, num_layers=1, bidirectional=True) 118 | state_dict1 = dict([(name, param.data) if isinstance(param, Parameter) else (name, param) 119 | for name, param in state_dict.items() if '0' in name]) 120 | state_dict2 = dict([(name.replace('1', '0'), param.data) if isinstance(param, Parameter) else (name.replace('1', '0'), param) 121 | for name, param in state_dict.items() if '1' in name]) 122 | self.rnn1.load_state_dict(state_dict1) 123 | self.rnn2.load_state_dict(state_dict2) 124 | for p in self.parameters(): p.requires_grad = False 125 | self.output_size = 600 126 | 127 | def setup_eval_embed(self, eval_embed, padding_idx=0): 128 | self.eval_embed = nn.Embedding(eval_embed.size(0), eval_embed.size(1), padding_idx = padding_idx) 129 | self.eval_embed.weight.data = eval_embed 130 | for p in self.eval_embed.parameters(): 131 | p.requires_grad = False 132 | 133 | def forward(self, x_idx, x_mask): 134 | emb = self.embedding if self.training else self.eval_embed 135 | x_hiddens = emb(x_idx) 136 | lengths = x_mask.data.eq(0).long().sum(1) 137 | lens, indices = torch.sort(lengths, 0, True) 138 | output1, _ = self.rnn1(pack(x_hiddens[indices], lens.tolist(), batch_first=True)) 139 | output2, _ = self.rnn2(output1) 140 | 141 | output1 = unpack(output1, batch_first=True)[0] 142 | output2 = unpack(output2, batch_first=True)[0] 143 | _, _indices = torch.sort(indices, 0) 144 | output1 = output1[_indices] 145 | output2 = output2[_indices] 146 | return output1, output2 147 | 148 | -------------------------------------------------------------------------------- /src/cmr/san_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | from torch.autograd import Variable 8 | from torch.nn.parameter import Parameter 9 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 10 | from torch.nn.utils.rnn import pack_padded_sequence as pack 11 | from torch.nn.utils import weight_norm 12 | from torch.nn import AlphaDropout 13 | import numpy as np 14 | from functools import wraps 15 | from .common import activation 16 | from .similarity import FlatSimilarityWrapper 17 | from .recurrent import RNN_MAP 18 | from .dropout_wrapper import DropoutWrapper 19 | 20 | SMALL_POS_NUM=1.0e-30 21 | RNN_MAP = {'lstm': nn.LSTMCell, 'gru': nn.GRUCell, 'rnn': nn.RNNCell} 22 | 23 | def generate_mask(new_data, dropout_p=0.0): 24 | new_data = (1-dropout_p) * (new_data.zero_() + 1) 25 | for i in range(new_data.size(0)): 26 | one = random.randint(0, new_data.size(1)-1) 27 | new_data[i][one] = 1 28 | mask = Variable(1.0/(1 - dropout_p) * torch.bernoulli(new_data), requires_grad=False) 29 | return mask 30 | 31 | class SANDecoder(nn.Module): 32 | def __init__(self, x_size, h_size, opt={}, prefix='answer', dropout=None): 33 | super(SANDecoder, self).__init__() 34 | self.prefix = prefix 35 | self.attn_b = FlatSimilarityWrapper(x_size, h_size, prefix, opt, dropout) 36 | self.attn_e = FlatSimilarityWrapper(x_size, h_size, prefix, opt, dropout) 37 | self.rnn_type = opt.get('{}_rnn_type'.format(prefix), 'gru') 38 | self.rnn =RNN_MAP.get(self.rnn_type, nn.GRUCell)(x_size, h_size) 39 | self.num_turn = opt.get('{}_num_turn'.format(prefix), 5) 40 | self.opt = opt 41 | self.mem_random_drop = opt.get('{}_mem_drop_p'.format(prefix), 0) 42 | self.answer_opt = opt.get('{}_opt'.format(prefix), 0) 43 | # 0: std mem; 1: random select step; 2 random selection; voting in pred; 3:sort merge 44 | self.mem_type = opt.get('{}_mem_type'.format(prefix), 0) 45 | self.gamma = opt.get('{}_mem_gamma'.format(prefix), 0.5) 46 | self.alpha = Parameter(torch.zeros(1, 1, 1)) 47 | 48 | self.proj = nn.Linear(h_size, x_size) if h_size != x_size else None 49 | if dropout is None: 50 | self.dropout = DropoutWrapper(opt.get('{}_dropout_p'.format(self.prefix), 0)) 51 | else: 52 | self.dropout = dropout 53 | self.h2h = nn.Linear(h_size, h_size) 54 | self.a2h = nn.Linear(x_size, h_size, bias=False) 55 | self.luong_output_layer = nn.Linear(h_size + x_size, h_size) 56 | 57 | def forward(self, input, hidden, context, context_mask): 58 | #print(input.size(), hidden.size(), context.size(), context_mask.size()) 59 | hidden = self.dropout(hidden) 60 | hidden = self.rnn(input, hidden) 61 | 62 | if self.opt['model_type'] == 'san': 63 | attn = self.attention(context, hidden, context_mask) 64 | attn_h = torch.cat([hidden, attn], dim=1) 65 | new_hidden = F.tanh(self.luong_output_layer(attn_h)) 66 | elif self.opt['model_type'] in {'seq2seq', 'memnet'}: 67 | new_hidden = hidden 68 | else: 69 | raise ValueError('Unknown model type: {}'.format(self.opt['model_type'])) 70 | 71 | return new_hidden 72 | 73 | def attention(self, x, h0, x_mask): 74 | if self.answer_opt in {1, 2, 3}: 75 | st_scores = self.attn_b(x, h0, x_mask) 76 | if self.answer_opt == 3: 77 | ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1), x).squeeze(1) 78 | ptr_net_b = self.dropout(ptr_net_b) 79 | xb = ptr_net_b if self.proj is None else self.proj(ptr_net_b) 80 | end_scores = self.attn_e(x, h0 + xb, x_mask) 81 | ptr_net_e = torch.bmm(F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1) 82 | ptr_net_in = (ptr_net_b + ptr_net_e)/2.0 83 | elif self.answer_opt == 2: 84 | ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1), x).squeeze(1) 85 | ptr_net_b = self.dropout(ptr_net_b) 86 | xb = ptr_net_b if self.proj is None else self.proj(ptr_net_b) 87 | end_scores = self.attn_e(x, xb, x_mask) 88 | ptr_net_e = torch.bmm(F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1) 89 | ptr_net_in = ptr_net_e 90 | elif self.answer_opt == 1: 91 | ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1), x).squeeze(1) 92 | ptr_net_b = self.dropout(ptr_net_b) 93 | ptr_net_in = ptr_net_b 94 | else: 95 | end_scores = self.attn_e(x, h0, x_mask) 96 | ptr_net_e = torch.bmm(F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1) 97 | ptr_net_in = ptr_net_e 98 | 99 | return ptr_net_in 100 | -------------------------------------------------------------------------------- /src/cmr/sub_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class PositionwiseNN(nn.Module): 8 | def __init__(self, idim, hdim, dropout=None): 9 | super(PositionwiseNN, self).__init__() 10 | self.w_0 = nn.Conv1d(idim, hdim, 1) 11 | self.w_1 = nn.Conv1d(hdim, hdim, 1) 12 | self.dropout = dropout 13 | 14 | def forward(self, x): 15 | output = F.relu(self.w_0(x.transpose(1, 2))) 16 | output = self.dropout(output) 17 | output = self.w_1(output) 18 | output = self.dropout(output).transpose(2, 1) 19 | return output 20 | 21 | 22 | class LayerNorm(nn.Module): 23 | """ 24 | ref: https://github.com/pytorch/pytorch/issues/1959 25 | :https://arxiv.org/pdf/1607.06450.pdf 26 | """ 27 | 28 | 29 | def __init__(self, hidden_size, eps=1e-4): 30 | super(LayerNorm, self).__init__() 31 | self.alpha = Parameter(torch.ones(1, 1, hidden_size)) # gain g 32 | self.beta = Parameter(torch.zeros(1, 1, hidden_size)) # bias b 33 | self.eps = eps 34 | 35 | 36 | def forward(self, x): 37 | """ 38 | Args: 39 | :param x: batch * len * input_size 40 | 41 | Returns: 42 | normalized x 43 | """ 44 | mu = torch.mean(x, 2, keepdim=True).expand_as(x) 45 | sigma = torch.std(x, 2, keepdim=True).expand_as(x) 46 | return (x - mu) / (sigma + self.eps) * self.alpha.expand_as(x) + self.beta.expand_as(x) 47 | 48 | 49 | class Highway(nn.Module): 50 | def __init__(self, size, num_layers, f=F.relu): 51 | super(Highway, self).__init__() 52 | self.num_layers = num_layers 53 | self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)]) 54 | self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)]) 55 | self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)]) 56 | self.f = f 57 | 58 | def forward(self, x): 59 | """ 60 | Args: 61 | :param x: tensor with shape of [batch_size, size] 62 | Returns: 63 | tensor with shape of [batch_size, size] 64 | applies σ(x)* (f(G(x))) + (1 - σ(x))* (Q(x)) transformation 65 | G and Q is affine transformation, 66 | f is non-linear transformation, σ(x) is affine transformation with sigmoid non-linearition 67 | and * is element-wise multiplication 68 | """ 69 | for layer in range(self.num_layers): 70 | gate = F.sigmoid(self.gate[layer](x)) 71 | nonlinear = self.f(self.nonlinear[layer](x)) 72 | linear = self.linear[layer](x) 73 | x = gate * nonlinear + (1 - gate) * linear 74 | return x 75 | -------------------------------------------------------------------------------- /src/demo_dialog.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | import json, subprocess, torch, os, time, pdb 4 | import numpy as np 5 | from knowledge import KnowledgeBase 6 | from mrc import BidafQA 7 | import base64 8 | from flask import Flask, request, render_template 9 | from open_dialog import DialoGPT 10 | from grounded import ConversingByReading 11 | from flask_restful import Resource, Api 12 | from tts import TextToSpeech 13 | from ranker import Ranker 14 | score_names = ['fwd', 'rvs', 'rep', 'info', 'score'] 15 | 16 | 17 | class DialogBackend: 18 | def __init__(self): 19 | self.turn_sep = ' <|endoftext|> ' 20 | 21 | def history2inp(self, context): 22 | turns = context.split(' __EOS__ ') 23 | context = self.turn_sep.join(turns).strip() 24 | query = turns[-1].strip() 25 | return context, query 26 | 27 | 28 | class DialogBackendLocal(DialogBackend): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | self.model_mrc = BidafQA() 34 | self.model_cmr = ConversingByReading() 35 | self.model_open = DialoGPT() 36 | self.kb = KnowledgeBase() 37 | model_mmi = DialoGPT(path_model='models/DialoGPT/small_reverse.pkl') 38 | self.ranker = Ranker(self.model_open, model_mmi) 39 | self.local = True 40 | 41 | 42 | def predict(self, context, max_n=-1): 43 | 44 | context, query = self.history2inp(context) 45 | print('backend running, context = %s'%context) 46 | 47 | # get results from different models 48 | results = self.model_open.predict(context) 49 | 50 | passages = [] 51 | url_snippet = [] 52 | for line in open('args/kb_sites.txt', encoding='utf-8'): 53 | cust = line.strip('\n') 54 | kb_args = {'domain': 'cust', 'cust': cust, 'must_include':[]} 55 | url_snippet.append(self.kb.predict(query, args=kb_args)[0]) 56 | passage = ' ... '.join([snippet for _, snippet in url_snippet]) 57 | passages.append((passage, query)) 58 | 59 | for passage, kb_query in passages: 60 | results += self.model_mrc.predict(kb_query, passage) 61 | results += self.model_cmr.predict(kb_query, passage) 62 | 63 | # rank hyps from different models 64 | 65 | hyps = [hyp for _, _, hyp in results] 66 | scored = self.ranker.predict(context, hyps) 67 | ret = [] 68 | for i, d in enumerate(scored): 69 | d['way'], _, d['hyp'] = results[i] 70 | ret.append((d['score'], d)) 71 | ranked = [d for _, d in sorted(ret, reverse=True)] 72 | if max_n > 0: 73 | ranked = ranked[:min(len(ranked), max_n)] 74 | return ranked, url_snippet 75 | 76 | 77 | 78 | class DialogBackendRemote(DialogBackend): 79 | 80 | def __init__(self, host): 81 | super().__init__() 82 | self.local = False 83 | self.host = host 84 | 85 | def predict(self, context): 86 | cmd = 'curl http://%s/ -d "context=%s" -X GET'%(self.host, context) 87 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 88 | output, error = process.communicate() 89 | if error is not None: 90 | print(error) 91 | return [], [] 92 | ret = json.loads(output.decode()) 93 | return ret['responses'], ret['passages'] 94 | 95 | 96 | def cmd_demo(backend): 97 | while True: 98 | src = input('\nUSER: ') 99 | if len(src) == 0: 100 | break 101 | with torch.no_grad(): 102 | ranked, url_passages = backend.predict(src) 103 | for url, passage in url_passages: 104 | print(url) 105 | print(passage) 106 | print() 107 | 108 | for d in ranked: 109 | ss = [] 110 | for k in d: 111 | if k not in ['way', 'hyp']: 112 | ss.append('%s %.3f'%(k, d[k])) 113 | line = '\t'.join([' '.join(ss), d['way'], d['hyp']]) 114 | print(line) 115 | 116 | 117 | def encode_file(path): 118 | code = base64.b64encode(open(path, 'rb').read()) 119 | return code.decode('utf-8') 120 | 121 | 122 | class Memo: 123 | def __init__(self): 124 | self.reset() 125 | 126 | def add_turn(self, tup): 127 | self.history.append(tup) 128 | 129 | def reset(self): 130 | self.history = [] 131 | 132 | def get_history(self): 133 | return self.history[:] 134 | 135 | 136 | def web_demo(backend, port=5000): 137 | tts = TextToSpeech() 138 | app = Flask(__name__) 139 | memo = Memo() 140 | 141 | @app.route('/', methods=['GET', 'POST']) 142 | def root(): 143 | if request.method == 'POST': 144 | query = request.form['inp_query'] 145 | v_new = (request.form.get('inp_new') is not None) 146 | if v_new: 147 | memo.reset() 148 | 149 | memo.add_turn(('User', query)) 150 | history = memo.get_history() 151 | context = ' __EOS__ '.join([utt for _, utt in history]) 152 | 153 | with torch.no_grad(): 154 | dd_hyp, url_snippet = backend.predict(context) 155 | 156 | memo.add_turn(('Agent', dd_hyp[0]['hyp'])) 157 | hyps = [] 158 | for d in dd_hyp: 159 | score = ['%.2f'%d.get(k, np.nan) for k in score_names] 160 | hyps.append([score, d['way'], d['hyp'], d.get('hyp_en','')]) 161 | 162 | path_audio = None 163 | if len(dd_hyp) > 0: 164 | hyp0 = dd_hyp[0]['hyp'] 165 | path_audio = tts.get_audio(hyp0) 166 | v_new = 0 167 | 168 | else: 169 | history = [] 170 | url_snippet = [] 171 | hyps = [] 172 | path_audio = None 173 | v_new = 1 174 | 175 | if path_audio is None: 176 | audio_code = '' 177 | else: 178 | audio_code = encode_file(path_audio) 179 | 180 | max_len = 30 181 | passages = [] 182 | for url, snippet in url_snippet: 183 | url_display = url.replace('http:','').replace('https:','').strip('/') 184 | url_display = url_display.replace('en.wikipedia.org','') 185 | if len(url_display) > max_len: 186 | url_display = url_display[:max_len] + '...' 187 | passages.append([url_display, url, snippet]) 188 | 189 | html = render_template('dialog.html', 190 | score_header=score_names, 191 | history=history, 192 | passages=passages, hyps=hyps, 193 | audio_code=audio_code, 194 | v_new=v_new, 195 | ) 196 | 197 | html = html.replace('value="1"','checked') 198 | html = html.replace('value="0"','') 199 | return html 200 | 201 | app.run(host='0.0.0.0', port=port) 202 | 203 | 204 | class ApiResource(Resource): 205 | def __init__(self, backend): 206 | self.backend = backend 207 | 208 | def get(self): 209 | context = request.form['context'] 210 | with torch.no_grad(): 211 | dd_hyp, url_snippet = self.backend.predict( 212 | context, 213 | ) 214 | 215 | ret = { 216 | # parsed input ---- 217 | 'context': context, 218 | 'passages': url_snippet, 219 | 'responses': dd_hyp, 220 | } 221 | return ret 222 | 223 | 224 | 225 | def restful_api_demo(backend, port=5000): 226 | app = Flask(__name__) 227 | api = Api(app) 228 | api.add_resource(ApiResource, '/', 229 | resource_class_kwargs={'backend': backend}) 230 | app.run(host='0.0.0.0', port=port) 231 | 232 | 233 | 234 | if __name__ == "__main__": 235 | import argparse 236 | parser = argparse.ArgumentParser() 237 | parser.add_argument('mode', type=str, help="`cmd`, `api`, or `web`") 238 | parser.add_argument('--remote', type=str, default='') 239 | parser.add_argument('--port', type=int, default=5000) 240 | args = parser.parse_args() 241 | 242 | if args.remote: 243 | backend = DialogBackendRemote(args.remote) 244 | else: 245 | backend = DialogBackendLocal() 246 | if args.mode == 'cmd': 247 | cmd_demo(backend) 248 | elif args.mode == 'web': 249 | web_demo(backend, port=args.port) 250 | elif args.mode == 'api': 251 | restful_api_demo(backend, port=args.port) -------------------------------------------------------------------------------- /src/demo_doc_gen.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | import json, subprocess, torch, os, time, pdb 4 | import numpy as np 5 | from knowledge import KnowledgeBase 6 | import base64 7 | from flask import Flask, request, render_template 8 | from lm import LanguageModel 9 | from grounded import ContentTransfer 10 | from flask_restful import Resource, Api 11 | from ranker import Ranker 12 | score_names = ['fwd', 'rep', 'info', 'score'] 13 | 14 | 15 | class DialogBackend: 16 | def get_query(self, context): 17 | return context.split('. ')[-1] 18 | 19 | 20 | class DialogBackendLocal(DialogBackend): 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | self.model_lm = LanguageModel() 26 | self.model_ct = ContentTransfer() 27 | self.kb = KnowledgeBase() 28 | self.ranker = Ranker(self.model_lm) 29 | self.local = True 30 | 31 | 32 | def predict(self, context, max_n=1): 33 | print('backend running, context = %s'%context) 34 | query = self.get_query(context) 35 | 36 | # get results from different models 37 | results = self.model_lm.predict(context) 38 | 39 | passages = [] 40 | url_snippet = [] 41 | for line in open('args/kb_sites.txt', encoding='utf-8'): 42 | cust = line.strip('\n') 43 | kb_args = {'domain': 'cust', 'cust': cust, 'must_include':[]} 44 | url_snippet.append(self.kb.predict(query, args=kb_args)[0]) 45 | passage = ' ... '.join([snippet for _, snippet in url_snippet]) 46 | passages.append((passage, query)) 47 | 48 | for passage, kb_query in passages: 49 | results += self.model_ct.predict(kb_query, passage) 50 | 51 | # rank hyps from different models 52 | 53 | hyps = [hyp for _, _, hyp in results] 54 | scored = self.ranker.predict(context, hyps) 55 | ret = [] 56 | for i, d in enumerate(scored): 57 | d['way'], _, d['hyp'] = results[i] 58 | ret.append((d['score'], d)) 59 | ranked = [d for _, d in sorted(ret, reverse=True)] 60 | if max_n > 0: 61 | ranked = ranked[:min(len(ranked), max_n)] 62 | return ranked, url_snippet 63 | 64 | 65 | 66 | class DialogBackendRemote(DialogBackend): 67 | 68 | def __init__(self, host): 69 | super().__init__() 70 | self.local = False 71 | self.host = host 72 | 73 | def predict(self, context): 74 | cmd = 'curl http://%s/ -d "context=%s" -X GET'%(self.host, context) 75 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 76 | output, error = process.communicate() 77 | if error is not None: 78 | print(error) 79 | return [], [] 80 | ret = json.loads(output.decode()) 81 | return ret['responses'], ret['passages'] 82 | 83 | 84 | def cmd_demo(backend): 85 | while True: 86 | src = input('\nUSER: ') 87 | if len(src) == 0: 88 | break 89 | with torch.no_grad(): 90 | ranked, url_passages = backend.predict(src) 91 | for url, passage in url_passages: 92 | print(url) 93 | print(passage) 94 | print() 95 | 96 | for d in ranked: 97 | ss = [] 98 | for k in d: 99 | if k not in ['way', 'hyp']: 100 | ss.append('%s %.3f'%(k, d[k])) 101 | line = '\t'.join([' '.join(ss), d['way'], d['hyp'].replace('\n',' ')]) 102 | print(line) 103 | 104 | 105 | 106 | def web_demo(backend, port=5000): 107 | app = Flask(__name__) 108 | 109 | @app.route('/', methods=['GET', 'POST']) 110 | def root(): 111 | if request.method == 'POST': 112 | context = request.form['inp_cxt'] 113 | with torch.no_grad(): 114 | dd_hyp, url_snippet = backend.predict(context) 115 | 116 | hyps = [] 117 | for d in dd_hyp: 118 | score = ['%.2f'%d.get(k, np.nan) for k in score_names] 119 | hyps.append([score, d['way'], d['hyp'], d.get('hyp_en','')]) 120 | 121 | else: 122 | context = '' 123 | url_snippet = [] 124 | hyps = [] 125 | 126 | max_len = 30 127 | passages = [] 128 | for url, snippet in url_snippet: 129 | url_display = url.replace('http:','').replace('https:','').strip('/') 130 | url_display = url_display.replace('en.wikipedia.org','') 131 | if len(url_display) > max_len: 132 | url_display = url_display[:max_len] + '...' 133 | passages.append([url_display, url, snippet]) 134 | 135 | html = render_template('doc_gen.html', 136 | score_header=score_names, 137 | passages=passages, hyps=hyps, 138 | cxt_val=context, 139 | ) 140 | 141 | html = html.replace('value="1"','checked') 142 | html = html.replace('value="0"','') 143 | return html 144 | 145 | app.run(host='0.0.0.0', port=port) 146 | 147 | 148 | class ApiResource(Resource): 149 | def __init__(self, backend): 150 | self.backend = backend 151 | 152 | def get(self): 153 | context = request.form['context'] 154 | with torch.no_grad(): 155 | dd_hyp, url_snippet = self.backend.predict( 156 | context, 157 | ) 158 | 159 | ret = { 160 | # parsed input ---- 161 | 'context': context, 162 | 'passages': url_snippet, 163 | 'responses': dd_hyp, 164 | } 165 | return ret 166 | 167 | 168 | 169 | def restful_api_demo(backend, port=5000): 170 | app = Flask(__name__) 171 | api = Api(app) 172 | api.add_resource(ApiResource, '/', 173 | resource_class_kwargs={'backend': backend}) 174 | app.run(host='0.0.0.0', port=port) 175 | 176 | 177 | 178 | if __name__ == "__main__": 179 | import argparse 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('mode', type=str, help="`cmd`, `api`, or `web`") 182 | parser.add_argument('--remote', type=str, default='') 183 | parser.add_argument('--port', type=int, default=5000) 184 | args = parser.parse_args() 185 | 186 | if args.remote: 187 | backend = DialogBackendRemote(args.remote) 188 | else: 189 | backend = DialogBackendLocal() 190 | if args.mode == 'cmd': 191 | cmd_demo(backend) 192 | elif args.mode == 'web': 193 | web_demo(backend, port=args.port) 194 | elif args.mode == 'api': 195 | restful_api_demo(backend, port=args.port) -------------------------------------------------------------------------------- /src/grounded.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | from cmr.process_raw_data import filter_query, filter_fact 4 | from cmr.batcher import load_meta, prepare_batch_data 5 | from cmr.model import DocReaderModel 6 | import json, os, torch 7 | import numpy as np 8 | from todo import pick_tokens 9 | 10 | 11 | class JsonConfig: 12 | def __init__(self, path): 13 | d = json.loads(open(path, encoding='utf-8').readline()) 14 | for k in d: 15 | setattr(self, k, d[k]) 16 | 17 | 18 | class ConversingByReading: 19 | # ref: https://github.com/qkaren/converse_reading_cmr 20 | 21 | def __init__(self, use_cuda=True): 22 | args = JsonConfig('models/cmr/args.json') 23 | self.embedding, self.opt, self.vocab = load_meta(vars(args), args.meta) 24 | self.opt['skip_tokens'] = self.get_skip_tokens(self.opt["skip_tokens_file"]) 25 | self.opt['skip_tokens_first'] = self.get_skip_tokens(self.opt["skip_tokens_first_file"]) 26 | self.state_dict = torch.load(args.model_dir)["state_dict"] 27 | self.model = DocReaderModel(self.opt, self.embedding, self.state_dict) 28 | self.model.setup_eval_embed(self.embedding) 29 | if use_cuda: 30 | self.model.cuda() 31 | 32 | def get_skip_tokens(self, path): 33 | skip_tokens = None 34 | if path and os.path.isfile(path): 35 | skip_tokens = [] 36 | with open(path, 'r') as f: 37 | for word in f: 38 | word = word.strip().rstrip('\n') 39 | try: 40 | skip_tokens.append(self.vocab[word]) 41 | except: 42 | print("Token %s not present in dictionary" % word) 43 | return skip_tokens 44 | 45 | def predict(self, context, passage, top_k=2, verbose=False): 46 | data = [{'query': context, 'fact': passage}] 47 | 48 | def pred2words(prediction, vocab): 49 | EOS_token = 3 50 | outputs = [] 51 | for pred in prediction: 52 | new_pred = pred 53 | for i, x in enumerate(pred): 54 | if int(x) == EOS_token: 55 | new_pred = pred[:i] 56 | break 57 | outputs.append(' '.join([vocab[int(x)] for x in new_pred])) 58 | return outputs 59 | 60 | processed_data = prepare_batch_data([self.preprocess_data(x) for x in data], ground_truth=False) 61 | logPs, predictions = self.model.predict(processed_data, pick_tokens=pick_tokens) 62 | pred_word = pred2words(predictions, self.vocab) 63 | hyps = [np.asarray(x, dtype=np.str).tolist() for x in pred_word] 64 | return [('CMR', np.exp(logPs[i]), hyps[i]) for i in range(len(hyps))] 65 | 66 | 67 | def preprocess_data(self, sample, q_cutoff=30, doc_cutoff=500): 68 | def tok_func(toks): 69 | return [self.vocab[w] for w in toks] 70 | 71 | fea_dict = {} 72 | 73 | query_tokend = filter_query(sample['query'].strip(), max_len=q_cutoff).split() 74 | doc_tokend = filter_fact(sample['fact'].strip()).split() 75 | if len(doc_tokend) > doc_cutoff: 76 | doc_tokend = doc_tokend[:doc_cutoff] + [''] 77 | 78 | # TODO 79 | fea_dict['query_tok'] = tok_func(query_tokend) 80 | fea_dict['query_pos'] = [] 81 | fea_dict['query_ner'] = [] 82 | 83 | fea_dict['doc_tok'] = tok_func(doc_tokend) 84 | fea_dict['doc_pos'] = [] 85 | fea_dict['doc_ner'] = [] 86 | fea_dict['doc_fea'] = '' 87 | 88 | if len(fea_dict['query_tok']) == 0: 89 | fea_dict['query_tok'] = [0] 90 | if len(fea_dict['doc_tok']) == 0: 91 | fea_dict['doc_tok'] = [0] 92 | 93 | return fea_dict 94 | 95 | 96 | class OptionContentTransfer: 97 | def __init__(self): 98 | self.path_model = 'models/crg/crg_model.pt' 99 | self.path_tokenizer = 'models/crg/bpeM.model' 100 | self.beam_size = 5 101 | self.batch_size = 1 102 | self.max_sent_length = 100 103 | self.replace_unk = True 104 | self.verbose = False 105 | self.n_best = 1 106 | self.cuda = True 107 | 108 | 109 | import onmt 110 | class ContentTransfer: 111 | # ref: https://github.com/shrimai/Towards-Content-Transfer-through-Grounded-Text-Generation 112 | 113 | def __init__(self): 114 | import sentencepiece as spm 115 | opt = OptionContentTransfer() 116 | self.model = onmt.Translator(opt) 117 | self.tokenizer = spm.SentencePieceProcessor() 118 | self.tokenizer.load(opt.path_tokenizer) 119 | 120 | def encode(self, s): 121 | return ' '.join(self.tokenizer.EncodeAsPieces(s.lower())) 122 | 123 | def decode(self, s): 124 | return self.tokenizer.DecodePieces(s).replace(chr(92),' ') 125 | 126 | def predict(self, query, passage, min_score_article=0.3, min_score_passage=0.1, verbose=False): 127 | cxt = self.encode(query) 128 | src = self.encode(passage) 129 | srcBatch = [src.split()] 130 | cxtBatch = [cxt.split()] 131 | hyp, _, _ = self.model.translate(srcBatch, cxtBatch, None) 132 | hyp = self.decode(hyp[0][0]).split('|')[0].strip() 133 | return [('CT', 0, hyp)] 134 | 135 | 136 | 137 | def play_grounded(which): 138 | if which == 'cmr': 139 | model = ConversingByReading() 140 | elif which == 'ct': 141 | model = ContentTransfer() 142 | 143 | while True: 144 | cxt = input('\nCONTEXT:\t') 145 | if not cxt: 146 | break 147 | passage = input('\nPASSAGE:\t') 148 | if not passage: 149 | break 150 | ret = model.predict(cxt, passage) 151 | for way, prob, hyp in ret: 152 | print('%s %.3f\t%s'%(way, prob, hyp)) 153 | 154 | 155 | if __name__ == "__main__": 156 | import sys 157 | play_grounded(sys.argv[1]) -------------------------------------------------------------------------------- /src/knowledge.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | import pke, time 4 | import numpy as np 5 | import pdb, pickle, os, re, json, requests 6 | from shared import get_api_key 7 | 8 | from azure.cognitiveservices.search.websearch import WebSearchAPI 9 | from msrest.authentication import CognitiveServicesCredentials 10 | 11 | 12 | def extract_keyphrase(txt, n_max=5): 13 | try: 14 | extractor = pke.unsupervised.TopicRank() 15 | extractor.load_document(input=txt, language='en') 16 | extractor.candidate_selection() 17 | extractor.candidate_weighting() 18 | results = extractor.get_n_best(n=n_max) 19 | except ValueError: 20 | return dict() 21 | 22 | ret = dict() 23 | for k, score in results: 24 | ret[k] = score 25 | return ret 26 | 27 | 28 | class KnowledgeBase: 29 | # select the most relavant snippets from external knowledge source 30 | 31 | def __init__(self): 32 | self.fld = 'kb' 33 | bing_v7_key = get_api_key('bing_v7')[0] 34 | self.bing_web_client = WebSearchAPI(CognitiveServicesCredentials(bing_v7_key)) 35 | 36 | 37 | def build_query(self, query, site=None, must_include=[]): 38 | if site is not None: 39 | query = query + ' site:%s'%site.strip() 40 | return query + ' ' + ' '.join(['"%s"'%w for w in must_include]) 41 | 42 | 43 | def search_bing_web(self, query, site=None, must_include=[]): 44 | # https://docs.microsoft.com/en-us/azure/cognitive-services/bing-web-search/web-sdk-python-quickstart 45 | ord_url_snippet = [] 46 | snippets = [] 47 | t0 = time.time() 48 | query = self.build_query(query, site, must_include) 49 | web_data = self.bing_web_client.web.search( 50 | query=query, 51 | ) 52 | if web_data.web_pages is None: 53 | return [] 54 | for i, data in enumerate(web_data.web_pages.value): 55 | snippets.append(data.snippet) 56 | ord_url_snippet.append((i, data.url, data.snippet)) 57 | return ord_url_snippet 58 | 59 | 60 | def rank_passage(self, query, ord_url_snippet, n_max=3): 61 | # a heuristic method by Xiang Gao, see Section 3.1 of https://arxiv.org/abs/2005.08365 62 | 63 | query_k = extract_keyphrase(query, n_max=n_max) 64 | set_q = dict() 65 | for k in query_k: 66 | set_q[k] = set(k.lower().split()) 67 | n = len(ord_url_snippet) 68 | score_i = [] 69 | for i in range(n): 70 | order, _, snippet = ord_url_snippet[i] 71 | ww = set(re.sub(r"[^a-z]", " ", snippet.lower()).split()) 72 | overlap = 0 73 | for k in query_k: 74 | overlap += len(set_q[k] & ww)/len(set_q[k]) * query_k[k] 75 | score = overlap * 1./(order + 1) 76 | score_i.append((score, i)) 77 | picked = pick_top(score_i) 78 | return [(ord_url_snippet[i][1], ord_url_snippet[i][2]) for _, i in picked[:min(len(picked), n_max)]] 79 | 80 | 81 | def predict(self, query, args=None): 82 | query = query.lower() 83 | if args is None: 84 | args = {'domain':'web', 'must_include':[]} 85 | domain = args['domain'] 86 | 87 | if domain in ['web', 'cust']: 88 | site = args['cust'] if domain == 'cust' else None 89 | url_snippet = self.search_bing_web(query, site=site, must_include=args['must_include']) 90 | allowed = args.get('allowed') 91 | if allowed is not None: 92 | kept = [] 93 | for url, snippet in url_snippet: 94 | if url in allowed: 95 | kept.append([url, snippet]) 96 | url_snippet = kept[:] 97 | elif domain == 'news': 98 | url_snippet = self.search_bing_news(query, must_include=args['must_include']) 99 | elif domain == 'user': 100 | lines = args['txt_kb'].split('\n') 101 | url_snippet = [] 102 | window = 1 103 | for i in range(len(lines) - window + 1): 104 | snippet = ' '.join(lines[i:i+window]).strip() 105 | if len(snippet) > 0: 106 | url_snippet.append(['user', snippet]) 107 | else: 108 | raise ValueError 109 | ranked = self.rank_passage(query, url_snippet) 110 | return ranked 111 | 112 | 113 | def pick_top(score_v, crit=0.1): 114 | if len(score_v) == 0: 115 | return [] 116 | s = sorted(score_v, reverse=True) 117 | max_score = s[0][0] 118 | picked = [] 119 | for score, v in s: 120 | if score < max_score * crit: 121 | break 122 | picked.append((score, v)) 123 | return picked 124 | 125 | 126 | def play_kb(): 127 | kb = KnowledgeBase() 128 | while True: 129 | print('\n(empty query to exit)') 130 | query = input('QUERY:\t') 131 | if not query: 132 | break 133 | url_snippets = kb.predict(query) 134 | for url, snippet in url_snippets: 135 | print('\nURL:\t%s'%url) 136 | print('TXT:\t%s'%snippet) 137 | 138 | 139 | if __name__ == "__main__": 140 | play_kb() -------------------------------------------------------------------------------- /src/lm.py: -------------------------------------------------------------------------------- 1 | from todo import pick_tokens 2 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 3 | import torch, pdb 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class LanguageModel: 9 | # based on: https://github.com/huggingface/transformers/blob/master/examples/run_generation.py 10 | 11 | def __init__(self, use_cuda=True): 12 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 13 | self.model = GPT2LMHeadModel.from_pretrained('gpt2') 14 | if use_cuda: 15 | self.model = self.model.cuda() 16 | self.model.eval() 17 | self.use_cuda = use_cuda 18 | self.ix_EOS = 50256 19 | 20 | 21 | def predict(self, inp, beam=10, max_t=30): 22 | # num_samples=1, temperature=0., repetition_penalty=1.,top_k=0, top_p=0.9 23 | ids = self.tokenizer.encode(inp) 24 | len_cxt = len(ids) 25 | context = torch.tensor(ids, dtype=torch.long) 26 | if self.use_cuda: 27 | context = context.cuda() 28 | context = context.unsqueeze(0) 29 | tokens = context 30 | way = 'GPT2' 31 | 32 | finished = [] 33 | hyp_set = set() 34 | sum_logP = [0] 35 | max_t = 30 36 | for t in range(max_t): 37 | with torch.no_grad(): 38 | outputs = self.model(tokens) 39 | predictions = outputs[0] 40 | logits = predictions[:, -1, :] # only care the last step. [n_hyp, vocab] 41 | prob = F.softmax(logits, dim=-1) 42 | logP = torch.log(prob) 43 | if t == max_t - 1: 44 | picked_tokens = torch.LongTensor([self.ix_EOS] * logits.shape[0]).view(-1, 1) 45 | if self.use_cuda: 46 | picked_tokens = picked_tokens.cuda() 47 | else: 48 | picked_tokens = pick_tokens(prob) 49 | 50 | cand = [] 51 | for i in range(picked_tokens.shape[0]): 52 | for j in range(picked_tokens.shape[1]): 53 | ix = picked_tokens[i, j].item() 54 | _sum_logP = sum_logP[i] + logP[i, ix].item() 55 | cand.append((_sum_logP, i, j)) 56 | 57 | if not cand: 58 | break 59 | cand = sorted(cand, reverse=True) 60 | cand = cand[:min(len(cand), beam)] 61 | sum_logP = [] 62 | cur = [] 63 | nxt = [] 64 | for _sum_logP, i, j in cand: 65 | ix = picked_tokens[i, j].item() 66 | if ix == self.ix_EOS: 67 | seq = [w.item() for w in tokens[i, len_cxt: len_cxt + t]] 68 | seq_tup = tuple(seq) 69 | if seq_tup not in hyp_set: 70 | finished.append((np.exp(_sum_logP/len(seq)), seq)) 71 | hyp_set.add(seq_tup) 72 | continue 73 | 74 | cur.append(tokens[i:i+1,:]) 75 | nxt.append(picked_tokens[i:i+1, j]) 76 | sum_logP.append(_sum_logP) 77 | if len(cur) == beam: 78 | break 79 | 80 | if not cur: 81 | break 82 | tokens = torch.cat([torch.cat(cur, dim=0), torch.cat(nxt, dim=0).unsqueeze(-1)], dim=-1) 83 | 84 | finished = sorted(finished, reverse=True) 85 | ret = [] 86 | for prob, seq in finished: 87 | hyp = self.tokenizer.decode(seq).strip() 88 | ret.append((way, prob, hyp)) 89 | if len(ret) == beam: 90 | break 91 | return sorted(ret, reverse=True) 92 | 93 | 94 | 95 | def tf_prob(self, context, hyps, batch=10, return_np=True): 96 | if isinstance(hyps, str): 97 | hyps = [hyps] 98 | i0 = 0 99 | prob = [] 100 | while i0 < len(hyps): 101 | i1 = min(i0 + batch, len(hyps)) 102 | with torch.no_grad(): 103 | prob.append(self._tf_prob(context, hyps[i0:i1])) 104 | i0 = i1 105 | if len(prob) > 1: 106 | prob = torch.cat(prob, dim=0) 107 | else: 108 | prob = prob[0] 109 | if return_np: 110 | if self.use_cuda: 111 | prob = prob.cpu() 112 | return prob.detach().numpy() 113 | else: 114 | return prob 115 | 116 | 117 | def _tf_prob(self, context, hyps): 118 | # converted what's from tokenizer.encode to what's should be used in logits 119 | enc2pred = {} 120 | ids_cxt = self.tokenizer.encode(context) 121 | ids_hyp = [] 122 | hyp_len = [] 123 | for hyp in hyps: 124 | raw_hyp_tokens = self.tokenizer.encode(hyp) 125 | hyp_tokens = [] 126 | for token in raw_hyp_tokens: 127 | hyp_tokens.append(enc2pred.get(token, token)) 128 | ids_hyp.append(hyp_tokens) 129 | hyp_len.append(len(hyp_tokens)) 130 | 131 | max_len = max(hyp_len) 132 | ids = [] 133 | mask = [] 134 | for i, seq in enumerate(ids_hyp): 135 | cat = ids_cxt + seq + [self.ix_EOS] * (max_len - hyp_len[i]) 136 | ids.append(cat) 137 | mask.append([1] * hyp_len[i] + [0] * (max_len - hyp_len[i])) 138 | ids = torch.tensor(ids) 139 | mask = torch.FloatTensor(mask) 140 | hyp_len = torch.FloatTensor(hyp_len) 141 | if self.use_cuda: 142 | ids = ids.to('cuda') 143 | mask = mask.to('cuda') 144 | hyp_len = hyp_len.to('cuda') 145 | 146 | l_cxt = len(ids_cxt) 147 | with torch.no_grad(): 148 | logits, _ = self.model(ids) 149 | logits = logits[:, l_cxt - 1: -1, :] # only care the part after cxt. ignore -1. 150 | logP = torch.log(F.softmax(logits, dim=-1)) 151 | 152 | logP_ids = logP.gather(dim=-1, index=ids[:,l_cxt:].unsqueeze(-1)).squeeze(-1) 153 | avg_logP = (logP_ids * mask).sum(dim=-1) / hyp_len 154 | return torch.exp(avg_logP) 155 | 156 | 157 | def play_lm(): 158 | lm = LanguageModel() 159 | while True: 160 | cxt = input('\nCONTEXT:\t') 161 | if not cxt: 162 | break 163 | ret = lm.predict(cxt) 164 | for way, prob, hyp in ret: 165 | print('%s %.3f\t%s'%(way, prob, hyp.replace('\n',' '))) 166 | 167 | 168 | if __name__ == "__main__": 169 | play_lm() -------------------------------------------------------------------------------- /src/mrc.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | from allennlp.predictors.predictor import Predictor 4 | 5 | class BidafQA: 6 | def __init__(self): 7 | self.model = Predictor.from_path('models/BiDAF/bidaf-model-2020.03.19.tar.gz') 8 | 9 | def predict(self, query, passage): 10 | ret = self.model.predict(passage=passage, question=query) 11 | span_str = ret['best_span_str'] 12 | span_prob = ret['span_start_probs'][ret['best_span'][0]] * ret['span_end_probs'][ret['best_span'][1]] 13 | return [('Bidaf', span_prob, span_str)] 14 | 15 | def play_mrc(): 16 | model = BidafQA() 17 | while True: 18 | q = input('QUERY:\t') 19 | p = input('PASSAGE:\t') 20 | ret = model.predict(q, p) 21 | for way, prob, ans in ret: 22 | print('%s %.3f\t%s'%(way, prob, ans)) 23 | 24 | if __name__ == "__main__": 25 | play_mrc() 26 | -------------------------------------------------------------------------------- /src/onmt/Beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import onmt 5 | 6 | 7 | class Beam(object): 8 | def __init__(self, size, cuda=False): 9 | 10 | self.size = size 11 | self.done = False 12 | 13 | self.tt = torch.cuda if cuda else torch 14 | 15 | # The score for each translation on the beam. 16 | self.scores = self.tt.FloatTensor(size).zero_() 17 | 18 | # The backpointers at each time-step. 19 | self.prevKs = [] 20 | 21 | # The outputs at each time-step. 22 | self.nextYs = [self.tt.LongTensor(size).fill_(onmt.Constants.PAD)] 23 | self.nextYs[0][0] = onmt.Constants.BOS 24 | 25 | # The attentions (matrix) for each time. 26 | self.attn = [] 27 | 28 | # Get the outputs for the current timestep. 29 | def getCurrentState(self): 30 | return self.nextYs[-1] 31 | 32 | # Get the backpointers for the current timestep. 33 | def getCurrentOrigin(self): 34 | return self.prevKs[-1] 35 | 36 | def advance(self, wordLk, attnOut): 37 | 38 | numWords = wordLk.size(1) 39 | 40 | # Sum the previous scores. 41 | if len(self.prevKs) > 0: 42 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 43 | else: 44 | beamLk = wordLk[0] 45 | 46 | flatBeamLk = beamLk.view(-1) 47 | 48 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 49 | self.scores = bestScores 50 | 51 | # bestScoresId is flattened beam x word array, so calculate which 52 | # word and beam each score came from 53 | prevK = bestScoresId / numWords 54 | self.prevKs.append(prevK) 55 | self.nextYs.append(bestScoresId - prevK * numWords) 56 | self.attn.append(attnOut.index_select(0, prevK)) 57 | 58 | # End condition is when top-of-beam is EOS. 59 | if self.nextYs[-1][0] == onmt.Constants.EOS: 60 | self.done = True 61 | 62 | return self.done 63 | 64 | def sortBest(self): 65 | return torch.sort(self.scores, 0, True) 66 | 67 | # Get the score of the best in the beam. 68 | def getBest(self): 69 | scores, ids = self.sortBest() 70 | return scores[1], ids[1] 71 | 72 | def getHyp(self, k): 73 | hyp, attn = [], [] 74 | # print(len(self.prevKs), len(self.nextYs), len(self.attn)) 75 | for j in range(len(self.prevKs) - 1, -1, -1): 76 | hyp.append(self.nextYs[j+1][k]) 77 | attn.append(self.attn[j][k]) 78 | k = self.prevKs[j][k] 79 | 80 | return hyp[::-1], torch.stack(attn[::-1]) 81 | -------------------------------------------------------------------------------- /src/onmt/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /src/onmt/Dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import random 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | import onmt 10 | 11 | 12 | class Dataset(object): 13 | 14 | def __init__(self, srcData, cxtData, tgtData, batchSize, cuda, volatile=False): 15 | self.src = srcData 16 | if cxtData: 17 | self.cxt = cxtData 18 | assert(len(self.src) == len(self.cxt)) 19 | else: 20 | self.cxt = None 21 | if tgtData: 22 | self.tgt = tgtData 23 | assert(len(self.src) == len(self.tgt)) 24 | else: 25 | self.tgt = None 26 | self.cuda = cuda 27 | 28 | self.batchSize = batchSize 29 | self.numBatches = math.ceil(len(self.src)/batchSize) 30 | self.volatile = volatile 31 | 32 | def _batchify(self, data, align_right=False, include_lengths=False): 33 | lengths = [x.size(0) for x in data] 34 | max_length = max(lengths) 35 | out = data[0].new(len(data), max_length).fill_(onmt.Constants.PAD) 36 | for i in range(len(data)): 37 | data_length = data[i].size(0) 38 | offset = max_length - data_length if align_right else 0 39 | out[i].narrow(0, offset, data_length).copy_(data[i]) 40 | 41 | if include_lengths: 42 | return out, lengths 43 | else: 44 | return out 45 | 46 | def __getitem__(self, index): 47 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 48 | srcBatch, lengths = self._batchify( 49 | self.src[index*self.batchSize:(index+1)*self.batchSize], 50 | align_right=False, include_lengths=True) 51 | 52 | if self.cxt: 53 | cxtBatch = self._batchify( 54 | self.cxt[index*self.batchSize:(index+1)*self.batchSize], 55 | align_right=False) 56 | else: 57 | cxtBatch = None 58 | if self.tgt: 59 | tgtBatch = self._batchify( 60 | self.tgt[index*self.batchSize:(index+1)*self.batchSize]) 61 | else: 62 | tgtBatch = None 63 | 64 | # within batch sorting by decreasing length for variable length rnns 65 | indices = range(len(srcBatch)) 66 | if tgtBatch is None and cxtBatch is None: 67 | batch = zip(indices, srcBatch) 68 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1])) 69 | indices, srcBatch = zip(*batch) 70 | elif tgtBatch is None: 71 | batch = zip(indices, srcBatch, cxtBatch) 72 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1])) 73 | indices, srcBatch, cxtBatch = zip(*batch) 74 | else: 75 | batch = zip(indices, srcBatch, cxtBatch, tgtBatch) 76 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1])) 77 | indices, srcBatch, cxtBatch, tgtBatch = zip(*batch) 78 | 79 | def wrap(b): 80 | if b is None: 81 | return b 82 | b = torch.stack(b, 0).t().contiguous() 83 | if self.cuda: 84 | b = b.cuda() 85 | b = Variable(b, volatile=self.volatile) 86 | return b 87 | 88 | return (wrap(srcBatch), lengths), wrap(cxtBatch), wrap(tgtBatch), indices 89 | 90 | def __len__(self): 91 | return self.numBatches 92 | 93 | 94 | def shuffle(self): 95 | data = list(zip(self.src, self.cxt, self.tgt)) 96 | self.src, self.cxt, self.tgt = zip(*[data[i] for i in torch.randperm(len(data))]) 97 | -------------------------------------------------------------------------------- /src/onmt/Dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import codecs 3 | 4 | class Dict(object): 5 | def __init__(self, data=None, lower=False): 6 | self.idxToLabel = {} 7 | self.labelToIdx = {} 8 | self.frequencies = {} 9 | self.lower = lower 10 | 11 | # Special entries will not be pruned. 12 | self.special = [] 13 | 14 | if data is not None: 15 | if type(data) == str: 16 | self.loadFile(data) 17 | else: 18 | self.addSpecials(data) 19 | 20 | def size(self): 21 | return len(self.idxToLabel) 22 | 23 | # Load entries from a file. 24 | def loadFile(self, filename): 25 | for line in open(filename): 26 | fields = line.split() 27 | if len(fields) > 2: 28 | idx = int(fields[-1]) 29 | label = ' '.join(fields[:-1]) 30 | else: 31 | label = fields[0] 32 | idx = int(fields[1]) 33 | self.add(label, idx) 34 | 35 | # Write entries to a file. 36 | def writeFile(self, filename): 37 | with codecs.open(filename, 'w', "utf-8") as file: 38 | for i in range(self.size()): 39 | label = self.idxToLabel[i] 40 | file.write('%s %d\n' % (label, i)) 41 | 42 | file.close() 43 | 44 | def lookup(self, key, default=None): 45 | key = key.lower() if self.lower else key 46 | try: 47 | return self.labelToIdx[key] 48 | except KeyError: 49 | return default 50 | 51 | def getLabel(self, idx, default=None): 52 | try: 53 | if torch.is_tensor(idx): 54 | idx = idx.item() 55 | return self.idxToLabel[idx] 56 | except KeyError: 57 | return default 58 | 59 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 60 | def addSpecial(self, label, idx=None): 61 | idx = self.add(label, idx) 62 | self.special += [idx] 63 | 64 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 65 | def addSpecials(self, labels): 66 | for label in labels: 67 | self.addSpecial(label) 68 | 69 | # Add `label` in the dictionary. Use `idx` as its index if given. 70 | def add(self, label, idx=None): 71 | label = label.lower() if self.lower else label 72 | if idx is not None: 73 | self.idxToLabel[idx] = label 74 | self.labelToIdx[label] = idx 75 | else: 76 | if label in self.labelToIdx: 77 | idx = self.labelToIdx[label] 78 | else: 79 | idx = len(self.idxToLabel) 80 | self.idxToLabel[idx] = label 81 | self.labelToIdx[label] = idx 82 | 83 | if idx not in self.frequencies: 84 | self.frequencies[idx] = 1 85 | else: 86 | self.frequencies[idx] += 1 87 | 88 | return idx 89 | 90 | # Return a new dictionary with the `size` most frequent entries. 91 | def prune(self, size): 92 | if size >= self.size(): 93 | return self 94 | 95 | # Only keep the `size` most frequent entries. 96 | freq = torch.Tensor( 97 | [self.frequencies[i] for i in range(len(self.frequencies))]) 98 | _, idx = torch.sort(freq, 0, True) 99 | 100 | newDict = Dict() 101 | newDict.lower = self.lower 102 | 103 | # Add special entries in all cases. 104 | for i in self.special: 105 | newDict.addSpecial(self.idxToLabel[i]) 106 | 107 | for i in idx[:size]: 108 | newDict.add(self.idxToLabel[i]) 109 | 110 | return newDict 111 | 112 | # Convert `labels` to indices. Use `unkWord` if not found. 113 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 114 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 115 | vec = [] 116 | 117 | if bosWord is not None: 118 | vec += [self.lookup(bosWord)] 119 | 120 | unk = self.lookup(unkWord) 121 | vec += [self.lookup(label, default=unk) for label in labels] 122 | 123 | if eosWord is not None: 124 | vec += [self.lookup(eosWord)] 125 | 126 | return torch.LongTensor(vec) 127 | 128 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 129 | def convertToLabels(self, idx, stop): 130 | labels = [] 131 | 132 | for i in idx: 133 | labels += [self.getLabel(i)] 134 | if i == stop: 135 | break 136 | 137 | return labels 138 | -------------------------------------------------------------------------------- /src/onmt/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import onmt.modules 5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | 8 | class Encoder(nn.Module): 9 | 10 | def __init__(self, opt, dicts): 11 | self.layers = opt.layers 12 | self.num_directions = 2 if opt.brnn else 1 13 | assert opt.rnn_size % self.num_directions == 0 14 | self.hidden_size = opt.rnn_size // self.num_directions 15 | input_size = opt.word_vec_size 16 | 17 | super(Encoder, self).__init__() 18 | self.word_lut = nn.Embedding(dicts.size(), 19 | opt.word_vec_size, 20 | padding_idx=onmt.Constants.PAD) 21 | self.rnn = nn.LSTM(input_size, self.hidden_size, 22 | num_layers=opt.layers, 23 | dropout=opt.dropout, 24 | bidirectional=opt.brnn) 25 | 26 | def load_pretrained_vectors(self, opt): 27 | if opt.pre_word_vecs_enc is not None: 28 | pretrained = torch.load(opt.pre_word_vecs_enc) 29 | self.word_lut.weight.data.copy_(pretrained) 30 | 31 | def forward(self, input, hidden=None): 32 | if isinstance(input, tuple): 33 | #import pdb; pdb.set_trace() 34 | emb = pack(self.word_lut(input[0]), list(input[1])) 35 | else: 36 | emb = self.word_lut(input) 37 | outputs, hidden_t = self.rnn(emb, hidden) 38 | if isinstance(input, tuple): 39 | outputs = unpack(outputs)[0] 40 | return hidden_t, outputs 41 | 42 | 43 | class StackedLSTM(nn.Module): 44 | def __init__(self, num_layers, input_size, rnn_size, dropout): 45 | super(StackedLSTM, self).__init__() 46 | self.dropout = nn.Dropout(dropout) 47 | self.num_layers = num_layers 48 | self.layers = nn.ModuleList() 49 | 50 | for i in range(num_layers): 51 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 52 | input_size = rnn_size 53 | 54 | def forward(self, input, hidden): 55 | h_0, c_0 = hidden 56 | h_1, c_1 = [], [] 57 | for i, layer in enumerate(self.layers): 58 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 59 | input = h_1_i 60 | if i + 1 != self.num_layers: 61 | input = self.dropout(input) 62 | h_1 += [h_1_i] 63 | c_1 += [c_1_i] 64 | 65 | h_1 = torch.stack(h_1) 66 | c_1 = torch.stack(c_1) 67 | 68 | return input, (h_1, c_1) 69 | 70 | 71 | class Decoder(nn.Module): 72 | 73 | def __init__(self, opt, dicts): 74 | self.layers = opt.layers 75 | self.input_feed = opt.input_feed 76 | self.context = opt.add_context 77 | input_size = opt.word_vec_size 78 | if self.input_feed: 79 | input_size += opt.rnn_size 80 | if self.context: 81 | input_size += opt.rnn_size 82 | 83 | super(Decoder, self).__init__() 84 | self.word_lut = nn.Embedding(dicts.size(), 85 | opt.word_vec_size, 86 | padding_idx=onmt.Constants.PAD) 87 | self.rnn = StackedLSTM(opt.layers, input_size, opt.rnn_size, opt.dropout) 88 | self.attn = onmt.modules.GlobalAttention(opt.rnn_size) 89 | self.dropout = nn.Dropout(opt.dropout) 90 | 91 | self.hidden_size = opt.rnn_size 92 | 93 | def load_pretrained_vectors(self, opt): 94 | if opt.pre_word_vecs_dec is not None: 95 | pretrained = torch.load(opt.pre_word_vecs_dec) 96 | self.word_lut.weight.data.copy_(pretrained) 97 | 98 | def forward(self, input, hidden, context, hidden_cxt, init_output): 99 | emb = self.word_lut(input) 100 | #print(context.size()) 101 | outputs = [] 102 | output = init_output 103 | for emb_t in emb.split(1): 104 | emb_t = emb_t.squeeze(0) 105 | if self.input_feed: 106 | emb_t = torch.cat([emb_t, output], 1) 107 | if self.context: 108 | emb_t = torch.cat([emb_t, hidden_cxt], 1) 109 | 110 | output, hidden = self.rnn(emb_t, hidden) 111 | output, attn = self.attn(output, context.transpose(0, 1)) 112 | output = self.dropout(output) 113 | outputs += [output] 114 | 115 | outputs = torch.stack(outputs) 116 | return outputs, hidden, attn 117 | 118 | 119 | class NMTModel(nn.Module): 120 | 121 | def __init__(self, encoder_src, encoder_cxt, decoder): 122 | super(NMTModel, self).__init__() 123 | self.encoder_src = encoder_src 124 | self.encoder_cxt = encoder_cxt 125 | self.decoder = decoder 126 | 127 | def make_init_decoder_output(self, context): 128 | batch_size = context.size(1) 129 | h_size = (batch_size, self.decoder.hidden_size) 130 | return Variable(context.data.new(*h_size).zero_(), requires_grad=False) 131 | 132 | def _fix_enc_hidden(self, h): 133 | # the encoder hidden is (layers*directions) x batch x dim 134 | # we need to convert it to layers x batch x (directions*dim) 135 | if self.encoder_src.num_directions == 2: 136 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ 137 | .transpose(1, 2).contiguous() \ 138 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2) 139 | else: 140 | return h 141 | 142 | def forward(self, input): 143 | src = input[0] 144 | cxt = input[1] 145 | tgt = input[2][:-1] # exclude last target from inputs 146 | enc_hidden_src, context_src = self.encoder_src(src) 147 | if self.encoder_cxt: 148 | enc_hidden_cxt, _ = self.encoder_cxt(cxt) 149 | if enc_hidden_cxt[0].size(1) == 1: 150 | enc_hidden_cxt = torch.cat((enc_hidden_cxt[0][-1], \ 151 | enc_hidden_cxt[0][-2]), 1) 152 | else: 153 | enc_hidden_cxt = torch.cat((enc_hidden_cxt[0][-1].squeeze(0), \ 154 | enc_hidden_cxt[0][-2].squeeze(0)), 1) 155 | else: 156 | enc_hidden_cxt = None 157 | #print(context_cxt.size()) 158 | #print(enc_hidden_cxt.size()) 159 | #print(enc_hidden_cxt[1].size()) 160 | init_output = self.make_init_decoder_output(context_src) 161 | 162 | enc_hidden_src = (self._fix_enc_hidden(enc_hidden_src[0]), 163 | self._fix_enc_hidden(enc_hidden_src[1])) 164 | 165 | out, dec_hidden, _attn = self.decoder(tgt, enc_hidden_src, context_src, enc_hidden_cxt, init_output) 166 | #print(context_src.size()) 167 | 168 | return out 169 | -------------------------------------------------------------------------------- /src/onmt/Optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.nn.utils import clip_grad_norm 5 | 6 | class Optim(object): 7 | 8 | def set_parameters(self, params): 9 | self.params = list(params) # careful: params may be a generator 10 | if self.method == 'sgd': 11 | self.optimizer = optim.SGD(self.params, lr=self.lr) 12 | elif self.method == 'adagrad': 13 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 14 | elif self.method == 'adadelta': 15 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 16 | elif self.method == 'adam': 17 | self.optimizer = optim.Adam(self.params, lr=self.lr) 18 | else: 19 | raise RuntimeError("Invalid optim method: " + self.method) 20 | 21 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None): 22 | self.last_ppl = None 23 | self.lr = lr 24 | self.max_grad_norm = max_grad_norm 25 | self.method = method 26 | self.lr_decay = lr_decay 27 | self.start_decay_at = start_decay_at 28 | self.start_decay = False 29 | 30 | def step(self): 31 | # Compute gradients norm. 32 | if self.max_grad_norm: 33 | clip_grad_norm(self.params, self.max_grad_norm) 34 | self.optimizer.step() 35 | 36 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 37 | def updateLearningRate(self, ppl, epoch): 38 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 39 | self.start_decay = True 40 | if self.last_ppl is not None and ppl > self.last_ppl: 41 | self.start_decay = True 42 | 43 | if self.start_decay: 44 | self.lr = self.lr * self.lr_decay 45 | print("Decaying learning rate to %g" % self.lr) 46 | 47 | self.last_ppl = ppl 48 | self.optimizer.param_groups[0]['lr'] = self.lr 49 | -------------------------------------------------------------------------------- /src/onmt/Translator.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class Translator(object): 8 | def __init__(self, opt): 9 | self.opt = opt 10 | self.tt = torch.cuda if opt.cuda else torch 11 | 12 | checkpoint = torch.load(opt.path_model) 13 | 14 | model_opt = checkpoint['opt'] 15 | self.src_dict = checkpoint['dicts']['src'] 16 | self.cxt_dict = checkpoint['dicts']['cxt'] 17 | self.tgt_dict = checkpoint['dicts']['tgt'] 18 | 19 | encoder_src = onmt.Models.Encoder(model_opt, self.src_dict) 20 | encoder_src.load_state_dict(checkpoint['encoder_src']) 21 | encoder_cxt = onmt.Models.Encoder(model_opt, self.cxt_dict) 22 | encoder_src.load_state_dict(checkpoint['encoder_src']) 23 | decoder = onmt.Models.Decoder(model_opt, self.tgt_dict) 24 | decoder.load_state_dict(checkpoint['decoder']) 25 | model = onmt.Models.NMTModel(encoder_src, encoder_cxt, decoder) 26 | 27 | generator = nn.Sequential( 28 | nn.Linear(model_opt.rnn_size, self.tgt_dict.size()), 29 | nn.LogSoftmax()) 30 | 31 | generator.load_state_dict(checkpoint['generator']) 32 | 33 | if opt.cuda: 34 | model.cuda() 35 | generator.cuda() 36 | else: 37 | model.cpu() 38 | generator.cpu() 39 | 40 | model.generator = generator 41 | 42 | self.model = model 43 | self.model.eval() 44 | 45 | 46 | def buildData(self, srcBatch, cxtBatch, goldBatch): 47 | srcData = [self.src_dict.convertToIdx(b, 48 | onmt.Constants.UNK_WORD) for b in srcBatch] 49 | cxtData = None 50 | if cxtBatch: 51 | cxtData = [self.cxt_dict.convertToIdx(b, 52 | onmt.Constants.UNK_WORD) for b in cxtBatch] 53 | tgtData = None 54 | if goldBatch: 55 | tgtData = [self.tgt_dict.convertToIdx(b, 56 | onmt.Constants.UNK_WORD, 57 | onmt.Constants.BOS_WORD, 58 | onmt.Constants.EOS_WORD) for b in goldBatch] 59 | 60 | return onmt.Dataset(srcData, cxtData, tgtData, 61 | self.opt.batch_size, self.opt.cuda, volatile=True) 62 | 63 | def buildTargetTokens(self, pred, src, attn): 64 | tokens = self.tgt_dict.convertToLabels(pred, onmt.Constants.EOS) 65 | tokens = tokens[:-1] # EOS 66 | if self.opt.replace_unk: 67 | for i in range(len(tokens)): 68 | if tokens[i] == onmt.Constants.UNK_WORD: 69 | _, maxIndex = attn[i].max(0) 70 | tokens[i] = src[maxIndex[0]] 71 | return tokens 72 | 73 | def translateBatch(self, srcBatch, cxtBatch, tgtBatch): 74 | batchSize = srcBatch[0].size(1) 75 | beamSize = self.opt.beam_size 76 | #import pdb; pdb.set_trace() 77 | # (1) run the encoder on the src 78 | encStates_src, context_src = self.model.encoder_src(srcBatch) 79 | encStates_cxt, context_cxt = self.model.encoder_src(cxtBatch) 80 | encStates_cxt = torch.cat((encStates_cxt[0][-1], encStates_cxt[0][-2]), 1) 81 | srcBatch = srcBatch[0] # drop the lengths needed for encoder 82 | 83 | rnnSize = context_src.size(2) 84 | encStates_src = (self.model._fix_enc_hidden(encStates_src[0]), 85 | self.model._fix_enc_hidden(encStates_src[1])) 86 | 87 | # This mask is applied to the attention model inside the decoder 88 | # so that the attention ignores source padding 89 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t() 90 | def applyContextMask(m): 91 | if isinstance(m, onmt.modules.GlobalAttention): 92 | m.applyMask(padMask) 93 | 94 | # (2) if a target is specified, compute the 'goldScore' 95 | # (i.e. log likelihood) of the target under the model 96 | goldScores = context_src.data.new(batchSize).zero_() 97 | if tgtBatch is not None: 98 | decStates = encStates_src 99 | decOut = self.model.make_init_decoder_output(context_src) 100 | self.model.decoder.apply(applyContextMask) 101 | initOutput = self.model.make_init_decoder_output(context_src) 102 | 103 | decOut, decStates, attn = self.model.decoder( 104 | tgtBatch[:-1], decStates, context_src, encStates_cxt, initOutput) 105 | for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): 106 | gen_t = self.model.generator.forward(dec_t) 107 | tgt_t = tgt_t.unsqueeze(1) 108 | scores = gen_t.data.gather(1, tgt_t) 109 | scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) 110 | goldScores += scores 111 | 112 | # (3) run the decoder to generate sentences, using beam search 113 | 114 | # Expand tensors for each beam. 115 | context_src = Variable(context_src.data.repeat(1, beamSize, 1)) 116 | decStates = (Variable(encStates_src[0].data.repeat(1, beamSize, 1)), 117 | Variable(encStates_src[1].data.repeat(1, beamSize, 1))) 118 | 119 | encStates_cxt = Variable(encStates_cxt.data.repeat(beamSize, 1)) 120 | 121 | beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] 122 | 123 | decOut = self.model.make_init_decoder_output(context_src) 124 | 125 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1) 126 | 127 | batchIdx = list(range(batchSize)) 128 | remainingSents = batchSize 129 | for i in range(self.opt.max_sent_length): 130 | 131 | self.model.decoder.apply(applyContextMask) 132 | 133 | # Prepare decoder input. 134 | input = torch.stack([b.getCurrentState() for b in beam 135 | if not b.done]).t().contiguous().view(1, -1) 136 | 137 | #import pdb; pdb.set_trace() 138 | decOut, decStates, attn = self.model.decoder(Variable(input, volatile=True), decStates, context_src, encStates_cxt, decOut) 139 | # decOut: 1 x (beam*batch) x numWords 140 | decOut = decOut.squeeze(0) 141 | out = self.model.generator.forward(decOut) 142 | 143 | # batch x beam x numWords 144 | wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() 145 | attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() 146 | 147 | active = [] 148 | for b in range(batchSize): 149 | if beam[b].done: 150 | continue 151 | 152 | idx = batchIdx[b] 153 | if not beam[b].advance(wordLk.data[idx], attn.data[idx]): 154 | active += [b] 155 | 156 | for decState in decStates: # iterate over h, c 157 | # layers x beam*sent x dim 158 | sentStates = decState.view( 159 | -1, beamSize, remainingSents, decState.size(2))[:, :, idx] 160 | sentStates.data.copy_( 161 | sentStates.data.index_select(1, beam[b].getCurrentOrigin())) 162 | 163 | if not active: 164 | break 165 | 166 | # in this section, the sentences that are still active are 167 | # compacted so that the decoder is not run on completed sentences 168 | activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) 169 | batchIdx = {beam: idx for idx, beam in enumerate(active)} 170 | 171 | def updateActive(t): 172 | # select only the remaining active sentences 173 | view = t.data.view(-1, remainingSents, rnnSize) 174 | newSize = list(t.size()) 175 | newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents 176 | return Variable(view.index_select(1, activeIdx) \ 177 | .view(*newSize), volatile=True) 178 | 179 | decStates = (updateActive(decStates[0]), updateActive(decStates[1])) 180 | encStates_cxt = updateActive(encStates_cxt) 181 | decOut = updateActive(decOut) 182 | context_src = updateActive(context_src) 183 | padMask = padMask.index_select(1, activeIdx) 184 | 185 | remainingSents = len(active) 186 | 187 | # (4) package everything up 188 | 189 | allHyp, allScores, allAttn = [], [], [] 190 | n_best = self.opt.n_best 191 | 192 | for b in range(batchSize): 193 | scores, ks = beam[b].sortBest() 194 | 195 | allScores += [scores[:n_best]] 196 | valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1) 197 | hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) 198 | attn = [a.index_select(1, valid_attn) for a in attn] 199 | allHyp += [hyps] 200 | allAttn += [attn] 201 | 202 | return allHyp, allScores, allAttn, goldScores 203 | 204 | def translate(self, srcBatch, cxtBatch, goldBatch): 205 | # (1) convert words to indexes 206 | dataset = self.buildData(srcBatch, cxtBatch, goldBatch) 207 | src, cxt, tgt, indices = dataset[0] 208 | 209 | # src = ([L,batch], (L0,L1,...)) 210 | 211 | # (2) translate 212 | pred, predScore, attn, goldScore = self.translateBatch(src, cxt, tgt) 213 | pred, predScore, attn, goldScore = list(zip(*sorted(zip(pred, predScore, attn, goldScore, indices), key=lambda x: x[-1])))[:-1] 214 | 215 | # (3) convert indexes to words 216 | predBatch = [] 217 | for b in range(src[0].size(1)): 218 | predBatch.append( 219 | [self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n]) 220 | for n in range(self.opt.n_best)] 221 | ) 222 | 223 | return predBatch, predScore, goldScore 224 | -------------------------------------------------------------------------------- /src/onmt/__init__.py: -------------------------------------------------------------------------------- 1 | import onmt.Constants 2 | import onmt.Models 3 | from onmt.Translator import Translator as Translator 4 | from onmt.Dataset import Dataset 5 | from onmt.Optim import Optim 6 | from onmt.Dict import Dict 7 | from onmt.Beam import Beam 8 | 9 | """ 10 | commit https://github.com/shrimai/Towards-Content-Transfer-through-Grounded-Text-Generation/commit/dd2246f557c39aee69711aabc1e3540a7c9f27a7 11 | some bugs-fixing/modification by Xiang Gao @ Microsoft Research 12 | """ -------------------------------------------------------------------------------- /src/onmt/modules/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query vector. It 3 | then computes a parameterized convex combination of the matrix 4 | based on the input query. 5 | 6 | 7 | H_1 H_2 H_3 ... H_n 8 | q q q q 9 | | | | | 10 | \ | | / 11 | ..... 12 | \ | / 13 | a 14 | 15 | Constructs a unit mapping. 16 | $$(H_1 + H_n, q) => (a)$$ 17 | Where H is of `batch x n x dim` and q is of `batch x dim`. 18 | 19 | The full def is $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$.: 20 | 21 | """ 22 | 23 | import torch 24 | import torch.nn as nn 25 | import math 26 | 27 | class GlobalAttention(nn.Module): 28 | def __init__(self, dim): 29 | super(GlobalAttention, self).__init__() 30 | self.linear_in = nn.Linear(dim, dim, bias=False) 31 | self.sm = nn.Softmax() 32 | self.linear_out = nn.Linear(dim*2, dim, bias=False) 33 | self.tanh = nn.Tanh() 34 | self.mask = None 35 | 36 | def applyMask(self, mask): 37 | self.mask = mask 38 | 39 | def forward(self, input, context): 40 | """ 41 | input: batch x dim 42 | context: batch x sourceL x dim 43 | """ 44 | targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1 45 | 46 | # Get attention 47 | attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL 48 | #import pdb; pdb.set_trace() 49 | if self.mask is not None: 50 | attn.data.masked_fill_(self.mask.view(-1, attn.shape[-1]), -float('inf')) 51 | attn = self.sm(attn) 52 | attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL 53 | 54 | weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim 55 | contextCombined = torch.cat((weightedContext, input), 1) 56 | 57 | contextOutput = self.tanh(self.linear_out(contextCombined)) 58 | 59 | return contextOutput, attn 60 | -------------------------------------------------------------------------------- /src/onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from onmt.modules.GlobalAttention import GlobalAttention 2 | -------------------------------------------------------------------------------- /src/open_dialog.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | import torch, os, pdb 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, BertTokenizer 7 | import time 8 | from todo import pick_tokens 9 | 10 | 11 | class DialoGPT: 12 | 13 | def __init__(self, use_cuda=True, path_model='models/DialoGPT/medium_ft.pkl'): 14 | self.use_cuda = use_cuda 15 | self.turn_sep = ' <|endoftext|> ' 16 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 17 | model_config = GPT2Config(n_embd=1024, n_layer=24, n_head=16) 18 | self.model = GPT2LMHeadModel(model_config) 19 | weights = torch.load(path_model) 20 | weights["lm_head.weight"] = weights["lm_head.decoder.weight"] 21 | weights.pop("lm_head.decoder.weight",None) 22 | self.model.load_state_dict(weights) 23 | if self.use_cuda: 24 | self.model = self.model.cuda() 25 | self.ix_EOS = 50256 26 | self.way = 'DPT' 27 | self.model.eval() 28 | 29 | 30 | def tf_prob(self, context, hyps, use_EOS=True, batch=10, return_np=True): 31 | if isinstance(hyps, str): 32 | hyps = [hyps] 33 | i0 = 0 34 | prob = [] 35 | while i0 < len(hyps): 36 | i1 = min(i0 + batch, len(hyps)) 37 | with torch.no_grad(): 38 | prob.append(self._tf_prob(context, hyps[i0:i1], use_EOS=use_EOS)) 39 | i0 = i1 40 | if len(prob) > 1: 41 | prob = torch.cat(prob, dim=0) 42 | else: 43 | prob = prob[0] 44 | if return_np: 45 | if self.use_cuda: 46 | prob = prob.cpu() 47 | return prob.detach().numpy() 48 | else: 49 | return prob 50 | 51 | 52 | def _tf_prob(self, context, hyps, use_EOS=True): 53 | # converted what's from tokenizer.encode to what's should be used in logits 54 | enc2pred = { 55 | 11:837, # ',' => 'Ġ,' 56 | 13:764, # '.' => 'Ġ.' 57 | 0:5145, # '!' => 'Ġ!' 58 | 30:5633, # '?' => 'Ġ?' 59 | } 60 | ids_cxt = self.tokenizer.encode(context) + [self.ix_EOS] 61 | ids_hyp = [] 62 | hyp_len = [] 63 | for hyp in hyps: 64 | raw_hyp_tokens = self.tokenizer.encode(hyp) 65 | 66 | # if not use_EOS, then hyps are some incomplete hyps, as in decoding with cross-model scoring 67 | if use_EOS: 68 | raw_hyp_tokens.append(self.ix_EOS) 69 | hyp_tokens = [] 70 | for token in raw_hyp_tokens: 71 | hyp_tokens.append(enc2pred.get(token, token)) 72 | ids_hyp.append(hyp_tokens) 73 | hyp_len.append(len(hyp_tokens)) 74 | 75 | max_len = max(hyp_len) 76 | ids = [] 77 | mask = [] 78 | for i, seq in enumerate(ids_hyp): 79 | cat = ids_cxt + seq + [self.ix_EOS] * (max_len - hyp_len[i]) 80 | ids.append(cat) 81 | mask.append([1] * hyp_len[i] + [0] * (max_len - hyp_len[i])) 82 | ids = torch.tensor(ids) 83 | mask = torch.FloatTensor(mask) 84 | hyp_len = torch.FloatTensor(hyp_len) 85 | if self.use_cuda: 86 | ids = ids.to('cuda') 87 | mask = mask.to('cuda') 88 | hyp_len = hyp_len.to('cuda') 89 | 90 | l_cxt = len(ids_cxt) 91 | with torch.no_grad(): 92 | logits, _ = self.model(ids) 93 | logits = logits[:, l_cxt - 1: -1, :] # only care the part after cxt. ignore -1. 94 | logP = torch.log(F.softmax(logits, dim=-1)) 95 | 96 | logP_ids = logP.gather(dim=-1, index=ids[:,l_cxt:].unsqueeze(-1)).squeeze(-1) 97 | avg_logP = (logP_ids * mask).sum(dim=-1) / hyp_len 98 | return torch.exp(avg_logP) 99 | 100 | 101 | def rvs_prob(self, cxt, hyps, batch=10): 102 | i0 = 0 103 | prob = [] 104 | while i0 < len(hyps): 105 | i1 = min(i0 + batch, len(hyps)) 106 | with torch.no_grad(): 107 | prob.append(self._rvs_prob(cxt, hyps[i0:i1])) 108 | i0 = i1 109 | return np.concatenate(prob, axis=0) 110 | 111 | 112 | def _rvs_prob(self, context, hyps): 113 | # converted what's from tokenizer.encode to what's should be used in logits 114 | enc2pred = { 115 | 11:837, # ',' => 'Ġ,' 116 | 13:764, # '.' => 'Ġ.' 117 | 0:5145, # '!' => 'Ġ!' 118 | 30:5633, # '?' => 'Ġ?' 119 | } 120 | 121 | raw_ids_cxt = self.tokenizer.encode(context) + [self.ix_EOS] 122 | ids_cxt = [] 123 | for token in raw_ids_cxt: 124 | ids_cxt.append(enc2pred.get(token, token)) 125 | 126 | ids_hyp = [] 127 | hyp_len = [] 128 | for hyp in hyps: 129 | hyp = (' ' + hyp + ' ').replace(' i ',' I ') 130 | hyp = hyp.strip().replace(" '","'") 131 | hyp = hyp[0].upper() + hyp[1:] 132 | 133 | hyp_tokens = self.tokenizer.encode(hyp) + [self.ix_EOS] 134 | ids_hyp.append(hyp_tokens) 135 | hyp_len.append(len(hyp_tokens)) 136 | 137 | max_len = max(hyp_len) 138 | ids = [] 139 | for i, seq in enumerate(ids_hyp): 140 | cat = seq + ids_cxt + [self.ix_EOS] * (max_len - hyp_len[i]) 141 | ids.append(cat) 142 | ids = torch.tensor(ids) 143 | if self.use_cuda: 144 | ids = ids.to('cuda') 145 | with torch.no_grad(): 146 | logits, _ = self.model(ids) 147 | logP = torch.log(F.softmax(logits, dim=-1)) 148 | 149 | logP_cxt = [] 150 | for i, l in enumerate(hyp_len): 151 | _logP = [] 152 | for t, token in enumerate(ids_cxt): 153 | _logP.append(logP[i, l + t - 1, token].item()) 154 | logP_cxt.append(np.mean(_logP)) 155 | return np.exp(logP_cxt) 156 | 157 | 158 | def predict(self, context, beam=10): 159 | # return n hypotheses given context, in parallel 160 | # context is str 161 | way = self.way 162 | 163 | conditioned_tokens = self.tokenizer.encode(context) + [self.ix_EOS] 164 | len_cxt = len(conditioned_tokens) 165 | tokens = torch.tensor([conditioned_tokens]).view(1, -1) 166 | if self.use_cuda: 167 | tokens = tokens.cuda() 168 | 169 | finished = [] 170 | hyp_set = set() 171 | sum_logP = [0] 172 | max_t = 30 173 | for t in range(max_t): 174 | with torch.no_grad(): 175 | outputs = self.model(tokens) 176 | predictions = outputs[0] 177 | logits = predictions[:, -1, :] # only care the last step. [n_hyp, vocab] 178 | prob = F.softmax(logits, dim=-1) 179 | logP = torch.log(prob) 180 | picked_tokens = pick_tokens(prob) 181 | 182 | cand = [] 183 | #tokens_np = (tokens.cpu() if self.use_cuda else tokens).detach().numpy() 184 | for i in range(picked_tokens.shape[0]): 185 | for j in range(picked_tokens.shape[1]): 186 | ix = picked_tokens[i, j].item() 187 | _sum_logP = sum_logP[i] + logP[i, ix].item() 188 | cand.append((_sum_logP, i, j)) 189 | 190 | if not cand: 191 | break 192 | cand = sorted(cand, reverse=True) 193 | cand = cand[:min(len(cand), beam)] 194 | sum_logP = [] 195 | cur = [] 196 | nxt = [] 197 | for _sum_logP, i, j in cand: 198 | ix = picked_tokens[i, j].item() 199 | if ix == self.ix_EOS: 200 | seq = [w.item() for w in tokens[i, len_cxt: len_cxt + t]] 201 | seq_tup = tuple(seq) 202 | if seq_tup not in hyp_set: 203 | finished.append((np.exp(_sum_logP/len(seq)), seq)) 204 | hyp_set.add(seq_tup) 205 | continue 206 | 207 | cur.append(tokens[i:i+1,:]) 208 | nxt.append(picked_tokens[i:i+1, j]) 209 | sum_logP.append(_sum_logP) 210 | if len(cur) == beam: 211 | break 212 | 213 | if not cur: 214 | break 215 | tokens = torch.cat([torch.cat(cur, dim=0), torch.cat(nxt, dim=0).unsqueeze(-1)], dim=-1) 216 | 217 | finished = sorted(finished, reverse=True) 218 | ret = [] 219 | for prob, seq in finished: 220 | hyp = self.tokenizer.decode(seq).strip() 221 | ret.append((way, prob, hyp)) 222 | if len(ret) == beam: 223 | break 224 | return sorted(ret, reverse=True) 225 | 226 | 227 | def play_dpt(): 228 | dialogpt = DialoGPT() 229 | while True: 230 | print('\n(empty query to exit)') 231 | cxt = input('CONTEXT:\t') 232 | if not cxt: 233 | break 234 | ret = dialogpt.predict(cxt) 235 | for way, prob, hyp in ret: 236 | print('%s %.3f\t%s'%(way, prob, hyp)) 237 | 238 | 239 | if __name__ == "__main__": 240 | play_dpt() -------------------------------------------------------------------------------- /src/ranker.py: -------------------------------------------------------------------------------- 1 | #// Copyright (c) Microsoft Corporation.// Licensed under the MIT license. 2 | 3 | import numpy as np 4 | from shared import alnum_only 5 | import os 6 | 7 | 8 | class ScorerRepetition: 9 | # measuring repetition penalty, proposed in https://arxiv.org/abs/2005.08365 10 | 11 | def predict(self, txts): 12 | scores = [] 13 | for txt in txts: 14 | ww = [] 15 | for w in alnum_only(txt).split(): 16 | if w: 17 | ww.append(w) 18 | if not ww: 19 | return 0 20 | rep = 1 - len(set(ww)) / len(ww) 21 | scores.append(- rep) 22 | return scores 23 | 24 | 25 | class ScorerInfo: 26 | # measuring informativeness, proposed in https://arxiv.org/abs/2005.08365 27 | 28 | def __init__(self): 29 | fld = 'models/info' 30 | os.makedirs(fld, exist_ok=True) 31 | self.path = 'src/common.txt' 32 | 33 | 34 | def load(self): 35 | self.w2rank = dict() 36 | for i, line in enumerate(open(self.path)): 37 | for w in line.strip('\n').split(): 38 | self.w2rank[w] = i 39 | self.max_rank = i 40 | 41 | 42 | def score(self, w): 43 | rank = self.w2rank.get(w, self.max_rank) 44 | return rank/self.max_rank 45 | 46 | 47 | def train(self, path_corpus, max_n=1e6): 48 | from collections import Counter, defaultdict 49 | counter = Counter() 50 | n = 0 51 | for line in open(path_corpus, encoding='utf-8'): 52 | ww = alnum_only(line).split() 53 | for w in ww: 54 | if w: 55 | counter[w] += 1 56 | n += 1 57 | if n == max_n: 58 | break 59 | 60 | freq_ww = defaultdict(list) 61 | for w, freq in counter.most_common(): 62 | if freq < 5: 63 | break 64 | freq_ww[freq].append(w) 65 | 66 | lines = [] 67 | for freq in sorted(list(freq_ww.keys()), reverse=True): 68 | lines.append(' '.join(freq_ww[freq])) 69 | 70 | with open(self.path, 'w', encoding='utf-8') as f: 71 | f.write('\n'.join(lines)) 72 | 73 | 74 | def predict(self, txts): 75 | scores = [] 76 | for txt in txts: 77 | score = [] 78 | ww = set(alnum_only(txt).split()) 79 | for w in ww: 80 | if w: 81 | score.append(self.score(w)) 82 | scores.append(np.mean(score)) 83 | return scores 84 | 85 | 86 | class Ranker: 87 | 88 | def __init__(self, scorer_fwd=None, scorer_rvs=None, scorer_style=None): 89 | self.scorer_fwd = scorer_fwd 90 | self.scorer_rvs = scorer_rvs 91 | self.scorer_style = scorer_style 92 | 93 | self.scorer_rep = ScorerRepetition() 94 | self.scorer_info = ScorerInfo() 95 | self.scorer_info.load() 96 | 97 | 98 | def predict(self, cxt, hyps): 99 | info = self.scorer_info.predict(hyps) 100 | rep = self.scorer_rep.predict(hyps) 101 | if self.scorer_fwd is not None: 102 | prob_fwd = self.scorer_fwd.tf_prob(cxt, hyps) 103 | if self.scorer_rvs is not None: 104 | prob_rvs = self.scorer_rvs.rvs_prob(cxt, hyps) 105 | if self.scorer_style is not None: 106 | style = self.scorer_style.predict(hyps) 107 | 108 | scored = [] 109 | for i in range(len(hyps)): 110 | d = { 111 | 'rep': rep[i], 112 | 'info': info[i], 113 | } 114 | if self.scorer_fwd is not None: 115 | d['fwd'] = float(prob_fwd[i]) 116 | if self.scorer_rvs is not None: 117 | d['rvs'] = float(prob_rvs[i]) 118 | if self.scorer_style is not None: 119 | d['style'] = float(style[i]) 120 | score = sum([d[k] for k in d]) 121 | d['score'] = score + np.random.random() * 1e-8 # to avoid exactly the same score 122 | scored.append(d) 123 | 124 | return scored 125 | 126 | 127 | def play_ranker(): 128 | ranker = Ranker() 129 | while True: 130 | txt = input('\nTXT:\t') 131 | if not txt: 132 | break 133 | scored = ranker.predict('', [txt])[0] 134 | print(' '.join(['%s %.4f'%(k, scored[k]) for k in scored])) 135 | 136 | 137 | if __name__ == "__main__": 138 | play_ranker() 139 | -------------------------------------------------------------------------------- /src/shared.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def get_api_key(name): 4 | for line in open('args/api.tsv'): 5 | ss = line.strip('\n').split('\t') 6 | if ss[0] == name: 7 | return ss[1:] 8 | return None 9 | 10 | 11 | def alnum_only(s): 12 | return re.sub(r"[^a-z0-9]", ' ', s.lower()) -------------------------------------------------------------------------------- /src/templates/dialog.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Grounded Conversational Demo 5 | 6 | 7 | 8 | 9 | 10 | 11 | 29 | 30 | 31 | 32 | 33 |

Grounded Conversational Demo

34 | 35 | 36 | 37 |

Context

38 | 39 | 40 | {% for speaker, utt in history %} 41 | 42 | 43 | 44 | 45 | {% endfor %} 46 |
{{ speaker }}{{ utt }}
47 | 48 | 49 |
50 | 51 |

Response

52 | 53 | 56 | 57 | {% for v in notes %} 58 | {{ v }}
59 | {% endfor %} 60 | 61 | 62 | 63 | {% for s in score_header %} 64 | 65 | {% endfor %} 66 | 67 | 68 | 69 | 70 | {% for score, model, hyp, hyp_en in hyps %} 71 | 72 | {% for v in score %} 73 | 74 | {% endfor %} 75 | 76 | 77 | 78 | 79 | {% endfor %} 80 |
{{ s }}modelcandidate
{{ v }}{{ model }}{{ hyp }}{{ hyp_en }}
81 | 82 |

Retrieved Passage:

83 | 84 | {% for url_display, url_full, snippet in passages %} 85 | 86 | 87 | 88 | 89 | {% endfor %} 90 |
{{ url_display }}{{ snippet }}
91 | 92 | 93 |

Input

94 |
95 | 96 | 97 | 101 | 102 | 103 | 106 | 107 |
98 | Message 99 | 100 |
104 | New session 105 |
108 |

109 |
110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /src/templates/doc_gen.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Document generation demo 5 | 6 | 7 | 8 | 9 | 10 | 11 | 29 | 30 | 31 | 32 | 33 |

Document generation demo

34 | 35 |

Input

36 |
37 |
38 |

39 |
40 | 41 | 42 | 43 |
44 |
45 | 46 |

Suggestion

47 | {% for v in notes %} 48 | {{ v }}
49 | {% endfor %} 50 | 51 | 52 | 53 | {% for s in score_header %} 54 | 55 | {% endfor %} 56 | 57 | 58 | 59 | 60 | {% for score, model, hyp, hyp_en in hyps %} 61 | 62 | {% for v in score %} 63 | 64 | {% endfor %} 65 | 66 | 67 | 68 | 69 | {% endfor %} 70 |
{{ s }}modelcandidate
{{ v }}{{ model }}{{ hyp }}{{ hyp_en }}
71 | 72 |

Retrieved Passage:

73 | 74 | {% for url_display, url_full, snippet in passages %} 75 | 76 | 77 | 78 | 79 | {% endfor %} 80 |
{{ url_display }}{{ snippet }}
81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /src/todo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def pick_tokens(prob): 4 | # prob: tensor, shape = [n, vocab_size], the predicted token generation probablity 5 | # return: tensor, shape = [n, k], picked token index based on prob 6 | # where k is a hyperparameter you can choose 7 | 8 | # please implement your algorithm here 9 | pass 10 | 11 | -------------------------------------------------------------------------------- /src/tts.py: -------------------------------------------------------------------------------- 1 | import requests, time, re, pyaudio, wave, os 2 | from xml.etree import ElementTree 3 | from shared import get_api_key 4 | 5 | 6 | class TextToSpeech: 7 | 8 | def __init__(self): 9 | region, key = get_api_key('speech') 10 | self.token_url = "https://%s.api.cognitive.microsoft.com/sts/v1.0/issueToken"%region 11 | self.token_headers = {'Ocp-Apim-Subscription-Key': key} 12 | self.tts_url = 'https://%s.tts.speech.microsoft.com/cognitiveservices/v1'%region 13 | self.tts_headers = { 14 | 'Content-Type': 'application/ssml+xml', 15 | 'X-Microsoft-OutputFormat': 'riff-24khz-16bit-mono-pcm', 16 | } 17 | response = requests.post(self.token_url, headers=self.token_headers) 18 | self.tts_headers['Authorization'] = 'Bearer ' + str(response.text) 19 | 20 | self.fld_out = 'voice' 21 | os.makedirs(self.fld_out, exist_ok=True) 22 | 23 | 24 | def get_audio(self, txt, name='en-US-JessaNeural'): 25 | # see: https://docs.microsoft.com/en-us/azure/cognitive-services/speech-service/language-support 26 | 27 | txt_fname = re.sub(r"[^A-Za-z0-9]", "", txt).lower() 28 | txt_fname = txt_fname[:min(20, len(txt_fname))] 29 | path_out = self.fld_out + '/%s_%s.wav'%(txt_fname, name) 30 | 31 | lang = '-'.join(name.split('-')[:2]) 32 | xml_body = ElementTree.Element('speak', version='1.0') 33 | xml_body.set('{http://www.w3.org/XML/1998/namespace}lang', lang) 34 | voice = ElementTree.SubElement(xml_body, 'voice') 35 | voice.set('{http://www.w3.org/XML/1998/namespace}lang', lang) 36 | voice.set('name', name) 37 | voice.text = txt 38 | body = ElementTree.tostring(xml_body) 39 | 40 | response = requests.post(self.tts_url, headers=self.tts_headers, data=body) 41 | if response.status_code == 200: 42 | with open(path_out, 'wb') as audio: 43 | audio.write(response.content) 44 | return path_out 45 | else: 46 | print('[TTS] failed with status code: ' + str(response.status_code)) 47 | return None 48 | 49 | 50 | def open_audio(self, path_audio): 51 | if path_audio is None: 52 | return 53 | 54 | chunk = 1024 55 | f = wave.open(path_audio,"rb") 56 | p = pyaudio.PyAudio() 57 | stream = p.open(format=p.get_format_from_width(f.getsampwidth()), 58 | channels=f.getnchannels(), 59 | rate=f.getframerate(), 60 | output=True) 61 | data = f.readframes(chunk) 62 | while data: 63 | stream.write(data) 64 | data = f.readframes(chunk) 65 | stream.stop_stream() 66 | stream.close() 67 | p.terminate() 68 | 69 | 70 | def play_tts(): 71 | tts = TextToSpeech() 72 | while True: 73 | txt = input('\nTXT:\t') 74 | if len(txt) == 0: 75 | break 76 | path_audio = tts.get_audio(txt) 77 | if path_audio is not None: 78 | print('audio saved to '+path_audio) 79 | tts.open_audio(path_audio) 80 | 81 | 82 | if __name__ == "__main__": 83 | play_tts() --------------------------------------------------------------------------------