├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_zh_Hans.md ├── basic_tenant_info.go ├── cache.go ├── conn_str_generator.go ├── context.go ├── data ├── conn_str.go ├── conn_str_resolver.go └── context.go ├── docs ├── mode1.png ├── mode1_zh.png ├── mode2.png └── mode2_zh.png ├── ent ├── ent.go └── readme.md ├── examples ├── .gitignore ├── ent │ ├── go.mod │ ├── go.sum │ ├── main.go │ ├── migration.go │ ├── readme.md │ ├── seed.go │ ├── shared │ │ └── ent │ │ │ ├── client.go │ │ │ ├── ent.go │ │ │ ├── enttest │ │ │ └── enttest.go │ │ │ ├── generate.go │ │ │ ├── hook │ │ │ └── hook.go │ │ │ ├── intercept │ │ │ └── intercept.go │ │ │ ├── internal │ │ │ └── schema.go │ │ │ ├── migrate │ │ │ ├── migrate.go │ │ │ └── schema.go │ │ │ ├── mutation.go │ │ │ ├── post.go │ │ │ ├── post │ │ │ ├── post.go │ │ │ └── where.go │ │ │ ├── post_create.go │ │ │ ├── post_delete.go │ │ │ ├── post_query.go │ │ │ ├── post_update.go │ │ │ ├── predicate │ │ │ └── predicate.go │ │ │ ├── privacy │ │ │ └── privacy.go │ │ │ ├── runtime.go │ │ │ ├── runtime │ │ │ └── runtime.go │ │ │ ├── schema │ │ │ ├── post.go │ │ │ ├── tenant.go │ │ │ └── tenantconn.go │ │ │ ├── tenant.go │ │ │ ├── tenant │ │ │ ├── tenant.go │ │ │ └── where.go │ │ │ ├── tenant_create.go │ │ │ ├── tenant_delete.go │ │ │ ├── tenant_query.go │ │ │ ├── tenant_update.go │ │ │ ├── tenantconn.go │ │ │ ├── tenantconn │ │ │ ├── tenantconn.go │ │ │ └── where.go │ │ │ ├── tenantconn_create.go │ │ │ ├── tenantconn_delete.go │ │ │ ├── tenantconn_query.go │ │ │ ├── tenantconn_update.go │ │ │ └── tx.go │ ├── tenant │ │ └── ent │ │ │ ├── client.go │ │ │ ├── ent.go │ │ │ ├── enttest │ │ │ └── enttest.go │ │ │ ├── generate.go │ │ │ ├── hook │ │ │ └── hook.go │ │ │ ├── intercept │ │ │ └── intercept.go │ │ │ ├── internal │ │ │ └── schema.go │ │ │ ├── migrate │ │ │ ├── migrate.go │ │ │ └── schema.go │ │ │ ├── mutation.go │ │ │ ├── post.go │ │ │ ├── post │ │ │ ├── post.go │ │ │ └── where.go │ │ │ ├── post_create.go │ │ │ ├── post_delete.go │ │ │ ├── post_query.go │ │ │ ├── post_update.go │ │ │ ├── predicate │ │ │ └── predicate.go │ │ │ ├── privacy │ │ │ └── privacy.go │ │ │ ├── runtime.go │ │ │ ├── runtime │ │ │ └── runtime.go │ │ │ ├── schema │ │ │ └── post.go │ │ │ └── tx.go │ └── tenant_store.go └── gorm │ ├── docker-compose.yml │ ├── go.mod │ ├── go.sum │ ├── main.go │ ├── migration.go │ ├── post.go │ ├── postgres-utils.go │ ├── readme.md │ ├── seed.go │ └── tenant.go ├── gateway └── apisix │ ├── readme.md │ ├── resolver.go │ ├── saas.go │ └── saas_test.go ├── gin ├── multi_tenancy.go ├── multi_tenancy_test.go └── readme.md ├── go.mod ├── go.sum ├── gorm ├── gorm.go ├── has_tenant.go ├── has_tenant_test.go ├── readme.md └── sqlite_db_test.go ├── http ├── cookie_tenant_resolve_contrib.go ├── domain_tenant_resolve_contrib.go ├── form_tenant_resolve_contrib.go ├── header_tenant_resolve_contrib.go ├── multi_tenancy.go ├── multi_tenancy_test.go ├── query_tenant_resolve_contrib.go └── web_multi_tenancy_option.go ├── iris ├── multi_tenancy.go ├── multi_tenancy_test.go └── readme.md ├── kratos ├── header_tenant_resolve_contributor.go └── multi_tenantcy.go ├── multi_tenancy_conn_str_resolver.go ├── multi_tenancy_option.go ├── multi_tenancy_side.go ├── provider.go ├── seed ├── context.go ├── contrib.go ├── option.go └── seeder.go ├── tenant_config.go ├── tenant_config_provider.go ├── tenant_config_test.go ├── tenant_resolve_context.go ├── tenant_resolve_contrib.go ├── tenant_resolve_option.go ├── tenant_resolve_result.go ├── tenant_resolver.go ├── tenant_store.go └── utils.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v4 18 | with: 19 | go-version: '1.20' 20 | 21 | - name: Build 22 | run: go build -v ./... 23 | 24 | - name: Test 25 | run: go test -v ./... -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .idea/ 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Goxiaoy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-saas 2 | 3 | [English](./README.md) | [中文文档](./README_zh_Hans.md) 4 | 5 | headless go framework for saas(multi-tenancy). 6 | `go-saas` targets to provide saas solution for go 7 | this project suits for simple (web) project, which is also called monolithic. 8 | 9 | if you are finding complete solution which is microservice compatible, please refer to [go-saas-kit](https://github.com/go-saas/kit) 10 | 11 | # Overview 12 | 13 | ## Feature 14 | 15 | * Different database architecture 16 | * [x] Single-tenancy: Each database stores data from only one tenant. 17 | 18 | ![img.png](docs/mode1.png) 19 | 20 | * [x] Multi-tenancy: Each database stores data from multiple separate tenants (with mechanisms to protect data privacy). 21 | 22 | ![img.png](docs/mode2.png) 23 | 24 | * [x] Hybrid tenancy models are also available. 25 | 26 | * [x] Implement your own resolver to achieve style like sharding 27 | 28 | 29 | * Support multiple web framework 30 | * [x] [gin](https://github.com/gin-gonic/gin) 31 | * [x] [iris](https://github.com/kataras/iris) 32 | * [x] net/http 33 | * [x] [kratos](https://github.com/go-kratos/kratos) 34 | * Supported orm with data filter, which means all underlying database 35 | * [x] [gorm](https://github.com/go-gorm/gorm) 36 | * [x] [ent](https://entgo.io/) 37 | * Customizable tenant resolver 38 | * [x] Query String 39 | * [x] Form parameters 40 | * [x] Header 41 | * [x] Cookie 42 | * [x] Domain format 43 | * Seed and Migration 44 | * [x] Seed/Migrate tenant database after creation or upgrade to new version 45 | * Integration with gateway 46 | * [x] [apisix](https://github.com/apache/apisix) 47 | 48 | 49 | ## Install 50 | 51 | ``` 52 | go get github.com/go-saas/saas 53 | ``` 54 | 55 | ## Design 56 | ```mermaid 57 | graph TD 58 | A(InComming Request) -->|cookie,domain,form,header,query...|B(TenantResolver) 59 | B --> C(Tenant Context) --> D(ConnectionString Resolver) 60 | D --> E(Tenant 1) --> J(Data Filter) --> H(Shared Database) 61 | D --> F(Tenant 2) --> J 62 | D --> G(Tenant 3) --> I(Tenant 3 Database) 63 | ``` 64 | 65 | 66 | # Sample Project 67 | * [example-gorm](https://github.com/go-saas/saas/tree/main/examples/gorm) combination of `go-saas`,`gin`,`gorm(sqlite/mysql)` 68 | * [example-ent](https://github.com/go-saas/saas/tree/main/examples/ent) combination of `go-saas`,`gin`,`ent(sqlite)` 69 | * [go-saas-kit](https://github.com/go-saas/kit) Microservice architecture starter kit for golang sass project 70 | 71 | # Documentation 72 | Refer to [wiki](https://github.com/go-saas/saas/wiki) 73 | 74 | 75 | # References 76 | 77 | https://docs.microsoft.com/en-us/azure/azure-sql/database/saas-tenancy-app-design-patterns 78 | -------------------------------------------------------------------------------- /README_zh_Hans.md: -------------------------------------------------------------------------------- 1 | # go-saas 2 | 3 | [English](./README.md) | [中文文档](./README_zh_Hans.md) 4 | 5 | 无头(无UI)的go语言的多租户框架。 6 | 本项目适合于简单的/单体(Web)项目,完整版本(支持微服务)可以看看[go-saas-kit](https://github.com/go-saas/kit) 7 | 8 | # 概览 9 | ## 功能 10 | 11 | * 不同的数据储存方式 12 | * [x] 每个租户各有数据库: 13 | 14 | ![img.png](docs/mode1_zh.png) 15 | 16 | * [x] 各个租户共享数据库: (数据访问层提供隔离) 17 | 18 | ![img.png](docs/mode2_zh.png) 19 | 20 | * [x] 混合模式 21 | 22 | * [x] 实现自己的Resolver来自定义,比如说像分片啥的 23 | 24 | * 支持多种Web框架 25 | * [x] [gin](https://github.com/gin-gonic/gin) 26 | * [x] net/http 27 | * [x] [kratos](https://github.com/go-kratos/kratos) 28 | * 共享数据库下,支持自动数据隔离的Orm, 包括Orm所支持的数据库 29 | * [x] [gorm](https://github.com/go-gorm/gorm) 30 | * 自定义租户解析 31 | * [x] Query String 32 | * [x] Form parameters 33 | * [x] Header 34 | * [x] Cookie 35 | * [x] Domain format 36 | * 初始化和数据库迁移 37 | * [x] 租户创建后初始化/迁移 数据库,或者以后升级到新的版本 38 | * 和网关集成 39 | * [x] [apisix](https://github.com/apache/apisix) 40 | 41 | 42 | ## 安装 43 | 44 | ``` 45 | go get github.com/go-saas/saas 46 | ``` 47 | 48 | ## 设计 49 | ```mermaid 50 | graph TD 51 | A(InComming Request) -->|cookie,domain,form,header,query...|B(TenantResolver) 52 | B --> C(Tenant Context) --> D(ConnectionString Resolver) 53 | D --> E(Tenant 1) --> J(Data Filter) --> H(Shared Database) 54 | D --> F(Tenant 2) --> J 55 | D --> G(Tenant 3) --> I(Tenant 3 Database) 56 | ``` 57 | 58 | 59 | # 示例 60 | * [example](https://github.com/go-saas/saas/tree/main/examples) 使用 `go-saas`,`gin`,`gorm(sqlite/mysql)` 61 | * [go-saas-kit](https://github.com/go-saas/kit) golang多租户微服务解决方案 62 | 63 | # 文档 64 | [wiki](https://github.com/go-saas/saas/wiki) 65 | 66 | 67 | # 参考 68 | 69 | https://docs.microsoft.com/zh-cn/azure/azure-sql/database/saas-tenancy-app-design-patterns 70 | -------------------------------------------------------------------------------- /basic_tenant_info.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | type TenantInfo interface { 4 | GetId() string 5 | GetName() string 6 | } 7 | 8 | type BasicTenantInfo struct { 9 | Id string 10 | Name string 11 | } 12 | 13 | func (b *BasicTenantInfo) GetId() string { 14 | return b.Id 15 | } 16 | 17 | func (b *BasicTenantInfo) GetName() string { 18 | return b.Name 19 | } 20 | 21 | func NewBasicTenantInfo(id string, name string) *BasicTenantInfo { 22 | return &BasicTenantInfo{Id: id, Name: name} 23 | } 24 | -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "container/list" 5 | "fmt" 6 | "io" 7 | "sync" 8 | ) 9 | 10 | // Cache is used a LRU (Least recently used) cache replacement policy. adapted from https://github.com/Code-Hex/go-generics-cache/blob/main/policy/lru/lru.go 11 | // 12 | // Discards the least recently used items first. This algorithm requires 13 | // keeping track of what was used when, which is expensive if one wants 14 | // to make sure the algorithm always discards the least recently used item. 15 | type Cache[K comparable, V io.Closer] struct { 16 | cap int 17 | list *list.List 18 | items map[K]*list.Element 19 | mu sync.Mutex 20 | } 21 | 22 | type entry[K comparable, V any] struct { 23 | key K 24 | val V 25 | } 26 | 27 | // Option is an option for LRU cache. 28 | type Option func(*options) 29 | 30 | type options struct { 31 | capacity int 32 | } 33 | 34 | func newOptions() *options { 35 | return &options{ 36 | capacity: 128, 37 | } 38 | } 39 | 40 | // WithCapacity is an option to set cache capacity. 41 | func WithCapacity(cap int) Option { 42 | return func(o *options) { 43 | o.capacity = cap 44 | } 45 | } 46 | 47 | // NewCache creates a new thread safe LRU cache whose capacity is the default size (128). 48 | func NewCache[K comparable, V io.Closer](opts ...Option) *Cache[K, V] { 49 | o := newOptions() 50 | for _, optFunc := range opts { 51 | optFunc(o) 52 | } 53 | return &Cache[K, V]{ 54 | cap: o.capacity, 55 | list: list.New(), 56 | items: make(map[K]*list.Element, o.capacity), 57 | } 58 | } 59 | 60 | // Get looks up a key's value from the cache. 61 | func (c *Cache[K, V]) Get(key K) (zero V, _ bool) { 62 | c.mu.Lock() 63 | defer c.mu.Unlock() 64 | return c.get(key) 65 | } 66 | 67 | func (c *Cache[K, V]) get(key K) (zero V, _ bool) { 68 | e, ok := c.items[key] 69 | if !ok { 70 | return 71 | } 72 | // updates cache order 73 | c.list.MoveToFront(e) 74 | return e.Value.(*entry[K, V]).val, true 75 | } 76 | 77 | func (c *Cache[K, V]) Set(key K, val V) { 78 | c.mu.Lock() 79 | defer c.mu.Unlock() 80 | c.set(key, val) 81 | } 82 | 83 | // Set sets a value to the cache with key. replacing any existing value. 84 | func (c *Cache[K, V]) set(key K, val V) { 85 | 86 | if e, ok := c.items[key]; ok { 87 | // updates cache order 88 | c.list.MoveToFront(e) 89 | entry := e.Value.(*entry[K, V]) 90 | entry.val = val 91 | return 92 | } 93 | 94 | newEntry := &entry[K, V]{ 95 | key: key, 96 | val: val, 97 | } 98 | e := c.list.PushFront(newEntry) 99 | c.items[key] = e 100 | 101 | if c.list.Len() > c.cap { 102 | c.deleteOldest() 103 | } 104 | } 105 | 106 | // GetOrSet combine Get and Set 107 | func (c *Cache[K, V]) GetOrSet(key K, factory func() (V, error)) (zero V, set bool, err error) { 108 | c.mu.Lock() 109 | defer c.mu.Unlock() 110 | if v, ok := c.get(key); ok { 111 | return v, false, nil 112 | } 113 | //use factory 114 | v, err := factory() 115 | if err != nil { 116 | return zero, false, err 117 | } 118 | c.set(key, v) 119 | return v, true, nil 120 | } 121 | 122 | // Keys returns the keys of the cache. the order is from oldest to newest. 123 | func (c *Cache[K, V]) Keys() []K { 124 | c.mu.Lock() 125 | defer c.mu.Unlock() 126 | return c.keys() 127 | } 128 | 129 | func (c *Cache[K, V]) keys() []K { 130 | keys := make([]K, 0, len(c.items)) 131 | for ent := c.list.Back(); ent != nil; ent = ent.Prev() { 132 | entry := ent.Value.(*entry[K, V]) 133 | keys = append(keys, entry.key) 134 | } 135 | return keys 136 | } 137 | 138 | // Len returns the number of items in the cache. 139 | func (c *Cache[K, V]) Len() int { 140 | c.mu.Lock() 141 | defer c.mu.Unlock() 142 | return c.list.Len() 143 | } 144 | 145 | // Delete deletes the item with provided key from the cache. 146 | func (c *Cache[K, V]) Delete(key K) { 147 | c.mu.Lock() 148 | defer c.mu.Unlock() 149 | c.deleteKey(key) 150 | } 151 | 152 | func (c *Cache[K, V]) deleteKey(key K) error { 153 | if e, ok := c.items[key]; ok { 154 | return c.delete(e) 155 | } 156 | return nil 157 | } 158 | 159 | // Flush delete all items 160 | func (c *Cache[K, V]) Flush() error { 161 | c.mu.Lock() 162 | defer c.mu.Unlock() 163 | var err error 164 | for _, k := range c.keys() { 165 | nerr := c.deleteKey(k) 166 | if nerr != nil { 167 | if err == nil { 168 | err = nerr 169 | } else { 170 | err = fmt.Errorf("%w; ", err) 171 | } 172 | } 173 | } 174 | return err 175 | } 176 | 177 | func (c *Cache[K, V]) deleteOldest() { 178 | c.mu.Lock() 179 | defer c.mu.Unlock() 180 | e := c.list.Back() 181 | c.delete(e) 182 | } 183 | 184 | func (c *Cache[K, V]) delete(e *list.Element) error { 185 | c.list.Remove(e) 186 | entry := e.Value.(*entry[K, V]) 187 | delete(c.items, entry.key) 188 | 189 | return entry.val.Close() 190 | } 191 | -------------------------------------------------------------------------------- /conn_str_generator.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // ConnStrGenerator generate connection string for tenant. useful for tenant creation 9 | type ConnStrGenerator interface { 10 | Gen(ctx context.Context, tenant TenantInfo) (string, error) 11 | } 12 | 13 | type DefaultConnStrGenerator struct { 14 | format string 15 | } 16 | 17 | var _ ConnStrGenerator = (*DefaultConnStrGenerator)(nil) 18 | 19 | func NewConnStrGenerator(format string) *DefaultConnStrGenerator { 20 | return &DefaultConnStrGenerator{format: format} 21 | } 22 | 23 | func (d *DefaultConnStrGenerator) Gen(ctx context.Context, tenant TenantInfo) (string, error) { 24 | if len(tenant.GetId()) == 0 { 25 | return "", nil 26 | } 27 | return fmt.Sprintf(d.format, tenant.GetId()), nil 28 | } 29 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type ( 8 | currentTenantCtx struct{} 9 | tenantResolveRes struct{} 10 | tenantConfigKey string 11 | ) 12 | 13 | func NewCurrentTenant(ctx context.Context, id, name string) context.Context { 14 | return NewCurrentTenantInfo(ctx, NewBasicTenantInfo(id, name)) 15 | } 16 | 17 | func NewCurrentTenantInfo(ctx context.Context, info TenantInfo) context.Context { 18 | return context.WithValue(ctx, currentTenantCtx{}, info) 19 | } 20 | 21 | func FromCurrentTenant(ctx context.Context) (TenantInfo, bool) { 22 | value, ok := ctx.Value(currentTenantCtx{}).(TenantInfo) 23 | if ok { 24 | return value, true 25 | } 26 | return NewBasicTenantInfo("", ""), false 27 | } 28 | 29 | func NewTenantResolveRes(ctx context.Context, t *TenantResolveResult) context.Context { 30 | return context.WithValue(ctx, tenantResolveRes{}, t) 31 | } 32 | 33 | func FromTenantResolveRes(ctx context.Context) *TenantResolveResult { 34 | v, ok := ctx.Value(tenantResolveRes{}).(*TenantResolveResult) 35 | if ok { 36 | return v 37 | } 38 | return nil 39 | } 40 | 41 | func NewTenantConfigContext(ctx context.Context, tenantId string, cfg *TenantConfig) context.Context { 42 | return context.WithValue(ctx, tenantConfigKey(tenantId), cfg) 43 | } 44 | 45 | func FromTenantConfigContext(ctx context.Context, tenantId string) (*TenantConfig, bool) { 46 | v, ok := ctx.Value(tenantConfigKey(tenantId)).(*TenantConfig) 47 | if ok { 48 | return v, ok && v != nil 49 | } 50 | return nil, false 51 | } 52 | -------------------------------------------------------------------------------- /data/conn_str.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import "context" 4 | 5 | type ConnStrings map[string]string 6 | 7 | const Default = "default" 8 | 9 | func (c ConnStrings) Default() string { 10 | return c[Default] 11 | } 12 | 13 | func (c ConnStrings) Resolve(_ context.Context, key string) (string, error) { 14 | s := c.getOrDefault(key) 15 | return s, nil 16 | } 17 | 18 | func (c ConnStrings) getOrDefault(k string) string { 19 | if len(k) == 0 { 20 | return c.Default() 21 | } 22 | ret := c[k] 23 | if ret == "" { 24 | return c.Default() 25 | } 26 | return ret 27 | } 28 | 29 | func (c ConnStrings) SetDefault(value string) { 30 | c[Default] = value 31 | } 32 | -------------------------------------------------------------------------------- /data/conn_str_resolver.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import "context" 4 | 5 | type ConnStrResolver interface { 6 | // Resolve connection string by user-friendly key 7 | Resolve(ctx context.Context, key string) (string, error) 8 | } 9 | 10 | type ConnStrResolverFunc func(ctx context.Context, key string) (string, error) 11 | 12 | func (c ConnStrResolverFunc) Resolve(ctx context.Context, key string) (string, error) { 13 | return c(ctx, key) 14 | } 15 | 16 | var _ ConnStrResolver = (*ConnStrResolverFunc)(nil) 17 | 18 | func ChainConnStrResolver(cs ...ConnStrResolver) ConnStrResolver { 19 | return ConnStrResolverFunc(func(ctx context.Context, key string) (string, error) { 20 | for _, c := range cs { 21 | conn, err := c.Resolve(ctx, key) 22 | if err != nil { 23 | return "", err 24 | } 25 | if len(conn) > 0 { 26 | return conn, err 27 | } 28 | } 29 | return "", nil 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /data/context.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type ( 8 | //soft delete status 9 | multiTenancyDataFilterCtx struct{} 10 | autoSetTenantIdCtx struct{} 11 | ) 12 | 13 | func NewMultiTenancyDataFilter(ctx context.Context, enable ...bool) context.Context { 14 | v := true 15 | if len(enable) > 0 { 16 | v = enable[0] 17 | } 18 | return context.WithValue(ctx, multiTenancyDataFilterCtx{}, v) 19 | } 20 | 21 | //FromMultiTenancyDataFilter resolve where apply multi tenancy data filter, default true 22 | func FromMultiTenancyDataFilter(ctx context.Context) bool { 23 | v := ctx.Value(multiTenancyDataFilterCtx{}) 24 | if v == nil { 25 | return true 26 | } 27 | return v.(bool) 28 | } 29 | 30 | func NewAutoSetTenantId(ctx context.Context, enable ...bool) context.Context { 31 | v := true 32 | if len(enable) > 0 { 33 | v = enable[0] 34 | } 35 | return context.WithValue(ctx, autoSetTenantIdCtx{}, v) 36 | } 37 | 38 | func FromAutoSetTenantId(ctx context.Context) bool { 39 | v := ctx.Value(autoSetTenantIdCtx{}) 40 | if v == nil { 41 | return true 42 | } 43 | return v.(bool) 44 | } 45 | -------------------------------------------------------------------------------- /docs/mode1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-saas/saas/ad5675338984beec1433054246443a664bb6916c/docs/mode1.png -------------------------------------------------------------------------------- /docs/mode1_zh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-saas/saas/ad5675338984beec1433054246443a664bb6916c/docs/mode1_zh.png -------------------------------------------------------------------------------- /docs/mode2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-saas/saas/ad5675338984beec1433054246443a664bb6916c/docs/mode2.png -------------------------------------------------------------------------------- /docs/mode2_zh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-saas/saas/ad5675338984beec1433054246443a664bb6916c/docs/mode2_zh.png -------------------------------------------------------------------------------- /ent/ent.go: -------------------------------------------------------------------------------- 1 | package ent 2 | 3 | import ( 4 | "context" 5 | entgo "entgo.io/ent" 6 | "entgo.io/ent/dialect/sql" 7 | "entgo.io/ent/schema/field" 8 | "github.com/go-saas/saas" 9 | "github.com/go-saas/saas/data" 10 | "reflect" 11 | ) 12 | 13 | type HasTenant struct { 14 | entgo.Schema 15 | } 16 | 17 | func (HasTenant) Fields() []entgo.Field { 18 | return []entgo.Field{ 19 | field.String("tenant_id").Optional().GoType(&sql.NullString{}), 20 | } 21 | } 22 | 23 | type WhereP interface{ WhereP(...func(*sql.Selector)) } 24 | 25 | func (h HasTenant) P(t saas.TenantInfo, w interface{ WhereP(...func(*sql.Selector)) }) { 26 | if len(t.GetId()) == 0 { 27 | w.WhereP( 28 | sql.FieldIsNull(h.Fields()[0].Descriptor().Name)) 29 | return 30 | } 31 | w.WhereP( 32 | sql.FieldEQ(h.Fields()[0].Descriptor().Name, t.GetId())) 33 | } 34 | 35 | func (h HasTenant) Interceptors() []entgo.Interceptor { 36 | return []entgo.Interceptor{ 37 | entgo.TraverseFunc(func(ctx context.Context, q entgo.Query) error { 38 | e := data.FromMultiTenancyDataFilter(ctx) 39 | if !e { 40 | // Skip tenant filter 41 | return nil 42 | } 43 | 44 | ct, _ := saas.FromCurrentTenant(ctx) 45 | //TODO we can not call WhereP directly because q does not implement it. So we need to use reflection 46 | addFilter := func(sqls ...func(*sql.Selector)) { 47 | in := make([]reflect.Value, len(sqls)) 48 | for i := range in { 49 | in[i] = reflect.ValueOf(sqls[i]) 50 | } 51 | reflect.ValueOf(q).MethodByName("Where").Call(in) 52 | } 53 | if len(ct.GetId()) == 0 { 54 | addFilter(sql.FieldIsNull(h.Fields()[0].Descriptor().Name)) 55 | } else { 56 | addFilter(sql.FieldEQ(h.Fields()[0].Descriptor().Name, ct.GetId())) 57 | } 58 | //h.P(ct, q) 59 | return nil 60 | }), 61 | } 62 | } 63 | 64 | func (h HasTenant) Hooks() []entgo.Hook { 65 | return []entgo.Hook{ 66 | On( 67 | func(next entgo.Mutator) entgo.Mutator { 68 | type hasTenant interface { 69 | SetOp(entgo.Op) 70 | SetTenantID(ss *sql.NullString) 71 | WhereP(...func(*sql.Selector)) 72 | } 73 | return entgo.MutateFunc(func(ctx context.Context, mutation entgo.Mutation) (entgo.Value, error) { 74 | if hf, ok := mutation.(hasTenant); ok { 75 | ct, _ := saas.FromCurrentTenant(ctx) 76 | at := data.FromAutoSetTenantId(ctx) 77 | if ok && at { 78 | if ct.GetId() != "" { 79 | //normalize tenant side only 80 | hf.SetTenantID(&sql.NullString{ 81 | String: ct.GetId(), 82 | Valid: true, 83 | }) 84 | } 85 | } 86 | } 87 | return next.Mutate(ctx, mutation) 88 | }) 89 | }, 90 | entgo.OpCreate|entgo.OpUpdate|entgo.OpUpdateOne|entgo.OpDeleteOne|entgo.OpDelete, 91 | ), 92 | } 93 | } 94 | 95 | func On(hk entgo.Hook, op entgo.Op) entgo.Hook { 96 | return If(hk, HasOp(op)) 97 | } 98 | 99 | // Condition is a hook condition function. 100 | type Condition func(context.Context, entgo.Mutation) bool 101 | 102 | // If executes the given hook under condition. 103 | // 104 | // hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) 105 | func If(hk entgo.Hook, cond Condition) entgo.Hook { 106 | return func(next entgo.Mutator) entgo.Mutator { 107 | return entgo.MutateFunc(func(ctx context.Context, m entgo.Mutation) (entgo.Value, error) { 108 | if cond(ctx, m) { 109 | return hk(next).Mutate(ctx, m) 110 | } 111 | return next.Mutate(ctx, m) 112 | }) 113 | } 114 | } 115 | 116 | // HasOp is a condition testing mutation operation. 117 | func HasOp(op entgo.Op) Condition { 118 | return func(_ context.Context, m entgo.Mutation) bool { 119 | return m.Op().Is(op) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /ent/readme.md: -------------------------------------------------------------------------------- 1 | # [Ent](https://entgo.io/) adapter 2 | 3 | - Enable [EntQL Filtering](https://entgo.io/docs/feature-flags/#entql-filtering) and [Privacy Layer](https://entgo.io/docs/feature-flags/#privacy-layer) features 4 | Modify your `ent/generate.go` 5 | ``` 6 | go generate ... --feature intercept,schema/snapshot ... 7 | ``` 8 | 9 | 10 | - Embed mixin into your schema 11 | 12 | ```go 13 | import ( 14 | sent "github.com/go-saas/saas/ent" 15 | ) 16 | ... 17 | // Post holds the schema definition for the Post entity. 18 | type Post struct { 19 | ent.Schema 20 | } 21 | 22 | func (Post) Mixin() []ent.Mixin { 23 | return []ent.Mixin{ 24 | sent.HasTenant{}, 25 | } 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | *.db -------------------------------------------------------------------------------- /examples/ent/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-saas/saas/examples/ent 2 | 3 | go 1.20 4 | 5 | replace github.com/go-saas/saas => ../../ 6 | 7 | require ( 8 | entgo.io/ent v0.12.4 9 | github.com/gin-gonic/gin v1.8.1 10 | github.com/go-saas/saas v0.0.0-00010101000000-000000000000 11 | github.com/mattn/go-sqlite3 v1.14.16 12 | ) 13 | 14 | require ( 15 | ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 // indirect 16 | github.com/agext/levenshtein v1.2.1 // indirect 17 | github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect 18 | github.com/gin-contrib/sse v0.1.0 // indirect 19 | github.com/go-openapi/inflect v0.19.0 // indirect 20 | github.com/go-playground/locales v0.14.0 // indirect 21 | github.com/go-playground/universal-translator v0.18.0 // indirect 22 | github.com/go-playground/validator/v10 v10.11.1 // indirect 23 | github.com/goccy/go-json v0.9.11 // indirect 24 | github.com/google/go-cmp v0.5.6 // indirect 25 | github.com/google/uuid v1.3.1 // indirect 26 | github.com/hashicorp/hcl/v2 v2.13.0 // indirect 27 | github.com/json-iterator/go v1.1.12 // indirect 28 | github.com/leodido/go-urn v1.2.1 // indirect 29 | github.com/mattn/go-isatty v0.0.19 // indirect 30 | github.com/mitchellh/go-wordwrap v1.0.1 // indirect 31 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 32 | github.com/modern-go/reflect2 v1.0.2 // indirect 33 | github.com/pelletier/go-toml/v2 v2.0.5 // indirect 34 | github.com/ugorji/go/codec v1.2.7 // indirect 35 | github.com/zclconf/go-cty v1.8.0 // indirect 36 | golang.org/x/crypto v0.13.0 // indirect 37 | golang.org/x/mod v0.10.0 // indirect 38 | golang.org/x/net v0.15.0 // indirect 39 | golang.org/x/sys v0.12.0 // indirect 40 | golang.org/x/text v0.13.0 // indirect 41 | google.golang.org/protobuf v1.31.0 // indirect 42 | gopkg.in/yaml.v2 v2.4.0 // indirect 43 | ) 44 | -------------------------------------------------------------------------------- /examples/ent/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "github.com/gin-gonic/gin" 6 | "github.com/go-saas/saas" 7 | "github.com/go-saas/saas/data" 8 | "github.com/go-saas/saas/examples/ent/shared/ent" 9 | _ "github.com/go-saas/saas/examples/ent/shared/ent/runtime" 10 | ent2 "github.com/go-saas/saas/examples/ent/tenant/ent" 11 | _ "github.com/go-saas/saas/examples/ent/tenant/ent/runtime" 12 | sgin "github.com/go-saas/saas/gin" 13 | "github.com/go-saas/saas/seed" 14 | _ "github.com/mattn/go-sqlite3" 15 | ) 16 | 17 | type SharedDbProvider saas.DbProvider[*ent.Client] 18 | type TenantDbProvider saas.DbProvider[*ent2.Client] 19 | 20 | func main() { 21 | r := gin.Default() 22 | 23 | cache := saas.NewCache[string, *ent.Client]() 24 | defer cache.Flush() 25 | cache2 := saas.NewCache[string, *ent2.Client]() 26 | defer cache.Flush() 27 | 28 | sharedClientProvider := saas.ClientProviderFunc[*ent.Client](func(ctx context.Context, s string) (*ent.Client, error) { 29 | v, _, err := cache.GetOrSet(s, func() (*ent.Client, error) { 30 | client, err := ent.Open("sqlite3", s, ent.Debug()) 31 | if err != nil { 32 | return nil, err 33 | } 34 | return client, nil 35 | }) 36 | return v, err 37 | }) 38 | tenantClientProvider := saas.ClientProviderFunc[*ent2.Client](func(ctx context.Context, s string) (*ent2.Client, error) { 39 | v, _, err := cache2.GetOrSet(s, func() (*ent2.Client, error) { 40 | client, err := ent2.Open("sqlite3", s, ent2.Debug()) 41 | if err != nil { 42 | return nil, err 43 | } 44 | return client, nil 45 | }) 46 | return v, err 47 | }) 48 | 49 | conn := make(data.ConnStrings, 1) 50 | //default database 51 | conn.SetDefault("./shared.db?_fk=1") 52 | 53 | var tenantStore saas.TenantStore 54 | 55 | //host (shared) database use connection string from config 56 | sharedDbProvider := saas.NewDbProvider[*ent.Client](conn, sharedClientProvider) 57 | 58 | tenantStore = &TenantStore{shared: sharedDbProvider} 59 | 60 | mr := saas.NewMultiTenancyConnStrResolver(tenantStore, conn) 61 | // tenant database use connection string from tenantStore 62 | tenantDbProvider := saas.NewDbProvider[*ent2.Client](mr, tenantClientProvider) 63 | 64 | r.Use(sgin.MultiTenancy(tenantStore)) 65 | 66 | //return current tenant 67 | r.GET("/tenant/current", func(c *gin.Context) { 68 | rCtx := c.Request.Context() 69 | tenantInfo, _ := saas.FromCurrentTenant(rCtx) 70 | trR := saas.FromTenantResolveRes(rCtx) 71 | c.JSON(200, gin.H{ 72 | "tenantId": tenantInfo.GetId(), 73 | "resolvers": trR.AppliedResolvers, 74 | }) 75 | }) 76 | 77 | r.GET("/posts", func(c *gin.Context) { 78 | ctx := c.Request.Context() 79 | tenantInfo, _ := saas.FromCurrentTenant(ctx) 80 | if tenantInfo.GetId() == "" { 81 | db := sharedDbProvider.Get(ctx, "") 82 | e, err := db.Post.Query().All(ctx) 83 | if err != nil { 84 | c.AbortWithError(500, err) 85 | } 86 | c.JSON(200, e) 87 | } else { 88 | db := tenantDbProvider.Get(ctx, "") 89 | e, err := db.Post.Query().All(ctx) 90 | if err != nil { 91 | c.AbortWithError(500, err) 92 | } 93 | c.JSON(200, e) 94 | } 95 | }) 96 | 97 | //seed data into db 98 | seeder := seed.NewDefaultSeeder(NewMigrationSeeder(sharedDbProvider, tenantDbProvider), NewSeed(sharedDbProvider, tenantDbProvider)) 99 | err := seeder.Seed(context.Background(), seed.AddHost(), seed.AddTenant("1", "2", "3")) 100 | if err != nil { 101 | panic(err) 102 | } 103 | 104 | r.Run(":8090") // listen and serve on 0.0.0.0:8090 (for windows "localhost:8090") 105 | } 106 | -------------------------------------------------------------------------------- /examples/ent/migration.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas/seed" 6 | ) 7 | 8 | type MigrationSeeder struct { 9 | shared SharedDbProvider 10 | tenant TenantDbProvider 11 | } 12 | 13 | func NewMigrationSeeder(shared SharedDbProvider, tenant TenantDbProvider) *MigrationSeeder { 14 | return &MigrationSeeder{shared: shared, tenant: tenant} 15 | } 16 | 17 | func (m *MigrationSeeder) Seed(ctx context.Context, sCtx *seed.Context) error { 18 | if sCtx.TenantId == "" { 19 | c := m.shared.Get(ctx, "") 20 | if err := c.Schema.Create(ctx); err != nil { 21 | return err 22 | } 23 | } else { 24 | c := m.tenant.Get(ctx, "") 25 | if err := c.Schema.Create(ctx); err != nil { 26 | return err 27 | } 28 | } 29 | return nil 30 | } 31 | -------------------------------------------------------------------------------- /examples/ent/readme.md: -------------------------------------------------------------------------------- 1 | # Example project 2 | 3 | combination of `go-saas`,`gin`,`ent(sqlite)` 4 | 5 | ```shell 6 | go run github.com/go-saas/saas/examples/ent 7 | ``` 8 | --- 9 | Host side ( use shared database): 10 | 11 | Open `http://localhost:8090/posts` 12 | 13 | --- 14 | Multi-tenancy ( use shared database): 15 | 16 | Open http://localhost:8090/posts?__tenant=1 17 | 18 | Open http://localhost:8090/posts?__tenant=2 19 | 20 | --- 21 | Single-tenancy ( use separate database): 22 | 23 | Open http://localhost:8090/posts?__tenant=3 -------------------------------------------------------------------------------- /examples/ent/seed.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/go-saas/saas/data" 7 | "github.com/go-saas/saas/examples/ent/shared/ent" 8 | "github.com/go-saas/saas/seed" 9 | ) 10 | 11 | type Seed struct { 12 | shared SharedDbProvider 13 | tenant TenantDbProvider 14 | } 15 | 16 | func NewSeed(shared SharedDbProvider, tenant TenantDbProvider) *Seed { 17 | return &Seed{shared: shared, tenant: tenant} 18 | } 19 | 20 | func (s *Seed) Seed(ctx context.Context, sCtx *seed.Context) error { 21 | 22 | if sCtx.TenantId == "" { 23 | //seed host 24 | c := s.shared.Get(ctx, "") 25 | 26 | c3, err := c.TenantConn.Create().SetKey(data.Default).SetValue("./tenant3.db?_fk=1").Save(ctx) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | tenants := make([]*ent.TenantCreate, 3) 32 | tenants[0] = c.Tenant.Create().SetID(1).SetName("Test1") 33 | tenants[1] = c.Tenant.Create().SetID(2).SetName("Test2") 34 | tenants[2] = c.Tenant.Create().SetID(3).SetName("Test3").AddConn(c3) 35 | 36 | err = c.Tenant.CreateBulk(tenants...).OnConflict().UpdateNewValues().Exec(ctx) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | if err := c.Post.Create().SetID(1).SetTitle("Host Side").SetDescription("Hello Host").OnConflict().UpdateNewValues().Exec(ctx); err != nil { 42 | return err 43 | } 44 | 45 | } else if sCtx.TenantId == "1" { 46 | c := s.tenant.Get(ctx, "") 47 | for i := 1; i < 2; i++ { 48 | if err := c.Post.Create().SetID(10 + i).SetTitle(fmt.Sprintf("Tenant %s Post %v", sCtx.TenantId, i)). 49 | SetDescription(fmt.Sprintf("Tenant %s ", sCtx.TenantId)).OnConflict().UpdateNewValues().Exec(ctx); err != nil { 50 | return err 51 | } 52 | } 53 | 54 | } else if sCtx.TenantId == "2" { 55 | c := s.tenant.Get(ctx, "") 56 | for i := 1; i < 3; i++ { 57 | if err := c.Post.Create().SetID(20 + i).SetTitle(fmt.Sprintf("Tenant %s Post %v", sCtx.TenantId, i)). 58 | SetDescription(fmt.Sprintf("Tenant %s ", sCtx.TenantId)).OnConflict().UpdateNewValues().Exec(ctx); err != nil { 59 | return err 60 | } 61 | } 62 | } else if sCtx.TenantId == "3" { 63 | c := s.tenant.Get(ctx, "") 64 | for i := 1; i < 4; i++ { 65 | if err := c.Post.Create().SetID(30 + i).SetTitle(fmt.Sprintf("Tenant %s Post %v", sCtx.TenantId, i)). 66 | SetDescription(fmt.Sprintf("Tenant %s ", sCtx.TenantId)).OnConflict().UpdateNewValues().Exec(ctx); err != nil { 67 | return err 68 | } 69 | } 70 | } 71 | 72 | return nil 73 | } 74 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/enttest/enttest.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package enttest 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/go-saas/saas/examples/ent/shared/ent" 9 | // required by schema hooks. 10 | _ "github.com/go-saas/saas/examples/ent/shared/ent/runtime" 11 | 12 | "entgo.io/ent/dialect/sql/schema" 13 | "github.com/go-saas/saas/examples/ent/shared/ent/migrate" 14 | ) 15 | 16 | type ( 17 | // TestingT is the interface that is shared between 18 | // testing.T and testing.B and used by enttest. 19 | TestingT interface { 20 | FailNow() 21 | Error(...any) 22 | } 23 | 24 | // Option configures client creation. 25 | Option func(*options) 26 | 27 | options struct { 28 | opts []ent.Option 29 | migrateOpts []schema.MigrateOption 30 | } 31 | ) 32 | 33 | // WithOptions forwards options to client creation. 34 | func WithOptions(opts ...ent.Option) Option { 35 | return func(o *options) { 36 | o.opts = append(o.opts, opts...) 37 | } 38 | } 39 | 40 | // WithMigrateOptions forwards options to auto migration. 41 | func WithMigrateOptions(opts ...schema.MigrateOption) Option { 42 | return func(o *options) { 43 | o.migrateOpts = append(o.migrateOpts, opts...) 44 | } 45 | } 46 | 47 | func newOptions(opts []Option) *options { 48 | o := &options{} 49 | for _, opt := range opts { 50 | opt(o) 51 | } 52 | return o 53 | } 54 | 55 | // Open calls ent.Open and auto-run migration. 56 | func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { 57 | o := newOptions(opts) 58 | c, err := ent.Open(driverName, dataSourceName, o.opts...) 59 | if err != nil { 60 | t.Error(err) 61 | t.FailNow() 62 | } 63 | migrateSchema(t, c, o) 64 | return c 65 | } 66 | 67 | // NewClient calls ent.NewClient and auto-run migration. 68 | func NewClient(t TestingT, opts ...Option) *ent.Client { 69 | o := newOptions(opts) 70 | c := ent.NewClient(o.opts...) 71 | migrateSchema(t, c, o) 72 | return c 73 | } 74 | func migrateSchema(t TestingT, c *ent.Client, o *options) { 75 | tables, err := schema.CopyTables(migrate.Tables) 76 | if err != nil { 77 | t.Error(err) 78 | t.FailNow() 79 | } 80 | if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil { 81 | t.Error(err) 82 | t.FailNow() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/generate.go: -------------------------------------------------------------------------------- 1 | package ent 2 | 3 | //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert,privacy,intercept,schema/snapshot ./schema 4 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/hook/hook.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package hook 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "github.com/go-saas/saas/examples/ent/shared/ent" 10 | ) 11 | 12 | // The PostFunc type is an adapter to allow the use of ordinary 13 | // function as Post mutator. 14 | type PostFunc func(context.Context, *ent.PostMutation) (ent.Value, error) 15 | 16 | // Mutate calls f(ctx, m). 17 | func (f PostFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 18 | if mv, ok := m.(*ent.PostMutation); ok { 19 | return f(ctx, mv) 20 | } 21 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PostMutation", m) 22 | } 23 | 24 | // The TenantFunc type is an adapter to allow the use of ordinary 25 | // function as Tenant mutator. 26 | type TenantFunc func(context.Context, *ent.TenantMutation) (ent.Value, error) 27 | 28 | // Mutate calls f(ctx, m). 29 | func (f TenantFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 30 | if mv, ok := m.(*ent.TenantMutation); ok { 31 | return f(ctx, mv) 32 | } 33 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TenantMutation", m) 34 | } 35 | 36 | // The TenantConnFunc type is an adapter to allow the use of ordinary 37 | // function as TenantConn mutator. 38 | type TenantConnFunc func(context.Context, *ent.TenantConnMutation) (ent.Value, error) 39 | 40 | // Mutate calls f(ctx, m). 41 | func (f TenantConnFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 42 | if mv, ok := m.(*ent.TenantConnMutation); ok { 43 | return f(ctx, mv) 44 | } 45 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TenantConnMutation", m) 46 | } 47 | 48 | // Condition is a hook condition function. 49 | type Condition func(context.Context, ent.Mutation) bool 50 | 51 | // And groups conditions with the AND operator. 52 | func And(first, second Condition, rest ...Condition) Condition { 53 | return func(ctx context.Context, m ent.Mutation) bool { 54 | if !first(ctx, m) || !second(ctx, m) { 55 | return false 56 | } 57 | for _, cond := range rest { 58 | if !cond(ctx, m) { 59 | return false 60 | } 61 | } 62 | return true 63 | } 64 | } 65 | 66 | // Or groups conditions with the OR operator. 67 | func Or(first, second Condition, rest ...Condition) Condition { 68 | return func(ctx context.Context, m ent.Mutation) bool { 69 | if first(ctx, m) || second(ctx, m) { 70 | return true 71 | } 72 | for _, cond := range rest { 73 | if cond(ctx, m) { 74 | return true 75 | } 76 | } 77 | return false 78 | } 79 | } 80 | 81 | // Not negates a given condition. 82 | func Not(cond Condition) Condition { 83 | return func(ctx context.Context, m ent.Mutation) bool { 84 | return !cond(ctx, m) 85 | } 86 | } 87 | 88 | // HasOp is a condition testing mutation operation. 89 | func HasOp(op ent.Op) Condition { 90 | return func(_ context.Context, m ent.Mutation) bool { 91 | return m.Op().Is(op) 92 | } 93 | } 94 | 95 | // HasAddedFields is a condition validating `.AddedField` on fields. 96 | func HasAddedFields(field string, fields ...string) Condition { 97 | return func(_ context.Context, m ent.Mutation) bool { 98 | if _, exists := m.AddedField(field); !exists { 99 | return false 100 | } 101 | for _, field := range fields { 102 | if _, exists := m.AddedField(field); !exists { 103 | return false 104 | } 105 | } 106 | return true 107 | } 108 | } 109 | 110 | // HasClearedFields is a condition validating `.FieldCleared` on fields. 111 | func HasClearedFields(field string, fields ...string) Condition { 112 | return func(_ context.Context, m ent.Mutation) bool { 113 | if exists := m.FieldCleared(field); !exists { 114 | return false 115 | } 116 | for _, field := range fields { 117 | if exists := m.FieldCleared(field); !exists { 118 | return false 119 | } 120 | } 121 | return true 122 | } 123 | } 124 | 125 | // HasFields is a condition validating `.Field` on fields. 126 | func HasFields(field string, fields ...string) Condition { 127 | return func(_ context.Context, m ent.Mutation) bool { 128 | if _, exists := m.Field(field); !exists { 129 | return false 130 | } 131 | for _, field := range fields { 132 | if _, exists := m.Field(field); !exists { 133 | return false 134 | } 135 | } 136 | return true 137 | } 138 | } 139 | 140 | // If executes the given hook under condition. 141 | // 142 | // hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) 143 | func If(hk ent.Hook, cond Condition) ent.Hook { 144 | return func(next ent.Mutator) ent.Mutator { 145 | return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { 146 | if cond(ctx, m) { 147 | return hk(next).Mutate(ctx, m) 148 | } 149 | return next.Mutate(ctx, m) 150 | }) 151 | } 152 | } 153 | 154 | // On executes the given hook only for the given operation. 155 | // 156 | // hook.On(Log, ent.Delete|ent.Create) 157 | func On(hk ent.Hook, op ent.Op) ent.Hook { 158 | return If(hk, HasOp(op)) 159 | } 160 | 161 | // Unless skips the given hook only for the given operation. 162 | // 163 | // hook.Unless(Log, ent.Update|ent.UpdateOne) 164 | func Unless(hk ent.Hook, op ent.Op) ent.Hook { 165 | return If(hk, Not(HasOp(op))) 166 | } 167 | 168 | // FixedError is a hook returning a fixed error. 169 | func FixedError(err error) ent.Hook { 170 | return func(ent.Mutator) ent.Mutator { 171 | return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { 172 | return nil, err 173 | }) 174 | } 175 | } 176 | 177 | // Reject returns a hook that rejects all operations that match op. 178 | // 179 | // func (T) Hooks() []ent.Hook { 180 | // return []ent.Hook{ 181 | // Reject(ent.Delete|ent.Update), 182 | // } 183 | // } 184 | func Reject(op ent.Op) ent.Hook { 185 | hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) 186 | return On(hk, op) 187 | } 188 | 189 | // Chain acts as a list of hooks and is effectively immutable. 190 | // Once created, it will always hold the same set of hooks in the same order. 191 | type Chain struct { 192 | hooks []ent.Hook 193 | } 194 | 195 | // NewChain creates a new chain of hooks. 196 | func NewChain(hooks ...ent.Hook) Chain { 197 | return Chain{append([]ent.Hook(nil), hooks...)} 198 | } 199 | 200 | // Hook chains the list of hooks and returns the final hook. 201 | func (c Chain) Hook() ent.Hook { 202 | return func(mutator ent.Mutator) ent.Mutator { 203 | for i := len(c.hooks) - 1; i >= 0; i-- { 204 | mutator = c.hooks[i](mutator) 205 | } 206 | return mutator 207 | } 208 | } 209 | 210 | // Append extends a chain, adding the specified hook 211 | // as the last ones in the mutation flow. 212 | func (c Chain) Append(hooks ...ent.Hook) Chain { 213 | newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) 214 | newHooks = append(newHooks, c.hooks...) 215 | newHooks = append(newHooks, hooks...) 216 | return Chain{newHooks} 217 | } 218 | 219 | // Extend extends a chain, adding the specified chain 220 | // as the last ones in the mutation flow. 221 | func (c Chain) Extend(chain Chain) Chain { 222 | return c.Append(chain.hooks...) 223 | } 224 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/internal/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | //go:build tools 4 | // +build tools 5 | 6 | // Package internal holds a loadable version of the latest schema. 7 | package internal 8 | 9 | const Schema = `{"Schema":"github.com/go-saas/saas/examples/ent/shared/ent/schema","Package":"github.com/go-saas/saas/examples/ent/shared/ent","Schemas":[{"name":"Post","config":{"Table":""},"fields":[{"name":"tenant_id","type":{"Type":7,"Ident":"*sql.NullString","PkgPath":"database/sql","PkgName":"sql","Nillable":true,"RType":{"Name":"NullString","Ident":"sql.NullString","Kind":22,"PkgPath":"database/sql","Methods":{"Scan":{"In":[{"Name":"","Ident":"interface {}","Kind":20,"PkgPath":"","Methods":null}],"Out":[{"Name":"error","Ident":"error","Kind":20,"PkgPath":"","Methods":null}]},"Value":{"In":[],"Out":[{"Name":"Value","Ident":"driver.Value","Kind":20,"PkgPath":"database/sql/driver","Methods":null},{"Name":"error","Ident":"error","Kind":20,"PkgPath":"","Methods":null}]}}}},"optional":true,"position":{"Index":0,"MixedIn":true,"MixinIndex":0}},{"name":"id","type":{"Type":12,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":0,"MixedIn":false,"MixinIndex":0}},{"name":"title","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":1,"MixedIn":false,"MixinIndex":0}},{"name":"description","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"optional":true,"position":{"Index":2,"MixedIn":false,"MixinIndex":0}},{"name":"dsn","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"optional":true,"position":{"Index":3,"MixedIn":false,"MixinIndex":0}}],"hooks":[{"Index":0,"MixedIn":true,"MixinIndex":0}],"interceptors":[{"Index":0,"MixedIn":true,"MixinIndex":0}]},{"name":"Tenant","config":{"Table":""},"edges":[{"name":"conn","type":"TenantConn"}],"fields":[{"name":"create_time","type":{"Type":2,"Ident":"","PkgPath":"time","PkgName":"","Nillable":false,"RType":null},"default":true,"default_kind":19,"immutable":true,"position":{"Index":0,"MixedIn":true,"MixinIndex":0}},{"name":"update_time","type":{"Type":2,"Ident":"","PkgPath":"time","PkgName":"","Nillable":false,"RType":null},"default":true,"default_kind":19,"update_default":true,"position":{"Index":1,"MixedIn":true,"MixinIndex":0}},{"name":"id","type":{"Type":12,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":0,"MixedIn":false,"MixinIndex":0}},{"name":"name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":1,"MixedIn":false,"MixinIndex":0}},{"name":"display_name","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"optional":true,"position":{"Index":2,"MixedIn":false,"MixinIndex":0}},{"name":"region","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"optional":true,"position":{"Index":3,"MixedIn":false,"MixinIndex":0}}]},{"name":"TenantConn","config":{"Table":""},"fields":[{"name":"create_time","type":{"Type":2,"Ident":"","PkgPath":"time","PkgName":"","Nillable":false,"RType":null},"default":true,"default_kind":19,"immutable":true,"position":{"Index":0,"MixedIn":true,"MixinIndex":0}},{"name":"update_time","type":{"Type":2,"Ident":"","PkgPath":"time","PkgName":"","Nillable":false,"RType":null},"default":true,"default_kind":19,"update_default":true,"position":{"Index":1,"MixedIn":true,"MixinIndex":0}},{"name":"key","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":0,"MixedIn":false,"MixinIndex":0}},{"name":"value","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":1,"MixedIn":false,"MixinIndex":0}}]}],"Features":["sql/upsert","privacy","intercept","schema/snapshot"]}` 10 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/migrate/migrate.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "io" 9 | 10 | "entgo.io/ent/dialect" 11 | "entgo.io/ent/dialect/sql/schema" 12 | ) 13 | 14 | var ( 15 | // WithGlobalUniqueID sets the universal ids options to the migration. 16 | // If this option is enabled, ent migration will allocate a 1<<32 range 17 | // for the ids of each entity (table). 18 | // Note that this option cannot be applied on tables that already exist. 19 | WithGlobalUniqueID = schema.WithGlobalUniqueID 20 | // WithDropColumn sets the drop column option to the migration. 21 | // If this option is enabled, ent migration will drop old columns 22 | // that were used for both fields and edges. This defaults to false. 23 | WithDropColumn = schema.WithDropColumn 24 | // WithDropIndex sets the drop index option to the migration. 25 | // If this option is enabled, ent migration will drop old indexes 26 | // that were defined in the schema. This defaults to false. 27 | // Note that unique constraints are defined using `UNIQUE INDEX`, 28 | // and therefore, it's recommended to enable this option to get more 29 | // flexibility in the schema changes. 30 | WithDropIndex = schema.WithDropIndex 31 | // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. 32 | WithForeignKeys = schema.WithForeignKeys 33 | ) 34 | 35 | // Schema is the API for creating, migrating and dropping a schema. 36 | type Schema struct { 37 | drv dialect.Driver 38 | } 39 | 40 | // NewSchema creates a new schema client. 41 | func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } 42 | 43 | // Create creates all schema resources. 44 | func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { 45 | return Create(ctx, s, Tables, opts...) 46 | } 47 | 48 | // Create creates all table resources using the given schema driver. 49 | func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...schema.MigrateOption) error { 50 | migrate, err := schema.NewMigrate(s.drv, opts...) 51 | if err != nil { 52 | return fmt.Errorf("ent/migrate: %w", err) 53 | } 54 | return migrate.Create(ctx, tables...) 55 | } 56 | 57 | // WriteTo writes the schema changes to w instead of running them against the database. 58 | // 59 | // if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { 60 | // log.Fatal(err) 61 | // } 62 | func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { 63 | return Create(ctx, &Schema{drv: &schema.WriteDriver{Writer: w, Driver: s.drv}}, Tables, opts...) 64 | } 65 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/migrate/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql/schema" 7 | "entgo.io/ent/schema/field" 8 | ) 9 | 10 | var ( 11 | // PostsColumns holds the columns for the "posts" table. 12 | PostsColumns = []*schema.Column{ 13 | {Name: "id", Type: field.TypeInt, Increment: true}, 14 | {Name: "tenant_id", Type: field.TypeString, Nullable: true}, 15 | {Name: "title", Type: field.TypeString}, 16 | {Name: "description", Type: field.TypeString, Nullable: true}, 17 | {Name: "dsn", Type: field.TypeString, Nullable: true}, 18 | } 19 | // PostsTable holds the schema information for the "posts" table. 20 | PostsTable = &schema.Table{ 21 | Name: "posts", 22 | Columns: PostsColumns, 23 | PrimaryKey: []*schema.Column{PostsColumns[0]}, 24 | } 25 | // TenantsColumns holds the columns for the "tenants" table. 26 | TenantsColumns = []*schema.Column{ 27 | {Name: "id", Type: field.TypeInt, Increment: true}, 28 | {Name: "create_time", Type: field.TypeTime}, 29 | {Name: "update_time", Type: field.TypeTime}, 30 | {Name: "name", Type: field.TypeString}, 31 | {Name: "display_name", Type: field.TypeString, Nullable: true}, 32 | {Name: "region", Type: field.TypeString, Nullable: true}, 33 | } 34 | // TenantsTable holds the schema information for the "tenants" table. 35 | TenantsTable = &schema.Table{ 36 | Name: "tenants", 37 | Columns: TenantsColumns, 38 | PrimaryKey: []*schema.Column{TenantsColumns[0]}, 39 | } 40 | // TenantConnsColumns holds the columns for the "tenant_conns" table. 41 | TenantConnsColumns = []*schema.Column{ 42 | {Name: "id", Type: field.TypeInt, Increment: true}, 43 | {Name: "create_time", Type: field.TypeTime}, 44 | {Name: "update_time", Type: field.TypeTime}, 45 | {Name: "key", Type: field.TypeString}, 46 | {Name: "value", Type: field.TypeString}, 47 | {Name: "tenant_conn", Type: field.TypeInt, Nullable: true}, 48 | } 49 | // TenantConnsTable holds the schema information for the "tenant_conns" table. 50 | TenantConnsTable = &schema.Table{ 51 | Name: "tenant_conns", 52 | Columns: TenantConnsColumns, 53 | PrimaryKey: []*schema.Column{TenantConnsColumns[0]}, 54 | ForeignKeys: []*schema.ForeignKey{ 55 | { 56 | Symbol: "tenant_conns_tenants_conn", 57 | Columns: []*schema.Column{TenantConnsColumns[5]}, 58 | RefColumns: []*schema.Column{TenantsColumns[0]}, 59 | OnDelete: schema.SetNull, 60 | }, 61 | }, 62 | } 63 | // Tables holds all the tables in the schema. 64 | Tables = []*schema.Table{ 65 | PostsTable, 66 | TenantsTable, 67 | TenantConnsTable, 68 | } 69 | ) 70 | 71 | func init() { 72 | TenantConnsTable.ForeignKeys[0].RefTable = TenantsTable 73 | } 74 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/post.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | 9 | "entgo.io/ent" 10 | "entgo.io/ent/dialect/sql" 11 | "github.com/go-saas/saas/examples/ent/shared/ent/post" 12 | ) 13 | 14 | // Post is the model entity for the Post schema. 15 | type Post struct { 16 | config `json:"-"` 17 | // ID of the ent. 18 | ID int `json:"id,omitempty"` 19 | // TenantID holds the value of the "tenant_id" field. 20 | TenantID *sql.NullString `json:"tenant_id,omitempty"` 21 | // Title holds the value of the "title" field. 22 | Title string `json:"title,omitempty"` 23 | // Description holds the value of the "description" field. 24 | Description string `json:"description,omitempty"` 25 | // Dsn holds the value of the "dsn" field. 26 | Dsn string `json:"dsn,omitempty"` 27 | selectValues sql.SelectValues 28 | } 29 | 30 | // scanValues returns the types for scanning values from sql.Rows. 31 | func (*Post) scanValues(columns []string) ([]any, error) { 32 | values := make([]any, len(columns)) 33 | for i := range columns { 34 | switch columns[i] { 35 | case post.FieldID: 36 | values[i] = new(sql.NullInt64) 37 | case post.FieldTenantID, post.FieldTitle, post.FieldDescription, post.FieldDsn: 38 | values[i] = new(sql.NullString) 39 | default: 40 | values[i] = new(sql.UnknownType) 41 | } 42 | } 43 | return values, nil 44 | } 45 | 46 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 47 | // to the Post fields. 48 | func (po *Post) assignValues(columns []string, values []any) error { 49 | if m, n := len(values), len(columns); m < n { 50 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 51 | } 52 | for i := range columns { 53 | switch columns[i] { 54 | case post.FieldID: 55 | value, ok := values[i].(*sql.NullInt64) 56 | if !ok { 57 | return fmt.Errorf("unexpected type %T for field id", value) 58 | } 59 | po.ID = int(value.Int64) 60 | case post.FieldTenantID: 61 | if value, ok := values[i].(*sql.NullString); !ok { 62 | return fmt.Errorf("unexpected type %T for field tenant_id", values[i]) 63 | } else if value.Valid { 64 | po.TenantID = value 65 | } 66 | case post.FieldTitle: 67 | if value, ok := values[i].(*sql.NullString); !ok { 68 | return fmt.Errorf("unexpected type %T for field title", values[i]) 69 | } else if value.Valid { 70 | po.Title = value.String 71 | } 72 | case post.FieldDescription: 73 | if value, ok := values[i].(*sql.NullString); !ok { 74 | return fmt.Errorf("unexpected type %T for field description", values[i]) 75 | } else if value.Valid { 76 | po.Description = value.String 77 | } 78 | case post.FieldDsn: 79 | if value, ok := values[i].(*sql.NullString); !ok { 80 | return fmt.Errorf("unexpected type %T for field dsn", values[i]) 81 | } else if value.Valid { 82 | po.Dsn = value.String 83 | } 84 | default: 85 | po.selectValues.Set(columns[i], values[i]) 86 | } 87 | } 88 | return nil 89 | } 90 | 91 | // Value returns the ent.Value that was dynamically selected and assigned to the Post. 92 | // This includes values selected through modifiers, order, etc. 93 | func (po *Post) Value(name string) (ent.Value, error) { 94 | return po.selectValues.Get(name) 95 | } 96 | 97 | // Update returns a builder for updating this Post. 98 | // Note that you need to call Post.Unwrap() before calling this method if this Post 99 | // was returned from a transaction, and the transaction was committed or rolled back. 100 | func (po *Post) Update() *PostUpdateOne { 101 | return NewPostClient(po.config).UpdateOne(po) 102 | } 103 | 104 | // Unwrap unwraps the Post entity that was returned from a transaction after it was closed, 105 | // so that all future queries will be executed through the driver which created the transaction. 106 | func (po *Post) Unwrap() *Post { 107 | _tx, ok := po.config.driver.(*txDriver) 108 | if !ok { 109 | panic("ent: Post is not a transactional entity") 110 | } 111 | po.config.driver = _tx.drv 112 | return po 113 | } 114 | 115 | // String implements the fmt.Stringer. 116 | func (po *Post) String() string { 117 | var builder strings.Builder 118 | builder.WriteString("Post(") 119 | builder.WriteString(fmt.Sprintf("id=%v, ", po.ID)) 120 | builder.WriteString("tenant_id=") 121 | builder.WriteString(fmt.Sprintf("%v", po.TenantID)) 122 | builder.WriteString(", ") 123 | builder.WriteString("title=") 124 | builder.WriteString(po.Title) 125 | builder.WriteString(", ") 126 | builder.WriteString("description=") 127 | builder.WriteString(po.Description) 128 | builder.WriteString(", ") 129 | builder.WriteString("dsn=") 130 | builder.WriteString(po.Dsn) 131 | builder.WriteByte(')') 132 | return builder.String() 133 | } 134 | 135 | // Posts is a parsable slice of Post. 136 | type Posts []*Post 137 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/post/post.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package post 4 | 5 | import ( 6 | "entgo.io/ent" 7 | "entgo.io/ent/dialect/sql" 8 | ) 9 | 10 | const ( 11 | // Label holds the string label denoting the post type in the database. 12 | Label = "post" 13 | // FieldID holds the string denoting the id field in the database. 14 | FieldID = "id" 15 | // FieldTenantID holds the string denoting the tenant_id field in the database. 16 | FieldTenantID = "tenant_id" 17 | // FieldTitle holds the string denoting the title field in the database. 18 | FieldTitle = "title" 19 | // FieldDescription holds the string denoting the description field in the database. 20 | FieldDescription = "description" 21 | // FieldDsn holds the string denoting the dsn field in the database. 22 | FieldDsn = "dsn" 23 | // Table holds the table name of the post in the database. 24 | Table = "posts" 25 | ) 26 | 27 | // Columns holds all SQL columns for post fields. 28 | var Columns = []string{ 29 | FieldID, 30 | FieldTenantID, 31 | FieldTitle, 32 | FieldDescription, 33 | FieldDsn, 34 | } 35 | 36 | // ValidColumn reports if the column name is valid (part of the table columns). 37 | func ValidColumn(column string) bool { 38 | for i := range Columns { 39 | if column == Columns[i] { 40 | return true 41 | } 42 | } 43 | return false 44 | } 45 | 46 | // Note that the variables below are initialized by the runtime 47 | // package on the initialization of the application. Therefore, 48 | // it should be imported in the main as follows: 49 | // 50 | // import _ "github.com/go-saas/saas/examples/ent/shared/ent/runtime" 51 | var ( 52 | Hooks [1]ent.Hook 53 | Interceptors [1]ent.Interceptor 54 | ) 55 | 56 | // OrderOption defines the ordering options for the Post queries. 57 | type OrderOption func(*sql.Selector) 58 | 59 | // ByID orders the results by the id field. 60 | func ByID(opts ...sql.OrderTermOption) OrderOption { 61 | return sql.OrderByField(FieldID, opts...).ToFunc() 62 | } 63 | 64 | // ByTenantID orders the results by the tenant_id field. 65 | func ByTenantID(opts ...sql.OrderTermOption) OrderOption { 66 | return sql.OrderByField(FieldTenantID, opts...).ToFunc() 67 | } 68 | 69 | // ByTitle orders the results by the title field. 70 | func ByTitle(opts ...sql.OrderTermOption) OrderOption { 71 | return sql.OrderByField(FieldTitle, opts...).ToFunc() 72 | } 73 | 74 | // ByDescription orders the results by the description field. 75 | func ByDescription(opts ...sql.OrderTermOption) OrderOption { 76 | return sql.OrderByField(FieldDescription, opts...).ToFunc() 77 | } 78 | 79 | // ByDsn orders the results by the dsn field. 80 | func ByDsn(opts ...sql.OrderTermOption) OrderOption { 81 | return sql.OrderByField(FieldDsn, opts...).ToFunc() 82 | } 83 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/post_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | "entgo.io/ent/dialect/sql/sqlgraph" 10 | "entgo.io/ent/schema/field" 11 | "github.com/go-saas/saas/examples/ent/shared/ent/post" 12 | "github.com/go-saas/saas/examples/ent/shared/ent/predicate" 13 | ) 14 | 15 | // PostDelete is the builder for deleting a Post entity. 16 | type PostDelete struct { 17 | config 18 | hooks []Hook 19 | mutation *PostMutation 20 | } 21 | 22 | // Where appends a list predicates to the PostDelete builder. 23 | func (pd *PostDelete) Where(ps ...predicate.Post) *PostDelete { 24 | pd.mutation.Where(ps...) 25 | return pd 26 | } 27 | 28 | // Exec executes the deletion query and returns how many vertices were deleted. 29 | func (pd *PostDelete) Exec(ctx context.Context) (int, error) { 30 | return withHooks(ctx, pd.sqlExec, pd.mutation, pd.hooks) 31 | } 32 | 33 | // ExecX is like Exec, but panics if an error occurs. 34 | func (pd *PostDelete) ExecX(ctx context.Context) int { 35 | n, err := pd.Exec(ctx) 36 | if err != nil { 37 | panic(err) 38 | } 39 | return n 40 | } 41 | 42 | func (pd *PostDelete) sqlExec(ctx context.Context) (int, error) { 43 | _spec := sqlgraph.NewDeleteSpec(post.Table, sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt)) 44 | if ps := pd.mutation.predicates; len(ps) > 0 { 45 | _spec.Predicate = func(selector *sql.Selector) { 46 | for i := range ps { 47 | ps[i](selector) 48 | } 49 | } 50 | } 51 | affected, err := sqlgraph.DeleteNodes(ctx, pd.driver, _spec) 52 | if err != nil && sqlgraph.IsConstraintError(err) { 53 | err = &ConstraintError{msg: err.Error(), wrap: err} 54 | } 55 | pd.mutation.done = true 56 | return affected, err 57 | } 58 | 59 | // PostDeleteOne is the builder for deleting a single Post entity. 60 | type PostDeleteOne struct { 61 | pd *PostDelete 62 | } 63 | 64 | // Where appends a list predicates to the PostDelete builder. 65 | func (pdo *PostDeleteOne) Where(ps ...predicate.Post) *PostDeleteOne { 66 | pdo.pd.mutation.Where(ps...) 67 | return pdo 68 | } 69 | 70 | // Exec executes the deletion query. 71 | func (pdo *PostDeleteOne) Exec(ctx context.Context) error { 72 | n, err := pdo.pd.Exec(ctx) 73 | switch { 74 | case err != nil: 75 | return err 76 | case n == 0: 77 | return &NotFoundError{post.Label} 78 | default: 79 | return nil 80 | } 81 | } 82 | 83 | // ExecX is like Exec, but panics if an error occurs. 84 | func (pdo *PostDeleteOne) ExecX(ctx context.Context) { 85 | if err := pdo.Exec(ctx); err != nil { 86 | panic(err) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/predicate/predicate.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package predicate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql" 7 | ) 8 | 9 | // Post is the predicate function for post builders. 10 | type Post func(*sql.Selector) 11 | 12 | // Tenant is the predicate function for tenant builders. 13 | type Tenant func(*sql.Selector) 14 | 15 | // TenantConn is the predicate function for tenantconn builders. 16 | type TenantConn func(*sql.Selector) 17 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/privacy/privacy.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package privacy 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/go-saas/saas/examples/ent/shared/ent" 9 | 10 | "entgo.io/ent/privacy" 11 | ) 12 | 13 | var ( 14 | // Allow may be returned by rules to indicate that the policy 15 | // evaluation should terminate with allow decision. 16 | Allow = privacy.Allow 17 | 18 | // Deny may be returned by rules to indicate that the policy 19 | // evaluation should terminate with deny decision. 20 | Deny = privacy.Deny 21 | 22 | // Skip may be returned by rules to indicate that the policy 23 | // evaluation should continue to the next rule. 24 | Skip = privacy.Skip 25 | ) 26 | 27 | // Allowf returns a formatted wrapped Allow decision. 28 | func Allowf(format string, a ...any) error { 29 | return privacy.Allowf(format, a...) 30 | } 31 | 32 | // Denyf returns a formatted wrapped Deny decision. 33 | func Denyf(format string, a ...any) error { 34 | return privacy.Denyf(format, a...) 35 | } 36 | 37 | // Skipf returns a formatted wrapped Skip decision. 38 | func Skipf(format string, a ...any) error { 39 | return privacy.Skipf(format, a...) 40 | } 41 | 42 | // DecisionContext creates a new context from the given parent context with 43 | // a policy decision attach to it. 44 | func DecisionContext(parent context.Context, decision error) context.Context { 45 | return privacy.DecisionContext(parent, decision) 46 | } 47 | 48 | // DecisionFromContext retrieves the policy decision from the context. 49 | func DecisionFromContext(ctx context.Context) (error, bool) { 50 | return privacy.DecisionFromContext(ctx) 51 | } 52 | 53 | type ( 54 | // Policy groups query and mutation policies. 55 | Policy = privacy.Policy 56 | 57 | // QueryRule defines the interface deciding whether a 58 | // query is allowed and optionally modify it. 59 | QueryRule = privacy.QueryRule 60 | // QueryPolicy combines multiple query rules into a single policy. 61 | QueryPolicy = privacy.QueryPolicy 62 | 63 | // MutationRule defines the interface which decides whether a 64 | // mutation is allowed and optionally modifies it. 65 | MutationRule = privacy.MutationRule 66 | // MutationPolicy combines multiple mutation rules into a single policy. 67 | MutationPolicy = privacy.MutationPolicy 68 | // MutationRuleFunc type is an adapter which allows the use of 69 | // ordinary functions as mutation rules. 70 | MutationRuleFunc = privacy.MutationRuleFunc 71 | 72 | // QueryMutationRule is an interface which groups query and mutation rules. 73 | QueryMutationRule = privacy.QueryMutationRule 74 | ) 75 | 76 | // QueryRuleFunc type is an adapter to allow the use of 77 | // ordinary functions as query rules. 78 | type QueryRuleFunc func(context.Context, ent.Query) error 79 | 80 | // Eval returns f(ctx, q). 81 | func (f QueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { 82 | return f(ctx, q) 83 | } 84 | 85 | // AlwaysAllowRule returns a rule that returns an allow decision. 86 | func AlwaysAllowRule() QueryMutationRule { 87 | return privacy.AlwaysAllowRule() 88 | } 89 | 90 | // AlwaysDenyRule returns a rule that returns a deny decision. 91 | func AlwaysDenyRule() QueryMutationRule { 92 | return privacy.AlwaysDenyRule() 93 | } 94 | 95 | // ContextQueryMutationRule creates a query/mutation rule from a context eval func. 96 | func ContextQueryMutationRule(eval func(context.Context) error) QueryMutationRule { 97 | return privacy.ContextQueryMutationRule(eval) 98 | } 99 | 100 | // OnMutationOperation evaluates the given rule only on a given mutation operation. 101 | func OnMutationOperation(rule MutationRule, op ent.Op) MutationRule { 102 | return privacy.OnMutationOperation(rule, op) 103 | } 104 | 105 | // DenyMutationOperationRule returns a rule denying specified mutation operation. 106 | func DenyMutationOperationRule(op ent.Op) MutationRule { 107 | rule := MutationRuleFunc(func(_ context.Context, m ent.Mutation) error { 108 | return Denyf("ent/privacy: operation %s is not allowed", m.Op()) 109 | }) 110 | return OnMutationOperation(rule, op) 111 | } 112 | 113 | // The PostQueryRuleFunc type is an adapter to allow the use of ordinary 114 | // functions as a query rule. 115 | type PostQueryRuleFunc func(context.Context, *ent.PostQuery) error 116 | 117 | // EvalQuery return f(ctx, q). 118 | func (f PostQueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { 119 | if q, ok := q.(*ent.PostQuery); ok { 120 | return f(ctx, q) 121 | } 122 | return Denyf("ent/privacy: unexpected query type %T, expect *ent.PostQuery", q) 123 | } 124 | 125 | // The PostMutationRuleFunc type is an adapter to allow the use of ordinary 126 | // functions as a mutation rule. 127 | type PostMutationRuleFunc func(context.Context, *ent.PostMutation) error 128 | 129 | // EvalMutation calls f(ctx, m). 130 | func (f PostMutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) error { 131 | if m, ok := m.(*ent.PostMutation); ok { 132 | return f(ctx, m) 133 | } 134 | return Denyf("ent/privacy: unexpected mutation type %T, expect *ent.PostMutation", m) 135 | } 136 | 137 | // The TenantQueryRuleFunc type is an adapter to allow the use of ordinary 138 | // functions as a query rule. 139 | type TenantQueryRuleFunc func(context.Context, *ent.TenantQuery) error 140 | 141 | // EvalQuery return f(ctx, q). 142 | func (f TenantQueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { 143 | if q, ok := q.(*ent.TenantQuery); ok { 144 | return f(ctx, q) 145 | } 146 | return Denyf("ent/privacy: unexpected query type %T, expect *ent.TenantQuery", q) 147 | } 148 | 149 | // The TenantMutationRuleFunc type is an adapter to allow the use of ordinary 150 | // functions as a mutation rule. 151 | type TenantMutationRuleFunc func(context.Context, *ent.TenantMutation) error 152 | 153 | // EvalMutation calls f(ctx, m). 154 | func (f TenantMutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) error { 155 | if m, ok := m.(*ent.TenantMutation); ok { 156 | return f(ctx, m) 157 | } 158 | return Denyf("ent/privacy: unexpected mutation type %T, expect *ent.TenantMutation", m) 159 | } 160 | 161 | // The TenantConnQueryRuleFunc type is an adapter to allow the use of ordinary 162 | // functions as a query rule. 163 | type TenantConnQueryRuleFunc func(context.Context, *ent.TenantConnQuery) error 164 | 165 | // EvalQuery return f(ctx, q). 166 | func (f TenantConnQueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { 167 | if q, ok := q.(*ent.TenantConnQuery); ok { 168 | return f(ctx, q) 169 | } 170 | return Denyf("ent/privacy: unexpected query type %T, expect *ent.TenantConnQuery", q) 171 | } 172 | 173 | // The TenantConnMutationRuleFunc type is an adapter to allow the use of ordinary 174 | // functions as a mutation rule. 175 | type TenantConnMutationRuleFunc func(context.Context, *ent.TenantConnMutation) error 176 | 177 | // EvalMutation calls f(ctx, m). 178 | func (f TenantConnMutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) error { 179 | if m, ok := m.(*ent.TenantConnMutation); ok { 180 | return f(ctx, m) 181 | } 182 | return Denyf("ent/privacy: unexpected mutation type %T, expect *ent.TenantConnMutation", m) 183 | } 184 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | // The schema-stitching logic is generated in github.com/go-saas/saas/examples/ent/shared/ent/runtime/runtime.go 6 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/runtime/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package runtime 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/go-saas/saas/examples/ent/shared/ent/post" 9 | "github.com/go-saas/saas/examples/ent/shared/ent/schema" 10 | "github.com/go-saas/saas/examples/ent/shared/ent/tenant" 11 | "github.com/go-saas/saas/examples/ent/shared/ent/tenantconn" 12 | ) 13 | 14 | // The init function reads all schema descriptors with runtime code 15 | // (default values, validators, hooks and policies) and stitches it 16 | // to their package variables. 17 | func init() { 18 | postMixin := schema.Post{}.Mixin() 19 | postMixinHooks0 := postMixin[0].Hooks() 20 | post.Hooks[0] = postMixinHooks0[0] 21 | postMixinInters0 := postMixin[0].Interceptors() 22 | post.Interceptors[0] = postMixinInters0[0] 23 | tenantMixin := schema.Tenant{}.Mixin() 24 | tenantMixinFields0 := tenantMixin[0].Fields() 25 | _ = tenantMixinFields0 26 | tenantFields := schema.Tenant{}.Fields() 27 | _ = tenantFields 28 | // tenantDescCreateTime is the schema descriptor for create_time field. 29 | tenantDescCreateTime := tenantMixinFields0[0].Descriptor() 30 | // tenant.DefaultCreateTime holds the default value on creation for the create_time field. 31 | tenant.DefaultCreateTime = tenantDescCreateTime.Default.(func() time.Time) 32 | // tenantDescUpdateTime is the schema descriptor for update_time field. 33 | tenantDescUpdateTime := tenantMixinFields0[1].Descriptor() 34 | // tenant.DefaultUpdateTime holds the default value on creation for the update_time field. 35 | tenant.DefaultUpdateTime = tenantDescUpdateTime.Default.(func() time.Time) 36 | // tenant.UpdateDefaultUpdateTime holds the default value on update for the update_time field. 37 | tenant.UpdateDefaultUpdateTime = tenantDescUpdateTime.UpdateDefault.(func() time.Time) 38 | tenantconnMixin := schema.TenantConn{}.Mixin() 39 | tenantconnMixinFields0 := tenantconnMixin[0].Fields() 40 | _ = tenantconnMixinFields0 41 | tenantconnFields := schema.TenantConn{}.Fields() 42 | _ = tenantconnFields 43 | // tenantconnDescCreateTime is the schema descriptor for create_time field. 44 | tenantconnDescCreateTime := tenantconnMixinFields0[0].Descriptor() 45 | // tenantconn.DefaultCreateTime holds the default value on creation for the create_time field. 46 | tenantconn.DefaultCreateTime = tenantconnDescCreateTime.Default.(func() time.Time) 47 | // tenantconnDescUpdateTime is the schema descriptor for update_time field. 48 | tenantconnDescUpdateTime := tenantconnMixinFields0[1].Descriptor() 49 | // tenantconn.DefaultUpdateTime holds the default value on creation for the update_time field. 50 | tenantconn.DefaultUpdateTime = tenantconnDescUpdateTime.Default.(func() time.Time) 51 | // tenantconn.UpdateDefaultUpdateTime holds the default value on update for the update_time field. 52 | tenantconn.UpdateDefaultUpdateTime = tenantconnDescUpdateTime.UpdateDefault.(func() time.Time) 53 | } 54 | 55 | const ( 56 | Version = "v0.12.4" // Version of ent codegen. 57 | Sum = "h1:LddPnAyxls/O7DTXZvUGDj0NZIdGSu317+aoNLJWbD8=" // Sum of ent codegen. 58 | ) 59 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/schema/post.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/field" 6 | sent "github.com/go-saas/saas/ent" 7 | ) 8 | 9 | // Post holds the schema definition for the Post entity. 10 | type Post struct { 11 | ent.Schema 12 | } 13 | 14 | // Fields of the Post. 15 | func (Post) Fields() []ent.Field { 16 | return []ent.Field{ 17 | field.Int("id"), 18 | field.String("title"), 19 | field.String("description").Optional(), 20 | field.String("dsn").Optional(), 21 | } 22 | } 23 | 24 | // Edges of the Post. 25 | func (Post) Edges() []ent.Edge { 26 | return nil 27 | } 28 | 29 | func (Post) Mixin() []ent.Mixin { 30 | return []ent.Mixin{ 31 | sent.HasTenant{}, 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/schema/tenant.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/edge" 6 | "entgo.io/ent/schema/field" 7 | "entgo.io/ent/schema/mixin" 8 | ) 9 | 10 | // Tenant holds the schema definition for the Tenant entity. 11 | type Tenant struct { 12 | ent.Schema 13 | } 14 | 15 | // Fields of the Tenant. 16 | func (Tenant) Fields() []ent.Field { 17 | return []ent.Field{ 18 | field.Int("id"), 19 | field.String("name"), 20 | field.String("display_name").Optional(), 21 | field.String("region").Optional(), 22 | } 23 | } 24 | 25 | // Edges of the Tenant. 26 | func (Tenant) Edges() []ent.Edge { 27 | return []ent.Edge{ 28 | edge.To("conn", TenantConn.Type), 29 | } 30 | } 31 | 32 | func (Tenant) Mixin() []ent.Mixin { 33 | return []ent.Mixin{ 34 | mixin.Time{}, 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/schema/tenantconn.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/field" 6 | "entgo.io/ent/schema/mixin" 7 | ) 8 | 9 | // TenantConn holds the schema definition for the TenantConn entity. 10 | type TenantConn struct { 11 | ent.Schema 12 | } 13 | 14 | // Fields of the TenantConn. 15 | func (TenantConn) Fields() []ent.Field { 16 | return []ent.Field{ 17 | field.String("key"), 18 | field.String("value"), 19 | } 20 | } 21 | 22 | // Edges of the TenantConn. 23 | func (TenantConn) Edges() []ent.Edge { 24 | return nil 25 | } 26 | func (TenantConn) Mixin() []ent.Mixin { 27 | return []ent.Mixin{ 28 | mixin.Time{}, 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/tenant.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | "time" 9 | 10 | "entgo.io/ent" 11 | "entgo.io/ent/dialect/sql" 12 | "github.com/go-saas/saas/examples/ent/shared/ent/tenant" 13 | ) 14 | 15 | // Tenant is the model entity for the Tenant schema. 16 | type Tenant struct { 17 | config `json:"-"` 18 | // ID of the ent. 19 | ID int `json:"id,omitempty"` 20 | // CreateTime holds the value of the "create_time" field. 21 | CreateTime time.Time `json:"create_time,omitempty"` 22 | // UpdateTime holds the value of the "update_time" field. 23 | UpdateTime time.Time `json:"update_time,omitempty"` 24 | // Name holds the value of the "name" field. 25 | Name string `json:"name,omitempty"` 26 | // DisplayName holds the value of the "display_name" field. 27 | DisplayName string `json:"display_name,omitempty"` 28 | // Region holds the value of the "region" field. 29 | Region string `json:"region,omitempty"` 30 | // Edges holds the relations/edges for other nodes in the graph. 31 | // The values are being populated by the TenantQuery when eager-loading is set. 32 | Edges TenantEdges `json:"edges"` 33 | selectValues sql.SelectValues 34 | } 35 | 36 | // TenantEdges holds the relations/edges for other nodes in the graph. 37 | type TenantEdges struct { 38 | // Conn holds the value of the conn edge. 39 | Conn []*TenantConn `json:"conn,omitempty"` 40 | // loadedTypes holds the information for reporting if a 41 | // type was loaded (or requested) in eager-loading or not. 42 | loadedTypes [1]bool 43 | } 44 | 45 | // ConnOrErr returns the Conn value or an error if the edge 46 | // was not loaded in eager-loading. 47 | func (e TenantEdges) ConnOrErr() ([]*TenantConn, error) { 48 | if e.loadedTypes[0] { 49 | return e.Conn, nil 50 | } 51 | return nil, &NotLoadedError{edge: "conn"} 52 | } 53 | 54 | // scanValues returns the types for scanning values from sql.Rows. 55 | func (*Tenant) scanValues(columns []string) ([]any, error) { 56 | values := make([]any, len(columns)) 57 | for i := range columns { 58 | switch columns[i] { 59 | case tenant.FieldID: 60 | values[i] = new(sql.NullInt64) 61 | case tenant.FieldName, tenant.FieldDisplayName, tenant.FieldRegion: 62 | values[i] = new(sql.NullString) 63 | case tenant.FieldCreateTime, tenant.FieldUpdateTime: 64 | values[i] = new(sql.NullTime) 65 | default: 66 | values[i] = new(sql.UnknownType) 67 | } 68 | } 69 | return values, nil 70 | } 71 | 72 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 73 | // to the Tenant fields. 74 | func (t *Tenant) assignValues(columns []string, values []any) error { 75 | if m, n := len(values), len(columns); m < n { 76 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 77 | } 78 | for i := range columns { 79 | switch columns[i] { 80 | case tenant.FieldID: 81 | value, ok := values[i].(*sql.NullInt64) 82 | if !ok { 83 | return fmt.Errorf("unexpected type %T for field id", value) 84 | } 85 | t.ID = int(value.Int64) 86 | case tenant.FieldCreateTime: 87 | if value, ok := values[i].(*sql.NullTime); !ok { 88 | return fmt.Errorf("unexpected type %T for field create_time", values[i]) 89 | } else if value.Valid { 90 | t.CreateTime = value.Time 91 | } 92 | case tenant.FieldUpdateTime: 93 | if value, ok := values[i].(*sql.NullTime); !ok { 94 | return fmt.Errorf("unexpected type %T for field update_time", values[i]) 95 | } else if value.Valid { 96 | t.UpdateTime = value.Time 97 | } 98 | case tenant.FieldName: 99 | if value, ok := values[i].(*sql.NullString); !ok { 100 | return fmt.Errorf("unexpected type %T for field name", values[i]) 101 | } else if value.Valid { 102 | t.Name = value.String 103 | } 104 | case tenant.FieldDisplayName: 105 | if value, ok := values[i].(*sql.NullString); !ok { 106 | return fmt.Errorf("unexpected type %T for field display_name", values[i]) 107 | } else if value.Valid { 108 | t.DisplayName = value.String 109 | } 110 | case tenant.FieldRegion: 111 | if value, ok := values[i].(*sql.NullString); !ok { 112 | return fmt.Errorf("unexpected type %T for field region", values[i]) 113 | } else if value.Valid { 114 | t.Region = value.String 115 | } 116 | default: 117 | t.selectValues.Set(columns[i], values[i]) 118 | } 119 | } 120 | return nil 121 | } 122 | 123 | // Value returns the ent.Value that was dynamically selected and assigned to the Tenant. 124 | // This includes values selected through modifiers, order, etc. 125 | func (t *Tenant) Value(name string) (ent.Value, error) { 126 | return t.selectValues.Get(name) 127 | } 128 | 129 | // QueryConn queries the "conn" edge of the Tenant entity. 130 | func (t *Tenant) QueryConn() *TenantConnQuery { 131 | return NewTenantClient(t.config).QueryConn(t) 132 | } 133 | 134 | // Update returns a builder for updating this Tenant. 135 | // Note that you need to call Tenant.Unwrap() before calling this method if this Tenant 136 | // was returned from a transaction, and the transaction was committed or rolled back. 137 | func (t *Tenant) Update() *TenantUpdateOne { 138 | return NewTenantClient(t.config).UpdateOne(t) 139 | } 140 | 141 | // Unwrap unwraps the Tenant entity that was returned from a transaction after it was closed, 142 | // so that all future queries will be executed through the driver which created the transaction. 143 | func (t *Tenant) Unwrap() *Tenant { 144 | _tx, ok := t.config.driver.(*txDriver) 145 | if !ok { 146 | panic("ent: Tenant is not a transactional entity") 147 | } 148 | t.config.driver = _tx.drv 149 | return t 150 | } 151 | 152 | // String implements the fmt.Stringer. 153 | func (t *Tenant) String() string { 154 | var builder strings.Builder 155 | builder.WriteString("Tenant(") 156 | builder.WriteString(fmt.Sprintf("id=%v, ", t.ID)) 157 | builder.WriteString("create_time=") 158 | builder.WriteString(t.CreateTime.Format(time.ANSIC)) 159 | builder.WriteString(", ") 160 | builder.WriteString("update_time=") 161 | builder.WriteString(t.UpdateTime.Format(time.ANSIC)) 162 | builder.WriteString(", ") 163 | builder.WriteString("name=") 164 | builder.WriteString(t.Name) 165 | builder.WriteString(", ") 166 | builder.WriteString("display_name=") 167 | builder.WriteString(t.DisplayName) 168 | builder.WriteString(", ") 169 | builder.WriteString("region=") 170 | builder.WriteString(t.Region) 171 | builder.WriteByte(')') 172 | return builder.String() 173 | } 174 | 175 | // Tenants is a parsable slice of Tenant. 176 | type Tenants []*Tenant 177 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/tenant/tenant.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package tenant 4 | 5 | import ( 6 | "time" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | "entgo.io/ent/dialect/sql/sqlgraph" 10 | ) 11 | 12 | const ( 13 | // Label holds the string label denoting the tenant type in the database. 14 | Label = "tenant" 15 | // FieldID holds the string denoting the id field in the database. 16 | FieldID = "id" 17 | // FieldCreateTime holds the string denoting the create_time field in the database. 18 | FieldCreateTime = "create_time" 19 | // FieldUpdateTime holds the string denoting the update_time field in the database. 20 | FieldUpdateTime = "update_time" 21 | // FieldName holds the string denoting the name field in the database. 22 | FieldName = "name" 23 | // FieldDisplayName holds the string denoting the display_name field in the database. 24 | FieldDisplayName = "display_name" 25 | // FieldRegion holds the string denoting the region field in the database. 26 | FieldRegion = "region" 27 | // EdgeConn holds the string denoting the conn edge name in mutations. 28 | EdgeConn = "conn" 29 | // Table holds the table name of the tenant in the database. 30 | Table = "tenants" 31 | // ConnTable is the table that holds the conn relation/edge. 32 | ConnTable = "tenant_conns" 33 | // ConnInverseTable is the table name for the TenantConn entity. 34 | // It exists in this package in order to avoid circular dependency with the "tenantconn" package. 35 | ConnInverseTable = "tenant_conns" 36 | // ConnColumn is the table column denoting the conn relation/edge. 37 | ConnColumn = "tenant_conn" 38 | ) 39 | 40 | // Columns holds all SQL columns for tenant fields. 41 | var Columns = []string{ 42 | FieldID, 43 | FieldCreateTime, 44 | FieldUpdateTime, 45 | FieldName, 46 | FieldDisplayName, 47 | FieldRegion, 48 | } 49 | 50 | // ValidColumn reports if the column name is valid (part of the table columns). 51 | func ValidColumn(column string) bool { 52 | for i := range Columns { 53 | if column == Columns[i] { 54 | return true 55 | } 56 | } 57 | return false 58 | } 59 | 60 | var ( 61 | // DefaultCreateTime holds the default value on creation for the "create_time" field. 62 | DefaultCreateTime func() time.Time 63 | // DefaultUpdateTime holds the default value on creation for the "update_time" field. 64 | DefaultUpdateTime func() time.Time 65 | // UpdateDefaultUpdateTime holds the default value on update for the "update_time" field. 66 | UpdateDefaultUpdateTime func() time.Time 67 | ) 68 | 69 | // OrderOption defines the ordering options for the Tenant queries. 70 | type OrderOption func(*sql.Selector) 71 | 72 | // ByID orders the results by the id field. 73 | func ByID(opts ...sql.OrderTermOption) OrderOption { 74 | return sql.OrderByField(FieldID, opts...).ToFunc() 75 | } 76 | 77 | // ByCreateTime orders the results by the create_time field. 78 | func ByCreateTime(opts ...sql.OrderTermOption) OrderOption { 79 | return sql.OrderByField(FieldCreateTime, opts...).ToFunc() 80 | } 81 | 82 | // ByUpdateTime orders the results by the update_time field. 83 | func ByUpdateTime(opts ...sql.OrderTermOption) OrderOption { 84 | return sql.OrderByField(FieldUpdateTime, opts...).ToFunc() 85 | } 86 | 87 | // ByName orders the results by the name field. 88 | func ByName(opts ...sql.OrderTermOption) OrderOption { 89 | return sql.OrderByField(FieldName, opts...).ToFunc() 90 | } 91 | 92 | // ByDisplayName orders the results by the display_name field. 93 | func ByDisplayName(opts ...sql.OrderTermOption) OrderOption { 94 | return sql.OrderByField(FieldDisplayName, opts...).ToFunc() 95 | } 96 | 97 | // ByRegion orders the results by the region field. 98 | func ByRegion(opts ...sql.OrderTermOption) OrderOption { 99 | return sql.OrderByField(FieldRegion, opts...).ToFunc() 100 | } 101 | 102 | // ByConnCount orders the results by conn count. 103 | func ByConnCount(opts ...sql.OrderTermOption) OrderOption { 104 | return func(s *sql.Selector) { 105 | sqlgraph.OrderByNeighborsCount(s, newConnStep(), opts...) 106 | } 107 | } 108 | 109 | // ByConn orders the results by conn terms. 110 | func ByConn(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { 111 | return func(s *sql.Selector) { 112 | sqlgraph.OrderByNeighborTerms(s, newConnStep(), append([]sql.OrderTerm{term}, terms...)...) 113 | } 114 | } 115 | func newConnStep() *sqlgraph.Step { 116 | return sqlgraph.NewStep( 117 | sqlgraph.From(Table, FieldID), 118 | sqlgraph.To(ConnInverseTable, FieldID), 119 | sqlgraph.Edge(sqlgraph.O2M, false, ConnTable, ConnColumn), 120 | ) 121 | } 122 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/tenant_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | "entgo.io/ent/dialect/sql/sqlgraph" 10 | "entgo.io/ent/schema/field" 11 | "github.com/go-saas/saas/examples/ent/shared/ent/predicate" 12 | "github.com/go-saas/saas/examples/ent/shared/ent/tenant" 13 | ) 14 | 15 | // TenantDelete is the builder for deleting a Tenant entity. 16 | type TenantDelete struct { 17 | config 18 | hooks []Hook 19 | mutation *TenantMutation 20 | } 21 | 22 | // Where appends a list predicates to the TenantDelete builder. 23 | func (td *TenantDelete) Where(ps ...predicate.Tenant) *TenantDelete { 24 | td.mutation.Where(ps...) 25 | return td 26 | } 27 | 28 | // Exec executes the deletion query and returns how many vertices were deleted. 29 | func (td *TenantDelete) Exec(ctx context.Context) (int, error) { 30 | return withHooks(ctx, td.sqlExec, td.mutation, td.hooks) 31 | } 32 | 33 | // ExecX is like Exec, but panics if an error occurs. 34 | func (td *TenantDelete) ExecX(ctx context.Context) int { 35 | n, err := td.Exec(ctx) 36 | if err != nil { 37 | panic(err) 38 | } 39 | return n 40 | } 41 | 42 | func (td *TenantDelete) sqlExec(ctx context.Context) (int, error) { 43 | _spec := sqlgraph.NewDeleteSpec(tenant.Table, sqlgraph.NewFieldSpec(tenant.FieldID, field.TypeInt)) 44 | if ps := td.mutation.predicates; len(ps) > 0 { 45 | _spec.Predicate = func(selector *sql.Selector) { 46 | for i := range ps { 47 | ps[i](selector) 48 | } 49 | } 50 | } 51 | affected, err := sqlgraph.DeleteNodes(ctx, td.driver, _spec) 52 | if err != nil && sqlgraph.IsConstraintError(err) { 53 | err = &ConstraintError{msg: err.Error(), wrap: err} 54 | } 55 | td.mutation.done = true 56 | return affected, err 57 | } 58 | 59 | // TenantDeleteOne is the builder for deleting a single Tenant entity. 60 | type TenantDeleteOne struct { 61 | td *TenantDelete 62 | } 63 | 64 | // Where appends a list predicates to the TenantDelete builder. 65 | func (tdo *TenantDeleteOne) Where(ps ...predicate.Tenant) *TenantDeleteOne { 66 | tdo.td.mutation.Where(ps...) 67 | return tdo 68 | } 69 | 70 | // Exec executes the deletion query. 71 | func (tdo *TenantDeleteOne) Exec(ctx context.Context) error { 72 | n, err := tdo.td.Exec(ctx) 73 | switch { 74 | case err != nil: 75 | return err 76 | case n == 0: 77 | return &NotFoundError{tenant.Label} 78 | default: 79 | return nil 80 | } 81 | } 82 | 83 | // ExecX is like Exec, but panics if an error occurs. 84 | func (tdo *TenantDeleteOne) ExecX(ctx context.Context) { 85 | if err := tdo.Exec(ctx); err != nil { 86 | panic(err) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/tenantconn.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | "time" 9 | 10 | "entgo.io/ent" 11 | "entgo.io/ent/dialect/sql" 12 | "github.com/go-saas/saas/examples/ent/shared/ent/tenantconn" 13 | ) 14 | 15 | // TenantConn is the model entity for the TenantConn schema. 16 | type TenantConn struct { 17 | config `json:"-"` 18 | // ID of the ent. 19 | ID int `json:"id,omitempty"` 20 | // CreateTime holds the value of the "create_time" field. 21 | CreateTime time.Time `json:"create_time,omitempty"` 22 | // UpdateTime holds the value of the "update_time" field. 23 | UpdateTime time.Time `json:"update_time,omitempty"` 24 | // Key holds the value of the "key" field. 25 | Key string `json:"key,omitempty"` 26 | // Value holds the value of the "value" field. 27 | Value string `json:"value,omitempty"` 28 | tenant_conn *int 29 | selectValues sql.SelectValues 30 | } 31 | 32 | // scanValues returns the types for scanning values from sql.Rows. 33 | func (*TenantConn) scanValues(columns []string) ([]any, error) { 34 | values := make([]any, len(columns)) 35 | for i := range columns { 36 | switch columns[i] { 37 | case tenantconn.FieldID: 38 | values[i] = new(sql.NullInt64) 39 | case tenantconn.FieldKey, tenantconn.FieldValue: 40 | values[i] = new(sql.NullString) 41 | case tenantconn.FieldCreateTime, tenantconn.FieldUpdateTime: 42 | values[i] = new(sql.NullTime) 43 | case tenantconn.ForeignKeys[0]: // tenant_conn 44 | values[i] = new(sql.NullInt64) 45 | default: 46 | values[i] = new(sql.UnknownType) 47 | } 48 | } 49 | return values, nil 50 | } 51 | 52 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 53 | // to the TenantConn fields. 54 | func (tc *TenantConn) assignValues(columns []string, values []any) error { 55 | if m, n := len(values), len(columns); m < n { 56 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 57 | } 58 | for i := range columns { 59 | switch columns[i] { 60 | case tenantconn.FieldID: 61 | value, ok := values[i].(*sql.NullInt64) 62 | if !ok { 63 | return fmt.Errorf("unexpected type %T for field id", value) 64 | } 65 | tc.ID = int(value.Int64) 66 | case tenantconn.FieldCreateTime: 67 | if value, ok := values[i].(*sql.NullTime); !ok { 68 | return fmt.Errorf("unexpected type %T for field create_time", values[i]) 69 | } else if value.Valid { 70 | tc.CreateTime = value.Time 71 | } 72 | case tenantconn.FieldUpdateTime: 73 | if value, ok := values[i].(*sql.NullTime); !ok { 74 | return fmt.Errorf("unexpected type %T for field update_time", values[i]) 75 | } else if value.Valid { 76 | tc.UpdateTime = value.Time 77 | } 78 | case tenantconn.FieldKey: 79 | if value, ok := values[i].(*sql.NullString); !ok { 80 | return fmt.Errorf("unexpected type %T for field key", values[i]) 81 | } else if value.Valid { 82 | tc.Key = value.String 83 | } 84 | case tenantconn.FieldValue: 85 | if value, ok := values[i].(*sql.NullString); !ok { 86 | return fmt.Errorf("unexpected type %T for field value", values[i]) 87 | } else if value.Valid { 88 | tc.Value = value.String 89 | } 90 | case tenantconn.ForeignKeys[0]: 91 | if value, ok := values[i].(*sql.NullInt64); !ok { 92 | return fmt.Errorf("unexpected type %T for edge-field tenant_conn", value) 93 | } else if value.Valid { 94 | tc.tenant_conn = new(int) 95 | *tc.tenant_conn = int(value.Int64) 96 | } 97 | default: 98 | tc.selectValues.Set(columns[i], values[i]) 99 | } 100 | } 101 | return nil 102 | } 103 | 104 | // GetValue returns the ent.Value that was dynamically selected and assigned to the TenantConn. 105 | // This includes values selected through modifiers, order, etc. 106 | func (tc *TenantConn) GetValue(name string) (ent.Value, error) { 107 | return tc.selectValues.Get(name) 108 | } 109 | 110 | // Update returns a builder for updating this TenantConn. 111 | // Note that you need to call TenantConn.Unwrap() before calling this method if this TenantConn 112 | // was returned from a transaction, and the transaction was committed or rolled back. 113 | func (tc *TenantConn) Update() *TenantConnUpdateOne { 114 | return NewTenantConnClient(tc.config).UpdateOne(tc) 115 | } 116 | 117 | // Unwrap unwraps the TenantConn entity that was returned from a transaction after it was closed, 118 | // so that all future queries will be executed through the driver which created the transaction. 119 | func (tc *TenantConn) Unwrap() *TenantConn { 120 | _tx, ok := tc.config.driver.(*txDriver) 121 | if !ok { 122 | panic("ent: TenantConn is not a transactional entity") 123 | } 124 | tc.config.driver = _tx.drv 125 | return tc 126 | } 127 | 128 | // String implements the fmt.Stringer. 129 | func (tc *TenantConn) String() string { 130 | var builder strings.Builder 131 | builder.WriteString("TenantConn(") 132 | builder.WriteString(fmt.Sprintf("id=%v, ", tc.ID)) 133 | builder.WriteString("create_time=") 134 | builder.WriteString(tc.CreateTime.Format(time.ANSIC)) 135 | builder.WriteString(", ") 136 | builder.WriteString("update_time=") 137 | builder.WriteString(tc.UpdateTime.Format(time.ANSIC)) 138 | builder.WriteString(", ") 139 | builder.WriteString("key=") 140 | builder.WriteString(tc.Key) 141 | builder.WriteString(", ") 142 | builder.WriteString("value=") 143 | builder.WriteString(tc.Value) 144 | builder.WriteByte(')') 145 | return builder.String() 146 | } 147 | 148 | // TenantConns is a parsable slice of TenantConn. 149 | type TenantConns []*TenantConn 150 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/tenantconn/tenantconn.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package tenantconn 4 | 5 | import ( 6 | "time" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | ) 10 | 11 | const ( 12 | // Label holds the string label denoting the tenantconn type in the database. 13 | Label = "tenant_conn" 14 | // FieldID holds the string denoting the id field in the database. 15 | FieldID = "id" 16 | // FieldCreateTime holds the string denoting the create_time field in the database. 17 | FieldCreateTime = "create_time" 18 | // FieldUpdateTime holds the string denoting the update_time field in the database. 19 | FieldUpdateTime = "update_time" 20 | // FieldKey holds the string denoting the key field in the database. 21 | FieldKey = "key" 22 | // FieldValue holds the string denoting the value field in the database. 23 | FieldValue = "value" 24 | // Table holds the table name of the tenantconn in the database. 25 | Table = "tenant_conns" 26 | ) 27 | 28 | // Columns holds all SQL columns for tenantconn fields. 29 | var Columns = []string{ 30 | FieldID, 31 | FieldCreateTime, 32 | FieldUpdateTime, 33 | FieldKey, 34 | FieldValue, 35 | } 36 | 37 | // ForeignKeys holds the SQL foreign-keys that are owned by the "tenant_conns" 38 | // table and are not defined as standalone fields in the schema. 39 | var ForeignKeys = []string{ 40 | "tenant_conn", 41 | } 42 | 43 | // ValidColumn reports if the column name is valid (part of the table columns). 44 | func ValidColumn(column string) bool { 45 | for i := range Columns { 46 | if column == Columns[i] { 47 | return true 48 | } 49 | } 50 | for i := range ForeignKeys { 51 | if column == ForeignKeys[i] { 52 | return true 53 | } 54 | } 55 | return false 56 | } 57 | 58 | var ( 59 | // DefaultCreateTime holds the default value on creation for the "create_time" field. 60 | DefaultCreateTime func() time.Time 61 | // DefaultUpdateTime holds the default value on creation for the "update_time" field. 62 | DefaultUpdateTime func() time.Time 63 | // UpdateDefaultUpdateTime holds the default value on update for the "update_time" field. 64 | UpdateDefaultUpdateTime func() time.Time 65 | ) 66 | 67 | // OrderOption defines the ordering options for the TenantConn queries. 68 | type OrderOption func(*sql.Selector) 69 | 70 | // ByID orders the results by the id field. 71 | func ByID(opts ...sql.OrderTermOption) OrderOption { 72 | return sql.OrderByField(FieldID, opts...).ToFunc() 73 | } 74 | 75 | // ByCreateTime orders the results by the create_time field. 76 | func ByCreateTime(opts ...sql.OrderTermOption) OrderOption { 77 | return sql.OrderByField(FieldCreateTime, opts...).ToFunc() 78 | } 79 | 80 | // ByUpdateTime orders the results by the update_time field. 81 | func ByUpdateTime(opts ...sql.OrderTermOption) OrderOption { 82 | return sql.OrderByField(FieldUpdateTime, opts...).ToFunc() 83 | } 84 | 85 | // ByKey orders the results by the key field. 86 | func ByKey(opts ...sql.OrderTermOption) OrderOption { 87 | return sql.OrderByField(FieldKey, opts...).ToFunc() 88 | } 89 | 90 | // ByValue orders the results by the value field. 91 | func ByValue(opts ...sql.OrderTermOption) OrderOption { 92 | return sql.OrderByField(FieldValue, opts...).ToFunc() 93 | } 94 | -------------------------------------------------------------------------------- /examples/ent/shared/ent/tenantconn_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | "entgo.io/ent/dialect/sql/sqlgraph" 10 | "entgo.io/ent/schema/field" 11 | "github.com/go-saas/saas/examples/ent/shared/ent/predicate" 12 | "github.com/go-saas/saas/examples/ent/shared/ent/tenantconn" 13 | ) 14 | 15 | // TenantConnDelete is the builder for deleting a TenantConn entity. 16 | type TenantConnDelete struct { 17 | config 18 | hooks []Hook 19 | mutation *TenantConnMutation 20 | } 21 | 22 | // Where appends a list predicates to the TenantConnDelete builder. 23 | func (tcd *TenantConnDelete) Where(ps ...predicate.TenantConn) *TenantConnDelete { 24 | tcd.mutation.Where(ps...) 25 | return tcd 26 | } 27 | 28 | // Exec executes the deletion query and returns how many vertices were deleted. 29 | func (tcd *TenantConnDelete) Exec(ctx context.Context) (int, error) { 30 | return withHooks(ctx, tcd.sqlExec, tcd.mutation, tcd.hooks) 31 | } 32 | 33 | // ExecX is like Exec, but panics if an error occurs. 34 | func (tcd *TenantConnDelete) ExecX(ctx context.Context) int { 35 | n, err := tcd.Exec(ctx) 36 | if err != nil { 37 | panic(err) 38 | } 39 | return n 40 | } 41 | 42 | func (tcd *TenantConnDelete) sqlExec(ctx context.Context) (int, error) { 43 | _spec := sqlgraph.NewDeleteSpec(tenantconn.Table, sqlgraph.NewFieldSpec(tenantconn.FieldID, field.TypeInt)) 44 | if ps := tcd.mutation.predicates; len(ps) > 0 { 45 | _spec.Predicate = func(selector *sql.Selector) { 46 | for i := range ps { 47 | ps[i](selector) 48 | } 49 | } 50 | } 51 | affected, err := sqlgraph.DeleteNodes(ctx, tcd.driver, _spec) 52 | if err != nil && sqlgraph.IsConstraintError(err) { 53 | err = &ConstraintError{msg: err.Error(), wrap: err} 54 | } 55 | tcd.mutation.done = true 56 | return affected, err 57 | } 58 | 59 | // TenantConnDeleteOne is the builder for deleting a single TenantConn entity. 60 | type TenantConnDeleteOne struct { 61 | tcd *TenantConnDelete 62 | } 63 | 64 | // Where appends a list predicates to the TenantConnDelete builder. 65 | func (tcdo *TenantConnDeleteOne) Where(ps ...predicate.TenantConn) *TenantConnDeleteOne { 66 | tcdo.tcd.mutation.Where(ps...) 67 | return tcdo 68 | } 69 | 70 | // Exec executes the deletion query. 71 | func (tcdo *TenantConnDeleteOne) Exec(ctx context.Context) error { 72 | n, err := tcdo.tcd.Exec(ctx) 73 | switch { 74 | case err != nil: 75 | return err 76 | case n == 0: 77 | return &NotFoundError{tenantconn.Label} 78 | default: 79 | return nil 80 | } 81 | } 82 | 83 | // ExecX is like Exec, but panics if an error occurs. 84 | func (tcdo *TenantConnDeleteOne) ExecX(ctx context.Context) { 85 | if err := tcdo.Exec(ctx); err != nil { 86 | panic(err) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/enttest/enttest.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package enttest 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/go-saas/saas/examples/ent/tenant/ent" 9 | // required by schema hooks. 10 | _ "github.com/go-saas/saas/examples/ent/tenant/ent/runtime" 11 | 12 | "entgo.io/ent/dialect/sql/schema" 13 | "github.com/go-saas/saas/examples/ent/tenant/ent/migrate" 14 | ) 15 | 16 | type ( 17 | // TestingT is the interface that is shared between 18 | // testing.T and testing.B and used by enttest. 19 | TestingT interface { 20 | FailNow() 21 | Error(...any) 22 | } 23 | 24 | // Option configures client creation. 25 | Option func(*options) 26 | 27 | options struct { 28 | opts []ent.Option 29 | migrateOpts []schema.MigrateOption 30 | } 31 | ) 32 | 33 | // WithOptions forwards options to client creation. 34 | func WithOptions(opts ...ent.Option) Option { 35 | return func(o *options) { 36 | o.opts = append(o.opts, opts...) 37 | } 38 | } 39 | 40 | // WithMigrateOptions forwards options to auto migration. 41 | func WithMigrateOptions(opts ...schema.MigrateOption) Option { 42 | return func(o *options) { 43 | o.migrateOpts = append(o.migrateOpts, opts...) 44 | } 45 | } 46 | 47 | func newOptions(opts []Option) *options { 48 | o := &options{} 49 | for _, opt := range opts { 50 | opt(o) 51 | } 52 | return o 53 | } 54 | 55 | // Open calls ent.Open and auto-run migration. 56 | func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { 57 | o := newOptions(opts) 58 | c, err := ent.Open(driverName, dataSourceName, o.opts...) 59 | if err != nil { 60 | t.Error(err) 61 | t.FailNow() 62 | } 63 | migrateSchema(t, c, o) 64 | return c 65 | } 66 | 67 | // NewClient calls ent.NewClient and auto-run migration. 68 | func NewClient(t TestingT, opts ...Option) *ent.Client { 69 | o := newOptions(opts) 70 | c := ent.NewClient(o.opts...) 71 | migrateSchema(t, c, o) 72 | return c 73 | } 74 | func migrateSchema(t TestingT, c *ent.Client, o *options) { 75 | tables, err := schema.CopyTables(migrate.Tables) 76 | if err != nil { 77 | t.Error(err) 78 | t.FailNow() 79 | } 80 | if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil { 81 | t.Error(err) 82 | t.FailNow() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/generate.go: -------------------------------------------------------------------------------- 1 | package ent 2 | 3 | //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert,privacy,intercept,schema/snapshot ./schema 4 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/hook/hook.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package hook 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "github.com/go-saas/saas/examples/ent/tenant/ent" 10 | ) 11 | 12 | // The PostFunc type is an adapter to allow the use of ordinary 13 | // function as Post mutator. 14 | type PostFunc func(context.Context, *ent.PostMutation) (ent.Value, error) 15 | 16 | // Mutate calls f(ctx, m). 17 | func (f PostFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { 18 | if mv, ok := m.(*ent.PostMutation); ok { 19 | return f(ctx, mv) 20 | } 21 | return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PostMutation", m) 22 | } 23 | 24 | // Condition is a hook condition function. 25 | type Condition func(context.Context, ent.Mutation) bool 26 | 27 | // And groups conditions with the AND operator. 28 | func And(first, second Condition, rest ...Condition) Condition { 29 | return func(ctx context.Context, m ent.Mutation) bool { 30 | if !first(ctx, m) || !second(ctx, m) { 31 | return false 32 | } 33 | for _, cond := range rest { 34 | if !cond(ctx, m) { 35 | return false 36 | } 37 | } 38 | return true 39 | } 40 | } 41 | 42 | // Or groups conditions with the OR operator. 43 | func Or(first, second Condition, rest ...Condition) Condition { 44 | return func(ctx context.Context, m ent.Mutation) bool { 45 | if first(ctx, m) || second(ctx, m) { 46 | return true 47 | } 48 | for _, cond := range rest { 49 | if cond(ctx, m) { 50 | return true 51 | } 52 | } 53 | return false 54 | } 55 | } 56 | 57 | // Not negates a given condition. 58 | func Not(cond Condition) Condition { 59 | return func(ctx context.Context, m ent.Mutation) bool { 60 | return !cond(ctx, m) 61 | } 62 | } 63 | 64 | // HasOp is a condition testing mutation operation. 65 | func HasOp(op ent.Op) Condition { 66 | return func(_ context.Context, m ent.Mutation) bool { 67 | return m.Op().Is(op) 68 | } 69 | } 70 | 71 | // HasAddedFields is a condition validating `.AddedField` on fields. 72 | func HasAddedFields(field string, fields ...string) Condition { 73 | return func(_ context.Context, m ent.Mutation) bool { 74 | if _, exists := m.AddedField(field); !exists { 75 | return false 76 | } 77 | for _, field := range fields { 78 | if _, exists := m.AddedField(field); !exists { 79 | return false 80 | } 81 | } 82 | return true 83 | } 84 | } 85 | 86 | // HasClearedFields is a condition validating `.FieldCleared` on fields. 87 | func HasClearedFields(field string, fields ...string) Condition { 88 | return func(_ context.Context, m ent.Mutation) bool { 89 | if exists := m.FieldCleared(field); !exists { 90 | return false 91 | } 92 | for _, field := range fields { 93 | if exists := m.FieldCleared(field); !exists { 94 | return false 95 | } 96 | } 97 | return true 98 | } 99 | } 100 | 101 | // HasFields is a condition validating `.Field` on fields. 102 | func HasFields(field string, fields ...string) Condition { 103 | return func(_ context.Context, m ent.Mutation) bool { 104 | if _, exists := m.Field(field); !exists { 105 | return false 106 | } 107 | for _, field := range fields { 108 | if _, exists := m.Field(field); !exists { 109 | return false 110 | } 111 | } 112 | return true 113 | } 114 | } 115 | 116 | // If executes the given hook under condition. 117 | // 118 | // hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) 119 | func If(hk ent.Hook, cond Condition) ent.Hook { 120 | return func(next ent.Mutator) ent.Mutator { 121 | return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { 122 | if cond(ctx, m) { 123 | return hk(next).Mutate(ctx, m) 124 | } 125 | return next.Mutate(ctx, m) 126 | }) 127 | } 128 | } 129 | 130 | // On executes the given hook only for the given operation. 131 | // 132 | // hook.On(Log, ent.Delete|ent.Create) 133 | func On(hk ent.Hook, op ent.Op) ent.Hook { 134 | return If(hk, HasOp(op)) 135 | } 136 | 137 | // Unless skips the given hook only for the given operation. 138 | // 139 | // hook.Unless(Log, ent.Update|ent.UpdateOne) 140 | func Unless(hk ent.Hook, op ent.Op) ent.Hook { 141 | return If(hk, Not(HasOp(op))) 142 | } 143 | 144 | // FixedError is a hook returning a fixed error. 145 | func FixedError(err error) ent.Hook { 146 | return func(ent.Mutator) ent.Mutator { 147 | return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { 148 | return nil, err 149 | }) 150 | } 151 | } 152 | 153 | // Reject returns a hook that rejects all operations that match op. 154 | // 155 | // func (T) Hooks() []ent.Hook { 156 | // return []ent.Hook{ 157 | // Reject(ent.Delete|ent.Update), 158 | // } 159 | // } 160 | func Reject(op ent.Op) ent.Hook { 161 | hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) 162 | return On(hk, op) 163 | } 164 | 165 | // Chain acts as a list of hooks and is effectively immutable. 166 | // Once created, it will always hold the same set of hooks in the same order. 167 | type Chain struct { 168 | hooks []ent.Hook 169 | } 170 | 171 | // NewChain creates a new chain of hooks. 172 | func NewChain(hooks ...ent.Hook) Chain { 173 | return Chain{append([]ent.Hook(nil), hooks...)} 174 | } 175 | 176 | // Hook chains the list of hooks and returns the final hook. 177 | func (c Chain) Hook() ent.Hook { 178 | return func(mutator ent.Mutator) ent.Mutator { 179 | for i := len(c.hooks) - 1; i >= 0; i-- { 180 | mutator = c.hooks[i](mutator) 181 | } 182 | return mutator 183 | } 184 | } 185 | 186 | // Append extends a chain, adding the specified hook 187 | // as the last ones in the mutation flow. 188 | func (c Chain) Append(hooks ...ent.Hook) Chain { 189 | newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) 190 | newHooks = append(newHooks, c.hooks...) 191 | newHooks = append(newHooks, hooks...) 192 | return Chain{newHooks} 193 | } 194 | 195 | // Extend extends a chain, adding the specified chain 196 | // as the last ones in the mutation flow. 197 | func (c Chain) Extend(chain Chain) Chain { 198 | return c.Append(chain.hooks...) 199 | } 200 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/intercept/intercept.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package intercept 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "entgo.io/ent/dialect/sql" 10 | "github.com/go-saas/saas/examples/ent/tenant/ent" 11 | "github.com/go-saas/saas/examples/ent/tenant/ent/post" 12 | "github.com/go-saas/saas/examples/ent/tenant/ent/predicate" 13 | ) 14 | 15 | // The Query interface represents an operation that queries a graph. 16 | // By using this interface, users can write generic code that manipulates 17 | // query builders of different types. 18 | type Query interface { 19 | // Type returns the string representation of the query type. 20 | Type() string 21 | // Limit the number of records to be returned by this query. 22 | Limit(int) 23 | // Offset to start from. 24 | Offset(int) 25 | // Unique configures the query builder to filter duplicate records. 26 | Unique(bool) 27 | // Order specifies how the records should be ordered. 28 | Order(...func(*sql.Selector)) 29 | // WhereP appends storage-level predicates to the query builder. Using this method, users 30 | // can use type-assertion to append predicates that do not depend on any generated package. 31 | WhereP(...func(*sql.Selector)) 32 | } 33 | 34 | // The Func type is an adapter that allows ordinary functions to be used as interceptors. 35 | // Unlike traversal functions, interceptors are skipped during graph traversals. Note that the 36 | // implementation of Func is different from the one defined in entgo.io/ent.InterceptFunc. 37 | type Func func(context.Context, Query) error 38 | 39 | // Intercept calls f(ctx, q) and then applied the next Querier. 40 | func (f Func) Intercept(next ent.Querier) ent.Querier { 41 | return ent.QuerierFunc(func(ctx context.Context, q ent.Query) (ent.Value, error) { 42 | query, err := NewQuery(q) 43 | if err != nil { 44 | return nil, err 45 | } 46 | if err := f(ctx, query); err != nil { 47 | return nil, err 48 | } 49 | return next.Query(ctx, q) 50 | }) 51 | } 52 | 53 | // The TraverseFunc type is an adapter to allow the use of ordinary function as Traverser. 54 | // If f is a function with the appropriate signature, TraverseFunc(f) is a Traverser that calls f. 55 | type TraverseFunc func(context.Context, Query) error 56 | 57 | // Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. 58 | func (f TraverseFunc) Intercept(next ent.Querier) ent.Querier { 59 | return next 60 | } 61 | 62 | // Traverse calls f(ctx, q). 63 | func (f TraverseFunc) Traverse(ctx context.Context, q ent.Query) error { 64 | query, err := NewQuery(q) 65 | if err != nil { 66 | return err 67 | } 68 | return f(ctx, query) 69 | } 70 | 71 | // The PostFunc type is an adapter to allow the use of ordinary function as a Querier. 72 | type PostFunc func(context.Context, *ent.PostQuery) (ent.Value, error) 73 | 74 | // Query calls f(ctx, q). 75 | func (f PostFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { 76 | if q, ok := q.(*ent.PostQuery); ok { 77 | return f(ctx, q) 78 | } 79 | return nil, fmt.Errorf("unexpected query type %T. expect *ent.PostQuery", q) 80 | } 81 | 82 | // The TraversePost type is an adapter to allow the use of ordinary function as Traverser. 83 | type TraversePost func(context.Context, *ent.PostQuery) error 84 | 85 | // Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. 86 | func (f TraversePost) Intercept(next ent.Querier) ent.Querier { 87 | return next 88 | } 89 | 90 | // Traverse calls f(ctx, q). 91 | func (f TraversePost) Traverse(ctx context.Context, q ent.Query) error { 92 | if q, ok := q.(*ent.PostQuery); ok { 93 | return f(ctx, q) 94 | } 95 | return fmt.Errorf("unexpected query type %T. expect *ent.PostQuery", q) 96 | } 97 | 98 | // NewQuery returns the generic Query interface for the given typed query. 99 | func NewQuery(q ent.Query) (Query, error) { 100 | switch q := q.(type) { 101 | case *ent.PostQuery: 102 | return &query[*ent.PostQuery, predicate.Post, post.OrderOption]{typ: ent.TypePost, tq: q}, nil 103 | default: 104 | return nil, fmt.Errorf("unknown query type %T", q) 105 | } 106 | } 107 | 108 | type query[T any, P ~func(*sql.Selector), R ~func(*sql.Selector)] struct { 109 | typ string 110 | tq interface { 111 | Limit(int) T 112 | Offset(int) T 113 | Unique(bool) T 114 | Order(...R) T 115 | Where(...P) T 116 | } 117 | } 118 | 119 | func (q query[T, P, R]) Type() string { 120 | return q.typ 121 | } 122 | 123 | func (q query[T, P, R]) Limit(limit int) { 124 | q.tq.Limit(limit) 125 | } 126 | 127 | func (q query[T, P, R]) Offset(offset int) { 128 | q.tq.Offset(offset) 129 | } 130 | 131 | func (q query[T, P, R]) Unique(unique bool) { 132 | q.tq.Unique(unique) 133 | } 134 | 135 | func (q query[T, P, R]) Order(orders ...func(*sql.Selector)) { 136 | rs := make([]R, len(orders)) 137 | for i := range orders { 138 | rs[i] = orders[i] 139 | } 140 | q.tq.Order(rs...) 141 | } 142 | 143 | func (q query[T, P, R]) WhereP(ps ...func(*sql.Selector)) { 144 | p := make([]P, len(ps)) 145 | for i := range ps { 146 | p[i] = ps[i] 147 | } 148 | q.tq.Where(p...) 149 | } 150 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/internal/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | //go:build tools 4 | // +build tools 5 | 6 | // Package internal holds a loadable version of the latest schema. 7 | package internal 8 | 9 | const Schema = `{"Schema":"github.com/go-saas/saas/examples/ent/tenant/ent/schema","Package":"github.com/go-saas/saas/examples/ent/tenant/ent","Schemas":[{"name":"Post","config":{"Table":""},"fields":[{"name":"tenant_id","type":{"Type":7,"Ident":"*sql.NullString","PkgPath":"database/sql","PkgName":"sql","Nillable":true,"RType":{"Name":"NullString","Ident":"sql.NullString","Kind":22,"PkgPath":"database/sql","Methods":{"Scan":{"In":[{"Name":"","Ident":"interface {}","Kind":20,"PkgPath":"","Methods":null}],"Out":[{"Name":"error","Ident":"error","Kind":20,"PkgPath":"","Methods":null}]},"Value":{"In":[],"Out":[{"Name":"Value","Ident":"driver.Value","Kind":20,"PkgPath":"database/sql/driver","Methods":null},{"Name":"error","Ident":"error","Kind":20,"PkgPath":"","Methods":null}]}}}},"optional":true,"position":{"Index":0,"MixedIn":true,"MixinIndex":0}},{"name":"id","type":{"Type":12,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":0,"MixedIn":false,"MixinIndex":0}},{"name":"title","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"position":{"Index":1,"MixedIn":false,"MixinIndex":0}},{"name":"description","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"optional":true,"position":{"Index":2,"MixedIn":false,"MixinIndex":0}},{"name":"dsn","type":{"Type":7,"Ident":"","PkgPath":"","PkgName":"","Nillable":false,"RType":null},"optional":true,"position":{"Index":3,"MixedIn":false,"MixinIndex":0}}],"hooks":[{"Index":0,"MixedIn":true,"MixinIndex":0}],"interceptors":[{"Index":0,"MixedIn":true,"MixinIndex":0}]}],"Features":["sql/upsert","privacy","intercept","schema/snapshot"]}` 10 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/migrate/migrate.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "io" 9 | 10 | "entgo.io/ent/dialect" 11 | "entgo.io/ent/dialect/sql/schema" 12 | ) 13 | 14 | var ( 15 | // WithGlobalUniqueID sets the universal ids options to the migration. 16 | // If this option is enabled, ent migration will allocate a 1<<32 range 17 | // for the ids of each entity (table). 18 | // Note that this option cannot be applied on tables that already exist. 19 | WithGlobalUniqueID = schema.WithGlobalUniqueID 20 | // WithDropColumn sets the drop column option to the migration. 21 | // If this option is enabled, ent migration will drop old columns 22 | // that were used for both fields and edges. This defaults to false. 23 | WithDropColumn = schema.WithDropColumn 24 | // WithDropIndex sets the drop index option to the migration. 25 | // If this option is enabled, ent migration will drop old indexes 26 | // that were defined in the schema. This defaults to false. 27 | // Note that unique constraints are defined using `UNIQUE INDEX`, 28 | // and therefore, it's recommended to enable this option to get more 29 | // flexibility in the schema changes. 30 | WithDropIndex = schema.WithDropIndex 31 | // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. 32 | WithForeignKeys = schema.WithForeignKeys 33 | ) 34 | 35 | // Schema is the API for creating, migrating and dropping a schema. 36 | type Schema struct { 37 | drv dialect.Driver 38 | } 39 | 40 | // NewSchema creates a new schema client. 41 | func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } 42 | 43 | // Create creates all schema resources. 44 | func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { 45 | return Create(ctx, s, Tables, opts...) 46 | } 47 | 48 | // Create creates all table resources using the given schema driver. 49 | func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...schema.MigrateOption) error { 50 | migrate, err := schema.NewMigrate(s.drv, opts...) 51 | if err != nil { 52 | return fmt.Errorf("ent/migrate: %w", err) 53 | } 54 | return migrate.Create(ctx, tables...) 55 | } 56 | 57 | // WriteTo writes the schema changes to w instead of running them against the database. 58 | // 59 | // if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { 60 | // log.Fatal(err) 61 | // } 62 | func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { 63 | return Create(ctx, &Schema{drv: &schema.WriteDriver{Writer: w, Driver: s.drv}}, Tables, opts...) 64 | } 65 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/migrate/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql/schema" 7 | "entgo.io/ent/schema/field" 8 | ) 9 | 10 | var ( 11 | // PostsColumns holds the columns for the "posts" table. 12 | PostsColumns = []*schema.Column{ 13 | {Name: "id", Type: field.TypeInt, Increment: true}, 14 | {Name: "tenant_id", Type: field.TypeString, Nullable: true}, 15 | {Name: "title", Type: field.TypeString}, 16 | {Name: "description", Type: field.TypeString, Nullable: true}, 17 | {Name: "dsn", Type: field.TypeString, Nullable: true}, 18 | } 19 | // PostsTable holds the schema information for the "posts" table. 20 | PostsTable = &schema.Table{ 21 | Name: "posts", 22 | Columns: PostsColumns, 23 | PrimaryKey: []*schema.Column{PostsColumns[0]}, 24 | } 25 | // Tables holds all the tables in the schema. 26 | Tables = []*schema.Table{ 27 | PostsTable, 28 | } 29 | ) 30 | 31 | func init() { 32 | } 33 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/post.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "fmt" 7 | "strings" 8 | 9 | "entgo.io/ent" 10 | "entgo.io/ent/dialect/sql" 11 | "github.com/go-saas/saas/examples/ent/tenant/ent/post" 12 | ) 13 | 14 | // Post is the model entity for the Post schema. 15 | type Post struct { 16 | config `json:"-"` 17 | // ID of the ent. 18 | ID int `json:"id,omitempty"` 19 | // TenantID holds the value of the "tenant_id" field. 20 | TenantID *sql.NullString `json:"tenant_id,omitempty"` 21 | // Title holds the value of the "title" field. 22 | Title string `json:"title,omitempty"` 23 | // Description holds the value of the "description" field. 24 | Description string `json:"description,omitempty"` 25 | // Dsn holds the value of the "dsn" field. 26 | Dsn string `json:"dsn,omitempty"` 27 | selectValues sql.SelectValues 28 | } 29 | 30 | // scanValues returns the types for scanning values from sql.Rows. 31 | func (*Post) scanValues(columns []string) ([]any, error) { 32 | values := make([]any, len(columns)) 33 | for i := range columns { 34 | switch columns[i] { 35 | case post.FieldID: 36 | values[i] = new(sql.NullInt64) 37 | case post.FieldTenantID, post.FieldTitle, post.FieldDescription, post.FieldDsn: 38 | values[i] = new(sql.NullString) 39 | default: 40 | values[i] = new(sql.UnknownType) 41 | } 42 | } 43 | return values, nil 44 | } 45 | 46 | // assignValues assigns the values that were returned from sql.Rows (after scanning) 47 | // to the Post fields. 48 | func (po *Post) assignValues(columns []string, values []any) error { 49 | if m, n := len(values), len(columns); m < n { 50 | return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) 51 | } 52 | for i := range columns { 53 | switch columns[i] { 54 | case post.FieldID: 55 | value, ok := values[i].(*sql.NullInt64) 56 | if !ok { 57 | return fmt.Errorf("unexpected type %T for field id", value) 58 | } 59 | po.ID = int(value.Int64) 60 | case post.FieldTenantID: 61 | if value, ok := values[i].(*sql.NullString); !ok { 62 | return fmt.Errorf("unexpected type %T for field tenant_id", values[i]) 63 | } else if value.Valid { 64 | po.TenantID = value 65 | } 66 | case post.FieldTitle: 67 | if value, ok := values[i].(*sql.NullString); !ok { 68 | return fmt.Errorf("unexpected type %T for field title", values[i]) 69 | } else if value.Valid { 70 | po.Title = value.String 71 | } 72 | case post.FieldDescription: 73 | if value, ok := values[i].(*sql.NullString); !ok { 74 | return fmt.Errorf("unexpected type %T for field description", values[i]) 75 | } else if value.Valid { 76 | po.Description = value.String 77 | } 78 | case post.FieldDsn: 79 | if value, ok := values[i].(*sql.NullString); !ok { 80 | return fmt.Errorf("unexpected type %T for field dsn", values[i]) 81 | } else if value.Valid { 82 | po.Dsn = value.String 83 | } 84 | default: 85 | po.selectValues.Set(columns[i], values[i]) 86 | } 87 | } 88 | return nil 89 | } 90 | 91 | // Value returns the ent.Value that was dynamically selected and assigned to the Post. 92 | // This includes values selected through modifiers, order, etc. 93 | func (po *Post) Value(name string) (ent.Value, error) { 94 | return po.selectValues.Get(name) 95 | } 96 | 97 | // Update returns a builder for updating this Post. 98 | // Note that you need to call Post.Unwrap() before calling this method if this Post 99 | // was returned from a transaction, and the transaction was committed or rolled back. 100 | func (po *Post) Update() *PostUpdateOne { 101 | return NewPostClient(po.config).UpdateOne(po) 102 | } 103 | 104 | // Unwrap unwraps the Post entity that was returned from a transaction after it was closed, 105 | // so that all future queries will be executed through the driver which created the transaction. 106 | func (po *Post) Unwrap() *Post { 107 | _tx, ok := po.config.driver.(*txDriver) 108 | if !ok { 109 | panic("ent: Post is not a transactional entity") 110 | } 111 | po.config.driver = _tx.drv 112 | return po 113 | } 114 | 115 | // String implements the fmt.Stringer. 116 | func (po *Post) String() string { 117 | var builder strings.Builder 118 | builder.WriteString("Post(") 119 | builder.WriteString(fmt.Sprintf("id=%v, ", po.ID)) 120 | builder.WriteString("tenant_id=") 121 | builder.WriteString(fmt.Sprintf("%v", po.TenantID)) 122 | builder.WriteString(", ") 123 | builder.WriteString("title=") 124 | builder.WriteString(po.Title) 125 | builder.WriteString(", ") 126 | builder.WriteString("description=") 127 | builder.WriteString(po.Description) 128 | builder.WriteString(", ") 129 | builder.WriteString("dsn=") 130 | builder.WriteString(po.Dsn) 131 | builder.WriteByte(')') 132 | return builder.String() 133 | } 134 | 135 | // Posts is a parsable slice of Post. 136 | type Posts []*Post 137 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/post/post.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package post 4 | 5 | import ( 6 | "entgo.io/ent" 7 | "entgo.io/ent/dialect/sql" 8 | ) 9 | 10 | const ( 11 | // Label holds the string label denoting the post type in the database. 12 | Label = "post" 13 | // FieldID holds the string denoting the id field in the database. 14 | FieldID = "id" 15 | // FieldTenantID holds the string denoting the tenant_id field in the database. 16 | FieldTenantID = "tenant_id" 17 | // FieldTitle holds the string denoting the title field in the database. 18 | FieldTitle = "title" 19 | // FieldDescription holds the string denoting the description field in the database. 20 | FieldDescription = "description" 21 | // FieldDsn holds the string denoting the dsn field in the database. 22 | FieldDsn = "dsn" 23 | // Table holds the table name of the post in the database. 24 | Table = "posts" 25 | ) 26 | 27 | // Columns holds all SQL columns for post fields. 28 | var Columns = []string{ 29 | FieldID, 30 | FieldTenantID, 31 | FieldTitle, 32 | FieldDescription, 33 | FieldDsn, 34 | } 35 | 36 | // ValidColumn reports if the column name is valid (part of the table columns). 37 | func ValidColumn(column string) bool { 38 | for i := range Columns { 39 | if column == Columns[i] { 40 | return true 41 | } 42 | } 43 | return false 44 | } 45 | 46 | // Note that the variables below are initialized by the runtime 47 | // package on the initialization of the application. Therefore, 48 | // it should be imported in the main as follows: 49 | // 50 | // import _ "github.com/go-saas/saas/examples/ent/tenant/ent/runtime" 51 | var ( 52 | Hooks [1]ent.Hook 53 | Interceptors [1]ent.Interceptor 54 | ) 55 | 56 | // OrderOption defines the ordering options for the Post queries. 57 | type OrderOption func(*sql.Selector) 58 | 59 | // ByID orders the results by the id field. 60 | func ByID(opts ...sql.OrderTermOption) OrderOption { 61 | return sql.OrderByField(FieldID, opts...).ToFunc() 62 | } 63 | 64 | // ByTenantID orders the results by the tenant_id field. 65 | func ByTenantID(opts ...sql.OrderTermOption) OrderOption { 66 | return sql.OrderByField(FieldTenantID, opts...).ToFunc() 67 | } 68 | 69 | // ByTitle orders the results by the title field. 70 | func ByTitle(opts ...sql.OrderTermOption) OrderOption { 71 | return sql.OrderByField(FieldTitle, opts...).ToFunc() 72 | } 73 | 74 | // ByDescription orders the results by the description field. 75 | func ByDescription(opts ...sql.OrderTermOption) OrderOption { 76 | return sql.OrderByField(FieldDescription, opts...).ToFunc() 77 | } 78 | 79 | // ByDsn orders the results by the dsn field. 80 | func ByDsn(opts ...sql.OrderTermOption) OrderOption { 81 | return sql.OrderByField(FieldDsn, opts...).ToFunc() 82 | } 83 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/post_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | 8 | "entgo.io/ent/dialect/sql" 9 | "entgo.io/ent/dialect/sql/sqlgraph" 10 | "entgo.io/ent/schema/field" 11 | "github.com/go-saas/saas/examples/ent/tenant/ent/post" 12 | "github.com/go-saas/saas/examples/ent/tenant/ent/predicate" 13 | ) 14 | 15 | // PostDelete is the builder for deleting a Post entity. 16 | type PostDelete struct { 17 | config 18 | hooks []Hook 19 | mutation *PostMutation 20 | } 21 | 22 | // Where appends a list predicates to the PostDelete builder. 23 | func (pd *PostDelete) Where(ps ...predicate.Post) *PostDelete { 24 | pd.mutation.Where(ps...) 25 | return pd 26 | } 27 | 28 | // Exec executes the deletion query and returns how many vertices were deleted. 29 | func (pd *PostDelete) Exec(ctx context.Context) (int, error) { 30 | return withHooks(ctx, pd.sqlExec, pd.mutation, pd.hooks) 31 | } 32 | 33 | // ExecX is like Exec, but panics if an error occurs. 34 | func (pd *PostDelete) ExecX(ctx context.Context) int { 35 | n, err := pd.Exec(ctx) 36 | if err != nil { 37 | panic(err) 38 | } 39 | return n 40 | } 41 | 42 | func (pd *PostDelete) sqlExec(ctx context.Context) (int, error) { 43 | _spec := sqlgraph.NewDeleteSpec(post.Table, sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt)) 44 | if ps := pd.mutation.predicates; len(ps) > 0 { 45 | _spec.Predicate = func(selector *sql.Selector) { 46 | for i := range ps { 47 | ps[i](selector) 48 | } 49 | } 50 | } 51 | affected, err := sqlgraph.DeleteNodes(ctx, pd.driver, _spec) 52 | if err != nil && sqlgraph.IsConstraintError(err) { 53 | err = &ConstraintError{msg: err.Error(), wrap: err} 54 | } 55 | pd.mutation.done = true 56 | return affected, err 57 | } 58 | 59 | // PostDeleteOne is the builder for deleting a single Post entity. 60 | type PostDeleteOne struct { 61 | pd *PostDelete 62 | } 63 | 64 | // Where appends a list predicates to the PostDelete builder. 65 | func (pdo *PostDeleteOne) Where(ps ...predicate.Post) *PostDeleteOne { 66 | pdo.pd.mutation.Where(ps...) 67 | return pdo 68 | } 69 | 70 | // Exec executes the deletion query. 71 | func (pdo *PostDeleteOne) Exec(ctx context.Context) error { 72 | n, err := pdo.pd.Exec(ctx) 73 | switch { 74 | case err != nil: 75 | return err 76 | case n == 0: 77 | return &NotFoundError{post.Label} 78 | default: 79 | return nil 80 | } 81 | } 82 | 83 | // ExecX is like Exec, but panics if an error occurs. 84 | func (pdo *PostDeleteOne) ExecX(ctx context.Context) { 85 | if err := pdo.Exec(ctx); err != nil { 86 | panic(err) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/predicate/predicate.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package predicate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql" 7 | ) 8 | 9 | // Post is the predicate function for post builders. 10 | type Post func(*sql.Selector) 11 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/privacy/privacy.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package privacy 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/go-saas/saas/examples/ent/tenant/ent" 9 | 10 | "entgo.io/ent/privacy" 11 | ) 12 | 13 | var ( 14 | // Allow may be returned by rules to indicate that the policy 15 | // evaluation should terminate with allow decision. 16 | Allow = privacy.Allow 17 | 18 | // Deny may be returned by rules to indicate that the policy 19 | // evaluation should terminate with deny decision. 20 | Deny = privacy.Deny 21 | 22 | // Skip may be returned by rules to indicate that the policy 23 | // evaluation should continue to the next rule. 24 | Skip = privacy.Skip 25 | ) 26 | 27 | // Allowf returns a formatted wrapped Allow decision. 28 | func Allowf(format string, a ...any) error { 29 | return privacy.Allowf(format, a...) 30 | } 31 | 32 | // Denyf returns a formatted wrapped Deny decision. 33 | func Denyf(format string, a ...any) error { 34 | return privacy.Denyf(format, a...) 35 | } 36 | 37 | // Skipf returns a formatted wrapped Skip decision. 38 | func Skipf(format string, a ...any) error { 39 | return privacy.Skipf(format, a...) 40 | } 41 | 42 | // DecisionContext creates a new context from the given parent context with 43 | // a policy decision attach to it. 44 | func DecisionContext(parent context.Context, decision error) context.Context { 45 | return privacy.DecisionContext(parent, decision) 46 | } 47 | 48 | // DecisionFromContext retrieves the policy decision from the context. 49 | func DecisionFromContext(ctx context.Context) (error, bool) { 50 | return privacy.DecisionFromContext(ctx) 51 | } 52 | 53 | type ( 54 | // Policy groups query and mutation policies. 55 | Policy = privacy.Policy 56 | 57 | // QueryRule defines the interface deciding whether a 58 | // query is allowed and optionally modify it. 59 | QueryRule = privacy.QueryRule 60 | // QueryPolicy combines multiple query rules into a single policy. 61 | QueryPolicy = privacy.QueryPolicy 62 | 63 | // MutationRule defines the interface which decides whether a 64 | // mutation is allowed and optionally modifies it. 65 | MutationRule = privacy.MutationRule 66 | // MutationPolicy combines multiple mutation rules into a single policy. 67 | MutationPolicy = privacy.MutationPolicy 68 | // MutationRuleFunc type is an adapter which allows the use of 69 | // ordinary functions as mutation rules. 70 | MutationRuleFunc = privacy.MutationRuleFunc 71 | 72 | // QueryMutationRule is an interface which groups query and mutation rules. 73 | QueryMutationRule = privacy.QueryMutationRule 74 | ) 75 | 76 | // QueryRuleFunc type is an adapter to allow the use of 77 | // ordinary functions as query rules. 78 | type QueryRuleFunc func(context.Context, ent.Query) error 79 | 80 | // Eval returns f(ctx, q). 81 | func (f QueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { 82 | return f(ctx, q) 83 | } 84 | 85 | // AlwaysAllowRule returns a rule that returns an allow decision. 86 | func AlwaysAllowRule() QueryMutationRule { 87 | return privacy.AlwaysAllowRule() 88 | } 89 | 90 | // AlwaysDenyRule returns a rule that returns a deny decision. 91 | func AlwaysDenyRule() QueryMutationRule { 92 | return privacy.AlwaysDenyRule() 93 | } 94 | 95 | // ContextQueryMutationRule creates a query/mutation rule from a context eval func. 96 | func ContextQueryMutationRule(eval func(context.Context) error) QueryMutationRule { 97 | return privacy.ContextQueryMutationRule(eval) 98 | } 99 | 100 | // OnMutationOperation evaluates the given rule only on a given mutation operation. 101 | func OnMutationOperation(rule MutationRule, op ent.Op) MutationRule { 102 | return privacy.OnMutationOperation(rule, op) 103 | } 104 | 105 | // DenyMutationOperationRule returns a rule denying specified mutation operation. 106 | func DenyMutationOperationRule(op ent.Op) MutationRule { 107 | rule := MutationRuleFunc(func(_ context.Context, m ent.Mutation) error { 108 | return Denyf("ent/privacy: operation %s is not allowed", m.Op()) 109 | }) 110 | return OnMutationOperation(rule, op) 111 | } 112 | 113 | // The PostQueryRuleFunc type is an adapter to allow the use of ordinary 114 | // functions as a query rule. 115 | type PostQueryRuleFunc func(context.Context, *ent.PostQuery) error 116 | 117 | // EvalQuery return f(ctx, q). 118 | func (f PostQueryRuleFunc) EvalQuery(ctx context.Context, q ent.Query) error { 119 | if q, ok := q.(*ent.PostQuery); ok { 120 | return f(ctx, q) 121 | } 122 | return Denyf("ent/privacy: unexpected query type %T, expect *ent.PostQuery", q) 123 | } 124 | 125 | // The PostMutationRuleFunc type is an adapter to allow the use of ordinary 126 | // functions as a mutation rule. 127 | type PostMutationRuleFunc func(context.Context, *ent.PostMutation) error 128 | 129 | // EvalMutation calls f(ctx, m). 130 | func (f PostMutationRuleFunc) EvalMutation(ctx context.Context, m ent.Mutation) error { 131 | if m, ok := m.(*ent.PostMutation); ok { 132 | return f(ctx, m) 133 | } 134 | return Denyf("ent/privacy: unexpected mutation type %T, expect *ent.PostMutation", m) 135 | } 136 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | // The schema-stitching logic is generated in github.com/go-saas/saas/examples/ent/tenant/ent/runtime/runtime.go 6 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/runtime/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package runtime 4 | 5 | import ( 6 | "github.com/go-saas/saas/examples/ent/tenant/ent/post" 7 | "github.com/go-saas/saas/examples/ent/tenant/ent/schema" 8 | ) 9 | 10 | // The init function reads all schema descriptors with runtime code 11 | // (default values, validators, hooks and policies) and stitches it 12 | // to their package variables. 13 | func init() { 14 | postMixin := schema.Post{}.Mixin() 15 | postMixinHooks0 := postMixin[0].Hooks() 16 | post.Hooks[0] = postMixinHooks0[0] 17 | postMixinInters0 := postMixin[0].Interceptors() 18 | post.Interceptors[0] = postMixinInters0[0] 19 | } 20 | 21 | const ( 22 | Version = "v0.12.4" // Version of ent codegen. 23 | Sum = "h1:LddPnAyxls/O7DTXZvUGDj0NZIdGSu317+aoNLJWbD8=" // Sum of ent codegen. 24 | ) 25 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/schema/post.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/field" 6 | sent "github.com/go-saas/saas/ent" 7 | ) 8 | 9 | // Post holds the schema definition for the Post entity. 10 | type Post struct { 11 | ent.Schema 12 | } 13 | 14 | // Fields of the Post. 15 | func (Post) Fields() []ent.Field { 16 | return []ent.Field{ 17 | field.Int("id"), 18 | field.String("title"), 19 | field.String("description").Optional(), 20 | field.String("dsn").Optional(), 21 | } 22 | } 23 | 24 | // Edges of the Post. 25 | func (Post) Edges() []ent.Edge { 26 | return nil 27 | } 28 | 29 | func (Post) Mixin() []ent.Mixin { 30 | return []ent.Mixin{ 31 | sent.HasTenant{}, 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /examples/ent/tenant/ent/tx.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package ent 4 | 5 | import ( 6 | "context" 7 | "sync" 8 | 9 | "entgo.io/ent/dialect" 10 | ) 11 | 12 | // Tx is a transactional client that is created by calling Client.Tx(). 13 | type Tx struct { 14 | config 15 | // Post is the client for interacting with the Post builders. 16 | Post *PostClient 17 | 18 | // lazily loaded. 19 | client *Client 20 | clientOnce sync.Once 21 | // ctx lives for the life of the transaction. It is 22 | // the same context used by the underlying connection. 23 | ctx context.Context 24 | } 25 | 26 | type ( 27 | // Committer is the interface that wraps the Commit method. 28 | Committer interface { 29 | Commit(context.Context, *Tx) error 30 | } 31 | 32 | // The CommitFunc type is an adapter to allow the use of ordinary 33 | // function as a Committer. If f is a function with the appropriate 34 | // signature, CommitFunc(f) is a Committer that calls f. 35 | CommitFunc func(context.Context, *Tx) error 36 | 37 | // CommitHook defines the "commit middleware". A function that gets a Committer 38 | // and returns a Committer. For example: 39 | // 40 | // hook := func(next ent.Committer) ent.Committer { 41 | // return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error { 42 | // // Do some stuff before. 43 | // if err := next.Commit(ctx, tx); err != nil { 44 | // return err 45 | // } 46 | // // Do some stuff after. 47 | // return nil 48 | // }) 49 | // } 50 | // 51 | CommitHook func(Committer) Committer 52 | ) 53 | 54 | // Commit calls f(ctx, m). 55 | func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error { 56 | return f(ctx, tx) 57 | } 58 | 59 | // Commit commits the transaction. 60 | func (tx *Tx) Commit() error { 61 | txDriver := tx.config.driver.(*txDriver) 62 | var fn Committer = CommitFunc(func(context.Context, *Tx) error { 63 | return txDriver.tx.Commit() 64 | }) 65 | txDriver.mu.Lock() 66 | hooks := append([]CommitHook(nil), txDriver.onCommit...) 67 | txDriver.mu.Unlock() 68 | for i := len(hooks) - 1; i >= 0; i-- { 69 | fn = hooks[i](fn) 70 | } 71 | return fn.Commit(tx.ctx, tx) 72 | } 73 | 74 | // OnCommit adds a hook to call on commit. 75 | func (tx *Tx) OnCommit(f CommitHook) { 76 | txDriver := tx.config.driver.(*txDriver) 77 | txDriver.mu.Lock() 78 | txDriver.onCommit = append(txDriver.onCommit, f) 79 | txDriver.mu.Unlock() 80 | } 81 | 82 | type ( 83 | // Rollbacker is the interface that wraps the Rollback method. 84 | Rollbacker interface { 85 | Rollback(context.Context, *Tx) error 86 | } 87 | 88 | // The RollbackFunc type is an adapter to allow the use of ordinary 89 | // function as a Rollbacker. If f is a function with the appropriate 90 | // signature, RollbackFunc(f) is a Rollbacker that calls f. 91 | RollbackFunc func(context.Context, *Tx) error 92 | 93 | // RollbackHook defines the "rollback middleware". A function that gets a Rollbacker 94 | // and returns a Rollbacker. For example: 95 | // 96 | // hook := func(next ent.Rollbacker) ent.Rollbacker { 97 | // return ent.RollbackFunc(func(ctx context.Context, tx *ent.Tx) error { 98 | // // Do some stuff before. 99 | // if err := next.Rollback(ctx, tx); err != nil { 100 | // return err 101 | // } 102 | // // Do some stuff after. 103 | // return nil 104 | // }) 105 | // } 106 | // 107 | RollbackHook func(Rollbacker) Rollbacker 108 | ) 109 | 110 | // Rollback calls f(ctx, m). 111 | func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error { 112 | return f(ctx, tx) 113 | } 114 | 115 | // Rollback rollbacks the transaction. 116 | func (tx *Tx) Rollback() error { 117 | txDriver := tx.config.driver.(*txDriver) 118 | var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { 119 | return txDriver.tx.Rollback() 120 | }) 121 | txDriver.mu.Lock() 122 | hooks := append([]RollbackHook(nil), txDriver.onRollback...) 123 | txDriver.mu.Unlock() 124 | for i := len(hooks) - 1; i >= 0; i-- { 125 | fn = hooks[i](fn) 126 | } 127 | return fn.Rollback(tx.ctx, tx) 128 | } 129 | 130 | // OnRollback adds a hook to call on rollback. 131 | func (tx *Tx) OnRollback(f RollbackHook) { 132 | txDriver := tx.config.driver.(*txDriver) 133 | txDriver.mu.Lock() 134 | txDriver.onRollback = append(txDriver.onRollback, f) 135 | txDriver.mu.Unlock() 136 | } 137 | 138 | // Client returns a Client that binds to current transaction. 139 | func (tx *Tx) Client() *Client { 140 | tx.clientOnce.Do(func() { 141 | tx.client = &Client{config: tx.config} 142 | tx.client.init() 143 | }) 144 | return tx.client 145 | } 146 | 147 | func (tx *Tx) init() { 148 | tx.Post = NewPostClient(tx.config) 149 | } 150 | 151 | // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. 152 | // The idea is to support transactions without adding any extra code to the builders. 153 | // When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. 154 | // Commit and Rollback are nop for the internal builders and the user must call one 155 | // of them in order to commit or rollback the transaction. 156 | // 157 | // If a closed transaction is embedded in one of the generated entities, and the entity 158 | // applies a query, for example: Post.QueryXXX(), the query will be executed 159 | // through the driver which created this transaction. 160 | // 161 | // Note that txDriver is not goroutine safe. 162 | type txDriver struct { 163 | // the driver we started the transaction from. 164 | drv dialect.Driver 165 | // tx is the underlying transaction. 166 | tx dialect.Tx 167 | // completion hooks. 168 | mu sync.Mutex 169 | onCommit []CommitHook 170 | onRollback []RollbackHook 171 | } 172 | 173 | // newTx creates a new transactional driver. 174 | func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { 175 | tx, err := drv.Tx(ctx) 176 | if err != nil { 177 | return nil, err 178 | } 179 | return &txDriver{tx: tx, drv: drv}, nil 180 | } 181 | 182 | // Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls 183 | // from the internal builders. Should be called only by the internal builders. 184 | func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } 185 | 186 | // Dialect returns the dialect of the driver we started the transaction from. 187 | func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } 188 | 189 | // Close is a nop close. 190 | func (*txDriver) Close() error { return nil } 191 | 192 | // Commit is a nop commit for the internal builders. 193 | // User must call `Tx.Commit` in order to commit the transaction. 194 | func (*txDriver) Commit() error { return nil } 195 | 196 | // Rollback is a nop rollback for the internal builders. 197 | // User must call `Tx.Rollback` in order to rollback the transaction. 198 | func (*txDriver) Rollback() error { return nil } 199 | 200 | // Exec calls tx.Exec. 201 | func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { 202 | return tx.tx.Exec(ctx, query, args, v) 203 | } 204 | 205 | // Query calls tx.Query. 206 | func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { 207 | return tx.tx.Query(ctx, query, args, v) 208 | } 209 | 210 | var _ dialect.Driver = (*txDriver)(nil) 211 | -------------------------------------------------------------------------------- /examples/ent/tenant_store.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas" 6 | 7 | "github.com/go-saas/saas/examples/ent/shared/ent" 8 | "github.com/go-saas/saas/examples/ent/shared/ent/tenant" 9 | "strconv" 10 | ) 11 | 12 | type TenantStore struct { 13 | shared SharedDbProvider 14 | } 15 | 16 | func (t *TenantStore) GetByNameOrId(ctx context.Context, nameOrId string) (*saas.TenantConfig, error) { 17 | ctx = saas.NewCurrentTenant(ctx, "", "") 18 | db := t.shared.Get(ctx, "") 19 | i, err := strconv.Atoi(nameOrId) 20 | var te *ent.Tenant 21 | if err == nil { 22 | te, err = db.Tenant.Query().Where(tenant.Or(tenant.ID(i), tenant.Name(nameOrId))).First(ctx) 23 | } else { 24 | te, err = db.Tenant.Query().Where(tenant.Name(nameOrId)).First(ctx) 25 | } 26 | if err != nil { 27 | if ent.IsNotFound(err) { 28 | return nil, saas.ErrTenantNotFound 29 | } else { 30 | return nil, err 31 | } 32 | } 33 | ret := saas.NewTenantConfig(strconv.Itoa(te.ID), te.Name, te.Region, "") 34 | conns, err := te.QueryConn().All(ctx) 35 | if err != nil { 36 | return nil, err 37 | } 38 | for _, conn := range conns { 39 | ret.Conn[conn.Key] = conn.Value 40 | } 41 | return ret, nil 42 | 43 | } 44 | -------------------------------------------------------------------------------- /examples/gorm/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.7' 2 | services: 3 | mysqld: 4 | image: mysql:8.0 5 | restart: always 6 | environment: 7 | - MYSQL_ROOT_PASSWORD=youShouldChangeThis 8 | - MYSQL_ROOT_HOST=% 9 | volumes: 10 | - mysql_data:/var/lib/mysql 11 | ports: 12 | - "3406:3306" 13 | healthcheck: 14 | test: ["CMD", "mysqladmin" ,"ping", "-h", "localhost"] 15 | timeout: 20s 16 | retries: 10 17 | 18 | postgres: 19 | image: postgres:alpine3.19 20 | restart: always 21 | environment: 22 | - POSTGRES_DB=pgsql-saas 23 | - POSTGRES_USER=pgsql-saas 24 | - POSTGRES_PASSWORD=pgsql-saas 25 | volumes: 26 | - postgres_data:/var/lib/postgresql/data 27 | ports: 28 | - "5435:5432" 29 | 30 | volumes: 31 | mysql_data: 32 | postgres_data: -------------------------------------------------------------------------------- /examples/gorm/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-saas/saas/examples/gorm 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/gin-gonic/gin v1.9.1 7 | github.com/go-saas/saas v0.0.10 8 | github.com/go-sql-driver/mysql v1.6.0 9 | github.com/google/uuid v1.3.1 10 | gorm.io/driver/mysql v1.3.3 11 | gorm.io/driver/sqlite v1.4.3 12 | gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde 13 | ) 14 | 15 | require ( 16 | github.com/bytedance/sonic v1.9.1 // indirect 17 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect 18 | github.com/gabriel-vasile/mimetype v1.4.2 // indirect 19 | github.com/jackc/pgpassfile v1.0.0 // indirect 20 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 21 | github.com/jackc/pgx/v5 v5.4.3 // indirect 22 | github.com/klauspost/cpuid/v2 v2.2.4 // indirect 23 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 24 | golang.org/x/arch v0.3.0 // indirect 25 | gopkg.in/yaml.v3 v3.0.1 // indirect 26 | ) 27 | 28 | require ( 29 | github.com/gin-contrib/sse v0.1.0 // indirect 30 | github.com/go-playground/locales v0.14.1 // indirect 31 | github.com/go-playground/universal-translator v0.18.1 // indirect 32 | github.com/go-playground/validator/v10 v10.14.0 // indirect 33 | github.com/goccy/go-json v0.10.2 // indirect 34 | github.com/jinzhu/inflection v1.0.0 // indirect 35 | github.com/jinzhu/now v1.1.5 // indirect 36 | github.com/json-iterator/go v1.1.12 // indirect 37 | github.com/leodido/go-urn v1.2.4 // indirect 38 | github.com/mattn/go-isatty v0.0.19 // indirect 39 | github.com/mattn/go-sqlite3 v1.14.16 // indirect 40 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 41 | github.com/modern-go/reflect2 v1.0.2 // indirect 42 | github.com/pelletier/go-toml/v2 v2.0.8 // indirect 43 | github.com/ugorji/go/codec v1.2.11 // indirect 44 | golang.org/x/crypto v0.14.0 // indirect 45 | golang.org/x/net v0.15.0 // indirect 46 | golang.org/x/sys v0.13.0 // indirect 47 | golang.org/x/text v0.13.0 // indirect 48 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect 49 | google.golang.org/protobuf v1.31.0 // indirect 50 | gopkg.in/yaml.v2 v2.4.0 // indirect 51 | gorm.io/driver/postgres v1.5.7 52 | ) 53 | 54 | replace github.com/go-saas/saas => ../../ 55 | -------------------------------------------------------------------------------- /examples/gorm/migration.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas/gorm" 6 | "github.com/go-saas/saas/seed" 7 | ) 8 | 9 | type MigrationSeeder struct { 10 | dbProvider gorm.DbProvider 11 | } 12 | 13 | func NewMigrationSeeder(dbProvider gorm.DbProvider) *MigrationSeeder { 14 | return &MigrationSeeder{dbProvider: dbProvider} 15 | } 16 | 17 | func (m *MigrationSeeder) Seed(ctx context.Context, sCtx *seed.Context) error { 18 | db := m.dbProvider.Get(ctx, "") 19 | if sCtx.TenantId == "" { 20 | //host add tenant database 21 | err := db.AutoMigrate(&Tenant{}, &TenantConn{}) 22 | if err != nil { 23 | return err 24 | } 25 | } 26 | err := db.AutoMigrate(&Post{}) 27 | if err != nil { 28 | return err 29 | } 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /examples/gorm/post.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | gorm2 "github.com/go-saas/saas/gorm" 5 | "gorm.io/gorm" 6 | ) 7 | 8 | type Post struct { 9 | gorm.Model 10 | Title string `json:"title"` 11 | Description string `json:"description"` 12 | gorm2.MultiTenancy 13 | } 14 | -------------------------------------------------------------------------------- /examples/gorm/postgres-utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | func ParseDBNameFromPostgresDSN(dsn string) (string, error) { 9 | // Split the DSN string into key-value pairs 10 | keyValuePairs := strings.Split(dsn, " ") 11 | 12 | // Loop through the key-value pairs 13 | for _, pair := range keyValuePairs { 14 | // Split each pair into key and value 15 | kv := strings.Split(pair, "=") 16 | if len(kv) != 2 { 17 | return "", fmt.Errorf("invalid key-value pair: %s", pair) 18 | } 19 | 20 | // Check if the key is "dbname" 21 | if kv[0] == "dbname" { 22 | // Return the corresponding value 23 | return kv[1], nil 24 | } 25 | } 26 | 27 | // If "dbname" key is not found, return an error 28 | return "", fmt.Errorf("dbname key not found in DSN: %s", dsn) 29 | } 30 | 31 | func RemoveDBNameFromPostgresDSN(dsn string) (string, error) { 32 | // Split the DSN string into key-value pairs 33 | keyValuePairs := strings.Split(dsn, " ") 34 | 35 | // Initialize a slice to store filtered key-value pairs 36 | filteredPairs := make([]string, 0) 37 | 38 | // Loop through the key-value pairs 39 | for _, pair := range keyValuePairs { 40 | // Split each pair into key and value 41 | kv := strings.Split(pair, "=") 42 | if len(kv) != 2 { 43 | return "", fmt.Errorf("invalid key-value pair: %s", pair) 44 | } 45 | 46 | // Check if the key is not "dbname" 47 | if kv[0] != "dbname" { 48 | // Append the pair to the filtered key-value pairs 49 | filteredPairs = append(filteredPairs, pair) 50 | } 51 | } 52 | 53 | // Reconstruct the DSN string without the dbname key 54 | result := strings.Join(filteredPairs, " ") 55 | 56 | return result, nil 57 | } 58 | 59 | func AddSuffixToDBName(dsn string, suffix string) string { 60 | // Split the DSN string into key-value pairs 61 | pairs := strings.Split(dsn, " ") 62 | 63 | // Initialize variables to store modified dbname and other parts of the DSN 64 | var modifiedDBName string 65 | var otherParts []string 66 | 67 | // Loop through the key-value pairs 68 | for _, pair := range pairs { 69 | // Split each pair into key and value 70 | kv := strings.Split(pair, "=") 71 | if len(kv) != 2 { 72 | continue 73 | } 74 | key, value := kv[0], kv[1] 75 | // Check if the key is "dbname" 76 | if key == "dbname" { 77 | // Append the modified dbname with suffix 78 | modifiedDBName = fmt.Sprintf("%s-%s", value, suffix) 79 | } else { 80 | // Append other parts of the DSN 81 | otherParts = append(otherParts, fmt.Sprintf("%s=%s", key, value)) 82 | } 83 | } 84 | 85 | // Construct the modified DSN with the new dbname 86 | modifiedDSN := strings.Join(otherParts, " ") 87 | if modifiedDBName != "" { 88 | modifiedDSN += fmt.Sprintf(" dbname=%s", modifiedDBName) 89 | } 90 | 91 | return modifiedDSN 92 | } 93 | -------------------------------------------------------------------------------- /examples/gorm/readme.md: -------------------------------------------------------------------------------- 1 | # Example project 2 | 3 | combination of `go-saas`,`gin`,`gorm(sqlite/mysql)` 4 | 5 | ### sqlite3 6 | ```shell 7 | go run github.com/go-saas/saas/examples/gorm 8 | ``` 9 | --- 10 | ### mysql 11 | ```shell 12 | docker-compose up -d 13 | go run github.com/go-saas/saas/examples/gorm --driver mysql 14 | ``` 15 | --- 16 | ### postgres 17 | ```shell 18 | docker-compose up -d 19 | go run github.com/go-saas/saas/examples/gorm --driver pgx 20 | ``` 21 | 22 | Host side ( use shared database): 23 | 24 | Open http://localhost:8090/posts 25 | 26 | --- 27 | Multi-tenancy ( use shared database): 28 | 29 | Open http://localhost:8090/posts?__tenant=1 30 | 31 | Open http://localhost:8090/posts?__tenant=2 32 | 33 | --- 34 | Single-tenancy ( use separate database): 35 | 36 | Open http://localhost:8090/posts?__tenant=3 37 | 38 | --- 39 | 40 | Create tenant 41 | ```shell 42 | curl -H "Accept: application/json" -H "Content-type: application/json" -X POST -d '{"name":"newTenant","separateDb":true}' http://localhost:8090/tenant 43 | ``` 44 | Open http://localhost:8090/posts?__tenant=newTenant 45 | -------------------------------------------------------------------------------- /examples/gorm/seed.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/go-saas/saas" 7 | 8 | "github.com/go-saas/saas/data" 9 | "github.com/go-saas/saas/gorm" 10 | "github.com/go-saas/saas/seed" 11 | gorm2 "gorm.io/gorm" 12 | "gorm.io/gorm/clause" 13 | ) 14 | 15 | type Seed struct { 16 | dbProvider gorm.DbProvider 17 | connStrGen saas.ConnStrGenerator 18 | } 19 | 20 | func NewSeed(dbProvider gorm.DbProvider, connStrGen saas.ConnStrGenerator) *Seed { 21 | return &Seed{dbProvider: dbProvider, connStrGen: connStrGen} 22 | } 23 | 24 | func (s *Seed) Seed(ctx context.Context, sCtx *seed.Context) error { 25 | db := s.dbProvider.Get(ctx, "") 26 | 27 | if sCtx.TenantId == "" { 28 | //seed host 29 | t3 := Tenant{ID: "3", Name: "Test3"} 30 | t3Conn, _ := s.connStrGen.Gen(ctx, saas.NewBasicTenantInfo(t3.ID, t3.Name)) 31 | 32 | t3.Conn = []TenantConn{ 33 | {Key: data.Default, Value: t3Conn}, // use tenant3.db 34 | } 35 | err := db.Model(&Tenant{}).Session(&gorm2.Session{FullSaveAssociations: true}).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches([]Tenant{ 36 | {ID: "1", Name: "Test1"}, // use default shared.db 37 | {ID: "2", Name: "Test2"}, 38 | t3}, 10).Error 39 | if err != nil { 40 | return err 41 | } 42 | entities := []Post{ 43 | { 44 | Model: gorm2.Model{ID: 1}, 45 | Title: fmt.Sprintf("Host Side"), 46 | Description: fmt.Sprintf("Hello Host"), 47 | }, 48 | } 49 | if err := createPosts(db, entities); err != nil { 50 | return err 51 | } 52 | } 53 | 54 | if sCtx.TenantId == "1" { 55 | entities := []Post{ 56 | { 57 | Model: gorm2.Model{ID: 2}, 58 | Title: fmt.Sprintf("Tenant %s Post 1", sCtx.TenantId), 59 | Description: fmt.Sprintf("Hello from tenant %s. There are one post in this tenant. This is post 1", sCtx.TenantId), 60 | }, 61 | } 62 | if err := createPosts(db, entities); err != nil { 63 | return err 64 | } 65 | } 66 | 67 | if sCtx.TenantId == "2" { 68 | entities := []Post{ 69 | { 70 | Model: gorm2.Model{ID: 3}, 71 | Title: fmt.Sprintf("Tenant %s Post 1", sCtx.TenantId), 72 | Description: fmt.Sprintf("Hello from tenant %s. There are two posts in this tenant. This is post 1", sCtx.TenantId), 73 | }, 74 | { 75 | Model: gorm2.Model{ID: 4}, 76 | Title: fmt.Sprintf("Tenant %s Post 2", sCtx.TenantId), 77 | Description: fmt.Sprintf("Hello from tenant %s. There are two posts in this tenant. This is post 2", sCtx.TenantId), 78 | }, 79 | } 80 | if err := createPosts(db, entities); err != nil { 81 | return err 82 | } 83 | } 84 | 85 | if sCtx.TenantId == "3" { 86 | entities := []Post{ 87 | { 88 | Model: gorm2.Model{ID: 5}, 89 | Title: fmt.Sprintf("Tenant %s Post 1", sCtx.TenantId), 90 | Description: fmt.Sprintf("Hello from tenant %s. There are there posts in this tenant. This is post 1", sCtx.TenantId), 91 | }, 92 | { 93 | Model: gorm2.Model{ID: 6}, 94 | Title: fmt.Sprintf("Tenant %s Post 2", sCtx.TenantId), 95 | Description: fmt.Sprintf("Hello from tenant %s. There are there posts in this tenant. This is post 2", sCtx.TenantId), 96 | }, 97 | { 98 | Model: gorm2.Model{ID: 7}, 99 | Title: fmt.Sprintf("Tenant %s Post 2", sCtx.TenantId), 100 | Description: fmt.Sprintf("Hello from tenant %s. There are there posts in this tenant. This is post 3", sCtx.TenantId), 101 | }, 102 | } 103 | if err := createPosts(db, entities); err != nil { 104 | return err 105 | } 106 | } 107 | return nil 108 | } 109 | 110 | func createPosts(db *gorm2.DB, entities []Post) error { 111 | for _, entity := range entities { 112 | err := db.Clauses(clause.OnConflict{ 113 | UpdateAll: true, 114 | }).Model(&Post{}).Create(&entity).Error 115 | if err != nil { 116 | return err 117 | } 118 | } 119 | return nil 120 | } 121 | -------------------------------------------------------------------------------- /examples/gorm/tenant.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/go-saas/saas" 7 | 8 | sgorm "github.com/go-saas/saas/gorm" 9 | "gorm.io/gorm" 10 | "time" 11 | ) 12 | 13 | type Tenant struct { 14 | ID string `gorm:"type:varchar(36)" json:"id"` 15 | //unique name. usually for domain name 16 | Name string `gorm:"column:name;index;size:255;"` 17 | //localed display name 18 | DisplayName string `gorm:"column:display_name;index;size:255;"` 19 | //region of this tenant 20 | Region string `gorm:"column:region;index;size:255;"` 21 | Logo string 22 | CreatedAt time.Time `gorm:"column:created_at;index;"` 23 | UpdatedAt time.Time `gorm:"column:updated_at;index;"` 24 | DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;index;"` 25 | 26 | //connection 27 | Conn []TenantConn `gorm:"foreignKey:TenantId"` 28 | } 29 | 30 | // TenantConn connection string info 31 | type TenantConn struct { 32 | TenantId string `gorm:"column:tenant_id;primary_key;size:36;"` 33 | //key of connection string 34 | Key string `gorm:"column:key;primary_key;size:100;"` 35 | //connection string 36 | Value string `gorm:"column:value;size:1000;"` 37 | CreatedAt time.Time `gorm:"column:created_at;index;"` 38 | UpdatedAt time.Time `gorm:"column:updated_at;index;"` 39 | } 40 | 41 | type TenantStore struct { 42 | dbProvider sgorm.DbProvider 43 | } 44 | 45 | func (t *TenantStore) GetByNameOrId(ctx context.Context, nameOrId string) (*saas.TenantConfig, error) { 46 | //change to host side 47 | ctx = saas.NewCurrentTenant(ctx, "", "") 48 | db := t.dbProvider.Get(ctx, "") 49 | var tenant Tenant 50 | err := db.Model(&Tenant{}).Preload("Conn").Where("id = ? OR name = ?", nameOrId, nameOrId).First(&tenant).Error 51 | if err != nil { 52 | if errors.Is(err, gorm.ErrRecordNotFound) { 53 | return nil, saas.ErrTenantNotFound 54 | } else { 55 | return nil, err 56 | } 57 | } 58 | ret := saas.NewTenantConfig(tenant.ID, tenant.Name, tenant.Region, "") 59 | for _, conn := range tenant.Conn { 60 | ret.Conn[conn.Key] = conn.Value 61 | } 62 | return ret, nil 63 | } 64 | 65 | var _ saas.TenantStore = (*TenantStore)(nil) 66 | -------------------------------------------------------------------------------- /gateway/apisix/readme.md: -------------------------------------------------------------------------------- 1 | # Apisix 2 | 3 | Ref: https://github.com/apache/apisix-go-plugin-runner 4 | 5 | In your `go-runner` 6 | ```go 7 | import _ "github.com/go-saas/saas/gateway/apisix" 8 | 9 | apisix.InitTenantStore(ts) 10 | ``` 11 | -------------------------------------------------------------------------------- /gateway/apisix/resolver.go: -------------------------------------------------------------------------------- 1 | package apisix 2 | 3 | import ( 4 | "context" 5 | pkgHTTP "github.com/apache/apisix-go-plugin-runner/pkg/http" 6 | "github.com/go-saas/saas" 7 | 8 | "regexp" 9 | ) 10 | 11 | type Resolver struct { 12 | r pkgHTTP.Request 13 | key string 14 | pathRegex string 15 | } 16 | 17 | func NewResolver(r pkgHTTP.Request, key, pathRegex string) *Resolver { 18 | return &Resolver{ 19 | r: r, 20 | key: key, 21 | pathRegex: pathRegex, 22 | } 23 | } 24 | 25 | var _ saas.TenantResolver = (*Resolver)(nil) 26 | 27 | func (r *Resolver) Resolve(ctx context.Context) (saas.TenantResolveResult, context.Context, error) { 28 | // default host side 29 | var t = "" 30 | if v := r.r.Header().Get(r.key); len(v) > 0 { 31 | t = v 32 | } 33 | if v := r.r.Args().Get(r.key); len(v) > 0 { 34 | t = v 35 | } 36 | if len(r.pathRegex) > 0 { 37 | reg := regexp.MustCompile(r.pathRegex) 38 | f := reg.FindAllStringSubmatch(string(r.r.Path()), -1) 39 | if f != nil { 40 | t = f[0][1] 41 | } 42 | } 43 | 44 | return saas.TenantResolveResult{TenantIdOrName: t}, ctx, nil 45 | } 46 | -------------------------------------------------------------------------------- /gateway/apisix/saas.go: -------------------------------------------------------------------------------- 1 | package apisix 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "github.com/go-kratos/kratos/v2/errors" 7 | "github.com/go-saas/saas" 8 | 9 | shttp "github.com/go-saas/saas/http" 10 | "net/http" 11 | 12 | pkgHTTP "github.com/apache/apisix-go-plugin-runner/pkg/http" 13 | "github.com/apache/apisix-go-plugin-runner/pkg/log" 14 | "github.com/apache/apisix-go-plugin-runner/pkg/plugin" 15 | ) 16 | 17 | func init() { 18 | err := plugin.RegisterPlugin(&Saas{}) 19 | if err != nil { 20 | log.Fatalf("failed to register plugin go-saas: %s", err) 21 | } 22 | } 23 | 24 | //Saas resolve and validate tenant information 25 | type Saas struct { 26 | plugin.DefaultPlugin 27 | } 28 | 29 | type SaasConf struct { 30 | TenantKey string `json:"tenant_key"` 31 | NextHeader string `json:"next_header"` 32 | NextInfoHeader string `json:"next_info_header"` 33 | PathRegex string `json:"path_regex"` 34 | } 35 | 36 | type FormatError func(err error, w http.ResponseWriter) 37 | 38 | //global variable to store tenants 39 | var ( 40 | tenantStore saas.TenantStore 41 | nextTenantHeader string 42 | nextTenantInfoHeader string 43 | ) 44 | 45 | var errFormat FormatError = func(err error, w http.ResponseWriter) { 46 | if errors.Is(err, saas.ErrTenantNotFound) { 47 | w.WriteHeader(404) 48 | } 49 | w.WriteHeader(500) 50 | } 51 | 52 | func Init(t saas.TenantStore, nextHeader, nextInfoHeader string, format FormatError) { 53 | tenantStore = t 54 | errFormat = format 55 | nextTenantHeader = nextHeader 56 | nextTenantInfoHeader = nextInfoHeader 57 | } 58 | 59 | func (p *Saas) Name() string { 60 | return "go-saas" 61 | } 62 | 63 | func (p *Saas) ParseConf(in []byte) (interface{}, error) { 64 | conf := SaasConf{} 65 | err := json.Unmarshal(in, &conf) 66 | return conf, err 67 | } 68 | 69 | func (p *Saas) RequestFilter(conf interface{}, w http.ResponseWriter, r pkgHTTP.Request) { 70 | cfg := conf.(SaasConf) 71 | if tenantStore == nil { 72 | log.Warnf("fail to find tenant store. please call Init first") 73 | return 74 | } 75 | key := shttp.KeyOrDefault(cfg.TenantKey) 76 | nextHeader := cfg.NextHeader 77 | if len(nextHeader) == 0 { 78 | nextHeader = nextTenantHeader 79 | } 80 | if len(nextHeader) == 0 { 81 | nextHeader = key 82 | } 83 | ctx := r.Context() 84 | //get tenant config 85 | tenantConfigProvider := saas.NewDefaultTenantConfigProvider(NewResolver(r, key, cfg.PathRegex), tenantStore) 86 | tenantConfig, ctx, err := tenantConfigProvider.Get(ctx) 87 | if err != nil { 88 | errFormat(err, w) 89 | return 90 | } 91 | resolveValue := saas.FromTenantResolveRes(ctx) 92 | idOrName := "" 93 | if resolveValue != nil { 94 | idOrName = resolveValue.TenantIdOrName 95 | } 96 | log.Infof("resolve tenant: %s ,id: %s ,is host: %v", idOrName, tenantConfig.ID, len(tenantConfig.ID) == 0) 97 | r.Header().Set(nextHeader, tenantConfig.ID) 98 | nextInfoHeader := cfg.NextInfoHeader 99 | if len(nextInfoHeader) == 0 { 100 | nextInfoHeader = nextTenantInfoHeader 101 | } 102 | nextInfoHeader = InfoHeaderOrDefault(nextInfoHeader) 103 | b, _ := json.Marshal(tenantConfig) 104 | r.Header().Set(nextInfoHeader, base64.StdEncoding.EncodeToString(b)) 105 | return 106 | } 107 | 108 | func InfoHeaderOrDefault(h string) string { 109 | if len(h) == 0 { 110 | return "X-TENANT-INFO" 111 | } else { 112 | return h 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /gateway/apisix/saas_test.go: -------------------------------------------------------------------------------- 1 | package apisix 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | var ( 12 | testIn = []byte(`{"tenant_key":"tenant","next_header":"next_tenant","path_regex":"http://(.*).test.com"}`) 13 | ) 14 | 15 | func TestSaas(t *testing.T) { 16 | in := testIn 17 | saas := &Saas{} 18 | conf, err := saas.ParseConf(in) 19 | assert.Nil(t, err) 20 | 21 | w := httptest.NewRecorder() 22 | saas.RequestFilter(conf, w, nil) 23 | resp := w.Result() 24 | body, _ := ioutil.ReadAll(resp.Body) 25 | assert.Equal(t, 200, resp.StatusCode) 26 | assert.Equal(t, "", string(body)) 27 | } 28 | 29 | func TestSaas_BadConf(t *testing.T) { 30 | in := []byte(``) 31 | saas := &Saas{} 32 | _, err := saas.ParseConf(in) 33 | assert.NotNil(t, err) 34 | } 35 | 36 | // 37 | //func TestSaasHeader(t *testing.T) { 38 | // InitTenantStore(common.NewMemoryTenantStore([]common.TenantConfig{ 39 | // {ID: "1", Name: "Test1"}, 40 | // {ID: "2", Name: "Test2", Conn: map[string]string{ 41 | // data.Default: ":memory:?cache=shared", 42 | // }}, 43 | // })) 44 | // in := test_in 45 | // saas := &Saas{} 46 | // conf, err := saas.ParseConf(in) 47 | // assert.Nil(t, err) 48 | // 49 | // w := httptest.NewRecorder() 50 | // //TODO 51 | // saas.Filter(conf, w, nil) 52 | // //_ := w.Result() 53 | //} 54 | // 55 | //func TestSaasArgs(t *testing.T) { 56 | // InitTenantStore(common.NewMemoryTenantStore([]common.TenantConfig{ 57 | // {ID: "1", Name: "Test1"}, 58 | // {ID: "2", Name: "Test2", Conn: map[string]string{ 59 | // data.Default: ":memory:?cache=shared", 60 | // }}, 61 | // })) 62 | // in := test_in 63 | // saas := &Saas{} 64 | // conf, err := saas.ParseConf(in) 65 | // assert.Nil(t, err) 66 | // 67 | // w := httptest.NewRecorder() 68 | // //TODO 69 | // saas.Filter(conf, w, nil) 70 | // //_ := w.Result() 71 | //} 72 | // 73 | //func TestSaasPath(t *testing.T) { 74 | // InitTenantStore(common.NewMemoryTenantStore([]common.TenantConfig{ 75 | // {ID: "1", Name: "Test1"}, 76 | // {ID: "2", Name: "Test2", Conn: map[string]string{ 77 | // data.Default: ":memory:?cache=shared", 78 | // }}, 79 | // })) 80 | // in := test_in 81 | // saas := &Saas{} 82 | // conf, err := saas.ParseConf(in) 83 | // assert.Nil(t, err) 84 | // 85 | // w := httptest.NewRecorder() 86 | // //TODO 87 | // saas.Filter(conf, w, nil) 88 | // //_ := w.Result() 89 | //} 90 | -------------------------------------------------------------------------------- /gin/multi_tenancy.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "errors" 5 | "github.com/gin-gonic/gin" 6 | "github.com/go-saas/saas" 7 | 8 | "github.com/go-saas/saas/data" 9 | "github.com/go-saas/saas/http" 10 | ) 11 | 12 | type ErrorFormatter func(context *gin.Context, err error) 13 | 14 | var ( 15 | DefaultErrorFormatter ErrorFormatter = func(context *gin.Context, err error) { 16 | if errors.Is(err, saas.ErrTenantNotFound) { 17 | context.AbortWithError(404, err) 18 | } else { 19 | context.AbortWithError(500, err) 20 | } 21 | } 22 | ) 23 | 24 | type option struct { 25 | hmtOpt *http.WebMultiTenancyOption 26 | ef ErrorFormatter 27 | resolve []saas.ResolveOption 28 | } 29 | 30 | type Option func(*option) 31 | 32 | func WithMultiTenancyOption(opt *http.WebMultiTenancyOption) Option { 33 | return func(o *option) { 34 | o.hmtOpt = opt 35 | } 36 | } 37 | 38 | func WithErrorFormatter(e ErrorFormatter) Option { 39 | return func(o *option) { 40 | o.ef = e 41 | } 42 | } 43 | 44 | func WithResolveOption(opt ...saas.ResolveOption) Option { 45 | return func(o *option) { 46 | o.resolve = opt 47 | } 48 | } 49 | 50 | func MultiTenancy(ts saas.TenantStore, options ...Option) gin.HandlerFunc { 51 | opt := &option{ 52 | hmtOpt: http.NewDefaultWebMultiTenancyOption(), 53 | ef: DefaultErrorFormatter, 54 | resolve: nil, 55 | } 56 | for _, o := range options { 57 | o(opt) 58 | } 59 | return func(context *gin.Context) { 60 | var trOpt []saas.ResolveOption 61 | df := []saas.TenantResolveContrib{ 62 | http.NewCookieTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request), 63 | http.NewFormTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request), 64 | http.NewHeaderTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request), 65 | http.NewQueryTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request)} 66 | if opt.hmtOpt.DomainFormat != "" { 67 | df = append(df, http.NewDomainTenantResolveContrib(opt.hmtOpt.DomainFormat, context.Request)) 68 | } 69 | df = append(df, saas.NewTenantNormalizerContrib(ts)) 70 | trOpt = append(trOpt, saas.AppendContribs(df...)) 71 | trOpt = append(trOpt, opt.resolve...) 72 | 73 | //get tenant config 74 | tenantConfigProvider := saas.NewDefaultTenantConfigProvider(saas.NewDefaultTenantResolver(trOpt...), ts) 75 | tenantConfig, ctx, err := tenantConfigProvider.Get(context) 76 | if err != nil { 77 | opt.ef(context, err) 78 | return 79 | } 80 | //set current tenant 81 | newContext := saas.NewCurrentTenant(ctx, tenantConfig.ID, tenantConfig.Name) 82 | //data filter 83 | newContext = data.NewMultiTenancyDataFilter(newContext) 84 | 85 | //with newContext 86 | context.Request = context.Request.WithContext(newContext) 87 | //next 88 | context.Next() 89 | 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /gin/multi_tenancy_test.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/gin-gonic/gin" 6 | "github.com/go-saas/saas" 7 | "github.com/stretchr/testify/assert" 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func SetUp() *gin.Engine { 15 | r := gin.Default() 16 | r.Use(MultiTenancy(saas.NewMemoryTenantStore( 17 | []saas.TenantConfig{ 18 | {ID: "1", Name: "Test1"}, 19 | {ID: "2", Name: "Test3"}, 20 | }))) 21 | r.GET("/", func(c *gin.Context) { 22 | rCtx := c.Request.Context() 23 | tenantInfo, _ := saas.FromCurrentTenant(rCtx) 24 | trR := saas.FromTenantResolveRes(rCtx) 25 | c.JSON(200, gin.H{ 26 | "tenantId": tenantInfo.GetId(), 27 | "resolvers": trR.AppliedResolvers, 28 | }) 29 | }) 30 | return r 31 | } 32 | 33 | func getW(url string, f func(r *http.Request)) *httptest.ResponseRecorder { 34 | r := SetUp() 35 | req, _ := http.NewRequest("GET", url, nil) 36 | f(req) 37 | w := httptest.NewRecorder() 38 | r.ServeHTTP(w, req) 39 | return w 40 | } 41 | 42 | func TestHostMultiTenancy(t *testing.T) { 43 | w := getW("/", func(r *http.Request) { 44 | }) 45 | assert.Equal(t, http.StatusOK, w.Code) 46 | var response map[string]interface{} 47 | err := json.Unmarshal([]byte(w.Body.String()), &response) 48 | value, exists := response["tenantId"] 49 | assert.True(t, exists) 50 | assert.Equal(t, "", value) 51 | assert.Nil(t, err) 52 | } 53 | func TestNotFoundMultiTenancy(t *testing.T) { 54 | w := getW("/", func(r *http.Request) { 55 | r.Header.Set("__tenant", "1000") 56 | }) 57 | assert.Equal(t, http.StatusNotFound, w.Code) 58 | } 59 | 60 | func TestCookieMultiTenancy(t *testing.T) { 61 | w := getW("/", func(r *http.Request) { 62 | r.AddCookie(&http.Cookie{ 63 | Name: "__tenant", 64 | Value: "1", 65 | Path: "", 66 | Domain: "", 67 | Expires: time.Time{}, 68 | RawExpires: "", 69 | MaxAge: 0, 70 | Secure: false, 71 | HttpOnly: false, 72 | SameSite: 0, 73 | Raw: "", 74 | Unparsed: nil, 75 | }) 76 | }) 77 | assert.Equal(t, http.StatusOK, w.Code) 78 | var response map[string]interface{} 79 | err := json.Unmarshal([]byte(w.Body.String()), &response) 80 | value, exists := response["tenantId"] 81 | assert.True(t, exists) 82 | assert.Equal(t, "1", value) 83 | assert.Nil(t, err) 84 | } 85 | 86 | func TestHeaderMultiTenancy(t *testing.T) { 87 | w := getW("/", func(r *http.Request) { 88 | r.Header.Set("__tenant", "1") 89 | }) 90 | assert.Equal(t, http.StatusOK, w.Code) 91 | var response map[string]interface{} 92 | err := json.Unmarshal([]byte(w.Body.String()), &response) 93 | value, exists := response["tenantId"] 94 | assert.True(t, exists) 95 | assert.Equal(t, "1", value) 96 | assert.Nil(t, err) 97 | } 98 | -------------------------------------------------------------------------------- /gin/readme.md: -------------------------------------------------------------------------------- 1 | ### [gin](https://github.com/gin-gonic/gin) adapter -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-saas/saas 2 | 3 | go 1.20 4 | 5 | require ( 6 | entgo.io/ent v0.12.4 7 | github.com/apache/apisix-go-plugin-runner v0.4.0 8 | github.com/gin-gonic/gin v1.8.1 9 | github.com/go-kratos/kratos/v2 v2.5.0 10 | github.com/google/uuid v1.3.1 11 | github.com/gorilla/mux v1.8.0 12 | github.com/kataras/iris/v12 v12.2.7 13 | github.com/stretchr/testify v1.8.4 14 | gorm.io/driver/sqlite v1.4.3 15 | gorm.io/gorm v1.24.2 16 | 17 | ) 18 | 19 | require ( 20 | github.com/BurntSushi/toml v1.3.2 // indirect 21 | github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 // indirect 22 | github.com/CloudyKit/jet/v6 v6.2.0 // indirect 23 | github.com/Joker/jade v1.1.3 // indirect 24 | github.com/ReneKroon/ttlcache/v2 v2.11.0 // indirect 25 | github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 // indirect 26 | github.com/ajg/form v1.5.1 // indirect 27 | github.com/andybalholm/brotli v1.0.5 // indirect 28 | github.com/api7/ext-plugin-proto v0.6.0 // indirect 29 | github.com/aymerick/douceur v0.2.0 // indirect 30 | github.com/davecgh/go-spew v1.1.1 // indirect 31 | github.com/fatih/color v1.15.0 // indirect 32 | github.com/fatih/structs v1.1.0 // indirect 33 | github.com/flosch/pongo2/v4 v4.0.2 // indirect 34 | github.com/gin-contrib/sse v0.1.0 // indirect 35 | github.com/go-playground/form/v4 v4.2.0 // indirect 36 | github.com/go-playground/locales v0.14.0 // indirect 37 | github.com/go-playground/universal-translator v0.18.0 // indirect 38 | github.com/go-playground/validator/v10 v10.11.1 // indirect 39 | github.com/gobwas/glob v0.2.3 // indirect 40 | github.com/goccy/go-json v0.9.11 // indirect 41 | github.com/golang/protobuf v1.5.2 // indirect 42 | github.com/golang/snappy v0.0.4 // indirect 43 | github.com/gomarkdown/markdown v0.0.0-20230922112808-5421fefb8386 // indirect 44 | github.com/google/flatbuffers v22.9.29+incompatible // indirect 45 | github.com/google/go-querystring v1.1.0 // indirect 46 | github.com/gorilla/css v1.0.0 // indirect 47 | github.com/gorilla/websocket v1.5.0 // indirect 48 | github.com/imkira/go-interpol v1.1.0 // indirect 49 | github.com/iris-contrib/httpexpect/v2 v2.15.2 // indirect 50 | github.com/iris-contrib/schema v0.0.6 // indirect 51 | github.com/jinzhu/inflection v1.0.0 // indirect 52 | github.com/jinzhu/now v1.1.5 // indirect 53 | github.com/josharian/intern v1.0.0 // indirect 54 | github.com/json-iterator/go v1.1.12 // indirect 55 | github.com/kataras/blocks v0.0.8 // indirect 56 | github.com/kataras/golog v0.1.9 // indirect 57 | github.com/kataras/pio v0.0.12 // indirect 58 | github.com/kataras/sitemap v0.0.6 // indirect 59 | github.com/kataras/tunnel v0.0.4 // indirect 60 | github.com/klauspost/compress v1.17.0 // indirect 61 | github.com/leodido/go-urn v1.2.1 // indirect 62 | github.com/mailgun/raymond/v2 v2.0.48 // indirect 63 | github.com/mailru/easyjson v0.7.7 // indirect 64 | github.com/mattn/go-colorable v0.1.13 // indirect 65 | github.com/mattn/go-isatty v0.0.19 // indirect 66 | github.com/mattn/go-sqlite3 v1.14.16 // indirect 67 | github.com/microcosm-cc/bluemonday v1.0.25 // indirect 68 | github.com/mitchellh/go-wordwrap v1.0.1 // indirect 69 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 70 | github.com/modern-go/reflect2 v1.0.2 // indirect 71 | github.com/pelletier/go-toml/v2 v2.0.5 // indirect 72 | github.com/pmezard/go-difflib v1.0.0 // indirect 73 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 74 | github.com/sanity-io/litter v1.5.5 // indirect 75 | github.com/schollz/closestmatch v2.1.0+incompatible // indirect 76 | github.com/sergi/go-diff v1.0.0 // indirect 77 | github.com/sirupsen/logrus v1.8.1 // indirect 78 | github.com/tdewolff/minify/v2 v2.12.9 // indirect 79 | github.com/tdewolff/parse/v2 v2.6.8 // indirect 80 | github.com/ugorji/go/codec v1.2.7 // indirect 81 | github.com/valyala/bytebufferpool v1.0.0 // indirect 82 | github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect 83 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 84 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect 85 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect 86 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect 87 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 // indirect 88 | github.com/yosssi/ace v0.0.5 // indirect 89 | github.com/yudai/gojsondiff v1.0.0 // indirect 90 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect 91 | go.uber.org/atomic v1.10.0 // indirect 92 | go.uber.org/multierr v1.8.0 // indirect 93 | go.uber.org/zap v1.23.0 // indirect 94 | golang.org/x/crypto v0.13.0 // indirect 95 | golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect 96 | golang.org/x/net v0.15.0 // indirect 97 | golang.org/x/sync v0.1.0 // indirect 98 | golang.org/x/sys v0.12.0 // indirect 99 | golang.org/x/text v0.13.0 // indirect 100 | golang.org/x/time v0.3.0 // indirect 101 | google.golang.org/genproto v0.0.0-20220930163606-c98284e70a91 // indirect 102 | google.golang.org/grpc v1.49.0 // indirect 103 | google.golang.org/protobuf v1.31.0 // indirect 104 | gopkg.in/ini.v1 v1.67.0 // indirect 105 | gopkg.in/yaml.v2 v2.4.0 // indirect 106 | gopkg.in/yaml.v3 v3.0.1 // indirect 107 | moul.io/http2curl/v2 v2.3.0 // indirect 108 | ) 109 | -------------------------------------------------------------------------------- /gorm/gorm.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas" 6 | 7 | "github.com/go-saas/saas/data" 8 | "gorm.io/gorm" 9 | ) 10 | 11 | // MultiTenancy entity 12 | type MultiTenancy struct { 13 | TenantId HasTenant `gorm:"index"` 14 | } 15 | 16 | type DbProvider saas.DbProvider[*gorm.DB] 17 | type ClientProvider saas.ClientProvider[*gorm.DB] 18 | type ClientProviderFunc saas.ClientProviderFunc[*gorm.DB] 19 | 20 | func (c ClientProviderFunc) Get(ctx context.Context, dsn string) (*gorm.DB, error) { 21 | return c(ctx, dsn) 22 | } 23 | 24 | func NewDbProvider(cs data.ConnStrResolver, cp ClientProvider) DbProvider { 25 | return saas.NewDbProvider[*gorm.DB](cs, cp) 26 | } 27 | 28 | type DbWrap struct { 29 | *gorm.DB 30 | } 31 | 32 | // NewDbWrap wrap gorm.DB into io.Close 33 | func NewDbWrap(db *gorm.DB) *DbWrap { 34 | return &DbWrap{db} 35 | } 36 | 37 | func (d *DbWrap) Close() error { 38 | return closeDb(d.DB) 39 | } 40 | 41 | func closeDb(d *gorm.DB) error { 42 | sqlDB, err := d.DB() 43 | if err != nil { 44 | return err 45 | } 46 | cErr := sqlDB.Close() 47 | if cErr != nil { 48 | //todo logging 49 | //logger.Errorf("Gorm db close error: %s", err.Error()) 50 | return cErr 51 | } 52 | return nil 53 | } 54 | -------------------------------------------------------------------------------- /gorm/has_tenant.go: -------------------------------------------------------------------------------- 1 | //ref:https://github.com/go-gorm/gorm/blob/master/soft_delete.go 2 | 3 | package gorm 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "database/sql/driver" 9 | "encoding/json" 10 | "github.com/go-saas/saas" 11 | 12 | "github.com/go-saas/saas/data" 13 | "gorm.io/gorm" 14 | "gorm.io/gorm/clause" 15 | "gorm.io/gorm/schema" 16 | ) 17 | 18 | type HasTenant sql.NullString 19 | 20 | func NewTenantId(s string) HasTenant { 21 | if s == "" { 22 | return HasTenant{ 23 | Valid: false, 24 | } 25 | } else { 26 | return HasTenant{ 27 | String: s, 28 | Valid: true, 29 | } 30 | } 31 | } 32 | 33 | func (t HasTenant) GormValue(ctx context.Context, db *gorm.DB) (expr clause.Expr) { 34 | ct, _ := saas.FromCurrentTenant(ctx) 35 | at := data.FromAutoSetTenantId(ctx) 36 | if at { 37 | if ct.GetId() != t.String { 38 | //mismatch 39 | if ct.GetId() != "" { 40 | //only normalize in tenant side 41 | if !t.Valid || t.String == "" { 42 | //tenant want to insert self 43 | return clause.Expr{SQL: "?", Vars: []interface{}{ct.GetId()}} 44 | } else { 45 | //tenant want to insert others 46 | //force reset 47 | return clause.Expr{SQL: "?", Vars: []interface{}{ct.GetId()}} 48 | } 49 | } 50 | } 51 | } 52 | if t.Valid && t.String != "" { 53 | return clause.Expr{SQL: "?", Vars: []interface{}{t.String}} 54 | } else { 55 | return clause.Expr{SQL: "?", Vars: []interface{}{nil}} 56 | } 57 | } 58 | 59 | // Scan implements the Scanner interface. 60 | func (t *HasTenant) Scan(value interface{}) error { 61 | return (*sql.NullString)(t).Scan(value) 62 | } 63 | 64 | // Value implements the driver Valuer interface. 65 | func (t HasTenant) Value() (driver.Value, error) { 66 | if !t.Valid { 67 | return nil, nil 68 | } 69 | return t.String, nil 70 | } 71 | 72 | func (t HasTenant) MarshalJSON() ([]byte, error) { 73 | if t.Valid { 74 | return json.Marshal(t.String) 75 | } 76 | return json.Marshal(nil) 77 | } 78 | 79 | func (t HasTenant) UnmarshalJSON(b []byte) error { 80 | if string(b) == "null" { 81 | t.Valid = false 82 | return nil 83 | } 84 | err := json.Unmarshal(b, &t.String) 85 | if err == nil { 86 | t.Valid = true 87 | } 88 | return err 89 | } 90 | 91 | func (HasTenant) QueryClauses(f *schema.Field) []clause.Interface { 92 | return []clause.Interface{HasTenantQueryClause{Field: f}} 93 | } 94 | 95 | type HasTenantQueryClause struct { 96 | Field *schema.Field 97 | } 98 | 99 | func (sd HasTenantQueryClause) Name() string { 100 | return "" 101 | } 102 | 103 | func (sd HasTenantQueryClause) Build(clause.Builder) { 104 | } 105 | 106 | func (sd HasTenantQueryClause) MergeClause(*clause.Clause) { 107 | } 108 | 109 | func (sd HasTenantQueryClause) ModifyStatement(stmt *gorm.Statement) { 110 | t, _ := saas.FromCurrentTenant(stmt.Context) 111 | e := data.FromMultiTenancyDataFilter(stmt.Context) 112 | if !e { 113 | return 114 | } 115 | if _, ok := stmt.Clauses["multi_tenancy_enabled"]; !ok { 116 | if c, ok := stmt.Clauses["WHERE"]; ok { 117 | if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { 118 | for _, expr := range where.Exprs { 119 | if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { 120 | where.Exprs = []clause.Expression{clause.And(where.Exprs...)} 121 | c.Expression = where 122 | stmt.Clauses["WHERE"] = c 123 | break 124 | } 125 | } 126 | } 127 | } 128 | var v interface{} 129 | if t.GetId() == "" { 130 | v = nil 131 | } else { 132 | v = t.GetId() 133 | } 134 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{ 135 | clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: v}, 136 | }}) 137 | stmt.Clauses["multi_tenancy_enabled"] = clause.Clause{} 138 | } 139 | } 140 | 141 | func (HasTenant) DeleteClauses(f *schema.Field) []clause.Interface { 142 | return []clause.Interface{HasTenantDeleteClause{Field: f}} 143 | } 144 | 145 | type HasTenantDeleteClause struct { 146 | Field *schema.Field 147 | } 148 | 149 | func (sd HasTenantDeleteClause) Name() string { 150 | return "" 151 | } 152 | 153 | func (sd HasTenantDeleteClause) Build(clause.Builder) { 154 | } 155 | 156 | func (sd HasTenantDeleteClause) MergeClause(*clause.Clause) { 157 | } 158 | 159 | func (sd HasTenantDeleteClause) ModifyStatement(stmt *gorm.Statement) { 160 | if stmt.SQL.Len() == 0 { 161 | HasTenantQueryClause(sd).ModifyStatement(stmt) 162 | } 163 | } 164 | 165 | func (HasTenant) UpdateClauses(f *schema.Field) []clause.Interface { 166 | return []clause.Interface{HasTenantUpdateClause{Field: f}} 167 | } 168 | 169 | type HasTenantUpdateClause struct { 170 | Field *schema.Field 171 | } 172 | 173 | func (sd HasTenantUpdateClause) Name() string { 174 | return "" 175 | } 176 | 177 | func (sd HasTenantUpdateClause) Build(clause.Builder) { 178 | } 179 | 180 | func (sd HasTenantUpdateClause) MergeClause(*clause.Clause) { 181 | } 182 | 183 | func (sd HasTenantUpdateClause) ModifyStatement(stmt *gorm.Statement) { 184 | if stmt.SQL.Len() == 0 { 185 | HasTenantQueryClause(sd).ModifyStatement(stmt) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /gorm/readme.md: -------------------------------------------------------------------------------- 1 | ### [gorm](https://github.com/go-gorm/gorm) adapter -------------------------------------------------------------------------------- /gorm/sqlite_db_test.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "github.com/go-saas/saas" 7 | "github.com/google/uuid" 8 | 9 | "github.com/go-saas/saas/data" 10 | "gorm.io/driver/sqlite" 11 | g "gorm.io/gorm" 12 | "os" 13 | "testing" 14 | ) 15 | 16 | var TestDb *g.DB 17 | 18 | var TestDbProvider DbProvider 19 | 20 | var ( 21 | TenantId1 = uuid.New().String() 22 | TenantId2 = uuid.New().String() 23 | ) 24 | 25 | func TestMain(m *testing.M) { 26 | 27 | clientProvider := ClientProviderFunc(func(ctx context.Context, s string) (*g.DB, error) { 28 | db, err := sql.Open("sqlite3", s) 29 | if err != nil { 30 | return nil, err 31 | } 32 | db.SetMaxIdleConns(1) 33 | db.SetMaxOpenConns(1) 34 | 35 | client, err := g.Open(&sqlite.Dialector{ 36 | DriverName: sqlite.DriverName, 37 | DSN: s, 38 | Conn: db, 39 | }) 40 | if err != nil { 41 | return client, err 42 | } 43 | return client.WithContext(ctx).Debug(), err 44 | }) 45 | TestDbProvider = NewDbProvider(GetConnStrResolver(), clientProvider) 46 | 47 | TestDb = GetDb(context.Background(), TestDbProvider) 48 | err := AutoMigrate(nil, TestDb) 49 | if err != nil { 50 | panic(err) 51 | } 52 | 53 | exitCode := m.Run() 54 | NewDbWrap(TestDb).Close() 55 | // 退出 56 | os.Exit(exitCode) 57 | 58 | } 59 | 60 | func GetConnStrResolver() *saas.MultiTenancyConnStrResolver { 61 | //use memory store 62 | 63 | ts := saas.NewMemoryTenantStore( 64 | []saas.TenantConfig{ 65 | {ID: TenantId1, Name: "Test1"}, 66 | {ID: TenantId2, Name: "Test2", Conn: map[string]string{ 67 | data.Default: ":memory:?cache=shared", 68 | }}, 69 | }) 70 | conn := make(data.ConnStrings, 1) 71 | conn.SetDefault("file::memory:?cache=shared") 72 | mr := saas.NewMultiTenancyConnStrResolver(ts, conn) 73 | return mr 74 | } 75 | -------------------------------------------------------------------------------- /http/cookie_tenant_resolve_contrib.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "net/http" 6 | ) 7 | 8 | type CookieTenantResolveContrib struct { 9 | key string 10 | request *http.Request 11 | } 12 | 13 | func NewCookieTenantResolveContrib(key string, r *http.Request) *CookieTenantResolveContrib { 14 | return &CookieTenantResolveContrib{ 15 | key: key, 16 | request: r, 17 | } 18 | } 19 | 20 | func (h *CookieTenantResolveContrib) Name() string { 21 | return "Cookie" 22 | } 23 | 24 | func (h *CookieTenantResolveContrib) Resolve(ctx *saas.Context) error { 25 | v, err := h.request.Cookie(h.key) 26 | if err != nil { 27 | //no cookie 28 | return nil 29 | } 30 | if v.Value == "" { 31 | return nil 32 | } 33 | ctx.TenantIdOrName = v.Value 34 | return nil 35 | } 36 | -------------------------------------------------------------------------------- /http/domain_tenant_resolve_contrib.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "net/http" 6 | "regexp" 7 | ) 8 | 9 | type DomainTenantResolveContrib struct { 10 | request *http.Request 11 | format string 12 | } 13 | 14 | func NewDomainTenantResolveContrib(f string, r *http.Request) *DomainTenantResolveContrib { 15 | return &DomainTenantResolveContrib{ 16 | request: r, 17 | format: f, 18 | } 19 | } 20 | 21 | func (h *DomainTenantResolveContrib) Name() string { 22 | return "Domain" 23 | } 24 | 25 | func (h *DomainTenantResolveContrib) Resolve(ctx *saas.Context) error { 26 | host := h.request.Host 27 | r := regexp.MustCompile(h.format) 28 | f := r.FindAllStringSubmatch(host, -1) 29 | if f == nil { 30 | //no match 31 | return nil 32 | } 33 | ctx.TenantIdOrName = f[0][1] 34 | return nil 35 | } 36 | -------------------------------------------------------------------------------- /http/form_tenant_resolve_contrib.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "net/http" 6 | ) 7 | 8 | type FormTenantResolveContrib struct { 9 | key string 10 | request *http.Request 11 | } 12 | 13 | func NewFormTenantResolveContrib(key string, r *http.Request) *FormTenantResolveContrib { 14 | return &FormTenantResolveContrib{ 15 | key: key, 16 | request: r, 17 | } 18 | } 19 | 20 | func (h *FormTenantResolveContrib) Name() string { 21 | return "Form" 22 | } 23 | 24 | func (h *FormTenantResolveContrib) Resolve(ctx *saas.Context) error { 25 | v := h.request.FormValue(h.key) 26 | if v == "" { 27 | return nil 28 | } 29 | ctx.TenantIdOrName = v 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /http/header_tenant_resolve_contrib.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "net/http" 6 | ) 7 | 8 | type HeaderTenantResolveContrib struct { 9 | key string 10 | request *http.Request 11 | } 12 | 13 | func NewHeaderTenantResolveContrib(key string, r *http.Request) *HeaderTenantResolveContrib { 14 | return &HeaderTenantResolveContrib{ 15 | key: key, 16 | request: r, 17 | } 18 | } 19 | 20 | func (h *HeaderTenantResolveContrib) Name() string { 21 | return "Header" 22 | } 23 | 24 | func (h *HeaderTenantResolveContrib) Resolve(ctx *saas.Context) error { 25 | v := h.request.Header.Get(h.key) 26 | if v == "" { 27 | return nil 28 | } 29 | ctx.TenantIdOrName = v 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /http/multi_tenancy.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "github.com/go-saas/saas/data" 6 | "net/http" 7 | ) 8 | 9 | type ErrorFormatter func(w http.ResponseWriter, err error) 10 | 11 | var ( 12 | DefaultErrorFormatter ErrorFormatter = func(w http.ResponseWriter, err error) { 13 | if err == saas.ErrTenantNotFound { 14 | //not found 15 | http.Error(w, "Not Found", 404) 16 | } else { 17 | http.Error(w, err.Error(), 500) 18 | } 19 | } 20 | ) 21 | 22 | type option struct { 23 | hmtOpt *WebMultiTenancyOption 24 | ef ErrorFormatter 25 | resolve []saas.ResolveOption 26 | } 27 | 28 | type Option func(*option) 29 | 30 | func WithMultiTenancyOption(opt *WebMultiTenancyOption) Option { 31 | return func(o *option) { 32 | o.hmtOpt = opt 33 | } 34 | } 35 | 36 | func WithErrorFormatter(e ErrorFormatter) Option { 37 | return func(o *option) { 38 | o.ef = e 39 | } 40 | } 41 | 42 | func WithResolveOption(opt ...saas.ResolveOption) Option { 43 | return func(o *option) { 44 | o.resolve = opt 45 | } 46 | } 47 | 48 | func Middleware(ts saas.TenantStore, options ...Option) func(next http.Handler) http.Handler { 49 | opt := &option{ 50 | hmtOpt: NewDefaultWebMultiTenancyOption(), 51 | ef: DefaultErrorFormatter, 52 | resolve: nil, 53 | } 54 | for _, o := range options { 55 | o(opt) 56 | } 57 | return func(next http.Handler) http.Handler { 58 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 59 | var trOpt []saas.ResolveOption 60 | df := []saas.TenantResolveContrib{ 61 | NewCookieTenantResolveContrib(opt.hmtOpt.TenantKey, r), 62 | NewFormTenantResolveContrib(opt.hmtOpt.TenantKey, r), 63 | NewHeaderTenantResolveContrib(opt.hmtOpt.TenantKey, r), 64 | NewQueryTenantResolveContrib(opt.hmtOpt.TenantKey, r), 65 | } 66 | 67 | if opt.hmtOpt.DomainFormat != "" { 68 | df = append(df, NewDomainTenantResolveContrib(opt.hmtOpt.DomainFormat, r)) 69 | } 70 | df = append(df, saas.NewTenantNormalizerContrib(ts)) 71 | trOpt = append(trOpt, saas.AppendContribs(df...)) 72 | trOpt = append(trOpt, opt.resolve...) 73 | 74 | //get tenant config 75 | tenantConfigProvider := saas.NewDefaultTenantConfigProvider(saas.NewDefaultTenantResolver(trOpt...), ts) 76 | tenantConfig, ctx, err := tenantConfigProvider.Get(r.Context()) 77 | if err != nil { 78 | opt.ef(w, err) 79 | return 80 | } 81 | //set current tenant 82 | newContext := saas.NewCurrentTenant(ctx, tenantConfig.ID, tenantConfig.Name) 83 | //data filter 84 | newContext = data.NewMultiTenancyDataFilter(newContext) 85 | next.ServeHTTP(w, r.WithContext(newContext)) 86 | }) 87 | } 88 | 89 | } 90 | -------------------------------------------------------------------------------- /http/multi_tenancy_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "github.com/go-saas/saas" 7 | "github.com/gorilla/mux" 8 | "github.com/stretchr/testify/assert" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func SetUp() *mux.Router { 16 | r := mux.NewRouter() 17 | 18 | r.Use(Middleware(saas.NewMemoryTenantStore( 19 | []saas.TenantConfig{ 20 | {ID: "1", Name: "Test1"}, 21 | {ID: "2", Name: "Test3"}, 22 | }))) 23 | 24 | r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 25 | // an example API handler 26 | tenantInfo, _ := saas.FromCurrentTenant(r.Context()) 27 | trR := saas.FromTenantResolveRes(r.Context()) 28 | json.NewEncoder(w).Encode(map[string]interface{}{ 29 | "tenantId": tenantInfo.GetId(), 30 | "resolvers": trR.AppliedResolvers, 31 | }) 32 | }) 33 | return r 34 | } 35 | 36 | func getW(url string, f func(r *http.Request)) *httptest.ResponseRecorder { 37 | r := SetUp() 38 | req, _ := http.NewRequest("GET", url, nil) 39 | f(req) 40 | w := httptest.NewRecorder() 41 | r.ServeHTTP(w, req) 42 | return w 43 | } 44 | 45 | func TestHostMultiTenancy(t *testing.T) { 46 | w := getW("/", func(r *http.Request) { 47 | }) 48 | assert.Equal(t, http.StatusOK, w.Code) 49 | var response map[string]interface{} 50 | err := json.Unmarshal([]byte(w.Body.String()), &response) 51 | value, exists := response["tenantId"] 52 | assert.True(t, exists) 53 | assert.Equal(t, "", value) 54 | assert.Nil(t, err) 55 | } 56 | func TestNotFoundMultiTenancy(t *testing.T) { 57 | w := getW("/", func(r *http.Request) { 58 | r.Header.Set("__tenant", "1000") 59 | }) 60 | assert.Equal(t, http.StatusNotFound, w.Code) 61 | } 62 | 63 | func TestCookieMultiTenancy(t *testing.T) { 64 | w := getW("/", func(r *http.Request) { 65 | r.AddCookie(&http.Cookie{ 66 | Name: "__tenant", 67 | Value: "1", 68 | Path: "", 69 | Domain: "", 70 | Expires: time.Time{}, 71 | RawExpires: "", 72 | MaxAge: 0, 73 | Secure: false, 74 | HttpOnly: false, 75 | SameSite: 0, 76 | Raw: "", 77 | Unparsed: nil, 78 | }) 79 | }) 80 | assert.Equal(t, http.StatusOK, w.Code) 81 | var response map[string]interface{} 82 | err := json.Unmarshal([]byte(w.Body.String()), &response) 83 | value, exists := response["tenantId"] 84 | assert.True(t, exists) 85 | assert.Equal(t, "1", value) 86 | assert.Nil(t, err) 87 | } 88 | 89 | func TestHeaderMultiTenancy(t *testing.T) { 90 | w := getW("/", func(r *http.Request) { 91 | r.Header.Set("__tenant", "1") 92 | }) 93 | assert.Equal(t, http.StatusOK, w.Code) 94 | var response map[string]interface{} 95 | err := json.Unmarshal([]byte(w.Body.String()), &response) 96 | value, exists := response["tenantId"] 97 | assert.True(t, exists) 98 | assert.Equal(t, "1", value) 99 | assert.Nil(t, err) 100 | } 101 | 102 | func TestTerminate(t *testing.T) { 103 | r := mux.NewRouter() 104 | 105 | r.Use(Middleware(saas.NewMemoryTenantStore( 106 | []saas.TenantConfig{ 107 | {ID: "1", Name: "Test1"}, 108 | {ID: "2", Name: "Test3"}, 109 | }), 110 | WithErrorFormatter(func(w http.ResponseWriter, err error) { 111 | if err == ErrForbidden { 112 | http.Error(w, "Forbidden", 403) 113 | } 114 | }), 115 | WithResolveOption(saas.AppendContribs(&TerminateContrib{})))) 116 | 117 | r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 118 | }) 119 | 120 | req, _ := http.NewRequest("GET", "/", nil) 121 | 122 | w := httptest.NewRecorder() 123 | r.ServeHTTP(w, req) 124 | 125 | assert.Equal(t, http.StatusForbidden, w.Code) 126 | } 127 | 128 | var ( 129 | ErrForbidden = errors.New("forbidden") 130 | ) 131 | 132 | type TerminateContrib struct { 133 | } 134 | 135 | func (t *TerminateContrib) Name() string { 136 | return "Terminate" 137 | } 138 | 139 | func (t TerminateContrib) Resolve(_ *saas.Context) error { 140 | return ErrForbidden 141 | } 142 | -------------------------------------------------------------------------------- /http/query_tenant_resolve_contrib.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "net/http" 6 | ) 7 | 8 | type QueryTenantResolveContrib struct { 9 | key string 10 | request *http.Request 11 | } 12 | 13 | func NewQueryTenantResolveContrib(key string, r *http.Request) *QueryTenantResolveContrib { 14 | return &QueryTenantResolveContrib{ 15 | key: key, 16 | request: r, 17 | } 18 | } 19 | 20 | func (h *QueryTenantResolveContrib) Name() string { 21 | return "Query" 22 | } 23 | 24 | func (h *QueryTenantResolveContrib) Resolve(ctx *saas.Context) error { 25 | v := h.request.URL.Query().Get(h.key) 26 | if v == "" { 27 | return nil 28 | } 29 | ctx.TenantIdOrName = v 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /http/web_multi_tenancy_option.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | const DefaultKey = "__tenant" 4 | 5 | func KeyOrDefault(key string) string { 6 | if len(key) > 0 { 7 | return key 8 | 9 | } 10 | return DefaultKey 11 | } 12 | 13 | type WebMultiTenancyOption struct { 14 | TenantKey string 15 | DomainFormat string 16 | } 17 | 18 | func NewWebMultiTenancyOption(key string, domainFormat string) *WebMultiTenancyOption { 19 | key = KeyOrDefault(key) 20 | return &WebMultiTenancyOption{ 21 | TenantKey: key, 22 | DomainFormat: domainFormat, 23 | } 24 | } 25 | 26 | func NewDefaultWebMultiTenancyOption() *WebMultiTenancyOption { 27 | return NewWebMultiTenancyOption("", "") 28 | } 29 | -------------------------------------------------------------------------------- /iris/multi_tenancy.go: -------------------------------------------------------------------------------- 1 | package iris 2 | 3 | import ( 4 | "errors" 5 | "github.com/go-saas/saas" 6 | "github.com/go-saas/saas/data" 7 | "github.com/go-saas/saas/http" 8 | "github.com/kataras/iris/v12" 9 | ) 10 | 11 | type ErrorFormatter func(context iris.Context, err error) 12 | 13 | var ( 14 | DefaultErrorFormatter ErrorFormatter = func(context iris.Context, err error) { 15 | if errors.Is(err, saas.ErrTenantNotFound) { 16 | context.StopWithError(404, err) 17 | } else { 18 | context.StopWithError(500, err) 19 | } 20 | } 21 | ) 22 | 23 | type option struct { 24 | hmtOpt *http.WebMultiTenancyOption 25 | ef ErrorFormatter 26 | resolve []saas.ResolveOption 27 | } 28 | 29 | type Option func(*option) 30 | 31 | func WithMultiTenancyOption(opt *http.WebMultiTenancyOption) Option { 32 | return func(o *option) { 33 | o.hmtOpt = opt 34 | } 35 | } 36 | 37 | func WithErrorFormatter(e ErrorFormatter) Option { 38 | return func(o *option) { 39 | o.ef = e 40 | } 41 | } 42 | 43 | func WithResolveOption(opt ...saas.ResolveOption) Option { 44 | return func(o *option) { 45 | o.resolve = opt 46 | } 47 | } 48 | 49 | func MultiTenancy(ts saas.TenantStore, options ...Option) iris.Handler { 50 | opt := &option{ 51 | hmtOpt: http.NewDefaultWebMultiTenancyOption(), 52 | ef: DefaultErrorFormatter, 53 | resolve: nil, 54 | } 55 | for _, o := range options { 56 | o(opt) 57 | } 58 | return func(context iris.Context) { 59 | var trOpt []saas.ResolveOption 60 | df := []saas.TenantResolveContrib{ 61 | http.NewCookieTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request()), 62 | http.NewFormTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request()), 63 | http.NewHeaderTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request()), 64 | http.NewQueryTenantResolveContrib(opt.hmtOpt.TenantKey, context.Request())} 65 | if opt.hmtOpt.DomainFormat != "" { 66 | df = append(df, http.NewDomainTenantResolveContrib(opt.hmtOpt.DomainFormat, context.Request())) 67 | } 68 | df = append(df, saas.NewTenantNormalizerContrib(ts)) 69 | trOpt = append(trOpt, saas.AppendContribs(df...)) 70 | trOpt = append(trOpt, opt.resolve...) 71 | 72 | //get tenant config 73 | tenantConfigProvider := saas.NewDefaultTenantConfigProvider(saas.NewDefaultTenantResolver(trOpt...), ts) 74 | tenantConfig, ctx, err := tenantConfigProvider.Get(context) 75 | if err != nil { 76 | opt.ef(context, err) 77 | return 78 | } 79 | //set current tenant 80 | newContext := saas.NewCurrentTenant(ctx, tenantConfig.ID, tenantConfig.Name) 81 | //data filter 82 | newContext = data.NewMultiTenancyDataFilter(newContext) 83 | 84 | //with newContext 85 | context.ResetRequest(context.Request().WithContext(newContext)) 86 | //next 87 | context.Next() 88 | 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /iris/multi_tenancy_test.go: -------------------------------------------------------------------------------- 1 | package iris 2 | 3 | import ( 4 | "github.com/go-saas/saas" 5 | "github.com/kataras/iris/v12" 6 | "github.com/kataras/iris/v12/httptest" 7 | "github.com/stretchr/testify/assert" 8 | "net/http" 9 | "testing" 10 | ) 11 | 12 | func SetUp() *iris.Application { 13 | r := iris.New() 14 | r.Use(MultiTenancy(saas.NewMemoryTenantStore( 15 | []saas.TenantConfig{ 16 | {ID: "1", Name: "Test1"}, 17 | {ID: "2", Name: "Test3"}, 18 | }))) 19 | r.Get("/", func(c iris.Context) { 20 | rCtx := c.Request().Context() 21 | tenantInfo, _ := saas.FromCurrentTenant(rCtx) 22 | trR := saas.FromTenantResolveRes(rCtx) 23 | c.JSON(iris.Map{ 24 | "tenantId": tenantInfo.GetId(), 25 | "resolvers": trR.AppliedResolvers, 26 | }) 27 | }) 28 | return r 29 | } 30 | 31 | func TestHostMultiTenancy(t *testing.T) { 32 | e := httptest.New(t, SetUp()) 33 | t1 := e.GET("/").Expect().Status(http.StatusOK) 34 | 35 | var response map[string]interface{} 36 | t1.JSON().Decode(&response) 37 | value, exists := response["tenantId"] 38 | assert.True(t, exists) 39 | assert.Equal(t, "", value) 40 | 41 | } 42 | func TestNotFoundMultiTenancy(t *testing.T) { 43 | 44 | e := httptest.New(t, SetUp()) 45 | e.GET("/").WithHeader("__tenant", "1000").Expect().Status(http.StatusNotFound) 46 | } 47 | 48 | func TestCookieMultiTenancy(t *testing.T) { 49 | 50 | e := httptest.New(t, SetUp()) 51 | t1 := e.GET("/").WithCookie("__tenant", "1").Expect().Status(http.StatusOK) 52 | 53 | var response map[string]interface{} 54 | t1.JSON().Decode(&response) 55 | value, exists := response["tenantId"] 56 | assert.True(t, exists) 57 | assert.Equal(t, "1", value) 58 | 59 | } 60 | 61 | func TestHeaderMultiTenancy(t *testing.T) { 62 | 63 | e := httptest.New(t, SetUp()) 64 | t1 := e.GET("/").WithHeader("__tenant", "1").Expect().Status(http.StatusOK) 65 | 66 | var response map[string]interface{} 67 | t1.JSON().Decode(&response) 68 | value, exists := response["tenantId"] 69 | assert.True(t, exists) 70 | assert.Equal(t, "1", value) 71 | 72 | } 73 | -------------------------------------------------------------------------------- /iris/readme.md: -------------------------------------------------------------------------------- 1 | ### [iris](https://github.com/kataras/iris) adapter -------------------------------------------------------------------------------- /kratos/header_tenant_resolve_contributor.go: -------------------------------------------------------------------------------- 1 | package kratos 2 | 3 | import ( 4 | "github.com/go-kratos/kratos/v2/transport" 5 | "github.com/go-saas/saas" 6 | ) 7 | 8 | type HeaderTenantResolveContrib struct { 9 | key string 10 | transporter transport.Transporter 11 | } 12 | 13 | func NewHeaderTenantResolveContrib(key string, transporter transport.Transporter) *HeaderTenantResolveContrib { 14 | return &HeaderTenantResolveContrib{ 15 | key: key, 16 | transporter: transporter, 17 | } 18 | } 19 | func (h *HeaderTenantResolveContrib) Name() string { 20 | return "KratosHeader" 21 | } 22 | 23 | func (h *HeaderTenantResolveContrib) Resolve(ctx *saas.Context) error { 24 | v := h.transporter.RequestHeader().Get(h.key) 25 | if v == "" { 26 | return nil 27 | } 28 | ctx.TenantIdOrName = v 29 | return nil 30 | } 31 | -------------------------------------------------------------------------------- /kratos/multi_tenantcy.go: -------------------------------------------------------------------------------- 1 | package kratos 2 | 3 | import ( 4 | "context" 5 | "github.com/go-kratos/kratos/v2/errors" 6 | "github.com/go-kratos/kratos/v2/middleware" 7 | "github.com/go-kratos/kratos/v2/transport" 8 | "github.com/go-kratos/kratos/v2/transport/http" 9 | "github.com/go-saas/saas" 10 | "github.com/go-saas/saas/data" 11 | shttp "github.com/go-saas/saas/http" 12 | ) 13 | 14 | type ErrorFormatter func(err error) (interface{}, error) 15 | 16 | var ( 17 | DefaultErrorFormatter ErrorFormatter = func(err error) (interface{}, error) { 18 | //not found 19 | if errors.Is(err, saas.ErrTenantNotFound) { 20 | return nil, errors.NotFound("TENANT", err.Error()) 21 | } 22 | return nil, err 23 | } 24 | ) 25 | 26 | type option struct { 27 | hmtOpt *shttp.WebMultiTenancyOption 28 | ef ErrorFormatter 29 | resolve []saas.ResolveOption 30 | } 31 | 32 | type Option func(*option) 33 | 34 | func WithMultiTenancyOption(opt *shttp.WebMultiTenancyOption) Option { 35 | return func(o *option) { 36 | o.hmtOpt = opt 37 | } 38 | } 39 | 40 | func WithErrorFormatter(e ErrorFormatter) Option { 41 | return func(o *option) { 42 | o.ef = e 43 | } 44 | } 45 | 46 | func WithResolveOption(opt ...saas.ResolveOption) Option { 47 | return func(o *option) { 48 | o.resolve = opt 49 | } 50 | } 51 | 52 | func Server(ts saas.TenantStore, options ...Option) middleware.Middleware { 53 | opt := &option{ 54 | hmtOpt: shttp.NewDefaultWebMultiTenancyOption(), 55 | ef: DefaultErrorFormatter, 56 | resolve: nil, 57 | } 58 | for _, o := range options { 59 | o(opt) 60 | } 61 | return func(handler middleware.Handler) middleware.Handler { 62 | return func(ctx context.Context, req interface{}) (reply interface{}, err error) { 63 | var trOpt []saas.ResolveOption 64 | if tr, ok := transport.FromServerContext(ctx); ok { 65 | if ht, ok := tr.(*http.Transport); ok { 66 | r := ht.Request() 67 | df := []saas.TenantResolveContrib{ 68 | shttp.NewCookieTenantResolveContrib(opt.hmtOpt.TenantKey, r), 69 | shttp.NewFormTenantResolveContrib(opt.hmtOpt.TenantKey, r), 70 | shttp.NewHeaderTenantResolveContrib(opt.hmtOpt.TenantKey, r), 71 | shttp.NewQueryTenantResolveContrib(opt.hmtOpt.TenantKey, r), 72 | } 73 | if opt.hmtOpt.DomainFormat != "" { 74 | df = append(df, shttp.NewDomainTenantResolveContrib(opt.hmtOpt.DomainFormat, r)) 75 | } 76 | df = append(df, saas.NewTenantNormalizerContrib(ts)) 77 | trOpt = append(trOpt, saas.AppendContribs(df...)) 78 | } else { 79 | trOpt = append(trOpt, saas.AppendContribs(NewHeaderTenantResolveContrib(opt.hmtOpt.TenantKey, tr))) 80 | } 81 | trOpt = append(trOpt, opt.resolve...) 82 | 83 | //get tenant config 84 | tenantConfigProvider := saas.NewDefaultTenantConfigProvider(saas.NewDefaultTenantResolver(trOpt...), ts) 85 | tenantConfig, ctx, err := tenantConfigProvider.Get(ctx) 86 | if err != nil { 87 | return opt.ef(err) 88 | } 89 | newContext := saas.NewCurrentTenant(ctx, tenantConfig.ID, tenantConfig.Name) 90 | //data filter 91 | dataFilterCtx := data.NewMultiTenancyDataFilter(newContext) 92 | return handler(dataFilterCtx, req) 93 | } 94 | return handler(ctx, req) 95 | } 96 | } 97 | } 98 | 99 | func Client(hmtOpt *shttp.WebMultiTenancyOption) middleware.Middleware { 100 | return func(handler middleware.Handler) middleware.Handler { 101 | return func(ctx context.Context, req interface{}) (reply interface{}, err error) { 102 | ti, _ := saas.FromCurrentTenant(ctx) 103 | if tr, ok := transport.FromClientContext(ctx); ok { 104 | if tr.Kind() == transport.KindHTTP { 105 | if ht, ok := tr.(*http.Transport); ok { 106 | ht.RequestHeader().Set(hmtOpt.TenantKey, ti.GetId()) 107 | } 108 | } else if tr.Kind() == transport.KindGRPC { 109 | tr.RequestHeader().Set(hmtOpt.TenantKey, ti.GetName()) 110 | } 111 | } 112 | return handler(ctx, req) 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /multi_tenancy_conn_str_resolver.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas/data" 6 | ) 7 | 8 | type MultiTenancyConnStrResolver struct { 9 | //use creator to prevent circular dependency 10 | ts TenantStore 11 | fallback data.ConnStrResolver 12 | } 13 | 14 | var _ data.ConnStrResolver = (*MultiTenancyConnStrResolver)(nil) 15 | 16 | // NewMultiTenancyConnStrResolver from tenant 17 | func NewMultiTenancyConnStrResolver(ts TenantStore, fallback data.ConnStrResolver) *MultiTenancyConnStrResolver { 18 | return &MultiTenancyConnStrResolver{ 19 | ts: ts, 20 | fallback: fallback, 21 | } 22 | } 23 | 24 | func (m *MultiTenancyConnStrResolver) Resolve(ctx context.Context, key string) (string, error) { 25 | tenantInfo, _ := FromCurrentTenant(ctx) 26 | id := tenantInfo.GetId() 27 | if len(id) == 0 { 28 | //skip query tenant store 29 | return m.fallback.Resolve(ctx, key) 30 | } 31 | 32 | var tenantConfig *TenantConfig 33 | //read from cache 34 | if tenant, ok := FromTenantConfigContext(ctx, id); ok { 35 | tenantConfig = tenant 36 | } else { 37 | tenant, err := m.ts.GetByNameOrId(ctx, id) 38 | if err != nil { 39 | return "", err 40 | } 41 | tenantConfig = tenant 42 | } 43 | 44 | if tenantConfig.Conn == nil { 45 | //not found 46 | return m.fallback.Resolve(ctx, key) 47 | } 48 | 49 | //get key 50 | ret, err := tenantConfig.Conn.Resolve(ctx, key) 51 | if err != nil { 52 | return "", err 53 | } 54 | if ret != "" { 55 | return ret, nil 56 | } 57 | //still not found 58 | return m.fallback.Resolve(ctx, key) 59 | } 60 | -------------------------------------------------------------------------------- /multi_tenancy_option.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | type DatabaseStyleType int32 4 | 5 | const ( 6 | Single DatabaseStyleType = 1 << 0 7 | PerTenant DatabaseStyleType = 1 << 1 8 | Multi DatabaseStyleType = 1 << 2 9 | ) 10 | 11 | type MultiTenancyOption struct { 12 | IsEnabled bool 13 | DatabaseStyle DatabaseStyleType 14 | } 15 | 16 | type option func(tenancyOption *MultiTenancyOption) 17 | 18 | // WithEnabled enable status 19 | func WithEnabled(isEnabled bool) option { 20 | return func(tenancyOption *MultiTenancyOption) { 21 | tenancyOption.IsEnabled = isEnabled 22 | } 23 | } 24 | 25 | // 26 | // WithDatabaseStyle database style, support Single/PerTenant/Multi 27 | func WithDatabaseStyle(databaseStyle DatabaseStyleType) option { 28 | return func(tenancyOption *MultiTenancyOption) { 29 | tenancyOption.DatabaseStyle = databaseStyle 30 | } 31 | } 32 | 33 | func NewMultiTenancyOption(opts ...option) *MultiTenancyOption { 34 | option := MultiTenancyOption{} 35 | for _, opt := range opts { 36 | opt(&option) 37 | } 38 | return &option 39 | } 40 | 41 | func DefaultMultiTenancyOption() *MultiTenancyOption { 42 | return NewMultiTenancyOption(WithEnabled(true), WithDatabaseStyle(Multi)) 43 | } 44 | -------------------------------------------------------------------------------- /multi_tenancy_side.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | type MultiTenancySide int32 4 | 5 | const ( 6 | Tenant MultiTenancySide = 1 << 0 7 | Host MultiTenancySide = 1 << 1 8 | Both = Tenant | Host 9 | ) 10 | -------------------------------------------------------------------------------- /provider.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas/data" 6 | ) 7 | 8 | type ( 9 | //ClientProvider resolve by dsn string (connection string) 10 | ClientProvider[TClient interface{}] interface { 11 | Get(ctx context.Context, dsn string) (TClient, error) 12 | } 13 | 14 | // ClientProviderFunc see ClientProvider 15 | ClientProviderFunc[TClient interface{}] func(ctx context.Context, dsn string) (TClient, error) 16 | 17 | //DbProvider resolve TClient from user friendly key 18 | DbProvider[TClient interface{}] interface { 19 | // Get instance by key 20 | Get(ctx context.Context, key string) TClient 21 | } 22 | 23 | //DefaultDbProvider resolve dsn from user friendly key by data.ConnStrResolver, then resolve TClient from dsn by ClientProvider 24 | DefaultDbProvider[TClient interface{}] struct { 25 | cs data.ConnStrResolver 26 | cp ClientProvider[TClient] 27 | } 28 | ) 29 | 30 | func (c ClientProviderFunc[TClient]) Get(ctx context.Context, dsn string) (TClient, error) { 31 | return c(ctx, dsn) 32 | } 33 | 34 | func NewDbProvider[TClient interface{}](cs data.ConnStrResolver, cp ClientProvider[TClient]) (d *DefaultDbProvider[TClient]) { 35 | d = &DefaultDbProvider[TClient]{ 36 | cs: cs, 37 | cp: cp, 38 | } 39 | return 40 | } 41 | 42 | func (d *DefaultDbProvider[TClient]) Get(ctx context.Context, key string) TClient { 43 | //resolve connection string 44 | s, err := d.cs.Resolve(ctx, key) 45 | if err != nil { 46 | panic(err) 47 | } 48 | c, err := d.cp.Get(ctx, s) 49 | if err != nil { 50 | panic(err) 51 | } 52 | return c 53 | } 54 | -------------------------------------------------------------------------------- /seed/context.go: -------------------------------------------------------------------------------- 1 | package seed 2 | 3 | type Context struct { 4 | TenantId string 5 | //extra properties 6 | Extra map[string]interface{} 7 | } 8 | 9 | func NewSeedContext(tenantId string, extra map[string]interface{}) *Context { 10 | return &Context{ 11 | TenantId: tenantId, 12 | Extra: extra, 13 | } 14 | } 15 | 16 | func (s *Context) WithExtra(k string, v interface{}) *Context { 17 | s.Extra[k] = v 18 | return s 19 | } 20 | -------------------------------------------------------------------------------- /seed/contrib.go: -------------------------------------------------------------------------------- 1 | package seed 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Contrib interface { 8 | Seed(ctx context.Context, sCtx *Context) error 9 | } 10 | 11 | type chainContrib struct { 12 | seeds []Contrib 13 | } 14 | 15 | var _ Contrib = (*chainContrib)(nil) 16 | 17 | func (c *chainContrib) Seed(ctx context.Context, sCtx *Context) error { 18 | for _, seed := range c.seeds { 19 | if err := seed.Seed(ctx, sCtx); err != nil { 20 | return err 21 | } 22 | } 23 | return nil 24 | } 25 | 26 | func Chain(seeds ...Contrib) Contrib { 27 | return &chainContrib{seeds: seeds} 28 | } 29 | -------------------------------------------------------------------------------- /seed/option.go: -------------------------------------------------------------------------------- 1 | package seed 2 | 3 | type SeedOption struct { 4 | TenantIds []string 5 | Extra map[string]interface{} 6 | } 7 | 8 | func NewOption() *SeedOption { 9 | return &SeedOption{Extra: map[string]interface{}{}} 10 | } 11 | 12 | type Option func(opt *SeedOption) 13 | 14 | func WithTenantId(tenants ...string) Option { 15 | return func(opt *SeedOption) { 16 | opt.TenantIds = tenants 17 | } 18 | } 19 | 20 | func AddHost() Option { 21 | return func(opt *SeedOption) { 22 | opt.TenantIds = append(opt.TenantIds, "") 23 | } 24 | } 25 | func AddTenant(tenants ...string) Option { 26 | return func(opt *SeedOption) { 27 | opt.TenantIds = append(opt.TenantIds, tenants...) 28 | } 29 | } 30 | 31 | func WithExtra(extra map[string]interface{}) Option { 32 | return func(opt *SeedOption) { 33 | opt.Extra = extra 34 | } 35 | } 36 | 37 | func SetExtra(key string, v interface{}) Option { 38 | return func(opt *SeedOption) { 39 | opt.Extra[key] = v 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /seed/seeder.go: -------------------------------------------------------------------------------- 1 | package seed 2 | 3 | import ( 4 | "context" 5 | "github.com/go-saas/saas" 6 | ) 7 | 8 | type Seeder interface { 9 | Seed(ctx context.Context, option ...Option) error 10 | } 11 | 12 | var _ Seeder = (*DefaultSeeder)(nil) 13 | 14 | type DefaultSeeder struct { 15 | contrib []Contrib 16 | } 17 | 18 | func NewDefaultSeeder(contrib ...Contrib) *DefaultSeeder { 19 | return &DefaultSeeder{ 20 | contrib: contrib, 21 | } 22 | } 23 | 24 | func (d *DefaultSeeder) Seed(ctx context.Context, options ...Option) error { 25 | opt := NewOption() 26 | for _, option := range options { 27 | option(opt) 28 | } 29 | for _, tenant := range opt.TenantIds { 30 | // change to next tenant 31 | ctx = saas.NewCurrentTenant(ctx, tenant, "") 32 | 33 | seedFn := func(ctx context.Context) error { 34 | sCtx := NewSeedContext(tenant, opt.Extra) 35 | //create seeder 36 | for _, contributor := range d.contrib { 37 | if err := contributor.Seed(ctx, sCtx); err != nil { 38 | return err 39 | } 40 | } 41 | return nil 42 | } 43 | if err := seedFn(ctx); err != nil { 44 | return err 45 | } 46 | } 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /tenant_config.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import "github.com/go-saas/saas/data" 4 | 5 | type TenantConfig struct { 6 | ID string `json:"id"` 7 | Name string `json:"name"` 8 | Region string `json:"region"` 9 | PlanKey string `json:"planKey"` 10 | Conn data.ConnStrings `json:"conn"` 11 | } 12 | 13 | func NewTenantConfig(id, name, region, planKey string) *TenantConfig { 14 | return &TenantConfig{ 15 | ID: id, 16 | Name: name, 17 | Region: region, 18 | PlanKey: planKey, 19 | Conn: make(data.ConnStrings), 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /tenant_config_provider.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import "context" 4 | 5 | // TenantConfigProvider resolve tenant config from current context 6 | type TenantConfigProvider interface { 7 | // Get tenant config 8 | Get(ctx context.Context) (TenantConfig, context.Context, error) 9 | } 10 | 11 | type DefaultTenantConfigProvider struct { 12 | tr TenantResolver 13 | ts TenantStore 14 | } 15 | 16 | func NewDefaultTenantConfigProvider(tr TenantResolver, ts TenantStore) TenantConfigProvider { 17 | return &DefaultTenantConfigProvider{ 18 | tr: tr, 19 | ts: ts, 20 | } 21 | } 22 | 23 | // Get read from context FromTenantConfigContext first, fallback with TenantStore and return new context with cached value 24 | func (d *DefaultTenantConfigProvider) Get(ctx context.Context) (TenantConfig, context.Context, error) { 25 | rr, ctx, err := d.tr.Resolve(ctx) 26 | if err != nil { 27 | return TenantConfig{}, ctx, err 28 | } 29 | if rr.TenantIdOrName != "" { 30 | //tenant side 31 | 32 | //read from cache 33 | if cfg, ok := FromTenantConfigContext(ctx, rr.TenantIdOrName); ok { 34 | return *cfg, ctx, nil 35 | } 36 | //get config from tenant store 37 | cfg, err := d.ts.GetByNameOrId(ctx, rr.TenantIdOrName) 38 | if err != nil { 39 | return TenantConfig{}, ctx, err 40 | } 41 | return *cfg, NewTenantConfigContext(ctx, cfg.ID, cfg), nil 42 | } 43 | // host side 44 | return TenantConfig{}, ctx, nil 45 | 46 | } 47 | -------------------------------------------------------------------------------- /tenant_config_test.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestTenantConfig(t *testing.T) { 10 | in := []byte(`{"id":"1","name":"1","region":"1","planKey":"","conn":{"a":"a","b":"b"}}`) 11 | conf := &TenantConfig{} 12 | err := json.Unmarshal(in, &conf) 13 | assert.NoError(t, err) 14 | 15 | s, err := json.Marshal(conf) 16 | assert.NoError(t, err) 17 | assert.Equal(t, string(in), string(s)) 18 | } 19 | -------------------------------------------------------------------------------- /tenant_resolve_context.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import "context" 4 | 5 | type Context struct { 6 | context context.Context 7 | TenantIdOrName string 8 | // HasHandled field to handle host side unresolved or resolved 9 | HasHandled bool 10 | } 11 | 12 | func NewTenantResolveContext(ctx context.Context) *Context { 13 | return &Context{ 14 | context: ctx, 15 | } 16 | } 17 | 18 | func (t *Context) HasResolved() bool { 19 | return t.HasHandled 20 | } 21 | 22 | func (t *Context) Context() context.Context { 23 | return t.context 24 | } 25 | 26 | func (t *Context) WithContext(ctx context.Context) { 27 | t.context = ctx 28 | } 29 | -------------------------------------------------------------------------------- /tenant_resolve_contrib.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | type TenantResolveContrib interface { 4 | // Name of resolver 5 | Name() string 6 | // Resolve tenant 7 | Resolve(ctx *Context) error 8 | } 9 | 10 | //TenantNormalizerContrib normalize tenant id or name into tenant id 11 | type TenantNormalizerContrib struct { 12 | ts TenantStore 13 | } 14 | 15 | var _ TenantResolveContrib = (*TenantNormalizerContrib)(nil) 16 | 17 | func NewTenantNormalizerContrib(ts TenantStore) *TenantNormalizerContrib { 18 | return &TenantNormalizerContrib{ 19 | ts: ts, 20 | } 21 | } 22 | func (t *TenantNormalizerContrib) Name() string { 23 | return "TenantNormalizer" 24 | } 25 | 26 | func (t *TenantNormalizerContrib) Resolve(ctx *Context) error { 27 | if len(ctx.TenantIdOrName) > 0 { 28 | tenant, err := t.ts.GetByNameOrId(ctx.Context(), ctx.TenantIdOrName) 29 | if err != nil { 30 | return err 31 | } 32 | ctx.TenantIdOrName = tenant.ID 33 | //store for cache 34 | ctx.WithContext(NewTenantConfigContext(ctx.Context(), tenant.ID, tenant)) 35 | } 36 | return nil 37 | } 38 | 39 | // ContextContrib resolve from current context 40 | type ContextContrib struct { 41 | } 42 | 43 | var _ TenantResolveContrib = (*ContextContrib)(nil) 44 | 45 | func (c *ContextContrib) Name() string { 46 | return "ContextContrib" 47 | } 48 | 49 | func (c *ContextContrib) Resolve(ctx *Context) error { 50 | info, ok := FromCurrentTenant(ctx.Context()) 51 | if ok { 52 | ctx.TenantIdOrName = info.GetId() 53 | //terminate 54 | ctx.HasHandled = true 55 | } 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /tenant_resolve_option.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | type TenantResolveOption struct { 4 | Resolvers []TenantResolveContrib 5 | } 6 | 7 | type ResolveOption func(resolveOption *TenantResolveOption) 8 | 9 | func AppendContribs(c ...TenantResolveContrib) ResolveOption { 10 | return func(resolveOption *TenantResolveOption) { 11 | resolveOption.AppendContribs(c...) 12 | } 13 | } 14 | 15 | func RemoveContribs(c ...TenantResolveContrib) ResolveOption { 16 | return func(resolveOption *TenantResolveOption) { 17 | resolveOption.RemoveContribs(c...) 18 | } 19 | } 20 | 21 | func NewTenantResolveOption(c ...TenantResolveContrib) *TenantResolveOption { 22 | return &TenantResolveOption{ 23 | Resolvers: c, 24 | } 25 | } 26 | 27 | func (opt *TenantResolveOption) AppendContribs(c ...TenantResolveContrib) { 28 | opt.Resolvers = append(opt.Resolvers, c...) 29 | } 30 | 31 | func (opt *TenantResolveOption) RemoveContribs(c ...TenantResolveContrib) { 32 | var r []TenantResolveContrib 33 | for _, resolver := range opt.Resolvers { 34 | if !contains(c, resolver) { 35 | r = append(r, resolver) 36 | } 37 | } 38 | opt.Resolvers = r 39 | } 40 | 41 | func contains(a []TenantResolveContrib, b TenantResolveContrib) bool { 42 | for i := 0; i < len(a); i++ { 43 | if a[i] == b { 44 | return true 45 | } 46 | } 47 | return false 48 | } 49 | -------------------------------------------------------------------------------- /tenant_resolve_result.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | type TenantResolveResult struct { 4 | TenantIdOrName string 5 | AppliedResolvers []string 6 | } 7 | -------------------------------------------------------------------------------- /tenant_resolver.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import "context" 4 | 5 | type TenantResolver interface { 6 | Resolve(ctx context.Context) (TenantResolveResult, context.Context, error) 7 | } 8 | 9 | type DefaultTenantResolver struct { 10 | //options 11 | o *TenantResolveOption 12 | } 13 | 14 | func NewDefaultTenantResolver(opt ...ResolveOption) TenantResolver { 15 | o := NewTenantResolveOption(&ContextContrib{}) 16 | for _, resolveOption := range opt { 17 | resolveOption(o) 18 | } 19 | return &DefaultTenantResolver{ 20 | o: o, 21 | } 22 | } 23 | 24 | func (d *DefaultTenantResolver) Resolve(ctx context.Context) (TenantResolveResult, context.Context, error) { 25 | res := TenantResolveResult{} 26 | trCtx := NewTenantResolveContext(ctx) 27 | for _, resolver := range d.o.Resolvers { 28 | if err := resolver.Resolve(trCtx); err != nil { 29 | return res, trCtx.Context(), err 30 | } 31 | res.AppliedResolvers = append(res.AppliedResolvers, resolver.Name()) 32 | if trCtx.HasResolved() { 33 | break 34 | } 35 | } 36 | res.TenantIdOrName = trCtx.TenantIdOrName 37 | ctx = NewTenantResolveRes(trCtx.Context(), &res) 38 | return res, ctx, nil 39 | } 40 | -------------------------------------------------------------------------------- /tenant_store.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | var ( 9 | ErrTenantNotFound = errors.New("tenant not found") 10 | ) 11 | 12 | type TenantStore interface { 13 | // GetByNameOrId return nil and ErrTenantNotFound if tenant not found 14 | GetByNameOrId(ctx context.Context, nameOrId string) (*TenantConfig, error) 15 | } 16 | 17 | type MemoryTenantStore struct { 18 | TenantConfig []TenantConfig 19 | } 20 | 21 | var _ TenantStore = (*MemoryTenantStore)(nil) 22 | 23 | func NewMemoryTenantStore(t []TenantConfig) *MemoryTenantStore { 24 | return &MemoryTenantStore{ 25 | TenantConfig: t, 26 | } 27 | } 28 | 29 | func (m *MemoryTenantStore) GetByNameOrId(_ context.Context, nameOrId string) (*TenantConfig, error) { 30 | for _, config := range m.TenantConfig { 31 | if config.ID == nameOrId || config.Name == nameOrId { 32 | return &config, nil 33 | } 34 | } 35 | return nil, ErrTenantNotFound 36 | } 37 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package saas 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | func GetMultiTenantSide(ctx context.Context) MultiTenancySide { 8 | tenantInfo, _ := FromCurrentTenant(ctx) 9 | if tenantInfo.GetId() == "" { 10 | return Host 11 | } else { 12 | return Tenant 13 | } 14 | } 15 | --------------------------------------------------------------------------------