├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── aggregate.go ├── aggregate_test.go ├── clause.go ├── clauses.go ├── clauses_test.go ├── column.go ├── column_test.go ├── combiner.go ├── combiner_test.go ├── compiler.go ├── compiler_test.go ├── conditional.go ├── conditional_test.go ├── constraint.go ├── constraint_test.go ├── delete.go ├── delete_test.go ├── dialect.go ├── dialect_default.go ├── dialect_test.go ├── dialects ├── mysql │ ├── README.md │ ├── errors.go │ ├── mysql.go │ ├── mysql_test.go │ └── tools │ │ └── generrors.go ├── postgres │ ├── postgres.go │ └── postgres_test.go └── sqlite │ ├── sqlite.go │ └── sqlite_test.go ├── docker-compose.yml ├── engine.go ├── engine_test.go ├── errors.go ├── errors_test.go ├── go.mod ├── go.sum ├── index.go ├── insert.go ├── insert_test.go ├── logger.go ├── logger_test.go ├── metadata.go ├── qb_logo_128.png ├── select.go ├── select_test.go ├── statement.go ├── statement_test.go ├── table.go ├── table_test.go ├── testutils_test.go ├── type.go ├── type_test.go ├── update.go ├── update_test.go ├── upsert.go ├── upsert_test.go ├── where.go └── where_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | .idea/* 27 | vendor/* 28 | qb.iml 29 | qb_test.db 30 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | sudo: required 4 | 5 | services: 6 | - mysql 7 | - postgresql 8 | 9 | addons: 10 | postgresql: 9.5 11 | 12 | go: 13 | - tip 14 | 15 | before_install: 16 | - go get github.com/mattn/goveralls 17 | 18 | install: 19 | - go get -v github.com/lib/pq 20 | - go get -v github.com/go-sql-driver/mysql 21 | - go get -v github.com/mattn/go-sqlite3 22 | - go get -t -v ./... 23 | 24 | script: 25 | - go test -v -covermode=count -coverprofile=coverage.out 26 | - go test -v -covermode=count -coverprofile=sqlite.out ./dialects/sqlite 27 | - go test -v -covermode=count -coverprofile=postgres.out ./dialects/postgres 28 | - go test -v -covermode=count -coverprofile=mysql.out ./dialects/mysql 29 | - tail --lines +2 sqlite.out >> coverage.out 30 | - tail --lines +2 postgres.out >> coverage.out 31 | - tail --lines +2 mysql.out >> coverage.out 32 | - $GOPATH/bin/goveralls -coverprofile=coverage.out -service=travis-ci -repotoken 0yIEy3NVX2lXn3KxYzHjkla7EWGjvmLAp 33 | 34 | before_script: 35 | - mysql -e 'create database IF NOT EXISTS qb_test;' 36 | - psql -U postgres -c 'CREATE DATABASE qb_test;' 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 2.1, February 1999 3 | 4 | Copyright (C) 1991, 1999 Free Software Foundation, Inc. 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | [This is the first released version of the Lesser GPL. It also counts 10 | as the successor of the GNU Library Public License, version 2, hence 11 | the version number 2.1.] 12 | 13 | Preamble 14 | 15 | The licenses for most software are designed to take away your 16 | freedom to share and change it. By contrast, the GNU General Public 17 | Licenses are intended to guarantee your freedom to share and change 18 | free software--to make sure the software is free for all its users. 19 | 20 | This license, the Lesser General Public License, applies to some 21 | specially designated software packages--typically libraries--of the 22 | Free Software Foundation and other authors who decide to use it. You 23 | can use it too, but we suggest you first think carefully about whether 24 | this license or the ordinary General Public License is the better 25 | strategy to use in any particular case, based on the explanations below. 26 | 27 | When we speak of free software, we are referring to freedom of use, 28 | not price. Our General Public Licenses are designed to make sure that 29 | you have the freedom to distribute copies of free software (and charge 30 | for this service if you wish); that you receive source code or can get 31 | it if you want it; that you can change the software and use pieces of 32 | it in new free programs; and that you are informed that you can do 33 | these things. 34 | 35 | To protect your rights, we need to make restrictions that forbid 36 | distributors to deny you these rights or to ask you to surrender these 37 | rights. These restrictions translate to certain responsibilities for 38 | you if you distribute copies of the library or if you modify it. 39 | 40 | For example, if you distribute copies of the library, whether gratis 41 | or for a fee, you must give the recipients all the rights that we gave 42 | you. You must make sure that they, too, receive or can get the source 43 | code. If you link other code with the library, you must provide 44 | complete object files to the recipients, so that they can relink them 45 | with the library after making changes to the library and recompiling 46 | it. And you must show them these terms so they know their rights. 47 | 48 | We protect your rights with a two-step method: (1) we copyright the 49 | library, and (2) we offer you this license, which gives you legal 50 | permission to copy, distribute and/or modify the library. 51 | 52 | To protect each distributor, we want to make it very clear that 53 | there is no warranty for the free library. Also, if the library is 54 | modified by someone else and passed on, the recipients should know 55 | that what they have is not the original version, so that the original 56 | author's reputation will not be affected by problems that might be 57 | introduced by others. 58 | 59 | Finally, software patents pose a constant threat to the existence of 60 | any free program. We wish to make sure that a company cannot 61 | effectively restrict the users of a free program by obtaining a 62 | restrictive license from a patent holder. Therefore, we insist that 63 | any patent license obtained for a version of the library must be 64 | consistent with the full freedom of use specified in this license. 65 | 66 | Most GNU software, including some libraries, is covered by the 67 | ordinary GNU General Public License. This license, the GNU Lesser 68 | General Public License, applies to certain designated libraries, and 69 | is quite different from the ordinary General Public License. We use 70 | this license for certain libraries in order to permit linking those 71 | libraries into non-free programs. 72 | 73 | When a program is linked with a library, whether statically or using 74 | a shared library, the combination of the two is legally speaking a 75 | combined work, a derivative of the original library. The ordinary 76 | General Public License therefore permits such linking only if the 77 | entire combination fits its criteria of freedom. The Lesser General 78 | Public License permits more lax criteria for linking other code with 79 | the library. 80 | 81 | We call this license the "Lesser" General Public License because it 82 | does Less to protect the user's freedom than the ordinary General 83 | Public License. It also provides other free software developers Less 84 | of an advantage over competing non-free programs. These disadvantages 85 | are the reason we use the ordinary General Public License for many 86 | libraries. However, the Lesser license provides advantages in certain 87 | special circumstances. 88 | 89 | For example, on rare occasions, there may be a special need to 90 | encourage the widest possible use of a certain library, so that it becomes 91 | a de-facto standard. To achieve this, non-free programs must be 92 | allowed to use the library. A more frequent case is that a free 93 | library does the same job as widely used non-free libraries. In this 94 | case, there is little to gain by limiting the free library to free 95 | software only, so we use the Lesser General Public License. 96 | 97 | In other cases, permission to use a particular library in non-free 98 | programs enables a greater number of people to use a large body of 99 | free software. For example, permission to use the GNU C Library in 100 | non-free programs enables many more people to use the whole GNU 101 | operating system, as well as its variant, the GNU/Linux operating 102 | system. 103 | 104 | Although the Lesser General Public License is Less protective of the 105 | users' freedom, it does ensure that the user of a program that is 106 | linked with the Library has the freedom and the wherewithal to run 107 | that program using a modified version of the Library. 108 | 109 | The precise terms and conditions for copying, distribution and 110 | modification follow. Pay close attention to the difference between a 111 | "work based on the library" and a "work that uses the library". The 112 | former contains code derived from the library, whereas the latter must 113 | be combined with the library in order to run. 114 | 115 | GNU LESSER GENERAL PUBLIC LICENSE 116 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 117 | 118 | 0. This License Agreement applies to any software library or other 119 | program which contains a notice placed by the copyright holder or 120 | other authorized party saying it may be distributed under the terms of 121 | this Lesser General Public License (also called "this License"). 122 | Each licensee is addressed as "you". 123 | 124 | A "library" means a collection of software functions and/or data 125 | prepared so as to be conveniently linked with application programs 126 | (which use some of those functions and data) to form executables. 127 | 128 | The "Library", below, refers to any such software library or work 129 | which has been distributed under these terms. A "work based on the 130 | Library" means either the Library or any derivative work under 131 | copyright law: that is to say, a work containing the Library or a 132 | portion of it, either verbatim or with modifications and/or translated 133 | straightforwardly into another language. (Hereinafter, translation is 134 | included without limitation in the term "modification".) 135 | 136 | "Source code" for a work means the preferred form of the work for 137 | making modifications to it. For a library, complete source code means 138 | all the source code for all modules it contains, plus any associated 139 | interface definition files, plus the scripts used to control compilation 140 | and installation of the library. 141 | 142 | Activities other than copying, distribution and modification are not 143 | covered by this License; they are outside its scope. The act of 144 | running a program using the Library is not restricted, and output from 145 | such a program is covered only if its contents constitute a work based 146 | on the Library (independent of the use of the Library in a tool for 147 | writing it). Whether that is true depends on what the Library does 148 | and what the program that uses the Library does. 149 | 150 | 1. You may copy and distribute verbatim copies of the Library's 151 | complete source code as you receive it, in any medium, provided that 152 | you conspicuously and appropriately publish on each copy an 153 | appropriate copyright notice and disclaimer of warranty; keep intact 154 | all the notices that refer to this License and to the absence of any 155 | warranty; and distribute a copy of this License along with the 156 | Library. 157 | 158 | You may charge a fee for the physical act of transferring a copy, 159 | and you may at your option offer warranty protection in exchange for a 160 | fee. 161 | 162 | 2. You may modify your copy or copies of the Library or any portion 163 | of it, thus forming a work based on the Library, and copy and 164 | distribute such modifications or work under the terms of Section 1 165 | above, provided that you also meet all of these conditions: 166 | 167 | a) The modified work must itself be a software library. 168 | 169 | b) You must cause the files modified to carry prominent notices 170 | stating that you changed the files and the date of any change. 171 | 172 | c) You must cause the whole of the work to be licensed at no 173 | charge to all third parties under the terms of this License. 174 | 175 | d) If a facility in the modified Library refers to a function or a 176 | table of data to be supplied by an application program that uses 177 | the facility, other than as an argument passed when the facility 178 | is invoked, then you must make a good faith effort to ensure that, 179 | in the event an application does not supply such function or 180 | table, the facility still operates, and performs whatever part of 181 | its purpose remains meaningful. 182 | 183 | (For example, a function in a library to compute square roots has 184 | a purpose that is entirely well-defined independent of the 185 | application. Therefore, Subsection 2d requires that any 186 | application-supplied function or table used by this function must 187 | be optional: if the application does not supply it, the square 188 | root function must still compute square roots.) 189 | 190 | These requirements apply to the modified work as a whole. If 191 | identifiable sections of that work are not derived from the Library, 192 | and can be reasonably considered independent and separate works in 193 | themselves, then this License, and its terms, do not apply to those 194 | sections when you distribute them as separate works. But when you 195 | distribute the same sections as part of a whole which is a work based 196 | on the Library, the distribution of the whole must be on the terms of 197 | this License, whose permissions for other licensees extend to the 198 | entire whole, and thus to each and every part regardless of who wrote 199 | it. 200 | 201 | Thus, it is not the intent of this section to claim rights or contest 202 | your rights to work written entirely by you; rather, the intent is to 203 | exercise the right to control the distribution of derivative or 204 | collective works based on the Library. 205 | 206 | In addition, mere aggregation of another work not based on the Library 207 | with the Library (or with a work based on the Library) on a volume of 208 | a storage or distribution medium does not bring the other work under 209 | the scope of this License. 210 | 211 | 3. You may opt to apply the terms of the ordinary GNU General Public 212 | License instead of this License to a given copy of the Library. To do 213 | this, you must alter all the notices that refer to this License, so 214 | that they refer to the ordinary GNU General Public License, version 2, 215 | instead of to this License. (If a newer version than version 2 of the 216 | ordinary GNU General Public License has appeared, then you can specify 217 | that version instead if you wish.) Do not make any other change in 218 | these notices. 219 | 220 | Once this change is made in a given copy, it is irreversible for 221 | that copy, so the ordinary GNU General Public License applies to all 222 | subsequent copies and derivative works made from that copy. 223 | 224 | This option is useful when you wish to copy part of the code of 225 | the Library into a program that is not a library. 226 | 227 | 4. You may copy and distribute the Library (or a portion or 228 | derivative of it, under Section 2) in object code or executable form 229 | under the terms of Sections 1 and 2 above provided that you accompany 230 | it with the complete corresponding machine-readable source code, which 231 | must be distributed under the terms of Sections 1 and 2 above on a 232 | medium customarily used for software interchange. 233 | 234 | If distribution of object code is made by offering access to copy 235 | from a designated place, then offering equivalent access to copy the 236 | source code from the same place satisfies the requirement to 237 | distribute the source code, even though third parties are not 238 | compelled to copy the source along with the object code. 239 | 240 | 5. A program that contains no derivative of any portion of the 241 | Library, but is designed to work with the Library by being compiled or 242 | linked with it, is called a "work that uses the Library". Such a 243 | work, in isolation, is not a derivative work of the Library, and 244 | therefore falls outside the scope of this License. 245 | 246 | However, linking a "work that uses the Library" with the Library 247 | creates an executable that is a derivative of the Library (because it 248 | contains portions of the Library), rather than a "work that uses the 249 | library". The executable is therefore covered by this License. 250 | Section 6 states terms for distribution of such executables. 251 | 252 | When a "work that uses the Library" uses material from a header file 253 | that is part of the Library, the object code for the work may be a 254 | derivative work of the Library even though the source code is not. 255 | Whether this is true is especially significant if the work can be 256 | linked without the Library, or if the work is itself a library. The 257 | threshold for this to be true is not precisely defined by law. 258 | 259 | If such an object file uses only numerical parameters, data 260 | structure layouts and accessors, and small macros and small inline 261 | functions (ten lines or less in length), then the use of the object 262 | file is unrestricted, regardless of whether it is legally a derivative 263 | work. (Executables containing this object code plus portions of the 264 | Library will still fall under Section 6.) 265 | 266 | Otherwise, if the work is a derivative of the Library, you may 267 | distribute the object code for the work under the terms of Section 6. 268 | Any executables containing that work also fall under Section 6, 269 | whether or not they are linked directly with the Library itself. 270 | 271 | 6. As an exception to the Sections above, you may also combine or 272 | link a "work that uses the Library" with the Library to produce a 273 | work containing portions of the Library, and distribute that work 274 | under terms of your choice, provided that the terms permit 275 | modification of the work for the customer's own use and reverse 276 | engineering for debugging such modifications. 277 | 278 | You must give prominent notice with each copy of the work that the 279 | Library is used in it and that the Library and its use are covered by 280 | this License. You must supply a copy of this License. If the work 281 | during execution displays copyright notices, you must include the 282 | copyright notice for the Library among them, as well as a reference 283 | directing the user to the copy of this License. Also, you must do one 284 | of these things: 285 | 286 | a) Accompany the work with the complete corresponding 287 | machine-readable source code for the Library including whatever 288 | changes were used in the work (which must be distributed under 289 | Sections 1 and 2 above); and, if the work is an executable linked 290 | with the Library, with the complete machine-readable "work that 291 | uses the Library", as object code and/or source code, so that the 292 | user can modify the Library and then relink to produce a modified 293 | executable containing the modified Library. (It is understood 294 | that the user who changes the contents of definitions files in the 295 | Library will not necessarily be able to recompile the application 296 | to use the modified definitions.) 297 | 298 | b) Use a suitable shared library mechanism for linking with the 299 | Library. A suitable mechanism is one that (1) uses at run time a 300 | copy of the library already present on the user's computer system, 301 | rather than copying library functions into the executable, and (2) 302 | will operate properly with a modified version of the library, if 303 | the user installs one, as long as the modified version is 304 | interface-compatible with the version that the work was made with. 305 | 306 | c) Accompany the work with a written offer, valid for at 307 | least three years, to give the same user the materials 308 | specified in Subsection 6a, above, for a charge no more 309 | than the cost of performing this distribution. 310 | 311 | d) If distribution of the work is made by offering access to copy 312 | from a designated place, offer equivalent access to copy the above 313 | specified materials from the same place. 314 | 315 | e) Verify that the user has already received a copy of these 316 | materials or that you have already sent this user a copy. 317 | 318 | For an executable, the required form of the "work that uses the 319 | Library" must include any data and utility programs needed for 320 | reproducing the executable from it. However, as a special exception, 321 | the materials to be distributed need not include anything that is 322 | normally distributed (in either source or binary form) with the major 323 | components (compiler, kernel, and so on) of the operating system on 324 | which the executable runs, unless that component itself accompanies 325 | the executable. 326 | 327 | It may happen that this requirement contradicts the license 328 | restrictions of other proprietary libraries that do not normally 329 | accompany the operating system. Such a contradiction means you cannot 330 | use both them and the Library together in an executable that you 331 | distribute. 332 | 333 | 7. You may place library facilities that are a work based on the 334 | Library side-by-side in a single library together with other library 335 | facilities not covered by this License, and distribute such a combined 336 | library, provided that the separate distribution of the work based on 337 | the Library and of the other library facilities is otherwise 338 | permitted, and provided that you do these two things: 339 | 340 | a) Accompany the combined library with a copy of the same work 341 | based on the Library, uncombined with any other library 342 | facilities. This must be distributed under the terms of the 343 | Sections above. 344 | 345 | b) Give prominent notice with the combined library of the fact 346 | that part of it is a work based on the Library, and explaining 347 | where to find the accompanying uncombined form of the same work. 348 | 349 | 8. You may not copy, modify, sublicense, link with, or distribute 350 | the Library except as expressly provided under this License. Any 351 | attempt otherwise to copy, modify, sublicense, link with, or 352 | distribute the Library is void, and will automatically terminate your 353 | rights under this License. However, parties who have received copies, 354 | or rights, from you under this License will not have their licenses 355 | terminated so long as such parties remain in full compliance. 356 | 357 | 9. You are not required to accept this License, since you have not 358 | signed it. However, nothing else grants you permission to modify or 359 | distribute the Library or its derivative works. These actions are 360 | prohibited by law if you do not accept this License. Therefore, by 361 | modifying or distributing the Library (or any work based on the 362 | Library), you indicate your acceptance of this License to do so, and 363 | all its terms and conditions for copying, distributing or modifying 364 | the Library or works based on it. 365 | 366 | 10. Each time you redistribute the Library (or any work based on the 367 | Library), the recipient automatically receives a license from the 368 | original licensor to copy, distribute, link with or modify the Library 369 | subject to these terms and conditions. You may not impose any further 370 | restrictions on the recipients' exercise of the rights granted herein. 371 | You are not responsible for enforcing compliance by third parties with 372 | this License. 373 | 374 | 11. If, as a consequence of a court judgment or allegation of patent 375 | infringement or for any other reason (not limited to patent issues), 376 | conditions are imposed on you (whether by court order, agreement or 377 | otherwise) that contradict the conditions of this License, they do not 378 | excuse you from the conditions of this License. If you cannot 379 | distribute so as to satisfy simultaneously your obligations under this 380 | License and any other pertinent obligations, then as a consequence you 381 | may not distribute the Library at all. For example, if a patent 382 | license would not permit royalty-free redistribution of the Library by 383 | all those who receive copies directly or indirectly through you, then 384 | the only way you could satisfy both it and this License would be to 385 | refrain entirely from distribution of the Library. 386 | 387 | If any portion of this section is held invalid or unenforceable under any 388 | particular circumstance, the balance of the section is intended to apply, 389 | and the section as a whole is intended to apply in other circumstances. 390 | 391 | It is not the purpose of this section to induce you to infringe any 392 | patents or other property right claims or to contest validity of any 393 | such claims; this section has the sole purpose of protecting the 394 | integrity of the free software distribution system which is 395 | implemented by public license practices. Many people have made 396 | generous contributions to the wide range of software distributed 397 | through that system in reliance on consistent application of that 398 | system; it is up to the author/donor to decide if he or she is willing 399 | to distribute software through any other system and a licensee cannot 400 | impose that choice. 401 | 402 | This section is intended to make thoroughly clear what is believed to 403 | be a consequence of the rest of this License. 404 | 405 | 12. If the distribution and/or use of the Library is restricted in 406 | certain countries either by patents or by copyrighted interfaces, the 407 | original copyright holder who places the Library under this License may add 408 | an explicit geographical distribution limitation excluding those countries, 409 | so that distribution is permitted only in or among countries not thus 410 | excluded. In such case, this License incorporates the limitation as if 411 | written in the body of this License. 412 | 413 | 13. The Free Software Foundation may publish revised and/or new 414 | versions of the Lesser General Public License from time to time. 415 | Such new versions will be similar in spirit to the present version, 416 | but may differ in detail to address new problems or concerns. 417 | 418 | Each version is given a distinguishing version number. If the Library 419 | specifies a version number of this License which applies to it and 420 | "any later version", you have the option of following the terms and 421 | conditions either of that version or of any later version published by 422 | the Free Software Foundation. If the Library does not specify a 423 | license version number, you may choose any version ever published by 424 | the Free Software Foundation. 425 | 426 | 14. If you wish to incorporate parts of the Library into other free 427 | programs whose distribution conditions are incompatible with these, 428 | write to the author to ask for permission. For software which is 429 | copyrighted by the Free Software Foundation, write to the Free 430 | Software Foundation; we sometimes make exceptions for this. Our 431 | decision will be guided by the two goals of preserving the free status 432 | of all derivatives of our free software and of promoting the sharing 433 | and reuse of software generally. 434 | 435 | NO WARRANTY 436 | 437 | 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO 438 | WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW. 439 | EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR 440 | OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY 441 | KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE 442 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 443 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE 444 | LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME 445 | THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 446 | 447 | 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN 448 | WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY 449 | AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU 450 | FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR 451 | CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE 452 | LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING 453 | RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A 454 | FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF 455 | SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 456 | DAMAGES. 457 | 458 | END OF TERMS AND CONDITIONS 459 | 460 | How to Apply These Terms to Your New Libraries 461 | 462 | If you develop a new library, and you want it to be of the greatest 463 | possible use to the public, we recommend making it free software that 464 | everyone can redistribute and change. You can do so by permitting 465 | redistribution under these terms (or, alternatively, under the terms of the 466 | ordinary General Public License). 467 | 468 | To apply these terms, attach the following notices to the library. It is 469 | safest to attach them to the start of each source file to most effectively 470 | convey the exclusion of warranty; and each file should have at least the 471 | "copyright" line and a pointer to where the full notice is found. 472 | 473 | 474 | Copyright (C) 475 | 476 | This library is free software; you can redistribute it and/or 477 | modify it under the terms of the GNU Lesser General Public 478 | License as published by the Free Software Foundation; either 479 | version 2.1 of the License, or (at your option) any later version. 480 | 481 | This library is distributed in the hope that it will be useful, 482 | but WITHOUT ANY WARRANTY; without even the implied warranty of 483 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 484 | Lesser General Public License for more details. 485 | 486 | You should have received a copy of the GNU Lesser General Public 487 | License along with this library; if not, write to the Free Software 488 | Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 489 | 490 | Also add information on how to contact you by electronic and paper mail. 491 | 492 | You should also get your employer (if you work as a programmer) or your 493 | school, if any, to sign a "copyright disclaimer" for the library, if 494 | necessary. Here is a sample; alter the names: 495 | 496 | Yoyodyne, Inc., hereby disclaims all copyright interest in the 497 | library `Frob' (a library for tweaking knobs) written by James Random Hacker. 498 | 499 | , 1 April 1990 500 | Ty Coon, President of Vice 501 | 502 | That's all there is to it! 503 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![alt text](https://github.com/slicebit/qb/raw/master/qb_logo_128.png "qb: the database toolkit for go") 2 | 3 | # qb - the database toolkit for go 4 | 5 | [![Build Status](https://travis-ci.org/slicebit/qb.svg?branch=master)](https://travis-ci.org/slicebit/qb) 6 | [![Coverage Status](https://coveralls.io/repos/github/slicebit/qb/badge.svg?branch=master)](https://coveralls.io/github/slicebit/qb?branch=master) 7 | [![License (LGPL version 2.1)](https://img.shields.io/badge/license-GNU%20LGPL%20version%202.1-brightgreen.svg?style=flat)](http://opensource.org/licenses/LGPL-2.1) 8 | [![Go Report Card](https://goreportcard.com/badge/github.com/slicebit/qb)](https://goreportcard.com/report/github.com/slicebit/qb) 9 | [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](http://godoc.org/github.com/slicebit/qb) 10 | 11 | **This project is currently pre 1.** 12 | 13 | Currently, it's not feature complete. It can have potential bugs. There are no tests covering concurrency race conditions. It can crash especially in concurrency. 14 | Before 1.x releases, each major release could break backwards compatibility. 15 | 16 | About qb 17 | -------- 18 | qb is a database toolkit for easier db queries in go. It is inspired from python's best orm, namely sqlalchemy. qb is an orm(sqlx) as well as a query builder. It is quite modular in case of using just expression api and query building stuff. 19 | 20 | [Documentation](https://qb.readme.io) 21 | ------------- 22 | The documentation is hosted in [readme.io](https://qb.readme.io) which has great support for markdown docs. Currently, the docs are about 80% - 90% complete. The doc files will be added to this repo soon. Moreover, you can check the godoc from [here](https://godoc.org/github.com/slicebit/qb). Contributions & Feedbacks in docs are welcome. 23 | 24 | Features 25 | -------- 26 | - Support for postgres(9.5.+), mysql & sqlite3 27 | - Powerful expression API for building queries & table ddls 28 | - Struct to table ddl mapper where initial table migrations can happen 29 | - Transactional session api that auto map structs to queries 30 | - Foreign key definitions 31 | - Single & Composite column indices 32 | - Relationships (soon.. probably in 0.3 milestone) 33 | 34 | Installation 35 | ------------ 36 | ```sh 37 | go get -u github.com/slicebit/qb 38 | ``` 39 | If you want to install test dependencies then; 40 | ```sh 41 | go get -u -t github.com/slicebit/qb 42 | ``` 43 | 44 | Quick Start 45 | ----------- 46 | ```go 47 | package main 48 | 49 | import ( 50 | "fmt" 51 | "github.com/slicebit/qb" 52 | _ "github.com/mattn/go-sqlite3" 53 | _ "github.com/slicebit/qb/dialects/sqlite" 54 | ) 55 | 56 | type User struct { 57 | ID string `db:"id"` 58 | Email string `db:"email"` 59 | FullName string `db:"full_name"` 60 | Oscars int `db:"oscars"` 61 | } 62 | 63 | func main() { 64 | 65 | users := qb.Table( 66 | "users", 67 | qb.Column("id", qb.Varchar().Size(40)), 68 | qb.Column("email", qb.Varchar()).NotNull().Unique(), 69 | qb.Column("full_name", qb.Varchar()).NotNull(), 70 | qb.Column("oscars", qb.Int()).NotNull().Default(0), 71 | qb.PrimaryKey("id"), 72 | ) 73 | 74 | db, err := qb.New("sqlite3", "./qb_test.db") 75 | if err != nil { 76 | panic(err) 77 | } 78 | 79 | defer db.Close() 80 | 81 | metadata := qb.MetaData() 82 | 83 | // add table to metadata 84 | metadata.AddTable(users) 85 | 86 | // create all tables registered to metadata 87 | metadata.CreateAll(db) 88 | defer metadata.DropAll(db) // drops all tables 89 | 90 | ins := qb.Insert(users).Values(map[string]interface{}{ 91 | "id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 92 | "email": "jack@nicholson.com", 93 | "full_name": "Jack Nicholson", 94 | }) 95 | 96 | _, err = db.Exec(ins) 97 | if err != nil { 98 | panic(err) 99 | } 100 | 101 | // find user 102 | var user User 103 | 104 | sel := qb.Select(users.C("id"), users.C("email"), users.C("full_name")). 105 | From(users). 106 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 107 | 108 | err = db.Get(sel, &user) 109 | fmt.Printf("%+v\n", user) 110 | } 111 | ``` 112 | 113 | Credits 114 | ------- 115 | - [Aras Can Akın](https://github.com/aacanakin) 116 | - [Christophe de Vienne](https://github.com/cdevienne) 117 | - [Onur Şentüre](https://github.com/onursenture) 118 | - [Aaron O. Ellis](https://github.com/aodin) 119 | - [Shawn Smith](https://github.com/shawnps) 120 | -------------------------------------------------------------------------------- /aggregate.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Avg function generates "avg(%s)" statement for clause 4 | func Avg(clause Clause) AggregateClause { 5 | return Aggregate("AVG", clause) 6 | } 7 | 8 | // Count function generates "count(%s)" statement for clause 9 | func Count(clause Clause) AggregateClause { 10 | return Aggregate("COUNT", clause) 11 | } 12 | 13 | // Sum function generates "sum(%s)" statement for clause 14 | func Sum(clause Clause) AggregateClause { 15 | return Aggregate("SUM", clause) 16 | } 17 | 18 | // Min function generates "min(%s)" statement for clause 19 | func Min(clause Clause) AggregateClause { 20 | return Aggregate("MIN", clause) 21 | } 22 | 23 | // Max function generates "max(%s)" statement for clause 24 | func Max(clause Clause) AggregateClause { 25 | return Aggregate("MAX", clause) 26 | } 27 | 28 | // Aggregate generates a new aggregate clause given function & clause 29 | func Aggregate(fn string, clause Clause) AggregateClause { 30 | return AggregateClause{fn, clause} 31 | } 32 | 33 | // AggregateClause is the base struct for building aggregate functions 34 | type AggregateClause struct { 35 | fn string 36 | clause Clause 37 | } 38 | 39 | // Accept calls the compiler VisitAggregate function 40 | func (c AggregateClause) Accept(context *CompilerContext) string { 41 | return context.Compiler.VisitAggregate(context, c) 42 | } 43 | -------------------------------------------------------------------------------- /aggregate_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestAggregates(t *testing.T) { 9 | col := Column("id", Varchar().Size(36)) 10 | assert.Equal(t, Aggregate("AVG", col), Avg(col)) 11 | assert.Equal(t, Aggregate("COUNT", col), Count(col)) 12 | assert.Equal(t, Aggregate("SUM", col), Sum(col)) 13 | assert.Equal(t, Aggregate("MIN", col), Min(col)) 14 | assert.Equal(t, Aggregate("MAX", col), Max(col)) 15 | } 16 | -------------------------------------------------------------------------------- /clause.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Clause is the base interface of all clauses that will get 4 | // compiled to SQL by Compiler 5 | type Clause interface { 6 | Accept(context *CompilerContext) string 7 | } 8 | 9 | // TableSQLClause is the common interface for ddl generators such as Column(), PrimaryKey(), ForeignKey().Ref(), etc. 10 | type TableSQLClause interface { 11 | // String takes the dialect and returns the ddl as an sql string 12 | String(dialect Dialect) string 13 | } 14 | 15 | // Builder is the common interface for any statement builder in qb such as Insert(), Update(), Delete(), Select() query starters 16 | type Builder interface { 17 | // Build takes a dialect and returns a stmt 18 | Build(dialect Dialect) *Stmt 19 | } 20 | -------------------------------------------------------------------------------- /clauses.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // SQLText returns a raw SQL clause 4 | func SQLText(text string) TextClause { 5 | return TextClause{Text: text} 6 | } 7 | 8 | // TextClause is a raw SQL clause 9 | type TextClause struct { 10 | Text string 11 | } 12 | 13 | // Accept calls the compiler VisitText method 14 | func (c TextClause) Accept(context *CompilerContext) string { 15 | return context.Compiler.VisitText(context, c) 16 | } 17 | 18 | // List returns a list-of-clauses clause 19 | func List(clauses ...Clause) ListClause { 20 | return ListClause{ 21 | Clauses: clauses, 22 | } 23 | } 24 | 25 | // ListClause is a list of clause elements (for IN operator for example) 26 | type ListClause struct { 27 | Clauses []Clause 28 | } 29 | 30 | // Accept calls the compiler VisitList method 31 | func (c ListClause) Accept(context *CompilerContext) string { 32 | return context.Compiler.VisitList(context, c) 33 | } 34 | 35 | // Bind a value 36 | func Bind(value interface{}) BindClause { 37 | return BindClause{ 38 | Value: value, 39 | } 40 | } 41 | 42 | // BindClause binds a value to a placeholder 43 | type BindClause struct { 44 | Value interface{} 45 | } 46 | 47 | // Accept calls the compiler VisitBind method 48 | func (c BindClause) Accept(context *CompilerContext) string { 49 | return context.Compiler.VisitBind(context, c) 50 | } 51 | 52 | // GetClauseFrom returns the value if already a Clause, or make one 53 | // if it is a scalar value 54 | func GetClauseFrom(value interface{}) Clause { 55 | if clause, ok := value.(Clause); ok { 56 | return clause 57 | } 58 | // For now we assume any non-clause is a Value: 59 | return Bind(value) 60 | } 61 | 62 | // GetListFrom returns a list clause from any list 63 | // 64 | // If only one value is passed and is a ListClause, it is returned 65 | // as-is. 66 | // In any other case, a ListClause is built with each value wrapped 67 | // by a Bind() if not already a Clause 68 | func GetListFrom(values ...interface{}) Clause { 69 | if len(values) == 1 { 70 | if clause, ok := values[0].(ListClause); ok { 71 | return clause 72 | } 73 | } 74 | 75 | var clauses []Clause 76 | for _, value := range values { 77 | clauses = append(clauses, GetClauseFrom(value)) 78 | } 79 | return List(clauses...) 80 | } 81 | 82 | // Exists returns a EXISTS clause 83 | func Exists(sel SelectStmt) ExistsClause { 84 | return ExistsClause{ 85 | Select: sel, 86 | Not: false, 87 | } 88 | } 89 | 90 | // NotExists returns a NOT EXISTS clause 91 | func NotExists(sel SelectStmt) ExistsClause { 92 | return ExistsClause{ 93 | Select: sel, 94 | Not: true, 95 | } 96 | } 97 | 98 | // ExistsClause is a EXISTS clause 99 | type ExistsClause struct { 100 | Select SelectStmt 101 | Not bool 102 | } 103 | 104 | // Accept calls compiler VisitExists methos 105 | func (c ExistsClause) Accept(context *CompilerContext) string { 106 | return context.Compiler.VisitExists(context, c) 107 | } 108 | -------------------------------------------------------------------------------- /clauses_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestSQLText(t *testing.T) { 9 | text := SQLText("1") 10 | assert.Equal(t, "1", text.Text) 11 | } 12 | 13 | func TestGetClauseFrom(t *testing.T) { 14 | var c Clause 15 | c = SQLText("1") 16 | assert.Equal(t, c, GetClauseFrom(c)) 17 | 18 | c = GetClauseFrom(2) 19 | b, ok := c.(BindClause) 20 | assert.True(t, ok, "Should have returned a BindClause") 21 | assert.Equal(t, 2, b.Value) 22 | } 23 | 24 | func TestGetListFrom(t *testing.T) { 25 | var c Clause 26 | c = ListClause{} 27 | assert.Equal(t, c, GetListFrom(c)) 28 | 29 | text := SQLText("SOME SQL") 30 | c = GetListFrom(text) 31 | l, ok := c.(ListClause) 32 | assert.True(t, ok, "Should have returned a ListClause") 33 | assert.Equal(t, 1, len(l.Clauses)) 34 | assert.Equal(t, text, l.Clauses[0]) 35 | 36 | c = GetListFrom(2) 37 | l, ok = c.(ListClause) 38 | assert.True(t, ok, "Should have returned a ListClause") 39 | assert.Equal(t, 1, len(l.Clauses)) 40 | assert.Equal(t, 2, l.Clauses[0].(BindClause).Value) 41 | 42 | c = GetListFrom(2, Bind(4)) 43 | l, ok = c.(ListClause) 44 | assert.True(t, ok, "Should have returned a ListClause") 45 | assert.Equal(t, 2, len(l.Clauses)) 46 | assert.Equal(t, 2, l.Clauses[0].(BindClause).Value) 47 | assert.Equal(t, 4, l.Clauses[1].(BindClause).Value) 48 | } 49 | 50 | func TestExists(t *testing.T) { 51 | s := Select() 52 | 53 | e := Exists(s) 54 | assert.False(t, e.Not) 55 | assert.Equal(t, s, e.Select) 56 | 57 | ne := NotExists(s) 58 | assert.True(t, ne.Not) 59 | assert.Equal(t, s, ne.Select) 60 | } 61 | -------------------------------------------------------------------------------- /column.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Column generates a ColumnElem given name and type 9 | func Column(name string, t TypeElem) ColumnElem { 10 | return ColumnElem{ 11 | Name: name, 12 | Type: t, 13 | } 14 | } 15 | 16 | // ColumnOptions holds options for a column 17 | type ColumnOptions struct { 18 | AutoIncrement bool 19 | PrimaryKey bool 20 | InlinePrimaryKey bool 21 | Unique bool 22 | } 23 | 24 | // ColumnElem is the definition of any columns defined in a table 25 | type ColumnElem struct { 26 | Name string 27 | Type TypeElem 28 | Table string // This field should be lazily set by Table() function 29 | Constraints []ConstraintElem 30 | Options ColumnOptions 31 | } 32 | 33 | // AutoIncrement set up “auto increment” semantics for an integer column. 34 | // Depending on the dialect, the column may be required to be a PrimaryKey too. 35 | func (c ColumnElem) AutoIncrement() ColumnElem { 36 | c.Options.AutoIncrement = true 37 | return c 38 | } 39 | 40 | // PrimaryKey add the column to the primary key 41 | func (c ColumnElem) PrimaryKey() ColumnElem { 42 | c.Options.PrimaryKey = true 43 | return c 44 | } 45 | 46 | // inlinePrimaryKey flags the column so it will inline the primary key constraint 47 | func (c ColumnElem) inlinePrimaryKey() ColumnElem { 48 | c.Options.InlinePrimaryKey = true 49 | return c 50 | } 51 | 52 | // String returns the column element as an sql clause 53 | // It satisfies the TableSQLClause interface 54 | func (c ColumnElem) String(dialect Dialect) string { 55 | colSpec := "" 56 | if c.Options.AutoIncrement { 57 | colSpec = dialect.AutoIncrement(&c) 58 | } 59 | if colSpec == "" { 60 | colSpec = dialect.CompileType(c.Type) 61 | constraintNames := []string{} 62 | for _, constraint := range c.Constraints { 63 | constraintNames = append(constraintNames, constraint.String()) 64 | } 65 | if len(constraintNames) != 0 { 66 | colSpec = fmt.Sprintf("%s %s", colSpec, strings.Join(constraintNames, " ")) 67 | } 68 | if c.Options.InlinePrimaryKey { 69 | colSpec += " PRIMARY KEY" 70 | } 71 | } 72 | res := fmt.Sprintf("%s %s", dialect.Escape(c.Name), colSpec) 73 | return res 74 | } 75 | 76 | // Accept calls the compiler VisitColumn function 77 | func (c ColumnElem) Accept(context *CompilerContext) string { 78 | return context.Compiler.VisitColumn(context, c) 79 | } 80 | 81 | // constraints setters 82 | 83 | // Default adds a default constraint to column type 84 | func (c ColumnElem) Default(def interface{}) ColumnElem { 85 | c.Constraints = append(c.Constraints, Default(def)) 86 | return c 87 | } 88 | 89 | // Null adds null constraint to column type 90 | func (c ColumnElem) Null() ColumnElem { 91 | c.Constraints = append(c.Constraints, Null()) 92 | return c 93 | } 94 | 95 | // NotNull adds not null constraint to column type 96 | func (c ColumnElem) NotNull() ColumnElem { 97 | c.Constraints = append(c.Constraints, NotNull()) 98 | return c 99 | } 100 | 101 | // Unique adds a unique constraint to column type 102 | func (c ColumnElem) Unique() ColumnElem { 103 | c.Constraints = append(c.Constraints, Unique()) 104 | c.Options.Unique = true 105 | return c 106 | } 107 | 108 | // Constraint adds a custom constraint to column type 109 | func (c ColumnElem) Constraint(name string) ColumnElem { 110 | c.Constraints = append(c.Constraints, Constraint(name)) 111 | return c 112 | } 113 | 114 | // conditional wrappers 115 | 116 | // Like wraps the Like(col ColumnElem, pattern string) 117 | func (c ColumnElem) Like(pattern string) Clause { 118 | return Like(c, pattern) 119 | } 120 | 121 | // NotIn wraps the NotIn(col ColumnElem, values ...interface{}) 122 | func (c ColumnElem) NotIn(values ...interface{}) Clause { 123 | return NotIn(c, values...) 124 | } 125 | 126 | // In wraps the In(col ColumnElem, values ...interface{}) 127 | func (c ColumnElem) In(values ...interface{}) Clause { 128 | return In(c, values...) 129 | } 130 | 131 | // NotEq wraps the NotEq(col ColumnElem, value interface{}) 132 | func (c ColumnElem) NotEq(value interface{}) Clause { 133 | return NotEq(c, value) 134 | } 135 | 136 | // Eq wraps the Eq(col ColumnElem, value interface{}) 137 | func (c ColumnElem) Eq(value interface{}) Clause { 138 | return Eq(c, value) 139 | } 140 | 141 | // Gt wraps the Gt(col ColumnElem, value interface{}) 142 | func (c ColumnElem) Gt(value interface{}) Clause { 143 | return Gt(c, value) 144 | } 145 | 146 | // Lt wraps the Lt(col ColumnElem, value interface{}) 147 | func (c ColumnElem) Lt(value interface{}) Clause { 148 | return Lt(c, value) 149 | } 150 | 151 | // Gte wraps the Gte(col ColumnElem, value interface{}) 152 | func (c ColumnElem) Gte(value interface{}) Clause { 153 | return Gte(c, value) 154 | } 155 | 156 | // Lte wraps the Lte(col ColumnElem, value interface{}) 157 | func (c ColumnElem) Lte(value interface{}) Clause { 158 | return Lte(c, value) 159 | } 160 | -------------------------------------------------------------------------------- /column_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type ColumnTestSuite struct { 12 | suite.Suite 13 | dialect Dialect 14 | ctx *CompilerContext 15 | } 16 | 17 | func (suite *ColumnTestSuite) SetupTest() { 18 | suite.dialect = NewDefaultDialect() 19 | suite.ctx = NewCompilerContext(suite.dialect) 20 | } 21 | 22 | func (suite *ColumnTestSuite) TestColumnVarcharSpecificSize() { 23 | col := Column("id", Varchar().Size(40)) 24 | assert.Equal(suite.T(), "id", col.Name) 25 | assert.Equal(suite.T(), Varchar().Size(40), col.Type) 26 | assert.Equal(suite.T(), "id VARCHAR(40)", col.String(suite.dialect)) 27 | } 28 | 29 | func (suite *ColumnTestSuite) TestColumnVarcharUniqueNotNullDefault() { 30 | col := Column("s", Varchar().Size(255)).Unique().NotNull().Default("hello") 31 | assert.Equal(suite.T(), "s VARCHAR(255) UNIQUE NOT NULL DEFAULT 'hello'", col.String(suite.dialect)) 32 | } 33 | 34 | func (suite *ColumnTestSuite) TestColumnFloatPrecision() { 35 | col := Column("f", Type("FLOAT").Precision(2, 5)).Null() 36 | assert.Equal(suite.T(), "f FLOAT(2, 5) NULL", col.String(suite.dialect)) 37 | } 38 | 39 | func (suite *ColumnTestSuite) TestColumnIntInlinePrimaryKeyAutoIncrement() { 40 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 41 | assert.Equal(suite.T(), "id INT PRIMARY KEY AUTO INCREMENT", col.String(suite.dialect)) 42 | assert.Equal(suite.T(), "c INT TEST", Column("c", Int()).Constraint("TEST").String(suite.dialect)) 43 | } 44 | 45 | func (suite *ColumnTestSuite) TestColumnLike() { 46 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 47 | like := col.Like("s%") 48 | 49 | sql := like.Accept(suite.ctx) 50 | binds := suite.ctx.Binds 51 | 52 | assert.Equal(suite.T(), "id LIKE ?", sql) 53 | assert.Equal(suite.T(), []interface{}{"s%"}, binds) 54 | } 55 | 56 | func (suite *ColumnTestSuite) TestColumnNotIn() { 57 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 58 | notIn := col.NotIn("id1", "id2") 59 | sql := notIn.Accept(suite.ctx) 60 | binds := suite.ctx.Binds 61 | assert.Equal(suite.T(), "id NOT IN (?, ?)", sql) 62 | assert.Equal(suite.T(), []interface{}{"id1", "id2"}, binds) 63 | } 64 | 65 | func (suite *ColumnTestSuite) TestColumnIn() { 66 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 67 | in := col.In("id1", "id2") 68 | sql := in.Accept(suite.ctx) 69 | binds := suite.ctx.Binds 70 | 71 | assert.Equal(suite.T(), "id IN (?, ?)", sql) 72 | assert.Equal(suite.T(), []interface{}{"id1", "id2"}, binds) 73 | } 74 | 75 | func (suite *ColumnTestSuite) TestColumnNotEq() { 76 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 77 | notEq := col.NotEq("id1") 78 | sql := notEq.Accept(suite.ctx) 79 | binds := suite.ctx.Binds 80 | 81 | assert.Equal(suite.T(), "id != ?", sql) 82 | assert.Equal(suite.T(), []interface{}{"id1"}, binds) 83 | } 84 | 85 | func (suite *ColumnTestSuite) TestColumnEq() { 86 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 87 | eq := col.Eq("id1") 88 | sql := eq.Accept(suite.ctx) 89 | binds := suite.ctx.Binds 90 | 91 | assert.Equal(suite.T(), "id = ?", sql) 92 | assert.Equal(suite.T(), []interface{}{"id1"}, binds) 93 | } 94 | 95 | func (suite *ColumnTestSuite) TestColumnGt() { 96 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 97 | gt := col.Gt("id1") 98 | sql := gt.Accept(suite.ctx) 99 | binds := suite.ctx.Binds 100 | 101 | assert.Equal(suite.T(), "id > ?", sql) 102 | assert.Equal(suite.T(), []interface{}{"id1"}, binds) 103 | } 104 | 105 | func (suite *ColumnTestSuite) TestColumnLt() { 106 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 107 | lt := col.Lt("id1") 108 | sql := lt.Accept(suite.ctx) 109 | binds := suite.ctx.Binds 110 | 111 | assert.Equal(suite.T(), "id < ?", sql) 112 | assert.Equal(suite.T(), []interface{}{"id1"}, binds) 113 | } 114 | 115 | func (suite *ColumnTestSuite) TestcolumnGte() { 116 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 117 | gte := col.Gte("id1") 118 | sql := gte.Accept(suite.ctx) 119 | binds := suite.ctx.Binds 120 | 121 | assert.Equal(suite.T(), "id >= ?", sql) 122 | assert.Equal(suite.T(), []interface{}{"id1"}, binds) 123 | } 124 | 125 | func (suite *ColumnTestSuite) TestColumnLte() { 126 | col := Column("id", Int()).PrimaryKey().AutoIncrement().inlinePrimaryKey() 127 | lte := col.Lte("id1") 128 | sql := lte.Accept(suite.ctx) 129 | binds := suite.ctx.Binds 130 | 131 | assert.Equal(suite.T(), "id <= ?", sql) 132 | assert.Equal(suite.T(), []interface{}{"id1"}, binds) 133 | } 134 | 135 | func TestColumnTestSuite(t *testing.T) { 136 | suite.Run(t, new(ColumnTestSuite)) 137 | } 138 | -------------------------------------------------------------------------------- /combiner.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // And generates an AndClause given conditional clauses 4 | func And(clauses ...Clause) CombinerClause { 5 | return CombinerClause{"AND", clauses} 6 | } 7 | 8 | // Or generates an AndClause given conditional clauses 9 | func Or(clauses ...Clause) CombinerClause { 10 | return CombinerClause{"OR", clauses} 11 | } 12 | 13 | // CombinerClause is for OR and AND clauses 14 | type CombinerClause struct { 15 | operator string 16 | clauses []Clause 17 | } 18 | 19 | // Accept calls the compiler VisitCombiner entry point 20 | func (c CombinerClause) Accept(context *CompilerContext) string { 21 | return context.Compiler.VisitCombiner(context, c) 22 | } 23 | -------------------------------------------------------------------------------- /combiner_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type CombinerTestSuite struct { 12 | suite.Suite 13 | dialect Dialect 14 | ctx *CompilerContext 15 | } 16 | 17 | func (suite *CombinerTestSuite) SetupTest() { 18 | suite.dialect = NewDefaultDialect() 19 | suite.ctx = NewCompilerContext(suite.dialect) 20 | } 21 | 22 | func (suite *CombinerTestSuite) TestCombinerAnd() { 23 | email := Column("email", Varchar()).NotNull().Unique() 24 | id := Column("id", Int()).NotNull() 25 | 26 | and := And(Eq(email, "al@pacino.com"), NotEq(id, 1)) 27 | sql := and.Accept(suite.ctx) 28 | binds := suite.ctx.Binds 29 | 30 | assert.Equal(suite.T(), "(email = ? AND id != ?)", sql) 31 | assert.Equal(suite.T(), []interface{}{"al@pacino.com", 1}, binds) 32 | } 33 | 34 | func (suite *CombinerTestSuite) TestCombinerOr() { 35 | email := Column("email", Varchar()).NotNull().Unique() 36 | id := Column("id", Int()).NotNull() 37 | 38 | or := Or(Eq(email, "al@pacino.com"), NotEq(id, 1)) 39 | sql := or.Accept(suite.ctx) 40 | binds := suite.ctx.Binds 41 | 42 | assert.Equal(suite.T(), "(email = ? OR id != ?)", sql) 43 | assert.Equal(suite.T(), []interface{}{"al@pacino.com", 1}, binds) 44 | } 45 | 46 | func TestCombinerTestSuite(t *testing.T) { 47 | suite.Run(t, new(CombinerTestSuite)) 48 | } 49 | -------------------------------------------------------------------------------- /compiler.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // NewCompilerContext initialize a new compiler context 9 | func NewCompilerContext(dialect Dialect) *CompilerContext { 10 | return &CompilerContext{ 11 | Dialect: dialect, 12 | Compiler: dialect.GetCompiler(), 13 | Vars: make(map[string]interface{}), 14 | Binds: []interface{}{}, 15 | } 16 | } 17 | 18 | // CompilerContext is a data structure passed to all the Compiler visit 19 | // functions. It contains the bindings, links to the Dialect and Compiler 20 | // being used, and some contextual informations that can be used by the 21 | // compiler functions to communicate during the compilation. 22 | type CompilerContext struct { 23 | Binds []interface{} 24 | DefaultTableName string 25 | InSubQuery bool 26 | Vars map[string]interface{} 27 | 28 | Dialect Dialect 29 | Compiler Compiler 30 | } 31 | 32 | // Compiler is a visitor that produce SQL from various types of Clause 33 | type Compiler interface { 34 | VisitAggregate(*CompilerContext, AggregateClause) string 35 | VisitAlias(*CompilerContext, AliasClause) string 36 | VisitBinary(*CompilerContext, BinaryExpressionClause) string 37 | VisitBind(*CompilerContext, BindClause) string 38 | VisitColumn(*CompilerContext, ColumnElem) string 39 | VisitCombiner(*CompilerContext, CombinerClause) string 40 | VisitDelete(*CompilerContext, DeleteStmt) string 41 | VisitExists(*CompilerContext, ExistsClause) string 42 | VisitForUpdate(*CompilerContext, ForUpdateClause) string 43 | VisitHaving(*CompilerContext, HavingClause) string 44 | VisitIn(*CompilerContext, InClause) string 45 | VisitInsert(*CompilerContext, InsertStmt) string 46 | VisitJoin(*CompilerContext, JoinClause) string 47 | VisitLabel(*CompilerContext, string) string 48 | VisitList(*CompilerContext, ListClause) string 49 | VisitOrderBy(*CompilerContext, OrderByClause) string 50 | VisitSelect(*CompilerContext, SelectStmt) string 51 | VisitTable(*CompilerContext, TableElem) string 52 | VisitText(*CompilerContext, TextClause) string 53 | VisitUpdate(*CompilerContext, UpdateStmt) string 54 | VisitUpsert(*CompilerContext, UpsertStmt) string 55 | VisitWhere(*CompilerContext, WhereClause) string 56 | } 57 | 58 | // NewSQLCompiler returns a new SQLCompiler 59 | func NewSQLCompiler(dialect Dialect) SQLCompiler { 60 | return SQLCompiler{Dialect: dialect} 61 | } 62 | 63 | // SQLCompiler aims to provide a SQL ANSI-92 implementation of Compiler 64 | type SQLCompiler struct { 65 | Dialect Dialect 66 | } 67 | 68 | // VisitAggregate compiles aggregate functions (COUNT, SUM...) 69 | func (c SQLCompiler) VisitAggregate(context *CompilerContext, aggregate AggregateClause) string { 70 | return fmt.Sprintf("%s(%s)", aggregate.fn, aggregate.clause.Accept(context)) 71 | } 72 | 73 | // VisitAlias compiles a ' AS ' SQL clause 74 | func (SQLCompiler) VisitAlias(context *CompilerContext, alias AliasClause) string { 75 | return fmt.Sprintf( 76 | "%s AS %s", 77 | alias.Selectable.Accept(context), 78 | context.Dialect.Escape(alias.Name), 79 | ) 80 | } 81 | 82 | // VisitBinary compiles LEFT RIGHT expressions 83 | func (c SQLCompiler) VisitBinary(context *CompilerContext, binary BinaryExpressionClause) string { 84 | return fmt.Sprintf( 85 | "%s %s %s", 86 | binary.Left.Accept(context), 87 | binary.Op, 88 | binary.Right.Accept(context), 89 | ) 90 | } 91 | 92 | // VisitBind renders a bounded value 93 | func (SQLCompiler) VisitBind(context *CompilerContext, bind BindClause) string { 94 | context.Binds = append(context.Binds, bind.Value) 95 | return "?" 96 | } 97 | 98 | // VisitColumn returns a column name, optionnaly escaped depending on the dialect 99 | // configuration 100 | func (c SQLCompiler) VisitColumn(context *CompilerContext, column ColumnElem) string { 101 | sql := "" 102 | if context.InSubQuery || context.DefaultTableName != column.Table { 103 | sql += c.Dialect.Escape(column.Table) + "." 104 | } 105 | sql += c.Dialect.Escape(column.Name) 106 | return sql 107 | } 108 | 109 | // VisitCombiner compiles AND and OR sql clauses 110 | func (c SQLCompiler) VisitCombiner(context *CompilerContext, combiner CombinerClause) string { 111 | sqls := []string{} 112 | for _, c := range combiner.clauses { 113 | sql := c.Accept(context) 114 | sqls = append(sqls, sql) 115 | } 116 | 117 | return fmt.Sprintf("(%s)", strings.Join(sqls, fmt.Sprintf(" %s ", combiner.operator))) 118 | } 119 | 120 | // VisitDelete compiles a DELETE statement 121 | func (c SQLCompiler) VisitDelete(context *CompilerContext, delete DeleteStmt) string { 122 | sql := "DELETE FROM " + delete.table.Accept(context) 123 | 124 | if delete.where != nil { 125 | sql += "\n" + delete.where.Accept(context) 126 | } 127 | 128 | returning := []string{} 129 | for _, c := range delete.returning { 130 | returning = append(returning, context.Dialect.Escape(c.Name)) 131 | } 132 | 133 | if len(returning) > 0 { 134 | sql += "\nRETURNING " + strings.Join(returning, ", ") 135 | } 136 | 137 | return sql 138 | } 139 | 140 | // VisitExists compile a EXISTS clause 141 | func (SQLCompiler) VisitExists(context *CompilerContext, exists ExistsClause) string { 142 | var sql string 143 | if exists.Not { 144 | sql = "NOT " 145 | } 146 | sql += "EXISTS(%s)" 147 | context.InSubQuery = true 148 | defer func() { context.InSubQuery = false }() 149 | return fmt.Sprintf(sql, exists.Select.Accept(context)) 150 | } 151 | 152 | // VisitForUpdate compiles a 'FOR UPDATE' clause 153 | func (c SQLCompiler) VisitForUpdate(context *CompilerContext, forUpdate ForUpdateClause) string { 154 | var sql = "FOR UPDATE" 155 | if len(forUpdate.Tables) != 0 { 156 | var tablenames []string 157 | for _, table := range forUpdate.Tables { 158 | tablenames = append(tablenames, table.Name) 159 | } 160 | sql += " OF " + strings.Join(tablenames, ", ") 161 | } 162 | return sql 163 | } 164 | 165 | // VisitHaving compiles a HAVING clause 166 | func (c SQLCompiler) VisitHaving(context *CompilerContext, having HavingClause) string { 167 | aggSQL := having.aggregate.Accept(context) 168 | return fmt.Sprintf("HAVING %s %s %s", aggSQL, having.op, Bind(having.value).Accept(context)) 169 | } 170 | 171 | // VisitIn compiles a (NOT) IN () 172 | func (c SQLCompiler) VisitIn(context *CompilerContext, in InClause) string { 173 | return fmt.Sprintf( 174 | "%s %s (%s)", 175 | in.Left.Accept(context), 176 | in.Op, 177 | in.Right.Accept(context), 178 | ) 179 | } 180 | 181 | // VisitInsert compiles a INSERT statement 182 | func (c SQLCompiler) VisitInsert(context *CompilerContext, insert InsertStmt) string { 183 | context.DefaultTableName = insert.table.Name 184 | defer func() { context.DefaultTableName = "" }() 185 | 186 | cols := List() 187 | values := List() 188 | for k, v := range insert.values { 189 | cols.Clauses = append(cols.Clauses, insert.table.C(k)) 190 | values.Clauses = append(values.Clauses, Bind(v)) 191 | } 192 | 193 | sql := fmt.Sprintf( 194 | "INSERT INTO %s(%s)\nVALUES(%s)", 195 | insert.table.Accept(context), 196 | cols.Accept(context), 197 | values.Accept(context), 198 | ) 199 | 200 | returning := []string{} 201 | for _, r := range insert.returning { 202 | returning = append(returning, r.Accept(context)) 203 | } 204 | if len(insert.returning) > 0 { 205 | sql += fmt.Sprintf( 206 | "\nRETURNING %s", 207 | strings.Join(returning, ", "), 208 | ) 209 | } 210 | 211 | return sql 212 | } 213 | 214 | // VisitJoin compiles a JOIN (ON) clause 215 | func (c SQLCompiler) VisitJoin(context *CompilerContext, join JoinClause) string { 216 | sql := fmt.Sprintf( 217 | "%s\n%s %s", 218 | join.Left.Accept(context), 219 | join.JoinType, 220 | join.Right.Accept(context), 221 | ) 222 | if join.OnClause != nil { 223 | sql += " ON " + join.OnClause.Accept(context) 224 | } 225 | 226 | return sql 227 | } 228 | 229 | // VisitLabel returns a single label, optionally escaped 230 | func (c SQLCompiler) VisitLabel(context *CompilerContext, label string) string { 231 | return c.Dialect.Escape(label) 232 | } 233 | 234 | // VisitList compiles a list of values 235 | func (c SQLCompiler) VisitList(context *CompilerContext, list ListClause) string { 236 | var clauses []string 237 | for _, clause := range list.Clauses { 238 | clauses = append(clauses, clause.Accept(context)) 239 | } 240 | return strings.Join(clauses, ", ") 241 | } 242 | 243 | // VisitOrderBy compiles a ORDER BY sql clause 244 | func (c SQLCompiler) VisitOrderBy(context *CompilerContext, OrderByClause OrderByClause) string { 245 | cols := []string{} 246 | for _, c := range OrderByClause.columns { 247 | cols = append(cols, c.Accept(context)) 248 | } 249 | 250 | return fmt.Sprintf("ORDER BY %s %s", strings.Join(cols, ", "), OrderByClause.t) 251 | } 252 | 253 | // VisitSelect compiles a SELECT statement 254 | func (c SQLCompiler) VisitSelect(context *CompilerContext, selectStmt SelectStmt) string { 255 | lines := []string{} 256 | addLine := func(s string) { 257 | lines = append(lines, s) 258 | } 259 | if !context.InSubQuery && selectStmt.FromClause != nil { 260 | context.DefaultTableName = selectStmt.FromClause.DefaultName() 261 | } 262 | 263 | // select 264 | columns := []string{} 265 | for _, c := range selectStmt.SelectList { 266 | sql := c.Accept(context) 267 | columns = append(columns, sql) 268 | } 269 | addLine(fmt.Sprintf("SELECT %s", strings.Join(columns, ", "))) 270 | 271 | // from 272 | if selectStmt.FromClause != nil { 273 | addLine(fmt.Sprintf("FROM %s", selectStmt.FromClause.Accept(context))) 274 | } 275 | 276 | // where 277 | if selectStmt.WhereClause != nil { 278 | addLine(selectStmt.WhereClause.Accept(context)) 279 | } 280 | 281 | // group by 282 | groupByCols := []string{} 283 | for _, c := range selectStmt.GroupByClause { 284 | groupByCols = append(groupByCols, context.Dialect.Escape(c.Name)) 285 | } 286 | if len(groupByCols) > 0 { 287 | addLine(fmt.Sprintf("GROUP BY %s", strings.Join(groupByCols, ", "))) 288 | } 289 | 290 | // having 291 | for _, h := range selectStmt.HavingClause { 292 | sql := h.Accept(context) 293 | addLine(sql) 294 | } 295 | 296 | // order by 297 | if selectStmt.OrderByClause != nil { 298 | sql := selectStmt.OrderByClause.Accept(context) 299 | addLine(sql) 300 | } 301 | 302 | if (selectStmt.OffsetValue != nil) || (selectStmt.LimitValue != nil) { 303 | var tokens []string 304 | if selectStmt.LimitValue != nil { 305 | tokens = append(tokens, fmt.Sprintf("LIMIT %d", *selectStmt.LimitValue)) 306 | } 307 | if selectStmt.OffsetValue != nil { 308 | tokens = append(tokens, fmt.Sprintf("OFFSET %d", *selectStmt.OffsetValue)) 309 | } 310 | addLine(strings.Join(tokens, " ")) 311 | } 312 | 313 | if selectStmt.ForUpdateClause != nil { 314 | addLine(selectStmt.ForUpdateClause.Accept(context)) 315 | } 316 | 317 | return strings.Join(lines, "\n") 318 | } 319 | 320 | // VisitTable returns a table name, optionally escaped 321 | func (SQLCompiler) VisitTable(context *CompilerContext, table TableElem) string { 322 | return context.Compiler.VisitLabel(context, table.Name) 323 | } 324 | 325 | // VisitText return a raw SQL clause as is 326 | func (SQLCompiler) VisitText(context *CompilerContext, text TextClause) string { 327 | return text.Text 328 | } 329 | 330 | // VisitUpdate compiles a UPDATE statement 331 | func (c SQLCompiler) VisitUpdate(context *CompilerContext, update UpdateStmt) string { 332 | context.DefaultTableName = update.table.Name 333 | defer func() { context.DefaultTableName = "" }() 334 | 335 | sql := "UPDATE " + update.table.Accept(context) 336 | 337 | sets := List() 338 | 339 | for k, v := range update.values { 340 | sets.Clauses = append(sets.Clauses, 341 | Eq(update.table.C(k), Bind(v))) 342 | } 343 | 344 | if len(sets.Clauses) > 0 { 345 | sql += "\nSET " + sets.Accept(context) 346 | } 347 | 348 | if update.where != nil { 349 | sql += "\n" + update.where.Accept(context) 350 | } 351 | 352 | returning := []string{} 353 | for _, c := range update.returning { 354 | returning = append(returning, context.Dialect.Escape(c.Name)) 355 | } 356 | 357 | if len(returning) > 0 { 358 | sql += "\nRETURNING " + strings.Join(returning, ", ") 359 | } 360 | 361 | return sql 362 | } 363 | 364 | // VisitUpsert is not implemented and will panic. 365 | // It should be implemented in each dialect 366 | func (c SQLCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt) string { 367 | panic("Upsert is not Implemented in this compiler") 368 | } 369 | 370 | // VisitWhere compiles a WHERE clause 371 | func (c SQLCompiler) VisitWhere(context *CompilerContext, where WhereClause) string { 372 | return fmt.Sprintf("WHERE %s", where.clause.Accept(context)) 373 | } 374 | -------------------------------------------------------------------------------- /compiler_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | var emptyBinds = []interface{}{} 9 | 10 | var ( 11 | TTGroup = Table( 12 | "group", 13 | Column("id", Int()).AutoIncrement().PrimaryKey(), 14 | Column("name", Text()).Unique(), 15 | ) 16 | 17 | TTUser = Table( 18 | "user", 19 | Column("id", Int()).AutoIncrement().PrimaryKey(), 20 | Column("name", Text()).Unique(), 21 | Column("main_group_id", Int()), 22 | ForeignKey("main_group_id").References("group", "id"), 23 | ) 24 | ) 25 | 26 | var compileTests = []struct { 27 | clause Clause 28 | expect string 29 | binds []interface{} 30 | }{ 31 | {SQLText("1"), "1", emptyBinds}, 32 | { 33 | Join("LEFT JOIN", TTGroup, TTUser), 34 | "group\nLEFT JOIN user ON user.main_group_id = group.id", 35 | emptyBinds, 36 | }, 37 | { 38 | Join("LEFT JOIN", TTGroup, TTUser, TTGroup.C("id").Eq(TTUser.C("id"))), 39 | "group\nLEFT JOIN user ON group.id = user.id", 40 | emptyBinds, 41 | }, 42 | { 43 | Join("LEFT JOIN", TTGroup, TTUser, TTGroup.C("id"), TTUser.C("id")), 44 | "group\nLEFT JOIN user ON group.id = user.id", 45 | emptyBinds, 46 | }, 47 | { 48 | Exists(Select(TTGroup.C("name")).From(TTGroup).Where(TTGroup.C("id").Eq(TTUser.C("main_group_id")))), 49 | "EXISTS(SELECT group.name\nFROM group\nWHERE group.id = user.main_group_id)", 50 | emptyBinds, 51 | }, 52 | { 53 | NotExists(Select(TTGroup.C("name")).From(TTGroup).Where(TTGroup.C("id").Eq(TTUser.C("main_group_id")))), 54 | "NOT EXISTS(SELECT group.name\nFROM group\nWHERE group.id = user.main_group_id)", 55 | emptyBinds, 56 | }, 57 | { 58 | Select(Exists(Select(SQLText("1")).From(TTGroup).Where(TTGroup.C("id").Eq(TTUser.C("main_group_id"))))), 59 | "SELECT EXISTS(SELECT 1\nFROM group\nWHERE group.id = user.main_group_id)", 60 | emptyBinds, 61 | }, 62 | { 63 | Select(SQLText("1")).From(TTGroup).ForUpdate(), 64 | "SELECT 1\nFROM group\nFOR UPDATE", 65 | emptyBinds, 66 | }, 67 | { 68 | Select(SQLText("1")).From(TTGroup).ForUpdate(TTUser, TTGroup), 69 | "SELECT 1\nFROM group\nFOR UPDATE OF user, group", 70 | emptyBinds, 71 | }, 72 | } 73 | 74 | func TestCompile(t *testing.T) { 75 | compile := func(clause Clause) (string, []interface{}) { 76 | context := NewCompilerContext(NewDialect("default")) 77 | return clause.Accept(context), context.Binds 78 | } 79 | 80 | for _, tt := range compileTests { 81 | actual, binds := compile(tt.clause) 82 | assert.Equal(t, tt.expect, actual) 83 | assert.Equal(t, tt.binds, binds) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /conditional.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // conditional generators, comparator functions 4 | 5 | // Like generates a like conditional sql clause 6 | func Like(left Clause, right interface{}) BinaryExpressionClause { 7 | return BinaryExpression(left, "LIKE", GetClauseFrom(right)) 8 | } 9 | 10 | // In generates an IN conditional sql clause 11 | func In(left Clause, values ...interface{}) InClause { 12 | return InClause{BinaryExpressionClause{ 13 | Left: left, 14 | Op: "IN", 15 | Right: GetListFrom(values...), 16 | }} 17 | } 18 | 19 | // NotIn generates an NOT IN conditional sql clause 20 | func NotIn(left Clause, values ...interface{}) InClause { 21 | return InClause{BinaryExpressionClause{ 22 | Left: left, 23 | Op: "NOT IN", 24 | Right: GetListFrom(values...), 25 | }} 26 | } 27 | 28 | // NotEq generates a not equal conditional sql clause 29 | func NotEq(left Clause, right interface{}) BinaryExpressionClause { 30 | return BinaryExpression(left, "!=", GetClauseFrom(right)) 31 | } 32 | 33 | // Eq generates a equals conditional sql clause 34 | func Eq(left Clause, right interface{}) BinaryExpressionClause { 35 | return BinaryExpression(left, "=", GetClauseFrom(right)) 36 | } 37 | 38 | // Gt generates a greater than conditional sql clause 39 | func Gt(left Clause, right interface{}) BinaryExpressionClause { 40 | return BinaryExpression(left, ">", GetClauseFrom(right)) 41 | } 42 | 43 | // Lt generates a less than conditional sql clause 44 | func Lt(left Clause, right interface{}) BinaryExpressionClause { 45 | return BinaryExpression(left, "<", GetClauseFrom(right)) 46 | } 47 | 48 | // Gte generates a greater than or equal to conditional sql clause 49 | func Gte(left Clause, right interface{}) BinaryExpressionClause { 50 | return BinaryExpression(left, ">=", GetClauseFrom(right)) 51 | } 52 | 53 | // Lte generates a less than or equal to conditional sql clause 54 | func Lte(left Clause, right interface{}) BinaryExpressionClause { 55 | return BinaryExpression(left, "<=", GetClauseFrom(right)) 56 | } 57 | 58 | // BinaryExpression generates a condition object to use in update, delete & select statements 59 | func BinaryExpression(left Clause, op string, right Clause) BinaryExpressionClause { 60 | return BinaryExpressionClause{ 61 | Left: left, 62 | Right: right, 63 | Op: op, 64 | } 65 | } 66 | 67 | // BinaryExpressionClause is the base struct for any conditional statements in sql clauses 68 | type BinaryExpressionClause struct { 69 | Left Clause 70 | Right Clause 71 | Op string 72 | } 73 | 74 | // Accept calls the compiler VisitBinary method 75 | func (c BinaryExpressionClause) Accept(context *CompilerContext) string { 76 | return context.Compiler.VisitBinary(context, c) 77 | } 78 | 79 | // InClause is a IN or NOT IN binary expression 80 | type InClause struct { 81 | BinaryExpressionClause 82 | } 83 | 84 | // Accept calls the compiler VisitBinary method 85 | func (c InClause) Accept(context *CompilerContext) string { 86 | return context.Compiler.VisitIn(context, c) 87 | } 88 | -------------------------------------------------------------------------------- /conditional_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type ConditionalTestSuite struct { 12 | suite.Suite 13 | dialect Dialect 14 | ctx *CompilerContext 15 | country ColumnElem 16 | score ColumnElem 17 | } 18 | 19 | func (suite *ConditionalTestSuite) SetupTest() { 20 | suite.dialect = NewDefaultDialect() 21 | suite.ctx = NewCompilerContext(suite.dialect) 22 | suite.country = Column("country", Varchar()).NotNull() 23 | suite.score = Column("score", BigInt()).NotNull() 24 | } 25 | 26 | func (suite *ConditionalTestSuite) TestConditionalLike() { 27 | like := Like(suite.country, "%land%") 28 | sql := like.Accept(suite.ctx) 29 | bindings := suite.ctx.Binds 30 | 31 | assert.Equal(suite.T(), "country LIKE ?", sql) 32 | assert.Equal(suite.T(), []interface{}{"%land%"}, bindings) 33 | } 34 | 35 | func (suite *ConditionalTestSuite) TestConditionalNotIn() { 36 | notIn := NotIn(suite.country, "USA", "England", "Sweden") 37 | sql := notIn.Accept(suite.ctx) 38 | bindings := suite.ctx.Binds 39 | 40 | assert.Equal(suite.T(), "country NOT IN (?, ?, ?)", sql) 41 | assert.Equal(suite.T(), []interface{}{"USA", "England", "Sweden"}, bindings) 42 | } 43 | 44 | func (suite *ConditionalTestSuite) TestConditionalIn() { 45 | in := In(suite.country, "USA", "England", "Sweden") 46 | sql := in.Accept(suite.ctx) 47 | bindings := suite.ctx.Binds 48 | assert.Equal(suite.T(), "country IN (?, ?, ?)", sql) 49 | assert.Equal(suite.T(), []interface{}{"USA", "England", "Sweden"}, bindings) 50 | } 51 | 52 | func (suite *ConditionalTestSuite) TestConditionalNotEq() { 53 | notEq := NotEq(suite.country, "USA") 54 | 55 | sql := notEq.Accept(suite.ctx) 56 | bindings := suite.ctx.Binds 57 | 58 | assert.Equal(suite.T(), "country != ?", sql) 59 | assert.Equal(suite.T(), []interface{}{"USA"}, bindings) 60 | } 61 | 62 | func (suite *ConditionalTestSuite) TestConditionalEq() { 63 | eq := Eq(suite.country, "Turkey") 64 | 65 | sql := eq.Accept(suite.ctx) 66 | bindings := suite.ctx.Binds 67 | 68 | assert.Equal(suite.T(), "country = ?", sql) 69 | assert.Equal(suite.T(), []interface{}{"Turkey"}, bindings) 70 | } 71 | 72 | func (suite *ConditionalTestSuite) TestConditionalGt() { 73 | gt := Gt(suite.score, 1500) 74 | 75 | sql := gt.Accept(suite.ctx) 76 | bindings := suite.ctx.Binds 77 | 78 | assert.Equal(suite.T(), "score > ?", sql) 79 | assert.Equal(suite.T(), []interface{}{1500}, bindings) 80 | } 81 | 82 | func (suite *ConditionalTestSuite) TestConditionalLt() { 83 | lt := Lt(suite.score, 1500) 84 | 85 | sql := lt.Accept(suite.ctx) 86 | bindings := suite.ctx.Binds 87 | 88 | assert.Equal(suite.T(), "score < ?", sql) 89 | assert.Equal(suite.T(), []interface{}{1500}, bindings) 90 | } 91 | 92 | func (suite *ConditionalTestSuite) TestConditionalGte() { 93 | gte := Gte(suite.score, 1500) 94 | 95 | sql := gte.Accept(suite.ctx) 96 | bindings := suite.ctx.Binds 97 | 98 | assert.Equal(suite.T(), "score >= ?", sql) 99 | assert.Equal(suite.T(), []interface{}{1500}, bindings) 100 | } 101 | 102 | func (suite *ConditionalTestSuite) TestConditionalLte() { 103 | lte := Lte(suite.score, 1500) 104 | 105 | sql := lte.Accept(suite.ctx) 106 | bindings := suite.ctx.Binds 107 | 108 | assert.Equal(suite.T(), "score <= ?", sql) 109 | assert.Equal(suite.T(), []interface{}{1500}, bindings) 110 | } 111 | 112 | func TestConditionalTestSuite(t *testing.T) { 113 | suite.Run(t, new(ConditionalTestSuite)) 114 | } 115 | -------------------------------------------------------------------------------- /constraint.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Null generates generic null constraint 9 | func Null() ConstraintElem { 10 | return ConstraintElem{"NULL"} 11 | } 12 | 13 | // NotNull generates generic not null constraint 14 | func NotNull() ConstraintElem { 15 | return ConstraintElem{"NOT NULL"} 16 | } 17 | 18 | // Default generates generic default constraint 19 | func Default(value interface{}) ConstraintElem { 20 | return ConstraintElem{fmt.Sprintf("DEFAULT '%v'", value)} 21 | } 22 | 23 | // Unique generates generic unique constraint 24 | // if cols are given, then composite unique constraint will be built 25 | func Unique() ConstraintElem { 26 | return ConstraintElem{"UNIQUE"} 27 | } 28 | 29 | // Constraint generates a custom constraint due to variation of dialects 30 | func Constraint(name string) ConstraintElem { 31 | return ConstraintElem{name} 32 | } 33 | 34 | // ConstraintElem is the definition of column & table constraints 35 | type ConstraintElem struct { 36 | Name string 37 | } 38 | 39 | // String returns the constraint as an sql clause 40 | func (c ConstraintElem) String() string { 41 | return c.Name 42 | } 43 | 44 | // PrimaryKey generates a primary key constraint of any table 45 | func PrimaryKey(cols ...string) PrimaryKeyConstraint { 46 | return PrimaryKeyConstraint{cols} 47 | } 48 | 49 | // PrimaryKeyConstraint is the definition of primary key constraints of any table 50 | type PrimaryKeyConstraint struct { 51 | Columns []string 52 | } 53 | 54 | // String returns the primary key constraints as an sql clause 55 | func (c PrimaryKeyConstraint) String(dialect Dialect) string { 56 | cols := []string{} 57 | for _, col := range c.Columns { 58 | cols = append(cols, dialect.Escape(col)) 59 | } 60 | 61 | return fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(cols, ", ")) 62 | } 63 | 64 | // ForeignKey generates a foreign key for table constraint definitions 65 | func ForeignKey(cols ...string) ForeignKeyConstraint { 66 | return ForeignKeyConstraint{Cols: cols} 67 | } 68 | 69 | // ForeignKeyConstraints is the definition of foreign keys in any table 70 | type ForeignKeyConstraints struct { 71 | FKeys []ForeignKeyConstraint 72 | } 73 | 74 | func (c ForeignKeyConstraints) String(dialect Dialect) string { 75 | clauses := []string{} 76 | for _, fkey := range c.FKeys { 77 | clauses = append(clauses, fkey.String(dialect)) 78 | } 79 | 80 | return strings.Join(clauses, ",\n") 81 | } 82 | 83 | // ForeignKeyConstraint is the main struct for defining foreign key references 84 | type ForeignKeyConstraint struct { 85 | Cols []string 86 | RefTable string 87 | RefCols []string 88 | ActionOnUpdate string 89 | ActionOnDelete string 90 | } 91 | 92 | func (fkey ForeignKeyConstraint) String(dialect Dialect) string { 93 | ddl := fmt.Sprintf( 94 | "\tFOREIGN KEY(%s) REFERENCES %s(%s)", 95 | strings.Join(dialect.EscapeAll(fkey.Cols), ", "), 96 | dialect.Escape(fkey.RefTable), 97 | strings.Join(dialect.EscapeAll(fkey.RefCols), ", "), 98 | ) 99 | if fkey.ActionOnUpdate != "" { 100 | ddl += " ON UPDATE " + fkey.ActionOnUpdate 101 | } 102 | if fkey.ActionOnDelete != "" { 103 | ddl += " ON DELETE " + fkey.ActionOnDelete 104 | } 105 | return ddl 106 | } 107 | 108 | func checkFKeyCascadeAction(action string) string { 109 | actionUp := strings.ToUpper(action) 110 | if actionUp != "" && 111 | actionUp != "CASCADE" && 112 | actionUp != "NO ACTION" && 113 | actionUp != "RESTRICT" && 114 | actionUp != "SET NULL" { 115 | panic("Invalid cascading action: " + actionUp) 116 | } 117 | return actionUp 118 | } 119 | 120 | // References set the reference part of the foreign key 121 | func (fkey ForeignKeyConstraint) References(refTable string, refCols ...string) ForeignKeyConstraint { 122 | fkey.RefTable = refTable 123 | fkey.RefCols = refCols 124 | return fkey 125 | } 126 | 127 | // OnUpdate set the ON UPDATE action 128 | func (fkey ForeignKeyConstraint) OnUpdate(action string) ForeignKeyConstraint { 129 | fkey.ActionOnUpdate = checkFKeyCascadeAction(action) 130 | return fkey 131 | } 132 | 133 | // OnDelete set the ON DELETE action 134 | func (fkey ForeignKeyConstraint) OnDelete(action string) ForeignKeyConstraint { 135 | fkey.ActionOnDelete = checkFKeyCascadeAction(action) 136 | return fkey 137 | } 138 | 139 | // UniqueKey generates UniqueKeyConstraint given columns as strings 140 | func UniqueKey(cols ...string) UniqueKeyConstraint { 141 | return UniqueKeyConstraint{ 142 | "", 143 | cols, 144 | } 145 | } 146 | 147 | // UniqueKeyConstraint is the base struct to define composite unique indexes of tables 148 | type UniqueKeyConstraint struct { 149 | name string 150 | cols []string 151 | } 152 | 153 | // String generates composite unique indices as sql clause 154 | func (c UniqueKeyConstraint) String(dialect Dialect) string { 155 | return fmt.Sprintf("CONSTRAINT %s UNIQUE(%s)", c.name, strings.Join(dialect.EscapeAll(c.cols), ", ")) 156 | } 157 | 158 | // Table optionally set the constraint name based on the table name 159 | // if a name is already defined, it remains untouched 160 | func (c UniqueKeyConstraint) Table(name string) UniqueKeyConstraint { 161 | return c.Name( 162 | fmt.Sprintf("u_%s_%s", name, strings.Join(c.cols, "_")), 163 | ) 164 | } 165 | 166 | // Name set the constraint name 167 | func (c UniqueKeyConstraint) Name(name string) UniqueKeyConstraint { 168 | c.name = name 169 | return c 170 | } 171 | -------------------------------------------------------------------------------- /constraint_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestConstraints(t *testing.T) { 9 | dialect := NewDialect("default") 10 | 11 | assert.Equal(t, Constraint("NULL"), Null()) 12 | assert.Equal(t, Constraint("NOT NULL"), NotNull()) 13 | assert.Equal(t, Constraint("DEFAULT '5'"), Default(5)) 14 | assert.Equal(t, Constraint("UNIQUE"), Unique()) 15 | assert.Equal(t, ConstraintElem{"CHECK id > 5"}, Constraint("CHECK id > 5")) 16 | assert.Equal(t, "NOT NULL", NotNull().String()) 17 | 18 | assert.Equal(t, "PRIMARY KEY(id)", PrimaryKey("id").String(dialect)) 19 | assert.Equal(t, "PRIMARY KEY(id, email)", PrimaryKey("id", "email").String(dialect)) 20 | 21 | assert.Contains(t, 22 | ForeignKey("user_id").References("users", "id").String(dialect), 23 | "FOREIGN KEY(user_id) REFERENCES users(id)") 24 | assert.Contains(t, 25 | ForeignKey("user_id", "user_email").References("users", "id", "email").String(dialect), 26 | "FOREIGN KEY(user_id, user_email) REFERENCES users(id, email)") 27 | 28 | assert.Panics(t, func() { 29 | ForeignKey().OnUpdate("invalid") 30 | }) 31 | assert.Panics(t, func() { 32 | ForeignKey().OnDelete("invalid") 33 | }) 34 | assert.Equal(t, 35 | "\tFOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL", 36 | ForeignKey("user_id").References("users", "id").OnDelete("SET NULL").String(dialect), 37 | ) 38 | assert.Equal(t, 39 | "\tFOREIGN KEY(user_id) REFERENCES users(id) ON UPDATE CASCADE ON DELETE CASCADE", 40 | ForeignKey("user_id").References("users", "id").OnUpdate("CASCADE").OnDelete("CASCADE").String(dialect), 41 | ) 42 | 43 | assert.Equal(t, 44 | "CONSTRAINT u_users_id_email UNIQUE(id, email)", 45 | UniqueKey("id", "email").Table("users").String(dialect)) 46 | } 47 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Delete generates a delete statement and returns it for chaining 4 | // qb.Delete(usersTable).Where(qb.Eq("id", 5)) 5 | func Delete(table TableElem) DeleteStmt { 6 | return DeleteStmt{ 7 | table: table, 8 | returning: []ColumnElem{}, 9 | } 10 | } 11 | 12 | // DeleteStmt is the base struct for building delete queries 13 | type DeleteStmt struct { 14 | table TableElem 15 | where *WhereClause 16 | returning []ColumnElem 17 | } 18 | 19 | // Where adds a where clause to the current delete statement 20 | func (s DeleteStmt) Where(clause Clause) DeleteStmt { 21 | s.where = &WhereClause{clause} 22 | return s 23 | } 24 | 25 | // Returning accepts the column names as strings and forms the returning array of insert statement 26 | // NOTE: Please use it in only postgres dialect, otherwise it'll crash 27 | func (s DeleteStmt) Returning(cols ...ColumnElem) DeleteStmt { 28 | s.returning = append(s.returning, cols...) 29 | return s 30 | } 31 | 32 | // Accept implements Clause.Accept 33 | func (s DeleteStmt) Accept(context *CompilerContext) string { 34 | return context.Compiler.VisitDelete(context, s) 35 | } 36 | 37 | // Build generates a statement out of DeleteStmt object 38 | func (s DeleteStmt) Build(dialect Dialect) *Stmt { 39 | context := NewCompilerContext(dialect) 40 | statement := Statement() 41 | statement.AddSQLClause(s.Accept(context)) 42 | statement.AddBinding(context.Binds...) 43 | 44 | return statement 45 | } 46 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestDelete(t *testing.T) { 9 | dialect := NewDialect("default") 10 | 11 | users := Table( 12 | "users", 13 | Column("id", Varchar().Size(36)), 14 | Column("email", Varchar()).Unique(), 15 | ) 16 | 17 | var statement *Stmt 18 | 19 | statement = Delete(users). 20 | Where(Eq(users.C("id"), 5)). 21 | Build(dialect) 22 | 23 | assert.Equal(t, "DELETE FROM users\nWHERE users.id = ?;", statement.SQL()) 24 | assert.Equal(t, []interface{}{5}, statement.Bindings()) 25 | 26 | statement = Delete(users). 27 | Where(Eq(users.C("id"), 5)). 28 | Returning(users.C("id")). 29 | Build(dialect) 30 | 31 | assert.Equal(t, "DELETE FROM users\nWHERE users.id = ?\nRETURNING id;", statement.SQL()) 32 | assert.Equal(t, []interface{}{5}, statement.Bindings()) 33 | 34 | statement = Delete(users).Build(dialect) 35 | assert.Equal(t, "DELETE FROM users;", statement.SQL()) 36 | } 37 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // NewDialect returns a dialect pointer given driver 4 | func NewDialect(driver string) Dialect { 5 | dialect, ok := DialectRegistry[driver] 6 | if ok { 7 | return dialect 8 | } 9 | panic("No such dialect: " + driver) 10 | } 11 | 12 | // DialectRegistry is a global registry of dialects 13 | var DialectRegistry = make(map[string]Dialect) 14 | 15 | // RegisterDialect add a new dialect to the registry 16 | func RegisterDialect(name string, dialect Dialect) { 17 | DialectRegistry[name] = dialect 18 | } 19 | 20 | // Dialect is the common interface for driver changes 21 | // It is for fixing compatibility issues of different drivers 22 | type Dialect interface { 23 | GetCompiler() Compiler 24 | CompileType(t TypeElem) string 25 | Escape(str string) string 26 | EscapeAll([]string) []string 27 | SetEscaping(escaping bool) 28 | Escaping() bool 29 | AutoIncrement(column *ColumnElem) string 30 | SupportsUnsigned() bool 31 | Driver() string 32 | WrapError(err error) Error 33 | } 34 | 35 | // EscapeAll common escape all 36 | func EscapeAll(dialect Dialect, strings []string) []string { 37 | for k, v := range strings { 38 | strings[k] = dialect.Escape(v) 39 | } 40 | 41 | return strings 42 | } 43 | -------------------------------------------------------------------------------- /dialect_default.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import "fmt" 4 | 5 | // DefaultDialect is a type of dialect that can be used with unsupported sql drivers 6 | type DefaultDialect struct { 7 | escaping bool 8 | } 9 | 10 | // NewDefaultDialect instanciate a DefaultDialect 11 | func NewDefaultDialect() Dialect { 12 | return &DefaultDialect{false} 13 | } 14 | 15 | // CompileType compiles a type into its DDL 16 | func (d *DefaultDialect) CompileType(t TypeElem) string { 17 | return DefaultCompileType(t, d.SupportsUnsigned()) 18 | } 19 | 20 | // Escape wraps the string with escape characters of the dialect 21 | func (d *DefaultDialect) Escape(str string) string { 22 | if d.escaping { 23 | return fmt.Sprintf("`%s`", str) 24 | } 25 | return str 26 | } 27 | 28 | // EscapeAll wraps all elements of string array 29 | func (d *DefaultDialect) EscapeAll(strings []string) []string { 30 | return EscapeAll(d, strings[0:]) 31 | } 32 | 33 | // SetEscaping sets the escaping parameter of dialect 34 | func (d *DefaultDialect) SetEscaping(escaping bool) { 35 | d.escaping = escaping 36 | } 37 | 38 | // Escaping gets the escaping parameter of dialect 39 | func (d *DefaultDialect) Escaping() bool { 40 | return d.escaping 41 | } 42 | 43 | // AutoIncrement generates auto increment sql of current dialect 44 | func (d *DefaultDialect) AutoIncrement(column *ColumnElem) string { 45 | colSpec := d.CompileType(column.Type) 46 | if column.Options.PrimaryKey { 47 | colSpec += " PRIMARY KEY" 48 | } 49 | colSpec += " AUTO INCREMENT" 50 | return colSpec 51 | } 52 | 53 | // SupportsUnsigned returns whether driver supports unsigned type mappings or not 54 | func (d *DefaultDialect) SupportsUnsigned() bool { return false } 55 | 56 | // Driver returns the current driver of dialect 57 | func (d *DefaultDialect) Driver() string { 58 | return "" 59 | } 60 | 61 | // GetCompiler returns the default SQLCompiler 62 | func (d *DefaultDialect) GetCompiler() Compiler { 63 | return SQLCompiler{d} 64 | } 65 | 66 | // WrapError wraps a native error in a qb Error 67 | func (d *DefaultDialect) WrapError(err error) Error { 68 | return Error{Orig: err} 69 | } 70 | 71 | func init() { 72 | RegisterDialect("default", NewDefaultDialect()) 73 | RegisterDialect("", NewDefaultDialect()) 74 | } 75 | -------------------------------------------------------------------------------- /dialect_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "errors" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestDefaultDialect(t *testing.T) { 10 | dialect := NewDialect("default") 11 | assert.Implements(t, (*Compiler)(nil), dialect.GetCompiler()) 12 | assert.Equal(t, false, dialect.SupportsUnsigned()) 13 | assert.Equal(t, "test", dialect.Escape("test")) 14 | assert.Equal(t, false, dialect.Escaping()) 15 | dialect.SetEscaping(true) 16 | assert.Equal(t, true, dialect.Escaping()) 17 | assert.Equal(t, "`test`", dialect.Escape("test")) 18 | assert.Equal(t, []string{"`test`"}, dialect.EscapeAll([]string{"test"})) 19 | assert.Equal(t, "", dialect.Driver()) 20 | 21 | autoincCol := Column("id", Int()).PrimaryKey().AutoIncrement() 22 | assert.Equal(t, 23 | "INT PRIMARY KEY AUTO INCREMENT", 24 | dialect.AutoIncrement(&autoincCol)) 25 | 26 | err := errors.New("xxx") 27 | qbErr := dialect.WrapError(err) 28 | assert.Equal(t, err, qbErr.Orig) 29 | } 30 | 31 | func TestGetDialect(t *testing.T) { 32 | assert.Panics(t, func() { 33 | NewDialect("unknown") 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /dialects/mysql/README.md: -------------------------------------------------------------------------------- 1 | # Mysql dialect 2 | 3 | Implements the Dialect interface for a MySQL database, using 4 | the following sql driver: 5 | 6 | github.com/go-sql-driver/mysql 7 | 8 | ## Error translation 9 | 10 | Because the driver only passes the numeric error codes, we had to redefine 11 | all the error constants in errors.go. 12 | This is done automatically from the mysql headers by running the following 13 | command in the current directory: 14 | 15 | go generate 16 | 17 | The script "tools/generrors.go" is doing the actual job of writing the source 18 | file. 19 | It requires the headers to be installed, and was last executed with 20 | libmysqlclient-dev 5.7.16-0ubuntu0.16.04.1 on a ubuntu 16.04. 21 | -------------------------------------------------------------------------------- /dialects/mysql/mysql.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/go-sql-driver/mysql" 8 | "github.com/slicebit/qb" 9 | ) 10 | 11 | //go:generate go run ./tools/generrors.go 12 | //go:generate gofmt -w errors.go 13 | 14 | const () 15 | 16 | // Dialect is a type of dialect that can be used with mysql driver 17 | type Dialect struct { 18 | escaping bool 19 | } 20 | 21 | // NewDialect returns a new MysqlDialect 22 | func NewDialect() qb.Dialect { 23 | return &Dialect{false} 24 | } 25 | 26 | func init() { 27 | qb.RegisterDialect("mysql", NewDialect()) 28 | } 29 | 30 | // CompileType compiles a type into its DDL 31 | func (d *Dialect) CompileType(t qb.TypeElem) string { 32 | if t.Name == "UUID" { 33 | return "VARCHAR(36)" 34 | } 35 | return qb.DefaultCompileType(t, d.SupportsUnsigned()) 36 | } 37 | 38 | // Escape wraps the string with escape characters of the dialect 39 | func (d *Dialect) Escape(str string) string { 40 | if d.escaping { 41 | return fmt.Sprintf("`%s`", str) 42 | } 43 | return str 44 | } 45 | 46 | // EscapeAll wraps all elements of string array 47 | func (d *Dialect) EscapeAll(strings []string) []string { 48 | return qb.EscapeAll(d, strings[0:]) 49 | } 50 | 51 | // SetEscaping sets the escaping parameter of dialect 52 | func (d *Dialect) SetEscaping(escaping bool) { 53 | d.escaping = escaping 54 | } 55 | 56 | // Escaping gets the escaping parameter of dialect 57 | func (d *Dialect) Escaping() bool { 58 | return d.escaping 59 | } 60 | 61 | // AutoIncrement generates auto increment sql of current dialect 62 | func (d *Dialect) AutoIncrement(column *qb.ColumnElem) string { 63 | colSpec := d.CompileType(column.Type) 64 | if column.Options.InlinePrimaryKey { 65 | colSpec += " PRIMARY KEY" 66 | } 67 | colSpec += " AUTO_INCREMENT" 68 | return colSpec 69 | } 70 | 71 | // SupportsUnsigned returns whether driver supports unsigned type mappings or not 72 | func (d *Dialect) SupportsUnsigned() bool { return true } 73 | 74 | // Driver returns the current driver of dialect 75 | func (d *Dialect) Driver() string { 76 | return "mysql" 77 | } 78 | 79 | // GetCompiler returns a MysqlCompiler 80 | func (d *Dialect) GetCompiler() qb.Compiler { 81 | return MysqlCompiler{qb.NewSQLCompiler(d)} 82 | } 83 | 84 | // WrapError wraps a native error in a qb Error 85 | func (d *Dialect) WrapError(err error) qb.Error { 86 | qbErr := qb.Error{Orig: err} 87 | mErr, ok := err.(*mysql.MySQLError) 88 | if !ok { 89 | return qbErr 90 | } 91 | // Error mapping logic is copied from MySQL-python-1.2.5 92 | switch mErr.Number { 93 | case CR_COMMANDS_OUT_OF_SYNC, 94 | ER_DB_CREATE_EXISTS, 95 | ER_SYNTAX_ERROR, 96 | ER_PARSE_ERROR, 97 | ER_NO_SUCH_TABLE, 98 | ER_WRONG_DB_NAME, 99 | ER_WRONG_TABLE_NAME, 100 | ER_FIELD_SPECIFIED_TWICE, 101 | ER_INVALID_GROUP_FUNC_USE, 102 | ER_UNSUPPORTED_EXTENSION, 103 | ER_TABLE_MUST_HAVE_COLUMNS, 104 | ER_CANT_DO_THIS_DURING_AN_TRANSACTION: 105 | qbErr.Code = qb.ErrProgramming 106 | case WARN_DATA_TRUNCATED, 107 | ER_WARN_DATA_OUT_OF_RANGE, 108 | ER_NO_DEFAULT, 109 | ER_PRIMARY_CANT_HAVE_NULL, 110 | ER_DATA_TOO_LONG, 111 | ER_DATETIME_FUNCTION_OVERFLOW: 112 | qbErr.Code = qb.ErrData 113 | case ER_DUP_ENTRY, 114 | ER_DUP_UNIQUE, 115 | ER_NO_REFERENCED_ROW, 116 | ER_NO_REFERENCED_ROW_2, 117 | ER_ROW_IS_REFERENCED, 118 | ER_ROW_IS_REFERENCED_2, 119 | ER_CANNOT_ADD_FOREIGN: 120 | qbErr.Code = qb.ErrIntegrity 121 | case ER_WARNING_NOT_COMPLETE_ROLLBACK, 122 | ER_NOT_SUPPORTED_YET, 123 | ER_FEATURE_DISABLED, 124 | ER_UNKNOWN_STORAGE_ENGINE: 125 | qbErr.Code = qb.ErrNotSupported 126 | default: 127 | if mErr.Number < 1000 { 128 | qbErr.Code = qb.ErrInternal 129 | } else { 130 | qbErr.Code = qb.ErrOperational 131 | } 132 | } 133 | return qbErr 134 | } 135 | 136 | // MysqlCompiler is a SQLCompiler specialised for Mysql 137 | type MysqlCompiler struct { 138 | qb.SQLCompiler 139 | } 140 | 141 | // VisitUpsert generates INSERT INTO ... VALUES ... ON DUPLICATE KEY UPDATE ... 142 | func (MysqlCompiler) VisitUpsert(context *qb.CompilerContext, upsert qb.UpsertStmt) string { 143 | var ( 144 | colNames []string 145 | values []string 146 | ) 147 | 148 | for k, v := range upsert.ValuesMap { 149 | colNames = append(colNames, context.Compiler.VisitLabel(context, k)) 150 | context.Binds = append(context.Binds, v) 151 | values = append(values, "?") 152 | } 153 | 154 | updates := []string{} 155 | for k, v := range upsert.ValuesMap { 156 | updates = append(updates, fmt.Sprintf( 157 | "%s = %s", 158 | context.Dialect.Escape(k), 159 | "?", 160 | )) 161 | context.Binds = append(context.Binds, v) 162 | } 163 | 164 | sql := fmt.Sprintf( 165 | "INSERT INTO %s(%s)\nVALUES(%s)\nON DUPLICATE KEY UPDATE %s", 166 | context.Dialect.Escape(upsert.Table.Name), 167 | strings.Join(colNames, ", "), 168 | strings.Join(values, ", "), 169 | strings.Join(updates, ", "), 170 | ) 171 | 172 | return sql 173 | } 174 | -------------------------------------------------------------------------------- /dialects/mysql/mysql_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/go-sql-driver/mysql" 11 | "github.com/slicebit/qb" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/suite" 14 | ) 15 | 16 | var mysqlDsn = "root:@tcp(localhost:3306)/qb_test?charset=utf8" 17 | 18 | type MysqlTestSuite struct { 19 | suite.Suite 20 | engine *qb.Engine 21 | metadata *qb.MetaDataElem 22 | } 23 | 24 | func (suite *MysqlTestSuite) SetupTest() { 25 | var err error 26 | suite.engine, err = qb.New("mysql", mysqlDsn) 27 | 28 | assert.Nil(suite.T(), err) 29 | err = suite.engine.Ping() 30 | 31 | assert.Nil(suite.T(), err) 32 | suite.metadata = qb.MetaData() 33 | 34 | assert.Nil(suite.T(), err) 35 | assert.NotNil(suite.T(), suite.engine) 36 | 37 | suite.engine.DB().Exec("DROP TABLE IF EXISTS user") 38 | suite.engine.DB().Exec("DROP TABLE IF EXISTS session") 39 | } 40 | 41 | func (suite *MysqlTestSuite) TestUUID() { 42 | assert.Equal(suite.T(), "VARCHAR(36)", suite.engine.Dialect().CompileType(qb.UUID())) 43 | } 44 | 45 | func (suite *MysqlTestSuite) TestDialect() { 46 | dialect := qb.NewDialect("mysql") 47 | assert.Equal(suite.T(), true, dialect.SupportsUnsigned()) 48 | assert.Equal(suite.T(), "test", dialect.Escape("test")) 49 | assert.Equal(suite.T(), false, dialect.Escaping()) 50 | dialect.SetEscaping(true) 51 | assert.Equal(suite.T(), true, dialect.Escaping()) 52 | assert.Equal(suite.T(), "`test`", dialect.Escape("test")) 53 | assert.Equal(suite.T(), []string{"`test`"}, dialect.EscapeAll([]string{"test"})) 54 | assert.Equal(suite.T(), "mysql", dialect.Driver()) 55 | } 56 | 57 | func (suite *MysqlTestSuite) TestWrapError() { 58 | dialect := qb.NewDialect("mysql") 59 | err := errors.New("xxx") 60 | qbErr := dialect.WrapError(err) 61 | assert.Equal(suite.T(), err, qbErr.Orig) 62 | 63 | for _, tt := range []struct { 64 | mErr uint16 65 | qbCode qb.ErrorCode 66 | }{ 67 | {ER_SYNTAX_ERROR, qb.ErrProgramming}, 68 | {ER_DATA_TOO_LONG, qb.ErrData}, 69 | {ER_CANNOT_ADD_FOREIGN, qb.ErrIntegrity}, 70 | {ER_FEATURE_DISABLED, qb.ErrNotSupported}, 71 | {ER_CHECKREAD, qb.ErrOperational}, 72 | {999, qb.ErrInternal}, 73 | } { 74 | mErr := mysql.MySQLError{Number: tt.mErr} 75 | qbErr := dialect.WrapError(&mErr) 76 | assert.Equal(suite.T(), tt.qbCode, qbErr.Code) 77 | } 78 | } 79 | 80 | func (suite *MysqlTestSuite) TestMysql() { 81 | type User struct { 82 | ID string `db:"id"` 83 | Email string `db:"email"` 84 | FullName string `db:"full_name"` 85 | Bio sql.NullString `db:"bio"` 86 | Oscars int `db:"oscars"` 87 | } 88 | 89 | type Session struct { 90 | ID int64 `db:"id"` 91 | UserID string `db:"user_id"` 92 | AuthToken string `db:"auth_token"` 93 | CreatedAt *time.Time `db:"created_at"` 94 | ExpiresAt *time.Time `db:"expires_at"` 95 | } 96 | 97 | users := qb.Table( 98 | "user", 99 | qb.Column("id", qb.Varchar().Size(40)), 100 | qb.Column("email", qb.Varchar()).Unique().NotNull(), 101 | qb.Column("full_name", qb.Varchar()).NotNull(), 102 | qb.Column("bio", qb.Text()).Null(), 103 | qb.Column("oscars", qb.Int()).Default(0), 104 | qb.PrimaryKey("id"), 105 | ) 106 | 107 | sessions := qb.Table( 108 | "session", 109 | qb.Column("id", qb.BigInt()).AutoIncrement(), 110 | qb.Column("user_id", qb.Varchar().Size(40)).NotNull(), 111 | qb.Column("auth_token", qb.Varchar().Size(40)).NotNull().Unique(), 112 | qb.Column("created_at", qb.Timestamp()).Null(), 113 | qb.Column("expires_at", qb.Timestamp()).Null(), 114 | qb.PrimaryKey("id"), 115 | qb.ForeignKey("user_id").References("user", "id"), 116 | ) 117 | 118 | var err error 119 | 120 | suite.metadata.AddTable(users) 121 | suite.metadata.AddTable(sessions) 122 | 123 | err = suite.metadata.CreateAll(suite.engine) 124 | assert.Nil(suite.T(), err) 125 | 126 | ins := qb.Insert(users).Values(map[string]interface{}{ 127 | "id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 128 | "email": "jack@nicholson.com", 129 | "full_name": "Jack Nicholson", 130 | "bio": sql.NullString{String: "Jack Nicholson, an American actor, producer, screen-writer and director, is a three-time Academy Award winner and twelve-time nominee.", Valid: true}, 131 | }) 132 | 133 | _, err = suite.engine.Exec(ins) 134 | assert.Nil(suite.T(), err) 135 | 136 | ins = qb.Insert(sessions).Values(map[string]interface{}{ 137 | "user_id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 138 | "auth_token": "e4968197-6137-47a4-ba79-690d8c552248", 139 | "created_at": time.Now(), 140 | "expires_at": time.Now().Add(24 * time.Hour), 141 | }) 142 | 143 | res, err := suite.engine.Exec(ins) 144 | assert.Nil(suite.T(), err) 145 | 146 | lastInsertID, err := res.LastInsertId() 147 | assert.Nil(suite.T(), err) 148 | assert.Equal(suite.T(), lastInsertID, int64(1)) 149 | 150 | rowsAffected, err := res.RowsAffected() 151 | assert.Equal(suite.T(), rowsAffected, int64(1)) 152 | 153 | // find user 154 | var user User 155 | 156 | sel := qb.Select(users.C("id"), users.C("email"), users.C("full_name"), users.C("bio")). 157 | From(users). 158 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 159 | 160 | err = suite.engine.Get(sel, &user) 161 | assert.Nil(suite.T(), err) 162 | 163 | assert.Equal(suite.T(), "jack@nicholson.com", user.Email) 164 | assert.Equal(suite.T(), "Jack Nicholson", user.FullName) 165 | assert.Equal(suite.T(), "Jack Nicholson, an American actor, producer, screen-writer and director, is a three-time Academy Award winner and twelve-time nominee.", user.Bio.String) 166 | 167 | // select using join 168 | sessionSlice := []Session{} 169 | sel = qb.Select(sessions.C("id"), sessions.C("user_id"), sessions.C("auth_token")). 170 | From(sessions). 171 | InnerJoin(users, sessions.C("user_id"), users.C("id")). 172 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 173 | 174 | err = suite.engine.Select(sel, &sessionSlice) 175 | 176 | assert.Nil(suite.T(), err) 177 | assert.Equal(suite.T(), len(sessionSlice), 1) 178 | 179 | assert.Equal(suite.T(), sessionSlice[0].ID, int64(1)) 180 | assert.Equal(suite.T(), sessionSlice[0].UserID, "b6f8bfe3-a830-441a-a097-1777e6bfae95") 181 | assert.Equal(suite.T(), sessionSlice[0].AuthToken, "e4968197-6137-47a4-ba79-690d8c552248") 182 | 183 | upd := qb.Update(users). 184 | Values(map[string]interface{}{ 185 | "bio": sql.NullString{String: "nil", Valid: false}, 186 | }).Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 187 | 188 | _, err = suite.engine.Exec(upd) 189 | assert.Nil(suite.T(), err) 190 | 191 | sel = qb.Select(users.C("id"), users.C("email"), users.C("full_name"), users.C("bio")). 192 | From(users). 193 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 194 | 195 | err = suite.engine.Get(sel, &user) 196 | assert.Equal(suite.T(), user.Bio, sql.NullString{String: "", Valid: false}) 197 | 198 | del := qb.Delete(sessions).Where(sessions.C("auth_token").Eq("99e591f8-1025-41ef-a833-6904a0f89a38")) 199 | _, err = suite.engine.Exec(del) 200 | assert.Nil(suite.T(), err) 201 | 202 | // drop tables 203 | assert.Nil(suite.T(), suite.metadata.DropAll(suite.engine)) 204 | } 205 | 206 | func (suite *MysqlTestSuite) TestUpsert() { 207 | users := qb.Table( 208 | "users", 209 | qb.Column("id", qb.Varchar().Size(36)), 210 | qb.Column("email", qb.Varchar()).Unique(), 211 | qb.Column("created_at", qb.Timestamp()).NotNull(), 212 | qb.PrimaryKey("id"), 213 | ) 214 | 215 | now := time.Now().UTC().String() 216 | 217 | ups := qb.Upsert(users).Values(map[string]interface{}{ 218 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 219 | "email": "al@pacino.com", 220 | "created_at": now, 221 | }) 222 | 223 | ctx := qb.NewCompilerContext(NewDialect()) 224 | sql := ups.Accept(ctx) 225 | binds := ctx.Binds 226 | 227 | assert.Contains(suite.T(), sql, "INSERT INTO users") 228 | assert.Contains(suite.T(), sql, "id", "email", "created_at") 229 | assert.Contains(suite.T(), sql, "VALUES(?, ?, ?)") 230 | assert.Contains(suite.T(), sql, "ON DUPLICATE KEY UPDATE") 231 | assert.Contains(suite.T(), sql, "id = ?", "email = ?", "created_at = ?") 232 | assert.Contains(suite.T(), binds, "9883cf81-3b56-4151-ae4e-3903c5bc436d") 233 | assert.Contains(suite.T(), binds, "al@pacino.com") 234 | assert.Equal(suite.T(), 6, len(binds)) 235 | } 236 | 237 | func TestMysqlTestSuite(t *testing.T) { 238 | suite.Run(t, new(MysqlTestSuite)) 239 | } 240 | 241 | func init() { 242 | dsn := os.Getenv("QBTEST_MYSQL") 243 | if dsn != "" { 244 | mysqlDsn = dsn 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /dialects/mysql/tools/generrors.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "os" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type constant struct { 14 | Name string 15 | Value int 16 | } 17 | 18 | func readConstants(filename string, prefix string) []constant { 19 | var constants []constant 20 | f, err := os.Open(filename) 21 | if err != nil { 22 | panic(err) 23 | } 24 | defer f.Close() 25 | scanner := bufio.NewScanner(f) 26 | splitter := regexp.MustCompile("[ \t]+") 27 | for scanner.Scan() { 28 | line := scanner.Text() 29 | if strings.Contains(line, "#define "+prefix) { 30 | tokens := splitter.Split(line, -1) 31 | if tokens[0] != "#define" { 32 | panic(tokens[0]) 33 | } 34 | value, err := strconv.Atoi(tokens[2]) 35 | if err != nil { 36 | os.Stderr.WriteString("Skipped " + line + "\n") 37 | } 38 | constants = append(constants, constant{ 39 | Name: tokens[1], 40 | Value: value, 41 | }) 42 | } 43 | } 44 | if err := scanner.Err(); err != nil { 45 | panic(err) 46 | } 47 | return constants 48 | } 49 | 50 | func writeConstant(w io.Writer, c constant) { 51 | fmt.Fprintf(w, " %s = %d\n", c.Name, c.Value) 52 | } 53 | 54 | func main() { 55 | f, err := os.OpenFile("errors.go", os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0666) 56 | if err != nil { 57 | panic(err) 58 | } 59 | defer f.Close() 60 | f.WriteString(`package mysql 61 | 62 | // Errors defined in errmsg.h 63 | const ( 64 | `) 65 | 66 | for _, c := range readConstants("/usr/include/mysql/errmsg.h", "CR_") { 67 | writeConstant(f, c) 68 | } 69 | f.WriteString(`) 70 | 71 | // Errors defined in mysqld_error.h 72 | const ( 73 | `) 74 | for _, c := range readConstants("/usr/include/mysql/mysqld_error.h", "ER_") { 75 | writeConstant(f, c) 76 | } 77 | for _, c := range readConstants("/usr/include/mysql/mysqld_error.h", "WARN_") { 78 | writeConstant(f, c) 79 | } 80 | f.WriteString(")\n") 81 | } 82 | -------------------------------------------------------------------------------- /dialects/postgres/postgres.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/lib/pq" 8 | "github.com/slicebit/qb" 9 | ) 10 | 11 | // Dialect is a type of dialect that can be used with postgres driver 12 | type Dialect struct { 13 | bindingIndex int 14 | escaping bool 15 | } 16 | 17 | // NewDialect returns a new PostgresDialect 18 | func NewDialect() qb.Dialect { 19 | return &Dialect{escaping: false, bindingIndex: 0} 20 | } 21 | 22 | func init() { 23 | qb.RegisterDialect("postgres", NewDialect()) 24 | } 25 | 26 | // CompileType compiles a type into its DDL 27 | func (d *Dialect) CompileType(t qb.TypeElem) string { 28 | if t.Name == "BLOB" { 29 | return "bytea" 30 | } 31 | return qb.DefaultCompileType(t, d.SupportsUnsigned()) 32 | } 33 | 34 | // Escape wraps the string with escape characters of the dialect 35 | func (d *Dialect) Escape(str string) string { 36 | if d.escaping { 37 | return fmt.Sprintf("\"%s\"", str) 38 | } 39 | return str 40 | } 41 | 42 | // EscapeAll wraps all elements of string array 43 | func (d *Dialect) EscapeAll(strings []string) []string { 44 | return qb.EscapeAll(d, strings[0:]) 45 | } 46 | 47 | // SetEscaping sets the escaping parameter of dialect 48 | func (d *Dialect) SetEscaping(escaping bool) { 49 | d.escaping = escaping 50 | } 51 | 52 | // Escaping gets the escaping parameter of dialect 53 | func (d *Dialect) Escaping() bool { 54 | return d.escaping 55 | } 56 | 57 | // AutoIncrement generates auto increment sql of current dialect 58 | func (d *Dialect) AutoIncrement(column *qb.ColumnElem) string { 59 | var colSpec string 60 | if column.Type.Name == "BIGINT" { 61 | colSpec = "BIGSERIAL" 62 | } else if column.Type.Name == "SMALLINT" { 63 | colSpec = "SMALLSERIAL" 64 | } else { 65 | colSpec = "SERIAL" 66 | } 67 | if column.Options.InlinePrimaryKey { 68 | colSpec += " PRIMARY KEY" 69 | } 70 | return colSpec 71 | } 72 | 73 | // SupportsUnsigned returns whether driver supports unsigned type mappings or not 74 | func (d *Dialect) SupportsUnsigned() bool { return false } 75 | 76 | // Driver returns the current driver of dialect 77 | func (d *Dialect) Driver() string { 78 | return "postgres" 79 | } 80 | 81 | // GetCompiler returns a PostgresCompiler 82 | func (d *Dialect) GetCompiler() qb.Compiler { 83 | return PostgresCompiler{qb.NewSQLCompiler(d)} 84 | } 85 | 86 | // WrapError wraps a native error in a qb Error 87 | func (d *Dialect) WrapError(err error) (qbErr qb.Error) { 88 | qbErr.Orig = err 89 | pgErr, ok := err.(*pq.Error) 90 | if !ok { 91 | return 92 | } 93 | switch pgErr.Code.Class() { 94 | case "0A": // Class 0A - Feature Not Supported 95 | qbErr.Code = qb.ErrNotSupported 96 | case "20", // Class 20 - Case Not Found 97 | "21": // Class 21 - Cardinality Violation 98 | qbErr.Code = qb.ErrProgramming 99 | case "22": // Class 22 - Data Exception 100 | qbErr.Code = qb.ErrData 101 | case "23": // Class 23 - Integrity Constraint Violation 102 | qbErr.Code = qb.ErrIntegrity 103 | case "24", // Class 24 - Invalid Cursor State 104 | "25": // Class 25 - Invalid Transaction State 105 | qbErr.Code = qb.ErrInternal 106 | case "26", // Class 26 - Invalid SQL Statement Name 107 | "27", // Class 27 - Triggered Data Change Violation 108 | "28": // Class 28 - Invalid Authorization Specification 109 | qbErr.Code = qb.ErrOperational 110 | case "2B", // Class 2B - Dependent Privilege Descriptors Still Exist 111 | "2D", // Class 2D - Invalid Transaction Termination 112 | "2F": // Class 2F - SQL Routine Exception 113 | qbErr.Code = qb.ErrInternal 114 | case "34": // Class 34 - Invalid Cursor Name 115 | qbErr.Code = qb.ErrOperational 116 | case "38", // Class 38 - External Routine Exception 117 | "39", // Class 39 - External Routine Invocation Exception 118 | "3B": // Class 3B - Savepoint Exception 119 | qbErr.Code = qb.ErrInternal 120 | case "3D", // Class 3D - Invalid Catalog Name 121 | "3F": // Class 3F - Invalid Schema Name 122 | qbErr.Code = qb.ErrProgramming 123 | case "40": // Class 40 - Transaction Rollback 124 | qbErr.Code = qb.ErrOperational 125 | case "42", // Class 42 - Syntax Error or Access Rule Violation 126 | "44": // Class 44 - WITH CHECK OPTION Violation 127 | qbErr.Code = qb.ErrProgramming 128 | case "53", // Class 53 - Insufficient Resources 129 | "54", // Class 54 - Program Limit Exceeded 130 | "55", // Class 55 - Object Not In Prerequisite State 131 | "57", // Class 57 - Operator Intervention 132 | "58": // Class 58 - System Error (errors external to PostgreSQL itself) 133 | qbErr.Code = qb.ErrOperational 134 | 135 | case "F0": // Class F0 - Configuration File Error 136 | qbErr.Code = qb.ErrInternal 137 | case "HV": // Class HV - Foreign Data Wrapper Error (SQL/MED) 138 | qbErr.Code = qb.ErrOperational 139 | case "P0", // Class P0 - PL/pgSQL Error 140 | "XX": // Class XX - Internal Error 141 | qbErr.Code = qb.ErrInternal 142 | default: 143 | qbErr.Code = qb.ErrDatabase 144 | } 145 | return 146 | } 147 | 148 | // PostgresCompiler is a SQLCompiler specialised for PostgreSQL 149 | type PostgresCompiler struct { 150 | qb.SQLCompiler 151 | } 152 | 153 | // VisitBind renders a bounded value 154 | func (PostgresCompiler) VisitBind(context *qb.CompilerContext, bind qb.BindClause) string { 155 | context.Binds = append(context.Binds, bind.Value) 156 | return fmt.Sprintf("$%d", len(context.Binds)) 157 | } 158 | 159 | // VisitUpsert generates INSERT INTO ... VALUES ... ON CONFLICT(...) DO UPDATE SET ... 160 | func (PostgresCompiler) VisitUpsert(context *qb.CompilerContext, upsert qb.UpsertStmt) string { 161 | var ( 162 | colNames []string 163 | values []string 164 | ) 165 | for k, v := range upsert.ValuesMap { 166 | colNames = append(colNames, context.Compiler.VisitLabel(context, k)) 167 | context.Binds = append(context.Binds, v) 168 | values = append(values, fmt.Sprintf("$%d", len(context.Binds))) 169 | } 170 | 171 | var updates []string 172 | for k, v := range upsert.ValuesMap { 173 | context.Binds = append(context.Binds, v) 174 | updates = append(updates, fmt.Sprintf( 175 | "%s = %s", 176 | context.Dialect.Escape(k), 177 | fmt.Sprintf("$%d", len(context.Binds)), 178 | )) 179 | } 180 | 181 | var uniqueCols []string 182 | for _, c := range upsert.Table.PrimaryCols() { 183 | uniqueCols = append(uniqueCols, context.Compiler.VisitLabel(context, c.Name)) 184 | } 185 | 186 | sql := fmt.Sprintf( 187 | "INSERT INTO %s(%s)\nVALUES(%s)\nON CONFLICT (%s) DO UPDATE SET %s", 188 | context.Compiler.VisitLabel(context, upsert.Table.Name), 189 | strings.Join(colNames, ", "), 190 | strings.Join(values, ", "), 191 | strings.Join(uniqueCols, ", "), 192 | strings.Join(updates, ", ")) 193 | 194 | var returning []string 195 | for _, r := range upsert.ReturningCols { 196 | returning = append(returning, context.Compiler.VisitLabel(context, r.Name)) 197 | } 198 | if len(returning) > 0 { 199 | sql += fmt.Sprintf( 200 | "RETURNING %s", 201 | strings.Join(returning, ", "), 202 | ) 203 | } 204 | return sql 205 | } 206 | -------------------------------------------------------------------------------- /dialects/postgres/postgres_test.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "testing" 9 | "time" 10 | 11 | "github.com/lib/pq" 12 | "github.com/slicebit/qb" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/suite" 15 | ) 16 | 17 | var postgresDsn = "user=postgres dbname=qb_test sslmode=disable" 18 | 19 | type PostgresTestSuite struct { 20 | suite.Suite 21 | engine *qb.Engine 22 | metadata *qb.MetaDataElem 23 | ctx *qb.CompilerContext 24 | } 25 | 26 | func (suite *PostgresTestSuite) SetupTest() { 27 | var err error 28 | 29 | suite.engine, err = qb.New("postgres", postgresDsn) 30 | suite.ctx = qb.NewCompilerContext(suite.engine.Dialect()) 31 | 32 | assert.Nil(suite.T(), err) 33 | assert.NotNil(suite.T(), suite.engine) 34 | 35 | suite.metadata = qb.MetaData() 36 | } 37 | 38 | func (suite *PostgresTestSuite) TestPostgresBlob() { 39 | dialect := NewDialect() 40 | assert.Equal(suite.T(), "bytea", dialect.CompileType(qb.Blob())) 41 | } 42 | 43 | func (suite *PostgresTestSuite) TestUUID() { 44 | dialect := NewDialect() 45 | assert.Equal(suite.T(), "UUID", dialect.CompileType(qb.UUID())) 46 | } 47 | 48 | func (suite *PostgresTestSuite) TestDialectSimple() { 49 | dialect := NewDialect() 50 | assert.Equal(suite.T(), false, dialect.SupportsUnsigned()) 51 | assert.Equal(suite.T(), "test", dialect.Escape("test")) 52 | assert.Equal(suite.T(), false, dialect.Escaping()) 53 | assert.Equal(suite.T(), "postgres", dialect.Driver()) 54 | } 55 | 56 | func (suite *PostgresTestSuite) TestDialectEscaping() { 57 | dialect := NewDialect() 58 | dialect.SetEscaping(true) 59 | assert.Equal(suite.T(), true, dialect.Escaping()) 60 | assert.Equal(suite.T(), "\"test\"", dialect.Escape("test")) 61 | assert.Equal(suite.T(), []string{"\"test\""}, dialect.EscapeAll([]string{"test"})) 62 | } 63 | 64 | func (suite *PostgresTestSuite) TestDialectIntAutoIncrement() { 65 | dialect := NewDialect() 66 | col := qb.Column("autoinc", qb.Int()).AutoIncrement() 67 | assert.Equal(suite.T(), "SERIAL", dialect.AutoIncrement(&col)) 68 | } 69 | 70 | func (suite *PostgresTestSuite) TestDialectBigIntAutoIncrement() { 71 | dialect := NewDialect() 72 | col := qb.Column("autoinc", qb.BigInt()).AutoIncrement() 73 | assert.Equal(suite.T(), "BIGSERIAL", dialect.AutoIncrement(&col)) 74 | } 75 | 76 | func (suite *PostgresTestSuite) TestDialectSmallIntAutoIncrement() { 77 | dialect := NewDialect() 78 | col := qb.Column("autoinc", qb.SmallInt()).AutoIncrement() 79 | assert.Equal(suite.T(), "SMALLSERIAL", dialect.AutoIncrement(&col)) 80 | } 81 | 82 | func (suite *PostgresTestSuite) TestWrapError() { 83 | err := errors.New("xxx") 84 | dialect := NewDialect() 85 | qbErr := dialect.WrapError(err) 86 | assert.Equal(suite.T(), err, qbErr.Orig) 87 | 88 | for _, tt := range []struct { 89 | pgCode string 90 | qbCode qb.ErrorCode 91 | }{ 92 | {"0A000", qb.ErrNotSupported}, 93 | {"20000", qb.ErrProgramming}, 94 | {"21000", qb.ErrProgramming}, 95 | {"22000", qb.ErrData}, 96 | {"23000", qb.ErrIntegrity}, 97 | {"24000", qb.ErrInternal}, 98 | {"27000", qb.ErrOperational}, 99 | {"2D000", qb.ErrInternal}, 100 | {"34000", qb.ErrOperational}, 101 | {"39000", qb.ErrInternal}, 102 | {"3D000", qb.ErrProgramming}, 103 | {"40000", qb.ErrOperational}, 104 | {"42000", qb.ErrProgramming}, 105 | {"54000", qb.ErrOperational}, 106 | {"F0000", qb.ErrInternal}, 107 | {"HV000", qb.ErrOperational}, 108 | {"P0000", qb.ErrInternal}, 109 | {"ZZ000", qb.ErrDatabase}, 110 | } { 111 | pgErr := pq.Error{Code: pq.ErrorCode(tt.pgCode)} 112 | qbErr := suite.engine.Dialect().WrapError(&pgErr) 113 | assert.Equal(suite.T(), tt.qbCode, qbErr.Code) 114 | } 115 | } 116 | 117 | func (suite *PostgresTestSuite) TestPostgres() { 118 | type Actor struct { 119 | ID string `db:"id"` 120 | Email string `db:"email"` 121 | FullName string `db:"full_name"` 122 | Bio sql.NullString `db:"bio"` 123 | Oscars int `db:"oscars"` 124 | IgnoreField string `db:"-"` 125 | } 126 | 127 | type Session struct { 128 | ID int64 `db:"id"` 129 | ActorID string `db:"actor_id"` 130 | AuthToken string `db:"auth_token"` 131 | CreatedAt *time.Time `db:"created_at"` 132 | ExpiresAt *time.Time `db:"expires_at"` 133 | } 134 | 135 | actorsTable := qb.Table( 136 | "actors", 137 | qb.Column("id", qb.Type("UUID")), 138 | qb.Column("email", qb.Varchar()).Unique().NotNull(), 139 | qb.Column("full_name", qb.Varchar()).NotNull(), 140 | qb.Column("bio", qb.Text()).Null(), 141 | qb.Column("oscars", qb.Int()).Default(0), 142 | qb.PrimaryKey("id"), 143 | ) 144 | 145 | sessionsTable := qb.Table( 146 | "sessions", 147 | qb.Column("id", qb.Type("BIGSERIAL")), 148 | qb.Column("actor_id", qb.Type("UUID")), 149 | qb.Column("auth_token", qb.Type("UUID")), 150 | qb.Column("created_at", qb.Timestamp()).NotNull(), 151 | qb.Column("expires_at", qb.Timestamp()).NotNull(), 152 | qb.PrimaryKey("id"), 153 | qb.ForeignKey("actor_id").References("actors", "id"), 154 | ).Index("created_at", "expires_at") 155 | 156 | var err error 157 | 158 | suite.metadata.AddTable(actorsTable) 159 | suite.metadata.AddTable(sessionsTable) 160 | 161 | err = suite.metadata.CreateAll(suite.engine) 162 | fmt.Println("Metadata create all", err) 163 | assert.Nil(suite.T(), err) 164 | 165 | ins := qb.Insert(actorsTable).Values(map[string]interface{}{ 166 | "id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 167 | "email": "jack@nicholson.com", 168 | "full_name": "Jack Nicholson", 169 | "bio": sql.NullString{String: "Jack Nicholson, an American actor, producer, screen-writer and director, is a three-time Academy Award winner and twelve-time nominee.", Valid: true}, 170 | }) 171 | 172 | _, err = suite.engine.Exec(ins) 173 | 174 | ins = qb.Insert(sessionsTable).Values(map[string]interface{}{ 175 | "actor_id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 176 | "auth_token": "e4968197-6137-47a4-ba79-690d8c552248", 177 | "created_at": time.Now(), 178 | "expires_at": time.Now().Add(24 * time.Hour), 179 | }).Returning(sessionsTable.C("id")) 180 | 181 | var id int64 182 | err = suite.engine.QueryRow(ins).Scan(&id) 183 | assert.Nil(suite.T(), err) 184 | 185 | statement := qb.Insert(actorsTable).Values(map[string]interface{}{ 186 | "id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 187 | "email": "jack@nicholson.com", 188 | "full_name": "Jack Nicholson", 189 | "bio": sql.NullString{}, 190 | }) 191 | 192 | _, err = suite.engine.Exec(statement) 193 | assert.NotNil(suite.T(), err) 194 | 195 | statement = qb.Insert(actorsTable).Values(map[string]interface{}{ 196 | "id": "cf28d117-a12d-4b75-acd8-73a7d3cbb15f", 197 | "email": "jack@nicholson2.com", 198 | "full_name": "Jack Nicholson", 199 | "bio": sql.NullString{}, 200 | }) 201 | 202 | _, err = suite.engine.Exec(statement) 203 | assert.Nil(suite.T(), err) 204 | 205 | // find user using QueryRow() 206 | sel := qb.Select( 207 | actorsTable.C("id"), 208 | actorsTable.C("email"), 209 | actorsTable.C("full_name"), 210 | actorsTable.C("bio")). 211 | From(actorsTable). 212 | Where(actorsTable.C("id").Eq("cf28d117-a12d-4b75-acd8-73a7d3cbb15f")) 213 | 214 | row := suite.engine.QueryRow(sel) 215 | assert.NotNil(suite.T(), row) 216 | 217 | // find user using Query() 218 | rows, err := suite.engine.Query(sel) 219 | 220 | assert.Nil(suite.T(), err) 221 | rowLength := 0 222 | for rows.Next() { 223 | rowLength++ 224 | } 225 | assert.Equal(suite.T(), 1, rowLength) 226 | 227 | // find user using session api's Find() 228 | var actor Actor 229 | 230 | sel = qb.Select( 231 | actorsTable.C("id"), 232 | actorsTable.C("email"), 233 | actorsTable.C("full_name"), 234 | actorsTable.C("bio")). 235 | From(actorsTable). 236 | Where(actorsTable.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 237 | 238 | err = suite.engine.Get(sel, &actor) 239 | assert.Nil(suite.T(), err) 240 | 241 | assert.Equal(suite.T(), "jack@nicholson.com", actor.Email) 242 | assert.Equal(suite.T(), "Jack Nicholson", actor.FullName) 243 | assert.Equal(suite.T(), "Jack Nicholson, an American actor, producer, screen-writer and director, is a three-time Academy Award winner and twelve-time nominee.", actor.Bio.String) 244 | 245 | // select using join 246 | sessionSlice := []Session{} 247 | 248 | sel = qb.Select( 249 | sessionsTable.C("id"), 250 | sessionsTable.C("actor_id"), 251 | sessionsTable.C("auth_token"), 252 | sessionsTable.C("created_at"), 253 | sessionsTable.C("expires_at")). 254 | From(sessionsTable). 255 | InnerJoin(actorsTable, sessionsTable.C("actor_id"), actorsTable.C("id")). 256 | Where(sessionsTable.C("actor_id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 257 | 258 | err = suite.engine.Select(sel, &sessionSlice) 259 | 260 | assert.Nil(suite.T(), err) 261 | assert.Equal(suite.T(), 1, len(sessionSlice)) 262 | 263 | assert.Equal(suite.T(), int64(1), sessionSlice[0].ID) 264 | assert.Equal(suite.T(), "b6f8bfe3-a830-441a-a097-1777e6bfae95", sessionSlice[0].ActorID) 265 | assert.Equal(suite.T(), "e4968197-6137-47a4-ba79-690d8c552248", sessionSlice[0].AuthToken) 266 | 267 | // update user 268 | 269 | upd := qb.Update(actorsTable).Values(map[string]interface{}{ 270 | "bio": sql.NullString{Valid: false}, 271 | }) 272 | 273 | _, err = suite.engine.Exec(upd) 274 | 275 | assert.Nil(suite.T(), err) 276 | 277 | sel = qb.Select( 278 | actorsTable.C("id"), 279 | actorsTable.C("email"), 280 | actorsTable.C("full_name"), 281 | actorsTable.C("bio")). 282 | From(actorsTable). 283 | Where(actorsTable.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 284 | 285 | err = suite.engine.Get(sel, &actor) 286 | assert.Nil(suite.T(), err) 287 | assert.Equal(suite.T(), actor.Bio, sql.NullString{Valid: false}) 288 | 289 | // delete session 290 | del := qb.Delete(sessionsTable).Where( 291 | sessionsTable.C("auth_token").Eq("99e591f8-1025-41ef-a833-6904a0f89a38"), 292 | ) 293 | _, err = suite.engine.Exec(del) 294 | assert.Nil(suite.T(), err) 295 | 296 | // drop tables 297 | assert.Nil(suite.T(), suite.metadata.DropAll(suite.engine)) 298 | } 299 | 300 | func (suite *PostgresTestSuite) TestAutoIncrement() { 301 | dialect := NewDialect() 302 | col := qb.Column("id", qb.BigInt()).AutoIncrement() 303 | assert.Equal(suite.T(), 304 | "BIGSERIAL", 305 | dialect.AutoIncrement(&col)) 306 | 307 | col = qb.Column("id", qb.SmallInt()).AutoIncrement() 308 | assert.Equal(suite.T(), 309 | "SMALLSERIAL", 310 | dialect.AutoIncrement(&col)) 311 | 312 | col = qb.Column("id", qb.Int()).AutoIncrement() 313 | assert.Equal(suite.T(), 314 | "SERIAL", 315 | dialect.AutoIncrement(&col)) 316 | 317 | col = qb.Column("id", qb.Int()).AutoIncrement() 318 | col.Options.InlinePrimaryKey = true 319 | assert.Equal(suite.T(), 320 | "SERIAL PRIMARY KEY", 321 | dialect.AutoIncrement(&col)) 322 | } 323 | 324 | func (suite *PostgresTestSuite) TestUpsert() { 325 | 326 | users := qb.Table( 327 | "users", 328 | qb.Column("id", qb.Varchar().Size(36)), 329 | qb.Column("email", qb.Varchar()).Unique(), 330 | qb.Column("created_at", qb.Timestamp()).NotNull(), 331 | qb.PrimaryKey("id"), 332 | ) 333 | now := time.Now().UTC().String() 334 | ups := qb.Upsert(users).Values(map[string]interface{}{ 335 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 336 | "email": "al@pacino.com", 337 | "created_at": now, 338 | }) 339 | 340 | sql := ups.Accept(suite.ctx) 341 | binds := suite.ctx.Binds 342 | 343 | assert.Contains(suite.T(), sql, "INSERT INTO users") 344 | assert.Contains(suite.T(), sql, "id", "email") 345 | assert.Contains(suite.T(), sql, "VALUES($1, $2, $3)") 346 | assert.Contains(suite.T(), sql, "ON CONFLICT", "DO UPDATE SET") 347 | assert.Contains(suite.T(), binds, "9883cf81-3b56-4151-ae4e-3903c5bc436d") 348 | assert.Contains(suite.T(), binds, "al@pacino.com") 349 | assert.Equal(suite.T(), 6, len(binds)) 350 | 351 | ups = qb.Upsert(users). 352 | Values(map[string]interface{}{ 353 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 354 | "email": "al@pacino.com", 355 | }). 356 | Returning(users.C("id"), users.C("email")) 357 | 358 | ctx := qb.NewCompilerContext(NewDialect()) 359 | sql = ups.Accept(ctx) 360 | binds = ctx.Binds 361 | 362 | assert.Contains(suite.T(), sql, "INSERT INTO users") 363 | assert.Contains(suite.T(), sql, "id", "email") 364 | assert.Contains(suite.T(), sql, "ON CONFLICT") 365 | assert.Contains(suite.T(), sql, "DO UPDATE SET") 366 | assert.Contains(suite.T(), sql, "VALUES($1, $2)") 367 | assert.Contains(suite.T(), sql, "RETURNING id, email") 368 | assert.Contains(suite.T(), binds, "9883cf81-3b56-4151-ae4e-3903c5bc436d") 369 | assert.Contains(suite.T(), binds, "al@pacino.com") 370 | assert.Equal(suite.T(), 4, len(binds)) 371 | } 372 | 373 | func TestPostgresTestSuite(t *testing.T) { 374 | suite.Run(t, new(PostgresTestSuite)) 375 | } 376 | 377 | func init() { 378 | dsn := os.Getenv("QBTEST_POSTGRES") 379 | if dsn != "" { 380 | postgresDsn = dsn 381 | } 382 | } 383 | -------------------------------------------------------------------------------- /dialects/sqlite/sqlite.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/mattn/go-sqlite3" 8 | "github.com/slicebit/qb" 9 | ) 10 | 11 | // Dialect is a type of dialect that can be used with sqlite driver 12 | type Dialect struct { 13 | escaping bool 14 | } 15 | 16 | // NewDialect creates a new sqlite3 dialect 17 | func NewDialect() qb.Dialect { 18 | return &Dialect{false} 19 | } 20 | 21 | func init() { 22 | qb.RegisterDialect("sqlite3", NewDialect()) 23 | qb.RegisterDialect("sqlite", NewDialect()) 24 | } 25 | 26 | // CompileType compiles a type into its DDL 27 | func (d *Dialect) CompileType(t qb.TypeElem) string { 28 | if t.Name == "UUID" { 29 | return "VARCHAR(36)" 30 | } 31 | return qb.DefaultCompileType(t, d.SupportsUnsigned()) 32 | } 33 | 34 | // Escape wraps the string with escape characters of the dialect 35 | func (d *Dialect) Escape(str string) string { 36 | if d.escaping { 37 | return fmt.Sprintf(`"%s"`, str) 38 | } 39 | return str 40 | } 41 | 42 | // EscapeAll wraps all elements of string array 43 | func (d *Dialect) EscapeAll(strings []string) []string { 44 | return qb.EscapeAll(d, strings[0:]) 45 | } 46 | 47 | // SetEscaping sets the escaping parameter of dialect 48 | func (d *Dialect) SetEscaping(escaping bool) { 49 | d.escaping = escaping 50 | } 51 | 52 | // Escaping gets the escaping parameter of dialect 53 | func (d *Dialect) Escaping() bool { 54 | return d.escaping 55 | } 56 | 57 | // AutoIncrement generates auto increment sql of current dialect 58 | func (d *Dialect) AutoIncrement(column *qb.ColumnElem) string { 59 | if !column.Options.InlinePrimaryKey { 60 | panic("Sqlite does not support non-primarykey autoincrement columns") 61 | } 62 | return "INTEGER PRIMARY KEY" 63 | } 64 | 65 | // SupportsUnsigned returns whether driver supports unsigned type mappings or not 66 | func (d *Dialect) SupportsUnsigned() bool { return false } 67 | 68 | // Driver returns the current driver of dialect 69 | func (d *Dialect) Driver() string { 70 | return "sqlite3" 71 | } 72 | 73 | // GetCompiler returns a SqliteCompiler 74 | func (d *Dialect) GetCompiler() qb.Compiler { 75 | return SqliteCompiler{qb.NewSQLCompiler(d)} 76 | } 77 | 78 | // WrapError wraps a native error in a qb Error 79 | func (d *Dialect) WrapError(err error) qb.Error { 80 | qbErr := qb.Error{Orig: err} 81 | sErr, ok := err.(sqlite3.Error) 82 | if !ok { 83 | return qbErr 84 | } 85 | switch sErr.Code { 86 | case sqlite3.ErrInternal, 87 | sqlite3.ErrNotFound, 88 | sqlite3.ErrNomem: 89 | qbErr.Code = qb.ErrInternal 90 | case sqlite3.ErrError, 91 | sqlite3.ErrPerm, 92 | sqlite3.ErrAbort, 93 | sqlite3.ErrBusy, 94 | sqlite3.ErrLocked, 95 | sqlite3.ErrReadonly, 96 | sqlite3.ErrInterrupt, 97 | sqlite3.ErrIoErr, 98 | sqlite3.ErrFull, 99 | sqlite3.ErrCantOpen, 100 | sqlite3.ErrProtocol, 101 | sqlite3.ErrEmpty, 102 | sqlite3.ErrSchema: 103 | qbErr.Code = qb.ErrOperational 104 | case sqlite3.ErrCorrupt: 105 | qbErr.Code = qb.ErrDatabase 106 | case sqlite3.ErrTooBig: 107 | qbErr.Code = qb.ErrData 108 | case sqlite3.ErrConstraint, 109 | sqlite3.ErrMismatch: 110 | qbErr.Code = qb.ErrIntegrity 111 | case sqlite3.ErrMisuse: 112 | qbErr.Code = qb.ErrProgramming 113 | default: 114 | qbErr.Code = qb.ErrDatabase 115 | } 116 | return qbErr 117 | } 118 | 119 | // SqliteCompiler is a SQLCompiler specialised for Sqlite 120 | type SqliteCompiler struct { 121 | qb.SQLCompiler 122 | } 123 | 124 | // VisitUpsert generates the following sql: REPLACE INTO ... VALUES ... 125 | func (SqliteCompiler) VisitUpsert(context *qb.CompilerContext, upsert qb.UpsertStmt) string { 126 | var ( 127 | colNames []string 128 | values []string 129 | ) 130 | for k, v := range upsert.ValuesMap { 131 | colNames = append(colNames, context.Compiler.VisitLabel(context, k)) 132 | context.Binds = append(context.Binds, v) 133 | values = append(values, "?") 134 | } 135 | 136 | sql := fmt.Sprintf( 137 | "REPLACE INTO %s(%s)\nVALUES(%s)", 138 | context.Compiler.VisitLabel(context, upsert.Table.Name), 139 | strings.Join(colNames, ", "), 140 | strings.Join(values, ", "), 141 | ) 142 | 143 | return sql 144 | } 145 | -------------------------------------------------------------------------------- /dialects/sqlite/sqlite_test.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/mattn/go-sqlite3" 10 | "github.com/slicebit/qb" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | type SqliteTestSuite struct { 16 | suite.Suite 17 | engine *qb.Engine 18 | metadata *qb.MetaDataElem 19 | } 20 | 21 | func (suite *SqliteTestSuite) SetupTest() { 22 | var err error 23 | 24 | suite.engine, err = qb.New("sqlite3", "./qb_test.db") 25 | 26 | suite.metadata = qb.MetaData() 27 | 28 | assert.Nil(suite.T(), err) 29 | assert.NotNil(suite.T(), suite.engine) 30 | } 31 | 32 | func (suite *SqliteTestSuite) TestUUID() { 33 | assert.Equal(suite.T(), "VARCHAR(36)", suite.engine.Dialect().CompileType(qb.UUID())) 34 | } 35 | 36 | func (suite *SqliteTestSuite) TestDialect() { 37 | dialect := qb.NewDialect("sqlite") 38 | assert.Equal(suite.T(), false, dialect.SupportsUnsigned()) 39 | assert.Equal(suite.T(), "test", dialect.Escape("test")) 40 | assert.Equal(suite.T(), false, dialect.Escaping()) 41 | dialect.SetEscaping(true) 42 | assert.Equal(suite.T(), true, dialect.Escaping()) 43 | assert.Equal(suite.T(), "\"test\"", dialect.Escape("test")) 44 | assert.Equal(suite.T(), []string{"\"test\""}, dialect.EscapeAll([]string{"test"})) 45 | assert.Equal(suite.T(), "sqlite3", dialect.Driver()) 46 | } 47 | 48 | func (suite *SqliteTestSuite) TestWrapError() { 49 | dialect := qb.NewDialect("sqlite") 50 | err := errors.New("xxx") 51 | qbErr := dialect.WrapError(err) 52 | assert.Equal(suite.T(), err, qbErr.Orig) 53 | 54 | for _, tt := range []struct { 55 | sCode sqlite3.ErrNo 56 | expCode qb.ErrorCode 57 | }{ 58 | {sqlite3.ErrInternal, qb.ErrInternal}, 59 | {sqlite3.ErrNotFound, qb.ErrInternal}, 60 | {sqlite3.ErrNomem, qb.ErrInternal}, 61 | {sqlite3.ErrError, qb.ErrOperational}, 62 | {sqlite3.ErrIoErr, qb.ErrOperational}, 63 | {sqlite3.ErrCorrupt, qb.ErrDatabase}, 64 | {sqlite3.ErrTooBig, qb.ErrData}, 65 | {sqlite3.ErrConstraint, qb.ErrIntegrity}, 66 | {sqlite3.ErrMismatch, qb.ErrIntegrity}, 67 | {sqlite3.ErrMisuse, qb.ErrProgramming}, 68 | {293012, qb.ErrDatabase}, 69 | } { 70 | sErr := sqlite3.Error{Code: tt.sCode} 71 | qErr := dialect.WrapError(sErr) 72 | assert.Equal(suite.T(), tt.expCode, qErr.Code) 73 | } 74 | } 75 | 76 | func (suite *SqliteTestSuite) TestSqlite() { 77 | type User struct { 78 | ID string `db:"id"` 79 | Email string `db:"email"` 80 | FullName string `db:"full_name"` 81 | Bio sql.NullString `db:"bio"` 82 | Oscars int `db:"oscars"` 83 | } 84 | 85 | type Session struct { 86 | UserID string `db:"user_id"` 87 | AuthToken string `db:"auth_token"` 88 | CreatedAt time.Time `db:"created_at"` 89 | ExpiresAt time.Time `db:"expires_at"` 90 | } 91 | 92 | users := qb.Table( 93 | "users", 94 | qb.Column("id", qb.Varchar().Size(40)), 95 | qb.Column("email", qb.Varchar()).NotNull().Unique(), 96 | qb.Column("full_name", qb.Varchar()).NotNull(), 97 | qb.Column("bio", qb.Text()).Null(), 98 | qb.Column("oscars", qb.Int()).NotNull().Default(0), 99 | qb.PrimaryKey("id"), 100 | ) 101 | 102 | sessions := qb.Table( 103 | "sessions", 104 | qb.Column("user_id", qb.Varchar().Size(40)), 105 | qb.Column("auth_token", qb.Varchar().Size(40)).NotNull().Unique(), 106 | qb.Column("created_at", qb.Timestamp()).NotNull(), 107 | qb.Column("expires_at", qb.Timestamp()).NotNull(), 108 | qb.ForeignKey("user_id").References("users", "id"), 109 | ) 110 | 111 | var err error 112 | 113 | suite.metadata.AddTable(users) 114 | suite.metadata.AddTable(sessions) 115 | 116 | err = suite.metadata.CreateAll(suite.engine) 117 | assert.Nil(suite.T(), err) 118 | 119 | ins := qb.Insert(users).Values(map[string]interface{}{ 120 | "id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 121 | "email": "jack@nicholson.com", 122 | "full_name": "Jack Nicholson", 123 | "bio": sql.NullString{String: "Jack Nicholson, an American actor, producer, screen-writer and director, is a three-time Academy Award winner and twelve-time nominee.", Valid: true}, 124 | }) 125 | 126 | _, err = suite.engine.Exec(ins) 127 | assert.Nil(suite.T(), err) 128 | 129 | ins = qb.Insert(sessions).Values(map[string]interface{}{ 130 | "user_id": "b6f8bfe3-a830-441a-a097-1777e6bfae95", 131 | "auth_token": "e4968197-6137-47a4-ba79-690d8c552248", 132 | "created_at": time.Now(), 133 | "expires_at": time.Now().Add(24 * time.Hour), 134 | }) 135 | 136 | _, err = suite.engine.Exec(ins) 137 | assert.Nil(suite.T(), err) 138 | 139 | // find user 140 | var user User 141 | 142 | sel := qb.Select(users.C("id"), users.C("email"), users.C("full_name"), users.C("bio")). 143 | From(users). 144 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 145 | 146 | err = suite.engine.Get(sel, &user) 147 | assert.Nil(suite.T(), err) 148 | 149 | assert.Equal(suite.T(), "jack@nicholson.com", user.Email) 150 | assert.Equal(suite.T(), "Jack Nicholson", user.FullName) 151 | assert.Equal(suite.T(), "Jack Nicholson, an American actor, producer, screen-writer and director, is a three-time Academy Award winner and twelve-time nominee.", user.Bio.String) 152 | 153 | // select using join 154 | sessionSlice := []Session{} 155 | 156 | sel = qb.Select(sessions.C("user_id"), sessions.C("auth_token"), sessions.C("created_at"), sessions.C("expires_at")). 157 | From(sessions). 158 | InnerJoin(users, sessions.C("user_id"), users.C("id")). 159 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 160 | 161 | err = suite.engine.Select(sel, &sessionSlice) 162 | assert.Nil(suite.T(), err) 163 | assert.Equal(suite.T(), 1, len(sessionSlice)) 164 | 165 | assert.Equal(suite.T(), "b6f8bfe3-a830-441a-a097-1777e6bfae95", sessionSlice[0].UserID) 166 | assert.Equal(suite.T(), "e4968197-6137-47a4-ba79-690d8c552248", sessionSlice[0].AuthToken) 167 | 168 | upd := qb.Update(users). 169 | Values(map[string]interface{}{ 170 | "bio": sql.NullString{String: "nil", Valid: false}, 171 | }). 172 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 173 | 174 | _, err = suite.engine.Exec(upd) 175 | assert.Nil(suite.T(), err) 176 | 177 | sel = qb.Select(users.C("id"), users.C("email"), users.C("full_name"), users.C("bio")). 178 | From(users). 179 | Where(users.C("id").Eq("b6f8bfe3-a830-441a-a097-1777e6bfae95")) 180 | 181 | err = suite.engine.Get(sel, &user) 182 | assert.Nil(suite.T(), err) 183 | assert.Equal(suite.T(), user.Bio, sql.NullString{String: "", Valid: false}) 184 | assert.Equal(suite.T(), sql.NullString{String: "", Valid: false}, user.Bio) 185 | 186 | del := qb.Delete(sessions).Where(sessions.C("auth_token").Eq("99e591f8-1025-41ef-a833-6904a0f89a38")) 187 | 188 | // delete session 189 | _, err = suite.engine.Exec(del) 190 | assert.Nil(suite.T(), err) 191 | 192 | // drop tables 193 | assert.Nil(suite.T(), suite.metadata.DropAll(suite.engine)) 194 | } 195 | 196 | func (suite *SqliteTestSuite) TestUpsert() { 197 | users := qb.Table( 198 | "users", 199 | qb.Column("id", qb.Varchar().Size(36)), 200 | qb.Column("email", qb.Varchar()).Unique(), 201 | qb.Column("created_at", qb.Timestamp()).NotNull(), 202 | qb.PrimaryKey("id"), 203 | ) 204 | 205 | now := time.Now().UTC().String() 206 | 207 | ups := qb.Upsert(users).Values(map[string]interface{}{ 208 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 209 | "email": "al@pacino.com", 210 | "created_at": now, 211 | }) 212 | 213 | ctx := qb.NewCompilerContext(NewDialect()) 214 | sql := ups.Accept(ctx) 215 | binds := ctx.Binds 216 | assert.Contains(suite.T(), sql, `REPLACE INTO users`) 217 | assert.Contains(suite.T(), sql, "id", "email", "created_at") 218 | assert.Contains(suite.T(), sql, "VALUES(?, ?, ?)") 219 | assert.Contains(suite.T(), binds, "9883cf81-3b56-4151-ae4e-3903c5bc436d") 220 | assert.Contains(suite.T(), binds, "al@pacino.com") 221 | assert.Contains(suite.T(), binds, now) 222 | assert.Equal(suite.T(), 3, len(binds)) 223 | } 224 | 225 | func (suite *SqliteTestSuite) TestSqliteAutoIncrement() { 226 | col := qb.Column("test", qb.Int()).AutoIncrement() 227 | assert.Panics(suite.T(), func() { 228 | col.String(suite.engine.Dialect()) 229 | }) 230 | 231 | col.Options.InlinePrimaryKey = true 232 | assert.Equal(suite.T(), "INTEGER PRIMARY KEY", suite.engine.Dialect().AutoIncrement(&col)) 233 | } 234 | 235 | func TestSqliteTestSuite(t *testing.T) { 236 | suite.Run(t, new(SqliteTestSuite)) 237 | } 238 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.3' 2 | 3 | services: 4 | mysql: 5 | container_name: qb-test-mysql 6 | image: mysql:latest 7 | restart: always 8 | environment: 9 | MYSQL_ALLOW_EMPTY_PASSWORD: "yes" 10 | MYSQL_DATABASE: qb_test 11 | ports: 12 | - "3306:3306" 13 | postgres: 14 | container_name: qb-test-postgres 15 | image: library/postgres 16 | restart: always 17 | environment: 18 | POSTGRES_DB: qb_test 19 | POSTGRES_USER: postgres 20 | POSTGRES_PASSWORD: "" 21 | ports: 22 | - "5432:5432" -------------------------------------------------------------------------------- /engine.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "database/sql" 5 | "log" 6 | "os" 7 | 8 | "github.com/jmoiron/sqlx" 9 | "github.com/serenize/snaker" 10 | ) 11 | 12 | // New generates a new engine and returns it as an engine pointer 13 | func New(driver string, dsn string) (*Engine, error) { 14 | conn, err := sqlx.Open(driver, dsn) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | // set name mapper function 20 | conn.MapperFunc(func(name string) string { 21 | return snaker.CamelToSnake(name) 22 | }) 23 | 24 | return &Engine{ 25 | dialect: NewDialect(driver), 26 | dsn: dsn, 27 | db: conn, 28 | logger: &DefaultLogger{LDefault, log.New(os.Stdout, "", -1)}, 29 | }, err 30 | } 31 | 32 | // Engine is the generic struct for handling db connections 33 | type Engine struct { 34 | dsn string 35 | db *sqlx.DB 36 | dialect Dialect 37 | logger Logger 38 | } 39 | 40 | // Dialect returns the engine dialect 41 | func (e Engine) Dialect() Dialect { 42 | return e.dialect 43 | } 44 | 45 | // SetDialect sets the current engine dialect 46 | func (e Engine) SetDialect(dialect Dialect) { 47 | e.dialect = dialect 48 | } 49 | 50 | // TranslateError translates the native errors into qb.Error 51 | func (e Engine) TranslateError(err error) error { 52 | if err != nil { 53 | return e.dialect.WrapError(err) 54 | } 55 | return nil 56 | } 57 | 58 | // Logger returns the active logger of engine 59 | func (e *Engine) Logger() Logger { 60 | return e.logger 61 | } 62 | 63 | // SetLogger sets the logger of engine 64 | func (e *Engine) SetLogger(logger Logger) { 65 | e.logger = logger 66 | } 67 | 68 | // SetLogFlags sets the log flags on the current logger 69 | func (e *Engine) SetLogFlags(flags LogFlags) { 70 | e.logger.SetLogFlags(flags) 71 | } 72 | 73 | func (e *Engine) log(statement *Stmt) { 74 | logFlags := e.logger.LogFlags() 75 | if logFlags&LQuery != 0 { 76 | e.logger.Println("SQL:", statement.SQL()) 77 | } 78 | if logFlags&LBindings != 0 { 79 | e.logger.Println("Bindings:", statement.Bindings()) 80 | } 81 | } 82 | 83 | // Exec executes insert & update type queries and returns sql.Result and error 84 | func (e *Engine) Exec(builder Builder) (sql.Result, error) { 85 | statement := builder.Build(e.dialect) 86 | e.log(statement) 87 | res, err := e.db.Exec(statement.SQL(), statement.Bindings()...) 88 | return res, e.TranslateError(err) 89 | } 90 | 91 | // Row wraps a *sql.Row in order to translate errors 92 | type Row struct { 93 | *sql.Row 94 | TranslateError func(error) error 95 | } 96 | 97 | // Scan wraps sql.Row.Scan() 98 | func (r Row) Scan(dest ...interface{}) error { 99 | return r.TranslateError(r.Row.Scan(dest...)) 100 | } 101 | 102 | // QueryRow wraps *sql.DB.QueryRow() 103 | func (e *Engine) QueryRow(builder Builder) Row { 104 | statement := builder.Build(e.dialect) 105 | e.log(statement) 106 | return Row{ 107 | e.db.QueryRow(statement.SQL(), statement.Bindings()...), 108 | e.TranslateError, 109 | } 110 | } 111 | 112 | // Query wraps *sql.DB.Query() 113 | func (e *Engine) Query(builder Builder) (*sql.Rows, error) { 114 | statement := builder.Build(e.dialect) 115 | e.log(statement) 116 | rows, err := e.db.Query(statement.SQL(), statement.Bindings()...) 117 | return rows, e.TranslateError(err) 118 | } 119 | 120 | // Get maps the single row to a model 121 | func (e *Engine) Get(builder Builder, model interface{}) error { 122 | statement := builder.Build(e.dialect) 123 | e.log(statement) 124 | return e.TranslateError( 125 | e.db.Get(model, statement.SQL(), statement.Bindings()...)) 126 | } 127 | 128 | // Select maps multiple rows to a model array 129 | func (e *Engine) Select(builder Builder, model interface{}) error { 130 | statement := builder.Build(e.dialect) 131 | e.log(statement) 132 | return e.TranslateError( 133 | e.db.Select(model, statement.SQL(), statement.Bindings()...)) 134 | } 135 | 136 | // DB returns sql.DB of wrapped engine connection 137 | func (e *Engine) DB() *sqlx.DB { 138 | return e.db 139 | } 140 | 141 | // Ping pings the db using connection and returns error if connectivity is not present 142 | func (e *Engine) Ping() error { 143 | return e.db.Ping() 144 | } 145 | 146 | // Close closes the sqlx db connection 147 | func (e *Engine) Close() error { 148 | return e.db.Close() 149 | } 150 | 151 | // Driver returns the driver as string 152 | func (e *Engine) Driver() string { 153 | return e.dialect.Driver() 154 | } 155 | 156 | // Dsn returns the connection dsn 157 | func (e *Engine) Dsn() string { 158 | return e.dsn 159 | } 160 | 161 | // Begin begins a transaction and return a *qb.Tx 162 | func (e *Engine) Begin() (*Tx, error) { 163 | tx, err := e.db.Beginx() 164 | if err != nil { 165 | return nil, e.dialect.WrapError(err) 166 | } 167 | return &Tx{e, tx}, nil 168 | } 169 | 170 | // Tx is an in-progress database transaction 171 | type Tx struct { 172 | engine *Engine 173 | tx *sqlx.Tx 174 | } 175 | 176 | // Tx returns the underlying *sqlx.Tx 177 | func (tx *Tx) Tx() *sqlx.Tx { 178 | return tx.tx 179 | } 180 | 181 | // Commit commits the transaction 182 | func (tx *Tx) Commit() error { 183 | return tx.tx.Commit() 184 | } 185 | 186 | // Rollback aborts the transaction 187 | func (tx *Tx) Rollback() error { 188 | return tx.tx.Rollback() 189 | } 190 | 191 | // Exec executes insert & update type queries and returns sql.Result and error 192 | func (tx *Tx) Exec(builder Builder) (sql.Result, error) { 193 | statement := builder.Build(tx.engine.dialect) 194 | tx.engine.log(statement) 195 | res, err := tx.tx.Exec(statement.SQL(), statement.Bindings()...) 196 | return res, tx.engine.TranslateError(err) 197 | } 198 | 199 | // QueryRow wraps *sql.DB.QueryRow() 200 | func (tx *Tx) QueryRow(builder Builder) Row { 201 | statement := builder.Build(tx.engine.dialect) 202 | tx.engine.log(statement) 203 | return Row{ 204 | tx.tx.QueryRow(statement.SQL(), statement.Bindings()...), 205 | tx.engine.TranslateError, 206 | } 207 | } 208 | 209 | // Query wraps *sql.DB.Query() 210 | func (tx *Tx) Query(builder Builder) (*sql.Rows, error) { 211 | statement := builder.Build(tx.engine.dialect) 212 | tx.engine.log(statement) 213 | rows, err := tx.tx.Query(statement.SQL(), statement.Bindings()...) 214 | return rows, tx.engine.TranslateError(err) 215 | } 216 | 217 | // Get maps the single row to a model 218 | func (tx *Tx) Get(builder Builder, model interface{}) error { 219 | statement := builder.Build(tx.engine.dialect) 220 | tx.engine.log(statement) 221 | return tx.engine.TranslateError( 222 | tx.tx.Get(model, statement.SQL(), statement.Bindings()...)) 223 | } 224 | 225 | // Select maps multiple rows to a model array 226 | func (tx *Tx) Select(builder Builder, model interface{}) error { 227 | statement := builder.Build(tx.engine.dialect) 228 | tx.engine.log(statement) 229 | return tx.engine.TranslateError( 230 | tx.tx.Select(model, statement.SQL(), statement.Bindings()...)) 231 | } 232 | -------------------------------------------------------------------------------- /engine_test.go: -------------------------------------------------------------------------------- 1 | package qb_test 2 | 3 | import ( 4 | "testing" 5 | 6 | _ "github.com/mattn/go-sqlite3" 7 | "github.com/slicebit/qb" 8 | _ "github.com/slicebit/qb/dialects/sqlite" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestEngine(t *testing.T) { 13 | engine, err := qb.New("sqlite3", ":memory:") 14 | 15 | assert.Equal(t, nil, err) 16 | assert.Equal(t, "sqlite3", engine.Driver()) 17 | assert.Equal(t, engine.DB().Ping(), engine.Ping()) 18 | assert.Equal(t, ":memory:", engine.Dsn()) 19 | } 20 | 21 | func TestInvalidEngine(t *testing.T) { 22 | engine, err := qb.New("invalid", "") 23 | assert.NotEqual(t, nil, err) 24 | assert.Equal(t, (*qb.Engine)(nil), engine) 25 | } 26 | 27 | func TestEngineExec(t *testing.T) { 28 | engine, err := qb.New("sqlite3", ":memory:") 29 | dialect := qb.NewDialect("sqlite") 30 | dialect.SetEscaping(true) 31 | engine.SetDialect(dialect) 32 | 33 | usersTable := qb.Table( 34 | "users", 35 | qb.Column("full_name", qb.Varchar()).NotNull(), 36 | ) 37 | 38 | ins := qb.Insert(usersTable). 39 | Values(map[string]interface{}{ 40 | "full_name": "Al Pacino", 41 | }) 42 | 43 | assert.Nil(t, err) 44 | 45 | res, err := engine.Exec(ins) 46 | assert.Equal(t, nil, res) 47 | assert.NotNil(t, err) 48 | } 49 | 50 | func TestEngineFail(t *testing.T) { 51 | engine, err := qb.New("sqlite3", ":memory:") 52 | defer engine.Close() 53 | engine.SetDialect(qb.NewDialect("sqlite3")) 54 | assert.Nil(t, err) 55 | 56 | usersTable := qb.Table( 57 | "users", 58 | qb.Column("full_name", qb.Varchar()).NotNull(), 59 | ) 60 | 61 | statement := qb.Insert(usersTable). 62 | Values(map[string]interface{}{ 63 | "full_name": "Robert De Niro", 64 | }) 65 | 66 | _, err = engine.Exec(statement) 67 | assert.NotNil(t, err) 68 | } 69 | 70 | func TestTx(t *testing.T) { 71 | engine, err := qb.New("sqlite3", ":memory:") 72 | assert.Nil(t, err) 73 | defer engine.Close() 74 | 75 | engine.SetDialect(qb.NewDialect("sqlite3")) 76 | 77 | usersTable := qb.Table( 78 | "users", 79 | qb.Column("full_name", qb.Varchar()).NotNull(), 80 | ) 81 | 82 | _, err = engine.DB().Exec(usersTable.Create(engine.Dialect())) 83 | assert.Nil(t, err) 84 | 85 | countStmt := qb.Select(qb.Count(usersTable.C("full_name"))).From(usersTable) 86 | 87 | tx, err := engine.Begin() 88 | assert.Equal(t, nil, err) 89 | 90 | assert.NotNil(t, tx.Tx()) 91 | 92 | _, err = tx.Exec( 93 | usersTable.Insert(). 94 | Values(map[string]interface{}{ 95 | "full_name": "Robert De Niro", 96 | }), 97 | ) 98 | assert.Equal(t, nil, err) 99 | var count int 100 | row := tx.QueryRow(countStmt) 101 | assert.Nil(t, row.Scan(&count)) 102 | assert.Equal(t, 1, count) 103 | 104 | tx.Commit() 105 | 106 | row = engine.QueryRow(countStmt) 107 | assert.Equal(t, nil, row.Scan(&count)) 108 | assert.Equal(t, 1, count) 109 | 110 | tx, err = engine.Begin() 111 | assert.Equal(t, nil, err) 112 | 113 | _, err = tx.Exec(usersTable.Insert(). 114 | Values(map[string]interface{}{ 115 | "full_name": "Al Pacino", 116 | }), 117 | ) 118 | assert.Equal(t, nil, err) 119 | 120 | rows, err := tx.Query(countStmt) 121 | assert.Equal(t, nil, err) 122 | assert.True(t, rows.Next()) 123 | assert.Equal(t, nil, rows.Scan(&count)) 124 | assert.Equal(t, 2, count) 125 | 126 | tx.Rollback() 127 | 128 | row = engine.QueryRow(countStmt) 129 | assert.Equal(t, nil, row.Scan(&count)) 130 | assert.Equal(t, 1, count) 131 | 132 | tx, _ = engine.Begin() 133 | 134 | assert.Nil(t, nil) 135 | 136 | var user struct{ FullName string } 137 | var users []struct{ FullName string } 138 | 139 | assert.Nil(t, 140 | tx.Get(usersTable.Select(usersTable.C("full_name")), &user), 141 | ) 142 | assert.Equal(t, "Robert De Niro", user.FullName) 143 | 144 | assert.Nil(t, 145 | tx.Select(usersTable.Select(usersTable.C("full_name")), &users), 146 | ) 147 | assert.Equal(t, "Robert De Niro", users[0].FullName) 148 | 149 | } 150 | 151 | func TestTxBeginError(t *testing.T) { 152 | engine, err := qb.New("sqlite3", "file:///dev/null?_txlock=exclusive") 153 | assert.Nil(t, err) 154 | _, err = engine.Begin() 155 | assert.NotNil(t, err) 156 | } 157 | 158 | func TestEngineQuery(t *testing.T) { 159 | engine, err := qb.New("sqlite3", ":memory:") 160 | assert.Nil(t, err) 161 | rows, err := engine.Query(qb.Select(qb.SQLText("1"))) 162 | assert.Nil(t, err) 163 | assert.True(t, rows.Next()) 164 | var value int 165 | assert.Nil(t, rows.Scan(&value)) 166 | assert.Equal(t, 1, value) 167 | assert.False(t, rows.Next()) 168 | } 169 | 170 | func TestEngineGet(t *testing.T) { 171 | var s struct { 172 | Value int `db:"value"` 173 | } 174 | engine, err := qb.New("sqlite3", ":memory:") 175 | assert.Nil(t, err) 176 | assert.Nil(t, engine.Get(qb.Select(qb.SQLText("1 AS value")), &s)) 177 | assert.Equal(t, 1, s.Value) 178 | } 179 | 180 | func TestEngineSelect(t *testing.T) { 181 | var s []struct { 182 | Value int `db:"value"` 183 | } 184 | engine, err := qb.New("sqlite3", ":memory:") 185 | assert.Nil(t, err) 186 | assert.Nil(t, engine.Select(qb.Select(qb.SQLText("1 AS value")), &s)) 187 | assert.Equal(t, 1, len(s)) 188 | assert.Equal(t, 1, s[0].Value) 189 | } 190 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // ErrorCode discriminates the types of errors that qb wraps, mainly the 4 | // constraint errors 5 | // The different kind of errors are based on the python dbapi errors 6 | // (https://www.python.org/dev/peps/pep-0249/#exceptions) 7 | type ErrorCode int 8 | 9 | // Bit 8 and 9 are flags to separate interface errors from database errors 10 | const ( 11 | // ErrAny is for errors that could not be categorized by the dialect 12 | ErrAny ErrorCode = 0 13 | // ErrInterface is a bit mask for errors that are related to the database 14 | // interface rather than the database itself. 15 | ErrInterface ErrorCode = 1 << 8 16 | // ErrDatabase is a bit mask for errors that are related to the database. 17 | ErrDatabase ErrorCode = 1 << 9 18 | ) 19 | 20 | // Database error codes are in bits 5 to 7, leaving bits 0 to 4 for detailed 21 | // codes in a later version 22 | const ( 23 | // ErrData is for errors that are due to problems with the processed data 24 | // like division by zero, numeric value out of range, etc. 25 | ErrData ErrorCode = ErrDatabase | (iota + 1<<5) 26 | // ErrOperational is for errors that are related to the database's 27 | // operation and not necessarily under the control of the programmer, e.g. 28 | // an unexpected disconnect occurs, the data source name is not found, a 29 | // transaction could not be processed, a memory allocation error occurred 30 | // during processing, etc. 31 | ErrOperational 32 | // ErrIntegrity is when the relational integrity of the database is 33 | // affected, e.g. a foreign key check fails 34 | ErrIntegrity 35 | // ErrInternal is when the database encounters an internal error, e.g. the 36 | // cursor is not valid anymore, the transaction is out of sync, etc. 37 | ErrInternal 38 | // ErrProgramming is for programming errors, e.g. table not found or 39 | // already exists, syntax error in the SQL statement, wrong number of 40 | // parameters specified, etc. 41 | ErrProgramming 42 | // ErrNotSupported is in case a method or database API was used which 43 | // is not supported by the database, e.g. requesting a .rollback() on a 44 | // connection that does not support transaction or has transactions turned 45 | // off. 46 | ErrNotSupported 47 | ) 48 | 49 | // IsInterfaceError returns true if the error is a Interface error 50 | func (err ErrorCode) IsInterfaceError() bool { 51 | return err&ErrInterface != 0 52 | } 53 | 54 | // IsDatabaseError returns true if the error is a Database error 55 | func (err ErrorCode) IsDatabaseError() bool { 56 | return err&ErrDatabase != 0 57 | } 58 | 59 | // Error wraps driver errors. It helps handling constraint error in 60 | // a generic way, while still giving access to the original error 61 | type Error struct { 62 | Code ErrorCode 63 | Orig error // The native error from the driver 64 | Table string 65 | Column string 66 | Constraint string 67 | } 68 | 69 | func (err Error) Error() string { 70 | switch err.Code { 71 | case ErrAny: 72 | return "Uncategorized error: " + err.Orig.Error() 73 | case ErrInterface: 74 | return "Interface error: " + err.Orig.Error() 75 | case ErrDatabase: 76 | return "Database error: " + err.Orig.Error() 77 | case ErrData: 78 | return "Database data error: " + err.Orig.Error() 79 | case ErrOperational: 80 | return "Database operational error: " + err.Orig.Error() 81 | case ErrIntegrity: 82 | return "Database integrity error: " + err.Orig.Error() 83 | case ErrInternal: 84 | return "Database internal error: " + err.Orig.Error() 85 | case ErrProgramming: 86 | return "Database programming error: " + err.Orig.Error() 87 | default: 88 | return err.Orig.Error() 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestErrorMessages(t *testing.T) { 11 | var tests = []struct { 12 | code ErrorCode 13 | expected string 14 | }{ 15 | {ErrAny, "Uncategorized error: xxx"}, 16 | {ErrInterface, "Interface error: xxx"}, 17 | {ErrDatabase, "Database error: xxx"}, 18 | {ErrData, "Database data error: xxx"}, 19 | {ErrOperational, "Database operational error: xxx"}, 20 | {ErrIntegrity, "Database integrity error: xxx"}, 21 | {ErrInternal, "Database internal error: xxx"}, 22 | {ErrProgramming, "Database programming error: xxx"}, 23 | {54, "xxx"}, 24 | } 25 | for _, tt := range tests { 26 | assert.Equal(t, tt.expected, Error{Code: tt.code, Orig: errors.New("xxx")}.Error()) 27 | } 28 | } 29 | 30 | func TestErrorCode(t *testing.T) { 31 | assert.True(t, ErrInterface.IsInterfaceError()) 32 | assert.False(t, ErrInterface.IsDatabaseError()) 33 | 34 | assert.True(t, ErrDatabase.IsDatabaseError()) 35 | assert.False(t, ErrDatabase.IsInterfaceError()) 36 | 37 | assert.True(t, ErrData.IsDatabaseError()) 38 | assert.False(t, ErrData.IsInterfaceError()) 39 | 40 | assert.True(t, ErrOperational.IsDatabaseError()) 41 | assert.False(t, ErrOperational.IsInterfaceError()) 42 | 43 | assert.True(t, ErrIntegrity.IsDatabaseError()) 44 | assert.False(t, ErrIntegrity.IsInterfaceError()) 45 | 46 | assert.True(t, ErrInternal.IsDatabaseError()) 47 | assert.False(t, ErrInternal.IsInterfaceError()) 48 | 49 | assert.True(t, ErrProgramming.IsDatabaseError()) 50 | assert.False(t, ErrProgramming.IsInterfaceError()) 51 | } 52 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/slicebit/qb 2 | 3 | require ( 4 | github.com/davecgh/go-spew v1.1.1 // indirect 5 | github.com/go-sql-driver/mysql v1.4.1 6 | github.com/jmoiron/sqlx v1.2.0 7 | github.com/lib/pq v1.0.0 8 | github.com/mattn/go-sqlite3 v1.10.0 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516 11 | github.com/stretchr/testify v1.2.2 12 | ) 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 4 | github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= 5 | github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 6 | github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= 7 | github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= 8 | github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= 9 | github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 10 | github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 11 | github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= 12 | github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516 h1:ofR1ZdrNSkiWcMsRrubK9tb2/SlZVWttAfqUjJi6QYc= 16 | github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516/go.mod h1:Yow6lPLSAXx2ifx470yD/nUe22Dv5vBvxK/UK9UUTVs= 17 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 18 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 19 | -------------------------------------------------------------------------------- /index.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // CompositeIndex is the struct definition when building composite indices for any struct that will be mapped into a table 9 | type CompositeIndex string 10 | 11 | // Index generates an index clause given table and columns as params 12 | func Index(table string, cols ...string) IndexElem { 13 | return IndexElem{ 14 | Table: table, 15 | Name: fmt.Sprintf("i_%s", strings.Join(cols, "_")), 16 | Columns: cols, 17 | } 18 | } 19 | 20 | // IndexElem is the definition of any index elements for a table 21 | type IndexElem struct { 22 | Table string 23 | Name string 24 | Columns []string 25 | } 26 | 27 | // String returns the index element as an sql clause 28 | func (i IndexElem) String(dialect Dialect) string { 29 | return fmt.Sprintf("CREATE INDEX %s ON %s(%s);", dialect.Escape(i.Name), dialect.Escape(i.Table), strings.Join(dialect.EscapeAll(i.Columns), ", ")) 30 | } 31 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Insert generates an insert statement and returns it 4 | // Insert(usersTable).Values(map[string]interface{}{"id": 1}) 5 | func Insert(table TableElem) InsertStmt { 6 | return InsertStmt{ 7 | table: table, 8 | values: map[string]interface{}{}, 9 | returning: []ColumnElem{}, 10 | } 11 | } 12 | 13 | // InsertStmt is the base struct for any insert statements 14 | type InsertStmt struct { 15 | table TableElem 16 | values map[string]interface{} 17 | returning []ColumnElem 18 | } 19 | 20 | // Values accepts map[string]interface{} and forms the values map of insert statement 21 | func (s InsertStmt) Values(values map[string]interface{}) InsertStmt { 22 | for k, v := range values { 23 | s.values[k] = v 24 | } 25 | return s 26 | } 27 | 28 | // Returning accepts the column names as strings and forms the returning array of insert statement 29 | // NOTE: Please use it in only postgres dialect, otherwise it'll crash 30 | func (s InsertStmt) Returning(cols ...ColumnElem) InsertStmt { 31 | for _, c := range cols { 32 | s.returning = append(s.returning, c) 33 | } 34 | return s 35 | } 36 | 37 | // Accept implements Clause.Accept 38 | func (s InsertStmt) Accept(context *CompilerContext) string { 39 | return context.Compiler.VisitInsert(context, s) 40 | } 41 | 42 | // Build generates a statement out of InsertStmt object 43 | func (s InsertStmt) Build(dialect Dialect) *Stmt { 44 | statement := Statement() 45 | context := NewCompilerContext(dialect) 46 | statement.AddSQLClause(s.Accept(context)) 47 | statement.AddBinding(context.Binds...) 48 | 49 | return statement 50 | } 51 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestInsert(t *testing.T) { 10 | users := Table( 11 | "users", 12 | Column("id", Varchar().Size(36)), 13 | Column("email", Varchar()).Unique(), 14 | ) 15 | 16 | ins := Insert(users).Values(map[string]interface{}{ 17 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 18 | "email": "al@pacino.com", 19 | }) 20 | 21 | dialect := NewDefaultDialect() 22 | ctx := NewCompilerContext(dialect) 23 | 24 | sql := ins.Accept(ctx) 25 | binds := ctx.Binds 26 | 27 | assert.Contains(t, sql, "INSERT INTO users") 28 | assert.Contains(t, sql, "id", "email") 29 | assert.Contains(t, sql, "VALUES(?, ?)") 30 | assert.Contains(t, binds, "9883cf81-3b56-4151-ae4e-3903c5bc436d") 31 | assert.Contains(t, binds, "al@pacino.com") 32 | 33 | sql = Insert(users). 34 | Values(map[string]interface{}{ 35 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 36 | "email": "al@pacino.com", 37 | }). 38 | Returning(users.C("id"), users.C("email")).Accept(ctx) 39 | binds = ctx.Binds 40 | 41 | assert.Contains(t, sql, "INSERT INTO users") 42 | assert.Contains(t, sql, "id", "email") 43 | assert.Contains(t, sql, "VALUES(?, ?)") 44 | assert.Contains(t, sql, "RETURNING id, email") 45 | assert.Contains(t, binds, "9883cf81-3b56-4151-ae4e-3903c5bc436d", "al@pacino.com") 46 | } 47 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import "log" 4 | 5 | // These are the log flags qb can use 6 | const ( 7 | // LDefault is the default flag that logs nothing 8 | LDefault LogFlags = 0 9 | // LQuery Flag to log queries 10 | LQuery LogFlags = 1 << iota 11 | // LBindings Flag to log bindings 12 | LBindings 13 | ) 14 | 15 | // LogFlags is the type we use for flags that can be passed 16 | // to the logger 17 | type LogFlags uint 18 | 19 | // Logger is the std logger interface of the qb engine 20 | type Logger interface { 21 | Print(...interface{}) 22 | Printf(string, ...interface{}) 23 | Println(...interface{}) 24 | 25 | Fatal(...interface{}) 26 | Fatalf(string, ...interface{}) 27 | Fatalln(...interface{}) 28 | 29 | Panic(...interface{}) 30 | Panicf(string, ...interface{}) 31 | Panicln(...interface{}) 32 | 33 | LogFlags() LogFlags 34 | SetLogFlags(LogFlags) 35 | } 36 | 37 | // DefaultLogger is the default logger of qb engine unless engine.SetLogger() is not called 38 | type DefaultLogger struct { 39 | logFlags LogFlags 40 | *log.Logger 41 | } 42 | 43 | // SetLogFlags sets the logflags 44 | // It is for changing engine log flags 45 | func (l *DefaultLogger) SetLogFlags(logFlags LogFlags) { 46 | l.logFlags = logFlags 47 | } 48 | 49 | // LogFlags gets the logflags as an int 50 | func (l *DefaultLogger) LogFlags() LogFlags { 51 | return l.logFlags 52 | } 53 | -------------------------------------------------------------------------------- /logger_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // func TestLogger(t *testing.T) { 12 | // // engine, err := New("default", ":memory:") 13 | // // metadata := MetaData() 14 | // // actors := Table("actors", 15 | // // Column("id", BigInt()).NotNull(), 16 | // // PrimaryKey("id"), 17 | // // ) 18 | // // metadata.AddTable(actors) 19 | // // metadata.CreateAll(engine) 20 | // defer metadata.DropAll(engine) 21 | // logCapture := &TestingLogWriter{t, nil} 22 | // defer logCapture.Flush() 23 | // engine.SetLogger(&DefaultLogger{LQuery | LBindings, log.New(logCapture, "", log.LstdFlags)}) 24 | // engine.Logger().SetLogFlags(LQuery) 25 | 26 | // _, err = engine.Exec(actors.Insert().Values(map[string]interface{}{"id": 5})) 27 | // assert.Nil(t, err) 28 | 29 | // engine.Logger().SetLogFlags(LQuery | LBindings) 30 | // _, err = engine.Exec(actors.Insert().Values(map[string]interface{}{"id": 10})) 31 | // assert.Nil(t, err) 32 | 33 | // assert.Equal(t, engine.Logger().LogFlags(), LQuery|LBindings) 34 | // } 35 | 36 | func TestLoggerFlags(t *testing.T) { 37 | logger := DefaultLogger{LDefault, log.New(os.Stdout, "", -1)} 38 | 39 | logger.SetLogFlags(LBindings) 40 | 41 | assert.Equal(t, logger.LogFlags(), LBindings) 42 | // engine, err := New("sqlite3", ":memory:") 43 | // assert.Equal(t, nil, err) 44 | 45 | // before setting flags, this is on the default 46 | // assert.Equal(t, engine.Logger().LogFlags(), LDefault) 47 | 48 | // engine.SetLogFlags(LBindings) 49 | // after setting flags, we have the expected value 50 | // assert.Equal(t, engine.Logger().LogFlags(), LBindings) 51 | } 52 | -------------------------------------------------------------------------------- /metadata.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // MetaData creates a new MetaData object and returns it as a pointer 9 | func MetaData() *MetaDataElem { 10 | return &MetaDataElem{ 11 | tables: []TableElem{}, 12 | } 13 | } 14 | 15 | // MetaDataElem is the container for database structs and tables 16 | type MetaDataElem struct { 17 | tables []TableElem 18 | } 19 | 20 | // AddTable appends table to tables slice 21 | func (m *MetaDataElem) AddTable(table TableElem) { 22 | m.tables = append(m.tables, table) 23 | } 24 | 25 | // Table returns the metadata registered table object. It returns nil if table is not found 26 | func (m *MetaDataElem) Table(name string) TableElem { 27 | for _, t := range m.tables { 28 | if t.Name == name { 29 | return t 30 | } 31 | } 32 | 33 | panic(fmt.Errorf("Table %s not found", name)) 34 | } 35 | 36 | // Tables returns the current tables slice 37 | func (m *MetaDataElem) Tables() []TableElem { 38 | return m.tables 39 | } 40 | 41 | // CreateAll creates all the tables added to metadata 42 | func (m *MetaDataElem) CreateAll(engine *Engine) error { 43 | tx, err := engine.DB().Begin() 44 | if err != nil { 45 | return err 46 | } 47 | 48 | for _, t := range m.tables { 49 | _, err = tx.Exec(t.Create(engine.Dialect())) 50 | if err != nil { 51 | return err 52 | } 53 | } 54 | 55 | err = tx.Commit() 56 | 57 | if len(m.tables) == 0 { 58 | return errors.New("Metadata is empty. You need to register tables by calling db.AddTable(model{})") 59 | } 60 | 61 | return err 62 | } 63 | 64 | // DropAll drops all the tables which is added to metadata 65 | func (m *MetaDataElem) DropAll(engine *Engine) error { 66 | tx, err := engine.DB().Begin() 67 | if err != nil { 68 | return err 69 | } 70 | 71 | for i := len(m.tables) - 1; i >= 0; i-- { 72 | drop := m.tables[i].Drop(engine.Dialect()) 73 | _, err = tx.Exec(drop) 74 | if err != nil { 75 | return err 76 | } 77 | } 78 | 79 | err = tx.Commit() 80 | 81 | if len(m.tables) == 0 { 82 | return errors.New("Metadata is empty") 83 | } 84 | return err 85 | } 86 | -------------------------------------------------------------------------------- /qb_logo_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slicebit/qb/6bc2ae13ece37358c8795bcf0acd352c735ebf4b/qb_logo_128.png -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import "fmt" 4 | 5 | // Selectable is any clause from which we can select columns and is suitable 6 | // as a FROM clause element 7 | type Selectable interface { 8 | Clause 9 | All() []Clause 10 | ColumnList() []ColumnElem 11 | C(column string) ColumnElem 12 | DefaultName() string 13 | } 14 | 15 | // Select generates a select statement and returns it 16 | func Select(clauses ...Clause) SelectStmt { 17 | return SelectStmt{ 18 | SelectList: clauses, 19 | GroupByClause: []ColumnElem{}, 20 | HavingClause: []HavingClause{}, 21 | } 22 | } 23 | 24 | // SelectStmt is the base struct for building select statements 25 | type SelectStmt struct { 26 | SelectList []Clause 27 | FromClause Selectable 28 | GroupByClause []ColumnElem 29 | OrderByClause *OrderByClause 30 | HavingClause []HavingClause 31 | WhereClause *WhereClause 32 | ForUpdateClause *ForUpdateClause 33 | OffsetValue *int 34 | LimitValue *int 35 | } 36 | 37 | // Select sets the selected columns 38 | func (s SelectStmt) Select(clauses ...Clause) SelectStmt { 39 | s.SelectList = clauses 40 | return s 41 | } 42 | 43 | // From sets the from selectable of select statement 44 | func (s SelectStmt) From(selectable Selectable) SelectStmt { 45 | s.FromClause = selectable 46 | return s 47 | } 48 | 49 | // Where sets the where clause of select statement 50 | func (s SelectStmt) Where(clauses ...Clause) SelectStmt { 51 | where := Where(clauses...) 52 | s.WhereClause = &where 53 | return s 54 | } 55 | 56 | // InnerJoin appends an inner join clause to the select statement 57 | func (s SelectStmt) InnerJoin(right Selectable, onClause ...Clause) SelectStmt { 58 | return s.From(Join("INNER JOIN", s.FromClause, right, onClause...)) 59 | } 60 | 61 | // CrossJoin appends an cross join clause to the select statement 62 | func (s SelectStmt) CrossJoin(right Selectable) SelectStmt { 63 | return s.From(Join("CROSS JOIN", s.FromClause, right, nil)) 64 | } 65 | 66 | // LeftJoin appends an left outer join clause to the select statement 67 | func (s SelectStmt) LeftJoin(right Selectable, onClause ...Clause) SelectStmt { 68 | return s.From(Join("LEFT OUTER JOIN", s.FromClause, right, onClause...)) 69 | } 70 | 71 | // RightJoin appends a right outer join clause to select statement 72 | func (s SelectStmt) RightJoin(right Selectable, onClause ...Clause) SelectStmt { 73 | return s.From(Join("RIGHT OUTER JOIN", s.FromClause, right, onClause...)) 74 | } 75 | 76 | // OrderBy generates an OrderByClause and sets select statement's orderbyclause 77 | // OrderBy(usersTable.C("id")).Asc() 78 | // OrderBy(usersTable.C("email")).Desc() 79 | func (s SelectStmt) OrderBy(columns ...ColumnElem) SelectStmt { 80 | s.OrderByClause = &OrderByClause{columns, "ASC"} 81 | return s 82 | } 83 | 84 | // Asc sets the t type of current order by clause 85 | // NOTE: Please use it after calling OrderBy() 86 | func (s SelectStmt) Asc() SelectStmt { 87 | s.OrderByClause.t = "ASC" 88 | return s 89 | } 90 | 91 | // Desc sets the t type of current order by clause 92 | // NOTE: Please use it after calling OrderBy() 93 | func (s SelectStmt) Desc() SelectStmt { 94 | s.OrderByClause.t = "DESC" 95 | return s 96 | } 97 | 98 | // GroupBy appends columns to group by clause of the select statement 99 | func (s SelectStmt) GroupBy(cols ...ColumnElem) SelectStmt { 100 | s.GroupByClause = append(s.GroupByClause, cols...) 101 | return s 102 | } 103 | 104 | // Having appends a having clause to select statement 105 | func (s SelectStmt) Having(aggregate AggregateClause, op string, value interface{}) SelectStmt { 106 | s.HavingClause = append(s.HavingClause, HavingClause{aggregate, op, value}) 107 | return s 108 | } 109 | 110 | // Limit sets the limit number of rows 111 | func (s SelectStmt) Limit(limit int) SelectStmt { 112 | s.LimitValue = &limit 113 | return s 114 | } 115 | 116 | // Offset sets the offset 117 | func (s SelectStmt) Offset(value int) SelectStmt { 118 | s.OffsetValue = &value 119 | return s 120 | } 121 | 122 | // LimitOffset sets the limit & offset values of the select statement 123 | func (s SelectStmt) LimitOffset(limit, offset int) SelectStmt { 124 | s.LimitValue = &limit 125 | s.OffsetValue = &offset 126 | return s 127 | } 128 | 129 | // ForUpdate adds a "FOR UPDATE" clause 130 | func (s SelectStmt) ForUpdate(tables ...TableElem) SelectStmt { 131 | s.ForUpdateClause = &ForUpdateClause{tables} 132 | return s 133 | } 134 | 135 | // Accept calls the compiler VisitSelect method 136 | func (s SelectStmt) Accept(context *CompilerContext) string { 137 | return context.Compiler.VisitSelect(context, s) 138 | } 139 | 140 | // Build compiles the select statement and returns the Stmt 141 | func (s SelectStmt) Build(dialect Dialect) *Stmt { 142 | context := NewCompilerContext(dialect) 143 | statement := Statement() 144 | statement.AddSQLClause(s.Accept(context)) 145 | statement.AddBinding(context.Binds...) 146 | 147 | return statement 148 | } 149 | 150 | type joinOnClauseCandidate struct { 151 | source TableElem 152 | fkey ForeignKeyConstraint 153 | target TableElem 154 | } 155 | 156 | func getTable(sel Selectable) (TableElem, bool) { 157 | switch t := sel.(type) { 158 | case TableElem: 159 | return t, true 160 | case *TableElem: 161 | return *t, true 162 | default: 163 | return TableElem{}, false 164 | } 165 | } 166 | 167 | // GuessJoinOnClause finds a join 'ON' clause between two tables 168 | func GuessJoinOnClause(left Selectable, right Selectable) Clause { 169 | leftTable, ok := getTable(left) 170 | if !ok { 171 | panic("left Selectable is not a Table: Cannot guess join onClause") 172 | } 173 | rightTable, ok := getTable(right) 174 | if !ok { 175 | panic("right Selectable is not a Table: Cannot guess join onClause") 176 | } 177 | 178 | var candidates []joinOnClauseCandidate 179 | 180 | for _, fkey := range leftTable.ForeignKeyConstraints.FKeys { 181 | if fkey.RefTable != rightTable.Name { 182 | continue 183 | } 184 | candidates = append( 185 | candidates, 186 | joinOnClauseCandidate{leftTable, fkey, rightTable}) 187 | } 188 | 189 | for _, fkey := range rightTable.ForeignKeyConstraints.FKeys { 190 | if fkey.RefTable != leftTable.Name { 191 | continue 192 | } 193 | candidates = append( 194 | candidates, 195 | joinOnClauseCandidate{rightTable, fkey, leftTable}) 196 | } 197 | switch len(candidates) { 198 | case 0: 199 | panic(fmt.Sprintf( 200 | "No foreign keys found between %s and %s", 201 | leftTable.Name, rightTable.Name)) 202 | case 1: 203 | candidate := candidates[0] 204 | var clauses []Clause 205 | for i, col := range candidate.fkey.Cols { 206 | refCol := candidate.fkey.RefCols[i] 207 | clauses = append( 208 | clauses, 209 | Eq(candidate.source.C(col), candidate.target.C(refCol)), 210 | ) 211 | } 212 | if len(clauses) == 1 { 213 | return clauses[0] 214 | } 215 | return And(clauses...) 216 | default: 217 | panic(fmt.Sprintf( 218 | "Found %d foreign keys between %s and %s", 219 | len(candidates), leftTable.Name, rightTable.Name)) 220 | } 221 | } 222 | 223 | // MakeJoinOnClause assemble a 'ON' clause for a join from either: 224 | // 0 clause: attempt to guess the join clause (only if left & right are tables), 225 | // otherwise panics 226 | // 1 clause: returns it 227 | // 2 clauses: returns a Eq() of both 228 | // otherwise if panics 229 | func MakeJoinOnClause(left Selectable, right Selectable, onClause ...Clause) Clause { 230 | switch len(onClause) { 231 | case 0: 232 | return GuessJoinOnClause(left, right) 233 | case 1: 234 | return onClause[0] 235 | case 2: 236 | return Eq(onClause[0], onClause[1]) 237 | default: 238 | panic("Cannot make a join condition with more than 2 clauses") 239 | } 240 | } 241 | 242 | // Join returns a new JoinClause 243 | // onClause can be one of: 244 | // - 0 clause: attempt to guess the join clause (only if left & right are tables), 245 | // otherwise panics 246 | // - 1 clause: use it directly 247 | // - 2 clauses: use a Eq() of both 248 | func Join(joinType string, left Selectable, right Selectable, onClause ...Clause) JoinClause { 249 | return JoinClause{ 250 | JoinType: joinType, 251 | Left: left, 252 | Right: right, 253 | OnClause: MakeJoinOnClause(left, right, onClause...), 254 | } 255 | } 256 | 257 | // JoinClause is the base struct for generating join clauses when using select 258 | // It satisfies Clause interface 259 | type JoinClause struct { 260 | JoinType string 261 | Left Selectable 262 | Right Selectable 263 | OnClause Clause 264 | } 265 | 266 | // Accept calls the compiler VisitJoin method 267 | func (c JoinClause) Accept(context *CompilerContext) string { 268 | return context.Compiler.VisitJoin(context, c) 269 | } 270 | 271 | // All returns the columns from both sides of the join 272 | func (c JoinClause) All() []Clause { 273 | return append(c.Left.All(), c.Right.All()...) 274 | } 275 | 276 | // ColumnList returns the columns from both sides of the join 277 | func (c JoinClause) ColumnList() []ColumnElem { 278 | return append(c.Left.ColumnList(), c.Right.ColumnList()...) 279 | } 280 | 281 | // C returns the first column with the given name 282 | // If columns from both sides of the join match the name, 283 | // the one from the left side will be returned. 284 | func (c JoinClause) C(name string) ColumnElem { 285 | for _, c := range c.ColumnList() { 286 | if c.Name == name { 287 | return c 288 | } 289 | } 290 | panic(fmt.Sprintf("No such column '%s' in join %v", name, c)) 291 | } 292 | 293 | // DefaultName returns an empty string because Joins have no name by default 294 | func (c JoinClause) DefaultName() string { 295 | return "" 296 | } 297 | 298 | // OrderByClause is the base struct for generating order by clauses when using select 299 | // It satisfies SQLClause interface 300 | type OrderByClause struct { 301 | columns []ColumnElem 302 | t string 303 | } 304 | 305 | // Accept generates an order by clause 306 | func (c OrderByClause) Accept(context *CompilerContext) string { 307 | return context.Compiler.VisitOrderBy(context, c) 308 | } 309 | 310 | // HavingClause is the base struct for generating having clauses when using select 311 | // It satisfies SQLClause interface 312 | type HavingClause struct { 313 | aggregate AggregateClause 314 | op string 315 | value interface{} 316 | } 317 | 318 | // Accept calls the compiler VisitHaving function 319 | func (c HavingClause) Accept(context *CompilerContext) string { 320 | return context.Compiler.VisitHaving(context, c) 321 | } 322 | 323 | // ForUpdateClause is a FOR UPDATE expression 324 | type ForUpdateClause struct { 325 | Tables []TableElem 326 | } 327 | 328 | // Accept calls the compiler VisitForUpdate method 329 | func (s ForUpdateClause) Accept(context *CompilerContext) string { 330 | return context.Compiler.VisitForUpdate(context, s) 331 | } 332 | 333 | // Alias returns a new AliasClause 334 | func Alias(name string, selectable Selectable) AliasClause { 335 | return AliasClause{ 336 | Name: name, 337 | Selectable: selectable, 338 | } 339 | } 340 | 341 | // AliasClause is a ALIAS sql clause 342 | type AliasClause struct { 343 | Name string 344 | Selectable Selectable 345 | } 346 | 347 | // Accept calls the compiler VisitAlias function 348 | func (c AliasClause) Accept(context *CompilerContext) string { 349 | return context.Compiler.VisitAlias(context, c) 350 | } 351 | 352 | // C returns the aliased selectable column with the given name. 353 | // Before returning it, the 'Table' field is updated with alias 354 | // name so that they can be used in Select() 355 | func (c AliasClause) C(name string) ColumnElem { 356 | col := c.Selectable.C(name) 357 | col.Table = c.Name 358 | return col 359 | } 360 | 361 | // All returns the aliased selectable columns with their "Table" 362 | // field updated with the alias name 363 | func (c AliasClause) All() []Clause { 364 | var clauses []Clause 365 | for _, col := range c.ColumnList() { 366 | clauses = append(clauses, col) 367 | } 368 | return clauses 369 | } 370 | 371 | // ColumnList returns the aliased selectable columns with their "Table" 372 | // field updated with the alias name 373 | func (c AliasClause) ColumnList() []ColumnElem { 374 | var cols []ColumnElem 375 | for _, col := range c.Selectable.ColumnList() { 376 | col.Table = c.Name 377 | cols = append(cols, col) 378 | } 379 | return cols 380 | } 381 | 382 | // DefaultName returns the alias name 383 | func (c AliasClause) DefaultName() string { 384 | return c.Name 385 | } 386 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/suite" 9 | ) 10 | 11 | type SelectTestSuite struct { 12 | suite.Suite 13 | users TableElem 14 | sessions TableElem 15 | ctx *CompilerContext 16 | dialect Dialect 17 | } 18 | 19 | func (suite *SelectTestSuite) SetupTest() { 20 | suite.users = Table( 21 | "users", 22 | Column("id", BigInt()), 23 | Column("email", Varchar()).NotNull().Unique(), 24 | Column("password", Varchar()).NotNull(), 25 | PrimaryKey("id"), 26 | ) 27 | 28 | suite.sessions = Table( 29 | "sessions", 30 | Column("id", BigInt()), 31 | Column("user_id", BigInt()), 32 | Column("auth_token", Varchar().Size(36)).Unique().NotNull(), 33 | PrimaryKey("id"), 34 | ForeignKey("user_id").References("users", "id"), 35 | ) 36 | 37 | suite.dialect = NewDefaultDialect() 38 | suite.ctx = NewCompilerContext(suite.dialect) 39 | } 40 | 41 | func (suite *SelectTestSuite) TestSelectSimple() { 42 | sel := Select(suite.users.C("id")).From(suite.users) 43 | assert.Equal(suite.T(), "SELECT id\nFROM users", sel.Accept(suite.ctx)) 44 | } 45 | 46 | func (suite *SelectTestSuite) TestSelectAggregate() { 47 | sel := Select(suite.users.C("id")).From(suite.users) 48 | selCount := sel.Select(Count(suite.users.C("id"))) 49 | assert.Equal(suite.T(), "SELECT COUNT(id)\nFROM users", selCount.Accept(suite.ctx)) 50 | } 51 | 52 | func (suite *SelectTestSuite) TestSelectWhere() { 53 | sel := Select(suite.users.C("id")). 54 | From(suite.users). 55 | Where( 56 | And( 57 | Eq(suite.users.C("email"), "al@pacino.com"), 58 | NotEq(suite.users.C("id"), 5), 59 | ), 60 | ) 61 | 62 | sql := sel.Accept(suite.ctx) 63 | binds := suite.ctx.Binds 64 | 65 | assert.Equal(suite.T(), "SELECT id\nFROM users\nWHERE (email = ? AND id != ?)", sql) 66 | assert.Equal(suite.T(), []interface{}{"al@pacino.com", 5}, binds) 67 | } 68 | 69 | func (suite *SelectTestSuite) TestSelectOrderByLimit() { 70 | selOrderByDesc := Select(suite.sessions.C("id")). 71 | From(suite.sessions). 72 | Where(Eq(suite.sessions.C("user_id"), 5)). 73 | OrderBy(suite.sessions.C("id")).Desc(). 74 | Limit(20) 75 | 76 | sql := selOrderByDesc.Accept(suite.ctx) 77 | binds := suite.ctx.Binds 78 | 79 | assert.Equal(suite.T(), "SELECT id\nFROM sessions\nWHERE user_id = ?\nORDER BY id DESC\nLIMIT 20", sql) 80 | assert.Equal(suite.T(), []interface{}{5}, binds) 81 | } 82 | 83 | func (suite *SelectTestSuite) TestSelectWithoutOrder() { 84 | selWithoutOrder := Select(suite.sessions.C("id")). 85 | From(suite.sessions). 86 | Where(Eq(suite.sessions.C("user_id"), 5)). 87 | OrderBy(suite.sessions.C("id")). 88 | Offset(12) 89 | 90 | sql := selWithoutOrder.Accept(suite.ctx) 91 | binds := suite.ctx.Binds 92 | 93 | assert.Equal(suite.T(), "SELECT id\nFROM sessions\nWHERE user_id = ?\nORDER BY id ASC\nOFFSET 12", sql) 94 | assert.Equal(suite.T(), []interface{}{5}, binds) 95 | } 96 | 97 | func (suite *SelectTestSuite) TestSelectOrderByAsc() { 98 | selOrderByAsc := Select(suite.sessions.C("id")). 99 | From(suite.sessions). 100 | Where(Eq(suite.sessions.C("user_id"), 5)). 101 | OrderBy(suite.sessions.C("id")).Asc(). 102 | LimitOffset(20, 12) 103 | 104 | sql := selOrderByAsc.Accept(suite.ctx) 105 | binds := suite.ctx.Binds 106 | 107 | assert.Equal(suite.T(), "SELECT id\nFROM sessions\nWHERE user_id = ?\nORDER BY id ASC\nLIMIT 20 OFFSET 12", sql) 108 | assert.Equal(suite.T(), []interface{}{5}, binds) 109 | } 110 | 111 | func (suite *SelectTestSuite) TestSelectInnerJoin() { 112 | selInnerJoin := Select(suite.sessions.C("id"), suite.sessions.C("auth_token")). 113 | From(suite.sessions). 114 | InnerJoin(suite.users, suite.sessions.C("user_id"), suite.users.C("id")). 115 | Where(Eq(suite.sessions.C("user_id"), 5)) 116 | 117 | sql := selInnerJoin.Accept(suite.ctx) 118 | binds := suite.ctx.Binds 119 | 120 | assert.Equal(suite.T(), suite.sessions.C("user_id"), selInnerJoin.FromClause.C("user_id")) 121 | assert.Panics(suite.T(), func() { selInnerJoin.FromClause.C("invalid") }) 122 | assert.Equal(suite.T(), len(suite.sessions.All())+len(suite.users.All()), len(selInnerJoin.FromClause.All())) 123 | 124 | assert.Equal(suite.T(), "SELECT sessions.id, sessions.auth_token\nFROM sessions\nINNER JOIN users ON sessions.user_id = users.id\nWHERE sessions.user_id = ?", sql) 125 | assert.Equal(suite.T(), []interface{}{5}, binds) 126 | } 127 | 128 | func (suite *SelectTestSuite) TestSelectLeftJoin() { 129 | selLeftJoin := Select(suite.sessions.C("id"), suite.sessions.C("auth_token")). 130 | From(suite.sessions). 131 | LeftJoin(suite.users, suite.sessions.C("user_id"), suite.users.C("id")). 132 | Where(Eq(suite.sessions.C("user_id"), 5)) 133 | 134 | sql := selLeftJoin.Accept(suite.ctx) 135 | binds := suite.ctx.Binds 136 | 137 | assert.Equal(suite.T(), "SELECT sessions.id, sessions.auth_token\nFROM sessions\nLEFT OUTER JOIN users ON sessions.user_id = users.id\nWHERE sessions.user_id = ?", sql) 138 | assert.Equal(suite.T(), []interface{}{5}, binds) 139 | } 140 | 141 | func (suite *SelectTestSuite) TestSelectRightJoin() { 142 | selRightJoin := Select(suite.sessions.C("id")). 143 | From(suite.sessions). 144 | RightJoin(suite.users, suite.sessions.C("user_id"), suite.users.C("id")). 145 | Where(Eq(suite.sessions.C("user_id"), 5)) 146 | 147 | sql := selRightJoin.Accept(suite.ctx) 148 | binds := suite.ctx.Binds 149 | 150 | assert.Equal(suite.T(), "SELECT sessions.id\nFROM sessions\nRIGHT OUTER JOIN users ON sessions.user_id = users.id\nWHERE sessions.user_id = ?", sql) 151 | assert.Equal(suite.T(), []interface{}{5}, binds) 152 | } 153 | 154 | func (suite *SelectTestSuite) TestSelectCrossJoin() { 155 | selCrossJoin := Select(suite.sessions.C("id")). 156 | From(suite.sessions). 157 | CrossJoin(suite.users). 158 | Where(Eq(suite.sessions.C("user_id"), 5)) 159 | 160 | sql := selCrossJoin.Accept(suite.ctx) 161 | binds := suite.ctx.Binds 162 | 163 | assert.Equal(suite.T(), "SELECT sessions.id\nFROM sessions\nCROSS JOIN users\nWHERE sessions.user_id = ?", sql) 164 | assert.Equal(suite.T(), []interface{}{5}, binds) 165 | } 166 | 167 | func (suite *SelectTestSuite) TestSelectGroupByHaving() { 168 | sel := Select(Count(suite.sessions.C("id"))). 169 | From(suite.sessions). 170 | GroupBy(suite.sessions.C("user_id")). 171 | Having(Sum(suite.sessions.C("id")), ">", 4) 172 | 173 | sql := sel.Accept(suite.ctx) 174 | binds := suite.ctx.Binds 175 | 176 | assert.Equal(suite.T(), "SELECT COUNT(id)\nFROM sessions\nGROUP BY user_id\nHAVING SUM(id) > ?", sql) 177 | assert.Equal(suite.T(), []interface{}{4}, binds) 178 | } 179 | 180 | func (suite *SelectTestSuite) TestSelectAliasFrom() { 181 | sessionAlias := Alias("newname", suite.sessions) 182 | 183 | sel := Select(sessionAlias.C("id")).From(sessionAlias) 184 | assert.Equal(suite.T(), "SELECT id\nFROM sessions AS newname", sel.Accept(suite.ctx)) 185 | } 186 | 187 | func (suite *SelectTestSuite) TestSelectAliasAll() { 188 | sessionAlias := Alias("newname", suite.sessions) 189 | 190 | sel := Select(sessionAlias.C("id")).From(sessionAlias) 191 | sql := sel.Accept(suite.ctx) 192 | 193 | assert.Contains(suite.T(), sql, "id", sql) 194 | assert.Contains(suite.T(), sql, "sessions AS newname", sql) 195 | } 196 | 197 | func (suite *SelectTestSuite) TestSelectAliasWhereMultipleTable() { 198 | sessionAlias := Alias("newname", suite.sessions) 199 | usersAlias := Alias("u", suite.users) 200 | sel := Select(usersAlias.C("email")). 201 | From(usersAlias). 202 | LeftJoin(sessionAlias, usersAlias.C("id"), sessionAlias.C("user_id")). 203 | Where(sessionAlias.C("auth_token").Eq("42")) 204 | 205 | sql := sel.Accept(suite.ctx) 206 | expected := strings.Join([]string{ 207 | "SELECT u.email", 208 | "FROM users AS u", 209 | "LEFT OUTER JOIN sessions AS newname ON u.id = newname.user_id", 210 | "WHERE newname.auth_token = ?", 211 | }, "\n") 212 | assert.Equal(suite.T(), expected, sql) 213 | } 214 | 215 | func (suite *SelectTestSuite) TestSelectGuessJoinOnClause() { 216 | t1 := Table( 217 | "t1", 218 | Column("c1", Int()), 219 | Column("c2", Int()), 220 | ) 221 | t2 := Table( 222 | "t2", 223 | Column("c1", Int()), 224 | Column("c2", Int()), 225 | ) 226 | t3 := Table( 227 | "t3", 228 | Column("c1", Int()), 229 | Column("c2", Int()), 230 | ForeignKey("c1").References("t1", "c1"), 231 | ForeignKey("c1").References("t2", "c1"), 232 | ForeignKey("c2").References("t2", "c2"), 233 | ) 234 | t4 := Table( 235 | "t4", 236 | Column("c1", Int()), 237 | Column("c2", Int()), 238 | ForeignKey("c1", "c2").References("t1", "c1", "c2"), 239 | ) 240 | 241 | assert.Panics(suite.T(), func() { 242 | GuessJoinOnClause(t1, Alias("tt", t3)) 243 | }) 244 | 245 | assert.Panics(suite.T(), func() { 246 | GuessJoinOnClause(Alias("tt", t3), t2) 247 | }) 248 | 249 | assert.Panics(suite.T(), func() { 250 | GuessJoinOnClause(t1, &t2) 251 | }) 252 | 253 | assert.Equal(suite.T(), "t3.c1 = t1.c1", GuessJoinOnClause(t3, t1).Accept(suite.ctx)) 254 | assert.Equal(suite.T(), "t3.c1 = t1.c1", GuessJoinOnClause(t1, t3).Accept(suite.ctx)) 255 | assert.Equal(suite.T(), "(t4.c1 = t1.c1 AND t4.c2 = t1.c2)", GuessJoinOnClause(t4, t1).Accept(suite.ctx)) 256 | 257 | assert.Panics(suite.T(), func() { 258 | GuessJoinOnClause(t2, t3) 259 | }) 260 | } 261 | 262 | func (suite *SelectTestSuite) TestSelectMakeJoinOnClause() { 263 | assert.Panics(suite.T(), func() { 264 | MakeJoinOnClause(TableElem{}, TableElem{}, And(), And(), And()) 265 | }) 266 | } 267 | 268 | func TestSelectTestSuite(t *testing.T) { 269 | suite.Run(t, new(SelectTestSuite)) 270 | } 271 | -------------------------------------------------------------------------------- /statement.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | const defaultDelimiter = "\n" 9 | 10 | // Statement creates a new query and returns its pointer 11 | func Statement() *Stmt { 12 | return &Stmt{ 13 | clauses: []string{}, 14 | bindings: []interface{}{}, 15 | delimiter: defaultDelimiter, 16 | bindingIndex: 0, 17 | } 18 | } 19 | 20 | // Stmt is the base abstraction for all sql queries 21 | type Stmt struct { 22 | clauses []string 23 | bindings []interface{} 24 | delimiter string 25 | bindingIndex int 26 | } 27 | 28 | // Text is for executing raw sql 29 | // It parses the sql and generates clauses from 30 | func (s *Stmt) Text(sql string) { 31 | sql = strings.Replace(sql, ";", "", -1) 32 | sql = strings.Replace(sql, "\t", "", -1) 33 | sql = strings.Trim(sql, "\n") 34 | clauses := strings.Split(sql, "\n") 35 | for _, c := range clauses { 36 | s.clauses = append(s.clauses, c) 37 | } 38 | } 39 | 40 | // SetDelimiter sets the delimiter of query 41 | func (s *Stmt) SetDelimiter(delimiter string) { 42 | s.delimiter = delimiter 43 | } 44 | 45 | // AddSQLClause appends a new clause to current query 46 | func (s *Stmt) AddSQLClause(clause string) { 47 | s.clauses = append(s.clauses, clause) 48 | } 49 | 50 | // AddBinding appends a new binding to current query 51 | func (s *Stmt) AddBinding(bindings ...interface{}) { 52 | for _, v := range bindings { 53 | s.bindings = append(s.bindings, v) 54 | } 55 | } 56 | 57 | // SQLClauses returns all clauses of current query 58 | func (s *Stmt) SQLClauses() []string { 59 | return s.clauses 60 | } 61 | 62 | // Bindings returns all bindings of current query 63 | func (s *Stmt) Bindings() []interface{} { 64 | return s.bindings 65 | } 66 | 67 | // SQL returns the query struct sql statement 68 | func (s *Stmt) SQL() string { 69 | if len(s.clauses) > 0 { 70 | sql := fmt.Sprintf("%s;", strings.Join(s.clauses, s.delimiter)) 71 | return sql 72 | } 73 | 74 | return "" 75 | } 76 | -------------------------------------------------------------------------------- /statement_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestStatement(t *testing.T) { 9 | statement := Statement() 10 | 11 | statement.AddSQLClause("SELECT name") 12 | statement.AddSQLClause("FROM user") 13 | statement.AddSQLClause("WHERE id = ?") 14 | statement.AddBinding(5) 15 | 16 | assert.Equal(t, []string{"SELECT name", "FROM user", "WHERE id = ?"}, statement.SQLClauses()) 17 | assert.Equal(t, []interface{}{5}, statement.Bindings()) 18 | assert.Equal(t, "SELECT name\nFROM user\nWHERE id = ?;", statement.SQL()) 19 | } 20 | 21 | func TestStatementRaw(t *testing.T) { 22 | 23 | statement := Statement() 24 | sql := ` 25 | SELECT name 26 | FROM user 27 | WHERE id = ?; 28 | ` 29 | statement.Text(sql) 30 | assert.Equal(t, []string{"SELECT name", "FROM user", "WHERE id = ?"}, statement.SQLClauses()) 31 | assert.Equal(t, "SELECT name\nFROM user\nWHERE id = ?;", statement.SQL()) 32 | } 33 | 34 | func TestStatementWithCustomDelimiter(t *testing.T) { 35 | statement := Statement() 36 | 37 | assert.Equal(t, "", statement.SQL()) 38 | 39 | statement.SetDelimiter(" ") 40 | 41 | statement.AddSQLClause("SELECT name") 42 | statement.AddSQLClause("FROM user") 43 | 44 | statement.AddSQLClause("WHERE id = ?") 45 | statement.AddBinding(5) 46 | 47 | assert.Equal(t, []string{"SELECT name", "FROM user", "WHERE id = ?"}, statement.SQLClauses()) 48 | assert.Equal(t, []interface{}{5}, statement.Bindings()) 49 | assert.Equal(t, "SELECT name FROM user WHERE id = ?;", statement.SQL()) 50 | } 51 | -------------------------------------------------------------------------------- /table.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Table generates table struct given name and clauses 9 | func Table(name string, clauses ...TableSQLClause) TableElem { 10 | table := TableElem{ 11 | Name: name, 12 | Columns: map[string]ColumnElem{}, 13 | ForeignKeyConstraints: ForeignKeyConstraints{}, 14 | Indices: []IndexElem{}, 15 | } 16 | 17 | var pkeyCols []ColumnElem 18 | 19 | for _, clause := range clauses { 20 | switch clause.(type) { 21 | case ColumnElem: 22 | col := clause.(ColumnElem) 23 | if col.Options.PrimaryKey { 24 | pkeyCols = append(pkeyCols, col) 25 | } 26 | col.Table = name 27 | table.Columns[col.Name] = col 28 | break 29 | case PrimaryKeyConstraint: 30 | table.PrimaryKeyConstraint = clause.(PrimaryKeyConstraint) 31 | break 32 | //case ForeignKeyConstraints: 33 | // table.ForeignKeyConstraints.FKeys = append( 34 | // table.ForeignKeyConstraints.FKeys, 35 | // clause.(ForeignKeyConstraints).FKeys...) 36 | case ForeignKeyConstraint: 37 | table.ForeignKeyConstraints.FKeys = append( 38 | table.ForeignKeyConstraints.FKeys, 39 | clause.(ForeignKeyConstraint), 40 | ) 41 | break 42 | case UniqueKeyConstraint: 43 | table.UniqueKeyConstraint = clause.(UniqueKeyConstraint).Table(table.Name) 44 | break 45 | case IndexElem: 46 | table.Indices = append(table.Indices, clause.(IndexElem)) 47 | break 48 | } 49 | } 50 | 51 | if len(pkeyCols) > 0 && table.PrimaryKeyConstraint.Columns != nil { 52 | panic(fmt.Sprintf("Table %s has both 'PrimaryKey()' columns (%#v) and a PrimaryKeyConstraint. Only only should be set", name, pkeyCols)) 53 | } 54 | if len(pkeyCols) > 0 { 55 | var pkeyNames []string 56 | for _, col := range pkeyCols { 57 | pkeyNames = append(pkeyNames, col.Name) 58 | } 59 | table.PrimaryKeyConstraint = PrimaryKey(pkeyNames...) 60 | } 61 | 62 | // Make sure the columns are flagged as primary key 63 | for _, name := range table.PrimaryKeyConstraint.Columns { 64 | table.Columns[name] = table.Columns[name].PrimaryKey() 65 | } 66 | if len(table.PrimaryKeyConstraint.Columns) == 1 { 67 | // Make sure the column will inline the primary key 68 | name := table.PrimaryKeyConstraint.Columns[0] 69 | table.Columns[name] = table.Columns[name].inlinePrimaryKey() 70 | } 71 | 72 | return table 73 | } 74 | 75 | // TableElem is the definition of any sql table 76 | type TableElem struct { 77 | Name string 78 | Columns map[string]ColumnElem 79 | PrimaryKeyConstraint PrimaryKeyConstraint 80 | ForeignKeyConstraints ForeignKeyConstraints 81 | UniqueKeyConstraint UniqueKeyConstraint 82 | Indices []IndexElem 83 | } 84 | 85 | // DefaultName returns the name of the table 86 | func (t TableElem) DefaultName() string { 87 | return t.Name 88 | } 89 | 90 | // All returns all columns of table as a column slice 91 | func (t TableElem) All() []Clause { 92 | cols := []Clause{} 93 | for _, v := range t.Columns { 94 | cols = append(cols, v) 95 | } 96 | return cols 97 | } 98 | 99 | // ColumnList columns of the table 100 | func (t TableElem) ColumnList() []ColumnElem { 101 | cols := []ColumnElem{} 102 | for _, v := range t.Columns { 103 | cols = append(cols, v) 104 | } 105 | return cols 106 | } 107 | 108 | // Index appends an IndexElem to current table without giving table name 109 | func (t TableElem) Index(cols ...string) TableElem { 110 | t.Indices = append(t.Indices, Index(t.Name, cols...)) 111 | return t 112 | } 113 | 114 | // Create generates create table syntax and returns it as a query struct 115 | func (t TableElem) Create(dialect Dialect) string { 116 | statement := Statement() 117 | statement.AddSQLClause(fmt.Sprintf("CREATE TABLE %s (", dialect.Escape(t.Name))) 118 | 119 | colClauses := []string{} 120 | for _, col := range t.Columns { 121 | colClauses = append(colClauses, fmt.Sprintf("\t%s", col.String(dialect))) 122 | } 123 | 124 | if len(t.PrimaryKeyConstraint.Columns) > 1 { 125 | colClauses = append(colClauses, fmt.Sprintf("\t%s", t.PrimaryKeyConstraint.String(dialect))) 126 | } 127 | 128 | if len(t.ForeignKeyConstraints.FKeys) > 0 { 129 | colClauses = append(colClauses, t.ForeignKeyConstraints.String(dialect)) 130 | } 131 | 132 | if t.UniqueKeyConstraint.name != "" { 133 | colClauses = append(colClauses, fmt.Sprintf("\t%s", t.UniqueKeyConstraint.String(dialect))) 134 | } 135 | 136 | statement.AddSQLClause(strings.Join(colClauses, ",\n")) 137 | 138 | statement.AddSQLClause(")") 139 | 140 | ddl := statement.SQL() 141 | 142 | indexSqls := []string{} 143 | for _, index := range t.Indices { 144 | iSQLClause := index.String(dialect) 145 | indexSqls = append(indexSqls, iSQLClause) 146 | } 147 | 148 | sqls := []string{ddl} 149 | sqls = append(sqls, indexSqls...) 150 | 151 | return strings.Join(sqls, "\n") 152 | } 153 | 154 | // Build generates a Statement object out of table ddl 155 | func (t TableElem) Build(dialect Dialect) *Stmt { 156 | sql := t.Create(dialect) 157 | statement := Statement() 158 | statement.AddSQLClause(strings.Trim(sql, ";")) // TODO: Remove this ugly hack 159 | return statement 160 | } 161 | 162 | // PrimaryCols returns the columns that are primary key to the table 163 | func (t TableElem) PrimaryCols() []ColumnElem { 164 | primaryCols := []ColumnElem{} 165 | pkCols := t.PrimaryKeyConstraint.Columns 166 | for _, pkCol := range pkCols { 167 | primaryCols = append(primaryCols, t.C(pkCol)) 168 | } 169 | return primaryCols 170 | } 171 | 172 | // Drop generates drop table syntax and returns it as a query struct 173 | func (t TableElem) Drop(dialect Dialect) string { 174 | stmt := Statement() 175 | stmt.AddSQLClause(fmt.Sprintf("DROP TABLE %s", dialect.Escape(t.Name))) 176 | return stmt.SQL() 177 | } 178 | 179 | // C returns the column name given col 180 | func (t TableElem) C(name string) ColumnElem { 181 | return t.Columns[name] 182 | } 183 | 184 | // query starters 185 | 186 | // Insert starts an insert statement by setting the table parameter 187 | func (t TableElem) Insert() InsertStmt { 188 | return Insert(t) 189 | } 190 | 191 | // Update starts an update statement by setting the table parameter 192 | func (t TableElem) Update() UpdateStmt { 193 | return Update(t) 194 | } 195 | 196 | // Delete starts a delete statement by setting the table parameter 197 | func (t TableElem) Delete() DeleteStmt { 198 | return Delete(t) 199 | } 200 | 201 | // Upsert starts an upsert statement by setting the table parameter 202 | func (t TableElem) Upsert() UpsertStmt { 203 | return Upsert(t) 204 | } 205 | 206 | // Select starts a select statement by setting from table 207 | func (t TableElem) Select(clauses ...Clause) SelectStmt { 208 | return Select(clauses...).From(t) 209 | } 210 | 211 | // Accept implements Clause.Accept 212 | func (t TableElem) Accept(context *CompilerContext) string { 213 | return context.Compiler.VisitTable(context, t) 214 | } 215 | -------------------------------------------------------------------------------- /table_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/suite" 8 | ) 9 | 10 | type TableTestSuite struct { 11 | suite.Suite 12 | dialect Dialect 13 | } 14 | 15 | func (suite *TableTestSuite) SetupTest() { 16 | suite.dialect = NewDefaultDialect() 17 | } 18 | 19 | func (suite *TableTestSuite) TestTableSimpleCreate() { 20 | usersTable := Table("users", Column("id", Varchar().Size(40))) 21 | assert.Equal(suite.T(), 1, len(usersTable.All())) 22 | 23 | ddl := usersTable.Create(suite.dialect) 24 | assert.Contains(suite.T(), ddl, "CREATE TABLE users (") 25 | assert.Contains(suite.T(), ddl, "id VARCHAR(40)") 26 | assert.Contains(suite.T(), ddl, ");") 27 | 28 | statement := usersTable.Build(suite.dialect) 29 | sql := statement.SQL() 30 | assert.Contains(suite.T(), sql, "CREATE TABLE users (") 31 | assert.Contains(suite.T(), sql, "id VARCHAR(40)") 32 | assert.Contains(suite.T(), sql, ");") 33 | assert.Equal(suite.T(), []interface{}{}, statement.Bindings()) 34 | } 35 | 36 | func (suite *TableTestSuite) TestTableSimpleDrop() { 37 | usersTable := Table("users", Column("id", Varchar().Size(40))) 38 | 39 | assert.Equal(suite.T(), "DROP TABLE users;", usersTable.Drop(suite.dialect)) 40 | } 41 | 42 | func (suite *TableTestSuite) TestTablePrimaryForeignKey() { 43 | usersTable := Table( 44 | "users", 45 | Column("id", Varchar().Size(40)), 46 | Column("session_id", Varchar().Size(40)), 47 | Column("auth_token", Varchar().Size(40)), 48 | Column("role_id", Varchar().Size(40)), 49 | PrimaryKey("id"), 50 | ForeignKey("session_id", "auth_token"). 51 | References("sessions", "id", "auth_token"), 52 | ForeignKey("role_id").References("roles", "id"), 53 | ) 54 | 55 | ddl := usersTable.Create(suite.dialect) 56 | assert.Contains(suite.T(), ddl, "CREATE TABLE users (") 57 | assert.Contains(suite.T(), ddl, "auth_token VARCHAR(40)") 58 | assert.Contains(suite.T(), ddl, "role_id VARCHAR(40)") 59 | assert.Contains(suite.T(), ddl, "id VARCHAR(40) PRIMARY KEY") 60 | assert.Contains(suite.T(), ddl, "session_id VARCHAR(40)") 61 | assert.Contains(suite.T(), ddl, "FOREIGN KEY(session_id, auth_token) REFERENCES sessions(id, auth_token)") 62 | assert.Contains(suite.T(), ddl, "FOREIGN KEY(role_id) REFERENCES roles(id)") 63 | assert.Contains(suite.T(), ddl, ");") 64 | } 65 | 66 | func (suite *TableTestSuite) TestTableSimplePrimaryKey() { 67 | users := Table( 68 | "users", 69 | Column("id", Varchar().Size(40)).PrimaryKey(), 70 | ) 71 | assert.Equal(suite.T(), []string{"id"}, users.PrimaryKeyConstraint.Columns) 72 | } 73 | 74 | func (suite *TableTestSuite) TestTableCompositePrimaryKey() { 75 | 76 | users := Table( 77 | "users", 78 | Column("fname", Varchar().Size(40)).PrimaryKey(), 79 | Column("lname", Varchar().Size(40)).PrimaryKey(), 80 | ) 81 | 82 | assert.Equal(suite.T(), []string{"fname", "lname"}, users.PrimaryKeyConstraint.Columns) 83 | cols := users.PrimaryCols() 84 | assert.Equal(suite.T(), 2, len(cols)) 85 | assert.Equal(suite.T(), "fname", cols[0].Name) 86 | assert.Equal(suite.T(), "lname", cols[1].Name) 87 | 88 | ddl := users.Create(suite.dialect) 89 | assert.Contains(suite.T(), ddl, "PRIMARY KEY(fname, lname)") 90 | 91 | assert.Panics(suite.T(), func() { 92 | Table( 93 | "users", 94 | Column("id", Varchar().Size(40)).PrimaryKey(), 95 | PrimaryKey("id"), 96 | ) 97 | }) 98 | } 99 | 100 | func (suite *TableTestSuite) TestTableUniqueCompositeUnique() { 101 | usersTable := Table( 102 | "users", 103 | Column("id", Varchar().Size(40)), 104 | Column("email", Varchar().Size(40)).Unique(), 105 | Column("device_id", Varchar().Size(255)).Unique(), 106 | UniqueKey("email", "device_id"), 107 | ) 108 | 109 | ddl := usersTable.Create(suite.dialect) 110 | assert.Contains(suite.T(), ddl, "CREATE TABLE users (") 111 | assert.Contains(suite.T(), ddl, "id VARCHAR(40)") 112 | assert.Contains(suite.T(), ddl, "email VARCHAR(40) UNIQUE") 113 | assert.Contains(suite.T(), ddl, "device_id VARCHAR(255) UNIQUE") 114 | assert.Contains(suite.T(), ddl, "CONSTRAINT u_users_email_device_id UNIQUE(email, device_id)") 115 | assert.Contains(suite.T(), ddl, ");") 116 | } 117 | 118 | func (suite *TableTestSuite) TestTableIndex() { 119 | usersTable := Table( 120 | "users", 121 | Column("id", Varchar().Size(40)), 122 | Column("email", Varchar().Size(40)).Unique(), 123 | Index("users", "id"), 124 | Index("users", "email"), 125 | Index("users", "id", "email"), 126 | ) 127 | ddl := usersTable.Create(suite.dialect) 128 | assert.Contains(suite.T(), ddl, "CREATE TABLE users (") 129 | assert.Contains(suite.T(), ddl, "id VARCHAR(40)") 130 | assert.Contains(suite.T(), ddl, "email VARCHAR(40) UNIQUE") 131 | assert.Contains(suite.T(), ddl, ")") 132 | assert.Contains(suite.T(), ddl, "CREATE INDEX i_id ON users(id)") 133 | assert.Contains(suite.T(), ddl, "CREATE INDEX i_email ON users(email)") 134 | assert.Contains(suite.T(), ddl, "CREATE INDEX i_id_email ON users(id, email);") 135 | 136 | assert.Equal(suite.T(), ColumnElem{Name: "id", Type: Varchar().Size(40), Table: "users"}, usersTable.C("id")) 137 | assert.Zero(suite.T(), usersTable.C("nonExisting")) 138 | } 139 | 140 | func (suite *TableTestSuite) TestTableIndexChain() { 141 | usersTable := Table("users", Column("id", Varchar().Size(40))).Index("id") 142 | ddl := usersTable.Create(suite.dialect) 143 | assert.Contains(suite.T(), ddl, "CREATE TABLE users (") 144 | assert.Contains(suite.T(), ddl, "id VARCHAR(40)") 145 | assert.Contains(suite.T(), ddl, ");") 146 | assert.Contains(suite.T(), ddl, "CREATE INDEX i_id ON users(id);") 147 | } 148 | 149 | func (suite *TableTestSuite) TestTableStarters() { 150 | users := Table( 151 | "users", 152 | Column("id", Varchar().Size(40)), 153 | Column("email", Varchar().Size(40)).Unique(), 154 | PrimaryKey("id"), 155 | ) 156 | 157 | ins := users. 158 | Insert(). 159 | Values(map[string]interface{}{ 160 | "id": "5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55", 161 | "email": "al@pacino.com", 162 | }). 163 | Build(suite.dialect) 164 | 165 | assert.Contains(suite.T(), ins.SQL(), "INSERT INTO users") 166 | assert.Contains(suite.T(), ins.SQL(), "id") 167 | assert.Contains(suite.T(), ins.SQL(), "email") 168 | assert.Contains(suite.T(), ins.SQL(), "VALUES(?, ?)") 169 | assert.Contains(suite.T(), ins.Bindings(), "5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55") 170 | assert.Contains(suite.T(), ins.Bindings(), "al@pacino.com") 171 | 172 | ups := users.Upsert() 173 | assert.Equal(suite.T(), users, ups.Table) 174 | 175 | upd := users. 176 | Update(). 177 | Values(map[string]interface{}{ 178 | "email": "al@pacino.com", 179 | }). 180 | Where(users.C("id").Eq("5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55")). 181 | Build(suite.dialect) 182 | 183 | updSQL := upd.SQL() 184 | assert.Contains(suite.T(), updSQL, "UPDATE users") 185 | assert.Contains(suite.T(), updSQL, "SET email = ?") 186 | assert.Contains(suite.T(), updSQL, "WHERE id = ?;") 187 | 188 | assert.Equal(suite.T(), []interface{}{ 189 | "al@pacino.com", 190 | "5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55", 191 | }, upd.Bindings()) 192 | 193 | del := users. 194 | Delete(). 195 | Where(users.C("id").Eq("5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55")). 196 | Build(suite.dialect) 197 | 198 | delSQL := del.SQL() 199 | 200 | assert.Contains(suite.T(), delSQL, "DELETE FROM users") 201 | assert.Contains(suite.T(), delSQL, "WHERE users.id = ?;") 202 | assert.Equal(suite.T(), []interface{}{"5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55"}, del.Bindings()) 203 | 204 | sel := users. 205 | Select(users.C("id"), users.C("email")). 206 | Where(users.C("id").Eq("5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55")). 207 | Build(suite.dialect) 208 | 209 | selSQL := sel.SQL() 210 | 211 | assert.Contains(suite.T(), selSQL, "SELECT id, email") 212 | assert.Contains(suite.T(), selSQL, "FROM users") 213 | assert.Contains(suite.T(), selSQL, "WHERE id = ?;") 214 | assert.Equal(suite.T(), []interface{}{"5a73ef89-cf0a-4c51-ab8c-cc273ebb3a55"}, sel.Bindings()) 215 | } 216 | 217 | func TestTableTestSuite(t *testing.T) { 218 | suite.Run(t, new(TableTestSuite)) 219 | } 220 | -------------------------------------------------------------------------------- /testutils_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | type TestingLogWriter struct { 9 | t *testing.T 10 | lines []string 11 | } 12 | 13 | func (w *TestingLogWriter) Write(p []byte) (n int, err error) { 14 | w.lines = append(w.lines, string(p)) 15 | return len(p), nil 16 | } 17 | 18 | func (w *TestingLogWriter) Flush() { 19 | w.t.Log("Captured:\n" + strings.Join(w.lines, "")) 20 | } 21 | -------------------------------------------------------------------------------- /type.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Char creates char type 9 | func Char() TypeElem { 10 | return Type("CHAR") 11 | } 12 | 13 | // Varchar creates varchar type 14 | func Varchar() TypeElem { 15 | return Type("VARCHAR").Size(255) 16 | } 17 | 18 | // Text creates text type 19 | func Text() TypeElem { 20 | return Type("TEXT") 21 | } 22 | 23 | // Int creates int type 24 | func Int() TypeElem { 25 | return Type("INT") 26 | } 27 | 28 | // TinyInt creates tinyint type 29 | func TinyInt() TypeElem { 30 | return Type("TINYINT") 31 | } 32 | 33 | // SmallInt creates smallint type 34 | func SmallInt() TypeElem { 35 | return Type("SMALLINT") 36 | } 37 | 38 | // BigInt creates bigint type 39 | func BigInt() TypeElem { 40 | return Type("BIGINT") 41 | } 42 | 43 | // Numeric creates a numeric type 44 | func Numeric() TypeElem { 45 | return Type("NUMERIC") 46 | } 47 | 48 | // Decimal creates a decimal type 49 | func Decimal() TypeElem { 50 | return Type("DECIMAL") 51 | } 52 | 53 | // Float creates float type 54 | func Float() TypeElem { 55 | return Type("FLOAT") 56 | } 57 | 58 | // Boolean creates boolean type 59 | func Boolean() TypeElem { 60 | return Type("BOOLEAN") 61 | } 62 | 63 | // Timestamp creates timestamp type 64 | func Timestamp() TypeElem { 65 | return Type("TIMESTAMP") 66 | } 67 | 68 | // UUID creates a UUID type 69 | func UUID() TypeElem { 70 | return Type("UUID") 71 | } 72 | 73 | // Blob creates a BLOB type 74 | func Blob() TypeElem { 75 | return Type("BLOB") 76 | } 77 | 78 | const defaultTypeSize = -1 79 | 80 | // Type returns a new TypeElem while defining columns in table 81 | func Type(name string) TypeElem { 82 | return TypeElem{ 83 | Name: name, 84 | size: defaultTypeSize, 85 | precision: []int{}, 86 | } 87 | } 88 | 89 | // TypeElem is the struct for defining column types 90 | type TypeElem struct { 91 | Name string 92 | size int 93 | precision []int 94 | unsigned bool 95 | } 96 | 97 | // DefaultCompileType is a default implementation for Dialect.CompileType 98 | func DefaultCompileType(t TypeElem, supportsUnsigned bool) string { 99 | name := t.Name 100 | 101 | if t.unsigned && !supportsUnsigned { 102 | // use a bigger int type so the unsigned values can fit in 103 | switch name { 104 | case "TINYINT": 105 | name = "SMALLINT" 106 | case "SMALLINT": 107 | name = "INT" 108 | case "INT": 109 | name = "BIGINT" 110 | } 111 | } 112 | 113 | sizeSpecified := false 114 | precisionSpecified := false 115 | if t.size != defaultTypeSize { 116 | sizeSpecified = true 117 | } else if len(t.precision) > 0 { 118 | precisionSpecified = true 119 | } 120 | 121 | if sizeSpecified { 122 | name = fmt.Sprintf("%s(%d)", name, t.size) 123 | } else if precisionSpecified { 124 | precision := []string{} 125 | for _, p := range t.precision { 126 | precision = append(precision, fmt.Sprintf("%v", p)) 127 | } 128 | name = fmt.Sprintf("%s(%s)", name, strings.Join(precision, ", ")) 129 | } 130 | 131 | if t.unsigned && supportsUnsigned { 132 | name = fmt.Sprintf("%s UNSIGNED", name) 133 | } 134 | return name 135 | } 136 | 137 | // Size adds size constraint to column type 138 | func (t TypeElem) Size(size int) TypeElem { 139 | t.size = size 140 | return t 141 | } 142 | 143 | // Precision sets the precision of column type 144 | // Note: Use it in Float, Decimal and Numeric types 145 | func (t TypeElem) Precision(p int, s int) TypeElem { 146 | t.precision = []int{p, s} 147 | return t 148 | } 149 | 150 | // Unsigned change the column type to 'unsigned' 151 | // Note: Use it in Float, Decimal and Numeric types 152 | func (t TypeElem) Unsigned() TypeElem { 153 | t.unsigned = true 154 | return t 155 | } 156 | 157 | // Signed change the column type to 'signed' 158 | // Note: Use it in Float, Decimal and Numeric types 159 | func (t TypeElem) Signed() TypeElem { 160 | t.unsigned = false 161 | return t 162 | } 163 | -------------------------------------------------------------------------------- /type_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "github.com/stretchr/testify/suite" 6 | "testing" 7 | ) 8 | 9 | type TypeTestSuite struct { 10 | suite.Suite 11 | } 12 | 13 | func (suite *TypeTestSuite) TestTypes() { 14 | dialect := NewDialect("") 15 | 16 | precisionType := Type("FLOAT").Precision(2, 5) 17 | 18 | assert.Equal(suite.T(), "FLOAT(2, 5)", dialect.CompileType(precisionType)) 19 | 20 | assert.Equal(suite.T(), "CHAR", dialect.CompileType(Char())) 21 | assert.Equal(suite.T(), "VARCHAR(255)", dialect.CompileType(Varchar())) 22 | assert.Equal(suite.T(), "TEXT", dialect.CompileType(Text())) 23 | assert.Equal(suite.T(), "INT", dialect.CompileType(Int())) 24 | assert.Equal(suite.T(), "SMALLINT", dialect.CompileType(SmallInt())) 25 | assert.Equal(suite.T(), "BIGINT", dialect.CompileType(BigInt())) 26 | assert.Equal(suite.T(), "NUMERIC(2, 5)", dialect.CompileType(Numeric().Precision(2, 5))) 27 | assert.Equal(suite.T(), "DECIMAL", dialect.CompileType(Decimal())) 28 | assert.Equal(suite.T(), "FLOAT", dialect.CompileType(Float())) 29 | assert.Equal(suite.T(), "BOOLEAN", dialect.CompileType(Boolean())) 30 | assert.Equal(suite.T(), "TIMESTAMP", dialect.CompileType(Timestamp())) 31 | assert.Equal(suite.T(), "BLOB", dialect.CompileType(Blob())) 32 | assert.Equal(suite.T(), "UUID", dialect.CompileType(UUID())) 33 | } 34 | 35 | func (suite *TypeTestSuite) TestUnsigned() { 36 | assert.Equal(suite.T(), "BIGINT", DefaultCompileType(BigInt().Signed(), true)) 37 | assert.Equal(suite.T(), "BIGINT UNSIGNED", DefaultCompileType(BigInt().Unsigned(), true)) 38 | assert.Equal(suite.T(), "NUMERIC(2, 5) UNSIGNED", DefaultCompileType(Numeric().Precision(2, 5).Unsigned(), true)) 39 | 40 | assert.Equal(suite.T(), "INT", DefaultCompileType(Int().Signed(), false)) 41 | assert.Equal(suite.T(), "SMALLINT", DefaultCompileType(TinyInt().Unsigned(), false)) 42 | assert.Equal(suite.T(), "INT", DefaultCompileType(SmallInt().Unsigned(), false)) 43 | assert.Equal(suite.T(), "BIGINT", DefaultCompileType(Int().Unsigned(), false)) 44 | assert.Equal(suite.T(), "BIGINT", DefaultCompileType(BigInt().Unsigned(), false)) 45 | } 46 | 47 | func TestTypeTestSuite(t *testing.T) { 48 | suite.Run(t, new(TypeTestSuite)) 49 | } 50 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Update generates an update statement and returns it 4 | // qb.Update(usersTable). 5 | // Values(map[string]interface{}{"id": 1}). 6 | // Where(qb.Eq("id", 5)) 7 | func Update(table TableElem) UpdateStmt { 8 | return UpdateStmt{ 9 | table: table, 10 | values: map[string]interface{}{}, 11 | returning: []ColumnElem{}, 12 | } 13 | } 14 | 15 | // UpdateStmt is the base struct for any update statements 16 | type UpdateStmt struct { 17 | table TableElem 18 | values map[string]interface{} 19 | returning []ColumnElem 20 | where *WhereClause 21 | } 22 | 23 | // Accept implements Clause.Accept 24 | func (s UpdateStmt) Accept(context *CompilerContext) string { 25 | return context.Compiler.VisitUpdate(context, s) 26 | } 27 | 28 | // Build generates a statement out of UpdateStmt object 29 | func (s UpdateStmt) Build(dialect Dialect) *Stmt { 30 | context := NewCompilerContext(dialect) 31 | statement := Statement() 32 | statement.AddSQLClause(s.Accept(context)) 33 | statement.AddBinding(context.Binds...) 34 | 35 | return statement 36 | } 37 | 38 | // Values accepts map[string]interface{} and forms the values map of insert statement 39 | func (s UpdateStmt) Values(values map[string]interface{}) UpdateStmt { 40 | for k, v := range values { 41 | s.values[s.table.C(k).Name] = v 42 | } 43 | return s 44 | } 45 | 46 | // Returning accepts the column names as strings and forms the returning array of insert statement 47 | // NOTE: Please use it in only postgres dialect, otherwise it'll crash 48 | func (s UpdateStmt) Returning(cols ...ColumnElem) UpdateStmt { 49 | for _, c := range cols { 50 | s.returning = append(s.returning, c) 51 | } 52 | return s 53 | } 54 | 55 | // Where adds a where clause to update statement and returns the update statement 56 | func (s UpdateStmt) Where(clause Clause) UpdateStmt { 57 | s.where = &WhereClause{clause} 58 | return s 59 | } 60 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type UpdateTestSuite struct { 12 | suite.Suite 13 | dialect Dialect 14 | ctx *CompilerContext 15 | users TableElem 16 | } 17 | 18 | func (suite *UpdateTestSuite) SetupTest() { 19 | suite.dialect = NewDefaultDialect() 20 | suite.ctx = NewCompilerContext(suite.dialect) 21 | suite.users = Table( 22 | "users", 23 | Column("id", BigInt()).NotNull(), 24 | Column("email", Varchar()).NotNull().Unique(), 25 | PrimaryKey("email"), 26 | ) 27 | } 28 | 29 | func (suite *UpdateTestSuite) TestUpdateSimple() { 30 | 31 | sql := Update(suite.users). 32 | Values(map[string]interface{}{ 33 | "email": "robert@de.niro", 34 | }).Accept(suite.ctx) 35 | 36 | binds := suite.ctx.Binds 37 | 38 | assert.Contains(suite.T(), sql, "UPDATE users") 39 | assert.Contains(suite.T(), sql, "SET email = ?") 40 | assert.Equal(suite.T(), []interface{}{"robert@de.niro"}, binds) 41 | } 42 | 43 | func (suite *UpdateTestSuite) TestUpdateWhereReturning() { 44 | sql := Update(suite.users). 45 | Values(map[string]interface{}{"email": "robert@de.niro"}). 46 | Where(Eq(suite.users.C("email"), "al@pacino")). 47 | Returning(suite.users.C("id"), suite.users.C("email")). 48 | Accept(suite.ctx) 49 | binds := suite.ctx.Binds 50 | 51 | assert.Contains(suite.T(), sql, "UPDATE users") 52 | assert.Contains(suite.T(), sql, "SET email = ?") 53 | assert.Contains(suite.T(), sql, "WHERE email = ?") 54 | assert.Contains(suite.T(), sql, "RETURNING id, email") 55 | assert.Equal(suite.T(), []interface{}{ 56 | "robert@de.niro", 57 | "al@pacino", 58 | }, binds) 59 | } 60 | 61 | func TestUpdateTestSuite(t *testing.T) { 62 | suite.Run(t, new(UpdateTestSuite)) 63 | } 64 | -------------------------------------------------------------------------------- /upsert.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Upsert generates an insert ... on (duplicate key/conflict) update statement 4 | func Upsert(table TableElem) UpsertStmt { 5 | return UpsertStmt{ 6 | Table: table, 7 | ValuesMap: map[string]interface{}{}, 8 | ReturningCols: []ColumnElem{}, 9 | } 10 | } 11 | 12 | // UpsertStmt is the base struct for any insert ... on conflict/duplicate key ... update ... statements 13 | type UpsertStmt struct { 14 | Table TableElem 15 | ValuesMap map[string]interface{} 16 | ReturningCols []ColumnElem 17 | } 18 | 19 | // Values accepts map[string]interface{} and forms the values map of insert statement 20 | func (s UpsertStmt) Values(values map[string]interface{}) UpsertStmt { 21 | for k, v := range values { 22 | s.ValuesMap[k] = v 23 | } 24 | return s 25 | } 26 | 27 | // Returning accepts the column names as strings and forms the returning array of insert statement 28 | // NOTE: Please use it in only postgres dialect, otherwise it'll crash 29 | func (s UpsertStmt) Returning(cols ...ColumnElem) UpsertStmt { 30 | for _, c := range cols { 31 | s.ReturningCols = append(s.ReturningCols, c) 32 | } 33 | return s 34 | } 35 | 36 | // Accept calls the compiler VisitUpsert function 37 | func (s UpsertStmt) Accept(context *CompilerContext) string { 38 | return context.Compiler.VisitUpsert(context, s) 39 | } 40 | 41 | // Build generates a statement out of UpdateStmt object 42 | func (s UpsertStmt) Build(dialect Dialect) *Stmt { 43 | context := NewCompilerContext(dialect) 44 | statement := Statement() 45 | statement.AddSQLClause(s.Accept(context)) 46 | statement.AddBinding(context.Binds...) 47 | 48 | return statement 49 | } 50 | -------------------------------------------------------------------------------- /upsert_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestUpsert(t *testing.T) { 10 | def := NewDialect("default") 11 | 12 | users := Table( 13 | "users", 14 | Column("id", Varchar().Size(36)), 15 | Column("email", Varchar()).Unique(), 16 | Column("created_at", Timestamp()).NotNull(), 17 | PrimaryKey("id"), 18 | ) 19 | 20 | now := time.Now().UTC().String() 21 | 22 | ups := Upsert(users).Values(map[string]interface{}{ 23 | "id": "9883cf81-3b56-4151-ae4e-3903c5bc436d", 24 | "email": "al@pacino.com", 25 | "created_at": now, 26 | }) 27 | 28 | assert.Panics(t, func() { 29 | ups.Build(def) 30 | }) 31 | 32 | ups = ups.Returning(users.C("email")) 33 | assert.Equal(t, []ColumnElem{users.C("email")}, ups.ReturningCols) 34 | } 35 | -------------------------------------------------------------------------------- /where.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | // Where generates a compilable where clause 4 | func Where(clauses ...Clause) WhereClause { 5 | var clause Clause 6 | if len(clauses) == 1 { 7 | clause = clauses[0] 8 | } else { 9 | clause = And(clauses...) 10 | } 11 | return WhereClause{clause} 12 | } 13 | 14 | // WhereClause is the base of any where clause when using expression api 15 | type WhereClause struct { 16 | clause Clause 17 | } 18 | 19 | // Accept compiles the where clause, returns sql 20 | func (c WhereClause) Accept(context *CompilerContext) string { 21 | return context.Compiler.VisitWhere(context, c) 22 | } 23 | 24 | // And combine the current clause and the new ones with a And() 25 | func (c WhereClause) And(clauses ...Clause) WhereClause { 26 | clauses = append([]Clause{c.clause}, clauses...) 27 | c.clause = And(clauses...) 28 | return c 29 | } 30 | 31 | // Or combine the current clause and the new ones with a Or() 32 | func (c WhereClause) Or(clauses ...Clause) WhereClause { 33 | clauses = append([]Clause{c.clause}, clauses...) 34 | c.clause = Or(clauses...) 35 | return c 36 | } 37 | -------------------------------------------------------------------------------- /where_test.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestWhere(t *testing.T) { 10 | ctx := NewCompilerContext(NewDefaultDialect()) 11 | assert.Equal(t, 12 | "WHERE X", 13 | Where(SQLText("X")).Accept(ctx)) 14 | assert.Equal(t, 15 | "WHERE (X AND Y)", 16 | Where(SQLText("X"), SQLText("Y")).Accept(ctx)) 17 | } 18 | 19 | func TestWhereAnd(t *testing.T) { 20 | ctx := NewCompilerContext(NewDefaultDialect()) 21 | assert.Equal(t, 22 | "WHERE (X AND Y)", 23 | Where(SQLText("X")).And(SQLText("Y")).Accept(ctx)) 24 | assert.Equal(t, 25 | "WHERE (X AND Y AND Z)", 26 | Where(SQLText("X")).And(SQLText("Y"), SQLText("Z")).Accept(ctx)) 27 | } 28 | 29 | func TestWhereOr(t *testing.T) { 30 | ctx := NewCompilerContext(NewDefaultDialect()) 31 | assert.Equal(t, 32 | "WHERE (X OR Y)", 33 | Where(SQLText("X")).Or(SQLText("Y")).Accept(ctx)) 34 | assert.Equal(t, 35 | "WHERE (X OR Y OR Z)", 36 | Where(SQLText("X")).Or(SQLText("Y"), SQLText("Z")).Accept(ctx)) 37 | } 38 | --------------------------------------------------------------------------------