├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── NOTICE
├── README.md
├── SECURITY.md
├── __init__.py
├── eval.py
├── eval.sh
├── preprocess
├── data_process.py
├── download_nltk.py
├── run_me.sh
├── sql2SemQL.py
└── utils.py
├── requirements.txt
├── sem2SQL.py
├── src
├── __init__.py
├── args.py
├── beam.py
├── dataset.py
├── models
│ ├── __init__.py
│ ├── basic_model.py
│ ├── model.py
│ ├── nn_utils.py
│ └── pointer_net.py
├── rule
│ ├── __init__.py
│ ├── graph.py
│ ├── lf.py
│ ├── semQL.py
│ └── sem_utils.py
└── utils.py
├── train.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.suo
8 | *.user
9 | *.userosscache
10 | *.sln.docstates
11 |
12 | # User-specific files (MonoDevelop/Xamarin Studio)
13 | *.userprefs
14 |
15 | # Build results
16 | [Dd]ebug/
17 | [Dd]ebugPublic/
18 | [Rr]elease/
19 | [Rr]eleases/
20 | x64/
21 | x86/
22 | bld/
23 | [Bb]in/
24 | [Oo]bj/
25 | [Ll]og/
26 |
27 | # Visual Studio 2015/2017 cache/options directory
28 | .vs/
29 | # Uncomment if you have tasks that create the project's static files in wwwroot
30 | #wwwroot/
31 |
32 | # Visual Studio 2017 auto generated files
33 | Generated\ Files/
34 |
35 | # MSTest test Results
36 | [Tt]est[Rr]esult*/
37 | [Bb]uild[Ll]og.*
38 |
39 | # NUNIT
40 | *.VisualState.xml
41 | TestResult.xml
42 |
43 | # Build Results of an ATL Project
44 | [Dd]ebugPS/
45 | [Rr]eleasePS/
46 | dlldata.c
47 |
48 | # Benchmark Results
49 | BenchmarkDotNet.Artifacts/
50 |
51 | # .NET Core
52 | project.lock.json
53 | project.fragment.lock.json
54 | artifacts/
55 | **/Properties/launchSettings.json
56 |
57 | # StyleCop
58 | StyleCopReport.xml
59 |
60 | # Files built by Visual Studio
61 | *_i.c
62 | *_p.c
63 | *_i.h
64 | *.ilk
65 | *.meta
66 | *.obj
67 | *.iobj
68 | *.pch
69 | *.pdb
70 | *.ipdb
71 | *.pgc
72 | *.pgd
73 | *.rsp
74 | *.sbr
75 | *.tlb
76 | *.tli
77 | *.tlh
78 | *.tmp
79 | *.tmp_proj
80 | *.log
81 | *.vspscc
82 | *.vssscc
83 | .builds
84 | *.pidb
85 | *.svclog
86 | *.scc
87 |
88 | # Chutzpah Test files
89 | _Chutzpah*
90 |
91 | # Visual C++ cache files
92 | ipch/
93 | *.aps
94 | *.ncb
95 | *.opendb
96 | *.opensdf
97 | *.sdf
98 | *.cachefile
99 | *.VC.db
100 | *.VC.VC.opendb
101 |
102 | # Visual Studio profiler
103 | *.psess
104 | *.vsp
105 | *.vspx
106 | *.sap
107 |
108 | # Visual Studio Trace Files
109 | *.e2e
110 |
111 | # TFS 2012 Local Workspace
112 | $tf/
113 |
114 | # Guidance Automation Toolkit
115 | *.gpState
116 |
117 | # ReSharper is a .NET coding add-in
118 | _ReSharper*/
119 | *.[Rr]e[Ss]harper
120 | *.DotSettings.user
121 |
122 | # JustCode is a .NET coding add-in
123 | .JustCode
124 |
125 | # TeamCity is a build add-in
126 | _TeamCity*
127 |
128 | # DotCover is a Code Coverage Tool
129 | *.dotCover
130 |
131 | # AxoCover is a Code Coverage Tool
132 | .axoCover/*
133 | !.axoCover/settings.json
134 |
135 | # Visual Studio code coverage results
136 | *.coverage
137 | *.coveragexml
138 |
139 | # NCrunch
140 | _NCrunch_*
141 | .*crunch*.local.xml
142 | nCrunchTemp_*
143 |
144 | # MightyMoose
145 | *.mm.*
146 | AutoTest.Net/
147 |
148 | # Web workbench (sass)
149 | .sass-cache/
150 |
151 | # Installshield output folder
152 | [Ee]xpress/
153 |
154 | # DocProject is a documentation generator add-in
155 | DocProject/buildhelp/
156 | DocProject/Help/*.HxT
157 | DocProject/Help/*.HxC
158 | DocProject/Help/*.hhc
159 | DocProject/Help/*.hhk
160 | DocProject/Help/*.hhp
161 | DocProject/Help/Html2
162 | DocProject/Help/html
163 |
164 | # Click-Once directory
165 | publish/
166 |
167 | # Publish Web Output
168 | *.[Pp]ublish.xml
169 | *.azurePubxml
170 | # Note: Comment the next line if you want to checkin your web deploy settings,
171 | # but database connection strings (with potential passwords) will be unencrypted
172 | *.pubxml
173 | *.publishproj
174 |
175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
176 | # checkin your Azure Web App publish settings, but sensitive information contained
177 | # in these scripts will be unencrypted
178 | PublishScripts/
179 |
180 | # NuGet Packages
181 | *.nupkg
182 | # The packages folder can be ignored because of Package Restore
183 | **/[Pp]ackages/*
184 | # except build/, which is used as an MSBuild target.
185 | !**/[Pp]ackages/build/
186 | # Uncomment if necessary however generally it will be regenerated when needed
187 | #!**/[Pp]ackages/repositories.config
188 | # NuGet v3's project.json files produces more ignorable files
189 | *.nuget.props
190 | *.nuget.targets
191 |
192 | # Microsoft Azure Build Output
193 | csx/
194 | *.build.csdef
195 |
196 | # Microsoft Azure Emulator
197 | ecf/
198 | rcf/
199 |
200 | # Windows Store app package directories and files
201 | AppPackages/
202 | BundleArtifacts/
203 | Package.StoreAssociation.xml
204 | _pkginfo.txt
205 | *.appx
206 |
207 | # Visual Studio cache files
208 | # files ending in .cache can be ignored
209 | *.[Cc]ache
210 | # but keep track of directories ending in .cache
211 | !*.[Cc]ache/
212 |
213 | # Others
214 | ClientBin/
215 | ~$*
216 | *~
217 | *.dbmdl
218 | *.dbproj.schemaview
219 | *.jfm
220 | *.pfx
221 | *.publishsettings
222 | orleans.codegen.cs
223 |
224 | # Including strong name files can present a security risk
225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
226 | #*.snk
227 |
228 | # Since there are multiple workflows, uncomment next line to ignore bower_components
229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
230 | #bower_components/
231 |
232 | # RIA/Silverlight projects
233 | Generated_Code/
234 |
235 | # Backup & report files from converting an old project file
236 | # to a newer Visual Studio version. Backup files are not needed,
237 | # because we have git ;-)
238 | _UpgradeReport_Files/
239 | Backup*/
240 | UpgradeLog*.XML
241 | UpgradeLog*.htm
242 | ServiceFabricBackup/
243 | *.rptproj.bak
244 |
245 | # SQL Server files
246 | *.mdf
247 | *.ldf
248 | *.ndf
249 |
250 | # Business Intelligence projects
251 | *.rdl.data
252 | *.bim.layout
253 | *.bim_*.settings
254 | *.rptproj.rsuser
255 |
256 | # Microsoft Fakes
257 | FakesAssemblies/
258 |
259 | # GhostDoc plugin setting file
260 | *.GhostDoc.xml
261 |
262 | # Node.js Tools for Visual Studio
263 | .ntvs_analysis.dat
264 | node_modules/
265 |
266 | # Visual Studio 6 build log
267 | *.plg
268 |
269 | # Visual Studio 6 workspace options file
270 | *.opt
271 |
272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
273 | *.vbw
274 |
275 | # Visual Studio LightSwitch build output
276 | **/*.HTMLClient/GeneratedArtifacts
277 | **/*.DesktopClient/GeneratedArtifacts
278 | **/*.DesktopClient/ModelManifest.xml
279 | **/*.Server/GeneratedArtifacts
280 | **/*.Server/ModelManifest.xml
281 | _Pvt_Extensions
282 |
283 | # Paket dependency manager
284 | .paket/paket.exe
285 | paket-files/
286 |
287 | # FAKE - F# Make
288 | .fake/
289 |
290 | # JetBrains Rider
291 | .idea/
292 | *.sln.iml
293 |
294 | # CodeRush
295 | .cr/
296 |
297 | # Python Tools for Visual Studio (PTVS)
298 | __pycache__/
299 | *.pyc
300 |
301 | # Cake - Uncomment if you are using it
302 | # tools/**
303 | # !tools/packages.config
304 |
305 | # Tabs Studio
306 | *.tss
307 |
308 | # Telerik's JustMock configuration file
309 | *.jmconfig
310 |
311 | # BizTalk build output
312 | *.btp.cs
313 | *.btm.cs
314 | *.odx.cs
315 | *.xsd.cs
316 |
317 | # OpenCover UI analysis results
318 | OpenCover/
319 |
320 | # Azure Stream Analytics local run output
321 | ASALocalRun/
322 |
323 | # MSBuild Binary and Structured Log
324 | *.binlog
325 |
326 | # NVidia Nsight GPU debugger configuration file
327 | *.nvuser
328 |
329 | # MFractors (Xamarin productivity tool) working folder
330 | .mfractor/
331 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | NOTICES AND INFORMATION
2 | Do Not Translate or Localize
3 |
4 | This software incorporates material from third parties. Microsoft makes certain
5 | open source code available at https3rdpartysource.microsoft.com, or you may
6 | send a check or money order for US $5.00, including the product name, the open
7 | source component name, platform, and version number, to
8 |
9 | Source Code Compliance Team
10 | Microsoft Corporation
11 | One Microsoft Way
12 | Redmond, WA 98052
13 | USA
14 |
15 | Notwithstanding any other terms, you may reverse engineer this software to the
16 | extent required to debug changes to any libraries licensed under the GNU Lesser
17 | General Public License.
18 |
19 | =======================================================================
20 | Component. PyTorch
21 |
22 | Open Source LicenseCopyright Notice.
23 |
24 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
25 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
26 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
27 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
28 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
29 | Copyright (c) 2011-2013 NYU (Clement Farabet)
30 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
31 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
32 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
33 |
34 | From Caffe2
35 |
36 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
37 |
38 | All contributions by Facebook
39 | Copyright (c) 2016 Facebook Inc.
40 |
41 | All contributions by Google
42 | Copyright (c) 2015 Google Inc.
43 | All rights reserved.
44 |
45 | All contributions by Yangqing Jia
46 | Copyright (c) 2015 Yangqing Jia
47 | All rights reserved.
48 |
49 | All contributions from Caffe
50 | Copyright(c) 2013, 2014, 2015, the respective contributors
51 | All rights reserved.
52 |
53 | All other contributions
54 | Copyright(c) 2015, 2016 the respective contributors
55 | All rights reserved.
56 |
57 | Caffe2 uses a copyright model similar to Caffe each contributor holds
58 | copyright over their contributions to Caffe2. The project versioning records
59 | all such contribution and copyright details. If a contributor wants to further
60 | mark their specific copyright on a particular contribution, they should
61 | indicate their copyright solely in the commit message of the change when it is
62 | committed.
63 |
64 | All rights reserved.
65 |
66 | Redistribution and use in source and binary forms, with or without
67 | modification, are permitted provided that the following conditions are met
68 |
69 | 1. Redistributions of source code must retain the above copyright
70 | notice, this list of conditions and the following disclaimer.
71 |
72 | 2. Redistributions in binary form must reproduce the above copyright
73 | notice, this list of conditions and the following disclaimer in the
74 | documentation andor other materials provided with the distribution.
75 |
76 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
77 | and IDIAP Research Institute nor the names of its contributors may be
78 | used to endorse or promote products derived from this software without
79 | specific prior written permission.
80 |
81 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
82 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
83 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
84 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
85 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
86 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
87 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
88 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
89 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
90 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
91 | POSSIBILITY OF SUCH DAMAGE.
92 |
93 | =======================================================================
94 | Component. nltk
95 |
96 | Open Source LicenseCopyright Notice.
97 |
98 | Copyright (C) 2001-2019 NLTK Project
99 |
100 | Licensed under the Apache License, Version 2.0 (the 'License');
101 | you may not use this file except in compliance with the License.
102 | You may obtain a copy of the License at
103 |
104 | httpwww.apache.orglicensesLICENSE-2.0
105 |
106 | Unless required by applicable law or agreed to in writing, software
107 | distributed under the License is distributed on an 'AS IS' BASIS,
108 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
109 | See the License for the specific language governing permissions and
110 | limitations under the License.
111 |
112 | =======================================================================
113 | Component. Glove
114 |
115 | Open Source LicenseCopyright Notice.
116 |
117 | Preamble
118 | The Open Data Commons – Public Domain Dedication & Licence is a document intended to allow you to freely share, modify, and use this work for any purpose and without any restrictions. This licence is intended for use on databases or their contents (“data”), either together or individually.
119 | Many databases are covered by copyright. Some jurisdictions, mainly in Europe, have specific special rights that cover databases called the “sui generis” database right. Both of these sets of rights, as well as other legal rights used to protect databases and data, can create uncertainty or practical difficulty for those wishing to share databases and their underlying data but retain a limited amount of rights under a “some rights reserved” approach to licensing as outlined in the Science Commons Protocol for Implementing Open Access Data. As a result, this waiver and licence tries to the fullest extent possible to eliminate or fully license any rights that cover this database and data. Any Community Norms or similar statements of use of the database or data do not form a part of this document, and do not act as a contract for access or other terms of use for the database or data.
120 | The position of the recipient of the work
121 | Because this document places the database and its contents in or as close as possible within the public domain, there are no restrictions or requirements placed on the recipient by this document. Recipients may use this work commercially, use technical protection measures, combine this data or database with other databases or data, and share their changes and additions or keep them secret. It is not a requirement that recipients provide further users with a copy of this licence or attribute the original creator of the data or database as a source. The goal is to eliminate restrictions held by the original creator of the data and database on the use of it by others.
122 | The position of the dedicator of the work
123 | Copyright law, as with most other law under the banner of “intellectual property”, is inherently national law. This means that there exists several differences in how copyright and other IP rights can be relinquished, waived or licensed in the many legal jurisdictions of the world. This is despite much harmonisation of minimum levels of protection. The internet and other communication technologies span these many disparate legal jurisdictions and thus pose special difficulties for a document relinquishing and waiving intellectual property rights, including copyright and database rights, for use by the global community. Because of this feature of intellectual property law, this document first relinquishes the rights and waives the relevant rights and claims. It then goes on to license these same rights for jurisdictions or areas of law that may make it difficult to relinquish or waive rights or claims.
124 | The purpose of this document is to enable rightsholders to place their work into the public domain. Unlike licences for free and open source software, free cultural works, or open content licences, rightsholders will not be able to “dual license” their work by releasing the same work under different licences. This is because they have allowed anyone to use the work in whatever way they choose. Rightsholders therefore can’t re-license it under copyright or database rights on different terms because they have nothing left to license. Doing so creates truly accessible data to build rich applications and advance the progress of science and the arts.
125 | This document can cover either or both of the database and its contents (the data). Because databases can have a wide variety of content – not just factual data – rightsholders should use the Open Data Commons – Public Domain Dedication & Licence for an entire database and its contents only if everything can be placed under the terms of this document. Because even factual data can sometimes have intellectual property rights, rightsholders should use this licence to cover both the database and its factual data when making material available under this document; even if it is likely that the data would not be covered by copyright or database rights.
126 | Rightsholders can also use this document to cover any copyright or database rights claims over only a database, and leave the contents to be covered by other licences or documents. They can do this because this document refers to the “Work”, which can be either – or both – the database and its contents. As a result, rightsholders need to clearly state what they are dedicating under this document when they dedicate it.
127 | Just like any licence or other document dealing with intellectual property, rightsholders should be aware that one can only license what one owns. Please ensure that the rights have been cleared to make this material available under this document.
128 | This document permanently and irrevocably makes the Work available to the public for any use of any kind, and it should not be used unless the rightsholder is prepared for this to happen.
129 | Part I Introduction
130 | The Rightsholder (the Person holding rights or claims over the Work) agrees as follows
131 | 1.0 Definitions of Capitalised Words
132 | “Copyright” – Includes rights under copyright and under neighbouring rights and similarly related sets of rights under the law of the relevant jurisdiction under Section 6.4.
133 | “Data” – The contents of the Database, which includes the information, independent works, or other material collected into the Database offered under the terms of this Document.
134 | “Database” – A collection of Data arranged in a systematic or methodical way and individually accessible by electronic or other means offered under the terms of this Document.
135 | “Database Right” – Means rights over Data resulting from the Chapter III (“sui generis”) rights in the Database Directive (Directive 969EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases) and any future updates as well as any similar rights available in the relevant jurisdiction under Section 6.4.
136 | “Document” – means this relinquishment and waiver of rights and claims and back up licence agreement.
137 | “Person” – Means a natural or legal person or a body of persons corporate or incorporate.
138 | “Use” – As a verb, means doing any act that is restricted by Copyright or Database Rights whether in the original medium or any other; and includes modifying the Work as may be technically necessary to use it in a different mode or format. This includes the right to sublicense the Work.
139 | “Work” – Means either or both of the Database and Data offered under the terms of this Document.
140 | “You” – the Person acquiring rights under the licence elements of this Document.
141 | Words in the singular include the plural and vice versa.
142 | 2.0 What this document covers
143 | 2.1. Legal effect of this Document. This Document is
144 | a. A dedication to the public domain and waiver of Copyright and Database Rights over the Work; and
145 | b. A licence of Copyright and Database Rights over the Work in jurisdictions that do not allow for relinquishment or waiver.
146 | 2.2. Legal rights covered.
147 | a. Copyright. Any copyright or neighbouring rights in the Work. Copyright law varies between jurisdictions, but is likely to cover the Database model or schema, which is the structure, arrangement, and organisation of the Database, and can also include the Database tables and table indexes; the data entry and output sheets; and the Field names of Data stored in the Database. Copyright may also cover the Data depending on the jurisdiction and type of Data; and
148 | b. Database Rights. Database Rights only extend to the extraction and re-utilisation of the whole or a substantial part of the Data. Database Rights can apply even when there is no copyright over the Database. Database Rights can also apply when the Data is removed from the Database and is selected and arranged in a way that would not infringe any applicable copyright.
149 | 2.2 Rights not covered.
150 | a. This Document does not apply to computer programs used in the making or operation of the Database;
151 | b. This Document does not cover any patents over the Data or the Database. Please see Section 4.2 later in this Document for further details; and
152 | c. This Document does not cover any trade marks associated with the Database. Please see Section 4.3 later in this Document for further details.
153 | Users of this Database are cautioned that they may have to clear other rights or consult other licences.
154 | 2.3 Facts are free. The Rightsholder takes the position that factual information is not covered by Copyright. This Document however covers the Work in jurisdictions that may protect the factual information in the Work by Copyright, and to cover any information protected by Copyright that is contained in the Work.
155 | Part II Dedication to the public domain
156 | 3.0 Dedication, waiver, and licence of Copyright and Database Rights
157 | 3.1 Dedication of Copyright and Database Rights to the public domain. The Rightsholder by using this Document, dedicates the Work to the public domain for the benefit of the public and relinquishes all rights in Copyright and Database Rights over the Work.
158 | a. The Rightsholder realises that once these rights are relinquished, that the Rightsholder has no further rights in Copyright and Database Rights over the Work, and that the Work is free and open for others to Use.
159 | b. The Rightsholder intends for their relinquishment to cover all present and future rights in the Work under Copyright and Database Rights, whether they are vested or contingent rights, and that this relinquishment of rights covers all their heirs and successors.
160 | The above relinquishment of rights applies worldwide and includes media and formats now known or created in the future.
161 | 3.2 Waiver of rights and claims in Copyright and Database Rights when Section 3.1 dedication inapplicable. If the dedication in Section 3.1 does not apply in the relevant jurisdiction under Section 6.4, the Rightsholder waives any rights and claims that the Rightsholder may have or acquire in the future over the Work in
162 | a. Copyright; and
163 | b. Database Rights.
164 | To the extent possible in the relevant jurisdiction, the above waiver of rights and claims applies worldwide and includes media and formats now known or created in the future. The Rightsholder agrees not to assert the above rights and waives the right to enforce them over the Work.
165 | 3.3 Licence of Copyright and Database Rights when Sections 3.1 and 3.2 inapplicable. If the dedication and waiver in Sections 3.1 and 3.2 does not apply in the relevant jurisdiction under Section 6.4, the Rightsholder and You agree as follows
166 | a. The Licensor grants to You a worldwide, royalty-free, non-exclusive, licence to Use the Work for the duration of any applicable Copyright and Database Rights. These rights explicitly include commercial use, and do not exclude any field of endeavour. To the extent possible in the relevant jurisdiction, these rights may be exercised in all media and formats whether now known or created in the future.
167 | 3.4 Moral rights. This section covers moral rights, including the right to be identified as the author of the Work or to object to treatment that would otherwise prejudice the author’s honour and reputation, or any other derogatory treatment
168 | a. For jurisdictions allowing waiver of moral rights, Licensor waives all moral rights that Licensor may have in the Work to the fullest extent possible by the law of the relevant jurisdiction under Section 6.4;
169 | b. If waiver of moral rights under Section 3.4 a in the relevant jurisdiction is not possible, Licensor agrees not to assert any moral rights over the Work and waives all claims in moral rights to the fullest extent possible by the law of the relevant jurisdiction under Section 6.4; and
170 | c. For jurisdictions not allowing waiver or an agreement not to assert moral rights under Section 3.4 a and b, the author may retain their moral rights over the copyrighted aspects of the Work.
171 | Please note that some jurisdictions do not allow for the waiver of moral rights, and so moral rights may still subsist over the work in some jurisdictions.
172 | 4.0 Relationship to other rights
173 | 4.1 No other contractual conditions. The Rightsholder makes this Work available to You without any other contractual obligations, either express or implied. Any Community Norms statement associated with the Work is not a contract and does not form part of this Document.
174 | 4.2 Relationship to patents. This Document does not grant You a licence for any patents that the Rightsholder may own. Users of this Database are cautioned that they may have to clear other rights or consult other licences.
175 | 4.3 Relationship to trade marks. This Document does not grant You a licence for any trade marks that the Rightsholder may own or that the Rightsholder may use to cover the Work. Users of this Database are cautioned that they may have to clear other rights or consult other licences.
176 | Part III General provisions
177 | 5.0 Warranties, disclaimer, and limitation of liability
178 | 5.1 The Work is provided by the Rightsholder “as is” and without any warranty of any kind, either express or implied, whether of title, of accuracy or completeness, of the presence of absence of errors, of fitness for purpose, or otherwise. Some jurisdictions do not allow the exclusion of implied warranties, so this exclusion may not apply to You.
179 | 5.2 Subject to any liability that may not be excluded or limited by law, the Rightsholder is not liable for, and expressly excludes, all liability for loss or damage however and whenever caused to anyone by any use under this Document, whether by You or by anyone else, and whether caused by any fault on the part of the Rightsholder or not. This exclusion of liability includes, but is not limited to, any special, incidental, consequential, punitive, or exemplary damages. This exclusion applies even if the Rightsholder has been advised of the possibility of such damages.
180 | 5.3 If liability may not be excluded by law, it is limited to actual and direct financial loss to the extent it is caused by proved negligence on the part of the Rightsholder.
181 | 6.0 General
182 | 6.1 If any provision of this Document is held to be invalid or unenforceable, that must not affect the validity or enforceability of the remainder of the terms of this Document.
183 | 6.2 This Document is the entire agreement between the parties with respect to the Work covered here. It replaces any earlier understandings, agreements or representations with respect to the Work not specified here.
184 | 6.3 This Document does not affect any rights that You or anyone else may independently have under any applicable law to make any use of this Work, including (for jurisdictions where this Document is a licence) fair dealing, fair use, database exceptions, or any other legally recognised limitation or exception to infringement of copyright or other applicable laws.
185 | 6.4 This Document takes effect in the relevant jurisdiction in which the Document terms are sought to be enforced. If the rights waived or granted under applicable law in the relevant jurisdiction includes additional rights not waived or granted under this Document, these additional rights are included in this Document in order to meet the intent of this Document.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IRNet
2 | Code for our ACL'19 accepted paper: [Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation](https://arxiv.org/pdf/1905.08205.pdf)
3 |
4 |
5 |
6 |
7 |
8 | ## Environment Setup
9 |
10 | * `Python3.6`
11 | * `Pytorch 0.4.0` or higher
12 |
13 | Install Python dependency via `pip install -r requirements.txt` when the environment of Python and Pytorch is setup.
14 |
15 | ## Running Code
16 |
17 | #### Data preparation
18 |
19 |
20 | * Download [Glove Embedding](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip) and put `glove.42B.300d` under `./data/` directory
21 | * Download [Pretrained IRNet](https://drive.google.com/open?id=1VoV28fneYss8HaZmoThGlvYU3A-aK31q) and put `
22 | IRNet_pretrained.model` under `./saved_model/` directory
23 | * Download preprocessed train/dev datasets from [here](https://drive.google.com/open?id=1YFV1GoLivOMlmunKW0nkzefKULO4wtrn) and put `train.json`, `dev.json` and
24 | `tables.json` under `./data/` directory
25 |
26 | ##### Generating train/dev data by yourself
27 | You could process the origin [Spider Data](https://drive.google.com/uc?export=download&id=11icoH_EA-NYb0OrPTdehRWm_d7-DIzWX) by your own. Download and put `train.json`, `dev.json` and
28 | `tables.json` under `./data/` directory and follow the instruction on `./preprocess/`
29 |
30 | #### Training
31 |
32 | Run `train.sh` to train IRNet.
33 |
34 | `sh train.sh [GPU_ID] [SAVE_FOLD]`
35 |
36 | #### Testing
37 |
38 | Run `eval.sh` to eval IRNet.
39 |
40 | `sh eval.sh [GPU_ID] [OUTPUT_FOLD]`
41 |
42 |
43 | #### Evaluation
44 |
45 | You could follow the general evaluation process in [Spider Page](https://github.com/taoyds/spider)
46 |
47 |
48 | ## Results
49 | | **Model** | Dev
Exact Set Match
Accuracy | Test
Exact Set Match
Accuracy |
50 | | ----------- | ------------------------------------- | -------------------------------------- |
51 | | IRNet | 53.2 | 46.7 |
52 | | IRNet+BERT(base) | 61.9 | **54.7** |
53 |
54 |
55 | ## Citation
56 |
57 | If you use IRNet, please cite the following work.
58 |
59 | ```
60 | @inproceedings{GuoIRNet2019,
61 | author={Jiaqi Guo and Zecheng Zhan and Yan Gao and Yan Xiao and Jian-Guang Lou and Ting Liu and Dongmei Zhang},
62 | title={Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation},
63 | booktitle={Proceeding of the 57th Annual Meeting of the Association for Computational Linguistics (ACL)},
64 | year={2019},
65 | organization={Association for Computational Linguistics}
66 | }
67 | ```
68 |
69 | ## Thanks
70 | We would like to thank [Tao Yu](https://taoyds.github.io/) and [Bo Pang](https://www.linkedin.com/in/bo-pang/) for running evaluations on our submitted models.
71 | We are also grateful to the flexible semantic parser [TranX](https://github.com/pcyin/tranX) that inspires our works.
72 |
73 | # Contributing
74 |
75 | This project welcomes contributions and suggestions. Most contributions require you to
76 | agree to a Contributor License Agreement (CLA) declaring that you have the right to,
77 | and actually do, grant us the rights to use your contribution. For details, visit
78 | https://cla.microsoft.com.
79 |
80 | When you submit a pull request, a CLA-bot will automatically determine whether you need
81 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
82 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
83 |
84 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
85 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
86 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
87 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : __init__.py
9 | # @Software: PyCharm
10 | """
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/27
7 | # @Author : Jiaqi&Zecheng
8 | # @File : eval.py
9 | # @Software: PyCharm
10 | """
11 |
12 |
13 | import torch
14 | from src import args as arg
15 | from src import utils
16 | from src.models.model import IRNet
17 | from src.rule import semQL
18 |
19 |
20 | def evaluate(args):
21 | """
22 | :param args:
23 | :return:
24 | """
25 |
26 | grammar = semQL.Grammar()
27 | sql_data, table_data, val_sql_data,\
28 | val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)
29 |
30 | model = IRNet(args, grammar)
31 |
32 | if args.cuda: model.cuda()
33 |
34 | print('load pretrained model from %s'% (args.load_model))
35 | pretrained_model = torch.load(args.load_model,
36 | map_location=lambda storage, loc: storage)
37 | import copy
38 | pretrained_modeled = copy.deepcopy(pretrained_model)
39 | for k in pretrained_model.keys():
40 | if k not in model.state_dict().keys():
41 | del pretrained_modeled[k]
42 |
43 | model.load_state_dict(pretrained_modeled)
44 |
45 | model.word_emb = utils.load_word_emb(args.glove_embed_path)
46 |
47 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
48 | beam_size=args.beam_size)
49 | print('Sketch Acc: %f, Acc: %f' % (sketch_acc, acc))
50 | # utils.eval_acc(json_datas, val_sql_data)
51 | import json
52 | with open('./predict_lf.json', 'w') as f:
53 | json.dump(json_datas, f)
54 |
55 | if __name__ == '__main__':
56 | arg_parser = arg.init_arg_parser()
57 | args = arg.init_config(arg_parser)
58 | print(args)
59 | evaluate(args)
60 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | #!/bin/bash
5 |
6 | devices=$1
7 | save_name=$2
8 |
9 | CUDA_VISIBLE_DEVICES=$devices python -u eval.py --dataset ./data \
10 | --glove_embed_path ./data/glove.42B.300d.txt \
11 | --cuda \
12 | --epoch 50 \
13 | --loss_epoch_threshold 50 \
14 | --sketch_loss_coefficie 1.0 \
15 | --beam_size 5 \
16 | --seed 90 \
17 | --save ${save_name} \
18 | --embed_size 300 \
19 | --sentence_features \
20 | --column_pointer \
21 | --hidden_size 300 \
22 | --lr_scheduler \
23 | --lr_scheduler_gammar 0.5 \
24 | --att_vec_size 300 \
25 | --load_model ./saved_model/IRNet_pretrained.model
26 |
27 | python sem2SQL.py --data_path ./data --input_path predict_lf.json --output_path ${save_name}
28 |
--------------------------------------------------------------------------------
/preprocess/data_process.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/24
7 | # @Author : Jiaqi&Zecheng
8 | # @File : data_process.py
9 | # @Software: PyCharm
10 | """
11 | import json
12 | import argparse
13 | import nltk
14 | import os
15 | import pickle
16 | from utils import symbol_filter, re_lemma, fully_part_header, group_header, partial_header, num2year, group_symbol, group_values, group_digital
17 | from utils import AGG, wordnet_lemmatizer
18 | from utils import load_dataSets
19 |
20 | def process_datas(datas, args):
21 | """
22 |
23 | :param datas:
24 | :param args:
25 | :return:
26 | """
27 | with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f:
28 | english_RelatedTo = pickle.load(f)
29 |
30 | with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f:
31 | english_IsA = pickle.load(f)
32 |
33 | # copy of the origin question_toks
34 | for d in datas:
35 | if 'origin_question_toks' not in d:
36 | d['origin_question_toks'] = d['question_toks']
37 |
38 | for entry in datas:
39 | entry['question_toks'] = symbol_filter(entry['question_toks'])
40 | origin_question_toks = symbol_filter([x for x in entry['origin_question_toks'] if x.lower() != 'the'])
41 | question_toks = [wordnet_lemmatizer.lemmatize(x.lower()) for x in entry['question_toks'] if x.lower() != 'the']
42 |
43 | entry['question_toks'] = question_toks
44 |
45 | table_names = []
46 | table_names_pattern = []
47 |
48 | for y in entry['table_names']:
49 | x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
50 | table_names.append(" ".join(x))
51 | x = [re_lemma(x.lower()) for x in y.split(' ')]
52 | table_names_pattern.append(" ".join(x))
53 |
54 | header_toks = []
55 | header_toks_list = []
56 |
57 | header_toks_pattern = []
58 | header_toks_list_pattern = []
59 |
60 | for y in entry['col_set']:
61 | x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
62 | header_toks.append(" ".join(x))
63 | header_toks_list.append(x)
64 |
65 | x = [re_lemma(x.lower()) for x in y.split(' ')]
66 | header_toks_pattern.append(" ".join(x))
67 | header_toks_list_pattern.append(x)
68 |
69 | num_toks = len(question_toks)
70 | idx = 0
71 | tok_concol = []
72 | type_concol = []
73 | nltk_result = nltk.pos_tag(question_toks)
74 |
75 | while idx < num_toks:
76 |
77 | # fully header
78 | end_idx, header = fully_part_header(question_toks, idx, num_toks, header_toks)
79 | if header:
80 | tok_concol.append(question_toks[idx: end_idx])
81 | type_concol.append(["col"])
82 | idx = end_idx
83 | continue
84 |
85 | # check for table
86 | end_idx, tname = group_header(question_toks, idx, num_toks, table_names)
87 | if tname:
88 | tok_concol.append(question_toks[idx: end_idx])
89 | type_concol.append(["table"])
90 | idx = end_idx
91 | continue
92 |
93 | # check for column
94 | end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
95 | if header:
96 | tok_concol.append(question_toks[idx: end_idx])
97 | type_concol.append(["col"])
98 | idx = end_idx
99 | continue
100 |
101 | # check for partial column
102 | end_idx, tname = partial_header(question_toks, idx, header_toks_list)
103 | if tname:
104 | tok_concol.append(tname)
105 | type_concol.append(["col"])
106 | idx = end_idx
107 | continue
108 |
109 | # check for aggregation
110 | end_idx, agg = group_header(question_toks, idx, num_toks, AGG)
111 | if agg:
112 | tok_concol.append(question_toks[idx: end_idx])
113 | type_concol.append(["agg"])
114 | idx = end_idx
115 | continue
116 |
117 | if nltk_result[idx][1] == 'RBR' or nltk_result[idx][1] == 'JJR':
118 | tok_concol.append([question_toks[idx]])
119 | type_concol.append(['MORE'])
120 | idx += 1
121 | continue
122 |
123 | if nltk_result[idx][1] == 'RBS' or nltk_result[idx][1] == 'JJS':
124 | tok_concol.append([question_toks[idx]])
125 | type_concol.append(['MOST'])
126 | idx += 1
127 | continue
128 |
129 | # string match for Time Format
130 | if num2year(question_toks[idx]):
131 | question_toks[idx] = 'year'
132 | end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
133 | if header:
134 | tok_concol.append(question_toks[idx: end_idx])
135 | type_concol.append(["col"])
136 | idx = end_idx
137 | continue
138 |
139 | def get_concept_result(toks, graph):
140 | for begin_id in range(0, len(toks)):
141 | for r_ind in reversed(range(1, len(toks) + 1 - begin_id)):
142 | tmp_query = "_".join(toks[begin_id:r_ind])
143 | if tmp_query in graph:
144 | mi = graph[tmp_query]
145 | for col in entry['col_set']:
146 | if col in mi:
147 | return col
148 |
149 | end_idx, symbol = group_symbol(question_toks, idx, num_toks)
150 | if symbol:
151 | tmp_toks = [x for x in question_toks[idx: end_idx]]
152 | assert len(tmp_toks) > 0, print(symbol, question_toks)
153 | pro_result = get_concept_result(tmp_toks, english_IsA)
154 | if pro_result is None:
155 | pro_result = get_concept_result(tmp_toks, english_RelatedTo)
156 | if pro_result is None:
157 | pro_result = "NONE"
158 | for tmp in tmp_toks:
159 | tok_concol.append([tmp])
160 | type_concol.append([pro_result])
161 | pro_result = "NONE"
162 | idx = end_idx
163 | continue
164 |
165 | end_idx, values = group_values(origin_question_toks, idx, num_toks)
166 | if values and (len(values) > 1 or question_toks[idx - 1] not in ['?', '.']):
167 | tmp_toks = [wordnet_lemmatizer.lemmatize(x) for x in question_toks[idx: end_idx] if x.isalnum() is True]
168 | assert len(tmp_toks) > 0, print(question_toks[idx: end_idx], values, question_toks, idx, end_idx)
169 | pro_result = get_concept_result(tmp_toks, english_IsA)
170 | if pro_result is None:
171 | pro_result = get_concept_result(tmp_toks, english_RelatedTo)
172 | if pro_result is None:
173 | pro_result = "NONE"
174 | for tmp in tmp_toks:
175 | tok_concol.append([tmp])
176 | type_concol.append([pro_result])
177 | pro_result = "NONE"
178 | idx = end_idx
179 | continue
180 |
181 | result = group_digital(question_toks, idx)
182 | if result is True:
183 | tok_concol.append(question_toks[idx: idx + 1])
184 | type_concol.append(["value"])
185 | idx += 1
186 | continue
187 | if question_toks[idx] == ['ha']:
188 | question_toks[idx] = ['have']
189 |
190 | tok_concol.append([question_toks[idx]])
191 | type_concol.append(['NONE'])
192 | idx += 1
193 | continue
194 |
195 | entry['question_arg'] = tok_concol
196 | entry['question_arg_type'] = type_concol
197 | entry['nltk_pos'] = nltk_result
198 |
199 | return datas
200 |
201 |
202 | if __name__ == '__main__':
203 | arg_parser = argparse.ArgumentParser()
204 | arg_parser.add_argument('--data_path', type=str, help='dataset', required=True)
205 | arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True)
206 | arg_parser.add_argument('--output', type=str, help='output data')
207 | args = arg_parser.parse_args()
208 | args.conceptNet = './conceptNet'
209 |
210 | # loading dataSets
211 | datas, table = load_dataSets(args)
212 |
213 | # process datasets
214 | process_result = process_datas(datas, args)
215 |
216 | with open(args.output, 'w') as f:
217 | json.dump(datas, f)
218 |
219 |
220 |
--------------------------------------------------------------------------------
/preprocess/download_nltk.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/1/29
7 | # @Author : Jiaqi&Zecheng
8 | # @File : download_nltk.py
9 | # @Software: PyCharm
10 | """
11 | import nltk
12 | nltk.download('averaged_perceptron_tagger')
13 | nltk.download('punkt')
14 | nltk.download('wordnet')
15 |
16 |
--------------------------------------------------------------------------------
/preprocess/run_me.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | #!/bin/bash
5 |
6 | data=$1
7 | table_data=$2
8 | output=$3
9 |
10 | echo "Start download NLTK data"
11 | python download_nltk.py
12 |
13 | echo "Start process the origin Spider dataset"
14 | python data_process.py --data_path ${data} --table_path ${table_data} --output "process_data.json"
15 |
16 | echo "Start generate SemQL from SQL"
17 | python sql2SemQL.py --data_path process_data.json --table_path ${table_data} --output ${data}
18 |
19 | rm process_data.json
20 |
--------------------------------------------------------------------------------
/preprocess/sql2SemQL.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/24
7 | # @Author : Jiaqi&Zecheng
8 | # @File : sql2SemQL.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import argparse
13 | import json
14 | import sys
15 |
16 | import copy
17 | from utils import load_dataSets
18 |
19 | sys.path.append("..")
20 | from src.rule.semQL import Root1, Root, N, A, C, T, Sel, Sup, Filter, Order
21 |
22 | class Parser:
23 | def __init__(self):
24 | self.copy_selec = None
25 | self.sel_result = []
26 | self.colSet = set()
27 |
28 | def _init_rule(self):
29 | self.copy_selec = None
30 | self.colSet = set()
31 |
32 | def _parse_root(self, sql):
33 | """
34 | parsing the sql by the grammar
35 | R ::= Select | Select Filter | Select Order | ... |
36 | :return: [R(), states]
37 | """
38 | use_sup, use_ord, use_fil = True, True, False
39 |
40 | if sql['sql']['limit'] == None:
41 | use_sup = False
42 |
43 | if sql['sql']['orderBy'] == []:
44 | use_ord = False
45 | elif sql['sql']['limit'] != None:
46 | use_ord = False
47 |
48 | # check the where and having
49 | if sql['sql']['where'] != [] or \
50 | sql['sql']['having'] != []:
51 | use_fil = True
52 |
53 | if use_fil and use_sup:
54 | return [Root(0)], ['FILTER', 'SUP', 'SEL']
55 | elif use_fil and use_ord:
56 | return [Root(1)], ['ORDER', 'FILTER', 'SEL']
57 | elif use_sup:
58 | return [Root(2)], ['SUP', 'SEL']
59 | elif use_fil:
60 | return [Root(3)], ['FILTER', 'SEL']
61 | elif use_ord:
62 | return [Root(4)], ['ORDER', 'SEL']
63 | else:
64 | return [Root(5)], ['SEL']
65 |
66 | def _parser_column0(self, sql, select):
67 | """
68 | Find table of column '*'
69 | :return: T(table_id)
70 | """
71 | if len(sql['sql']['from']['table_units']) == 1:
72 | return T(sql['sql']['from']['table_units'][0][1])
73 | else:
74 | table_list = []
75 | for tmp_t in sql['sql']['from']['table_units']:
76 | if type(tmp_t[1]) == int:
77 | table_list.append(tmp_t[1])
78 | table_set, other_set = set(table_list), set()
79 | for sel_p in select:
80 | if sel_p[1][1][1] != 0:
81 | other_set.add(sql['col_table'][sel_p[1][1][1]])
82 |
83 | if len(sql['sql']['where']) == 1:
84 | other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
85 | elif len(sql['sql']['where']) == 3:
86 | other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
87 | other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]])
88 | elif len(sql['sql']['where']) == 5:
89 | other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
90 | other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]])
91 | other_set.add(sql['col_table'][sql['sql']['where'][4][2][1][1]])
92 | table_set = table_set - other_set
93 | if len(table_set) == 1:
94 | return T(list(table_set)[0])
95 | elif len(table_set) == 0 and sql['sql']['groupBy'] != []:
96 | return T(sql['col_table'][sql['sql']['groupBy'][0][1]])
97 | else:
98 | question = sql['question']
99 | self.sel_result.append(question)
100 | print('column * table error')
101 | return T(sql['sql']['from']['table_units'][0][1])
102 |
103 | def _parse_select(self, sql):
104 | """
105 | parsing the sql by the grammar
106 | Select ::= A | AA | AAA | ... |
107 | A ::= agg column table
108 | :return: [Sel(), states]
109 | """
110 | result = []
111 | select = sql['sql']['select'][1]
112 | result.append(Sel(0))
113 | result.append(N(len(select) - 1))
114 |
115 | for sel in select:
116 | result.append(A(sel[0]))
117 | self.colSet.add(sql['col_set'].index(sql['names'][sel[1][1][1]]))
118 | result.append(C(sql['col_set'].index(sql['names'][sel[1][1][1]])))
119 | # now check for the situation with *
120 | if sel[1][1][1] == 0:
121 | result.append(self._parser_column0(sql, select))
122 | else:
123 | result.append(T(sql['col_table'][sel[1][1][1]]))
124 | if not self.copy_selec:
125 | self.copy_selec = [copy.deepcopy(result[-2]), copy.deepcopy(result[-1])]
126 |
127 | return result, None
128 |
129 | def _parse_sup(self, sql):
130 | """
131 | parsing the sql by the grammar
132 | Sup ::= Most A | Least A
133 | A ::= agg column table
134 | :return: [Sup(), states]
135 | """
136 | result = []
137 | select = sql['sql']['select'][1]
138 | if sql['sql']['limit'] == None:
139 | return result, None
140 | if sql['sql']['orderBy'][0] == 'desc':
141 | result.append(Sup(0))
142 | else:
143 | result.append(Sup(1))
144 |
145 | result.append(A(sql['sql']['orderBy'][1][0][1][0]))
146 | self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))
147 | result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])))
148 | if sql['sql']['orderBy'][1][0][1][1] == 0:
149 | result.append(self._parser_column0(sql, select))
150 | else:
151 | result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]]))
152 | return result, None
153 |
154 | def _parse_filter(self, sql):
155 | """
156 | parsing the sql by the grammar
157 | Filter ::= and Filter Filter | ... |
158 | A ::= agg column table
159 | :return: [Filter(), states]
160 | """
161 | result = []
162 | # check the where
163 | if sql['sql']['where'] != [] and sql['sql']['having'] != []:
164 | result.append(Filter(0))
165 |
166 | if sql['sql']['where'] != []:
167 | # check the not and/or
168 | if len(sql['sql']['where']) == 1:
169 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
170 | elif len(sql['sql']['where']) == 3:
171 | if sql['sql']['where'][1] == 'or':
172 | result.append(Filter(1))
173 | else:
174 | result.append(Filter(0))
175 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
176 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
177 | else:
178 | if sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'and':
179 | result.append(Filter(0))
180 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
181 | result.append(Filter(0))
182 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
183 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
184 | elif sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'or':
185 | result.append(Filter(1))
186 | result.append(Filter(0))
187 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
188 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
189 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
190 | elif sql['sql']['where'][1] == 'or' and sql['sql']['where'][3] == 'and':
191 | result.append(Filter(1))
192 | result.append(Filter(0))
193 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
194 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
195 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
196 | else:
197 | result.append(Filter(1))
198 | result.append(Filter(1))
199 | result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
200 | result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
201 | result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
202 |
203 | # check having
204 | if sql['sql']['having'] != []:
205 | result.extend(self.parse_one_condition(sql['sql']['having'][0], sql['names'], sql))
206 | return result, None
207 |
208 | def _parse_order(self, sql):
209 | """
210 | parsing the sql by the grammar
211 | Order ::= asc A | desc A
212 | A ::= agg column table
213 | :return: [Order(), states]
214 | """
215 | result = []
216 |
217 | if 'order' not in sql['query_toks_no_value'] or 'by' not in sql['query_toks_no_value']:
218 | return result, None
219 | elif 'limit' in sql['query_toks_no_value']:
220 | return result, None
221 | else:
222 | if sql['sql']['orderBy'] == []:
223 | return result, None
224 | else:
225 | select = sql['sql']['select'][1]
226 | if sql['sql']['orderBy'][0] == 'desc':
227 | result.append(Order(0))
228 | else:
229 | result.append(Order(1))
230 | result.append(A(sql['sql']['orderBy'][1][0][1][0]))
231 | self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))
232 | result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])))
233 | if sql['sql']['orderBy'][1][0][1][1] == 0:
234 | result.append(self._parser_column0(sql, select))
235 | else:
236 | result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]]))
237 | return result, None
238 |
239 |
240 | def parse_one_condition(self, sql_condit, names, sql):
241 | result = []
242 | # check if V(root)
243 | nest_query = True
244 | if type(sql_condit[3]) != dict:
245 | nest_query = False
246 |
247 | if sql_condit[0] == True:
248 | if sql_condit[1] == 9:
249 | # not like only with values
250 | fil = Filter(10)
251 | elif sql_condit[1] == 8:
252 | # not in with Root
253 | fil = Filter(19)
254 | else:
255 | print(sql_condit[1])
256 | raise NotImplementedError("not implement for the others FIL")
257 | else:
258 | # check for Filter (<,=,>,!=,between, >=, <=, ...)
259 | single_map = {1:8,2:2,3:5,4:4,5:7,6:6,7:3}
260 | nested_map = {1:15,2:11,3:13,4:12,5:16,6:17,7:14}
261 | if sql_condit[1] in [1, 2, 3, 4, 5, 6, 7]:
262 | if nest_query == False:
263 | fil = Filter(single_map[sql_condit[1]])
264 | else:
265 | fil = Filter(nested_map[sql_condit[1]])
266 | elif sql_condit[1] == 9:
267 | fil = Filter(9)
268 | elif sql_condit[1] == 8:
269 | fil = Filter(18)
270 | else:
271 | print(sql_condit[1])
272 | raise NotImplementedError("not implement for the others FIL")
273 |
274 | result.append(fil)
275 | result.append(A(sql_condit[2][1][0]))
276 | self.colSet.add(sql['col_set'].index(sql['names'][sql_condit[2][1][1]]))
277 | result.append(C(sql['col_set'].index(sql['names'][sql_condit[2][1][1]])))
278 | if sql_condit[2][1][1] == 0:
279 | select = sql['sql']['select'][1]
280 | result.append(self._parser_column0(sql, select))
281 | else:
282 | result.append(T(sql['col_table'][sql_condit[2][1][1]]))
283 |
284 | # check for the nested value
285 | if type(sql_condit[3]) == dict:
286 | nest_query = {}
287 | nest_query['names'] = names
288 | nest_query['query_toks_no_value'] = ""
289 | nest_query['sql'] = sql_condit[3]
290 | nest_query['col_table'] = sql['col_table']
291 | nest_query['col_set'] = sql['col_set']
292 | nest_query['table_names'] = sql['table_names']
293 | nest_query['question'] = sql['question']
294 | nest_query['query'] = sql['query']
295 | nest_query['keys'] = sql['keys']
296 | result.extend(self.parser(nest_query))
297 |
298 | return result
299 |
300 | def _parse_step(self, state, sql):
301 |
302 | if state == 'ROOT':
303 | return self._parse_root(sql)
304 |
305 | if state == 'SEL':
306 | return self._parse_select(sql)
307 |
308 | elif state == 'SUP':
309 | return self._parse_sup(sql)
310 |
311 | elif state == 'FILTER':
312 | return self._parse_filter(sql)
313 |
314 | elif state == 'ORDER':
315 | return self._parse_order(sql)
316 | else:
317 | raise NotImplementedError("Not the right state")
318 |
319 | def full_parse(self, query):
320 | sql = query['sql']
321 | nest_query = {}
322 | nest_query['names'] = query['names']
323 | nest_query['query_toks_no_value'] = ""
324 | nest_query['col_table'] = query['col_table']
325 | nest_query['col_set'] = query['col_set']
326 | nest_query['table_names'] = query['table_names']
327 | nest_query['question'] = query['question']
328 | nest_query['query'] = query['query']
329 | nest_query['keys'] = query['keys']
330 |
331 | if sql['intersect']:
332 | results = [Root1(0)]
333 | nest_query['sql'] = sql['intersect']
334 | results.extend(self.parser(query))
335 | results.extend(self.parser(nest_query))
336 | return results
337 |
338 | if sql['union']:
339 | results = [Root1(1)]
340 | nest_query['sql'] = sql['union']
341 | results.extend(self.parser(query))
342 | results.extend(self.parser(nest_query))
343 | return results
344 |
345 | if sql['except']:
346 | results = [Root1(2)]
347 | nest_query['sql'] = sql['except']
348 | results.extend(self.parser(query))
349 | results.extend(self.parser(nest_query))
350 | return results
351 |
352 | results = [Root1(3)]
353 | results.extend(self.parser(query))
354 |
355 | return results
356 |
357 | def parser(self, query):
358 | stack = ["ROOT"]
359 | result = []
360 | while len(stack) > 0:
361 | state = stack.pop()
362 | step_result, step_state = self._parse_step(state, query)
363 | result.extend(step_result)
364 | if step_state:
365 | stack.extend(step_state)
366 | return result
367 |
368 | if __name__ == '__main__':
369 | arg_parser = argparse.ArgumentParser()
370 | arg_parser.add_argument('--data_path', type=str, help='dataset', required=True)
371 | arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True)
372 | arg_parser.add_argument('--output', type=str, help='output data', required=True)
373 | args = arg_parser.parse_args()
374 |
375 | parser = Parser()
376 |
377 | # loading dataSets
378 | datas, table = load_dataSets(args)
379 | processed_data = []
380 |
381 | for i, d in enumerate(datas):
382 | if len(datas[i]['sql']['select'][1]) > 5:
383 | continue
384 | r = parser.full_parse(datas[i])
385 | datas[i]['rule_label'] = " ".join([str(x) for x in r])
386 | processed_data.append(datas[i])
387 |
388 | print('Finished %s datas and failed %s datas' % (len(processed_data), len(datas) - len(processed_data)))
389 | with open(args.output, 'w', encoding='utf8') as f:
390 | f.write(json.dumps(processed_data))
391 |
392 |
--------------------------------------------------------------------------------
/preprocess/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/24
7 | # @Author : Jiaqi&Zecheng
8 | # @File : utils.py
9 | # @Software: PyCharm
10 | """
11 | import os
12 | import json
13 | from pattern.en import lemma
14 | from nltk.stem import WordNetLemmatizer
15 |
16 | VALUE_FILTER = ['what', 'how', 'list', 'give', 'show', 'find', 'id', 'order', 'when']
17 | AGG = ['average', 'sum', 'max', 'min', 'minimum', 'maximum', 'between']
18 |
19 | wordnet_lemmatizer = WordNetLemmatizer()
20 |
21 | def load_dataSets(args):
22 | with open(args.table_path, 'r', encoding='utf8') as f:
23 | table_datas = json.load(f)
24 | with open(args.data_path, 'r', encoding='utf8') as f:
25 | datas = json.load(f)
26 |
27 | output_tab = {}
28 | tables = {}
29 | tabel_name = set()
30 | for i in range(len(table_datas)):
31 | table = table_datas[i]
32 | temp = {}
33 | temp['col_map'] = table['column_names']
34 | temp['table_names'] = table['table_names']
35 | tmp_col = []
36 | for cc in [x[1] for x in table['column_names']]:
37 | if cc not in tmp_col:
38 | tmp_col.append(cc)
39 | table['col_set'] = tmp_col
40 | db_name = table['db_id']
41 | tabel_name.add(db_name)
42 | table['schema_content'] = [col[1] for col in table['column_names']]
43 | table['col_table'] = [col[0] for col in table['column_names']]
44 | output_tab[db_name] = temp
45 | tables[db_name] = table
46 |
47 | for d in datas:
48 | d['names'] = tables[d['db_id']]['schema_content']
49 | d['table_names'] = tables[d['db_id']]['table_names']
50 | d['col_set'] = tables[d['db_id']]['col_set']
51 | d['col_table'] = tables[d['db_id']]['col_table']
52 | keys = {}
53 | for kv in tables[d['db_id']]['foreign_keys']:
54 | keys[kv[0]] = kv[1]
55 | keys[kv[1]] = kv[0]
56 | for id_k in tables[d['db_id']]['primary_keys']:
57 | keys[id_k] = id_k
58 | d['keys'] = keys
59 | return datas, tables
60 |
61 | def group_header(toks, idx, num_toks, header_toks):
62 | for endIdx in reversed(range(idx + 1, num_toks+1)):
63 | sub_toks = toks[idx: endIdx]
64 | sub_toks = " ".join(sub_toks)
65 | if sub_toks in header_toks:
66 | return endIdx, sub_toks
67 | return idx, None
68 |
69 | def fully_part_header(toks, idx, num_toks, header_toks):
70 | for endIdx in reversed(range(idx + 1, num_toks+1)):
71 | sub_toks = toks[idx: endIdx]
72 | if len(sub_toks) > 1:
73 | sub_toks = " ".join(sub_toks)
74 | if sub_toks in header_toks:
75 | return endIdx, sub_toks
76 | return idx, None
77 |
78 | def partial_header(toks, idx, header_toks):
79 | def check_in(list_one, list_two):
80 | if len(set(list_one) & set(list_two)) == len(list_one) and (len(list_two) <= 3):
81 | return True
82 | for endIdx in reversed(range(idx + 1, len(toks))):
83 | sub_toks = toks[idx: min(endIdx, len(toks))]
84 | if len(sub_toks) > 1:
85 | flag_count = 0
86 | tmp_heads = None
87 | for heads in header_toks:
88 | if check_in(sub_toks, heads):
89 | flag_count += 1
90 | tmp_heads = heads
91 | if flag_count == 1:
92 | return endIdx, tmp_heads
93 | return idx, None
94 |
95 | def symbol_filter(questions):
96 | question_tmp_q = []
97 | for q_id, q_val in enumerate(questions):
98 | if len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�', '鈥�'] and q_val[-1] in ["'", '"', '`', '鈥�']:
99 | question_tmp_q.append("'")
100 | question_tmp_q += ["".join(q_val[1:-1])]
101 | question_tmp_q.append("'")
102 | elif len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�'] :
103 | question_tmp_q.append("'")
104 | question_tmp_q += ["".join(q_val[1:])]
105 | elif len(q_val) > 2 and q_val[-1] in ["'", '"', '`', '鈥�']:
106 | question_tmp_q += ["".join(q_val[0:-1])]
107 | question_tmp_q.append("'")
108 | elif q_val in ["'", '"', '`', '鈥�', '鈥�', '``', "''"]:
109 | question_tmp_q += ["'"]
110 | else:
111 | question_tmp_q += [q_val]
112 | return question_tmp_q
113 |
114 |
115 | def group_values(toks, idx, num_toks):
116 | def check_isupper(tok_lists):
117 | for tok_one in tok_lists:
118 | if tok_one[0].isupper() is False:
119 | return False
120 | return True
121 |
122 | for endIdx in reversed(range(idx + 1, num_toks + 1)):
123 | sub_toks = toks[idx: endIdx]
124 |
125 | if len(sub_toks) > 1 and check_isupper(sub_toks) is True:
126 | return endIdx, sub_toks
127 | if len(sub_toks) == 1:
128 | if sub_toks[0][0].isupper() and sub_toks[0].lower() not in VALUE_FILTER and \
129 | sub_toks[0].lower().isalnum() is True:
130 | return endIdx, sub_toks
131 | return idx, None
132 |
133 |
134 | def group_digital(toks, idx):
135 | test = toks[idx].replace(':', '')
136 | test = test.replace('.', '')
137 | if test.isdigit():
138 | return True
139 | else:
140 | return False
141 |
142 | def group_symbol(toks, idx, num_toks):
143 | if toks[idx-1] == "'":
144 | for i in range(0, min(3, num_toks-idx)):
145 | if toks[i + idx] == "'":
146 | return i + idx, toks[idx:i+idx]
147 | return idx, None
148 |
149 |
150 | def num2year(tok):
151 | if len(str(tok)) == 4 and str(tok).isdigit() and int(str(tok)[:2]) < 22 and int(str(tok)[:2]) > 15:
152 | return True
153 | return False
154 |
155 | def set_header(toks, header_toks, tok_concol, idx, num_toks):
156 | def check_in(list_one, list_two):
157 | if set(list_one) == set(list_two):
158 | return True
159 | for endIdx in range(idx, num_toks):
160 | toks += tok_concol[endIdx]
161 | if len(tok_concol[endIdx]) > 1:
162 | break
163 | for heads in header_toks:
164 | if check_in(toks, heads):
165 | return heads
166 | return None
167 |
168 | def re_lemma(string):
169 | lema = lemma(string.lower())
170 | if len(lema) > 0:
171 | return lema
172 | else:
173 | return string.lower()
174 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | nltk==3.4
5 | pattern
6 | numpy==1.14.0
7 | pytorch-pretrained-bert==0.5.1
8 | tqdm==4.31.1
--------------------------------------------------------------------------------
/sem2SQL.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/27
7 | # @Author : Jiaqi&Zecheng
8 | # @File : sem2SQL.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import argparse
13 | import traceback
14 |
15 | from src.rule.graph import Graph
16 | from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
17 | from src.rule.sem_utils import alter_inter, alter_not_in, alter_column0, load_dataSets
18 |
19 |
20 | def split_logical_form(lf):
21 | indexs = [i+1 for i, letter in enumerate(lf) if letter == ')']
22 | indexs.insert(0, 0)
23 | components = list()
24 | for i in range(1, len(indexs)):
25 | components.append(lf[indexs[i-1]:indexs[i]].strip())
26 | return components
27 |
28 |
29 | def pop_front(array):
30 | if len(array) == 0:
31 | return 'None'
32 | return array.pop(0)
33 |
34 |
35 | def is_end(components, transformed_sql, is_root_processed):
36 | end = False
37 | c = pop_front(components)
38 | c_instance = eval(c)
39 |
40 | if isinstance(c_instance, Root) and is_root_processed:
41 | # intersect, union, except
42 | end = True
43 | elif isinstance(c_instance, Filter):
44 | if 'where' not in transformed_sql:
45 | end = True
46 | else:
47 | num_conjunction = 0
48 | for f in transformed_sql['where']:
49 | if isinstance(f, str) and (f == 'and' or f == 'or'):
50 | num_conjunction += 1
51 | current_filters = len(transformed_sql['where'])
52 | valid_filters = current_filters - num_conjunction
53 | if valid_filters >= num_conjunction + 1:
54 | end = True
55 | elif isinstance(c_instance, Order):
56 | if 'order' not in transformed_sql:
57 | end = True
58 | elif len(transformed_sql['order']) == 0:
59 | end = False
60 | else:
61 | end = True
62 | elif isinstance(c_instance, Sup):
63 | if 'sup' not in transformed_sql:
64 | end = True
65 | elif len(transformed_sql['sup']) == 0:
66 | end = False
67 | else:
68 | end = True
69 | components.insert(0, c)
70 | return end
71 |
72 |
73 | def _transform(components, transformed_sql, col_set, table_names, schema):
74 | processed_root = False
75 | current_table = schema
76 |
77 | while len(components) > 0:
78 | if is_end(components, transformed_sql, processed_root):
79 | break
80 | c = pop_front(components)
81 | c_instance = eval(c)
82 | if isinstance(c_instance, Root):
83 | processed_root = True
84 | transformed_sql['select'] = list()
85 | if c_instance.id_c == 0:
86 | transformed_sql['where'] = list()
87 | transformed_sql['sup'] = list()
88 | elif c_instance.id_c == 1:
89 | transformed_sql['where'] = list()
90 | transformed_sql['order'] = list()
91 | elif c_instance.id_c == 2:
92 | transformed_sql['sup'] = list()
93 | elif c_instance.id_c == 3:
94 | transformed_sql['where'] = list()
95 | elif c_instance.id_c == 4:
96 | transformed_sql['order'] = list()
97 | elif isinstance(c_instance, Sel):
98 | continue
99 | elif isinstance(c_instance, N):
100 | for i in range(c_instance.id_c + 1):
101 | agg = eval(pop_front(components))
102 | column = eval(pop_front(components))
103 | _table = pop_front(components)
104 | table = eval(_table)
105 | if not isinstance(table, T):
106 | table = None
107 | components.insert(0, _table)
108 | assert isinstance(agg, A) and isinstance(column, C)
109 |
110 | transformed_sql['select'].append((
111 | agg.production.split()[1],
112 | replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) if table is not None else col_set[column.id_c],
113 | table_names[table.id_c] if table is not None else table
114 | ))
115 |
116 | elif isinstance(c_instance, Sup):
117 | transformed_sql['sup'].append(c_instance.production.split()[1])
118 | agg = eval(pop_front(components))
119 | column = eval(pop_front(components))
120 | _table = pop_front(components)
121 | table = eval(_table)
122 | if not isinstance(table, T):
123 | table = None
124 | components.insert(0, _table)
125 | assert isinstance(agg, A) and isinstance(column, C)
126 |
127 | transformed_sql['sup'].append(agg.production.split()[1])
128 | if table:
129 | fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
130 | else:
131 | fix_col_id = col_set[column.id_c]
132 | raise RuntimeError('not found table !!!!')
133 | transformed_sql['sup'].append(fix_col_id)
134 | transformed_sql['sup'].append(table_names[table.id_c] if table is not None else table)
135 |
136 | elif isinstance(c_instance, Order):
137 | transformed_sql['order'].append(c_instance.production.split()[1])
138 | agg = eval(pop_front(components))
139 | column = eval(pop_front(components))
140 | _table = pop_front(components)
141 | table = eval(_table)
142 | if not isinstance(table, T):
143 | table = None
144 | components.insert(0, _table)
145 | assert isinstance(agg, A) and isinstance(column, C)
146 | transformed_sql['order'].append(agg.production.split()[1])
147 | transformed_sql['order'].append(replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table))
148 | transformed_sql['order'].append(table_names[table.id_c] if table is not None else table)
149 |
150 | elif isinstance(c_instance, Filter):
151 | op = c_instance.production.split()[1]
152 | if op == 'and' or op == 'or':
153 | transformed_sql['where'].append(op)
154 | else:
155 | # No Supquery
156 | agg = eval(pop_front(components))
157 | column = eval(pop_front(components))
158 | _table = pop_front(components)
159 | table = eval(_table)
160 | if not isinstance(table, T):
161 | table = None
162 | components.insert(0, _table)
163 | assert isinstance(agg, A) and isinstance(column, C)
164 | if len(c_instance.production.split()) == 3:
165 | if table:
166 | fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
167 | else:
168 | fix_col_id = col_set[column.id_c]
169 | raise RuntimeError('not found table !!!!')
170 | transformed_sql['where'].append((
171 | op,
172 | agg.production.split()[1],
173 | fix_col_id,
174 | table_names[table.id_c] if table is not None else table,
175 | None
176 | ))
177 | else:
178 | # Subquery
179 | new_dict = dict()
180 | new_dict['sql'] = transformed_sql['sql']
181 | transformed_sql['where'].append((
182 | op,
183 | agg.production.split()[1],
184 | replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table),
185 | table_names[table.id_c] if table is not None else table,
186 | _transform(components, new_dict, col_set, table_names, schema)
187 | ))
188 |
189 | return transformed_sql
190 |
191 |
192 | def transform(query, schema, origin=None):
193 | preprocess_schema(schema)
194 | if origin is None:
195 | lf = query['model_result_replace']
196 | else:
197 | lf = origin
198 | # lf = query['rule_label']
199 | col_set = query['col_set']
200 | table_names = query['table_names']
201 | current_table = schema
202 |
203 | current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
204 | current_table['schema_content'] = [x[1] for x in current_table['column_names_original']]
205 |
206 | components = split_logical_form(lf)
207 |
208 | transformed_sql = dict()
209 | transformed_sql['sql'] = query
210 | c = pop_front(components)
211 | c_instance = eval(c)
212 | assert isinstance(c_instance, Root1)
213 | if c_instance.id_c == 0:
214 | transformed_sql['intersect'] = dict()
215 | transformed_sql['intersect']['sql'] = query
216 |
217 | _transform(components, transformed_sql, col_set, table_names, schema)
218 | _transform(components, transformed_sql['intersect'], col_set, table_names, schema)
219 | elif c_instance.id_c == 1:
220 | transformed_sql['union'] = dict()
221 | transformed_sql['union']['sql'] = query
222 | _transform(components, transformed_sql, col_set, table_names, schema)
223 | _transform(components, transformed_sql['union'], col_set, table_names, schema)
224 | elif c_instance.id_c == 2:
225 | transformed_sql['except'] = dict()
226 | transformed_sql['except']['sql'] = query
227 | _transform(components, transformed_sql, col_set, table_names, schema)
228 | _transform(components, transformed_sql['except'], col_set, table_names, schema)
229 | else:
230 | _transform(components, transformed_sql, col_set, table_names, schema)
231 |
232 | parse_result = to_str(transformed_sql, 1, schema)
233 |
234 | parse_result = parse_result.replace('\t', '')
235 | return [parse_result]
236 |
237 | def col_to_str(agg, col, tab, table_names, N=1):
238 | _col = col.replace(' ', '_')
239 | if agg == 'none':
240 | if tab not in table_names:
241 | table_names[tab] = 'T' + str(len(table_names) + N)
242 | table_alias = table_names[tab]
243 | if col == '*':
244 | return '*'
245 | return '%s.%s' % (table_alias, _col)
246 | else:
247 | if col == '*':
248 | if tab is not None and tab not in table_names:
249 | table_names[tab] = 'T' + str(len(table_names) + N)
250 | return '%s(%s)' % (agg, _col)
251 | else:
252 | if tab not in table_names:
253 | table_names[tab] = 'T' + str(len(table_names) + N)
254 | table_alias = table_names[tab]
255 | return '%s(%s.%s)' % (agg, table_alias, _col)
256 |
257 |
258 | def infer_from_clause(table_names, schema, columns):
259 | tables = list(table_names.keys())
260 | # print(table_names)
261 | start_table = None
262 | end_table = None
263 | join_clause = list()
264 | if len(tables) == 1:
265 | join_clause.append((tables[0], table_names[tables[0]]))
266 | elif len(tables) == 2:
267 | use_graph = True
268 | # print(schema['graph'].vertices)
269 | for t in tables:
270 | if t not in schema['graph'].vertices:
271 | use_graph = False
272 | break
273 | if use_graph:
274 | start_table = tables[0]
275 | end_table = tables[1]
276 | _tables = list(schema['graph'].dijkstra(tables[0], tables[1]))
277 | # print('Two tables: ', _tables)
278 | max_key = 1
279 | for t, k in table_names.items():
280 | _k = int(k[1:])
281 | if _k > max_key:
282 | max_key = _k
283 | for t in _tables:
284 | if t not in table_names:
285 | table_names[t] = 'T' + str(max_key + 1)
286 | max_key += 1
287 | join_clause.append((t, table_names[t],))
288 | else:
289 | join_clause = list()
290 | for t in tables:
291 | join_clause.append((t, table_names[t],))
292 | else:
293 | # > 2
294 | # print('More than 2 table')
295 | for t in tables:
296 | join_clause.append((t, table_names[t],))
297 |
298 | if len(join_clause) >= 3:
299 | star_table = None
300 | for agg, col, tab in columns:
301 | if col == '*':
302 | star_table = tab
303 | break
304 | if star_table is not None:
305 | star_table_count = 0
306 | for agg, col, tab in columns:
307 | if tab == star_table and col != '*':
308 | star_table_count += 1
309 | if star_table_count == 0 and ((end_table is None or end_table == star_table) or (start_table is None or start_table == star_table)):
310 | # Remove the table the rest tables still can join without star_table
311 | new_join_clause = list()
312 | for t in join_clause:
313 | if t[0] != star_table:
314 | new_join_clause.append(t)
315 | join_clause = new_join_clause
316 |
317 | join_clause = ' JOIN '.join(['%s AS %s' % (jc[0], jc[1]) for jc in join_clause])
318 | return 'FROM ' + join_clause
319 |
320 | def replace_col_with_original_col(query, col, current_table):
321 | # print(query, col)
322 | if query == '*':
323 | return query
324 |
325 | cur_table = col
326 | cur_col = query
327 | single_final_col = None
328 | # print(query, col)
329 | for col_ind, col_name in enumerate(current_table['schema_content_clean']):
330 | if col_name == cur_col:
331 | assert cur_table in current_table['table_names']
332 | if current_table['table_names'][current_table['col_table'][col_ind]] == cur_table:
333 | single_final_col = current_table['column_names_original'][col_ind][1]
334 | break
335 |
336 | assert single_final_col
337 | # if query != single_final_col:
338 | # print(query, single_final_col)
339 | return single_final_col
340 |
341 |
342 | def build_graph(schema):
343 | relations = list()
344 | foreign_keys = schema['foreign_keys']
345 | for (fkey, pkey) in foreign_keys:
346 | fkey_table = schema['table_names_original'][schema['column_names'][fkey][0]]
347 | pkey_table = schema['table_names_original'][schema['column_names'][pkey][0]]
348 | relations.append((fkey_table, pkey_table))
349 | relations.append((pkey_table, fkey_table))
350 | return Graph(relations)
351 |
352 |
353 | def preprocess_schema(schema):
354 | tmp_col = []
355 | for cc in [x[1] for x in schema['column_names']]:
356 | if cc not in tmp_col:
357 | tmp_col.append(cc)
358 | schema['col_set'] = tmp_col
359 | # print table
360 | schema['schema_content'] = [col[1] for col in schema['column_names']]
361 | schema['col_table'] = [col[0] for col in schema['column_names']]
362 | graph = build_graph(schema)
363 | schema['graph'] = graph
364 |
365 |
366 |
367 |
368 | def to_str(sql_json, N_T, schema, pre_table_names=None):
369 | all_columns = list()
370 | select_clause = list()
371 | table_names = dict()
372 | current_table = schema
373 | for (agg, col, tab) in sql_json['select']:
374 | all_columns.append((agg, col, tab))
375 | select_clause.append(col_to_str(agg, col, tab, table_names, N_T))
376 | select_clause_str = 'SELECT ' + ', '.join(select_clause).strip()
377 |
378 | sup_clause = ''
379 | order_clause = ''
380 | direction_map = {"des": 'DESC', 'asc': 'ASC'}
381 |
382 | if 'sup' in sql_json:
383 | (direction, agg, col, tab,) = sql_json['sup']
384 | all_columns.append((agg, col, tab))
385 | subject = col_to_str(agg, col, tab, table_names, N_T)
386 | sup_clause = ('ORDER BY %s %s LIMIT 1' % (subject, direction_map[direction])).strip()
387 | elif 'order' in sql_json:
388 | (direction, agg, col, tab,) = sql_json['order']
389 | all_columns.append((agg, col, tab))
390 | subject = col_to_str(agg, col, tab, table_names, N_T)
391 | order_clause = ('ORDER BY %s %s' % (subject, direction_map[direction])).strip()
392 |
393 | has_group_by = False
394 | where_clause = ''
395 | have_clause = ''
396 | if 'where' in sql_json:
397 | conjunctions = list()
398 | filters = list()
399 | # print(sql_json['where'])
400 | for f in sql_json['where']:
401 | if isinstance(f, str):
402 | conjunctions.append(f)
403 | else:
404 | op, agg, col, tab, value = f
405 | if value:
406 | value['sql'] = sql_json['sql']
407 | all_columns.append((agg, col, tab))
408 | subject = col_to_str(agg, col, tab, table_names, N_T)
409 | if value is None:
410 | where_value = '1'
411 | if op == 'between':
412 | where_value = '1 AND 2'
413 | filters.append('%s %s %s' % (subject, op, where_value))
414 | else:
415 | if op == 'in' and len(value['select']) == 1 and value['select'][0][0] == 'none' \
416 | and 'where' not in value and 'order' not in value and 'sup' not in value:
417 | # and value['select'][0][2] not in table_names:
418 | if value['select'][0][2] not in table_names:
419 | table_names[value['select'][0][2]] = 'T' + str(len(table_names) + N_T)
420 | filters.append(None)
421 |
422 | else:
423 | filters.append('%s %s %s' % (subject, op, '(' + to_str(value, len(table_names) + 1, schema) + ')'))
424 | if len(conjunctions):
425 | filters.append(conjunctions.pop())
426 |
427 | aggs = ['count(', 'avg(', 'min(', 'max(', 'sum(']
428 | having_filters = list()
429 | idx = 0
430 | while idx < len(filters):
431 | _filter = filters[idx]
432 | if _filter is None:
433 | idx += 1
434 | continue
435 | for agg in aggs:
436 | if _filter.startswith(agg):
437 | having_filters.append(_filter)
438 | filters.pop(idx)
439 | # print(filters)
440 | if 0 < idx and (filters[idx - 1] in ['and', 'or']):
441 | filters.pop(idx - 1)
442 | # print(filters)
443 | break
444 | else:
445 | idx += 1
446 | if len(having_filters) > 0:
447 | have_clause = 'HAVING ' + ' '.join(having_filters).strip()
448 | if len(filters) > 0:
449 | # print(filters)
450 | filters = [_f for _f in filters if _f is not None]
451 | conjun_num = 0
452 | filter_num = 0
453 | for _f in filters:
454 | if _f in ['or', 'and']:
455 | conjun_num += 1
456 | else:
457 | filter_num += 1
458 | if conjun_num > 0 and filter_num != (conjun_num + 1):
459 | # assert 'and' in filters
460 | idx = 0
461 | while idx < len(filters):
462 | if filters[idx] == 'and':
463 | if idx - 1 == 0:
464 | filters.pop(idx)
465 | break
466 | if filters[idx - 1] in ['and', 'or']:
467 | filters.pop(idx)
468 | break
469 | if idx + 1 >= len(filters) - 1:
470 | filters.pop(idx)
471 | break
472 | if filters[idx + 1] in ['and', 'or']:
473 | filters.pop(idx)
474 | break
475 | idx += 1
476 | if len(filters) > 0:
477 | where_clause = 'WHERE ' + ' '.join(filters).strip()
478 | where_clause = where_clause.replace('not_in', 'NOT IN')
479 | else:
480 | where_clause = ''
481 |
482 | if len(having_filters) > 0:
483 | has_group_by = True
484 |
485 | for agg in ['count(', 'avg(', 'min(', 'max(', 'sum(']:
486 | if (len(sql_json['select']) > 1 and agg in select_clause_str)\
487 | or agg in sup_clause or agg in order_clause:
488 | has_group_by = True
489 | break
490 |
491 | group_by_clause = ''
492 | if has_group_by:
493 | if len(table_names) == 1:
494 | # check none agg
495 | is_agg_flag = False
496 | for (agg, col, tab) in sql_json['select']:
497 |
498 | if agg == 'none':
499 | group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
500 | else:
501 | is_agg_flag = True
502 |
503 | if is_agg_flag is False and len(group_by_clause) > 5:
504 | group_by_clause = "GROUP BY"
505 | for (agg, col, tab) in sql_json['select']:
506 | group_by_clause = group_by_clause + ' ' + col_to_str(agg, col, tab, table_names, N_T)
507 |
508 | if len(group_by_clause) < 5:
509 | if 'count(*)' in select_clause_str:
510 | current_table = schema
511 | for primary in current_table['primary_keys']:
512 | if current_table['table_names'][current_table['col_table'][primary]] in table_names :
513 | group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][primary],
514 | current_table['table_names'][
515 | current_table['col_table'][primary]],
516 | table_names, N_T)
517 | else:
518 | # if only one select
519 | if len(sql_json['select']) == 1:
520 | agg, col, tab = sql_json['select'][0]
521 | non_lists = [tab]
522 | fix_flag = False
523 | # add tab from other part
524 | for key, value in table_names.items():
525 | if key not in non_lists:
526 | non_lists.append(key)
527 |
528 | a = non_lists[0]
529 | b = None
530 | for non in non_lists:
531 | if a != non:
532 | b = non
533 | if b:
534 | for pair in current_table['foreign_keys']:
535 | t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
536 | t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
537 | if t1 in [a, b] and t2 in [a, b]:
538 | if pre_table_names and t1 not in pre_table_names:
539 | assert t2 in pre_table_names
540 | t1 = t2
541 | group_by_clause = 'GROUP BY ' + col_to_str('none',
542 | current_table['schema_content'][pair[0]],
543 | t1,
544 | table_names, N_T)
545 | fix_flag = True
546 | break
547 |
548 | if fix_flag is False:
549 | agg, col, tab = sql_json['select'][0]
550 | group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
551 |
552 | else:
553 | # check if there are only one non agg
554 | non_agg, non_agg_count = None, 0
555 | non_lists = []
556 | for (agg, col, tab) in sql_json['select']:
557 | if agg == 'none':
558 | non_agg = (agg, col, tab)
559 | non_lists.append(tab)
560 | non_agg_count += 1
561 |
562 | non_lists = list(set(non_lists))
563 | # print(non_lists)
564 | if non_agg_count == 1:
565 | group_by_clause = 'GROUP BY ' + col_to_str(non_agg[0], non_agg[1], non_agg[2], table_names, N_T)
566 | elif non_agg:
567 | find_flag = False
568 | fix_flag = False
569 | find_primary = None
570 | if len(non_lists) <= 1:
571 | for key, value in table_names.items():
572 | if key not in non_lists:
573 | non_lists.append(key)
574 | if len(non_lists) > 1:
575 | a = non_lists[0]
576 | b = None
577 | for non in non_lists:
578 | if a != non:
579 | b = non
580 | if b:
581 | for pair in current_table['foreign_keys']:
582 | t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
583 | t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
584 | if t1 in [a, b] and t2 in [a, b]:
585 | if pre_table_names and t1 not in pre_table_names:
586 | assert t2 in pre_table_names
587 | t1 = t2
588 | group_by_clause = 'GROUP BY ' + col_to_str('none',
589 | current_table['schema_content'][pair[0]],
590 | t1,
591 | table_names, N_T)
592 | fix_flag = True
593 | break
594 | tab = non_agg[2]
595 | assert tab in current_table['table_names']
596 |
597 | for primary in current_table['primary_keys']:
598 | if current_table['table_names'][current_table['col_table'][primary]] == tab:
599 | find_flag = True
600 | find_primary = (current_table['schema_content'][primary], tab)
601 | if fix_flag is False:
602 | if find_flag is False:
603 | # rely on count *
604 | foreign = []
605 | for pair in current_table['foreign_keys']:
606 | if current_table['table_names'][current_table['col_table'][pair[0]]] == tab:
607 | foreign.append(pair[1])
608 | if current_table['table_names'][current_table['col_table'][pair[1]]] == tab:
609 | foreign.append(pair[0])
610 |
611 | for pair in foreign:
612 | if current_table['table_names'][current_table['col_table'][pair]] in table_names:
613 | group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][pair],
614 | current_table['table_names'][current_table['col_table'][pair]],
615 | table_names, N_T)
616 | find_flag = True
617 | break
618 | if find_flag is False:
619 | for (agg, col, tab) in sql_json['select']:
620 | if 'id' in col.lower():
621 | group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
622 | break
623 | if len(group_by_clause) > 5:
624 | pass
625 | else:
626 | raise RuntimeError('fail to convert')
627 | else:
628 | group_by_clause = 'GROUP BY ' + col_to_str('none', find_primary[0],
629 | find_primary[1],
630 | table_names, N_T)
631 | intersect_clause = ''
632 | if 'intersect' in sql_json:
633 | sql_json['intersect']['sql'] = sql_json['sql']
634 | intersect_clause = 'INTERSECT ' + to_str(sql_json['intersect'], len(table_names) + 1, schema, table_names)
635 | union_clause = ''
636 | if 'union' in sql_json:
637 | sql_json['union']['sql'] = sql_json['sql']
638 | union_clause = 'UNION ' + to_str(sql_json['union'], len(table_names) + 1, schema, table_names)
639 | except_clause = ''
640 | if 'except' in sql_json:
641 | sql_json['except']['sql'] = sql_json['sql']
642 | except_clause = 'EXCEPT ' + to_str(sql_json['except'], len(table_names) + 1, schema, table_names)
643 |
644 | # print(current_table['table_names_original'])
645 | table_names_replace = {}
646 | for a, b in zip(current_table['table_names_original'], current_table['table_names']):
647 | table_names_replace[b] = a
648 | new_table_names = {}
649 | for key, value in table_names.items():
650 | if key is None:
651 | continue
652 | new_table_names[table_names_replace[key]] = value
653 | from_clause = infer_from_clause(new_table_names, schema, all_columns).strip()
654 |
655 | sql = ' '.join([select_clause_str, from_clause, where_clause, group_by_clause, have_clause, sup_clause, order_clause,
656 | intersect_clause, union_clause, except_clause])
657 |
658 | return sql
659 |
660 |
661 | if __name__ == '__main__':
662 |
663 | arg_parser = argparse.ArgumentParser()
664 | arg_parser.add_argument('--data_path', type=str, help='dataset path', required=True)
665 | arg_parser.add_argument('--input_path', type=str, help='predicted logical form', required=True)
666 | arg_parser.add_argument('--output_path', type=str, help='output data')
667 | args = arg_parser.parse_args()
668 |
669 | # loading dataSets
670 | datas, schemas = load_dataSets(args)
671 | alter_not_in(datas, schemas=schemas)
672 | alter_inter(datas)
673 | alter_column0(datas)
674 |
675 |
676 | index = range(len(datas))
677 | count = 0
678 | exception_count = 0
679 | with open(args.output_path, 'w', encoding='utf8') as d, open('gold.txt', 'w', encoding='utf8') as g:
680 | for i in index:
681 | try:
682 | result = transform(datas[i], schemas[datas[i]['db_id']])
683 | d.write(result[0] + '\n')
684 | g.write("%s\t%s\t%s\n" % (datas[i]['query'], datas[i]["db_id"], datas[i]["question"]))
685 | count += 1
686 | except Exception as e:
687 | result = transform(datas[i], schemas[datas[i]['db_id']], origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)')
688 | exception_count += 1
689 | d.write(result[0] + '\n')
690 | g.write("%s\t%s\t%s\n" % (datas[i]['query'], datas[i]["db_id"], datas[i]["question"]))
691 | count += 1
692 | print(e)
693 | print('Exception')
694 | print(traceback.format_exc())
695 | print('===\n\n')
696 |
697 | print(count, exception_count)
698 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : __init__.py
9 | # @Software: PyCharm
10 | """
--------------------------------------------------------------------------------
/src/args.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : args.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import random
13 | import argparse
14 | import torch
15 | import numpy as np
16 |
17 | def init_arg_parser():
18 | arg_parser = argparse.ArgumentParser()
19 | arg_parser.add_argument('--seed', default=5783287, type=int, help='random seed')
20 | arg_parser.add_argument('--cuda', action='store_true', help='use gpu')
21 | arg_parser.add_argument('--lr_scheduler', action='store_true', help='use learning rate scheduler')
22 | arg_parser.add_argument('--lr_scheduler_gammar', default=0.5, type=float, help='decay rate of learning rate scheduler')
23 | arg_parser.add_argument('--column_pointer', action='store_true', help='use column pointer')
24 | arg_parser.add_argument('--loss_epoch_threshold', default=20, type=int, help='loss epoch threshold')
25 | arg_parser.add_argument('--sketch_loss_coefficient', default=0.2, type=float, help='sketch loss coefficient')
26 | arg_parser.add_argument('--sentence_features', action='store_true', help='use sentence features')
27 | arg_parser.add_argument('--model_name', choices=['transformer', 'rnn', 'table', 'sketch'], default='rnn',
28 | help='model name')
29 |
30 | arg_parser.add_argument('--lstm', choices=['lstm', 'lstm_with_dropout', 'parent_feed'], default='lstm')
31 |
32 | arg_parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model')
33 | arg_parser.add_argument('--glove_embed_path', default="glove.42B.300d.txt", type=str)
34 |
35 | arg_parser.add_argument('--batch_size', default=64, type=int, help='batch size')
36 | arg_parser.add_argument('--beam_size', default=5, type=int, help='beam size for beam search')
37 | arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings')
38 | arg_parser.add_argument('--col_embed_size', default=300, type=int, help='size of word embeddings')
39 |
40 | arg_parser.add_argument('--action_embed_size', default=128, type=int, help='size of word embeddings')
41 | arg_parser.add_argument('--type_embed_size', default=128, type=int, help='size of word embeddings')
42 | arg_parser.add_argument('--hidden_size', default=100, type=int, help='size of LSTM hidden states')
43 | arg_parser.add_argument('--att_vec_size', default=100, type=int, help='size of attentional vector')
44 | arg_parser.add_argument('--dropout', default=0.3, type=float, help='dropout rate')
45 | arg_parser.add_argument('--word_dropout', default=0.2, type=float, help='word dropout rate')
46 |
47 | # readout layer
48 | arg_parser.add_argument('--no_query_vec_to_action_map', default=False, action='store_true')
49 | arg_parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear'])
50 | arg_parser.add_argument('--query_vec_to_action_diff_map', default=False, action='store_true')
51 |
52 |
53 | arg_parser.add_argument('--column_att', choices=['dot_prod', 'affine'], default='affine')
54 |
55 | arg_parser.add_argument('--decode_max_time_step', default=40, type=int, help='maximum number of time steps used '
56 | 'in decoding and sampling')
57 |
58 |
59 | arg_parser.add_argument('--save_to', default='model', type=str, help='save trained model to')
60 | arg_parser.add_argument('--toy', action='store_true',
61 | help='If set, use small data; used for fast debugging.')
62 | arg_parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients')
63 | arg_parser.add_argument('--max_epoch', default=-1, type=int, help='maximum number of training epoches')
64 | arg_parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer')
65 | arg_parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
66 |
67 | arg_parser.add_argument('--dataset', default="./data", type=str)
68 |
69 | arg_parser.add_argument('--epoch', default=50, type=int, help='Maximum Epoch')
70 | arg_parser.add_argument('--save', default='./', type=str,
71 | help="Path to save the checkpoint and logs of epoch")
72 |
73 | return arg_parser
74 |
75 | def init_config(arg_parser):
76 | args = arg_parser.parse_args()
77 | torch.manual_seed(args.seed)
78 | if args.cuda:
79 | torch.cuda.manual_seed(args.seed)
80 | np.random.seed(int(args.seed * 13 / 7))
81 | random.seed(int(args.seed))
82 | return args
83 |
--------------------------------------------------------------------------------
/src/beam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : beam.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import copy
13 |
14 | from src.rule import semQL
15 |
16 |
17 | class ActionInfo(object):
18 | """sufficient statistics for making a prediction of an action at a time step"""
19 |
20 | def __init__(self, action=None):
21 | self.t = 0
22 | self.score = 0
23 | self.parent_t = -1
24 | self.action = action
25 | self.frontier_prod = None
26 | self.frontier_field = None
27 |
28 | # for GenToken actions only
29 | self.copy_from_src = False
30 | self.src_token_position = -1
31 |
32 |
33 | class Beams(object):
34 | def __init__(self, is_sketch=False):
35 | self.actions = []
36 | self.action_infos = []
37 | self.inputs = []
38 | self.score = 0.
39 | self.t = 0
40 | self.is_sketch = is_sketch
41 | self.sketch_step = 0
42 | self.sketch_attention_history = list()
43 |
44 | def get_availableClass(self):
45 | """
46 | return the available action class
47 | :return:
48 | """
49 |
50 | # TODO: it could be update by speed
51 | # return the available class using rule
52 | # FIXME: now should change for these 11: "Filter 1 ROOT",
53 | def check_type(lists):
54 | for s in lists:
55 | if type(s) == int:
56 | return False
57 | return True
58 |
59 | stack = [semQL.Root1]
60 | for action in self.actions:
61 | infer_action = action.get_next_action(is_sketch=self.is_sketch)
62 | infer_action.reverse()
63 | if stack[-1] is type(action):
64 | stack.pop()
65 | # check if the are non-terminal
66 | if check_type(infer_action):
67 | stack.extend(infer_action)
68 | else:
69 | raise RuntimeError("Not the right action")
70 |
71 | result = stack[-1] if len(stack) > 0 else None
72 |
73 | return result
74 |
75 | @classmethod
76 | def get_parent_action(cls, actions):
77 | """
78 |
79 | :param actions:
80 | :return:
81 | """
82 |
83 | def check_type(lists):
84 | for s in lists:
85 | if type(s) == int:
86 | return False
87 | return True
88 |
89 | # check the origin state Root
90 | if len(actions) == 0:
91 | return None
92 |
93 | stack = [semQL.Root1]
94 | for id_x, action in enumerate(actions):
95 | infer_action = action.get_next_action()
96 | for ac in infer_action:
97 | ac.parent = action
98 | ac.pt = id_x
99 | infer_action.reverse()
100 | if stack[-1] is type(action):
101 | stack.pop()
102 | # check if the are non-terminal
103 | if check_type(infer_action):
104 | stack.extend(infer_action)
105 | else:
106 | for t in actions:
107 | if type(t) != semQL.C:
108 | print(t, end="")
109 | print('asd')
110 | print(action)
111 | print(stack[-1])
112 | raise RuntimeError("Not the right action")
113 | result = stack[-1] if len(stack) > 0 else None
114 |
115 | return result
116 |
117 | def apply_action(self, action):
118 | # TODO: not finish implement yet
119 | self.t += 1
120 | self.actions.append(action)
121 |
122 | def clone_and_apply_action(self, action):
123 | new_hyp = self.copy()
124 | new_hyp.apply_action(action)
125 |
126 | return new_hyp
127 |
128 | def clone_and_apply_action_info(self, action_info):
129 | action = action_info.action
130 | action.score = action_info.score
131 | new_hyp = self.clone_and_apply_action(action)
132 | new_hyp.action_infos.append(action_info)
133 | new_hyp.sketch_step = self.sketch_step
134 | new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)
135 |
136 | return new_hyp
137 |
138 | def copy(self):
139 | new_hyp = Beams(is_sketch=self.is_sketch)
140 | # if self.tree:
141 | # new_hyp.tree = self.tree.copy()
142 |
143 | new_hyp.actions = list(self.actions)
144 | new_hyp.score = self.score
145 | new_hyp.t = self.t
146 | new_hyp.sketch_step = self.sketch_step
147 | new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)
148 |
149 | return new_hyp
150 |
151 | def infer_n(self):
152 | if len(self.actions) > 4:
153 | prev_action = self.actions[-3]
154 | if isinstance(prev_action, semQL.Filter):
155 | if prev_action.id_c > 11:
156 | # Nested Query, only select 1 column
157 | return ['N A']
158 | if self.actions[0].id_c != 3:
159 | return [self.actions[3].production]
160 | return semQL.N._init_grammar()
161 |
162 | @property
163 | def completed(self):
164 | return True if self.get_availableClass() is None else False
165 |
166 | @property
167 | def is_valid(self):
168 | actions = self.actions
169 | return self.check_sel_valid(actions)
170 |
171 | def check_sel_valid(self, actions):
172 | find_sel = False
173 | sel_actions = list()
174 | for ac in actions:
175 | if type(ac) == semQL.Sel:
176 | find_sel = True
177 | elif find_sel and type(ac) in [semQL.N, semQL.T, semQL.C, semQL.A]:
178 | if type(ac) not in [semQL.N]:
179 | sel_actions.append(ac)
180 | elif find_sel and type(ac) not in [semQL.N, semQL.T, semQL.C, semQL.A]:
181 | break
182 |
183 | if find_sel is False:
184 | return True
185 |
186 | # not the complete sel lf
187 | if len(sel_actions) % 3 != 0:
188 | return True
189 |
190 | sel_string = list()
191 | for i in range(len(sel_actions) // 3):
192 | if (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c) in sel_string:
193 | return False
194 | else:
195 | sel_string.append(
196 | (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c))
197 | return True
198 |
199 |
200 | if __name__ == '__main__':
201 | test = Beams(is_sketch=True)
202 | # print(semQL.Root1(1).get_next_action())
203 | test.actions.append(semQL.Root1(3))
204 | test.actions.append(semQL.Root(5))
205 |
206 | print(str(test.get_availableClass()))
207 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : utils.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import copy
13 |
14 | import src.rule.semQL as define_rule
15 | from src.models import nn_utils
16 |
17 |
18 | class Example:
19 | """
20 |
21 | """
22 | def __init__(self, src_sent, tgt_actions=None, vis_seq=None, tab_cols=None, col_num=None, sql=None,
23 | one_hot_type=None, col_hot_type=None, schema_len=None, tab_ids=None,
24 | table_names=None, table_len=None, col_table_dict=None, cols=None,
25 | table_col_name=None, table_col_len=None,
26 | col_pred=None, tokenized_src_sent=None,
27 | ):
28 |
29 | self.src_sent = src_sent
30 | self.tokenized_src_sent = tokenized_src_sent
31 | self.vis_seq = vis_seq
32 | self.tab_cols = tab_cols
33 | self.col_num = col_num
34 | self.sql = sql
35 | self.one_hot_type=one_hot_type
36 | self.col_hot_type = col_hot_type
37 | self.schema_len = schema_len
38 | self.tab_ids = tab_ids
39 | self.table_names = table_names
40 | self.table_len = table_len
41 | self.col_table_dict = col_table_dict
42 | self.cols = cols
43 | self.table_col_name = table_col_name
44 | self.table_col_len = table_col_len
45 | self.col_pred = col_pred
46 | self.tgt_actions = tgt_actions
47 | self.truth_actions = copy.deepcopy(tgt_actions)
48 |
49 | self.sketch = list()
50 | if self.truth_actions:
51 | for ta in self.truth_actions:
52 | if isinstance(ta, define_rule.C) or isinstance(ta, define_rule.T) or isinstance(ta, define_rule.A):
53 | continue
54 | self.sketch.append(ta)
55 |
56 |
57 | class cached_property(object):
58 | """ A property that is only computed once per instance and then replaces
59 | itself with an ordinary attribute. Deleting the attribute resets the
60 | property.
61 |
62 | Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76
63 | """
64 |
65 | def __init__(self, func):
66 | self.__doc__ = getattr(func, '__doc__')
67 | self.func = func
68 |
69 | def __get__(self, obj, cls):
70 | if obj is None:
71 | return self
72 | value = obj.__dict__[self.func.__name__] = self.func(obj)
73 | return value
74 |
75 |
76 | class Batch(object):
77 | def __init__(self, examples, grammar, cuda=False):
78 | self.examples = examples
79 |
80 | if examples[0].tgt_actions:
81 | self.max_action_num = max(len(e.tgt_actions) for e in self.examples)
82 | self.max_sketch_num = max(len(e.sketch) for e in self.examples)
83 |
84 | self.src_sents = [e.src_sent for e in self.examples]
85 | self.src_sents_len = [len(e.src_sent) for e in self.examples]
86 | self.tokenized_src_sents = [e.tokenized_src_sent for e in self.examples]
87 | self.tokenized_src_sents_len = [len(e.tokenized_src_sent) for e in examples]
88 | self.src_sents_word = [e.src_sent for e in self.examples]
89 | self.table_sents_word = [[" ".join(x) for x in e.tab_cols] for e in self.examples]
90 |
91 | self.schema_sents_word = [[" ".join(x) for x in e.table_names] for e in self.examples]
92 |
93 | self.src_type = [e.one_hot_type for e in self.examples]
94 | self.col_hot_type = [e.col_hot_type for e in self.examples]
95 | self.table_sents = [e.tab_cols for e in self.examples]
96 | self.col_num = [e.col_num for e in self.examples]
97 | self.tab_ids = [e.tab_ids for e in self.examples]
98 | self.table_names = [e.table_names for e in self.examples]
99 | self.table_len = [e.table_len for e in examples]
100 | self.col_table_dict = [e.col_table_dict for e in examples]
101 | self.table_col_name = [e.table_col_name for e in examples]
102 | self.table_col_len = [e.table_col_len for e in examples]
103 | self.col_pred = [e.col_pred for e in examples]
104 |
105 | self.grammar = grammar
106 | self.cuda = cuda
107 |
108 | def __len__(self):
109 | return len(self.examples)
110 |
111 |
112 | def table_dict_mask(self, table_dict):
113 | return nn_utils.table_dict_to_mask_tensor(self.table_len, table_dict, cuda=self.cuda)
114 |
115 | @cached_property
116 | def pred_col_mask(self):
117 | return nn_utils.pred_col_mask(self.col_pred, self.col_num)
118 |
119 | @cached_property
120 | def schema_token_mask(self):
121 | return nn_utils.length_array_to_mask_tensor(self.table_len, cuda=self.cuda)
122 |
123 | @cached_property
124 | def table_token_mask(self):
125 | return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda)
126 |
127 | @cached_property
128 | def table_appear_mask(self):
129 | return nn_utils.appear_to_mask_tensor(self.col_num, cuda=self.cuda)
130 |
131 | @cached_property
132 | def table_unk_mask(self):
133 | return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda, value=None)
134 |
135 | @cached_property
136 | def src_token_mask(self):
137 | return nn_utils.length_array_to_mask_tensor(self.src_sents_len,
138 | cuda=self.cuda)
139 |
140 |
141 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/26
7 | # @Author : Jiaqi&Zecheng
8 | # @File : __init__.py.py
9 | # @Software: PyCharm
10 | """
--------------------------------------------------------------------------------
/src/models/basic_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/26
7 | # @Author : Jiaqi&Zecheng
8 | # @File : basic_model.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import numpy as np
13 | import os
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.nn.utils
18 | from torch.autograd import Variable
19 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
20 |
21 | from src.rule import semQL as define_rule
22 |
23 |
24 | class BasicModel(nn.Module):
25 |
26 | def __init__(self):
27 | super(BasicModel, self).__init__()
28 | pass
29 |
30 | def embedding_cosine(self, src_embedding, table_embedding, table_unk_mask):
31 | embedding_differ = []
32 | for i in range(table_embedding.size(1)):
33 | one_table_embedding = table_embedding[:, i, :]
34 | one_table_embedding = one_table_embedding.unsqueeze(1).expand(table_embedding.size(0),
35 | src_embedding.size(1),
36 | table_embedding.size(2))
37 |
38 | topk_val = F.cosine_similarity(one_table_embedding, src_embedding, dim=-1)
39 |
40 | embedding_differ.append(topk_val)
41 | embedding_differ = torch.stack(embedding_differ).transpose(1, 0)
42 | embedding_differ.data.masked_fill_(table_unk_mask.unsqueeze(2).expand(
43 | table_embedding.size(0),
44 | table_embedding.size(1),
45 | embedding_differ.size(2)
46 | ).bool(), 0)
47 |
48 | return embedding_differ
49 |
50 | def encode(self, src_sents_var, src_sents_len, q_onehot_project=None):
51 | """
52 | encode the source sequence
53 | :return:
54 | src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
55 | last_state, last_cell: Variable(batch_size, hidden_size)
56 | """
57 | src_token_embed = self.gen_x_batch(src_sents_var)
58 |
59 | if q_onehot_project is not None:
60 | src_token_embed = torch.cat([src_token_embed, q_onehot_project], dim=-1)
61 |
62 | packed_src_token_embed = pack_padded_sequence(src_token_embed, src_sents_len, batch_first=True)
63 | # src_encodings: (tgt_query_len, batch_size, hidden_size)
64 | src_encodings, (last_state, last_cell) = self.encoder_lstm(packed_src_token_embed)
65 | src_encodings, _ = pad_packed_sequence(src_encodings, batch_first=True)
66 | # src_encodings: (batch_size, tgt_query_len, hidden_size)
67 | # src_encodings = src_encodings.permute(1, 0, 2)
68 | # (batch_size, hidden_size * 2)
69 | last_state = torch.cat([last_state[0], last_state[1]], -1)
70 | last_cell = torch.cat([last_cell[0], last_cell[1]], -1)
71 |
72 | return src_encodings, (last_state, last_cell)
73 |
74 | def input_type(self, values_list):
75 | B = len(values_list)
76 | val_len = []
77 | for value in values_list:
78 | val_len.append(len(value))
79 | max_len = max(val_len)
80 | # for the Begin and End
81 | val_emb_array = np.zeros((B, max_len, values_list[0].shape[1]), dtype=np.float32)
82 | for i in range(B):
83 | val_emb_array[i, :val_len[i], :] = values_list[i][:, :]
84 |
85 | val_inp = torch.from_numpy(val_emb_array)
86 | if self.args.cuda:
87 | val_inp = val_inp.cuda()
88 | val_inp_var = Variable(val_inp)
89 | return val_inp_var
90 |
91 | def padding_sketch(self, sketch):
92 | padding_result = []
93 | for action in sketch:
94 | padding_result.append(action)
95 | if type(action) == define_rule.N:
96 | for _ in range(action.id_c + 1):
97 | padding_result.append(define_rule.A(0))
98 | padding_result.append(define_rule.C(0))
99 | padding_result.append(define_rule.T(0))
100 | elif type(action) == define_rule.Filter and 'A' in action.production:
101 | padding_result.append(define_rule.A(0))
102 | padding_result.append(define_rule.C(0))
103 | padding_result.append(define_rule.T(0))
104 | elif type(action) == define_rule.Order or type(action) == define_rule.Sup:
105 | padding_result.append(define_rule.A(0))
106 | padding_result.append(define_rule.C(0))
107 | padding_result.append(define_rule.T(0))
108 |
109 | return padding_result
110 |
111 | def gen_x_batch(self, q):
112 | B = len(q)
113 | val_embs = []
114 | val_len = np.zeros(B, dtype=np.int64)
115 | is_list = False
116 | if type(q[0][0]) == list:
117 | is_list = True
118 | for i, one_q in enumerate(q):
119 | if not is_list:
120 | q_val = list(
121 | map(lambda x: self.word_emb.get(x, np.zeros(self.args.col_embed_size, dtype=np.float32)), one_q))
122 | else:
123 | q_val = []
124 | for ws in one_q:
125 | emb_list = []
126 | ws_len = len(ws)
127 | for w in ws:
128 | emb_list.append(self.word_emb.get(w, self.word_emb['unk']))
129 | if ws_len == 0:
130 | raise Exception("word list should not be empty!")
131 | elif ws_len == 1:
132 | q_val.append(emb_list[0])
133 | else:
134 | q_val.append(sum(emb_list) / float(ws_len))
135 |
136 | val_embs.append(q_val)
137 | val_len[i] = len(q_val)
138 | max_len = max(val_len)
139 |
140 | val_emb_array = np.zeros((B, max_len, self.args.col_embed_size), dtype=np.float32)
141 | for i in range(B):
142 | for t in range(len(val_embs[i])):
143 | val_emb_array[i, t, :] = val_embs[i][t]
144 | val_inp = torch.from_numpy(val_emb_array)
145 | if self.args.cuda:
146 | val_inp = val_inp.cuda()
147 | return val_inp
148 |
149 | def save(self, path):
150 | dir_name = os.path.dirname(path)
151 | if not os.path.exists(dir_name):
152 | os.makedirs(dir_name)
153 |
154 | params = {
155 | 'args': self.args,
156 | 'vocab': self.vocab,
157 | 'grammar': self.grammar,
158 | 'state_dict': self.state_dict()
159 | }
160 | torch.save(params, path)
161 |
--------------------------------------------------------------------------------
/src/models/nn_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : utils.py
9 | # @Software: PyCharm
10 | """
11 | import torch.nn.functional as F
12 | import torch.nn.init as init
13 | import numpy as np
14 | import torch
15 | from torch.autograd import Variable
16 | from six.moves import xrange
17 |
18 | def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None):
19 | """
20 | :param h_t: (batch_size, hidden_size)
21 | :param src_encoding: (batch_size, src_sent_len, hidden_size * 2)
22 | :param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size)
23 | :param mask: (batch_size, src_sent_len)
24 | """
25 | # (batch_size, src_sent_len)
26 | att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2)
27 | if mask is not None:
28 | att_weight.data.masked_fill_(mask.bool(), -float('inf'))
29 | att_weight = F.softmax(att_weight, dim=-1)
30 |
31 | att_view = (att_weight.size(0), 1, att_weight.size(1))
32 | # (batch_size, hidden_size)
33 | ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1)
34 |
35 | return ctx_vec, att_weight
36 |
37 |
38 | def length_array_to_mask_tensor(length_array, cuda=False, value=None):
39 | max_len = max(length_array)
40 | batch_size = len(length_array)
41 |
42 | mask = np.ones((batch_size, max_len), dtype=np.uint8)
43 | for i, seq_len in enumerate(length_array):
44 | mask[i][:seq_len] = 0
45 |
46 | if value != None:
47 | for b_id in range(len(value)):
48 | for c_id, c in enumerate(value[b_id]):
49 | if value[b_id][c_id] == [3]:
50 | mask[b_id][c_id] = 1
51 |
52 | mask = torch.ByteTensor(mask)
53 | return mask.cuda() if cuda else mask
54 |
55 |
56 | def table_dict_to_mask_tensor(length_array, table_dict, cuda=False ):
57 | max_len = max(length_array)
58 | batch_size = len(table_dict)
59 |
60 | mask = np.ones((batch_size, max_len), dtype=np.uint8)
61 | for i, ta_val in enumerate(table_dict):
62 | for tt in ta_val:
63 | mask[i][tt] = 0
64 |
65 | mask = torch.ByteTensor(mask)
66 | return mask.cuda() if cuda else mask
67 |
68 |
69 | def length_position_tensor(length_array, cuda=False, value=None):
70 | max_len = max(length_array)
71 | batch_size = len(length_array)
72 |
73 | mask = np.zeros((batch_size, max_len), dtype=np.float32)
74 |
75 | for b_id in range(batch_size):
76 | for len_c in range(length_array[b_id]):
77 | mask[b_id][len_c] = len_c + 1
78 |
79 | mask = torch.LongTensor(mask)
80 | return mask.cuda() if cuda else mask
81 |
82 |
83 | def appear_to_mask_tensor(length_array, cuda=False, value=None):
84 | max_len = max(length_array)
85 | batch_size = len(length_array)
86 | mask = np.zeros((batch_size, max_len), dtype=np.float32)
87 | return mask
88 |
89 | def pred_col_mask(value, max_len):
90 | max_len = max(max_len)
91 | batch_size = len(value)
92 | mask = np.ones((batch_size, max_len), dtype=np.uint8)
93 | for v_ind, v_val in enumerate(value):
94 | for v in v_val:
95 | mask[v_ind][v] = 0
96 | mask = torch.ByteTensor(mask)
97 | return mask.cuda()
98 |
99 |
100 | def input_transpose(sents, pad_token):
101 | """
102 | transform the input List[sequence] of size (batch_size, max_sent_len)
103 | into a list of size (batch_size, max_sent_len), with proper padding
104 | """
105 | max_len = max(len(s) for s in sents)
106 | batch_size = len(sents)
107 | sents_t = []
108 | masks = []
109 | for e_id in range(batch_size):
110 | if type(sents[0][0]) != list:
111 | sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else pad_token for i in range(max_len)])
112 | else:
113 | sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else [pad_token] for i in range(max_len)])
114 |
115 | masks.append([1 if len(sents[e_id]) > i else 0 for i in range(max_len)])
116 |
117 | return sents_t, masks
118 |
119 |
120 | def word2id(sents, vocab):
121 | if type(sents[0]) == list:
122 | if type(sents[0][0]) != list:
123 | return [[vocab[w] for w in s] for s in sents]
124 | else:
125 | return [[[vocab[w] for w in s] for s in v] for v in sents ]
126 | else:
127 | return [vocab[w] for w in sents]
128 |
129 |
130 | def id2word(sents, vocab):
131 | if type(sents[0]) == list:
132 | return [[vocab.id2word[w] for w in s] for s in sents]
133 | else:
134 | return [vocab.id2word[w] for w in sents]
135 |
136 |
137 | def to_input_variable(sequences, vocab, cuda=False, training=True):
138 | """
139 | given a list of sequences,
140 | return a tensor of shape (max_sent_len, batch_size)
141 | """
142 | word_ids = word2id(sequences, vocab)
143 | sents_t, masks = input_transpose(word_ids, vocab[''])
144 |
145 | if type(sents_t[0][0]) != list:
146 | with torch.no_grad():
147 | sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
148 | if cuda:
149 | sents_var = sents_var.cuda()
150 | else:
151 | sents_var = sents_t
152 |
153 | return sents_var
154 |
155 |
156 | def variable_constr(x, v, cuda=False):
157 | return Variable(torch.cuda.x(v)) if cuda else Variable(torch.x(v))
158 |
159 |
160 | def batch_iter(examples, batch_size, shuffle=False):
161 | index_arr = np.arange(len(examples))
162 | if shuffle:
163 | np.random.shuffle(index_arr)
164 |
165 | batch_num = int(np.ceil(len(examples) / float(batch_size)))
166 | for batch_id in xrange(batch_num):
167 | batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)]
168 | batch_examples = [examples[i] for i in batch_ids]
169 |
170 | yield batch_examples
171 |
172 |
173 | def isnan(data):
174 | data = data.cpu().numpy()
175 | return np.isnan(data).any() or np.isinf(data).any()
176 |
177 |
178 | def log_sum_exp(inputs, dim=None, keepdim=False):
179 | """Numerically stable logsumexp.
180 | source: https://github.com/pytorch/pytorch/issues/2591
181 |
182 | Args:
183 | inputs: A Variable with any shape.
184 | dim: An integer.
185 | keepdim: A boolean.
186 |
187 | Returns:
188 | Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
189 | """
190 | # For a 1-D array x (any array along a single dimension),
191 | # log sum exp(x) = s + log sum exp(x - s)
192 | # with s = max(x) being a common choice.
193 |
194 | if dim is None:
195 | inputs = inputs.view(-1)
196 | dim = 0
197 | s, _ = torch.max(inputs, dim=dim, keepdim=True)
198 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
199 | if not keepdim:
200 | outputs = outputs.squeeze(dim)
201 | return outputs
202 |
203 |
204 | def uniform_init(lower, upper, params):
205 | for p in params:
206 | p.data.uniform_(lower, upper)
207 |
208 |
209 | def glorot_init(params):
210 | for p in params:
211 | if len(p.data.size()) > 1:
212 | init.xavier_normal(p.data)
213 |
214 |
215 | def identity(x):
216 | return x
217 |
218 |
219 | def pad_matrix(matrixs, cuda=False):
220 | """
221 | :param matrixs:
222 | :return: [batch_size, max_shape, max_shape], [batch_size]
223 | """
224 | shape = [m.shape[0] for m in matrixs]
225 | max_shape = max(shape)
226 | tensors = list()
227 | for s, m in zip(shape, matrixs):
228 | delta = max_shape - s
229 | if s > 0:
230 | tensors.append(torch.as_tensor(np.pad(m, [(0, delta), (0, delta)], mode='constant'), dtype=torch.float))
231 | else:
232 | tensors.append(torch.as_tensor(m, dtype=torch.float))
233 | tensors = torch.stack(tensors)
234 | if cuda:
235 | tensors = tensors.cuda()
236 | return tensors
237 |
--------------------------------------------------------------------------------
/src/models/pointer_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # coding=utf8
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.utils
9 | from torch.nn import Parameter
10 |
11 |
12 | class AuxiliaryPointerNet(nn.Module):
13 |
14 | def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
15 | super(AuxiliaryPointerNet, self).__init__()
16 |
17 | assert attention_type in ('affine', 'dot_prod')
18 | if attention_type == 'affine':
19 | self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
20 | self.auxiliary_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
21 | self.attention_type = attention_type
22 |
23 | def forward(self, src_encodings, src_context_encodings, src_token_mask, query_vec):
24 | """
25 | :param src_context_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
26 | :param src_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
27 | :param src_token_mask: Variable(batch_size, src_sent_len)
28 | :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
29 | :return: Variable(tgt_action_num, batch_size, src_sent_len)
30 | """
31 |
32 | # (batch_size, 1, src_sent_len, query_vec_size)
33 | encodings = src_encodings.clone()
34 | context_encodings = src_context_encodings.clone()
35 | if self.attention_type == 'affine':
36 | encodings = self.src_encoding_linear(src_encodings)
37 | context_encodings = self.auxiliary_encoding_linear(src_context_encodings)
38 | encodings = encodings.unsqueeze(1)
39 | context_encodings = context_encodings.unsqueeze(1)
40 |
41 | # (batch_size, tgt_action_num, query_vec_size, 1)
42 | q = query_vec.permute(1, 0, 2).unsqueeze(3)
43 |
44 | # (batch_size, tgt_action_num, src_sent_len)
45 | weights = torch.matmul(encodings, q).squeeze(3)
46 | context_weights = torch.matmul(context_encodings, q).squeeze(3)
47 |
48 | # (tgt_action_num, batch_size, src_sent_len)
49 | weights = weights.permute(1, 0, 2)
50 | context_weights = context_weights.permute(1, 0, 2)
51 |
52 | if src_token_mask is not None:
53 | # (tgt_action_num, batch_size, src_sent_len)
54 | src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
55 | weights.data.masked_fill_(src_token_mask.bool(), -float('inf'))
56 | context_weights.data.masked_fill_(src_token_mask.bool(), -float('inf'))
57 |
58 | sigma = 0.1
59 | return weights.squeeze(0) + sigma * context_weights.squeeze(0)
60 |
61 |
62 | class PointerNet(nn.Module):
63 | def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
64 | super(PointerNet, self).__init__()
65 |
66 | assert attention_type in ('affine', 'dot_prod')
67 | if attention_type == 'affine':
68 | self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
69 |
70 | self.attention_type = attention_type
71 | self.input_linear = nn.Linear(query_vec_size, query_vec_size)
72 | self.type_linear = nn.Linear(32, query_vec_size)
73 | self.V = Parameter(torch.FloatTensor(query_vec_size), requires_grad=True)
74 | self.tanh = nn.Tanh()
75 | self.context_linear = nn.Conv1d(src_encoding_size, query_vec_size, 1, 1)
76 | self.coverage_linear = nn.Conv1d(1, query_vec_size, 1, 1)
77 |
78 |
79 | nn.init.uniform_(self.V, -1, 1)
80 |
81 | def forward(self, src_encodings, src_token_mask, query_vec):
82 | """
83 | :param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
84 | :param src_token_mask: Variable(batch_size, src_sent_len)
85 | :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
86 | :return: Variable(tgt_action_num, batch_size, src_sent_len)
87 | """
88 |
89 | # (batch_size, 1, src_sent_len, query_vec_size)
90 |
91 | if self.attention_type == 'affine':
92 | src_encodings = self.src_encoding_linear(src_encodings)
93 | src_encodings = src_encodings.unsqueeze(1)
94 |
95 | # (batch_size, tgt_action_num, query_vec_size, 1)
96 | q = query_vec.permute(1, 0, 2).unsqueeze(3)
97 |
98 | weights = torch.matmul(src_encodings, q).squeeze(3)
99 |
100 | weights = weights.permute(1, 0, 2)
101 |
102 | if src_token_mask is not None:
103 | src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
104 | weights.data.masked_fill_(src_token_mask.bool(), -float('inf'))
105 |
106 | return weights.squeeze(0)
107 |
--------------------------------------------------------------------------------
/src/rule/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/27
7 | # @Author : Jiaqi&Zecheng
8 | # @File : __init__.py.py
9 | # @Software: PyCharm
10 | """
--------------------------------------------------------------------------------
/src/rule/graph.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : utils.py
9 | # @Software: PyCharm
10 | """
11 |
12 | from collections import deque, namedtuple
13 |
14 |
15 | # we'll use infinity as a default distance to nodes.
16 | inf = float('inf')
17 | Edge = namedtuple('Edge', 'start, end, cost')
18 |
19 |
20 | def make_edge(start, end, cost=1):
21 | return Edge(start, end, cost)
22 |
23 |
24 | class Graph:
25 | def __init__(self, edges):
26 | # let's check that the data is right
27 | wrong_edges = [i for i in edges if len(i) not in [2, 3]]
28 | if wrong_edges:
29 | raise ValueError('Wrong edges data: {}'.format(wrong_edges))
30 |
31 | self.edges = [make_edge(*edge) for edge in edges]
32 |
33 | @property
34 | def vertices(self):
35 | return set(
36 | # this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4]
37 | # the set above makes it's elements unique.
38 | sum(
39 | ([edge.start, edge.end] for edge in self.edges), []
40 | )
41 | )
42 |
43 | def get_node_pairs(self, n1, n2, both_ends=True):
44 | if both_ends:
45 | node_pairs = [[n1, n2], [n2, n1]]
46 | else:
47 | node_pairs = [[n1, n2]]
48 | return node_pairs
49 |
50 | def remove_edge(self, n1, n2, both_ends=True):
51 | node_pairs = self.get_node_pairs(n1, n2, both_ends)
52 | edges = self.edges[:]
53 | for edge in edges:
54 | if [edge.start, edge.end] in node_pairs:
55 | self.edges.remove(edge)
56 |
57 | def add_edge(self, n1, n2, cost=1, both_ends=True):
58 | node_pairs = self.get_node_pairs(n1, n2, both_ends)
59 | for edge in self.edges:
60 | if [edge.start, edge.end] in node_pairs:
61 | return ValueError('Edge {} {} already exists'.format(n1, n2))
62 |
63 | self.edges.append(Edge(start=n1, end=n2, cost=cost))
64 | if both_ends:
65 | self.edges.append(Edge(start=n2, end=n1, cost=cost))
66 |
67 | @property
68 | def neighbours(self):
69 | neighbours = {vertex: set() for vertex in self.vertices}
70 | for edge in self.edges:
71 | neighbours[edge.start].add((edge.end, edge.cost))
72 |
73 | return neighbours
74 |
75 | def dijkstra(self, source, dest):
76 | assert source in self.vertices, 'Such source node doesn\'t exist'
77 | assert dest in self.vertices, 'Such source node doesn\'t exis'
78 |
79 | # 1. Mark all nodes unvisited and store them.
80 | # 2. Set the distance to zero for our initial node
81 | # and to infinity for other nodes.
82 | distances = {vertex: inf for vertex in self.vertices}
83 | previous_vertices = {
84 | vertex: None for vertex in self.vertices
85 | }
86 | distances[source] = 0
87 | vertices = self.vertices.copy()
88 |
89 | while vertices:
90 | # 3. Select the unvisited node with the smallest distance,
91 | # it's current node now.
92 | current_vertex = min(
93 | vertices, key=lambda vertex: distances[vertex])
94 |
95 | # 6. Stop, if the smallest distance
96 | # among the unvisited nodes is infinity.
97 | if distances[current_vertex] == inf:
98 | break
99 |
100 | # 4. Find unvisited neighbors for the current node
101 | # and calculate their distances through the current node.
102 | for neighbour, cost in self.neighbours[current_vertex]:
103 | alternative_route = distances[current_vertex] + cost
104 |
105 | # Compare the newly calculated distance to the assigned
106 | # and save the smaller one.
107 | if alternative_route < distances[neighbour]:
108 | distances[neighbour] = alternative_route
109 | previous_vertices[neighbour] = current_vertex
110 |
111 | # 5. Mark the current node as visited
112 | # and remove it from the unvisited set.
113 | vertices.remove(current_vertex)
114 |
115 | path, current_vertex = deque(), dest
116 | while previous_vertices[current_vertex] is not None:
117 | path.appendleft(current_vertex)
118 | current_vertex = previous_vertices[current_vertex]
119 | if path:
120 | path.appendleft(current_vertex)
121 | return path
122 |
123 |
124 | if __name__ == '__main__':
125 | graph = Graph([
126 | ("a", "b", 7), ("a", "c", 9), ("a", "f", 14), ("b", "c", 10),
127 | ("b", "d", 15), ("c", "d", 11), ("c", "f", 2), ("d", "e", 6),
128 | ("e", "f", 9)])
129 |
130 | print(graph.dijkstra("a", "e"))
--------------------------------------------------------------------------------
/src/rule/lf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : utils.py
9 | # @Software: PyCharm
10 | """
11 | import copy
12 | import json
13 |
14 | import numpy as np
15 |
16 | from src.rule import semQL as define_rule
17 | from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
18 |
19 |
20 | def _build_single_filter(lf, f):
21 | # No conjunction
22 | agg = lf.pop(0)
23 | column = lf.pop(0)
24 | if len(lf) == 0:
25 | table = None
26 | else:
27 | table = lf.pop(0)
28 | if not isinstance(table, define_rule.T):
29 | lf.insert(0, table)
30 | table = None
31 | assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
32 | if len(f.production.split()) == 3:
33 | f.add_children(agg)
34 | agg.set_parent(f)
35 | agg.add_children(column)
36 | column.set_parent(agg)
37 | if table is not None:
38 | column.add_children(table)
39 | table.set_parent(column)
40 | else:
41 | # Subquery
42 | f.add_children(agg)
43 | agg.set_parent(f)
44 | agg.add_children(column)
45 | column.set_parent(agg)
46 | if table is not None:
47 | column.add_children(table)
48 | table.set_parent(column)
49 | _root = _build(lf)
50 | f.add_children(_root)
51 | _root.set_parent(f)
52 |
53 |
54 | def _build_filter(lf, root_filter):
55 | assert isinstance(root_filter, define_rule.Filter)
56 | op = root_filter.production.split()[1]
57 | if op == 'and' or op == 'or':
58 | for i in range(2):
59 | child = lf.pop(0)
60 | op = child.production.split()[1]
61 | if op == 'and' or op == 'or':
62 | _f = _build_filter(lf, child)
63 | root_filter.add_children(_f)
64 | _f.set_parent(root_filter)
65 | else:
66 | _build_single_filter(lf, child)
67 | root_filter.add_children(child)
68 | child.set_parent(root_filter)
69 | else:
70 | _build_single_filter(lf, root_filter)
71 | return root_filter
72 |
73 |
74 | def _build(lf):
75 | root = lf.pop(0)
76 | assert isinstance(root, define_rule.Root)
77 | length = len(root.production.split()) - 1
78 | while len(root.children) != length:
79 | c_instance = lf.pop(0)
80 | if isinstance(c_instance, define_rule.Sel):
81 | sel_instance = c_instance
82 | root.add_children(sel_instance)
83 | sel_instance.set_parent(root)
84 |
85 | # define_rule.N
86 | c_instance = lf.pop(0)
87 | c_instance.set_parent(sel_instance)
88 | sel_instance.add_children(c_instance)
89 | assert isinstance(c_instance, define_rule.N)
90 | for i in range(c_instance.id_c + 1):
91 | agg = lf.pop(0)
92 | column = lf.pop(0)
93 | if len(lf) == 0:
94 | table = None
95 | else:
96 | table = lf.pop(0)
97 | if not isinstance(table, define_rule.T):
98 | lf.insert(0, table)
99 | table = None
100 | assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
101 | c_instance.add_children(agg)
102 | agg.set_parent(c_instance)
103 | agg.add_children(column)
104 | column.set_parent(agg)
105 | if table is not None:
106 | column.add_children(table)
107 | table.set_parent(column)
108 |
109 | elif isinstance(c_instance, define_rule.Sup) or isinstance(c_instance, define_rule.Order):
110 | root.add_children(c_instance)
111 | c_instance.set_parent(root)
112 |
113 | agg = lf.pop(0)
114 | column = lf.pop(0)
115 | if len(lf) == 0:
116 | table = None
117 | else:
118 | table = lf.pop(0)
119 | if not isinstance(table, define_rule.T):
120 | lf.insert(0, table)
121 | table = None
122 | assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
123 | c_instance.add_children(agg)
124 | agg.set_parent(c_instance)
125 | agg.add_children(column)
126 | column.set_parent(agg)
127 | if table is not None:
128 | column.add_children(table)
129 | table.set_parent(column)
130 |
131 | elif isinstance(c_instance, define_rule.Filter):
132 | _build_filter(lf, c_instance)
133 | root.add_children(c_instance)
134 | c_instance.set_parent(root)
135 |
136 | return root
137 |
138 |
139 | def build_tree(lf):
140 | root = lf.pop(0)
141 | assert isinstance(root, define_rule.Root1)
142 | if root.id_c == 0 or root.id_c == 1 or root.id_c == 2:
143 | root_1 = _build(lf)
144 | root_2 = _build(lf)
145 | root.add_children(root_1)
146 | root.add_children(root_2)
147 | root_1.set_parent(root)
148 | root_2.set_parent(root)
149 | else:
150 | root_1 = _build(lf)
151 | root.add_children(root_1)
152 | root_1.set_parent(root)
153 | verify(root)
154 | # eliminate_parent(root)
155 |
156 |
157 | def eliminate_parent(node):
158 | for child in node.children:
159 | eliminate_parent(child)
160 | node.children = list()
161 |
162 |
163 | def verify(node):
164 | if isinstance(node, C) and len(node.children) > 0:
165 | table = node.children[0]
166 | assert table is None or isinstance(table, T)
167 | if isinstance(node, T):
168 | return
169 | children_num = len(node.children)
170 | if isinstance(node, Root1):
171 | if node.id_c == 0 or node.id_c == 1 or node.id_c == 2:
172 | assert children_num == 2
173 | else:
174 | assert children_num == 1
175 | elif isinstance(node, Root):
176 | assert children_num == len(node.production.split()) - 1
177 | elif isinstance(node, N):
178 | assert children_num == int(node.id_c) + 1
179 | elif isinstance(node, Sup) or isinstance(node, Order) or isinstance(node, Sel):
180 | assert children_num == 1
181 | elif isinstance(node, Filter):
182 | op = node.production.split()[1]
183 | if op == 'and' or op == 'or':
184 | assert children_num == 2
185 | else:
186 | if len(node.production.split()) == 3:
187 | assert children_num == 1
188 | else:
189 | assert children_num == 2
190 | for child in node.children:
191 | assert child.parent == node
192 | verify(child)
193 |
194 |
195 | def label_matrix(lf, matrix, node):
196 | nindex = lf.index(node)
197 | for child in node.children:
198 | if child not in lf:
199 | continue
200 | index = lf.index(child)
201 | matrix[nindex][index] = 1
202 | label_matrix(lf, matrix, child)
203 |
204 |
205 | def build_adjacency_matrix(lf, symmetry=False):
206 | _lf = list()
207 | for rule in lf:
208 | if isinstance(rule, A) or isinstance(rule, C) or isinstance(rule, T):
209 | continue
210 | _lf.append(rule)
211 | length = len(_lf)
212 | matrix = np.zeros((length, length,))
213 | label_matrix(_lf, matrix, _lf[0])
214 | if symmetry:
215 | matrix += matrix.T
216 | return matrix
217 |
218 |
219 | if __name__ == '__main__':
220 | with open(r'..\data\train.json', 'r') as f:
221 | data = json.load(f)
222 | for d in data:
223 | rule_label = [eval(x) for x in d['rule_label'].strip().split(' ')]
224 | print(d['question'])
225 | print(rule_label)
226 | build_tree(copy.copy(rule_label))
227 | adjacency_matrix = build_adjacency_matrix(rule_label, symmetry=True)
228 | print(adjacency_matrix)
229 | print('===\n\n')
230 |
--------------------------------------------------------------------------------
/src/rule/semQL.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/24
7 | # @Author : Jiaqi&Zecheng
8 | # @File : semQL.py
9 | # @Software: PyCharm
10 | """
11 |
12 | Keywords = ['des', 'asc', 'and', 'or', 'sum', 'min', 'max', 'avg', 'none', '=', '!=', '<', '>', '<=', '>=', 'between', 'like', 'not_like'] + [
13 | 'in', 'not_in', 'count', 'intersect', 'union', 'except'
14 | ]
15 |
16 |
17 | class Grammar(object):
18 | def __init__(self, is_sketch=False):
19 | self.begin = 0
20 | self.type_id = 0
21 | self.is_sketch = is_sketch
22 | self.prod2id = {}
23 | self.type2id = {}
24 | self._init_grammar(Sel)
25 | self._init_grammar(Root)
26 | self._init_grammar(Sup)
27 | self._init_grammar(Filter)
28 | self._init_grammar(Order)
29 | self._init_grammar(N)
30 | self._init_grammar(Root1)
31 |
32 | if not self.is_sketch:
33 | self._init_grammar(A)
34 |
35 | self._init_id2prod()
36 | self.type2id[C] = self.type_id
37 | self.type_id += 1
38 | self.type2id[T] = self.type_id
39 |
40 | def _init_grammar(self, Cls):
41 | """
42 | get the production of class Cls
43 | :param Cls:
44 | :return:
45 | """
46 | production = Cls._init_grammar()
47 | for p in production:
48 | self.prod2id[p] = self.begin
49 | self.begin += 1
50 | self.type2id[Cls] = self.type_id
51 | self.type_id += 1
52 |
53 | def _init_id2prod(self):
54 | self.id2prod = {}
55 | for key, value in self.prod2id.items():
56 | self.id2prod[value] = key
57 |
58 | def get_production(self, Cls):
59 | return Cls._init_grammar()
60 |
61 |
62 | class Action(object):
63 | def __init__(self):
64 | self.pt = 0
65 | self.production = None
66 | self.children = list()
67 |
68 | def get_next_action(self, is_sketch=False):
69 | actions = list()
70 | for x in self.production.split(' ')[1:]:
71 | if x not in Keywords:
72 | rule_type = eval(x)
73 | if is_sketch:
74 | if rule_type is not A:
75 | actions.append(rule_type)
76 | else:
77 | actions.append(rule_type)
78 | return actions
79 |
80 | def set_parent(self, parent):
81 | self.parent = parent
82 |
83 | def add_children(self, child):
84 | self.children.append(child)
85 |
86 |
87 | class Root1(Action):
88 | def __init__(self, id_c, parent=None):
89 | super(Root1, self).__init__()
90 | self.parent = parent
91 | self.id_c = id_c
92 | self._init_grammar()
93 | self.production = self.grammar_dict[id_c]
94 |
95 | @classmethod
96 | def _init_grammar(self):
97 | # TODO: should add Root grammar to this
98 | self.grammar_dict = {
99 | 0: 'Root1 intersect Root Root',
100 | 1: 'Root1 union Root Root',
101 | 2: 'Root1 except Root Root',
102 | 3: 'Root1 Root',
103 | }
104 | self.production_id = {}
105 | for id_x, value in enumerate(self.grammar_dict.values()):
106 | self.production_id[value] = id_x
107 |
108 | return self.grammar_dict.values()
109 |
110 | def __str__(self):
111 | return 'Root1(' + str(self.id_c) + ')'
112 |
113 | def __repr__(self):
114 | return 'Root1(' + str(self.id_c) + ')'
115 |
116 |
117 | class Root(Action):
118 | def __init__(self, id_c, parent=None):
119 | super(Root, self).__init__()
120 | self.parent = parent
121 | self.id_c = id_c
122 | self._init_grammar()
123 | self.production = self.grammar_dict[id_c]
124 |
125 | @classmethod
126 | def _init_grammar(self):
127 | # TODO: should add Root grammar to this
128 | self.grammar_dict = {
129 | 0: 'Root Sel Sup Filter',
130 | 1: 'Root Sel Filter Order',
131 | 2: 'Root Sel Sup',
132 | 3: 'Root Sel Filter',
133 | 4: 'Root Sel Order',
134 | 5: 'Root Sel'
135 | }
136 | self.production_id = {}
137 | for id_x, value in enumerate(self.grammar_dict.values()):
138 | self.production_id[value] = id_x
139 |
140 | return self.grammar_dict.values()
141 |
142 | def __str__(self):
143 | return 'Root(' + str(self.id_c) + ')'
144 |
145 | def __repr__(self):
146 | return 'Root(' + str(self.id_c) + ')'
147 |
148 |
149 | class N(Action):
150 | """
151 | Number of Columns
152 | """
153 | def __init__(self, id_c, parent=None):
154 | super(N, self).__init__()
155 | self.parent = parent
156 | self.id_c = id_c
157 | self._init_grammar()
158 | self.production = self.grammar_dict[id_c]
159 |
160 | @classmethod
161 | def _init_grammar(self):
162 | self.grammar_dict = {
163 | 0: 'N A',
164 | 1: 'N A A',
165 | 2: 'N A A A',
166 | 3: 'N A A A A',
167 | 4: 'N A A A A A'
168 | }
169 | self.production_id = {}
170 | for id_x, value in enumerate(self.grammar_dict.values()):
171 | self.production_id[value] = id_x
172 |
173 | return self.grammar_dict.values()
174 |
175 | def __str__(self):
176 | return 'N(' + str(self.id_c) + ')'
177 |
178 | def __repr__(self):
179 | return 'N(' + str(self.id_c) + ')'
180 |
181 | class C(Action):
182 | """
183 | Column
184 | """
185 | def __init__(self, id_c, parent=None):
186 | super(C, self).__init__()
187 | self.parent = parent
188 | self.id_c = id_c
189 | self.production = 'C T'
190 | self.table = None
191 |
192 | def __str__(self):
193 | return 'C(' + str(self.id_c) + ')'
194 |
195 | def __repr__(self):
196 | return 'C(' + str(self.id_c) + ')'
197 |
198 |
199 | class T(Action):
200 | """
201 | Table
202 | """
203 | def __init__(self, id_c, parent=None):
204 | super(T, self).__init__()
205 |
206 | self.parent = parent
207 | self.id_c = id_c
208 | self.production = 'T min'
209 | self.table = None
210 |
211 | def __str__(self):
212 | return 'T(' + str(self.id_c) + ')'
213 |
214 | def __repr__(self):
215 | return 'T(' + str(self.id_c) + ')'
216 |
217 |
218 | class A(Action):
219 | """
220 | Aggregator
221 | """
222 | def __init__(self, id_c, parent=None):
223 | super(A, self).__init__()
224 |
225 | self.parent = parent
226 | self.id_c = id_c
227 | self._init_grammar()
228 | self.production = self.grammar_dict[id_c]
229 |
230 | @classmethod
231 | def _init_grammar(self):
232 | # TODO: should add Root grammar to this
233 | self.grammar_dict = {
234 | 0: 'A none C',
235 | 1: 'A max C',
236 | 2: "A min C",
237 | 3: "A count C",
238 | 4: "A sum C",
239 | 5: "A avg C"
240 | }
241 | self.production_id = {}
242 | for id_x, value in enumerate(self.grammar_dict.values()):
243 | self.production_id[value] = id_x
244 |
245 | return self.grammar_dict.values()
246 |
247 | def __str__(self):
248 | return 'A(' + str(self.id_c) + ')'
249 |
250 | def __repr__(self):
251 | return 'A(' + str(self.grammar_dict[self.id_c].split(' ')[1]) + ')'
252 |
253 |
254 | class Sel(Action):
255 | """
256 | Select
257 | """
258 | def __init__(self, id_c, parent=None):
259 | super(Sel, self).__init__()
260 |
261 | self.parent = parent
262 | self.id_c = id_c
263 | self._init_grammar()
264 | self.production = self.grammar_dict[id_c]
265 |
266 | @classmethod
267 | def _init_grammar(self):
268 | self.grammar_dict = {
269 | 0: 'Sel N',
270 | }
271 | self.production_id = {}
272 | for id_x, value in enumerate(self.grammar_dict.values()):
273 | self.production_id[value] = id_x
274 |
275 | return self.grammar_dict.values()
276 |
277 | def __str__(self):
278 | return 'Sel(' + str(self.id_c) + ')'
279 |
280 | def __repr__(self):
281 | return 'Sel(' + str(self.id_c) + ')'
282 |
283 | class Filter(Action):
284 | """
285 | Filter
286 | """
287 | def __init__(self, id_c, parent=None):
288 | super(Filter, self).__init__()
289 |
290 | self.parent = parent
291 | self.id_c = id_c
292 | self._init_grammar()
293 | self.production = self.grammar_dict[id_c]
294 |
295 | @classmethod
296 | def _init_grammar(self):
297 | self.grammar_dict = {
298 | # 0: "Filter 1"
299 | 0: 'Filter and Filter Filter',
300 | 1: 'Filter or Filter Filter',
301 | 2: 'Filter = A',
302 | 3: 'Filter != A',
303 | 4: 'Filter < A',
304 | 5: 'Filter > A',
305 | 6: 'Filter <= A',
306 | 7: 'Filter >= A',
307 | 8: 'Filter between A',
308 | 9: 'Filter like A',
309 | 10: 'Filter not_like A',
310 | # now begin root
311 | 11: 'Filter = A Root',
312 | 12: 'Filter < A Root',
313 | 13: 'Filter > A Root',
314 | 14: 'Filter != A Root',
315 | 15: 'Filter between A Root',
316 | 16: 'Filter >= A Root',
317 | 17: 'Filter <= A Root',
318 | # now for In
319 | 18: 'Filter in A Root',
320 | 19: 'Filter not_in A Root'
321 |
322 | }
323 | self.production_id = {}
324 | for id_x, value in enumerate(self.grammar_dict.values()):
325 | self.production_id[value] = id_x
326 |
327 | return self.grammar_dict.values()
328 |
329 | def __str__(self):
330 | return 'Filter(' + str(self.id_c) + ')'
331 |
332 | def __repr__(self):
333 | return 'Filter(' + str(self.grammar_dict[self.id_c]) + ')'
334 |
335 |
336 | class Sup(Action):
337 | """
338 | Superlative
339 | """
340 | def __init__(self, id_c, parent=None):
341 | super(Sup, self).__init__()
342 |
343 | self.parent = parent
344 | self.id_c = id_c
345 | self._init_grammar()
346 | self.production = self.grammar_dict[id_c]
347 |
348 | @classmethod
349 | def _init_grammar(self):
350 | self.grammar_dict = {
351 | 0: 'Sup des A',
352 | 1: 'Sup asc A',
353 | }
354 | self.production_id = {}
355 | for id_x, value in enumerate(self.grammar_dict.values()):
356 | self.production_id[value] = id_x
357 |
358 | return self.grammar_dict.values()
359 |
360 | def __str__(self):
361 | return 'Sup(' + str(self.id_c) + ')'
362 |
363 | def __repr__(self):
364 | return 'Sup(' + str(self.id_c) + ')'
365 |
366 |
367 | class Order(Action):
368 | """
369 | Order
370 | """
371 | def __init__(self, id_c, parent=None):
372 | super(Order, self).__init__()
373 |
374 | self.parent = parent
375 | self.id_c = id_c
376 | self._init_grammar()
377 | self.production = self.grammar_dict[id_c]
378 |
379 | @classmethod
380 | def _init_grammar(self):
381 | self.grammar_dict = {
382 | 0: 'Order des A',
383 | 1: 'Order asc A',
384 | }
385 | self.production_id = {}
386 | for id_x, value in enumerate(self.grammar_dict.values()):
387 | self.production_id[value] = id_x
388 |
389 | return self.grammar_dict.values()
390 |
391 | def __str__(self):
392 | return 'Order(' + str(self.id_c) + ')'
393 |
394 | def __repr__(self):
395 | return 'Order(' + str(self.id_c) + ')'
396 |
397 |
398 | if __name__ == '__main__':
399 | print(list(Root._init_grammar()))
400 |
--------------------------------------------------------------------------------
/src/rule/sem_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/27
7 | # @Author : Jiaqi&Zecheng
8 | # @File : sem_utils.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import os
13 | import json
14 | import argparse
15 | import re as regex
16 | from nltk.stem import WordNetLemmatizer
17 | from pattern.en import lemma
18 | wordnet_lemmatizer = WordNetLemmatizer()
19 |
20 |
21 | def load_dataSets(args):
22 | with open(args.input_path, 'r') as f:
23 | datas = json.load(f)
24 | with open(os.path.join(args.data_path, 'tables.json'), 'r', encoding='utf8') as f:
25 | table_datas = json.load(f)
26 | schemas = dict()
27 | for i in range(len(table_datas)):
28 | schemas[table_datas[i]['db_id']] = table_datas[i]
29 | return datas, schemas
30 |
31 |
32 | def partial_match(query, table_name):
33 | query = [lemma(x) for x in query]
34 | table_name = [lemma(x) for x in table_name]
35 | if query in table_name:
36 | return True
37 | return False
38 |
39 |
40 | def is_partial_match(query, table_names):
41 | query = lemma(query)
42 | table_names = [[lemma(x) for x in names.split(' ') ] for names in table_names]
43 | same_count = 0
44 | result = None
45 | for names in table_names:
46 | if query in names:
47 | same_count += 1
48 | result = names
49 | return result if same_count == 1 else False
50 |
51 |
52 | def multi_option(question, q_ind, names, N):
53 | for i in range(q_ind + 1, q_ind + N + 1):
54 | if i < len(question):
55 | re = is_partial_match(question[i][0], names)
56 | if re is not False:
57 | return re
58 | return False
59 |
60 |
61 | def multi_equal(question, q_ind, names, N):
62 | for i in range(q_ind + 1, q_ind + N + 1):
63 | if i < len(question):
64 | if question[i] == names:
65 | return i
66 | return False
67 |
68 |
69 | def random_choice(question_arg, question_arg_type, names, ground_col_labels, q_ind, N, origin_name):
70 | # first try if there are other table
71 | for t_ind, t_val in enumerate(question_arg_type):
72 | if t_val == ['table']:
73 | return names[origin_name.index(question_arg[t_ind])]
74 | for i in range(q_ind + 1, q_ind + N + 1):
75 | if i < len(question_arg):
76 | if len(ground_col_labels) == 0:
77 | for n in names:
78 | if partial_match(question_arg[i][0], n) is True:
79 | return n
80 | else:
81 | for n_id, n in enumerate(names):
82 | if n_id in ground_col_labels and partial_match(question_arg[i][0], n) is True:
83 | return n
84 | if len(ground_col_labels) > 0:
85 | return names[ground_col_labels[0]]
86 | else:
87 | return names[0]
88 |
89 |
90 | def find_table(cur_table, origin_table_names, question_arg_type, question_arg):
91 | h_table = None
92 | for i in range(len(question_arg_type))[::-1]:
93 | if question_arg_type[i] == ['table']:
94 | h_table = question_arg[i]
95 | h_table = origin_table_names.index(h_table)
96 | if h_table != cur_table:
97 | break
98 | if h_table != cur_table:
99 | return h_table
100 |
101 | # find partial
102 | for i in range(len(question_arg_type))[::-1]:
103 | if question_arg_type[i] == ['NONE']:
104 | for t_id, table_name in enumerate(origin_table_names):
105 | if partial_match(question_arg[i], table_name) is True and t_id != h_table:
106 | return t_id
107 |
108 | # random return
109 | for i in range(len(question_arg_type))[::-1]:
110 | if question_arg_type[i] == ['table']:
111 | h_table = question_arg[i]
112 | h_table = origin_table_names.index(h_table)
113 | return h_table
114 |
115 | return cur_table
116 |
117 |
118 | def alter_not_in(datas, schemas):
119 | for d in datas:
120 | if 'Filter(19)' in d['model_result']:
121 | current_table = schemas[d['db_id']]
122 | current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
123 | current_table['col_table'] = [col[0] for col in current_table['column_names']]
124 | origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in
125 | d['table_names']]
126 | question_arg_type = d['question_arg_type']
127 | question_arg = d['question_arg']
128 | pred_label = d['model_result'].split(' ')
129 |
130 | # get potiantial table
131 | cur_table = None
132 | for label_id, label_val in enumerate(pred_label):
133 | if label_val in ['Filter(19)']:
134 | cur_table = int(pred_label[label_id - 1][2:-1])
135 | break
136 |
137 | h_table = find_table(cur_table, origin_table_names, question_arg_type, question_arg)
138 |
139 | for label_id, label_val in enumerate(pred_label):
140 | if label_val in ['Filter(19)']:
141 | for primary in current_table['primary_keys']:
142 | if int(current_table['col_table'][primary]) == int(pred_label[label_id - 1][2:-1]):
143 | pred_label[label_id + 2] = 'C(' + str(
144 | d['col_set'].index(current_table['schema_content_clean'][primary])) + ')'
145 | break
146 | for pair in current_table['foreign_keys']:
147 | if int(current_table['col_table'][pair[0]]) == h_table and d['col_set'].index(
148 | current_table['schema_content_clean'][pair[1]]) == int(pred_label[label_id + 2][2:-1]):
149 | pred_label[label_id + 8] = 'C(' + str(
150 | d['col_set'].index(current_table['schema_content_clean'][pair[0]])) + ')'
151 | pred_label[label_id + 9] = 'T(' + str(h_table) + ')'
152 | break
153 | elif int(current_table['col_table'][pair[1]]) == h_table and d['col_set'].index(
154 | current_table['schema_content_clean'][pair[0]]) == int(pred_label[label_id + 2][2:-1]):
155 | pred_label[label_id + 8] = 'C(' + str(
156 | d['col_set'].index(current_table['schema_content_clean'][pair[1]])) + ')'
157 | pred_label[label_id + 9] = 'T(' + str(h_table) + ')'
158 | break
159 | pred_label[label_id + 3] = pred_label[label_id - 1]
160 |
161 | d['model_result'] = " ".join(pred_label)
162 |
163 |
164 | def alter_inter(datas):
165 | for d in datas:
166 | if 'Filter(0)' in d['model_result']:
167 | now_result = d['model_result'].split(' ')
168 | index = now_result.index('Filter(0)')
169 | c1 = None
170 | c2 = None
171 | for i in range(index + 1, len(now_result)):
172 | if c1 is None and 'C(' in now_result[i]:
173 | c1 = now_result[i]
174 | elif c1 is not None and c2 is None and 'C(' in now_result[i]:
175 | c2 = now_result[i]
176 |
177 | if c1 != c2 or c1 is None or c2 is None:
178 | continue
179 | replace_result = ['Root1(0)'] + now_result[1:now_result.index('Filter(0)')]
180 | for r_id, r_val in enumerate(now_result[now_result.index('Filter(0)') + 2:]):
181 | if 'Filter' in r_val:
182 | break
183 |
184 | replace_result = replace_result + now_result[now_result.index('Filter(0)') + 1:r_id + now_result.index(
185 | 'Filter(0)') + 2]
186 | replace_result = replace_result + now_result[1:now_result.index('Filter(0)')]
187 |
188 | replace_result = replace_result + now_result[r_id + now_result.index('Filter(0)') + 2:]
189 | replace_result = " ".join(replace_result)
190 | d['model_result'] = replace_result
191 |
192 |
193 | def alter_column0(datas):
194 | """
195 | Attach column * table
196 | :return: model_result_replace
197 | """
198 | zero_count = 0
199 | count = 0
200 | result = []
201 | for d in datas:
202 | if 'C(0)' in d['model_result']:
203 | pattern = regex.compile('C\(.*?\) T\(.*?\)')
204 | result_pattern = list(set(pattern.findall(d['model_result'])))
205 | ground_col_labels = []
206 | for pa in result_pattern:
207 | pa = pa.split(' ')
208 | if pa[0] != 'C(0)':
209 | index = int(pa[1][2:-1])
210 | ground_col_labels.append(index)
211 |
212 | ground_col_labels = list(set(ground_col_labels))
213 | question_arg_type = d['question_arg_type']
214 | question_arg = d['question_arg']
215 | table_names = [[lemma(x) for x in names.split(' ')] for names in d['table_names']]
216 | origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in
217 | d['table_names']]
218 | count += 1
219 | easy_flag = False
220 | for q_ind, q in enumerate(d['question_arg']):
221 | q = [lemma(x) for x in q]
222 | q_str = " ".join(" ".join(x) for x in d['question_arg'])
223 | if 'how many' in q_str or 'number of' in q_str or 'count of' in q_str:
224 | easy_flag = True
225 | if easy_flag:
226 | # check for the last one is a table word
227 | for q_ind, q in enumerate(d['question_arg']):
228 | if (q_ind > 0 and q == ['many'] and d['question_arg'][q_ind - 1] == ['how']) or (
229 | q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['number']) or (
230 | q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['count']):
231 | re = multi_equal(question_arg_type, q_ind, ['table'], 2)
232 | if re is not False:
233 | # This step work for the number of [table] example
234 | table_result = table_names[origin_table_names.index(question_arg[re])]
235 | result.append((d['query'], d['question'], table_result, d))
236 | break
237 | else:
238 | re = multi_option(question_arg, q_ind, d['table_names'], 2)
239 | if re is not False:
240 | table_result = re
241 | result.append((d['query'], d['question'], table_result, d))
242 | pass
243 | else:
244 | re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type))
245 | if re is not False:
246 | # This step work for the number of [table] example
247 | table_result = table_names[origin_table_names.index(question_arg[re])]
248 | result.append((d['query'], d['question'], table_result, d))
249 | break
250 | pass
251 | table_result = random_choice(question_arg=question_arg,
252 | question_arg_type=question_arg_type,
253 | names=table_names,
254 | ground_col_labels=ground_col_labels, q_ind=q_ind, N=2,
255 | origin_name=origin_table_names)
256 | result.append((d['query'], d['question'], table_result, d))
257 |
258 | zero_count += 1
259 | break
260 |
261 | else:
262 | M_OP = False
263 | for q_ind, q in enumerate(d['question_arg']):
264 | if M_OP is False and q in [['than'], ['least'], ['most'], ['msot'], ['fewest']] or \
265 | question_arg_type[q_ind] == ['M_OP']:
266 | M_OP = True
267 | re = multi_equal(question_arg_type, q_ind, ['table'], 3)
268 | if re is not False:
269 | # This step work for the number of [table] example
270 | table_result = table_names[origin_table_names.index(question_arg[re])]
271 | result.append((d['query'], d['question'], table_result, d))
272 | break
273 | else:
274 | re = multi_option(question_arg, q_ind, d['table_names'], 3)
275 | if re is not False:
276 | table_result = re
277 | # print(table_result)
278 | result.append((d['query'], d['question'], table_result, d))
279 | pass
280 | else:
281 | # zero_count += 1
282 | re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type))
283 | if re is not False:
284 | # This step work for the number of [table] example
285 | table_result = table_names[origin_table_names.index(question_arg[re])]
286 | result.append((d['query'], d['question'], table_result, d))
287 | break
288 |
289 | table_result = random_choice(question_arg=question_arg,
290 | question_arg_type=question_arg_type,
291 | names=table_names,
292 | ground_col_labels=ground_col_labels, q_ind=q_ind, N=2,
293 | origin_name=origin_table_names)
294 | result.append((d['query'], d['question'], table_result, d))
295 |
296 | pass
297 | if M_OP is False:
298 | table_result = random_choice(question_arg=question_arg,
299 | question_arg_type=question_arg_type,
300 | names=table_names, ground_col_labels=ground_col_labels, q_ind=q_ind,
301 | N=2,
302 | origin_name=origin_table_names)
303 | result.append((d['query'], d['question'], table_result, d))
304 |
305 | for re in result:
306 | table_names = [[lemma(x) for x in names.split(' ')] for names in re[3]['table_names']]
307 | origin_table_names = [[x for x in names.split(' ')] for names in re[3]['table_names']]
308 | if re[2] in table_names:
309 | re[3]['rule_count'] = table_names.index(re[2])
310 | else:
311 | re[3]['rule_count'] = origin_table_names.index(re[2])
312 |
313 | for data in datas:
314 | if 'rule_count' in data:
315 | str_replace = 'C(0) T(' + str(data['rule_count']) + ')'
316 | replace_result = regex.sub('C\(0\) T\(.\)', str_replace, data['model_result'])
317 | data['model_result_replace'] = replace_result
318 | else:
319 | data['model_result_replace'] = data['model_result']
320 |
321 |
322 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # -*- coding: utf-8 -*-
5 | """
6 | # @Time : 2019/5/25
7 | # @Author : Jiaqi&Zecheng
8 | # @File : utils.py
9 | # @Software: PyCharm
10 | """
11 |
12 | import json
13 | import time
14 |
15 | import copy
16 | import numpy as np
17 | import os
18 | import torch
19 | from nltk.stem import WordNetLemmatizer
20 |
21 | from src.dataset import Example
22 | from src.rule import lf
23 | from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
24 |
25 | wordnet_lemmatizer = WordNetLemmatizer()
26 |
27 |
28 | def load_word_emb(file_name, use_small=False):
29 | print ('Loading word embedding from %s'%file_name)
30 | ret = {}
31 | with open(file_name) as inf:
32 | for idx, line in enumerate(inf):
33 | if (use_small and idx >= 500000):
34 | break
35 | info = line.strip().split(' ')
36 | if info[0].lower() not in ret:
37 | ret[info[0]] = np.array(list(map(lambda x:float(x), info[1:])))
38 | return ret
39 |
40 | def lower_keys(x):
41 | if isinstance(x, list):
42 | return [lower_keys(v) for v in x]
43 | elif isinstance(x, dict):
44 | return dict((k.lower(), lower_keys(v)) for k, v in x.items())
45 | else:
46 | return x
47 |
48 | def get_table_colNames(tab_ids, tab_cols):
49 | table_col_dict = {}
50 | for ci, cv in zip(tab_ids, tab_cols):
51 | if ci != -1:
52 | table_col_dict[ci] = table_col_dict.get(ci, []) + cv
53 | result = []
54 | for ci in range(len(table_col_dict)):
55 | result.append(table_col_dict[ci])
56 | return result
57 |
58 | def get_col_table_dict(tab_cols, tab_ids, sql):
59 | table_dict = {}
60 | for c_id, c_v in enumerate(sql['col_set']):
61 | for cor_id, cor_val in enumerate(tab_cols):
62 | if c_v == cor_val:
63 | table_dict[tab_ids[cor_id]] = table_dict.get(tab_ids[cor_id], []) + [c_id]
64 |
65 | col_table_dict = {}
66 | for key_item, value_item in table_dict.items():
67 | for value in value_item:
68 | col_table_dict[value] = col_table_dict.get(value, []) + [key_item]
69 | col_table_dict[0] = [x for x in range(len(table_dict) - 1)]
70 | return col_table_dict
71 |
72 |
73 | def schema_linking(question_arg, question_arg_type, one_hot_type, col_set_type, col_set_iter, sql):
74 |
75 | for count_q, t_q in enumerate(question_arg_type):
76 | t = t_q[0]
77 | if t == 'NONE':
78 | continue
79 | elif t == 'table':
80 | one_hot_type[count_q][0] = 1
81 | question_arg[count_q] = ['table'] + question_arg[count_q]
82 | elif t == 'col':
83 | one_hot_type[count_q][1] = 1
84 | try:
85 | col_set_type[col_set_iter.index(question_arg[count_q])][1] = 5
86 | question_arg[count_q] = ['column'] + question_arg[count_q]
87 | except:
88 | print(col_set_iter, question_arg[count_q])
89 | raise RuntimeError("not in col set")
90 | elif t == 'agg':
91 | one_hot_type[count_q][2] = 1
92 | elif t == 'MORE':
93 | one_hot_type[count_q][3] = 1
94 | elif t == 'MOST':
95 | one_hot_type[count_q][4] = 1
96 | elif t == 'value':
97 | one_hot_type[count_q][5] = 1
98 | question_arg[count_q] = ['value'] + question_arg[count_q]
99 | else:
100 | if len(t_q) == 1:
101 | for col_probase in t_q:
102 | if col_probase == 'asd':
103 | continue
104 | try:
105 | col_set_type[sql['col_set'].index(col_probase)][2] = 5
106 | question_arg[count_q] = ['value'] + question_arg[count_q]
107 | except:
108 | print(sql['col_set'], col_probase)
109 | raise RuntimeError('not in col')
110 | one_hot_type[count_q][5] = 1
111 | else:
112 | for col_probase in t_q:
113 | if col_probase == 'asd':
114 | continue
115 | col_set_type[sql['col_set'].index(col_probase)][3] += 1
116 |
117 | def process(sql, table):
118 |
119 | process_dict = {}
120 |
121 | origin_sql = sql['question_toks']
122 | table_names = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in table['table_names']]
123 |
124 | sql['pre_sql'] = copy.deepcopy(sql)
125 |
126 | tab_cols = [col[1] for col in table['column_names']]
127 | tab_ids = [col[0] for col in table['column_names']]
128 |
129 | col_set_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in sql['col_set']]
130 | col_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(" ")] for x in tab_cols]
131 | q_iter_small = [wordnet_lemmatizer.lemmatize(x).lower() for x in origin_sql]
132 | question_arg = copy.deepcopy(sql['question_arg'])
133 | question_arg_type = sql['question_arg_type']
134 | one_hot_type = np.zeros((len(question_arg_type), 6))
135 |
136 | col_set_type = np.zeros((len(col_set_iter), 4))
137 |
138 | process_dict['col_set_iter'] = col_set_iter
139 | process_dict['q_iter_small'] = q_iter_small
140 | process_dict['col_set_type'] = col_set_type
141 | process_dict['question_arg'] = question_arg
142 | process_dict['question_arg_type'] = question_arg_type
143 | process_dict['one_hot_type'] = one_hot_type
144 | process_dict['tab_cols'] = tab_cols
145 | process_dict['tab_ids'] = tab_ids
146 | process_dict['col_iter'] = col_iter
147 | process_dict['table_names'] = table_names
148 |
149 | return process_dict
150 |
151 | def is_valid(rule_label, col_table_dict, sql):
152 | try:
153 | lf.build_tree(copy.copy(rule_label))
154 | except:
155 | print(rule_label)
156 |
157 | flag = False
158 | for r_id, rule in enumerate(rule_label):
159 | if type(rule) == C:
160 | try:
161 | assert rule_label[r_id + 1].id_c in col_table_dict[rule.id_c], print(sql['question'])
162 | except:
163 | flag = True
164 | print(sql['question'])
165 | return flag is False
166 |
167 |
168 | def to_batch_seq(sql_data, table_data, idxes, st, ed,
169 | is_train=True):
170 | """
171 |
172 | :return:
173 | """
174 | examples = []
175 |
176 | for i in range(st, ed):
177 | sql = sql_data[idxes[i]]
178 | table = table_data[sql['db_id']]
179 |
180 | process_dict = process(sql, table)
181 |
182 | for c_id, col_ in enumerate(process_dict['col_set_iter']):
183 | for q_id, ori in enumerate(process_dict['q_iter_small']):
184 | if ori in col_:
185 | process_dict['col_set_type'][c_id][0] += 1
186 |
187 | schema_linking(process_dict['question_arg'], process_dict['question_arg_type'],
188 | process_dict['one_hot_type'], process_dict['col_set_type'], process_dict['col_set_iter'], sql)
189 |
190 | col_table_dict = get_col_table_dict(process_dict['tab_cols'], process_dict['tab_ids'], sql)
191 | table_col_name = get_table_colNames(process_dict['tab_ids'], process_dict['col_iter'])
192 |
193 | process_dict['col_set_iter'][0] = ['count', 'number', 'many']
194 |
195 | rule_label = None
196 | if 'rule_label' in sql:
197 | try:
198 | rule_label = [eval(x) for x in sql['rule_label'].strip().split(' ')]
199 | except:
200 | continue
201 | if is_valid(rule_label, col_table_dict=col_table_dict, sql=sql) is False:
202 | continue
203 |
204 | example = Example(
205 | src_sent=process_dict['question_arg'],
206 | col_num=len(process_dict['col_set_iter']),
207 | vis_seq=(sql['question'], process_dict['col_set_iter'], sql['query']),
208 | tab_cols=process_dict['col_set_iter'],
209 | sql=sql['query'],
210 | one_hot_type=process_dict['one_hot_type'],
211 | col_hot_type=process_dict['col_set_type'],
212 | table_names=process_dict['table_names'],
213 | table_len=len(process_dict['table_names']),
214 | col_table_dict=col_table_dict,
215 | cols=process_dict['tab_cols'],
216 | table_col_name=table_col_name,
217 | table_col_len=len(table_col_name),
218 | tokenized_src_sent=process_dict['col_set_type'],
219 | tgt_actions=rule_label
220 | )
221 | example.sql_json = copy.deepcopy(sql)
222 | examples.append(example)
223 |
224 | if is_train:
225 | examples.sort(key=lambda e: -len(e.src_sent))
226 | return examples
227 | else:
228 | return examples
229 |
230 | def epoch_train(model, optimizer, batch_size, sql_data, table_data,
231 | args, epoch=0, loss_epoch_threshold=20, sketch_loss_coefficient=0.2):
232 | model.train()
233 | # shuffe
234 | perm=np.random.permutation(len(sql_data))
235 | cum_loss = 0.0
236 | st = 0
237 | while st < len(sql_data):
238 | ed = st+batch_size if st+batch_size < len(perm) else len(perm)
239 | examples = to_batch_seq(sql_data, table_data, perm, st, ed)
240 | optimizer.zero_grad()
241 |
242 | score = model.forward(examples)
243 | loss_sketch = -score[0]
244 | loss_lf = -score[1]
245 |
246 | loss_sketch = torch.mean(loss_sketch)
247 | loss_lf = torch.mean(loss_lf)
248 |
249 | if epoch > loss_epoch_threshold:
250 | loss = loss_lf + sketch_loss_coefficient * loss_sketch
251 | else:
252 | loss = loss_lf + loss_sketch
253 |
254 | loss.backward()
255 | if args.clip_grad > 0.:
256 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
257 | optimizer.step()
258 | cum_loss += loss.data.cpu().numpy()*(ed - st)
259 | st = ed
260 | return cum_loss / len(sql_data)
261 |
262 | def epoch_acc(model, batch_size, sql_data, table_data, beam_size=3):
263 | model.eval()
264 | perm = list(range(len(sql_data)))
265 | st = 0
266 |
267 | json_datas = []
268 | sketch_correct, rule_label_correct, total = 0, 0, 0
269 | while st < len(sql_data):
270 | ed = st+batch_size if st+batch_size < len(perm) else len(perm)
271 | examples = to_batch_seq(sql_data, table_data, perm, st, ed,
272 | is_train=False)
273 | for example in examples:
274 | results_all = model.parse(example, beam_size=beam_size)
275 | results = results_all[0]
276 | list_preds = []
277 | try:
278 |
279 | pred = " ".join([str(x) for x in results[0].actions])
280 | for x in results:
281 | list_preds.append(" ".join(str(x.actions)))
282 | except Exception as e:
283 | # print('Epoch Acc: ', e)
284 | # print(results)
285 | # print(results_all)
286 | pred = ""
287 |
288 | simple_json = example.sql_json['pre_sql']
289 |
290 | simple_json['sketch_result'] = " ".join(str(x) for x in results_all[1])
291 | simple_json['model_result'] = pred
292 |
293 | truth_sketch = " ".join([str(x) for x in example.sketch])
294 | truth_rule_label = " ".join([str(x) for x in example.tgt_actions])
295 |
296 | if truth_sketch == simple_json['sketch_result']:
297 | sketch_correct += 1
298 | if truth_rule_label == simple_json['model_result']:
299 | rule_label_correct += 1
300 | total += 1
301 |
302 | json_datas.append(simple_json)
303 | st = ed
304 | return json_datas, float(sketch_correct)/float(total), float(rule_label_correct)/float(total)
305 |
306 | def eval_acc(preds, sqls):
307 | sketch_correct, best_correct = 0, 0
308 | for i, (pred, sql) in enumerate(zip(preds, sqls)):
309 | if pred['model_result'] == sql['rule_label']:
310 | best_correct += 1
311 | print(best_correct / len(preds))
312 | return best_correct / len(preds)
313 |
314 |
315 | def load_data_new(sql_path, table_data, use_small=False):
316 | sql_data = []
317 |
318 | print("Loading data from %s" % sql_path)
319 | with open(sql_path) as inf:
320 | data = lower_keys(json.load(inf))
321 | sql_data += data
322 |
323 | table_data_new = {table['db_id']: table for table in table_data}
324 |
325 | if use_small:
326 | return sql_data[:80], table_data_new
327 | else:
328 | return sql_data, table_data_new
329 |
330 |
331 | def load_dataset(dataset_dir, use_small=False):
332 | print("Loading from datasets...")
333 |
334 | TABLE_PATH = os.path.join(dataset_dir, "tables.json")
335 | TRAIN_PATH = os.path.join(dataset_dir, "train.json")
336 | DEV_PATH = os.path.join(dataset_dir, "dev.json")
337 | with open(TABLE_PATH) as inf:
338 | print("Loading data from %s"%TABLE_PATH)
339 | table_data = json.load(inf)
340 |
341 | train_sql_data, train_table_data = load_data_new(TRAIN_PATH, table_data, use_small=use_small)
342 | val_sql_data, val_table_data = load_data_new(DEV_PATH, table_data, use_small=use_small)
343 |
344 | return train_sql_data, train_table_data, val_sql_data, val_table_data
345 |
346 |
347 | def save_checkpoint(model, checkpoint_name):
348 | torch.save(model.state_dict(), checkpoint_name)
349 |
350 |
351 | def save_args(args, path):
352 | with open(path, 'w') as f:
353 | f.write(json.dumps(vars(args), indent=4))
354 |
355 | def init_log_checkpoint_path(args):
356 | save_path = args.save
357 | dir_name = save_path + str(int(time.time()))
358 | save_path = os.path.join(os.path.curdir, 'saved_model', dir_name)
359 | if os.path.exists(save_path) is False:
360 | os.makedirs(save_path)
361 | return save_path
362 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | # Copyright (c) Microsoft Corporation.
5 | # Licensed under the MIT license.
6 | # -*- coding: utf-8 -*-
7 | """
8 | # @Time : 2019/5/25
9 | # @Author : Jiaqi&Zecheng
10 | # @File : train.py
11 | # @Software: PyCharm
12 | """
13 |
14 | import time
15 | import traceback
16 |
17 | import os
18 | import torch
19 | import torch.optim as optim
20 | import tqdm
21 | import copy
22 |
23 | from src import args as arg
24 | from src import utils
25 | from src.models.model import IRNet
26 | from src.rule import semQL
27 |
28 |
29 | def train(args):
30 | """
31 | :param args:
32 | :return:
33 | """
34 |
35 | grammar = semQL.Grammar()
36 | sql_data, table_data, val_sql_data,\
37 | val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)
38 |
39 | model = IRNet(args, grammar)
40 |
41 |
42 | if args.cuda: model.cuda()
43 |
44 | # now get the optimizer
45 | optimizer_cls = eval('torch.optim.%s' % args.optimizer)
46 | optimizer = optimizer_cls(model.parameters(), lr=args.lr)
47 | print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
48 | if args.lr_scheduler:
49 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
50 | else:
51 | scheduler = None
52 |
53 | print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
54 | print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)
55 |
56 | if args.load_model:
57 | print('load pretrained model from %s'% (args.load_model))
58 | pretrained_model = torch.load(args.load_model,
59 | map_location=lambda storage, loc: storage)
60 | pretrained_modeled = copy.deepcopy(pretrained_model)
61 | for k in pretrained_model.keys():
62 | if k not in model.state_dict().keys():
63 | del pretrained_modeled[k]
64 |
65 | model.load_state_dict(pretrained_modeled)
66 |
67 | model.word_emb = utils.load_word_emb(args.glove_embed_path)
68 | # begin train
69 |
70 | model_save_path = utils.init_log_checkpoint_path(args)
71 | utils.save_args(args, os.path.join(model_save_path, 'config.json'))
72 | best_dev_acc = .0
73 |
74 | try:
75 | with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
76 | for epoch in tqdm.tqdm(range(args.epoch)):
77 | if args.lr_scheduler:
78 | scheduler.step()
79 | epoch_begin = time.time()
80 | loss = utils.epoch_train(model, optimizer, args.batch_size, sql_data, table_data, args,
81 | loss_epoch_threshold=args.loss_epoch_threshold,
82 | sketch_loss_coefficient=args.sketch_loss_coefficient)
83 | epoch_end = time.time()
84 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
85 | beam_size=args.beam_size)
86 | # acc = utils.eval_acc(json_datas, val_sql_data)
87 |
88 | if acc > best_dev_acc:
89 | utils.save_checkpoint(model, os.path.join(model_save_path, 'best_model.model'))
90 | best_dev_acc = acc
91 | utils.save_checkpoint(model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc))
92 |
93 | log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % (
94 | epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin)
95 | tqdm.tqdm.write(log_str)
96 | epoch_fd.write(log_str)
97 | epoch_fd.flush()
98 | except Exception as e:
99 | # Save model
100 | utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
101 | print(e)
102 | tb = traceback.format_exc()
103 | print(tb)
104 | else:
105 | utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
106 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
107 | beam_size=args.beam_size)
108 | # acc = utils.eval_acc(json_datas, val_sql_data)
109 |
110 | print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (sketch_acc, acc, acc,))
111 |
112 |
113 | if __name__ == '__main__':
114 | arg_parser = arg.init_arg_parser()
115 | args = arg.init_config(arg_parser)
116 | print(args)
117 | train(args)
118 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | #!/bin/bash
5 |
6 | devices=$1
7 | save_name=$2
8 |
9 | CUDA_VISIBLE_DEVICES=$devices python -u train.py --dataset ./data \
10 | --glove_embed_path ./data/glove.42B.300d.txt \
11 | --cuda \
12 | --epoch 50 \
13 | --loss_epoch_threshold 50 \
14 | --sketch_loss_coefficie 1.0 \
15 | --beam_size 1 \
16 | --seed 90 \
17 | --save ${save_name} \
18 | --embed_size 300 \
19 | --sentence_features \
20 | --column_pointer \
21 | --hidden_size 300 \
22 | --lr_scheduler \
23 | --lr_scheduler_gammar 0.5 \
24 | --att_vec_size 300 > ${save_name}".log"
--------------------------------------------------------------------------------