├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── NOTICE ├── README.md ├── SECURITY.md ├── __init__.py ├── eval.py ├── eval.sh ├── preprocess ├── data_process.py ├── download_nltk.py ├── run_me.sh ├── sql2SemQL.py └── utils.py ├── requirements.txt ├── sem2SQL.py ├── src ├── __init__.py ├── args.py ├── beam.py ├── dataset.py ├── models │ ├── __init__.py │ ├── basic_model.py │ ├── model.py │ ├── nn_utils.py │ └── pointer_net.py ├── rule │ ├── __init__.py │ ├── graph.py │ ├── lf.py │ ├── semQL.py │ └── sem_utils.py └── utils.py ├── train.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.suo 8 | *.user 9 | *.userosscache 10 | *.sln.docstates 11 | 12 | # User-specific files (MonoDevelop/Xamarin Studio) 13 | *.userprefs 14 | 15 | # Build results 16 | [Dd]ebug/ 17 | [Dd]ebugPublic/ 18 | [Rr]elease/ 19 | [Rr]eleases/ 20 | x64/ 21 | x86/ 22 | bld/ 23 | [Bb]in/ 24 | [Oo]bj/ 25 | [Ll]og/ 26 | 27 | # Visual Studio 2015/2017 cache/options directory 28 | .vs/ 29 | # Uncomment if you have tasks that create the project's static files in wwwroot 30 | #wwwroot/ 31 | 32 | # Visual Studio 2017 auto generated files 33 | Generated\ Files/ 34 | 35 | # MSTest test Results 36 | [Tt]est[Rr]esult*/ 37 | [Bb]uild[Ll]og.* 38 | 39 | # NUNIT 40 | *.VisualState.xml 41 | TestResult.xml 42 | 43 | # Build Results of an ATL Project 44 | [Dd]ebugPS/ 45 | [Rr]eleasePS/ 46 | dlldata.c 47 | 48 | # Benchmark Results 49 | BenchmarkDotNet.Artifacts/ 50 | 51 | # .NET Core 52 | project.lock.json 53 | project.fragment.lock.json 54 | artifacts/ 55 | **/Properties/launchSettings.json 56 | 57 | # StyleCop 58 | StyleCopReport.xml 59 | 60 | # Files built by Visual Studio 61 | *_i.c 62 | *_p.c 63 | *_i.h 64 | *.ilk 65 | *.meta 66 | *.obj 67 | *.iobj 68 | *.pch 69 | *.pdb 70 | *.ipdb 71 | *.pgc 72 | *.pgd 73 | *.rsp 74 | *.sbr 75 | *.tlb 76 | *.tli 77 | *.tlh 78 | *.tmp 79 | *.tmp_proj 80 | *.log 81 | *.vspscc 82 | *.vssscc 83 | .builds 84 | *.pidb 85 | *.svclog 86 | *.scc 87 | 88 | # Chutzpah Test files 89 | _Chutzpah* 90 | 91 | # Visual C++ cache files 92 | ipch/ 93 | *.aps 94 | *.ncb 95 | *.opendb 96 | *.opensdf 97 | *.sdf 98 | *.cachefile 99 | *.VC.db 100 | *.VC.VC.opendb 101 | 102 | # Visual Studio profiler 103 | *.psess 104 | *.vsp 105 | *.vspx 106 | *.sap 107 | 108 | # Visual Studio Trace Files 109 | *.e2e 110 | 111 | # TFS 2012 Local Workspace 112 | $tf/ 113 | 114 | # Guidance Automation Toolkit 115 | *.gpState 116 | 117 | # ReSharper is a .NET coding add-in 118 | _ReSharper*/ 119 | *.[Rr]e[Ss]harper 120 | *.DotSettings.user 121 | 122 | # JustCode is a .NET coding add-in 123 | .JustCode 124 | 125 | # TeamCity is a build add-in 126 | _TeamCity* 127 | 128 | # DotCover is a Code Coverage Tool 129 | *.dotCover 130 | 131 | # AxoCover is a Code Coverage Tool 132 | .axoCover/* 133 | !.axoCover/settings.json 134 | 135 | # Visual Studio code coverage results 136 | *.coverage 137 | *.coveragexml 138 | 139 | # NCrunch 140 | _NCrunch_* 141 | .*crunch*.local.xml 142 | nCrunchTemp_* 143 | 144 | # MightyMoose 145 | *.mm.* 146 | AutoTest.Net/ 147 | 148 | # Web workbench (sass) 149 | .sass-cache/ 150 | 151 | # Installshield output folder 152 | [Ee]xpress/ 153 | 154 | # DocProject is a documentation generator add-in 155 | DocProject/buildhelp/ 156 | DocProject/Help/*.HxT 157 | DocProject/Help/*.HxC 158 | DocProject/Help/*.hhc 159 | DocProject/Help/*.hhk 160 | DocProject/Help/*.hhp 161 | DocProject/Help/Html2 162 | DocProject/Help/html 163 | 164 | # Click-Once directory 165 | publish/ 166 | 167 | # Publish Web Output 168 | *.[Pp]ublish.xml 169 | *.azurePubxml 170 | # Note: Comment the next line if you want to checkin your web deploy settings, 171 | # but database connection strings (with potential passwords) will be unencrypted 172 | *.pubxml 173 | *.publishproj 174 | 175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 176 | # checkin your Azure Web App publish settings, but sensitive information contained 177 | # in these scripts will be unencrypted 178 | PublishScripts/ 179 | 180 | # NuGet Packages 181 | *.nupkg 182 | # The packages folder can be ignored because of Package Restore 183 | **/[Pp]ackages/* 184 | # except build/, which is used as an MSBuild target. 185 | !**/[Pp]ackages/build/ 186 | # Uncomment if necessary however generally it will be regenerated when needed 187 | #!**/[Pp]ackages/repositories.config 188 | # NuGet v3's project.json files produces more ignorable files 189 | *.nuget.props 190 | *.nuget.targets 191 | 192 | # Microsoft Azure Build Output 193 | csx/ 194 | *.build.csdef 195 | 196 | # Microsoft Azure Emulator 197 | ecf/ 198 | rcf/ 199 | 200 | # Windows Store app package directories and files 201 | AppPackages/ 202 | BundleArtifacts/ 203 | Package.StoreAssociation.xml 204 | _pkginfo.txt 205 | *.appx 206 | 207 | # Visual Studio cache files 208 | # files ending in .cache can be ignored 209 | *.[Cc]ache 210 | # but keep track of directories ending in .cache 211 | !*.[Cc]ache/ 212 | 213 | # Others 214 | ClientBin/ 215 | ~$* 216 | *~ 217 | *.dbmdl 218 | *.dbproj.schemaview 219 | *.jfm 220 | *.pfx 221 | *.publishsettings 222 | orleans.codegen.cs 223 | 224 | # Including strong name files can present a security risk 225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 226 | #*.snk 227 | 228 | # Since there are multiple workflows, uncomment next line to ignore bower_components 229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 230 | #bower_components/ 231 | 232 | # RIA/Silverlight projects 233 | Generated_Code/ 234 | 235 | # Backup & report files from converting an old project file 236 | # to a newer Visual Studio version. Backup files are not needed, 237 | # because we have git ;-) 238 | _UpgradeReport_Files/ 239 | Backup*/ 240 | UpgradeLog*.XML 241 | UpgradeLog*.htm 242 | ServiceFabricBackup/ 243 | *.rptproj.bak 244 | 245 | # SQL Server files 246 | *.mdf 247 | *.ldf 248 | *.ndf 249 | 250 | # Business Intelligence projects 251 | *.rdl.data 252 | *.bim.layout 253 | *.bim_*.settings 254 | *.rptproj.rsuser 255 | 256 | # Microsoft Fakes 257 | FakesAssemblies/ 258 | 259 | # GhostDoc plugin setting file 260 | *.GhostDoc.xml 261 | 262 | # Node.js Tools for Visual Studio 263 | .ntvs_analysis.dat 264 | node_modules/ 265 | 266 | # Visual Studio 6 build log 267 | *.plg 268 | 269 | # Visual Studio 6 workspace options file 270 | *.opt 271 | 272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 273 | *.vbw 274 | 275 | # Visual Studio LightSwitch build output 276 | **/*.HTMLClient/GeneratedArtifacts 277 | **/*.DesktopClient/GeneratedArtifacts 278 | **/*.DesktopClient/ModelManifest.xml 279 | **/*.Server/GeneratedArtifacts 280 | **/*.Server/ModelManifest.xml 281 | _Pvt_Extensions 282 | 283 | # Paket dependency manager 284 | .paket/paket.exe 285 | paket-files/ 286 | 287 | # FAKE - F# Make 288 | .fake/ 289 | 290 | # JetBrains Rider 291 | .idea/ 292 | *.sln.iml 293 | 294 | # CodeRush 295 | .cr/ 296 | 297 | # Python Tools for Visual Studio (PTVS) 298 | __pycache__/ 299 | *.pyc 300 | 301 | # Cake - Uncomment if you are using it 302 | # tools/** 303 | # !tools/packages.config 304 | 305 | # Tabs Studio 306 | *.tss 307 | 308 | # Telerik's JustMock configuration file 309 | *.jmconfig 310 | 311 | # BizTalk build output 312 | *.btp.cs 313 | *.btm.cs 314 | *.odx.cs 315 | *.xsd.cs 316 | 317 | # OpenCover UI analysis results 318 | OpenCover/ 319 | 320 | # Azure Stream Analytics local run output 321 | ASALocalRun/ 322 | 323 | # MSBuild Binary and Structured Log 324 | *.binlog 325 | 326 | # NVidia Nsight GPU debugger configuration file 327 | *.nvuser 328 | 329 | # MFractors (Xamarin productivity tool) working folder 330 | .mfractor/ 331 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | NOTICES AND INFORMATION 2 | Do Not Translate or Localize 3 | 4 | This software incorporates material from third parties. Microsoft makes certain 5 | open source code available at https3rdpartysource.microsoft.com, or you may 6 | send a check or money order for US $5.00, including the product name, the open 7 | source component name, platform, and version number, to 8 | 9 | Source Code Compliance Team 10 | Microsoft Corporation 11 | One Microsoft Way 12 | Redmond, WA 98052 13 | USA 14 | 15 | Notwithstanding any other terms, you may reverse engineer this software to the 16 | extent required to debug changes to any libraries licensed under the GNU Lesser 17 | General Public License. 18 | 19 | ======================================================================= 20 | Component. PyTorch 21 | 22 | Open Source LicenseCopyright Notice. 23 | 24 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 25 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 26 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 27 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 28 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 29 | Copyright (c) 2011-2013 NYU (Clement Farabet) 30 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 31 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 32 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 33 | 34 | From Caffe2 35 | 36 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 37 | 38 | All contributions by Facebook 39 | Copyright (c) 2016 Facebook Inc. 40 | 41 | All contributions by Google 42 | Copyright (c) 2015 Google Inc. 43 | All rights reserved. 44 | 45 | All contributions by Yangqing Jia 46 | Copyright (c) 2015 Yangqing Jia 47 | All rights reserved. 48 | 49 | All contributions from Caffe 50 | Copyright(c) 2013, 2014, 2015, the respective contributors 51 | All rights reserved. 52 | 53 | All other contributions 54 | Copyright(c) 2015, 2016 the respective contributors 55 | All rights reserved. 56 | 57 | Caffe2 uses a copyright model similar to Caffe each contributor holds 58 | copyright over their contributions to Caffe2. The project versioning records 59 | all such contribution and copyright details. If a contributor wants to further 60 | mark their specific copyright on a particular contribution, they should 61 | indicate their copyright solely in the commit message of the change when it is 62 | committed. 63 | 64 | All rights reserved. 65 | 66 | Redistribution and use in source and binary forms, with or without 67 | modification, are permitted provided that the following conditions are met 68 | 69 | 1. Redistributions of source code must retain the above copyright 70 | notice, this list of conditions and the following disclaimer. 71 | 72 | 2. Redistributions in binary form must reproduce the above copyright 73 | notice, this list of conditions and the following disclaimer in the 74 | documentation andor other materials provided with the distribution. 75 | 76 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 77 | and IDIAP Research Institute nor the names of its contributors may be 78 | used to endorse or promote products derived from this software without 79 | specific prior written permission. 80 | 81 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS 82 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 83 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 84 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 85 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 86 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 87 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 88 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 89 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 90 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 91 | POSSIBILITY OF SUCH DAMAGE. 92 | 93 | ======================================================================= 94 | Component. nltk 95 | 96 | Open Source LicenseCopyright Notice. 97 | 98 | Copyright (C) 2001-2019 NLTK Project 99 | 100 | Licensed under the Apache License, Version 2.0 (the 'License'); 101 | you may not use this file except in compliance with the License. 102 | You may obtain a copy of the License at 103 | 104 | httpwww.apache.orglicensesLICENSE-2.0 105 | 106 | Unless required by applicable law or agreed to in writing, software 107 | distributed under the License is distributed on an 'AS IS' BASIS, 108 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 109 | See the License for the specific language governing permissions and 110 | limitations under the License. 111 | 112 | ======================================================================= 113 | Component. Glove 114 | 115 | Open Source LicenseCopyright Notice. 116 | 117 | Preamble 118 | The Open Data Commons – Public Domain Dedication & Licence is a document intended to allow you to freely share, modify, and use this work for any purpose and without any restrictions. This licence is intended for use on databases or their contents (“data”), either together or individually. 119 | Many databases are covered by copyright. Some jurisdictions, mainly in Europe, have specific special rights that cover databases called the “sui generis” database right. Both of these sets of rights, as well as other legal rights used to protect databases and data, can create uncertainty or practical difficulty for those wishing to share databases and their underlying data but retain a limited amount of rights under a “some rights reserved” approach to licensing as outlined in the Science Commons Protocol for Implementing Open Access Data. As a result, this waiver and licence tries to the fullest extent possible to eliminate or fully license any rights that cover this database and data. Any Community Norms or similar statements of use of the database or data do not form a part of this document, and do not act as a contract for access or other terms of use for the database or data. 120 | The position of the recipient of the work 121 | Because this document places the database and its contents in or as close as possible within the public domain, there are no restrictions or requirements placed on the recipient by this document. Recipients may use this work commercially, use technical protection measures, combine this data or database with other databases or data, and share their changes and additions or keep them secret. It is not a requirement that recipients provide further users with a copy of this licence or attribute the original creator of the data or database as a source. The goal is to eliminate restrictions held by the original creator of the data and database on the use of it by others. 122 | The position of the dedicator of the work 123 | Copyright law, as with most other law under the banner of “intellectual property”, is inherently national law. This means that there exists several differences in how copyright and other IP rights can be relinquished, waived or licensed in the many legal jurisdictions of the world. This is despite much harmonisation of minimum levels of protection. The internet and other communication technologies span these many disparate legal jurisdictions and thus pose special difficulties for a document relinquishing and waiving intellectual property rights, including copyright and database rights, for use by the global community. Because of this feature of intellectual property law, this document first relinquishes the rights and waives the relevant rights and claims. It then goes on to license these same rights for jurisdictions or areas of law that may make it difficult to relinquish or waive rights or claims. 124 | The purpose of this document is to enable rightsholders to place their work into the public domain. Unlike licences for free and open source software, free cultural works, or open content licences, rightsholders will not be able to “dual license” their work by releasing the same work under different licences. This is because they have allowed anyone to use the work in whatever way they choose. Rightsholders therefore can’t re-license it under copyright or database rights on different terms because they have nothing left to license. Doing so creates truly accessible data to build rich applications and advance the progress of science and the arts. 125 | This document can cover either or both of the database and its contents (the data). Because databases can have a wide variety of content – not just factual data – rightsholders should use the Open Data Commons – Public Domain Dedication & Licence for an entire database and its contents only if everything can be placed under the terms of this document. Because even factual data can sometimes have intellectual property rights, rightsholders should use this licence to cover both the database and its factual data when making material available under this document; even if it is likely that the data would not be covered by copyright or database rights. 126 | Rightsholders can also use this document to cover any copyright or database rights claims over only a database, and leave the contents to be covered by other licences or documents. They can do this because this document refers to the “Work”, which can be either – or both – the database and its contents. As a result, rightsholders need to clearly state what they are dedicating under this document when they dedicate it. 127 | Just like any licence or other document dealing with intellectual property, rightsholders should be aware that one can only license what one owns. Please ensure that the rights have been cleared to make this material available under this document. 128 | This document permanently and irrevocably makes the Work available to the public for any use of any kind, and it should not be used unless the rightsholder is prepared for this to happen. 129 | Part I Introduction 130 | The Rightsholder (the Person holding rights or claims over the Work) agrees as follows 131 | 1.0 Definitions of Capitalised Words 132 | “Copyright” – Includes rights under copyright and under neighbouring rights and similarly related sets of rights under the law of the relevant jurisdiction under Section 6.4. 133 | “Data” – The contents of the Database, which includes the information, independent works, or other material collected into the Database offered under the terms of this Document. 134 | “Database” – A collection of Data arranged in a systematic or methodical way and individually accessible by electronic or other means offered under the terms of this Document. 135 | “Database Right” – Means rights over Data resulting from the Chapter III (“sui generis”) rights in the Database Directive (Directive 969EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases) and any future updates as well as any similar rights available in the relevant jurisdiction under Section 6.4. 136 | “Document” – means this relinquishment and waiver of rights and claims and back up licence agreement. 137 | “Person” – Means a natural or legal person or a body of persons corporate or incorporate. 138 | “Use” – As a verb, means doing any act that is restricted by Copyright or Database Rights whether in the original medium or any other; and includes modifying the Work as may be technically necessary to use it in a different mode or format. This includes the right to sublicense the Work. 139 | “Work” – Means either or both of the Database and Data offered under the terms of this Document. 140 | “You” – the Person acquiring rights under the licence elements of this Document. 141 | Words in the singular include the plural and vice versa. 142 | 2.0 What this document covers 143 | 2.1. Legal effect of this Document. This Document is 144 | a. A dedication to the public domain and waiver of Copyright and Database Rights over the Work; and 145 | b. A licence of Copyright and Database Rights over the Work in jurisdictions that do not allow for relinquishment or waiver. 146 | 2.2. Legal rights covered. 147 | a. Copyright. Any copyright or neighbouring rights in the Work. Copyright law varies between jurisdictions, but is likely to cover the Database model or schema, which is the structure, arrangement, and organisation of the Database, and can also include the Database tables and table indexes; the data entry and output sheets; and the Field names of Data stored in the Database. Copyright may also cover the Data depending on the jurisdiction and type of Data; and 148 | b. Database Rights. Database Rights only extend to the extraction and re-utilisation of the whole or a substantial part of the Data. Database Rights can apply even when there is no copyright over the Database. Database Rights can also apply when the Data is removed from the Database and is selected and arranged in a way that would not infringe any applicable copyright. 149 | 2.2 Rights not covered. 150 | a. This Document does not apply to computer programs used in the making or operation of the Database; 151 | b. This Document does not cover any patents over the Data or the Database. Please see Section 4.2 later in this Document for further details; and 152 | c. This Document does not cover any trade marks associated with the Database. Please see Section 4.3 later in this Document for further details. 153 | Users of this Database are cautioned that they may have to clear other rights or consult other licences. 154 | 2.3 Facts are free. The Rightsholder takes the position that factual information is not covered by Copyright. This Document however covers the Work in jurisdictions that may protect the factual information in the Work by Copyright, and to cover any information protected by Copyright that is contained in the Work. 155 | Part II Dedication to the public domain 156 | 3.0 Dedication, waiver, and licence of Copyright and Database Rights 157 | 3.1 Dedication of Copyright and Database Rights to the public domain. The Rightsholder by using this Document, dedicates the Work to the public domain for the benefit of the public and relinquishes all rights in Copyright and Database Rights over the Work. 158 | a. The Rightsholder realises that once these rights are relinquished, that the Rightsholder has no further rights in Copyright and Database Rights over the Work, and that the Work is free and open for others to Use. 159 | b. The Rightsholder intends for their relinquishment to cover all present and future rights in the Work under Copyright and Database Rights, whether they are vested or contingent rights, and that this relinquishment of rights covers all their heirs and successors. 160 | The above relinquishment of rights applies worldwide and includes media and formats now known or created in the future. 161 | 3.2 Waiver of rights and claims in Copyright and Database Rights when Section 3.1 dedication inapplicable. If the dedication in Section 3.1 does not apply in the relevant jurisdiction under Section 6.4, the Rightsholder waives any rights and claims that the Rightsholder may have or acquire in the future over the Work in 162 | a. Copyright; and 163 | b. Database Rights. 164 | To the extent possible in the relevant jurisdiction, the above waiver of rights and claims applies worldwide and includes media and formats now known or created in the future. The Rightsholder agrees not to assert the above rights and waives the right to enforce them over the Work. 165 | 3.3 Licence of Copyright and Database Rights when Sections 3.1 and 3.2 inapplicable. If the dedication and waiver in Sections 3.1 and 3.2 does not apply in the relevant jurisdiction under Section 6.4, the Rightsholder and You agree as follows 166 | a. The Licensor grants to You a worldwide, royalty-free, non-exclusive, licence to Use the Work for the duration of any applicable Copyright and Database Rights. These rights explicitly include commercial use, and do not exclude any field of endeavour. To the extent possible in the relevant jurisdiction, these rights may be exercised in all media and formats whether now known or created in the future. 167 | 3.4 Moral rights. This section covers moral rights, including the right to be identified as the author of the Work or to object to treatment that would otherwise prejudice the author’s honour and reputation, or any other derogatory treatment 168 | a. For jurisdictions allowing waiver of moral rights, Licensor waives all moral rights that Licensor may have in the Work to the fullest extent possible by the law of the relevant jurisdiction under Section 6.4; 169 | b. If waiver of moral rights under Section 3.4 a in the relevant jurisdiction is not possible, Licensor agrees not to assert any moral rights over the Work and waives all claims in moral rights to the fullest extent possible by the law of the relevant jurisdiction under Section 6.4; and 170 | c. For jurisdictions not allowing waiver or an agreement not to assert moral rights under Section 3.4 a and b, the author may retain their moral rights over the copyrighted aspects of the Work. 171 | Please note that some jurisdictions do not allow for the waiver of moral rights, and so moral rights may still subsist over the work in some jurisdictions. 172 | 4.0 Relationship to other rights 173 | 4.1 No other contractual conditions. The Rightsholder makes this Work available to You without any other contractual obligations, either express or implied. Any Community Norms statement associated with the Work is not a contract and does not form part of this Document. 174 | 4.2 Relationship to patents. This Document does not grant You a licence for any patents that the Rightsholder may own. Users of this Database are cautioned that they may have to clear other rights or consult other licences. 175 | 4.3 Relationship to trade marks. This Document does not grant You a licence for any trade marks that the Rightsholder may own or that the Rightsholder may use to cover the Work. Users of this Database are cautioned that they may have to clear other rights or consult other licences. 176 | Part III General provisions 177 | 5.0 Warranties, disclaimer, and limitation of liability 178 | 5.1 The Work is provided by the Rightsholder “as is” and without any warranty of any kind, either express or implied, whether of title, of accuracy or completeness, of the presence of absence of errors, of fitness for purpose, or otherwise. Some jurisdictions do not allow the exclusion of implied warranties, so this exclusion may not apply to You. 179 | 5.2 Subject to any liability that may not be excluded or limited by law, the Rightsholder is not liable for, and expressly excludes, all liability for loss or damage however and whenever caused to anyone by any use under this Document, whether by You or by anyone else, and whether caused by any fault on the part of the Rightsholder or not. This exclusion of liability includes, but is not limited to, any special, incidental, consequential, punitive, or exemplary damages. This exclusion applies even if the Rightsholder has been advised of the possibility of such damages. 180 | 5.3 If liability may not be excluded by law, it is limited to actual and direct financial loss to the extent it is caused by proved negligence on the part of the Rightsholder. 181 | 6.0 General 182 | 6.1 If any provision of this Document is held to be invalid or unenforceable, that must not affect the validity or enforceability of the remainder of the terms of this Document. 183 | 6.2 This Document is the entire agreement between the parties with respect to the Work covered here. It replaces any earlier understandings, agreements or representations with respect to the Work not specified here. 184 | 6.3 This Document does not affect any rights that You or anyone else may independently have under any applicable law to make any use of this Work, including (for jurisdictions where this Document is a licence) fair dealing, fair use, database exceptions, or any other legally recognised limitation or exception to infringement of copyright or other applicable laws. 185 | 6.4 This Document takes effect in the relevant jurisdiction in which the Document terms are sought to be enforced. If the rights waived or granted under applicable law in the relevant jurisdiction includes additional rights not waived or granted under this Document, these additional rights are included in this Document in order to meet the intent of this Document. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IRNet 2 | Code for our ACL'19 accepted paper: [Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation](https://arxiv.org/pdf/1905.08205.pdf) 3 | 4 |

5 | 6 |

7 | 8 | ## Environment Setup 9 | 10 | * `Python3.6` 11 | * `Pytorch 0.4.0` or higher 12 | 13 | Install Python dependency via `pip install -r requirements.txt` when the environment of Python and Pytorch is setup. 14 | 15 | ## Running Code 16 | 17 | #### Data preparation 18 | 19 | 20 | * Download [Glove Embedding](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip) and put `glove.42B.300d` under `./data/` directory 21 | * Download [Pretrained IRNet](https://drive.google.com/open?id=1VoV28fneYss8HaZmoThGlvYU3A-aK31q) and put ` 22 | IRNet_pretrained.model` under `./saved_model/` directory 23 | * Download preprocessed train/dev datasets from [here](https://drive.google.com/open?id=1YFV1GoLivOMlmunKW0nkzefKULO4wtrn) and put `train.json`, `dev.json` and 24 | `tables.json` under `./data/` directory 25 | 26 | ##### Generating train/dev data by yourself 27 | You could process the origin [Spider Data](https://drive.google.com/uc?export=download&id=11icoH_EA-NYb0OrPTdehRWm_d7-DIzWX) by your own. Download and put `train.json`, `dev.json` and 28 | `tables.json` under `./data/` directory and follow the instruction on `./preprocess/` 29 | 30 | #### Training 31 | 32 | Run `train.sh` to train IRNet. 33 | 34 | `sh train.sh [GPU_ID] [SAVE_FOLD]` 35 | 36 | #### Testing 37 | 38 | Run `eval.sh` to eval IRNet. 39 | 40 | `sh eval.sh [GPU_ID] [OUTPUT_FOLD]` 41 | 42 | 43 | #### Evaluation 44 | 45 | You could follow the general evaluation process in [Spider Page](https://github.com/taoyds/spider) 46 | 47 | 48 | ## Results 49 | | **Model** | Dev
Exact Set Match
Accuracy | Test
Exact Set Match
Accuracy | 50 | | ----------- | ------------------------------------- | -------------------------------------- | 51 | | IRNet | 53.2 | 46.7 | 52 | | IRNet+BERT(base) | 61.9 | **54.7** | 53 | 54 | 55 | ## Citation 56 | 57 | If you use IRNet, please cite the following work. 58 | 59 | ``` 60 | @inproceedings{GuoIRNet2019, 61 | author={Jiaqi Guo and Zecheng Zhan and Yan Gao and Yan Xiao and Jian-Guang Lou and Ting Liu and Dongmei Zhang}, 62 | title={Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation}, 63 | booktitle={Proceeding of the 57th Annual Meeting of the Association for Computational Linguistics (ACL)}, 64 | year={2019}, 65 | organization={Association for Computational Linguistics} 66 | } 67 | ``` 68 | 69 | ## Thanks 70 | We would like to thank [Tao Yu](https://taoyds.github.io/) and [Bo Pang](https://www.linkedin.com/in/bo-pang/) for running evaluations on our submitted models. 71 | We are also grateful to the flexible semantic parser [TranX](https://github.com/pcyin/tranX) that inspires our works. 72 | 73 | # Contributing 74 | 75 | This project welcomes contributions and suggestions. Most contributions require you to 76 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 77 | and actually do, grant us the rights to use your contribution. For details, visit 78 | https://cla.microsoft.com. 79 | 80 | When you submit a pull request, a CLA-bot will automatically determine whether you need 81 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 82 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 83 | 84 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 85 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 86 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 87 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : __init__.py 9 | # @Software: PyCharm 10 | """ -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/27 7 | # @Author : Jiaqi&Zecheng 8 | # @File : eval.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | 13 | import torch 14 | from src import args as arg 15 | from src import utils 16 | from src.models.model import IRNet 17 | from src.rule import semQL 18 | 19 | 20 | def evaluate(args): 21 | """ 22 | :param args: 23 | :return: 24 | """ 25 | 26 | grammar = semQL.Grammar() 27 | sql_data, table_data, val_sql_data,\ 28 | val_table_data= utils.load_dataset(args.dataset, use_small=args.toy) 29 | 30 | model = IRNet(args, grammar) 31 | 32 | if args.cuda: model.cuda() 33 | 34 | print('load pretrained model from %s'% (args.load_model)) 35 | pretrained_model = torch.load(args.load_model, 36 | map_location=lambda storage, loc: storage) 37 | import copy 38 | pretrained_modeled = copy.deepcopy(pretrained_model) 39 | for k in pretrained_model.keys(): 40 | if k not in model.state_dict().keys(): 41 | del pretrained_modeled[k] 42 | 43 | model.load_state_dict(pretrained_modeled) 44 | 45 | model.word_emb = utils.load_word_emb(args.glove_embed_path) 46 | 47 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, 48 | beam_size=args.beam_size) 49 | print('Sketch Acc: %f, Acc: %f' % (sketch_acc, acc)) 50 | # utils.eval_acc(json_datas, val_sql_data) 51 | import json 52 | with open('./predict_lf.json', 'w') as f: 53 | json.dump(json_datas, f) 54 | 55 | if __name__ == '__main__': 56 | arg_parser = arg.init_arg_parser() 57 | args = arg.init_config(arg_parser) 58 | print(args) 59 | evaluate(args) 60 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | #!/bin/bash 5 | 6 | devices=$1 7 | save_name=$2 8 | 9 | CUDA_VISIBLE_DEVICES=$devices python -u eval.py --dataset ./data \ 10 | --glove_embed_path ./data/glove.42B.300d.txt \ 11 | --cuda \ 12 | --epoch 50 \ 13 | --loss_epoch_threshold 50 \ 14 | --sketch_loss_coefficie 1.0 \ 15 | --beam_size 5 \ 16 | --seed 90 \ 17 | --save ${save_name} \ 18 | --embed_size 300 \ 19 | --sentence_features \ 20 | --column_pointer \ 21 | --hidden_size 300 \ 22 | --lr_scheduler \ 23 | --lr_scheduler_gammar 0.5 \ 24 | --att_vec_size 300 \ 25 | --load_model ./saved_model/IRNet_pretrained.model 26 | 27 | python sem2SQL.py --data_path ./data --input_path predict_lf.json --output_path ${save_name} 28 | -------------------------------------------------------------------------------- /preprocess/data_process.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/24 7 | # @Author : Jiaqi&Zecheng 8 | # @File : data_process.py 9 | # @Software: PyCharm 10 | """ 11 | import json 12 | import argparse 13 | import nltk 14 | import os 15 | import pickle 16 | from utils import symbol_filter, re_lemma, fully_part_header, group_header, partial_header, num2year, group_symbol, group_values, group_digital 17 | from utils import AGG, wordnet_lemmatizer 18 | from utils import load_dataSets 19 | 20 | def process_datas(datas, args): 21 | """ 22 | 23 | :param datas: 24 | :param args: 25 | :return: 26 | """ 27 | with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f: 28 | english_RelatedTo = pickle.load(f) 29 | 30 | with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f: 31 | english_IsA = pickle.load(f) 32 | 33 | # copy of the origin question_toks 34 | for d in datas: 35 | if 'origin_question_toks' not in d: 36 | d['origin_question_toks'] = d['question_toks'] 37 | 38 | for entry in datas: 39 | entry['question_toks'] = symbol_filter(entry['question_toks']) 40 | origin_question_toks = symbol_filter([x for x in entry['origin_question_toks'] if x.lower() != 'the']) 41 | question_toks = [wordnet_lemmatizer.lemmatize(x.lower()) for x in entry['question_toks'] if x.lower() != 'the'] 42 | 43 | entry['question_toks'] = question_toks 44 | 45 | table_names = [] 46 | table_names_pattern = [] 47 | 48 | for y in entry['table_names']: 49 | x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')] 50 | table_names.append(" ".join(x)) 51 | x = [re_lemma(x.lower()) for x in y.split(' ')] 52 | table_names_pattern.append(" ".join(x)) 53 | 54 | header_toks = [] 55 | header_toks_list = [] 56 | 57 | header_toks_pattern = [] 58 | header_toks_list_pattern = [] 59 | 60 | for y in entry['col_set']: 61 | x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')] 62 | header_toks.append(" ".join(x)) 63 | header_toks_list.append(x) 64 | 65 | x = [re_lemma(x.lower()) for x in y.split(' ')] 66 | header_toks_pattern.append(" ".join(x)) 67 | header_toks_list_pattern.append(x) 68 | 69 | num_toks = len(question_toks) 70 | idx = 0 71 | tok_concol = [] 72 | type_concol = [] 73 | nltk_result = nltk.pos_tag(question_toks) 74 | 75 | while idx < num_toks: 76 | 77 | # fully header 78 | end_idx, header = fully_part_header(question_toks, idx, num_toks, header_toks) 79 | if header: 80 | tok_concol.append(question_toks[idx: end_idx]) 81 | type_concol.append(["col"]) 82 | idx = end_idx 83 | continue 84 | 85 | # check for table 86 | end_idx, tname = group_header(question_toks, idx, num_toks, table_names) 87 | if tname: 88 | tok_concol.append(question_toks[idx: end_idx]) 89 | type_concol.append(["table"]) 90 | idx = end_idx 91 | continue 92 | 93 | # check for column 94 | end_idx, header = group_header(question_toks, idx, num_toks, header_toks) 95 | if header: 96 | tok_concol.append(question_toks[idx: end_idx]) 97 | type_concol.append(["col"]) 98 | idx = end_idx 99 | continue 100 | 101 | # check for partial column 102 | end_idx, tname = partial_header(question_toks, idx, header_toks_list) 103 | if tname: 104 | tok_concol.append(tname) 105 | type_concol.append(["col"]) 106 | idx = end_idx 107 | continue 108 | 109 | # check for aggregation 110 | end_idx, agg = group_header(question_toks, idx, num_toks, AGG) 111 | if agg: 112 | tok_concol.append(question_toks[idx: end_idx]) 113 | type_concol.append(["agg"]) 114 | idx = end_idx 115 | continue 116 | 117 | if nltk_result[idx][1] == 'RBR' or nltk_result[idx][1] == 'JJR': 118 | tok_concol.append([question_toks[idx]]) 119 | type_concol.append(['MORE']) 120 | idx += 1 121 | continue 122 | 123 | if nltk_result[idx][1] == 'RBS' or nltk_result[idx][1] == 'JJS': 124 | tok_concol.append([question_toks[idx]]) 125 | type_concol.append(['MOST']) 126 | idx += 1 127 | continue 128 | 129 | # string match for Time Format 130 | if num2year(question_toks[idx]): 131 | question_toks[idx] = 'year' 132 | end_idx, header = group_header(question_toks, idx, num_toks, header_toks) 133 | if header: 134 | tok_concol.append(question_toks[idx: end_idx]) 135 | type_concol.append(["col"]) 136 | idx = end_idx 137 | continue 138 | 139 | def get_concept_result(toks, graph): 140 | for begin_id in range(0, len(toks)): 141 | for r_ind in reversed(range(1, len(toks) + 1 - begin_id)): 142 | tmp_query = "_".join(toks[begin_id:r_ind]) 143 | if tmp_query in graph: 144 | mi = graph[tmp_query] 145 | for col in entry['col_set']: 146 | if col in mi: 147 | return col 148 | 149 | end_idx, symbol = group_symbol(question_toks, idx, num_toks) 150 | if symbol: 151 | tmp_toks = [x for x in question_toks[idx: end_idx]] 152 | assert len(tmp_toks) > 0, print(symbol, question_toks) 153 | pro_result = get_concept_result(tmp_toks, english_IsA) 154 | if pro_result is None: 155 | pro_result = get_concept_result(tmp_toks, english_RelatedTo) 156 | if pro_result is None: 157 | pro_result = "NONE" 158 | for tmp in tmp_toks: 159 | tok_concol.append([tmp]) 160 | type_concol.append([pro_result]) 161 | pro_result = "NONE" 162 | idx = end_idx 163 | continue 164 | 165 | end_idx, values = group_values(origin_question_toks, idx, num_toks) 166 | if values and (len(values) > 1 or question_toks[idx - 1] not in ['?', '.']): 167 | tmp_toks = [wordnet_lemmatizer.lemmatize(x) for x in question_toks[idx: end_idx] if x.isalnum() is True] 168 | assert len(tmp_toks) > 0, print(question_toks[idx: end_idx], values, question_toks, idx, end_idx) 169 | pro_result = get_concept_result(tmp_toks, english_IsA) 170 | if pro_result is None: 171 | pro_result = get_concept_result(tmp_toks, english_RelatedTo) 172 | if pro_result is None: 173 | pro_result = "NONE" 174 | for tmp in tmp_toks: 175 | tok_concol.append([tmp]) 176 | type_concol.append([pro_result]) 177 | pro_result = "NONE" 178 | idx = end_idx 179 | continue 180 | 181 | result = group_digital(question_toks, idx) 182 | if result is True: 183 | tok_concol.append(question_toks[idx: idx + 1]) 184 | type_concol.append(["value"]) 185 | idx += 1 186 | continue 187 | if question_toks[idx] == ['ha']: 188 | question_toks[idx] = ['have'] 189 | 190 | tok_concol.append([question_toks[idx]]) 191 | type_concol.append(['NONE']) 192 | idx += 1 193 | continue 194 | 195 | entry['question_arg'] = tok_concol 196 | entry['question_arg_type'] = type_concol 197 | entry['nltk_pos'] = nltk_result 198 | 199 | return datas 200 | 201 | 202 | if __name__ == '__main__': 203 | arg_parser = argparse.ArgumentParser() 204 | arg_parser.add_argument('--data_path', type=str, help='dataset', required=True) 205 | arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True) 206 | arg_parser.add_argument('--output', type=str, help='output data') 207 | args = arg_parser.parse_args() 208 | args.conceptNet = './conceptNet' 209 | 210 | # loading dataSets 211 | datas, table = load_dataSets(args) 212 | 213 | # process datasets 214 | process_result = process_datas(datas, args) 215 | 216 | with open(args.output, 'w') as f: 217 | json.dump(datas, f) 218 | 219 | 220 | -------------------------------------------------------------------------------- /preprocess/download_nltk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/1/29 7 | # @Author : Jiaqi&Zecheng 8 | # @File : download_nltk.py 9 | # @Software: PyCharm 10 | """ 11 | import nltk 12 | nltk.download('averaged_perceptron_tagger') 13 | nltk.download('punkt') 14 | nltk.download('wordnet') 15 | 16 | -------------------------------------------------------------------------------- /preprocess/run_me.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | #!/bin/bash 5 | 6 | data=$1 7 | table_data=$2 8 | output=$3 9 | 10 | echo "Start download NLTK data" 11 | python download_nltk.py 12 | 13 | echo "Start process the origin Spider dataset" 14 | python data_process.py --data_path ${data} --table_path ${table_data} --output "process_data.json" 15 | 16 | echo "Start generate SemQL from SQL" 17 | python sql2SemQL.py --data_path process_data.json --table_path ${table_data} --output ${data} 18 | 19 | rm process_data.json 20 | -------------------------------------------------------------------------------- /preprocess/sql2SemQL.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/24 7 | # @Author : Jiaqi&Zecheng 8 | # @File : sql2SemQL.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import argparse 13 | import json 14 | import sys 15 | 16 | import copy 17 | from utils import load_dataSets 18 | 19 | sys.path.append("..") 20 | from src.rule.semQL import Root1, Root, N, A, C, T, Sel, Sup, Filter, Order 21 | 22 | class Parser: 23 | def __init__(self): 24 | self.copy_selec = None 25 | self.sel_result = [] 26 | self.colSet = set() 27 | 28 | def _init_rule(self): 29 | self.copy_selec = None 30 | self.colSet = set() 31 | 32 | def _parse_root(self, sql): 33 | """ 34 | parsing the sql by the grammar 35 | R ::= Select | Select Filter | Select Order | ... | 36 | :return: [R(), states] 37 | """ 38 | use_sup, use_ord, use_fil = True, True, False 39 | 40 | if sql['sql']['limit'] == None: 41 | use_sup = False 42 | 43 | if sql['sql']['orderBy'] == []: 44 | use_ord = False 45 | elif sql['sql']['limit'] != None: 46 | use_ord = False 47 | 48 | # check the where and having 49 | if sql['sql']['where'] != [] or \ 50 | sql['sql']['having'] != []: 51 | use_fil = True 52 | 53 | if use_fil and use_sup: 54 | return [Root(0)], ['FILTER', 'SUP', 'SEL'] 55 | elif use_fil and use_ord: 56 | return [Root(1)], ['ORDER', 'FILTER', 'SEL'] 57 | elif use_sup: 58 | return [Root(2)], ['SUP', 'SEL'] 59 | elif use_fil: 60 | return [Root(3)], ['FILTER', 'SEL'] 61 | elif use_ord: 62 | return [Root(4)], ['ORDER', 'SEL'] 63 | else: 64 | return [Root(5)], ['SEL'] 65 | 66 | def _parser_column0(self, sql, select): 67 | """ 68 | Find table of column '*' 69 | :return: T(table_id) 70 | """ 71 | if len(sql['sql']['from']['table_units']) == 1: 72 | return T(sql['sql']['from']['table_units'][0][1]) 73 | else: 74 | table_list = [] 75 | for tmp_t in sql['sql']['from']['table_units']: 76 | if type(tmp_t[1]) == int: 77 | table_list.append(tmp_t[1]) 78 | table_set, other_set = set(table_list), set() 79 | for sel_p in select: 80 | if sel_p[1][1][1] != 0: 81 | other_set.add(sql['col_table'][sel_p[1][1][1]]) 82 | 83 | if len(sql['sql']['where']) == 1: 84 | other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]]) 85 | elif len(sql['sql']['where']) == 3: 86 | other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]]) 87 | other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]]) 88 | elif len(sql['sql']['where']) == 5: 89 | other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]]) 90 | other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]]) 91 | other_set.add(sql['col_table'][sql['sql']['where'][4][2][1][1]]) 92 | table_set = table_set - other_set 93 | if len(table_set) == 1: 94 | return T(list(table_set)[0]) 95 | elif len(table_set) == 0 and sql['sql']['groupBy'] != []: 96 | return T(sql['col_table'][sql['sql']['groupBy'][0][1]]) 97 | else: 98 | question = sql['question'] 99 | self.sel_result.append(question) 100 | print('column * table error') 101 | return T(sql['sql']['from']['table_units'][0][1]) 102 | 103 | def _parse_select(self, sql): 104 | """ 105 | parsing the sql by the grammar 106 | Select ::= A | AA | AAA | ... | 107 | A ::= agg column table 108 | :return: [Sel(), states] 109 | """ 110 | result = [] 111 | select = sql['sql']['select'][1] 112 | result.append(Sel(0)) 113 | result.append(N(len(select) - 1)) 114 | 115 | for sel in select: 116 | result.append(A(sel[0])) 117 | self.colSet.add(sql['col_set'].index(sql['names'][sel[1][1][1]])) 118 | result.append(C(sql['col_set'].index(sql['names'][sel[1][1][1]]))) 119 | # now check for the situation with * 120 | if sel[1][1][1] == 0: 121 | result.append(self._parser_column0(sql, select)) 122 | else: 123 | result.append(T(sql['col_table'][sel[1][1][1]])) 124 | if not self.copy_selec: 125 | self.copy_selec = [copy.deepcopy(result[-2]), copy.deepcopy(result[-1])] 126 | 127 | return result, None 128 | 129 | def _parse_sup(self, sql): 130 | """ 131 | parsing the sql by the grammar 132 | Sup ::= Most A | Least A 133 | A ::= agg column table 134 | :return: [Sup(), states] 135 | """ 136 | result = [] 137 | select = sql['sql']['select'][1] 138 | if sql['sql']['limit'] == None: 139 | return result, None 140 | if sql['sql']['orderBy'][0] == 'desc': 141 | result.append(Sup(0)) 142 | else: 143 | result.append(Sup(1)) 144 | 145 | result.append(A(sql['sql']['orderBy'][1][0][1][0])) 146 | self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])) 147 | result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))) 148 | if sql['sql']['orderBy'][1][0][1][1] == 0: 149 | result.append(self._parser_column0(sql, select)) 150 | else: 151 | result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]])) 152 | return result, None 153 | 154 | def _parse_filter(self, sql): 155 | """ 156 | parsing the sql by the grammar 157 | Filter ::= and Filter Filter | ... | 158 | A ::= agg column table 159 | :return: [Filter(), states] 160 | """ 161 | result = [] 162 | # check the where 163 | if sql['sql']['where'] != [] and sql['sql']['having'] != []: 164 | result.append(Filter(0)) 165 | 166 | if sql['sql']['where'] != []: 167 | # check the not and/or 168 | if len(sql['sql']['where']) == 1: 169 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) 170 | elif len(sql['sql']['where']) == 3: 171 | if sql['sql']['where'][1] == 'or': 172 | result.append(Filter(1)) 173 | else: 174 | result.append(Filter(0)) 175 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) 176 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) 177 | else: 178 | if sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'and': 179 | result.append(Filter(0)) 180 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) 181 | result.append(Filter(0)) 182 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) 183 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) 184 | elif sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'or': 185 | result.append(Filter(1)) 186 | result.append(Filter(0)) 187 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) 188 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) 189 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) 190 | elif sql['sql']['where'][1] == 'or' and sql['sql']['where'][3] == 'and': 191 | result.append(Filter(1)) 192 | result.append(Filter(0)) 193 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) 194 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) 195 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) 196 | else: 197 | result.append(Filter(1)) 198 | result.append(Filter(1)) 199 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) 200 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) 201 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) 202 | 203 | # check having 204 | if sql['sql']['having'] != []: 205 | result.extend(self.parse_one_condition(sql['sql']['having'][0], sql['names'], sql)) 206 | return result, None 207 | 208 | def _parse_order(self, sql): 209 | """ 210 | parsing the sql by the grammar 211 | Order ::= asc A | desc A 212 | A ::= agg column table 213 | :return: [Order(), states] 214 | """ 215 | result = [] 216 | 217 | if 'order' not in sql['query_toks_no_value'] or 'by' not in sql['query_toks_no_value']: 218 | return result, None 219 | elif 'limit' in sql['query_toks_no_value']: 220 | return result, None 221 | else: 222 | if sql['sql']['orderBy'] == []: 223 | return result, None 224 | else: 225 | select = sql['sql']['select'][1] 226 | if sql['sql']['orderBy'][0] == 'desc': 227 | result.append(Order(0)) 228 | else: 229 | result.append(Order(1)) 230 | result.append(A(sql['sql']['orderBy'][1][0][1][0])) 231 | self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])) 232 | result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))) 233 | if sql['sql']['orderBy'][1][0][1][1] == 0: 234 | result.append(self._parser_column0(sql, select)) 235 | else: 236 | result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]])) 237 | return result, None 238 | 239 | 240 | def parse_one_condition(self, sql_condit, names, sql): 241 | result = [] 242 | # check if V(root) 243 | nest_query = True 244 | if type(sql_condit[3]) != dict: 245 | nest_query = False 246 | 247 | if sql_condit[0] == True: 248 | if sql_condit[1] == 9: 249 | # not like only with values 250 | fil = Filter(10) 251 | elif sql_condit[1] == 8: 252 | # not in with Root 253 | fil = Filter(19) 254 | else: 255 | print(sql_condit[1]) 256 | raise NotImplementedError("not implement for the others FIL") 257 | else: 258 | # check for Filter (<,=,>,!=,between, >=, <=, ...) 259 | single_map = {1:8,2:2,3:5,4:4,5:7,6:6,7:3} 260 | nested_map = {1:15,2:11,3:13,4:12,5:16,6:17,7:14} 261 | if sql_condit[1] in [1, 2, 3, 4, 5, 6, 7]: 262 | if nest_query == False: 263 | fil = Filter(single_map[sql_condit[1]]) 264 | else: 265 | fil = Filter(nested_map[sql_condit[1]]) 266 | elif sql_condit[1] == 9: 267 | fil = Filter(9) 268 | elif sql_condit[1] == 8: 269 | fil = Filter(18) 270 | else: 271 | print(sql_condit[1]) 272 | raise NotImplementedError("not implement for the others FIL") 273 | 274 | result.append(fil) 275 | result.append(A(sql_condit[2][1][0])) 276 | self.colSet.add(sql['col_set'].index(sql['names'][sql_condit[2][1][1]])) 277 | result.append(C(sql['col_set'].index(sql['names'][sql_condit[2][1][1]]))) 278 | if sql_condit[2][1][1] == 0: 279 | select = sql['sql']['select'][1] 280 | result.append(self._parser_column0(sql, select)) 281 | else: 282 | result.append(T(sql['col_table'][sql_condit[2][1][1]])) 283 | 284 | # check for the nested value 285 | if type(sql_condit[3]) == dict: 286 | nest_query = {} 287 | nest_query['names'] = names 288 | nest_query['query_toks_no_value'] = "" 289 | nest_query['sql'] = sql_condit[3] 290 | nest_query['col_table'] = sql['col_table'] 291 | nest_query['col_set'] = sql['col_set'] 292 | nest_query['table_names'] = sql['table_names'] 293 | nest_query['question'] = sql['question'] 294 | nest_query['query'] = sql['query'] 295 | nest_query['keys'] = sql['keys'] 296 | result.extend(self.parser(nest_query)) 297 | 298 | return result 299 | 300 | def _parse_step(self, state, sql): 301 | 302 | if state == 'ROOT': 303 | return self._parse_root(sql) 304 | 305 | if state == 'SEL': 306 | return self._parse_select(sql) 307 | 308 | elif state == 'SUP': 309 | return self._parse_sup(sql) 310 | 311 | elif state == 'FILTER': 312 | return self._parse_filter(sql) 313 | 314 | elif state == 'ORDER': 315 | return self._parse_order(sql) 316 | else: 317 | raise NotImplementedError("Not the right state") 318 | 319 | def full_parse(self, query): 320 | sql = query['sql'] 321 | nest_query = {} 322 | nest_query['names'] = query['names'] 323 | nest_query['query_toks_no_value'] = "" 324 | nest_query['col_table'] = query['col_table'] 325 | nest_query['col_set'] = query['col_set'] 326 | nest_query['table_names'] = query['table_names'] 327 | nest_query['question'] = query['question'] 328 | nest_query['query'] = query['query'] 329 | nest_query['keys'] = query['keys'] 330 | 331 | if sql['intersect']: 332 | results = [Root1(0)] 333 | nest_query['sql'] = sql['intersect'] 334 | results.extend(self.parser(query)) 335 | results.extend(self.parser(nest_query)) 336 | return results 337 | 338 | if sql['union']: 339 | results = [Root1(1)] 340 | nest_query['sql'] = sql['union'] 341 | results.extend(self.parser(query)) 342 | results.extend(self.parser(nest_query)) 343 | return results 344 | 345 | if sql['except']: 346 | results = [Root1(2)] 347 | nest_query['sql'] = sql['except'] 348 | results.extend(self.parser(query)) 349 | results.extend(self.parser(nest_query)) 350 | return results 351 | 352 | results = [Root1(3)] 353 | results.extend(self.parser(query)) 354 | 355 | return results 356 | 357 | def parser(self, query): 358 | stack = ["ROOT"] 359 | result = [] 360 | while len(stack) > 0: 361 | state = stack.pop() 362 | step_result, step_state = self._parse_step(state, query) 363 | result.extend(step_result) 364 | if step_state: 365 | stack.extend(step_state) 366 | return result 367 | 368 | if __name__ == '__main__': 369 | arg_parser = argparse.ArgumentParser() 370 | arg_parser.add_argument('--data_path', type=str, help='dataset', required=True) 371 | arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True) 372 | arg_parser.add_argument('--output', type=str, help='output data', required=True) 373 | args = arg_parser.parse_args() 374 | 375 | parser = Parser() 376 | 377 | # loading dataSets 378 | datas, table = load_dataSets(args) 379 | processed_data = [] 380 | 381 | for i, d in enumerate(datas): 382 | if len(datas[i]['sql']['select'][1]) > 5: 383 | continue 384 | r = parser.full_parse(datas[i]) 385 | datas[i]['rule_label'] = " ".join([str(x) for x in r]) 386 | processed_data.append(datas[i]) 387 | 388 | print('Finished %s datas and failed %s datas' % (len(processed_data), len(datas) - len(processed_data))) 389 | with open(args.output, 'w', encoding='utf8') as f: 390 | f.write(json.dumps(processed_data)) 391 | 392 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/24 7 | # @Author : Jiaqi&Zecheng 8 | # @File : utils.py 9 | # @Software: PyCharm 10 | """ 11 | import os 12 | import json 13 | from pattern.en import lemma 14 | from nltk.stem import WordNetLemmatizer 15 | 16 | VALUE_FILTER = ['what', 'how', 'list', 'give', 'show', 'find', 'id', 'order', 'when'] 17 | AGG = ['average', 'sum', 'max', 'min', 'minimum', 'maximum', 'between'] 18 | 19 | wordnet_lemmatizer = WordNetLemmatizer() 20 | 21 | def load_dataSets(args): 22 | with open(args.table_path, 'r', encoding='utf8') as f: 23 | table_datas = json.load(f) 24 | with open(args.data_path, 'r', encoding='utf8') as f: 25 | datas = json.load(f) 26 | 27 | output_tab = {} 28 | tables = {} 29 | tabel_name = set() 30 | for i in range(len(table_datas)): 31 | table = table_datas[i] 32 | temp = {} 33 | temp['col_map'] = table['column_names'] 34 | temp['table_names'] = table['table_names'] 35 | tmp_col = [] 36 | for cc in [x[1] for x in table['column_names']]: 37 | if cc not in tmp_col: 38 | tmp_col.append(cc) 39 | table['col_set'] = tmp_col 40 | db_name = table['db_id'] 41 | tabel_name.add(db_name) 42 | table['schema_content'] = [col[1] for col in table['column_names']] 43 | table['col_table'] = [col[0] for col in table['column_names']] 44 | output_tab[db_name] = temp 45 | tables[db_name] = table 46 | 47 | for d in datas: 48 | d['names'] = tables[d['db_id']]['schema_content'] 49 | d['table_names'] = tables[d['db_id']]['table_names'] 50 | d['col_set'] = tables[d['db_id']]['col_set'] 51 | d['col_table'] = tables[d['db_id']]['col_table'] 52 | keys = {} 53 | for kv in tables[d['db_id']]['foreign_keys']: 54 | keys[kv[0]] = kv[1] 55 | keys[kv[1]] = kv[0] 56 | for id_k in tables[d['db_id']]['primary_keys']: 57 | keys[id_k] = id_k 58 | d['keys'] = keys 59 | return datas, tables 60 | 61 | def group_header(toks, idx, num_toks, header_toks): 62 | for endIdx in reversed(range(idx + 1, num_toks+1)): 63 | sub_toks = toks[idx: endIdx] 64 | sub_toks = " ".join(sub_toks) 65 | if sub_toks in header_toks: 66 | return endIdx, sub_toks 67 | return idx, None 68 | 69 | def fully_part_header(toks, idx, num_toks, header_toks): 70 | for endIdx in reversed(range(idx + 1, num_toks+1)): 71 | sub_toks = toks[idx: endIdx] 72 | if len(sub_toks) > 1: 73 | sub_toks = " ".join(sub_toks) 74 | if sub_toks in header_toks: 75 | return endIdx, sub_toks 76 | return idx, None 77 | 78 | def partial_header(toks, idx, header_toks): 79 | def check_in(list_one, list_two): 80 | if len(set(list_one) & set(list_two)) == len(list_one) and (len(list_two) <= 3): 81 | return True 82 | for endIdx in reversed(range(idx + 1, len(toks))): 83 | sub_toks = toks[idx: min(endIdx, len(toks))] 84 | if len(sub_toks) > 1: 85 | flag_count = 0 86 | tmp_heads = None 87 | for heads in header_toks: 88 | if check_in(sub_toks, heads): 89 | flag_count += 1 90 | tmp_heads = heads 91 | if flag_count == 1: 92 | return endIdx, tmp_heads 93 | return idx, None 94 | 95 | def symbol_filter(questions): 96 | question_tmp_q = [] 97 | for q_id, q_val in enumerate(questions): 98 | if len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�', '鈥�'] and q_val[-1] in ["'", '"', '`', '鈥�']: 99 | question_tmp_q.append("'") 100 | question_tmp_q += ["".join(q_val[1:-1])] 101 | question_tmp_q.append("'") 102 | elif len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�'] : 103 | question_tmp_q.append("'") 104 | question_tmp_q += ["".join(q_val[1:])] 105 | elif len(q_val) > 2 and q_val[-1] in ["'", '"', '`', '鈥�']: 106 | question_tmp_q += ["".join(q_val[0:-1])] 107 | question_tmp_q.append("'") 108 | elif q_val in ["'", '"', '`', '鈥�', '鈥�', '``', "''"]: 109 | question_tmp_q += ["'"] 110 | else: 111 | question_tmp_q += [q_val] 112 | return question_tmp_q 113 | 114 | 115 | def group_values(toks, idx, num_toks): 116 | def check_isupper(tok_lists): 117 | for tok_one in tok_lists: 118 | if tok_one[0].isupper() is False: 119 | return False 120 | return True 121 | 122 | for endIdx in reversed(range(idx + 1, num_toks + 1)): 123 | sub_toks = toks[idx: endIdx] 124 | 125 | if len(sub_toks) > 1 and check_isupper(sub_toks) is True: 126 | return endIdx, sub_toks 127 | if len(sub_toks) == 1: 128 | if sub_toks[0][0].isupper() and sub_toks[0].lower() not in VALUE_FILTER and \ 129 | sub_toks[0].lower().isalnum() is True: 130 | return endIdx, sub_toks 131 | return idx, None 132 | 133 | 134 | def group_digital(toks, idx): 135 | test = toks[idx].replace(':', '') 136 | test = test.replace('.', '') 137 | if test.isdigit(): 138 | return True 139 | else: 140 | return False 141 | 142 | def group_symbol(toks, idx, num_toks): 143 | if toks[idx-1] == "'": 144 | for i in range(0, min(3, num_toks-idx)): 145 | if toks[i + idx] == "'": 146 | return i + idx, toks[idx:i+idx] 147 | return idx, None 148 | 149 | 150 | def num2year(tok): 151 | if len(str(tok)) == 4 and str(tok).isdigit() and int(str(tok)[:2]) < 22 and int(str(tok)[:2]) > 15: 152 | return True 153 | return False 154 | 155 | def set_header(toks, header_toks, tok_concol, idx, num_toks): 156 | def check_in(list_one, list_two): 157 | if set(list_one) == set(list_two): 158 | return True 159 | for endIdx in range(idx, num_toks): 160 | toks += tok_concol[endIdx] 161 | if len(tok_concol[endIdx]) > 1: 162 | break 163 | for heads in header_toks: 164 | if check_in(toks, heads): 165 | return heads 166 | return None 167 | 168 | def re_lemma(string): 169 | lema = lemma(string.lower()) 170 | if len(lema) > 0: 171 | return lema 172 | else: 173 | return string.lower() 174 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | nltk==3.4 5 | pattern 6 | numpy==1.14.0 7 | pytorch-pretrained-bert==0.5.1 8 | tqdm==4.31.1 -------------------------------------------------------------------------------- /sem2SQL.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/27 7 | # @Author : Jiaqi&Zecheng 8 | # @File : sem2SQL.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import argparse 13 | import traceback 14 | 15 | from src.rule.graph import Graph 16 | from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1 17 | from src.rule.sem_utils import alter_inter, alter_not_in, alter_column0, load_dataSets 18 | 19 | 20 | def split_logical_form(lf): 21 | indexs = [i+1 for i, letter in enumerate(lf) if letter == ')'] 22 | indexs.insert(0, 0) 23 | components = list() 24 | for i in range(1, len(indexs)): 25 | components.append(lf[indexs[i-1]:indexs[i]].strip()) 26 | return components 27 | 28 | 29 | def pop_front(array): 30 | if len(array) == 0: 31 | return 'None' 32 | return array.pop(0) 33 | 34 | 35 | def is_end(components, transformed_sql, is_root_processed): 36 | end = False 37 | c = pop_front(components) 38 | c_instance = eval(c) 39 | 40 | if isinstance(c_instance, Root) and is_root_processed: 41 | # intersect, union, except 42 | end = True 43 | elif isinstance(c_instance, Filter): 44 | if 'where' not in transformed_sql: 45 | end = True 46 | else: 47 | num_conjunction = 0 48 | for f in transformed_sql['where']: 49 | if isinstance(f, str) and (f == 'and' or f == 'or'): 50 | num_conjunction += 1 51 | current_filters = len(transformed_sql['where']) 52 | valid_filters = current_filters - num_conjunction 53 | if valid_filters >= num_conjunction + 1: 54 | end = True 55 | elif isinstance(c_instance, Order): 56 | if 'order' not in transformed_sql: 57 | end = True 58 | elif len(transformed_sql['order']) == 0: 59 | end = False 60 | else: 61 | end = True 62 | elif isinstance(c_instance, Sup): 63 | if 'sup' not in transformed_sql: 64 | end = True 65 | elif len(transformed_sql['sup']) == 0: 66 | end = False 67 | else: 68 | end = True 69 | components.insert(0, c) 70 | return end 71 | 72 | 73 | def _transform(components, transformed_sql, col_set, table_names, schema): 74 | processed_root = False 75 | current_table = schema 76 | 77 | while len(components) > 0: 78 | if is_end(components, transformed_sql, processed_root): 79 | break 80 | c = pop_front(components) 81 | c_instance = eval(c) 82 | if isinstance(c_instance, Root): 83 | processed_root = True 84 | transformed_sql['select'] = list() 85 | if c_instance.id_c == 0: 86 | transformed_sql['where'] = list() 87 | transformed_sql['sup'] = list() 88 | elif c_instance.id_c == 1: 89 | transformed_sql['where'] = list() 90 | transformed_sql['order'] = list() 91 | elif c_instance.id_c == 2: 92 | transformed_sql['sup'] = list() 93 | elif c_instance.id_c == 3: 94 | transformed_sql['where'] = list() 95 | elif c_instance.id_c == 4: 96 | transformed_sql['order'] = list() 97 | elif isinstance(c_instance, Sel): 98 | continue 99 | elif isinstance(c_instance, N): 100 | for i in range(c_instance.id_c + 1): 101 | agg = eval(pop_front(components)) 102 | column = eval(pop_front(components)) 103 | _table = pop_front(components) 104 | table = eval(_table) 105 | if not isinstance(table, T): 106 | table = None 107 | components.insert(0, _table) 108 | assert isinstance(agg, A) and isinstance(column, C) 109 | 110 | transformed_sql['select'].append(( 111 | agg.production.split()[1], 112 | replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) if table is not None else col_set[column.id_c], 113 | table_names[table.id_c] if table is not None else table 114 | )) 115 | 116 | elif isinstance(c_instance, Sup): 117 | transformed_sql['sup'].append(c_instance.production.split()[1]) 118 | agg = eval(pop_front(components)) 119 | column = eval(pop_front(components)) 120 | _table = pop_front(components) 121 | table = eval(_table) 122 | if not isinstance(table, T): 123 | table = None 124 | components.insert(0, _table) 125 | assert isinstance(agg, A) and isinstance(column, C) 126 | 127 | transformed_sql['sup'].append(agg.production.split()[1]) 128 | if table: 129 | fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) 130 | else: 131 | fix_col_id = col_set[column.id_c] 132 | raise RuntimeError('not found table !!!!') 133 | transformed_sql['sup'].append(fix_col_id) 134 | transformed_sql['sup'].append(table_names[table.id_c] if table is not None else table) 135 | 136 | elif isinstance(c_instance, Order): 137 | transformed_sql['order'].append(c_instance.production.split()[1]) 138 | agg = eval(pop_front(components)) 139 | column = eval(pop_front(components)) 140 | _table = pop_front(components) 141 | table = eval(_table) 142 | if not isinstance(table, T): 143 | table = None 144 | components.insert(0, _table) 145 | assert isinstance(agg, A) and isinstance(column, C) 146 | transformed_sql['order'].append(agg.production.split()[1]) 147 | transformed_sql['order'].append(replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)) 148 | transformed_sql['order'].append(table_names[table.id_c] if table is not None else table) 149 | 150 | elif isinstance(c_instance, Filter): 151 | op = c_instance.production.split()[1] 152 | if op == 'and' or op == 'or': 153 | transformed_sql['where'].append(op) 154 | else: 155 | # No Supquery 156 | agg = eval(pop_front(components)) 157 | column = eval(pop_front(components)) 158 | _table = pop_front(components) 159 | table = eval(_table) 160 | if not isinstance(table, T): 161 | table = None 162 | components.insert(0, _table) 163 | assert isinstance(agg, A) and isinstance(column, C) 164 | if len(c_instance.production.split()) == 3: 165 | if table: 166 | fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) 167 | else: 168 | fix_col_id = col_set[column.id_c] 169 | raise RuntimeError('not found table !!!!') 170 | transformed_sql['where'].append(( 171 | op, 172 | agg.production.split()[1], 173 | fix_col_id, 174 | table_names[table.id_c] if table is not None else table, 175 | None 176 | )) 177 | else: 178 | # Subquery 179 | new_dict = dict() 180 | new_dict['sql'] = transformed_sql['sql'] 181 | transformed_sql['where'].append(( 182 | op, 183 | agg.production.split()[1], 184 | replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table), 185 | table_names[table.id_c] if table is not None else table, 186 | _transform(components, new_dict, col_set, table_names, schema) 187 | )) 188 | 189 | return transformed_sql 190 | 191 | 192 | def transform(query, schema, origin=None): 193 | preprocess_schema(schema) 194 | if origin is None: 195 | lf = query['model_result_replace'] 196 | else: 197 | lf = origin 198 | # lf = query['rule_label'] 199 | col_set = query['col_set'] 200 | table_names = query['table_names'] 201 | current_table = schema 202 | 203 | current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']] 204 | current_table['schema_content'] = [x[1] for x in current_table['column_names_original']] 205 | 206 | components = split_logical_form(lf) 207 | 208 | transformed_sql = dict() 209 | transformed_sql['sql'] = query 210 | c = pop_front(components) 211 | c_instance = eval(c) 212 | assert isinstance(c_instance, Root1) 213 | if c_instance.id_c == 0: 214 | transformed_sql['intersect'] = dict() 215 | transformed_sql['intersect']['sql'] = query 216 | 217 | _transform(components, transformed_sql, col_set, table_names, schema) 218 | _transform(components, transformed_sql['intersect'], col_set, table_names, schema) 219 | elif c_instance.id_c == 1: 220 | transformed_sql['union'] = dict() 221 | transformed_sql['union']['sql'] = query 222 | _transform(components, transformed_sql, col_set, table_names, schema) 223 | _transform(components, transformed_sql['union'], col_set, table_names, schema) 224 | elif c_instance.id_c == 2: 225 | transformed_sql['except'] = dict() 226 | transformed_sql['except']['sql'] = query 227 | _transform(components, transformed_sql, col_set, table_names, schema) 228 | _transform(components, transformed_sql['except'], col_set, table_names, schema) 229 | else: 230 | _transform(components, transformed_sql, col_set, table_names, schema) 231 | 232 | parse_result = to_str(transformed_sql, 1, schema) 233 | 234 | parse_result = parse_result.replace('\t', '') 235 | return [parse_result] 236 | 237 | def col_to_str(agg, col, tab, table_names, N=1): 238 | _col = col.replace(' ', '_') 239 | if agg == 'none': 240 | if tab not in table_names: 241 | table_names[tab] = 'T' + str(len(table_names) + N) 242 | table_alias = table_names[tab] 243 | if col == '*': 244 | return '*' 245 | return '%s.%s' % (table_alias, _col) 246 | else: 247 | if col == '*': 248 | if tab is not None and tab not in table_names: 249 | table_names[tab] = 'T' + str(len(table_names) + N) 250 | return '%s(%s)' % (agg, _col) 251 | else: 252 | if tab not in table_names: 253 | table_names[tab] = 'T' + str(len(table_names) + N) 254 | table_alias = table_names[tab] 255 | return '%s(%s.%s)' % (agg, table_alias, _col) 256 | 257 | 258 | def infer_from_clause(table_names, schema, columns): 259 | tables = list(table_names.keys()) 260 | # print(table_names) 261 | start_table = None 262 | end_table = None 263 | join_clause = list() 264 | if len(tables) == 1: 265 | join_clause.append((tables[0], table_names[tables[0]])) 266 | elif len(tables) == 2: 267 | use_graph = True 268 | # print(schema['graph'].vertices) 269 | for t in tables: 270 | if t not in schema['graph'].vertices: 271 | use_graph = False 272 | break 273 | if use_graph: 274 | start_table = tables[0] 275 | end_table = tables[1] 276 | _tables = list(schema['graph'].dijkstra(tables[0], tables[1])) 277 | # print('Two tables: ', _tables) 278 | max_key = 1 279 | for t, k in table_names.items(): 280 | _k = int(k[1:]) 281 | if _k > max_key: 282 | max_key = _k 283 | for t in _tables: 284 | if t not in table_names: 285 | table_names[t] = 'T' + str(max_key + 1) 286 | max_key += 1 287 | join_clause.append((t, table_names[t],)) 288 | else: 289 | join_clause = list() 290 | for t in tables: 291 | join_clause.append((t, table_names[t],)) 292 | else: 293 | # > 2 294 | # print('More than 2 table') 295 | for t in tables: 296 | join_clause.append((t, table_names[t],)) 297 | 298 | if len(join_clause) >= 3: 299 | star_table = None 300 | for agg, col, tab in columns: 301 | if col == '*': 302 | star_table = tab 303 | break 304 | if star_table is not None: 305 | star_table_count = 0 306 | for agg, col, tab in columns: 307 | if tab == star_table and col != '*': 308 | star_table_count += 1 309 | if star_table_count == 0 and ((end_table is None or end_table == star_table) or (start_table is None or start_table == star_table)): 310 | # Remove the table the rest tables still can join without star_table 311 | new_join_clause = list() 312 | for t in join_clause: 313 | if t[0] != star_table: 314 | new_join_clause.append(t) 315 | join_clause = new_join_clause 316 | 317 | join_clause = ' JOIN '.join(['%s AS %s' % (jc[0], jc[1]) for jc in join_clause]) 318 | return 'FROM ' + join_clause 319 | 320 | def replace_col_with_original_col(query, col, current_table): 321 | # print(query, col) 322 | if query == '*': 323 | return query 324 | 325 | cur_table = col 326 | cur_col = query 327 | single_final_col = None 328 | # print(query, col) 329 | for col_ind, col_name in enumerate(current_table['schema_content_clean']): 330 | if col_name == cur_col: 331 | assert cur_table in current_table['table_names'] 332 | if current_table['table_names'][current_table['col_table'][col_ind]] == cur_table: 333 | single_final_col = current_table['column_names_original'][col_ind][1] 334 | break 335 | 336 | assert single_final_col 337 | # if query != single_final_col: 338 | # print(query, single_final_col) 339 | return single_final_col 340 | 341 | 342 | def build_graph(schema): 343 | relations = list() 344 | foreign_keys = schema['foreign_keys'] 345 | for (fkey, pkey) in foreign_keys: 346 | fkey_table = schema['table_names_original'][schema['column_names'][fkey][0]] 347 | pkey_table = schema['table_names_original'][schema['column_names'][pkey][0]] 348 | relations.append((fkey_table, pkey_table)) 349 | relations.append((pkey_table, fkey_table)) 350 | return Graph(relations) 351 | 352 | 353 | def preprocess_schema(schema): 354 | tmp_col = [] 355 | for cc in [x[1] for x in schema['column_names']]: 356 | if cc not in tmp_col: 357 | tmp_col.append(cc) 358 | schema['col_set'] = tmp_col 359 | # print table 360 | schema['schema_content'] = [col[1] for col in schema['column_names']] 361 | schema['col_table'] = [col[0] for col in schema['column_names']] 362 | graph = build_graph(schema) 363 | schema['graph'] = graph 364 | 365 | 366 | 367 | 368 | def to_str(sql_json, N_T, schema, pre_table_names=None): 369 | all_columns = list() 370 | select_clause = list() 371 | table_names = dict() 372 | current_table = schema 373 | for (agg, col, tab) in sql_json['select']: 374 | all_columns.append((agg, col, tab)) 375 | select_clause.append(col_to_str(agg, col, tab, table_names, N_T)) 376 | select_clause_str = 'SELECT ' + ', '.join(select_clause).strip() 377 | 378 | sup_clause = '' 379 | order_clause = '' 380 | direction_map = {"des": 'DESC', 'asc': 'ASC'} 381 | 382 | if 'sup' in sql_json: 383 | (direction, agg, col, tab,) = sql_json['sup'] 384 | all_columns.append((agg, col, tab)) 385 | subject = col_to_str(agg, col, tab, table_names, N_T) 386 | sup_clause = ('ORDER BY %s %s LIMIT 1' % (subject, direction_map[direction])).strip() 387 | elif 'order' in sql_json: 388 | (direction, agg, col, tab,) = sql_json['order'] 389 | all_columns.append((agg, col, tab)) 390 | subject = col_to_str(agg, col, tab, table_names, N_T) 391 | order_clause = ('ORDER BY %s %s' % (subject, direction_map[direction])).strip() 392 | 393 | has_group_by = False 394 | where_clause = '' 395 | have_clause = '' 396 | if 'where' in sql_json: 397 | conjunctions = list() 398 | filters = list() 399 | # print(sql_json['where']) 400 | for f in sql_json['where']: 401 | if isinstance(f, str): 402 | conjunctions.append(f) 403 | else: 404 | op, agg, col, tab, value = f 405 | if value: 406 | value['sql'] = sql_json['sql'] 407 | all_columns.append((agg, col, tab)) 408 | subject = col_to_str(agg, col, tab, table_names, N_T) 409 | if value is None: 410 | where_value = '1' 411 | if op == 'between': 412 | where_value = '1 AND 2' 413 | filters.append('%s %s %s' % (subject, op, where_value)) 414 | else: 415 | if op == 'in' and len(value['select']) == 1 and value['select'][0][0] == 'none' \ 416 | and 'where' not in value and 'order' not in value and 'sup' not in value: 417 | # and value['select'][0][2] not in table_names: 418 | if value['select'][0][2] not in table_names: 419 | table_names[value['select'][0][2]] = 'T' + str(len(table_names) + N_T) 420 | filters.append(None) 421 | 422 | else: 423 | filters.append('%s %s %s' % (subject, op, '(' + to_str(value, len(table_names) + 1, schema) + ')')) 424 | if len(conjunctions): 425 | filters.append(conjunctions.pop()) 426 | 427 | aggs = ['count(', 'avg(', 'min(', 'max(', 'sum('] 428 | having_filters = list() 429 | idx = 0 430 | while idx < len(filters): 431 | _filter = filters[idx] 432 | if _filter is None: 433 | idx += 1 434 | continue 435 | for agg in aggs: 436 | if _filter.startswith(agg): 437 | having_filters.append(_filter) 438 | filters.pop(idx) 439 | # print(filters) 440 | if 0 < idx and (filters[idx - 1] in ['and', 'or']): 441 | filters.pop(idx - 1) 442 | # print(filters) 443 | break 444 | else: 445 | idx += 1 446 | if len(having_filters) > 0: 447 | have_clause = 'HAVING ' + ' '.join(having_filters).strip() 448 | if len(filters) > 0: 449 | # print(filters) 450 | filters = [_f for _f in filters if _f is not None] 451 | conjun_num = 0 452 | filter_num = 0 453 | for _f in filters: 454 | if _f in ['or', 'and']: 455 | conjun_num += 1 456 | else: 457 | filter_num += 1 458 | if conjun_num > 0 and filter_num != (conjun_num + 1): 459 | # assert 'and' in filters 460 | idx = 0 461 | while idx < len(filters): 462 | if filters[idx] == 'and': 463 | if idx - 1 == 0: 464 | filters.pop(idx) 465 | break 466 | if filters[idx - 1] in ['and', 'or']: 467 | filters.pop(idx) 468 | break 469 | if idx + 1 >= len(filters) - 1: 470 | filters.pop(idx) 471 | break 472 | if filters[idx + 1] in ['and', 'or']: 473 | filters.pop(idx) 474 | break 475 | idx += 1 476 | if len(filters) > 0: 477 | where_clause = 'WHERE ' + ' '.join(filters).strip() 478 | where_clause = where_clause.replace('not_in', 'NOT IN') 479 | else: 480 | where_clause = '' 481 | 482 | if len(having_filters) > 0: 483 | has_group_by = True 484 | 485 | for agg in ['count(', 'avg(', 'min(', 'max(', 'sum(']: 486 | if (len(sql_json['select']) > 1 and agg in select_clause_str)\ 487 | or agg in sup_clause or agg in order_clause: 488 | has_group_by = True 489 | break 490 | 491 | group_by_clause = '' 492 | if has_group_by: 493 | if len(table_names) == 1: 494 | # check none agg 495 | is_agg_flag = False 496 | for (agg, col, tab) in sql_json['select']: 497 | 498 | if agg == 'none': 499 | group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T) 500 | else: 501 | is_agg_flag = True 502 | 503 | if is_agg_flag is False and len(group_by_clause) > 5: 504 | group_by_clause = "GROUP BY" 505 | for (agg, col, tab) in sql_json['select']: 506 | group_by_clause = group_by_clause + ' ' + col_to_str(agg, col, tab, table_names, N_T) 507 | 508 | if len(group_by_clause) < 5: 509 | if 'count(*)' in select_clause_str: 510 | current_table = schema 511 | for primary in current_table['primary_keys']: 512 | if current_table['table_names'][current_table['col_table'][primary]] in table_names : 513 | group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][primary], 514 | current_table['table_names'][ 515 | current_table['col_table'][primary]], 516 | table_names, N_T) 517 | else: 518 | # if only one select 519 | if len(sql_json['select']) == 1: 520 | agg, col, tab = sql_json['select'][0] 521 | non_lists = [tab] 522 | fix_flag = False 523 | # add tab from other part 524 | for key, value in table_names.items(): 525 | if key not in non_lists: 526 | non_lists.append(key) 527 | 528 | a = non_lists[0] 529 | b = None 530 | for non in non_lists: 531 | if a != non: 532 | b = non 533 | if b: 534 | for pair in current_table['foreign_keys']: 535 | t1 = current_table['table_names'][current_table['col_table'][pair[0]]] 536 | t2 = current_table['table_names'][current_table['col_table'][pair[1]]] 537 | if t1 in [a, b] and t2 in [a, b]: 538 | if pre_table_names and t1 not in pre_table_names: 539 | assert t2 in pre_table_names 540 | t1 = t2 541 | group_by_clause = 'GROUP BY ' + col_to_str('none', 542 | current_table['schema_content'][pair[0]], 543 | t1, 544 | table_names, N_T) 545 | fix_flag = True 546 | break 547 | 548 | if fix_flag is False: 549 | agg, col, tab = sql_json['select'][0] 550 | group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T) 551 | 552 | else: 553 | # check if there are only one non agg 554 | non_agg, non_agg_count = None, 0 555 | non_lists = [] 556 | for (agg, col, tab) in sql_json['select']: 557 | if agg == 'none': 558 | non_agg = (agg, col, tab) 559 | non_lists.append(tab) 560 | non_agg_count += 1 561 | 562 | non_lists = list(set(non_lists)) 563 | # print(non_lists) 564 | if non_agg_count == 1: 565 | group_by_clause = 'GROUP BY ' + col_to_str(non_agg[0], non_agg[1], non_agg[2], table_names, N_T) 566 | elif non_agg: 567 | find_flag = False 568 | fix_flag = False 569 | find_primary = None 570 | if len(non_lists) <= 1: 571 | for key, value in table_names.items(): 572 | if key not in non_lists: 573 | non_lists.append(key) 574 | if len(non_lists) > 1: 575 | a = non_lists[0] 576 | b = None 577 | for non in non_lists: 578 | if a != non: 579 | b = non 580 | if b: 581 | for pair in current_table['foreign_keys']: 582 | t1 = current_table['table_names'][current_table['col_table'][pair[0]]] 583 | t2 = current_table['table_names'][current_table['col_table'][pair[1]]] 584 | if t1 in [a, b] and t2 in [a, b]: 585 | if pre_table_names and t1 not in pre_table_names: 586 | assert t2 in pre_table_names 587 | t1 = t2 588 | group_by_clause = 'GROUP BY ' + col_to_str('none', 589 | current_table['schema_content'][pair[0]], 590 | t1, 591 | table_names, N_T) 592 | fix_flag = True 593 | break 594 | tab = non_agg[2] 595 | assert tab in current_table['table_names'] 596 | 597 | for primary in current_table['primary_keys']: 598 | if current_table['table_names'][current_table['col_table'][primary]] == tab: 599 | find_flag = True 600 | find_primary = (current_table['schema_content'][primary], tab) 601 | if fix_flag is False: 602 | if find_flag is False: 603 | # rely on count * 604 | foreign = [] 605 | for pair in current_table['foreign_keys']: 606 | if current_table['table_names'][current_table['col_table'][pair[0]]] == tab: 607 | foreign.append(pair[1]) 608 | if current_table['table_names'][current_table['col_table'][pair[1]]] == tab: 609 | foreign.append(pair[0]) 610 | 611 | for pair in foreign: 612 | if current_table['table_names'][current_table['col_table'][pair]] in table_names: 613 | group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][pair], 614 | current_table['table_names'][current_table['col_table'][pair]], 615 | table_names, N_T) 616 | find_flag = True 617 | break 618 | if find_flag is False: 619 | for (agg, col, tab) in sql_json['select']: 620 | if 'id' in col.lower(): 621 | group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T) 622 | break 623 | if len(group_by_clause) > 5: 624 | pass 625 | else: 626 | raise RuntimeError('fail to convert') 627 | else: 628 | group_by_clause = 'GROUP BY ' + col_to_str('none', find_primary[0], 629 | find_primary[1], 630 | table_names, N_T) 631 | intersect_clause = '' 632 | if 'intersect' in sql_json: 633 | sql_json['intersect']['sql'] = sql_json['sql'] 634 | intersect_clause = 'INTERSECT ' + to_str(sql_json['intersect'], len(table_names) + 1, schema, table_names) 635 | union_clause = '' 636 | if 'union' in sql_json: 637 | sql_json['union']['sql'] = sql_json['sql'] 638 | union_clause = 'UNION ' + to_str(sql_json['union'], len(table_names) + 1, schema, table_names) 639 | except_clause = '' 640 | if 'except' in sql_json: 641 | sql_json['except']['sql'] = sql_json['sql'] 642 | except_clause = 'EXCEPT ' + to_str(sql_json['except'], len(table_names) + 1, schema, table_names) 643 | 644 | # print(current_table['table_names_original']) 645 | table_names_replace = {} 646 | for a, b in zip(current_table['table_names_original'], current_table['table_names']): 647 | table_names_replace[b] = a 648 | new_table_names = {} 649 | for key, value in table_names.items(): 650 | if key is None: 651 | continue 652 | new_table_names[table_names_replace[key]] = value 653 | from_clause = infer_from_clause(new_table_names, schema, all_columns).strip() 654 | 655 | sql = ' '.join([select_clause_str, from_clause, where_clause, group_by_clause, have_clause, sup_clause, order_clause, 656 | intersect_clause, union_clause, except_clause]) 657 | 658 | return sql 659 | 660 | 661 | if __name__ == '__main__': 662 | 663 | arg_parser = argparse.ArgumentParser() 664 | arg_parser.add_argument('--data_path', type=str, help='dataset path', required=True) 665 | arg_parser.add_argument('--input_path', type=str, help='predicted logical form', required=True) 666 | arg_parser.add_argument('--output_path', type=str, help='output data') 667 | args = arg_parser.parse_args() 668 | 669 | # loading dataSets 670 | datas, schemas = load_dataSets(args) 671 | alter_not_in(datas, schemas=schemas) 672 | alter_inter(datas) 673 | alter_column0(datas) 674 | 675 | 676 | index = range(len(datas)) 677 | count = 0 678 | exception_count = 0 679 | with open(args.output_path, 'w', encoding='utf8') as d, open('gold.txt', 'w', encoding='utf8') as g: 680 | for i in index: 681 | try: 682 | result = transform(datas[i], schemas[datas[i]['db_id']]) 683 | d.write(result[0] + '\n') 684 | g.write("%s\t%s\t%s\n" % (datas[i]['query'], datas[i]["db_id"], datas[i]["question"])) 685 | count += 1 686 | except Exception as e: 687 | result = transform(datas[i], schemas[datas[i]['db_id']], origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)') 688 | exception_count += 1 689 | d.write(result[0] + '\n') 690 | g.write("%s\t%s\t%s\n" % (datas[i]['query'], datas[i]["db_id"], datas[i]["question"])) 691 | count += 1 692 | print(e) 693 | print('Exception') 694 | print(traceback.format_exc()) 695 | print('===\n\n') 696 | 697 | print(count, exception_count) 698 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : __init__.py 9 | # @Software: PyCharm 10 | """ -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : args.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import random 13 | import argparse 14 | import torch 15 | import numpy as np 16 | 17 | def init_arg_parser(): 18 | arg_parser = argparse.ArgumentParser() 19 | arg_parser.add_argument('--seed', default=5783287, type=int, help='random seed') 20 | arg_parser.add_argument('--cuda', action='store_true', help='use gpu') 21 | arg_parser.add_argument('--lr_scheduler', action='store_true', help='use learning rate scheduler') 22 | arg_parser.add_argument('--lr_scheduler_gammar', default=0.5, type=float, help='decay rate of learning rate scheduler') 23 | arg_parser.add_argument('--column_pointer', action='store_true', help='use column pointer') 24 | arg_parser.add_argument('--loss_epoch_threshold', default=20, type=int, help='loss epoch threshold') 25 | arg_parser.add_argument('--sketch_loss_coefficient', default=0.2, type=float, help='sketch loss coefficient') 26 | arg_parser.add_argument('--sentence_features', action='store_true', help='use sentence features') 27 | arg_parser.add_argument('--model_name', choices=['transformer', 'rnn', 'table', 'sketch'], default='rnn', 28 | help='model name') 29 | 30 | arg_parser.add_argument('--lstm', choices=['lstm', 'lstm_with_dropout', 'parent_feed'], default='lstm') 31 | 32 | arg_parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model') 33 | arg_parser.add_argument('--glove_embed_path', default="glove.42B.300d.txt", type=str) 34 | 35 | arg_parser.add_argument('--batch_size', default=64, type=int, help='batch size') 36 | arg_parser.add_argument('--beam_size', default=5, type=int, help='beam size for beam search') 37 | arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings') 38 | arg_parser.add_argument('--col_embed_size', default=300, type=int, help='size of word embeddings') 39 | 40 | arg_parser.add_argument('--action_embed_size', default=128, type=int, help='size of word embeddings') 41 | arg_parser.add_argument('--type_embed_size', default=128, type=int, help='size of word embeddings') 42 | arg_parser.add_argument('--hidden_size', default=100, type=int, help='size of LSTM hidden states') 43 | arg_parser.add_argument('--att_vec_size', default=100, type=int, help='size of attentional vector') 44 | arg_parser.add_argument('--dropout', default=0.3, type=float, help='dropout rate') 45 | arg_parser.add_argument('--word_dropout', default=0.2, type=float, help='word dropout rate') 46 | 47 | # readout layer 48 | arg_parser.add_argument('--no_query_vec_to_action_map', default=False, action='store_true') 49 | arg_parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear']) 50 | arg_parser.add_argument('--query_vec_to_action_diff_map', default=False, action='store_true') 51 | 52 | 53 | arg_parser.add_argument('--column_att', choices=['dot_prod', 'affine'], default='affine') 54 | 55 | arg_parser.add_argument('--decode_max_time_step', default=40, type=int, help='maximum number of time steps used ' 56 | 'in decoding and sampling') 57 | 58 | 59 | arg_parser.add_argument('--save_to', default='model', type=str, help='save trained model to') 60 | arg_parser.add_argument('--toy', action='store_true', 61 | help='If set, use small data; used for fast debugging.') 62 | arg_parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients') 63 | arg_parser.add_argument('--max_epoch', default=-1, type=int, help='maximum number of training epoches') 64 | arg_parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer') 65 | arg_parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 66 | 67 | arg_parser.add_argument('--dataset', default="./data", type=str) 68 | 69 | arg_parser.add_argument('--epoch', default=50, type=int, help='Maximum Epoch') 70 | arg_parser.add_argument('--save', default='./', type=str, 71 | help="Path to save the checkpoint and logs of epoch") 72 | 73 | return arg_parser 74 | 75 | def init_config(arg_parser): 76 | args = arg_parser.parse_args() 77 | torch.manual_seed(args.seed) 78 | if args.cuda: 79 | torch.cuda.manual_seed(args.seed) 80 | np.random.seed(int(args.seed * 13 / 7)) 81 | random.seed(int(args.seed)) 82 | return args 83 | -------------------------------------------------------------------------------- /src/beam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : beam.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import copy 13 | 14 | from src.rule import semQL 15 | 16 | 17 | class ActionInfo(object): 18 | """sufficient statistics for making a prediction of an action at a time step""" 19 | 20 | def __init__(self, action=None): 21 | self.t = 0 22 | self.score = 0 23 | self.parent_t = -1 24 | self.action = action 25 | self.frontier_prod = None 26 | self.frontier_field = None 27 | 28 | # for GenToken actions only 29 | self.copy_from_src = False 30 | self.src_token_position = -1 31 | 32 | 33 | class Beams(object): 34 | def __init__(self, is_sketch=False): 35 | self.actions = [] 36 | self.action_infos = [] 37 | self.inputs = [] 38 | self.score = 0. 39 | self.t = 0 40 | self.is_sketch = is_sketch 41 | self.sketch_step = 0 42 | self.sketch_attention_history = list() 43 | 44 | def get_availableClass(self): 45 | """ 46 | return the available action class 47 | :return: 48 | """ 49 | 50 | # TODO: it could be update by speed 51 | # return the available class using rule 52 | # FIXME: now should change for these 11: "Filter 1 ROOT", 53 | def check_type(lists): 54 | for s in lists: 55 | if type(s) == int: 56 | return False 57 | return True 58 | 59 | stack = [semQL.Root1] 60 | for action in self.actions: 61 | infer_action = action.get_next_action(is_sketch=self.is_sketch) 62 | infer_action.reverse() 63 | if stack[-1] is type(action): 64 | stack.pop() 65 | # check if the are non-terminal 66 | if check_type(infer_action): 67 | stack.extend(infer_action) 68 | else: 69 | raise RuntimeError("Not the right action") 70 | 71 | result = stack[-1] if len(stack) > 0 else None 72 | 73 | return result 74 | 75 | @classmethod 76 | def get_parent_action(cls, actions): 77 | """ 78 | 79 | :param actions: 80 | :return: 81 | """ 82 | 83 | def check_type(lists): 84 | for s in lists: 85 | if type(s) == int: 86 | return False 87 | return True 88 | 89 | # check the origin state Root 90 | if len(actions) == 0: 91 | return None 92 | 93 | stack = [semQL.Root1] 94 | for id_x, action in enumerate(actions): 95 | infer_action = action.get_next_action() 96 | for ac in infer_action: 97 | ac.parent = action 98 | ac.pt = id_x 99 | infer_action.reverse() 100 | if stack[-1] is type(action): 101 | stack.pop() 102 | # check if the are non-terminal 103 | if check_type(infer_action): 104 | stack.extend(infer_action) 105 | else: 106 | for t in actions: 107 | if type(t) != semQL.C: 108 | print(t, end="") 109 | print('asd') 110 | print(action) 111 | print(stack[-1]) 112 | raise RuntimeError("Not the right action") 113 | result = stack[-1] if len(stack) > 0 else None 114 | 115 | return result 116 | 117 | def apply_action(self, action): 118 | # TODO: not finish implement yet 119 | self.t += 1 120 | self.actions.append(action) 121 | 122 | def clone_and_apply_action(self, action): 123 | new_hyp = self.copy() 124 | new_hyp.apply_action(action) 125 | 126 | return new_hyp 127 | 128 | def clone_and_apply_action_info(self, action_info): 129 | action = action_info.action 130 | action.score = action_info.score 131 | new_hyp = self.clone_and_apply_action(action) 132 | new_hyp.action_infos.append(action_info) 133 | new_hyp.sketch_step = self.sketch_step 134 | new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history) 135 | 136 | return new_hyp 137 | 138 | def copy(self): 139 | new_hyp = Beams(is_sketch=self.is_sketch) 140 | # if self.tree: 141 | # new_hyp.tree = self.tree.copy() 142 | 143 | new_hyp.actions = list(self.actions) 144 | new_hyp.score = self.score 145 | new_hyp.t = self.t 146 | new_hyp.sketch_step = self.sketch_step 147 | new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history) 148 | 149 | return new_hyp 150 | 151 | def infer_n(self): 152 | if len(self.actions) > 4: 153 | prev_action = self.actions[-3] 154 | if isinstance(prev_action, semQL.Filter): 155 | if prev_action.id_c > 11: 156 | # Nested Query, only select 1 column 157 | return ['N A'] 158 | if self.actions[0].id_c != 3: 159 | return [self.actions[3].production] 160 | return semQL.N._init_grammar() 161 | 162 | @property 163 | def completed(self): 164 | return True if self.get_availableClass() is None else False 165 | 166 | @property 167 | def is_valid(self): 168 | actions = self.actions 169 | return self.check_sel_valid(actions) 170 | 171 | def check_sel_valid(self, actions): 172 | find_sel = False 173 | sel_actions = list() 174 | for ac in actions: 175 | if type(ac) == semQL.Sel: 176 | find_sel = True 177 | elif find_sel and type(ac) in [semQL.N, semQL.T, semQL.C, semQL.A]: 178 | if type(ac) not in [semQL.N]: 179 | sel_actions.append(ac) 180 | elif find_sel and type(ac) not in [semQL.N, semQL.T, semQL.C, semQL.A]: 181 | break 182 | 183 | if find_sel is False: 184 | return True 185 | 186 | # not the complete sel lf 187 | if len(sel_actions) % 3 != 0: 188 | return True 189 | 190 | sel_string = list() 191 | for i in range(len(sel_actions) // 3): 192 | if (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c) in sel_string: 193 | return False 194 | else: 195 | sel_string.append( 196 | (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c)) 197 | return True 198 | 199 | 200 | if __name__ == '__main__': 201 | test = Beams(is_sketch=True) 202 | # print(semQL.Root1(1).get_next_action()) 203 | test.actions.append(semQL.Root1(3)) 204 | test.actions.append(semQL.Root(5)) 205 | 206 | print(str(test.get_availableClass())) 207 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : utils.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import copy 13 | 14 | import src.rule.semQL as define_rule 15 | from src.models import nn_utils 16 | 17 | 18 | class Example: 19 | """ 20 | 21 | """ 22 | def __init__(self, src_sent, tgt_actions=None, vis_seq=None, tab_cols=None, col_num=None, sql=None, 23 | one_hot_type=None, col_hot_type=None, schema_len=None, tab_ids=None, 24 | table_names=None, table_len=None, col_table_dict=None, cols=None, 25 | table_col_name=None, table_col_len=None, 26 | col_pred=None, tokenized_src_sent=None, 27 | ): 28 | 29 | self.src_sent = src_sent 30 | self.tokenized_src_sent = tokenized_src_sent 31 | self.vis_seq = vis_seq 32 | self.tab_cols = tab_cols 33 | self.col_num = col_num 34 | self.sql = sql 35 | self.one_hot_type=one_hot_type 36 | self.col_hot_type = col_hot_type 37 | self.schema_len = schema_len 38 | self.tab_ids = tab_ids 39 | self.table_names = table_names 40 | self.table_len = table_len 41 | self.col_table_dict = col_table_dict 42 | self.cols = cols 43 | self.table_col_name = table_col_name 44 | self.table_col_len = table_col_len 45 | self.col_pred = col_pred 46 | self.tgt_actions = tgt_actions 47 | self.truth_actions = copy.deepcopy(tgt_actions) 48 | 49 | self.sketch = list() 50 | if self.truth_actions: 51 | for ta in self.truth_actions: 52 | if isinstance(ta, define_rule.C) or isinstance(ta, define_rule.T) or isinstance(ta, define_rule.A): 53 | continue 54 | self.sketch.append(ta) 55 | 56 | 57 | class cached_property(object): 58 | """ A property that is only computed once per instance and then replaces 59 | itself with an ordinary attribute. Deleting the attribute resets the 60 | property. 61 | 62 | Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 63 | """ 64 | 65 | def __init__(self, func): 66 | self.__doc__ = getattr(func, '__doc__') 67 | self.func = func 68 | 69 | def __get__(self, obj, cls): 70 | if obj is None: 71 | return self 72 | value = obj.__dict__[self.func.__name__] = self.func(obj) 73 | return value 74 | 75 | 76 | class Batch(object): 77 | def __init__(self, examples, grammar, cuda=False): 78 | self.examples = examples 79 | 80 | if examples[0].tgt_actions: 81 | self.max_action_num = max(len(e.tgt_actions) for e in self.examples) 82 | self.max_sketch_num = max(len(e.sketch) for e in self.examples) 83 | 84 | self.src_sents = [e.src_sent for e in self.examples] 85 | self.src_sents_len = [len(e.src_sent) for e in self.examples] 86 | self.tokenized_src_sents = [e.tokenized_src_sent for e in self.examples] 87 | self.tokenized_src_sents_len = [len(e.tokenized_src_sent) for e in examples] 88 | self.src_sents_word = [e.src_sent for e in self.examples] 89 | self.table_sents_word = [[" ".join(x) for x in e.tab_cols] for e in self.examples] 90 | 91 | self.schema_sents_word = [[" ".join(x) for x in e.table_names] for e in self.examples] 92 | 93 | self.src_type = [e.one_hot_type for e in self.examples] 94 | self.col_hot_type = [e.col_hot_type for e in self.examples] 95 | self.table_sents = [e.tab_cols for e in self.examples] 96 | self.col_num = [e.col_num for e in self.examples] 97 | self.tab_ids = [e.tab_ids for e in self.examples] 98 | self.table_names = [e.table_names for e in self.examples] 99 | self.table_len = [e.table_len for e in examples] 100 | self.col_table_dict = [e.col_table_dict for e in examples] 101 | self.table_col_name = [e.table_col_name for e in examples] 102 | self.table_col_len = [e.table_col_len for e in examples] 103 | self.col_pred = [e.col_pred for e in examples] 104 | 105 | self.grammar = grammar 106 | self.cuda = cuda 107 | 108 | def __len__(self): 109 | return len(self.examples) 110 | 111 | 112 | def table_dict_mask(self, table_dict): 113 | return nn_utils.table_dict_to_mask_tensor(self.table_len, table_dict, cuda=self.cuda) 114 | 115 | @cached_property 116 | def pred_col_mask(self): 117 | return nn_utils.pred_col_mask(self.col_pred, self.col_num) 118 | 119 | @cached_property 120 | def schema_token_mask(self): 121 | return nn_utils.length_array_to_mask_tensor(self.table_len, cuda=self.cuda) 122 | 123 | @cached_property 124 | def table_token_mask(self): 125 | return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda) 126 | 127 | @cached_property 128 | def table_appear_mask(self): 129 | return nn_utils.appear_to_mask_tensor(self.col_num, cuda=self.cuda) 130 | 131 | @cached_property 132 | def table_unk_mask(self): 133 | return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda, value=None) 134 | 135 | @cached_property 136 | def src_token_mask(self): 137 | return nn_utils.length_array_to_mask_tensor(self.src_sents_len, 138 | cuda=self.cuda) 139 | 140 | 141 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/26 7 | # @Author : Jiaqi&Zecheng 8 | # @File : __init__.py.py 9 | # @Software: PyCharm 10 | """ -------------------------------------------------------------------------------- /src/models/basic_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/26 7 | # @Author : Jiaqi&Zecheng 8 | # @File : basic_model.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import numpy as np 13 | import os 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.nn.utils 18 | from torch.autograd import Variable 19 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 20 | 21 | from src.rule import semQL as define_rule 22 | 23 | 24 | class BasicModel(nn.Module): 25 | 26 | def __init__(self): 27 | super(BasicModel, self).__init__() 28 | pass 29 | 30 | def embedding_cosine(self, src_embedding, table_embedding, table_unk_mask): 31 | embedding_differ = [] 32 | for i in range(table_embedding.size(1)): 33 | one_table_embedding = table_embedding[:, i, :] 34 | one_table_embedding = one_table_embedding.unsqueeze(1).expand(table_embedding.size(0), 35 | src_embedding.size(1), 36 | table_embedding.size(2)) 37 | 38 | topk_val = F.cosine_similarity(one_table_embedding, src_embedding, dim=-1) 39 | 40 | embedding_differ.append(topk_val) 41 | embedding_differ = torch.stack(embedding_differ).transpose(1, 0) 42 | embedding_differ.data.masked_fill_(table_unk_mask.unsqueeze(2).expand( 43 | table_embedding.size(0), 44 | table_embedding.size(1), 45 | embedding_differ.size(2) 46 | ).bool(), 0) 47 | 48 | return embedding_differ 49 | 50 | def encode(self, src_sents_var, src_sents_len, q_onehot_project=None): 51 | """ 52 | encode the source sequence 53 | :return: 54 | src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2) 55 | last_state, last_cell: Variable(batch_size, hidden_size) 56 | """ 57 | src_token_embed = self.gen_x_batch(src_sents_var) 58 | 59 | if q_onehot_project is not None: 60 | src_token_embed = torch.cat([src_token_embed, q_onehot_project], dim=-1) 61 | 62 | packed_src_token_embed = pack_padded_sequence(src_token_embed, src_sents_len, batch_first=True) 63 | # src_encodings: (tgt_query_len, batch_size, hidden_size) 64 | src_encodings, (last_state, last_cell) = self.encoder_lstm(packed_src_token_embed) 65 | src_encodings, _ = pad_packed_sequence(src_encodings, batch_first=True) 66 | # src_encodings: (batch_size, tgt_query_len, hidden_size) 67 | # src_encodings = src_encodings.permute(1, 0, 2) 68 | # (batch_size, hidden_size * 2) 69 | last_state = torch.cat([last_state[0], last_state[1]], -1) 70 | last_cell = torch.cat([last_cell[0], last_cell[1]], -1) 71 | 72 | return src_encodings, (last_state, last_cell) 73 | 74 | def input_type(self, values_list): 75 | B = len(values_list) 76 | val_len = [] 77 | for value in values_list: 78 | val_len.append(len(value)) 79 | max_len = max(val_len) 80 | # for the Begin and End 81 | val_emb_array = np.zeros((B, max_len, values_list[0].shape[1]), dtype=np.float32) 82 | for i in range(B): 83 | val_emb_array[i, :val_len[i], :] = values_list[i][:, :] 84 | 85 | val_inp = torch.from_numpy(val_emb_array) 86 | if self.args.cuda: 87 | val_inp = val_inp.cuda() 88 | val_inp_var = Variable(val_inp) 89 | return val_inp_var 90 | 91 | def padding_sketch(self, sketch): 92 | padding_result = [] 93 | for action in sketch: 94 | padding_result.append(action) 95 | if type(action) == define_rule.N: 96 | for _ in range(action.id_c + 1): 97 | padding_result.append(define_rule.A(0)) 98 | padding_result.append(define_rule.C(0)) 99 | padding_result.append(define_rule.T(0)) 100 | elif type(action) == define_rule.Filter and 'A' in action.production: 101 | padding_result.append(define_rule.A(0)) 102 | padding_result.append(define_rule.C(0)) 103 | padding_result.append(define_rule.T(0)) 104 | elif type(action) == define_rule.Order or type(action) == define_rule.Sup: 105 | padding_result.append(define_rule.A(0)) 106 | padding_result.append(define_rule.C(0)) 107 | padding_result.append(define_rule.T(0)) 108 | 109 | return padding_result 110 | 111 | def gen_x_batch(self, q): 112 | B = len(q) 113 | val_embs = [] 114 | val_len = np.zeros(B, dtype=np.int64) 115 | is_list = False 116 | if type(q[0][0]) == list: 117 | is_list = True 118 | for i, one_q in enumerate(q): 119 | if not is_list: 120 | q_val = list( 121 | map(lambda x: self.word_emb.get(x, np.zeros(self.args.col_embed_size, dtype=np.float32)), one_q)) 122 | else: 123 | q_val = [] 124 | for ws in one_q: 125 | emb_list = [] 126 | ws_len = len(ws) 127 | for w in ws: 128 | emb_list.append(self.word_emb.get(w, self.word_emb['unk'])) 129 | if ws_len == 0: 130 | raise Exception("word list should not be empty!") 131 | elif ws_len == 1: 132 | q_val.append(emb_list[0]) 133 | else: 134 | q_val.append(sum(emb_list) / float(ws_len)) 135 | 136 | val_embs.append(q_val) 137 | val_len[i] = len(q_val) 138 | max_len = max(val_len) 139 | 140 | val_emb_array = np.zeros((B, max_len, self.args.col_embed_size), dtype=np.float32) 141 | for i in range(B): 142 | for t in range(len(val_embs[i])): 143 | val_emb_array[i, t, :] = val_embs[i][t] 144 | val_inp = torch.from_numpy(val_emb_array) 145 | if self.args.cuda: 146 | val_inp = val_inp.cuda() 147 | return val_inp 148 | 149 | def save(self, path): 150 | dir_name = os.path.dirname(path) 151 | if not os.path.exists(dir_name): 152 | os.makedirs(dir_name) 153 | 154 | params = { 155 | 'args': self.args, 156 | 'vocab': self.vocab, 157 | 'grammar': self.grammar, 158 | 'state_dict': self.state_dict() 159 | } 160 | torch.save(params, path) 161 | -------------------------------------------------------------------------------- /src/models/nn_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : utils.py 9 | # @Software: PyCharm 10 | """ 11 | import torch.nn.functional as F 12 | import torch.nn.init as init 13 | import numpy as np 14 | import torch 15 | from torch.autograd import Variable 16 | from six.moves import xrange 17 | 18 | def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None): 19 | """ 20 | :param h_t: (batch_size, hidden_size) 21 | :param src_encoding: (batch_size, src_sent_len, hidden_size * 2) 22 | :param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size) 23 | :param mask: (batch_size, src_sent_len) 24 | """ 25 | # (batch_size, src_sent_len) 26 | att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2) 27 | if mask is not None: 28 | att_weight.data.masked_fill_(mask.bool(), -float('inf')) 29 | att_weight = F.softmax(att_weight, dim=-1) 30 | 31 | att_view = (att_weight.size(0), 1, att_weight.size(1)) 32 | # (batch_size, hidden_size) 33 | ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1) 34 | 35 | return ctx_vec, att_weight 36 | 37 | 38 | def length_array_to_mask_tensor(length_array, cuda=False, value=None): 39 | max_len = max(length_array) 40 | batch_size = len(length_array) 41 | 42 | mask = np.ones((batch_size, max_len), dtype=np.uint8) 43 | for i, seq_len in enumerate(length_array): 44 | mask[i][:seq_len] = 0 45 | 46 | if value != None: 47 | for b_id in range(len(value)): 48 | for c_id, c in enumerate(value[b_id]): 49 | if value[b_id][c_id] == [3]: 50 | mask[b_id][c_id] = 1 51 | 52 | mask = torch.ByteTensor(mask) 53 | return mask.cuda() if cuda else mask 54 | 55 | 56 | def table_dict_to_mask_tensor(length_array, table_dict, cuda=False ): 57 | max_len = max(length_array) 58 | batch_size = len(table_dict) 59 | 60 | mask = np.ones((batch_size, max_len), dtype=np.uint8) 61 | for i, ta_val in enumerate(table_dict): 62 | for tt in ta_val: 63 | mask[i][tt] = 0 64 | 65 | mask = torch.ByteTensor(mask) 66 | return mask.cuda() if cuda else mask 67 | 68 | 69 | def length_position_tensor(length_array, cuda=False, value=None): 70 | max_len = max(length_array) 71 | batch_size = len(length_array) 72 | 73 | mask = np.zeros((batch_size, max_len), dtype=np.float32) 74 | 75 | for b_id in range(batch_size): 76 | for len_c in range(length_array[b_id]): 77 | mask[b_id][len_c] = len_c + 1 78 | 79 | mask = torch.LongTensor(mask) 80 | return mask.cuda() if cuda else mask 81 | 82 | 83 | def appear_to_mask_tensor(length_array, cuda=False, value=None): 84 | max_len = max(length_array) 85 | batch_size = len(length_array) 86 | mask = np.zeros((batch_size, max_len), dtype=np.float32) 87 | return mask 88 | 89 | def pred_col_mask(value, max_len): 90 | max_len = max(max_len) 91 | batch_size = len(value) 92 | mask = np.ones((batch_size, max_len), dtype=np.uint8) 93 | for v_ind, v_val in enumerate(value): 94 | for v in v_val: 95 | mask[v_ind][v] = 0 96 | mask = torch.ByteTensor(mask) 97 | return mask.cuda() 98 | 99 | 100 | def input_transpose(sents, pad_token): 101 | """ 102 | transform the input List[sequence] of size (batch_size, max_sent_len) 103 | into a list of size (batch_size, max_sent_len), with proper padding 104 | """ 105 | max_len = max(len(s) for s in sents) 106 | batch_size = len(sents) 107 | sents_t = [] 108 | masks = [] 109 | for e_id in range(batch_size): 110 | if type(sents[0][0]) != list: 111 | sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else pad_token for i in range(max_len)]) 112 | else: 113 | sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else [pad_token] for i in range(max_len)]) 114 | 115 | masks.append([1 if len(sents[e_id]) > i else 0 for i in range(max_len)]) 116 | 117 | return sents_t, masks 118 | 119 | 120 | def word2id(sents, vocab): 121 | if type(sents[0]) == list: 122 | if type(sents[0][0]) != list: 123 | return [[vocab[w] for w in s] for s in sents] 124 | else: 125 | return [[[vocab[w] for w in s] for s in v] for v in sents ] 126 | else: 127 | return [vocab[w] for w in sents] 128 | 129 | 130 | def id2word(sents, vocab): 131 | if type(sents[0]) == list: 132 | return [[vocab.id2word[w] for w in s] for s in sents] 133 | else: 134 | return [vocab.id2word[w] for w in sents] 135 | 136 | 137 | def to_input_variable(sequences, vocab, cuda=False, training=True): 138 | """ 139 | given a list of sequences, 140 | return a tensor of shape (max_sent_len, batch_size) 141 | """ 142 | word_ids = word2id(sequences, vocab) 143 | sents_t, masks = input_transpose(word_ids, vocab['']) 144 | 145 | if type(sents_t[0][0]) != list: 146 | with torch.no_grad(): 147 | sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False) 148 | if cuda: 149 | sents_var = sents_var.cuda() 150 | else: 151 | sents_var = sents_t 152 | 153 | return sents_var 154 | 155 | 156 | def variable_constr(x, v, cuda=False): 157 | return Variable(torch.cuda.x(v)) if cuda else Variable(torch.x(v)) 158 | 159 | 160 | def batch_iter(examples, batch_size, shuffle=False): 161 | index_arr = np.arange(len(examples)) 162 | if shuffle: 163 | np.random.shuffle(index_arr) 164 | 165 | batch_num = int(np.ceil(len(examples) / float(batch_size))) 166 | for batch_id in xrange(batch_num): 167 | batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)] 168 | batch_examples = [examples[i] for i in batch_ids] 169 | 170 | yield batch_examples 171 | 172 | 173 | def isnan(data): 174 | data = data.cpu().numpy() 175 | return np.isnan(data).any() or np.isinf(data).any() 176 | 177 | 178 | def log_sum_exp(inputs, dim=None, keepdim=False): 179 | """Numerically stable logsumexp. 180 | source: https://github.com/pytorch/pytorch/issues/2591 181 | 182 | Args: 183 | inputs: A Variable with any shape. 184 | dim: An integer. 185 | keepdim: A boolean. 186 | 187 | Returns: 188 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). 189 | """ 190 | # For a 1-D array x (any array along a single dimension), 191 | # log sum exp(x) = s + log sum exp(x - s) 192 | # with s = max(x) being a common choice. 193 | 194 | if dim is None: 195 | inputs = inputs.view(-1) 196 | dim = 0 197 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 198 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 199 | if not keepdim: 200 | outputs = outputs.squeeze(dim) 201 | return outputs 202 | 203 | 204 | def uniform_init(lower, upper, params): 205 | for p in params: 206 | p.data.uniform_(lower, upper) 207 | 208 | 209 | def glorot_init(params): 210 | for p in params: 211 | if len(p.data.size()) > 1: 212 | init.xavier_normal(p.data) 213 | 214 | 215 | def identity(x): 216 | return x 217 | 218 | 219 | def pad_matrix(matrixs, cuda=False): 220 | """ 221 | :param matrixs: 222 | :return: [batch_size, max_shape, max_shape], [batch_size] 223 | """ 224 | shape = [m.shape[0] for m in matrixs] 225 | max_shape = max(shape) 226 | tensors = list() 227 | for s, m in zip(shape, matrixs): 228 | delta = max_shape - s 229 | if s > 0: 230 | tensors.append(torch.as_tensor(np.pad(m, [(0, delta), (0, delta)], mode='constant'), dtype=torch.float)) 231 | else: 232 | tensors.append(torch.as_tensor(m, dtype=torch.float)) 233 | tensors = torch.stack(tensors) 234 | if cuda: 235 | tensors = tensors.cuda() 236 | return tensors 237 | -------------------------------------------------------------------------------- /src/models/pointer_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # coding=utf8 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.utils 9 | from torch.nn import Parameter 10 | 11 | 12 | class AuxiliaryPointerNet(nn.Module): 13 | 14 | def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): 15 | super(AuxiliaryPointerNet, self).__init__() 16 | 17 | assert attention_type in ('affine', 'dot_prod') 18 | if attention_type == 'affine': 19 | self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) 20 | self.auxiliary_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) 21 | self.attention_type = attention_type 22 | 23 | def forward(self, src_encodings, src_context_encodings, src_token_mask, query_vec): 24 | """ 25 | :param src_context_encodings: Variable(batch_size, src_sent_len, src_encoding_size) 26 | :param src_encodings: Variable(batch_size, src_sent_len, src_encoding_size) 27 | :param src_token_mask: Variable(batch_size, src_sent_len) 28 | :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size) 29 | :return: Variable(tgt_action_num, batch_size, src_sent_len) 30 | """ 31 | 32 | # (batch_size, 1, src_sent_len, query_vec_size) 33 | encodings = src_encodings.clone() 34 | context_encodings = src_context_encodings.clone() 35 | if self.attention_type == 'affine': 36 | encodings = self.src_encoding_linear(src_encodings) 37 | context_encodings = self.auxiliary_encoding_linear(src_context_encodings) 38 | encodings = encodings.unsqueeze(1) 39 | context_encodings = context_encodings.unsqueeze(1) 40 | 41 | # (batch_size, tgt_action_num, query_vec_size, 1) 42 | q = query_vec.permute(1, 0, 2).unsqueeze(3) 43 | 44 | # (batch_size, tgt_action_num, src_sent_len) 45 | weights = torch.matmul(encodings, q).squeeze(3) 46 | context_weights = torch.matmul(context_encodings, q).squeeze(3) 47 | 48 | # (tgt_action_num, batch_size, src_sent_len) 49 | weights = weights.permute(1, 0, 2) 50 | context_weights = context_weights.permute(1, 0, 2) 51 | 52 | if src_token_mask is not None: 53 | # (tgt_action_num, batch_size, src_sent_len) 54 | src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights) 55 | weights.data.masked_fill_(src_token_mask.bool(), -float('inf')) 56 | context_weights.data.masked_fill_(src_token_mask.bool(), -float('inf')) 57 | 58 | sigma = 0.1 59 | return weights.squeeze(0) + sigma * context_weights.squeeze(0) 60 | 61 | 62 | class PointerNet(nn.Module): 63 | def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): 64 | super(PointerNet, self).__init__() 65 | 66 | assert attention_type in ('affine', 'dot_prod') 67 | if attention_type == 'affine': 68 | self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) 69 | 70 | self.attention_type = attention_type 71 | self.input_linear = nn.Linear(query_vec_size, query_vec_size) 72 | self.type_linear = nn.Linear(32, query_vec_size) 73 | self.V = Parameter(torch.FloatTensor(query_vec_size), requires_grad=True) 74 | self.tanh = nn.Tanh() 75 | self.context_linear = nn.Conv1d(src_encoding_size, query_vec_size, 1, 1) 76 | self.coverage_linear = nn.Conv1d(1, query_vec_size, 1, 1) 77 | 78 | 79 | nn.init.uniform_(self.V, -1, 1) 80 | 81 | def forward(self, src_encodings, src_token_mask, query_vec): 82 | """ 83 | :param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2) 84 | :param src_token_mask: Variable(batch_size, src_sent_len) 85 | :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size) 86 | :return: Variable(tgt_action_num, batch_size, src_sent_len) 87 | """ 88 | 89 | # (batch_size, 1, src_sent_len, query_vec_size) 90 | 91 | if self.attention_type == 'affine': 92 | src_encodings = self.src_encoding_linear(src_encodings) 93 | src_encodings = src_encodings.unsqueeze(1) 94 | 95 | # (batch_size, tgt_action_num, query_vec_size, 1) 96 | q = query_vec.permute(1, 0, 2).unsqueeze(3) 97 | 98 | weights = torch.matmul(src_encodings, q).squeeze(3) 99 | 100 | weights = weights.permute(1, 0, 2) 101 | 102 | if src_token_mask is not None: 103 | src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights) 104 | weights.data.masked_fill_(src_token_mask.bool(), -float('inf')) 105 | 106 | return weights.squeeze(0) 107 | -------------------------------------------------------------------------------- /src/rule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/27 7 | # @Author : Jiaqi&Zecheng 8 | # @File : __init__.py.py 9 | # @Software: PyCharm 10 | """ -------------------------------------------------------------------------------- /src/rule/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : utils.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | from collections import deque, namedtuple 13 | 14 | 15 | # we'll use infinity as a default distance to nodes. 16 | inf = float('inf') 17 | Edge = namedtuple('Edge', 'start, end, cost') 18 | 19 | 20 | def make_edge(start, end, cost=1): 21 | return Edge(start, end, cost) 22 | 23 | 24 | class Graph: 25 | def __init__(self, edges): 26 | # let's check that the data is right 27 | wrong_edges = [i for i in edges if len(i) not in [2, 3]] 28 | if wrong_edges: 29 | raise ValueError('Wrong edges data: {}'.format(wrong_edges)) 30 | 31 | self.edges = [make_edge(*edge) for edge in edges] 32 | 33 | @property 34 | def vertices(self): 35 | return set( 36 | # this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4] 37 | # the set above makes it's elements unique. 38 | sum( 39 | ([edge.start, edge.end] for edge in self.edges), [] 40 | ) 41 | ) 42 | 43 | def get_node_pairs(self, n1, n2, both_ends=True): 44 | if both_ends: 45 | node_pairs = [[n1, n2], [n2, n1]] 46 | else: 47 | node_pairs = [[n1, n2]] 48 | return node_pairs 49 | 50 | def remove_edge(self, n1, n2, both_ends=True): 51 | node_pairs = self.get_node_pairs(n1, n2, both_ends) 52 | edges = self.edges[:] 53 | for edge in edges: 54 | if [edge.start, edge.end] in node_pairs: 55 | self.edges.remove(edge) 56 | 57 | def add_edge(self, n1, n2, cost=1, both_ends=True): 58 | node_pairs = self.get_node_pairs(n1, n2, both_ends) 59 | for edge in self.edges: 60 | if [edge.start, edge.end] in node_pairs: 61 | return ValueError('Edge {} {} already exists'.format(n1, n2)) 62 | 63 | self.edges.append(Edge(start=n1, end=n2, cost=cost)) 64 | if both_ends: 65 | self.edges.append(Edge(start=n2, end=n1, cost=cost)) 66 | 67 | @property 68 | def neighbours(self): 69 | neighbours = {vertex: set() for vertex in self.vertices} 70 | for edge in self.edges: 71 | neighbours[edge.start].add((edge.end, edge.cost)) 72 | 73 | return neighbours 74 | 75 | def dijkstra(self, source, dest): 76 | assert source in self.vertices, 'Such source node doesn\'t exist' 77 | assert dest in self.vertices, 'Such source node doesn\'t exis' 78 | 79 | # 1. Mark all nodes unvisited and store them. 80 | # 2. Set the distance to zero for our initial node 81 | # and to infinity for other nodes. 82 | distances = {vertex: inf for vertex in self.vertices} 83 | previous_vertices = { 84 | vertex: None for vertex in self.vertices 85 | } 86 | distances[source] = 0 87 | vertices = self.vertices.copy() 88 | 89 | while vertices: 90 | # 3. Select the unvisited node with the smallest distance, 91 | # it's current node now. 92 | current_vertex = min( 93 | vertices, key=lambda vertex: distances[vertex]) 94 | 95 | # 6. Stop, if the smallest distance 96 | # among the unvisited nodes is infinity. 97 | if distances[current_vertex] == inf: 98 | break 99 | 100 | # 4. Find unvisited neighbors for the current node 101 | # and calculate their distances through the current node. 102 | for neighbour, cost in self.neighbours[current_vertex]: 103 | alternative_route = distances[current_vertex] + cost 104 | 105 | # Compare the newly calculated distance to the assigned 106 | # and save the smaller one. 107 | if alternative_route < distances[neighbour]: 108 | distances[neighbour] = alternative_route 109 | previous_vertices[neighbour] = current_vertex 110 | 111 | # 5. Mark the current node as visited 112 | # and remove it from the unvisited set. 113 | vertices.remove(current_vertex) 114 | 115 | path, current_vertex = deque(), dest 116 | while previous_vertices[current_vertex] is not None: 117 | path.appendleft(current_vertex) 118 | current_vertex = previous_vertices[current_vertex] 119 | if path: 120 | path.appendleft(current_vertex) 121 | return path 122 | 123 | 124 | if __name__ == '__main__': 125 | graph = Graph([ 126 | ("a", "b", 7), ("a", "c", 9), ("a", "f", 14), ("b", "c", 10), 127 | ("b", "d", 15), ("c", "d", 11), ("c", "f", 2), ("d", "e", 6), 128 | ("e", "f", 9)]) 129 | 130 | print(graph.dijkstra("a", "e")) -------------------------------------------------------------------------------- /src/rule/lf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : utils.py 9 | # @Software: PyCharm 10 | """ 11 | import copy 12 | import json 13 | 14 | import numpy as np 15 | 16 | from src.rule import semQL as define_rule 17 | from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1 18 | 19 | 20 | def _build_single_filter(lf, f): 21 | # No conjunction 22 | agg = lf.pop(0) 23 | column = lf.pop(0) 24 | if len(lf) == 0: 25 | table = None 26 | else: 27 | table = lf.pop(0) 28 | if not isinstance(table, define_rule.T): 29 | lf.insert(0, table) 30 | table = None 31 | assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C) 32 | if len(f.production.split()) == 3: 33 | f.add_children(agg) 34 | agg.set_parent(f) 35 | agg.add_children(column) 36 | column.set_parent(agg) 37 | if table is not None: 38 | column.add_children(table) 39 | table.set_parent(column) 40 | else: 41 | # Subquery 42 | f.add_children(agg) 43 | agg.set_parent(f) 44 | agg.add_children(column) 45 | column.set_parent(agg) 46 | if table is not None: 47 | column.add_children(table) 48 | table.set_parent(column) 49 | _root = _build(lf) 50 | f.add_children(_root) 51 | _root.set_parent(f) 52 | 53 | 54 | def _build_filter(lf, root_filter): 55 | assert isinstance(root_filter, define_rule.Filter) 56 | op = root_filter.production.split()[1] 57 | if op == 'and' or op == 'or': 58 | for i in range(2): 59 | child = lf.pop(0) 60 | op = child.production.split()[1] 61 | if op == 'and' or op == 'or': 62 | _f = _build_filter(lf, child) 63 | root_filter.add_children(_f) 64 | _f.set_parent(root_filter) 65 | else: 66 | _build_single_filter(lf, child) 67 | root_filter.add_children(child) 68 | child.set_parent(root_filter) 69 | else: 70 | _build_single_filter(lf, root_filter) 71 | return root_filter 72 | 73 | 74 | def _build(lf): 75 | root = lf.pop(0) 76 | assert isinstance(root, define_rule.Root) 77 | length = len(root.production.split()) - 1 78 | while len(root.children) != length: 79 | c_instance = lf.pop(0) 80 | if isinstance(c_instance, define_rule.Sel): 81 | sel_instance = c_instance 82 | root.add_children(sel_instance) 83 | sel_instance.set_parent(root) 84 | 85 | # define_rule.N 86 | c_instance = lf.pop(0) 87 | c_instance.set_parent(sel_instance) 88 | sel_instance.add_children(c_instance) 89 | assert isinstance(c_instance, define_rule.N) 90 | for i in range(c_instance.id_c + 1): 91 | agg = lf.pop(0) 92 | column = lf.pop(0) 93 | if len(lf) == 0: 94 | table = None 95 | else: 96 | table = lf.pop(0) 97 | if not isinstance(table, define_rule.T): 98 | lf.insert(0, table) 99 | table = None 100 | assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C) 101 | c_instance.add_children(agg) 102 | agg.set_parent(c_instance) 103 | agg.add_children(column) 104 | column.set_parent(agg) 105 | if table is not None: 106 | column.add_children(table) 107 | table.set_parent(column) 108 | 109 | elif isinstance(c_instance, define_rule.Sup) or isinstance(c_instance, define_rule.Order): 110 | root.add_children(c_instance) 111 | c_instance.set_parent(root) 112 | 113 | agg = lf.pop(0) 114 | column = lf.pop(0) 115 | if len(lf) == 0: 116 | table = None 117 | else: 118 | table = lf.pop(0) 119 | if not isinstance(table, define_rule.T): 120 | lf.insert(0, table) 121 | table = None 122 | assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C) 123 | c_instance.add_children(agg) 124 | agg.set_parent(c_instance) 125 | agg.add_children(column) 126 | column.set_parent(agg) 127 | if table is not None: 128 | column.add_children(table) 129 | table.set_parent(column) 130 | 131 | elif isinstance(c_instance, define_rule.Filter): 132 | _build_filter(lf, c_instance) 133 | root.add_children(c_instance) 134 | c_instance.set_parent(root) 135 | 136 | return root 137 | 138 | 139 | def build_tree(lf): 140 | root = lf.pop(0) 141 | assert isinstance(root, define_rule.Root1) 142 | if root.id_c == 0 or root.id_c == 1 or root.id_c == 2: 143 | root_1 = _build(lf) 144 | root_2 = _build(lf) 145 | root.add_children(root_1) 146 | root.add_children(root_2) 147 | root_1.set_parent(root) 148 | root_2.set_parent(root) 149 | else: 150 | root_1 = _build(lf) 151 | root.add_children(root_1) 152 | root_1.set_parent(root) 153 | verify(root) 154 | # eliminate_parent(root) 155 | 156 | 157 | def eliminate_parent(node): 158 | for child in node.children: 159 | eliminate_parent(child) 160 | node.children = list() 161 | 162 | 163 | def verify(node): 164 | if isinstance(node, C) and len(node.children) > 0: 165 | table = node.children[0] 166 | assert table is None or isinstance(table, T) 167 | if isinstance(node, T): 168 | return 169 | children_num = len(node.children) 170 | if isinstance(node, Root1): 171 | if node.id_c == 0 or node.id_c == 1 or node.id_c == 2: 172 | assert children_num == 2 173 | else: 174 | assert children_num == 1 175 | elif isinstance(node, Root): 176 | assert children_num == len(node.production.split()) - 1 177 | elif isinstance(node, N): 178 | assert children_num == int(node.id_c) + 1 179 | elif isinstance(node, Sup) or isinstance(node, Order) or isinstance(node, Sel): 180 | assert children_num == 1 181 | elif isinstance(node, Filter): 182 | op = node.production.split()[1] 183 | if op == 'and' or op == 'or': 184 | assert children_num == 2 185 | else: 186 | if len(node.production.split()) == 3: 187 | assert children_num == 1 188 | else: 189 | assert children_num == 2 190 | for child in node.children: 191 | assert child.parent == node 192 | verify(child) 193 | 194 | 195 | def label_matrix(lf, matrix, node): 196 | nindex = lf.index(node) 197 | for child in node.children: 198 | if child not in lf: 199 | continue 200 | index = lf.index(child) 201 | matrix[nindex][index] = 1 202 | label_matrix(lf, matrix, child) 203 | 204 | 205 | def build_adjacency_matrix(lf, symmetry=False): 206 | _lf = list() 207 | for rule in lf: 208 | if isinstance(rule, A) or isinstance(rule, C) or isinstance(rule, T): 209 | continue 210 | _lf.append(rule) 211 | length = len(_lf) 212 | matrix = np.zeros((length, length,)) 213 | label_matrix(_lf, matrix, _lf[0]) 214 | if symmetry: 215 | matrix += matrix.T 216 | return matrix 217 | 218 | 219 | if __name__ == '__main__': 220 | with open(r'..\data\train.json', 'r') as f: 221 | data = json.load(f) 222 | for d in data: 223 | rule_label = [eval(x) for x in d['rule_label'].strip().split(' ')] 224 | print(d['question']) 225 | print(rule_label) 226 | build_tree(copy.copy(rule_label)) 227 | adjacency_matrix = build_adjacency_matrix(rule_label, symmetry=True) 228 | print(adjacency_matrix) 229 | print('===\n\n') 230 | -------------------------------------------------------------------------------- /src/rule/semQL.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/24 7 | # @Author : Jiaqi&Zecheng 8 | # @File : semQL.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | Keywords = ['des', 'asc', 'and', 'or', 'sum', 'min', 'max', 'avg', 'none', '=', '!=', '<', '>', '<=', '>=', 'between', 'like', 'not_like'] + [ 13 | 'in', 'not_in', 'count', 'intersect', 'union', 'except' 14 | ] 15 | 16 | 17 | class Grammar(object): 18 | def __init__(self, is_sketch=False): 19 | self.begin = 0 20 | self.type_id = 0 21 | self.is_sketch = is_sketch 22 | self.prod2id = {} 23 | self.type2id = {} 24 | self._init_grammar(Sel) 25 | self._init_grammar(Root) 26 | self._init_grammar(Sup) 27 | self._init_grammar(Filter) 28 | self._init_grammar(Order) 29 | self._init_grammar(N) 30 | self._init_grammar(Root1) 31 | 32 | if not self.is_sketch: 33 | self._init_grammar(A) 34 | 35 | self._init_id2prod() 36 | self.type2id[C] = self.type_id 37 | self.type_id += 1 38 | self.type2id[T] = self.type_id 39 | 40 | def _init_grammar(self, Cls): 41 | """ 42 | get the production of class Cls 43 | :param Cls: 44 | :return: 45 | """ 46 | production = Cls._init_grammar() 47 | for p in production: 48 | self.prod2id[p] = self.begin 49 | self.begin += 1 50 | self.type2id[Cls] = self.type_id 51 | self.type_id += 1 52 | 53 | def _init_id2prod(self): 54 | self.id2prod = {} 55 | for key, value in self.prod2id.items(): 56 | self.id2prod[value] = key 57 | 58 | def get_production(self, Cls): 59 | return Cls._init_grammar() 60 | 61 | 62 | class Action(object): 63 | def __init__(self): 64 | self.pt = 0 65 | self.production = None 66 | self.children = list() 67 | 68 | def get_next_action(self, is_sketch=False): 69 | actions = list() 70 | for x in self.production.split(' ')[1:]: 71 | if x not in Keywords: 72 | rule_type = eval(x) 73 | if is_sketch: 74 | if rule_type is not A: 75 | actions.append(rule_type) 76 | else: 77 | actions.append(rule_type) 78 | return actions 79 | 80 | def set_parent(self, parent): 81 | self.parent = parent 82 | 83 | def add_children(self, child): 84 | self.children.append(child) 85 | 86 | 87 | class Root1(Action): 88 | def __init__(self, id_c, parent=None): 89 | super(Root1, self).__init__() 90 | self.parent = parent 91 | self.id_c = id_c 92 | self._init_grammar() 93 | self.production = self.grammar_dict[id_c] 94 | 95 | @classmethod 96 | def _init_grammar(self): 97 | # TODO: should add Root grammar to this 98 | self.grammar_dict = { 99 | 0: 'Root1 intersect Root Root', 100 | 1: 'Root1 union Root Root', 101 | 2: 'Root1 except Root Root', 102 | 3: 'Root1 Root', 103 | } 104 | self.production_id = {} 105 | for id_x, value in enumerate(self.grammar_dict.values()): 106 | self.production_id[value] = id_x 107 | 108 | return self.grammar_dict.values() 109 | 110 | def __str__(self): 111 | return 'Root1(' + str(self.id_c) + ')' 112 | 113 | def __repr__(self): 114 | return 'Root1(' + str(self.id_c) + ')' 115 | 116 | 117 | class Root(Action): 118 | def __init__(self, id_c, parent=None): 119 | super(Root, self).__init__() 120 | self.parent = parent 121 | self.id_c = id_c 122 | self._init_grammar() 123 | self.production = self.grammar_dict[id_c] 124 | 125 | @classmethod 126 | def _init_grammar(self): 127 | # TODO: should add Root grammar to this 128 | self.grammar_dict = { 129 | 0: 'Root Sel Sup Filter', 130 | 1: 'Root Sel Filter Order', 131 | 2: 'Root Sel Sup', 132 | 3: 'Root Sel Filter', 133 | 4: 'Root Sel Order', 134 | 5: 'Root Sel' 135 | } 136 | self.production_id = {} 137 | for id_x, value in enumerate(self.grammar_dict.values()): 138 | self.production_id[value] = id_x 139 | 140 | return self.grammar_dict.values() 141 | 142 | def __str__(self): 143 | return 'Root(' + str(self.id_c) + ')' 144 | 145 | def __repr__(self): 146 | return 'Root(' + str(self.id_c) + ')' 147 | 148 | 149 | class N(Action): 150 | """ 151 | Number of Columns 152 | """ 153 | def __init__(self, id_c, parent=None): 154 | super(N, self).__init__() 155 | self.parent = parent 156 | self.id_c = id_c 157 | self._init_grammar() 158 | self.production = self.grammar_dict[id_c] 159 | 160 | @classmethod 161 | def _init_grammar(self): 162 | self.grammar_dict = { 163 | 0: 'N A', 164 | 1: 'N A A', 165 | 2: 'N A A A', 166 | 3: 'N A A A A', 167 | 4: 'N A A A A A' 168 | } 169 | self.production_id = {} 170 | for id_x, value in enumerate(self.grammar_dict.values()): 171 | self.production_id[value] = id_x 172 | 173 | return self.grammar_dict.values() 174 | 175 | def __str__(self): 176 | return 'N(' + str(self.id_c) + ')' 177 | 178 | def __repr__(self): 179 | return 'N(' + str(self.id_c) + ')' 180 | 181 | class C(Action): 182 | """ 183 | Column 184 | """ 185 | def __init__(self, id_c, parent=None): 186 | super(C, self).__init__() 187 | self.parent = parent 188 | self.id_c = id_c 189 | self.production = 'C T' 190 | self.table = None 191 | 192 | def __str__(self): 193 | return 'C(' + str(self.id_c) + ')' 194 | 195 | def __repr__(self): 196 | return 'C(' + str(self.id_c) + ')' 197 | 198 | 199 | class T(Action): 200 | """ 201 | Table 202 | """ 203 | def __init__(self, id_c, parent=None): 204 | super(T, self).__init__() 205 | 206 | self.parent = parent 207 | self.id_c = id_c 208 | self.production = 'T min' 209 | self.table = None 210 | 211 | def __str__(self): 212 | return 'T(' + str(self.id_c) + ')' 213 | 214 | def __repr__(self): 215 | return 'T(' + str(self.id_c) + ')' 216 | 217 | 218 | class A(Action): 219 | """ 220 | Aggregator 221 | """ 222 | def __init__(self, id_c, parent=None): 223 | super(A, self).__init__() 224 | 225 | self.parent = parent 226 | self.id_c = id_c 227 | self._init_grammar() 228 | self.production = self.grammar_dict[id_c] 229 | 230 | @classmethod 231 | def _init_grammar(self): 232 | # TODO: should add Root grammar to this 233 | self.grammar_dict = { 234 | 0: 'A none C', 235 | 1: 'A max C', 236 | 2: "A min C", 237 | 3: "A count C", 238 | 4: "A sum C", 239 | 5: "A avg C" 240 | } 241 | self.production_id = {} 242 | for id_x, value in enumerate(self.grammar_dict.values()): 243 | self.production_id[value] = id_x 244 | 245 | return self.grammar_dict.values() 246 | 247 | def __str__(self): 248 | return 'A(' + str(self.id_c) + ')' 249 | 250 | def __repr__(self): 251 | return 'A(' + str(self.grammar_dict[self.id_c].split(' ')[1]) + ')' 252 | 253 | 254 | class Sel(Action): 255 | """ 256 | Select 257 | """ 258 | def __init__(self, id_c, parent=None): 259 | super(Sel, self).__init__() 260 | 261 | self.parent = parent 262 | self.id_c = id_c 263 | self._init_grammar() 264 | self.production = self.grammar_dict[id_c] 265 | 266 | @classmethod 267 | def _init_grammar(self): 268 | self.grammar_dict = { 269 | 0: 'Sel N', 270 | } 271 | self.production_id = {} 272 | for id_x, value in enumerate(self.grammar_dict.values()): 273 | self.production_id[value] = id_x 274 | 275 | return self.grammar_dict.values() 276 | 277 | def __str__(self): 278 | return 'Sel(' + str(self.id_c) + ')' 279 | 280 | def __repr__(self): 281 | return 'Sel(' + str(self.id_c) + ')' 282 | 283 | class Filter(Action): 284 | """ 285 | Filter 286 | """ 287 | def __init__(self, id_c, parent=None): 288 | super(Filter, self).__init__() 289 | 290 | self.parent = parent 291 | self.id_c = id_c 292 | self._init_grammar() 293 | self.production = self.grammar_dict[id_c] 294 | 295 | @classmethod 296 | def _init_grammar(self): 297 | self.grammar_dict = { 298 | # 0: "Filter 1" 299 | 0: 'Filter and Filter Filter', 300 | 1: 'Filter or Filter Filter', 301 | 2: 'Filter = A', 302 | 3: 'Filter != A', 303 | 4: 'Filter < A', 304 | 5: 'Filter > A', 305 | 6: 'Filter <= A', 306 | 7: 'Filter >= A', 307 | 8: 'Filter between A', 308 | 9: 'Filter like A', 309 | 10: 'Filter not_like A', 310 | # now begin root 311 | 11: 'Filter = A Root', 312 | 12: 'Filter < A Root', 313 | 13: 'Filter > A Root', 314 | 14: 'Filter != A Root', 315 | 15: 'Filter between A Root', 316 | 16: 'Filter >= A Root', 317 | 17: 'Filter <= A Root', 318 | # now for In 319 | 18: 'Filter in A Root', 320 | 19: 'Filter not_in A Root' 321 | 322 | } 323 | self.production_id = {} 324 | for id_x, value in enumerate(self.grammar_dict.values()): 325 | self.production_id[value] = id_x 326 | 327 | return self.grammar_dict.values() 328 | 329 | def __str__(self): 330 | return 'Filter(' + str(self.id_c) + ')' 331 | 332 | def __repr__(self): 333 | return 'Filter(' + str(self.grammar_dict[self.id_c]) + ')' 334 | 335 | 336 | class Sup(Action): 337 | """ 338 | Superlative 339 | """ 340 | def __init__(self, id_c, parent=None): 341 | super(Sup, self).__init__() 342 | 343 | self.parent = parent 344 | self.id_c = id_c 345 | self._init_grammar() 346 | self.production = self.grammar_dict[id_c] 347 | 348 | @classmethod 349 | def _init_grammar(self): 350 | self.grammar_dict = { 351 | 0: 'Sup des A', 352 | 1: 'Sup asc A', 353 | } 354 | self.production_id = {} 355 | for id_x, value in enumerate(self.grammar_dict.values()): 356 | self.production_id[value] = id_x 357 | 358 | return self.grammar_dict.values() 359 | 360 | def __str__(self): 361 | return 'Sup(' + str(self.id_c) + ')' 362 | 363 | def __repr__(self): 364 | return 'Sup(' + str(self.id_c) + ')' 365 | 366 | 367 | class Order(Action): 368 | """ 369 | Order 370 | """ 371 | def __init__(self, id_c, parent=None): 372 | super(Order, self).__init__() 373 | 374 | self.parent = parent 375 | self.id_c = id_c 376 | self._init_grammar() 377 | self.production = self.grammar_dict[id_c] 378 | 379 | @classmethod 380 | def _init_grammar(self): 381 | self.grammar_dict = { 382 | 0: 'Order des A', 383 | 1: 'Order asc A', 384 | } 385 | self.production_id = {} 386 | for id_x, value in enumerate(self.grammar_dict.values()): 387 | self.production_id[value] = id_x 388 | 389 | return self.grammar_dict.values() 390 | 391 | def __str__(self): 392 | return 'Order(' + str(self.id_c) + ')' 393 | 394 | def __repr__(self): 395 | return 'Order(' + str(self.id_c) + ')' 396 | 397 | 398 | if __name__ == '__main__': 399 | print(list(Root._init_grammar())) 400 | -------------------------------------------------------------------------------- /src/rule/sem_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/27 7 | # @Author : Jiaqi&Zecheng 8 | # @File : sem_utils.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import os 13 | import json 14 | import argparse 15 | import re as regex 16 | from nltk.stem import WordNetLemmatizer 17 | from pattern.en import lemma 18 | wordnet_lemmatizer = WordNetLemmatizer() 19 | 20 | 21 | def load_dataSets(args): 22 | with open(args.input_path, 'r') as f: 23 | datas = json.load(f) 24 | with open(os.path.join(args.data_path, 'tables.json'), 'r', encoding='utf8') as f: 25 | table_datas = json.load(f) 26 | schemas = dict() 27 | for i in range(len(table_datas)): 28 | schemas[table_datas[i]['db_id']] = table_datas[i] 29 | return datas, schemas 30 | 31 | 32 | def partial_match(query, table_name): 33 | query = [lemma(x) for x in query] 34 | table_name = [lemma(x) for x in table_name] 35 | if query in table_name: 36 | return True 37 | return False 38 | 39 | 40 | def is_partial_match(query, table_names): 41 | query = lemma(query) 42 | table_names = [[lemma(x) for x in names.split(' ') ] for names in table_names] 43 | same_count = 0 44 | result = None 45 | for names in table_names: 46 | if query in names: 47 | same_count += 1 48 | result = names 49 | return result if same_count == 1 else False 50 | 51 | 52 | def multi_option(question, q_ind, names, N): 53 | for i in range(q_ind + 1, q_ind + N + 1): 54 | if i < len(question): 55 | re = is_partial_match(question[i][0], names) 56 | if re is not False: 57 | return re 58 | return False 59 | 60 | 61 | def multi_equal(question, q_ind, names, N): 62 | for i in range(q_ind + 1, q_ind + N + 1): 63 | if i < len(question): 64 | if question[i] == names: 65 | return i 66 | return False 67 | 68 | 69 | def random_choice(question_arg, question_arg_type, names, ground_col_labels, q_ind, N, origin_name): 70 | # first try if there are other table 71 | for t_ind, t_val in enumerate(question_arg_type): 72 | if t_val == ['table']: 73 | return names[origin_name.index(question_arg[t_ind])] 74 | for i in range(q_ind + 1, q_ind + N + 1): 75 | if i < len(question_arg): 76 | if len(ground_col_labels) == 0: 77 | for n in names: 78 | if partial_match(question_arg[i][0], n) is True: 79 | return n 80 | else: 81 | for n_id, n in enumerate(names): 82 | if n_id in ground_col_labels and partial_match(question_arg[i][0], n) is True: 83 | return n 84 | if len(ground_col_labels) > 0: 85 | return names[ground_col_labels[0]] 86 | else: 87 | return names[0] 88 | 89 | 90 | def find_table(cur_table, origin_table_names, question_arg_type, question_arg): 91 | h_table = None 92 | for i in range(len(question_arg_type))[::-1]: 93 | if question_arg_type[i] == ['table']: 94 | h_table = question_arg[i] 95 | h_table = origin_table_names.index(h_table) 96 | if h_table != cur_table: 97 | break 98 | if h_table != cur_table: 99 | return h_table 100 | 101 | # find partial 102 | for i in range(len(question_arg_type))[::-1]: 103 | if question_arg_type[i] == ['NONE']: 104 | for t_id, table_name in enumerate(origin_table_names): 105 | if partial_match(question_arg[i], table_name) is True and t_id != h_table: 106 | return t_id 107 | 108 | # random return 109 | for i in range(len(question_arg_type))[::-1]: 110 | if question_arg_type[i] == ['table']: 111 | h_table = question_arg[i] 112 | h_table = origin_table_names.index(h_table) 113 | return h_table 114 | 115 | return cur_table 116 | 117 | 118 | def alter_not_in(datas, schemas): 119 | for d in datas: 120 | if 'Filter(19)' in d['model_result']: 121 | current_table = schemas[d['db_id']] 122 | current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']] 123 | current_table['col_table'] = [col[0] for col in current_table['column_names']] 124 | origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in 125 | d['table_names']] 126 | question_arg_type = d['question_arg_type'] 127 | question_arg = d['question_arg'] 128 | pred_label = d['model_result'].split(' ') 129 | 130 | # get potiantial table 131 | cur_table = None 132 | for label_id, label_val in enumerate(pred_label): 133 | if label_val in ['Filter(19)']: 134 | cur_table = int(pred_label[label_id - 1][2:-1]) 135 | break 136 | 137 | h_table = find_table(cur_table, origin_table_names, question_arg_type, question_arg) 138 | 139 | for label_id, label_val in enumerate(pred_label): 140 | if label_val in ['Filter(19)']: 141 | for primary in current_table['primary_keys']: 142 | if int(current_table['col_table'][primary]) == int(pred_label[label_id - 1][2:-1]): 143 | pred_label[label_id + 2] = 'C(' + str( 144 | d['col_set'].index(current_table['schema_content_clean'][primary])) + ')' 145 | break 146 | for pair in current_table['foreign_keys']: 147 | if int(current_table['col_table'][pair[0]]) == h_table and d['col_set'].index( 148 | current_table['schema_content_clean'][pair[1]]) == int(pred_label[label_id + 2][2:-1]): 149 | pred_label[label_id + 8] = 'C(' + str( 150 | d['col_set'].index(current_table['schema_content_clean'][pair[0]])) + ')' 151 | pred_label[label_id + 9] = 'T(' + str(h_table) + ')' 152 | break 153 | elif int(current_table['col_table'][pair[1]]) == h_table and d['col_set'].index( 154 | current_table['schema_content_clean'][pair[0]]) == int(pred_label[label_id + 2][2:-1]): 155 | pred_label[label_id + 8] = 'C(' + str( 156 | d['col_set'].index(current_table['schema_content_clean'][pair[1]])) + ')' 157 | pred_label[label_id + 9] = 'T(' + str(h_table) + ')' 158 | break 159 | pred_label[label_id + 3] = pred_label[label_id - 1] 160 | 161 | d['model_result'] = " ".join(pred_label) 162 | 163 | 164 | def alter_inter(datas): 165 | for d in datas: 166 | if 'Filter(0)' in d['model_result']: 167 | now_result = d['model_result'].split(' ') 168 | index = now_result.index('Filter(0)') 169 | c1 = None 170 | c2 = None 171 | for i in range(index + 1, len(now_result)): 172 | if c1 is None and 'C(' in now_result[i]: 173 | c1 = now_result[i] 174 | elif c1 is not None and c2 is None and 'C(' in now_result[i]: 175 | c2 = now_result[i] 176 | 177 | if c1 != c2 or c1 is None or c2 is None: 178 | continue 179 | replace_result = ['Root1(0)'] + now_result[1:now_result.index('Filter(0)')] 180 | for r_id, r_val in enumerate(now_result[now_result.index('Filter(0)') + 2:]): 181 | if 'Filter' in r_val: 182 | break 183 | 184 | replace_result = replace_result + now_result[now_result.index('Filter(0)') + 1:r_id + now_result.index( 185 | 'Filter(0)') + 2] 186 | replace_result = replace_result + now_result[1:now_result.index('Filter(0)')] 187 | 188 | replace_result = replace_result + now_result[r_id + now_result.index('Filter(0)') + 2:] 189 | replace_result = " ".join(replace_result) 190 | d['model_result'] = replace_result 191 | 192 | 193 | def alter_column0(datas): 194 | """ 195 | Attach column * table 196 | :return: model_result_replace 197 | """ 198 | zero_count = 0 199 | count = 0 200 | result = [] 201 | for d in datas: 202 | if 'C(0)' in d['model_result']: 203 | pattern = regex.compile('C\(.*?\) T\(.*?\)') 204 | result_pattern = list(set(pattern.findall(d['model_result']))) 205 | ground_col_labels = [] 206 | for pa in result_pattern: 207 | pa = pa.split(' ') 208 | if pa[0] != 'C(0)': 209 | index = int(pa[1][2:-1]) 210 | ground_col_labels.append(index) 211 | 212 | ground_col_labels = list(set(ground_col_labels)) 213 | question_arg_type = d['question_arg_type'] 214 | question_arg = d['question_arg'] 215 | table_names = [[lemma(x) for x in names.split(' ')] for names in d['table_names']] 216 | origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in 217 | d['table_names']] 218 | count += 1 219 | easy_flag = False 220 | for q_ind, q in enumerate(d['question_arg']): 221 | q = [lemma(x) for x in q] 222 | q_str = " ".join(" ".join(x) for x in d['question_arg']) 223 | if 'how many' in q_str or 'number of' in q_str or 'count of' in q_str: 224 | easy_flag = True 225 | if easy_flag: 226 | # check for the last one is a table word 227 | for q_ind, q in enumerate(d['question_arg']): 228 | if (q_ind > 0 and q == ['many'] and d['question_arg'][q_ind - 1] == ['how']) or ( 229 | q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['number']) or ( 230 | q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['count']): 231 | re = multi_equal(question_arg_type, q_ind, ['table'], 2) 232 | if re is not False: 233 | # This step work for the number of [table] example 234 | table_result = table_names[origin_table_names.index(question_arg[re])] 235 | result.append((d['query'], d['question'], table_result, d)) 236 | break 237 | else: 238 | re = multi_option(question_arg, q_ind, d['table_names'], 2) 239 | if re is not False: 240 | table_result = re 241 | result.append((d['query'], d['question'], table_result, d)) 242 | pass 243 | else: 244 | re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type)) 245 | if re is not False: 246 | # This step work for the number of [table] example 247 | table_result = table_names[origin_table_names.index(question_arg[re])] 248 | result.append((d['query'], d['question'], table_result, d)) 249 | break 250 | pass 251 | table_result = random_choice(question_arg=question_arg, 252 | question_arg_type=question_arg_type, 253 | names=table_names, 254 | ground_col_labels=ground_col_labels, q_ind=q_ind, N=2, 255 | origin_name=origin_table_names) 256 | result.append((d['query'], d['question'], table_result, d)) 257 | 258 | zero_count += 1 259 | break 260 | 261 | else: 262 | M_OP = False 263 | for q_ind, q in enumerate(d['question_arg']): 264 | if M_OP is False and q in [['than'], ['least'], ['most'], ['msot'], ['fewest']] or \ 265 | question_arg_type[q_ind] == ['M_OP']: 266 | M_OP = True 267 | re = multi_equal(question_arg_type, q_ind, ['table'], 3) 268 | if re is not False: 269 | # This step work for the number of [table] example 270 | table_result = table_names[origin_table_names.index(question_arg[re])] 271 | result.append((d['query'], d['question'], table_result, d)) 272 | break 273 | else: 274 | re = multi_option(question_arg, q_ind, d['table_names'], 3) 275 | if re is not False: 276 | table_result = re 277 | # print(table_result) 278 | result.append((d['query'], d['question'], table_result, d)) 279 | pass 280 | else: 281 | # zero_count += 1 282 | re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type)) 283 | if re is not False: 284 | # This step work for the number of [table] example 285 | table_result = table_names[origin_table_names.index(question_arg[re])] 286 | result.append((d['query'], d['question'], table_result, d)) 287 | break 288 | 289 | table_result = random_choice(question_arg=question_arg, 290 | question_arg_type=question_arg_type, 291 | names=table_names, 292 | ground_col_labels=ground_col_labels, q_ind=q_ind, N=2, 293 | origin_name=origin_table_names) 294 | result.append((d['query'], d['question'], table_result, d)) 295 | 296 | pass 297 | if M_OP is False: 298 | table_result = random_choice(question_arg=question_arg, 299 | question_arg_type=question_arg_type, 300 | names=table_names, ground_col_labels=ground_col_labels, q_ind=q_ind, 301 | N=2, 302 | origin_name=origin_table_names) 303 | result.append((d['query'], d['question'], table_result, d)) 304 | 305 | for re in result: 306 | table_names = [[lemma(x) for x in names.split(' ')] for names in re[3]['table_names']] 307 | origin_table_names = [[x for x in names.split(' ')] for names in re[3]['table_names']] 308 | if re[2] in table_names: 309 | re[3]['rule_count'] = table_names.index(re[2]) 310 | else: 311 | re[3]['rule_count'] = origin_table_names.index(re[2]) 312 | 313 | for data in datas: 314 | if 'rule_count' in data: 315 | str_replace = 'C(0) T(' + str(data['rule_count']) + ')' 316 | replace_result = regex.sub('C\(0\) T\(.\)', str_replace, data['model_result']) 317 | data['model_result_replace'] = replace_result 318 | else: 319 | data['model_result_replace'] = data['model_result'] 320 | 321 | 322 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # -*- coding: utf-8 -*- 5 | """ 6 | # @Time : 2019/5/25 7 | # @Author : Jiaqi&Zecheng 8 | # @File : utils.py 9 | # @Software: PyCharm 10 | """ 11 | 12 | import json 13 | import time 14 | 15 | import copy 16 | import numpy as np 17 | import os 18 | import torch 19 | from nltk.stem import WordNetLemmatizer 20 | 21 | from src.dataset import Example 22 | from src.rule import lf 23 | from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1 24 | 25 | wordnet_lemmatizer = WordNetLemmatizer() 26 | 27 | 28 | def load_word_emb(file_name, use_small=False): 29 | print ('Loading word embedding from %s'%file_name) 30 | ret = {} 31 | with open(file_name) as inf: 32 | for idx, line in enumerate(inf): 33 | if (use_small and idx >= 500000): 34 | break 35 | info = line.strip().split(' ') 36 | if info[0].lower() not in ret: 37 | ret[info[0]] = np.array(list(map(lambda x:float(x), info[1:]))) 38 | return ret 39 | 40 | def lower_keys(x): 41 | if isinstance(x, list): 42 | return [lower_keys(v) for v in x] 43 | elif isinstance(x, dict): 44 | return dict((k.lower(), lower_keys(v)) for k, v in x.items()) 45 | else: 46 | return x 47 | 48 | def get_table_colNames(tab_ids, tab_cols): 49 | table_col_dict = {} 50 | for ci, cv in zip(tab_ids, tab_cols): 51 | if ci != -1: 52 | table_col_dict[ci] = table_col_dict.get(ci, []) + cv 53 | result = [] 54 | for ci in range(len(table_col_dict)): 55 | result.append(table_col_dict[ci]) 56 | return result 57 | 58 | def get_col_table_dict(tab_cols, tab_ids, sql): 59 | table_dict = {} 60 | for c_id, c_v in enumerate(sql['col_set']): 61 | for cor_id, cor_val in enumerate(tab_cols): 62 | if c_v == cor_val: 63 | table_dict[tab_ids[cor_id]] = table_dict.get(tab_ids[cor_id], []) + [c_id] 64 | 65 | col_table_dict = {} 66 | for key_item, value_item in table_dict.items(): 67 | for value in value_item: 68 | col_table_dict[value] = col_table_dict.get(value, []) + [key_item] 69 | col_table_dict[0] = [x for x in range(len(table_dict) - 1)] 70 | return col_table_dict 71 | 72 | 73 | def schema_linking(question_arg, question_arg_type, one_hot_type, col_set_type, col_set_iter, sql): 74 | 75 | for count_q, t_q in enumerate(question_arg_type): 76 | t = t_q[0] 77 | if t == 'NONE': 78 | continue 79 | elif t == 'table': 80 | one_hot_type[count_q][0] = 1 81 | question_arg[count_q] = ['table'] + question_arg[count_q] 82 | elif t == 'col': 83 | one_hot_type[count_q][1] = 1 84 | try: 85 | col_set_type[col_set_iter.index(question_arg[count_q])][1] = 5 86 | question_arg[count_q] = ['column'] + question_arg[count_q] 87 | except: 88 | print(col_set_iter, question_arg[count_q]) 89 | raise RuntimeError("not in col set") 90 | elif t == 'agg': 91 | one_hot_type[count_q][2] = 1 92 | elif t == 'MORE': 93 | one_hot_type[count_q][3] = 1 94 | elif t == 'MOST': 95 | one_hot_type[count_q][4] = 1 96 | elif t == 'value': 97 | one_hot_type[count_q][5] = 1 98 | question_arg[count_q] = ['value'] + question_arg[count_q] 99 | else: 100 | if len(t_q) == 1: 101 | for col_probase in t_q: 102 | if col_probase == 'asd': 103 | continue 104 | try: 105 | col_set_type[sql['col_set'].index(col_probase)][2] = 5 106 | question_arg[count_q] = ['value'] + question_arg[count_q] 107 | except: 108 | print(sql['col_set'], col_probase) 109 | raise RuntimeError('not in col') 110 | one_hot_type[count_q][5] = 1 111 | else: 112 | for col_probase in t_q: 113 | if col_probase == 'asd': 114 | continue 115 | col_set_type[sql['col_set'].index(col_probase)][3] += 1 116 | 117 | def process(sql, table): 118 | 119 | process_dict = {} 120 | 121 | origin_sql = sql['question_toks'] 122 | table_names = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in table['table_names']] 123 | 124 | sql['pre_sql'] = copy.deepcopy(sql) 125 | 126 | tab_cols = [col[1] for col in table['column_names']] 127 | tab_ids = [col[0] for col in table['column_names']] 128 | 129 | col_set_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in sql['col_set']] 130 | col_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(" ")] for x in tab_cols] 131 | q_iter_small = [wordnet_lemmatizer.lemmatize(x).lower() for x in origin_sql] 132 | question_arg = copy.deepcopy(sql['question_arg']) 133 | question_arg_type = sql['question_arg_type'] 134 | one_hot_type = np.zeros((len(question_arg_type), 6)) 135 | 136 | col_set_type = np.zeros((len(col_set_iter), 4)) 137 | 138 | process_dict['col_set_iter'] = col_set_iter 139 | process_dict['q_iter_small'] = q_iter_small 140 | process_dict['col_set_type'] = col_set_type 141 | process_dict['question_arg'] = question_arg 142 | process_dict['question_arg_type'] = question_arg_type 143 | process_dict['one_hot_type'] = one_hot_type 144 | process_dict['tab_cols'] = tab_cols 145 | process_dict['tab_ids'] = tab_ids 146 | process_dict['col_iter'] = col_iter 147 | process_dict['table_names'] = table_names 148 | 149 | return process_dict 150 | 151 | def is_valid(rule_label, col_table_dict, sql): 152 | try: 153 | lf.build_tree(copy.copy(rule_label)) 154 | except: 155 | print(rule_label) 156 | 157 | flag = False 158 | for r_id, rule in enumerate(rule_label): 159 | if type(rule) == C: 160 | try: 161 | assert rule_label[r_id + 1].id_c in col_table_dict[rule.id_c], print(sql['question']) 162 | except: 163 | flag = True 164 | print(sql['question']) 165 | return flag is False 166 | 167 | 168 | def to_batch_seq(sql_data, table_data, idxes, st, ed, 169 | is_train=True): 170 | """ 171 | 172 | :return: 173 | """ 174 | examples = [] 175 | 176 | for i in range(st, ed): 177 | sql = sql_data[idxes[i]] 178 | table = table_data[sql['db_id']] 179 | 180 | process_dict = process(sql, table) 181 | 182 | for c_id, col_ in enumerate(process_dict['col_set_iter']): 183 | for q_id, ori in enumerate(process_dict['q_iter_small']): 184 | if ori in col_: 185 | process_dict['col_set_type'][c_id][0] += 1 186 | 187 | schema_linking(process_dict['question_arg'], process_dict['question_arg_type'], 188 | process_dict['one_hot_type'], process_dict['col_set_type'], process_dict['col_set_iter'], sql) 189 | 190 | col_table_dict = get_col_table_dict(process_dict['tab_cols'], process_dict['tab_ids'], sql) 191 | table_col_name = get_table_colNames(process_dict['tab_ids'], process_dict['col_iter']) 192 | 193 | process_dict['col_set_iter'][0] = ['count', 'number', 'many'] 194 | 195 | rule_label = None 196 | if 'rule_label' in sql: 197 | try: 198 | rule_label = [eval(x) for x in sql['rule_label'].strip().split(' ')] 199 | except: 200 | continue 201 | if is_valid(rule_label, col_table_dict=col_table_dict, sql=sql) is False: 202 | continue 203 | 204 | example = Example( 205 | src_sent=process_dict['question_arg'], 206 | col_num=len(process_dict['col_set_iter']), 207 | vis_seq=(sql['question'], process_dict['col_set_iter'], sql['query']), 208 | tab_cols=process_dict['col_set_iter'], 209 | sql=sql['query'], 210 | one_hot_type=process_dict['one_hot_type'], 211 | col_hot_type=process_dict['col_set_type'], 212 | table_names=process_dict['table_names'], 213 | table_len=len(process_dict['table_names']), 214 | col_table_dict=col_table_dict, 215 | cols=process_dict['tab_cols'], 216 | table_col_name=table_col_name, 217 | table_col_len=len(table_col_name), 218 | tokenized_src_sent=process_dict['col_set_type'], 219 | tgt_actions=rule_label 220 | ) 221 | example.sql_json = copy.deepcopy(sql) 222 | examples.append(example) 223 | 224 | if is_train: 225 | examples.sort(key=lambda e: -len(e.src_sent)) 226 | return examples 227 | else: 228 | return examples 229 | 230 | def epoch_train(model, optimizer, batch_size, sql_data, table_data, 231 | args, epoch=0, loss_epoch_threshold=20, sketch_loss_coefficient=0.2): 232 | model.train() 233 | # shuffe 234 | perm=np.random.permutation(len(sql_data)) 235 | cum_loss = 0.0 236 | st = 0 237 | while st < len(sql_data): 238 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 239 | examples = to_batch_seq(sql_data, table_data, perm, st, ed) 240 | optimizer.zero_grad() 241 | 242 | score = model.forward(examples) 243 | loss_sketch = -score[0] 244 | loss_lf = -score[1] 245 | 246 | loss_sketch = torch.mean(loss_sketch) 247 | loss_lf = torch.mean(loss_lf) 248 | 249 | if epoch > loss_epoch_threshold: 250 | loss = loss_lf + sketch_loss_coefficient * loss_sketch 251 | else: 252 | loss = loss_lf + loss_sketch 253 | 254 | loss.backward() 255 | if args.clip_grad > 0.: 256 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) 257 | optimizer.step() 258 | cum_loss += loss.data.cpu().numpy()*(ed - st) 259 | st = ed 260 | return cum_loss / len(sql_data) 261 | 262 | def epoch_acc(model, batch_size, sql_data, table_data, beam_size=3): 263 | model.eval() 264 | perm = list(range(len(sql_data))) 265 | st = 0 266 | 267 | json_datas = [] 268 | sketch_correct, rule_label_correct, total = 0, 0, 0 269 | while st < len(sql_data): 270 | ed = st+batch_size if st+batch_size < len(perm) else len(perm) 271 | examples = to_batch_seq(sql_data, table_data, perm, st, ed, 272 | is_train=False) 273 | for example in examples: 274 | results_all = model.parse(example, beam_size=beam_size) 275 | results = results_all[0] 276 | list_preds = [] 277 | try: 278 | 279 | pred = " ".join([str(x) for x in results[0].actions]) 280 | for x in results: 281 | list_preds.append(" ".join(str(x.actions))) 282 | except Exception as e: 283 | # print('Epoch Acc: ', e) 284 | # print(results) 285 | # print(results_all) 286 | pred = "" 287 | 288 | simple_json = example.sql_json['pre_sql'] 289 | 290 | simple_json['sketch_result'] = " ".join(str(x) for x in results_all[1]) 291 | simple_json['model_result'] = pred 292 | 293 | truth_sketch = " ".join([str(x) for x in example.sketch]) 294 | truth_rule_label = " ".join([str(x) for x in example.tgt_actions]) 295 | 296 | if truth_sketch == simple_json['sketch_result']: 297 | sketch_correct += 1 298 | if truth_rule_label == simple_json['model_result']: 299 | rule_label_correct += 1 300 | total += 1 301 | 302 | json_datas.append(simple_json) 303 | st = ed 304 | return json_datas, float(sketch_correct)/float(total), float(rule_label_correct)/float(total) 305 | 306 | def eval_acc(preds, sqls): 307 | sketch_correct, best_correct = 0, 0 308 | for i, (pred, sql) in enumerate(zip(preds, sqls)): 309 | if pred['model_result'] == sql['rule_label']: 310 | best_correct += 1 311 | print(best_correct / len(preds)) 312 | return best_correct / len(preds) 313 | 314 | 315 | def load_data_new(sql_path, table_data, use_small=False): 316 | sql_data = [] 317 | 318 | print("Loading data from %s" % sql_path) 319 | with open(sql_path) as inf: 320 | data = lower_keys(json.load(inf)) 321 | sql_data += data 322 | 323 | table_data_new = {table['db_id']: table for table in table_data} 324 | 325 | if use_small: 326 | return sql_data[:80], table_data_new 327 | else: 328 | return sql_data, table_data_new 329 | 330 | 331 | def load_dataset(dataset_dir, use_small=False): 332 | print("Loading from datasets...") 333 | 334 | TABLE_PATH = os.path.join(dataset_dir, "tables.json") 335 | TRAIN_PATH = os.path.join(dataset_dir, "train.json") 336 | DEV_PATH = os.path.join(dataset_dir, "dev.json") 337 | with open(TABLE_PATH) as inf: 338 | print("Loading data from %s"%TABLE_PATH) 339 | table_data = json.load(inf) 340 | 341 | train_sql_data, train_table_data = load_data_new(TRAIN_PATH, table_data, use_small=use_small) 342 | val_sql_data, val_table_data = load_data_new(DEV_PATH, table_data, use_small=use_small) 343 | 344 | return train_sql_data, train_table_data, val_sql_data, val_table_data 345 | 346 | 347 | def save_checkpoint(model, checkpoint_name): 348 | torch.save(model.state_dict(), checkpoint_name) 349 | 350 | 351 | def save_args(args, path): 352 | with open(path, 'w') as f: 353 | f.write(json.dumps(vars(args), indent=4)) 354 | 355 | def init_log_checkpoint_path(args): 356 | save_path = args.save 357 | dir_name = save_path + str(int(time.time())) 358 | save_path = os.path.join(os.path.curdir, 'saved_model', dir_name) 359 | if os.path.exists(save_path) is False: 360 | os.makedirs(save_path) 361 | return save_path 362 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # Licensed under the MIT license. 6 | # -*- coding: utf-8 -*- 7 | """ 8 | # @Time : 2019/5/25 9 | # @Author : Jiaqi&Zecheng 10 | # @File : train.py 11 | # @Software: PyCharm 12 | """ 13 | 14 | import time 15 | import traceback 16 | 17 | import os 18 | import torch 19 | import torch.optim as optim 20 | import tqdm 21 | import copy 22 | 23 | from src import args as arg 24 | from src import utils 25 | from src.models.model import IRNet 26 | from src.rule import semQL 27 | 28 | 29 | def train(args): 30 | """ 31 | :param args: 32 | :return: 33 | """ 34 | 35 | grammar = semQL.Grammar() 36 | sql_data, table_data, val_sql_data,\ 37 | val_table_data= utils.load_dataset(args.dataset, use_small=args.toy) 38 | 39 | model = IRNet(args, grammar) 40 | 41 | 42 | if args.cuda: model.cuda() 43 | 44 | # now get the optimizer 45 | optimizer_cls = eval('torch.optim.%s' % args.optimizer) 46 | optimizer = optimizer_cls(model.parameters(), lr=args.lr) 47 | print('Enable Learning Rate Scheduler: ', args.lr_scheduler) 48 | if args.lr_scheduler: 49 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar) 50 | else: 51 | scheduler = None 52 | 53 | print('Loss epoch threshold: %d' % args.loss_epoch_threshold) 54 | print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient) 55 | 56 | if args.load_model: 57 | print('load pretrained model from %s'% (args.load_model)) 58 | pretrained_model = torch.load(args.load_model, 59 | map_location=lambda storage, loc: storage) 60 | pretrained_modeled = copy.deepcopy(pretrained_model) 61 | for k in pretrained_model.keys(): 62 | if k not in model.state_dict().keys(): 63 | del pretrained_modeled[k] 64 | 65 | model.load_state_dict(pretrained_modeled) 66 | 67 | model.word_emb = utils.load_word_emb(args.glove_embed_path) 68 | # begin train 69 | 70 | model_save_path = utils.init_log_checkpoint_path(args) 71 | utils.save_args(args, os.path.join(model_save_path, 'config.json')) 72 | best_dev_acc = .0 73 | 74 | try: 75 | with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd: 76 | for epoch in tqdm.tqdm(range(args.epoch)): 77 | if args.lr_scheduler: 78 | scheduler.step() 79 | epoch_begin = time.time() 80 | loss = utils.epoch_train(model, optimizer, args.batch_size, sql_data, table_data, args, 81 | loss_epoch_threshold=args.loss_epoch_threshold, 82 | sketch_loss_coefficient=args.sketch_loss_coefficient) 83 | epoch_end = time.time() 84 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, 85 | beam_size=args.beam_size) 86 | # acc = utils.eval_acc(json_datas, val_sql_data) 87 | 88 | if acc > best_dev_acc: 89 | utils.save_checkpoint(model, os.path.join(model_save_path, 'best_model.model')) 90 | best_dev_acc = acc 91 | utils.save_checkpoint(model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc)) 92 | 93 | log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % ( 94 | epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin) 95 | tqdm.tqdm.write(log_str) 96 | epoch_fd.write(log_str) 97 | epoch_fd.flush() 98 | except Exception as e: 99 | # Save model 100 | utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) 101 | print(e) 102 | tb = traceback.format_exc() 103 | print(tb) 104 | else: 105 | utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) 106 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, 107 | beam_size=args.beam_size) 108 | # acc = utils.eval_acc(json_datas, val_sql_data) 109 | 110 | print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (sketch_acc, acc, acc,)) 111 | 112 | 113 | if __name__ == '__main__': 114 | arg_parser = arg.init_arg_parser() 115 | args = arg.init_config(arg_parser) 116 | print(args) 117 | train(args) 118 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | #!/bin/bash 5 | 6 | devices=$1 7 | save_name=$2 8 | 9 | CUDA_VISIBLE_DEVICES=$devices python -u train.py --dataset ./data \ 10 | --glove_embed_path ./data/glove.42B.300d.txt \ 11 | --cuda \ 12 | --epoch 50 \ 13 | --loss_epoch_threshold 50 \ 14 | --sketch_loss_coefficie 1.0 \ 15 | --beam_size 1 \ 16 | --seed 90 \ 17 | --save ${save_name} \ 18 | --embed_size 300 \ 19 | --sentence_features \ 20 | --column_pointer \ 21 | --hidden_size 300 \ 22 | --lr_scheduler \ 23 | --lr_scheduler_gammar 0.5 \ 24 | --att_vec_size 300 > ${save_name}".log" --------------------------------------------------------------------------------