├── .gitignore ├── .travis.yml ├── .travis ├── deploy.sh └── keys.tar.enc ├── LICENSE ├── README.md ├── build.sbt ├── js └── src │ ├── main │ └── scala │ │ └── com │ │ └── criteo │ │ └── vizatra │ │ └── vizsql │ │ └── js │ │ ├── Database.scala │ │ ├── QueryParser.scala │ │ ├── common │ │ └── ParseResult.scala │ │ └── json │ │ ├── ColumnReader.scala │ │ ├── DBReader.scala │ │ ├── DialectReader.scala │ │ ├── Reader.scala │ │ ├── SchemaReader.scala │ │ └── TableReader.scala │ └── test │ └── scala │ └── com │ └── criteo │ └── vizatra │ └── vizsql │ └── js │ ├── DatabaseSpec.scala │ ├── QueryParserSpec.scala │ ├── common │ └── ParseResultSpec.scala │ └── json │ ├── ColumnReaderSpec.scala │ ├── DBReaderSpec.scala │ ├── DialectReaderSpec.scala │ ├── SchemaReaderSpec.scala │ └── TableReaderSpec.scala ├── jvm └── src │ └── test │ └── scala │ └── com │ └── criteo │ └── vizatra │ └── vizsql │ ├── ExtractColumnsSpec.scala │ ├── ExtractPlaceholdersSpec.scala │ ├── FormatSQLSpec.scala │ ├── OlapSpec.scala │ ├── OptimizeSpec.scala │ ├── ParseSQL99Spec.scala │ ├── ParseVerticaDialectSpec.scala │ ├── ParsingErrorsSpec.scala │ ├── SchemaErrorsSpec.scala │ ├── ThreadSafetySpec.scala │ ├── TypingErrorsSpec.scala │ └── hive │ ├── HiveParsingErrorsSpec.scala │ ├── HiveTypeParserSpec.scala │ └── ParseHiveQuerySpec.scala ├── project └── plugins.sbt └── shared └── src └── main └── scala └── com └── criteo └── vizatra └── vizsql ├── AST.scala ├── Dialect.scala ├── Errors.scala ├── EvalHelper.scala ├── Functions.scala ├── Parser.scala ├── Schema.scala ├── Show.scala ├── Types.scala ├── VizSQL.scala ├── dialects ├── h2 │ └── H2Dialect.scala ├── hive │ ├── HiveAST.scala │ ├── HiveDialect.scala │ ├── HiveFunctions.scala │ ├── HiveTypes.scala │ └── TypeParser.scala ├── hsqldb │ └── HsqlDBDialect.scala ├── mysql │ └── MySQLDialect.scala ├── postgresql │ └── PostgresqlDialect.scala ├── sql99 │ └── SQL99.scala └── vertica │ └── VerticaDialect.scala ├── olap └── Olap.scala └── optimize └── Optimizer.scala /.gitignore: -------------------------------------------------------------------------------- 1 | logs/* 2 | target 3 | target/* 4 | *.DS_Store 5 | *.releaseBackup 6 | release.properties 7 | *.iml 8 | *.ipr 9 | *.iws 10 | .idea 11 | .idea/* 12 | src/main/resources/application.conf 13 | .checkstyle 14 | .classpath 15 | .project 16 | .settings/ 17 | src/main/webapp/node/ 18 | src/main/webapp/node_modules/ 19 | src/main/webapp/lib/ 20 | src/main/resources/public/ 21 | *.bower.json 22 | dependency-reduced-pom.xml 23 | *.swp 24 | .devsettings 25 | effective-pom.xml 26 | drivers/ 27 | project/* 28 | !project/plugins.sbt 29 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.11.8 4 | env: 5 | global: 6 | secure: IpQPAMFrafCkbzq+/RgHbv7kaTS7yViw4ce6yTNKmdgQeVHBPeRjaKteyuGmaFyOMIvHCsABtk3QWKfOlxyuzv/ZMlfTy69TnoDa62jAFIJLhZ6pqiGswtXwso7EFW6NgBQJi1AEV9mtC0VOH6YklbUSsnZsI1saB2JHD9h9fafOoiH8ifDDGh+vrIiozrr4YjuT/5/fV90rt+Bha5nOVqVmxmUP9Fpmh0ca3hY0Nw/KhtzxP+VAnQd4MaP0CU8b8+H/IVBUnVBB3xj8pPKs+t/di4NBDfDb6IYmGBtHviJxoDxnJwfRlT8I6o/e2EMNs7wQHke29VsI9Vl+BK6oug0JDugTCgocNrvjqsFDu1nInP3D0yMJRsFBIJ+BhmRw5fEi85KMlu+InfQTRq+FzqMHcIuQMYNlCDgIsQK4an3rUmJ8yJ6nX2qTwZjw8kRTVoszlmZpJ6h1pjG95jHq42JObnPge+oJyl4sO4ngClxEiHr1YQBsDuJfgPl2skXn1p3AGnK6AYn8I+NdK0mtBnJCrlonpxYz3i5zL2mUnDTV19dMxAKXljQRckllVO+JDKZiBl8s+GDp/JT5tIKGozp4phRRUlXRW1RROLHU7qvO1SEvC8ERUQS9K0V/QxCPK9L12IkJYUi37aaN/Qw51MLScxwEniAjoOuCaYpxj4c= 7 | before_deploy: 8 | - openssl aes-256-cbc -K $encrypted_12ac5f6463c4_key -iv $encrypted_12ac5f6463c4_iv 9 | -in .travis/keys.tar.enc -out .travis/keys.tar -d 10 | - tar xvf .travis/keys.tar 11 | 12 | deploy: 13 | provider: script 14 | script: './.travis/deploy.sh' 15 | skip_cleanup: true 16 | on: 17 | tags: true 18 | -------------------------------------------------------------------------------- /.travis/deploy.sh: -------------------------------------------------------------------------------- 1 | sbt vizsqlJVM/publishSigned 2 | sbt 'project vizsqlJVM' sonatypeRelease 3 | -------------------------------------------------------------------------------- /.travis/keys.tar.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/criteo/vizsql/b6bcbae61acf8269ec21b786aae8bc27b8adc355/.travis/keys.tar.enc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VizSQL 2 | 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | [![Build Status](https://travis-ci.org/criteo/vizsql.svg?branch=master)](https://travis-ci.org/criteo/vizsql) 5 | [![Maven Central](https://maven-badges.herokuapp.com/maven-central/com.criteo/vizsql_2.11/badge.svg)](https://maven-badges.herokuapp.com/maven-central/com.criteo/vizsql_2.11) 6 | 7 | ## VizSQL is a SQL parser & typer for scala 8 | 9 | VizSQL provides support for parsing SQL statements into a scala AST. This AST can then be used to support many interesting transformations. We provide for example a static optimizer and an OLAP query rewriter. It also support typing the query to retrieve the resultset columns returned by a SQL statement. 10 | 11 | ### Installation 12 | 13 | #### SBT 14 | 15 | ```scala 16 | libraryDependencies += "com.criteo" % "vizsql_2.11" % "{latestVersion}" 17 | ``` 18 | 19 | #### Maven 20 | 21 | ```xml 22 | 23 | com.criteo 24 | vizsql_2.11 25 | {latestVersion} 26 | 27 | ``` 28 | 29 | ### tl;dr 30 | 31 | Here we use VizSQL to parse a SQL SELECT statement based on the SAKILA database, and to retrieve the column names and types of the returned resultset. 32 | 33 | ```scala 34 | import com.criteo.vizatra.vizsql 35 | 36 | val resultColumns = 37 | VizSQL.parseQuery( 38 | """SELECT country_id, max(last_update) as updated from City as x""", 39 | SAKILA 40 | ) 41 | .fold(e => sys.error(s"SQL syntax error: $e"), identity) 42 | .columns 43 | .fold(e => sys.error(s"SQL error: $e"), identity) 44 | 45 | assert( 46 | resultColumns == List( 47 | Column("country_id", INTEGER(nullable = false)), 48 | Column("updated", TIMESTAMP(nullable = true)) 49 | ) 50 | ) 51 | ``` 52 | 53 | Note that the query is not executed. It is just parsed, validated, analyzed and typed using the SAKILA database schema. The schema can be provided directly or extracted at runtime from a JDBC connection if needed. Here is for example the minimum SAKILA schema we need to provide so VizSQL is able to type the previous query: 54 | 55 | ```scala 56 | val SAKILA = DB(schemas = List( 57 | Schema( 58 | "sakila", 59 | tables = List( 60 | Table( 61 | "City", 62 | columns = List( 63 | Column("city_id", INTEGER(nullable = false)), 64 | Column("city", STRING(nullable = true)), 65 | Column("country_id", INTEGER(nullable = false)), 66 | Column("last_update", TIMESTAMP(nullable = true)) 67 | ) 68 | ) 69 | ) 70 | ) 71 | )) 72 | ``` 73 | 74 | ### Dialects 75 | 76 | VizSQL support several dialects allowing to understand specific SQL syntax or functions for different databases. We started the core one based on the *SQL99* standard. Then we enriched it to create specific dialects for **Vertica** and **Hive**. This is of course a work in progress (even for the SQL99 one). 77 | 78 | ### In-browser VizSQL 79 | 80 | Yes, VizSQL can run in browsers, thanks to Scala.js, the project can be compiled to JavaScript, so that your front-end code can also be equipped with VizSQL. 81 | 82 | To compile to JavaScript: 83 | ```sh 84 | sbt fullOptJS 85 | ``` 86 | 87 | then `vizsql-opt.js` is generated in the `target` folder, it's packed as a CommonJS (node.js) module 88 | 89 | To use VizSQL: 90 | ```javascript 91 | const Database = require('vizsql').Database; 92 | const db = Database().from({ 93 | /* db definitions */ 94 | }); 95 | 96 | const parseResult = db.parse("SELECT * FROM table"); 97 | ``` 98 | 99 | the parse result contains ```{ error, select }```: 100 | 101 | - error (object), null if no error is present 102 | ```javascript 103 | { 104 | msg: 'err' // message of the error 105 | pos: 1 // position of the error 106 | } 107 | ``` 108 | - select (object), null if there's an error 109 | ```javascript 110 | { 111 | columns: [...], // columns in the query 112 | tables: [...] // tables in the query 113 | } 114 | ``` 115 | 116 | ### License 117 | 118 | This project is licensed under the Apache 2.0 license. 119 | 120 | ### Copyright 121 | 122 | Copyright © [Criteo](http://labs.criteo.com), 2016. 123 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "vizsql" 2 | 3 | scalaVersion in ThisBuild := "2.11.8" 4 | 5 | lazy val root = project.in(file(".")) 6 | .enablePlugins(ScalaJSPlugin) 7 | .aggregate(vizsqlJS, vizsqlJVM) 8 | .settings() 9 | 10 | lazy val vizsql = crossProject.in(file(".")) 11 | .settings( 12 | libraryDependencies += "org.scalactic" %%% "scalactic" % "3.0.0", 13 | libraryDependencies += "org.scalatest" %%% "scalatest" % "3.0.0" % "test" 14 | ) 15 | .jvmSettings( 16 | organization := "com.criteo", 17 | version := "1.0.0", 18 | libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.4", 19 | credentials += Credentials( 20 | "Sonatype Nexus Repository Manager", 21 | "oss.sonatype.org", 22 | "criteo-oss", 23 | sys.env.getOrElse("SONATYPE_PASSWORD", "") 24 | ) 25 | ) 26 | .jsSettings( 27 | libraryDependencies += "org.scala-js" %%% "scala-parser-combinators" % "1.0.2", 28 | scalaJSModuleKind := ModuleKind.CommonJSModule 29 | ) 30 | 31 | lazy val vizsqlJVM = vizsql.jvm 32 | lazy val vizsqlJS = vizsql.js 33 | 34 | // To sync with Maven central, you need to supply the following information: 35 | pomExtra in Global := { 36 | https://github.com/criteo/vizsql 37 | 38 | 39 | Apache 2 40 | http://www.apache.org/licenses/LICENSE-2.0.txt 41 | 42 | 43 | 44 | scm:git:github.com/criteo/vizsql.git 45 | scm:git:git@github.com:criteo/vizsql.git 46 | github.com/criteo/vizsql 47 | 48 | 49 | 50 | Guillaume Bort 51 | g.bort@criteo.com 52 | https://github.com/guillaumebort 53 | Criteo 54 | http://www.criteo.com 55 | 56 | 57 | } 58 | 59 | pgpPassphrase := sys.env.get("SONATYPE_PASSWORD").map(_.toArray) 60 | pgpSecretRing := file(".travis/secring.gpg") 61 | pgpPublicRing := file(".travis/pubring.gpg") 62 | usePgpKeyHex("755d525885532e9e") 63 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/Database.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js 2 | 3 | import com.criteo.vizatra.vizsql.DB 4 | import com.criteo.vizatra.vizsql.js.json.DBReader 5 | 6 | import scala.scalajs.js.Dynamic 7 | import scala.scalajs.js 8 | import scala.scalajs.js.annotation.{JSExport, ScalaJSDefined} 9 | 10 | @JSExport("Database") 11 | object Database { 12 | /** 13 | * Parses to a VizSQL DB object 14 | * @param input a JS object of the database definition 15 | * @return DB 16 | */ 17 | @JSExport 18 | def parse(input: Dynamic): DB = DBReader.apply(input) 19 | 20 | /** 21 | * Construct a Database object from the database definition 22 | * @param input a JS object of the database definition 23 | * @return Database 24 | */ 25 | @JSExport 26 | def from(input: Dynamic): Database = new Database(parse(input)) 27 | } 28 | 29 | @ScalaJSDefined 30 | class Database(db: DB) extends js.Object { 31 | def parse(query: String) = { 32 | QueryParser.parse(query, db) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/QueryParser.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js 2 | 3 | import com.criteo.vizatra.vizsql.js.common._ 4 | import com.criteo.vizatra.vizsql.{DB, Query, VizSQL} 5 | 6 | import scala.scalajs.js.JSConverters._ 7 | import scala.scalajs.js.annotation.JSExport 8 | 9 | @JSExport("QueryParser") 10 | object QueryParser { 11 | @JSExport 12 | def parse(query: String, db: DB): ParseResult = 13 | VizSQL.parseQuery(query, db) match { 14 | case Left(err) => new ParseResult(new ParseError(err.msg, err.pos)) 15 | case Right(query) => convert(query) 16 | } 17 | 18 | def convert(query: Query): ParseResult = { 19 | val select = query.select 20 | val db = query.db 21 | val result = for { 22 | columns <- select.getColumns(db).right 23 | tables <- select.getTables(db).right 24 | } yield (columns, tables) 25 | result fold ( 26 | err => new ParseResult(new ParseError(err.msg, err.pos)), { case (columns, tables) => 27 | val cols = columns map Column.from 28 | val tbls = tables map { case (maybeSchema, table) => Table.from(table, maybeSchema) } 29 | new ParseResult(select = new Select(cols.toJSArray, tbls.toJSArray)) 30 | } 31 | ) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/common/ParseResult.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.common 2 | 3 | import scala.scalajs.js 4 | import scala.scalajs.js.UndefOr 5 | import scala.scalajs.js.annotation.ScalaJSDefined 6 | import scala.scalajs.js.JSConverters._ 7 | import com.criteo.vizatra.vizsql 8 | 9 | @ScalaJSDefined 10 | class ParseResult(val error: UndefOr[ParseError] = js.undefined, val select: UndefOr[Select] = js.undefined) extends js.Object 11 | 12 | @ScalaJSDefined 13 | class ParseError(val msg: String, val pos: Int) extends js.Object 14 | 15 | @ScalaJSDefined 16 | class Select(val columns: js.Array[Column], val tables: js.Array[Table]) extends js.Object 17 | 18 | @ScalaJSDefined 19 | class Column(val name: String, val `type`: String, val nullable: Boolean) extends js.Object 20 | 21 | object Column { 22 | def from(column: vizsql.Column): Column = new Column(column.name, column.typ.show, column.typ.nullable) 23 | } 24 | 25 | @ScalaJSDefined 26 | class Table(val name: String, val columns: js.Array[Column], val schema: UndefOr[String] = js.undefined) extends js.Object 27 | 28 | object Table { 29 | def from(table: vizsql.Table, schema: Option[String] = None): Table = new Table(table.name, table.columns.map(Column.from).toJSArray, schema.orUndefined) 30 | } 31 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/json/ColumnReader.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import com.criteo.vizatra.vizsql.hive.TypeParser 5 | 6 | import scala.scalajs.js.{Dynamic, UndefOr} 7 | 8 | object ColumnReader extends Reader[Column] { 9 | lazy val hiveTypeParser = new TypeParser 10 | 11 | override def apply(dyn: Dynamic): Column = { 12 | val name = dyn.name.asInstanceOf[String] 13 | val nullable = dyn.nullable.asInstanceOf[UndefOr[Boolean]].getOrElse(false) 14 | val typ = parseType(dyn.`type`.asInstanceOf[String], nullable) 15 | Column(name, typ) 16 | } 17 | 18 | def parseType(input: String, nullable: Boolean): Type = Type.from(nullable).applyOrElse(input, (rest: String) => 19 | hiveTypeParser.parseType(rest) match { 20 | case Left(err) => throw new IllegalArgumentException(err) 21 | case Right(t) => t 22 | } 23 | ) 24 | } 25 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/json/DBReader.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql.{DB, Schemas} 4 | 5 | import scala.scalajs.js.{Array, Dynamic, UndefOr} 6 | 7 | object DBReader extends Reader[DB] { 8 | override def apply(dyn: Dynamic): DB = { 9 | implicit val dialect = DialectReader(dyn.dialect) 10 | val schemas = dyn.schemas 11 | .asInstanceOf[Array[Dynamic]] 12 | .toList map SchemaReader.apply 13 | DB(schemas) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/json/DialectReader.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import com.criteo.vizatra.vizsql.hive.HiveDialect 5 | 6 | import scala.scalajs.js.Dynamic 7 | 8 | object DialectReader extends Reader[Dialect] { 9 | override def apply(dyn: Dynamic): Dialect = from(dyn.asInstanceOf[String]) 10 | 11 | def from(input: String): Dialect = input.toLowerCase match { 12 | case "vertica" => vertica.dialect 13 | case "hsql" => hsqldb.dialect 14 | case "h2" => h2.dialect 15 | case "postgresql" => postgresql.dialect 16 | case "hive" => HiveDialect(Map.empty) // TODO: handle hive UDFs 17 | case _ => sql99.dialect 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/json/Reader.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import scala.scalajs.js.{Dynamic, JSON} 4 | 5 | trait Reader[T] { 6 | def apply(dyn: Dynamic): T 7 | def apply(input: String): T = apply(JSON.parse(input)) 8 | } 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/json/SchemaReader.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql.{Schema, Table} 4 | import scalajs.js.Array 5 | 6 | import scala.scalajs.js.UndefOr 7 | import scala.scalajs.js.Dynamic 8 | 9 | object SchemaReader extends Reader[Schema] { 10 | override def apply(dyn: Dynamic): Schema = { 11 | val name = dyn.name.asInstanceOf[String] 12 | val tables: List[Table] = dyn.tables 13 | .asInstanceOf[UndefOr[Array[Dynamic]]] 14 | .getOrElse(Array()) 15 | .toList map TableReader.apply 16 | Schema(name, tables) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /js/src/main/scala/com/criteo/vizatra/vizsql/js/json/TableReader.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql.{Column, Table} 4 | 5 | import scala.scalajs.js.{Dynamic, UndefOr, Array} 6 | 7 | object TableReader extends Reader[Table] { 8 | override def apply(dyn: Dynamic): Table = { 9 | val table = dyn.name.asInstanceOf[String] 10 | val columns: List[Column] = dyn.columns 11 | .asInstanceOf[UndefOr[Array[Dynamic]]] 12 | .getOrElse(Array()) 13 | .toList map ColumnReader.apply 14 | Table(table, columns) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/DatabaseSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js 2 | 3 | import com.criteo.vizatra.vizsql.DB 4 | import com.criteo.vizatra.vizsql._ 5 | import org.scalatest.{FlatSpec, Matchers} 6 | 7 | class DatabaseSpec extends FlatSpec with Matchers { 8 | implicit val dialect = sql99.dialect 9 | val db = DB(schemas = List( 10 | Schema( 11 | "sakila", 12 | tables = List( 13 | Table( 14 | "City", 15 | columns = List( 16 | Column("city_id", INTEGER(nullable = false)), 17 | Column("city", STRING(nullable = true)), 18 | Column("country_id", INTEGER(nullable = false)), 19 | Column("last_update", TIMESTAMP(nullable = false)) 20 | ) 21 | ) 22 | ) 23 | ) 24 | )) 25 | 26 | "Database class" should "be able to parse queries" in { 27 | val database = new Database(db) 28 | val res = database.parse("SELECT city_id FROM city") 29 | res.select.get.columns.head.name shouldEqual "city_id" 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/QueryParserSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import org.scalatest.{FlatSpec, Matchers} 5 | 6 | class QueryParserSpec extends FlatSpec with Matchers { 7 | implicit val dialect = sql99.dialect 8 | val db = DB(schemas = List( 9 | Schema( 10 | "sakila", 11 | tables = List( 12 | Table( 13 | "City", 14 | columns = List( 15 | Column("city_id", INTEGER(nullable = false)), 16 | Column("city", STRING(nullable = true)), 17 | Column("country_id", INTEGER(nullable = false)), 18 | Column("last_update", TIMESTAMP(nullable = false)) 19 | ) 20 | ), 21 | Table( 22 | "Country", 23 | columns = List( 24 | Column("country_id", INTEGER(nullable = false)), 25 | Column("country", STRING(nullable = true)), 26 | Column("last_update", TIMESTAMP(nullable = false)) 27 | ) 28 | ) 29 | ) 30 | ) 31 | )) 32 | "parse()" should "return a result" in { 33 | val result = QueryParser.parse( 34 | s""" 35 | |SELECT country, city 36 | |FROM city JOIN country ON city.country_id = country.country_id 37 | |WHERE city IN ?{availableCities} 38 | """.stripMargin, db) 39 | result.error.isDefined shouldBe false 40 | result.select.isDefined shouldBe true 41 | val select = result.select.get 42 | select.columns.length shouldBe 2 43 | } 44 | 45 | "parse()" should "handle errors" in { 46 | val result = QueryParser.parse( 47 | s"""S 48 | """.stripMargin, db) 49 | result.error.isDefined shouldBe true 50 | result.select.isDefined shouldBe false 51 | val error = result.error.get 52 | error.pos shouldBe 0 53 | error.msg shouldBe "select expected" 54 | } 55 | "parse()" should "identify invalid columns" in { 56 | val result = QueryParser.parse( 57 | s""" 58 | |SELECT country1, city 59 | |FROM city JOIN country ON city.country_id = country.country_id 60 | |WHERE city IN ?{availableCities} 61 | """.stripMargin, db) 62 | result.error.get.msg shouldBe "column not found country1" 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/common/ParseResultSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.common 2 | 3 | import org.scalatest.{FunSpec, Matchers} 4 | import com.criteo.vizatra.vizsql 5 | import com.criteo.vizatra.vizsql.INTEGER 6 | 7 | import scala.scalajs.js.JSON 8 | class ParseResultSpec extends FunSpec with Matchers { 9 | describe("Column") { 10 | describe("from()") { 11 | it("converts scala Column to JS object") { 12 | val col = Column.from(vizsql.Column("col1", INTEGER(true))) 13 | JSON.stringify(col) shouldEqual """{"name":"col1","type":"integer","nullable":true}""" 14 | } 15 | } 16 | } 17 | 18 | describe("Table") { 19 | describe("from()") { 20 | it("converts scala Table to JS object") { 21 | val table = Table.from(vizsql.Table("table1", List(vizsql.Column("col1", INTEGER(true)))), Some("schema1")) 22 | JSON.stringify(table) shouldEqual """{"name":"table1","columns":[{"name":"col1","type":"integer","nullable":true}],"schema":"schema1"}""" 23 | } 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/json/ColumnReaderSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql.hive.{HiveArray, HiveMap, HiveStruct, TypeParser} 4 | import com.criteo.vizatra.vizsql.{BOOLEAN, Column, INTEGER, STRING} 5 | import org.scalatest.{FlatSpec, Matchers} 6 | 7 | class ColumnReaderSpec extends FlatSpec with Matchers { 8 | "apply()" should "return a column with nullable" in { 9 | val res = ColumnReader.apply("""{"name":"col","type":"int4","nullable":true}""") 10 | res shouldEqual Column("col", INTEGER(true)) 11 | } 12 | "apply()" should "return a column without nullable" in { 13 | val res = ColumnReader.apply("""{"name":"col","type":"int4"}""") 14 | res shouldEqual Column("col", INTEGER(false)) 15 | } 16 | "parseType()" should "parse map type" in { 17 | val res = ColumnReader.parseType("""map""", true) 18 | res shouldEqual HiveMap(STRING(true), INTEGER(true)) 19 | } 20 | "parseType()" should "parse array type" in { 21 | ColumnReader.parseType("array", true) shouldEqual HiveArray(INTEGER(true)) 22 | } 23 | "parseType()" should "parse struct type" in { 24 | ColumnReader.parseType("struct,c:struct>", true) shouldEqual HiveStruct(List( 25 | Column("a", HiveStruct(List( 26 | Column("b", BOOLEAN(true)) 27 | ))), 28 | Column("c", HiveStruct(List( 29 | Column("d", STRING(true)) 30 | ))) 31 | )) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/json/DBReaderSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import org.scalatest.{FlatSpec, Matchers} 5 | 6 | class DBReaderSpec extends FlatSpec with Matchers { 7 | "apply()" should "returns a DB" in { 8 | val db = DBReader.apply( 9 | """ 10 | |{ 11 | | "dialect": "vertica", 12 | | "schemas": [ 13 | | { 14 | | "name":"schema1", 15 | | "tables": [ 16 | | { 17 | | "name":"table1", 18 | | "columns": [ 19 | | {"name": "col1", "type": "int4"} 20 | | ] 21 | | } 22 | | ] 23 | | } 24 | | ] 25 | |} 26 | """.stripMargin) 27 | db.dialect shouldBe vertica.dialect 28 | db.schemas shouldEqual Schemas(List(Schema("schema1", List(Table("table1", List(Column("col1", INTEGER()))))))) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/json/DialectReaderSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import com.criteo.vizatra.vizsql.hive.HiveDialect 5 | import org.scalatest.{FlatSpec, Matchers} 6 | 7 | class DialectReaderSpec extends FlatSpec with Matchers { 8 | "from()" should "returns a dialect" in { 9 | DialectReader.from("VERTICA") shouldBe vertica.dialect 10 | DialectReader.from("PostgreSQL") shouldBe postgresql.dialect 11 | DialectReader.from("any") shouldBe sql99.dialect 12 | DialectReader.from("hsql") shouldBe hsqldb.dialect 13 | DialectReader.from("h2") shouldBe h2.dialect 14 | DialectReader.from("Hive") shouldBe HiveDialect(Map.empty) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/json/SchemaReaderSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql.{Column, INTEGER, Schema, Table} 4 | import org.scalatest.{FlatSpec, Matchers} 5 | 6 | class SchemaReaderSpec extends FlatSpec with Matchers { 7 | "apply()" should "return a schema" in { 8 | val res = SchemaReader( 9 | """ 10 | |{ 11 | | "name":"schema1", 12 | | "tables": [ 13 | | { 14 | | "name":"table1", 15 | | "columns": [ 16 | | {"name": "col1", "type": "int4"} 17 | | ] 18 | | } 19 | | ] 20 | |} 21 | """.stripMargin) 22 | res shouldEqual Schema( 23 | "schema1", 24 | List( 25 | Table("table1", List(Column("col1", INTEGER()))) 26 | ) 27 | ) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /js/src/test/scala/com/criteo/vizatra/vizsql/js/json/TableReaderSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.js.json 2 | 3 | import com.criteo.vizatra.vizsql.{Column, DECIMAL, INTEGER} 4 | import org.scalatest.{FlatSpec, Matchers} 5 | 6 | class TableReaderSpec extends FlatSpec with Matchers { 7 | "apply()" should "return a table" in { 8 | val res = TableReader.apply( 9 | """ 10 | |{ 11 | |"name": "table_1", 12 | |"columns": [ 13 | | { 14 | | "name": "col1", 15 | | "type": "int4" 16 | | }, 17 | | { 18 | | "name": "col2", 19 | | "type": "float4" 20 | | } 21 | |] 22 | |}""".stripMargin) 23 | res.name shouldBe "table_1" 24 | res.columns shouldBe List(Column("col1", INTEGER()), Column("col2", DECIMAL())) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/ExtractColumnsSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import org.scalatest.prop.TableDrivenPropertyChecks 4 | import org.scalatest.{Matchers, EitherValues, PropSpec} 5 | import sql99._ 6 | 7 | class ExtractColumnsSpec extends PropSpec with Matchers with EitherValues { 8 | 9 | val validSQL99SelectStatements = TableDrivenPropertyChecks.Table( 10 | ("SQL", "Expected Columns"), 11 | 12 | ("""SELECT 1""", List( 13 | Column("1", INTEGER(false)) 14 | )), 15 | 16 | ("""SELECT 1 as one""", List( 17 | Column("one", INTEGER(false)) 18 | )), 19 | 20 | ("""SELECT * from City""", List( 21 | Column("city_id", INTEGER(nullable = false)), 22 | Column("city", STRING(nullable = false)), 23 | Column("country_id", INTEGER(nullable = false)), 24 | Column("last_update", TIMESTAMP(nullable = false)) 25 | )), 26 | 27 | ("""SELECT x.* from City as x""", List( 28 | Column("city_id", INTEGER(nullable = false)), 29 | Column("city", STRING(nullable = false)), 30 | Column("country_id", INTEGER(nullable = false)), 31 | Column("last_update", TIMESTAMP(nullable = false)) 32 | )), 33 | 34 | ("""SELECT x.* from sakila.City x""", List( 35 | Column("city_id", INTEGER(nullable = false)), 36 | Column("city", STRING(nullable = false)), 37 | Column("country_id", INTEGER(nullable = false)), 38 | Column("last_update", TIMESTAMP(nullable = false)) 39 | )), 40 | 41 | ("""SELECT country_id, max(last_update) from City as x""", List( 42 | Column("country_id", INTEGER(nullable = false)), 43 | Column("MAX(last_update)", TIMESTAMP(nullable = false)) 44 | )), 45 | 46 | ("""SELECT cities.country_id, max(cities.last_update) from City as cities""", List( 47 | Column("cities.country_id", INTEGER(nullable = false)), 48 | Column("MAX(cities.last_update)", TIMESTAMP(nullable = false)) 49 | )), 50 | 51 | ("""SELECT * from Country, City""", List( 52 | Column("country_id", INTEGER(nullable = false)), 53 | Column("country", STRING(nullable = false)), 54 | Column("last_update", TIMESTAMP(nullable = false)), 55 | Column("city_id", INTEGER(nullable = false)), 56 | Column("city", STRING(nullable = false)), 57 | Column("country_id", INTEGER(nullable = false)), 58 | Column("last_update", TIMESTAMP(nullable = false)) 59 | )), 60 | 61 | ("""SELECT city_id from city""", List( 62 | Column("city_id", INTEGER(nullable = false)) 63 | )), 64 | 65 | ("""SELECT City.country_id from Country, City""", List( 66 | Column("city.country_id", INTEGER(nullable = false)) 67 | )), 68 | 69 | ("""select blah.* from (select *, 1 as one from Country) as blah;""", List( 70 | Column("country_id", INTEGER(nullable = false)), 71 | Column("country", STRING(nullable = false)), 72 | Column("last_update", TIMESTAMP(nullable = false)), 73 | Column("one", INTEGER(false)) 74 | )), 75 | 76 | ("""select * from City as v join Country as p on v.country_id = p.country_id""", List( 77 | Column("city_id", INTEGER(nullable = false)), 78 | Column("city", STRING(nullable = false)), 79 | Column("country_id", INTEGER(nullable = false)), 80 | Column("last_update", TIMESTAMP(nullable = false)), 81 | Column("country_id", INTEGER(nullable = false)), 82 | Column("country", STRING(nullable = false)), 83 | Column("last_update", TIMESTAMP(nullable = false)) 84 | )), 85 | 86 | ("SELECT now() as now, MAX(last_update) + 3599 / 86400 AS last_update FROM City", List( 87 | Column("now", TIMESTAMP(nullable = false)), 88 | Column("last_update", TIMESTAMP(nullable = false)) 89 | )), 90 | 91 | ("""select case 3 when 1 then 2 when 3 then 5 else 0 end AS foo""", List( 92 | Column("foo", INTEGER(nullable = false)) 93 | )), 94 | 95 | ("""select case 3 when 1 then 2 when 3 then 5 end AS foo""", List( 96 | Column("foo", INTEGER(nullable = true)) 97 | )), 98 | 99 | ("""select case city when 'a' then 2.5 when 'b' then 5.0 end AS foo from City""", List( 100 | Column("foo", DECIMAL(nullable = true)) 101 | )), 102 | 103 | ("""select coalesce(case when city = 'a' then 1 end, 3) x from City""", List( 104 | Column("x", INTEGER(nullable = false)) 105 | )), 106 | 107 | ("""select coalesce(case when city = 'a' then 1 end, case when country_id = 2 then 3 end) x from City""", List( 108 | Column("x", INTEGER(nullable = true)) 109 | )) 110 | ) 111 | 112 | // -- 113 | 114 | val SAKILA = DB(schemas = List( 115 | Schema( 116 | "sakila", 117 | tables = List( 118 | Table( 119 | "City", 120 | columns = List( 121 | Column("city_id", INTEGER(nullable = false)), 122 | Column("city", STRING(nullable = false)), 123 | Column("country_id", INTEGER(nullable = false)), 124 | Column("last_update", TIMESTAMP(nullable = false)) 125 | ) 126 | ), 127 | Table( 128 | "Country", 129 | columns = List( 130 | Column("country_id", INTEGER(nullable = false)), 131 | Column("country", STRING(nullable = false)), 132 | Column("last_update", TIMESTAMP(nullable = false)) 133 | ) 134 | ) 135 | ) 136 | ) 137 | )) 138 | 139 | // -- 140 | 141 | property("extract SQL-99 SELECT statements columns") { 142 | TableDrivenPropertyChecks.forAll(validSQL99SelectStatements) { 143 | case (sql, expectedColumns) => 144 | VizSQL.parseQuery(sql, SAKILA) 145 | .fold(e => sys.error(s"Query doesn't parse: $e"), identity) 146 | .columns 147 | .fold(e => sys.error(s"Invalid query: $e"), identity) should be (expectedColumns) 148 | } 149 | } 150 | 151 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/ExtractPlaceholdersSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import sql99._ 4 | import org.scalatest.prop.TableDrivenPropertyChecks 5 | import org.scalatest.{Matchers, EitherValues, PropSpec} 6 | 7 | class ExtractPlaceholdersSpec extends PropSpec with Matchers with EitherValues { 8 | 9 | val validSQL99SelectStatements = TableDrivenPropertyChecks.Table( 10 | ("SQL", "Expected Placeholders"), 11 | 12 | ("""SELECT * FROM city WHERE ?""", List( 13 | (None, BOOLEAN()) 14 | )), 15 | 16 | ("""SELECT * FROM city WHERE ?condition""", List( 17 | (Some("condition"), BOOLEAN()) 18 | )), 19 | 20 | ("""SELECT ?today:timestamp""", List( 21 | (Some("today"), TIMESTAMP()) 22 | )), 23 | 24 | ("""SELECT ?today:timestamp""", List( 25 | (Some("today"), TIMESTAMP()) 26 | )), 27 | 28 | ("""SELECT * FROM country WHERE country_id = ?""", List( 29 | (None, INTEGER()) 30 | )), 31 | 32 | ("""select *, country like ? from country where country_id = ? and last_update < ?;""", List( 33 | (None, STRING(nullable = true)), 34 | (None, INTEGER()), 35 | (None, TIMESTAMP()) 36 | )), 37 | 38 | ("""SELECT ?a:DECIMAL = ?b:DECIMAL""", List( 39 | (Some("a"), DECIMAL()), 40 | (Some("b"), DECIMAL()) 41 | )), 42 | 43 | ("""SELECT max(?today:timestamp)""", List( 44 | (Some("today"), TIMESTAMP()) 45 | )), 46 | 47 | ("""SELECT ? = 3""", List( 48 | (None, INTEGER()) 49 | )), 50 | 51 | ("""SELECT 3 = ?""", List( 52 | (None, INTEGER()) 53 | )), 54 | 55 | ("""SELECT ?:varchar = ?""", List( 56 | (None, STRING()), (None, STRING()) 57 | )), 58 | 59 | ("""SELECT ? = ?:varchar""", List( 60 | (None, STRING()), (None, STRING()) 61 | )), 62 | 63 | ( 64 | """ 65 | select 66 | *, ?today:timestamp as TODAY, ?ratio * city_id 67 | from city 68 | where 69 | 1 + ?a > 10 + ?b:integer AND 70 | ? 71 | OR ?blah like city; 72 | """, List( 73 | (Some("today"), TIMESTAMP()), 74 | (Some("ratio"), DECIMAL()), 75 | (Some("a"), DECIMAL()), 76 | (Some("b"), INTEGER()), 77 | (None, BOOLEAN()), 78 | (Some("blah"), STRING(nullable = true)) 79 | )), 80 | 81 | ("""select 1 between ? and ?""", List( 82 | (None, INTEGER()), (None, INTEGER()) 83 | )), 84 | 85 | ("""select 1 between 0 and ?""", List( 86 | (None, INTEGER()) 87 | )), 88 | 89 | ("""select 1 between ?[)""", List( 90 | (None, RANGE(INTEGER())) 91 | )), 92 | 93 | ("""select 1 between ?[:varchar)""", List( 94 | (None, RANGE(STRING())) 95 | )), 96 | 97 | ("""select ? between 0 and 5""", List( 98 | (None, INTEGER()) 99 | )), 100 | 101 | ("""select ? between 'a' and ?""", List( 102 | (None, STRING()), (None, STRING()) 103 | )), 104 | 105 | ("""select ? between ? and ?:timestamp""", List( 106 | (None, TIMESTAMP()), (None, TIMESTAMP()), (None, TIMESTAMP()) 107 | )), 108 | 109 | ("""select * from city where last_update between ? and ?""", List( 110 | (None, TIMESTAMP()), (None, TIMESTAMP()) 111 | )), 112 | 113 | ("""select ? between ? and (? + ?)""", List( 114 | (None, DECIMAL()), (None, DECIMAL()), (None, DECIMAL()), (None, DECIMAL()) 115 | )), 116 | 117 | ("""select 1 in (?)""", List( 118 | (None, INTEGER()) 119 | )), 120 | 121 | ("""select ? in (1,2,3,4)""", List( 122 | (None, INTEGER()) 123 | )), 124 | 125 | ("""select ? in ('a','b',?,'d')""", List( 126 | (None, STRING()), (None, STRING()) 127 | )), 128 | 129 | ("""select country_id in (?,?,?) from city""", List( 130 | (None, INTEGER()), (None, INTEGER()), (None, INTEGER()) 131 | )), 132 | 133 | ("""select 1 in ?{}""", List( 134 | (None, SET(INTEGER())) 135 | )), 136 | 137 | ( 138 | """ 139 | select country in ?{authorizedCountries} 140 | from city join country on city.country_id = country.country_id 141 | """, List( 142 | (Some("authorizedCountries"), SET(STRING(nullable = true))) 143 | )), 144 | 145 | ( 146 | """ 147 | select city 148 | from city join country on city.country_id = ? 149 | """, List( 150 | (None, INTEGER()) 151 | )), 152 | 153 | ( 154 | """ 155 | select 156 | case city_id 157 | when ? then 1 158 | when ? then 2 159 | else 0 160 | end 161 | from city 162 | """, List( 163 | (None, INTEGER()), (None, INTEGER()) 164 | )) 165 | ) 166 | 167 | // -- 168 | 169 | val SAKILA = DB(schemas = List( 170 | Schema( 171 | "sakila", 172 | tables = List( 173 | Table( 174 | "City", 175 | columns = List( 176 | Column("city_id", INTEGER(nullable = false)), 177 | Column("city", STRING(nullable = true)), 178 | Column("country_id", INTEGER(nullable = false)), 179 | Column("last_update", TIMESTAMP(nullable = false)) 180 | ) 181 | ), 182 | Table( 183 | "Country", 184 | columns = List( 185 | Column("country_id", INTEGER(nullable = false)), 186 | Column("country", STRING(nullable = true)), 187 | Column("last_update", TIMESTAMP(nullable = false)) 188 | ) 189 | ) 190 | ) 191 | ) 192 | )) 193 | 194 | // -- 195 | 196 | property("compute SQL-99 SELECT statements placeholders") { 197 | TableDrivenPropertyChecks.forAll(validSQL99SelectStatements) { 198 | case (sql, expectedPlaceholders) => 199 | VizSQL.parseQuery(sql, SAKILA) 200 | .fold(e => sys.error(s"Query doesn't parse: $e"), identity) 201 | .placeholders 202 | .fold(e => sys.error(s"Invalid query: $e"), _.namesAndTypes) should be (expectedPlaceholders) 203 | } 204 | } 205 | 206 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/FormatSQLSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import org.scalatest.prop.TableDrivenPropertyChecks 4 | import org.scalatest.{Matchers, EitherValues, PropSpec} 5 | 6 | class FormatSQLSpec extends PropSpec with Matchers with EitherValues { 7 | 8 | val examples = TableDrivenPropertyChecks.Table( 9 | ("Input SQL", "Formatted SQL"), 10 | 11 | ( 12 | """SELECT 1""", 13 | """ 14 | |SELECT 15 | | 1 16 | """.stripMargin 17 | ), 18 | 19 | ( 20 | """select district, sum(population) from city""", 21 | """ 22 | |SELECT 23 | | district, 24 | | SUM(population) 25 | |FROM 26 | | city 27 | """.stripMargin 28 | ), 29 | 30 | ( 31 | """select * from City as v join Country as p on v.country_id = p.country_id where city.name like ? AND population > 10000""", 32 | """ 33 | |SELECT 34 | | * 35 | |FROM 36 | | city AS v 37 | | JOIN country AS p 38 | | ON v.country_id = p.country_id 39 | |WHERE 40 | | city.name LIKE ? 41 | | AND population > 10000 42 | """.stripMargin 43 | ), 44 | 45 | ( 46 | """ 47 | SELECT 48 | d.device_name as device_name, 49 | z.description as zone_name, 50 | z.technology as technology, 51 | z.affiliate_name as affiliate_name, 52 | t.country_name as affiliate_country, 53 | t.country_level_1_name as affiliate_region, 54 | z.network_name as network_name, 55 | f.time_id as hour, 56 | CAST(date_trunc('day', f.time_id) as DATE) as day, 57 | CAST(date_trunc('week', f.time_id) as DATE) as week, 58 | CAST(date_trunc('month', f.time_id) as DATE) as month, 59 | CAST(date_trunc('quarter', f.time_id) as DATE) as quarter, 60 | 61 | SUM(displays) as displays, 62 | SUM(clicks) as clicks, 63 | SUM(sales) as sales, 64 | SUM(order_value_euro * r.rate) as order_value, 65 | SUM(revenue_euro * r.rate) as revenue, 66 | SUM(tac_euro * r.rate) as tac, 67 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac, 68 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue, 69 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac, 70 | ZEROIFNULL(SUM(clicks)/NULLIF(SUM(displays), 0.0)) as ctr, 71 | ZEROIFNULL(SUM(sales)/NULLIF(SUM(clicks), 0.0)) as cr, 72 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIF(SUM(revenue_euro), 0.0)) as margin, 73 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIF(SUM(marketplace_revenue_euro), 0.0)) as marketplace_margin, 74 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIF(SUM(clicks), 0.0)) as cpc, 75 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIF(SUM(displays), 0.0)) * 1000 as cpm 76 | FROM 77 | wopr.fact_zone_device_stats_hourly f 78 | JOIN wopr.dim_zone z 79 | ON z.zone_id = f.zone_id 80 | JOIN wopr.dim_device d 81 | ON d.device_id = f.device_id 82 | JOIN wopr.dim_country t 83 | ON t.country_id = f.affiliate_country_id 84 | JOIN wopr.fact_euro_rates_hourly r 85 | ON r.currency_id = ?currency_id AND f.time_id = r.time_id 86 | 87 | WHERE 88 | CAST(f.time_id AS DATE) between ?[day) 89 | AND t.country_code IN ?{publisher_countries} 90 | AND d.device_name IN ?{device_name} 91 | AND z.description IN ?{zone_name} 92 | AND z.technology IN ?{technology} 93 | AND z.affiliate_name IN ?{affiliate_name} 94 | AND t.country_name IN ?{affiliate_country} 95 | AND t.country_level_1_name IN ?{affiliate_region} 96 | AND z.network_name IN ?{network_name} 97 | 98 | GROUP BY ROLLUP(( 99 | d.device_name, 100 | z.description, 101 | z.technology, 102 | z.affiliate_name, 103 | t.country_name, 104 | z.network_name, 105 | t.country_level_1_name, 106 | f.time_id, 107 | CAST(date_trunc('day', f.time_id) as DATE), 108 | CAST(date_trunc('week', f.time_id) as DATE), 109 | CAST(date_trunc('month', f.time_id) as DATE), 110 | CAST(date_trunc('quarter', f.time_id) as DATE) 111 | )) 112 | HAVING SUM(clicks) > 0 113 | """, 114 | """ 115 | |SELECT 116 | | d.device_name AS device_name, 117 | | z.description AS zone_name, 118 | | z.technology AS technology, 119 | | z.affiliate_name AS affiliate_name, 120 | | t.country_name AS affiliate_country, 121 | | t.country_level_1_name AS affiliate_region, 122 | | z.network_name AS network_name, 123 | | f.time_id AS hour, 124 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE) AS day, 125 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE) AS week, 126 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE) AS month, 127 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE) AS quarter, 128 | | SUM(displays) AS displays, 129 | | SUM(clicks) AS clicks, 130 | | SUM(sales) AS sales, 131 | | SUM(order_value_euro * r.rate) AS order_value, 132 | | SUM(revenue_euro * r.rate) AS revenue, 133 | | SUM(tac_euro * r.rate) AS tac, 134 | | SUM((revenue_euro - tac_euro) * r.rate) AS revenue_ex_tac, 135 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue, 136 | | SUM((marketplace_revenue_euro - tac_euro) * r.rate) AS marketplace_revenue_ex_tac, 137 | | ZEROIFNULL(SUM(clicks) / NULLIF(SUM(displays), 0.0)) AS ctr, 138 | | ZEROIFNULL(SUM(sales) / NULLIF(SUM(clicks), 0.0)) AS cr, 139 | | ZEROIFNULL(SUM(revenue_euro - tac_euro) / NULLIF(SUM(revenue_euro), 0.0)) AS margin, 140 | | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro) / NULLIF(SUM(marketplace_revenue_euro), 0.0)) AS marketplace_margin, 141 | | ZEROIFNULL(SUM(revenue_euro * r.rate) / NULLIF(SUM(clicks), 0.0)) AS cpc, 142 | | ZEROIFNULL(SUM(tac_euro * r.rate) / NULLIF(SUM(displays), 0.0)) * 1000 AS cpm 143 | |FROM 144 | | wopr.fact_zone_device_stats_hourly AS f 145 | | JOIN wopr.dim_zone AS z 146 | | ON z.zone_id = f.zone_id 147 | | JOIN wopr.dim_device AS d 148 | | ON d.device_id = f.device_id 149 | | JOIN wopr.dim_country AS t 150 | | ON t.country_id = f.affiliate_country_id 151 | | JOIN wopr.fact_euro_rates_hourly AS r 152 | | ON r.currency_id = ?currency_id 153 | | AND f.time_id = r.time_id 154 | |WHERE 155 | | CAST(f.time_id AS DATE) BETWEEN ?day 156 | | AND t.country_code IN ?publisher_countries 157 | | AND d.device_name IN ?device_name 158 | | AND z.description IN ?zone_name 159 | | AND z.technology IN ?technology 160 | | AND z.affiliate_name IN ?affiliate_name 161 | | AND t.country_name IN ?affiliate_country 162 | | AND t.country_level_1_name IN ?affiliate_region 163 | | AND z.network_name IN ?network_name 164 | |GROUP BY 165 | | ROLLUP( 166 | | ( 167 | | d.device_name, 168 | | z.description, 169 | | z.technology, 170 | | z.affiliate_name, 171 | | t.country_name, 172 | | z.network_name, 173 | | t.country_level_1_name, 174 | | f.time_id, 175 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE), 176 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE), 177 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE), 178 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE) 179 | | ) 180 | | ) 181 | |HAVING 182 | | SUM(clicks) > 0 183 | """.stripMargin 184 | ), 185 | 186 | ( 187 | """select case when a =b then 1 when b <> 2 then 2 else 0 end""", 188 | """ 189 | |SELECT 190 | | CASE 191 | | WHEN a = b THEN 1 192 | | WHEN b <> 2 THEN 2 193 | | ELSE 0 194 | | END 195 | """.stripMargin 196 | ), 197 | 198 | ( 199 | """select 1,2 union all select 3,4 union all select 5,6""", 200 | """ 201 | |SELECT 202 | | 1, 203 | | 2 204 | |UNION ALL 205 | |SELECT 206 | | 3, 207 | | 4 208 | |UNION ALL 209 | |SELECT 210 | | 5, 211 | | 6 212 | """.stripMargin 213 | ), 214 | 215 | ( 216 | """select count(distinct woot)""", 217 | """ 218 | |SELECT 219 | | COUNT(DISTINCT woot) 220 | """.stripMargin 221 | ), 222 | 223 | ( 224 | """SELECT 1 LIMIT 10""", 225 | """ 226 | |SELECT 227 | | 1 228 | |LIMIT 10 229 | """.stripMargin 230 | ) 231 | ) 232 | 233 | // -- 234 | 235 | property("print SQL") { 236 | TableDrivenPropertyChecks.forAll(examples) { 237 | case (sql, expectedSQL) => 238 | (new SQL99Parser).parseStatement(sql) 239 | .fold(e => sys.error(s"\n\n${e.toString(sql)}\n"), identity).toSQL should be (expectedSQL.stripMargin.trim) 240 | } 241 | } 242 | 243 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/OlapSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import org.scalatest.prop.TableDrivenPropertyChecks 4 | import org.scalatest.{Matchers, EitherValues, FlatSpec} 5 | 6 | class OlapSpec extends FlatSpec with Matchers with EitherValues { 7 | 8 | implicit val VERTICA = new Dialect { 9 | 10 | def parser = new SQL99Parser 11 | 12 | val functions = SQLFunction.standard orElse { 13 | case "date_trunc" => new SQLFunction2 { 14 | def result = { case ((_, _), (_, t)) => Right(TIMESTAMP(nullable = t.nullable)) } 15 | } 16 | case "zeroifnull" => new SQLFunction1 { 17 | def result = { case (_, t) => Right(t.withNullable(false)) } 18 | } 19 | case "nullifzero" => new SQLFunction1 { 20 | def result = { case (_, t) => Right(t.withNullable(true)) } 21 | } 22 | }: PartialFunction[String,SQLFunction] 23 | } 24 | 25 | val BIDATA = DB(schemas = List( 26 | Schema( 27 | "wopr", 28 | tables = List( 29 | Table( 30 | "fact_zone_device_stats_hourly", 31 | columns = List( 32 | Column("time_id", TIMESTAMP(nullable = false)), 33 | Column("zone_id", INTEGER(nullable = false)), 34 | Column("device_id", INTEGER(nullable = false)), 35 | Column("affiliate_country_id", INTEGER(nullable = false)), 36 | Column("displays", INTEGER(nullable = true)), 37 | Column("clicks", INTEGER(nullable = true)), 38 | Column("revenue_euro", DECIMAL(nullable = true)), 39 | Column("order_value_euro", DECIMAL(nullable = true)), 40 | Column("sales", DECIMAL(nullable = true)), 41 | Column("marketplace_revenue_euro", DECIMAL(nullable = true)), 42 | Column("tac_euro", DECIMAL(nullable = true)), 43 | Column("network_id", INTEGER(nullable = false)) 44 | ) 45 | ), 46 | Table( 47 | "dim_zone", 48 | columns = List( 49 | Column("zone_id", INTEGER(nullable = false)), 50 | Column("description", STRING(nullable = true)), 51 | Column("network_name", STRING(nullable = true)), 52 | Column("affiliate_name", STRING(nullable = true)), 53 | Column("technology", STRING(nullable = true)) 54 | ) 55 | ), 56 | Table( 57 | "dim_device", 58 | columns = List( 59 | Column("device_id", INTEGER(nullable = false)), 60 | Column("device_name", STRING(nullable = true)) 61 | ) 62 | ), 63 | Table( 64 | "dim_country", 65 | columns = List( 66 | Column("country_id", INTEGER(nullable = false)), 67 | Column("country_code", STRING(nullable = true)), 68 | Column("country_name", STRING(nullable = true)), 69 | Column("country_level_1_name", STRING(nullable = true)) 70 | ) 71 | ), 72 | Table( 73 | "fact_euro_rates_hourly", 74 | columns = List( 75 | Column("time_id", TIMESTAMP(nullable = false)), 76 | Column("currency_id", INTEGER(nullable = false)), 77 | Column("rate", DECIMAL(nullable = true)) 78 | ) 79 | ), 80 | Table( 81 | "fact_portfolio", 82 | columns = List( 83 | Column("time_id", TIMESTAMP(nullable = false)), 84 | Column("affiliate_id", INTEGER(nullable = false)), 85 | Column("network_id", INTEGER(nullable = false)), 86 | Column("zone_id", INTEGER(nullable = false)), 87 | Column("device_id", INTEGER(nullable = false)), 88 | Column("affiliate_country_id", INTEGER(nullable = false)), 89 | Column("affiliate_region_name", STRING(nullable = false)), 90 | Column("displays", INTEGER(nullable = true)), 91 | Column("clicks", INTEGER(nullable = true)), 92 | Column("revenue_euro", DECIMAL(nullable = true)), 93 | Column("order_value_euro", DECIMAL(nullable = true)), 94 | Column("sales", DECIMAL(nullable = true)), 95 | Column("marketplace_revenue_euro", DECIMAL(nullable = true)), 96 | Column("tac_euro", DECIMAL(nullable = true)) 97 | ) 98 | ) 99 | ) 100 | ) 101 | )) 102 | 103 | val QUERY = 104 | """ 105 | SELECT 106 | d.device_name as device_name, 107 | z.description as zone_name, 108 | z.technology as technology, 109 | z.affiliate_name as affiliate_name, 110 | t.country_name as affiliate_country, 111 | t.country_level_1_name as affiliate_region, 112 | z.network_name as network_name, 113 | date_trunc('hour', f.time_id) as hour, 114 | date_trunc('day', f.time_id) as day, 115 | date_trunc('week', f.time_id) as week, 116 | date_trunc('month', f.time_id) as month, 117 | date_trunc('quarter', f.time_id) as quarter, 118 | 119 | SUM(displays) as displays, 120 | SUM(clicks) as clicks, 121 | SUM(sales) as sales, 122 | SUM(order_value_euro * r.rate) as order_value, 123 | SUM(revenue_euro * r.rate) as revenue, 124 | SUM(tac_euro * r.rate) as tac, 125 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac, 126 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue, 127 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac, 128 | ZEROIFNULL(SUM(clicks)/NULLIFZERO(SUM(displays))) as ctr, 129 | ZEROIFNULL(SUM(sales)/NULLIFZERO(SUM(clicks))) as cr, 130 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIFZERO(SUM(revenue_euro))) as margin, 131 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIFZERO(SUM(marketplace_revenue_euro))) as marketplace_margin, 132 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIFZERO(SUM(clicks))) as cpc, 133 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIFZERO(SUM(displays))) * 1000 as cpm 134 | FROM 135 | wopr.fact_zone_device_stats_hourly f 136 | JOIN wopr.dim_zone z 137 | ON z.zone_id = f.zone_id 138 | JOIN wopr.dim_device d 139 | ON d.device_id = f.device_id 140 | JOIN wopr.dim_country t 141 | ON t.country_name = t2.country_name 142 | JOIN wopr.fact_euro_rates_hourly r 143 | ON r.currency_id = ?currency_id AND f.time_id = r.time_id 144 | JOIN wopr.dim_country t2 145 | ON t2.country_id = f.affiliate_country_id 146 | 147 | WHERE 148 | CAST(f.time_id AS DATE) between ?[date_range) 149 | AND t.country_code IN ?{publisher_countries} 150 | AND d.device_name IN ?{device_name} 151 | AND z.description IN ?{zone_name} 152 | AND z.technology IN ?{technology} 153 | AND z.affiliate_name IN ?{affiliate_name} 154 | AND t.country_name IN ?{affiliate_country} 155 | AND t.country_level_1_name IN ?{affiliate_region} 156 | AND z.network_name IN ?{network_name} 157 | 158 | GROUP BY ROLLUP( 159 | ( 160 | d.device_name, 161 | z.description, 162 | z.technology, 163 | z.affiliate_name, 164 | t.country_name, 165 | z.network_name, 166 | t.country_level_1_name, 167 | date_trunc('hour', f.time_id), 168 | date_trunc('day', f.time_id), 169 | date_trunc('week', f.time_id), 170 | date_trunc('month', f.time_id), 171 | date_trunc('quarter', f.time_id) 172 | ) 173 | ) 174 | 175 | ORDER BY 176 | date_trunc('hour', f.time_id), 177 | date_trunc('day', f.time_id), 178 | date_trunc('week', f.time_id), 179 | date_trunc('month', f.time_id), 180 | date_trunc('quarter', f.time_id) 181 | """ 182 | 183 | // -- Actual tests 184 | 185 | val olapQuery = VizSQL.parseOlapQuery(QUERY, BIDATA).left.map(_.toString(QUERY)) 186 | 187 | "An OLAP query" should "recognize time dimensions" in { 188 | olapQuery.right.flatMap(_.getTimeDimensions) should be (Right(List( 189 | "hour", 190 | "day", 191 | "week", 192 | "month", 193 | "quarter" 194 | ))) 195 | } 196 | 197 | it should "recognize other dimensions" in { 198 | olapQuery.right.flatMap(_.getDimensions) should be (Right(List( 199 | "device_name", 200 | "zone_name", 201 | "technology", 202 | "affiliate_name", 203 | "affiliate_country", 204 | "affiliate_region", 205 | "network_name" 206 | ))) 207 | } 208 | 209 | it should "recognize metrics" in { 210 | olapQuery.right.flatMap(_.getMetrics) should be (Right(List( 211 | "displays", 212 | "clicks", 213 | "sales", 214 | "order_value", 215 | "revenue", 216 | "tac", 217 | "revenue_ex_tac", 218 | "marketplace_revenue", 219 | "marketplace_revenue_ex_tac", 220 | "ctr", 221 | "cr", 222 | "margin", 223 | "marketplace_margin", 224 | "cpc", 225 | "cpm" 226 | ))) 227 | } 228 | 229 | it should "recognize input parameters" in { 230 | olapQuery.right.flatMap(_.getParameters) should be (Right(List( 231 | "currency_id", 232 | "date_range", 233 | "publisher_countries", 234 | "device_name", 235 | "zone_name", 236 | "technology", 237 | "affiliate_name", 238 | "affiliate_country", 239 | "affiliate_region", 240 | "network_name" 241 | ))) 242 | } 243 | 244 | it should "retrieve the projection expression for dimensions/metrics" in { 245 | val examples = TableDrivenPropertyChecks.Table( 246 | ("Dimension or metric", "Expression"), 247 | 248 | ("revenue", """SUM(revenue_euro * r.rate) AS revenue"""), 249 | ("affiliate_region", """t.country_level_1_name AS affiliate_region""") 250 | ) 251 | 252 | TableDrivenPropertyChecks.forAll(examples) { 253 | case (dimensionOrMetric, expression) => 254 | olapQuery.right.flatMap(_.getProjection(dimensionOrMetric).right.map(_.toSQL)) should be (Right( 255 | expression 256 | )) 257 | } 258 | } 259 | 260 | it should "retrieve the tables needed for dimensions/metrics" in { 261 | val examples = TableDrivenPropertyChecks.Table( 262 | ("Dimension", "Used tables"), 263 | 264 | ("zone_name", List("z")), 265 | ("marketplace_margin", List("f")), 266 | ("cpc", List("f", "r")), 267 | ("affiliate_region", List("t")) 268 | ) 269 | 270 | TableDrivenPropertyChecks.forAll(examples) { 271 | case (dim, tables) => 272 | olapQuery.right.flatMap(q => 273 | q.getProjection(dim).right.map(_.expression).right.flatMap(OlapQuery.tablesFor(q.query.select, q.query.db, _)) 274 | ) should be (Right(tables)) 275 | } 276 | } 277 | 278 | it should "rewrite FROM relations based on a given table set" in { 279 | val examples = TableDrivenPropertyChecks.Table( 280 | ("Tables", "FROM clause"), 281 | 282 | (Set.empty[String], ""), 283 | 284 | (Set("f"), 285 | """wopr.fact_zone_device_stats_hourly AS f""" 286 | ), 287 | 288 | (Set("t"), 289 | """wopr.dim_country AS t""" 290 | ), 291 | 292 | (Set("f", "t"), 293 | """ 294 | |wopr.fact_zone_device_stats_hourly AS f 295 | |JOIN wopr.dim_country AS t 296 | | ON t.country_name = t2.country_name 297 | |JOIN wopr.dim_country AS t2 298 | | ON t2.country_id = f.affiliate_country_id 299 | """ 300 | ), 301 | 302 | (Set("f", "z", "d"), 303 | """ 304 | |wopr.fact_zone_device_stats_hourly AS f 305 | |JOIN wopr.dim_zone AS z 306 | | ON z.zone_id = f.zone_id 307 | |JOIN wopr.dim_device AS d 308 | | ON d.device_id = f.device_id 309 | """ 310 | ), 311 | 312 | (Set("f", "r"), 313 | """ 314 | |wopr.fact_zone_device_stats_hourly AS f 315 | |JOIN wopr.fact_euro_rates_hourly AS r 316 | | ON r.currency_id = ?currency_id 317 | | AND f.time_id = r.time_id 318 | """ 319 | ) 320 | ) 321 | 322 | TableDrivenPropertyChecks.forAll(examples) { 323 | case (tables, expectedSQL) => 324 | olapQuery.right.map(q => OlapQuery.rewriteRelations(q.query.select, BIDATA, tables) 325 | .map(_.toSQL).mkString(",\n")) should be (Right(expectedSQL.stripMargin.trim)) 326 | } 327 | } 328 | 329 | it should "rewrite WHERE conditions based on available parameters" in { 330 | val examples = TableDrivenPropertyChecks.Table( 331 | ("Parameters", "WHERE clause"), 332 | 333 | (Set.empty[String], ""), 334 | 335 | (Set("date_range", "device_name"), 336 | """ 337 | |CAST(f.time_id AS DATE) BETWEEN ?date_range 338 | |AND d.device_name IN ?device_name 339 | """ 340 | ), 341 | 342 | (Set("technology", "affiliate_region", "publisher_countries", "network_name"), 343 | """ 344 | |t.country_code IN ?publisher_countries 345 | |AND z.technology IN ?technology 346 | |AND t.country_level_1_name IN ?affiliate_region 347 | |AND z.network_name IN ?network_name 348 | """ 349 | ) 350 | ) 351 | 352 | TableDrivenPropertyChecks.forAll(examples) { 353 | case (parameters, expectedSQL) => 354 | olapQuery.right.map(q => 355 | OlapQuery.rewriteWhereCondition(q.query.select, parameters).map(_.toSQL).getOrElse("") 356 | ) should be (Right(expectedSQL.stripMargin.trim)) 357 | } 358 | } 359 | 360 | it should "rewrite GROUP BY expressions for given dimensions" in { 361 | val examples = TableDrivenPropertyChecks.Table( 362 | ("Dimensions", "GROUP BY clause"), 363 | 364 | (Set.empty[String], ""), 365 | 366 | (Set("hour"), 367 | """ 368 | |ROLLUP( 369 | | ( 370 | | DATE_TRUNC('hour', f.time_id) 371 | | ) 372 | |) 373 | """ 374 | ), 375 | 376 | (Set("affiliate_name", "affiliate_country"), 377 | """ 378 | |ROLLUP( 379 | | ( 380 | | z.affiliate_name, 381 | | t.country_name 382 | | ) 383 | |) 384 | """ 385 | ), 386 | 387 | (Set("affiliate_country", "affiliate_name", "network_name", "zone_name", "day"), 388 | """ 389 | |ROLLUP( 390 | | ( 391 | | z.description, 392 | | z.affiliate_name, 393 | | t.country_name, 394 | | z.network_name, 395 | | DATE_TRUNC('day', f.time_id) 396 | | ) 397 | |) 398 | """ 399 | ) 400 | 401 | ) 402 | 403 | TableDrivenPropertyChecks.forAll(examples) { 404 | case (dimensions, expectedSQL) => 405 | olapQuery.right.flatMap(q => 406 | dimensions.foldRight(Right(Nil):Either[Err,List[Expression]]) { 407 | (d, acc) => for(a <- acc.right; b <- q.getProjection(d).right.map(_.expression).right) yield b :: a 408 | }.right.map(expressions => OlapQuery.rewriteGroupBy(q.query.select, expressions).map(_.toSQL).mkString(",\n")) 409 | ) should be (Right(expectedSQL.stripMargin.trim)) 410 | } 411 | } 412 | 413 | it should "compute the right query" in { 414 | val examples = TableDrivenPropertyChecks.Table( 415 | ("Projection", "Selection", "SQL"), 416 | 417 | ( 418 | OlapProjection(Set("affiliate_country"), Set("displays", "clicks")), 419 | OlapSelection(Map.empty, Map.empty), 420 | """ 421 | |SELECT 422 | | t.country_name AS affiliate_country, 423 | | SUM(displays) AS displays, 424 | | SUM(clicks) AS clicks 425 | |FROM 426 | | wopr.fact_zone_device_stats_hourly AS f 427 | | JOIN wopr.dim_country AS t 428 | | ON t.country_name = t2.country_name 429 | | JOIN wopr.dim_country AS t2 430 | | ON t2.country_id = f.affiliate_country_id 431 | |GROUP BY 432 | | ROLLUP( 433 | | ( 434 | | t.country_name 435 | | ) 436 | | ) 437 | """ 438 | ), 439 | 440 | ( 441 | OlapProjection(Set("affiliate_country", "day"), Set("displays", "clicks", "marketplace_revenue")), 442 | OlapSelection( 443 | Map( 444 | "currency_id" -> 1, 445 | "publisher_countries" -> List("FR", "UK", "US", "DE"), 446 | "date_range" -> ("2015-09-10","2015-09-11") 447 | ), 448 | Map( 449 | "affiliate_country" -> List("FRANCE", "GERMANY"), 450 | "device_name" -> List("IPAD", "IPHONE") 451 | ) 452 | ), 453 | """ 454 | |SELECT 455 | | t.country_name AS affiliate_country, 456 | | DATE_TRUNC('day', f.time_id) AS day, 457 | | SUM(displays) AS displays, 458 | | SUM(clicks) AS clicks, 459 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue 460 | |FROM 461 | | wopr.fact_zone_device_stats_hourly AS f 462 | | JOIN wopr.dim_device AS d 463 | | ON d.device_id = f.device_id 464 | | JOIN wopr.dim_country AS t 465 | | ON t.country_name = t2.country_name 466 | | JOIN wopr.fact_euro_rates_hourly AS r 467 | | ON r.currency_id = 1 468 | | AND f.time_id = r.time_id 469 | | JOIN wopr.dim_country AS t2 470 | | ON t2.country_id = f.affiliate_country_id 471 | |WHERE 472 | | CAST(f.time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11' 473 | | AND t.country_code IN ('FR', 'UK', 'US', 'DE') 474 | | AND d.device_name IN ('IPAD', 'IPHONE') 475 | | AND t.country_name IN ('FRANCE', 'GERMANY') 476 | |GROUP BY 477 | | ROLLUP( 478 | | ( 479 | | t.country_name, 480 | | DATE_TRUNC('day', f.time_id) 481 | | ) 482 | | ) 483 | |ORDER BY 484 | | DATE_TRUNC('day', f.time_id) 485 | """ 486 | ) 487 | ) 488 | 489 | TableDrivenPropertyChecks.forAll(examples) { 490 | case (projection, selection, expectedSQL) => 491 | olapQuery.right.flatMap( 492 | _.computeQuery(projection, selection) 493 | ) should be (Right(expectedSQL.stripMargin.trim)) 494 | } 495 | } 496 | 497 | it should "rewrite metrics aggregate" in { 498 | val examples = TableDrivenPropertyChecks.Table( 499 | ("Metric", "Aggregate"), 500 | 501 | ("clicks", Some( 502 | SumPostAggregate( 503 | FunctionCallExpression("sum", None, args = List( 504 | ColumnExpression(ColumnIdent("clicks", None)) 505 | )) 506 | ) 507 | )), 508 | 509 | ("tac", Some( 510 | SumPostAggregate( 511 | FunctionCallExpression("sum", None, args = List( 512 | MathExpression( 513 | "*", 514 | ColumnExpression(ColumnIdent("tac_euro", None)), 515 | ColumnExpression(ColumnIdent("rate", Some(TableIdent("r", None)))) 516 | ) 517 | )) 518 | ) 519 | )), 520 | 521 | ("ctr", Some( 522 | DividePostAggregate( 523 | SumPostAggregate( 524 | FunctionCallExpression("sum", None, args = List( 525 | ColumnExpression(ColumnIdent("clicks", None)) 526 | )) 527 | ), 528 | SumPostAggregate( 529 | FunctionCallExpression("sum", None, args = List( 530 | ColumnExpression(ColumnIdent("displays", None)) 531 | )) 532 | ) 533 | ) 534 | )) 535 | 536 | ) 537 | 538 | TableDrivenPropertyChecks.forAll(examples) { 539 | case (metric, aggregateExpression) => 540 | olapQuery.right.flatMap(_.rewriteMetricAggregate(metric)) should be (Right(aggregateExpression)) 541 | } 542 | } 543 | 544 | it should "work with a subselect as relation" in { 545 | val QUERY = 546 | """ 547 | SELECT 548 | device as device, 549 | zone as zone, 550 | affiliate as affiliate, 551 | affiliate_country as affiliate_country, 552 | affiliate_region as affiliate_region, 553 | network as network, 554 | 555 | hour as hour, 556 | day as day, 557 | week as week, 558 | month as month, 559 | quarter as quarter, 560 | 561 | SUM(displays) as displays, 562 | SUM(clicks) as clicks, 563 | SUM(sales) as sales, 564 | SUM(order_value_euro * r.rate) as order_value, 565 | SUM(revenue_euro * r.rate) as revenue, 566 | SUM(tac_euro * r.rate) as tac, 567 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac, 568 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue, 569 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac, 570 | ZEROIFNULL(SUM(clicks)/NULLIFZERO(SUM(displays))) as ctr, 571 | ZEROIFNULL(SUM(sales)/NULLIFZERO(SUM(clicks))) as cr, 572 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIFZERO(SUM(revenue_euro))) as margin, 573 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIFZERO(SUM(marketplace_revenue_euro))) as marketplace_margin, 574 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIFZERO(SUM(clicks))) as cpc, 575 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIFZERO(SUM(displays))) * 1000 as cpm 576 | 577 | FROM 578 | ( 579 | SELECT 580 | date_trunc('month', time_id) AS x, 581 | 582 | device_id as device, 583 | zone_id as zone, 584 | affiliate_id as affiliate, 585 | affiliate_country_id as affiliate_country, 586 | affiliate_region_name as affiliate_region, 587 | network_id as network, 588 | 589 | date_trunc('hour', time_id) as hour, 590 | date_trunc('day', time_id) as day, 591 | date_trunc('week', time_id) as week, 592 | date_trunc('month', time_id) as month, 593 | date_trunc('quarter', time_id) as quarter, 594 | 595 | SUM(displays) as displays, 596 | SUM(clicks) as clicks, 597 | SUM(sales) as sales, 598 | SUM(order_value_euro) as order_value_euro, 599 | SUM(revenue_euro) as revenue_euro, 600 | SUM(tac_euro) as tac_euro, 601 | SUM(marketplace_revenue_euro) as marketplace_revenue_euro 602 | 603 | FROM 604 | wopr.fact_portfolio 605 | 606 | WHERE 607 | CAST(time_id AS DATE) BETWEEN ?[day) 608 | AND affiliate_country_id IN ?{publisher_countries} 609 | AND device_id IN ?{device} 610 | AND zone_id IN ?{zone} 611 | AND affiliate_id IN ?{affiliate} 612 | AND affiliate_country_id IN ?{affiliate_country} 613 | AND affiliate_region_name IN ?{affiliate_region} 614 | AND network_id IN ?{network} 615 | 616 | GROUP BY 617 | device_id, 618 | zone_id, 619 | affiliate_id, 620 | affiliate_country_id, 621 | network_id, 622 | affiliate_region_name, 623 | date_trunc('hour', time_id), 624 | date_trunc('day', time_id), 625 | date_trunc('week', time_id), 626 | date_trunc('month', time_id), 627 | date_trunc('quarter', time_id) 628 | ) as f 629 | JOIN wopr.fact_euro_rates_hourly as r 630 | ON r.currency_id = ?currency_id AND x = r.time_id 631 | 632 | GROUP BY ROLLUP(( 633 | device, 634 | zone, 635 | affiliate, 636 | affiliate_country, 637 | network, 638 | affiliate_region, 639 | hour, 640 | day, 641 | week, 642 | month, 643 | quarter 644 | )) 645 | 646 | ORDER BY 647 | hour, 648 | day, 649 | week, 650 | month, 651 | quarter 652 | """ 653 | 654 | val olapQuery = VizSQL.parseOlapQuery(QUERY, BIDATA).left.map(e => sys.error(e.toString(QUERY))).right.get 655 | 656 | olapQuery.getTimeDimensions should be (Right(List( 657 | "hour", 658 | "day", 659 | "week", 660 | "month", 661 | "quarter" 662 | ))) 663 | 664 | olapQuery.getDimensions should be (Right(List( 665 | "device", 666 | "zone", 667 | "affiliate", 668 | "affiliate_country", 669 | "affiliate_region", 670 | "network" 671 | ))) 672 | 673 | olapQuery.getMetrics should be (Right(List( 674 | "displays", 675 | "clicks", 676 | "sales", 677 | "order_value", 678 | "revenue", 679 | "tac", 680 | "revenue_ex_tac", 681 | "marketplace_revenue", 682 | "marketplace_revenue_ex_tac", 683 | "ctr", 684 | "cr", 685 | "margin", 686 | "marketplace_margin", 687 | "cpc", 688 | "cpm" 689 | ))) 690 | 691 | olapQuery.getParameters should be (Right(List( 692 | "day", 693 | "publisher_countries", 694 | "device", 695 | "zone", 696 | "affiliate", 697 | "affiliate_country", 698 | "affiliate_region", 699 | "network", 700 | "currency_id" 701 | ))) 702 | 703 | olapQuery.computeQuery( 704 | OlapProjection(Set("device"), Set("clicks")), 705 | OlapSelection(Map("currency_id" -> 1, "publisher_countries" -> List(1,2,3), "day" -> ("2015-09-10","2015-09-11")), Map.empty) 706 | ).left.map(_.toString(QUERY)) should be (Right( 707 | """ 708 | |SELECT 709 | | device AS device, 710 | | SUM(clicks) AS clicks 711 | |FROM 712 | | ( 713 | | SELECT 714 | | device_id AS device, 715 | | SUM(clicks) AS clicks 716 | | FROM 717 | | wopr.fact_portfolio 718 | | WHERE 719 | | CAST(time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11' 720 | | AND affiliate_country_id IN (1, 2, 3) 721 | | GROUP BY 722 | | device_id 723 | | ) AS f 724 | |GROUP BY 725 | | ROLLUP( 726 | | ( 727 | | device 728 | | ) 729 | | ) 730 | """.stripMargin.trim 731 | )) 732 | 733 | olapQuery.computeQuery( 734 | OlapProjection(Set("device", "hour"), Set("clicks")), 735 | OlapSelection(Map("currency_id" -> 1, "publisher_countries" -> List(1,2,3), "day" -> ("2015-09-10","2015-09-11")), Map.empty) 736 | ).left.map(_.toString(QUERY)) should be (Right( 737 | """ 738 | |SELECT 739 | | device AS device, 740 | | hour AS hour, 741 | | SUM(clicks) AS clicks 742 | |FROM 743 | | ( 744 | | SELECT 745 | | device_id AS device, 746 | | DATE_TRUNC('hour', time_id) AS hour, 747 | | SUM(clicks) AS clicks 748 | | FROM 749 | | wopr.fact_portfolio 750 | | WHERE 751 | | CAST(time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11' 752 | | AND affiliate_country_id IN (1, 2, 3) 753 | | GROUP BY 754 | | device_id, 755 | | DATE_TRUNC('hour', time_id) 756 | | ) AS f 757 | |GROUP BY 758 | | ROLLUP( 759 | | ( 760 | | device, 761 | | hour 762 | | ) 763 | | ) 764 | |ORDER BY 765 | | hour 766 | """.stripMargin.trim 767 | )) 768 | 769 | olapQuery.computeQuery( 770 | OlapProjection(Set("device", "hour"), Set("clicks", "marketplace_revenue")), 771 | OlapSelection(Map("currency_id" -> 1, "publisher_countries" -> List(1,2,3), "day" -> ("2015-09-10","2015-09-11")), Map.empty) 772 | ).left.map(_.toString(QUERY)) should be (Right( 773 | """ 774 | |SELECT 775 | | device AS device, 776 | | hour AS hour, 777 | | SUM(clicks) AS clicks, 778 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue 779 | |FROM 780 | | ( 781 | | SELECT 782 | | DATE_TRUNC('month', time_id) AS x, 783 | | device_id AS device, 784 | | DATE_TRUNC('hour', time_id) AS hour, 785 | | SUM(clicks) AS clicks, 786 | | SUM(marketplace_revenue_euro) AS marketplace_revenue_euro 787 | | FROM 788 | | wopr.fact_portfolio 789 | | WHERE 790 | | CAST(time_id AS DATE) BETWEEN '2015-09-10' AND '2015-09-11' 791 | | AND affiliate_country_id IN (1, 2, 3) 792 | | GROUP BY 793 | | device_id, 794 | | DATE_TRUNC('hour', time_id), 795 | | DATE_TRUNC('month', time_id) 796 | | ) AS f 797 | | JOIN wopr.fact_euro_rates_hourly AS r 798 | | ON r.currency_id = 1 799 | | AND x = r.time_id 800 | |GROUP BY 801 | | ROLLUP( 802 | | ( 803 | | device, 804 | | hour 805 | | ) 806 | | ) 807 | |ORDER BY 808 | | hour 809 | """.stripMargin.trim 810 | )) 811 | } 812 | 813 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/OptimizeSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import org.scalatest.{Matchers, FlatSpec} 4 | 5 | class OptimizeSpec extends FlatSpec with Matchers { 6 | 7 | val CASE_QUERY = 8 | """SELECT 9 | | country, 10 | | CASE 1 WHEN 1 THEN 13 11 | | WHEN clicks THEN 9 END, 12 | | CASE 1 WHEN clicks THEN 65 13 | | WHEN 1 THEN 12 END, 14 | | CASE clicks WHEN clicks THEN 1 15 | | WHEN 42 then 77 16 | | ELSE 18 END, 17 | | CASE WHEN 1 > 2 THEN 42 18 | | WHEN 2 > 1 THEN 38 19 | | WHEN CLICKS > DISPLAYS THEN 654 END, 20 | | CASE 5 WHEN 1 THEN 2 21 | | WHEN 3 THEN 4 22 | | ELSE 10 END, 23 | | CASE 5 WHEN 1 THEN 2 24 | | WHEN 3 THEN 4 25 | | END, 26 | | CASE 42 WHEN 10 THEN 2 27 | | WHEN displays THEN 12 28 | | ELSE 88 29 | | END 30 | | FROM facts 31 | """.stripMargin 32 | 33 | val OPERATOR_QUERY = 34 | """SELECT 35 | |country, 36 | |3 + 2, 37 | |10/5, 38 | |NULL / 2, 39 | |3 * 4, 40 | |5 / 0, 41 | |'Marco' + 'Polo', 42 | |44 - 2, 43 | | -(-(-2)) 44 | |FROM facts 45 | |WHERE (3 > 2) AND 'X' = 'Y' 46 | """.stripMargin 47 | 48 | val JOIN_QUERY = 49 | """SELECT 50 | |f.region, 51 | |f.country, 52 | |f.date, 53 | |c.comment 54 | |FROM facts AS f 55 | |JOIN review AS c ON f.idfacts = c.idfacts AND 2 > 3 56 | """.stripMargin 57 | 58 | val SUB_QUERY= 59 | """SELECT 60 | |x.result, 61 | |(SELECT comment FROM review WHERE idfacts = 42 AND (2 + 2 = 4)) AS comment 62 | |FROM (SELECT (2 + 2) AS result) AS x 63 | |WHERE (SELECT 2 = 2) 64 | """.stripMargin 65 | 66 | val testDb = { 67 | val cols = for { (name, coltype) <- List(("idfacts", INTEGER(false)), 68 | ("REGION", STRING(false)), 69 | ("COUNTRY", STRING(false)), 70 | ("CLICKS", INTEGER(false)), 71 | ("DISPLAYS", INTEGER(false)), 72 | ("DATE", DATE(true)))} 73 | yield Column(name, coltype) 74 | val cols2 = for { (name, coltype) <- List(("idfacts", INTEGER(false)), 75 | ("COMMENT", STRING(false)), 76 | ("COMMENTER", STRING(false)), 77 | ("RATING", INTEGER(true)))} 78 | yield Column(name, coltype) 79 | implicit val dialect = sql99.dialect 80 | DB(List(Schema("", List(Table("facts", cols), Table("review", cols2))))) 81 | } 82 | 83 | val subQuery = VizSQL.parseQuery(SUB_QUERY, testDb) 84 | val caseQuery = VizSQL.parseQuery(CASE_QUERY, testDb) 85 | val opQuery = VizSQL.parseQuery(OPERATOR_QUERY, testDb) 86 | val joinQuery = VizSQL.parseQuery(JOIN_QUERY, testDb) 87 | 88 | val queries = List(subQuery, caseQuery, opQuery, joinQuery) 89 | 90 | "The Optimizer" should "Simplify \"CASE\" statements" in { 91 | caseQuery.isRight shouldBe true 92 | val expected = 93 | """SELECT 94 | |country, 95 | |13, 96 | |CASE 1 WHEN clicks THEN 65 97 | | WHEN 1 THEN 12 END, 98 | |CASE clicks WHEN clicks THEN 1 99 | | WHEN 42 then 77 100 | | ELSE 18 END, 101 | |38, 102 | |10, 103 | |NULL, 104 | |CASE 42 105 | | WHEN displays THEN 12 106 | | ELSE 88 107 | | END 108 | |FROM facts 109 | """.stripMargin 110 | val query = caseQuery.right.get 111 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get 112 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL) 113 | } 114 | 115 | it should "pre-compute literal expressions" in { 116 | val expected = 117 | """SELECT 118 | |country, 119 | |5, 120 | |2, 121 | |NULL, 122 | |12, 123 | |5 / 0, 124 | |'Marco' + 'Polo', 125 | |42, 126 | |-2 127 | |FROM facts 128 | |WHERE FALSE 129 | """.stripMargin 130 | opQuery.isRight shouldBe true 131 | val query = opQuery.right.get 132 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get 133 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL) 134 | } 135 | 136 | it should "be idempotent" in { 137 | for {q <- queries } { 138 | q.isRight shouldBe true 139 | val query = q.right.get 140 | val optimized = Optimizer.optimize(query) 141 | val superOptimized = Optimizer.optimize(optimized) 142 | superOptimized should be(optimized) 143 | } 144 | } 145 | 146 | it should "remove unnecessary tables" in { 147 | val QUERY = 148 | """SELECT 149 | |3.0 / 2 150 | |FROM facts 151 | """.stripMargin 152 | val expected = 153 | """SELECT 154 | |1.5 155 | """.stripMargin 156 | val query = VizSQL.parseQuery(QUERY, testDb).right.get 157 | val ref = VizSQL.parseQuery(expected, testDb).right.get 158 | Optimizer.optimize(query).sql should be (ref.select.toSQL) 159 | val QUERY2 = 160 | """SELECT 161 | |r.comment, 162 | |3.0 / 2 163 | |FROM facts as f 164 | |JOIN review as r ON r.idfact = f.idfact 165 | """.stripMargin 166 | val expected2 = 167 | """SELECT 168 | |r.comment, 169 | |1.5 170 | |FROM review as r 171 | """.stripMargin 172 | val query2 = VizSQL.parseQuery(QUERY2, testDb).right.get 173 | val ref2 = VizSQL.parseQuery(expected2, testDb).right.get 174 | Optimizer.optimize(query2).sql should be (ref2.select.toSQL) 175 | } 176 | 177 | it should "rewrite sons of expressions it can't optimize" in { 178 | val QUERY = """SELECT 179 | |SUM(3.0 / 2), 180 | |CAST(3 + 2 AS INTEGER) / 2, 181 | |(5 > 3) 182 | |FROM facts 183 | """.stripMargin 184 | val expected = 185 | """SELECT 186 | |SUM(1.5), 187 | |CAST(5 AS INTEGER) / 2, 188 | |TRUE 189 | """.stripMargin 190 | val query = VizSQL.parseQuery(QUERY, testDb).right.get 191 | val ref = VizSQL.parseQuery(expected, testDb).right.get 192 | Optimizer.optimize(query).sql should be (ref.select.toSQL) 193 | } 194 | 195 | it should "optimize joins" in { 196 | val expected = 197 | """SELECT 198 | |f.region, 199 | |f.country, 200 | |f.date, 201 | |c.comment 202 | |FROM facts AS f 203 | |JOIN review AS c ON f.idfacts = c.idfacts AND FALSE 204 | """.stripMargin 205 | val query = joinQuery.right.get 206 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get 207 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL) 208 | } 209 | 210 | it should "handle Subselect" in { 211 | val expected = 212 | """SELECT 213 | |x.result, 214 | |(SELECT comment FROM review WHERE idfacts = 42 AND TRUE) AS comment 215 | |FROM (SELECT 4 AS result) AS x 216 | |WHERE (SELECT TRUE) 217 | """.stripMargin 218 | val query = subQuery.right.get 219 | val expectedQ = VizSQL.parseQuery(expected, testDb).right.get 220 | Optimizer.optimize(query).sql should be (expectedQ.select.toSQL) 221 | } 222 | 223 | } 224 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/ParseVerticaDialectSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import com.criteo.vizatra.vizsql.vertica._ 4 | import org.scalatest.prop.TableDrivenPropertyChecks 5 | import org.scalatest.{EitherValues, Matchers, PropSpec} 6 | 7 | class ParseVerticaDialectSpec extends PropSpec with Matchers with EitherValues { 8 | 9 | val validVerticaSelectStatements = TableDrivenPropertyChecks.Table( 10 | ("SQL", "Expected Columns"), 11 | ("""SELECT 12 | | NOW() as now, 13 | | MAX(last_update) + 3599 / 86400 AS last_update, 14 | | CONCAT('The most recent update was on ', TO_CHAR(MAX(last_update) + 3599 / 86400, 'YYYY-MM-DD at HH:MI')) as content 15 | |FROM 16 | | City""".stripMargin, 17 | List( 18 | Column("now", TIMESTAMP(nullable = false)), 19 | Column("last_update", TIMESTAMP(nullable = false)), 20 | Column("content", STRING(nullable = false)) 21 | )) 22 | ) 23 | 24 | // -- 25 | 26 | val SAKILA = DB(schemas = List( 27 | Schema( 28 | "sakila", 29 | tables = List( 30 | Table( 31 | "City", 32 | columns = List( 33 | Column("city_id", INTEGER(nullable = false)), 34 | Column("city", STRING(nullable = false)), 35 | Column("country_id", INTEGER(nullable = false)), 36 | Column("last_update", TIMESTAMP(nullable = false)) 37 | ) 38 | ), 39 | Table( 40 | "Country", 41 | columns = List( 42 | Column("country_id", INTEGER(nullable = false)), 43 | Column("country", STRING(nullable = false)), 44 | Column("last_update", TIMESTAMP(nullable = false)) 45 | ) 46 | ) 47 | ) 48 | ) 49 | )) 50 | 51 | // -- 52 | 53 | property("extract Vertica SELECT statements columns") { 54 | TableDrivenPropertyChecks.forAll(validVerticaSelectStatements) { 55 | case (sql, expectedColumns) => 56 | VizSQL.parseQuery(sql, SAKILA) 57 | .fold(e => sys.error(s"Query doesn't parse: $e"), identity) 58 | .columns 59 | .fold(e => sys.error(s"Invalid query: $e"), identity) should be (expectedColumns) 60 | } 61 | } 62 | 63 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/ParsingErrorsSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import sql99._ 4 | import org.scalatest.prop.TableDrivenPropertyChecks 5 | import org.scalatest.{Matchers, EitherValues, PropSpec} 6 | 7 | class ParsingErrorsSpec extends PropSpec with Matchers with EitherValues { 8 | 9 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table( 10 | ("SQL", "Expected error"), 11 | 12 | ( 13 | """xxx""", 14 | """|xxx 15 | |^ 16 | |Error: select expected 17 | """ 18 | ), 19 | ( 20 | """select""", 21 | """|select 22 | | ^ 23 | |Error: *, table or expression expected 24 | """ 25 | ), 26 | ( 27 | """select 1 +""", 28 | """|select 1 + 29 | | ^ 30 | |Error: expression expected 31 | """ 32 | ), 33 | ( 34 | """select 1 + *""", 35 | """|select 1 + * 36 | | ^ 37 | |Error: expression expected 38 | """ 39 | ), 40 | ( 41 | """select (1 + 3""", 42 | """|select (1 + 3 43 | | ^ 44 | |Error: ) expected 45 | """ 46 | ), 47 | ( 48 | """select * from""", 49 | """|select * from 50 | | ^ 51 | |Error: table, join or subselect expected 52 | """ 53 | ), 54 | ( 55 | """select * from (selet 1)""", 56 | """|select * from (selet 1) 57 | | ^ 58 | |Error: select expected 59 | """ 60 | ), 61 | ( 62 | """select * from (select 1sh);""", 63 | """|select * from (select 1sh); 64 | | ^ 65 | |Error: ident expected 66 | """ 67 | ), 68 | ( 69 | """select * from (select 1)sh)""", 70 | """|select * from (select 1)sh) 71 | | ^ 72 | |Error: ; expected 73 | """ 74 | ), 75 | ( 76 | """SELECT CustomerName; City FROM Customers;""", 77 | """|SELECT CustomerName; City FROM Customers; 78 | | ^ 79 | |Error: end of statement expected 80 | """ 81 | ), 82 | ( 83 | """SELECT CustomerName FROM Customers UNION ALL""", 84 | """|SELECT CustomerName FROM Customers UNION ALL 85 | | ^ 86 | |Error: select expected 87 | """ 88 | ) 89 | ) 90 | 91 | // -- 92 | 93 | property("report parsing errors on invalid SQL-99 SELECT statements") { 94 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) { 95 | case (sql, expectedError) => 96 | (new SQL99Parser).parseStatement(sql) 97 | .fold(_.toString(sql, ' ').trim, _ => "[NO ERROR]") should be (expectedError.toString.stripMargin.trim) 98 | } 99 | } 100 | 101 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/SchemaErrorsSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import org.scalatest.prop.TableDrivenPropertyChecks 4 | import org.scalatest.{EitherValues, Matchers, PropSpec} 5 | import sql99._ 6 | 7 | /** 8 | * Test cases for schema errors 9 | */ 10 | class SchemaErrorsSpec extends PropSpec with Matchers with EitherValues { 11 | 12 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table( 13 | ("SQL", "Expected error"), 14 | ( 15 | "SELECT region from City as C1 JOIN Country as C2 ON C1.country_id = C2.country_id WHERE region < 42", 16 | SchemaError("ambiguous column region", 6) 17 | ),( 18 | "SELECT nonexistent, region from City", 19 | SchemaError("column not found nonexistent", 6) 20 | ) 21 | ) 22 | 23 | // -- 24 | 25 | val SAKILA = DB(schemas = List( 26 | Schema( 27 | "sakila", 28 | tables = List( 29 | Table( 30 | "City", 31 | columns = List( 32 | Column("city_id", INTEGER(nullable = false)), 33 | Column("city", STRING(nullable = false)), 34 | Column("country_id", INTEGER(nullable = false)), 35 | Column("last_update", TIMESTAMP(nullable = false)), 36 | Column("region", INTEGER(nullable = false)) 37 | ) 38 | ), 39 | Table( 40 | "Country", 41 | columns = List( 42 | Column("country_id", INTEGER(nullable = false)), 43 | Column("country", STRING(nullable = false)), 44 | Column("last_update", TIMESTAMP(nullable = false)), 45 | Column("region", INTEGER(nullable = false)) 46 | ) 47 | ) 48 | ) 49 | ) 50 | )) 51 | 52 | // -- 53 | 54 | property("report schema errors on invalid SQL-99 SELECT statements") { 55 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) { 56 | case (sql, expectedError) => 57 | VizSQL.parseQuery(sql, SAKILA) 58 | .fold( 59 | e => sys.error(s"Query doesn't parse: $e"), 60 | _.error.getOrElse(sys.error(s"Query should not type!")) 61 | ) should be (expectedError) 62 | } 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/ThreadSafetySpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import org.scalatest.prop.TableDrivenPropertyChecks 4 | import org.scalatest.{Matchers, EitherValues, FlatSpec} 5 | 6 | class ThreadSafetySpec extends FlatSpec with Matchers with EitherValues { 7 | 8 | "An SQL parser" should "be thread safe" in { 9 | val uniqueParser = new SQL99Parser 10 | 11 | import scala.util._ 12 | import concurrent._ 13 | import duration._ 14 | import ExecutionContext.Implicits.global 15 | 16 | val randomDataSet = (1 to 10).map { _ => 17 | (1 to 100).map(_ => Random.nextInt(4)).map { 18 | case 0 => 19 | ( 20 | """SELECT 1""", 21 | """ 22 | |SELECT 23 | | 1 24 | """.stripMargin 25 | ) 26 | case 1 => 27 | ( 28 | """select district, sum(population) from city""", 29 | """ 30 | |SELECT 31 | | district, 32 | | SUM(population) 33 | |FROM 34 | | city 35 | """.stripMargin 36 | ) 37 | case 2 => 38 | ( 39 | """select * from City as v join Country as p on v.country_id = p.country_id where city.name like ? AND population > 10000""", 40 | """ 41 | |SELECT 42 | | * 43 | |FROM 44 | | city AS v 45 | | JOIN country AS p 46 | | ON v.country_id = p.country_id 47 | |WHERE 48 | | city.name LIKE ? 49 | | AND population > 10000 50 | """.stripMargin 51 | ) 52 | case 3 => 53 | ( 54 | """ 55 | SELECT 56 | d.device_name as device_name, 57 | z.description as zone_name, 58 | z.technology as technology, 59 | z.affiliate_name as affiliate_name, 60 | t.country_name as affiliate_country, 61 | t.country_level_1_name as affiliate_region, 62 | z.network_name as network_name, 63 | f.time_id as hour, 64 | CAST(date_trunc('day', f.time_id) as DATE) as day, 65 | CAST(date_trunc('week', f.time_id) as DATE) as week, 66 | CAST(date_trunc('month', f.time_id) as DATE) as month, 67 | CAST(date_trunc('quarter', f.time_id) as DATE) as quarter, 68 | 69 | SUM(displays) as displays, 70 | SUM(clicks) as clicks, 71 | SUM(sales) as sales, 72 | SUM(order_value_euro * r.rate) as order_value, 73 | SUM(revenue_euro * r.rate) as revenue, 74 | SUM(tac_euro * r.rate) as tac, 75 | SUM((revenue_euro - tac_euro) * r.rate) as revenue_ex_tac, 76 | SUM(marketplace_revenue_euro * r.rate) as marketplace_revenue, 77 | SUM((marketplace_revenue_euro - tac_euro) * r.rate) as marketplace_revenue_ex_tac, 78 | ZEROIFNULL(SUM(clicks)/NULLIF(SUM(displays), 0.0)) as ctr, 79 | ZEROIFNULL(SUM(sales)/NULLIF(SUM(clicks), 0.0)) as cr, 80 | ZEROIFNULL(SUM(revenue_euro - tac_euro)/NULLIF(SUM(revenue_euro), 0.0)) as margin, 81 | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro)/NULLIF(SUM(marketplace_revenue_euro), 0.0)) as marketplace_margin, 82 | ZEROIFNULL(SUM(revenue_euro * r.rate)/NULLIF(SUM(clicks), 0.0)) as cpc, 83 | ZEROIFNULL(SUM(tac_euro * r.rate)/NULLIF(SUM(displays), 0.0)) * 1000 as cpm 84 | FROM 85 | wopr.fact_zone_device_stats_hourly f 86 | JOIN wopr.dim_zone z 87 | ON z.zone_id = f.zone_id 88 | JOIN wopr.dim_device d 89 | ON d.device_id = f.device_id 90 | JOIN wopr.dim_country t 91 | ON t.country_id = f.affiliate_country_id 92 | JOIN wopr.fact_euro_rates_hourly r 93 | ON r.currency_id = ?currency_id AND f.time_id = r.time_id 94 | 95 | WHERE 96 | CAST(f.time_id AS DATE) between ?[day) 97 | AND t.country_code IN ?{publisher_countries} 98 | AND d.device_name IN ?{device_name} 99 | AND z.description IN ?{zone_name} 100 | AND z.technology IN ?{technology} 101 | AND z.affiliate_name IN ?{affiliate_name} 102 | AND t.country_name IN ?{affiliate_country} 103 | AND t.country_level_1_name IN ?{affiliate_region} 104 | AND z.network_name IN ?{network_name} 105 | 106 | GROUP BY ROLLUP(( 107 | d.device_name, 108 | z.description, 109 | z.technology, 110 | z.affiliate_name, 111 | t.country_name, 112 | z.network_name, 113 | t.country_level_1_name, 114 | f.time_id, 115 | CAST(date_trunc('day', f.time_id) as DATE), 116 | CAST(date_trunc('week', f.time_id) as DATE), 117 | CAST(date_trunc('month', f.time_id) as DATE), 118 | CAST(date_trunc('quarter', f.time_id) as DATE) 119 | )) 120 | HAVING SUM(clicks) > 0 121 | """, 122 | """ 123 | |SELECT 124 | | d.device_name AS device_name, 125 | | z.description AS zone_name, 126 | | z.technology AS technology, 127 | | z.affiliate_name AS affiliate_name, 128 | | t.country_name AS affiliate_country, 129 | | t.country_level_1_name AS affiliate_region, 130 | | z.network_name AS network_name, 131 | | f.time_id AS hour, 132 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE) AS day, 133 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE) AS week, 134 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE) AS month, 135 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE) AS quarter, 136 | | SUM(displays) AS displays, 137 | | SUM(clicks) AS clicks, 138 | | SUM(sales) AS sales, 139 | | SUM(order_value_euro * r.rate) AS order_value, 140 | | SUM(revenue_euro * r.rate) AS revenue, 141 | | SUM(tac_euro * r.rate) AS tac, 142 | | SUM((revenue_euro - tac_euro) * r.rate) AS revenue_ex_tac, 143 | | SUM(marketplace_revenue_euro * r.rate) AS marketplace_revenue, 144 | | SUM((marketplace_revenue_euro - tac_euro) * r.rate) AS marketplace_revenue_ex_tac, 145 | | ZEROIFNULL(SUM(clicks) / NULLIF(SUM(displays), 0.0)) AS ctr, 146 | | ZEROIFNULL(SUM(sales) / NULLIF(SUM(clicks), 0.0)) AS cr, 147 | | ZEROIFNULL(SUM(revenue_euro - tac_euro) / NULLIF(SUM(revenue_euro), 0.0)) AS margin, 148 | | ZEROIFNULL(SUM(marketplace_revenue_euro - tac_euro) / NULLIF(SUM(marketplace_revenue_euro), 0.0)) AS marketplace_margin, 149 | | ZEROIFNULL(SUM(revenue_euro * r.rate) / NULLIF(SUM(clicks), 0.0)) AS cpc, 150 | | ZEROIFNULL(SUM(tac_euro * r.rate) / NULLIF(SUM(displays), 0.0)) * 1000 AS cpm 151 | |FROM 152 | | wopr.fact_zone_device_stats_hourly AS f 153 | | JOIN wopr.dim_zone AS z 154 | | ON z.zone_id = f.zone_id 155 | | JOIN wopr.dim_device AS d 156 | | ON d.device_id = f.device_id 157 | | JOIN wopr.dim_country AS t 158 | | ON t.country_id = f.affiliate_country_id 159 | | JOIN wopr.fact_euro_rates_hourly AS r 160 | | ON r.currency_id = ?currency_id 161 | | AND f.time_id = r.time_id 162 | |WHERE 163 | | CAST(f.time_id AS DATE) BETWEEN ?day 164 | | AND t.country_code IN ?publisher_countries 165 | | AND d.device_name IN ?device_name 166 | | AND z.description IN ?zone_name 167 | | AND z.technology IN ?technology 168 | | AND z.affiliate_name IN ?affiliate_name 169 | | AND t.country_name IN ?affiliate_country 170 | | AND t.country_level_1_name IN ?affiliate_region 171 | | AND z.network_name IN ?network_name 172 | |GROUP BY 173 | | ROLLUP( 174 | | ( 175 | | d.device_name, 176 | | z.description, 177 | | z.technology, 178 | | z.affiliate_name, 179 | | t.country_name, 180 | | z.network_name, 181 | | t.country_level_1_name, 182 | | f.time_id, 183 | | CAST(DATE_TRUNC('day', f.time_id) AS DATE), 184 | | CAST(DATE_TRUNC('week', f.time_id) AS DATE), 185 | | CAST(DATE_TRUNC('month', f.time_id) AS DATE), 186 | | CAST(DATE_TRUNC('quarter', f.time_id) AS DATE) 187 | | ) 188 | | ) 189 | |HAVING 190 | | SUM(clicks) > 0 191 | """.stripMargin 192 | ) 193 | } 194 | } 195 | 196 | val eventuallyResult = Future.sequence { 197 | randomDataSet.zipWithIndex.map { 198 | case (queries, threadX) => 199 | Future { 200 | blocking { 201 | queries.zipWithIndex.map { 202 | case ((sql, expectedOutput), itemX) => 203 | uniqueParser 204 | .parseStatement(sql) 205 | .fold(e => sys.error(s"\n\n${e.toString(sql)}\n"), identity) 206 | .toSQL == expectedOutput.stripMargin.trim 207 | } 208 | } 209 | } 210 | } 211 | } 212 | 213 | Await.result(eventuallyResult, 5 minutes).flatten.forall(identity) shouldBe(true) 214 | } 215 | 216 | } 217 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/TypingErrorsSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import sql99._ 4 | import org.scalatest.prop.TableDrivenPropertyChecks 5 | import org.scalatest.{Matchers, EitherValues, PropSpec} 6 | 7 | class TypingErrorsSpec extends PropSpec with Matchers with EitherValues { 8 | 9 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table( 10 | ("SQL", "Expected error"), 11 | 12 | ( 13 | """select ?""", 14 | """|select ? 15 | | ^ 16 | |Error: parameter type cannot be inferred from context 17 | """ 18 | ), 19 | ( 20 | """select ? = ?""", 21 | """|select ? = ? 22 | | ^ 23 | |Error: parameter type cannot be inferred from context 24 | """ 25 | ), 26 | ( 27 | """select ? = plop""", 28 | """|select ? = plop 29 | | ^ 30 | |Error: column not found plop 31 | """ 32 | ), 33 | ( 34 | """select YOLO(88)""", 35 | """|select YOLO(88) 36 | | ^ 37 | |Error: unknown function yolo 38 | """ 39 | ), 40 | ( 41 | """select min(12, 13)""", 42 | """|select min(12, 13) 43 | | ^ 44 | |Error: too many arguments 45 | """ 46 | ), 47 | ( 48 | """select max()""", 49 | """|select max() 50 | | ^ 51 | |Error: expected argument 52 | """ 53 | ), 54 | ( 55 | """select ? between ? and ?""", 56 | """|select ? between ? and ? 57 | | ^ 58 | |Error: parameter type cannot be inferred from context 59 | """ 60 | ), 61 | ( 62 | """select ? between ? and coco""", 63 | """|select ? between ? and coco 64 | | ^ 65 | |Error: column not found coco 66 | """ 67 | ), 68 | ( 69 | """select ? in (?)""", 70 | """|select ? in (?) 71 | | ^ 72 | |Error: parameter type cannot be inferred from context 73 | """ 74 | ), 75 | ( 76 | """select case when city_id = 1 then 'foo' else 1 end from City""", 77 | """|select case when city_id = 1 then 'foo' else 1 end from City 78 | | ^ 79 | |Error: expected string, found integer 80 | """ 81 | ), 82 | ( 83 | """select 1 = true""", 84 | """|select 1 = true 85 | | ^ 86 | |Error: expected integer, found boolean 87 | """ 88 | ), 89 | ( 90 | """select * from (select 1 a union select true a) x""", 91 | """|select * from (select 1 a union select true a) x 92 | | ^ 93 | |Error: expected integer, found boolean for column a 94 | """ 95 | ), 96 | ( 97 | """select * from (select 1, 2 union select 1) x""", 98 | """|select * from (select 1, 2 union select 1) x 99 | | ^ 100 | |Error: expected same number of columns on both sides of the union 101 | """ 102 | ) 103 | ) 104 | 105 | // -- 106 | 107 | val SAKILA = DB(schemas = List( 108 | Schema( 109 | "sakila", 110 | tables = List( 111 | Table( 112 | "City", 113 | columns = List( 114 | Column("city_id", INTEGER(nullable = false)), 115 | Column("city", STRING(nullable = false)), 116 | Column("country_id", INTEGER(nullable = false)), 117 | Column("last_update", TIMESTAMP(nullable = false)) 118 | ) 119 | ), 120 | Table( 121 | "Country", 122 | columns = List( 123 | Column("country_id", INTEGER(nullable = false)), 124 | Column("country", STRING(nullable = false)), 125 | Column("last_update", TIMESTAMP(nullable = false)) 126 | ) 127 | ) 128 | ) 129 | ) 130 | )) 131 | 132 | // -- 133 | 134 | property("report typing errors on invalid SQL-99 SELECT statements") { 135 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) { 136 | case (sql, expectedError) => 137 | VizSQL.parseQuery(sql, SAKILA) 138 | .fold( 139 | e => sys.error(s"Query doesn't parse: $e"), 140 | _.error.map(_.toString(sql, ' ')).getOrElse(sys.error(s"Query should not type!")) 141 | ) should be (expectedError.stripMargin.trim) 142 | } 143 | } 144 | 145 | } -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/hive/HiveParsingErrorsSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import org.scalatest.prop.TableDrivenPropertyChecks 4 | import org.scalatest.{EitherValues, Matchers, PropSpec} 5 | 6 | class HiveParsingErrorsSpec extends PropSpec with Matchers with EitherValues { 7 | 8 | val invalidSQL99SelectStatements = TableDrivenPropertyChecks.Table( 9 | ("SQL", "Expected error"), 10 | 11 | ( 12 | "select bucket from t", 13 | """select bucket from t 14 | | ^ 15 | |Error: *, table or expression expected 16 | """ 17 | ), 18 | ( 19 | "select foo from tbl limit 100 order by foo", 20 | """select foo from tbl limit 100 order by foo 21 | | ^ 22 | |Error: ; expected 23 | """ 24 | ), 25 | ( 26 | "select foo from bar tablesample (bucket 2 out af 3)", 27 | """select foo from bar tablesample (bucket 2 out af 3) 28 | | ^ 29 | |Error: of expected 30 | """.stripMargin 31 | ) 32 | ) 33 | 34 | // -- 35 | 36 | property("report parsing errors on invalid Hive statements") { 37 | TableDrivenPropertyChecks.forAll(invalidSQL99SelectStatements) { 38 | case (sql, expectedError) => 39 | new HiveDialect(Map.empty).parser.parseStatement(sql) 40 | .fold(_.toString(sql, ' ').trim, _ => "[NO ERROR]") should be (expectedError.toString.stripMargin.trim) 41 | } 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/hive/HiveTypeParserSpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import org.scalatest.prop.TableDrivenPropertyChecks 5 | import org.scalatest.{Matchers, PropSpec} 6 | 7 | class HiveTypeParserSpec extends PropSpec with Matchers { 8 | 9 | val types = TableDrivenPropertyChecks.Table( 10 | ("Type string", "Expected type"), 11 | 12 | ("double", DECIMAL(true)), 13 | ("array", HiveArray(INTEGER(true))), 14 | ("map>", HiveMap(STRING(true), HiveStruct(List( 15 | Column("a", BOOLEAN(true)), 16 | Column("b", TIMESTAMP(true)) 17 | )))), 18 | ("array>,bar:array,`baz`:double>>", HiveArray(HiveStruct(List( 19 | Column("foo", HiveArray(HiveMap(STRING(true), STRING(true)))), 20 | Column("bar", HiveArray(INTEGER(true))), 21 | Column("baz", DECIMAL(true)) 22 | )))), 23 | ("struct", HiveStruct(List( 24 | Column("timestamp", INTEGER(true)) 25 | ))) 26 | ) 27 | 28 | // -- 29 | 30 | property("parse to correct types") { 31 | val parser = new TypeParser 32 | TableDrivenPropertyChecks.forAll(types) { 33 | case (typeString, expectedType) => 34 | parser.parseType(typeString) shouldEqual Right(expectedType) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/criteo/vizatra/vizsql/hive/ParseHiveQuerySpec.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import org.scalatest.prop.TableDrivenPropertyChecks 5 | import org.scalatest.{Matchers, PropSpec} 6 | 7 | class ParseHiveQuerySpec extends PropSpec with Matchers { 8 | val queries = TableDrivenPropertyChecks.Table( 9 | ("Valid Hive query", "Expected AST"), 10 | 11 | ("select a.b[2].c.d.e['foo'].f.g.h.i[0] from t", 12 | SimpleSelect( 13 | projections = List(ExpressionProjection( 14 | MapOrArrayAccessExpression( 15 | StructAccessExpr( 16 | StructAccessExpr( 17 | StructAccessExpr( 18 | StructAccessExpr( 19 | MapOrArrayAccessExpression( 20 | StructAccessExpr( 21 | StructAccessExpr( 22 | StructAccessExpr( 23 | MapOrArrayAccessExpression( 24 | ColumnOrStructAccessExpression(ColumnIdent("b", Some(TableIdent("a")))), 25 | LiteralExpression(IntegerLiteral(2))), 26 | "c"), 27 | "d"), 28 | "e"), 29 | LiteralExpression(StringLiteral("foo"))), 30 | "f"), 31 | "g"), 32 | "h"), 33 | "i"), 34 | LiteralExpression(IntegerLiteral(0))) 35 | )), 36 | relations = List(SingleTableRelation(TableIdent("t"))) 37 | ) 38 | ), 39 | ("select a, b from table x lateral view explode (col) y as z", 40 | SimpleSelect( 41 | projections = List( 42 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("a"))), 43 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("b"))) 44 | ), 45 | relations = List(LateralView( 46 | inner = SingleTableRelation(TableIdent("table"), Some("x")), 47 | explodeFunction = FunctionCallExpression( 48 | name = "explode", 49 | distinct = None, 50 | args = List(ColumnOrStructAccessExpression(ColumnIdent("col")))), 51 | tableAlias = "y", 52 | columnAliases = List("z") 53 | )) 54 | ) 55 | ), 56 | ("select * from ta a left semi join tb b on a.id = b.id", 57 | SimpleSelect( 58 | projections = List(AllColumns), 59 | relations = List(JoinRelation( 60 | left = SingleTableRelation(TableIdent("ta"), Some("a")), 61 | join = LeftSemiJoin, 62 | right = SingleTableRelation(TableIdent("tb"), Some("b")), 63 | on = Some(ComparisonExpression( 64 | op = "=", 65 | left = ColumnOrStructAccessExpression(ColumnIdent("id", Some(TableIdent("a")))), 66 | right = ColumnOrStructAccessExpression(ColumnIdent("id", Some(TableIdent("b")))) 67 | )) 68 | )) 69 | ) 70 | ), 71 | ("select derp(a) from t", 72 | SimpleSelect( 73 | projections = List(ExpressionProjection( 74 | FunctionCallExpression( 75 | name = "derp", 76 | distinct = None, 77 | args = ColumnOrStructAccessExpression(ColumnIdent("a")) :: Nil 78 | ) 79 | )), 80 | relations = List(SingleTableRelation(TableIdent("t"))) 81 | ) 82 | ), 83 | ("select `select` s, `from` f from `join` j order by `select` desc cluster by `from` limit 100", 84 | //FIXME limit is lost, and cluster by is clumped with order by 85 | SimpleSelect( 86 | projections = List( 87 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("select")), Some("s")), 88 | ExpressionProjection(ColumnOrStructAccessExpression(ColumnIdent("from")), Some("f")) 89 | ), 90 | relations = List(SingleTableRelation(TableIdent("join"), Some("j"))), 91 | orderBy = List( 92 | SortExpression(ColumnOrStructAccessExpression(ColumnIdent("select")), Some(SortDESC)), 93 | SortExpression(ColumnOrStructAccessExpression(ColumnIdent("from")), None) 94 | ), 95 | limit = Some(IntegerLiteral(100)) 96 | ) 97 | ), 98 | ("select foo from bar tablesample (bucket 2 out of 3 on baz)", 99 | //FIXME tablesample is lost 100 | SimpleSelect( 101 | projections = List(ExpressionProjection (ColumnOrStructAccessExpression(ColumnIdent("foo")))), 102 | relations = List(SingleTableRelation(TableIdent("bar"))) 103 | ) 104 | ) 105 | ) 106 | 107 | // -- 108 | 109 | property("parse query") { 110 | TableDrivenPropertyChecks.forAll(queries) { case (query, ast) => 111 | new HiveDialect(Map.empty).parser.parseStatement(query) match { 112 | case Left(err) => fail(err.toString(query)) 113 | case Right(result) => result shouldEqual ast 114 | } 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.13") 2 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "1.1") 3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") 4 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Dialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | trait Dialect { 4 | def parser: SQLParser 5 | def functions: PartialFunction[String,SQLFunction] 6 | } 7 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Errors.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | trait Err { 4 | val msg: String 5 | val pos: Int 6 | 7 | def toString(sql: String, lineChar: Char = '~') = { 8 | val carretPosition = sql.size - sql.drop(pos).dropWhile(_.toString.matches("""\s""")).size 9 | 10 | sql.split('\n').foldLeft("") { 11 | case (sql, line) if sql.size <= carretPosition && (sql.size + line.size) >= carretPosition => { 12 | val carret = (1 to (carretPosition - sql.size)).map(_ => lineChar).mkString + "^" 13 | s"$sql\n$line\n$carret\nError: $msg\n" 14 | } 15 | case (sql, line) => s"$sql\n$line" 16 | }.split('\n').dropWhile(_.isEmpty).reverse.dropWhile(_.isEmpty).reverse.mkString("\n") 17 | } 18 | 19 | def combine(err: Err): Err = this 20 | } 21 | 22 | case class PlaceholderError(msg: String, pos: Int) extends Err { 23 | override def combine(err: Err) = err match { 24 | case SchemaError(_, _) | TypeError(_, _) => err 25 | case _ => this 26 | } 27 | } 28 | case class ParsingError(msg: String, pos: Int) extends Err 29 | case class SchemaError(msg: String, pos: Int) extends Err 30 | case class TypeError(msg: String, pos: Int) extends Err 31 | case class SQLError(msg: String, pos: Int) extends Err 32 | case class ParameterError(msg: String, pos: Int) extends Err 33 | case object NoError extends Err { 34 | val msg = "?WAT" 35 | val pos = 0 36 | override def combine(err: Err) = err 37 | } -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/EvalHelper.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import java.util.Date 4 | 5 | object EvalHelper { 6 | 7 | def applyOp[T](op: String, x : T, y : T)(implicit num : Numeric[T]) : T = op match { 8 | case "+" => num.plus(x, y) 9 | case "-" => num.minus(x, y) 10 | case "*" => num.times(x, y) 11 | // Numeric doesn't handle division 12 | } 13 | 14 | val ops = Set("+", "-", "*") 15 | 16 | def mathEval(op : String, x : Literal, y : Literal) : Either[Literal, Expression] = (x, y) match { 17 | case (DecimalLiteral(a), DecimalLiteral(b)) if op == "/" && (b != 0) => Left(DecimalLiteral(a / b)) 18 | case (IntegerLiteral(a), DecimalLiteral(b)) if op == "/" && (b != 0) => Left(DecimalLiteral(a / b)) 19 | case (DecimalLiteral(a), IntegerLiteral(b)) if op == "/" && (b != 0) => Left(DecimalLiteral(a / b)) 20 | case (IntegerLiteral(a), IntegerLiteral(b)) if op == "/" && (b != 0) && (a % b == 0) => Left(IntegerLiteral(a/b)) 21 | 22 | case (DecimalLiteral(a), DecimalLiteral(b)) if ops contains op => Left(DecimalLiteral(applyOp(op, a,b))) 23 | case (IntegerLiteral(a), DecimalLiteral(b)) if ops contains op => Left(DecimalLiteral(applyOp(op, a,b))) 24 | case (DecimalLiteral(a), IntegerLiteral(b)) if ops contains op => Left(DecimalLiteral(applyOp(op, a,b))) 25 | case (IntegerLiteral(a), IntegerLiteral(b)) if ops contains op => Left(IntegerLiteral(applyOp(op, a,b))) 26 | case (NullLiteral, _) => Left(NullLiteral) 27 | case (_, NullLiteral) => Left(NullLiteral) 28 | case _ => Right(MathExpression(op, LiteralExpression(x), LiteralExpression(y))) 29 | } 30 | 31 | def compEval(op : String, x : Literal, y : Literal) : Either[Literal, Expression] = (x, y) match { 32 | case (DecimalLiteral(a), DecimalLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b))) 33 | case (IntegerLiteral(a), DecimalLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b))) 34 | case (DecimalLiteral(a), IntegerLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b))) 35 | case (IntegerLiteral(a), IntegerLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b))) 36 | case (StringLiteral(a), StringLiteral(b)) => Left(bool2Literal(compOperators(op)(a, b))) 37 | case (NullLiteral, _) => Left(NullLiteral) 38 | case (_, NullLiteral) => Left(NullLiteral) 39 | case (l, r) if (l == TrueLiteral || l== FalseLiteral) && 40 | (r == TrueLiteral || r == FalseLiteral) => Left(bool2Literal(compOperators(op)(literal2bool(l), literal2bool(r)))) 41 | case _ => Right(ComparisonExpression(op, LiteralExpression(x), LiteralExpression(y))) 42 | } 43 | 44 | def compare[T<: Any](v1 : T, v2: T) : Int = (v1, v2) match { 45 | case (x : Ordered[_], y) => x.asInstanceOf[Ordered[Any]] compareTo y.asInstanceOf[Any] 46 | case (x : Number, y : Number) => x.doubleValue().compare(y.doubleValue()) 47 | case (x : String, y : String) => x.compare(y) 48 | case (x : Date, y: Date) => x compareTo y 49 | case (x : Boolean, y : Boolean ) => x compareTo y 50 | case (x, y) if x == y => 0 51 | } 52 | 53 | def compOperators = Map( 54 | ">" -> ((a : Any, b : Any) => compare(a, b) > 0), 55 | "<" -> ((a : Any, b : Any) => compare(a, b) < 0), 56 | ">=" -> ((a : Any, b : Any) => compare(a, b) >= 0), 57 | "<=" -> ((a : Any, b : Any) => compare(a, b) <= 0), 58 | "=" -> ((a : Any, b : Any) => a == b), 59 | "<>"-> ((a : Any, b : Any) => a != b) 60 | ) 61 | 62 | def bool2Literal(b : Boolean) = if (b) TrueLiteral else FalseLiteral 63 | def literal2bool(b: Literal) = (b == TrueLiteral) 64 | } 65 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Functions.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | trait SQLFunction { 4 | def getPlaceholders(call: FunctionCallExpression, db: DB, expectedType: Option[Type]) = 5 | Utils.allPlaceholders[Expression](call.args, _.getPlaceholders(db, expectedType)) 6 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders): Either[Err,Type] 7 | def getArguments(call: FunctionCallExpression, db: DB, placeholders: Placeholders, nb: Int): Either[Err,List[Type]] = { 8 | if(nb >= 0 && nb > call.args.size) { 9 | Left(TypeError("expected argument", call.pos + call.name.size + 1)) 10 | } 11 | else if(nb >= 0 && call.args.size > nb) { 12 | Left(TypeError("too many arguments", call.args.drop(nb).head.pos)) 13 | } 14 | else { 15 | call.args.foldRight(Right(Nil):Either[Err,List[Type]]) { 16 | (arg, acc) => for(a <- acc.right; b <- arg.resultType(db, placeholders).right) yield b :: a 17 | } 18 | } 19 | } 20 | } 21 | 22 | trait SQLFunction0 extends SQLFunction { 23 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) = 24 | getArguments(call, db, placeholders, 0).right.flatMap { _ => 25 | result 26 | } 27 | def result: Either[Err,Type] 28 | } 29 | 30 | trait SQLFunction1 extends SQLFunction { 31 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) = 32 | getArguments(call, db, placeholders, 1).right.flatMap { types => 33 | result(call.args(0) -> types(0)) 34 | } 35 | def result: PartialFunction[(Expression,Type),Either[Err,Type]] 36 | } 37 | 38 | trait SQLFunction2 extends SQLFunction { 39 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) = 40 | getArguments(call, db, placeholders, 2).right.flatMap { types => 41 | result((call.args(0) -> types(0), call.args(1) -> types(1))) 42 | } 43 | def result: PartialFunction[((Expression,Type),(Expression,Type)),Either[Err,Type]] 44 | } 45 | 46 | trait SQLFunction3 extends SQLFunction { 47 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) = 48 | getArguments(call, db, placeholders, 3).right.flatMap { types => 49 | result((call.args(0) -> types(0), call.args(1) -> types(1), call.args(2) -> types(2))) 50 | } 51 | def result: PartialFunction[((Expression,Type),(Expression,Type),(Expression,Type)),Either[Err,Type]] 52 | } 53 | 54 | trait SQLFunctionX extends SQLFunction { 55 | def resultType(call: FunctionCallExpression, db: DB, placeholders: Placeholders) = 56 | getArguments(call, db, placeholders, -1).right.flatMap { types => 57 | result(call.args.zip(types)) 58 | } 59 | def result: PartialFunction[List[(Expression,Type)],Either[Err,Type]] 60 | } 61 | 62 | object SQLFunction { 63 | 64 | def standard: PartialFunction[String,SQLFunction] = { 65 | case "min" | "max" => new SQLFunction1 { 66 | def result = { case (_, t) => Right(t) } 67 | } 68 | case "avg" | "sum" => new SQLFunction1 { 69 | def result = { 70 | case (_, t @ (INTEGER(_) | DECIMAL(_))) => Right(t) 71 | case (arg, _) => Left(TypeError("expected numeric argument", arg.pos)) 72 | } 73 | } 74 | case "now" => new SQLFunction0 { 75 | def result = Right(TIMESTAMP()) 76 | } 77 | case "concat" => new SQLFunctionX { 78 | def result = { 79 | case (_, t1) :: _ => Right(STRING(t1.nullable)) 80 | } 81 | } 82 | case "coalesce" => new SQLFunctionX { 83 | def result = { 84 | case (_, t1) :: tail => 85 | tail.foldLeft[Either[Err, Type]](Right(t1)) { 86 | case (Right(t), (curExpr, curType)) => 87 | Utils.parentType(t, curType, TypeError(s"expected ${t.show}, got ${curType.show}", curExpr.pos)) 88 | case (Left(err), _) => 89 | Left(err) 90 | } 91 | } 92 | } 93 | case "count" => new SQLFunctionX { 94 | override def result = { 95 | case _ :: _ => Right(INTEGER(false)) 96 | } 97 | } 98 | } 99 | 100 | } 101 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Parser.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | import scala.util.parsing.combinator._ 4 | import scala.util.parsing.combinator.lexical._ 5 | import scala.util.parsing.combinator.syntactical._ 6 | import scala.util.parsing.input.CharArrayReader.EofCh 7 | 8 | trait SQLParser { 9 | def parseStatement(sql: String): Either[Err,Statement] 10 | } 11 | 12 | object SQL99Parser { 13 | val keywords = Set( 14 | "all", "and", "as", "asc", "between", "by", "case", "cast", "cross", "cube", 15 | "desc", "distinct", "else", "end", "exists", "false", "from", "full", "group", "grouping", 16 | "having", "in", "inner", "is", "join", "left", "like", "limit", 17 | "not", "null", "on", "or", "order", "outer", "right", "rollup", "select", 18 | "sets", "then", "true", "union", "unknown", "when", "where" 19 | ) 20 | 21 | val delimiters = Set( 22 | "(", ")", "\"", 23 | "'", "%", "&", 24 | "*", "/", "+", 25 | "-", ",", ".", 26 | ":", ";", "<", 27 | ">", "?", "[", 28 | "]", "_", "|", 29 | "=", "{", "}", 30 | "^", "??(", "??)", 31 | "<>", ">=", "<=", 32 | "||", "->", "=>" 33 | ) 34 | val comparisonOperators = Set("=", "<>", "<", ">", ">=", "<=") 35 | 36 | val likeOperators = Set("like") 37 | 38 | val orOperators = Set("or") 39 | 40 | val andOperators = Set("and") 41 | 42 | val additionOperators = Set("+", "-") 43 | 44 | val multiplicationOperators = Set("*", "/") 45 | 46 | val unaryOperators = Set("-", "+") 47 | 48 | val typeMap: Map[String, TypeLiteral] = Map( 49 | "timestamp" -> TimestampTypeLiteral, 50 | "datetime" -> TimestampTypeLiteral, 51 | "date" -> DateTypeLiteral, 52 | "boolean" -> BooleanTypeLiteral, 53 | "varchar" -> VarcharTypeLiteral, 54 | "integer" -> IntegerTypeLiteral, 55 | "numeric" -> NumericTypeLiteral, 56 | "decimal" -> DecimalTypeLiteral, 57 | "real" -> RealTypeLiteral 58 | ) 59 | } 60 | 61 | class SQL99Parser extends SQLParser with TokenParsers with PackratParsers { 62 | 63 | type Tokens = SQLTokens 64 | 65 | trait SQLTokens extends token.Tokens { 66 | case class Keyword(chars: String) extends Token 67 | case class Identifier(chars: String) extends Token 68 | case class IntegerLit(chars: String) extends Token 69 | case class DecimalLit(chars: String) extends Token 70 | case class StringLit(chars: String) extends Token 71 | case class Delimiter(chars: String) extends Token 72 | } 73 | 74 | class SQLLexical extends Lexical with SQLTokens { 75 | 76 | lazy val whitespace = rep(whitespaceChar | blockComment | lineComment) 77 | 78 | val keywords = SQL99Parser.keywords 79 | val delimiters = SQL99Parser.delimiters 80 | 81 | lazy val blockComment = '/' ~ '*' ~ rep(chrExcept('*', EofCh) | ('*' ~ chrExcept('/', EofCh))) ~ '*' ~ '/' 82 | lazy val lineComment = '-' ~ '-' ~ rep(chrExcept('\n', '\r', EofCh)) 83 | 84 | lazy val delimiter = { 85 | delimiters.toList.sorted.map(s => accept(s.toList) ^^^ Delimiter(s)).foldRight( 86 | failure("no matching delimiter"): Parser[Token] 87 | )((x, y) => y | x) 88 | } 89 | 90 | override lazy val letter = elem("letter", _.isLetter) 91 | lazy val identifierOrKeyword = letter ~ rep(letter | digit | '_') ^^ { 92 | case first ~ rest => 93 | val chars = first :: rest mkString "" 94 | if(keywords.contains(chars.toLowerCase)) Keyword(chars.toLowerCase) else Identifier(chars) 95 | } 96 | 97 | def customToken: Parser[Token] = elem("should not exist", _ => false) ^^ { _ => sys.error("custom token should not exist")} 98 | 99 | lazy val exp = (accept('e') | accept('E')) ~ opt('-') ~ rep1(digit) ^^ { case e ~ n ~ d => e.toString + n.mkString + d.mkString} 100 | 101 | def quoted(delimiter: Char) = 102 | delimiter ~> rep( 103 | ( (delimiter ~ delimiter) ^^^ delimiter 104 | | chrExcept(delimiter, '\n', EofCh) 105 | ) 106 | ) <~ delimiter 107 | 108 | lazy val token = 109 | ( customToken 110 | | identifierOrKeyword 111 | | rep1(digit) ~ ('.' ~> rep(digit)) ~ opt(exp) ^^ { case i ~ d ~ mE => DecimalLit(i.mkString + "." + d.mkString + mE.mkString) } 112 | | rep1(digit) ~ exp ^^ { case i ~ e => DecimalLit(i.mkString + e) } 113 | | '.' ~> rep1(digit) ~ opt(exp) ^^ { case d ~ mE => DecimalLit("." + d.mkString + mE.mkString) } 114 | | rep1(digit) ^^ { case i => IntegerLit(i.mkString) } 115 | | quoted('\'') ^^ { case chars => StringLit(chars mkString "") } 116 | | quoted('\"') ^^ { case chars => Identifier(chars mkString "") } 117 | | EofCh ^^^ EOF 118 | | '\'' ~> failure("unclosed string literal") 119 | | '\"' ~> failure("unclosed identifier") 120 | | delimiter 121 | | failure("illegal character") 122 | ) 123 | 124 | } 125 | 126 | override val lexical = new SQLLexical 127 | 128 | implicit def stringLiteralToKeywordOrDelimiter(chars: String) = { 129 | if(lexical.keywords.contains(chars) || lexical.delimiters.contains(chars)) { 130 | (accept(lexical.Keyword(chars)) | accept(lexical.Delimiter(chars))) ^^ (_.chars) withFailureMessage(s"$chars expected") 131 | } 132 | else { 133 | sys.error(s"""!!! Invalid parser definition: $chars is not a valid SQL keyword or delimiter""") 134 | } 135 | } 136 | 137 | // -- 138 | 139 | lazy val ident = 140 | elem("ident", _.isInstanceOf[lexical.Identifier]) ^^ (_.chars) 141 | 142 | lazy val booleanLiteral = 143 | ( "true" ^^^ TrueLiteral 144 | | "false" ^^^ FalseLiteral 145 | | "unknown" ^^^ UnknownLiteral 146 | ) 147 | 148 | lazy val nullLiteral = 149 | "null" ^^^ NullLiteral 150 | 151 | lazy val stringLiteral = 152 | elem("string", _.isInstanceOf[lexical.StringLit]) ^^ { v => StringLiteral(v.chars) } 153 | 154 | lazy val integerLiteral = 155 | elem("integer", _.isInstanceOf[lexical.IntegerLit]) ^^ { v => IntegerLiteral(v.chars.toLong) } 156 | 157 | lazy val decimalLiteral = 158 | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ { v => DecimalLiteral(v.chars.toDouble) } 159 | 160 | lazy val literal = pos( 161 | ( decimalLiteral 162 | | integerLiteral 163 | | stringLiteral 164 | | booleanLiteral 165 | | nullLiteral 166 | ) 167 | ) 168 | 169 | lazy val typeLiteral: Parser[TypeLiteral] = 170 | elem("type literal", { t => typeMap.contains(t.chars.toLowerCase) }) ^^ { t => typeMap(t.chars.toLowerCase) } 171 | 172 | val typeMap = SQL99Parser.typeMap 173 | 174 | lazy val column = pos( 175 | ( ident ~ "." ~ ident ~ "." ~ ident ^^ { case s ~ _ ~ t ~ _ ~ c => ColumnIdent(c, Some(TableIdent(t, Some(s)))) } 176 | | ident ~ "." ~ ident ^^ { case t ~ _ ~ c => ColumnIdent(c, Some(TableIdent(t, None))) } 177 | | ident ^^ { case c => ColumnIdent(c, None) } 178 | ) 179 | ) 180 | 181 | lazy val table = 182 | ( ident ~ "." ~ ident ^^ { case s ~ _ ~ t => TableIdent(t, Some(s)) } 183 | | ident ^^ { case t => TableIdent(t, None) } 184 | ) 185 | 186 | lazy val function = 187 | ident ~ ("(" ~> opt(distinct) ~ repsep(expr, ",") <~ ")") ^^ { case n ~ (d ~ a) => FunctionCallExpression(n.toLowerCase, d, a) } 188 | 189 | lazy val countStar = 190 | elem("count", _.chars.toLowerCase == "count") ~ "(" ~ "*" ~ ")" ^^^ CountStarExpression 191 | 192 | lazy val or = (precExpr: Parser[Expression]) => 193 | precExpr * 194 | orOperators.map { op => op ^^^ OrExpression.operator(op) _ }.reduce(_ | _) 195 | 196 | val orOperators = SQL99Parser.orOperators 197 | 198 | lazy val and = (precExpr: Parser[Expression]) => 199 | precExpr * 200 | andOperators.map { op => op ^^^ AndExpression.operator(op) _ }.reduce(_ | _) 201 | 202 | val andOperators = SQL99Parser.andOperators 203 | 204 | lazy val not = (precExpr: Parser[Expression]) => { 205 | def thisExpr: Parser[Expression] = 206 | ( "not" ~> thisExpr ^^ NotExpression 207 | | precExpr 208 | ) 209 | thisExpr 210 | } 211 | 212 | lazy val exists = (precExpr: Parser[Expression]) => 213 | ( "exists" ~> "(" ~> select <~ ")" ^^ ExistsExpression 214 | | precExpr 215 | ) 216 | 217 | lazy val comparator = (precExpr: Parser[Expression]) => 218 | precExpr * 219 | comparisonOperators.map { op => op ^^^ ComparisonExpression.operator(op)_ }.reduce(_ | _) 220 | 221 | val comparisonOperators = SQL99Parser.comparisonOperators 222 | 223 | lazy val like = (precExpr: Parser[Expression]) => 224 | precExpr * 225 | likeOperators.map { op => 226 | opt("not") <~ op ^^ { case o => LikeExpression.operator(op, o.isDefined)_ } 227 | }.reduce(_ | _) 228 | 229 | val likeOperators = SQL99Parser.likeOperators 230 | 231 | lazy val limit = "limit" ~> integerLiteral 232 | 233 | lazy val between = (precExpr: Parser[Expression]) => 234 | precExpr ~ rep(opt("not") ~ ("between" ~> precExpr ~ ("and" ~> precExpr))) ^^ { 235 | case l ~ r => r.foldLeft(l) { case (e, n ~ (lb ~ ub)) => IsBetweenExpression(e, n.isDefined, (lb, ub)) } 236 | } 237 | 238 | lazy val between0 = (precExpr: Parser[Expression]) => 239 | precExpr ~ rep(opt("not") ~ ("between" ~> rangePlaceholder)) ^^ { 240 | case l ~ r => r.foldLeft(l) { case (e, n ~ p) => IsBetweenExpression0(e, n.isDefined, p) } 241 | } 242 | 243 | lazy val in = (precExpr: Parser[Expression]) => 244 | precExpr ~ rep(opt("not") ~ ("in" ~> ("(" ~> rep1sep(expr, ",") <~ ")"))) ^^ { 245 | case l ~ r => r.foldLeft(l) { case (e, n ~ l) => IsInExpression(e, n.isDefined, l) } 246 | } 247 | 248 | lazy val in0 = (precExpr: Parser[Expression]) => 249 | precExpr ~ rep(opt("not") ~ ("in" ~> setPlaceholder)) ^^ { 250 | case l ~ r => r.foldLeft(l) { case (e, n ~ p) => IsInExpression0(e, n.isDefined, p) } 251 | } 252 | 253 | lazy val is = (precExpr: Parser[Expression]) => 254 | precExpr ~ rep("is" ~> opt("not") ~ (booleanLiteral | nullLiteral)) ^^ { 255 | case l ~ r => r.foldLeft(l) { case (e, n ~ l) => IsExpression(e, n.isDefined, l) } 256 | } 257 | 258 | lazy val add = (precExpr: Parser[Expression]) => 259 | precExpr * 260 | additionOperators.map { op => op ^^^ MathExpression.operator(op)_ }.reduce(_ | _) 261 | 262 | val additionOperators = SQL99Parser.additionOperators 263 | 264 | lazy val multiply = (precExpr: Parser[Expression]) => 265 | precExpr * 266 | multiplicationOperators.map { op => op ^^^ MathExpression.operator(op)_ }.reduce(_ | _) 267 | 268 | val multiplicationOperators = SQL99Parser.multiplicationOperators 269 | 270 | lazy val unary = (precExpr: Parser[Expression]) => 271 | ((unaryOperators.map { op => 272 | op ~ precExpr ^^ { case `op` ~ p => UnaryMathExpression(op, p) } 273 | }: Set[Parser[Expression]]) + precExpr).reduce(_ | _) 274 | 275 | val unaryOperators = SQL99Parser.unaryOperators 276 | 277 | lazy val placeholder = 278 | opt(ident) ~ opt(":" ~> typeLiteral) ^^ { case i ~ t => ExpressionPlaceholder(Placeholder(i), t) } 279 | 280 | lazy val expressionPlaceholder = 281 | "?" ~> placeholder 282 | 283 | lazy val rangePlaceholder = 284 | ("?" ~ "[") ~> placeholder <~ ")" 285 | 286 | lazy val setPlaceholder = 287 | ("?" ~ "{") ~> placeholder <~ "}" 288 | 289 | lazy val cast = 290 | ("cast" ~ "(") ~> expr ~ ("as" ~> typeLiteral <~ ")") ^^ { case e ~ t => CastExpression(e, t) } 291 | 292 | lazy val caseWhen = 293 | ("case" ~> opt(expr) ~ when ~ opt("else" ~> expr) <~ "end" ) ^^ { case maybeValue ~ whenList ~ maybeElse => CaseWhenExpression(maybeValue, whenList, maybeElse) } 294 | 295 | lazy val when: Parser[List[(Expression, Expression)]] = 296 | rep1("when" ~> expr ~ ("then" ~> expr)) ^^ { _.map { case a ~ b => (a, b) } } 297 | 298 | lazy val simpleExpr = (_: Parser[Expression]) => 299 | ( literal ^^ LiteralExpression 300 | | function 301 | | countStar 302 | | cast 303 | | caseWhen 304 | | column ^^ ColumnExpression 305 | | "(" ~> select <~ ")" ^^ SubSelectExpression 306 | | "(" ~> expr <~ ")" ^^ ParenthesedExpression 307 | | expressionPlaceholder 308 | ) 309 | 310 | lazy val precedenceOrder = 311 | ( simpleExpr 312 | :: unary 313 | :: multiply 314 | :: add 315 | :: is 316 | :: in0 317 | :: in 318 | :: between0 319 | :: between 320 | :: comparator 321 | :: like 322 | :: exists 323 | :: not 324 | :: and 325 | :: or 326 | :: Nil 327 | ) 328 | 329 | lazy val expr: PackratParser[Expression] = 330 | precedenceOrder.reverse.foldRight(failure("Invalid expression"):Parser[Expression])((a,b) => a(pos(b) | failure("expression expected"))) 331 | 332 | lazy val starProjection = 333 | "*" ^^^ AllColumns 334 | 335 | lazy val tableProjection = 336 | ( (ident ~ "." ~ ident) <~ ("." ~ "*") ^^ { case s ~ _ ~ t => AllTableColumns(TableIdent(t, Some(s))) } 337 | | ident <~ ("." ~ "*") ^^ { case t => AllTableColumns(TableIdent(t, None)) } 338 | ) 339 | 340 | lazy val exprProjection = 341 | expr ~ opt(opt("as") ~> (ident | stringLiteral)) ^^ { 342 | case e ~ a => ExpressionProjection(e, a.collect { case StringLiteral(v) => v; case a: String => a }) 343 | } 344 | 345 | lazy val projections = pos( 346 | ( starProjection 347 | | tableProjection 348 | | exprProjection 349 | 350 | | failure("*, table or expression expected") 351 | ) 352 | ) 353 | 354 | lazy val relations = 355 | "from" ~> rep1sep(relation, ",") 356 | 357 | lazy val relation: PackratParser[Relation] = pos( 358 | ( joinRelation 359 | | singleTableRelation 360 | | subSelectRelation 361 | 362 | | failure("table, join or subselect expected") 363 | ) 364 | ) 365 | 366 | lazy val singleTableRelation = 367 | table ~ opt(opt("as") ~> (ident | stringLiteral)) ^^ { 368 | case t ~ a => SingleTableRelation(t, a.collect { case StringLiteral(v) => v; case a: String => a }) 369 | } 370 | 371 | lazy val subSelectRelation = 372 | ("(" ~> select <~ ")") ~ (opt("as") ~> ident) ^^ { 373 | case s ~ a => SubSelectRelation(s, a) 374 | } 375 | 376 | lazy val join = 377 | ( opt("inner") ~ "join" ^^^ InnerJoin 378 | | "left" ~ opt("outer") ~ "join" ^^^ LeftJoin 379 | | "right" ~ opt("outer") ~ "join" ^^^ RightJoin 380 | | "full" ~ opt("outer") ~ "join" ^^^ FullJoin 381 | | "cross" ~ "join" ^^^ CrossJoin 382 | ) 383 | 384 | lazy val joinRelation = 385 | relation ~ join ~ relation ~ opt("on" ~> expr) ^^ { 386 | case l ~ j ~ r ~ o => JoinRelation(l, j, r, o) 387 | } 388 | 389 | lazy val filters = 390 | "where" ~> expr 391 | 392 | lazy val groupingSet = 393 | ( "(" ~ ")" ^^^ GroupingSet(Nil) 394 | | "(" ~> repsep(expr, ",") <~ ")" ^^ GroupingSet 395 | ) 396 | 397 | lazy val groupingSetOrExpr = 398 | ( groupingSet ^^ Right.apply 399 | | expr ^^ Left.apply 400 | ) 401 | 402 | lazy val group = 403 | ( ("grouping" ~ "sets") ~> ("(" ~> repsep(groupingSet, ",") <~ ")") ^^ GroupByGroupingSets 404 | | "rollup" ~> ("(" ~> repsep(groupingSetOrExpr, ",") <~ ")") ^^ GroupByRollup 405 | | "cube" ~> ("(" ~> repsep(groupingSetOrExpr, ",") <~ ")") ^^ GroupByCube 406 | | expr ^^ GroupByExpression 407 | ) 408 | 409 | lazy val groupBy = 410 | ("group" ~ "by") ~> repsep(group, ",") 411 | 412 | lazy val having = 413 | "having" ~> expr 414 | 415 | lazy val sortOrder = 416 | ( "asc" ^^^ SortASC 417 | | "desc" ^^^ SortDESC 418 | ) 419 | 420 | lazy val sortExpr = 421 | expr ~ opt(sortOrder) ^^ { case e ~ o => SortExpression(e, o) } 422 | 423 | lazy val orderBy = 424 | ("order" ~ "by") ~> repsep(sortExpr, ",") 425 | 426 | lazy val distinct = 427 | ( "distinct" ^^^ SetDistinct 428 | | "all" ^^^ SetAll 429 | ) 430 | 431 | lazy val unionSelect: Parser[UnionSelect] = 432 | select ~ ("union" ~> opt(distinct)) ~ select ^^ { 433 | case l ~ d ~ r => UnionSelect(l, d, r) 434 | } 435 | 436 | lazy val simpleSelect: Parser[SimpleSelect] = 437 | "select" ~> opt(distinct) ~ rep1sep(projections, ",") ~ opt(relations) ~ opt(filters) ~ opt(groupBy) ~ opt(having) ~ opt(orderBy) ~ opt(limit) ^^ { 438 | case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => SimpleSelect(d, p, r.getOrElse(Nil), f, g.getOrElse(Nil), h, o.getOrElse(Nil), l) 439 | } 440 | 441 | lazy val select: PackratParser[Select] = pos( 442 | ( unionSelect 443 | | simpleSelect 444 | ) 445 | ) 446 | 447 | lazy val statement = pos( 448 | ( select 449 | ) 450 | ) 451 | 452 | // -- 453 | 454 | def pos[T <: SQL](p: => Parser[T]) = Parser { in => 455 | p(in) match { 456 | case Success(t, in1) => Success(t.setPos(in.offset), in1) 457 | case ns: NoSuccess => ns 458 | } 459 | } 460 | 461 | def parseStatement(sql: String): Either[Err,Statement] = { 462 | phrase(statement <~ opt(";"))(new lexical.Scanner(sql)) match { 463 | case Success(stmt, _) => Right(stmt) 464 | case NoSuccess(msg, rest) => Left(ParsingError(msg match { 465 | case "end of input expected" => "end of statement expected" 466 | case x => x 467 | }, rest.offset)) 468 | } 469 | } 470 | 471 | } 472 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Schema.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | case class Schemas(schemas: List[Schema]) { 4 | val index = schemas.map(s => (s.name.toLowerCase, s)).toMap 5 | 6 | def getSchema(name: String): Either[String, Schema] = { 7 | index.get(name.toLowerCase).map(Right.apply _).getOrElse { 8 | Left(s"""schema not found $name""") 9 | } 10 | } 11 | 12 | def getNonAmbiguousTable(name: String): Either[String, (Schema, Table)] = { 13 | schemas.flatMap(s => s.getTable(name).right.toOption.map(s -> _)) match { 14 | case Nil => Left(s"table not found $name") 15 | case table :: Nil => Right(table) 16 | case _ => Left(s"ambiguous table $name") 17 | } 18 | } 19 | 20 | def getNonAmbiguousColumn(name: String): Either[String, (Schema, Table, Column)] = { 21 | schemas.map(s => s.getNonAmbiguousColumn(name).right.map { case (t, c) => (s, t, c) }) match { 22 | case Right(col) :: Nil => Right(col) 23 | case Left(err) :: _ => Left(err) 24 | case _ => Left(s"column not found $name") 25 | } 26 | } 27 | } 28 | 29 | case class DB(dialect: Dialect, schemas: Schemas, view: Schemas = Schemas(Nil)) { 30 | 31 | def function(name: String): Either[String, SQLFunction] = { 32 | dialect.functions.lift(name).map(Right.apply).getOrElse(Left(s"unknown function $name")) 33 | } 34 | 35 | } 36 | 37 | object DB { 38 | def apply(schemas: List[Schema])(implicit dialect: Dialect): DB = DB(dialect, Schemas(schemas)) 39 | 40 | import java.sql.{Connection, ResultSet} 41 | 42 | def apply(connection: Connection): DB = { 43 | val databaseMetaData = connection.getMetaData() 44 | 45 | val name = databaseMetaData.getDatabaseProductName() 46 | val version = databaseMetaData.getDatabaseProductVersion() 47 | 48 | val dialect = (name, version) match { 49 | case ("Vertica Database", _) => vertica.dialect 50 | case ("HSQL Database Engine", _) => hsqldb.dialect 51 | case ("H2", _) => h2.dialect 52 | case ("PostgreSQL", _) => postgresql.dialect 53 | case x => sql99.dialect 54 | } 55 | 56 | def consume(rs: ResultSet): List[Map[String, Any]] = { 57 | import collection.mutable._ 58 | val data = ListBuffer.empty[Map[String, Any]] 59 | while (rs.next()) { 60 | val row = HashMap.empty[String, Any] 61 | (1 to rs.getMetaData.getColumnCount).foreach { i => 62 | rs.getObject(i).asInstanceOf[Any] match { 63 | case null => 64 | case x => row.put(rs.getMetaData.getColumnLabel(i).toLowerCase, x) 65 | } 66 | } 67 | data += row 68 | } 69 | data.toList.map(_.toMap) 70 | } 71 | 72 | val allColumns = consume(databaseMetaData.getColumns(null, null, null, null)) 73 | 74 | DB( 75 | dialect, 76 | Schemas( 77 | allColumns.flatMap(_.get("table_schem").map(_.toString.toLowerCase)).distinct.map { schemaName => 78 | val schemaColumns = allColumns.filter(_.get("table_schem").map(_.toString.toLowerCase) == Some(schemaName)).flatMap { col => 79 | for { 80 | table <- col.get("table_name").map(_.toString.toLowerCase) 81 | name <- col.get("column_name").map(_.toString.toLowerCase) 82 | nullable <- col.get("is_nullable").map(_.toString.toLowerCase).collect { 83 | case "yes" => true 84 | case "no" => false 85 | } 86 | typ <- col.get("type_name").map(_.toString.toLowerCase).collect(Type.from(nullable)) 87 | } yield (table, Column(name, typ)) 88 | } 89 | Schema( 90 | name = schemaName, 91 | tables = schemaColumns.groupBy(_._1).map { 92 | case (tableName, columns) => Table(tableName, columns.map(_._2)) 93 | }.toList 94 | ) 95 | } 96 | ) 97 | ) 98 | } 99 | } 100 | 101 | case class Schema(name: String, tables: List[Table]) { 102 | val index = tables.map(t => (t.name.toLowerCase, t)).toMap 103 | 104 | def getTable(name: String): Either[String, Table] = { 105 | index.get(name.toLowerCase).map(Right.apply _).getOrElse { 106 | Left(s"""table not found ${this.name}.$name""") 107 | } 108 | } 109 | 110 | def getNonAmbiguousColumn(name: String): Either[String, (Table, Column)] = { 111 | tables.flatMap(t => t.getColumn(name).right.toOption.map(t -> _)) match { 112 | case Nil => Left(s"column not found $name") 113 | case col :: Nil => Right(col) 114 | case _ => Left(s"ambiguous column $name") 115 | } 116 | } 117 | 118 | } 119 | 120 | case class Table(name: String, columns: List[Column]) { 121 | val index = columns.map(t => (t.name.toLowerCase, t)).toMap 122 | 123 | def getColumn(columnName: String): Either[String, Column] = { 124 | index.get(columnName.toLowerCase).map { 125 | column => Right(column) 126 | }.getOrElse { 127 | Left(s"""no column $columnName in table $name""") 128 | } 129 | } 130 | } 131 | 132 | case class Column(name: String, typ: Type) 133 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Show.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | sealed trait Case { def format(str: String): String } 4 | case object UpperCase extends Case { def format(str: String) = str.toUpperCase } 5 | case object LowerCase extends Case { def format(str: String) = str.toLowerCase } 6 | case object CamelCase extends Case { 7 | def format(str: String) = str.headOption.map(_.toString.toUpperCase).getOrElse("") + str.drop(1).toLowerCase 8 | } 9 | 10 | case class Style(pretty: Boolean, keywords: Case, identifiers: Case) 11 | object Style { 12 | implicit val default = Style(true, UpperCase, LowerCase) 13 | val compact = Style(false, UpperCase, LowerCase) 14 | } 15 | 16 | sealed trait Show { 17 | def ~(o: Show) = Show.Group(this :: o :: Nil) 18 | def ~(o: Option[Show]) = o.map(o => Show.Group(this :: o :: Nil)).getOrElse(this) 19 | def ~-(o: Show) = Show.Group(this :: Show.Whitespace :: o :: Nil) 20 | def ~-(o: Option[Show]) = o.map(o => Show.Group(this :: Show.Whitespace :: o :: Nil)).getOrElse(this) 21 | def ~/(o: Show) = Show.Group(this :: Show.NewLine :: o :: Nil) 22 | def ~/(o: Option[Show]) = o.map(o => Show.Group(this :: Show.NewLine :: o :: Nil)).getOrElse(this) 23 | def ~|(o: Show*) = Show.Group(this :: Show.Indented(Show.Group(o.toList)) :: Nil) 24 | def ~|(o: Option[Show]) = o.map(o => Show.Group(this :: Show.Indented(Show.Group(List(o))) :: Nil)).getOrElse(this) 25 | 26 | def toSQL(style: Style): String = Show.toSQL(this, style, None).right.getOrElse(sys.error("WAT?")) 27 | def toSQL(style: Style, placeholders: Placeholders, namedParameters: Map[String,Any], anonymousParameters: List[Any]) = { 28 | Show.toSQL(this, style, Some((placeholders, namedParameters, anonymousParameters))) 29 | } 30 | } 31 | 32 | object Show { 33 | def toSQL(show: Show, style: Style, placeholders: Option[(Placeholders,Map[String,Any],List[Any])]): Either[Err,String] = { 34 | val INDENT = " " 35 | case class MissingParameter(err: Err) extends Throwable 36 | def trimRight(parts: List[String]) = { 37 | val maybeT = parts.reverse.dropWhile(_ == INDENT) 38 | if(maybeT.headOption.exists(_ == "\n")) { 39 | maybeT.tail.reverse 40 | } else parts 41 | } 42 | var pIndex = 0 43 | def print(x: Show, indent: Int, parts: List[String]): List[String] = x match { 44 | case Keyword(k) => parts ++ (style.keywords.format(k) :: Nil) 45 | case Identifier(i) => parts ++ (style.identifiers.format(i) :: Nil) 46 | case Text(x) => parts ++ (x :: Nil) 47 | case Whitespace => parts ++ (" " :: Nil) 48 | case NewLine => 49 | trimRight(parts) ++ ("\n" :: (0 until indent).map(_ => INDENT).toList) 50 | case Indented(group) => 51 | print(NewLine, indent, trimRight( 52 | print(group.copy(items = NewLine :: group.items), indent + 1, parts) 53 | )) 54 | case Group(items) => 55 | items.foldLeft(parts) { 56 | case (parts, i) => print(i, indent, parts) 57 | } 58 | case Parameter(placeholder) => 59 | placeholders.map { 60 | case (placeholders, namedParameters, anonymousParameters) => 61 | def param(paramType: Option[Type], value: Any): String = paramType.map { p => 62 | def rec(value: FilledParameter) : String = value match { 63 | case StringParameter(s) => s"'${s.replace("'", "''")}'" 64 | case IntegerParameter(x) => x.toString 65 | case DateTimeParameter(t) => s"'${t.replace("'", "''")}'" 66 | case SetParameter(set) => "(" + set.map(rec).mkString(", ") + ")" 67 | case RangeParameter(low, high) => rec(low) + " AND " + rec(high) 68 | case x => throw new IllegalArgumentException(x.getClass.toString) 69 | } 70 | rec(Type.convertParam(p, value)) 71 | } 72 | .getOrElse { 73 | throw new MissingParameter(ParameterError( 74 | "unresolved parameter", placeholder.pos 75 | )) 76 | } 77 | placeholder.name match { 78 | case Some(key) if namedParameters.contains(key) => 79 | parts ++ (param(placeholders.find(_._1.name.exists(_ == key)).map(_._2), namedParameters(key)) :: Nil) 80 | case None if pIndex < anonymousParameters.size => 81 | val s = param(placeholders.filterNot(_._1.name.isDefined).drop(pIndex).headOption.map(_._2), anonymousParameters(pIndex)) 82 | pIndex = pIndex + 1 83 | parts ++ (s :: Nil) 84 | case x => 85 | throw new MissingParameter(ParameterError( 86 | s"""missing value for parameter ${placeholder.name.getOrElse("")}""", placeholder.pos 87 | )) 88 | } 89 | }.getOrElse { 90 | parts ++ (s"""?${placeholder.name.getOrElse("")}""" :: Nil) 91 | } 92 | } 93 | try { 94 | Right(print(show, 0, Nil).mkString.trim) 95 | } catch { 96 | case MissingParameter(err) => Left(err) 97 | } 98 | } 99 | 100 | case class Keyword(keyword: String) extends Show 101 | case class Identifier(identifier: String) extends Show 102 | case class Text(chars: String) extends Show 103 | case class Indented(group: Group) extends Show 104 | case class Parameter(placeholder: Placeholder) extends Show 105 | case class Group(items: List[Show]) extends Show 106 | case object Whitespace extends Show 107 | case object NewLine extends Show 108 | 109 | def line = NewLine 110 | def nest(show: Show*) = Indented(Group(show.toList)) 111 | def keyword(str: String) = Keyword(str) 112 | def ident(str: String) = Identifier(str) 113 | def join(items: List[Show], separator: Show) = { 114 | Group(items.dropRight(1).flatMap(_ :: separator :: Nil) ++ items.lastOption.map(_ :: Nil).getOrElse(Nil)) 115 | } 116 | def ~?(placeholder: Placeholder) = Parameter(placeholder) 117 | 118 | implicit def toText(str: String) = Text(str) 119 | } 120 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/Types.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | trait Type { 4 | val nullable: Boolean 5 | def withNullable(nullable: Boolean): this.type 6 | def canBeCastTo(other: Type): Boolean 7 | def show: String 8 | } 9 | 10 | case class INTEGER(nullable: Boolean = false) extends Type { 11 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type] 12 | def canBeCastTo(other: Type): Boolean = other match { 13 | case DECIMAL(_) | INTEGER(_) => true 14 | case _ => false 15 | } 16 | def show = "integer" 17 | } 18 | case class DECIMAL(nullable: Boolean = false) extends Type { 19 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type] 20 | def canBeCastTo(other: Type): Boolean = other match { 21 | case DECIMAL(_) => true 22 | case _ => false 23 | } 24 | def show = "decimal" 25 | } 26 | case class BOOLEAN(nullable: Boolean = false) extends Type { 27 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type] 28 | def canBeCastTo(other: Type): Boolean = other match { 29 | case BOOLEAN(_) => true 30 | case _ => false 31 | } 32 | def show = "boolean" 33 | } 34 | case class STRING(nullable: Boolean = false) extends Type { 35 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type] 36 | def canBeCastTo(other: Type): Boolean = other match { 37 | case STRING(_) => true 38 | case _ => false 39 | } 40 | def show = "string" 41 | } 42 | case class TIMESTAMP(nullable: Boolean = false) extends Type { 43 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type] 44 | def canBeCastTo(other: Type): Boolean = other match { 45 | case TIMESTAMP(_) => true 46 | case _ => false 47 | } 48 | def show = "timestamp" 49 | } 50 | case class DATE(nullable: Boolean = false) extends Type { 51 | def withNullable(nullable: Boolean = nullable) = this.copy(nullable).asInstanceOf[this.type] 52 | def canBeCastTo(other: Type): Boolean = other match { 53 | case DATE(_)|TIMESTAMP(_) => true 54 | case _ => false 55 | } 56 | def show = "date" 57 | } 58 | 59 | case object NULL extends Type { 60 | val nullable = true 61 | def withNullable(nullable: Boolean) = this 62 | def canBeCastTo(other: Type): Boolean = true 63 | def show = "null" 64 | } 65 | 66 | case class SET(of: Type) extends Type { 67 | val nullable = of.nullable 68 | def withNullable(nullable: Boolean = nullable) = SET(of.withNullable(nullable)).asInstanceOf[this.type] 69 | def canBeCastTo(other: Type): Boolean = other match { 70 | case SET(x) => of.canBeCastTo(x) 71 | case _ => false 72 | } 73 | def show = s"set(${of.show})" 74 | } 75 | 76 | case class RANGE(of: Type) extends Type { 77 | val nullable = of.nullable 78 | def withNullable(nullable: Boolean = nullable) = RANGE(of.withNullable(nullable)).asInstanceOf[this.type] 79 | def canBeCastTo(other: Type): Boolean = other match { 80 | case RANGE(x) => of.canBeCastTo(x) 81 | case _ => false 82 | } 83 | def show = s"range(${of.show})" 84 | } 85 | 86 | object Type { 87 | case class MissingParameter(err: Err) extends Throwable 88 | 89 | def convertParamList(parameters: Map[String, Any], query : Query) = query.placeholders.right.map { placeholders => 90 | for { 91 | (name, value) <- parameters 92 | typ <- placeholders.typeOf(Placeholder(Some(name))) 93 | } yield { 94 | name -> convertParam(typ, value) 95 | } 96 | } 97 | 98 | def convertParam(pType : Type, value : Any) : FilledParameter = pType match { 99 | case STRING(_) => value match { 100 | case x :: _ => convertParam(pType, x) 101 | case s: String => StringParameter(s) 102 | case n: Number => StringParameter(n.toString) 103 | case b: Boolean => StringParameter(b.toString) 104 | case c: Char => StringParameter(c.toString) 105 | case x => throw new MissingParameter(ParameterError( 106 | s"unexpected value $x (${x.getClass.getName}) for an SQL STRING parameter", -1 107 | )) 108 | } 109 | case INTEGER(_) => value match { 110 | case x :: _ => convertParam(pType, x) 111 | case x: Int => IntegerParameter(x) 112 | case x: Long => IntegerParameter(x.toInt) 113 | case x: String => try { 114 | IntegerParameter(x.toInt) 115 | } catch { 116 | case _: Throwable => throw new MissingParameter(ParameterError( 117 | s"unexpected value $x (${x.getClass.getName}) for an SQL INTEGER parameter", -1 118 | )) 119 | } 120 | case x => throw new MissingParameter(ParameterError( 121 | s"unexpected value $x (${x.getClass.getName}) for an SQL INTEGER parameter", -1 122 | )) 123 | } 124 | case DATE(_) => value match { 125 | case x :: _ => convertParam(pType, x) 126 | case x: String => DateTimeParameter(x) 127 | case x => throw new MissingParameter(ParameterError( 128 | s"unexpected value $x (${x.getClass.getName}) for an SQL DATE parameter", -1 129 | )) 130 | } 131 | case TIMESTAMP(_) => value match { 132 | case x :: _ => convertParam(pType, x) 133 | case x: String => DateTimeParameter(x) 134 | case x => throw new MissingParameter(ParameterError( 135 | s"unexpected value $x (${x.getClass.getName}) for an SQL TIMESTAMP parameter", -1 136 | )) 137 | } 138 | case SET(t) => value match { 139 | case x: Seq[_] => SetParameter(x.map(x =>convertParam(t, x)).toSet) 140 | case x => SetParameter(Set(convertParam(t, x))) 141 | } 142 | case RANGE(t) => value match { 143 | case (a,b) => RangeParameter(convertParam(t, a) ,convertParam(t, b)) 144 | case a :: b :: _ => RangeParameter(convertParam(t, a), convertParam(t, b)) 145 | case x => throw new MissingParameter(ParameterError( 146 | s"unexpected value $x (${x.getClass.getName}) for an SQL RANGE parameter", -1 147 | )) 148 | } 149 | } 150 | 151 | def from(nullable: Boolean): PartialFunction[String, Type] = { 152 | case "varchar" | "char" | "bpchar" | "string" => STRING(nullable) 153 | case x if x.contains("varchar") => STRING(nullable) 154 | case "int4" | "integer" => INTEGER(nullable) 155 | case "float" | "float4" | "numeric" | "decimal" => DECIMAL(nullable) 156 | case "timestamp" | "timestamptz" | "timestamp with time zone" => TIMESTAMP(nullable) 157 | case "date" => DATE(nullable) 158 | case "boolean" => BOOLEAN(nullable) 159 | } 160 | } 161 | trait FilledParameter 162 | case class DateTimeParameter(value : String) extends FilledParameter 163 | case class SetParameter(value : Set[FilledParameter]) extends FilledParameter 164 | case class RangeParameter(low : FilledParameter, High : FilledParameter) extends FilledParameter 165 | case class StringParameter(value : String) extends FilledParameter 166 | case class IntegerParameter(value : Int) extends FilledParameter 167 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/VizSQL.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object VizSQL { 4 | 5 | def parseQuery(sql: String, db: DB): Either[Err,Query] = 6 | (db.dialect.parser).parseStatement(sql).right.flatMap { 7 | case select @ SimpleSelect(_, _, _, _, _, _, _, _) => Right(Query(sql, select, db)) 8 | case stmt => Left(SQLError("select expected", stmt.pos)) 9 | } 10 | 11 | def parseOlapQuery(sql: String, db: DB): Either[Err,OlapQuery] = (for { 12 | query <- parseQuery(sql, db).right 13 | } yield OlapQuery(query)).right.flatMap { q => 14 | q.query.error.map(err => Left(err)).getOrElse(Right(q)) 15 | } 16 | 17 | } 18 | 19 | case class Query(sql: String, select: SimpleSelect, db: DB) { 20 | 21 | def tables = select.getTables(db) 22 | def columns = select.getColumns(db) 23 | def placeholders = select.getPlaceholders(db) 24 | def queryView = select.getQueryView(db) 25 | 26 | def error = (for { 27 | _ <- tables.right 28 | _ <- columns.right 29 | _ <- placeholders.right 30 | } yield ()).fold(Some.apply _, _ => None) 31 | } -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/h2/H2Dialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object h2 { 4 | implicit val dialect = new Dialect { 5 | lazy val parser = new SQL99Parser 6 | lazy val functions = SQLFunction.standard 7 | override def toString = "H2" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveAST.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | import com.criteo.vizatra.vizsql.Show._ 5 | 6 | case class LateralView(inner: Relation, explodeFunction: FunctionCallExpression, tableAlias: String, columnAliases: List[String]) extends Relation { 7 | def getTables(db: DB) = { 8 | val result = for { 9 | innerTables <- inner.getTables(db).right 10 | placeholders <- inner.getPlaceholders(db).right 11 | } yield { 12 | val schemas = Schemas(innerTables.groupBy(_._1).map { case (maybeSchema, tableList) => 13 | Schema(maybeSchema.getOrElse(""), tableList.map(_._2)) 14 | }.toList) 15 | val newDb = db.copy(view = schemas) 16 | explodeFunction.resultType(newDb, placeholders).right.flatMap { 17 | case HiveUDTFResult(types) if columnAliases.length == types.length => 18 | val columns = columnAliases.zip(types).map { case (alias, typ) => 19 | Column(alias, typ) 20 | } 21 | Right(innerTables ++ Seq((None, Table(tableAlias, columns)))) 22 | case HiveUDTFResult(types) => 23 | Left(ParsingError(s"Expected ${types.size} aliases, got ${columnAliases.size}", pos)) 24 | case _ => 25 | Left(ParsingError(s"Expected a UDTF, got ${explodeFunction.show}", pos)) 26 | } 27 | } 28 | result.joinRight 29 | } 30 | 31 | def visit = ??? 32 | 33 | def getPlaceholders(db: DB) = inner.getPlaceholders(db) 34 | 35 | def show = ??? 36 | } 37 | 38 | case class MapOrArrayAccessExpression(map: Expression, key: Expression) extends Expression { 39 | def getPlaceholders(db: DB, expectedType: Option[Type]) = for { 40 | mapPlaceholders <- map.getPlaceholders(db, None).right 41 | keyPlaceholders <- key.getPlaceholders(db, None).right 42 | } yield mapPlaceholders ++ keyPlaceholders 43 | 44 | def visit = ??? 45 | 46 | def resultType(db: DB, placeholders: Placeholders) = (for { 47 | mapType <- map.resultType(db, placeholders).right 48 | keyType <- key.resultType(db, placeholders).right 49 | } yield (mapType, keyType) match { 50 | case (HiveMap(k, x), _) if k.canBeCastTo(keyType) => Right(x) 51 | case (HiveMap(k, _), _) => Left(TypeError(s"Expected key type ${keyType.show}, got ${k.show}", pos)) 52 | case (HiveArray(x), INTEGER(_)) => Right(x) 53 | case (HiveArray(_), x) => Left(TypeError(s"Expected integer index, got ${x.show}", pos)) 54 | case (x, _) => Left(TypeError(s"Expected map or array, got ${x.show}", pos)) 55 | }).joinRight 56 | 57 | def show = map.show ~ "[" ~ key.show ~ "]" 58 | } 59 | 60 | case class StructAccessExpr(struct: Expression, field: String) extends Expression { 61 | def getPlaceholders(db: DB, expectedType: Option[Type]) = struct.getPlaceholders(db, None) 62 | 63 | def visit = ??? 64 | 65 | def resultType(db: DB, placeholders: Placeholders) = 66 | struct.resultType(db, placeholders).right.map { 67 | case HiveStruct(cols) => 68 | cols.find(_.name.toLowerCase == field.toLowerCase).map(_.typ).toRight(SchemaError(s"Field $field not found", pos)) 69 | case HiveArray(HiveStruct(cols)) => 70 | cols.find(_.name.toLowerCase == field.toLowerCase).map(c => HiveArray(c.typ)).toRight(SchemaError(s"Field $field not found", pos)) 71 | case x => 72 | Left(TypeError(s"Expected struct, got ${x.show}", pos)) 73 | }.joinRight 74 | 75 | def show = struct.show ~ "." ~ field 76 | 77 | override def toColumnName = field 78 | } 79 | 80 | case class ColumnOrStructAccessExpression(column: ColumnIdent) extends Expression { 81 | def getPlaceholders(db: DB, expectedType: Option[Type]) = Right(Placeholders()) 82 | 83 | def resultType(db: DB, placeholders: Placeholders) = column match { 84 | case ColumnIdent(c3, Some(TableIdent(c2, Some(c1)))) => 85 | (for { 86 | schema <- db.view.getSchema(c1).right 87 | table <- schema.getTable(c2).right 88 | column <- table.getColumn(c3).right 89 | } yield column.typ 90 | ).left.flatMap { _ => 91 | val newCol = ColumnOrStructAccessExpression(ColumnIdent(c2, Some(TableIdent(c1, None)))) 92 | newCol.pos = pos 93 | val structAccess = StructAccessExpr(newCol, c3) 94 | structAccess.pos = pos 95 | structAccess.resultType(db, placeholders) 96 | } 97 | case ColumnIdent(c2, Some(TableIdent(c1, None))) => 98 | (for { 99 | table <- db.view.getNonAmbiguousTable(c1).right.map(_._2).right 100 | column <- table.getColumn(c2).right 101 | } yield column.typ 102 | ).left.flatMap { _ => 103 | val newCol = ColumnOrStructAccessExpression(ColumnIdent(c1, None)) 104 | newCol.pos = pos 105 | val structAccess = StructAccessExpr(newCol, c2) 106 | structAccess.pos = pos 107 | structAccess.resultType(db, placeholders) 108 | } 109 | case ColumnIdent(c, None) => 110 | (for { 111 | column <- db.view.getNonAmbiguousColumn(c).right.map(_._3).right 112 | } yield column.typ 113 | ).left.map(SchemaError(_, column.pos)) 114 | } 115 | 116 | def show = column.show 117 | 118 | def visit = Nil 119 | 120 | override def toColumnName = column.name 121 | } 122 | 123 | case object LeftSemiJoin extends Join { 124 | def show = keyword("left") ~- keyword("semi") ~- keyword("join") 125 | } 126 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveDialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | 5 | import scala.util.parsing.input.CharArrayReader.EofCh 6 | 7 | case class HiveDialect(udfs: Map[String, SQLFunction]) extends Dialect { 8 | 9 | lazy val parser = new SQL99Parser { 10 | override val lexical = new SQLLexical { 11 | override val keywords = SQL99Parser.keywords ++ Set( 12 | "rlike", "regexp", "limit", "lateral", "view", "distribute", "sort", "cluster", "semi", 13 | "tablesample", "bucket", "out", "of", "percent") 14 | override val delimiters = SQL99Parser.delimiters ++ Set("!=", "<=>", "||", "&&", "%", "&", "|", "^", "~", "`", "==") 15 | override val customToken = 16 | ( '`' ~> rep1(chrExcept('`', '\n', EofCh)) <~ '`' ^^ { x => Identifier(x.mkString) } 17 | | '\"' ~> rep1(chrExcept('\"', '\n', EofCh)) <~ '\"' ^^ { x => StringLit(x.mkString) } 18 | | '`' ~> failure("unclosed backtick") 19 | | '\'' ~ 'N' ~ 'a' ~ 'N' ~ '\'' ^^^ DecimalLit("NaN") 20 | ) 21 | } 22 | override val comparisonOperators = SQL99Parser.comparisonOperators ++ Set("!=", "<=>", "==") 23 | override val likeOperators = SQL99Parser.likeOperators ++ Set("rlike", "regexp") 24 | override val orOperators = SQL99Parser.orOperators + "||" 25 | override val andOperators = SQL99Parser.andOperators + "&&" 26 | override val multiplicationOperators = SQL99Parser.multiplicationOperators ++ Set("%", "&", "|", "^") 27 | override val unaryOperators = SQL99Parser.unaryOperators + "~" 28 | override val typeMap: Map[String, TypeLiteral] = HiveDialect.typeMap 29 | 30 | lazy val mapOrArrayAccessExpr = 31 | expr ~ ("[" ~> expr <~ "]") ^^ { case m ~ k => MapOrArrayAccessExpression(m, k) } 32 | 33 | lazy val structAccessExpr = 34 | expr ~ ("." ~> ident) ^^ { case e ~ f => StructAccessExpr(e, f) } 35 | 36 | override lazy val simpleExpr = (_: Parser[Expression]) => 37 | ( literal ^^ LiteralExpression 38 | | structAccessExpr 39 | | mapOrArrayAccessExpr 40 | | function 41 | | countStar 42 | | cast 43 | | caseWhen 44 | | column ^^ ColumnOrStructAccessExpression 45 | | "(" ~> select <~ ")" ^^ SubSelectExpression 46 | | "(" ~> expr <~ ")" ^^ ParenthesedExpression 47 | | expressionPlaceholder 48 | ) 49 | 50 | lazy val tablesampleRelation = 51 | relation <~ "tablesample" ~ "(" ~ tablesampleExpr ~ ")" 52 | 53 | lazy val tablesampleExpr = 54 | ( "bucket" ~ integerLiteral ~ "out" ~ "of" ~ integerLiteral ~ opt("on" ~ expr) 55 | | integerLiteral ~ "percent" 56 | | integerLiteral ~ elem("bytelength", e => List("b", "B", "k", "K", "m", "g", "G").contains(e.chars)) 57 | ) 58 | 59 | lazy val lateralViewRelation = 60 | relation ~ ("lateral" ~ "view" ~ opt("outer") ~> function) ~ ident ~ ("as" ~> repsep(ident, ",")) ^^ { 61 | case r ~ f ~ tblAlias ~ colAliases => LateralView(r, f, tblAlias, colAliases) 62 | } 63 | 64 | override lazy val join = 65 | ( opt("inner") ~ "join" ^^^ InnerJoin 66 | | "left" ~ opt("outer") ~ "join" ^^^ LeftJoin 67 | | "right" ~ opt("outer") ~ "join" ^^^ RightJoin 68 | | "full" ~ opt("outer") ~ "join" ^^^ FullJoin 69 | | "cross" ~ "join" ^^^ CrossJoin 70 | | "left" ~ "semi" ~ "join" ^^^ LeftSemiJoin 71 | ) 72 | 73 | override lazy val relation: PackratParser[Relation] = pos( 74 | ( lateralViewRelation 75 | | tablesampleRelation 76 | | joinRelation 77 | | singleTableRelation 78 | | subSelectRelation 79 | 80 | | failure("table, join or subselect expected") 81 | ) 82 | ) 83 | 84 | lazy val sortBy = 85 | "sort" ~ "by" ~> repsep(sortExpr, ",") 86 | 87 | lazy val distributeBy = 88 | "distribute" ~ "by" ~> repsep(expr, ",") ^^ { c => c.map(SortExpression(_, None)) } 89 | 90 | lazy val clusterBy = 91 | "cluster" ~ "by" ~> repsep(expr, ",") ^^ { c => c.map(SortExpression(_, None)) } 92 | 93 | lazy val xxxBy = // FIXME: losing lots of information by removing the keywords 94 | ( orderBy ~ clusterBy ^^ { case o ~ c => o ++ c } 95 | | distributeBy ~ sortBy ^^ { case d ~ s => d ++ s } 96 | | orderBy 97 | | sortBy 98 | | distributeBy 99 | | clusterBy 100 | ) 101 | 102 | override lazy val simpleSelect: Parser[SimpleSelect] = 103 | "select" ~> opt(distinct) ~ rep1sep(projections, ",") ~ opt(relations) ~ opt(filters) ~ opt(groupBy) ~ opt(having) ~ opt(xxxBy) ~ opt(limit) ^^ { 104 | case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => SimpleSelect(d, p, r.getOrElse(Nil), f, g.getOrElse(Nil), h, o.getOrElse(Nil), l) 105 | } 106 | } 107 | 108 | override def functions = udfs.orElse { 109 | case "min" | "max" => new SQLFunction1 { 110 | def result = { case (_, t) => Right(t) } 111 | } 112 | case "avg" | "sum" => new SQLFunction1 { 113 | def result = { 114 | case (_, t @ (INTEGER(_) | DECIMAL(_))) => Right(t) 115 | case (arg, _) => Left(TypeError("expected numeric argument", arg.pos)) 116 | } 117 | } 118 | case "now" => new SQLFunction0 { 119 | def result = Right(TIMESTAMP()) 120 | } 121 | case "concat" => new SQLFunctionX { 122 | def result = { 123 | case (_, t1) :: _ => Right(STRING(t1.nullable)) 124 | } 125 | } 126 | case "coalesce" => new SQLFunctionX { 127 | def result = { 128 | case (_, t1) :: tail => Right(t1) 129 | } 130 | } 131 | case "count" => new SQLFunctionX { 132 | override def result = { 133 | case _ :: _ => Right(INTEGER(false)) 134 | } 135 | } 136 | case "e" | "pi" => 137 | SimpleFunction0(DECIMAL(true)) 138 | case "current_date" => 139 | SimpleFunction0(DATE(true)) 140 | case "current_user" => 141 | SimpleFunction0(STRING(true)) 142 | case "current_timestamp" => 143 | SimpleFunction0(TIMESTAMP(true)) 144 | case "isnull" | "isnotnull" => 145 | SimpleFunction1(BOOLEAN(true)) 146 | case "variance" | "var_pop" | "var_samp" | "stddev_pop" | "stddev_samp" | "exp" | "ln" | "log10" | 147 | "log2" | "sqrt" | "abs" | "sin" | "asin" | "cos" | "acos" | "tan" | "atan" | "degrees" | 148 | "radians" | "sign" | "cbrt" => 149 | SimpleFunction1(DECIMAL(true)) 150 | case "year" | "quarter" | "month" | "day" | "dayofmonth" | "hour" | "minute" | "second" | "weekofyear" | 151 | "ascii" | "length" | "levenshtein" | "crc32" | "ntile" | "floor" | "ceil" | "ceiling" | 152 | "factorial" | "shiftleft" | "shiftright" | "shiftrightunsigned" | "size" => 153 | SimpleFunction1(INTEGER(true)) 154 | case "to_date" | "last_day" | "base64" | "lower" | "lcase" | "ltrim" | "reverse" | "rtrim" | "space" | 155 | "trim" | "unbase64" | "upper" | "ucase" | "initcap" | "soundex" | "md5" | "sha" | "sha1" | "bin" | 156 | "hex" | "unhex" | "binary" => 157 | SimpleFunction1(STRING(true)) 158 | case "in_file" | "array_contains" => 159 | SimpleFunction2(BOOLEAN(true)) 160 | case "months_between" | "covar_pop" | "covar_samp" | "corr" | "log" | "pow" => 161 | SimpleFunction2(DECIMAL(true)) 162 | case "datediff" | "find_in_set" | "instr" => 163 | SimpleFunction2(INTEGER(true)) 164 | case "date_add" | "date_sub" | "add_months" | "next_day" | "trunc" | "date_format" | "decode" | "encode" | 165 | "format_number" | "get_json_object" | "repeat" | "sha2" | "aes_encrypt" | "aes_decrypt" => 166 | SimpleFunction2(STRING(true)) 167 | case "from_utc_timestamp" | "to_utc_timestamp" => 168 | SimpleFunction2(TIMESTAMP(true)) 169 | case "split" => 170 | SimpleFunction2(HiveArray(STRING(true))) 171 | case "histogram_numeric" => 172 | SimpleFunction2(HiveArray(HiveStruct(List(Column("x", DECIMAL(true)), Column("y", DECIMAL(true)))))) 173 | case "round" | "bround" => 174 | SimpleFunctionX(1, 2, INTEGER(true)) 175 | case "rand" => 176 | SimpleFunctionX(0, 1, INTEGER(true)) 177 | case "lpad" | "regexp_replace" | "rpad" | "translate" | "conv" => 178 | SimpleFunctionX(3, 3, STRING(true)) 179 | case "from_unixtime" => 180 | SimpleFunctionX(1, 2, STRING(true)) 181 | case "unix_timestamp" => 182 | SimpleFunctionX(0, 2, INTEGER(true)) 183 | case "context_ngrams" | "ngrams" => 184 | SimpleFunctionX(4, 4, HiveArray(HiveStruct(List(Column("ngram", STRING(true)), Column("estfrequency", DECIMAL(true)))))) 185 | case "concat_ws" | "printf" => 186 | SimpleFunctionX(2, None, STRING(true)) 187 | case "locate" => 188 | SimpleFunctionX(2, 3, INTEGER(true)) 189 | case "parse_url" | "substr" | "substring" | "regexp_extract" => 190 | SimpleFunctionX(2, 3, STRING(true)) 191 | case "sentences" => 192 | SimpleFunctionX(1, 3, HiveArray(HiveArray(STRING(true)))) 193 | case "str_to_map" => 194 | SimpleFunctionX(1, 3, HiveMap(STRING(true), STRING(true))) 195 | case "substring_index" => 196 | SimpleFunctionX(3, 3, STRING(true)) 197 | case "hash" => 198 | SimpleFunctionX(1, None, INTEGER(true)) 199 | case "pmod" => 200 | new SQLFunction2 { 201 | override def result = { 202 | case ((_, t @ (INTEGER(_) | DECIMAL(_))), _) => Right(t) 203 | } 204 | } 205 | case "positive" | "negative" => 206 | new SQLFunction1 { 207 | override def result = { 208 | case (_, t @ (INTEGER(_) | DECIMAL(_))) => Right(t) 209 | } 210 | } 211 | case "percentile" => 212 | new SQLFunction2 { 213 | override def result = { 214 | case (_, (_, a: HiveArray)) => Right(HiveArray(DECIMAL(true))) 215 | case _ => Right(DECIMAL(true)) 216 | } 217 | } 218 | case "percentile_approx" => 219 | new SQLFunctionX { 220 | override def result = { 221 | case l if l.length >= 2 && l.length <= 3 && l(1)._2.isInstanceOf[HiveArray] => Right(HiveArray(DECIMAL(true))) 222 | case _ => Right(DECIMAL(true)) 223 | } 224 | } 225 | case "collect_set" | "collect_list" => 226 | new SQLFunction1 { 227 | override def result = { 228 | case (_, t) => Right(HiveArray(t)) 229 | } 230 | } 231 | case "if" => 232 | new SQLFunctionX { 233 | override def result = { 234 | case (_, BOOLEAN(_)) :: (_, t1) :: (_, t2) :: Nil => Right(t1) 235 | case (arg, _) :: _ => Left(TypeError("Expected boolean argument", arg.pos)) 236 | } 237 | } 238 | case "map_keys" => 239 | new SQLFunction1 { 240 | override def result = { 241 | case (_, HiveMap(k, _)) => Right(HiveArray(k)) 242 | } 243 | } 244 | case "map_values" => 245 | new SQLFunction1 { 246 | override def result = { 247 | case (_, HiveMap(_, v)) => Right(HiveArray(v)) 248 | } 249 | } 250 | case "sort_array" => 251 | new SQLFunction1 { 252 | override def result = { 253 | case (_, a: HiveArray) => Right(a) 254 | } 255 | } 256 | case "map" => 257 | new SQLFunctionX { 258 | override def result = { 259 | case (_, k) :: (_, v) :: _ => Right(HiveMap(k, v)) 260 | } 261 | } 262 | case "struct" => 263 | new SQLFunctionX { 264 | override def result = { 265 | case l => 266 | val cols = l.zipWithIndex.map { case ((_, t), i) => Column(s"col$i", t) } 267 | Right(HiveStruct(cols)) 268 | } 269 | } 270 | case "named_struct" => 271 | new SQLFunctionX { 272 | override def result = { 273 | case l if l.length % 2 == 0 => 274 | val cols = l.grouped(2).toList.map { 275 | case (LiteralExpression(StringLiteral(name)), _) :: (_, t) :: Nil => Column(name, t) 276 | case x => sys.error(s"Unsupported expression in named_struct: $x") 277 | } 278 | Right(HiveStruct(cols)) 279 | } 280 | } 281 | case "array" => 282 | new SQLFunctionX { 283 | override def result = { 284 | case (_, t) :: _ => Right(HiveArray(t)) 285 | } 286 | } 287 | case "explode" => 288 | new SQLFunction1 { 289 | override def result = { 290 | case (_, HiveArray(elem)) => Right(HiveUDTFResult(List(elem))) 291 | case (_, HiveMap(key, value)) => Right(HiveUDTFResult(List(key, value))) 292 | } 293 | } 294 | case "inline" => 295 | new SQLFunction1 { 296 | override def result = { 297 | case (_, HiveArray(HiveStruct(cols))) => Right(HiveUDTFResult(cols.map(_.typ))) 298 | } 299 | } 300 | case "json_tuple" | "parse_url_tuple" => 301 | new SQLFunctionX { 302 | override def result = { 303 | case l if l.length >= 2 => Right(HiveUDTFResult(l.tail.map(_ => STRING(true)))) 304 | } 305 | } 306 | case "posexplode" => 307 | new SQLFunction1 { 308 | override def result = { 309 | case (_, HiveArray(elem)) => Right(HiveUDTFResult(List(INTEGER(true), elem))) 310 | } 311 | } 312 | case "stack" => 313 | new SQLFunctionX { 314 | override def result = { 315 | case l if l.length >= 2 => Right(HiveUDTFResult(l.tail.map(_._2))) 316 | } 317 | } 318 | case "java_method" | "reflect" => 319 | new SQLFunctionX { 320 | override def result = { 321 | case l if l.length >= 2 => Right(STRING(true)) 322 | } 323 | } 324 | } 325 | } 326 | 327 | object HiveDialect { 328 | 329 | val typeMap = Map( // vizsql types don't matter that much 330 | "tinyint" -> IntegerTypeLiteral, 331 | "smallint" -> IntegerTypeLiteral, 332 | "int" -> IntegerTypeLiteral, 333 | "bigint" -> IntegerTypeLiteral, 334 | "float" -> DecimalTypeLiteral, 335 | "double" -> DecimalTypeLiteral, 336 | "decimal" -> DecimalTypeLiteral, 337 | "timestamp" -> TimestampTypeLiteral, 338 | "string" -> VarcharTypeLiteral, 339 | "boolean" -> BooleanTypeLiteral, 340 | "binary" -> VarcharTypeLiteral 341 | ) 342 | } 343 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveFunctions.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql.Show._ 4 | import com.criteo.vizatra.vizsql._ 5 | 6 | case class SimpleFunction0(resultType: Type) extends SQLFunction0 { 7 | override def result = Right(resultType) 8 | } 9 | 10 | case class SimpleFunction1(resultType: Type) extends SQLFunction1 { 11 | override def result = { 12 | case _ => Right(resultType) 13 | } 14 | } 15 | 16 | case class SimpleFunction2(resultType: Type) extends SQLFunction2 { 17 | override def result = { 18 | case _ => Right(resultType) 19 | } 20 | } 21 | 22 | case class SimpleFunctionX(minArgs: Int, maxArgs: Option[Int], resultType: Type) extends SQLFunctionX { 23 | override def result = { 24 | case l if l.length >= minArgs && maxArgs.forall(l.length <= _) => 25 | Right(resultType) 26 | case l => 27 | Left(ParsingError(s"wrong argument list size ${l.length}", l.headOption.fold(0)(_._1.pos))) 28 | } 29 | } 30 | 31 | object SimpleFunctionX { 32 | def apply(minArgs: Int, maxArgs: Int, resultType: Type): SimpleFunctionX = apply(minArgs, Some(maxArgs), resultType) 33 | } 34 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/HiveTypes.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql.{Column, Type} 4 | 5 | case class HiveArray(elem: Type) extends Type { 6 | val nullable = true 7 | 8 | def withNullable(nullable: Boolean) = this 9 | 10 | def canBeCastTo(other: Type) = this == other 11 | 12 | def show = s"array<${elem.show}>" 13 | } 14 | 15 | case class HiveStruct(elems: List[Column]) extends Type { 16 | val nullable = true 17 | 18 | def withNullable(nullable: Boolean) = this 19 | 20 | def canBeCastTo(other: Type) = this == other 21 | 22 | def show = s"struct<${elems.map { col => s"${col.name}:${col.typ.show}" }.mkString(",")}>" 23 | } 24 | 25 | case class HiveMap(key: Type, value: Type) extends Type { 26 | val nullable = true 27 | 28 | def withNullable(nullable: Boolean) = this 29 | 30 | def canBeCastTo(other: Type) = this == other 31 | 32 | def show = s"map<${key.show},${value.show}>" 33 | } 34 | 35 | case class HiveUDTFResult(types: List[Type]) extends Type { 36 | val nullable = true 37 | 38 | def withNullable(nullable: Boolean) = this 39 | 40 | def canBeCastTo(other: Type) = this == other 41 | 42 | def show = s"udtfresult<${types.map(_.show).mkString(",")}>" 43 | } 44 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hive/TypeParser.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql.hive 2 | 3 | import com.criteo.vizatra.vizsql._ 4 | 5 | import scala.util.parsing.combinator.PackratParsers 6 | import scala.util.parsing.combinator.lexical.Lexical 7 | import scala.util.parsing.combinator.syntactical.TokenParsers 8 | import scala.util.parsing.input.CharArrayReader.EofCh 9 | 10 | class TypeParser extends TokenParsers with PackratParsers { 11 | type Tokens = TypeLexical 12 | override val lexical = new TypeLexical 13 | 14 | class TypeLexical extends Lexical { 15 | case class ColumnName(chars: String) extends Token 16 | case class Keyword(chars: String) extends Token 17 | case class Symbol(chars: String) extends Token 18 | 19 | val keywords = Set( 20 | "array", "map", "struct", "tinyint", "smallint", "int", "bigint", "integer", 21 | "double", "float", "decimal", "string", "binary", "boolean", "timestamp" 22 | ) 23 | val symbols = Set(",", ":", "<", ">") 24 | 25 | lazy val columnNameOrKeyword = rep1(chrExcept('`', ',', ':', '<', '>', EofCh)) ^^ { chrs => 26 | val s = chrs.mkString 27 | if (keywords(s.toLowerCase)) Keyword(s.toLowerCase) 28 | else ColumnName(s) 29 | } 30 | 31 | lazy val token = 32 | ( columnNameOrKeyword 33 | | (elem(',') | elem(':') | elem('<') | elem('>')) ^^ { c => Symbol(c.toString) } 34 | | '`' ~> rep1(chrExcept('`', EofCh)) <~ '`' ^^ { chrs => ColumnName(chrs.mkString) } 35 | | '`' ~> failure("Unclosed backtick") 36 | ) 37 | lazy val whitespace = rep(whitespaceChar) 38 | } 39 | 40 | private val tokenParserCache = collection.mutable.HashMap.empty[String, Parser[String]] 41 | implicit def string2KeywordOrSymbolParser(chars: String): Parser[String] = tokenParserCache.getOrElseUpdate( 42 | chars, 43 | if (lexical.keywords(chars)) accept(lexical.Keyword(chars)) ^^ (_.chars) withFailureMessage s"$chars expected" 44 | else if (lexical.symbols(chars)) accept(lexical.Symbol(chars)) ^^ (_.chars) withFailureMessage s"$chars expected" 45 | else sys.error("Invalid parser definition") 46 | ) 47 | 48 | lazy val name = elem("column name", token => token.isInstanceOf[lexical.ColumnName] || token.isInstanceOf[lexical.Keyword]) ^^ (_.chars) 49 | 50 | lazy val array = "array" ~ "<" ~> anyType <~ ">" ^^ { case el => HiveArray(el) } 51 | 52 | lazy val map = ("map" ~ "<" ~> anyType) ~ ("," ~> anyType <~ ">") ^^ { case k ~ v => HiveMap(k, v) } 53 | 54 | lazy val struct = "struct" ~ "<" ~> structFields <~ ">" ^^ { case l => 55 | HiveStruct(l.map { case n ~ t => Column(n, t) }) 56 | } 57 | 58 | lazy val structFields = rep1sep(name ~ (":" ~> anyType), ",") 59 | 60 | lazy val simpleType: Parser[Type] = 61 | ( ("tinyint" | "smallint" | "int" | "bigint" | "integer") ^^^ INTEGER(true) 62 | | ("double" | "float" | "decimal") ^^^ DECIMAL(true) 63 | | ("string" | "binary") ^^^ STRING(true) 64 | | "boolean" ^^^ BOOLEAN(true) 65 | | "timestamp" ^^^ TIMESTAMP(true) 66 | | failure("expected type literal") 67 | ) 68 | 69 | lazy val anyType: PackratParser[Type] = 70 | ( simpleType 71 | | array 72 | | struct 73 | | map 74 | ) 75 | 76 | def parseType(typeName: String): Either[String, Type] = 77 | phrase(anyType)(new lexical.Scanner(typeName)) match { 78 | case Success(t, _) => Right(t) 79 | case NoSuccess(msg, _) => Left(s"$msg for $typeName") 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/hsqldb/HsqlDBDialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object hsqldb { 4 | implicit val dialect = new Dialect { 5 | lazy val parser = new SQL99Parser 6 | lazy val functions = SQLFunction.standard orElse { 7 | case "timestamp" => new SQLFunction2 { 8 | def result = { case ((_, t1), (_, t2)) => Right(TIMESTAMP(nullable = t1.nullable || t2.nullable)) } 9 | } 10 | }: PartialFunction[String,SQLFunction] 11 | override def toString = "HSQLDB" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/mysql/MySQLDialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object mysql { 4 | implicit val dialect = new Dialect { 5 | lazy val parser = new SQL99Parser 6 | lazy val functions = SQLFunction.standard 7 | override def toString = "MySQL" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/postgresql/PostgresqlDialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object postgresql { 4 | implicit val dialect = new Dialect { 5 | lazy val parser = new SQL99Parser 6 | lazy val functions = SQLFunction.standard orElse { 7 | case "date_trunc" => new SQLFunction2 { 8 | def result = { case (_, (_, t)) => Right(TIMESTAMP(nullable = t.nullable)) } 9 | } 10 | case "zeroifnull" => new SQLFunction1 { 11 | def result = { case (_, t) => Right(t.withNullable(false)) } 12 | } 13 | case "nullifzero" => new SQLFunction1 { 14 | def result = { case (_, t) => Right(t.withNullable(true)) } 15 | } 16 | case "nullif" => new SQLFunction2 { 17 | def result = { case ((_, t1), (_, t2)) => Right(t1.withNullable(true)) } 18 | } 19 | case "to_char" => new SQLFunction2 { 20 | def result = { case ((_, t1), (_, t2)) => Right(STRING()) } 21 | } 22 | }: PartialFunction[String,SQLFunction] 23 | override def toString = "Postgres" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/sql99/SQL99.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object sql99 { 4 | implicit val dialect = new Dialect { 5 | lazy val parser = new SQL99Parser 6 | lazy val functions = SQLFunction.standard 7 | override def toString = "SQL-99" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/dialects/vertica/VerticaDialect.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | object vertica { 4 | implicit val dialect = new Dialect { 5 | lazy val parser = new SQL99Parser 6 | lazy val functions = postgresql.dialect.functions orElse { 7 | case "datediff" => new SQLFunction3 { 8 | def result = { case (_, (_, t1), (_, t2)) => Right(INTEGER(nullable = t1.nullable ||t2.nullable)) } 9 | } 10 | case "to_timestamp" => new SQLFunction1 { 11 | def result = { case (_, t1) => Right(TIMESTAMP(nullable = t1.nullable)) } 12 | } 13 | }: PartialFunction[String,SQLFunction] 14 | override def toString = "Vertica" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/olap/Olap.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | case class OlapError(val msg: String, val pos: Int) extends Err 4 | case class OlapSelection(parameters: Map[String,Any], filters: Map[String,Any]) 5 | case class OlapProjection(dimensions: Set[String], metrics: Set[String]) 6 | case class OlapQuery(query: Query) { 7 | 8 | def getProjections: Either[Err,List[ExpressionProjection]] = { 9 | val validProjections = query.select.projections.collect { case e @ ExpressionProjection(_, Some(_)) => e } 10 | (query.select.projections diff validProjections) match { 11 | case Nil => Right(validProjections) 12 | case oops :: _ => Left(OlapError("Please specify an expression with a label", oops.pos)) 13 | } 14 | } 15 | 16 | def getParameters: Either[Err,List[String]] = { 17 | query.placeholders.right.flatMap { placeholders => 18 | placeholders.foldRight(Right(Nil):Either[Err,List[String]]) { (p, acc) => 19 | acc.right.flatMap { params => 20 | p match { 21 | case (Placeholder(Some(name)), _) => Right(name :: params) 22 | case (p, _) => Left(OlapError("All parameters must be named", p.pos)) 23 | } 24 | } 25 | } 26 | } 27 | } 28 | 29 | private def getAllDimensions: Either[Err,(List[String],List[String])] = for { 30 | columns <- query.columns.right 31 | projections <- getProjections.right 32 | } yield projections.collect { 33 | case ExpressionProjection(e, Some(name)) if query.select.groupBy.flatMap(_.expressions).contains(e) => name 34 | }.partition { dim => 35 | columns.find(_.name == dim).map(_.typ).collect({ 36 | case DATE(_) | TIMESTAMP(_) => true 37 | }).getOrElse(false) 38 | } 39 | 40 | def getDimensions: Either[Err,List[String]] = getAllDimensions.right.map(_._2) 41 | def getTimeDimensions: Either[Err,List[String]] = getAllDimensions.right.map(_._1) 42 | 43 | def getMetrics: Either[Err,List[String]] = for { 44 | projections <- getProjections.right 45 | } yield projections.collect { 46 | case ExpressionProjection(e, Some(name)) if !query.select.groupBy.flatMap(_.expressions).contains(e) => name 47 | } 48 | 49 | def getProjection(dimensionOrMetric: String): Either[Err,ExpressionProjection] = { 50 | query.select.projections.collectFirst { 51 | case e @ ExpressionProjection(_, Some(`dimensionOrMetric`)) => e 52 | }.toRight(OlapError(s"Not found $dimensionOrMetric", query.select.pos)) 53 | } 54 | 55 | def computeQuery(projection: OlapProjection, selection: OlapSelection): Either[Err,String] = for { 56 | projectionExpressions <- { 57 | (if((projection.dimensions ++ projection.metrics).isEmpty) { 58 | Right[Err,List[ExpressionProjection]]((query.select.projections.collect { 59 | case e @ ExpressionProjection(_, _) => e 60 | })) 61 | } else { 62 | (projection.dimensions ++ projection.metrics).foldRight(Right(Nil):Either[Err,List[ExpressionProjection]]) { 63 | (x, acc) => for(a <- acc.right; b <- getProjection(x).right) yield b :: a 64 | } 65 | }).right 66 | } 67 | select <- OlapQuery.rewriteSelect(query.select, query.db, projectionExpressions, selection.filters.keySet ++ selection.parameters.keySet).right 68 | sql <- select.fillParameters(query.db, selection.parameters ++ selection.filters).right 69 | filledQuery <- VizSQL.parseQuery(sql, query.db).right 70 | } yield { 71 | Optimizer.optimize(filledQuery).sql 72 | } 73 | 74 | def rewriteMetricAggregate(metric: String): Either[Err,Option[PostAggregate]] = for { 75 | expression <- getProjection(metric).right.map(_.expression).right 76 | } yield { 77 | def rewriteRecursively(expression: Expression): Option[PostAggregate] = expression match { 78 | case expr @ FunctionCallExpression("sum", _, _) => Some(SumPostAggregate(expr)) 79 | case FunctionCallExpression("zeroifnull", _, expr :: Nil) => rewriteRecursively(expr) 80 | case FunctionCallExpression("nullifzero", _, expr :: Nil) => rewriteRecursively(expr) 81 | case MathExpression("/", left, right) => 82 | for { 83 | leftPA <- rewriteRecursively(left) 84 | rightPA <- rewriteRecursively(right) 85 | } yield { 86 | DividePostAggregate(leftPA, rightPA) 87 | } 88 | case _ => None 89 | } 90 | rewriteRecursively(expression) 91 | } 92 | 93 | } 94 | 95 | object OlapQuery { 96 | 97 | 98 | def rewriteSelect(select: Select, db: DB, keepProjections: List[ExpressionProjection], availableParams: Set[String]): Either[Err, Select] = { 99 | select match { 100 | case s: SimpleSelect => for { 101 | projections <- Right(s.projections.filter(x => keepProjections.contains(x))).right 102 | orderBy <- Right(s.orderBy.filter(f => keepProjections.map(_.expression).contains(f.expression))).right 103 | where <- Right(rewriteWhereCondition(s, availableParams)).right 104 | groupBy <- Right(rewriteGroupBy(s, keepProjections.map(_.expression))).right 105 | tables <- (keepProjections.filter(s.projections.contains).map(_.expression) ++ where.toList).foldRight(Right(Nil):Either[Err,List[String]]) { 106 | (p, acc) => for(a <- acc.right; b <- tablesFor(s, db, p).right) yield b ++ a 107 | }.right 108 | from <- Right(rewriteRelations(s, db, tables.toSet)).right 109 | rewrittenSelect <- Right(s.copy( 110 | projections = projections, 111 | relations = from, 112 | where = where, 113 | groupBy = groupBy, 114 | orderBy = orderBy)).right 115 | optimizedSelect <- optimizeSubSelect(rewrittenSelect, db, availableParams).right 116 | } yield { 117 | optimizedSelect 118 | } 119 | case UnionSelect(left, d, right) => for { 120 | rewrittenLeft <- rewriteSelect(left, db, keepProjections, availableParams).right 121 | rewrittenRight <- rewriteSelect(right, db, keepProjections, availableParams).right 122 | } yield { 123 | UnionSelect(rewrittenLeft, d, rewrittenRight) 124 | } 125 | } 126 | } 127 | 128 | def getTableReferences(select: Select, db: DB, expression: Expression) = { 129 | select.getQueryView(db).right.flatMap { db => 130 | expression.getColumnReferences.foldRight(Right(Nil):Either[Err,List[Table]]) { 131 | (column, acc) => for { 132 | a <- acc.right 133 | b <- (column match { 134 | case ColumnIdent(_, Some(TableIdent(tableName, Some(schemaName)))) => 135 | for { 136 | schema <- db.view.getSchema(schemaName).right 137 | table <- schema.getTable(tableName).right 138 | } yield table 139 | case ColumnIdent(_, Some(TableIdent(tableName, None))) => 140 | db.view.getNonAmbiguousTable(tableName).right.map(_._2) 141 | case ColumnIdent(column, None) => 142 | db.view.getNonAmbiguousColumn(column).right.map(_._2) 143 | }).left.map(SchemaError(_, column.pos)).right 144 | } yield b :: a 145 | } 146 | } 147 | } 148 | 149 | def tablesFor(select: Select, db: DB, expression: Expression): Either[Err,List[String]] = { 150 | getTableReferences(select, db, expression).right.map(_.map(_.name.toLowerCase).distinct) 151 | } 152 | 153 | def optimizeSubSelect(select: SimpleSelect, db: DB, availableParams: Set[String]): Either[Err,Select] = { 154 | def optimizeRecursively(relation: Relation): Either[Err,Relation] = relation match { 155 | case rel @ SubSelectRelation(subSelect, table) => 156 | for { 157 | db <- select.getQueryView(db).right 158 | optimized <- { 159 | val allExpressions = select.projections.collect { 160 | case ExpressionProjection(e, _) => e 161 | } ++ select.relations.flatMap(r => r :: r.visit).distinct.collect { 162 | case JoinRelation(_, _, _, Some(e)) => e 163 | } ++ select.where.toList 164 | 165 | val columnsUsed = allExpressions.flatMap(_.getColumnReferences).collect { 166 | case ColumnIdent(col, Some(TableIdent(`table`, None))) => col 167 | case ColumnIdent(col, None) if db.view.getNonAmbiguousColumn(col).right.get._2.name == table => col 168 | } 169 | 170 | val keepProjections = subSelect.projections.collect { 171 | case e @ ExpressionProjection(_, Some(alias)) if columnsUsed.contains(alias) => e 172 | } 173 | 174 | rewriteSelect(subSelect, db, keepProjections, availableParams) 175 | }.right 176 | } yield { 177 | rel.copy(select = optimized) 178 | } 179 | case rel @ JoinRelation(left, _, right, _) => 180 | for { 181 | newLeft <- optimizeRecursively(left).right 182 | newRight <- optimizeRecursively(right).right 183 | } yield rel.copy(left = newLeft, right = newRight) 184 | case rel => Right(rel) 185 | } 186 | for { 187 | newRelations <- (select.relations.foldRight(Right(Nil):Either[Err,List[Relation]]) { 188 | (rel, acc) => for { 189 | a <- acc.right 190 | b <- optimizeRecursively(rel).right 191 | } yield b :: a 192 | }).right 193 | } yield select.copy(relations = newRelations) 194 | } 195 | 196 | def rewriteRelations(select: SimpleSelect, db: DB, tables: Set[String]): List[Relation] = { 197 | def rewritePass(relations: List[Relation], tables: Set[String]) = { 198 | def rewriteRecursively(relation: Relation): Option[Relation] = relation match { 199 | case rel @ SingleTableRelation(_, Some(table)) => if(tables.contains(table.toLowerCase)) Some(rel) else None 200 | case rel @ SingleTableRelation(TableIdent(table, _), None) => if(tables.contains(table.toLowerCase)) Some(rel) else None 201 | case rel @ SubSelectRelation(subSelect, table) => if(tables.contains(table.toLowerCase)) Some(rel) else None 202 | case rel @ JoinRelation(left, _, right, _) => 203 | (rewriteRecursively(left) :: rewriteRecursively(right) :: Nil).flatten match { 204 | case Nil => None 205 | case x :: Nil => Some(x) 206 | case l :: r :: Nil => Some(rel.copy(left = l, right = r)) 207 | case _ => sys.error("Unreacheable path") 208 | } 209 | case _ => None 210 | } 211 | relations.flatMap(rewriteRecursively) 212 | } 213 | 214 | // We start with the set of tables originally specified, 215 | // and after each rewrite we check that the rewritten joins 216 | // don't need more tables. 217 | def rewriteRecursively(previous: List[Relation], tableSet: Set[String]): List[Relation] = { 218 | def tablesUsedByRelations(relation: Relation): Set[String] = relation match { 219 | case JoinRelation(left, _, right, maybeOn) => 220 | tablesUsedByRelations(left) ++ tablesUsedByRelations(right) ++ maybeOn.map { on => 221 | tablesFor(select, db, on).right.map(_.toSet).right.getOrElse(Set.empty) 222 | }.getOrElse(Set.empty) 223 | case _ => Set.empty 224 | } 225 | val rewritten = rewritePass(select.relations, tableSet) 226 | if(rewritten != previous) { 227 | rewriteRecursively(rewritten, tableSet ++ rewritten.flatMap(tablesUsedByRelations)) 228 | } 229 | else rewritten 230 | } 231 | 232 | rewriteRecursively(Nil, tables) 233 | } 234 | 235 | def rewriteExpression(parameters: Set[String], expression: Expression): Option[Expression] = { 236 | def rewriteRecursively(expression: Expression): Option[Expression] = expression match { 237 | case AndExpression(op, left, right) => 238 | (rewriteRecursively(left) :: rewriteRecursively(right) :: Nil).flatten match { 239 | case Nil => None 240 | case x :: Nil => Some(x) 241 | case l :: r :: Nil => Some(AndExpression(op, l, r)) 242 | case _ => sys.error("Unreacheable path") 243 | } 244 | case expr => 245 | (expr.visitPlaceholders.map(_.name).flatten.toSet &~ parameters).toList match { 246 | case Nil => Some(expr) 247 | case _ => None 248 | } 249 | } 250 | rewriteRecursively(expression) 251 | } 252 | 253 | def rewriteWhereCondition(select: SimpleSelect, parameters: Set[String]): Option[Expression] = { 254 | select.where.flatMap(rewriteExpression(parameters, _)) 255 | } 256 | 257 | def rewriteGroupBy(select: SimpleSelect, expressions: List[Expression]): List[GroupBy] = { 258 | def rewriteGroupingSet(groupingSet: GroupingSet): Option[GroupingSet] = { 259 | Option(GroupingSet(groupingSet.groups.filter(expressions.contains))).filterNot(_.groups.isEmpty) 260 | } 261 | def rewriteRecursively(groupBy: GroupBy): Option[GroupBy] = groupBy match { 262 | case g @ GroupByExpression(e) => if(expressions.contains(e)) Some(g) else None 263 | case g @ GroupByRollup(groups) => 264 | Option(GroupByRollup(groups.map { 265 | case g @ Left(e) => if(expressions.contains(e)) Some(g) else None 266 | case g @ Right(gs) => rewriteGroupingSet(gs).map(Right.apply) 267 | }.flatten)).filterNot(_.groups.isEmpty) 268 | } 269 | select.groupBy.flatMap(rewriteRecursively) 270 | } 271 | 272 | } 273 | 274 | // 275 | 276 | trait PostAggregate 277 | case class SumPostAggregate(expr: Expression) extends PostAggregate 278 | case class DividePostAggregate(left: PostAggregate, right: PostAggregate) extends PostAggregate 279 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/criteo/vizatra/vizsql/optimize/Optimizer.scala: -------------------------------------------------------------------------------- 1 | package com.criteo.vizatra.vizsql 2 | 3 | class Optimizer(db : DB) { 4 | 5 | def optimize(sel : SimpleSelect) : SimpleSelect = { 6 | val newWhere = sel.where.map(preEvaluate) 7 | val newProj = sel.projections.map { 8 | case ExpressionProjection(exp, alias) => ExpressionProjection(preEvaluate(exp), alias) 9 | case x => x 10 | } 11 | val newRel = sel.relations.map(apply) 12 | val newOrder = sel.orderBy.map {case SortExpression(exp, ord) => SortExpression(preEvaluate(exp), ord) } 13 | val sel2 = SimpleSelect(sel.distinct, newProj, newRel, newWhere,sel.groupBy, sel.having, newOrder, sel.limit) 14 | val tables = (sel2.projections.collect {case ExpressionProjection(exp, _) => exp} ++ sel2.where.toList).foldRight(Right(Nil):Either[Err,List[String]]) { 15 | (p, acc) => for(a <- acc.right; b <- OlapQuery.tablesFor(sel, db, p).right) yield b ++ a 16 | }.right.getOrElse(Nil) 17 | SimpleSelect(sel2.distinct, sel2.projections, OlapQuery.rewriteRelations(sel2, db,tables.toSet), sel2.where, sel2.groupBy, sel2.having, sel2.orderBy, sel2.limit) 18 | } 19 | 20 | def apply(select : Select) : Select = select match { 21 | case sel : SimpleSelect => optimize(sel) 22 | case UnionSelect(left, distinct, right) => UnionSelect(this.apply(left), distinct, apply(right)) 23 | } 24 | 25 | def apply(relation: Relation) : Relation = relation match { 26 | case JoinRelation(left, join, right, on) => JoinRelation(apply(left), join, apply(right), on.map(preEvaluate)) 27 | case SubSelectRelation(select, alias) => SubSelectRelation(apply(select), alias) 28 | case x => x 29 | } 30 | 31 | def preEvaluate(expr : Expression) : Expression = { 32 | def convert(v : Either[Literal, Expression]) = v match {case Left(x) => LiteralExpression(x); case Right(x) => x} 33 | def rec(expr: Expression) : Either[Literal, Expression] = expr match { 34 | case LiteralExpression(lit) => Left(lit) 35 | case ParenthesedExpression(exp) => rec(exp).right.map(x => ParenthesedExpression(x)) 36 | case MathExpression(op, left, right) => (rec(left), rec(right)) match { 37 | case (Left(x), Left(y)) => EvalHelper.mathEval(op, x, y) 38 | case (l, r) => Right(MathExpression(op, convert(l), convert(r))) 39 | } 40 | case UnaryMathExpression(op, exp) => rec(exp) match { 41 | case x if op == "+" => x 42 | case Left(lit) if op == "-" => lit match { 43 | case DecimalLiteral(x) => Left(DecimalLiteral(-x)) 44 | case IntegerLiteral(x) => Left(IntegerLiteral(-x)) 45 | case _ => Right(UnaryMathExpression(op, LiteralExpression(lit))) 46 | } 47 | case x => Right(UnaryMathExpression(op, convert(x))) 48 | } 49 | case ComparisonExpression(op, left, right) => (rec(left), rec(right)) match { 50 | case (Left(x), Left(y)) if EvalHelper.compOperators.contains(op) => EvalHelper.compEval(op, x, y) 51 | case (l, r) => Right(ComparisonExpression(op, convert(l), convert(r))) 52 | } 53 | case AndExpression(op, left, right) => (rec(left), rec(right)) match { 54 | case (Left(x), Left(y)) => Left(EvalHelper.bool2Literal(EvalHelper.literal2bool(x) && EvalHelper.literal2bool(y))) 55 | case (l, r) => Right(AndExpression(op, convert(l), convert(r))) 56 | } 57 | case OrExpression(op, left, right) => (rec(left), rec(right)) match { 58 | case (Left(x), Left(y)) => Left(EvalHelper.bool2Literal(EvalHelper.literal2bool(x) || EvalHelper.literal2bool(y))) 59 | case (l, r) => Right(OrExpression(op, convert(l), convert(r))) 60 | } 61 | case IsInExpression(left, not, right) => (rec(left), right.map(rec)) match { 62 | case (l @ Left(_), r) if not && r.contains(l) => Left(FalseLiteral) 63 | case (l @ Left(_), r) if r.contains(l) => Left(TrueLiteral) 64 | case (l, r) => Right(IsInExpression(convert(l), not, 65 | r.map(convert))) 66 | } 67 | case FunctionCallExpression(f, d, args) => Right(FunctionCallExpression(f, d, args.map(x => convert(rec(x))))) 68 | case NotExpression(exp) => rec(exp) match { 69 | case Left(l) => Left(EvalHelper.bool2Literal(l == FalseLiteral)) 70 | case Right(ex) => Right(NotExpression(ex)) 71 | } 72 | case IsExpression(exp, not, lit) => rec(exp) match { 73 | case Left(lit2) => Left(EvalHelper.bool2Literal(not ^ lit == lit2)) 74 | case Right(ex) => Right(IsExpression(ex, not, lit)) 75 | } 76 | case CastExpression(from, to) => Right(CastExpression(convert(rec(from)), to)) 77 | case CaseWhenExpression(value, mapping,elseVal) => 78 | val valueRes = value.map(rec) 79 | val elseRes = elseVal.map(rec) 80 | def recList(l : List[(Either[Literal, Expression],Either[Literal, Expression])]): Either[Literal, Expression] = (valueRes, l, elseRes) match { 81 | case (Some(matchValue), mapVals@(h::_), elseClause) if matchValue.isRight => Right(CaseWhenExpression(Some(convert(matchValue)), 82 | mapVals.map(x => (convert(x._1), convert(x._2))), elseClause.map(convert))) 83 | case (Some(matchValue), h :: t, _) if h._1 == matchValue => h._2 84 | case (Some(matchValue), mapVals @ (h :: t), elseClause) if h._1.isRight => Right(CaseWhenExpression(Some(convert(matchValue)), 85 | mapVals.filter(x => x._1.isRight || x._1 == matchValue).map(x => (convert(x._1), convert(x._2))), elseClause.map(convert))) 86 | case (None, h::t, _) if h._1 == Left(TrueLiteral) => h._2 87 | case (None, mapVals @ (h::t), elseClause) if h._1.isRight => Right(CaseWhenExpression(None, 88 | mapVals.toList.filter(x => x._1.isRight || x._1 == Left(TrueLiteral)).map(x => (convert(x._1), convert(x._2))), elseClause.map(convert))) 89 | case (matchValue, Nil, Some(elseClause)) => elseClause 90 | case (matchValue, Nil, None) => Left(NullLiteral) 91 | case (_, h :: t, _) => recList(t) 92 | } 93 | recList(mapping.map(x => (rec(x._1), rec(x._2)))) 94 | case LikeExpression(left, not, op, right) => (rec(left), rec(right)) match { 95 | case (Left(l), Left(r)) if l == r && !not=> Left(TrueLiteral) 96 | case (l, r) => Right(LikeExpression(convert(l), not, op, convert(r))) 97 | } 98 | case SubSelectExpression(select) => Right(SubSelectExpression(apply(select))) 99 | case _ => Right(expr) 100 | } 101 | convert(rec(expr)) 102 | } 103 | } 104 | 105 | object Optimizer { 106 | def optimize(query : Query) : Query = { 107 | val sel = new Optimizer(query.db).optimize(query.select) 108 | Query(sel.toSQL, sel, query.db) 109 | } 110 | } 111 | --------------------------------------------------------------------------------