├── .editorconfig ├── .gitignore ├── LICENSE ├── NOTES.md ├── README.md ├── bin └── test ├── examples ├── pg.cr └── sqlite.cr ├── shard.yml ├── src ├── focus.cr ├── focus │ ├── assignments_builder.cr │ ├── cached_column.cr │ ├── cached_result_set.cr │ ├── cached_row.cr │ ├── column.cr │ ├── column_declaring.cr │ ├── database.cr │ ├── dsl │ │ ├── aggregation.cr │ │ └── operators.cr │ ├── query.cr │ ├── query_source.cr │ ├── sql_expression.cr │ ├── sql_expressions.cr │ ├── sql_formatter.cr │ ├── sql_visitor.cr │ ├── table.cr │ ├── transaction_manager.cr │ └── update_statement_builder.cr ├── mysql.cr ├── mysql │ ├── mysql_database.cr │ └── mysql_formatter.cr ├── pg.cr ├── pg │ ├── i_like.cr │ ├── insert_returning_expression.cr │ ├── pg_database.cr │ └── pg_formatter.cr ├── sqlite.cr └── sqlite │ ├── sqlite_database.cr │ └── sqlite_formatter.cr └── test ├── mysql ├── mysql_database_test.cr └── mysql_test_base.cr ├── pg ├── pg_database_test.cr └── pg_test_base.cr ├── sqlite ├── sqlite_database_test.cr └── sqlite_test_base.cr ├── support ├── drop-mysql-data.sql ├── drop-pg-data.sql ├── drop-sqlite-data.sql ├── init-mysql-data.sql ├── init-pg-data.sql ├── init-sqlite-data.sql └── tables.cr └── test_base.cr /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.cr] 4 | charset = utf-8 5 | end_of_line = lf 6 | insert_final_newline = true 7 | indent_style = space 8 | indent_size = 2 9 | trim_trailing_whitespace = true 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /docs/ 2 | /lib/ 3 | bin/* 4 | !bin/test 5 | /.shards/ 6 | *.dwarf 7 | 8 | # Libraries don't need dependency lock 9 | # Dependencies will be locked in applications that use them 10 | /shard.lock 11 | 12 | # Ignore any sqlite databases 13 | *.db 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 matthewmcgarvey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | # Notes 2 | 3 | These are for me to jot down some things I'm learning as I reference https://www.ktorm.org and figuring out how to implement something similar in Crystal. 4 | 5 | ## Why? 6 | 7 | What are the motivating reasons for making another ORM? 8 | 9 | - Avram connects to Lucky params and the view layer which leads the standard way of writing code using Lucky to be very flat 10 | - No talk about repository layers or a layout around domains 11 | - With the way Avram Operations work, it's meant to encompass all of your code instead of used by your code for database operations 12 | - Other ORMs try to copy dynamic language ORMS (ActiveRecord, Ecto) and have limitations 13 | - I want POCO models that don't carry alot of baggage and are trivial to make as many as necessary 14 | - With the way the current ORMs work, there's a whole lat that is connected to models 15 | - it's pretty much 1 to 1 model to table and it's not possible to do things like selecting only a subset of fields, joining other tables, etc. 16 | 17 | ## Other 18 | 19 | ### Storing the result set results 20 | 21 | Because the interface is chainable, the result set needs to be pulled out and cached so that it can be iterated over as many times as needed 22 | 23 | ### Visitor Pattern 24 | 25 | ktorm has a visitor pattern for turning sql expressions into a sql statement. 26 | Except it still does what the visitor pattern is trying to avoid which is having conditionals for every type being visited. 27 | https://github.com/kotlin-orm/ktorm/blob/85647c01ed6504ea7a13a9e4ee4b50377dfd8e6a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressionVisitor.kt#L43-L56 28 | I've long been confused why the objects being visited need an "accept" method when they all call the same thing. 29 | I even made sure that you could just have overloads of a "visit" method on the visitor and pass in the classes that way and it works fine. 30 | The one reason I found for why you would still want to do the "accept" method still is to restrict the types being visited to one base class. 31 | So you can have a visitor with a visit overload for an array and a struct, but if you limit the entrance to the visitor patter to an accept method on a shared interface then you limit it to the classes that implement the interface. 32 | In this case, I will limit it to classes that implement `SqlExpression` by putting the "accept" method requirement there. 33 | I did verify that it still works with the "accept" method implemented in `SqlExpression` so I don't have to have the same method implemented in every subclass. 34 | 35 | All of this to say, I want to use the visitor pattern, but I'm going to implement it correctly whereas I don't think ktorm did. 36 | 37 | ## Generics 38 | 39 | Frustrating trying to work with generics in this library. 40 | Had to add the BaseColumn and BaseColumnExpression classes to cope with the need for arrays of these things. 41 | **NOTE** The code works as it is right now with the `abstract def as_expression` commented out in `Focus::BaseColumn` which seems... broken? 42 | 43 | One thing I just realized with Ktorm's generics usage is that Kotlin is able to conditionally add methods to an object based on generics 44 | 45 | ```kotlin 46 | public operator fun ColumnDeclaring.not(): UnaryExpression { 47 | return UnaryExpression(UnaryExpressionType.NOT, asExpression(), BooleanSqlType) 48 | } 49 | ``` 50 | 51 | This means that only boolean columns (`ColumnDeclaring` is a bit more than just columns but w/e) have access to the `not` method. 52 | I don't think there is an equivalent in Crystal 53 | 54 | UPDATE: After talking about it in the Discord, I found a hack solution for this. 55 | 56 | ```crystal 57 | def not : UnaryExpression(Bool) 58 | {% raise "#{@type.name}##{@def.name} may only be used with Bool columns" %} 59 | UnaryExpression.new( 60 | Focus::UnaryExpressionType::NOT, 61 | operand: as_expression, 62 | sql_type: Bool 63 | ) 64 | end 65 | ``` 66 | 67 | The macro `raise` call means that it fails at compile time if you call the method and the generic isn't `Bool`. 68 | I'm not going to do it right now because there's so much more important work to do right now. 69 | This only keeps devs from building incorrect sql, so we're going to go without this safety feature until later on. 70 | 71 | ### Query results by field 72 | 73 | Traced back java's postgres library and java's sqlite library. 74 | The sqlite one seems to do some sort of interaction with the native library to get a column name from the result set. 75 | Postgres, on the other hand, passes the fields from the query to the result set. 76 | I think the postgres implementation is exactly what I was thinking I could do, so I'm glad to find an example. 77 | 78 | Side note... tracing java libraries (and especially ones as complicated as SQL integrations) is ridiculously difficult. 79 | 80 | I was happy with the current implementation of how you can get results by passing in the field... 81 | that is, until I tried to implement `SELECT * FROM users`. When you provide a way to select "*" then you aren't passed any of the column information. 82 | I don't really want to trace java code right now, so I'm making this note and removing that functionality. 83 | 84 | I was wrong about how the Java Postgres library gets the field name. 85 | https://github.com/pgjdbc/pgjdbc/blob/d5ed52ef391670e83ae5265af2f7301c615ce4ca/pgjdbc/src/main/java/org/postgresql/core/v3/QueryExecutorImpl.java#L2619-L2644 86 | 87 | How do they do it? Don't know 88 | 89 | I'm...so...dumb. After finding out about ^ I looked at crystal-pg and crystal-sqlite and they both implemented stuff around column name... so I look at crystal-db. Right in front of my face the whole time is the column_name method. 90 | 91 | Still a problem though. If I join two tables that have overlapping column names, there's no way to determine the difference with the way things are right now. 92 | Looking into postgres, they issue a separate query. https://github.com/pgjdbc/pgjdbc/blob/d5ed52ef391670e83ae5265af2f7301c615ce4ca/pgjdbc/src/main/java/org/postgresql/jdbc/PgResultSetMetaData.java#L187 93 | What should I do? 94 | The reason it's not a problem for anyone else is that they don't serialize joins often, and the ordering of how you extract data from the result set matters. 95 | https://github.com/crystal-lang/crystal-db/issues/175 96 | 97 | ## Table Macro 98 | 99 | I'm really hoping this will be the only macro and I plan to document exactly what the table looks like so that people aren't scared of the macro or confused. 100 | They wouldn't even have to use it if they don't want to. 101 | 102 | ## Entity Binding 103 | 104 | I have not implemented this yet as I was focusing solely on the query building. 105 | 106 | One thing I don't like about Ktorm is how it implemented entity binding https://www.ktorm.org/en/entities-and-column-binding.html 107 | 108 | It goes against my stated goal which is to not tie the database table down to one singular model. 109 | Ktorm has you define the entity on the table which does exactly that. 110 | 111 | Looked into this more, seems like we don't need the entity sequence as much as I though. 112 | Stuff around grouping might be wanted, but the main thing was getting an iterable list of the entities and getting just one 113 | 114 | ## Ktorm Blocks 115 | 116 | For a long time, I've wondered why a lot of methods take in a block when it always seemed like it could just take in one argument instead. 117 | The blocks always pass in the table, but all the doc examples I've seen never use it. 118 | In my code I didn't add block overloads, `#where` is the best example. 119 | I finally realized why you would want the block, though. 120 | With a block you could have helper functions that don't care what the table is, but uses methods that are on all tables. 121 | The best example I can come up with is pagination. I can't do it without telling the pagination which table to use. 122 | I'm sure there's alternatives to the block, so I'm not sure I want to add in the block stuff right now anyways. 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Focus 2 | 3 | SQL query builder for multiple databases. Heavily inspired by Kotlin's [Ktorm](https://www.ktorm.org). 4 | Do you want to learn an ORM or do you want to be able to focus? 5 | 6 | Not at all ready for real use. (There's not even any tests, so don't be crazy and try to use it for real.) 7 | 8 | ## Goals 9 | 10 | ### Create an easy to understand library 11 | 12 | I don't want this library to take 10 months to feel like you know how to do everything. 13 | I want it to be made in a way that's not only easy to use, but easy to dig into internally. 14 | That means two things: 15 | 16 | 1. This will be a clean, well documented DSL 17 | 2. There will be minimal macro usages 18 | 19 | I want to avoid macro usages because, while they can make code simpler to write, it can cause confusion for maintainers and developers when they run into bugs. 20 | By trying to avoid macros, I also have to think about how I can use regular Crystal to make a pleasant API rather than fall back to macros. 21 | As of right now, I only have one macro which is `Stealth::Table.column`. It's only used when defining tables and I believe it's necessary to avoid users immediately running into hand cramps when defining tables. 22 | Maybe one day this library could have no macros? 23 | 24 | ### Separate data models from database tables 25 | 26 | This will probably be the most unusual goal of this project. 27 | The vast majority of ORMs have you define your data model and equate that to a table in the database. 28 | When you create a `User` model, it wraps the `users` table in the database and that class is how you fetch and manipulate data. 29 | So why change that? 30 | Well, from my experience in Crystal over the past few years, maintaining that style fundamentally limits the database queries that can be safely constructed (or else they provide a backdoor way to do it that feels like you're subverting the whole point of the ORM) and places quite a burden on the maintainers of the project to add increasing complexity to manage the codebase and add more and more features. 31 | We've spent enough time trying to copy ActiveRecord, and it's just not going to be possible to provide the same flexibility that it does. 32 | So I'm trying a different path. One where you define your table completely separately from your data models. The table is used to build queries and the results can be parsed just as they are or bound to a data model. 33 | This way, you can have as many data models as you want connected to the same table, you can build much more customized SQL queries, and the internals of the library are much simpler to understand. 34 | 35 | ### Work with multiple database types 36 | 37 | I don't want to limit this library to just PostgreSQL. I want developers to be able to fully use different databases even within the same project. 38 | By "fully use" I do mean that I want to provide accessible DSLs or extension points to use all the features of a particular database (like jsonb in PostgreSQL). 39 | 40 | ## Installation 41 | 42 | 1. Add the dependency to your `shard.yml`: 43 | 44 | ```yaml 45 | dependencies: 46 | focus: 47 | github: matthewmcgarvey/focus 48 | ``` 49 | 50 | 2. Run `shards install` 51 | 52 | ## Usage 53 | 54 | ### Connect to a database 55 | 56 | The aim of focus is to provide fluent access to many different types of databases. 57 | Focus will provide any extra functionality for each supported database or ways of providing it yourself. 58 | 59 | ```crystal 60 | require "focus" 61 | require "focus/sqlite" 62 | 63 | database = SQLiteDatabase.connect("sqlite3://./data.db") 64 | ``` 65 | 66 | Databases currently supported: 67 | 68 | - SQLite3 69 | - Postgresql 70 | - Mysql 71 | 72 | ### Define a table 73 | 74 | Tables are where we connect Crystal code to the database tables. 75 | It's important to understand that these are not our data models. 76 | They are used to build queries. 77 | 78 | ```crystal 79 | class UsersTable < Focus::Table 80 | @table_name = "users" 81 | 82 | column id : Int64 83 | column name : String 84 | column role : String 85 | end 86 | 87 | Users = UsersTable.new 88 | ``` 89 | 90 | We define the table `UsersTable` with the table name and the columns. 91 | We then create an instance and assign it to `Users` for a nice API. 92 | 93 | ### Make a query 94 | 95 | ```crystal 96 | database.from(Users) 97 | .select(Users.id) 98 | .where(Users.role.eq("admin")) 99 | .map(&.get(Users.id)) 100 | ``` 101 | 102 | ### Bind rows to Crystal objects 103 | 104 | Focus cleanly integrates with `DB::Serializable`. 105 | 106 | ```crystal 107 | struct User 108 | DB::Serializable 109 | 110 | property id : Int64 111 | property name : String 112 | property role : String 113 | end 114 | 115 | users = database.from(Users) 116 | .select 117 | .bind_to(User) 118 | ``` 119 | 120 | ### Insert data 121 | 122 | ```crystal 123 | database.insert(Users) do 124 | set(Users.name, "bobby") 125 | set(Users.role, "user") 126 | end 127 | ``` 128 | 129 | ### Update data 130 | 131 | ```crystal 132 | database.update(Users) do 133 | set(Users.role, "admin") 134 | where(Users.name.eq("bobby")) 135 | end 136 | ``` 137 | 138 | ## Development 139 | 140 | TODO: Write development instructions here 141 | 142 | ## TODO 143 | 144 | - Int32 vs Int64 primary keys 145 | - Error when using Int32 keys is very confusing, solution is to switch to Int64 but you wouldn't know it 146 | - Write good tests 147 | - Custom data types (i.e. postgis) 148 | - Custom queries (i.e. jsonb queries) 149 | - Add overloads to query methods that can be given a block 150 | - The blocks will be passed the table being used 151 | - They must return the expected criteria 152 | - This is so that you can have helpers that don't care about the specifics of the table but can still do common things between them 153 | - The most obvious example I can think of is for having an agnostic pagination helper 154 | - Seriously consider whether that's actually beneficial or it can be implemented cleanly the way it is right now 155 | 156 | Take like a month and a half and you pretty much forget everything! 157 | There was something about the table definitions I wanted to change but I don't remember what. 158 | 159 | ## Contributing 160 | 161 | 1. Fork it () 162 | 2. Create your feature branch (`git checkout -b my-new-feature`) 163 | 3. Commit your changes (`git commit -am 'Add some feature'`) 164 | 4. Push to the branch (`git push origin my-new-feature`) 165 | 5. Create a new Pull Request 166 | 167 | ## Contributors 168 | 169 | - [matthewmcgarvey](https://github.com/matthewmcgarvey) - creator and maintainer 170 | -------------------------------------------------------------------------------- /bin/test: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | files=$1 4 | 5 | if [ -z "$files" ] 6 | then 7 | files=$(find ./test -iname "*_test.cr") 8 | fi 9 | 10 | crystal run $files 11 | -------------------------------------------------------------------------------- /examples/pg.cr: -------------------------------------------------------------------------------- 1 | require "../src/focus" 2 | require "../src/pg" 3 | 4 | database = Focus::PGDatabase.connect("postgresql://lucky@localhost:5432/avram_dev") 5 | 6 | class UsersTable < Focus::Table 7 | @table_name = "users" 8 | 9 | column id : Int32 10 | column name : String 11 | column age : Int32 12 | column year_born : Int16 13 | column nickname : String 14 | column joined_at : Time 15 | column total_score : Int64 16 | column average_score : Float32 17 | column available_for_hire : Bool 18 | column created_at : Time 19 | column updated_at : Time 20 | end 21 | 22 | Users = UsersTable.new 23 | 24 | struct User 25 | include DB::Serializable 26 | 27 | property id : Int64 28 | property name : String 29 | property age : Int32 30 | property year_born : Int16 31 | end 32 | 33 | database.insert(Users) do 34 | set(Users.name, "William") 35 | set(Users.age, 39) 36 | set(Users.year_born, 1983) 37 | set(Users.nickname, "Bob") 38 | set(Users.joined_at, Time.utc) 39 | set(Users.total_score, 100) 40 | set(Users.created_at, Time.utc) 41 | set(Users.updated_at, Time.utc) 42 | set(Users.available_for_hire, true) 43 | set(Users.average_score, 45.78) 44 | end 45 | 46 | query = database.from(Users).select(Users.id, Users.name, Users.age, Users.year_born) 47 | pp query.bind_to_last(User) 48 | 49 | database.close 50 | -------------------------------------------------------------------------------- /examples/sqlite.cr: -------------------------------------------------------------------------------- 1 | require "../src/focus" 2 | require "../src/sqlite" 3 | 4 | database = Focus::SQLiteDatabase.connect("sqlite3://./data.db") 5 | 6 | class UsersTable < Focus::Table 7 | @table_name = "users" 8 | 9 | column id : Int32 10 | column name : String 11 | column role : String 12 | end 13 | 14 | Users = UsersTable.new 15 | 16 | class TodosTable < Focus::Table 17 | @table_name = "todos" 18 | 19 | column id : Int32 20 | column name : String 21 | column user_id : Int32 22 | end 23 | 24 | Todos = TodosTable.new 25 | 26 | class TodoWithUser 27 | include DB::Serializable 28 | 29 | property id : Int64 30 | property name : String 31 | property user_id : Int64 32 | property user_name : String 33 | property user_role : String 34 | end 35 | 36 | class User 37 | include DB::Serializable 38 | 39 | property id : Int64 40 | property name : String 41 | property role : String 42 | end 43 | 44 | database.with_connection do |conn| 45 | conn.exec(<<-SQL) 46 | create table if not exists 47 | users( 48 | id INTEGER PRIMARY KEY AUTOINCREMENT, 49 | name varchar(128) not null, 50 | role varchar(64) not null 51 | ); 52 | SQL 53 | 54 | conn.exec(<<-SQL) 55 | create table if not exists 56 | todos( 57 | id INTEGER PRIMARY KEY AUTOINCREMENT, 58 | name varchar(128) not null, 59 | user_id INTEGER not null 60 | ); 61 | SQL 62 | end 63 | 64 | database.insert(Users) do 65 | set(Users.name, "bobby") 66 | set(Users.role, "user") 67 | end 68 | 69 | database.insert(Users) do 70 | set(Users.name, "billy") 71 | set(Users.role, "admin") 72 | end 73 | 74 | database.insert(Todos) do 75 | set(Todos.name, "Take out trash") 76 | set(Todos.user_id, 2) 77 | end 78 | 79 | pp database.from(Todos) 80 | .left_join(Users, on: Todos.user_id.eq(Users.id)) 81 | .select(Todos.id, Todos.name, Todos.user_id, Users.name.aliased("user_name"), Users.role.aliased("user_role")) 82 | .bind_to(TodoWithUser) 83 | 84 | database.close 85 | -------------------------------------------------------------------------------- /shard.yml: -------------------------------------------------------------------------------- 1 | name: focus 2 | version: 0.1.0 3 | 4 | authors: 5 | - matthewmcgarvey 6 | 7 | crystal: 1.6.0 8 | 9 | license: MIT 10 | 11 | dependencies: 12 | db: 13 | github: crystal-lang/crystal-db 14 | version: 0.11.0 15 | development_dependencies: 16 | sqlite3: 17 | github: crystal-lang/crystal-sqlite3 18 | pg: 19 | github: will/crystal-pg 20 | ameba: 21 | github: crystal-ameba/ameba 22 | mysql: 23 | github: crystal-lang/crystal-mysql 24 | minitest: 25 | github: ysbaddaden/minitest.cr 26 | -------------------------------------------------------------------------------- /src/focus.cr: -------------------------------------------------------------------------------- 1 | require "db" 2 | 3 | require "./focus/sql_expression" 4 | require "./focus/column_declaring" 5 | require "./focus/*" 6 | require "./focus/dsl/*" 7 | 8 | module Focus 9 | extend Focus::Dsl::Aggregation 10 | extend Focus::Dsl::TopLevelOperators 11 | VERSION = "0.1.0" 12 | 13 | # TODO: Put your code here 14 | end 15 | -------------------------------------------------------------------------------- /src/focus/assignments_builder.cr: -------------------------------------------------------------------------------- 1 | class Focus::AssignmentsBuilder 2 | getter assignments = [] of Focus::BaseColumnAssignmentExpression 3 | 4 | def set(column : Focus::Column(T), value : T?) forall T 5 | assignments << Focus::ColumnAssignmentExpression.new(column.as_expression, column.wrap_argument(value)) 6 | end 7 | end 8 | -------------------------------------------------------------------------------- /src/focus/cached_column.cr: -------------------------------------------------------------------------------- 1 | module Focus::BaseCachedColumn 2 | abstract def value 3 | abstract def name : String 4 | end 5 | 6 | class Focus::CachedColumn(T) 7 | include Focus::BaseCachedColumn 8 | getter value : T 9 | getter name : String 10 | 11 | def initialize(@value, @name) 12 | end 13 | end 14 | -------------------------------------------------------------------------------- /src/focus/cached_result_set.cr: -------------------------------------------------------------------------------- 1 | class Focus::CachedResultSet < DB::ResultSet 2 | private property current_row_count = -1 3 | private property current_column_count = -1 4 | private getter cached_rows 5 | private getter inner : DB::ResultSet 6 | 7 | def initialize(@cached_rows : Array(CachedRow), @inner) 8 | super(inner.statement) 9 | end 10 | 11 | def move_next : Bool 12 | if cached_rows.size == current_row_count + 1 13 | false 14 | else 15 | self.current_row_count += 1 16 | self.current_column_count = -1 17 | true 18 | end 19 | end 20 | 21 | def column_count : Int32 22 | inner.column_count 23 | end 24 | 25 | def column_name(index : Int32) : String 26 | inner.column_name(index) 27 | end 28 | 29 | def read 30 | self.current_column_count += 1 31 | value = current_row.columns[current_column_count].value 32 | value 33 | end 34 | 35 | def next_column_index : Int32 36 | current_column_count + 1 37 | end 38 | 39 | private def current_row : CachedRow 40 | cached_rows[current_row_count] 41 | end 42 | end 43 | -------------------------------------------------------------------------------- /src/focus/cached_row.cr: -------------------------------------------------------------------------------- 1 | class Focus::CachedRow 2 | # Taken from https://github.com/luckyframework/avram/blob/f0148f4274798124f5457c85fa35f7ba985636b6/src/avram/charms/time_extensions.cr#L10-L23 3 | TIME_FORMATS = [ 4 | Time::Format::ISO_8601_DATE_TIME, 5 | Time::Format::RFC_2822, 6 | Time::Format::RFC_3339, 7 | # HTML datetime-local inputs are basically RFC 3339 without the timezone: 8 | # https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/datetime-local 9 | Time::Format.new("%Y-%m-%dT%H:%M:%S", Time::Location::UTC), 10 | Time::Format.new("%Y-%m-%dT%H:%M", Time::Location::UTC), 11 | # Dates and times go last, otherwise it will parse strings with both 12 | # dates *and* times incorrectly. 13 | Time::Format::HTTP_DATE, 14 | Time::Format::ISO_8601_DATE, 15 | Time::Format::ISO_8601_TIME, 16 | ] 17 | 18 | def self.build(result_set : DB::ResultSet) : Focus::CachedRow 19 | columns = [] of Focus::BaseCachedColumn 20 | result_set.column_count.times do 21 | index = result_set.next_column_index 22 | name = result_set.column_name(index) 23 | columns << Focus::CachedColumn.new(value: result_set.read, name: name) 24 | end 25 | new(columns) 26 | end 27 | 28 | getter columns : Array(Focus::BaseCachedColumn) 29 | 30 | def initialize(@columns) 31 | end 32 | 33 | def get?(column : Column(C)) : C? forall C 34 | index = -1 35 | col = columns.find do |c| 36 | index += 1 37 | c.name == column.name 38 | end 39 | return nil if col.nil? 40 | 41 | get?(index, type: C) 42 | end 43 | 44 | def get(column : Column(C)) : C forall C 45 | get?(column).not_nil! 46 | end 47 | 48 | def get?(column : ColumnDeclaringExpression(C)) : C? forall C 49 | declared_name = column.declared_name 50 | if declared_name.nil? || declared_name.blank? 51 | raise "TODO: Label of the specified column cannot be null or blank." 52 | end 53 | 54 | columns.each_with_index do |col, idx| 55 | return get?(idx, type: C) if col.name == declared_name 56 | end 57 | end 58 | 59 | def get(column : ColumnDeclaringExpression(C)) : C forall C 60 | get?(column).not_nil! 61 | end 62 | 63 | def get?(column_label : String, type : T.class) : T? forall T 64 | get?(find_column(column_label), type) 65 | end 66 | 67 | def get(column_label : String, type : T.class) : T forall T 68 | get(find_column(column_label), type) 69 | end 70 | 71 | def get(column_index : Int32, type : T.class) : T forall T 72 | get?(column_index, type).not_nil! 73 | end 74 | 75 | def get?(column_index : Int32, type : Int16.class) : Int16? 76 | val = columns[column_index].value 77 | case val 78 | when Int16, Int32, Int64, Float32, Float64 79 | val.to_i16 80 | when Bool 81 | val ? 1_i16 : 0_i16 82 | else 83 | val.try(&.to_s.to_i16) 84 | end 85 | end 86 | 87 | def get?(column_index : Int32, type : Int32.class) : Int32? 88 | val = columns[column_index].value 89 | case val 90 | when Int16, Int32, Int64, Float32, Float64 91 | val.to_i 92 | when Bool 93 | val ? 1 : 0 94 | when Nil 95 | nil 96 | else 97 | val.to_s.to_i 98 | end 99 | end 100 | 101 | def get?(column_index : Int32, type : Int64.class) : Int64? 102 | val = columns[column_index].value 103 | case val 104 | when Int16, Int32, Int64, Float32, Float64 105 | val.to_i64 106 | when Bool 107 | val ? 1_i64 : 0_i64 108 | when Nil 109 | nil 110 | else 111 | val.to_s.to_i64 112 | end 113 | end 114 | 115 | def get?(column_index : Int32, type : Float32.class) : Float32? 116 | val = columns[column_index].value 117 | case val 118 | when Int32, Int64, Float32, Float64 119 | val.to_f32 120 | when Bool 121 | val ? 1.0_f32 : 0.0_f32 122 | when Nil 123 | nil 124 | else 125 | val.to_s.to_f32 126 | end 127 | end 128 | 129 | def get?(column_index : Int32, type : Float64.class) : Float64? 130 | val = columns[column_index].value 131 | case val 132 | when Int32, Int64, Float32, Float64 133 | val.to_f 134 | when Bool 135 | val ? 1.0 : 0.0 136 | when Nil 137 | nil 138 | else 139 | val.to_s.to_f64 140 | end 141 | end 142 | 143 | def get?(column_index : Int32, type : String.class) : String? 144 | val = columns[column_index].value 145 | case val 146 | when String 147 | val 148 | else 149 | val.try(&.to_s) 150 | end 151 | end 152 | 153 | def get?(column_index : Int32, type : Bool.class) : Bool? 154 | val = columns[column_index].value 155 | case val 156 | when Bool 157 | val 158 | when Int16, Int32, Int64, Float32, Float64 159 | !val.zero? 160 | else 161 | !!val 162 | end 163 | end 164 | 165 | def get?(column_index : Int32, type : Time.class) : Time? 166 | val = columns[column_index].value 167 | return val if val.is_a? Time 168 | return if val.nil? 169 | 170 | str = val.to_s 171 | TIME_FORMATS.each do |format| 172 | begin 173 | return format.parse(str) 174 | rescue e : Time::Format::Error 175 | # do nothing 176 | end 177 | end 178 | end 179 | 180 | def find_column(column_label : String) : Int32 181 | columns.each_with_index do |col, idx| 182 | return idx if col.name == column_label 183 | end 184 | raise "Invalid column name: #{column_label}" 185 | end 186 | end 187 | -------------------------------------------------------------------------------- /src/focus/column.cr: -------------------------------------------------------------------------------- 1 | require "./column_declaring" 2 | 3 | module Focus::BaseColumn 4 | include Focus::BaseColumnDeclaring 5 | 6 | getter table : Focus::Table 7 | getter name : String 8 | 9 | abstract def as_expression : Focus::BaseColumnExpression 10 | 11 | def label : String 12 | name 13 | end 14 | end 15 | 16 | class Focus::Column(T) 17 | include Focus::BaseColumn 18 | include Focus::ColumnDeclaring(T) 19 | 20 | def initialize(@table : Focus::Table, @name : String) 21 | end 22 | 23 | def as_expression : Focus::ColumnExpression(T) 24 | Focus::ColumnExpression(T).new(table.as_expression, name) 25 | end 26 | 27 | def wrap_argument(argument : T?) : Focus::ArgumentExpression(T) 28 | Focus::ArgumentExpression(T).new(argument) 29 | end 30 | 31 | def aliased(label : String? = nil) : Focus::ColumnDeclaringExpression(T) 32 | Focus::ColumnDeclaringExpression(T).new(as_expression, label) 33 | end 34 | 35 | def as_declaring_expression : Focus::ColumnDeclaringExpression(T) 36 | aliased(label) 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /src/focus/column_declaring.cr: -------------------------------------------------------------------------------- 1 | require "./dsl/operators" 2 | 3 | module Focus::BaseColumnDeclaring 4 | abstract def as_declaring_expression : Focus::BaseColumnDeclaringExpression 5 | end 6 | 7 | module Focus::ColumnDeclaring(T) 8 | include Focus::BaseColumnDeclaring 9 | include Focus::Dsl::Operators(T) 10 | 11 | abstract def as_expression : Focus::ScalarExpression(T) 12 | abstract def wrap_argument(argument : T?) : Focus::ArgumentExpression(T) 13 | abstract def aliased(label : String? = nil) : Focus::ColumnDeclaringExpression(T) 14 | 15 | def asc : OrderByExpression 16 | OrderByExpression.new(as_expression, OrderType::ASCENDING) 17 | end 18 | 19 | def desc : OrderByExpression 20 | OrderByExpression.new(as_expression, OrderType::DESCENDING) 21 | end 22 | 23 | def sql_type 24 | T 25 | end 26 | end 27 | -------------------------------------------------------------------------------- /src/focus/database.cr: -------------------------------------------------------------------------------- 1 | abstract class Focus::Database 2 | def self.connect(url : String) : Database 3 | new(raw_db: DB.open(url)) 4 | end 5 | 6 | def self.connect(db : DB::Database) : Database 7 | new(raw_db: db) 8 | end 9 | 10 | private getter raw_db : DB::Database 11 | private getter transaction_manager : Focus::TransactionManager 12 | 13 | def initialize(@raw_db : DB::Database) 14 | @transaction_manager = Focus::TransactionManager.new(raw_db) 15 | end 16 | 17 | def setup_connection(&block : DB::Connection -> _) 18 | raw_db.setup_connection do |conn| 19 | block.call conn 20 | end 21 | end 22 | 23 | def from(table : Focus::Table) : Focus::QuerySource 24 | Focus::QuerySource.new(self, table, table.as_expression) 25 | end 26 | 27 | def insert(table : Focus::Table) : Int64 28 | builder = Focus::AssignmentsBuilder.new 29 | with builder yield 30 | expression = InsertExpression.new(table.as_expression, builder.assignments) 31 | execute_update(expression) 32 | end 33 | 34 | def insert_returning_generated_key(table : Focus::Table, column : Focus::Column(T)) : T forall T 35 | builder = Focus::AssignmentsBuilder.new 36 | with builder yield 37 | expression = InsertExpression.new(table.as_expression, builder.assignments) 38 | result_set = execute_insert_and_return_generated_key(expression, column) 39 | rows = [] of Focus::CachedRow 40 | begin 41 | result_set.each do 42 | rows << Focus::CachedRow.build(result_set) 43 | end 44 | ensure 45 | result_set.close 46 | end 47 | rows 48 | if row = rows.first? 49 | row.get(0, type: T) 50 | else 51 | raise "Expected a key to be returned by the database" 52 | end 53 | end 54 | 55 | def update(table : Focus::Table) : Int64 56 | builder = Focus::UpdateStatementBuilder.new 57 | with builder yield 58 | expression = Focus::UpdateExpression.new( 59 | table.as_expression, 60 | builder.assignments, 61 | builder.where.try(&.as_expression) 62 | ) 63 | execute_update(expression) 64 | end 65 | 66 | def delete(table : Focus::Table, where : ColumnDeclaring(Bool)) : Int64 67 | expression = DeleteExpression.new(table.as_expression, where.as_expression) 68 | execute_update(expression) 69 | end 70 | 71 | def delete_all(table : Focus::Table) : Int64 72 | expression = DeleteExpression.new(table.as_expression, where: nil) 73 | execute_update(expression) 74 | end 75 | 76 | def close : Nil 77 | raw_db.close 78 | end 79 | 80 | def execute_query(expression : Focus::SqlExpression) : DB::ResultSet 81 | sql, args = format_expression(expression) 82 | with_connection do |conn| 83 | conn.query(sql, args: args.map(&.value)) 84 | end 85 | end 86 | 87 | def execute_update(expression : Focus::SqlExpression) : Int64 88 | sql, args = format_expression(expression) 89 | with_connection do |conn| 90 | conn.exec(sql, args: args.map(&.value)).rows_affected 91 | end 92 | end 93 | 94 | def with_connection(&block : DB::Connection -> T) : T forall T 95 | transaction_manager.with_connection do |conn| 96 | yield conn 97 | end 98 | end 99 | 100 | def with_transaction(&block : DB::Transaction -> T) : T? forall T 101 | transaction_manager.with_transaction do |txn| 102 | yield txn 103 | end 104 | end 105 | 106 | abstract def format_expression(expression : Focus::SqlExpression) : Tuple(String, Array(Focus::BaseArgumentExpression)) 107 | abstract def execute_insert_and_return_generated_key(expression : Focus::InsertExpression, column : Focus::BaseColumn) : DB::ResultSet 108 | end 109 | -------------------------------------------------------------------------------- /src/focus/dsl/aggregation.cr: -------------------------------------------------------------------------------- 1 | module Focus::Dsl::Aggregation 2 | def min(column : ColumnDeclaring(Comparable)) : AggregateExpression(Comparable) 3 | AggregateExpression(Comparable).new(AggregateType::MIN, column.as_expression, is_distinct: false) 4 | end 5 | 6 | def min_distinct(column : ColumnDeclaring(Comparable)) : AggregateExpression(Comparable) 7 | AggregateExpression(Comparable).new(AggregateType::MIN, column.as_expression, is_distinct: true) 8 | end 9 | 10 | def max(column : ColumnDeclaring(Comparable)) : AggregateExpression(Comparable) 11 | AggregateExpression(Comparable).new(AggregateType::MAX, column.as_expression, is_distinct: false) 12 | end 13 | 14 | def max_distinct(column : ColumnDeclaring(Comparable)) : AggregateExpression(Comparable) 15 | AggregateExpression(Comparable).new(AggregateType::MAX, column.as_expression, is_distinct: true) 16 | end 17 | 18 | def avg(column : ColumnDeclaring(Comparable)) : AggregateExpression(Float32) 19 | AggregateExpression(Float32).new(AggregateType::AVG, column.as_expression, is_distinct: false) 20 | end 21 | 22 | def avg_distinct(column : ColumnDeclaring(Comparable)) : AggregateExpression(Float32) 23 | AggregateExpression(Float32).new(AggregateType::AVG, column.as_expression, is_distinct: true) 24 | end 25 | 26 | def sum(column : ColumnDeclaring(Comparable)) : AggregateExpression(Comparable) 27 | AggregateExpression(Comparable).new(AggregateType::SUM, column.as_expression, is_distinct: false) 28 | end 29 | 30 | def sum_distinct(column : ColumnDeclaring(Comparable)) : AggregateExpression(Comparable) 31 | AggregateExpression(Comparable).new(AggregateType::SUM, column.as_expression, is_distinct: true) 32 | end 33 | 34 | def count(column : ColumnDeclaring(Comparable)? = nil) : AggregateExpression(Int32) 35 | AggregateExpression(Int32).new(AggregateType::COUNT, column.try(&.as_expression), is_distinct: false) 36 | end 37 | 38 | def count_distinct(column : ColumnDeclaring(Comparable)) : AggregateExpression(Int32) 39 | AggregateExpression(Int32).new(AggregateType::COUNT, column.as_expression, is_distinct: true) 40 | end 41 | end 42 | -------------------------------------------------------------------------------- /src/focus/dsl/operators.cr: -------------------------------------------------------------------------------- 1 | module Focus::Dsl::Operators(T) 2 | def eq(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 3 | Focus::BinaryExpression(Bool, T).new( 4 | Focus::BinaryExpressionType::EQUAL, 5 | left: as_expression, 6 | right: expr.as_expression 7 | ) 8 | end 9 | 10 | def eq(val : T) : Focus::BinaryExpression(Bool, T) 11 | eq(wrap_argument(val)) 12 | end 13 | 14 | def ==(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 15 | eq(expr) 16 | end 17 | 18 | def ==(val : T) : Focus::BinaryExpression(Bool, T) 19 | eq(wrap_argument(val)) 20 | end 21 | 22 | def not_eq(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 23 | Focus::BinaryExpression(Bool, T).new( 24 | Focus::BinaryExpressionType::NOT_EQUAL, 25 | left: as_expression, 26 | right: expr.as_expression 27 | ) 28 | end 29 | 30 | def not_eq(val : T) : Focus::BinaryExpression(Bool, T) 31 | not_eq(wrap_argument(val)) 32 | end 33 | 34 | def !=(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 35 | not_eq(expr) 36 | end 37 | 38 | def !=(val : T) : Focus::BinaryExpression(Bool, T) 39 | not_eq(wrap_argument(val)) 40 | end 41 | 42 | def between(range : Range(T, T)) : BetweenExpression(T) 43 | BetweenExpression(T).new(as_expression, wrap_argument(range.begin), wrap_argument(range.end)) 44 | end 45 | 46 | def not_between(range : Range(T, T)) : BetweenExpression(T) 47 | BetweenExpression(T).new(as_expression, wrap_argument(range.begin), wrap_argument(range.end), not_between: true) 48 | end 49 | 50 | def in_list(*list : T) : InListExpression(T) 51 | in_list(list.to_a) 52 | end 53 | 54 | def in_list(list : Array(T)) : InListExpression(T) 55 | values = list.map { |value| wrap_argument(value) } 56 | InListExpression(T).new(left: as_expression, values: values) 57 | end 58 | 59 | def in_list(query : Query) : InListExpression(T) 60 | InListExpression(T).new(left: as_expression, query: query.expression) 61 | end 62 | 63 | def not_in_list(*list : T) : InListExpression(T) 64 | values = list.map { |value| wrap_argument(value) }.to_a 65 | InListExpression(T).new(left: as_expression, values: values, not_in_list: true) 66 | end 67 | 68 | def not_in_list(list : Array(T)) : InListExpression(T) 69 | values = list.map { |value| wrap_argument(value) } 70 | InListExpression(T).new(left: as_expression, values: values, not_in_list: true) 71 | end 72 | 73 | def not_in_list(query : Query) : InListExpression(T) 74 | InListExpression(T).new(left: as_expression, query: query.expression, not_in_list: true) 75 | end 76 | 77 | def is_null : UnaryExpression(Bool) 78 | UnaryExpression(Bool).new( 79 | Focus::UnaryExpressionType::IS_NULL, 80 | operand: as_expression 81 | ) 82 | end 83 | 84 | def is_not_null : UnaryExpression(Bool) 85 | UnaryExpression(Bool).new( 86 | Focus::UnaryExpressionType::IS_NOT_NULL, 87 | operand: as_expression 88 | ) 89 | end 90 | 91 | def unary_minus : UnaryExpression(T) 92 | UnaryExpression(T).new( 93 | Focus::UnaryExpressionType::UNARY_MINUS, 94 | operand: as_expression 95 | ) 96 | end 97 | 98 | def unary_plus : UnaryExpression(T) 99 | UnaryExpression(T).new( 100 | Focus::UnaryExpressionType::UNARY_PLUS, 101 | operand: as_expression 102 | ) 103 | end 104 | 105 | def not : UnaryExpression(Bool) 106 | UnaryExpression(Bool).new( 107 | Focus::UnaryExpressionType::NOT, 108 | operand: as_expression 109 | ) 110 | end 111 | 112 | def plus(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 113 | Focus::BinaryExpression(T, T).new( 114 | BinaryExpressionType::PLUS, 115 | as_expression, 116 | expr.as_expression 117 | ) 118 | end 119 | 120 | def plus(value : T) : Focus::BinaryExpression(T, T) 121 | plus(wrap_argument(value)) 122 | end 123 | 124 | def +(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 125 | plus(expr) 126 | end 127 | 128 | def +(value : T) : Focus::BinaryExpression(T, T) 129 | plus(wrap_argument(value)) 130 | end 131 | 132 | def minus(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 133 | Focus::BinaryExpression(T, T).new( 134 | BinaryExpressionType::MINUS, 135 | as_expression, 136 | expr.as_expression 137 | ) 138 | end 139 | 140 | def minus(value : T) : Focus::BinaryExpression(T, T) 141 | minus(wrap_argument(value)) 142 | end 143 | 144 | def -(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 145 | minus(expr) 146 | end 147 | 148 | def -(value : T) : Focus::BinaryExpression(T, T) 149 | minus(wrap_argument(value)) 150 | end 151 | 152 | def times(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 153 | Focus::BinaryExpression(T, T).new( 154 | BinaryExpressionType::TIMES, 155 | as_expression, 156 | expr.as_expression 157 | ) 158 | end 159 | 160 | def times(value : T) : Focus::BinaryExpression(T, T) 161 | times(wrap_argument(value)) 162 | end 163 | 164 | def *(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 165 | times(expr) 166 | end 167 | 168 | def *(value : T) : Focus::BinaryExpression(T, T) 169 | times(wrap_argument(value)) 170 | end 171 | 172 | def div(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 173 | Focus::BinaryExpression(T, T).new( 174 | BinaryExpressionType::DIV, 175 | as_expression, 176 | expr.as_expression 177 | ) 178 | end 179 | 180 | def div(value : T) : Focus::BinaryExpression(T, T) 181 | div(wrap_argument(value)) 182 | end 183 | 184 | def /(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 185 | div(expr) 186 | end 187 | 188 | def /(value : T) : Focus::BinaryExpression(T, T) 189 | div(wrap_argument(value)) 190 | end 191 | 192 | def rem(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 193 | Focus::BinaryExpression(T, T).new( 194 | BinaryExpressionType::REM, 195 | as_expression, 196 | expr.as_expression 197 | ) 198 | end 199 | 200 | def rem(value : T) : Focus::BinaryExpression(T, T) 201 | rem(wrap_argument(value)) 202 | end 203 | 204 | def %(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(T, T) 205 | rem(expr) 206 | end 207 | 208 | def %(value : T) : Focus::BinaryExpression(T, T) 209 | rem(wrap_argument(value)) 210 | end 211 | 212 | def like(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 213 | Focus::BinaryExpression(Bool, T).new( 214 | BinaryExpressionType::LIKE, 215 | as_expression, 216 | expr.as_expression 217 | ) 218 | end 219 | 220 | def like(value : T) : Focus::BinaryExpression(Bool, T) 221 | like(wrap_argument(value)) 222 | end 223 | 224 | def not_like(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 225 | Focus::BinaryExpression(Bool, T).new( 226 | BinaryExpressionType::NOT_LIKE, 227 | as_expression, 228 | expr.as_expression 229 | ) 230 | end 231 | 232 | def not_like(value : T) : Focus::BinaryExpression(Bool, T) 233 | not_like(wrap_argument(value)) 234 | end 235 | 236 | def and(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 237 | Focus::BinaryExpression(Bool, T).new( 238 | BinaryExpressionType::AND, 239 | as_expression, 240 | expr.as_expression 241 | ) 242 | end 243 | 244 | def and(value : T) : Focus::BinaryExpression(Bool, T) 245 | and(wrap_argument(value)) 246 | end 247 | 248 | def &(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 249 | and(expr) 250 | end 251 | 252 | def &(value : T) : Focus::BinaryExpression(Bool, T) 253 | and(wrap_argument(value)) 254 | end 255 | 256 | def or(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 257 | Focus::BinaryExpression(Bool, T).new( 258 | BinaryExpressionType::OR, 259 | as_expression, 260 | expr.as_expression 261 | ) 262 | end 263 | 264 | def or(value : T) : Focus::BinaryExpression(Bool, T) 265 | or(wrap_argument(value)) 266 | end 267 | 268 | def |(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 269 | or(expr) 270 | end 271 | 272 | def |(value : T) : Focus::BinaryExpression(Bool, T) 273 | or(wrap_argument(value)) 274 | end 275 | 276 | def xor(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 277 | Focus::BinaryExpression(Bool, T).new( 278 | BinaryExpressionType::XOR, 279 | as_expression, 280 | expr.as_expression 281 | ) 282 | end 283 | 284 | def xor(value : T) : Focus::BinaryExpression(Bool, T) 285 | xor(wrap_argument(value)) 286 | end 287 | 288 | def ^(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 289 | xor(expr) 290 | end 291 | 292 | def ^(value : T) : Focus::BinaryExpression(Bool, T) 293 | xor(wrap_argument(value)) 294 | end 295 | 296 | def less_than(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 297 | Focus::BinaryExpression(Bool, T).new( 298 | BinaryExpressionType::LESS_THAN, 299 | as_expression, 300 | expr.as_expression 301 | ) 302 | end 303 | 304 | def less_than(value : T) : Focus::BinaryExpression(Bool, T) 305 | less_than(wrap_argument(value)) 306 | end 307 | 308 | def <(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 309 | less_than(expr) 310 | end 311 | 312 | def <(value : T) : Focus::BinaryExpression(Bool, T) 313 | less_than(wrap_argument(value)) 314 | end 315 | 316 | def greater_than(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 317 | Focus::BinaryExpression(Bool, T).new( 318 | BinaryExpressionType::GREATER_THAN, 319 | as_expression, 320 | expr.as_expression 321 | ) 322 | end 323 | 324 | def greater_than(value : T) : Focus::BinaryExpression(Bool, T) 325 | greater_than(wrap_argument(value)) 326 | end 327 | 328 | def >(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 329 | greater_than(expr) 330 | end 331 | 332 | def >(value : T) : Focus::BinaryExpression(Bool, T) 333 | greater_than(wrap_argument(value)) 334 | end 335 | 336 | def less_than_or_equal(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 337 | Focus::BinaryExpression(Bool, T).new( 338 | BinaryExpressionType::LESS_THAN_OR_EQUAL, 339 | as_expression, 340 | expr.as_expression 341 | ) 342 | end 343 | 344 | def less_than_or_equal(value : T) : Focus::BinaryExpression(Bool, T) 345 | less_than_or_equal(wrap_argument(value)) 346 | end 347 | 348 | def <=(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 349 | less_than_or_equal(expr) 350 | end 351 | 352 | def <=(value : T) : Focus::BinaryExpression(Bool, T) 353 | less_than_or_equal(wrap_argument(value)) 354 | end 355 | 356 | def greater_than_or_equal(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 357 | Focus::BinaryExpression(Bool, T).new( 358 | BinaryExpressionType::GREATER_THAN_OR_EQUAL, 359 | as_expression, 360 | expr.as_expression 361 | ) 362 | end 363 | 364 | def greater_than_or_equal(value : T) : Focus::BinaryExpression(Bool, T) 365 | greater_than_or_equal(wrap_argument(value)) 366 | end 367 | 368 | def >=(expr : ColumnDeclaring(T)) : Focus::BinaryExpression(Bool, T) 369 | greater_than_or_equal(expr) 370 | end 371 | 372 | def >=(value : T) : Focus::BinaryExpression(Bool, T) 373 | greater_than_or_equal(wrap_argument(value)) 374 | end 375 | end 376 | 377 | module Focus::Dsl::TopLevelOperators 378 | def exists(query : Query) : ExistsExpression 379 | ExistsExpression.new(query.expression) 380 | end 381 | 382 | def not_exists(query : Query) : ExistsExpression 383 | ExistsExpression.new(query.expression, not_exists: true) 384 | end 385 | end 386 | -------------------------------------------------------------------------------- /src/focus/query.cr: -------------------------------------------------------------------------------- 1 | class Focus::Query 2 | include Enumerable(Focus::CachedRow) 3 | 4 | getter database : Focus::Database 5 | getter expression : Focus::SelectExpression 6 | 7 | def initialize(@database : Focus::Database, @expression : Focus::SelectExpression) 8 | end 9 | 10 | def each(&block : Focus::CachedRow -> Nil) 11 | rows.each do |row| 12 | yield row 13 | end 14 | end 15 | 16 | @rows : Array(Focus::CachedRow)? 17 | 18 | def rows : Array(Focus::CachedRow) 19 | @rows ||= begin 20 | rows = [] of Focus::CachedRow 21 | result_set = inner_result_set 22 | begin 23 | result_set.each do 24 | rows << Focus::CachedRow.build(result_set) 25 | end 26 | ensure 27 | result_set.close 28 | end 29 | rows 30 | end 31 | end 32 | 33 | def result_set : DB::ResultSet 34 | CachedResultSet.new(rows, inner_result_set) 35 | end 36 | 37 | @inner_result_set : DB::ResultSet? 38 | 39 | # only meant for use between #rows and #result_set 40 | # it can only be used once 41 | private def inner_result_set : DB::ResultSet 42 | @inner_result_set ||= database.execute_query(expression) 43 | end 44 | 45 | def bind_to(entity : T.class) : Array(T) forall T 46 | entity.from_rs(result_set) 47 | end 48 | 49 | def bind_to_one?(entity : T.class, at index : Int32) : T? forall T 50 | drop(index) 51 | .take(1) 52 | .bind_to(entity) 53 | .first? 54 | end 55 | 56 | def bind_to_one(entity : T.class, at index : Int32) : T forall T 57 | bind_to_one?(entity, index).not_nil! 58 | end 59 | 60 | def bind_to_first?(entity : T.class) : T? forall T 61 | bind_to_one?(entity, at: 0) 62 | end 63 | 64 | def bind_to_first(entity : T.class) : T forall T 65 | bind_to_one(entity, at: 0) 66 | end 67 | 68 | def bind_to_last?(entity : T.class) : T? forall T 69 | result_set = CachedResultSet.new(rows.last(1), inner_result_set) 70 | entity.from_rs(result_set).pop? 71 | end 72 | 73 | def bind_to_last(entity : T.class) : T forall T 74 | bind_to_last?(entity).not_nil! 75 | end 76 | 77 | def to_sql : String 78 | result = database.format_expression(expression) 79 | "#{result.first} #{result[1].map(&.value)}" 80 | end 81 | 82 | def where(condition : Focus::ScalarExpression(Bool)) : Focus::Query 83 | new_expression = expression.copy(where: condition) 84 | Focus::Query.new(database, new_expression) 85 | end 86 | 87 | def where_with_conditions(&block : Array(ColumnDeclaring(Bool)) -> Nil) : Focus::Query 88 | conditions = [] of ColumnDeclaring(Bool) 89 | yield conditions 90 | return self if conditions.empty? 91 | 92 | condition = conditions.reduce { |a, b| a.and b } 93 | where condition.as(Focus::ScalarExpression(Bool)) 94 | end 95 | 96 | def where_with_or_conditions(&block : Array(ColumnDeclaring(Bool)) -> Nil) : Focus::Query 97 | conditions = [] of ColumnDeclaring(Bool) 98 | yield conditions 99 | return self if conditions.empty? 100 | 101 | condition = conditions.reduce { |a, b| a.or b } 102 | where condition.as(Focus::ScalarExpression(Bool)) 103 | end 104 | 105 | def group_by(*columns : BaseColumnDeclaring) : Query 106 | group_by(columns.to_a) 107 | end 108 | 109 | def group_by(columns : Array(BaseColumnDeclaring)) : Query 110 | new_expression = expression.copy(group_by: columns.map(&.as_expression.as(BaseScalarExpression))) 111 | Focus::Query.new(database, new_expression) 112 | end 113 | 114 | def having(condition : ColumnDeclaring(Bool)) : Query 115 | new_expression = expression.copy(having: condition.as_expression) 116 | Focus::Query.new(database, new_expression) 117 | end 118 | 119 | def order_by(*orders : OrderByExpression) : Query 120 | order_by(orders.to_a) 121 | end 122 | 123 | def order_by(orders : Array(OrderByExpression)) : Query 124 | new_expression = expression.copy(order_by: orders) 125 | Focus::Query.new(database, new_expression) 126 | end 127 | 128 | def limit(offset : Int32?, limit : Int32?) : Query 129 | new_limit = limit.try { |lim| lim > 0 ? lim : nil } || expression.limit 130 | new_offset = offset.try { |off| off > 0 ? off : nil } || expression.offset 131 | new_expression = expression.copy(limit: new_limit, offset: new_offset) 132 | Focus::Query.new(database, new_expression) 133 | end 134 | 135 | def limit(limit : Int32) : Query 136 | limit(limit: limit, offset: nil) 137 | end 138 | 139 | def offset(offset : Int32) : Query 140 | limit(limit: nil, offset: offset) 141 | end 142 | 143 | def drop(n : Int32) : Query 144 | if n.zero? 145 | self 146 | else 147 | offset = expression.offset || 0 148 | new_expression = expression.copy(offset: offset + n) 149 | Query.new(database, new_expression) 150 | end 151 | end 152 | 153 | def take(n : Int32) : Query 154 | limit = expression.limit || Int32::MAX 155 | new_expression = expression.copy(limit: Math.min(limit, n)) 156 | Query.new(database, new_expression) 157 | end 158 | 159 | def any? : Bool 160 | count > 0 161 | end 162 | 163 | def none? : Bool 164 | count.zero? 165 | end 166 | 167 | def count : Int32 168 | aggregate_columns(Focus.count).not_nil! 169 | end 170 | 171 | def sum_by(selector : ColumnDeclaring(T)) : T? forall T 172 | aggregate_columns(Focus.sum(selector)) 173 | end 174 | 175 | def max_by(selector : ColumnDeclaring(T)) : T? forall T 176 | aggregate_columns(Focus.max(selector)) 177 | end 178 | 179 | def min_by(selector : ColumnDeclaring(T)) : T? forall T 180 | aggregate_columns(Focus.min(selector)) 181 | end 182 | 183 | def average_by(selector : BaseColumnDeclaring) : Float32? 184 | aggregate_columns(Focus.avg(selector)) 185 | end 186 | 187 | def aggregate_columns(aggregation : ColumnDeclaring(T)) : T? forall T 188 | new_expression = expression.copy(columns: [aggregation.aliased(nil).as(BaseColumnDeclaringExpression)]) 189 | new_query = Query.new(database, new_expression) 190 | row = new_query.rows.first 191 | row.get?(0, T) 192 | end 193 | end 194 | -------------------------------------------------------------------------------- /src/focus/query_source.cr: -------------------------------------------------------------------------------- 1 | class Focus::QuerySource 2 | getter database : Focus::Database 3 | getter table : Focus::Table 4 | getter expression : Focus::QuerySourceExpression 5 | 6 | def initialize(@database, @table, @expression) 7 | end 8 | 9 | def select(*columns : Focus::BaseColumnDeclaring) : Focus::Query 10 | self.select(columns.to_a) 11 | end 12 | 13 | def select(columns : Enumerable(Focus::BaseColumnDeclaring)) : Focus::Query 14 | columns = columns.map(&.as_declaring_expression).select(Focus::BaseColumnDeclaringExpression) # cuz generics 15 | select_expression = Focus::SelectExpression.new(columns: columns, from: expression) 16 | Focus::Query.new(database: database, expression: select_expression) 17 | end 18 | 19 | def select : Focus::Query 20 | select_expression = Focus::SelectExpression.new(from: expression) 21 | Focus::Query.new(database: database, expression: select_expression) 22 | end 23 | 24 | def select_distinct(*columns : Focus::BaseColumnDeclaring) : Focus::Query 25 | self.select_distinct(columns.to_a) 26 | end 27 | 28 | def select_distinct(columns : Enumerable(Focus::BaseColumnDeclaring)) : Focus::Query 29 | columns = columns.map(&.as_declaring_expression).select(Focus::BaseColumnDeclaringExpression) # cuz generics 30 | select_expression = Focus::SelectExpression.new(columns: columns, from: expression, is_distinct: true) 31 | Focus::Query.new(database: database, expression: select_expression) 32 | end 33 | 34 | def select_distinct : Focus::Query 35 | select_expression = Focus::SelectExpression.new(from: expression, is_distinct: true) 36 | Focus::Query.new(database: database, expression: select_expression) 37 | end 38 | 39 | def cross_join(right : Focus::Table, on : ColumnDeclaring(Bool)? = nil) : QuerySource 40 | new_expression = JoinExpression.new( 41 | type: JoinType::CROSS_JOIN, 42 | left: expression, 43 | right: right.as_expression, 44 | condition: on.try(&.as_expression) 45 | ) 46 | QuerySource.new(database, table, new_expression) 47 | end 48 | 49 | def inner_join(right : Focus::Table, on : ColumnDeclaring(Bool)? = nil) : QuerySource 50 | new_expression = JoinExpression.new( 51 | type: JoinType::INNER_JOIN, 52 | left: expression, 53 | right: right.as_expression, 54 | condition: on.try(&.as_expression) 55 | ) 56 | QuerySource.new(database, table, new_expression) 57 | end 58 | 59 | def left_join(right : Focus::Table, on : ColumnDeclaring(Bool)? = nil) : QuerySource 60 | new_expression = JoinExpression.new( 61 | type: JoinType::LEFT_JOIN, 62 | left: expression, 63 | right: right.as_expression, 64 | condition: on.try(&.as_expression) 65 | ) 66 | QuerySource.new(database, table, new_expression) 67 | end 68 | 69 | def right_join(right : Focus::Table, on : ColumnDeclaring(Bool)? = nil) : QuerySource 70 | new_expression = JoinExpression.new( 71 | type: JoinType::RIGHT_JOIN, 72 | left: expression, 73 | right: right.as_expression, 74 | condition: on.try(&.as_expression) 75 | ) 76 | QuerySource.new(database, table, new_expression) 77 | end 78 | end 79 | -------------------------------------------------------------------------------- /src/focus/sql_expression.cr: -------------------------------------------------------------------------------- 1 | module Focus::SqlExpression 2 | def accept(visitor : Focus::SqlVisitor) : Nil 3 | visitor.visit(self) 4 | end 5 | 6 | def wrap_in_parens? : Bool 7 | true 8 | end 9 | end 10 | -------------------------------------------------------------------------------- /src/focus/sql_expressions.cr: -------------------------------------------------------------------------------- 1 | module Focus::BaseScalarExpression 2 | include Focus::SqlExpression 3 | end 4 | 5 | module Focus::ScalarExpression(T) 6 | include Focus::BaseScalarExpression 7 | include Focus::ColumnDeclaring(T) 8 | 9 | def as_expression : self 10 | self 11 | end 12 | 13 | def wrap_argument(argument : T?) : Focus::ArgumentExpression(T) 14 | ArgumentExpression(T).new(argument) 15 | end 16 | 17 | def aliased(label : String? = nil) : Focus::ColumnDeclaringExpression(T) 18 | Focus::ColumnDeclaringExpression.new(self, label) 19 | end 20 | 21 | def as_declaring_expression : Focus::ColumnDeclaringExpression(T) 22 | aliased(nil) 23 | end 24 | end 25 | 26 | module Focus::QuerySourceExpression 27 | include Focus::SqlExpression 28 | end 29 | 30 | module Focus::QueryExpression 31 | include Focus::QuerySourceExpression 32 | 33 | getter table_alias : String? 34 | end 35 | 36 | class Focus::TableExpression 37 | include Focus::QuerySourceExpression 38 | 39 | getter name : String 40 | getter table_alias : String? 41 | getter catalog : String? 42 | getter schema : String? 43 | 44 | def initialize(@name, @table_alias = nil, @catalog = nil, @schema = nil) 45 | end 46 | 47 | def wrap_in_parens? : Bool 48 | false 49 | end 50 | end 51 | 52 | class Focus::SelectExpression 53 | include Focus::QueryExpression 54 | 55 | getter columns : Array(BaseColumnDeclaringExpression) 56 | getter from : QuerySourceExpression 57 | getter where : ScalarExpression(Bool)? 58 | getter group_by : Array(BaseScalarExpression) 59 | getter having : ScalarExpression(Bool)? 60 | getter order_by : Array(OrderByExpression) 61 | getter is_distinct : Bool 62 | getter limit : Int32? 63 | getter offset : Int32? 64 | 65 | def initialize( 66 | @from, 67 | @columns = [] of BaseColumnDeclaringExpression, 68 | @where = nil, 69 | @group_by = [] of BaseScalarExpression, 70 | @having = nil, 71 | @order_by = [] of OrderByExpression, 72 | @is_distinct = false, 73 | @limit = nil, 74 | @offset = nil 75 | ) 76 | end 77 | 78 | def copy( 79 | columns = self.columns, 80 | from = self.from, 81 | where = self.where, 82 | group_by = self.group_by, 83 | having = self.having, 84 | order_by = self.order_by, 85 | is_distinct = self.is_distinct, 86 | limit = self.limit, 87 | offset = self.offset 88 | ) 89 | SelectExpression.new( 90 | from, 91 | columns, 92 | where, 93 | group_by, 94 | having, 95 | order_by, 96 | is_distinct, 97 | limit, 98 | offset 99 | ) 100 | end 101 | end 102 | 103 | class Focus::AggregateExpression(T) 104 | include Focus::ScalarExpression(T) 105 | 106 | getter type : Focus::AggregateType 107 | getter argument : Focus::BaseScalarExpression? 108 | getter is_distinct : Bool 109 | 110 | def initialize(@type, @argument, @is_distinct) 111 | end 112 | 113 | def method : String 114 | type.method 115 | end 116 | 117 | def wrap_in_parens? : Bool 118 | false 119 | end 120 | end 121 | 122 | enum Focus::AggregateType 123 | MIN 124 | MAX 125 | AVG 126 | SUM 127 | COUNT 128 | 129 | def method : String 130 | case self 131 | when MIN 132 | "min" 133 | when MAX 134 | "max" 135 | when AVG 136 | "avg" 137 | when SUM 138 | "sum" 139 | when COUNT 140 | "count" 141 | else 142 | raise "missing a case statement for #{self}" 143 | end 144 | end 145 | end 146 | 147 | module Focus::BaseArgumentExpression 148 | abstract def value 149 | end 150 | 151 | class Focus::ArgumentExpression(T) 152 | include Focus::ScalarExpression(T) 153 | include Focus::BaseArgumentExpression 154 | 155 | getter value : T? 156 | 157 | def initialize(@value) 158 | end 159 | end 160 | 161 | class Focus::BetweenExpression(T) 162 | include Focus::ScalarExpression(Bool) 163 | 164 | getter expression : Focus::ScalarExpression(T) 165 | getter lower : Focus::ScalarExpression(T) 166 | getter upper : Focus::ScalarExpression(T) 167 | getter not_between : Bool 168 | 169 | def initialize(@expression, @lower, @upper, @not_between = false) 170 | end 171 | 172 | def wrap_in_parens? : Bool 173 | false 174 | end 175 | end 176 | 177 | class Focus::BinaryExpression(T, V) 178 | include Focus::ScalarExpression(T) 179 | 180 | getter type : Focus::BinaryExpressionType 181 | getter left : Focus::ScalarExpression(V) 182 | getter right : Focus::ScalarExpression(V) 183 | 184 | def initialize(@type, @left, @right) 185 | end 186 | 187 | def operator : String 188 | type.operator 189 | end 190 | end 191 | 192 | enum Focus::BinaryExpressionType 193 | PLUS 194 | MINUS 195 | TIMES 196 | DIV 197 | REM 198 | LIKE 199 | NOT_LIKE 200 | AND 201 | OR 202 | XOR 203 | LESS_THAN 204 | LESS_THAN_OR_EQUAL 205 | GREATER_THAN 206 | GREATER_THAN_OR_EQUAL 207 | EQUAL 208 | NOT_EQUAL 209 | 210 | def operator : String 211 | case self 212 | when PLUS 213 | "+" 214 | when MINUS 215 | "-" 216 | when TIMES 217 | "*" 218 | when DIV 219 | "/" 220 | when REM 221 | "%" 222 | when LIKE 223 | "like" 224 | when NOT_LIKE 225 | "not like" 226 | when AND 227 | "and" 228 | when OR 229 | "or" 230 | when XOR 231 | "xor" 232 | when LESS_THAN 233 | "<" 234 | when LESS_THAN_OR_EQUAL 235 | "<=" 236 | when GREATER_THAN 237 | ">" 238 | when GREATER_THAN_OR_EQUAL 239 | ">=" 240 | when EQUAL 241 | "=" 242 | when NOT_EQUAL 243 | "<>" 244 | else 245 | raise "missing a case statement for #{self}" 246 | end 247 | end 248 | end 249 | 250 | module Focus::BaseColumnExpression 251 | include Focus::SqlExpression 252 | 253 | getter table : Focus::TableExpression? 254 | getter name : String 255 | 256 | def wrap_in_parens? : Bool 257 | false 258 | end 259 | end 260 | 261 | class Focus::ColumnExpression(T) 262 | include Focus::ScalarExpression(T) 263 | include Focus::BaseColumnExpression 264 | 265 | def initialize(@table, @name) 266 | end 267 | 268 | def initialize(@name) 269 | @table = nil 270 | end 271 | end 272 | 273 | module Focus::BaseColumnDeclaringExpression 274 | include Focus::SqlExpression 275 | 276 | getter declared_name : String? 277 | 278 | def wrap_in_parens? : Bool 279 | false 280 | end 281 | end 282 | 283 | class Focus::ColumnDeclaringExpression(T) 284 | include Focus::ScalarExpression(T) 285 | include Focus::BaseColumnDeclaringExpression 286 | 287 | getter expression : ScalarExpression(T) 288 | 289 | def initialize(@expression, @declared_name) 290 | end 291 | end 292 | 293 | module Focus::BaseColumnAssignmentExpression 294 | include Focus::SqlExpression 295 | 296 | abstract def column : Focus::BaseColumnExpression 297 | abstract def expression : Focus::BaseScalarExpression 298 | end 299 | 300 | class Focus::ColumnAssignmentExpression(T) 301 | include Focus::BaseColumnAssignmentExpression 302 | 303 | getter column : Focus::ColumnExpression(T) 304 | getter expression : Focus::ScalarExpression(T) 305 | 306 | def initialize(@column, @expression) 307 | end 308 | end 309 | 310 | class Focus::DeleteExpression 311 | include Focus::SqlExpression 312 | getter table : TableExpression 313 | getter where : ScalarExpression(Bool)? 314 | 315 | def initialize(@table, @where) 316 | end 317 | end 318 | 319 | class Focus::ExistsExpression 320 | include Focus::ScalarExpression(Bool) 321 | 322 | getter query : QueryExpression 323 | getter not_exists : Bool 324 | 325 | def initialize(@query, @not_exists = false) 326 | end 327 | end 328 | 329 | class Focus::InListExpression(T) 330 | include Focus::ScalarExpression(Bool) 331 | 332 | getter left : ScalarExpression(T) 333 | getter query : QueryExpression? 334 | getter values : Array(ArgumentExpression(T))? 335 | getter not_in_list : Bool 336 | 337 | def initialize(@left, @query = nil, @values = nil, @not_in_list = false) 338 | end 339 | end 340 | 341 | class Focus::InsertExpression 342 | include Focus::SqlExpression 343 | 344 | getter table : Focus::TableExpression 345 | getter assignments : Array(BaseColumnAssignmentExpression) 346 | 347 | def initialize(@table, @assignments) 348 | end 349 | end 350 | 351 | class Focus::JoinExpression 352 | include Focus::QuerySourceExpression 353 | 354 | getter type : JoinType 355 | getter left : QuerySourceExpression 356 | getter right : QuerySourceExpression 357 | getter condition : ScalarExpression(Bool)? 358 | 359 | def initialize(@type, @left, @right, @condition = nil) 360 | end 361 | 362 | def join_type : String 363 | type.join_type 364 | end 365 | end 366 | 367 | enum Focus::JoinType 368 | CROSS_JOIN 369 | INNER_JOIN 370 | LEFT_JOIN 371 | RIGHT_JOIN 372 | 373 | def join_type : String 374 | case self 375 | when CROSS_JOIN 376 | "cross join" 377 | when INNER_JOIN 378 | "inner join" 379 | when LEFT_JOIN 380 | "left join" 381 | when RIGHT_JOIN 382 | "right join" 383 | else 384 | raise "missing a case statement for #{self}" 385 | end 386 | end 387 | end 388 | 389 | class Focus::OrderByExpression 390 | include Focus::SqlExpression 391 | 392 | getter expression : BaseScalarExpression 393 | getter order_type : OrderType 394 | 395 | def initialize(@expression, @order_type) 396 | end 397 | 398 | def order : String 399 | order_type.order 400 | end 401 | end 402 | 403 | enum Focus::OrderType 404 | ASCENDING 405 | DESCENDING 406 | 407 | def order : String 408 | case self 409 | when ASCENDING 410 | "asc" 411 | when DESCENDING 412 | "desc" 413 | else 414 | raise "missing a case statement for #{self}" 415 | end 416 | end 417 | end 418 | 419 | class Focus::UnaryExpression(T) 420 | include Focus::ScalarExpression(T) 421 | 422 | getter type : Focus::UnaryExpressionType 423 | getter operand : Focus::BaseScalarExpression 424 | 425 | def initialize(@type, @operand) 426 | end 427 | 428 | def operator : String 429 | type.operator 430 | end 431 | end 432 | 433 | enum Focus::UnaryExpressionType 434 | IS_NULL 435 | IS_NOT_NULL 436 | UNARY_MINUS 437 | UNARY_PLUS 438 | NOT 439 | 440 | def operator : String 441 | case self 442 | when IS_NULL 443 | "is null" 444 | when IS_NOT_NULL 445 | "is not null" 446 | when UNARY_MINUS 447 | "-" 448 | when UNARY_PLUS 449 | "+" 450 | when NOT 451 | "not" 452 | else 453 | raise "missing a case statement for #{self}" 454 | end 455 | end 456 | end 457 | 458 | class Focus::UpdateExpression 459 | include Focus::SqlExpression 460 | 461 | getter table : Focus::TableExpression 462 | getter assignments : Array(Focus::BaseColumnAssignmentExpression) 463 | getter where : Focus::ScalarExpression(Bool)? 464 | 465 | def initialize(@table, @assignments, @where = nil) 466 | end 467 | end 468 | -------------------------------------------------------------------------------- /src/focus/sql_formatter.cr: -------------------------------------------------------------------------------- 1 | require "./sql_visitor" 2 | 3 | abstract class Focus::SqlFormatter < Focus::SqlVisitor 4 | WHITESPACE_BYTE = 32_u8 5 | 6 | private getter sql_string_builder = String::Builder.new 7 | getter parameters = [] of Focus::BaseArgumentExpression 8 | 9 | def visit(expression : Focus::SelectExpression) 10 | write "select " 11 | write "distinct " if expression.is_distinct 12 | if expression.columns.empty? 13 | write "* " 14 | else 15 | visit_list(expression.columns) 16 | end 17 | write "from " 18 | visit_query_source(expression.from) 19 | if where = expression.where 20 | write "where " 21 | where.accept(self) 22 | end 23 | if !expression.group_by.empty? 24 | write "group by " 25 | visit_list expression.group_by 26 | end 27 | if having = expression.having 28 | write "having " 29 | having.accept(self) 30 | end 31 | if !expression.order_by.empty? 32 | write "order by " 33 | visit_list expression.order_by 34 | end 35 | 36 | if expression.limit || expression.offset 37 | write_pagination(expression) 38 | end 39 | end 40 | 41 | def visit(expression : Focus::BaseColumnExpression) 42 | if table = expression.table 43 | if table_alias = table.table_alias.presence 44 | write "#{quoted(table_alias)}." 45 | else 46 | if catalog = table.catalog.presence 47 | write "#{quoted(catalog)}." 48 | end 49 | if schema = table.schema.presence 50 | write "#{quoted(schema)}." 51 | end 52 | write "#{quoted(table.name)}." 53 | end 54 | end 55 | write "#{quoted(expression.name)} " 56 | end 57 | 58 | def visit(expression : Focus::TableExpression) 59 | if catalog = expression.catalog.presence 60 | write "#{quoted(catalog)}." 61 | end 62 | if schema = expression.schema.presence 63 | write "#{quoted(schema)}." 64 | end 65 | write "#{quoted(expression.name)} " 66 | 67 | if table_alias = expression.table_alias.presence 68 | write "#{quoted(table_alias)} " 69 | end 70 | end 71 | 72 | def visit(expression : Focus::BinaryExpression) 73 | if expression.left.wrap_in_parens? 74 | wrap_in_parens do 75 | expression.left.accept(self) 76 | end 77 | else 78 | expression.left.accept(self) 79 | end 80 | 81 | write "#{expression.operator} " 82 | 83 | if expression.right.wrap_in_parens? 84 | wrap_in_parens do 85 | expression.right.accept(self) 86 | end 87 | else 88 | expression.right.accept(self) 89 | end 90 | end 91 | 92 | def visit(expression : Focus::UnaryExpression(_)) 93 | case expression.type 94 | when UnaryExpressionType::IS_NULL, UnaryExpressionType::IS_NOT_NULL 95 | if expression.operand.wrap_in_parens? 96 | wrap_in_parens do 97 | expression.operand.accept(self) 98 | end 99 | else 100 | expression.operand.accept(self) 101 | end 102 | write "#{expression.operator} " 103 | else 104 | write "#{expression.operator} " 105 | 106 | if expression.operand.wrap_in_parens? 107 | wrap_in_parens do 108 | expression.operand.accept(self) 109 | end 110 | else 111 | expression.operand.accept(self) 112 | end 113 | end 114 | end 115 | 116 | abstract def visit(expression : Focus::ArgumentExpression) 117 | 118 | def visit(expression : Focus::BetweenExpression(_)) 119 | expression.expression.accept(self) 120 | 121 | if expression.not_between 122 | write "not between " 123 | else 124 | write "between " 125 | end 126 | 127 | expression.lower.accept(self) 128 | write "and " 129 | expression.upper.accept(self) 130 | end 131 | 132 | def visit(expression : Focus::ColumnDeclaringExpression(_)) 133 | expression.expression.accept(self) 134 | declared_name = expression.declared_name.presence 135 | column_expression = expression.expression.as?(Focus::ColumnExpression) 136 | if declared_name && (column_expression.nil? || column_expression.name != declared_name) 137 | write "as #{quoted(declared_name)} " 138 | end 139 | end 140 | 141 | def visit(expression : Focus::AggregateExpression(_)) 142 | write "#{expression.method}(" 143 | if expression.is_distinct 144 | write "distinct " 145 | end 146 | 147 | if arg = expression.argument 148 | arg.accept(self) 149 | else 150 | write "*" 151 | end 152 | 153 | remove_last_blank 154 | write ") " 155 | end 156 | 157 | def visit(expression : Focus::InsertExpression) 158 | write "insert into " 159 | expression.table.accept(self) 160 | write_insert_column_names(expression.assignments.map(&.column.as(BaseColumnExpression))) 161 | write "values " 162 | write_insert_values(expression.assignments) 163 | end 164 | 165 | def visit(expression : Focus::UpdateExpression) 166 | write "update " 167 | expression.table.accept(self) 168 | write "set " 169 | visit_column_assignments(expression.assignments) 170 | if where = expression.where 171 | write "where " 172 | where.accept(self) 173 | end 174 | end 175 | 176 | def visit(expression : Focus::InListExpression) 177 | expression.left.accept(self) 178 | 179 | if expression.not_in_list 180 | write "not in " 181 | else 182 | write "in " 183 | end 184 | 185 | if query = expression.query 186 | visit_query_source(query) 187 | end 188 | if values = expression.values 189 | write "(" 190 | visit_list(values) 191 | remove_last_blank 192 | write ") " 193 | end 194 | end 195 | 196 | def visit(expression : Focus::JoinExpression) 197 | visit_query_source(expression.left) 198 | write "#{expression.join_type} " 199 | visit_query_source(expression.right) 200 | 201 | if condition = expression.condition 202 | write "on " 203 | condition.accept(self) 204 | end 205 | end 206 | 207 | def visit(expression : Focus::OrderByExpression) 208 | expression.expression.accept(self) 209 | if expression.order_type == OrderType::DESCENDING 210 | write "desc " 211 | end 212 | end 213 | 214 | def visit(expression : Focus::DeleteExpression) 215 | write "delete from " 216 | expression.table.accept(self) 217 | 218 | if where = expression.where 219 | write "where " 220 | where.accept(self) 221 | end 222 | end 223 | 224 | def visit(expression : ExistsExpression) 225 | if expression.not_exists 226 | write "not exists" 227 | else 228 | write "exists " 229 | end 230 | 231 | visit_query_source(expression.query) 232 | end 233 | 234 | # TODO: figure out a good way to handle formatters not 235 | # providing all expected overloads 236 | def visit(expression : Focus::SqlExpression) 237 | raise "No visit method found for #{expression.class.name}" 238 | end 239 | 240 | def to_sql : String 241 | sql_string_builder.to_s 242 | end 243 | 244 | protected def visit_list(expressions : Array(Focus::SqlExpression)) 245 | expressions.each_with_index do |expression, idx| 246 | if idx > 0 247 | remove_last_blank 248 | write ", " 249 | end 250 | 251 | expression.accept(self) 252 | end 253 | end 254 | 255 | protected def visit_column_assignments(assignments : Array(BaseColumnAssignmentExpression)) 256 | assignments.each_with_index do |assignment, idx| 257 | if idx > 0 258 | remove_last_blank 259 | write ", " 260 | end 261 | 262 | write "#{quoted(assignment.column.name)} = " 263 | assignment.expression.accept(self) 264 | end 265 | end 266 | 267 | protected def visit_query_source(expression : QuerySourceExpression) 268 | case expression 269 | when TableExpression, JoinExpression 270 | expression.accept(self) 271 | when QueryExpression 272 | write "(" 273 | expression.accept(self) 274 | remove_last_blank 275 | write ")" 276 | expression.table_alias.try { |it| write "#{quoted(it)} " } 277 | end 278 | end 279 | 280 | protected def write_insert_column_names(columns : Array(BaseColumnExpression)) 281 | write "(" 282 | columns.each_with_index do |column, idx| 283 | write ", " if idx > 0 284 | write quoted(column.name) 285 | end 286 | write ") " 287 | end 288 | 289 | protected def write_insert_values(assignments : Array(BaseColumnAssignmentExpression)) 290 | write "(" 291 | visit_list(assignments.map(&.expression.as(BaseScalarExpression))) 292 | remove_last_blank 293 | write ") " 294 | end 295 | 296 | protected abstract def write_pagination(expr : QueryExpression) 297 | 298 | protected def remove_last_blank 299 | sql_string_builder.chomp!(WHITESPACE_BYTE) 300 | end 301 | 302 | protected def write(str : String) 303 | sql_string_builder << str 304 | end 305 | 306 | protected def quoted(str : String) 307 | str 308 | end 309 | 310 | protected def wrap_in_parens 311 | write "(" 312 | yield 313 | remove_last_blank 314 | write ") " 315 | end 316 | end 317 | -------------------------------------------------------------------------------- /src/focus/sql_visitor.cr: -------------------------------------------------------------------------------- 1 | abstract class Focus::SqlVisitor 2 | end 3 | -------------------------------------------------------------------------------- /src/focus/table.cr: -------------------------------------------------------------------------------- 1 | abstract class Focus::Table 2 | annotation ColumnLabel 3 | end 4 | 5 | private macro column(type_declaration) 6 | {% 7 | name = type_declaration.var 8 | name_str = name.stringify 9 | type = type_declaration.type 10 | %} 11 | 12 | @[ColumnLabel] 13 | getter {{ name }} : Focus::Column({{ type }}) do 14 | Focus::Column({{ type }}).new(table: self, name: {{ name_str }}) 15 | end 16 | end 17 | 18 | getter table_name : String 19 | getter columns : Array(Focus::BaseColumn) = [] of Focus::BaseColumn 20 | 21 | def initialize 22 | {% begin %} 23 | {% for ivar in @type.instance_vars %} 24 | {% ann = ivar.annotation(::Focus::Table::ColumnLabel) %} 25 | {% if ann %} 26 | columns << {{ ivar.id }} 27 | {% end %} 28 | {% end %} 29 | {% end %} 30 | end 31 | 32 | def as_expression : Focus::TableExpression 33 | Focus::TableExpression.new(name: table_name) 34 | end 35 | end 36 | -------------------------------------------------------------------------------- /src/focus/transaction_manager.cr: -------------------------------------------------------------------------------- 1 | # As long as I limit usage to block forms I can lean on crystal-db's connection and transaction handling 2 | class Focus::TransactionManager 3 | private getter raw_db : DB::Database 4 | private getter current_transaction : DB::Transaction? 5 | 6 | def initialize(@raw_db) 7 | end 8 | 9 | def with_connection(&block : DB::Connection -> T) : T forall T 10 | if transaction = current_transaction 11 | yield transaction.connection 12 | else 13 | raw_db.using_connection do |conn| 14 | yield conn 15 | end 16 | end 17 | end 18 | 19 | def with_transaction(&block : DB::Transaction -> T) : T? forall T 20 | if transaction = current_transaction 21 | yield transaction 22 | else 23 | raw_db.transaction do |txn| 24 | begin 25 | @current_transaction = txn 26 | yield txn 27 | ensure 28 | @current_transaction = nil 29 | end 30 | end 31 | end 32 | end 33 | end 34 | -------------------------------------------------------------------------------- /src/focus/update_statement_builder.cr: -------------------------------------------------------------------------------- 1 | class Focus::UpdateStatementBuilder < Focus::AssignmentsBuilder 2 | getter where : Focus::ColumnDeclaring(Bool)? 3 | 4 | def where(@where) 5 | end 6 | end 7 | -------------------------------------------------------------------------------- /src/mysql.cr: -------------------------------------------------------------------------------- 1 | require "mysql" 2 | require "./mysql/*" 3 | -------------------------------------------------------------------------------- /src/mysql/mysql_database.cr: -------------------------------------------------------------------------------- 1 | class Focus::MySqlDatabase < Focus::Database 2 | def self.connect(url : String) : MySqlDatabase 3 | new(raw_db: DB::Database.new(MySql::Driver.new, URI.parse(url))) 4 | end 5 | 6 | def self.connect(db : DB::Database) : MySqlDatabase 7 | new(raw_db: db) 8 | end 9 | 10 | def format_expression(expression : Focus::SqlExpression) : Tuple(String, Array(Focus::BaseArgumentExpression)) 11 | visitor = Focus::MySqlFormatter.new 12 | expression.accept(visitor) 13 | {visitor.to_sql, visitor.parameters} 14 | end 15 | 16 | def execute_insert_and_return_generated_key(expression : Focus::InsertExpression, column : Focus::BaseColumn) : DB::ResultSet 17 | with_connection do |conn| 18 | sql, args = format_expression(expression) 19 | conn.exec(sql, args: args.map(&.value)) 20 | conn.query("select last_insert_id()") 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /src/mysql/mysql_formatter.cr: -------------------------------------------------------------------------------- 1 | class Focus::MySqlFormatter < Focus::SqlFormatter 2 | def visit(expression : Focus::ArgumentExpression) 3 | write "? " 4 | parameters << expression 5 | end 6 | 7 | protected def write_pagination(expr : QueryExpression) 8 | write "limit ?, ? " 9 | parameters << ArgumentExpression(Int32).new(expr.offset || 0) 10 | parameters << ArgumentExpression(Int32).new(expr.limit || Int32::MAX) 11 | end 12 | end 13 | -------------------------------------------------------------------------------- /src/pg.cr: -------------------------------------------------------------------------------- 1 | require "pg" 2 | require "./pg/*" 3 | -------------------------------------------------------------------------------- /src/pg/i_like.cr: -------------------------------------------------------------------------------- 1 | class Focus::ILikeExpression 2 | include ScalarExpression(Bool) 3 | 4 | getter left : BaseScalarExpression 5 | getter right : BaseScalarExpression 6 | 7 | def initialize(@left, @right) 8 | end 9 | end 10 | 11 | module Focus::ColumnDeclaring(T) 12 | def i_like(expr : ColumnDeclaring(String)) : ILikeExpression 13 | ILikeExpression.new(as_expression, expr.as_expression) 14 | end 15 | 16 | def i_like(argument : String) : ILikeExpression 17 | i_like(ArgumentExpression.new(argument, String)) 18 | end 19 | end 20 | -------------------------------------------------------------------------------- /src/pg/insert_returning_expression.cr: -------------------------------------------------------------------------------- 1 | class Focus::InsertOrUpdateExpression 2 | include Focus::SqlExpression 3 | 4 | getter table : Focus::TableExpression 5 | getter assignments : Array(BaseColumnAssignmentExpression) 6 | getter conflict_columns : Array(BaseColumnExpression) 7 | getter update_assignments : Array(BaseColumnAssignmentExpression) 8 | getter returning_columns : Array(BaseColumnExpression) 9 | 10 | def initialize( 11 | @table, 12 | @assignments, 13 | @conflict_columns = [] of BaseColumnExpression, 14 | @update_assignments = [] of BaseColumnAssignmentExpression, 15 | @returning_columns = [] of BaseColumnExpression 16 | ) 17 | end 18 | end 19 | -------------------------------------------------------------------------------- /src/pg/pg_database.cr: -------------------------------------------------------------------------------- 1 | class Focus::PGDatabase < Focus::Database 2 | def self.connect(url : String) : PGDatabase 3 | new(raw_db: DB::Database.new(PG::Driver.new, URI.parse(url))) 4 | end 5 | 6 | def self.connect(db : DB::Database) : PGDatabase 7 | new(raw_db: db) 8 | end 9 | 10 | def format_expression(expression : Focus::SqlExpression) : Tuple(String, Array(Focus::BaseArgumentExpression)) 11 | visitor = Focus::PGFormatter.new 12 | expression.accept(visitor) 13 | {visitor.to_sql, visitor.parameters} 14 | end 15 | 16 | def execute_insert_and_return_generated_key(expression : Focus::InsertExpression, column : Focus::BaseColumn) : DB::ResultSet 17 | returning_expression = InsertOrUpdateExpression.new( 18 | table: expression.table, 19 | assignments: expression.assignments, 20 | returning_columns: [column.as_expression] of BaseColumnExpression 21 | ) 22 | execute_query(returning_expression) 23 | end 24 | end 25 | -------------------------------------------------------------------------------- /src/pg/pg_formatter.cr: -------------------------------------------------------------------------------- 1 | class Focus::PGFormatter < Focus::SqlFormatter 2 | property argument_counter = 1 3 | 4 | def visit(expression : TableExpression) 5 | if catalog = expression.catalog.presence 6 | write "#{quoted(catalog)}." 7 | end 8 | if schema = expression.schema.presence 9 | write "#{quoted(schema)}." 10 | end 11 | write "#{quoted(expression.name)} " 12 | 13 | if table_alias = expression.table_alias.presence 14 | write "as #{quoted(table_alias)} " 15 | end 16 | end 17 | 18 | def visit(expression : Focus::ArgumentExpression) 19 | write "$#{argument_counter} " 20 | parameters << expression 21 | self.argument_counter += 1 22 | end 23 | 24 | def visit(expression : Focus::ILikeExpression) 25 | if expression.left.wrap_in_parens? 26 | wrap_in_parens do 27 | expression.left.accept(self) 28 | end 29 | else 30 | expression.left.accept(self) 31 | end 32 | write "ilike " 33 | if expression.right.wrap_in_parens? 34 | wrap_in_parens do 35 | expression.right.accept(self) 36 | end 37 | else 38 | expression.right.accept(self) 39 | end 40 | end 41 | 42 | def visit(expression : Focus::InsertOrUpdateExpression) 43 | write "insert into " 44 | expression.table.accept(self) 45 | write_insert_column_names(expression.assignments.map(&.column.as(BaseColumnExpression))) 46 | write "values " 47 | write_insert_values(expression.assignments) 48 | 49 | if expression.conflict_columns.any? 50 | write "on conflict " 51 | write_insert_column_names(expression.conflict_columns) 52 | 53 | if expression.update_assignments.any? 54 | write "do update set " 55 | visit_column_assignments(expression.update_assignments) 56 | else 57 | write "do nothing " 58 | end 59 | end 60 | 61 | if expression.returning_columns.any? 62 | write "returning " 63 | 64 | expression.returning_columns.each_with_index do |returning_column, idx| 65 | write ", " if idx > 0 66 | write quoted(returning_column.name) 67 | end 68 | end 69 | end 70 | 71 | protected def write_pagination(expr : QueryExpression) 72 | if limit = expr.limit 73 | write "limit " 74 | ArgumentExpression(Int32).new(limit).accept(self) 75 | end 76 | if offset = expr.offset 77 | write "offset " 78 | ArgumentExpression(Int32).new(offset).accept(self) 79 | end 80 | end 81 | end 82 | -------------------------------------------------------------------------------- /src/sqlite.cr: -------------------------------------------------------------------------------- 1 | require "sqlite3" 2 | require "./sqlite/*" 3 | -------------------------------------------------------------------------------- /src/sqlite/sqlite_database.cr: -------------------------------------------------------------------------------- 1 | class Focus::SQLiteDatabase < Focus::Database 2 | def self.connect(url : String) : SQLiteDatabase 3 | new(raw_db: DB::Database.new(SQLite3::Driver.new, URI.parse(url))) 4 | end 5 | 6 | def self.connect(db : DB::Database) : SQLiteDatabase 7 | new(raw_db: db) 8 | end 9 | 10 | def format_expression(expression : Focus::SqlExpression) : Tuple(String, Array(Focus::BaseArgumentExpression)) 11 | visitor = Focus::SQLiteFormatter.new 12 | expression.accept(visitor) 13 | {visitor.to_sql, visitor.parameters} 14 | end 15 | 16 | def execute_insert_and_return_generated_key(expression : Focus::InsertExpression, column : Focus::BaseColumn) : DB::ResultSet 17 | with_connection do |conn| 18 | sql, args = format_expression(expression) 19 | conn.exec(sql, args: args.map(&.value)) 20 | conn.query("select last_insert_rowid()") 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /src/sqlite/sqlite_formatter.cr: -------------------------------------------------------------------------------- 1 | class Focus::SQLiteFormatter < Focus::SqlFormatter 2 | def visit(expression : Focus::ArgumentExpression) 3 | write "? " 4 | parameters << expression 5 | end 6 | 7 | protected def write_pagination(expr : QueryExpression) 8 | write "limit ?, ? " 9 | parameters << ArgumentExpression(Int32).new(expr.offset || 0) 10 | parameters << ArgumentExpression(Int32).new(expr.limit || Int32::MAX) 11 | end 12 | end 13 | -------------------------------------------------------------------------------- /test/mysql/mysql_database_test.cr: -------------------------------------------------------------------------------- 1 | require "./mysql_test_base" 2 | 3 | class MySqlDatabaseTest < MySqlTestBase 4 | def test_it_works 5 | database.insert(Departments) do 6 | set(Departments.name, "r&d") 7 | set(Departments.location, "Boston") 8 | end 9 | 10 | count = database.from(Departments) 11 | .select(Focus.count(Departments.id)) 12 | .where(Departments.name.eq("r&d")) 13 | .first 14 | .get(0, Int32) 15 | 16 | assert_equal 1, count 17 | 18 | database.delete(Departments, where: Departments.name.eq("r&d")) 19 | 20 | count = database.from(Departments) 21 | .select(Focus.count(Departments.id)) 22 | .where(Departments.name.eq("r&d")) 23 | .first 24 | .get(0, Int32) 25 | 26 | assert_equal 0, count 27 | end 28 | 29 | def test_limit 30 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(offset: 0, limit: 2) 31 | 32 | assert_equal [4, 3], query.map(&.get(Employees.id)) 33 | end 34 | 35 | def test_both_limit_and_offset_are_not_positive 36 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(offset: 0, limit: -1) 37 | 38 | assert_equal [4, 3, 2, 1], query.map(&.get(Employees.id)) 39 | end 40 | 41 | def test_limit_without_offset 42 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(2) 43 | 44 | assert_equal [4, 3], query.map(&.get(Employees.id)) 45 | end 46 | 47 | def test_offset_without_limit 48 | query = database.from(Employees).select.order_by(Employees.id.desc).offset(2) 49 | 50 | assert_equal [2, 1], query.map(&.get(Employees.id)) 51 | end 52 | 53 | def test_offset_with_limit 54 | query = database.from(Employees).select.order_by(Employees.id.desc).offset(2).limit(1) 55 | 56 | assert_equal [2], query.map(&.get(Employees.id)) 57 | end 58 | end 59 | -------------------------------------------------------------------------------- /test/mysql/mysql_test_base.cr: -------------------------------------------------------------------------------- 1 | require "../test_base" 2 | require "../../src/mysql" 3 | 4 | abstract class MySqlTestBase < TestBase 5 | @database : Focus::MySqlDatabase? 6 | 7 | def database : Focus::MySqlDatabase 8 | @database ||= Focus::MySqlDatabase.connect("mysql://root:password@localhost/test") 9 | end 10 | 11 | def setup 12 | exec_sql_script("./test/support/init-mysql-data.sql") 13 | end 14 | 15 | def teardown 16 | exec_sql_script("./test/support/drop-mysql-data.sql") 17 | end 18 | end 19 | -------------------------------------------------------------------------------- /test/pg/pg_database_test.cr: -------------------------------------------------------------------------------- 1 | require "./pg_test_base" 2 | 3 | class PGDatabaseTest < PGTestBase 4 | def test_it_works 5 | id = database.insert_returning_generated_key(Departments, Departments.id) do 6 | set(Departments.name, "r&d") 7 | set(Departments.location, "Boston") 8 | end 9 | 10 | department = database.from(Departments).select.where(Departments.id.eq(id)).first 11 | assert_equal "r&d", department.get(Departments.name) 12 | 13 | count = database.from(Departments) 14 | .select(Focus.count(Departments.id)) 15 | .where(Departments.name.eq("r&d")) 16 | .first 17 | .get(0, Int32) 18 | 19 | assert_equal 1, count 20 | 21 | database.delete(Departments, where: Departments.name.eq("r&d")) 22 | 23 | count = database.from(Departments) 24 | .select(Focus.count(Departments.id)) 25 | .where(Departments.name.eq("r&d")) 26 | .first 27 | .get(0, Int32) 28 | 29 | assert_equal 0, count 30 | end 31 | 32 | def test_limit 33 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(offset: 0, limit: 2) 34 | 35 | assert_equal [4, 3], query.map(&.get(Employees.id)) 36 | end 37 | 38 | def test_both_limit_and_offset_are_not_positive 39 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(offset: 0, limit: -1) 40 | 41 | assert_equal [4, 3, 2, 1], query.map(&.get(Employees.id)) 42 | end 43 | 44 | def test_limit_without_offset 45 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(2) 46 | 47 | assert_equal [4, 3], query.map(&.get(Employees.id)) 48 | end 49 | 50 | def test_offset_without_limit 51 | query = database.from(Employees).select.order_by(Employees.id.desc).offset(2) 52 | 53 | assert_equal [2, 1], query.map(&.get(Employees.id)) 54 | end 55 | 56 | def test_offset_with_limit 57 | query = database.from(Employees).select.order_by(Employees.id.desc).offset(2).limit(1) 58 | 59 | assert_equal [2], query.map(&.get(Employees.id)) 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /test/pg/pg_test_base.cr: -------------------------------------------------------------------------------- 1 | require "../test_base" 2 | require "../../src/pg" 3 | 4 | abstract class PGTestBase < TestBase 5 | @database : Focus::PGDatabase? 6 | 7 | def database : Focus::PGDatabase 8 | @database ||= Focus::PGDatabase.connect("postgres://postgres:postgres@localhost/test") 9 | end 10 | 11 | def setup 12 | exec_sql_script("./test/support/init-pg-data.sql") 13 | end 14 | 15 | def teardown 16 | exec_sql_script("./test/support/drop-pg-data.sql") 17 | end 18 | end 19 | -------------------------------------------------------------------------------- /test/sqlite/sqlite_database_test.cr: -------------------------------------------------------------------------------- 1 | require "./sqlite_test_base" 2 | 3 | class SQLiteDatabaseTest < SQLiteTestBase 4 | def test_it_works 5 | id = database.insert_returning_generated_key(Departments, Departments.id) do 6 | set(Departments.name, "r&d") 7 | set(Departments.location, "Boston") 8 | end 9 | 10 | department = database.from(Departments).select.where(Departments.id.eq(id)).first 11 | assert_equal "r&d", department.get(Departments.name) 12 | 13 | count = database.from(Departments) 14 | .select(Focus.count(Departments.id)) 15 | .where(Departments.name.eq("r&d")) 16 | .first 17 | .get(0, Int32) 18 | 19 | assert_equal 1, count 20 | 21 | database.delete(Departments, where: Departments.name.eq("r&d")) 22 | 23 | count = database.from(Departments) 24 | .select(Focus.count(Departments.id)) 25 | .where(Departments.name.eq("r&d")) 26 | .first 27 | .get(0, Int32) 28 | 29 | assert_equal 0, count 30 | end 31 | 32 | def test_limit 33 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(offset: 0, limit: 2) 34 | 35 | assert_equal [4, 3], query.map(&.get(Employees.id)) 36 | end 37 | 38 | def test_both_limit_and_offset_are_not_positive 39 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(offset: 0, limit: -1) 40 | 41 | assert_equal [4, 3, 2, 1], query.map(&.get(Employees.id)) 42 | end 43 | 44 | def test_limit_without_offset 45 | query = database.from(Employees).select.order_by(Employees.id.desc).limit(2) 46 | 47 | assert_equal [4, 3], query.map(&.get(Employees.id)) 48 | end 49 | 50 | def test_offset_without_limit 51 | query = database.from(Employees).select.order_by(Employees.id.desc).offset(2) 52 | 53 | assert_equal [2, 1], query.map(&.get(Employees.id)) 54 | end 55 | 56 | def test_offset_with_limit 57 | query = database.from(Employees).select.order_by(Employees.id.desc).offset(2).limit(1) 58 | 59 | assert_equal [2], query.map(&.get(Employees.id)) 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /test/sqlite/sqlite_test_base.cr: -------------------------------------------------------------------------------- 1 | require "../test_base" 2 | require "../../src/sqlite" 3 | 4 | abstract class SQLiteTestBase < TestBase 5 | @database : Focus::SQLiteDatabase? 6 | 7 | def database : Focus::SQLiteDatabase 8 | @database ||= Focus::SQLiteDatabase.connect("sqlite3://%3Amemory%3A") 9 | end 10 | 11 | def setup 12 | exec_sql_script("./test/support/init-sqlite-data.sql") 13 | end 14 | 15 | def teardown 16 | exec_sql_script("./test/support/drop-sqlite-data.sql") 17 | end 18 | end 19 | -------------------------------------------------------------------------------- /test/support/drop-mysql-data.sql: -------------------------------------------------------------------------------- 1 | drop table if exists departments; 2 | drop table if exists employees; 3 | -------------------------------------------------------------------------------- /test/support/drop-pg-data.sql: -------------------------------------------------------------------------------- 1 | drop table if exists departments; 2 | drop table if exists employees; 3 | -------------------------------------------------------------------------------- /test/support/drop-sqlite-data.sql: -------------------------------------------------------------------------------- 1 | drop table if exists "departments"; 2 | drop table if exists "employees"; 3 | -------------------------------------------------------------------------------- /test/support/init-mysql-data.sql: -------------------------------------------------------------------------------- 1 | create table departments( 2 | id int not null primary key auto_increment, 3 | name varchar(128) not null, 4 | location varchar(128) not null, 5 | mixedCase varchar(128) 6 | ); 7 | 8 | create table employees( 9 | id int not null primary key auto_increment, 10 | name varchar(128) not null, 11 | job varchar(128) not null, 12 | manager_id int null, 13 | hire_date date not null, 14 | salary bigint not null, 15 | department_id int not null 16 | ); 17 | 18 | create fulltext index employee_name_job on employees(name, job); 19 | 20 | insert into departments(name, location) values ('tech', 'Guangzhou'); 21 | insert into departments(name, location) values ('finance', 'Beijing'); 22 | 23 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 24 | values ('vince', 'engineer', null, '2018-01-01', 100, 1); 25 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 26 | values ('marry', 'trainee', 1, '2019-01-01', 50, 1); 27 | 28 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 29 | values ('tom', 'director', null, '2018-01-01', 200, 2); 30 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 31 | values ('penny', 'assistant', 3, '2019-01-01', 100, 2); 32 | 33 | 34 | -------------------------------------------------------------------------------- /test/support/init-pg-data.sql: -------------------------------------------------------------------------------- 1 | create table departments( 2 | id serial primary key, 3 | name varchar(128) not null, 4 | location varchar(128) not null, 5 | "mixedCase" varchar(128) 6 | ); 7 | 8 | create table employees( 9 | id serial primary key, 10 | name varchar(128) not null, 11 | job varchar(128) not null, 12 | manager_id int null, 13 | hire_date date not null, 14 | salary bigint not null, 15 | department_id int not null 16 | ); 17 | 18 | insert into departments(name, location) values ('tech', 'Guangzhou'); 19 | insert into departments(name, location) values ('finance', 'Beijing'); 20 | 21 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 22 | values ('vince', 'engineer', null, '2018-01-01', 100, 1); 23 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 24 | values ('marry', 'trainee', 1, '2019-01-01', 50, 1); 25 | 26 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 27 | values ('tom', 'director', null, '2018-01-01', 200, 2); 28 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 29 | values ('penny', 'assistant', 3, '2019-01-01', 100, 2); 30 | 31 | 32 | -------------------------------------------------------------------------------- /test/support/init-sqlite-data.sql: -------------------------------------------------------------------------------- 1 | create table departments( 2 | id integer primary key autoincrement, 3 | name text not null, 4 | location text not null, 5 | mixedCase text 6 | ); 7 | 8 | create table employees( 9 | id integer primary key autoincrement, 10 | name text not null, 11 | job text not null, 12 | manager_id integer null, 13 | hire_date integer not null, 14 | salary integer not null, 15 | department_id integer not null 16 | ); 17 | 18 | insert into departments(name, location) values ('tech', 'Guangzhou'); 19 | insert into departments(name, location) values ('finance', 'Beijing'); 20 | 21 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 22 | values ('vince', 'engineer', null, 1514736000000, 100, 1); 23 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 24 | values ('marry', 'trainee', 1, 1546272000000, 50, 1); 25 | 26 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 27 | values ('tom', 'director', null, 1514736000000, 200, 2); 28 | insert into employees(name, job, manager_id, hire_date, salary, department_id) 29 | values ('penny', 'assistant', 3, 1546272000000, 100, 2); 30 | -------------------------------------------------------------------------------- /test/support/tables.cr: -------------------------------------------------------------------------------- 1 | class DepartmentsTable < Focus::Table 2 | @table_name = "departments" 3 | 4 | column id : Int32 5 | column name : String 6 | column location : String 7 | end 8 | 9 | Departments = DepartmentsTable.new 10 | 11 | class EmployeesTable < Focus::Table 12 | @table_name = "employees" 13 | 14 | column id : Int32 15 | column name : String 16 | column job : String 17 | column manager_id : Int32 18 | column hire_date : Time 19 | column salary : Int32 20 | column department_id : Int32 21 | end 22 | 23 | Employees = EmployeesTable.new 24 | -------------------------------------------------------------------------------- /test/test_base.cr: -------------------------------------------------------------------------------- 1 | require "minitest" 2 | require "minitest/focus" 3 | require "../src/focus" 4 | require "./support/tables" 5 | 6 | abstract class TestBase < Minitest::Test 7 | abstract def database : Focus::Database 8 | 9 | def exec_sql_script(filename : String) 10 | database.with_connection do |conn| 11 | raw_sql = File.read(filename) 12 | raw_sql.split(';') 13 | .compact_map(&.presence) 14 | .each { |sql_stmt| conn.exec sql_stmt } 15 | end 16 | end 17 | end 18 | 19 | require "minitest/autorun" 20 | --------------------------------------------------------------------------------