├── .github └── workflow │ └── ci.yml ├── LICENSE ├── README.md ├── context.go ├── driver.go ├── driver_test.go ├── go.mod ├── go.sum ├── internal └── examples │ ├── ctxlevel │ ├── go.mod │ ├── go.sum │ ├── main.go │ └── main_test.go │ ├── multilevel │ ├── README.md │ ├── go.mod │ ├── go.sum │ └── main.go │ └── todo │ ├── README.md │ ├── ent.graphql │ ├── ent │ ├── client.go │ ├── config.go │ ├── context.go │ ├── ent.go │ ├── entc.go │ ├── enttest │ │ └── enttest.go │ ├── generate.go │ ├── gql_collection.go │ ├── gql_edge.go │ ├── gql_node.go │ ├── gql_pagination.go │ ├── gql_transaction.go │ ├── gql_where_input.go │ ├── hook │ │ └── hook.go │ ├── migrate │ │ ├── migrate.go │ │ └── schema.go │ ├── mutation.go │ ├── mutation_input.go │ ├── predicate │ │ └── predicate.go │ ├── runtime.go │ ├── runtime │ │ └── runtime.go │ ├── schema │ │ ├── todo.go │ │ └── user.go │ ├── template │ │ └── mutation_input.tmpl │ ├── todo.go │ ├── todo │ │ ├── todo.go │ │ └── where.go │ ├── todo_create.go │ ├── todo_delete.go │ ├── todo_query.go │ ├── todo_update.go │ ├── tx.go │ ├── user.go │ ├── user │ │ ├── user.go │ │ └── where.go │ ├── user_create.go │ ├── user_delete.go │ ├── user_query.go │ └── user_update.go │ ├── generate.go │ ├── generated.go │ ├── go.mod │ ├── go.sum │ ├── gqlgen.yml │ ├── resolver.go │ ├── todo.graphql │ └── todo.resolvers.go └── level.go /.github/workflow/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous Integration 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | lint: 6 | runs-on: ubuntu-latest 7 | strategy: 8 | matrix: 9 | goversion: [1.17, 1.16] 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-go@v2 13 | with: 14 | go-version: ${{ matrix.goversion }} 15 | - name: Run linters 16 | uses: golangci/golangci-lint-action@v2.5.2 17 | with: 18 | version: v1.41.1 19 | args: --timeout 3m 20 | test: 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | goversion: [1.17, 1.16] 25 | steps: 26 | - uses: actions/checkout@v2 27 | - uses: actions/setup-go@v2 28 | with: 29 | go-version: ${{ matrix.goversion }} 30 | - uses: actions/cache@v2 31 | with: 32 | path: | 33 | ~/go/pkg/mod 34 | ~/.cache/go-build 35 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 36 | restore-keys: | 37 | ${{ runner.os }}-go- 38 | - name: Run tests 39 | run: go test -race ./... -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # entcache 2 | 3 | An experimental cache driver for [ent](https://github.com/ent/ent) with variety of storage options, such as: 4 | 5 | 1. A `context.Context`-based cache. Usually, attached to an HTTP request. 6 | 7 | 2. A driver level cache embedded in the `ent.Client`. Used to share cache entries on the process level. 8 | 9 | 4. A remote cache. For example, a Redis database that provides a persistence layer for storing and sharing cache 10 | entries between multiple processes. 11 | 12 | 4. A cache hierarchy, or multi-level cache allows structuring the cache in hierarchical way. For example, a 2-level cache 13 | that composed from an LRU-cache in the application memory, and a remote-level cache backed by a Redis database. 14 | 15 | ## Quick Introduction 16 | 17 | First, `go get` the package using the following command. 18 | 19 | ```shell 20 | go get ariga.io/entcache 21 | ``` 22 | 23 | After installing `entcache`, you can easily add it to your project with the snippet below: 24 | 25 | ```go 26 | // Open the database connection. 27 | db, err := sql.Open(dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1") 28 | if err != nil { 29 | log.Fatal("opening database", err) 30 | } 31 | // Decorates the sql.Driver with entcache.Driver. 32 | drv := entcache.NewDriver(db) 33 | // Create an ent.Client. 34 | client := ent.NewClient(ent.Driver(drv)) 35 | 36 | // Tell the entcache.Driver to skip the caching layer 37 | // when running the schema migration. 38 | if client.Schema.Create(entcache.Skip(ctx)); err != nil { 39 | log.Fatal("running schema migration", err) 40 | } 41 | 42 | // Run queries. 43 | if u, err := client.User.Get(ctx, id); err != nil { 44 | log.Fatal("querying user", err) 45 | } 46 | // The query below is cached. 47 | if u, err := client.User.Get(ctx, id); err != nil { 48 | log.Fatal("querying user", err) 49 | } 50 | ``` 51 | 52 | **However**, you need to choose the cache storage carefully before adding `entcache` to your project. 53 | The section below covers the different approaches provided by this package. 54 | 55 | 56 | ## High Level Design 57 | 58 | On a high level, `entcache.Driver` decorates the `Query` method of the given driver, and for each call, generates a cache 59 | key (i.e. hash) from its arguments (i.e. statement and parameters). After the query is executed, the driver records the 60 | raw values of the returned rows (`sql.Rows`), and stores them in the cache store with the generated cache key. This 61 | means, that the recorded rows will be returned the next time the query is executed, if it was not evicted by the cache store. 62 | 63 | The package provides a variety of options to configure the TTL of the cache entries, control the hash function, provide 64 | custom and multi-level cache stores, evict and skip cache entries. See the full documentation in 65 | [go.dev/entcache](https://pkg.go.dev/ariga.io/entcache). 66 | 67 | ### Caching Levels 68 | 69 | `entcache` provides several builtin cache levels: 70 | 71 | 1. A `context.Context`-based cache. Usually, attached to a request and does not work with other cache levels. 72 | It is used to eliminate duplicate queries that are executed by the same request. 73 | 74 | 2. A driver-level cache used by the `ent.Client`. An application usually creates a driver per database, 75 | and therefore, we treat it as a process-level cache. 76 | 77 | 3. A remote cache. For example, a Redis database that provides a persistence layer for storing and sharing cache 78 | entries between multiple processes. A remote cache layer is resistant to application deployment changes or failures, 79 | and allows reducing the number of identical queries executed on the database by different process. 80 | 81 | 4. A cache hierarchy, or multi-level cache allows structuring the cache in hierarchical way. The hierarchy of cache 82 | stores is mostly based on access speeds and cache sizes. For example, a 2-level cache that composed from an LRU-cache 83 | in the application memory, and a remote-level cache backed by a Redis database. 84 | 85 | #### Context Level Cache 86 | 87 | The `ContextLevel` option configures the driver to work with a `context.Context` level cache. The context is usually 88 | attached to a request (e.g. `*http.Request`) and is not available in multi-level mode. When this option is used as 89 | a cache store, the attached `context.Context` carries an LRU cache (can be configured differently), and the driver 90 | stores and searches entries in the LRU cache when queries are executed. 91 | 92 | This option is ideal for applications that require strong consistency, but still want to avoid executing duplicate 93 | database queries on the same request. For example, given the following GraphQL query: 94 | 95 | ```graphql 96 | query($ids: [ID!]!) { 97 | nodes(ids: $ids) { 98 | ... on User { 99 | id 100 | name 101 | todos { 102 | id 103 | owner { 104 | id 105 | name 106 | } 107 | } 108 | } 109 | } 110 | } 111 | ``` 112 | 113 | A naive solution for resolving the above query will execute, 1 for getting N users, another N queries for getting 114 | the todos of each user, and a query for each todo item for getting its owner (read more about the 115 | [_N+1 Problem_](https://entgo.io/docs/tutorial-todo-gql-field-collection/#problem)). 116 | 117 | However, Ent provides a unique approach for resolving such queries(read more in 118 | [Ent website](https://entgo.io/docs/tutorial-todo-gql-field-collection)) and therefore, only 3 queries will be executed 119 | in this case. 1 for getting N users, 1 for getting the todo items of **all** users, and 1 query for getting the owners 120 | of **all** todo items. 121 | 122 | With `entcache`, the number of queries may be reduced to 2, as the first and last queries are identical (see 123 | [code example](internal/examples/ctxlevel/main_test.go)). 124 | 125 | ![context-level-cache](https://github.com/ariga/entcache/blob/assets/internal/assets/ctxlevel.png) 126 | 127 | ##### Usage In GraphQL 128 | 129 | In order to instantiate an `entcache.Driver` in a `ContextLevel` mode and use it in the generated `ent.Client` use the 130 | following configuration. 131 | 132 | ```go 133 | db, err := sql.Open(dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1") 134 | if err != nil { 135 | log.Fatal("opening database", err) 136 | } 137 | drv := entcache.NewDriver(db, entcache.ContextLevel()) 138 | client := ent.NewClient(ent.Driver(drv)) 139 | ``` 140 | 141 | Then, when a GraphQL query hits the server, we wrap the request `context.Context` with an `entcache.NewContext`. 142 | 143 | ```go 144 | srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 145 | if op := graphql.GetOperationContext(ctx).Operation; op != nil && op.Operation == ast.Query { 146 | ctx = entcache.NewContext(ctx) 147 | } 148 | return next(ctx) 149 | }) 150 | ``` 151 | 152 | That's it! Your server is ready to use `entcache` with GraphQL, and a full server example exits in 153 | [examples/ctxlevel](internal/examples/ctxlevel). 154 | 155 | ##### Middleware Example 156 | 157 | An example of using the common middleware pattern in Go for wrapping the request `context.Context` with 158 | an `entcache.NewContext` in case of `GET` requests. 159 | 160 | ```go 161 | srv.Use(func(next http.Handler) http.Handler { 162 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 163 | if r.Method == http.MethodGet { 164 | r = r.WithContext(entcache.NewContext(r.Context())) 165 | } 166 | next.ServeHTTP(w, r) 167 | }) 168 | }) 169 | ``` 170 | 171 | #### Driver Level Cache 172 | 173 | A driver-based level cached stores the cache entries on the `ent.Client`. An application usually creates a driver per 174 | database (i.e. `sql.DB`), and therefore, we treat it as a process-level cache. The default cache storage for this option 175 | is an LRU cache with no limit and no TTL for its entries, but can be configured differently. 176 | 177 | ![driver-level-cache](https://github.com/ariga/entcache/blob/assets/internal/assets/drvlevel.png) 178 | 179 | ##### Create a default cache driver, with no limit and no TTL. 180 | 181 | ```go 182 | db, err := sql.Open(dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1") 183 | if err != nil { 184 | log.Fatal("opening database", err) 185 | } 186 | drv := entcache.NewDriver(db) 187 | client := ent.NewClient(ent.Driver(drv)) 188 | ``` 189 | 190 | ##### Set the TTL to 1s. 191 | 192 | ```go 193 | drv := entcache.NewDriver(drv, entcache.TTL(time.Second)) 194 | client := ent.NewClient(ent.Driver(drv)) 195 | ``` 196 | 197 | ##### Limit the cache to 128 entries and set the TTL to 1s. 198 | 199 | ```go 200 | drv := entcache.NewDriver( 201 | drv, 202 | entcache.TTL(time.Second), 203 | entcache.Levels(entcache.NewLRU(128)), 204 | ) 205 | client := ent.NewClient(ent.Driver(drv)) 206 | ``` 207 | 208 | #### Remote Level Cache 209 | 210 | A remote-based level cache is used to share cached entries between multiple processes. For example, a Redis database. 211 | A remote cache layer is resistant to application deployment changes or failures, and allows reducing the number of 212 | identical queries executed on the database by different processes. This option plays nicely the multi-level option below. 213 | 214 | #### Multi Level Cache 215 | 216 | A cache hierarchy, or multi-level cache allows structuring the cache in hierarchical way. The hierarchy of cache 217 | stores is mostly based on access speeds and cache sizes. For example, a 2-level cache that compounds from an LRU-cache 218 | in the application memory, and a remote-level cache backed by a Redis database. 219 | 220 | ![context-level-cache](https://github.com/ariga/entcache/blob/assets/internal/assets/multilevel.png) 221 | 222 | ```go 223 | rdb := redis.NewClient(&redis.Options{ 224 | Addr: ":6379", 225 | }) 226 | if err := rdb.Ping(ctx).Err(); err != nil { 227 | log.Fatal(err) 228 | } 229 | drv := entcache.NewDriver( 230 | drv, 231 | entcache.TTL(time.Second), 232 | entcache.Levels( 233 | entcache.NewLRU(256), 234 | entcache.NewRedis(rdb), 235 | ), 236 | ) 237 | client := ent.NewClient(ent.Driver(drv)) 238 | ``` 239 | 240 | ### Future Work 241 | 242 | There are a few features we are working on, and wish to work on, but need help from the community to design them 243 | properly. If you are interested in one of the tasks or features below, do not hesitate to open an issue, or start a 244 | discussion on GitHub or in [Ent Slack channel](https://entgo.io/docs/slack). 245 | 246 | 1. Add a Memcache implementation for a remote-level cache. 247 | 2. Support for smart eviction mechanism based on SQL parsing. 248 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package entcache 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | type ctxKey struct{} 9 | 10 | // NewContext returns a new Context that carries a cache. 11 | func NewContext(ctx context.Context, levels ...AddGetDeleter) context.Context { 12 | var cache AddGetDeleter 13 | switch len(levels) { 14 | case 0: 15 | cache = NewLRU(0) 16 | case 1: 17 | cache = levels[0] 18 | default: 19 | cache = &multiLevel{levels: levels} 20 | } 21 | return context.WithValue(ctx, ctxKey{}, cache) 22 | } 23 | 24 | // FromContext returns the cache value stored in ctx, if any. 25 | func FromContext(ctx context.Context) (AddGetDeleter, bool) { 26 | c, ok := ctx.Value(ctxKey{}).(AddGetDeleter) 27 | return c, ok 28 | } 29 | 30 | // ctxOptions allows injecting runtime options. 31 | type ctxOptions struct { 32 | skip bool // i.e. skip entry. 33 | evict bool // i.e. skip and invalidate entry. 34 | key Key // entry key. 35 | ttl time.Duration // entry duration. 36 | } 37 | 38 | var ctxOptionsKey ctxOptions 39 | 40 | // Skip returns a new Context that tells the Driver 41 | // to skip the cache entry on Query. 42 | // 43 | // client.T.Query().All(entcache.Skip(ctx)) 44 | // 45 | func Skip(ctx context.Context) context.Context { 46 | c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions) 47 | if !ok { 48 | return context.WithValue(ctx, ctxOptionsKey, &ctxOptions{skip: true}) 49 | } 50 | c.skip = true 51 | return ctx 52 | } 53 | 54 | // Evict returns a new Context that tells the Driver 55 | // to skip and invalidate the cache entry on Query. 56 | // 57 | // client.T.Query().All(entcache.Evict(ctx)) 58 | // 59 | func Evict(ctx context.Context) context.Context { 60 | c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions) 61 | if !ok { 62 | return context.WithValue(ctx, ctxOptionsKey, &ctxOptions{skip: true, evict: true}) 63 | } 64 | c.skip = true 65 | c.evict = true 66 | return ctx 67 | } 68 | 69 | // WithKey returns a new Context that carries the Key for the cache entry. 70 | // Note that, this option should not be used if the ent.Client query involves 71 | // more than 1 SQL query (e.g. eager loading). 72 | // 73 | // client.T.Query().All(entcache.WithKey(ctx, "key")) 74 | // 75 | func WithKey(ctx context.Context, key Key) context.Context { 76 | c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions) 77 | if !ok { 78 | return context.WithValue(ctx, ctxOptionsKey, &ctxOptions{key: key}) 79 | } 80 | c.key = key 81 | return ctx 82 | } 83 | 84 | // WithTTL returns a new Context that carries the TTL for the cache entry. 85 | // 86 | // client.T.Query().All(entcache.WithTTL(ctx, time.Second)) 87 | // 88 | func WithTTL(ctx context.Context, ttl time.Duration) context.Context { 89 | c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions) 90 | if !ok { 91 | return context.WithValue(ctx, ctxOptionsKey, &ctxOptions{ttl: ttl}) 92 | } 93 | c.ttl = ttl 94 | return ctx 95 | } 96 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package entcache 2 | 3 | import ( 4 | "context" 5 | stdsql "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "fmt" 9 | "strings" 10 | "sync/atomic" 11 | "time" 12 | _ "unsafe" 13 | 14 | "entgo.io/ent/dialect" 15 | "entgo.io/ent/dialect/sql" 16 | "github.com/mitchellh/hashstructure/v2" 17 | ) 18 | 19 | type ( 20 | // Options wraps the basic configuration cache options. 21 | Options struct { 22 | // TTL defines the period of time that an Entry 23 | // is valid in the cache. 24 | TTL time.Duration 25 | 26 | // Cache defines the GetAddDeleter (cache implementation) 27 | // for holding the cache entries. If no cache implementation 28 | // was provided, an LRU cache with no limit is used. 29 | Cache AddGetDeleter 30 | 31 | // Hash defines an optional Hash function for converting 32 | // a query and its arguments to a cache key. If no Hash 33 | // function was provided, the DefaultHash is used. 34 | Hash func(query string, args []any) (Key, error) 35 | 36 | // Logf function. If provided, the Driver will call it with 37 | // errors that can not be handled. 38 | Log func(...any) 39 | } 40 | 41 | // Option allows configuring the cache 42 | // driver using functional options. 43 | Option func(*Options) 44 | 45 | // A Driver is an SQL cached client. Users should use the 46 | // constructor below for creating new driver. 47 | Driver struct { 48 | dialect.Driver 49 | *Options 50 | stats Stats 51 | } 52 | ) 53 | 54 | // NewDriver returns a new Driver an existing driver and optional 55 | // configuration functions. For example: 56 | // 57 | // entcache.NewDriver( 58 | // drv, 59 | // entcache.TTL(time.Minute), 60 | // entcache.Levels( 61 | // NewLRU(256), 62 | // NewRedis(redis.NewClient(&redis.Options{ 63 | // Addr: ":6379", 64 | // })), 65 | // ) 66 | // ) 67 | func NewDriver(drv dialect.Driver, opts ...Option) *Driver { 68 | options := &Options{Hash: DefaultHash, Cache: NewLRU(0)} 69 | for _, opt := range opts { 70 | opt(options) 71 | } 72 | return &Driver{ 73 | Driver: drv, 74 | Options: options, 75 | } 76 | } 77 | 78 | // TTL configures the period of time that an Entry 79 | // is valid in the cache. 80 | func TTL(ttl time.Duration) Option { 81 | return func(o *Options) { 82 | o.TTL = ttl 83 | } 84 | } 85 | 86 | // Hash configures an optional Hash function for 87 | // converting a query and its arguments to a cache key. 88 | func Hash(hash func(query string, args []any) (Key, error)) Option { 89 | return func(o *Options) { 90 | o.Hash = hash 91 | } 92 | } 93 | 94 | // Levels configures the Driver to work with the given cache levels. 95 | // For example, in process LRU cache and a remote Redis cache. 96 | func Levels(levels ...AddGetDeleter) Option { 97 | return func(o *Options) { 98 | if len(levels) == 1 { 99 | o.Cache = levels[0] 100 | } else { 101 | o.Cache = &multiLevel{levels: levels} 102 | } 103 | } 104 | } 105 | 106 | // ContextLevel configures the driver to work with context/request level cache. 107 | // Users that use this option, should wraps the *http.Request context with the 108 | // cache value as follows: 109 | // 110 | // ctx = entcache.NewContext(ctx) 111 | // 112 | // ctx = entcache.NewContext(ctx, entcache.NewLRU(128)) 113 | func ContextLevel() Option { 114 | return func(o *Options) { 115 | o.Cache = &contextLevel{} 116 | } 117 | } 118 | 119 | // Query implements the Querier interface for the driver. It falls back to the 120 | // underlying wrapped driver in case of caching error. 121 | // 122 | // Note that, the driver does not synchronize identical queries that are executed 123 | // concurrently. Hence, if 2 identical queries are executed at the ~same time, and 124 | // there is no cache entry for them, the driver will execute both of them and the 125 | // last successful one will be stored in the cache. 126 | func (d *Driver) Query(ctx context.Context, query string, args, v any) error { 127 | // Check if the given statement looks like a standard Ent query (e.g. SELECT). 128 | // Custom queries (e.g. CTE) or statements that are prefixed with comments are 129 | // not supported. This check is mainly necessary, because PostgreSQL and SQLite 130 | // may execute insert statement like "INSERT ... RETURNING" using Driver.Query. 131 | if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") { 132 | return d.Driver.Query(ctx, query, args, v) 133 | } 134 | vr, ok := v.(*sql.Rows) 135 | if !ok { 136 | return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v) 137 | } 138 | argv, ok := args.([]any) 139 | if !ok { 140 | return fmt.Errorf("entcache: invalid type %T. expect []interface{} for args", args) 141 | } 142 | opts, err := d.optionsFromContext(ctx, query, argv) 143 | if err != nil { 144 | return d.Driver.Query(ctx, query, args, v) 145 | } 146 | atomic.AddUint64(&d.stats.Gets, 1) 147 | switch e, err := d.Cache.Get(ctx, opts.key); { 148 | case err == nil: 149 | atomic.AddUint64(&d.stats.Hits, 1) 150 | vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values} 151 | case err == ErrNotFound: 152 | if err := d.Driver.Query(ctx, query, args, vr); err != nil { 153 | return err 154 | } 155 | vr.ColumnScanner = &recorder{ 156 | ColumnScanner: vr.ColumnScanner, 157 | onClose: func(columns []string, values [][]driver.Value) { 158 | err := d.Cache.Add(ctx, opts.key, &Entry{Columns: columns, Values: values}, opts.ttl) 159 | if err != nil && d.Log != nil { 160 | atomic.AddUint64(&d.stats.Errors, 1) 161 | d.Log(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err)) 162 | } 163 | }, 164 | } 165 | default: 166 | return d.Driver.Query(ctx, query, args, v) 167 | } 168 | return nil 169 | } 170 | 171 | // Stats returns a copy of the cache statistics. 172 | func (d *Driver) Stats() Stats { 173 | return Stats{ 174 | Gets: atomic.LoadUint64(&d.stats.Gets), 175 | Hits: atomic.LoadUint64(&d.stats.Hits), 176 | Errors: atomic.LoadUint64(&d.stats.Errors), 177 | } 178 | } 179 | 180 | // QueryContext calls QueryContext of the underlying driver, or fails if it is not supported. 181 | // Note, this method is not part of the caching layer since Ent does not use it by default. 182 | func (d *Driver) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { 183 | drv, ok := d.Driver.(interface { 184 | QueryContext(context.Context, string, ...any) (*stdsql.Rows, error) 185 | }) 186 | if !ok { 187 | return nil, fmt.Errorf("Driver.QueryContext is not supported") 188 | } 189 | return drv.QueryContext(ctx, query, args...) 190 | } 191 | 192 | // ExecContext calls ExecContext of the underlying driver, or fails if it is not supported. 193 | func (d *Driver) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { 194 | drv, ok := d.Driver.(interface { 195 | ExecContext(context.Context, string, ...any) (stdsql.Result, error) 196 | }) 197 | if !ok { 198 | return nil, fmt.Errorf("Driver.ExecContext is not supported") 199 | } 200 | return drv.ExecContext(ctx, query, args...) 201 | } 202 | 203 | // errSkip tells the driver to skip cache layer. 204 | var errSkip = errors.New("entcache: skip cache") 205 | 206 | // optionsFromContext returns the injected options from the context, or its default value. 207 | func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) { 208 | var opts ctxOptions 209 | if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok { 210 | opts = *c 211 | } 212 | if opts.key == nil { 213 | key, err := d.Hash(query, args) 214 | if err != nil { 215 | return opts, errSkip 216 | } 217 | opts.key = key 218 | } 219 | if opts.ttl == 0 { 220 | opts.ttl = d.TTL 221 | } 222 | if opts.evict { 223 | if err := d.Cache.Del(ctx, opts.key); err != nil { 224 | return opts, err 225 | } 226 | } 227 | if opts.skip { 228 | return opts, errSkip 229 | } 230 | return opts, nil 231 | } 232 | 233 | // DefaultHash provides the default implementation for converting 234 | // a query and its argument to a cache key. 235 | func DefaultHash(query string, args []any) (Key, error) { 236 | key, err := hashstructure.Hash(struct { 237 | Q string 238 | A []any 239 | }{ 240 | Q: query, 241 | A: args, 242 | }, hashstructure.FormatV2, nil) 243 | if err != nil { 244 | return nil, err 245 | } 246 | return key, nil 247 | } 248 | 249 | // Stats represents the cache statistics of the driver. 250 | type Stats struct { 251 | Gets uint64 252 | Hits uint64 253 | Errors uint64 254 | } 255 | 256 | // rawCopy copies the driver values by implementing 257 | // the sql.Scanner interface. 258 | type rawCopy struct { 259 | values []driver.Value 260 | } 261 | 262 | func (c *rawCopy) Scan(src interface{}) error { 263 | if b, ok := src.([]byte); ok { 264 | b1 := make([]byte, len(b)) 265 | copy(b1, b) 266 | src = b1 267 | } 268 | c.values[0] = src 269 | c.values = c.values[1:] 270 | return nil 271 | } 272 | 273 | // recorder represents an sql.Rows recorder that implements 274 | // the entgo.io/ent/dialect/sql.ColumnScanner interface. 275 | type recorder struct { 276 | sql.ColumnScanner 277 | values [][]driver.Value 278 | columns []string 279 | done bool 280 | onClose func([]string, [][]driver.Value) 281 | } 282 | 283 | // Next wraps the underlying Next method 284 | func (r *recorder) Next() bool { 285 | hasNext := r.ColumnScanner.Next() 286 | r.done = !hasNext 287 | return hasNext 288 | } 289 | 290 | // Scan copies database values for future use (by the repeater) 291 | // and assign them to the given destinations using the standard 292 | // database/sql.convertAssign function. 293 | func (r *recorder) Scan(dest ...any) error { 294 | values := make([]driver.Value, len(dest)) 295 | args := make([]any, len(dest)) 296 | c := &rawCopy{values: values} 297 | for i := range args { 298 | args[i] = c 299 | } 300 | if err := r.ColumnScanner.Scan(args...); err != nil { 301 | return err 302 | } 303 | for i := range values { 304 | if err := convertAssign(dest[i], values[i]); err != nil { 305 | return err 306 | } 307 | } 308 | r.values = append(r.values, values) 309 | return nil 310 | } 311 | 312 | // Columns wraps the underlying Column method and stores it in the recorder state. 313 | // The repeater.Columns cannot be called if the recorder method was not called before. 314 | // That means, raw scanning should be identical for identical queries. 315 | func (r *recorder) Columns() ([]string, error) { 316 | columns, err := r.ColumnScanner.Columns() 317 | if err != nil { 318 | return nil, err 319 | } 320 | r.columns = columns 321 | return columns, nil 322 | } 323 | 324 | func (r *recorder) Close() error { 325 | if err := r.ColumnScanner.Close(); err != nil { 326 | return err 327 | } 328 | // If we did not encounter any error during iteration, 329 | // and we scanned all rows, we store it on cache. 330 | if err := r.ColumnScanner.Err(); err == nil || r.done { 331 | r.onClose(r.columns, r.values) 332 | } 333 | return nil 334 | } 335 | 336 | // repeater repeats columns scanning from cache history. 337 | type repeater struct { 338 | columns []string 339 | values [][]driver.Value 340 | } 341 | 342 | func (*repeater) Close() error { 343 | return nil 344 | } 345 | func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) { 346 | return nil, fmt.Errorf("entcache.ColumnTypes is not supported") 347 | } 348 | func (r *repeater) Columns() ([]string, error) { 349 | return r.columns, nil 350 | } 351 | func (*repeater) Err() error { 352 | return nil 353 | } 354 | func (r *repeater) Next() bool { 355 | return len(r.values) > 0 356 | } 357 | func (r *repeater) NextResultSet() bool { 358 | return len(r.values) > 0 359 | } 360 | 361 | func (r *repeater) Scan(dest ...any) error { 362 | if !r.Next() { 363 | return stdsql.ErrNoRows 364 | } 365 | for i, src := range r.values[0] { 366 | if err := convertAssign(dest[i], src); err != nil { 367 | return err 368 | } 369 | } 370 | r.values = r.values[1:] 371 | return nil 372 | } 373 | 374 | //go:linkname convertAssign database/sql.convertAssign 375 | func convertAssign(dest, src any) error 376 | -------------------------------------------------------------------------------- /driver_test.go: -------------------------------------------------------------------------------- 1 | package entcache_test 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "testing" 7 | "time" 8 | 9 | "ariga.io/entcache" 10 | 11 | "entgo.io/ent/dialect" 12 | "entgo.io/ent/dialect/sql" 13 | "github.com/DATA-DOG/go-sqlmock" 14 | "github.com/go-redis/redismock/v9" 15 | ) 16 | 17 | func TestDriver_ContextLevel(t *testing.T) { 18 | db, mock, err := sqlmock.New() 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | drv := sql.OpenDB(dialect.MySQL, db) 23 | 24 | t.Run("One", func(t *testing.T) { 25 | drv := entcache.NewDriver(drv, entcache.ContextLevel()) 26 | mock.ExpectQuery("SELECT id FROM users"). 27 | WillReturnRows( 28 | sqlmock.NewRows([]string{"id"}). 29 | AddRow(1). 30 | AddRow(2). 31 | AddRow(3), 32 | ) 33 | ctx := entcache.NewContext(context.Background()) 34 | expectQuery(ctx, t, drv, "SELECT id FROM users", []interface{}{int64(1), int64(2), int64(3)}) 35 | expectQuery(ctx, t, drv, "SELECT id FROM users", []interface{}{int64(1), int64(2), int64(3)}) 36 | if err := mock.ExpectationsWereMet(); err != nil { 37 | t.Fatal(err) 38 | } 39 | }) 40 | 41 | t.Run("Multi", func(t *testing.T) { 42 | drv := entcache.NewDriver(drv, entcache.ContextLevel()) 43 | mock.ExpectQuery("SELECT name FROM users"). 44 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 45 | ctx1 := entcache.NewContext(context.Background()) 46 | expectQuery(ctx1, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 47 | ctx2 := entcache.NewContext(context.Background()) 48 | mock.ExpectQuery("SELECT name FROM users"). 49 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 50 | expectQuery(ctx2, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 51 | if err := mock.ExpectationsWereMet(); err != nil { 52 | t.Fatal(err) 53 | } 54 | }) 55 | 56 | t.Run("TTL", func(t *testing.T) { 57 | drv := entcache.NewDriver(drv, entcache.ContextLevel(), entcache.TTL(-1)) 58 | mock.ExpectQuery("SELECT name FROM users"). 59 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 60 | mock.ExpectQuery("SELECT name FROM users"). 61 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 62 | ctx := entcache.NewContext(context.Background()) 63 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 64 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 65 | if err := mock.ExpectationsWereMet(); err != nil { 66 | t.Fatal(err) 67 | } 68 | }) 69 | } 70 | 71 | func TestDriver_Levels(t *testing.T) { 72 | db, mock, err := sqlmock.New() 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | drv := sql.OpenDB(dialect.Postgres, db) 77 | 78 | t.Run("One", func(t *testing.T) { 79 | drv := entcache.NewDriver(drv, entcache.TTL(time.Second)) 80 | mock.ExpectQuery("SELECT age FROM users"). 81 | WillReturnRows( 82 | sqlmock.NewRows([]string{"age"}). 83 | AddRow(20.1). 84 | AddRow(30.2). 85 | AddRow(40.5), 86 | ) 87 | expectQuery(context.Background(), t, drv, "SELECT age FROM users", []interface{}{20.1, 30.2, 40.5}) 88 | expectQuery(context.Background(), t, drv, "SELECT age FROM users", []interface{}{20.1, 30.2, 40.5}) 89 | if err := mock.ExpectationsWereMet(); err != nil { 90 | t.Fatal(err) 91 | } 92 | }) 93 | 94 | t.Run("Multi", func(t *testing.T) { 95 | drv := entcache.NewDriver( 96 | drv, 97 | entcache.Levels( 98 | entcache.NewLRU(-1), // Nop. 99 | entcache.NewLRU(0), // No limit. 100 | ), 101 | ) 102 | mock.ExpectQuery("SELECT age FROM users"). 103 | WillReturnRows( 104 | sqlmock.NewRows([]string{"age"}). 105 | AddRow(20.1). 106 | AddRow(30.2). 107 | AddRow(40.5), 108 | ) 109 | expectQuery(context.Background(), t, drv, "SELECT age FROM users", []interface{}{20.1, 30.2, 40.5}) 110 | expectQuery(context.Background(), t, drv, "SELECT age FROM users", []interface{}{20.1, 30.2, 40.5}) 111 | if err := mock.ExpectationsWereMet(); err != nil { 112 | t.Fatal(err) 113 | } 114 | }) 115 | 116 | t.Run("Redis", func(t *testing.T) { 117 | var ( 118 | rdb, rmock = redismock.NewClientMock() 119 | drv = entcache.NewDriver( 120 | drv, 121 | entcache.Levels( 122 | entcache.NewLRU(-1), 123 | entcache.NewRedis(rdb), 124 | ), 125 | entcache.Hash(func(string, []interface{}) (entcache.Key, error) { 126 | return 1, nil 127 | }), 128 | ) 129 | ) 130 | mock.ExpectQuery("SELECT active FROM users"). 131 | WillReturnRows(sqlmock.NewRows([]string{"active"}).AddRow(true).AddRow(false)) 132 | rmock.ExpectGet("1").RedisNil() 133 | buf, _ := entcache.Entry{Values: [][]driver.Value{{true}, {false}}}.MarshalBinary() 134 | rmock.ExpectSet("1", buf, 0).RedisNil() 135 | expectQuery(context.Background(), t, drv, "SELECT active FROM users", []interface{}{true, false}) 136 | rmock.ExpectGet("1").SetVal(string(buf)) 137 | expectQuery(context.Background(), t, drv, "SELECT active FROM users", []interface{}{true, false}) 138 | if err := rmock.ExpectationsWereMet(); err != nil { 139 | t.Fatal(err) 140 | } 141 | expected := entcache.Stats{Gets: 2, Hits: 1} 142 | if s := drv.Stats(); s != expected { 143 | t.Errorf("unexpected stats: %v != %v", s, expected) 144 | } 145 | }) 146 | } 147 | 148 | func TestDriver_ContextOptions(t *testing.T) { 149 | db, mock, err := sqlmock.New() 150 | if err != nil { 151 | t.Fatal(err) 152 | } 153 | drv := sql.OpenDB(dialect.MySQL, db) 154 | 155 | t.Run("Skip", func(t *testing.T) { 156 | drv := entcache.NewDriver(drv) 157 | mock.ExpectQuery("SELECT name FROM users"). 158 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 159 | ctx := context.Background() 160 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 161 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 162 | mock.ExpectQuery("SELECT name FROM users"). 163 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 164 | skipCtx := entcache.Skip(ctx) 165 | expectQuery(skipCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 166 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 167 | if err := mock.ExpectationsWereMet(); err != nil { 168 | t.Fatal(err) 169 | } 170 | }) 171 | 172 | t.Run("Evict", func(t *testing.T) { 173 | drv := entcache.NewDriver(drv) 174 | mock.ExpectQuery("SELECT name FROM users"). 175 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 176 | ctx := context.Background() 177 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 178 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 179 | mock.ExpectQuery("SELECT name FROM users"). 180 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 181 | evictCtx := entcache.Evict(ctx) 182 | expectQuery(evictCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 183 | mock.ExpectQuery("SELECT name FROM users"). 184 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 185 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 186 | if err := mock.ExpectationsWereMet(); err != nil { 187 | t.Fatal(err) 188 | } 189 | }) 190 | 191 | t.Run("WithTTL", func(t *testing.T) { 192 | drv := entcache.NewDriver(drv) 193 | mock.ExpectQuery("SELECT name FROM users"). 194 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 195 | ttlCtx := entcache.WithTTL(context.Background(), -1) 196 | expectQuery(ttlCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 197 | mock.ExpectQuery("SELECT name FROM users"). 198 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 199 | expectQuery(ttlCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 200 | if err := mock.ExpectationsWereMet(); err != nil { 201 | t.Fatal(err) 202 | } 203 | }) 204 | 205 | t.Run("WithKey", func(t *testing.T) { 206 | drv := entcache.NewDriver(drv) 207 | mock.ExpectQuery("SELECT name FROM users"). 208 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 209 | ctx := context.Background() 210 | keyCtx := entcache.WithKey(ctx, "cache-key") 211 | expectQuery(keyCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 212 | expectQuery(keyCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 213 | mock.ExpectQuery("SELECT name FROM users"). 214 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 215 | expectQuery(ctx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 216 | if err := drv.Cache.Del(ctx, "cache-key"); err != nil { 217 | t.Fatal(err) 218 | } 219 | mock.ExpectQuery("SELECT name FROM users"). 220 | WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("a8m")) 221 | expectQuery(keyCtx, t, drv, "SELECT name FROM users", []interface{}{"a8m"}) 222 | if err := mock.ExpectationsWereMet(); err != nil { 223 | t.Fatal(err) 224 | } 225 | expected := entcache.Stats{Gets: 4, Hits: 1} 226 | if s := drv.Stats(); s != expected { 227 | t.Errorf("unexpected stats: %v != %v", s, expected) 228 | } 229 | }) 230 | } 231 | 232 | func TestDriver_SkipInsert(t *testing.T) { 233 | db, mock, err := sqlmock.New() 234 | if err != nil { 235 | t.Fatal(err) 236 | } 237 | drv := entcache.NewDriver(sql.OpenDB(dialect.Postgres, db), entcache.Hash(func(string, []interface{}) (entcache.Key, error) { 238 | t.Fatal("Driver.Query should not be called for INSERT statements") 239 | return nil, nil 240 | })) 241 | mock.ExpectQuery("INSERT INTO users DEFAULT VALUES RETURNING id"). 242 | WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) 243 | expectQuery(context.Background(), t, drv, "INSERT INTO users DEFAULT VALUES RETURNING id", []interface{}{int64(1)}) 244 | if err := mock.ExpectationsWereMet(); err != nil { 245 | t.Fatal(err) 246 | } 247 | var expected entcache.Stats 248 | if s := drv.Stats(); s != expected { 249 | t.Errorf("unexpected stats: %v != %v", s, expected) 250 | } 251 | } 252 | 253 | func expectQuery(ctx context.Context, t *testing.T, drv dialect.Driver, query string, args []interface{}) { 254 | rows := &sql.Rows{} 255 | if err := drv.Query(ctx, query, []interface{}{}, rows); err != nil { 256 | t.Fatalf("unexpected query failure: %q: %v", query, err) 257 | } 258 | var dest []interface{} 259 | for rows.Next() { 260 | var v interface{} 261 | if err := rows.Scan(&v); err != nil { 262 | t.Fatal("unexpected Rows.Scan failure:", err) 263 | } 264 | dest = append(dest, v) 265 | } 266 | if len(dest) != len(args) { 267 | t.Fatalf("mismatch rows length: %d != %d", len(dest), len(args)) 268 | } 269 | for i := range dest { 270 | if dest[i] != args[i] { 271 | t.Fatalf("mismatch values: %v != %v", dest[i], args[i]) 272 | } 273 | } 274 | if err := rows.Close(); err != nil { 275 | t.Fatal(err) 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module ariga.io/entcache 2 | 3 | go 1.19 4 | 5 | require ( 6 | entgo.io/ent v0.11.2-0.20220805114204-0066eb986dd3 7 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da 8 | github.com/google/uuid v1.3.0 // indirect 9 | github.com/mitchellh/hashstructure/v2 v2.0.2 10 | ) 11 | 12 | require ( 13 | github.com/DATA-DOG/go-sqlmock v1.5.0 14 | github.com/go-redis/redismock/v9 v9.0.3 15 | github.com/redis/go-redis/v9 v9.0.5 16 | ) 17 | 18 | require ( 19 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 20 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | entgo.io/ent v0.11.2-0.20220805114204-0066eb986dd3 h1:JNVBOLEDn2snBII7sjINt0AWn8WOwCdw4wNjW0J3px8= 2 | entgo.io/ent v0.11.2-0.20220805114204-0066eb986dd3/go.mod h1:YGHEQnmmIUgtD5b1ICD5vg74dS3npkNnmC5K+0J+IHU= 3 | github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= 4 | github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 5 | github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= 6 | github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= 7 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 8 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 11 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 12 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= 13 | github.com/go-redis/redismock/v9 v9.0.3 h1:mtHQi2l51lCmXIbTRTqb1EiHYe9tL5Yk5oorlSJJqR0= 14 | github.com/go-redis/redismock/v9 v9.0.3/go.mod h1:F6tJRfnU8R/NZ0E+Gjvoluk14MqMC5ueSZX6vVQypc0= 15 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= 16 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 17 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 18 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 19 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 20 | github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= 21 | github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= 22 | github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= 23 | github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= 24 | github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= 25 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 26 | github.com/redis/go-redis/v9 v9.0.5 h1:CuQcn5HIEeK7BgElubPP8CGtE0KakrnbBSTLjathl5o= 27 | github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= 28 | github.com/stretchr/testify v1.7.1-0.20210427113832-6241f9ab9942 h1:t0lM6y/M5IiUZyvbBTcngso8SZEZICH7is9B6g/obVU= 29 | golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= 30 | golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= 31 | golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= 32 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 33 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 34 | -------------------------------------------------------------------------------- /internal/examples/ctxlevel/go.mod: -------------------------------------------------------------------------------- 1 | module ctxlevel 2 | 3 | go 1.17 4 | 5 | replace ( 6 | ariga.io/entcache => ../../../ 7 | todo => ../todo 8 | ) 9 | 10 | require ( 11 | ariga.io/entcache v0.0.0 12 | todo v0.0.0 13 | ) 14 | 15 | require ( 16 | entgo.io/contrib v0.1.1-0.20211009150803-2f98d3a15e7d 17 | entgo.io/ent v0.9.2-0.20211014063230-899e9f0e50ba 18 | github.com/99designs/gqlgen v0.14.0 19 | github.com/alecthomas/kong v0.2.16 20 | github.com/mattn/go-sqlite3 v1.14.8 21 | github.com/vektah/gqlparser/v2 v2.2.0 22 | ) 23 | 24 | require ( 25 | github.com/agnivade/levenshtein v1.1.0 // indirect 26 | github.com/cespare/xxhash/v2 v2.1.1 // indirect 27 | github.com/davecgh/go-spew v1.1.1 // indirect 28 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 29 | github.com/go-openapi/inflect v0.19.0 // indirect 30 | github.com/go-redis/redis/v8 v8.11.3 // indirect 31 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 32 | github.com/google/uuid v1.3.0 // indirect 33 | github.com/gorilla/websocket v1.4.2 // indirect 34 | github.com/graphql-go/graphql v0.7.10-0.20210411022516-8a92e977c10b // indirect 35 | github.com/hashicorp/errwrap v1.1.0 // indirect 36 | github.com/hashicorp/go-multierror v1.1.1 // indirect 37 | github.com/hashicorp/golang-lru v0.5.4 // indirect 38 | github.com/mitchellh/hashstructure v1.1.0 // indirect 39 | github.com/mitchellh/mapstructure v1.4.2 // indirect 40 | github.com/pkg/errors v0.9.1 // indirect 41 | github.com/pmezard/go-difflib v1.0.0 // indirect 42 | github.com/vmihailenco/msgpack/v5 v5.2.0 // indirect 43 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 44 | golang.org/x/mod v0.4.2 // indirect 45 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect 46 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e // indirect 47 | golang.org/x/tools v0.1.7 // indirect 48 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect 49 | gopkg.in/yaml.v2 v2.4.0 // indirect 50 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect 51 | ) 52 | -------------------------------------------------------------------------------- /internal/examples/ctxlevel/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "net/http" 7 | 8 | "todo" 9 | "todo/ent" 10 | "todo/ent/migrate" 11 | 12 | "ariga.io/entcache" 13 | "entgo.io/contrib/entgql" 14 | "entgo.io/ent/dialect" 15 | "entgo.io/ent/dialect/sql" 16 | "github.com/99designs/gqlgen/graphql" 17 | "github.com/99designs/gqlgen/graphql/handler" 18 | "github.com/99designs/gqlgen/graphql/playground" 19 | "github.com/alecthomas/kong" 20 | _ "github.com/mattn/go-sqlite3" 21 | "github.com/vektah/gqlparser/v2/ast" 22 | ) 23 | 24 | func main() { 25 | var cli struct { 26 | Addr string `name:"address" default:":8081" help:"Address to listen on."` 27 | Cache bool `name:"cache" default:"true" help:"Enable context-level cache mode."` 28 | } 29 | kong.Parse(&cli) 30 | db, err := sql.Open(dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1") 31 | if err != nil { 32 | log.Fatal("opening database", err) 33 | } 34 | // Run the migration without the debug information. 35 | if err := ent.NewClient(ent.Driver(db)).Schema.Create( 36 | context.Background(), 37 | migrate.WithGlobalUniqueID(true), 38 | ); err != nil { 39 | log.Fatal("running schema migration", err) 40 | } 41 | drv := dialect.Debug(db) 42 | if cli.Cache { 43 | // In case of the context/request level cache is enabled, we wrap the 44 | // driver with a cache driver, and configures it to work in this mode. 45 | drv = entcache.NewDriver(drv, entcache.ContextLevel()) 46 | } 47 | client := ent.NewClient(ent.Driver(drv)) 48 | srv := handler.NewDefaultServer(todo.NewSchema(client)) 49 | srv.Use(entgql.Transactioner{TxOpener: client}) 50 | if cli.Cache { 51 | // In case of the context/request level cache is enabled, we add a middleware 52 | // that wraps the context of GraphQL queries with cache context. 53 | srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 54 | if op := graphql.GetOperationContext(ctx).Operation; op != nil && op.Operation == ast.Query { 55 | ctx = entcache.NewContext(ctx) 56 | } 57 | return next(ctx) 58 | }) 59 | } 60 | http.Handle("/", playground.Handler("Todo", "/query")) 61 | http.Handle("/query", srv) 62 | log.Println("listening on", cli.Addr) 63 | if err := http.ListenAndServe(cli.Addr, nil); err != nil { 64 | log.Fatal("http server terminated", err) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /internal/examples/ctxlevel/main_test.go: -------------------------------------------------------------------------------- 1 | package main_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "todo" 8 | "todo/ent" 9 | "todo/ent/migrate" 10 | "todo/ent/user" 11 | 12 | "ariga.io/entcache" 13 | 14 | "entgo.io/ent/dialect" 15 | "entgo.io/ent/dialect/sql" 16 | gqlclient "github.com/99designs/gqlgen/client" 17 | "github.com/99designs/gqlgen/graphql" 18 | "github.com/99designs/gqlgen/graphql/handler" 19 | _ "github.com/mattn/go-sqlite3" 20 | ) 21 | 22 | func TestContextLevel(t *testing.T) { 23 | db, err := sql.Open(dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1") 24 | if err != nil { 25 | t.Fatal("opening database", err) 26 | } 27 | ctx := context.Background() 28 | // Run the migration without the debug information. 29 | if err := ent.NewClient(ent.Driver(db)).Schema.Create(ctx, migrate.WithGlobalUniqueID(true)); err != nil { 30 | t.Fatal("running schema migration", err) 31 | } 32 | // Wraps the driver with a query counter. 33 | q := &queryCount{Driver: db} 34 | // Wrap the driver with a cache driver, and configure 35 | // it to work in a context-level mode. 36 | drv := entcache.NewDriver(q, entcache.ContextLevel()) 37 | client := ent.NewClient(ent.Driver(drv)) 38 | 39 | parent := client.Todo.Create().SetText("parent").SaveX(ctx) 40 | children := client.Todo.CreateBulk( 41 | client.Todo.Create().SetText("child-1").SetParent(parent), 42 | client.Todo.Create().SetText("child-2").SetParent(parent), 43 | client.Todo.Create().SetText("child-3").SetParent(parent), 44 | ).SaveX(ctx) 45 | client.User.CreateBulk( 46 | client.User.Create().SetName("a8m").AddTodos(parent, children[1]), 47 | client.User.Create().SetName("nati").AddTodos(children[2:]...), 48 | ).ExecX(ctx) 49 | ids := client.User.Query().IDsX(ctx) 50 | 51 | t.Run("WithoutCache", func(t *testing.T) { 52 | q.expectCount(t, 3, func() { 53 | client.User.Query(). 54 | Where(user.IDIn(ids...)). 55 | WithTodos(func(q *ent.TodoQuery) { 56 | q.WithOwner() 57 | }). 58 | AllX(ctx) 59 | }) 60 | }) 61 | 62 | t.Run("WithCache", func(t *testing.T) { 63 | ctx := entcache.NewContext(ctx) 64 | q.expectCount(t, 2, func() { 65 | client.User.Query(). 66 | Where(user.IDIn(ids...)). 67 | WithTodos(func(q *ent.TodoQuery) { 68 | q.WithOwner() 69 | }). 70 | AllX(ctx) 71 | }) 72 | }) 73 | 74 | // Demonstrate the usage with GraphQL. 75 | t.Run("GQL", func(t *testing.T) { 76 | const query = `query($ids: [ID!]!) { 77 | nodes(ids: $ids) { 78 | ... on User { 79 | id 80 | todos { 81 | id 82 | owner { 83 | id 84 | name 85 | } 86 | } 87 | } 88 | } 89 | }` 90 | // Load the ent_types table on the first Node query. 91 | if _, err := client.Noders(ctx, ids); err != nil { 92 | t.Fatal(err) 93 | } 94 | t.Run("WithoutCache", func(t *testing.T) { 95 | var ( 96 | rsp any 97 | srv = handler.NewDefaultServer(todo.NewSchema(client)) 98 | gql = gqlclient.New(srv) 99 | ) 100 | q.expectCount(t, 3, func() { 101 | if err := gql.Post(query, &rsp, gqlclient.Var("ids", ids)); err != nil { 102 | t.Fatal(err) 103 | } 104 | }) 105 | }) 106 | t.Run("WithCache", func(t *testing.T) { 107 | var ( 108 | rsp any 109 | srv = handler.NewDefaultServer(todo.NewSchema(client)) 110 | gql = gqlclient.New(srv) 111 | ) 112 | srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { 113 | return next(entcache.NewContext(ctx)) 114 | }) 115 | q.expectCount(t, 2, func() { 116 | if err := gql.Post(query, &rsp, gqlclient.Var("ids", ids)); err != nil { 117 | t.Fatal(err) 118 | } 119 | }) 120 | }) 121 | }) 122 | } 123 | 124 | type queryCount struct { 125 | n int 126 | dialect.Driver 127 | } 128 | 129 | func (q *queryCount) Query(ctx context.Context, query string, args, v any) error { 130 | q.n++ 131 | return q.Driver.Query(ctx, query, args, v) 132 | } 133 | 134 | // expectCount expects the given function to execute "n" queries. 135 | func (q *queryCount) expectCount(t *testing.T, n int, fn func()) { 136 | q.n = 0 137 | fn() 138 | if q.n != n { 139 | t.Errorf("expect client to execute %d queries, got: %d", n, q.n) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /internal/examples/multilevel/README.md: -------------------------------------------------------------------------------- 1 | ### Run Redis Container 2 | 3 | ```shell 4 | docker run --rm -p 6379:6379 redis 5 | ``` 6 | 7 | ### Run The App 8 | 9 | ```shell 10 | go run main.go 11 | ``` 12 | 13 | Open [localhost:8081](http://localhost:8081/) and execute GraphQL queries. 14 | 15 | ### Get Cache Stats 16 | 17 | ```shell 18 | curl :8081/stats 19 | 20 | # cache stats (gets: 61, hits: 44, errors: 0) 21 | ``` 22 | -------------------------------------------------------------------------------- /internal/examples/multilevel/go.mod: -------------------------------------------------------------------------------- 1 | module multilevel 2 | 3 | go 1.17 4 | 5 | replace ( 6 | ariga.io/entcache => ../../../ 7 | todo => ../todo 8 | ) 9 | 10 | require ( 11 | ariga.io/entcache v0.0.0 12 | todo v0.0.0 13 | ) 14 | 15 | require ( 16 | entgo.io/contrib v0.1.1-0.20211009150803-2f98d3a15e7d 17 | entgo.io/ent v0.9.2-0.20211014063230-899e9f0e50ba 18 | github.com/99designs/gqlgen v0.14.0 19 | github.com/alecthomas/kong v0.2.16 20 | github.com/mattn/go-sqlite3 v1.14.8 21 | github.com/vektah/gqlparser/v2 v2.2.0 22 | ) 23 | 24 | require ( 25 | github.com/agnivade/levenshtein v1.1.0 // indirect 26 | github.com/cespare/xxhash/v2 v2.1.1 // indirect 27 | github.com/davecgh/go-spew v1.1.1 // indirect 28 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 29 | github.com/go-openapi/inflect v0.19.0 // indirect 30 | github.com/go-redis/redis/v8 v8.11.3 // indirect 31 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 32 | github.com/google/uuid v1.3.0 // indirect 33 | github.com/gorilla/websocket v1.4.2 // indirect 34 | github.com/graphql-go/graphql v0.7.10-0.20210411022516-8a92e977c10b // indirect 35 | github.com/hashicorp/errwrap v1.1.0 // indirect 36 | github.com/hashicorp/go-multierror v1.1.1 // indirect 37 | github.com/hashicorp/golang-lru v0.5.4 // indirect 38 | github.com/mitchellh/hashstructure v1.1.0 // indirect 39 | github.com/mitchellh/mapstructure v1.4.2 // indirect 40 | github.com/pkg/errors v0.9.1 // indirect 41 | github.com/pmezard/go-difflib v1.0.0 // indirect 42 | github.com/vmihailenco/msgpack/v5 v5.2.0 // indirect 43 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 44 | golang.org/x/mod v0.4.2 // indirect 45 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect 46 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e // indirect 47 | golang.org/x/tools v0.1.7 // indirect 48 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect 49 | gopkg.in/yaml.v2 v2.4.0 // indirect 50 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect 51 | ) 52 | -------------------------------------------------------------------------------- /internal/examples/multilevel/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "time" 9 | 10 | "todo" 11 | "todo/ent" 12 | "todo/ent/migrate" 13 | 14 | "ariga.io/entcache" 15 | "entgo.io/contrib/entgql" 16 | "entgo.io/ent/dialect" 17 | "entgo.io/ent/dialect/sql" 18 | "github.com/99designs/gqlgen/graphql/handler" 19 | "github.com/99designs/gqlgen/graphql/playground" 20 | "github.com/alecthomas/kong" 21 | "github.com/go-redis/redis/v8" 22 | _ "github.com/mattn/go-sqlite3" 23 | ) 24 | 25 | func main() { 26 | var cli struct { 27 | Addr string `name:"address" default:":8081" help:"Address to listen on."` 28 | Cache bool `name:"cache" default:"true" help:"Enable context-level cache mode."` 29 | RedisAddr string `name:"redis" default:":6379" help:"Redis address"` 30 | } 31 | kong.Parse(&cli) 32 | db, err := sql.Open(dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1") 33 | if err != nil { 34 | log.Fatal("opening database", err) 35 | } 36 | ctx := context.Background() 37 | // Run the migration without the debug information. 38 | if err := ent.NewClient(ent.Driver(db)).Schema.Create(ctx, migrate.WithGlobalUniqueID(true)); err != nil { 39 | log.Fatal("running schema migration", err) 40 | } 41 | drv := dialect.Debug(db) 42 | if cli.Cache { 43 | rdb := redis.NewClient(&redis.Options{ 44 | Addr: cli.RedisAddr, 45 | }) 46 | if err := rdb.Ping(ctx).Err(); err != nil { 47 | log.Fatal(err) 48 | } 49 | // In case of the cache cache is enabled, we wrap the driver with 50 | // a cache driver, and configures it to work in multi-level mode. 51 | drv = entcache.NewDriver( 52 | drv, 53 | entcache.TTL(time.Second*5), 54 | entcache.Levels( 55 | entcache.NewLRU(256), 56 | entcache.NewRedis(rdb), 57 | ), 58 | ) 59 | } 60 | client := ent.NewClient(ent.Driver(drv)) 61 | srv := handler.NewDefaultServer(todo.NewSchema(client)) 62 | srv.Use(entgql.Transactioner{TxOpener: client}) 63 | http.Handle("/", playground.Handler("Todo", "/query")) 64 | http.Handle("/query", srv) 65 | http.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) { 66 | if cd, ok := drv.(*entcache.Driver); ok { 67 | stat := cd.Stats() 68 | fmt.Fprintf(w, "cache stats (gets: %d, hits: %d, errors: %d)\n", stat.Gets, stat.Hits, stat.Errors) 69 | } else { 70 | fmt.Fprintln(w, "cache mode is not enabled") 71 | } 72 | }) 73 | log.Println("listening on", cli.Addr) 74 | if err := http.ListenAndServe(cli.Addr, nil); err != nil { 75 | log.Fatal("http server terminated", err) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /internal/examples/todo/README.md: -------------------------------------------------------------------------------- 1 | # ent-graphql-example 2 | 3 | Please go to [github.com/a8m/ent-graphql-example](https://github.com/a8m/ent-graphql-example) for the complete example. -------------------------------------------------------------------------------- /internal/examples/todo/ent.graphql: -------------------------------------------------------------------------------- 1 | """ 2 | TodoWhereInput is used for filtering Todo objects. 3 | Input was generated by ent. 4 | """ 5 | input TodoWhereInput { 6 | not: TodoWhereInput 7 | and: [TodoWhereInput!] 8 | or: [TodoWhereInput!] 9 | 10 | """text field predicates""" 11 | text: String 12 | textNEQ: String 13 | textIn: [String!] 14 | textNotIn: [String!] 15 | textGT: String 16 | textGTE: String 17 | textLT: String 18 | textLTE: String 19 | textContains: String 20 | textHasPrefix: String 21 | textHasSuffix: String 22 | textEqualFold: String 23 | textContainsFold: String 24 | 25 | """created_at field predicates""" 26 | createdAt: Time 27 | createdAtNEQ: Time 28 | createdAtIn: [Time!] 29 | createdAtNotIn: [Time!] 30 | createdAtGT: Time 31 | createdAtGTE: Time 32 | createdAtLT: Time 33 | createdAtLTE: Time 34 | 35 | """status field predicates""" 36 | status: Status 37 | statusNEQ: Status 38 | statusIn: [Status!] 39 | statusNotIn: [Status!] 40 | 41 | """priority field predicates""" 42 | priority: Int 43 | priorityNEQ: Int 44 | priorityIn: [Int!] 45 | priorityNotIn: [Int!] 46 | priorityGT: Int 47 | priorityGTE: Int 48 | priorityLT: Int 49 | priorityLTE: Int 50 | 51 | """id field predicates""" 52 | id: ID 53 | idNEQ: ID 54 | idIn: [ID!] 55 | idNotIn: [ID!] 56 | idGT: ID 57 | idGTE: ID 58 | idLT: ID 59 | idLTE: ID 60 | 61 | """children edge predicates""" 62 | hasChildren: Boolean 63 | hasChildrenWith: [TodoWhereInput!] 64 | 65 | """parent edge predicates""" 66 | hasParent: Boolean 67 | hasParentWith: [TodoWhereInput!] 68 | 69 | """owner edge predicates""" 70 | hasOwner: Boolean 71 | hasOwnerWith: [UserWhereInput!] 72 | } 73 | 74 | """ 75 | UserWhereInput is used for filtering User objects. 76 | Input was generated by ent. 77 | """ 78 | input UserWhereInput { 79 | not: UserWhereInput 80 | and: [UserWhereInput!] 81 | or: [UserWhereInput!] 82 | 83 | """name field predicates""" 84 | name: String 85 | nameNEQ: String 86 | nameIn: [String!] 87 | nameNotIn: [String!] 88 | nameGT: String 89 | nameGTE: String 90 | nameLT: String 91 | nameLTE: String 92 | nameContains: String 93 | nameHasPrefix: String 94 | nameHasSuffix: String 95 | nameEqualFold: String 96 | nameContainsFold: String 97 | 98 | """id field predicates""" 99 | id: ID 100 | idNEQ: ID 101 | idIn: [ID!] 102 | idNotIn: [ID!] 103 | idGT: ID 104 | idGTE: ID 105 | idLT: ID 106 | idLTE: ID 107 | 108 | """todos edge predicates""" 109 | hasTodos: Boolean 110 | hasTodosWith: [TodoWhereInput!] 111 | } 112 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/client.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log" 9 | 10 | "todo/ent/migrate" 11 | 12 | "todo/ent/todo" 13 | "todo/ent/user" 14 | 15 | "entgo.io/ent/dialect" 16 | "entgo.io/ent/dialect/sql" 17 | "entgo.io/ent/dialect/sql/sqlgraph" 18 | ) 19 | 20 | // Client is the client that holds all ent builders. 21 | type Client struct { 22 | config 23 | // Schema is the client for creating, migrating and dropping schema. 24 | Schema *migrate.Schema 25 | // Todo is the client for interacting with the Todo builders. 26 | Todo *TodoClient 27 | // User is the client for interacting with the User builders. 28 | User *UserClient 29 | // additional fields for node api 30 | tables tables 31 | } 32 | 33 | // NewClient creates a new client configured with the given options. 34 | func NewClient(opts ...Option) *Client { 35 | cfg := config{log: log.Println, hooks: &hooks{}} 36 | cfg.options(opts...) 37 | client := &Client{config: cfg} 38 | client.init() 39 | return client 40 | } 41 | 42 | func (c *Client) init() { 43 | c.Schema = migrate.NewSchema(c.driver) 44 | c.Todo = NewTodoClient(c.config) 45 | c.User = NewUserClient(c.config) 46 | } 47 | 48 | // Open opens a database/sql.DB specified by the driver name and 49 | // the data source name, and returns a new client attached to it. 50 | // Optional parameters can be added for configuring the client. 51 | func Open(driverName, dataSourceName string, options ...Option) (*Client, error) { 52 | switch driverName { 53 | case dialect.MySQL, dialect.Postgres, dialect.SQLite: 54 | drv, err := sql.Open(driverName, dataSourceName) 55 | if err != nil { 56 | return nil, err 57 | } 58 | return NewClient(append(options, Driver(drv))...), nil 59 | default: 60 | return nil, fmt.Errorf("unsupported driver: %q", driverName) 61 | } 62 | } 63 | 64 | // Tx returns a new transactional client. The provided context 65 | // is used until the transaction is committed or rolled back. 66 | func (c *Client) Tx(ctx context.Context) (*Tx, error) { 67 | if _, ok := c.driver.(*txDriver); ok { 68 | return nil, fmt.Errorf("ent: cannot start a transaction within a transaction") 69 | } 70 | tx, err := newTx(ctx, c.driver) 71 | if err != nil { 72 | return nil, fmt.Errorf("ent: starting a transaction: %w", err) 73 | } 74 | cfg := c.config 75 | cfg.driver = tx 76 | return &Tx{ 77 | ctx: ctx, 78 | config: cfg, 79 | Todo: NewTodoClient(cfg), 80 | User: NewUserClient(cfg), 81 | }, nil 82 | } 83 | 84 | // BeginTx returns a transactional client with specified options. 85 | func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { 86 | if _, ok := c.driver.(*txDriver); ok { 87 | return nil, fmt.Errorf("ent: cannot start a transaction within a transaction") 88 | } 89 | tx, err := c.driver.(interface { 90 | BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) 91 | }).BeginTx(ctx, opts) 92 | if err != nil { 93 | return nil, fmt.Errorf("ent: starting a transaction: %w", err) 94 | } 95 | cfg := c.config 96 | cfg.driver = &txDriver{tx: tx, drv: c.driver} 97 | return &Tx{ 98 | config: cfg, 99 | Todo: NewTodoClient(cfg), 100 | User: NewUserClient(cfg), 101 | }, nil 102 | } 103 | 104 | // Debug returns a new debug-client. It's used to get verbose logging on specific operations. 105 | // 106 | // client.Debug(). 107 | // Todo. 108 | // Query(). 109 | // Count(ctx) 110 | // 111 | func (c *Client) Debug() *Client { 112 | if c.debug { 113 | return c 114 | } 115 | cfg := c.config 116 | cfg.driver = dialect.Debug(c.driver, c.log) 117 | client := &Client{config: cfg} 118 | client.init() 119 | return client 120 | } 121 | 122 | // Close closes the database connection and prevents new queries from starting. 123 | func (c *Client) Close() error { 124 | return c.driver.Close() 125 | } 126 | 127 | // Use adds the mutation hooks to all the entity clients. 128 | // In order to add hooks to a specific client, call: `client.Node.Use(...)`. 129 | func (c *Client) Use(hooks ...Hook) { 130 | c.Todo.Use(hooks...) 131 | c.User.Use(hooks...) 132 | } 133 | 134 | // TodoClient is a client for the Todo schema. 135 | type TodoClient struct { 136 | config 137 | } 138 | 139 | // NewTodoClient returns a client for the Todo from the given config. 140 | func NewTodoClient(c config) *TodoClient { 141 | return &TodoClient{config: c} 142 | } 143 | 144 | // Use adds a list of mutation hooks to the hooks stack. 145 | // A call to `Use(f, g, h)` equals to `todo.Hooks(f(g(h())))`. 146 | func (c *TodoClient) Use(hooks ...Hook) { 147 | c.hooks.Todo = append(c.hooks.Todo, hooks...) 148 | } 149 | 150 | // Create returns a create builder for Todo. 151 | func (c *TodoClient) Create() *TodoCreate { 152 | mutation := newTodoMutation(c.config, OpCreate) 153 | return &TodoCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} 154 | } 155 | 156 | // CreateBulk returns a builder for creating a bulk of Todo entities. 157 | func (c *TodoClient) CreateBulk(builders ...*TodoCreate) *TodoCreateBulk { 158 | return &TodoCreateBulk{config: c.config, builders: builders} 159 | } 160 | 161 | // Update returns an update builder for Todo. 162 | func (c *TodoClient) Update() *TodoUpdate { 163 | mutation := newTodoMutation(c.config, OpUpdate) 164 | return &TodoUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} 165 | } 166 | 167 | // UpdateOne returns an update builder for the given entity. 168 | func (c *TodoClient) UpdateOne(t *Todo) *TodoUpdateOne { 169 | mutation := newTodoMutation(c.config, OpUpdateOne, withTodo(t)) 170 | return &TodoUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} 171 | } 172 | 173 | // UpdateOneID returns an update builder for the given id. 174 | func (c *TodoClient) UpdateOneID(id int) *TodoUpdateOne { 175 | mutation := newTodoMutation(c.config, OpUpdateOne, withTodoID(id)) 176 | return &TodoUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} 177 | } 178 | 179 | // Delete returns a delete builder for Todo. 180 | func (c *TodoClient) Delete() *TodoDelete { 181 | mutation := newTodoMutation(c.config, OpDelete) 182 | return &TodoDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} 183 | } 184 | 185 | // DeleteOne returns a delete builder for the given entity. 186 | func (c *TodoClient) DeleteOne(t *Todo) *TodoDeleteOne { 187 | return c.DeleteOneID(t.ID) 188 | } 189 | 190 | // DeleteOneID returns a delete builder for the given id. 191 | func (c *TodoClient) DeleteOneID(id int) *TodoDeleteOne { 192 | builder := c.Delete().Where(todo.ID(id)) 193 | builder.mutation.id = &id 194 | builder.mutation.op = OpDeleteOne 195 | return &TodoDeleteOne{builder} 196 | } 197 | 198 | // Query returns a query builder for Todo. 199 | func (c *TodoClient) Query() *TodoQuery { 200 | return &TodoQuery{ 201 | config: c.config, 202 | } 203 | } 204 | 205 | // Get returns a Todo entity by its id. 206 | func (c *TodoClient) Get(ctx context.Context, id int) (*Todo, error) { 207 | return c.Query().Where(todo.ID(id)).Only(ctx) 208 | } 209 | 210 | // GetX is like Get, but panics if an error occurs. 211 | func (c *TodoClient) GetX(ctx context.Context, id int) *Todo { 212 | obj, err := c.Get(ctx, id) 213 | if err != nil { 214 | panic(err) 215 | } 216 | return obj 217 | } 218 | 219 | // QueryChildren queries the children edge of a Todo. 220 | func (c *TodoClient) QueryChildren(t *Todo) *TodoQuery { 221 | query := &TodoQuery{config: c.config} 222 | query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { 223 | id := t.ID 224 | step := sqlgraph.NewStep( 225 | sqlgraph.From(todo.Table, todo.FieldID, id), 226 | sqlgraph.To(todo.Table, todo.FieldID), 227 | sqlgraph.Edge(sqlgraph.O2M, true, todo.ChildrenTable, todo.ChildrenColumn), 228 | ) 229 | fromV = sqlgraph.Neighbors(t.driver.Dialect(), step) 230 | return fromV, nil 231 | } 232 | return query 233 | } 234 | 235 | // QueryParent queries the parent edge of a Todo. 236 | func (c *TodoClient) QueryParent(t *Todo) *TodoQuery { 237 | query := &TodoQuery{config: c.config} 238 | query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { 239 | id := t.ID 240 | step := sqlgraph.NewStep( 241 | sqlgraph.From(todo.Table, todo.FieldID, id), 242 | sqlgraph.To(todo.Table, todo.FieldID), 243 | sqlgraph.Edge(sqlgraph.M2O, false, todo.ParentTable, todo.ParentColumn), 244 | ) 245 | fromV = sqlgraph.Neighbors(t.driver.Dialect(), step) 246 | return fromV, nil 247 | } 248 | return query 249 | } 250 | 251 | // QueryOwner queries the owner edge of a Todo. 252 | func (c *TodoClient) QueryOwner(t *Todo) *UserQuery { 253 | query := &UserQuery{config: c.config} 254 | query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { 255 | id := t.ID 256 | step := sqlgraph.NewStep( 257 | sqlgraph.From(todo.Table, todo.FieldID, id), 258 | sqlgraph.To(user.Table, user.FieldID), 259 | sqlgraph.Edge(sqlgraph.M2O, true, todo.OwnerTable, todo.OwnerColumn), 260 | ) 261 | fromV = sqlgraph.Neighbors(t.driver.Dialect(), step) 262 | return fromV, nil 263 | } 264 | return query 265 | } 266 | 267 | // Hooks returns the client hooks. 268 | func (c *TodoClient) Hooks() []Hook { 269 | return c.hooks.Todo 270 | } 271 | 272 | // UserClient is a client for the User schema. 273 | type UserClient struct { 274 | config 275 | } 276 | 277 | // NewUserClient returns a client for the User from the given config. 278 | func NewUserClient(c config) *UserClient { 279 | return &UserClient{config: c} 280 | } 281 | 282 | // Use adds a list of mutation hooks to the hooks stack. 283 | // A call to `Use(f, g, h)` equals to `user.Hooks(f(g(h())))`. 284 | func (c *UserClient) Use(hooks ...Hook) { 285 | c.hooks.User = append(c.hooks.User, hooks...) 286 | } 287 | 288 | // Create returns a create builder for User. 289 | func (c *UserClient) Create() *UserCreate { 290 | mutation := newUserMutation(c.config, OpCreate) 291 | return &UserCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} 292 | } 293 | 294 | // CreateBulk returns a builder for creating a bulk of User entities. 295 | func (c *UserClient) CreateBulk(builders ...*UserCreate) *UserCreateBulk { 296 | return &UserCreateBulk{config: c.config, builders: builders} 297 | } 298 | 299 | // Update returns an update builder for User. 300 | func (c *UserClient) Update() *UserUpdate { 301 | mutation := newUserMutation(c.config, OpUpdate) 302 | return &UserUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} 303 | } 304 | 305 | // UpdateOne returns an update builder for the given entity. 306 | func (c *UserClient) UpdateOne(u *User) *UserUpdateOne { 307 | mutation := newUserMutation(c.config, OpUpdateOne, withUser(u)) 308 | return &UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} 309 | } 310 | 311 | // UpdateOneID returns an update builder for the given id. 312 | func (c *UserClient) UpdateOneID(id int) *UserUpdateOne { 313 | mutation := newUserMutation(c.config, OpUpdateOne, withUserID(id)) 314 | return &UserUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} 315 | } 316 | 317 | // Delete returns a delete builder for User. 318 | func (c *UserClient) Delete() *UserDelete { 319 | mutation := newUserMutation(c.config, OpDelete) 320 | return &UserDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} 321 | } 322 | 323 | // DeleteOne returns a delete builder for the given entity. 324 | func (c *UserClient) DeleteOne(u *User) *UserDeleteOne { 325 | return c.DeleteOneID(u.ID) 326 | } 327 | 328 | // DeleteOneID returns a delete builder for the given id. 329 | func (c *UserClient) DeleteOneID(id int) *UserDeleteOne { 330 | builder := c.Delete().Where(user.ID(id)) 331 | builder.mutation.id = &id 332 | builder.mutation.op = OpDeleteOne 333 | return &UserDeleteOne{builder} 334 | } 335 | 336 | // Query returns a query builder for User. 337 | func (c *UserClient) Query() *UserQuery { 338 | return &UserQuery{ 339 | config: c.config, 340 | } 341 | } 342 | 343 | // Get returns a User entity by its id. 344 | func (c *UserClient) Get(ctx context.Context, id int) (*User, error) { 345 | return c.Query().Where(user.ID(id)).Only(ctx) 346 | } 347 | 348 | // GetX is like Get, but panics if an error occurs. 349 | func (c *UserClient) GetX(ctx context.Context, id int) *User { 350 | obj, err := c.Get(ctx, id) 351 | if err != nil { 352 | panic(err) 353 | } 354 | return obj 355 | } 356 | 357 | // QueryTodos queries the todos edge of a User. 358 | func (c *UserClient) QueryTodos(u *User) *TodoQuery { 359 | query := &TodoQuery{config: c.config} 360 | query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { 361 | id := u.ID 362 | step := sqlgraph.NewStep( 363 | sqlgraph.From(user.Table, user.FieldID, id), 364 | sqlgraph.To(todo.Table, todo.FieldID), 365 | sqlgraph.Edge(sqlgraph.O2M, false, user.TodosTable, user.TodosColumn), 366 | ) 367 | fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) 368 | return fromV, nil 369 | } 370 | return query 371 | } 372 | 373 | // Hooks returns the client hooks. 374 | func (c *UserClient) Hooks() []Hook { 375 | return c.hooks.User 376 | } 377 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/config.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "entgo.io/ent" 7 | "entgo.io/ent/dialect" 8 | ) 9 | 10 | // Option function to configure the client. 11 | type Option func(*config) 12 | 13 | // Config is the configuration for the client and its builder. 14 | type config struct { 15 | // driver used for executing database requests. 16 | driver dialect.Driver 17 | // debug enable a debug logging. 18 | debug bool 19 | // log used for logging on debug mode. 20 | log func(...any) 21 | // hooks to execute on mutations. 22 | hooks *hooks 23 | } 24 | 25 | // hooks per client, for fast access. 26 | type hooks struct { 27 | Todo []ent.Hook 28 | User []ent.Hook 29 | } 30 | 31 | // Options applies the options on the config object. 32 | func (c *config) options(opts ...Option) { 33 | for _, opt := range opts { 34 | opt(c) 35 | } 36 | if c.debug { 37 | c.driver = dialect.Debug(c.driver, c.log) 38 | } 39 | } 40 | 41 | // Debug enables debug logging on the ent.Driver. 42 | func Debug() Option { 43 | return func(c *config) { 44 | c.debug = true 45 | } 46 | } 47 | 48 | // Log sets the logging function for debug mode. 49 | func Log(fn func(...any)) Option { 50 | return func(c *config) { 51 | c.log = fn 52 | } 53 | } 54 | 55 | // Driver configures the client driver. 56 | func Driver(driver dialect.Driver) Option { 57 | return func(c *config) { 58 | c.driver = driver 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/context.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | type clientCtxKey struct{} 10 | 11 | // FromContext returns a Client stored inside a context, or nil if there isn't one. 12 | func FromContext(ctx context.Context) *Client { 13 | c, _ := ctx.Value(clientCtxKey{}).(*Client) 14 | return c 15 | } 16 | 17 | // NewContext returns a new context with the given Client attached. 18 | func NewContext(parent context.Context, c *Client) context.Context { 19 | return context.WithValue(parent, clientCtxKey{}, c) 20 | } 21 | 22 | type txCtxKey struct{} 23 | 24 | // TxFromContext returns a Tx stored inside a context, or nil if there isn't one. 25 | func TxFromContext(ctx context.Context) *Tx { 26 | tx, _ := ctx.Value(txCtxKey{}).(*Tx) 27 | return tx 28 | } 29 | 30 | // NewTxContext returns a new context with the given Tx attached. 31 | func NewTxContext(parent context.Context, tx *Tx) context.Context { 32 | return context.WithValue(parent, txCtxKey{}, tx) 33 | } 34 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/ent.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "todo/ent/todo" 9 | "todo/ent/user" 10 | 11 | "entgo.io/ent" 12 | "entgo.io/ent/dialect/sql" 13 | ) 14 | 15 | // ent aliases to avoid import conflicts in user's code. 16 | type ( 17 | Op = ent.Op 18 | Hook = ent.Hook 19 | Value = ent.Value 20 | Query = ent.Query 21 | Policy = ent.Policy 22 | Mutator = ent.Mutator 23 | Mutation = ent.Mutation 24 | MutateFunc = ent.MutateFunc 25 | ) 26 | 27 | // OrderFunc applies an ordering on the sql selector. 28 | type OrderFunc func(*sql.Selector) 29 | 30 | // columnChecker returns a function indicates if the column exists in the given column. 31 | func columnChecker(table string) func(string) error { 32 | checks := map[string]func(string) bool{ 33 | todo.Table: todo.ValidColumn, 34 | user.Table: user.ValidColumn, 35 | } 36 | check, ok := checks[table] 37 | if !ok { 38 | return func(string) error { 39 | return fmt.Errorf("unknown table %q", table) 40 | } 41 | } 42 | return func(column string) error { 43 | if !check(column) { 44 | return fmt.Errorf("unknown column %q for table %q", column, table) 45 | } 46 | return nil 47 | } 48 | } 49 | 50 | // Asc applies the given fields in ASC order. 51 | func Asc(fields ...string) OrderFunc { 52 | return func(s *sql.Selector) { 53 | check := columnChecker(s.TableName()) 54 | for _, f := range fields { 55 | if err := check(f); err != nil { 56 | s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) 57 | } 58 | s.OrderBy(sql.Asc(s.C(f))) 59 | } 60 | } 61 | } 62 | 63 | // Desc applies the given fields in DESC order. 64 | func Desc(fields ...string) OrderFunc { 65 | return func(s *sql.Selector) { 66 | check := columnChecker(s.TableName()) 67 | for _, f := range fields { 68 | if err := check(f); err != nil { 69 | s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) 70 | } 71 | s.OrderBy(sql.Desc(s.C(f))) 72 | } 73 | } 74 | } 75 | 76 | // AggregateFunc applies an aggregation step on the group-by traversal/selector. 77 | type AggregateFunc func(*sql.Selector) string 78 | 79 | // As is a pseudo aggregation function for renaming another other functions with custom names. For example: 80 | // 81 | // GroupBy(field1, field2). 82 | // Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")). 83 | // Scan(ctx, &v) 84 | // 85 | func As(fn AggregateFunc, end string) AggregateFunc { 86 | return func(s *sql.Selector) string { 87 | return sql.As(fn(s), end) 88 | } 89 | } 90 | 91 | // Count applies the "count" aggregation function on each group. 92 | func Count() AggregateFunc { 93 | return func(s *sql.Selector) string { 94 | return sql.Count("*") 95 | } 96 | } 97 | 98 | // Max applies the "max" aggregation function on the given field of each group. 99 | func Max(field string) AggregateFunc { 100 | return func(s *sql.Selector) string { 101 | check := columnChecker(s.TableName()) 102 | if err := check(field); err != nil { 103 | s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) 104 | return "" 105 | } 106 | return sql.Max(s.C(field)) 107 | } 108 | } 109 | 110 | // Mean applies the "mean" aggregation function on the given field of each group. 111 | func Mean(field string) AggregateFunc { 112 | return func(s *sql.Selector) string { 113 | check := columnChecker(s.TableName()) 114 | if err := check(field); err != nil { 115 | s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) 116 | return "" 117 | } 118 | return sql.Avg(s.C(field)) 119 | } 120 | } 121 | 122 | // Min applies the "min" aggregation function on the given field of each group. 123 | func Min(field string) AggregateFunc { 124 | return func(s *sql.Selector) string { 125 | check := columnChecker(s.TableName()) 126 | if err := check(field); err != nil { 127 | s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) 128 | return "" 129 | } 130 | return sql.Min(s.C(field)) 131 | } 132 | } 133 | 134 | // Sum applies the "sum" aggregation function on the given field of each group. 135 | func Sum(field string) AggregateFunc { 136 | return func(s *sql.Selector) string { 137 | check := columnChecker(s.TableName()) 138 | if err := check(field); err != nil { 139 | s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) 140 | return "" 141 | } 142 | return sql.Sum(s.C(field)) 143 | } 144 | } 145 | 146 | // ValidationError returns when validating a field or edge fails. 147 | type ValidationError struct { 148 | Name string // Field or edge name. 149 | err error 150 | } 151 | 152 | // Error implements the error interface. 153 | func (e *ValidationError) Error() string { 154 | return e.err.Error() 155 | } 156 | 157 | // Unwrap implements the errors.Wrapper interface. 158 | func (e *ValidationError) Unwrap() error { 159 | return e.err 160 | } 161 | 162 | // IsValidationError returns a boolean indicating whether the error is a validation error. 163 | func IsValidationError(err error) bool { 164 | if err == nil { 165 | return false 166 | } 167 | var e *ValidationError 168 | return errors.As(err, &e) 169 | } 170 | 171 | // NotFoundError returns when trying to fetch a specific entity and it was not found in the database. 172 | type NotFoundError struct { 173 | label string 174 | } 175 | 176 | // Error implements the error interface. 177 | func (e *NotFoundError) Error() string { 178 | return "ent: " + e.label + " not found" 179 | } 180 | 181 | // IsNotFound returns a boolean indicating whether the error is a not found error. 182 | func IsNotFound(err error) bool { 183 | if err == nil { 184 | return false 185 | } 186 | var e *NotFoundError 187 | return errors.As(err, &e) 188 | } 189 | 190 | // MaskNotFound masks not found error. 191 | func MaskNotFound(err error) error { 192 | if IsNotFound(err) { 193 | return nil 194 | } 195 | return err 196 | } 197 | 198 | // NotSingularError returns when trying to fetch a singular entity and more then one was found in the database. 199 | type NotSingularError struct { 200 | label string 201 | } 202 | 203 | // Error implements the error interface. 204 | func (e *NotSingularError) Error() string { 205 | return "ent: " + e.label + " not singular" 206 | } 207 | 208 | // IsNotSingular returns a boolean indicating whether the error is a not singular error. 209 | func IsNotSingular(err error) bool { 210 | if err == nil { 211 | return false 212 | } 213 | var e *NotSingularError 214 | return errors.As(err, &e) 215 | } 216 | 217 | // NotLoadedError returns when trying to get a node that was not loaded by the query. 218 | type NotLoadedError struct { 219 | edge string 220 | } 221 | 222 | // Error implements the error interface. 223 | func (e *NotLoadedError) Error() string { 224 | return "ent: " + e.edge + " edge was not loaded" 225 | } 226 | 227 | // IsNotLoaded returns a boolean indicating whether the error is a not loaded error. 228 | func IsNotLoaded(err error) bool { 229 | if err == nil { 230 | return false 231 | } 232 | var e *NotLoadedError 233 | return errors.As(err, &e) 234 | } 235 | 236 | // ConstraintError returns when trying to create/update one or more entities and 237 | // one or more of their constraints failed. For example, violation of edge or 238 | // field uniqueness. 239 | type ConstraintError struct { 240 | msg string 241 | wrap error 242 | } 243 | 244 | // Error implements the error interface. 245 | func (e ConstraintError) Error() string { 246 | return "ent: constraint failed: " + e.msg 247 | } 248 | 249 | // Unwrap implements the errors.Wrapper interface. 250 | func (e *ConstraintError) Unwrap() error { 251 | return e.wrap 252 | } 253 | 254 | // IsConstraintError returns a boolean indicating whether the error is a constraint failure. 255 | func IsConstraintError(err error) bool { 256 | if err == nil { 257 | return false 258 | } 259 | var e *ConstraintError 260 | return errors.As(err, &e) 261 | } 262 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/entc.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | package main 5 | 6 | import ( 7 | "log" 8 | 9 | "entgo.io/contrib/entgql" 10 | "entgo.io/ent/entc" 11 | "entgo.io/ent/entc/gen" 12 | ) 13 | 14 | func main() { 15 | ex, err := entgql.NewExtension( 16 | entgql.WithWhereFilters(true), 17 | entgql.WithConfigPath("../gqlgen.yml"), 18 | // Generate the filters to a separate schema 19 | // file and load it in the gqlgen.yml config. 20 | entgql.WithSchemaPath("../ent.graphql"), 21 | ) 22 | if err != nil { 23 | log.Fatalf("creating entgql extension: %v", err) 24 | } 25 | opts := []entc.Option{ 26 | entc.Extensions(ex), 27 | entc.TemplateDir("./template"), 28 | } 29 | if err := entc.Generate("./schema", &gen.Config{}, opts...); err != nil { 30 | log.Fatalf("running ent codegen: %v", err) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/enttest/enttest.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package enttest 4 | 5 | import ( 6 | "context" 7 | "todo/ent" 8 | // required by schema hooks. 9 | _ "todo/ent/runtime" 10 | 11 | "entgo.io/ent/dialect/sql/schema" 12 | ) 13 | 14 | type ( 15 | // TestingT is the interface that is shared between 16 | // testing.T and testing.B and used by enttest. 17 | TestingT interface { 18 | FailNow() 19 | Error(...any) 20 | } 21 | 22 | // Option configures client creation. 23 | Option func(*options) 24 | 25 | options struct { 26 | opts []ent.Option 27 | migrateOpts []schema.MigrateOption 28 | } 29 | ) 30 | 31 | // WithOptions forwards options to client creation. 32 | func WithOptions(opts ...ent.Option) Option { 33 | return func(o *options) { 34 | o.opts = append(o.opts, opts...) 35 | } 36 | } 37 | 38 | // WithMigrateOptions forwards options to auto migration. 39 | func WithMigrateOptions(opts ...schema.MigrateOption) Option { 40 | return func(o *options) { 41 | o.migrateOpts = append(o.migrateOpts, opts...) 42 | } 43 | } 44 | 45 | func newOptions(opts []Option) *options { 46 | o := &options{} 47 | for _, opt := range opts { 48 | opt(o) 49 | } 50 | return o 51 | } 52 | 53 | // Open calls ent.Open and auto-run migration. 54 | func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { 55 | o := newOptions(opts) 56 | c, err := ent.Open(driverName, dataSourceName, o.opts...) 57 | if err != nil { 58 | t.Error(err) 59 | t.FailNow() 60 | } 61 | if err := c.Schema.Create(context.Background(), o.migrateOpts...); err != nil { 62 | t.Error(err) 63 | t.FailNow() 64 | } 65 | return c 66 | } 67 | 68 | // NewClient calls ent.NewClient and auto-run migration. 69 | func NewClient(t TestingT, opts ...Option) *ent.Client { 70 | o := newOptions(opts) 71 | c := ent.NewClient(o.opts...) 72 | if err := c.Schema.Create(context.Background(), o.migrateOpts...); err != nil { 73 | t.Error(err) 74 | t.FailNow() 75 | } 76 | return c 77 | } 78 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/generate.go: -------------------------------------------------------------------------------- 1 | package ent 2 | 3 | //go:generate go run -mod=mod entc.go 4 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/gql_collection.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/99designs/gqlgen/graphql" 9 | ) 10 | 11 | // CollectFields tells the query-builder to eagerly load connected nodes by resolver context. 12 | func (t *TodoQuery) CollectFields(ctx context.Context, satisfies ...string) *TodoQuery { 13 | if fc := graphql.GetFieldContext(ctx); fc != nil { 14 | t = t.collectField(graphql.GetOperationContext(ctx), fc.Field, satisfies...) 15 | } 16 | return t 17 | } 18 | 19 | func (t *TodoQuery) collectField(ctx *graphql.OperationContext, field graphql.CollectedField, satisfies ...string) *TodoQuery { 20 | for _, field := range graphql.CollectFields(ctx, field.Selections, satisfies) { 21 | switch field.Name { 22 | case "children": 23 | t = t.WithChildren(func(query *TodoQuery) { 24 | query.collectField(ctx, field) 25 | }) 26 | case "owner": 27 | t = t.WithOwner(func(query *UserQuery) { 28 | query.collectField(ctx, field) 29 | }) 30 | case "parent": 31 | t = t.WithParent(func(query *TodoQuery) { 32 | query.collectField(ctx, field) 33 | }) 34 | } 35 | } 36 | return t 37 | } 38 | 39 | // CollectFields tells the query-builder to eagerly load connected nodes by resolver context. 40 | func (u *UserQuery) CollectFields(ctx context.Context, satisfies ...string) *UserQuery { 41 | if fc := graphql.GetFieldContext(ctx); fc != nil { 42 | u = u.collectField(graphql.GetOperationContext(ctx), fc.Field, satisfies...) 43 | } 44 | return u 45 | } 46 | 47 | func (u *UserQuery) collectField(ctx *graphql.OperationContext, field graphql.CollectedField, satisfies ...string) *UserQuery { 48 | for _, field := range graphql.CollectFields(ctx, field.Selections, satisfies) { 49 | switch field.Name { 50 | case "todos": 51 | u = u.WithTodos(func(query *TodoQuery) { 52 | query.collectField(ctx, field) 53 | }) 54 | } 55 | } 56 | return u 57 | } 58 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/gql_edge.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import "context" 6 | 7 | func (t *Todo) Children(ctx context.Context) ([]*Todo, error) { 8 | result, err := t.Edges.ChildrenOrErr() 9 | if IsNotLoaded(err) { 10 | result, err = t.QueryChildren().All(ctx) 11 | } 12 | return result, err 13 | } 14 | 15 | func (t *Todo) Parent(ctx context.Context) (*Todo, error) { 16 | result, err := t.Edges.ParentOrErr() 17 | if IsNotLoaded(err) { 18 | result, err = t.QueryParent().Only(ctx) 19 | } 20 | return result, MaskNotFound(err) 21 | } 22 | 23 | func (t *Todo) Owner(ctx context.Context) (*User, error) { 24 | result, err := t.Edges.OwnerOrErr() 25 | if IsNotLoaded(err) { 26 | result, err = t.QueryOwner().Only(ctx) 27 | } 28 | return result, MaskNotFound(err) 29 | } 30 | 31 | func (u *User) Todos(ctx context.Context) ([]*Todo, error) { 32 | result, err := u.Edges.TodosOrErr() 33 | if IsNotLoaded(err) { 34 | result, err = u.QueryTodos().All(ctx) 35 | } 36 | return result, err 37 | } 38 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/gql_node.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | "sync" 10 | "sync/atomic" 11 | "todo/ent/todo" 12 | "todo/ent/user" 13 | 14 | "entgo.io/contrib/entgql" 15 | "entgo.io/ent/dialect" 16 | "entgo.io/ent/dialect/sql" 17 | "entgo.io/ent/dialect/sql/schema" 18 | "github.com/99designs/gqlgen/graphql" 19 | "github.com/hashicorp/go-multierror" 20 | "golang.org/x/sync/semaphore" 21 | ) 22 | 23 | // Noder wraps the basic Node method. 24 | type Noder interface { 25 | Node(context.Context) (*Node, error) 26 | } 27 | 28 | // Node in the graph. 29 | type Node struct { 30 | ID int `json:"id,omitempty"` // node id. 31 | Type string `json:"type,omitempty"` // node type. 32 | Fields []*Field `json:"fields,omitempty"` // node fields. 33 | Edges []*Edge `json:"edges,omitempty"` // node edges. 34 | } 35 | 36 | // Field of a node. 37 | type Field struct { 38 | Type string `json:"type,omitempty"` // field type. 39 | Name string `json:"name,omitempty"` // field name (as in struct). 40 | Value string `json:"value,omitempty"` // stringified value. 41 | } 42 | 43 | // Edges between two nodes. 44 | type Edge struct { 45 | Type string `json:"type,omitempty"` // edge type. 46 | Name string `json:"name,omitempty"` // edge name. 47 | IDs []int `json:"ids,omitempty"` // node ids (where this edge point to). 48 | } 49 | 50 | func (t *Todo) Node(ctx context.Context) (node *Node, err error) { 51 | node = &Node{ 52 | ID: t.ID, 53 | Type: "Todo", 54 | Fields: make([]*Field, 4), 55 | Edges: make([]*Edge, 3), 56 | } 57 | var buf []byte 58 | if buf, err = json.Marshal(t.Text); err != nil { 59 | return nil, err 60 | } 61 | node.Fields[0] = &Field{ 62 | Type: "string", 63 | Name: "text", 64 | Value: string(buf), 65 | } 66 | if buf, err = json.Marshal(t.CreatedAt); err != nil { 67 | return nil, err 68 | } 69 | node.Fields[1] = &Field{ 70 | Type: "time.Time", 71 | Name: "created_at", 72 | Value: string(buf), 73 | } 74 | if buf, err = json.Marshal(t.Status); err != nil { 75 | return nil, err 76 | } 77 | node.Fields[2] = &Field{ 78 | Type: "todo.Status", 79 | Name: "status", 80 | Value: string(buf), 81 | } 82 | if buf, err = json.Marshal(t.Priority); err != nil { 83 | return nil, err 84 | } 85 | node.Fields[3] = &Field{ 86 | Type: "int", 87 | Name: "priority", 88 | Value: string(buf), 89 | } 90 | node.Edges[0] = &Edge{ 91 | Type: "Todo", 92 | Name: "children", 93 | } 94 | err = t.QueryChildren(). 95 | Select(todo.FieldID). 96 | Scan(ctx, &node.Edges[0].IDs) 97 | if err != nil { 98 | return nil, err 99 | } 100 | node.Edges[1] = &Edge{ 101 | Type: "Todo", 102 | Name: "parent", 103 | } 104 | err = t.QueryParent(). 105 | Select(todo.FieldID). 106 | Scan(ctx, &node.Edges[1].IDs) 107 | if err != nil { 108 | return nil, err 109 | } 110 | node.Edges[2] = &Edge{ 111 | Type: "User", 112 | Name: "owner", 113 | } 114 | err = t.QueryOwner(). 115 | Select(user.FieldID). 116 | Scan(ctx, &node.Edges[2].IDs) 117 | if err != nil { 118 | return nil, err 119 | } 120 | return node, nil 121 | } 122 | 123 | func (u *User) Node(ctx context.Context) (node *Node, err error) { 124 | node = &Node{ 125 | ID: u.ID, 126 | Type: "User", 127 | Fields: make([]*Field, 1), 128 | Edges: make([]*Edge, 1), 129 | } 130 | var buf []byte 131 | if buf, err = json.Marshal(u.Name); err != nil { 132 | return nil, err 133 | } 134 | node.Fields[0] = &Field{ 135 | Type: "string", 136 | Name: "name", 137 | Value: string(buf), 138 | } 139 | node.Edges[0] = &Edge{ 140 | Type: "Todo", 141 | Name: "todos", 142 | } 143 | err = u.QueryTodos(). 144 | Select(todo.FieldID). 145 | Scan(ctx, &node.Edges[0].IDs) 146 | if err != nil { 147 | return nil, err 148 | } 149 | return node, nil 150 | } 151 | 152 | func (c *Client) Node(ctx context.Context, id int) (*Node, error) { 153 | n, err := c.Noder(ctx, id) 154 | if err != nil { 155 | return nil, err 156 | } 157 | return n.Node(ctx) 158 | } 159 | 160 | var errNodeInvalidID = &NotFoundError{"node"} 161 | 162 | // NodeOption allows configuring the Noder execution using functional options. 163 | type NodeOption func(*nodeOptions) 164 | 165 | // WithNodeType sets the node Type resolver function (i.e. the table to query). 166 | // If was not provided, the table will be derived from the universal-id 167 | // configuration as described in: https://entgo.io/docs/migrate/#universal-ids. 168 | func WithNodeType(f func(context.Context, int) (string, error)) NodeOption { 169 | return func(o *nodeOptions) { 170 | o.nodeType = f 171 | } 172 | } 173 | 174 | // WithFixedNodeType sets the Type of the node to a fixed value. 175 | func WithFixedNodeType(t string) NodeOption { 176 | return WithNodeType(func(context.Context, int) (string, error) { 177 | return t, nil 178 | }) 179 | } 180 | 181 | type nodeOptions struct { 182 | nodeType func(context.Context, int) (string, error) 183 | } 184 | 185 | func (c *Client) newNodeOpts(opts []NodeOption) *nodeOptions { 186 | nopts := &nodeOptions{} 187 | for _, opt := range opts { 188 | opt(nopts) 189 | } 190 | if nopts.nodeType == nil { 191 | nopts.nodeType = func(ctx context.Context, id int) (string, error) { 192 | return c.tables.nodeType(ctx, c.driver, id) 193 | } 194 | } 195 | return nopts 196 | } 197 | 198 | // Noder returns a Node by its id. If the NodeType was not provided, it will 199 | // be derived from the id value according to the universal-id configuration. 200 | // 201 | // c.Noder(ctx, id) 202 | // c.Noder(ctx, id, ent.WithNodeType(pet.Table)) 203 | // 204 | func (c *Client) Noder(ctx context.Context, id int, opts ...NodeOption) (_ Noder, err error) { 205 | defer func() { 206 | if IsNotFound(err) { 207 | err = multierror.Append(err, entgql.ErrNodeNotFound(id)) 208 | } 209 | }() 210 | table, err := c.newNodeOpts(opts).nodeType(ctx, id) 211 | if err != nil { 212 | return nil, err 213 | } 214 | return c.noder(ctx, table, id) 215 | } 216 | 217 | func (c *Client) noder(ctx context.Context, table string, id int) (Noder, error) { 218 | switch table { 219 | case todo.Table: 220 | n, err := c.Todo.Query(). 221 | Where(todo.ID(id)). 222 | CollectFields(ctx, "Todo"). 223 | Only(ctx) 224 | if err != nil { 225 | return nil, err 226 | } 227 | return n, nil 228 | case user.Table: 229 | n, err := c.User.Query(). 230 | Where(user.ID(id)). 231 | CollectFields(ctx, "User"). 232 | Only(ctx) 233 | if err != nil { 234 | return nil, err 235 | } 236 | return n, nil 237 | default: 238 | return nil, fmt.Errorf("cannot resolve noder from table %q: %w", table, errNodeInvalidID) 239 | } 240 | } 241 | 242 | func (c *Client) Noders(ctx context.Context, ids []int, opts ...NodeOption) ([]Noder, error) { 243 | switch len(ids) { 244 | case 1: 245 | noder, err := c.Noder(ctx, ids[0], opts...) 246 | if err != nil { 247 | return nil, err 248 | } 249 | return []Noder{noder}, nil 250 | case 0: 251 | return []Noder{}, nil 252 | } 253 | 254 | noders := make([]Noder, len(ids)) 255 | errors := make([]error, len(ids)) 256 | tables := make(map[string][]int) 257 | id2idx := make(map[int][]int, len(ids)) 258 | nopts := c.newNodeOpts(opts) 259 | for i, id := range ids { 260 | table, err := nopts.nodeType(ctx, id) 261 | if err != nil { 262 | errors[i] = err 263 | continue 264 | } 265 | tables[table] = append(tables[table], id) 266 | id2idx[id] = append(id2idx[id], i) 267 | } 268 | 269 | for table, ids := range tables { 270 | nodes, err := c.noders(ctx, table, ids) 271 | if err != nil { 272 | for _, id := range ids { 273 | for _, idx := range id2idx[id] { 274 | errors[idx] = err 275 | } 276 | } 277 | } else { 278 | for i, id := range ids { 279 | for _, idx := range id2idx[id] { 280 | noders[idx] = nodes[i] 281 | } 282 | } 283 | } 284 | } 285 | 286 | for i, id := range ids { 287 | if errors[i] == nil { 288 | if noders[i] != nil { 289 | continue 290 | } 291 | errors[i] = entgql.ErrNodeNotFound(id) 292 | } else if IsNotFound(errors[i]) { 293 | errors[i] = multierror.Append(errors[i], entgql.ErrNodeNotFound(id)) 294 | } 295 | ctx := graphql.WithPathContext(ctx, 296 | graphql.NewPathWithIndex(i), 297 | ) 298 | graphql.AddError(ctx, errors[i]) 299 | } 300 | return noders, nil 301 | } 302 | 303 | func (c *Client) noders(ctx context.Context, table string, ids []int) ([]Noder, error) { 304 | noders := make([]Noder, len(ids)) 305 | idmap := make(map[int][]*Noder, len(ids)) 306 | for i, id := range ids { 307 | idmap[id] = append(idmap[id], &noders[i]) 308 | } 309 | switch table { 310 | case todo.Table: 311 | nodes, err := c.Todo.Query(). 312 | Where(todo.IDIn(ids...)). 313 | CollectFields(ctx, "Todo"). 314 | All(ctx) 315 | if err != nil { 316 | return nil, err 317 | } 318 | for _, node := range nodes { 319 | for _, noder := range idmap[node.ID] { 320 | *noder = node 321 | } 322 | } 323 | case user.Table: 324 | nodes, err := c.User.Query(). 325 | Where(user.IDIn(ids...)). 326 | CollectFields(ctx, "User"). 327 | All(ctx) 328 | if err != nil { 329 | return nil, err 330 | } 331 | for _, node := range nodes { 332 | for _, noder := range idmap[node.ID] { 333 | *noder = node 334 | } 335 | } 336 | default: 337 | return nil, fmt.Errorf("cannot resolve noders from table %q: %w", table, errNodeInvalidID) 338 | } 339 | return noders, nil 340 | } 341 | 342 | type tables struct { 343 | once sync.Once 344 | sem *semaphore.Weighted 345 | value atomic.Value 346 | } 347 | 348 | func (t *tables) nodeType(ctx context.Context, drv dialect.Driver, id int) (string, error) { 349 | tables, err := t.Load(ctx, drv) 350 | if err != nil { 351 | return "", err 352 | } 353 | idx := int(id / (1<<32 - 1)) 354 | if idx < 0 || idx >= len(tables) { 355 | return "", fmt.Errorf("cannot resolve table from id %v: %w", id, errNodeInvalidID) 356 | } 357 | return tables[idx], nil 358 | } 359 | 360 | func (t *tables) Load(ctx context.Context, drv dialect.Driver) ([]string, error) { 361 | if tables := t.value.Load(); tables != nil { 362 | return tables.([]string), nil 363 | } 364 | t.once.Do(func() { t.sem = semaphore.NewWeighted(1) }) 365 | if err := t.sem.Acquire(ctx, 1); err != nil { 366 | return nil, err 367 | } 368 | defer t.sem.Release(1) 369 | if tables := t.value.Load(); tables != nil { 370 | return tables.([]string), nil 371 | } 372 | tables, err := t.load(ctx, drv) 373 | if err == nil { 374 | t.value.Store(tables) 375 | } 376 | return tables, err 377 | } 378 | 379 | func (*tables) load(ctx context.Context, drv dialect.Driver) ([]string, error) { 380 | rows := &sql.Rows{} 381 | query, args := sql.Dialect(drv.Dialect()). 382 | Select("type"). 383 | From(sql.Table(schema.TypeTable)). 384 | OrderBy(sql.Asc("id")). 385 | Query() 386 | if err := drv.Query(ctx, query, args, rows); err != nil { 387 | return nil, err 388 | } 389 | defer rows.Close() 390 | var tables []string 391 | return tables, sql.ScanSlice(rows, &tables) 392 | } 393 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/gql_transaction.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "database/sql/driver" 8 | "errors" 9 | ) 10 | 11 | // OpenTx opens a transaction and returns a transactional 12 | // context along with the created transaction. 13 | func (c *Client) OpenTx(ctx context.Context) (context.Context, driver.Tx, error) { 14 | tx, err := c.Tx(ctx) 15 | if err != nil { 16 | return nil, nil, err 17 | } 18 | ctx = NewTxContext(ctx, tx) 19 | ctx = NewContext(ctx, tx.Client()) 20 | return ctx, tx, nil 21 | } 22 | 23 | // OpenTxFromContext open transactions from client stored in context. 24 | func OpenTxFromContext(ctx context.Context) (context.Context, driver.Tx, error) { 25 | client := FromContext(ctx) 26 | if client == nil { 27 | return nil, nil, errors.New("no client attached to context") 28 | } 29 | return client.OpenTx(ctx) 30 | } 31 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/hook/hook.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package hook 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "todo/ent" 9 | ) 10 | 11 | // The TodoFunc type is an adapter to allow the use of ordinary 12 | // function as Todo mutator. 13 | type TodoFunc func(context.Context, *ent.TodoMutation) (ent.Value, error) 14 | 15 | // Mutate calls f(ctx, m). 16 | func (f TodoFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 17 | mv, ok := m.(*ent.TodoMutation) 18 | if !ok { 19 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TodoMutation", m) 20 | } 21 | return f(ctx, mv) 22 | } 23 | 24 | // The UserFunc type is an adapter to allow the use of ordinary 25 | // function as User mutator. 26 | type UserFunc func(context.Context, *ent.UserMutation) (ent.Value, error) 27 | 28 | // Mutate calls f(ctx, m). 29 | func (f UserFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 30 | mv, ok := m.(*ent.UserMutation) 31 | if !ok { 32 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserMutation", m) 33 | } 34 | return f(ctx, mv) 35 | } 36 | 37 | // Condition is a hook condition function. 38 | type Condition func(context.Context, ent.Mutation) bool 39 | 40 | // And groups conditions with the AND operator. 41 | func And(first, second Condition, rest ...Condition) Condition { 42 | return func(ctx context.Context, m ent.Mutation) bool { 43 | if !first(ctx, m) || !second(ctx, m) { 44 | return false 45 | } 46 | for _, cond := range rest { 47 | if !cond(ctx, m) { 48 | return false 49 | } 50 | } 51 | return true 52 | } 53 | } 54 | 55 | // Or groups conditions with the OR operator. 56 | func Or(first, second Condition, rest ...Condition) Condition { 57 | return func(ctx context.Context, m ent.Mutation) bool { 58 | if first(ctx, m) || second(ctx, m) { 59 | return true 60 | } 61 | for _, cond := range rest { 62 | if cond(ctx, m) { 63 | return true 64 | } 65 | } 66 | return false 67 | } 68 | } 69 | 70 | // Not negates a given condition. 71 | func Not(cond Condition) Condition { 72 | return func(ctx context.Context, m ent.Mutation) bool { 73 | return !cond(ctx, m) 74 | } 75 | } 76 | 77 | // HasOp is a condition testing mutation operation. 78 | func HasOp(op ent.Op) Condition { 79 | return func(_ context.Context, m ent.Mutation) bool { 80 | return m.Op().Is(op) 81 | } 82 | } 83 | 84 | // HasAddedFields is a condition validating `.AddedField` on fields. 85 | func HasAddedFields(field string, fields ...string) Condition { 86 | return func(_ context.Context, m ent.Mutation) bool { 87 | if _, exists := m.AddedField(field); !exists { 88 | return false 89 | } 90 | for _, field := range fields { 91 | if _, exists := m.AddedField(field); !exists { 92 | return false 93 | } 94 | } 95 | return true 96 | } 97 | } 98 | 99 | // HasClearedFields is a condition validating `.FieldCleared` on fields. 100 | func HasClearedFields(field string, fields ...string) Condition { 101 | return func(_ context.Context, m ent.Mutation) bool { 102 | if exists := m.FieldCleared(field); !exists { 103 | return false 104 | } 105 | for _, field := range fields { 106 | if exists := m.FieldCleared(field); !exists { 107 | return false 108 | } 109 | } 110 | return true 111 | } 112 | } 113 | 114 | // HasFields is a condition validating `.Field` on fields. 115 | func HasFields(field string, fields ...string) Condition { 116 | return func(_ context.Context, m ent.Mutation) bool { 117 | if _, exists := m.Field(field); !exists { 118 | return false 119 | } 120 | for _, field := range fields { 121 | if _, exists := m.Field(field); !exists { 122 | return false 123 | } 124 | } 125 | return true 126 | } 127 | } 128 | 129 | // If executes the given hook under condition. 130 | // 131 | // hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) 132 | // 133 | func If(hk ent.Hook, cond Condition) ent.Hook { 134 | return func(next ent.Mutator) ent.Mutator { 135 | return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { 136 | if cond(ctx, m) { 137 | return hk(next).Mutate(ctx, m) 138 | } 139 | return next.Mutate(ctx, m) 140 | }) 141 | } 142 | } 143 | 144 | // On executes the given hook only for the given operation. 145 | // 146 | // hook.On(Log, ent.Delete|ent.Create) 147 | // 148 | func On(hk ent.Hook, op ent.Op) ent.Hook { 149 | return If(hk, HasOp(op)) 150 | } 151 | 152 | // Unless skips the given hook only for the given operation. 153 | // 154 | // hook.Unless(Log, ent.Update|ent.UpdateOne) 155 | // 156 | func Unless(hk ent.Hook, op ent.Op) ent.Hook { 157 | return If(hk, Not(HasOp(op))) 158 | } 159 | 160 | // FixedError is a hook returning a fixed error. 161 | func FixedError(err error) ent.Hook { 162 | return func(ent.Mutator) ent.Mutator { 163 | return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { 164 | return nil, err 165 | }) 166 | } 167 | } 168 | 169 | // Reject returns a hook that rejects all operations that match op. 170 | // 171 | // func (T) Hooks() []ent.Hook { 172 | // return []ent.Hook{ 173 | // Reject(ent.Delete|ent.Update), 174 | // } 175 | // } 176 | // 177 | func Reject(op ent.Op) ent.Hook { 178 | hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) 179 | return On(hk, op) 180 | } 181 | 182 | // Chain acts as a list of hooks and is effectively immutable. 183 | // Once created, it will always hold the same set of hooks in the same order. 184 | type Chain struct { 185 | hooks []ent.Hook 186 | } 187 | 188 | // NewChain creates a new chain of hooks. 189 | func NewChain(hooks ...ent.Hook) Chain { 190 | return Chain{append([]ent.Hook(nil), hooks...)} 191 | } 192 | 193 | // Hook chains the list of hooks and returns the final hook. 194 | func (c Chain) Hook() ent.Hook { 195 | return func(mutator ent.Mutator) ent.Mutator { 196 | for i := len(c.hooks) - 1; i >= 0; i-- { 197 | mutator = c.hooks[i](mutator) 198 | } 199 | return mutator 200 | } 201 | } 202 | 203 | // Append extends a chain, adding the specified hook 204 | // as the last ones in the mutation flow. 205 | func (c Chain) Append(hooks ...ent.Hook) Chain { 206 | newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) 207 | newHooks = append(newHooks, c.hooks...) 208 | newHooks = append(newHooks, hooks...) 209 | return Chain{newHooks} 210 | } 211 | 212 | // Extend extends a chain, adding the specified chain 213 | // as the last ones in the mutation flow. 214 | func (c Chain) Extend(chain Chain) Chain { 215 | return c.Append(chain.hooks...) 216 | } 217 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/migrate/migrate.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "io" 9 | 10 | "entgo.io/ent/dialect" 11 | "entgo.io/ent/dialect/sql/schema" 12 | ) 13 | 14 | var ( 15 | // WithGlobalUniqueID sets the universal ids options to the migration. 16 | // If this option is enabled, ent migration will allocate a 1<<32 range 17 | // for the ids of each entity (table). 18 | // Note that this option cannot be applied on tables that already exist. 19 | WithGlobalUniqueID = schema.WithGlobalUniqueID 20 | // WithDropColumn sets the drop column option to the migration. 21 | // If this option is enabled, ent migration will drop old columns 22 | // that were used for both fields and edges. This defaults to false. 23 | WithDropColumn = schema.WithDropColumn 24 | // WithDropIndex sets the drop index option to the migration. 25 | // If this option is enabled, ent migration will drop old indexes 26 | // that were defined in the schema. This defaults to false. 27 | // Note that unique constraints are defined using `UNIQUE INDEX`, 28 | // and therefore, it's recommended to enable this option to get more 29 | // flexibility in the schema changes. 30 | WithDropIndex = schema.WithDropIndex 31 | // WithFixture sets the foreign-key renaming option to the migration when upgrading 32 | // ent from v0.1.0 (issue-#285). Defaults to false. 33 | WithFixture = schema.WithFixture 34 | // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. 35 | WithForeignKeys = schema.WithForeignKeys 36 | ) 37 | 38 | // Schema is the API for creating, migrating and dropping a schema. 39 | type Schema struct { 40 | drv dialect.Driver 41 | universalID bool 42 | } 43 | 44 | // NewSchema creates a new schema client. 45 | func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } 46 | 47 | // Create creates all schema resources. 48 | func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { 49 | migrate, err := schema.NewMigrate(s.drv, opts...) 50 | if err != nil { 51 | return fmt.Errorf("ent/migrate: %w", err) 52 | } 53 | return migrate.Create(ctx, Tables...) 54 | } 55 | 56 | // WriteTo writes the schema changes to w instead of running them against the database. 57 | // 58 | // if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { 59 | // log.Fatal(err) 60 | // } 61 | // 62 | func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { 63 | drv := &schema.WriteDriver{ 64 | Writer: w, 65 | Driver: s.drv, 66 | } 67 | migrate, err := schema.NewMigrate(drv, opts...) 68 | if err != nil { 69 | return fmt.Errorf("ent/migrate: %w", err) 70 | } 71 | return migrate.Create(ctx, Tables...) 72 | } 73 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/migrate/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql/schema" 7 | "entgo.io/ent/schema/field" 8 | ) 9 | 10 | var ( 11 | // TodosColumns holds the columns for the "todos" table. 12 | TodosColumns = []*schema.Column{ 13 | {Name: "id", Type: field.TypeInt, Increment: true}, 14 | {Name: "text", Type: field.TypeString, Size: 2147483647}, 15 | {Name: "created_at", Type: field.TypeTime}, 16 | {Name: "status", Type: field.TypeEnum, Enums: []string{"IN_PROGRESS", "COMPLETED"}, Default: "IN_PROGRESS"}, 17 | {Name: "priority", Type: field.TypeInt, Default: 0}, 18 | {Name: "todo_parent", Type: field.TypeInt, Nullable: true}, 19 | {Name: "user_todos", Type: field.TypeInt, Nullable: true}, 20 | } 21 | // TodosTable holds the schema information for the "todos" table. 22 | TodosTable = &schema.Table{ 23 | Name: "todos", 24 | Columns: TodosColumns, 25 | PrimaryKey: []*schema.Column{TodosColumns[0]}, 26 | ForeignKeys: []*schema.ForeignKey{ 27 | { 28 | Symbol: "todos_todos_parent", 29 | Columns: []*schema.Column{TodosColumns[5]}, 30 | RefColumns: []*schema.Column{TodosColumns[0]}, 31 | OnDelete: schema.SetNull, 32 | }, 33 | { 34 | Symbol: "todos_users_todos", 35 | Columns: []*schema.Column{TodosColumns[6]}, 36 | RefColumns: []*schema.Column{UsersColumns[0]}, 37 | OnDelete: schema.SetNull, 38 | }, 39 | }, 40 | } 41 | // UsersColumns holds the columns for the "users" table. 42 | UsersColumns = []*schema.Column{ 43 | {Name: "id", Type: field.TypeInt, Increment: true}, 44 | {Name: "name", Type: field.TypeString}, 45 | } 46 | // UsersTable holds the schema information for the "users" table. 47 | UsersTable = &schema.Table{ 48 | Name: "users", 49 | Columns: UsersColumns, 50 | PrimaryKey: []*schema.Column{UsersColumns[0]}, 51 | } 52 | // Tables holds all the tables in the schema. 53 | Tables = []*schema.Table{ 54 | TodosTable, 55 | UsersTable, 56 | } 57 | ) 58 | 59 | func init() { 60 | TodosTable.ForeignKeys[0].RefTable = TodosTable 61 | TodosTable.ForeignKeys[1].RefTable = UsersTable 62 | } 63 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/mutation_input.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "time" 7 | "todo/ent/todo" 8 | ) 9 | 10 | // CreateTodoInput represents a mutation input for creating todos. 11 | type CreateTodoInput struct { 12 | Text string 13 | CreatedAt *time.Time 14 | Status *todo.Status 15 | Priority *int 16 | Children []int 17 | Parent *int 18 | Owner *int 19 | } 20 | 21 | // Mutate applies the CreateTodoInput on the TodoCreate builder. 22 | func (i *CreateTodoInput) Mutate(m *TodoCreate) { 23 | m.SetText(i.Text) 24 | if v := i.CreatedAt; v != nil { 25 | m.SetCreatedAt(*v) 26 | } 27 | if v := i.Status; v != nil { 28 | m.SetStatus(*v) 29 | } 30 | if v := i.Priority; v != nil { 31 | m.SetPriority(*v) 32 | } 33 | if ids := i.Children; len(ids) > 0 { 34 | m.AddChildIDs(ids...) 35 | } 36 | if v := i.Parent; v != nil { 37 | m.SetParentID(*v) 38 | } 39 | if v := i.Owner; v != nil { 40 | m.SetOwnerID(*v) 41 | } 42 | } 43 | 44 | // SetInput applies the change-set in the CreateTodoInput on the create builder. 45 | func (c *TodoCreate) SetInput(i CreateTodoInput) *TodoCreate { 46 | i.Mutate(c) 47 | return c 48 | } 49 | 50 | // UpdateTodoInput represents a mutation input for updating todos. 51 | type UpdateTodoInput struct { 52 | Text *string 53 | Status *todo.Status 54 | Priority *int 55 | AddChildIDs []int 56 | RemoveChildIDs []int 57 | Parent *int 58 | ClearParent bool 59 | Owner *int 60 | ClearOwner bool 61 | } 62 | 63 | // Mutate applies the UpdateTodoInput on the TodoMutation. 64 | func (i *UpdateTodoInput) Mutate(m *TodoMutation) { 65 | if v := i.Text; v != nil { 66 | m.SetText(*v) 67 | } 68 | if v := i.Status; v != nil { 69 | m.SetStatus(*v) 70 | } 71 | if v := i.Priority; v != nil { 72 | m.SetPriority(*v) 73 | } 74 | if ids := i.AddChildIDs; len(ids) > 0 { 75 | m.AddChildIDs(ids...) 76 | } 77 | if ids := i.RemoveChildIDs; len(ids) > 0 { 78 | m.RemoveChildIDs(ids...) 79 | } 80 | if i.ClearParent { 81 | m.ClearParent() 82 | } 83 | if v := i.Parent; v != nil { 84 | m.SetParentID(*v) 85 | } 86 | if i.ClearOwner { 87 | m.ClearOwner() 88 | } 89 | if v := i.Owner; v != nil { 90 | m.SetOwnerID(*v) 91 | } 92 | } 93 | 94 | // SetInput applies the change-set in the UpdateTodoInput on the update builder. 95 | func (u *TodoUpdate) SetInput(i UpdateTodoInput) *TodoUpdate { 96 | i.Mutate(u.Mutation()) 97 | return u 98 | } 99 | 100 | // SetInput applies the change-set in the UpdateTodoInput on the update-one builder. 101 | func (u *TodoUpdateOne) SetInput(i UpdateTodoInput) *TodoUpdateOne { 102 | i.Mutate(u.Mutation()) 103 | return u 104 | } 105 | 106 | // CreateUserInput represents a mutation input for creating users. 107 | type CreateUserInput struct { 108 | Name string 109 | Todos []int 110 | } 111 | 112 | // Mutate applies the CreateUserInput on the UserCreate builder. 113 | func (i *CreateUserInput) Mutate(m *UserCreate) { 114 | m.SetName(i.Name) 115 | if ids := i.Todos; len(ids) > 0 { 116 | m.AddTodoIDs(ids...) 117 | } 118 | } 119 | 120 | // SetInput applies the change-set in the CreateUserInput on the create builder. 121 | func (c *UserCreate) SetInput(i CreateUserInput) *UserCreate { 122 | i.Mutate(c) 123 | return c 124 | } 125 | 126 | // UpdateUserInput represents a mutation input for updating users. 127 | type UpdateUserInput struct { 128 | Name *string 129 | AddTodoIDs []int 130 | RemoveTodoIDs []int 131 | } 132 | 133 | // Mutate applies the UpdateUserInput on the UserMutation. 134 | func (i *UpdateUserInput) Mutate(m *UserMutation) { 135 | if v := i.Name; v != nil { 136 | m.SetName(*v) 137 | } 138 | if ids := i.AddTodoIDs; len(ids) > 0 { 139 | m.AddTodoIDs(ids...) 140 | } 141 | if ids := i.RemoveTodoIDs; len(ids) > 0 { 142 | m.RemoveTodoIDs(ids...) 143 | } 144 | } 145 | 146 | // SetInput applies the change-set in the UpdateUserInput on the update builder. 147 | func (u *UserUpdate) SetInput(i UpdateUserInput) *UserUpdate { 148 | i.Mutate(u.Mutation()) 149 | return u 150 | } 151 | 152 | // SetInput applies the change-set in the UpdateUserInput on the update-one builder. 153 | func (u *UserUpdateOne) SetInput(i UpdateUserInput) *UserUpdateOne { 154 | i.Mutate(u.Mutation()) 155 | return u 156 | } 157 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/predicate/predicate.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package predicate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql" 7 | ) 8 | 9 | // Todo is the predicate function for todo builders. 10 | type Todo func(*sql.Selector) 11 | 12 | // User is the predicate function for user builders. 13 | type User func(*sql.Selector) 14 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "time" 7 | "todo/ent/schema" 8 | "todo/ent/todo" 9 | ) 10 | 11 | // The init function reads all schema descriptors with runtime code 12 | // (default values, validators, hooks and policies) and stitches it 13 | // to their package variables. 14 | func init() { 15 | todoFields := schema.Todo{}.Fields() 16 | _ = todoFields 17 | // todoDescText is the schema descriptor for text field. 18 | todoDescText := todoFields[0].Descriptor() 19 | // todo.TextValidator is a validator for the "text" field. It is called by the builders before save. 20 | todo.TextValidator = todoDescText.Validators[0].(func(string) error) 21 | // todoDescCreatedAt is the schema descriptor for created_at field. 22 | todoDescCreatedAt := todoFields[1].Descriptor() 23 | // todo.DefaultCreatedAt holds the default value on creation for the created_at field. 24 | todo.DefaultCreatedAt = todoDescCreatedAt.Default.(func() time.Time) 25 | // todoDescPriority is the schema descriptor for priority field. 26 | todoDescPriority := todoFields[3].Descriptor() 27 | // todo.DefaultPriority holds the default value on creation for the priority field. 28 | todo.DefaultPriority = todoDescPriority.Default.(int) 29 | } 30 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/runtime/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package runtime 4 | 5 | // The schema-stitching logic is generated in todo/ent/runtime.go 6 | 7 | const ( 8 | Version = "(devel)" // Version of ent codegen. 9 | ) 10 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/schema/todo.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "time" 5 | 6 | "entgo.io/contrib/entgql" 7 | "entgo.io/ent/schema/edge" 8 | 9 | "entgo.io/ent" 10 | "entgo.io/ent/schema/field" 11 | ) 12 | 13 | // Todo holds the schema definition for the Todo entity. 14 | type Todo struct { 15 | ent.Schema 16 | } 17 | 18 | // Fields of the Todo. 19 | func (Todo) Fields() []ent.Field { 20 | return []ent.Field{ 21 | field.Text("text"). 22 | NotEmpty(). 23 | Annotations( 24 | entgql.OrderField("TEXT"), 25 | ), 26 | field.Time("created_at"). 27 | Default(time.Now). 28 | Immutable(). 29 | Annotations( 30 | entgql.OrderField("CREATED_AT"), 31 | ), 32 | field.Enum("status"). 33 | NamedValues( 34 | "InProgress", "IN_PROGRESS", 35 | "Completed", "COMPLETED", 36 | ). 37 | Default("IN_PROGRESS"). 38 | Annotations( 39 | entgql.OrderField("STATUS"), 40 | ), 41 | field.Int("priority"). 42 | Default(0). 43 | Annotations( 44 | entgql.OrderField("PRIORITY"), 45 | ), 46 | } 47 | } 48 | 49 | // Edges of the Todo. 50 | func (Todo) Edges() []ent.Edge { 51 | return []ent.Edge{ 52 | edge.To("parent", Todo.Type). 53 | Annotations(entgql.Bind()). 54 | Unique(). 55 | From("children"). 56 | Annotations(entgql.Bind()), 57 | edge.From("owner", User.Type). 58 | Ref("todos"). 59 | Unique(). 60 | Annotations(entgql.Bind()), 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/schema/user.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/contrib/entgql" 5 | "entgo.io/ent" 6 | "entgo.io/ent/schema/edge" 7 | "entgo.io/ent/schema/field" 8 | ) 9 | 10 | // User holds the schema definition for the User entity. 11 | type User struct { 12 | ent.Schema 13 | } 14 | 15 | // Fields of the User. 16 | func (User) Fields() []ent.Field { 17 | return []ent.Field{ 18 | field.String("name"), 19 | } 20 | } 21 | 22 | // Edges of the User. 23 | func (User) Edges() []ent.Edge { 24 | return []ent.Edge{ 25 | edge.To("todos", Todo.Type). 26 | Annotations(entgql.Bind()), 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/template/mutation_input.tmpl: -------------------------------------------------------------------------------- 1 | {{ define "mutation_input" }} 2 | 3 | {{- /*gotype: entgo.io/ent/entc/gen.Graph*/ -}} 4 | 5 | {{ $pkg := base $.Config.Package }} 6 | {{- with extend $ "Package" $pkg }} 7 | {{ template "header" . }} 8 | {{- end }} 9 | 10 | {{ template "import" $ }} 11 | 12 | {{- range $n := $.Nodes }} 13 | {{ $input := print "Create" $n.Name "Input" }} 14 | // {{ $input }} represents a mutation input for creating {{ plural $n.Name | lower }}. 15 | type {{ $input }} struct { 16 | {{- range $f := $n.Fields }} 17 | {{- if not $f.IsEdgeField }} 18 | {{ $f.StructField }} {{ if and (or $f.Optional $f.Default) (not $f.Type.RType.IsPtr) }}*{{ end }}{{ $f.Type }} 19 | {{- end }} 20 | {{- end }} 21 | {{- range $e := $n.Edges }} 22 | {{- if $e.Unique }} 23 | {{ $e.StructField }} {{ if $e.Optional }}*{{ end }}{{ $e.Type.ID.Type }} 24 | {{- else }} 25 | {{ $e.StructField }} []{{ $e.Type.ID.Type }} 26 | {{- end }} 27 | {{- end }} 28 | } 29 | 30 | // Mutate applies the {{ $input }} on the {{ $n.CreateName }} builder. 31 | func (i *{{ $input }}) Mutate(m *{{ $n.CreateName }}) { 32 | {{- range $f := $n.Fields }} 33 | {{- if not $f.IsEdgeField }} 34 | {{- if or $f.Optional $f.Default }} 35 | if v := i.{{ $f.StructField }}; v != nil { 36 | m.{{ $f.MutationSet }}(*v) 37 | } 38 | {{- else }} 39 | m.{{ $f.MutationSet }}(i.{{ $f.StructField }}) 40 | {{- end }} 41 | {{- end }} 42 | {{- end }} 43 | {{- range $e := $n.Edges }} 44 | {{- if $e.Unique }} 45 | {{- if $e.Optional }} 46 | if v := i.{{ $e.StructField }}; v != nil { 47 | m.{{ $e.MutationSet }}(*v) 48 | } 49 | {{- else }} 50 | m.{{ $e.MutationSet }}(i.{{ $e.StructField }}) 51 | {{- end }} 52 | {{- else }} 53 | if ids := i.{{ $e.StructField }}; len(ids) > 0 { 54 | m.{{ $e.MutationAdd }}(ids...) 55 | } 56 | {{- end }} 57 | {{- end }} 58 | } 59 | 60 | // SetInput applies the change-set in the {{ $input }} on the create builder. 61 | func(c *{{ $n.CreateName }}) SetInput(i {{ $input }}) *{{ $n.CreateName }} { 62 | i.Mutate(c) 63 | return c 64 | } 65 | 66 | {{ $input = print "Update" $n.Name "Input" }} 67 | // {{ $input }} represents a mutation input for updating {{ plural $n.Name | lower }}. 68 | type {{ $input }} struct { 69 | {{- range $f := $n.MutableFields }} 70 | {{ $f.StructField }} {{ if not $f.Type.RType.IsPtr }}*{{ end }}{{ $f.Type }} 71 | {{- if $f.Optional }} 72 | {{ print "Clear" $f.StructField }} bool 73 | {{- end }} 74 | {{- end }} 75 | {{- range $e := $n.Edges }} 76 | {{- if $e.Unique }} 77 | {{ $e.StructField }} *{{ $e.Type.ID.Type }} 78 | {{ $e.MutationClear }} bool 79 | {{- else }} 80 | {{ $e.MutationAdd }} []{{ $e.Type.ID.Type }} 81 | {{ $e.MutationRemove }} []{{ $e.Type.ID.Type }} 82 | {{- end }} 83 | {{- end }} 84 | } 85 | 86 | // Mutate applies the {{ $input }} on the {{ $n.MutationName }}. 87 | func (i *{{ $input }}) Mutate(m *{{ $n.MutationName }}) { 88 | {{- range $f := $n.MutableFields }} 89 | {{- if $f.Optional }} 90 | if i.{{ print "Clear" $f.StructField }} { 91 | m.{{ print "Clear" $f.StructField }}() 92 | } 93 | {{- end }} 94 | if v := i.{{ $f.StructField }}; v != nil { 95 | m.{{ $f.MutationSet }}(*v) 96 | } 97 | {{- end }} 98 | {{- range $e := $n.Edges }} 99 | {{- if $e.Unique }} 100 | if i.{{ $e.MutationClear }} { 101 | m.{{ $e.MutationClear }}() 102 | } 103 | if v := i.{{ $e.StructField }}; v != nil { 104 | m.{{ $e.MutationSet }}(*v) 105 | } 106 | {{- else }} 107 | if ids := i.{{ $e.MutationAdd }}; len(ids) > 0 { 108 | m.{{ $e.MutationAdd }}(ids...) 109 | } 110 | if ids := i.{{ $e.MutationRemove }}; len(ids) > 0 { 111 | m.{{ $e.MutationRemove }}(ids...) 112 | } 113 | {{- end }} 114 | {{- end }} 115 | } 116 | 117 | // SetInput applies the change-set in the {{ $input }} on the update builder. 118 | func(u *{{ $n.UpdateName }}) SetInput(i {{ $input }}) *{{ $n.UpdateName }} { 119 | i.Mutate(u.Mutation()) 120 | return u 121 | } 122 | 123 | // SetInput applies the change-set in the {{ $input }} on the update-one builder. 124 | func(u *{{ $n.UpdateOneName }}) SetInput(i {{ $input }}) *{{ $n.UpdateOneName }} { 125 | i.Mutate(u.Mutation()) 126 | return u 127 | } 128 | {{- end }} 129 | {{ end }} 130 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/todo.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | "time" 9 | "todo/ent/todo" 10 | "todo/ent/user" 11 | 12 | "entgo.io/ent/dialect/sql" 13 | ) 14 | 15 | // Todo is the model entity for the Todo schema. 16 | type Todo struct { 17 | config `json:"-"` 18 | // ID of the ent. 19 | ID int `json:"id,omitempty"` 20 | // Text holds the value of the "text" field. 21 | Text string `json:"text,omitempty"` 22 | // CreatedAt holds the value of the "created_at" field. 23 | CreatedAt time.Time `json:"created_at,omitempty"` 24 | // Status holds the value of the "status" field. 25 | Status todo.Status `json:"status,omitempty"` 26 | // Priority holds the value of the "priority" field. 27 | Priority int `json:"priority,omitempty"` 28 | // Edges holds the relations/edges for other nodes in the graph. 29 | // The values are being populated by the TodoQuery when eager-loading is set. 30 | Edges TodoEdges `json:"edges"` 31 | todo_parent *int 32 | user_todos *int 33 | } 34 | 35 | // TodoEdges holds the relations/edges for other nodes in the graph. 36 | type TodoEdges struct { 37 | // Children holds the value of the children edge. 38 | Children []*Todo `json:"children,omitempty"` 39 | // Parent holds the value of the parent edge. 40 | Parent *Todo `json:"parent,omitempty"` 41 | // Owner holds the value of the owner edge. 42 | Owner *User `json:"owner,omitempty"` 43 | // loadedTypes holds the information for reporting if a 44 | // type was loaded (or requested) in eager-loading or not. 45 | loadedTypes [3]bool 46 | } 47 | 48 | // ChildrenOrErr returns the Children value or an error if the edge 49 | // was not loaded in eager-loading. 50 | func (e TodoEdges) ChildrenOrErr() ([]*Todo, error) { 51 | if e.loadedTypes[0] { 52 | return e.Children, nil 53 | } 54 | return nil, &NotLoadedError{edge: "children"} 55 | } 56 | 57 | // ParentOrErr returns the Parent value or an error if the edge 58 | // was not loaded in eager-loading, or loaded but was not found. 59 | func (e TodoEdges) ParentOrErr() (*Todo, error) { 60 | if e.loadedTypes[1] { 61 | if e.Parent == nil { 62 | // The edge parent was loaded in eager-loading, 63 | // but was not found. 64 | return nil, &NotFoundError{label: todo.Label} 65 | } 66 | return e.Parent, nil 67 | } 68 | return nil, &NotLoadedError{edge: "parent"} 69 | } 70 | 71 | // OwnerOrErr returns the Owner value or an error if the edge 72 | // was not loaded in eager-loading, or loaded but was not found. 73 | func (e TodoEdges) OwnerOrErr() (*User, error) { 74 | if e.loadedTypes[2] { 75 | if e.Owner == nil { 76 | // The edge owner was loaded in eager-loading, 77 | // but was not found. 78 | return nil, &NotFoundError{label: user.Label} 79 | } 80 | return e.Owner, nil 81 | } 82 | return nil, &NotLoadedError{edge: "owner"} 83 | } 84 | 85 | // scanValues returns the types for scanning values from sql.Rows. 86 | func (*Todo) scanValues(columns []string) ([]any, error) { 87 | values := make([]any, len(columns)) 88 | for i := range columns { 89 | switch columns[i] { 90 | case todo.FieldID, todo.FieldPriority: 91 | values[i] = new(sql.NullInt64) 92 | case todo.FieldText, todo.FieldStatus: 93 | values[i] = new(sql.NullString) 94 | case todo.FieldCreatedAt: 95 | values[i] = new(sql.NullTime) 96 | case todo.ForeignKeys[0]: // todo_parent 97 | values[i] = new(sql.NullInt64) 98 | case todo.ForeignKeys[1]: // user_todos 99 | values[i] = new(sql.NullInt64) 100 | default: 101 | return nil, fmt.Errorf("unexpected column %q for type Todo", columns[i]) 102 | } 103 | } 104 | return values, nil 105 | } 106 | 107 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 108 | // to the Todo fields. 109 | func (t *Todo) assignValues(columns []string, values []any) error { 110 | if m, n := len(values), len(columns); m < n { 111 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 112 | } 113 | for i := range columns { 114 | switch columns[i] { 115 | case todo.FieldID: 116 | value, ok := values[i].(*sql.NullInt64) 117 | if !ok { 118 | return fmt.Errorf("unexpected type %T for field id", value) 119 | } 120 | t.ID = int(value.Int64) 121 | case todo.FieldText: 122 | if value, ok := values[i].(*sql.NullString); !ok { 123 | return fmt.Errorf("unexpected type %T for field text", values[i]) 124 | } else if value.Valid { 125 | t.Text = value.String 126 | } 127 | case todo.FieldCreatedAt: 128 | if value, ok := values[i].(*sql.NullTime); !ok { 129 | return fmt.Errorf("unexpected type %T for field created_at", values[i]) 130 | } else if value.Valid { 131 | t.CreatedAt = value.Time 132 | } 133 | case todo.FieldStatus: 134 | if value, ok := values[i].(*sql.NullString); !ok { 135 | return fmt.Errorf("unexpected type %T for field status", values[i]) 136 | } else if value.Valid { 137 | t.Status = todo.Status(value.String) 138 | } 139 | case todo.FieldPriority: 140 | if value, ok := values[i].(*sql.NullInt64); !ok { 141 | return fmt.Errorf("unexpected type %T for field priority", values[i]) 142 | } else if value.Valid { 143 | t.Priority = int(value.Int64) 144 | } 145 | case todo.ForeignKeys[0]: 146 | if value, ok := values[i].(*sql.NullInt64); !ok { 147 | return fmt.Errorf("unexpected type %T for edge-field todo_parent", value) 148 | } else if value.Valid { 149 | t.todo_parent = new(int) 150 | *t.todo_parent = int(value.Int64) 151 | } 152 | case todo.ForeignKeys[1]: 153 | if value, ok := values[i].(*sql.NullInt64); !ok { 154 | return fmt.Errorf("unexpected type %T for edge-field user_todos", value) 155 | } else if value.Valid { 156 | t.user_todos = new(int) 157 | *t.user_todos = int(value.Int64) 158 | } 159 | } 160 | } 161 | return nil 162 | } 163 | 164 | // QueryChildren queries the "children" edge of the Todo entity. 165 | func (t *Todo) QueryChildren() *TodoQuery { 166 | return (&TodoClient{config: t.config}).QueryChildren(t) 167 | } 168 | 169 | // QueryParent queries the "parent" edge of the Todo entity. 170 | func (t *Todo) QueryParent() *TodoQuery { 171 | return (&TodoClient{config: t.config}).QueryParent(t) 172 | } 173 | 174 | // QueryOwner queries the "owner" edge of the Todo entity. 175 | func (t *Todo) QueryOwner() *UserQuery { 176 | return (&TodoClient{config: t.config}).QueryOwner(t) 177 | } 178 | 179 | // Update returns a builder for updating this Todo. 180 | // Note that you need to call Todo.Unwrap() before calling this method if this Todo 181 | // was returned from a transaction, and the transaction was committed or rolled back. 182 | func (t *Todo) Update() *TodoUpdateOne { 183 | return (&TodoClient{config: t.config}).UpdateOne(t) 184 | } 185 | 186 | // Unwrap unwraps the Todo entity that was returned from a transaction after it was closed, 187 | // so that all future queries will be executed through the driver which created the transaction. 188 | func (t *Todo) Unwrap() *Todo { 189 | tx, ok := t.config.driver.(*txDriver) 190 | if !ok { 191 | panic("ent: Todo is not a transactional entity") 192 | } 193 | t.config.driver = tx.drv 194 | return t 195 | } 196 | 197 | // String implements the fmt.Stringer. 198 | func (t *Todo) String() string { 199 | var builder strings.Builder 200 | builder.WriteString("Todo(") 201 | builder.WriteString(fmt.Sprintf("id=%v", t.ID)) 202 | builder.WriteString(", text=") 203 | builder.WriteString(t.Text) 204 | builder.WriteString(", created_at=") 205 | builder.WriteString(t.CreatedAt.Format(time.ANSIC)) 206 | builder.WriteString(", status=") 207 | builder.WriteString(fmt.Sprintf("%v", t.Status)) 208 | builder.WriteString(", priority=") 209 | builder.WriteString(fmt.Sprintf("%v", t.Priority)) 210 | builder.WriteByte(')') 211 | return builder.String() 212 | } 213 | 214 | // Todos is a parsable slice of Todo. 215 | type Todos []*Todo 216 | 217 | func (t Todos) config(cfg config) { 218 | for _i := range t { 219 | t[_i].config = cfg 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/todo/todo.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package todo 4 | 5 | import ( 6 | "fmt" 7 | "io" 8 | "strconv" 9 | "time" 10 | ) 11 | 12 | const ( 13 | // Label holds the string label denoting the todo type in the database. 14 | Label = "todo" 15 | // FieldID holds the string denoting the id field in the database. 16 | FieldID = "id" 17 | // FieldText holds the string denoting the text field in the database. 18 | FieldText = "text" 19 | // FieldCreatedAt holds the string denoting the created_at field in the database. 20 | FieldCreatedAt = "created_at" 21 | // FieldStatus holds the string denoting the status field in the database. 22 | FieldStatus = "status" 23 | // FieldPriority holds the string denoting the priority field in the database. 24 | FieldPriority = "priority" 25 | // EdgeChildren holds the string denoting the children edge name in mutations. 26 | EdgeChildren = "children" 27 | // EdgeParent holds the string denoting the parent edge name in mutations. 28 | EdgeParent = "parent" 29 | // EdgeOwner holds the string denoting the owner edge name in mutations. 30 | EdgeOwner = "owner" 31 | // Table holds the table name of the todo in the database. 32 | Table = "todos" 33 | // ChildrenTable is the table that holds the children relation/edge. 34 | ChildrenTable = "todos" 35 | // ChildrenColumn is the table column denoting the children relation/edge. 36 | ChildrenColumn = "todo_parent" 37 | // ParentTable is the table that holds the parent relation/edge. 38 | ParentTable = "todos" 39 | // ParentColumn is the table column denoting the parent relation/edge. 40 | ParentColumn = "todo_parent" 41 | // OwnerTable is the table that holds the owner relation/edge. 42 | OwnerTable = "todos" 43 | // OwnerInverseTable is the table name for the User entity. 44 | // It exists in this package in order to avoid circular dependency with the "user" package. 45 | OwnerInverseTable = "users" 46 | // OwnerColumn is the table column denoting the owner relation/edge. 47 | OwnerColumn = "user_todos" 48 | ) 49 | 50 | // Columns holds all SQL columns for todo fields. 51 | var Columns = []string{ 52 | FieldID, 53 | FieldText, 54 | FieldCreatedAt, 55 | FieldStatus, 56 | FieldPriority, 57 | } 58 | 59 | // ForeignKeys holds the SQL foreign-keys that are owned by the "todos" 60 | // table and are not defined as standalone fields in the schema. 61 | var ForeignKeys = []string{ 62 | "todo_parent", 63 | "user_todos", 64 | } 65 | 66 | // ValidColumn reports if the column name is valid (part of the table columns). 67 | func ValidColumn(column string) bool { 68 | for i := range Columns { 69 | if column == Columns[i] { 70 | return true 71 | } 72 | } 73 | for i := range ForeignKeys { 74 | if column == ForeignKeys[i] { 75 | return true 76 | } 77 | } 78 | return false 79 | } 80 | 81 | var ( 82 | // TextValidator is a validator for the "text" field. It is called by the builders before save. 83 | TextValidator func(string) error 84 | // DefaultCreatedAt holds the default value on creation for the "created_at" field. 85 | DefaultCreatedAt func() time.Time 86 | // DefaultPriority holds the default value on creation for the "priority" field. 87 | DefaultPriority int 88 | ) 89 | 90 | // Status defines the type for the "status" enum field. 91 | type Status string 92 | 93 | // StatusInProgress is the default value of the Status enum. 94 | const DefaultStatus = StatusInProgress 95 | 96 | // Status values. 97 | const ( 98 | StatusInProgress Status = "IN_PROGRESS" 99 | StatusCompleted Status = "COMPLETED" 100 | ) 101 | 102 | func (s Status) String() string { 103 | return string(s) 104 | } 105 | 106 | // StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. 107 | func StatusValidator(s Status) error { 108 | switch s { 109 | case StatusInProgress, StatusCompleted: 110 | return nil 111 | default: 112 | return fmt.Errorf("todo: invalid enum value for status field: %q", s) 113 | } 114 | } 115 | 116 | // MarshalGQL implements graphql.Marshaler interface. 117 | func (s Status) MarshalGQL(w io.Writer) { 118 | io.WriteString(w, strconv.Quote(s.String())) 119 | } 120 | 121 | // UnmarshalGQL implements graphql.Unmarshaler interface. 122 | func (s *Status) UnmarshalGQL(val any) error { 123 | str, ok := val.(string) 124 | if !ok { 125 | return fmt.Errorf("enum %T must be a string", val) 126 | } 127 | *s = Status(str) 128 | if err := StatusValidator(*s); err != nil { 129 | return fmt.Errorf("%s is not a valid Status", str) 130 | } 131 | return nil 132 | } 133 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/todo_create.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "time" 10 | "todo/ent/todo" 11 | "todo/ent/user" 12 | 13 | "entgo.io/ent/dialect/sql/sqlgraph" 14 | "entgo.io/ent/schema/field" 15 | ) 16 | 17 | // TodoCreate is the builder for creating a Todo entity. 18 | type TodoCreate struct { 19 | config 20 | mutation *TodoMutation 21 | hooks []Hook 22 | } 23 | 24 | // SetText sets the "text" field. 25 | func (tc *TodoCreate) SetText(s string) *TodoCreate { 26 | tc.mutation.SetText(s) 27 | return tc 28 | } 29 | 30 | // SetCreatedAt sets the "created_at" field. 31 | func (tc *TodoCreate) SetCreatedAt(t time.Time) *TodoCreate { 32 | tc.mutation.SetCreatedAt(t) 33 | return tc 34 | } 35 | 36 | // SetNillableCreatedAt sets the "created_at" field if the given value is not nil. 37 | func (tc *TodoCreate) SetNillableCreatedAt(t *time.Time) *TodoCreate { 38 | if t != nil { 39 | tc.SetCreatedAt(*t) 40 | } 41 | return tc 42 | } 43 | 44 | // SetStatus sets the "status" field. 45 | func (tc *TodoCreate) SetStatus(t todo.Status) *TodoCreate { 46 | tc.mutation.SetStatus(t) 47 | return tc 48 | } 49 | 50 | // SetNillableStatus sets the "status" field if the given value is not nil. 51 | func (tc *TodoCreate) SetNillableStatus(t *todo.Status) *TodoCreate { 52 | if t != nil { 53 | tc.SetStatus(*t) 54 | } 55 | return tc 56 | } 57 | 58 | // SetPriority sets the "priority" field. 59 | func (tc *TodoCreate) SetPriority(i int) *TodoCreate { 60 | tc.mutation.SetPriority(i) 61 | return tc 62 | } 63 | 64 | // SetNillablePriority sets the "priority" field if the given value is not nil. 65 | func (tc *TodoCreate) SetNillablePriority(i *int) *TodoCreate { 66 | if i != nil { 67 | tc.SetPriority(*i) 68 | } 69 | return tc 70 | } 71 | 72 | // AddChildIDs adds the "children" edge to the Todo entity by IDs. 73 | func (tc *TodoCreate) AddChildIDs(ids ...int) *TodoCreate { 74 | tc.mutation.AddChildIDs(ids...) 75 | return tc 76 | } 77 | 78 | // AddChildren adds the "children" edges to the Todo entity. 79 | func (tc *TodoCreate) AddChildren(t ...*Todo) *TodoCreate { 80 | ids := make([]int, len(t)) 81 | for i := range t { 82 | ids[i] = t[i].ID 83 | } 84 | return tc.AddChildIDs(ids...) 85 | } 86 | 87 | // SetParentID sets the "parent" edge to the Todo entity by ID. 88 | func (tc *TodoCreate) SetParentID(id int) *TodoCreate { 89 | tc.mutation.SetParentID(id) 90 | return tc 91 | } 92 | 93 | // SetNillableParentID sets the "parent" edge to the Todo entity by ID if the given value is not nil. 94 | func (tc *TodoCreate) SetNillableParentID(id *int) *TodoCreate { 95 | if id != nil { 96 | tc = tc.SetParentID(*id) 97 | } 98 | return tc 99 | } 100 | 101 | // SetParent sets the "parent" edge to the Todo entity. 102 | func (tc *TodoCreate) SetParent(t *Todo) *TodoCreate { 103 | return tc.SetParentID(t.ID) 104 | } 105 | 106 | // SetOwnerID sets the "owner" edge to the User entity by ID. 107 | func (tc *TodoCreate) SetOwnerID(id int) *TodoCreate { 108 | tc.mutation.SetOwnerID(id) 109 | return tc 110 | } 111 | 112 | // SetNillableOwnerID sets the "owner" edge to the User entity by ID if the given value is not nil. 113 | func (tc *TodoCreate) SetNillableOwnerID(id *int) *TodoCreate { 114 | if id != nil { 115 | tc = tc.SetOwnerID(*id) 116 | } 117 | return tc 118 | } 119 | 120 | // SetOwner sets the "owner" edge to the User entity. 121 | func (tc *TodoCreate) SetOwner(u *User) *TodoCreate { 122 | return tc.SetOwnerID(u.ID) 123 | } 124 | 125 | // Mutation returns the TodoMutation object of the builder. 126 | func (tc *TodoCreate) Mutation() *TodoMutation { 127 | return tc.mutation 128 | } 129 | 130 | // Save creates the Todo in the database. 131 | func (tc *TodoCreate) Save(ctx context.Context) (*Todo, error) { 132 | var ( 133 | err error 134 | node *Todo 135 | ) 136 | tc.defaults() 137 | if len(tc.hooks) == 0 { 138 | if err = tc.check(); err != nil { 139 | return nil, err 140 | } 141 | node, err = tc.sqlSave(ctx) 142 | } else { 143 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 144 | mutation, ok := m.(*TodoMutation) 145 | if !ok { 146 | return nil, fmt.Errorf("unexpected mutation type %T", m) 147 | } 148 | if err = tc.check(); err != nil { 149 | return nil, err 150 | } 151 | tc.mutation = mutation 152 | if node, err = tc.sqlSave(ctx); err != nil { 153 | return nil, err 154 | } 155 | mutation.id = &node.ID 156 | mutation.done = true 157 | return node, err 158 | }) 159 | for i := len(tc.hooks) - 1; i >= 0; i-- { 160 | if tc.hooks[i] == nil { 161 | return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 162 | } 163 | mut = tc.hooks[i](mut) 164 | } 165 | if _, err := mut.Mutate(ctx, tc.mutation); err != nil { 166 | return nil, err 167 | } 168 | } 169 | return node, err 170 | } 171 | 172 | // SaveX calls Save and panics if Save returns an error. 173 | func (tc *TodoCreate) SaveX(ctx context.Context) *Todo { 174 | v, err := tc.Save(ctx) 175 | if err != nil { 176 | panic(err) 177 | } 178 | return v 179 | } 180 | 181 | // Exec executes the query. 182 | func (tc *TodoCreate) Exec(ctx context.Context) error { 183 | _, err := tc.Save(ctx) 184 | return err 185 | } 186 | 187 | // ExecX is like Exec, but panics if an error occurs. 188 | func (tc *TodoCreate) ExecX(ctx context.Context) { 189 | if err := tc.Exec(ctx); err != nil { 190 | panic(err) 191 | } 192 | } 193 | 194 | // defaults sets the default values of the builder before save. 195 | func (tc *TodoCreate) defaults() { 196 | if _, ok := tc.mutation.CreatedAt(); !ok { 197 | v := todo.DefaultCreatedAt() 198 | tc.mutation.SetCreatedAt(v) 199 | } 200 | if _, ok := tc.mutation.Status(); !ok { 201 | v := todo.DefaultStatus 202 | tc.mutation.SetStatus(v) 203 | } 204 | if _, ok := tc.mutation.Priority(); !ok { 205 | v := todo.DefaultPriority 206 | tc.mutation.SetPriority(v) 207 | } 208 | } 209 | 210 | // check runs all checks and user-defined validators on the builder. 211 | func (tc *TodoCreate) check() error { 212 | if _, ok := tc.mutation.Text(); !ok { 213 | return &ValidationError{Name: "text", err: errors.New(`ent: missing required field "Todo.text"`)} 214 | } 215 | if v, ok := tc.mutation.Text(); ok { 216 | if err := todo.TextValidator(v); err != nil { 217 | return &ValidationError{Name: "text", err: fmt.Errorf(`ent: validator failed for field "Todo.text": %w`, err)} 218 | } 219 | } 220 | if _, ok := tc.mutation.CreatedAt(); !ok { 221 | return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "Todo.created_at"`)} 222 | } 223 | if _, ok := tc.mutation.Status(); !ok { 224 | return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Todo.status"`)} 225 | } 226 | if v, ok := tc.mutation.Status(); ok { 227 | if err := todo.StatusValidator(v); err != nil { 228 | return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Todo.status": %w`, err)} 229 | } 230 | } 231 | if _, ok := tc.mutation.Priority(); !ok { 232 | return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Todo.priority"`)} 233 | } 234 | return nil 235 | } 236 | 237 | func (tc *TodoCreate) sqlSave(ctx context.Context) (*Todo, error) { 238 | _node, _spec := tc.createSpec() 239 | if err := sqlgraph.CreateNode(ctx, tc.driver, _spec); err != nil { 240 | if sqlgraph.IsConstraintError(err) { 241 | err = &ConstraintError{err.Error(), err} 242 | } 243 | return nil, err 244 | } 245 | id := _spec.ID.Value.(int64) 246 | _node.ID = int(id) 247 | return _node, nil 248 | } 249 | 250 | func (tc *TodoCreate) createSpec() (*Todo, *sqlgraph.CreateSpec) { 251 | var ( 252 | _node = &Todo{config: tc.config} 253 | _spec = &sqlgraph.CreateSpec{ 254 | Table: todo.Table, 255 | ID: &sqlgraph.FieldSpec{ 256 | Type: field.TypeInt, 257 | Column: todo.FieldID, 258 | }, 259 | } 260 | ) 261 | if value, ok := tc.mutation.Text(); ok { 262 | _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ 263 | Type: field.TypeString, 264 | Value: value, 265 | Column: todo.FieldText, 266 | }) 267 | _node.Text = value 268 | } 269 | if value, ok := tc.mutation.CreatedAt(); ok { 270 | _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ 271 | Type: field.TypeTime, 272 | Value: value, 273 | Column: todo.FieldCreatedAt, 274 | }) 275 | _node.CreatedAt = value 276 | } 277 | if value, ok := tc.mutation.Status(); ok { 278 | _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ 279 | Type: field.TypeEnum, 280 | Value: value, 281 | Column: todo.FieldStatus, 282 | }) 283 | _node.Status = value 284 | } 285 | if value, ok := tc.mutation.Priority(); ok { 286 | _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ 287 | Type: field.TypeInt, 288 | Value: value, 289 | Column: todo.FieldPriority, 290 | }) 291 | _node.Priority = value 292 | } 293 | if nodes := tc.mutation.ChildrenIDs(); len(nodes) > 0 { 294 | edge := &sqlgraph.EdgeSpec{ 295 | Rel: sqlgraph.O2M, 296 | Inverse: true, 297 | Table: todo.ChildrenTable, 298 | Columns: []string{todo.ChildrenColumn}, 299 | Bidi: false, 300 | Target: &sqlgraph.EdgeTarget{ 301 | IDSpec: &sqlgraph.FieldSpec{ 302 | Type: field.TypeInt, 303 | Column: todo.FieldID, 304 | }, 305 | }, 306 | } 307 | for _, k := range nodes { 308 | edge.Target.Nodes = append(edge.Target.Nodes, k) 309 | } 310 | _spec.Edges = append(_spec.Edges, edge) 311 | } 312 | if nodes := tc.mutation.ParentIDs(); len(nodes) > 0 { 313 | edge := &sqlgraph.EdgeSpec{ 314 | Rel: sqlgraph.M2O, 315 | Inverse: false, 316 | Table: todo.ParentTable, 317 | Columns: []string{todo.ParentColumn}, 318 | Bidi: false, 319 | Target: &sqlgraph.EdgeTarget{ 320 | IDSpec: &sqlgraph.FieldSpec{ 321 | Type: field.TypeInt, 322 | Column: todo.FieldID, 323 | }, 324 | }, 325 | } 326 | for _, k := range nodes { 327 | edge.Target.Nodes = append(edge.Target.Nodes, k) 328 | } 329 | _node.todo_parent = &nodes[0] 330 | _spec.Edges = append(_spec.Edges, edge) 331 | } 332 | if nodes := tc.mutation.OwnerIDs(); len(nodes) > 0 { 333 | edge := &sqlgraph.EdgeSpec{ 334 | Rel: sqlgraph.M2O, 335 | Inverse: true, 336 | Table: todo.OwnerTable, 337 | Columns: []string{todo.OwnerColumn}, 338 | Bidi: false, 339 | Target: &sqlgraph.EdgeTarget{ 340 | IDSpec: &sqlgraph.FieldSpec{ 341 | Type: field.TypeInt, 342 | Column: user.FieldID, 343 | }, 344 | }, 345 | } 346 | for _, k := range nodes { 347 | edge.Target.Nodes = append(edge.Target.Nodes, k) 348 | } 349 | _node.user_todos = &nodes[0] 350 | _spec.Edges = append(_spec.Edges, edge) 351 | } 352 | return _node, _spec 353 | } 354 | 355 | // TodoCreateBulk is the builder for creating many Todo entities in bulk. 356 | type TodoCreateBulk struct { 357 | config 358 | builders []*TodoCreate 359 | } 360 | 361 | // Save creates the Todo entities in the database. 362 | func (tcb *TodoCreateBulk) Save(ctx context.Context) ([]*Todo, error) { 363 | specs := make([]*sqlgraph.CreateSpec, len(tcb.builders)) 364 | nodes := make([]*Todo, len(tcb.builders)) 365 | mutators := make([]Mutator, len(tcb.builders)) 366 | for i := range tcb.builders { 367 | func(i int, root context.Context) { 368 | builder := tcb.builders[i] 369 | builder.defaults() 370 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 371 | mutation, ok := m.(*TodoMutation) 372 | if !ok { 373 | return nil, fmt.Errorf("unexpected mutation type %T", m) 374 | } 375 | if err := builder.check(); err != nil { 376 | return nil, err 377 | } 378 | builder.mutation = mutation 379 | nodes[i], specs[i] = builder.createSpec() 380 | var err error 381 | if i < len(mutators)-1 { 382 | _, err = mutators[i+1].Mutate(root, tcb.builders[i+1].mutation) 383 | } else { 384 | spec := &sqlgraph.BatchCreateSpec{Nodes: specs} 385 | // Invoke the actual operation on the latest mutation in the chain. 386 | if err = sqlgraph.BatchCreate(ctx, tcb.driver, spec); err != nil { 387 | if sqlgraph.IsConstraintError(err) { 388 | err = &ConstraintError{err.Error(), err} 389 | } 390 | } 391 | } 392 | if err != nil { 393 | return nil, err 394 | } 395 | mutation.id = &nodes[i].ID 396 | mutation.done = true 397 | if specs[i].ID.Value != nil { 398 | id := specs[i].ID.Value.(int64) 399 | nodes[i].ID = int(id) 400 | } 401 | return nodes[i], nil 402 | }) 403 | for i := len(builder.hooks) - 1; i >= 0; i-- { 404 | mut = builder.hooks[i](mut) 405 | } 406 | mutators[i] = mut 407 | }(i, ctx) 408 | } 409 | if len(mutators) > 0 { 410 | if _, err := mutators[0].Mutate(ctx, tcb.builders[0].mutation); err != nil { 411 | return nil, err 412 | } 413 | } 414 | return nodes, nil 415 | } 416 | 417 | // SaveX is like Save, but panics if an error occurs. 418 | func (tcb *TodoCreateBulk) SaveX(ctx context.Context) []*Todo { 419 | v, err := tcb.Save(ctx) 420 | if err != nil { 421 | panic(err) 422 | } 423 | return v 424 | } 425 | 426 | // Exec executes the query. 427 | func (tcb *TodoCreateBulk) Exec(ctx context.Context) error { 428 | _, err := tcb.Save(ctx) 429 | return err 430 | } 431 | 432 | // ExecX is like Exec, but panics if an error occurs. 433 | func (tcb *TodoCreateBulk) ExecX(ctx context.Context) { 434 | if err := tcb.Exec(ctx); err != nil { 435 | panic(err) 436 | } 437 | } 438 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/todo_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "todo/ent/predicate" 9 | "todo/ent/todo" 10 | 11 | "entgo.io/ent/dialect/sql" 12 | "entgo.io/ent/dialect/sql/sqlgraph" 13 | "entgo.io/ent/schema/field" 14 | ) 15 | 16 | // TodoDelete is the builder for deleting a Todo entity. 17 | type TodoDelete struct { 18 | config 19 | hooks []Hook 20 | mutation *TodoMutation 21 | } 22 | 23 | // Where appends a list predicates to the TodoDelete builder. 24 | func (td *TodoDelete) Where(ps ...predicate.Todo) *TodoDelete { 25 | td.mutation.Where(ps...) 26 | return td 27 | } 28 | 29 | // Exec executes the deletion query and returns how many vertices were deleted. 30 | func (td *TodoDelete) Exec(ctx context.Context) (int, error) { 31 | var ( 32 | err error 33 | affected int 34 | ) 35 | if len(td.hooks) == 0 { 36 | affected, err = td.sqlExec(ctx) 37 | } else { 38 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 39 | mutation, ok := m.(*TodoMutation) 40 | if !ok { 41 | return nil, fmt.Errorf("unexpected mutation type %T", m) 42 | } 43 | td.mutation = mutation 44 | affected, err = td.sqlExec(ctx) 45 | mutation.done = true 46 | return affected, err 47 | }) 48 | for i := len(td.hooks) - 1; i >= 0; i-- { 49 | if td.hooks[i] == nil { 50 | return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 51 | } 52 | mut = td.hooks[i](mut) 53 | } 54 | if _, err := mut.Mutate(ctx, td.mutation); err != nil { 55 | return 0, err 56 | } 57 | } 58 | return affected, err 59 | } 60 | 61 | // ExecX is like Exec, but panics if an error occurs. 62 | func (td *TodoDelete) ExecX(ctx context.Context) int { 63 | n, err := td.Exec(ctx) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return n 68 | } 69 | 70 | func (td *TodoDelete) sqlExec(ctx context.Context) (int, error) { 71 | _spec := &sqlgraph.DeleteSpec{ 72 | Node: &sqlgraph.NodeSpec{ 73 | Table: todo.Table, 74 | ID: &sqlgraph.FieldSpec{ 75 | Type: field.TypeInt, 76 | Column: todo.FieldID, 77 | }, 78 | }, 79 | } 80 | if ps := td.mutation.predicates; len(ps) > 0 { 81 | _spec.Predicate = func(selector *sql.Selector) { 82 | for i := range ps { 83 | ps[i](selector) 84 | } 85 | } 86 | } 87 | return sqlgraph.DeleteNodes(ctx, td.driver, _spec) 88 | } 89 | 90 | // TodoDeleteOne is the builder for deleting a single Todo entity. 91 | type TodoDeleteOne struct { 92 | td *TodoDelete 93 | } 94 | 95 | // Exec executes the deletion query. 96 | func (tdo *TodoDeleteOne) Exec(ctx context.Context) error { 97 | n, err := tdo.td.Exec(ctx) 98 | switch { 99 | case err != nil: 100 | return err 101 | case n == 0: 102 | return &NotFoundError{todo.Label} 103 | default: 104 | return nil 105 | } 106 | } 107 | 108 | // ExecX is like Exec, but panics if an error occurs. 109 | func (tdo *TodoDeleteOne) ExecX(ctx context.Context) { 110 | tdo.td.ExecX(ctx) 111 | } 112 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/tx.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "sync" 8 | 9 | "entgo.io/ent/dialect" 10 | ) 11 | 12 | // Tx is a transactional client that is created by calling Client.Tx(). 13 | type Tx struct { 14 | config 15 | // Todo is the client for interacting with the Todo builders. 16 | Todo *TodoClient 17 | // User is the client for interacting with the User builders. 18 | User *UserClient 19 | 20 | // lazily loaded. 21 | client *Client 22 | clientOnce sync.Once 23 | 24 | // completion callbacks. 25 | mu sync.Mutex 26 | onCommit []CommitHook 27 | onRollback []RollbackHook 28 | 29 | // ctx lives for the life of the transaction. It is 30 | // the same context used by the underlying connection. 31 | ctx context.Context 32 | } 33 | 34 | type ( 35 | // Committer is the interface that wraps the Committer method. 36 | Committer interface { 37 | Commit(context.Context, *Tx) error 38 | } 39 | 40 | // The CommitFunc type is an adapter to allow the use of ordinary 41 | // function as a Committer. If f is a function with the appropriate 42 | // signature, CommitFunc(f) is a Committer that calls f. 43 | CommitFunc func(context.Context, *Tx) error 44 | 45 | // CommitHook defines the "commit middleware". A function that gets a Committer 46 | // and returns a Committer. For example: 47 | // 48 | // hook := func(next ent.Committer) ent.Committer { 49 | // return ent.CommitFunc(func(context.Context, tx *ent.Tx) error { 50 | // // Do some stuff before. 51 | // if err := next.Commit(ctx, tx); err != nil { 52 | // return err 53 | // } 54 | // // Do some stuff after. 55 | // return nil 56 | // }) 57 | // } 58 | // 59 | CommitHook func(Committer) Committer 60 | ) 61 | 62 | // Commit calls f(ctx, m). 63 | func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error { 64 | return f(ctx, tx) 65 | } 66 | 67 | // Commit commits the transaction. 68 | func (tx *Tx) Commit() error { 69 | txDriver := tx.config.driver.(*txDriver) 70 | var fn Committer = CommitFunc(func(context.Context, *Tx) error { 71 | return txDriver.tx.Commit() 72 | }) 73 | tx.mu.Lock() 74 | hooks := append([]CommitHook(nil), tx.onCommit...) 75 | tx.mu.Unlock() 76 | for i := len(hooks) - 1; i >= 0; i-- { 77 | fn = hooks[i](fn) 78 | } 79 | return fn.Commit(tx.ctx, tx) 80 | } 81 | 82 | // OnCommit adds a hook to call on commit. 83 | func (tx *Tx) OnCommit(f CommitHook) { 84 | tx.mu.Lock() 85 | defer tx.mu.Unlock() 86 | tx.onCommit = append(tx.onCommit, f) 87 | } 88 | 89 | type ( 90 | // Rollbacker is the interface that wraps the Rollbacker method. 91 | Rollbacker interface { 92 | Rollback(context.Context, *Tx) error 93 | } 94 | 95 | // The RollbackFunc type is an adapter to allow the use of ordinary 96 | // function as a Rollbacker. If f is a function with the appropriate 97 | // signature, RollbackFunc(f) is a Rollbacker that calls f. 98 | RollbackFunc func(context.Context, *Tx) error 99 | 100 | // RollbackHook defines the "rollback middleware". A function that gets a Rollbacker 101 | // and returns a Rollbacker. For example: 102 | // 103 | // hook := func(next ent.Rollbacker) ent.Rollbacker { 104 | // return ent.RollbackFunc(func(context.Context, tx *ent.Tx) error { 105 | // // Do some stuff before. 106 | // if err := next.Rollback(ctx, tx); err != nil { 107 | // return err 108 | // } 109 | // // Do some stuff after. 110 | // return nil 111 | // }) 112 | // } 113 | // 114 | RollbackHook func(Rollbacker) Rollbacker 115 | ) 116 | 117 | // Rollback calls f(ctx, m). 118 | func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error { 119 | return f(ctx, tx) 120 | } 121 | 122 | // Rollback rollbacks the transaction. 123 | func (tx *Tx) Rollback() error { 124 | txDriver := tx.config.driver.(*txDriver) 125 | var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { 126 | return txDriver.tx.Rollback() 127 | }) 128 | tx.mu.Lock() 129 | hooks := append([]RollbackHook(nil), tx.onRollback...) 130 | tx.mu.Unlock() 131 | for i := len(hooks) - 1; i >= 0; i-- { 132 | fn = hooks[i](fn) 133 | } 134 | return fn.Rollback(tx.ctx, tx) 135 | } 136 | 137 | // OnRollback adds a hook to call on rollback. 138 | func (tx *Tx) OnRollback(f RollbackHook) { 139 | tx.mu.Lock() 140 | defer tx.mu.Unlock() 141 | tx.onRollback = append(tx.onRollback, f) 142 | } 143 | 144 | // Client returns a Client that binds to current transaction. 145 | func (tx *Tx) Client() *Client { 146 | tx.clientOnce.Do(func() { 147 | tx.client = &Client{config: tx.config} 148 | tx.client.init() 149 | }) 150 | return tx.client 151 | } 152 | 153 | func (tx *Tx) init() { 154 | tx.Todo = NewTodoClient(tx.config) 155 | tx.User = NewUserClient(tx.config) 156 | } 157 | 158 | // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. 159 | // The idea is to support transactions without adding any extra code to the builders. 160 | // When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. 161 | // Commit and Rollback are nop for the internal builders and the user must call one 162 | // of them in order to commit or rollback the transaction. 163 | // 164 | // If a closed transaction is embedded in one of the generated entities, and the entity 165 | // applies a query, for example: Todo.QueryXXX(), the query will be executed 166 | // through the driver which created this transaction. 167 | // 168 | // Note that txDriver is not goroutine safe. 169 | type txDriver struct { 170 | // the driver we started the transaction from. 171 | drv dialect.Driver 172 | // tx is the underlying transaction. 173 | tx dialect.Tx 174 | } 175 | 176 | // newTx creates a new transactional driver. 177 | func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { 178 | tx, err := drv.Tx(ctx) 179 | if err != nil { 180 | return nil, err 181 | } 182 | return &txDriver{tx: tx, drv: drv}, nil 183 | } 184 | 185 | // Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls 186 | // from the internal builders. Should be called only by the internal builders. 187 | func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } 188 | 189 | // Dialect returns the dialect of the driver we started the transaction from. 190 | func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } 191 | 192 | // Close is a nop close. 193 | func (*txDriver) Close() error { return nil } 194 | 195 | // Commit is a nop commit for the internal builders. 196 | // User must call `Tx.Commit` in order to commit the transaction. 197 | func (*txDriver) Commit() error { return nil } 198 | 199 | // Rollback is a nop rollback for the internal builders. 200 | // User must call `Tx.Rollback` in order to rollback the transaction. 201 | func (*txDriver) Rollback() error { return nil } 202 | 203 | // Exec calls tx.Exec. 204 | func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { 205 | return tx.tx.Exec(ctx, query, args, v) 206 | } 207 | 208 | // Query calls tx.Query. 209 | func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { 210 | return tx.tx.Query(ctx, query, args, v) 211 | } 212 | 213 | var _ dialect.Driver = (*txDriver)(nil) 214 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/user.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | "todo/ent/user" 9 | 10 | "entgo.io/ent/dialect/sql" 11 | ) 12 | 13 | // User is the model entity for the User schema. 14 | type User struct { 15 | config `json:"-"` 16 | // ID of the ent. 17 | ID int `json:"id,omitempty"` 18 | // Name holds the value of the "name" field. 19 | Name string `json:"name,omitempty"` 20 | // Edges holds the relations/edges for other nodes in the graph. 21 | // The values are being populated by the UserQuery when eager-loading is set. 22 | Edges UserEdges `json:"edges"` 23 | } 24 | 25 | // UserEdges holds the relations/edges for other nodes in the graph. 26 | type UserEdges struct { 27 | // Todos holds the value of the todos edge. 28 | Todos []*Todo `json:"todos,omitempty"` 29 | // loadedTypes holds the information for reporting if a 30 | // type was loaded (or requested) in eager-loading or not. 31 | loadedTypes [1]bool 32 | } 33 | 34 | // TodosOrErr returns the Todos value or an error if the edge 35 | // was not loaded in eager-loading. 36 | func (e UserEdges) TodosOrErr() ([]*Todo, error) { 37 | if e.loadedTypes[0] { 38 | return e.Todos, nil 39 | } 40 | return nil, &NotLoadedError{edge: "todos"} 41 | } 42 | 43 | // scanValues returns the types for scanning values from sql.Rows. 44 | func (*User) scanValues(columns []string) ([]any, error) { 45 | values := make([]any, len(columns)) 46 | for i := range columns { 47 | switch columns[i] { 48 | case user.FieldID: 49 | values[i] = new(sql.NullInt64) 50 | case user.FieldName: 51 | values[i] = new(sql.NullString) 52 | default: 53 | return nil, fmt.Errorf("unexpected column %q for type User", columns[i]) 54 | } 55 | } 56 | return values, nil 57 | } 58 | 59 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 60 | // to the User fields. 61 | func (u *User) assignValues(columns []string, values []any) error { 62 | if m, n := len(values), len(columns); m < n { 63 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 64 | } 65 | for i := range columns { 66 | switch columns[i] { 67 | case user.FieldID: 68 | value, ok := values[i].(*sql.NullInt64) 69 | if !ok { 70 | return fmt.Errorf("unexpected type %T for field id", value) 71 | } 72 | u.ID = int(value.Int64) 73 | case user.FieldName: 74 | if value, ok := values[i].(*sql.NullString); !ok { 75 | return fmt.Errorf("unexpected type %T for field name", values[i]) 76 | } else if value.Valid { 77 | u.Name = value.String 78 | } 79 | } 80 | } 81 | return nil 82 | } 83 | 84 | // QueryTodos queries the "todos" edge of the User entity. 85 | func (u *User) QueryTodos() *TodoQuery { 86 | return (&UserClient{config: u.config}).QueryTodos(u) 87 | } 88 | 89 | // Update returns a builder for updating this User. 90 | // Note that you need to call User.Unwrap() before calling this method if this User 91 | // was returned from a transaction, and the transaction was committed or rolled back. 92 | func (u *User) Update() *UserUpdateOne { 93 | return (&UserClient{config: u.config}).UpdateOne(u) 94 | } 95 | 96 | // Unwrap unwraps the User entity that was returned from a transaction after it was closed, 97 | // so that all future queries will be executed through the driver which created the transaction. 98 | func (u *User) Unwrap() *User { 99 | tx, ok := u.config.driver.(*txDriver) 100 | if !ok { 101 | panic("ent: User is not a transactional entity") 102 | } 103 | u.config.driver = tx.drv 104 | return u 105 | } 106 | 107 | // String implements the fmt.Stringer. 108 | func (u *User) String() string { 109 | var builder strings.Builder 110 | builder.WriteString("User(") 111 | builder.WriteString(fmt.Sprintf("id=%v", u.ID)) 112 | builder.WriteString(", name=") 113 | builder.WriteString(u.Name) 114 | builder.WriteByte(')') 115 | return builder.String() 116 | } 117 | 118 | // Users is a parsable slice of User. 119 | type Users []*User 120 | 121 | func (u Users) config(cfg config) { 122 | for _i := range u { 123 | u[_i].config = cfg 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/user/user.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package user 4 | 5 | const ( 6 | // Label holds the string label denoting the user type in the database. 7 | Label = "user" 8 | // FieldID holds the string denoting the id field in the database. 9 | FieldID = "id" 10 | // FieldName holds the string denoting the name field in the database. 11 | FieldName = "name" 12 | // EdgeTodos holds the string denoting the todos edge name in mutations. 13 | EdgeTodos = "todos" 14 | // Table holds the table name of the user in the database. 15 | Table = "users" 16 | // TodosTable is the table that holds the todos relation/edge. 17 | TodosTable = "todos" 18 | // TodosInverseTable is the table name for the Todo entity. 19 | // It exists in this package in order to avoid circular dependency with the "todo" package. 20 | TodosInverseTable = "todos" 21 | // TodosColumn is the table column denoting the todos relation/edge. 22 | TodosColumn = "user_todos" 23 | ) 24 | 25 | // Columns holds all SQL columns for user fields. 26 | var Columns = []string{ 27 | FieldID, 28 | FieldName, 29 | } 30 | 31 | // ValidColumn reports if the column name is valid (part of the table columns). 32 | func ValidColumn(column string) bool { 33 | for i := range Columns { 34 | if column == Columns[i] { 35 | return true 36 | } 37 | } 38 | return false 39 | } 40 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/user/where.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package user 4 | 5 | import ( 6 | "todo/ent/predicate" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | "entgo.io/ent/dialect/sql/sqlgraph" 10 | ) 11 | 12 | // ID filters vertices based on their ID field. 13 | func ID(id int) predicate.User { 14 | return predicate.User(func(s *sql.Selector) { 15 | s.Where(sql.EQ(s.C(FieldID), id)) 16 | }) 17 | } 18 | 19 | // IDEQ applies the EQ predicate on the ID field. 20 | func IDEQ(id int) predicate.User { 21 | return predicate.User(func(s *sql.Selector) { 22 | s.Where(sql.EQ(s.C(FieldID), id)) 23 | }) 24 | } 25 | 26 | // IDNEQ applies the NEQ predicate on the ID field. 27 | func IDNEQ(id int) predicate.User { 28 | return predicate.User(func(s *sql.Selector) { 29 | s.Where(sql.NEQ(s.C(FieldID), id)) 30 | }) 31 | } 32 | 33 | // IDIn applies the In predicate on the ID field. 34 | func IDIn(ids ...int) predicate.User { 35 | return predicate.User(func(s *sql.Selector) { 36 | // if not arguments were provided, append the FALSE constants, 37 | // since we can't apply "IN ()". This will make this predicate falsy. 38 | if len(ids) == 0 { 39 | s.Where(sql.False()) 40 | return 41 | } 42 | v := make([]any, len(ids)) 43 | for i := range v { 44 | v[i] = ids[i] 45 | } 46 | s.Where(sql.In(s.C(FieldID), v...)) 47 | }) 48 | } 49 | 50 | // IDNotIn applies the NotIn predicate on the ID field. 51 | func IDNotIn(ids ...int) predicate.User { 52 | return predicate.User(func(s *sql.Selector) { 53 | // if not arguments were provided, append the FALSE constants, 54 | // since we can't apply "IN ()". This will make this predicate falsy. 55 | if len(ids) == 0 { 56 | s.Where(sql.False()) 57 | return 58 | } 59 | v := make([]any, len(ids)) 60 | for i := range v { 61 | v[i] = ids[i] 62 | } 63 | s.Where(sql.NotIn(s.C(FieldID), v...)) 64 | }) 65 | } 66 | 67 | // IDGT applies the GT predicate on the ID field. 68 | func IDGT(id int) predicate.User { 69 | return predicate.User(func(s *sql.Selector) { 70 | s.Where(sql.GT(s.C(FieldID), id)) 71 | }) 72 | } 73 | 74 | // IDGTE applies the GTE predicate on the ID field. 75 | func IDGTE(id int) predicate.User { 76 | return predicate.User(func(s *sql.Selector) { 77 | s.Where(sql.GTE(s.C(FieldID), id)) 78 | }) 79 | } 80 | 81 | // IDLT applies the LT predicate on the ID field. 82 | func IDLT(id int) predicate.User { 83 | return predicate.User(func(s *sql.Selector) { 84 | s.Where(sql.LT(s.C(FieldID), id)) 85 | }) 86 | } 87 | 88 | // IDLTE applies the LTE predicate on the ID field. 89 | func IDLTE(id int) predicate.User { 90 | return predicate.User(func(s *sql.Selector) { 91 | s.Where(sql.LTE(s.C(FieldID), id)) 92 | }) 93 | } 94 | 95 | // Name applies equality check predicate on the "name" field. It's identical to NameEQ. 96 | func Name(v string) predicate.User { 97 | return predicate.User(func(s *sql.Selector) { 98 | s.Where(sql.EQ(s.C(FieldName), v)) 99 | }) 100 | } 101 | 102 | // NameEQ applies the EQ predicate on the "name" field. 103 | func NameEQ(v string) predicate.User { 104 | return predicate.User(func(s *sql.Selector) { 105 | s.Where(sql.EQ(s.C(FieldName), v)) 106 | }) 107 | } 108 | 109 | // NameNEQ applies the NEQ predicate on the "name" field. 110 | func NameNEQ(v string) predicate.User { 111 | return predicate.User(func(s *sql.Selector) { 112 | s.Where(sql.NEQ(s.C(FieldName), v)) 113 | }) 114 | } 115 | 116 | // NameIn applies the In predicate on the "name" field. 117 | func NameIn(vs ...string) predicate.User { 118 | v := make([]any, len(vs)) 119 | for i := range v { 120 | v[i] = vs[i] 121 | } 122 | return predicate.User(func(s *sql.Selector) { 123 | // if not arguments were provided, append the FALSE constants, 124 | // since we can't apply "IN ()". This will make this predicate falsy. 125 | if len(v) == 0 { 126 | s.Where(sql.False()) 127 | return 128 | } 129 | s.Where(sql.In(s.C(FieldName), v...)) 130 | }) 131 | } 132 | 133 | // NameNotIn applies the NotIn predicate on the "name" field. 134 | func NameNotIn(vs ...string) predicate.User { 135 | v := make([]any, len(vs)) 136 | for i := range v { 137 | v[i] = vs[i] 138 | } 139 | return predicate.User(func(s *sql.Selector) { 140 | // if not arguments were provided, append the FALSE constants, 141 | // since we can't apply "IN ()". This will make this predicate falsy. 142 | if len(v) == 0 { 143 | s.Where(sql.False()) 144 | return 145 | } 146 | s.Where(sql.NotIn(s.C(FieldName), v...)) 147 | }) 148 | } 149 | 150 | // NameGT applies the GT predicate on the "name" field. 151 | func NameGT(v string) predicate.User { 152 | return predicate.User(func(s *sql.Selector) { 153 | s.Where(sql.GT(s.C(FieldName), v)) 154 | }) 155 | } 156 | 157 | // NameGTE applies the GTE predicate on the "name" field. 158 | func NameGTE(v string) predicate.User { 159 | return predicate.User(func(s *sql.Selector) { 160 | s.Where(sql.GTE(s.C(FieldName), v)) 161 | }) 162 | } 163 | 164 | // NameLT applies the LT predicate on the "name" field. 165 | func NameLT(v string) predicate.User { 166 | return predicate.User(func(s *sql.Selector) { 167 | s.Where(sql.LT(s.C(FieldName), v)) 168 | }) 169 | } 170 | 171 | // NameLTE applies the LTE predicate on the "name" field. 172 | func NameLTE(v string) predicate.User { 173 | return predicate.User(func(s *sql.Selector) { 174 | s.Where(sql.LTE(s.C(FieldName), v)) 175 | }) 176 | } 177 | 178 | // NameContains applies the Contains predicate on the "name" field. 179 | func NameContains(v string) predicate.User { 180 | return predicate.User(func(s *sql.Selector) { 181 | s.Where(sql.Contains(s.C(FieldName), v)) 182 | }) 183 | } 184 | 185 | // NameHasPrefix applies the HasPrefix predicate on the "name" field. 186 | func NameHasPrefix(v string) predicate.User { 187 | return predicate.User(func(s *sql.Selector) { 188 | s.Where(sql.HasPrefix(s.C(FieldName), v)) 189 | }) 190 | } 191 | 192 | // NameHasSuffix applies the HasSuffix predicate on the "name" field. 193 | func NameHasSuffix(v string) predicate.User { 194 | return predicate.User(func(s *sql.Selector) { 195 | s.Where(sql.HasSuffix(s.C(FieldName), v)) 196 | }) 197 | } 198 | 199 | // NameEqualFold applies the EqualFold predicate on the "name" field. 200 | func NameEqualFold(v string) predicate.User { 201 | return predicate.User(func(s *sql.Selector) { 202 | s.Where(sql.EqualFold(s.C(FieldName), v)) 203 | }) 204 | } 205 | 206 | // NameContainsFold applies the ContainsFold predicate on the "name" field. 207 | func NameContainsFold(v string) predicate.User { 208 | return predicate.User(func(s *sql.Selector) { 209 | s.Where(sql.ContainsFold(s.C(FieldName), v)) 210 | }) 211 | } 212 | 213 | // HasTodos applies the HasEdge predicate on the "todos" edge. 214 | func HasTodos() predicate.User { 215 | return predicate.User(func(s *sql.Selector) { 216 | step := sqlgraph.NewStep( 217 | sqlgraph.From(Table, FieldID), 218 | sqlgraph.To(TodosTable, FieldID), 219 | sqlgraph.Edge(sqlgraph.O2M, false, TodosTable, TodosColumn), 220 | ) 221 | sqlgraph.HasNeighbors(s, step) 222 | }) 223 | } 224 | 225 | // HasTodosWith applies the HasEdge predicate on the "todos" edge with a given conditions (other predicates). 226 | func HasTodosWith(preds ...predicate.Todo) predicate.User { 227 | return predicate.User(func(s *sql.Selector) { 228 | step := sqlgraph.NewStep( 229 | sqlgraph.From(Table, FieldID), 230 | sqlgraph.To(TodosInverseTable, FieldID), 231 | sqlgraph.Edge(sqlgraph.O2M, false, TodosTable, TodosColumn), 232 | ) 233 | sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { 234 | for _, p := range preds { 235 | p(s) 236 | } 237 | }) 238 | }) 239 | } 240 | 241 | // And groups predicates with the AND operator between them. 242 | func And(predicates ...predicate.User) predicate.User { 243 | return predicate.User(func(s *sql.Selector) { 244 | s1 := s.Clone().SetP(nil) 245 | for _, p := range predicates { 246 | p(s1) 247 | } 248 | s.Where(s1.P()) 249 | }) 250 | } 251 | 252 | // Or groups predicates with the OR operator between them. 253 | func Or(predicates ...predicate.User) predicate.User { 254 | return predicate.User(func(s *sql.Selector) { 255 | s1 := s.Clone().SetP(nil) 256 | for i, p := range predicates { 257 | if i > 0 { 258 | s1.Or() 259 | } 260 | p(s1) 261 | } 262 | s.Where(s1.P()) 263 | }) 264 | } 265 | 266 | // Not applies the not operator on the given predicate. 267 | func Not(p predicate.User) predicate.User { 268 | return predicate.User(func(s *sql.Selector) { 269 | p(s.Not()) 270 | }) 271 | } 272 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/user_create.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "todo/ent/todo" 10 | "todo/ent/user" 11 | 12 | "entgo.io/ent/dialect/sql/sqlgraph" 13 | "entgo.io/ent/schema/field" 14 | ) 15 | 16 | // UserCreate is the builder for creating a User entity. 17 | type UserCreate struct { 18 | config 19 | mutation *UserMutation 20 | hooks []Hook 21 | } 22 | 23 | // SetName sets the "name" field. 24 | func (uc *UserCreate) SetName(s string) *UserCreate { 25 | uc.mutation.SetName(s) 26 | return uc 27 | } 28 | 29 | // AddTodoIDs adds the "todos" edge to the Todo entity by IDs. 30 | func (uc *UserCreate) AddTodoIDs(ids ...int) *UserCreate { 31 | uc.mutation.AddTodoIDs(ids...) 32 | return uc 33 | } 34 | 35 | // AddTodos adds the "todos" edges to the Todo entity. 36 | func (uc *UserCreate) AddTodos(t ...*Todo) *UserCreate { 37 | ids := make([]int, len(t)) 38 | for i := range t { 39 | ids[i] = t[i].ID 40 | } 41 | return uc.AddTodoIDs(ids...) 42 | } 43 | 44 | // Mutation returns the UserMutation object of the builder. 45 | func (uc *UserCreate) Mutation() *UserMutation { 46 | return uc.mutation 47 | } 48 | 49 | // Save creates the User in the database. 50 | func (uc *UserCreate) Save(ctx context.Context) (*User, error) { 51 | var ( 52 | err error 53 | node *User 54 | ) 55 | if len(uc.hooks) == 0 { 56 | if err = uc.check(); err != nil { 57 | return nil, err 58 | } 59 | node, err = uc.sqlSave(ctx) 60 | } else { 61 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 62 | mutation, ok := m.(*UserMutation) 63 | if !ok { 64 | return nil, fmt.Errorf("unexpected mutation type %T", m) 65 | } 66 | if err = uc.check(); err != nil { 67 | return nil, err 68 | } 69 | uc.mutation = mutation 70 | if node, err = uc.sqlSave(ctx); err != nil { 71 | return nil, err 72 | } 73 | mutation.id = &node.ID 74 | mutation.done = true 75 | return node, err 76 | }) 77 | for i := len(uc.hooks) - 1; i >= 0; i-- { 78 | if uc.hooks[i] == nil { 79 | return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 80 | } 81 | mut = uc.hooks[i](mut) 82 | } 83 | if _, err := mut.Mutate(ctx, uc.mutation); err != nil { 84 | return nil, err 85 | } 86 | } 87 | return node, err 88 | } 89 | 90 | // SaveX calls Save and panics if Save returns an error. 91 | func (uc *UserCreate) SaveX(ctx context.Context) *User { 92 | v, err := uc.Save(ctx) 93 | if err != nil { 94 | panic(err) 95 | } 96 | return v 97 | } 98 | 99 | // Exec executes the query. 100 | func (uc *UserCreate) Exec(ctx context.Context) error { 101 | _, err := uc.Save(ctx) 102 | return err 103 | } 104 | 105 | // ExecX is like Exec, but panics if an error occurs. 106 | func (uc *UserCreate) ExecX(ctx context.Context) { 107 | if err := uc.Exec(ctx); err != nil { 108 | panic(err) 109 | } 110 | } 111 | 112 | // check runs all checks and user-defined validators on the builder. 113 | func (uc *UserCreate) check() error { 114 | if _, ok := uc.mutation.Name(); !ok { 115 | return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "User.name"`)} 116 | } 117 | return nil 118 | } 119 | 120 | func (uc *UserCreate) sqlSave(ctx context.Context) (*User, error) { 121 | _node, _spec := uc.createSpec() 122 | if err := sqlgraph.CreateNode(ctx, uc.driver, _spec); err != nil { 123 | if sqlgraph.IsConstraintError(err) { 124 | err = &ConstraintError{err.Error(), err} 125 | } 126 | return nil, err 127 | } 128 | id := _spec.ID.Value.(int64) 129 | _node.ID = int(id) 130 | return _node, nil 131 | } 132 | 133 | func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { 134 | var ( 135 | _node = &User{config: uc.config} 136 | _spec = &sqlgraph.CreateSpec{ 137 | Table: user.Table, 138 | ID: &sqlgraph.FieldSpec{ 139 | Type: field.TypeInt, 140 | Column: user.FieldID, 141 | }, 142 | } 143 | ) 144 | if value, ok := uc.mutation.Name(); ok { 145 | _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ 146 | Type: field.TypeString, 147 | Value: value, 148 | Column: user.FieldName, 149 | }) 150 | _node.Name = value 151 | } 152 | if nodes := uc.mutation.TodosIDs(); len(nodes) > 0 { 153 | edge := &sqlgraph.EdgeSpec{ 154 | Rel: sqlgraph.O2M, 155 | Inverse: false, 156 | Table: user.TodosTable, 157 | Columns: []string{user.TodosColumn}, 158 | Bidi: false, 159 | Target: &sqlgraph.EdgeTarget{ 160 | IDSpec: &sqlgraph.FieldSpec{ 161 | Type: field.TypeInt, 162 | Column: todo.FieldID, 163 | }, 164 | }, 165 | } 166 | for _, k := range nodes { 167 | edge.Target.Nodes = append(edge.Target.Nodes, k) 168 | } 169 | _spec.Edges = append(_spec.Edges, edge) 170 | } 171 | return _node, _spec 172 | } 173 | 174 | // UserCreateBulk is the builder for creating many User entities in bulk. 175 | type UserCreateBulk struct { 176 | config 177 | builders []*UserCreate 178 | } 179 | 180 | // Save creates the User entities in the database. 181 | func (ucb *UserCreateBulk) Save(ctx context.Context) ([]*User, error) { 182 | specs := make([]*sqlgraph.CreateSpec, len(ucb.builders)) 183 | nodes := make([]*User, len(ucb.builders)) 184 | mutators := make([]Mutator, len(ucb.builders)) 185 | for i := range ucb.builders { 186 | func(i int, root context.Context) { 187 | builder := ucb.builders[i] 188 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 189 | mutation, ok := m.(*UserMutation) 190 | if !ok { 191 | return nil, fmt.Errorf("unexpected mutation type %T", m) 192 | } 193 | if err := builder.check(); err != nil { 194 | return nil, err 195 | } 196 | builder.mutation = mutation 197 | nodes[i], specs[i] = builder.createSpec() 198 | var err error 199 | if i < len(mutators)-1 { 200 | _, err = mutators[i+1].Mutate(root, ucb.builders[i+1].mutation) 201 | } else { 202 | spec := &sqlgraph.BatchCreateSpec{Nodes: specs} 203 | // Invoke the actual operation on the latest mutation in the chain. 204 | if err = sqlgraph.BatchCreate(ctx, ucb.driver, spec); err != nil { 205 | if sqlgraph.IsConstraintError(err) { 206 | err = &ConstraintError{err.Error(), err} 207 | } 208 | } 209 | } 210 | if err != nil { 211 | return nil, err 212 | } 213 | mutation.id = &nodes[i].ID 214 | mutation.done = true 215 | if specs[i].ID.Value != nil { 216 | id := specs[i].ID.Value.(int64) 217 | nodes[i].ID = int(id) 218 | } 219 | return nodes[i], nil 220 | }) 221 | for i := len(builder.hooks) - 1; i >= 0; i-- { 222 | mut = builder.hooks[i](mut) 223 | } 224 | mutators[i] = mut 225 | }(i, ctx) 226 | } 227 | if len(mutators) > 0 { 228 | if _, err := mutators[0].Mutate(ctx, ucb.builders[0].mutation); err != nil { 229 | return nil, err 230 | } 231 | } 232 | return nodes, nil 233 | } 234 | 235 | // SaveX is like Save, but panics if an error occurs. 236 | func (ucb *UserCreateBulk) SaveX(ctx context.Context) []*User { 237 | v, err := ucb.Save(ctx) 238 | if err != nil { 239 | panic(err) 240 | } 241 | return v 242 | } 243 | 244 | // Exec executes the query. 245 | func (ucb *UserCreateBulk) Exec(ctx context.Context) error { 246 | _, err := ucb.Save(ctx) 247 | return err 248 | } 249 | 250 | // ExecX is like Exec, but panics if an error occurs. 251 | func (ucb *UserCreateBulk) ExecX(ctx context.Context) { 252 | if err := ucb.Exec(ctx); err != nil { 253 | panic(err) 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/user_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "todo/ent/predicate" 9 | "todo/ent/user" 10 | 11 | "entgo.io/ent/dialect/sql" 12 | "entgo.io/ent/dialect/sql/sqlgraph" 13 | "entgo.io/ent/schema/field" 14 | ) 15 | 16 | // UserDelete is the builder for deleting a User entity. 17 | type UserDelete struct { 18 | config 19 | hooks []Hook 20 | mutation *UserMutation 21 | } 22 | 23 | // Where appends a list predicates to the UserDelete builder. 24 | func (ud *UserDelete) Where(ps ...predicate.User) *UserDelete { 25 | ud.mutation.Where(ps...) 26 | return ud 27 | } 28 | 29 | // Exec executes the deletion query and returns how many vertices were deleted. 30 | func (ud *UserDelete) Exec(ctx context.Context) (int, error) { 31 | var ( 32 | err error 33 | affected int 34 | ) 35 | if len(ud.hooks) == 0 { 36 | affected, err = ud.sqlExec(ctx) 37 | } else { 38 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 39 | mutation, ok := m.(*UserMutation) 40 | if !ok { 41 | return nil, fmt.Errorf("unexpected mutation type %T", m) 42 | } 43 | ud.mutation = mutation 44 | affected, err = ud.sqlExec(ctx) 45 | mutation.done = true 46 | return affected, err 47 | }) 48 | for i := len(ud.hooks) - 1; i >= 0; i-- { 49 | if ud.hooks[i] == nil { 50 | return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 51 | } 52 | mut = ud.hooks[i](mut) 53 | } 54 | if _, err := mut.Mutate(ctx, ud.mutation); err != nil { 55 | return 0, err 56 | } 57 | } 58 | return affected, err 59 | } 60 | 61 | // ExecX is like Exec, but panics if an error occurs. 62 | func (ud *UserDelete) ExecX(ctx context.Context) int { 63 | n, err := ud.Exec(ctx) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return n 68 | } 69 | 70 | func (ud *UserDelete) sqlExec(ctx context.Context) (int, error) { 71 | _spec := &sqlgraph.DeleteSpec{ 72 | Node: &sqlgraph.NodeSpec{ 73 | Table: user.Table, 74 | ID: &sqlgraph.FieldSpec{ 75 | Type: field.TypeInt, 76 | Column: user.FieldID, 77 | }, 78 | }, 79 | } 80 | if ps := ud.mutation.predicates; len(ps) > 0 { 81 | _spec.Predicate = func(selector *sql.Selector) { 82 | for i := range ps { 83 | ps[i](selector) 84 | } 85 | } 86 | } 87 | return sqlgraph.DeleteNodes(ctx, ud.driver, _spec) 88 | } 89 | 90 | // UserDeleteOne is the builder for deleting a single User entity. 91 | type UserDeleteOne struct { 92 | ud *UserDelete 93 | } 94 | 95 | // Exec executes the deletion query. 96 | func (udo *UserDeleteOne) Exec(ctx context.Context) error { 97 | n, err := udo.ud.Exec(ctx) 98 | switch { 99 | case err != nil: 100 | return err 101 | case n == 0: 102 | return &NotFoundError{user.Label} 103 | default: 104 | return nil 105 | } 106 | } 107 | 108 | // ExecX is like Exec, but panics if an error occurs. 109 | func (udo *UserDeleteOne) ExecX(ctx context.Context) { 110 | udo.ud.ExecX(ctx) 111 | } 112 | -------------------------------------------------------------------------------- /internal/examples/todo/ent/user_update.go: -------------------------------------------------------------------------------- 1 | // Code generated by entc, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "todo/ent/predicate" 10 | "todo/ent/todo" 11 | "todo/ent/user" 12 | 13 | "entgo.io/ent/dialect/sql" 14 | "entgo.io/ent/dialect/sql/sqlgraph" 15 | "entgo.io/ent/schema/field" 16 | ) 17 | 18 | // UserUpdate is the builder for updating User entities. 19 | type UserUpdate struct { 20 | config 21 | hooks []Hook 22 | mutation *UserMutation 23 | } 24 | 25 | // Where appends a list predicates to the UserUpdate builder. 26 | func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate { 27 | uu.mutation.Where(ps...) 28 | return uu 29 | } 30 | 31 | // SetName sets the "name" field. 32 | func (uu *UserUpdate) SetName(s string) *UserUpdate { 33 | uu.mutation.SetName(s) 34 | return uu 35 | } 36 | 37 | // AddTodoIDs adds the "todos" edge to the Todo entity by IDs. 38 | func (uu *UserUpdate) AddTodoIDs(ids ...int) *UserUpdate { 39 | uu.mutation.AddTodoIDs(ids...) 40 | return uu 41 | } 42 | 43 | // AddTodos adds the "todos" edges to the Todo entity. 44 | func (uu *UserUpdate) AddTodos(t ...*Todo) *UserUpdate { 45 | ids := make([]int, len(t)) 46 | for i := range t { 47 | ids[i] = t[i].ID 48 | } 49 | return uu.AddTodoIDs(ids...) 50 | } 51 | 52 | // Mutation returns the UserMutation object of the builder. 53 | func (uu *UserUpdate) Mutation() *UserMutation { 54 | return uu.mutation 55 | } 56 | 57 | // ClearTodos clears all "todos" edges to the Todo entity. 58 | func (uu *UserUpdate) ClearTodos() *UserUpdate { 59 | uu.mutation.ClearTodos() 60 | return uu 61 | } 62 | 63 | // RemoveTodoIDs removes the "todos" edge to Todo entities by IDs. 64 | func (uu *UserUpdate) RemoveTodoIDs(ids ...int) *UserUpdate { 65 | uu.mutation.RemoveTodoIDs(ids...) 66 | return uu 67 | } 68 | 69 | // RemoveTodos removes "todos" edges to Todo entities. 70 | func (uu *UserUpdate) RemoveTodos(t ...*Todo) *UserUpdate { 71 | ids := make([]int, len(t)) 72 | for i := range t { 73 | ids[i] = t[i].ID 74 | } 75 | return uu.RemoveTodoIDs(ids...) 76 | } 77 | 78 | // Save executes the query and returns the number of nodes affected by the update operation. 79 | func (uu *UserUpdate) Save(ctx context.Context) (int, error) { 80 | var ( 81 | err error 82 | affected int 83 | ) 84 | if len(uu.hooks) == 0 { 85 | affected, err = uu.sqlSave(ctx) 86 | } else { 87 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 88 | mutation, ok := m.(*UserMutation) 89 | if !ok { 90 | return nil, fmt.Errorf("unexpected mutation type %T", m) 91 | } 92 | uu.mutation = mutation 93 | affected, err = uu.sqlSave(ctx) 94 | mutation.done = true 95 | return affected, err 96 | }) 97 | for i := len(uu.hooks) - 1; i >= 0; i-- { 98 | if uu.hooks[i] == nil { 99 | return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 100 | } 101 | mut = uu.hooks[i](mut) 102 | } 103 | if _, err := mut.Mutate(ctx, uu.mutation); err != nil { 104 | return 0, err 105 | } 106 | } 107 | return affected, err 108 | } 109 | 110 | // SaveX is like Save, but panics if an error occurs. 111 | func (uu *UserUpdate) SaveX(ctx context.Context) int { 112 | affected, err := uu.Save(ctx) 113 | if err != nil { 114 | panic(err) 115 | } 116 | return affected 117 | } 118 | 119 | // Exec executes the query. 120 | func (uu *UserUpdate) Exec(ctx context.Context) error { 121 | _, err := uu.Save(ctx) 122 | return err 123 | } 124 | 125 | // ExecX is like Exec, but panics if an error occurs. 126 | func (uu *UserUpdate) ExecX(ctx context.Context) { 127 | if err := uu.Exec(ctx); err != nil { 128 | panic(err) 129 | } 130 | } 131 | 132 | func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { 133 | _spec := &sqlgraph.UpdateSpec{ 134 | Node: &sqlgraph.NodeSpec{ 135 | Table: user.Table, 136 | Columns: user.Columns, 137 | ID: &sqlgraph.FieldSpec{ 138 | Type: field.TypeInt, 139 | Column: user.FieldID, 140 | }, 141 | }, 142 | } 143 | if ps := uu.mutation.predicates; len(ps) > 0 { 144 | _spec.Predicate = func(selector *sql.Selector) { 145 | for i := range ps { 146 | ps[i](selector) 147 | } 148 | } 149 | } 150 | if value, ok := uu.mutation.Name(); ok { 151 | _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ 152 | Type: field.TypeString, 153 | Value: value, 154 | Column: user.FieldName, 155 | }) 156 | } 157 | if uu.mutation.TodosCleared() { 158 | edge := &sqlgraph.EdgeSpec{ 159 | Rel: sqlgraph.O2M, 160 | Inverse: false, 161 | Table: user.TodosTable, 162 | Columns: []string{user.TodosColumn}, 163 | Bidi: false, 164 | Target: &sqlgraph.EdgeTarget{ 165 | IDSpec: &sqlgraph.FieldSpec{ 166 | Type: field.TypeInt, 167 | Column: todo.FieldID, 168 | }, 169 | }, 170 | } 171 | _spec.Edges.Clear = append(_spec.Edges.Clear, edge) 172 | } 173 | if nodes := uu.mutation.RemovedTodosIDs(); len(nodes) > 0 && !uu.mutation.TodosCleared() { 174 | edge := &sqlgraph.EdgeSpec{ 175 | Rel: sqlgraph.O2M, 176 | Inverse: false, 177 | Table: user.TodosTable, 178 | Columns: []string{user.TodosColumn}, 179 | Bidi: false, 180 | Target: &sqlgraph.EdgeTarget{ 181 | IDSpec: &sqlgraph.FieldSpec{ 182 | Type: field.TypeInt, 183 | Column: todo.FieldID, 184 | }, 185 | }, 186 | } 187 | for _, k := range nodes { 188 | edge.Target.Nodes = append(edge.Target.Nodes, k) 189 | } 190 | _spec.Edges.Clear = append(_spec.Edges.Clear, edge) 191 | } 192 | if nodes := uu.mutation.TodosIDs(); len(nodes) > 0 { 193 | edge := &sqlgraph.EdgeSpec{ 194 | Rel: sqlgraph.O2M, 195 | Inverse: false, 196 | Table: user.TodosTable, 197 | Columns: []string{user.TodosColumn}, 198 | Bidi: false, 199 | Target: &sqlgraph.EdgeTarget{ 200 | IDSpec: &sqlgraph.FieldSpec{ 201 | Type: field.TypeInt, 202 | Column: todo.FieldID, 203 | }, 204 | }, 205 | } 206 | for _, k := range nodes { 207 | edge.Target.Nodes = append(edge.Target.Nodes, k) 208 | } 209 | _spec.Edges.Add = append(_spec.Edges.Add, edge) 210 | } 211 | if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil { 212 | if _, ok := err.(*sqlgraph.NotFoundError); ok { 213 | err = &NotFoundError{user.Label} 214 | } else if sqlgraph.IsConstraintError(err) { 215 | err = &ConstraintError{err.Error(), err} 216 | } 217 | return 0, err 218 | } 219 | return n, nil 220 | } 221 | 222 | // UserUpdateOne is the builder for updating a single User entity. 223 | type UserUpdateOne struct { 224 | config 225 | fields []string 226 | hooks []Hook 227 | mutation *UserMutation 228 | } 229 | 230 | // SetName sets the "name" field. 231 | func (uuo *UserUpdateOne) SetName(s string) *UserUpdateOne { 232 | uuo.mutation.SetName(s) 233 | return uuo 234 | } 235 | 236 | // AddTodoIDs adds the "todos" edge to the Todo entity by IDs. 237 | func (uuo *UserUpdateOne) AddTodoIDs(ids ...int) *UserUpdateOne { 238 | uuo.mutation.AddTodoIDs(ids...) 239 | return uuo 240 | } 241 | 242 | // AddTodos adds the "todos" edges to the Todo entity. 243 | func (uuo *UserUpdateOne) AddTodos(t ...*Todo) *UserUpdateOne { 244 | ids := make([]int, len(t)) 245 | for i := range t { 246 | ids[i] = t[i].ID 247 | } 248 | return uuo.AddTodoIDs(ids...) 249 | } 250 | 251 | // Mutation returns the UserMutation object of the builder. 252 | func (uuo *UserUpdateOne) Mutation() *UserMutation { 253 | return uuo.mutation 254 | } 255 | 256 | // ClearTodos clears all "todos" edges to the Todo entity. 257 | func (uuo *UserUpdateOne) ClearTodos() *UserUpdateOne { 258 | uuo.mutation.ClearTodos() 259 | return uuo 260 | } 261 | 262 | // RemoveTodoIDs removes the "todos" edge to Todo entities by IDs. 263 | func (uuo *UserUpdateOne) RemoveTodoIDs(ids ...int) *UserUpdateOne { 264 | uuo.mutation.RemoveTodoIDs(ids...) 265 | return uuo 266 | } 267 | 268 | // RemoveTodos removes "todos" edges to Todo entities. 269 | func (uuo *UserUpdateOne) RemoveTodos(t ...*Todo) *UserUpdateOne { 270 | ids := make([]int, len(t)) 271 | for i := range t { 272 | ids[i] = t[i].ID 273 | } 274 | return uuo.RemoveTodoIDs(ids...) 275 | } 276 | 277 | // Select allows selecting one or more fields (columns) of the returned entity. 278 | // The default is selecting all fields defined in the entity schema. 279 | func (uuo *UserUpdateOne) Select(field string, fields ...string) *UserUpdateOne { 280 | uuo.fields = append([]string{field}, fields...) 281 | return uuo 282 | } 283 | 284 | // Save executes the query and returns the updated User entity. 285 | func (uuo *UserUpdateOne) Save(ctx context.Context) (*User, error) { 286 | var ( 287 | err error 288 | node *User 289 | ) 290 | if len(uuo.hooks) == 0 { 291 | node, err = uuo.sqlSave(ctx) 292 | } else { 293 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 294 | mutation, ok := m.(*UserMutation) 295 | if !ok { 296 | return nil, fmt.Errorf("unexpected mutation type %T", m) 297 | } 298 | uuo.mutation = mutation 299 | node, err = uuo.sqlSave(ctx) 300 | mutation.done = true 301 | return node, err 302 | }) 303 | for i := len(uuo.hooks) - 1; i >= 0; i-- { 304 | if uuo.hooks[i] == nil { 305 | return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") 306 | } 307 | mut = uuo.hooks[i](mut) 308 | } 309 | if _, err := mut.Mutate(ctx, uuo.mutation); err != nil { 310 | return nil, err 311 | } 312 | } 313 | return node, err 314 | } 315 | 316 | // SaveX is like Save, but panics if an error occurs. 317 | func (uuo *UserUpdateOne) SaveX(ctx context.Context) *User { 318 | node, err := uuo.Save(ctx) 319 | if err != nil { 320 | panic(err) 321 | } 322 | return node 323 | } 324 | 325 | // Exec executes the query on the entity. 326 | func (uuo *UserUpdateOne) Exec(ctx context.Context) error { 327 | _, err := uuo.Save(ctx) 328 | return err 329 | } 330 | 331 | // ExecX is like Exec, but panics if an error occurs. 332 | func (uuo *UserUpdateOne) ExecX(ctx context.Context) { 333 | if err := uuo.Exec(ctx); err != nil { 334 | panic(err) 335 | } 336 | } 337 | 338 | func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { 339 | _spec := &sqlgraph.UpdateSpec{ 340 | Node: &sqlgraph.NodeSpec{ 341 | Table: user.Table, 342 | Columns: user.Columns, 343 | ID: &sqlgraph.FieldSpec{ 344 | Type: field.TypeInt, 345 | Column: user.FieldID, 346 | }, 347 | }, 348 | } 349 | id, ok := uuo.mutation.ID() 350 | if !ok { 351 | return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "User.id" for update`)} 352 | } 353 | _spec.Node.ID.Value = id 354 | if fields := uuo.fields; len(fields) > 0 { 355 | _spec.Node.Columns = make([]string, 0, len(fields)) 356 | _spec.Node.Columns = append(_spec.Node.Columns, user.FieldID) 357 | for _, f := range fields { 358 | if !user.ValidColumn(f) { 359 | return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} 360 | } 361 | if f != user.FieldID { 362 | _spec.Node.Columns = append(_spec.Node.Columns, f) 363 | } 364 | } 365 | } 366 | if ps := uuo.mutation.predicates; len(ps) > 0 { 367 | _spec.Predicate = func(selector *sql.Selector) { 368 | for i := range ps { 369 | ps[i](selector) 370 | } 371 | } 372 | } 373 | if value, ok := uuo.mutation.Name(); ok { 374 | _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ 375 | Type: field.TypeString, 376 | Value: value, 377 | Column: user.FieldName, 378 | }) 379 | } 380 | if uuo.mutation.TodosCleared() { 381 | edge := &sqlgraph.EdgeSpec{ 382 | Rel: sqlgraph.O2M, 383 | Inverse: false, 384 | Table: user.TodosTable, 385 | Columns: []string{user.TodosColumn}, 386 | Bidi: false, 387 | Target: &sqlgraph.EdgeTarget{ 388 | IDSpec: &sqlgraph.FieldSpec{ 389 | Type: field.TypeInt, 390 | Column: todo.FieldID, 391 | }, 392 | }, 393 | } 394 | _spec.Edges.Clear = append(_spec.Edges.Clear, edge) 395 | } 396 | if nodes := uuo.mutation.RemovedTodosIDs(); len(nodes) > 0 && !uuo.mutation.TodosCleared() { 397 | edge := &sqlgraph.EdgeSpec{ 398 | Rel: sqlgraph.O2M, 399 | Inverse: false, 400 | Table: user.TodosTable, 401 | Columns: []string{user.TodosColumn}, 402 | Bidi: false, 403 | Target: &sqlgraph.EdgeTarget{ 404 | IDSpec: &sqlgraph.FieldSpec{ 405 | Type: field.TypeInt, 406 | Column: todo.FieldID, 407 | }, 408 | }, 409 | } 410 | for _, k := range nodes { 411 | edge.Target.Nodes = append(edge.Target.Nodes, k) 412 | } 413 | _spec.Edges.Clear = append(_spec.Edges.Clear, edge) 414 | } 415 | if nodes := uuo.mutation.TodosIDs(); len(nodes) > 0 { 416 | edge := &sqlgraph.EdgeSpec{ 417 | Rel: sqlgraph.O2M, 418 | Inverse: false, 419 | Table: user.TodosTable, 420 | Columns: []string{user.TodosColumn}, 421 | Bidi: false, 422 | Target: &sqlgraph.EdgeTarget{ 423 | IDSpec: &sqlgraph.FieldSpec{ 424 | Type: field.TypeInt, 425 | Column: todo.FieldID, 426 | }, 427 | }, 428 | } 429 | for _, k := range nodes { 430 | edge.Target.Nodes = append(edge.Target.Nodes, k) 431 | } 432 | _spec.Edges.Add = append(_spec.Edges.Add, edge) 433 | } 434 | _node = &User{config: uuo.config} 435 | _spec.Assign = _node.assignValues 436 | _spec.ScanValues = _node.scanValues 437 | if err = sqlgraph.UpdateNode(ctx, uuo.driver, _spec); err != nil { 438 | if _, ok := err.(*sqlgraph.NotFoundError); ok { 439 | err = &NotFoundError{user.Label} 440 | } else if sqlgraph.IsConstraintError(err) { 441 | err = &ConstraintError{err.Error(), err} 442 | } 443 | return nil, err 444 | } 445 | return _node, nil 446 | } 447 | -------------------------------------------------------------------------------- /internal/examples/todo/generate.go: -------------------------------------------------------------------------------- 1 | package todo 2 | 3 | //go:generate go run -mod=mod github.com/99designs/gqlgen 4 | -------------------------------------------------------------------------------- /internal/examples/todo/go.mod: -------------------------------------------------------------------------------- 1 | module todo 2 | 3 | go 1.17 4 | 5 | require ( 6 | entgo.io/contrib v0.1.1-0.20211009150803-2f98d3a15e7d 7 | entgo.io/ent v0.9.2-0.20211014063230-899e9f0e50ba 8 | github.com/99designs/gqlgen v0.14.0 9 | github.com/hashicorp/go-multierror v1.1.1 10 | github.com/pkg/errors v0.9.1 // indirect 11 | github.com/vektah/gqlparser/v2 v2.2.0 12 | github.com/vmihailenco/msgpack/v5 v5.2.0 13 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c 14 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e // indirect 15 | ) 16 | 17 | require ( 18 | github.com/agnivade/levenshtein v1.1.0 // indirect 19 | github.com/go-openapi/inflect v0.19.0 // indirect 20 | github.com/google/uuid v1.3.0 // indirect 21 | github.com/graphql-go/graphql v0.7.10-0.20210411022516-8a92e977c10b // indirect 22 | github.com/hashicorp/errwrap v1.1.0 // indirect 23 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 24 | golang.org/x/mod v0.4.2 // indirect 25 | golang.org/x/tools v0.1.7 // indirect 26 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect 27 | gopkg.in/yaml.v2 v2.4.0 // indirect 28 | ) 29 | -------------------------------------------------------------------------------- /internal/examples/todo/gqlgen.yml: -------------------------------------------------------------------------------- 1 | # schema tells gqlgen when the GraphQL schema is located. 2 | schema: 3 | - todo.graphql 4 | - ent.graphql 5 | 6 | # resolver reports where the resolver implementations go. 7 | resolver: 8 | layout: follow-schema 9 | dir: . 10 | 11 | # gqlgen will search for any type names in the schema in these go packages 12 | # if they match it will use them, otherwise it will generate them. 13 | 14 | # autobind tells gqngen to search for any type names in the GraphQL schema in the 15 | # provided package. If they match it will use them, otherwise it will generate new. 16 | autobind: 17 | - todo/ent 18 | - todo/ent/todo 19 | 20 | # This section declares type mapping between the GraphQL and Go type systems. 21 | models: 22 | # Defines the ID field as Go 'int'. 23 | ID: 24 | model: 25 | - github.com/99designs/gqlgen/graphql.IntID 26 | Node: 27 | model: 28 | - todo/ent.Noder -------------------------------------------------------------------------------- /internal/examples/todo/resolver.go: -------------------------------------------------------------------------------- 1 | package todo 2 | 3 | import ( 4 | "todo/ent" 5 | 6 | "github.com/99designs/gqlgen/graphql" 7 | ) 8 | 9 | // This file will not be regenerated automatically. 10 | // 11 | // It serves as dependency injection for your app, add any dependencies you require here. 12 | 13 | // Resolver is the resolver root. 14 | type Resolver struct{ client *ent.Client } 15 | 16 | // NewSchema creates a graphql executable schema. 17 | func NewSchema(client *ent.Client) graphql.ExecutableSchema { 18 | return NewExecutableSchema(Config{ 19 | Resolvers: &Resolver{client}, 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /internal/examples/todo/todo.graphql: -------------------------------------------------------------------------------- 1 | interface Node { 2 | id: ID! 3 | } 4 | 5 | """Maps a Time GraphQL scalar to a Go time.Time struct.""" 6 | scalar Time 7 | 8 | """ 9 | Define a Relay Cursor type: 10 | https://relay.dev/graphql/connections.htm#sec-Cursor 11 | """ 12 | scalar Cursor 13 | 14 | """ 15 | Define an enumeration type and map it later to Ent enum (Go type). 16 | https://graphql.org/learn/schema/#enumeration-types 17 | """ 18 | enum Status { 19 | IN_PROGRESS 20 | COMPLETED 21 | } 22 | 23 | type PageInfo { 24 | hasNextPage: Boolean! 25 | hasPreviousPage: Boolean! 26 | startCursor: Cursor 27 | endCursor: Cursor 28 | } 29 | 30 | type TodoConnection { 31 | totalCount: Int! 32 | pageInfo: PageInfo! 33 | edges: [TodoEdge] 34 | } 35 | 36 | type TodoEdge { 37 | node: Todo 38 | cursor: Cursor! 39 | } 40 | 41 | """The following enums are matched the entgql annotations in the ent/schema.""" 42 | enum TodoOrderField { 43 | CREATED_AT 44 | PRIORITY 45 | STATUS 46 | TEXT 47 | } 48 | 49 | enum OrderDirection { 50 | ASC 51 | DESC 52 | } 53 | 54 | input TodoOrder { 55 | direction: OrderDirection! 56 | field: TodoOrderField 57 | } 58 | 59 | """ 60 | Define an object type and map it later to the generated Ent model. 61 | https://graphql.org/learn/schema/#object-types-and-fields 62 | """ 63 | type Todo implements Node { 64 | id: ID! 65 | createdAt: Time 66 | status: Status! 67 | priority: Int! 68 | text: String! 69 | owner: User 70 | parent: Todo 71 | children: [Todo!] 72 | } 73 | 74 | """ 75 | Define an input type for the mutation below. 76 | https://graphql.org/learn/schema/#input-types 77 | 78 | Note that, this type is mapped to the generated 79 | input type in mutation_input.go. 80 | """ 81 | input CreateTodoInput { 82 | status: Status! = IN_PROGRESS 83 | priority: Int 84 | text: String 85 | parent: ID 86 | children: [ID!] 87 | } 88 | 89 | """ 90 | Define an input type for the mutation below. 91 | https://graphql.org/learn/schema/#input-types 92 | 93 | Note that, this type is mapped to the generated 94 | input type in mutation_input.go. 95 | """ 96 | input UpdateTodoInput { 97 | status: Status 98 | priority: Int 99 | text: String 100 | parent: ID 101 | clearParent: Boolean 102 | addChildIDs: [ID!] 103 | removeChildIDs: [ID!] 104 | } 105 | 106 | """ 107 | Define an input type for the mutation below. 108 | https://graphql.org/learn/schema/#input-types 109 | 110 | Note that, this type is mapped to the generated 111 | input type in mutation_input.go. 112 | """ 113 | input CreateUserInput { 114 | name: String 115 | todos: [ID!] 116 | } 117 | 118 | """ 119 | Define an object type and map it later to the generated Ent model. 120 | https://graphql.org/learn/schema/#object-types-and-fields 121 | """ 122 | type User implements Node { 123 | id: ID! 124 | name: String! 125 | todos: [Todo!] 126 | } 127 | 128 | """ 129 | Define a mutation for creating todos. 130 | https://graphql.org/learn/queries/#mutations 131 | """ 132 | type Mutation { 133 | createTodo(input: CreateTodoInput!): Todo! 134 | updateTodo(id: ID!, input: UpdateTodoInput!): Todo! 135 | updateTodos(ids: [ID!]!, input: UpdateTodoInput!): [Todo!]! 136 | createUser(input: CreateUserInput!): User! 137 | } 138 | 139 | """Define a query for getting all todos and support the Node interface.""" 140 | type Query { 141 | todos(after: Cursor, first: Int, before: Cursor, last: Int, orderBy: TodoOrder, where: TodoWhereInput): TodoConnection 142 | node(id: ID!): Node 143 | nodes(ids: [ID!]!): [Node]! 144 | } 145 | 146 | -------------------------------------------------------------------------------- /internal/examples/todo/todo.resolvers.go: -------------------------------------------------------------------------------- 1 | package todo 2 | 3 | // This file will be automatically regenerated based on the schema, any resolver implementations 4 | // will be copied through when generating and any unknown code will be moved to the end. 5 | 6 | import ( 7 | "context" 8 | "todo/ent" 9 | "todo/ent/todo" 10 | ) 11 | 12 | func (r *mutationResolver) CreateTodo(ctx context.Context, input ent.CreateTodoInput) (*ent.Todo, error) { 13 | return ent.FromContext(ctx).Todo.Create().SetInput(input).Save(ctx) 14 | } 15 | 16 | func (r *mutationResolver) UpdateTodo(ctx context.Context, id int, input ent.UpdateTodoInput) (*ent.Todo, error) { 17 | return ent.FromContext(ctx).Todo.UpdateOneID(id).SetInput(input).Save(ctx) 18 | } 19 | 20 | func (r *mutationResolver) UpdateTodos(ctx context.Context, ids []int, input ent.UpdateTodoInput) ([]*ent.Todo, error) { 21 | client := ent.FromContext(ctx) 22 | if err := client.Todo.Update().Where(todo.IDIn(ids...)).SetInput(input).Exec(ctx); err != nil { 23 | return nil, err 24 | } 25 | return client.Todo.Query().Where(todo.IDIn(ids...)).All(ctx) 26 | } 27 | 28 | func (r *mutationResolver) CreateUser(ctx context.Context, input ent.CreateUserInput) (*ent.User, error) { 29 | return ent.FromContext(ctx).User.Create().SetInput(input).Save(ctx) 30 | } 31 | 32 | func (r *queryResolver) Todos(ctx context.Context, after *ent.Cursor, first *int, before *ent.Cursor, last *int, orderBy *ent.TodoOrder, where *ent.TodoWhereInput) (*ent.TodoConnection, error) { 33 | return r.client.Todo.Query(). 34 | Paginate(ctx, after, first, before, last, 35 | ent.WithTodoOrder(orderBy), 36 | ent.WithTodoFilter(where.Filter), 37 | ) 38 | } 39 | 40 | func (r *queryResolver) Node(ctx context.Context, id int) (ent.Noder, error) { 41 | return r.client.Noder(ctx, id) 42 | } 43 | 44 | func (r *queryResolver) Nodes(ctx context.Context, ids []int) ([]ent.Noder, error) { 45 | return r.client.Noders(ctx, ids) 46 | } 47 | 48 | // Mutation returns MutationResolver implementation. 49 | func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} } 50 | 51 | // Query returns QueryResolver implementation. 52 | func (r *Resolver) Query() QueryResolver { return &queryResolver{r} } 53 | 54 | type mutationResolver struct{ *Resolver } 55 | type queryResolver struct{ *Resolver } 56 | -------------------------------------------------------------------------------- /level.go: -------------------------------------------------------------------------------- 1 | package entcache 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql/driver" 7 | "encoding/gob" 8 | "errors" 9 | "fmt" 10 | "sync" 11 | "time" 12 | 13 | "github.com/golang/groupcache/lru" 14 | "github.com/redis/go-redis/v9" 15 | ) 16 | 17 | type ( 18 | // Entry defines an entry to store in a cache. 19 | Entry struct { 20 | Columns []string 21 | Values [][]driver.Value 22 | } 23 | 24 | // A Key defines a comparable Go value. 25 | // See http://golang.org/ref/spec#Comparison_operators 26 | Key any 27 | 28 | // AddGetDeleter defines the interface for getting, 29 | // adding and deleting entries from the cache. 30 | AddGetDeleter interface { 31 | Del(context.Context, Key) error 32 | Add(context.Context, Key, *Entry, time.Duration) error 33 | Get(context.Context, Key) (*Entry, error) 34 | } 35 | ) 36 | 37 | func init() { 38 | // Register non builtin driver.Values. 39 | gob.Register(time.Time{}) 40 | } 41 | 42 | // MarshalBinary implements the encoding.BinaryMarshaler interface. 43 | func (e Entry) MarshalBinary() ([]byte, error) { 44 | entry := struct { 45 | C []string 46 | V [][]driver.Value 47 | }{ 48 | C: e.Columns, 49 | V: e.Values, 50 | } 51 | var buf bytes.Buffer 52 | if err := gob.NewEncoder(&buf).Encode(entry); err != nil { 53 | return nil, err 54 | } 55 | return buf.Bytes(), nil 56 | } 57 | 58 | // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. 59 | func (e *Entry) UnmarshalBinary(buf []byte) error { 60 | var entry struct { 61 | C []string 62 | V [][]driver.Value 63 | } 64 | if err := gob.NewDecoder(bytes.NewBuffer(buf)).Decode(&entry); err != nil { 65 | return err 66 | } 67 | e.Values = entry.V 68 | e.Columns = entry.C 69 | return nil 70 | } 71 | 72 | // ErrNotFound is returned by Get when and Entry does not exist in the cache. 73 | var ErrNotFound = errors.New("entcache: entry was not found") 74 | 75 | type ( 76 | // LRU provides an LRU cache that implements the AddGetter interface. 77 | LRU struct { 78 | mu sync.Mutex 79 | *lru.Cache 80 | } 81 | // entry wraps the Entry with additional expiry information. 82 | entry struct { 83 | *Entry 84 | expiry time.Time 85 | } 86 | ) 87 | 88 | // NewLRU creates a new Cache. 89 | // If maxEntries is zero, the cache has no limit. 90 | func NewLRU(maxEntries int) *LRU { 91 | return &LRU{ 92 | Cache: lru.New(maxEntries), 93 | } 94 | } 95 | 96 | // Add adds the entry to the cache. 97 | func (l *LRU) Add(_ context.Context, k Key, e *Entry, ttl time.Duration) error { 98 | l.mu.Lock() 99 | defer l.mu.Unlock() 100 | buf, err := e.MarshalBinary() 101 | if err != nil { 102 | return err 103 | } 104 | ne := &Entry{} 105 | if err := ne.UnmarshalBinary(buf); err != nil { 106 | return err 107 | } 108 | if ttl == 0 { 109 | l.Cache.Add(k, ne) 110 | } else { 111 | l.Cache.Add(k, &entry{Entry: ne, expiry: time.Now().Add(ttl)}) 112 | } 113 | return nil 114 | } 115 | 116 | // Get gets an entry from the cache. 117 | func (l *LRU) Get(_ context.Context, k Key) (*Entry, error) { 118 | l.mu.Lock() 119 | e, ok := l.Cache.Get(k) 120 | l.mu.Unlock() 121 | if !ok { 122 | return nil, ErrNotFound 123 | } 124 | switch e := e.(type) { 125 | case *Entry: 126 | return e, nil 127 | case *entry: 128 | if time.Now().Before(e.expiry) { 129 | return e.Entry, nil 130 | } 131 | l.mu.Lock() 132 | l.Cache.Remove(k) 133 | l.mu.Unlock() 134 | return nil, ErrNotFound 135 | default: 136 | return nil, fmt.Errorf("entcache: unexpected entry type: %T", e) 137 | } 138 | } 139 | 140 | // Del deletes an entry from the cache. 141 | func (l *LRU) Del(_ context.Context, k Key) error { 142 | l.mu.Lock() 143 | l.Cache.Remove(k) 144 | l.mu.Unlock() 145 | return nil 146 | } 147 | 148 | // Redis provides a remote cache backed by Redis 149 | // and implements the SetGetter interface. 150 | type Redis struct { 151 | c redis.Cmdable 152 | } 153 | 154 | // NewRedis returns a new Redis cache level from the given Redis connection. 155 | // 156 | // entcache.NewRedis(redis.NewClient(&redis.Options{ 157 | // Addr: ":6379" 158 | // })) 159 | // 160 | // entcache.NewRedis(redis.NewClusterClient(&redis.ClusterOptions{ 161 | // Addrs: []string{":7000", ":7001", ":7002"}, 162 | // })) 163 | func NewRedis(c redis.Cmdable) *Redis { 164 | return &Redis{c: c} 165 | } 166 | 167 | // Add adds the entry to the cache. 168 | func (r *Redis) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error { 169 | key := fmt.Sprint(k) 170 | if key == "" { 171 | return nil 172 | } 173 | buf, err := e.MarshalBinary() 174 | if err != nil { 175 | return err 176 | } 177 | if err := r.c.Set(ctx, key, buf, ttl).Err(); err != nil { 178 | return err 179 | } 180 | return nil 181 | } 182 | 183 | // Get gets an entry from the cache. 184 | func (r *Redis) Get(ctx context.Context, k Key) (*Entry, error) { 185 | key := fmt.Sprint(k) 186 | if key == "" { 187 | return nil, ErrNotFound 188 | } 189 | buf, err := r.c.Get(ctx, key).Bytes() 190 | if err != nil || len(buf) == 0 { 191 | return nil, ErrNotFound 192 | } 193 | e := &Entry{} 194 | if err := e.UnmarshalBinary(buf); err != nil { 195 | return nil, err 196 | } 197 | return e, nil 198 | } 199 | 200 | // Del deletes an entry from the cache. 201 | func (r *Redis) Del(ctx context.Context, k Key) error { 202 | key := fmt.Sprint(k) 203 | if key == "" { 204 | return nil 205 | } 206 | return r.c.Del(ctx, key).Err() 207 | } 208 | 209 | // multiLevel provides a multi-level cache implementation. 210 | type multiLevel struct { 211 | levels []AddGetDeleter 212 | } 213 | 214 | // Add adds the entry to the cache. 215 | func (m *multiLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error { 216 | for i := range m.levels { 217 | if err := m.levels[i].Add(ctx, k, e, ttl); err != nil { 218 | return err 219 | } 220 | } 221 | return nil 222 | } 223 | 224 | // Get gets an entry from the cache. 225 | func (m *multiLevel) Get(ctx context.Context, k Key) (*Entry, error) { 226 | for i := range m.levels { 227 | switch e, err := m.levels[i].Get(ctx, k); { 228 | case err == nil: 229 | return e, nil 230 | case err != ErrNotFound: 231 | return nil, err 232 | } 233 | } 234 | return nil, ErrNotFound 235 | } 236 | 237 | // Del deletes an entry from the cache. 238 | func (m *multiLevel) Del(ctx context.Context, k Key) error { 239 | for i := range m.levels { 240 | if err := m.levels[i].Del(ctx, k); err != nil { 241 | return err 242 | } 243 | } 244 | return nil 245 | } 246 | 247 | // contextLevel provides a context/request level cache implementation. 248 | type contextLevel struct{} 249 | 250 | // Get gets an entry from the cache. 251 | func (*contextLevel) Get(ctx context.Context, k Key) (*Entry, error) { 252 | c, ok := FromContext(ctx) 253 | if !ok { 254 | return nil, ErrNotFound 255 | } 256 | return c.Get(ctx, k) 257 | } 258 | 259 | // Add adds the entry to the cache. 260 | func (*contextLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error { 261 | c, ok := FromContext(ctx) 262 | if !ok { 263 | return nil 264 | } 265 | return c.Add(ctx, k, e, ttl) 266 | } 267 | 268 | // Del deletes an entry from the cache. 269 | func (*contextLevel) Del(ctx context.Context, k Key) error { 270 | c, ok := FromContext(ctx) 271 | if !ok { 272 | return nil 273 | } 274 | return c.Del(ctx, k) 275 | } 276 | --------------------------------------------------------------------------------