├── .git-blame-ignore-revs ├── .github └── workflows │ └── scala.yml ├── .gitignore ├── .scalafmt.conf ├── LICENSE ├── Readme.md ├── build.sbt ├── project ├── build.properties └── plugins.sbt ├── publish.sh └── src ├── main └── scala │ └── usql │ ├── Batch.scala │ ├── ConnectionProvider.scala │ ├── DataType.scala │ ├── Query.scala │ ├── RawSql.scala │ ├── RowDecoder.scala │ ├── RowEncoder.scala │ ├── Sql.scala │ ├── SqlBase.scala │ ├── SqlIdentifier.scala │ ├── SqlInterpoationParameter.scala │ ├── SqlReservedWords.scala │ ├── Update.scala │ ├── dao │ ├── Alias.scala │ ├── Annotations.scala │ ├── ColumnGroupMapping.scala │ ├── ColumnPath.scala │ ├── Crd.scala │ ├── Macros.scala │ ├── NameMapping.scala │ ├── SqlColumn.scala │ ├── SqlColumnar.scala │ ├── SqlFielded.scala │ └── SqlTabular.scala │ └── profiles │ ├── BasicProfile.scala │ ├── H2Profile.scala │ └── PostgresProfile.scala └── test └── scala ├── com └── example │ └── example.sc └── usql ├── AutoGeneratedUpdateTest.scala ├── HelloDbTest.scala ├── SqlIdentifierTest.scala ├── SqlInterpolationTest.scala ├── dao ├── ColumnPathTest.scala ├── KeyedCrudBaseTest.scala ├── NameMappingTest.scala ├── SimpleJoinTest.scala ├── SqlColumnarTest.scala ├── SqlCrdBaseTest.scala ├── SqlFieldedTest.scala └── SqlTabularTest.scala └── util ├── TestBase.scala └── TestBaseWithH2.scala /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Scala Steward: Reformat with scalafmt 3.9.6 2 | 7d0d5e735cdc6303b4b33f692e02fd25c4fa2e6f 3 | -------------------------------------------------------------------------------- /.github/workflows/scala.yml: -------------------------------------------------------------------------------- 1 | name: Scala CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up JDK 17 20 | uses: actions/setup-java@v4 21 | with: 22 | java-version: '17' 23 | distribution: 'temurin' 24 | cache: 'sbt' 25 | - uses: sbt/setup-sbt@v1 26 | - name: Run tests 27 | run: sbt "test;scalafmtCheckAll" 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | .idea 3 | .bsp 4 | .bloop 5 | .metals 6 | .vscode 7 | metals.sbt 8 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version=3.9.7 2 | maxColumn = 120 3 | assumeStandardLibraryStripMargin = true 4 | newlines.beforeCurlyLambdaParams = never 5 | newlines.implicitParamListModifierPrefer = before 6 | comments.wrap = no 7 | align.preset = most 8 | runner.dialect=scala3 9 | docstrings.style = Asterisk -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # usql micro JDBC toolkit for Scala 3 2 | 3 | usql is a small jdbc wrapper to automate recurring patterns and 4 | to simplify writing SQL typical Actions in the age of direct style scala. 5 | 6 | Note: this is Beta software. Only Postgres and H2 are supported yet (although it's 7 | easy to write more Profiles). 8 | 9 | ## Installation 10 | 11 | Version Matrix 12 | 13 | | Version | JVM Version | Scala Version | 14 | |---------|-------------|---------------| 15 | | 0.3.x | 21+ | 3.7.x+ | 16 | | 0.2.x | 17+ | 3.3.x+ | 17 | 18 | Add to build.sbt 19 | 20 | ```scala 21 | libraryDependencies += "net.reactivecore" %% "usql" % "CURRENT_VERSION" 22 | ``` 23 | 24 | Replace `CURRENT_VERSION` with current version (e.g. `0.2.0`) 25 | 26 | ## Features 27 | 28 | - No dependencies 29 | - Fast compile speed 30 | - Functional API 31 | - Extensible 32 | - SQL Interpolation 33 | - Simple CRUD (Create, Replace, Update, Modify) / DAO-Object generation for your case classes. 34 | 35 | ## Non-Features 36 | 37 | - Not bound to effect system 38 | - No ORM 39 | - JDBC Only. 40 | - No Connection-Management, but easy to connect to [HikariCP](https://github.com/brettwooldridge/HikariCP) 41 | - No query validation (this should be done by testcases) 42 | - No DDL Generation 43 | 44 | ## Supported Databases 45 | 46 | - `BasicProfile` supports basic types for most JDBC-Compatible Databases 47 | - `H2Profile` for H2 48 | - `PostgresProfile` for Postgres 49 | 50 | The profiles can be incomplete, but should be easy to extend for your needs. 51 | 52 | ## Prior Art 53 | 54 | - A lot of ideas are from [Anorm](https://playframework.github.io/anorm/) 55 | - [Magnum](https://github.com/AugustNagro/magnum), quite similar but more advanced. 56 | 57 | # Examples 58 | 59 | Also see the Example in [example.sc](src/test/scala/com/example/example.sc) 60 | 61 | ## Connecting to a Database 62 | 63 | To use usql you need to provide a given `ConnectionProvider`, this can be as easy as: 64 | 65 | ```scala 3 66 | import usql.* 67 | import usql.profiles.H2Profile.* 68 | 69 | val jdbcUrl = "" 70 | given cp: ConnectionProvider with { 71 | override def withConnection[T](f: Connection ?=> T): T = { 72 | Using.resource(DriverManager.getConnection(jdbcUrl)) { c => 73 | f(using c) 74 | } 75 | } 76 | } 77 | ``` 78 | 79 | ## Simple Actions 80 | 81 | ```scala 3 82 | sql"CREATE TABLE person (id INT PRIMARY KEY, name TEXT)" 83 | .execute() 84 | ``` 85 | 86 | Using Interpolation, which will be used as parameter for prepared statements 87 | 88 | ```scala 3 89 | sql"INSERT INTO person (id, name) VALUES (${1}, ${"Alice"})" 90 | .execute() 91 | 92 | sql"INSERT INTO person (id, name) VALUES (${2}, ${"Bob"})" 93 | .execute() 94 | ``` 95 | 96 | ## Queries and Interpolation 97 | 98 | Simple Queries: 99 | 100 | ```scala 3 101 | val all: Vector[(Int, String)] = sql"SELECT id, name FROM person".query.all[(Int, String)]() 102 | println(s"All=${all}") 103 | ``` 104 | 105 | ```scala 3 106 | val one: Option[(Int, String)] = sql"SELECT id, name FROM #${"person"} WHERE id = ${1}".query.one[(Int, String)]() 107 | println(s"One=${one}") 108 | ``` 109 | 110 | Encoding multiple Parameters (e.g. SQL-In-Operator): 111 | 112 | ```scala 3 113 | val ids = Seq(1,2,3) 114 | val names = sql"SELECT name FROM person WHERE id IN (${SqlParameters(ids)})".query.all[String]() 115 | println(s"Names=${names}") 116 | ``` 117 | 118 | ## Inserts 119 | 120 | ```scala 3 121 | // Single Insert 122 | sql"INSERT INTO person (id, name) VALUES(?, ?)".one((3, "Charly")).update.run() 123 | 124 | // Batch Insert 125 | sql"INSERT INTO person (id, name) VALUES(?, ?)" 126 | .batch( 127 | Seq( 128 | 4 -> "Dave", 129 | 5 -> "Emil" 130 | ) 131 | ) 132 | .run() 133 | 134 | sql"SELECT COUNT(*) FROM person".query.one[Int]().get 135 | // is 5 136 | ``` 137 | 138 | ## Reusable Parts 139 | 140 | You can concatenate sql parts: 141 | 142 | ```scala 3 143 | val select = sql"SELECT id, name FROM person" 144 | val selectAlice = (select + sql" WHERE id = ${1}").query.one[(Int, String)]() 145 | println(s"Alice: ${selectAlice}") 146 | ``` 147 | 148 | ## Transactions 149 | 150 | This fails because of the duplicate entry with id `100`, but at the end both are not inside: 151 | ```scala 3 152 | Try { 153 | transaction { 154 | sql"INSERT INTO person(id, name) VALUES(${100}, ${"Duplicate"})".execute() 155 | sql"INSERT INTO person(id, name) VALUES(${100}, ${"Duplicate 2"})".execute() 156 | } 157 | } 158 | ``` 159 | 160 | ## Automatic DAO Objects 161 | 162 | DAO (Data Access Objects) can be created using the base classes `CrdBase` and `KeyedCrudBase`. 163 | 164 | They are using a helper description object called `SqlColumnar` and `SqlTabular`. 165 | 166 | ```scala 3 167 | import usql.dao.* 168 | 169 | case class Person( 170 | id: Int, 171 | name: String 172 | ) derives SqlTabular 173 | 174 | object Person extends KeyedCrudBase[Int, Person] { 175 | override def key: KeyColumnPath = cols.id 176 | 177 | override def keyOf(value: Person): Int = value.id 178 | 179 | override lazy val tabular: SqlTabular[Person] = summon 180 | } 181 | 182 | println(s"All Persons: ${Person.findAll()}") 183 | 184 | Person.insert(Person(6, "Fritz")) 185 | Person.update(Person(6, "Franziska")) 186 | println(Person.findByKey(6)) // Person(6, Franziska) 187 | ``` 188 | 189 | ## Scala 3.7.0+ Named Tuples 190 | 191 | ```scala 3 192 | // Person.col.id will be automatically checked. 193 | val allAgain: Vector[(Int, String)] = 194 | sql"SELECT ${Person.cols.id}, ${Person.cols.name} FROM ${Person}".query.all[(Int, String)]() 195 | 196 | println(s"allAgain=${allAgain}") 197 | ``` 198 | 199 | # Core Types 200 | 201 | - `DataType` a type class which derives how to fetch a Type `T` from a `ResultSet` and how to store it in a `PreparedStatement` 202 | - `RowDecoder` type class for fetching tuples / values from `ResultSet` 203 | - `RowEncoder` type class for filling tuples / values into a `PreparedStatement` 204 | - `SqlIdentifier` an SQL identifier, quoted if necessary. 205 | - `RawSql` Raw SQL Queries 206 | - `Sql` interpolated SQL Queries 207 | 208 | ## DAO Core Types 209 | 210 | - `SqlColumnar` describes the columns and codec for a case class `T`, macro generated 211 | - `SqlTabular` like `SqlColumnar`, but also contains a table name 212 | - `Crd` basic Create-Read-Delete operations 213 | - `KeyedCrud` Crd for single-keyed types 214 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import xerial.sbt.Sonatype.GitHubHosting 2 | 3 | // If there is a Tag starting with v, e.g. v0.3.0 use it as the build artefact version (e.g. 0.3.0) 4 | val versionTag = sys.env 5 | .get("CI_COMMIT_TAG") 6 | .filter(_.startsWith("v")) 7 | .map(_.stripPrefix("v")) 8 | 9 | val snapshotVersion = "0.3-SNAPSHOT" 10 | val artefactVersion = versionTag.getOrElse(snapshotVersion) 11 | 12 | ThisBuild / scalacOptions ++= Seq("-feature") 13 | 14 | def publishSettings = Seq( 15 | publishTo := sonatypePublishToBundle.value, 16 | sonatypeBundleDirectory := (ThisBuild / baseDirectory).value / "target" / "sonatype-staging" / s"${version.value}", 17 | licenses := Seq("APL2" -> url("http://www.apache.org/licenses/LICENSE-2.0.txt")), 18 | homepage := Some(url("https://github.com/reactivecore/usql")), 19 | sonatypeProjectHosting := Some(GitHubHosting("reactivecore", "usql", "contact@reactivecore.de")), 20 | developers := List( 21 | Developer( 22 | id = "nob13", 23 | name = "Norbert Schultz", 24 | email = "norbert.schultz@reactivecore.de", 25 | url = url("https://www.reactivecore.de") 26 | ) 27 | ), 28 | publish / test := {}, 29 | publishLocal / test := {} 30 | ) 31 | 32 | usePgpKeyHex("77D0E9E04837F8CBBCD56429897A43978251C225") 33 | 34 | ThisBuild / version := artefactVersion 35 | ThisBuild / organization := "net.reactivecore" 36 | ThisBuild / scalaVersion := "3.7.1" 37 | ThisBuild / Test / fork := true 38 | ThisBuild / scalacOptions ++= Seq("-new-syntax", "-rewrite") 39 | 40 | val scalaTestVersion = "3.2.19" 41 | 42 | lazy val root = (project in file(".")) 43 | .settings( 44 | name := "usql", 45 | libraryDependencies ++= Seq( 46 | "org.scalatest" %% "scalatest" % scalaTestVersion % Test, 47 | "org.scalatest" %% "scalatest-flatspec" % scalaTestVersion % Test, 48 | "com.h2database" % "h2" % "2.3.232" % Test 49 | ), 50 | publishSettings 51 | ) 52 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.11.2 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.timushev.sbt" % "sbt-updates" % "0.6.4") 2 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.12.2") 3 | addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.3.1") 4 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4") 5 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | export CI_COMMIT_TAG=`git describe --tags` 4 | echo "Publishing for $CI_COMMIT_TAG" 5 | source ~/bin/java21.sh 6 | 7 | read -p "Press enter to continue" 8 | 9 | sbt publishSigned sonatypeBundleRelease 10 | 11 | 12 | -------------------------------------------------------------------------------- /src/main/scala/usql/Batch.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | /** Encapsulates a batch (insert) */ 4 | case class Batch[T](sql: SqlBase, values: IterableOnce[T], filler: RowEncoder[T]) { 5 | def run()(using cp: ConnectionProvider): Seq[Int] = { 6 | sql.withPreparedStatement { ps => 7 | values.iterator.foreach { value => 8 | filler.fill(ps, value) 9 | ps.addBatch() 10 | } 11 | val results = ps.executeBatch() 12 | results.toSeq 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/usql/ConnectionProvider.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import java.sql.Connection 4 | import scala.util.control.NonFatal 5 | 6 | /** Provider for JDBC Connections. */ 7 | trait ConnectionProvider { 8 | 9 | /** Run some code using this connection. */ 10 | def withConnection[T](f: Connection ?=> T): T 11 | } 12 | 13 | /** Helper for building transactions. */ 14 | def transaction[T](using cp: ConnectionProvider)(f: ConnectionProvider ?=> T): T = { 15 | cp.withConnection { 16 | val c = summon[Connection] 17 | 18 | val oldAutoCommit = c.getAutoCommit 19 | c.setAutoCommit(false) 20 | try { 21 | val res = f(using ConnectionProvider.forConnection(using c)) 22 | c.commit() 23 | res 24 | } catch { 25 | case NonFatal(e) => 26 | c.rollback() 27 | throw e 28 | } finally { 29 | c.setAutoCommit(oldAutoCommit) 30 | } 31 | } 32 | } 33 | 34 | object ConnectionProvider { 35 | given forConnection(using c: Connection): ConnectionProvider with { 36 | override def withConnection[T](f: Connection ?=> T): T = { 37 | f(using c) 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/usql/DataType.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import java.sql.{Connection, JDBCType, PreparedStatement, ResultSet} 4 | 5 | /** Type class describing a type to use. */ 6 | trait DataType[T] { 7 | 8 | /** Serialize a value (e.g. Debugging) */ 9 | def serialize(value: T): String = value.toString 10 | 11 | /** The underlying jdbc type. */ 12 | def jdbcType: JDBCType 13 | 14 | // Extractors from ResultSet 15 | 16 | def extractByZeroBasedIndex(idx: Int, rs: ResultSet): T = { 17 | extractBySqlIdx(idx + 1, rs) 18 | } 19 | 20 | def extractBySqlIdx(cIdx: Int, rs: ResultSet): T 21 | 22 | def extractOptionalBySqlIdx(cIdx: Int, rs: ResultSet): Option[T] = { 23 | val candidate = Option(extractBySqlIdx(cIdx, rs)) 24 | if rs.wasNull() then { 25 | None 26 | } else { 27 | candidate 28 | } 29 | } 30 | 31 | def extractByName(columnLabel: String, resultSet: ResultSet): T = { 32 | val sqlIdx = resultSet.findColumn(columnLabel) 33 | extractBySqlIdx(sqlIdx, resultSet) 34 | } 35 | 36 | // Fillers 37 | 38 | def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: T): Unit 39 | 40 | def fillByZeroBasedIdx(idx: Int, ps: PreparedStatement, value: T): Unit = { 41 | fillBySqlIdx(idx + 1, ps, value) 42 | } 43 | 44 | /** Adapt to another type. */ 45 | def adapt[U](mapFn: T => U, contraMapFn: U => T): DataType[U] = { 46 | val me = this 47 | new DataType[U] { 48 | override def jdbcType: JDBCType = me.jdbcType 49 | 50 | override def extractBySqlIdx(cIdx: Int, rs: ResultSet): U = mapFn(me.extractBySqlIdx(cIdx, rs)) 51 | 52 | override def extractOptionalBySqlIdx(cIdx: Int, rs: ResultSet): Option[U] = { 53 | me.extractOptionalBySqlIdx(cIdx, rs).map(mapFn) 54 | } 55 | 56 | override def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: U): Unit = { 57 | me.fillBySqlIdx(pIdx, ps, contraMapFn(value)) 58 | } 59 | } 60 | } 61 | 62 | /** Adapt to another type, also providing the prepared statement */ 63 | def adaptWithPs[U](mapFn: T => U, contraMapFn: (U, PreparedStatement) => T): DataType[U] = { 64 | val me = this 65 | new DataType[U] { 66 | override def jdbcType: JDBCType = me.jdbcType 67 | 68 | override def extractBySqlIdx(cIdx: Int, rs: ResultSet): U = { 69 | mapFn(me.extractBySqlIdx(cIdx, rs)) 70 | } 71 | 72 | override def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: U): Unit = 73 | me.fillBySqlIdx(pIdx, ps, contraMapFn(value, ps)) 74 | } 75 | } 76 | } 77 | 78 | object DataType { 79 | def simple[T]( 80 | jdbc: JDBCType, 81 | rsExtractor: (ResultSet, Int) => T, 82 | filler: (PreparedStatement, Int, T) => Unit 83 | ): DataType[T] = new DataType[T] { 84 | override def jdbcType: JDBCType = jdbc 85 | 86 | override def extractBySqlIdx(cIdx: Int, rs: ResultSet): T = rsExtractor(rs, cIdx) 87 | 88 | override def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: T): Unit = filler(ps, pIdx, value) 89 | } 90 | 91 | def get[T](using dt: DataType[T]): DataType[T] = dt 92 | } 93 | -------------------------------------------------------------------------------- /src/main/scala/usql/Query.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import java.sql.{Connection, ResultSet} 4 | import scala.util.Using 5 | 6 | /** The user wants o issue a query. */ 7 | case class Query(sql: SqlBase) { 8 | 9 | /** Run a query for one row. */ 10 | def one[T]()(using rowParser: RowDecoder[T], cp: ConnectionProvider): Option[T] = { 11 | run { resultSet => 12 | if resultSet.next() then { 13 | Some(rowParser.parseRow(resultSet)) 14 | } else { 15 | None 16 | } 17 | } 18 | } 19 | 20 | /** Run a query for all rows. */ 21 | def all[T]()(using rowParser: RowDecoder[T], cp: ConnectionProvider): Vector[T] = { 22 | run { resultSet => 23 | val builder = Vector.newBuilder[T] 24 | while resultSet.next() do { 25 | builder += rowParser.parseRow(resultSet) 26 | } 27 | builder.result() 28 | } 29 | } 30 | 31 | /** Run with some method decoding the result set. */ 32 | private def run[T](f: ResultSet => T)(using cp: ConnectionProvider): T = { 33 | sql.withPreparedStatement { statement => 34 | Using.resource(statement.executeQuery()) { resultSet => 35 | f(resultSet) 36 | } 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/usql/RawSql.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import java.sql.{Connection, PreparedStatement} 4 | import scala.util.Using 5 | 6 | /** Raw SQL Query string. */ 7 | case class RawSql(sql: String) extends SqlBase { 8 | override def withPreparedStatement[T]( 9 | f: PreparedStatement => T 10 | )(using cp: ConnectionProvider, sp: StatementPreparator): T = { 11 | cp.withConnection { 12 | val c = summon[Connection] 13 | Using.resource(sp.prepare(c, sql)) { statement => 14 | f(statement) 15 | } 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/usql/RowDecoder.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.dao.SqlColumnar 4 | import java.sql.ResultSet 5 | 6 | /** Decoder for singles rows in a [[ResultSet]] */ 7 | trait RowDecoder[T] { 8 | 9 | /** Parse a single row. */ 10 | def parseRow(offset: Int, row: ResultSet): T 11 | 12 | /** Parse at offset 0 */ 13 | def parseRow(row: ResultSet): T = parseRow(0, row) 14 | 15 | def map[U](f: T => U): RowDecoder[U] = { 16 | val me = this 17 | new RowDecoder[U] { 18 | override def parseRow(offset: Int, row: ResultSet): U = { 19 | f(me.parseRow(offset, row)) 20 | } 21 | 22 | override def cardinality: Int = me.cardinality 23 | } 24 | } 25 | 26 | /** The number of elements consumed by the decoder. */ 27 | def cardinality: Int 28 | } 29 | 30 | object RowDecoder { 31 | given forTuple[H, T <: Tuple]( 32 | using headDecoder: RowDecoder[H], 33 | tailDecoder: RowDecoder[T] 34 | ): RowDecoder[H *: T] = new RowDecoder[H *: T] { 35 | override def parseRow(offset: Int, row: ResultSet): H *: T = { 36 | val h = headDecoder.parseRow(offset, row) 37 | val t = tailDecoder.parseRow(offset + headDecoder.cardinality, row) 38 | (h *: t) 39 | } 40 | 41 | override def cardinality: Int = headDecoder.cardinality + tailDecoder.cardinality 42 | } 43 | 44 | given empty: RowDecoder[EmptyTuple] = new RowDecoder[EmptyTuple] { 45 | override def parseRow(offset: Int, row: ResultSet): EmptyTuple = EmptyTuple 46 | 47 | override def cardinality: Int = 0 48 | } 49 | 50 | given forDataType[T](using dt: DataType[T]): RowDecoder[T] = new RowDecoder[T] { 51 | override def parseRow(offset: Int, row: ResultSet): T = dt.extractByZeroBasedIndex(offset, row) 52 | 53 | override def cardinality: Int = 1 54 | } 55 | 56 | given forOptional[T](using rd: RowDecoder[T]): RowDecoder[Option[T]] = new RowDecoder[Option[T]] { 57 | override def parseRow(offset: Int, row: ResultSet): Option[T] = { 58 | val isNone = (0 until cardinality).forall { baseIdx => 59 | val cIdx = offset + baseIdx + 1 60 | val _ = row.getObject(cIdx) 61 | row.wasNull() 62 | } 63 | if isNone then { 64 | None 65 | } else { 66 | val inner = rd.parseRow(offset, row) 67 | Some(inner) 68 | } 69 | } 70 | 71 | override def cardinality: Int = rd.cardinality 72 | } 73 | 74 | given forColumnar[T](using c: SqlColumnar[T]): RowDecoder[T] = c.rowDecoder 75 | } 76 | -------------------------------------------------------------------------------- /src/main/scala/usql/RowEncoder.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.dao.SqlColumnar 4 | 5 | import java.sql.PreparedStatement 6 | 7 | /** Responsible for filling arguments into prepared statements for batch operations. */ 8 | trait RowEncoder[T] { 9 | 10 | /** Fill something at the zero-based position index into the prepared statement. */ 11 | def fill(offset: Int, ps: PreparedStatement, value: T): Unit 12 | 13 | /** Fill some value without type checking. */ 14 | private[usql] def fillUnchecked(offset: Int, ps: PreparedStatement, value: Any): Unit = { 15 | fill(offset, ps, value.asInstanceOf[T]) 16 | } 17 | 18 | /** Fill at position 0 */ 19 | def fill(ps: PreparedStatement, value: T): Unit = fill(0, ps, value) 20 | 21 | def contraMap[U](f: U => T): RowEncoder[U] = { 22 | val me = this 23 | new RowEncoder[U] { 24 | override def fill(offset: Int, ps: PreparedStatement, value: U): Unit = me.fill(offset, ps, f(value)) 25 | 26 | override def cardinality: Int = me.cardinality 27 | } 28 | } 29 | 30 | /** The number of elements set by this filler */ 31 | def cardinality: Int 32 | } 33 | 34 | object RowEncoder { 35 | 36 | given forTuple[H, T <: Tuple]( 37 | using headFiller: RowEncoder[H], 38 | tailFiller: RowEncoder[T] 39 | ): RowEncoder[H *: T] = new RowEncoder[H *: T] { 40 | override def fill(offset: Int, ps: PreparedStatement, value: H *: T): Unit = { 41 | headFiller.fill(offset, ps, value.head) 42 | tailFiller.fill(offset + headFiller.cardinality, ps, value.tail) 43 | } 44 | 45 | override def cardinality: Int = { 46 | headFiller.cardinality + tailFiller.cardinality 47 | } 48 | } 49 | 50 | given empty: RowEncoder[EmptyTuple] = new RowEncoder[EmptyTuple] { 51 | override def fill(offset: Int, ps: PreparedStatement, value: EmptyTuple): Unit = () 52 | 53 | override def cardinality: Int = 0 54 | } 55 | 56 | given forDataType[T](using dt: DataType[T]): RowEncoder[T] = new RowEncoder[T] { 57 | override def fill(offset: Int, ps: PreparedStatement, value: T): Unit = dt.fillByZeroBasedIdx(offset, ps, value) 58 | 59 | override def cardinality: Int = 1 60 | } 61 | 62 | given forColumnar[T](using c: SqlColumnar[T]): RowEncoder[T] = c.rowEncoder 63 | } 64 | -------------------------------------------------------------------------------- /src/main/scala/usql/Sql.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import SqlInterpolationParameter.{InnerSql, SqlParameter} 4 | 5 | import java.sql.{Connection, PreparedStatement} 6 | import scala.annotation.{tailrec, targetName} 7 | import scala.language.implicitConversions 8 | import scala.util.Using 9 | 10 | extension (sc: StringContext) { 11 | def sql(parameters: SqlInterpolationParameter*): Sql = { 12 | Sql(fixParameters(parameters)) 13 | } 14 | 15 | /** Bring parameters into a canonical format. */ 16 | private def fixParameters(parameters: Seq[SqlInterpolationParameter]): Seq[(String, SqlInterpolationParameter)] = { 17 | @tailrec 18 | def fix( 19 | parts: List[String], 20 | params: List[SqlInterpolationParameter], 21 | builder: List[(String, SqlInterpolationParameter)] 22 | ): List[(String, SqlInterpolationParameter)] = { 23 | (parts, params) match { 24 | case (Nil, _) => 25 | // No more parts 26 | builder 27 | case (part :: restParts, Nil) if part.isEmpty => 28 | // Skip it, empty part and no parameter 29 | fix(restParts, Nil, builder) 30 | case (part :: restParts, Nil) => 31 | // More Parts but no parameters 32 | fix(restParts, Nil, (part, SqlInterpolationParameter.Empty) :: builder) 33 | case (part :: restParts, (param: SqlParameter[?]) :: restParams) if part.endsWith("#") => 34 | // Getting #${..} parameters to work 35 | val replacedPart = part.stripSuffix("#") + param.dataType.serialize(param.value) 36 | fix(restParts, restParams, (replacedPart, SqlInterpolationParameter.Empty) :: builder) 37 | case (part :: restParts, (param: InnerSql) :: restParams) => 38 | // Inner Sql 39 | 40 | val inner = if part.isEmpty then { 41 | Nil 42 | } else { 43 | List(part -> SqlInterpolationParameter.Empty) 44 | } 45 | 46 | val combined = param.sql.parts.toList.reverse ++ inner ++ builder 47 | fix(restParts, restParams, combined) 48 | case (part :: restParts, param :: restParams) => 49 | // Regular Case 50 | fix(restParts, restParams, (part, param) :: builder) 51 | } 52 | } 53 | fix(sc.parts.toList, parameters.toList, Nil).reverse 54 | } 55 | } 56 | 57 | /** SQL with already embedded parameters. */ 58 | case class Sql(parts: Seq[(String, SqlInterpolationParameter)]) extends SqlBase { 59 | def sql = parts.iterator.map { case (part, param) => 60 | part + param.replacement 61 | }.mkString 62 | 63 | private def sqlParameters: Seq[SqlParameter[?]] = parts.collect { case (_, p: SqlParameter[?]) => 64 | p 65 | } 66 | 67 | override def withPreparedStatement[T]( 68 | f: PreparedStatement => T 69 | )(using cp: ConnectionProvider, sp: StatementPreparator): T = { 70 | cp.withConnection { 71 | val c = summon[Connection] 72 | Using.resource(sp.prepare(c, sql)) { statement => 73 | sqlParameters.zipWithIndex.foreach { case (param, idx) => 74 | param.dataType.fillByZeroBasedIdx(idx, statement, param.value) 75 | } 76 | f(statement) 77 | } 78 | } 79 | } 80 | 81 | def stripMargin: Sql = { 82 | stripMargin('|') 83 | } 84 | 85 | def stripMargin(marginChar: Char): Sql = { 86 | Sql( 87 | parts.map { case (s, p) => 88 | s.stripMargin(marginChar) -> p 89 | } 90 | ) 91 | } 92 | 93 | @targetName("concat") 94 | inline def +(other: Sql): Sql = concat(other) 95 | 96 | def concat(other: Sql): Sql = { 97 | Sql( 98 | this.parts ++ other.parts 99 | ) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/main/scala/usql/SqlBase.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import java.sql.{Connection, PreparedStatement, Statement} 4 | 5 | /** Something which can create prepared statements. */ 6 | trait SqlBase { 7 | 8 | /** Prepares a statement which can then be further filled or executed. */ 9 | def withPreparedStatement[T]( 10 | f: PreparedStatement => T 11 | )(using cp: ConnectionProvider, prep: StatementPreparator = StatementPreparator.default): T 12 | 13 | /** Turns into a query */ 14 | def query: Query = Query(this) 15 | 16 | /** Turns into an update. */ 17 | def update: Update = { 18 | Update(this) 19 | } 20 | 21 | /** Turns into a update on one value set. */ 22 | def one[T](value: T)(using p: RowEncoder[T]): AppliedSql[T] = { 23 | AppliedSql(this, value, p) 24 | } 25 | 26 | /** Turns into a batch operation */ 27 | def batch[T](values: Iterable[T])(using p: RowEncoder[T]): Batch[T] = { 28 | Batch(this, values, p) 29 | } 30 | 31 | /** Raw Executes this statement. */ 32 | def execute()(using ConnectionProvider): Boolean = { 33 | withPreparedStatement(_.execute()) 34 | } 35 | } 36 | 37 | /** Hook for changing the preparation of SQL. */ 38 | trait StatementPreparator { 39 | def prepare(connection: Connection, sql: String): PreparedStatement 40 | } 41 | 42 | object StatementPreparator { 43 | 44 | /** Default Implementation */ 45 | object default extends StatementPreparator { 46 | override def prepare(connection: Connection, sql: String): PreparedStatement = { 47 | connection.prepareStatement(sql) 48 | } 49 | } 50 | 51 | /** Statement should return generated keys */ 52 | object withGeneratedKeys extends StatementPreparator { 53 | override def prepare(connection: Connection, sql: String): PreparedStatement = { 54 | connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS) 55 | } 56 | } 57 | } 58 | 59 | /** With supplied arguments */ 60 | case class AppliedSql[T](base: SqlBase, parameter: T, rowEncoder: RowEncoder[T]) extends SqlBase { 61 | override def withPreparedStatement[T]( 62 | f: PreparedStatement => T 63 | )(using cp: ConnectionProvider, sp: StatementPreparator): T = { 64 | base.withPreparedStatement { ps => 65 | rowEncoder.fill(ps, parameter) 66 | f(ps) 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/scala/usql/SqlIdentifier.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | /** Something which can produce an identifier. */ 4 | trait SqlIdentifying { 5 | 6 | /** Build the identifier. */ 7 | def buildIdentifier: SqlIdentifier 8 | } 9 | 10 | /** 11 | * An SQL Identifier (table or colum name 12 | * @param name 13 | * raw name 14 | * @param quoted 15 | * if true, the identifier will be quoted. 16 | */ 17 | @throws[IllegalArgumentException]("If name contains a \"") 18 | case class SqlIdentifier(name: String, quoted: Boolean, alias: Option[String] = None) extends SqlIdentifying { 19 | require(!name.contains("\""), "Identifiers may not contain \"") 20 | 21 | /** Serialize the identifier. */ 22 | def serialize: String = { 23 | val sb = StringBuilder() 24 | alias.foreach { alias => 25 | sb ++= alias 26 | sb += '.' 27 | } 28 | if quoted then { 29 | sb += '"' 30 | } 31 | sb ++= name 32 | if quoted then { 33 | sb += '"' 34 | } 35 | sb.result() 36 | } 37 | 38 | /** Placeholder for select query */ 39 | def placeholder: SqlRawPart = SqlRawPart("?") 40 | 41 | /** Named placeholder for update query */ 42 | def namedPlaceholder: SqlRawPart = SqlRawPart(serialize + " = ?") 43 | 44 | override def toString: String = serialize 45 | 46 | override def buildIdentifier: SqlIdentifier = this 47 | } 48 | 49 | object SqlIdentifier { 50 | given stringToIdentifier: Conversion[String, SqlIdentifier] with { 51 | override def apply(x: String): SqlIdentifier = fromString(x) 52 | } 53 | 54 | def fromString(s: String): SqlIdentifier = { 55 | if s.length >= 2 && s.startsWith("\"") && s.endsWith("\"") then { 56 | SqlIdentifier(s.drop(1).dropRight(1), true) 57 | } else { 58 | if SqlReservedWords.isReserved(s) then { 59 | SqlIdentifier(s, quoted = true) 60 | } else { 61 | SqlIdentifier(s, quoted = false) 62 | } 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/usql/SqlInterpoationParameter.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.dao.{Alias, ColumnPath, Crd, CrdBase, SqlColumn} 4 | 5 | import scala.language.implicitConversions 6 | 7 | /** Parameters available in sql""-Interpolation. */ 8 | sealed trait SqlInterpolationParameter { 9 | /// Replacement, e.g. '?' 10 | def replacement: String 11 | } 12 | 13 | object SqlInterpolationParameter { 14 | 15 | /** A Parameter which will be filled using '?' and parameter filler */ 16 | class SqlParameter[T](val value: T, val dataType: DataType[T]) extends SqlInterpolationParameter { 17 | override def equals(obj: Any): Boolean = { 18 | obj match { 19 | case s: SqlParameter[_] if value == s.value && dataType == s.dataType => true 20 | case _ => false 21 | } 22 | } 23 | 24 | override def hashCode(): Int = { 25 | value.hashCode() 26 | } 27 | 28 | override def replacement: String = "?" 29 | 30 | override def toString: String = { 31 | s"SqlParameter(${value} of type ${dataType.jdbcType.getName})" 32 | } 33 | } 34 | 35 | object SqlParameter { 36 | def apply[T](value: T)(using dataType: DataType[T]): SqlParameter[T] = new SqlParameter(value, dataType) 37 | } 38 | 39 | /** A single identifier. */ 40 | case class IdentifierParameter(i: SqlIdentifier) extends SqlInterpolationParameter { 41 | override def replacement: String = i.serialize 42 | } 43 | 44 | /** Multiple identifiers. */ 45 | case class IdentifiersParameter(i: Seq[SqlIdentifier]) extends SqlInterpolationParameter { 46 | override def replacement: String = { 47 | i.iterator.map(_.serialize).mkString(",") 48 | } 49 | } 50 | 51 | /** Some unchecked raw block. */ 52 | case class RawBlockParameter(s: String) extends SqlInterpolationParameter { 53 | override def replacement: String = s 54 | } 55 | 56 | case class InnerSql(sql: Sql) extends SqlInterpolationParameter { 57 | // Not used 58 | override def replacement: String = sql.sql 59 | } 60 | 61 | /** Empty leaf, so that we have exactly as much interpolation parameters as string parts. */ 62 | object Empty extends SqlInterpolationParameter { 63 | override def replacement: String = "" 64 | } 65 | 66 | implicit def toSqlParameter[T](value: T)(using dataType: DataType[T]): SqlParameter[T] = { 67 | new SqlParameter(value, dataType) 68 | } 69 | 70 | implicit def toIdentifierParameter(i: SqlIdentifying): IdentifierParameter = IdentifierParameter(i.buildIdentifier) 71 | implicit def toIdentifiersParameter(i: Seq[SqlIdentifying]): IdentifiersParameter = IdentifiersParameter( 72 | i.map(_.buildIdentifier) 73 | ) 74 | implicit def columnsParameter(c: Seq[SqlColumn[?]]): IdentifiersParameter = IdentifiersParameter(c.map(_.id)) 75 | implicit def rawBlockParameter(rawPart: SqlRawPart): RawBlockParameter = RawBlockParameter(rawPart.s) 76 | implicit def innerSql(sql: Sql): InnerSql = InnerSql(sql) 77 | implicit def alias(alias: Alias[?]): RawBlockParameter = RawBlockParameter( 78 | s"${alias.tabular.tableName} ${alias.aliasName}" 79 | ) 80 | 81 | implicit def sqlParameters[T](sqlParameters: SqlParameters[T])(using dataType: DataType[T]): InnerSql = { 82 | val builder = Seq.newBuilder[(String, SqlInterpolationParameter)] 83 | sqlParameters.values.headOption.foreach { first => 84 | builder += (("", SqlParameter(first))) 85 | sqlParameters.values.tail.foreach { next => 86 | builder += ((",", SqlParameter(next))) 87 | } 88 | } 89 | InnerSql(Sql(builder.result())) 90 | } 91 | 92 | implicit def crd(crd: CrdBase[?]): RawBlockParameter = RawBlockParameter(s"${crd.tabular.tableName}") 93 | } 94 | 95 | /** Something which can be added to sql""-interpolation without further checking. */ 96 | case class SqlRawPart(s: String) { 97 | override def toString: String = s 98 | } 99 | 100 | /** Marker for a sequence of elements like in SQL IN Clause, will be encoded as `?,...,?` and filled with values */ 101 | case class SqlParameters[T](values: Seq[T]) 102 | -------------------------------------------------------------------------------- /src/main/scala/usql/SqlReservedWords.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | /** 4 | * SQL Reserved Words. Can be used to escape identifiers if needed. 5 | */ 6 | private[usql] object SqlReservedWords { 7 | 8 | /** Check if some identifier is a reserved word. */ 9 | def isReserved(name: String): Boolean = { 10 | wordList.contains(name.toUpperCase) 11 | } 12 | 13 | // Source https://www.postgresql.org/docs/current/sql-keywords-appendix.html 14 | // Only reserved in postgres (otherwise we would also escape words like `id`) 15 | private val wordList: Set[String] = Set( 16 | "ALL", 17 | "ANALYSE", 18 | "ANALYZE", 19 | "AND", 20 | "ANY", 21 | "ARRAY", 22 | "AS", 23 | "ASC", 24 | "ASYMMETRIC", 25 | "AUTHORIZATION", 26 | "BINARY", 27 | "BOTH", 28 | "CASE", 29 | "CAST", 30 | "CHECK", 31 | "COLLATE", 32 | "COLLATION", 33 | "COLUMN", 34 | "CONCURRENTLY", 35 | "CONSTRAINT", 36 | "CREATE", 37 | "CROSS", 38 | "CURRENT_CATALOG", 39 | "CURRENT_DATE", 40 | "CURRENT_ROLE", 41 | "CURRENT_TIME", 42 | "CURRENT_TIMESTAMP", 43 | "CURRENT_USER", 44 | "DEFAULT", 45 | "DEFERRABLE", 46 | "DESC", 47 | "DISTINCT", 48 | "DO", 49 | "ELSE", 50 | "END", 51 | "EXCEPT", 52 | "FALSE", 53 | "FETCH", 54 | "FOR", 55 | "FOREIGN", 56 | "FREEZE", 57 | "FROM", 58 | "FULL", 59 | "GRANT", 60 | "GROUP", 61 | "HAVING", 62 | "ILIKE", 63 | "IN", 64 | "INITIALLY", 65 | "INNER", 66 | "INTERSECT", 67 | "INTO", 68 | "IS", 69 | "ISNULL", 70 | "JOIN", 71 | "LATERAL", 72 | "LEADING", 73 | "LEFT", 74 | "LIKE", 75 | "LIMIT", 76 | "LOCALTIME", 77 | "LOCALTIMESTAMP", 78 | "NATURAL", 79 | "NOT", 80 | "NULL", 81 | "OFFSET", 82 | "ON", 83 | "ONLY", 84 | "OR", 85 | "ORDER", 86 | "OUTER", 87 | "OVERLAPS", 88 | "PLACING", 89 | "PRIMARY", 90 | "REFERENCES", 91 | "RETURNING", 92 | "RIGHT", 93 | "SELECT", 94 | "SESSION_USER", 95 | "SIMILAR", 96 | "SOME", 97 | "SYMMETRIC", 98 | "TABLE", 99 | "TABLESAMPLE", 100 | "THEN", 101 | "TO", 102 | "TRAILING", 103 | "TRUE", 104 | "UNION", 105 | "UNIQUE", 106 | "USER", 107 | "USING", 108 | "VARIADIC", 109 | "VERBOSE", 110 | "WHEN", 111 | "WHERE", 112 | "WINDOW", 113 | "WITH" 114 | ) 115 | } 116 | -------------------------------------------------------------------------------- /src/main/scala/usql/Update.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.Update.SqlResultMissingGenerated 4 | 5 | import java.sql.SQLException 6 | import scala.util.Using 7 | 8 | /** Encapsulates an update statement */ 9 | case class Update(sql: SqlBase) { 10 | 11 | /** Run the update statement */ 12 | def run()(using c: ConnectionProvider): Int = { 13 | sql.withPreparedStatement(_.executeUpdate()) 14 | } 15 | 16 | /** 17 | * Run the update statement and get generated values. See [[java.sql.PreparedStatement.getGeneratedKeys()]] 18 | */ 19 | def runAndGetGenerated[T]()(using d: RowDecoder[T], c: ConnectionProvider): T = { 20 | given sp: StatementPreparator = StatementPreparator.withGeneratedKeys 21 | sql.withPreparedStatement { statement => 22 | statement.executeUpdate() 23 | Using.resource(statement.getGeneratedKeys) { resultSet => 24 | if resultSet.next() then { 25 | d.parseRow(resultSet) 26 | } else { 27 | throw new SqlResultMissingGenerated("Missing row for getGeneratedKeys") 28 | } 29 | } 30 | } 31 | } 32 | } 33 | 34 | object Update { 35 | 36 | /** Exception thrown if the result set has no generated data. */ 37 | class SqlResultMissingGenerated(msg: String) extends SQLException(msg) 38 | } 39 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/Alias.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{SqlIdentifier, SqlRawPart} 4 | 5 | /** Experimental helper for building aliases used in Join Statements */ 6 | case class Alias[T]( 7 | aliasName: String, 8 | tabular: SqlTabular[T] 9 | ) { 10 | 11 | /** Alias one identifier */ 12 | def apply(c: SqlIdentifier): SqlRawPart = { 13 | SqlRawPart(this.aliasName + "." + c.serialize) 14 | } 15 | 16 | /** Refers to all aliased columns */ 17 | def columns: SqlRawPart = { 18 | SqlRawPart( 19 | tabular.columns 20 | .map { c => apply(c.id).s } 21 | .mkString(",") 22 | ) 23 | } 24 | 25 | /** Access to aliased cols. */ 26 | def col: ColumnPath[T, T] = ColumnPath(tabular, Nil, alias = Some(aliasName)) 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/Annotations.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | 5 | import scala.annotation.StaticAnnotation 6 | 7 | /** Annotation to override the default table name in [[SqlTabular]] */ 8 | case class TableName(name: String) extends StaticAnnotation 9 | 10 | /** Annotation to override the default column name in [[SqlColumnar]] */ 11 | case class ColumnName(name: String) extends StaticAnnotation { 12 | def id: SqlIdentifier = SqlIdentifier.fromString(name) 13 | } 14 | 15 | /** 16 | * Controls the way nested column group names are generated. 17 | * 18 | * @param mapping 19 | * the mapping which will be applied 20 | */ 21 | case class ColumnGroup(mapping: ColumnGroupMapping = ColumnGroupMapping.Pattern()) extends StaticAnnotation 22 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/ColumnGroupMapping.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | 5 | /** Maps an inner column name inside a ColumnGroup. */ 6 | trait ColumnGroupMapping { 7 | def map(columnBaseName: SqlIdentifier, childId: SqlIdentifier): SqlIdentifier 8 | } 9 | 10 | object ColumnGroupMapping { 11 | 12 | /** Simple Pattern based column group mapping. */ 13 | case class Pattern(pattern: String = "%m_%c") extends ColumnGroupMapping { 14 | override def map(columnBaseName: SqlIdentifier, childId: SqlIdentifier): SqlIdentifier = { 15 | val applied = pattern 16 | .replace("%m", columnBaseName.name) 17 | .replace("%c", childId.name) 18 | // Do not take escaping from the field or parent as this can lead to strange situations (still hacky) 19 | SqlIdentifier.fromString(applied) 20 | } 21 | } 22 | 23 | case object Anonymous extends ColumnGroupMapping { 24 | override def map(columnBaseName: SqlIdentifier, childId: SqlIdentifier): SqlIdentifier = childId 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/ColumnPath.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{SqlIdentifier, SqlIdentifying} 4 | 5 | /** 6 | * Helper for going through the field path of SqlFielded. 7 | * 8 | * They can provide Identifiers and build getters like lenses do. 9 | * 10 | * @tparam R 11 | * root model 12 | * @tparam T 13 | * end path 14 | */ 15 | case class ColumnPath[R, T](root: SqlFielded[R], fields: List[String] = Nil, alias: Option[String] = None) 16 | extends Selectable 17 | with SqlIdentifying { 18 | 19 | final type Child[X] = ColumnPath[R, X] 20 | 21 | type Fields = NamedTuple.Map[NamedTuple.From[T], Child] 22 | 23 | def selectDynamic(name: String): ColumnPath[R, ?] = { 24 | ColumnPath(root, name :: fields, alias) 25 | } 26 | 27 | private lazy val walker: ColumnPath.Walker[R, T] = { 28 | val reversed = fields.reverse 29 | reversed 30 | .foldLeft( 31 | ColumnPath.FieldedWalker[R, R]( 32 | root, 33 | mapping = identity, 34 | getter = identity 35 | ): ColumnPath.Walker[?, ?] 36 | )(_.select(_)) 37 | .asInstanceOf[ColumnPath.Walker[R, T]] 38 | } 39 | 40 | override def buildIdentifier: SqlIdentifier = { 41 | walker.id.copy(alias = alias) 42 | } 43 | 44 | def buildGetter: R => T = { 45 | walker.get 46 | } 47 | } 48 | 49 | object ColumnPath { 50 | 51 | def make[T](using f: SqlFielded[T]): ColumnPath[T, T] = ColumnPath(f) 52 | 53 | trait Walker[R, T] { 54 | def select(field: String): Walker[R, ?] 55 | def id: SqlIdentifier 56 | def get(root: R): T 57 | } 58 | 59 | case class FieldedWalker[R, T]( 60 | model: SqlFielded[T], 61 | mapping: SqlIdentifier => SqlIdentifier = identity, 62 | getter: R => T = identity 63 | ) extends Walker[R, T] { 64 | override def select(field: String): Walker[R, ?] = { 65 | model.fields.view.zipWithIndex 66 | .collectFirst { 67 | case (f, idx) if f.fieldName == field => 68 | selectField(idx, f) 69 | } 70 | .getOrElse { 71 | throw new IllegalStateException(s"Can not fiend field nane ${field}") 72 | } 73 | } 74 | 75 | private def selectField[X](idx: Int, f: Field[X]): Walker[R, X] = { 76 | val subGetter: T => X = (value) => { 77 | val splitted = model.split(value) 78 | splitted.apply(idx).asInstanceOf[X] 79 | } 80 | val newFetcher: R => X = getter.andThen(subGetter) 81 | f match { 82 | case f: Field.Column[X] => ColumnWalker[R, X](f, mapping, newFetcher) 83 | case g: Field.Group[X] => 84 | val subMapping: SqlIdentifier => SqlIdentifier = in => mapping(g.mapping.map(g.columnBaseName, in)) 85 | FieldedWalker(g.fielded, subMapping, newFetcher) 86 | } 87 | } 88 | 89 | override def id: SqlIdentifier = { 90 | throw new IllegalStateException("Not at a final field") 91 | } 92 | 93 | override def get(root: R): T = { 94 | getter(root) 95 | } 96 | } 97 | 98 | case class ColumnWalker[R, T]( 99 | column: Field.Column[T], 100 | mapping: SqlIdentifier => SqlIdentifier = identity, 101 | getter: R => T 102 | ) extends Walker[R, T] { 103 | override def select(field: String): Walker[R, ?] = { 104 | throw new IllegalStateException(s"Can walk further column") 105 | } 106 | 107 | override def id: SqlIdentifier = mapping(column.column.id) 108 | 109 | override def get(root: R): T = getter(root) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/Crd.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.* 4 | 5 | /** Simple create/read/delete interface */ 6 | trait Crd[T] { 7 | 8 | /** Insert into database. */ 9 | def insert(value: T)(using ConnectionProvider): Int = insert(Seq(value)) 10 | 11 | /** Insert many elements */ 12 | def insert(value1: T, value2: T, values: T*)(using ConnectionProvider): Int = insert(value1 +: value2 +: values) 13 | 14 | /** Insert many elements. */ 15 | def insert(values: Seq[T])(using ConnectionProvider): Int 16 | 17 | /** Find all instances */ 18 | def findAll()(using ConnectionProvider): Seq[T] 19 | 20 | /** Count all instances. */ 21 | def countAll()(using ConnectionProvider): Int 22 | 23 | /** Delete all instances. */ 24 | def deleteAll()(using ConnectionProvider): Int 25 | 26 | } 27 | 28 | /** CRUD (Create, Retrieve, Update, Delete) for keyed data. */ 29 | trait KeyedCrud[T] extends Crd[T] { 30 | 31 | /** Type of the key */ 32 | type Key 33 | 34 | /** Returns the key of a value. */ 35 | def keyOf(value: T): Key 36 | 37 | /** Update some value. */ 38 | def update(value: T)(using ConnectionProvider): Int 39 | 40 | /** Find one by key. */ 41 | def findByKey(key: Key)(using ConnectionProvider): Option[T] 42 | 43 | /** Load some value again based upon key. */ 44 | def findAgain(value: T)(using ConnectionProvider): Option[T] = findByKey(keyOf(value)) 45 | 46 | /** Delete by key. */ 47 | def deleteByKey(key: Key)(using ConnectionProvider): Int 48 | } 49 | 50 | /** Implementation of Crd for Tabular data. */ 51 | abstract class CrdBase[T] extends Crd[T] { 52 | 53 | protected given pf: RowEncoder[T] = tabular.rowEncoder 54 | 55 | protected given rd: RowDecoder[T] = tabular.rowDecoder 56 | 57 | /** 58 | * Define the referenced tabular, usually implemented using `summon`. We would like to have it as a parameter, but 59 | * this leads to this error https://github.com/scala/scala3/issues/22704 even when using lazy parameters. 60 | */ 61 | lazy val tabular: SqlTabular[T] 62 | 63 | /** Gives access to an aliased view. */ 64 | def alias(name: String): Alias[T] = tabular.alias(name) 65 | 66 | /** Gives access to the columns */ 67 | def cols: ColumnPath[T, T] = tabular.cols 68 | 69 | private lazy val insertStatement = { 70 | val placeholders = SqlRawPart(tabular.columns.map(_.id.placeholder.s).mkString(",")) 71 | sql"INSERT INTO ${tabular.tableName} (${tabular.columns}) VALUES ($placeholders)" 72 | } 73 | 74 | override def insert(value: T)(using ConnectionProvider): Int = { 75 | insertStatement.one(value).update.run() 76 | } 77 | 78 | override def insert(values: Seq[T])(using ConnectionProvider): Int = { 79 | insertStatement.batch(values).run().sum 80 | } 81 | 82 | /** Select All Statement, may be reused. */ 83 | protected lazy val selectAll = sql"SELECT ${tabular.columns} FROM ${tabular.tableName}" 84 | 85 | override def findAll()(using ConnectionProvider): Seq[T] = { 86 | selectAll.query.all() 87 | } 88 | 89 | private lazy val countAllStatement = sql"SELECT COUNT(*) FROM ${tabular.tableName}" 90 | 91 | override def countAll()(using ConnectionProvider): Int = { 92 | import usql.profiles.BasicProfile.intType 93 | countAllStatement.query.one[Int]().getOrElse(0) 94 | } 95 | 96 | private lazy val deleteAllStatement = sql"DELETE FROM ${tabular.tableName}" 97 | 98 | override def deleteAll()(using ConnectionProvider): Int = { 99 | deleteAllStatement.update.run() 100 | } 101 | } 102 | 103 | /** Implementation of KeyedCrd for KeyedTabular data. */ 104 | abstract class KeyedCrudBase[K, T](using keyDataType: DataType[K]) extends CrdBase[T] with KeyedCrud[T] { 105 | 106 | override type Key = K 107 | 108 | final type KeyColumnPath = ColumnPath[T, K] 109 | 110 | /** The column of the key */ 111 | lazy val keyColumn: SqlIdentifying = key.buildIdentifier 112 | 113 | def keyOf(value: T): K = cachedKeyGetter(value) 114 | 115 | def key: KeyColumnPath 116 | 117 | private lazy val cachedKeyGetter: T => K = key.buildGetter 118 | 119 | private lazy val updateStatement = { 120 | val namedPlaceholders = SqlRawPart(tabular.columns.map(_.id.namedPlaceholder.s).mkString(",")) 121 | sql"UPDATE ${tabular.tableName} SET $namedPlaceholders WHERE ${keyColumn} = ?" 122 | } 123 | 124 | override def update(value: T)(using ConnectionProvider): Int = { 125 | val key = keyOf(value) 126 | updateStatement.one((value, key)).update.run() 127 | } 128 | 129 | private lazy val findByKeyQuery = 130 | sql"${selectAll} WHERE ${keyColumn} = ?" 131 | 132 | override def findByKey(key: K)(using ConnectionProvider): Option[T] = { 133 | findByKeyQuery.one(key).query.one() 134 | } 135 | 136 | private lazy val deleteByKeyStatement = 137 | sql"DELETE FROM ${tabular.tableName} WHERE ${keyColumn} = ?" 138 | 139 | override def deleteByKey(key: K)(using ConnectionProvider): Int = { 140 | deleteByKeyStatement.one(key).update.run() 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/Macros.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{DataType, RowEncoder, RowDecoder, SqlIdentifier} 4 | 5 | import scala.annotation.Annotation 6 | import scala.compiletime.{erasedValue, summonInline} 7 | import scala.deriving.Mirror 8 | import scala.quoted.{Expr, Quotes, Type} 9 | import scala.reflect.ClassTag 10 | 11 | object Macros { 12 | 13 | def getMaxOneAnnotation[T: ClassTag](in: List[Annotation]): Option[T] = { 14 | in.collect { case a: T => 15 | a 16 | } match { 17 | case Nil => None 18 | case List(one) => Some(one) 19 | case multiple => throw new IllegalArgumentException(s"More than one annotation of same type found: ${multiple}") 20 | } 21 | } 22 | 23 | /** Type info for each member, to differentiate between columnar and scalar types. */ 24 | sealed trait TypeInfo[T] 25 | 26 | object TypeInfo { 27 | case class Scalar[T](dataType: DataType[T]) extends TypeInfo[T] 28 | 29 | case class Columnar[T](columnar: SqlColumnar[T]) extends TypeInfo[T] 30 | 31 | given scalar[T](using dt: DataType[T]): TypeInfo[T] = Scalar(dt) 32 | given columnar[T](using c: SqlColumnar[T]): TypeInfo[T] = Columnar(c) 33 | } 34 | 35 | /** Combined TypeInfos for a tuple. */ 36 | case class TypeInfos[T](infos: List[TypeInfo[?]], builder: List[Any] => T) 37 | 38 | object TypeInfos { 39 | given forTuple[H, T <: Tuple]( 40 | using typeInfo: TypeInfo[H], 41 | tailInfos: TypeInfos[T] 42 | ): TypeInfos[H *: T] = TypeInfos( 43 | typeInfo :: tailInfos.infos, 44 | builder = values => { 45 | values.head.asInstanceOf[H] *: tailInfos.builder(values.tail) 46 | } 47 | ) 48 | given empty: TypeInfos[EmptyTuple] = TypeInfos(Nil, _ => EmptyTuple) 49 | } 50 | 51 | inline def buildTabular[T <: Product](using nm: NameMapping, mirror: Mirror.ProductOf[T]): SqlTabular[T] = { 52 | val fielded = buildFielded[T] 53 | 54 | val tableName: SqlIdentifier = tableNameAnnotation[T] 55 | .map { tn => 56 | SqlIdentifier.fromString(tn.name) 57 | } 58 | .getOrElse { 59 | nm.caseClassToTableName(typeName[T]) 60 | } 61 | 62 | SqlTabular.SimpleTabular( 63 | tableName = tableName, 64 | fielded = fielded 65 | ) 66 | } 67 | 68 | inline def typeName[T]: String = { 69 | ${ typeNameImpl[T] } 70 | } 71 | 72 | def typeNameImpl[T](using types: Type[T], quotes: Quotes): Expr[String] = { 73 | Expr(Type.show[T]) 74 | } 75 | 76 | inline def deriveLabels[T](using m: Mirror.Of[T]): List[String] = { 77 | // Also See https://stackoverflow.com/a/70416544/335385 78 | summonLabels[m.MirroredElemLabels] 79 | } 80 | 81 | inline def summonLabels[T <: Tuple]: List[String] = { 82 | inline erasedValue[T] match { 83 | case _: EmptyTuple => Nil 84 | case _: (t *: ts) => summonInline[ValueOf[t]].value.asInstanceOf[String] :: summonLabels[ts] 85 | } 86 | } 87 | 88 | /** Extract table name annotation for the type. */ 89 | inline def tableNameAnnotation[T]: Option[TableName] = { 90 | ${ tableNameAnnotationImpl[T] } 91 | } 92 | 93 | def tableNameAnnotationImpl[T](using quotes: Quotes, t: Type[T]): Expr[Option[TableName]] = { 94 | import quotes.reflect.* 95 | val tree = TypeRepr.of[T] 96 | val symbol = tree.typeSymbol 97 | symbol.annotations.collectFirst { 98 | case term if (term.tpe <:< TypeRepr.of[TableName]) => 99 | term.asExprOf[TableName] 100 | } match { 101 | case None => '{ None } 102 | case Some(e) => '{ Some(${ e }) } 103 | } 104 | } 105 | 106 | inline def annotationsExtractor[T]: List[List[Annotation]] = { 107 | ${ annotationsExtractorImpl[T] } 108 | } 109 | 110 | def annotationsExtractorImpl[T](using quotes: Quotes, t: Type[T]): Expr[List[List[Annotation]]] = { 111 | import quotes.reflect.* 112 | val tree = TypeRepr.of[T] 113 | val symbol = tree.typeSymbol 114 | 115 | // Note: symbol.caseFields.map(_.annotations) does not work, but using the primaryConstructor works 116 | // Also see https://august.nagro.us/read-annotations-from-macro.html 117 | 118 | Expr.ofList( 119 | symbol.primaryConstructor.paramSymss.flatten 120 | .map { sym => 121 | Expr.ofList { 122 | sym.annotations.collect { 123 | case term if (term.tpe <:< TypeRepr.of[Annotation]) => 124 | term.asExprOf[Annotation] 125 | } 126 | } 127 | } 128 | ) 129 | } 130 | 131 | inline def buildFielded[T <: Product]( 132 | using nm: NameMapping, 133 | mirror: Mirror.ProductOf[T] 134 | ): SqlFielded[T] = { 135 | val labels: List[String] = deriveLabels[T] 136 | val annotations: List[List[Annotation]] = annotationsExtractor[T] 137 | val typeInfos = summonInline[TypeInfos[mirror.MirroredElemTypes]] 138 | val splitter: T => List[Any] = v => v.productIterator.toList 139 | 140 | val fields = 141 | labels.zip(annotations).zip(typeInfos.infos).map { 142 | case ((label, annotations), typeInfo: TypeInfo.Scalar[?]) => 143 | val nameAnnotation = getMaxOneAnnotation[ColumnName](annotations) 144 | val id = nameAnnotation.map(a => SqlIdentifier.fromString(a.name)).getOrElse(nm.columnToSql(label)) 145 | val column = SqlColumn(id, typeInfo.dataType) 146 | Field.Column(label, column) 147 | case ((label, annotations), c: TypeInfo.Columnar[?]) => 148 | val nameAnnotation = getMaxOneAnnotation[ColumnName](annotations) 149 | val columnGroup = getMaxOneAnnotation[ColumnGroup](annotations) 150 | val mapping = columnGroup.map(_.mapping).getOrElse(ColumnGroupMapping.Pattern()) 151 | val columnBaseName = 152 | nameAnnotation.map(a => SqlIdentifier.fromString(a.name)).getOrElse(nm.columnToSql(label)) 153 | Field.Group(label, mapping, columnBaseName, c.columnar.asInstanceOf[SqlFielded[?]]) 154 | } 155 | SqlFielded.SimpleSqlFielded( 156 | fields = fields, 157 | splitter = splitter, 158 | builder = typeInfos.builder.andThen(mirror.fromTuple) 159 | ) 160 | } 161 | 162 | } 163 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/NameMapping.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | 5 | /** Maps Column / Table names. */ 6 | trait NameMapping { 7 | 8 | /** Converts a column name to SQL identifiers */ 9 | def columnToSql(name: String): SqlIdentifier 10 | 11 | /** Converts a case class (full qualified) name to SQL. */ 12 | def caseClassToTableName(name: String): SqlIdentifier 13 | } 14 | 15 | object NameMapping { 16 | 17 | /** Simple Snake Case Conversion with checking against escaping. */ 18 | object Default extends NameMapping { 19 | 20 | override def columnToSql(name: String): SqlIdentifier = SqlIdentifier.fromString(snakeCase(name)) 21 | 22 | override def caseClassToTableName(name: String): SqlIdentifier = { 23 | SqlIdentifier.fromString(snakeCase(getSimpleClassName(name))) 24 | } 25 | } 26 | 27 | /** Returns the simple class name from full qualified name. */ 28 | def getSimpleClassName(s: String): String = { 29 | s.lastIndexOf('.') match { 30 | case -1 => s 31 | case n => s.drop(n + 1) 32 | } 33 | } 34 | 35 | /** Converts a string to snake case. */ 36 | def snakeCase(s: String): String = { 37 | val builder = StringBuilder() 38 | var lastIsUpper = false 39 | var first = true 40 | s.foreach { c => 41 | if c.isUpper && !lastIsUpper && !first then { 42 | builder += '_' 43 | } 44 | builder += c.toLower 45 | lastIsUpper = c.isUpper 46 | first = false 47 | } 48 | builder.result() 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/SqlColumn.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{DataType, SqlIdentifier, SqlRawPart} 4 | 5 | /** A Single Column */ 6 | case class SqlColumn[T]( 7 | id: SqlIdentifier, 8 | dataType: DataType[T] 9 | ) 10 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/SqlColumnar.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{RowEncoder, RowDecoder} 4 | 5 | import scala.deriving.Mirror 6 | 7 | /** 8 | * Encapsulates column data and codecs for a product type. 9 | * 10 | * Note: for case classes, this is usually presented by [[SqlFielded]] 11 | */ 12 | trait SqlColumnar[T] { 13 | 14 | /** The columns */ 15 | def columns: Seq[SqlColumn[?]] 16 | 17 | /** Count of columns */ 18 | def cardinality: Int = columns.size 19 | 20 | /** Decoder for a full row. */ 21 | def rowDecoder: RowDecoder[T] 22 | 23 | /** Filler for a full row. */ 24 | def rowEncoder: RowEncoder[T] 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/SqlFielded.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{RowEncoder, RowDecoder, SqlIdentifier} 4 | 5 | import java.sql.{PreparedStatement, ResultSet} 6 | import scala.deriving.Mirror 7 | 8 | /** Something which has fields (e.g. a case class) */ 9 | trait SqlFielded[T] extends SqlColumnar[T] { 10 | 11 | /** Returns the available fields. */ 12 | def fields: Seq[Field[?]] 13 | 14 | /** Access to the columns */ 15 | def cols: ColumnPath[T, T] = ColumnPath(this, Nil) 16 | 17 | /** Split an instance into its fields */ 18 | protected[dao] def split(value: T): Seq[Any] 19 | 20 | /** Build from field values. */ 21 | protected[dao] def build(fieldValues: Seq[Any]): T 22 | 23 | override lazy val columns: Seq[SqlColumn[?]] = 24 | fields.flatMap { field => 25 | field.columns 26 | } 27 | 28 | override def rowDecoder: RowDecoder[T] = new RowDecoder { 29 | override def parseRow(offset: Int, row: ResultSet): T = { 30 | val fieldValues = Seq.newBuilder[Any] 31 | var currentOffset = offset 32 | fields.foreach { field => 33 | fieldValues += field.decoder.parseRow(currentOffset, row) 34 | currentOffset += field.decoder.cardinality 35 | } 36 | build(fieldValues.result()) 37 | } 38 | 39 | override def cardinality: Int = SqlFielded.this.cardinality 40 | } 41 | 42 | override def rowEncoder: RowEncoder[T] = new RowEncoder[T] { 43 | override def fill(offset: Int, ps: PreparedStatement, value: T): Unit = { 44 | var currentOffset = offset 45 | val fieldValues = split(value) 46 | fieldValues.zip(fields).foreach { case (fieldValue, field) => 47 | field.filler.fillUnchecked(currentOffset, ps, fieldValue) 48 | currentOffset += field.filler.cardinality 49 | } 50 | } 51 | 52 | override def cardinality: Int = SqlFielded.this.cardinality 53 | } 54 | } 55 | 56 | object SqlFielded { 57 | 58 | /** Simple implementation. */ 59 | case class SimpleSqlFielded[T]( 60 | fields: Seq[Field[?]], 61 | splitter: T => List[Any], 62 | builder: List[Any] => T 63 | ) extends SqlFielded[T] { 64 | override protected[dao] def split(value: T): Seq[Any] = splitter(value) 65 | 66 | override protected[dao] def build(fieldValues: Seq[Any]): T = builder(fieldValues.toList) 67 | } 68 | 69 | inline def derived[T <: Product: Mirror.ProductOf](using nm: NameMapping = NameMapping.Default): SqlFielded[T] = 70 | Macros.buildFielded[T] 71 | 72 | } 73 | 74 | /** A Field of a case class. */ 75 | sealed trait Field[T] { 76 | 77 | /** Name of the field (case class member) */ 78 | def fieldName: String 79 | 80 | /** Columns represented by this field. */ 81 | def columns: Seq[SqlColumn[?]] 82 | 83 | /** Decoder for this field. */ 84 | def decoder: RowDecoder[T] 85 | 86 | /** Filler for this field. */ 87 | def filler: RowEncoder[T] 88 | } 89 | 90 | object Field { 91 | 92 | /** A Field which maps to a column */ 93 | case class Column[T](fieldName: String, column: SqlColumn[T]) extends Field[T] { 94 | override def columns: Seq[SqlColumn[?]] = List(column) 95 | 96 | override def decoder: RowDecoder[T] = RowDecoder.forDataType[T](using column.dataType) 97 | 98 | override def filler: RowEncoder[T] = RowEncoder.forDataType[T](using column.dataType) 99 | } 100 | 101 | /** A Field which maps to a nested case class */ 102 | case class Group[T]( 103 | fieldName: String, 104 | mapping: ColumnGroupMapping, 105 | columnBaseName: SqlIdentifier, 106 | fielded: SqlFielded[T] 107 | ) extends Field[T] { 108 | override def columns: Seq[SqlColumn[?]] = 109 | fielded.columns.map { column => 110 | column.copy( 111 | id = mapping.map(columnBaseName, column.id) 112 | ) 113 | } 114 | 115 | override def decoder: RowDecoder[T] = fielded.rowDecoder 116 | 117 | override def filler: RowEncoder[T] = fielded.rowEncoder 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/main/scala/usql/dao/SqlTabular.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{RowEncoder, RowDecoder, SqlIdentifier} 4 | 5 | import scala.deriving.Mirror 6 | 7 | /** Maps some thing to a whole table */ 8 | trait SqlTabular[T] extends SqlFielded[T] { 9 | 10 | /** Name of the table. */ 11 | def tableName: SqlIdentifier 12 | 13 | /** Alias this table to be used in joins */ 14 | def alias(name: String): Alias[T] = Alias(name, this) 15 | } 16 | 17 | object SqlTabular { 18 | 19 | /** 20 | * Derive an instance for a case class. 21 | * 22 | * Use [[ColumnName]] to control column names. 23 | * 24 | * Use [[TableName]] to control table names. 25 | * 26 | * @param nm 27 | * name mapping strategy. 28 | */ 29 | inline def derived[T <: Product: Mirror.ProductOf](using nm: NameMapping = NameMapping.Default): SqlTabular[T] = 30 | Macros.buildTabular[T] 31 | 32 | case class SimpleTabular[T]( 33 | tableName: SqlIdentifier, 34 | fielded: SqlFielded[T] 35 | ) extends SqlTabular[T] { 36 | override def fields: Seq[Field[?]] = fielded.fields 37 | 38 | override protected[dao] def split(value: T): Seq[Any] = fielded.split(value) 39 | 40 | override protected[dao] def build(fieldValues: Seq[Any]): T = fielded.build(fieldValues) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/usql/profiles/BasicProfile.scala: -------------------------------------------------------------------------------- 1 | package usql.profiles 2 | 3 | import usql.DataType 4 | 5 | import java.sql.{JDBCType, PreparedStatement, ResultSet, Timestamp} 6 | import java.time.Instant 7 | import java.util 8 | import scala.language.implicitConversions 9 | import scala.reflect.ClassTag 10 | 11 | trait BasicProfile { 12 | implicit val intType: DataType[Int] = DataType.simple(JDBCType.INTEGER, _.getInt(_), _.setInt(_, _)) 13 | 14 | implicit val longType: DataType[Long] = DataType.simple(JDBCType.BIGINT, _.getLong(_), _.setLong(_, _)) 15 | 16 | implicit val shortType: DataType[Short] = DataType.simple(JDBCType.SMALLINT, _.getShort(_), _.setShort(_, _)) 17 | 18 | implicit val byteType: DataType[Byte] = DataType.simple(JDBCType.TINYINT, _.getByte(_), _.setByte(_, _)) 19 | 20 | implicit val booleanType: DataType[Boolean] = DataType.simple(JDBCType.BOOLEAN, _.getBoolean(_), _.setBoolean(_, _)) 21 | 22 | implicit val floatType: DataType[Float] = DataType.simple(JDBCType.FLOAT, _.getFloat(_), _.setFloat(_, _)) 23 | 24 | implicit val doubleType: DataType[Double] = DataType.simple(JDBCType.DOUBLE, _.getDouble(_), _.setDouble(_, _)) 25 | 26 | implicit val bigDecimalType: DataType[BigDecimal] = 27 | DataType.simple(JDBCType.DECIMAL, _.getBigDecimal(_), (ps, idx, v) => ps.setBigDecimal(idx, v.underlying())) 28 | 29 | implicit val stringType: DataType[String] = DataType.simple(JDBCType.VARCHAR, _.getString(_), _.setString(_, _)) 30 | 31 | implicit val timestampType: DataType[Timestamp] = 32 | DataType.simple(JDBCType.TIMESTAMP, _.getTimestamp(_), _.setTimestamp(_, _)) 33 | 34 | implicit val instantType: DataType[Instant] = timestampType.adapt[Instant](_.toInstant, Timestamp.from) 35 | 36 | implicit def optionType[T](using dt: DataType[T]): DataType[Option[T]] = new DataType[Option[T]] { 37 | override def extractBySqlIdx(cIdx: Int, rs: ResultSet): Option[T] = { 38 | dt.extractOptionalBySqlIdx(cIdx, rs) 39 | } 40 | 41 | override def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: Option[T]): Unit = { 42 | value match { 43 | case None => ps.setNull(pIdx, jdbcType.getVendorTypeNumber) 44 | case Some(v) => dt.fillBySqlIdx(pIdx, ps, v) 45 | } 46 | } 47 | 48 | override def jdbcType: JDBCType = dt.jdbcType 49 | } 50 | 51 | implicit val arrayType: DataType[java.sql.Array] = new DataType[java.sql.Array] { 52 | override def jdbcType: JDBCType = JDBCType.ARRAY 53 | 54 | override def extractBySqlIdx(cIdx: Int, rs: ResultSet): java.sql.Array = { 55 | rs.getArray(cIdx) 56 | } 57 | 58 | override def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: java.sql.Array): Unit = { 59 | ps.setArray(pIdx, value) 60 | } 61 | } 62 | 63 | implicit val stringArray: DataType[Seq[String]] = arrayType.adaptWithPs( 64 | _.getArray.asInstanceOf[Array[String]].toSeq, 65 | (v, ps) => { 66 | val array = ps.getConnection.createArrayOf(JDBCType.VARCHAR.toString, v.toArray) 67 | array 68 | } 69 | ) 70 | 71 | implicit val stringList: DataType[List[String]] = stringArray.adapt(_.toList, identity) 72 | 73 | implicit val intArray: DataType[Seq[Int]] = arrayType.adaptWithPs( 74 | _.getArray.asInstanceOf[Array[Int]].toSeq, 75 | (v, ps) => { 76 | val array = ps.getConnection.createArrayOf(JDBCType.INTEGER.toString, v.toArray) 77 | array 78 | } 79 | ) 80 | } 81 | 82 | object BasicProfile extends BasicProfile 83 | -------------------------------------------------------------------------------- /src/main/scala/usql/profiles/H2Profile.scala: -------------------------------------------------------------------------------- 1 | package usql.profiles 2 | 3 | trait H2Profile extends BasicProfile {} 4 | 5 | object H2Profile extends H2Profile 6 | -------------------------------------------------------------------------------- /src/main/scala/usql/profiles/PostgresProfile.scala: -------------------------------------------------------------------------------- 1 | package usql.profiles 2 | 3 | import usql.DataType 4 | import java.sql.{JDBCType, PreparedStatement, ResultSet} 5 | import java.util.UUID 6 | 7 | trait PostgresProfile extends BasicProfile { 8 | implicit val uuidType: DataType[UUID] = new DataType[UUID] { 9 | override def jdbcType: JDBCType = JDBCType.OTHER 10 | 11 | override def extractBySqlIdx(cIdx: Int, rs: ResultSet): UUID = rs.getObject(cIdx, classOf[UUID]) 12 | 13 | override def fillBySqlIdx(pIdx: Int, ps: PreparedStatement, value: UUID): Unit = ps.setObject(pIdx, value) 14 | } 15 | } 16 | 17 | object PostgresProfile extends PostgresProfile 18 | -------------------------------------------------------------------------------- /src/test/scala/com/example/example.sc: -------------------------------------------------------------------------------- 1 | import usql.* 2 | import usql.dao.* 3 | import usql.profiles.H2Profile.given 4 | 5 | import java.sql.{Connection, DriverManager} 6 | import scala.util.{Try, Using} 7 | 8 | Class.forName("org.h2.Driver") 9 | 10 | val jdbcUrl = "jdbc:h2:mem:hello;DB_CLOSE_DELAY=-1" 11 | given cp: ConnectionProvider with { 12 | override def withConnection[T](f: Connection ?=> T): T = { 13 | Using.resource(DriverManager.getConnection(jdbcUrl)) { c => 14 | f(using c) 15 | } 16 | } 17 | } 18 | 19 | // Simple Actions 20 | 21 | sql"CREATE TABLE person (id INT PRIMARY KEY, name TEXT)" 22 | .execute() 23 | 24 | sql"INSERT INTO person (id, name) VALUES (${1}, ${"Alice"})" 25 | .execute() 26 | 27 | sql"INSERT INTO person (id, name) VALUES (${2}, ${"Bob"})" 28 | .execute() 29 | 30 | // Simple Queries 31 | 32 | val all: Vector[(Int, String)] = sql"SELECT id, name FROM person".query.all[(Int, String)]() 33 | println(s"All=${all}") 34 | 35 | // Constant Parts of the query 36 | 37 | val one: Option[(Int, String)] = sql"SELECT id, name FROM #${"person"} WHERE id = ${1}".query.one[(Int, String)]() 38 | println(s"One=${one}") 39 | 40 | val ids = Seq(1, 2, 3) 41 | val names = sql"SELECT name FROM person WHERE id IN (${SqlParameters(ids)})".query.all[String]() 42 | println(s"Names=${names}") 43 | 44 | // Inserts 45 | 46 | sql"INSERT INTO person (id, name) VALUES(?, ?)".one((3, "Charly")).update.run() 47 | sql"INSERT INTO person (id, name) VALUES(?, ?)" 48 | .batch( 49 | Seq( 50 | 4 -> "Dave", 51 | 5 -> "Emil" 52 | ) 53 | ) 54 | .run() 55 | 56 | sql"SELECT COUNT(*) FROM person".query.one[Int]().get 57 | 58 | // Reusable Parts 59 | val select = sql"SELECT id, name FROM person" 60 | val selectAlice = (select + sql" WHERE id = ${1}").query.one[(Int, String)]() 61 | println(s"Alice: ${selectAlice}") 62 | 63 | // Transactions 64 | 65 | Try { 66 | transaction { 67 | sql"INSERT INTO person(id, name) VALUES(${100}, ${"Duplicate"})".execute() 68 | sql"INSERT INTO person(id, name) VALUES(${100}, ${"Duplicate 2"})".execute() 69 | } 70 | } 71 | 72 | // Dao 73 | 74 | case class Person( 75 | id: Int, 76 | name: String 77 | ) derives SqlTabular 78 | 79 | object Person extends KeyedCrudBase[Int, Person] { 80 | override def key: KeyColumnPath = cols.id 81 | 82 | override def keyOf(value: Person): Int = value.id 83 | 84 | override lazy val tabular: SqlTabular[Person] = summon 85 | } 86 | 87 | println(s"All Persons: ${Person.findAll()}") 88 | 89 | Person.insert(Person(6, "Fritz")) 90 | Person.update(Person(6, "Franziska")) 91 | println(Person.findByKey(6)) // Person(6, Franziska) 92 | 93 | // Simple Queries (using Scala 3.7.0+ Named Tuples) 94 | 95 | val allAgain: Vector[(Int, String)] = 96 | sql"SELECT ${Person.cols.id}, ${Person.cols.name} FROM ${Person}".query.all[(Int, String)]() 97 | 98 | println(s"allAgain=${allAgain}") 99 | -------------------------------------------------------------------------------- /src/test/scala/usql/AutoGeneratedUpdateTest.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.dao.{ColumnPath, KeyedCrudBase, SqlColumnar, SqlTabular} 4 | import usql.util.TestBaseWithH2 5 | 6 | class AutoGeneratedUpdateTest extends TestBaseWithH2 { 7 | 8 | override protected def baseSql: String = 9 | """ 10 | |CREATE TABLE tenant ( 11 | | id SERIAL NOT NULL PRIMARY KEY, 12 | | name TEXT 13 | |); 14 | |""".stripMargin 15 | 16 | case class Tenant( 17 | id: Int, 18 | name: Option[String] 19 | ) derives SqlTabular 20 | 21 | object Tenant extends KeyedCrudBase[Int, Tenant] { 22 | override def key: KeyColumnPath = cols.id 23 | 24 | override lazy val tabular: SqlTabular[Tenant] = summon 25 | } 26 | 27 | it should "be possible to insert values" in { 28 | Tenant.findAll() shouldBe empty 29 | val sample = Tenant(1, Some("Hello World")) 30 | Tenant.insert(sample) 31 | Tenant.findAll() shouldBe Seq(sample) 32 | } 33 | 34 | it should "be possible to use auto generated keys" in { 35 | sql"INSERT INTO tenant (name) VALUES (${"Alice"})".update.run() 36 | sql"INSERT INTO tenant (name) VALUES (${"Bob"})".update.run() 37 | Tenant.findAll() should contain theSameElementsAs Seq( 38 | Tenant(1, Some("Alice")), 39 | Tenant(2, Some("Bob")) 40 | ) 41 | } 42 | 43 | it should "be possible to return auto generated keys" in { 44 | val id1 = sql"INSERT INTO tenant (name) VALUES (${"Alice"})".update.runAndGetGenerated[Int]() 45 | val id2 = sql"INSERT INTO tenant (name) VALUES (${"Bob"})".update.runAndGetGenerated[Int]() 46 | Tenant.findAll() should contain theSameElementsAs Seq( 47 | Tenant(id1, Some("Alice")), 48 | Tenant(id2, Some("Bob")) 49 | ) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/test/scala/usql/HelloDbTest.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.util.TestBaseWithH2 4 | 5 | import java.sql.SQLException 6 | 7 | class HelloDbTest extends TestBaseWithH2 { 8 | 9 | override protected def baseSql: String = 10 | """ 11 | |CREATE TABLE "user" (id INT PRIMARY KEY, name VARCHAR); 12 | |""".stripMargin 13 | 14 | val tableName = SqlIdentifier.fromString("user") 15 | 16 | it should "work" in { 17 | sql"""INSERT INTO "user" (id, name) VALUES (${1}, ${"Hello World"})""".update.run() 18 | sql"""INSERT INTO "user" (id, name) VALUES (${3}, ${"How are you?"})""".update.run() 19 | 20 | withClue("it should be possible to build various result row parsers") { 21 | summon[RowDecoder[EmptyTuple]] 22 | summon[RowDecoder[Int *: EmptyTuple]] 23 | summon[RowDecoder[Int]] 24 | summon[RowDecoder[(Int, String)]] 25 | } 26 | 27 | sql"""SELECT id, name FROM "user" WHERE id=${1}""".query.one[(Int, String)]() shouldBe Some(1 -> "Hello World") 28 | 29 | sql"""SELECT id, name FROM "user" WHERE id=${2}""".query.one[(Int, String)]() shouldBe None 30 | 31 | sql"""SELECT id, name FROM "user" ORDER BY id""".query.all[(Int, String)]() shouldBe Seq( 32 | 1 -> "Hello World", 33 | 3 -> "How are you?" 34 | ) 35 | 36 | withClue("It should allow inferenced return types") { 37 | val result: Seq[(Int, String)] = sql"""SELECT id, name FROM "user" ORDER BY id""".query.all() 38 | result shouldBe Seq( 39 | 1 -> "Hello World", 40 | 3 -> "How are you?" 41 | ) 42 | } 43 | } 44 | 45 | it should "allow hash replacements" in { 46 | sql"""SELECT id, name FROM #${"\"user\""} WHERE id=${1}""".query.one[(Int, String)]() shouldBe empty 47 | } 48 | 49 | it should "allow identifiers" in { 50 | val userTable = SqlIdentifier.fromString("user") 51 | sql"""SELECT id, name FROM ${userTable}""".query.one[(Int, String)]() shouldBe empty 52 | } 53 | 54 | it should "allow batch inserts" in { 55 | val batchInsert = sql"""INSERT INTO "user" (id, name) VALUES(?,?)""".batch( 56 | Seq( 57 | 1 -> "Hello", 58 | 2 -> "World" 59 | ) 60 | ) 61 | val response = batchInsert.run() 62 | response shouldBe Seq(1, 1) 63 | 64 | val got = sql"""SELECT id, name FROM "user" ORDER BY ID""".query.all[(Int, String)]() 65 | got shouldBe Seq( 66 | 1 -> "Hello", 67 | 2 -> "World" 68 | ) 69 | } 70 | 71 | it should "allow transactions" in { 72 | val insertCall = sql"INSERT INTO ${tableName} (id, name) VALUES(${1}, ${"Alice"})" 73 | intercept[SQLException] { 74 | transaction { 75 | insertCall.execute() 76 | insertCall.execute() 77 | } 78 | } 79 | sql"SELECT COUNT(*) FROM ${tableName}".query.one[Int]() shouldBe Some(0) 80 | 81 | transaction { 82 | insertCall.execute() 83 | } 84 | 85 | sql"SELECT COUNT(*) FROM ${tableName}".query.one[Int]() shouldBe Some(1) 86 | } 87 | 88 | it should "allow in queries" in { 89 | sql"""INSERT INTO "user" (id, name) VALUES (${1}, ${"Alice"})""".update.run() 90 | sql"""INSERT INTO "user" (id, name) VALUES (${3}, ${"Bob"})""".update.run() 91 | 92 | val ids = Seq(1, 2, 3) 93 | 94 | val got = 95 | sql""" 96 | SELECT name FROM "user" WHERE id IN (${SqlParameters(ids)}) 97 | """.query.all[String]() 98 | 99 | got should contain theSameElementsAs Seq("Alice", "Bob") 100 | 101 | sql""" 102 | SELECT name FROM "user" WHERE id IN (${SqlParameters(Seq(9, 8, 7, 6))}) 103 | """.query.all[String]() shouldBe empty 104 | 105 | sql""" 106 | SELECT name FROM "user" WHERE id IN (${SqlParameters(Nil: Seq[Int])}) 107 | """.query.all[String]() shouldBe empty 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/test/scala/usql/SqlIdentifierTest.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import usql.util.TestBase 4 | 5 | class SqlIdentifierTest extends TestBase { 6 | "fromString" should "automatically quote" in { 7 | SqlIdentifier.fromString("foo") shouldBe SqlIdentifier("foo", false) 8 | SqlIdentifier.fromString("id") shouldBe SqlIdentifier("id", false) 9 | SqlIdentifier.fromString("user") shouldBe SqlIdentifier("user", true) 10 | SqlIdentifier.fromString("\"foo\"") shouldBe SqlIdentifier("foo", true) 11 | intercept[IllegalArgumentException] { 12 | SqlIdentifier.fromString("\"id\"\"") 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/test/scala/usql/SqlInterpolationTest.scala: -------------------------------------------------------------------------------- 1 | package usql 2 | 3 | import SqlInterpolationParameter.{Empty, SqlParameter} 4 | import usql.profiles.BasicProfile.* 5 | import usql.util.TestBase 6 | 7 | class SqlInterpolationTest extends TestBase { 8 | it should "work" in { 9 | val baz = 123 10 | val buz = "Hello" 11 | sql"foo ${baz} bar" shouldBe Sql( 12 | Seq(("foo ", SqlParameter(baz)), (" bar", SqlInterpolationParameter.Empty)) 13 | ) 14 | val withParams = sql"foo ${baz} bar ${buz}" 15 | withParams.sql shouldBe "foo ? bar ?" 16 | 17 | withParams shouldBe Sql( 18 | Seq(("foo ", SqlParameter(baz)), (" bar ", SqlParameter("Hello"))) 19 | ) 20 | sql"${baz}" shouldBe Sql( 21 | Seq( 22 | ("", SqlParameter(baz)) 23 | ) 24 | ) 25 | 26 | val identifier = SqlIdentifier.fromString("table1") 27 | val withSingle = sql"select * from ${identifier}" 28 | withSingle shouldBe Sql( 29 | Seq("select * from " -> SqlInterpolationParameter.IdentifierParameter(identifier)) 30 | ) 31 | withSingle.sql shouldBe "select * from table1" 32 | 33 | val identifiers = 34 | Seq( 35 | SqlIdentifier.fromString("a"), 36 | SqlIdentifier.fromString("b") 37 | ) 38 | 39 | val withIdentifiers = sql"select ${identifiers} from ${identifier} where id = ${2}" 40 | withIdentifiers shouldBe Sql( 41 | Seq( 42 | "select " -> SqlInterpolationParameter.IdentifiersParameter(identifiers), 43 | " from " -> SqlInterpolationParameter.IdentifierParameter(identifier), 44 | " where id = " -> SqlInterpolationParameter.SqlParameter(2) 45 | ) 46 | ) 47 | withIdentifiers.sql shouldBe "select a,b from table1 where id = ?" 48 | } 49 | 50 | it should "allow stripMargin" in { 51 | sql""" 52 | |Hello ${1} 53 | |World ${2} 54 | |""".stripMargin shouldBe Sql( 55 | Seq( 56 | "\nHello " -> SqlParameter(1), 57 | "\nWorld " -> SqlParameter(2), 58 | "\n" -> Empty 59 | ) 60 | ) 61 | } 62 | 63 | it should "allow concatenation" in { 64 | sql"HELLO ${1}" + sql"WORLD ${2}" shouldBe Sql( 65 | Seq( 66 | "HELLO " -> SqlParameter(1), 67 | "WORLD " -> SqlParameter(2) 68 | ) 69 | ) 70 | } 71 | 72 | it should "allow embedded frags" in { 73 | val inner = sql"SELECT * FROM foo" 74 | val combined = sql"${inner} WHERE id = ${1} AND bar = ${2}" 75 | combined shouldBe Sql( 76 | Seq( 77 | "SELECT * FROM foo" -> Empty, 78 | " WHERE id = " -> SqlParameter(1), 79 | " AND bar = " -> SqlParameter(2) 80 | ) 81 | ) 82 | } 83 | 84 | it should "also work in another case" in { 85 | val inner = sql"C = ${2}" 86 | val foo = sql"HELLO a = ${1} AND" 87 | val combined = (sql"HELLO a = ${1} AND ${inner}") 88 | combined shouldBe Sql( 89 | Seq( 90 | "HELLO a = " -> SqlParameter(1), 91 | " AND " -> Empty, 92 | "C = " -> SqlParameter(2) 93 | ) 94 | ) 95 | combined.sql shouldBe "HELLO a = ? AND C = ?" 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/ColumnPathTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | import usql.util.TestBase 5 | import usql.profiles.BasicProfile.* 6 | 7 | class ColumnPathTest extends TestBase { 8 | 9 | case class SubSubElement( 10 | foo: Boolean 11 | ) derives SqlFielded 12 | 13 | case class SubElement( 14 | a: Int, 15 | b: String, 16 | @ColumnGroup 17 | sub2: SubSubElement 18 | ) derives SqlFielded 19 | 20 | case class Sample( 21 | x: Int, 22 | @ColumnGroup(ColumnGroupMapping.Anonymous) 23 | sub: SubElement 24 | ) derives SqlFielded 25 | 26 | val path: ColumnPath[Sample, Sample] = ColumnPath.make 27 | 28 | it should "fetch identifiers" in { 29 | path.x.buildIdentifier shouldBe SqlIdentifier.fromString("x") 30 | path.sub.a.buildIdentifier shouldBe SqlIdentifier.fromString("a") 31 | intercept[IllegalStateException] { 32 | path.sub.sub2.buildIdentifier 33 | } 34 | path.sub.sub2.foo.buildIdentifier shouldBe SqlIdentifier.fromString("sub2_foo") 35 | } 36 | 37 | it should "fetch elements" in { 38 | val sample = Sample( 39 | 100, 40 | sub = SubElement( 41 | a = 101, 42 | b = "Hello", 43 | sub2 = SubSubElement( 44 | true 45 | ) 46 | ) 47 | ) 48 | val getter1 = path.x.buildGetter 49 | val getter2 = path.sub.sub2.foo.buildGetter 50 | getter1(sample) shouldBe 100 51 | getter2(sample) shouldBe true 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/KeyedCrudBaseTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | import usql.profiles.BasicProfile.* 5 | import usql.util.TestBaseWithH2 6 | import scala.language.implicitConversions 7 | 8 | class KeyedCrudBaseTest extends TestBaseWithH2 { 9 | override protected def baseSql: String = 10 | """ 11 | |CREATE TABLE "user" ( 12 | | id INT PRIMARY KEY, 13 | | name TEXT, 14 | | age INT 15 | |) 16 | |""".stripMargin 17 | 18 | case class User( 19 | id: Int, 20 | name: Option[String], 21 | age: Option[Int] 22 | ) derives SqlTabular 23 | 24 | object UserCrd extends KeyedCrudBase[Int, User] { 25 | override def key: KeyColumnPath = cols.id 26 | 27 | override lazy val tabular: SqlTabular[User] = summon 28 | } 29 | 30 | val sample1 = User(1, Some("Alice"), Some(42)) 31 | val sample2 = User(2, Some("Bob"), None) 32 | val sample3 = User(3, None, None) 33 | 34 | it should "properly escape" in { 35 | UserCrd.tabular.tableName shouldBe SqlIdentifier.fromString("user") 36 | UserCrd.tabular.columns.map(_.id) shouldBe Seq( 37 | "id", 38 | "name", 39 | "age" 40 | ).map(SqlIdentifier.fromString) 41 | } 42 | 43 | it should "provide basic crd features" in { 44 | UserCrd.countAll() shouldBe 0 45 | UserCrd.insert(sample1) 46 | UserCrd.insert(sample2) 47 | UserCrd.countAll() shouldBe 2 48 | UserCrd.findAll() should contain theSameElementsAs Seq(sample1, sample2) 49 | UserCrd.findByKey(1) shouldBe Some(sample1) 50 | UserCrd.findByKey(0) shouldBe empty 51 | UserCrd.deleteByKey(0) shouldBe 0 52 | UserCrd.deleteByKey(1) shouldBe 1 53 | UserCrd.findAll() should contain theSameElementsAs Seq(sample2) 54 | UserCrd.deleteAll() shouldBe 1 55 | UserCrd.countAll() shouldBe 0 56 | UserCrd.findAll() shouldBe empty 57 | } 58 | 59 | it should "provide updates" in { 60 | UserCrd.insert(Seq(sample1, sample2)) 61 | val sample2x = sample2.copy( 62 | age = Some(100) 63 | ) 64 | UserCrd.update(sample3) shouldBe 0 // was not existant 65 | UserCrd.update(sample2x) shouldBe 1 66 | UserCrd.findAll() should contain theSameElementsAs Seq( 67 | sample1, 68 | sample2x 69 | ) 70 | 71 | val sample1x = sample1.copy(name = None) 72 | UserCrd.update(sample1x) shouldBe 1 73 | 74 | UserCrd.findAll() should contain theSameElementsAs Seq( 75 | sample1x, 76 | sample2x 77 | ) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/NameMappingTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import NameMapping.Default 4 | import usql.SqlIdentifier 5 | import usql.util.TestBase 6 | 7 | class NameMappingTest extends TestBase { 8 | "Default" should "work" in { 9 | Default.caseClassToTableName("foo.bar.MySuperClass") shouldBe SqlIdentifier.fromString("my_super_class") 10 | Default.caseClassToTableName("foo.bar.User") shouldBe SqlIdentifier.fromString("user") 11 | Default.columnToSql("id") shouldBe SqlIdentifier.fromString("id") 12 | Default.columnToSql("myData") shouldBe SqlIdentifier.fromString("my_data") 13 | 14 | } 15 | 16 | "snakeCase" should "work" in { 17 | NameMapping.snakeCase("foo") shouldBe "foo" 18 | NameMapping.snakeCase("fooBar") shouldBe "foo_bar" 19 | NameMapping.snakeCase("XyzTCPStream") shouldBe "xyz_tcpstream" 20 | NameMapping.snakeCase("BOOM") shouldBe "boom" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/SimpleJoinTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.* 4 | import usql.util.TestBaseWithH2 5 | import scala.language.implicitConversions 6 | 7 | class SimpleJoinTest extends TestBaseWithH2 { 8 | 9 | override protected def baseSql: String = 10 | """ 11 | |CREATE TABLE person ( 12 | | id INT PRIMARY KEY, 13 | | name TEXT NOT NULL, 14 | | level_id INT 15 | |); 16 | | 17 | |CREATE TABLE level( 18 | | id INT PRIMARY KEY, 19 | | level_name TEXT 20 | |); 21 | |""".stripMargin 22 | 23 | case class Person( 24 | id: Int, 25 | name: String, 26 | levelId: Option[Int] = None 27 | ) derives SqlTabular 28 | 29 | object Person extends KeyedCrudBase[Int, Person] { 30 | override def key: KeyColumnPath = cols.id 31 | 32 | override lazy val tabular: SqlTabular[Person] = summon 33 | } 34 | 35 | case class Level( 36 | id: Int, 37 | levelName: String 38 | ) derives SqlTabular 39 | 40 | object Level extends KeyedCrudBase[Int, Level] { 41 | override def key: KeyColumnPath = cols.id 42 | 43 | override lazy val tabular: SqlTabular[Level] = summon 44 | } 45 | 46 | trait Env { 47 | val person1 = Person(1, "Alice") 48 | val person2 = Person(2, "Bob", Some(1)) 49 | val person3 = Person(3, "Charly", Some(2)) 50 | val person4 = Person(4, "Secret", Some(999)) 51 | 52 | Person.insert(person1) 53 | Person.insert(person2) 54 | Person.insert(person3) 55 | Person.insert(person4) 56 | 57 | val level1 = Level(1, "Administrator") 58 | val level2 = Level(2, "Regular") 59 | val level3 = Level(3, "Nobody") 60 | 61 | Level.insert(level1) 62 | Level.insert(level2) 63 | Level.insert(level3) 64 | } 65 | 66 | val person = Person.alias("p") 67 | val level = Level.alias("l") 68 | 69 | it should "do an easy inner join" in new Env { 70 | val joined = 71 | sql"""SELECT ${person.columns}, ${level.columns} 72 | FROM ${person} INNER JOIN ${level} 73 | WHERE p.level_id = l.id 74 | """.query.all[(Person, Level)]() 75 | 76 | joined should contain theSameElementsAs Seq( 77 | (person2, level1), 78 | (person3, level2) 79 | ) 80 | } 81 | 82 | it should "do an easy left join" in new Env { 83 | val joined = 84 | sql"""SELECT ${person.columns}, ${level.columns} 85 | FROM ${person} LEFT JOIN ${level} ON p.level_id = l.id 86 | """.query.all[(Person, Option[Level])]() 87 | 88 | joined should contain theSameElementsAs Seq( 89 | (person1, None), 90 | (person2, Some(level1)), 91 | (person3, Some(level2)), 92 | (person4, None) 93 | ) 94 | } 95 | 96 | it should "provide access to aliased column names" in new Env { 97 | person.col.name.buildIdentifier.toString shouldBe "p.name" 98 | val selected = 99 | sql""" 100 | SELECT ${person.col.name} FROM ${person} 101 | """.query.all[String]() 102 | selected should contain theSameElementsAs Person.findAll().map(_.name) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/SqlColumnarTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.{SqlIdentifier, dao} 4 | import usql.profiles.BasicProfile.given 5 | import usql.util.TestBase 6 | 7 | class SqlColumnarTest extends TestBase { 8 | case class Sample( 9 | name: String, 10 | age: Int 11 | ) 12 | 13 | "Tabular" should "be derivable" in { 14 | val tabular = SqlTabular.derived[Sample] 15 | tabular.columns.map(_.id) shouldBe Seq(SqlIdentifier.fromString("name"), SqlIdentifier.fromString("age")) 16 | tabular.tableName shouldBe SqlIdentifier.fromString("sample") 17 | } 18 | 19 | @TableName("samplename") 20 | case class SampleWithAnnotations( 21 | @ColumnName("my_name") name: String, 22 | age: Int 23 | ) 24 | 25 | it should "work with annotations" in { 26 | val tabular = SqlTabular.derived[SampleWithAnnotations] 27 | tabular.tableName shouldBe SqlIdentifier.fromString("samplename") 28 | tabular.columns.map(_.id) shouldBe Seq(SqlIdentifier.fromString("my_name"), SqlIdentifier.fromString("age")) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/SqlCrdBaseTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | import usql.profiles.BasicProfile.* 5 | import usql.util.TestBaseWithH2 6 | 7 | class SqlCrdBaseTest extends TestBaseWithH2 { 8 | override protected def baseSql: String = 9 | """ 10 | |CREATE TABLE coordinate( 11 | | id INT PRIMARY KEY, 12 | | x INT, 13 | | y INT 14 | |); 15 | | 16 | | 17 | |CREATE TABLE subcoord ( 18 | | id INT PRIMARY KEY, 19 | | from_x DOUBLE, 20 | | from_y DOUBLE, 21 | | x_to DOUBLE, 22 | | y_to DOUBLE 23 | |); 24 | |""".stripMargin 25 | 26 | case class Coordinate(id: Int, x: Int, y: Int) derives SqlTabular 27 | 28 | object CoordinateCrd extends CrdBase[Coordinate] { 29 | override lazy val tabular: SqlTabular[Coordinate] = summon 30 | } 31 | 32 | val sample = Coordinate(0, 5, 6) 33 | val samples = Seq( 34 | Coordinate(1, 10, 20), 35 | Coordinate(2, 20, 30) 36 | ) 37 | 38 | it should "do the usual operations with one item" in { 39 | CoordinateCrd.countAll() shouldBe 0 40 | CoordinateCrd.findAll() shouldBe empty 41 | CoordinateCrd.insert(sample) shouldBe 1 42 | CoordinateCrd.countAll() shouldBe 1 43 | CoordinateCrd.findAll() shouldBe Seq(sample) 44 | CoordinateCrd.deleteAll() shouldBe 1 45 | CoordinateCrd.findAll() shouldBe empty 46 | } 47 | 48 | it should "do the usual operations with many items" in { 49 | CoordinateCrd.insert(samples) shouldBe samples.size 50 | CoordinateCrd.countAll() shouldBe 2 51 | CoordinateCrd.findAll() should contain theSameElementsAs samples 52 | CoordinateCrd.deleteAll() shouldBe 2 53 | CoordinateCrd.countAll() shouldBe 0 54 | CoordinateCrd.findAll() shouldBe empty 55 | } 56 | 57 | case class SubCoord(x: Double, y: Double) derives SqlFielded 58 | 59 | @TableName("subcoord") 60 | case class WithSubCoords( 61 | id: Int, 62 | from: SubCoord, 63 | @ColumnGroup(ColumnGroupMapping.Pattern("%c_to")) 64 | to: SubCoord 65 | ) derives SqlTabular 66 | 67 | object WithSubCoords extends KeyedCrudBase[Int, WithSubCoords] { 68 | override def key: KeyColumnPath = cols.id 69 | 70 | override lazy val tabular: SqlTabular[WithSubCoords] = summon 71 | } 72 | 73 | it should "work for nested columns" in { 74 | val examples = Seq( 75 | WithSubCoords(3, SubCoord(3.4, 5.6), SubCoord(7.8, 9.7)), 76 | WithSubCoords(4, SubCoord(2.4, 1.6), SubCoord(2.8, 1.7)) 77 | ) 78 | WithSubCoords.insert(examples) 79 | WithSubCoords.findAll() should contain theSameElementsAs examples 80 | WithSubCoords.deleteByKey(1) 81 | WithSubCoords.findAll() should contain theSameElementsAs examples 82 | WithSubCoords.deleteByKey(3) 83 | WithSubCoords.findAll() should contain theSameElementsAs Seq(examples(1)) 84 | WithSubCoords.findByKey(3) shouldBe None 85 | WithSubCoords.findByKey(4) shouldBe Some(examples(1)) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/SqlFieldedTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.SqlIdentifier 4 | import usql.util.TestBase 5 | import usql.profiles.BasicProfile.* 6 | 7 | class SqlFieldedTest extends TestBase { 8 | case class Coordinate( 9 | x: Int, 10 | y: Int 11 | ) derives SqlFielded 12 | 13 | @TableName("test_person") 14 | case class Person( 15 | id: Int, 16 | @ColumnName("long_name") 17 | name: String, 18 | age: Option[Int], 19 | @ColumnGroup 20 | coordinate: Coordinate 21 | ) derives SqlTabular 22 | 23 | object Person extends KeyedCrudBase[Int, Person] { 24 | override def key: KeyColumnPath = cols.id 25 | 26 | override lazy val tabular: SqlTabular[Person] = summon 27 | } 28 | 29 | it should "work" in { 30 | val adapter = summon[SqlFielded[Person]] 31 | adapter.fields.map(_.fieldName) shouldBe Seq("id", "name", "age", "coordinate") 32 | adapter.columns 33 | .map(_.id) shouldBe Seq("id", "long_name", "age", "coordinate_x", "coordinate_y").map(SqlIdentifier.fromString) 34 | 35 | intercept[IllegalStateException] { 36 | adapter.cols.buildIdentifier 37 | } 38 | adapter.cols.name.buildIdentifier shouldBe SqlIdentifier.fromString("long_name") 39 | Person.cols.name.buildIdentifier shouldBe SqlIdentifier.fromString("long_name") 40 | 41 | adapter.cols.coordinate.x.buildIdentifier shouldBe SqlIdentifier.fromString("coordinate_x") 42 | Person.cols.coordinate.x.buildIdentifier shouldBe SqlIdentifier.fromString("coordinate_x") 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/test/scala/usql/dao/SqlTabularTest.scala: -------------------------------------------------------------------------------- 1 | package usql.dao 2 | 3 | import usql.* 4 | import usql.util.TestBase 5 | import usql.profiles.BasicProfile.* 6 | 7 | class SqlTabularTest extends TestBase { 8 | case class Nested( 9 | x: Double, 10 | y: Double 11 | ) derives SqlFielded 12 | 13 | case class WithNested( 14 | @ColumnGroup(ColumnGroupMapping.Pattern(pattern = "a_%c")) 15 | a: Nested, 16 | @ColumnGroup(ColumnGroupMapping.Pattern(pattern = "%c_s")) 17 | b: Nested, 18 | c: Nested 19 | ) 20 | 21 | it should "work for nested" in { 22 | val tabular = SqlTabular.derived[WithNested] 23 | tabular.rowEncoder.cardinality shouldBe 6 24 | tabular.rowDecoder.cardinality shouldBe 6 25 | tabular.columns.map(_.id) shouldBe Seq("a_x", "a_y", "x_s", "y_s", "c_x", "c_y").map(SqlIdentifier.fromString) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/scala/usql/util/TestBase.scala: -------------------------------------------------------------------------------- 1 | package usql.util 2 | 3 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 4 | import org.scalatest.flatspec.AnyFlatSpec 5 | import org.scalatest.matchers.should.Matchers 6 | 7 | abstract class TestBase extends AnyFlatSpec with Matchers with BeforeAndAfterEach with BeforeAndAfterAll {} 8 | -------------------------------------------------------------------------------- /src/test/scala/usql/util/TestBaseWithH2.scala: -------------------------------------------------------------------------------- 1 | package usql.util 2 | 3 | import usql.ConnectionProvider 4 | import usql.profiles.H2Profile 5 | 6 | import java.sql.{Connection, DriverManager} 7 | import scala.util.{Random, Using} 8 | 9 | abstract class TestBaseWithH2 extends TestBase with H2Profile { 10 | 11 | protected def baseSql: String = "" 12 | 13 | private var _rootConnection: Option[Connection] = None 14 | private var _url: Option[String] = None 15 | 16 | protected def jdbcUrl: String = _url.getOrElse { 17 | throw new IllegalStateException(s"No connection") 18 | } 19 | 20 | given cp: ConnectionProvider with { 21 | override def withConnection[T](f: Connection ?=> T): T = { 22 | Using.resource(DriverManager.getConnection(jdbcUrl)) { c => 23 | f(using c) 24 | } 25 | } 26 | } 27 | 28 | override protected def beforeEach(): Unit = { 29 | super.beforeEach() 30 | val name = "db" + Math.abs(Random.nextLong()) 31 | classOf[org.h2.Driver].toString 32 | val url = s"jdbc:h2:mem:${name};DB_CLOSE_DELAY=-1" 33 | val connection = DriverManager.getConnection(url) 34 | _rootConnection = Some(connection) 35 | _url = Some(url) 36 | 37 | runBaseSql() 38 | } 39 | 40 | override protected def afterEach(): Unit = { 41 | _rootConnection.foreach(_.close()) 42 | super.afterEach() 43 | } 44 | 45 | protected def runSql(sql: String): Unit = { 46 | cp.withConnection { 47 | val c = summon[Connection] 48 | c.prepareStatement(sql).execute() 49 | } 50 | } 51 | 52 | protected def runSqlMultiline(sql: String): Unit = { 53 | val splitted = splitSql(baseSql) 54 | splitted.foreach { line => 55 | TestBaseWithH2.this.runSql(line) 56 | } 57 | } 58 | 59 | private def runBaseSql(): Unit = { 60 | runSqlMultiline(baseSql) 61 | } 62 | 63 | protected def splitSql(s: String): Seq[String] = { 64 | // Note: very rough 65 | s.split("(?<=;)\\s+").toSeq.map(_.trim.stripSuffix(";")).filter(_.nonEmpty) 66 | } 67 | 68 | } 69 | --------------------------------------------------------------------------------