├── .editorconfig ├── .github ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── haskell.yml ├── .gitignore ├── .stylish-haskell.yaml ├── CHANGELOG.md ├── LICENSE ├── README.lhs ├── README.md ├── Setup.hs ├── persistent-typed-db.cabal ├── src └── Database │ └── Persist │ └── Typed.hs ├── stack-lts-12.yaml ├── stack.yaml └── test ├── EsqueletoSpec.hs └── Spec.hs /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | root = true 3 | 4 | [Makefile] 5 | indent_style = tabs 6 | indent_size = 8 7 | end_of_line = lf 8 | charset = utf-8 9 | trim_trailing_whitespace = true 10 | insert_final_newline = true 11 | 12 | [*.yaml] 13 | indent_style = space 14 | indent_size = 2 15 | tab_width = 2 16 | end_of_line = lf 17 | charset = utf-8 18 | trim_trailing_whitespace = true 19 | insert_final_newline = true 20 | max_line_length = 80 21 | 22 | 23 | [*.{hs,md,php}] 24 | indent_style = space 25 | indent_size = 4 26 | tab_width = 4 27 | end_of_line = lf 28 | charset = utf-8 29 | trim_trailing_whitespace = true 30 | insert_final_newline = true 31 | max_line_length = 80 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 33 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Before submitting your PR, check that you've: 2 | 3 | - [ ] Documented new APIs with [Haddock markup](https://www.haskell.org/haddock/doc/html/index.html) 4 | - [ ] Added [`@since` declarations](http://haskell-haddock.readthedocs.io/en/latest/markup.html#since) to the Haddock 5 | - [ ] Ran `stylish-haskell` on any changed files. 6 | - [ ] Adhered to the code style (see the `.editorconfig` file for details) 7 | 8 | After submitting your PR: 9 | 10 | - [ ] Update the Changelog.md file with a link to your PR 11 | - [ ] Bumped the version number if there isn't an `(unreleased)` on the Changelog 12 | - [ ] Check that CI passes (or if it fails, for reasons unrelated to your change, like CI timeouts) 13 | 14 | 17 | -------------------------------------------------------------------------------- /.github/workflows/haskell.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | types: 8 | - opened 9 | - synchronize 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | cabal: ["3.12"] 16 | ghc: 17 | - "8.8" 18 | - "8.10" 19 | - "9.0" 20 | - "9.2" 21 | - "9.4" 22 | - "9.6" 23 | - "9.8" 24 | - "9.10" 25 | - "9.12" 26 | env: 27 | CONFIG: "--enable-tests" 28 | steps: 29 | - uses: actions/checkout@v4 30 | - uses: haskell-actions/setup@v2 31 | id: setup-haskell-cabal 32 | with: 33 | ghc-version: ${{ matrix.ghc }} 34 | cabal-version: ${{ matrix.cabal }} 35 | - run: cabal v2-update 36 | - run: cabal v2-freeze $CONFIG 37 | - uses: actions/cache@v3 38 | with: 39 | path: | 40 | ${{ steps.setup-haskell-cabal.outputs.cabal-store }} 41 | dist-newstyle 42 | key: ${{ runner.os }}-${{ matrix.ghc }}-${{ hashFiles('cabal.project.freeze') }} 43 | restore-keys: | 44 | ${{ runner.os }}-${{ matrix.ghc }}-${{ hashFiles('cabal.project.freeze') }} 45 | ${{ runner.os }}-${{ matrix.ghc }}- 46 | - run: cabal v2-build --disable-optimization -j $CONFIG 47 | - run: cabal v2-test --disable-optimization -j $CONFIG 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | *.yaml.lock 3 | .cabal-sandbox/ 4 | cabal.sandbox.config 5 | .stack-work/ 6 | tarballs/ 7 | dist-newstyle/ 8 | -------------------------------------------------------------------------------- /.stylish-haskell.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - imports: 3 | align: none 4 | list_align: with_module_name 5 | pad_module_names: false 6 | long_list_align: new_line_multiline 7 | empty_list_align: inherit 8 | list_padding: 7 # length "import " 9 | separate_lists: false 10 | space_surround: false 11 | - language_pragmas: 12 | style: vertical 13 | align: false 14 | remove_redundant: true 15 | - simple_align: 16 | cases: false 17 | top_level_patterns: false 18 | records: false 19 | - trailing_whitespace: {} 20 | indent: 4 21 | columns: 80 22 | newline: native 23 | language_extensions: 24 | - BlockArguments 25 | - DataKinds 26 | - DeriveGeneric 27 | - DerivingStrategies 28 | - DerivingVia 29 | - ExplicitForAll 30 | - FlexibleContexts 31 | - MultiParamTypeClasses 32 | - NamedFieldPuns 33 | - OverloadedStrings 34 | - QuantifiedConstraints 35 | - RecordWildCards 36 | - ScopedTypeVariables 37 | - TemplateHaskell 38 | - TypeApplications 39 | - ViewPatterns 40 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | # 0.1.0.7 4 | 5 | - Support `persistent-2.14` [#18](https://github.com/parsonsmatt/persistent-typed-db/pull/18) 6 | 7 | # 0.1.0.6 8 | 9 | - Support `aeson-2` 10 | 11 | # 0.1.0.5 12 | 13 | - Support `persistent-2.13.2` 14 | 15 | # 0.1.0.4 16 | 17 | - Support `persistent-2.13` 18 | 19 | # 0.1.0.3 20 | 21 | - Support `persistent-2.12` 22 | 23 | # 0.1.0.2 24 | 25 | - Fix test suite to build with `persistent-2.11` 26 | 27 | # 0.1.0.1 28 | 29 | - Fix test suite to build for stackage 30 | 31 | # v0.1.0.0 32 | 33 | - Add support for `persistent` v2.10.0 and `persistent-template` 2.7.0 34 | 35 | # v0.0.1.1 36 | 37 | - Fix build on earlier versions of GHC. 38 | - Use the bulk insertMany functions where possible. 39 | 40 | # v0.0.1.0 41 | 42 | - Initial Release 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Author name here (c) 2017 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Author name here nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.lhs: -------------------------------------------------------------------------------- 1 | README.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `persistent-typed-db` 2 | 3 | [![Build Status](https://travis-ci.org/parsonsmatt/persistent-typed-db.svg?branch=master)](https://travis-ci.org/parsonsmatt/persistent-typed-db) 4 | 5 | This library defines an alternate `SqlBackend` type for the Haskell [persistent](https://hackage.haskell.org/package/persistent) database library. 6 | The type has a phantom type parameter which allows you to write queries against multiple databases safely. 7 | 8 | # The Problem 9 | 10 | The `persistent` library uses a "handle pattern" with the `SqlBackend` type. 11 | The `SqlBackend` (or `Pool SqlBackend`) is used to provide access to "The Database." 12 | To access the database, you generally use a function like: 13 | 14 | ```haskell 15 | runDB :: SqlPersistT m a -> App a 16 | ``` 17 | 18 | However, you may have two (or more!) databases. 19 | You might have multiple database "runner" functions, like this: 20 | 21 | ```haskell 22 | runMainDB 23 | :: SqlPersistT m a -> App a 24 | runAuxDB 25 | :: SqlPersistT m a -> App a 26 | ``` 27 | 28 | Unfortunately, this isn't safe. 29 | The schemas differ, and there's nothing preventing you from using the wrong runner for the query at hand. 30 | 31 | This library allows you to differentiate between database schemata. 32 | To demonstrate usage, we'll start with a pair of ordinary `persistent` quasiquoter blocks: 33 | 34 | ```haskell 35 | share [ mkPersist sqlSettings, mkMigrate "migrateAll" ] [persistLowerCase| 36 | 37 | User 38 | name Text 39 | age Int 40 | 41 | deriving Show Eq 42 | |] 43 | 44 | share [ mkPersist sqlSettings, mkMigrate "migrateAll" ] [persistLowerCase| 45 | 46 | AuxRecord 47 | createdAt UTCTime 48 | reason Text 49 | 50 | deriving Show Eq 51 | |] 52 | ``` 53 | 54 | These two definitions correspond to different databases. 55 | The `persistent` library uses the `SqlPersistT m` monad transformer for queries. 56 | This type is a synonym for `ReaderT SqlBackend m`. 57 | Many of the functions defined in `persistent` have a signature like this: 58 | 59 | ```haskell 60 | get :: (MonadIO m, PersistEntityBackend record ~ backend) 61 | => Key record 62 | -> ReaderT backend m (Maybe record) 63 | ``` 64 | 65 | It requires that the entity is compatible with the query monad. 66 | We're going to substitute `User` for the type variable `record`. 67 | In the initial schema definition, the `PersistEntityBackend User` is defined as `SqlBackend`. 68 | So the type of `get`, in the original definition, is this: 69 | 70 | ```haskell 71 | get :: (MonadIO m, PersistEntityBackend User ~ SqlBackend) 72 | => Key record 73 | -> ReaderT SqlBackend m (Maybe User) 74 | ``` 75 | 76 | If we look at the type of `get` specialized to `AuxRecord`, we see this: 77 | 78 | ```haskell 79 | get :: (MonadIO m, PersistEntityBackend AuxRecord ~ SqlBackend) 80 | => Key record 81 | -> ReaderT SqlBackend m (Maybe AuxRecord) 82 | ``` 83 | 84 | This means that we might be able to write a query like this: 85 | 86 | ```haskell 87 | impossibleQuery 88 | :: MonadIO m 89 | => SqlPersistT m (Maybe User, Maybe AuxRecord) 90 | impossibleQuery = do 91 | muser <- get (UserKey 1) 92 | maux <- get (AuxRecordKey 1) 93 | pure (muser, maux) 94 | ``` 95 | 96 | This query will fail at runtime, since the entities exist on different schemata. 97 | Likewise, there's nothing in the types to stop you from running a query against 98 | the wrong backend: 99 | 100 | ```haskell 101 | app = do 102 | runMainDB $ get (AuxRecordKey 3) 103 | runAuxDb $ get (UserKey 3) 104 | ``` 105 | 106 | # The Solution 107 | 108 | Let's solve this problem. 109 | 110 | ## Declaring the Schema 111 | 112 | We are going to create an empty datatype tag for each schema, and then we're going to use `mkSqlSettingsFor` instead of `sqlSettings`. 113 | 114 | ```haskell 115 | data MainDb 116 | data AuxDb 117 | 118 | share [ mkPersist (mkSqlSettingsFor ''MainDb), mkMigrate "migrateAll" ] [persistLowerCase| 119 | 120 | User 121 | name Text 122 | age Int 123 | 124 | deriving Show Eq 125 | |] 126 | 127 | share [ mkPersist (mkSqlSettingsFor ''AuxDb), mkMigrate "migrateAll" ] [persistLowerCase| 128 | 129 | AuxRecord 130 | createdAt UTCTime 131 | reason Text 132 | 133 | deriving Show Eq 134 | |] 135 | ``` 136 | 137 | This changes the type of the `PersistEntityBackend record` for each entity defined in the QuasiQuoter. 138 | The previous type of `PersistEntityBackend User` was `SqlBackend`, but with this change, it is now `SqlFor MainDb`. 139 | Likewise, the type of `PersistEntityBackend AuxRecord` has become `SqlFor AuxDb`. 140 | 141 | ## Using the Schema 142 | 143 | Let's look at the new type of `get` for these two records: 144 | 145 | ```haskell 146 | get :: (MonadIO m, PersistEntityBackend User ~ SqlFor MainDb) 147 | => Key record 148 | -> ReaderT (SqlFor MainDb) m (Maybe User) 149 | 150 | get :: (MonadIO m, PersistEntityBackend AuxRecord ~ SqlFor AuxDb) 151 | => Key record 152 | -> ReaderT (SqlFor AuxDb) m (Maybe AuxRecord) 153 | ``` 154 | 155 | Now that the monad type is different, we can't use them in the same query. 156 | Our previous `impossibleQuery` now fails with a type error. 157 | 158 | The `persistent-typed-db` library defines a type synonym for `ReaderT`. 159 | It is similar to the `SqlPersistT` synonym: 160 | 161 | ```haskell 162 | type SqlPersistT = ReaderT SqlBackend 163 | type SqlPersistTFor db = ReaderT (SqlFor db) 164 | ``` 165 | 166 | When using this library, it is a good idea to define a type snynonym for your databases as well. 167 | So we might also write: 168 | 169 | ```haskell 170 | type MainDbT = SqlPersistTFor MainDb 171 | type AuxDbT = SqlPersistTFor AuxDb 172 | ``` 173 | 174 | The type of our runner functions has changed, as well. 175 | Before, we accepted a `SqlPersistT`, but now, we'll accept the right query type for the database: 176 | 177 | ```haskell 178 | runMainDb :: MainDbT m a -> App a 179 | 180 | runAuxDb :: AuxDbT m a -> App a 181 | ``` 182 | 183 | We'll cover how to define these runner functions soon. 184 | 185 | ## Defining the Runner Function 186 | 187 | `persistent` defines a function `runSqlPool` that is useful for running a SQL action. 188 | The type is essentially this: 189 | 190 | ```haskell 191 | runSqlPool 192 | :: (MonadUnliftIO m, IsSqlBackend backend) 193 | => ReaderT backend m a 194 | -> Pool backend 195 | -> m a 196 | ``` 197 | 198 | `persistent-typed-db` defines a function that is a drop in replacement for this, called `runSqlPoolFor`. 199 | 200 | ```haskell 201 | runSqlPoolFor 202 | :: (MonadUnliftIO m) 203 | => SqlPersistTFor db m a 204 | -> ConnectionPoolFor db 205 | -> m a 206 | ``` 207 | 208 | It is defined by generalizing the input query and pool, and delegating to `runSqlPool`. 209 | 210 | ```haskell 211 | runSqlPoolFor query conn = 212 | runSqlPool (generalizeQuery query) (generalizePool conn) 213 | ``` 214 | 215 | Sometimes, you'll have some function that is in `SqlPersistT` that you want to use on a specialized database. 216 | This can occur with raw queries, like `rawSql` and friends, or other queries/actions that are not tied to a `PersistEntityBackend` type. 217 | In this case, you'll want to use `specializeQuery`. 218 | You will likely want to define type-specified helpers that are aliases for `specializeQuery`: 219 | 220 | ```haskell 221 | toMainQuery :: SqlPersistT m a -> MainDbT m a 222 | toMainQuery = specializeQuery 223 | 224 | toAuxQuery :: SqlPersistT m a -> AuxDbT m a 225 | toAuxQuery = specializeQuery 226 | ``` 227 | 228 | ## Constructing the Pools 229 | 230 | `persistent` (and the relevant database-specific libraries) define many functions for creating connections. 231 | We'll use [`createPostgresqlPool`](https://hackage.haskell.org/package/persistent-postgresql-2.9.0/docs/Database-Persist-Postgresql.html#v:createPostgresqlPool) as an example. 232 | This is one place where you do need to be careful, as you are tagging the database pool with the database type. 233 | 234 | To create a specific database pool, we'll map `specializePool` over the result 235 | of `createPostgresqlPool`: 236 | 237 | ```haskell 238 | createPostgresqlPoolFor 239 | :: ConnectionString 240 | -> Int 241 | -> IO (ConnectionPoolFor db) 242 | createPostgresqlPoolFor connStr i = 243 | specializePool <$> createPostgresqlPool connStr i 244 | ``` 245 | 246 | It is a good idea to make specialized variants of this function to improve type 247 | inference and errors: 248 | 249 | ```haskell 250 | createMainPool :: ConnectionString -> Int -> IO (ConnectionPoolFor MainDb) 251 | createMainPool = createPostgresqlPoolFor 252 | 253 | createAuxPool :: ConnectionString -> Int -> IO (ConnectionPoolFor AuxDb) 254 | createAuxPool = createPostgresqlPoolFor 255 | ``` 256 | 257 | It is common to use `with`-style functions with these, as well. 258 | These functions automate closure of the database resources. 259 | We can specialize these functions similarly: 260 | 261 | ```haskell 262 | withPoolFor 263 | :: forall db m a 264 | . (MonadLogger m, MonadUnliftIO m) 265 | => ConnectionString 266 | -> Int 267 | -> (ConnectionPoolFor db -> m a) 268 | -> m a 269 | withPoolFor connStr conns action = 270 | withPostgresqlPool connStr conns $ \genericPool -> 271 | action (specializePool genericPool) 272 | ``` 273 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /persistent-typed-db.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | name: persistent-typed-db 3 | version: 0.1.0.7 4 | synopsis: Type safe access to multiple database schemata. 5 | description: See README.md for more details, examples, and fun. 6 | category: Web 7 | homepage: https://github.com/parsonsmatt/persistent-typed-db#readme 8 | bug-reports: https://github.com/parsonsmatt/persistent-typed-db/issues 9 | author: Matt Parsons 10 | maintainer: parsonsmatt@gmail.com 11 | copyright: 2017 Matt Parsons 12 | license: BSD3 13 | license-file: LICENSE 14 | build-type: Simple 15 | extra-source-files: 16 | README.md 17 | CHANGELOG.md 18 | 19 | source-repository head 20 | type: git 21 | location: https://github.com/parsonsmatt/persistent-typed-db 22 | 23 | library 24 | exposed-modules: 25 | Database.Persist.Typed 26 | other-modules: 27 | Paths_persistent_typed_db 28 | hs-source-dirs: 29 | src 30 | ghc-options: -Wall -Wcompat -Wincomplete-record-updates -Wcompat -Wincomplete-uni-patterns 31 | build-depends: 32 | aeson 33 | , base >=4.7 && <5 34 | , bytestring 35 | , conduit >=1.3.0 36 | , http-api-data 37 | , monad-logger 38 | , path-pieces 39 | , persistent >=2.13.0 && < 3 40 | , resource-pool 41 | , resourcet >=1.2.0 42 | , template-haskell 43 | , text 44 | , transformers 45 | default-language: Haskell2010 46 | 47 | test-suite specs 48 | type: exitcode-stdio-1.0 49 | main-is: Spec.hs 50 | other-modules: 51 | EsqueletoSpec 52 | Paths_persistent_typed_db 53 | hs-source-dirs: 54 | test 55 | ghc-options: -threaded -rtsopts -with-rtsopts=-N 56 | build-depends: 57 | aeson 58 | , base >=4.10 59 | , bytestring 60 | , conduit >=1.3.0 61 | , esqueleto 62 | , hspec 63 | , http-api-data 64 | , monad-logger 65 | , path-pieces 66 | , persistent >= 2.13 67 | , persistent-typed-db 68 | , resource-pool 69 | , resourcet >=1.2.0 70 | , template-haskell 71 | , text 72 | , transformers 73 | build-tool-depends: 74 | hspec-discover:hspec-discover 75 | default-language: Haskell2010 76 | -------------------------------------------------------------------------------- /src/Database/Persist/Typed.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 | {-# LANGUAGE InstanceSigs #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | {-# LANGUAGE OverloadedStrings #-} 7 | {-# LANGUAGE QuasiQuotes #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE TemplateHaskell #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# language CPP #-} 12 | 13 | -- | This module defines types and helpers for type-safe access to multiple 14 | -- database schema. 15 | module Database.Persist.Typed 16 | ( -- * Schema Definition 17 | mkSqlSettingsFor 18 | , SqlFor(..) 19 | , BackendKey(..) 20 | -- * Specialized aliases 21 | , SqlPersistTFor 22 | , ConnectionPoolFor 23 | , SqlPersistMFor 24 | -- * Running specialized queries 25 | , runSqlPoolFor 26 | , runSqlConnFor 27 | -- * Specializing and generalizing 28 | , generalizePool 29 | , specializePool 30 | , generalizeQuery 31 | , specializeQuery 32 | , generalizeSqlBackend 33 | , specializeSqlBackend 34 | -- * Key functions 35 | , toSqlKeyFor 36 | , fromSqlKeyFor 37 | ) where 38 | 39 | import Control.Exception hiding (throw) 40 | import Control.Monad.IO.Class (MonadIO(..)) 41 | import Control.Monad.Logger (NoLoggingT) 42 | import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, withReaderT) 43 | import Control.Monad.Trans.Resource (MonadUnliftIO, ResourceT) 44 | import qualified Data.Aeson as A 45 | import Data.ByteString.Char8 (readInteger) 46 | import Data.Coerce (coerce) 47 | import Data.Conduit ((.|)) 48 | import qualified Data.Conduit.List as CL 49 | import qualified Data.Foldable as Foldable 50 | import Data.Foldable (toList) 51 | import Data.Int (Int64) 52 | import Data.List (find, inits, transpose) 53 | import Data.Maybe (isJust) 54 | import Data.Monoid (mappend) 55 | import Data.Pool (Pool) 56 | import Data.Text (Text) 57 | import qualified Data.Text as Text 58 | import Database.Persist.Sql hiding (deleteWhereCount, orderClause, updateWhereCount) 59 | import Database.Persist.Sql.Types.Internal (IsPersistBackend(..)) 60 | import Database.Persist.Sql.Util 61 | import Database.Persist.SqlBackend.Internal 62 | import Database.Persist.TH (MkPersistSettings, mkPersistSettings) 63 | import Language.Haskell.TH (Name, Type(..)) 64 | import Web.HttpApiData (FromHttpApiData, ToHttpApiData) 65 | import Web.PathPieces (PathPiece) 66 | 67 | #if MIN_VERSION_persistent(2,14,0) 68 | import Database.Persist.Class.PersistEntity (SafeToInsert) 69 | #else 70 | import GHC.Exts (Constraint) 71 | #endif 72 | 73 | -- | A wrapper around 'SqlBackend' type. To specialize this to a specific 74 | -- database, fill in the type parameter. 75 | -- 76 | -- @since 0.0.1.0 77 | newtype SqlFor db = SqlFor { unSqlFor :: SqlBackend } 78 | 79 | instance BackendCompatible SqlBackend (SqlFor db) where 80 | projectBackend = unSqlFor 81 | 82 | -- | This type signature represents a database query for a specific database. 83 | -- You will likely want to specialize this to your own application for 84 | -- readability: 85 | -- 86 | -- @ 87 | -- data MainDb 88 | -- 89 | -- type MainQueryT = 'SqlPersistTFor' MainDb 90 | -- 91 | -- getStuff :: 'MonadIO' m => StuffId -> MainQueryT m (Maybe Stuff) 92 | -- @ 93 | -- 94 | -- @since 0.0.1.0 95 | type SqlPersistTFor db = ReaderT (SqlFor db) 96 | 97 | -- | A 'Pool' of database connections that are specialized to a specific 98 | -- database. 99 | -- 100 | -- @since 0.0.1.0 101 | type ConnectionPoolFor db = Pool (SqlFor db) 102 | -- 103 | -- | A specialization of 'SqlPersistM' that uses the underlying @db@ database 104 | -- type. 105 | -- 106 | -- @since 0.0.1.0 107 | type SqlPersistMFor db = ReaderT (SqlFor db) (NoLoggingT (ResourceT IO)) 108 | 109 | -- | Specialize a query to a specific database. You should define aliases for 110 | -- this function for each database you use. 111 | -- 112 | -- @ 113 | -- data MainDb 114 | -- 115 | -- data AccountDb 116 | -- 117 | -- mainQuery :: 'ReaderT' 'SqlBackend' m a -> 'ReaderT' ('SqlFor' MainDb) m a 118 | -- mainQuery = 'specializeQuery' 119 | -- 120 | -- accountQuery :: 'ReaderT' 'SqlBackend' m a -> 'ReaderT' ('SqlFor' AccountDb) m a 121 | -- accountQuery = 'specializeQuery' 122 | -- @ 123 | -- 124 | -- @since 0.0.1.0 125 | specializeQuery :: forall db m a. SqlPersistT m a -> SqlPersistTFor db m a 126 | specializeQuery = withReaderT unSqlFor 127 | 128 | -- | Generalizes a query from a specific database to one that is database 129 | -- agnostic. 130 | -- 131 | -- @since 0.0.1.0 132 | generalizeQuery :: forall db m a. SqlPersistTFor db m a -> SqlPersistT m a 133 | generalizeQuery = withReaderT SqlFor 134 | 135 | -- | Use the 'SqlFor' type for the database connection backend. Use this instead 136 | -- of 'sqlSettings' and provide a quoted type name. 137 | -- 138 | -- @ 139 | -- data MainDb 140 | -- 141 | -- share [ mkPersist (mkSqlSettingsFor ''MainDb), mkMigrate "migrateAll" ] [persistLowerCase| 142 | -- 143 | -- User 144 | -- name Text 145 | -- age Int 146 | -- 147 | -- deriving Show Eq 148 | -- |] 149 | -- @ 150 | -- 151 | -- The entities generated will have the 'PersistEntityBackend' defined to be 152 | -- @'SqlFor' MainDb@ instead of 'SqlBackend'. This is what provides the type 153 | -- safety. 154 | -- 155 | -- @since 0.0.1.0 156 | mkSqlSettingsFor :: Name -> MkPersistSettings 157 | mkSqlSettingsFor n = mkPersistSettings (AppT (ConT ''SqlFor) (ConT n)) 158 | 159 | -- | Persistent's @toSqlKey@ and @fromSqlKey@ hardcode the 'SqlBackend', so we 160 | -- have to reimplement them here. 161 | -- 162 | -- @since 0.0.1.0 163 | toSqlKeyFor :: (ToBackendKey (SqlFor a) record) => Int64 -> Key record 164 | toSqlKeyFor = fromBackendKey . SqlForKey . SqlBackendKey 165 | 166 | -- | Persistent's @toSqlKey@ and @fromSqlKey@ hardcode the 'SqlBackend', so we 167 | -- have to reimplement them here. 168 | -- 169 | -- @since 0.0.1.0 170 | fromSqlKeyFor :: ToBackendKey (SqlFor a) record => Key record -> Int64 171 | fromSqlKeyFor = unSqlBackendKey . unSqlForKey . toBackendKey 172 | 173 | -- | Specialize a 'ConnectionPool' to a @'Pool' ('SqlFor' db)@. You should apply 174 | -- this whenever you create or initialize the database connection pooling to 175 | -- avoid potentially mixing the database pools up. 176 | -- 177 | -- @since 0.0.1.0 178 | specializePool :: ConnectionPool -> ConnectionPoolFor db 179 | specializePool = coerce 180 | 181 | -- | Generalize a @'Pool' ('SqlFor' db)@ to an ordinary 'ConnectionPool'. This 182 | -- renders the pool unusable for model-specific code that relies on the type 183 | -- safety, but allows you to use it for general-purpose SQL queries. 184 | -- 185 | -- @since 0.0.1.0 186 | generalizePool :: ConnectionPoolFor db -> ConnectionPool 187 | generalizePool = coerce 188 | 189 | -- | Specializes a 'SqlBackend' for a specific database. 190 | -- 191 | -- @since 0.0.1.0 192 | specializeSqlBackend :: SqlBackend -> SqlFor db 193 | specializeSqlBackend = SqlFor 194 | 195 | -- | Generalizes a 'SqlFor' backend to be database agnostic. 196 | -- 197 | -- @since 0.0.1.0 198 | generalizeSqlBackend :: SqlFor db -> SqlBackend 199 | generalizeSqlBackend = unSqlFor 200 | 201 | -- | Run a 'SqlPersistTFor' action on an appropriate database. 202 | -- 203 | -- @since 0.0.1.0 204 | runSqlPoolFor 205 | :: MonadUnliftIO m 206 | => SqlPersistTFor db m a 207 | -> ConnectionPoolFor db 208 | -> m a 209 | runSqlPoolFor query conn = 210 | runSqlPool (generalizeQuery query) (generalizePool conn) 211 | 212 | -- | Run a 'SqlPersistTFor' action on the appropriate database connection. 213 | -- 214 | -- @since 0.0.1.0 215 | runSqlConnFor 216 | :: MonadUnliftIO m 217 | => SqlPersistTFor db m a 218 | -> SqlFor db 219 | -> m a 220 | runSqlConnFor query conn = 221 | runSqlConn (generalizeQuery query) (generalizeSqlBackend conn) 222 | 223 | -- The following instances are almost entirely copy-pasted from the Persistent 224 | -- library for SqlBackend. 225 | instance HasPersistBackend (SqlFor a) where 226 | type BaseBackend (SqlFor a) = SqlFor a 227 | persistBackend = id 228 | 229 | instance IsPersistBackend (SqlFor a) where 230 | mkPersistBackend = id 231 | 232 | instance PersistCore (SqlFor a) where 233 | newtype BackendKey (SqlFor a) = 234 | SqlForKey { unSqlForKey :: BackendKey SqlBackend } 235 | deriving ( Show, Read, Eq, Ord, Num, Integral, PersistField 236 | , PersistFieldSql, PathPiece, ToHttpApiData, FromHttpApiData 237 | , Real, Enum, Bounded, A.ToJSON, A.FromJSON 238 | ) 239 | 240 | instance PersistStoreRead (SqlFor a) where 241 | get k = do 242 | conn <- asks unSqlFor 243 | let t = entityDef $ dummyFromKey k 244 | let cols = Text.intercalate "," 245 | $ map (connEscapeRawName conn . unFieldNameDB . fieldDB) $ getEntityFieldsDatabase t 246 | noColumns :: Bool 247 | noColumns = null $ getEntityFieldsDatabase t 248 | let wher = whereStmtForKey conn k 249 | let sql = Text.concat 250 | [ "SELECT " 251 | , if noColumns then "*" else cols 252 | , " FROM " 253 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 254 | , " WHERE " 255 | , wher 256 | ] 257 | flip runReaderT conn $ withRawQuery sql (keyToValues k) $ do 258 | res <- CL.head 259 | case res of 260 | Nothing -> return Nothing 261 | Just vals -> 262 | case fromPersistValues $ if noColumns then [] else vals of 263 | Left e -> error $ "get " ++ show k ++ ": " ++ Text.unpack e 264 | Right v -> return $ Just v 265 | 266 | instance PersistStoreWrite (SqlFor a) where 267 | update _ [] = return () 268 | update k upds = specializeQuery $ do 269 | conn <- ask 270 | let go'' n Assign = n <> "=?" 271 | go'' n Add = Text.concat [n, "=", n, "+?"] 272 | go'' n Subtract = Text.concat [n, "=", n, "-?"] 273 | go'' n Multiply = Text.concat [n, "=", n, "*?"] 274 | go'' n Divide = Text.concat [n, "=", n, "/?"] 275 | go'' _ (BackendSpecificUpdate up) = error $ Text.unpack $ "BackendSpecificUpdate" `Data.Monoid.mappend` up `mappend` "not supported" 276 | let go' (x, pu) = go'' (connEscapeRawName conn x) pu 277 | let wher = whereStmtForKey conn k 278 | let sql = Text.concat 279 | [ "UPDATE " 280 | , connEscapeRawName conn $ unEntityNameDB $ tableDBName $ recordTypeFromKey k 281 | , " SET " 282 | , Text.intercalate "," $ map (go' . go) upds 283 | , " WHERE " 284 | , wher 285 | ] 286 | rawExecute sql $ 287 | map updatePersistValue upds `mappend` keyToValues k 288 | where 289 | go x = (unFieldNameDB $ fieldDB $ updateFieldDef x, updateUpdate x) 290 | 291 | insert val = specializeQuery $ do 292 | conn <- ask 293 | case connInsertSql conn t vals of 294 | ISRSingle sql -> withRawQuery sql vals $ do 295 | x <- CL.head 296 | case x of 297 | Just [PersistInt64 i] -> case keyFromValues [PersistInt64 i] of 298 | Left err -> error $ "SQL insert: keyFromValues: PersistInt64 " `mappend` show i `mappend` " " `mappend` Text.unpack err 299 | Right k -> return k 300 | Nothing -> error "SQL insert did not return a result giving the generated ID" 301 | Just vals' -> case keyFromValues vals' of 302 | Left _ -> error $ "Invalid result from a SQL insert, got: " ++ show vals' 303 | Right k -> return k 304 | 305 | ISRInsertGet sql1 sql2 -> do 306 | rawExecute sql1 vals 307 | withRawQuery sql2 [] $ do 308 | mm <- CL.head 309 | let m = maybe 310 | (Left $ "No results from ISRInsertGet: " `mappend` tshow (sql1, sql2)) 311 | Right mm 312 | 313 | -- TODO: figure out something better for MySQL 314 | let convert x = 315 | case x of 316 | [PersistByteString i] -> case readInteger i of -- mssql 317 | Just (ret,"") -> [PersistInt64 $ fromIntegral ret] 318 | _ -> x 319 | _ -> x 320 | -- Yes, it's just <|>. Older bases don't have the 321 | -- instance for Either. 322 | onLeft Left{} x = x 323 | onLeft x _ = x 324 | 325 | case m >>= (\x -> keyFromValues x `onLeft` keyFromValues (convert x)) of 326 | Right k -> return k 327 | Left err -> throw $ "ISRInsertGet: keyFromValues failed: " `mappend` err 328 | ISRManyKeys sql fs -> do 329 | rawExecute sql vals 330 | case entityPrimary t of 331 | Nothing -> error $ "ISRManyKeys is used when Primary is defined " ++ show sql 332 | Just pdef -> 333 | let pks = map fieldHaskell $ toList $ compositeFields pdef 334 | keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ getEntityFieldsDatabase t) fs 335 | in case keyFromValues keyvals of 336 | Right k -> return k 337 | Left e -> error $ "ISRManyKeys: unexpected keyvals result: " `mappend` Text.unpack e 338 | where 339 | tshow :: Show a => a -> Text 340 | tshow = Text.pack . show 341 | throw = liftIO . throwIO . userError . Text.unpack 342 | t = entityDef $ Just val 343 | vals = map toPersistValue $ toPersistFields val 344 | 345 | insertMany [] = return [] 346 | insertMany vals = specializeQuery $ do 347 | conn <- ask 348 | 349 | case connInsertManySql conn of 350 | Nothing -> withReaderT SqlFor $ mapM insert vals 351 | Just insertManyFn -> 352 | case insertManyFn ent valss of 353 | ISRSingle sql -> rawSql sql (concat valss) 354 | _ -> error "ISRSingle is expected from the connInsertManySql function" 355 | where 356 | ent = entityDef vals 357 | valss = map (map toPersistValue . toPersistFields) vals 358 | 359 | insertEntityMany es' = specializeQuery $ do 360 | conn <- ask 361 | let entDef = entityDef $ map entityVal es' 362 | let columnNames = keyAndEntityColumnNames entDef conn 363 | runChunked (length columnNames) go es' 364 | where 365 | go = insrepHelper "INSERT" 366 | 367 | 368 | insertMany_ [] = return () 369 | insertMany_ vals0 = specializeQuery $ do 370 | conn <- ask 371 | case connMaxParams conn of 372 | Nothing -> insertMany_' vals0 373 | Just maxParams -> do 374 | let chunkSize = maxParams `div` length (getEntityFieldsDatabase t) 375 | mapM_ insertMany_' (chunksOf chunkSize vals0) 376 | where 377 | insertMany_' vals = do 378 | conn <- ask 379 | let valss = map (map toPersistValue . toPersistFields) vals 380 | let sql = Text.concat 381 | [ "INSERT INTO " 382 | , connEscapeRawName conn (unEntityNameDB $ getEntityDBName t) 383 | , "(" 384 | , Text.intercalate "," $ map (connEscapeRawName conn . unFieldNameDB . fieldDB) $ getEntityFieldsDatabase t 385 | , ") VALUES (" 386 | , Text.intercalate "),(" $ replicate (length valss) $ Text.intercalate "," $ map (const "?") (getEntityFieldsDatabase t) 387 | , ")" 388 | ] 389 | rawExecute sql (concat valss) 390 | 391 | t = entityDef vals0 392 | 393 | replace k val = do 394 | conn <- asks unSqlFor 395 | let t = entityDef $ Just val 396 | let wher = whereStmtForKey conn k 397 | let sql = Text.concat 398 | [ "UPDATE " 399 | , connEscapeRawName conn (unEntityNameDB $ getEntityDBName t) 400 | , " SET " 401 | , Text.intercalate "," (map (go conn . unFieldNameDB . fieldDB) $ getEntityFieldsDatabase t) 402 | , " WHERE " 403 | , wher 404 | ] 405 | vals = map toPersistValue (toPersistFields val) `mappend` keyToValues k 406 | specializeQuery $ rawExecute sql vals 407 | where 408 | go conn x = connEscapeRawName conn x `Text.append` "=?" 409 | 410 | insertKey k v = specializeQuery $ insrepHelper "INSERT" [Entity k v] 411 | 412 | repsert key value = do 413 | mExisting <- get key 414 | case mExisting of 415 | Nothing -> insertKey key value 416 | Just _ -> replace key value 417 | 418 | delete k = do 419 | conn <- asks unSqlFor 420 | specializeQuery $ rawExecute (sql conn) (keyToValues k) 421 | where 422 | wher conn = whereStmtForKey conn k 423 | sql conn = Text.concat 424 | [ "DELETE FROM " 425 | , connEscapeRawName conn $ unEntityNameDB $ tableDBName $ recordTypeFromKey k 426 | , " WHERE " 427 | , wher conn 428 | ] 429 | 430 | -- orphaned instance for convenience of modularity 431 | instance PersistQueryRead (SqlFor a) where 432 | exists filts = 433 | (>0) <$> count filts 434 | count filts = specializeQuery $ do 435 | conn <- ask 436 | let wher = if null filts 437 | then "" 438 | else filterClause Nothing conn filts 439 | let sql = mconcat 440 | [ "SELECT COUNT(*) FROM " 441 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 442 | , wher 443 | ] 444 | withRawQuery sql (getFiltsValues (SqlFor conn) filts) $ do 445 | mm <- CL.head 446 | case mm of 447 | Just [PersistInt64 i] -> return $ fromIntegral i 448 | Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64) -- gb oracle 449 | Just [PersistByteString i] -> case readInteger i of -- gb mssql 450 | Just (ret,"") -> return $ fromIntegral ret 451 | xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]" 452 | Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]" 453 | Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]" 454 | where 455 | t = entityDef $ dummyFromFilts filts 456 | 457 | selectSourceRes filts opts = specializeQuery $ do 458 | conn <- ask 459 | srcRes <- rawQueryRes (sql conn) (getFiltsValues (SqlFor conn) filts) 460 | return $ fmap (.| CL.mapM parse) srcRes 461 | where 462 | (limit, offset, orders) = limitOffsetOrder opts 463 | 464 | parse vals = case parseEntityValues t vals of 465 | Left s -> liftIO $ throwIO $ PersistMarshalError s 466 | Right row -> return row 467 | t = entityDef $ dummyFromFilts filts 468 | wher conn = if null filts 469 | then "" 470 | else filterClause Nothing conn filts 471 | ord conn = 472 | case map (orderClause False conn) orders of 473 | [] -> "" 474 | ords -> " ORDER BY " <> Text.intercalate "," ords 475 | cols = Text.intercalate ", " . toList . keyAndEntityColumnNames t 476 | sql conn = connLimitOffset conn (limit,offset) $ mconcat 477 | [ "SELECT " 478 | , cols conn 479 | , " FROM " 480 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 481 | , wher conn 482 | , ord (SqlFor conn) 483 | ] 484 | 485 | selectKeysRes filts opts = specializeQuery $ do 486 | conn <- ask 487 | srcRes <- rawQueryRes (sql conn) (getFiltsValues (SqlFor conn) filts) 488 | return $ fmap (.| CL.mapM parse) srcRes 489 | where 490 | t = entityDef $ dummyFromFilts filts 491 | cols conn = Text.intercalate "," $ toList $ dbIdColumns conn t 492 | 493 | 494 | wher conn = if null filts 495 | then "" 496 | else filterClause Nothing conn filts 497 | sql conn = connLimitOffset conn (limit,offset) $ mconcat 498 | [ "SELECT " 499 | , cols conn 500 | , " FROM " 501 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 502 | , wher conn 503 | , ord conn 504 | ] 505 | 506 | (limit, offset, orders) = limitOffsetOrder opts 507 | 508 | ord conn = 509 | case map (orderClause False (SqlFor conn)) orders of 510 | [] -> "" 511 | ords -> " ORDER BY " <> Text.intercalate "," ords 512 | 513 | parse xs = do 514 | keyvals <- case entityPrimary t of 515 | Nothing -> 516 | case xs of 517 | [PersistInt64 x] -> return [PersistInt64 x] 518 | [PersistDouble x] -> return [PersistInt64 (truncate x)] -- oracle returns Double 519 | _ -> return xs 520 | Just pdef -> 521 | let pks = map fieldHaskell $ toList $ compositeFields pdef 522 | keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ getEntityFieldsDatabase t) xs 523 | in return keyvals 524 | case keyFromValues keyvals of 525 | Right k -> return k 526 | Left err -> error $ "selectKeysImpl: keyFromValues failed" <> show err 527 | 528 | instance PersistUniqueWrite (SqlFor db) where 529 | upsertBy uniqueKey record updates = specializeQuery $ do 530 | conn <- ask 531 | let escape = connEscapeRawName conn 532 | let refCol n = Text.concat [escape (unEntityNameDB $ getEntityDBName t), ".", n] 533 | let mkUpdateFieldText = mkUpdateText' (escape . unFieldNameDB) refCol 534 | case connUpsertSql conn of 535 | Just upsertSql -> case updates of 536 | [] -> generalizeQuery $ defaultUpsertBy uniqueKey record updates 537 | _:_ -> do 538 | let upds = Text.intercalate "," $ map mkUpdateFieldText updates 539 | sql = upsertSql t (persistUniqueToFieldNames uniqueKey) upds 540 | vals = map toPersistValue (toPersistFields record) 541 | ++ map updatePersistValue updates 542 | ++ unqs uniqueKey 543 | 544 | x <- rawSql sql vals 545 | return $ head x 546 | Nothing -> generalizeQuery $ defaultUpsertBy uniqueKey record updates 547 | where 548 | t = entityDef $ Just record 549 | unqs uniqueKey' = concatMap persistUniqueToValues [uniqueKey'] 550 | 551 | deleteBy uniq = specializeQuery $ do 552 | conn <- ask 553 | let sql' = sql conn 554 | vals = persistUniqueToValues uniq 555 | rawExecute sql' vals 556 | where 557 | t = entityDef $ dummyFromUnique uniq 558 | go = map snd . toList . persistUniqueToFieldNames 559 | go' conn x = connEscapeRawName conn (unFieldNameDB x) `mappend` "=?" 560 | sql conn = 561 | Text.concat 562 | [ "DELETE FROM " 563 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 564 | , " WHERE " 565 | , Text.intercalate " AND " $ map (go' conn) $ go uniq] 566 | 567 | instance PersistUniqueRead (SqlFor a) where 568 | getBy uniq = specializeQuery $ do 569 | conn <- ask 570 | let sql = 571 | Text.concat 572 | [ "SELECT " 573 | , Text.intercalate "," $ toList $ dbColumns conn t 574 | , " FROM " 575 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 576 | , " WHERE " 577 | , sqlClause conn] 578 | uvals = persistUniqueToValues uniq 579 | withRawQuery sql uvals $ 580 | do row <- CL.head 581 | case row of 582 | Nothing -> return Nothing 583 | Just [] -> error "getBy: empty row" 584 | Just vals -> 585 | case parseEntityValues t vals of 586 | Left err -> 587 | liftIO $ throwIO $ PersistMarshalError err 588 | Right r -> return $ Just r 589 | where 590 | sqlClause conn = 591 | Text.intercalate " AND " $ map (go conn . unFieldNameDB) $ toFieldNames' uniq 592 | go conn x = connEscapeRawName conn x `mappend` "=?" 593 | t = entityDef $ dummyFromUnique uniq 594 | toFieldNames' = map snd . toList . persistUniqueToFieldNames 595 | 596 | instance PersistQueryWrite (SqlFor db) where 597 | deleteWhere filts = do 598 | _ <- deleteWhereCount filts 599 | return () 600 | updateWhere filts upds = do 601 | _ <- updateWhereCount filts upds 602 | return () 603 | -- 604 | -- Here be dragons! These are functions, types, and helpers that were vendored 605 | -- from Persistent. 606 | 607 | -- | Same as 'deleteWhere', but returns the number of rows affected. 608 | -- 609 | -- 610 | deleteWhereCount :: (PersistEntity val, MonadIO m, PersistEntityBackend val ~ SqlFor db) 611 | => [Filter val] 612 | -> ReaderT (SqlFor db) m Int64 613 | deleteWhereCount filts = withReaderT unSqlFor $ do 614 | conn <- ask 615 | let t = entityDef $ dummyFromFilts filts 616 | let wher = if null filts 617 | then "" 618 | else filterClause Nothing conn filts 619 | sql = mconcat 620 | [ "DELETE FROM " 621 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 622 | , wher 623 | ] 624 | rawExecuteCount sql $ getFiltsValues (SqlFor conn) filts 625 | 626 | -- | Same as 'updateWhere', but returns the number of rows affected. 627 | -- 628 | -- @since 1.1.5 629 | updateWhereCount :: (PersistEntity val, MonadIO m, SqlFor db ~ PersistEntityBackend val) 630 | => [Filter val] 631 | -> [Update val] 632 | -> ReaderT (SqlFor db) m Int64 633 | updateWhereCount _ [] = return 0 634 | updateWhereCount filts upds = withReaderT unSqlFor $ do 635 | conn <- ask 636 | let wher = if null filts 637 | then "" 638 | else filterClause Nothing conn filts 639 | let sql = mconcat 640 | [ "UPDATE " 641 | , connEscapeRawName conn $ unEntityNameDB $ getEntityDBName t 642 | , " SET " 643 | , Text.intercalate "," $ map (go' conn . go) upds 644 | , wher 645 | ] 646 | let dat = map updatePersistValue upds `Data.Monoid.mappend` 647 | getFiltsValues (SqlFor conn) filts 648 | rawExecuteCount sql dat 649 | where 650 | t = entityDef $ dummyFromFilts filts 651 | go'' n Assign = n <> "=?" 652 | go'' n Add = mconcat [n, "=", n, "+?"] 653 | go'' n Subtract = mconcat [n, "=", n, "-?"] 654 | go'' n Multiply = mconcat [n, "=", n, "*?"] 655 | go'' n Divide = mconcat [n, "=", n, "/?"] 656 | go'' _ (BackendSpecificUpdate up) = error $ Text.unpack $ "BackendSpecificUpdate" `mappend` up `mappend` "not supported" 657 | go' conn (x, pu) = go'' (connEscapeRawName conn x) pu 658 | go x = (updateField' x, updateUpdate x) 659 | 660 | updateField' (Update f _ _) = fieldName f 661 | updateField' _ = error "BackendUpdate not implemented" 662 | 663 | dummyFromKey :: Key record -> Maybe record 664 | dummyFromKey = Just . recordTypeFromKey 665 | 666 | recordTypeFromKey :: Key record -> record 667 | recordTypeFromKey _ = error "dummyFromKey" 668 | 669 | whereStmtForKey :: PersistEntity record => SqlBackend -> Key record -> Text 670 | whereStmtForKey conn k = 671 | Text.intercalate " AND " 672 | $ map (<> "=? ") 673 | $ toList $ dbIdColumns conn entDef 674 | where 675 | entDef = entityDef $ dummyFromKey k 676 | 677 | 678 | insrepHelper :: (MonadIO m, PersistEntity val) 679 | => Text 680 | -> [Entity val] 681 | -> ReaderT SqlBackend m () 682 | insrepHelper _ [] = return () 683 | insrepHelper command es = do 684 | conn <- ask 685 | let columnNames = toList $ keyAndEntityColumnNames entDef conn 686 | rawExecute (sql conn columnNames) vals 687 | where 688 | entDef = entityDef $ map entityVal es 689 | sql conn columnNames = Text.concat 690 | [ command 691 | , " INTO " 692 | , connEscapeRawName conn (unEntityNameDB $ getEntityDBName entDef) 693 | , "(" 694 | , Text.intercalate "," columnNames 695 | , ") VALUES (" 696 | , Text.intercalate "),(" $ replicate (length es) $ Text.intercalate "," $ map (const "?") columnNames 697 | , ")" 698 | ] 699 | vals = Foldable.foldMap entityValues es 700 | 701 | data OrNull = OrNullYes | OrNullNo 702 | 703 | filterClauseHelper :: (PersistEntity val, PersistEntityBackend val ~ SqlFor a) 704 | => Bool -- ^ include table name? 705 | -> Bool -- ^ include WHERE? 706 | -> SqlFor a 707 | -> OrNull 708 | -> [Filter val] 709 | -> (Text, [PersistValue]) 710 | filterClauseHelper includeTable includeWhere (SqlFor conn) orNull filters = 711 | (if not (Text.null sql) && includeWhere 712 | then " WHERE " <> sql 713 | else sql, vals) 714 | where 715 | (sql, vals) = combineAND filters 716 | combineAND = combine " AND " 717 | 718 | combine s fs = 719 | (Text.intercalate s $ map wrapP a, mconcat b) 720 | where 721 | (a, b) = unzip $ map go fs 722 | wrapP x = Text.concat ["(", x, ")"] 723 | 724 | go (BackendFilter _) = error "BackendFilter not expected" 725 | go (FilterAnd []) = ("1=1", []) 726 | go (FilterAnd fs) = combineAND fs 727 | go (FilterOr []) = ("1=0", []) 728 | go (FilterOr fs) = combine " OR " fs 729 | go (Filter field value pfilter) = 730 | let t = entityDef $ dummyFromFilts [Filter field value pfilter] 731 | in case (isIdField field, entityPrimary t, allVals) of 732 | (True, Just pdef, PersistList ys:_) -> 733 | if length (compositeFields pdef) /= length ys 734 | then error $ "wrong number of entries in compositeFields vs PersistList allVals=" ++ show allVals 735 | else 736 | case (allVals, pfilter, isCompFilter pfilter) of 737 | ([PersistList xs], Eq, _) -> 738 | let sqlcl=Text.intercalate " and " (map (\a -> connEscapeRawName conn (unFieldNameDB $ fieldDB a) <> showSqlFilter pfilter <> "? ") (toList $ compositeFields pdef)) 739 | in (wrapSql sqlcl,xs) 740 | ([PersistList xs], Ne, _) -> 741 | let sqlcl=Text.intercalate " or " (map (\a -> connEscapeRawName conn (unFieldNameDB $ fieldDB a) <> showSqlFilter pfilter <> "? ") (toList $ compositeFields pdef)) 742 | in (wrapSql sqlcl,xs) 743 | (_, In, _) -> 744 | let xxs = transpose (map fromPersistList allVals) 745 | sqls=map (\(a,xs) -> connEscapeRawName conn (unFieldNameDB $ fieldDB a) <> showSqlFilter pfilter <> "(" <> Text.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (toList $ compositeFields pdef) xxs) 746 | in (wrapSql (Text.intercalate " and " (map wrapSql sqls)), concat xxs) 747 | (_, NotIn, _) -> 748 | let xxs = transpose (map fromPersistList allVals) 749 | sqls=map (\(a,xs) -> connEscapeRawName conn (unFieldNameDB $ fieldDB a) <> showSqlFilter pfilter <> "(" <> Text.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (toList $ compositeFields pdef) xxs) 750 | in (wrapSql (Text.intercalate " or " (map wrapSql sqls)), concat xxs) 751 | ([PersistList xs], _, True) -> 752 | let zs = tail (inits (toList $ compositeFields pdef)) 753 | sql1 = map (\b -> wrapSql (Text.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs 754 | sql2 islast a = connEscapeRawName conn (unFieldNameDB $ fieldDB a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? " 755 | sqlcl = Text.intercalate " or " sql1 756 | in (wrapSql sqlcl, concat (tail (inits xs))) 757 | (_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys" 758 | _ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals 759 | (True, Just pdef, []) -> 760 | error $ "empty list given as filter value filter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef 761 | (True, Just pdef, _) -> 762 | error $ "unhandled error for composite/non id primary keys filter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef 763 | 764 | _ -> case (isNull, pfilter, length notNullVals) of 765 | (True, Eq, _) -> (name <> " IS NULL", []) 766 | (True, Ne, _) -> (name <> " IS NOT NULL", []) 767 | (False, Ne, _) -> (Text.concat 768 | [ "(" 769 | , name 770 | , " IS NULL OR " 771 | , name 772 | , " <> " 773 | , qmarks 774 | , ")" 775 | ], notNullVals) 776 | -- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since 777 | -- not all databases support those words directly. 778 | (_, In, 0) -> ("1=2" <> orNullSuffix, []) 779 | (False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals) 780 | (True, In, _) -> (Text.concat 781 | [ "(" 782 | , name 783 | , " IS NULL OR " 784 | , name 785 | , " IN " 786 | , qmarks 787 | , ")" 788 | ], notNullVals) 789 | (False, NotIn, 0) -> ("1=1", []) 790 | (True, NotIn, 0) -> (name <> " IS NOT NULL", []) 791 | (False, NotIn, _) -> (Text.concat 792 | [ "(" 793 | , name 794 | , " IS NULL OR " 795 | , name 796 | , " NOT IN " 797 | , qmarks 798 | , ")" 799 | ], notNullVals) 800 | (True, NotIn, _) -> (Text.concat 801 | [ "(" 802 | , name 803 | , " IS NOT NULL AND " 804 | , name 805 | , " NOT IN " 806 | , qmarks 807 | , ")" 808 | ], notNullVals) 809 | _ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals) 810 | 811 | where 812 | isCompFilter Lt = True 813 | isCompFilter Le = True 814 | isCompFilter Gt = True 815 | isCompFilter Ge = True 816 | isCompFilter _ = False 817 | 818 | wrapSql sqlcl = "(" <> sqlcl <> ")" 819 | fromPersistList (PersistList xs) = xs 820 | fromPersistList other = error $ "expected PersistList but found " ++ show other 821 | 822 | filterValueToPersistValues :: forall a. PersistField a => FilterValue a -> [PersistValue] 823 | filterValueToPersistValues v = case v of 824 | FilterValue a -> map toPersistValue [a] 825 | FilterValues as -> map toPersistValue as 826 | UnsafeValue a -> map toPersistValue [a] 827 | 828 | orNullSuffix = 829 | case orNull of 830 | OrNullYes -> mconcat [" OR ", name, " IS NULL"] 831 | OrNullNo -> "" 832 | 833 | isNull = PersistNull `elem` allVals 834 | notNullVals = filter (/= PersistNull) allVals 835 | allVals = filterValueToPersistValues value 836 | tn = connEscapeRawName conn $ unEntityNameDB $ getEntityDBName 837 | $ entityDef $ dummyFromFilts [Filter field value pfilter] 838 | name = 839 | (if includeTable 840 | then ((tn <> ".") <>) 841 | else id) 842 | $ connEscapeRawName conn $ fieldName field 843 | qmarks = case value of 844 | FilterValues x -> 845 | let x' = filter (/= PersistNull) $ map toPersistValue x 846 | in "(" <> Text.intercalate "," (map (const "?") x') <> ")" 847 | _ -> "?" 848 | showSqlFilter Eq = "=" 849 | showSqlFilter Ne = "<>" 850 | showSqlFilter Gt = ">" 851 | showSqlFilter Lt = "<" 852 | showSqlFilter Ge = ">=" 853 | showSqlFilter Le = "<=" 854 | showSqlFilter In = " IN " 855 | showSqlFilter NotIn = " NOT IN " 856 | showSqlFilter (BackendSpecificFilter s) = s 857 | 858 | dummyFromFilts :: [Filter v] -> Maybe v 859 | dummyFromFilts _ = Nothing 860 | 861 | fieldName :: forall record typ a. (PersistEntity record, PersistEntityBackend record ~ SqlFor a) => EntityField record typ -> Text 862 | fieldName f = unFieldNameDB $ fieldDB $ persistFieldDef f 863 | 864 | 865 | getFiltsValues :: forall val a. (PersistEntity val, PersistEntityBackend val ~ SqlFor a) 866 | => SqlFor a -> [Filter val] -> [PersistValue] 867 | getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo 868 | 869 | orderClause :: (PersistEntity val, PersistEntityBackend val ~ SqlFor a) 870 | => Bool -- ^ include the table name 871 | -> SqlFor a 872 | -> SelectOpt val 873 | -> Text 874 | orderClause includeTable (SqlFor conn) o = 875 | case o of 876 | Asc x -> name x 877 | Desc x -> name x <> " DESC" 878 | _ -> error "orderClause: expected Asc or Desc, not limit or offset" 879 | where 880 | dummyFromOrder :: SelectOpt a -> Maybe a 881 | dummyFromOrder _ = Nothing 882 | 883 | tn = connEscapeRawName conn $ unEntityNameDB $ getEntityDBName $ entityDef $ dummyFromOrder o 884 | 885 | name :: (PersistEntityBackend record ~ SqlFor a, PersistEntity record) 886 | => EntityField record typ -> Text 887 | name x = 888 | (if includeTable 889 | then ((tn <> ".") <>) 890 | else id) 891 | $ connEscapeRawName conn $ fieldName x 892 | 893 | dummyFromUnique :: Unique v -> Maybe v 894 | dummyFromUnique _ = Nothing 895 | 896 | -- escape :: DBName -> Text.Text 897 | -- escape (DBName s) = Text.pack $ '"' : escapeQuote (Text.unpack s) ++ "\"" 898 | -- where 899 | -- escapeQuote "" = "" 900 | -- escapeQuote ('"':xs) = "\"\"" ++ escapeQuote xs 901 | -- escapeQuote (x:xs) = x : escapeQuote xs 902 | 903 | runChunked 904 | :: (Monad m) 905 | => Int 906 | -> ([a] -> ReaderT SqlBackend m ()) 907 | -> [a] 908 | -> ReaderT SqlBackend m () 909 | runChunked _ _ [] = return () 910 | runChunked width m xs = do 911 | conn <- ask 912 | case connMaxParams conn of 913 | Nothing -> m xs 914 | Just maxParams -> let chunkSize = maxParams `div` width in 915 | mapM_ m (chunksOf chunkSize xs) 916 | 917 | -- Implement this here to avoid depending on the split package 918 | chunksOf :: Int -> [a] -> [[a]] 919 | chunksOf _ [] = [] 920 | chunksOf size xs = let (chunk, rest) = splitAt size xs in chunk : chunksOf size rest 921 | 922 | -- | The slow but generic 'upsertBy' implementation for any 'PersistUniqueRead'. 923 | -- * Lookup corresponding entities (if any) 'getBy'. 924 | -- * If the record exists, update using 'updateGet'. 925 | -- * If it does not exist, insert using 'insertEntity'. 926 | -- @since 2.11 927 | defaultUpsertBy 928 | :: ( PersistEntityBackend record ~ backend 929 | , PersistEntity record 930 | , BaseBackend backend ~ backend 931 | , BackendCompatible SqlBackend backend 932 | , MonadIO m 933 | , PersistStoreWrite backend 934 | , PersistUniqueRead backend 935 | , MySafeToInsert record 936 | ) 937 | => Unique record -- ^ uniqueness constraint to find by 938 | -> record -- ^ new record to insert 939 | -> [Update record] -- ^ updates to perform if the record already exists 940 | -> ReaderT backend m (Entity record) -- ^ the record in the database after the operation 941 | defaultUpsertBy uniqueKey record updates = do 942 | mrecord <- getBy uniqueKey 943 | maybe (insertEntity record) (`updateGetEntity` updates) mrecord 944 | where 945 | updateGetEntity (Entity k _) upds = 946 | (Entity k) `fmap` (updateGet k upds) 947 | 948 | type MySafeToInsert a = 949 | #if MIN_VERSION_persistent(2,14,0) 950 | SafeToInsert a 951 | #else 952 | () :: Constraint 953 | #endif 954 | -------------------------------------------------------------------------------- /stack-lts-12.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-13.12 2 | 3 | packages: 4 | - . 5 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: nightly-2022-08-04 2 | 3 | packages: 4 | - . 5 | -------------------------------------------------------------------------------- /test/EsqueletoSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE DerivingStrategies #-} 4 | {-# LANGUAGE EmptyDataDecls #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE FlexibleInstances #-} 7 | {-# LANGUAGE GADTs #-} 8 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 9 | {-# LANGUAGE MultiParamTypeClasses #-} 10 | {-# LANGUAGE OverloadedStrings #-} 11 | {-# LANGUAGE QuasiQuotes #-} 12 | {-# LANGUAGE StandaloneDeriving #-} 13 | {-# LANGUAGE TemplateHaskell #-} 14 | {-# LANGUAGE TypeFamilies #-} 15 | {-# LANGUAGE UndecidableInstances #-} 16 | 17 | module EsqueletoSpec where 18 | 19 | import Database.Esqueleto 20 | import Database.Persist.TH 21 | import Database.Persist.Typed 22 | import Test.Hspec 23 | 24 | data TestDb 25 | 26 | share [mkPersist (mkSqlSettingsFor ''TestDb)] [persistLowerCase| 27 | 28 | Person 29 | name String 30 | age Int 31 | deriving Show Eq 32 | 33 | Dog 34 | name String 35 | owner PersonId 36 | deriving Show Eq 37 | 38 | Foo 39 | Id sql=other_id 40 | other_id Int 41 | |] 42 | 43 | instance ToBackendKey (SqlFor TestDb) Foo where 44 | toBackendKey = 45 | unFooKey 46 | fromBackendKey = 47 | FooKey 48 | 49 | spec :: Spec 50 | spec = do 51 | let typeChecks = True `shouldBe` True 52 | describe "select" $ 53 | it "type checks" $ do 54 | let q :: SqlPersistMFor TestDb [(Entity Person, Entity Dog)] 55 | q = select $ 56 | from $ \(p `InnerJoin` d) -> do 57 | on (p ^. PersonId ==. d ^. DogOwner) 58 | pure (p, d) 59 | typeChecks 60 | 61 | describe "update" $ 62 | it "type checks" $ do 63 | let q :: SqlPersistMFor TestDb () 64 | q = update $ \p -> do 65 | set p [ PersonName =. val "world" ] 66 | where_ (p ^. PersonName ==. val "hello") 67 | typeChecks 68 | 69 | describe "delete" $ 70 | it "type checks" $ do 71 | let q :: SqlPersistMFor TestDb () 72 | q = delete $ from $ \p -> where_ (p ^. PersonName ==. val "world") 73 | typeChecks 74 | 75 | describe "issue #2" $ do 76 | it "type checks" $ do 77 | let k = toSqlKeyFor 3 :: Key Foo 78 | typeChecks 79 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | --------------------------------------------------------------------------------