├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── README_EN.md ├── auth2 ├── claims.go ├── claims_test.go ├── example │ └── main.go ├── extractor.go ├── interface.go ├── jwt.go ├── jwt_test.go ├── local.go ├── local_test.go ├── redis.go ├── redis_test.go ├── token.go ├── token_test.go └── verifier.go ├── conf ├── auth.go ├── auth_test.go ├── config.go ├── config_test.go ├── cros.go ├── mysql.go ├── mysql_test.go ├── operate.go ├── viper.go └── viper_test.go ├── e └── error.go ├── example ├── main.go ├── public │ └── index.html └── readme.md ├── go.mod ├── go.sum ├── httptest ├── base.go ├── base_test.go ├── common.go ├── printer.go ├── reporter.go ├── respose.go └── respose_test.go ├── loadtls.go ├── menu.go ├── migrate.go ├── migrate_test.go ├── model.go ├── router.go ├── run.go ├── run_darwin.go ├── run_linux.go ├── run_windows.go ├── server.go ├── server_test.go └── validate.go /.gitignore: -------------------------------------------------------------------------------- 1 | # ext 2 | *.dll 3 | *.db 4 | *.exe 5 | *.log 6 | *.out 7 | *.txt 8 | *.yaml 9 | *.jpg 10 | *.json 11 | 12 | # file 13 | rbac_model.conf 14 | 15 | 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.18.x 5 | 6 | before_script: 7 | - sudo redis-server /etc/redis/redis.conf --port 6379 --requirepass 'secret' 8 | - mkdir -p data/db 9 | - mongod --dbpath=data/db & 10 | - sleep 5 11 | - mongo mongo_test --eval 'db.createUser({user:"travis",pwd:"test",roles:["readWrite"]});' 12 | 13 | services: 14 | - mysql 15 | 16 | env: 17 | - GO111MODULE=on redisPwd=secret mongoAddr='travis:test@127.0.0.1:27017/mongo_test' 18 | 19 | before_install: 20 | - go get -v -t ./... 21 | 22 | script: 23 | - go test -v -race -coverprofile='coverage.txt' -covermode=atomic github.com/snowlyg/iris-admin/g github.com/snowlyg/iris-admin/migration github.com/snowlyg/iris-admin/seed github.com/snowlyg/iris-admin/... 24 | 25 | after_success: 26 | - bash <(curl -s https://codecov.io/bash) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

IrisAdmin

2 | 3 | [![Build Status](https://app.travis-ci.com/snowlyg/iris-admin.svg?branch=master)](https://app.travis-ci.com/snowlyg/iris-admin) 4 | [![LICENSE](https://img.shields.io/github/license/snowlyg/iris-admin)](https://github.com/snowlyg/iris-admin/blob/master/LICENSE) 5 | [![go doc](https://godoc.org/github.com/snowlyg/iris-admin?status.svg)](https://godoc.org/github.com/snowlyg/iris-admin) 6 | [![go report](https://goreportcard.com/badge/github.com/snowlyg/iris-admin)](https://goreportcard.com/badge/github.com/snowlyg/iris-admin) 7 | [![Build Status](https://codecov.io/gh/snowlyg/iris-admin/branch/master/graph/badge.svg)](https://codecov.io/gh/snowlyg/iris-admin) 8 | 9 | 简体中文 | [English](./README_EN.md) 10 | 11 | #### 项目地址 12 | 13 | [GITHUB](https://github.com/snowlyg/iris-admin) 14 | 15 | > 简单项目仅供学习,欢迎指点! 16 | 17 | #### 相关文档 18 | 19 | - [IRIS-ADMIN-DOC](https://doc.snowlyg.com) 20 | - [IRIS V12 中文文档](https://github.com/snowlyg/iris/wiki) 21 | - [godoc](https://pkg.go.dev/github.com/snowlyg/iris-admin?utm_source=godoc) 22 | 23 | e9939a7e92f32337871feb22e06bd05a.jpeg 24 | e9939a7e92f32337871feb22e06bd05a.jpeg 25 | 26 | #### iris 学习记录分享 27 | 28 | - [Iris-go 项目登陆 API 构建细节实现过程](https://snowlyg.github.io/iris-go-api-1/) 29 | 30 | - [iris + casbin 从陌生到学会使用的过程](https://snowlyg.github.io/iris-go-api-2/) 31 | 32 | --- 33 | 34 | #### 简单使用 35 | 36 | - 获取依赖包,注意必须带上 `master` 版本 37 | 38 | ```sh 39 | go get github.com/snowlyg/iris-admin@master 40 | ``` 41 | 42 | #### 打赏 43 | 44 | > 您的打赏将用于支付网站运行,会在项目介绍中特别鸣谢您 45 | - [爱发电](https://afdian.net/@snowlyg/plan) 46 | - [donating](https://paypal.me/snowlyg?country.x=C2&locale.x=zh_XC) 47 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 |

IrisAdmin

2 | 3 | [![Build Status](https://app.travis-ci.com/snowlyg/iris-admin.svg?branch=master)](https://app.travis-ci.com/snowlyg/iris-admin) 4 | [![LICENSE](https://img.shields.io/github/license/snowlyg/iris-admin)](https://github.com/snowlyg/iris-admin/blob/master/LICENSE) 5 | [![go doc](https://godoc.org/github.com/snowlyg/iris-admin?status.svg)](https://godoc.org/github.com/snowlyg/iris-admin) 6 | [![go report](https://goreportcard.com/badge/github.com/snowlyg/iris-admin)](https://goreportcard.com/badge/github.com/snowlyg/iris-admin) 7 | [![Build Status](https://codecov.io/gh/snowlyg/iris-admin/branch/master/graph/badge.svg)](https://codecov.io/gh/snowlyg/iris-admin) 8 | 9 | [简体中文](./README.md) | English 10 | 11 | #### Project url 12 | 13 | [GITHUB](https://github.com/snowlyg/iris-admin) | [GITEE](https://gitee.com/snowlyg/iris-admin) 14 | 15 | --- 16 | 17 | > This project just for learning golang, welcome to give your suggestions! 18 | 19 | #### Documentation 20 | 21 | - [IRIS-ADMIN-DOC](https://doc.snowlyg.com) 22 | - [IRIS V12 document for chinese](https://github.com/snowlyg/iris/wiki) 23 | - [godoc](https://pkg.go.dev/github.com/snowlyg/iris-admin?utm_source=godoc) 24 | 25 | [![Gitter](https://badges.gitter.im/iris-go-tenancy/community.svg)](https://gitter.im/iris-go-tenancy/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Join the chat at https://gitter.im/iris-go-tenancy/iris-admin](https://badges.gitter.im/iris-go-tenancy/iris-admin.svg)](https://gitter.im/iris-go-tenancy/iris-admin?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 26 | 27 | #### BLOG 28 | 29 | - [REST API with iris-go web framework](https://snowlyg.github.io/iris-go-api-1/) 30 | 31 | - [How to user iris-go with casbin](https://snowlyg.github.io/iris-go-api-2/) 32 | 33 | --- 34 | 35 | #### Getting started 36 | 37 | - Get master package , Notice must use `master` version. 38 | 39 | ```sh 40 | go get github.com/snowlyg/iris-admin@master 41 | ``` 42 | 43 | 44 | ## ☕️ Buy me a coffee 45 | 46 | > Please be sure to leave your name, GitHub account or other social media accounts when you donate by the following means so that I can add it to the list of donors as a token of my appreciation. 47 | - [爱发电](https://afdian.net/@snowlyg/plan) 48 | - [donating](https://paypal.me/snowlyg?country.x=C2&locale.x=zh_XC) 49 | -------------------------------------------------------------------------------- /auth2/claims.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "github.com/golang-jwt/jwt" 10 | ) 11 | 12 | const ( 13 | ValidationMalformed uint32 = 1 << iota // Token is malformed 14 | ValidationUnverifiable // Token could not be verified because of signing problems 15 | ValidationSignatureInvalid // Signature validation failed 16 | ValidationExpired // EXP validation failed 17 | ValidationId 18 | ValidationUsername 19 | ValidationAuthId 20 | ValidationRoleType 21 | ValidationLoginType 22 | ValidationAuthType 23 | ) 24 | 25 | type Claims struct { 26 | Id string `json:"id,omitempty" redis:"id"` 27 | SuperAdmin bool `json:"superAdmin,omitempty" redis:"super_admin"` 28 | Username string `json:"username,omitempty" redis:"username"` 29 | AuthId string `json:"authId,omitempty" redis:"auth_id"` 30 | RoleType int `json:"roleType,omitempty" redis:"role_type"` 31 | LoginType int `json:"loginType,omitempty" redis:"login_type"` 32 | AuthType int `json:"authType,omitempty" redis:"auth_type"` 33 | CreationTime int64 `json:"creationData,omitempty" redis:"creation_data"` 34 | ExpiresAt int64 `json:"expiresAt,omitempty" redis:"expires_at"` 35 | } 36 | 37 | func (c *Claims) roleType() RoleType { 38 | return RoleType(c.RoleType) 39 | } 40 | 41 | func (c *Claims) loginType() LoginType { 42 | return LoginType(c.LoginType) 43 | } 44 | 45 | func (c *Claims) authType() AuthType { 46 | return AuthType(c.AuthType) 47 | } 48 | 49 | func (c *Claims) setRoleType(roleType int) { 50 | c.RoleType = roleType 51 | } 52 | 53 | func (c *Claims) setLoginType(loginType int) { 54 | c.LoginType = loginType 55 | } 56 | 57 | func (c *Claims) setAuthType(authType int) { 58 | c.AuthType = authType 59 | } 60 | 61 | func NewClaims(m *Agent) *Claims { 62 | claims := &Claims{ 63 | Id: strconv.FormatUint(uint64(m.Id), 10), 64 | SuperAdmin: m.SuperAdmin, 65 | Username: m.Username, 66 | AuthId: strings.Join(m.AuthIds, "-"), 67 | RoleType: int(m.RoleType), 68 | LoginType: int(m.LoginType), 69 | AuthType: int(m.AuthType), 70 | CreationTime: time.Now().Local().Unix(), 71 | ExpiresAt: m.ExpiresAt, 72 | } 73 | return claims 74 | } 75 | 76 | func (c *Claims) Valid() error { 77 | vErr := new(jwt.ValidationError) 78 | now := time.Now().Unix() 79 | // The claims below are optional, by default, so if they are set to the 80 | // default value in Go, let's not fail the verification for them. 81 | if !c.VerifyExpiresAt(now, false) { 82 | delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) 83 | vErr.Inner = fmt.Errorf("claims:token is expired by %v", delta) 84 | vErr.Errors |= ValidationExpired 85 | } 86 | if !c.VerifyId() { 87 | vErr.Inner = fmt.Errorf("claims:id[%s] is empty", c.Id) 88 | vErr.Errors |= ValidationId 89 | } 90 | if !c.VerifyUsername() { 91 | vErr.Inner = fmt.Errorf("claims:username[%s] is empty", c.Username) 92 | vErr.Errors |= ValidationUsername 93 | } 94 | if !c.VerifyAuthId() { 95 | vErr.Inner = fmt.Errorf("claims:authId[%s] is empty", c.AuthId) 96 | vErr.Errors |= ValidationAuthId 97 | } 98 | if !c.VerifyType() { 99 | vErr.Inner = fmt.Errorf("claims:roleType[%d] is invalid", c.RoleType) 100 | vErr.Errors |= ValidationRoleType 101 | } 102 | if !c.VerifyLoginType() { 103 | vErr.Inner = fmt.Errorf("claims:loginType[%d] is invalid", c.LoginType) 104 | vErr.Errors |= ValidationLoginType 105 | } 106 | if !c.VerifyAuthType() { 107 | vErr.Inner = fmt.Errorf("claims:authType[%d] is invalid", c.AuthType) 108 | vErr.Errors |= ValidationAuthType 109 | } 110 | if !valid(vErr) { 111 | return vErr 112 | } 113 | 114 | return nil 115 | } 116 | 117 | // No errors 118 | func valid(e *jwt.ValidationError) bool { 119 | return e.Errors == 0 120 | } 121 | 122 | // Compares the exp claim against cmp. 123 | // If required is false, this method will return true if the value matches or is unset 124 | func (c *Claims) VerifyExpiresAt(cmp int64, req bool) bool { 125 | return verifyExp(c.ExpiresAt, cmp, req) 126 | } 127 | 128 | func verifyExp(exp int64, now int64, required bool) bool { 129 | if exp == 0 { 130 | return !required 131 | } 132 | return now <= exp 133 | } 134 | 135 | func (c *Claims) VerifyId() bool { 136 | if id, err := strconv.Atoi(c.Id); err != nil { 137 | return false 138 | } else if id > 0 { 139 | return true 140 | } 141 | return false 142 | } 143 | 144 | func (c *Claims) VerifyUsername() bool { 145 | return c.Username != "" 146 | } 147 | 148 | func (c *Claims) VerifyAuthId() bool { 149 | return c.AuthId != "" 150 | } 151 | 152 | func (c *Claims) VerifyType() bool { 153 | return c.RoleType > 0 154 | } 155 | 156 | func (c *Claims) VerifyLoginType() bool { 157 | return LoginType(c.LoginType) >= LoginTypeWeb && LoginType(c.LoginType) <= LoginTypeDevice 158 | } 159 | 160 | func (c *Claims) VerifyAuthType() bool { 161 | return AuthType(c.AuthType) >= NoAuth && AuthType(c.AuthType) <= AuthThirdParty 162 | } 163 | -------------------------------------------------------------------------------- /auth2/claims_test.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/golang-jwt/jwt" 9 | ) 10 | 11 | var testAgent = &Agent{ 12 | Id: uint(8457585), 13 | SuperAdmin: true, 14 | Username: "username", 15 | AuthIds: []string{"999"}, 16 | RoleType: RoleAdmin, 17 | LoginType: LoginTypeWeb, 18 | AuthType: AuthPwd, 19 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 20 | } 21 | 22 | func TestNewClaims(t *testing.T) { 23 | cla := NewClaims(testAgent) 24 | if cla == nil { 25 | t.Fatal("claims init return is nil") 26 | } 27 | if cla.Id != "8457585" { 28 | t.Error("claims id is not 8457585") 29 | } 30 | if cla.SuperAdmin != true { 31 | t.Error("claims super admin is not true") 32 | } 33 | if cla.Username != "username" { 34 | t.Error("claims username is not username") 35 | } 36 | if cla.AuthId != "999" { 37 | t.Error("claims auth ids is not 999") 38 | } 39 | if cla.roleType() != RoleAdmin { 40 | t.Error("claims type is not admin") 41 | } 42 | if cla.loginType() != LoginTypeWeb { 43 | t.Error("claims login type is not web") 44 | } 45 | if cla.authType() != AuthPwd { 46 | t.Error("claims auth type is not web") 47 | } 48 | if cla.ExpiresAt != time.Now().Local().Add(TimeoutWeb).Unix() { 49 | t.Error("claims expires at is not now") 50 | } 51 | 52 | cla.setAuthType(2) 53 | if cla.authType() != AuthCode { 54 | t.Error("claims authType is not authCode") 55 | } 56 | 57 | cla.setLoginType(1) 58 | if cla.loginType() != LoginTypeApp { 59 | t.Error("claims loginType is not loginTypeApp") 60 | } 61 | cla.setRoleType(2) 62 | if cla.roleType() != RoleTenancy { 63 | t.Error("claims roleType is not roleTenancy") 64 | } 65 | } 66 | 67 | func TestValid(t *testing.T) { 68 | cla := NewClaims(&Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, RoleType: RoleAdmin, LoginType: LoginTypeWeb, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}) 69 | if err := cla.Valid(); err != nil { 70 | t.Fatal(err) 71 | } 72 | args := []struct { 73 | agent *Agent 74 | name string 75 | want uint32 76 | }{ 77 | { 78 | name: "ValidationAuthType", 79 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, AuthType: 99, RoleType: RoleAdmin, LoginType: LoginTypeWeb, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}, 80 | want: ValidationAuthType, 81 | }, 82 | { 83 | name: "ValidationLoginType", 84 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, LoginType: 99, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}, 85 | want: ValidationLoginType, 86 | }, 87 | { 88 | name: "ValidationRoleType", 89 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, LoginType: LoginTypeApp, RoleType: -1, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}, 90 | want: ValidationRoleType, 91 | }, 92 | { 93 | name: "ValidationAuthId", 94 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{""}, LoginType: LoginTypeApp, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}, 95 | want: ValidationAuthId, 96 | }, 97 | { 98 | name: "ValidationUsername", 99 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "", AuthIds: []string{"1"}, LoginType: LoginTypeApp, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}, 100 | want: ValidationUsername, 101 | }, 102 | { 103 | name: "ValidationId", 104 | agent: &Agent{Id: 0, SuperAdmin: true, Username: "username", AuthIds: []string{"1"}, LoginType: LoginTypeApp, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()}, 105 | want: ValidationId, 106 | }, 107 | } 108 | 109 | for _, arg := range args { 110 | t.Run(arg.name, func(t *testing.T) { 111 | cla = NewClaims(arg.agent) 112 | if err := cla.Valid(); err == nil { 113 | t.Fatal("error is nil") 114 | } else { 115 | 116 | if v, ok := err.(*jwt.ValidationError); !ok { 117 | t.Fatalf("%s %s", reflect.TypeOf(err).String(), err.Error()) 118 | } else if v.Errors != arg.want { 119 | t.Fatalf("%d %d %s", v.Errors, arg.want, v.Error()) 120 | } 121 | } 122 | }) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /auth2/example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/go-redis/redis/v8" 10 | "github.com/snowlyg/iris-admin/auth2" 11 | ) 12 | 13 | func init() { 14 | options := &redis.UniversalOptions{ 15 | Addrs: []string{"127.0.0.1:6379"}, 16 | Password: "", 17 | PoolSize: 10, 18 | IdleTimeout: 300 * time.Second, 19 | // Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { 20 | // conn, err := net.Dial(network, addr) 21 | // if err == nil { 22 | // go func() { 23 | // time.Sleep(5 * time.Second) 24 | // conn.Close() 25 | // }() 26 | // } 27 | // return conn, err 28 | // }, 29 | } 30 | 31 | err := auth2.NewAgent(&auth2.Config{ 32 | Type: "redis", 33 | Max: 10, 34 | UniversalClient: redis.NewUniversalClient(options)}) 35 | if err != nil { 36 | panic(fmt.Sprintf("auth is not init get err %v\n", err)) 37 | } 38 | } 39 | 40 | func auth() gin.HandlerFunc { 41 | verifier := auth2.NewVerifier() 42 | verifier.Extractors = []auth2.TokenExtractor{auth2.FromHeader} // extract token only from Authorization: Bearer $token 43 | return verifier.Verify() 44 | } 45 | 46 | func main() { 47 | app := gin.New() 48 | 49 | app.GET("/", generateToken()) 50 | 51 | protectedAPI := app.Group("/protected") 52 | // Register the verify middleware to allow access only to authorized clients. 53 | protectedAPI.Use(auth()) 54 | // ^ or UseRouter(verifyMiddleware) to disallow unauthorized http error handlers too. 55 | 56 | protectedAPI.GET("/", protected) 57 | // Invalidate the token through server-side, even if it's not expired yet. 58 | protectedAPI.GET("/logout", logout) 59 | 60 | // http://localhost:8080 61 | // http://localhost:8080/protected (or Authorization: Bearer $token) 62 | // http://localhost:8080/protected/logout 63 | // http://localhost:8080/protected (401) 64 | app.Run(":8080") 65 | } 66 | 67 | func generateToken() gin.HandlerFunc { 68 | return func(ctx *gin.Context) { 69 | claims := auth2.NewClaims(&auth2.Agent{ 70 | Id: 1, 71 | Username: "your name", 72 | AuthIds: []string{"your authority id"}, 73 | RoleType: auth2.RoleAdmin, 74 | LoginType: auth2.LoginTypeWeb, 75 | AuthType: auth2.AuthPwd, 76 | CreationTime: time.Now().Local().Unix(), 77 | ExpiresAt: time.Now().Local().Add(auth2.TimeoutWeb).Unix(), 78 | }) 79 | 80 | token, _, err := auth2.AuthAgent.Generate(claims) 81 | if err != nil { 82 | ctx.AbortWithStatus(http.StatusInternalServerError) 83 | return 84 | } 85 | 86 | ctx.String(200, token) 87 | } 88 | } 89 | 90 | func protected(ctx *gin.Context) { 91 | claims := auth2.Get(ctx) 92 | ctx.JSON(http.StatusOK, fmt.Sprintf("claims=%+v\n", claims)) 93 | } 94 | 95 | func logout(ctx *gin.Context) { 96 | token := auth2.GetVerifiedToken(ctx) 97 | if token == nil { 98 | ctx.String(http.StatusOK, auth2.ErrEmptyToken.Error()) 99 | return 100 | } 101 | err := auth2.AuthAgent.DelCache(string(token)) 102 | if err != nil { 103 | ctx.JSON(http.StatusOK, err.Error()) 104 | return 105 | } 106 | ctx.String(http.StatusOK, "token invalidated, a new token is required to access the protected API") 107 | } 108 | -------------------------------------------------------------------------------- /auth2/extractor.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "log" 5 | "strings" 6 | 7 | "github.com/gin-gonic/gin" 8 | "github.com/gin-gonic/gin/binding" 9 | ) 10 | 11 | // TokenExtractor is a function that takes a context as input and returns 12 | // a token. An empty string should be returned if no token found 13 | // without additional information. 14 | type TokenExtractor func(*gin.Context) string 15 | 16 | // FromHeader is a token extractor. 17 | // It reads the token from the Authorization request header of form: 18 | // Authorization: "Bearer {token}". 19 | func FromHeader(ctx *gin.Context) string { 20 | authHeader := ctx.GetHeader("Authorization") 21 | if authHeader == "" { 22 | return "" 23 | } 24 | 25 | // pure check: authorization header format must be Bearer {token} 26 | authHeaderParts := strings.Split(authHeader, " ") 27 | if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { 28 | return "" 29 | } 30 | 31 | return authHeaderParts[1] 32 | } 33 | 34 | // FromQuery is a token extractor. 35 | // It reads the token from the "token" url query parameter. 36 | func FromQuery(ctx *gin.Context) string { 37 | return ctx.Query("token") 38 | } 39 | 40 | // FromJSON is a token extractor. 41 | // Reads a json request body and extracts the json based on the given field. 42 | // The request content-type should contain the: application/json header value, otherwise 43 | // this method will not try to read and consume the body. 44 | func FromJSON(jsonKey string) TokenExtractor { 45 | return func(ctx *gin.Context) string { 46 | if ctx.ContentType() != binding.MIMEJSON { 47 | log.Printf("extractor: content-type %s not supported\n", ctx.ContentType()) 48 | return "" 49 | } 50 | 51 | var m gin.H 52 | if err := ctx.BindJSON(&m); err != nil { 53 | log.Println("extractor: bind json error:", err.Error()) 54 | return "" 55 | } 56 | 57 | if m == nil { 58 | log.Println("extractor: json is empty") 59 | return "" 60 | } 61 | 62 | v, ok := m[jsonKey] 63 | if !ok { 64 | log.Printf("extractor: key %s not found\n", jsonKey) 65 | return "" 66 | } 67 | 68 | tok, ok := v.(string) 69 | if !ok { 70 | log.Printf("extractor: key %s value:[%v] is not a string\n", jsonKey, v) 71 | return "" 72 | } 73 | return tok 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /auth2/interface.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/go-redis/redis/v8" 9 | ) 10 | 11 | const ( 12 | TokenPrefix = "GST:" 13 | BindUserPrefix = "GSBU:" 14 | UserPrefix = "GSU:" 15 | LimitTokenPrefix = "GT_LIMIT_TOKEN" 16 | ) 17 | 18 | var ( 19 | AuthTypeSplit = "-" 20 | LimitTokenDefault int64 = 10 21 | ) 22 | 23 | var ( 24 | ErrTokenInvalid = errors.New("token_invalid") 25 | ErrEmptyToken = errors.New("token_empty") 26 | ErrOverLimit = errors.New("token_over_limit") 27 | ) 28 | 29 | type RoleType int 30 | 31 | const ( 32 | RoleNone RoleType = iota 33 | RoleAdmin 34 | RoleTenancy 35 | RoleGeneral 36 | ) 37 | 38 | type AuthType int 39 | 40 | const ( 41 | NoAuth AuthType = iota 42 | AuthPwd 43 | AuthCode 44 | AuthThirdParty 45 | ) 46 | 47 | type LoginType int 48 | 49 | const ( 50 | LoginTypeWeb LoginType = iota 51 | LoginTypeApp 52 | LoginTypeWx 53 | LoginTypeDevice 54 | ) 55 | 56 | var ( 57 | TimeoutWeb = 4 * time.Hour 58 | TimeoutApp = 7 * 24 * time.Hour 59 | TimeoutWx = 5 * 52 * 168 * time.Hour 60 | TimeoutDevice = 5 * 52 * 168 * time.Hour 61 | ) 62 | 63 | func NewAgent(c *Config) error { 64 | if c.Max == 0 { 65 | c.Max = 10 66 | } 67 | switch c.Type { 68 | case "redis": 69 | agent, err := NewRedis(c.UniversalClient) 70 | if err != nil { 71 | return err 72 | } 73 | 74 | AuthAgent = agent 75 | err = AuthAgent.SetLimit(c.Max) 76 | if err != nil { 77 | return err 78 | } 79 | case "local": 80 | AuthAgent = NewLocal() 81 | err := AuthAgent.SetLimit(c.Max) 82 | if err != nil { 83 | return err 84 | } 85 | case "jwt": 86 | AuthAgent = NewJwt(c.HmacSecret) 87 | default: 88 | AuthAgent = NewJwt(c.HmacSecret) 89 | } 90 | 91 | return nil 92 | } 93 | 94 | // Agent 95 | type Agent struct { 96 | Id uint `json:"id,omitempty"` 97 | SuperAdmin bool `json:"superAdmin,omitempty"` 98 | Username string `json:"username,omitempty"` 99 | AuthIds []string `json:"authIds,omitempty"` 100 | RoleType RoleType `json:"type,omitempty"` 101 | LoginType LoginType `json:"loginType,omitempty"` 102 | AuthType AuthType `json:"authType,omitempty"` 103 | CreationTime int64 `json:"creationData,omitempty"` 104 | ExpiresAt int64 `json:"expiresAt,omitempty"` 105 | } 106 | 107 | type Config struct { 108 | Type string 109 | Max int64 110 | UniversalClient redis.UniversalClient 111 | HmacSecret []byte 112 | } 113 | 114 | type ( 115 | // TokenValidator provides further token and claims validation. 116 | TokenValidator interface { 117 | // Validater accepts the token, the claims extracted from that 118 | // and any error that may caused by claims validation (e.g. ErrExpired) 119 | // or the previous validator. 120 | // A token validator can skip the builtin validation and return a nil error. 121 | // Usage: 122 | // func(v *myValidator) Validater(token []byte, standardClaims Claims, err error) error { 123 | // if err!=nil { return err } <- to respect the previous error 124 | // // otherwise return nil or any custom error. 125 | // } 126 | // 127 | // Look `Blocklist`, `Expected` and `Leeway` for builtin implementations. 128 | Validater(token []byte, err error) error 129 | } 130 | 131 | // ValidatorFunc is the interface-as-function shortcut for a TokenValidator. 132 | ValidatorFunc func(token []byte, err error) error 133 | ) 134 | 135 | // ValidateToken completes the ValidateToken interface. 136 | // It calls itself. 137 | func (fn ValidatorFunc) ValidateToken(token []byte, err error) error { 138 | return fn(token, err) 139 | } 140 | 141 | var AuthAgent Authentication 142 | 143 | // Authentication 144 | type Authentication interface { 145 | Generate(claims *Claims) (string, int64, error) 146 | DelCache(token string) error 147 | UpdateCacheExpire(token string) error 148 | GetClaims(token string) (*Claims, error) 149 | Token(claims *Claims) (string, error) 150 | CleanCache(roleType RoleType, userId string) error 151 | SetLimit(max int64) error 152 | IsRole(token string, roleType RoleType) (bool, error) 153 | IsSuperAdmin(token string) bool 154 | Close() 155 | } 156 | 157 | // getExpire 158 | func getExpire(loginType LoginType) time.Duration { 159 | switch loginType { 160 | case LoginTypeWeb: 161 | return TimeoutWeb 162 | case LoginTypeWx: 163 | return TimeoutWx 164 | case LoginTypeApp: 165 | return TimeoutApp 166 | case LoginTypeDevice: 167 | return TimeoutDevice 168 | default: 169 | return TimeoutWeb 170 | } 171 | } 172 | 173 | // getPrefixKey 174 | func getPrefixKey(roleType RoleType, id string) string { 175 | return fmt.Sprintf("%s%d_%s", UserPrefix, roleType, id) 176 | } 177 | -------------------------------------------------------------------------------- /auth2/jwt.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/golang-jwt/jwt" 8 | "github.com/snowlyg/helper/arr" 9 | ) 10 | 11 | var hmacSampleSecret = []byte("updPA0L2uQ56LwHZoyUX") 12 | 13 | // JwtAuth 14 | type JwtAuth struct { 15 | HmacSecret []byte 16 | delToken arr.ArrayType 17 | } 18 | 19 | // NewJwt 20 | func NewJwt(hmacSecret []byte) *JwtAuth { 21 | ja := &JwtAuth{ 22 | HmacSecret: hmacSecret, 23 | delToken: arr.NewCheckArrayType(0), 24 | } 25 | if ja.HmacSecret == nil { 26 | ja.HmacSecret = hmacSampleSecret 27 | } 28 | return ja 29 | } 30 | 31 | // Generate 32 | func (ra *JwtAuth) Generate(claims *Claims) (string, int64, error) { 33 | token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 34 | 35 | // Sign and get the complete encoded token as a string using the secret 36 | tokenString, err := token.SignedString(ra.HmacSecret) 37 | if err != nil { 38 | return "", 0, err 39 | } 40 | return tokenString, 0, nil 41 | } 42 | 43 | // Token 44 | func (ra *JwtAuth) Token(cla *Claims) (string, error) { 45 | log.Printf("jwt:get token not support\n") 46 | return "", nil 47 | } 48 | 49 | // GetClaims 50 | func (ra *JwtAuth) GetClaims(tokenString string) (*Claims, error) { 51 | if ra.delToken.Check(tokenString) { 52 | return nil, fmt.Errorf("jwt:token deleted %w", ErrTokenInvalid) 53 | } 54 | mc := &Claims{} 55 | token, err := jwt.ParseWithClaims(tokenString, mc, func(token *jwt.Token) (interface{}, error) { 56 | // Don't forget to validate the alg is what you expect: 57 | if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 58 | return nil, fmt.Errorf("incorrect signing method: %v", token.Header["alg"]) 59 | } 60 | return ra.HmacSecret, nil 61 | }) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | if _, ok := token.Claims.(*Claims); ok && token.Valid { 67 | return mc, nil 68 | } else { 69 | return nil, fmt.Errorf("token[%s]:%w", tokenString, ErrTokenInvalid) 70 | } 71 | } 72 | 73 | // SetLimit 74 | func (ra *JwtAuth) SetLimit(limit int64) error { 75 | log.Printf("jwt:set max count not support\n") 76 | return nil 77 | } 78 | 79 | // UpdateCacheExpire 80 | func (ra *JwtAuth) UpdateCacheExpire(token string) error { 81 | log.Printf("jwt:UpdateCacheExpire not support\n") 82 | return nil 83 | } 84 | 85 | // DelCache 86 | func (ra *JwtAuth) DelCache(token string) error { 87 | ra.delToken.Add(token) 88 | return nil 89 | } 90 | 91 | // CleanCache 92 | func (ra *JwtAuth) CleanCache(roleType RoleType, userId string) error { 93 | log.Printf("jwt:CleanCache not support") 94 | return nil 95 | } 96 | 97 | // IsRole 98 | func (ra *JwtAuth) IsRole(token string, roleType RoleType) (bool, error) { 99 | rcc, err := ra.GetClaims(token) 100 | if err != nil { 101 | return false, fmt.Errorf("jwt:get User's infomation return error: %w", err) 102 | } 103 | return rcc.roleType() == roleType, nil 104 | } 105 | 106 | // IsSuperAdmin 107 | func (ra *JwtAuth) IsSuperAdmin(token string) bool { 108 | rcc, err := ra.GetClaims(token) 109 | if err != nil { 110 | log.Printf("jwt:get claims fail:%s\n", err.Error()) 111 | return false 112 | } 113 | return rcc.SuperAdmin 114 | } 115 | 116 | // Close 117 | func (ra *JwtAuth) Close() {} 118 | -------------------------------------------------------------------------------- /auth2/jwt_test.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | var ( 9 | jwtAuth = NewJwt(nil) 10 | jwtClaims = NewClaims( 11 | &Agent{ 12 | Id: uint(8457585), 13 | Username: "jwt username", 14 | AuthIds: []string{"999"}, 15 | RoleType: RoleAdmin, 16 | LoginType: LoginTypeWeb, 17 | AuthType: AuthPwd, 18 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 19 | }, 20 | ) 21 | ) 22 | 23 | func TestJwtGenerateToken(t *testing.T) { 24 | defer jwtAuth.CleanCache(jwtClaims.roleType(), jwtClaims.Id) 25 | token, _, err := jwtAuth.Generate(jwtClaims) 26 | if err != nil { 27 | t.Fatalf("generate token %v", err) 28 | } 29 | if token == "" { 30 | t.Error("generate token is empty") 31 | } 32 | 33 | cc, err := jwtAuth.GetClaims(token) 34 | if err != nil { 35 | t.Fatalf("get custom claims fail:%v", err) 36 | } 37 | 38 | if cc.Id != jwtClaims.Id { 39 | t.Errorf("get custom id want %v but get %v", jwtClaims.Id, cc.Id) 40 | } 41 | if cc.Username != jwtClaims.Username { 42 | t.Errorf("get custom username want %v but get %v", jwtClaims.Username, cc.Username) 43 | } 44 | if cc.AuthId != jwtClaims.AuthId { 45 | t.Errorf("get custom authority_id want %v but get %v", jwtClaims.AuthId, cc.AuthId) 46 | } 47 | if cc.RoleType != jwtClaims.RoleType { 48 | t.Errorf("get custom authority_type want %v but get %v", jwtClaims.RoleType, cc.RoleType) 49 | } 50 | if cc.LoginType != jwtClaims.LoginType { 51 | t.Errorf("get custom login_type want %v but get %v", jwtClaims.LoginType, cc.LoginType) 52 | } 53 | if cc.AuthType != jwtClaims.AuthType { 54 | t.Errorf("get custom auth_type want %v but get %v", jwtClaims.AuthType, cc.AuthType) 55 | } 56 | if cc.CreationTime != jwtClaims.CreationTime { 57 | t.Errorf("get custom creation_data want %v but get %v", jwtClaims.CreationTime, cc.CreationTime) 58 | } 59 | if cc.ExpiresAt != jwtClaims.ExpiresAt { 60 | t.Errorf("get custom expires_at want %v but get %v", jwtClaims.ExpiresAt, cc.ExpiresAt) 61 | } 62 | } 63 | 64 | func TestJwtSetUserTokenMaxCount(t *testing.T) { 65 | err := jwtAuth.SetLimit(3) 66 | if err != nil { 67 | t.Errorf("get token by claims token want %v but get %v", nil, err) 68 | } 69 | } 70 | 71 | func TestJwtGetMultiClaims(t *testing.T) { 72 | defer jwtAuth.CleanCache(jwtClaims.roleType(), jwtClaims.Id) 73 | var token string 74 | jwtClaims.setLoginType(int(LoginTypeWeb)) 75 | token, _, err := jwtAuth.Generate(jwtClaims) 76 | if err != nil { 77 | t.Fatalf("get custom claims %v", err) 78 | } 79 | for i := LoginTypeWeb; i <= LoginTypeDevice; i++ { 80 | wg.Add(1) 81 | jwtClaims.setLoginType(int(i)) 82 | go func(i LoginType) { 83 | jwtAuth.Generate(jwtClaims) 84 | wg.Done() 85 | }(i) 86 | wg.Wait() 87 | } 88 | for i := 0; i < 4; i++ { 89 | go func() { 90 | _, err := jwtAuth.GetClaims(token) 91 | if err != nil { 92 | t.Errorf("get custom claims %v", err) 93 | } 94 | }() 95 | } 96 | time.Sleep(3 * time.Second) 97 | } 98 | 99 | func TestJwtGetTokenByClaims(t *testing.T) { 100 | _, err := jwtAuth.Token(jwtClaims) 101 | if err != nil { 102 | t.Errorf("get token by claims token want %v but get %v", nil, err) 103 | } 104 | } 105 | 106 | func TestJwtDelUserTokenCache(t *testing.T) { 107 | token, _, _ := jwtAuth.Generate(jwtClaims) 108 | if token == "" { 109 | t.Error("generate token is empty") 110 | } 111 | err := jwtAuth.DelCache(token) 112 | if err != nil { 113 | t.Errorf("del token fail:%v", err.Error()) 114 | } 115 | if _, err := jwtAuth.GetClaims(token); err == nil { 116 | t.Error("del user token fail") 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /auth2/local.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/patrickmn/go-cache" 9 | ) 10 | 11 | type tokens []string 12 | 13 | var localCache *cache.Cache 14 | 15 | type LocalAuth struct { 16 | Cache *cache.Cache 17 | } 18 | 19 | func NewLocal() *LocalAuth { 20 | if localCache == nil { 21 | localCache = cache.New(4*time.Hour, 24*time.Minute) 22 | } 23 | return &LocalAuth{ 24 | Cache: localCache, 25 | } 26 | } 27 | 28 | // Generate 29 | func (la *LocalAuth) Generate(claims *Claims) (string, int64, error) { 30 | 31 | if la.isUserTokenOver(claims.roleType(), claims.Id) { 32 | return "", 0, fmt.Errorf("local: is user token over fail:%w", ErrOverLimit) 33 | } 34 | token, err := getToken() 35 | if err != nil { 36 | return "", 0, fmt.Errorf("local: get token fail:%w", err) 37 | } 38 | la.toCache(token, claims) 39 | if e := la.syncUserCache(token); e != nil { 40 | return "", 0, fmt.Errorf("local: sync user token fail:%w", e) 41 | } 42 | return token, int64(claims.ExpiresAt), nil 43 | } 44 | 45 | func (la *LocalAuth) toCache(token string, rcc *Claims) error { 46 | sKey := TokenPrefix + token 47 | la.Cache.Set(sKey, rcc, getExpire(rcc.loginType())) 48 | return nil 49 | } 50 | 51 | func (la *LocalAuth) syncUserCache(token string) error { 52 | rcc, err := la.GetClaims(token) 53 | if err != nil { 54 | return err 55 | } 56 | userPrefixKey := getPrefixKey(rcc.roleType(), rcc.Id) 57 | ts := tokens{} 58 | if uTokens, ok := la.Cache.Get(userPrefixKey); ok && uTokens != nil { 59 | ts = uTokens.(tokens) 60 | } 61 | ts = append(ts, token) 62 | la.Cache.Set(userPrefixKey, ts, cache.NoExpiration) 63 | la.Cache.Set(BindUserPrefix+token, userPrefixKey, getExpire(rcc.loginType())) 64 | return nil 65 | } 66 | 67 | func (la *LocalAuth) DelCache(token string) error { 68 | rcc, err := la.GetClaims(token) 69 | if err != nil { 70 | return err 71 | } 72 | userPrefixKey := getPrefixKey(rcc.roleType(), rcc.Id) 73 | if utokens, ok := la.Cache.Get(userPrefixKey); ok && utokens != nil { 74 | t := utokens.(tokens) 75 | for index, u := range t { 76 | if u == token { 77 | if len(t) == 1 { 78 | utokens = nil 79 | } else { 80 | utokens = append(t[0:index], t[index:]...) 81 | } 82 | } 83 | } 84 | la.Cache.Set(userPrefixKey, utokens, cache.NoExpiration) 85 | } 86 | la.delTokenCache(token) 87 | return nil 88 | } 89 | 90 | // delTokenCache 91 | func (la *LocalAuth) delTokenCache(token string) error { 92 | la.Cache.Delete(BindUserPrefix + token) 93 | la.Cache.Delete(TokenPrefix + token) 94 | return nil 95 | } 96 | 97 | func (la *LocalAuth) UpdateCacheExpire(token string) error { 98 | rsv2, err := la.GetClaims(token) 99 | if err != nil { 100 | return err 101 | } 102 | if rsv2 == nil { 103 | return errors.New("token cache is nil") 104 | } 105 | la.Cache.Set(BindUserPrefix+token, rsv2, getExpire(rsv2.loginType())) 106 | la.Cache.Set(TokenPrefix+token, rsv2, getExpire(rsv2.loginType())) 107 | return nil 108 | } 109 | 110 | func (la *LocalAuth) GetClaims(token string) (*Claims, error) { 111 | sKey := TokenPrefix + token 112 | if food, found := la.Cache.Get(sKey); !found || food == nil { 113 | return nil, fmt.Errorf("token not found:%w", ErrTokenInvalid) 114 | } else { 115 | return food.(*Claims), nil 116 | } 117 | } 118 | 119 | // Token 120 | func (la *LocalAuth) Token(cla *Claims) (string, error) { 121 | userTokens, err := la.getUserTokens(cla.roleType(), cla.Id) 122 | if err != nil { 123 | return "", err 124 | } 125 | clas, err := la.getMultiClaimses(userTokens) 126 | if err != nil { 127 | return "", err 128 | } 129 | for token, existCla := range clas { 130 | if cla.AuthType == existCla.AuthType && 131 | cla.Id == existCla.Id && 132 | cla.RoleType == existCla.RoleType && 133 | cla.AuthId == existCla.AuthId && 134 | cla.LoginType == existCla.LoginType { 135 | return token, nil 136 | } 137 | } 138 | return "", nil 139 | } 140 | 141 | // getUserTokens 142 | func (la *LocalAuth) getUserTokens(roleType RoleType, userId string) (tokens, error) { 143 | if utokens, ok := la.Cache.Get(getPrefixKey(roleType, userId)); ok && utokens != nil { 144 | return utokens.(tokens), nil 145 | } 146 | return nil, nil 147 | } 148 | 149 | // getMultiClaimses 获取用户信息 150 | func (la *LocalAuth) getMultiClaimses(tokens tokens) (map[string]*Claims, error) { 151 | clas := make(map[string]*Claims, la.getUserTokenMaxCount()) 152 | for _, token := range tokens { 153 | cla, err := la.GetClaims(token) 154 | if err != nil { 155 | continue 156 | } 157 | clas[token] = cla 158 | } 159 | 160 | return clas, nil 161 | } 162 | 163 | func (la *LocalAuth) isUserTokenOver(roleType RoleType, userId string) bool { 164 | return la.getUserTokenCount(roleType, userId) >= la.getUserTokenMaxCount() 165 | } 166 | 167 | // getUserTokenCount 获取登录数量 168 | func (la *LocalAuth) getUserTokenCount(roleType RoleType, userId string) int64 { 169 | return la.checkMaxCount(roleType, userId) 170 | } 171 | 172 | func (la *LocalAuth) checkMaxCount(roleType RoleType, userId string) int64 { 173 | utokens, _ := la.getUserTokens(roleType, userId) 174 | if utokens == nil { 175 | return 0 176 | } 177 | for index, u := range utokens { 178 | if _, found := la.Cache.Get(TokenPrefix + u); !found { 179 | if len(utokens) == 1 { 180 | utokens = nil 181 | } else { 182 | utokens = append(utokens[0:index], utokens[index:]...) 183 | } 184 | } 185 | } 186 | la.Cache.Set(getPrefixKey(roleType, userId), utokens, cache.NoExpiration) 187 | return int64(len(utokens)) 188 | 189 | } 190 | 191 | // getUserTokenMaxCount 192 | func (la *LocalAuth) getUserTokenMaxCount() int64 { 193 | if count, found := la.Cache.Get(LimitTokenPrefix); !found { 194 | return LimitTokenDefault 195 | } else { 196 | return count.(int64) 197 | } 198 | } 199 | 200 | // SetLimit 201 | func (la *LocalAuth) SetLimit(tokenMaxCount int64) error { 202 | la.Cache.Set(LimitTokenPrefix, tokenMaxCount, cache.NoExpiration) 203 | return nil 204 | } 205 | 206 | // CleanCache 207 | func (la *LocalAuth) CleanCache(roleType RoleType, userId string) error { 208 | utokens, _ := la.getUserTokens(roleType, userId) 209 | if utokens == nil { 210 | return nil 211 | } 212 | for _, token := range utokens { 213 | err := la.delTokenCache(token) 214 | if err != nil { 215 | continue 216 | } 217 | } 218 | la.Cache.Delete(getPrefixKey(roleType, userId)) 219 | return nil 220 | } 221 | 222 | // IsRole 223 | func (la *LocalAuth) IsRole(token string, roleType RoleType) (bool, error) { 224 | rcc, err := la.GetClaims(token) 225 | if err != nil { 226 | return false, fmt.Errorf("local: get multi claims fail %w", err) 227 | } 228 | return rcc.roleType() == roleType, nil 229 | } 230 | 231 | // IsRole 232 | func (la *LocalAuth) IsSuperAdmin(token string) bool { 233 | rcc, err := la.GetClaims(token) 234 | if err != nil { 235 | return false 236 | } 237 | return rcc.SuperAdmin 238 | } 239 | 240 | func (la *LocalAuth) Close() {} 241 | -------------------------------------------------------------------------------- /auth2/local_test.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | var ( 10 | localAuth = NewLocal() 11 | tToken = "TVRReU1EVTFOek13TmpFd09UWXlPRFF4TmcuTWpBeU1TMHdOeTB5T1ZRd09Ub3pNRG95T1Nzd09Eb3dNQQ.MTQyMDU1NzMwNjEwOTYyODQxNg" 12 | loginTypeApp = NewClaims( 13 | &Agent{ 14 | Id: uint(1), 15 | SuperAdmin: true, 16 | Username: "username", 17 | AuthIds: []string{"999"}, 18 | RoleType: RoleAdmin, 19 | LoginType: LoginTypeWeb, 20 | AuthType: AuthPwd, 21 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 22 | }, 23 | ) 24 | userKey = getPrefixKey(loginTypeApp.roleType(), loginTypeApp.Id) 25 | ) 26 | 27 | func TestNewLocalAuth(t *testing.T) { 28 | if NewLocal() == nil { 29 | t.Error("new local auth get nil") 30 | } 31 | } 32 | 33 | func TestGenerateToken(t *testing.T) { 34 | token, expiresIn, err := localAuth.Generate(loginTypeApp) 35 | if err != nil { 36 | t.Fatalf("generate token %v", err) 37 | } 38 | if token == "" { 39 | t.Error("generate token is empty") 40 | } 41 | 42 | if expiresIn != loginTypeApp.ExpiresAt { 43 | t.Errorf("generate token expires want %v but get %v", loginTypeApp.ExpiresAt, expiresIn) 44 | } 45 | cc, err := localAuth.GetClaims(token) 46 | if err != nil { 47 | t.Fatalf("get custom claims %v", err) 48 | } 49 | 50 | if cc.Id != loginTypeApp.Id { 51 | t.Errorf("get custom id want %v but get %v", loginTypeApp.Id, cc.Id) 52 | } 53 | if cc.Username != loginTypeApp.Username { 54 | t.Errorf("get custom username want %v but get %v", loginTypeApp.Username, cc.Username) 55 | } 56 | if cc.AuthId != loginTypeApp.AuthId { 57 | t.Errorf("get custom authority_id want %v but get %v", loginTypeApp.AuthId, cc.AuthId) 58 | } 59 | if cc.RoleType != loginTypeApp.RoleType { 60 | t.Errorf("get custom authority_type want %v but get %v", loginTypeApp.RoleType, cc.RoleType) 61 | } 62 | if cc.LoginType != loginTypeApp.LoginType { 63 | t.Errorf("get custom login_type want %v but get %v", loginTypeApp.LoginType, cc.LoginType) 64 | } 65 | if cc.AuthType != loginTypeApp.AuthType { 66 | t.Errorf("get custom auth_type want %v but get %v", loginTypeApp.AuthType, cc.AuthType) 67 | } 68 | if cc.CreationTime != loginTypeApp.CreationTime { 69 | t.Errorf("get custom creation_data want %v but get %v", loginTypeApp.CreationTime, cc.CreationTime) 70 | } 71 | if cc.ExpiresAt != loginTypeApp.ExpiresAt { 72 | t.Errorf("get custom expires_at want %v but get %v", loginTypeApp.ExpiresAt, cc.ExpiresAt) 73 | } 74 | 75 | if uTokens, uFound := localAuth.Cache.Get(userKey); uFound { 76 | ts := uTokens.(tokens) 77 | if len(ts) == 0 || ts[0] != token { 78 | t.Errorf("user prefix value want %v but get %v", userKey, uTokens) 79 | } 80 | } else { 81 | t.Error("user prefix value is emptpy") 82 | } 83 | bindKey := BindUserPrefix + token 84 | if uTokens, uFound := localAuth.Cache.Get(bindKey); uFound { 85 | if uTokens != userKey { 86 | t.Errorf("bind user prefix value want %v but get %v", userKey, uTokens) 87 | } 88 | } else { 89 | t.Error("bind user prefix value is emptpy") 90 | } 91 | } 92 | 93 | func TestToCache(t *testing.T) { 94 | err := localAuth.toCache(tToken, loginTypeApp) 95 | if err != nil { 96 | t.Fatalf("generate token %v", err) 97 | } 98 | cc, err := localAuth.GetClaims(tToken) 99 | if err != nil { 100 | t.Fatalf("get custom claims %v", err) 101 | } 102 | 103 | if cc.Id != loginTypeApp.Id { 104 | t.Errorf("get custom id want %v but get %v", loginTypeApp.Id, cc.Id) 105 | } 106 | if cc.Username != loginTypeApp.Username { 107 | t.Errorf("get custom username want %v but get %v", loginTypeApp.Username, cc.Username) 108 | } 109 | if cc.AuthId != loginTypeApp.AuthId { 110 | t.Errorf("get custom authority_id want %v but get %v", loginTypeApp.AuthId, cc.AuthId) 111 | } 112 | if cc.RoleType != loginTypeApp.RoleType { 113 | t.Errorf("get custom authority_type want %v but get %v", loginTypeApp.RoleType, cc.RoleType) 114 | } 115 | if cc.LoginType != loginTypeApp.LoginType { 116 | t.Errorf("get custom login_type want %v but get %v", loginTypeApp.LoginType, cc.LoginType) 117 | } 118 | if cc.AuthType != loginTypeApp.AuthType { 119 | t.Errorf("get custom auth_type want %v but get %v", loginTypeApp.AuthType, cc.AuthType) 120 | } 121 | if cc.CreationTime != loginTypeApp.CreationTime { 122 | t.Errorf("get custom creation_data want %v but get %v", loginTypeApp.CreationTime, cc.CreationTime) 123 | } 124 | if cc.ExpiresAt != loginTypeApp.ExpiresAt { 125 | t.Errorf("get custom expires_at want %v but get %v", loginTypeApp.ExpiresAt, cc.ExpiresAt) 126 | } 127 | } 128 | 129 | func TestDelUserTokenCache(t *testing.T) { 130 | cc := NewClaims( 131 | &Agent{ 132 | Id: uint(2), 133 | Username: "username", 134 | SuperAdmin: true, 135 | AuthIds: []string{"999"}, 136 | RoleType: RoleAdmin, 137 | LoginType: LoginTypeWeb, 138 | AuthType: AuthPwd, 139 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 140 | }, 141 | ) 142 | token, _, _ := localAuth.Generate(cc) 143 | if token == "" { 144 | t.Error("generate token is empty") 145 | } 146 | 147 | err := localAuth.DelCache(token) 148 | if err != nil { 149 | t.Fatalf("del user token cache %v", err) 150 | } 151 | _, err = localAuth.GetClaims(token) 152 | if !errors.Is(err, ErrTokenInvalid) { 153 | t.Fatalf("get custom claims err want %v but get %v", ErrTokenInvalid, err) 154 | } 155 | 156 | if uTokens, uFound := localAuth.Cache.Get(UserPrefix + cc.Id); uFound && uTokens != nil { 157 | t.Errorf("user prefix value want empty but get %v", uTokens) 158 | } 159 | bindKey := BindUserPrefix + token 160 | if key, uFound := localAuth.Cache.Get(bindKey); uFound { 161 | t.Errorf("bind user prefix value want empty but get %v", key) 162 | } 163 | } 164 | 165 | func TestIsUserTokenOver(t *testing.T) { 166 | cc := NewClaims( 167 | &Agent{ 168 | Id: uint(3), 169 | Username: "username", 170 | SuperAdmin: true, 171 | AuthIds: []string{"999"}, 172 | RoleType: RoleAdmin, 173 | LoginType: LoginTypeWeb, 174 | AuthType: AuthPwd, 175 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 176 | }, 177 | ) 178 | for i := 0; i < 6; i++ { 179 | localAuth.Generate(cc) 180 | } 181 | if localAuth.isUserTokenOver(cc.roleType(), cc.Id) { 182 | t.Error("user token want not over but get over") 183 | } 184 | count := localAuth.getUserTokenCount(cc.roleType(), cc.Id) 185 | if count != 6 { 186 | t.Errorf("user token count want %v but get %v", 6, count) 187 | } 188 | } 189 | 190 | func TestSetUserTokenMaxCount(t *testing.T) { 191 | for i := 0; i < 6; i++ { 192 | localAuth.Generate(loginTypeApp) 193 | } 194 | if err := localAuth.SetLimit(5); err != nil { 195 | t.Fatalf("set user token max count %v", err) 196 | } 197 | count := localAuth.getUserTokenMaxCount() 198 | if count != 5 { 199 | t.Errorf("user token max count want %v but get %v", 5, count) 200 | } 201 | if !localAuth.isUserTokenOver(loginTypeApp.roleType(), loginTypeApp.Id) { 202 | t.Error("user token want over but get not over") 203 | } 204 | } 205 | func TestCleanUserTokenCache(t *testing.T) { 206 | for i := 0; i < 6; i++ { 207 | localAuth.Generate(loginTypeApp) 208 | } 209 | if err := localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id); err != nil { 210 | t.Fatalf("clear user token cache %v", err) 211 | } 212 | if localAuth.getUserTokenCount(loginTypeApp.roleType(), loginTypeApp.Id) != 0 { 213 | t.Error("user token count want 0 but get not 0") 214 | } 215 | } 216 | 217 | func TestLocalGetMultiClaims(t *testing.T) { 218 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id) 219 | var token string 220 | loginTypeApp.LoginType = 3 221 | token, _, err := localAuth.Generate(loginTypeApp) 222 | if err != nil { 223 | t.Fatalf("get custom claims %v", err) 224 | } 225 | for i := LoginTypeWeb; i <= LoginTypeDevice; i++ { 226 | wg.Add(1) 227 | loginTypeApp.setLoginType(int(i)) 228 | go func(i LoginType) { 229 | localAuth.Generate(loginTypeApp) 230 | wg.Done() 231 | }(i) 232 | wg.Wait() 233 | } 234 | for i := 0; i < 4; i++ { 235 | go func() { 236 | _, err := localAuth.GetClaims(token) 237 | if err != nil { 238 | t.Errorf("get custom claims fail:%v", err) 239 | } 240 | }() 241 | } 242 | time.Sleep(3 * time.Second) 243 | } 244 | 245 | func TestLocalGetUserTokens(t *testing.T) { 246 | loginTypeWeb := NewClaims( 247 | &Agent{ 248 | Id: uint(121321), 249 | Username: "username", 250 | SuperAdmin: true, 251 | AuthIds: []string{"999"}, 252 | RoleType: RoleAdmin, 253 | LoginType: LoginTypeWeb, 254 | AuthType: AuthPwd, 255 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 256 | }, 257 | ) 258 | 259 | defer localAuth.CleanCache(loginTypeWeb.roleType(), loginTypeWeb.Id) 260 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id) 261 | 262 | token, _, err := localAuth.Generate(loginTypeApp) 263 | if err != nil { 264 | t.Fatalf("get user tokens by claims generate token %v \n", err) 265 | } 266 | 267 | if token == "" { 268 | t.Fatal("get user tokens by claims generate token is empty \n") 269 | } 270 | 271 | token3232, _, err := localAuth.Generate(loginTypeWeb) 272 | if err != nil { 273 | t.Fatalf("get user tokens by claims generate token %v \n", err) 274 | } 275 | 276 | if token3232 == "" { 277 | t.Fatal("get user tokens by claims generate token is empty \n") 278 | } 279 | 280 | if token == token3232 { 281 | t.Fatal("get user tokens by claims generate token is same") 282 | } 283 | 284 | tokens, err := localAuth.getUserTokens(loginTypeApp.roleType(), loginTypeApp.Id) 285 | if err != nil { 286 | t.Fatalf("get user tokens by claims %v", err) 287 | } 288 | 289 | if len(tokens) != 1 { 290 | t.Fatalf("get user tokens by claims want len 1 but get %d", len(tokens)) 291 | } 292 | } 293 | 294 | func TestLocalGetTokenByClaims(t *testing.T) { 295 | LoginTypeWeb := NewClaims( 296 | &Agent{ 297 | Id: uint(3232), 298 | Username: "username", 299 | SuperAdmin: true, 300 | AuthIds: []string{"999"}, 301 | RoleType: RoleAdmin, 302 | LoginType: LoginTypeWeb, 303 | AuthType: AuthPwd, 304 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 305 | }, 306 | ) 307 | defer localAuth.CleanCache(LoginTypeWeb.roleType(), LoginTypeWeb.Id) 308 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id) 309 | 310 | token, _, err := localAuth.Generate(loginTypeApp) 311 | if err != nil { 312 | t.Fatalf("get token by claims generate token %v \n", err) 313 | } 314 | 315 | if token == "" { 316 | t.Fatal("get token by claims generate token is empty \n") 317 | } 318 | 319 | token3232, _, err := localAuth.Generate(LoginTypeWeb) 320 | if err != nil { 321 | t.Fatalf("get token by claims generate token %v \n", err) 322 | } 323 | 324 | if token3232 == "" { 325 | t.Fatal("get token by claims generate token is empty \n") 326 | } 327 | 328 | userToken, err := localAuth.Token(loginTypeApp) 329 | if err != nil { 330 | t.Fatalf("get token by claims %v", err) 331 | } 332 | 333 | if token != userToken { 334 | t.Errorf("get token by claims token want %s but get '%s'", token, userToken) 335 | } 336 | if token == token3232 { 337 | t.Errorf("get token by claims token not want %s but get '%s'", token3232, token) 338 | } 339 | 340 | } 341 | func TestLocalGetMultiClaimses(t *testing.T) { 342 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id) 343 | tokenLen := 0 344 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ { 345 | wg.Add(1) 346 | tokenLen++ 347 | loginTypeApp.setLoginType(int(i)) 348 | go func(i LoginType) { 349 | localAuth.Generate(loginTypeApp) 350 | wg.Done() 351 | }(i) 352 | wg.Wait() 353 | } 354 | userTokens, err := localAuth.getUserTokens(loginTypeApp.roleType(), loginTypeApp.Id) 355 | if err != nil { 356 | t.Fatal("get custom claimses generate token is empty \n") 357 | } 358 | clas, err := localAuth.getMultiClaimses(userTokens) 359 | if err != nil { 360 | t.Fatalf("get custom claimses %v", err) 361 | } 362 | 363 | if len(userTokens) != tokenLen { 364 | t.Fatalf("get custom claimses want len %d but get %d", tokenLen, len(userTokens)) 365 | } 366 | if len(clas) != tokenLen { 367 | t.Fatalf("get custom claimses want len %d but get %d", tokenLen, len(clas)) 368 | } 369 | 370 | } 371 | -------------------------------------------------------------------------------- /auth2/redis.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | 8 | "github.com/go-redis/redis/v8" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // RedisAuth 13 | type RedisAuth struct { 14 | Client redis.UniversalClient 15 | } 16 | 17 | // NewRedis 18 | func NewRedis(client redis.UniversalClient) (*RedisAuth, error) { 19 | if client == nil { 20 | return nil, errors.New("redis client is nil") 21 | } 22 | _, err := client.Ping(context.Background()).Result() 23 | if err != nil { 24 | return nil, err 25 | } 26 | return &RedisAuth{ 27 | Client: client, 28 | }, nil 29 | } 30 | 31 | // Generate 32 | func (ra *RedisAuth) Generate(claims *Claims) (string, int64, error) { 33 | token, err := ra.Token(claims) 34 | if err != nil { 35 | return "", int64(claims.ExpiresAt), err 36 | } 37 | 38 | if token == "" { 39 | if isOver, err := ra.isUserTokenOver(claims.roleType(), claims.Id); err != nil { 40 | return "", int64(claims.ExpiresAt), err 41 | } else if isOver { 42 | return "", int64(claims.ExpiresAt), ErrOverLimit 43 | } 44 | 45 | token, err = getToken() 46 | if err != nil { 47 | return "", int64(claims.ExpiresAt), err 48 | } 49 | } 50 | 51 | if err = ra.toCache(token, claims); err != nil { 52 | return "", int64(claims.ExpiresAt), err 53 | } 54 | 55 | if err = ra.syncUserTokenCache(token); err != nil { 56 | return "", int64(claims.ExpiresAt), err 57 | } 58 | 59 | return token, int64(claims.ExpiresAt), nil 60 | } 61 | 62 | // toCache 63 | func (ra *RedisAuth) toCache(token string, cla *Claims) error { 64 | sKey := TokenPrefix + token 65 | if _, err := ra.Client.HMSet(context.Background(), sKey, 66 | "id", cla.Id, 67 | "super_admin", cla.SuperAdmin, 68 | "login_type", cla.LoginType, 69 | "auth_type", cla.AuthType, 70 | "username", cla.Username, 71 | "auth_id", cla.AuthId, 72 | "role_type", cla.RoleType, 73 | "creation_data", cla.CreationTime, 74 | "expires_at", cla.ExpiresAt, 75 | ).Result(); err != nil { 76 | return fmt.Errorf("to cache token %w", err) 77 | } 78 | err := ra.setExpire(sKey, cla.loginType()) 79 | if err != nil { 80 | return err 81 | } 82 | 83 | return nil 84 | } 85 | 86 | // Token 87 | func (ra *RedisAuth) Token(cla *Claims) (string, error) { 88 | userTokens, err := ra.getUserTokens(cla.roleType(), cla.Id) 89 | if err != nil { 90 | return "", err 91 | } 92 | clas, err := ra.getMultiClaimses(userTokens) 93 | if err != nil { 94 | return "", err 95 | } 96 | for token, existCla := range clas { 97 | if cla.AuthType == existCla.AuthType && cla.Id == existCla.Id && cla.RoleType == existCla.RoleType && 98 | cla.AuthId == existCla.AuthId && cla.LoginType == existCla.LoginType { 99 | return token, nil 100 | } 101 | } 102 | return "", nil 103 | } 104 | 105 | // getMultiClaimses 106 | func (ra *RedisAuth) getMultiClaimses(tokens []string) (map[string]*Claims, error) { 107 | clas := make(map[string]*Claims, ra.getUserTokenLimit()) 108 | for _, token := range tokens { 109 | cla, err := ra.GetClaims(token) 110 | if err != nil { 111 | continue 112 | } 113 | clas[token] = cla 114 | } 115 | 116 | return clas, nil 117 | } 118 | 119 | // GetClaims 120 | func (ra *RedisAuth) GetClaims(token string) (*Claims, error) { 121 | cla := new(Claims) 122 | if err := ra.Client.HGetAll(context.Background(), TokenPrefix+token).Scan(cla); err != nil { 123 | return nil, fmt.Errorf("get custom claims redis hgetall %w", err) 124 | } 125 | 126 | if cla.Id == "" { 127 | return nil, ErrEmptyToken 128 | } 129 | 130 | return cla, nil 131 | } 132 | 133 | // isUserTokenOver 134 | func (ra *RedisAuth) isUserTokenOver(roleType RoleType, userId string) (bool, error) { 135 | max, err := ra.getUserTokenCount(roleType, userId) 136 | if err != nil { 137 | return true, err 138 | } 139 | return max >= ra.getUserTokenLimit(), nil 140 | } 141 | 142 | // getUserTokens 143 | func (ra *RedisAuth) getUserTokens(roleType RoleType, userId string) ([]string, error) { 144 | userTokens, err := ra.Client.SMembers(context.Background(), getPrefixKey(roleType, userId)).Result() 145 | if err != nil { 146 | return nil, fmt.Errorf("get user token count menbers %w", err) 147 | } 148 | return userTokens, nil 149 | } 150 | 151 | // getUserTokenCount 152 | func (ra *RedisAuth) getUserTokenCount(roleType RoleType, userId string) (int64, error) { 153 | var count int64 154 | userTokens, err := ra.getUserTokens(roleType, userId) 155 | if err != nil { 156 | return count, fmt.Errorf("get user token count menbers %w", err) 157 | } 158 | userPrefixKey := getPrefixKey(roleType, userId) 159 | for _, token := range userTokens { 160 | if ra.checkUserTokenCount(token, userPrefixKey) == 1 { 161 | count++ 162 | } 163 | } 164 | return count, nil 165 | } 166 | 167 | // checkUserTokenCount 168 | func (ra *RedisAuth) checkUserTokenCount(token, userPrefixKey string) int64 { 169 | mun, err := ra.Client.Exists(context.Background(), TokenPrefix+token).Result() 170 | if err != nil || mun == 0 { 171 | ra.Client.SRem(context.Background(), userPrefixKey, token) 172 | } 173 | return mun 174 | } 175 | 176 | // getUserTokenLimit 177 | func (ra *RedisAuth) getUserTokenLimit() int64 { 178 | count, err := ra.Client.Get(context.Background(), LimitTokenPrefix).Int64() 179 | if err != nil { 180 | return LimitTokenDefault 181 | } 182 | return count 183 | } 184 | 185 | // SetLimit 186 | func (ra *RedisAuth) SetLimit(limit int64) error { 187 | err := ra.Client.Set(context.Background(), LimitTokenPrefix, limit, 0).Err() 188 | if err != nil { 189 | return err 190 | } 191 | return nil 192 | } 193 | 194 | // syncUserTokenCache 195 | func (ra *RedisAuth) syncUserTokenCache(token string) error { 196 | cla, err := ra.GetClaims(token) 197 | if err != nil { 198 | return fmt.Errorf("sysnc user token cache %w", err) 199 | } 200 | userPrefixKey := getPrefixKey(cla.roleType(), cla.Id) 201 | if _, err := ra.Client.SAdd(context.Background(), userPrefixKey, token).Result(); err != nil { 202 | return fmt.Errorf("sync user token cache redis sadd %w", err) 203 | } 204 | 205 | bindUserPrefixKey := BindUserPrefix + token 206 | _, err = ra.Client.Set(context.Background(), bindUserPrefixKey, userPrefixKey, getExpire(cla.loginType())).Result() 207 | if err != nil { 208 | return fmt.Errorf("sync user token cache %w", err) 209 | } 210 | return nil 211 | } 212 | 213 | // UpdateCacheExpire 214 | func (ra *RedisAuth) UpdateCacheExpire(token string) error { 215 | rcc, err := ra.GetClaims(token) 216 | if err != nil { 217 | return fmt.Errorf("update user token cache expire %w", err) 218 | } 219 | if rcc == nil { 220 | return errors.New("token cache is nil") 221 | } 222 | if err = ra.setExpire(TokenPrefix+token, rcc.loginType()); err != nil { 223 | return fmt.Errorf("update user token cache expire redis expire %w", err) 224 | } 225 | if err = ra.setExpire(BindUserPrefix+token, rcc.loginType()); err != nil { 226 | return fmt.Errorf("update user token cache expire redis expire %w", err) 227 | } 228 | return nil 229 | } 230 | 231 | func (ra *RedisAuth) setExpire(key string, loginType LoginType) error { 232 | if _, err := ra.Client.Expire(context.Background(), key, getExpire(loginType)).Result(); err != nil { 233 | return fmt.Errorf("update user token cache expire redis expire %w", err) 234 | } 235 | return nil 236 | } 237 | 238 | // DelCache 239 | func (ra *RedisAuth) DelCache(token string) error { 240 | log.Println("auth2: redis del user token") 241 | cla, err := ra.GetClaims(token) 242 | if err != nil { 243 | return err 244 | } 245 | if cla == nil { 246 | return errors.New("del user token, reids cache is nil") 247 | } 248 | 249 | if e := ra.delUserTokenPrefixToken(cla.roleType(), cla.Id, token); e != nil { 250 | return e 251 | } 252 | 253 | if e := ra.delTokenCache(token); e != nil { 254 | return e 255 | } 256 | return nil 257 | } 258 | 259 | // delUserTokenPrefixToken 260 | func (ra *RedisAuth) delUserTokenPrefixToken(roleType RoleType, id, token string) error { 261 | _, err := ra.Client.SRem(context.Background(), getPrefixKey(roleType, id), token).Result() 262 | if err != nil { 263 | return fmt.Errorf("del user token cache redis srem %w", err) 264 | } 265 | return nil 266 | } 267 | 268 | // delTokenCache 269 | func (ra *RedisAuth) delTokenCache(token string) error { 270 | sKey2 := BindUserPrefix + token 271 | _, err := ra.Client.Del(context.Background(), sKey2).Result() 272 | if err != nil { 273 | return fmt.Errorf("del user token cache redis del2 %w", err) 274 | } 275 | 276 | sKey3 := TokenPrefix + token 277 | _, err = ra.Client.Del(context.Background(), sKey3).Result() 278 | if err != nil { 279 | return fmt.Errorf("del user token cache redis del3 %w", err) 280 | } 281 | 282 | return nil 283 | } 284 | 285 | // CleanCache 286 | func (ra *RedisAuth) CleanCache(roleType RoleType, userId string) error { 287 | allTokens, err := ra.getUserTokens(roleType, userId) 288 | if err != nil { 289 | return fmt.Errorf("clean user token cache redis smembers %w", err) 290 | } 291 | _, err = ra.Client.Del(context.Background(), getPrefixKey(roleType, userId)).Result() 292 | if err != nil { 293 | return fmt.Errorf("clean user token cache redis del %w", err) 294 | } 295 | 296 | for _, token := range allTokens { 297 | err = ra.delTokenCache(token) 298 | if err != nil { 299 | return err 300 | } 301 | } 302 | return nil 303 | } 304 | 305 | // IsRole 306 | func (ra *RedisAuth) IsRole(token string, roleType RoleType) (bool, error) { 307 | rcc, err := ra.GetClaims(token) 308 | if err != nil { 309 | return false, fmt.Errorf("get User's infomation return error: %w", err) 310 | } 311 | return rcc.roleType() == roleType, nil 312 | } 313 | 314 | // IsSuperAdmin 315 | func (ra *RedisAuth) IsSuperAdmin(token string) bool { 316 | rcc, err := ra.GetClaims(token) 317 | if err != nil { 318 | return false 319 | } 320 | return rcc.SuperAdmin 321 | } 322 | 323 | // Close 324 | func (ra *RedisAuth) Close() { 325 | ra.Client.Close() 326 | } 327 | -------------------------------------------------------------------------------- /auth2/redis_test.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "os" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/go-redis/redis/v8" 12 | ) 13 | 14 | var ( 15 | wg sync.WaitGroup 16 | options = &redis.UniversalOptions{ 17 | DB: 1, 18 | Addrs: []string{"127.0.0.1:6379"}, 19 | Password: os.Getenv("redisPwd"), // 20 | PoolSize: 10, 21 | IdleTimeout: 300 * time.Second, 22 | // Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { 23 | // conn, err := net.Dial(network, addr) 24 | // if err == nil { 25 | // go func() { 26 | // time.Sleep(5 * time.Second) 27 | // conn.Close() 28 | // }() 29 | // } 30 | // return conn, err 31 | // }, 32 | } 33 | 34 | rToken = "TVRReU1EVTFOek13TmpFd09UWXlPRFF4TmcuTWpBeU1TMHdOeTB5T1ZRd09Ub3pNRG95T1Nzd09Eb3dNQQ.MTQyMDU1NzMwNjEwOTYyODrtrt" 35 | logTypeWeb = NewClaims( 36 | &Agent{ 37 | Id: uint(121321), 38 | Username: "username", 39 | SuperAdmin: true, 40 | AuthIds: []string{"999"}, 41 | RoleType: RoleAdmin, 42 | LoginType: LoginTypeWeb, 43 | AuthType: AuthPwd, 44 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 45 | }, 46 | ) 47 | ruserKey = getPrefixKey(logTypeWeb.roleType(), logTypeWeb.Id) 48 | ) 49 | 50 | func TestRedisGenerateToken(t *testing.T) { 51 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 52 | if err != nil { 53 | t.Fatal(err.Error()) 54 | } 55 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 56 | token, expiresIn, err := redisAuth.Generate(logTypeWeb) 57 | if err != nil { 58 | t.Fatalf("generate token %v", err) 59 | } 60 | if token == "" { 61 | t.Error("generate token is empty") 62 | } 63 | 64 | if expiresIn != logTypeWeb.ExpiresAt { 65 | t.Errorf("generate token expires want %v but get %v", logTypeWeb.ExpiresAt, expiresIn) 66 | } 67 | cc, err := redisAuth.GetClaims(token) 68 | if err != nil { 69 | t.Fatalf("get custom claims %v", err) 70 | } 71 | 72 | if cc.Id != logTypeWeb.Id { 73 | t.Errorf("get custom id want %v but get %v", logTypeWeb.Id, cc.Id) 74 | } 75 | if cc.Username != logTypeWeb.Username { 76 | t.Errorf("get custom username want %v but get %v", logTypeWeb.Username, cc.Username) 77 | } 78 | if cc.AuthId != logTypeWeb.AuthId { 79 | t.Errorf("get custom authority_id want %v but get %v", logTypeWeb.AuthId, cc.AuthId) 80 | } 81 | if cc.RoleType != logTypeWeb.RoleType { 82 | t.Errorf("get custom authority_type want %v but get %v", logTypeWeb.RoleType, cc.RoleType) 83 | } 84 | if cc.LoginType != logTypeWeb.LoginType { 85 | t.Errorf("get custom login_type want %v but get %v", logTypeWeb.LoginType, cc.LoginType) 86 | } 87 | if cc.AuthType != logTypeWeb.AuthType { 88 | t.Errorf("get custom auth_type want %v but get %v", logTypeWeb.AuthType, cc.AuthType) 89 | } 90 | if cc.CreationTime != logTypeWeb.CreationTime { 91 | t.Errorf("get custom creation_data want %v but get %v", logTypeWeb.CreationTime, cc.CreationTime) 92 | } 93 | if cc.ExpiresAt != logTypeWeb.ExpiresAt { 94 | t.Errorf("get custom expires_at want %v but get %v", logTypeWeb.ExpiresAt, cc.ExpiresAt) 95 | } 96 | 97 | if uTokens, err := redisAuth.Client.SMembers(context.Background(), ruserKey).Result(); err != nil { 98 | t.Fatalf("user prefix value get %s", err) 99 | } else { 100 | if len(uTokens) == 0 || uTokens[0] != token { 101 | t.Errorf("user prefix value want %v but get %v", ruserKey, uTokens) 102 | } 103 | } 104 | bindKey := BindUserPrefix + token 105 | key, err := redisAuth.Client.Get(context.Background(), bindKey).Result() 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | if key != ruserKey { 110 | t.Errorf("bind user prefix value want %v but get %v", ruserKey, key) 111 | } 112 | } 113 | 114 | func TestRedisToCache(t *testing.T) { 115 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 116 | if err != nil { 117 | t.Fatal(err.Error()) 118 | } 119 | defer redisAuth.Client.Del(context.Background(), TokenPrefix+rToken) 120 | if err := redisAuth.toCache(rToken, logTypeWeb); err != nil { 121 | t.Fatalf("generate token %v", err) 122 | } 123 | cc, err := redisAuth.GetClaims(rToken) 124 | if err != nil { 125 | t.Fatalf("get custom claims %v", err) 126 | } 127 | 128 | if cc.Id != logTypeWeb.Id { 129 | t.Errorf("get custom id want %v but get %v", logTypeWeb.Id, cc.Id) 130 | } 131 | if cc.Username != logTypeWeb.Username { 132 | t.Errorf("get custom username want %v but get %v", logTypeWeb.Username, cc.Username) 133 | } 134 | if cc.AuthId != logTypeWeb.AuthId { 135 | t.Errorf("get custom authority_id want %v but get %v", logTypeWeb.AuthId, cc.AuthId) 136 | } 137 | if cc.RoleType != logTypeWeb.RoleType { 138 | t.Errorf("get custom authority_type want %v but get %v", logTypeWeb.RoleType, cc.RoleType) 139 | } 140 | if cc.LoginType != logTypeWeb.LoginType { 141 | t.Errorf("get custom login_type want %v but get %v", logTypeWeb.LoginType, cc.LoginType) 142 | } 143 | if cc.AuthType != logTypeWeb.AuthType { 144 | t.Errorf("get custom auth_type want %v but get %v", logTypeWeb.AuthType, cc.AuthType) 145 | } 146 | if cc.CreationTime != logTypeWeb.CreationTime { 147 | t.Errorf("get custom creation_data want %v but get %v", logTypeWeb.CreationTime, cc.CreationTime) 148 | } 149 | if cc.ExpiresAt != logTypeWeb.ExpiresAt { 150 | t.Errorf("get custom expires_at want %v but get %v", logTypeWeb.ExpiresAt, cc.ExpiresAt) 151 | } 152 | } 153 | 154 | func TestRedisDelUserTokenCache(t *testing.T) { 155 | cc := NewClaims( 156 | &Agent{ 157 | Id: uint(221), 158 | Username: "username", 159 | SuperAdmin: true, 160 | AuthIds: []string{"999"}, 161 | RoleType: RoleAdmin, 162 | LoginType: LoginTypeWeb, 163 | AuthType: AuthPwd, 164 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 165 | }, 166 | ) 167 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 168 | if err != nil { 169 | t.Fatal(err.Error()) 170 | } 171 | defer redisAuth.CleanCache(cc.roleType(), cc.Id) 172 | token, _, _ := redisAuth.Generate(cc) 173 | if token == "" { 174 | t.Error("generate token is empty") 175 | } 176 | 177 | if err := redisAuth.DelCache(token); err != nil { 178 | t.Fatalf("del user token cache %v", err) 179 | } 180 | _, err = redisAuth.GetClaims(token) 181 | if !errors.Is(err, ErrEmptyToken) { 182 | t.Fatalf("get custom claims err want '%v' but get '%v'", ErrEmptyToken, err) 183 | } 184 | 185 | if uTokens, err := redisAuth.Client.SMembers(context.Background(), UserPrefix+cc.Id).Result(); err != nil { 186 | t.Fatalf("user prefix value wantget %v", err) 187 | } else if len(uTokens) != 0 { 188 | t.Errorf("user prefix value want empty but get %+v", uTokens) 189 | } 190 | bindKey := BindUserPrefix + token 191 | key, _ := redisAuth.Client.Get(context.Background(), bindKey).Result() 192 | if key != "" { 193 | t.Errorf("bind user prefix value want empty but get %v", key) 194 | } 195 | } 196 | 197 | func TestRedisIsUserTokenOver(t *testing.T) { 198 | cc := NewClaims( 199 | &Agent{ 200 | Id: uint(3232), 201 | Username: "username", 202 | SuperAdmin: true, 203 | AuthIds: []string{"999"}, 204 | RoleType: RoleAdmin, 205 | LoginType: LoginTypeWeb, 206 | AuthType: AuthPwd, 207 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 208 | }, 209 | ) 210 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 211 | if err != nil { 212 | t.Fatal(err.Error()) 213 | } 214 | defer redisAuth.CleanCache(cc.roleType(), cc.Id) 215 | if err := redisAuth.SetLimit(10); err != nil { 216 | t.Fatalf("set user token max count %v", err) 217 | } 218 | var wantTokenLen int64 = 0 219 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ { 220 | cc.setLoginType(int(i)) 221 | wg.Add(1) 222 | wantTokenLen++ 223 | go func(i LoginType) { 224 | redisAuth.Generate(cc) 225 | wg.Done() 226 | }(i) 227 | wg.Wait() 228 | } 229 | isOver, err := redisAuth.isUserTokenOver(cc.roleType(), cc.Id) 230 | if err != nil { 231 | t.Fatalf("is user token over get %v", err) 232 | } 233 | if isOver { 234 | t.Error("user token want not over but get over") 235 | } 236 | count, err := redisAuth.getUserTokenCount(cc.roleType(), cc.Id) 237 | if err != nil { 238 | t.Fatalf("user token count get %v", err) 239 | } 240 | if count != wantTokenLen { 241 | t.Errorf("user token count want %v but get %v", wantTokenLen, count) 242 | } 243 | } 244 | 245 | func TestRedisSetUserTokenMaxCount(t *testing.T) { 246 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 247 | if err != nil { 248 | t.Fatal(err.Error()) 249 | } 250 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 251 | if err := redisAuth.SetLimit(10); err != nil { 252 | t.Fatalf("set user token max count %v", err) 253 | } 254 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ { 255 | wg.Add(1) 256 | logTypeWeb.setLoginType(int(i)) 257 | go func(i LoginType) { 258 | redisAuth.Generate(logTypeWeb) 259 | wg.Done() 260 | }(i) 261 | wg.Wait() 262 | } 263 | if err := redisAuth.SetLimit(3); err != nil { 264 | t.Fatalf("set user token max count %v", err) 265 | } 266 | count := redisAuth.getUserTokenLimit() 267 | if count != 3 { 268 | t.Errorf("user token max count want %v but get %v", 3, count) 269 | } 270 | isOver, err := redisAuth.isUserTokenOver(logTypeWeb.roleType(), logTypeWeb.Id) 271 | if err != nil { 272 | t.Fatalf("is user token over get %v", err) 273 | } 274 | if !isOver { 275 | t.Error("user token want over but get not over") 276 | } 277 | } 278 | func TestRedisCleanUserTokenCache(t *testing.T) { 279 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 280 | if err != nil { 281 | t.Fatal(err.Error()) 282 | } 283 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 284 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ { 285 | wg.Add(1) 286 | logTypeWeb.setLoginType(int(i)) 287 | go func(i LoginType) { 288 | redisAuth.Generate(logTypeWeb) 289 | wg.Done() 290 | }(i) 291 | wg.Wait() 292 | } 293 | if err := redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id); err != nil { 294 | t.Fatalf("clear user token cache %v", err) 295 | } 296 | count, err := redisAuth.getUserTokenCount(logTypeWeb.roleType(), logTypeWeb.Id) 297 | if err != nil { 298 | t.Fatalf("user token count get %v", err) 299 | } 300 | if count != 0 { 301 | t.Error("user token count want 0 but get not 0") 302 | } 303 | } 304 | 305 | func TestRedisGetMultiClaims(t *testing.T) { 306 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 307 | if err != nil { 308 | t.Fatal(err.Error()) 309 | } 310 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 311 | logTypeWeb.LoginType = 3 312 | token, _, err := redisAuth.Generate(logTypeWeb) 313 | if err != nil { 314 | t.Fatalf("get custom claims %v", err) 315 | } 316 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ { 317 | wg.Add(1) 318 | logTypeWeb.setLoginType(int(i)) 319 | go func(i LoginType) { 320 | redisAuth.Generate(logTypeWeb) 321 | wg.Done() 322 | }(i) 323 | wg.Wait() 324 | } 325 | for i := 0; i < 4; i++ { 326 | go func() { 327 | _, err := redisAuth.GetClaims(token) 328 | if err != nil { 329 | t.Errorf("get custom claims %v", err) 330 | } 331 | }() 332 | } 333 | time.Sleep(3 * time.Second) 334 | } 335 | 336 | func TestRedisGetUserTokens(t *testing.T) { 337 | cc := NewClaims( 338 | &Agent{ 339 | Id: uint(121321), 340 | Username: "username", 341 | SuperAdmin: true, 342 | AuthIds: []string{"999"}, 343 | RoleType: RoleAdmin, 344 | LoginType: LoginTypeWeb, 345 | AuthType: AuthPwd, 346 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 347 | }, 348 | ) 349 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 350 | if err != nil { 351 | t.Fatal(err.Error()) 352 | } 353 | defer redisAuth.CleanCache(cc.roleType(), cc.Id) 354 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 355 | token, _, err := redisAuth.Generate(logTypeWeb) 356 | if err != nil { 357 | t.Fatalf("get user tokens by claims generate token %v \n", err) 358 | } 359 | 360 | if token == "" { 361 | t.Fatal("get user tokens by claims generate token is empty \n") 362 | } 363 | 364 | token3232, _, err := redisAuth.Generate(cc) 365 | if err != nil { 366 | t.Fatalf("get user tokens by claims generate token %v \n", err) 367 | } 368 | 369 | if token3232 == "" { 370 | t.Fatal("get user tokens by claims generate token is empty \n") 371 | } 372 | 373 | tokens, err := redisAuth.getUserTokens(logTypeWeb.roleType(), logTypeWeb.Id) 374 | if err != nil { 375 | t.Fatalf("get user tokens by claims %v", err) 376 | } 377 | wantTokenLen := 2 378 | if len(tokens) != wantTokenLen { 379 | t.Fatalf("get user tokens by claims want len %d but get %d", wantTokenLen, len(tokens)) 380 | } 381 | } 382 | 383 | func TestRedisGetTokenByClaims(t *testing.T) { 384 | cc := NewClaims( 385 | &Agent{ 386 | Id: uint(3232), 387 | Username: "username", 388 | SuperAdmin: true, 389 | AuthIds: []string{"999"}, 390 | RoleType: RoleAdmin, 391 | LoginType: LoginTypeWeb, 392 | AuthType: AuthPwd, 393 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(), 394 | }, 395 | ) 396 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 397 | if err != nil { 398 | t.Fatal(err.Error()) 399 | } 400 | defer redisAuth.CleanCache(cc.roleType(), cc.Id) 401 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 402 | 403 | token, _, err := redisAuth.Generate(logTypeWeb) 404 | if err != nil { 405 | t.Fatalf("get token by claims generate token %v \n", err) 406 | } 407 | 408 | if token == "" { 409 | t.Fatal("get token by claims generate token is empty \n") 410 | } 411 | 412 | token3232, _, err := redisAuth.Generate(cc) 413 | if err != nil { 414 | t.Fatalf("get token by claims generate token %v \n", err) 415 | } 416 | 417 | if token3232 == "" { 418 | t.Fatal("get token by claims generate token is empty \n") 419 | } 420 | 421 | userToken, err := redisAuth.Token(logTypeWeb) 422 | if err != nil { 423 | t.Fatalf("get token by claims %v", err) 424 | } 425 | 426 | if token != userToken { 427 | t.Errorf("get token by claims token want %s but get %s", token, userToken) 428 | } 429 | if token == token3232 { 430 | t.Errorf("get token by claims token not want %s but get %s", token3232, token) 431 | } 432 | 433 | } 434 | func TestRedisGetMultiClaimses(t *testing.T) { 435 | redisAuth, err := NewRedis(redis.NewUniversalClient(options)) 436 | if err != nil { 437 | t.Fatal(err.Error()) 438 | } 439 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id) 440 | wantTokenLen := 0 441 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ { 442 | wg.Add(1) 443 | wantTokenLen++ 444 | logTypeWeb.setLoginType(int(i)) 445 | go func(i LoginType) { 446 | redisAuth.Generate(logTypeWeb) 447 | wg.Done() 448 | }(i) 449 | wg.Wait() 450 | } 451 | userTokens, err := redisAuth.getUserTokens(logTypeWeb.roleType(), logTypeWeb.Id) 452 | if err != nil { 453 | t.Fatal("get custom claimses generate token is empty \n") 454 | } 455 | clas, err := redisAuth.getMultiClaimses(userTokens) 456 | if err != nil { 457 | t.Fatalf("get custom claimses %v", err) 458 | } 459 | 460 | if len(userTokens) != wantTokenLen { 461 | t.Fatalf("get custom claimses want len %d but get %d", wantTokenLen, len(userTokens)) 462 | } 463 | 464 | if len(clas) != wantTokenLen { 465 | t.Fatalf("get custom claimses want len %d but get %d", wantTokenLen, len(clas)) 466 | } 467 | 468 | } 469 | -------------------------------------------------------------------------------- /auth2/token.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "fmt" 7 | 8 | "github.com/bwmarrin/snowflake" 9 | uuid "github.com/satori/go.uuid" 10 | "github.com/snowlyg/helper/dir" 11 | ) 12 | 13 | var ( 14 | sep = []byte(".") 15 | pad = []byte("=") 16 | padStr = string(pad) 17 | ) 18 | 19 | // getToken 20 | func getToken() (string, error) { 21 | v4 := uuid.NewV4() 22 | node, err := snowflake.NewNode(1) 23 | if err != nil { 24 | return "", fmt.Errorf("token: get token %w", err) 25 | } 26 | 27 | // 混入两个时间,防止并发token重复 28 | nodeBytes, _ := dir.Md5Byte(Base64Encode(node.Generate().Bytes())) 29 | uuidBytes, _ := dir.Md5Byte(Base64Encode(joinParts(Base64Encode(v4.Bytes()), []byte(nodeBytes)))) 30 | token := joinParts(Base64Encode([]byte(uuidBytes)), Base64Encode([]byte(nodeBytes))) 31 | return string(Base64Encode([]byte(token))), nil 32 | } 33 | 34 | // joinParts 35 | func joinParts(parts ...[]byte) []byte { 36 | return bytes.Join(parts, sep) 37 | } 38 | 39 | // Base64Encode 40 | func Base64Encode(src []byte) []byte { 41 | buf := make([]byte, base64.URLEncoding.EncodedLen(len(src))) 42 | base64.URLEncoding.Encode(buf, src) 43 | 44 | return bytes.TrimRight(buf, padStr) // JWT: no trailing '='. 45 | } 46 | 47 | // Base64Decode decodes "src" to jwt base64 url format. 48 | // We could use the base64.RawURLEncoding but the below is a bit faster. 49 | func Base64Decode(src []byte) ([]byte, error) { 50 | if n := len(src) % 4; n > 0 { 51 | // JWT: Because of no trailing '=' let's suffix it 52 | // with the correct number of those '=' before decoding. 53 | src = append(src, bytes.Repeat(pad, 4-n)...) 54 | } 55 | 56 | buf := make([]byte, base64.URLEncoding.DecodedLen(len(src))) 57 | n, err := base64.URLEncoding.Decode(buf, src) 58 | return buf[:n], err 59 | } 60 | -------------------------------------------------------------------------------- /auth2/token_test.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/snowlyg/helper/arr" 8 | ) 9 | 10 | func TestGetToken(t *testing.T) { 11 | token, err := getToken() 12 | if err != nil { 13 | t.Error(err) 14 | } 15 | if token == "" { 16 | t.Error("Generate token is fail.") 17 | } 18 | if token1, err := getToken(); err != nil { 19 | t.Error(err) 20 | } else if token == "" { 21 | t.Error("Generate token is fail.") 22 | } else if token == token1 { 23 | t.Errorf("token[%s] token1[%s] is repeat", token, token1) 24 | } 25 | } 26 | 27 | func TestJoinParts(t *testing.T) { 28 | afterJoin := joinParts([]byte("header"), []byte("footer")) 29 | want := []byte("header.footer") 30 | if bytes.Compare(afterJoin, want) > 0 { 31 | t.Errorf("Join parts want %s but get %s", string(want), string(afterJoin)) 32 | } 33 | } 34 | 35 | func TestBase64Encode(t *testing.T) { 36 | want := []byte("header") 37 | baseEncode := Base64Encode(want) 38 | afterDecode, err := Base64Decode(baseEncode) 39 | if err != nil { 40 | t.Error(err) 41 | } 42 | if bytes.Compare(afterDecode, want) > 0 { 43 | t.Errorf("Base64Encode and Base64Decode not effect") 44 | } 45 | } 46 | 47 | func BenchmarkGetToken(b *testing.B) { 48 | b.Run("Benchmark test get token", func(b *testing.B) { 49 | tokens := Token{CheckArrayType: *arr.NewCheckArrayType(b.N)} 50 | for i := 0; i < b.N; i++ { 51 | token, err := getToken() 52 | if err != nil { 53 | b.Error(err) 54 | } 55 | if token == "" { 56 | b.Error("Generate token is fail.") 57 | } 58 | if tokens.Check(token) { 59 | b.Fatalf("token is repeat") 60 | } 61 | tokens.Add(token) 62 | } 63 | }) 64 | } 65 | 66 | type Token struct { 67 | arr.CheckArrayType 68 | } 69 | -------------------------------------------------------------------------------- /auth2/verifier.go: -------------------------------------------------------------------------------- 1 | package auth2 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | const ( 13 | claimsContextKey = "gin.auth2.claims" 14 | verifiedTokenContextKey = "gin.auth2.token" 15 | ) 16 | 17 | // Get returns the claims decoded by a verifier. 18 | func Get(ctx *gin.Context) *Claims { 19 | v, b := ctx.Get(claimsContextKey) 20 | if !b { 21 | log.Println("verifier: key not exist") 22 | return nil 23 | } 24 | tok, ok := v.(*Claims) 25 | if !ok { 26 | log.Println("verifier: object not claims") 27 | return nil 28 | } 29 | return tok 30 | } 31 | 32 | // GetType 33 | func GetType(ctx *gin.Context) RoleType { 34 | if v := Get(ctx); v != nil { 35 | return v.roleType() 36 | } 37 | return 0 38 | } 39 | 40 | // GetAuthId 41 | func GetAuthId(ctx *gin.Context) []string { 42 | if v := Get(ctx); v != nil { 43 | return strings.Split(v.AuthId, AuthTypeSplit) 44 | } 45 | return nil 46 | } 47 | 48 | // GetUserId 49 | func GetUserId(ctx *gin.Context) uint { 50 | v := Get(ctx) 51 | if v == nil { 52 | return 0 53 | } 54 | id, err := strconv.Atoi(v.Id) 55 | if err != nil { 56 | return 0 57 | } 58 | return uint(id) 59 | } 60 | 61 | // IsSuperAdmin 62 | func IsSuperAdmin(ctx *gin.Context) bool { 63 | v := Get(ctx) 64 | if v == nil { 65 | log.Println("verifier: Claim is nil") 66 | return false 67 | } 68 | return v.SuperAdmin 69 | } 70 | 71 | // GetUsername 72 | func GetUsername(ctx *gin.Context) string { 73 | if v := Get(ctx); v != nil { 74 | return v.Username 75 | } 76 | return "" 77 | } 78 | 79 | // GetCreationDate 80 | func GetCreationDate(ctx *gin.Context) int64 { 81 | if v := Get(ctx); v != nil { 82 | return v.CreationTime 83 | } 84 | return 0 85 | } 86 | 87 | // GetExpiresIn 88 | func GetExpiresIn(ctx *gin.Context) int64 { 89 | if v := Get(ctx); v != nil { 90 | return v.ExpiresAt 91 | } 92 | return 0 93 | } 94 | 95 | func GetVerifiedToken(ctx *gin.Context) []byte { 96 | v, b := ctx.Get(verifiedTokenContextKey) 97 | if !b { 98 | return nil 99 | } 100 | if tok, ok := v.([]byte); ok { 101 | return tok 102 | } 103 | return nil 104 | } 105 | 106 | func IsRole(ctx *gin.Context, roleType RoleType) bool { 107 | v := GetVerifiedToken(ctx) 108 | if v == nil { 109 | return false 110 | } 111 | b, err := AuthAgent.IsRole(string(v), roleType) 112 | if err != nil { 113 | return false 114 | } 115 | return b 116 | } 117 | 118 | func IsAdmin(ctx *gin.Context) bool { 119 | return IsRole(ctx, RoleAdmin) 120 | } 121 | 122 | type Verifier struct { 123 | Extractors []TokenExtractor 124 | Validators []TokenValidator 125 | ErrorHandler func(ctx *gin.Context, err error) 126 | } 127 | 128 | func NewVerifier(validators ...TokenValidator) *Verifier { 129 | return &Verifier{ 130 | Extractors: []TokenExtractor{FromHeader, FromQuery}, 131 | ErrorHandler: func(ctx *gin.Context, err error) { 132 | ctx.AbortWithError(http.StatusUnauthorized, err) 133 | }, 134 | Validators: validators, 135 | } 136 | } 137 | 138 | // Invalidate 139 | func (v *Verifier) invalidate(ctx *gin.Context) { 140 | if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil { 141 | ctx.Set(claimsContextKey, "") 142 | ctx.Set(verifiedTokenContextKey, "") 143 | } 144 | } 145 | 146 | // RequestToken extracts the token from the 147 | func (v *Verifier) RequestToken(ctx *gin.Context) (token string) { 148 | for _, extract := range v.Extractors { 149 | if token = extract(ctx); token != "" { 150 | break // ok we found it. 151 | } 152 | } 153 | return 154 | } 155 | 156 | func (v *Verifier) VerifyToken(token []byte, validators ...TokenValidator) ([]byte, *Claims, error) { 157 | if len(token) == 0 { 158 | return nil, nil, ErrEmptyToken 159 | } 160 | var err error 161 | for _, validator := range validators { 162 | // A token validator can skip the builtin validation and return a nil error, 163 | // in that case the previous error is skipped. 164 | if err = validator.Validater(token, err); err != nil { 165 | break 166 | } 167 | } 168 | if err != nil { 169 | // Exit on parsing standard claims error(when Plain is missing) or standard claims validation error or custom validators. 170 | return nil, nil, err 171 | } 172 | rcc, err := AuthAgent.GetClaims(string(token)) 173 | if err != nil { 174 | return nil, nil, err 175 | } 176 | err = rcc.Valid() 177 | if err != nil { 178 | return nil, nil, err 179 | } 180 | return token, rcc, nil 181 | } 182 | 183 | func (v *Verifier) Verify(validators ...TokenValidator) gin.HandlerFunc { 184 | return func(ctx *gin.Context) { 185 | token := []byte(v.RequestToken(ctx)) 186 | verifiedToken, rcc, err := v.VerifyToken(token, validators...) 187 | if err != nil { 188 | v.invalidate(ctx) 189 | v.ErrorHandler(ctx, err) 190 | return 191 | } 192 | ctx.Set(claimsContextKey, rcc) 193 | ctx.Set(verifiedTokenContextKey, verifiedToken) 194 | ctx.Next() 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /conf/auth.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "fmt" 5 | "path/filepath" 6 | 7 | "github.com/casbin/casbin/v2" 8 | gormadapter "github.com/casbin/gorm-adapter/v3" 9 | "github.com/snowlyg/helper/dir" 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const CasbinName = "rbac_model.conf" 14 | 15 | // Remove del config file 16 | func (conf *Conf) RemoveRbacModel() error { 17 | p := conf.casbinFilePath() 18 | if filepath.Base(p) != CasbinName { 19 | return nil 20 | } 21 | if dir.IsExist(p) && dir.IsFile(p) { 22 | return dir.Remove(p) 23 | } 24 | return nil 25 | } 26 | 27 | // casbinFilePath 28 | func (conf *Conf) casbinFilePath() string { 29 | return filepath.Join(dir.GetCurrentAbPath(), ConfigDir, CasbinName) 30 | } 31 | 32 | // newRbacModel initialize casbin's config file as rbac_model.conf name 33 | func (conf *Conf) newRbacModel() { 34 | if dir.IsExist(conf.casbinFilePath()) { 35 | // casbin rbac_model.conf file 36 | // log.Printf("rbac_model.conf file is existed.") 37 | return 38 | } 39 | 40 | var rbacModelConf = []byte(`[request_definition] 41 | r = sub, obj, act 42 | 43 | [policy_definition] 44 | p = sub, obj, act 45 | 46 | [role_definition] 47 | g = _, _ 48 | 49 | [policy_effect] 50 | e = some(where (p.eft == allow)) 51 | 52 | [matchers] 53 | m = g(r.sub, p.sub) && keyMatch2(r.obj, p.obj) && (r.act == p.act || p.act == "*")`) 54 | if _, err := dir.WriteBytes(conf.casbinFilePath(), rbacModelConf); err != nil { 55 | panic(fmt.Errorf("initialize casbin rbac_model.conf file return error: %w ", err)) 56 | } 57 | } 58 | 59 | // getEnforcer get casbin.Enforcer 60 | func (conf *Conf) GetEnforcer(db *gorm.DB) (*casbin.Enforcer, error) { 61 | if db == nil { 62 | return nil, gorm.ErrInvalidDB 63 | } 64 | c, err := gormadapter.NewAdapterByDBUseTableName(db, "", "casbin_rule") // Your driver and data source. 65 | if err != nil { 66 | return nil, err 67 | } 68 | enforcer, err := casbin.NewEnforcer(conf.casbinFilePath(), c) 69 | if err != nil { 70 | return nil, err 71 | } 72 | if err = enforcer.LoadPolicy(); err != nil { 73 | return nil, err 74 | } 75 | return enforcer, nil 76 | } 77 | -------------------------------------------------------------------------------- /conf/auth_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/snowlyg/helper/dir" 7 | ) 8 | 9 | func TestNewRbacModel(t *testing.T) { 10 | conf := new(Conf) 11 | if conf.casbinFilePath() == "" { 12 | t.Errorf("rbac model path:%s empty", conf.casbinFilePath()) 13 | } 14 | conf.newRbacModel() 15 | if !dir.IsExist(CasbinName) { 16 | t.Error("rbac_model.conf not exist after conf not init and new rbac model") 17 | } 18 | conf.RemoveRbacModel() 19 | if dir.IsExist(CasbinName) { 20 | t.Error("rbac_model.conf exist after conf not init and new rbac model") 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "log" 7 | "os" 8 | "strconv" 9 | "strings" 10 | 11 | "github.com/gin-gonic/gin" 12 | "github.com/spf13/viper" 13 | ) 14 | 15 | const ( 16 | ConfigType = "json" // config's type 17 | ConfigDir = "config" // config's dir 18 | ) 19 | 20 | var ( 21 | mysqlAddrKey = "IRIS_ADMIN_MYSQL_ADDR" 22 | mysqlPwdKey = "IRIS_ADMIN_MYSQL_PWD" 23 | mysqlNameKey = "IRIS_ADMIN_MYSQL_NAME" 24 | webAddrKey = "IRIS_ADMIN_WEB_ADDR" 25 | ) 26 | 27 | func NewConf() *Conf { 28 | c := &Conf{ 29 | Locale: "zh", 30 | FileMaxSize: 1024, // upload file size limit 1024M 31 | SessionTimeout: 172800, // session timeout after 4 months 32 | CorsConf: CorsConf{ 33 | AccessOrigin: "*", 34 | AccessHeaders: "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token,X-Token,X-User-Id", 35 | AccessMethods: "POST,GET,OPTIONS,DELETE,PUT", 36 | AccessExposeHeaders: "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type", 37 | AccessCredentials: "true", 38 | }, 39 | Except: Route{ 40 | Uri: "", 41 | Method: "", 42 | }, 43 | System: System{ 44 | Tls: false, 45 | GinMode: gin.ReleaseMode, 46 | Level: "debug", 47 | Addr: "127.0.0.1:8080", 48 | TimeFormat: "2006-01-02 15:04:05", 49 | }, 50 | Limit: Limit{ 51 | Disable: true, 52 | Limit: 0, 53 | Burst: 5, 54 | }, 55 | Captcha: Captcha{ 56 | KeyLong: 0, 57 | ImgWidth: 240, 58 | ImgHeight: 80, 59 | }, 60 | Mysql: &Mysql{ 61 | Path: "127.0.0.1:3306", 62 | Config: "charset=utf8mb4&parseTime=True&loc=Local", 63 | DbName: "iris-admin", 64 | Username: "root", 65 | Password: "", 66 | MaxIdleConns: 0, 67 | MaxOpenConns: 0, 68 | LogMode: false, 69 | LogZap: "error", 70 | }, 71 | Operate: Operate{ 72 | Except: Route{ 73 | Uri: "api/v1/upload;api/v1/upload", 74 | Method: "post;put", 75 | }, 76 | Include: Route{ 77 | Uri: "api/v1/menus", 78 | Method: "get", 79 | }, 80 | }, 81 | } 82 | mysqlAddr := strings.TrimSpace(os.Getenv(mysqlAddrKey)) 83 | mysqlPwd := strings.TrimSpace(os.Getenv(mysqlPwdKey)) 84 | mysqlName := strings.TrimSpace(os.Getenv(mysqlNameKey)) 85 | webAddr := strings.TrimSpace(os.Getenv(webAddrKey)) 86 | if mysqlAddr != "" { 87 | c.Mysql.Path = mysqlAddr 88 | } 89 | if mysqlPwd != "" { 90 | c.Mysql.Password = mysqlPwd 91 | } 92 | if mysqlName != "" { 93 | c.Mysql.Username = mysqlName 94 | } 95 | if webAddr != "" { 96 | c.System.Addr = webAddr 97 | } 98 | if c.Mysql.Path == "" || c.Mysql.Password == "" || c.Mysql.DbName == "" { 99 | log.Printf("mysql driver config empty,you can set env %s %s %s to change it.\n", mysqlAddrKey, mysqlPwdKey, mysqlNameKey) 100 | } 101 | return c 102 | } 103 | 104 | type Conf struct { 105 | Locale string `mapstructure:"locale" json:"locale" yaml:"locale"` 106 | FileMaxSize int64 `mapstructure:"file-max-size" json:"file-max-size" yaml:"file-max-siz"` 107 | SessionTimeout int64 `mapstructure:"session-timeout" json:"session-timeout" yaml:"session-timeout"` 108 | Except Route `mapstructure:"except" json:"except" yaml:"except"` 109 | System System `mapstructure:"system" json:"system" yaml:"system"` 110 | Limit Limit `mapstructure:"limit" json:"limit" yaml:"limit"` 111 | Captcha Captcha `mapstructure:"captcha" json:"captcha" yaml:"captcha"` 112 | CorsConf CorsConf `mapstructure:"cors" json:"cors" yaml:"cors"` 113 | Mysql *Mysql `mapstructure:"mysql" json:"mysql" yaml:"mysql"` 114 | Operate Operate `mapstructure:"operate" json:"operate" yaml:"operate"` 115 | } 116 | 117 | type Route struct { 118 | Uri string `mapstructure:"uri" json:"uri" yaml:"uri"` 119 | Method string `mapstructure:"method" json:"method" yaml:"method"` 120 | } 121 | 122 | type Captcha struct { 123 | KeyLong int `mapstructure:"key-long" json:"key-long" yaml:"key-long"` 124 | ImgWidth int `mapstructure:"img-width" json:"img-width" yaml:"img-width"` 125 | ImgHeight int `mapstructure:"img-height" json:"img-height" yaml:"img-height"` 126 | } 127 | 128 | type Limit struct { 129 | Disable bool `mapstructure:"disable" json:"disable" yaml:"disable"` 130 | Limit float64 `mapstructure:"limit" json:"limit" yaml:"limit"` 131 | Burst int `mapstructure:"burst" json:"burst" yaml:"burst"` 132 | } 133 | 134 | type System struct { 135 | GinMode string `mapstructure:"gin-mode" json:"gin-mode" yaml:"gin-mode"` 136 | Tls bool `mapstructure:"tls" json:"tls" yaml:"tls"` 137 | Level string `mapstructure:"level" json:"level" yaml:"level"` // debug,release,test 138 | Addr string `mapstructure:"addr" json:"addr" yaml:"addr"` 139 | DbType string `mapstructure:"db-type" json:"db-type" yaml:"db-type"` 140 | TimeFormat string `mapstructure:"time-format" json:"time-format" yaml:"time-format"` 141 | } 142 | 143 | // SetDefaultAddrAndTimeFormat 144 | func (conf *Conf) SetDefaultAddrAndTimeFormat() { 145 | if conf.System.Addr == "" { 146 | conf.System.Addr = "127.0.0.1:8080" 147 | } 148 | 149 | if conf.System.TimeFormat == "" { 150 | conf.System.TimeFormat = "2006-01-02 15:04:05" 151 | } 152 | } 153 | 154 | // // toStaticUrl 155 | // func (conf *Conf) toStaticUrl(uri string) string { 156 | // path := filepath.Join(conf.System.Addr, conf.System.StaticPrefix, uri) 157 | // if conf.System.Tls { 158 | // return filepath.ToSlash(str.Join("https://", path)) 159 | // } 160 | // return filepath.ToSlash(str.Join("http://", path)) 161 | // } 162 | 163 | // IsExist config file is exist 164 | func (conf *Conf) IsExist() bool { 165 | return conf.getViperConfig().IsExist() 166 | } 167 | 168 | // RemoveFile remove config file 169 | func (conf *Conf) RemoveFile() error { 170 | return conf.getViperConfig().RemoveFile() 171 | } 172 | 173 | // Recover 174 | func (conf *Conf) Recover() error { 175 | conf.newRbacModel() 176 | b, err := json.MarshalIndent(conf, "", "\t") 177 | if err != nil { 178 | return fmt.Errorf("iris-admin recover config faild:%w", err) 179 | } 180 | return conf.getViperConfig().Recover(b) 181 | } 182 | 183 | // getViperConfig get viper config 184 | func (conf *Conf) getViperConfig() *ViperConf { 185 | maxSize := strconv.FormatInt(conf.FileMaxSize, 10) 186 | sessionTimeout := strconv.FormatInt(conf.SessionTimeout, 10) 187 | keyLong := strconv.FormatInt(int64(conf.Captcha.KeyLong), 10) 188 | imgWidth := strconv.FormatInt(int64(conf.Captcha.ImgWidth), 10) 189 | imgHeight := strconv.FormatInt(int64(conf.Captcha.ImgHeight), 10) 190 | limit := strconv.FormatInt(int64(conf.Limit.Limit), 10) 191 | burst := strconv.FormatInt(int64(conf.Limit.Burst), 10) 192 | disable := strconv.FormatBool(conf.Limit.Disable) 193 | tls := strconv.FormatBool(conf.System.Tls) 194 | 195 | mxIdleConns := fmt.Sprintf("%d", conf.Mysql.MaxIdleConns) 196 | mxOpenConns := fmt.Sprintf("%d", conf.Mysql.MaxOpenConns) 197 | logMode := fmt.Sprintf("%t", conf.Mysql.LogMode) 198 | 199 | configName := "iris_admin" 200 | return &ViperConf{ 201 | dir: ConfigDir, 202 | name: configName, 203 | t: ConfigType, 204 | watch: func(vi *viper.Viper) error { 205 | if err := vi.Unmarshal(&conf); err != nil { 206 | return fmt.Errorf("get Unarshal error: %v", err) 207 | } 208 | // watch config file change 209 | vi.SetConfigName(configName) 210 | return nil 211 | }, 212 | // 213 | Default: []byte(` 214 | { 215 | "locale": "` + conf.Locale + `", 216 | "file-max-size": ` + maxSize + `, 217 | "session-timeout": ` + sessionTimeout + `, 218 | "except": 219 | { 220 | "uri": "` + conf.Except.Uri + `", 221 | "method": "` + conf.Except.Method + `" 222 | }, 223 | "cors": 224 | { 225 | "access-origin": "` + conf.CorsConf.AccessOrigin + `", 226 | "access-headers": "` + conf.CorsConf.AccessHeaders + `", 227 | "access-methods": "` + conf.CorsConf.AccessMethods + `", 228 | "access-expose-headers": "` + conf.CorsConf.AccessExposeHeaders + `", 229 | "access-credentials": "` + conf.CorsConf.AccessCredentials + `" 230 | }, 231 | "captcha": 232 | { 233 | "key-long": ` + keyLong + `, 234 | "img-width": ` + imgWidth + `, 235 | "img-height": ` + imgHeight + ` 236 | }, 237 | "limit": 238 | { 239 | "limit": ` + limit + `, 240 | "disable": ` + disable + `, 241 | "burst": ` + burst + ` 242 | }, 243 | "system": 244 | { 245 | "tls": ` + tls + `, 246 | "level": "` + conf.System.Level + `", 247 | "gin-mode": "` + conf.System.GinMode + `", 248 | "addr": "` + conf.System.Addr + `", 249 | "time-format": "` + conf.System.TimeFormat + `" 250 | }, 251 | "mysql": 252 | { 253 | "path": "` + conf.Mysql.Path + `", 254 | "config": "` + conf.Mysql.Config + `", 255 | "db-name": "` + conf.Mysql.DbName + `", 256 | "username": "` + conf.Mysql.Username + `", 257 | "password": "` + conf.Mysql.Password + `", 258 | "max-idle-conns": ` + mxIdleConns + `, 259 | "max-open-conns": ` + mxOpenConns + `, 260 | "log-mode": ` + logMode + `, 261 | "log-zap": "` + conf.Mysql.LogZap + `" 262 | }, 263 | "operate": 264 | { 265 | "except":{ 266 | "uri": "` + conf.Operate.Except.Uri + `", 267 | "method": "` + conf.Operate.Except.Method + `" 268 | }, 269 | "include": 270 | { 271 | "uri": "` + conf.Operate.Include.Uri + `", 272 | "method": "` + conf.Operate.Include.Method + `" 273 | } 274 | } 275 | }`), 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /conf/config_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import "testing" 4 | 5 | func TestSetDefaultAddrAndTimeFormat(t *testing.T) { 6 | dc := &Conf{} 7 | addr := "" 8 | if dc.System.Addr != addr { 9 | t.Errorf("config system addr want '%s' but get '%s'", addr, dc.System.Addr) 10 | } 11 | timeFormat := "" 12 | if dc.System.TimeFormat != timeFormat { 13 | t.Errorf("config system time format want '%s' but get '%s'", timeFormat, dc.System.TimeFormat) 14 | } 15 | dc.SetDefaultAddrAndTimeFormat() 16 | addr = "127.0.0.1:8080" 17 | if dc.System.Addr != addr { 18 | t.Errorf("config system addr want '%s' but get '%s'", addr, dc.System.Addr) 19 | } 20 | timeFormat = "2006-01-02 15:04:05" 21 | if dc.System.TimeFormat != timeFormat { 22 | t.Errorf("config system time format want '%s' but get '%s'", timeFormat, dc.System.TimeFormat) 23 | } 24 | 25 | c := NewConf() 26 | if c.IsExist() { 27 | t.Error("config exist before init") 28 | } 29 | addr = "127.0.0.1:8080" 30 | if c.System.Addr != addr { 31 | t.Errorf("config system addr want '%s' but get '%s'", addr, c.System.Addr) 32 | } 33 | timeFormat = "2006-01-02 15:04:05" 34 | if c.System.TimeFormat != timeFormat { 35 | t.Errorf("config system time format want '%s' but get '%s'", timeFormat, c.System.TimeFormat) 36 | } 37 | 38 | if err := c.Recover(); err != nil { 39 | t.Error(err.Error()) 40 | } 41 | defer func() { 42 | if err := c.getViperConfig().RemoveDir(); err != nil { 43 | t.Error(err.Error()) 44 | } 45 | c.RemoveRbacModel() 46 | }() 47 | if !c.IsExist() { 48 | t.Error("config not exist after recover") 49 | } 50 | c.RemoveFile() 51 | if c.IsExist() { 52 | t.Error("config exist after remove") 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /conf/cros.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | type CorsConf struct { 10 | AccessOrigin string `mapstructure:"access-origin" json:"burst" access-origin:"access-origin"` 11 | AccessHeaders string `mapstructure:"access-headers" json:"access-headers" yaml:"access-headers"` 12 | AccessMethods string `mapstructure:"access-methods" json:"access-methods" yaml:"access-methods"` 13 | AccessExposeHeaders string `mapstructure:"access-expose-headers" json:"access-expose-headers" yaml:"access-expose-headers"` 14 | AccessCredentials string `mapstructure:"access-credentials" json:"access-credentials" yaml:"access-credentials"` 15 | } 16 | 17 | // Cors 18 | func (corsConf *CorsConf) Cors() gin.HandlerFunc { 19 | return func(c *gin.Context) { 20 | method := c.Request.Method 21 | c.Header("Access-Control-Allow-Origin", corsConf.AccessOrigin) 22 | c.Header("Access-Control-Allow-Headers", corsConf.AccessHeaders) 23 | c.Header("Access-Control-Allow-Methods", corsConf.AccessMethods) 24 | c.Header("Access-Control-Expose-Headers", corsConf.AccessExposeHeaders) 25 | c.Header("Access-Control-Allow-Credentials", corsConf.AccessCredentials) 26 | if method == "OPTIONS" { 27 | c.AbortWithStatus(http.StatusNoContent) 28 | } 29 | c.Next() 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /conf/mysql.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import "fmt" 4 | 5 | type Mysql struct { 6 | Path string `mapstructure:"path" json:"path" yaml:"path"` 7 | Config string `mapstructure:"config" json:"config" yaml:"config"` 8 | DbName string `mapstructure:"db-name" json:"db-name" yaml:"db-name"` 9 | Username string `mapstructure:"username" json:"username" yaml:"username"` 10 | Password string `mapstructure:"password" json:"password" yaml:"password"` 11 | MaxIdleConns int `mapstructure:"max-idle-conns" json:"max-idle-conns" yaml:"max-idle-conns"` 12 | MaxOpenConns int `mapstructure:"max-open-conns" json:"max-open-conns" yaml:"max-open-conns"` 13 | LogMode bool `mapstructure:"log-mode" json:"log-mode" yaml:"log-mode"` 14 | LogZap string `mapstructure:"log-zap" json:"log-zap" yaml:"log-zap"` //silent,error,warn,info,zap 15 | } 16 | 17 | // Dsn return mysql dsn 18 | func (m *Mysql) Dsn() string { 19 | return fmt.Sprintf("%s%s?%s", m.BaseDsn(), m.DbName, m.Config) 20 | } 21 | 22 | // Dsn return 23 | func (m *Mysql) BaseDsn() string { 24 | return fmt.Sprintf("%s:%s@tcp(%s)/", m.Username, m.Password, m.Path) 25 | } 26 | -------------------------------------------------------------------------------- /conf/mysql_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import "testing" 4 | 5 | func TestMysqlBaseDsn(t *testing.T) { 6 | m := &Mysql{ 7 | Path: "127.0.0.1:3306", 8 | Config: "charset=utf8mb4", 9 | DbName: "db_name", 10 | Username: "name", 11 | Password: "pwd", 12 | MaxIdleConns: 0, 13 | MaxOpenConns: 0, 14 | LogMode: false, 15 | LogZap: "", 16 | } 17 | b := m.BaseDsn() 18 | want := "name:pwd@tcp(127.0.0.1:3306)/" 19 | if b != want { 20 | t.Errorf("mysql config base dsn want '%s' but get '%s'", want, b) 21 | } 22 | dsn := m.Dsn() 23 | wantDsn := "name:pwd@tcp(127.0.0.1:3306)/db_name?charset=utf8mb4" 24 | if dsn != wantDsn { 25 | t.Errorf("mysql config base dsn want '%s' but get '%s'", wantDsn, dsn) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /conf/operate.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // Operate 8 | // Except set which routers don't generate system log, use ';' to separate. 9 | // Include set which routers need to generate system log, use ';' to separate. 10 | type Operate struct { 11 | Except Route `mapstructure:"except" json:"except" yaml:"except"` 12 | Include Route `mapstructure:"include" json:"include" yaml:"include"` 13 | } 14 | 15 | // GetExcept return routers which need to excepted 16 | func (op Operate) GetExcept() ([]string, []string) { 17 | uri := strings.Split(op.Except.Uri, ";") 18 | method := strings.Split(op.Except.Method, ";") 19 | return uri, method 20 | } 21 | 22 | // GetInclude return routers which need to included 23 | func (op Operate) GetInclude() ([]string, []string) { 24 | uri := strings.Split(op.Include.Uri, ";") 25 | method := strings.Split(op.Include.Method, ";") 26 | return uri, method 27 | } 28 | 29 | // IsInclude check whether the current route needs to belong to the included data 30 | func (op Operate) IsInclude(uri, method string) bool { 31 | incUri, incMethod := op.GetInclude() 32 | if len(incUri) != len(incMethod) { 33 | return false 34 | } 35 | 36 | for i := 0; i < len(incUri); i++ { 37 | if uri == incUri[i] && method == incMethod[i] { 38 | return true 39 | } 40 | } 41 | return false 42 | } 43 | 44 | // IsExcept check whether the current route needs to belong to the excepted data 45 | func (op Operate) IsExcept(uri, method string) bool { 46 | excUri, excMethod := op.GetExcept() 47 | if len(excUri) != len(excMethod) { 48 | return false 49 | } 50 | 51 | for i := 0; i < len(excUri); i++ { 52 | if uri == excUri[i] && method == excMethod[i] { 53 | return true 54 | } 55 | } 56 | return false 57 | } 58 | -------------------------------------------------------------------------------- /conf/viper.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/snowlyg/helper/dir" 10 | "github.com/snowlyg/helper/str" 11 | "github.com/snowlyg/iris-admin/e" 12 | "github.com/spf13/viper" 13 | ) 14 | 15 | type ViperConf struct { 16 | dir string 17 | name string 18 | t string 19 | Default []byte 20 | watch func(*viper.Viper) error 21 | } 22 | 23 | // getConfPath 24 | func (vc *ViperConf) getConfPath() string { 25 | if vc == nil { 26 | return "" 27 | } 28 | return filepath.Join(dir.GetCurrentAbPath(), vc.Dir(), str.Join(vc.name, ".", vc.t)) 29 | } 30 | 31 | // Dir 32 | func (vc *ViperConf) Dir() string { 33 | if vc.dir == "" { 34 | vc.dir = "config" 35 | return vc.dir 36 | } 37 | return vc.dir 38 | } 39 | 40 | // IsExist 41 | func (vc *ViperConf) IsExist() bool { 42 | if vc == nil { 43 | return false 44 | } 45 | return dir.IsExist(vc.getConfPath()) 46 | } 47 | 48 | // RemoveFile remove config file 49 | func (vc *ViperConf) RemoveFile() error { 50 | if vc == nil { 51 | return e.ErrViperConfInvalid 52 | } 53 | d := filepath.Dir(vc.getConfPath()) 54 | b := filepath.Base(d) 55 | if b != vc.Dir() { 56 | return nil 57 | } 58 | return dir.Remove(vc.getConfPath()) 59 | } 60 | 61 | // RemoveDir remove config dir 62 | func (vc *ViperConf) RemoveDir() error { 63 | if vc == nil { 64 | return e.ErrViperConfInvalid 65 | } 66 | d := filepath.Dir(vc.getConfPath()) 67 | b := filepath.Base(d) 68 | if b != vc.Dir() { 69 | return fmt.Errorf("%s viper conf base '%s' want but get '%s'", d, b, vc.Dir()) 70 | } 71 | return os.RemoveAll(d) 72 | } 73 | 74 | // Recover 75 | func (vc *ViperConf) Recover(b []byte) error { 76 | if vc == nil { 77 | return e.ErrViperConfInvalid 78 | } 79 | _, err := dir.WriteBytes(vc.getConfPath(), b) 80 | return err 81 | } 82 | 83 | // NewViperConf 84 | func NewViperConf(vc *ViperConf) error { 85 | if vc == nil { 86 | return e.ErrViperConfInvalid 87 | } 88 | if vc.name == "" { 89 | return e.ErrConfigNameEmpty 90 | } 91 | if vc.t == "" { 92 | vc.t = "yaml" 93 | } 94 | 95 | vc.dir = vc.Dir() 96 | filePath := vc.getConfPath() 97 | 98 | vi := viper.New() 99 | vi.SetConfigName(vc.name) 100 | vi.SetConfigType(vc.t) 101 | vi.AddConfigPath(vc.dir) 102 | isExist := dir.IsExist(filePath) 103 | if !isExist { 104 | if vc.Dir() != "./" { 105 | if err := dir.InsureDir(filepath.Dir(filePath)); err != nil { 106 | return fmt.Errorf("create dir %s fail : %v", filePath, err) 107 | } 108 | } 109 | // ReadConfig 110 | if err := vi.ReadConfig(bytes.NewBuffer(vc.Default)); err != nil { 111 | return fmt.Errorf("read default config fail : %w ", err) 112 | } 113 | // WriteConfigAs 114 | if err := vi.WriteConfigAs(filePath); err != nil { 115 | return fmt.Errorf("write config to path fail: %w ", err) 116 | } 117 | } else { 118 | vi.SetConfigFile(filePath) 119 | if err := vi.ReadInConfig(); err != nil { 120 | return fmt.Errorf("read config fail: %w ", err) 121 | } 122 | } 123 | if err := vc.watch(vi); err != nil { 124 | return fmt.Errorf("watch config fail: %w ", err) 125 | } 126 | return nil 127 | } 128 | -------------------------------------------------------------------------------- /conf/viper_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "path/filepath" 8 | "testing" 9 | 10 | "github.com/snowlyg/helper/dir" 11 | "github.com/snowlyg/helper/str" 12 | "github.com/snowlyg/iris-admin/e" 13 | "github.com/spf13/viper" 14 | ) 15 | 16 | func TestNewViperConfFail(t *testing.T) { 17 | if err := NewViperConf(nil); !errors.Is(err, e.ErrViperConfInvalid) { 18 | t.Errorf("new viper conf with nil return err not confi invalid:%v", err) 19 | } 20 | if err := NewViperConf(&ViperConf{}); !errors.Is(err, e.ErrConfigNameEmpty) { 21 | t.Errorf("new viper conf with nil return err not emtpy name:%v", err) 22 | } 23 | } 24 | 25 | type Zap struct { 26 | Level int64 `mapstructure:"level" json:"level" yaml:"level"` //debug ,info,warn,error,panic,fatal 27 | StacktraceKey string `mapstructure:"stacktrace-key" json:"stacktrace-key" yaml:"stacktrace-key"` 28 | LogInConsole bool `mapstructure:"log-in-console" json:"log-in-console" yaml:"log-in-console"` 29 | } 30 | 31 | func TestViperInit(t *testing.T) { 32 | tc := &Zap{} 33 | vi := &ViperConf{ 34 | // directory: ConfigDir, 35 | name: "config", 36 | t: ConfigType, 37 | watch: func(vi *viper.Viper) error { 38 | if err := vi.Unmarshal(tc); err != nil { 39 | return fmt.Errorf("get Unarshal error: %v", err) 40 | } 41 | vi.SetConfigName("config") 42 | return nil 43 | }, 44 | // 45 | Default: []byte(`{ 46 | "level": 0, 47 | "stacktrace-key": "stacktrace", 48 | "log-in-console": true}`), 49 | } 50 | defer func() { 51 | if err := vi.RemoveDir(); err != nil { 52 | t.Error(err.Error()) 53 | } 54 | }() 55 | 56 | if vi.IsExist() { 57 | t.Error("config exist") 58 | } 59 | 60 | vi.Dir() 61 | if vi.dir != "config" { 62 | t.Errorf("directory want '%s' but get '%s'", "config", vi.dir) 63 | } 64 | 65 | want := Zap{ 66 | Level: 0, 67 | StacktraceKey: "stacktrace", 68 | LogInConsole: true, 69 | } 70 | 71 | if err := NewViperConf(vi); err != nil { 72 | t.Errorf("init %s's config get error: %v", str.Join(vi.name, ".", vi.t), err) 73 | } 74 | 75 | if !vi.IsExist() { 76 | t.Error("config not exist") 77 | } 78 | 79 | if want.Level != tc.Level { 80 | t.Errorf("want %+v but get %+v", want.Level, tc.Level) 81 | } 82 | if want.StacktraceKey != tc.StacktraceKey { 83 | t.Errorf("want %+v but get %+v", want.StacktraceKey, tc.StacktraceKey) 84 | } 85 | if want.LogInConsole != tc.LogInConsole { 86 | t.Errorf("want %+v but get %+v", want.LogInConsole, tc.LogInConsole) 87 | } 88 | 89 | dir.WriteBytes(filepath.Join(vi.getConfPath()), []byte(`{ 90 | "level": 2, 91 | "stacktrace-key": "stacktrace1", 92 | "log-in-console": false}`)) 93 | 94 | want1 := Zap{ 95 | Level: 2, 96 | StacktraceKey: "stacktrace1", 97 | LogInConsole: false, 98 | } 99 | 100 | if err := NewViperConf(vi); err != nil { 101 | t.Errorf("init %s's config get error: %v", str.Join(vi.name, ".", vi.t), err) 102 | } 103 | 104 | if want1.Level != tc.Level { 105 | t.Errorf("want1 %+v but get %+v", want1.Level, tc.Level) 106 | } 107 | if want1.StacktraceKey != tc.StacktraceKey { 108 | t.Errorf("want1 %+v but get %+v", want1.StacktraceKey, tc.StacktraceKey) 109 | } 110 | if want1.LogInConsole != tc.LogInConsole { 111 | t.Errorf("want1 %+v but get %+v", want1.LogInConsole, tc.LogInConsole) 112 | } 113 | 114 | tc.Level = 3 115 | tc.StacktraceKey = "stacktrace3" 116 | tc.LogInConsole = true 117 | 118 | b, err := json.Marshal(&tc) 119 | if err != nil { 120 | t.Error(err.Error()) 121 | } 122 | 123 | if err := vi.Recover(b); err != nil { 124 | t.Error(err.Error()) 125 | } 126 | 127 | want2 := &Zap{} 128 | if b, err := dir.ReadBytes(vi.getConfPath()); err != nil { 129 | t.Error(err.Error()) 130 | } else { 131 | if err := json.Unmarshal(b, want2); err != nil { 132 | t.Error(err.Error()) 133 | } 134 | } 135 | 136 | if want2.Level != tc.Level { 137 | t.Errorf("want2 %+v but get %+v", tc.Level, want2.Level) 138 | } 139 | if want2.StacktraceKey != tc.StacktraceKey { 140 | t.Errorf("want2 %+v but get %+v", tc.StacktraceKey, want2.StacktraceKey) 141 | } 142 | if want2.LogInConsole != tc.LogInConsole { 143 | t.Errorf("want2 %+v but get %+v", tc.LogInConsole, want2.LogInConsole) 144 | } 145 | 146 | if err := vi.RemoveFile(); err != nil { 147 | t.Error(err.Error()) 148 | } 149 | 150 | if vi.IsExist() { 151 | t.Error("config file exist after remove") 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /e/error.go: -------------------------------------------------------------------------------- 1 | package e 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ( 8 | ErrConfigNameEmpty = errors.New("config file name empty") 9 | ErrViperConfInvalid = errors.New("viper conf not invalid") 10 | ErrConfigInvalid = errors.New("config not invalid") 11 | 12 | ErrAuthInvalid = errors.New("auth invalid") 13 | ErrDbTableNameEmpty = errors.New("database table name empty") 14 | ) 15 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | 7 | "github.com/gin-gonic/gin" 8 | admin "github.com/snowlyg/iris-admin" 9 | "github.com/snowlyg/iris-admin/conf" 10 | ) 11 | 12 | func main() { 13 | c := conf.NewConf() 14 | // change default config 15 | if err := c.Recover(); err != nil { 16 | panic(err.Error()) 17 | } 18 | s, err := admin.NewServe(c) 19 | if err != nil { 20 | panic(err.Error()) 21 | } 22 | 23 | engine := s.Engine() 24 | // add group api v1 25 | v1 := engine.Group("/api/v1") 26 | { 27 | v1.GET("/health", func(ctx *gin.Context) { 28 | ctx.String(http.StatusOK, "OK") 29 | }) 30 | } 31 | 32 | // noitce the static path should not start with / 33 | // because static path use /*filepath to match all path start with / 34 | engine.Static("/admin", "./public") 35 | log.Printf("open: http://%s/admin in your browser\n", s.SystemAddr()) 36 | 37 | s.Run() 38 | } 39 | -------------------------------------------------------------------------------- /example/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Iris-Admin by Snowlyg 5 | 71 | 72 | 73 |
74 |

Iris-Admin

75 |

A powerful admin backend built with Go Iris framework

76 |
77 | 78 |
79 |
80 |

About Iris-Admin

81 |

82 | Snowlyg/Iris-Admin is a robust and modular backend system based on the Iris web framework. It offers a scalable, maintainable, and extensible architecture for rapid development of admin dashboards and APIs. 83 |

84 |
85 | 86 |
87 |

Key Features

88 |
89 |
90 |

Modular Architecture

91 |

Build and manage your system with reusable modules for users, roles, permissions, and more.

92 |
93 |
94 |

RBAC & JWT Auth

95 |

Secure your application with role-based access control and JWT authentication.

96 |
97 |
98 |

Swagger Integration

99 |

Auto-generate API documentation using Swagger for easy testing and integration.

100 |
101 |
102 |

Frontend Ready

103 |

Designed to integrate seamlessly with Vue-based admin UIs like vue-element-admin.

104 |
105 |
106 |
107 | 108 |
109 |

Get Started

110 |

111 | Visit the GitHub repository and follow the README to get started: 112 |
113 | 114 | https://github.com/snowlyg/iris-admin 115 | 116 |

117 |
118 |
119 | 120 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /example/readme.md: -------------------------------------------------------------------------------- 1 | # example 2 | 3 | ## todo 4 | 5 | [+] example, one page show repository information. 6 | [-] menu sync. 7 | [-] permission & role & user. 8 | [-] login logout. 9 | [-] use [xaboty/form-create-designer](https://github.com/xaboy/form-create-designer) create form with gorm model. 10 | 11 | 1. get [vue-element-admin](https://github.com/PanJiaChen/vue-element-admin.git) 12 | 13 | ```shell 14 | git clone https://github.com/PanJiaChen/vue-element-admin.git 15 | ``` 16 | 17 | 2. get [iris-admin example](https://github.com/snowlyg/iris-admin.git) 18 | 19 | ```shell 20 | git clone https://github.com/snowlyg/iris-admin.git 21 | 22 | cd iris-admin/example 23 | 24 | go run main.go 25 | ``` -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/snowlyg/iris-admin 2 | 3 | go 1.22.6 4 | 5 | require ( 6 | github.com/aviddiviner/gin-limit v0.0.0-20170918012823-43b5f79762c1 7 | github.com/bwmarrin/snowflake v0.3.0 8 | github.com/casbin/casbin/v2 v2.104.0 9 | github.com/casbin/gorm-adapter/v3 v3.32.0 10 | github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6 11 | github.com/gavv/httpexpect/v2 v2.17.0 12 | github.com/gin-contrib/pprof v1.3.0 13 | github.com/gin-gonic/gin v1.8.1 14 | github.com/go-gormigrate/gormigrate/v2 v2.0.0 15 | github.com/go-playground/validator/v10 v10.11.0 16 | github.com/go-redis/redis/v8 v8.11.5 17 | github.com/golang-jwt/jwt v3.2.2+incompatible 18 | github.com/gorilla/websocket v1.5.0 19 | github.com/mattn/go-colorable v0.1.13 20 | github.com/patrickmn/go-cache v2.1.0+incompatible 21 | github.com/pkg/errors v0.9.1 22 | github.com/satori/go.uuid v1.2.0 23 | github.com/snowlyg/helper v0.1.33 24 | github.com/spf13/viper v1.12.0 25 | github.com/unrolled/secure v1.0.9 26 | go.mongodb.org/mongo-driver v1.11.4 27 | gorm.io/driver/mysql v1.5.7 28 | gorm.io/gorm v1.25.12 29 | ) 30 | 31 | require ( 32 | github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 // indirect 33 | github.com/ajg/form v1.5.1 // indirect 34 | github.com/andybalholm/brotli v1.0.4 // indirect 35 | github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect 36 | github.com/casbin/govaluate v1.3.0 // indirect 37 | github.com/cespare/xxhash/v2 v2.1.2 // indirect 38 | github.com/davecgh/go-spew v1.1.1 // indirect 39 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 40 | github.com/dustin/go-humanize v1.0.1 // indirect 41 | github.com/fatih/color v1.15.0 // indirect 42 | github.com/fatih/structs v1.1.0 // indirect 43 | github.com/fsnotify/fsnotify v1.8.0 // indirect 44 | github.com/gin-contrib/sse v0.1.0 // indirect 45 | github.com/glebarez/go-sqlite v1.20.3 // indirect 46 | github.com/glebarez/sqlite v1.7.0 // indirect 47 | github.com/go-playground/locales v0.14.0 // indirect 48 | github.com/go-playground/universal-translator v0.18.0 // indirect 49 | github.com/go-sql-driver/mysql v1.7.0 // indirect 50 | github.com/gobwas/glob v0.2.3 // indirect 51 | github.com/goccy/go-json v0.9.8-0.20220506185958-23bd66f4c0d5 // indirect 52 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 53 | github.com/golang-sql/sqlexp v0.1.0 // indirect 54 | github.com/golang/snappy v0.0.4 // indirect 55 | github.com/google/go-querystring v1.1.0 // indirect 56 | github.com/google/uuid v1.3.0 // indirect 57 | github.com/gosuri/uilive v0.0.4 // indirect 58 | github.com/gosuri/uiprogress v0.0.1 // indirect 59 | github.com/hashicorp/hcl v1.0.0 // indirect 60 | github.com/imkira/go-interpol v1.1.0 // indirect 61 | github.com/jackc/pgpassfile v1.0.0 // indirect 62 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 63 | github.com/jackc/pgx/v5 v5.5.5 // indirect 64 | github.com/jackc/puddle/v2 v2.2.1 // indirect 65 | github.com/jinzhu/inflection v1.0.0 // indirect 66 | github.com/jinzhu/now v1.1.5 // indirect 67 | github.com/json-iterator/go v1.1.12 // indirect 68 | github.com/klauspost/compress v1.15.6 // indirect 69 | github.com/leodido/go-urn v1.2.1 // indirect 70 | github.com/magiconair/properties v1.8.6 // indirect 71 | github.com/mattn/go-isatty v0.0.18 // indirect 72 | github.com/microsoft/go-mssqldb v1.6.0 // indirect 73 | github.com/mitchellh/go-wordwrap v1.0.1 // indirect 74 | github.com/mitchellh/mapstructure v1.5.0 // indirect 75 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 76 | github.com/modern-go/reflect2 v1.0.2 // indirect 77 | github.com/montanaflynn/stats v0.7.0 // indirect 78 | github.com/pelletier/go-toml v1.9.5 // indirect 79 | github.com/pelletier/go-toml/v2 v2.0.2 // indirect 80 | github.com/pmezard/go-difflib v1.0.0 // indirect 81 | github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect 82 | github.com/sanity-io/litter v1.5.5 // indirect 83 | github.com/sergi/go-diff v1.2.0 // indirect 84 | github.com/spf13/afero v1.8.2 // indirect 85 | github.com/spf13/cast v1.5.0 // indirect 86 | github.com/spf13/jwalterweatherman v1.1.0 // indirect 87 | github.com/spf13/pflag v1.0.5 // indirect 88 | github.com/stretchr/testify v1.8.4 // indirect 89 | github.com/subosito/gotenv v1.3.0 // indirect 90 | github.com/ugorji/go/codec v1.2.7 // indirect 91 | github.com/valyala/bytebufferpool v1.0.0 // indirect 92 | github.com/valyala/fasthttp v1.40.0 // indirect 93 | github.com/xdg-go/pbkdf2 v1.0.0 // indirect 94 | github.com/xdg-go/scram v1.1.1 // indirect 95 | github.com/xdg-go/stringprep v1.0.3 // indirect 96 | github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect 97 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect 98 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect 99 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 // indirect 100 | github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect 101 | github.com/yudai/gojsondiff v1.0.0 // indirect 102 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect 103 | golang.org/x/crypto v0.31.0 // indirect 104 | golang.org/x/net v0.33.0 // indirect 105 | golang.org/x/sync v0.10.0 // indirect 106 | golang.org/x/sys v0.28.0 // indirect 107 | golang.org/x/text v0.21.0 // indirect 108 | google.golang.org/protobuf v1.28.0 // indirect 109 | gopkg.in/ini.v1 v1.66.6 // indirect 110 | gopkg.in/yaml.v2 v2.4.0 // indirect 111 | gopkg.in/yaml.v3 v3.0.1 // indirect 112 | gorm.io/driver/postgres v1.5.9 // indirect 113 | gorm.io/driver/sqlserver v1.5.3 // indirect 114 | gorm.io/plugin/dbresolver v1.5.3 // indirect 115 | modernc.org/libc v1.22.2 // indirect 116 | modernc.org/mathutil v1.5.0 // indirect 117 | modernc.org/memory v1.5.0 // indirect 118 | modernc.org/sqlite v1.20.3 // indirect 119 | moul.io/http2curl/v2 v2.3.0 // indirect 120 | ) 121 | -------------------------------------------------------------------------------- /httptest/base.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/gavv/httpexpect/v2" 11 | "github.com/snowlyg/helper/str" 12 | ) 13 | 14 | var ( 15 | c *Client 16 | // default page request params 17 | GetRequestFunc = NewWithQueryObjectParamFunc(map[string]any{"page": 1, "pageSize": 10}) 18 | 19 | // default page request params 20 | PostRequestFunc = NewWithJsonParamFunc(map[string]any{"page": 1, "pageSize": 10}) 21 | 22 | // default login request params 23 | LoginFunc = NewWithJsonParamFunc(map[string]any{"username": "admin", "password": "123456"}) 24 | 25 | // default login response params 26 | LoginResponse = Responses{ 27 | {Key: "status", Value: http.StatusOK}, 28 | {Key: "message", Value: "OK"}, 29 | {Key: "data", 30 | Value: Responses{ 31 | {Key: "accessToken", Value: "", Type: "notempty"}, 32 | }, 33 | }, 34 | } 35 | 36 | // SuccessResponse default success response params 37 | SuccessResponse = Responses{ 38 | {Key: "status", Value: http.StatusOK}, 39 | {Key: "message", Value: "OK"}, 40 | } 41 | 42 | // ResponsePage default data response params 43 | ResponsePage = Responses{ 44 | {Key: "status", Value: http.StatusOK}, 45 | {Key: "message", Value: "OK"}, 46 | {Key: "data", Value: Responses{ 47 | {Key: "pageSize", Value: 10}, 48 | {Key: "page", Value: 1}, 49 | }}, 50 | } 51 | ) 52 | 53 | // paramFunc 54 | type paramFunc func(req *httpexpect.Request) *httpexpect.Request 55 | 56 | // NewWithJsonParamFunc return req.WithJSON 57 | func NewWithJsonParamFunc(query map[string]any) paramFunc { 58 | return func(req *httpexpect.Request) *httpexpect.Request { 59 | return req.WithJSON(query) 60 | } 61 | } 62 | 63 | // NewWithQueryObjectParamFunc query for get method 64 | func NewWithQueryObjectParamFunc(query map[string]any) paramFunc { 65 | return func(req *httpexpect.Request) *httpexpect.Request { 66 | return req.WithQueryObject(query) 67 | } 68 | } 69 | 70 | // NewWithFileParamFunc return req.WithFile 71 | func NewWithFileParamFunc(fs []File, query map[string]any) paramFunc { 72 | return func(req *httpexpect.Request) *httpexpect.Request { 73 | if len(fs) == 0 { 74 | return req 75 | } 76 | req = req.WithMultipart() 77 | for _, f := range fs { 78 | req = req.WithFile(f.Key, f.Path, f.Reader) 79 | } 80 | if query == nil { 81 | return req 82 | } 83 | return req.WithForm(query) 84 | } 85 | } 86 | 87 | // NewWithFormParamFunc 88 | func NewWithFormParamFunc(query map[string]any) paramFunc { 89 | return func(req *httpexpect.Request) *httpexpect.Request { 90 | if query == nil { 91 | return req 92 | } 93 | return req.WithMultipart().WithForm(query) 94 | } 95 | } 96 | 97 | // NewResponsesWithLength return Responses with length value for data key 98 | func NewResponsesWithLength(status int, message string, data []Responses, length int) Responses { 99 | return Responses{ 100 | {Key: "status", Value: status}, 101 | {Key: "message", Value: message}, 102 | {Key: "data", Value: data, Length: length}, 103 | } 104 | } 105 | 106 | // NewResponses return Responses 107 | func NewResponses(status int, message string, data ...Responses) Responses { 108 | if status != http.StatusOK { 109 | return Responses{ 110 | {Key: "status", Value: status}, 111 | {Key: "message", Value: message}, 112 | } 113 | } 114 | if len(data) == 0 { 115 | return Responses{ 116 | {Key: "status", Value: status}, 117 | {Key: "message", Value: message}, 118 | } 119 | } 120 | if len(data) == 1 { 121 | return Responses{ 122 | {Key: "status", Value: status}, 123 | {Key: "message", Value: message}, 124 | {Key: "data", Value: data[0]}, 125 | } 126 | } 127 | return Responses{ 128 | {Key: "status", Value: status}, 129 | {Key: "message", Value: message}, 130 | {Key: "data", Value: data}, 131 | } 132 | } 133 | 134 | type Client struct { 135 | t *testing.T 136 | conf httpexpect.Config 137 | expect *httpexpect.Expect 138 | status int 139 | headers map[string]string 140 | 141 | tokenIndex string 142 | loginApi string 143 | logoutApi string 144 | } 145 | 146 | type ClientConf struct { 147 | Key string 148 | Value string 149 | } 150 | 151 | const ( 152 | BASE_URL = "base_url" 153 | TOKEN_INDEX = "token_index" 154 | LoginApi = "login_api" 155 | LogoutApi = "logout_api" 156 | ) 157 | 158 | func NewBaseUrlConf(value string) ClientConf { 159 | return ClientConf{Key: BASE_URL, Value: value} 160 | } 161 | 162 | func NewTokenIndexConf(value string) ClientConf { 163 | return ClientConf{Key: TOKEN_INDEX, Value: value} 164 | } 165 | 166 | func NewLoginApiConf(value string) ClientConf { 167 | return ClientConf{Key: LoginApi, Value: value} 168 | } 169 | func NewLogoutApiConf(value string) ClientConf { 170 | return ClientConf{Key: LogoutApi, Value: value} 171 | } 172 | 173 | // NewClient return test client instance 174 | func NewClient(t *testing.T, handler http.Handler, confs ...ClientConf) *Client { 175 | c = &Client{ 176 | t: t, 177 | conf: httpexpect.Config{ 178 | TestName: t.Name(), 179 | Client: &http.Client{ 180 | Transport: httpexpect.NewBinder(handler), 181 | Jar: httpexpect.NewCookieJar(), 182 | }, 183 | Reporter: httpexpect.NewAssertReporter(t), 184 | Printers: []httpexpect.Printer{ 185 | NewDebugPrinter(t, true), 186 | // httpexpect.NewCompactPrinter(t), 187 | // httpexpect.NewCurlPrinter(t), 188 | }, 189 | // Printers: []httpexpect.Printer{ 190 | // httpexpect.NewCompactPrinter(t), 191 | // }, 192 | Formatter: &httpexpect.DefaultFormatter{ 193 | // DisablePaths: true, 194 | // DisableDiffs: true, 195 | // FloatFormat: httpexpect.FloatFormatScientific, 196 | ColorMode: httpexpect.ColorModeAlways, 197 | // LineWidth: 80, 198 | }, 199 | }, 200 | headers: map[string]string{}, 201 | tokenIndex: "data.accessToken", 202 | loginApi: "/login", 203 | logoutApi: "/logout", 204 | } 205 | if len(confs) > 0 { 206 | for _, conf := range confs { 207 | switch conf.Key { 208 | case BASE_URL: 209 | c.conf.BaseURL = conf.Value 210 | case TOKEN_INDEX: 211 | c.tokenIndex = conf.Value 212 | case LoginApi: 213 | c.loginApi = conf.Value 214 | case LogoutApi: 215 | c.logoutApi = conf.Value 216 | } 217 | } 218 | } 219 | c.expect = httpexpect.WithConfig(c.conf) 220 | return c 221 | } 222 | 223 | func (c *Client) SwitchT(t *testing.T) { 224 | c.t = t 225 | c.conf.TestName = t.Name() 226 | c.conf.Reporter = httpexpect.NewAssertReporter(t) 227 | c.conf.Printers = []httpexpect.Printer{ 228 | NewDebugPrinter(t, true), 229 | // httpexpect.NewCompactPrinter(t), 230 | // httpexpect.NewCurlPrinter(t), 231 | } 232 | c.expect = httpexpect.WithConfig(c.conf) 233 | c.expect = c.expect.Builder(func(req *httpexpect.Request) { 234 | req.WithHeaders(c.headers) 235 | }) 236 | } 237 | 238 | // Login for http login 239 | func (c *Client) Login(res Responses, paramFuncs ...paramFunc) error { 240 | if len(paramFuncs) == 0 { 241 | paramFuncs = append(paramFuncs, LoginFunc) 242 | } 243 | c.POST(c.loginApi, res, paramFuncs...) 244 | token := res.GetString(c.tokenIndex) 245 | fmt.Printf("token %s is '%s'\n", c.tokenIndex, token) 246 | if token == "" { 247 | return fmt.Errorf("token %s is empty", c.tokenIndex) 248 | } 249 | c.headers["Authorization"] = str.Join("Bearer ", token) 250 | c.expect = c.expect.Builder(func(req *httpexpect.Request) { 251 | req.WithHeaders(c.headers) 252 | }) 253 | return nil 254 | } 255 | 256 | // Logout for http logout 257 | func (c *Client) Logout(res Responses) { 258 | if res == nil { 259 | res = SuccessResponse 260 | } 261 | c.GET(c.logoutApi, res) 262 | 263 | c.headers["Authorization"] = "" 264 | c.expect = c.expect.Builder(func(req *httpexpect.Request) { 265 | req.WithHeaders(c.headers) 266 | }) 267 | } 268 | 269 | type File struct { 270 | Key string 271 | Path string 272 | Reader io.Reader 273 | } 274 | 275 | // checkStatus check what's http response stauts want 276 | func (c *Client) checkStatus() int { 277 | if c.status == 0 { 278 | return http.StatusOK 279 | } 280 | return c.status 281 | } 282 | 283 | // SetStatus set what's http response stauts want 284 | func (c *Client) SetStatus(status int) *Client { 285 | c.status = status 286 | return c 287 | } 288 | 289 | // AddHeader 290 | func (c *Client) AddHeader(k, v string) *Client { 291 | c.headers[k] = v 292 | return c 293 | } 294 | 295 | // POST 296 | func (c *Client) POST(url string, resps any, paramFuncs ...paramFunc) { 297 | req := c.expect.POST(url) 298 | if len(paramFuncs) > 0 { 299 | for _, f := range paramFuncs { 300 | req = f(req) 301 | } 302 | } 303 | if testRes, ok := resps.(Responses); ok { 304 | obj := req.Expect().Status(c.checkStatus()).JSON() 305 | testRes.Test(obj) 306 | } else if testRes, ok := resps.([]Responses); ok { 307 | array := req.Expect().Status(c.checkStatus()).JSON().Array() 308 | for i, v := range testRes { 309 | v.Test(array.Value(i)) 310 | } 311 | } else { 312 | log.Println("data type error") 313 | } 314 | } 315 | 316 | // PUT 317 | func (c *Client) PUT(url string, resps any, paramFuncs ...paramFunc) { 318 | req := c.expect.PUT(url) 319 | if len(paramFuncs) > 0 { 320 | for _, f := range paramFuncs { 321 | req = f(req) 322 | } 323 | } 324 | if testRes, ok := resps.(Responses); ok { 325 | obj := req.Expect().Status(c.checkStatus()).JSON() 326 | testRes.Test(obj) 327 | } else if testRes, ok := resps.([]Responses); ok { 328 | array := req.Expect().Status(c.checkStatus()).JSON().Array() 329 | for i, v := range testRes { 330 | v.Test(array.Value(i)) 331 | } 332 | } else { 333 | log.Println("data type error") 334 | } 335 | } 336 | 337 | // UPLOAD 338 | func (c *Client) UPLOAD(url string, resps any, paramFuncs ...paramFunc) { 339 | req := c.expect.POST(url) 340 | if len(paramFuncs) > 0 { 341 | for _, f := range paramFuncs { 342 | req = f(req) 343 | } 344 | } 345 | if testRes, ok := resps.(Responses); ok { 346 | obj := req.Expect().Status(c.checkStatus()).JSON() 347 | testRes.Test(obj) 348 | } else if testRes, ok := resps.([]Responses); ok { 349 | array := req.Expect().Status(c.checkStatus()).JSON().Array() 350 | for i, v := range testRes { 351 | v.Test(array.Value(i)) 352 | } 353 | } else { 354 | log.Println("data type error") 355 | } 356 | } 357 | 358 | // GET 359 | func (c *Client) GET(url string, resps any, paramFuncs ...paramFunc) { 360 | req := c.expect.GET(url) 361 | if len(paramFuncs) > 0 { 362 | for _, f := range paramFuncs { 363 | req = f(req) 364 | } 365 | } 366 | if resp, ok := resps.(Responses); ok { 367 | obj := req.Expect().Status(c.checkStatus()).JSON() 368 | resp.Test(obj) 369 | } else if resp, ok := resps.([]Responses); ok { 370 | array := req.Expect().Status(c.checkStatus()).JSON().Array() 371 | for i, v := range resp { 372 | v.Test(array.Value(i)) 373 | } 374 | } else { 375 | log.Println("data type error") 376 | } 377 | } 378 | 379 | // DOWNLOAD 380 | func (c *Client) DOWNLOAD(url string, resps any, paramFuncs ...paramFunc) string { 381 | req := c.expect.GET(url) 382 | if len(paramFuncs) > 0 { 383 | for _, f := range paramFuncs { 384 | req = f(req) 385 | } 386 | } 387 | 388 | return req.Expect().Status(c.checkStatus()).Body().NotEmpty().Raw() 389 | } 390 | 391 | // DELETE 392 | func (c *Client) DELETE(url string, resps any, paramFuncs ...paramFunc) { 393 | req := c.expect.DELETE(url) 394 | if len(paramFuncs) > 0 { 395 | for _, f := range paramFuncs { 396 | req = f(req) 397 | } 398 | } 399 | if testRes, ok := resps.(Responses); ok { 400 | obj := req.Expect().Status(c.checkStatus()).JSON() 401 | testRes.Test(obj) 402 | } else if testRes, ok := resps.([]Responses); ok { 403 | array := req.Expect().Status(c.checkStatus()).JSON().Array() 404 | for i, v := range testRes { 405 | v.Test(array.Value(i)) 406 | } 407 | } else { 408 | log.Println("data type error") 409 | } 410 | } 411 | -------------------------------------------------------------------------------- /httptest/base_test.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | "testing" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/snowlyg/helper/dir" 10 | ) 11 | 12 | type Request struct { 13 | Message string `json:"message" form:"message" uri:"message"` 14 | } 15 | 16 | // GinHandler Create add /example route to gin engine 17 | func GinHandler(r *gin.Engine) *gin.Engine { 18 | 19 | // Add route to the gin engine 20 | r.GET("/example", func(c *gin.Context) { 21 | var req Request 22 | if errs := c.ShouldBind(&req); errs != nil { 23 | c.JSON(http.StatusBadRequest, nil) 24 | return 25 | } 26 | message := "pong" 27 | if req.Message != "" { 28 | message = req.Message 29 | } 30 | c.JSON(http.StatusOK, gin.H{ 31 | "status": 200, 32 | "message": "OK", 33 | "data": gin.H{ 34 | "message": message, 35 | }, 36 | }) 37 | }) 38 | 39 | // Add route to the gin engine 40 | r.GET("/array", func(c *gin.Context) { 41 | var req Request 42 | if errs := c.ShouldBind(&req); errs != nil { 43 | c.JSON(http.StatusBadRequest, nil) 44 | return 45 | } 46 | c.JSON(http.StatusOK, []string{"1", "2"}) 47 | }) 48 | 49 | // Add route to the gin engine 50 | r.GET("/mutil", func(c *gin.Context) { 51 | var req Request 52 | if errs := c.ShouldBind(&req); errs != nil { 53 | c.JSON(http.StatusBadRequest, nil) 54 | return 55 | } 56 | message := "pong" 57 | if req.Message != "" { 58 | message = req.Message 59 | } 60 | c.JSON(http.StatusOK, gin.H{ 61 | "status": 200, 62 | "message": "OK", 63 | "data": []gin.H{ 64 | {"message": message}, 65 | {"message": message}, 66 | }, 67 | }) 68 | }) 69 | 70 | // Add route to the gin engine 71 | r.POST("/example", func(c *gin.Context) { 72 | var req Request 73 | if errs := c.ShouldBindJSON(&req); errs != nil { 74 | c.JSON(http.StatusBadRequest, gin.H{ 75 | "status": http.StatusBadRequest, 76 | "message": "FAIL", 77 | }) 78 | return 79 | } 80 | message := "pong" 81 | if req.Message != "" { 82 | message = req.Message 83 | } 84 | c.JSON(http.StatusOK, gin.H{ 85 | "status": 200, 86 | "message": "OK", 87 | "data": gin.H{ 88 | "message": message, 89 | }, 90 | }) 91 | }) 92 | 93 | // Add route to the gin engine 94 | r.POST("/upload", func(c *gin.Context) { 95 | _, _, err := c.Request.FormFile("file") 96 | if err != nil { 97 | c.JSON(http.StatusBadRequest, gin.H{ 98 | "status": http.StatusBadRequest, 99 | "message": "FAIL", 100 | }) 101 | return 102 | } 103 | 104 | c.JSON(http.StatusOK, gin.H{ 105 | "status": 200, 106 | "message": "OK", 107 | }) 108 | }) 109 | 110 | type RequestId struct { 111 | Id uint `json:"id" form:"id" uri:"id"` 112 | } 113 | 114 | // Add route to the gin engine 115 | r.DELETE("/example/:id", func(c *gin.Context) { 116 | var req RequestId 117 | if errs := c.ShouldBindUri(&req); errs != nil { 118 | c.JSON(http.StatusBadRequest, nil) 119 | return 120 | } 121 | c.JSON(http.StatusOK, gin.H{ 122 | "status": 200, 123 | "message": "OK", 124 | "data": gin.H{ 125 | "id": req.Id, 126 | }, 127 | }) 128 | }) 129 | 130 | // Add route to the gin engine 131 | r.POST("login", func(c *gin.Context) { 132 | c.JSON(http.StatusOK, gin.H{ 133 | "status": 200, 134 | "message": "OK", 135 | "data": gin.H{ 136 | "AccessToken": "EIIDFJDIKFJJIdfdkfk.uisdifsdfisdouf", 137 | "user": gin.H{ 138 | "id": 1, 139 | }, 140 | }, 141 | }) 142 | }) 143 | 144 | // Add route to the gin engine 145 | r.GET("logout", func(c *gin.Context) { 146 | c.JSON(http.StatusOK, gin.H{ 147 | "status": 200, 148 | "message": "OK", 149 | }) 150 | }) 151 | 152 | // Add route to the gin engine 153 | r.GET("header", func(c *gin.Context) { 154 | c.GetHeader("Authorization") 155 | c.JSON(http.StatusOK, gin.H{ 156 | "status": 200, 157 | "message": "OK", 158 | "data": gin.H{ 159 | "Authorization": c.GetHeader("Authorization"), 160 | }, 161 | }) 162 | }) 163 | 164 | // return gin engine with newly added route 165 | return r 166 | } 167 | 168 | func TestNewClient(t *testing.T) { 169 | engine := gin.New() 170 | // Create httpexpect instance 171 | client := NewClient(t, GinHandler(engine)) 172 | client.GET("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}})) 173 | client.DELETE("/example/1", NewResponses(http.StatusOK, "OK", Responses{{Key: "id", Value: 1}})) 174 | } 175 | 176 | func TestNewWithQueryObjectParamFunc(t *testing.T) { 177 | engine := gin.New() 178 | // Create httpexpect instance 179 | client := NewClient(t, GinHandler(engine)) 180 | pageKeys := Responses{{Key: "message", Value: "message"}} 181 | client.GET("/example", NewResponses(http.StatusOK, "OK", pageKeys), NewWithQueryObjectParamFunc(map[string]interface{}{"message": "message"})) 182 | } 183 | 184 | func TestNewNewWithJsonParamFunc(t *testing.T) { 185 | engine := gin.New() 186 | // Create httpexpect instance 187 | client := NewClient(t, GinHandler(engine)) 188 | client.POST("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "message"}}), NewWithJsonParamFunc(map[string]interface{}{"message": "message"})) 189 | client.POST("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}}), NewWithJsonParamFunc(map[string]interface{}{"message": ""})) 190 | } 191 | 192 | func TestNewResponses(t *testing.T) { 193 | engine := gin.New() 194 | // Create httpexpect instance 195 | client := NewClient(t, GinHandler(engine)) 196 | 197 | client.GET("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}})) 198 | client.GET("/mutil", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}}, Responses{{Key: "message", Value: "pong"}})) 199 | client.SetStatus(http.StatusBadRequest).POST("/example", NewResponses(http.StatusBadRequest, "FAIL", nil)) 200 | } 201 | 202 | func TestNewResponsesWithLength(t *testing.T) { 203 | engine := gin.New() 204 | // Create httpexpect instance 205 | client := NewClient(t, GinHandler(engine)) 206 | res := []Responses{{{Key: "message", Value: "pong"}}, {{Key: "message", Value: "pong"}}} 207 | client.GET("/mutil", NewResponsesWithLength(http.StatusOK, "OK", res, 2)) 208 | } 209 | 210 | func TestNewWithFileParamFunc(t *testing.T) { 211 | name := "test_img.jpg" 212 | if _, err := dir.WriteString("./"+name, ""); err != nil { 213 | t.Fatal(err.Error()) 214 | } 215 | defer os.Remove("./" + name) 216 | 217 | engine := gin.New() 218 | // Create httpexpect instance 219 | client := NewClient(t, GinHandler(engine)) 220 | fh, _ := os.Open("./" + name) 221 | defer fh.Close() 222 | 223 | uf := []File{{Key: "file", Path: name, Reader: fh}} 224 | client.UPLOAD("/upload", SuccessResponse, NewWithFileParamFunc(uf, nil)) 225 | } 226 | 227 | func TestLogin(t *testing.T) { 228 | engine := gin.New() 229 | // Create httpexpect instance 230 | client := NewClient(t, GinHandler(engine), NewTokenIndexConf("data.AccessToken")) 231 | x := Responses{{Key: "AccessToken", Value: "EIIDFJDIKFJJIdfdkfk.uisdifsdfisdouf"}, {Key: "user", Value: Responses{{Key: "id", Value: 1}}}} 232 | err := client.Login(NewResponses(http.StatusOK, "OK", x)) 233 | if err != nil { 234 | t.Error(err.Error()) 235 | } 236 | if x.GetId("data.user.id") == 0 { 237 | t.Error("id is 0") 238 | } 239 | client.GET("/header", NewResponses(http.StatusOK, "OK", Responses{{Key: "Authorization", Value: "Bearer EIIDFJDIKFJJIdfdkfk.uisdifsdfisdouf"}})) 240 | } 241 | 242 | func TestLogout(t *testing.T) { 243 | engine := gin.New() 244 | client := NewClient(t, GinHandler(engine)) 245 | client.Logout(SuccessResponse) 246 | } 247 | -------------------------------------------------------------------------------- /httptest/common.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | 3 | // // BeforeTestMainGin 4 | // func BeforeTestMainGin(party func(wi *WebServer), seed func(wi *WebServer, mc *MigrationCmd)) (string, *WebServer) { 5 | // dbType := admin.TestDbType 6 | // if dbType != "" { 7 | // CONFIG.System.DbType = dbType 8 | // } 9 | // if dbType == "redis" { 10 | // if err := cache.Recover(); err != nil { 11 | // log.Printf("cache recover fail:%s\n", err.Error()) 12 | // } 13 | // } 14 | // if err := Recover(); err != nil { 15 | // log.Printf("web recover fail:%s\n", err.Error()) 16 | // } 17 | 18 | // node, _ := snowflake.NewNode(1) 19 | // uuid := str.Join("gin", "_", node.Generate().String()) 20 | 21 | // CONFIG.DbName = uuid 22 | // if user := admin.TestMysqlName; user != "" { 23 | // CONFIG.Username = user 24 | // } 25 | // if pwd := admin.TestMysqlPwd; pwd != "" { 26 | // CONFIG.Password = pwd 27 | // } 28 | // if addr := admin.TestMysqlAddr; addr != "" { 29 | // CONFIG.Path = addr 30 | // } 31 | // CONFIG.LogMode = true 32 | // if err := Recover(); err != nil { 33 | // log.Printf("databse recover fail:%s\n", err.Error()) 34 | // } 35 | 36 | // if Instance() == nil { 37 | // fmt.Println("database instance is nil") 38 | // return uuid, nil 39 | // } 40 | 41 | // wi := Init() 42 | // party(wi) 43 | // StartTest(wi) 44 | 45 | // mc := New() 46 | // seed(wi, mc) 47 | // err := mc.Migrate() 48 | // if err != nil { 49 | // fmt.Printf("migrate fail: [%s]", err.Error()) 50 | // return uuid, nil 51 | // } 52 | // err = mc.Seed() 53 | // if err != nil { 54 | // fmt.Printf("seed fail: [%s]", err.Error()) 55 | // return uuid, nil 56 | // } 57 | 58 | // return uuid, wi 59 | // } 60 | 61 | // func AfterTestMain(uuid string, isDelDb bool) { 62 | // if isDelDb { 63 | // dsn := CONFIG.BaseDsn() 64 | // if err := DorpDB(dsn, "mysql", uuid); err != nil { 65 | // log.Printf("drop table(%s) on dsn(%s) fail %s\n", uuid, dsn, err.Error()) 66 | // } 67 | // } 68 | 69 | // if db, _ := Instance().DB(); db != nil { 70 | // db.Close() 71 | // } 72 | 73 | // // defer operation.Remove() 74 | // // defer casbin.Remove() 75 | // // defer Remove() 76 | // // defer Remove() 77 | // } 78 | -------------------------------------------------------------------------------- /httptest/printer.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "net/http/httputil" 9 | "strings" 10 | "time" 11 | 12 | "github.com/gavv/httpexpect/v2" 13 | "github.com/gorilla/websocket" 14 | ) 15 | 16 | // DebugPrinter implements Printer and WebsocketPrinter. 17 | // Uses net/http/httputil to dump both requests and responses. 18 | // Also prints all websocket messages. 19 | type DebugPrinter struct { 20 | logger httpexpect.Logger 21 | body bool 22 | } 23 | 24 | var contentTypes = "text/html,image/jpeg,image/png,video/mp4" 25 | 26 | // NewDebugPrinter returns a new DebugPrinter given a logger and body 27 | // flag. If body is true, request and response body is also printed. 28 | func NewDebugPrinter(logger httpexpect.Logger, body bool) DebugPrinter { 29 | return DebugPrinter{logger, body} 30 | } 31 | 32 | // Request implements Printer.Request. 33 | func (p DebugPrinter) Request(req *http.Request) { 34 | if req == nil { 35 | return 36 | } 37 | log.Printf("Content-Type:%s\n", req.Header.Get("Content-Type")) 38 | if req.URL.Path == "/api/v1/file/upload" || strings.Contains(req.Header.Get("Content-Type"), "multipart/form-data") { 39 | p.body = false 40 | } 41 | 42 | for _, contentType := range strings.Split(contentTypes, ",") { 43 | if !p.body { 44 | continue 45 | } 46 | if req.Header.Get("Content-Type") == contentType { 47 | p.body = false 48 | } 49 | } 50 | 51 | dump, err := httputil.DumpRequest(req, p.body) 52 | if err != nil { 53 | panic(err) 54 | } 55 | p.logger.Logf("%s", dump) 56 | } 57 | 58 | // Response implements Printer.Response. 59 | func (p DebugPrinter) Response(resp *http.Response, duration time.Duration) { 60 | if resp == nil { 61 | return 62 | } 63 | for _, contentType := range strings.Split(contentTypes, ",") { 64 | if !p.body { 65 | continue 66 | } 67 | if resp.Header.Get("Content-Type") == contentType { 68 | p.body = false 69 | } 70 | } 71 | 72 | dump, err := httputil.DumpResponse(resp, p.body) 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | text := strings.Replace(string(dump), "\r\n", "\n", -1) 78 | lines := strings.SplitN(text, "\n", 2) 79 | 80 | p.logger.Logf("%s %s\n%s", lines[0], duration, lines[1]) 81 | } 82 | 83 | // WebsocketWrite implements WebsocketPrinter.WebsocketWrite. 84 | func (p DebugPrinter) WebsocketWrite(typ int, content []byte, closeCode int) { 85 | b := &bytes.Buffer{} 86 | fmt.Fprintf(b, "-> Sent: %s", wsMessageType(typ)) 87 | if typ == websocket.CloseMessage { 88 | fmt.Fprintf(b, " %s", wsCloseCode(closeCode)) 89 | } 90 | fmt.Fprint(b, "\n") 91 | if len(content) > 0 { 92 | if typ == websocket.BinaryMessage { 93 | fmt.Fprintf(b, "%v\n", content) 94 | } else { 95 | fmt.Fprintf(b, "%s\n", content) 96 | } 97 | } 98 | fmt.Fprintf(b, "\n") 99 | p.logger.Logf(b.String()) 100 | } 101 | 102 | // WebsocketRead implements WebsocketPrinter.WebsocketRead. 103 | func (p DebugPrinter) WebsocketRead(typ int, content []byte, closeCode int) { 104 | b := &bytes.Buffer{} 105 | fmt.Fprintf(b, "<- Received: %s", wsMessageType(typ)) 106 | if typ == websocket.CloseMessage { 107 | fmt.Fprintf(b, " %s", wsCloseCode(closeCode)) 108 | } 109 | fmt.Fprint(b, "\n") 110 | if len(content) > 0 { 111 | if typ == websocket.BinaryMessage { 112 | fmt.Fprintf(b, "%v\n", content) 113 | } else { 114 | fmt.Fprintf(b, "%s\n", content) 115 | } 116 | } 117 | fmt.Fprintf(b, "\n") 118 | p.logger.Logf(b.String()) 119 | } 120 | 121 | type wsMessageType int 122 | 123 | func (wmt wsMessageType) String() string { 124 | s := "unknown" 125 | 126 | switch wmt { 127 | case websocket.TextMessage: 128 | s = "text" 129 | case websocket.BinaryMessage: 130 | s = "binary" 131 | case websocket.CloseMessage: 132 | s = "close" 133 | case websocket.PingMessage: 134 | s = "ping" 135 | case websocket.PongMessage: 136 | s = "pong" 137 | } 138 | 139 | return fmt.Sprintf("%s(%d)", s, wmt) 140 | } 141 | 142 | type wsCloseCode int 143 | 144 | // https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent/code 145 | func (wcc wsCloseCode) String() string { 146 | s := "Unknown" 147 | 148 | switch wcc { 149 | case 1000: 150 | s = "NormalClosure" 151 | case 1001: 152 | s = "GoingAway" 153 | case 1002: 154 | s = "ProtocolError" 155 | case 1003: 156 | s = "UnsupportedData" 157 | case 1004: 158 | s = "Reserved" 159 | case 1005: 160 | s = "NoStatusReceived" 161 | case 1006: 162 | s = "AbnormalClosure" 163 | case 1007: 164 | s = "InvalidFramePayloadData" 165 | case 1008: 166 | s = "PolicyViolation" 167 | case 1009: 168 | s = "MessageTooBig" 169 | case 1010: 170 | s = "MandatoryExtension" 171 | case 1011: 172 | s = "InternalServerError" 173 | case 1012: 174 | s = "ServiceRestart" 175 | case 1013: 176 | s = "TryAgainLater" 177 | case 1014: 178 | s = "BadGateway" 179 | case 1015: 180 | s = "TLSHandshake" 181 | } 182 | 183 | return fmt.Sprintf("%s(%d)", s, wcc) 184 | } 185 | -------------------------------------------------------------------------------- /httptest/reporter.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | -------------------------------------------------------------------------------- /httptest/respose.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "reflect" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/gavv/httpexpect/v2" 11 | "github.com/snowlyg/helper/arr" 12 | ) 13 | 14 | // Responses 15 | type Responses []Response 16 | 17 | // Response 18 | type Response struct { 19 | Type string // httpest type , if empty use IsEqual() function to test 20 | Key string // httptest data's key 21 | Value any // httptest data's value 22 | Length int // httptest data's length,when the data are array or map 23 | Func func(obj any) // httpest func, you can add your test logic ,can be empty 24 | } 25 | 26 | // Keys return Responses object key array 27 | func (res Responses) Keys() []string { 28 | keys := []string{} 29 | for _, re := range res { 30 | keys = append(keys, re.Key) 31 | } 32 | return keys 33 | } 34 | 35 | // IdKeys return Responses with id 36 | func IdKeys() Responses { 37 | return Responses{ 38 | {Key: "id", Value: 0, Type: "ge"}, 39 | } 40 | } 41 | 42 | // Test for data test 43 | func Test(value *httpexpect.Value, reses ...any) { 44 | for _, ks := range reses { 45 | if ks == nil { 46 | return 47 | } 48 | reflectTypeString := reflect.TypeOf(ks).String() 49 | switch reflectTypeString { 50 | case "bool": 51 | value.Boolean().IsEqual(ks.(bool)) 52 | case "string": 53 | value.String().IsEqual(ks.(string)) 54 | case "float64": 55 | value.Number().IsEqual(ks.(float64)) 56 | case "uint": 57 | value.Number().IsEqual(ks.(uint)) 58 | case "int": 59 | value.IsEqual(ks.(int)) 60 | 61 | case "[]httptest.Responses": 62 | valueLen := len(ks.([]Responses)) 63 | length := int(value.Array().Length().Raw()) 64 | value.Array().Length().IsEqual(valueLen) 65 | if length > 0 { 66 | max := 1 67 | if valueLen == length { 68 | max = length 69 | } 70 | for i := 0; i < max; i++ { 71 | ks.([]Responses)[i].Test(value.Array().Value(i)) 72 | } 73 | } 74 | 75 | case "map[int][]httptest.Responses": 76 | values := ks.(map[int][]Responses) 77 | length := len(values) 78 | value.Object().Keys().Length().IsEqual(length) 79 | if length > 0 { 80 | for key, v := range values { 81 | for _, vres := range v { 82 | vres.Test(value.Object().Value(strconv.FormatInt(int64(key), 10))) 83 | } 84 | } 85 | } 86 | case "httptest.Responses": 87 | ks.(Responses).Test(value) 88 | case "[]uint": 89 | valueLen := len(ks.([]uint)) 90 | value.Array().Length().IsEqual(valueLen) 91 | length := int(value.Array().Length().Raw()) 92 | if length > 0 { 93 | max := 1 94 | if valueLen == length { 95 | max = length 96 | } 97 | for i := 0; i < max; i++ { 98 | value.Array().Value(i).Number().IsEqual(ks.([]uint)[i]) 99 | } 100 | } 101 | 102 | case "[]string": 103 | valueLen := len(ks.([]string)) 104 | value.Array().Length().IsEqual(valueLen) 105 | length := int(value.Array().Length().Raw()) 106 | if length > 0 { 107 | max := 1 108 | if valueLen == length { 109 | max = length 110 | } 111 | for i := 0; i < max; i++ { 112 | value.Array().Value(i).String().IsEqual(ks.([]string)[i]) 113 | } 114 | } 115 | case "map[int]string": 116 | values := ks.(map[int]string) 117 | value.Object().Keys().Length().IsEqual(len(values)) 118 | for key, v := range values { 119 | value.Object().Value(strconv.FormatInt(int64(key), 10)).IsEqual(v) 120 | } 121 | default: 122 | continue 123 | } 124 | } 125 | } 126 | 127 | // Scan scan data form http response 128 | func Scan(object *httpexpect.Object, reses ...Responses) { 129 | if len(reses) == 0 { 130 | return 131 | } 132 | 133 | //return once 134 | if len(reses) == 1 { 135 | reses[0].Scan(object.Value("data").Object()) 136 | return 137 | } 138 | 139 | array := object.Value("data").Array() 140 | length := int(array.Length().Raw()) 141 | if length < len(reses) { 142 | fmt.Println("Return data not IsEqual keys length") 143 | array.Length().IsEqual(len(reses)) 144 | return 145 | } 146 | 147 | // return array 148 | for m, res := range reses { 149 | if res == nil { 150 | return 151 | } 152 | res.Scan(object.Value("data").Array().Value(m).Object()) 153 | } 154 | } 155 | 156 | // Test Test Responses object 157 | func (resp Responses) Test(value *httpexpect.Value) { 158 | for _, rs := range resp { 159 | if rs.Value == nil { 160 | continue 161 | } 162 | if rs.Func != nil { 163 | rs.Func(value.Object().Value(rs.Key)) 164 | 165 | } else { 166 | reflectTypeString := reflect.TypeOf(rs.Value).String() 167 | switch reflectTypeString { 168 | case "bool": 169 | value.Object().Value(rs.Key).Boolean().IsEqual(rs.Value.(bool)) 170 | case "string": 171 | if strings.ToLower(rs.Type) == "notempty" { 172 | value.Object().Value(rs.Key).String().NotEmpty() 173 | } else { 174 | value.Object().Value(rs.Key).String().IsEqual(rs.Value.(string)) 175 | } 176 | case "float64": 177 | if strings.ToLower(rs.Type) == "ge" { 178 | value.Object().Value(rs.Key).Number().Ge(rs.Value.(float64)) 179 | } else { 180 | value.Object().Value(rs.Key).Number().IsEqual(rs.Value.(float64)) 181 | } 182 | case "uint": 183 | if strings.ToLower(rs.Type) == "ge" { 184 | value.Object().Value(rs.Key).Number().Ge(rs.Value.(uint)) 185 | } else { 186 | value.Object().Value(rs.Key).Number().IsEqual(rs.Value.(uint)) 187 | } 188 | case "int": 189 | if strings.ToLower(rs.Type) == "ge" { 190 | value.Object().Value(rs.Key).Number().Ge(rs.Value.(int)) 191 | } else { 192 | value.Object().Value(rs.Key).Number().IsEqual(rs.Value.(int)) 193 | } 194 | case "[]httptest.Responses": 195 | valueLen := len(rs.Value.([]Responses)) 196 | length := int(value.Object().Value(rs.Key).Array().Length().Raw()) 197 | value.Object().Value(rs.Key).Array().Length().IsEqual(valueLen) 198 | if length > 0 { 199 | max := 1 200 | if rs.Length > 0 { 201 | max = rs.Length 202 | } 203 | if valueLen == length { 204 | max = length 205 | } 206 | if valueLen > 0 { 207 | for i := 0; i < max; i++ { 208 | rs.Value.([]Responses)[i].Test(value.Object().Value(rs.Key).Array().Value(i)) 209 | } 210 | } 211 | } 212 | 213 | case "map[int][]httptest.Responses": 214 | values := rs.Value.(map[int][]Responses) 215 | length := len(values) 216 | value.Object().Value(rs.Key).Object().Keys().Length().IsEqual(length) 217 | if length > 0 { 218 | for key, v := range values { 219 | for _, vres := range v { 220 | vres.Test(value.Object().Value(rs.Key).Object().Value(strconv.FormatInt(int64(key), 10))) 221 | } 222 | } 223 | } 224 | case "httptest.Responses": 225 | rs.Value.(Responses).Test(value.Object().Value(rs.Key)) 226 | case "[]uint": 227 | valueLen := len(rs.Value.([]uint)) 228 | value.Object().Value(rs.Key).Array().Length().IsEqual(valueLen) 229 | length := int(value.Object().Value(rs.Key).Array().Length().Raw()) 230 | if length > 0 { 231 | max := 1 232 | if rs.Length > 0 { 233 | max = rs.Length 234 | } 235 | if valueLen == length { 236 | max = length 237 | } 238 | for i := 0; i < max; i++ { 239 | value.Object().Value(rs.Key).Array().ContainsAny(rs.Value.([]uint)[i]) 240 | } 241 | } 242 | 243 | case "[]string": 244 | 245 | if strings.ToLower(rs.Type) == "null" { 246 | value.Object().Value(rs.Key).IsNull() 247 | } else if strings.ToLower(rs.Type) == "notnull" { 248 | value.Object().Value(rs.Key).NotNull() 249 | } else { 250 | valueLen := len(rs.Value.([]string)) 251 | value.Object().Value(rs.Key).Array().Length().IsEqual(valueLen) 252 | length := int(value.Object().Value(rs.Key).Array().Length().Raw()) 253 | if length > 0 { 254 | max := 1 255 | if rs.Length > 0 { 256 | max = rs.Length 257 | } 258 | if valueLen == length { 259 | max = length 260 | } 261 | for i := 0; i < max; i++ { 262 | value.Object().Value(rs.Key).Array().ContainsAny(rs.Value.([]string)[i]) 263 | } 264 | } 265 | } 266 | case "map[int]string": 267 | if strings.ToLower(rs.Type) == "null" { 268 | value.Object().Value(rs.Key).IsNull() 269 | } else if strings.ToLower(rs.Type) == "notnull" { 270 | value.Object().Value(rs.Key).NotNull() 271 | } else { 272 | values := rs.Value.(map[int]string) 273 | value.Object().Value(rs.Key).Object().Keys().Length().IsEqual(len(values)) 274 | for key, v := range values { 275 | value.Object().Value(rs.Key).Object().Value(strconv.FormatInt(int64(key), 10)).IsEqual(v) 276 | } 277 | } 278 | default: 279 | continue 280 | } 281 | } 282 | } 283 | resp.Scan(value.Object()) 284 | } 285 | 286 | // Scan Scan response data to Responses object. 287 | func (res Responses) Scan(object *httpexpect.Object) { 288 | for k, rk := range res { 289 | if !Exist(object, rk.Key) { 290 | continue 291 | } 292 | if rk.Value == nil { 293 | continue 294 | } 295 | valueTypeName := reflect.TypeOf(rk.Value).String() 296 | switch valueTypeName { 297 | case "bool": 298 | res[k].Value = object.Value(rk.Key).Boolean().Raw() 299 | case "string": 300 | res[k].Value = object.Value(rk.Key).String().Raw() 301 | case "uint": 302 | res[k].Value = uint(object.Value(rk.Key).Number().Raw()) 303 | case "int": 304 | res[k].Value = int(object.Value(rk.Key).Number().Raw()) 305 | case "int32": 306 | res[k].Value = int32(object.Value(rk.Key).Number().Raw()) 307 | case "float64": 308 | res[k].Value = object.Value(rk.Key).Number().Raw() 309 | case "[]httptest.Responses": 310 | valueLen := len(res[k].Value.([]Responses)) 311 | if rk.Length > 0 { 312 | valueLen = rk.Length 313 | } 314 | object.Value(rk.Key).Array().Length().IsEqual(valueLen) 315 | length := int(object.Value(rk.Key).Array().Length().Raw()) 316 | if length > 0 { 317 | max := 1 318 | if rk.Length > 0 { 319 | max = rk.Length 320 | } 321 | if valueLen == length { 322 | max = length 323 | } 324 | if valueLen > 0 { 325 | for i := 0; i < max; i++ { 326 | res[k].Value.([]Responses)[i].Scan(object.Value(rk.Key).Array().Value(i).Object()) 327 | } 328 | } 329 | } 330 | case "httptest.Responses": 331 | rk.Value.(Responses).Scan(object.Value(rk.Key).Object()) 332 | case "[]string": 333 | if strings.ToLower(rk.Type) == "null" { 334 | res[k].Value = []string{} 335 | } else if strings.ToLower(rk.Type) == "notnull" { 336 | continue 337 | } else { 338 | length := int(object.Value(rk.Key).Array().Length().Raw()) 339 | if length == 0 { 340 | continue 341 | } 342 | reskey, ok := res[k].Value.([]string) 343 | if ok { 344 | var strings []string 345 | for i := 0; i < length; i++ { 346 | strings = append(reskey, object.Value(rk.Key).Array().Value(i).String().Raw()) 347 | } 348 | res[k].Value = strings 349 | } 350 | } 351 | default: 352 | continue 353 | } 354 | } 355 | } 356 | 357 | // Exist Check object keys if the key is in the keys array. 358 | func Exist(object *httpexpect.Object, key string) bool { 359 | objectKyes := object.Keys().Raw() 360 | for _, objectKey := range objectKyes { 361 | if key == objectKey.(string) { 362 | return true 363 | } 364 | } 365 | return false 366 | } 367 | 368 | // GetString return string value. 369 | func (res Responses) GetString(key ...string) string { 370 | if len(key) == 0 { 371 | return "" 372 | } 373 | 374 | if len(key) == 1 { 375 | k := key[0] 376 | if strings.Contains(k, ".") { 377 | keys := strings.Split(k, ".") 378 | if len(keys) == 0 { 379 | return "" 380 | } 381 | key = keys 382 | } 383 | } 384 | 385 | for i := 0; i < len(key); i++ { 386 | for m, rk := range res { 387 | if rk.Value == nil { 388 | return "" 389 | } 390 | reflectTypeString := reflect.TypeOf(rk.Value).String() 391 | if key[i] == rk.Key { 392 | switch reflectTypeString { 393 | case "string": 394 | return rk.Value.(string) 395 | case "httptest.Responses": 396 | return res[m].Value.(Responses).GetString(key[i+1:]...) 397 | } 398 | } 399 | } 400 | 401 | } 402 | return "" 403 | } 404 | 405 | // GetStrArray return string array value. 406 | func (rks Responses) GetStrArray(key string) []string { 407 | for _, rk := range rks { 408 | if key == rk.Key { 409 | if rk.Value == nil { 410 | return nil 411 | } 412 | switch reflect.TypeOf(rk.Value).String() { 413 | case "[]string": 414 | return rk.Value.([]string) 415 | } 416 | } 417 | } 418 | return nil 419 | } 420 | 421 | // GetResponses return Resposnes Array value 422 | func (rks Responses) GetResponses(key string) []Responses { 423 | for _, rk := range rks { 424 | if key == rk.Key { 425 | if rk.Value == nil { 426 | return nil 427 | } 428 | switch reflect.TypeOf(rk.Value).String() { 429 | case "[]httptest.Responses": 430 | return rk.Value.([]Responses) 431 | } 432 | } 433 | } 434 | return nil 435 | } 436 | 437 | // GetResponsereturn Resposnes value 438 | func (rks Responses) GetResponse(key string) Responses { 439 | for _, rk := range rks { 440 | if key == rk.Key { 441 | if rk.Value == nil { 442 | return nil 443 | } 444 | switch reflect.TypeOf(rk.Value).String() { 445 | case "httptest.Responses": 446 | return rk.Value.(Responses) 447 | } 448 | } 449 | } 450 | return nil 451 | } 452 | 453 | // GetUint return uint value 454 | func (rks Responses) GetUint(key ...string) uint { 455 | 456 | if len(key) == 0 { 457 | return 0 458 | } 459 | 460 | if len(key) == 1 { 461 | k := key[0] 462 | if strings.Contains(k, ".") { 463 | keys := strings.Split(k, ".") 464 | if len(keys) == 0 { 465 | return 0 466 | } 467 | key = keys 468 | } 469 | } 470 | 471 | for i := 0; i < len(key); i++ { 472 | for m, rk := range rks { 473 | if key[i] == rk.Key { 474 | if rk.Value == nil { 475 | return 0 476 | } 477 | valueTypeName := reflect.TypeOf(rk.Value).String() 478 | switch valueTypeName { 479 | case "float64": 480 | return uint(rk.Value.(float64)) 481 | case "int32": 482 | return uint(rk.Value.(int32)) 483 | case "uint": 484 | return rk.Value.(uint) 485 | case "int": 486 | return uint(rk.Value.(int)) 487 | case "httptest.Responses": 488 | return rks[m].Value.(Responses).GetUint(key[i:]...) 489 | } 490 | } 491 | } 492 | } 493 | 494 | return 0 495 | } 496 | 497 | // GetInt return int value 498 | func (rks Responses) GetInt(key ...string) int { 499 | if len(key) == 0 { 500 | return 0 501 | } 502 | 503 | if len(key) == 1 { 504 | k := key[0] 505 | if strings.Contains(k, ".") { 506 | keys := strings.Split(k, ".") 507 | if len(keys) == 0 { 508 | return 0 509 | } 510 | key = keys 511 | } 512 | } 513 | 514 | for i := 0; i < len(key); i++ { 515 | for m, rk := range rks { 516 | if key[i] == rk.Key { 517 | if rk.Value == nil { 518 | return 0 519 | } 520 | switch reflect.TypeOf(rk.Value).String() { 521 | case "float64": 522 | return int(rk.Value.(float64)) 523 | case "int": 524 | return rk.Value.(int) 525 | case "int32": 526 | return int(rk.Value.(int32)) 527 | case "uint": 528 | return int(rk.Value.(uint)) 529 | case "httptest.Responses": 530 | return rks[m].Value.(Responses).GetInt(key[i+1:]...) 531 | } 532 | } 533 | } 534 | } 535 | 536 | return 0 537 | } 538 | 539 | // GetInt32 return int32. 540 | func (rks Responses) GetInt32(key ...string) int32 { 541 | if len(key) == 0 { 542 | return 0 543 | } 544 | if len(key) == 1 { 545 | k := key[0] 546 | if strings.Contains(k, ".") { 547 | keys := strings.Split(k, ".") 548 | if len(keys) == 0 { 549 | return 0 550 | } 551 | key = keys 552 | } 553 | } 554 | for i := 0; i < len(key); i++ { 555 | for m, rk := range rks { 556 | if key[i] == rk.Key { 557 | if rk.Value == nil { 558 | return 0 559 | } 560 | switch reflect.TypeOf(rk.Value).String() { 561 | case "float64": 562 | return int32(rk.Value.(float64)) 563 | case "int32": 564 | return rk.Value.(int32) 565 | case "int": 566 | return int32(rk.Value.(int)) 567 | case "uint": 568 | return int32(rk.Value.(uint)) 569 | case "httptest.Responses": 570 | return rks[m].Value.(Responses).GetInt32(key[i+1:]...) 571 | } 572 | } 573 | } 574 | } 575 | return 0 576 | } 577 | 578 | // GetFloat64 return float64 579 | func (rks Responses) GetFloat64(key ...string) float64 { 580 | if len(key) == 0 { 581 | return 0 582 | } 583 | if len(key) == 1 { 584 | k := key[0] 585 | if strings.Contains(k, ".") { 586 | keys := strings.Split(k, ".") 587 | if len(keys) == 0 { 588 | return 0 589 | } 590 | key = keys 591 | } 592 | } 593 | for i := 0; i < len(key); i++ { 594 | for m, rk := range rks { 595 | if key[i] == rk.Key { 596 | if rk.Value == nil { 597 | return 0 598 | } 599 | switch reflect.TypeOf(rk.Value).String() { 600 | case "float64": 601 | return rk.Value.(float64) 602 | case "int": 603 | return float64(rk.Value.(int)) 604 | case "int32": 605 | return float64(rk.Value.(int32)) 606 | case "uint": 607 | return float64(rk.Value.(uint)) 608 | case "httptest.Responses": 609 | return rks[m].Value.(Responses).GetFloat64(key[i+1:]...) 610 | } 611 | } 612 | } 613 | } 614 | return 0 615 | } 616 | 617 | // GetId return id 618 | func (res Responses) GetId(key ...string) uint { 619 | if len(key) == 0 { 620 | key = append(key, "data", "id") 621 | } 622 | return res.GetUint(key...) 623 | } 624 | 625 | var NotEmptyKey = arr.NewCheckArrayType(0) 626 | 627 | // Schema 628 | func Schema(str []byte) (Responses, error) { 629 | objs := Responses{} 630 | j := map[string]any{} 631 | if err := json.Unmarshal(str, &j); err != nil { 632 | return objs, fmt.Errorf("json unmarshal error %w", err) 633 | } 634 | if o, err := schema(j); err != nil { 635 | return objs, err 636 | } else { 637 | objs = o 638 | } 639 | return objs, nil 640 | } 641 | 642 | func (r Responses) Replace(key string, value any, testType ...string) { 643 | if len(r) == 0 { 644 | return 645 | } 646 | 647 | if !strings.Contains(key, ".") { 648 | for i1, k1 := range r { 649 | if k1.Key == key { 650 | r[i1].Value = value 651 | if len(testType) > 0 && testType[0] != "" { 652 | r[i1].Type = testType[0] 653 | } 654 | } 655 | } 656 | return 657 | } 658 | keys := strings.Split(key, ".") 659 | if len(keys) == 1 { 660 | for i1, k1 := range r { 661 | if k1.Key == keys[0] { 662 | r[i1].Value = value 663 | if len(testType) > 0 && testType[0] != "" { 664 | r[i1].Type = testType[0] 665 | } 666 | } 667 | } 668 | return 669 | } 670 | for i1, k1 := range r { 671 | if k1.Key != keys[0] || k1.Value == nil { 672 | continue 673 | } 674 | tof := reflect.TypeOf(k1.Value).String() 675 | if tof == "httptest.Responses" { 676 | r[i1].Value.(Responses).Replace(strings.Join(keys[1:], "."), value, testType...) 677 | } else if tof == "[]httptest.Responses" { 678 | if len(keys) <= 1 { 679 | continue 680 | } 681 | key1, _ := strconv.Atoi(keys[1]) 682 | if r[i1].Value.([]Responses)[key1] != nil { 683 | r[i1].Value.([]Responses)[key1].Replace(strings.Join(keys[2:], "."), value, testType...) 684 | } 685 | } 686 | } 687 | } 688 | 689 | // schema 690 | func schema(j map[string]any) (Responses, error) { 691 | objs := Responses{} 692 | if j == nil { 693 | return objs, nil 694 | } 695 | for k, v := range j { 696 | if k == "" { 697 | continue 698 | } 699 | obj := schemaResponse(k, v) 700 | objs = append(objs, obj) 701 | } 702 | return objs, nil 703 | } 704 | 705 | // schemaResponse 706 | func schemaResponse(k string, v any) Response { 707 | obj := Response{} 708 | obj.Key = k 709 | 710 | if v == nil { 711 | return obj 712 | } 713 | typeName := reflect.TypeOf(v).String() 714 | switch typeName { 715 | case "bool": 716 | obj.Value = v.(bool) 717 | case "string": 718 | if obj.Key == "createdAt" || obj.Key == "updatedAt" || obj.Key == "deletedAt" { 719 | obj.Type = "notempty" 720 | } else if NotEmptyKey.Len() > 0 && NotEmptyKey.Check(obj.Key) { 721 | obj.Type = "notempty" 722 | } else { 723 | obj.Value = v.(string) 724 | } 725 | case "uint": 726 | obj.Value = v.(uint) 727 | case "int": 728 | obj.Value = v.(int) 729 | case "int32": 730 | obj.Value = v.(int32) 731 | case "float64": 732 | obj.Value = v.(float64) 733 | case "[]string": 734 | obj.Value = v.([]string) 735 | case "map[string]interface {}": 736 | if value, _ := schema(v.(map[string]any)); value != nil { 737 | obj.Value = value 738 | } 739 | case "[]interface {}": 740 | list := []Responses{} 741 | for _, v1 := range v.([]any) { 742 | listObj := Responses{} 743 | if v3, ok := v1.(map[string]any); ok { 744 | for k2, v2 := range v3 { 745 | listObj = append(listObj, schemaResponse(k2, v2)) 746 | } 747 | list = append(list, listObj) 748 | obj.Value = list 749 | } else if _, ok := v1.(string); ok { 750 | obj.Value = v 751 | } 752 | } 753 | 754 | default: 755 | fmt.Printf("schemaResponse key:%s valueTypeName:%s\n", k, typeName) 756 | } 757 | return obj 758 | } 759 | -------------------------------------------------------------------------------- /httptest/respose_test.go: -------------------------------------------------------------------------------- 1 | package httptest 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "net/http" 7 | "reflect" 8 | "testing" 9 | 10 | "github.com/gavv/httpexpect/v2" 11 | "github.com/gin-gonic/gin" 12 | "github.com/snowlyg/helper/arr" 13 | ) 14 | 15 | func TestIdKeys(t *testing.T) { 16 | want := Responses{ 17 | {Key: "id", Value: 0, Type: "ge"}, 18 | } 19 | t.Run("Test id keys", func(t *testing.T) { 20 | idKeys := IdKeys() 21 | if !reflect.DeepEqual(want, idKeys) { 22 | t.Errorf("IdKeys want %+v but get %+v", want, idKeys) 23 | } 24 | }) 25 | } 26 | 27 | func TestHttpTest(t *testing.T) { 28 | engine := gin.New() 29 | // Add /example route via handler function to the gin instance 30 | handler := GinHandler(engine) 31 | // Create httpexpect instance 32 | e := httpexpect.WithConfig(httpexpect.Config{ 33 | Client: &http.Client{ 34 | Transport: httpexpect.NewBinder(handler), 35 | Jar: httpexpect.NewCookieJar(), 36 | }, 37 | Reporter: httpexpect.NewAssertReporter(t), 38 | Printers: []httpexpect.Printer{ 39 | httpexpect.NewDebugPrinter(t, true), 40 | }, 41 | }) 42 | pageKeys := Responses{{Key: "message", Value: "OK"}, {Key: "status", Value: 200}, {Key: "data", Value: Response{Key: "message", Value: "pong"}}} 43 | value := e.GET("/example").Expect().Status(http.StatusOK).JSON() 44 | 45 | Test(value, pageKeys) 46 | } 47 | 48 | func TestHttpTestArray(t *testing.T) { 49 | engine := gin.New() 50 | // Add /example route via handler function to the gin instance 51 | handler := GinHandler(engine) 52 | // Create httpexpect instance 53 | e := httpexpect.WithConfig(httpexpect.Config{ 54 | Client: &http.Client{ 55 | Transport: httpexpect.NewBinder(handler), 56 | Jar: httpexpect.NewCookieJar(), 57 | }, 58 | Reporter: httpexpect.NewAssertReporter(t), 59 | Printers: []httpexpect.Printer{ 60 | httpexpect.NewDebugPrinter(t, true), 61 | }, 62 | }) 63 | pageKeys := []string{"1", "2"} 64 | value := e.GET("/array").Expect().Status(http.StatusOK).JSON() 65 | Test(value, pageKeys) 66 | } 67 | 68 | func TestHttpScan(t *testing.T) { 69 | engine := gin.New() 70 | // Add /example route via handler function to the gin instance 71 | handler := GinHandler(engine) 72 | // Create httpexpect instance 73 | e := httpexpect.WithConfig(httpexpect.Config{ 74 | Client: &http.Client{ 75 | Transport: httpexpect.NewBinder(handler), 76 | Jar: httpexpect.NewCookieJar(), 77 | }, 78 | Reporter: httpexpect.NewAssertReporter(t), 79 | Printers: []httpexpect.Printer{ 80 | httpexpect.NewDebugPrinter(t, true), 81 | }, 82 | }) 83 | pageKeys := Responses{{Key: "message", Value: ""}} 84 | obj := e.GET("/example").Expect().Status(http.StatusOK).JSON().Object() 85 | 86 | Scan(obj, pageKeys) 87 | x := pageKeys.GetString("data.message") 88 | if x != "pong" { 89 | t.Errorf("Scan want get pong but get %s", x) 90 | } 91 | } 92 | 93 | func TestSchema(t *testing.T) { 94 | wantJson := `{ 95 | "status": 200, 96 | "data": { 97 | "list": [ 98 | { 99 | "createdAt": "2025-03-21T16:27:20+08:00", 100 | "deletedAt": "", 101 | "updatedAt": "2025-03-21T16:27:20+08:00", 102 | "dev_remark": "", 103 | "pac_room_id": 1, 104 | "room_desc": "1413-301" 105 | } 106 | ], 107 | "total": 1, 108 | "page": 1, 109 | "pageSize": 20 110 | }, 111 | "message": "OK" 112 | }` 113 | res, err := Schema([]byte(wantJson)) 114 | if err != nil { 115 | t.Fatal(err.Error()) 116 | } 117 | if res.GetInt("status") != 200 { 118 | t.Errorf("status want %d but get %d", 200, res.GetInt("status")) 119 | } 120 | if res.GetString("message") != "OK" { 121 | t.Errorf("message want %s but get %s", "OK", res.GetString("message")) 122 | } 123 | data := res.GetResponse("data") 124 | if data != nil { 125 | if data.GetInt("total") != 1 { 126 | t.Errorf("total want %d but get %d", 1, data.GetInt("total")) 127 | } 128 | if data.GetInt("page") != 1 { 129 | t.Errorf("page want %d but get %d", 1, data.GetInt("page")) 130 | } 131 | if data.GetInt("pageSize") != 20 { 132 | t.Errorf("pageSize want %d but get %d", 20, data.GetInt("pageSize")) 133 | } 134 | list := data.GetResponses("list") 135 | if len(list) != 1 { 136 | t.Errorf("list len want %d but get %d", 1, len(list)) 137 | } 138 | first := list[0] 139 | if first.GetId("pac_room_id") != 1 { 140 | t.Errorf("pac_room_id want %d but get %d", 1, first.GetId("pac_room_id")) 141 | } 142 | roomDesc := first.GetString("room_desc") 143 | if roomDesc != "1413-301" { 144 | t.Errorf("room_desc want %s but get '%s'", "1413-301", roomDesc) 145 | } 146 | if first.GetString("dev_remark") != "" { 147 | t.Errorf("dev_remark want %s but get '%s'", "", first.GetString("dev_remark")) 148 | } 149 | keys := arr.NewCheckArrayType(0) 150 | for _, v := range first.Keys() { 151 | keys.Add(v) 152 | } 153 | 154 | for _, k := range []string{"createdAt", "deletedAt", "updatedAt", "dev_remark", "pac_room_id", "room_desc"} { 155 | if !keys.Check(k) { 156 | t.Errorf("%s not in keys", k) 157 | } 158 | } 159 | } 160 | } 161 | 162 | func TestSchemaResponse(t *testing.T) { 163 | data := `{ 164 | "status": 200, 165 | "message": "OK" 166 | }` 167 | j := map[string]any{} 168 | if err := json.Unmarshal([]byte(data), &j); err != nil { 169 | t.Error(err.Error()) 170 | } 171 | log.Printf("j %+v\n", j) 172 | wantKey := "data" 173 | resp := schemaResponse(wantKey, j) 174 | 175 | if value, ok := resp.Value.(Responses); !ok { 176 | t.Error("schema response return value not Responses") 177 | } else { 178 | keys := arr.NewCheckArrayType(0) 179 | for _, v := range value { 180 | keys.Add(v.Key) 181 | if v.Key == "message" { 182 | wantValue := "OK" 183 | if v.Value != wantValue { 184 | t.Errorf("%s Value want '%v' but get '%v'", v.Key, wantValue, v.Value) 185 | } 186 | } else if v.Key == "status" { 187 | var wantValue float64 = 200 188 | if v.Value != wantValue { 189 | t.Errorf("%s Value want '%v' but get '%v'", v.Key, wantValue, v.Value) 190 | } 191 | } else { 192 | t.Errorf("key %s is in response", v.Key) 193 | } 194 | } 195 | if !keys.Check("message") { 196 | t.Error("message not in keys") 197 | } 198 | if !keys.Check("status") { 199 | t.Error("status not in keys") 200 | } 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /loadtls.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/gin-gonic/gin" 7 | "github.com/unrolled/secure" 8 | ) 9 | 10 | // LoadTls 11 | func LoadTls() gin.HandlerFunc { 12 | return func(c *gin.Context) { 13 | middleware := secure.New(secure.Options{ 14 | SSLRedirect: true, 15 | SSLHost: "127.0.0.1:443", 16 | }) 17 | err := middleware.Process(c.Writer, c.Request) 18 | if err != nil { 19 | fmt.Println(err) 20 | return 21 | } 22 | c.Next() 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /menu.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import "gorm.io/gorm" 4 | 5 | type Menu struct { 6 | gorm.Model 7 | Path string `json:"path"` 8 | Component string `json:"component"` 9 | Redirect string `json:"redirect"` 10 | Hidden bool `json:"hidden"` 11 | AlwaysShow bool `json:"alwaysShow"` 12 | Meta 13 | Children []*Menu `json:"children" gorm:"-"` 14 | } 15 | 16 | func (m *Menu) TableName() string { 17 | return "menus" 18 | } 19 | 20 | type Meta struct { 21 | Roles []string `json:"roles" gorm:"-"` 22 | Title string `json:"title"` 23 | Icon string `json:"icon"` 24 | NoCache bool `json:"noCache"` 25 | } 26 | -------------------------------------------------------------------------------- /migrate.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/go-gormigrate/gormigrate/v2" 7 | "gorm.io/gorm" 8 | ) 9 | 10 | // AddMigration add *gormigrate.Migration 11 | func (ws *WebServe) AddMigration(m ...*gormigrate.Migration) { 12 | ws.items = append(ws.items, m...) 13 | } 14 | 15 | // MigrationLen length of MigrationCollection 16 | func (ws *WebServe) MigrationLen() int { 17 | return len(ws.items) 18 | } 19 | 20 | // Refresh refresh migration 21 | func (mc *WebServe) Refresh() error { 22 | if mc.getFirstMigration() == "" { 23 | return nil 24 | } 25 | err := mc.rollbackTo(mc.getFirstMigration()) 26 | if !errors.Is(err, gormigrate.ErrMigrationIDDoesNotExist) && err != nil { 27 | return err 28 | } 29 | return mc.Migrate() 30 | } 31 | 32 | // rollbackTo roolback migration to migrationId 33 | func (ws *WebServe) rollbackTo(migrationId string) error { 34 | return ws.m.RollbackTo(migrationId) 35 | } 36 | 37 | // Rollback roolback migrations 38 | func (ws *WebServe) Rollback(migrationId string) error { 39 | if ws.MigrationLen() == 0 { 40 | return nil 41 | } 42 | if migrationId == "" { 43 | err := ws.rollbackLast() 44 | if !errors.Is(err, gormigrate.ErrMigrationIDDoesNotExist) && err != nil { 45 | return err 46 | } 47 | return nil 48 | } 49 | err := ws.rollbackTo(migrationId) 50 | if !errors.Is(err, gormigrate.ErrMigrationIDDoesNotExist) && err != nil { 51 | return err 52 | } 53 | return nil 54 | } 55 | 56 | // rollbackLast roolback the lasted migration 57 | func (ws *WebServe) rollbackLast() error { 58 | return ws.m.RollbackLast() 59 | } 60 | 61 | // Migrate exec migration cmd 62 | func (ws *WebServe) Migrate() error { 63 | // add migrations 64 | ws.AddMigration( 65 | &gormigrate.Migration{ 66 | ID: "init_system", 67 | Migrate: func(tx *gorm.DB) error { 68 | return tx.AutoMigrate(&Router{}, &Menu{}) 69 | }, 70 | Rollback: func(tx *gorm.DB) error { 71 | return tx.Migrator().DropTable(new(Router).TableName(), new(Menu).TableName()) 72 | }, 73 | }, 74 | // add more migrations 75 | ) 76 | if ws.m == nil { 77 | ws.m = gormigrate.New(ws.db, gormigrate.DefaultOptions, ws.items) 78 | } 79 | if err := ws.m.Migrate(); err != nil { 80 | return err 81 | } 82 | return nil 83 | } 84 | 85 | // getFirstMigration get first migration's id 86 | func (ws *WebServe) getFirstMigration() string { 87 | if ws.MigrationLen() == 0 { 88 | return "" 89 | } 90 | return ws.items[0].ID 91 | } 92 | -------------------------------------------------------------------------------- /migrate_test.go: -------------------------------------------------------------------------------- 1 | package admin 2 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/snowlyg/helper/str" 10 | "gorm.io/gorm" 11 | ) 12 | 13 | type ErrMsg struct { 14 | Code int64 `json:"code"` 15 | Msg string `json:"message"` 16 | } 17 | 18 | var ( 19 | ErrParamValidate = errors.New("param unvalidate") 20 | ErrPaginateParam = errors.New("paginate param unvalidate") 21 | ErrUnSupportFramework = errors.New("unsupport framework") 22 | ) 23 | 24 | // Model 25 | type Model struct { 26 | Id uint `json:"id"` 27 | UpdatedAt string `json:"updatedAt"` 28 | CreatedAt string `json:"createdAt"` 29 | DeletedAt string `json:"deletedAt"` 30 | } 31 | 32 | // Paginate param for paginate query 33 | type Paginate struct { 34 | Page int `json:"page" form:"page"` 35 | PageSize int `json:"pageSize" form:"pageSize"` 36 | OrderBy string `json:"orderBy" form:"orderBy"` 37 | Sort string `json:"sort" form:"sort"` 38 | } 39 | 40 | func (req *Paginate) Request(ctx *gin.Context) error { 41 | if err := ctx.ShouldBind(req); err != nil { 42 | return ErrParamValidate 43 | } 44 | return nil 45 | } 46 | 47 | // PaginateScope paginate scope 48 | func (req *Paginate) PaginateScope() func(db *gorm.DB) *gorm.DB { 49 | return PaginateScope(req.Page, req.PageSize, req.Sort, req.OrderBy) 50 | } 51 | 52 | // IdScope 53 | func IdScope(id any) func(db *gorm.DB) *gorm.DB { 54 | return func(db *gorm.DB) *gorm.DB { 55 | return db.Where("id = ?", id) 56 | } 57 | } 58 | 59 | // InIdsScope 60 | func InIdsScope(ids []uint) func(db *gorm.DB) *gorm.DB { 61 | return func(db *gorm.DB) *gorm.DB { 62 | return db.Where("id in ?", ids) 63 | } 64 | } 65 | 66 | // InNamesScope 67 | func InNamesScope(names []string) func(db *gorm.DB) *gorm.DB { 68 | return func(db *gorm.DB) *gorm.DB { 69 | return db.Where("name in ?", names) 70 | } 71 | } 72 | 73 | // InUuidsScope 74 | func InUuidsScope(uuids []string) func(db *gorm.DB) *gorm.DB { 75 | return func(db *gorm.DB) *gorm.DB { 76 | return db.Where("uuid in ?", uuids) 77 | } 78 | } 79 | 80 | // NeIdScope 81 | func NeIdScope(id any) func(db *gorm.DB) *gorm.DB { 82 | return func(db *gorm.DB) *gorm.DB { 83 | return db.Where("id != ?", id) 84 | } 85 | } 86 | 87 | // PaginateScope return paginate scope for gorm 88 | func PaginateScope(page, pageSize int, sort, orderBy string) func(db *gorm.DB) *gorm.DB { 89 | return func(db *gorm.DB) *gorm.DB { 90 | pageSize := getPageSize(pageSize) 91 | offset := getOffset(page, pageSize) 92 | return db.Order(getOrderBy(sort, orderBy)).Offset(offset).Limit(pageSize) 93 | } 94 | } 95 | 96 | // getOffset 97 | func getOffset(page, pageSize int) int { 98 | if page == 0 { 99 | page = 1 100 | } 101 | offset := (page - 1) * pageSize 102 | if page < 0 { 103 | offset = -1 104 | } 105 | return offset 106 | } 107 | 108 | // getPageSize 109 | func getPageSize(pageSize int) int { 110 | switch { 111 | case pageSize > 100: 112 | pageSize = 100 113 | case pageSize < 0: 114 | pageSize = -1 115 | case pageSize == 0: 116 | pageSize = 10 117 | } 118 | return pageSize 119 | } 120 | 121 | // getOrderBy 122 | func getOrderBy(sort, orderBy string) string { 123 | if sort == "" { 124 | sort = "desc" 125 | } 126 | if orderBy == "" { 127 | orderBy = "created_at" 128 | } 129 | return str.Join(orderBy, " ", sort) 130 | } 131 | 132 | const ( 133 | ResponseOkMessage = "OK" 134 | ResponseErrorMessage = "FAIL" 135 | ) 136 | 137 | type Response struct { 138 | Code int `json:"status"` 139 | Data any `json:"data,omitempty"` 140 | Msg string `json:"message"` 141 | } 142 | 143 | func Result(code int, data any, msg string, ctx *gin.Context) { 144 | ctx.JSON(http.StatusOK, Response{code, data, msg}) 145 | } 146 | 147 | func Ok(ctx *gin.Context) { 148 | Result(http.StatusOK, map[string]any{}, ResponseOkMessage, ctx) 149 | } 150 | 151 | func OkWithMessage(message string, ctx *gin.Context) { 152 | Result(http.StatusOK, map[string]any{}, message, ctx) 153 | } 154 | 155 | func OkWithData(data any, ctx *gin.Context) { 156 | Result(http.StatusOK, data, ResponseOkMessage, ctx) 157 | } 158 | 159 | func OkWithDetailed(data any, message string, ctx *gin.Context) { 160 | Result(http.StatusOK, data, message, ctx) 161 | } 162 | 163 | func Fail(ctx *gin.Context) { 164 | Result(http.StatusBadRequest, map[string]any{}, ResponseErrorMessage, ctx) 165 | } 166 | 167 | func UnauthorizedFailWithMessage(message string, ctx *gin.Context) { 168 | Result(http.StatusUnauthorized, map[string]any{}, message, ctx) 169 | } 170 | 171 | func UnauthorizedFailWithDetailed(data any, message string, ctx *gin.Context) { 172 | Result(http.StatusUnauthorized, data, message, ctx) 173 | } 174 | 175 | func ForbiddenFailWithMessage(message string, ctx *gin.Context) { 176 | Result(http.StatusForbidden, map[string]any{}, message, ctx) 177 | } 178 | 179 | func FailWithMessage(message string, ctx *gin.Context) { 180 | Result(http.StatusBadRequest, map[string]any{}, message, ctx) 181 | } 182 | 183 | func FailWithDetailed(data any, message string, ctx *gin.Context) { 184 | Result(http.StatusBadRequest, data, message, ctx) 185 | } 186 | 187 | type PageResult struct { 188 | List any `json:"list,omitempty"` 189 | Total int64 `json:"total"` 190 | Page int `json:"page"` 191 | PageSize int `json:"pageSize"` 192 | } 193 | 194 | type BaseResponse struct { 195 | Id uint `json:"id"` 196 | CreatedAt *time.Time `json:"createdAt"` 197 | UpdatedAt *time.Time `json:"updatedAt"` 198 | } 199 | 200 | // Paging common input parameter structure 201 | type PageInfo struct { 202 | Page int `json:"page" form:"page" validate:"required"` 203 | PageSize int `json:"pageSize" form:"pageSize" validate:"required"` 204 | OrderBy string `json:"orderBy" form:"orderBy"` 205 | SortBy string `json:"sortBy" form:"sortBy"` 206 | } 207 | 208 | type IdsBinding struct { 209 | Ids []uint `json:"ids" form:"ids" validate:"required,dive,required"` 210 | } 211 | 212 | // Request get id data form the context of every query 213 | func (req *IdsBinding) Request(ctx *gin.Context) error { 214 | if err := ctx.ShouldBind(req); err != nil { 215 | return ErrParamValidate 216 | } 217 | return nil 218 | } 219 | 220 | // Id the struct has used to get id form the context of every query 221 | type Id struct { 222 | Id uint `json:"id" uri:"id"` 223 | } 224 | 225 | // Request get id data form the context of every query 226 | func (req *Id) Request(ctx *gin.Context) error { 227 | if err := ctx.ShouldBindUri(req); err != nil { 228 | return ErrParamValidate 229 | } 230 | return nil 231 | } 232 | 233 | type Empty struct{} 234 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "path/filepath" 7 | "strings" 8 | 9 | "github.com/snowlyg/helper/arr" 10 | "gorm.io/gorm" 11 | ) 12 | 13 | type Router struct { 14 | gorm.Model 15 | Path string `json:"path"` 16 | Title string `json:"title"` 17 | Group string `json:"group"` 18 | Method string `json:"method"` 19 | Children []*Router `json:"children" gorm:"-"` 20 | } 21 | 22 | func (m *Router) TableName() string { 23 | return "routers" 24 | } 25 | 26 | func (ws *WebServe) routers() { 27 | methodExcepts := strings.Split(ws.conf.Except.Method, ";") 28 | uriExcepts := strings.Split(ws.conf.Except.Uri, ";") 29 | 30 | // routeLen := len(ws.engine.Routes()) 31 | // permRoutes := make([]*Router, 0, routeLen) 32 | // otherMethodTypes := make([]*Router, 0, routeLen) 33 | 34 | for _, r := range ws.engine.Routes() { 35 | // log.Printf("handler:%s, method:%s, path:%s\n", r.Handler, r.Method, r.Path) 36 | if strings.Contains(r.Path, "/*filepath") || r.Handler == "github.com/gin-gonic/gin.(*RouterGroup).createStaticHandler.func1" { 37 | continue 38 | } 39 | path := filepath.ToSlash(filepath.Clean(r.Path)) 40 | route := &Router{ 41 | Path: path, 42 | Title: path, 43 | Group: "", 44 | Method: r.Method, 45 | } 46 | 47 | httpStatusType := arr.NewCheckArrayType(4) 48 | httpStatusType.AddMutil(http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete) 49 | if !httpStatusType.Check(r.Method) { 50 | ws.otherRoutes = append(ws.otherRoutes, route) 51 | continue 52 | } 53 | 54 | if len(methodExcepts) > 0 && len(uriExcepts) > 0 && len(methodExcepts) == len(uriExcepts) { 55 | for i := range methodExcepts { 56 | if strings.EqualFold(r.Method, strings.ToLower(methodExcepts[i])) && strings.EqualFold(path, strings.ToLower(uriExcepts[i])) { 57 | ws.otherRoutes = append(ws.otherRoutes, route) 58 | continue 59 | } 60 | } 61 | } 62 | ws.permRoutes = append(ws.permRoutes, route) 63 | } 64 | 65 | // log.Printf("permRoutes:%d other:%d\n", len(ws.permRoutes), len(ws.otherRoutes)) 66 | 67 | if ws.db == nil { 68 | return 69 | } 70 | 71 | if len(ws.permRoutes) == 0 { 72 | return 73 | } 74 | 75 | // seed routers 76 | olds := []*Router{} 77 | dels := []uint{} 78 | adds := []*Router{} 79 | if err := ws.db.Model(&Router{}).Find(&olds).Error; err != nil { 80 | log.Printf("iris-admin: old router find get err:%s\n", err.Error()) 81 | } 82 | 83 | if len(olds) == 0 { 84 | if err := ws.db.Create(&ws.permRoutes).Error; err == nil { 85 | log.Printf("iris-admin: add %d router \n", len(ws.permRoutes)) 86 | } 87 | return 88 | } 89 | 90 | oldCheck := arr.NewCheckArrayType(len(olds)) 91 | for _, old := range olds { 92 | oldCheck.Add(old.Path) 93 | found := false 94 | for _, a := range ws.permRoutes { 95 | if old.Path == a.Path && old.Method == a.Method { 96 | found = true 97 | break 98 | } 99 | } 100 | if !found { 101 | dels = append(dels, old.ID) 102 | } 103 | } 104 | 105 | if len(dels) > 0 { 106 | if err := ws.db.Delete(&Router{}, dels).Error; err == nil { 107 | log.Printf("iris-admin: delete %d router\n", len(dels)) 108 | } 109 | } 110 | 111 | for _, r := range ws.permRoutes { 112 | if !oldCheck.Check(r.Path) { 113 | adds = append(adds, r) 114 | } 115 | } 116 | 117 | if len(adds) > 0 { 118 | if err := ws.db.Create(&adds).Error; err == nil { 119 | log.Printf("iris-admin: add %d router,old:%d\n", len(adds), len(olds)) 120 | } 121 | } 122 | 123 | } 124 | -------------------------------------------------------------------------------- /run.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | type serve interface { 4 | ListenAndServe() error 5 | } 6 | -------------------------------------------------------------------------------- /run_darwin.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "fmt" 5 | "syscall" 6 | "time" 7 | 8 | "github.com/fvbock/endless" 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | func run(address string, router *gin.Engine) serve { 13 | s := endless.NewServer(address, router) 14 | s.BeforeBegin = func(add string) { 15 | fmt.Printf("Actual pid is %d\n", syscall.Getpid()) 16 | } 17 | s.ReadHeaderTimeout = 10 * time.Millisecond 18 | s.WriteTimeout = 10 * time.Second 19 | s.MaxHeaderBytes = 1 << 20 20 | return s 21 | } 22 | -------------------------------------------------------------------------------- /run_linux.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/fvbock/endless" 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func run(address string, router *gin.Engine) serve { 11 | s := endless.NewServer(address, router) 12 | s.ReadHeaderTimeout = 10 * time.Millisecond 13 | s.WriteTimeout = 10 * time.Second 14 | s.MaxHeaderBytes = 1 << 20 15 | return s 16 | } 17 | -------------------------------------------------------------------------------- /run_windows.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func run(address string, router *gin.Engine) serve { 11 | return &http.Server{ 12 | Addr: address, 13 | Handler: router, 14 | ReadTimeout: 10 * time.Second, 15 | WriteTimeout: 10 * time.Second, 16 | MaxHeaderBytes: 1 << 20, 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "time" 7 | 8 | limit "github.com/aviddiviner/gin-limit" 9 | "github.com/casbin/casbin/v2" 10 | "github.com/gin-gonic/gin" 11 | "github.com/go-gormigrate/gormigrate/v2" 12 | "github.com/mattn/go-colorable" 13 | "github.com/snowlyg/iris-admin/conf" 14 | "github.com/snowlyg/iris-admin/e" 15 | 16 | "gorm.io/driver/mysql" 17 | "gorm.io/gorm" 18 | ) 19 | 20 | // Status 21 | const ( 22 | StatusUnknown int = iota 23 | StatusTrue 24 | StatusFalse 25 | ) 26 | 27 | type WebServe struct { 28 | serve 29 | conf *conf.Conf 30 | db *gorm.DB 31 | enforcer *casbin.Enforcer 32 | engine *gin.Engine 33 | iroutes *gin.IRoutes 34 | 35 | validate *Validator 36 | 37 | m *gormigrate.Gormigrate 38 | items []*gormigrate.Migration 39 | 40 | permRoutes []*Router 41 | otherRoutes []*Router 42 | } 43 | 44 | // gormDb 45 | func gormDb(m *conf.Mysql) (*gorm.DB, error) { 46 | if m == nil { 47 | return nil, e.ErrConfigInvalid 48 | } 49 | if m.DbName == "" { 50 | return nil, e.ErrDbTableNameEmpty 51 | } 52 | mysqlConfig := mysql.Config{ 53 | DSN: m.Dsn(), 54 | DefaultStringSize: 191, 55 | // DisableDatetimePrecision: true, 56 | // DontSupportRenameIndex: true, 57 | // DontSupportRenameColumn: true, 58 | // SkipInitializeWithVersion: false, 59 | } 60 | if db, err := gorm.Open(mysql.New(mysqlConfig)); err != nil { 61 | fmt.Printf("open mysql[%s] is fail:%v\n", m.Dsn(), err) 62 | return nil, err 63 | } else { 64 | sqlDB, err := db.DB() 65 | if err != nil { 66 | return nil, err 67 | } 68 | if err := sqlDB.Ping(); err != nil { 69 | log.Printf("ping mysql[%s] is fail:%v\n", m.Dsn(), err) 70 | return nil, err 71 | } 72 | sqlDB.SetMaxIdleConns(m.MaxIdleConns) 73 | sqlDB.SetMaxOpenConns(m.MaxOpenConns) 74 | return db, nil 75 | } 76 | } 77 | 78 | // NewServe 79 | func NewServe(c *conf.Conf) (*WebServe, error) { 80 | 81 | gin.SetMode(c.System.GinMode) 82 | app := gin.Default() 83 | if c.System.Tls { 84 | app.Use(LoadTls()) 85 | } 86 | app.Use(c.CorsConf.Cors()) 87 | // registerValidation() 88 | gin.DefaultWriter = colorable.NewColorableStdout() 89 | c.SetDefaultAddrAndTimeFormat() 90 | db, err := gormDb(c.Mysql) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | auth, err := c.GetEnforcer(db) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | ws := &WebServe{ 101 | conf: c, 102 | engine: app, 103 | enforcer: auth, 104 | db: db, 105 | permRoutes: []*Router{}, 106 | otherRoutes: []*Router{}, 107 | } 108 | if err := ws.Migrate(); err != nil { 109 | return nil, err 110 | } 111 | 112 | switch c.Locale { 113 | case "en": 114 | ws.validate = newEn() 115 | case "zh": 116 | ws.validate = newZh() 117 | default: 118 | ws.validate = newZh() 119 | } 120 | 121 | ws.engine.Use(limit.MaxAllowed(50)) 122 | 123 | return ws, nil 124 | } 125 | 126 | // Engine return *gin.Engine 127 | func (ws *WebServe) Engine() *gin.Engine { 128 | return ws.engine 129 | } 130 | 131 | func (ws *WebServe) IRoutes() *gin.IRoutes { 132 | return ws.iroutes 133 | } 134 | 135 | // Config 136 | func (ws *WebServe) Config() *conf.Conf { 137 | return ws.conf 138 | } 139 | 140 | // SystemAddr 141 | func (ws *WebServe) SystemAddr() string { 142 | return ws.conf.System.Addr 143 | } 144 | 145 | // Auth 146 | func (ws *WebServe) Auth() *casbin.Enforcer { 147 | return ws.enforcer 148 | } 149 | 150 | // DB 151 | func (ws *WebServe) DB() *gorm.DB { 152 | return ws.db 153 | } 154 | 155 | // // Deprecated: use nginx or apache instead. 156 | // func (ws *WebServe) AddWebStatic(staticAbsPath, webPrefix string, paths ...string) { 157 | // } 158 | 159 | // // Deprecated: use nginx or apache instead. 160 | // func (ws *WebServe) AddUploadStatic(webPrefix, staticAbsPath string) { 161 | // } 162 | 163 | // Run 164 | func (ws *WebServe) Run() { 165 | if ws.engine == nil { 166 | panic("init engine please") 167 | } 168 | 169 | // ws.Engine().NoRoute(func(ctx *gin.Context) { 170 | // // excepte for /v0 /v1 and so on 171 | // reg := `^/v[0-9]+$|^(/v[0-9]+)/` 172 | // ok, _ := regexp.MatchString(reg, ctx.Request.RequestURI) 173 | // if ok { 174 | // ctx.Writer.WriteHeader(http.StatusNotFound) 175 | // ctx.Writer.Flush() 176 | // return 177 | // } 178 | 179 | // var indexFile []byte 180 | // for _, wp := range ws.statics { 181 | // // match /admin or /admin/*** 182 | // reg := str.Join("^", wp.Prefix, "$|^(", wp.Prefix, ")/") 183 | // ok, err := regexp.MatchString(reg, ctx.Request.RequestURI) 184 | // if err != nil || !ok { 185 | // continue 186 | // } 187 | // indexFile = wp.IndexFile 188 | // } 189 | 190 | // ctx.Writer.WriteHeader(http.StatusOK) 191 | // ctx.Writer.Write(indexFile) 192 | 193 | // ctx.Writer.Header().Add("Accept", "text/html") 194 | // ctx.Writer.Flush() 195 | // }) 196 | 197 | ws.routers() 198 | 199 | systemAddr := ws.SystemAddr() 200 | s := run(systemAddr, ws.engine) 201 | time.Sleep(10 * time.Microsecond) 202 | 203 | log.Printf("listen on: http://%s\n", systemAddr) 204 | 205 | s.ListenAndServe() 206 | } 207 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/snowlyg/iris-admin/conf" 11 | ) 12 | 13 | func TestStart(t *testing.T) { 14 | go func() { 15 | os.Setenv("IRIS_ADMIN_WEB_ADDR", "127.0.0.1:18088") 16 | c := conf.NewConf() 17 | if serve, err := NewServe(c); err != nil { 18 | t.Error(err.Error()) 19 | } else { 20 | serve.Engine() 21 | serve.Run() 22 | } 23 | }() 24 | 25 | time.Sleep(3 * time.Second) 26 | 27 | resp, err := http.Get("http://127.0.0.1:18088") 28 | if err != nil { 29 | t.Errorf("test web start get %v", err) 30 | } 31 | defer resp.Body.Close() 32 | 33 | _, err = io.ReadAll(resp.Body) 34 | if err != nil { 35 | t.Errorf("test web start get %v", err) 36 | } 37 | 38 | if resp.StatusCode != http.StatusNotFound { 39 | t.Errorf("test web start want [%d] but get [%d]", http.StatusNotFound, resp.StatusCode) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /validate.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/gin-gonic/gin" 8 | "github.com/go-playground/locales/en" 9 | "github.com/go-playground/locales/zh" 10 | ut "github.com/go-playground/universal-translator" 11 | "github.com/go-playground/validator/v10" 12 | en_translations "github.com/go-playground/validator/v10/translations/en" 13 | zh_translations "github.com/go-playground/validator/v10/translations/zh" 14 | ) 15 | 16 | type Validator struct { 17 | uni *ut.UniversalTranslator 18 | validate *validator.Validate 19 | trans ut.Translator 20 | } 21 | 22 | func newZh() *Validator { 23 | zh := zh.New() 24 | uni := ut.New(zh, zh) 25 | 26 | // this is usually know or extracted from http 'Accept-Language' header 27 | // also see uni.FindTranslator(...) 28 | trans, _ := uni.GetTranslator("zh") 29 | 30 | validate := validator.New() 31 | zh_translations.RegisterDefaultTranslations(validate, trans) 32 | return &Validator{uni: uni, validate: validate, trans: trans} 33 | } 34 | 35 | func newEn() *Validator { 36 | en := en.New() 37 | uni := ut.New(en, en) 38 | 39 | // this is usually know or extracted from http 'Accept-Language' header 40 | // also see uni.FindTranslator(...) 41 | trans, _ := uni.GetTranslator("en") 42 | 43 | validate := validator.New() 44 | en_translations.RegisterDefaultTranslations(validate, trans) 45 | return &Validator{uni: uni, validate: validate, trans: trans} 46 | } 47 | 48 | type IdBinding struct { 49 | Id uint `json:"id" uri:"id" validate:"required"` 50 | } 51 | 52 | // ShouldBindUri binds the passed struct pointer using the specified binding engine. 53 | func (val *Validator) ShouldBindUri(ctx *gin.Context) (uint, error) { 54 | var id IdBinding 55 | if e := ctx.ShouldBindUri(&id); e != nil { 56 | return 0, e 57 | } 58 | if e := val.Translate(id); e != nil { 59 | return 0, e 60 | } 61 | return id.Id, nil 62 | } 63 | 64 | func (val *Validator) Translate(s any) error { 65 | if err := val.validate.Struct(s); err != nil { 66 | // translate all error at once 67 | errs := err.(validator.ValidationErrors) 68 | for _, v := range errs.Translate(val.trans) { 69 | if v != "" { 70 | return errors.New(v) 71 | } 72 | } 73 | } 74 | return nil 75 | } 76 | 77 | func (val *Validator) ValidateMap(data map[string]any, rules map[string]any) error { 78 | mapErrs := val.validate.ValidateMap(data, rules) 79 | for mk, err := range mapErrs { 80 | if err != nil { 81 | // translate all error at once 82 | errs := err.(validator.ValidationErrors) 83 | for _, v := range errs.Translate(val.trans) { 84 | if v != "" { 85 | return fmt.Errorf("%s%s", mk, v) 86 | } 87 | } 88 | } 89 | } 90 | return nil 91 | } 92 | --------------------------------------------------------------------------------