├── .circleci └── config.yml ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ └── test.yaml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── boil.sh ├── boil ├── columns.go ├── columns_test.go ├── context_keys.go ├── db.go ├── db_test.go ├── debug.go ├── errors.go ├── errors_test.go ├── global.go ├── hooks.go └── hooks_test.go ├── boilingcore ├── aliases.go ├── aliases_test.go ├── boilingcore.go ├── boilingcore_test.go ├── config.go ├── config_test.go ├── output.go ├── output_test.go ├── templates.go ├── templates_test.go ├── text_helpers.go └── text_helpers_test.go ├── drivers ├── binary_driver.go ├── binary_driver_test.go ├── column.go ├── column_test.go ├── config.go ├── config_test.go ├── driver_main.go ├── interface.go ├── interface_test.go ├── keys.go ├── keys_test.go ├── mocks │ └── mock.go ├── registration.go ├── registration_test.go ├── relationships.go ├── relationships_test.go ├── sqlboiler-mssql │ ├── driver │ │ ├── mssql.go │ │ ├── mssql.golden.json │ │ ├── mssql_test.go │ │ ├── override │ │ │ ├── main │ │ │ │ ├── 17_upsert.go.tpl │ │ │ │ └── singleton │ │ │ │ │ └── mssql_upsert.go.tpl │ │ │ └── test │ │ │ │ ├── singleton │ │ │ │ ├── mssql_main_test.go.tpl │ │ │ │ └── mssql_suites_test.go.tpl │ │ │ │ └── upsert.go.tpl │ │ └── testdatabase.sql │ └── main.go ├── sqlboiler-mysql │ ├── driver │ │ ├── mysql.go │ │ ├── mysql.golden.enums.json │ │ ├── mysql.golden.json │ │ ├── mysql_test.go │ │ ├── override │ │ │ ├── main │ │ │ │ ├── 17_upsert.go.tpl │ │ │ │ └── singleton │ │ │ │ │ └── mysql_upsert.go.tpl │ │ │ └── test │ │ │ │ ├── singleton │ │ │ │ ├── mysql_main_test.go.tpl │ │ │ │ └── mysql_suites_test.go.tpl │ │ │ │ └── upsert.go.tpl │ │ └── testdatabase.sql │ └── main.go ├── sqlboiler-psql │ ├── driver │ │ ├── override │ │ │ ├── main │ │ │ │ ├── 17_upsert.go.tpl │ │ │ │ ├── 22_ilike.go.tpl │ │ │ │ ├── 23_similarto.go.tpl │ │ │ │ └── singleton │ │ │ │ │ └── psql_upsert.go.tpl │ │ │ └── test │ │ │ │ ├── singleton │ │ │ │ ├── psql_main_test.go.tpl │ │ │ │ └── psql_suites_test.go.tpl │ │ │ │ └── upsert.go.tpl │ │ ├── psql.go │ │ ├── psql.golden.enums.json │ │ ├── psql.golden.json │ │ ├── psql_test.go │ │ └── testdatabase.sql │ └── main.go ├── sqlboiler-sqlite3 │ ├── README.md │ ├── driver │ │ ├── override │ │ │ ├── main │ │ │ │ ├── 17_upsert.go.tpl │ │ │ │ └── singleton │ │ │ │ │ └── sqlite_upsert.go.tpl │ │ │ └── test │ │ │ │ ├── singleton │ │ │ │ ├── sqlite3_main_test.go.tpl │ │ │ │ └── sqlite3_suites_test.go.tpl │ │ │ │ └── upsert.go.tpl │ │ ├── sqlite3.go │ │ ├── sqlite3.golden.json │ │ ├── sqlite3_test.go │ │ └── testdatabase.sql │ └── main.go ├── table.go └── table_test.go ├── go.mod ├── go.sum ├── importers ├── imports.go └── imports_test.go ├── main.go ├── queries ├── _fixtures │ ├── 00.sql │ ├── 01.sql │ ├── 02.sql │ ├── 03.sql │ ├── 04.sql │ ├── 05.sql │ ├── 06.sql │ ├── 07.sql │ ├── 08.sql │ ├── 09.sql │ ├── 10.sql │ ├── 11.sql │ ├── 12.sql │ ├── 13.sql │ ├── 14.sql │ ├── 15.sql │ ├── 16.sql │ ├── 17.sql │ ├── 18.sql │ ├── 19.sql │ ├── 20.sql │ ├── 21.sql │ ├── 22.sql │ ├── 23.sql │ ├── 24.sql │ ├── 25.sql │ ├── 26.sql │ ├── 27.sql │ ├── 28.sql │ ├── 29.sql │ ├── 30.sql │ ├── 31.sql │ ├── 32.sql │ └── 33.sql ├── eager_load.go ├── eager_load_test.go ├── helpers.go ├── helpers_test.go ├── qm │ └── query_mods.go ├── qmhelper │ └── qmhelper.go ├── query.go ├── query_builders.go ├── query_builders_test.go ├── query_test.go ├── reflect.go └── reflect_test.go ├── templates ├── embed.go ├── main │ ├── 00_struct.go.tpl │ ├── 01_types.go.tpl │ ├── 02_hooks.go.tpl │ ├── 03_finishers.go.tpl │ ├── 04_relationship_to_one.go.tpl │ ├── 05_relationship_one_to_one.go.tpl │ ├── 06_relationship_to_many.go.tpl │ ├── 07_relationship_to_one_eager.go.tpl │ ├── 08_relationship_one_to_one_eager.go.tpl │ ├── 09_relationship_to_many_eager.go.tpl │ ├── 10_relationship_to_one_setops.go.tpl │ ├── 11_relationship_one_to_one_setops.go.tpl │ ├── 12_relationship_to_many_setops.go.tpl │ ├── 13_all.go.tpl │ ├── 14_find.go.tpl │ ├── 15_insert.go.tpl │ ├── 16_update.go.tpl │ ├── 18_delete.go.tpl │ ├── 19_reload.go.tpl │ ├── 20_exists.go.tpl │ ├── 21_auto_timestamps.go.tpl │ └── singleton │ │ ├── boil_queries.go.tpl │ │ ├── boil_table_names.go.tpl │ │ ├── boil_types.go.tpl │ │ └── boil_view_names.go.tpl └── test │ ├── 00_types.go.tpl │ ├── all.go.tpl │ ├── delete.go.tpl │ ├── exists.go.tpl │ ├── find.go.tpl │ ├── finishers.go.tpl │ ├── hooks.go.tpl │ ├── insert.go.tpl │ ├── relationship_one_to_one.go.tpl │ ├── relationship_one_to_one_setops.go.tpl │ ├── relationship_to_many.go.tpl │ ├── relationship_to_many_setops.go.tpl │ ├── relationship_to_one.go.tpl │ ├── relationship_to_one_setops.go.tpl │ ├── reload.go.tpl │ ├── select.go.tpl │ ├── singleton │ ├── boil_main_test.go.tpl │ ├── boil_queries_test.go.tpl │ ├── boil_relationship_test.go.tpl │ └── boil_suites_test.go.tpl │ ├── types.go.tpl │ └── update.go.tpl ├── testdata ├── Dockerfile ├── mssql_test_schema.sql ├── mysql_test_schema.sql └── psql_test_schema.sql └── types ├── array.go ├── array_test.go ├── byte.go ├── byte_test.go ├── decimal.go ├── decimal_test.go ├── hstore.go ├── json.go ├── json_test.go └── pgeo ├── box.go ├── circle.go ├── general.go ├── general_test.go ├── line.go ├── lseg.go ├── main.go ├── nullBox.go ├── nullCircle.go ├── nullLine.go ├── nullLseg.go ├── nullPath.go ├── nullPoint.go ├── nullPolygon.go ├── path.go ├── point.go └── polygon.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | working_directory: /root 5 | docker: 6 | - image: aarondl0/sqlboiler-test:v3 7 | 8 | - image: postgres:9.6 9 | environment: 10 | POSTGRES_PASSWORD: psqlpassword 11 | 12 | - image: mysql:5.7 13 | environment: 14 | MYSQL_ROOT_PASSWORD: mysqlpassword 15 | 16 | - image: microsoft/mssql-server-linux:2017-GDR 17 | environment: 18 | ACCEPT_EULA: 'Y' 19 | SA_PASSWORD: 'Sqlboiler@1234' 20 | 21 | environment: 22 | GOPATH: /go 23 | ROOTPATH: /go/src/github.com/volatiletech/sqlboiler 24 | 25 | steps: 26 | - run: 27 | name: 'Make GOPATH' 28 | command: mkdir -p $ROOTPATH 29 | 30 | - checkout: 31 | name: 'Checkout' 32 | path: /go/src/github.com/volatiletech/sqlboiler 33 | 34 | # Workaround to allow the use of the circleci local cli. 35 | - run: 36 | name: 'Checkout (local)' 37 | command: | 38 | if [ ! -z "$ROOTPATH" ]; then rmdir $ROOTPATH; ln -s /root $ROOTPATH; fi 39 | 40 | - run: 41 | name: 'Add PSQL Credentials' 42 | command: | 43 | echo "*:*:*:*:psqlpassword" > /root/.pgpass 44 | chmod 600 /root/.pgpass 45 | 46 | - run: 47 | name: 'Add MySQL Credentials' 48 | command: | 49 | echo -e "[client]\nuser = root\npassword = mysqlpassword\nhost = localhost\nprotocol = tcp" > /root/.my.cnf 50 | chmod 600 /root/.my.cnf 51 | 52 | - run: 53 | name: 'Wait for PSQL' 54 | command: > 55 | c=0; 56 | for i in `seq 30`; do 57 | echo "Waiting for psql" 58 | psql --host localhost --username postgres --dbname template1 -c 'select * from information_schema.tables;' > /dev/null && c=0 && break || c=$? && sleep 1 59 | done; 60 | exit $c 61 | 62 | - run: 63 | name: 'Wait for MySQL' 64 | command: > 65 | c=0; 66 | for i in `seq 30`; do 67 | echo "Waiting for mysql" 68 | mysql --execute 'select * from information_schema.tables;' > /dev/null > /dev/null && c=0 && break || c=$? && sleep 1 69 | done; 70 | exit $c 71 | 72 | - run: 73 | name: Wait for MSSQL 74 | command: > 75 | c=0; 76 | for i in `seq 30`; do 77 | echo "Waiting for mssql" 78 | sqlcmd -H localhost -U sa -P Sqlboiler@1234 -Q "select * from information_schema.tables;" > /dev/null > /dev/null && c=0 && break || c=$? && sleep 1 79 | done; 80 | exit $c 81 | 82 | - run: 83 | name: 'Download dependencies (core, driver, test, generated)' 84 | command: | 85 | cd $ROOTPATH; go get -v -t ./... 86 | 87 | - run: 88 | name: 'Build SQLBoiler core and drivers' 89 | command: | 90 | cd $ROOTPATH; make build 91 | cd $ROOTPATH; make build-{psql,mysql,mssql} 92 | 93 | - run: 94 | name: 'Prepare for tests' 95 | command: | 96 | mkdir -p $HOME/test_results 97 | 98 | - run: 99 | name: 'Tests: All (except drivers,vendor)' 100 | command: | 101 | cd $ROOTPATH 102 | make test | tee $HOME/test_results/results.txt 103 | for engine in psql mysql mssql; do 104 | make test-user-${engine} 105 | make test-db-${engine} 106 | make test-generate-${engine} 107 | # workaround to fix failing tests due to the absence of 'table_schema.sql' 108 | if [ "${engine}" != "mssql" ]; then 109 | make test-${engine} | tee $HOME/test_results/results.${engine}.txt 110 | fi 111 | done 112 | 113 | - run: 114 | name: 'Tests: Drivers' 115 | command: | 116 | cd $ROOTPATH 117 | for engine in psql mysql mssql; do 118 | make driver-db-${engine} 119 | make driver-user-${engine} 120 | make driver-test-${engine} | tee $HOME/test_results/results.driver-${engine}.txt 121 | done 122 | 123 | - run: 124 | name: 'Tests: Convert from plain to JUnit' 125 | command: | 126 | for file in $HOME/test_results/*.txt; do 127 | cat ${file} | go-junit-report > "${file%.txt}.xml" 128 | done 129 | 130 | - store_test_results: 131 | name: 'Store test results' 132 | path: test_results 133 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | If you're having a generation problem please answer these questions before submitting your issue. Thanks! 7 | 8 | ### What version of SQLBoiler are you using (`sqlboiler --version`)? 9 | 10 | 11 | ### What is your database and version (eg. Postgresql 10) 12 | 13 | 14 | ### If this happened at generation time what was the full SQLBoiler command you used to generate your models? (if not applicable leave blank) 15 | 16 | 17 | ### If this happened at runtime what code produced the issue? (if not applicable leave blank) 18 | 19 | 20 | ### What is the output of the command above with the `-d` flag added to it? (Provided you are comfortable sharing this, it contains a blueprint of your schema) 21 | 22 | 23 | ### Please provide a relevant database schema so we can replicate your issue (Provided you are comfortable sharing this) 24 | 25 | 26 | ### Further information. What did you do, what did you expect? 27 | 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /sqlboiler 2 | /sqlboiler-psql 3 | /sqlboiler-mysql 4 | /sqlboiler-mssql 5 | /cmd/sqlboiler/sqlboiler 6 | sqlboiler.toml 7 | models/ 8 | testschema.sql 9 | .cover 10 | *.sqlite3 11 | issue*.sql 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for your interest in contributing to SQLBoiler! 4 | 5 | We have a very lightweight process and aim to keep it that way. 6 | Read the sections for the piece you're interested in and go from 7 | there. 8 | 9 | If you need quick communication we're usually on [Slack](https://sqlboiler.from-the.cloud). 10 | 11 | # New Code / Features 12 | 13 | ## Small Change 14 | 15 | #### TLDR 16 | 17 | 1. Open PR against **master** branch with explanation 18 | 1. Participate in Github Code Review 19 | 20 | #### Long version 21 | 22 | For code that requires little to no discussion, please just open a pull request with some 23 | explanation against the **master** branch. 24 | 25 | ## Bigger Change 26 | 27 | #### TLDR 28 | 29 | 1. Start proposal of idea in Github issue 30 | 1. After design concensus, open PR with the work against the **master** branch 31 | 1. Participate in Github Code Review 32 | 33 | #### Long version 34 | 35 | If however you're working on something bigger, it's usually better to check with us on the idea 36 | before starting on a pull request, just so there's no time wasted in redoing/refactoring or being 37 | outright rejected because the PR is at odds with the design. The best way to accomplish this is to 38 | open an issue to discuss it. It can always start as a Slack conversation but should eventually end 39 | up as an issue to avoid penalizing the rest of the users for not being on Slack. Once we agree on 40 | the way to do something, then open the PR against the **master** branch and we'll commence code review 41 | with the Github code review tools. Then it will be merged into master, and later go out in a release. 42 | 43 | ## Developer getting started 44 | 45 | 1. Add a [Configuration files](https://github.com/volatiletech/sqlboiler#configuration). 46 | 1. Write your changes 47 | 1. Generate executable. Run again if you have changed anything in core code or driver code. 48 | ``` 49 | ./boil.sh build all 50 | ``` 51 | 52 | 1. Also Move sqlboiler-[driver] built to the bin of gopath if you have changed the driver code. 53 | 54 | 1. Generate your models from existing tables 55 | 56 | ``` 57 | ./boil.sh gen [driver] 58 | ``` 59 | 60 | 1. You may need to install following package before able to run the tests. 61 | 62 | ``` 63 | go get -u github.com/volatiletech/null 64 | ``` 65 | 66 | 1. Test the output 67 | 68 | ``` 69 | ./boil.sh test 70 | ``` 71 | 72 | 73 | # Bugs 74 | 75 | Issues should be filed on Github, simply use the template provided and fill in detail. If there's 76 | more information you feel you should give use your best judgement and add it in, the more the better. 77 | See the section below for information on providing database schemas. 78 | 79 | Bugs that have responses from contributors but no action from those who opened them after a time 80 | will be closed with the comment: "Stale" 81 | 82 | ## Schemas 83 | 84 | A database schema can help us fix generation issues very quickly. However not everyone is willing to part 85 | with their database schema for various reasons and that's fine. Instead of providing the schema please 86 | then provide a subset of your database (you can munge the names so as to be unrecognizable) that can 87 | help us reproduce the problem. 88 | 89 | _Note:_ Your schema information is included in the output from `--debug`, so be careful giving this 90 | information out publicly on a Github issue if you're sensitive about this. 91 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Volatile Technologies Inc. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Vattle or Volatile Technologies Inc. nor the 14 | names of its contributors may be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /boil/columns_test.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestColumns(t *testing.T) { 9 | t.Parallel() 10 | 11 | list := Whitelist("a", "b") 12 | if list.Kind != columnsWhitelist || !list.IsWhitelist() { 13 | t.Error(list.Kind) 14 | } 15 | if list.Cols[0] != "a" || list.Cols[1] != "b" { 16 | t.Error("columns were wrong") 17 | } 18 | list = Blacklist("a", "b") 19 | if list.Kind != columnsBlacklist || !list.IsBlacklist() { 20 | t.Error(list.Kind) 21 | } 22 | if list.Cols[0] != "a" || list.Cols[1] != "b" { 23 | t.Error("columns were wrong") 24 | } 25 | list = Greylist("a", "b") 26 | if list.Kind != columnsGreylist || !list.IsGreylist() { 27 | t.Error(list.Kind) 28 | } 29 | if list.Cols[0] != "a" || list.Cols[1] != "b" { 30 | t.Error("columns were wrong") 31 | } 32 | 33 | list = Infer() 34 | if list.Kind != columnsInfer || !list.IsInfer() { 35 | t.Error(list.Kind) 36 | } 37 | if len(list.Cols) != 0 { 38 | t.Error("non zero length columns") 39 | } 40 | } 41 | 42 | func TestInsertColumnSet(t *testing.T) { 43 | t.Parallel() 44 | 45 | columns := []string{"a", "b", "c"} 46 | defaults := []string{"a", "c"} 47 | nodefaults := []string{"b"} 48 | 49 | tests := []struct { 50 | Columns Columns 51 | Cols []string 52 | Defaults []string 53 | NoDefaults []string 54 | NonZeroDefaults []string 55 | Set []string 56 | Ret []string 57 | }{ 58 | // Infer 59 | {Columns: Infer(), Set: []string{"b"}, Ret: []string{"a", "c"}}, 60 | {Columns: Infer(), Defaults: []string{}, NoDefaults: []string{"a", "b", "c"}, Set: []string{"a", "b", "c"}, Ret: []string{}}, 61 | 62 | // Infer with non-zero defaults 63 | {Columns: Infer(), NonZeroDefaults: []string{"a"}, Set: []string{"a", "b"}, Ret: []string{"c"}}, 64 | {Columns: Infer(), NonZeroDefaults: []string{"c"}, Set: []string{"b", "c"}, Ret: []string{"a"}}, 65 | 66 | // Whitelist 67 | {Columns: Whitelist("a"), Set: []string{"a"}, Ret: []string{"c"}}, 68 | {Columns: Whitelist("c"), Set: []string{"c"}, Ret: []string{"a"}}, 69 | {Columns: Whitelist("a", "c"), Set: []string{"a", "c"}, Ret: []string{}}, 70 | {Columns: Whitelist("a", "b", "c"), Set: []string{"a", "b", "c"}, Ret: []string{}}, 71 | 72 | // Whitelist + Nonzero defaults (shouldn't care, same results as above) 73 | {Columns: Whitelist("a"), NonZeroDefaults: []string{"c"}, Set: []string{"a"}, Ret: []string{"c"}}, 74 | {Columns: Whitelist("c"), NonZeroDefaults: []string{"b"}, Set: []string{"c"}, Ret: []string{"a"}}, 75 | 76 | // Blacklist 77 | {Columns: Blacklist("b"), NonZeroDefaults: []string{"c"}, Set: []string{"c"}, Ret: []string{"a"}}, 78 | {Columns: Blacklist("c"), NonZeroDefaults: []string{"c"}, Set: []string{"b"}, Ret: []string{"a", "c"}}, 79 | 80 | // Greylist 81 | {Columns: Greylist("c"), NonZeroDefaults: []string{}, Set: []string{"b", "c"}, Ret: []string{"a"}}, 82 | {Columns: Greylist("a"), NonZeroDefaults: []string{}, Set: []string{"a", "b"}, Ret: []string{"c"}}, 83 | } 84 | 85 | for i, test := range tests { 86 | if test.Cols == nil { 87 | test.Cols = columns 88 | } 89 | if test.Defaults == nil { 90 | test.Defaults = defaults 91 | } 92 | if test.NoDefaults == nil { 93 | test.NoDefaults = nodefaults 94 | } 95 | 96 | set, ret := test.Columns.InsertColumnSet(test.Cols, test.Defaults, test.NoDefaults, test.NonZeroDefaults) 97 | 98 | if !reflect.DeepEqual(set, test.Set) { 99 | t.Errorf("%d) set was wrong\nwant: %v\ngot: %v", i, test.Set, set) 100 | } 101 | if !reflect.DeepEqual(ret, test.Ret) { 102 | t.Errorf("%d) ret was wrong\nwant: %v\ngot: %v", i, test.Ret, ret) 103 | } 104 | } 105 | } 106 | 107 | func TestUpdateColumnSet(t *testing.T) { 108 | t.Parallel() 109 | 110 | tests := []struct { 111 | Columns Columns 112 | Cols []string 113 | PKeys []string 114 | Out []string 115 | }{ 116 | // Infer 117 | {Columns: Infer(), Cols: []string{"a", "b"}, PKeys: []string{"a"}, Out: []string{"b"}}, 118 | 119 | // Whitelist 120 | {Columns: Whitelist("a"), Cols: []string{"a", "b"}, PKeys: []string{"a"}, Out: []string{"a"}}, 121 | {Columns: Whitelist("a", "b"), Cols: []string{"a", "b"}, PKeys: []string{"a"}, Out: []string{"a", "b"}}, 122 | 123 | // Blacklist 124 | {Columns: Blacklist("b"), Cols: []string{"a", "b"}, PKeys: []string{"a"}, Out: []string{}}, 125 | 126 | // Greylist 127 | {Columns: Greylist("a"), Cols: []string{"a", "b"}, PKeys: []string{"a"}, Out: []string{"a", "b"}}, 128 | } 129 | 130 | for i, test := range tests { 131 | set := test.Columns.UpdateColumnSet(test.Cols, test.PKeys) 132 | 133 | if !reflect.DeepEqual(set, test.Out) { 134 | t.Errorf("%d) set was wrong\nwant: %v\ngot: %v", i, test.Out, set) 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /boil/context_keys.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | type contextType int 4 | 5 | const ( 6 | ctxSkipHooks contextType = iota 7 | ctxSkipTimestamps 8 | ctxDebug 9 | ctxDebugWriter 10 | ) 11 | -------------------------------------------------------------------------------- /boil/db.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | // Executor can perform SQL queries. 9 | type Executor interface { 10 | Exec(query string, args ...interface{}) (sql.Result, error) 11 | Query(query string, args ...interface{}) (*sql.Rows, error) 12 | QueryRow(query string, args ...interface{}) *sql.Row 13 | } 14 | 15 | // ContextExecutor can perform SQL queries with context 16 | type ContextExecutor interface { 17 | Executor 18 | 19 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 20 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 21 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 22 | } 23 | 24 | // Transactor can commit and rollback, on top of being able to execute queries. 25 | type Transactor interface { 26 | Commit() error 27 | Rollback() error 28 | 29 | Executor 30 | } 31 | 32 | // Beginner begins transactions. 33 | type Beginner interface { 34 | Begin() (*sql.Tx, error) 35 | } 36 | 37 | // Begin a transaction with the current global database handle. 38 | func Begin() (Transactor, error) { 39 | creator, ok := currentDB.(Beginner) 40 | if !ok { 41 | panic("database does not support transactions") 42 | } 43 | 44 | return creator.Begin() 45 | } 46 | 47 | // ContextTransactor can commit and rollback, on top of being able to execute 48 | // context-aware queries. 49 | type ContextTransactor interface { 50 | Commit() error 51 | Rollback() error 52 | 53 | ContextExecutor 54 | } 55 | 56 | // ContextBeginner allows creation of context aware transactions with options. 57 | type ContextBeginner interface { 58 | BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) 59 | } 60 | 61 | // BeginTx begins a transaction with the current global database handle. 62 | func BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { 63 | creator, ok := currentDB.(ContextBeginner) 64 | if !ok { 65 | panic("database does not support context-aware transactions") 66 | } 67 | 68 | return creator.BeginTx(ctx, opts) 69 | } 70 | -------------------------------------------------------------------------------- /boil/db_test.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | ) 7 | 8 | func TestGetSetDB(t *testing.T) { 9 | t.Parallel() 10 | 11 | SetDB(&sql.DB{}) 12 | 13 | if GetDB() == nil { 14 | t.Errorf("Expected GetDB to return a database handle, got nil") 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /boil/debug.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "os" 7 | ) 8 | 9 | // DebugMode is a flag controlling whether generated sql statements and 10 | // debug information is outputted to the DebugWriter handle 11 | // 12 | // NOTE: This should be disabled in production to avoid leaking sensitive data 13 | var DebugMode = false 14 | 15 | // DebugWriter is where the debug output will be sent if DebugMode is true 16 | var DebugWriter io.Writer = os.Stdout 17 | 18 | // WithDebug modifies a context to configure debug writing. If true, 19 | // all queries made using this context will be outputted to the io.Writer 20 | // returned by DebugWriterFrom. 21 | func WithDebug(ctx context.Context, debug bool) context.Context { 22 | return context.WithValue(ctx, ctxDebug, debug) 23 | } 24 | 25 | // IsDebug returns true if the context has debugging enabled, or 26 | // the value of DebugMode if not set. 27 | func IsDebug(ctx context.Context) bool { 28 | debug, ok := ctx.Value(ctxDebug).(bool) 29 | if ok { 30 | return debug 31 | } 32 | return DebugMode 33 | } 34 | 35 | // WithDebugWriter modifies a context to configure the writer written to 36 | // when debugging is enabled. 37 | func WithDebugWriter(ctx context.Context, writer io.Writer) context.Context { 38 | return context.WithValue(ctx, ctxDebugWriter, writer) 39 | } 40 | 41 | // DebugWriterFrom returns the debug writer for the context, or DebugWriter 42 | // if not set. 43 | func DebugWriterFrom(ctx context.Context) io.Writer { 44 | writer, ok := ctx.Value(ctxDebugWriter).(io.Writer) 45 | if ok { 46 | return writer 47 | } 48 | return DebugWriter 49 | } 50 | -------------------------------------------------------------------------------- /boil/errors.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | type boilErr struct { 4 | error 5 | } 6 | 7 | // WrapErr wraps err in a boilErr 8 | func WrapErr(err error) error { 9 | return boilErr{ 10 | error: err, 11 | } 12 | } 13 | 14 | // Error returns the underlying error string 15 | func (e boilErr) Error() string { 16 | return e.error.Error() 17 | } 18 | 19 | // IsBoilErr checks if err is a boilErr 20 | func IsBoilErr(err error) bool { 21 | _, ok := err.(boilErr) 22 | return ok 23 | } 24 | -------------------------------------------------------------------------------- /boil/errors_test.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestErrors(t *testing.T) { 9 | t.Parallel() 10 | 11 | err := errors.New("test error") 12 | if IsBoilErr(err) == true { 13 | t.Errorf("Expected false") 14 | } 15 | 16 | err = WrapErr(errors.New("test error")) 17 | if err.Error() != "test error" { 18 | t.Errorf(`Expected "test error", got %v`, err.Error()) 19 | } 20 | 21 | if IsBoilErr(err) != true { 22 | t.Errorf("Expected true") 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /boil/global.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | var ( 8 | // currentDB is a global database handle for the package 9 | currentDB Executor 10 | currentContextDB ContextExecutor 11 | // timestampLocation is the timezone used for the 12 | // automated setting of created_at/updated_at columns 13 | timestampLocation = time.UTC 14 | ) 15 | 16 | // SetDB initializes the database handle for all template db interactions 17 | func SetDB(db Executor) { 18 | currentDB = db 19 | if c, ok := currentDB.(ContextExecutor); ok { 20 | currentContextDB = c 21 | } 22 | } 23 | 24 | // GetDB retrieves the global state database handle 25 | func GetDB() Executor { 26 | return currentDB 27 | } 28 | 29 | // GetContextDB retrieves the global state database handle as a context executor 30 | func GetContextDB() ContextExecutor { 31 | return currentContextDB 32 | } 33 | 34 | // SetLocation sets the global timestamp Location. 35 | // This is the timezone used by the generated package for the 36 | // automated setting of created_at and updated_at columns. 37 | // If the package was generated with the --no-auto-timestamps flag 38 | // then this function has no effect. 39 | func SetLocation(loc *time.Location) { 40 | timestampLocation = loc 41 | } 42 | 43 | // GetLocation retrieves the global timestamp Location. 44 | // This is the timezone used by the generated package for the 45 | // automated setting of created_at and updated_at columns 46 | // if the package was not generated with the --no-auto-timestamps flag. 47 | func GetLocation() *time.Location { 48 | return timestampLocation 49 | } 50 | -------------------------------------------------------------------------------- /boil/hooks.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import "context" 4 | 5 | // SkipHooks modifies a context to prevent hooks from running for any query 6 | // it encounters. 7 | func SkipHooks(ctx context.Context) context.Context { 8 | return context.WithValue(ctx, ctxSkipHooks, true) 9 | } 10 | 11 | // HooksAreSkipped returns true if the context skips hooks 12 | func HooksAreSkipped(ctx context.Context) bool { 13 | skip := ctx.Value(ctxSkipHooks) 14 | return skip != nil && skip.(bool) 15 | } 16 | 17 | // SkipTimestamps modifies a context to prevent hooks from running for any query 18 | // it encounters. 19 | func SkipTimestamps(ctx context.Context) context.Context { 20 | return context.WithValue(ctx, ctxSkipTimestamps, true) 21 | } 22 | 23 | // TimestampsAreSkipped returns true if the context skips hooks 24 | func TimestampsAreSkipped(ctx context.Context) bool { 25 | skip := ctx.Value(ctxSkipTimestamps) 26 | return skip != nil && skip.(bool) 27 | } 28 | 29 | // HookPoint is the point in time at which we hook 30 | type HookPoint int 31 | 32 | // the hook point constants 33 | const ( 34 | BeforeInsertHook HookPoint = iota + 1 35 | BeforeUpdateHook 36 | BeforeDeleteHook 37 | BeforeUpsertHook 38 | AfterInsertHook 39 | AfterSelectHook 40 | AfterUpdateHook 41 | AfterDeleteHook 42 | AfterUpsertHook 43 | ) 44 | -------------------------------------------------------------------------------- /boil/hooks_test.go: -------------------------------------------------------------------------------- 1 | package boil 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestSkipHooks(t *testing.T) { 9 | t.Parallel() 10 | 11 | ctx := context.Background() 12 | if HooksAreSkipped(ctx) { 13 | t.Error("they should not be skipped") 14 | } 15 | 16 | ctx = SkipHooks(ctx) 17 | 18 | if !HooksAreSkipped(ctx) { 19 | t.Error("they should be skipped") 20 | } 21 | } 22 | 23 | func TestSkipTimestamps(t *testing.T) { 24 | t.Parallel() 25 | 26 | ctx := context.Background() 27 | if TimestampsAreSkipped(ctx) { 28 | t.Error("they should not be skipped") 29 | } 30 | 31 | ctx = SkipTimestamps(ctx) 32 | 33 | if !TimestampsAreSkipped(ctx) { 34 | t.Error("they should be skipped") 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /boilingcore/output_test.go: -------------------------------------------------------------------------------- 1 | package boilingcore 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/google/go-cmp/cmp" 12 | ) 13 | 14 | type NopWriteCloser struct { 15 | io.Writer 16 | } 17 | 18 | func (NopWriteCloser) Close() error { 19 | return nil 20 | } 21 | 22 | func nopCloser(w io.Writer) io.WriteCloser { 23 | return NopWriteCloser{w} 24 | } 25 | 26 | func TestWriteFile(t *testing.T) { 27 | // t.Parallel() cannot be used 28 | 29 | // set the function pointer back to its original value 30 | // after we modify it for the test 31 | saveTestHarnessWriteFile := testHarnessWriteFile 32 | defer func() { 33 | testHarnessWriteFile = saveTestHarnessWriteFile 34 | }() 35 | 36 | var output []byte 37 | testHarnessWriteFile = func(_ string, in []byte, _ os.FileMode) error { 38 | output = in 39 | return nil 40 | } 41 | 42 | buf := &bytes.Buffer{} 43 | writePackageName(buf, "pkg") 44 | fmt.Fprintf(buf, "func hello() {}\n\n\nfunc world() {\nreturn\n}\n\n\n\n") 45 | 46 | if err := writeFile("", "", buf, true); err != nil { 47 | t.Error(err) 48 | } 49 | 50 | if string(output) != "package pkg\n\nfunc hello() {}\n\nfunc world() {\n\treturn\n}\n" { 51 | t.Errorf("Wrong output: %q", output) 52 | } 53 | } 54 | 55 | func TestFormatBuffer(t *testing.T) { 56 | t.Parallel() 57 | 58 | buf := &bytes.Buffer{} 59 | 60 | fmt.Fprintf(buf, "package pkg\n\nfunc() {a}\n") 61 | 62 | // Only test error case - happy case is taken care of by template test 63 | _, err := formatBuffer(buf) 64 | if err == nil { 65 | t.Error("want an error") 66 | } 67 | 68 | if txt := err.Error(); !strings.Contains(txt, ">>>> func() {a}") { 69 | t.Error("got:\n", txt) 70 | } 71 | } 72 | 73 | func TestOutputFilenameParts(t *testing.T) { 74 | t.Parallel() 75 | 76 | tests := []struct { 77 | Filename string 78 | 79 | FirstDir string 80 | Normalized string 81 | IsSingleton bool 82 | IsGo bool 83 | UsePkg bool 84 | }{ 85 | {"templates/00_struct.go.tpl", "templates", "struct.go", false, true, true}, 86 | {"templates/singleton/00_struct.go.tpl", "templates", "struct.go", true, true, true}, 87 | {"templates/notpkg/00_struct.go.tpl", "templates", "notpkg/struct.go", false, true, false}, 88 | {"templates/js/singleton/00_struct.js.tpl", "templates", "js/struct.js", true, false, false}, 89 | {"templates/js/00_struct.js.tpl", "templates", "js/struct.js", false, false, false}, 90 | } 91 | 92 | for i, test := range tests { 93 | normalized, isSingleton, isGo, usePkg := outputFilenameParts(test.Filename) 94 | 95 | if normalized != test.Normalized { 96 | t.Errorf("%d) normalized wrong, want: %s, got: %s", i, test.Normalized, normalized) 97 | } 98 | if isSingleton != test.IsSingleton { 99 | t.Errorf("%d) isSingleton wrong, want: %t, got: %t", i, test.IsSingleton, isSingleton) 100 | } 101 | if isGo != test.IsGo { 102 | t.Errorf("%d) isGo wrong, want: %t, got: %t", i, test.IsGo, isGo) 103 | } 104 | if usePkg != test.UsePkg { 105 | t.Errorf("%d) usePkg wrong, want: %t, got: %t", i, test.UsePkg, usePkg) 106 | } 107 | } 108 | } 109 | 110 | func TestGetOutputFilename(t *testing.T) { 111 | t.Parallel() 112 | 113 | tests := map[string]struct { 114 | TableName string 115 | IsTest bool 116 | IsGo bool 117 | Expected string 118 | }{ 119 | "regular": { 120 | TableName: "hello", 121 | IsTest: false, 122 | IsGo: true, 123 | Expected: "hello", 124 | }, 125 | "contains_forward_slash": { 126 | TableName: "slash/test", 127 | IsTest: false, 128 | IsGo: true, 129 | Expected: "slash_test_model", 130 | }, 131 | "begins with underscore": { 132 | TableName: "_hello", 133 | IsTest: false, 134 | IsGo: true, 135 | Expected: "und_hello", 136 | }, 137 | "ends with _test": { 138 | TableName: "hello_test", 139 | IsTest: false, 140 | IsGo: true, 141 | Expected: "hello_test_model", 142 | }, 143 | "ends with _js": { 144 | TableName: "hello_js", 145 | IsTest: false, 146 | IsGo: true, 147 | Expected: "hello_js_model", 148 | }, 149 | "ends with _windows": { 150 | TableName: "hello_windows", 151 | IsTest: false, 152 | IsGo: true, 153 | Expected: "hello_windows_model", 154 | }, 155 | "ends with _arm64": { 156 | TableName: "hello_arm64", 157 | IsTest: false, 158 | IsGo: true, 159 | Expected: "hello_arm64_model", 160 | }, 161 | "non-go ends with _arm64": { 162 | TableName: "hello_arm64", 163 | IsTest: false, 164 | IsGo: false, 165 | Expected: "hello_arm64", 166 | }, 167 | } 168 | 169 | for name, tc := range tests { 170 | t.Run(name, func(t *testing.T) { 171 | notTest := getOutputFilename(tc.TableName, false, tc.IsGo) 172 | if diff := cmp.Diff(tc.Expected, notTest); diff != "" { 173 | t.Fatalf(diff) 174 | } 175 | 176 | isTest := getOutputFilename(tc.TableName, true, tc.IsGo) 177 | if diff := cmp.Diff(tc.Expected+"_test", isTest); diff != "" { 178 | t.Fatalf(diff) 179 | } 180 | }) 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /boilingcore/templates_test.go: -------------------------------------------------------------------------------- 1 | package boilingcore 2 | 3 | import ( 4 | "sort" 5 | "testing" 6 | "text/template" 7 | ) 8 | 9 | func TestTemplateNameListSort(t *testing.T) { 10 | t.Parallel() 11 | 12 | templs := templateNameList{ 13 | "bob.tpl", 14 | "all.tpl", 15 | "struct.tpl", 16 | "ttt.tpl", 17 | } 18 | 19 | expected := []string{"bob.tpl", "all.tpl", "struct.tpl", "ttt.tpl"} 20 | 21 | for i, v := range templs { 22 | if v != expected[i] { 23 | t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v) 24 | } 25 | } 26 | 27 | expected = []string{"struct.tpl", "all.tpl", "bob.tpl", "ttt.tpl"} 28 | 29 | sort.Sort(templs) 30 | 31 | for i, v := range templs { 32 | if v != expected[i] { 33 | t.Errorf("Order mismatch, expected: %s, got: %s", expected[i], v) 34 | } 35 | } 36 | } 37 | 38 | func TestTemplateList_Templates(t *testing.T) { 39 | t.Parallel() 40 | 41 | tpl := template.New("") 42 | tpl.New("wat.tpl").Parse("hello") 43 | tpl.New("que.tpl").Parse("there") 44 | tpl.New("not").Parse("hello") 45 | 46 | tplList := templateList{tpl} 47 | foundWat, foundQue, foundNot := false, false, false 48 | for _, n := range tplList.Templates() { 49 | switch n { 50 | case "wat.tpl": 51 | foundWat = true 52 | case "que.tpl": 53 | foundQue = true 54 | case "not": 55 | foundNot = true 56 | } 57 | } 58 | 59 | if !foundWat { 60 | t.Error("want wat") 61 | } 62 | if !foundQue { 63 | t.Error("want que") 64 | } 65 | if foundNot { 66 | t.Error("don't want not") 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /boilingcore/text_helpers_test.go: -------------------------------------------------------------------------------- 1 | package boilingcore 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/volatiletech/sqlboiler/v4/drivers" 7 | ) 8 | 9 | func TestTxtNameToOne(t *testing.T) { 10 | t.Parallel() 11 | 12 | tests := []struct { 13 | Table string 14 | Column string 15 | Unique bool 16 | ForeignTable string 17 | ForeignColumn string 18 | ForeignColumnUnique bool 19 | 20 | LocalFn string 21 | ForeignFn string 22 | }{ 23 | {"jets", "airport_id", false, "airports", "id", true, "Jets", "Airport"}, 24 | {"jets", "airport_id", true, "airports", "id", true, "Jet", "Airport"}, 25 | 26 | {"jets", "holiday_id", false, "airports", "id", true, "HolidayJets", "Holiday"}, 27 | {"jets", "holiday_id", true, "airports", "id", true, "HolidayJet", "Holiday"}, 28 | 29 | {"jets", "holiday_airport_id", false, "airports", "id", true, "HolidayAirportJets", "HolidayAirport"}, 30 | {"jets", "holiday_airport_id", true, "airports", "id", true, "HolidayAirportJet", "HolidayAirport"}, 31 | 32 | {"jets", "jet_id", false, "jets", "id", true, "Jets", "Jet"}, 33 | {"jets", "jet_id", true, "jets", "id", true, "Jet", "Jet"}, 34 | {"jets", "plane_id", false, "jets", "id", true, "PlaneJets", "Plane"}, 35 | {"jets", "plane_id", true, "jets", "id", true, "PlaneJet", "Plane"}, 36 | 37 | {"videos", "user_id", false, "users", "id", true, "Videos", "User"}, 38 | {"videos", "producer_id", false, "users", "id", true, "ProducerVideos", "Producer"}, 39 | {"videos", "user_id", true, "users", "id", true, "Video", "User"}, 40 | {"videos", "producer_id", true, "users", "id", true, "ProducerVideo", "Producer"}, 41 | 42 | {"videos", "user", false, "users", "id", true, "Videos", "VideoUser"}, 43 | {"videos", "created_by", false, "users", "id", true, "CreatedByVideos", "CreatedByUser"}, 44 | {"videos", "director", false, "users", "id", true, "DirectorVideos", "DirectorUser"}, 45 | {"videos", "user", true, "users", "id", true, "Video", "VideoUser"}, 46 | {"videos", "created_by", true, "users", "id", true, "CreatedByVideo", "CreatedByUser"}, 47 | {"videos", "director", true, "users", "id", true, "DirectorVideo", "DirectorUser"}, 48 | 49 | {"industries", "industry_id", false, "industries", "id", true, "Industries", "Industry"}, 50 | {"industries", "parent_id", false, "industries", "id", true, "ParentIndustries", "Parent"}, 51 | {"industries", "industry_id", true, "industries", "id", true, "Industry", "Industry"}, 52 | {"industries", "parent_id", true, "industries", "id", true, "ParentIndustry", "Parent"}, 53 | 54 | {"race_result_scratchings", "results_id", false, "race_results", "id", true, "ResultRaceResultScratchings", "Result"}, 55 | } 56 | 57 | for i, test := range tests { 58 | fk := drivers.ForeignKey{ 59 | Table: test.Table, Column: test.Column, Unique: test.Unique, 60 | ForeignTable: test.ForeignTable, ForeignColumn: test.ForeignColumn, ForeignColumnUnique: test.ForeignColumnUnique, 61 | } 62 | 63 | local, foreign := txtNameToOne(fk) 64 | if local != test.LocalFn { 65 | t.Error(i, "local wrong:", local, "want:", test.LocalFn) 66 | } 67 | if foreign != test.ForeignFn { 68 | t.Error(i, "foreign wrong:", foreign, "want:", test.ForeignFn) 69 | } 70 | } 71 | } 72 | 73 | func TestTxtNameToMany(t *testing.T) { 74 | t.Parallel() 75 | 76 | tests := []struct { 77 | LHSTable string 78 | LHSColumn string 79 | 80 | RHSTable string 81 | RHSColumn string 82 | 83 | LHSFn string 84 | RHSFn string 85 | }{ 86 | {"pilots", "pilot_id", "languages", "language_id", "Pilots", "Languages"}, 87 | {"pilots", "captain_id", "languages", "lingo_id", "CaptainPilots", "LingoLanguages"}, 88 | 89 | {"pilots", "pilot_id", "pilots", "mentor_id", "Pilots", "MentorPilots"}, 90 | {"pilots", "mentor_id", "pilots", "pilot_id", "MentorPilots", "Pilots"}, 91 | {"pilots", "captain_id", "pilots", "mentor_id", "CaptainPilots", "MentorPilots"}, 92 | 93 | {"videos", "video_id", "tags", "tag_id", "Videos", "Tags"}, 94 | {"tags", "tag_id", "videos", "video_id", "Tags", "Videos"}, 95 | } 96 | 97 | for i, test := range tests { 98 | lhsFk := drivers.ForeignKey{ 99 | ForeignTable: test.LHSTable, 100 | Column: test.LHSColumn, 101 | } 102 | rhsFk := drivers.ForeignKey{ 103 | ForeignTable: test.RHSTable, 104 | Column: test.RHSColumn, 105 | } 106 | 107 | lhs, rhs := txtNameToMany(lhsFk, rhsFk) 108 | if lhs != test.LHSFn { 109 | t.Error(i, "local wrong:", lhs, "want:", test.LHSFn) 110 | } 111 | if rhs != test.RHSFn { 112 | t.Error(i, "foreign wrong:", rhs, "want:", test.RHSFn) 113 | } 114 | } 115 | } 116 | 117 | func TestTrimSuffixes(t *testing.T) { 118 | t.Parallel() 119 | 120 | for _, s := range identifierSuffixes { 121 | a := "hello" + s 122 | 123 | if z := trimSuffixes(a); z != "hello" { 124 | t.Errorf("got %s", z) 125 | } 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /drivers/binary_driver.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | "os" 8 | "os/exec" 9 | 10 | "github.com/friendsofgo/errors" 11 | "github.com/volatiletech/sqlboiler/v4/importers" 12 | ) 13 | 14 | type binaryDriver string 15 | 16 | // Assemble calls out to the binary with JSON 17 | // The contract for error messages is that a plain text error message is delivered 18 | // and the exit status of the process is non-zero 19 | func (b binaryDriver) Assemble(config Config) (*DBInfo, error) { 20 | var dbInfo DBInfo 21 | err := execute(string(b), "assemble", config, &dbInfo, os.Stderr) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | return &dbInfo, nil 27 | } 28 | 29 | // Templates calls the templates function to get a map of overidden file names 30 | // and their contents in base64 31 | func (b binaryDriver) Templates() (map[string]string, error) { 32 | var templates map[string]string 33 | err := execute(string(b), "templates", nil, &templates, os.Stderr) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return templates, nil 39 | } 40 | 41 | // Imports calls the imports function to get imports from the driver 42 | func (b binaryDriver) Imports() (col importers.Collection, err error) { 43 | err = execute(string(b), "imports", nil, &col, os.Stderr) 44 | if err != nil { 45 | return col, err 46 | } 47 | 48 | return col, nil 49 | } 50 | 51 | func execute(executable, method string, input interface{}, output interface{}, errStream io.Writer) error { 52 | var err error 53 | var inputBytes []byte 54 | if input != nil { 55 | inputBytes, err = json.Marshal(input) 56 | if err != nil { 57 | return errors.Wrap(err, "failed to json-ify driver configuration") 58 | } 59 | } 60 | 61 | outputBytes := &bytes.Buffer{} 62 | cmd := exec.Command(executable, method) 63 | cmd.Stdout = outputBytes 64 | cmd.Stderr = errStream 65 | if inputBytes != nil { 66 | cmd.Stdin = bytes.NewReader(inputBytes) 67 | } 68 | err = cmd.Run() 69 | 70 | if err != nil { 71 | if ee, ok := err.(*exec.ExitError); ok { 72 | if ee.ProcessState.Exited() && !ee.ProcessState.Success() { 73 | return errors.Wrapf(err, "driver (%s) exited non-zero", executable) 74 | } 75 | } 76 | 77 | return errors.Wrapf(err, "something totally unexpected happened when running the binary driver %s", executable) 78 | } 79 | 80 | if err = json.Unmarshal(outputBytes.Bytes(), &output); err != nil { 81 | return errors.Wrap(err, "failed to marshal json from binary") 82 | } 83 | 84 | return nil 85 | } 86 | -------------------------------------------------------------------------------- /drivers/binary_driver_test.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "os" 8 | "reflect" 9 | "runtime" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | var testBinaryDriver = fmt.Sprintf("#!/bin/sh\ncat <&2 17 | echo "{}" 18 | ` 19 | var testBadBinaryDriver = `#!/bin/sh 20 | echo "bad binary" 1>&2 21 | exit 1 22 | ` 23 | 24 | func TestBinaryDriver(t *testing.T) { 25 | if runtime.GOOS == "windows" { 26 | t.Skip("cannot run binary test on windows (needs bin/sh)") 27 | } 28 | 29 | var want, got *DBInfo 30 | if err := json.Unmarshal([]byte(testBinaryJSON), &want); err != nil { 31 | t.Fatal(err) 32 | } 33 | 34 | bin, err := os.CreateTemp("", "test_binary_driver") 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | fmt.Fprint(bin, testBinaryDriver) 39 | if err := bin.Chmod(0774); err != nil { 40 | t.Fatal(err) 41 | } 42 | if err := bin.Close(); err != nil { 43 | t.Fatal(err) 44 | } 45 | 46 | name := bin.Name() 47 | 48 | exe := binaryDriver(name) 49 | got, err = exe.Assemble(nil) 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | if !reflect.DeepEqual(want, got) { 55 | t.Errorf("want:\n%#v\ngot:\n%#v\n", want, got) 56 | } 57 | } 58 | 59 | func TestBinaryWarningDriver(t *testing.T) { 60 | if runtime.GOOS == "windows" { 61 | t.Skip("cannot run binary test on windows (needs bin/sh)") 62 | } 63 | 64 | bin, err := os.CreateTemp("", "test_binary_driver") 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | fmt.Fprint(bin, testWarningBinaryDriver) 69 | if err := bin.Chmod(0774); err != nil { 70 | t.Fatal(err) 71 | } 72 | if err := bin.Close(); err != nil { 73 | t.Fatal(err) 74 | } 75 | 76 | stderr := &bytes.Buffer{} 77 | err = execute(bin.Name(), "method", nil, nil, stderr) 78 | if err != nil { 79 | t.Error(err) 80 | } else if !strings.Contains(stderr.String(), "warning binary") { 81 | t.Error("it should have written to stderr") 82 | } 83 | } 84 | 85 | func TestBinaryBadDriver(t *testing.T) { 86 | if runtime.GOOS == "windows" { 87 | t.Skip("cannot run binary test on windows (needs bin/sh)") 88 | } 89 | 90 | bin, err := os.CreateTemp("", "test_binary_driver") 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | fmt.Fprint(bin, testBadBinaryDriver) 95 | if err := bin.Chmod(0774); err != nil { 96 | t.Fatal(err) 97 | } 98 | if err := bin.Close(); err != nil { 99 | t.Fatal(err) 100 | } 101 | 102 | stderr := &bytes.Buffer{} 103 | err = execute(bin.Name(), "method", nil, nil, stderr) 104 | if err == nil { 105 | t.Error("it should have failed when the program exited 1") 106 | } else if !strings.Contains(stderr.String(), "bad binary") { 107 | t.Error("it should have written to stderr") 108 | } else if !strings.Contains(err.Error(), "non-zero") { 109 | t.Error("it should have reported non-zero exit") 110 | } 111 | } 112 | 113 | var testBinaryJSON = ` 114 | { 115 | "tables": [ 116 | { 117 | "name": "users", 118 | "schema_name": "dbo", 119 | "columns": [ 120 | { 121 | "name": "id", 122 | "type": "int", 123 | "db_type": "integer", 124 | "default": "", 125 | "nullable": false, 126 | "unique": true, 127 | "validated": false, 128 | "arr_type": null, 129 | "udt_name": "", 130 | "full_db_type": "", 131 | "auto_generated": false 132 | }, 133 | { 134 | "name": "profile_id", 135 | "type": "int", 136 | "db_type": "integer", 137 | "default": "", 138 | "nullable": false, 139 | "unique": true, 140 | "validated": false, 141 | "arr_type": null, 142 | "udt_name": "", 143 | "full_db_type": "", 144 | "auto_generated": false 145 | } 146 | ], 147 | "p_key": { 148 | "name": "pk_users", 149 | "columns": ["id"] 150 | }, 151 | "f_keys": [ 152 | { 153 | "table": "users", 154 | "name": "fk_users_profile", 155 | "column": "profile_id", 156 | "nullable": false, 157 | "unique": true, 158 | "foreign_table": "profiles", 159 | "foreign_column": "id", 160 | "foreign_column_nullable": false, 161 | "foreign_column_unique": true 162 | } 163 | ], 164 | "is_join_table": false, 165 | "to_one_relationships": [ 166 | { 167 | "table": "users", 168 | "name": "fk_users_profile", 169 | "column": "profile_id", 170 | "nullable": false, 171 | "unique": true, 172 | "foreign_table": "profiles", 173 | "foreign_column": "id", 174 | "foreign_column_nullable": false, 175 | "foreign_column_unique": true 176 | } 177 | ] 178 | } 179 | ], 180 | "dialect": { 181 | "lq": 91, 182 | "rq": 93, 183 | 184 | "use_index_placeholders": false, 185 | "use_last_insert_id": true, 186 | "use_top_clause": false 187 | } 188 | } 189 | ` 190 | -------------------------------------------------------------------------------- /drivers/column.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "regexp" 5 | 6 | "github.com/volatiletech/strmangle" 7 | ) 8 | 9 | var rgxEnum = regexp.MustCompile(`^enum(\.\w+)?\([^)]+\)$`) 10 | 11 | // Column holds information about a database column. 12 | // Types are Go types, converted by TranslateColumnType. 13 | type Column struct { 14 | Name string `json:"name" toml:"name"` 15 | Type string `json:"type" toml:"type"` 16 | DBType string `json:"db_type" toml:"db_type"` 17 | Default string `json:"default" toml:"default"` 18 | Comment string `json:"comment" toml:"comment"` 19 | Nullable bool `json:"nullable" toml:"nullable"` 20 | Unique bool `json:"unique" toml:"unique"` 21 | Validated bool `json:"validated" toml:"validated"` 22 | AutoGenerated bool `json:"auto_generated" toml:"auto_generated"` 23 | 24 | // Postgres only extension bits 25 | // ArrType is the underlying data type of the Postgres 26 | // ARRAY type. See here: 27 | // https://www.postgresql.org/docs/9.1/static/infoschema-element-types.html 28 | ArrType *string `json:"arr_type" toml:"arr_type"` 29 | UDTName string `json:"udt_name" toml:"udt_name"` 30 | // DomainName is the domain type name associated to the column. See here: 31 | // https://www.postgresql.org/docs/10/extend-type-system.html#EXTEND-TYPE-SYSTEM-DOMAINS 32 | DomainName *string `json:"domain_name" toml:"domain_name"` 33 | 34 | // MySQL only bits 35 | // Used to get full type, ex: 36 | // tinyint(1) instead of tinyint 37 | // Used for "tinyint-as-bool" flag 38 | FullDBType string `json:"full_db_type" toml:"full_db_type"` 39 | } 40 | 41 | // ColumnNames of the columns. 42 | func ColumnNames(cols []Column) []string { 43 | names := make([]string, len(cols)) 44 | for i, c := range cols { 45 | names[i] = c.Name 46 | } 47 | 48 | return names 49 | } 50 | 51 | // ColumnDBTypes of the columns. 52 | func ColumnDBTypes(cols []Column) map[string]string { 53 | types := map[string]string{} 54 | 55 | for _, c := range cols { 56 | types[strmangle.TitleCase(c.Name)] = c.DBType 57 | } 58 | 59 | return types 60 | } 61 | 62 | // FilterColumnsByAuto generates the list of columns that have autogenerated values 63 | func FilterColumnsByAuto(auto bool, columns []Column) []Column { 64 | var cols []Column 65 | 66 | for _, c := range columns { 67 | if (auto && c.AutoGenerated) || (!auto && !c.AutoGenerated) { 68 | cols = append(cols, c) 69 | } 70 | } 71 | 72 | return cols 73 | } 74 | 75 | // FilterColumnsByDefault generates the list of columns that have default values 76 | func FilterColumnsByDefault(defaults bool, columns []Column) []Column { 77 | var cols []Column 78 | 79 | for _, c := range columns { 80 | if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) { 81 | cols = append(cols, c) 82 | } 83 | } 84 | 85 | return cols 86 | } 87 | 88 | // FilterColumnsByEnum generates the list of columns that are enum values. 89 | func FilterColumnsByEnum(columns []Column) []Column { 90 | var cols []Column 91 | 92 | for _, c := range columns { 93 | if rgxEnum.MatchString(c.DBType) { 94 | cols = append(cols, c) 95 | } 96 | } 97 | 98 | return cols 99 | } 100 | 101 | // IsEnumDBType reports whether the column type is Enum 102 | func IsEnumDBType(dbType string) bool { 103 | return rgxEnum.MatchString(dbType) 104 | } 105 | -------------------------------------------------------------------------------- /drivers/column_test.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestColumnNames(t *testing.T) { 9 | t.Parallel() 10 | 11 | cols := []Column{ 12 | {Name: "one"}, 13 | {Name: "two"}, 14 | {Name: "three"}, 15 | } 16 | 17 | out := strings.Join(ColumnNames(cols), " ") 18 | if out != "one two three" { 19 | t.Error("output was wrong:", out) 20 | } 21 | } 22 | 23 | func TestColumnDBTypes(t *testing.T) { 24 | cols := []Column{ 25 | {Name: "test_one", DBType: "integer"}, 26 | {Name: "test_two", DBType: "interval"}, 27 | } 28 | 29 | res := ColumnDBTypes(cols) 30 | if res["TestOne"] != "integer" { 31 | t.Errorf(`Expected res["TestOne"]="integer", got: %s`, res["TestOne"]) 32 | } 33 | if res["TestTwo"] != "interval" { 34 | t.Errorf(`Expected res["TestOne"]="interval", got: %s`, res["TestOne"]) 35 | } 36 | } 37 | 38 | func TestFilterColumnsByDefault(t *testing.T) { 39 | t.Parallel() 40 | 41 | cols := []Column{ 42 | {Name: "col1", Default: ""}, 43 | {Name: "col2", Default: "things"}, 44 | {Name: "col3", Default: ""}, 45 | {Name: "col4", Default: "things2"}, 46 | } 47 | 48 | res := FilterColumnsByDefault(false, cols) 49 | if res[0].Name != `col1` { 50 | t.Errorf("Invalid result: %#v", res) 51 | } 52 | if res[1].Name != `col3` { 53 | t.Errorf("Invalid result: %#v", res) 54 | } 55 | 56 | res = FilterColumnsByDefault(true, cols) 57 | if res[0].Name != `col2` { 58 | t.Errorf("Invalid result: %#v", res) 59 | } 60 | if res[1].Name != `col4` { 61 | t.Errorf("Invalid result: %#v", res) 62 | } 63 | 64 | res = FilterColumnsByDefault(false, []Column{}) 65 | if res != nil { 66 | t.Errorf("Invalid result: %#v", res) 67 | } 68 | } 69 | 70 | func TestFilterColumnsByEnum(t *testing.T) { 71 | t.Parallel() 72 | 73 | cols := []Column{ 74 | {Name: "col1", DBType: "enum('hello')"}, 75 | {Name: "col2", DBType: "enum('hello','there')"}, 76 | {Name: "col3", DBType: "enum"}, 77 | {Name: "col4", DBType: ""}, 78 | {Name: "col5", DBType: "int"}, 79 | } 80 | 81 | res := FilterColumnsByEnum(cols) 82 | if res[0].Name != `col1` { 83 | t.Errorf("Invalid result: %#v", res) 84 | } 85 | if res[1].Name != `col2` { 86 | t.Errorf("Invalid result: %#v", res) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /drivers/driver_main.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "os" 8 | ) 9 | 10 | // DriverMain helps dry up the implementation of main.go for drivers 11 | func DriverMain(driver Interface) { 12 | method := os.Args[1] 13 | var config Config 14 | 15 | switch method { 16 | case "assemble": 17 | b, err := io.ReadAll(os.Stdin) 18 | if err != nil { 19 | fmt.Fprintln(os.Stderr, "failed to read from stdin") 20 | os.Exit(1) 21 | } 22 | 23 | err = json.Unmarshal(b, &config) 24 | if err != nil { 25 | fmt.Fprintf(os.Stderr, "failed to parse json from stdin: %v\n%s\n", err, b) 26 | os.Exit(1) 27 | } 28 | case "templates": 29 | // No input for this method 30 | case "imports": 31 | // No input for this method 32 | } 33 | 34 | var output interface{} 35 | switch method { 36 | case "assemble": 37 | dinfo, err := driver.Assemble(config) 38 | if err != nil { 39 | fmt.Fprintln(os.Stderr, err) 40 | os.Exit(1) 41 | } 42 | output = dinfo 43 | case "templates": 44 | templates, err := driver.Templates() 45 | if err != nil { 46 | fmt.Fprintln(os.Stderr, err) 47 | os.Exit(1) 48 | } 49 | output = templates 50 | case "imports": 51 | collection, err := driver.Imports() 52 | if err != nil { 53 | fmt.Fprintln(os.Stderr, err) 54 | os.Exit(1) 55 | } 56 | output = collection 57 | } 58 | 59 | b, err := json.Marshal(output) 60 | if err != nil { 61 | fmt.Fprintln(os.Stderr, "failed to marshal json:", err) 62 | os.Exit(1) 63 | } 64 | 65 | os.Stdout.Write(b) 66 | } 67 | -------------------------------------------------------------------------------- /drivers/keys.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import "fmt" 4 | 5 | // PrimaryKey represents a primary key constraint in a database 6 | type PrimaryKey struct { 7 | Name string `json:"name"` 8 | Columns []string `json:"columns"` 9 | } 10 | 11 | // ForeignKey represents a foreign key constraint in a database 12 | type ForeignKey struct { 13 | Table string `json:"table"` 14 | Name string `json:"name"` 15 | Column string `json:"column"` 16 | Nullable bool `json:"nullable"` 17 | Unique bool `json:"unique"` 18 | 19 | ForeignTable string `json:"foreign_table"` 20 | ForeignColumn string `json:"foreign_column"` 21 | ForeignColumnNullable bool `json:"foreign_column_nullable"` 22 | ForeignColumnUnique bool `json:"foreign_column_unique"` 23 | } 24 | 25 | // SQLColumnDef formats a column name and type like an SQL column definition. 26 | type SQLColumnDef struct { 27 | Name string 28 | Type string 29 | } 30 | 31 | // String for fmt.Stringer 32 | func (s SQLColumnDef) String() string { 33 | return fmt.Sprintf("%s %s", s.Name, s.Type) 34 | } 35 | 36 | // SQLColumnDefs has small helper functions 37 | type SQLColumnDefs []SQLColumnDef 38 | 39 | // Names returns all the names 40 | func (s SQLColumnDefs) Names() []string { 41 | names := make([]string, len(s)) 42 | 43 | for i, sqlDef := range s { 44 | names[i] = sqlDef.Name 45 | } 46 | 47 | return names 48 | } 49 | 50 | // Types returns all the types 51 | func (s SQLColumnDefs) Types() []string { 52 | types := make([]string, len(s)) 53 | 54 | for i, sqlDef := range s { 55 | types[i] = sqlDef.Type 56 | } 57 | 58 | return types 59 | } 60 | 61 | // SQLColDefinitions creates a definition in sql format for a column 62 | func SQLColDefinitions(cols []Column, names []string) SQLColumnDefs { 63 | ret := make([]SQLColumnDef, len(names)) 64 | 65 | for i, n := range names { 66 | for _, c := range cols { 67 | if n != c.Name { 68 | continue 69 | } 70 | 71 | ret[i] = SQLColumnDef{Name: n, Type: c.Type} 72 | } 73 | } 74 | 75 | return ret 76 | } 77 | -------------------------------------------------------------------------------- /drivers/keys_test.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import "testing" 4 | 5 | func TestSQLColDefinitions(t *testing.T) { 6 | t.Parallel() 7 | 8 | cols := []Column{ 9 | {Name: "one", Type: "int64"}, 10 | {Name: "two", Type: "string"}, 11 | {Name: "three", Type: "string"}, 12 | } 13 | 14 | defs := SQLColDefinitions(cols, []string{"one"}) 15 | if len(defs) != 1 { 16 | t.Error("wrong number of defs:", len(defs)) 17 | } 18 | if got := defs[0].String(); got != "one int64" { 19 | t.Error("wrong def:", got) 20 | } 21 | 22 | defs = SQLColDefinitions(cols, []string{"one", "three"}) 23 | if len(defs) != 2 { 24 | t.Error("wrong number of defs:", len(defs)) 25 | } 26 | if got := defs[0].String(); got != "one int64" { 27 | t.Error("wrong def:", got) 28 | } 29 | if got := defs[1].String(); got != "three string" { 30 | t.Error("wrong def:", got) 31 | } 32 | } 33 | 34 | func TestTypes(t *testing.T) { 35 | t.Parallel() 36 | 37 | defs := SQLColumnDefs{ 38 | {Type: "thing1"}, 39 | {Type: "thing2"}, 40 | } 41 | 42 | ret := defs.Types() 43 | if ret[0] != "thing1" { 44 | t.Error("wrong type:", ret[0]) 45 | } 46 | if ret[1] != "thing2" { 47 | t.Error("wrong type:", ret[1]) 48 | } 49 | } 50 | 51 | func TestNames(t *testing.T) { 52 | t.Parallel() 53 | 54 | defs := SQLColumnDefs{ 55 | {Name: "thing1"}, 56 | {Name: "thing2"}, 57 | } 58 | 59 | ret := defs.Names() 60 | if ret[0] != "thing1" { 61 | t.Error("wrong type:", ret[0]) 62 | } 63 | if ret[1] != "thing2" { 64 | t.Error("wrong type:", ret[1]) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /drivers/registration.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "path/filepath" 8 | "strings" 9 | ) 10 | 11 | // registeredDrivers are all the drivers which are currently registered 12 | var registeredDrivers = map[string]Interface{} 13 | 14 | // RegisterBinary is used to register drivers that are binaries. 15 | // Panics if a driver with the same name has been previously loaded. 16 | func RegisterBinary(name, path string) { 17 | register(name, binaryDriver(path)) 18 | } 19 | 20 | // RegisterFromInit is typically called by a side-effect loaded driver 21 | // during init time. 22 | // Panics if a driver with the same name has been previously loaded. 23 | func RegisterFromInit(name string, driver Interface) { 24 | register(name, driver) 25 | } 26 | 27 | // GetDriver retrieves the driver by name 28 | func GetDriver(name string) Interface { 29 | if d, ok := registeredDrivers[name]; ok { 30 | return d 31 | } 32 | 33 | panic(fmt.Sprintf("drivers: sqlboiler driver %s has not been registered", name)) 34 | } 35 | 36 | func register(name string, driver Interface) { 37 | if _, ok := registeredDrivers[name]; ok { 38 | panic(fmt.Sprintf("drivers: sqlboiler driver %s already loaded", name)) 39 | } 40 | 41 | registeredDrivers[name] = driver 42 | } 43 | 44 | // RegisterBinaryFromCmdArg is used to register drivers from a command line argument 45 | // The argument is either just the driver name or a path to a specific driver 46 | // Panics if a driver with the same name has been previously loaded. 47 | func RegisterBinaryFromCmdArg(arg string) (name, path string, err error) { 48 | path, err = getFullPath(arg) 49 | if err != nil { 50 | return name, path, err 51 | } 52 | 53 | name = getNameFromPath(path) 54 | 55 | RegisterBinary(name, path) 56 | 57 | return name, path, nil 58 | } 59 | 60 | // Get the full path to the driver binary from the given path 61 | // the path can also be just the driver name e.g. "psql" 62 | func getFullPath(path string) (string, error) { 63 | var err error 64 | 65 | if strings.ContainsRune(path, os.PathSeparator) { 66 | return path, nil 67 | } 68 | 69 | path, err = exec.LookPath("sqlboiler-" + path) 70 | if err != nil { 71 | return path, fmt.Errorf("could not find driver executable: %w", err) 72 | } 73 | 74 | path, err = filepath.Abs(path) 75 | if err != nil { 76 | return path, fmt.Errorf("could not find absolute path to driver: %w", err) 77 | } 78 | 79 | return path, nil 80 | } 81 | 82 | // Get the driver name from the path. 83 | // strips the "sqlboiler-" prefix if it exists 84 | // strips the ".exe" suffix if it exits 85 | func getNameFromPath(name string) string { 86 | name = strings.Replace(filepath.Base(name), "sqlboiler-", "", 1) 87 | name = strings.Replace(name, ".exe", "", 1) 88 | 89 | return name 90 | } 91 | -------------------------------------------------------------------------------- /drivers/registration_test.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/volatiletech/sqlboiler/v4/importers" 7 | ) 8 | 9 | type testRegistrationDriver struct{} 10 | 11 | func (t testRegistrationDriver) Assemble(config Config) (*DBInfo, error) { 12 | return &DBInfo{ 13 | Tables: nil, 14 | Dialect: Dialect{}, 15 | }, nil 16 | } 17 | 18 | func (t testRegistrationDriver) Templates() (map[string]string, error) { 19 | return nil, nil 20 | } 21 | 22 | func (t testRegistrationDriver) Imports() (importers.Collection, error) { 23 | return importers.Collection{}, nil 24 | } 25 | 26 | func TestRegistration(t *testing.T) { 27 | mock := testRegistrationDriver{} 28 | RegisterFromInit("mock1", mock) 29 | 30 | if d, ok := registeredDrivers["mock1"]; !ok { 31 | t.Error("driver was not found") 32 | } else if d != mock { 33 | t.Error("got the wrong driver back") 34 | } 35 | } 36 | 37 | func TestBinaryRegistration(t *testing.T) { 38 | RegisterBinary("mock2", "/bin/true") 39 | 40 | if d, ok := registeredDrivers["mock2"]; !ok { 41 | t.Error("driver was not found") 42 | } else if string(d.(binaryDriver)) != "/bin/true" { 43 | t.Error("got the wrong driver back") 44 | } 45 | } 46 | 47 | func TestBinaryFromArgRegistration(t *testing.T) { 48 | RegisterBinaryFromCmdArg("/bin/true/mock5") 49 | 50 | if d, ok := registeredDrivers["mock5"]; !ok { 51 | t.Error("driver was not found") 52 | } else if string(d.(binaryDriver)) != "/bin/true/mock5" { 53 | t.Error("got the wrong driver back") 54 | } 55 | } 56 | 57 | func TestGetDriver(t *testing.T) { 58 | didYouPanic := false 59 | 60 | RegisterBinary("mock4", "/bin/true") 61 | 62 | func() { 63 | defer func() { 64 | if r := recover(); r != nil { 65 | didYouPanic = true 66 | } 67 | }() 68 | 69 | _ = GetDriver("mock4") 70 | }() 71 | 72 | if didYouPanic { 73 | t.Error("expected not to panic when fetching a driver that's known") 74 | } 75 | 76 | func() { 77 | defer func() { 78 | if r := recover(); r != nil { 79 | didYouPanic = true 80 | } 81 | }() 82 | 83 | _ = GetDriver("notpresentdriver") 84 | }() 85 | 86 | if !didYouPanic { 87 | t.Error("expected to recover from a panic") 88 | } 89 | } 90 | 91 | func TestReregister(t *testing.T) { 92 | didYouPanic := false 93 | 94 | func() { 95 | defer func() { 96 | if r := recover(); r != nil { 97 | didYouPanic = true 98 | } 99 | }() 100 | 101 | RegisterBinary("mock3", "/bin/true") 102 | RegisterBinary("mock3", "/bin/true") 103 | }() 104 | 105 | if !didYouPanic { 106 | t.Error("expected to recover from a panic") 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mssql/driver/mssql_test.go: -------------------------------------------------------------------------------- 1 | // These tests assume there is a user sqlboiler_driver_user and a database 2 | // by the name of sqlboiler_driver_test that it has full R/W rights to. 3 | // In order to create this you can use the following steps from a root 4 | // mssql account: 5 | // 6 | // create database sqlboiler_driver_test; 7 | // go 8 | // use sqlboiler_driver_test; 9 | // go 10 | // create user sqlboiler_driver_user with password = 'sqlboiler'; 11 | // go 12 | // exec sp_configure 'contained database authentication', 1; 13 | // go 14 | // reconfigure 15 | // go 16 | // alter database sqlboiler_driver_test set containment = partial; 17 | // go 18 | // create user sqlboiler_driver_user with password = 'Sqlboiler@1234'; 19 | // go 20 | // grant alter, control to sqlboiler_driver_user; 21 | // go 22 | 23 | package driver 24 | 25 | import ( 26 | "bytes" 27 | "encoding/json" 28 | "flag" 29 | "os" 30 | "os/exec" 31 | "regexp" 32 | "testing" 33 | 34 | "github.com/volatiletech/sqlboiler/v4/drivers" 35 | ) 36 | 37 | var ( 38 | flagOverwriteGolden = flag.Bool("overwrite-golden", false, "Overwrite the golden file with the current execution results") 39 | 40 | envHostname = drivers.DefaultEnv("DRIVER_HOSTNAME", "localhost") 41 | envPort = drivers.DefaultEnv("DRIVER_PORT", "1433") 42 | envUsername = drivers.DefaultEnv("DRIVER_USER", "sqlboiler_driver_user") 43 | envPassword = drivers.DefaultEnv("DRIVER_PASS", "Sqlboiler@1234") 44 | envDatabase = drivers.DefaultEnv("DRIVER_DB", "sqlboiler_driver_test") 45 | 46 | rgxKeyIDs = regexp.MustCompile(`__[A-F0-9]+$`) 47 | ) 48 | 49 | func TestDriver(t *testing.T) { 50 | out := &bytes.Buffer{} 51 | createDB := exec.Command("sqlcmd", "-S", envHostname, "-U", envUsername, "-P", envPassword, "-d", envDatabase, "-b", "-i", "testdatabase.sql") 52 | createDB.Stdout = out 53 | createDB.Stderr = out 54 | 55 | if err := createDB.Run(); err != nil { 56 | t.Logf("mssql output:\n%s\n", out.Bytes()) 57 | t.Fatal(err) 58 | } 59 | t.Logf("mssql output:\n%s\n", out.Bytes()) 60 | 61 | config := drivers.Config{ 62 | "user": envUsername, 63 | "pass": envPassword, 64 | "dbname": envDatabase, 65 | "host": envHostname, 66 | "port": envPort, 67 | "sslmode": "disable", 68 | "schema": "dbo", 69 | } 70 | 71 | p := &MSSQLDriver{} 72 | info, err := p.Assemble(config) 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | 77 | for _, t := range info.Tables { 78 | if t.IsView { 79 | continue 80 | } 81 | 82 | t.PKey.Name = rgxKeyIDs.ReplaceAllString(t.PKey.Name, "") 83 | for i := range t.FKeys { 84 | t.FKeys[i].Name = rgxKeyIDs.ReplaceAllString(t.FKeys[i].Name, "") 85 | } 86 | } 87 | 88 | got, err := json.MarshalIndent(info, "", "\t") 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | 93 | if *flagOverwriteGolden { 94 | if err = os.WriteFile("mssql.golden.json", got, 0664); err != nil { 95 | t.Fatal(err) 96 | } 97 | t.Log("wrote:", string(got)) 98 | return 99 | } 100 | 101 | want, err := os.ReadFile("mssql.golden.json") 102 | if err != nil { 103 | t.Fatal(err) 104 | } 105 | 106 | if bytes.Compare(want, got) != 0 { 107 | t.Errorf("want:\n%s\ngot:\n%s\n", want, got) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mssql/driver/override/main/singleton/mssql_upsert.go.tpl: -------------------------------------------------------------------------------- 1 | // buildUpsertQueryMSSQL builds a SQL statement string using the upsertData provided. 2 | func buildUpsertQueryMSSQL(dia drivers.Dialect, tableName string, primary, update, insert []string, output []string) string { 3 | insert = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, insert) 4 | 5 | buf := strmangle.GetBuffer() 6 | defer strmangle.PutBuffer(buf) 7 | 8 | startIndex := 1 9 | 10 | fmt.Fprintf(buf, "MERGE INTO %s as [t]\n", tableName) 11 | fmt.Fprintf(buf, "USING (SELECT %s) as [s] ([%s])\n", 12 | strmangle.Placeholders(dia.UseIndexPlaceholders, len(primary), startIndex, 1), 13 | strings.Join(primary, string(dia.RQ)+","+string(dia.LQ))) 14 | fmt.Fprint(buf, "ON (") 15 | for i, v := range primary { 16 | if i != 0 { 17 | fmt.Fprint(buf, " AND ") 18 | } 19 | fmt.Fprintf(buf, "[s].[%s] = [t].[%s]", v, v) 20 | } 21 | fmt.Fprint(buf, ")\n") 22 | 23 | startIndex += len(primary) 24 | 25 | if len(update) > 0 { 26 | fmt.Fprint(buf, "WHEN MATCHED THEN ") 27 | fmt.Fprintf(buf, "UPDATE SET %s\n", strmangle.SetParamNames(string(dia.LQ), string(dia.RQ), startIndex, update)) 28 | 29 | startIndex += len(update) 30 | } 31 | 32 | fmt.Fprint(buf, "WHEN NOT MATCHED THEN ") 33 | fmt.Fprintf(buf, "INSERT (%s) VALUES (%s)", 34 | strings.Join(insert, ", "), 35 | strmangle.Placeholders(dia.UseIndexPlaceholders, len(insert), startIndex, 1)) 36 | 37 | if len(output) > 0 { 38 | fmt.Fprintf(buf, "\nOUTPUT INSERTED.[%s];", strings.Join(output, "],INSERTED.[")) 39 | } else { 40 | fmt.Fprint(buf, ";") 41 | } 42 | 43 | return buf.String() 44 | } 45 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mssql/driver/override/test/singleton/mssql_main_test.go.tpl: -------------------------------------------------------------------------------- 1 | var rgxMSSQLkey = regexp.MustCompile(`(?m)^ALTER TABLE .*ADD\s+CONSTRAINT .* FOREIGN KEY.*?.*\n?REFERENCES.*`) 2 | 3 | type mssqlTester struct { 4 | dbConn *sql.DB 5 | dbName string 6 | host string 7 | user string 8 | pass string 9 | sslmode string 10 | port int 11 | testDBName string 12 | skipSQLCmd bool 13 | } 14 | 15 | func init() { 16 | dbMain = &mssqlTester{} 17 | } 18 | 19 | func (m *mssqlTester) setup() error { 20 | var err error 21 | 22 | viper.SetDefault("mssql.schema", "dbo") 23 | viper.SetDefault("mssql.sslmode", "true") 24 | viper.SetDefault("mssql.port", 1433) 25 | 26 | m.dbName = viper.GetString("mssql.dbname") 27 | m.host = viper.GetString("mssql.host") 28 | m.user = viper.GetString("mssql.user") 29 | m.pass = viper.GetString("mssql.pass") 30 | m.port = viper.GetInt("mssql.port") 31 | m.sslmode = viper.GetString("mssql.sslmode") 32 | m.testDBName = viper.GetString("mssql.testdbname") 33 | m.skipSQLCmd = viper.GetBool("mssql.skipsqlcmd") 34 | 35 | err = vala.BeginValidation().Validate( 36 | vala.StringNotEmpty(viper.GetString("mssql.user"), "mssql.user"), 37 | vala.StringNotEmpty(viper.GetString("mssql.host"), "mssql.host"), 38 | vala.Not(vala.Equals(viper.GetInt("mssql.port"), 0, "mssql.port")), 39 | vala.StringNotEmpty(viper.GetString("mssql.dbname"), "mssql.dbname"), 40 | vala.StringNotEmpty(viper.GetString("mssql.sslmode"), "mssql.sslmode"), 41 | ).Check() 42 | 43 | if err != nil { 44 | return err 45 | } 46 | 47 | // Create a randomized db name. 48 | if len(m.testDBName) == 0 { 49 | m.testDBName = randomize.StableDBName(m.dbName) 50 | } 51 | 52 | if !m.skipSQLCmd { 53 | if err = m.dropTestDB(); err != nil { 54 | return err 55 | } 56 | if err = m.createTestDB(); err != nil { 57 | return err 58 | } 59 | 60 | createCmd := exec.Command("sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass, "-d", m.testDBName) 61 | 62 | f, err := os.Open("tables_schema.sql") 63 | if err != nil { 64 | return errors.Wrap(err, "failed to open tables_schema.sql file") 65 | } 66 | 67 | defer func() { _ = f.Close() }() 68 | 69 | stderr := &bytes.Buffer{} 70 | createCmd.Stdin = newFKeyDestroyer(rgxMSSQLkey, f) 71 | createCmd.Stderr = stderr 72 | 73 | if err = createCmd.Start(); err != nil { 74 | return errors.Wrap(err, "failed to start sqlcmd command") 75 | } 76 | 77 | if err = createCmd.Wait(); err != nil { 78 | fmt.Println(err) 79 | fmt.Println(stderr.String()) 80 | return errors.Wrap(err, "failed to wait for sqlcmd command") 81 | } 82 | } 83 | 84 | return nil 85 | } 86 | 87 | func (m *mssqlTester) sslMode(mode string) string { 88 | switch mode { 89 | case "true": 90 | return "true" 91 | case "false": 92 | return "false" 93 | default: 94 | return "disable" 95 | } 96 | } 97 | 98 | func (m *mssqlTester) createTestDB() error { 99 | sql := fmt.Sprintf(` 100 | CREATE DATABASE %s; 101 | GO 102 | ALTER DATABASE %[1]s 103 | SET READ_COMMITTED_SNAPSHOT ON; 104 | GO`, m.testDBName) 105 | return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass) 106 | } 107 | 108 | func (m *mssqlTester) dropTestDB() error { 109 | // Since MS SQL 2016 it can be done with 110 | // DROP DATABASE [ IF EXISTS ] { database_name | database_snapshot_name } [ ,...n ] [;] 111 | sql := fmt.Sprintf(` 112 | IF EXISTS(SELECT name FROM sys.databases 113 | WHERE name = '%s') 114 | DROP DATABASE %s 115 | GO`, m.testDBName, m.testDBName) 116 | return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass) 117 | } 118 | 119 | func (m *mssqlTester) teardown() error { 120 | if m.dbConn != nil { 121 | m.dbConn.Close() 122 | } 123 | 124 | if !m.skipSQLCmd { 125 | if err := m.dropTestDB(); err != nil { 126 | return err 127 | } 128 | } 129 | 130 | return nil 131 | } 132 | 133 | func (m *mssqlTester) runCmd(stdin, command string, args ...string) error { 134 | cmd := exec.Command(command, args...) 135 | cmd.Stdin = strings.NewReader(stdin) 136 | 137 | stdout := &bytes.Buffer{} 138 | stderr := &bytes.Buffer{} 139 | cmd.Stdout = stdout 140 | cmd.Stderr = stderr 141 | if err := cmd.Run(); err != nil { 142 | fmt.Println("failed running:", command, args) 143 | fmt.Println(stdout.String()) 144 | fmt.Println(stderr.String()) 145 | return err 146 | } 147 | 148 | return nil 149 | } 150 | 151 | func (m *mssqlTester) conn() (*sql.DB, error) { 152 | if m.dbConn != nil { 153 | return m.dbConn, nil 154 | } 155 | 156 | var err error 157 | m.dbConn, err = sql.Open("mssql", driver.MSSQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode)) 158 | if err != nil { 159 | return nil, err 160 | } 161 | 162 | return m.dbConn, nil 163 | } 164 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mssql/driver/override/test/singleton/mssql_suites_test.go.tpl: -------------------------------------------------------------------------------- 1 | func TestUpsert(t *testing.T) { 2 | {{- range $index, $table := .Tables}} 3 | {{- if or $table.IsJoinTable $table.IsView -}} 4 | {{- else -}} 5 | {{- $alias := $.Aliases.Table $table.Name}} 6 | t.Run("{{$alias.UpPlural}}", test{{$alias.UpPlural}}Upsert) 7 | {{end -}} 8 | {{- end -}} 9 | } 10 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mssql/driver/override/test/upsert.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Upsert(t *testing.T) { 3 | t.Parallel() 4 | 5 | if len({{$alias.DownSingular}}AllColumns) == len({{$alias.DownSingular}}PrimaryKeyColumns) { 6 | t.Skip("Skipping table with only primary key columns") 7 | } 8 | 9 | seed := randomize.NewSeed() 10 | var err error 11 | // Attempt the INSERT side of an UPSERT 12 | o := {{$alias.UpSingular}}{} 13 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, true); err != nil { 14 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 15 | } 16 | 17 | {{if not .NoContext}}ctx := context.Background(){{end}} 18 | tx := MustTx({{if .NoContext}}{{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}{{else}}boil.BeginTx(ctx, nil){{end}}) 19 | defer func() { _ = tx.Rollback() }() 20 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer(), boil.Infer()); err != nil { 21 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 22 | } 23 | 24 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 25 | if err != nil { 26 | t.Error(err) 27 | } 28 | if count != 1 { 29 | t.Error("want one record, got:", count) 30 | } 31 | 32 | // Attempt the UPDATE side of an UPSERT 33 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}PrimaryKeyColumns...); err != nil { 34 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 35 | } 36 | 37 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer(), boil.Infer()); err != nil { 38 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 39 | } 40 | 41 | count, err = {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 42 | if err != nil { 43 | t.Error(err) 44 | } 45 | if count != 1 { 46 | t.Error("want one record, got:", count) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mssql/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/volatiletech/sqlboiler/v4/drivers" 5 | "github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-mssql/driver" 6 | ) 7 | 8 | func main() { 9 | drivers.DriverMain(&driver.MSSQLDriver{}) 10 | } 11 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mysql/driver/mysql_test.go: -------------------------------------------------------------------------------- 1 | // These tests assume there is a user sqlboiler_test_user and a database 2 | // by the name of sqlboiler_test that it has full R/W rights to. 3 | // In order to create this you can use the following steps from a root 4 | // mysql account: 5 | // 6 | // create user sqlboiler_driver_user identified by 'sqlboiler'; 7 | // create database sqlboiler_driver_test; 8 | // grant all privileges on sqlboiler_driver_test.* to sqlboiler_driver_user; 9 | 10 | package driver 11 | 12 | import ( 13 | "bytes" 14 | "encoding/json" 15 | "flag" 16 | "fmt" 17 | "os" 18 | "os/exec" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/require" 22 | "github.com/volatiletech/sqlboiler/v4/drivers" 23 | ) 24 | 25 | var ( 26 | flagOverwriteGolden = flag.Bool("overwrite-golden", false, "Overwrite the golden file with the current execution results") 27 | 28 | envHostname = drivers.DefaultEnv("DRIVER_HOSTNAME", "localhost") 29 | envPort = drivers.DefaultEnv("DRIVER_PORT", "3306") 30 | envUsername = drivers.DefaultEnv("DRIVER_USER", "sqlboiler_driver_user") 31 | envPassword = drivers.DefaultEnv("DRIVER_PASS", "sqlboiler") 32 | envDatabase = drivers.DefaultEnv("DRIVER_DB", "sqlboiler_driver_test") 33 | ) 34 | 35 | func TestDriver(t *testing.T) { 36 | b, err := os.ReadFile("testdatabase.sql") 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | 41 | out := &bytes.Buffer{} 42 | createDB := exec.Command("mysql", "-h", envHostname, "-P", envPort, "-u", envUsername, fmt.Sprintf("-p%s", envPassword), envDatabase) 43 | createDB.Stdout = out 44 | createDB.Stderr = out 45 | createDB.Stdin = bytes.NewReader(b) 46 | 47 | if err := createDB.Run(); err != nil { 48 | t.Logf("mysql output:\n%s\n", out.Bytes()) 49 | t.Fatal(err) 50 | } 51 | t.Logf("mysql output:\n%s\n", out.Bytes()) 52 | 53 | tests := []struct { 54 | name string 55 | config drivers.Config 56 | goldenJson string 57 | }{ 58 | { 59 | name: "default", 60 | config: drivers.Config{ 61 | "user": envUsername, 62 | "pass": envPassword, 63 | "dbname": envDatabase, 64 | "host": envHostname, 65 | "port": envPort, 66 | "sslmode": "false", 67 | "schema": envDatabase, 68 | }, 69 | goldenJson: "mysql.golden.json", 70 | }, 71 | { 72 | name: "enum_types", 73 | config: drivers.Config{ 74 | "user": envUsername, 75 | "pass": envPassword, 76 | "dbname": envDatabase, 77 | "host": envHostname, 78 | "port": envPort, 79 | "sslmode": "false", 80 | "schema": envDatabase, 81 | "add-enum-types": true, 82 | }, 83 | goldenJson: "mysql.golden.enums.json", 84 | }, 85 | } 86 | 87 | for _, tt := range tests { 88 | t.Run(tt.name, func(t *testing.T) { 89 | p := &MySQLDriver{} 90 | info, err := p.Assemble(tt.config) 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | 95 | got, err := json.MarshalIndent(info, "", "\t") 96 | if err != nil { 97 | t.Fatal(err) 98 | } 99 | 100 | if *flagOverwriteGolden { 101 | if err = os.WriteFile(tt.goldenJson, got, 0664); err != nil { 102 | t.Fatal(err) 103 | } 104 | t.Log("wrote:", string(got)) 105 | return 106 | } 107 | 108 | want, err := os.ReadFile(tt.goldenJson) 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | 113 | require.JSONEq(t, string(want), string(got)) 114 | }) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mysql/driver/override/main/singleton/mysql_upsert.go.tpl: -------------------------------------------------------------------------------- 1 | // buildUpsertQueryMySQL builds a SQL statement string using the upsertData provided. 2 | func buildUpsertQueryMySQL(dia drivers.Dialect, tableName string, update, whitelist []string) string { 3 | whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) 4 | tableName = strmangle.IdentQuote(dia.LQ, dia.RQ, tableName) 5 | 6 | buf := strmangle.GetBuffer() 7 | defer strmangle.PutBuffer(buf) 8 | 9 | var columns string 10 | if len(whitelist) != 0 { 11 | columns = strings.Join(whitelist, ",") 12 | } 13 | 14 | if len(update) == 0 { 15 | fmt.Fprintf( 16 | buf, 17 | "INSERT IGNORE INTO %s (%s) VALUES (%s)", 18 | tableName, 19 | columns, 20 | strmangle.Placeholders(dia.UseIndexPlaceholders, len(whitelist), 1, 1), 21 | ) 22 | return buf.String() 23 | } 24 | 25 | fmt.Fprintf( 26 | buf, 27 | "INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ", 28 | tableName, 29 | columns, 30 | strmangle.Placeholders(dia.UseIndexPlaceholders, len(whitelist), 1, 1), 31 | ) 32 | 33 | for i, v := range update { 34 | if i != 0 { 35 | buf.WriteByte(',') 36 | } 37 | quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) 38 | buf.WriteString(quoted) 39 | buf.WriteString(" = VALUES(") 40 | buf.WriteString(quoted) 41 | buf.WriteByte(')') 42 | } 43 | 44 | return buf.String() 45 | } 46 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mysql/driver/override/test/singleton/mysql_suites_test.go.tpl: -------------------------------------------------------------------------------- 1 | func TestUpsert(t *testing.T) { 2 | {{- range $index, $table := .Tables}} 3 | {{- if or $table.IsJoinTable $table.IsView -}} 4 | {{- else -}} 5 | {{- $alias := $.Aliases.Table $table.Name}} 6 | t.Run("{{$alias.UpPlural}}", test{{$alias.UpPlural}}Upsert) 7 | {{end -}} 8 | {{- end -}} 9 | } 10 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mysql/driver/override/test/upsert.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Upsert(t *testing.T) { 3 | t.Parallel() 4 | 5 | if len({{$alias.DownSingular}}AllColumns) == len({{$alias.DownSingular}}PrimaryKeyColumns) { 6 | t.Skip("Skipping table with only primary key columns") 7 | } 8 | if len(mySQL{{$alias.UpSingular}}UniqueColumns) == 0 { 9 | t.Skip("Skipping table with no unique columns to conflict on") 10 | } 11 | 12 | seed := randomize.NewSeed() 13 | var err error 14 | // Attempt the INSERT side of an UPSERT 15 | o := {{$alias.UpSingular}}{} 16 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, false); err != nil { 17 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 18 | } 19 | 20 | {{if not .NoContext}}ctx := context.Background(){{end}} 21 | tx := MustTx({{if .NoContext}}{{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}{{else}}boil.BeginTx(ctx, nil){{end}}) 22 | defer func() { _ = tx.Rollback() }() 23 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer(), boil.Infer()); err != nil { 24 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 25 | } 26 | 27 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 28 | if err != nil { 29 | t.Error(err) 30 | } 31 | if count != 1 { 32 | t.Error("want one record, got:", count) 33 | } 34 | 35 | // Attempt the UPDATE side of an UPSERT 36 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}PrimaryKeyColumns...); err != nil { 37 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 38 | } 39 | 40 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer(), boil.Infer()); err != nil { 41 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 42 | } 43 | 44 | count, err = {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | if count != 1 { 49 | t.Error("want one record, got:", count) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /drivers/sqlboiler-mysql/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/volatiletech/sqlboiler/v4/drivers" 5 | "github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-mysql/driver" 6 | ) 7 | 8 | func main() { 9 | drivers.DriverMain(&driver.MySQLDriver{}) 10 | } 11 | -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/driver/override/main/22_ilike.go.tpl: -------------------------------------------------------------------------------- 1 | {{- define "where_ilike_override" -}} 2 | {{$name := printf "whereHelper%s" (goVarname .Type)}} 3 | func (w {{$name}}) ILIKE(x {{.Type}}) qm.QueryMod { return qm.Where(w.field+" ILIKE ?", x) } 4 | func (w {{$name}}) NILIKE(x {{.Type}}) qm.QueryMod { return qm.Where(w.field+" NOT ILIKE ?", x) } 5 | {{- end -}} -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/driver/override/main/23_similarto.go.tpl: -------------------------------------------------------------------------------- 1 | {{- define "where_similarto_override" -}} 2 | {{$name := printf "whereHelper%s" (goVarname .Type)}} 3 | func (w {{$name}}) SIMILAR(x {{.Type}}) qm.QueryMod { return qm.Where(w.field+" SIMILAR TO ?", x) } 4 | func (w {{$name}}) NSIMILAR(x {{.Type}}) qm.QueryMod { return qm.Where(w.field+" NOT SIMILAR TO ?", x) } 5 | {{- end -}} -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/driver/override/main/singleton/psql_upsert.go.tpl: -------------------------------------------------------------------------------- 1 | type UpsertOptions struct { 2 | conflictTarget string 3 | updateSet string 4 | } 5 | 6 | type UpsertOptionFunc func(o *UpsertOptions) 7 | 8 | func UpsertConflictTarget(conflictTarget string) UpsertOptionFunc { 9 | return func(o *UpsertOptions) { 10 | o.conflictTarget = conflictTarget 11 | } 12 | } 13 | 14 | func UpsertUpdateSet(updateSet string) UpsertOptionFunc { 15 | return func(o *UpsertOptions) { 16 | o.updateSet = updateSet 17 | } 18 | } 19 | 20 | // buildUpsertQueryPostgres builds a SQL statement string using the upsertData provided. 21 | func buildUpsertQueryPostgres(dia drivers.Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string, opts ...UpsertOptionFunc) string { 22 | conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict) 23 | whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) 24 | ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret) 25 | 26 | upsertOpts := &UpsertOptions{} 27 | for _, o := range opts { 28 | o(upsertOpts) 29 | } 30 | 31 | buf := strmangle.GetBuffer() 32 | defer strmangle.PutBuffer(buf) 33 | 34 | columns := "DEFAULT VALUES" 35 | if len(whitelist) != 0 { 36 | columns = fmt.Sprintf("(%s) VALUES (%s)", 37 | strings.Join(whitelist, ", "), 38 | strmangle.Placeholders(dia.UseIndexPlaceholders, len(whitelist), 1, 1)) 39 | } 40 | 41 | fmt.Fprintf( 42 | buf, 43 | "INSERT INTO %s %s ON CONFLICT ", 44 | tableName, 45 | columns, 46 | ) 47 | 48 | if upsertOpts.conflictTarget != "" { 49 | buf.WriteString(upsertOpts.conflictTarget) 50 | } else if len(conflict) != 0 { 51 | buf.WriteByte('(') 52 | buf.WriteString(strings.Join(conflict, ", ")) 53 | buf.WriteByte(')') 54 | } 55 | buf.WriteByte(' ') 56 | 57 | if !updateOnConflict || len(update) == 0 { 58 | buf.WriteString("DO NOTHING") 59 | } else { 60 | buf.WriteString("DO UPDATE SET ") 61 | 62 | if upsertOpts.updateSet != "" { 63 | buf.WriteString(upsertOpts.updateSet) 64 | } else { 65 | for i, v := range update { 66 | if len(v) == 0 { 67 | continue 68 | } 69 | if i != 0 { 70 | buf.WriteByte(',') 71 | } 72 | quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) 73 | buf.WriteString(quoted) 74 | buf.WriteString(" = EXCLUDED.") 75 | buf.WriteString(quoted) 76 | } 77 | } 78 | } 79 | 80 | if len(ret) != 0 { 81 | buf.WriteString(" RETURNING ") 82 | buf.WriteString(strings.Join(ret, ", ")) 83 | } 84 | 85 | return buf.String() 86 | } 87 | -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/driver/override/test/singleton/psql_suites_test.go.tpl: -------------------------------------------------------------------------------- 1 | func TestUpsert(t *testing.T) { 2 | {{- range $index, $table := .Tables}} 3 | {{- if or $table.IsJoinTable $table.IsView -}} 4 | {{- else -}} 5 | {{- $alias := $.Aliases.Table $table.Name}} 6 | t.Run("{{$alias.UpPlural}}", test{{$alias.UpPlural}}Upsert) 7 | {{end -}} 8 | {{- end -}} 9 | } 10 | -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/driver/override/test/upsert.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Upsert(t *testing.T) { 3 | t.Parallel() 4 | 5 | if len({{$alias.DownSingular}}AllColumns) == len({{$alias.DownSingular}}PrimaryKeyColumns) { 6 | t.Skip("Skipping table with only primary key columns") 7 | } 8 | 9 | seed := randomize.NewSeed() 10 | var err error 11 | // Attempt the INSERT side of an UPSERT 12 | o := {{$alias.UpSingular}}{} 13 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, true); err != nil { 14 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 15 | } 16 | 17 | {{if not .NoContext}}ctx := context.Background(){{end}} 18 | tx := MustTx({{if .NoContext}}{{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}{{else}}boil.BeginTx(ctx, nil){{end}}) 19 | defer func() { _ = tx.Rollback() }() 20 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, false, nil, boil.Infer(), boil.Infer()); err != nil { 21 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 22 | } 23 | 24 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 25 | if err != nil { 26 | t.Error(err) 27 | } 28 | if count != 1 { 29 | t.Error("want one record, got:", count) 30 | } 31 | 32 | // Attempt the UPDATE side of an UPSERT 33 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}PrimaryKeyColumns...); err != nil { 34 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 35 | } 36 | 37 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, true, nil, boil.Infer(), boil.Infer()); err != nil { 38 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 39 | } 40 | 41 | count, err = {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 42 | if err != nil { 43 | t.Error(err) 44 | } 45 | if count != 1 { 46 | t.Error("want one record, got:", count) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/driver/psql_test.go: -------------------------------------------------------------------------------- 1 | // These tests assume there is a user sqlboiler_driver_user and a database 2 | // by the name of sqlboiler_driver_test that it has full R/W rights to. 3 | // In order to create this you can use the following steps from a root 4 | // psql account: 5 | // 6 | // create role sqlboiler_driver_user login nocreatedb nocreaterole nocreateuser password 'sqlboiler'; 7 | // create database sqlboiler_driver_test owner = sqlboiler_driver_user; 8 | 9 | package driver 10 | 11 | import ( 12 | "bytes" 13 | "encoding/json" 14 | "flag" 15 | "fmt" 16 | "os" 17 | "os/exec" 18 | "testing" 19 | 20 | "github.com/stretchr/testify/require" 21 | 22 | "github.com/volatiletech/sqlboiler/v4/drivers" 23 | ) 24 | 25 | var ( 26 | flagOverwriteGolden = flag.Bool("overwrite-golden", false, "Overwrite the golden file with the current execution results") 27 | 28 | envHostname = drivers.DefaultEnv("DRIVER_HOSTNAME", "localhost") 29 | envPort = drivers.DefaultEnv("DRIVER_PORT", "5432") 30 | envUsername = drivers.DefaultEnv("DRIVER_USER", "sqlboiler_driver_user") 31 | envPassword = drivers.DefaultEnv("DRIVER_PASS", "sqlboiler") 32 | envDatabase = drivers.DefaultEnv("DRIVER_DB", "sqlboiler_driver_test") 33 | ) 34 | 35 | func TestAssemble(t *testing.T) { 36 | b, err := os.ReadFile("testdatabase.sql") 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | 41 | out := &bytes.Buffer{} 42 | createDB := exec.Command("psql", "-h", envHostname, "-U", envUsername, envDatabase) 43 | createDB.Env = append([]string{fmt.Sprintf("PGPASSWORD=%s", envPassword)}, os.Environ()...) 44 | createDB.Stdout = out 45 | createDB.Stderr = out 46 | createDB.Stdin = bytes.NewReader(b) 47 | 48 | if err := createDB.Run(); err != nil { 49 | t.Logf("psql output:\n%s\n", out.Bytes()) 50 | t.Fatal(err) 51 | } 52 | t.Logf("psql output:\n%s\n", out.Bytes()) 53 | 54 | tests := []struct { 55 | name string 56 | config drivers.Config 57 | goldenJson string 58 | }{ 59 | { 60 | name: "default", 61 | config: drivers.Config{ 62 | "user": envUsername, 63 | "pass": envPassword, 64 | "dbname": envDatabase, 65 | "host": envHostname, 66 | "port": envPort, 67 | "sslmode": "disable", 68 | "schema": "public", 69 | }, 70 | goldenJson: "psql.golden.json", 71 | }, 72 | { 73 | name: "enum_types", 74 | config: drivers.Config{ 75 | "user": envUsername, 76 | "pass": envPassword, 77 | "dbname": envDatabase, 78 | "host": envHostname, 79 | "port": envPort, 80 | "sslmode": "disable", 81 | "schema": "public", 82 | "add-enum-types": true, 83 | }, 84 | goldenJson: "psql.golden.enums.json", 85 | }, 86 | } 87 | 88 | for _, tt := range tests { 89 | t.Run(tt.name, func(t *testing.T) { 90 | p := PostgresDriver{} 91 | info, err := p.Assemble(tt.config) 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | 96 | got, err := json.MarshalIndent(info, "", "\t") 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | 101 | if *flagOverwriteGolden { 102 | if err = os.WriteFile(tt.goldenJson, got, 0664); err != nil { 103 | t.Fatal(err) 104 | } 105 | t.Log("wrote:", string(got)) 106 | return 107 | } 108 | 109 | want, err := os.ReadFile(tt.goldenJson) 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | 114 | require.JSONEq(t, string(want), string(got)) 115 | }) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /drivers/sqlboiler-psql/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/volatiletech/sqlboiler/v4/drivers" 5 | "github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-psql/driver" 6 | ) 7 | 8 | func main() { 9 | drivers.DriverMain(&driver.PostgresDriver{}) 10 | } 11 | -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/README.md: -------------------------------------------------------------------------------- 1 | # sqlboiler-sqlite3 2 | 3 | ## Configuration 4 | 5 | ```toml 6 | # Absolute path is recommended since the location 7 | # sqlite3 is being run can change. 8 | # For example generation time and model test time. 9 | [sqlite3] 10 | dbname = "/path/to/file" 11 | ``` 12 | -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/driver/override/main/singleton/sqlite_upsert.go.tpl: -------------------------------------------------------------------------------- 1 | // buildUpsertQuerySQLite builds a SQL statement string using the upsertData provided. 2 | func buildUpsertQuerySQLite(dia drivers.Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { 3 | conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict) 4 | whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) 5 | ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret) 6 | 7 | buf := strmangle.GetBuffer() 8 | defer strmangle.PutBuffer(buf) 9 | 10 | columns := "DEFAULT VALUES" 11 | if len(whitelist) != 0 { 12 | columns = fmt.Sprintf("(%s) VALUES (%s)", 13 | strings.Join(whitelist, ", "), 14 | strmangle.Placeholders(dia.UseIndexPlaceholders, len(whitelist), 1, 1)) 15 | } 16 | 17 | fmt.Fprintf( 18 | buf, 19 | "INSERT INTO %s %s ON CONFLICT ", 20 | tableName, 21 | columns, 22 | ) 23 | 24 | if !updateOnConflict || len(update) == 0 { 25 | buf.WriteString("DO NOTHING") 26 | } else { 27 | buf.WriteByte('(') 28 | buf.WriteString(strings.Join(conflict, ", ")) 29 | buf.WriteString(") DO UPDATE SET ") 30 | 31 | for i, v := range update { 32 | if i != 0 { 33 | buf.WriteByte(',') 34 | } 35 | quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) 36 | buf.WriteString(quoted) 37 | buf.WriteString(" = EXCLUDED.") 38 | buf.WriteString(quoted) 39 | } 40 | } 41 | 42 | if len(ret) != 0 { 43 | buf.WriteString(" RETURNING ") 44 | buf.WriteString(strings.Join(ret, ", ")) 45 | } 46 | 47 | return buf.String() 48 | } -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/driver/override/test/singleton/sqlite3_main_test.go.tpl: -------------------------------------------------------------------------------- 1 | var rgxSQLitekey = regexp.MustCompile(`(?mi)((,\n)?\s+foreign key.*?\n)+`) 2 | 3 | type sqliteTester struct { 4 | dbConn *sql.DB 5 | 6 | dbName string 7 | testDBName string 8 | } 9 | 10 | func init() { 11 | dbMain = &sqliteTester{} 12 | } 13 | 14 | func (s *sqliteTester) setup() error { 15 | var err error 16 | 17 | s.dbName = viper.GetString("sqlite3.dbname") 18 | if len(s.dbName) == 0 { 19 | return errors.New("no dbname specified") 20 | } 21 | 22 | s.testDBName = filepath.Join(os.TempDir(), fmt.Sprintf("boil-sqlite3-%d.sql", rand.Int())) 23 | 24 | dumpCmd := exec.Command("sqlite3", "-cmd", ".dump", s.dbName) 25 | createCmd := exec.Command("sqlite3", s.testDBName) 26 | 27 | r, w := io.Pipe() 28 | dumpCmd.Stdout = w 29 | createCmd.Stdin = newFKeyDestroyer(rgxSQLitekey, r) 30 | 31 | if err = dumpCmd.Start(); err != nil { 32 | return errors.Wrap(err, "failed to start sqlite3 dump command") 33 | } 34 | if err = createCmd.Start(); err != nil { 35 | return errors.Wrap(err, "failed to start sqlite3 create command") 36 | } 37 | 38 | if err = dumpCmd.Wait(); err != nil { 39 | fmt.Println(err) 40 | return errors.Wrap(err, "failed to wait for sqlite3 dump command") 41 | } 42 | 43 | w.Close() // After dumpCmd is done, close the write end of the pipe 44 | 45 | if err = createCmd.Wait(); err != nil { 46 | fmt.Println(err) 47 | return errors.Wrap(err, "failed to wait for sqlite3 create command") 48 | } 49 | 50 | return nil 51 | } 52 | 53 | func (s *sqliteTester) teardown() error { 54 | if s.dbConn != nil { 55 | s.dbConn.Close() 56 | } 57 | 58 | return os.Remove(s.testDBName) 59 | } 60 | 61 | func (s *sqliteTester) conn() (*sql.DB, error) { 62 | if s.dbConn != nil { 63 | return s.dbConn, nil 64 | } 65 | 66 | var err error 67 | s.dbConn, err = sql.Open("sqlite", fmt.Sprintf("file:%s?cache=shared&_loc=UTC", s.testDBName)) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return s.dbConn, nil 73 | } 74 | -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/driver/override/test/singleton/sqlite3_suites_test.go.tpl: -------------------------------------------------------------------------------- 1 | func TestUpsert(t *testing.T) { 2 | {{- range $index, $table := .Tables}} 3 | {{- if or $table.IsJoinTable $table.IsView -}} 4 | {{- else -}} 5 | {{- $alias := $.Aliases.Table $table.Name}} 6 | t.Run("{{$alias.UpPlural}}", test{{$alias.UpPlural}}Upsert) 7 | {{end -}} 8 | {{- end -}} 9 | } 10 | -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/driver/override/test/upsert.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Upsert(t *testing.T) { 3 | t.Parallel() 4 | if len({{$alias.DownSingular}}AllColumns) == len({{$alias.DownSingular}}PrimaryKeyColumns) { 5 | t.Skip("Skipping table with only primary key columns") 6 | } 7 | 8 | seed := randomize.NewSeed() 9 | var err error 10 | // Attempt the INSERT side of an UPSERT 11 | o := {{$alias.UpSingular}}{} 12 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, true); err != nil { 13 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 14 | } 15 | 16 | {{if not .NoContext}}ctx := context.Background(){{end}} 17 | tx := MustTx({{if .NoContext}}{{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}{{else}}boil.BeginTx(ctx, nil){{end}}) 18 | defer func() { _ = tx.Rollback() }() 19 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, false, nil, boil.Infer(), boil.Infer()); err != nil { 20 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 21 | } 22 | 23 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 24 | if err != nil { 25 | t.Error(err) 26 | } 27 | if count != 1 { 28 | t.Error("want one record, got:", count) 29 | } 30 | 31 | // Attempt the UPDATE side of an UPSERT 32 | if err = randomize.Struct(seed, &o, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}PrimaryKeyColumns...); err != nil { 33 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 34 | } 35 | 36 | if err = o.Upsert({{if not .NoContext}}ctx, {{end -}} tx, true, nil, boil.Infer(), boil.Infer()); err != nil { 37 | t.Errorf("Unable to upsert {{$alias.UpSingular}}: %s", err) 38 | } 39 | 40 | count, err = {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 41 | if err != nil { 42 | t.Error(err) 43 | } 44 | if count != 1 { 45 | t.Error("want one record, got:", count) 46 | } 47 | } 48 | 49 | -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/driver/sqlite3_test.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "flag" 7 | "fmt" 8 | "math/rand" 9 | "os" 10 | "os/exec" 11 | "path/filepath" 12 | "testing" 13 | "time" 14 | 15 | "github.com/stretchr/testify/require" 16 | "github.com/volatiletech/sqlboiler/v4/drivers" 17 | _ "modernc.org/sqlite" 18 | ) 19 | 20 | var ( 21 | flagOverwriteGolden = flag.Bool("overwrite-golden", false, "Overwrite the golden file with the current execution results") 22 | ) 23 | 24 | func TestDriver(t *testing.T) { 25 | rand.New(rand.NewSource(time.Now().Unix())) 26 | b, err := os.ReadFile("testdatabase.sql") 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | 31 | tmpName := filepath.Join(os.TempDir(), fmt.Sprintf("sqlboiler-sqlite3-%d.sql", rand.Int())) 32 | 33 | out := &bytes.Buffer{} 34 | createDB := exec.Command("sqlite3", tmpName) 35 | createDB.Stdout = out 36 | createDB.Stderr = out 37 | createDB.Stdin = bytes.NewReader(b) 38 | 39 | t.Log("sqlite file:", tmpName) 40 | if err := createDB.Run(); err != nil { 41 | t.Logf("sqlite output:\n%s\n", out.Bytes()) 42 | t.Fatal(err) 43 | } 44 | t.Logf("sqlite output:\n%s\n", out.Bytes()) 45 | 46 | tests := []struct { 47 | name string 48 | config drivers.Config 49 | goldenJson string 50 | }{ 51 | { 52 | name: "default", 53 | config: drivers.Config{ 54 | "dbname": tmpName, 55 | }, 56 | goldenJson: "sqlite3.golden.json", 57 | }, 58 | } 59 | 60 | for _, tt := range tests { 61 | t.Run(tt.name, func(t *testing.T) { 62 | s := &SQLiteDriver{} 63 | info, err := s.Assemble(tt.config) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | 68 | got, err := json.MarshalIndent(info, "", "\t") 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | if *flagOverwriteGolden { 74 | if err = os.WriteFile(tt.goldenJson, got, 0664); err != nil { 75 | t.Fatal(err) 76 | } 77 | t.Log("wrote:", string(got)) 78 | return 79 | } 80 | 81 | want, err := os.ReadFile(tt.goldenJson) 82 | if err != nil { 83 | t.Fatal(err) 84 | } 85 | 86 | require.JSONEq(t, string(want), string(got)) 87 | }) 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /drivers/sqlboiler-sqlite3/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/volatiletech/sqlboiler/v4/drivers" 5 | "github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-sqlite3/driver" 6 | ) 7 | 8 | func main() { 9 | drivers.DriverMain(&driver.SQLiteDriver{}) 10 | } 11 | -------------------------------------------------------------------------------- /drivers/table.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Table metadata from the database schema. 8 | type Table struct { 9 | Name string `json:"name"` 10 | // For dbs with real schemas, like Postgres. 11 | // Example value: "schema_name"."table_name" 12 | SchemaName string `json:"schema_name"` 13 | Columns []Column `json:"columns"` 14 | 15 | PKey *PrimaryKey `json:"p_key"` 16 | FKeys []ForeignKey `json:"f_keys"` 17 | 18 | IsJoinTable bool `json:"is_join_table"` 19 | 20 | ToOneRelationships []ToOneRelationship `json:"to_one_relationships"` 21 | ToManyRelationships []ToManyRelationship `json:"to_many_relationships"` 22 | 23 | // For views 24 | IsView bool `json:"is_view"` 25 | ViewCapabilities ViewCapabilities `json:"view_capabilities"` 26 | } 27 | 28 | type ViewCapabilities struct { 29 | CanInsert bool `json:"can_insert"` 30 | CanUpsert bool `json:"can_upsert"` 31 | } 32 | 33 | // GetTable by name. Panics if not found (for use in templates mostly). 34 | func GetTable(tables []Table, name string) (tbl Table) { 35 | for _, t := range tables { 36 | if t.Name == name { 37 | return t 38 | } 39 | } 40 | 41 | panic(fmt.Sprintf("could not find table name: %s", name)) 42 | } 43 | 44 | // GetColumn by name. Panics if not found (for use in templates mostly). 45 | func (t Table) GetColumn(name string) (col Column) { 46 | for _, c := range t.Columns { 47 | if c.Name == name { 48 | return c 49 | } 50 | } 51 | 52 | panic(fmt.Sprintf("could not find column name: %s", name)) 53 | } 54 | 55 | // CanLastInsertID checks the following: 56 | // 1. Is there only one primary key? 57 | // 2. Does the primary key column have a default value? 58 | // 3. Is the primary key column type one of uintX/intX? 59 | // If the above is all true, this table can use LastInsertId 60 | func (t Table) CanLastInsertID() bool { 61 | if t.PKey == nil || len(t.PKey.Columns) != 1 { 62 | return false 63 | } 64 | 65 | col := t.GetColumn(t.PKey.Columns[0]) 66 | if len(col.Default) == 0 { 67 | return false 68 | } 69 | 70 | switch col.Type { 71 | case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": 72 | default: 73 | return false 74 | } 75 | 76 | return true 77 | } 78 | 79 | func (t Table) CanSoftDelete(deleteColumn string) bool { 80 | if deleteColumn == "" { 81 | deleteColumn = "deleted_at" 82 | } 83 | 84 | for _, column := range t.Columns { 85 | if column.Name == deleteColumn && column.Type == "null.Time" { 86 | return true 87 | } 88 | } 89 | return false 90 | } 91 | 92 | func TablesHaveNullableEnums(tables []Table) bool { 93 | for _, table := range tables { 94 | for _, col := range table.Columns { 95 | if col.Nullable && IsEnumDBType(col.DBType) { 96 | return true 97 | } 98 | } 99 | } 100 | return false 101 | } 102 | -------------------------------------------------------------------------------- /drivers/table_test.go: -------------------------------------------------------------------------------- 1 | package drivers 2 | 3 | import "testing" 4 | 5 | func TestGetTable(t *testing.T) { 6 | t.Parallel() 7 | 8 | tables := []Table{ 9 | {Name: "one"}, 10 | } 11 | 12 | tbl := GetTable(tables, "one") 13 | 14 | if tbl.Name != "one" { 15 | t.Error("didn't get column") 16 | } 17 | } 18 | 19 | func TestGetTableMissing(t *testing.T) { 20 | t.Parallel() 21 | 22 | tables := []Table{ 23 | {Name: "one"}, 24 | } 25 | 26 | defer func() { 27 | if r := recover(); r == nil { 28 | t.Error("expected a panic failure") 29 | } 30 | }() 31 | 32 | GetTable(tables, "missing") 33 | } 34 | 35 | func TestGetColumn(t *testing.T) { 36 | t.Parallel() 37 | 38 | table := Table{ 39 | Columns: []Column{ 40 | {Name: "one"}, 41 | }, 42 | } 43 | 44 | c := table.GetColumn("one") 45 | 46 | if c.Name != "one" { 47 | t.Error("didn't get column") 48 | } 49 | } 50 | 51 | func TestGetColumnMissing(t *testing.T) { 52 | t.Parallel() 53 | 54 | table := Table{ 55 | Columns: []Column{ 56 | {Name: "one"}, 57 | }, 58 | } 59 | 60 | defer func() { 61 | if r := recover(); r == nil { 62 | t.Error("expected a panic failure") 63 | } 64 | }() 65 | 66 | table.GetColumn("missing") 67 | } 68 | 69 | func TestCanLastInsertID(t *testing.T) { 70 | t.Parallel() 71 | 72 | tests := []struct { 73 | Can bool 74 | PKeys []Column 75 | }{ 76 | {true, []Column{ 77 | {Name: "id", Type: "int64", Default: "a"}, 78 | }}, 79 | {true, []Column{ 80 | {Name: "id", Type: "uint64", Default: "a"}, 81 | }}, 82 | {true, []Column{ 83 | {Name: "id", Type: "int", Default: "a"}, 84 | }}, 85 | {true, []Column{ 86 | {Name: "id", Type: "uint", Default: "a"}, 87 | }}, 88 | {true, []Column{ 89 | {Name: "id", Type: "uint", Default: "a"}, 90 | }}, 91 | {false, []Column{ 92 | {Name: "id", Type: "uint", Default: "a"}, 93 | {Name: "id2", Type: "uint", Default: "a"}, 94 | }}, 95 | {false, []Column{ 96 | {Name: "id", Type: "string", Default: "a"}, 97 | }}, 98 | {false, []Column{ 99 | {Name: "id", Type: "int", Default: ""}, 100 | }}, 101 | {false, nil}, 102 | } 103 | 104 | for i, test := range tests { 105 | table := Table{ 106 | Columns: test.PKeys, 107 | PKey: &PrimaryKey{}, 108 | } 109 | 110 | var pkeyNames []string 111 | for _, pk := range test.PKeys { 112 | pkeyNames = append(pkeyNames, pk.Name) 113 | } 114 | table.PKey.Columns = pkeyNames 115 | 116 | if got := table.CanLastInsertID(); got != test.Can { 117 | t.Errorf("%d) wrong: %t", i, got) 118 | } 119 | } 120 | } 121 | 122 | func TestCanSoftDelete(t *testing.T) { 123 | t.Parallel() 124 | 125 | tests := []struct { 126 | Can bool 127 | Columns []Column 128 | }{ 129 | {true, []Column{ 130 | {Name: "deleted_at", Type: "null.Time"}, 131 | }}, 132 | {false, []Column{ 133 | {Name: "deleted_at", Type: "time.Time"}, 134 | }}, 135 | {false, []Column{ 136 | {Name: "deleted_at", Type: "int"}, 137 | }}, 138 | {false, nil}, 139 | } 140 | 141 | for i, test := range tests { 142 | table := Table{ 143 | Columns: test.Columns, 144 | } 145 | 146 | if got := table.CanSoftDelete("deleted_at"); got != test.Can { 147 | t.Errorf("%d) wrong: %t", i, got) 148 | } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/volatiletech/sqlboiler/v4 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.4.1 7 | github.com/Masterminds/sprig/v3 v3.2.2 8 | github.com/davecgh/go-spew v1.1.1 9 | github.com/ericlagergren/decimal v0.0.0-20190420051523-6335edbaa640 10 | github.com/friendsofgo/errors v0.9.2 11 | github.com/go-sql-driver/mysql v1.6.0 12 | github.com/google/go-cmp v0.6.0 13 | github.com/lib/pq v1.10.6 14 | github.com/microsoft/go-mssqldb v0.17.0 15 | github.com/spf13/cast v1.5.0 16 | github.com/spf13/cobra v1.5.0 17 | github.com/spf13/viper v1.12.0 18 | github.com/stretchr/testify v1.8.0 19 | github.com/volatiletech/null/v8 v8.1.2 20 | github.com/volatiletech/randomize v0.0.1 21 | github.com/volatiletech/strmangle v0.0.6 22 | modernc.org/sqlite v1.18.1 23 | ) 24 | 25 | require ( 26 | github.com/Masterminds/goutils v1.1.1 // indirect 27 | github.com/Masterminds/semver/v3 v3.1.1 // indirect 28 | github.com/fsnotify/fsnotify v1.5.4 // indirect 29 | github.com/gofrs/uuid v4.2.0+incompatible // indirect 30 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 31 | github.com/golang-sql/sqlexp v0.1.0 // indirect 32 | github.com/google/uuid v1.3.0 // indirect 33 | github.com/hashicorp/hcl v1.0.0 // indirect 34 | github.com/huandu/xstrings v1.3.2 // indirect 35 | github.com/imdario/mergo v0.3.13 // indirect 36 | github.com/inconshreveable/mousetrap v1.0.1 // indirect 37 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect 38 | github.com/magiconair/properties v1.8.6 // indirect 39 | github.com/mattn/go-isatty v0.0.16 // indirect 40 | github.com/mitchellh/copystructure v1.2.0 // indirect 41 | github.com/mitchellh/mapstructure v1.5.0 // indirect 42 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 43 | github.com/pelletier/go-toml v1.9.5 // indirect 44 | github.com/pelletier/go-toml/v2 v2.0.5 // indirect 45 | github.com/pmezard/go-difflib v1.0.0 // indirect 46 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect 47 | github.com/shopspring/decimal v1.3.1 // indirect 48 | github.com/spf13/afero v1.9.2 // indirect 49 | github.com/spf13/jwalterweatherman v1.1.0 // indirect 50 | github.com/spf13/pflag v1.0.5 // indirect 51 | github.com/subosito/gotenv v1.4.1 // indirect 52 | github.com/volatiletech/inflect v0.0.1 // indirect 53 | golang.org/x/crypto v0.38.0 // indirect 54 | golang.org/x/mod v0.17.0 // indirect 55 | golang.org/x/sync v0.14.0 // indirect 56 | golang.org/x/sys v0.33.0 // indirect 57 | golang.org/x/text v0.25.0 // indirect 58 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect 59 | golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect 60 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect 61 | gopkg.in/ini.v1 v1.67.0 // indirect 62 | gopkg.in/yaml.v2 v2.4.0 // indirect 63 | gopkg.in/yaml.v3 v3.0.1 // indirect 64 | lukechampine.com/uint128 v1.2.0 // indirect 65 | modernc.org/cc/v3 v3.36.3 // indirect 66 | modernc.org/ccgo/v3 v3.16.9 // indirect 67 | modernc.org/libc v1.17.1 // indirect 68 | modernc.org/mathutil v1.5.0 // indirect 69 | modernc.org/memory v1.2.1 // indirect 70 | modernc.org/opt v0.1.3 // indirect 71 | modernc.org/strutil v1.1.3 // indirect 72 | modernc.org/token v1.0.1 // indirect 73 | ) 74 | 75 | retract ( 76 | v4.19.0 // Performance issue due to cleaning up unused imports in generated code 77 | v4.17.1 // Generates faulty code for DeleteAll if the table has multiple foreign keys 78 | v4.17.0 // Generates faulty code for DeleteAll if the table has multiple foreign keys 79 | v4.10.0 // Generated models are invalid due to a wrong assignment 80 | v4.9.0 // Generated code shows v4.8.6, messed up commit tagging and untidy go.mod 81 | ) 82 | -------------------------------------------------------------------------------- /queries/_fixtures/00.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM "t"; -------------------------------------------------------------------------------- /queries/_fixtures/01.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM "q" LIMIT 5 OFFSET 6; -------------------------------------------------------------------------------- /queries/_fixtures/02.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM "q" ORDER BY a ASC, b like $1 DESC; -------------------------------------------------------------------------------- /queries/_fixtures/03.sql: -------------------------------------------------------------------------------- 1 | SELECT count(*) as ab, thing as bd, "stuff" FROM "t"; -------------------------------------------------------------------------------- /queries/_fixtures/04.sql: -------------------------------------------------------------------------------- 1 | SELECT count(*) as ab, thing as bd, "stuff" FROM "a", "b"; -------------------------------------------------------------------------------- /queries/_fixtures/05.sql: -------------------------------------------------------------------------------- 1 | SELECT "a"."happy" as "a.happy", "r"."fun" as "r.fun", "q" FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id; -------------------------------------------------------------------------------- /queries/_fixtures/06.sql: -------------------------------------------------------------------------------- 1 | SELECT "a".* FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id; -------------------------------------------------------------------------------- /queries/_fixtures/07.sql: -------------------------------------------------------------------------------- 1 | SELECT "videos".* FROM "videos" INNER JOIN (select id from users where deleted = $1) u on u.id = videos.user_id WHERE (videos.deleted = $2); -------------------------------------------------------------------------------- /queries/_fixtures/08.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM "a" WHERE (a=$1 or b=$2) AND (c=$3) GROUP BY id, name HAVING id <> $4 AND length(name, $5) > $6; -------------------------------------------------------------------------------- /queries/_fixtures/09.sql: -------------------------------------------------------------------------------- 1 | DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE (a=$1) AND (b=$2) AND (c=$3); -------------------------------------------------------------------------------- /queries/_fixtures/10.sql: -------------------------------------------------------------------------------- 1 | DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3) LIMIT 5; -------------------------------------------------------------------------------- /queries/_fixtures/11.sql: -------------------------------------------------------------------------------- 1 | UPDATE thing happy, "fun", "stuff" SET "col2" = $1, "fun"."col3" = $2, "col1" = $3 WHERE (aa=$4 or bb=$5 or cc=$6) AND (dd=$7 or ee=$8 or ff=$9 and gg=$10) LIMIT 5; -------------------------------------------------------------------------------- /queries/_fixtures/12.sql: -------------------------------------------------------------------------------- 1 | SELECT "cats".* FROM "cats" INNER JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/13.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats c INNER JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/14.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats as c INNER JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/15.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".*, "d".* FROM cats as c, dogs as d INNER JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/16.sql: -------------------------------------------------------------------------------- 1 | SELECT "cats".* FROM "cats" LEFT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/17.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats c LEFT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/18.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats as c LEFT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/19.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".*, "d".* FROM cats as c, dogs as d LEFT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/20.sql: -------------------------------------------------------------------------------- 1 | SELECT "cats".* FROM "cats" RIGHT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/21.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats c RIGHT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/22.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats as c RIGHT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/23.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".*, "d".* FROM cats as c, dogs as d RIGHT JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/24.sql: -------------------------------------------------------------------------------- 1 | SELECT "cats".* FROM "cats" FULL JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/25.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats c FULL JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/26.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".* FROM cats as c FULL JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/27.sql: -------------------------------------------------------------------------------- 1 | SELECT "c".*, "d".* FROM cats as c, dogs as d FULL JOIN dogs d on d.cat_id = cats.id; -------------------------------------------------------------------------------- /queries/_fixtures/28.sql: -------------------------------------------------------------------------------- 1 | WITH cte_0 AS (SELECT * FROM other_t0), cte_1 AS (SELECT * FROM other_t1 WHERE thing=$1 AND stuff=$2) SELECT * FROM "t"; -------------------------------------------------------------------------------- /queries/_fixtures/29.sql: -------------------------------------------------------------------------------- 1 | SELECT DISTINCT id FROM "t"; -------------------------------------------------------------------------------- /queries/_fixtures/30.sql: -------------------------------------------------------------------------------- 1 | SELECT COUNT(DISTINCT (id)) FROM "t"; -------------------------------------------------------------------------------- /queries/_fixtures/31.sql: -------------------------------------------------------------------------------- 1 | SELECT DISTINCT id, t.* FROM "t" INNER JOIN dogs d on d.cat_id = t.id; -------------------------------------------------------------------------------- /queries/_fixtures/32.sql: -------------------------------------------------------------------------------- 1 | SELECT COUNT(DISTINCT (id, t.*)) FROM "t" INNER JOIN dogs d on d.cat_id = t.id; -------------------------------------------------------------------------------- /queries/_fixtures/33.sql: -------------------------------------------------------------------------------- 1 | SELECT * FROM "t" WHERE (deleted_at = survives); -------------------------------------------------------------------------------- /queries/helpers.go: -------------------------------------------------------------------------------- 1 | package queries 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // NonZeroDefaultSet returns the fields included in the 9 | // defaults slice that are non zero values 10 | func NonZeroDefaultSet(defaults []string, obj interface{}) []string { 11 | c := make([]string, 0, len(defaults)) 12 | 13 | val := reflect.Indirect(reflect.ValueOf(obj)) 14 | typ := val.Type() 15 | nf := typ.NumField() 16 | 17 | for _, def := range defaults { 18 | found := false 19 | for i := 0; i < nf; i++ { 20 | field := typ.Field(i) 21 | name, _ := getBoilTag(field) 22 | 23 | if name != def { 24 | continue 25 | } 26 | 27 | found = true 28 | fieldVal := val.Field(i) 29 | 30 | zero := reflect.Zero(fieldVal.Type()) 31 | if !reflect.DeepEqual(zero.Interface(), fieldVal.Interface()) { 32 | c = append(c, def) 33 | } 34 | break 35 | } 36 | 37 | if !found { 38 | panic(fmt.Sprintf("could not find field name %s in type %T", def, obj)) 39 | } 40 | } 41 | 42 | return c 43 | } 44 | -------------------------------------------------------------------------------- /queries/helpers_test.go: -------------------------------------------------------------------------------- 1 | package queries 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/volatiletech/null/v8" 9 | ) 10 | 11 | type testObj struct { 12 | ID int 13 | Name string `db:"TestHello"` 14 | HeadSize int 15 | } 16 | 17 | func TestNonZeroDefaultSet(t *testing.T) { 18 | t.Parallel() 19 | 20 | type Anything struct { 21 | ID int `boil:"id"` 22 | Name string `boil:"name"` 23 | CreatedAt *time.Time `boil:"created_at"` 24 | UpdatedAt null.Time `boil:"updated_at"` 25 | } 26 | 27 | now := time.Now() 28 | 29 | tests := []struct { 30 | Defaults []string 31 | Obj interface{} 32 | Ret []string 33 | }{ 34 | { 35 | []string{"id"}, 36 | Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, 37 | []string{}, 38 | }, 39 | { 40 | []string{"id"}, 41 | Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, 42 | []string{"id"}, 43 | }, 44 | { 45 | []string{}, 46 | Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, 47 | []string{}, 48 | }, 49 | { 50 | []string{"id", "created_at", "updated_at"}, 51 | Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}}, 52 | []string{"id"}, 53 | }, 54 | { 55 | []string{"id", "created_at", "updated_at"}, 56 | Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}}, 57 | []string{"id", "created_at", "updated_at"}, 58 | }, 59 | } 60 | 61 | for i, test := range tests { 62 | z := NonZeroDefaultSet(test.Defaults, test.Obj) 63 | if !reflect.DeepEqual(test.Ret, z) { 64 | t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /queries/qmhelper/qmhelper.go: -------------------------------------------------------------------------------- 1 | package qmhelper 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "github.com/volatiletech/sqlboiler/v4/queries" 8 | ) 9 | 10 | // Nullable object 11 | type Nullable interface { 12 | IsZero() bool 13 | } 14 | 15 | // WhereQueryMod allows construction of where clauses 16 | type WhereQueryMod struct { 17 | Clause string 18 | Args []interface{} 19 | } 20 | 21 | // Apply implements QueryMod.Apply. 22 | func (qm WhereQueryMod) Apply(q *queries.Query) { 23 | queries.AppendWhere(q, qm.Clause, qm.Args...) 24 | } 25 | 26 | // WhereNullEQ is a helper for doing equality with null types 27 | func WhereNullEQ(name string, negated bool, value interface{}) WhereQueryMod { 28 | isNull := false 29 | if nullable, ok := value.(Nullable); ok { 30 | isNull = nullable.IsZero() 31 | } else { 32 | isNull = reflect.ValueOf(value).IsNil() 33 | } 34 | 35 | if isNull { 36 | var not string 37 | if negated { 38 | not = "not " 39 | } 40 | return WhereQueryMod{ 41 | Clause: fmt.Sprintf("%s is %snull", name, not), 42 | } 43 | } 44 | 45 | op := "=" 46 | if negated { 47 | op = "!=" 48 | } 49 | 50 | return WhereQueryMod{ 51 | Clause: fmt.Sprintf("%s %s ?", name, op), 52 | Args: []interface{}{value}, 53 | } 54 | } 55 | 56 | // WhereIsNull is a helper that just returns "name is null" 57 | func WhereIsNull(name string) WhereQueryMod { 58 | return WhereQueryMod{Clause: fmt.Sprintf("%s is null", name)} 59 | } 60 | 61 | // WhereIsNotNull is a helper that just returns "name is not null" 62 | func WhereIsNotNull(name string) WhereQueryMod { 63 | return WhereQueryMod{Clause: fmt.Sprintf("%s is not null", name)} 64 | } 65 | 66 | type operator string 67 | 68 | // Supported operations 69 | const ( 70 | EQ operator = "=" 71 | NEQ operator = "!=" 72 | LT operator = "<" 73 | LTE operator = "<=" 74 | GT operator = ">" 75 | GTE operator = ">=" 76 | ) 77 | 78 | // Where is a helper for doing operations on primitive types 79 | func Where(name string, operator operator, value interface{}) WhereQueryMod { 80 | return WhereQueryMod{ 81 | Clause: fmt.Sprintf("%s %s ?", name, string(operator)), 82 | Args: []interface{}{value}, 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /templates/embed.go: -------------------------------------------------------------------------------- 1 | // Package templates is an empty package strictly for embedding sqlboiler 2 | // default templates. 3 | package templates 4 | 5 | import "embed" 6 | 7 | // Builtin sqlboiler templates 8 | //go:embed main test 9 | var Builtin embed.FS 10 | -------------------------------------------------------------------------------- /templates/main/01_types.go.tpl: -------------------------------------------------------------------------------- 1 | {{if .Table.IsJoinTable -}} 2 | {{else -}} 3 | {{- $alias := .Aliases.Table .Table.Name -}} 4 | var ( 5 | {{$alias.DownSingular}}AllColumns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} 6 | {{$alias.DownSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} 7 | {{$alias.DownSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} 8 | {{if .Table.IsView -}} 9 | {{$alias.DownSingular}}PrimaryKeyColumns = []string{} 10 | {{else -}} 11 | {{$alias.DownSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} 12 | {{end -}} 13 | {{$alias.DownSingular}}GeneratedColumns = []string{{"{"}}{{.Table.Columns | filterColumnsByAuto true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} 14 | ) 15 | 16 | type ( 17 | // {{$alias.UpSingular}}Slice is an alias for a slice of pointers to {{$alias.UpSingular}}. 18 | // This should almost always be used instead of []{{$alias.UpSingular}}. 19 | {{$alias.UpSingular}}Slice []*{{$alias.UpSingular}} 20 | {{if not .NoHooks -}} 21 | // {{$alias.UpSingular}}Hook is the signature for custom {{$alias.UpSingular}} hook methods 22 | {{$alias.UpSingular}}Hook func({{if .NoContext}}boil.Executor{{else}}context.Context, boil.ContextExecutor{{end}}, *{{$alias.UpSingular}}) error 23 | {{- end}} 24 | 25 | {{$alias.DownSingular}}Query struct { 26 | *queries.Query 27 | } 28 | ) 29 | 30 | // Cache for insert, update and upsert 31 | var ( 32 | {{$alias.DownSingular}}Type = reflect.TypeOf(&{{$alias.UpSingular}}{}) 33 | {{$alias.DownSingular}}Mapping = queries.MakeStructMapping({{$alias.DownSingular}}Type) 34 | {{if not .Table.IsView -}} 35 | {{$alias.DownSingular}}PrimaryKeyMapping, _ = queries.BindMapping({{$alias.DownSingular}}Type, {{$alias.DownSingular}}Mapping, {{$alias.DownSingular}}PrimaryKeyColumns) 36 | {{end -}} 37 | {{$alias.DownSingular}}InsertCacheMut sync.RWMutex 38 | {{$alias.DownSingular}}InsertCache = make(map[string]insertCache) 39 | {{$alias.DownSingular}}UpdateCacheMut sync.RWMutex 40 | {{$alias.DownSingular}}UpdateCache = make(map[string]updateCache) 41 | {{$alias.DownSingular}}UpsertCacheMut sync.RWMutex 42 | {{$alias.DownSingular}}UpsertCache = make(map[string]insertCache) 43 | ) 44 | 45 | var ( 46 | // Force time package dependency for automated UpdatedAt/CreatedAt. 47 | _ = time.Second 48 | // Force qmhelper dependency for where clause generation (which doesn't 49 | // always happen) 50 | _ = qmhelper.Where 51 | {{if .Table.IsView -}} 52 | // These are used in some views 53 | _ = fmt.Sprintln("") 54 | _ = reflect.Int 55 | _ = strings.Builder{} 56 | _ = sync.Mutex{} 57 | _ = strmangle.Plural("") 58 | _ = strconv.IntSize 59 | {{- end}} 60 | ) 61 | {{end -}} 62 | -------------------------------------------------------------------------------- /templates/main/04_relationship_to_one.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if or .Table.IsJoinTable .Table.IsView -}} 2 | {{- else -}} 3 | {{- range $fkey := .Table.FKeys -}} 4 | {{- $ltable := $.Aliases.Table $fkey.Table -}} 5 | {{- $ftable := $.Aliases.Table $fkey.ForeignTable -}} 6 | {{- $rel := $ltable.Relationship $fkey.Name -}} 7 | {{- $canSoftDelete := (getTable $.Tables $fkey.ForeignTable).CanSoftDelete $.AutoColumns.Deleted }} 8 | // {{$rel.Foreign}} pointed to by the foreign key. 9 | func (o *{{$ltable.UpSingular}}) {{$rel.Foreign}}(mods ...qm.QueryMod) ({{$ftable.DownSingular}}Query) { 10 | queryMods := []qm.QueryMod{ 11 | qm.Where("{{$fkey.ForeignColumn | $.Quotes}} = ?", o.{{$ltable.Column $fkey.Column}}), 12 | } 13 | 14 | queryMods = append(queryMods, mods...) 15 | 16 | return {{$ftable.UpPlural}}(queryMods...) 17 | } 18 | {{- end -}} 19 | {{- end -}} 20 | -------------------------------------------------------------------------------- /templates/main/05_relationship_one_to_one.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if or .Table.IsJoinTable .Table.IsView -}} 2 | {{- else -}} 3 | {{- range $rel := .Table.ToOneRelationships -}} 4 | {{- $ltable := $.Aliases.Table $rel.Table -}} 5 | {{- $ftable := $.Aliases.Table $rel.ForeignTable -}} 6 | {{- $relAlias := $ftable.Relationship $rel.Name -}} 7 | {{- $canSoftDelete := (getTable $.Tables $rel.ForeignTable).CanSoftDelete $.AutoColumns.Deleted }} 8 | // {{$relAlias.Local}} pointed to by the foreign key. 9 | func (o *{{$ltable.UpSingular}}) {{$relAlias.Local}}(mods ...qm.QueryMod) ({{$ftable.DownSingular}}Query) { 10 | queryMods := []qm.QueryMod{ 11 | qm.Where("{{$rel.ForeignColumn | $.Quotes}} = ?", o.{{$ltable.Column $rel.Column}}), 12 | } 13 | 14 | queryMods = append(queryMods, mods...) 15 | 16 | return {{$ftable.UpPlural}}(queryMods...) 17 | } 18 | {{- end -}} 19 | {{- end -}} 20 | -------------------------------------------------------------------------------- /templates/main/06_relationship_to_many.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if or .Table.IsJoinTable .Table.IsView -}} 2 | {{- else -}} 3 | {{- range $rel := .Table.ToManyRelationships -}} 4 | {{- $ltable := $.Aliases.Table $rel.Table -}} 5 | {{- $ftable := $.Aliases.Table $rel.ForeignTable -}} 6 | {{- $relAlias := $.Aliases.ManyRelationship $rel.ForeignTable $rel.Name $rel.JoinTable $rel.JoinLocalFKeyName -}} 7 | {{- $schemaForeignTable := .ForeignTable | $.SchemaTable -}} 8 | {{- $canSoftDelete := (getTable $.Tables .ForeignTable).CanSoftDelete $.AutoColumns.Deleted}} 9 | // {{$relAlias.Local}} retrieves all the {{.ForeignTable | singular}}'s {{$ftable.UpPlural}} with an executor 10 | {{- if not (eq $relAlias.Local $ftable.UpPlural)}} via {{$rel.ForeignColumn}} column{{- end}}. 11 | func (o *{{$ltable.UpSingular}}) {{$relAlias.Local}}(mods ...qm.QueryMod) {{$ftable.DownSingular}}Query { 12 | var queryMods []qm.QueryMod 13 | if len(mods) != 0 { 14 | queryMods = append(queryMods, mods...) 15 | } 16 | 17 | {{if $rel.ToJoinTable -}} 18 | queryMods = append(queryMods, 19 | {{$schemaJoinTable := $rel.JoinTable | $.SchemaTable -}} 20 | qm.InnerJoin("{{$schemaJoinTable}} on {{$schemaForeignTable}}.{{$rel.ForeignColumn | $.Quotes}} = {{$schemaJoinTable}}.{{$rel.JoinForeignColumn | $.Quotes}}"), 21 | qm.Where("{{$schemaJoinTable}}.{{$rel.JoinLocalColumn | $.Quotes}}=?", o.{{$ltable.Column $rel.Column}}), 22 | ) 23 | {{else -}} 24 | queryMods = append(queryMods, 25 | qm.Where("{{$schemaForeignTable}}.{{$rel.ForeignColumn | $.Quotes}}=?", o.{{$ltable.Column $rel.Column}}), 26 | ) 27 | {{end}} 28 | 29 | return {{$ftable.UpPlural}}(queryMods...) 30 | } 31 | 32 | {{end -}}{{- /* range relationships */ -}} 33 | {{- end -}}{{- /* if isJoinTable */ -}} 34 | -------------------------------------------------------------------------------- /templates/main/08_relationship_one_to_one_eager.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if or .Table.IsJoinTable .Table.IsView -}} 2 | {{- else -}} 3 | {{- range $rel := .Table.ToOneRelationships -}} 4 | {{- $ltable := $.Aliases.Table $rel.Table -}} 5 | {{- $ftable := $.Aliases.Table $rel.ForeignTable -}} 6 | {{- $relAlias := $ftable.Relationship $rel.Name -}} 7 | {{- $col := $ltable.Column $rel.Column -}} 8 | {{- $fcol := $ftable.Column $rel.ForeignColumn -}} 9 | {{- $usesPrimitives := usesPrimitives $.Tables $rel.Table $rel.Column $rel.ForeignTable $rel.ForeignColumn -}} 10 | {{- $arg := printf "maybe%s" $ltable.UpSingular -}} 11 | {{- $canSoftDelete := (getTable $.Tables $rel.ForeignTable).CanSoftDelete $.AutoColumns.Deleted }} 12 | // Load{{$relAlias.Local}} allows an eager lookup of values, cached into the 13 | // loaded structs of the objects. This is for a 1-1 relationship. 14 | func ({{$ltable.DownSingular}}L) Load{{$relAlias.Local}}({{if $.NoContext}}e boil.Executor{{else}}ctx context.Context, e boil.ContextExecutor{{end}}, singular bool, {{$arg}} interface{}, mods queries.Applicator) error { 15 | var slice []*{{$ltable.UpSingular}} 16 | var object *{{$ltable.UpSingular}} 17 | 18 | if singular { 19 | var ok bool 20 | object, ok = {{$arg}}.(*{{$ltable.UpSingular}}) 21 | if !ok { 22 | object = new({{$ltable.UpSingular}}) 23 | ok = queries.SetFromEmbeddedStruct(&object, &{{$arg}}) 24 | if !ok { 25 | return errors.New(fmt.Sprintf("failed to set %T from embedded struct %T", object, {{$arg}})) 26 | } 27 | } 28 | } else { 29 | s, ok := {{$arg}}.(*[]*{{$ltable.UpSingular}}) 30 | if ok { 31 | slice = *s 32 | } else { 33 | ok = queries.SetFromEmbeddedStruct(&slice, {{$arg}}) 34 | if !ok { 35 | return errors.New(fmt.Sprintf("failed to set %T from embedded struct %T", slice, {{$arg}})) 36 | } 37 | } 38 | } 39 | 40 | args := make(map[interface{}]struct{}) 41 | if singular { 42 | if object.R == nil { 43 | object.R = &{{$ltable.DownSingular}}R{} 44 | } 45 | args[object.{{$col}}] = struct{}{} 46 | } else { 47 | for _, obj := range slice { 48 | if obj.R == nil { 49 | obj.R = &{{$ltable.DownSingular}}R{} 50 | } 51 | 52 | args[obj.{{$col}}] = struct{}{} 53 | } 54 | } 55 | 56 | if len(args) == 0 { 57 | return nil 58 | } 59 | 60 | argsSlice := make([]interface{}, len(args)) 61 | i := 0 62 | for arg := range args { 63 | argsSlice[i] = arg 64 | i++ 65 | } 66 | 67 | query := NewQuery( 68 | qm.From(`{{if $.Dialect.UseSchema}}{{$.Schema}}.{{end}}{{.ForeignTable}}`), 69 | qm.WhereIn(`{{if $.Dialect.UseSchema}}{{$.Schema}}.{{end}}{{.ForeignTable}}.{{.ForeignColumn}} in ?`, argsSlice...), 70 | {{if and $.AddSoftDeletes $canSoftDelete -}} 71 | qmhelper.WhereIsNull(`{{if $.Dialect.UseSchema}}{{$.Schema}}.{{end}}{{.ForeignTable}}.{{or $.AutoColumns.Deleted "deleted_at"}}`), 72 | {{- end}} 73 | ) 74 | if mods != nil { 75 | mods.Apply(query) 76 | } 77 | 78 | {{if $.NoContext -}} 79 | results, err := query.Query(e) 80 | {{else -}} 81 | results, err := query.QueryContext(ctx, e) 82 | {{end -}} 83 | if err != nil { 84 | return errors.Wrap(err, "failed to eager load {{$ftable.UpSingular}}") 85 | } 86 | 87 | var resultSlice []*{{$ftable.UpSingular}} 88 | if err = queries.Bind(results, &resultSlice); err != nil { 89 | return errors.Wrap(err, "failed to bind eager loaded slice {{$ftable.UpSingular}}") 90 | } 91 | 92 | if err = results.Close(); err != nil { 93 | return errors.Wrap(err, "failed to close results of eager load for {{.ForeignTable}}") 94 | } 95 | if err = results.Err(); err != nil { 96 | return errors.Wrap(err, "error occurred during iteration of eager loaded relations for {{.ForeignTable}}") 97 | } 98 | 99 | {{if not $.NoHooks -}} 100 | if len({{$ftable.DownSingular}}AfterSelectHooks) != 0 { 101 | for _, obj := range resultSlice { 102 | if err := obj.doAfterSelectHooks({{if $.NoContext}}e{{else}}ctx, e{{end}}); err != nil { 103 | return err 104 | } 105 | } 106 | } 107 | {{- end}} 108 | 109 | if len(resultSlice) == 0 { 110 | return nil 111 | } 112 | 113 | if singular { 114 | foreign := resultSlice[0] 115 | object.R.{{$relAlias.Local}} = foreign 116 | {{if not $.NoBackReferencing -}} 117 | if foreign.R == nil { 118 | foreign.R = &{{$ftable.DownSingular}}R{} 119 | } 120 | foreign.R.{{$relAlias.Foreign}} = object 121 | {{end -}} 122 | } 123 | 124 | for _, local := range slice { 125 | for _, foreign := range resultSlice { 126 | {{if $usesPrimitives -}} 127 | if local.{{$col}} == foreign.{{$fcol}} { 128 | {{else -}} 129 | if queries.Equal(local.{{$col}}, foreign.{{$fcol}}) { 130 | {{end -}} 131 | local.R.{{$relAlias.Local}} = foreign 132 | {{if not $.NoBackReferencing -}} 133 | if foreign.R == nil { 134 | foreign.R = &{{$ftable.DownSingular}}R{} 135 | } 136 | foreign.R.{{$relAlias.Foreign}} = local 137 | {{end -}} 138 | break 139 | } 140 | } 141 | } 142 | 143 | return nil 144 | } 145 | {{end -}}{{/* range */}} 146 | {{end}}{{/* join table */}} 147 | -------------------------------------------------------------------------------- /templates/main/13_all.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | {{- $schemaTable := .Table.Name | .SchemaTable}} 3 | {{- $canSoftDelete := .Table.CanSoftDelete $.AutoColumns.Deleted }} 4 | // {{$alias.UpPlural}} retrieves all the records using an executor. 5 | func {{$alias.UpPlural}}(mods ...qm.QueryMod) {{$alias.DownSingular}}Query { 6 | {{if and .AddSoftDeletes $canSoftDelete -}} 7 | mods = append(mods, qm.From("{{$schemaTable}}"), qmhelper.WhereIsNull("{{$schemaTable}}.{{or $.AutoColumns.Deleted "deleted_at" | $.Quotes}}")) 8 | {{else -}} 9 | mods = append(mods, qm.From("{{$schemaTable}}")) 10 | {{end -}} 11 | 12 | q := NewQuery(mods...) 13 | if len(queries.GetSelect(q)) == 0 { 14 | queries.SetSelect(q, []string{"{{$schemaTable}}.*"}) 15 | } 16 | 17 | return {{$alias.DownSingular}}Query{q} 18 | } 19 | -------------------------------------------------------------------------------- /templates/main/14_find.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if .Table.IsView -}} 2 | {{- else -}} 3 | {{- $alias := .Aliases.Table .Table.Name -}} 4 | {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} 5 | {{- $pkNames := $colDefs.Names | stringMap (aliasCols $alias) | stringMap .StringFuncs.camelCase | stringMap .StringFuncs.replaceReserved -}} 6 | {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} 7 | {{- $canSoftDelete := .Table.CanSoftDelete $.AutoColumns.Deleted }} 8 | {{if .AddGlobal -}} 9 | // Find{{$alias.UpSingular}}G retrieves a single record by ID. 10 | func Find{{$alias.UpSingular}}G({{if not .NoContext}}ctx context.Context, {{end -}} {{$pkArgs}}, selectCols ...string) (*{{$alias.UpSingular}}, error) { 11 | return Find{{$alias.UpSingular}}({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}, {{$pkNames | join ", "}}, selectCols...) 12 | } 13 | 14 | {{end -}} 15 | 16 | {{if .AddPanic -}} 17 | // Find{{$alias.UpSingular}}P retrieves a single record by ID with an executor, and panics on error. 18 | func Find{{$alias.UpSingular}}P({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}, {{$pkArgs}}, selectCols ...string) *{{$alias.UpSingular}} { 19 | retobj, err := Find{{$alias.UpSingular}}({{if not .NoContext}}ctx, {{end -}} exec, {{$pkNames | join ", "}}, selectCols...) 20 | if err != nil { 21 | panic(boil.WrapErr(err)) 22 | } 23 | 24 | return retobj 25 | } 26 | 27 | {{end -}} 28 | 29 | {{if and .AddGlobal .AddPanic -}} 30 | // Find{{$alias.UpSingular}}GP retrieves a single record by ID, and panics on error. 31 | func Find{{$alias.UpSingular}}GP({{if not .NoContext}}ctx context.Context, {{end -}} {{$pkArgs}}, selectCols ...string) *{{$alias.UpSingular}} { 32 | retobj, err := Find{{$alias.UpSingular}}({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}, {{$pkNames | join ", "}}, selectCols...) 33 | if err != nil { 34 | panic(boil.WrapErr(err)) 35 | } 36 | 37 | return retobj 38 | } 39 | 40 | {{end -}} 41 | 42 | // Find{{$alias.UpSingular}} retrieves a single record by ID with an executor. 43 | // If selectCols is empty Find will return all columns. 44 | func Find{{$alias.UpSingular}}({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}, {{$pkArgs}}, selectCols ...string) (*{{$alias.UpSingular}}, error) { 45 | {{$alias.DownSingular}}Obj := &{{$alias.UpSingular}}{} 46 | 47 | sel := "*" 48 | if len(selectCols) > 0 { 49 | sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") 50 | } 51 | query := fmt.Sprintf( 52 | "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.UseIndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}{{if and .AddSoftDeletes $canSoftDelete}} and {{or $.AutoColumns.Deleted "deleted_at" | $.Quotes}} is null{{end}}", sel, 53 | ) 54 | 55 | q := queries.Raw(query, {{$pkNames | join ", "}}) 56 | 57 | err := q.Bind({{if not .NoContext}}ctx{{else}}nil{{end}}, exec, {{$alias.DownSingular}}Obj) 58 | if err != nil { 59 | {{if not .AlwaysWrapErrors -}} 60 | if errors.Is(err, sql.ErrNoRows) { 61 | return nil, sql.ErrNoRows 62 | } 63 | {{end -}} 64 | return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}") 65 | } 66 | 67 | {{if not .NoHooks -}} 68 | if err = {{$alias.DownSingular}}Obj.doAfterSelectHooks({{if not .NoContext}}ctx, {{end -}} exec); err != nil { 69 | return {{$alias.DownSingular}}Obj, err 70 | } 71 | {{- end}} 72 | 73 | return {{$alias.DownSingular}}Obj, nil 74 | } 75 | 76 | {{- end -}} 77 | -------------------------------------------------------------------------------- /templates/main/19_reload.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if .Table.IsView -}} 2 | {{- else -}} 3 | {{- $alias := .Aliases.Table .Table.Name -}} 4 | {{- $schemaTable := .Table.Name | .SchemaTable -}} 5 | {{- $canSoftDelete := .Table.CanSoftDelete $.AutoColumns.Deleted }} 6 | {{if .AddGlobal -}} 7 | // ReloadG refetches the object from the database using the primary keys. 8 | func (o *{{$alias.UpSingular}}) ReloadG({{if not .NoContext}}ctx context.Context{{end}}) error { 9 | if o == nil { 10 | return errors.New("{{.PkgName}}: no {{$alias.UpSingular}} provided for reload") 11 | } 12 | 13 | return o.Reload({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}) 14 | } 15 | 16 | {{end -}} 17 | 18 | {{if .AddPanic -}} 19 | // ReloadP refetches the object from the database with an executor. Panics on error. 20 | func (o *{{$alias.UpSingular}}) ReloadP({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}) { 21 | if err := o.Reload({{if not .NoContext}}ctx, {{end -}} exec); err != nil { 22 | panic(boil.WrapErr(err)) 23 | } 24 | } 25 | 26 | {{end -}} 27 | 28 | {{if and .AddGlobal .AddPanic -}} 29 | // ReloadGP refetches the object from the database and panics on error. 30 | func (o *{{$alias.UpSingular}}) ReloadGP({{if not .NoContext}}ctx context.Context{{end}}) { 31 | if err := o.Reload({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}); err != nil { 32 | panic(boil.WrapErr(err)) 33 | } 34 | } 35 | 36 | {{end -}} 37 | 38 | // Reload refetches the object from the database 39 | // using the primary keys with an executor. 40 | func (o *{{$alias.UpSingular}}) Reload({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}) error { 41 | ret, err := Find{{$alias.UpSingular}}({{if not .NoContext}}ctx, {{end -}} exec, {{.Table.PKey.Columns | stringMap (aliasCols $alias) | prefixStringSlice "o." | join ", "}}) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | *o = *ret 47 | return nil 48 | } 49 | 50 | {{if .AddGlobal -}} 51 | // ReloadAllG refetches every row with matching primary key column values 52 | // and overwrites the original object slice with the newly updated slice. 53 | func (o *{{$alias.UpSingular}}Slice) ReloadAllG({{if not .NoContext}}ctx context.Context{{end}}) error { 54 | if o == nil { 55 | return errors.New("{{.PkgName}}: empty {{$alias.UpSingular}}Slice provided for reload all") 56 | } 57 | 58 | return o.ReloadAll({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}) 59 | } 60 | 61 | {{end -}} 62 | 63 | {{if .AddPanic -}} 64 | // ReloadAllP refetches every row with matching primary key column values 65 | // and overwrites the original object slice with the newly updated slice. 66 | // Panics on error. 67 | func (o *{{$alias.UpSingular}}Slice) ReloadAllP({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}) { 68 | if err := o.ReloadAll({{if not .NoContext}}ctx, {{end -}} exec); err != nil { 69 | panic(boil.WrapErr(err)) 70 | } 71 | } 72 | 73 | {{end -}} 74 | 75 | {{if and .AddGlobal .AddPanic -}} 76 | // ReloadAllGP refetches every row with matching primary key column values 77 | // and overwrites the original object slice with the newly updated slice. 78 | // Panics on error. 79 | func (o *{{$alias.UpSingular}}Slice) ReloadAllGP({{if not .NoContext}}ctx context.Context{{end}}) { 80 | if err := o.ReloadAll({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}); err != nil { 81 | panic(boil.WrapErr(err)) 82 | } 83 | } 84 | 85 | {{end -}} 86 | 87 | // ReloadAll refetches every row with matching primary key column values 88 | // and overwrites the original object slice with the newly updated slice. 89 | func (o *{{$alias.UpSingular}}Slice) ReloadAll({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}) error { 90 | if o == nil || len(*o) == 0 { 91 | return nil 92 | } 93 | 94 | slice := {{$alias.UpSingular}}Slice{} 95 | var args []interface{} 96 | for _, obj := range *o { 97 | pkeyArgs := queries.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(obj)), {{$alias.DownSingular}}PrimaryKeyMapping) 98 | args = append(args, pkeyArgs...) 99 | } 100 | 101 | sql := "SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE " + 102 | strmangle.WhereClauseRepeated(string(dialect.LQ), string(dialect.RQ), {{if .Dialect.UseIndexPlaceholders}}1{{else}}0{{end}}, {{$alias.DownSingular}}PrimaryKeyColumns, len(*o)){{if and .AddSoftDeletes $canSoftDelete}} + 103 | "and {{or $.AutoColumns.Deleted "deleted_at" | $.Quotes}} is null" 104 | {{- end}} 105 | 106 | q := queries.Raw(sql, args...) 107 | 108 | err := q.Bind({{if .NoContext}}nil{{else}}ctx{{end}}, exec, &slice) 109 | if err != nil { 110 | return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$alias.UpSingular}}Slice") 111 | } 112 | 113 | *o = slice 114 | 115 | return nil 116 | } 117 | 118 | {{- end -}} 119 | -------------------------------------------------------------------------------- /templates/main/20_exists.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if .Table.IsView -}} 2 | {{- else -}} 3 | {{- $alias := .Aliases.Table .Table.Name -}} 4 | {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} 5 | {{- $pkNames := $colDefs.Names | stringMap (aliasCols $alias) | stringMap .StringFuncs.camelCase | stringMap .StringFuncs.replaceReserved -}} 6 | {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} 7 | {{- $schemaTable := .Table.Name | .SchemaTable -}} 8 | {{- $canSoftDelete := .Table.CanSoftDelete $.AutoColumns.Deleted }} 9 | {{if .AddGlobal -}} 10 | // {{$alias.UpSingular}}ExistsG checks if the {{$alias.UpSingular}} row exists. 11 | func {{$alias.UpSingular}}ExistsG({{if not .NoContext}}ctx context.Context, {{end -}} {{$pkArgs}}) (bool, error) { 12 | return {{$alias.UpSingular}}Exists({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}, {{$pkNames | join ", "}}) 13 | } 14 | 15 | {{end -}} 16 | 17 | {{if .AddPanic -}} 18 | // {{$alias.UpSingular}}ExistsP checks if the {{$alias.UpSingular}} row exists. Panics on error. 19 | func {{$alias.UpSingular}}ExistsP({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}, {{$pkArgs}}) bool { 20 | e, err := {{$alias.UpSingular}}Exists({{if not .NoContext}}ctx, {{end -}} exec, {{$pkNames | join ", "}}) 21 | if err != nil { 22 | panic(boil.WrapErr(err)) 23 | } 24 | 25 | return e 26 | } 27 | 28 | {{end -}} 29 | 30 | {{if and .AddGlobal .AddPanic -}} 31 | // {{$alias.UpSingular}}ExistsGP checks if the {{$alias.UpSingular}} row exists. Panics on error. 32 | func {{$alias.UpSingular}}ExistsGP({{if not .NoContext}}ctx context.Context, {{end -}} {{$pkArgs}}) bool { 33 | e, err := {{$alias.UpSingular}}Exists({{if .NoContext}}boil.GetDB(){{else}}ctx, boil.GetContextDB(){{end}}, {{$pkNames | join ", "}}) 34 | if err != nil { 35 | panic(boil.WrapErr(err)) 36 | } 37 | 38 | return e 39 | } 40 | 41 | {{end -}} 42 | 43 | // {{$alias.UpSingular}}Exists checks if the {{$alias.UpSingular}} row exists. 44 | func {{$alias.UpSingular}}Exists({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}, {{$pkArgs}}) (bool, error) { 45 | var exists bool 46 | {{if .Dialect.UseCaseWhenExistsClause -}} 47 | sql := "select case when exists(select top(1) 1 from {{$schemaTable}} where {{if .Dialect.UseIndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}) then 1 else 0 end" 48 | {{- else -}} 49 | sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.UseIndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}{{if and .AddSoftDeletes $canSoftDelete}} and {{or $.AutoColumns.Deleted "deleted_at" | $.Quotes}} is null{{end}} limit 1)" 50 | {{- end}} 51 | 52 | {{if .NoContext -}} 53 | if boil.DebugMode { 54 | fmt.Fprintln(boil.DebugWriter, sql) 55 | fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}}) 56 | } 57 | {{else -}} 58 | if boil.IsDebug(ctx) { 59 | writer := boil.DebugWriterFrom(ctx) 60 | fmt.Fprintln(writer, sql) 61 | fmt.Fprintln(writer, {{$pkNames | join ", "}}) 62 | } 63 | {{end -}} 64 | 65 | {{if .NoContext -}} 66 | row := exec.QueryRow(sql, {{$pkNames | join ", "}}) 67 | {{else -}} 68 | row := exec.QueryRowContext(ctx, sql, {{$pkNames | join ", "}}) 69 | {{- end}} 70 | 71 | err := row.Scan(&exists) 72 | if err != nil { 73 | return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists") 74 | } 75 | 76 | return exists, nil 77 | } 78 | 79 | // Exists checks if the {{$alias.UpSingular}} row exists. 80 | func (o *{{$alias.UpSingular}}) Exists({{if .NoContext}}exec boil.Executor{{else}}ctx context.Context, exec boil.ContextExecutor{{end}}) (bool, error) { 81 | return {{$alias.UpSingular}}Exists({{if .NoContext}}exec{{else}}ctx, exec{{end}}, o.{{$.Table.PKey.Columns | stringMap (aliasCols $alias) | join ", o."}}) 82 | } 83 | 84 | {{- end -}} 85 | -------------------------------------------------------------------------------- /templates/main/21_auto_timestamps.go.tpl: -------------------------------------------------------------------------------- 1 | {{- define "timestamp_insert_helper" -}} 2 | {{- if not .NoAutoTimestamps -}} 3 | {{- $alias := .Aliases.Table .Table.Name -}} 4 | {{- $colNames := .Table.Columns | columnNames -}} 5 | {{if containsAny $colNames (or $.AutoColumns.Created "created_at") (or $.AutoColumns.Updated "updated_at")}} 6 | {{if not .NoContext -}} 7 | if !boil.TimestampsAreSkipped(ctx) { 8 | {{end -}} 9 | currTime := time.Now().In(boil.GetLocation()) 10 | {{range $ind, $col := .Table.Columns}} 11 | {{- $colAlias := $alias.Column $col.Name -}} 12 | {{- if eq $col.Name (or $.AutoColumns.Created "created_at") -}} 13 | {{- if eq $col.Type "time.Time" }} 14 | if o.{{$colAlias}}.IsZero() { 15 | o.{{$colAlias}} = currTime 16 | } 17 | {{- else}} 18 | if queries.MustTime(o.{{$colAlias}}).IsZero() { 19 | queries.SetScanner(&o.{{$colAlias}}, currTime) 20 | } 21 | {{- end -}} 22 | {{- end -}} 23 | {{- if eq $col.Name (or $.AutoColumns.Updated "updated_at") -}} 24 | {{- if eq $col.Type "time.Time"}} 25 | if o.{{$colAlias}}.IsZero() { 26 | o.{{$colAlias}} = currTime 27 | } 28 | {{- else}} 29 | if queries.MustTime(o.{{$colAlias}}).IsZero() { 30 | queries.SetScanner(&o.{{$colAlias}}, currTime) 31 | } 32 | {{- end -}} 33 | {{- end -}} 34 | {{end}} 35 | {{if not .NoContext -}} 36 | } 37 | {{end -}} 38 | {{end}} 39 | {{- end}} 40 | {{- end -}} 41 | {{- define "timestamp_update_helper" -}} 42 | {{- if not .NoAutoTimestamps -}} 43 | {{- $alias := .Aliases.Table .Table.Name -}} 44 | {{- $colNames := .Table.Columns | columnNames -}} 45 | {{if containsAny $colNames (or $.AutoColumns.Updated "updated_at")}} 46 | {{if not .NoContext -}} 47 | if !boil.TimestampsAreSkipped(ctx) { 48 | {{end -}} 49 | currTime := time.Now().In(boil.GetLocation()) 50 | {{range $ind, $col := .Table.Columns}} 51 | {{- $colAlias := $alias.Column $col.Name -}} 52 | {{- if eq $col.Name (or $.AutoColumns.Updated "updated_at") -}} 53 | {{- if eq $col.Type "time.Time"}} 54 | o.{{$colAlias}} = currTime 55 | {{- else}} 56 | queries.SetScanner(&o.{{$colAlias}}, currTime) 57 | {{- end -}} 58 | {{- end -}} 59 | {{end}} 60 | {{if not .NoContext -}} 61 | } 62 | {{end -}} 63 | {{end}} 64 | {{- end}} 65 | {{end -}} 66 | {{- define "timestamp_upsert_helper" -}} 67 | {{- if not .NoAutoTimestamps -}} 68 | {{- $alias := .Aliases.Table .Table.Name -}} 69 | {{- $colNames := .Table.Columns | columnNames -}} 70 | {{if containsAny $colNames (or $.AutoColumns.Created "created_at") (or $.AutoColumns.Updated "updated_at")}} 71 | {{if not .NoContext -}} 72 | if !boil.TimestampsAreSkipped(ctx) { 73 | {{end -}} 74 | currTime := time.Now().In(boil.GetLocation()) 75 | {{range $ind, $col := .Table.Columns}} 76 | {{- $colAlias := $alias.Column $col.Name -}} 77 | {{- if eq $col.Name (or $.AutoColumns.Created "created_at") -}} 78 | {{- if eq $col.Type "time.Time"}} 79 | if o.{{$colAlias}}.IsZero() { 80 | o.{{$colAlias}} = currTime 81 | } 82 | {{- else}} 83 | if queries.MustTime(o.{{$colAlias}}).IsZero() { 84 | queries.SetScanner(&o.{{$colAlias}}, currTime) 85 | } 86 | {{- end -}} 87 | {{- end -}} 88 | {{- if eq $col.Name (or $.AutoColumns.Updated "updated_at") -}} 89 | {{- if eq $col.Type "time.Time"}} 90 | o.{{$colAlias}} = currTime 91 | {{- else}} 92 | queries.SetScanner(&o.{{$colAlias}}, currTime) 93 | {{- end -}} 94 | {{- end -}} 95 | {{end}} 96 | {{if not .NoContext -}} 97 | } 98 | {{end -}} 99 | {{end}} 100 | {{- end}} 101 | {{end -}} 102 | -------------------------------------------------------------------------------- /templates/main/singleton/boil_queries.go.tpl: -------------------------------------------------------------------------------- 1 | var dialect = drivers.Dialect{ 2 | LQ: 0x{{printf "%x" .Dialect.LQ}}, 3 | RQ: 0x{{printf "%x" .Dialect.RQ}}, 4 | 5 | UseIndexPlaceholders: {{.Dialect.UseIndexPlaceholders}}, 6 | UseLastInsertID: {{.Dialect.UseLastInsertID}}, 7 | UseSchema: {{.Dialect.UseSchema}}, 8 | UseDefaultKeyword: {{.Dialect.UseDefaultKeyword}}, 9 | UseAutoColumns: {{.Dialect.UseAutoColumns}}, 10 | UseTopClause: {{.Dialect.UseTopClause}}, 11 | UseOutputClause: {{.Dialect.UseOutputClause}}, 12 | UseCaseWhenExistsClause: {{.Dialect.UseCaseWhenExistsClause}}, 13 | } 14 | 15 | {{- if not .AutoColumns.Deleted }} 16 | // This is a dummy variable to prevent unused regexp import error 17 | var _ = ®exp.Regexp{} 18 | {{- end }} 19 | 20 | {{- if and (.AutoColumns.Deleted) (ne $.AutoColumns.Deleted "deleted_at") }} 21 | func init() { 22 | queries.SetRemoveSoftDeleteRgx(regexp.MustCompile("{{$.AutoColumns.Deleted}}[\"'`]? is null")) 23 | } 24 | {{- end }} 25 | 26 | // NewQuery initializes a new Query using the passed in QueryMods 27 | func NewQuery(mods ...qm.QueryMod) *queries.Query { 28 | q := &queries.Query{} 29 | queries.SetDialect(q, &dialect) 30 | qm.Apply(q, mods...) 31 | 32 | return q 33 | } 34 | -------------------------------------------------------------------------------- /templates/main/singleton/boil_table_names.go.tpl: -------------------------------------------------------------------------------- 1 | var TableNames = struct { 2 | {{range $table := .Tables}}{{if not $table.IsView -}} 3 | {{titleCase $table.Name}} string 4 | {{end}}{{end -}} 5 | }{ 6 | {{range $table := .Tables}}{{if not $table.IsView -}} 7 | {{titleCase $table.Name}}: "{{$table.Name}}", 8 | {{end}}{{end -}} 9 | } 10 | -------------------------------------------------------------------------------- /templates/main/singleton/boil_view_names.go.tpl: -------------------------------------------------------------------------------- 1 | var ViewNames = struct { 2 | {{range $table := .Tables}}{{if $table.IsView -}} 3 | {{titleCase $table.Name}} string 4 | {{end}}{{end -}} 5 | }{ 6 | {{range $table := .Tables}}{{if $table.IsView -}} 7 | {{titleCase $table.Name}}: "{{$table.Name}}", 8 | {{end}}{{end -}} 9 | } 10 | 11 | -------------------------------------------------------------------------------- /templates/test/00_types.go.tpl: -------------------------------------------------------------------------------- 1 | var ( 2 | // Relationships sometimes use the reflection helper queries.Equal/queries.Assign 3 | // so force a package dependency in case they don't. 4 | _ = queries.Equal 5 | ) 6 | -------------------------------------------------------------------------------- /templates/test/all.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name -}} 2 | func test{{$alias.UpPlural}}(t *testing.T) { 3 | t.Parallel() 4 | 5 | query := {{$alias.UpPlural}}() 6 | 7 | if query.Query == nil { 8 | t.Error("expected a query, got nothing") 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /templates/test/exists.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Exists(t *testing.T) { 3 | t.Parallel() 4 | 5 | seed := randomize.NewSeed() 6 | var err error 7 | o := &{{$alias.UpSingular}}{} 8 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 9 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 10 | } 11 | 12 | {{if not .NoContext}}ctx := context.Background(){{end}} 13 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 14 | defer func() { _ = tx.Rollback() }() 15 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 16 | t.Error(err) 17 | } 18 | 19 | {{$pkeyArgs := .Table.PKey.Columns | stringMap (aliasCols $alias) | prefixStringSlice (printf "%s." "o") | join ", " -}} 20 | e, err := {{$alias.UpSingular}}Exists({{if not .NoContext}}ctx, {{end -}} tx, {{$pkeyArgs}}) 21 | if err != nil { 22 | t.Errorf("Unable to check if {{$alias.UpSingular}} exists: %s", err) 23 | } 24 | if !e { 25 | t.Errorf("Expected {{$alias.UpSingular}}Exists to return true, but got false.") 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /templates/test/find.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Find(t *testing.T) { 3 | t.Parallel() 4 | 5 | seed := randomize.NewSeed() 6 | var err error 7 | o := &{{$alias.UpSingular}}{} 8 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 9 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 10 | } 11 | 12 | {{if not .NoContext}}ctx := context.Background(){{end}} 13 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 14 | defer func() { _ = tx.Rollback() }() 15 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 16 | t.Error(err) 17 | } 18 | 19 | {{$alias.DownSingular}}Found, err := Find{{$alias.UpSingular}}({{if not .NoContext}}ctx, {{end -}} tx, {{.Table.PKey.Columns | stringMap (aliasCols $alias) | prefixStringSlice (printf "%s." "o") | join ", "}}) 20 | if err != nil { 21 | t.Error(err) 22 | } 23 | 24 | if {{$alias.DownSingular}}Found == nil { 25 | t.Error("want a record, got nil") 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /templates/test/finishers.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Bind(t *testing.T) { 3 | t.Parallel() 4 | 5 | seed := randomize.NewSeed() 6 | var err error 7 | o := &{{$alias.UpSingular}}{} 8 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 9 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 10 | } 11 | 12 | {{if not .NoContext}}ctx := context.Background(){{end}} 13 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 14 | defer func() { _ = tx.Rollback() }() 15 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 16 | t.Error(err) 17 | } 18 | 19 | if err = {{$alias.UpPlural}}().Bind({{if .NoContext}}nil{{else}}ctx{{end}}, tx, o); err != nil { 20 | t.Error(err) 21 | } 22 | } 23 | 24 | func test{{$alias.UpPlural}}One(t *testing.T) { 25 | t.Parallel() 26 | 27 | seed := randomize.NewSeed() 28 | var err error 29 | o := &{{$alias.UpSingular}}{} 30 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 31 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 32 | } 33 | 34 | {{if not .NoContext}}ctx := context.Background(){{end}} 35 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 36 | defer func() { _ = tx.Rollback() }() 37 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 38 | t.Error(err) 39 | } 40 | 41 | if x, err := {{$alias.UpPlural}}().One({{if not .NoContext}}ctx, {{end -}} tx); err != nil { 42 | t.Error(err) 43 | } else if x == nil { 44 | t.Error("expected to get a non nil record") 45 | } 46 | } 47 | 48 | func test{{$alias.UpPlural}}All(t *testing.T) { 49 | t.Parallel() 50 | 51 | seed := randomize.NewSeed() 52 | var err error 53 | {{$alias.DownSingular}}One := &{{$alias.UpSingular}}{} 54 | {{$alias.DownSingular}}Two := &{{$alias.UpSingular}}{} 55 | if err = randomize.Struct(seed, {{$alias.DownSingular}}One, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 56 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 57 | } 58 | if err = randomize.Struct(seed, {{$alias.DownSingular}}Two, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 59 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 60 | } 61 | 62 | {{if not .NoContext}}ctx := context.Background(){{end}} 63 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 64 | defer func() { _ = tx.Rollback() }() 65 | if err = {{$alias.DownSingular}}One.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 66 | t.Error(err) 67 | } 68 | if err = {{$alias.DownSingular}}Two.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 69 | t.Error(err) 70 | } 71 | 72 | slice, err := {{$alias.UpPlural}}().All({{if not .NoContext}}ctx, {{end -}} tx) 73 | if err != nil { 74 | t.Error(err) 75 | } 76 | 77 | if len(slice) != 2 { 78 | t.Error("want 2 records, got:", len(slice)) 79 | } 80 | } 81 | 82 | func test{{$alias.UpPlural}}Count(t *testing.T) { 83 | t.Parallel() 84 | 85 | var err error 86 | seed := randomize.NewSeed() 87 | {{$alias.DownSingular}}One := &{{$alias.UpSingular}}{} 88 | {{$alias.DownSingular}}Two := &{{$alias.UpSingular}}{} 89 | if err = randomize.Struct(seed, {{$alias.DownSingular}}One, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 90 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 91 | } 92 | if err = randomize.Struct(seed, {{$alias.DownSingular}}Two, {{$alias.DownSingular}}DBTypes, false, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 93 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 94 | } 95 | 96 | {{if not .NoContext}}ctx := context.Background(){{end}} 97 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 98 | defer func() { _ = tx.Rollback() }() 99 | if err = {{$alias.DownSingular}}One.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 100 | t.Error(err) 101 | } 102 | if err = {{$alias.DownSingular}}Two.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 103 | t.Error(err) 104 | } 105 | 106 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 107 | if err != nil { 108 | t.Error(err) 109 | } 110 | 111 | if count != 2 { 112 | t.Error("want 2 records, got:", count) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /templates/test/insert.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Insert(t *testing.T) { 3 | t.Parallel() 4 | 5 | seed := randomize.NewSeed() 6 | var err error 7 | o := &{{$alias.UpSingular}}{} 8 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 9 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 10 | } 11 | 12 | {{if not .NoContext}}ctx := context.Background(){{end}} 13 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 14 | defer func() { _ = tx.Rollback() }() 15 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 16 | t.Error(err) 17 | } 18 | 19 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 20 | if err != nil { 21 | t.Error(err) 22 | } 23 | 24 | if count != 1 { 25 | t.Error("want one record, got:", count) 26 | } 27 | } 28 | 29 | func test{{$alias.UpPlural}}InsertWhitelist(t *testing.T) { 30 | t.Parallel() 31 | 32 | seed := randomize.NewSeed() 33 | var err error 34 | o := &{{$alias.UpSingular}}{} 35 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true); err != nil { 36 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 37 | } 38 | 39 | {{if not .NoContext}}ctx := context.Background(){{end}} 40 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 41 | defer func() { _ = tx.Rollback() }() 42 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Whitelist(strmangle.SetMerge({{$alias.DownSingular}}PrimaryKeyColumns, {{$alias.DownSingular}}ColumnsWithoutDefault)...)); err != nil { 43 | t.Error(err) 44 | } 45 | 46 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 47 | if err != nil { 48 | t.Error(err) 49 | } 50 | 51 | if count != 1 { 52 | t.Error("want one record, got:", count) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /templates/test/relationship_one_to_one.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if .Table.IsJoinTable -}} 2 | {{- else -}} 3 | {{- range $rel := .Table.ToOneRelationships -}} 4 | {{- $ltable := $.Aliases.Table $rel.Table -}} 5 | {{- $ftable := $.Aliases.Table $rel.ForeignTable -}} 6 | {{- $relAlias := $ftable.Relationship $rel.Name -}} 7 | {{- $usesPrimitives := usesPrimitives $.Tables $rel.Table $rel.Column $rel.ForeignTable $rel.ForeignColumn -}} 8 | {{- $colField := $ltable.Column $rel.Column -}} 9 | {{- $fcolField := $ftable.Column $rel.ForeignColumn }} 10 | func test{{$ltable.UpSingular}}OneToOne{{$ftable.UpSingular}}Using{{$relAlias.Local}}(t *testing.T) { 11 | {{if not $.NoContext}}ctx := context.Background(){{end}} 12 | tx := MustTx({{if $.NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 13 | defer func() { _ = tx.Rollback() }() 14 | 15 | var foreign {{$ftable.UpSingular}} 16 | var local {{$ltable.UpSingular}} 17 | 18 | seed := randomize.NewSeed() 19 | if err := randomize.Struct(seed, &foreign, {{$ftable.DownSingular}}DBTypes, true, {{$ftable.DownSingular}}ColumnsWithDefault...); err != nil { 20 | t.Errorf("Unable to randomize {{$ftable.UpSingular}} struct: %s", err) 21 | } 22 | if err := randomize.Struct(seed, &local, {{$ltable.DownSingular}}DBTypes, true, {{$ltable.DownSingular}}ColumnsWithDefault...); err != nil { 23 | t.Errorf("Unable to randomize {{$ltable.UpSingular}} struct: %s", err) 24 | } 25 | 26 | if err := local.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | {{if $usesPrimitives -}} 31 | foreign.{{$fcolField}} = local.{{$colField}} 32 | {{else -}} 33 | queries.Assign(&foreign.{{$fcolField}}, local.{{$colField}}) 34 | {{end -}} 35 | if err := foreign.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | check, err := local.{{$relAlias.Local}}().One({{if not $.NoContext}}ctx, {{end -}} tx) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | {{if $usesPrimitives -}} 45 | if check.{{$fcolField}} != foreign.{{$fcolField}} { 46 | {{else -}} 47 | if !queries.Equal(check.{{$fcolField}}, foreign.{{$fcolField}}) { 48 | {{end -}} 49 | t.Errorf("want: %v, got %v", foreign.{{$fcolField}}, check.{{$fcolField}}) 50 | } 51 | 52 | {{if not $.NoHooks -}} 53 | ranAfterSelectHook := false 54 | Add{{$ftable.UpSingular}}Hook(boil.AfterSelectHook, func({{if not $.NoContext}}ctx context.Context, e boil.ContextExecutor{{else}}e boil.Executor{{end}}, o *{{$ftable.UpSingular}}) error { 55 | ranAfterSelectHook = true 56 | return nil 57 | }) 58 | {{- end}} 59 | 60 | slice := {{$ltable.UpSingular}}Slice{&local} 61 | if err = local.L.Load{{$relAlias.Local}}({{if not $.NoContext}}ctx, {{end -}} tx, false, (*[]*{{$ltable.UpSingular}})(&slice), nil); err != nil { 62 | t.Fatal(err) 63 | } 64 | if local.R.{{$relAlias.Local}} == nil { 65 | t.Error("struct should have been eager loaded") 66 | } 67 | 68 | local.R.{{$relAlias.Local}} = nil 69 | if err = local.L.Load{{$relAlias.Local}}({{if not $.NoContext}}ctx, {{end -}} tx, true, &local, nil); err != nil { 70 | t.Fatal(err) 71 | } 72 | if local.R.{{$relAlias.Local}} == nil { 73 | t.Error("struct should have been eager loaded") 74 | } 75 | 76 | {{if not $.NoHooks -}} 77 | if !ranAfterSelectHook { 78 | t.Error("failed to run AfterSelect hook for relationship") 79 | } 80 | {{- end}} 81 | } 82 | 83 | {{end -}}{{/* range */}} 84 | {{- end -}}{{/* join table */}} 85 | -------------------------------------------------------------------------------- /templates/test/relationship_to_many.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if .Table.IsJoinTable -}} 2 | {{- else -}} 3 | {{- $table := .Table }} 4 | {{- range $rel := .Table.ToManyRelationships -}} 5 | {{- $ltable := $.Aliases.Table $rel.Table -}} 6 | {{- $ftable := $.Aliases.Table $rel.ForeignTable -}} 7 | {{- $relAlias := $.Aliases.ManyRelationship $rel.ForeignTable $rel.Name $rel.JoinTable $rel.JoinLocalFKeyName -}} 8 | {{- $colField := $ltable.Column $rel.Column -}} 9 | {{- $fcolField := $ftable.Column $rel.ForeignColumn -}} 10 | {{- $usesPrimitives := usesPrimitives $.Tables $rel.Table $rel.Column $rel.ForeignTable $rel.ForeignColumn -}} 11 | {{- $schemaForeignTable := .ForeignTable | $.SchemaTable }} 12 | func test{{$ltable.UpSingular}}ToMany{{$relAlias.Local}}(t *testing.T) { 13 | var err error 14 | {{if not $.NoContext}}ctx := context.Background(){{end}} 15 | tx := MustTx({{if $.NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 16 | defer func() { _ = tx.Rollback() }() 17 | 18 | var a {{$ltable.UpSingular}} 19 | var b, c {{$ftable.UpSingular}} 20 | 21 | seed := randomize.NewSeed() 22 | if err = randomize.Struct(seed, &a, {{$ltable.DownSingular}}DBTypes, true, {{$ltable.DownSingular}}ColumnsWithDefault...); err != nil { 23 | t.Errorf("Unable to randomize {{$ltable.UpSingular}} struct: %s", err) 24 | } 25 | 26 | if err := a.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | if err = randomize.Struct(seed, &b, {{$ftable.DownSingular}}DBTypes, false, {{$ftable.DownSingular}}ColumnsWithDefault...); err != nil { 31 | t.Fatal(err) 32 | } 33 | if err = randomize.Struct(seed, &c, {{$ftable.DownSingular}}DBTypes, false, {{$ftable.DownSingular}}ColumnsWithDefault...); err != nil { 34 | t.Fatal(err) 35 | } 36 | 37 | {{if not .ToJoinTable -}} 38 | {{if $usesPrimitives}} 39 | b.{{$fcolField}} = a.{{$colField}} 40 | c.{{$fcolField}} = a.{{$colField}} 41 | {{else -}} 42 | queries.Assign(&b.{{$fcolField}}, a.{{$colField}}) 43 | queries.Assign(&c.{{$fcolField}}, a.{{$colField}}) 44 | {{- end}} 45 | {{- end}} 46 | if err = b.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 47 | t.Fatal(err) 48 | } 49 | if err = c.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | {{if .ToJoinTable -}} 54 | _, err = tx.Exec("insert into {{.JoinTable | $.SchemaTable}} ({{.JoinLocalColumn | $.Quotes}}, {{.JoinForeignColumn | $.Quotes}}) values {{if $.Dialect.UseIndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$colField}}, b.{{$fcolField}}) 55 | if err != nil { 56 | t.Fatal(err) 57 | } 58 | _, err = tx.Exec("insert into {{.JoinTable | $.SchemaTable}} ({{.JoinLocalColumn | $.Quotes}}, {{.JoinForeignColumn | $.Quotes}}) values {{if $.Dialect.UseIndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$colField}}, c.{{$fcolField}}) 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | {{end}} 63 | 64 | check, err := a.{{$relAlias.Local}}().All({{if not $.NoContext}}ctx, {{end -}} tx) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | 69 | bFound, cFound := false, false 70 | for _, v := range check { 71 | {{if $usesPrimitives -}} 72 | if v.{{$fcolField}} == b.{{$fcolField}} { 73 | bFound = true 74 | } 75 | if v.{{$fcolField}} == c.{{$fcolField}} { 76 | cFound = true 77 | } 78 | {{else -}} 79 | if queries.Equal(v.{{$fcolField}}, b.{{$fcolField}}) { 80 | bFound = true 81 | } 82 | if queries.Equal(v.{{$fcolField}}, c.{{$fcolField}}) { 83 | cFound = true 84 | } 85 | {{end -}} 86 | } 87 | 88 | if !bFound { 89 | t.Error("expected to find b") 90 | } 91 | if !cFound { 92 | t.Error("expected to find c") 93 | } 94 | 95 | slice := {{$ltable.UpSingular}}Slice{&a} 96 | if err = a.L.Load{{$relAlias.Local}}({{if not $.NoContext}}ctx, {{end -}} tx, false, (*[]*{{$ltable.UpSingular}})(&slice), nil); err != nil { 97 | t.Fatal(err) 98 | } 99 | if got := len(a.R.{{$relAlias.Local}}); got != 2 { 100 | t.Error("number of eager loaded records wrong, got:", got) 101 | } 102 | 103 | a.R.{{$relAlias.Local}} = nil 104 | if err = a.L.Load{{$relAlias.Local}}({{if not $.NoContext}}ctx, {{end -}} tx, true, &a, nil); err != nil { 105 | t.Fatal(err) 106 | } 107 | if got := len(a.R.{{$relAlias.Local}}); got != 2 { 108 | t.Error("number of eager loaded records wrong, got:", got) 109 | } 110 | 111 | if t.Failed() { 112 | t.Logf("%#v", check) 113 | } 114 | } 115 | 116 | {{end -}}{{- /* range */ -}} 117 | {{- end -}}{{- /* outer if join table */ -}} 118 | -------------------------------------------------------------------------------- /templates/test/relationship_to_one.go.tpl: -------------------------------------------------------------------------------- 1 | {{- if .Table.IsJoinTable -}} 2 | {{- else -}} 3 | {{- range $fkey := .Table.FKeys -}} 4 | {{- $ltable := $.Aliases.Table $fkey.Table -}} 5 | {{- $ftable := $.Aliases.Table $fkey.ForeignTable -}} 6 | {{- $rel := $ltable.Relationship $fkey.Name -}} 7 | {{- $colField := $ltable.Column $fkey.Column -}} 8 | {{- $fcolField := $ftable.Column $fkey.ForeignColumn -}} 9 | {{- $usesPrimitives := usesPrimitives $.Tables $fkey.Table $fkey.Column $fkey.ForeignTable $fkey.ForeignColumn }} 10 | func test{{$ltable.UpSingular}}ToOne{{$ftable.UpSingular}}Using{{$rel.Foreign}}(t *testing.T) { 11 | {{if not $.NoContext}}ctx := context.Background(){{end}} 12 | tx := MustTx({{if $.NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 13 | defer func() { _ = tx.Rollback() }() 14 | 15 | var local {{$ltable.UpSingular}} 16 | var foreign {{$ftable.UpSingular}} 17 | 18 | seed := randomize.NewSeed() 19 | if err := randomize.Struct(seed, &local, {{$ltable.DownSingular}}DBTypes, {{if $fkey.Nullable}}true{{else}}false{{end}}, {{$ltable.DownSingular}}ColumnsWithDefault...); err != nil { 20 | t.Errorf("Unable to randomize {{$ltable.UpSingular}} struct: %s", err) 21 | } 22 | if err := randomize.Struct(seed, &foreign, {{$ftable.DownSingular}}DBTypes, {{if $fkey.ForeignColumnNullable}}true{{else}}false{{end}}, {{$ftable.DownSingular}}ColumnsWithDefault...); err != nil { 23 | t.Errorf("Unable to randomize {{$ftable.UpSingular}} struct: %s", err) 24 | } 25 | 26 | if err := foreign.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | {{if $usesPrimitives -}} 31 | local.{{$colField}} = foreign.{{$fcolField}} 32 | {{else -}} 33 | queries.Assign(&local.{{$colField}}, foreign.{{$fcolField}}) 34 | {{end -}} 35 | if err := local.Insert({{if not $.NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | check, err := local.{{$rel.Foreign}}().One({{if not $.NoContext}}ctx, {{end -}} tx) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | {{if $usesPrimitives -}} 45 | if check.{{$fcolField}} != foreign.{{$fcolField}} { 46 | {{else -}} 47 | if !queries.Equal(check.{{$fcolField}}, foreign.{{$fcolField}}) { 48 | {{end -}} 49 | t.Errorf("want: %v, got %v", foreign.{{$fcolField}}, check.{{$fcolField}}) 50 | } 51 | 52 | {{if not $.NoHooks -}} 53 | ranAfterSelectHook := false 54 | Add{{$ftable.UpSingular}}Hook(boil.AfterSelectHook, func({{if not $.NoContext}}ctx context.Context, e boil.ContextExecutor{{else}}e boil.Executor{{end}}, o *{{$ftable.UpSingular}}) error { 55 | ranAfterSelectHook = true 56 | return nil 57 | }) 58 | {{- end}} 59 | 60 | slice := {{$ltable.UpSingular}}Slice{&local} 61 | if err = local.L.Load{{$rel.Foreign}}({{if not $.NoContext}}ctx, {{end -}} tx, false, (*[]*{{$ltable.UpSingular}})(&slice), nil); err != nil { 62 | t.Fatal(err) 63 | } 64 | if local.R.{{$rel.Foreign}} == nil { 65 | t.Error("struct should have been eager loaded") 66 | } 67 | 68 | local.R.{{$rel.Foreign}} = nil 69 | if err = local.L.Load{{$rel.Foreign}}({{if not $.NoContext}}ctx, {{end -}} tx, true, &local, nil); err != nil { 70 | t.Fatal(err) 71 | } 72 | if local.R.{{$rel.Foreign}} == nil { 73 | t.Error("struct should have been eager loaded") 74 | } 75 | 76 | {{if not $.NoHooks -}} 77 | if !ranAfterSelectHook { 78 | t.Error("failed to run AfterSelect hook for relationship") 79 | } 80 | {{- end}} 81 | } 82 | 83 | {{end -}}{{/* range */}} 84 | {{- end -}}{{/* join table */}} 85 | -------------------------------------------------------------------------------- /templates/test/reload.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Reload(t *testing.T) { 3 | t.Parallel() 4 | 5 | seed := randomize.NewSeed() 6 | var err error 7 | o := &{{$alias.UpSingular}}{} 8 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 9 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 10 | } 11 | 12 | {{if not .NoContext}}ctx := context.Background(){{end}} 13 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 14 | defer func() { _ = tx.Rollback() }() 15 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 16 | t.Error(err) 17 | } 18 | 19 | if err = o.Reload({{if not .NoContext}}ctx, {{end -}} tx); err != nil { 20 | t.Error(err) 21 | } 22 | } 23 | 24 | func test{{$alias.UpPlural}}ReloadAll(t *testing.T) { 25 | t.Parallel() 26 | 27 | seed := randomize.NewSeed() 28 | var err error 29 | o := &{{$alias.UpSingular}}{} 30 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 31 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 32 | } 33 | 34 | {{if not .NoContext}}ctx := context.Background(){{end}} 35 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 36 | defer func() { _ = tx.Rollback() }() 37 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 38 | t.Error(err) 39 | } 40 | 41 | slice := {{$alias.UpSingular}}Slice{{"{"}}o{{"}"}} 42 | 43 | if err = slice.ReloadAll({{if not .NoContext}}ctx, {{end -}} tx); err != nil { 44 | t.Error(err) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /templates/test/select.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Select(t *testing.T) { 3 | t.Parallel() 4 | 5 | seed := randomize.NewSeed() 6 | var err error 7 | o := &{{$alias.UpSingular}}{} 8 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 9 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 10 | } 11 | 12 | {{if not .NoContext}}ctx := context.Background(){{end}} 13 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 14 | defer func() { _ = tx.Rollback() }() 15 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 16 | t.Error(err) 17 | } 18 | 19 | slice, err := {{$alias.UpPlural}}().All({{if not .NoContext}}ctx, {{end -}} tx) 20 | if err != nil { 21 | t.Error(err) 22 | } 23 | 24 | if len(slice) != 1 { 25 | t.Error("want one record, got:", len(slice)) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /templates/test/singleton/boil_main_test.go.tpl: -------------------------------------------------------------------------------- 1 | var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") 2 | var flagConfigFile = flag.String("test.config", "", "Overrides the default config") 3 | 4 | const outputDirDepth = {{.OutputDirDepth}} 5 | 6 | var ( 7 | dbMain tester 8 | ) 9 | 10 | type tester interface { 11 | setup() error 12 | conn() (*sql.DB, error) 13 | teardown() error 14 | } 15 | 16 | func TestMain(m *testing.M) { 17 | if dbMain == nil { 18 | fmt.Println("no dbMain tester interface was ready") 19 | os.Exit(-1) 20 | } 21 | 22 | rand.New(rand.NewSource(time.Now().UnixNano())) 23 | 24 | flag.Parse() 25 | 26 | var err error 27 | 28 | // Load configuration 29 | err = initViper() 30 | if err != nil { 31 | fmt.Println("unable to load config file") 32 | os.Exit(-2) 33 | } 34 | 35 | // Set DebugMode so we can see generated sql statements 36 | boil.DebugMode = *flagDebugMode 37 | 38 | if err = dbMain.setup(); err != nil { 39 | fmt.Println("Unable to execute setup:", err) 40 | os.Exit(-4) 41 | } 42 | 43 | conn, err := dbMain.conn() 44 | if err != nil { 45 | fmt.Println("failed to get connection:", err) 46 | } 47 | 48 | var code int 49 | boil.SetDB(conn) 50 | code = m.Run() 51 | 52 | if err = dbMain.teardown(); err != nil { 53 | fmt.Println("Unable to execute teardown:", err) 54 | os.Exit(-5) 55 | } 56 | 57 | os.Exit(code) 58 | } 59 | 60 | func initViper() error { 61 | if flagConfigFile != nil && *flagConfigFile != "" { 62 | viper.SetConfigFile(*flagConfigFile) 63 | if err := viper.ReadInConfig(); err != nil { 64 | return err 65 | } 66 | return nil 67 | } 68 | 69 | var err error 70 | 71 | viper.SetConfigName("sqlboiler") 72 | 73 | configHome := os.Getenv("XDG_CONFIG_HOME") 74 | homePath := os.Getenv("HOME") 75 | wd, err := os.Getwd() 76 | if err != nil { 77 | wd = strings.Repeat("../", outputDirDepth) 78 | } else { 79 | wd = wd + strings.Repeat("/..", outputDirDepth) 80 | } 81 | 82 | configPaths := []string{wd} 83 | if len(configHome) > 0 { 84 | configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) 85 | } else { 86 | configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) 87 | } 88 | 89 | for _, p := range configPaths { 90 | viper.AddConfigPath(p) 91 | } 92 | 93 | // Ignore errors here, fall back to defaults and validation to provide errs 94 | _ = viper.ReadInConfig() 95 | viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 96 | viper.AutomaticEnv() 97 | 98 | return nil 99 | } 100 | -------------------------------------------------------------------------------- /templates/test/singleton/boil_queries_test.go.tpl: -------------------------------------------------------------------------------- 1 | var dbNameRand *rand.Rand 2 | 3 | {{if .NoContext -}} 4 | func MustTx(transactor boil.Transactor, err error) boil.Transactor { 5 | if err != nil { 6 | panic(fmt.Sprintf("Cannot create a transactor: %s", err)) 7 | } 8 | return transactor 9 | } 10 | {{- else -}} 11 | func MustTx(transactor boil.ContextTransactor, err error) boil.ContextTransactor { 12 | if err != nil { 13 | panic(fmt.Sprintf("Cannot create a transactor: %s", err)) 14 | } 15 | return transactor 16 | } 17 | {{- end}} 18 | 19 | func newFKeyDestroyer(regex *regexp.Regexp, reader io.Reader) io.Reader { 20 | return &fKeyDestroyer{ 21 | reader: reader, 22 | rgx: regex, 23 | } 24 | } 25 | 26 | type fKeyDestroyer struct { 27 | reader io.Reader 28 | buf *bytes.Buffer 29 | rgx *regexp.Regexp 30 | } 31 | 32 | func (f *fKeyDestroyer) Read(b []byte) (int, error) { 33 | if f.buf == nil { 34 | all, err := io.ReadAll(f.reader) 35 | if err != nil { 36 | return 0, err 37 | } 38 | 39 | all = bytes.Replace(all, []byte{'\r', '\n'}, []byte{'\n'}, -1) 40 | all = f.rgx.ReplaceAll(all, []byte{}) 41 | f.buf = bytes.NewBuffer(all) 42 | } 43 | 44 | return f.buf.Read(b) 45 | } 46 | 47 | -------------------------------------------------------------------------------- /templates/test/types.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | var ( 3 | {{$alias.DownSingular}}DBTypes = map[string]string{{"{"}}{{range $i, $col := .Table.Columns -}}{{- if ne $i 0}},{{end}}`{{$alias.Column $col.Name}}`: `{{$col.DBType}}`{{end}}{{"}"}} 4 | _ = bytes.MinRead 5 | ) 6 | -------------------------------------------------------------------------------- /templates/test/update.go.tpl: -------------------------------------------------------------------------------- 1 | {{- $alias := .Aliases.Table .Table.Name}} 2 | func test{{$alias.UpPlural}}Update(t *testing.T) { 3 | t.Parallel() 4 | 5 | if 0 == len({{$alias.DownSingular}}PrimaryKeyColumns) { 6 | t.Skip("Skipping table with no primary key columns") 7 | } 8 | if len({{$alias.DownSingular}}AllColumns) == len({{$alias.DownSingular}}PrimaryKeyColumns) { 9 | t.Skip("Skipping table with only primary key columns") 10 | } 11 | 12 | seed := randomize.NewSeed() 13 | var err error 14 | o := &{{$alias.UpSingular}}{} 15 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 16 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 17 | } 18 | 19 | {{if not .NoContext}}ctx := context.Background(){{end}} 20 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 21 | defer func() { _ = tx.Rollback() }() 22 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 23 | t.Error(err) 24 | } 25 | 26 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 27 | if err != nil { 28 | t.Error(err) 29 | } 30 | 31 | if count != 1 { 32 | t.Error("want one record, got:", count) 33 | } 34 | 35 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}PrimaryKeyColumns...); err != nil { 36 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 37 | } 38 | 39 | {{if .NoRowsAffected -}} 40 | if err = o.Update({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 41 | t.Error(err) 42 | } 43 | {{else -}} 44 | if rowsAff, err := o.Update({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 45 | t.Error(err) 46 | } else if rowsAff != 1 { 47 | t.Error("should only affect one row but affected", rowsAff) 48 | } 49 | {{end -}} 50 | } 51 | 52 | func test{{$alias.UpPlural}}SliceUpdateAll(t *testing.T) { 53 | t.Parallel() 54 | 55 | if len({{$alias.DownSingular}}AllColumns) == len({{$alias.DownSingular}}PrimaryKeyColumns) { 56 | t.Skip("Skipping table with only primary key columns") 57 | } 58 | 59 | seed := randomize.NewSeed() 60 | var err error 61 | o := &{{$alias.UpSingular}}{} 62 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}ColumnsWithDefault...); err != nil { 63 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 64 | } 65 | 66 | {{if not .NoContext}}ctx := context.Background(){{end}} 67 | tx := MustTx({{if .NoContext}}boil.Begin(){{else}}boil.BeginTx(ctx, nil){{end}}) 68 | defer func() { _ = tx.Rollback() }() 69 | if err = o.Insert({{if not .NoContext}}ctx, {{end -}} tx, boil.Infer()); err != nil { 70 | t.Error(err) 71 | } 72 | 73 | count, err := {{$alias.UpPlural}}().Count({{if not .NoContext}}ctx, {{end -}} tx) 74 | if err != nil { 75 | t.Error(err) 76 | } 77 | 78 | if count != 1 { 79 | t.Error("want one record, got:", count) 80 | } 81 | 82 | if err = randomize.Struct(seed, o, {{$alias.DownSingular}}DBTypes, true, {{$alias.DownSingular}}PrimaryKeyColumns...); err != nil { 83 | t.Errorf("Unable to randomize {{$alias.UpSingular}} struct: %s", err) 84 | } 85 | 86 | // Remove Primary keys and unique columns from what we plan to update 87 | var fields []string 88 | if strmangle.StringSliceMatch({{$alias.DownSingular}}AllColumns, {{$alias.DownSingular}}PrimaryKeyColumns) { 89 | fields = {{$alias.DownSingular}}AllColumns 90 | } else { 91 | fields = strmangle.SetComplement( 92 | {{$alias.DownSingular}}AllColumns, 93 | {{$alias.DownSingular}}PrimaryKeyColumns, 94 | ) 95 | {{- if filterColumnsByAuto true .Table.Columns }} 96 | fields = strmangle.SetComplement(fields, {{$alias.DownSingular}}GeneratedColumns) 97 | {{- end}} 98 | } 99 | 100 | value := reflect.Indirect(reflect.ValueOf(o)) 101 | typ := reflect.TypeOf(o).Elem() 102 | n := typ.NumField() 103 | 104 | updateMap := M{} 105 | for _, col := range fields { 106 | for i := 0; i < n; i++ { 107 | f := typ.Field(i) 108 | if f.Tag.Get("boil") == col { 109 | updateMap[col] = value.Field(i).Interface() 110 | } 111 | } 112 | } 113 | 114 | slice := {{$alias.UpSingular}}Slice{{"{"}}o{{"}"}} 115 | {{if .NoRowsAffected -}} 116 | if err = slice.UpdateAll({{if not .NoContext}}ctx, {{end -}} tx, updateMap); err != nil { 117 | t.Error(err) 118 | } 119 | {{else -}} 120 | if rowsAff, err := slice.UpdateAll({{if not .NoContext}}ctx, {{end -}} tx, updateMap); err != nil { 121 | t.Error(err) 122 | } else if rowsAff != 1 { 123 | t.Error("wanted one record updated but got", rowsAff) 124 | } 125 | {{end -}} 126 | } 127 | -------------------------------------------------------------------------------- /testdata/Dockerfile: -------------------------------------------------------------------------------- 1 | # This Dockerfile builds the image used for CI/testing. 2 | FROM ubuntu:16.04 3 | 4 | # Set PATH 5 | ENV PATH /usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin:/opt/mssql-tools/bin 6 | 7 | # Install bootstrap-y tools 8 | RUN apt-get update \ 9 | && apt-get install -y apt-transport-https software-properties-common python3-software-properties \ 10 | && apt-add-repository ppa:git-core/ppa \ 11 | && apt-get update \ 12 | && apt-get install -y curl git make locales 13 | 14 | # Set up locales for sqlcmd (otherwise it breaks) 15 | RUN locale-gen en_US.UTF-8 \ 16 | && echo "LC_ALL=en_US.UTF-8" >> /etc/default/locale \ 17 | && echo "LANG=en_US.UTF-8" >> /etc/default/locale 18 | 19 | # Install database clients 20 | # MySQL 8.0 is still in development, so we're using 5.7 which is already 21 | # available in Ubuntu 16.04 22 | RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - \ 23 | && echo 'deb http://apt.postgresql.org/pub/repos/apt/ xenial-pgdg main' > /etc/apt/sources.list.d/psql.list \ 24 | && curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add - \ 25 | && curl https://packages.microsoft.com/config/ubuntu/16.04/prod.list > /etc/apt/sources.list.d/msprod.list \ 26 | && apt-get update \ 27 | && env ACCEPT_EULA=Y apt-get install -y git postgresql-client-9.6 mysql-client-5.7 mssql-tools unixodbc-dev 28 | 29 | # Install Go + Go based tooling 30 | ENV GOLANG_VERSION 1.10 31 | RUN curl -o go.tar.gz "https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-amd64.tar.gz" \ 32 | && rm -rf /usr/local/go \ 33 | && tar -C /usr/local -xzf go.tar.gz \ 34 | && go get -u -v github.com/jstemmer/go-junit-report \ 35 | && mv /root/go/bin/go-junit-report /usr/bin/go-junit-report 36 | -------------------------------------------------------------------------------- /types/byte.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "errors" 7 | ) 8 | 9 | // Byte is an alias for byte. 10 | // Byte implements Marshal and Unmarshal. 11 | type Byte byte 12 | 13 | // String output your byte. 14 | func (b Byte) String() string { 15 | return string(b) 16 | } 17 | 18 | // UnmarshalJSON sets *b to a copy of data. 19 | func (b *Byte) UnmarshalJSON(data []byte) error { 20 | if b == nil { 21 | return errors.New("json: unmarshal json on nil pointer to byte") 22 | } 23 | 24 | var x string 25 | if err := json.Unmarshal(data, &x); err != nil { 26 | return err 27 | } 28 | 29 | if len(x) > 1 { 30 | return errors.New("json: cannot convert to byte, text len is greater than one") 31 | } 32 | 33 | *b = Byte(x[0]) 34 | return nil 35 | } 36 | 37 | // MarshalJSON returns the JSON encoding of b. 38 | func (b Byte) MarshalJSON() ([]byte, error) { 39 | return []byte{'"', byte(b), '"'}, nil 40 | } 41 | 42 | // Value returns b as a driver.Value. 43 | func (b Byte) Value() (driver.Value, error) { 44 | return []byte{byte(b)}, nil 45 | } 46 | 47 | // Scan stores the src in *b. 48 | func (b *Byte) Scan(src interface{}) error { 49 | switch src.(type) { 50 | case uint8: 51 | *b = Byte(src.(uint8)) 52 | case string: 53 | *b = Byte(src.(string)[0]) 54 | case []byte: 55 | *b = Byte(src.([]byte)[0]) 56 | default: 57 | return errors.New("incompatible type for byte") 58 | } 59 | 60 | return nil 61 | } 62 | 63 | // Randomize for sqlboiler 64 | func (b *Byte) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 65 | if shouldBeNull { 66 | *b = Byte(65) // Can't deal with a true 0-value 67 | } 68 | 69 | *b = Byte(nextInt()%60 + 65) // Can't deal with non-ascii characters in some databases 70 | } 71 | -------------------------------------------------------------------------------- /types/byte_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "testing" 7 | ) 8 | 9 | func TestByteString(t *testing.T) { 10 | t.Parallel() 11 | 12 | b := Byte('b') 13 | if b.String() != "b" { 14 | t.Errorf("Expected %q, got %s", "b", b.String()) 15 | } 16 | } 17 | 18 | func TestByteUnmarshal(t *testing.T) { 19 | t.Parallel() 20 | 21 | var b Byte 22 | err := json.Unmarshal([]byte(`"b"`), &b) 23 | if err != nil { 24 | t.Error(err) 25 | } 26 | 27 | if b != 'b' { 28 | t.Errorf("Expected %q, got %s", "b", b) 29 | } 30 | } 31 | 32 | func TestByteMarshal(t *testing.T) { 33 | t.Parallel() 34 | 35 | b := Byte('b') 36 | res, err := json.Marshal(&b) 37 | if err != nil { 38 | t.Error(err) 39 | } 40 | 41 | if !bytes.Equal(res, []byte(`"b"`)) { 42 | t.Errorf("expected %s, got %s", `"b"`, b.String()) 43 | } 44 | } 45 | 46 | func TestByteValue(t *testing.T) { 47 | t.Parallel() 48 | 49 | b := Byte('b') 50 | v, err := b.Value() 51 | if err != nil { 52 | t.Error(err) 53 | } 54 | 55 | if !bytes.Equal([]byte{byte(b)}, v.([]byte)) { 56 | t.Errorf("byte mismatch, %v %v", b, v) 57 | } 58 | } 59 | 60 | func TestByteScan(t *testing.T) { 61 | t.Parallel() 62 | 63 | var b Byte 64 | 65 | s := "b" 66 | err := b.Scan(s) 67 | if err != nil { 68 | t.Error(err) 69 | } 70 | 71 | if !bytes.Equal([]byte{byte(b)}, []byte{'b'}) { 72 | t.Errorf("bad []byte: %#v ≠ %#v\n", b, "b") 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /types/hstore.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining 4 | // a copy of this software and associated documentation files (the "Software"), 5 | // to deal in the Software without restriction, including without limitation the 6 | // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software 8 | // is furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included 11 | // in all copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | // INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 15 | // PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 16 | // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 | // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 18 | // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | package types 21 | 22 | import ( 23 | "database/sql" 24 | "database/sql/driver" 25 | "strings" 26 | 27 | "github.com/volatiletech/null/v8" 28 | "github.com/volatiletech/randomize" 29 | ) 30 | 31 | // HStore is a wrapper for transferring HStore values back and forth easily. 32 | type HStore map[string]null.String 33 | 34 | // escapes and quotes hstore keys/values 35 | // s should be a sql.NullString or string 36 | func hQuote(s interface{}) string { 37 | var str string 38 | switch v := s.(type) { 39 | case null.String: 40 | if !v.Valid { 41 | return "NULL" 42 | } 43 | str = v.String 44 | case sql.NullString: 45 | if !v.Valid { 46 | return "NULL" 47 | } 48 | str = v.String 49 | case string: 50 | str = v 51 | default: 52 | panic("not a string or sql.NullString") 53 | } 54 | 55 | str = strings.ReplaceAll(str, "\\", "\\\\") 56 | return `"` + strings.ReplaceAll(str, "\"", "\\\"") + `"` 57 | } 58 | 59 | // Scan implements the Scanner interface. 60 | // 61 | // Note h is reallocated before the scan to clear existing values. If the 62 | // hstore column's database value is NULL, then h is set to nil instead. 63 | func (h *HStore) Scan(value interface{}) error { 64 | if value == nil { 65 | h = nil 66 | return nil 67 | } 68 | *h = make(map[string]null.String) 69 | var b byte 70 | pair := [][]byte{{}, {}} 71 | pi := 0 72 | inQuote := false 73 | didQuote := false 74 | sawSlash := false 75 | bindex := 0 76 | for bindex, b = range value.([]byte) { 77 | if sawSlash { 78 | pair[pi] = append(pair[pi], b) 79 | sawSlash = false 80 | continue 81 | } 82 | 83 | switch b { 84 | case '\\': 85 | sawSlash = true 86 | continue 87 | case '"': 88 | inQuote = !inQuote 89 | if !didQuote { 90 | didQuote = true 91 | } 92 | continue 93 | default: 94 | if !inQuote { 95 | switch b { 96 | case ' ', '\t', '\n', '\r': 97 | continue 98 | case '=': 99 | continue 100 | case '>': 101 | pi = 1 102 | didQuote = false 103 | continue 104 | case ',': 105 | s := string(pair[1]) 106 | if !didQuote && len(s) == 4 && strings.EqualFold(s, "null") { 107 | (*h)[string(pair[0])] = null.String{String: "", Valid: false} 108 | } else { 109 | (*h)[string(pair[0])] = null.String{String: string(pair[1]), Valid: true} 110 | } 111 | pair[0] = []byte{} 112 | pair[1] = []byte{} 113 | pi = 0 114 | continue 115 | } 116 | } 117 | } 118 | pair[pi] = append(pair[pi], b) 119 | } 120 | if bindex > 0 { 121 | s := string(pair[1]) 122 | if !didQuote && len(s) == 4 && strings.EqualFold(s, "null") { 123 | (*h)[string(pair[0])] = null.String{String: "", Valid: false} 124 | } else { 125 | (*h)[string(pair[0])] = null.String{String: string(pair[1]), Valid: true} 126 | } 127 | } 128 | return nil 129 | } 130 | 131 | // Value implements the driver Valuer interface. Note if h is nil, the 132 | // database column value will be set to NULL. 133 | func (h HStore) Value() (driver.Value, error) { 134 | if h == nil { 135 | return nil, nil 136 | } 137 | parts := []string{} 138 | for key, val := range h { 139 | thispart := hQuote(key) + "=>" + hQuote(val) 140 | parts = append(parts, thispart) 141 | } 142 | return []byte(strings.Join(parts, ",")), nil 143 | } 144 | 145 | // Randomize for sqlboiler 146 | func (h *HStore) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 147 | if shouldBeNull { 148 | *h = nil 149 | return 150 | } 151 | 152 | *h = make(map[string]null.String) 153 | (*h)[randomize.Str(nextInt, 3)] = null.String{String: randomize.Str(nextInt, 3), Valid: nextInt()%3 == 0} 154 | (*h)[randomize.Str(nextInt, 3)] = null.String{String: randomize.Str(nextInt, 3), Valid: nextInt()%3 == 0} 155 | } 156 | -------------------------------------------------------------------------------- /types/json.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/volatiletech/randomize" 9 | ) 10 | 11 | // JSON is an alias for json.RawMessage, which is 12 | // a []byte underneath. 13 | // JSON implements Marshal and Unmarshal. 14 | type JSON json.RawMessage 15 | 16 | // String output your JSON. 17 | func (j JSON) String() string { 18 | return string(j) 19 | } 20 | 21 | // Unmarshal your JSON variable into dest. 22 | func (j JSON) Unmarshal(dest interface{}) error { 23 | return json.Unmarshal(j, dest) 24 | } 25 | 26 | // Marshal obj into your JSON variable. 27 | func (j *JSON) Marshal(obj interface{}) error { 28 | res, err := json.Marshal(obj) 29 | if err != nil { 30 | return err 31 | } 32 | 33 | *j = res 34 | return nil 35 | } 36 | 37 | // UnmarshalJSON sets *j to a copy of data. 38 | func (j *JSON) UnmarshalJSON(data []byte) error { 39 | if j == nil { 40 | return errors.New("json: unmarshal json on nil pointer to json") 41 | } 42 | 43 | *j = append((*j)[0:0], data...) 44 | return nil 45 | } 46 | 47 | // MarshalJSON returns j as the JSON encoding of j. 48 | func (j JSON) MarshalJSON() ([]byte, error) { 49 | if j == nil { 50 | return []byte("null"), nil 51 | } 52 | return j, nil 53 | } 54 | 55 | // Value returns j as a value. 56 | // Unmarshal into RawMessage for validation. 57 | func (j JSON) Value() (driver.Value, error) { 58 | var r json.RawMessage 59 | if err := j.Unmarshal(&r); err != nil { 60 | return nil, err 61 | } 62 | 63 | return []byte(r), nil 64 | } 65 | 66 | // Scan stores the src in *j. 67 | func (j *JSON) Scan(src interface{}) error { 68 | switch source := src.(type) { 69 | case string: 70 | *j = append((*j)[0:0], source...) 71 | return nil 72 | case []byte: 73 | *j = append((*j)[0:0], source...) 74 | return nil 75 | default: 76 | return errors.New("incompatible type for json") 77 | } 78 | } 79 | 80 | // Randomize for sqlboiler 81 | func (j *JSON) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 82 | *j = []byte(`"` + randomize.Str(nextInt, 1) + `"`) 83 | } 84 | -------------------------------------------------------------------------------- /types/json_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestJSONString(t *testing.T) { 10 | t.Parallel() 11 | 12 | j := JSON("hello") 13 | if j.String() != "hello" { 14 | t.Errorf("Expected %q, got %s", "hello", j.String()) 15 | } 16 | } 17 | 18 | func TestJSONUnmarshal(t *testing.T) { 19 | t.Parallel() 20 | 21 | type JSONTest struct { 22 | Name string 23 | Age int 24 | } 25 | var jt JSONTest 26 | 27 | j := JSON(`{"Name":"hi","Age":15}`) 28 | err := j.Unmarshal(&jt) 29 | if err != nil { 30 | t.Error(err) 31 | } 32 | 33 | if jt.Name != "hi" { 34 | t.Errorf("Expected %q, got %s", "hi", jt.Name) 35 | } 36 | if jt.Age != 15 { 37 | t.Errorf("Expected %v, got %v", 15, jt.Age) 38 | } 39 | } 40 | 41 | func TestJSONMarshal(t *testing.T) { 42 | t.Parallel() 43 | 44 | type JSONTest struct { 45 | Name string 46 | Age int 47 | } 48 | jt := JSONTest{ 49 | Name: "hi", 50 | Age: 15, 51 | } 52 | 53 | var j JSON 54 | err := j.Marshal(jt) 55 | if err != nil { 56 | t.Error(err) 57 | } 58 | 59 | if j.String() != `{"Name":"hi","Age":15}` { 60 | t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String()) 61 | } 62 | } 63 | 64 | func TestJSONUnmarshalJSON(t *testing.T) { 65 | t.Parallel() 66 | 67 | j := JSON(nil) 68 | 69 | err := j.UnmarshalJSON(JSON(`"hi"`)) 70 | if err != nil { 71 | t.Error(err) 72 | } 73 | 74 | if j.String() != `"hi"` { 75 | t.Errorf("Expected %q, got %s", "hi", j.String()) 76 | } 77 | } 78 | 79 | func TestJSONMarshalJSON_Null(t *testing.T) { 80 | t.Parallel() 81 | 82 | var j JSON 83 | res, err := j.MarshalJSON() 84 | if err != nil { 85 | t.Error(err) 86 | } 87 | 88 | if !bytes.Equal(res, []byte(`null`)) { 89 | t.Errorf("Expected %q, got %v", `null`, res) 90 | } 91 | } 92 | 93 | func TestJSONMarshalJSON(t *testing.T) { 94 | t.Parallel() 95 | 96 | j := JSON(`"hi"`) 97 | res, err := j.MarshalJSON() 98 | if err != nil { 99 | t.Error(err) 100 | } 101 | 102 | if !bytes.Equal(res, []byte(`"hi"`)) { 103 | t.Errorf("Expected %q, got %v", `"hi"`, res) 104 | } 105 | } 106 | 107 | func TestJSONValue(t *testing.T) { 108 | t.Parallel() 109 | 110 | j := JSON(`{"Name":"hi","Age":15}`) 111 | v, err := j.Value() 112 | if err != nil { 113 | t.Error(err) 114 | } 115 | 116 | if !bytes.Equal(j, v.([]byte)) { 117 | t.Errorf("byte mismatch, %v %v", j, v) 118 | } 119 | } 120 | 121 | func TestJSONScan(t *testing.T) { 122 | t.Parallel() 123 | 124 | j := JSON{} 125 | 126 | err := j.Scan(`"hello"`) 127 | if err != nil { 128 | t.Error(err) 129 | } 130 | 131 | if !bytes.Equal(j, []byte(`"hello"`)) { 132 | t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`))) 133 | } 134 | } 135 | 136 | func BenchmarkJSON_Scan(b *testing.B) { 137 | data := `"` + strings.Repeat("A", 1024) + `"` 138 | for i := 0; i < b.N; i++ { 139 | var j JSON 140 | err := j.Scan(data) 141 | if err != nil { 142 | b.Error(err) 143 | } 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /types/pgeo/box.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | // Box is represented by pairs of points that are opposite corners of the box. 10 | type Box [2]Point 11 | 12 | // Value for the database 13 | func (b Box) Value() (driver.Value, error) { 14 | return valueBox(b) 15 | } 16 | 17 | // Scan from sql query 18 | func (b *Box) Scan(src interface{}) error { 19 | return scanBox(b, src) 20 | } 21 | 22 | func valueBox(b Box) (driver.Value, error) { 23 | return fmt.Sprintf(`(%s)`, formatPoints(b[:])), nil 24 | } 25 | 26 | func scanBox(b *Box, src interface{}) error { 27 | if src == nil { 28 | *b = NewBox(Point{}, Point{}) 29 | return nil 30 | } 31 | 32 | points, err := parsePointsSrc(src) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | if len(points) != 2 { 38 | return errors.New("wrong box") 39 | } 40 | 41 | *b = NewBox(points[0], points[1]) 42 | 43 | return nil 44 | } 45 | 46 | func randBox(nextInt func() int64) Box { 47 | return Box([2]Point{randPoint(nextInt), randPoint(nextInt)}) 48 | } 49 | 50 | // Randomize for sqlboiler 51 | func (b *Box) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 52 | *b = randBox(nextInt) 53 | } 54 | -------------------------------------------------------------------------------- /types/pgeo/circle.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // Circle is represented by a center point and radius. 12 | type Circle struct { 13 | Point 14 | Radius float64 `json:"radius"` 15 | } 16 | 17 | // Value for the database 18 | func (c Circle) Value() (driver.Value, error) { 19 | return valueCircle(c) 20 | } 21 | 22 | // Scan from sql query 23 | func (c *Circle) Scan(src interface{}) error { 24 | return scanCircle(c, src) 25 | } 26 | 27 | func valueCircle(c Circle) (driver.Value, error) { 28 | return fmt.Sprintf(`<%s,%v>`, formatPoint(c.Point), c.Radius), nil 29 | } 30 | 31 | func scanCircle(c *Circle, src interface{}) error { 32 | if src == nil { 33 | *c = NewCircle(Point{}, 0) 34 | return nil 35 | } 36 | 37 | val, err := iToS(src) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | points, err := parsePoints(val) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | pdzs := strings.Split(val, "),") 48 | 49 | if len(points) != 1 || len(pdzs) != 2 { 50 | return errors.New("wrong circle") 51 | } 52 | 53 | r, err := strconv.ParseFloat(strings.Trim(pdzs[1], ">"), 64) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | *c = NewCircle(points[0], r) 59 | 60 | return nil 61 | } 62 | 63 | func randCircle(nextInt func() int64) Circle { 64 | return Circle{randPoint(nextInt), newRandNum(nextInt)} 65 | } 66 | 67 | // Randomize for sqlboiler 68 | func (c *Circle) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 69 | *c = randCircle(nextInt) 70 | } 71 | -------------------------------------------------------------------------------- /types/pgeo/general.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | func iToS(src interface{}) (string, error) { 13 | var val string 14 | var err error 15 | 16 | switch src.(type) { 17 | case string: 18 | val = src.(string) 19 | case []byte: 20 | val = string(src.([]byte)) 21 | default: 22 | err = fmt.Errorf("incompatible type %v", reflect.ValueOf(src).Kind().String()) 23 | } 24 | 25 | return val, err 26 | } 27 | 28 | func parseNums(s []string) ([]float64, error) { 29 | var pts = []float64{} 30 | for _, p := range s { 31 | pt, err := strconv.ParseFloat(p, 64) 32 | if err != nil { 33 | return pts, err 34 | } 35 | 36 | pts = append(pts, pt) 37 | } 38 | 39 | return pts, nil 40 | } 41 | 42 | func formatPoint(point Point) string { 43 | return fmt.Sprintf(`(%v,%v)`, point.X, point.Y) 44 | } 45 | 46 | func formatPoints(points []Point) string { 47 | var pts = []string{} 48 | for _, p := range points { 49 | pts = append(pts, formatPoint(p)) 50 | } 51 | return strings.Join(pts, ",") 52 | } 53 | 54 | var parsePointRegexp = regexp.MustCompile(`^\(([0-9\.Ee-]+),([0-9\.Ee-]+)\)$`) 55 | 56 | func parsePoint(pt string) (Point, error) { 57 | var point = Point{} 58 | var err error 59 | 60 | pdzs := parsePointRegexp.FindStringSubmatch(pt) 61 | if len(pdzs) != 3 { 62 | return point, errors.New("wrong point") 63 | } 64 | 65 | nums, err := parseNums(pdzs[1:3]) 66 | if err != nil { 67 | return point, err 68 | } 69 | 70 | point.X = nums[0] 71 | point.Y = nums[1] 72 | 73 | return point, nil 74 | } 75 | 76 | var parsePointsRegexp = regexp.MustCompile(`\(([0-9\.Ee-]+),([0-9\.Ee-]+)\)`) 77 | 78 | func parsePoints(pts string) ([]Point, error) { 79 | var points = []Point{} 80 | 81 | pdzs := parsePointsRegexp.FindAllString(pts, -1) 82 | for _, pt := range pdzs { 83 | point, err := parsePoint(pt) 84 | if err != nil { 85 | return points, err 86 | } 87 | 88 | points = append(points, point) 89 | } 90 | 91 | return points, nil 92 | } 93 | 94 | func parsePointsSrc(src interface{}) ([]Point, error) { 95 | val, err := iToS(src) 96 | if err != nil { 97 | return []Point{}, err 98 | } 99 | 100 | return parsePoints(val) 101 | } 102 | 103 | func newRandNum(nextInt func() int64) float64 { 104 | return float64(nextInt()) 105 | } 106 | -------------------------------------------------------------------------------- /types/pgeo/general_test.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func BenchmarkParsePoint(b *testing.B) { 8 | pString := "(-13.735219157895635,-72.7159127785469)" 9 | for i := 0; i < 100000; i++ { 10 | _, err := parsePoint(pString) 11 | if err != nil { 12 | b.Error("parsePoint failed", err) 13 | } 14 | } 15 | } 16 | 17 | func BenchmarkParsePointScientificNotation(b *testing.B) { 18 | pString := "(-1.73521E-5,-2.7159127785469e-7)" 19 | for i := 0; i < 100000; i++ { 20 | _, err := parsePoint(pString) 21 | if err != nil { 22 | b.Error("parsePoint failed", err) 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /types/pgeo/line.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | "regexp" 8 | ) 9 | 10 | var parseLineRegexp = regexp.MustCompile(`^\{(-?[0-9]+(?:\.[0-9]+)?),(-?[0-9]+(?:\.[0-9]+)?),(-?[0-9]+(?:\.[0-9]+)?)\}$`) 11 | 12 | // Line represents a infinite line with the linear equation Ax + By + C = 0, where A and B are not both zero. 13 | type Line struct { 14 | A float64 `json:"a"` 15 | B float64 `json:"b"` 16 | C float64 `json:"c"` 17 | } 18 | 19 | // Value for database 20 | func (l Line) Value() (driver.Value, error) { 21 | return valueLine(l) 22 | } 23 | 24 | // Scan from sql query 25 | func (l *Line) Scan(src interface{}) error { 26 | return scanLine(l, src) 27 | } 28 | 29 | func valueLine(l Line) (driver.Value, error) { 30 | return fmt.Sprintf(`{%[1]v,%[2]v,%[3]v}`, l.A, l.B, l.C), nil 31 | } 32 | 33 | func scanLine(l *Line, src interface{}) error { 34 | if src == nil { 35 | *l = NewLine(0, 0, 0) 36 | return nil 37 | } 38 | 39 | val, err := iToS(src) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | pdzs := parseLineRegexp.FindStringSubmatch(val) 45 | if len(pdzs) != 4 { 46 | return errors.New("wrong line") 47 | } 48 | 49 | nums, err := parseNums(pdzs[1:4]) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | *l = NewLine(nums[0], nums[1], nums[2]) 55 | 56 | return nil 57 | } 58 | 59 | func randLine(nextInt func() int64) Line { 60 | return Line{newRandNum(nextInt), newRandNum(nextInt), 0} 61 | } 62 | 63 | // Randomize for sqlboiler 64 | func (l *Line) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 65 | *l = randLine(nextInt) 66 | } 67 | -------------------------------------------------------------------------------- /types/pgeo/lseg.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | // Lseg is a line segment and is represented by pairs of points that are the endpoints of the segment. 10 | type Lseg [2]Point 11 | 12 | // Value for the database 13 | func (l Lseg) Value() (driver.Value, error) { 14 | return valueLseg(l) 15 | } 16 | 17 | // Scan from sql query 18 | func (l *Lseg) Scan(src interface{}) error { 19 | return scanLseg(l, src) 20 | } 21 | 22 | func valueLseg(l Lseg) (driver.Value, error) { 23 | return fmt.Sprintf(`[%s]`, formatPoints(l[:])), nil 24 | } 25 | 26 | func scanLseg(l *Lseg, src interface{}) error { 27 | if src == nil { 28 | *l = NewLseg(Point{}, Point{}) 29 | return nil 30 | } 31 | 32 | points, err := parsePointsSrc(src) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | if len(points) != 2 { 38 | return errors.New("wrong lseg") 39 | } 40 | 41 | *l = NewLseg(points[0], points[1]) 42 | 43 | return nil 44 | } 45 | 46 | func randLseg(nextInt func() int64) Lseg { 47 | return Lseg([2]Point{randPoint(nextInt), randPoint(nextInt)}) 48 | } 49 | 50 | // Randomize for sqlboiler 51 | func (l *Lseg) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 52 | *l = randLseg(nextInt) 53 | } 54 | -------------------------------------------------------------------------------- /types/pgeo/main.go: -------------------------------------------------------------------------------- 1 | // Package pgeo implements geometric types for Postgres 2 | // 3 | // Geometric types: 4 | // https://www.postgresql.org/docs/current/static/datatype-geometric.html 5 | // 6 | // Description: 7 | // https://github.com/saulortega/pgeo 8 | package pgeo 9 | 10 | // NewPoint creates a point 11 | func NewPoint(X, Y float64) Point { 12 | return Point{X, Y} 13 | } 14 | 15 | // NewLine creates a line 16 | func NewLine(A, B, C float64) Line { 17 | return Line{A, B, C} 18 | } 19 | 20 | // NewLseg creates a line segment 21 | func NewLseg(A, B Point) Lseg { 22 | return Lseg([2]Point{A, B}) 23 | } 24 | 25 | // NewBox creates a box 26 | func NewBox(A, B Point) Box { 27 | return Box([2]Point{A, B}) 28 | } 29 | 30 | // NewPath creates a path 31 | func NewPath(P []Point, C bool) Path { 32 | return Path{P, C} 33 | } 34 | 35 | // NewPolygon creates a polygon 36 | func NewPolygon(P []Point) Polygon { 37 | return Polygon(P) 38 | } 39 | 40 | // NewCircle creates a circle from a radius and a point 41 | func NewCircle(P Point, R float64) Circle { 42 | return Circle{P, R} 43 | } 44 | 45 | // NewNullPoint creates a point which can be null 46 | func NewNullPoint(P Point, v bool) NullPoint { 47 | return NullPoint{P, v} 48 | } 49 | 50 | // NewNullLine creates a line which can be null 51 | func NewNullLine(L Line, v bool) NullLine { 52 | return NullLine{L, v} 53 | } 54 | 55 | // NewNullLseg creates a line segment which can be null 56 | func NewNullLseg(L Lseg, v bool) NullLseg { 57 | return NullLseg{L, v} 58 | } 59 | 60 | // NewNullBox creates a box which can be null 61 | func NewNullBox(B Box, v bool) NullBox { 62 | return NullBox{B, v} 63 | } 64 | 65 | // NewNullPath creates a path which can be null 66 | func NewNullPath(P Path, v bool) NullPath { 67 | return NullPath{P, v} 68 | } 69 | 70 | // NewNullPolygon creates a polygon which can be null 71 | func NewNullPolygon(P Polygon, v bool) NullPolygon { 72 | return NullPolygon{P, v} 73 | } 74 | 75 | // NewNullCircle creates a circle which can be null 76 | func NewNullCircle(C Circle, v bool) NullCircle { 77 | return NullCircle{C, v} 78 | } 79 | -------------------------------------------------------------------------------- /types/pgeo/nullBox.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullBox allows a box to be null 8 | type NullBox struct { 9 | Box 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for the database 14 | func (b NullBox) Value() (driver.Value, error) { 15 | if !b.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valueBox(b.Box) 20 | } 21 | 22 | // Scan from sql query 23 | func (b *NullBox) Scan(src interface{}) error { 24 | if src == nil { 25 | b.Box, b.Valid = NewBox(Point{}, Point{}), false 26 | return nil 27 | } 28 | 29 | b.Valid = true 30 | return scanBox(&b.Box, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (b *NullBox) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | b.Valid = false 37 | return 38 | } 39 | 40 | b.Valid = true 41 | b.Box = randBox(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/nullCircle.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullCircle allows circle to be null 8 | type NullCircle struct { 9 | Circle 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for database 14 | func (c NullCircle) Value() (driver.Value, error) { 15 | if !c.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valueCircle(c.Circle) 20 | } 21 | 22 | // Scan from sql query 23 | func (c *NullCircle) Scan(src interface{}) error { 24 | if src == nil { 25 | c.Circle, c.Valid = NewCircle(Point{}, 0), false 26 | return nil 27 | } 28 | 29 | c.Valid = true 30 | return scanCircle(&c.Circle, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (c *NullCircle) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | c.Valid = false 37 | return 38 | } 39 | 40 | c.Valid = true 41 | c.Circle = randCircle(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/nullLine.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullLine allows line to be null 8 | type NullLine struct { 9 | Line 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for database 14 | func (l NullLine) Value() (driver.Value, error) { 15 | if !l.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valueLine(l.Line) 20 | } 21 | 22 | // Scan from sql query 23 | func (l *NullLine) Scan(src interface{}) error { 24 | if src == nil { 25 | l.Line, l.Valid = NewLine(0, 0, 0), false 26 | return nil 27 | } 28 | 29 | l.Valid = true 30 | return scanLine(&l.Line, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (l *NullLine) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | l.Valid = false 37 | return 38 | } 39 | 40 | l.Valid = true 41 | l.Line = randLine(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/nullLseg.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullLseg allows line segment to be null 8 | type NullLseg struct { 9 | Lseg 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for database 14 | func (l NullLseg) Value() (driver.Value, error) { 15 | if !l.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valueLseg(l.Lseg) 20 | } 21 | 22 | // Scan from sql query 23 | func (l *NullLseg) Scan(src interface{}) error { 24 | if src == nil { 25 | l.Lseg, l.Valid = NewLseg(Point{}, Point{}), false 26 | return nil 27 | } 28 | 29 | l.Valid = true 30 | return scanLseg(&l.Lseg, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (l *NullLseg) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | l.Valid = false 37 | return 38 | } 39 | 40 | l.Valid = true 41 | l.Lseg = randLseg(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/nullPath.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullPath allows path to be null 8 | type NullPath struct { 9 | Path 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for database 14 | func (p NullPath) Value() (driver.Value, error) { 15 | if !p.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valuePath(p.Path) 20 | } 21 | 22 | // Scan from sql query 23 | func (p *NullPath) Scan(src interface{}) error { 24 | if src == nil { 25 | p.Path, p.Valid = NewPath([]Point{Point{}, Point{}}, false), false 26 | return nil 27 | } 28 | 29 | p.Valid = true 30 | return scanPath(&p.Path, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (p *NullPath) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | p.Valid = false 37 | return 38 | } 39 | 40 | p.Valid = true 41 | p.Path = randPath(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/nullPoint.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullPoint allows point to be null 8 | type NullPoint struct { 9 | Point 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for database 14 | func (p NullPoint) Value() (driver.Value, error) { 15 | if !p.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valuePoint(p.Point) 20 | } 21 | 22 | // Scan from sql query 23 | func (p *NullPoint) Scan(src interface{}) error { 24 | if src == nil { 25 | p.Point, p.Valid = NewPoint(0, 0), false 26 | return nil 27 | } 28 | 29 | p.Valid = true 30 | return scanPoint(&p.Point, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (p *NullPoint) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | p.Valid = false 37 | return 38 | } 39 | 40 | p.Valid = true 41 | p.Point = randPoint(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/nullPolygon.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // NullPolygon allows polygon to be null 8 | type NullPolygon struct { 9 | Polygon 10 | Valid bool `json:"valid"` 11 | } 12 | 13 | // Value for database 14 | func (p NullPolygon) Value() (driver.Value, error) { 15 | if !p.Valid { 16 | return nil, nil 17 | } 18 | 19 | return valuePolygon(p.Polygon) 20 | } 21 | 22 | // Scan from sql query 23 | func (p *NullPolygon) Scan(src interface{}) error { 24 | if src == nil { 25 | p.Polygon, p.Valid = NewPolygon([]Point{Point{}, Point{}, Point{}, Point{}}), false 26 | return nil 27 | } 28 | 29 | p.Valid = true 30 | return scanPolygon(&p.Polygon, src) 31 | } 32 | 33 | // Randomize for sqlboiler 34 | func (p *NullPolygon) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 35 | if shouldBeNull { 36 | p.Valid = false 37 | return 38 | } 39 | 40 | p.Valid = true 41 | p.Polygon = randPolygon(nextInt) 42 | } 43 | -------------------------------------------------------------------------------- /types/pgeo/path.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | "regexp" 8 | ) 9 | 10 | var closedPathRegexp = regexp.MustCompile(`^\(\(`) 11 | 12 | // Path is represented by lists of connected points. 13 | // Paths can be open, where the first and last points in the list are considered not connected, 14 | // or closed, where the first and last points are considered connected. 15 | type Path struct { 16 | Points []Point 17 | Closed bool `json:"closed"` 18 | } 19 | 20 | // Value for database 21 | func (p Path) Value() (driver.Value, error) { 22 | return valuePath(p) 23 | } 24 | 25 | // Scan from sql query 26 | func (p *Path) Scan(src interface{}) error { 27 | return scanPath(p, src) 28 | } 29 | 30 | func valuePath(p Path) (driver.Value, error) { 31 | var val string 32 | if p.Closed { 33 | val = fmt.Sprintf(`(%s)`, formatPoints(p.Points)) 34 | } else { 35 | val = fmt.Sprintf(`[%s]`, formatPoints(p.Points)) 36 | } 37 | return val, nil 38 | } 39 | 40 | func scanPath(p *Path, src interface{}) error { 41 | if src == nil { 42 | return nil 43 | } 44 | 45 | val, err := iToS(src) 46 | if err != nil { 47 | return err 48 | } 49 | 50 | (*p).Points, err = parsePoints(val) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | if len((*p).Points) < 1 { 56 | return errors.New("wrong path") 57 | } 58 | 59 | (*p).Closed = closedPathRegexp.MatchString(val) 60 | 61 | return nil 62 | } 63 | 64 | func randPath(nextInt func() int64) Path { 65 | return Path{randPoints(nextInt, 3), newRandNum(nextInt) < 40} 66 | } 67 | 68 | // Randomize for sqlboiler 69 | func (p *Path) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 70 | *p = randPath(nextInt) 71 | } 72 | -------------------------------------------------------------------------------- /types/pgeo/point.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // Point is the fundamental two-dimensional building block for geometric types. 8 | // X and Y are the respective coordinates, as floating-point numbers 9 | type Point struct { 10 | X float64 `json:"x"` 11 | Y float64 `json:"y"` 12 | } 13 | 14 | // Value representation for database 15 | func (p Point) Value() (driver.Value, error) { 16 | return valuePoint(p) 17 | } 18 | 19 | // Scan from query 20 | func (p *Point) Scan(src interface{}) error { 21 | return scanPoint(p, src) 22 | } 23 | 24 | func valuePoint(p Point) (driver.Value, error) { 25 | return formatPoint(p), nil 26 | } 27 | 28 | func scanPoint(p *Point, src interface{}) error { 29 | if src == nil { 30 | *p = NewPoint(0, 0) 31 | return nil 32 | } 33 | 34 | val, err := iToS(src) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | *p, err = parsePoint(val) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | return nil 45 | 46 | } 47 | 48 | func randPoint(nextInt func() int64) Point { 49 | return Point{newRandNum(nextInt), newRandNum(nextInt)} 50 | } 51 | 52 | func randPoints(nextInt func() int64, n int) []Point { 53 | var points = []Point{} 54 | if n <= 0 { 55 | return points 56 | } 57 | 58 | for i := 0; i < n; i++ { 59 | points = append(points, randPoint(nextInt)) 60 | } 61 | 62 | return points 63 | } 64 | 65 | // Randomize for sqlboiler 66 | func (p *Point) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 67 | *p = randPoint(nextInt) 68 | } 69 | -------------------------------------------------------------------------------- /types/pgeo/polygon.go: -------------------------------------------------------------------------------- 1 | package pgeo 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | // Polygon is represented by lists of points (the vertexes of the polygon). 10 | type Polygon []Point 11 | 12 | // Value for database 13 | func (p Polygon) Value() (driver.Value, error) { 14 | return valuePolygon(p) 15 | } 16 | 17 | // Scan from sql query 18 | func (p *Polygon) Scan(src interface{}) error { 19 | return scanPolygon(p, src) 20 | } 21 | 22 | func valuePolygon(p Polygon) (driver.Value, error) { 23 | return fmt.Sprintf(`(%s)`, formatPoints(p[:])), nil 24 | } 25 | 26 | func scanPolygon(p *Polygon, src interface{}) error { 27 | if src == nil { 28 | return nil 29 | } 30 | 31 | var err error 32 | *p, err = parsePointsSrc(src) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | if len(*p) < 1 { 38 | return errors.New("wrong polygon") 39 | } 40 | 41 | return nil 42 | } 43 | 44 | func randPolygon(nextInt func() int64) Polygon { 45 | return Polygon(randPoints(nextInt, 3)) 46 | } 47 | 48 | // Randomize for sqlboiler 49 | func (p *Polygon) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 50 | *p = randPolygon(nextInt) 51 | } 52 | --------------------------------------------------------------------------------