├── .gitattributes ├── .gitignore ├── LICENSE ├── Parse.py ├── README.md ├── SECURITY.md ├── action.py ├── check.sh ├── config.py ├── debug.py ├── eval.py ├── evalModel-indep.sh ├── evalModel.sh ├── model ├── moduleKwNoMap-b15-0-18-dynetmodel.bin ├── moduleKwNoMap-b15-0-18-extra.bin ├── moduleKwNoMap-b15-ind-0-20-dynetmodel.bin └── moduleKwNoMap-b15-ind-0-20-extra.bin ├── nnmodule.py ├── output ├── moduleKwNoMap-b15-0.out ├── moduleKwNoMap-b15-1.out ├── moduleKwNoMap-b15-2.out ├── moduleKwNoMap-b15-3.out ├── moduleKwNoMap-b15-4.out ├── moduleKwNoMap-b15-5.out ├── moduleKwNoMap-b15-ind-0.out ├── moduleKwNoMap-b15-ind-1.out ├── moduleKwNoMap-b15-ind-2.out ├── moduleKwNoMap-b15-ind-3.out ├── moduleKwNoMap-b15-ind-4.out └── moduleKwNoMap-b15-ind-5.out ├── results ├── moduleKwNoMap-b15-0-17.tsv ├── moduleKwNoMap-b15-0-18.tsv ├── moduleKwNoMap-b15-1-12.tsv ├── moduleKwNoMap-b15-2-16.tsv ├── moduleKwNoMap-b15-3-26.tsv ├── moduleKwNoMap-b15-4-20.tsv ├── moduleKwNoMap-b15-5-18.tsv ├── moduleKwNoMap-b15-ind-0-20.tsv └── moduleKwNoMap-b15-ind-0-28.tsv ├── run5fold-firstq.sh ├── run5fold-indep.sh ├── run5fold.sh ├── seqtagging.py ├── sqafirst.1.py ├── sqafirst.py ├── sqafollow.py ├── sqamodel.py ├── sqastate.py ├── statesearch.py ├── testmkl.py └── util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | cache/glove.6b.100d.trie filter=lfs diff=lfs merge=lfs -text 2 | cache/glove.twitter.100d.trie filter=lfs diff=lfs merge=lfs -text 3 | cache/senna.wiki.50d.trie filter=lfs diff=lfs merge=lfs -text 4 | model/moduleKwNoMap-b15-ind-0-29-dynetmodel.bin filter=lfs diff=lfs merge=lfs -text 5 | model/moduleKwNoMap-b15-ind-0-29-extra.bin filter=lfs diff=lfs merge=lfs -text 6 | model/moduleKwNoMap-b15-0-18-dynetmodel.bin filter=lfs diff=lfs merge=lfs -text 7 | model/moduleKwNoMap-b15-0-18-extra.bin filter=lfs diff=lfs merge=lfs -text 8 | model/moduleKwNoMap-b15-ind-0-20-dynetmodel.bin filter=lfs diff=lfs merge=lfs -text 9 | model/moduleKwNoMap-b15-ind-0-20-extra.bin filter=lfs diff=lfs merge=lfs -text 10 | model/moduleKwNoMap-b15-ind-0-29-dynetmodel.bin filter=lfs diff=lfs merge=lfs -text 11 | model/moduleKwNoMap-b15-ind-0-29-extra.bin filter=lfs diff=lfs merge=lfs -text 12 | model/moduleKwNoMap-b15-0-18-dynetmodel.bin filter=lfs diff=lfs merge=lfs -text 13 | model/moduleKwNoMap-b15-0-18-extra.bin filter=lfs diff=lfs merge=lfs -text 14 | model/moduleKwNoMap-b15-ind-0-20-dynetmodel.bin filter=lfs diff=lfs merge=lfs -text 15 | model/moduleKwNoMap-b15-ind-0-20-extra.bin filter=lfs diff=lfs merge=lfs -text 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /Parse.py: -------------------------------------------------------------------------------- 1 | from sys import float_info 2 | 3 | import util, config 4 | from util import QuestionInfo, ResultInfo 5 | 6 | class Condition: 7 | OpGT, OpLT, OpGE, OpLE, OpEqRow, OpNeRow, OpArgMin, OpArgMax = xrange(8) 8 | 9 | def __init__(self, col=-1, op=-1, arg=None): 10 | self.cond_col = col 11 | self.operator = op 12 | self.arg = arg 13 | return 14 | 15 | @staticmethod 16 | def numValue(entry, annEntry, default): 17 | ret = default 18 | if annEntry.date != None: 19 | ret = float(annEntry.date.toordinal()) 20 | elif annEntry.number != config.NaN: 21 | ret = annEntry.number 22 | elif annEntry.num2 != config.NaN: 23 | ret = annEntry.num2 24 | elif util.is_float_try(entry): 25 | ret = float(entry) 26 | return ret 27 | 28 | # Check which entries satisfy the condition 29 | def check(self, qinfo, subtab_rows): 30 | entries = qinfo.entries 31 | annTab = qinfo.annTab 32 | if self.operator == Condition.OpEqRow: 33 | cond_row = self.arg 34 | #print("debug: cond_row, self.cond_col = ", cond_row, self.cond_col) 35 | cond_val = entries[cond_row][self.cond_col].lower() 36 | 37 | try: 38 | ret = set([r for r in subtab_rows if entries[r][self.cond_col].lower() == cond_val]) 39 | except: 40 | print("debug: qinfo", qinfo.seq_qid, qinfo.table_file) 41 | print("debug: cond_row, self.cond_col = ", cond_row, self.cond_col) 42 | print("debug: subtab_rows", subtab_rows) 43 | print("debug: entries", entries) 44 | 45 | return ret 46 | elif self.operator == Condition.OpNeRow: 47 | cond_row = self.arg 48 | cond_val = entries[cond_row][self.cond_col].lower() 49 | return set([r for r in subtab_rows if entries[r][self.cond_col].lower() != cond_val]) 50 | elif self.operator == Condition.OpGT: 51 | cond_val = self.arg 52 | return set([r for r in subtab_rows if self.numValue(entries[r][self.cond_col], annTab[(r,self.cond_col)], float_info.min) > cond_val]) 53 | elif self.operator == Condition.OpGE: 54 | cond_val = self.arg 55 | return set([r for r in subtab_rows if self.numValue(entries[r][self.cond_col], annTab[(r,self.cond_col)], float_info.min) >= cond_val]) 56 | elif self.operator == Condition.OpLT: 57 | cond_val = self.arg 58 | return set([r for r in subtab_rows if self.numValue(entries[r][self.cond_col], annTab[(r,self.cond_col)], float_info.max) < cond_val]) 59 | elif self.operator == Condition.OpLE: 60 | cond_val = self.arg 61 | return set([r for r in subtab_rows if self.numValue(entries[r][self.cond_col], annTab[(r,self.cond_col)], float_info.max) <= cond_val]) 62 | elif self.operator == Condition.OpArgMin: 63 | numeric_values = [self.numValue(entries[r][self.cond_col], annTab[(r,self.cond_col)], float_info.max) for r in subtab_rows] 64 | if len(numeric_values) != len(subtab_rows): 65 | return set() 66 | min_idx = min((v,i) for i,v in enumerate(numeric_values))[1] 67 | return [min_idx] 68 | elif self.operator == Condition.OpArgMax: 69 | numeric_values = [self.numValue(entries[r][self.cond_col], annTab[(r,self.cond_col)], float_info.min) for r in subtab_rows] 70 | if len(numeric_values) != len(subtab_rows): 71 | return set() 72 | max_idx = max((v,i) for i,v in enumerate(numeric_values))[1] 73 | return [max_idx] 74 | 75 | assert False, "Unknown condition operator: %d" % self.operator 76 | 77 | class Parse: 78 | Independent, FollowUp = xrange(2) 79 | 80 | def __init__(self): 81 | self.type = Parse.Independent 82 | self.select_columns = [] # the list of columns to be selected 83 | self.conditions = [] # list of Condition objects; meaning that it needs to satisfy all the conditions 84 | 85 | def run(self, qinfo, resinfo): 86 | if self.type == Parse.Independent: 87 | if self.select_columns == []: # if no columns are selected, return all columns 88 | select_columns = xrange(qinfo.num_columns) 89 | else: 90 | select_columns = self.select_columns 91 | if self.conditions == []: 92 | legit_rows = set(xrange(qinfo.num_rows)) 93 | else: 94 | legit_rows = reduce(lambda x,y: x.intersection(y), [cond.check(qinfo, xrange(qinfo.num_rows)) for cond in self.conditions]) 95 | elif self.type == Parse.FollowUp: 96 | # TODO, ASSUMING THE PREVIOUS ANSWERS ARE SINGLE COLUMN HERE 97 | ans_col = resinfo.prev_pred_answer_column # answer column 98 | if self.conditions == []: 99 | # TODO: it should be using the previous answers, but because we're restricted to only one column answers, we have to restrict to that column for now 100 | # return resinfo.prev_pred_answer_coordinates 101 | return [coor for coor in resinfo.prev_pred_answer_coordinates if coor[1] == ans_col] 102 | else: 103 | legit_rows = reduce(lambda x,y: x.intersection(y), [cond.check(qinfo, resinfo.subtab_rows) for cond in self.conditions]) 104 | select_columns = [ans_col] 105 | else: 106 | assert False, "Unknown parse type: %d" % self.type 107 | return [(r,c) for r in legit_rows for c in select_columns] 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DynSP (Dynamic Neural Semantic Parser) 3 | 4 | This project contains the source code of the Dynamic Neural Semantic Parser (DynSP), 5 | based on [DyNet](https://github.com/clab/dynet). 6 | 7 | Detail of DynSP can be found in the following ACL-2017 paper: 8 | 9 | [Mohit Iyyer](https://people.cs.umass.edu/~miyyer/), [Wen-tau Yih](http://scottyih.org), [Ming-Wei Chang](https://ming-wei-chang.github.io/). 10 | [Search-based Neural Structured Learning for Sequential Question Answering.](http://aclweb.org/anthology/P17-1167) ACL-2017. 11 | 12 | @InProceedings{iyyer-yih-chang:2017:Long, 13 | author = {Iyyer, Mohit and Yih, Wen-tau and Chang, Ming-Wei}, 14 | title = {Search-based Neural Structured Learning for Sequential Question Answering}, 15 | booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 16 | month = {July}, 17 | year = {2017}, 18 | address = {Vancouver, Canada}, 19 | publisher = {Association for Computational Linguistics}, 20 | pages = {1821--1831}, 21 | } 22 | 23 | The output files and the models for producing the reported results are also included. Below are the scripts that 24 | produces the results reported in Table 2 of the paper (DynSP and DynSP*). 25 | 26 | ```bash 27 | $ ./check.sh moduleKwNoMap-b15-ind 28 | moduleKwNoMap-b15-ind 29 | Best Accuracy: 0.425963 (Reward: 0.479099) at epoch 28 30 | Best Accuracy: 0.369095 (Reward: 0.423323) at epoch 10 31 | Best Accuracy: 0.348668 (Reward: 0.405460) at epoch 24 32 | Best Accuracy: 0.377477 (Reward: 0.439594) at epoch 29 33 | Best Accuracy: 0.349951 (Reward: 0.413719) at epoch 20 34 | Best Accuracy: 0.351802 (Reward: 0.409400) at epoch 19 35 | 0.359399 20 36 | ``` 37 | 38 | ```bash 39 | $ ./evalModel-indep.sh moduleKwNoMap-b15-ind 20 40 | Sequence Accuracy = 10.15% (104/1025) 41 | Answer Accuracy = 41.97% (1264/3012) 42 | Break-down: 43 | Position 0 Accuracy = 70.93% (727/1025) 44 | Position 1 Accuracy = 35.84% (367/1024) 45 | Position 2 Accuracy = 20.06% (137/683) 46 | Position 3 Accuracy = 12.23% (28/229) 47 | Position 4 Accuracy = 13.16% (5/38) 48 | Position 5 Accuracy = 0.00% (0/9) 49 | Position 6 Accuracy = 0.00% (0/4) 50 | ``` 51 | 52 | ```bash 53 | $ ./check.sh moduleKwNoMap-b15 54 | moduleKwNoMap-b15 55 | Best Accuracy: 0.450863 (Reward: 0.516281) at epoch 17 56 | Best Accuracy: 0.379691 (Reward: 0.439837) at epoch 12 57 | Best Accuracy: 0.366021 (Reward: 0.422335) at epoch 16 58 | Best Accuracy: 0.391892 (Reward: 0.456894) at epoch 26 59 | Best Accuracy: 0.370968 (Reward: 0.442918) at epoch 20 60 | Best Accuracy: 0.368468 (Reward: 0.431721) at epoch 18 61 | 0.375408 18 62 | ``` 63 | 64 | ```bash 65 | $ ./evalModel.sh moduleKwNoMap-b15 18 66 | Sequence Accuracy = 12.78% (131/1025) 67 | Answer Accuracy = 44.65% (1345/3012) 68 | Break-down: 69 | Position 0 Accuracy = 70.44% (722/1025) 70 | Position 1 Accuracy = 41.11% (421/1024) 71 | Position 2 Accuracy = 23.57% (161/683) 72 | Position 3 Accuracy = 13.97% (32/229) 73 | Position 4 Accuracy = 18.42% (7/38) 74 | Position 5 Accuracy = 11.11% (1/9) 75 | Position 6 Accuracy = 25.00% (1/4) 76 | ``` 77 | 78 | _A [tokenzier](https://github.com/myleott/ark-twokenize-py) and some data files are not included in the initial release 79 | due to licencing issues. They can be found at a [fork](https://github.com/scottyih/DynSP)._ 80 | The [Sequential Question Answering (SQA) dataset](https://www.microsoft.com/en-us/download/details.aspx?id=54253), 81 | published and used in the same paper, can be downloaded separately. 82 | 83 | 84 | 85 | 86 | # Contributing 87 | 88 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 89 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 90 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 91 | 92 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 93 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 94 | provided by the bot. You will only need to do this once across all repos using our CLA. 95 | 96 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 97 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 98 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 99 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /action.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from Parse import * 3 | from collections import namedtuple 4 | Action = namedtuple("Action", "type idx row col val") 5 | ''' 6 | class Action: 7 | def __init__(self, act_type, act_idx, row=-1, col=-1, val=-1): 8 | self.type = act_type # action type 9 | self.idx = act_idx # action index 10 | self.row = row # row number -- used by WhereEqRow 11 | self.col = col # column number -- used by Select, WhereCol 12 | ''' 13 | 14 | ### Action type: 15 | # 16 | # Original action type: 17 | # (0) START 18 | # (1) STOP (no more condition, 1) 19 | # (2) SELECT X (# table columns) 20 | # (3) WHERE Y=? (# columns, can be other operators) 21 | # (4) WHERE Y=Z (# rows) 22 | # (5) WHERE Y != Z (# rows) 23 | # (6) WHERE Y > z1 (# values in Q) 24 | # (7) WHERE Y >= z1 (# values in Q) 25 | # (8) WHERE Y < z1 (# values in Q) 26 | # (9) WHERE Y <= z1 (# values in Q) 27 | # (10) WHERE Y is ArgMin (1) 28 | # (11) WHERE Y is ArgMax (1) 29 | # 30 | # Follow-up action type: 31 | # (12) SameAsPrevious (1) 32 | # (13) WHERE Y=? (# columns, can be other operators) 33 | # (14) WHERE Y=Z (# rows in the subtable) 34 | # (15) WHERE Y != Z (# rows in the subtable) 35 | # (16) WHERE Y > z1 (# values in Q) 36 | # (17) WHERE Y >= z1 (# values in Q) 37 | # (18) WHERE Y < z1 (# values in Q) 38 | # (19) WHERE Y <= z1 (# values in Q) 39 | # (20) WHERE Y is ArgMin (1) 40 | # (21) WHERE Y is ArgMax (1) 41 | # 42 | # Legit sequence: Check OneNote for the transition diagram 43 | class ActionType: 44 | Num_Types = 22 45 | Start, \ 46 | Stop, Select, WhereCol, CondEqRow, CondNeRow, CondGT, CondGE, CondLT, CondLE, ArgMin, ArgMax, \ 47 | SameAsPrevious, FpWhereCol, FpCondEqRow, FpCondNeRow, FpCondGT, \ 48 | FpCondGE, FpCondLT, FpCondLE, FpArgMin, FpArgMax = xrange(Num_Types) 49 | LegitNextActionType = {} 50 | IndepLegitNextActionType = {} 51 | 52 | IndepLegitNextActionType[Start] = [Select] 53 | 54 | LegitNextActionType[Start] = [Select, SameAsPrevious, FpWhereCol] 55 | LegitNextActionType[Select] = [Stop, WhereCol] 56 | #LegitNextActionType[Select] = [Stop] 57 | LegitNextActionType[Stop] = [] 58 | LegitNextActionType[WhereCol] = [CondEqRow, CondNeRow, CondGT, CondGE, CondLT, CondLE, ArgMin, ArgMax] 59 | LegitNextActionType[CondEqRow] = [] 60 | LegitNextActionType[CondNeRow] = [] 61 | LegitNextActionType[CondGT] = [] 62 | LegitNextActionType[CondGE] = [] 63 | LegitNextActionType[CondLT] = [] 64 | LegitNextActionType[CondLE] = [] 65 | LegitNextActionType[ArgMin] = [] 66 | LegitNextActionType[ArgMax] = [] 67 | 68 | LegitNextActionType[SameAsPrevious] = [] 69 | LegitNextActionType[FpWhereCol] = [FpCondEqRow, FpCondNeRow, FpCondGT, FpCondGE, FpCondLT, FpCondLE, FpArgMin, FpArgMax] 70 | LegitNextActionType[FpCondEqRow] = [] 71 | LegitNextActionType[FpCondNeRow] = [] 72 | LegitNextActionType[FpCondGT] = [] 73 | LegitNextActionType[FpCondGE] = [] 74 | LegitNextActionType[FpCondLT] = [] 75 | LegitNextActionType[FpCondLE] = [] 76 | LegitNextActionType[FpArgMin] = [] 77 | LegitNextActionType[FpArgMax] = [] 78 | 79 | # group the action types based on their numbers of instances and more 80 | SingleInstanceActions = set([Start, Stop, SameAsPrevious, ArgMin, ArgMax, FpArgMin, FpArgMax]) 81 | ColumnActions = set([Select, WhereCol, FpWhereCol]) 82 | RowActions = set([CondEqRow, CondNeRow]) 83 | SubTabRowActions = set([FpCondEqRow, FpCondNeRow]) 84 | QuesValueActions = set([CondGE, CondGT, CondLE, CondLT, FpCondGT, FpCondGE, FpCondLT, FpCondLE]) 85 | 86 | WhereConditions = set([CondEqRow, CondNeRow, CondGE, CondGT, CondLE, CondLT, ArgMin, ArgMax]) 87 | FpWhereConditions = set([FpCondEqRow, FpCondNeRow, FpCondGT, FpCondGE, FpCondLT, FpCondLE, FpArgMin, FpArgMax]) 88 | Conditions = WhereConditions.union(FpWhereConditions) 89 | FirstQuestionActions = set([Start, Stop, Select, WhereCol, CondEqRow, CondNeRow, CondGT, CondGE, CondLT, CondLE, ArgMin, ArgMax]) 90 | FollowUpQuestionActions = set([Start, Stop, SameAsPrevious, FpWhereCol, FpCondEqRow, FpCondNeRow, FpCondGT, \ 91 | FpCondGE, FpCondLT, FpCondLE, FpArgMin, FpArgMax]) 92 | if config.d["ReduceRowCond"]: 93 | EqRowConditions = set([CondEqRow, CondNeRow, FpCondEqRow, FpCondNeRow]) 94 | else: 95 | EqRowConditions = set([]) 96 | 97 | NumericConditions = QuesValueActions.union(set([ArgMin, ArgMax, FpArgMin, FpArgMax])) 98 | 99 | class ActionFactory: 100 | def __init__(self, qinfo, resinfo=None): 101 | self.qinfo = qinfo 102 | self.actions = [] 103 | self.type2actidxs = {} 104 | self.legit_next_action_idxs_cache = {} 105 | self.legit_next_action_idxs_history_cache = {} 106 | if resinfo == None: 107 | self.subtab_rows = [] 108 | else: 109 | self.subtab_rows = resinfo.subtab_rows 110 | 111 | #print("qinfo.question:", qinfo.question) 112 | #print("qinfo.values_in_ques:", qinfo.values_in_ques) 113 | 114 | # Create "instances" of action types; essentially list the number of the actions each type can have 115 | p = 0 116 | for act_type in xrange(ActionType.Num_Types): 117 | # number of instances: 1 118 | if act_type in ActionType.SingleInstanceActions: 119 | num_acts = self.append_actions(act_type, p) 120 | # number of instances: # columns 121 | elif act_type in ActionType.ColumnActions: 122 | num_acts = self.append_actions(act_type, p, qinfo.num_columns, cols = xrange(qinfo.num_columns)) 123 | # number of instances: # rows 124 | elif act_type in ActionType.RowActions: 125 | num_acts = self.append_actions(act_type, p, qinfo.num_rows, rows = xrange(qinfo.num_rows)) 126 | # number of instances: # rows in subtable 127 | elif act_type in ActionType.SubTabRowActions: 128 | num_acts = self.append_actions(act_type, p, len(self.subtab_rows), rows = self.subtab_rows) 129 | # number of instances: # values in question 130 | elif act_type in ActionType.QuesValueActions: 131 | #print("qinfo.values_in_ques = ", qinfo.values_in_ques, len(qinfo.values_in_ques)) 132 | num_acts = self.append_actions(act_type, p, len(qinfo.values_in_ques), values=qinfo.values_in_ques) 133 | 134 | self.type2actidxs[act_type] = xrange(p, p+num_acts) 135 | p += num_acts 136 | 137 | # special action -- Start 138 | self.start_action_idx = self.type2actidxs[ActionType.Start][0] 139 | 140 | def append_actions(self, act_type, start_idx, num_act = 1, rows = None, cols = None, values = None): 141 | if num_act == 1 and rows == None and cols == None and values == None: # do not need either row or column info 142 | self.actions.append(Action(act_type, start_idx, -1, -1, -1)) 143 | elif rows == None and cols != None and values == None: # need column info 144 | assert(num_act == len(cols)) 145 | for p in xrange(num_act): 146 | act_idx = start_idx + p 147 | self.actions.append(Action(act_type, act_idx, -1, cols[p], -1)) 148 | elif rows != None and cols == None and values == None: # need row info 149 | assert(num_act == len(rows)) 150 | for p in xrange(num_act): 151 | act_idx = start_idx + p 152 | self.actions.append(Action(act_type, act_idx, rows[p], -1, -1)) 153 | elif rows == None and cols == None and values != None: # need values in the question 154 | assert(num_act == len(values)) 155 | for p,v in enumerate(values): 156 | act_idx = start_idx + p 157 | self.actions.append(Action(act_type, act_idx, -1, -1, v)) 158 | else: 159 | print("Debug: rows = ", rows, "cols = ", cols, "values = ", values) 160 | assert False, "Error! Unknown act_type: %d" % act_type 161 | return num_act 162 | 163 | def legit_next_action_idxs(self, act_idx, action_history = None): 164 | 165 | # check history cache 166 | if action_history != None: 167 | action_history_key = ','.join(map(str,action_history)) 168 | if action_history_key in self.legit_next_action_idxs_history_cache: 169 | return self.legit_next_action_idxs_history_cache[action_history_key] 170 | 171 | # The first chunk of this function determines the possible next actions based only on the given current action. 172 | # In other words, it only checks whether the transition is defined previously in the action-space graph. 173 | if act_idx not in self.legit_next_action_idxs_cache: 174 | #print ("debug:", act_idx, self.actions) 175 | act = self.actions[act_idx] 176 | act_type = act.type 177 | if not self.subtab_rows and act_type in ActionType.IndepLegitNextActionType: # no previous question, and special state 178 | legit_next_action_types = ActionType.IndepLegitNextActionType[act_type] 179 | else: 180 | legit_next_action_types = ActionType.LegitNextActionType[act_type] 181 | ret = [] 182 | for legit_type in legit_next_action_types: 183 | if legit_type in ActionType.EqRowConditions: # remove some equivalent equal row conditions (changing the action to condition on value directly is difficult given the current code design...) 184 | addedCondValues = set() 185 | for legit_act_idx in self.type2actidxs[legit_type]: 186 | act = self.actions[legit_act_idx] 187 | entVal = self.qinfo.entries[act.row][act.col] 188 | if entVal in addedCondValues: 189 | continue 190 | addedCondValues.add(entVal) 191 | ret.append(legit_act_idx) 192 | elif legit_type in ActionType.NumericConditions: 193 | cond_col = self.actions[act_idx].col 194 | if cond_col not in self.qinfo.numeric_cols: # not a numeric column 195 | continue 196 | 197 | #print ("debug: passed numeric columns", self.qinfo.table_file, cond_col) 198 | for legit_act_idx in self.type2actidxs[legit_type]: 199 | ret.append(legit_act_idx) 200 | else: 201 | for legit_act_idx in self.type2actidxs[legit_type]: 202 | ret.append(legit_act_idx) 203 | self.legit_next_action_idxs_cache[act_idx] = ret 204 | else: 205 | ret = self.legit_next_action_idxs_cache[act_idx] 206 | 207 | # The second chunk of this function "prunes" some of the possible actions by looking at the existing partial parse, 208 | # as well as taking the cues from the question. The goal is to reduce the search space whenever possible. 209 | if action_history != None: # Check if actions that have to be unique already occur 210 | act_set = set(action_history) 211 | new_ret = [] 212 | for act_idx in ret: 213 | act = self.actions[act_idx] 214 | if (act.type in ActionType.Conditions) and (act_idx in act_set): # no duplicate conditions 215 | continue 216 | if (self.qinfo.pos == 0) and (act.type not in ActionType.FirstQuestionActions): # first question 217 | continue 218 | # follow-up question with keywords indicating it's a dependent parse 219 | # TODO: unify the code in action_history_quality 220 | if self.qinfo.pos != 0 \ 221 | and (self.qinfo.contain_ngram('of those') or self.qinfo.contain_ngram('which one') or self.qinfo.contain_ngram('which ones')) \ 222 | and (act.type not in ActionType.FollowUpQuestionActions): 223 | continue 224 | new_ret.append(act_idx) 225 | 226 | #if new_ret == []: 227 | # print (self.qinfo.seq_qid, self.qinfo.question, action_history, ret) 228 | ret = new_ret 229 | self.legit_next_action_idxs_history_cache[action_history_key] = ret 230 | 231 | #print("legit_next_action_idxs", "input:", act_idx, "output:", ret) 232 | 233 | return ret 234 | 235 | def find_actions(self, act_idxs, act_type): 236 | return [self.actions[act_idx] for act_idx in act_idxs if self.actions[act_idx].type == act_type] 237 | 238 | # map a sequence of actions to a parse 239 | def action_history_to_parse(self, act_idxs): 240 | 241 | #print("act_idxs", act_idxs) 242 | 243 | parse = Parse() 244 | 245 | p, act_history_length = 0, len(act_idxs) 246 | while p < act_history_length: 247 | act_idx = act_idxs[p] 248 | act = self.actions[act_idx] 249 | if act.type == ActionType.Select: # having SELECT meaning it's an independent parse 250 | parse.type = Parse.Independent 251 | parse.select_columns.append(act.col) # record the column it selects 252 | elif act.type == ActionType.WhereCol: 253 | col = act.col 254 | # have to assume that the follow-up action is about the operator and argument 255 | p += 1 256 | #print("p =", p) 257 | act_idx = act_idxs[p] 258 | 259 | act = self.actions[act_idx] 260 | 261 | if act.type != ActionType.Stop: 262 | # This is the only legitimate type after WhereCol currently. Will have to expand the coverage later for more action types. 263 | assert (act.type in ActionType.WhereConditions), "Illegit action type after WhereCol: %d" % act.type 264 | if act.type == ActionType.CondEqRow: 265 | cond = Condition(col, Condition.OpEqRow, act.row) 266 | elif act.type == ActionType.CondNeRow: 267 | cond = Condition(col, Condition.OpNeRow, act.row) 268 | elif act.type == ActionType.CondGT: 269 | cond = Condition(col, Condition.OpGT, act.val[1]) 270 | elif act.type == ActionType.CondGE: 271 | cond = Condition(col, Condition.OpGE, act.val[1]) 272 | elif act.type == ActionType.CondLT: 273 | cond = Condition(col, Condition.OpLT, act.val[1]) 274 | elif act.type == ActionType.CondLE: 275 | cond = Condition(col, Condition.OpLE, act.val[1]) 276 | elif act.type == ActionType.ArgMin: 277 | cond = Condition(col, Condition.OpArgMin) 278 | elif act.type == ActionType.ArgMax: 279 | cond = Condition(col, Condition.OpArgMax) 280 | parse.conditions.append(cond) 281 | elif act.type == ActionType.SameAsPrevious: 282 | parse.type = Parse.FollowUp 283 | elif act.type == ActionType.FpWhereCol: 284 | parse.type = Parse.FollowUp 285 | col = act.col 286 | # have to assume that the follow-up action is about the operator and argument 287 | p += 1 288 | act_idx = act_idxs[p] 289 | act = self.actions[act_idx] 290 | 291 | if act.type != ActionType.Stop: 292 | # This is the only legitimate type after WhereCol currently. Will have to expand the coverage later for more action types. 293 | assert (act.type in ActionType.FpWhereConditions), "Illegit action type after FpWhereCol: %d" % act.type 294 | if act.type == ActionType.FpCondEqRow: 295 | cond = Condition(col, Condition.OpEqRow, act.row) 296 | elif act.type == ActionType.FpCondNeRow: 297 | cond = Condition(col, Condition.OpNeRow, act.row) 298 | elif act.type == ActionType.FpCondGT: 299 | cond = Condition(col, Condition.OpGT, act.val[1]) 300 | elif act.type == ActionType.FpCondGE: 301 | cond = Condition(col, Condition.OpGE, act.val[1]) 302 | elif act.type == ActionType.FpCondLT: 303 | cond = Condition(col, Condition.OpLT, act.val[1]) 304 | elif act.type == ActionType.FpCondLE: 305 | cond = Condition(col, Condition.OpLE, act.val[1]) 306 | elif act.type == ActionType.FpArgMin: 307 | cond = Condition(col, Condition.OpArgMin) 308 | elif act.type == ActionType.FpArgMax: 309 | cond = Condition(col, Condition.OpArgMax) 310 | parse.conditions.append(cond) 311 | else: 312 | assert (act.type == ActionType.Start or act.type == ActionType.Stop), "Unknown action type: %d" % act.type 313 | p += 1 314 | 315 | return parse 316 | 317 | def action_history_quality(self, act_idxs): 318 | # if it's a follow-up question & contains "of those" or "which", then it's forced to switch "follow up" actions 319 | if self.qinfo.pos != 0 \ 320 | and (self.qinfo.contain_ngram('of those') or self.qinfo.contain_ngram('which one') or self.qinfo.contain_ngram('which ones')): 321 | parse = self.action_history_to_parse(act_idxs) # TODO: action_history_to_parse has been called multiple times 322 | #print ("\t".join([self.actidx2str(act) for act in act_idxs])) 323 | if parse.type == Parse.Independent: 324 | return 0.0 325 | 326 | return 1.0 327 | 328 | # for debugging 329 | def actidx2str(self, act_idx): 330 | action = self.actions[act_idx] 331 | if action.type == ActionType.Start: 332 | #return "START" 333 | return "" 334 | elif action.type == ActionType.Stop: 335 | #return "Stop" 336 | return "" 337 | elif action.type == ActionType.Select: 338 | col = action.col 339 | return "SELECT %s" % self.qinfo.headers[col] 340 | elif action.type == ActionType.WhereCol: 341 | col = action.col 342 | return "WHERE %s" % self.qinfo.headers[col] 343 | elif action.type == ActionType.CondEqRow: 344 | return "= ROW %d" % action.row 345 | elif action.type == ActionType.CondNeRow: 346 | return "!= ROW %d" % action.row 347 | elif action.type == ActionType.CondGT: 348 | return "> %f" % action.val[1] 349 | elif action.type == ActionType.CondGE: 350 | return ">= %f" % action.val[1] 351 | elif action.type == ActionType.CondLT: 352 | return "< %f" % action.val[1] 353 | elif action.type == ActionType.CondLE: 354 | return "<= %f" % action.val[1] 355 | elif action.type == ActionType.ArgMin: 356 | return "is Min" 357 | elif action.type == ActionType.ArgMax: 358 | return "is Max" 359 | elif action.type == ActionType.SameAsPrevious: 360 | return "SameAsPrevious" 361 | elif action.type == ActionType.FpWhereCol: 362 | col = action.col 363 | return "FollowUp WHERE %s" % self.qinfo.headers[col] 364 | elif action.type == ActionType.FpCondEqRow: 365 | return "= ROW %d" % action.row 366 | elif action.type == ActionType.FpCondNeRow: 367 | return "!= ROW %d" % action.row 368 | elif action.type == ActionType.FpCondGT: 369 | return "> %f" % action.val[1] 370 | elif action.type == ActionType.FpCondGE: 371 | return ">= %f" % action.val[1] 372 | elif action.type == ActionType.FpCondLT: 373 | return "< %f" % action.val[1] 374 | elif action.type == ActionType.FpCondLE: 375 | return "<= %f" % action.val[1] 376 | elif action.type == ActionType.FpArgMin: 377 | return "is Min" 378 | elif action.type == ActionType.FpArgMax: 379 | return "is Max" 380 | else: 381 | assert False, "Error! Unknown action.type: %d" % action.type -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Usage: check.sh expSym" 5 | exit 1 6 | fi 7 | 8 | echo $1 9 | 10 | awkscript='/ Accuracy/ { best[FILENAME] = $0; acc[FILENAME] = $3; iter[FILENAME] = $NF } 11 | END { n = asorti(best,sbest); 12 | for (i=1;i<=n;i++) { print best[sbest[i]] } 13 | for (i=2;i<=n;i++) { a += acc[sbest[i]]; s += iter[sbest[i]] } 14 | print a/5, int(s/5 + 0.5) 15 | }' 16 | gawk "$awkscript" output/"$1"-[0-5].out 17 | 18 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | import sys,io,os 9 | 10 | reload(sys) 11 | sys.setdefaultencoding('utf8') 12 | 13 | import codecs 14 | sys.stdout = codecs.getwriter('utf8')(sys.stdout) 15 | sys.stderr = codecs.getwriter('utf8')(sys.stderr) 16 | 17 | from collections import Counter 18 | import random 19 | import time 20 | import tempfile 21 | 22 | import util 23 | import marisa_trie 24 | 25 | data_path = "./cache/" 26 | d = {} 27 | 28 | NaN=-9999999 29 | OnlyOneBest = True 30 | d["DropOut"] = True 31 | d["partial_reward"] = True 32 | d["ReduceRowCond"] = False 33 | 34 | d["USE_PRETRAIN_WORD_EMBEDDING"] = True 35 | d["WORD_EMBEDDING_DIM"] = 100 36 | 37 | d["DIST_BIAS_DIM"] = 100 38 | 39 | #d["USE_PRETRAIN_WORD_EMBEDDING"] = False 40 | #d["WORD_EMBEDDING_DIM"] = 2 41 | 42 | #d["record_path"] = data_path + "/glove.twitter.100d.trie" 43 | d["LSTM_HIDDEN_DIM"] = 50 44 | d["record_path"] = data_path + "/glove.6b.100d.trie" 45 | 46 | # d["WORD_EMBEDDING_DIM"] = 50 47 | # d["word_embeddings"] = 50 48 | # d["record_path"] = data_path +"/senna.wiki.50d.trie" 49 | 50 | d["recordtriestructure"] = "<" + "".join(["f"] * d["WORD_EMBEDDING_DIM"]) 51 | d["recordtriesep"] = u"|" 52 | d["embeddingtrie"] = marisa_trie.RecordTrie(d["recordtriestructure"]) 53 | d["embeddingtrie"].mmap(d["record_path"]) 54 | 55 | d["beam_size"] = 15 56 | d["NUM_ITER"] = 30 57 | 58 | UPDATE_WORD_EMD = 0 59 | NOUPDATE_WORD_EMD_ALL = 1 60 | NOUPDATE_WORD_EMD_PRETRAIN = 2 61 | 62 | d["updateEMB"] = UPDATE_WORD_EMD 63 | 64 | #d["AnnotatedTableDir"] = "/home/scottyih/Neural-RL-SP/Arvind_data/annotated" 65 | #d["dirTable"] = "file:///D:/ScottYih/Source/Repos/Neural-RL-SP/data" 66 | 67 | d["AnnotatedTableDir"] = "Arvind_data/annotated" 68 | d["dirTable"] = "data" 69 | 70 | #d["guessLogPass"] = True 71 | d["guessLogPass"] = False 72 | #d["verbose-dump"] = True 73 | d["verbose-dump"] = False 74 | 75 | if __name__ == '__main__': 76 | print("test") 77 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | import sys, time, argparse, csv 11 | import cProfile 12 | 13 | if sys.version < '3': 14 | from codecs import getwriter 15 | stderr = getwriter('utf-8')(sys.stderr) 16 | stdout = getwriter('utf-8')(sys.stdout) 17 | else: 18 | stderr = sys.stderr 19 | 20 | import dynet as dt 21 | from collections import Counter 22 | import random 23 | import util 24 | import config 25 | import cPickle 26 | 27 | from action import * 28 | from statesearch import * 29 | from Parse import * 30 | 31 | ######## START OF THE CODE ######## 32 | 33 | def main1(): 34 | dat = util.get_labeled_questions(str("data/nt-13588_2.tsv"), "data") 35 | fLog = sys.stdout 36 | for i,qinfo in enumerate(dat,1): 37 | if qinfo.seq_qid[-1] != '0': 38 | parse = Parse() 39 | parse.type = Parse.FollowUp 40 | cond = Condition(3,Condition.OpEqRow,7) 41 | parse.conditions = [cond] 42 | pred = parse.run(qinfo,resinfo) 43 | 44 | fLog.write("(%s) %s\n" % (qinfo.seq_qid, qinfo.question)) 45 | fLog.write("Answer: %s\n" % ", ".join(["(%d,%d)" % coord for coord in qinfo.answer_coordinates])) 46 | fLog.write("Predictions: %s\n" % ", ".join(["(%d,%d)" % coord for coord in pred])) 47 | fLog.write("\n") 48 | fLog.flush() 49 | 50 | # use the gold answers 51 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, 52 | qinfo.answer_coordinates, qinfo.answer_column_idx) 53 | 54 | def main(): 55 | dat = util.get_labeled_questions(str("data/random-split-2-dev.tsv"), "data") 56 | 57 | for qinfo in dat: 58 | print("%s\t%s" % (qinfo.question, list(util.findNumbers(qinfo.ques_word_sequence)))) 59 | 60 | print (len(dat)) 61 | return 62 | 63 | reader = csv.DictReader(open("data/train.tsv", 'r'), delimiter=str('\t')) 64 | for row in reader: 65 | ques = row['question'] 66 | #print("%s\t%s" % (ques, list(util.findNumbers(ques)))) 67 | print("%s\t%s" % (ques, list(util.findNumbers(ques.lower().split(' '))))) 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import csv, sys 2 | 3 | # read tsv files with the four essential columns: id, annotator, position, answer_coordinates 4 | # output: dt: sid -> pos -> ansCord 5 | def readTsv(fnTsv): 6 | dt = {} 7 | for row in csv.DictReader(open(fnTsv, 'r'), delimiter='\t'): 8 | sid = row['id'] + '\t' + row['annotator'] # sequence id 9 | pos = int(row['position']) # position 10 | ansCord = set(eval(row['answer_coordinates'])) # answer coordinates 11 | if not sid in dt: 12 | dt[sid] = {} 13 | dt[sid][pos] = ansCord 14 | return dt 15 | 16 | 17 | def evaluate(fnGold, fnPred): 18 | 19 | dtGold = readTsv(fnGold) 20 | dtPred = readTsv(fnPred) 21 | 22 | # Calcuate both sequence-level accuracy and question-level accuracy 23 | seqCnt = seqCor = 0 24 | ansCnt = ansCor = 0 25 | breakCorrect, breakTotal = {},{} 26 | for sid,qa in dtGold.items(): 27 | seqCnt += 1 28 | ansCnt += len(qa) 29 | 30 | if sid not in dtPred: continue # sequence does not exist in the prediction 31 | 32 | predQA = dtPred[sid] 33 | allQCorrect = True 34 | for q,a in qa.items(): 35 | if q not in breakTotal: 36 | breakCorrect[q] = breakTotal[q] = 0 37 | breakTotal[q] += 1 38 | 39 | if q in predQA and a == predQA[q]: 40 | ansCor += 1 # correctly answered question 41 | breakCorrect[q] += 1 42 | else: 43 | allQCorrect = False 44 | if allQCorrect: seqCor += 1 45 | 46 | print "Sequence Accuracy = %0.2f%% (%d/%d)" % (100.0 * seqCor/seqCnt, seqCor, seqCnt) 47 | print "Answer Accuracy = %0.2f%% (%d/%d)" % (100.0 * ansCor/ansCnt, ansCor, ansCnt) 48 | 49 | print "Break-down:" 50 | for q in sorted(breakTotal.keys()): 51 | print "Position %d Accuracy = %0.2f%% (%d/%d)" % (q, 100.0 * breakCorrect[q]/breakTotal[q], breakCorrect[q], breakTotal[q]) 52 | 53 | return [seqCor, seqCnt, ansCor, ansCnt] 54 | 55 | 56 | if __name__ == '__main__': 57 | if len(sys.argv) != 3: 58 | sys.stderr.write("Usage: %s goldTsv predTsv\n" % sys.argv[0]) 59 | sys.exit(-1) 60 | fnGold = sys.argv[1] 61 | fnPred = sys.argv[2] 62 | evaluate(fnGold, fnPred) 63 | -------------------------------------------------------------------------------- /evalModel-indep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 2 ]; then 4 | echo "Usage: evalModel.sh trialname iter" 5 | exit 9 6 | fi 7 | 8 | #alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so python' 9 | alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/libmkl_def.so:/opt/intel/mkl/lib/libmkl_avx2.so:/opt/intel/mkl/lib/libmkl_core.so:/opt/intel/mkl/lib/libmkl_intel_lp64.so:/opt/intel/mkl/lib/libmkl_intel_thread.so:/opt/intel/lib/libiomp5.so python2' 10 | dm=2000 11 | ds=1 12 | 13 | name="$1"-"0" 14 | model=model/"$name"-"$2" 15 | dirData=data 16 | res=results/"$name"-"$2".tsv 17 | 18 | mklpython sqafollow.py --dynet-mem $dm --dynet-seed $ds --expSym 0 --evalModel $model --dirData $dirData --res $res --indep 19 | python eval.py data/test.tsv $res 20 | -------------------------------------------------------------------------------- /evalModel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 2 ]; then 4 | echo "Usage: evalModel.sh trialname iter" 5 | exit 9 6 | fi 7 | 8 | #alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so python' 9 | alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/libmkl_def.so:/opt/intel/mkl/lib/libmkl_avx2.so:/opt/intel/mkl/lib/libmkl_core.so:/opt/intel/mkl/lib/libmkl_intel_lp64.so:/opt/intel/mkl/lib/libmkl_intel_thread.so:/opt/intel/lib/libiomp5.so python2' 10 | dm=2000 11 | ds=1 12 | 13 | name="$1"-"0" 14 | model=model/"$name"-"$2" 15 | dirData=data 16 | res=results/"$name"-"$2".tsv 17 | 18 | mklpython sqafollow.py --dynet-mem $dm --dynet-seed $ds --expSym 0 --evalModel $model --dirData $dirData --res $res 19 | python eval.py data/test.tsv $res 20 | -------------------------------------------------------------------------------- /model/moduleKwNoMap-b15-0-18-dynetmodel.bin: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c9bf0b70df680f09d1f16f948a79e4d823601cd16b204a4506a379a26f713f01 3 | size 54962479 4 | -------------------------------------------------------------------------------- /model/moduleKwNoMap-b15-0-18-extra.bin: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b341b60f2b460633d3cec8f019f25a402681a2047a5a322ffcd6fea98307e198 3 | size 1410711 4 | -------------------------------------------------------------------------------- /model/moduleKwNoMap-b15-ind-0-20-dynetmodel.bin: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:800e0429ded193f5371781f2514b62d7c6b4362441124e924b321803b72d43fb 3 | size 54962611 4 | -------------------------------------------------------------------------------- /model/moduleKwNoMap-b15-ind-0-20-extra.bin: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b341b60f2b460633d3cec8f019f25a402681a2047a5a322ffcd6fea98307e198 3 | size 1410711 4 | -------------------------------------------------------------------------------- /nnmodule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | import sys, time, argparse, csv 11 | import cProfile 12 | 13 | if sys.version < '3': 14 | from codecs import getwriter 15 | stderr = getwriter('utf-8')(sys.stderr) 16 | stdout = getwriter('utf-8')(sys.stdout) 17 | else: 18 | stderr = sys.stderr 19 | 20 | import dynet as dt 21 | from collections import Counter 22 | import random 23 | import util 24 | import config 25 | import cPickle 26 | 27 | #---------------------------------------------------------------------- 28 | 29 | """ This class is used as a score combiner. """ 30 | class FeedForwardModel: 31 | def __init__(self, model, dim_input, dim_hidden=-1): 32 | if dim_hidden == -1: # 33 | dim_hidden = dim_input 34 | self.W1 = model.add_parameters((dim_hidden, dim_input)) 35 | self.W2 = model.add_parameters((1, dim_hidden)) 36 | 37 | """ This is used for initializing parameter expressions for each example. """ 38 | def spawn_expression(self): 39 | return FeedForwardExp(self.W1, self.W2) 40 | 41 | class FeedForwardExp: 42 | def __init__(self, W1, W2): 43 | self.W1 = dt.parameter(W1) 44 | self.W2 = dt.parameter(W2) 45 | 46 | """ input_exp should be a vector of the scores to be combined """ 47 | def score_expression(self, input_exp): 48 | return self.W2 * dt.tanh(self.W1 * input_exp) 49 | 50 | #---------------------------------------------------------------------- 51 | 52 | class QuestionColumnMatchModel: 53 | def __init__(self, model, dim_word_embedding): 54 | self.ColW = model.add_parameters((dim_word_embedding)) 55 | def spawn_expression(self): 56 | return QuestionColumnMatchExp(self.ColW) 57 | 58 | class QuestionColumnMatchExp: 59 | def __init__(self, ColW): 60 | self.ColW = dt.parameter(ColW) 61 | 62 | """ return a list of scores """ 63 | def score_expression(self, qwVecs, qwAvgVec, qLSTMVec, colnameVec, colWdVecs): 64 | colPriorScore = dt.dot_product(self.ColW, colnameVec) 65 | colMaxScore = AvgMaxScore(qwVecs, colWdVecs) 66 | colAvgScore = AvgScore(qwAvgVec, colnameVec) 67 | colQLSTMScore = AvgScore(qLSTMVec, colnameVec) 68 | ret = [colPriorScore, colMaxScore, colAvgScore, colQLSTMScore] 69 | return ret 70 | 71 | #---------------------------------------------------------------------- 72 | 73 | class NegationModel: 74 | def __init__(self, model, dim_word_embedding, UNK, init_keywords = '', vw = None, E = None): 75 | if init_keywords: 76 | vector = dt.average([E[vw.w2i.get(w, UNK)] for w in init_keywords.split(' ')]) 77 | self.NegW = model.parameters_from_numpy(vector.npvalue()) 78 | else: 79 | self.NegW = model.add_parameters((dim_word_embedding)) 80 | def spawn_expression(self): 81 | return NegationExp(self.NegW) 82 | 83 | class NegationExp: 84 | def __init__(self, NegW): 85 | self.NegW = dt.parameter(NegW) 86 | 87 | def score_expression(self, qwAvgVec): 88 | ret = dt.dot_product(qwAvgVec, self.NegW) 89 | return ret 90 | 91 | #---------------------------------------------------------------------- 92 | 93 | class CompareModel: 94 | def __init__(self, model, dim_word_embedding, UNK, init_keywords = '', vw = None, E = None): 95 | if init_keywords: 96 | vector = dt.average([E[vw.w2i.get(w, UNK)] for w in init_keywords.split(' ')]) 97 | self.OpW = model.parameters_from_numpy(vector.npvalue()) 98 | else: 99 | self.OpW = model.add_parameters((dim_word_embedding)) 100 | def spawn_expression(self): 101 | return CompareExp(self.OpW) 102 | 103 | class CompareExp: 104 | def __init__(self, OpW): 105 | self.OpW = dt.parameter(OpW) 106 | 107 | def score_expression(self, qwVecs, numWdPos): 108 | if numWdPos == 0: 109 | kwVec = qwVecs[numWdPos+1] 110 | elif numWdPos == 1: 111 | kwVec = qwVecs[0] 112 | else: 113 | kwVec = dt.average(qwVecs[numWdPos-2:numWdPos]) 114 | 115 | ret = dt.dot_product(kwVec, self.OpW) 116 | return ret 117 | 118 | #---------------------------------------------------------------------- 119 | 120 | class ArgModel: 121 | def __init__(self, model, dim_word_embedding, UNK, init_keywords = '', vw = None, E = None): 122 | if init_keywords: 123 | vector = dt.average([E[vw.w2i.get(w, UNK)] for w in init_keywords.split(' ')]) 124 | self.OpW = model.parameters_from_numpy(vector.npvalue()) 125 | else: 126 | self.OpW = model.add_parameters((dim_word_embedding)) 127 | def spawn_expression(self): 128 | return ArgExp(self.OpW) 129 | 130 | class ArgExp: 131 | def __init__(self, OpW): 132 | self.OpW = dt.parameter(OpW) 133 | 134 | def score_expression(self, qwVecs): 135 | ret = MaxScore(qwVecs, self.OpW) 136 | return ret 137 | 138 | #---------------------------------------------------------------------- 139 | 140 | def AvgVector(txtV): 141 | if type(txtV) == list: 142 | vec = dt.average(txtV) 143 | else: 144 | vec = txtV 145 | return vec 146 | 147 | ''' txtV1 and txtV2 can be either a vector or a list of vectors ''' 148 | def AvgScore(txtV1, txtV2): 149 | vec1 = AvgVector(txtV1) 150 | vec2 = AvgVector(txtV2) 151 | ret = dt.dot_product(vec1, vec2) 152 | return ret 153 | 154 | ''' both qwVecs and colWdVecs have to be lists of vectors ''' 155 | def AvgMaxScore(qwVecs, colWdVecs): 156 | ret = dt.average([dt.emax([dt.dot_product(qwVec, colWdVec) for qwVec in qwVecs]) for colWdVec in colWdVecs]) 157 | return ret 158 | 159 | def MaxScore(qwVecs, vec): 160 | ret = dt.emax([dt.dot_product(qwVec, vec) for qwVec in qwVecs]) 161 | return ret 162 | 163 | -------------------------------------------------------------------------------- /run5fold-firstq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Usage: run5fold.sh trialname" 5 | exit 9 6 | fi 7 | 8 | #alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so python' 9 | alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/libmkl_def.so:/opt/intel/mkl/lib/libmkl_avx2.so:/opt/intel/mkl/lib/libmkl_core.so:/opt/intel/mkl/lib/libmkl_intel_lp64.so:/opt/intel/mkl/lib/libmkl_intel_thread.so:/opt/intel/lib/libiomp5.so python2' 10 | dm=2000 11 | ds=1 12 | 13 | for p in 0 1 2 3 4 5; 14 | do 15 | name="$1"-"$p" 16 | out=output/"$name".out 17 | model=model/"$name" 18 | log=log/"$name".log 19 | dirData=data 20 | echo $out 21 | mklpython sqafollow.py --dynet-mem $dm --dynet-seed $ds --expSym $p --dirData $dirData --firstOnly > $out 2>&1 & 22 | echo "Split: $p, PID: $!" 23 | done 24 | -------------------------------------------------------------------------------- /run5fold-indep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Usage: run5fold.sh trialname" 5 | exit 9 6 | fi 7 | 8 | #alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so python' 9 | alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/libmkl_def.so:/opt/intel/mkl/lib/libmkl_avx2.so:/opt/intel/mkl/lib/libmkl_core.so:/opt/intel/mkl/lib/libmkl_intel_lp64.so:/opt/intel/mkl/lib/libmkl_intel_thread.so:/opt/intel/lib/libiomp5.so python2' 10 | dm=2000 11 | ds=1 12 | 13 | for p in 0 1 2 3 4 5; 14 | do 15 | name="$1"-"$p" 16 | out=output/"$name".out 17 | model=model/"$name" 18 | log=log/"$name".log 19 | dirData=data 20 | echo $out 21 | mklpython sqafollow.py --dynet-mem $dm --dynet-seed $ds --expSym $p --model $model --log $log --dirData $dirData --indep > $out 2>&1 & 22 | echo "Split: $p, PID: $!" 23 | sleep 5 24 | done 25 | -------------------------------------------------------------------------------- /run5fold.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Usage: run5fold.sh trialname" 5 | exit 9 6 | fi 7 | 8 | #alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so python' 9 | alias mklpython='LD_PRELOAD=/opt/intel/mkl/lib/libmkl_def.so:/opt/intel/mkl/lib/libmkl_avx2.so:/opt/intel/mkl/lib/libmkl_core.so:/opt/intel/mkl/lib/libmkl_intel_lp64.so:/opt/intel/mkl/lib/libmkl_intel_thread.so:/opt/intel/lib/libiomp5.so python2' 10 | dm=2000 11 | ds=1 12 | 13 | for p in 0 1 2 3 4 5 14 | do 15 | name="$1"-"$p" 16 | out=output/"$name".out 17 | model=model/"$name" 18 | log=log/"$name".log 19 | dirData=data 20 | echo $out 21 | mklpython sqafollow.py --dynet-mem $dm --dynet-seed $ds --expSym $p --model $model --log $log --dirData $dirData > $out 2>&1 & 22 | echo "Split: $p, PID: $!" 23 | done 24 | -------------------------------------------------------------------------------- /seqtagging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | # from sys import stdin 11 | # reload(sys) 12 | # sys.setdefaultencoding('utf8') 13 | 14 | import sys 15 | 16 | if sys.version < '3': 17 | from codecs import getwriter 18 | stderr = getwriter('utf-8')(sys.stderr) 19 | stdout = getwriter('utf-8')(sys.stdout) 20 | else: 21 | stderr = sys.stderr 22 | 23 | import dynet as dt 24 | from collections import Counter 25 | import random 26 | import util 27 | 28 | from statesearch import * 29 | 30 | # format of files: each line is "wordtag", blank line is new sentence. 31 | 32 | class SeqState: 33 | O_state = -1 34 | non_I_state_list = [] 35 | 36 | def __init__(self,sentence,vt): 37 | #print(sentence, vt) 38 | self.action_history = [] 39 | self.words = sentence 40 | self.vt =vt 41 | self.n_tags = vt.size() 42 | self.tag_idx = 0 43 | 44 | if SeqState.O_state == -1: 45 | SeqState.O_state = vt.w2i["O"] 46 | for i in range(vt.size()): 47 | if "I-" not in vt.i2w[i]: 48 | SeqState.non_I_state_list.append(i) 49 | print("non I states", SeqState.non_I_state_list) 50 | print ("ntags",vt.size()) 51 | 52 | 53 | # def get_action_set(self): 54 | # return range(self.n_tags) 55 | 56 | def get_action_set(self): 57 | if (self.tag_idx > 0 and self.action_history[-1] == SeqState.O_state): 58 | return SeqState.non_I_state_list 59 | return range(self.n_tags) 60 | 61 | def get_action_set_withans(self,gold_ans): 62 | gold_act = gold_ans[self.tag_idx] 63 | return [[gold_act],[1]] 64 | #return self.get_action_set() # no trick is done here 65 | 66 | def is_end(self): 67 | assert len(self.action_history) <= len(self.words) 68 | if len(self.action_history) == len(self.words): 69 | return True 70 | else: 71 | return False 72 | 73 | def reward(self,golds): 74 | good = 0.0 75 | bad = 0.0 76 | 77 | for w, gold_tag, pred_tag in zip(self.words, golds, self.action_history): 78 | if gold_tag == pred_tag: 79 | good += 1 80 | else: 81 | bad += 1 82 | #print (good) 83 | #return good 84 | return good/len(golds) 85 | 86 | # the estimated final reward value of by treating the partial pass as the full actions, after executing the given partial actions; 87 | # if you do not know how to estimate the reward given the partial sequence, should return 0 here 88 | def estimated_reward(self, gold_ans, action): 89 | good = 0.9 90 | bad = 0.9 91 | actions = self.action_history[:] + [action] 92 | for i in range(len(actions)): 93 | if gold_ans[i] == actions[i]: 94 | good += 1 95 | else: 96 | bad += 1 97 | 98 | return good/len(actions) 99 | 100 | 101 | class SeqTaggingModel(): 102 | 103 | def __init__(self,init_grad,n_words,n_tags): 104 | self.model = dt.Model() 105 | self.learner = dt.SimpleSGDTrainer(self.model,e0=init_grad) 106 | self.E = self.model.add_lookup_parameters((n_words, 128)) 107 | 108 | self.pH = self.model.add_parameters((32, 50*2)) 109 | self.pO = self.model.add_parameters((n_tags, 32)) 110 | 111 | self.builders=[ 112 | dt.LSTMBuilder(1, 128, 50, self.model), 113 | dt.LSTMBuilder(1, 128, 50, self.model), 114 | ] 115 | 116 | 117 | class SeqScoreExpressionState(SeqState): 118 | 119 | def __init__(self, nmodel,sentence,vw,vt,copy=False): 120 | #super(SeqScoreExpressionState, self).__init__(sentence,vt) 121 | #super(SeqScoreExpressionState, self).__init__() 122 | SeqState.__init__(self,sentence,vt) 123 | self.path_score_expression = 0 124 | self.score = 0 125 | self.nm = nmodel 126 | self.vw = vw 127 | 128 | if copy == False: 129 | UNK = self.vw.w2i["_UNK_"] 130 | sent = self.words 131 | f_init, b_init = [b.initial_state() for b in self.nm.builders] 132 | wembs = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in sent] 133 | self.fw = [x.output() for x in f_init.add_inputs(wembs)] 134 | self.bw = [x.output() for x in b_init.add_inputs(reversed(wembs))] 135 | self.bw.reverse() 136 | self.H = dt.parameter(self.nm.pH) 137 | self.O = dt.parameter(self.nm.pO) 138 | 139 | def get_next_score_expressions(self, action_list): 140 | #assert len(action_list) == self.n_tags 141 | i_repr = dt.concatenate([self.fw[self.tag_idx],self.bw[self.tag_idx]]) 142 | r_t = self.O*(dt.tanh(self.H * i_repr)) #results for all list 143 | res_list=[] 144 | for action in action_list: 145 | res_list.append(r_t[action]) 146 | 147 | return [dt.concatenate(res_list), [0] * len(action_list)] #row or cols? 148 | 149 | def get_new_state_after_action(self, action,meta_info): 150 | assert action in self.get_action_set() 151 | new_state = self.clone() 152 | 153 | #will make it call helper function from parnet instead 154 | new_state.action_history.append(action) 155 | new_state.tag_idx +=1 156 | return new_state 157 | 158 | def clone(self): 159 | res = SeqScoreExpressionState(self.nm,self.words,self.vw,self.vt,copy=True) 160 | 161 | #will make it call helper function from parnet instead 162 | res.action_history = self.action_history[:] 163 | res.tag_idx = self.tag_idx 164 | 165 | res.fw = self.fw 166 | res.bw = self.bw 167 | res.H = self.H 168 | res.O = self.O 169 | 170 | return res 171 | 172 | def __str__(self): 173 | #return "<"+" ".join(self.words) + "> at idx: " + str(self.tag_idx) 174 | return "> " + str(self.tag_idx) + ": " + " ".join([str(x) for x in self.action_history]) 175 | 176 | def main(): 177 | train_file="./data/conll-4types-bio-dev.txt" 178 | test_file="./data/conll-4types-bio-test.txt" 179 | 180 | train=list(util.read(train_file)) 181 | test=list(util.read(test_file)) 182 | # train = train[:1] 183 | # test = train 184 | words=[] 185 | tags=[] 186 | 187 | wc=Counter() 188 | 189 | for s in train: 190 | for w,p in s: 191 | words.append(w) 192 | tags.append(p) 193 | wc[w]+=1 194 | 195 | 196 | words.append("_UNK_") 197 | #words=[w if wc[w] > 1 else "_UNK_" for w in words] 198 | #tags.append("_START_") 199 | 200 | for s in test: 201 | for w,p in s: 202 | words.append(w) 203 | 204 | vw = util.Vocab.from_corpus([words]) 205 | vt = util.Vocab.from_corpus([tags]) 206 | 207 | # for i in range(vt.size()): 208 | # print("tagidx,tag",i,vt.i2w[i]) 209 | nwords = vw.size() 210 | ntags = vt.size() 211 | 212 | 213 | neural_model = SeqTaggingModel(0.01,nwords,ntags) 214 | sm = BeamSearchInferencer(neural_model,5) 215 | 216 | for ITER in xrange(1000): 217 | random.shuffle(train) 218 | loss = 0 219 | for i,s in enumerate(train,1): 220 | dt.renew_cg() #very important! to renew the cg 221 | words = [x[0] for x in s] 222 | tags = [x[1] for x in s] 223 | tags_idxes = [vt.w2i[t] for t in tags] 224 | init_state = SeqScoreExpressionState(neural_model,words,vw,vt,False) 225 | loss += sm.beam_train_max_margin_with_answer_guidence(init_state,tags_idxes) 226 | 227 | #loss += sm.beam_train_expected_reward(init_state,tags_idxes) 228 | #loss += sm.beam_train_max_margin(init_state,tags_idxes) 229 | #loss += sm.beam_train_max_margin_with_goldactions(init_state,tags_idxes) 230 | #loss += sm.greedy_train_max_sumlogllh(init_state,tags_idxes) 231 | 232 | if i % 1000 == 0: 233 | print (i) 234 | 235 | neural_model.learner.update_epoch(1.0) 236 | 237 | accuracy = 0.0 238 | total = 0.0 239 | oidx = vt.w2i['O'] 240 | ocount = 0.0 241 | for i,s in enumerate(test,1): 242 | dt.renew_cg() #very important! to renew the cg 243 | words = [x[0] for x in s] 244 | tags = [x[1] for x in s] 245 | tags_idxes = [vt.w2i[t] for t in tags] 246 | init_state = SeqScoreExpressionState(neural_model,words,vw,vt,False) 247 | top_state = sm.beam_predict(init_state)[0] 248 | #top_state = sm.greedy_predict(init_state) 249 | for t in top_state.action_history: 250 | if t == oidx: 251 | ocount += 1 252 | 253 | accuracy += top_state.reward(tags_idxes) 254 | 255 | total += len(tags) 256 | 257 | print ("accuracy",accuracy/total) 258 | print ("O percentage",ocount/total) 259 | 260 | print("In epoch ", ITER, " avg loss (or negative reward) is ", loss) 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /sqafirst.1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | # from sys import stdin 11 | # reload(sys) 12 | # sys.setdefaultencoding('utf8') 13 | 14 | import sys, time 15 | 16 | if sys.version < '3': 17 | from codecs import getwriter 18 | stderr = getwriter('utf-8')(sys.stderr) 19 | stdout = getwriter('utf-8')(sys.stdout) 20 | else: 21 | stderr = sys.stderr 22 | 23 | import dynet as dt 24 | from collections import Counter 25 | import random 26 | import util 27 | 28 | from statesearch import * 29 | 30 | ######## START OF THE CODE ######## 31 | 32 | class SqaState: 33 | # Action type: 34 | # (1) SELECT X (# table columns) 35 | # (2) WHERE NULL (no condition, 1) 36 | # (3) WHERE Y=? (# columns) 37 | # (4) WHERE Y=Z (# rows) 38 | # Legit sequence: (1) -> (2), (1) -> (3) -> (4) 39 | 40 | ActSelect, ActWhereNul, ActWhereCol, ActWhereEqRow = xrange(4) 41 | 42 | def __init__(self, qinfo): 43 | self.action_history = [] 44 | self.qinfo = qinfo 45 | self.numCol = len(qinfo.headers) 46 | self.numRow = len(qinfo.entries) 47 | self.act2type = {} 48 | 49 | # Define the actions 50 | 51 | # ActSelect 52 | self.actSetSelectStartIdx = 0 53 | self.actSetSelect = xrange(self.actSetSelectStartIdx, self.actSetSelectStartIdx + self.numCol) 54 | for act in self.actSetSelect: 55 | self.act2type[act] = SqaState.ActSelect 56 | 57 | # ActWhereNul 58 | self.actSetWhereNulStartIdx = self.actSetSelectStartIdx + len(self.actSetSelect) 59 | self.actSetWhereNul = xrange(self.actSetWhereNulStartIdx, self.actSetWhereNulStartIdx + 1) 60 | for act in self.actSetWhereNul: 61 | self.act2type[act] = SqaState.ActWhereNul 62 | 63 | # ActWhereCol 64 | self.actSetWhereColStartIdx = self.actSetWhereNulStartIdx + len(self.actSetWhereNul) 65 | self.actSetWhereCol = xrange(self.actSetWhereColStartIdx, self.actSetWhereColStartIdx + self.numCol) 66 | for act in self.actSetWhereCol: 67 | self.act2type[act] = SqaState.ActWhereCol 68 | 69 | # ActWhereEqRow 70 | self.actSetWhereEqRowStartIdx = self.actSetWhereColStartIdx + len(self.actSetWhereCol) 71 | self.actSetWhereEqRow = xrange(self.actSetWhereEqRowStartIdx, self.actSetWhereEqRowStartIdx + self.numRow) 72 | for act in self.actSetWhereEqRow: 73 | self.act2type[act] = SqaState.ActWhereEqRow 74 | 75 | 76 | ### Auxiliary routines for action index mapping 77 | # Given the id of an action that belongs to ActSelect, return the column number 78 | def selectAct2Col(self, act): 79 | # check if actIdx belongs to the right action type 80 | if (self.act2type[act] != SqaState.ActSelect): 81 | return None 82 | col = act - self.actSetSelectStartIdx 83 | return col 84 | 85 | ### Auxiliary routines for action index mapping 86 | # Given the id of an action that belongs to ActWhereEq, return the table entry coordinate 87 | def whereEqAct2Coord(self, act): 88 | # check if actIdx belongs to the right action type 89 | if (self.act2type[act] != SqaState.ActWhereEq): 90 | return None 91 | idx = act - self.actSetWhereEqStartIdx 92 | r = idx // self.numCol 93 | c = idx % self.numCol 94 | return (r,c) 95 | 96 | ### Auxiliary routines for action index mapping 97 | # Given the id of an action that belongs to ActWhereCol, return the column number 98 | def whereColAct2Col(self, act): 99 | if (self.act2type[act] != SqaState.ActWhereCol): 100 | return None 101 | col = act - self.actSetWhereColStartIdx 102 | return col 103 | 104 | ### Auxiliary routines for action index mapping 105 | # Given the id of an action that belongs to ActWhereEqRow, return the row number 106 | def whereEqRowAct2Row(self, act): 107 | if (self.act2type[act] != SqaState.ActWhereEqRow): 108 | return None 109 | row = act - self.actSetWhereEqRowStartIdx 110 | return row 111 | 112 | ### Auxiliary routines for action index mapping 113 | # Given the coordinate of a table entry, return the ActWhereEq id 114 | # Not sure if needed now 115 | def coord2whereEqAct(self,r,c): 116 | return r*numCol + c + self.actSetWhereEqStartIdx 117 | 118 | 119 | # Given the current state (self), return a list of legitimate actions 120 | # Currently, it follows the STAGG's fashion and requires it does SELECT first. We can later relax it. 121 | # 122 | # (1) ActSelect: SELECT X (# table columns) 123 | # (2) ActWhereNul: WHERE NULL (no condition, 1) 124 | # (3) ActWhereCol: WHERE Y=? (# columns) 125 | # (4) ActWhereEqRow: WHERE Y=Z (# rows) 126 | # Legit sequence: (1) -> (2), (1) -> (3) -> (4) 127 | # 128 | def get_action_set(self): 129 | if not self.action_history: # empty action_history 130 | return self.actSetSelect 131 | else: 132 | last_act = self.action_history[-1] 133 | if self.act2type[last_act] == SqaState.ActSelect: 134 | return list(self.actSetWhereNul) + list(self.actSetWhereCol) 135 | elif self.act2type[last_act] == SqaState.ActWhereCol: 136 | return list(self.actSetWhereEqRow) 137 | return [] 138 | 139 | def is_end(self): 140 | return not self.get_action_set() # empty action_set 141 | 142 | # return a set of action that can lead to the gold state from the current state 143 | def get_action_set_withans(self, gold_ans): 144 | ret = [] 145 | for act in self.get_action_set(): 146 | if self.estimated_reward(gold_ans, act) > 0: # TODO: fix redundant calling estimated_reward by the search code 147 | ret.append(act) 148 | return ret 149 | 150 | # the estimated final reward value of a full path, after executing the given action 151 | def estimated_reward(self, gold_ans, action): 152 | # if this action is the final action to goal state (i.e., (2) ActWhereNul or (4) ActWhereEqRow) 153 | # use the real reward directly 154 | if self.act2type[action] == SqaState.ActWhereNul or self.act2type[action] == SqaState.ActWhereEqRow: 155 | path = self.action_history + [action] 156 | return self.reward(gold_ans, path) 157 | else: # treat it as select column only 158 | if self.act2type[action] == SqaState.ActSelect: 159 | return self.reward(gold_ans, [action, self.actSetWhereNul[0]]) 160 | else: # action is (3) ActWhereCol: 161 | return self.reward(gold_ans, self.action_history + [self.actSetWhereNul[0]]) 162 | 163 | # Reward = #(Gold INTERSECT Pred) / #(Gold UNION Pred) 164 | def reward(self, gold, action_history = None): 165 | if not gold: 166 | gold = self.qinfo.answer_coordinates 167 | 168 | # execute the parse 169 | pred = self.execute_parse(action_history) 170 | 171 | # if verbose: 172 | # print("gold coordinates", gold) 173 | # print("pred coordinates", pred) 174 | setGold = set(gold) 175 | setPred = set(pred) 176 | 177 | return float(len(setGold.intersection(setPred))) / len(setGold.union(setPred)) 178 | #return int(float(len(setGold.intersection(setPred))) / len(setGold.union(setPred))) #0-1 reward 179 | 180 | # Currently, it follows the STAGG's fashion and requires it does SELECT first. We can later relax it. 181 | # 182 | # (1) ActSelect: SELECT X (# table columns) 183 | # (2) ActWhereNul: WHERE NULL (no condition, 1) 184 | # (3) ActWhereCol: WHERE Y=? (# columns) 185 | # (4) ActWhereEqRow: WHERE Y=Z (# rows) 186 | # Legit sequence: (1) -> (2), (1) -> (3) -> (4) 187 | # 188 | def execute_parse(self, action_history=None): 189 | if not action_history: 190 | action_history = self.action_history 191 | 192 | # only execute if the parse is complete (i.e., length-2 or length-3) 193 | if len(action_history) != 2 and len(action_history) != 3: 194 | return [] 195 | 196 | # answer column 197 | actSel = action_history[0] 198 | ans_col = self.selectAct2Col(actSel) 199 | 200 | # check where condition 201 | actWhere = action_history[1] 202 | if self.act2type[actWhere] == SqaState.ActWhereNul: 203 | legit_rows = [r for r in xrange(self.numRow)] 204 | elif self.act2type[actWhere] == SqaState.ActWhereCol: 205 | cond_col = self.whereColAct2Col(actWhere) 206 | actWhereEqRow = action_history[2] 207 | cond_row = self.whereEqRowAct2Row(actWhereEqRow) 208 | cond_val = self.qinfo.entries[cond_row][cond_col] 209 | legit_rows = [r for r in xrange(self.numRow) if self.qinfo.entries[r][cond_col].lower() == cond_val.lower()] 210 | 211 | return [(r,ans_col) for r in legit_rows] 212 | 213 | # For debugging 214 | def act2str(self, act): 215 | if self.act2type[act] == SqaState.ActSelect: 216 | col = self.selectAct2Col(act) 217 | return "SELECT %s" % self.qinfo.headers[col] 218 | elif self.act2type[act] == SqaState.ActWhereEq: 219 | r,c = self.whereEqAct2Coord(act) 220 | return "WHERE %s = '%s'" % (self.qinfo.headers[c], self.qinfo.entries[r][c]) 221 | else: # self.act2type[act] == SqaState.ActWhereNul: 222 | return "WHERE True" 223 | 224 | class SqaModel(): 225 | 226 | WORD_EMBEDDING_DIM = 128 227 | 228 | def __init__(self, init_learning_rate, n_words): 229 | self.model = dt.Model() 230 | self.learner = dt.SimpleSGDTrainer(self.model, e0=init_learning_rate) 231 | self.E = self.model.add_lookup_parameters((n_words, SqaModel.WORD_EMBEDDING_DIM)) 232 | # similarity(v,o): (R^Tv)^T (R^To) 233 | self.R = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM, SqaModel.WORD_EMBEDDING_DIM)) 234 | self.NulW = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM)) 235 | 236 | class SqaScoreExpressionState(SqaState): 237 | 238 | def __init__(self, nmodel, qinfo, vw, init_example = True): 239 | SqaState.__init__(self, qinfo) 240 | self.path_score_expression = dt.scalarInput(0) 241 | self.score = 0 242 | self.nm = nmodel 243 | self.vw = vw 244 | 245 | if init_example: 246 | UNK = self.vw.w2i["_UNK_"] 247 | self.ques_word_sequence = self.qinfo.ques_word_sequence() 248 | 249 | # vectors of question words 250 | #self.ques_emb = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.ques_word_sequence] 251 | self.ques_emb = dt.concatenate_cols([self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.ques_word_sequence]) 252 | #self.ques_avg_emb = dt.average(self.ques_emb) 253 | 254 | # avg. vectors of column names 255 | self.headers_embs = [] 256 | for colname_word_sequence in self.qinfo.headers_word_sequences(): 257 | colname_emb = dt.average([self.nm.E[self.vw.w2i.get(w, UNK)] for w in colname_word_sequence]) 258 | self.headers_embs.append(colname_emb) 259 | 260 | # avg. vectors of table entries 261 | self.entries_embs = [] 262 | for row_word_sequences in self.qinfo.entries_word_sequences(): 263 | row_embs = [] 264 | for cell_word_sequence in row_word_sequences: 265 | row_embs.append(dt.average([self.nm.E[self.vw.w2i.get(w, UNK)] for w in cell_word_sequence])) 266 | self.entries_embs.append(row_embs) 267 | 268 | self.R = dt.parameter(self.nm.R) 269 | self.NulW = dt.parameter(self.nm.NulW) 270 | 271 | 272 | def get_next_score_expressions(self, legit_actions): 273 | 274 | res_list = [] 275 | for act in legit_actions: 276 | act_type = self.act2type[act] 277 | #qwVecs = [dt.transpose(self.R) * qemb for qemb in self.ques_emb] 278 | qwVecs = dt.transpose(self.R) * self.ques_emb 279 | 280 | if act_type == SqaState.ActSelect: 281 | # question_embedding x column_name_embedding 282 | col = self.selectAct2Col(act) 283 | col_emb = self.headers_embs[col] 284 | colnameVec = dt.transpose(self.R) * col_emb 285 | # max_w sim(w,colname) 286 | #colScore = dt.emax([dt.dot_product(qwVec, colnameVec) for qwVec in qwVecs]) 287 | colScore = dt.softmax(dt.transpose(qwVecs) * colnameVec) 288 | res_list.append(colScore) 289 | 290 | elif act_type == SqaState.ActWhereCol: # same as SqaState.ActSelect 291 | # question_embedding x column_name_embedding 292 | col = self.whereColAct2Col(act) 293 | col_emb = self.headers_embs[col] 294 | colnameVec = dt.transpose(self.R) * col_emb 295 | # max_w sim(w,colname) 296 | #colScore = dt.emax([dt.dot_product(qwVec, colnameVec) for qwVec in qwVecs]) 297 | colScore = dt.softmax(dt.transpose(qwVecs) * colnameVec) 298 | res_list.append(colScore) 299 | 300 | elif act_type == SqaState.ActWhereEqRow: 301 | r = self.whereEqRowAct2Row(act) 302 | c = self.whereColAct2Col(self.action_history[-1]) # assuming the last action of the curren state is ActWhereCol 303 | entryVec = dt.transpose(self.R) * self.entries_embs[r][c] 304 | # max_w sim(w,entry) 305 | #entScore = dt.emax([dt.dot_product(qwVec, entryVec) for qwVec in qwVecs]) 306 | entScore = dt.softmax(dt.transpose(qwVecs) * entryVec) 307 | res_list.append(entScore) 308 | 309 | elif act_type == SqaState.ActWhereNul: 310 | #res_list.append(dt.dot_product(dt.average(qwVecs), self.NulW)) 311 | res_list.append((dt.transpose(qwVecs) * self.NulW) / dt.scalarInput(float(len(self.ques_word_sequence)))) 312 | 313 | return dt.concatenate(res_list) 314 | 315 | def get_new_state_after_action(self, action): 316 | assert action in self.get_action_set() 317 | new_state = self.clone() 318 | new_state.action_history.append(action) 319 | return new_state 320 | 321 | def clone(self): 322 | res = SqaScoreExpressionState(self.nm, self.qinfo, self.vw,False) 323 | res.action_history = self.action_history[:] 324 | res.ques_word_sequence = self.ques_word_sequence 325 | 326 | # vectors of question words 327 | res.ques_emb = self.ques_emb 328 | #res.ques_avg_emb = self.ques_avg_emb 329 | 330 | # avg. vectors of column names 331 | res.headers_embs = self.headers_embs 332 | 333 | # avg. vectors of table entries 334 | res.entries_embs = self.entries_embs 335 | res.R = self.R 336 | res.NulW = self.NulW 337 | 338 | return res 339 | 340 | def __str__(self): 341 | return "> " + "\t".join([self.act2str(act) for act in self.action_history]) 342 | 343 | def main(): 344 | 345 | # Prepare training and testing (development) data 346 | 347 | data_folder = "data" 348 | train_file="%s/random-split-1-train.first.tsv" % data_folder 349 | test_file="%s/random-split-1-dev.first.tsv" % data_folder 350 | 351 | train = util.get_labeled_questions(train_file, data_folder) 352 | test = util.get_labeled_questions(test_file, data_folder) 353 | 354 | #train = train[:1000] 355 | #test = train 356 | 357 | # create a word embedding table 358 | 359 | words = set(["_UNK_", "_EMPTY_"]) 360 | for ex in train: 361 | words.update(ex.all_words()) 362 | for ex in test: 363 | words.update(ex.all_words()) 364 | 365 | vw = util.Vocab.from_corpus([words]) 366 | nwords = vw.size() 367 | 368 | neural_model = SqaModel(0.01, nwords) 369 | sm = BeamSearchInferencer(neural_model,1) 370 | 371 | # main loop 372 | start_time = time.time() 373 | for ITER in xrange(100): 374 | random.shuffle(train) 375 | loss = 0 376 | for i,qinfo in enumerate(train,1): 377 | dt.renew_cg() # very important! to renew the cg 378 | 379 | init_state = SqaScoreExpressionState(neural_model, qinfo ,vw) 380 | #loss += sm.beam_train_max_margin(init_state, qinfo.answer_coordinates) 381 | loss += sm.beam_train_max_margin_with_answer_guidence(init_state, qinfo.answer_coordinates) 382 | 383 | if i % 100 == 0: 384 | print (i, "/", len(train)) 385 | 386 | neural_model.learner.update_epoch(1.0) 387 | 388 | accuracy = 0.0 389 | all_reward = 0.0 390 | total = 0.0 391 | for i,qinfo in enumerate(test,1): 392 | dt.renew_cg() # very important! to renew the cg 393 | init_state = SqaScoreExpressionState(neural_model, qinfo ,vw) 394 | top1_state = sm.beam_predict(init_state)[0] 395 | rew = top1_state.reward(qinfo.answer_coordinates) 396 | 397 | all_reward += rew 398 | accuracy += int(rew) # 0-1, only get a point if all predictions are correct 399 | total += 1 400 | 401 | print("In epoch ", ITER, " avg loss (or negative reward) is ", loss) 402 | print ("reward", all_reward/total) 403 | print ("accuracy", accuracy/total) 404 | now_time = time.time() 405 | print ("Time taken in this epoch", now_time - start_time) 406 | start_time = now_time 407 | print () 408 | sys.stdout.flush() 409 | 410 | if __name__ == '__main__': 411 | main() 412 | -------------------------------------------------------------------------------- /sqafirst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | # from sys import stdin 11 | # reload(sys) 12 | # sys.setdefaultencoding('utf8') 13 | 14 | import sys, time, argparse 15 | import cProfile 16 | 17 | if sys.version < '3': 18 | from codecs import getwriter 19 | stderr = getwriter('utf-8')(sys.stderr) 20 | stdout = getwriter('utf-8')(sys.stdout) 21 | else: 22 | stderr = sys.stderr 23 | 24 | import dynet as dt 25 | from collections import Counter 26 | import random 27 | import util 28 | import config 29 | 30 | from statesearch import * 31 | 32 | ######## START OF THE CODE ######## 33 | 34 | class SqaState: 35 | # Action type: 36 | # (1) SELECT X (# table columns) 37 | # (2) WHERE NULL (no condition, 1) 38 | # (3) WHERE Y=? (# columns) 39 | # (4) WHERE Y=Z (# rows) 40 | # Legit sequence: (1) -> (2), (1) -> (3) -> (4) 41 | 42 | ActSelect, ActWhereNul, ActWhereCol, ActWhereEqRow = xrange(4) 43 | 44 | def __init__(self, qinfo): 45 | self.action_history = [] 46 | self.qinfo = qinfo 47 | self.numCol = len(qinfo.headers) 48 | self.numRow = len(qinfo.entries) 49 | self.act2type = {} 50 | 51 | # Define the actions 52 | 53 | # ActSelect 54 | self.actSetSelectStartIdx = 0 55 | self.actSetSelect = xrange(self.actSetSelectStartIdx, self.actSetSelectStartIdx + self.numCol) 56 | for act in self.actSetSelect: 57 | self.act2type[act] = SqaState.ActSelect 58 | 59 | # ActWhereNul 60 | self.actSetWhereNulStartIdx = self.actSetSelectStartIdx + len(self.actSetSelect) 61 | self.actSetWhereNul = xrange(self.actSetWhereNulStartIdx, self.actSetWhereNulStartIdx + 1) 62 | for act in self.actSetWhereNul: 63 | self.act2type[act] = SqaState.ActWhereNul 64 | 65 | # ActWhereCol 66 | self.actSetWhereColStartIdx = self.actSetWhereNulStartIdx + len(self.actSetWhereNul) 67 | self.actSetWhereCol = xrange(self.actSetWhereColStartIdx, self.actSetWhereColStartIdx + self.numCol) 68 | for act in self.actSetWhereCol: 69 | self.act2type[act] = SqaState.ActWhereCol 70 | 71 | # ActWhereEqRow 72 | self.actSetWhereEqRowStartIdx = self.actSetWhereColStartIdx + len(self.actSetWhereCol) 73 | self.actSetWhereEqRow = xrange(self.actSetWhereEqRowStartIdx, self.actSetWhereEqRowStartIdx + self.numRow) 74 | for act in self.actSetWhereEqRow: 75 | self.act2type[act] = SqaState.ActWhereEqRow 76 | 77 | 78 | ### Auxiliary routines for action index mapping 79 | # Given the id of an action that belongs to ActSelect, return the column number 80 | def selectAct2Col(self, act): 81 | # check if actIdx belongs to the right action type 82 | if (self.act2type[act] != SqaState.ActSelect): 83 | return None 84 | col = act - self.actSetSelectStartIdx 85 | return col 86 | 87 | ### Auxiliary routines for action index mapping 88 | # Given the id of an action that belongs to ActWhereEq, return the table entry coordinate 89 | def whereEqAct2Coord(self, act): 90 | # check if actIdx belongs to the right action type 91 | if (self.act2type[act] != SqaState.ActWhereEq): 92 | return None 93 | idx = act - self.actSetWhereEqStartIdx 94 | r = idx // self.numCol 95 | c = idx % self.numCol 96 | return (r,c) 97 | 98 | ### Auxiliary routines for action index mapping 99 | # Given the id of an action that belongs to ActWhereCol, return the column number 100 | def whereColAct2Col(self, act): 101 | if (self.act2type[act] != SqaState.ActWhereCol): 102 | return None 103 | col = act - self.actSetWhereColStartIdx 104 | return col 105 | 106 | ### Auxiliary routines for action index mapping 107 | # Given the id of an action that belongs to ActWhereEqRow, return the row number 108 | def whereEqRowAct2Row(self, act): 109 | if (self.act2type[act] != SqaState.ActWhereEqRow): 110 | return None 111 | row = act - self.actSetWhereEqRowStartIdx 112 | return row 113 | 114 | ### Auxiliary routines for action index mapping 115 | # Given the coordinate of a table entry, return the ActWhereEq id 116 | # Not sure if needed now 117 | def coord2whereEqAct(self,r,c): 118 | return r*numCol + c + self.actSetWhereEqStartIdx 119 | 120 | 121 | # Given the current state (self), return a list of legitimate actions 122 | # Currently, it follows the STAGG's fashion and requires it does SELECT first. We can later relax it. 123 | # 124 | # (1) ActSelect: SELECT X (# table columns) 125 | # (2) ActWhereNul: WHERE NULL (no condition, 1) 126 | # (3) ActWhereCol: WHERE Y=? (# columns) 127 | # (4) ActWhereEqRow: WHERE Y=Z (# rows) 128 | # Legit sequence: (1) -> (2), (1) -> (3) -> (4) 129 | # 130 | def get_action_set(self): 131 | if not self.action_history: # empty action_history 132 | return self.actSetSelect 133 | else: 134 | last_act = self.action_history[-1] 135 | if self.act2type[last_act] == SqaState.ActSelect: 136 | return list(self.actSetWhereNul) + list(self.actSetWhereCol) 137 | elif self.act2type[last_act] == SqaState.ActWhereCol: 138 | return list(self.actSetWhereEqRow) 139 | return [] 140 | 141 | def is_end(self): 142 | #return len(self.action_history) == 1 # only SELECT X 143 | return not self.get_action_set() # empty action_set 144 | 145 | # return a set of action that can lead to the gold state from the current state 146 | def get_action_set_withans(self, gold_ans): 147 | ret = [] 148 | for act in self.get_action_set(): 149 | if self.estimated_reward(gold_ans, act) > 0: # TODO: fix redundant calling estimated_reward by the search code 150 | ret.append(act) 151 | return ret 152 | 153 | # the estimated final reward value of a full path, after executing the given action 154 | def estimated_reward(self, gold_ans, action): 155 | # if this action is the final action to goal state (i.e., (2) ActWhereNul or (4) ActWhereEqRow) 156 | # use the real reward directly 157 | if self.act2type[action] == SqaState.ActWhereNul or self.act2type[action] == SqaState.ActWhereEqRow: 158 | path = self.action_history + [action] 159 | return self.reward(gold_ans, path) 160 | else: # treat it as select column only 161 | if self.act2type[action] == SqaState.ActSelect: 162 | return self.reward(gold_ans, [action, self.actSetWhereNul[0]]) 163 | else: # action is (3) ActWhereCol: 164 | return self.reward(gold_ans, self.action_history + [self.actSetWhereNul[0]]) 165 | 166 | # Reward = #(Gold INTERSECT Pred) / #(Gold UNION Pred) 167 | def reward(self, gold, action_history = None): 168 | if not gold: 169 | gold = self.qinfo.answer_coordinates 170 | 171 | # execute the parse 172 | pred = self.execute_parse(action_history) 173 | 174 | # if verbose: 175 | # print("gold coordinates", gold) 176 | # print("pred coordinates", pred) 177 | setGold = set(gold) 178 | setPred = set(pred) 179 | 180 | ret = float(len(setGold.intersection(setPred))) / len(setGold.union(setPred)) 181 | #if (ret > 0 and ret < 1): 182 | # print("qid", self.qinfo.seq_qid, "gold:", gold, "pred:", pred, "reward:", ret) 183 | 184 | return ret 185 | #return int(float(len(setGold.intersection(setPred))) / len(setGold.union(setPred))) #0-1 reward 186 | 187 | # Currently, it follows the STAGG's fashion and requires it does SELECT first. We can later relax it. 188 | # 189 | # (1) ActSelect: SELECT X (# table columns) 190 | # (2) ActWhereNul: WHERE NULL (no condition, 1) 191 | # (3) ActWhereCol: WHERE Y=? (# columns) 192 | # (4) ActWhereEqRow: WHERE Y=Z (# rows) 193 | # Legit sequence: (1) -> (2), (1) -> (3) -> (4) 194 | # 195 | def execute_parse(self, action_history=None): 196 | if not action_history: 197 | action_history = self.action_history 198 | 199 | # only execute if the parse is complete (i.e., length-2 or length-3) 200 | #if len(action_history) != 1 and len(action_history) != 2: # and len(action_history) != 3: 201 | 202 | if len(action_history) != 2 and len(action_history) != 3: 203 | return [] 204 | 205 | # answer column 206 | actSel = action_history[0] 207 | ans_col = self.selectAct2Col(actSel) 208 | 209 | #return [(r,ans_col) for r in xrange(self.numRow) if (r,ans_col) not in self.qinfo.illegit_answer_coordinates] 210 | 211 | # check where condition 212 | actWhere = action_history[1] 213 | if self.act2type[actWhere] == SqaState.ActWhereNul: 214 | legit_rows = [r for r in xrange(self.numRow)] 215 | elif self.act2type[actWhere] == SqaState.ActWhereCol: 216 | cond_col = self.whereColAct2Col(actWhere) 217 | actWhereEqRow = action_history[2] 218 | cond_row = self.whereEqRowAct2Row(actWhereEqRow) 219 | cond_val = self.qinfo.entries[cond_row][cond_col] 220 | legit_rows = [r for r in xrange(self.numRow) if self.qinfo.entries[r][cond_col].lower() == cond_val.lower()] 221 | 222 | return [(r,ans_col) for r in legit_rows] 223 | 224 | # For debugging 225 | def act2str(self, act): 226 | if self.act2type[act] == SqaState.ActSelect: 227 | col = self.selectAct2Col(act) 228 | return "SELECT %s" % self.qinfo.headers[col] 229 | elif self.act2type[act] == SqaState.ActWhereEq: 230 | r,c = self.whereEqAct2Coord(act) 231 | return "WHERE %s = '%s'" % (self.qinfo.headers[c], self.qinfo.entries[r][c]) 232 | else: # self.act2type[act] == SqaState.ActWhereNul: 233 | return "WHERE True" 234 | 235 | class SqaModel(): 236 | 237 | WORD_EMBEDDING_DIM = config.d["WORD_EMBEDDING_DIM"] 238 | LSTM_HIDDEN_DIM = config.d["LSTM_HIDDEN_DIM"] 239 | 240 | def __init__(self, init_learning_rate, vw): 241 | self.model = dt.Model() 242 | self.vw = vw 243 | n_words = vw.size() 244 | 245 | self.learner = dt.SimpleSGDTrainer(self.model, e0=init_learning_rate) 246 | self.E = self.model.add_lookup_parameters((n_words, SqaModel.WORD_EMBEDDING_DIM)) 247 | # similarity(v,o): v^T o 248 | self.SelColW = self.model.add_parameters((4)) 249 | self.SelColWhereW = self.model.add_parameters((4)) 250 | self.NulW = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM)) 251 | self.ColW = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM)) 252 | 253 | # LSTM question representation 254 | self.builders=[ 255 | dt.LSTMBuilder(1, SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM, self.model), 256 | dt.LSTMBuilder(1, SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM, self.model) 257 | ] 258 | self.pH = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM*2)) 259 | 260 | if config.d["USE_PRETRAIN_WORD_EMBEDDING"]: 261 | n_hit_pretrain = 0.0 262 | trie = config.d["embeddingtrie"] 263 | print ("beginning to load embeddings....") 264 | for i in range(n_words): 265 | word = self.vw.i2w[i].lower() 266 | results = trie.items(word+ config.d["recordtriesep"]) 267 | if len(results) == 1: 268 | pretrain_v = np.array(list(results[0][1])) 269 | pretrain_v = pretrain_v/np.linalg.norm(pretrain_v) 270 | self.E.init_row(i,pretrain_v) 271 | n_hit_pretrain += 1 272 | else: 273 | pretrain_v = self.E[i].npvalue() 274 | pretrain_v = pretrain_v/np.linalg.norm(pretrain_v) 275 | self.E.init_row(i,pretrain_v) 276 | 277 | 278 | print ("the number of words that are in pretrain", n_hit_pretrain, n_words, n_hit_pretrain/n_words) 279 | print ("loading complete!") 280 | 281 | 282 | 283 | 284 | class SqaScoreExpressionState(SqaState): 285 | 286 | def __init__(self, nmodel, qinfo, vw, init_example = True): 287 | SqaState.__init__(self, qinfo) 288 | self.path_score_expression = dt.scalarInput(0) 289 | self.score = 0 290 | self.nm = nmodel 291 | self.vw = vw 292 | self.H = dt.parameter(self.nm.pH) 293 | 294 | if init_example: 295 | UNK = self.vw.w2i["_UNK_"] 296 | 297 | # vectors of question words 298 | self.ques_emb = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.qinfo.ques_word_sequence] 299 | #self.ques_avg_emb = dt.average(self.ques_emb) 300 | #self.ques_emb = dt.concatenate_cols([self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.qinfo.ques_word_sequence]) 301 | 302 | # avg. vectors of column names 303 | self.headers_embs = [] 304 | for colname_word_sequence in self.qinfo.headers_word_sequences: 305 | colname_emb = dt.average([self.nm.E[self.vw.w2i.get(w, UNK)] for w in colname_word_sequence]) 306 | self.headers_embs.append(colname_emb) 307 | 308 | # avg. vectors of table entries 309 | self.entries_embs = [] 310 | for row_word_sequences in self.qinfo.entries_word_sequences: 311 | row_embs = [] 312 | for cell_word_sequence in row_word_sequences: 313 | row_embs.append(dt.average([self.nm.E[self.vw.w2i.get(w, UNK)] for w in cell_word_sequence])) 314 | self.entries_embs.append(row_embs) 315 | 316 | self.NulW = dt.parameter(self.nm.NulW) 317 | self.ColW = dt.parameter(self.nm.ColW) 318 | self.SelColW = dt.parameter(self.nm.SelColW) 319 | self.SelColWhereW = dt.parameter(self.nm.SelColWhereW) 320 | 321 | # question LSTM 322 | f_init, b_init = [b.initial_state() for b in self.nm.builders] 323 | wembs = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.qinfo.ques_word_sequence] 324 | self.fw = [x.output() for x in f_init.add_inputs(wembs)] 325 | self.bw = [x.output() for x in b_init.add_inputs(reversed(wembs))] 326 | self.bw.reverse() 327 | 328 | 329 | def get_next_score_expressions(self, legit_actions): 330 | 331 | res_list = [] 332 | for act in legit_actions: 333 | act_type = self.act2type[act] 334 | qwVecs = self.ques_emb 335 | qwAvgVec = dt.average(qwVecs) 336 | 337 | i_repr = dt.concatenate([self.fw[-1],self.bw[0]]) 338 | qLSTMVec = dt.tanh(self.H * i_repr) # question words LSTM embedding 339 | 340 | if act_type == SqaState.ActSelect: 341 | # question_embedding x column_name_embedding 342 | col = self.selectAct2Col(act) 343 | colnameVec = self.headers_embs[col] 344 | 345 | colPriorScore = dt.dot_product(self.ColW, colnameVec) 346 | colMaxScore = dt.emax([dt.dot_product(qwVec, colnameVec) for qwVec in qwVecs]) 347 | colAvgScore = dt.dot_product(qwAvgVec, colnameVec) 348 | colQLSTMScore = dt.dot_product(qLSTMVec, colnameVec) 349 | 350 | colScore = dt.dot_product(self.SelColW, dt.concatenate([colPriorScore, colMaxScore, colAvgScore, colQLSTMScore])) 351 | 352 | res_list.append(colScore) 353 | 354 | elif act_type == SqaState.ActWhereCol: # same as SqaState.ActSelect 355 | # question_embedding x column_name_embedding 356 | col = self.whereColAct2Col(act) 357 | colnameVec = self.headers_embs[col] 358 | 359 | colPriorScore = dt.dot_product(self.ColW, colnameVec) 360 | colMaxScore = dt.emax([dt.dot_product(qwVec, colnameVec) for qwVec in qwVecs]) 361 | colAvgScore = dt.dot_product(qwAvgVec, colnameVec) 362 | colQLSTMScore = dt.dot_product(qLSTMVec, colnameVec) 363 | 364 | colScore = dt.dot_product(self.SelColWhereW, dt.concatenate([colPriorScore, colMaxScore, colAvgScore, colQLSTMScore])) 365 | 366 | res_list.append(colScore) 367 | 368 | elif act_type == SqaState.ActWhereEqRow: 369 | r = self.whereEqRowAct2Row(act) 370 | c = self.whereColAct2Col(self.action_history[-1]) # assuming the last action of the curren state is ActWhereCol 371 | entryVec = self.entries_embs[r][c] 372 | # max_w sim(w,entry) 373 | entScore = dt.emax([dt.dot_product(qwVec, entryVec) for qwVec in qwVecs]) 374 | res_list.append(entScore) 375 | 376 | elif act_type == SqaState.ActWhereNul: 377 | res_list.append(dt.dot_product(dt.average(qwVecs), self.NulW)) 378 | 379 | return dt.concatenate(res_list) 380 | 381 | def get_new_state_after_action(self, action): 382 | assert action in self.get_action_set() 383 | new_state = self.clone() 384 | new_state.action_history.append(action) 385 | return new_state 386 | 387 | def clone(self): 388 | res = SqaScoreExpressionState(self.nm, self.qinfo, self.vw, False) 389 | res.action_history = self.action_history[:] 390 | 391 | # vectors of question words 392 | res.ques_emb = self.ques_emb 393 | #res.ques_avg_emb = self.ques_avg_emb 394 | 395 | # avg. vectors of column names 396 | res.headers_embs = self.headers_embs 397 | 398 | # avg. vectors of table entries 399 | res.entries_embs = self.entries_embs 400 | res.ColW = self.ColW 401 | res.NulW = self.NulW 402 | 403 | res.SelColW = self.SelColW 404 | res.SelColWhereW = self.SelColWhereW 405 | res.fw = self.fw 406 | res.bw = self.bw 407 | 408 | return res 409 | 410 | def __str__(self): 411 | return "> " + "\t".join([self.act2str(act) for act in self.action_history]) 412 | 413 | def main(): 414 | 415 | parser = argparse.ArgumentParser(description='Targeting "first questions" only.') 416 | parser.add_argument('--expSym', help='1, 2, 3, 4, 5 or 0 (full)', type=int) 417 | parser.add_argument('--dynet-mem') 418 | parser.add_argument('--dynet-seed') 419 | args = parser.parse_args() 420 | 421 | # Prepare training and testing (development) data 422 | random.seed(1) 423 | 424 | data_folder = "data" 425 | if args.expSym == 0: 426 | print("Full Train/Test splits...") 427 | train_file="%s/train.first.tsv" % data_folder 428 | test_file="%s/test.first.tsv" % data_folder 429 | elif args.expSym in xrange(1,6): 430 | print ("Random-split-%d-train/dev..." % args.expSym) 431 | train_file="%s/random-split-%d-train.first.tsv" % (data_folder, args.expSym) 432 | test_file="%s/random-split-%d-dev.first.tsv" % (data_folder, args.expSym) 433 | else: 434 | print("Unknown experimental setting...") 435 | return 436 | 437 | print("=" * 80) 438 | print("Train",train_file) 439 | print("Test",test_file) 440 | print(config.d) 441 | print(">" * 8, "begin experiments") 442 | 443 | train = util.get_labeled_questions(train_file, data_folder) 444 | test = util.get_labeled_questions(test_file, data_folder) 445 | 446 | # create a word embedding table 447 | 448 | words = set(["_UNK_", "_EMPTY_"]) 449 | for ex in train: 450 | words.update(ex.all_words) 451 | for ex in test: 452 | words.update(ex.all_words) 453 | 454 | vw = util.Vocab.from_corpus([words]) 455 | nwords = vw.size() 456 | 457 | neural_model = SqaModel(0.01, vw) 458 | sm = BeamSearchInferencer(neural_model,config.d["beam_size"]) 459 | 460 | # main loop 461 | start_time = time.time() 462 | max_reward_at_epoch = [0,0] 463 | for ITER in xrange(config.d["NUM_ITER"]): 464 | random.shuffle(train) 465 | loss = 0 466 | for i,qinfo in enumerate(train,1): 467 | dt.renew_cg() # very important! to renew the cg 468 | 469 | init_state = SqaScoreExpressionState(neural_model, qinfo ,vw) 470 | #loss += sm.beam_train_max_margin(init_state, qinfo.answer_coordinates) 471 | loss += sm.beam_train_max_margin_with_answer_guidence(init_state, qinfo.answer_coordinates) 472 | 473 | if i % 100 == 0: 474 | print (i, "/", len(train)) 475 | 476 | neural_model.learner.update_epoch(1.0) 477 | 478 | accuracy = 0.0 479 | all_reward = 0.0 480 | total = 0.0 481 | for i,qinfo in enumerate(test,1): 482 | dt.renew_cg() # very important! to renew the cg 483 | init_state = SqaScoreExpressionState(neural_model, qinfo ,vw) 484 | top1_state = sm.beam_predict(init_state)[0] 485 | rew = top1_state.reward(qinfo.answer_coordinates) 486 | 487 | all_reward += rew 488 | accuracy += int(rew) # 0-1, only get a point if all predictions are correct 489 | total += 1 490 | 491 | print("In epoch ", ITER, " avg loss (or negative reward) is ", loss) 492 | reported_reward = all_reward/total 493 | reported_accuracy = accuracy/total 494 | print ("reward", reported_reward) 495 | print ("accuracy", reported_accuracy) 496 | if (reported_reward > max_reward_at_epoch[0]): 497 | max_reward_at_epoch = (reported_reward, reported_accuracy, ITER) 498 | 499 | now_time = time.time() 500 | print ("Time taken in this epoch", now_time - start_time) 501 | start_time = now_time 502 | print("Best Reward: %f (Accuracy: %f) at epoch %d" % max_reward_at_epoch) 503 | print () 504 | sys.stdout.flush() 505 | 506 | if __name__ == '__main__': 507 | cProfile.run('main()') 508 | -------------------------------------------------------------------------------- /sqafollow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | import sys, time, argparse, csv 11 | import cProfile 12 | 13 | if sys.version < '3': 14 | from codecs import getwriter 15 | 16 | stderr = getwriter('utf-8')(sys.stderr) 17 | stdout = getwriter('utf-8')(sys.stdout) 18 | else: 19 | stderr = sys.stderr 20 | 21 | import dynet as dt 22 | from collections import Counter 23 | import random 24 | import util 25 | import config 26 | import cPickle 27 | import copy 28 | 29 | from action import * 30 | from statesearch import * 31 | from sqastate import * 32 | from sqamodel import * 33 | import nnmodule as nnmod 34 | 35 | ######## START OF THE CODE ######## 36 | 37 | ''' 38 | This is forked from "sqafirst.py" and is designed for "follow-up" questions. 39 | In addition to the regular input, question and table, it also has: 40 | (1) the previous question 41 | (2) the parse of the previous question 42 | (3) the results of the previous question 43 | In other words, the action space is bigger in the sense that it can ignore the previous question completely, 44 | or it can take the previous question into account. 45 | ''' 46 | 47 | 48 | def test_model(test, search_manager, model, fLog=None, fnRes='', blWTQ=False, indep=False): 49 | if fnRes: 50 | if blWTQ: 51 | fRes = open(fnRes, 'w') 52 | else: 53 | fieldnames = ['id', 'annotator', 'position', 'answer_coordinates'] 54 | fRes = csv.DictWriter(open(fnRes, 'w'), delimiter=str('\t'), fieldnames=fieldnames) 55 | fRes.writeheader() 56 | 57 | accuracy = 0.0 58 | all_reward = 0.0 59 | total = 0.0 60 | for i, qinfo in enumerate(test, 1): 61 | dt.renew_cg() # very important! to renew the cg 62 | 63 | if qinfo.seq_qid[-1] == '0' or indep: # first question or treat all questions independently 64 | init_state = SqaScoreExpressionState(model, qinfo, testmode=True) 65 | else: 66 | init_state = SqaScoreExpressionState(model, qinfo, resinfo=resinfo, testmode=True) 67 | 68 | top_states = search_manager.beam_predict(init_state) 69 | top1_state = top_states[0] 70 | rew = top1_state.reward(qinfo.answer_coordinates) 71 | pred = top1_state.recent_pred 72 | 73 | # use the predicted answers & the most common column index 74 | 75 | if pred == []: 76 | # empty answers, set pred_column_idx to be 0, as it does not matter... 77 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, pred, 0) 78 | else: 79 | pred_columns = [coord[1] for coord in pred] 80 | pred_column_idx = Counter(pred_columns).most_common(1)[0][0] 81 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, 82 | pred, pred_column_idx) 83 | # output to result file 84 | if fnRes: 85 | if blWTQ: 86 | id, annotator, position = qinfo.seq_qid.split('_') 87 | fRes.write(id) 88 | for coord in pred: 89 | fRes.write("\t%s" % qinfo.entries[coord[0]][coord[1]]) 90 | fRes.write('\n') 91 | else: 92 | outRow = {} 93 | outRow['id'], outRow['annotator'], outRow['position'] = qinfo.seq_qid.split('_') 94 | outRow['answer_coordinates'] = "[%s]" % ", ".join(["'(%d, %d)'" % coord for coord in pred]) 95 | fRes.writerow(outRow) 96 | 97 | # output detailed predictions 98 | if fLog: 99 | fLog.write("(%s) %s\n" % (qinfo.seq_qid, qinfo.question)) 100 | fLog.write("%s/%s\n" % (config.d["dirTable"], qinfo.table_file)) 101 | if config.d["guessLogPass"]: 102 | best_reward_state = \ 103 | search_manager.beam_find_actions_with_answer_guidence(init_state, qinfo.answer_coordinates)[0] 104 | fLog.write("Guessed Gold Parse: %s\n" % best_reward_state) 105 | 106 | if config.d["verbose-dump"]: 107 | for i, s in enumerate(top_states, 1): 108 | fLog.write("Parse %d: [%f] %s\n" % (i, s.score, s)) 109 | else: 110 | fLog.write("Parse: %s\n" % top1_state) 111 | 112 | fLog.write("Answer: %s\n" % ", ".join(["(%d,%d)" % coord for coord in qinfo.answer_coordinates])) 113 | fLog.write("Predictions: %s\n" % ", ".join(["(%d,%d)" % coord for coord in pred])) 114 | fLog.write("Reward: %f\n" % rew) 115 | fLog.write("Accuracy: %d\n" % int(rew)) 116 | fLog.write("\n") 117 | 118 | all_reward += rew 119 | accuracy += int(rew) # 0-1, only get a point if all predictions are correct 120 | total += 1 121 | 122 | reported_reward = all_reward / total 123 | reported_accuracy = accuracy / total 124 | 125 | if fLog: 126 | fLog.write("Average reward: %f\n" % reported_reward) 127 | fLog.write("Average accuracy %f\n" % reported_accuracy) 128 | 129 | if fnRes and blWTQ: 130 | fRes.close() 131 | 132 | return reported_reward, reported_accuracy 133 | 134 | 135 | def main(): 136 | parser = argparse.ArgumentParser(description='Targeting "first questions" only.') 137 | parser.add_argument('--expSym', 138 | help='1, 2, 3, 4, 5 or 0 (full), or 11 (WTQ), 21 (WTQ-dev1) .. 25 (WTQ-dev5), or -1 (quick test)', 139 | type=int) 140 | parser.add_argument('--dynet-mem') 141 | parser.add_argument('--dynet-seed') 142 | parser.add_argument('--log', help='file to store additional log information', default='') 143 | parser.add_argument('--res', help='file to store the predictions on test set in the "official" format', default='') 144 | parser.add_argument('--model', help='prefix of the file that stores the model parameters', default='') 145 | parser.add_argument('--firstOnly', help='load only first questions', action='store_const', const=True, 146 | default=False) 147 | parser.add_argument('--dirData', help='data folder', default='data') 148 | parser.add_argument('--evalModel', help='evaluate a particular model on the test data only', default='') 149 | parser.add_argument('--evalOracle', help='yes/no switch with see oracle results for the action space', 150 | action='store_const', const=True, default=False) 151 | parser.add_argument('--wtq', help='WikiTableQuestions output', action='store_const', const=True, default=False) 152 | parser.add_argument('--indep', help='Treat all questions independently', action='store_const', const=True, 153 | default=False) 154 | 155 | args = parser.parse_args() 156 | 157 | # Prepare training and testing (development) data 158 | random.seed(1) 159 | 160 | data_folder = args.dirData 161 | # config.d["AnnotatedTableDir"] = args.dirData.replace('data', config.d["AnnotatedTableDir"]) 162 | if args.expSym == 0: 163 | print("Full Train/Test splits...") 164 | train_file = "%s/train.tsv" % data_folder 165 | test_file = "%s/test.tsv" % data_folder 166 | elif args.expSym in xrange(1, 6): 167 | print("Random-split-%d-train/dev..." % args.expSym) 168 | train_file = "%s/random-split-%d-train.tsv" % (data_folder, args.expSym) 169 | test_file = "%s/random-split-%d-dev.tsv" % (data_folder, args.expSym) 170 | elif args.expSym == -1: 171 | print("Quick code test...") 172 | train_file = "%s/unit.tsv" % data_folder 173 | test_file = "%s/unit.tsv" % data_folder 174 | elif args.expSym == 11: 175 | print("WikiTable Questions...") 176 | train_file = "%s/training.tsv" % data_folder 177 | test_file = "%s/pristine-unseen-tables.tsv" % data_folder 178 | elif args.expSym in xrange(21, 26): 179 | print("WikiTable Questions -- Dev1/Dev1...") 180 | train_file = "%s/random-split-%d-dev.tsv" % (data_folder, args.expSym - 20) 181 | test_file = "%s/random-split-%d-dev.tsv" % (data_folder, args.expSym - 20) 182 | else: 183 | assert False, "Unknown experimental setting..." 184 | return 185 | 186 | if args.evalModel: 187 | evalModel(args.evalModel, data_folder, test_file, args.log, args.res, args.wtq, args.indep) 188 | return 189 | 190 | fLog = None 191 | if args.log: 192 | fLog = open(args.log, 'w') 193 | 194 | print("=" * 80) 195 | print("Train", train_file) 196 | print("Test", test_file) 197 | print(config.d) 198 | print(">" * 8, "begin experiments") 199 | 200 | train = util.get_labeled_questions(train_file, data_folder, args.firstOnly) 201 | test = util.get_labeled_questions(test_file, data_folder, args.firstOnly) 202 | 203 | # create a word embedding table 204 | 205 | words = set(["_UNK_", "_EMPTY_"]) 206 | for ex in train: 207 | words.update(ex.all_words) 208 | for ex in test: 209 | words.update(ex.all_words) 210 | 211 | vw = util.Vocab.from_corpus([words]) 212 | nwords = vw.size() 213 | 214 | neural_model = SqaModel(0.01, vw) 215 | sm = BeamSearchInferencer(neural_model, config.d["beam_size"]) 216 | sm.only_one_best = config.OnlyOneBest 217 | 218 | if args.evalOracle: 219 | evalOracleActions(neural_model, sm, train) 220 | return 221 | 222 | # main loop 223 | start_time = time.time() 224 | max_reward_at_epoch = (0, 0, 0) 225 | max_accuracy_at_epoch = (0, 0, 0) 226 | for ITER in xrange(config.d["NUM_ITER"]): 227 | # random.shuffle(train) 228 | loss = 0 229 | for i, qinfo in enumerate(train, 1): 230 | dt.renew_cg() # very important! to renew the cg 231 | 232 | if qinfo.seq_qid[-1] == '0' or args.indep: # first question or treat all questions independenly 233 | init_state = SqaScoreExpressionState(neural_model, qinfo) 234 | else: 235 | init_state = SqaScoreExpressionState(neural_model, qinfo, resinfo=resinfo) 236 | 237 | # loss += sm.beam_train_max_margin(init_state, qinfo.answer_coordinates) 238 | try: 239 | new_loss, end_state_list = sm.beam_train_max_margin_with_answer_guidence(init_state, 240 | qinfo.answer_coordinates) 241 | loss += new_loss 242 | 243 | if random.uniform(0, 1) < 0.5: 244 | # print("use gold!") 245 | # use the gold answers 246 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, 247 | qinfo.answer_coordinates, qinfo.answer_column_idx) 248 | else: 249 | # print("use predict!") 250 | 251 | if len(end_state_list) == 0: 252 | # empty answers, set pred_column_idx to be 0, as it does not matter... 253 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, pred, 0) 254 | else: 255 | # use the predicted answers & the most common column index 256 | 257 | pred = end_state_list[0].recent_pred 258 | # use the predicted answers & the most common column index 259 | if pred == []: 260 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, pred, 0) 261 | else: 262 | pred_columns = [coord[1] for coord in pred] 263 | pred_column_idx = Counter(pred_columns).most_common(1)[0][0] 264 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, 265 | pred, pred_column_idx) 266 | except Exception as e: 267 | print(str(e)) 268 | print("Exception in running!") 269 | 270 | # print ("debug: resinfo:", resinfo) 271 | 272 | if i % 100 == 0: 273 | print(i, "/", len(train)) 274 | # print ("debug: dynet.parameter(sm.neural_model.SelColW).value():", dt.parameter(sm.neural_model.SelColW).value()) 275 | # print ("debug: loss:", loss) 276 | 277 | neural_model.learner.update_epoch(1.0) 278 | print("In epoch ", ITER, " avg loss (or negative reward) is ", loss / len(train)) 279 | reported_reward, reported_accuracy = test_model(test, sm, neural_model, fLog, indep=args.indep) 280 | print("In epoch ", ITER, " test reward is %f, test accuracy is %f" % (reported_reward, reported_accuracy)) 281 | 282 | if (reported_reward > max_reward_at_epoch[0]): 283 | max_reward_at_epoch = (reported_reward, reported_accuracy, ITER) 284 | 285 | if (reported_accuracy > max_accuracy_at_epoch[0]): 286 | max_accuracy_at_epoch = (reported_accuracy, reported_reward, ITER) 287 | 288 | now_time = time.time() 289 | print("Time taken in this epoch", now_time - start_time) 290 | start_time = now_time 291 | print("Best Reward: %f (Accuracy: %f) at epoch %d" % max_reward_at_epoch) 292 | print("Best Accuracy: %f (Reward: %f) at epoch %d" % max_accuracy_at_epoch) 293 | 294 | if args.model: 295 | neural_model.save("%s-%d" % (args.model, ITER)) 296 | ''' 297 | if fLog: # Test the saved model 298 | new_model = SqaModel.load(args.model) 299 | new_sm = BeamSearchInferencer(new_model,config.d["beam_size"]) 300 | reported_reward,reported_accuracy = test_model(test, new_sm, new_model, fLog) 301 | ''' 302 | print() 303 | sys.stdout.flush() 304 | 305 | if args.res: test_model(test, sm, model, fnRes=args.res, blWTQ=args.wtq, indep=args.indep) 306 | 307 | if fLog: fLog.close() 308 | 309 | 310 | def evalOracleActions(neural_model, sm, train): 311 | train_reward = 0.0 312 | count_perfect_reward = 0.0 313 | 314 | first_train_reward = 0.0 315 | first_count_perfect_reward = 0.0 316 | num_first = 0.0 317 | 318 | rest_train_reward = 0.0 319 | rest_count_perfect_reward = 0.0 320 | num_rest = 0.0 321 | 322 | for i, qinfo in enumerate(train, 1): 323 | dt.renew_cg() # very important! to renew the cg 324 | 325 | if qinfo.seq_qid[-1] == '0': # first question 326 | init_state = SqaScoreExpressionState(neural_model, qinfo) 327 | else: 328 | init_state = SqaScoreExpressionState(neural_model, qinfo, resinfo=resinfo) 329 | 330 | gold_ans = qinfo.answer_coordinates 331 | 332 | top_states = sm.beam_find_actions_with_answer_guidence(init_state, gold_ans) 333 | best_reward_state = top_states[0] 334 | best_reward_state_reward = best_reward_state.reward(gold_ans) 335 | 336 | # output detailed predictions 337 | sys.stdout.write("(%s) %s\n" % (qinfo.seq_qid, qinfo.question)) 338 | sys.stdout.write("%s/%s\n" % (config.d["dirTable"], qinfo.table_file)) 339 | if config.d["verbose-dump"]: 340 | for i, s in enumerate(top_states, 1): 341 | sys.stdout.write("Parse %d: estimate reward: [%f] model score: [%f] %s\n" % ( 342 | i, s.score, s.path_score_expression.value(), s)) 343 | else: 344 | sys.stdout.write("Parse: %s\n" % best_reward_state) 345 | 346 | sys.stdout.write("Answer: %s\n" % ", ".join(["(%d,%d)" % coord for coord in qinfo.answer_coordinates])) 347 | sys.stdout.write("\n") 348 | 349 | train_reward += best_reward_state_reward 350 | if (best_reward_state_reward == 1): 351 | count_perfect_reward += 1 352 | 353 | if qinfo.seq_qid[-1] == '0': # first question 354 | num_first += 1 355 | first_train_reward += best_reward_state_reward 356 | if (best_reward_state_reward == 1): 357 | first_count_perfect_reward += 1 358 | else: 359 | num_rest += 1 360 | rest_train_reward += best_reward_state_reward 361 | if (best_reward_state_reward == 1): 362 | rest_count_perfect_reward += 1 363 | 364 | # use the gold answers 365 | resinfo = util.ResultInfo(qinfo.seq_qid, qinfo.question, qinfo.ques_word_sequence, 366 | qinfo.answer_coordinates, qinfo.answer_column_idx) 367 | 368 | if i % 100 == 0: 369 | print(i, "/", len(train)) 370 | # print ("debug: dynet.parameter(sm.neural_model.SelColW).value():", dt.parameter(sm.neural_model.SelColW).value()) 371 | # print ("debug: loss:", loss) 372 | 373 | print("With beamsize ", config.d['beam_size']) 374 | print("Oracle training reward is ", (train_reward / len(train))) 375 | print("percentage of getting perfect reward is ", (count_perfect_reward / len(train))) 376 | 377 | print("# first question ", num_first) 378 | print("Oracle training reward is ", (first_train_reward / num_first)) 379 | print("percentage of getting perfect reward is ", (first_count_perfect_reward / num_first)) 380 | 381 | print("# non-first question ", num_rest) 382 | print("Oracle training reward is ", (rest_train_reward / num_rest)) 383 | print("percentage of getting perfect reward is ", (rest_count_perfect_reward / num_rest)) 384 | 385 | 386 | def evalModel(fnModel, data_folder, fnData, fnLog='', fnRes='', blWTQ=False, indep=False): 387 | data = util.get_labeled_questions(fnData, data_folder, skipEmptyAns=not blWTQ) 388 | model = SqaModel.load(fnModel) 389 | sm = BeamSearchInferencer(model, config.d["beam_size"]) 390 | 391 | fLog = None 392 | if fnLog: fLog = open(fnLog, 'w') 393 | test_model(data, sm, model, fLog=fLog, fnRes=fnRes, blWTQ=blWTQ, indep=indep) 394 | if fnLog: fLog.close() 395 | 396 | 397 | if __name__ == '__main__': 398 | # cProfile.run('main()') 399 | main() 400 | -------------------------------------------------------------------------------- /sqamodel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | import sys, time, argparse, csv 11 | import cProfile 12 | 13 | if sys.version < '3': 14 | from codecs import getwriter 15 | stderr = getwriter('utf-8')(sys.stderr) 16 | stdout = getwriter('utf-8')(sys.stdout) 17 | else: 18 | stderr = sys.stderr 19 | 20 | import dynet as dt 21 | from collections import Counter 22 | import random 23 | import util 24 | import config 25 | import cPickle 26 | import copy 27 | 28 | from action import * 29 | from statesearch import * 30 | import nnmodule as nnmod 31 | 32 | ######## START OF THE CODE ######## 33 | 34 | class SqaModel(): 35 | 36 | WORD_EMBEDDING_DIM = config.d["WORD_EMBEDDING_DIM"] 37 | LSTM_HIDDEN_DIM = config.d["LSTM_HIDDEN_DIM"] 38 | 39 | def __init__(self, init_learning_rate, vw, reload_embeddings = True): 40 | self.model = dt.Model() 41 | 42 | self.vw = vw 43 | 44 | UNK = self.vw.w2i["_UNK_"] 45 | n_words = vw.size() 46 | print("init vw =", self.vw.size(), "words") 47 | self.learning_rate = init_learning_rate 48 | #self.learner = dt.SimpleSGDTrainer(self.model, e0=init_learning_rate) 49 | self.learner = dt.SimpleSGDTrainer(self.model) 50 | self.E = self.model.add_lookup_parameters((n_words, SqaModel.WORD_EMBEDDING_DIM)) 51 | # similarity(v,o): v^T o 52 | self.SelHW = self.model.add_parameters((4 * SqaModel.WORD_EMBEDDING_DIM)) 53 | self.SelIntraFW = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM / 2, SqaModel.WORD_EMBEDDING_DIM)) 54 | self.SelIntraHW = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM, SqaModel.WORD_EMBEDDING_DIM * 2)) 55 | self.SelIntraBias = self.model.add_parameters((config.d["DIST_BIAS_DIM"])) 56 | self.ColTypeN = self.model.add_parameters((1)) 57 | self.ColTypeW = self.model.add_parameters((1)) 58 | self.NulW = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM)) 59 | 60 | ''' new ways to add module ''' 61 | self.SelColFF = nnmod.FeedForwardModel(self.model, 4) 62 | self.WhereColFF = nnmod.FeedForwardModel(self.model, 5) 63 | self.QCMatch = nnmod.QuestionColumnMatchModel(self.model, SqaModel.WORD_EMBEDDING_DIM) 64 | self.NegFF = nnmod.FeedForwardModel(self.model, 2) 65 | self.FpWhereColFF = nnmod.FeedForwardModel(self.model, 9) 66 | 67 | 68 | # LSTM question representation 69 | self.builders = [ 70 | dt.LSTMBuilder(1, SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM, self.model), 71 | dt.LSTMBuilder(1, SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM, self.model) 72 | ] 73 | self.pH = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM*2)) 74 | 75 | # LSTM question representation 76 | self.prev_builders = [ 77 | dt.LSTMBuilder(1, SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM, self.model), 78 | dt.LSTMBuilder(1, SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM, self.model) 79 | ] 80 | self.prev_pH = self.model.add_parameters((SqaModel.WORD_EMBEDDING_DIM, SqaModel.LSTM_HIDDEN_DIM*2)) 81 | self.SelColFpWhereW = self.model.add_parameters((4)) 82 | self.SameAsPreviousW = self.model.add_parameters((2)) 83 | 84 | if config.d["USE_PRETRAIN_WORD_EMBEDDING"] and reload_embeddings: 85 | n_hit_pretrain = 0.0 86 | trie = config.d["embeddingtrie"] 87 | print ("beginning to load embeddings....") 88 | for i in range(n_words): 89 | word = self.vw.i2w[i].lower() 90 | results = trie.items(word+ config.d["recordtriesep"]) 91 | if len(results) == 1: 92 | pretrain_v = np.array(list(results[0][1])) 93 | pretrain_v = pretrain_v/np.linalg.norm(pretrain_v) 94 | self.E.init_row(i,pretrain_v) 95 | n_hit_pretrain += 1 96 | else: 97 | pretrain_v = self.E[i].npvalue() 98 | pretrain_v = pretrain_v/np.linalg.norm(pretrain_v) 99 | self.E.init_row(i,pretrain_v) 100 | 101 | 102 | print ("the number of words that are in pretrain", n_hit_pretrain, n_words, n_hit_pretrain/n_words) 103 | print ("loading complete!") 104 | 105 | if config.d["USE_PRETRAIN_WORD_EMBEDDING"]: 106 | self.Negate = nnmod.NegationModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "not", self.vw, self.E) 107 | self.CondGT = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "more greater larger than", self.vw, self.E) 108 | self.CondGE = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "more greater larger than or equal to at least", self.vw, self.E) 109 | self.CondLT = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "less fewer smaller than", self.vw, self.E) 110 | self.CondLE = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "less fewer smaller than or equal to at most", self.vw, self.E) 111 | self.ArgMin = nnmod.ArgModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "least fewest smallest lowest shortest oldest", self.vw, self.E) 112 | self.ArgMax = nnmod.ArgModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK, "most greatest biggest largest highest longest latest tallest", self.vw, self.E) 113 | 114 | else: 115 | self.Negate = nnmod.NegationModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 116 | self.CondGT = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 117 | self.CondGE = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 118 | self.CondLT = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 119 | self.CondLE = nnmod.CompareModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 120 | self.ArgMin = nnmod.ArgModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 121 | self.ArgMax = nnmod.ArgModel(self.model, SqaModel.WORD_EMBEDDING_DIM, UNK) 122 | 123 | def save(self,header): 124 | print("Saving model with header = ", header) 125 | f = open(header + "-extra.bin",'wb') 126 | cPickle.dump(self.vw,f) 127 | cPickle.dump(self.learning_rate,f) 128 | f.close() 129 | self.model.save(header + "-dynetmodel.bin") 130 | #print("Done!") 131 | 132 | @staticmethod 133 | def load(header): 134 | print("Loading model with header = ", header) 135 | f = open(header + "-extra.bin",'rb') 136 | vw = cPickle.load(f) 137 | lr = cPickle.load(f) 138 | f.close() 139 | res = SqaModel(lr,vw,False) # do not waste time reload embeddings 140 | #res.model.load(header + "-dynetmodel.bin") 141 | res.model.populate(header + "-dynetmodel.bin") 142 | #print("Done!") 143 | 144 | return res 145 | -------------------------------------------------------------------------------- /sqastate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | import sys, time, argparse, csv 11 | import cProfile 12 | 13 | if sys.version < '3': 14 | from codecs import getwriter 15 | stderr = getwriter('utf-8')(sys.stderr) 16 | stdout = getwriter('utf-8')(sys.stdout) 17 | else: 18 | stderr = sys.stderr 19 | 20 | import dynet as dt 21 | from collections import Counter 22 | import random 23 | import util 24 | import config 25 | import cPickle 26 | import copy 27 | 28 | from action import * 29 | from statesearch import * 30 | import nnmodule as nnmod 31 | 32 | ######## START OF THE CODE ######## 33 | 34 | 35 | class SqaState: 36 | action_factory_cache = {} # qinfo.qid -> action_factory 37 | 38 | def __init__(self, qinfo, resinfo=None): 39 | if resinfo == None: 40 | if qinfo.seq_qid in SqaState.action_factory_cache: 41 | self.af = SqaState.action_factory_cache[qinfo.seq_qid] 42 | else: 43 | self.af = ActionFactory(qinfo) # define the actions 44 | SqaState.action_factory_cache[qinfo.seq_qid] = self.af 45 | else: # has previous question, currently no cache 46 | self.af = ActionFactory(qinfo, resinfo) # define the actions 47 | self.action_history = [self.af.start_action_idx] # fill in the "start" null action 48 | self.meta_history = [dt.inputVector(np.zeros(len(qinfo.ques_word_sequence)))] 49 | self.qinfo = qinfo 50 | self.resinfo = resinfo 51 | self.numCol = self.qinfo.num_columns 52 | self.numRow = self.qinfo.num_rows 53 | 54 | # Given the current state (self), return a list of legitimate actions 55 | # Currently, it follows the STAGG's fashion and requires it does SELECT first. We can later relax it. 56 | def get_action_set(self, action_history = None): 57 | if action_history == None: 58 | action_history = self.action_history 59 | last_act_idx = action_history[-1] 60 | return self.af.legit_next_action_idxs(last_act_idx, action_history) 61 | 62 | def is_end(self, action_history = None): 63 | return not self.get_action_set(action_history) # empty action_set 64 | 65 | # return a set of actions that can lead to the gold state from the current state 66 | def get_action_set_withans(self, gold_ans): 67 | ret_action = [] 68 | ret_estimated_reward = [] 69 | ''' 70 | print("debug: current action_history", self.action_history) 71 | print("debug: self.get_action_set()", self.get_action_set()) 72 | print("debug: self.qinfo.question", self.qinfo.question) 73 | for actidx in self.get_action_set(): 74 | print("debug: action %d: %s" % (actidx, self.af.actidx2str(actidx))) 75 | print("") 76 | ''' 77 | for act in self.get_action_set(): 78 | e_reward = self.estimated_reward(gold_ans, act) 79 | #print("debug: act = %s, e_reward = %f" % (self.af.actidx2str(actidx), e_reward)) 80 | if e_reward > 0: 81 | ret_action.append(act) 82 | ret_estimated_reward.append(e_reward) 83 | return ret_action, ret_estimated_reward 84 | 85 | # the estimated final reward value of a full path, after executing the given action 86 | def estimated_reward(self, gold_ans, act_idx): 87 | # if this action is the final action to goal state (i.e., Stop) 88 | # use the real reward directly 89 | if not self.af.legit_next_action_idxs(act_idx, self.action_history): # empty set 90 | path = self.action_history + [act_idx] 91 | ret = self.reward(gold_ans, path, True) 92 | else: 93 | act_idx_stop = self.af.type2actidxs[ActionType.Stop][0] 94 | action = self.af.actions[act_idx] 95 | #print ("action.type:", action.type) 96 | if action.type == ActionType.Select: 97 | path = self.action_history + [act_idx, act_idx_stop] 98 | ret = self.reward(gold_ans, path, True) 99 | elif action.type == ActionType.WhereCol: # TODO: ignore WhereCol although technically it's not correct 100 | path = self.action_history + [act_idx_stop] 101 | #print("path=", path) 102 | ret = self.reward(gold_ans, path, True) 103 | elif action.type == ActionType.FpWhereCol: # ignore FpWhereCol & treat it as SameAsPrevious 104 | act_idx_same = self.af.type2actidxs[ActionType.SameAsPrevious][0] 105 | path = self.action_history + [act_idx_same] 106 | ret = self.reward(gold_ans, path, True) 107 | else: # append action 108 | path = self.action_history + [act_idx] 109 | #print("path2=", path, "path2.action.type=", [self.af.actions[p].type for p in path]) 110 | ret = self.reward(gold_ans, path, True) 111 | 112 | if ret == 1: # check if we like this parse 113 | ret = self.af.action_history_quality(path) 114 | 115 | return ret 116 | 117 | def reward(self, gold, action_history = None, partial = None): 118 | if partial == None: 119 | partial = config.d["partial_reward"] 120 | 121 | if not gold: 122 | gold = self.qinfo.answer_coordinates 123 | 124 | # execute the parse 125 | pred = self.execute_parse(action_history) 126 | 127 | setGold = set(gold) 128 | setPred = set(pred) 129 | 130 | if partial: 131 | # Reward = #(Gold INTERSECT Pred) / #(Gold UNION Pred) 132 | ret = float(len(setGold.intersection(setPred))) / len(setGold.union(setPred)) 133 | else: 134 | # change the reward function to be 0/1 135 | if setGold == setPred: 136 | ret = 1.0 137 | else: 138 | ret = 0.0 139 | 140 | self.recent_pred = pred 141 | return ret 142 | 143 | def execute_parse(self, action_history=None): 144 | if action_history == None: 145 | action_history = self.action_history 146 | 147 | # only execute if the parse is complete 148 | if not self.is_end(action_history): 149 | return [] 150 | 151 | # map the action sequence to a parse 152 | parse = self.af.action_history_to_parse(action_history) 153 | coords = parse.run(self.qinfo, self.resinfo) 154 | 155 | return coords 156 | 157 | class SqaScoreExpressionState(SqaState): 158 | 159 | def __init__(self, nmodel, qinfo, init_example = True, resinfo = None, testmode = False): 160 | SqaState.__init__(self, qinfo, resinfo) 161 | self.path_score_expression = dt.scalarInput(0) 162 | self.score = 0 163 | self.nm = nmodel 164 | self.vw = self.nm.vw 165 | self.H = dt.parameter(self.nm.pH) 166 | self.prev_H = dt.parameter(self.nm.prev_pH) 167 | 168 | if init_example: 169 | UNK = self.vw.w2i["_UNK_"] 170 | 171 | # vectors of question words 172 | if testmode or not config.d["DropOut"]: 173 | self.ques_emb = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.qinfo.ques_word_sequence] 174 | else: 175 | self.ques_emb = [dt.dropout(self.nm.E[self.vw.w2i.get(w, UNK)], 0.5) for w in self.qinfo.ques_word_sequence] 176 | self.ques_avg_emb = dt.average(self.ques_emb) 177 | 178 | # column name embeddings 179 | self.colname_embs = [] 180 | # avg. vectors of column names 181 | self.headers_embs = [] 182 | for colname_word_sequence in self.qinfo.headers_word_sequences: 183 | colname_emb = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in colname_word_sequence] 184 | self.colname_embs.append(colname_emb) 185 | self.headers_embs.append(dt.average(colname_emb)) 186 | 187 | # avg. vectors of table entries 188 | self.entries_embs = [] 189 | for row_word_sequences in self.qinfo.entries_word_sequences: 190 | row_embs = [] 191 | for cell_word_sequence in row_word_sequences: 192 | row_embs.append(dt.average([self.nm.E[self.vw.w2i.get(w, UNK)] for w in cell_word_sequence])) 193 | self.entries_embs.append(row_embs) 194 | 195 | self.NulW = dt.parameter(self.nm.NulW) 196 | self.SelHW = dt.parameter(self.nm.SelHW) 197 | self.SelIntraFW = dt.parameter(self.nm.SelIntraFW) 198 | self.SelIntraHW = dt.parameter(self.nm.SelIntraHW) 199 | self.SelIntraBias = dt.parameter(self.nm.SelIntraBias) 200 | self.ColTypeN = dt.parameter(self.nm.ColTypeN) 201 | self.ColTypeW = dt.parameter(self.nm.ColTypeW) 202 | 203 | ''' new ways to add module ''' 204 | self.SelColFF = self.nm.SelColFF.spawn_expression() 205 | self.WhereColFF = self.nm.WhereColFF.spawn_expression() 206 | self.QCMatch = self.nm.QCMatch.spawn_expression() 207 | self.Negate = self.nm.Negate.spawn_expression() 208 | self.NegFF = self.nm.NegFF.spawn_expression() 209 | self.FpWhereColFF = self.nm.FpWhereColFF.spawn_expression() 210 | self.CondGT = self.nm.CondGT.spawn_expression() 211 | self.CondGE = self.nm.CondGE.spawn_expression() 212 | self.CondLT = self.nm.CondLT.spawn_expression() 213 | self.CondLE = self.nm.CondLE.spawn_expression() 214 | self.ArgMin = self.nm.ArgMin.spawn_expression() 215 | self.ArgMax = self.nm.ArgMax.spawn_expression() 216 | 217 | # question LSTM 218 | f_init, b_init = [b.initial_state() for b in self.nm.builders] 219 | self.fw = [x.output() for x in f_init.add_inputs(self.ques_emb)] 220 | self.bw = [x.output() for x in b_init.add_inputs(reversed(self.ques_emb))] 221 | self.bw.reverse() 222 | 223 | # from previous question & its answers 224 | if resinfo != None: 225 | # vectors of question words 226 | self.prev_ques_emb = [self.nm.E[self.vw.w2i.get(w, UNK)] for w in self.resinfo.prev_ques_word_sequence] 227 | self.prev_ques_avg_emb = dt.average(self.prev_ques_emb) 228 | 229 | # previous question LSTM 230 | f_init, b_init = [b.initial_state() for b in self.nm.prev_builders] 231 | self.prev_fw = [x.output() for x in f_init.add_inputs(self.prev_ques_emb)] 232 | self.prev_bw = [x.output() for x in b_init.add_inputs(reversed(self.prev_ques_emb))] 233 | self.prev_bw.reverse() 234 | 235 | self.SelColFpWhereW = dt.parameter(self.nm.SelColFpWhereW) 236 | self.SameAsPreviousW = dt.parameter(self.nm.SameAsPreviousW) 237 | 238 | def clone(self): 239 | res = SqaScoreExpressionState(self.nm, self.qinfo, False, self.resinfo) 240 | res.action_history = self.action_history[:] 241 | res.meta_history = self.meta_history[:] 242 | 243 | # vectors of question words 244 | res.ques_emb = self.ques_emb 245 | res.ques_avg_emb = self.ques_avg_emb 246 | 247 | # vectors of column names 248 | res.colname_embs = self.colname_embs 249 | res.headers_embs = self.headers_embs 250 | 251 | # avg. vectors of table entries 252 | res.entries_embs = self.entries_embs 253 | res.NulW = self.NulW 254 | 255 | res.SelHW = self.SelHW 256 | res.SelIntraFW = self.SelIntraFW 257 | res.SelIntraHW = self.SelIntraHW 258 | res.SelIntraBias = self.SelIntraBias 259 | res.ColTypeN = self.ColTypeN 260 | res.ColTypeW = self.ColTypeW 261 | res.fw = self.fw 262 | res.bw = self.bw 263 | 264 | ''' clone ''' 265 | ''' 266 | res.SelColFF = copy.deepcopy(self.SelColFF) 267 | res.WhereColFF = copy.deepcopy(self.WhereColFF) 268 | res.QCMatch = copy.deepcopy(self.QCMatch) 269 | res.Negate = copy.deepcopy(self.Negate) 270 | res.NegFF = copy.deepcopy(self.NegFF) 271 | res.FpWhereColFF = copy.deepcopy(self.FpWhereColFF) 272 | ''' 273 | res.SelColFF = self.nm.SelColFF.spawn_expression() 274 | res.WhereColFF = self.nm.WhereColFF.spawn_expression() 275 | res.QCMatch = self.nm.QCMatch.spawn_expression() 276 | res.Negate = self.nm.Negate.spawn_expression() 277 | res.NegFF = self.nm.NegFF.spawn_expression() 278 | res.FpWhereColFF = self.nm.FpWhereColFF.spawn_expression() 279 | res.CondGT = self.nm.CondGT.spawn_expression() 280 | res.CondGE = self.nm.CondGE.spawn_expression() 281 | res.CondLT = self.nm.CondLT.spawn_expression() 282 | res.CondLE = self.nm.CondLE.spawn_expression() 283 | res.ArgMin = self.nm.ArgMin.spawn_expression() 284 | res.ArgMax = self.nm.ArgMax.spawn_expression() 285 | 286 | if self.resinfo != None: 287 | # vectors of previous question words 288 | res.prev_ques_emb = self.prev_ques_emb 289 | res.prev_ques_avg_emb = self.prev_ques_avg_emb 290 | 291 | # previous question LSTM 292 | res.prev_fw = self.prev_fw 293 | res.prev_bw = self.prev_bw 294 | 295 | res.SelColFpWhereW = self.SelColFpWhereW 296 | 297 | return res 298 | 299 | 300 | # Decomposable attention between question and column name 301 | # Overall, it needs more efficient implementation... :( 302 | def decomp_attend(self, vecsA, vecsB): 303 | # Fq^T Fc -> need to expedite using native matrix/tensor multiplication 304 | Fq = vecsA # the original word vector, not yet passing a NN as in Eq.1, # need a function F 305 | Fc = vecsB # need a function F 306 | 307 | expE = [] 308 | for fq in Fq: 309 | row = [] 310 | for fc in Fc: 311 | row.append(dt.exp(dt.dot_product(fq,fc))) 312 | expE.append(row) 313 | #print ("debug: expE", expE[0][0].value()) 314 | 315 | invSumExpEi = [] 316 | for i in xrange(len(Fq)): 317 | invSumExpEi.append(dt.pow(dt.esum(expE[i]), dt.scalarInput(-1))) 318 | 319 | invSumExpEj = [] 320 | for j in xrange(len(Fc)): 321 | invSumExpEj.append(dt.pow(dt.esum([expE[i][j] for i in xrange(len(Fq))]), dt.scalarInput(-1))) 322 | 323 | beta = [] 324 | for i in xrange(len(Fq)): 325 | s = dt.esum([Fc[j] * expE[i][j] for j in xrange(len(Fc))]) 326 | beta.append(s * invSumExpEi[i]) 327 | #print("debug: beta", beta[0].value()) 328 | 329 | alpha = [] 330 | for j in xrange(len(Fc)): 331 | s = dt.esum([Fc[j] * expE[i][j] for i in xrange(len(Fq))]) 332 | alpha.append(s * invSumExpEj[j]) 333 | #print("debug: alpha", alpha[0].value()) 334 | 335 | # Compare 336 | v1i = [dt.logistic(dt.concatenate([Fq[i],beta[i]])) for i in xrange(len(Fq))] # need a function G 337 | v2j = [dt.logistic(dt.concatenate([Fc[j],alpha[j]])) for j in xrange(len(Fc))] # need a function G 338 | 339 | #print ("debug: v1i", v1i[0].value()) 340 | #print ("debug: v2j", v2j[0].value()) 341 | 342 | # Aggregate 343 | 344 | v1 = dt.esum(v1i) 345 | v2 = dt.esum(v2j) 346 | 347 | #print ("debug: v1.value()", v1.value()) 348 | #print ("debug: v2.value()", v2.value()) 349 | 350 | #colScore = dt.logistic(dt.dot_product(self.SelHW, dt.concatenate([v1,v2]))) 351 | return dt.dot_product(v1,v2) 352 | 353 | def intra_sent_attend(self, vecs): 354 | numVecs = len(vecs) 355 | fVecs = [dt.tanh(self.SelIntraFW * v) for v in vecs] 356 | expE = [] 357 | for i,fq in enumerate(fVecs): 358 | row = [] 359 | for j,fc in enumerate(fVecs): 360 | row.append(dt.exp(dt.dot_product(fq,fc) + self.SelIntraBias[i-j+int(config.d["DIST_BIAS_DIM"]/2)])) 361 | expE.append(row) 362 | 363 | invSumExpE = [] 364 | for i in xrange(numVecs): 365 | invSumExpE.append(dt.pow(dt.esum(expE[i]), dt.scalarInput(-1))) 366 | 367 | alpha = [] 368 | for i in xrange(numVecs): 369 | s = dt.esum([vecs[j] * expE[i][j] for j in xrange(numVecs)]) 370 | alpha.append(s * invSumExpE[i]) 371 | 372 | return [dt.tanh(self.SelIntraHW * dt.concatenate([v,a])) for v,a in zip(vecs, alpha)] 373 | 374 | def positional_reweight(self, vecs): 375 | return [v * dt.logistic(self.SelIntraBias[i]) for i,v in enumerate(vecs)] 376 | 377 | # input: question word vectors, averaged matched word vectors (e.g., column name or table entry) 378 | # output: a vector of length of question; each element represents how much the word is covered 379 | def determine_coverage_by_name(self, qwVecs, avgVec): 380 | return None 381 | # Compute question coverage -- hard/rough implementation to test idea first 382 | qWdMatchScore = [dt.dot_product(qwVec, avgVec).value() for qwVec in qwVecs] 383 | ret = dt.softmax(dt.inputVector(np.array(qWdMatchScore))) 384 | return ret 385 | 386 | def attend_question_coverage(self): 387 | return self.ques_emb 388 | #print("Question Info: seq_qid", self.qinfo.seq_qid, "question", self.qinfo.question) 389 | maskWdIndices = set() 390 | for coverageMap in self.meta_history: 391 | mask = coverageMap.value() 392 | if type(mask) != list: 393 | mask = [mask] 394 | max_value = max(mask) 395 | if max_value == 0: 396 | continue 397 | max_index = mask.index(max_value) 398 | maskWdIndices.add(max_index) 399 | qwVecs = [] 400 | for i,vec in enumerate(self.ques_emb): 401 | #print (i, vec) 402 | if mask[i] == 1: 403 | qwVecs.append(dt.inputVector(np.zeros(SqaModel.WORD_EMBEDDING_DIM))) 404 | else: 405 | qwVecs.append(vec) 406 | return qwVecs 407 | 408 | 409 | def get_next_score_expressions(self, legit_act_idxs): 410 | 411 | res_list = [] 412 | meta_list = [] 413 | ''' 414 | print ("debug: self.action_history", self.action_history) 415 | print ("debug: self.is_end()", self.is_end()) 416 | print ("debug: self.qinfo.seq_qid", self.qinfo.seq_qid) 417 | print ("debug: legit_act_idxs", legit_act_idxs) 418 | ''' 419 | 420 | qwVecs = self.attend_question_coverage() 421 | qwAvgVec = self.ques_avg_emb 422 | qLSTMVec = dt.tanh(self.H * dt.concatenate([self.fw[-1],self.bw[0]])) # question words LSTM embedding 423 | 424 | if self.resinfo != None: 425 | prev_qwVecs = self.prev_ques_emb 426 | prev_qwAvgVec = self.prev_ques_avg_emb 427 | prev_qLSTMVec = dt.tanh(self.prev_H * dt.concatenate([self.prev_fw[-1], self.prev_bw[0]])) 428 | 429 | for act_idx in legit_act_idxs: 430 | action = self.af.actions[act_idx] 431 | act_type = action.type 432 | #print("act_type", act_type) 433 | 434 | col = action.col 435 | colnameVec = self.headers_embs[col] 436 | colWdVecs = self.colname_embs[col] 437 | r = action.row 438 | if self.action_history != []: 439 | # for condition check, assuming the last action of the current state is ActWhereCol 440 | c = self.af.actions[self.action_history[-1]].col 441 | condCellVec = self.entries_embs[r][c] 442 | 443 | if act_type == ActionType.Stop: 444 | # use the average after mask 445 | actScore = dt.dot_product(dt.average(qwVecs), self.NulW) 446 | coverageMap = dt.inputVector(np.zeros(len(qwVecs))) 447 | 448 | elif act_type == ActionType.Select: 449 | lstScores = self.QCMatch.score_expression(qwVecs, qwAvgVec, qLSTMVec, colnameVec, colWdVecs) 450 | scoreVec = dt.concatenate(lstScores) 451 | actScore = self.SelColFF.score_expression(scoreVec) 452 | coverageMap = self.determine_coverage_by_name(qwVecs, colnameVec) 453 | 454 | elif act_type == ActionType.WhereCol: # same as ActionType.ActSelect, but with different coefficients in weighted sum 455 | # column type embedding # TODO: MAY BE WRONG IMPLEMENTATION HERE 456 | if self.qinfo.values_in_ques: 457 | colTypeScore = self.ColTypeN 458 | else: 459 | colTypeScore = self.ColTypeW 460 | lstScores = self.QCMatch.score_expression(qwVecs, qwAvgVec, qLSTMVec, colnameVec, colWdVecs) 461 | scoreVec = dt.concatenate(lstScores + [colTypeScore]) 462 | actScore = self.WhereColFF.score_expression(scoreVec) 463 | coverageMap = self.determine_coverage_by_name(qwVecs, colnameVec) 464 | 465 | elif act_type == ActionType.CondEqRow: 466 | actScore = nnmod.MaxScore(qwVecs, condCellVec) 467 | coverageMap = self.determine_coverage_by_name(qwVecs, condCellVec) 468 | 469 | elif act_type == ActionType.CondNeRow: 470 | entScore = nnmod.MaxScore(qwVecs, condCellVec) 471 | negScore = self.Negate.score_expression(qwAvgVec) 472 | scoreVec = dt.concatenate([entScore, negScore]) 473 | actScore = self.NegFF.score_expression(scoreVec) 474 | coverageMap = self.determine_coverage_by_name(qwVecs, condCellVec) 475 | 476 | elif act_type == ActionType.CondGT or act_type == ActionType.FpCondGT: 477 | actScore = self.CondGT.score_expression(qwVecs, action.val[0]) 478 | coverageMap = self.determine_coverage_by_name(qwVecs, self.CondGT.OpW) 479 | 480 | elif act_type == ActionType.CondGE or act_type == ActionType.FpCondGE: 481 | actScore = self.CondGE.score_expression(qwVecs, action.val[0]) 482 | coverageMap = self.determine_coverage_by_name(qwVecs, self.CondGE.OpW) 483 | 484 | elif act_type == ActionType.CondLT or act_type == ActionType.FpCondLT: 485 | actScore = self.CondLT.score_expression(qwVecs, action.val[0]) 486 | coverageMap = self.determine_coverage_by_name(qwVecs, self.CondLT.OpW) 487 | 488 | elif act_type == ActionType.CondLE or act_type == ActionType.FpCondLE: 489 | actScore = self.CondLE.score_expression(qwVecs, action.val[0]) 490 | coverageMap = self.determine_coverage_by_name(qwVecs, self.CondLE.OpW) 491 | 492 | elif act_type == ActionType.ArgMin or act_type == ActionType.FpArgMin: 493 | actScore = self.ArgMin.score_expression(qwVecs) 494 | coverageMap = self.determine_coverage_by_name(qwVecs, self.ArgMin.OpW) 495 | 496 | elif act_type == ActionType.ArgMax or act_type == ActionType.FpArgMax: 497 | actScore = self.ArgMax.score_expression(qwVecs) 498 | coverageMap = self.determine_coverage_by_name(qwVecs, self.ArgMax.OpW) 499 | 500 | elif act_type == ActionType.FpWhereCol: # similar to ActionType.WhereCol 501 | # column type embedding 502 | if self.qinfo.values_in_ques: 503 | colTypeScore = self.ColTypeN 504 | else: 505 | colTypeScore = self.ColTypeW 506 | lstScores = self.QCMatch.score_expression(qwVecs, qwAvgVec, qLSTMVec, colnameVec, colWdVecs) 507 | lstPrevScores = self.QCMatch.score_expression(prev_qwVecs, prev_qwAvgVec, prev_qLSTMVec, colnameVec, colWdVecs) 508 | scoreVec = dt.concatenate(lstScores + [colTypeScore] + lstPrevScores) 509 | 510 | actScore = self.FpWhereColFF.score_expression(scoreVec) 511 | coverageMap = self.determine_coverage_by_name(qwVecs, colnameVec) 512 | 513 | elif act_type == ActionType.FpCondEqRow: 514 | entScore = nnmod.MaxScore(qwVecs, condCellVec) 515 | prev_entScore = nnmod.MaxScore(prev_qwVecs, condCellVec) 516 | 517 | actScore = dt.bmax(entScore, prev_entScore) 518 | coverageMap = self.determine_coverage_by_name(qwVecs, condCellVec) 519 | 520 | elif act_type == ActionType.FpCondNeRow: 521 | entScore = nnmod.MaxScore(qwVecs, condCellVec) 522 | prev_entScore = nnmod.MaxScore(prev_qwVecs, condCellVec) 523 | negScore = self.Negate.score_expression(qwAvgVec) 524 | scoreVec = dt.concatenate([dt.bmax(entScore, prev_entScore), negScore]) 525 | actScore = self.NegFF.score_expression(scoreVec) 526 | coverageMap = self.determine_coverage_by_name(qwVecs, condCellVec) 527 | 528 | 529 | elif act_type == ActionType.SameAsPrevious: 530 | quesLSTMScore = dt.dot_product(prev_qLSTMVec, qLSTMVec) 531 | quesAvgScore = dt.dot_product(prev_qwAvgVec, qwAvgVec) 532 | actScore = dt.dot_product(self.SameAsPreviousW, 533 | dt.concatenate([quesLSTMScore, quesAvgScore])) 534 | coverageMap = dt.inputVector(np.zeros(len(qwVecs))) 535 | 536 | else: 537 | assert False, "Error! Unknown act_type: %d" % act_type 538 | 539 | res_list.append(actScore) 540 | meta_list.append(coverageMap) 541 | 542 | return dt.concatenate(res_list), meta_list 543 | 544 | def get_new_state_after_action(self, action, meta): 545 | assert action in self.get_action_set() 546 | new_state = self.clone() 547 | new_state.action_history.append(action) 548 | new_state.meta_history.append(meta) 549 | return new_state 550 | 551 | def __str__(self): 552 | return "\t".join([self.af.actidx2str(act) for act in self.action_history]) -------------------------------------------------------------------------------- /statesearch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # effort of writing python 2/3 compatiable code 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import unicode_literals 7 | from future.utils import iteritems 8 | from operator import itemgetter, attrgetter, methodcaller 9 | 10 | # from sys import stdin 11 | # reload(sys) 12 | # sys.setdefaultencoding('utf8') 13 | 14 | import sys 15 | 16 | if sys.version < '3': 17 | from codecs import getwriter 18 | stderr = getwriter('utf-8')(sys.stderr) 19 | stdout = getwriter('utf-8')(sys.stdout) 20 | else: 21 | stderr = sys.stderr 22 | 23 | import dynet as dt 24 | import numpy as np 25 | import random 26 | 27 | class BeamSearchInferencer: 28 | def __init__(self,nmodel,beam_size = 5): 29 | self.beam_size = beam_size 30 | self.neural_model = nmodel 31 | self.only_one_best = True 32 | 33 | def beam_predict(self,init_state): 34 | """ perform beam search to find the state with the best path score: maximize \sum f(S,a) 35 | 36 | :param state: state 37 | :returns: a list of the state (usually with size beam) 38 | :rtype: list(State) 39 | 40 | """ 41 | init_state.path_score_expression = dt.scalarInput(0) 42 | init_state.score = 0 43 | self.cur_save_states = [init_state] 44 | 45 | while True: 46 | self.next_states = [] 47 | contain_nonend_state = False 48 | 49 | for state in self.cur_save_states: 50 | cur_path_expression = state.path_score_expression 51 | cur_path_score = cur_path_expression.scalar_value() 52 | 53 | if state.is_end(): 54 | self.next_states.append(state) # keep end state for once so that they can compete with longer sequence 55 | continue 56 | 57 | contain_nonend_state = True 58 | 59 | action_list = state.get_action_set() 60 | new_expression_list, meta_info_list = state.get_next_score_expressions(action_list) 61 | 62 | for i in range(len(action_list)): 63 | action = action_list[i] 64 | new_expression = new_expression_list[i] #dynet call for getting element 65 | meta_info = meta_info_list[i] 66 | new_state = state.get_new_state_after_action(action,meta_info) 67 | new_state.path_score_expression = cur_path_expression + new_expression 68 | new_state.score = state.score + new_expression.scalar_value() 69 | #print("comparison",new_state.score,new_state.path_score_expression.value()) 70 | #assert abs(new_state.score - new_state.path_score_expression.value()) < 1e-5 71 | 72 | self.next_states.append(new_state) 73 | 74 | self.next_states.sort(key=lambda x: -1 * x.score) 75 | # for x in self.next_states: 76 | # print("===>",x,x.score) 77 | 78 | next_size = min([len(self.next_states),self.beam_size]) 79 | 80 | if contain_nonend_state == False: 81 | return self.next_states[:next_size] 82 | else: 83 | self.cur_save_states = self.next_states[:next_size] 84 | 85 | def beam_predict_max_violation(self,init_state,gold_ans): 86 | """ perform beam search to find the state with the best path score: maximize \sum f(S,a) - R(a) using estimated reward 87 | 88 | :param state: state 89 | :param state: the gold answers (used to calcuate step_reward) 90 | :returns: a list of the state (usually with size beam) 91 | :rtype: list(State) 92 | 93 | """ 94 | init_state.path_score_expression = dt.scalarInput(0) 95 | init_state.score = 0 96 | self.cur_save_states = [init_state] 97 | 98 | while True: 99 | self.next_states = [] 100 | contain_nonend_state = False 101 | 102 | for state in self.cur_save_states: 103 | cur_path_expression = state.path_score_expression 104 | cur_path_score = cur_path_expression.scalar_value() 105 | 106 | if state.is_end(): 107 | self.next_states.append(state) # keep end state for once so that they can compete with longer sequence 108 | continue 109 | 110 | contain_nonend_state = True 111 | 112 | action_list = state.get_action_set() 113 | estimated_reward_list = [state.estimated_reward(gold_ans,action) for action in action_list] 114 | 115 | 116 | new_expression_list, meta_info_list = state.get_next_score_expressions(action_list) 117 | 118 | 119 | for i in range(len(action_list)): 120 | action = action_list[i] 121 | new_expression = new_expression_list[i] #dynet call for getting element 122 | meta_info = meta_info_list[i] 123 | new_state = state.get_new_state_after_action(action,meta_info) 124 | 125 | new_state.path_score_expression = cur_path_expression + new_expression 126 | new_state.score = new_state.path_score_expression.scalar_value() - estimated_reward_list[i] #TODO this forward can be expensive... 127 | #print("comparison",new_state.score,new_state.path_score_expression.value()) 128 | #assert abs(new_state.score - new_state.path_score_expression.value()) < 1e-5 129 | 130 | self.next_states.append(new_state) 131 | 132 | self.next_states.sort(key=lambda x: -1 * x.score) 133 | # for x in self.next_states: 134 | # print("===>",x,x.score) 135 | 136 | next_size = min([len(self.next_states),self.beam_size]) 137 | 138 | if contain_nonend_state == False: 139 | return self.next_states[:next_size] 140 | else: 141 | self.cur_save_states = self.next_states[:next_size] 142 | 143 | def beam_find_actions_with_answer_guidence(self,init_state,gold_ans): 144 | """ perform beam search to find the state with the best reward: maximize \sum R(a) using estimated reward 145 | 146 | :param state: state 147 | :param state: the gold answers (used to calcuate estimated_reward, also used it to cut the action space) 148 | :returns: a list of the state (usually with size beam) 149 | :rtype: list(State) 150 | 151 | """ 152 | 153 | init_state.path_score_expression = dt.scalarInput(0) 154 | init_state.score = 0 155 | self.cur_save_states = [init_state] 156 | 157 | #print ("*" * 100) 158 | while True: 159 | self.next_states = [] 160 | contain_nonend_state = False 161 | 162 | for state in self.cur_save_states: 163 | cur_path_expression = state.path_score_expression 164 | cur_path_score = cur_path_expression.scalar_value() 165 | 166 | if state.is_end(): 167 | self.next_states.append(state) # keep end state for once so that they can compete with longer sequence 168 | continue 169 | 170 | contain_nonend_state = True 171 | 172 | action_list, estimated_reward_list = state.get_action_set_withans(gold_ans) 173 | if action_list == []: 174 | continue 175 | 176 | new_expression_list, meta_info_list = state.get_next_score_expressions(action_list) 177 | 178 | for i in range(len(action_list)): 179 | action = action_list[i] 180 | new_expression = new_expression_list[i] #dynet call for getting element 181 | meta_info = meta_info_list[i] 182 | new_state = state.get_new_state_after_action(action,meta_info) 183 | new_state.path_score_expression = cur_path_expression + new_expression 184 | new_state.score = estimated_reward_list[i] 185 | 186 | self.next_states.append(new_state) 187 | 188 | self.next_states.sort(key=lambda x: (-1 * x.score,-1* (x.path_score_expression.value()))) #sort by reward first; if equal; sort by model score 189 | 190 | next_size = min([len(self.next_states),self.beam_size]) 191 | 192 | if contain_nonend_state == False: 193 | return self.next_states[:next_size] 194 | else: 195 | self.cur_save_states = self.next_states[:next_size] 196 | 197 | 198 | def beam_train_expected_reward(self,init_state, gold_ans): 199 | 200 | end_state_list = self.beam_predict(init_state) 201 | 202 | # find the best state in the list to make things faster 203 | reward_states = self.beam_find_actions_with_answer_guidence(init_state, gold_ans) 204 | if reward_states == []: 205 | return 0 206 | best_reward_state = reward_states[0] 207 | find_best_state = False 208 | for state in end_state_list: 209 | if state.action_history == best_reward_state.action_history: #?? 210 | find_best_state = True 211 | #print("found gold state", find_best_state, best_reward_state) 212 | break 213 | if not find_best_state: 214 | end_state_list.append(best_reward_state) 215 | 216 | exp_path_list = [dt.exp(x.path_score_expression) for x in end_state_list] 217 | sum_exp = dt.esum(exp_path_list) 218 | # print("size of end state", len(end_state_list)) 219 | reward_list = [dt.scalarInput(st.reward(gold_ans)) * exp_expr for st ,exp_expr in zip(end_state_list,exp_path_list)] 220 | # print("=" * 80) 221 | # for x in end_state_list: 222 | # print(x, dt.exp(x.path_score_expression).value(), x.reward(gold_ans)) 223 | 224 | 225 | expected_negative_reward = -dt.cdiv(dt.esum(reward_list),sum_exp) 226 | 227 | #print("values", sum_exp.value(), (dt.esum(reward_list)).value(), dt.cdiv(dt.esum(reward_list),sum_exp).value()) 228 | # print("=" * 80) 229 | value = expected_negative_reward.scalar_value() 230 | #print("obj", expected_negative_reward.value()) 231 | expected_negative_reward.backward() 232 | self.neural_model.learner.update() 233 | return value 234 | # there might be a bug here; should not update at the first iter (if all reward are zero), should not update? 235 | 236 | def beam_train_max_margin_with_answer_guidence(self, init_state, gold_ans): 237 | # perform two beam search; one for prediction and the other for state action suff 238 | # max reward y = argmax(r(y)) with the help of gold_ans 239 | # max y' = argmax f(x,y) - R(y') 240 | # loss = max(f(x,y') - f(x,y) + R(y) - R(y') , 0) 241 | 242 | #end_state_list = self.beam_predict(init_state) 243 | end_state_list = self.beam_predict_max_violation(init_state, gold_ans) # have to use this to make it work.... 244 | reward_list = [x.reward(gold_ans) for x in end_state_list] 245 | violation_list = [s.path_score_expression.value() - reward for s,reward in zip(end_state_list,reward_list)] 246 | 247 | best_score_state_idx = violation_list.index(max(violation_list)) # find the best scoring seq with minimal reward 248 | best_score_state = end_state_list[best_score_state_idx] 249 | best_score_state_reward = reward_list[best_score_state_idx] 250 | 251 | loss_value = 0 252 | 253 | if self.only_one_best: 254 | best_states = self.beam_find_actions_with_answer_guidence(init_state, gold_ans) 255 | if best_states == []: 256 | return 0,[] 257 | best_reward_state = best_states[0] 258 | #print ("debug: found best_reward_state: qid =", best_reward_state.qinfo.seq_qid, best_reward_state) 259 | best_reward_state_reward = best_reward_state.reward(gold_ans) 260 | #print ("debug: best_reward_state_reward =", best_reward_state_reward) 261 | loss = dt.rectify(best_score_state.path_score_expression - best_reward_state.path_score_expression + dt.scalarInput(best_reward_state_reward - best_score_state_reward)) 262 | else: 263 | best_states = self.beam_find_actions_with_answer_guidence(init_state, gold_ans) 264 | best_states_rewards = [s.reward(gold_ans) for s in best_states] 265 | max_reward = max(best_states_rewards) 266 | best_states = [s for s,r in zip(best_states,best_states_rewards) if r == max_reward] 267 | loss = dt.average([dt.rectify(best_score_state.path_score_expression - best_reward_state.path_score_expression + dt.scalarInput(max_reward - best_score_state_reward)) for best_reward_state in best_states]) 268 | 269 | loss_value = loss.value() 270 | loss.backward() 271 | 272 | self.neural_model.learner.update() 273 | 274 | #print ("debug: beam_train_max_margin_with_answer_guidence done. loss_value =", loss_value) 275 | 276 | return loss_value,best_states 277 | 278 | def beam_train_max_margin(self, init_state, gold_ans): 279 | #still did not use the gold sequence but use the min risk training 280 | #max reward y = argmax(r(y)) 281 | #max y' = argmax f(x,y) - R(y') 282 | # loss = max(f(x,y') - f(x,y) + R(y) - R(y') , 0) 283 | 284 | end_state_list = self.beam_predict(init_state) 285 | reward_list = [x.reward(gold_ans) for x in end_state_list] 286 | violation_list = [s.score - reward for s,reward in zip(end_state_list,reward_list)] 287 | 288 | best_score_state_idx = violation_list.index(max(violation_list)) # find the best scoring seq with minimal reward 289 | best_reward_state_idx = reward_list.index(max(reward_list)) # find seq with the max reward in beam 290 | 291 | best_score_state = end_state_list[best_score_state_idx] 292 | best_reward_state = end_state_list[best_reward_state_idx] 293 | 294 | best_score_state_reward = reward_list[best_score_state_idx] 295 | best_reward_state_reward = reward_list[best_reward_state_idx] 296 | 297 | 298 | loss = dt.rectify(best_score_state.path_score_expression - best_reward_state.path_score_expression + dt.scalarInput(best_reward_state_reward - best_score_state_reward)) 299 | loss_value = loss.value() 300 | 301 | loss.backward() 302 | 303 | self.neural_model.learner.update() 304 | return loss_value 305 | 306 | #print("loss_value", loss_value) 307 | #print(self.neural_model.learner.status()) 308 | #for i,ss in enumerate(end_state_list): 309 | # print(i, ss, "score", ss.path_score_expression.value(), "reward", reward_list[i]) 310 | 311 | #print("first", best_score_state, "score", best_score_state.path_score_expression.value(), "reward", reward_list[best_score_state_idx]) 312 | #print("gold", best_reward_state, "score", best_reward_state.path_score_expression.value(), "reward", reward_list[best_reward_state_idx]) 313 | 314 | 315 | def beam_train_max_margin_with_goldactions(self,init_state, gold_actions): 316 | #max y = gold y 317 | #max y' = argmax f(x,y) 318 | # loss = max(f(x,y') - f(x,y) + R(y) - R(y') , 0) 319 | 320 | #loss 321 | #end_state_list = self.beam_predict(init_state) # top-k argmax_y f(x,y) 322 | end_state_list = self.beam_predict_max_violation(init_state,gold_actions) # top-k argmax_y f(x,y) + R(y*) - R(y) // Current implementation is the same as Hamming distance 323 | best_score_state = end_state_list[0] 324 | reward_list = [x.reward(gold_actions) for x in end_state_list] 325 | 326 | best_reward_state = self.get_goldstate_with_gold_actions(init_state,gold_actions) 327 | best_reward = best_reward_state.reward(gold_actions) 328 | 329 | loss = dt.rectify(best_score_state.path_score_expression - best_reward_state.path_score_expression + dt.scalarInput(best_reward-reward_list[0]) ) 330 | loss_value = loss.value() 331 | 332 | loss.backward() 333 | self.neural_model.learner.update() 334 | return loss_value 335 | 336 | 337 | def greedy_train_max_sumlogllh(self,init_state, gold_actions): 338 | 339 | total_obj = dt.scalarInput(0) 340 | 341 | cur_state = init_state 342 | res = 0 343 | idx = 0 344 | while True: 345 | if cur_state.is_end(): 346 | break 347 | 348 | action_list = list(cur_state.get_action_set()) 349 | new_expression_list, meta_info_list = cur_state.get_next_score_expressions(action_list) 350 | prob_list = dt.softmax(new_expression_list) 351 | gold_action = gold_actions[idx] 352 | action_idx = action_list.index(gold_action) 353 | total_obj += -(dt.log(prob_list[action_idx])) 354 | 355 | cur_state = cur_state.get_new_state_after_action(gold_action,meta_info_list[action_idx]) 356 | idx += 1 357 | #print (cur_state) 358 | 359 | res = total_obj.scalar_value() 360 | total_obj.backward() 361 | self.neural_model.learner.update() 362 | 363 | 364 | return res 365 | 366 | def get_goldstate_with_gold_actions(self,state,gold_actions): 367 | 368 | cur_state = state.clone() 369 | cur_state.path_score_expression = dt.scalarInput(0) 370 | 371 | time_idx = 0 372 | while True: 373 | if cur_state.is_end(): 374 | break 375 | action_list = list(cur_state.get_action_set()) 376 | old_expression = cur_state.path_score_expression 377 | new_expression_list, meta_info_list = state.get_next_score_expressions(action_list) 378 | gold_act = gold_actions[time_idx] 379 | action_idx = action_list.index(gold_act) 380 | 381 | cur_state = cur_state.get_new_state_after_action(gold_act,meta_info_list[action_idx]) 382 | cur_state.path_score_expression = old_expression + new_expression_list[action_idx] 383 | 384 | time_idx += 1 385 | 386 | return cur_state 387 | 388 | 389 | def greedy_predict(self,state): 390 | 391 | cur_state = state 392 | 393 | while True: 394 | if cur_state.is_end(): 395 | break 396 | 397 | action_list = cur_state.get_action_set() 398 | new_expression_list,meta_info_list = cur_state.get_next_score_expressions(action_list) 399 | prob_list = dt.softmax(new_expression_list) 400 | pred = np.argmax(prob_list.npvalue()) 401 | action = action_list[pred] 402 | 403 | cur_state = cur_state.get_new_state_after_action(action,meta_info_list[pred]) 404 | 405 | return cur_state 406 | 407 | 408 | if __name__ == '__main__': 409 | print("test") 410 | -------------------------------------------------------------------------------- /testmkl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dynet as dt 3 | 4 | def test_eigenvalue(): 5 | i= 5000 6 | data = np.random.random((i,i)) 7 | result = np.linalg.eig(data) 8 | print (result) 9 | 10 | def multi(): 11 | i= 100 12 | j= 100 13 | model = dt.Model() 14 | pA = model.add_parameters((i,j)) 15 | pB = model.add_parameters((j,1)) 16 | A = dt.parameter(pA) 17 | B = dt.parameter(pB) 18 | for j in range(1000000): 19 | result = A * B 20 | len(result.value()) 21 | 22 | def npmulti(i,j): 23 | A = np.random.random((i,j)) 24 | B = np.random.random((i,j)) 25 | result = A * B 26 | ''' 27 | for p in xrange(1000000): 28 | result = A * B 29 | len(result) 30 | ''' 31 | 32 | if __name__ == '__main__': 33 | #test_eigenvalue() 34 | multi() 35 | npmulti(5000,5000) 36 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | 3 | import argparse, os, cPickle, sys, csv, glob, json, itertools, re, datetime 4 | from unidecode import unidecode 5 | from collections import Counter, defaultdict 6 | from tokenizer import simpleTokenize 7 | 8 | import config 9 | 10 | def read(fname): 11 | sent = [] 12 | for line in file(fname): 13 | line = line.strip().split() 14 | if not line: 15 | if sent: yield sent 16 | sent = [] 17 | else: 18 | w,p = line 19 | sent.append((w,p)) 20 | 21 | class Vocab: 22 | def __init__(self, w2i=None): 23 | if w2i is None: w2i = defaultdict(count(0).next) 24 | self.w2i = dict(w2i) 25 | self.i2w = {i:w for w,i in w2i.iteritems()} 26 | @classmethod 27 | def from_corpus(cls, corpus): 28 | w2i = defaultdict(count(0).next) 29 | for sent in corpus: 30 | [w2i[word] for word in sent] 31 | return Vocab(w2i) 32 | 33 | def size(self): return len(self.w2i.keys()) 34 | 35 | class CorpusReader: 36 | def __init__(self, fname): 37 | self.fname = fname 38 | def __iter__(self): 39 | for line in file(self.fname): 40 | line = line.strip().split() 41 | #line = [' ' if x == '' else x for x in line] 42 | yield line 43 | 44 | class CharsCorpusReader: 45 | def __init__(self, fname, begin=None): 46 | self.fname = fname 47 | self.begin = begin 48 | def __iter__(self): 49 | begin = self.begin 50 | for line in file(self.fname): 51 | line = list(line) 52 | if begin: 53 | line = [begin] + line 54 | yield line 55 | 56 | ##------------------------------------------------------------------------------ 57 | 58 | class QuestionInfo: 59 | def __init__(self, seq_qid = "", pos = -1, question = "", table_file = "", headers = [], entries = [], types = [], 60 | answer_column_idx = -1, answer_column_name = "", answer_rows = [], is_col_select = None, 61 | eq_cond_column_idx = -1, eq_cond_column_name = "", eq_cond_value = "", 62 | complete_match = None, answer_coordinates = [], answer_text = [], annTab = None, 63 | numeric_cols = set()): 64 | 65 | self.seq_qid = seq_qid # question id 66 | self.pos = pos # question position in a sequence 67 | self.question = question # question 68 | self.table_file = table_file # table file name 69 | self.headers = headers # table header fields 70 | self.entries = entries # table content 71 | self.types = types # table column field types 72 | self.answer_column_idx = answer_column_idx # answer column index 73 | self.answer_column_name = answer_column_name # answer column field 74 | self.answer_rows = answer_rows # answer row indices 75 | self.is_col_select = is_col_select # is a column-select-only question? 76 | self.eq_cond_column_idx = eq_cond_column_idx # column index of Y in (Y=Z condition) 77 | self.eq_cond_column_name = eq_cond_column_name # column index of Y in (Y=Z condition) 78 | self.eq_cond_value = eq_cond_value # column value of Z in (Y=Z condition) 79 | self.complete_match = complete_match # does our "parse" answer the question? 80 | self.answer_coordinates = answer_coordinates # answer_coordinates 81 | self.answer_text = answer_text # answer_text 82 | self.annTab = annTab # annotated table 83 | self.numeric_cols = numeric_cols # set of numeric columns 84 | 85 | self.all_words = self.comp_all_words() 86 | self.ques_word_sequence = self.comp_ques_word_sequence() 87 | self.ques_word_sequence_ngram_str = ' ' + ' '.join(self.ques_word_sequence) + ' ' 88 | self.headers_word_sequences = self.comp_headers_word_sequences() 89 | self.entries_word_sequences = self.comp_entries_word_sequences() 90 | 91 | #print("before call...") 92 | numbers = findNumbers(self.ques_word_sequence) 93 | #print("after call", numbers, len(numbers)) 94 | #self.values_in_ques = findNumbers(self.ques_word_sequence), # numeric values in the question 95 | self.values_in_ques = numbers 96 | 97 | #print("values_in_ques:", self.values_in_ques, len(self.values_in_ques)) 98 | 99 | self.num_rows = len(self.entries) 100 | self.num_columns = len(self.headers) 101 | 102 | # TODO: call the following functions to make sure "lower" is consistent 103 | def comp_all_words(self): 104 | words = set() 105 | for w in simpleTokenize(self.question): words.add(w.lower()) 106 | for colname in self.headers: 107 | for w in simpleTokenize(colname): words.add(w.lower()) 108 | for row in self.entries: 109 | for ent in row: 110 | for w in simpleTokenize(ent): words.add(w.lower()) 111 | return words 112 | 113 | def comp_ques_word_sequence(self): 114 | return [w.lower() for w in simpleTokenize(self.question)] 115 | 116 | def comp_headers_word_sequences(self): 117 | ret = [] 118 | for colname in self.headers: 119 | ret.append([w.lower() for w in simpleTokenize(colname)]) 120 | return ret 121 | 122 | def comp_entries_word_sequences(self): 123 | ret = [] 124 | for row in self.entries: 125 | row_sequences = [] 126 | for ent in row: 127 | row_sequences.append([w.lower() for w in simpleTokenize(ent)]) 128 | ret.append(row_sequences) 129 | return ret 130 | 131 | def contain_ngram(self, sub_ngram_str): 132 | sub_ngram_str = ' ' + sub_ngram_str + ' ' 133 | ret = sub_ngram_str in self.ques_word_sequence_ngram_str 134 | #print(ret, sub_ngram_str, self.ques_word_sequence_ngram_str) 135 | return ret 136 | 137 | class ResultInfo: # information of the previous question and its parse and answers 138 | def __init__(self, seq_qid = "", question = "", ques_word_sequence = [], 139 | pred_answer_coordinates = [], pred_answer_column = -1, pred_question_parse = []): 140 | self.prev_seq_qid = seq_qid # previous question id 141 | self.prev_question = question # previous question 142 | self.prev_ques_word_sequence = ques_word_sequence # previous question word sequence 143 | self.prev_pred_answer_coordinates = pred_answer_coordinates # predicted answer_coordinates of previous question 144 | self.prev_pred_answer_column = pred_answer_column # predicted answer column of previous question (i.e., SELECT X) 145 | self.prev_question_parse = pred_question_parse # predicted parse of previous question (i.e., the final picked action_history) 146 | self.subtab_rows = sorted(list(set([coor[0] for coor in pred_answer_coordinates]))) 147 | def __str__(self): 148 | ret = "prev_seq_qid = '%s', prev_question = '%s', prev_ques_word_sequence = '[%s]'" % (self.prev_seq_qid, self.prev_question, self.prev_ques_word_sequence) 149 | ret += "prev_pred_answer_column = '%s'" % self.prev_pred_answer_column 150 | return ret 151 | 152 | def get_question_ids(fnTsv, dirData): 153 | reader = csv.DictReader(open("%s" % (fnTsv), 'r'), delimiter='\t') 154 | ret = [] 155 | for row in reader: 156 | ret.append(row['id']) 157 | return ret 158 | 159 | # read questions from dataset and store the question information 160 | def get_labeled_questions(fnTsv, dirData, firstOnly=False, skipEmptyAns=True): 161 | reader = csv.DictReader(open("%s" % (fnTsv), 'r'), delimiter='\t') 162 | data = [] 163 | 164 | for row in reader: 165 | qid = row['id'] 166 | count = row['annotator'] 167 | pos = row['position'] 168 | lstLoc = eval(row['answer_coordinates']) 169 | if skipEmptyAns and lstLoc == []: # TO-DO: try to use the numeric answers 170 | continue 171 | if lstLoc != [] and type(lstLoc[0]) == str: 172 | locations = sorted([eval(t) for t in lstLoc]) # from string to 2-tuples 173 | else: 174 | locations = lstLoc 175 | 176 | if firstOnly and pos != '0': 177 | continue 178 | 179 | seq_qid = '%s_%s_%s' % (qid, count, pos) 180 | with open("%s/%s" % (dirData, row['table_file']), 'rb') as csvfile: 181 | table_reader = csv.DictReader(csvfile) 182 | headers = table_reader.fieldnames 183 | # replace empty column name with _EMPTY_ 184 | headers = map(lambda x: "_EMPTY_" if x == '' else x, headers) 185 | question = row['question'] 186 | if 'eq_cond' in row: 187 | eq_cond = row['eq_cond'] 188 | else: 189 | eq_cond = '' 190 | eq_cond_column = -1 191 | eq_cond_column_name = "" 192 | eq_cond_value = "" 193 | if eq_cond.find(' = ') != -1: 194 | f = eq_cond.split(' = ') 195 | eq_cond_column_name = f[0] 196 | # replace empty column name with _EMPTY_ 197 | if eq_cond_column_name == '': eq_cond_column_name = "_EMPTY_" 198 | if len(f) == 2: eq_cond_value = f[1] 199 | eq_cond_column = headers.index(eq_cond_column_name) 200 | 201 | entries = [] 202 | for row_index, trow in enumerate(table_reader): 203 | dec_row = [] 204 | for key in table_reader.fieldnames: 205 | try: 206 | cell = unidecode(trow[key].decode('utf-8')).strip().replace('\n', ' ').replace('"', '') 207 | except: 208 | cell = '' 209 | if cell == '': cell = "_EMPTY_" # replace empty cell with a special _EMPTY_ token 210 | dec_row.append(cell) 211 | 212 | dec_row = dict(zip(headers, dec_row)) 213 | 214 | entry = [] 215 | for key in headers: 216 | entry.append(dec_row[key]) 217 | entries.append(entry) 218 | 219 | num_cols = len(headers) 220 | num_rows = len(entries) 221 | 222 | # check annotated table 223 | tabid = str(row['table_file']).replace('table_csv/', '').replace('.csv','').split('_') 224 | fnAnnTab = str("%s/%s-annotated/%s.annotated" % (config.d["AnnotatedTableDir"], tabid[0], tabid[1])) 225 | annTab = loadAnnotatedTable(fnAnnTab, num_rows, num_cols) 226 | 227 | # check which columns are numeric columns 228 | numeric_cols = set() 229 | for c in xrange(num_cols): 230 | sumNumRows = sum([1 for r in xrange(num_rows) if annTab[(r,c)].number != config.NaN or annTab[(r,c)].date != None]) 231 | if sumNumRows >= num_rows-1: # leave some room for table error 232 | numeric_cols.add(c) 233 | 234 | types = infer_column_types(headers, entries) 235 | if locations == []: 236 | answer_column = -1 237 | is_col_select = False 238 | answer_rows = set() 239 | else: 240 | columns = [coord[1] for coord in locations] 241 | col_count = Counter() 242 | for col in columns: 243 | col_count[col] += 1 244 | #print("debug: most_common", col_count.most_common(1), "fnAnnTab", fnAnnTab) 245 | answer_column = col_count.most_common(1)[0][0] 246 | if 'complete_match' in row: 247 | complete_match = row['complete_match'] 248 | else: 249 | complete_match = '' 250 | 251 | # make sure it's not column selection -- this snippet may have bugs 252 | is_contig = True 253 | prev_row_idx = locations[0][0] 254 | for r, c in locations[1:]: 255 | if r == prev_row_idx + 1: 256 | prev_row_idx = r 257 | else: 258 | is_contig = False 259 | break 260 | is_col_select = False 261 | if len(locations) == num_rows or (len(locations) == (num_rows - 1) and is_contig): 262 | is_col_select = True 263 | 264 | answer_rows = set([coord[0] for coord in locations]) 265 | 266 | quesInfo = QuestionInfo(seq_qid = seq_qid, # question id 267 | pos = pos, # the question position in the sequence 268 | question = question, # question 269 | table_file = row['table_file'], # table file name 270 | headers = headers, # table header fields 271 | entries = entries, # table content 272 | types = types, # table column field types 273 | answer_column_idx = answer_column, # answer column index 274 | answer_column_name = headers[answer_column], # answer column field 275 | answer_rows = answer_rows, # answer row indices 276 | is_col_select = is_col_select, # is a column-select-only question? 277 | eq_cond_column_idx = eq_cond_column, # column index of Y in (Y=Z condition) 278 | eq_cond_column_name = headers[eq_cond_column], # column index of Y in (Y=Z condition) 279 | eq_cond_value = eq_cond_value, # column value of Z in (Y=Z condition) 280 | complete_match = complete_match, # does our "parse" answer the question? 281 | answer_coordinates = locations, # answer_coordinates 282 | answer_text = eval(row['answer_text']), # answer_text 283 | annTab = annTab, 284 | numeric_cols = numeric_cols # set of numeric or date columns 285 | ) 286 | 287 | data.append(quesInfo) 288 | 289 | return data 290 | 291 | 292 | # figure out if each column contains text, dates, or numbers 293 | def infer_column_types(headers, entries, min_date=1700, max_date=2016): 294 | types = [] 295 | for c in range(len(headers)): 296 | curr_entries = [] 297 | for r in range(len(entries)): 298 | curr_entries.append(entries[r][c]) 299 | 300 | curr_type = 'text' 301 | 302 | # check for numbers 303 | new_entries = [] 304 | for x in curr_entries: 305 | try: 306 | # if the first thing in the cell is a number, ignore the rest 307 | # x = x.replace(',', '') 308 | y = float(x) 309 | new_entries.append(y) 310 | except: 311 | break 312 | 313 | # now figure out if it's a date or a normal number 314 | # we'll say dates are everything between 1800-2016 315 | if len(new_entries) == len(curr_entries): 316 | curr_type = 'number' 317 | 318 | min_col = min(new_entries) 319 | max_col = max(new_entries) 320 | if min_col > min_date and max_col <= max_date: 321 | curr_type = 'date' 322 | 323 | types.append(curr_type) 324 | 325 | return types 326 | 327 | def is_float_try(str): 328 | try: 329 | float(str) 330 | return True 331 | except ValueError: 332 | return False 333 | 334 | ##------------------------------------------------------------------------------ 335 | 336 | class AnnotatedTabEntry: 337 | def __init__(self, row = config.NaN, col = config.NaN, content = "", 338 | strTokens = "", strLemmaTokens = "", strPosTags = "", strNerTags = "", strNerValues = "", 339 | number = config.NaN, strDate = "", num2 = config.NaN, 340 | strList = ""): 341 | self.row = row 342 | self.col = col 343 | self.content = content 344 | self.tokens = self.mysplit(strTokens) 345 | self.lemmaTokens = self.mysplit(strLemmaTokens) 346 | self.posTags = self.mysplit(strPosTags) 347 | self.nerTags = self.mysplit(strNerTags) 348 | self.nerValues = self.mysplit(strNerValues) 349 | self.number = number 350 | if strDate: 351 | year = 4 # to allow Feb-29 352 | month = day = 1 353 | dateFields = self.mysplit(strDate,'-') 354 | if len(dateFields) == 3: 355 | if RepresentsInt(dateFields[0]): 356 | year = int(dateFields[0]) 357 | if year > datetime.MAXYEAR: year = datetime.MAXYEAR 358 | if year < datetime.MINYEAR: year = datetime.MINYEAR 359 | if RepresentsInt(dateFields[1]): 360 | month = int(dateFields[1]) 361 | if RepresentsInt(dateFields[2]): 362 | day = int(dateFields[2]) 363 | #print(strDate, year,month,day) 364 | self.date = datetime.date(year,month,day) 365 | else: 366 | self.date = None 367 | self.num2 = num2 368 | self.list = self.mysplit(strList) 369 | 370 | @staticmethod 371 | def mysplit(field, delim='|'): 372 | if field: return field.split(delim) 373 | return [] 374 | 375 | def RepresentsInt(s): 376 | try: 377 | int(s) 378 | return True 379 | except ValueError: 380 | return False 381 | 382 | def loadAnnotatedTable(fnAnnTab, num_rows, num_cols): 383 | ret = {} 384 | with open(fnAnnTab, 'rb') as tsvfile: 385 | table_reader = csv.DictReader(tsvfile, delimiter='\t') 386 | for trow in table_reader: 387 | r = int(trow['row']) 388 | c = int(trow['col']) 389 | number = config.NaN 390 | if trow['number']: 391 | number = float(trow['number']) 392 | num2 = config.NaN 393 | if trow['num2']: 394 | number = float(trow['num2']) 395 | 396 | entry = AnnotatedTabEntry(r, c, trow['content'], 397 | trow['tokens'], trow['lemmaTokens'], trow['posTags'], trow['nerTags'], trow['nerValues'], 398 | number, trow['date'], num2, trow['list']) 399 | ret[(r,c)] = entry 400 | # check if the annotated table is complete 401 | inComplete = False 402 | for r in xrange(-1, num_rows): 403 | for c in xrange(num_cols): 404 | if (r,c) not in ret: 405 | ret[(r,c)] = AnnotatedTabEntry() 406 | inComplete = True 407 | if inComplete: 408 | print("Warning: %s is not complete!" % fnAnnTab) 409 | 410 | return ret 411 | 412 | 413 | reNumber = re.compile(r'\d[\d,\.]*|\.\d[\d,\.]*') 414 | strWdNumbers = ['zero','one','two','three','four','five','six','seven','eight','nine','ten','eleven','twelve'] 415 | setWdNumbers = set(strWdNumbers) 416 | ''' 417 | def findNumbers(strQues): 418 | ret = set() 419 | for strNum in reNumber.findall(strQues): 420 | strN = strNum.replace(',','') 421 | if is_float_try(strN): 422 | ret.add(float(strN)) 423 | lowerQ = strQues.lower() 424 | setWds = set(lowerQ.split(' ')) 425 | for i,w in enumerate(strWdNumbers): 426 | if w in setWds: 427 | if i == 1 and "which one" in lowerQ: continue # ignore "one" in "which one" 428 | ret.add(float(i)) 429 | return ret 430 | ''' 431 | def findNumbers(quesWds): 432 | ret = [] 433 | for i,tok in enumerate(quesWds): 434 | m = reNumber.match(tok) 435 | if not m: continue 436 | strN = m.group().replace(',','') 437 | if is_float_try(strN): 438 | ret.append((i,float(strN))) 439 | 440 | if tok in setWdNumbers: 441 | ret.append((i, float(strWdNumbers.index(tok)))) 442 | 443 | #print("len(ret) = ", len(ret)) 444 | return ret 445 | --------------------------------------------------------------------------------