├── .gitignore
├── .travis.yml
├── .travis
├── deploy.sh
└── keys.tar.enc
├── LICENSE
├── README.md
├── build.sbt
├── js
└── src
│ ├── main
│ └── scala
│ │ └── com
│ │ └── criteo
│ │ └── vizatra
│ │ └── vizsql
│ │ └── js
│ │ ├── Database.scala
│ │ ├── QueryParser.scala
│ │ ├── common
│ │ └── ParseResult.scala
│ │ └── json
│ │ ├── ColumnReader.scala
│ │ ├── DBReader.scala
│ │ ├── DialectReader.scala
│ │ ├── Reader.scala
│ │ ├── SchemaReader.scala
│ │ └── TableReader.scala
│ └── test
│ └── scala
│ └── com
│ └── criteo
│ └── vizatra
│ └── vizsql
│ └── js
│ ├── DatabaseSpec.scala
│ ├── QueryParserSpec.scala
│ ├── common
│ └── ParseResultSpec.scala
│ └── json
│ ├── ColumnReaderSpec.scala
│ ├── DBReaderSpec.scala
│ ├── DialectReaderSpec.scala
│ ├── SchemaReaderSpec.scala
│ └── TableReaderSpec.scala
├── jvm
└── src
│ └── test
│ └── scala
│ └── com
│ └── criteo
│ └── vizatra
│ └── vizsql
│ ├── ExtractColumnsSpec.scala
│ ├── ExtractPlaceholdersSpec.scala
│ ├── FormatSQLSpec.scala
│ ├── OlapSpec.scala
│ ├── OptimizeSpec.scala
│ ├── ParseSQL99Spec.scala
│ ├── ParseVerticaDialectSpec.scala
│ ├── ParsingErrorsSpec.scala
│ ├── SchemaErrorsSpec.scala
│ ├── ThreadSafetySpec.scala
│ ├── TypingErrorsSpec.scala
│ └── hive
│ ├── HiveParsingErrorsSpec.scala
│ ├── HiveTypeParserSpec.scala
│ └── ParseHiveQuerySpec.scala
├── project
└── plugins.sbt
└── shared
└── src
└── main
└── scala
└── com
└── criteo
└── vizatra
└── vizsql
├── AST.scala
├── Dialect.scala
├── Errors.scala
├── EvalHelper.scala
├── Functions.scala
├── Parser.scala
├── Schema.scala
├── Show.scala
├── Types.scala
├── VizSQL.scala
├── dialects
├── h2
│ └── H2Dialect.scala
├── hive
│ ├── HiveAST.scala
│ ├── HiveDialect.scala
│ ├── HiveFunctions.scala
│ ├── HiveTypes.scala
│ └── TypeParser.scala
├── hsqldb
│ └── HsqlDBDialect.scala
├── mysql
│ └── MySQLDialect.scala
├── postgresql
│ └── PostgresqlDialect.scala
├── sql99
│ └── SQL99.scala
└── vertica
│ └── VerticaDialect.scala
├── olap
└── Olap.scala
└── optimize
└── Optimizer.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/*
2 | target
3 | target/*
4 | *.DS_Store
5 | *.releaseBackup
6 | release.properties
7 | *.iml
8 | *.ipr
9 | *.iws
10 | .idea
11 | .idea/*
12 | src/main/resources/application.conf
13 | .checkstyle
14 | .classpath
15 | .project
16 | .settings/
17 | src/main/webapp/node/
18 | src/main/webapp/node_modules/
19 | src/main/webapp/lib/
20 | src/main/resources/public/
21 | *.bower.json
22 | dependency-reduced-pom.xml
23 | *.swp
24 | .devsettings
25 | effective-pom.xml
26 | drivers/
27 | project/*
28 | !project/plugins.sbt
29 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 | scala:
3 | - 2.11.8
4 | env:
5 | global:
6 | secure: IpQPAMFrafCkbzq+/RgHbv7kaTS7yViw4ce6yTNKmdgQeVHBPeRjaKteyuGmaFyOMIvHCsABtk3QWKfOlxyuzv/ZMlfTy69TnoDa62jAFIJLhZ6pqiGswtXwso7EFW6NgBQJi1AEV9mtC0VOH6YklbUSsnZsI1saB2JHD9h9fafOoiH8ifDDGh+vrIiozrr4YjuT/5/fV90rt+Bha5nOVqVmxmUP9Fpmh0ca3hY0Nw/KhtzxP+VAnQd4MaP0CU8b8+H/IVBUnVBB3xj8pPKs+t/di4NBDfDb6IYmGBtHviJxoDxnJwfRlT8I6o/e2EMNs7wQHke29VsI9Vl+BK6oug0JDugTCgocNrvjqsFDu1nInP3D0yMJRsFBIJ+BhmRw5fEi85KMlu+InfQTRq+FzqMHcIuQMYNlCDgIsQK4an3rUmJ8yJ6nX2qTwZjw8kRTVoszlmZpJ6h1pjG95jHq42JObnPge+oJyl4sO4ngClxEiHr1YQBsDuJfgPl2skXn1p3AGnK6AYn8I+NdK0mtBnJCrlonpxYz3i5zL2mUnDTV19dMxAKXljQRckllVO+JDKZiBl8s+GDp/JT5tIKGozp4phRRUlXRW1RROLHU7qvO1SEvC8ERUQS9K0V/QxCPK9L12IkJYUi37aaN/Qw51MLScxwEniAjoOuCaYpxj4c=
7 | before_deploy:
8 | - openssl aes-256-cbc -K $encrypted_12ac5f6463c4_key -iv $encrypted_12ac5f6463c4_iv
9 | -in .travis/keys.tar.enc -out .travis/keys.tar -d
10 | - tar xvf .travis/keys.tar
11 |
12 | deploy:
13 | provider: script
14 | script: './.travis/deploy.sh'
15 | skip_cleanup: true
16 | on:
17 | tags: true
18 |
--------------------------------------------------------------------------------
/.travis/deploy.sh:
--------------------------------------------------------------------------------
1 | sbt vizsqlJVM/publishSigned
2 | sbt 'project vizsqlJVM' sonatypeRelease
3 |
--------------------------------------------------------------------------------
/.travis/keys.tar.enc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/criteo/vizsql/b6bcbae61acf8269ec21b786aae8bc27b8adc355/.travis/keys.tar.enc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VizSQL
2 |
3 | [](https://opensource.org/licenses/Apache-2.0)
4 | [](https://travis-ci.org/criteo/vizsql)
5 | [](https://maven-badges.herokuapp.com/maven-central/com.criteo/vizsql_2.11)
6 |
7 | ## VizSQL is a SQL parser & typer for scala
8 |
9 | VizSQL provides support for parsing SQL statements into a scala AST. This AST can then be used to support many interesting transformations. We provide for example a static optimizer and an OLAP query rewriter. It also support typing the query to retrieve the resultset columns returned by a SQL statement.
10 |
11 | ### Installation
12 |
13 | #### SBT
14 |
15 | ```scala
16 | libraryDependencies += "com.criteo" % "vizsql_2.11" % "{latestVersion}"
17 | ```
18 |
19 | #### Maven
20 |
21 | ```xml
22 |
23 | com.criteo
24 | vizsql_2.11
25 | {latestVersion}
26 |
27 | ```
28 |
29 | ### tl;dr
30 |
31 | Here we use VizSQL to parse a SQL SELECT statement based on the SAKILA database, and to retrieve the column names and types of the returned resultset.
32 |
33 | ```scala
34 | import com.criteo.vizatra.vizsql
35 |
36 | val resultColumns =
37 | VizSQL.parseQuery(
38 | """SELECT country_id, max(last_update) as updated from City as x""",
39 | SAKILA
40 | )
41 | .fold(e => sys.error(s"SQL syntax error: $e"), identity)
42 | .columns
43 | .fold(e => sys.error(s"SQL error: $e"), identity)
44 |
45 | assert(
46 | resultColumns == List(
47 | Column("country_id", INTEGER(nullable = false)),
48 | Column("updated", TIMESTAMP(nullable = true))
49 | )
50 | )
51 | ```
52 |
53 | Note that the query is not executed. It is just parsed, validated, analyzed and typed using the SAKILA database schema. The schema can be provided directly or extracted at runtime from a JDBC connection if needed. Here is for example the minimum SAKILA schema we need to provide so VizSQL is able to type the previous query:
54 |
55 | ```scala
56 | val SAKILA = DB(schemas = List(
57 | Schema(
58 | "sakila",
59 | tables = List(
60 | Table(
61 | "City",
62 | columns = List(
63 | Column("city_id", INTEGER(nullable = false)),
64 | Column("city", STRING(nullable = true)),
65 | Column("country_id", INTEGER(nullable = false)),
66 | Column("last_update", TIMESTAMP(nullable = true))
67 | )
68 | )
69 | )
70 | )
71 | ))
72 | ```
73 |
74 | ### Dialects
75 |
76 | VizSQL support several dialects allowing to understand specific SQL syntax or functions for different databases. We started the core one based on the *SQL99* standard. Then we enriched it to create specific dialects for **Vertica** and **Hive**. This is of course a work in progress (even for the SQL99 one).
77 |
78 | ### In-browser VizSQL
79 |
80 | Yes, VizSQL can run in browsers, thanks to Scala.js, the project can be compiled to JavaScript, so that your front-end code can also be equipped with VizSQL.
81 |
82 | To compile to JavaScript:
83 | ```sh
84 | sbt fullOptJS
85 | ```
86 |
87 | then `vizsql-opt.js` is generated in the `target` folder, it's packed as a CommonJS (node.js) module
88 |
89 | To use VizSQL:
90 | ```javascript
91 | const Database = require('vizsql').Database;
92 | const db = Database().from({
93 | /* db definitions */
94 | });
95 |
96 | const parseResult = db.parse("SELECT * FROM table");
97 | ```
98 |
99 | the parse result contains ```{ error, select }```:
100 |
101 | - error (object), null if no error is present
102 | ```javascript
103 | {
104 | msg: 'err' // message of the error
105 | pos: 1 // position of the error
106 | }
107 | ```
108 | - select (object), null if there's an error
109 | ```javascript
110 | {
111 | columns: [...], // columns in the query
112 | tables: [...] // tables in the query
113 | }
114 | ```
115 |
116 | ### License
117 |
118 | This project is licensed under the Apache 2.0 license.
119 |
120 | ### Copyright
121 |
122 | Copyright © [Criteo](http://labs.criteo.com), 2016.
123 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | name := "vizsql"
2 |
3 | scalaVersion in ThisBuild := "2.11.8"
4 |
5 | lazy val root = project.in(file("."))
6 | .enablePlugins(ScalaJSPlugin)
7 | .aggregate(vizsqlJS, vizsqlJVM)
8 | .settings()
9 |
10 | lazy val vizsql = crossProject.in(file("."))
11 | .settings(
12 | libraryDependencies += "org.scalactic" %%% "scalactic" % "3.0.0",
13 | libraryDependencies += "org.scalatest" %%% "scalatest" % "3.0.0" % "test"
14 | )
15 | .jvmSettings(
16 | organization := "com.criteo",
17 | version := "1.0.0",
18 | libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.4",
19 | credentials += Credentials(
20 | "Sonatype Nexus Repository Manager",
21 | "oss.sonatype.org",
22 | "criteo-oss",
23 | sys.env.getOrElse("SONATYPE_PASSWORD", "")
24 | )
25 | )
26 | .jsSettings(
27 | libraryDependencies += "org.scala-js" %%% "scala-parser-combinators" % "1.0.2",
28 | scalaJSModuleKind := ModuleKind.CommonJSModule
29 | )
30 |
31 | lazy val vizsqlJVM = vizsql.jvm
32 | lazy val vizsqlJS = vizsql.js
33 |
34 | // To sync with Maven central, you need to supply the following information:
35 | pomExtra in Global := {
36 | https://github.com/criteo/vizsql
37 |
38 |
39 | Apache 2
40 | http://www.apache.org/licenses/LICENSE-2.0.txt
41 |
42 |
43 |
44 | scm:git:github.com/criteo/vizsql.git
45 | scm:git:git@github.com:criteo/vizsql.git
46 | github.com/criteo/vizsql
47 |
48 |
49 |
50 | Guillaume Bort
51 | g.bort@criteo.com
52 | https://github.com/guillaumebort
53 | Criteo
54 | http://www.criteo.com
55 |
56 |
57 | }
58 |
59 | pgpPassphrase := sys.env.get("SONATYPE_PASSWORD").map(_.toArray)
60 | pgpSecretRing := file(".travis/secring.gpg")
61 | pgpPublicRing := file(".travis/pubring.gpg")
62 | usePgpKeyHex("755d525885532e9e")
63 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/Database.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js
2 |
3 | import com.criteo.vizatra.vizsql.DB
4 | import com.criteo.vizatra.vizsql.js.json.DBReader
5 |
6 | import scala.scalajs.js.Dynamic
7 | import scala.scalajs.js
8 | import scala.scalajs.js.annotation.{JSExport, ScalaJSDefined}
9 |
10 | @JSExport("Database")
11 | object Database {
12 | /**
13 | * Parses to a VizSQL DB object
14 | * @param input a JS object of the database definition
15 | * @return DB
16 | */
17 | @JSExport
18 | def parse(input: Dynamic): DB = DBReader.apply(input)
19 |
20 | /**
21 | * Construct a Database object from the database definition
22 | * @param input a JS object of the database definition
23 | * @return Database
24 | */
25 | @JSExport
26 | def from(input: Dynamic): Database = new Database(parse(input))
27 | }
28 |
29 | @ScalaJSDefined
30 | class Database(db: DB) extends js.Object {
31 | def parse(query: String) = {
32 | QueryParser.parse(query, db)
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/QueryParser.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js
2 |
3 | import com.criteo.vizatra.vizsql.js.common._
4 | import com.criteo.vizatra.vizsql.{DB, Query, VizSQL}
5 |
6 | import scala.scalajs.js.JSConverters._
7 | import scala.scalajs.js.annotation.JSExport
8 |
9 | @JSExport("QueryParser")
10 | object QueryParser {
11 | @JSExport
12 | def parse(query: String, db: DB): ParseResult =
13 | VizSQL.parseQuery(query, db) match {
14 | case Left(err) => new ParseResult(new ParseError(err.msg, err.pos))
15 | case Right(query) => convert(query)
16 | }
17 |
18 | def convert(query: Query): ParseResult = {
19 | val select = query.select
20 | val db = query.db
21 | val result = for {
22 | columns <- select.getColumns(db).right
23 | tables <- select.getTables(db).right
24 | } yield (columns, tables)
25 | result fold (
26 | err => new ParseResult(new ParseError(err.msg, err.pos)), { case (columns, tables) =>
27 | val cols = columns map Column.from
28 | val tbls = tables map { case (maybeSchema, table) => Table.from(table, maybeSchema) }
29 | new ParseResult(select = new Select(cols.toJSArray, tbls.toJSArray))
30 | }
31 | )
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/common/ParseResult.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.common
2 |
3 | import scala.scalajs.js
4 | import scala.scalajs.js.UndefOr
5 | import scala.scalajs.js.annotation.ScalaJSDefined
6 | import scala.scalajs.js.JSConverters._
7 | import com.criteo.vizatra.vizsql
8 |
9 | @ScalaJSDefined
10 | class ParseResult(val error: UndefOr[ParseError] = js.undefined, val select: UndefOr[Select] = js.undefined) extends js.Object
11 |
12 | @ScalaJSDefined
13 | class ParseError(val msg: String, val pos: Int) extends js.Object
14 |
15 | @ScalaJSDefined
16 | class Select(val columns: js.Array[Column], val tables: js.Array[Table]) extends js.Object
17 |
18 | @ScalaJSDefined
19 | class Column(val name: String, val `type`: String, val nullable: Boolean) extends js.Object
20 |
21 | object Column {
22 | def from(column: vizsql.Column): Column = new Column(column.name, column.typ.show, column.typ.nullable)
23 | }
24 |
25 | @ScalaJSDefined
26 | class Table(val name: String, val columns: js.Array[Column], val schema: UndefOr[String] = js.undefined) extends js.Object
27 |
28 | object Table {
29 | def from(table: vizsql.Table, schema: Option[String] = None): Table = new Table(table.name, table.columns.map(Column.from).toJSArray, schema.orUndefined)
30 | }
31 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/json/ColumnReader.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import com.criteo.vizatra.vizsql.hive.TypeParser
5 |
6 | import scala.scalajs.js.{Dynamic, UndefOr}
7 |
8 | object ColumnReader extends Reader[Column] {
9 | lazy val hiveTypeParser = new TypeParser
10 |
11 | override def apply(dyn: Dynamic): Column = {
12 | val name = dyn.name.asInstanceOf[String]
13 | val nullable = dyn.nullable.asInstanceOf[UndefOr[Boolean]].getOrElse(false)
14 | val typ = parseType(dyn.`type`.asInstanceOf[String], nullable)
15 | Column(name, typ)
16 | }
17 |
18 | def parseType(input: String, nullable: Boolean): Type = Type.from(nullable).applyOrElse(input, (rest: String) =>
19 | hiveTypeParser.parseType(rest) match {
20 | case Left(err) => throw new IllegalArgumentException(err)
21 | case Right(t) => t
22 | }
23 | )
24 | }
25 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/json/DBReader.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql.{DB, Schemas}
4 |
5 | import scala.scalajs.js.{Array, Dynamic, UndefOr}
6 |
7 | object DBReader extends Reader[DB] {
8 | override def apply(dyn: Dynamic): DB = {
9 | implicit val dialect = DialectReader(dyn.dialect)
10 | val schemas = dyn.schemas
11 | .asInstanceOf[Array[Dynamic]]
12 | .toList map SchemaReader.apply
13 | DB(schemas)
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/json/DialectReader.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import com.criteo.vizatra.vizsql.hive.HiveDialect
5 |
6 | import scala.scalajs.js.Dynamic
7 |
8 | object DialectReader extends Reader[Dialect] {
9 | override def apply(dyn: Dynamic): Dialect = from(dyn.asInstanceOf[String])
10 |
11 | def from(input: String): Dialect = input.toLowerCase match {
12 | case "vertica" => vertica.dialect
13 | case "hsql" => hsqldb.dialect
14 | case "h2" => h2.dialect
15 | case "postgresql" => postgresql.dialect
16 | case "hive" => HiveDialect(Map.empty) // TODO: handle hive UDFs
17 | case _ => sql99.dialect
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/json/Reader.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import scala.scalajs.js.{Dynamic, JSON}
4 |
5 | trait Reader[T] {
6 | def apply(dyn: Dynamic): T
7 | def apply(input: String): T = apply(JSON.parse(input))
8 | }
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/json/SchemaReader.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql.{Schema, Table}
4 | import scalajs.js.Array
5 |
6 | import scala.scalajs.js.UndefOr
7 | import scala.scalajs.js.Dynamic
8 |
9 | object SchemaReader extends Reader[Schema] {
10 | override def apply(dyn: Dynamic): Schema = {
11 | val name = dyn.name.asInstanceOf[String]
12 | val tables: List[Table] = dyn.tables
13 | .asInstanceOf[UndefOr[Array[Dynamic]]]
14 | .getOrElse(Array())
15 | .toList map TableReader.apply
16 | Schema(name, tables)
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/js/src/main/scala/com/criteo/vizatra/vizsql/js/json/TableReader.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql.{Column, Table}
4 |
5 | import scala.scalajs.js.{Dynamic, UndefOr, Array}
6 |
7 | object TableReader extends Reader[Table] {
8 | override def apply(dyn: Dynamic): Table = {
9 | val table = dyn.name.asInstanceOf[String]
10 | val columns: List[Column] = dyn.columns
11 | .asInstanceOf[UndefOr[Array[Dynamic]]]
12 | .getOrElse(Array())
13 | .toList map ColumnReader.apply
14 | Table(table, columns)
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/DatabaseSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js
2 |
3 | import com.criteo.vizatra.vizsql.DB
4 | import com.criteo.vizatra.vizsql._
5 | import org.scalatest.{FlatSpec, Matchers}
6 |
7 | class DatabaseSpec extends FlatSpec with Matchers {
8 | implicit val dialect = sql99.dialect
9 | val db = DB(schemas = List(
10 | Schema(
11 | "sakila",
12 | tables = List(
13 | Table(
14 | "City",
15 | columns = List(
16 | Column("city_id", INTEGER(nullable = false)),
17 | Column("city", STRING(nullable = true)),
18 | Column("country_id", INTEGER(nullable = false)),
19 | Column("last_update", TIMESTAMP(nullable = false))
20 | )
21 | )
22 | )
23 | )
24 | ))
25 |
26 | "Database class" should "be able to parse queries" in {
27 | val database = new Database(db)
28 | val res = database.parse("SELECT city_id FROM city")
29 | res.select.get.columns.head.name shouldEqual "city_id"
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/QueryParserSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import org.scalatest.{FlatSpec, Matchers}
5 |
6 | class QueryParserSpec extends FlatSpec with Matchers {
7 | implicit val dialect = sql99.dialect
8 | val db = DB(schemas = List(
9 | Schema(
10 | "sakila",
11 | tables = List(
12 | Table(
13 | "City",
14 | columns = List(
15 | Column("city_id", INTEGER(nullable = false)),
16 | Column("city", STRING(nullable = true)),
17 | Column("country_id", INTEGER(nullable = false)),
18 | Column("last_update", TIMESTAMP(nullable = false))
19 | )
20 | ),
21 | Table(
22 | "Country",
23 | columns = List(
24 | Column("country_id", INTEGER(nullable = false)),
25 | Column("country", STRING(nullable = true)),
26 | Column("last_update", TIMESTAMP(nullable = false))
27 | )
28 | )
29 | )
30 | )
31 | ))
32 | "parse()" should "return a result" in {
33 | val result = QueryParser.parse(
34 | s"""
35 | |SELECT country, city
36 | |FROM city JOIN country ON city.country_id = country.country_id
37 | |WHERE city IN ?{availableCities}
38 | """.stripMargin, db)
39 | result.error.isDefined shouldBe false
40 | result.select.isDefined shouldBe true
41 | val select = result.select.get
42 | select.columns.length shouldBe 2
43 | }
44 |
45 | "parse()" should "handle errors" in {
46 | val result = QueryParser.parse(
47 | s"""S
48 | """.stripMargin, db)
49 | result.error.isDefined shouldBe true
50 | result.select.isDefined shouldBe false
51 | val error = result.error.get
52 | error.pos shouldBe 0
53 | error.msg shouldBe "select expected"
54 | }
55 | "parse()" should "identify invalid columns" in {
56 | val result = QueryParser.parse(
57 | s"""
58 | |SELECT country1, city
59 | |FROM city JOIN country ON city.country_id = country.country_id
60 | |WHERE city IN ?{availableCities}
61 | """.stripMargin, db)
62 | result.error.get.msg shouldBe "column not found country1"
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/common/ParseResultSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.common
2 |
3 | import org.scalatest.{FunSpec, Matchers}
4 | import com.criteo.vizatra.vizsql
5 | import com.criteo.vizatra.vizsql.INTEGER
6 |
7 | import scala.scalajs.js.JSON
8 | class ParseResultSpec extends FunSpec with Matchers {
9 | describe("Column") {
10 | describe("from()") {
11 | it("converts scala Column to JS object") {
12 | val col = Column.from(vizsql.Column("col1", INTEGER(true)))
13 | JSON.stringify(col) shouldEqual """{"name":"col1","type":"integer","nullable":true}"""
14 | }
15 | }
16 | }
17 |
18 | describe("Table") {
19 | describe("from()") {
20 | it("converts scala Table to JS object") {
21 | val table = Table.from(vizsql.Table("table1", List(vizsql.Column("col1", INTEGER(true)))), Some("schema1"))
22 | JSON.stringify(table) shouldEqual """{"name":"table1","columns":[{"name":"col1","type":"integer","nullable":true}],"schema":"schema1"}"""
23 | }
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/json/ColumnReaderSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql.hive.{HiveArray, HiveMap, HiveStruct, TypeParser}
4 | import com.criteo.vizatra.vizsql.{BOOLEAN, Column, INTEGER, STRING}
5 | import org.scalatest.{FlatSpec, Matchers}
6 |
7 | class ColumnReaderSpec extends FlatSpec with Matchers {
8 | "apply()" should "return a column with nullable" in {
9 | val res = ColumnReader.apply("""{"name":"col","type":"int4","nullable":true}""")
10 | res shouldEqual Column("col", INTEGER(true))
11 | }
12 | "apply()" should "return a column without nullable" in {
13 | val res = ColumnReader.apply("""{"name":"col","type":"int4"}""")
14 | res shouldEqual Column("col", INTEGER(false))
15 | }
16 | "parseType()" should "parse map type" in {
17 | val res = ColumnReader.parseType("""map""", true)
18 | res shouldEqual HiveMap(STRING(true), INTEGER(true))
19 | }
20 | "parseType()" should "parse array type" in {
21 | ColumnReader.parseType("array", true) shouldEqual HiveArray(INTEGER(true))
22 | }
23 | "parseType()" should "parse struct type" in {
24 | ColumnReader.parseType("struct,c:struct>", true) shouldEqual HiveStruct(List(
25 | Column("a", HiveStruct(List(
26 | Column("b", BOOLEAN(true))
27 | ))),
28 | Column("c", HiveStruct(List(
29 | Column("d", STRING(true))
30 | )))
31 | ))
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/json/DBReaderSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import org.scalatest.{FlatSpec, Matchers}
5 |
6 | class DBReaderSpec extends FlatSpec with Matchers {
7 | "apply()" should "returns a DB" in {
8 | val db = DBReader.apply(
9 | """
10 | |{
11 | | "dialect": "vertica",
12 | | "schemas": [
13 | | {
14 | | "name":"schema1",
15 | | "tables": [
16 | | {
17 | | "name":"table1",
18 | | "columns": [
19 | | {"name": "col1", "type": "int4"}
20 | | ]
21 | | }
22 | | ]
23 | | }
24 | | ]
25 | |}
26 | """.stripMargin)
27 | db.dialect shouldBe vertica.dialect
28 | db.schemas shouldEqual Schemas(List(Schema("schema1", List(Table("table1", List(Column("col1", INTEGER())))))))
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/json/DialectReaderSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import com.criteo.vizatra.vizsql.hive.HiveDialect
5 | import org.scalatest.{FlatSpec, Matchers}
6 |
7 | class DialectReaderSpec extends FlatSpec with Matchers {
8 | "from()" should "returns a dialect" in {
9 | DialectReader.from("VERTICA") shouldBe vertica.dialect
10 | DialectReader.from("PostgreSQL") shouldBe postgresql.dialect
11 | DialectReader.from("any") shouldBe sql99.dialect
12 | DialectReader.from("hsql") shouldBe hsqldb.dialect
13 | DialectReader.from("h2") shouldBe h2.dialect
14 | DialectReader.from("Hive") shouldBe HiveDialect(Map.empty)
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/json/SchemaReaderSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql.{Column, INTEGER, Schema, Table}
4 | import org.scalatest.{FlatSpec, Matchers}
5 |
6 | class SchemaReaderSpec extends FlatSpec with Matchers {
7 | "apply()" should "return a schema" in {
8 | val res = SchemaReader(
9 | """
10 | |{
11 | | "name":"schema1",
12 | | "tables": [
13 | | {
14 | | "name":"table1",
15 | | "columns": [
16 | | {"name": "col1", "type": "int4"}
17 | | ]
18 | | }
19 | | ]
20 | |}
21 | """.stripMargin)
22 | res shouldEqual Schema(
23 | "schema1",
24 | List(
25 | Table("table1", List(Column("col1", INTEGER())))
26 | )
27 | )
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/js/src/test/scala/com/criteo/vizatra/vizsql/js/json/TableReaderSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.js.json
2 |
3 | import com.criteo.vizatra.vizsql.{Column, DECIMAL, INTEGER}
4 | import org.scalatest.{FlatSpec, Matchers}
5 |
6 | class TableReaderSpec extends FlatSpec with Matchers {
7 | "apply()" should "return a table" in {
8 | val res = TableReader.apply(
9 | """
10 | |{
11 | |"name": "table_1",
12 | |"columns": [
13 | | {
14 | | "name": "col1",
15 | | "type": "int4"
16 | | },
17 | | {
18 | | "name": "col2",
19 | | "type": "float4"
20 | | }
21 | |]
22 | |}""".stripMargin)
23 | res.name shouldBe "table_1"
24 | res.columns shouldBe List(Column("col1", INTEGER()), Column("col2", DECIMAL()))
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/ExtractColumnsSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import org.scalatest.prop.TableDrivenPropertyChecks
4 | import org.scalatest.{Matchers, EitherValues, PropSpec}
5 | import sql99._
6 |
7 | class ExtractColumnsSpec extends PropSpec with Matchers with EitherValues {
8 |
9 | val validSQL99SelectStatements = TableDrivenPropertyChecks.Table(
10 | ("SQL", "Expected Columns"),
11 |
12 | ("""SELECT 1""", List(
13 | Column("1", INTEGER(false))
14 | )),
15 |
16 | ("""SELECT 1 as one""", List(
17 | Column("one", INTEGER(false))
18 | )),
19 |
20 | ("""SELECT * from City""", List(
21 | Column("city_id", INTEGER(nullable = false)),
22 | Column("city", STRING(nullable = false)),
23 | Column("country_id", INTEGER(nullable = false)),
24 | Column("last_update", TIMESTAMP(nullable = false))
25 | )),
26 |
27 | ("""SELECT x.* from City as x""", List(
28 | Column("city_id", INTEGER(nullable = false)),
29 | Column("city", STRING(nullable = false)),
30 | Column("country_id", INTEGER(nullable = false)),
31 | Column("last_update", TIMESTAMP(nullable = false))
32 | )),
33 |
34 | ("""SELECT x.* from sakila.City x""", List(
35 | Column("city_id", INTEGER(nullable = false)),
36 | Column("city", STRING(nullable = false)),
37 | Column("country_id", INTEGER(nullable = false)),
38 | Column("last_update", TIMESTAMP(nullable = false))
39 | )),
40 |
41 | ("""SELECT country_id, max(last_update) from City as x""", List(
42 | Column("country_id", INTEGER(nullable = false)),
43 | Column("MAX(last_update)", TIMESTAMP(nullable = false))
44 | )),
45 |
46 | ("""SELECT cities.country_id, max(cities.last_update) from City as cities""", List(
47 | Column("cities.country_id", INTEGER(nullable = false)),
48 | Column("MAX(cities.last_update)", TIMESTAMP(nullable = false))
49 | )),
50 |
51 | ("""SELECT * from Country, City""", List(
52 | Column("country_id", INTEGER(nullable = false)),
53 | Column("country", STRING(nullable = false)),
54 | Column("last_update", TIMESTAMP(nullable = false)),
55 | Column("city_id", INTEGER(nullable = false)),
56 | Column("city", STRING(nullable = false)),
57 | Column("country_id", INTEGER(nullable = false)),
58 | Column("last_update", TIMESTAMP(nullable = false))
59 | )),
60 |
61 | ("""SELECT city_id from city""", List(
62 | Column("city_id", INTEGER(nullable = false))
63 | )),
64 |
65 | ("""SELECT City.country_id from Country, City""", List(
66 | Column("city.country_id", INTEGER(nullable = false))
67 | )),
68 |
69 | ("""select blah.* from (select *, 1 as one from Country) as blah;""", List(
70 | Column("country_id", INTEGER(nullable = false)),
71 | Column("country", STRING(nullable = false)),
72 | Column("last_update", TIMESTAMP(nullable = false)),
73 | Column("one", INTEGER(false))
74 | )),
75 |
76 | ("""select * from City as v join Country as p on v.country_id = p.country_id""", List(
77 | Column("city_id", INTEGER(nullable = false)),
78 | Column("city", STRING(nullable = false)),
79 | Column("country_id", INTEGER(nullable = false)),
80 | Column("last_update", TIMESTAMP(nullable = false)),
81 | Column("country_id", INTEGER(nullable = false)),
82 | Column("country", STRING(nullable = false)),
83 | Column("last_update", TIMESTAMP(nullable = false))
84 | )),
85 |
86 | ("SELECT now() as now, MAX(last_update) + 3599 / 86400 AS last_update FROM City", List(
87 | Column("now", TIMESTAMP(nullable = false)),
88 | Column("last_update", TIMESTAMP(nullable = false))
89 | )),
90 |
91 | ("""select case 3 when 1 then 2 when 3 then 5 else 0 end AS foo""", List(
92 | Column("foo", INTEGER(nullable = false))
93 | )),
94 |
95 | ("""select case 3 when 1 then 2 when 3 then 5 end AS foo""", List(
96 | Column("foo", INTEGER(nullable = true))
97 | )),
98 |
99 | ("""select case city when 'a' then 2.5 when 'b' then 5.0 end AS foo from City""", List(
100 | Column("foo", DECIMAL(nullable = true))
101 | )),
102 |
103 | ("""select coalesce(case when city = 'a' then 1 end, 3) x from City""", List(
104 | Column("x", INTEGER(nullable = false))
105 | )),
106 |
107 | ("""select coalesce(case when city = 'a' then 1 end, case when country_id = 2 then 3 end) x from City""", List(
108 | Column("x", INTEGER(nullable = true))
109 | ))
110 | )
111 |
112 | // --
113 |
114 | val SAKILA = DB(schemas = List(
115 | Schema(
116 | "sakila",
117 | tables = List(
118 | Table(
119 | "City",
120 | columns = List(
121 | Column("city_id", INTEGER(nullable = false)),
122 | Column("city", STRING(nullable = false)),
123 | Column("country_id", INTEGER(nullable = false)),
124 | Column("last_update", TIMESTAMP(nullable = false))
125 | )
126 | ),
127 | Table(
128 | "Country",
129 | columns = List(
130 | Column("country_id", INTEGER(nullable = false)),
131 | Column("country", STRING(nullable = false)),
132 | Column("last_update", TIMESTAMP(nullable = false))
133 | )
134 | )
135 | )
136 | )
137 | ))
138 |
139 | // --
140 |
141 | property("extract SQL-99 SELECT statements columns") {
142 | TableDrivenPropertyChecks.forAll(validSQL99SelectStatements) {
143 | case (sql, expectedColumns) =>
144 | VizSQL.parseQuery(sql, SAKILA)
145 | .fold(e => sys.error(s"Query doesn't parse: $e"), identity)
146 | .columns
147 | .fold(e => sys.error(s"Invalid query: $e"), identity) should be (expectedColumns)
148 | }
149 | }
150 |
151 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/ExtractPlaceholdersSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import sql99._
4 | import org.scalatest.prop.TableDrivenPropertyChecks
5 | import org.scalatest.{Matchers, EitherValues, PropSpec}
6 |
7 | class ExtractPlaceholdersSpec extends PropSpec with Matchers with EitherValues {
8 |
9 | val validSQL99SelectStatements = TableDrivenPropertyChecks.Table(
10 | ("SQL", "Expected Placeholders"),
11 |
12 | ("""SELECT * FROM city WHERE ?""", List(
13 | (None, BOOLEAN())
14 | )),
15 |
16 | ("""SELECT * FROM city WHERE ?condition""", List(
17 | (Some("condition"), BOOLEAN())
18 | )),
19 |
20 | ("""SELECT ?today:timestamp""", List(
21 | (Some("today"), TIMESTAMP())
22 | )),
23 |
24 | ("""SELECT ?today:timestamp""", List(
25 | (Some("today"), TIMESTAMP())
26 | )),
27 |
28 | ("""SELECT * FROM country WHERE country_id = ?""", List(
29 | (None, INTEGER())
30 | )),
31 |
32 | ("""select *, country like ? from country where country_id = ? and last_update < ?;""", List(
33 | (None, STRING(nullable = true)),
34 | (None, INTEGER()),
35 | (None, TIMESTAMP())
36 | )),
37 |
38 | ("""SELECT ?a:DECIMAL = ?b:DECIMAL""", List(
39 | (Some("a"), DECIMAL()),
40 | (Some("b"), DECIMAL())
41 | )),
42 |
43 | ("""SELECT max(?today:timestamp)""", List(
44 | (Some("today"), TIMESTAMP())
45 | )),
46 |
47 | ("""SELECT ? = 3""", List(
48 | (None, INTEGER())
49 | )),
50 |
51 | ("""SELECT 3 = ?""", List(
52 | (None, INTEGER())
53 | )),
54 |
55 | ("""SELECT ?:varchar = ?""", List(
56 | (None, STRING()), (None, STRING())
57 | )),
58 |
59 | ("""SELECT ? = ?:varchar""", List(
60 | (None, STRING()), (None, STRING())
61 | )),
62 |
63 | (
64 | """
65 | select
66 | *, ?today:timestamp as TODAY, ?ratio * city_id
67 | from city
68 | where
69 | 1 + ?a > 10 + ?b:integer AND
70 | ?
71 | OR ?blah like city;
72 | """, List(
73 | (Some("today"), TIMESTAMP()),
74 | (Some("ratio"), DECIMAL()),
75 | (Some("a"), DECIMAL()),
76 | (Some("b"), INTEGER()),
77 | (None, BOOLEAN()),
78 | (Some("blah"), STRING(nullable = true))
79 | )),
80 |
81 | ("""select 1 between ? and ?""", List(
82 | (None, INTEGER()), (None, INTEGER())
83 | )),
84 |
85 | ("""select 1 between 0 and ?""", List(
86 | (None, INTEGER())
87 | )),
88 |
89 | ("""select 1 between ?[)""", List(
90 | (None, RANGE(INTEGER()))
91 | )),
92 |
93 | ("""select 1 between ?[:varchar)""", List(
94 | (None, RANGE(STRING()))
95 | )),
96 |
97 | ("""select ? between 0 and 5""", List(
98 | (None, INTEGER())
99 | )),
100 |
101 | ("""select ? between 'a' and ?""", List(
102 | (None, STRING()), (None, STRING())
103 | )),
104 |
105 | ("""select ? between ? and ?:timestamp""", List(
106 | (None, TIMESTAMP()), (None, TIMESTAMP()), (None, TIMESTAMP())
107 | )),
108 |
109 | ("""select * from city where last_update between ? and ?""", List(
110 | (None, TIMESTAMP()), (None, TIMESTAMP())
111 | )),
112 |
113 | ("""select ? between ? and (? + ?)""", List(
114 | (None, DECIMAL()), (None, DECIMAL()), (None, DECIMAL()), (None, DECIMAL())
115 | )),
116 |
117 | ("""select 1 in (?)""", List(
118 | (None, INTEGER())
119 | )),
120 |
121 | ("""select ? in (1,2,3,4)""", List(
122 | (None, INTEGER())
123 | )),
124 |
125 | ("""select ? in ('a','b',?,'d')""", List(
126 | (None, STRING()), (None, STRING())
127 | )),
128 |
129 | ("""select country_id in (?,?,?) from city""", List(
130 | (None, INTEGER()), (None, INTEGER()), (None, INTEGER())
131 | )),
132 |
133 | ("""select 1 in ?{}""", List(
134 | (None, SET(INTEGER()))
135 | )),
136 |
137 | (
138 | """
139 | select country in ?{authorizedCountries}
140 | from city join country on city.country_id = country.country_id
141 | """, List(
142 | (Some("authorizedCountries"), SET(STRING(nullable = true)))
143 | )),
144 |
145 | (
146 | """
147 | select city
148 | from city join country on city.country_id = ?
149 | """, List(
150 | (None, INTEGER())
151 | )),
152 |
153 | (
154 | """
155 | select
156 | case city_id
157 | when ? then 1
158 | when ? then 2
159 | else 0
160 | end
161 | from city
162 | """, List(
163 | (None, INTEGER()), (None, INTEGER())
164 | ))
165 | )
166 |
167 | // --
168 |
169 | val SAKILA = DB(schemas = List(
170 | Schema(
171 | "sakila",
172 | tables = List(
173 | Table(
174 | "City",
175 | columns = List(
176 | Column("city_id", INTEGER(nullable = false)),
177 | Column("city", STRING(nullable = true)),
178 | Column("country_id", INTEGER(nullable = false)),
179 | Column("last_update", TIMESTAMP(nullable = false))
180 | )
181 | ),
182 | Table(
183 | "Country",
184 | columns = List(
185 | Column("country_id", INTEGER(nullable = false)),
186 | Column("country", STRING(nullable = true)),
187 | Column("last_update", TIMESTAMP(nullable = false))
188 | )
189 | )
190 | )
191 | )
192 | ))
193 |
194 | // --
195 |
196 | property("compute SQL-99 SELECT statements placeholders") {
197 | TableDrivenPropertyChecks.forAll(validSQL99SelectStatements) {
198 | case (sql, expectedPlaceholders) =>
199 | VizSQL.parseQuery(sql, SAKILA)
200 | .fold(e => sys.error(s"Query doesn't parse: $e"), identity)
201 | .placeholders
202 | .fold(e => sys.error(s"Invalid query: $e"), _.namesAndTypes) should be (expectedPlaceholders)
203 | }
204 | }
205 |
206 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/FormatSQLSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import org.scalatest.prop.TableDrivenPropertyChecks
4 | import org.scalatest.{Matchers, EitherValues, PropSpec}
5 |
6 | class FormatSQLSpec extends PropSpec with Matchers with EitherValues {
7 |
8 | val examples = TableDrivenPropertyChecks.Table(
9 | ("Input SQL", "Formatted SQL"),
10 |
11 | (
12 | """SELECT 1""",
13 | """
14 | |SELECT
15 | | 1
16 | """.stripMargin
17 | ),
18 |
19 | (
20 | """select district, sum(population) from city""",
21 | """
22 | |SELECT
23 | | district,
24 | | SUM(population)
25 | |FROM
26 | | city
27 | """.stripMargin
28 | ),
29 |
30 | (
31 | """select * from City as v join Country as p on v.country_id = p.country_id where city.name like ? AND population > 10000""",
32 | """
33 | |SELECT
34 | | *
35 | |FROM
36 | | city AS v
37 | | JOIN country AS p
38 | | ON v.country_id = p.country_id
39 | |WHERE
40 | | city.name LIKE ?
41 | | AND population > 10000
42 | """.stripMargin
43 | ),
44 |
45 | (
46 | """
47 | SELECT
48 | d.device_name as device_name,
49 | z.description as zone_name,
50 | z.technology as technology,
51 | z.affiliate_name as affiliate_name,
52 | t.country_name as affiliate_country,
53 | t.country_level_1_name as affiliate_region,
54 | z.network_name as network_name,
55 | f.time_id as hour,
56 | CAST(date_trunc('day', f.time_id) as DATE) as day,
57 | CAST(date_trunc('week', f.time_id) as DATE) as week,
58 | CAST(date_trunc('month', f.time_id) as DATE) as month,
59 | CAST(date_trunc('quarter', f.time_id) as DATE) as quarter,
60 |
61 | SUM(displays) as displays,
62 | SUM(clicks) as clicks,
63 | SUM(sales) as sales,
64 | SUM(order_value_euro * r.rate) as order_value,
65 | SUM(revenue_euro * r.rate) as revenue,
66 | SUM(tac_euro * r.rate) as tac,
67 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac,
68 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue,
69 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac,
70 | ZEROIFNULL(SUM(clicks)/NULLIF(SUM(displays), 0.0)) as ctr,
71 | ZEROIFNULL(SUM(sales)/NULLIF(SUM(clicks), 0.0)) as cr,
72 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIF(SUM(revenue_euro), 0.0)) as margin,
73 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIF(SUM(marketplace_revenue_euro), 0.0)) as marketplace_margin,
74 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIF(SUM(clicks), 0.0)) as cpc,
75 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIF(SUM(displays), 0.0)) * 1000 as cpm
76 | FROM
77 | wopr.fact_zone_device_stats_hourly f
78 | JOIN wopr.dim_zone z
79 | ON z.zone_id = f.zone_id
80 | JOIN wopr.dim_device d
81 | ON d.device_id = f.device_id
82 | JOIN wopr.dim_country t
83 | ON t.country_id = f.affiliate_country_id
84 | JOIN wopr.fact_euro_rates_hourly r
85 | ON r.currency_id = ?currency_id AND f.time_id = r.time_id
86 |
87 | WHERE
88 | CAST(f.time_id AS DATE) between ?[day)
89 | AND t.country_code IN ?{publisher_countries}
90 | AND d.device_name IN ?{device_name}
91 | AND z.description IN ?{zone_name}
92 | AND z.technology IN ?{technology}
93 | AND z.affiliate_name IN ?{affiliate_name}
94 | AND t.country_name IN ?{affiliate_country}
95 | AND t.country_level_1_name IN ?{affiliate_region}
96 | AND z.network_name IN ?{network_name}
97 |
98 | GROUP BY ROLLUP((
99 | d.device_name,
100 | z.description,
101 | z.technology,
102 | z.affiliate_name,
103 | t.country_name,
104 | z.network_name,
105 | t.country_level_1_name,
106 | f.time_id,
107 | CAST(date_trunc('day', f.time_id) as DATE),
108 | CAST(date_trunc('week', f.time_id) as DATE),
109 | CAST(date_trunc('month', f.time_id) as DATE),
110 | CAST(date_trunc('quarter', f.time_id) as DATE)
111 | ))
112 | HAVING SUM(clicks) > 0
113 | """,
114 | """
115 | |SELECT
116 | | d.device_name AS device_name,
117 | | z.description AS zone_name,
118 | | z.technology AS technology,
119 | | z.affiliate_name AS affiliate_name,
120 | | t.country_name AS affiliate_country,
121 | | t.country_level_1_name AS affiliate_region,
122 | | z.network_name AS network_name,
123 | | f.time_id AS hour,
124 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE) AS day,
125 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE) AS week,
126 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE) AS month,
127 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE) AS quarter,
128 | | SUM(displays) AS displays,
129 | | SUM(clicks) AS clicks,
130 | | SUM(sales) AS sales,
131 | | SUM(order_value_euro * r.rate) AS order_value,
132 | | SUM(revenue_euro * r.rate) AS revenue,
133 | | SUM(tac_euro * r.rate) AS tac,
134 | | SUM((revenue_euro - tac_euro) * r.rate) AS revenue_ex_tac,
135 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue,
136 | | SUM((marketplace_revenue_euro - tac_euro) * r.rate) AS marketplace_revenue_ex_tac,
137 | | ZEROIFNULL(SUM(clicks) / NULLIF(SUM(displays), 0.0)) AS ctr,
138 | | ZEROIFNULL(SUM(sales) / NULLIF(SUM(clicks), 0.0)) AS cr,
139 | | ZEROIFNULL(SUM(revenue_euro - tac_euro) / NULLIF(SUM(revenue_euro), 0.0)) AS margin,
140 | | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro) / NULLIF(SUM(marketplace_revenue_euro), 0.0)) AS marketplace_margin,
141 | | ZEROIFNULL(SUM(revenue_euro * r.rate) / NULLIF(SUM(clicks), 0.0)) AS cpc,
142 | | ZEROIFNULL(SUM(tac_euro * r.rate) / NULLIF(SUM(displays), 0.0)) * 1000 AS cpm
143 | |FROM
144 | | wopr.fact_zone_device_stats_hourly AS f
145 | | JOIN wopr.dim_zone AS z
146 | | ON z.zone_id = f.zone_id
147 | | JOIN wopr.dim_device AS d
148 | | ON d.device_id = f.device_id
149 | | JOIN wopr.dim_country AS t
150 | | ON t.country_id = f.affiliate_country_id
151 | | JOIN wopr.fact_euro_rates_hourly AS r
152 | | ON r.currency_id = ?currency_id
153 | | AND f.time_id = r.time_id
154 | |WHERE
155 | | CAST(f.time_id AS DATE) BETWEEN ?day
156 | | AND t.country_code IN ?publisher_countries
157 | | AND d.device_name IN ?device_name
158 | | AND z.description IN ?zone_name
159 | | AND z.technology IN ?technology
160 | | AND z.affiliate_name IN ?affiliate_name
161 | | AND t.country_name IN ?affiliate_country
162 | | AND t.country_level_1_name IN ?affiliate_region
163 | | AND z.network_name IN ?network_name
164 | |GROUP BY
165 | | ROLLUP(
166 | | (
167 | | d.device_name,
168 | | z.description,
169 | | z.technology,
170 | | z.affiliate_name,
171 | | t.country_name,
172 | | z.network_name,
173 | | t.country_level_1_name,
174 | | f.time_id,
175 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE),
176 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE),
177 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE),
178 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE)
179 | | )
180 | | )
181 | |HAVING
182 | | SUM(clicks) > 0
183 | """.stripMargin
184 | ),
185 |
186 | (
187 | """select case when a =b then 1 when b <> 2 then 2 else 0 end""",
188 | """
189 | |SELECT
190 | | CASE
191 | | WHEN a = b THEN 1
192 | | WHEN b <> 2 THEN 2
193 | | ELSE 0
194 | | END
195 | """.stripMargin
196 | ),
197 |
198 | (
199 | """select 1,2 union all select 3,4 union all select 5,6""",
200 | """
201 | |SELECT
202 | | 1,
203 | | 2
204 | |UNION ALL
205 | |SELECT
206 | | 3,
207 | | 4
208 | |UNION ALL
209 | |SELECT
210 | | 5,
211 | | 6
212 | """.stripMargin
213 | ),
214 |
215 | (
216 | """select count(distinct woot)""",
217 | """
218 | |SELECT
219 | | COUNT(DISTINCT woot)
220 | """.stripMargin
221 | ),
222 |
223 | (
224 | """SELECT 1 LIMIT 10""",
225 | """
226 | |SELECT
227 | | 1
228 | |LIMIT 10
229 | """.stripMargin
230 | )
231 | )
232 |
233 | // --
234 |
235 | property("print SQL") {
236 | TableDrivenPropertyChecks.forAll(examples) {
237 | case (sql, expectedSQL) =>
238 | (new SQL99Parser).parseStatement(sql)
239 | .fold(e => sys.error(s"\n\n${e.toString(sql)}\n"), identity).toSQL should be (expectedSQL.stripMargin.trim)
240 | }
241 | }
242 |
243 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/OlapSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import org.scalatest.prop.TableDrivenPropertyChecks
4 | import org.scalatest.{Matchers, EitherValues, FlatSpec}
5 |
6 | class OlapSpec extends FlatSpec with Matchers with EitherValues {
7 |
8 | implicit val VERTICA = new Dialect {
9 |
10 | def parser = new SQL99Parser
11 |
12 | val functions = SQLFunction.standard orElse {
13 | case "date_trunc" => new SQLFunction2 {
14 | def result = { case ((_, _), (_, t)) => Right(TIMESTAMP(nullable = t.nullable)) }
15 | }
16 | case "zeroifnull" => new SQLFunction1 {
17 | def result = { case (_, t) => Right(t.withNullable(false)) }
18 | }
19 | case "nullifzero" => new SQLFunction1 {
20 | def result = { case (_, t) => Right(t.withNullable(true)) }
21 | }
22 | }: PartialFunction[String,SQLFunction]
23 | }
24 |
25 | val BIDATA = DB(schemas = List(
26 | Schema(
27 | "wopr",
28 | tables = List(
29 | Table(
30 | "fact_zone_device_stats_hourly",
31 | columns = List(
32 | Column("time_id", TIMESTAMP(nullable = false)),
33 | Column("zone_id", INTEGER(nullable = false)),
34 | Column("device_id", INTEGER(nullable = false)),
35 | Column("affiliate_country_id", INTEGER(nullable = false)),
36 | Column("displays", INTEGER(nullable = true)),
37 | Column("clicks", INTEGER(nullable = true)),
38 | Column("revenue_euro", DECIMAL(nullable = true)),
39 | Column("order_value_euro", DECIMAL(nullable = true)),
40 | Column("sales", DECIMAL(nullable = true)),
41 | Column("marketplace_revenue_euro", DECIMAL(nullable = true)),
42 | Column("tac_euro", DECIMAL(nullable = true)),
43 | Column("network_id", INTEGER(nullable = false))
44 | )
45 | ),
46 | Table(
47 | "dim_zone",
48 | columns = List(
49 | Column("zone_id", INTEGER(nullable = false)),
50 | Column("description", STRING(nullable = true)),
51 | Column("network_name", STRING(nullable = true)),
52 | Column("affiliate_name", STRING(nullable = true)),
53 | Column("technology", STRING(nullable = true))
54 | )
55 | ),
56 | Table(
57 | "dim_device",
58 | columns = List(
59 | Column("device_id", INTEGER(nullable = false)),
60 | Column("device_name", STRING(nullable = true))
61 | )
62 | ),
63 | Table(
64 | "dim_country",
65 | columns = List(
66 | Column("country_id", INTEGER(nullable = false)),
67 | Column("country_code", STRING(nullable = true)),
68 | Column("country_name", STRING(nullable = true)),
69 | Column("country_level_1_name", STRING(nullable = true))
70 | )
71 | ),
72 | Table(
73 | "fact_euro_rates_hourly",
74 | columns = List(
75 | Column("time_id", TIMESTAMP(nullable = false)),
76 | Column("currency_id", INTEGER(nullable = false)),
77 | Column("rate", DECIMAL(nullable = true))
78 | )
79 | ),
80 | Table(
81 | "fact_portfolio",
82 | columns = List(
83 | Column("time_id", TIMESTAMP(nullable = false)),
84 | Column("affiliate_id", INTEGER(nullable = false)),
85 | Column("network_id", INTEGER(nullable = false)),
86 | Column("zone_id", INTEGER(nullable = false)),
87 | Column("device_id", INTEGER(nullable = false)),
88 | Column("affiliate_country_id", INTEGER(nullable = false)),
89 | Column("affiliate_region_name", STRING(nullable = false)),
90 | Column("displays", INTEGER(nullable = true)),
91 | Column("clicks", INTEGER(nullable = true)),
92 | Column("revenue_euro", DECIMAL(nullable = true)),
93 | Column("order_value_euro", DECIMAL(nullable = true)),
94 | Column("sales", DECIMAL(nullable = true)),
95 | Column("marketplace_revenue_euro", DECIMAL(nullable = true)),
96 | Column("tac_euro", DECIMAL(nullable = true))
97 | )
98 | )
99 | )
100 | )
101 | ))
102 |
103 | val QUERY =
104 | """
105 | SELECT
106 | d.device_name as device_name,
107 | z.description as zone_name,
108 | z.technology as technology,
109 | z.affiliate_name as affiliate_name,
110 | t.country_name as affiliate_country,
111 | t.country_level_1_name as affiliate_region,
112 | z.network_name as network_name,
113 | date_trunc('hour', f.time_id) as hour,
114 | date_trunc('day', f.time_id) as day,
115 | date_trunc('week', f.time_id) as week,
116 | date_trunc('month', f.time_id) as month,
117 | date_trunc('quarter', f.time_id) as quarter,
118 |
119 | SUM(displays) as displays,
120 | SUM(clicks) as clicks,
121 | SUM(sales) as sales,
122 | SUM(order_value_euro * r.rate) as order_value,
123 | SUM(revenue_euro * r.rate) as revenue,
124 | SUM(tac_euro * r.rate) as tac,
125 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac,
126 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue,
127 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac,
128 | ZEROIFNULL(SUM(clicks)/NULLIFZERO(SUM(displays))) as ctr,
129 | ZEROIFNULL(SUM(sales)/NULLIFZERO(SUM(clicks))) as cr,
130 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIFZERO(SUM(revenue_euro))) as margin,
131 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIFZERO(SUM(marketplace_revenue_euro))) as marketplace_margin,
132 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIFZERO(SUM(clicks))) as cpc,
133 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIFZERO(SUM(displays))) * 1000 as cpm
134 | FROM
135 | wopr.fact_zone_device_stats_hourly f
136 | JOIN wopr.dim_zone z
137 | ON z.zone_id = f.zone_id
138 | JOIN wopr.dim_device d
139 | ON d.device_id = f.device_id
140 | JOIN wopr.dim_country t
141 | ON t.country_name = t2.country_name
142 | JOIN wopr.fact_euro_rates_hourly r
143 | ON r.currency_id = ?currency_id AND f.time_id = r.time_id
144 | JOIN wopr.dim_country t2
145 | ON t2.country_id = f.affiliate_country_id
146 |
147 | WHERE
148 | CAST(f.time_id AS DATE) between ?[date_range)
149 | AND t.country_code IN ?{publisher_countries}
150 | AND d.device_name IN ?{device_name}
151 | AND z.description IN ?{zone_name}
152 | AND z.technology IN ?{technology}
153 | AND z.affiliate_name IN ?{affiliate_name}
154 | AND t.country_name IN ?{affiliate_country}
155 | AND t.country_level_1_name IN ?{affiliate_region}
156 | AND z.network_name IN ?{network_name}
157 |
158 | GROUP BY ROLLUP(
159 | (
160 | d.device_name,
161 | z.description,
162 | z.technology,
163 | z.affiliate_name,
164 | t.country_name,
165 | z.network_name,
166 | t.country_level_1_name,
167 | date_trunc('hour', f.time_id),
168 | date_trunc('day', f.time_id),
169 | date_trunc('week', f.time_id),
170 | date_trunc('month', f.time_id),
171 | date_trunc('quarter', f.time_id)
172 | )
173 | )
174 |
175 | ORDER BY
176 | date_trunc('hour', f.time_id),
177 | date_trunc('day', f.time_id),
178 | date_trunc('week', f.time_id),
179 | date_trunc('month', f.time_id),
180 | date_trunc('quarter', f.time_id)
181 | """
182 |
183 | // -- Actual tests
184 |
185 | val olapQuery = VizSQL.parseOlapQuery(QUERY, BIDATA).left.map(_.toString(QUERY))
186 |
187 | "An OLAP query" should "recognize time dimensions" in {
188 | olapQuery.right.flatMap(_.getTimeDimensions) should be (Right(List(
189 | "hour",
190 | "day",
191 | "week",
192 | "month",
193 | "quarter"
194 | )))
195 | }
196 |
197 | it should "recognize other dimensions" in {
198 | olapQuery.right.flatMap(_.getDimensions) should be (Right(List(
199 | "device_name",
200 | "zone_name",
201 | "technology",
202 | "affiliate_name",
203 | "affiliate_country",
204 | "affiliate_region",
205 | "network_name"
206 | )))
207 | }
208 |
209 | it should "recognize metrics" in {
210 | olapQuery.right.flatMap(_.getMetrics) should be (Right(List(
211 | "displays",
212 | "clicks",
213 | "sales",
214 | "order_value",
215 | "revenue",
216 | "tac",
217 | "revenue_ex_tac",
218 | "marketplace_revenue",
219 | "marketplace_revenue_ex_tac",
220 | "ctr",
221 | "cr",
222 | "margin",
223 | "marketplace_margin",
224 | "cpc",
225 | "cpm"
226 | )))
227 | }
228 |
229 | it should "recognize input parameters" in {
230 | olapQuery.right.flatMap(_.getParameters) should be (Right(List(
231 | "currency_id",
232 | "date_range",
233 | "publisher_countries",
234 | "device_name",
235 | "zone_name",
236 | "technology",
237 | "affiliate_name",
238 | "affiliate_country",
239 | "affiliate_region",
240 | "network_name"
241 | )))
242 | }
243 |
244 | it should "retrieve the projection expression for dimensions/metrics" in {
245 | val examples = TableDrivenPropertyChecks.Table(
246 | ("Dimension or metric", "Expression"),
247 |
248 | ("revenue", """SUM(revenue_euro * r.rate) AS revenue"""),
249 | ("affiliate_region", """t.country_level_1_name AS affiliate_region""")
250 | )
251 |
252 | TableDrivenPropertyChecks.forAll(examples) {
253 | case (dimensionOrMetric, expression) =>
254 | olapQuery.right.flatMap(_.getProjection(dimensionOrMetric).right.map(_.toSQL)) should be (Right(
255 | expression
256 | ))
257 | }
258 | }
259 |
260 | it should "retrieve the tables needed for dimensions/metrics" in {
261 | val examples = TableDrivenPropertyChecks.Table(
262 | ("Dimension", "Used tables"),
263 |
264 | ("zone_name", List("z")),
265 | ("marketplace_margin", List("f")),
266 | ("cpc", List("f", "r")),
267 | ("affiliate_region", List("t"))
268 | )
269 |
270 | TableDrivenPropertyChecks.forAll(examples) {
271 | case (dim, tables) =>
272 | olapQuery.right.flatMap(q =>
273 | q.getProjection(dim).right.map(_.expression).right.flatMap(OlapQuery.tablesFor(q.query.select, q.query.db, _))
274 | ) should be (Right(tables))
275 | }
276 | }
277 |
278 | it should "rewrite FROM relations based on a given table set" in {
279 | val examples = TableDrivenPropertyChecks.Table(
280 | ("Tables", "FROM clause"),
281 |
282 | (Set.empty[String], ""),
283 |
284 | (Set("f"),
285 | """wopr.fact_zone_device_stats_hourly AS f"""
286 | ),
287 |
288 | (Set("t"),
289 | """wopr.dim_country AS t"""
290 | ),
291 |
292 | (Set("f", "t"),
293 | """
294 | |wopr.fact_zone_device_stats_hourly AS f
295 | |JOIN wopr.dim_country AS t
296 | | ON t.country_name = t2.country_name
297 | |JOIN wopr.dim_country AS t2
298 | | ON t2.country_id = f.affiliate_country_id
299 | """
300 | ),
301 |
302 | (Set("f", "z", "d"),
303 | """
304 | |wopr.fact_zone_device_stats_hourly AS f
305 | |JOIN wopr.dim_zone AS z
306 | | ON z.zone_id = f.zone_id
307 | |JOIN wopr.dim_device AS d
308 | | ON d.device_id = f.device_id
309 | """
310 | ),
311 |
312 | (Set("f", "r"),
313 | """
314 | |wopr.fact_zone_device_stats_hourly AS f
315 | |JOIN wopr.fact_euro_rates_hourly AS r
316 | | ON r.currency_id = ?currency_id
317 | | AND f.time_id = r.time_id
318 | """
319 | )
320 | )
321 |
322 | TableDrivenPropertyChecks.forAll(examples) {
323 | case (tables, expectedSQL) =>
324 | olapQuery.right.map(q => OlapQuery.rewriteRelations(q.query.select, BIDATA, tables)
325 | .map(_.toSQL).mkString(",\n")) should be (Right(expectedSQL.stripMargin.trim))
326 | }
327 | }
328 |
329 | it should "rewrite WHERE conditions based on available parameters" in {
330 | val examples = TableDrivenPropertyChecks.Table(
331 | ("Parameters", "WHERE clause"),
332 |
333 | (Set.empty[String], ""),
334 |
335 | (Set("date_range", "device_name"),
336 | """
337 | |CAST(f.time_id AS DATE) BETWEEN ?date_range
338 | |AND d.device_name IN ?device_name
339 | """
340 | ),
341 |
342 | (Set("technology", "affiliate_region", "publisher_countries", "network_name"),
343 | """
344 | |t.country_code IN ?publisher_countries
345 | |AND z.technology IN ?technology
346 | |AND t.country_level_1_name IN ?affiliate_region
347 | |AND z.network_name IN ?network_name
348 | """
349 | )
350 | )
351 |
352 | TableDrivenPropertyChecks.forAll(examples) {
353 | case (parameters, expectedSQL) =>
354 | olapQuery.right.map(q =>
355 | OlapQuery.rewriteWhereCondition(q.query.select, parameters).map(_.toSQL).getOrElse("")
356 | ) should be (Right(expectedSQL.stripMargin.trim))
357 | }
358 | }
359 |
360 | it should "rewrite GROUP BY expressions for given dimensions" in {
361 | val examples = TableDrivenPropertyChecks.Table(
362 | ("Dimensions", "GROUP BY clause"),
363 |
364 | (Set.empty[String], ""),
365 |
366 | (Set("hour"),
367 | """
368 | |ROLLUP(
369 | | (
370 | | DATE_TRUNC('hour', f.time_id)
371 | | )
372 | |)
373 | """
374 | ),
375 |
376 | (Set("affiliate_name", "affiliate_country"),
377 | """
378 | |ROLLUP(
379 | | (
380 | | z.affiliate_name,
381 | | t.country_name
382 | | )
383 | |)
384 | """
385 | ),
386 |
387 | (Set("affiliate_country", "affiliate_name", "network_name", "zone_name", "day"),
388 | """
389 | |ROLLUP(
390 | | (
391 | | z.description,
392 | | z.affiliate_name,
393 | | t.country_name,
394 | | z.network_name,
395 | | DATE_TRUNC('day', f.time_id)
396 | | )
397 | |)
398 | """
399 | )
400 |
401 | )
402 |
403 | TableDrivenPropertyChecks.forAll(examples) {
404 | case (dimensions, expectedSQL) =>
405 | olapQuery.right.flatMap(q =>
406 | dimensions.foldRight(Right(Nil):Either[Err,List[Expression]]) {
407 | (d, acc) => for(a <- acc.right; b <- q.getProjection(d).right.map(_.expression).right) yield b :: a
408 | }.right.map(expressions => OlapQuery.rewriteGroupBy(q.query.select, expressions).map(_.toSQL).mkString(",\n"))
409 | ) should be (Right(expectedSQL.stripMargin.trim))
410 | }
411 | }
412 |
413 | it should "compute the right query" in {
414 | val examples = TableDrivenPropertyChecks.Table(
415 | ("Projection", "Selection", "SQL"),
416 |
417 | (
418 | OlapProjection(Set("affiliate_country"), Set("displays", "clicks")),
419 | OlapSelection(Map.empty, Map.empty),
420 | """
421 | |SELECT
422 | | t.country_name AS affiliate_country,
423 | | SUM(displays) AS displays,
424 | | SUM(clicks) AS clicks
425 | |FROM
426 | | wopr.fact_zone_device_stats_hourly AS f
427 | | JOIN wopr.dim_country AS t
428 | | ON t.country_name = t2.country_name
429 | | JOIN wopr.dim_country AS t2
430 | | ON t2.country_id = f.affiliate_country_id
431 | |GROUP BY
432 | | ROLLUP(
433 | | (
434 | | t.country_name
435 | | )
436 | | )
437 | """
438 | ),
439 |
440 | (
441 | OlapProjection(Set("affiliate_country", "day"), Set("displays", "clicks", "marketplace_revenue")),
442 | OlapSelection(
443 | Map(
444 | "currency_id" -> 1,
445 | "publisher_countries" -> List("FR", "UK", "US", "DE"),
446 | "date_range" -> ("2015-09-10","2015-09-11")
447 | ),
448 | Map(
449 | "affiliate_country" -> List("FRANCE", "GERMANY"),
450 | "device_name" -> List("IPAD", "IPHONE")
451 | )
452 | ),
453 | """
454 | |SELECT
455 | | t.country_name AS affiliate_country,
456 | | DATE_TRUNC('day', f.time_id) AS day,
457 | | SUM(displays) AS displays,
458 | | SUM(clicks) AS clicks,
459 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue
460 | |FROM
461 | | wopr.fact_zone_device_stats_hourly AS f
462 | | JOIN wopr.dim_device AS d
463 | | ON d.device_id = f.device_id
464 | | JOIN wopr.dim_country AS t
465 | | ON t.country_name = t2.country_name
466 | | JOIN wopr.fact_euro_rates_hourly AS r
467 | | ON r.currency_id = 1
468 | | AND f.time_id = r.time_id
469 | | JOIN wopr.dim_country AS t2
470 | | ON t2.country_id = f.affiliate_country_id
471 | |WHERE
472 | | CAST(f.time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11'
473 | | AND t.country_code IN ('FR', 'UK', 'US', 'DE')
474 | | AND d.device_name IN ('IPAD', 'IPHONE')
475 | | AND t.country_name IN ('FRANCE', 'GERMANY')
476 | |GROUP BY
477 | | ROLLUP(
478 | | (
479 | | t.country_name,
480 | | DATE_TRUNC('day', f.time_id)
481 | | )
482 | | )
483 | |ORDER BY
484 | | DATE_TRUNC('day', f.time_id)
485 | """
486 | )
487 | )
488 |
489 | TableDrivenPropertyChecks.forAll(examples) {
490 | case (projection, selection, expectedSQL) =>
491 | olapQuery.right.flatMap(
492 | _.computeQuery(projection, selection)
493 | ) should be (Right(expectedSQL.stripMargin.trim))
494 | }
495 | }
496 |
497 | it should "rewrite metrics aggregate" in {
498 | val examples = TableDrivenPropertyChecks.Table(
499 | ("Metric", "Aggregate"),
500 |
501 | ("clicks", Some(
502 | SumPostAggregate(
503 | FunctionCallExpression("sum", None, args = List(
504 | ColumnExpression(ColumnIdent("clicks", None))
505 | ))
506 | )
507 | )),
508 |
509 | ("tac", Some(
510 | SumPostAggregate(
511 | FunctionCallExpression("sum", None, args = List(
512 | MathExpression(
513 | "*",
514 | ColumnExpression(ColumnIdent("tac_euro", None)),
515 | ColumnExpression(ColumnIdent("rate", Some(TableIdent("r", None))))
516 | )
517 | ))
518 | )
519 | )),
520 |
521 | ("ctr", Some(
522 | DividePostAggregate(
523 | SumPostAggregate(
524 | FunctionCallExpression("sum", None, args = List(
525 | ColumnExpression(ColumnIdent("clicks", None))
526 | ))
527 | ),
528 | SumPostAggregate(
529 | FunctionCallExpression("sum", None, args = List(
530 | ColumnExpression(ColumnIdent("displays", None))
531 | ))
532 | )
533 | )
534 | ))
535 |
536 | )
537 |
538 | TableDrivenPropertyChecks.forAll(examples) {
539 | case (metric, aggregateExpression) =>
540 | olapQuery.right.flatMap(_.rewriteMetricAggregate(metric)) should be (Right(aggregateExpression))
541 | }
542 | }
543 |
544 | it should "work with a subselect as relation" in {
545 | val QUERY =
546 | """
547 | SELECT
548 | device as device,
549 | zone as zone,
550 | affiliate as affiliate,
551 | affiliate_country as affiliate_country,
552 | affiliate_region as affiliate_region,
553 | network as network,
554 |
555 | hour as hour,
556 | day as day,
557 | week as week,
558 | month as month,
559 | quarter as quarter,
560 |
561 | SUM(displays) as displays,
562 | SUM(clicks) as clicks,
563 | SUM(sales) as sales,
564 | SUM(order_value_euro * r.rate) as order_value,
565 | SUM(revenue_euro * r.rate) as revenue,
566 | SUM(tac_euro * r.rate) as tac,
567 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac,
568 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue,
569 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac,
570 | ZEROIFNULL(SUM(clicks)/NULLIFZERO(SUM(displays))) as ctr,
571 | ZEROIFNULL(SUM(sales)/NULLIFZERO(SUM(clicks))) as cr,
572 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIFZERO(SUM(revenue_euro))) as margin,
573 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIFZERO(SUM(marketplace_revenue_euro))) as marketplace_margin,
574 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIFZERO(SUM(clicks))) as cpc,
575 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIFZERO(SUM(displays))) * 1000 as cpm
576 |
577 | FROM
578 | (
579 | SELECT
580 | date_trunc('month', time_id) AS x,
581 |
582 | device_id as device,
583 | zone_id as zone,
584 | affiliate_id as affiliate,
585 | affiliate_country_id as affiliate_country,
586 | affiliate_region_name as affiliate_region,
587 | network_id as network,
588 |
589 | date_trunc('hour', time_id) as hour,
590 | date_trunc('day', time_id) as day,
591 | date_trunc('week', time_id) as week,
592 | date_trunc('month', time_id) as month,
593 | date_trunc('quarter', time_id) as quarter,
594 |
595 | SUM(displays) as displays,
596 | SUM(clicks) as clicks,
597 | SUM(sales) as sales,
598 | SUM(order_value_euro) as order_value_euro,
599 | SUM(revenue_euro) as revenue_euro,
600 | SUM(tac_euro) as tac_euro,
601 | SUM(marketplace_revenue_euro) as marketplace_revenue_euro
602 |
603 | FROM
604 | wopr.fact_portfolio
605 |
606 | WHERE
607 | CAST(time_id AS DATE) BETWEEN ?[day)
608 | AND affiliate_country_id IN ?{publisher_countries}
609 | AND device_id IN ?{device}
610 | AND zone_id IN ?{zone}
611 | AND affiliate_id IN ?{affiliate}
612 | AND affiliate_country_id IN ?{affiliate_country}
613 | AND affiliate_region_name IN ?{affiliate_region}
614 | AND network_id IN ?{network}
615 |
616 | GROUP BY
617 | device_id,
618 | zone_id,
619 | affiliate_id,
620 | affiliate_country_id,
621 | network_id,
622 | affiliate_region_name,
623 | date_trunc('hour', time_id),
624 | date_trunc('day', time_id),
625 | date_trunc('week', time_id),
626 | date_trunc('month', time_id),
627 | date_trunc('quarter', time_id)
628 | ) as f
629 | JOIN wopr.fact_euro_rates_hourly as r
630 | ON r.currency_id = ?currency_id AND x = r.time_id
631 |
632 | GROUP BY ROLLUP((
633 | device,
634 | zone,
635 | affiliate,
636 | affiliate_country,
637 | network,
638 | affiliate_region,
639 | hour,
640 | day,
641 | week,
642 | month,
643 | quarter
644 | ))
645 |
646 | ORDER BY
647 | hour,
648 | day,
649 | week,
650 | month,
651 | quarter
652 | """
653 |
654 | val olapQuery = VizSQL.parseOlapQuery(QUERY, BIDATA).left.map(e => sys.error(e.toString(QUERY))).right.get
655 |
656 | olapQuery.getTimeDimensions should be (Right(List(
657 | "hour",
658 | "day",
659 | "week",
660 | "month",
661 | "quarter"
662 | )))
663 |
664 | olapQuery.getDimensions should be (Right(List(
665 | "device",
666 | "zone",
667 | "affiliate",
668 | "affiliate_country",
669 | "affiliate_region",
670 | "network"
671 | )))
672 |
673 | olapQuery.getMetrics should be (Right(List(
674 | "displays",
675 | "clicks",
676 | "sales",
677 | "order_value",
678 | "revenue",
679 | "tac",
680 | "revenue_ex_tac",
681 | "marketplace_revenue",
682 | "marketplace_revenue_ex_tac",
683 | "ctr",
684 | "cr",
685 | "margin",
686 | "marketplace_margin",
687 | "cpc",
688 | "cpm"
689 | )))
690 |
691 | olapQuery.getParameters should be (Right(List(
692 | "day",
693 | "publisher_countries",
694 | "device",
695 | "zone",
696 | "affiliate",
697 | "affiliate_country",
698 | "affiliate_region",
699 | "network",
700 | "currency_id"
701 | )))
702 |
703 | olapQuery.computeQuery(
704 | OlapProjection(Set("device"), Set("clicks")),
705 | OlapSelection(Map("currency_id" -> 1, "publisher_countries" -> List(1,2,3), "day" -> ("2015-09-10","2015-09-11")), Map.empty)
706 | ).left.map(_.toString(QUERY)) should be (Right(
707 | """
708 | |SELECT
709 | | device AS device,
710 | | SUM(clicks) AS clicks
711 | |FROM
712 | | (
713 | | SELECT
714 | | device_id AS device,
715 | | SUM(clicks) AS clicks
716 | | FROM
717 | | wopr.fact_portfolio
718 | | WHERE
719 | | CAST(time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11'
720 | | AND affiliate_country_id IN (1, 2, 3)
721 | | GROUP BY
722 | | device_id
723 | | ) AS f
724 | |GROUP BY
725 | | ROLLUP(
726 | | (
727 | | device
728 | | )
729 | | )
730 | """.stripMargin.trim
731 | ))
732 |
733 | olapQuery.computeQuery(
734 | OlapProjection(Set("device", "hour"), Set("clicks")),
735 | OlapSelection(Map("currency_id" -> 1, "publisher_countries" -> List(1,2,3), "day" -> ("2015-09-10","2015-09-11")), Map.empty)
736 | ).left.map(_.toString(QUERY)) should be (Right(
737 | """
738 | |SELECT
739 | | device AS device,
740 | | hour AS hour,
741 | | SUM(clicks) AS clicks
742 | |FROM
743 | | (
744 | | SELECT
745 | | device_id AS device,
746 | | DATE_TRUNC('hour', time_id) AS hour,
747 | | SUM(clicks) AS clicks
748 | | FROM
749 | | wopr.fact_portfolio
750 | | WHERE
751 | | CAST(time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11'
752 | | AND affiliate_country_id IN (1, 2, 3)
753 | | GROUP BY
754 | | device_id,
755 | | DATE_TRUNC('hour', time_id)
756 | | ) AS f
757 | |GROUP BY
758 | | ROLLUP(
759 | | (
760 | | device,
761 | | hour
762 | | )
763 | | )
764 | |ORDER BY
765 | | hour
766 | """.stripMargin.trim
767 | ))
768 |
769 | olapQuery.computeQuery(
770 | OlapProjection(Set("device", "hour"), Set("clicks", "marketplace_revenue")),
771 | OlapSelection(Map("currency_id" -> 1, "publisher_countries" -> List(1,2,3), "day" -> ("2015-09-10","2015-09-11")), Map.empty)
772 | ).left.map(_.toString(QUERY)) should be (Right(
773 | """
774 | |SELECT
775 | | device AS device,
776 | | hour AS hour,
777 | | SUM(clicks) AS clicks,
778 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue
779 | |FROM
780 | | (
781 | | SELECT
782 | | DATE_TRUNC('month', time_id) AS x,
783 | | device_id AS device,
784 | | DATE_TRUNC('hour', time_id) AS hour,
785 | | SUM(clicks) AS clicks,
786 | | SUM(marketplace_revenue_euro) AS marketplace_revenue_euro
787 | | FROM
788 | | wopr.fact_portfolio
789 | | WHERE
790 | | CAST(time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11'
791 | | AND affiliate_country_id IN (1, 2, 3)
792 | | GROUP BY
793 | | device_id,
794 | | DATE_TRUNC('hour', time_id),
795 | | DATE_TRUNC('month', time_id)
796 | | ) AS f
797 | | JOIN wopr.fact_euro_rates_hourly AS r
798 | | ON r.currency_id = 1
799 | | AND x = r.time_id
800 | |GROUP BY
801 | | ROLLUP(
802 | | (
803 | | device,
804 | | hour
805 | | )
806 | | )
807 | |ORDER BY
808 | | hour
809 | """.stripMargin.trim
810 | ))
811 | }
812 |
813 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/OptimizeSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import org.scalatest.{Matchers, FlatSpec}
4 |
5 | class OptimizeSpec extends FlatSpec with Matchers {
6 |
7 | val CASE_QUERY =
8 | """SELECT
9 | | country,
10 | | CASE 1 WHEN 1 THEN 13
11 | | WHEN clicks THEN 9 END,
12 | | CASE 1 WHEN clicks THEN 65
13 | | WHEN 1 THEN 12 END,
14 | | CASE clicks WHEN clicks THEN 1
15 | | WHEN 42 then 77
16 | | ELSE 18 END,
17 | | CASE WHEN 1 > 2 THEN 42
18 | | WHEN 2 > 1 THEN 38
19 | | WHEN CLICKS > DISPLAYS THEN 654 END,
20 | | CASE 5 WHEN 1 THEN 2
21 | | WHEN 3 THEN 4
22 | | ELSE 10 END,
23 | | CASE 5 WHEN 1 THEN 2
24 | | WHEN 3 THEN 4
25 | | END,
26 | | CASE 42 WHEN 10 THEN 2
27 | | WHEN displays THEN 12
28 | | ELSE 88
29 | | END
30 | | FROM facts
31 | """.stripMargin
32 |
33 | val OPERATOR_QUERY =
34 | """SELECT
35 | |country,
36 | |3 + 2,
37 | |10/5,
38 | |NULL / 2,
39 | |3 * 4,
40 | |5 / 0,
41 | |'Marco' + 'Polo',
42 | |44 - 2,
43 | | -(-(-2))
44 | |FROM facts
45 | |WHERE (3 > 2) AND 'X' = 'Y'
46 | """.stripMargin
47 |
48 | val JOIN_QUERY =
49 | """SELECT
50 | |f.region,
51 | |f.country,
52 | |f.date,
53 | |c.comment
54 | |FROM facts AS f
55 | |JOIN review AS c ON f.idfacts = c.idfacts AND 2 > 3
56 | """.stripMargin
57 |
58 | val SUB_QUERY=
59 | """SELECT
60 | |x.result,
61 | |(SELECT comment FROM review WHERE idfacts = 42 AND (2 + 2 = 4)) AS comment
62 | |FROM (SELECT (2 + 2) AS result) AS x
63 | |WHERE (SELECT 2 = 2)
64 | """.stripMargin
65 |
66 | val testDb = {
67 | val cols = for { (name, coltype) <- List(("idfacts", INTEGER(false)),
68 | ("REGION", STRING(false)),
69 | ("COUNTRY", STRING(false)),
70 | ("CLICKS", INTEGER(false)),
71 | ("DISPLAYS", INTEGER(false)),
72 | ("DATE", DATE(true)))}
73 | yield Column(name, coltype)
74 | val cols2 = for { (name, coltype) <- List(("idfacts", INTEGER(false)),
75 | ("COMMENT", STRING(false)),
76 | ("COMMENTER", STRING(false)),
77 | ("RATING", INTEGER(true)))}
78 | yield Column(name, coltype)
79 | implicit val dialect = sql99.dialect
80 | DB(List(Schema("", List(Table("facts", cols), Table("review", cols2)))))
81 | }
82 |
83 | val subQuery = VizSQL.parseQuery(SUB_QUERY, testDb)
84 | val caseQuery = VizSQL.parseQuery(CASE_QUERY, testDb)
85 | val opQuery = VizSQL.parseQuery(OPERATOR_QUERY, testDb)
86 | val joinQuery = VizSQL.parseQuery(JOIN_QUERY, testDb)
87 |
88 | val queries = List(subQuery, caseQuery, opQuery, joinQuery)
89 |
90 | "The Optimizer" should "Simplify \"CASE\" statements" in {
91 | caseQuery.isRight shouldBe true
92 | val expected =
93 | """SELECT
94 | |country,
95 | |13,
96 | |CASE 1 WHEN clicks THEN 65
97 | | WHEN 1 THEN 12 END,
98 | |CASE clicks WHEN clicks THEN 1
99 | | WHEN 42 then 77
100 | | ELSE 18 END,
101 | |38,
102 | |10,
103 | |NULL,
104 | |CASE 42
105 | | WHEN displays THEN 12
106 | | ELSE 88
107 | | END
108 | |FROM facts
109 | """.stripMargin
110 | val query = caseQuery.right.get
111 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get
112 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL)
113 | }
114 |
115 | it should "pre-compute literal expressions" in {
116 | val expected =
117 | """SELECT
118 | |country,
119 | |5,
120 | |2,
121 | |NULL,
122 | |12,
123 | |5 / 0,
124 | |'Marco' + 'Polo',
125 | |42,
126 | |-2
127 | |FROM facts
128 | |WHERE FALSE
129 | """.stripMargin
130 | opQuery.isRight shouldBe true
131 | val query = opQuery.right.get
132 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get
133 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL)
134 | }
135 |
136 | it should "be idempotent" in {
137 | for {q <- queries } {
138 | q.isRight shouldBe true
139 | val query = q.right.get
140 | val optimized = Optimizer.optimize(query)
141 | val superOptimized = Optimizer.optimize(optimized)
142 | superOptimized should be(optimized)
143 | }
144 | }
145 |
146 | it should "remove unnecessary tables" in {
147 | val QUERY =
148 | """SELECT
149 | |3.0 / 2
150 | |FROM facts
151 | """.stripMargin
152 | val expected =
153 | """SELECT
154 | |1.5
155 | """.stripMargin
156 | val query = VizSQL.parseQuery(QUERY, testDb).right.get
157 | val ref = VizSQL.parseQuery(expected, testDb).right.get
158 | Optimizer.optimize(query).sql should be (ref.select.toSQL)
159 | val QUERY2 =
160 | """SELECT
161 | |r.comment,
162 | |3.0 / 2
163 | |FROM facts as f
164 | |JOIN review as r ON r.idfact = f.idfact
165 | """.stripMargin
166 | val expected2 =
167 | """SELECT
168 | |r.comment,
169 | |1.5
170 | |FROM review as r
171 | """.stripMargin
172 | val query2 = VizSQL.parseQuery(QUERY2, testDb).right.get
173 | val ref2 = VizSQL.parseQuery(expected2, testDb).right.get
174 | Optimizer.optimize(query2).sql should be (ref2.select.toSQL)
175 | }
176 |
177 | it should "rewrite sons of expressions it can't optimize" in {
178 | val QUERY = """SELECT
179 | |SUM(3.0 / 2),
180 | |CAST(3 + 2 AS INTEGER) / 2,
181 | |(5 > 3)
182 | |FROM facts
183 | """.stripMargin
184 | val expected =
185 | """SELECT
186 | |SUM(1.5),
187 | |CAST(5 AS INTEGER) / 2,
188 | |TRUE
189 | """.stripMargin
190 | val query = VizSQL.parseQuery(QUERY, testDb).right.get
191 | val ref = VizSQL.parseQuery(expected, testDb).right.get
192 | Optimizer.optimize(query).sql should be (ref.select.toSQL)
193 | }
194 |
195 | it should "optimize joins" in {
196 | val expected =
197 | """SELECT
198 | |f.region,
199 | |f.country,
200 | |f.date,
201 | |c.comment
202 | |FROM facts AS f
203 | |JOIN review AS c ON f.idfacts = c.idfacts AND FALSE
204 | """.stripMargin
205 | val query = joinQuery.right.get
206 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get
207 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL)
208 | }
209 |
210 | it should "handle Subselect" in {
211 | val expected =
212 | """SELECT
213 | |x.result,
214 | |(SELECT comment FROM review WHERE idfacts = 42 AND TRUE) AS comment
215 | |FROM (SELECT 4 AS result) AS x
216 | |WHERE (SELECT TRUE)
217 | """.stripMargin
218 | val query = subQuery.right.get
219 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get
220 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL)
221 | }
222 |
223 | }
224 |
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/ParseVerticaDialectSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import com.criteo.vizatra.vizsql.vertica._
4 | import org.scalatest.prop.TableDrivenPropertyChecks
5 | import org.scalatest.{EitherValues, Matchers, PropSpec}
6 |
7 | class ParseVerticaDialectSpec extends PropSpec with Matchers with EitherValues {
8 |
9 | val validVerticaSelectStatements = TableDrivenPropertyChecks.Table(
10 | ("SQL", "Expected Columns"),
11 | ("""SELECT
12 | | NOW() as now,
13 | | MAX(last_update) + 3599 / 86400 AS last_update,
14 | | CONCAT('The most recent update was on ', TO_CHAR(MAX(last_update) + 3599 / 86400, 'YYYY-MM-DD at HH:MI')) as content
15 | |FROM
16 | | City""".stripMargin,
17 | List(
18 | Column("now", TIMESTAMP(nullable = false)),
19 | Column("last_update", TIMESTAMP(nullable = false)),
20 | Column("content", STRING(nullable = false))
21 | ))
22 | )
23 |
24 | // --
25 |
26 | val SAKILA = DB(schemas = List(
27 | Schema(
28 | "sakila",
29 | tables = List(
30 | Table(
31 | "City",
32 | columns = List(
33 | Column("city_id", INTEGER(nullable = false)),
34 | Column("city", STRING(nullable = false)),
35 | Column("country_id", INTEGER(nullable = false)),
36 | Column("last_update", TIMESTAMP(nullable = false))
37 | )
38 | ),
39 | Table(
40 | "Country",
41 | columns = List(
42 | Column("country_id", INTEGER(nullable = false)),
43 | Column("country", STRING(nullable = false)),
44 | Column("last_update", TIMESTAMP(nullable = false))
45 | )
46 | )
47 | )
48 | )
49 | ))
50 |
51 | // --
52 |
53 | property("extract Vertica SELECT statements columns") {
54 | TableDrivenPropertyChecks.forAll(validVerticaSelectStatements) {
55 | case (sql, expectedColumns) =>
56 | VizSQL.parseQuery(sql, SAKILA)
57 | .fold(e => sys.error(s"Query doesn't parse: $e"), identity)
58 | .columns
59 | .fold(e => sys.error(s"Invalid query: $e"), identity) should be (expectedColumns)
60 | }
61 | }
62 |
63 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/ParsingErrorsSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import sql99._
4 | import org.scalatest.prop.TableDrivenPropertyChecks
5 | import org.scalatest.{Matchers, EitherValues, PropSpec}
6 |
7 | class ParsingErrorsSpec extends PropSpec with Matchers with EitherValues {
8 |
9 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table(
10 | ("SQL", "Expected error"),
11 |
12 | (
13 | """xxx""",
14 | """|xxx
15 | |^
16 | |Error: select expected
17 | """
18 | ),
19 | (
20 | """select""",
21 | """|select
22 | | ^
23 | |Error: *, table or expression expected
24 | """
25 | ),
26 | (
27 | """select 1 +""",
28 | """|select 1 +
29 | | ^
30 | |Error: expression expected
31 | """
32 | ),
33 | (
34 | """select 1 + *""",
35 | """|select 1 + *
36 | | ^
37 | |Error: expression expected
38 | """
39 | ),
40 | (
41 | """select (1 + 3""",
42 | """|select (1 + 3
43 | | ^
44 | |Error: ) expected
45 | """
46 | ),
47 | (
48 | """select * from""",
49 | """|select * from
50 | | ^
51 | |Error: table, join or subselect expected
52 | """
53 | ),
54 | (
55 | """select * from (selet 1)""",
56 | """|select * from (selet 1)
57 | | ^
58 | |Error: select expected
59 | """
60 | ),
61 | (
62 | """select * from (select 1sh);""",
63 | """|select * from (select 1sh);
64 | | ^
65 | |Error: ident expected
66 | """
67 | ),
68 | (
69 | """select * from (select 1)sh)""",
70 | """|select * from (select 1)sh)
71 | | ^
72 | |Error: ; expected
73 | """
74 | ),
75 | (
76 | """SELECT CustomerName; City FROM Customers;""",
77 | """|SELECT CustomerName; City FROM Customers;
78 | | ^
79 | |Error: end of statement expected
80 | """
81 | ),
82 | (
83 | """SELECT CustomerName FROM Customers UNION ALL""",
84 | """|SELECT CustomerName FROM Customers UNION ALL
85 | | ^
86 | |Error: select expected
87 | """
88 | )
89 | )
90 |
91 | // --
92 |
93 | property("report parsing errors on invalid SQL-99 SELECT statements") {
94 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) {
95 | case (sql, expectedError) =>
96 | (new SQL99Parser).parseStatement(sql)
97 | .fold(_.toString(sql, ' ').trim, _ => "[NO ERROR]") should be (expectedError.toString.stripMargin.trim)
98 | }
99 | }
100 |
101 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/SchemaErrorsSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import org.scalatest.prop.TableDrivenPropertyChecks
4 | import org.scalatest.{EitherValues, Matchers, PropSpec}
5 | import sql99._
6 |
7 | /**
8 | * Test cases for schema errors
9 | */
10 | class SchemaErrorsSpec extends PropSpec with Matchers with EitherValues {
11 |
12 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table(
13 | ("SQL", "Expected error"),
14 | (
15 | "SELECT region from City as C1 JOIN Country as C2 ON C1.country_id = C2.country_id WHERE region < 42",
16 | SchemaError("ambiguous column region", 6)
17 | ),(
18 | "SELECT nonexistent, region from City",
19 | SchemaError("column not found nonexistent", 6)
20 | )
21 | )
22 |
23 | // --
24 |
25 | val SAKILA = DB(schemas = List(
26 | Schema(
27 | "sakila",
28 | tables = List(
29 | Table(
30 | "City",
31 | columns = List(
32 | Column("city_id", INTEGER(nullable = false)),
33 | Column("city", STRING(nullable = false)),
34 | Column("country_id", INTEGER(nullable = false)),
35 | Column("last_update", TIMESTAMP(nullable = false)),
36 | Column("region", INTEGER(nullable = false))
37 | )
38 | ),
39 | Table(
40 | "Country",
41 | columns = List(
42 | Column("country_id", INTEGER(nullable = false)),
43 | Column("country", STRING(nullable = false)),
44 | Column("last_update", TIMESTAMP(nullable = false)),
45 | Column("region", INTEGER(nullable = false))
46 | )
47 | )
48 | )
49 | )
50 | ))
51 |
52 | // --
53 |
54 | property("report schema errors on invalid SQL-99 SELECT statements") {
55 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) {
56 | case (sql, expectedError) =>
57 | VizSQL.parseQuery(sql, SAKILA)
58 | .fold(
59 | e => sys.error(s"Query doesn't parse: $e"),
60 | _.error.getOrElse(sys.error(s"Query should not type!"))
61 | ) should be (expectedError)
62 | }
63 | }
64 |
65 | }
66 |
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/ThreadSafetySpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import org.scalatest.prop.TableDrivenPropertyChecks
4 | import org.scalatest.{Matchers, EitherValues, FlatSpec}
5 |
6 | class ThreadSafetySpec extends FlatSpec with Matchers with EitherValues {
7 |
8 | "An SQL parser" should "be thread safe" in {
9 | val uniqueParser = new SQL99Parser
10 |
11 | import scala.util._
12 | import concurrent._
13 | import duration._
14 | import ExecutionContext.Implicits.global
15 |
16 | val randomDataSet = (1 to 10).map { _ =>
17 | (1 to 100).map(_ => Random.nextInt(4)).map {
18 | case 0 =>
19 | (
20 | """SELECT 1""",
21 | """
22 | |SELECT
23 | | 1
24 | """.stripMargin
25 | )
26 | case 1 =>
27 | (
28 | """select district, sum(population) from city""",
29 | """
30 | |SELECT
31 | | district,
32 | | SUM(population)
33 | |FROM
34 | | city
35 | """.stripMargin
36 | )
37 | case 2 =>
38 | (
39 | """select * from City as v join Country as p on v.country_id = p.country_id where city.name like ? AND population > 10000""",
40 | """
41 | |SELECT
42 | | *
43 | |FROM
44 | | city AS v
45 | | JOIN country AS p
46 | | ON v.country_id = p.country_id
47 | |WHERE
48 | | city.name LIKE ?
49 | | AND population > 10000
50 | """.stripMargin
51 | )
52 | case 3 =>
53 | (
54 | """
55 | SELECT
56 | d.device_name as device_name,
57 | z.description as zone_name,
58 | z.technology as technology,
59 | z.affiliate_name as affiliate_name,
60 | t.country_name as affiliate_country,
61 | t.country_level_1_name as affiliate_region,
62 | z.network_name as network_name,
63 | f.time_id as hour,
64 | CAST(date_trunc('day', f.time_id) as DATE) as day,
65 | CAST(date_trunc('week', f.time_id) as DATE) as week,
66 | CAST(date_trunc('month', f.time_id) as DATE) as month,
67 | CAST(date_trunc('quarter', f.time_id) as DATE) as quarter,
68 |
69 | SUM(displays) as displays,
70 | SUM(clicks) as clicks,
71 | SUM(sales) as sales,
72 | SUM(order_value_euro * r.rate) as order_value,
73 | SUM(revenue_euro * r.rate) as revenue,
74 | SUM(tac_euro * r.rate) as tac,
75 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac,
76 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue,
77 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac,
78 | ZEROIFNULL(SUM(clicks)/NULLIF(SUM(displays), 0.0)) as ctr,
79 | ZEROIFNULL(SUM(sales)/NULLIF(SUM(clicks), 0.0)) as cr,
80 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIF(SUM(revenue_euro), 0.0)) as margin,
81 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIF(SUM(marketplace_revenue_euro), 0.0)) as marketplace_margin,
82 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIF(SUM(clicks), 0.0)) as cpc,
83 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIF(SUM(displays), 0.0)) * 1000 as cpm
84 | FROM
85 | wopr.fact_zone_device_stats_hourly f
86 | JOIN wopr.dim_zone z
87 | ON z.zone_id = f.zone_id
88 | JOIN wopr.dim_device d
89 | ON d.device_id = f.device_id
90 | JOIN wopr.dim_country t
91 | ON t.country_id = f.affiliate_country_id
92 | JOIN wopr.fact_euro_rates_hourly r
93 | ON r.currency_id = ?currency_id AND f.time_id = r.time_id
94 |
95 | WHERE
96 | CAST(f.time_id AS DATE) between ?[day)
97 | AND t.country_code IN ?{publisher_countries}
98 | AND d.device_name IN ?{device_name}
99 | AND z.description IN ?{zone_name}
100 | AND z.technology IN ?{technology}
101 | AND z.affiliate_name IN ?{affiliate_name}
102 | AND t.country_name IN ?{affiliate_country}
103 | AND t.country_level_1_name IN ?{affiliate_region}
104 | AND z.network_name IN ?{network_name}
105 |
106 | GROUP BY ROLLUP((
107 | d.device_name,
108 | z.description,
109 | z.technology,
110 | z.affiliate_name,
111 | t.country_name,
112 | z.network_name,
113 | t.country_level_1_name,
114 | f.time_id,
115 | CAST(date_trunc('day', f.time_id) as DATE),
116 | CAST(date_trunc('week', f.time_id) as DATE),
117 | CAST(date_trunc('month', f.time_id) as DATE),
118 | CAST(date_trunc('quarter', f.time_id) as DATE)
119 | ))
120 | HAVING SUM(clicks) > 0
121 | """,
122 | """
123 | |SELECT
124 | | d.device_name AS device_name,
125 | | z.description AS zone_name,
126 | | z.technology AS technology,
127 | | z.affiliate_name AS affiliate_name,
128 | | t.country_name AS affiliate_country,
129 | | t.country_level_1_name AS affiliate_region,
130 | | z.network_name AS network_name,
131 | | f.time_id AS hour,
132 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE) AS day,
133 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE) AS week,
134 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE) AS month,
135 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE) AS quarter,
136 | | SUM(displays) AS displays,
137 | | SUM(clicks) AS clicks,
138 | | SUM(sales) AS sales,
139 | | SUM(order_value_euro * r.rate) AS order_value,
140 | | SUM(revenue_euro * r.rate) AS revenue,
141 | | SUM(tac_euro * r.rate) AS tac,
142 | | SUM((revenue_euro - tac_euro) * r.rate) AS revenue_ex_tac,
143 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue,
144 | | SUM((marketplace_revenue_euro - tac_euro) * r.rate) AS marketplace_revenue_ex_tac,
145 | | ZEROIFNULL(SUM(clicks) / NULLIF(SUM(displays), 0.0)) AS ctr,
146 | | ZEROIFNULL(SUM(sales) / NULLIF(SUM(clicks), 0.0)) AS cr,
147 | | ZEROIFNULL(SUM(revenue_euro - tac_euro) / NULLIF(SUM(revenue_euro), 0.0)) AS margin,
148 | | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro) / NULLIF(SUM(marketplace_revenue_euro), 0.0)) AS marketplace_margin,
149 | | ZEROIFNULL(SUM(revenue_euro * r.rate) / NULLIF(SUM(clicks), 0.0)) AS cpc,
150 | | ZEROIFNULL(SUM(tac_euro * r.rate) / NULLIF(SUM(displays), 0.0)) * 1000 AS cpm
151 | |FROM
152 | | wopr.fact_zone_device_stats_hourly AS f
153 | | JOIN wopr.dim_zone AS z
154 | | ON z.zone_id = f.zone_id
155 | | JOIN wopr.dim_device AS d
156 | | ON d.device_id = f.device_id
157 | | JOIN wopr.dim_country AS t
158 | | ON t.country_id = f.affiliate_country_id
159 | | JOIN wopr.fact_euro_rates_hourly AS r
160 | | ON r.currency_id = ?currency_id
161 | | AND f.time_id = r.time_id
162 | |WHERE
163 | | CAST(f.time_id AS DATE) BETWEEN ?day
164 | | AND t.country_code IN ?publisher_countries
165 | | AND d.device_name IN ?device_name
166 | | AND z.description IN ?zone_name
167 | | AND z.technology IN ?technology
168 | | AND z.affiliate_name IN ?affiliate_name
169 | | AND t.country_name IN ?affiliate_country
170 | | AND t.country_level_1_name IN ?affiliate_region
171 | | AND z.network_name IN ?network_name
172 | |GROUP BY
173 | | ROLLUP(
174 | | (
175 | | d.device_name,
176 | | z.description,
177 | | z.technology,
178 | | z.affiliate_name,
179 | | t.country_name,
180 | | z.network_name,
181 | | t.country_level_1_name,
182 | | f.time_id,
183 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE),
184 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE),
185 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE),
186 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE)
187 | | )
188 | | )
189 | |HAVING
190 | | SUM(clicks) > 0
191 | """.stripMargin
192 | )
193 | }
194 | }
195 |
196 | val eventuallyResult = Future.sequence {
197 | randomDataSet.zipWithIndex.map {
198 | case (queries, threadX) =>
199 | Future {
200 | blocking {
201 | queries.zipWithIndex.map {
202 | case ((sql, expectedOutput), itemX) =>
203 | uniqueParser
204 | .parseStatement(sql)
205 | .fold(e => sys.error(s"\n\n${e.toString(sql)}\n"), identity)
206 | .toSQL == expectedOutput.stripMargin.trim
207 | }
208 | }
209 | }
210 | }
211 | }
212 |
213 | Await.result(eventuallyResult, 5 minutes).flatten.forall(identity) shouldBe(true)
214 | }
215 |
216 | }
217 |
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/TypingErrorsSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import sql99._
4 | import org.scalatest.prop.TableDrivenPropertyChecks
5 | import org.scalatest.{Matchers, EitherValues, PropSpec}
6 |
7 | class TypingErrorsSpec extends PropSpec with Matchers with EitherValues {
8 |
9 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table(
10 | ("SQL", "Expected error"),
11 |
12 | (
13 | """select ?""",
14 | """|select ?
15 | | ^
16 | |Error: parameter type cannot be inferred from context
17 | """
18 | ),
19 | (
20 | """select ? = ?""",
21 | """|select ? = ?
22 | | ^
23 | |Error: parameter type cannot be inferred from context
24 | """
25 | ),
26 | (
27 | """select ? = plop""",
28 | """|select ? = plop
29 | | ^
30 | |Error: column not found plop
31 | """
32 | ),
33 | (
34 | """select YOLO(88)""",
35 | """|select YOLO(88)
36 | | ^
37 | |Error: unknown function yolo
38 | """
39 | ),
40 | (
41 | """select min(12, 13)""",
42 | """|select min(12, 13)
43 | | ^
44 | |Error: too many arguments
45 | """
46 | ),
47 | (
48 | """select max()""",
49 | """|select max()
50 | | ^
51 | |Error: expected argument
52 | """
53 | ),
54 | (
55 | """select ? between ? and ?""",
56 | """|select ? between ? and ?
57 | | ^
58 | |Error: parameter type cannot be inferred from context
59 | """
60 | ),
61 | (
62 | """select ? between ? and coco""",
63 | """|select ? between ? and coco
64 | | ^
65 | |Error: column not found coco
66 | """
67 | ),
68 | (
69 | """select ? in (?)""",
70 | """|select ? in (?)
71 | | ^
72 | |Error: parameter type cannot be inferred from context
73 | """
74 | ),
75 | (
76 | """select case when city_id = 1 then 'foo' else 1 end from City""",
77 | """|select case when city_id = 1 then 'foo' else 1 end from City
78 | | ^
79 | |Error: expected string, found integer
80 | """
81 | ),
82 | (
83 | """select 1 = true""",
84 | """|select 1 = true
85 | | ^
86 | |Error: expected integer, found boolean
87 | """
88 | ),
89 | (
90 | """select * from (select 1 a union select true a) x""",
91 | """|select * from (select 1 a union select true a) x
92 | | ^
93 | |Error: expected integer, found boolean for column a
94 | """
95 | ),
96 | (
97 | """select * from (select 1, 2 union select 1) x""",
98 | """|select * from (select 1, 2 union select 1) x
99 | | ^
100 | |Error: expected same number of columns on both sides of the union
101 | """
102 | )
103 | )
104 |
105 | // --
106 |
107 | val SAKILA = DB(schemas = List(
108 | Schema(
109 | "sakila",
110 | tables = List(
111 | Table(
112 | "City",
113 | columns = List(
114 | Column("city_id", INTEGER(nullable = false)),
115 | Column("city", STRING(nullable = false)),
116 | Column("country_id", INTEGER(nullable = false)),
117 | Column("last_update", TIMESTAMP(nullable = false))
118 | )
119 | ),
120 | Table(
121 | "Country",
122 | columns = List(
123 | Column("country_id", INTEGER(nullable = false)),
124 | Column("country", STRING(nullable = false)),
125 | Column("last_update", TIMESTAMP(nullable = false))
126 | )
127 | )
128 | )
129 | )
130 | ))
131 |
132 | // --
133 |
134 | property("report typing errors on invalid SQL-99 SELECT statements") {
135 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) {
136 | case (sql, expectedError) =>
137 | VizSQL.parseQuery(sql, SAKILA)
138 | .fold(
139 | e => sys.error(s"Query doesn't parse: $e"),
140 | _.error.map(_.toString(sql, ' ')).getOrElse(sys.error(s"Query should not type!"))
141 | ) should be (expectedError.stripMargin.trim)
142 | }
143 | }
144 |
145 | }
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/hive/HiveParsingErrorsSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import org.scalatest.prop.TableDrivenPropertyChecks
4 | import org.scalatest.{EitherValues, Matchers, PropSpec}
5 |
6 | class HiveParsingErrorsSpec extends PropSpec with Matchers with EitherValues {
7 |
8 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table(
9 | ("SQL", "Expected error"),
10 |
11 | (
12 | "select bucket from t",
13 | """select bucket from t
14 | | ^
15 | |Error: *, table or expression expected
16 | """
17 | ),
18 | (
19 | "select foo from tbl limit 100 order by foo",
20 | """select foo from tbl limit 100 order by foo
21 | | ^
22 | |Error: ; expected
23 | """
24 | ),
25 | (
26 | "select foo from bar tablesample (bucket 2 out af 3)",
27 | """select foo from bar tablesample (bucket 2 out af 3)
28 | | ^
29 | |Error: of expected
30 | """.stripMargin
31 | )
32 | )
33 |
34 | // --
35 |
36 | property("report parsing errors on invalid Hive statements") {
37 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) {
38 | case (sql, expectedError) =>
39 | new HiveDialect(Map.empty).parser.parseStatement(sql)
40 | .fold(_.toString(sql, ' ').trim, _ => "[NO ERROR]") should be (expectedError.toString.stripMargin.trim)
41 | }
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/hive/HiveTypeParserSpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import org.scalatest.prop.TableDrivenPropertyChecks
5 | import org.scalatest.{Matchers, PropSpec}
6 |
7 | class HiveTypeParserSpec extends PropSpec with Matchers {
8 |
9 | val types = TableDrivenPropertyChecks.Table(
10 | ("Type string", "Expected type"),
11 |
12 | ("double", DECIMAL(true)),
13 | ("array", HiveArray(INTEGER(true))),
14 | ("map>", HiveMap(STRING(true), HiveStruct(List(
15 | Column("a", BOOLEAN(true)),
16 | Column("b", TIMESTAMP(true))
17 | )))),
18 | ("array>,bar:array,`baz`:double>>", HiveArray(HiveStruct(List(
19 | Column("foo", HiveArray(HiveMap(STRING(true), STRING(true)))),
20 | Column("bar", HiveArray(INTEGER(true))),
21 | Column("baz", DECIMAL(true))
22 | )))),
23 | ("struct", HiveStruct(List(
24 | Column("timestamp", INTEGER(true))
25 | )))
26 | )
27 |
28 | // --
29 |
30 | property("parse to correct types") {
31 | val parser = new TypeParser
32 | TableDrivenPropertyChecks.forAll(types) {
33 | case (typeString, expectedType) =>
34 | parser.parseType(typeString) shouldEqual Right(expectedType)
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/jvm/src/test/scala/com/criteo/vizatra/vizsql/hive/ParseHiveQuerySpec.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import org.scalatest.prop.TableDrivenPropertyChecks
5 | import org.scalatest.{Matchers, PropSpec}
6 |
7 | class ParseHiveQuerySpec extends PropSpec with Matchers {
8 | val queries = TableDrivenPropertyChecks.Table(
9 | ("Valid Hive query", "Expected AST"),
10 |
11 | ("select a.b[2].c.d.e['foo'].f.g.h.i[0] from t",
12 | SimpleSelect(
13 | projections = List(ExpressionProjection(
14 | MapOrArrayAccessExpression(
15 | StructAccessExpr(
16 | StructAccessExpr(
17 | StructAccessExpr(
18 | StructAccessExpr(
19 | MapOrArrayAccessExpression(
20 | StructAccessExpr(
21 | StructAccessExpr(
22 | StructAccessExpr(
23 | MapOrArrayAccessExpression(
24 | ColumnOrStructAccessExpression(ColumnIdent("b", Some(TableIdent("a")))),
25 | LiteralExpression(IntegerLiteral(2))),
26 | "c"),
27 | "d"),
28 | "e"),
29 | LiteralExpression(StringLiteral("foo"))),
30 | "f"),
31 | "g"),
32 | "h"),
33 | "i"),
34 | LiteralExpression(IntegerLiteral(0)))
35 | )),
36 | relations = List(SingleTableRelation(TableIdent("t")))
37 | )
38 | ),
39 | ("select a, b from table x lateral view explode (col) y as z",
40 | SimpleSelect(
41 | projections = List(
42 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("a"))),
43 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("b")))
44 | ),
45 | relations = List(LateralView(
46 | inner = SingleTableRelation(TableIdent("table"), Some("x")),
47 | explodeFunction = FunctionCallExpression(
48 | name = "explode",
49 | distinct = None,
50 | args = List(ColumnOrStructAccessExpression(ColumnIdent("col")))),
51 | tableAlias = "y",
52 | columnAliases = List("z")
53 | ))
54 | )
55 | ),
56 | ("select * from ta a left semi join tb b on a.id = b.id",
57 | SimpleSelect(
58 | projections = List(AllColumns),
59 | relations = List(JoinRelation(
60 | left = SingleTableRelation(TableIdent("ta"), Some("a")),
61 | join = LeftSemiJoin,
62 | right = SingleTableRelation(TableIdent("tb"), Some("b")),
63 | on = Some(ComparisonExpression(
64 | op = "=",
65 | left = ColumnOrStructAccessExpression(ColumnIdent("id", Some(TableIdent("a")))),
66 | right = ColumnOrStructAccessExpression(ColumnIdent("id", Some(TableIdent("b"))))
67 | ))
68 | ))
69 | )
70 | ),
71 | ("select derp(a) from t",
72 | SimpleSelect(
73 | projections = List(ExpressionProjection(
74 | FunctionCallExpression(
75 | name = "derp",
76 | distinct = None,
77 | args = ColumnOrStructAccessExpression(ColumnIdent("a")) :: Nil
78 | )
79 | )),
80 | relations = List(SingleTableRelation(TableIdent("t")))
81 | )
82 | ),
83 | ("select `select` s, `from` f from `join` j order by `select` desc cluster by `from` limit 100",
84 | //FIXME limit is lost, and cluster by is clumped with order by
85 | SimpleSelect(
86 | projections = List(
87 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("select")), Some("s")),
88 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("from")), Some("f"))
89 | ),
90 | relations = List(SingleTableRelation(TableIdent("join"), Some("j"))),
91 | orderBy = List(
92 | SortExpression(ColumnOrStructAccessExpression(ColumnIdent("select")), Some(SortDESC)),
93 | SortExpression(ColumnOrStructAccessExpression(ColumnIdent("from")), None)
94 | ),
95 | limit = Some(IntegerLiteral(100))
96 | )
97 | ),
98 | ("select foo from bar tablesample (bucket 2 out of 3 on baz)",
99 | //FIXME tablesample is lost
100 | SimpleSelect(
101 | projections = List(ExpressionProjection (ColumnOrStructAccessExpression(ColumnIdent("foo")))),
102 | relations = List(SingleTableRelation(TableIdent("bar")))
103 | )
104 | )
105 | )
106 |
107 | // --
108 |
109 | property("parse query") {
110 | TableDrivenPropertyChecks.forAll(queries) { case (query, ast) =>
111 | new HiveDialect(Map.empty).parser.parseStatement(query) match {
112 | case Left(err) => fail(err.toString(query))
113 | case Right(result) => result shouldEqual ast
114 | }
115 | }
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.13")
2 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "1.1")
3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
4 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Dialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | trait Dialect {
4 | def parser: SQLParser
5 | def functions: PartialFunction[String,SQLFunction]
6 | }
7 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Errors.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | trait Err {
4 | val msg: String
5 | val pos: Int
6 |
7 | def toString(sql: String, lineChar: Char = '~') = {
8 | val carretPosition = sql.size - sql.drop(pos).dropWhile(_.toString.matches("""\s""")).size
9 |
10 | sql.split('\n').foldLeft("") {
11 | case (sql, line) if sql.size <= carretPosition && (sql.size + line.size) >= carretPosition => {
12 | val carret = (1 to (carretPosition - sql.size)).map(_ => lineChar).mkString + "^"
13 | s"$sql\n$line\n$carret\nError: $msg\n"
14 | }
15 | case (sql, line) => s"$sql\n$line"
16 | }.split('\n').dropWhile(_.isEmpty).reverse.dropWhile(_.isEmpty).reverse.mkString("\n")
17 | }
18 |
19 | def combine(err: Err): Err = this
20 | }
21 |
22 | case class PlaceholderError(msg: String, pos: Int) extends Err {
23 | override def combine(err: Err) = err match {
24 | case SchemaError(_, _) | TypeError(_, _) => err
25 | case _ => this
26 | }
27 | }
28 | case class ParsingError(msg: String, pos: Int) extends Err
29 | case class SchemaError(msg: String, pos: Int) extends Err
30 | case class TypeError(msg: String, pos: Int) extends Err
31 | case class SQLError(msg: String, pos: Int) extends Err
32 | case class ParameterError(msg: String, pos: Int) extends Err
33 | case object NoError extends Err {
34 | val msg = "?WAT"
35 | val pos = 0
36 | override def combine(err: Err) = err
37 | }
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/EvalHelper.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import java.util.Date
4 |
5 | object EvalHelper {
6 |
7 | def applyOp[T](op: String, x : T, y : T)(implicit num : Numeric[T]) : T = op match {
8 | case "+" => num.plus(x, y)
9 | case "-" => num.minus(x, y)
10 | case "*" => num.times(x, y)
11 | // Numeric doesn't handle division
12 | }
13 |
14 | val ops = Set("+", "-", "*")
15 |
16 | def mathEval(op : String, x : Literal, y : Literal) : Either[Literal, Expression] = (x, y) match {
17 | case (DecimalLiteral(a), DecimalLiteral(b)) if op == "/" && (b != 0) => Left(DecimalLiteral(a / b))
18 | case (IntegerLiteral(a), DecimalLiteral(b)) if op == "/" && (b != 0) => Left(DecimalLiteral(a / b))
19 | case (DecimalLiteral(a), IntegerLiteral(b)) if op == "/" && (b != 0) => Left(DecimalLiteral(a / b))
20 | case (IntegerLiteral(a), IntegerLiteral(b)) if op == "/" && (b != 0) && (a % b == 0) => Left(IntegerLiteral(a/b))
21 |
22 | case (DecimalLiteral(a), DecimalLiteral(b)) if ops contains op => Left(DecimalLiteral(applyOp(op, a,b)))
23 | case (IntegerLiteral(a), DecimalLiteral(b)) if ops contains op => Left(DecimalLiteral(applyOp(op, a,b)))
24 | case (DecimalLiteral(a), IntegerLiteral(b)) if ops contains op => Left(DecimalLiteral(applyOp(op, a,b)))
25 | case (IntegerLiteral(a), IntegerLiteral(b)) if ops contains op => Left(IntegerLiteral(applyOp(op, a,b)))
26 | case (NullLiteral, _) => Left(NullLiteral)
27 | case (_, NullLiteral) => Left(NullLiteral)
28 | case _ => Right(MathExpression(op, LiteralExpression(x), LiteralExpression(y)))
29 | }
30 |
31 | def compEval(op : String, x : Literal, y : Literal) : Either[Literal, Expression] = (x, y) match {
32 | case (DecimalLiteral(a), DecimalLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b)))
33 | case (IntegerLiteral(a), DecimalLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b)))
34 | case (DecimalLiteral(a), IntegerLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b)))
35 | case (IntegerLiteral(a), IntegerLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b)))
36 | case (StringLiteral(a), StringLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b)))
37 | case (NullLiteral, _) => Left(NullLiteral)
38 | case (_, NullLiteral) => Left(NullLiteral)
39 | case (l, r) if (l == TrueLiteral || l== FalseLiteral) &&
40 | (r == TrueLiteral || r == FalseLiteral) => Left(bool2Literal(compOperators(op)(literal2bool(l), literal2bool(r))))
41 | case _ => Right(ComparisonExpression(op, LiteralExpression(x), LiteralExpression(y)))
42 | }
43 |
44 | def compare[T<: Any](v1 : T, v2: T) : Int = (v1, v2) match {
45 | case (x : Ordered[_], y) => x.asInstanceOf[Ordered[Any]] compareTo y.asInstanceOf[Any]
46 | case (x : Number, y : Number) => x.doubleValue().compare(y.doubleValue())
47 | case (x : String, y : String) => x.compare(y)
48 | case (x : Date, y: Date) => x compareTo y
49 | case (x : Boolean, y : Boolean ) => x compareTo y
50 | case (x, y) if x == y => 0
51 | }
52 |
53 | def compOperators = Map(
54 | ">" -> ((a : Any, b : Any) => compare(a, b) > 0),
55 | "<" -> ((a : Any, b : Any) => compare(a, b) < 0),
56 | ">=" -> ((a : Any, b : Any) => compare(a, b) >= 0),
57 | "<=" -> ((a : Any, b : Any) => compare(a, b) <= 0),
58 | "=" -> ((a : Any, b : Any) => a == b),
59 | "<>"-> ((a : Any, b : Any) => a != b)
60 | )
61 |
62 | def bool2Literal(b : Boolean) = if (b) TrueLiteral else FalseLiteral
63 | def literal2bool(b: Literal) = (b == TrueLiteral)
64 | }
65 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Functions.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | trait SQLFunction {
4 | def getPlaceholders(call: FunctionCallExpression, db: DB, expectedType: Option[Type]) =
5 | Utils.allPlaceholders[Expression](call.args, _.getPlaceholders(db, expectedType))
6 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders): Either[Err,Type]
7 | def getArguments(call: FunctionCallExpression, db: DB, placeholders: Placeholders, nb: Int): Either[Err,List[Type]] = {
8 | if(nb >= 0 && nb > call.args.size) {
9 | Left(TypeError("expected argument", call.pos + call.name.size + 1))
10 | }
11 | else if(nb >= 0 && call.args.size > nb) {
12 | Left(TypeError("too many arguments", call.args.drop(nb).head.pos))
13 | }
14 | else {
15 | call.args.foldRight(Right(Nil):Either[Err,List[Type]]) {
16 | (arg, acc) => for(a <- acc.right; b <- arg.resultType(db, placeholders).right) yield b :: a
17 | }
18 | }
19 | }
20 | }
21 |
22 | trait SQLFunction0 extends SQLFunction {
23 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) =
24 | getArguments(call, db, placeholders, 0).right.flatMap { _ =>
25 | result
26 | }
27 | def result: Either[Err,Type]
28 | }
29 |
30 | trait SQLFunction1 extends SQLFunction {
31 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) =
32 | getArguments(call, db, placeholders, 1).right.flatMap { types =>
33 | result(call.args(0) -> types(0))
34 | }
35 | def result: PartialFunction[(Expression,Type),Either[Err,Type]]
36 | }
37 |
38 | trait SQLFunction2 extends SQLFunction {
39 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) =
40 | getArguments(call, db, placeholders, 2).right.flatMap { types =>
41 | result((call.args(0) -> types(0), call.args(1) -> types(1)))
42 | }
43 | def result: PartialFunction[((Expression,Type),(Expression,Type)),Either[Err,Type]]
44 | }
45 |
46 | trait SQLFunction3 extends SQLFunction {
47 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) =
48 | getArguments(call, db, placeholders, 3).right.flatMap { types =>
49 | result((call.args(0) -> types(0), call.args(1) -> types(1), call.args(2) -> types(2)))
50 | }
51 | def result: PartialFunction[((Expression,Type),(Expression,Type),(Expression,Type)),Either[Err,Type]]
52 | }
53 |
54 | trait SQLFunctionX extends SQLFunction {
55 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) =
56 | getArguments(call, db, placeholders, -1).right.flatMap { types =>
57 | result(call.args.zip(types))
58 | }
59 | def result: PartialFunction[List[(Expression,Type)],Either[Err,Type]]
60 | }
61 |
62 | object SQLFunction {
63 |
64 | def standard: PartialFunction[String,SQLFunction] = {
65 | case "min" | "max" => new SQLFunction1 {
66 | def result = { case (_, t) => Right(t) }
67 | }
68 | case "avg" | "sum" => new SQLFunction1 {
69 | def result = {
70 | case (_, t @ (INTEGER(_) | DECIMAL(_))) => Right(t)
71 | case (arg, _) => Left(TypeError("expected numeric argument", arg.pos))
72 | }
73 | }
74 | case "now" => new SQLFunction0 {
75 | def result = Right(TIMESTAMP())
76 | }
77 | case "concat" => new SQLFunctionX {
78 | def result = {
79 | case (_, t1) :: _ => Right(STRING(t1.nullable))
80 | }
81 | }
82 | case "coalesce" => new SQLFunctionX {
83 | def result = {
84 | case (_, t1) :: tail =>
85 | tail.foldLeft[Either[Err, Type]](Right(t1)) {
86 | case (Right(t), (curExpr, curType)) =>
87 | Utils.parentType(t, curType, TypeError(s"expected ${t.show}, got ${curType.show}", curExpr.pos))
88 | case (Left(err), _) =>
89 | Left(err)
90 | }
91 | }
92 | }
93 | case "count" => new SQLFunctionX {
94 | override def result = {
95 | case _ :: _ => Right(INTEGER(false))
96 | }
97 | }
98 | }
99 |
100 | }
101 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Parser.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | import scala.util.parsing.combinator._
4 | import scala.util.parsing.combinator.lexical._
5 | import scala.util.parsing.combinator.syntactical._
6 | import scala.util.parsing.input.CharArrayReader.EofCh
7 |
8 | trait SQLParser {
9 | def parseStatement(sql: String): Either[Err,Statement]
10 | }
11 |
12 | object SQL99Parser {
13 | val keywords = Set(
14 | "all", "and", "as", "asc", "between", "by", "case", "cast", "cross", "cube",
15 | "desc", "distinct", "else", "end", "exists", "false", "from", "full", "group", "grouping",
16 | "having", "in", "inner", "is", "join", "left", "like", "limit",
17 | "not", "null", "on", "or", "order", "outer", "right", "rollup", "select",
18 | "sets", "then", "true", "union", "unknown", "when", "where"
19 | )
20 |
21 | val delimiters = Set(
22 | "(", ")", "\"",
23 | "'", "%", "&",
24 | "*", "/", "+",
25 | "-", ",", ".",
26 | ":", ";", "<",
27 | ">", "?", "[",
28 | "]", "_", "|",
29 | "=", "{", "}",
30 | "^", "??(", "??)",
31 | "<>", ">=", "<=",
32 | "||", "->", "=>"
33 | )
34 | val comparisonOperators = Set("=", "<>", "<", ">", ">=", "<=")
35 |
36 | val likeOperators = Set("like")
37 |
38 | val orOperators = Set("or")
39 |
40 | val andOperators = Set("and")
41 |
42 | val additionOperators = Set("+", "-")
43 |
44 | val multiplicationOperators = Set("*", "/")
45 |
46 | val unaryOperators = Set("-", "+")
47 |
48 | val typeMap: Map[String, TypeLiteral] = Map(
49 | "timestamp" -> TimestampTypeLiteral,
50 | "datetime" -> TimestampTypeLiteral,
51 | "date" -> DateTypeLiteral,
52 | "boolean" -> BooleanTypeLiteral,
53 | "varchar" -> VarcharTypeLiteral,
54 | "integer" -> IntegerTypeLiteral,
55 | "numeric" -> NumericTypeLiteral,
56 | "decimal" -> DecimalTypeLiteral,
57 | "real" -> RealTypeLiteral
58 | )
59 | }
60 |
61 | class SQL99Parser extends SQLParser with TokenParsers with PackratParsers {
62 |
63 | type Tokens = SQLTokens
64 |
65 | trait SQLTokens extends token.Tokens {
66 | case class Keyword(chars: String) extends Token
67 | case class Identifier(chars: String) extends Token
68 | case class IntegerLit(chars: String) extends Token
69 | case class DecimalLit(chars: String) extends Token
70 | case class StringLit(chars: String) extends Token
71 | case class Delimiter(chars: String) extends Token
72 | }
73 |
74 | class SQLLexical extends Lexical with SQLTokens {
75 |
76 | lazy val whitespace = rep(whitespaceChar | blockComment | lineComment)
77 |
78 | val keywords = SQL99Parser.keywords
79 | val delimiters = SQL99Parser.delimiters
80 |
81 | lazy val blockComment = '/' ~ '*' ~ rep(chrExcept('*', EofCh) | ('*' ~ chrExcept('/', EofCh))) ~ '*' ~ '/'
82 | lazy val lineComment = '-' ~ '-' ~ rep(chrExcept('\n', '\r', EofCh))
83 |
84 | lazy val delimiter = {
85 | delimiters.toList.sorted.map(s => accept(s.toList) ^^^ Delimiter(s)).foldRight(
86 | failure("no matching delimiter"): Parser[Token]
87 | )((x, y) => y | x)
88 | }
89 |
90 | override lazy val letter = elem("letter", _.isLetter)
91 | lazy val identifierOrKeyword = letter ~ rep(letter | digit | '_') ^^ {
92 | case first ~ rest =>
93 | val chars = first :: rest mkString ""
94 | if(keywords.contains(chars.toLowerCase)) Keyword(chars.toLowerCase) else Identifier(chars)
95 | }
96 |
97 | def customToken: Parser[Token] = elem("should not exist", _ => false) ^^ { _ => sys.error("custom token should not exist")}
98 |
99 | lazy val exp = (accept('e') | accept('E')) ~ opt('-') ~ rep1(digit) ^^ { case e ~ n ~ d => e.toString + n.mkString + d.mkString}
100 |
101 | def quoted(delimiter: Char) =
102 | delimiter ~> rep(
103 | ( (delimiter ~ delimiter) ^^^ delimiter
104 | | chrExcept(delimiter, '\n', EofCh)
105 | )
106 | ) <~ delimiter
107 |
108 | lazy val token =
109 | ( customToken
110 | | identifierOrKeyword
111 | | rep1(digit) ~ ('.' ~> rep(digit)) ~ opt(exp) ^^ { case i ~ d ~ mE => DecimalLit(i.mkString + "." + d.mkString + mE.mkString) }
112 | | rep1(digit) ~ exp ^^ { case i ~ e => DecimalLit(i.mkString + e) }
113 | | '.' ~> rep1(digit) ~ opt(exp) ^^ { case d ~ mE => DecimalLit("." + d.mkString + mE.mkString) }
114 | | rep1(digit) ^^ { case i => IntegerLit(i.mkString) }
115 | | quoted('\'') ^^ { case chars => StringLit(chars mkString "") }
116 | | quoted('\"') ^^ { case chars => Identifier(chars mkString "") }
117 | | EofCh ^^^ EOF
118 | | '\'' ~> failure("unclosed string literal")
119 | | '\"' ~> failure("unclosed identifier")
120 | | delimiter
121 | | failure("illegal character")
122 | )
123 |
124 | }
125 |
126 | override val lexical = new SQLLexical
127 |
128 | implicit def stringLiteralToKeywordOrDelimiter(chars: String) = {
129 | if(lexical.keywords.contains(chars) || lexical.delimiters.contains(chars)) {
130 | (accept(lexical.Keyword(chars)) | accept(lexical.Delimiter(chars))) ^^ (_.chars) withFailureMessage(s"$chars expected")
131 | }
132 | else {
133 | sys.error(s"""!!! Invalid parser definition: $chars is not a valid SQL keyword or delimiter""")
134 | }
135 | }
136 |
137 | // --
138 |
139 | lazy val ident =
140 | elem("ident", _.isInstanceOf[lexical.Identifier]) ^^ (_.chars)
141 |
142 | lazy val booleanLiteral =
143 | ( "true" ^^^ TrueLiteral
144 | | "false" ^^^ FalseLiteral
145 | | "unknown" ^^^ UnknownLiteral
146 | )
147 |
148 | lazy val nullLiteral =
149 | "null" ^^^ NullLiteral
150 |
151 | lazy val stringLiteral =
152 | elem("string", _.isInstanceOf[lexical.StringLit]) ^^ { v => StringLiteral(v.chars) }
153 |
154 | lazy val integerLiteral =
155 | elem("integer", _.isInstanceOf[lexical.IntegerLit]) ^^ { v => IntegerLiteral(v.chars.toLong) }
156 |
157 | lazy val decimalLiteral =
158 | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ { v => DecimalLiteral(v.chars.toDouble) }
159 |
160 | lazy val literal = pos(
161 | ( decimalLiteral
162 | | integerLiteral
163 | | stringLiteral
164 | | booleanLiteral
165 | | nullLiteral
166 | )
167 | )
168 |
169 | lazy val typeLiteral: Parser[TypeLiteral] =
170 | elem("type literal", { t => typeMap.contains(t.chars.toLowerCase) }) ^^ { t => typeMap(t.chars.toLowerCase) }
171 |
172 | val typeMap = SQL99Parser.typeMap
173 |
174 | lazy val column = pos(
175 | ( ident ~ "." ~ ident ~ "." ~ ident ^^ { case s ~ _ ~ t ~ _ ~ c => ColumnIdent(c, Some(TableIdent(t, Some(s)))) }
176 | | ident ~ "." ~ ident ^^ { case t ~ _ ~ c => ColumnIdent(c, Some(TableIdent(t, None))) }
177 | | ident ^^ { case c => ColumnIdent(c, None) }
178 | )
179 | )
180 |
181 | lazy val table =
182 | ( ident ~ "." ~ ident ^^ { case s ~ _ ~ t => TableIdent(t, Some(s)) }
183 | | ident ^^ { case t => TableIdent(t, None) }
184 | )
185 |
186 | lazy val function =
187 | ident ~ ("(" ~> opt(distinct) ~ repsep(expr, ",") <~ ")") ^^ { case n ~ (d ~ a) => FunctionCallExpression(n.toLowerCase, d, a) }
188 |
189 | lazy val countStar =
190 | elem("count", _.chars.toLowerCase == "count") ~ "(" ~ "*" ~ ")" ^^^ CountStarExpression
191 |
192 | lazy val or = (precExpr: Parser[Expression]) =>
193 | precExpr *
194 | orOperators.map { op => op ^^^ OrExpression.operator(op) _ }.reduce(_ | _)
195 |
196 | val orOperators = SQL99Parser.orOperators
197 |
198 | lazy val and = (precExpr: Parser[Expression]) =>
199 | precExpr *
200 | andOperators.map { op => op ^^^ AndExpression.operator(op) _ }.reduce(_ | _)
201 |
202 | val andOperators = SQL99Parser.andOperators
203 |
204 | lazy val not = (precExpr: Parser[Expression]) => {
205 | def thisExpr: Parser[Expression] =
206 | ( "not" ~> thisExpr ^^ NotExpression
207 | | precExpr
208 | )
209 | thisExpr
210 | }
211 |
212 | lazy val exists = (precExpr: Parser[Expression]) =>
213 | ( "exists" ~> "(" ~> select <~ ")" ^^ ExistsExpression
214 | | precExpr
215 | )
216 |
217 | lazy val comparator = (precExpr: Parser[Expression]) =>
218 | precExpr *
219 | comparisonOperators.map { op => op ^^^ ComparisonExpression.operator(op)_ }.reduce(_ | _)
220 |
221 | val comparisonOperators = SQL99Parser.comparisonOperators
222 |
223 | lazy val like = (precExpr: Parser[Expression]) =>
224 | precExpr *
225 | likeOperators.map { op =>
226 | opt("not") <~ op ^^ { case o => LikeExpression.operator(op, o.isDefined)_ }
227 | }.reduce(_ | _)
228 |
229 | val likeOperators = SQL99Parser.likeOperators
230 |
231 | lazy val limit = "limit" ~> integerLiteral
232 |
233 | lazy val between = (precExpr: Parser[Expression]) =>
234 | precExpr ~ rep(opt("not") ~ ("between" ~> precExpr ~ ("and" ~> precExpr))) ^^ {
235 | case l ~ r => r.foldLeft(l) { case (e, n ~ (lb ~ ub)) => IsBetweenExpression(e, n.isDefined, (lb, ub)) }
236 | }
237 |
238 | lazy val between0 = (precExpr: Parser[Expression]) =>
239 | precExpr ~ rep(opt("not") ~ ("between" ~> rangePlaceholder)) ^^ {
240 | case l ~ r => r.foldLeft(l) { case (e, n ~ p) => IsBetweenExpression0(e, n.isDefined, p) }
241 | }
242 |
243 | lazy val in = (precExpr: Parser[Expression]) =>
244 | precExpr ~ rep(opt("not") ~ ("in" ~> ("(" ~> rep1sep(expr, ",") <~ ")"))) ^^ {
245 | case l ~ r => r.foldLeft(l) { case (e, n ~ l) => IsInExpression(e, n.isDefined, l) }
246 | }
247 |
248 | lazy val in0 = (precExpr: Parser[Expression]) =>
249 | precExpr ~ rep(opt("not") ~ ("in" ~> setPlaceholder)) ^^ {
250 | case l ~ r => r.foldLeft(l) { case (e, n ~ p) => IsInExpression0(e, n.isDefined, p) }
251 | }
252 |
253 | lazy val is = (precExpr: Parser[Expression]) =>
254 | precExpr ~ rep("is" ~> opt("not") ~ (booleanLiteral | nullLiteral)) ^^ {
255 | case l ~ r => r.foldLeft(l) { case (e, n ~ l) => IsExpression(e, n.isDefined, l) }
256 | }
257 |
258 | lazy val add = (precExpr: Parser[Expression]) =>
259 | precExpr *
260 | additionOperators.map { op => op ^^^ MathExpression.operator(op)_ }.reduce(_ | _)
261 |
262 | val additionOperators = SQL99Parser.additionOperators
263 |
264 | lazy val multiply = (precExpr: Parser[Expression]) =>
265 | precExpr *
266 | multiplicationOperators.map { op => op ^^^ MathExpression.operator(op)_ }.reduce(_ | _)
267 |
268 | val multiplicationOperators = SQL99Parser.multiplicationOperators
269 |
270 | lazy val unary = (precExpr: Parser[Expression]) =>
271 | ((unaryOperators.map { op =>
272 | op ~ precExpr ^^ { case `op` ~ p => UnaryMathExpression(op, p) }
273 | }: Set[Parser[Expression]]) + precExpr).reduce(_ | _)
274 |
275 | val unaryOperators = SQL99Parser.unaryOperators
276 |
277 | lazy val placeholder =
278 | opt(ident) ~ opt(":" ~> typeLiteral) ^^ { case i ~ t => ExpressionPlaceholder(Placeholder(i), t) }
279 |
280 | lazy val expressionPlaceholder =
281 | "?" ~> placeholder
282 |
283 | lazy val rangePlaceholder =
284 | ("?" ~ "[") ~> placeholder <~ ")"
285 |
286 | lazy val setPlaceholder =
287 | ("?" ~ "{") ~> placeholder <~ "}"
288 |
289 | lazy val cast =
290 | ("cast" ~ "(") ~> expr ~ ("as" ~> typeLiteral <~ ")") ^^ { case e ~ t => CastExpression(e, t) }
291 |
292 | lazy val caseWhen =
293 | ("case" ~> opt(expr) ~ when ~ opt("else" ~> expr) <~ "end" ) ^^ { case maybeValue ~ whenList ~ maybeElse => CaseWhenExpression(maybeValue, whenList, maybeElse) }
294 |
295 | lazy val when: Parser[List[(Expression, Expression)]] =
296 | rep1("when" ~> expr ~ ("then" ~> expr)) ^^ { _.map { case a ~ b => (a, b) } }
297 |
298 | lazy val simpleExpr = (_: Parser[Expression]) =>
299 | ( literal ^^ LiteralExpression
300 | | function
301 | | countStar
302 | | cast
303 | | caseWhen
304 | | column ^^ ColumnExpression
305 | | "(" ~> select <~ ")" ^^ SubSelectExpression
306 | | "(" ~> expr <~ ")" ^^ ParenthesedExpression
307 | | expressionPlaceholder
308 | )
309 |
310 | lazy val precedenceOrder =
311 | ( simpleExpr
312 | :: unary
313 | :: multiply
314 | :: add
315 | :: is
316 | :: in0
317 | :: in
318 | :: between0
319 | :: between
320 | :: comparator
321 | :: like
322 | :: exists
323 | :: not
324 | :: and
325 | :: or
326 | :: Nil
327 | )
328 |
329 | lazy val expr: PackratParser[Expression] =
330 | precedenceOrder.reverse.foldRight(failure("Invalid expression"):Parser[Expression])((a,b) => a(pos(b) | failure("expression expected")))
331 |
332 | lazy val starProjection =
333 | "*" ^^^ AllColumns
334 |
335 | lazy val tableProjection =
336 | ( (ident ~ "." ~ ident) <~ ("." ~ "*") ^^ { case s ~ _ ~ t => AllTableColumns(TableIdent(t, Some(s))) }
337 | | ident <~ ("." ~ "*") ^^ { case t => AllTableColumns(TableIdent(t, None)) }
338 | )
339 |
340 | lazy val exprProjection =
341 | expr ~ opt(opt("as") ~> (ident | stringLiteral)) ^^ {
342 | case e ~ a => ExpressionProjection(e, a.collect { case StringLiteral(v) => v; case a: String => a })
343 | }
344 |
345 | lazy val projections = pos(
346 | ( starProjection
347 | | tableProjection
348 | | exprProjection
349 |
350 | | failure("*, table or expression expected")
351 | )
352 | )
353 |
354 | lazy val relations =
355 | "from" ~> rep1sep(relation, ",")
356 |
357 | lazy val relation: PackratParser[Relation] = pos(
358 | ( joinRelation
359 | | singleTableRelation
360 | | subSelectRelation
361 |
362 | | failure("table, join or subselect expected")
363 | )
364 | )
365 |
366 | lazy val singleTableRelation =
367 | table ~ opt(opt("as") ~> (ident | stringLiteral)) ^^ {
368 | case t ~ a => SingleTableRelation(t, a.collect { case StringLiteral(v) => v; case a: String => a })
369 | }
370 |
371 | lazy val subSelectRelation =
372 | ("(" ~> select <~ ")") ~ (opt("as") ~> ident) ^^ {
373 | case s ~ a => SubSelectRelation(s, a)
374 | }
375 |
376 | lazy val join =
377 | ( opt("inner") ~ "join" ^^^ InnerJoin
378 | | "left" ~ opt("outer") ~ "join" ^^^ LeftJoin
379 | | "right" ~ opt("outer") ~ "join" ^^^ RightJoin
380 | | "full" ~ opt("outer") ~ "join" ^^^ FullJoin
381 | | "cross" ~ "join" ^^^ CrossJoin
382 | )
383 |
384 | lazy val joinRelation =
385 | relation ~ join ~ relation ~ opt("on" ~> expr) ^^ {
386 | case l ~ j ~ r ~ o => JoinRelation(l, j, r, o)
387 | }
388 |
389 | lazy val filters =
390 | "where" ~> expr
391 |
392 | lazy val groupingSet =
393 | ( "(" ~ ")" ^^^ GroupingSet(Nil)
394 | | "(" ~> repsep(expr, ",") <~ ")" ^^ GroupingSet
395 | )
396 |
397 | lazy val groupingSetOrExpr =
398 | ( groupingSet ^^ Right.apply
399 | | expr ^^ Left.apply
400 | )
401 |
402 | lazy val group =
403 | ( ("grouping" ~ "sets") ~> ("(" ~> repsep(groupingSet, ",") <~ ")") ^^ GroupByGroupingSets
404 | | "rollup" ~> ("(" ~> repsep(groupingSetOrExpr, ",") <~ ")") ^^ GroupByRollup
405 | | "cube" ~> ("(" ~> repsep(groupingSetOrExpr, ",") <~ ")") ^^ GroupByCube
406 | | expr ^^ GroupByExpression
407 | )
408 |
409 | lazy val groupBy =
410 | ("group" ~ "by") ~> repsep(group, ",")
411 |
412 | lazy val having =
413 | "having" ~> expr
414 |
415 | lazy val sortOrder =
416 | ( "asc" ^^^ SortASC
417 | | "desc" ^^^ SortDESC
418 | )
419 |
420 | lazy val sortExpr =
421 | expr ~ opt(sortOrder) ^^ { case e ~ o => SortExpression(e, o) }
422 |
423 | lazy val orderBy =
424 | ("order" ~ "by") ~> repsep(sortExpr, ",")
425 |
426 | lazy val distinct =
427 | ( "distinct" ^^^ SetDistinct
428 | | "all" ^^^ SetAll
429 | )
430 |
431 | lazy val unionSelect: Parser[UnionSelect] =
432 | select ~ ("union" ~> opt(distinct)) ~ select ^^ {
433 | case l ~ d ~ r => UnionSelect(l, d, r)
434 | }
435 |
436 | lazy val simpleSelect: Parser[SimpleSelect] =
437 | "select" ~> opt(distinct) ~ rep1sep(projections, ",") ~ opt(relations) ~ opt(filters) ~ opt(groupBy) ~ opt(having) ~ opt(orderBy) ~ opt(limit) ^^ {
438 | case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => SimpleSelect(d, p, r.getOrElse(Nil), f, g.getOrElse(Nil), h, o.getOrElse(Nil), l)
439 | }
440 |
441 | lazy val select: PackratParser[Select] = pos(
442 | ( unionSelect
443 | | simpleSelect
444 | )
445 | )
446 |
447 | lazy val statement = pos(
448 | ( select
449 | )
450 | )
451 |
452 | // --
453 |
454 | def pos[T <: SQL](p: => Parser[T]) = Parser { in =>
455 | p(in) match {
456 | case Success(t, in1) => Success(t.setPos(in.offset), in1)
457 | case ns: NoSuccess => ns
458 | }
459 | }
460 |
461 | def parseStatement(sql: String): Either[Err,Statement] = {
462 | phrase(statement <~ opt(";"))(new lexical.Scanner(sql)) match {
463 | case Success(stmt, _) => Right(stmt)
464 | case NoSuccess(msg, rest) => Left(ParsingError(msg match {
465 | case "end of input expected" => "end of statement expected"
466 | case x => x
467 | }, rest.offset))
468 | }
469 | }
470 |
471 | }
472 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Schema.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | case class Schemas(schemas: List[Schema]) {
4 | val index = schemas.map(s => (s.name.toLowerCase, s)).toMap
5 |
6 | def getSchema(name: String): Either[String, Schema] = {
7 | index.get(name.toLowerCase).map(Right.apply _).getOrElse {
8 | Left(s"""schema not found $name""")
9 | }
10 | }
11 |
12 | def getNonAmbiguousTable(name: String): Either[String, (Schema, Table)] = {
13 | schemas.flatMap(s => s.getTable(name).right.toOption.map(s -> _)) match {
14 | case Nil => Left(s"table not found $name")
15 | case table :: Nil => Right(table)
16 | case _ => Left(s"ambiguous table $name")
17 | }
18 | }
19 |
20 | def getNonAmbiguousColumn(name: String): Either[String, (Schema, Table, Column)] = {
21 | schemas.map(s => s.getNonAmbiguousColumn(name).right.map { case (t, c) => (s, t, c) }) match {
22 | case Right(col) :: Nil => Right(col)
23 | case Left(err) :: _ => Left(err)
24 | case _ => Left(s"column not found $name")
25 | }
26 | }
27 | }
28 |
29 | case class DB(dialect: Dialect, schemas: Schemas, view: Schemas = Schemas(Nil)) {
30 |
31 | def function(name: String): Either[String, SQLFunction] = {
32 | dialect.functions.lift(name).map(Right.apply).getOrElse(Left(s"unknown function $name"))
33 | }
34 |
35 | }
36 |
37 | object DB {
38 | def apply(schemas: List[Schema])(implicit dialect: Dialect): DB = DB(dialect, Schemas(schemas))
39 |
40 | import java.sql.{Connection, ResultSet}
41 |
42 | def apply(connection: Connection): DB = {
43 | val databaseMetaData = connection.getMetaData()
44 |
45 | val name = databaseMetaData.getDatabaseProductName()
46 | val version = databaseMetaData.getDatabaseProductVersion()
47 |
48 | val dialect = (name, version) match {
49 | case ("Vertica Database", _) => vertica.dialect
50 | case ("HSQL Database Engine", _) => hsqldb.dialect
51 | case ("H2", _) => h2.dialect
52 | case ("PostgreSQL", _) => postgresql.dialect
53 | case x => sql99.dialect
54 | }
55 |
56 | def consume(rs: ResultSet): List[Map[String, Any]] = {
57 | import collection.mutable._
58 | val data = ListBuffer.empty[Map[String, Any]]
59 | while (rs.next()) {
60 | val row = HashMap.empty[String, Any]
61 | (1 to rs.getMetaData.getColumnCount).foreach { i =>
62 | rs.getObject(i).asInstanceOf[Any] match {
63 | case null =>
64 | case x => row.put(rs.getMetaData.getColumnLabel(i).toLowerCase, x)
65 | }
66 | }
67 | data += row
68 | }
69 | data.toList.map(_.toMap)
70 | }
71 |
72 | val allColumns = consume(databaseMetaData.getColumns(null, null, null, null))
73 |
74 | DB(
75 | dialect,
76 | Schemas(
77 | allColumns.flatMap(_.get("table_schem").map(_.toString.toLowerCase)).distinct.map { schemaName =>
78 | val schemaColumns = allColumns.filter(_.get("table_schem").map(_.toString.toLowerCase) == Some(schemaName)).flatMap { col =>
79 | for {
80 | table <- col.get("table_name").map(_.toString.toLowerCase)
81 | name <- col.get("column_name").map(_.toString.toLowerCase)
82 | nullable <- col.get("is_nullable").map(_.toString.toLowerCase).collect {
83 | case "yes" => true
84 | case "no" => false
85 | }
86 | typ <- col.get("type_name").map(_.toString.toLowerCase).collect(Type.from(nullable))
87 | } yield (table, Column(name, typ))
88 | }
89 | Schema(
90 | name = schemaName,
91 | tables = schemaColumns.groupBy(_._1).map {
92 | case (tableName, columns) => Table(tableName, columns.map(_._2))
93 | }.toList
94 | )
95 | }
96 | )
97 | )
98 | }
99 | }
100 |
101 | case class Schema(name: String, tables: List[Table]) {
102 | val index = tables.map(t => (t.name.toLowerCase, t)).toMap
103 |
104 | def getTable(name: String): Either[String, Table] = {
105 | index.get(name.toLowerCase).map(Right.apply _).getOrElse {
106 | Left(s"""table not found ${this.name}.$name""")
107 | }
108 | }
109 |
110 | def getNonAmbiguousColumn(name: String): Either[String, (Table, Column)] = {
111 | tables.flatMap(t => t.getColumn(name).right.toOption.map(t -> _)) match {
112 | case Nil => Left(s"column not found $name")
113 | case col :: Nil => Right(col)
114 | case _ => Left(s"ambiguous column $name")
115 | }
116 | }
117 |
118 | }
119 |
120 | case class Table(name: String, columns: List[Column]) {
121 | val index = columns.map(t => (t.name.toLowerCase, t)).toMap
122 |
123 | def getColumn(columnName: String): Either[String, Column] = {
124 | index.get(columnName.toLowerCase).map {
125 | column => Right(column)
126 | }.getOrElse {
127 | Left(s"""no column $columnName in table $name""")
128 | }
129 | }
130 | }
131 |
132 | case class Column(name: String, typ: Type)
133 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Show.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | sealed trait Case { def format(str: String): String }
4 | case object UpperCase extends Case { def format(str: String) = str.toUpperCase }
5 | case object LowerCase extends Case { def format(str: String) = str.toLowerCase }
6 | case object CamelCase extends Case {
7 | def format(str: String) = str.headOption.map(_.toString.toUpperCase).getOrElse("") + str.drop(1).toLowerCase
8 | }
9 |
10 | case class Style(pretty: Boolean, keywords: Case, identifiers: Case)
11 | object Style {
12 | implicit val default = Style(true, UpperCase, LowerCase)
13 | val compact = Style(false, UpperCase, LowerCase)
14 | }
15 |
16 | sealed trait Show {
17 | def ~(o: Show) = Show.Group(this :: o :: Nil)
18 | def ~(o: Option[Show]) = o.map(o => Show.Group(this :: o :: Nil)).getOrElse(this)
19 | def ~-(o: Show) = Show.Group(this :: Show.Whitespace :: o :: Nil)
20 | def ~-(o: Option[Show]) = o.map(o => Show.Group(this :: Show.Whitespace :: o :: Nil)).getOrElse(this)
21 | def ~/(o: Show) = Show.Group(this :: Show.NewLine :: o :: Nil)
22 | def ~/(o: Option[Show]) = o.map(o => Show.Group(this :: Show.NewLine :: o :: Nil)).getOrElse(this)
23 | def ~|(o: Show*) = Show.Group(this :: Show.Indented(Show.Group(o.toList)) :: Nil)
24 | def ~|(o: Option[Show]) = o.map(o => Show.Group(this :: Show.Indented(Show.Group(List(o))) :: Nil)).getOrElse(this)
25 |
26 | def toSQL(style: Style): String = Show.toSQL(this, style, None).right.getOrElse(sys.error("WAT?"))
27 | def toSQL(style: Style, placeholders: Placeholders, namedParameters: Map[String,Any], anonymousParameters: List[Any]) = {
28 | Show.toSQL(this, style, Some((placeholders, namedParameters, anonymousParameters)))
29 | }
30 | }
31 |
32 | object Show {
33 | def toSQL(show: Show, style: Style, placeholders: Option[(Placeholders,Map[String,Any],List[Any])]): Either[Err,String] = {
34 | val INDENT = " "
35 | case class MissingParameter(err: Err) extends Throwable
36 | def trimRight(parts: List[String]) = {
37 | val maybeT = parts.reverse.dropWhile(_ == INDENT)
38 | if(maybeT.headOption.exists(_ == "\n")) {
39 | maybeT.tail.reverse
40 | } else parts
41 | }
42 | var pIndex = 0
43 | def print(x: Show, indent: Int, parts: List[String]): List[String] = x match {
44 | case Keyword(k) => parts ++ (style.keywords.format(k) :: Nil)
45 | case Identifier(i) => parts ++ (style.identifiers.format(i) :: Nil)
46 | case Text(x) => parts ++ (x :: Nil)
47 | case Whitespace => parts ++ (" " :: Nil)
48 | case NewLine =>
49 | trimRight(parts) ++ ("\n" :: (0 until indent).map(_ => INDENT).toList)
50 | case Indented(group) =>
51 | print(NewLine, indent, trimRight(
52 | print(group.copy(items = NewLine :: group.items), indent + 1, parts)
53 | ))
54 | case Group(items) =>
55 | items.foldLeft(parts) {
56 | case (parts, i) => print(i, indent, parts)
57 | }
58 | case Parameter(placeholder) =>
59 | placeholders.map {
60 | case (placeholders, namedParameters, anonymousParameters) =>
61 | def param(paramType: Option[Type], value: Any): String = paramType.map { p =>
62 | def rec(value: FilledParameter) : String = value match {
63 | case StringParameter(s) => s"'${s.replace("'", "''")}'"
64 | case IntegerParameter(x) => x.toString
65 | case DateTimeParameter(t) => s"'${t.replace("'", "''")}'"
66 | case SetParameter(set) => "(" + set.map(rec).mkString(", ") + ")"
67 | case RangeParameter(low, high) => rec(low) + " AND " + rec(high)
68 | case x => throw new IllegalArgumentException(x.getClass.toString)
69 | }
70 | rec(Type.convertParam(p, value))
71 | }
72 | .getOrElse {
73 | throw new MissingParameter(ParameterError(
74 | "unresolved parameter", placeholder.pos
75 | ))
76 | }
77 | placeholder.name match {
78 | case Some(key) if namedParameters.contains(key) =>
79 | parts ++ (param(placeholders.find(_._1.name.exists(_ == key)).map(_._2), namedParameters(key)) :: Nil)
80 | case None if pIndex < anonymousParameters.size =>
81 | val s = param(placeholders.filterNot(_._1.name.isDefined).drop(pIndex).headOption.map(_._2), anonymousParameters(pIndex))
82 | pIndex = pIndex + 1
83 | parts ++ (s :: Nil)
84 | case x =>
85 | throw new MissingParameter(ParameterError(
86 | s"""missing value for parameter ${placeholder.name.getOrElse("")}""", placeholder.pos
87 | ))
88 | }
89 | }.getOrElse {
90 | parts ++ (s"""?${placeholder.name.getOrElse("")}""" :: Nil)
91 | }
92 | }
93 | try {
94 | Right(print(show, 0, Nil).mkString.trim)
95 | } catch {
96 | case MissingParameter(err) => Left(err)
97 | }
98 | }
99 |
100 | case class Keyword(keyword: String) extends Show
101 | case class Identifier(identifier: String) extends Show
102 | case class Text(chars: String) extends Show
103 | case class Indented(group: Group) extends Show
104 | case class Parameter(placeholder: Placeholder) extends Show
105 | case class Group(items: List[Show]) extends Show
106 | case object Whitespace extends Show
107 | case object NewLine extends Show
108 |
109 | def line = NewLine
110 | def nest(show: Show*) = Indented(Group(show.toList))
111 | def keyword(str: String) = Keyword(str)
112 | def ident(str: String) = Identifier(str)
113 | def join(items: List[Show], separator: Show) = {
114 | Group(items.dropRight(1).flatMap(_ :: separator :: Nil) ++ items.lastOption.map(_ :: Nil).getOrElse(Nil))
115 | }
116 | def ~?(placeholder: Placeholder) = Parameter(placeholder)
117 |
118 | implicit def toText(str: String) = Text(str)
119 | }
120 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/Types.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | trait Type {
4 | val nullable: Boolean
5 | def withNullable(nullable: Boolean): this.type
6 | def canBeCastTo(other: Type): Boolean
7 | def show: String
8 | }
9 |
10 | case class INTEGER(nullable: Boolean = false) extends Type {
11 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type]
12 | def canBeCastTo(other: Type): Boolean = other match {
13 | case DECIMAL(_) | INTEGER(_) => true
14 | case _ => false
15 | }
16 | def show = "integer"
17 | }
18 | case class DECIMAL(nullable: Boolean = false) extends Type {
19 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type]
20 | def canBeCastTo(other: Type): Boolean = other match {
21 | case DECIMAL(_) => true
22 | case _ => false
23 | }
24 | def show = "decimal"
25 | }
26 | case class BOOLEAN(nullable: Boolean = false) extends Type {
27 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type]
28 | def canBeCastTo(other: Type): Boolean = other match {
29 | case BOOLEAN(_) => true
30 | case _ => false
31 | }
32 | def show = "boolean"
33 | }
34 | case class STRING(nullable: Boolean = false) extends Type {
35 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type]
36 | def canBeCastTo(other: Type): Boolean = other match {
37 | case STRING(_) => true
38 | case _ => false
39 | }
40 | def show = "string"
41 | }
42 | case class TIMESTAMP(nullable: Boolean = false) extends Type {
43 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type]
44 | def canBeCastTo(other: Type): Boolean = other match {
45 | case TIMESTAMP(_) => true
46 | case _ => false
47 | }
48 | def show = "timestamp"
49 | }
50 | case class DATE(nullable: Boolean = false) extends Type {
51 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type]
52 | def canBeCastTo(other: Type): Boolean = other match {
53 | case DATE(_)|TIMESTAMP(_) => true
54 | case _ => false
55 | }
56 | def show = "date"
57 | }
58 |
59 | case object NULL extends Type {
60 | val nullable = true
61 | def withNullable(nullable: Boolean) = this
62 | def canBeCastTo(other: Type): Boolean = true
63 | def show = "null"
64 | }
65 |
66 | case class SET(of: Type) extends Type {
67 | val nullable = of.nullable
68 | def withNullable(nullable: Boolean = nullable) = SET(of.withNullable(nullable)).asInstanceOf[this.type]
69 | def canBeCastTo(other: Type): Boolean = other match {
70 | case SET(x) => of.canBeCastTo(x)
71 | case _ => false
72 | }
73 | def show = s"set(${of.show})"
74 | }
75 |
76 | case class RANGE(of: Type) extends Type {
77 | val nullable = of.nullable
78 | def withNullable(nullable: Boolean = nullable) = RANGE(of.withNullable(nullable)).asInstanceOf[this.type]
79 | def canBeCastTo(other: Type): Boolean = other match {
80 | case RANGE(x) => of.canBeCastTo(x)
81 | case _ => false
82 | }
83 | def show = s"range(${of.show})"
84 | }
85 |
86 | object Type {
87 | case class MissingParameter(err: Err) extends Throwable
88 |
89 | def convertParamList(parameters: Map[String, Any], query : Query) = query.placeholders.right.map { placeholders =>
90 | for {
91 | (name, value) <- parameters
92 | typ <- placeholders.typeOf(Placeholder(Some(name)))
93 | } yield {
94 | name -> convertParam(typ, value)
95 | }
96 | }
97 |
98 | def convertParam(pType : Type, value : Any) : FilledParameter = pType match {
99 | case STRING(_) => value match {
100 | case x :: _ => convertParam(pType, x)
101 | case s: String => StringParameter(s)
102 | case n: Number => StringParameter(n.toString)
103 | case b: Boolean => StringParameter(b.toString)
104 | case c: Char => StringParameter(c.toString)
105 | case x => throw new MissingParameter(ParameterError(
106 | s"unexpected value $x (${x.getClass.getName}) for an SQL STRING parameter", -1
107 | ))
108 | }
109 | case INTEGER(_) => value match {
110 | case x :: _ => convertParam(pType, x)
111 | case x: Int => IntegerParameter(x)
112 | case x: Long => IntegerParameter(x.toInt)
113 | case x: String => try {
114 | IntegerParameter(x.toInt)
115 | } catch {
116 | case _: Throwable => throw new MissingParameter(ParameterError(
117 | s"unexpected value $x (${x.getClass.getName}) for an SQL INTEGER parameter", -1
118 | ))
119 | }
120 | case x => throw new MissingParameter(ParameterError(
121 | s"unexpected value $x (${x.getClass.getName}) for an SQL INTEGER parameter", -1
122 | ))
123 | }
124 | case DATE(_) => value match {
125 | case x :: _ => convertParam(pType, x)
126 | case x: String => DateTimeParameter(x)
127 | case x => throw new MissingParameter(ParameterError(
128 | s"unexpected value $x (${x.getClass.getName}) for an SQL DATE parameter", -1
129 | ))
130 | }
131 | case TIMESTAMP(_) => value match {
132 | case x :: _ => convertParam(pType, x)
133 | case x: String => DateTimeParameter(x)
134 | case x => throw new MissingParameter(ParameterError(
135 | s"unexpected value $x (${x.getClass.getName}) for an SQL TIMESTAMP parameter", -1
136 | ))
137 | }
138 | case SET(t) => value match {
139 | case x: Seq[_] => SetParameter(x.map(x =>convertParam(t, x)).toSet)
140 | case x => SetParameter(Set(convertParam(t, x)))
141 | }
142 | case RANGE(t) => value match {
143 | case (a,b) => RangeParameter(convertParam(t, a) ,convertParam(t, b))
144 | case a :: b :: _ => RangeParameter(convertParam(t, a), convertParam(t, b))
145 | case x => throw new MissingParameter(ParameterError(
146 | s"unexpected value $x (${x.getClass.getName}) for an SQL RANGE parameter", -1
147 | ))
148 | }
149 | }
150 |
151 | def from(nullable: Boolean): PartialFunction[String, Type] = {
152 | case "varchar" | "char" | "bpchar" | "string" => STRING(nullable)
153 | case x if x.contains("varchar") => STRING(nullable)
154 | case "int4" | "integer" => INTEGER(nullable)
155 | case "float" | "float4" | "numeric" | "decimal" => DECIMAL(nullable)
156 | case "timestamp" | "timestamptz" | "timestamp with time zone" => TIMESTAMP(nullable)
157 | case "date" => DATE(nullable)
158 | case "boolean" => BOOLEAN(nullable)
159 | }
160 | }
161 | trait FilledParameter
162 | case class DateTimeParameter(value : String) extends FilledParameter
163 | case class SetParameter(value : Set[FilledParameter]) extends FilledParameter
164 | case class RangeParameter(low : FilledParameter, High : FilledParameter) extends FilledParameter
165 | case class StringParameter(value : String) extends FilledParameter
166 | case class IntegerParameter(value : Int) extends FilledParameter
167 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/VizSQL.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object VizSQL {
4 |
5 | def parseQuery(sql: String, db: DB): Either[Err,Query] =
6 | (db.dialect.parser).parseStatement(sql).right.flatMap {
7 | case select @ SimpleSelect(_, _, _, _, _, _, _, _) => Right(Query(sql, select, db))
8 | case stmt => Left(SQLError("select expected", stmt.pos))
9 | }
10 |
11 | def parseOlapQuery(sql: String, db: DB): Either[Err,OlapQuery] = (for {
12 | query <- parseQuery(sql, db).right
13 | } yield OlapQuery(query)).right.flatMap { q =>
14 | q.query.error.map(err => Left(err)).getOrElse(Right(q))
15 | }
16 |
17 | }
18 |
19 | case class Query(sql: String, select: SimpleSelect, db: DB) {
20 |
21 | def tables = select.getTables(db)
22 | def columns = select.getColumns(db)
23 | def placeholders = select.getPlaceholders(db)
24 | def queryView = select.getQueryView(db)
25 |
26 | def error = (for {
27 | _ <- tables.right
28 | _ <- columns.right
29 | _ <- placeholders.right
30 | } yield ()).fold(Some.apply _, _ => None)
31 | }
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/h2/H2Dialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object h2 {
4 | implicit val dialect = new Dialect {
5 | lazy val parser = new SQL99Parser
6 | lazy val functions = SQLFunction.standard
7 | override def toString = "H2"
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveAST.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql._
4 | import com.criteo.vizatra.vizsql.Show._
5 |
6 | case class LateralView(inner: Relation, explodeFunction: FunctionCallExpression, tableAlias: String, columnAliases: List[String]) extends Relation {
7 | def getTables(db: DB) = {
8 | val result = for {
9 | innerTables <- inner.getTables(db).right
10 | placeholders <- inner.getPlaceholders(db).right
11 | } yield {
12 | val schemas = Schemas(innerTables.groupBy(_._1).map { case (maybeSchema, tableList) =>
13 | Schema(maybeSchema.getOrElse(""), tableList.map(_._2))
14 | }.toList)
15 | val newDb = db.copy(view = schemas)
16 | explodeFunction.resultType(newDb, placeholders).right.flatMap {
17 | case HiveUDTFResult(types) if columnAliases.length == types.length =>
18 | val columns = columnAliases.zip(types).map { case (alias, typ) =>
19 | Column(alias, typ)
20 | }
21 | Right(innerTables ++ Seq((None, Table(tableAlias, columns))))
22 | case HiveUDTFResult(types) =>
23 | Left(ParsingError(s"Expected ${types.size} aliases, got ${columnAliases.size}", pos))
24 | case _ =>
25 | Left(ParsingError(s"Expected a UDTF, got ${explodeFunction.show}", pos))
26 | }
27 | }
28 | result.joinRight
29 | }
30 |
31 | def visit = ???
32 |
33 | def getPlaceholders(db: DB) = inner.getPlaceholders(db)
34 |
35 | def show = ???
36 | }
37 |
38 | case class MapOrArrayAccessExpression(map: Expression, key: Expression) extends Expression {
39 | def getPlaceholders(db: DB, expectedType: Option[Type]) = for {
40 | mapPlaceholders <- map.getPlaceholders(db, None).right
41 | keyPlaceholders <- key.getPlaceholders(db, None).right
42 | } yield mapPlaceholders ++ keyPlaceholders
43 |
44 | def visit = ???
45 |
46 | def resultType(db: DB, placeholders: Placeholders) = (for {
47 | mapType <- map.resultType(db, placeholders).right
48 | keyType <- key.resultType(db, placeholders).right
49 | } yield (mapType, keyType) match {
50 | case (HiveMap(k, x), _) if k.canBeCastTo(keyType) => Right(x)
51 | case (HiveMap(k, _), _) => Left(TypeError(s"Expected key type ${keyType.show}, got ${k.show}", pos))
52 | case (HiveArray(x), INTEGER(_)) => Right(x)
53 | case (HiveArray(_), x) => Left(TypeError(s"Expected integer index, got ${x.show}", pos))
54 | case (x, _) => Left(TypeError(s"Expected map or array, got ${x.show}", pos))
55 | }).joinRight
56 |
57 | def show = map.show ~ "[" ~ key.show ~ "]"
58 | }
59 |
60 | case class StructAccessExpr(struct: Expression, field: String) extends Expression {
61 | def getPlaceholders(db: DB, expectedType: Option[Type]) = struct.getPlaceholders(db, None)
62 |
63 | def visit = ???
64 |
65 | def resultType(db: DB, placeholders: Placeholders) =
66 | struct.resultType(db, placeholders).right.map {
67 | case HiveStruct(cols) =>
68 | cols.find(_.name.toLowerCase == field.toLowerCase).map(_.typ).toRight(SchemaError(s"Field $field not found", pos))
69 | case HiveArray(HiveStruct(cols)) =>
70 | cols.find(_.name.toLowerCase == field.toLowerCase).map(c => HiveArray(c.typ)).toRight(SchemaError(s"Field $field not found", pos))
71 | case x =>
72 | Left(TypeError(s"Expected struct, got ${x.show}", pos))
73 | }.joinRight
74 |
75 | def show = struct.show ~ "." ~ field
76 |
77 | override def toColumnName = field
78 | }
79 |
80 | case class ColumnOrStructAccessExpression(column: ColumnIdent) extends Expression {
81 | def getPlaceholders(db: DB, expectedType: Option[Type]) = Right(Placeholders())
82 |
83 | def resultType(db: DB, placeholders: Placeholders) = column match {
84 | case ColumnIdent(c3, Some(TableIdent(c2, Some(c1)))) =>
85 | (for {
86 | schema <- db.view.getSchema(c1).right
87 | table <- schema.getTable(c2).right
88 | column <- table.getColumn(c3).right
89 | } yield column.typ
90 | ).left.flatMap { _ =>
91 | val newCol = ColumnOrStructAccessExpression(ColumnIdent(c2, Some(TableIdent(c1, None))))
92 | newCol.pos = pos
93 | val structAccess = StructAccessExpr(newCol, c3)
94 | structAccess.pos = pos
95 | structAccess.resultType(db, placeholders)
96 | }
97 | case ColumnIdent(c2, Some(TableIdent(c1, None))) =>
98 | (for {
99 | table <- db.view.getNonAmbiguousTable(c1).right.map(_._2).right
100 | column <- table.getColumn(c2).right
101 | } yield column.typ
102 | ).left.flatMap { _ =>
103 | val newCol = ColumnOrStructAccessExpression(ColumnIdent(c1, None))
104 | newCol.pos = pos
105 | val structAccess = StructAccessExpr(newCol, c2)
106 | structAccess.pos = pos
107 | structAccess.resultType(db, placeholders)
108 | }
109 | case ColumnIdent(c, None) =>
110 | (for {
111 | column <- db.view.getNonAmbiguousColumn(c).right.map(_._3).right
112 | } yield column.typ
113 | ).left.map(SchemaError(_, column.pos))
114 | }
115 |
116 | def show = column.show
117 |
118 | def visit = Nil
119 |
120 | override def toColumnName = column.name
121 | }
122 |
123 | case object LeftSemiJoin extends Join {
124 | def show = keyword("left") ~- keyword("semi") ~- keyword("join")
125 | }
126 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveDialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql._
4 |
5 | import scala.util.parsing.input.CharArrayReader.EofCh
6 |
7 | case class HiveDialect(udfs: Map[String, SQLFunction]) extends Dialect {
8 |
9 | lazy val parser = new SQL99Parser {
10 | override val lexical = new SQLLexical {
11 | override val keywords = SQL99Parser.keywords ++ Set(
12 | "rlike", "regexp", "limit", "lateral", "view", "distribute", "sort", "cluster", "semi",
13 | "tablesample", "bucket", "out", "of", "percent")
14 | override val delimiters = SQL99Parser.delimiters ++ Set("!=", "<=>", "||", "&&", "%", "&", "|", "^", "~", "`", "==")
15 | override val customToken =
16 | ( '`' ~> rep1(chrExcept('`', '\n', EofCh)) <~ '`' ^^ { x => Identifier(x.mkString) }
17 | | '\"' ~> rep1(chrExcept('\"', '\n', EofCh)) <~ '\"' ^^ { x => StringLit(x.mkString) }
18 | | '`' ~> failure("unclosed backtick")
19 | | '\'' ~ 'N' ~ 'a' ~ 'N' ~ '\'' ^^^ DecimalLit("NaN")
20 | )
21 | }
22 | override val comparisonOperators = SQL99Parser.comparisonOperators ++ Set("!=", "<=>", "==")
23 | override val likeOperators = SQL99Parser.likeOperators ++ Set("rlike", "regexp")
24 | override val orOperators = SQL99Parser.orOperators + "||"
25 | override val andOperators = SQL99Parser.andOperators + "&&"
26 | override val multiplicationOperators = SQL99Parser.multiplicationOperators ++ Set("%", "&", "|", "^")
27 | override val unaryOperators = SQL99Parser.unaryOperators + "~"
28 | override val typeMap: Map[String, TypeLiteral] = HiveDialect.typeMap
29 |
30 | lazy val mapOrArrayAccessExpr =
31 | expr ~ ("[" ~> expr <~ "]") ^^ { case m ~ k => MapOrArrayAccessExpression(m, k) }
32 |
33 | lazy val structAccessExpr =
34 | expr ~ ("." ~> ident) ^^ { case e ~ f => StructAccessExpr(e, f) }
35 |
36 | override lazy val simpleExpr = (_: Parser[Expression]) =>
37 | ( literal ^^ LiteralExpression
38 | | structAccessExpr
39 | | mapOrArrayAccessExpr
40 | | function
41 | | countStar
42 | | cast
43 | | caseWhen
44 | | column ^^ ColumnOrStructAccessExpression
45 | | "(" ~> select <~ ")" ^^ SubSelectExpression
46 | | "(" ~> expr <~ ")" ^^ ParenthesedExpression
47 | | expressionPlaceholder
48 | )
49 |
50 | lazy val tablesampleRelation =
51 | relation <~ "tablesample" ~ "(" ~ tablesampleExpr ~ ")"
52 |
53 | lazy val tablesampleExpr =
54 | ( "bucket" ~ integerLiteral ~ "out" ~ "of" ~ integerLiteral ~ opt("on" ~ expr)
55 | | integerLiteral ~ "percent"
56 | | integerLiteral ~ elem("bytelength", e => List("b", "B", "k", "K", "m", "g", "G").contains(e.chars))
57 | )
58 |
59 | lazy val lateralViewRelation =
60 | relation ~ ("lateral" ~ "view" ~ opt("outer") ~> function) ~ ident ~ ("as" ~> repsep(ident, ",")) ^^ {
61 | case r ~ f ~ tblAlias ~ colAliases => LateralView(r, f, tblAlias, colAliases)
62 | }
63 |
64 | override lazy val join =
65 | ( opt("inner") ~ "join" ^^^ InnerJoin
66 | | "left" ~ opt("outer") ~ "join" ^^^ LeftJoin
67 | | "right" ~ opt("outer") ~ "join" ^^^ RightJoin
68 | | "full" ~ opt("outer") ~ "join" ^^^ FullJoin
69 | | "cross" ~ "join" ^^^ CrossJoin
70 | | "left" ~ "semi" ~ "join" ^^^ LeftSemiJoin
71 | )
72 |
73 | override lazy val relation: PackratParser[Relation] = pos(
74 | ( lateralViewRelation
75 | | tablesampleRelation
76 | | joinRelation
77 | | singleTableRelation
78 | | subSelectRelation
79 |
80 | | failure("table, join or subselect expected")
81 | )
82 | )
83 |
84 | lazy val sortBy =
85 | "sort" ~ "by" ~> repsep(sortExpr, ",")
86 |
87 | lazy val distributeBy =
88 | "distribute" ~ "by" ~> repsep(expr, ",") ^^ { c => c.map(SortExpression(_, None)) }
89 |
90 | lazy val clusterBy =
91 | "cluster" ~ "by" ~> repsep(expr, ",") ^^ { c => c.map(SortExpression(_, None)) }
92 |
93 | lazy val xxxBy = // FIXME: losing lots of information by removing the keywords
94 | ( orderBy ~ clusterBy ^^ { case o ~ c => o ++ c }
95 | | distributeBy ~ sortBy ^^ { case d ~ s => d ++ s }
96 | | orderBy
97 | | sortBy
98 | | distributeBy
99 | | clusterBy
100 | )
101 |
102 | override lazy val simpleSelect: Parser[SimpleSelect] =
103 | "select" ~> opt(distinct) ~ rep1sep(projections, ",") ~ opt(relations) ~ opt(filters) ~ opt(groupBy) ~ opt(having) ~ opt(xxxBy) ~ opt(limit) ^^ {
104 | case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => SimpleSelect(d, p, r.getOrElse(Nil), f, g.getOrElse(Nil), h, o.getOrElse(Nil), l)
105 | }
106 | }
107 |
108 | override def functions = udfs.orElse {
109 | case "min" | "max" => new SQLFunction1 {
110 | def result = { case (_, t) => Right(t) }
111 | }
112 | case "avg" | "sum" => new SQLFunction1 {
113 | def result = {
114 | case (_, t @ (INTEGER(_) | DECIMAL(_))) => Right(t)
115 | case (arg, _) => Left(TypeError("expected numeric argument", arg.pos))
116 | }
117 | }
118 | case "now" => new SQLFunction0 {
119 | def result = Right(TIMESTAMP())
120 | }
121 | case "concat" => new SQLFunctionX {
122 | def result = {
123 | case (_, t1) :: _ => Right(STRING(t1.nullable))
124 | }
125 | }
126 | case "coalesce" => new SQLFunctionX {
127 | def result = {
128 | case (_, t1) :: tail => Right(t1)
129 | }
130 | }
131 | case "count" => new SQLFunctionX {
132 | override def result = {
133 | case _ :: _ => Right(INTEGER(false))
134 | }
135 | }
136 | case "e" | "pi" =>
137 | SimpleFunction0(DECIMAL(true))
138 | case "current_date" =>
139 | SimpleFunction0(DATE(true))
140 | case "current_user" =>
141 | SimpleFunction0(STRING(true))
142 | case "current_timestamp" =>
143 | SimpleFunction0(TIMESTAMP(true))
144 | case "isnull" | "isnotnull" =>
145 | SimpleFunction1(BOOLEAN(true))
146 | case "variance" | "var_pop" | "var_samp" | "stddev_pop" | "stddev_samp" | "exp" | "ln" | "log10" |
147 | "log2" | "sqrt" | "abs" | "sin" | "asin" | "cos" | "acos" | "tan" | "atan" | "degrees" |
148 | "radians" | "sign" | "cbrt" =>
149 | SimpleFunction1(DECIMAL(true))
150 | case "year" | "quarter" | "month" | "day" | "dayofmonth" | "hour" | "minute" | "second" | "weekofyear" |
151 | "ascii" | "length" | "levenshtein" | "crc32" | "ntile" | "floor" | "ceil" | "ceiling" |
152 | "factorial" | "shiftleft" | "shiftright" | "shiftrightunsigned" | "size" =>
153 | SimpleFunction1(INTEGER(true))
154 | case "to_date" | "last_day" | "base64" | "lower" | "lcase" | "ltrim" | "reverse" | "rtrim" | "space" |
155 | "trim" | "unbase64" | "upper" | "ucase" | "initcap" | "soundex" | "md5" | "sha" | "sha1" | "bin" |
156 | "hex" | "unhex" | "binary" =>
157 | SimpleFunction1(STRING(true))
158 | case "in_file" | "array_contains" =>
159 | SimpleFunction2(BOOLEAN(true))
160 | case "months_between" | "covar_pop" | "covar_samp" | "corr" | "log" | "pow" =>
161 | SimpleFunction2(DECIMAL(true))
162 | case "datediff" | "find_in_set" | "instr" =>
163 | SimpleFunction2(INTEGER(true))
164 | case "date_add" | "date_sub" | "add_months" | "next_day" | "trunc" | "date_format" | "decode" | "encode" |
165 | "format_number" | "get_json_object" | "repeat" | "sha2" | "aes_encrypt" | "aes_decrypt" =>
166 | SimpleFunction2(STRING(true))
167 | case "from_utc_timestamp" | "to_utc_timestamp" =>
168 | SimpleFunction2(TIMESTAMP(true))
169 | case "split" =>
170 | SimpleFunction2(HiveArray(STRING(true)))
171 | case "histogram_numeric" =>
172 | SimpleFunction2(HiveArray(HiveStruct(List(Column("x", DECIMAL(true)), Column("y", DECIMAL(true))))))
173 | case "round" | "bround" =>
174 | SimpleFunctionX(1, 2, INTEGER(true))
175 | case "rand" =>
176 | SimpleFunctionX(0, 1, INTEGER(true))
177 | case "lpad" | "regexp_replace" | "rpad" | "translate" | "conv" =>
178 | SimpleFunctionX(3, 3, STRING(true))
179 | case "from_unixtime" =>
180 | SimpleFunctionX(1, 2, STRING(true))
181 | case "unix_timestamp" =>
182 | SimpleFunctionX(0, 2, INTEGER(true))
183 | case "context_ngrams" | "ngrams" =>
184 | SimpleFunctionX(4, 4, HiveArray(HiveStruct(List(Column("ngram", STRING(true)), Column("estfrequency", DECIMAL(true))))))
185 | case "concat_ws" | "printf" =>
186 | SimpleFunctionX(2, None, STRING(true))
187 | case "locate" =>
188 | SimpleFunctionX(2, 3, INTEGER(true))
189 | case "parse_url" | "substr" | "substring" | "regexp_extract" =>
190 | SimpleFunctionX(2, 3, STRING(true))
191 | case "sentences" =>
192 | SimpleFunctionX(1, 3, HiveArray(HiveArray(STRING(true))))
193 | case "str_to_map" =>
194 | SimpleFunctionX(1, 3, HiveMap(STRING(true), STRING(true)))
195 | case "substring_index" =>
196 | SimpleFunctionX(3, 3, STRING(true))
197 | case "hash" =>
198 | SimpleFunctionX(1, None, INTEGER(true))
199 | case "pmod" =>
200 | new SQLFunction2 {
201 | override def result = {
202 | case ((_, t @ (INTEGER(_) | DECIMAL(_))), _) => Right(t)
203 | }
204 | }
205 | case "positive" | "negative" =>
206 | new SQLFunction1 {
207 | override def result = {
208 | case (_, t @ (INTEGER(_) | DECIMAL(_))) => Right(t)
209 | }
210 | }
211 | case "percentile" =>
212 | new SQLFunction2 {
213 | override def result = {
214 | case (_, (_, a: HiveArray)) => Right(HiveArray(DECIMAL(true)))
215 | case _ => Right(DECIMAL(true))
216 | }
217 | }
218 | case "percentile_approx" =>
219 | new SQLFunctionX {
220 | override def result = {
221 | case l if l.length >= 2 && l.length <= 3 && l(1)._2.isInstanceOf[HiveArray] => Right(HiveArray(DECIMAL(true)))
222 | case _ => Right(DECIMAL(true))
223 | }
224 | }
225 | case "collect_set" | "collect_list" =>
226 | new SQLFunction1 {
227 | override def result = {
228 | case (_, t) => Right(HiveArray(t))
229 | }
230 | }
231 | case "if" =>
232 | new SQLFunctionX {
233 | override def result = {
234 | case (_, BOOLEAN(_)) :: (_, t1) :: (_, t2) :: Nil => Right(t1)
235 | case (arg, _) :: _ => Left(TypeError("Expected boolean argument", arg.pos))
236 | }
237 | }
238 | case "map_keys" =>
239 | new SQLFunction1 {
240 | override def result = {
241 | case (_, HiveMap(k, _)) => Right(HiveArray(k))
242 | }
243 | }
244 | case "map_values" =>
245 | new SQLFunction1 {
246 | override def result = {
247 | case (_, HiveMap(_, v)) => Right(HiveArray(v))
248 | }
249 | }
250 | case "sort_array" =>
251 | new SQLFunction1 {
252 | override def result = {
253 | case (_, a: HiveArray) => Right(a)
254 | }
255 | }
256 | case "map" =>
257 | new SQLFunctionX {
258 | override def result = {
259 | case (_, k) :: (_, v) :: _ => Right(HiveMap(k, v))
260 | }
261 | }
262 | case "struct" =>
263 | new SQLFunctionX {
264 | override def result = {
265 | case l =>
266 | val cols = l.zipWithIndex.map { case ((_, t), i) => Column(s"col$i", t) }
267 | Right(HiveStruct(cols))
268 | }
269 | }
270 | case "named_struct" =>
271 | new SQLFunctionX {
272 | override def result = {
273 | case l if l.length % 2 == 0 =>
274 | val cols = l.grouped(2).toList.map {
275 | case (LiteralExpression(StringLiteral(name)), _) :: (_, t) :: Nil => Column(name, t)
276 | case x => sys.error(s"Unsupported expression in named_struct: $x")
277 | }
278 | Right(HiveStruct(cols))
279 | }
280 | }
281 | case "array" =>
282 | new SQLFunctionX {
283 | override def result = {
284 | case (_, t) :: _ => Right(HiveArray(t))
285 | }
286 | }
287 | case "explode" =>
288 | new SQLFunction1 {
289 | override def result = {
290 | case (_, HiveArray(elem)) => Right(HiveUDTFResult(List(elem)))
291 | case (_, HiveMap(key, value)) => Right(HiveUDTFResult(List(key, value)))
292 | }
293 | }
294 | case "inline" =>
295 | new SQLFunction1 {
296 | override def result = {
297 | case (_, HiveArray(HiveStruct(cols))) => Right(HiveUDTFResult(cols.map(_.typ)))
298 | }
299 | }
300 | case "json_tuple" | "parse_url_tuple" =>
301 | new SQLFunctionX {
302 | override def result = {
303 | case l if l.length >= 2 => Right(HiveUDTFResult(l.tail.map(_ => STRING(true))))
304 | }
305 | }
306 | case "posexplode" =>
307 | new SQLFunction1 {
308 | override def result = {
309 | case (_, HiveArray(elem)) => Right(HiveUDTFResult(List(INTEGER(true), elem)))
310 | }
311 | }
312 | case "stack" =>
313 | new SQLFunctionX {
314 | override def result = {
315 | case l if l.length >= 2 => Right(HiveUDTFResult(l.tail.map(_._2)))
316 | }
317 | }
318 | case "java_method" | "reflect" =>
319 | new SQLFunctionX {
320 | override def result = {
321 | case l if l.length >= 2 => Right(STRING(true))
322 | }
323 | }
324 | }
325 | }
326 |
327 | object HiveDialect {
328 |
329 | val typeMap = Map( // vizsql types don't matter that much
330 | "tinyint" -> IntegerTypeLiteral,
331 | "smallint" -> IntegerTypeLiteral,
332 | "int" -> IntegerTypeLiteral,
333 | "bigint" -> IntegerTypeLiteral,
334 | "float" -> DecimalTypeLiteral,
335 | "double" -> DecimalTypeLiteral,
336 | "decimal" -> DecimalTypeLiteral,
337 | "timestamp" -> TimestampTypeLiteral,
338 | "string" -> VarcharTypeLiteral,
339 | "boolean" -> BooleanTypeLiteral,
340 | "binary" -> VarcharTypeLiteral
341 | )
342 | }
343 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveFunctions.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql.Show._
4 | import com.criteo.vizatra.vizsql._
5 |
6 | case class SimpleFunction0(resultType: Type) extends SQLFunction0 {
7 | override def result = Right(resultType)
8 | }
9 |
10 | case class SimpleFunction1(resultType: Type) extends SQLFunction1 {
11 | override def result = {
12 | case _ => Right(resultType)
13 | }
14 | }
15 |
16 | case class SimpleFunction2(resultType: Type) extends SQLFunction2 {
17 | override def result = {
18 | case _ => Right(resultType)
19 | }
20 | }
21 |
22 | case class SimpleFunctionX(minArgs: Int, maxArgs: Option[Int], resultType: Type) extends SQLFunctionX {
23 | override def result = {
24 | case l if l.length >= minArgs && maxArgs.forall(l.length <= _) =>
25 | Right(resultType)
26 | case l =>
27 | Left(ParsingError(s"wrong argument list size ${l.length}", l.headOption.fold(0)(_._1.pos)))
28 | }
29 | }
30 |
31 | object SimpleFunctionX {
32 | def apply(minArgs: Int, maxArgs: Int, resultType: Type): SimpleFunctionX = apply(minArgs, Some(maxArgs), resultType)
33 | }
34 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveTypes.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql.{Column, Type}
4 |
5 | case class HiveArray(elem: Type) extends Type {
6 | val nullable = true
7 |
8 | def withNullable(nullable: Boolean) = this
9 |
10 | def canBeCastTo(other: Type) = this == other
11 |
12 | def show = s"array<${elem.show}>"
13 | }
14 |
15 | case class HiveStruct(elems: List[Column]) extends Type {
16 | val nullable = true
17 |
18 | def withNullable(nullable: Boolean) = this
19 |
20 | def canBeCastTo(other: Type) = this == other
21 |
22 | def show = s"struct<${elems.map { col => s"${col.name}:${col.typ.show}" }.mkString(",")}>"
23 | }
24 |
25 | case class HiveMap(key: Type, value: Type) extends Type {
26 | val nullable = true
27 |
28 | def withNullable(nullable: Boolean) = this
29 |
30 | def canBeCastTo(other: Type) = this == other
31 |
32 | def show = s"map<${key.show},${value.show}>"
33 | }
34 |
35 | case class HiveUDTFResult(types: List[Type]) extends Type {
36 | val nullable = true
37 |
38 | def withNullable(nullable: Boolean) = this
39 |
40 | def canBeCastTo(other: Type) = this == other
41 |
42 | def show = s"udtfresult<${types.map(_.show).mkString(",")}>"
43 | }
44 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/TypeParser.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql.hive
2 |
3 | import com.criteo.vizatra.vizsql._
4 |
5 | import scala.util.parsing.combinator.PackratParsers
6 | import scala.util.parsing.combinator.lexical.Lexical
7 | import scala.util.parsing.combinator.syntactical.TokenParsers
8 | import scala.util.parsing.input.CharArrayReader.EofCh
9 |
10 | class TypeParser extends TokenParsers with PackratParsers {
11 | type Tokens = TypeLexical
12 | override val lexical = new TypeLexical
13 |
14 | class TypeLexical extends Lexical {
15 | case class ColumnName(chars: String) extends Token
16 | case class Keyword(chars: String) extends Token
17 | case class Symbol(chars: String) extends Token
18 |
19 | val keywords = Set(
20 | "array", "map", "struct", "tinyint", "smallint", "int", "bigint", "integer",
21 | "double", "float", "decimal", "string", "binary", "boolean", "timestamp"
22 | )
23 | val symbols = Set(",", ":", "<", ">")
24 |
25 | lazy val columnNameOrKeyword = rep1(chrExcept('`', ',', ':', '<', '>', EofCh)) ^^ { chrs =>
26 | val s = chrs.mkString
27 | if (keywords(s.toLowerCase)) Keyword(s.toLowerCase)
28 | else ColumnName(s)
29 | }
30 |
31 | lazy val token =
32 | ( columnNameOrKeyword
33 | | (elem(',') | elem(':') | elem('<') | elem('>')) ^^ { c => Symbol(c.toString) }
34 | | '`' ~> rep1(chrExcept('`', EofCh)) <~ '`' ^^ { chrs => ColumnName(chrs.mkString) }
35 | | '`' ~> failure("Unclosed backtick")
36 | )
37 | lazy val whitespace = rep(whitespaceChar)
38 | }
39 |
40 | private val tokenParserCache = collection.mutable.HashMap.empty[String, Parser[String]]
41 | implicit def string2KeywordOrSymbolParser(chars: String): Parser[String] = tokenParserCache.getOrElseUpdate(
42 | chars,
43 | if (lexical.keywords(chars)) accept(lexical.Keyword(chars)) ^^ (_.chars) withFailureMessage s"$chars expected"
44 | else if (lexical.symbols(chars)) accept(lexical.Symbol(chars)) ^^ (_.chars) withFailureMessage s"$chars expected"
45 | else sys.error("Invalid parser definition")
46 | )
47 |
48 | lazy val name = elem("column name", token => token.isInstanceOf[lexical.ColumnName] || token.isInstanceOf[lexical.Keyword]) ^^ (_.chars)
49 |
50 | lazy val array = "array" ~ "<" ~> anyType <~ ">" ^^ { case el => HiveArray(el) }
51 |
52 | lazy val map = ("map" ~ "<" ~> anyType) ~ ("," ~> anyType <~ ">") ^^ { case k ~ v => HiveMap(k, v) }
53 |
54 | lazy val struct = "struct" ~ "<" ~> structFields <~ ">" ^^ { case l =>
55 | HiveStruct(l.map { case n ~ t => Column(n, t) })
56 | }
57 |
58 | lazy val structFields = rep1sep(name ~ (":" ~> anyType), ",")
59 |
60 | lazy val simpleType: Parser[Type] =
61 | ( ("tinyint" | "smallint" | "int" | "bigint" | "integer") ^^^ INTEGER(true)
62 | | ("double" | "float" | "decimal") ^^^ DECIMAL(true)
63 | | ("string" | "binary") ^^^ STRING(true)
64 | | "boolean" ^^^ BOOLEAN(true)
65 | | "timestamp" ^^^ TIMESTAMP(true)
66 | | failure("expected type literal")
67 | )
68 |
69 | lazy val anyType: PackratParser[Type] =
70 | ( simpleType
71 | | array
72 | | struct
73 | | map
74 | )
75 |
76 | def parseType(typeName: String): Either[String, Type] =
77 | phrase(anyType)(new lexical.Scanner(typeName)) match {
78 | case Success(t, _) => Right(t)
79 | case NoSuccess(msg, _) => Left(s"$msg for $typeName")
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hsqldb/HsqlDBDialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object hsqldb {
4 | implicit val dialect = new Dialect {
5 | lazy val parser = new SQL99Parser
6 | lazy val functions = SQLFunction.standard orElse {
7 | case "timestamp" => new SQLFunction2 {
8 | def result = { case ((_, t1), (_, t2)) => Right(TIMESTAMP(nullable = t1.nullable || t2.nullable)) }
9 | }
10 | }: PartialFunction[String,SQLFunction]
11 | override def toString = "HSQLDB"
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/mysql/MySQLDialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object mysql {
4 | implicit val dialect = new Dialect {
5 | lazy val parser = new SQL99Parser
6 | lazy val functions = SQLFunction.standard
7 | override def toString = "MySQL"
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/postgresql/PostgresqlDialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object postgresql {
4 | implicit val dialect = new Dialect {
5 | lazy val parser = new SQL99Parser
6 | lazy val functions = SQLFunction.standard orElse {
7 | case "date_trunc" => new SQLFunction2 {
8 | def result = { case (_, (_, t)) => Right(TIMESTAMP(nullable = t.nullable)) }
9 | }
10 | case "zeroifnull" => new SQLFunction1 {
11 | def result = { case (_, t) => Right(t.withNullable(false)) }
12 | }
13 | case "nullifzero" => new SQLFunction1 {
14 | def result = { case (_, t) => Right(t.withNullable(true)) }
15 | }
16 | case "nullif" => new SQLFunction2 {
17 | def result = { case ((_, t1), (_, t2)) => Right(t1.withNullable(true)) }
18 | }
19 | case "to_char" => new SQLFunction2 {
20 | def result = { case ((_, t1), (_, t2)) => Right(STRING()) }
21 | }
22 | }: PartialFunction[String,SQLFunction]
23 | override def toString = "Postgres"
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/sql99/SQL99.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object sql99 {
4 | implicit val dialect = new Dialect {
5 | lazy val parser = new SQL99Parser
6 | lazy val functions = SQLFunction.standard
7 | override def toString = "SQL-99"
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/vertica/VerticaDialect.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | object vertica {
4 | implicit val dialect = new Dialect {
5 | lazy val parser = new SQL99Parser
6 | lazy val functions = postgresql.dialect.functions orElse {
7 | case "datediff" => new SQLFunction3 {
8 | def result = { case (_, (_, t1), (_, t2)) => Right(INTEGER(nullable = t1.nullable ||t2.nullable)) }
9 | }
10 | case "to_timestamp" => new SQLFunction1 {
11 | def result = { case (_, t1) => Right(TIMESTAMP(nullable = t1.nullable)) }
12 | }
13 | }: PartialFunction[String,SQLFunction]
14 | override def toString = "Vertica"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/olap/Olap.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | case class OlapError(val msg: String, val pos: Int) extends Err
4 | case class OlapSelection(parameters: Map[String,Any], filters: Map[String,Any])
5 | case class OlapProjection(dimensions: Set[String], metrics: Set[String])
6 | case class OlapQuery(query: Query) {
7 |
8 | def getProjections: Either[Err,List[ExpressionProjection]] = {
9 | val validProjections = query.select.projections.collect { case e @ ExpressionProjection(_, Some(_)) => e }
10 | (query.select.projections diff validProjections) match {
11 | case Nil => Right(validProjections)
12 | case oops :: _ => Left(OlapError("Please specify an expression with a label", oops.pos))
13 | }
14 | }
15 |
16 | def getParameters: Either[Err,List[String]] = {
17 | query.placeholders.right.flatMap { placeholders =>
18 | placeholders.foldRight(Right(Nil):Either[Err,List[String]]) { (p, acc) =>
19 | acc.right.flatMap { params =>
20 | p match {
21 | case (Placeholder(Some(name)), _) => Right(name :: params)
22 | case (p, _) => Left(OlapError("All parameters must be named", p.pos))
23 | }
24 | }
25 | }
26 | }
27 | }
28 |
29 | private def getAllDimensions: Either[Err,(List[String],List[String])] = for {
30 | columns <- query.columns.right
31 | projections <- getProjections.right
32 | } yield projections.collect {
33 | case ExpressionProjection(e, Some(name)) if query.select.groupBy.flatMap(_.expressions).contains(e) => name
34 | }.partition { dim =>
35 | columns.find(_.name == dim).map(_.typ).collect({
36 | case DATE(_) | TIMESTAMP(_) => true
37 | }).getOrElse(false)
38 | }
39 |
40 | def getDimensions: Either[Err,List[String]] = getAllDimensions.right.map(_._2)
41 | def getTimeDimensions: Either[Err,List[String]] = getAllDimensions.right.map(_._1)
42 |
43 | def getMetrics: Either[Err,List[String]] = for {
44 | projections <- getProjections.right
45 | } yield projections.collect {
46 | case ExpressionProjection(e, Some(name)) if !query.select.groupBy.flatMap(_.expressions).contains(e) => name
47 | }
48 |
49 | def getProjection(dimensionOrMetric: String): Either[Err,ExpressionProjection] = {
50 | query.select.projections.collectFirst {
51 | case e @ ExpressionProjection(_, Some(`dimensionOrMetric`)) => e
52 | }.toRight(OlapError(s"Not found $dimensionOrMetric", query.select.pos))
53 | }
54 |
55 | def computeQuery(projection: OlapProjection, selection: OlapSelection): Either[Err,String] = for {
56 | projectionExpressions <- {
57 | (if((projection.dimensions ++ projection.metrics).isEmpty) {
58 | Right[Err,List[ExpressionProjection]]((query.select.projections.collect {
59 | case e @ ExpressionProjection(_, _) => e
60 | }))
61 | } else {
62 | (projection.dimensions ++ projection.metrics).foldRight(Right(Nil):Either[Err,List[ExpressionProjection]]) {
63 | (x, acc) => for(a <- acc.right; b <- getProjection(x).right) yield b :: a
64 | }
65 | }).right
66 | }
67 | select <- OlapQuery.rewriteSelect(query.select, query.db, projectionExpressions, selection.filters.keySet ++ selection.parameters.keySet).right
68 | sql <- select.fillParameters(query.db, selection.parameters ++ selection.filters).right
69 | filledQuery <- VizSQL.parseQuery(sql, query.db).right
70 | } yield {
71 | Optimizer.optimize(filledQuery).sql
72 | }
73 |
74 | def rewriteMetricAggregate(metric: String): Either[Err,Option[PostAggregate]] = for {
75 | expression <- getProjection(metric).right.map(_.expression).right
76 | } yield {
77 | def rewriteRecursively(expression: Expression): Option[PostAggregate] = expression match {
78 | case expr @ FunctionCallExpression("sum", _, _) => Some(SumPostAggregate(expr))
79 | case FunctionCallExpression("zeroifnull", _, expr :: Nil) => rewriteRecursively(expr)
80 | case FunctionCallExpression("nullifzero", _, expr :: Nil) => rewriteRecursively(expr)
81 | case MathExpression("/", left, right) =>
82 | for {
83 | leftPA <- rewriteRecursively(left)
84 | rightPA <- rewriteRecursively(right)
85 | } yield {
86 | DividePostAggregate(leftPA, rightPA)
87 | }
88 | case _ => None
89 | }
90 | rewriteRecursively(expression)
91 | }
92 |
93 | }
94 |
95 | object OlapQuery {
96 |
97 |
98 | def rewriteSelect(select: Select, db: DB, keepProjections: List[ExpressionProjection], availableParams: Set[String]): Either[Err, Select] = {
99 | select match {
100 | case s: SimpleSelect => for {
101 | projections <- Right(s.projections.filter(x => keepProjections.contains(x))).right
102 | orderBy <- Right(s.orderBy.filter(f => keepProjections.map(_.expression).contains(f.expression))).right
103 | where <- Right(rewriteWhereCondition(s, availableParams)).right
104 | groupBy <- Right(rewriteGroupBy(s, keepProjections.map(_.expression))).right
105 | tables <- (keepProjections.filter(s.projections.contains).map(_.expression) ++ where.toList).foldRight(Right(Nil):Either[Err,List[String]]) {
106 | (p, acc) => for(a <- acc.right; b <- tablesFor(s, db, p).right) yield b ++ a
107 | }.right
108 | from <- Right(rewriteRelations(s, db, tables.toSet)).right
109 | rewrittenSelect <- Right(s.copy(
110 | projections = projections,
111 | relations = from,
112 | where = where,
113 | groupBy = groupBy,
114 | orderBy = orderBy)).right
115 | optimizedSelect <- optimizeSubSelect(rewrittenSelect, db, availableParams).right
116 | } yield {
117 | optimizedSelect
118 | }
119 | case UnionSelect(left, d, right) => for {
120 | rewrittenLeft <- rewriteSelect(left, db, keepProjections, availableParams).right
121 | rewrittenRight <- rewriteSelect(right, db, keepProjections, availableParams).right
122 | } yield {
123 | UnionSelect(rewrittenLeft, d, rewrittenRight)
124 | }
125 | }
126 | }
127 |
128 | def getTableReferences(select: Select, db: DB, expression: Expression) = {
129 | select.getQueryView(db).right.flatMap { db =>
130 | expression.getColumnReferences.foldRight(Right(Nil):Either[Err,List[Table]]) {
131 | (column, acc) => for {
132 | a <- acc.right
133 | b <- (column match {
134 | case ColumnIdent(_, Some(TableIdent(tableName, Some(schemaName)))) =>
135 | for {
136 | schema <- db.view.getSchema(schemaName).right
137 | table <- schema.getTable(tableName).right
138 | } yield table
139 | case ColumnIdent(_, Some(TableIdent(tableName, None))) =>
140 | db.view.getNonAmbiguousTable(tableName).right.map(_._2)
141 | case ColumnIdent(column, None) =>
142 | db.view.getNonAmbiguousColumn(column).right.map(_._2)
143 | }).left.map(SchemaError(_, column.pos)).right
144 | } yield b :: a
145 | }
146 | }
147 | }
148 |
149 | def tablesFor(select: Select, db: DB, expression: Expression): Either[Err,List[String]] = {
150 | getTableReferences(select, db, expression).right.map(_.map(_.name.toLowerCase).distinct)
151 | }
152 |
153 | def optimizeSubSelect(select: SimpleSelect, db: DB, availableParams: Set[String]): Either[Err,Select] = {
154 | def optimizeRecursively(relation: Relation): Either[Err,Relation] = relation match {
155 | case rel @ SubSelectRelation(subSelect, table) =>
156 | for {
157 | db <- select.getQueryView(db).right
158 | optimized <- {
159 | val allExpressions = select.projections.collect {
160 | case ExpressionProjection(e, _) => e
161 | } ++ select.relations.flatMap(r => r :: r.visit).distinct.collect {
162 | case JoinRelation(_, _, _, Some(e)) => e
163 | } ++ select.where.toList
164 |
165 | val columnsUsed = allExpressions.flatMap(_.getColumnReferences).collect {
166 | case ColumnIdent(col, Some(TableIdent(`table`, None))) => col
167 | case ColumnIdent(col, None) if db.view.getNonAmbiguousColumn(col).right.get._2.name == table => col
168 | }
169 |
170 | val keepProjections = subSelect.projections.collect {
171 | case e @ ExpressionProjection(_, Some(alias)) if columnsUsed.contains(alias) => e
172 | }
173 |
174 | rewriteSelect(subSelect, db, keepProjections, availableParams)
175 | }.right
176 | } yield {
177 | rel.copy(select = optimized)
178 | }
179 | case rel @ JoinRelation(left, _, right, _) =>
180 | for {
181 | newLeft <- optimizeRecursively(left).right
182 | newRight <- optimizeRecursively(right).right
183 | } yield rel.copy(left = newLeft, right = newRight)
184 | case rel => Right(rel)
185 | }
186 | for {
187 | newRelations <- (select.relations.foldRight(Right(Nil):Either[Err,List[Relation]]) {
188 | (rel, acc) => for {
189 | a <- acc.right
190 | b <- optimizeRecursively(rel).right
191 | } yield b :: a
192 | }).right
193 | } yield select.copy(relations = newRelations)
194 | }
195 |
196 | def rewriteRelations(select: SimpleSelect, db: DB, tables: Set[String]): List[Relation] = {
197 | def rewritePass(relations: List[Relation], tables: Set[String]) = {
198 | def rewriteRecursively(relation: Relation): Option[Relation] = relation match {
199 | case rel @ SingleTableRelation(_, Some(table)) => if(tables.contains(table.toLowerCase)) Some(rel) else None
200 | case rel @ SingleTableRelation(TableIdent(table, _), None) => if(tables.contains(table.toLowerCase)) Some(rel) else None
201 | case rel @ SubSelectRelation(subSelect, table) => if(tables.contains(table.toLowerCase)) Some(rel) else None
202 | case rel @ JoinRelation(left, _, right, _) =>
203 | (rewriteRecursively(left) :: rewriteRecursively(right) :: Nil).flatten match {
204 | case Nil => None
205 | case x :: Nil => Some(x)
206 | case l :: r :: Nil => Some(rel.copy(left = l, right = r))
207 | case _ => sys.error("Unreacheable path")
208 | }
209 | case _ => None
210 | }
211 | relations.flatMap(rewriteRecursively)
212 | }
213 |
214 | // We start with the set of tables originally specified,
215 | // and after each rewrite we check that the rewritten joins
216 | // don't need more tables.
217 | def rewriteRecursively(previous: List[Relation], tableSet: Set[String]): List[Relation] = {
218 | def tablesUsedByRelations(relation: Relation): Set[String] = relation match {
219 | case JoinRelation(left, _, right, maybeOn) =>
220 | tablesUsedByRelations(left) ++ tablesUsedByRelations(right) ++ maybeOn.map { on =>
221 | tablesFor(select, db, on).right.map(_.toSet).right.getOrElse(Set.empty)
222 | }.getOrElse(Set.empty)
223 | case _ => Set.empty
224 | }
225 | val rewritten = rewritePass(select.relations, tableSet)
226 | if(rewritten != previous) {
227 | rewriteRecursively(rewritten, tableSet ++ rewritten.flatMap(tablesUsedByRelations))
228 | }
229 | else rewritten
230 | }
231 |
232 | rewriteRecursively(Nil, tables)
233 | }
234 |
235 | def rewriteExpression(parameters: Set[String], expression: Expression): Option[Expression] = {
236 | def rewriteRecursively(expression: Expression): Option[Expression] = expression match {
237 | case AndExpression(op, left, right) =>
238 | (rewriteRecursively(left) :: rewriteRecursively(right) :: Nil).flatten match {
239 | case Nil => None
240 | case x :: Nil => Some(x)
241 | case l :: r :: Nil => Some(AndExpression(op, l, r))
242 | case _ => sys.error("Unreacheable path")
243 | }
244 | case expr =>
245 | (expr.visitPlaceholders.map(_.name).flatten.toSet &~ parameters).toList match {
246 | case Nil => Some(expr)
247 | case _ => None
248 | }
249 | }
250 | rewriteRecursively(expression)
251 | }
252 |
253 | def rewriteWhereCondition(select: SimpleSelect, parameters: Set[String]): Option[Expression] = {
254 | select.where.flatMap(rewriteExpression(parameters, _))
255 | }
256 |
257 | def rewriteGroupBy(select: SimpleSelect, expressions: List[Expression]): List[GroupBy] = {
258 | def rewriteGroupingSet(groupingSet: GroupingSet): Option[GroupingSet] = {
259 | Option(GroupingSet(groupingSet.groups.filter(expressions.contains))).filterNot(_.groups.isEmpty)
260 | }
261 | def rewriteRecursively(groupBy: GroupBy): Option[GroupBy] = groupBy match {
262 | case g @ GroupByExpression(e) => if(expressions.contains(e)) Some(g) else None
263 | case g @ GroupByRollup(groups) =>
264 | Option(GroupByRollup(groups.map {
265 | case g @ Left(e) => if(expressions.contains(e)) Some(g) else None
266 | case g @ Right(gs) => rewriteGroupingSet(gs).map(Right.apply)
267 | }.flatten)).filterNot(_.groups.isEmpty)
268 | }
269 | select.groupBy.flatMap(rewriteRecursively)
270 | }
271 |
272 | }
273 |
274 | //
275 |
276 | trait PostAggregate
277 | case class SumPostAggregate(expr: Expression) extends PostAggregate
278 | case class DividePostAggregate(left: PostAggregate, right: PostAggregate) extends PostAggregate
279 |
--------------------------------------------------------------------------------
/shared/src/main/scala/com/criteo/vizatra/vizsql/optimize/Optimizer.scala:
--------------------------------------------------------------------------------
1 | package com.criteo.vizatra.vizsql
2 |
3 | class Optimizer(db : DB) {
4 |
5 | def optimize(sel : SimpleSelect) : SimpleSelect = {
6 | val newWhere = sel.where.map(preEvaluate)
7 | val newProj = sel.projections.map {
8 | case ExpressionProjection(exp, alias) => ExpressionProjection(preEvaluate(exp), alias)
9 | case x => x
10 | }
11 | val newRel = sel.relations.map(apply)
12 | val newOrder = sel.orderBy.map {case SortExpression(exp, ord) => SortExpression(preEvaluate(exp), ord) }
13 | val sel2 = SimpleSelect(sel.distinct, newProj, newRel, newWhere,sel.groupBy, sel.having, newOrder, sel.limit)
14 | val tables = (sel2.projections.collect {case ExpressionProjection(exp, _) => exp} ++ sel2.where.toList).foldRight(Right(Nil):Either[Err,List[String]]) {
15 | (p, acc) => for(a <- acc.right; b <- OlapQuery.tablesFor(sel, db, p).right) yield b ++ a
16 | }.right.getOrElse(Nil)
17 | SimpleSelect(sel2.distinct, sel2.projections, OlapQuery.rewriteRelations(sel2, db,tables.toSet), sel2.where, sel2.groupBy, sel2.having, sel2.orderBy, sel2.limit)
18 | }
19 |
20 | def apply(select : Select) : Select = select match {
21 | case sel : SimpleSelect => optimize(sel)
22 | case UnionSelect(left, distinct, right) => UnionSelect(this.apply(left), distinct, apply(right))
23 | }
24 |
25 | def apply(relation: Relation) : Relation = relation match {
26 | case JoinRelation(left, join, right, on) => JoinRelation(apply(left), join, apply(right), on.map(preEvaluate))
27 | case SubSelectRelation(select, alias) => SubSelectRelation(apply(select), alias)
28 | case x => x
29 | }
30 |
31 | def preEvaluate(expr : Expression) : Expression = {
32 | def convert(v : Either[Literal, Expression]) = v match {case Left(x) => LiteralExpression(x); case Right(x) => x}
33 | def rec(expr: Expression) : Either[Literal, Expression] = expr match {
34 | case LiteralExpression(lit) => Left(lit)
35 | case ParenthesedExpression(exp) => rec(exp).right.map(x => ParenthesedExpression(x))
36 | case MathExpression(op, left, right) => (rec(left), rec(right)) match {
37 | case (Left(x), Left(y)) => EvalHelper.mathEval(op, x, y)
38 | case (l, r) => Right(MathExpression(op, convert(l), convert(r)))
39 | }
40 | case UnaryMathExpression(op, exp) => rec(exp) match {
41 | case x if op == "+" => x
42 | case Left(lit) if op == "-" => lit match {
43 | case DecimalLiteral(x) => Left(DecimalLiteral(-x))
44 | case IntegerLiteral(x) => Left(IntegerLiteral(-x))
45 | case _ => Right(UnaryMathExpression(op, LiteralExpression(lit)))
46 | }
47 | case x => Right(UnaryMathExpression(op, convert(x)))
48 | }
49 | case ComparisonExpression(op, left, right) => (rec(left), rec(right)) match {
50 | case (Left(x), Left(y)) if EvalHelper.compOperators.contains(op) => EvalHelper.compEval(op, x, y)
51 | case (l, r) => Right(ComparisonExpression(op, convert(l), convert(r)))
52 | }
53 | case AndExpression(op, left, right) => (rec(left), rec(right)) match {
54 | case (Left(x), Left(y)) => Left(EvalHelper.bool2Literal(EvalHelper.literal2bool(x) && EvalHelper.literal2bool(y)))
55 | case (l, r) => Right(AndExpression(op, convert(l), convert(r)))
56 | }
57 | case OrExpression(op, left, right) => (rec(left), rec(right)) match {
58 | case (Left(x), Left(y)) => Left(EvalHelper.bool2Literal(EvalHelper.literal2bool(x) || EvalHelper.literal2bool(y)))
59 | case (l, r) => Right(OrExpression(op, convert(l), convert(r)))
60 | }
61 | case IsInExpression(left, not, right) => (rec(left), right.map(rec)) match {
62 | case (l @ Left(_), r) if not && r.contains(l) => Left(FalseLiteral)
63 | case (l @ Left(_), r) if r.contains(l) => Left(TrueLiteral)
64 | case (l, r) => Right(IsInExpression(convert(l), not,
65 | r.map(convert)))
66 | }
67 | case FunctionCallExpression(f, d, args) => Right(FunctionCallExpression(f, d, args.map(x => convert(rec(x)))))
68 | case NotExpression(exp) => rec(exp) match {
69 | case Left(l) => Left(EvalHelper.bool2Literal(l == FalseLiteral))
70 | case Right(ex) => Right(NotExpression(ex))
71 | }
72 | case IsExpression(exp, not, lit) => rec(exp) match {
73 | case Left(lit2) => Left(EvalHelper.bool2Literal(not ^ lit == lit2))
74 | case Right(ex) => Right(IsExpression(ex, not, lit))
75 | }
76 | case CastExpression(from, to) => Right(CastExpression(convert(rec(from)), to))
77 | case CaseWhenExpression(value, mapping,elseVal) =>
78 | val valueRes = value.map(rec)
79 | val elseRes = elseVal.map(rec)
80 | def recList(l : List[(Either[Literal, Expression],Either[Literal, Expression])]): Either[Literal, Expression] = (valueRes, l, elseRes) match {
81 | case (Some(matchValue), mapVals@(h::_), elseClause) if matchValue.isRight => Right(CaseWhenExpression(Some(convert(matchValue)),
82 | mapVals.map(x => (convert(x._1), convert(x._2))), elseClause.map(convert)))
83 | case (Some(matchValue), h :: t, _) if h._1 == matchValue => h._2
84 | case (Some(matchValue), mapVals @ (h :: t), elseClause) if h._1.isRight => Right(CaseWhenExpression(Some(convert(matchValue)),
85 | mapVals.filter(x => x._1.isRight || x._1 == matchValue).map(x => (convert(x._1), convert(x._2))), elseClause.map(convert)))
86 | case (None, h::t, _) if h._1 == Left(TrueLiteral) => h._2
87 | case (None, mapVals @ (h::t), elseClause) if h._1.isRight => Right(CaseWhenExpression(None,
88 | mapVals.toList.filter(x => x._1.isRight || x._1 == Left(TrueLiteral)).map(x => (convert(x._1), convert(x._2))), elseClause.map(convert)))
89 | case (matchValue, Nil, Some(elseClause)) => elseClause
90 | case (matchValue, Nil, None) => Left(NullLiteral)
91 | case (_, h :: t, _) => recList(t)
92 | }
93 | recList(mapping.map(x => (rec(x._1), rec(x._2))))
94 | case LikeExpression(left, not, op, right) => (rec(left), rec(right)) match {
95 | case (Left(l), Left(r)) if l == r && !not=> Left(TrueLiteral)
96 | case (l, r) => Right(LikeExpression(convert(l), not, op, convert(r)))
97 | }
98 | case SubSelectExpression(select) => Right(SubSelectExpression(apply(select)))
99 | case _ => Right(expr)
100 | }
101 | convert(rec(expr))
102 | }
103 | }
104 |
105 | object Optimizer {
106 | def optimize(query : Query) : Query = {
107 | val sel = new Optimizer(query.db).optimize(query.select)
108 | Query(sel.toSQL, sel, query.db)
109 | }
110 | }
111 |
--------------------------------------------------------------------------------