├── .gitignore ├── LICENSE.txt ├── README.md ├── TODO ├── core └── src │ ├── main │ └── scala │ │ ├── analysis.scala │ │ ├── ast.scala │ │ ├── csv.scala │ │ ├── database.scala │ │ ├── dialects.scala │ │ ├── jdbc.scala │ │ ├── macro.scala │ │ ├── package.scala │ │ ├── parser.scala │ │ ├── record.scala │ │ ├── timer.scala │ │ ├── typer.scala │ │ ├── typesigdsl.scala │ │ └── validator.scala │ └── test │ ├── resources │ ├── test-postgresql.sql │ └── test.sql │ └── scala │ ├── dynamicexamples.scala │ ├── examples.scala │ ├── failures.scala │ ├── mysqlexamples.scala │ ├── postgreexamples.scala │ └── recordexamples.scala ├── demo ├── README.md ├── project │ ├── build.properties │ └── build.scala └── src │ └── main │ ├── resources │ └── schema.sql │ └── scala │ ├── db.scala │ ├── package.scala │ ├── server.scala │ └── testdata.scala ├── docs ├── phases.dot └── phases.png ├── json4s └── src │ ├── main │ └── scala │ │ └── json.scala │ └── test │ └── scala │ └── jsonexample.scala ├── notes ├── 0.1.0.markdown ├── 0.2.0.markdown ├── 0.3.0.markdown ├── 0.4.0.markdown └── about.markdown ├── project ├── build.properties ├── build.scala ├── plugins.sbt └── publish.scala └── slick-integration └── src └── test └── scala └── slickexample.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | target 3 | project/target -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | sqlτyped - a macro which infers Scala types by analysing SQL statements 2 | ======================================================================= 3 | 4 | 5 | _Towards a perfect impedance match..._ 6 | 7 | * The types and column names are already defined in database schema and SQL query. Why not use those and infer types and accessor functions? 8 | 9 | * SQL is a fine DSL for many queries. It is the native DSL of relational databases and wrapping it with another DSL is often unnecessary (SQL sucks when one has to compose queries, or if you have to be database agnostic). 10 | 11 | 12 | **sqlτyped converts SQL string literals into typed functions at compile time.** 13 | 14 | ```sql 15 | select age, name from person where age > ? 16 | ``` 17 | 18 | ==> 19 | 20 | ```scala 21 | Int => List[{ age: Int, name: String }] 22 | ``` 23 | 24 | 25 | Examples 26 | -------- 27 | 28 | The following examples use schema and data from [test.sql](https://github.com/jonifreeman/sqltyped/blob/master/core/src/test/resources/test.sql) 29 | 30 | First some boring initialization... 31 | 32 | Start console: ```sbt```, then ```project sqltyped``` and ```test:console```. 33 | 34 | ```scala 35 | import java.sql._ 36 | import sqltyped._ 37 | Class.forName("com.mysql.jdbc.Driver") 38 | implicit def conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/sqltyped", 39 | "root", "") 40 | ``` 41 | 42 | Now we are ready to query the data. 43 | 44 | ```scala 45 | scala> val q = sql("select name, age from person") 46 | scala> q() map (_ get "age") 47 | res0: List[Int] = List(36, 14) 48 | ``` 49 | 50 | Notice how the type of 'age' was infered to be Int. 51 | 52 | ```scala 53 | scala> q() map (_ get "salary") 54 | :24: error: No field String("salary") in record ... 55 | q() map (_ get "salary") 56 | ``` 57 | 58 | Oops, a compilation failure. Can't access 'salary', it was not selected in the query. 59 | 60 | Query results are returned as List of type safe records (think ```List[{name:String, age:Int}]```). 61 | As the above examples showed a field of a record can be accessed with get function: ```row.get(name)```. 62 | Functions ```values``` and ```tuples``` can be used to drop record names and get just the query values. 63 | 64 | ```scala 65 | scala> q().values 66 | res1: List[shapeless.::[String,shapeless.::[Int,shapeless.HNil]]] = 67 | List(joe :: 36 :: HNil, moe :: 14 :: HNil) 68 | 69 | scala> q().tuples 70 | res2: List[(String, Int)] = List((joe,36), (moe,14)) 71 | ``` 72 | 73 | Input parameters are parsed and typed. 74 | 75 | ```scala 76 | scala> val q = sql("select name, age from person where age > ?") 77 | 78 | scala> q("30") map (_ get "name") 79 | :24: error: type mismatch; 80 | found : String("30") 81 | required: Int 82 | q("30") map (_ get name) 83 | 84 | scala> q(30) map (_ get "name") 85 | res4: List[String] = List(joe) 86 | ``` 87 | 88 | Nullable columns are inferred to be Scala Options. 89 | 90 | ```scala 91 | scala> val q = sql("""select p.name, j.name as employer, j.started, j.resigned 92 | from person p join job_history j on p.id=j.person order by employer""") 93 | scala> q().tuples 94 | res5: List[(String, String, java.sql.Timestamp, Option[java.sql.Timestamp])] = 95 | List((joe,Enron,2002-08-02 12:00:00.0,Some(2004-06-22 18:00:00.0)), 96 | (joe,IBM,2004-07-13 11:00:00.0,None)) 97 | ``` 98 | 99 | Functions are supported too. Note how function 'max' is polymorphic on its argument. For String 100 | column it is typed as String => String etc. 101 | 102 | ```scala 103 | scala> val q = sql("select max(name) as name, max(age) as age from person where age > ?") 104 | scala> q(10).tupled 105 | res6: (Option[String], Option[Int]) = (Some(moe),Some(36)) 106 | ``` 107 | 108 | ### Analysis ### 109 | 110 | So far all the examples have returned results as Lists of records. But with a little bit of query 111 | analysis we can do better. Like, it is quite unnecessary to box the values as records if just one 112 | column is selected. 113 | 114 | ```scala 115 | scala> sql("select name from person").apply 116 | res7: List[String] = List(joe, moe) 117 | 118 | scala> sql("select age from person").apply 119 | res8: List[Int] = List(36, 14) 120 | ``` 121 | 122 | Then, some queries are known to return just 0 or 1 values, a perfect match for Option type. 123 | The following queries return possible result as an Option instead of List. The first query uses 124 | a uniquely constraint column in its where clause. The second one explicitly wants at most one row. 125 | 126 | ```scala 127 | scala> sql("select name from person where id=?").apply(1) 128 | res9: Some[String] = Some(joe) 129 | 130 | scala> sql("select age from person order by age desc limit 1").apply 131 | res10: Some[Int] = Some(36) 132 | ``` 133 | 134 | ### Inserting data ### 135 | 136 | ```scala 137 | scala> sql("insert into person(name, age, salary) values (?, ?, ?)").apply("bill", 45, 30000) 138 | res1: Int = 1 139 | ``` 140 | 141 | Return value was 1, which means that one row was added. However, often a more useful return value 142 | is the generated primary key. Table 'person' has an autogenerated primary key column named 'id'. To get 143 | the generated value use a function ```sqlk``` (will be changed to ```sql(..., keys = true)``` once 144 | [Scala macros](https://issues.scala-lang.org/browse/SI-5920) support default and named arguments). 145 | 146 | ```scala 147 | scala> sqlk("insert into person(name, age, salary) values (?, ?, ?)").apply("jill", 45, 30000) 148 | res2: Long = 3 149 | ``` 150 | 151 | Inserting multiple values is supported too. 152 | 153 | ```scala 154 | scala> sqlk("insert into person(name, age, salary) select name, age, salary from person").apply 155 | res3: List[Long] = List(4, 5, 6) 156 | ``` 157 | 158 | Updates work as expected. 159 | 160 | ```scala 161 | scala> sql("update person set name=? where age >= ?").apply("joe2", 30) 162 | res4: Int = 1 163 | ``` 164 | 165 | 166 | Documentation 167 | ------------- 168 | 169 | See [wiki](https://github.com/jonifreeman/sqltyped/wiki). 170 | 171 | [Demo app](https://github.com/jonifreeman/sqltyped/tree/master/demo) 172 | 173 | How to try it? 174 | -------------- 175 | 176 | ### Install ### 177 | 178 | Requires at least Scala 2.10.2 and SBT 0.13. 179 | 180 | sqlτyped is published to Sonatype repositories. 181 | 182 | ```scala 183 | "fi.reaktor" %% "sqltyped" % "0.4.3" 184 | ``` 185 | 186 | ### Build ### 187 | 188 | git clone https://github.com/jonifreeman/sqltyped.git 189 | cd sqltyped 190 | 191 | Then either: 192 | 193 | mysql -u root -e 'create database sqltyped' 194 | mysql -u root sqltyped < core/src/test/resources/test.sql 195 | 196 | or: 197 | 198 | sudo -u postgres createuser -P sqltypedtest // Note, change the password from project/build.scala 199 | sudo -u postgres createdb -O sqltypedtest sqltyped 200 | sudo -u postgres psql sqltyped < core/src/test/resources/test-postgresql.sql 201 | 202 | To run the tests you need to setup both databases. 203 | 204 | Credits 205 | ------- 206 | 207 | *(in order of appearance)* 208 | 209 | * Joni Freeman 210 | * Dylan Alex Simon 211 | * Vassil Dichev 212 | -------------------------------------------------------------------------------- /TODO: -------------------------------------------------------------------------------- 1 | ======== 0.4 ======== 2 | 3 | - Analysis: aggregate function over non-nullable column can't be null if GROUP BY? 4 | 5 | ======== 0.5 ======== 6 | 7 | - Schemaprefixes 8 | - Recursive unions (postgresql) 9 | - Views? 10 | - Bug: select max(id) from jackpot where `group` = ? ('group' works) 11 | 12 | ======== Backlog ======== 13 | 14 | - 2.11: Replace Query(n) with SqlF 15 | - 2.11: Batch insert and update (after SqlF) 16 | - Infer function types from schema 17 | - Handle dialect specific keywords properly (e.g. http://www.postgresql.org/docs/9.2/static/sql-keywords-appendix.html) 18 | - 2.11: Better interpolation support (table + column names etc.) 19 | - DB REPL 20 | -------------------------------------------------------------------------------- /core/src/main/scala/analysis.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import Ast._ 4 | import NumOfResults._ 5 | 6 | class Analyzer(typer: Typer) extends Ast.Resolved { 7 | def refine(stmt: Statement, typed: TypedStatement): ?[TypedStatement] = 8 | analyzeSelection(stmt, typed.copy(numOfResults = analyzeResults(stmt, typed))).ok 9 | 10 | /** 11 | * Statement returns 0 - 1 rows if, 12 | * 13 | * - It is SQL insert 14 | * - It has no joins and it contains only 'and' expressions in its where clause and at least one of 15 | * those targets unique constraint with '=' operator 16 | * - It has LIMIT 1 clause 17 | * 18 | * Statement returns 1 row if, 19 | * 20 | * - The projection contains aggregate function and there's no group by 21 | */ 22 | private def analyzeResults(stmt: Statement, typed: TypedStatement): NumOfResults = { 23 | import scala.math.Ordering.Implicits._ 24 | 25 | def inWhereClause(s: Select, cols: List[Column]) = { 26 | def inExpr(e: Expr, col: Column): Boolean = e match { 27 | // note, column comparision works since we only examine statements with one table 28 | case Comparison1(_, _) => false 29 | case Comparison2(Column(n, _), Eq, _) => col.name == n 30 | case Comparison2(_, Eq, Column(n, _)) => col.name == n 31 | case Comparison2(_, _, _) => false 32 | case Comparison3(_, _, _, _) => false 33 | case And(e1, e2) => inExpr(e1, col) || inExpr(e2, col) 34 | case Or(e1, e2) => inExpr(e1, col) || inExpr(e2, col) 35 | case Not(e) => inExpr(e, col) 36 | } 37 | s.where.map(w => cols.map(col => inExpr(w.expr, col)).forall(identity)).getOrElse(false) 38 | } 39 | 40 | def hasLimit1(s: Select) = s.limit.map { 41 | _.count match { 42 | case Left(x) => x == 1 43 | case _ => false 44 | } 45 | } getOrElse false 46 | 47 | def hasAggregate(projection: List[Named]) = { 48 | def collectFs(term: Term): List[Function] = term match { 49 | case f@Function(_, _) => f :: Nil 50 | case Comparison1(t, _) => collectFs(t) 51 | case Comparison2(lhs, _, rhs) => collectFs(lhs) ::: collectFs(rhs) 52 | case Comparison3(t1, _, t2, t3) => collectFs(t1) ::: collectFs(t2) ::: collectFs(t3) 53 | case ArithExpr(lhs, _, rhs) => collectFs(lhs) ::: collectFs(rhs) 54 | case _ => Nil 55 | } 56 | 57 | projection map (_.term) flatMap collectFs map (_.name) exists typer.isAggregate 58 | } 59 | 60 | def hasJoin(t: TableReference) = t match { 61 | case ConcreteTable(_, join) => join.length > 0 62 | case DerivedTable(_, s, join) => join.length > 0 63 | } 64 | 65 | stmt match { 66 | case s@Select(projection, _, _, None, _, _) if hasAggregate(projection) => One 67 | case s@Select(_, tableRefs, where, _, _, _) => 68 | if ((tableRefs.length == 1 && !hasJoin(tableRefs.head) && 69 | where.isDefined && hasNoOrExprs(s) && 70 | typed.uniqueConstraints(tableRefs.head.tables.head).exists(c => inWhereClause(s, c))) || 71 | hasLimit1(s)) 72 | ZeroOrOne 73 | else 74 | Many 75 | case Insert(_, _, SelectedInput(s)) => analyzeResults(s, typed) 76 | case Insert(_, _, _) => One 77 | case Update(_, _, _, _, _) => One 78 | case Delete(_, _) => One 79 | case Create() => One 80 | case SetStatement(s1, _, s2, _, _) => 81 | analyzeResults(s1, typed) max analyzeResults(s2, typed) 82 | case Composed(s1, s2) => 83 | analyzeResults(s1, typed) max analyzeResults(s2, typed) 84 | } 85 | } 86 | 87 | /** 88 | * Selected column is not optional if, 89 | * 90 | * - it is restricted with IS NOT NULL and the expression contains only 'and' operators 91 | */ 92 | private def analyzeSelection(stmt: Statement, typed: TypedStatement): TypedStatement = { 93 | def isNotNull(col: Term, where: Where) = (where.expr find { 94 | case Comparison1(t, IsNotNull) if t == col => 95 | true 96 | case _ => 97 | false 98 | }).isDefined 99 | 100 | stmt match { 101 | case s@Select(projection, _, Some(where), _, _, _) if hasNoOrExprs(s) => 102 | typed.copy(output = typed.output map { case t => 103 | if (t.nullable && isNotNull(t.term, where)) t.copy(nullable = false) 104 | else t 105 | }) 106 | case _ => typed 107 | } 108 | } 109 | 110 | private def hasNoOrExprs(s: Select) = 111 | s.where.map(w => !w.expr.find { 112 | case Or(_, _) => true 113 | case _ => false 114 | }.isDefined) getOrElse false 115 | } 116 | -------------------------------------------------------------------------------- /core/src/main/scala/ast.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import schemacrawler.schema.Schema 4 | import scala.reflect.runtime.universe.{Type, typeOf} 5 | 6 | private[sqltyped] object Ast { 7 | // Types used for AST when references to tables are not yet resolved 8 | // (table is optional string reference). 9 | trait Unresolved { 10 | type Expr = Ast.Expr[Option[String]] 11 | type Term = Ast.Term[Option[String]] 12 | type Named = Ast.Named[Option[String]] 13 | type Statement = Ast.Statement[Option[String]] 14 | type ArithExpr = Ast.ArithExpr[Option[String]] 15 | type Comparison = Ast.Comparison[Option[String]] 16 | type Column = Ast.Column[Option[String]] 17 | type Function = Ast.Function[Option[String]] 18 | type Constant = Ast.Constant[Option[String]] 19 | type Case = Ast.Case[Option[String]] 20 | type Select = Ast.Select[Option[String]] 21 | type Join = Ast.Join[Option[String]] 22 | type JoinType = Ast.JoinType[Option[String]] 23 | type TableReference = Ast.TableReference[Option[String]] 24 | type ConcreteTable = Ast.ConcreteTable[Option[String]] 25 | type DerivedTable = Ast.DerivedTable[Option[String]] 26 | type Where = Ast.Where[Option[String]] 27 | type OrderBy = Ast.OrderBy[Option[String]] 28 | type Limit = Ast.Limit[Option[String]] 29 | } 30 | object Unresolved extends Unresolved 31 | 32 | // Types used for AST when references to tables are resolved 33 | trait Resolved { 34 | type Expr = Ast.Expr[Table] 35 | type Term = Ast.Term[Table] 36 | type Named = Ast.Named[Table] 37 | type Statement = Ast.Statement[Table] 38 | type ArithExpr = Ast.ArithExpr[Table] 39 | type Comparison = Ast.Comparison[Table] 40 | type Column = Ast.Column[Table] 41 | type Function = Ast.Function[Table] 42 | type Constant = Ast.Constant[Table] 43 | type Case = Ast.Case[Table] 44 | type Select = Ast.Select[Table] 45 | type Join = Ast.Join[Table] 46 | type JoinType = Ast.JoinType[Table] 47 | type TableReference = Ast.TableReference[Table] 48 | type ConcreteTable = Ast.ConcreteTable[Table] 49 | type DerivedTable = Ast.DerivedTable[Table] 50 | type Where = Ast.Where[Table] 51 | type OrderBy = Ast.OrderBy[Table] 52 | type Limit = Ast.Limit[Table] 53 | } 54 | object Resolved extends Resolved 55 | 56 | sealed trait Term[T] 57 | 58 | case class Named[T](name: String, alias: Option[String], term: Term[T]) { 59 | def aname = alias getOrElse name 60 | } 61 | 62 | case class Constant[T](tpe: (Type, Int), value: Any) extends Term[T] 63 | case class Column[T](name: String, table: T) extends Term[T] 64 | case class AllColumns[T](table: T) extends Term[T] 65 | case class Function[T](name: String, params: List[Expr[T]]) extends Term[T] 66 | case class ArithExpr[T](lhs: Term[T], op: String, rhs: Term[T]) extends Term[T] 67 | case class Input[T]() extends Term[T] 68 | case class Subselect[T](select: Select[T]) extends Term[T] 69 | case class TermList[T](terms: List[Term[T]]) extends Term[T] 70 | case class Case[T](conditions: List[(Expr[T], Term[T])], elze: Option[Term[T]]) extends Term[T] 71 | 72 | case class Table(name: String, alias: Option[String], schema: Option[String]) 73 | 74 | sealed trait Operator1 75 | case object IsNull extends Operator1 76 | case object IsNotNull extends Operator1 77 | case object Exists extends Operator1 78 | case object NotExists extends Operator1 79 | 80 | sealed trait Operator2 81 | case object Eq extends Operator2 82 | case object Neq extends Operator2 83 | case object Lt extends Operator2 84 | case object Gt extends Operator2 85 | case object Le extends Operator2 86 | case object Ge extends Operator2 87 | case object In extends Operator2 88 | case object NotIn extends Operator2 89 | case object Like extends Operator2 90 | 91 | sealed trait Operator3 92 | case object Between extends Operator3 93 | case object NotBetween extends Operator3 94 | 95 | sealed trait Expr[T] { 96 | def find(p: Expr[T] => Boolean): Option[Expr[T]] = 97 | if (p(this)) Some(this) 98 | else this match { 99 | case And(e1, e2) => e1.find(p) orElse e2.find(p) 100 | case Or(e1, e2) => e1.find(p) orElse e2.find(p) 101 | case _ => None 102 | } 103 | } 104 | 105 | sealed trait Comparison[T] extends Expr[T] with Term[T] 106 | 107 | case class DataType(name: String, precision: List[Int] = Nil) 108 | 109 | case class SimpleExpr[T](term: Term[T]) extends Expr[T] 110 | case class Comparison1[T](term: Term[T], op: Operator1) extends Comparison[T] 111 | case class Comparison2[T](lhs: Term[T], op: Operator2, rhs: Term[T]) extends Comparison[T] 112 | case class Comparison3[T](t1: Term[T], op: Operator3, t2: Term[T], t3: Term[T]) extends Comparison[T] 113 | case class And[T](e1: Expr[T], e2: Expr[T]) extends Expr[T] 114 | case class Or[T](e1: Expr[T], e2: Expr[T]) extends Expr[T] 115 | case class Not[T](e: Expr[T]) extends Expr[T] 116 | case class TypeExpr[T](dataType: DataType) extends Expr[T] 117 | 118 | // Parametrized by Table type (Option[String] or Table) 119 | sealed trait Statement[T] { 120 | def tables: List[Table] 121 | def isQuery = false 122 | } 123 | 124 | def isProjectedByJoin(stmt: Statement[Table], col: Column[Table]): Option[Join[Table]] = stmt match { 125 | case Select(_, tableRefs, _, _, _, _) => (tableRefs flatMap { 126 | case ConcreteTable(t, joins) => if (col.table == t) None else isProjectedByJoin(joins, col) 127 | case DerivedTable(_, _, joins) => isProjectedByJoin(joins, col) 128 | }).headOption 129 | case Composed(l, r) => isProjectedByJoin(l, col) orElse isProjectedByJoin(r, col) 130 | case SetStatement(l, _, r, _, _) => isProjectedByJoin(l, col) orElse isProjectedByJoin(r, col) 131 | case _ => None 132 | } 133 | 134 | def isProjectedByJoin(joins: List[Join[Table]], col: Column[Table]): Option[Join[Table]] = 135 | joins.flatMap(j => isProjectedByJoin(j, col)).headOption 136 | 137 | def isProjectedByJoin(join: Join[Table], col: Column[Table]): Option[Join[Table]] = join.table match { 138 | case ConcreteTable(t, joins) => if (col.table == t) Some(join) else isProjectedByJoin(joins, col) 139 | case DerivedTable(_, _, joins) => isProjectedByJoin(joins, col) 140 | } 141 | 142 | /** 143 | * Returns a Statement where all columns have their tables resolved. 144 | */ 145 | def resolveTables(stmt: Statement[Option[String]]): ?[Statement[Table]] = stmt match { 146 | case s@Select(_, _, _, _, _, _) => resolveSelect(s)() 147 | case d@Delete(_, _) => resolveDelete(d)() 148 | case u@Update(_, _, _, _, _) => resolveUpdate(u)() 149 | case Create() => Create[Table]().ok 150 | case i@Insert(_, _, _) => resolveInsert(i)() 151 | case s@SetStatement(_, _, _, _, _) => resolveSetStatement(s)() 152 | case c@Composed(_, _) => resolveComposed(c)() 153 | } 154 | 155 | private class ResolveEnv(env: List[Table]) { 156 | def resolve(term: Term[Option[String]]): ?[Term[Table]] = term match { 157 | case col@Column(_, _) => resolveColumn(col) 158 | case AllColumns(t) => resolveAllColumns(t) 159 | case f@Function(_, ps) => resolveFunc(f) 160 | case Subselect(select) => resolveSelect(select)(select.tables ::: env) map (s => Subselect(s)) 161 | case ArithExpr(lhs, op, rhs) => 162 | for { l <- resolve(lhs); r <- resolve(rhs) } yield ArithExpr(l, op, r) 163 | case c: Comparison[Option[String]] => resolveComparison(c) 164 | case Constant(tpe, value) => Constant[Table](tpe, value).ok 165 | case TermList(terms) => sequence(terms map resolve) map (ts => TermList[Table](ts)) 166 | case Input() => Input[Table]().ok 167 | case c@Case(_, _) => resolveCase(c) 168 | } 169 | 170 | def resolveColumn(col: Column[Option[String]]) = 171 | env find { t => 172 | (col.table, t.alias) match { 173 | case (Some(ref), None) => t.name == ref 174 | case (Some(ref), Some(a)) => t.name == ref || a == ref 175 | case (None, _) => true 176 | } 177 | } map (t => col.copy(table = t)) orFail ("Column references unknown table " + col) 178 | 179 | def resolveAllColumns(tableRef: Option[String]) = tableRef match { 180 | case Some(ref) => 181 | (env.find(t => t.name == ref || t.alias.map(_ == ref).getOrElse(false)) orFail 182 | ("Unknown table '" + ref + "'")) map (r => AllColumns(r)) 183 | case None => 184 | AllColumns(env.head).ok 185 | } 186 | 187 | def resolveCase(c: Case[Option[String]]) = for { 188 | exprs <- sequence(c.conditions map { case (x, _) => resolveExpr(x) }) 189 | results <- sequence(c.conditions map { case (_, x) => resolve(x) }) 190 | elze <- sequenceO(c.elze map resolve) 191 | } yield Case(exprs zip results, elze) 192 | 193 | def resolveNamed(n: Named[Option[String]]) = resolve(n.term) map (t => n.copy(term = t)) 194 | def resolveFunc(f: Function[Option[String]]) = sequence(f.params map resolveExpr) map (ps => f.copy(params = ps)) 195 | def resolveProj(proj: List[Named[Option[String]]]) = sequence(proj map resolveNamed) 196 | def resolveTableRefs(ts: List[TableReference[Option[String]]]) = sequence(ts map resolveTableRef) 197 | def resolveTableRef(t: TableReference[Option[String]]): ?[TableReference[Table]] = t match { 198 | case f@ConcreteTable(_, join) => sequence(join map resolveJoin) map (j => f.copy(join = j)) 199 | case f@DerivedTable(_, select, join) => for { 200 | s <- resolveSelect(select)() 201 | j <- sequence(join map resolveJoin) 202 | } yield f.copy(subselect = s, join = j) 203 | } 204 | def resolveJoin(join: Join[Option[String]]) = for { 205 | t <- resolveTableRef(join.table) 206 | j <- sequenceO(join.joinType map resolveJoinType) 207 | } yield join.copy(table = t, joinType = j) 208 | def resolveJoinType(t: JoinType[Option[String]]): ?[JoinType[Table]] = t match { 209 | case QualifiedJoin(e) => resolveExpr(e) map QualifiedJoin.apply 210 | case NamedColumnsJoin(cs) => NamedColumnsJoin(cs).ok 211 | } 212 | def resolveWhere(where: Where[Option[String]]) = resolveExpr(where.expr) map Where.apply 213 | def resolveWhereOpt(where: Option[Where[Option[String]]]) = sequenceO(where map resolveWhere) 214 | def resolveGroupBy(groupBy: GroupBy[Option[String]]) = for { 215 | t <- sequence(groupBy.terms map resolve) 216 | h <- resolveHavingOpt(groupBy.having) 217 | } yield groupBy.copy(terms = t, having = h) 218 | def resolveGroupByOpt(groupBy: Option[GroupBy[Option[String]]]) = sequenceO(groupBy map resolveGroupBy) 219 | def resolveHaving(having: Having[Option[String]]) = resolveExpr(having.expr) map Having.apply 220 | def resolveHavingOpt(having: Option[Having[Option[String]]]) = sequenceO(having map resolveHaving) 221 | def resolveOrderBy(orderBy: OrderBy[Option[String]]) = sequence(orderBy.sort map resolve) map (s => orderBy.copy(sort = s)) 222 | def resolveOrderByOpt(orderBy: Option[OrderBy[Option[String]]]) = sequenceO(orderBy map resolveOrderBy) 223 | def resolveLimit(limit: Limit[Option[String]]) = 224 | Limit[Table]( 225 | limit.count.right map (_ => Input[Table]()), 226 | limit.offset.map(_.right map (_ => Input[Table]())) 227 | ).ok 228 | def resolveLimitOpt(limit: Option[Limit[Option[String]]]) = sequenceO(limit map resolveLimit) 229 | 230 | def resolveExpr(e: Expr[Option[String]]): ?[Expr[Table]] = e match { 231 | case SimpleExpr(t) => resolve(t) map SimpleExpr.apply 232 | case c: Comparison[Option[String]] => resolveComparison(c) 233 | case And(e1, e2) => 234 | for { r1 <- resolveExpr(e1); r2 <- resolveExpr(e2) } yield And(r1, r2) 235 | case Or(e1, e2) => 236 | for { r1 <- resolveExpr(e1); r2 <- resolveExpr(e2) } yield Or(r1, r2) 237 | case Not(e) => 238 | for { r <- resolveExpr(e) } yield Not(r) 239 | case TypeExpr(d) => TypeExpr(d).ok 240 | } 241 | 242 | def resolveComparison(c: Comparison[Option[String]]) = c match { 243 | case p@Comparison1(t1, op) => 244 | resolve(t1) map (t => p.copy(term = t)) 245 | case p@Comparison2(t1, op, t2) => 246 | for { l <- resolve(t1); r <- resolve(t2) } yield p.copy(lhs = l, rhs = r) 247 | case p@Comparison3(t1, op, t2, t3) => 248 | for { r1 <- resolve(t1); r2 <- resolve(t2); r3 <- resolve(t3) } yield p.copy(t1 = r1, t2 = r2, t3 = r3) 249 | } 250 | } 251 | 252 | private def resolveSelect(s: Select[Option[String]])(env: List[Table] = s.tables): ?[Select[Table]] = { 253 | val r = new ResolveEnv(env) 254 | for { 255 | p <- r.resolveProj(s.projection) 256 | t <- r.resolveTableRefs(s.tableReferences) 257 | w <- r.resolveWhereOpt(s.where) 258 | g <- r.resolveGroupByOpt(s.groupBy) 259 | o <- r.resolveOrderByOpt(s.orderBy) 260 | l <- r.resolveLimitOpt(s.limit) 261 | } yield s.copy(projection = p, tableReferences = t, where = w, groupBy = g, orderBy = o, limit = l) 262 | } 263 | 264 | private def resolveInsert(i: Insert[Option[String]])(env: List[Table] = i.tables): ?[Insert[Table]] = { 265 | val r = new ResolveEnv(env) 266 | (i.insertInput match { 267 | case SelectedInput(select) => resolveSelect(select)() map SelectedInput.apply 268 | case ListedInput(vals) => 269 | if (i.colNames map (_.length != vals.length) getOrElse false) 270 | fail("Number of column names do not match with number of inputs") 271 | else sequence(vals map r.resolve) map ListedInput.apply 272 | }) map (in => i.copy(insertInput = in)) 273 | } 274 | 275 | private def resolveSetStatement(s: SetStatement[Option[String]])(env: List[Table] = s.tables): ?[SetStatement[Table]] = { 276 | val r = new ResolveEnv(env) 277 | for { 278 | le <- resolveTables(s.left) 279 | ri <- resolveTables(s.right) 280 | o <- r.resolveOrderByOpt(s.orderBy) 281 | l <- r.resolveLimitOpt(s.limit) 282 | } yield SetStatement(le, s.op, ri, o, l) 283 | } 284 | 285 | private def resolveComposed(c: Composed[Option[String]])(env: List[Table] = c.tables): ?[Composed[Table]] = { 286 | val r = new ResolveEnv(env) 287 | for { 288 | l <- resolveTables(c.left) 289 | r <- resolveTables(c.right) 290 | } yield Composed(l, r) 291 | } 292 | 293 | private def resolveDelete(d: Delete[Option[String]])(env: List[Table] = d.tables): ?[Delete[Table]] = { 294 | val r = new ResolveEnv(env) 295 | for { 296 | w <- r.resolveWhereOpt(d.where) 297 | } yield d.copy(where = w) 298 | } 299 | 300 | private def resolveUpdate(u: Update[Option[String]])(env: List[Table] = u.tables): ?[Update[Table]] = { 301 | val r = new ResolveEnv(env) 302 | 303 | def resolveSet(c: Column[Option[String]], t: Term[Option[String]]) = for { 304 | rc <- r.resolveColumn(c) 305 | rt <- r.resolve(t) 306 | } yield (rc, rt) 307 | 308 | for { 309 | s <- sequence(u.set map { case (c, t) => resolveSet(c, t) }) 310 | w <- r.resolveWhereOpt(u.where) 311 | o <- r.resolveOrderByOpt(u.orderBy) 312 | l <- r.resolveLimitOpt(u.limit) 313 | } yield u.copy(set = s, where = w, orderBy = o, limit = l) 314 | } 315 | 316 | case class Delete[T](tables: List[Table], where: Option[Where[T]]) extends Statement[T] 317 | 318 | sealed trait InsertInput[T] 319 | case class ListedInput[T](values: List[Term[T]]) extends InsertInput[T] 320 | case class SelectedInput[T](select: Select[T]) extends InsertInput[T] 321 | 322 | case class Insert[T](table: Table, colNames: Option[List[String]], insertInput: InsertInput[T]) extends Statement[T] { 323 | def output = Nil 324 | def tables = table :: Nil 325 | } 326 | 327 | case class SetStatement[T](left: Statement[T], op: String, right: Statement[T], 328 | orderBy: Option[OrderBy[T]], limit: Option[Limit[T]]) extends Statement[T] { 329 | def tables = left.tables ::: right.tables 330 | override def isQuery = true 331 | } 332 | 333 | case class Update[T](tables: List[Table], set: List[(Column[T], Term[T])], where: Option[Where[T]], 334 | orderBy: Option[OrderBy[T]], limit: Option[Limit[T]]) extends Statement[T] 335 | 336 | case class Create[T]() extends Statement[T] { 337 | def tables = Nil 338 | } 339 | 340 | case class Composed[T](left: Statement[T], right: Statement[T]) extends Statement[T] { 341 | def tables = left.tables ::: right.tables 342 | override def isQuery = left.isQuery || right.isQuery 343 | } 344 | 345 | case class Select[T](projection: List[Named[T]], 346 | tableReferences: List[TableReference[T]], // should be NonEmptyList 347 | where: Option[Where[T]], 348 | groupBy: Option[GroupBy[T]], 349 | orderBy: Option[OrderBy[T]], 350 | limit: Option[Limit[T]]) extends Statement[T] { 351 | def tables = tableReferences flatMap (_.tables) 352 | override def isQuery = true 353 | } 354 | 355 | sealed trait TableReference[T] { 356 | def tables: List[Table] 357 | def name: String 358 | } 359 | case class ConcreteTable[T](table: Table, join: List[Join[T]]) extends TableReference[T] { 360 | def tables = table :: join.flatMap(_.table.tables) 361 | def name = table.name 362 | } 363 | case class DerivedTable[T](name: String, subselect: Select[T], join: List[Join[T]]) extends TableReference[T] { 364 | def tables = Table(name, None, None) :: join.flatMap(_.table.tables) 365 | } 366 | 367 | case class Where[T](expr: Expr[T]) 368 | 369 | case class Join[T](table: TableReference[T], joinType: Option[JoinType[T]], joinDesc: JoinDesc) 370 | 371 | sealed trait JoinDesc 372 | case object Inner extends JoinDesc 373 | case object LeftOuter extends JoinDesc 374 | case object RightOuter extends JoinDesc 375 | case object FullOuter extends JoinDesc 376 | case object Cross extends JoinDesc 377 | 378 | trait JoinType[T] 379 | case class QualifiedJoin[T](expr: Expr[T]) extends JoinType[T] 380 | case class NamedColumnsJoin[T](columns: List[String]) extends JoinType[T] 381 | 382 | case class GroupBy[T](terms: List[Term[T]], withRollup: Boolean, having: Option[Having[T]]) 383 | 384 | case class Having[T](expr: Expr[T]) 385 | 386 | case class OrderBy[T](sort: List[Term[T]], orders: List[Option[Order]]) 387 | 388 | sealed trait Order 389 | case object Asc extends Order 390 | case object Desc extends Order 391 | 392 | case class Limit[T](count: Either[Int, Input[T]], offset: Option[Either[Int, Input[T]]]) 393 | } 394 | -------------------------------------------------------------------------------- /core/src/main/scala/csv.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import shapeless._, ops.hlist._, tag.@@ 4 | 5 | trait Show[A] { 6 | def show(a: A): String 7 | } 8 | 9 | object Show { 10 | implicit object ShowByte extends ToStringShow[Byte] 11 | implicit object ShowInt extends ToStringShow[Int] 12 | implicit object ShowShort extends ToStringShow[Short] 13 | implicit object ShowDouble extends ToStringShow[Double] 14 | implicit object ShowLong extends ToStringShow[Long] 15 | implicit object ShowFloat extends ToStringShow[Float] 16 | implicit object ShowString extends ToStringShow[String] 17 | implicit object ShowBool extends ToStringShow[Boolean] 18 | implicit def ShowOption[A]: Show[Option[A]] = new Show[Option[A]] { 19 | def show(a: Option[A]) = "" 20 | } 21 | implicit def ShowTagged[A: Show, T]: Show[A @@ T] = new Show[A @@ T] { 22 | def show(a: A @@ T) = implicitly[Show[A]].show(a: A) 23 | } 24 | 25 | class ToStringShow[A] extends Show[A] { 26 | def show(a: A) = a.toString 27 | } 28 | } 29 | 30 | object CSV { 31 | def fromList[R <: HList](rs: List[R], separator: String = ",")(implicit foldMap: MapFolder[R, List[String], toCSV.type]) = 32 | (rs map (r => fromRow(r, separator))).mkString("\n") 33 | 34 | def fromRow[R <: HList](r: R, separator: String = ",")(implicit foldMap: MapFolder[R, List[String], toCSV.type]) = 35 | (columnsFromRow(r) map escape).mkString(separator) 36 | 37 | def columnsFromRow[R <: HList](r: R)(implicit foldMap: MapFolder[R, List[String], toCSV.type]) = 38 | r.foldMap(Nil: List[String])(toCSV)(_ ::: _) 39 | 40 | private def escape(s: String) = "\"" + s.replaceAll("\"","\"\"") + "\"" 41 | } 42 | 43 | object toCSV extends Poly1 { 44 | implicit def valueToCsv[V: Show] = at[V](v => List(implicitly[Show[V]].show(v))) 45 | } 46 | -------------------------------------------------------------------------------- /core/src/main/scala/database.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import schemacrawler.schemacrawler._ 4 | import schemacrawler.schema.Schema 5 | import schemacrawler.utility.SchemaCrawlerUtility 6 | 7 | case class DbConfig(url: String, driver: String, username: String, password: String, schema: Option[String]) { 8 | def getConnection = java.sql.DriverManager.getConnection(url, username, password) 9 | } 10 | 11 | object DbSchema { 12 | def read(config: DbConfig): ?[Schema] = try { 13 | Class.forName(config.driver) 14 | val options = new SchemaCrawlerOptions 15 | val level = new SchemaInfoLevel 16 | level.setRetrieveTables(true) 17 | level.setRetrieveColumnDataTypes(true) 18 | level.setRetrieveTableColumns(true) 19 | level.setRetrieveIndices(true) 20 | level.setRetrieveForeignKeys(true) 21 | options.setSchemaInfoLevel(level) 22 | val schemaName = config.schema getOrElse config.url.split('?')(0).split('/').reverse.head 23 | options.setSchemaInclusionRule(new InclusionRule(schemaName, "")) 24 | val conn = config.getConnection 25 | val database = SchemaCrawlerUtility.getDatabase(conn, options) 26 | Option(database.getSchema(schemaName)) orFail 27 | s"Can't read schema '$schemaName'. Schema name can be configured with system property 'sqltyped.schema'." 28 | } catch { 29 | case e: Exception => fail(e.getMessage) 30 | } 31 | } 32 | 33 | -------------------------------------------------------------------------------- /core/src/main/scala/dialects.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import schemacrawler.schema.Schema 4 | import scala.reflect.runtime.universe.{Type, typeOf, appliedType} 5 | import Ast._ 6 | import java.sql.{Types => JdbcTypes} 7 | 8 | trait Dialect { 9 | def parser: SqlParser 10 | def validator: Validator 11 | def typer(schema: Schema, stmt: Statement[Table], dbConfig: DbConfig): Typer 12 | } 13 | 14 | object Dialect { 15 | def choose(driver: String): Dialect = { 16 | if (driver.toLowerCase.contains("mysql")) MysqlDialect 17 | else if (driver.toLowerCase.contains("postgresql")) PostgresqlDialect 18 | else GenericDialect 19 | } 20 | } 21 | 22 | object GenericDialect extends Dialect { 23 | val parser = new SqlParser {} 24 | def validator = JdbcValidator 25 | def typer(schema: Schema, stmt: Statement[Table], dbConfig: DbConfig) = new Typer(schema, stmt, dbConfig) 26 | } 27 | 28 | object MysqlDialect extends Dialect { 29 | def validator = MySQLValidator 30 | 31 | def typer(schema: Schema, stmt: Statement[Table], dbConfig: DbConfig) = new Typer(schema, stmt, dbConfig) { 32 | import dsl._ 33 | 34 | override def extraScalarFunctions = Map( 35 | "datediff" -> datediff _ 36 | , "ifnull" -> ifnull _ 37 | , "coalesce" -> ifnull _ 38 | , "if" -> iff _ 39 | , "binary" -> binary _ 40 | , "convert" -> convert _ 41 | , "concat" -> concat _ 42 | ) 43 | 44 | override def typeSpecifyTerm(v: Variable) = v.comparisonTerm flatMap { 45 | // Special case for: WHERE col IN (?) 46 | case Comparison2(t1, In | NotIn, TermList(Input() :: Nil)) => Some(for { 47 | tpe <- typeTerm(false)(Variable(Named("", None, t1))) 48 | } yield List(TypedValue(v.term.aname, (appliedType(typeOf[Seq[_]].typeConstructor, tpe.head.tpe._1 :: Nil), JdbcTypes.JAVA_OBJECT), isNullable(t1), None, v.term.term))) 49 | case _ => None 50 | } 51 | 52 | def datediff(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 53 | if (params.length != 2) fail("Expected 2 parameters " + params) 54 | else for { 55 | (_tpe0, opt0) <- tpeOf(params(0), ct) 56 | (_tpe1, opt1) <- tpeOf(params(1), ct) 57 | tpe0 = if (_tpe0._1 <:< typeOf[java.util.Date]) _tpe0 else _tpe1 58 | tpe1 = if (_tpe1._1 <:< typeOf[java.util.Date]) _tpe1 else _tpe0 59 | } yield (List((tpe0, opt0), (tpe1, opt1)), ((typeOf[Int], JdbcTypes.INTEGER), true)) 60 | 61 | def ifnull(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 62 | if (params.length != 2) fail("Expected 2 parameters " + params) 63 | else for { 64 | (tpe0, opt0) <- tpeOf(params(0), ct) 65 | (tpe1, opt1) <- tpeOf(params(1), ct) 66 | } yield (List((tpe0, opt0), (tpe1, opt1)), (tpe0, opt1)) 67 | 68 | def iff(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 69 | if (params.length != 3) fail("Expected 3 parameters " + params) 70 | else for { 71 | (tpe0, opt0) <- tpeOf(params(0), ct) 72 | (tpe1, opt1) <- tpeOf(params(1), ct) 73 | (tpe2, opt2) <- tpeOf(params(2), ct) 74 | } yield (List((tpe0, opt0), (tpe1, opt1), (tpe2, opt2)), (tpe1, opt1 || opt2)) 75 | 76 | def binary(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 77 | if (params.length != 1) fail("Expected 1 parameter " + params) 78 | else for { 79 | (tpe0, opt0) <- tpeOf(params(0), ct) 80 | } yield (List((tpe0, opt0)), (tpe0, opt0)) 81 | 82 | def convert(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 83 | if (params.length != 2) fail("Expected 2 parameters " + params) 84 | else for { 85 | (tpe0, opt0) <- tpeOf(params(0), ct) 86 | (tpe1, opt1) <- tpeOf(params(1), ct) 87 | (tpe, opt) <- castToType(tpe0._1, params(1)) 88 | } yield (List((tpe0, opt0), (tpe1, opt1)), (tpe, opt0 || opt)) 89 | 90 | def concat(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 91 | if (params.length < 1) fail("Expected at least 1 parameter") 92 | else for { 93 | in <- sequence(params map (p => tpeOf(p, ct))) 94 | } yield (in, ((typeOf[String], JdbcTypes.VARCHAR), in.map(_._2).forall(identity))) 95 | 96 | private def castToType(orig: Type, target: Expr) = target match { 97 | case TypeExpr(d) => d.name match { 98 | case "date" => ((typeOf[java.sql.Date], JdbcTypes.DATE), true).ok 99 | case "datetime" => ((typeOf[java.sql.Timestamp], JdbcTypes.TIMESTAMP), true).ok 100 | case "time" => ((typeOf[java.sql.Time], JdbcTypes.TIME), true).ok 101 | case "char" => ((typeOf[String], JdbcTypes.CHAR), false).ok 102 | case "binary" => ((typeOf[String], JdbcTypes.BINARY), false).ok 103 | case "decimal" => ((typeOf[Double], JdbcTypes.DECIMAL), false).ok 104 | case "signed" if orig == typeOf[Long] => ((typeOf[Long], JdbcTypes.BIGINT), false).ok 105 | case "signed" => ((typeOf[Int], JdbcTypes.INTEGER), false).ok 106 | case "unsigned" if orig == typeOf[Long] => ((typeOf[Long], JdbcTypes.BIGINT), false).ok 107 | case "unsigned" => ((typeOf[Int], JdbcTypes.INTEGER), false).ok 108 | case x => fail(s"Unsupported type '$target' in cast operation") 109 | } 110 | case e => fail(s"Expected a data type, got '$e'") 111 | } 112 | } 113 | 114 | val parser = MysqlParser 115 | 116 | object MysqlParser extends SqlParser { 117 | import scala.reflect.runtime.universe.typeOf 118 | 119 | override def insert = "insert".i <~ opt("ignore".i) 120 | override def update = "update".i <~ opt("ignore".i) 121 | 122 | override lazy val insertStmt = insertSyntax ~ opt(onDuplicateKey) ^^ { 123 | case t ~ cols ~ vals ~ None => Insert(t, cols, vals) 124 | case t ~ cols ~ vals ~ Some(as) => 125 | Composed(Insert(t, cols, vals), Update(t :: Nil, as, None, None, None)) 126 | } 127 | 128 | lazy val onDuplicateKey = 129 | "on".i ~> "duplicate".i ~> "key".i ~> "update".i ~> repsep(assignment, ",") 130 | 131 | override def quoteChar = ("\"" | "`") 132 | 133 | override def extraTerms = MysqlParser.interval 134 | 135 | override def dataTypes = List( 136 | precision1("binary") 137 | , precision1("char") 138 | , precision1("varchar") 139 | , precision0("tinytext") 140 | , precision0("text") 141 | , precision0("blob") 142 | , precision0("mediumtext") 143 | , precision0("mediumblob") 144 | , precision0("longtext") 145 | , precision0("longblob") 146 | , precision1("tinyint") 147 | , precision1("smallint") 148 | , precision1("mediumint") 149 | , precision1("int") 150 | , precision1("bigint") 151 | , precision0("float") 152 | , precision2("double") 153 | , precision2("decimal") 154 | , precision0("date") 155 | , precision0("datetime") 156 | , precision0("timestamp") 157 | , precision0("time") 158 | , precision0("signed") 159 | , precision0("unsigned") 160 | ) 161 | 162 | def precision0(name: String) = name.i ^^ (n => DataType(n)) 163 | def precision1(name: String) = ( 164 | name.i ~ "(" ~ integer ~ ")" ^^ { case n ~ _ ~ l ~ _ => DataType(n, List(l)) } 165 | | precision0(name) 166 | ) 167 | def precision2(name: String) = ( 168 | name.i ~ "(" ~ integer ~ "," ~ integer ~ ")" ^^ { case n ~ _ ~ l1 ~ _ ~ l2 ~ _ => DataType(n, List(l1, l2)) } 169 | | precision1(name) 170 | | precision0(name) 171 | ) 172 | 173 | lazy val intervalAmount = opt("'") ~> numericLit <~ opt("'") 174 | lazy val interval = "interval".i ~> intervalAmount ~ timeUnit ^^ { case x ~ _ => const((typeOf[java.util.Date], JdbcTypes.TIMESTAMP), x) } 175 | 176 | lazy val timeUnit = ( 177 | "microsecond".i 178 | | "second".i 179 | | "minute".i 180 | | "hour".i 181 | | "day".i 182 | | "week".i 183 | | "month".i 184 | | "quarter".i 185 | | "year".i 186 | ) 187 | } 188 | } 189 | 190 | object PostgresqlDialect extends Dialect { 191 | val parser = new SqlParser {} 192 | def validator = JdbcValidator 193 | 194 | def typer(schema: Schema, stmt: Statement[Table], dbConfig: DbConfig) = new Typer(schema, stmt, dbConfig) { 195 | import dsl._ 196 | 197 | override def extraScalarFunctions = Map( 198 | "any" -> arrayExpr _ 199 | , "some" -> arrayExpr _ 200 | , "all" -> arrayExpr _ 201 | ) 202 | 203 | def arrayExpr(fname: String, params: List[Expr], ct: Option[Term]): ?[SqlFType] = 204 | if (params.length != 1) fail("Expected 1 parameter " + params) 205 | else for { 206 | (tpe0, opt0) <- tpeOf(params(0), ct) 207 | } yield (List(((appliedType(typeOf[Seq[_]].typeConstructor, tpe0._1 :: Nil), JdbcTypes.ARRAY), false)), (tpe0, opt0)) 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /core/src/main/scala/jdbc.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import scala.reflect.runtime.universe.{Type, typeOf} 5 | 6 | private[sqltyped] object Jdbc { 7 | def infer(db: DbConfig, sql: String): ?[TypedStatement] = 8 | withConnection(db.getConnection) { conn => 9 | val stmt = conn.prepareStatement(sql) 10 | for { 11 | out <- (Option(stmt.getMetaData) map inferOutput getOrElse Nil).ok 12 | in <- Option(stmt.getParameterMetaData) map inferInput orFail "Input metadata not available" 13 | isQuery = !out.isEmpty 14 | } yield TypedStatement(in, out, isQuery, Map(), Nil, if (isQuery) NumOfResults.Many else NumOfResults.One) 15 | } flatMap identity 16 | 17 | def inferInput(meta: ParameterMetaData) = 18 | (1 to meta.getParameterCount).toList map { i => 19 | try { 20 | TypedValue("a" + i, (mkType(meta.getParameterClassName(i)), meta.getParameterType(i)), 21 | meta.isNullable(i) == ParameterMetaData.parameterNullable, None, unknownTerm) 22 | } catch { 23 | case e: SQLException => TypedValue("a" + i, (typeOf[Any], Types.JAVA_OBJECT), false, None, unknownTerm) 24 | } 25 | } 26 | 27 | def inferOutput(meta: ResultSetMetaData) = 28 | (1 to meta.getColumnCount).toList map { i => 29 | TypedValue(meta.getColumnLabel(i), (mkType(meta.getColumnClassName(i)), meta.getColumnType(i)), 30 | meta.isNullable(i) != ResultSetMetaData.columnNoNulls, None, unknownTerm) 31 | } 32 | 33 | def unknownTerm = Ast.Column("unknown", Ast.Table("unknown", None, None)) 34 | 35 | def withConnection[A](conn: Connection)(a: Connection => A): ?[A] = try { 36 | a(conn).ok 37 | } catch { 38 | case e: SQLException => fail(e.getMessage) 39 | } finally { 40 | conn.close 41 | } 42 | 43 | // FIXME move to TypeMappings 44 | def mkType(className: String): Type = className match { 45 | case "java.lang.String" => typeOf[String] 46 | case "java.lang.Short" => typeOf[Short] 47 | case "java.lang.Integer" => typeOf[Int] 48 | case "java.lang.Long" => typeOf[Long] 49 | case "java.lang.Float" => typeOf[Float] 50 | case "java.lang.Double" => typeOf[Double] 51 | case "java.lang.Boolean" => typeOf[Boolean] 52 | case "java.lang.Byte" => typeOf[Byte] 53 | case "java.sql.Timestamp" => typeOf[java.sql.Timestamp] 54 | case "java.sql.Date" => typeOf[java.sql.Date] 55 | case "java.sql.Time" => typeOf[java.sql.Time] 56 | case "byte[]" => typeOf[java.sql.Blob] 57 | case "[B" => typeOf[java.sql.Blob] 58 | case "byte" => typeOf[Byte] 59 | case "java.math.BigDecimal" => typeOf[scala.math.BigDecimal] 60 | case x => sys.error("Unknown type " + x) 61 | } 62 | } 63 | 64 | private [sqltyped] object TypeMappings { 65 | import java.sql.Types._ 66 | 67 | /* a mapping from java.sql.Types.* values to their getFoo/setFoo names */ 68 | final val setterGetterNames = Map( 69 | ARRAY -> "Array" 70 | , BIGINT -> "Long" 71 | , BINARY -> "Bytes" 72 | , BIT -> "Boolean" 73 | , BLOB -> "Blob" 74 | , BOOLEAN -> "Boolean" 75 | , CHAR -> "String" 76 | , CLOB -> "Clob" 77 | , DATALINK -> "URL" 78 | , DATE -> "Date" 79 | , DECIMAL -> "BigDecimal" 80 | , DOUBLE -> "Double" 81 | , FLOAT -> "Float" 82 | , INTEGER -> "Int" 83 | , JAVA_OBJECT -> "Object" 84 | , LONGNVARCHAR -> "String" 85 | , LONGVARBINARY -> "Blob" // FIXME should be Bytes? 86 | , LONGVARCHAR -> "String" 87 | , NCHAR -> "String" 88 | , NCLOB -> "NClob" 89 | , NUMERIC -> "BigDecimal" 90 | , NVARCHAR -> "String" 91 | , REAL -> "Float" 92 | , REF -> "Ref" 93 | , ROWID -> "RowId" 94 | , SMALLINT -> "Short" 95 | , SQLXML -> "SQLXML" 96 | , TIME -> "Time" 97 | , TIMESTAMP -> "Timestamp" 98 | , TINYINT -> "Byte" 99 | , VARBINARY -> "Bytes" 100 | , VARCHAR -> "String" 101 | ) 102 | 103 | // FIXME this is dialect specific, move there. Or perhaps schemacrawler provides these? 104 | def arrayTypeName(tpe: Type) = 105 | if (tpe =:= typeOf[String]) "varchar" 106 | else if (tpe =:= typeOf[Int]) "integer" 107 | else if (tpe =:= typeOf[Long]) "bigint" 108 | else sys.error("Unsupported array type " + tpe) 109 | } 110 | -------------------------------------------------------------------------------- /core/src/main/scala/macro.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import scala.util.Properties 5 | import schemacrawler.schema.Schema 6 | import NumOfResults._ 7 | 8 | trait ConfigurationName 9 | 10 | object EnableTagging 11 | 12 | object SqlMacro { 13 | import shapeless._ 14 | import scala.reflect.macros._ 15 | 16 | private val schemaCache = new java.util.WeakHashMap[Context#Run, ?[Schema]]() 17 | 18 | def withResultSet[A](stmt: PreparedStatement)(f: ResultSet => A) = { 19 | var rs: ResultSet = null 20 | try { 21 | f(stmt.executeQuery) 22 | } finally { 23 | if (rs != null) try rs.close catch { case e: Exception => } 24 | stmt.close 25 | } 26 | } 27 | 28 | def withStatement(stmt: PreparedStatement) = 29 | try { 30 | stmt.executeUpdate 31 | } finally { 32 | stmt.close 33 | } 34 | 35 | def withStatementF[A](stmt: PreparedStatement)(f: => A): A = 36 | try { 37 | stmt.executeUpdate 38 | f 39 | } finally { 40 | stmt.close 41 | } 42 | 43 | def sqlImpl 44 | (c: Context) 45 | (s: c.Expr[String]): c.Expr[Any] = { 46 | 47 | import c.universe._ 48 | 49 | val sql = s.tree match { 50 | case Literal(Constant(sql: String)) => sql 51 | case _ => c.abort(c.enclosingPosition, "Argument to macro must be a String literal") 52 | } 53 | compile(c, inputsInferred = true, validate = true, 54 | analyze = true, 55 | sql, (p, s) => p.parseAllWith(p.stmt, s))(Literal(Constant(sql)), Nil) 56 | } 57 | 58 | def dynsqlImpl 59 | (c: Context)(exprs: c.Expr[Any]*): c.Expr[Any] = { 60 | 61 | import c.universe._ 62 | 63 | def append(t1: Tree, t2: Tree) = Apply(Select(t1, newTermName("+").encoded), List(t2)) 64 | 65 | val Expr(Apply(_, List(Apply(_, parts)))) = c.prefix 66 | 67 | val select = parts.head 68 | val sqlExpr = exprs.zip(parts.tail).foldLeft(select) { 69 | case (acc, (Expr(expr), part)) => append(acc, append(expr, part)) 70 | } 71 | 72 | val sql = select match { 73 | case Literal(Constant(sql: String)) => sql 74 | case _ => c.abort(c.enclosingPosition, "Expected String literal as first part of interpolation") 75 | } 76 | 77 | compile(c, inputsInferred = false, 78 | validate = false, analyze = false, 79 | sql, (p, s) => p.parseWith(p.selectStmt, s))(sqlExpr, Nil) 80 | } 81 | 82 | def paramDynsqlImpl 83 | (c: Context)(exprs: c.Expr[Any]*): c.Expr[Any] = { 84 | 85 | import c.universe._ 86 | 87 | val Expr(Apply(_, List(Apply(_, parts)))) = c.prefix 88 | 89 | val sql = parts.map { case Literal(Constant(sql: String)) => sql } mkString "?" 90 | compile(c, inputsInferred = true, 91 | validate = true, analyze = true, 92 | sql, (p, s) => p.parseAllWith(p.stmt, s))(Literal(Constant(sql)), exprs.map(_.tree).toList) 93 | } 94 | 95 | def compile 96 | (c: Context, inputsInferred: Boolean, validate: Boolean, analyze: Boolean, 97 | sql: String, parse: (SqlParser, String) => ?[Ast.Statement[Option[String]]]) 98 | (sqlExpr: c.Tree, args: List[c.Tree]): c.Expr[Any] = { 99 | 100 | import c.universe._ 101 | 102 | val annotations = c.macroApplication.symbol.annotations 103 | val jdbcOnly = annotations.exists( 104 | _.tree.tpe <:< typeOf[jdbcOnly] 105 | ) 106 | 107 | val useInputTags = annotations.exists( 108 | _.tree.tpe <:< typeOf[useInputTags] 109 | ) 110 | 111 | val returnKeys = annotations.exists( 112 | _.tree.tpe <:< typeOf[returnKeys] 113 | ) 114 | 115 | val useSymbolKeyRecords = c.macroApplication.symbol.annotations.exists( 116 | _.tree.tpe <:< typeOf[useSymbolKeyRecords] 117 | ) 118 | 119 | def sysProp(n: String) = Properties.propOrNone(n) orFail 120 | "System property '" + n + "' is required to get a compile time connection to the database" 121 | 122 | def cachedSchema(config: DbConfig) = { 123 | val cached = schemaCache.get(c.enclosingRun) 124 | if (cached != null) cached else { 125 | val s = DbSchema.read(config) 126 | schemaCache.put(c.enclosingRun, s) 127 | s 128 | } 129 | } 130 | 131 | def toPosition(f: Failure[_]) = { 132 | val lineOffset = sql.split("\n").take(f.line - 1).map(_.length).sum 133 | c.enclosingPosition.withPoint(wrappingPos(List(c.prefix.tree)).startOrPoint + f.column + lineOffset) 134 | } 135 | 136 | // A rather kludgey way to pass config name to macro. The concrete name of the type of implicit 137 | // value is used as a config name. E.g. implicit object postgresql extends ConfigurationName 138 | val configName = c.inferImplicitValue(typeOf[ConfigurationName], silent = true) match { 139 | case EmptyTree => None 140 | case tree => 141 | val tpeName = tree.tpe.toString 142 | val s = tpeName.substring(0, tpeName.lastIndexOf(".type")) 143 | Some(s.substring(s.lastIndexOf(".") + 1)) 144 | } 145 | 146 | def propName(suffix: String) = "sqltyped." + configName.map(_ + ".").getOrElse("") + suffix 147 | 148 | def dbConfig = for { 149 | url <- sysProp(propName("url")) 150 | driver <- sysProp(propName("driver")) 151 | username <- sysProp(propName("username")) 152 | password <- sysProp(propName("password")) 153 | } yield DbConfig(url, driver, username, password, Properties.propOrNone(propName("schema"))) 154 | 155 | def generateCode(meta: TypedStatement) = 156 | codeGen(meta, sql, c, returnKeys, inputsInferred, useSymbolKeyRecords)(sqlExpr, args) 157 | 158 | def fallback = for { 159 | db <- dbConfig 160 | _ = Class.forName(db.driver) 161 | meta <- Jdbc.infer(db, sql) 162 | } yield meta 163 | 164 | val timer = Timer(Properties.propIsSet("sqltyped.enable-timer")) 165 | 166 | timer("SQL: " + sql.replace("\n", " ").trim, 0, (if (jdbcOnly) fallback else { 167 | for { 168 | db <- dbConfig 169 | _ = Class.forName(db.driver) 170 | dialect = Dialect.choose(db.driver) 171 | parser = dialect.parser 172 | schema <- cachedSchema(db) 173 | validator = if (validate) dialect.validator else NOPValidator 174 | _ <- timer("validating", 2, validator.validate(db, sql)) 175 | stmt <- timer("parsing", 2, parse(parser, sql)) 176 | resolved <- timer("resolving tables", 2, Ast.resolveTables(stmt)) 177 | typer = dialect.typer(schema, resolved, db) 178 | typed <- timer("typing", 2, typer.infer(useInputTags)) 179 | meta <- timer("analyzing", 2, if (analyze) new Analyzer(typer).refine(resolved, typed) else typed.ok) 180 | } yield meta }) fold ( 181 | fail => fallback fold ( 182 | fail2 => c.abort(toPosition(fail2), fail2.message), 183 | meta => { 184 | c.warning(toPosition(fail), fail.message + "\nFallback to JDBC metadata. Please file a bug at https://github.com/jonifreeman/sqltyped/issues") 185 | timer("codegen", 2, generateCode(meta)) 186 | } 187 | ), 188 | meta => timer("codegen", 2, generateCode(meta)) 189 | )) 190 | } 191 | 192 | def codeGen[A: c.WeakTypeTag] 193 | (meta: TypedStatement, sql: String, c: Context, keys: Boolean, inputsInferred: Boolean, useSymbolKeyRecords: Boolean) 194 | (sqlExpr: c.Tree, args: List[c.Tree] = Nil): c.Expr[Any] = { 195 | 196 | import c.universe._ 197 | 198 | val enableTagging = c.inferImplicitValue(typeOf[EnableTagging.type], silent = true) match { 199 | case EmptyTree => false 200 | case _ => true 201 | } 202 | 203 | val namingStrategy: String => String = Properties.propOrNone("sqltyped.naming_strategy") match { 204 | case None => identity _ 205 | case Some(cl) => 206 | val constructor = this.getClass.getClassLoader.loadClass(cl).getDeclaredConstructors()(0) 207 | constructor.setAccessible(true) 208 | constructor.newInstance().asInstanceOf[String => String] 209 | } 210 | 211 | def rs(x: TypedValue, pos: Int) = 212 | if (x.nullable) { 213 | Block( 214 | List(ValDef(Modifiers(), newTermName("x"), TypeTree(), getValue(x, pos))), 215 | If(Apply(Select(Ident(newTermName("rs")), newTermName("wasNull")), List()), 216 | Select(Ident(newTermName("scala")), newTermName("None")), 217 | Apply(Select(Select(Ident(newTermName("scala")), newTermName("Some")), newTermName("apply")), List(getTyped(x, Ident(newTermName("x"))))))) 218 | } else getTyped(x, getValue(x, pos)) 219 | 220 | def getValue(x: TypedValue, pos: Int) = 221 | Apply(Select(Ident(newTermName("rs")), newTermName(rsGetterName(x))), List(Literal(Constant(pos)))) 222 | 223 | def getTyped(x: TypedValue, r: Tree) = { 224 | def baseValue = Typed(r, scalaBaseType(x)) 225 | 226 | (if (enableTagging) x.tag else None) map(t => tagType(t)) map (tagged => 227 | Apply( 228 | Select( 229 | TypeApply( 230 | Select(Select(Ident(newTermName("shapeless")), newTermName("tag")), newTermName("apply")), 231 | List(tagged)), newTermName("apply")), List(baseValue)) 232 | ) getOrElse baseValue 233 | } 234 | 235 | def firstTypeParamOf(tpe: reflect.runtime.universe.Type): reflect.runtime.universe.Type = 236 | tpe.asInstanceOf[reflect.runtime.universe.TypeRefApi].args.headOption getOrElse reflect.runtime.universe.typeOf[Any] 237 | 238 | def scalaBaseType(x: TypedValue) = 239 | if (x.tpe._1 <:< reflect.runtime.universe.typeOf[Seq[_]]) { 240 | // We'd like to say here just TypeTree(x.tpe) but types don't match. 241 | // x.tpe is from runtime.universe and we need c.universe. 242 | // Making x.tpe to be c.universe.Type would be correct but rather complicated 243 | // as explicit passing of Context everywhere just pollutes the code. 244 | val typeParam = firstTypeParamOf(x.tpe._1) 245 | AppliedTypeTree(Ident(c.mirror.staticClass("scala.collection.Seq")), 246 | List(Ident(c.mirror.staticClass(typeParam.typeSymbol.fullName)))) 247 | } else Ident(c.mirror.staticClass(x.tpe._1.typeSymbol.fullName)) 248 | 249 | def scalaType(x: TypedValue) = { 250 | (if (enableTagging) x.tag else None) map (t => tagType(t)) map (tagged => 251 | AppliedTypeTree( 252 | Select(Select(Ident(newTermName("shapeless")), newTermName("tag")), newTypeName("$at$at")), 253 | List(scalaBaseType(x), tagged)) 254 | ) getOrElse scalaBaseType(x) 255 | } 256 | 257 | def tagType(tag: String) = Select(Ident(newTermName(tag)), newTypeName("T")) 258 | 259 | def stmtSetterName(x: TypedValue) = "set" + TypeMappings.setterGetterNames(x.tpe._2) 260 | def rsGetterName(x: TypedValue) = "get" + TypeMappings.setterGetterNames(x.tpe._2) 261 | 262 | def setParam(x: TypedValue, pos: Int) = { 263 | val param = Ident(newTermName("i" + pos)) 264 | if (x.nullable) 265 | If(Select(param, newTermName("isDefined")), 266 | Apply(Select(Ident(newTermName("stmt")), newTermName(stmtSetterName(x))), List(Literal(Constant(pos+1)), coerce(x, Select(param, newTermName("get"))))), 267 | Apply(Select(Ident(newTermName("stmt")), newTermName("setObject")), List(Literal(Constant(pos+1)), Literal(Constant(null))))) 268 | else 269 | Apply(Select(Ident(newTermName("stmt")), newTermName(stmtSetterName(x))), 270 | List(Literal(Constant(pos+1)), coerce(x, param))) 271 | } 272 | 273 | def coerce(x: TypedValue, i: Tree) = 274 | if (x.tpe._2 == java.sql.Types.ARRAY) 275 | createArray(TypeMappings.arrayTypeName(firstTypeParamOf(x.tpe._1)), i) 276 | else if (x.tpe._1 =:= reflect.runtime.universe.typeOf[BigDecimal]) 277 | Apply(Select(i, newTermName("underlying")), List()) 278 | else i 279 | 280 | def inputParam(x: TypedValue, pos: Int) = 281 | ValDef(Modifiers(Flag.PARAM), newTermName("i" + pos), possiblyOptional(x, scalaType(x)), EmptyTree) 282 | 283 | def inputTypeSig = 284 | if (inputsInferred) meta.input.map(col => possiblyOptional(col, scalaType(col))) 285 | else List(AppliedTypeTree(Ident(c.mirror.staticClass("scala.collection.Seq")), 286 | List(Ident(c.mirror.staticClass("scala.Any"))))) 287 | 288 | def possiblyOptional(x: TypedValue, tpe: Tree) = 289 | if (x.nullable) AppliedTypeTree(Ident(c.mirror.staticClass("scala.Option")), List(tpe)) 290 | else tpe 291 | 292 | def returnTypeSig = 293 | if (meta.output.length == 0) List(Ident(c.mirror.staticClass("scala.Int"))) 294 | else if (meta.output.length == 1) returnTypeSigScalar 295 | else returnTypeSigRecord 296 | 297 | def resultTypeSig = 298 | if (keys && (meta.numOfResults != Many)) scalaType(meta.generatedKeyTypes.head) 299 | else if (keys) AppliedTypeTree(Ident(newTypeName("List")), List(scalaType(meta.generatedKeyTypes.head))) 300 | else if (meta.output.length == 0 || meta.numOfResults == One) returnTypeSig.head 301 | else if (meta.numOfResults == ZeroOrOne) AppliedTypeTree(Ident(newTypeName("Option")), returnTypeSig) 302 | else AppliedTypeTree(Ident(newTypeName("List")), returnTypeSig) 303 | 304 | def appendRow = 305 | if (meta.output.length == 1) appendRowScalar 306 | else appendRowRecord 307 | 308 | def returnTypeSigRecord = List(meta.output.foldRight(Ident(c.mirror.staticClass("shapeless.HNil")): Tree) { (x, sig) => 309 | val keyType = if (useSymbolKeyRecords) { 310 | CompoundTypeTree(Template(List((Ident(c.mirror.staticClass("scala.Symbol"))), AppliedTypeTree(Select(Ident(c.mirror.staticModule("shapeless.tag")), TypeName("Tagged")), List(Select(Ident(newTermName(keyName(x))), newTypeName("T"))))), noSelfType, Nil)) 311 | } else { 312 | Select(Ident(newTermName(keyName(x))), newTypeName("T")) 313 | } 314 | AppliedTypeTree( 315 | Ident(c.mirror.staticClass("shapeless.$colon$colon")), 316 | List(AppliedTypeTree(Select(Ident(c.mirror.staticModule("shapeless.labelled")), newTypeName("FieldType")), List(keyType, possiblyOptional(x, scalaType(x)))), sig) 317 | ) 318 | }) 319 | 320 | def keyName(x: TypedValue) = namingStrategy(x.name) 321 | 322 | def returnTypeSigScalar = List(possiblyOptional(meta.output.head, scalaType(meta.output.head))) 323 | 324 | def appendRowRecord = { 325 | def processRow(x: TypedValue, i: Int): Tree = { 326 | val key = if (useSymbolKeyRecords) { 327 | Apply(Select(Select(Ident(newTermName("scala")), newTermName("Symbol")), newTermName("apply")), List(Literal(Constant(keyName(x))))) 328 | } else { 329 | Literal(Constant(keyName(x))) 330 | } 331 | 332 | ValDef(Modifiers(/*Flag.SYNTHETIC*/), 333 | newTermName("x$" + (i+1)), 334 | TypeTree(), 335 | Apply(Select(Apply( 336 | Select(Ident(c.mirror.staticModule("shapeless.syntax.singleton")), newTermName("mkSingletonOps")), 337 | List(key)), newTermName("->>").encoded), List(rs(x, meta.output.length - i)))) 338 | } 339 | 340 | val init: Tree = 341 | Block(List( 342 | processRow(meta.output.last, 0)), 343 | Apply( 344 | Select(Select(Ident(newTermName("shapeless")), newTermName("HNil")), newTermName("::").encoded), 345 | List(Ident(newTermName("x$1"))) 346 | )) 347 | 348 | List(meta.output.reverse.drop(1).zipWithIndex.foldLeft(init) { case (ast, (x, i)) => 349 | Block( 350 | processRow(x, i+1), 351 | Apply( 352 | Select( 353 | Apply( 354 | Select(Ident(c.mirror.staticModule("shapeless.HList")), newTermName("hlistOps")), 355 | List(Block(ast)) 356 | ), 357 | newTermName("::").encoded), List(Ident(newTermName("x$" + (i+2)))))) 358 | }) 359 | } 360 | 361 | def appendRowScalar = List(rs(meta.output.head, 1)) 362 | 363 | def readRows = 364 | List( 365 | ValDef( 366 | Modifiers(), newTermName("rows"), TypeTree(), 367 | Apply(TypeApply(Select(Select(Select(Select(Ident(newTermName("scala")), newTermName("collection")), newTermName("mutable")), newTermName("ListBuffer")), newTermName("apply")), returnTypeSig), List())), 368 | LabelDef(newTermName("while$1"), List(), 369 | If(Apply(Select(Ident(newTermName("rs")), newTermName("next")), List()), 370 | Block(List(Apply(Select(Ident(newTermName("rows")), newTermName("append")), appendRow)), 371 | Apply(Ident(newTermName("while$1")), List())), Literal(Constant(()))))) 372 | 373 | def returnRows = 374 | if (meta.numOfResults == Many) 375 | Select(Ident(newTermName("rows")), newTermName("toList")) 376 | else if (meta.numOfResults == ZeroOrOne) 377 | Select(Select(Ident(newTermName("rows")), newTermName("toList")), newTermName("headOption")) 378 | else 379 | Select(Select(Ident(newTermName("rows")), newTermName("toList")), newTermName("head")) 380 | 381 | def processStmt = 382 | if (meta.isQuery) { 383 | Apply( 384 | Apply(Select(Select(Ident(newTermName("sqltyped")), newTermName("SqlMacro")), newTermName("withResultSet")), List(Ident(newTermName("stmt")))), 385 | List(Function(List(ValDef(Modifiers(Flag.PARAM), newTermName("rs"), TypeTree(), EmptyTree)), 386 | Block(readRows, returnRows)))) 387 | } else if (keys && meta.numOfResults == Many) { 388 | processStmtWithKeys(meta.generatedKeyTypes.head) 389 | } else if (keys) { 390 | Select(processStmtWithKeys(meta.generatedKeyTypes.head), newTermName("head")) 391 | } else { 392 | Apply(Select(Select(Ident(newTermName("sqltyped")), newTermName("SqlMacro")), newTermName("withStatement")), List(Ident(newTermName("stmt")))) 393 | } 394 | 395 | def processStmtWithKeys(keyType: TypedValue) = 396 | Apply( 397 | Apply(Select(Select(Ident(newTermName("sqltyped")), newTermName("SqlMacro")), newTermName("withStatementF")), List(Ident(newTermName("stmt")))), 398 | List( 399 | Block( 400 | List( 401 | ValDef(Modifiers(), newTermName("rs"), TypeTree(), Apply(Select(Ident(newTermName("stmt")), newTermName("getGeneratedKeys")), List())), 402 | ValDef(Modifiers(), newTermName("keys"), TypeTree(), Apply( 403 | TypeApply(Select(Select(Select(Select(Ident(newTermName("scala")), newTermName("collection")), newTermName("mutable")), newTermName("ListBuffer")), newTermName("apply")), List(scalaType(keyType))), List())), 404 | LabelDef(newTermName("while$1"), List(), 405 | If(Apply(Select(Ident(newTermName("rs")), newTermName("next")), List()), 406 | Block( 407 | List( 408 | Apply(Select(Ident(newTermName("keys")), newTermName("append")), List(getTyped(keyType, getValue(keyType, 1))))), 409 | Apply(Ident(newTermName("while$1")), List())), Literal(Constant(())))), 410 | Apply(Select(Ident(newTermName("rs")), newTermName("close")), List())), 411 | Select(Ident(newTermName("keys")), newTermName("toList"))))) 412 | 413 | def createArray(elemTypeName: String, seq: Tree) = { 414 | val castedSeq = 415 | TypeApply( 416 | Select(seq, newTermName("asInstanceOf")), 417 | List(AppliedTypeTree(Ident(c.mirror.staticClass("scala.collection.Seq")), 418 | List(Ident(c.mirror.staticClass("java.lang.Object")))))) 419 | 420 | Apply( 421 | Select(Ident(newTermName("conn")), newTermName("createArrayOf")), 422 | List( 423 | Literal(Constant(elemTypeName)), 424 | Apply(Select(castedSeq, newTermName("toArray")), List(Select(Ident(c.mirror.staticModule("scala.Predef")), newTermName("implicitly")))))) 425 | } 426 | 427 | def prepareStatement = 428 | Apply( 429 | Select(Ident(newTermName("conn")), newTermName("prepareStatement")), 430 | sqlExpr :: (if (keys) List(Literal(Constant(Statement.RETURN_GENERATED_KEYS))) else Nil)) 431 | 432 | def queryF = { 433 | def argList = 434 | if (inputsInferred) meta.input.zipWithIndex.map { case (c, i) => inputParam(c, i) } 435 | else List(ValDef(Modifiers(Flag.PARAM), newTermName("args$"), 436 | AppliedTypeTree(Ident(c.mirror.staticClass("scala.collection.Seq")), 437 | List(Ident(c.mirror.staticClass("scala.Any")))), 438 | EmptyTree)) 439 | 440 | def processArgs = 441 | if (inputsInferred) meta.input.zipWithIndex.map { case (c, i) => setParam(c, i) } 442 | else 443 | List(Block( 444 | List(ValDef(Modifiers(Flag.MUTABLE), newTermName("count$"), TypeTree(), Literal(Constant(0)))), 445 | LabelDef( 446 | newTermName("while$1"), 447 | List(), 448 | If(Apply(Select(Ident(newTermName("count$")), newTermName("$less")), 449 | List(Select(Ident(newTermName("args$")), newTermName("length")))), 450 | Block( 451 | List( 452 | Block( 453 | List( 454 | Apply(Select(Ident(newTermName("stmt")), newTermName("setObject")), 455 | List(Apply(Select(Ident(newTermName("count$")), newTermName("$plus")), List(Literal(Constant(1)))), 456 | Apply(Select(Ident(newTermName("args$")), newTermName("apply")), List(Ident(newTermName("count$"))))))), 457 | Assign(Ident(newTermName("count$")), 458 | Apply( 459 | Select(Ident(newTermName("count$")), newTermName("$plus")), 460 | List(Literal(Constant(1))))))), 461 | Apply(Ident(newTermName("while$1")), List())), Literal(Constant(())))))) 462 | 463 | 464 | DefDef( 465 | Modifiers(), newTermName("apply"), List(), 466 | List( 467 | argList, 468 | List(ValDef(Modifiers(Flag.IMPLICIT | Flag.PARAM), newTermName("conn"), Ident(c.mirror.staticClass("java.sql.Connection")), EmptyTree))), 469 | TypeTree(), 470 | Block( 471 | ValDef(Modifiers(), newTermName("stmt"), TypeTree(), prepareStatement) :: processArgs, 472 | processStmt 473 | ) 474 | ) 475 | } 476 | 477 | /* 478 | sql("select name, age from person where age > ?") 479 | 480 | Generates following code: 481 | 482 | identity { 483 | val name = Witness("name") 484 | val age = Witness("age") 485 | 486 | new Query1[Int, FieldType[name.T, String] :: FieldType[age.T, Int] :: HNil] { 487 | def apply(i1: Int)(implicit conn: Connection) = { 488 | val stmt = conn.prepareStatement("select name, age from person where age > ?") 489 | stmt.setInt(1, i1) 490 | withResultSet(stmt) { rs => 491 | val rows = collection.mutable.ListBuffer[FieldType[name.T, String] :: FieldType[age.T, Int] :: HNil]() 492 | while (rs.next) { 493 | rows.append("name" ->> rs.getString(1) :: "age" ->> rs.getInt(2) :: HNil) 494 | } 495 | rows.toList 496 | } 497 | } 498 | } 499 | } 500 | 501 | */ 502 | 503 | val inputLen = if (inputsInferred) meta.input.length else 1 504 | 505 | def witnesses = ( 506 | (meta.output map (keyName)) ::: 507 | (meta.input flatMap (_.tag)) ::: 508 | (meta.output flatMap (_.tag)) ::: 509 | (if (keys) { meta.generatedKeyTypes flatMap (_.tag) } else Nil) 510 | ).distinct 511 | 512 | def mkWitness(name: String) = 513 | ValDef( 514 | Modifiers(), 515 | newTermName(name), 516 | TypeTree(), 517 | Apply(Select(Ident(c.mirror.staticModule("shapeless.Witness")), newTermName("apply")), 518 | List(Literal(Constant(name))))) 519 | 520 | def genQueryClass(inputLen: Int, methodSig: List[Tree], impl: Tree) = { 521 | ClassDef(Modifiers(Flag.FINAL), newTypeName("$anon"), List(), 522 | Template(List( 523 | AppliedTypeTree( 524 | Ident(c.mirror.staticClass("sqltyped.Query" + inputLen)), methodSig)), 525 | emptyValDef, List( 526 | ValDef(Modifiers(), TermName("sql"), TypeTree(), Literal(Constant(sql))), 527 | DefDef( 528 | Modifiers(), 529 | nme.CONSTRUCTOR, 530 | List(), 531 | List(List()), 532 | TypeTree(), 533 | Block( 534 | List( 535 | Apply( 536 | Select(Super(This(""), ""), nme.CONSTRUCTOR), Nil)), 537 | Literal(Constant(())))), 538 | TypeDef(Modifiers(), newTypeName("ReturnType"), List(), returnTypeSig.head), impl))) 539 | } 540 | 541 | def mkQuery = 542 | Block( 543 | List(genQueryClass(inputLen, inputTypeSig ::: List(resultTypeSig), queryF)), 544 | Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) 545 | ) 546 | 547 | c.Expr { 548 | Apply( 549 | Select(Ident(c.mirror.staticModule("scala.Predef")), newTermName("identity")), 550 | List( 551 | Block( 552 | witnesses map (i => mkWitness(i)), 553 | if (args.nonEmpty) { 554 | val impl = DefDef( 555 | Modifiers(), newTermName("apply"), List(), 556 | List(Nil, List(ValDef(Modifiers(Flag.IMPLICIT | Flag.PARAM), newTermName("conn"), Ident(c.mirror.staticClass("java.sql.Connection")), EmptyTree))), 557 | TypeTree(), 558 | Block(Nil, Apply(Select(mkQuery, newTermName("apply")), args))) 559 | Block( 560 | List( 561 | genQueryClass(0, List(resultTypeSig), impl)), 562 | Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) 563 | ) 564 | } else mkQuery 565 | ) 566 | ) 567 | ) 568 | } 569 | } 570 | } 571 | 572 | // FIXME Replace all these with 'trait SqlF[R]' once Scala macros can create public members 573 | // (apply must be public) 574 | trait Query0[R] { 575 | type ReturnType 576 | val sql: String 577 | def apply()(implicit conn: Connection): R 578 | } 579 | trait Query1[I1, R] { 580 | type ReturnType 581 | val sql: String 582 | def apply(i1: I1)(implicit conn: Connection): R 583 | } 584 | trait Query2[I1, I2, R] { 585 | type ReturnType 586 | val sql: String 587 | def apply(i1: I1, i2: I2)(implicit conn: Connection): R 588 | } 589 | trait Query3[I1, I2, I3, R] { 590 | type ReturnType 591 | val sql: String 592 | def apply(i1: I1, i2: I2, i3: I3)(implicit conn: Connection): R 593 | } 594 | trait Query4[I1, I2, I3, I4, R] { 595 | type ReturnType 596 | val sql: String 597 | def apply(i1: I1, i2: I2, i3: I3, i4: I4)(implicit conn: Connection): R 598 | } 599 | trait Query5[I1, I2, I3, I4, I5, R] { 600 | type ReturnType 601 | val sql: String 602 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5)(implicit conn: Connection): R 603 | } 604 | trait Query6[I1, I2, I3, I4, I5, I6, R] { 605 | type ReturnType 606 | val sql: String 607 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6)(implicit conn: Connection): R 608 | } 609 | trait Query7[I1, I2, I3, I4, I5, I6, I7, R] { 610 | type ReturnType 611 | val sql: String 612 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7)(implicit conn: Connection): R 613 | } 614 | trait Query8[I1, I2, I3, I4, I5, I6, I7, I8, R] { 615 | type ReturnType 616 | val sql: String 617 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8)(implicit conn: Connection): R 618 | } 619 | trait Query9[I1, I2, I3, I4, I5, I6, I7, I8, I9, R] { 620 | type ReturnType 621 | val sql: String 622 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9)(implicit conn: Connection): R 623 | } 624 | trait Query10[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, R] { 625 | type ReturnType 626 | val sql: String 627 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10)(implicit conn: Connection): R 628 | } 629 | trait Query11[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, R] { 630 | type ReturnType 631 | val sql: String 632 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11)(implicit conn: Connection): R 633 | } 634 | trait Query12[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, R] { 635 | type ReturnType 636 | val sql: String 637 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12)(implicit conn: Connection): R 638 | } 639 | trait Query13[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, R] { 640 | type ReturnType 641 | val sql: String 642 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13)(implicit conn: Connection): R 643 | } 644 | trait Query14[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, R] { 645 | type ReturnType 646 | val sql: String 647 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14)(implicit conn: Connection): R 648 | } 649 | trait Query15[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, R] { 650 | type ReturnType 651 | val sql: String 652 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15)(implicit conn: Connection): R 653 | } 654 | trait Query16[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, R] { 655 | type ReturnType 656 | val sql: String 657 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16)(implicit conn: Connection): R 658 | } 659 | trait Query17[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, I17, R] { 660 | type ReturnType 661 | val sql: String 662 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16, i17: I17)(implicit conn: Connection): R 663 | } 664 | trait Query18[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, I17, I18, R] { 665 | type ReturnType 666 | val sql: String 667 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16, i17: I17, i18: I18)(implicit conn: Connection): R 668 | } 669 | trait Query19[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, I17, I18, I19, R] { 670 | type ReturnType 671 | val sql: String 672 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16, i17: I17, i18: I18, i19: I19)(implicit conn: Connection): R 673 | } 674 | trait Query20[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, I17, I18, I19, I20, R] { 675 | type ReturnType 676 | val sql: String 677 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16, i17: I17, i18: I18, i19: I19, i20: I20)(implicit conn: Connection): R 678 | } 679 | trait Query21[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, I17, I18, I19, I20, I21, R] { 680 | type ReturnType 681 | val sql: String 682 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16, i17: I17, i18: I18, i19: I19, i20: I20, i21: I21)(implicit conn: Connection): R 683 | } 684 | trait Query22[I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, I16, I17, I18, I19, I20, I21, I22, R] { 685 | type ReturnType 686 | val sql: String 687 | def apply(i1: I1, i2: I2, i3: I3, i4: I4, i5: I5, i6: I6, i7: I7, i8: I8, i9: I9, i10: I10, i11: I11, i12: I12, i13: I13, i14: I14, i15: I15, i16: I16, i17: I17, i18: I18, i19: I19, i20: I20, i21: I21, i22: I22)(implicit conn: Connection): R 688 | } 689 | -------------------------------------------------------------------------------- /core/src/main/scala/package.scala: -------------------------------------------------------------------------------- 1 | import shapeless._ 2 | 3 | package object sqltyped { 4 | import scala.annotation.StaticAnnotation 5 | import scala.language.experimental.macros 6 | import scala.language.implicitConversions 7 | 8 | class useInputTags extends StaticAnnotation 9 | class jdbcOnly extends StaticAnnotation 10 | class returnKeys extends StaticAnnotation 11 | class useSymbolKeyRecords extends StaticAnnotation 12 | 13 | def sql(s: String) = macro SqlMacro.sqlImpl 14 | 15 | @useInputTags def sqlt(s: String) = macro SqlMacro.sqlImpl 16 | 17 | // FIXME switch to sql("select ...", keys = true) after; 18 | // https://issues.scala-lang.org/browse/SI-5920 19 | @returnKeys def sqlk(s: String) = macro SqlMacro.sqlImpl 20 | 21 | @jdbcOnly def sqlj(s: String) = macro SqlMacro.sqlImpl 22 | 23 | @useSymbolKeyRecords def sqls(s: String) = macro SqlMacro.sqlImpl 24 | 25 | implicit class DynSQLContext(sc: StringContext) { 26 | def sql(exprs: Any*) = macro SqlMacro.dynsqlImpl 27 | } 28 | 29 | implicit class DynParamSQLContext(sc: StringContext) { 30 | def sqlp(exprs: Any*) = macro SqlMacro.paramDynsqlImpl 31 | } 32 | 33 | implicit def recordOps[R <: HList](r: R): RecordOps[R] = new RecordOps(r) 34 | 35 | implicit def listOps[L <: HList](l: List[L]): ListOps[L] = new ListOps(l) 36 | 37 | implicit def optionOps[L <: HList](l: Option[L]): OptionOps[L] = new OptionOps(l) 38 | 39 | // To reduce importing when using records... 40 | implicit def mkSingletonOps(t: Any): syntax.SingletonOps = macro SingletonTypeMacros.mkSingletonOps 41 | 42 | // Internally ? is used to denote computations that may fail. 43 | private[sqltyped] def fail[A](s: String, column: Int = 0, line: Int = 0): ?[A] = 44 | sqltyped.Failure(s, column, line) 45 | private[sqltyped] def ok[A](a: A): ?[A] = sqltyped.Ok(a) 46 | private[sqltyped] implicit class ResultOps[A](a: A) { 47 | def ok = sqltyped.ok(a) 48 | } 49 | private[sqltyped] implicit class ResultOptionOps[A](a: Option[A]) { 50 | def orFail(s: => String) = a map sqltyped.ok getOrElse fail(s) 51 | } 52 | private[sqltyped] def sequence[A](rs: List[?[A]]): ?[List[A]] = 53 | rs.foldRight(List[A]().ok) { (ra, ras) => for { as <- ras; a <- ra } yield a :: as } 54 | private[sqltyped] def sequenceO[A](rs: Option[?[A]]): ?[Option[A]] = 55 | rs.foldRight(None.ok: ?[Option[A]]) { (ra, _) => for { a <- ra } yield Some(a) } 56 | } 57 | 58 | package sqltyped { 59 | private[sqltyped] abstract sealed class ?[+A] { self => 60 | def map[B](f: A => B): ?[B] 61 | def flatMap[B](f: A => ?[B]): ?[B] 62 | def foreach[U](f: A => U): Unit 63 | def fold[B](ifFail: Failure[A] => B, f: A => B): B 64 | def getOrElse[B >: A](default: => B): B 65 | def filter(p: A => Boolean): ?[A] 66 | def withFilter(p: A => Boolean): WithFilter = new WithFilter(p) 67 | class WithFilter(p: A => Boolean) { 68 | def map[B](f: A => B): ?[B] = self filter p map f 69 | def flatMap[B](f: A => ?[B]): ?[B] = self filter p flatMap f 70 | def foreach[U](f: A => U): Unit = self filter p foreach f 71 | def withFilter(q: A => Boolean): WithFilter = new WithFilter(x => p(x) && q(x)) 72 | } 73 | } 74 | private[sqltyped] final case class Ok[+A](a: A) extends ?[A] { 75 | def map[B](f: A => B) = Ok(f(a)) 76 | def flatMap[B](f: A => ?[B]) = f(a) 77 | def foreach[U](f: A => U) = { f(a); () } 78 | def fold[B](ifFail: Failure[A] => B, f: A => B) = f(a) 79 | def getOrElse[B >: A](default: => B) = a 80 | def filter(p: A => Boolean) = if (p(a)) this else fail("filter on ?[_] failed") 81 | } 82 | private[sqltyped] final case class Failure[+A](message: String, column: Int, line: Int) extends ?[A] { 83 | def map[B](f: A => B) = Failure(message, column, line) 84 | def flatMap[B](f: A => ?[B]) = Failure(message, column, line) 85 | def foreach[U](f: A => U) = () 86 | def fold[B](ifFail: Failure[A] => B, f: A => B) = ifFail(this) 87 | def getOrElse[B >: A](default: => B) = default 88 | def filter(p: A => Boolean) = this 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /core/src/main/scala/parser.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import scala.util.parsing.combinator._ 4 | import scala.reflect.runtime.universe.{Type, typeOf} 5 | import java.sql.{Types => JdbcTypes} 6 | 7 | trait SqlParser extends RegexParsers with Ast.Unresolved with PackratParsers { 8 | import Ast._ 9 | 10 | def parseAllWith(p: Parser[Statement], sql: String) = ok_?(parseAll(p, sql)) 11 | 12 | def parseWith(p: Parser[Statement], sql: String) = ok_?(parse(p, input(sql))) 13 | 14 | def input(s: String) = new PackratReader(new scala.util.parsing.input.CharArrayReader(s.toCharArray)) 15 | 16 | def ok_?(res: ParseResult[Statement]) = res match { 17 | case Success(r, q) => ok(r) 18 | case err: NoSuccess => fail(err.msg, err.next.pos.column, err.next.pos.line) 19 | } 20 | 21 | lazy val stmt = (setStmt | selectStmt | insertStmt | updateStmt | deleteStmt | createStmt) 22 | 23 | lazy val selectStmt = select ~ from ~ opt(where) ~ opt(groupBy) ~ opt(orderBy) ~ opt(limit) <~ opt("for".i ~ "update".i) ^^ { 24 | case s ~ f ~ w ~ g ~ o ~ l => Select(s, f, w, g, o, l) 25 | } 26 | 27 | lazy val setOperator = ("union".i | "except".i | "intersect".i) 28 | 29 | lazy val setStmt = optParens(selectStmt) ~ setOperator ~ opt("all".i) ~ optParens(selectStmt) ~ opt(orderBy) ~ opt(limit) ^^ { 30 | case s1 ~ op ~ _ ~ s2 ~ o ~ l => SetStatement(s1, op, s2, o, l) 31 | } 32 | 33 | lazy val insertSyntax = insert ~> "into".i ~> table ~ opt(colNames) ~ (listValues | selectValues) 34 | 35 | lazy val insertStmt: Parser[Statement] = insertSyntax ^^ { 36 | case t ~ cols ~ vals => Insert(t, cols, vals) 37 | } 38 | 39 | lazy val colNames = "(" ~> repsep(ident, ",") <~ ")" 40 | 41 | lazy val listValues = "values".i ~> "(" ~> repsep(term, ",") <~ ")" ^^ ListedInput.apply 42 | 43 | lazy val selectValues = optParens(selectStmt) ^^ SelectedInput.apply 44 | 45 | lazy val updateStmt = update ~ rep1sep(table, ",") ~ "set".i ~ rep1sep(assignment, ",") ~ opt(where) ~ opt(orderBy) ~ opt(limit) ^^ { 46 | case _ ~ t ~ _ ~ a ~ w ~ o ~ l => Update(t, a, w, o, l) 47 | } 48 | 49 | def insert = "insert".i 50 | def update = "update".i 51 | 52 | lazy val assignment = column ~ "=" ~ term ^^ { case c ~ _ ~ t => (c, t) } 53 | 54 | lazy val deleteStmt = 55 | "delete".i ~ opt(repsep(ident, ",")) ~ "from".i ~ rep1sep(table, ",") ~ opt(where) ^^ { 56 | case _ ~ _ ~ t ~ w => Delete(t, w) 57 | } 58 | 59 | lazy val createStmt = "create".i ^^^ Create[Option[String]]() 60 | 61 | lazy val select = "select".i ~> repsep((opt("all".i) ~> named), ",") 62 | 63 | lazy val from = "from".i ~> rep1sep(tableReference, ",") 64 | 65 | lazy val tableReference: Parser[TableReference] = ( 66 | joinedTable 67 | | derivedTable 68 | | table ^^ (t => ConcreteTable(t, Nil)) 69 | ) 70 | 71 | lazy val joinedTable = table ~ rep(joinType) ^^ { case t ~ j => ConcreteTable(t, j) } 72 | 73 | lazy val joinType = (crossJoin | qualifiedJoin) 74 | 75 | lazy val crossJoin = "cross".i ~ "join".i ~ optParens(tableReference) ^^ { 76 | case _ ~ _ ~ table => Join(table, None, Cross) 77 | } 78 | 79 | lazy val joinDesc = ( 80 | "left".i ~ opt("outer".i) ^^^ LeftOuter 81 | | "right".i ~ opt("outer".i) ^^^ RightOuter 82 | | "full".i ~ opt("outer".i) ^^^ FullOuter 83 | | "inner".i ^^^ Inner 84 | ) 85 | 86 | lazy val qualifiedJoin = opt(joinDesc) ~ "join".i ~ optParens(tableReference) ~ opt(joinSpec) ^^ { 87 | case joinDesc ~ _ ~ table ~ spec => Join(table, spec, joinDesc getOrElse Inner) 88 | } 89 | 90 | lazy val joinSpec = (joinCondition | namedColumnsJoin) 91 | 92 | lazy val joinCondition = "on".i ~> expr ^^ QualifiedJoin.apply 93 | 94 | lazy val namedColumnsJoin = "using".i ~> "(" ~> rep1sep(ident, ",") <~ ")" ^^ { 95 | cols => NamedColumnsJoin[Option[String]](cols) 96 | } 97 | 98 | lazy val derivedTable = subselect ~ opt("as".i) ~ ident ~ rep(joinType) ^^ { 99 | case s ~ _ ~ a ~ j => DerivedTable(a, s.select, j) 100 | } 101 | 102 | lazy val table = optParens(opt(ident <~ ".") ~ ident ~ opt(opt("as".i) ~> ident)) ^^ { 103 | case s ~ n ~ a => Table(n, a, s.map(_.mkString)) 104 | } 105 | 106 | lazy val where = "where".i ~> expr ^^ Where.apply 107 | 108 | lazy val expr: PackratParser[Expr] = (comparison | parens | notExpr)* ( 109 | "and".i ^^^ { (e1: Expr, e2: Expr) => And(e1, e2) } 110 | | "or".i ^^^ { (e1: Expr, e2: Expr) => Or(e1, e2) } 111 | ) 112 | 113 | lazy val parens: PackratParser[Expr] = "(" ~> expr <~ ")" 114 | lazy val notExpr: PackratParser[Expr] = "not".i ~> expr ^^ Not.apply 115 | 116 | lazy val comparison: PackratParser[Comparison] = ( 117 | term ~ "=" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Eq, rhs) } 118 | | term ~ "!=" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Neq, rhs) } 119 | | term ~ "<>" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Neq, rhs) } 120 | | term ~ "<" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Lt, rhs) } 121 | | term ~ ">" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Gt, rhs) } 122 | | term ~ "<=" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Le, rhs) } 123 | | term ~ ">=" ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Ge, rhs) } 124 | | term ~ "like".i ~ term ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, Like, rhs) } 125 | | term ~ "in".i ~ (terms | subselect) ^^ { case lhs ~ _ ~ rhs => Comparison2(lhs, In, rhs) } 126 | | term ~ "not".i ~ "in".i ~ (terms | subselect) ^^ { case lhs ~ _ ~ _ ~ rhs => Comparison2(lhs, NotIn, rhs) } 127 | | term ~ "between".i ~ term ~ "and".i ~ term ^^ { case t1 ~ _ ~ t2 ~ _ ~ t3 => Comparison3(t1, Between, t2, t3) } 128 | | term ~ "not".i ~ "between".i ~ term ~ "and".i ~ term ^^ { case t1 ~ _ ~ _ ~ t2 ~ _ ~ t3 => Comparison3(t1, NotBetween, t2, t3) } 129 | | term <~ "is".i ~ "null".i ^^ { t => Comparison1(t, IsNull) } 130 | | term <~ "is".i ~ "not".i ~ "null".i ^^ { t => Comparison1(t, IsNotNull) } 131 | | "exists".i ~> subselect ^^ { t => Comparison1(t, Exists) } 132 | | "not" ~> "exists".i ~> subselect ^^ { t => Comparison1(t, NotExists) } 133 | ) 134 | 135 | lazy val subselect = "(" ~> selectStmt <~ ")" ^^ Subselect.apply 136 | 137 | lazy val term = (arith | simpleTerm) 138 | 139 | lazy val terms: PackratParser[Term] = "(" ~> repsep(term, ",") <~ ")" ^^ TermList.apply 140 | 141 | lazy val simpleTerm: PackratParser[Term] = ( 142 | subselect 143 | | caseExpr 144 | | function 145 | | boolean 146 | | nullLit ^^^ constNull 147 | | stringLit ^^ constS 148 | | numericLit ^^ (n => if (n.contains(".")) constD(n.toDouble) else constL(n.toLong)) 149 | | extraTerms 150 | | allColumns 151 | | column 152 | | "?" ^^^ Input[Option[String]]() 153 | | optParens(simpleTerm) 154 | ) 155 | 156 | lazy val named = opt("distinct".i) ~> (comparison | arith | simpleTerm) ~ opt(opt("as".i) ~> ident) ^^ { 157 | case (c@Constant(_, _)) ~ a => Named("", a, c) 158 | case (f@Function(n, _)) ~ a => Named(n, a, f) 159 | case (c@Column(n, _)) ~ a => Named(n, a, c) 160 | case (i@Input()) ~ a => Named("?", a, i) 161 | case (c@AllColumns(_)) ~ a => Named("*", a, c) 162 | case (e@ArithExpr(_, _, _)) ~ a => Named("", a, e) 163 | case (c@Comparison1(_, _)) ~ a => Named("", a, c) 164 | case (c@Comparison2(_, _, _)) ~ a => Named("", a, c) 165 | case (c@Comparison3(_, _, _, _)) ~ a => Named("", a, c) 166 | case (s@Subselect(_)) ~ a => Named("subselect", a, s) 167 | case (c@Case(_, _)) ~ a => Named("case", a, c) 168 | } 169 | 170 | def extraTerms: PackratParser[Term] = failure("expected a term") 171 | 172 | def dataTypes: List[Parser[DataType]] = Nil 173 | 174 | lazy val dataType: Parser[Expr] = 175 | dataTypes.foldLeft(failure("expected data type"): Parser[DataType])(_ | _) ^^ TypeExpr.apply 176 | 177 | lazy val column = ( 178 | ident ~ "." ~ ident ^^ { case t ~ _ ~ c => col(c, Some(t)) } 179 | | ident ^^ (c => col(c, None)) 180 | ) 181 | 182 | lazy val allColumns = 183 | opt(ident <~ ".") <~ "*" ^^ (t => AllColumns(t)) 184 | 185 | lazy val caseExpr = "case".i ~ rep(caseCond) ~ opt(caseElse) ~ "end".i ^^ { 186 | case _ ~ conds ~ elze ~ _ => Case(conds, elze) 187 | } 188 | 189 | lazy val caseCond = "when".i ~ expr ~ "then".i ~ term ^^ { 190 | case _ ~ e ~ _ ~ result => (e, result) 191 | } 192 | 193 | lazy val caseElse = "else".i ~> term 194 | 195 | lazy val functionArg: PackratParser[Expr] = opt("distinct".i) ~> (expr | dataType | term ^^ SimpleExpr.apply) 196 | lazy val infixFunctionArg = term ^^ SimpleExpr.apply 197 | 198 | lazy val function = (prefixFunction | infixFunction) 199 | 200 | lazy val prefixFunction: PackratParser[Function] = 201 | ident ~ "(" ~ repsep(functionArg, ",") ~ ")" ^^ { 202 | case name ~ _ ~ params ~ _ => Function(name, params) 203 | } 204 | 205 | lazy val infixFunction: PackratParser[Function] = ( 206 | infixFunctionArg ~ "|" ~ infixFunctionArg 207 | | infixFunctionArg ~ "&" ~ infixFunctionArg 208 | | infixFunctionArg ~ "^" ~ infixFunctionArg 209 | | infixFunctionArg ~ "<<" ~ infixFunctionArg 210 | | infixFunctionArg ~ ">>" ~ infixFunctionArg 211 | ) ^^ { 212 | case lhs ~ name ~ rhs => Function(name, List(lhs, rhs)) 213 | } 214 | 215 | lazy val arith: PackratParser[Term] = (simpleTerm | arithParens)* ( 216 | "+" ^^^ { (lhs: Term, rhs: Term) => ArithExpr(lhs, "+", rhs) } 217 | | "-" ^^^ { (lhs: Term, rhs: Term) => ArithExpr(lhs, "-", rhs) } 218 | | "*" ^^^ { (lhs: Term, rhs: Term) => ArithExpr(lhs, "*", rhs) } 219 | | "/" ^^^ { (lhs: Term, rhs: Term) => ArithExpr(lhs, "/", rhs) } 220 | | "%" ^^^ { (lhs: Term, rhs: Term) => ArithExpr(lhs, "%", rhs) } 221 | ) 222 | 223 | lazy val arithParens = "(" ~> arith <~ ")" 224 | 225 | lazy val boolean = (booleanFactor | booleanTerm) 226 | 227 | lazy val booleanTerm = ("true".i ^^^ true | "false".i ^^^ false) ^^ constB 228 | 229 | lazy val booleanFactor = "not".i ~> term 230 | 231 | lazy val nullLit = "null".i 232 | 233 | lazy val collate = "collate".i ~> ident 234 | 235 | lazy val orderBy = "order".i ~> "by".i ~> rep1sep(orderSpec, ",") ^^ { 236 | orderSpecs => OrderBy(orderSpecs.unzip._1, orderSpecs.unzip._2) 237 | } 238 | 239 | lazy val orderSpec = optParens(term) ~ opt(collate) ~ opt("asc".i ^^^ Asc | "desc".i ^^^ Desc) ^^ { 240 | case s ~ _ ~ o => (s, o) 241 | } 242 | 243 | lazy val groupBy = "group".i ~> "by".i ~> rep1sep(term <~ opt(collate), ",") ~ opt(withRollup) ~ opt(having) ^^ { 244 | case cols ~ withRollup ~ having => GroupBy(cols, withRollup map (_ => true) getOrElse false, having) 245 | } 246 | 247 | lazy val withRollup = "with rollup".i 248 | 249 | lazy val having = "having".i ~> expr ^^ Having.apply 250 | 251 | lazy val limit = "limit".i ~> intOrInput ~ opt("offset".i ~> intOrInput) ^^ { 252 | case count ~ offset => Limit(count, offset) 253 | } 254 | 255 | lazy val intOrInput = ( 256 | "?" ^^^ Right(Input[Option[String]]()) 257 | | numericLit ^^ (n => Left(n.toInt)) 258 | ) 259 | 260 | def optParens[A](p: PackratParser[A]): PackratParser[A] = ( 261 | "(" ~> p <~ ")" 262 | | p 263 | ) 264 | 265 | lazy val reserved = 266 | ("select".i | "delete".i | "insert".i | "update".i | "from".i | "into".i | "where".i | "as".i | 267 | "and".i | "or".i | "join".i | "inner".i | "outer".i | "left".i | "right".i | "on".i | "group".i | 268 | "by".i | "having".i | "limit".i | "offset".i | "order".i | "asc".i | "desc".i | "distinct".i | 269 | "is".i | "not".i | "null".i | "between".i | "in".i | "exists".i | "values".i | "create".i | 270 | "set".i | "union".i | "except".i | "intersect".i) 271 | 272 | private def col(name: String, table: Option[String]) = Column(name, table) 273 | 274 | def constB(b: Boolean) = const((typeOf[Boolean], JdbcTypes.BOOLEAN), b) 275 | def constS(s: String) = const((typeOf[String], JdbcTypes.VARCHAR), s) 276 | def constD(d: Double) = const((typeOf[Double], JdbcTypes.DOUBLE), d) 277 | def constL(l: Long) = const((typeOf[Long], JdbcTypes.BIGINT), l) 278 | def constNull = const((typeOf[AnyRef], JdbcTypes.JAVA_OBJECT), null) 279 | def const(tpe: (Type, Int), x: Any) = Constant[Option[String]](tpe, x) 280 | 281 | implicit class KeywordOps(kw: String) { 282 | def i = keyword(kw) 283 | } 284 | 285 | def keyword(kw: String): Parser[String] = ("(?i)" + kw + "\\b").r 286 | 287 | def quoteChar: Parser[String] = "\"" 288 | 289 | lazy val ident = (quotedIdent | rawIdent) 290 | 291 | lazy val rawIdent = not(reserved) ~> identValue 292 | lazy val quotedIdent = quoteChar ~> identValue <~ quoteChar 293 | 294 | lazy val stringLit = 295 | "'" ~ """([^'\p{Cntrl}\\]|\\[\\/bfnrt]|\\u[a-fA-F0-9]{4})*""".r ~ "'" ^^ { case _ ~ s ~ _ => s } 296 | 297 | lazy val identValue: Parser[String] = "[a-zA-Z][a-zA-Z0-9_-]*".r 298 | lazy val numericLit: Parser[String] = """(-)?(\d+(\.\d*)?|\d*\.\d+)""".r 299 | lazy val integer: Parser[Int] = """\d+""".r ^^ (s => s.toInt) 300 | } 301 | -------------------------------------------------------------------------------- /core/src/main/scala/record.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import shapeless._, ops.hlist._, ops.record._, labelled.FieldType 4 | 5 | object Record { 6 | private object fieldToUntyped extends Poly1 { 7 | implicit def f[F, V](implicit wk: shapeless.Witness.Aux[F]) = at[FieldType[F, V]] { 8 | f => (wk.value.toString, f: Any) :: Nil 9 | } 10 | } 11 | 12 | def toTupleLists[R <: HList, F, V](rs: List[R]) 13 | (implicit folder: MapFolder[R, List[(String, Any)], fieldToUntyped.type]): List[List[(String, Any)]] = 14 | rs map (r => toTupleList(r)(folder)) 15 | 16 | def toTupleList[R <: HList, F, V](r: R) 17 | (implicit folder: MapFolder[R, List[(String, Any)], fieldToUntyped.type]): List[(String, Any)] = 18 | r.foldMap(Nil: List[(String, Any)])(fieldToUntyped)(_ ::: _) 19 | } 20 | 21 | private[sqltyped] object showField extends Poly1 { 22 | implicit def f[F, V](implicit wk: shapeless.Witness.Aux[F]) = at[FieldType[F, V]] { 23 | f => wk.value.toString + " = " + f.toString 24 | } 25 | } 26 | 27 | final class RecordOps[R <: HList](r: R) { 28 | 29 | def show(implicit folder: ops.hlist.MapFolder[R, String, showField.type]): String = { 30 | val concat = (s1: String, s2: String) => if (s2 != "") s1 + ", " + s2 else s1 31 | "{ " + r.foldMap("")(showField)(concat) + " }" 32 | } 33 | } 34 | 35 | final class ListOps[L <: HList](l: List[L]) { 36 | def values(implicit values: Values[L]): List[values.Out] = l map (r => values(r)) 37 | 38 | def tuples[Out0 <: HList, Out <: Product] 39 | (implicit 40 | values: Values.Aux[L, Out0], 41 | tupler: Tupler.Aux[Out0, Out]): List[Out] = l map (r => tupler(values(r))) 42 | } 43 | 44 | final class OptionOps[L <: HList](o: Option[L]) { 45 | def values(implicit values: Values[L]): Option[values.Out] = o map (r => values(r)) 46 | 47 | def tuples[Out0 <: HList, Out <: Product] 48 | (implicit 49 | values: Values.Aux[L, Out0], 50 | tupler: Tupler.Aux[Out0, Out]): Option[Out] = o map (r => tupler(values(r))) 51 | } 52 | -------------------------------------------------------------------------------- /core/src/main/scala/timer.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | private[sqltyped] class Timer(enabled: Boolean) { 4 | def apply[A](msg: => String, indent: Int, a: => A) = 5 | if (enabled) { 6 | val start = System.currentTimeMillis 7 | println((" " * indent) + msg) 8 | val aa = a 9 | println((" " * indent) + (System.currentTimeMillis - start) + "ms") 10 | aa 11 | } else a 12 | } 13 | 14 | private[sqltyped] object Timer { 15 | def apply(enabled: Boolean) = new Timer(enabled) 16 | } 17 | -------------------------------------------------------------------------------- /core/src/main/scala/typer.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import schemacrawler.schema.{ColumnDataType, Schema} 4 | import scala.reflect.runtime.universe.{Type, typeOf} 5 | import Ast._ 6 | 7 | case class TypedValue(name: String, tpe: (Type, Int), nullable: Boolean, tag: Option[String], term: Term[Table]) 8 | 9 | case class TypedStatement( 10 | input: List[TypedValue] 11 | , output: List[TypedValue] 12 | , isQuery: Boolean 13 | , uniqueConstraints: Map[Table, List[List[Column[Table]]]] 14 | , generatedKeyTypes: List[TypedValue] 15 | , numOfResults: NumOfResults.NumOfResults = NumOfResults.Many) 16 | 17 | object NumOfResults extends Enumeration { 18 | type NumOfResults = Value 19 | val ZeroOrOne, One, Many = Value 20 | } 21 | 22 | /** 23 | * Variable is a placeholder in SQL statement (ie. ?-char). 24 | * 'comparisonTerm' is the outer context of a variable. For instance in the stmt below: 25 | * where name = upper(?) 26 | * the comparison term of Function("upper") is Comparison2(Column("name"), Eq, f) 27 | */ 28 | case class Variable(term: Named[Table], comparisonTerm: Option[Term[Table]] = None) 29 | 30 | class Variables(typer: Typer) extends Ast.Resolved { 31 | def input(schema: Schema, stmt: Statement): List[Variable] = stmt match { 32 | case Delete(from, where) => where map (w => input(w.expr, None)) getOrElse Nil 33 | 34 | case Insert(table, colNames, insertInput) => 35 | val realSchema = table.schema.flatMap { schemaName => 36 | typer.cachedSchema(schemaName).fold(_ => None, Option.apply) 37 | } getOrElse schema 38 | def colNamesFromSchema = realSchema.getTable(table.name).getColumns.toList.map(_.getName) 39 | 40 | insertInput match { 41 | case ListedInput(values) => 42 | (colNames getOrElse colNamesFromSchema zip values collect { 43 | case (name, Input()) => List(Variable(Named(name, None, Column(name, table)))) 44 | case (name, Subselect(s)) => input(schema, s) 45 | }).flatten 46 | case SelectedInput(select) => input(schema, select) 47 | } 48 | 49 | case SetStatement(l, op, r, orderBy, limit) => 50 | input(schema, l) ::: input(schema, r) ::: (orderBy map input).getOrElse(Nil) ::: limitInput(limit) 51 | 52 | case Composed(left, right) => 53 | input(schema, left) ::: input(schema, right) 54 | 55 | case Update(tables, set, where, orderBy, limit) => 56 | set.flatMap { 57 | case (col, Input()) => Variable(Named(col.name, None, col)) :: Nil 58 | case (col, t) => inputTerm(t, None) 59 | } ::: 60 | where.map(w => input(w.expr, None)).getOrElse(Nil) ::: 61 | (orderBy map input).getOrElse(Nil) ::: 62 | limitInput(limit) 63 | 64 | case Create() => Nil 65 | 66 | case s@Select(_, _, _, _, _, _) => input(s) 67 | } 68 | 69 | def input(s: Select): List[Variable] = 70 | s.projection.collect { 71 | case Named(n, a, f@Function(_, _)) => input(f, None) 72 | case n@Named(_, _, Input()) => Variable(n) :: Nil 73 | case Named(_, _, Subselect(s)) => input(s) 74 | case Named(_, _, e: Expr) => input(e, None) 75 | case Named(_, _, c@Case(_, _)) => inputTerm(c, None) 76 | }.flatten ::: 77 | s.tableReferences.flatMap(input) ::: 78 | s.where.map(w => input(w.expr, None)).getOrElse(Nil) ::: 79 | s.groupBy.toList.flatMap(g => (g.terms flatMap (t => inputTerm(t, None))) ::: g.having.toList.flatMap(h => input(h.expr, None))) ::: 80 | s.orderBy.map(input).getOrElse(Nil) ::: 81 | limitInput(s.limit) 82 | 83 | def input(t: TableReference): List[Variable] = t match { 84 | case ConcreteTable(_, join) => join flatMap input 85 | case DerivedTable(_, s, join) => input(s) ::: (join flatMap input) 86 | } 87 | 88 | def input(j: Join): List[Variable] = 89 | input(j.table) ::: (j.joinType map input getOrElse Nil) 90 | 91 | def input(o: OrderBy): List[Variable] = o.sort flatMap (s => inputTerm(s, None)) 92 | 93 | def input(j: JoinType): List[Variable] = j match { 94 | case QualifiedJoin(e) => input(e, None) 95 | case _ => Nil 96 | } 97 | 98 | def input(f: Function, comparisonTerm: Option[Term]): List[Variable] = 99 | f.params zip typer.inferArguments(f, comparisonTerm).getOrElse(Nil) flatMap { 100 | case (SimpleExpr(t), (tpe, _)) => t match { 101 | case Input() => Variable(Named("", None, Constant[Table](tpe, ())), comparisonTerm) :: Nil 102 | case _ => inputTerm(t, comparisonTerm) 103 | } 104 | case (e, _) => input(e, comparisonTerm) 105 | } 106 | 107 | def nameVar(t: Term, comparisonTerm: Option[Term]) = t match { 108 | case c@Constant(_, _) => Variable(Named("", None, c), comparisonTerm) 109 | case f@Function(n, _) => Variable(Named(n, None, f), comparisonTerm) 110 | case c@Column(n, _) => Variable(Named(n, None, c), comparisonTerm) 111 | case c@AllColumns(_) => Variable(Named("*", None, c), comparisonTerm) 112 | case Subselect(s) => Variable(s.projection.head, comparisonTerm) 113 | case i@Input() => Variable(Named("_?", None, i), comparisonTerm) 114 | case _ => sys.error("Invalid term " + t) 115 | } 116 | 117 | def inputTerm(t: Term, comparisonTerm: Option[Term]): List[Variable] = t match { 118 | case f@Function(_, _) => input(f, comparisonTerm) 119 | case Subselect(s) => input(s) 120 | case Input() => nameVar(t, comparisonTerm) :: Nil 121 | case ArithExpr(Input(), _, t) => nameVar(t, comparisonTerm) :: inputTerm(t, comparisonTerm) 122 | case ArithExpr(t, _, Input()) => inputTerm(t, comparisonTerm) ::: List(nameVar(t, comparisonTerm)) 123 | case ArithExpr(lhs, _, rhs) => inputTerm(lhs, comparisonTerm) ::: inputTerm(rhs, comparisonTerm) 124 | case Case(conds, elze) => (conds flatMap { case (expr, result) => 125 | input(expr, comparisonTerm) ::: inputTerm(result, comparisonTerm) 126 | }) ::: (elze map (t => inputTerm(t, comparisonTerm)) getOrElse Nil) 127 | case _ => Nil 128 | } 129 | 130 | object Inputs { 131 | def unapply(t: Term) = t match { 132 | case TermList(ts) => Some(ts collect { case Input() => Input() }) 133 | case _ => None 134 | } 135 | } 136 | 137 | def input(e: Expr, ct: Option[Term]): List[Variable] = e match { 138 | case SimpleExpr(t) => inputTerm(t, ct) 139 | case c@Comparison1(t, _) => inputTerm(t, Some(c)) 140 | case c@Comparison2(Input(), op, t) => nameVar(t, Some(c)) :: inputTerm(t, Some(c)) 141 | case c@Comparison2(t, op, Input()) => inputTerm(t, Some(c)) ::: List(nameVar(t, Some(c))) 142 | case c@Comparison2(Inputs(is), op, t) => (is map (_ => nameVar(t, Some(c)))) ::: inputTerm(t, Some(c)) 143 | case c@Comparison2(t, op, Inputs(is)) => inputTerm(t, Some(c)) ::: (is map (_ => nameVar(t, Some(c)))) 144 | case c@Comparison2(t, op, Subselect(s)) => inputTerm(t, Some(c)) ::: input(s) 145 | case c@Comparison2(t1, op, t2) => inputTerm(t1, Some(c)) ::: inputTerm(t2, Some(c)) 146 | case c@Comparison3(t, op, Input(), Input()) => inputTerm(t, Some(c)) ::: (nameVar(t, Some(c)) :: nameVar(t, Some(c)) :: Nil) 147 | case c@Comparison3(t1, op, Input(), t2) => inputTerm(t1, Some(c)) ::: (nameVar(t1, Some(c)) :: inputTerm(t2, Some(c))) 148 | case c@Comparison3(t1, op, t2, Input()) => inputTerm(t1, Some(c)) ::: inputTerm(t2, Some(c)) ::: List(nameVar(t1, Some(c))) 149 | case c@Comparison3(t1, op, t2, t3) => inputTerm(t1, Some(c)) ::: inputTerm(t2, Some(c)) ::: inputTerm(t3, Some(c)) 150 | case And(e1, e2) => input(e1, ct) ::: input(e2, ct) 151 | case Or(e1, e2) => input(e1, ct) ::: input(e2, ct) 152 | case Not(e) => input(e, ct) 153 | case TypeExpr(d) => Nil 154 | } 155 | 156 | def limitInput(limit: Option[Limit]) = 157 | limit.map(l => l.count.right.toSeq.toList ::: l.offset.map(_.right.toSeq.toList).getOrElse(Nil)).getOrElse(Nil).map { _ => 158 | Variable(Named("", None, Constant[Table]((typeOf[Long], java.sql.Types.BIGINT), None))) 159 | } 160 | 161 | def output(stmt: Statement): List[Variable] = stmt match { 162 | case Delete(_, _) => Nil 163 | case Insert(_, _, _) => Nil 164 | case SetStatement(left, _, _, _, _) => output(left) 165 | case Composed(left, right) => output(left) ::: output(right) 166 | case Update(_, _, _, _, _) => Nil 167 | case Create() => Nil 168 | case Select(projection, _, _, _, _, _) => projection map (t => Variable(t)) 169 | } 170 | } 171 | 172 | class Typer(schema: Schema, stmt: Ast.Statement[Table], dbConfig: DbConfig) extends Ast.Resolved { 173 | private val schemaCache = new java.util.WeakHashMap[String, ?[Schema]]() 174 | def cachedSchema(schemaName: String) = { 175 | val cached = schemaCache.get(schemaName) 176 | if (cached != null) cached else { 177 | val s = DbSchema.read(dbConfig.copy(schema = Some(schemaName))) 178 | schemaCache.put(schemaName, s) 179 | s 180 | } 181 | } 182 | 183 | import java.sql.{Types => JdbcTypes} 184 | 185 | type SqlType = ((Type, Int), Boolean) 186 | type SqlFType = (List[SqlType], SqlType) 187 | 188 | def typeSpecifyTerm(v: Variable): Option[?[List[TypedValue]]] = None 189 | 190 | def infer(useInputTags: Boolean): ?[TypedStatement] = { 191 | def uniqueConstraints = { 192 | val constraints = stmt.tables map { t => 193 | (tableSchema(t) map { table => 194 | val indices = Option(table.getPrimaryKey).map(List(_)).getOrElse(Nil) ::: table.getIndices.toList 195 | val uniques = indices filter (_.isUnique) map { i => 196 | i.getColumns.toList.map(col => Column(col.getName, t)) 197 | } 198 | (t, uniques) 199 | }) fold (_ => (t, List[List[Column]]()), identity) 200 | } 201 | 202 | Map[Table, List[List[Column]]]().withDefault(_ => Nil) ++ constraints 203 | } 204 | 205 | def generatedKeyTypes(table: Table) = (for { 206 | t <- tableSchema(table) 207 | } yield { 208 | def tag(c: schemacrawler.schema.Column) = 209 | Option(t.getPrimaryKey).flatMap(_.getColumns.find(_.getName == c.getName)).map(_ => t.getName) 210 | 211 | t.getColumns.toList 212 | .filter(c => c.getType.isAutoIncrementable) 213 | .map(c => TypedValue(c.getName, mkType(c.getType), false, tag(c), Column(c.getName, table))) 214 | }) fold (_ => Nil, identity) 215 | 216 | val vars = new Variables(this) 217 | for { 218 | in <- sequence(vars.input(schema, stmt) map typeTerm(useTags = useInputTags)) 219 | out <- sequence(vars.output(stmt) map typeTerm(useTags = true)) 220 | } yield TypedStatement(in.flatten, out.flatten, stmt.isQuery, uniqueConstraints, generatedKeyTypes(stmt.tables.head)) 221 | } 222 | 223 | def tag(col: Column) = 224 | tableSchema(col.table) map { t => 225 | def findFK = t.getForeignKeys 226 | .flatMap(_.getColumnPairs.map(_.getForeignKeyColumn)) 227 | .find(_.getName == col.name) 228 | .map(_.getReferencedColumn.getParent.getName) 229 | 230 | if (t.getPrimaryKey != null && t.getPrimaryKey.getColumns.exists(_.getName == col.name)) 231 | Some(col.table.name) 232 | else findFK orElse None 233 | } fold (_ => None, identity) 234 | 235 | def typeTerm(useTags: Boolean)(v: Variable): ?[List[TypedValue]] = typeSpecifyTerm(v) getOrElse { 236 | val x = v.term 237 | x.term match { 238 | case col@Column(_, _) => 239 | for { 240 | (tpe, opt) <- inferColumnType(col) 241 | } yield List(TypedValue(x.aname, tpe, opt, if (useTags) tag(col) else None, x.term)) 242 | case AllColumns(t) => 243 | for { 244 | tbl <- tableSchema(t) 245 | cs <- sequence(tbl.getColumns.toList map { c => typeTerm(useTags)(Variable(Named(c.getName, None, Column(c.getName, t)))) }) 246 | } yield cs.flatten 247 | case f@Function(_, _) => 248 | inferReturnType(f, v.comparisonTerm) map { case (tpe, opt) => 249 | List(TypedValue(x.aname, tpe, opt, None, x.term)) 250 | } 251 | case Constant(tpe, _) => List(TypedValue(x.aname, tpe, false, None, x.term)).ok 252 | case Input() => 253 | List(TypedValue(x.aname, (typeOf[Any], JdbcTypes.JAVA_OBJECT), false, None, x.term)).ok 254 | case ArithExpr(_, "/", _) => 255 | List(TypedValue(x.aname, (typeOf[Double], JdbcTypes.DOUBLE), true, None, x.term)).ok 256 | case ArithExpr(lhs, _, rhs) => 257 | (lhs, rhs) match { 258 | case (c@Column(_, _), _) => typeTerm(useTags)(Variable(Named(c.name, x.alias, c), v.comparisonTerm)) 259 | case (_, c@Column(_, _)) => typeTerm(useTags)(Variable(Named(c.name, x.alias, c), v.comparisonTerm)) 260 | case _ => typeTerm(useTags)(Variable(Named(x.name, x.alias, lhs), v.comparisonTerm)) 261 | } 262 | case Comparison1(_, IsNull) | Comparison1(_, IsNotNull) => 263 | List(TypedValue(x.aname, (typeOf[Boolean], JdbcTypes.BOOLEAN), false, None, x.term)).ok 264 | case Comparison1(t, _) => 265 | List(TypedValue(x.aname, (typeOf[Boolean], JdbcTypes.BOOLEAN), isNullable(t), None, x.term)).ok 266 | case Comparison2(t1, _, t2) => 267 | List(TypedValue(x.aname, (typeOf[Boolean], JdbcTypes.BOOLEAN), isNullable(t1) || isNullable(t2), None, x.term)).ok 268 | case Comparison3(t1, _, t2, t3) => 269 | List(TypedValue(x.aname, (typeOf[Boolean], JdbcTypes.BOOLEAN), isNullable(t1) || isNullable(t2) || isNullable(t3), None, x.term)).ok 270 | case Subselect(s) => 271 | sequence(s.projection map (t => Variable(t)) map typeTerm(useTags)) map (_.flatten) map (_ map makeNullable) 272 | case TermList(t) => 273 | sequence(t.map(t => typeTerm(useTags)(Variable(Named("elem", None, t), v.comparisonTerm)))).map(_.flatten) 274 | case Case(conds, elze) => 275 | typeTerm(useTags)(Variable(Named("case", None, conds.head._2), v.comparisonTerm)) 276 | } 277 | } 278 | 279 | def makeNullable(x: TypedValue) = x.copy(nullable = true) 280 | 281 | def isNullable(t: Term) = tpeOf(SimpleExpr(t), None) map { case (_, opt) => opt } getOrElse false 282 | 283 | def isAggregate(fname: String): Boolean = aggregateFunctions.contains(fname.toLowerCase) 284 | 285 | val dsl = new TypeSigDSL(this) 286 | import dsl._ 287 | 288 | val aggregateFunctions = Map( 289 | "avg" -> (f(a) -> option(double)) 290 | , "count" -> (f(a) -> long) 291 | , "min" -> (f(a) -> a) 292 | , "max" -> (f(a) -> a) 293 | , "sum" -> (f(a) -> a) 294 | ) ++ extraAggregateFunctions 295 | 296 | val scalarFunctions = Map( 297 | "abs" -> (f(a) -> a) 298 | , "lower" -> (f(a) -> a) 299 | , "upper" -> (f(a) -> a) 300 | , "|" -> (f2(a, a) -> a) 301 | , "&" -> (f2(a, a) -> a) 302 | , "^" -> (f2(a, a) -> a) 303 | , ">>" -> (f2(a, a) -> a) 304 | , "<<" -> (f2(a, a) -> a) 305 | ) ++ extraScalarFunctions 306 | 307 | val knownFunctions = aggregateFunctions ++ scalarFunctions 308 | 309 | def extraAggregateFunctions: Map[String, (String, List[Expr], Option[Term]) => ?[SqlFType]] = Map() 310 | def extraScalarFunctions: Map[String, (String, List[Expr], Option[Term]) => ?[SqlFType]] = Map() 311 | 312 | def tpeOf(e: Expr, comparisonTerm: Option[Term]): ?[SqlType] = e match { 313 | case SimpleExpr(t) => t match { 314 | case Constant(tpe, x) if x == null => (tpe, true).ok 315 | case Constant(tpe, _) => (tpe, false).ok 316 | case col@Column(_, _) => inferColumnType(col) 317 | case f@Function(_, _) => inferReturnType(f, comparisonTerm) 318 | case Input() => (comparisonTerm map typeFromComparison) getOrElse ((typeOf[Any], JdbcTypes.JAVA_OBJECT), false).ok 319 | case TermList(terms) => tpeOf(SimpleExpr(terms.head), comparisonTerm) 320 | case ArithExpr(Input(), op, rhs) => tpeOf(SimpleExpr(rhs), comparisonTerm) 321 | case ArithExpr(lhs, op, rhs) => tpeOf(SimpleExpr(lhs), comparisonTerm) 322 | case x => ((typeOf[Any], JdbcTypes.JAVA_OBJECT), false).ok 323 | } 324 | 325 | case _ => ((typeOf[Boolean], JdbcTypes.BOOLEAN), false).ok 326 | } 327 | 328 | def typeFromComparison(term: Term) = term match { 329 | case Comparison2(t, _, _) => typeTerm(false)(Variable(Named("", None, t))) map (ts => (ts.head.tpe, ts.head.nullable)) 330 | case x => ((typeOf[Any], JdbcTypes.JAVA_OBJECT), false).ok 331 | } 332 | 333 | def inferReturnType(f: Function, comparisonTerm: Option[Term]) = 334 | knownFunctions.get(f.name.toLowerCase) match { 335 | case Some(func) => func(f.name, f.params, comparisonTerm).map(_._2) 336 | case None => ((typeOf[Any], JdbcTypes.JAVA_OBJECT), true).ok 337 | } 338 | 339 | def inferArguments(f: Function, comparisonTerm: Option[Term]) = 340 | knownFunctions.get(f.name.toLowerCase) match { 341 | case Some(func) => func(f.name, f.params, comparisonTerm).map(_._1) 342 | case None => f.params.map(_ => ((typeOf[Any], JdbcTypes.JAVA_OBJECT), true)).ok 343 | } 344 | 345 | def inferColumnType(col: Column) = (for { 346 | t <- tableSchema(col.table) 347 | c <- Option(t.getColumn(col.name)) orFail ("No such column " + col) 348 | } yield (mkType(c.getType), c.isNullable || isNullableByJoin(col) || isNullableByGroupBy(col))) fold ( 349 | _ => inferFromDerivedTable(col), x => x.ok 350 | ) 351 | 352 | def inferFromDerivedTable(col: Column) = for { 353 | t <- derivedTable(col.table) 354 | c <- t.output.find(_.name == col.name) orFail ("No such column " + col) 355 | } yield (c.tpe, c.nullable) 356 | 357 | def isNullableByJoin(col: Column) = isProjectedByJoin(stmt, col) map (_.joinDesc) exists { 358 | case LeftOuter | RightOuter | FullOuter => true 359 | case Inner | Cross => false 360 | } 361 | 362 | def isNullableByGroupBy(col: Column) = stmt match { 363 | case Select(_, _, _, Some(GroupBy(terms, true, _)), _, _) => terms.contains(col) 364 | case _ => false 365 | } 366 | 367 | private def tableSchema(tbl: Table) = 368 | tbl.schema.map { schemaName => 369 | val schema = cachedSchema(schemaName) 370 | schema.flatMap { schema => 371 | Option(schema.getTable(tbl.name)) orFail ("Unknown table " + schemaName + "." + tbl.name) 372 | } 373 | } getOrElse { 374 | if (tbl.name.toLowerCase == "dual") DualTable(schema).ok 375 | else Option(schema.getTable(tbl.name)) orFail ("Unknown table " + tbl.name) 376 | } 377 | 378 | private def derivedTable(tbl: Table) = for { 379 | t <- DerivedTables(schema, stmt, tbl.name) orFail ("Unknown table XXX " + tbl.name) 380 | typed <- new Typer(schema, t, dbConfig).infer(false) 381 | } yield typed 382 | 383 | private def mkType(t: ColumnDataType) = (Jdbc.mkType(t.getTypeClassName), t.getType) 384 | } 385 | 386 | object DualTable { 387 | def apply(schema: Schema) = { 388 | val cstr = schema.getClass.getClassLoader.loadClass("schemacrawler.crawl.MutableTable") 389 | .getDeclaredConstructor(classOf[Schema], classOf[String]) 390 | cstr.setAccessible(true) 391 | cstr.newInstance(schema, "dual").asInstanceOf[schemacrawler.schema.Table] 392 | } 393 | } 394 | 395 | object DerivedTables extends Ast.Resolved { 396 | def apply(schema: Schema, stmt: Statement, name: String): Option[Select] = 397 | derivedTables(stmt) find (_.name == name) map (_.subselect) 398 | 399 | private def derivedTables(stmt: Statement): List[DerivedTable] = stmt match { 400 | case Select(proj, tableRefs, where, groupBy, orderBy, limit) => tableRefs flatMap referencedTables 401 | case _ => Nil 402 | } 403 | 404 | private def joinedTables(j: Join) = referencedTables(j.table) 405 | 406 | private def referencedTables(table: TableReference): List[DerivedTable] = table match { 407 | case ConcreteTable(_, join) => join flatMap joinedTables 408 | case t@DerivedTable(_, sub, join) => t :: derivedTables(sub) ::: (join flatMap joinedTables) 409 | } 410 | } 411 | -------------------------------------------------------------------------------- /core/src/main/scala/typesigdsl.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import scala.reflect.runtime.universe.{Type, typeOf} 4 | import Ast.Resolved._ 5 | import java.sql.{Types => JdbcTypes} 6 | 7 | class TypeSigDSL(typer: Typer) { 8 | case class f[A: Typed](a: A) { 9 | def ->[R: Typed](r: R) = (fname: String, params: List[Expr], comparisonTerm: Option[Term]) => 10 | if (params.length != 1) fail("Expected 1 parameter " + params) 11 | else for { 12 | a1 <- implicitly[Typed[A]].tpe(fname, params(0), comparisonTerm) 13 | r <- implicitly[Typed[R]].tpe(fname, params(0), comparisonTerm) 14 | } yield (List(a1), r) 15 | } 16 | 17 | case class f2[A: Typed, B: Typed](a: A, b: B) { 18 | def ->[R: Typed](r: R) = (fname: String, params: List[Expr], comparisonTerm: Option[Term]) => 19 | if (params.length != 2) fail("Expected 2 parameters " + params) 20 | else for { 21 | a1 <- implicitly[Typed[A]].tpe(fname, params(0), comparisonTerm) 22 | a2 <- implicitly[Typed[B]].tpe(fname, params(1), comparisonTerm) 23 | r <- implicitly[Typed[R]].tpe(fname, if (a == r) params(0) else params(1), comparisonTerm) 24 | } yield (List(a1, a2), r) 25 | } 26 | 27 | trait Typed[A] { 28 | def tpe(fname: String, e: Expr, comparisonTerm: Option[Term]): ?[((Type, Int), Boolean)] 29 | } 30 | 31 | trait TypeParam 32 | object a extends TypeParam 33 | object b extends TypeParam 34 | object c extends TypeParam 35 | object d extends TypeParam 36 | 37 | object int 38 | object long 39 | object double 40 | object date 41 | case class option[A: Typed](a: A) 42 | 43 | implicit def optionTyped[A: Typed]: Typed[option[A]] = new Typed[option[A]] { 44 | def tpe(fname: String, e: Expr, ct: Option[Term]) = implicitly[Typed[A]].tpe(fname, e, ct) map { 45 | case (tpe, opt) => (tpe, true) 46 | } 47 | } 48 | 49 | implicit def typeParamTyped[A <: TypeParam]: Typed[A] = new Typed[A] { 50 | def tpe(fname: String, e: Expr, ct: Option[Term]) = typer.tpeOf(e, ct) map { 51 | case (tpe, opt) => (tpe, typer.isAggregate(fname) || opt) 52 | } 53 | } 54 | 55 | implicit def intTyped: Typed[int.type] = new Const[int.type]((typeOf[Int], JdbcTypes.INTEGER)) 56 | implicit def longTyped: Typed[long.type] = new Const[long.type]((typeOf[Long], JdbcTypes.BIGINT)) 57 | implicit def doubleTyped: Typed[double.type] = new Const[double.type]((typeOf[Double], JdbcTypes.DOUBLE)) 58 | implicit def dateTyped: Typed[date.type] = new Const[date.type]((typeOf[java.sql.Date], JdbcTypes.TIMESTAMP)) 59 | 60 | class Const[A](tpe: (Type, Int)) extends Typed[A] { 61 | def tpe(fname: String, e: Expr, comparisonTerm: Option[Term]) = (tpe, false).ok 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /core/src/main/scala/validator.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import schemacrawler.schema.Schema 4 | import Ast._ 5 | 6 | trait Validator { 7 | def validate(db: DbConfig, sql: String): ?[Unit] 8 | } 9 | 10 | object NOPValidator extends Validator { 11 | def validate(db: DbConfig, sql: String) = ().ok 12 | } 13 | 14 | object JdbcValidator extends Validator { 15 | def validate(db: DbConfig, sql: String) = 16 | Jdbc.withConnection(db.getConnection) { conn => 17 | val stmt = conn.prepareStatement(sql) 18 | stmt.getMetaData // some JDBC drivers do round trip to DB here and validates the statement 19 | } 20 | } 21 | 22 | /** 23 | * For MySQL we use its internal API to get better validation. 24 | */ 25 | object MySQLValidator extends Validator { 26 | def validate(db: DbConfig, sql: String) = try { 27 | Jdbc.withConnection(db.getConnection) { conn => 28 | val m = Class.forName("com.mysql.jdbc.ServerPreparedStatement").getDeclaredMethod( 29 | "getInstance", 30 | Class.forName("com.mysql.jdbc.MySQLConnection"), 31 | classOf[String], 32 | classOf[String], 33 | classOf[Int], 34 | classOf[Int]) 35 | m.setAccessible(true) 36 | m.invoke(null, conn, sql, "", 0: java.lang.Integer, 0: java.lang.Integer).ok 37 | } 38 | } catch { 39 | case e: Exception if e.getClass.getName.endsWith("MySQLSyntaxErrorException") => fail(e.getMessage) 40 | case e: Exception => JdbcValidator.validate(db, sql) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /core/src/test/resources/test-postgresql.sql: -------------------------------------------------------------------------------- 1 | CREATE SCHEMA sqltyped AUTHORIZATION sqltypedtest; 2 | SET search_path TO sqltyped; 3 | ALTER USER sqltypedtest SET search_path to sqltyped; 4 | 5 | create table person( 6 | id bigserial NOT NULL, 7 | name varchar(255) NOT NULL, 8 | age int NOT NULL, 9 | salary int NOT NULL, 10 | img bytea, 11 | PRIMARY KEY (id) 12 | ); 13 | 14 | create table job_history( 15 | person bigint references person(id) NOT NULL, 16 | name varchar(255) NOT NULL, 17 | started timestamp NOT NULL, 18 | resigned timestamp NULL 19 | ); 20 | 21 | create table jobs(person varchar(255) NOT NULL,job varchar(255) NOT NULL); 22 | 23 | insert into person values (1, 'joe', 36, 9500); 24 | insert into person values (2, 'moe', 14, 8000); 25 | 26 | insert into job_history values (1, 'Enron', '2002-08-02 08:00:00', '2004-06-22 18:00:00'); 27 | insert into job_history values (1, 'IBM', '2004-07-13 11:00:00', NULL); 28 | insert into job_history values (2, 'IBM', '2005-08-10 11:00:00', NULL); 29 | -------------------------------------------------------------------------------- /core/src/test/resources/test.sql: -------------------------------------------------------------------------------- 1 | create table person( 2 | id bigint(20) NOT NULL auto_increment, 3 | name varchar(255) NOT NULL, 4 | age INT NOT NULL, 5 | salary INT NOT NULL, 6 | img BLOB, 7 | PRIMARY KEY (id) 8 | ) ENGINE=InnoDB; 9 | 10 | create table job_history( 11 | person bigint(20) NOT NULL, 12 | name varchar(255) NOT NULL, 13 | started timestamp NOT NULL, 14 | resigned timestamp NULL, 15 | FOREIGN KEY person_id_fk (person) 16 | REFERENCES person (id) 17 | ON DELETE CASCADE 18 | ON UPDATE NO ACTION 19 | ) ENGINE=InnoDB; 20 | 21 | create table jobs(person varchar(255) NOT NULL,job varchar(255) NOT NULL) ENGINE=InnoDB; 22 | 23 | create table alltypes( 24 | a TINYINT NOT NULL, 25 | b SMALLINT NOT NULL, 26 | c MEDIUMINT NOT NULL, 27 | d INT NOT NULL, 28 | e BIGINT NOT NULL, 29 | f FLOAT(24) NOT NULL, 30 | g FLOAT(53) NOT NULL, 31 | h DOUBLE NOT NULL, 32 | i BIT(2) NOT NULL, 33 | j DATE NOT NULL, 34 | k TIME NOT NULL, 35 | l DATETIME NOT NULL, 36 | m TIMESTAMP NOT NULL, 37 | n YEAR NOT NULL, 38 | o CHAR(255) NOT NULL, 39 | p VARCHAR(255) NOT NULL, 40 | q TEXT NOT NULL, 41 | r ENUM('v1','v2') NOT NULL, 42 | s SET('v1','v2') NOT NULL, 43 | t DECIMAL NOT NULL 44 | ) ENGINE=InnoDB; 45 | 46 | insert into person values (1, 'joe', 36, 9500, NULL); 47 | insert into person values (2, 'moe', 14, 8000, NULL); 48 | 49 | insert into job_history values (1, 'Enron', '2002-08-02 08:00:00', '2004-06-22 18:00:00'); 50 | insert into job_history values (1, 'IBM', '2004-07-13 11:00:00', NULL); 51 | insert into job_history values (2, 'IBM', '2005-08-10 11:00:00', NULL); 52 | 53 | insert into alltypes values( 54 | 1, 55 | 1, 56 | 1, 57 | 1, 58 | 1, 59 | 1.0, 60 | 1.0, 61 | 1.0, 62 | 1, 63 | '2012-10-10', 64 | '14:00:00', 65 | '2012-10-10', 66 | '2012-10-10', 67 | 2012, 68 | 'a', 69 | 'a', 70 | 'a', 71 | 'v1', 72 | 'v1', 73 | 1.0 74 | ); 75 | -------------------------------------------------------------------------------- /core/src/test/scala/dynamicexamples.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import org.scalatest._ 4 | 5 | class DynamicExamples extends MySQLConfig { 6 | test("Runtime query building") { 7 | val where = "age > ?" + " or " + "1 > 2" 8 | 9 | sql"select name from person where $where order by age".apply(Seq(5)) === 10 | List("moe", "joe") 11 | 12 | sql"select name from person where $where order by age".apply(Seq(25)) === 13 | List("joe") 14 | 15 | sql"select name, age from person where $where order by $orderBy".apply(Seq(5)).tuples === 16 | List(("moe", 14), ("joe", 36)) 17 | 18 | sql"select j.name, p.name from person p join job_history j on p.id=j.person where $where".apply(Seq(15)).tuples === 19 | List(("Enron", "joe"), ("IBM", "joe")) 20 | } 21 | 22 | test("Analysis is skipped since statement is not fully known") { 23 | sql"select name from person where id=?".apply(Seq(1)) === 24 | List("joe") 25 | } 26 | 27 | test("Simple expr library") { 28 | import ExprLib._ 29 | 30 | val p1 = pred("age > ?", 15) 31 | val p2 = pred("age < ?", 2) 32 | val p3 = pred("length(name) < ?", 6) 33 | 34 | val expr = (p1 or p2) and p3 35 | 36 | sql"select name from person where ${expr.sql}".apply(expr.args) === 37 | List("joe") 38 | } 39 | 40 | def orderBy = "age" 41 | 42 | object ExprLib { 43 | sealed trait Expr { 44 | def sql: String = this match { 45 | case Predicate(e, _) => e 46 | case And(l, r) => "(" + l.sql + " and " + r.sql + ")" 47 | case Or(l, r) => "(" + l.sql + " or " + r.sql + ")" 48 | } 49 | 50 | def args: Seq[Any] = this match { 51 | case Predicate(_, as) => as 52 | case And(l, r) => l.args ++ r.args 53 | case Or(l, r) => l.args ++ r.args 54 | } 55 | 56 | def and(other: Expr) = And(this, other) 57 | def or(other: Expr) = Or(this, other) 58 | } 59 | 60 | case class Predicate(sqlExpr: String, arguments: Seq[Any]) extends Expr 61 | case class And(l: Expr, r: Expr) extends Expr 62 | case class Or(l: Expr, r: Expr) extends Expr 63 | 64 | def pred(sql: String, args: Any*) = Predicate(sql, args) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /core/src/test/scala/examples.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import org.scalatest._ 5 | import shapeless._, tag.@@ 6 | 7 | trait Example extends FunSuite with BeforeAndAfterEach with matchers.ShouldMatchers { 8 | def beforeEachWithConfig[A](implicit conn: Connection) { 9 | val newPerson = sql("insert into person(id, name, age, salary) values (?, ?, ?, ?)") 10 | val jobHistory = sql("insert into job_history values (?, ?, ?, ?)") 11 | 12 | sql("delete from job_history").apply 13 | sql("delete from person").apply 14 | 15 | newPerson(1, "joe", 36, 9500) 16 | newPerson(2, "moe", 14, 8000) 17 | 18 | jobHistory(1, "Enron", tstamp("2002-08-02 08:00:00.0"), Some(tstamp("2004-06-22 18:00:00.0"))) 19 | jobHistory(1, "IBM", tstamp("2004-07-13 11:00:00.0"), None) 20 | jobHistory(2, "IBM", tstamp("2005-08-10 11:00:00.0"), None) 21 | } 22 | 23 | def tstamp(s: String) = 24 | new java.sql.Timestamp(new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse(s).getTime) 25 | 26 | def datetime(s: String) = 27 | new java.sql.Timestamp(new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse(s).getTime) 28 | 29 | def date(s: String) = 30 | new java.sql.Date(new java.text.SimpleDateFormat("yyyy-MM-dd").parse(s).getTime) 31 | 32 | def time(s: String) = 33 | new java.sql.Time(new java.text.SimpleDateFormat("HH:mm:ss.S").parse(s).getTime) 34 | 35 | def year(y: Int) = tstamp(y + "-01-01 00:00:00.0") 36 | 37 | implicit class TypeSafeEquals[A](a: A) { 38 | def ===(other: A) = a should equal(other) 39 | } 40 | } 41 | 42 | trait PostgreSQLConfig extends Example { 43 | Class.forName("org.postgresql.Driver") 44 | 45 | implicit object postgresql extends ConfigurationName 46 | implicit val conn = DriverManager.getConnection("jdbc:postgresql://localhost/sqltyped", "sqltypedtest", "secret") 47 | 48 | override def beforeEach() = beforeEachWithConfig 49 | } 50 | 51 | trait MySQLConfig extends Example { 52 | Class.forName("com.mysql.jdbc.Driver") 53 | 54 | implicit val enableTagging = EnableTagging 55 | implicit val conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/sqltyped", "root", "") 56 | 57 | override def beforeEach() = beforeEachWithConfig 58 | } 59 | 60 | class ExampleSuite extends MySQLConfig { 61 | test("Simple query") { 62 | val q1 = sql("select name, age from person") 63 | q1().map(_.get("age")).sum === 50 64 | 65 | val q2 = sql("select * from person") 66 | q2().map(_.get("age")).sum === 50 67 | 68 | sql("select p.* from person p").apply.map(_.get("age")).sum === 50 69 | 70 | sql("select (name) n, (age) as a from person").apply.tuples === 71 | List(("joe", 36), ("moe", 14)) 72 | 73 | sql("select 'success' as status from person").apply === 74 | List("success", "success") 75 | } 76 | 77 | test("Query with input") { 78 | val q = sql("select name, age from person where age > ? order by name") 79 | q(30).map(_.get("name")) === List("joe") 80 | q(10).map(_.get("name")) === List("joe", "moe") 81 | 82 | val q2 = sql("select name, age from person where age > ? and name != ? order by name") 83 | q2(10, "joe").map(_.get("name")) === List("moe") 84 | 85 | // FIXME: see https://github.com/milessabin/shapeless/issues/44 86 | // sql("select name, ? from person").apply("x").tuples === 87 | // List(("joe", "x"), ("moe", "x")) 88 | } 89 | 90 | test("Joins") { 91 | sql("SELECT distinct p.name FROM person p JOIN job_history j ON p.id=j.person").apply === 92 | List("joe", "moe") 93 | 94 | sql("SELECT distinct p.name FROM (person p) INNER JOIN job_history j").apply === 95 | List("joe", "moe") 96 | 97 | sql("SELECT distinct p.name FROM person p CROSS JOIN job_history j").apply === 98 | List("joe", "moe") 99 | 100 | sql("SELECT p.name FROM person p JOIN (job_history j) ON (p.id=j.person and j.name=?)").apply("IBM") === 101 | List("joe", "moe") 102 | 103 | sql("select p.name from person p join person p2 using (id)").apply === List("joe", "moe") 104 | 105 | val qNullable = sql(""" 106 | select j.resigned from person p left join job_history j on p.id=j.person 107 | where p.name=? and j.name=? LIMIT 1 108 | """) 109 | 110 | qNullable.apply("unknown", Some("IBM")) === None 111 | qNullable.apply("joe", Some("IBM")) === Some(None) 112 | qNullable.apply("joe", Some("Enron")) === Some(Some(tstamp("2004-06-22 18:00:00.0"))) 113 | 114 | val qNonNullable = sql(""" 115 | select j.started from person p left join job_history j on p.id=j.person 116 | where p.name=? and j.name=? LIMIT 1 117 | """) 118 | 119 | qNonNullable.apply("unknown", Some("IBM")) === None 120 | qNonNullable.apply("joe", Some("unknown")) === None 121 | qNonNullable.apply("joe", Some("Enron")) === Some(Some(tstamp("2002-08-02 08:00:00.0"))) 122 | 123 | sql(""" 124 | SELECT p.name 125 | FROM person p 126 | JOIN ( 127 | SELECT * FROM person WHERE age>? 128 | ) AS p2 ON p.id=p2.id 129 | """).apply(10) === List("joe", "moe") 130 | } 131 | 132 | test("Query with join and column alias") { 133 | val q = sql("select p.name, j.name as employer, p.age from person p join job_history j on p.id=j.person where id=? order by employer") 134 | 135 | q(1).values === List("joe" :: "Enron" :: 36 :: HNil, "joe" :: "IBM" :: 36 :: HNil) 136 | q(1).tuples === List(("joe", "Enron", 36), ("joe", "IBM", 36)) 137 | } 138 | 139 | test("Query with optional column") { 140 | val q = sql("select p.name, j.name as employer, j.started, j.resigned from person p join job_history j on p.id=j.person order by j.started") 141 | 142 | q().tuples === List( 143 | ("joe", "Enron", tstamp("2002-08-02 08:00:00.0"), Some(tstamp("2004-06-22 18:00:00.0"))), 144 | ("joe", "IBM", tstamp("2004-07-13 11:00:00.0"), None), 145 | ("moe", "IBM", tstamp("2005-08-10 11:00:00.0"), None)) 146 | } 147 | 148 | test("Group by and order by") { 149 | sql("select p.name from person p where age > ? group by p.id, p.age order by p.name").apply(1) === 150 | List("joe", "moe") 151 | 152 | sql("select p.name from person p where age > ? order by abs(salary - age)").apply(1) === 153 | List("moe", "joe") 154 | 155 | sql("select p.name from person p where age > ? order by ? desc").apply(5, 1) === 156 | List("moe", "joe") 157 | 158 | sql("select p.name from person p where age > ? order by abs(salary - ?)").apply(1, 500) === 159 | List("moe", "joe") 160 | 161 | sql("select p.name, p.age from person p where age > ? order by 2 desc").apply(1).tuples === 162 | List(("joe", 36), ("moe", 14)) 163 | 164 | sql("select p.name, p.age from person p where age > ? order by (2) desc, (p.name)").apply(1).tuples === 165 | List(("joe", 36), ("moe", 14)) 166 | 167 | sql("select p.name from person p where age > ? group by p.name collate latin1_swedish_ci order by p.name").apply(10) === 168 | List("joe", "moe") 169 | 170 | sql("select p.name from person p where age > ? group by p.name collate latin1_swedish_ci order by p.name collate latin1_swedish_ci asc limit 2").apply(10) === 171 | List("joe", "moe") 172 | 173 | sql("select p.name from person p where age > ? group by ucase(p.name)").apply(10) === 174 | List("joe", "moe") 175 | } 176 | 177 | test("Query with functions") { 178 | val q = sql("select avg(age), sum(salary) as salary, count(1) from person where abs(age) > ?") 179 | val res = q(10) 180 | res.get("avg") === Some(25.0) 181 | res.get("salary") === Some(17500) 182 | res.get("count") === 2 183 | 184 | val q2 = sql("select min(name) as name, max(age) as age from person where age > ?") 185 | val res2 = q2(10) 186 | res2.get("name") === Some("joe") 187 | res2.get("age") === Some(36) 188 | 189 | val res3 = q2(100) 190 | res3.get("name") === None 191 | res3.get("age") === None 192 | 193 | sql("select min(?) from person").apply(10) should equal(Some(10)) 194 | 195 | sql("select max(age) from person").apply === Some(36) 196 | 197 | sql("select max(age) + 1 from person").apply === Some(37) 198 | 199 | sql("select count(id) from person").apply === 2 200 | 201 | sql("select max(id) from person where age > ?").apply(100) === None 202 | 203 | sql("select age > 20 from person order by age").apply === 204 | List(false, true) 205 | 206 | sql("select resigned is not null from job_history order by started").apply === 207 | List(true, false, false) 208 | 209 | sql("select age in (1,36) from person order by age desc").apply === 210 | List(true, false) 211 | 212 | sql("select resigned < now() from job_history order by started").apply === 213 | List(Some(true), None, None) 214 | 215 | sql("select age from person where age|10=46").apply === 216 | List(36) 217 | 218 | sql("select age from person where age|?=?").apply(10, 46) === 219 | List(36) 220 | 221 | sql("select age|? from person where age&?=0").apply(10, 2) === 222 | List(46) 223 | 224 | sql("select ? > 2 from person").apply(3) === 225 | List(true, true) 226 | 227 | sql("select age > ? from person order by age").apply(18) === 228 | List(false, true) 229 | 230 | sql("select count(age>?) as a2 from person").apply(30) === 2L 231 | 232 | sql("select age/(age*10) from person").apply === List(Some(0.1), Some(0.1)) 233 | sql("select age/0 from person").apply === List(None, None) 234 | 235 | sql("select count(distinct name) from person").apply === 2 236 | } 237 | 238 | test("Query with just one selected column") { 239 | sql("select name from person where age > ? order by name").apply(10) === 240 | List("joe", "moe") 241 | 242 | sql("select name from person where age > ? order by name for update").apply(10) === 243 | List("joe", "moe") 244 | 245 | sql("select name from person where name LIKE ? order by name").apply("j%") === 246 | List("joe") 247 | 248 | sql("select name from person where not(age > ? and name=?)").apply(10, "joe") === 249 | List("moe") 250 | } 251 | 252 | test("Query with constraint by unique column") { 253 | val q = sql("select age, name from person where id=?") 254 | q(1) === Some(("age" ->> 36) :: ("name" ->> "joe") :: HNil) 255 | q(1).tuples === Some((36, "joe")) 256 | 257 | val q2 = sql("select name from person where id=?") 258 | q2(1) === Some("joe") 259 | 260 | val q3 = sql("select name from person where id=? and age>?") 261 | q3(1, 10) === Some("joe") 262 | 263 | val q4 = sql("select name from person where id=? or age>?") 264 | q4(1, 10) === List("joe", "moe") 265 | 266 | val q5 = sql("select age from person order by age desc limit 1") 267 | q5() === Some(36) 268 | 269 | sql("select count(1) from person where id=?").apply(1) === 1 270 | sql("select count(1) from person where id=?").apply(999) === 0 271 | sql("select count(1) > 0 from person where id=?").apply(1) === true 272 | } 273 | 274 | test("Query with is not null") { 275 | val q1 = 276 | sql(""" 277 | SELECT j.resigned FROM person p LEFT JOIN job_history j ON p.id=j.person 278 | WHERE p.name=? AND j.resigned IS NOT NULL 279 | """) 280 | q1.apply("joe") === List(tstamp("2004-06-22 18:00:00.0")) 281 | 282 | val q2 = 283 | sql(""" 284 | SELECT j.resigned FROM person p LEFT JOIN job_history j ON p.id=j.person 285 | WHERE (p.name=? AND j.resigned IS NOT NULL) OR p.age>? 286 | """) 287 | q2.apply("joe", 100) === List(Some(tstamp("2004-06-22 18:00:00.0"))) 288 | } 289 | 290 | test("Query with limit") { 291 | sql("select age from person order by age limit ?").apply(2) === List(14, 36) 292 | sql("select age from person order by age desc, name asc limit ?").apply(2) === List(36, 14) 293 | sql("select age from person order by age limit ?").apply(1) === List(14) 294 | sql("select age from person order by age limit ? offset 1").apply(1) === List(36) 295 | sql("select age from person order by age limit ? offset ?").apply(1, 1) === List(36) 296 | 297 | val q = sql("select age from person where name between ? and ? limit ?") 298 | q("i", "k", 2) === List(36) 299 | 300 | sql("select age from person where name not between ? and ?").apply("l", "z") === 301 | List(36) 302 | } 303 | 304 | test("Tagging") { 305 | val person = Witness("person") 306 | def findName(id: Long @@ person.T) = sql("select name from person where id=?").apply(id) 307 | 308 | val names = sql("select distinct person from job_history").apply map findName 309 | names === List(Some("joe"), Some("moe")) 310 | 311 | sqlt("select name,age from person where id=?").apply(tag[person.T](1)).tuples === Some("joe", 36) 312 | } 313 | 314 | test("Subselects") { 315 | sql("select distinct name from person where id = (select person from job_history limit 1)").apply === 316 | Some("joe") 317 | 318 | sql("select distinct name from person where id in (select person from job_history)").apply === 319 | List("joe", "moe") 320 | 321 | sql("select distinct name from person where id not in (select person from job_history)").apply === 322 | Nil 323 | 324 | sql("""select distinct name from person where id in 325 | (select person from job_history where started > ?)""").apply(year(2003)) === 326 | List("joe", "moe") 327 | 328 | sql("""select name from person p where exists 329 | (select person from job_history j where resigned is not null and p.id=j.person)""").apply === 330 | List("joe") 331 | 332 | sql("""select name from person p where not exists 333 | (select person from job_history j where resigned is not null and p.id=j.person)""").apply === 334 | List("moe") 335 | 336 | sql(""" 337 | select p.name from person p where p.age <= ? and ? < 338 | (select j.started from job_history j where p.id=j.person limit 1) order by p.name 339 | """).apply(50, tstamp("2002-08-02 08:00:00.0")) === 340 | List("moe") 341 | } 342 | 343 | test("Insert, delete") { 344 | sql("delete from jobs").apply 345 | sql("insert into jobs(person, job) select p.name, j.name from person p, job_history j where p.id=j.person and p.age>?").apply(30) 346 | sql("select person, job from jobs").apply.tuples === 347 | List(("joe", "Enron"), ("joe", "IBM")) 348 | 349 | sql("delete from jobs where job=?").apply("Enron") 350 | sql("select person, job from jobs").apply.tuples === List(("joe", "IBM")) 351 | 352 | sql("delete p from person p where p.age < ?").apply(40) 353 | sql("select name from person").apply === Nil 354 | 355 | sql(""" 356 | insert into person(salary, name, age) 357 | values ((select count(1) from job_history where name = ? LIMIT 1), ?, ?) 358 | """).apply("IBM", "foo", 50) 359 | } 360 | 361 | test("Multidelete") { 362 | sql("delete p, j from person p, job_history j where p.id=j.person and p.id=?").apply(1) 363 | sql("select name from person").apply === List("moe") 364 | } 365 | 366 | test("Get generated keys (currently supports just one generated key per row)") { 367 | val newPerson = sqlk("insert into person(name, age, salary) values (?, ?, ?)") 368 | val newPersons = sqlk("insert into person(name, age, salary) select name, age, salary from person") 369 | 370 | val key = newPerson.apply("bill", 45, 3000) 371 | val maxId = sql("select id from person where name=?").apply("bill").head 372 | assert(key.equals(maxId)) 373 | sqlt("delete from person where id=?").apply(key) 374 | 375 | newPersons.apply should equal(List(maxId + 1, maxId + 2)) 376 | } 377 | 378 | test("Update") { 379 | sql("update person set name=? where age >= ?").apply("joe2", 30) 380 | sql("select name from person order by age").apply === List("moe", "joe2") 381 | 382 | sql("update person set name=? where age >= ? order by age asc limit 1").apply("moe2", 10) 383 | sql("select name from person order by age").apply === List("moe2", "joe2") 384 | 385 | sql("update person set name=upper(name)").apply 386 | sql("select name from person order by age").apply === List("MOE2", "JOE2") 387 | 388 | sql("update person p, job_history j set p.name=?, j.name=? where p.id=j.person and p.age > ?").apply("joe2", "x", 30) 389 | sql("select p.name, j.name from person p, job_history j where p.id=j.person order by age").apply.tuples === 390 | List(("MOE2", "IBM"), ("joe2", "x"), ("joe2", "x")) 391 | 392 | sql("UPDATE person SET age=age&? WHERE name=?").apply(0, "joe2") 393 | sql("SELECT age FROM person WHERE name=?").apply("joe2").head === 0 394 | 395 | sql("update alltypes set i=not(i) where a > 100").apply 396 | } 397 | 398 | test("Blob") { 399 | import javax.sql.rowset.serial.SerialBlob 400 | 401 | val img = new SerialBlob("fake img".getBytes("UTF-8")) 402 | sql("update person set img=? where name=?").apply(Some(img), "joe") 403 | val savedImg = sql("select img from person where name=?").apply("joe").head.get.getBytes(1, 8) 404 | new String(savedImg, "UTF-8") === "fake img" 405 | } 406 | 407 | test("Union") { 408 | sql(""" 409 | select name,age from person where age < ? 410 | union 411 | select name,age from person where age > ? 412 | """).apply(15, 20).tuples === List(("moe", 14), ("joe", 36)) 413 | 414 | sql(""" 415 | (select name from person where age < ?) 416 | union all 417 | (select name from person where age > ?) 418 | order by name desc limit ? 419 | """).apply(15, 20, 5) === List("moe", "joe") 420 | } 421 | 422 | test("Arithmetic") { 423 | sql("select SUM(age + age) from person where age<>?").apply(40) === 424 | Some(100) 425 | 426 | sql("update person set age = age + 1").apply 427 | sql("select age - 1, age, age * 2, (age % 10) - 1, -10 from person order by age").apply.tuples === 428 | List((14, 15, 30, 4, -10), (36, 37, 74, 6, -10)) 429 | } 430 | 431 | test("Quoting") { 432 | val rows = sql(""" select `in`.name as "select", age > 20 as "type" from person `in` order by age """).apply 433 | 434 | rows.head.get("select") === "moe" 435 | rows.head.get("type") === false 436 | } 437 | 438 | test("DUAL table") { 439 | sql("select count(1) from dual").apply === 1 440 | } 441 | 442 | test("Derived tables") { 443 | sql("select id from (select id, name from person) AS data where data.name = ?").apply("joe") === 444 | List(1) 445 | 446 | sql("select id from (select id, name from person where age>?) data where data.name = ?").apply(20, "joe") === 447 | List(1) 448 | 449 | sql(""" 450 | SELECT data.id, j.name 451 | FROM (select id, name from person) AS data join job_history j on (data.id = j.person) 452 | WHERE data.name = ? 453 | """).apply("joe").tuples === List((1, "Enron"), (1, "IBM")) 454 | 455 | sql(""" 456 | SELECT data.id, j.name 457 | FROM 458 | (select id, name from person where age>20) AS data join 459 | (select person,name from job_history) AS j on (data.id=j.person) 460 | WHERE data.name = ? 461 | """).apply("joe").tuples === List((1, "Enron"), (1, "IBM")) 462 | 463 | sql(""" 464 | SELECT p.name, j.c AS cnt 465 | FROM person p 466 | LEFT JOIN ( 467 | SELECT person, count(1) AS c FROM job_history 468 | GROUP BY person 469 | ) AS j ON p.id=j.person 470 | WHERE p.name = ? 471 | """).apply("joe").tuples === List(("joe", 2)) 472 | 473 | sql(""" 474 | SELECT p.name, coalesce(j.c, 0) AS cnt 475 | FROM person p 476 | LEFT JOIN ( 477 | SELECT person, count(1) AS c FROM job_history 478 | GROUP BY person 479 | ) AS j ON p.id=j.person 480 | WHERE p.name = ? 481 | """).apply("joe").tuples === List(("joe", 2)) 482 | } 483 | 484 | test("Subselect in projection") { 485 | sql(""" 486 | select p.name, (select name from job_history where p.age > ? and person=p.id limit 1) 487 | from person p order by p.name 488 | """).apply(20).tuples === List(("joe", Some("Enron")), ("moe", None)) 489 | } 490 | 491 | test("Case expr") { 492 | sql(""" 493 | SELECT name, CASE 494 | WHEN salary > 7000 THEN 'rich' 495 | WHEN salary > 2000 THEN 'proletarian' 496 | ELSE 'unemployed' 497 | END AS x 498 | FROM person p 499 | """).apply.tuples === List(("joe", "rich"), ("moe", "rich")) 500 | 501 | sql(""" 502 | SELECT name, CASE 503 | WHEN salary > ? THEN 'rich' 504 | WHEN salary > ? THEN 'proletarian' 505 | ELSE 'unemployed' 506 | END AS x 507 | FROM person p 508 | """).apply(7000, 2000).tuples === List(("joe", "rich"), ("moe", "rich")) 509 | 510 | sql(""" 511 | SELECT count(1) 512 | FROM job_history 513 | GROUP BY CASE 514 | WHEN ? = 'person' THEN person 515 | ELSE name 516 | END 517 | """).apply("person") === List(2, 1) 518 | 519 | sql(""" 520 | SELECT distinct name 521 | FROM job_history 522 | ORDER BY CASE 523 | WHEN ? = 'name' THEN name 524 | ELSE person 525 | END 526 | """).apply("name") === List("Enron", "IBM") 527 | } 528 | 529 | test("with rollup") { 530 | sql(""" 531 | SELECT j.name, age, sum(salary) 532 | FROM person p 533 | JOIN job_history j 534 | ON p.id = j.person 535 | GROUP BY j.name, age 536 | WITH ROLLUP 537 | """).apply.tuples === 538 | (Option("Enron"), Option(36), Option(9500)) :: 539 | (Option("Enron"), None: Option[Int], Option(9500)) :: 540 | (Option("IBM"), Option(14), Option(8000)) :: 541 | (Option("IBM"), Option(36), Option(9500)) :: 542 | (Option("IBM"), None: Option[Int], Option(17500)) :: 543 | (None: Option[String], None: Option[Int], Option(27000)) :: Nil 544 | } 545 | 546 | test("Fallback by using unsupported MySQL syntax") { 547 | sql("select 1").apply === List(1) 548 | } 549 | 550 | test("JDBC based inference sqlj") { 551 | sqlj("select name from person where id=?").apply(1) === List("joe") 552 | } 553 | } 554 | -------------------------------------------------------------------------------- /core/src/test/scala/failures.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import org.scalatest._ 5 | import shapeless.test._ 6 | 7 | class FailureSuite extends Example { 8 | test("ORDER BY references unknown column") { 9 | illTyped(""" 10 | sql("select name, age from person order by unknown_column").apply 11 | """) 12 | } 13 | 14 | test("INSERT references unknown table") { 15 | illTyped(""" 16 | sql("insert into peson(id, name, age, salary) values (?, ?, ?, ?)").apply(1, "joe", 10, 0) 17 | """) 18 | } 19 | 20 | test("INSERT has unmatching number of listed columns and input") { 21 | illTyped(""" 22 | sql("insert into person(id, name, age) values (?, ?)").apply(1, "joe") 23 | """) 24 | 25 | illTyped(""" 26 | sql("insert into person(id, name, age) values (?, ?, 10, ?)").apply(1, "joe", 10) 27 | """) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /core/src/test/scala/mysqlexamples.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import org.scalatest._ 5 | import shapeless._ 6 | 7 | class MySQLExamples extends MySQLConfig { 8 | test("Interval") { 9 | sql("select started + interval 1 month from job_history order by started").apply should 10 | equal(List(tstamp("2002-09-02 08:00:00.0"), tstamp("2004-08-13 11:00:00.0"), tstamp("2005-09-10 11:00:00.0"))) 11 | } 12 | 13 | test("Functions") { 14 | val d = sql("select datediff(resigned, '2010-10-10') from job_history where resigned IS NOT NULL").apply.head 15 | (d map math.abs) === Some(2301) 16 | 17 | val resignedQ = sql("select name from job_history where datediff(resigned, ?) < ?") 18 | resignedQ.apply(tstamp("2004-08-13 11:00:00.0"), Some(60)) === List("Enron") 19 | 20 | sql("select coalesce(resigned, '1990-01-01 12:00:00') from job_history order by resigned").apply === 21 | List(tstamp("1990-01-01 12:00:00.0"), tstamp("1990-01-01 12:00:00.0"), tstamp("2004-06-22 18:00:00.0")) 22 | 23 | sql("select coalesce(resigned, NULL) from job_history order by resigned").apply === 24 | List(None, None, Some(tstamp("2004-06-22 18:00:00.0"))) 25 | 26 | sql("select ifnull(resigned, resigned) from job_history order by resigned").apply === 27 | List(None, None, Some(tstamp("2004-06-22 18:00:00.0"))) 28 | 29 | sql("select ifnull(resigned, started) from job_history order by resigned").apply === 30 | List(tstamp("2004-07-13 11:00:00.0"), tstamp("2005-08-10 11:00:00.0"), tstamp("2004-06-22 18:00:00.0")) 31 | 32 | sql("select coalesce(resigned, ?) from job_history order by resigned").apply(tstamp("1990-01-01 12:00:00.0")) === 33 | List(tstamp("1990-01-01 12:00:00.0"), tstamp("1990-01-01 12:00:00.0"), tstamp("2004-06-22 18:00:00.0")) 34 | 35 | sql("select IF(age<18 or age>100, 18, age) from person where age > ? order by age").apply(5) === 36 | List(18, 36) 37 | } 38 | 39 | test("String functions") { 40 | sql("select concat('hello ', name, ?) from person").apply("!") === 41 | List("hello joe!", "hello moe!") 42 | } 43 | 44 | test("Insert/update ignore") { 45 | val addPerson = sql("insert ignore into person(id, name, age, salary) values (?, ?, ?, ?)") 46 | val updateId = sql("update ignore person set id=? where id=?") 47 | 48 | addPerson(1, "tom", 40, 1000) === 0 49 | } 50 | 51 | test("ON DUPLICATE KEY") { 52 | val addOrUpdate = sql(""" 53 | insert into person(id, name, age, salary) values (?, ?, ?, ?) 54 | on duplicate key update name=?, age=age+1, salary=? 55 | """) 56 | 57 | addOrUpdate(1, "tom", 40, 1000, "tommy", 2000) 58 | sql("select name, age, salary from person where id=1").apply.tuples === 59 | Some(("tommy", 37, 2000)) 60 | } 61 | 62 | test("Types") { 63 | val q = sql("select * from alltypes LIMIT 1").apply.tuples 64 | q === Some((1, 1, 1, 1, 1, 1.0f, 1.0, 1.0, true, 65 | date("2012-10-10"), 66 | time("14:00:00.0"), 67 | datetime("2012-10-10 00:00:00.0"), 68 | tstamp("2012-10-10 00:00:00.0"), 69 | date("2012-01-01"), 70 | "a", "a", "a", "v1", "v1", BigDecimal(1.0))) 71 | 72 | sql("update alltypes set t=? where a>100").apply(BigDecimal(1.0)) 73 | } 74 | 75 | test("Cast functions") { 76 | sql("select binary(name) from person").apply === List("joe", "moe") 77 | 78 | sql("select binary(age) from person").apply === List(36, 14) 79 | 80 | val res1: List[String] = sql("select convert(age, char(10)) from person").apply 81 | res1 === List("36", "14") 82 | 83 | val res2: List[Int] = sql("select convert(name, signed) from person").apply 84 | res2 === List(0, 0) 85 | 86 | sql("select convert(name using utf8) as name from person").apply === List(Option("joe"), Option("moe")) 87 | } 88 | 89 | test("IN operator") { 90 | sql("select age from person where name in (?)").apply(List("joe", "moe")) == List(36, 14) 91 | 92 | sql("select age from person where name in (?, ?)").apply("joe", "moe") == List(36, 14) 93 | 94 | sql("select age from person where name in (?, 'moe')").apply("joe") == List(36, 14) 95 | 96 | sql("select age from person where name not in (?)").apply(List("joe", "zoe")) == List(14) 97 | 98 | sql("select name from person where age in (?)").apply(List(1, 36)) == List("joe") 99 | } 100 | } 101 | 102 | -------------------------------------------------------------------------------- /core/src/test/scala/postgreexamples.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import org.scalatest._ 5 | import shapeless._ 6 | 7 | class PostgreSQLExamples extends PostgreSQLConfig { 8 | test("Simple query") { 9 | sql("select name from person").apply === List("joe", "moe") 10 | 11 | sql("select sum(age) from person").apply === Some(50) 12 | } 13 | 14 | test("any, some and all") { 15 | sql("select age from person where name = any(?)").apply(Seq("joe", "moe")) === List(36, 14) 16 | sql("select age from person where name = some(?)").apply(Seq("joe", "moe")) === List(36, 14) 17 | sql("select age from person where name = all(?)").apply(Seq("joe")) === List(36) 18 | 19 | sql("select name from person where age = any(?)").apply(Seq(1, 36)) === List("joe") 20 | sql("select name from person where age+1 = any(?)").apply(Seq(1, 37)) === List("joe") 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /core/src/test/scala/recordexamples.scala: -------------------------------------------------------------------------------- 1 | package sqltyped 2 | 3 | import java.sql._ 4 | import org.scalatest._ 5 | import shapeless._ 6 | 7 | class RecordExampleSuite extends MySQLConfig { 8 | test("Query to CSV") { 9 | val rows = sql("select name as fname, age, img from person limit 100").apply.values 10 | 11 | CSV.fromList(rows) === """ "joe","36","" 12 | |"moe","14","" """.stripMargin.trim 13 | } 14 | 15 | test("Query to untyped tuples") { 16 | val rows = sql("select name, age from person limit 100").apply 17 | Record.toTupleList(rows.head) === List(("name", "joe"), ("age", 36)) 18 | Record.toTupleLists(rows) === 19 | List(List(("name", "joe"), ("age", 36)), List(("name", "moe"), ("age", 14))) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | Demo app 2 | ======== 3 | 4 | Demo app is a small REST server. The stack is: 5 | 6 | * Unfiltered (REST API) 7 | * json4s (JSON rendering) 8 | * sqlτyped (Database access) 9 | * Shapeless (Records) 10 | * Slick (Database connection handling) 11 | * MySQL (Database) 12 | 13 | Start 14 | ----- 15 | 16 | Start MySQL database. 17 | 18 | Then in the directory 'demo': 19 | 20 | ``` 21 | mysql -u root -e 'create database sqltyped_demo' 22 | mysql -u root sqltyped_demo < src/main/resources/schema.sql 23 | sbt run 24 | ``` 25 | 26 | * List persons 27 | 28 | ```curl http://localhost:8080/people``` 29 | 30 | * View person details 31 | 32 | ```curl http://localhost:8080/person/3``` 33 | 34 | * Add a comment 35 | 36 | ```curl -X PUT http://localhost:8080/person/3/comment?text=Hello``` 37 | 38 | Breakdown 39 | --------- 40 | 41 | ### [schema.sql](https://github.com/jonifreeman/sqltyped/blob/master/demo/src/main/resources/schema.sql) ### 42 | 43 | That's where the demo schema is defined. sqlτyped compiler will get an access to those definitions at compile time, which it then uses to infer Scala types of the SQL statements. 44 | 45 | ![Schema](http://yuml.me/d0e5d450) 46 | 47 | 48 | ### [package.scala](https://github.com/jonifreeman/sqltyped/blob/master/demo/src/main/scala/package.scala) ### 49 | 50 | Database connection and sqlτyped is configured at package object. Fairly trivial stuff. 51 | 52 | ### [testdata.scala](https://github.com/jonifreeman/sqltyped/blob/master/demo/src/main/scala/testdata.scala) ### 53 | 54 | Initial testdata is created by executing SQL statements with function ```sqlk```. The SQL string literal is converted into a function by sqlτyped compiler. Why ```sqlk``` and not just ```sql```? Well, ```sqlk``` is a specialized version of function ```sql``` which returns generated keys instead of updated rows. 55 | 56 | ### [server.scala](https://github.com/jonifreeman/sqltyped/blob/master/demo/src/main/scala/server.scala) ### 57 | 58 | REST endpoints are defined as Unfiltered pattern matches. The server is booted to port 8080 after test data is initialized. Note, some sloppy error handling but this is not an Unfiltered demo after all. 59 | 60 | ### [db.scala](https://github.com/jonifreeman/sqltyped/blob/master/demo/src/main/scala/db.scala) ### 61 | 62 | This is where the meat of the demo is. sqlτyped promotes a style where SQL is used directly to define data access functions. It let's the programmer use all the available database features in a native form. It is a job of the compiler to integrate these two worlds as seamlessly as possible. 63 | 64 | Note the complete lack of type annotations in defined ```sql``` functions. The types are inferred from database. If you know SQL, you know how to define data access functions with sqlτyped. 65 | 66 | Processing results is slightly more involved and has some rough edges. sqlτyped returns results as a list of [extensible records](https://github.com/jonifreeman/sqltyped/wiki/User-guide#wiki-records). Record system is built on top of HList which sometimes leaks through as very cryptic compiler error messages. Nevertheless, extensible record is a nice abstraction for a database row and it is very easy to convert record to a more familiar tuple when needed (```record.values.tupled```). Function ```personWithInterviews``` shows an example usage. ```personById``` returns a record: 67 | 68 | ```scala 69 | { 70 | id: Long 71 | , name: String 72 | , interview: Option[Long] 73 | , rating: Option[Double] 74 | , held_by: Option[String] 75 | } 76 | ``` 77 | 78 | We could convert that directly to JSON with function ```sqltyped.json4s.JSON.compact```. However, to offer a nicer API some structure should be added to that flat row. We can modify the value of a field ```interview``` with function ```updateWith```. It is a function from the original value (of type ```Option[Long]``` here) to a new value. The new value is a new record. Finally, fields which are added to the just created record are removed from the original record. 79 | 80 | ```scala 81 | personById(id) map { p => 82 | p.updateWith("interview") { _ map (i => 83 | "rating" ->> p.get("rating") :: "held_by" ->> p.get("held_by") :: "comments" ->> comments(i) :: HNil 84 | )} - "rating" - "held_by" 85 | } 86 | ``` 87 | 88 | The result is a following record which can be directly rendered as a nice JSON document. 89 | 90 | ```scala 91 | { 92 | id: Long 93 | , name: String 94 | , interview: Option[{ 95 | rating: Option[Double] 96 | , held_by: Option[String] 97 | , comments: List[{text: String, created: Timestamp, author: String}] 98 | }] 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /demo/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.0 -------------------------------------------------------------------------------- /demo/project/build.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object DemoBuild extends Build { 5 | lazy val demoSettings = Defaults.defaultSettings ++ Seq( 6 | organization := "com.example", 7 | version := "0.4.1", 8 | scalaVersion := "2.11.7", 9 | scalacOptions ++= Seq("-unchecked", "-deprecation"), 10 | javacOptions ++= Seq("-target", "1.6", "-source", "1.6"), 11 | crossPaths := false, 12 | libraryDependencies ++= Seq( 13 | "fi.reaktor" %% "sqltyped" % "0.4.1", 14 | "fi.reaktor" %% "sqltyped-json4s" % "0.4.0", 15 | "com.typesafe" %% "slick" % "1.0.0-RC1", 16 | "net.databinder" %% "unfiltered" % "0.6.5", 17 | "net.databinder" %% "unfiltered-netty" % "0.6.5", 18 | "net.databinder" %% "unfiltered-netty-server" % "0.6.5", 19 | "mysql" % "mysql-connector-java" % "5.1.21" 20 | ), 21 | initialize ~= { _ => initSqltyped }, 22 | resolvers ++= Seq(sonatypeNexusSnapshots, sonatypeNexusReleases) 23 | ) 24 | 25 | lazy val demo = Project( 26 | id = "sqltyped-demo", 27 | base = file("."), 28 | settings = demoSettings 29 | ) 30 | 31 | def initSqltyped { 32 | System.setProperty("sqltyped.url", "jdbc:mysql://localhost:3306/sqltyped_demo") 33 | System.setProperty("sqltyped.driver", "com.mysql.jdbc.Driver") 34 | System.setProperty("sqltyped.username", "root") 35 | System.setProperty("sqltyped.password", "") 36 | } 37 | 38 | val sonatypeNexusSnapshots = "Sonatype Nexus Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots" 39 | val sonatypeNexusReleases = "Sonatype Nexus Releases" at "https://oss.sonatype.org/content/repositories/releases" 40 | } 41 | -------------------------------------------------------------------------------- /demo/src/main/resources/schema.sql: -------------------------------------------------------------------------------- 1 | create table person( 2 | id bigint(20) NOT NULL auto_increment, 3 | name varchar(255) NOT NULL, 4 | secret varchar(255), 5 | interview bigint(20), 6 | PRIMARY KEY (id), 7 | UNIQUE (name) 8 | ) ENGINE=InnoDB; 9 | 10 | create table interview( 11 | id bigint(20) NOT NULL auto_increment, 12 | held_by bigint(20) NOT NULL, 13 | rating DOUBLE, 14 | FOREIGN KEY held_by_fk (held_by) 15 | REFERENCES person (id) 16 | ON DELETE CASCADE 17 | ON UPDATE NO ACTION, 18 | PRIMARY KEY (id) 19 | ) ENGINE=InnoDB; 20 | 21 | alter table person 22 | ADD FOREIGN KEY interview_fk (interview) 23 | REFERENCES interview (id) 24 | ON DELETE CASCADE 25 | ON UPDATE NO ACTION; 26 | 27 | create table comment( 28 | text TEXT, 29 | created TIMESTAMP NOT NULL, 30 | author bigint(20) NOT NULL, 31 | interview bigint(20) NOT NULL, 32 | FOREIGN KEY author_fk (author) 33 | REFERENCES person (id) 34 | ON DELETE CASCADE 35 | ON UPDATE NO ACTION, 36 | FOREIGN KEY interview_fk (interview) 37 | REFERENCES interview (id) 38 | ON DELETE CASCADE 39 | ON UPDATE NO ACTION 40 | ) ENGINE=InnoDB; 41 | -------------------------------------------------------------------------------- /demo/src/main/scala/db.scala: -------------------------------------------------------------------------------- 1 | package demo 2 | 3 | import sqltyped._ 4 | import shapeless._ 5 | 6 | object Db { 7 | val personNames = sql("SELECT name FROM person") 8 | val personIdByName = sql("SELECT id FROM person WHERE name=?") 9 | 10 | val personById = 11 | sql("""SELECT p.id, p.name, p.interview, i.rating, p2.name AS held_by 12 | FROM person p LEFT JOIN interview i ON p.interview=i.id 13 | LEFT JOIN person p2 ON i.held_by=p2.id 14 | WHERE p.id = ? LIMIT 1""") 15 | 16 | val comments = 17 | sql("""SELECT c.text, c.created, p.name AS author 18 | FROM comment c JOIN person p ON c.author=p.id 19 | WHERE c.interview = ?""") 20 | 21 | val newComment = 22 | sql("""INSERT INTO comment (text, created, author, interview) 23 | SELECT ?, now(), ?, p.interview FROM person p 24 | WHERE p.id = ? LIMIT 1""") 25 | 26 | def personWithInterviews(id: Long) = personById(id) map { p => 27 | p.updateWith("interview") { _ map (i => 28 | "rating" ->> p.get("rating") :: "held_by" ->> p.get("held_by") :: "comments" ->> comments(i) :: HNil 29 | )} - "rating" - "held_by" 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /demo/src/main/scala/package.scala: -------------------------------------------------------------------------------- 1 | import scala.slick.session.Database 2 | import sqltyped._ 3 | 4 | package object demo { 5 | Class.forName("com.mysql.jdbc.Driver") 6 | val db = Database.forURL("jdbc:mysql://localhost:3306/sqltyped_demo", 7 | driver = "com.mysql.jdbc.Driver", user = "root", password = "") 8 | 9 | implicit val formats = org.json4s.DefaultFormats 10 | implicit def conn = Database.threadLocalSession.conn 11 | } 12 | -------------------------------------------------------------------------------- /demo/src/main/scala/server.scala: -------------------------------------------------------------------------------- 1 | package demo 2 | 3 | import unfiltered.request._ 4 | import unfiltered.response._ 5 | import sqltyped.json4s.JSON.compact 6 | 7 | object Server extends scala.App { 8 | val api = unfiltered.netty.cycle.Planify { 9 | case GET(Path("/people")) => 10 | Ok ~> Json(compact(Db.personNames.apply)) 11 | 12 | case GET(Path(Seg("person" :: id :: Nil))) => 13 | Db.personWithInterviews(id.toLong).headOption match { 14 | case None => NotFound ~> ResponseString("No such person") 15 | case Some(p) => Ok ~> Json(compact(p)) 16 | } 17 | 18 | case PUT(Path(Seg("person" :: id :: "comment" :: Nil)) & Params(params)) => 19 | val author = Db.personIdByName("Admin") getOrElse sys.error("Initialization error: no admin") 20 | Db.newComment(params("text").head, author, id.toLong) 21 | Ok ~> ResponseString("OK") 22 | } 23 | 24 | db.withSession { TestData.drop; TestData.create } 25 | unfiltered.netty.Http(8080) 26 | .plan(unfiltered.netty.cycle.Planify { case x => db.withSession(api.intent(x)) }) 27 | .run(svr => println("Running: " + svr.url), svr => println("Shutting down.")) 28 | 29 | def Json(s: String) = JsonContent ~> ResponseString(s) 30 | } 31 | -------------------------------------------------------------------------------- /demo/src/main/scala/testdata.scala: -------------------------------------------------------------------------------- 1 | package demo 2 | 3 | import sqltyped._ 4 | 5 | object TestData { 6 | def drop = sql("DELETE FROM person").apply 7 | 8 | def create = { 9 | val newPerson = sqlk("INSERT INTO person (name, secret, interview) VALUES (?, ?, ?)") 10 | val admin = newPerson("Admin", None, None) 11 | val other = newPerson("Some other guy", None, None) 12 | val interview = sqlk("INSERT INTO interview (held_by, rating) VALUES (?, ?)").apply(admin, Some(4.5)) 13 | val dick = newPerson("Dick Tracy", Some("secret"), Some(interview)) 14 | Db.newComment("My first comment.", admin, dick) 15 | Db.newComment("My second comment.", other, dick) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /docs/phases.dot: -------------------------------------------------------------------------------- 1 | digraph phases { 2 | rankdir="LR" 3 | parse -> resolve [label="Statement[Option[String]]"] 4 | resolve -> type [label="Statement[Table]"] 5 | type -> analyze [label="TypedStatement"] 6 | analyze -> codegen [label="TypedStatement"] 7 | codegen -> embed [label="Scala AST"] 8 | } 9 | -------------------------------------------------------------------------------- /docs/phases.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonifreeman/sqltyped/270496b2cc5f4581c3a4c77a939fd9d894943163/docs/phases.png -------------------------------------------------------------------------------- /json4s/src/main/scala/json.scala: -------------------------------------------------------------------------------- 1 | package sqltyped.json4s 2 | 3 | import org.json4s._ 4 | import native.JsonMethods 5 | import JsonMethods._ 6 | 7 | import shapeless._ 8 | import labelled._ 9 | import poly._ 10 | import ops.hlist._ 11 | import syntax.singleton._ 12 | import record._ 13 | import shapeless.tag.@@ 14 | 15 | /** 16 | * Note, copy-pasted from https://github.com/rrmckinley/jsonless 17 | * 18 | * Remove this once 'jsonless' is published to Sonatype repos. 19 | */ 20 | 21 | object Record { 22 | def keyAsString[F, V](f: FieldType[F, V])(implicit wk: shapeless.Witness.Aux[F]) = 23 | wk.value.toString 24 | } 25 | 26 | object JSON { 27 | def compact[A](a: A)(implicit ctoj: toJSON.Case[A] { type Result = JValue }): String = 28 | JsonMethods.compact(render(ctoj(a))) 29 | 30 | def pretty[A](a: A)(implicit ptoj: toJSON.Case[A] { type Result = JValue }): String = 31 | JsonMethods.pretty(render(ptoj(a))) 32 | } 33 | 34 | object toJSON extends Poly1 { 35 | implicit def atNull = at[Null](_ => JNull : JValue) 36 | implicit def atDouble = at[Double](JDouble(_) : JValue) 37 | implicit def atBigInt = at[BigInt](JInt(_) : JValue) 38 | implicit def atBigDecimal = at[BigDecimal](JDecimal(_) : JValue) 39 | implicit def atNumber[V <% Long] = at[V](i => JInt(BigInt(i)) : JValue) 40 | implicit def atString = at[String](s => (if (s == null) JNull else JString(s)) : JValue) 41 | implicit def atBoolean = at[Boolean](JBool(_) : JValue) 42 | implicit def atJSON[V <: JValue] = at[V](json => json : JValue) 43 | 44 | implicit def atTagged[V, T](implicit ttoj: toJSON.Case[V] { type Result = JValue }) = 45 | at[V @@ T](v => ttoj(v: V) : JValue) 46 | 47 | implicit def atOption[V](implicit otoj: toJSON.Case[V] { type Result = JValue }) = 48 | at[Option[V]] { 49 | case Some(v) => otoj(v : V) : JValue 50 | case None => JNull : JValue 51 | } 52 | 53 | implicit def atDate[V <: java.util.Date](implicit f: Formats) = 54 | at[V](s => (if (s == null) JNull else JString(f.dateFormat.format(s))) : JValue) 55 | 56 | implicit def atTraversable[V, C[V] <: Traversable[V]](implicit ttoj: toJSON.Case[V] { type Result = JValue }) = 57 | at[C[V]](l => JArray(l.toList.map(v => ttoj(v : V) : JValue)) : JValue) 58 | 59 | implicit def atMap[K, V](implicit mtoj: toJSON.Case[V] { type Result = JValue }) = 60 | at[Map[K, V]](m => JObject(m.toList.map(e => JField(e._1.toString, mtoj(e._2 : V) : JValue))) : JValue) 61 | 62 | implicit def atRecord[R <: HList, F, V](implicit folder: MapFolder[R, List[JField], fieldToJSON.type]) = 63 | at[R](r => JObject(r.foldMap(Nil: List[JField])(fieldToJSON)(_ ::: _)) : JValue) 64 | } 65 | 66 | object fieldToJSON extends Poly1 { 67 | implicit def atFieldType[F, V](implicit ftoj: toJSON.Case[V] { type Result = JValue }, wk: shapeless.Witness.Aux[F]) = at[FieldType[F, V]] { 68 | f => (Record.keyAsString(f), ftoj(f : V) : JValue) :: Nil 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /json4s/src/test/scala/jsonexample.scala: -------------------------------------------------------------------------------- 1 | package sqltyped.json4s 2 | 3 | import org.scalatest._ 4 | import shapeless._ 5 | import sqltyped._ 6 | 7 | class JSONExampleSuite extends FunSuite with matchers.ShouldMatchers { 8 | test("Record to JSON") { 9 | val SomeTag = Witness("SomeTag") 10 | 11 | val addr = ("street" ->> "Boulevard") :: ("city" ->> tag[SomeTag.T]("Helsinki")) :: HNil 12 | val child1 = ("name" ->> "ella") :: ("toys" ->> List("paperdoll", "jump rope")) :: ("age" ->> (Some(4): Option[Int])) :: HNil 13 | val child2 = ("name" ->> "moe") :: ("toys" ->> List("tin train")) :: ("age" ->> (None: Option[Int])) :: HNil 14 | val person = ("name" ->> "joe") :: ("age" ->> 36) :: ("address" ->> addr) :: ("children" ->> Seq(child1, child2)) :: HNil 15 | 16 | JSON.compact(addr) should equal("""{"street":"Boulevard","city":"Helsinki"}""") 17 | 18 | JSON.compact(child1) should equal("""{"name":"ella","toys":["paperdoll","jump rope"],"age":4}""") 19 | 20 | JSON.compact(person) should equal("""{"name":"joe","age":36,"address":{"street":"Boulevard","city":"Helsinki"},"children":[{"name":"ella","toys":["paperdoll","jump rope"],"age":4},{"name":"moe","toys":["tin train"],"age":null}]}""") 21 | } 22 | 23 | /* 24 | test("Nulls") { 25 | val foo: String = null 26 | val bad = ("key1" ->> null) :: ("key2" ->> foo) :: HNil 27 | JSON.compact(bad) should equal("""{"key1":null,"key2":null}""") 28 | } 29 | */ 30 | 31 | test("Date") { 32 | implicit val formats = org.json4s.DefaultFormats 33 | val p = ("name" ->> "Joe") :: ("birthdate" ->> new java.util.Date(0)) :: HNil 34 | JSON.compact(p) should equal("""{"name":"Joe","birthdate":"1970-01-01T00:00:00Z"}""") 35 | } 36 | 37 | test("Data which is already in JSON format") { 38 | import org.json4s._ 39 | 40 | val p = ("name" ->> "Joe") :: ("addr" ->> JObject(List("street" -> JString("Boulevard"), "city" -> JString("Helsinki")))) :: HNil 41 | JSON.compact(p) should equal("""{"name":"Joe","addr":{"street":"Boulevard","city":"Helsinki"}}""") 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /notes/0.1.0.markdown: -------------------------------------------------------------------------------- 1 | [sqlτyped](https://github.com/jonifreeman/sqltyped) - a macro which infers Scala types from database. 2 | 3 | Initial release: 4 | 5 | ## Infers Scala types from SQL statements. 6 | 7 | scala> sql("select age from person").apply 8 | res0: List[Int] = List(36, 14) 9 | 10 | ## Uses type tags for PKs and FKs. 11 | 12 | scala> sql("select id from person").apply 13 | res0: List[Long @@ Tables.person] = List(1, 2) 14 | 15 | ## Analyzes the query and infers more exact type (e.g. Option instead of List). 16 | 17 | scala> sql("select name from person where id=?").apply(123) 18 | res0: Some[String] = Some(joe) 19 | 20 | ## Can return generated keys of a statement instead of updated rows. 21 | 22 | scala> sqlk("insert into person(name, age, salary) values (?, ?, ?)").apply("jill", 45, 30000) 23 | res0: Long @@ Tables.person = 3 24 | 25 | ## Includes an extensible record system and returns query results as a List of records. 26 | 27 | scala> object name; object age 28 | scala> val p = (name -> "Joe") :: (age -> 36) :: HNil 29 | scala> p.show 30 | res0: String = { name = Joe, age = 36 } 31 | scala> p get name 32 | res1: String = Joe 33 | scala> p get age 34 | res2: Int = 36 35 | 36 | ## Includes a function which can convert records to JSON. 37 | 38 | scala> val p = (name -> "Joe") :: (age -> 36) :: HNil 39 | scala> sqltyped.json4s.JSON.compact(p) 40 | res0: String = {"name":"Joe","age":36} 41 | 42 | ## Parses a subset of standard SQL syntax. 43 | 44 | ## Has some MySQL specific extensions (ON DUPLICATE KEY UPDATE, INTERVAL, ...) 45 | -------------------------------------------------------------------------------- /notes/0.2.0.markdown: -------------------------------------------------------------------------------- 1 | [sqlτyped](https://github.com/jonifreeman/sqltyped) - a macro which infers Scala types from database. 2 | 3 | This release adds following new features and improvements: 4 | 5 | ## [Runtime query building](https://github.com/jonifreeman/sqltyped/wiki/User-guide#wiki-runtime) with interpolation syntax 6 | 7 | scala> sql"select name from person where $where order by age".apply(Seq(5)) 8 | res0: List[String] = List("moe", "joe") 9 | 10 | ## Fallback to JDBC based inference when SQL parsing fails 11 | 12 | sqlτyped uses custom SQL parsers to parse the SQL statements. The advantage of custom parsing is that it enables better type inference. It is possible to do more thorough query analysis compared to what JDBC API provides. In addition, some JDBC drivers are notoriously bad when it comes to query analysis (MySQL, I'm looking at you ;). The disadvantage of custom parsing is that it will take some time to polish parsers to support all quirks and nonstandard syntax of all SQL dialects. To get a best of both worlds, sqlτyped first tries its more exact inference analysis and if it fails fallsback to JDBC based analysis. 13 | 14 | -------------------------------------------------------------------------------- /notes/0.3.0.markdown: -------------------------------------------------------------------------------- 1 | [sqlτyped](https://github.com/jonifreeman/sqltyped) - a macro which infers Scala types by analysing SQL statements. 2 | 3 | This release adds following new features and improvements: 4 | 5 | ## Support for multiple compile time datasources 6 | 7 | In Scala code bring correct configuration into scope by: 8 | 9 | implicit object postgresql extends ConfigurationName 10 | 11 | The above configuration reads following system properties to establish a connection to database at compile time: 12 | 13 | sqltyped.postgresql.url 14 | sqltyped.postgresql.driver 15 | sqltyped.postgresql.username 16 | sqltyped.postgresql.password 17 | 18 | ## Support for collections in IN-clause 19 | 20 | JDBC does not have a standard way to support queries like: 21 | 22 | select age from person where name in (?) 23 | 24 | where parameter is a collection of values. Some drivers support PreparedStatement#setArray(), MySQL implicitely supports this through PreparedStatement#setObject() etc. 25 | 26 | This version adds following support: 27 | 28 | ### MySQL 29 | 30 | select age from person where name in (?) 31 | 32 | is typed as 33 | 34 | Seq[String] => List[Int] 35 | 36 | ### PostgreSQL 37 | 38 | select age from person where name = any(?) 39 | 40 | is typed as 41 | 42 | Seq[String] => List[Int] 43 | 44 | ## MySQL type conversion support 45 | 46 | scala> sql("select convert(age, char(10)) from person").apply 47 | res0: List[String] = List("36", "14") 48 | 49 | ## Results to CSV conversion 50 | 51 | scala> val rows = sql("select name as fname, age from person limit 100").apply.values 52 | scala> CSV.fromList(rows) 53 | res1: String = 54 | "joe","36" 55 | "moe","14" 56 | 57 | ## Results to untyped List conversion 58 | 59 | scala> val rows = sql("select name, age from person limit 100").apply 60 | scala> Record.toTupleLists(rows) 61 | res2: List[List[(String, Any)]] = List(List(("name", "joe"), ("age", 36)), List(("name", "moe"), ("age", 14))) 62 | 63 | ## Improved nullability analysis 64 | 65 | In previous versions, if a selected column was nullable it was always boxed to Option[A]. However, if the column is restricted with `WHERE x IS NOT NULL` and the expression contains only 'and' operators then it obviously can't be null and boxing is unnecessary. 66 | -------------------------------------------------------------------------------- /notes/0.4.0.markdown: -------------------------------------------------------------------------------- 1 | [sqlτyped](https://github.com/jonifreeman/sqltyped) - a macro which infers Scala types by analysing SQL statements. 2 | 3 | This release adds following new features and improvements: 4 | 5 | ## Switching to Shapeless 2.0 records 6 | 7 | Previous version encoded records as HList of pairs (key, value). A big downside of this was the need to define record keys explicitely before use. Shapeless 2.0 comes with improved record encoding where record is a HList of values tagged by singleton types representing keys. 8 | 9 | scala> val r = ("name" ->> "Joe") :: ("age" ->> 13) :: HNil 10 | 11 | scala> r get "age" 12 | res0: Int = 13 13 | 14 | scala> r get "bzzzt" 15 | :22: error: No field String("bzzzt") in record ... 16 | 17 | As a consequence names of some common functions to manipulate records has been changed. 18 | 19 | scala> r - "name" // was 'removeKey' 20 | scala> r.renameField("name", "nme") // was 'renameKey' 21 | scala> r.updateWith("age")(_ + 1) // was 'modify' 22 | 23 | ## Configurable naming strategy for record field names 24 | 25 | By default the record fields are named identically to database column names. This can 26 | be altered by providing a function String => String. This function must be put to 27 | compiler's classpath and passed to macro with environment property: 28 | 29 | sqltyped.naming_strategy 30 | 31 | For example: 32 | 33 | object MyNamingStrategy extends (String => String) { 34 | def apply(s: String) = .... 35 | } 36 | 37 | Then: 38 | 39 | System.setProperty("sqltyped.naming_strategy", "MyNamingStrategy$") 40 | 41 | -------------------------------------------------------------------------------- /notes/about.markdown: -------------------------------------------------------------------------------- 1 | [sqlτyped](https://github.com/jonifreeman/sqltyped) converts SQL string literals into typed functions at compile time. 2 | 3 | select age, name from person where age > ? 4 | 5 | ==> 6 | 7 | Int => List[{ age: Int, name: String }] 8 | 9 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.0 2 | -------------------------------------------------------------------------------- /project/build.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object SqltypedBuild extends Build with Publish { 5 | import Resolvers._ 6 | 7 | //lazy val versionFormat = "%s" 8 | lazy val majorVersion = "0.4.3" 9 | lazy val versionFormat = "%s-SNAPSHOT" 10 | 11 | lazy val sqltypedSettings = Defaults.defaultSettings ++ publishSettings ++ Seq( 12 | organization := "fi.reaktor", 13 | version := versionFormat format majorVersion, 14 | scalaVersion := "2.11.7", 15 | scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature"), 16 | javacOptions ++= Seq("-target", "1.6", "-source", "1.6"), 17 | crossScalaVersions := Seq("2.11"), 18 | parallelExecution in Test := false, 19 | resolvers ++= Seq(sonatypeNexusSnapshots, sonatypeNexusReleases) 20 | ) 21 | 22 | lazy val root = Project("root", file(".")) aggregate(core) 23 | 24 | lazy val core = Project( 25 | id = "sqltyped", 26 | base = file("core"), 27 | settings = sqltypedSettings ++ Seq( 28 | libraryDependencies ++= Seq( 29 | "com.chuusai" %% "shapeless" % "2.3.1", 30 | "net.sourceforge.schemacrawler" % "schemacrawler" % "8.17", 31 | "org.scala-lang" % "scala-reflect" % "2.11.7", 32 | "org.scalatest" %% "scalatest" % "2.2.6" % "test", 33 | "org.scala-lang" % "scala-actors" % "2.11.7" % "test", 34 | "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.4", 35 | "mysql" % "mysql-connector-java" % "5.1.21" % "test", 36 | "postgresql" % "postgresql" % "9.1-901.jdbc4" % "test" 37 | ), 38 | initialize ~= { _ => initSqltyped } 39 | ) 40 | ) 41 | 42 | lazy val json4s = Project( 43 | id = "sqltyped-json4s", 44 | base = file("json4s"), 45 | settings = sqltypedSettings ++ Seq( 46 | libraryDependencies ++= Seq( 47 | "org.json4s" %% "json4s-native" % "3.3.0" 48 | ) 49 | ) 50 | ) dependsOn(core % "compile;test->test;provided->provided") 51 | 52 | lazy val slickIntegration = Project( 53 | id = "sqltyped-slick", 54 | base = file("slick-integration"), 55 | settings = sqltypedSettings ++ Seq( 56 | libraryDependencies ++= Seq( 57 | "com.typesafe.slick" %% "slick" % "3.1.1" 58 | ), 59 | initialize ~= { _ => initSqltyped } 60 | ) 61 | ) dependsOn(core % "compile;test->test;provided->provided") 62 | 63 | def initSqltyped { 64 | System.setProperty("sqltyped.url", "jdbc:mysql://localhost:3306/sqltyped") 65 | System.setProperty("sqltyped.driver", "com.mysql.jdbc.Driver") 66 | System.setProperty("sqltyped.username", "root") 67 | System.setProperty("sqltyped.password", "") 68 | 69 | System.setProperty("sqltyped.postgresql.url", "jdbc:postgresql://localhost/sqltyped") 70 | System.setProperty("sqltyped.postgresql.driver", "org.postgresql.Driver") 71 | System.setProperty("sqltyped.postgresql.username", "sqltypedtest") 72 | System.setProperty("sqltyped.postgresql.password", "secret") 73 | System.setProperty("sqltyped.postgresql.schema", "sqltyped") 74 | } 75 | 76 | object Resolvers { 77 | val sonatypeNexusSnapshots = "Sonatype Nexus Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots" 78 | val sonatypeNexusReleases = "Sonatype Nexus Releases" at "https://oss.sonatype.org/content/repositories/releases" 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.typesafe.sbt" % "sbt-pgp" % "0.8.3") 2 | -------------------------------------------------------------------------------- /project/publish.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | trait Publish { 5 | val nexus = "https://oss.sonatype.org/" 6 | val snapshots = "snapshots" at nexus + "content/repositories/snapshots" 7 | val releases = "releases" at nexus + "service/local/staging/deploy/maven2" 8 | 9 | lazy val publishSettings = Seq( 10 | publishMavenStyle := true, 11 | publishTo <<= version((v: String) => Some(if (v.trim endsWith "SNAPSHOT") snapshots else releases)), 12 | publishArtifact in Test := false, 13 | pomIncludeRepository := (_ => false), 14 | pomExtra := projectPomExtra 15 | ) 16 | 17 | val projectPomExtra = 18 | https://github.com/jonifreeman/sqltyped 19 | 20 | 21 | Apache License 22 | http://www.apache.org/licenses/ 23 | repo 24 | 25 | 26 | 27 | git@github.com:jonifreeman/sqltyped.git 28 | scm:git:git@github.com:jonifreeman/sqltyped.git 29 | 30 | 31 | 32 | jonifreeman 33 | Joni Freeman 34 | https://twitter.com/jonifreeman 35 | 36 | 37 | } 38 | -------------------------------------------------------------------------------- /slick-integration/src/test/scala/slickexample.scala: -------------------------------------------------------------------------------- 1 | package sqltyped.slick 2 | 3 | import scala.slick.session.Database 4 | import org.scalatest._ 5 | import sqltyped._ 6 | 7 | class SlickExample extends FunSuite with BeforeAndAfterEach with matchers.ShouldMatchers { 8 | val db = Database.forURL("jdbc:mysql://localhost:3306/sqltyped", 9 | driver = "com.mysql.jdbc.Driver", user = "root", password = "") 10 | 11 | implicit val c = Configuration() 12 | implicit def conn = Database.threadLocalSession.conn 13 | 14 | override def beforeEach() { 15 | val newPerson = sql("insert into person(id, name, age, salary) values (?, ?, ?, ?)") 16 | 17 | db withSession { 18 | sql("delete from person").apply 19 | 20 | newPerson(1, "joe", 36, 9500) 21 | newPerson(2, "moe", 14, 8000) 22 | } 23 | } 24 | 25 | test("with session") { 26 | val q = sql("select name, age from person where age > ?") 27 | 28 | db withSession { 29 | q(30).tuples should equal(List(("joe", 36))) 30 | } 31 | } 32 | 33 | test("transaction") { 34 | db withTransaction { 35 | sql("update person set name=? where id=?").apply("danny", 1) 36 | Database.threadLocalSession.rollback 37 | } 38 | 39 | db withSession { 40 | sql("select name from person where id=?").apply(1) should equal(Some("joe")) 41 | } 42 | } 43 | } 44 | 45 | --------------------------------------------------------------------------------