├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── README_EN.md
├── auth2
├── claims.go
├── claims_test.go
├── example
│ └── main.go
├── extractor.go
├── interface.go
├── jwt.go
├── jwt_test.go
├── local.go
├── local_test.go
├── redis.go
├── redis_test.go
├── token.go
├── token_test.go
└── verifier.go
├── conf
├── auth.go
├── auth_test.go
├── config.go
├── config_test.go
├── cros.go
├── mysql.go
├── mysql_test.go
├── operate.go
├── viper.go
└── viper_test.go
├── e
└── error.go
├── example
├── main.go
├── public
│ └── index.html
└── readme.md
├── go.mod
├── go.sum
├── httptest
├── base.go
├── base_test.go
├── common.go
├── printer.go
├── reporter.go
├── respose.go
└── respose_test.go
├── loadtls.go
├── menu.go
├── migrate.go
├── migrate_test.go
├── model.go
├── router.go
├── run.go
├── run_darwin.go
├── run_linux.go
├── run_windows.go
├── server.go
├── server_test.go
└── validate.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # ext
2 | *.dll
3 | *.db
4 | *.exe
5 | *.log
6 | *.out
7 | *.txt
8 | *.yaml
9 | *.jpg
10 | *.json
11 |
12 | # file
13 | rbac_model.conf
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 |
3 | go:
4 | - 1.18.x
5 |
6 | before_script:
7 | - sudo redis-server /etc/redis/redis.conf --port 6379 --requirepass 'secret'
8 | - mkdir -p data/db
9 | - mongod --dbpath=data/db &
10 | - sleep 5
11 | - mongo mongo_test --eval 'db.createUser({user:"travis",pwd:"test",roles:["readWrite"]});'
12 |
13 | services:
14 | - mysql
15 |
16 | env:
17 | - GO111MODULE=on redisPwd=secret mongoAddr='travis:test@127.0.0.1:27017/mongo_test'
18 |
19 | before_install:
20 | - go get -v -t ./...
21 |
22 | script:
23 | - go test -v -race -coverprofile='coverage.txt' -covermode=atomic github.com/snowlyg/iris-admin/g github.com/snowlyg/iris-admin/migration github.com/snowlyg/iris-admin/seed github.com/snowlyg/iris-admin/...
24 |
25 | after_success:
26 | - bash <(curl -s https://codecov.io/bash)
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
IrisAdmin
2 |
3 | [](https://app.travis-ci.com/snowlyg/iris-admin)
4 | [](https://github.com/snowlyg/iris-admin/blob/master/LICENSE)
5 | [](https://godoc.org/github.com/snowlyg/iris-admin)
6 | [](https://goreportcard.com/badge/github.com/snowlyg/iris-admin)
7 | [](https://codecov.io/gh/snowlyg/iris-admin)
8 |
9 | 简体中文 | [English](./README_EN.md)
10 |
11 | #### 项目地址
12 |
13 | [GITHUB](https://github.com/snowlyg/iris-admin)
14 |
15 | > 简单项目仅供学习,欢迎指点!
16 |
17 | #### 相关文档
18 |
19 | - [IRIS-ADMIN-DOC](https://doc.snowlyg.com)
20 | - [IRIS V12 中文文档](https://github.com/snowlyg/iris/wiki)
21 | - [godoc](https://pkg.go.dev/github.com/snowlyg/iris-admin?utm_source=godoc)
22 |
23 |
24 |
25 |
26 | #### iris 学习记录分享
27 |
28 | - [Iris-go 项目登陆 API 构建细节实现过程](https://snowlyg.github.io/iris-go-api-1/)
29 |
30 | - [iris + casbin 从陌生到学会使用的过程](https://snowlyg.github.io/iris-go-api-2/)
31 |
32 | ---
33 |
34 | #### 简单使用
35 |
36 | - 获取依赖包,注意必须带上 `master` 版本
37 |
38 | ```sh
39 | go get github.com/snowlyg/iris-admin@master
40 | ```
41 |
42 | #### 打赏
43 |
44 | > 您的打赏将用于支付网站运行,会在项目介绍中特别鸣谢您
45 | - [爱发电](https://afdian.net/@snowlyg/plan)
46 | - [donating](https://paypal.me/snowlyg?country.x=C2&locale.x=zh_XC)
47 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 | IrisAdmin
2 |
3 | [](https://app.travis-ci.com/snowlyg/iris-admin)
4 | [](https://github.com/snowlyg/iris-admin/blob/master/LICENSE)
5 | [](https://godoc.org/github.com/snowlyg/iris-admin)
6 | [](https://goreportcard.com/badge/github.com/snowlyg/iris-admin)
7 | [](https://codecov.io/gh/snowlyg/iris-admin)
8 |
9 | [简体中文](./README.md) | English
10 |
11 | #### Project url
12 |
13 | [GITHUB](https://github.com/snowlyg/iris-admin) | [GITEE](https://gitee.com/snowlyg/iris-admin)
14 |
15 | ---
16 |
17 | > This project just for learning golang, welcome to give your suggestions!
18 |
19 | #### Documentation
20 |
21 | - [IRIS-ADMIN-DOC](https://doc.snowlyg.com)
22 | - [IRIS V12 document for chinese](https://github.com/snowlyg/iris/wiki)
23 | - [godoc](https://pkg.go.dev/github.com/snowlyg/iris-admin?utm_source=godoc)
24 |
25 | [](https://gitter.im/iris-go-tenancy/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [](https://gitter.im/iris-go-tenancy/iris-admin?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
26 |
27 | #### BLOG
28 |
29 | - [REST API with iris-go web framework](https://snowlyg.github.io/iris-go-api-1/)
30 |
31 | - [How to user iris-go with casbin](https://snowlyg.github.io/iris-go-api-2/)
32 |
33 | ---
34 |
35 | #### Getting started
36 |
37 | - Get master package , Notice must use `master` version.
38 |
39 | ```sh
40 | go get github.com/snowlyg/iris-admin@master
41 | ```
42 |
43 |
44 | ## ☕️ Buy me a coffee
45 |
46 | > Please be sure to leave your name, GitHub account or other social media accounts when you donate by the following means so that I can add it to the list of donors as a token of my appreciation.
47 | - [爱发电](https://afdian.net/@snowlyg/plan)
48 | - [donating](https://paypal.me/snowlyg?country.x=C2&locale.x=zh_XC)
49 |
--------------------------------------------------------------------------------
/auth2/claims.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "fmt"
5 | "strconv"
6 | "strings"
7 | "time"
8 |
9 | "github.com/golang-jwt/jwt"
10 | )
11 |
12 | const (
13 | ValidationMalformed uint32 = 1 << iota // Token is malformed
14 | ValidationUnverifiable // Token could not be verified because of signing problems
15 | ValidationSignatureInvalid // Signature validation failed
16 | ValidationExpired // EXP validation failed
17 | ValidationId
18 | ValidationUsername
19 | ValidationAuthId
20 | ValidationRoleType
21 | ValidationLoginType
22 | ValidationAuthType
23 | )
24 |
25 | type Claims struct {
26 | Id string `json:"id,omitempty" redis:"id"`
27 | SuperAdmin bool `json:"superAdmin,omitempty" redis:"super_admin"`
28 | Username string `json:"username,omitempty" redis:"username"`
29 | AuthId string `json:"authId,omitempty" redis:"auth_id"`
30 | RoleType int `json:"roleType,omitempty" redis:"role_type"`
31 | LoginType int `json:"loginType,omitempty" redis:"login_type"`
32 | AuthType int `json:"authType,omitempty" redis:"auth_type"`
33 | CreationTime int64 `json:"creationData,omitempty" redis:"creation_data"`
34 | ExpiresAt int64 `json:"expiresAt,omitempty" redis:"expires_at"`
35 | }
36 |
37 | func (c *Claims) roleType() RoleType {
38 | return RoleType(c.RoleType)
39 | }
40 |
41 | func (c *Claims) loginType() LoginType {
42 | return LoginType(c.LoginType)
43 | }
44 |
45 | func (c *Claims) authType() AuthType {
46 | return AuthType(c.AuthType)
47 | }
48 |
49 | func (c *Claims) setRoleType(roleType int) {
50 | c.RoleType = roleType
51 | }
52 |
53 | func (c *Claims) setLoginType(loginType int) {
54 | c.LoginType = loginType
55 | }
56 |
57 | func (c *Claims) setAuthType(authType int) {
58 | c.AuthType = authType
59 | }
60 |
61 | func NewClaims(m *Agent) *Claims {
62 | claims := &Claims{
63 | Id: strconv.FormatUint(uint64(m.Id), 10),
64 | SuperAdmin: m.SuperAdmin,
65 | Username: m.Username,
66 | AuthId: strings.Join(m.AuthIds, "-"),
67 | RoleType: int(m.RoleType),
68 | LoginType: int(m.LoginType),
69 | AuthType: int(m.AuthType),
70 | CreationTime: time.Now().Local().Unix(),
71 | ExpiresAt: m.ExpiresAt,
72 | }
73 | return claims
74 | }
75 |
76 | func (c *Claims) Valid() error {
77 | vErr := new(jwt.ValidationError)
78 | now := time.Now().Unix()
79 | // The claims below are optional, by default, so if they are set to the
80 | // default value in Go, let's not fail the verification for them.
81 | if !c.VerifyExpiresAt(now, false) {
82 | delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
83 | vErr.Inner = fmt.Errorf("claims:token is expired by %v", delta)
84 | vErr.Errors |= ValidationExpired
85 | }
86 | if !c.VerifyId() {
87 | vErr.Inner = fmt.Errorf("claims:id[%s] is empty", c.Id)
88 | vErr.Errors |= ValidationId
89 | }
90 | if !c.VerifyUsername() {
91 | vErr.Inner = fmt.Errorf("claims:username[%s] is empty", c.Username)
92 | vErr.Errors |= ValidationUsername
93 | }
94 | if !c.VerifyAuthId() {
95 | vErr.Inner = fmt.Errorf("claims:authId[%s] is empty", c.AuthId)
96 | vErr.Errors |= ValidationAuthId
97 | }
98 | if !c.VerifyType() {
99 | vErr.Inner = fmt.Errorf("claims:roleType[%d] is invalid", c.RoleType)
100 | vErr.Errors |= ValidationRoleType
101 | }
102 | if !c.VerifyLoginType() {
103 | vErr.Inner = fmt.Errorf("claims:loginType[%d] is invalid", c.LoginType)
104 | vErr.Errors |= ValidationLoginType
105 | }
106 | if !c.VerifyAuthType() {
107 | vErr.Inner = fmt.Errorf("claims:authType[%d] is invalid", c.AuthType)
108 | vErr.Errors |= ValidationAuthType
109 | }
110 | if !valid(vErr) {
111 | return vErr
112 | }
113 |
114 | return nil
115 | }
116 |
117 | // No errors
118 | func valid(e *jwt.ValidationError) bool {
119 | return e.Errors == 0
120 | }
121 |
122 | // Compares the exp claim against cmp.
123 | // If required is false, this method will return true if the value matches or is unset
124 | func (c *Claims) VerifyExpiresAt(cmp int64, req bool) bool {
125 | return verifyExp(c.ExpiresAt, cmp, req)
126 | }
127 |
128 | func verifyExp(exp int64, now int64, required bool) bool {
129 | if exp == 0 {
130 | return !required
131 | }
132 | return now <= exp
133 | }
134 |
135 | func (c *Claims) VerifyId() bool {
136 | if id, err := strconv.Atoi(c.Id); err != nil {
137 | return false
138 | } else if id > 0 {
139 | return true
140 | }
141 | return false
142 | }
143 |
144 | func (c *Claims) VerifyUsername() bool {
145 | return c.Username != ""
146 | }
147 |
148 | func (c *Claims) VerifyAuthId() bool {
149 | return c.AuthId != ""
150 | }
151 |
152 | func (c *Claims) VerifyType() bool {
153 | return c.RoleType > 0
154 | }
155 |
156 | func (c *Claims) VerifyLoginType() bool {
157 | return LoginType(c.LoginType) >= LoginTypeWeb && LoginType(c.LoginType) <= LoginTypeDevice
158 | }
159 |
160 | func (c *Claims) VerifyAuthType() bool {
161 | return AuthType(c.AuthType) >= NoAuth && AuthType(c.AuthType) <= AuthThirdParty
162 | }
163 |
--------------------------------------------------------------------------------
/auth2/claims_test.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 | "time"
7 |
8 | "github.com/golang-jwt/jwt"
9 | )
10 |
11 | var testAgent = &Agent{
12 | Id: uint(8457585),
13 | SuperAdmin: true,
14 | Username: "username",
15 | AuthIds: []string{"999"},
16 | RoleType: RoleAdmin,
17 | LoginType: LoginTypeWeb,
18 | AuthType: AuthPwd,
19 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
20 | }
21 |
22 | func TestNewClaims(t *testing.T) {
23 | cla := NewClaims(testAgent)
24 | if cla == nil {
25 | t.Fatal("claims init return is nil")
26 | }
27 | if cla.Id != "8457585" {
28 | t.Error("claims id is not 8457585")
29 | }
30 | if cla.SuperAdmin != true {
31 | t.Error("claims super admin is not true")
32 | }
33 | if cla.Username != "username" {
34 | t.Error("claims username is not username")
35 | }
36 | if cla.AuthId != "999" {
37 | t.Error("claims auth ids is not 999")
38 | }
39 | if cla.roleType() != RoleAdmin {
40 | t.Error("claims type is not admin")
41 | }
42 | if cla.loginType() != LoginTypeWeb {
43 | t.Error("claims login type is not web")
44 | }
45 | if cla.authType() != AuthPwd {
46 | t.Error("claims auth type is not web")
47 | }
48 | if cla.ExpiresAt != time.Now().Local().Add(TimeoutWeb).Unix() {
49 | t.Error("claims expires at is not now")
50 | }
51 |
52 | cla.setAuthType(2)
53 | if cla.authType() != AuthCode {
54 | t.Error("claims authType is not authCode")
55 | }
56 |
57 | cla.setLoginType(1)
58 | if cla.loginType() != LoginTypeApp {
59 | t.Error("claims loginType is not loginTypeApp")
60 | }
61 | cla.setRoleType(2)
62 | if cla.roleType() != RoleTenancy {
63 | t.Error("claims roleType is not roleTenancy")
64 | }
65 | }
66 |
67 | func TestValid(t *testing.T) {
68 | cla := NewClaims(&Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, RoleType: RoleAdmin, LoginType: LoginTypeWeb, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()})
69 | if err := cla.Valid(); err != nil {
70 | t.Fatal(err)
71 | }
72 | args := []struct {
73 | agent *Agent
74 | name string
75 | want uint32
76 | }{
77 | {
78 | name: "ValidationAuthType",
79 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, AuthType: 99, RoleType: RoleAdmin, LoginType: LoginTypeWeb, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()},
80 | want: ValidationAuthType,
81 | },
82 | {
83 | name: "ValidationLoginType",
84 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, LoginType: 99, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()},
85 | want: ValidationLoginType,
86 | },
87 | {
88 | name: "ValidationRoleType",
89 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{"999"}, LoginType: LoginTypeApp, RoleType: -1, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()},
90 | want: ValidationRoleType,
91 | },
92 | {
93 | name: "ValidationAuthId",
94 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "username", AuthIds: []string{""}, LoginType: LoginTypeApp, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()},
95 | want: ValidationAuthId,
96 | },
97 | {
98 | name: "ValidationUsername",
99 | agent: &Agent{Id: uint(8457585), SuperAdmin: true, Username: "", AuthIds: []string{"1"}, LoginType: LoginTypeApp, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()},
100 | want: ValidationUsername,
101 | },
102 | {
103 | name: "ValidationId",
104 | agent: &Agent{Id: 0, SuperAdmin: true, Username: "username", AuthIds: []string{"1"}, LoginType: LoginTypeApp, RoleType: RoleAdmin, AuthType: AuthPwd, ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix()},
105 | want: ValidationId,
106 | },
107 | }
108 |
109 | for _, arg := range args {
110 | t.Run(arg.name, func(t *testing.T) {
111 | cla = NewClaims(arg.agent)
112 | if err := cla.Valid(); err == nil {
113 | t.Fatal("error is nil")
114 | } else {
115 |
116 | if v, ok := err.(*jwt.ValidationError); !ok {
117 | t.Fatalf("%s %s", reflect.TypeOf(err).String(), err.Error())
118 | } else if v.Errors != arg.want {
119 | t.Fatalf("%d %d %s", v.Errors, arg.want, v.Error())
120 | }
121 | }
122 | })
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/auth2/example/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "time"
7 |
8 | "github.com/gin-gonic/gin"
9 | "github.com/go-redis/redis/v8"
10 | "github.com/snowlyg/iris-admin/auth2"
11 | )
12 |
13 | func init() {
14 | options := &redis.UniversalOptions{
15 | Addrs: []string{"127.0.0.1:6379"},
16 | Password: "",
17 | PoolSize: 10,
18 | IdleTimeout: 300 * time.Second,
19 | // Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
20 | // conn, err := net.Dial(network, addr)
21 | // if err == nil {
22 | // go func() {
23 | // time.Sleep(5 * time.Second)
24 | // conn.Close()
25 | // }()
26 | // }
27 | // return conn, err
28 | // },
29 | }
30 |
31 | err := auth2.NewAgent(&auth2.Config{
32 | Type: "redis",
33 | Max: 10,
34 | UniversalClient: redis.NewUniversalClient(options)})
35 | if err != nil {
36 | panic(fmt.Sprintf("auth is not init get err %v\n", err))
37 | }
38 | }
39 |
40 | func auth() gin.HandlerFunc {
41 | verifier := auth2.NewVerifier()
42 | verifier.Extractors = []auth2.TokenExtractor{auth2.FromHeader} // extract token only from Authorization: Bearer $token
43 | return verifier.Verify()
44 | }
45 |
46 | func main() {
47 | app := gin.New()
48 |
49 | app.GET("/", generateToken())
50 |
51 | protectedAPI := app.Group("/protected")
52 | // Register the verify middleware to allow access only to authorized clients.
53 | protectedAPI.Use(auth())
54 | // ^ or UseRouter(verifyMiddleware) to disallow unauthorized http error handlers too.
55 |
56 | protectedAPI.GET("/", protected)
57 | // Invalidate the token through server-side, even if it's not expired yet.
58 | protectedAPI.GET("/logout", logout)
59 |
60 | // http://localhost:8080
61 | // http://localhost:8080/protected (or Authorization: Bearer $token)
62 | // http://localhost:8080/protected/logout
63 | // http://localhost:8080/protected (401)
64 | app.Run(":8080")
65 | }
66 |
67 | func generateToken() gin.HandlerFunc {
68 | return func(ctx *gin.Context) {
69 | claims := auth2.NewClaims(&auth2.Agent{
70 | Id: 1,
71 | Username: "your name",
72 | AuthIds: []string{"your authority id"},
73 | RoleType: auth2.RoleAdmin,
74 | LoginType: auth2.LoginTypeWeb,
75 | AuthType: auth2.AuthPwd,
76 | CreationTime: time.Now().Local().Unix(),
77 | ExpiresAt: time.Now().Local().Add(auth2.TimeoutWeb).Unix(),
78 | })
79 |
80 | token, _, err := auth2.AuthAgent.Generate(claims)
81 | if err != nil {
82 | ctx.AbortWithStatus(http.StatusInternalServerError)
83 | return
84 | }
85 |
86 | ctx.String(200, token)
87 | }
88 | }
89 |
90 | func protected(ctx *gin.Context) {
91 | claims := auth2.Get(ctx)
92 | ctx.JSON(http.StatusOK, fmt.Sprintf("claims=%+v\n", claims))
93 | }
94 |
95 | func logout(ctx *gin.Context) {
96 | token := auth2.GetVerifiedToken(ctx)
97 | if token == nil {
98 | ctx.String(http.StatusOK, auth2.ErrEmptyToken.Error())
99 | return
100 | }
101 | err := auth2.AuthAgent.DelCache(string(token))
102 | if err != nil {
103 | ctx.JSON(http.StatusOK, err.Error())
104 | return
105 | }
106 | ctx.String(http.StatusOK, "token invalidated, a new token is required to access the protected API")
107 | }
108 |
--------------------------------------------------------------------------------
/auth2/extractor.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "log"
5 | "strings"
6 |
7 | "github.com/gin-gonic/gin"
8 | "github.com/gin-gonic/gin/binding"
9 | )
10 |
11 | // TokenExtractor is a function that takes a context as input and returns
12 | // a token. An empty string should be returned if no token found
13 | // without additional information.
14 | type TokenExtractor func(*gin.Context) string
15 |
16 | // FromHeader is a token extractor.
17 | // It reads the token from the Authorization request header of form:
18 | // Authorization: "Bearer {token}".
19 | func FromHeader(ctx *gin.Context) string {
20 | authHeader := ctx.GetHeader("Authorization")
21 | if authHeader == "" {
22 | return ""
23 | }
24 |
25 | // pure check: authorization header format must be Bearer {token}
26 | authHeaderParts := strings.Split(authHeader, " ")
27 | if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
28 | return ""
29 | }
30 |
31 | return authHeaderParts[1]
32 | }
33 |
34 | // FromQuery is a token extractor.
35 | // It reads the token from the "token" url query parameter.
36 | func FromQuery(ctx *gin.Context) string {
37 | return ctx.Query("token")
38 | }
39 |
40 | // FromJSON is a token extractor.
41 | // Reads a json request body and extracts the json based on the given field.
42 | // The request content-type should contain the: application/json header value, otherwise
43 | // this method will not try to read and consume the body.
44 | func FromJSON(jsonKey string) TokenExtractor {
45 | return func(ctx *gin.Context) string {
46 | if ctx.ContentType() != binding.MIMEJSON {
47 | log.Printf("extractor: content-type %s not supported\n", ctx.ContentType())
48 | return ""
49 | }
50 |
51 | var m gin.H
52 | if err := ctx.BindJSON(&m); err != nil {
53 | log.Println("extractor: bind json error:", err.Error())
54 | return ""
55 | }
56 |
57 | if m == nil {
58 | log.Println("extractor: json is empty")
59 | return ""
60 | }
61 |
62 | v, ok := m[jsonKey]
63 | if !ok {
64 | log.Printf("extractor: key %s not found\n", jsonKey)
65 | return ""
66 | }
67 |
68 | tok, ok := v.(string)
69 | if !ok {
70 | log.Printf("extractor: key %s value:[%v] is not a string\n", jsonKey, v)
71 | return ""
72 | }
73 | return tok
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/auth2/interface.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/go-redis/redis/v8"
9 | )
10 |
11 | const (
12 | TokenPrefix = "GST:"
13 | BindUserPrefix = "GSBU:"
14 | UserPrefix = "GSU:"
15 | LimitTokenPrefix = "GT_LIMIT_TOKEN"
16 | )
17 |
18 | var (
19 | AuthTypeSplit = "-"
20 | LimitTokenDefault int64 = 10
21 | )
22 |
23 | var (
24 | ErrTokenInvalid = errors.New("token_invalid")
25 | ErrEmptyToken = errors.New("token_empty")
26 | ErrOverLimit = errors.New("token_over_limit")
27 | )
28 |
29 | type RoleType int
30 |
31 | const (
32 | RoleNone RoleType = iota
33 | RoleAdmin
34 | RoleTenancy
35 | RoleGeneral
36 | )
37 |
38 | type AuthType int
39 |
40 | const (
41 | NoAuth AuthType = iota
42 | AuthPwd
43 | AuthCode
44 | AuthThirdParty
45 | )
46 |
47 | type LoginType int
48 |
49 | const (
50 | LoginTypeWeb LoginType = iota
51 | LoginTypeApp
52 | LoginTypeWx
53 | LoginTypeDevice
54 | )
55 |
56 | var (
57 | TimeoutWeb = 4 * time.Hour
58 | TimeoutApp = 7 * 24 * time.Hour
59 | TimeoutWx = 5 * 52 * 168 * time.Hour
60 | TimeoutDevice = 5 * 52 * 168 * time.Hour
61 | )
62 |
63 | func NewAgent(c *Config) error {
64 | if c.Max == 0 {
65 | c.Max = 10
66 | }
67 | switch c.Type {
68 | case "redis":
69 | agent, err := NewRedis(c.UniversalClient)
70 | if err != nil {
71 | return err
72 | }
73 |
74 | AuthAgent = agent
75 | err = AuthAgent.SetLimit(c.Max)
76 | if err != nil {
77 | return err
78 | }
79 | case "local":
80 | AuthAgent = NewLocal()
81 | err := AuthAgent.SetLimit(c.Max)
82 | if err != nil {
83 | return err
84 | }
85 | case "jwt":
86 | AuthAgent = NewJwt(c.HmacSecret)
87 | default:
88 | AuthAgent = NewJwt(c.HmacSecret)
89 | }
90 |
91 | return nil
92 | }
93 |
94 | // Agent
95 | type Agent struct {
96 | Id uint `json:"id,omitempty"`
97 | SuperAdmin bool `json:"superAdmin,omitempty"`
98 | Username string `json:"username,omitempty"`
99 | AuthIds []string `json:"authIds,omitempty"`
100 | RoleType RoleType `json:"type,omitempty"`
101 | LoginType LoginType `json:"loginType,omitempty"`
102 | AuthType AuthType `json:"authType,omitempty"`
103 | CreationTime int64 `json:"creationData,omitempty"`
104 | ExpiresAt int64 `json:"expiresAt,omitempty"`
105 | }
106 |
107 | type Config struct {
108 | Type string
109 | Max int64
110 | UniversalClient redis.UniversalClient
111 | HmacSecret []byte
112 | }
113 |
114 | type (
115 | // TokenValidator provides further token and claims validation.
116 | TokenValidator interface {
117 | // Validater accepts the token, the claims extracted from that
118 | // and any error that may caused by claims validation (e.g. ErrExpired)
119 | // or the previous validator.
120 | // A token validator can skip the builtin validation and return a nil error.
121 | // Usage:
122 | // func(v *myValidator) Validater(token []byte, standardClaims Claims, err error) error {
123 | // if err!=nil { return err } <- to respect the previous error
124 | // // otherwise return nil or any custom error.
125 | // }
126 | //
127 | // Look `Blocklist`, `Expected` and `Leeway` for builtin implementations.
128 | Validater(token []byte, err error) error
129 | }
130 |
131 | // ValidatorFunc is the interface-as-function shortcut for a TokenValidator.
132 | ValidatorFunc func(token []byte, err error) error
133 | )
134 |
135 | // ValidateToken completes the ValidateToken interface.
136 | // It calls itself.
137 | func (fn ValidatorFunc) ValidateToken(token []byte, err error) error {
138 | return fn(token, err)
139 | }
140 |
141 | var AuthAgent Authentication
142 |
143 | // Authentication
144 | type Authentication interface {
145 | Generate(claims *Claims) (string, int64, error)
146 | DelCache(token string) error
147 | UpdateCacheExpire(token string) error
148 | GetClaims(token string) (*Claims, error)
149 | Token(claims *Claims) (string, error)
150 | CleanCache(roleType RoleType, userId string) error
151 | SetLimit(max int64) error
152 | IsRole(token string, roleType RoleType) (bool, error)
153 | IsSuperAdmin(token string) bool
154 | Close()
155 | }
156 |
157 | // getExpire
158 | func getExpire(loginType LoginType) time.Duration {
159 | switch loginType {
160 | case LoginTypeWeb:
161 | return TimeoutWeb
162 | case LoginTypeWx:
163 | return TimeoutWx
164 | case LoginTypeApp:
165 | return TimeoutApp
166 | case LoginTypeDevice:
167 | return TimeoutDevice
168 | default:
169 | return TimeoutWeb
170 | }
171 | }
172 |
173 | // getPrefixKey
174 | func getPrefixKey(roleType RoleType, id string) string {
175 | return fmt.Sprintf("%s%d_%s", UserPrefix, roleType, id)
176 | }
177 |
--------------------------------------------------------------------------------
/auth2/jwt.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "fmt"
5 | "log"
6 |
7 | "github.com/golang-jwt/jwt"
8 | "github.com/snowlyg/helper/arr"
9 | )
10 |
11 | var hmacSampleSecret = []byte("updPA0L2uQ56LwHZoyUX")
12 |
13 | // JwtAuth
14 | type JwtAuth struct {
15 | HmacSecret []byte
16 | delToken arr.ArrayType
17 | }
18 |
19 | // NewJwt
20 | func NewJwt(hmacSecret []byte) *JwtAuth {
21 | ja := &JwtAuth{
22 | HmacSecret: hmacSecret,
23 | delToken: arr.NewCheckArrayType(0),
24 | }
25 | if ja.HmacSecret == nil {
26 | ja.HmacSecret = hmacSampleSecret
27 | }
28 | return ja
29 | }
30 |
31 | // Generate
32 | func (ra *JwtAuth) Generate(claims *Claims) (string, int64, error) {
33 | token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
34 |
35 | // Sign and get the complete encoded token as a string using the secret
36 | tokenString, err := token.SignedString(ra.HmacSecret)
37 | if err != nil {
38 | return "", 0, err
39 | }
40 | return tokenString, 0, nil
41 | }
42 |
43 | // Token
44 | func (ra *JwtAuth) Token(cla *Claims) (string, error) {
45 | log.Printf("jwt:get token not support\n")
46 | return "", nil
47 | }
48 |
49 | // GetClaims
50 | func (ra *JwtAuth) GetClaims(tokenString string) (*Claims, error) {
51 | if ra.delToken.Check(tokenString) {
52 | return nil, fmt.Errorf("jwt:token deleted %w", ErrTokenInvalid)
53 | }
54 | mc := &Claims{}
55 | token, err := jwt.ParseWithClaims(tokenString, mc, func(token *jwt.Token) (interface{}, error) {
56 | // Don't forget to validate the alg is what you expect:
57 | if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
58 | return nil, fmt.Errorf("incorrect signing method: %v", token.Header["alg"])
59 | }
60 | return ra.HmacSecret, nil
61 | })
62 | if err != nil {
63 | return nil, err
64 | }
65 |
66 | if _, ok := token.Claims.(*Claims); ok && token.Valid {
67 | return mc, nil
68 | } else {
69 | return nil, fmt.Errorf("token[%s]:%w", tokenString, ErrTokenInvalid)
70 | }
71 | }
72 |
73 | // SetLimit
74 | func (ra *JwtAuth) SetLimit(limit int64) error {
75 | log.Printf("jwt:set max count not support\n")
76 | return nil
77 | }
78 |
79 | // UpdateCacheExpire
80 | func (ra *JwtAuth) UpdateCacheExpire(token string) error {
81 | log.Printf("jwt:UpdateCacheExpire not support\n")
82 | return nil
83 | }
84 |
85 | // DelCache
86 | func (ra *JwtAuth) DelCache(token string) error {
87 | ra.delToken.Add(token)
88 | return nil
89 | }
90 |
91 | // CleanCache
92 | func (ra *JwtAuth) CleanCache(roleType RoleType, userId string) error {
93 | log.Printf("jwt:CleanCache not support")
94 | return nil
95 | }
96 |
97 | // IsRole
98 | func (ra *JwtAuth) IsRole(token string, roleType RoleType) (bool, error) {
99 | rcc, err := ra.GetClaims(token)
100 | if err != nil {
101 | return false, fmt.Errorf("jwt:get User's infomation return error: %w", err)
102 | }
103 | return rcc.roleType() == roleType, nil
104 | }
105 |
106 | // IsSuperAdmin
107 | func (ra *JwtAuth) IsSuperAdmin(token string) bool {
108 | rcc, err := ra.GetClaims(token)
109 | if err != nil {
110 | log.Printf("jwt:get claims fail:%s\n", err.Error())
111 | return false
112 | }
113 | return rcc.SuperAdmin
114 | }
115 |
116 | // Close
117 | func (ra *JwtAuth) Close() {}
118 |
--------------------------------------------------------------------------------
/auth2/jwt_test.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "testing"
5 | "time"
6 | )
7 |
8 | var (
9 | jwtAuth = NewJwt(nil)
10 | jwtClaims = NewClaims(
11 | &Agent{
12 | Id: uint(8457585),
13 | Username: "jwt username",
14 | AuthIds: []string{"999"},
15 | RoleType: RoleAdmin,
16 | LoginType: LoginTypeWeb,
17 | AuthType: AuthPwd,
18 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
19 | },
20 | )
21 | )
22 |
23 | func TestJwtGenerateToken(t *testing.T) {
24 | defer jwtAuth.CleanCache(jwtClaims.roleType(), jwtClaims.Id)
25 | token, _, err := jwtAuth.Generate(jwtClaims)
26 | if err != nil {
27 | t.Fatalf("generate token %v", err)
28 | }
29 | if token == "" {
30 | t.Error("generate token is empty")
31 | }
32 |
33 | cc, err := jwtAuth.GetClaims(token)
34 | if err != nil {
35 | t.Fatalf("get custom claims fail:%v", err)
36 | }
37 |
38 | if cc.Id != jwtClaims.Id {
39 | t.Errorf("get custom id want %v but get %v", jwtClaims.Id, cc.Id)
40 | }
41 | if cc.Username != jwtClaims.Username {
42 | t.Errorf("get custom username want %v but get %v", jwtClaims.Username, cc.Username)
43 | }
44 | if cc.AuthId != jwtClaims.AuthId {
45 | t.Errorf("get custom authority_id want %v but get %v", jwtClaims.AuthId, cc.AuthId)
46 | }
47 | if cc.RoleType != jwtClaims.RoleType {
48 | t.Errorf("get custom authority_type want %v but get %v", jwtClaims.RoleType, cc.RoleType)
49 | }
50 | if cc.LoginType != jwtClaims.LoginType {
51 | t.Errorf("get custom login_type want %v but get %v", jwtClaims.LoginType, cc.LoginType)
52 | }
53 | if cc.AuthType != jwtClaims.AuthType {
54 | t.Errorf("get custom auth_type want %v but get %v", jwtClaims.AuthType, cc.AuthType)
55 | }
56 | if cc.CreationTime != jwtClaims.CreationTime {
57 | t.Errorf("get custom creation_data want %v but get %v", jwtClaims.CreationTime, cc.CreationTime)
58 | }
59 | if cc.ExpiresAt != jwtClaims.ExpiresAt {
60 | t.Errorf("get custom expires_at want %v but get %v", jwtClaims.ExpiresAt, cc.ExpiresAt)
61 | }
62 | }
63 |
64 | func TestJwtSetUserTokenMaxCount(t *testing.T) {
65 | err := jwtAuth.SetLimit(3)
66 | if err != nil {
67 | t.Errorf("get token by claims token want %v but get %v", nil, err)
68 | }
69 | }
70 |
71 | func TestJwtGetMultiClaims(t *testing.T) {
72 | defer jwtAuth.CleanCache(jwtClaims.roleType(), jwtClaims.Id)
73 | var token string
74 | jwtClaims.setLoginType(int(LoginTypeWeb))
75 | token, _, err := jwtAuth.Generate(jwtClaims)
76 | if err != nil {
77 | t.Fatalf("get custom claims %v", err)
78 | }
79 | for i := LoginTypeWeb; i <= LoginTypeDevice; i++ {
80 | wg.Add(1)
81 | jwtClaims.setLoginType(int(i))
82 | go func(i LoginType) {
83 | jwtAuth.Generate(jwtClaims)
84 | wg.Done()
85 | }(i)
86 | wg.Wait()
87 | }
88 | for i := 0; i < 4; i++ {
89 | go func() {
90 | _, err := jwtAuth.GetClaims(token)
91 | if err != nil {
92 | t.Errorf("get custom claims %v", err)
93 | }
94 | }()
95 | }
96 | time.Sleep(3 * time.Second)
97 | }
98 |
99 | func TestJwtGetTokenByClaims(t *testing.T) {
100 | _, err := jwtAuth.Token(jwtClaims)
101 | if err != nil {
102 | t.Errorf("get token by claims token want %v but get %v", nil, err)
103 | }
104 | }
105 |
106 | func TestJwtDelUserTokenCache(t *testing.T) {
107 | token, _, _ := jwtAuth.Generate(jwtClaims)
108 | if token == "" {
109 | t.Error("generate token is empty")
110 | }
111 | err := jwtAuth.DelCache(token)
112 | if err != nil {
113 | t.Errorf("del token fail:%v", err.Error())
114 | }
115 | if _, err := jwtAuth.GetClaims(token); err == nil {
116 | t.Error("del user token fail")
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/auth2/local.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/patrickmn/go-cache"
9 | )
10 |
11 | type tokens []string
12 |
13 | var localCache *cache.Cache
14 |
15 | type LocalAuth struct {
16 | Cache *cache.Cache
17 | }
18 |
19 | func NewLocal() *LocalAuth {
20 | if localCache == nil {
21 | localCache = cache.New(4*time.Hour, 24*time.Minute)
22 | }
23 | return &LocalAuth{
24 | Cache: localCache,
25 | }
26 | }
27 |
28 | // Generate
29 | func (la *LocalAuth) Generate(claims *Claims) (string, int64, error) {
30 |
31 | if la.isUserTokenOver(claims.roleType(), claims.Id) {
32 | return "", 0, fmt.Errorf("local: is user token over fail:%w", ErrOverLimit)
33 | }
34 | token, err := getToken()
35 | if err != nil {
36 | return "", 0, fmt.Errorf("local: get token fail:%w", err)
37 | }
38 | la.toCache(token, claims)
39 | if e := la.syncUserCache(token); e != nil {
40 | return "", 0, fmt.Errorf("local: sync user token fail:%w", e)
41 | }
42 | return token, int64(claims.ExpiresAt), nil
43 | }
44 |
45 | func (la *LocalAuth) toCache(token string, rcc *Claims) error {
46 | sKey := TokenPrefix + token
47 | la.Cache.Set(sKey, rcc, getExpire(rcc.loginType()))
48 | return nil
49 | }
50 |
51 | func (la *LocalAuth) syncUserCache(token string) error {
52 | rcc, err := la.GetClaims(token)
53 | if err != nil {
54 | return err
55 | }
56 | userPrefixKey := getPrefixKey(rcc.roleType(), rcc.Id)
57 | ts := tokens{}
58 | if uTokens, ok := la.Cache.Get(userPrefixKey); ok && uTokens != nil {
59 | ts = uTokens.(tokens)
60 | }
61 | ts = append(ts, token)
62 | la.Cache.Set(userPrefixKey, ts, cache.NoExpiration)
63 | la.Cache.Set(BindUserPrefix+token, userPrefixKey, getExpire(rcc.loginType()))
64 | return nil
65 | }
66 |
67 | func (la *LocalAuth) DelCache(token string) error {
68 | rcc, err := la.GetClaims(token)
69 | if err != nil {
70 | return err
71 | }
72 | userPrefixKey := getPrefixKey(rcc.roleType(), rcc.Id)
73 | if utokens, ok := la.Cache.Get(userPrefixKey); ok && utokens != nil {
74 | t := utokens.(tokens)
75 | for index, u := range t {
76 | if u == token {
77 | if len(t) == 1 {
78 | utokens = nil
79 | } else {
80 | utokens = append(t[0:index], t[index:]...)
81 | }
82 | }
83 | }
84 | la.Cache.Set(userPrefixKey, utokens, cache.NoExpiration)
85 | }
86 | la.delTokenCache(token)
87 | return nil
88 | }
89 |
90 | // delTokenCache
91 | func (la *LocalAuth) delTokenCache(token string) error {
92 | la.Cache.Delete(BindUserPrefix + token)
93 | la.Cache.Delete(TokenPrefix + token)
94 | return nil
95 | }
96 |
97 | func (la *LocalAuth) UpdateCacheExpire(token string) error {
98 | rsv2, err := la.GetClaims(token)
99 | if err != nil {
100 | return err
101 | }
102 | if rsv2 == nil {
103 | return errors.New("token cache is nil")
104 | }
105 | la.Cache.Set(BindUserPrefix+token, rsv2, getExpire(rsv2.loginType()))
106 | la.Cache.Set(TokenPrefix+token, rsv2, getExpire(rsv2.loginType()))
107 | return nil
108 | }
109 |
110 | func (la *LocalAuth) GetClaims(token string) (*Claims, error) {
111 | sKey := TokenPrefix + token
112 | if food, found := la.Cache.Get(sKey); !found || food == nil {
113 | return nil, fmt.Errorf("token not found:%w", ErrTokenInvalid)
114 | } else {
115 | return food.(*Claims), nil
116 | }
117 | }
118 |
119 | // Token
120 | func (la *LocalAuth) Token(cla *Claims) (string, error) {
121 | userTokens, err := la.getUserTokens(cla.roleType(), cla.Id)
122 | if err != nil {
123 | return "", err
124 | }
125 | clas, err := la.getMultiClaimses(userTokens)
126 | if err != nil {
127 | return "", err
128 | }
129 | for token, existCla := range clas {
130 | if cla.AuthType == existCla.AuthType &&
131 | cla.Id == existCla.Id &&
132 | cla.RoleType == existCla.RoleType &&
133 | cla.AuthId == existCla.AuthId &&
134 | cla.LoginType == existCla.LoginType {
135 | return token, nil
136 | }
137 | }
138 | return "", nil
139 | }
140 |
141 | // getUserTokens
142 | func (la *LocalAuth) getUserTokens(roleType RoleType, userId string) (tokens, error) {
143 | if utokens, ok := la.Cache.Get(getPrefixKey(roleType, userId)); ok && utokens != nil {
144 | return utokens.(tokens), nil
145 | }
146 | return nil, nil
147 | }
148 |
149 | // getMultiClaimses 获取用户信息
150 | func (la *LocalAuth) getMultiClaimses(tokens tokens) (map[string]*Claims, error) {
151 | clas := make(map[string]*Claims, la.getUserTokenMaxCount())
152 | for _, token := range tokens {
153 | cla, err := la.GetClaims(token)
154 | if err != nil {
155 | continue
156 | }
157 | clas[token] = cla
158 | }
159 |
160 | return clas, nil
161 | }
162 |
163 | func (la *LocalAuth) isUserTokenOver(roleType RoleType, userId string) bool {
164 | return la.getUserTokenCount(roleType, userId) >= la.getUserTokenMaxCount()
165 | }
166 |
167 | // getUserTokenCount 获取登录数量
168 | func (la *LocalAuth) getUserTokenCount(roleType RoleType, userId string) int64 {
169 | return la.checkMaxCount(roleType, userId)
170 | }
171 |
172 | func (la *LocalAuth) checkMaxCount(roleType RoleType, userId string) int64 {
173 | utokens, _ := la.getUserTokens(roleType, userId)
174 | if utokens == nil {
175 | return 0
176 | }
177 | for index, u := range utokens {
178 | if _, found := la.Cache.Get(TokenPrefix + u); !found {
179 | if len(utokens) == 1 {
180 | utokens = nil
181 | } else {
182 | utokens = append(utokens[0:index], utokens[index:]...)
183 | }
184 | }
185 | }
186 | la.Cache.Set(getPrefixKey(roleType, userId), utokens, cache.NoExpiration)
187 | return int64(len(utokens))
188 |
189 | }
190 |
191 | // getUserTokenMaxCount
192 | func (la *LocalAuth) getUserTokenMaxCount() int64 {
193 | if count, found := la.Cache.Get(LimitTokenPrefix); !found {
194 | return LimitTokenDefault
195 | } else {
196 | return count.(int64)
197 | }
198 | }
199 |
200 | // SetLimit
201 | func (la *LocalAuth) SetLimit(tokenMaxCount int64) error {
202 | la.Cache.Set(LimitTokenPrefix, tokenMaxCount, cache.NoExpiration)
203 | return nil
204 | }
205 |
206 | // CleanCache
207 | func (la *LocalAuth) CleanCache(roleType RoleType, userId string) error {
208 | utokens, _ := la.getUserTokens(roleType, userId)
209 | if utokens == nil {
210 | return nil
211 | }
212 | for _, token := range utokens {
213 | err := la.delTokenCache(token)
214 | if err != nil {
215 | continue
216 | }
217 | }
218 | la.Cache.Delete(getPrefixKey(roleType, userId))
219 | return nil
220 | }
221 |
222 | // IsRole
223 | func (la *LocalAuth) IsRole(token string, roleType RoleType) (bool, error) {
224 | rcc, err := la.GetClaims(token)
225 | if err != nil {
226 | return false, fmt.Errorf("local: get multi claims fail %w", err)
227 | }
228 | return rcc.roleType() == roleType, nil
229 | }
230 |
231 | // IsRole
232 | func (la *LocalAuth) IsSuperAdmin(token string) bool {
233 | rcc, err := la.GetClaims(token)
234 | if err != nil {
235 | return false
236 | }
237 | return rcc.SuperAdmin
238 | }
239 |
240 | func (la *LocalAuth) Close() {}
241 |
--------------------------------------------------------------------------------
/auth2/local_test.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | "time"
7 | )
8 |
9 | var (
10 | localAuth = NewLocal()
11 | tToken = "TVRReU1EVTFOek13TmpFd09UWXlPRFF4TmcuTWpBeU1TMHdOeTB5T1ZRd09Ub3pNRG95T1Nzd09Eb3dNQQ.MTQyMDU1NzMwNjEwOTYyODQxNg"
12 | loginTypeApp = NewClaims(
13 | &Agent{
14 | Id: uint(1),
15 | SuperAdmin: true,
16 | Username: "username",
17 | AuthIds: []string{"999"},
18 | RoleType: RoleAdmin,
19 | LoginType: LoginTypeWeb,
20 | AuthType: AuthPwd,
21 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
22 | },
23 | )
24 | userKey = getPrefixKey(loginTypeApp.roleType(), loginTypeApp.Id)
25 | )
26 |
27 | func TestNewLocalAuth(t *testing.T) {
28 | if NewLocal() == nil {
29 | t.Error("new local auth get nil")
30 | }
31 | }
32 |
33 | func TestGenerateToken(t *testing.T) {
34 | token, expiresIn, err := localAuth.Generate(loginTypeApp)
35 | if err != nil {
36 | t.Fatalf("generate token %v", err)
37 | }
38 | if token == "" {
39 | t.Error("generate token is empty")
40 | }
41 |
42 | if expiresIn != loginTypeApp.ExpiresAt {
43 | t.Errorf("generate token expires want %v but get %v", loginTypeApp.ExpiresAt, expiresIn)
44 | }
45 | cc, err := localAuth.GetClaims(token)
46 | if err != nil {
47 | t.Fatalf("get custom claims %v", err)
48 | }
49 |
50 | if cc.Id != loginTypeApp.Id {
51 | t.Errorf("get custom id want %v but get %v", loginTypeApp.Id, cc.Id)
52 | }
53 | if cc.Username != loginTypeApp.Username {
54 | t.Errorf("get custom username want %v but get %v", loginTypeApp.Username, cc.Username)
55 | }
56 | if cc.AuthId != loginTypeApp.AuthId {
57 | t.Errorf("get custom authority_id want %v but get %v", loginTypeApp.AuthId, cc.AuthId)
58 | }
59 | if cc.RoleType != loginTypeApp.RoleType {
60 | t.Errorf("get custom authority_type want %v but get %v", loginTypeApp.RoleType, cc.RoleType)
61 | }
62 | if cc.LoginType != loginTypeApp.LoginType {
63 | t.Errorf("get custom login_type want %v but get %v", loginTypeApp.LoginType, cc.LoginType)
64 | }
65 | if cc.AuthType != loginTypeApp.AuthType {
66 | t.Errorf("get custom auth_type want %v but get %v", loginTypeApp.AuthType, cc.AuthType)
67 | }
68 | if cc.CreationTime != loginTypeApp.CreationTime {
69 | t.Errorf("get custom creation_data want %v but get %v", loginTypeApp.CreationTime, cc.CreationTime)
70 | }
71 | if cc.ExpiresAt != loginTypeApp.ExpiresAt {
72 | t.Errorf("get custom expires_at want %v but get %v", loginTypeApp.ExpiresAt, cc.ExpiresAt)
73 | }
74 |
75 | if uTokens, uFound := localAuth.Cache.Get(userKey); uFound {
76 | ts := uTokens.(tokens)
77 | if len(ts) == 0 || ts[0] != token {
78 | t.Errorf("user prefix value want %v but get %v", userKey, uTokens)
79 | }
80 | } else {
81 | t.Error("user prefix value is emptpy")
82 | }
83 | bindKey := BindUserPrefix + token
84 | if uTokens, uFound := localAuth.Cache.Get(bindKey); uFound {
85 | if uTokens != userKey {
86 | t.Errorf("bind user prefix value want %v but get %v", userKey, uTokens)
87 | }
88 | } else {
89 | t.Error("bind user prefix value is emptpy")
90 | }
91 | }
92 |
93 | func TestToCache(t *testing.T) {
94 | err := localAuth.toCache(tToken, loginTypeApp)
95 | if err != nil {
96 | t.Fatalf("generate token %v", err)
97 | }
98 | cc, err := localAuth.GetClaims(tToken)
99 | if err != nil {
100 | t.Fatalf("get custom claims %v", err)
101 | }
102 |
103 | if cc.Id != loginTypeApp.Id {
104 | t.Errorf("get custom id want %v but get %v", loginTypeApp.Id, cc.Id)
105 | }
106 | if cc.Username != loginTypeApp.Username {
107 | t.Errorf("get custom username want %v but get %v", loginTypeApp.Username, cc.Username)
108 | }
109 | if cc.AuthId != loginTypeApp.AuthId {
110 | t.Errorf("get custom authority_id want %v but get %v", loginTypeApp.AuthId, cc.AuthId)
111 | }
112 | if cc.RoleType != loginTypeApp.RoleType {
113 | t.Errorf("get custom authority_type want %v but get %v", loginTypeApp.RoleType, cc.RoleType)
114 | }
115 | if cc.LoginType != loginTypeApp.LoginType {
116 | t.Errorf("get custom login_type want %v but get %v", loginTypeApp.LoginType, cc.LoginType)
117 | }
118 | if cc.AuthType != loginTypeApp.AuthType {
119 | t.Errorf("get custom auth_type want %v but get %v", loginTypeApp.AuthType, cc.AuthType)
120 | }
121 | if cc.CreationTime != loginTypeApp.CreationTime {
122 | t.Errorf("get custom creation_data want %v but get %v", loginTypeApp.CreationTime, cc.CreationTime)
123 | }
124 | if cc.ExpiresAt != loginTypeApp.ExpiresAt {
125 | t.Errorf("get custom expires_at want %v but get %v", loginTypeApp.ExpiresAt, cc.ExpiresAt)
126 | }
127 | }
128 |
129 | func TestDelUserTokenCache(t *testing.T) {
130 | cc := NewClaims(
131 | &Agent{
132 | Id: uint(2),
133 | Username: "username",
134 | SuperAdmin: true,
135 | AuthIds: []string{"999"},
136 | RoleType: RoleAdmin,
137 | LoginType: LoginTypeWeb,
138 | AuthType: AuthPwd,
139 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
140 | },
141 | )
142 | token, _, _ := localAuth.Generate(cc)
143 | if token == "" {
144 | t.Error("generate token is empty")
145 | }
146 |
147 | err := localAuth.DelCache(token)
148 | if err != nil {
149 | t.Fatalf("del user token cache %v", err)
150 | }
151 | _, err = localAuth.GetClaims(token)
152 | if !errors.Is(err, ErrTokenInvalid) {
153 | t.Fatalf("get custom claims err want %v but get %v", ErrTokenInvalid, err)
154 | }
155 |
156 | if uTokens, uFound := localAuth.Cache.Get(UserPrefix + cc.Id); uFound && uTokens != nil {
157 | t.Errorf("user prefix value want empty but get %v", uTokens)
158 | }
159 | bindKey := BindUserPrefix + token
160 | if key, uFound := localAuth.Cache.Get(bindKey); uFound {
161 | t.Errorf("bind user prefix value want empty but get %v", key)
162 | }
163 | }
164 |
165 | func TestIsUserTokenOver(t *testing.T) {
166 | cc := NewClaims(
167 | &Agent{
168 | Id: uint(3),
169 | Username: "username",
170 | SuperAdmin: true,
171 | AuthIds: []string{"999"},
172 | RoleType: RoleAdmin,
173 | LoginType: LoginTypeWeb,
174 | AuthType: AuthPwd,
175 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
176 | },
177 | )
178 | for i := 0; i < 6; i++ {
179 | localAuth.Generate(cc)
180 | }
181 | if localAuth.isUserTokenOver(cc.roleType(), cc.Id) {
182 | t.Error("user token want not over but get over")
183 | }
184 | count := localAuth.getUserTokenCount(cc.roleType(), cc.Id)
185 | if count != 6 {
186 | t.Errorf("user token count want %v but get %v", 6, count)
187 | }
188 | }
189 |
190 | func TestSetUserTokenMaxCount(t *testing.T) {
191 | for i := 0; i < 6; i++ {
192 | localAuth.Generate(loginTypeApp)
193 | }
194 | if err := localAuth.SetLimit(5); err != nil {
195 | t.Fatalf("set user token max count %v", err)
196 | }
197 | count := localAuth.getUserTokenMaxCount()
198 | if count != 5 {
199 | t.Errorf("user token max count want %v but get %v", 5, count)
200 | }
201 | if !localAuth.isUserTokenOver(loginTypeApp.roleType(), loginTypeApp.Id) {
202 | t.Error("user token want over but get not over")
203 | }
204 | }
205 | func TestCleanUserTokenCache(t *testing.T) {
206 | for i := 0; i < 6; i++ {
207 | localAuth.Generate(loginTypeApp)
208 | }
209 | if err := localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id); err != nil {
210 | t.Fatalf("clear user token cache %v", err)
211 | }
212 | if localAuth.getUserTokenCount(loginTypeApp.roleType(), loginTypeApp.Id) != 0 {
213 | t.Error("user token count want 0 but get not 0")
214 | }
215 | }
216 |
217 | func TestLocalGetMultiClaims(t *testing.T) {
218 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id)
219 | var token string
220 | loginTypeApp.LoginType = 3
221 | token, _, err := localAuth.Generate(loginTypeApp)
222 | if err != nil {
223 | t.Fatalf("get custom claims %v", err)
224 | }
225 | for i := LoginTypeWeb; i <= LoginTypeDevice; i++ {
226 | wg.Add(1)
227 | loginTypeApp.setLoginType(int(i))
228 | go func(i LoginType) {
229 | localAuth.Generate(loginTypeApp)
230 | wg.Done()
231 | }(i)
232 | wg.Wait()
233 | }
234 | for i := 0; i < 4; i++ {
235 | go func() {
236 | _, err := localAuth.GetClaims(token)
237 | if err != nil {
238 | t.Errorf("get custom claims fail:%v", err)
239 | }
240 | }()
241 | }
242 | time.Sleep(3 * time.Second)
243 | }
244 |
245 | func TestLocalGetUserTokens(t *testing.T) {
246 | loginTypeWeb := NewClaims(
247 | &Agent{
248 | Id: uint(121321),
249 | Username: "username",
250 | SuperAdmin: true,
251 | AuthIds: []string{"999"},
252 | RoleType: RoleAdmin,
253 | LoginType: LoginTypeWeb,
254 | AuthType: AuthPwd,
255 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
256 | },
257 | )
258 |
259 | defer localAuth.CleanCache(loginTypeWeb.roleType(), loginTypeWeb.Id)
260 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id)
261 |
262 | token, _, err := localAuth.Generate(loginTypeApp)
263 | if err != nil {
264 | t.Fatalf("get user tokens by claims generate token %v \n", err)
265 | }
266 |
267 | if token == "" {
268 | t.Fatal("get user tokens by claims generate token is empty \n")
269 | }
270 |
271 | token3232, _, err := localAuth.Generate(loginTypeWeb)
272 | if err != nil {
273 | t.Fatalf("get user tokens by claims generate token %v \n", err)
274 | }
275 |
276 | if token3232 == "" {
277 | t.Fatal("get user tokens by claims generate token is empty \n")
278 | }
279 |
280 | if token == token3232 {
281 | t.Fatal("get user tokens by claims generate token is same")
282 | }
283 |
284 | tokens, err := localAuth.getUserTokens(loginTypeApp.roleType(), loginTypeApp.Id)
285 | if err != nil {
286 | t.Fatalf("get user tokens by claims %v", err)
287 | }
288 |
289 | if len(tokens) != 1 {
290 | t.Fatalf("get user tokens by claims want len 1 but get %d", len(tokens))
291 | }
292 | }
293 |
294 | func TestLocalGetTokenByClaims(t *testing.T) {
295 | LoginTypeWeb := NewClaims(
296 | &Agent{
297 | Id: uint(3232),
298 | Username: "username",
299 | SuperAdmin: true,
300 | AuthIds: []string{"999"},
301 | RoleType: RoleAdmin,
302 | LoginType: LoginTypeWeb,
303 | AuthType: AuthPwd,
304 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
305 | },
306 | )
307 | defer localAuth.CleanCache(LoginTypeWeb.roleType(), LoginTypeWeb.Id)
308 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id)
309 |
310 | token, _, err := localAuth.Generate(loginTypeApp)
311 | if err != nil {
312 | t.Fatalf("get token by claims generate token %v \n", err)
313 | }
314 |
315 | if token == "" {
316 | t.Fatal("get token by claims generate token is empty \n")
317 | }
318 |
319 | token3232, _, err := localAuth.Generate(LoginTypeWeb)
320 | if err != nil {
321 | t.Fatalf("get token by claims generate token %v \n", err)
322 | }
323 |
324 | if token3232 == "" {
325 | t.Fatal("get token by claims generate token is empty \n")
326 | }
327 |
328 | userToken, err := localAuth.Token(loginTypeApp)
329 | if err != nil {
330 | t.Fatalf("get token by claims %v", err)
331 | }
332 |
333 | if token != userToken {
334 | t.Errorf("get token by claims token want %s but get '%s'", token, userToken)
335 | }
336 | if token == token3232 {
337 | t.Errorf("get token by claims token not want %s but get '%s'", token3232, token)
338 | }
339 |
340 | }
341 | func TestLocalGetMultiClaimses(t *testing.T) {
342 | defer localAuth.CleanCache(loginTypeApp.roleType(), loginTypeApp.Id)
343 | tokenLen := 0
344 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ {
345 | wg.Add(1)
346 | tokenLen++
347 | loginTypeApp.setLoginType(int(i))
348 | go func(i LoginType) {
349 | localAuth.Generate(loginTypeApp)
350 | wg.Done()
351 | }(i)
352 | wg.Wait()
353 | }
354 | userTokens, err := localAuth.getUserTokens(loginTypeApp.roleType(), loginTypeApp.Id)
355 | if err != nil {
356 | t.Fatal("get custom claimses generate token is empty \n")
357 | }
358 | clas, err := localAuth.getMultiClaimses(userTokens)
359 | if err != nil {
360 | t.Fatalf("get custom claimses %v", err)
361 | }
362 |
363 | if len(userTokens) != tokenLen {
364 | t.Fatalf("get custom claimses want len %d but get %d", tokenLen, len(userTokens))
365 | }
366 | if len(clas) != tokenLen {
367 | t.Fatalf("get custom claimses want len %d but get %d", tokenLen, len(clas))
368 | }
369 |
370 | }
371 |
--------------------------------------------------------------------------------
/auth2/redis.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 |
8 | "github.com/go-redis/redis/v8"
9 | "github.com/pkg/errors"
10 | )
11 |
12 | // RedisAuth
13 | type RedisAuth struct {
14 | Client redis.UniversalClient
15 | }
16 |
17 | // NewRedis
18 | func NewRedis(client redis.UniversalClient) (*RedisAuth, error) {
19 | if client == nil {
20 | return nil, errors.New("redis client is nil")
21 | }
22 | _, err := client.Ping(context.Background()).Result()
23 | if err != nil {
24 | return nil, err
25 | }
26 | return &RedisAuth{
27 | Client: client,
28 | }, nil
29 | }
30 |
31 | // Generate
32 | func (ra *RedisAuth) Generate(claims *Claims) (string, int64, error) {
33 | token, err := ra.Token(claims)
34 | if err != nil {
35 | return "", int64(claims.ExpiresAt), err
36 | }
37 |
38 | if token == "" {
39 | if isOver, err := ra.isUserTokenOver(claims.roleType(), claims.Id); err != nil {
40 | return "", int64(claims.ExpiresAt), err
41 | } else if isOver {
42 | return "", int64(claims.ExpiresAt), ErrOverLimit
43 | }
44 |
45 | token, err = getToken()
46 | if err != nil {
47 | return "", int64(claims.ExpiresAt), err
48 | }
49 | }
50 |
51 | if err = ra.toCache(token, claims); err != nil {
52 | return "", int64(claims.ExpiresAt), err
53 | }
54 |
55 | if err = ra.syncUserTokenCache(token); err != nil {
56 | return "", int64(claims.ExpiresAt), err
57 | }
58 |
59 | return token, int64(claims.ExpiresAt), nil
60 | }
61 |
62 | // toCache
63 | func (ra *RedisAuth) toCache(token string, cla *Claims) error {
64 | sKey := TokenPrefix + token
65 | if _, err := ra.Client.HMSet(context.Background(), sKey,
66 | "id", cla.Id,
67 | "super_admin", cla.SuperAdmin,
68 | "login_type", cla.LoginType,
69 | "auth_type", cla.AuthType,
70 | "username", cla.Username,
71 | "auth_id", cla.AuthId,
72 | "role_type", cla.RoleType,
73 | "creation_data", cla.CreationTime,
74 | "expires_at", cla.ExpiresAt,
75 | ).Result(); err != nil {
76 | return fmt.Errorf("to cache token %w", err)
77 | }
78 | err := ra.setExpire(sKey, cla.loginType())
79 | if err != nil {
80 | return err
81 | }
82 |
83 | return nil
84 | }
85 |
86 | // Token
87 | func (ra *RedisAuth) Token(cla *Claims) (string, error) {
88 | userTokens, err := ra.getUserTokens(cla.roleType(), cla.Id)
89 | if err != nil {
90 | return "", err
91 | }
92 | clas, err := ra.getMultiClaimses(userTokens)
93 | if err != nil {
94 | return "", err
95 | }
96 | for token, existCla := range clas {
97 | if cla.AuthType == existCla.AuthType && cla.Id == existCla.Id && cla.RoleType == existCla.RoleType &&
98 | cla.AuthId == existCla.AuthId && cla.LoginType == existCla.LoginType {
99 | return token, nil
100 | }
101 | }
102 | return "", nil
103 | }
104 |
105 | // getMultiClaimses
106 | func (ra *RedisAuth) getMultiClaimses(tokens []string) (map[string]*Claims, error) {
107 | clas := make(map[string]*Claims, ra.getUserTokenLimit())
108 | for _, token := range tokens {
109 | cla, err := ra.GetClaims(token)
110 | if err != nil {
111 | continue
112 | }
113 | clas[token] = cla
114 | }
115 |
116 | return clas, nil
117 | }
118 |
119 | // GetClaims
120 | func (ra *RedisAuth) GetClaims(token string) (*Claims, error) {
121 | cla := new(Claims)
122 | if err := ra.Client.HGetAll(context.Background(), TokenPrefix+token).Scan(cla); err != nil {
123 | return nil, fmt.Errorf("get custom claims redis hgetall %w", err)
124 | }
125 |
126 | if cla.Id == "" {
127 | return nil, ErrEmptyToken
128 | }
129 |
130 | return cla, nil
131 | }
132 |
133 | // isUserTokenOver
134 | func (ra *RedisAuth) isUserTokenOver(roleType RoleType, userId string) (bool, error) {
135 | max, err := ra.getUserTokenCount(roleType, userId)
136 | if err != nil {
137 | return true, err
138 | }
139 | return max >= ra.getUserTokenLimit(), nil
140 | }
141 |
142 | // getUserTokens
143 | func (ra *RedisAuth) getUserTokens(roleType RoleType, userId string) ([]string, error) {
144 | userTokens, err := ra.Client.SMembers(context.Background(), getPrefixKey(roleType, userId)).Result()
145 | if err != nil {
146 | return nil, fmt.Errorf("get user token count menbers %w", err)
147 | }
148 | return userTokens, nil
149 | }
150 |
151 | // getUserTokenCount
152 | func (ra *RedisAuth) getUserTokenCount(roleType RoleType, userId string) (int64, error) {
153 | var count int64
154 | userTokens, err := ra.getUserTokens(roleType, userId)
155 | if err != nil {
156 | return count, fmt.Errorf("get user token count menbers %w", err)
157 | }
158 | userPrefixKey := getPrefixKey(roleType, userId)
159 | for _, token := range userTokens {
160 | if ra.checkUserTokenCount(token, userPrefixKey) == 1 {
161 | count++
162 | }
163 | }
164 | return count, nil
165 | }
166 |
167 | // checkUserTokenCount
168 | func (ra *RedisAuth) checkUserTokenCount(token, userPrefixKey string) int64 {
169 | mun, err := ra.Client.Exists(context.Background(), TokenPrefix+token).Result()
170 | if err != nil || mun == 0 {
171 | ra.Client.SRem(context.Background(), userPrefixKey, token)
172 | }
173 | return mun
174 | }
175 |
176 | // getUserTokenLimit
177 | func (ra *RedisAuth) getUserTokenLimit() int64 {
178 | count, err := ra.Client.Get(context.Background(), LimitTokenPrefix).Int64()
179 | if err != nil {
180 | return LimitTokenDefault
181 | }
182 | return count
183 | }
184 |
185 | // SetLimit
186 | func (ra *RedisAuth) SetLimit(limit int64) error {
187 | err := ra.Client.Set(context.Background(), LimitTokenPrefix, limit, 0).Err()
188 | if err != nil {
189 | return err
190 | }
191 | return nil
192 | }
193 |
194 | // syncUserTokenCache
195 | func (ra *RedisAuth) syncUserTokenCache(token string) error {
196 | cla, err := ra.GetClaims(token)
197 | if err != nil {
198 | return fmt.Errorf("sysnc user token cache %w", err)
199 | }
200 | userPrefixKey := getPrefixKey(cla.roleType(), cla.Id)
201 | if _, err := ra.Client.SAdd(context.Background(), userPrefixKey, token).Result(); err != nil {
202 | return fmt.Errorf("sync user token cache redis sadd %w", err)
203 | }
204 |
205 | bindUserPrefixKey := BindUserPrefix + token
206 | _, err = ra.Client.Set(context.Background(), bindUserPrefixKey, userPrefixKey, getExpire(cla.loginType())).Result()
207 | if err != nil {
208 | return fmt.Errorf("sync user token cache %w", err)
209 | }
210 | return nil
211 | }
212 |
213 | // UpdateCacheExpire
214 | func (ra *RedisAuth) UpdateCacheExpire(token string) error {
215 | rcc, err := ra.GetClaims(token)
216 | if err != nil {
217 | return fmt.Errorf("update user token cache expire %w", err)
218 | }
219 | if rcc == nil {
220 | return errors.New("token cache is nil")
221 | }
222 | if err = ra.setExpire(TokenPrefix+token, rcc.loginType()); err != nil {
223 | return fmt.Errorf("update user token cache expire redis expire %w", err)
224 | }
225 | if err = ra.setExpire(BindUserPrefix+token, rcc.loginType()); err != nil {
226 | return fmt.Errorf("update user token cache expire redis expire %w", err)
227 | }
228 | return nil
229 | }
230 |
231 | func (ra *RedisAuth) setExpire(key string, loginType LoginType) error {
232 | if _, err := ra.Client.Expire(context.Background(), key, getExpire(loginType)).Result(); err != nil {
233 | return fmt.Errorf("update user token cache expire redis expire %w", err)
234 | }
235 | return nil
236 | }
237 |
238 | // DelCache
239 | func (ra *RedisAuth) DelCache(token string) error {
240 | log.Println("auth2: redis del user token")
241 | cla, err := ra.GetClaims(token)
242 | if err != nil {
243 | return err
244 | }
245 | if cla == nil {
246 | return errors.New("del user token, reids cache is nil")
247 | }
248 |
249 | if e := ra.delUserTokenPrefixToken(cla.roleType(), cla.Id, token); e != nil {
250 | return e
251 | }
252 |
253 | if e := ra.delTokenCache(token); e != nil {
254 | return e
255 | }
256 | return nil
257 | }
258 |
259 | // delUserTokenPrefixToken
260 | func (ra *RedisAuth) delUserTokenPrefixToken(roleType RoleType, id, token string) error {
261 | _, err := ra.Client.SRem(context.Background(), getPrefixKey(roleType, id), token).Result()
262 | if err != nil {
263 | return fmt.Errorf("del user token cache redis srem %w", err)
264 | }
265 | return nil
266 | }
267 |
268 | // delTokenCache
269 | func (ra *RedisAuth) delTokenCache(token string) error {
270 | sKey2 := BindUserPrefix + token
271 | _, err := ra.Client.Del(context.Background(), sKey2).Result()
272 | if err != nil {
273 | return fmt.Errorf("del user token cache redis del2 %w", err)
274 | }
275 |
276 | sKey3 := TokenPrefix + token
277 | _, err = ra.Client.Del(context.Background(), sKey3).Result()
278 | if err != nil {
279 | return fmt.Errorf("del user token cache redis del3 %w", err)
280 | }
281 |
282 | return nil
283 | }
284 |
285 | // CleanCache
286 | func (ra *RedisAuth) CleanCache(roleType RoleType, userId string) error {
287 | allTokens, err := ra.getUserTokens(roleType, userId)
288 | if err != nil {
289 | return fmt.Errorf("clean user token cache redis smembers %w", err)
290 | }
291 | _, err = ra.Client.Del(context.Background(), getPrefixKey(roleType, userId)).Result()
292 | if err != nil {
293 | return fmt.Errorf("clean user token cache redis del %w", err)
294 | }
295 |
296 | for _, token := range allTokens {
297 | err = ra.delTokenCache(token)
298 | if err != nil {
299 | return err
300 | }
301 | }
302 | return nil
303 | }
304 |
305 | // IsRole
306 | func (ra *RedisAuth) IsRole(token string, roleType RoleType) (bool, error) {
307 | rcc, err := ra.GetClaims(token)
308 | if err != nil {
309 | return false, fmt.Errorf("get User's infomation return error: %w", err)
310 | }
311 | return rcc.roleType() == roleType, nil
312 | }
313 |
314 | // IsSuperAdmin
315 | func (ra *RedisAuth) IsSuperAdmin(token string) bool {
316 | rcc, err := ra.GetClaims(token)
317 | if err != nil {
318 | return false
319 | }
320 | return rcc.SuperAdmin
321 | }
322 |
323 | // Close
324 | func (ra *RedisAuth) Close() {
325 | ra.Client.Close()
326 | }
327 |
--------------------------------------------------------------------------------
/auth2/redis_test.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "os"
7 | "sync"
8 | "testing"
9 | "time"
10 |
11 | "github.com/go-redis/redis/v8"
12 | )
13 |
14 | var (
15 | wg sync.WaitGroup
16 | options = &redis.UniversalOptions{
17 | DB: 1,
18 | Addrs: []string{"127.0.0.1:6379"},
19 | Password: os.Getenv("redisPwd"), //
20 | PoolSize: 10,
21 | IdleTimeout: 300 * time.Second,
22 | // Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
23 | // conn, err := net.Dial(network, addr)
24 | // if err == nil {
25 | // go func() {
26 | // time.Sleep(5 * time.Second)
27 | // conn.Close()
28 | // }()
29 | // }
30 | // return conn, err
31 | // },
32 | }
33 |
34 | rToken = "TVRReU1EVTFOek13TmpFd09UWXlPRFF4TmcuTWpBeU1TMHdOeTB5T1ZRd09Ub3pNRG95T1Nzd09Eb3dNQQ.MTQyMDU1NzMwNjEwOTYyODrtrt"
35 | logTypeWeb = NewClaims(
36 | &Agent{
37 | Id: uint(121321),
38 | Username: "username",
39 | SuperAdmin: true,
40 | AuthIds: []string{"999"},
41 | RoleType: RoleAdmin,
42 | LoginType: LoginTypeWeb,
43 | AuthType: AuthPwd,
44 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
45 | },
46 | )
47 | ruserKey = getPrefixKey(logTypeWeb.roleType(), logTypeWeb.Id)
48 | )
49 |
50 | func TestRedisGenerateToken(t *testing.T) {
51 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
52 | if err != nil {
53 | t.Fatal(err.Error())
54 | }
55 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
56 | token, expiresIn, err := redisAuth.Generate(logTypeWeb)
57 | if err != nil {
58 | t.Fatalf("generate token %v", err)
59 | }
60 | if token == "" {
61 | t.Error("generate token is empty")
62 | }
63 |
64 | if expiresIn != logTypeWeb.ExpiresAt {
65 | t.Errorf("generate token expires want %v but get %v", logTypeWeb.ExpiresAt, expiresIn)
66 | }
67 | cc, err := redisAuth.GetClaims(token)
68 | if err != nil {
69 | t.Fatalf("get custom claims %v", err)
70 | }
71 |
72 | if cc.Id != logTypeWeb.Id {
73 | t.Errorf("get custom id want %v but get %v", logTypeWeb.Id, cc.Id)
74 | }
75 | if cc.Username != logTypeWeb.Username {
76 | t.Errorf("get custom username want %v but get %v", logTypeWeb.Username, cc.Username)
77 | }
78 | if cc.AuthId != logTypeWeb.AuthId {
79 | t.Errorf("get custom authority_id want %v but get %v", logTypeWeb.AuthId, cc.AuthId)
80 | }
81 | if cc.RoleType != logTypeWeb.RoleType {
82 | t.Errorf("get custom authority_type want %v but get %v", logTypeWeb.RoleType, cc.RoleType)
83 | }
84 | if cc.LoginType != logTypeWeb.LoginType {
85 | t.Errorf("get custom login_type want %v but get %v", logTypeWeb.LoginType, cc.LoginType)
86 | }
87 | if cc.AuthType != logTypeWeb.AuthType {
88 | t.Errorf("get custom auth_type want %v but get %v", logTypeWeb.AuthType, cc.AuthType)
89 | }
90 | if cc.CreationTime != logTypeWeb.CreationTime {
91 | t.Errorf("get custom creation_data want %v but get %v", logTypeWeb.CreationTime, cc.CreationTime)
92 | }
93 | if cc.ExpiresAt != logTypeWeb.ExpiresAt {
94 | t.Errorf("get custom expires_at want %v but get %v", logTypeWeb.ExpiresAt, cc.ExpiresAt)
95 | }
96 |
97 | if uTokens, err := redisAuth.Client.SMembers(context.Background(), ruserKey).Result(); err != nil {
98 | t.Fatalf("user prefix value get %s", err)
99 | } else {
100 | if len(uTokens) == 0 || uTokens[0] != token {
101 | t.Errorf("user prefix value want %v but get %v", ruserKey, uTokens)
102 | }
103 | }
104 | bindKey := BindUserPrefix + token
105 | key, err := redisAuth.Client.Get(context.Background(), bindKey).Result()
106 | if err != nil {
107 | t.Fatal(err)
108 | }
109 | if key != ruserKey {
110 | t.Errorf("bind user prefix value want %v but get %v", ruserKey, key)
111 | }
112 | }
113 |
114 | func TestRedisToCache(t *testing.T) {
115 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
116 | if err != nil {
117 | t.Fatal(err.Error())
118 | }
119 | defer redisAuth.Client.Del(context.Background(), TokenPrefix+rToken)
120 | if err := redisAuth.toCache(rToken, logTypeWeb); err != nil {
121 | t.Fatalf("generate token %v", err)
122 | }
123 | cc, err := redisAuth.GetClaims(rToken)
124 | if err != nil {
125 | t.Fatalf("get custom claims %v", err)
126 | }
127 |
128 | if cc.Id != logTypeWeb.Id {
129 | t.Errorf("get custom id want %v but get %v", logTypeWeb.Id, cc.Id)
130 | }
131 | if cc.Username != logTypeWeb.Username {
132 | t.Errorf("get custom username want %v but get %v", logTypeWeb.Username, cc.Username)
133 | }
134 | if cc.AuthId != logTypeWeb.AuthId {
135 | t.Errorf("get custom authority_id want %v but get %v", logTypeWeb.AuthId, cc.AuthId)
136 | }
137 | if cc.RoleType != logTypeWeb.RoleType {
138 | t.Errorf("get custom authority_type want %v but get %v", logTypeWeb.RoleType, cc.RoleType)
139 | }
140 | if cc.LoginType != logTypeWeb.LoginType {
141 | t.Errorf("get custom login_type want %v but get %v", logTypeWeb.LoginType, cc.LoginType)
142 | }
143 | if cc.AuthType != logTypeWeb.AuthType {
144 | t.Errorf("get custom auth_type want %v but get %v", logTypeWeb.AuthType, cc.AuthType)
145 | }
146 | if cc.CreationTime != logTypeWeb.CreationTime {
147 | t.Errorf("get custom creation_data want %v but get %v", logTypeWeb.CreationTime, cc.CreationTime)
148 | }
149 | if cc.ExpiresAt != logTypeWeb.ExpiresAt {
150 | t.Errorf("get custom expires_at want %v but get %v", logTypeWeb.ExpiresAt, cc.ExpiresAt)
151 | }
152 | }
153 |
154 | func TestRedisDelUserTokenCache(t *testing.T) {
155 | cc := NewClaims(
156 | &Agent{
157 | Id: uint(221),
158 | Username: "username",
159 | SuperAdmin: true,
160 | AuthIds: []string{"999"},
161 | RoleType: RoleAdmin,
162 | LoginType: LoginTypeWeb,
163 | AuthType: AuthPwd,
164 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
165 | },
166 | )
167 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
168 | if err != nil {
169 | t.Fatal(err.Error())
170 | }
171 | defer redisAuth.CleanCache(cc.roleType(), cc.Id)
172 | token, _, _ := redisAuth.Generate(cc)
173 | if token == "" {
174 | t.Error("generate token is empty")
175 | }
176 |
177 | if err := redisAuth.DelCache(token); err != nil {
178 | t.Fatalf("del user token cache %v", err)
179 | }
180 | _, err = redisAuth.GetClaims(token)
181 | if !errors.Is(err, ErrEmptyToken) {
182 | t.Fatalf("get custom claims err want '%v' but get '%v'", ErrEmptyToken, err)
183 | }
184 |
185 | if uTokens, err := redisAuth.Client.SMembers(context.Background(), UserPrefix+cc.Id).Result(); err != nil {
186 | t.Fatalf("user prefix value wantget %v", err)
187 | } else if len(uTokens) != 0 {
188 | t.Errorf("user prefix value want empty but get %+v", uTokens)
189 | }
190 | bindKey := BindUserPrefix + token
191 | key, _ := redisAuth.Client.Get(context.Background(), bindKey).Result()
192 | if key != "" {
193 | t.Errorf("bind user prefix value want empty but get %v", key)
194 | }
195 | }
196 |
197 | func TestRedisIsUserTokenOver(t *testing.T) {
198 | cc := NewClaims(
199 | &Agent{
200 | Id: uint(3232),
201 | Username: "username",
202 | SuperAdmin: true,
203 | AuthIds: []string{"999"},
204 | RoleType: RoleAdmin,
205 | LoginType: LoginTypeWeb,
206 | AuthType: AuthPwd,
207 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
208 | },
209 | )
210 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
211 | if err != nil {
212 | t.Fatal(err.Error())
213 | }
214 | defer redisAuth.CleanCache(cc.roleType(), cc.Id)
215 | if err := redisAuth.SetLimit(10); err != nil {
216 | t.Fatalf("set user token max count %v", err)
217 | }
218 | var wantTokenLen int64 = 0
219 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ {
220 | cc.setLoginType(int(i))
221 | wg.Add(1)
222 | wantTokenLen++
223 | go func(i LoginType) {
224 | redisAuth.Generate(cc)
225 | wg.Done()
226 | }(i)
227 | wg.Wait()
228 | }
229 | isOver, err := redisAuth.isUserTokenOver(cc.roleType(), cc.Id)
230 | if err != nil {
231 | t.Fatalf("is user token over get %v", err)
232 | }
233 | if isOver {
234 | t.Error("user token want not over but get over")
235 | }
236 | count, err := redisAuth.getUserTokenCount(cc.roleType(), cc.Id)
237 | if err != nil {
238 | t.Fatalf("user token count get %v", err)
239 | }
240 | if count != wantTokenLen {
241 | t.Errorf("user token count want %v but get %v", wantTokenLen, count)
242 | }
243 | }
244 |
245 | func TestRedisSetUserTokenMaxCount(t *testing.T) {
246 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
247 | if err != nil {
248 | t.Fatal(err.Error())
249 | }
250 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
251 | if err := redisAuth.SetLimit(10); err != nil {
252 | t.Fatalf("set user token max count %v", err)
253 | }
254 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ {
255 | wg.Add(1)
256 | logTypeWeb.setLoginType(int(i))
257 | go func(i LoginType) {
258 | redisAuth.Generate(logTypeWeb)
259 | wg.Done()
260 | }(i)
261 | wg.Wait()
262 | }
263 | if err := redisAuth.SetLimit(3); err != nil {
264 | t.Fatalf("set user token max count %v", err)
265 | }
266 | count := redisAuth.getUserTokenLimit()
267 | if count != 3 {
268 | t.Errorf("user token max count want %v but get %v", 3, count)
269 | }
270 | isOver, err := redisAuth.isUserTokenOver(logTypeWeb.roleType(), logTypeWeb.Id)
271 | if err != nil {
272 | t.Fatalf("is user token over get %v", err)
273 | }
274 | if !isOver {
275 | t.Error("user token want over but get not over")
276 | }
277 | }
278 | func TestRedisCleanUserTokenCache(t *testing.T) {
279 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
280 | if err != nil {
281 | t.Fatal(err.Error())
282 | }
283 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
284 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ {
285 | wg.Add(1)
286 | logTypeWeb.setLoginType(int(i))
287 | go func(i LoginType) {
288 | redisAuth.Generate(logTypeWeb)
289 | wg.Done()
290 | }(i)
291 | wg.Wait()
292 | }
293 | if err := redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id); err != nil {
294 | t.Fatalf("clear user token cache %v", err)
295 | }
296 | count, err := redisAuth.getUserTokenCount(logTypeWeb.roleType(), logTypeWeb.Id)
297 | if err != nil {
298 | t.Fatalf("user token count get %v", err)
299 | }
300 | if count != 0 {
301 | t.Error("user token count want 0 but get not 0")
302 | }
303 | }
304 |
305 | func TestRedisGetMultiClaims(t *testing.T) {
306 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
307 | if err != nil {
308 | t.Fatal(err.Error())
309 | }
310 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
311 | logTypeWeb.LoginType = 3
312 | token, _, err := redisAuth.Generate(logTypeWeb)
313 | if err != nil {
314 | t.Fatalf("get custom claims %v", err)
315 | }
316 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ {
317 | wg.Add(1)
318 | logTypeWeb.setLoginType(int(i))
319 | go func(i LoginType) {
320 | redisAuth.Generate(logTypeWeb)
321 | wg.Done()
322 | }(i)
323 | wg.Wait()
324 | }
325 | for i := 0; i < 4; i++ {
326 | go func() {
327 | _, err := redisAuth.GetClaims(token)
328 | if err != nil {
329 | t.Errorf("get custom claims %v", err)
330 | }
331 | }()
332 | }
333 | time.Sleep(3 * time.Second)
334 | }
335 |
336 | func TestRedisGetUserTokens(t *testing.T) {
337 | cc := NewClaims(
338 | &Agent{
339 | Id: uint(121321),
340 | Username: "username",
341 | SuperAdmin: true,
342 | AuthIds: []string{"999"},
343 | RoleType: RoleAdmin,
344 | LoginType: LoginTypeWeb,
345 | AuthType: AuthPwd,
346 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
347 | },
348 | )
349 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
350 | if err != nil {
351 | t.Fatal(err.Error())
352 | }
353 | defer redisAuth.CleanCache(cc.roleType(), cc.Id)
354 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
355 | token, _, err := redisAuth.Generate(logTypeWeb)
356 | if err != nil {
357 | t.Fatalf("get user tokens by claims generate token %v \n", err)
358 | }
359 |
360 | if token == "" {
361 | t.Fatal("get user tokens by claims generate token is empty \n")
362 | }
363 |
364 | token3232, _, err := redisAuth.Generate(cc)
365 | if err != nil {
366 | t.Fatalf("get user tokens by claims generate token %v \n", err)
367 | }
368 |
369 | if token3232 == "" {
370 | t.Fatal("get user tokens by claims generate token is empty \n")
371 | }
372 |
373 | tokens, err := redisAuth.getUserTokens(logTypeWeb.roleType(), logTypeWeb.Id)
374 | if err != nil {
375 | t.Fatalf("get user tokens by claims %v", err)
376 | }
377 | wantTokenLen := 2
378 | if len(tokens) != wantTokenLen {
379 | t.Fatalf("get user tokens by claims want len %d but get %d", wantTokenLen, len(tokens))
380 | }
381 | }
382 |
383 | func TestRedisGetTokenByClaims(t *testing.T) {
384 | cc := NewClaims(
385 | &Agent{
386 | Id: uint(3232),
387 | Username: "username",
388 | SuperAdmin: true,
389 | AuthIds: []string{"999"},
390 | RoleType: RoleAdmin,
391 | LoginType: LoginTypeWeb,
392 | AuthType: AuthPwd,
393 | ExpiresAt: time.Now().Local().Add(TimeoutWeb).Unix(),
394 | },
395 | )
396 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
397 | if err != nil {
398 | t.Fatal(err.Error())
399 | }
400 | defer redisAuth.CleanCache(cc.roleType(), cc.Id)
401 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
402 |
403 | token, _, err := redisAuth.Generate(logTypeWeb)
404 | if err != nil {
405 | t.Fatalf("get token by claims generate token %v \n", err)
406 | }
407 |
408 | if token == "" {
409 | t.Fatal("get token by claims generate token is empty \n")
410 | }
411 |
412 | token3232, _, err := redisAuth.Generate(cc)
413 | if err != nil {
414 | t.Fatalf("get token by claims generate token %v \n", err)
415 | }
416 |
417 | if token3232 == "" {
418 | t.Fatal("get token by claims generate token is empty \n")
419 | }
420 |
421 | userToken, err := redisAuth.Token(logTypeWeb)
422 | if err != nil {
423 | t.Fatalf("get token by claims %v", err)
424 | }
425 |
426 | if token != userToken {
427 | t.Errorf("get token by claims token want %s but get %s", token, userToken)
428 | }
429 | if token == token3232 {
430 | t.Errorf("get token by claims token not want %s but get %s", token3232, token)
431 | }
432 |
433 | }
434 | func TestRedisGetMultiClaimses(t *testing.T) {
435 | redisAuth, err := NewRedis(redis.NewUniversalClient(options))
436 | if err != nil {
437 | t.Fatal(err.Error())
438 | }
439 | defer redisAuth.CleanCache(logTypeWeb.roleType(), logTypeWeb.Id)
440 | wantTokenLen := 0
441 | for i := LoginTypeWeb; i <= LoginTypeWx; i++ {
442 | wg.Add(1)
443 | wantTokenLen++
444 | logTypeWeb.setLoginType(int(i))
445 | go func(i LoginType) {
446 | redisAuth.Generate(logTypeWeb)
447 | wg.Done()
448 | }(i)
449 | wg.Wait()
450 | }
451 | userTokens, err := redisAuth.getUserTokens(logTypeWeb.roleType(), logTypeWeb.Id)
452 | if err != nil {
453 | t.Fatal("get custom claimses generate token is empty \n")
454 | }
455 | clas, err := redisAuth.getMultiClaimses(userTokens)
456 | if err != nil {
457 | t.Fatalf("get custom claimses %v", err)
458 | }
459 |
460 | if len(userTokens) != wantTokenLen {
461 | t.Fatalf("get custom claimses want len %d but get %d", wantTokenLen, len(userTokens))
462 | }
463 |
464 | if len(clas) != wantTokenLen {
465 | t.Fatalf("get custom claimses want len %d but get %d", wantTokenLen, len(clas))
466 | }
467 |
468 | }
469 |
--------------------------------------------------------------------------------
/auth2/token.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "fmt"
7 |
8 | "github.com/bwmarrin/snowflake"
9 | uuid "github.com/satori/go.uuid"
10 | "github.com/snowlyg/helper/dir"
11 | )
12 |
13 | var (
14 | sep = []byte(".")
15 | pad = []byte("=")
16 | padStr = string(pad)
17 | )
18 |
19 | // getToken
20 | func getToken() (string, error) {
21 | v4 := uuid.NewV4()
22 | node, err := snowflake.NewNode(1)
23 | if err != nil {
24 | return "", fmt.Errorf("token: get token %w", err)
25 | }
26 |
27 | // 混入两个时间,防止并发token重复
28 | nodeBytes, _ := dir.Md5Byte(Base64Encode(node.Generate().Bytes()))
29 | uuidBytes, _ := dir.Md5Byte(Base64Encode(joinParts(Base64Encode(v4.Bytes()), []byte(nodeBytes))))
30 | token := joinParts(Base64Encode([]byte(uuidBytes)), Base64Encode([]byte(nodeBytes)))
31 | return string(Base64Encode([]byte(token))), nil
32 | }
33 |
34 | // joinParts
35 | func joinParts(parts ...[]byte) []byte {
36 | return bytes.Join(parts, sep)
37 | }
38 |
39 | // Base64Encode
40 | func Base64Encode(src []byte) []byte {
41 | buf := make([]byte, base64.URLEncoding.EncodedLen(len(src)))
42 | base64.URLEncoding.Encode(buf, src)
43 |
44 | return bytes.TrimRight(buf, padStr) // JWT: no trailing '='.
45 | }
46 |
47 | // Base64Decode decodes "src" to jwt base64 url format.
48 | // We could use the base64.RawURLEncoding but the below is a bit faster.
49 | func Base64Decode(src []byte) ([]byte, error) {
50 | if n := len(src) % 4; n > 0 {
51 | // JWT: Because of no trailing '=' let's suffix it
52 | // with the correct number of those '=' before decoding.
53 | src = append(src, bytes.Repeat(pad, 4-n)...)
54 | }
55 |
56 | buf := make([]byte, base64.URLEncoding.DecodedLen(len(src)))
57 | n, err := base64.URLEncoding.Decode(buf, src)
58 | return buf[:n], err
59 | }
60 |
--------------------------------------------------------------------------------
/auth2/token_test.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 |
7 | "github.com/snowlyg/helper/arr"
8 | )
9 |
10 | func TestGetToken(t *testing.T) {
11 | token, err := getToken()
12 | if err != nil {
13 | t.Error(err)
14 | }
15 | if token == "" {
16 | t.Error("Generate token is fail.")
17 | }
18 | if token1, err := getToken(); err != nil {
19 | t.Error(err)
20 | } else if token == "" {
21 | t.Error("Generate token is fail.")
22 | } else if token == token1 {
23 | t.Errorf("token[%s] token1[%s] is repeat", token, token1)
24 | }
25 | }
26 |
27 | func TestJoinParts(t *testing.T) {
28 | afterJoin := joinParts([]byte("header"), []byte("footer"))
29 | want := []byte("header.footer")
30 | if bytes.Compare(afterJoin, want) > 0 {
31 | t.Errorf("Join parts want %s but get %s", string(want), string(afterJoin))
32 | }
33 | }
34 |
35 | func TestBase64Encode(t *testing.T) {
36 | want := []byte("header")
37 | baseEncode := Base64Encode(want)
38 | afterDecode, err := Base64Decode(baseEncode)
39 | if err != nil {
40 | t.Error(err)
41 | }
42 | if bytes.Compare(afterDecode, want) > 0 {
43 | t.Errorf("Base64Encode and Base64Decode not effect")
44 | }
45 | }
46 |
47 | func BenchmarkGetToken(b *testing.B) {
48 | b.Run("Benchmark test get token", func(b *testing.B) {
49 | tokens := Token{CheckArrayType: *arr.NewCheckArrayType(b.N)}
50 | for i := 0; i < b.N; i++ {
51 | token, err := getToken()
52 | if err != nil {
53 | b.Error(err)
54 | }
55 | if token == "" {
56 | b.Error("Generate token is fail.")
57 | }
58 | if tokens.Check(token) {
59 | b.Fatalf("token is repeat")
60 | }
61 | tokens.Add(token)
62 | }
63 | })
64 | }
65 |
66 | type Token struct {
67 | arr.CheckArrayType
68 | }
69 |
--------------------------------------------------------------------------------
/auth2/verifier.go:
--------------------------------------------------------------------------------
1 | package auth2
2 |
3 | import (
4 | "log"
5 | "net/http"
6 | "strconv"
7 | "strings"
8 |
9 | "github.com/gin-gonic/gin"
10 | )
11 |
12 | const (
13 | claimsContextKey = "gin.auth2.claims"
14 | verifiedTokenContextKey = "gin.auth2.token"
15 | )
16 |
17 | // Get returns the claims decoded by a verifier.
18 | func Get(ctx *gin.Context) *Claims {
19 | v, b := ctx.Get(claimsContextKey)
20 | if !b {
21 | log.Println("verifier: key not exist")
22 | return nil
23 | }
24 | tok, ok := v.(*Claims)
25 | if !ok {
26 | log.Println("verifier: object not claims")
27 | return nil
28 | }
29 | return tok
30 | }
31 |
32 | // GetType
33 | func GetType(ctx *gin.Context) RoleType {
34 | if v := Get(ctx); v != nil {
35 | return v.roleType()
36 | }
37 | return 0
38 | }
39 |
40 | // GetAuthId
41 | func GetAuthId(ctx *gin.Context) []string {
42 | if v := Get(ctx); v != nil {
43 | return strings.Split(v.AuthId, AuthTypeSplit)
44 | }
45 | return nil
46 | }
47 |
48 | // GetUserId
49 | func GetUserId(ctx *gin.Context) uint {
50 | v := Get(ctx)
51 | if v == nil {
52 | return 0
53 | }
54 | id, err := strconv.Atoi(v.Id)
55 | if err != nil {
56 | return 0
57 | }
58 | return uint(id)
59 | }
60 |
61 | // IsSuperAdmin
62 | func IsSuperAdmin(ctx *gin.Context) bool {
63 | v := Get(ctx)
64 | if v == nil {
65 | log.Println("verifier: Claim is nil")
66 | return false
67 | }
68 | return v.SuperAdmin
69 | }
70 |
71 | // GetUsername
72 | func GetUsername(ctx *gin.Context) string {
73 | if v := Get(ctx); v != nil {
74 | return v.Username
75 | }
76 | return ""
77 | }
78 |
79 | // GetCreationDate
80 | func GetCreationDate(ctx *gin.Context) int64 {
81 | if v := Get(ctx); v != nil {
82 | return v.CreationTime
83 | }
84 | return 0
85 | }
86 |
87 | // GetExpiresIn
88 | func GetExpiresIn(ctx *gin.Context) int64 {
89 | if v := Get(ctx); v != nil {
90 | return v.ExpiresAt
91 | }
92 | return 0
93 | }
94 |
95 | func GetVerifiedToken(ctx *gin.Context) []byte {
96 | v, b := ctx.Get(verifiedTokenContextKey)
97 | if !b {
98 | return nil
99 | }
100 | if tok, ok := v.([]byte); ok {
101 | return tok
102 | }
103 | return nil
104 | }
105 |
106 | func IsRole(ctx *gin.Context, roleType RoleType) bool {
107 | v := GetVerifiedToken(ctx)
108 | if v == nil {
109 | return false
110 | }
111 | b, err := AuthAgent.IsRole(string(v), roleType)
112 | if err != nil {
113 | return false
114 | }
115 | return b
116 | }
117 |
118 | func IsAdmin(ctx *gin.Context) bool {
119 | return IsRole(ctx, RoleAdmin)
120 | }
121 |
122 | type Verifier struct {
123 | Extractors []TokenExtractor
124 | Validators []TokenValidator
125 | ErrorHandler func(ctx *gin.Context, err error)
126 | }
127 |
128 | func NewVerifier(validators ...TokenValidator) *Verifier {
129 | return &Verifier{
130 | Extractors: []TokenExtractor{FromHeader, FromQuery},
131 | ErrorHandler: func(ctx *gin.Context, err error) {
132 | ctx.AbortWithError(http.StatusUnauthorized, err)
133 | },
134 | Validators: validators,
135 | }
136 | }
137 |
138 | // Invalidate
139 | func (v *Verifier) invalidate(ctx *gin.Context) {
140 | if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil {
141 | ctx.Set(claimsContextKey, "")
142 | ctx.Set(verifiedTokenContextKey, "")
143 | }
144 | }
145 |
146 | // RequestToken extracts the token from the
147 | func (v *Verifier) RequestToken(ctx *gin.Context) (token string) {
148 | for _, extract := range v.Extractors {
149 | if token = extract(ctx); token != "" {
150 | break // ok we found it.
151 | }
152 | }
153 | return
154 | }
155 |
156 | func (v *Verifier) VerifyToken(token []byte, validators ...TokenValidator) ([]byte, *Claims, error) {
157 | if len(token) == 0 {
158 | return nil, nil, ErrEmptyToken
159 | }
160 | var err error
161 | for _, validator := range validators {
162 | // A token validator can skip the builtin validation and return a nil error,
163 | // in that case the previous error is skipped.
164 | if err = validator.Validater(token, err); err != nil {
165 | break
166 | }
167 | }
168 | if err != nil {
169 | // Exit on parsing standard claims error(when Plain is missing) or standard claims validation error or custom validators.
170 | return nil, nil, err
171 | }
172 | rcc, err := AuthAgent.GetClaims(string(token))
173 | if err != nil {
174 | return nil, nil, err
175 | }
176 | err = rcc.Valid()
177 | if err != nil {
178 | return nil, nil, err
179 | }
180 | return token, rcc, nil
181 | }
182 |
183 | func (v *Verifier) Verify(validators ...TokenValidator) gin.HandlerFunc {
184 | return func(ctx *gin.Context) {
185 | token := []byte(v.RequestToken(ctx))
186 | verifiedToken, rcc, err := v.VerifyToken(token, validators...)
187 | if err != nil {
188 | v.invalidate(ctx)
189 | v.ErrorHandler(ctx, err)
190 | return
191 | }
192 | ctx.Set(claimsContextKey, rcc)
193 | ctx.Set(verifiedTokenContextKey, verifiedToken)
194 | ctx.Next()
195 | }
196 | }
197 |
--------------------------------------------------------------------------------
/conf/auth.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "fmt"
5 | "path/filepath"
6 |
7 | "github.com/casbin/casbin/v2"
8 | gormadapter "github.com/casbin/gorm-adapter/v3"
9 | "github.com/snowlyg/helper/dir"
10 | "gorm.io/gorm"
11 | )
12 |
13 | const CasbinName = "rbac_model.conf"
14 |
15 | // Remove del config file
16 | func (conf *Conf) RemoveRbacModel() error {
17 | p := conf.casbinFilePath()
18 | if filepath.Base(p) != CasbinName {
19 | return nil
20 | }
21 | if dir.IsExist(p) && dir.IsFile(p) {
22 | return dir.Remove(p)
23 | }
24 | return nil
25 | }
26 |
27 | // casbinFilePath
28 | func (conf *Conf) casbinFilePath() string {
29 | return filepath.Join(dir.GetCurrentAbPath(), ConfigDir, CasbinName)
30 | }
31 |
32 | // newRbacModel initialize casbin's config file as rbac_model.conf name
33 | func (conf *Conf) newRbacModel() {
34 | if dir.IsExist(conf.casbinFilePath()) {
35 | // casbin rbac_model.conf file
36 | // log.Printf("rbac_model.conf file is existed.")
37 | return
38 | }
39 |
40 | var rbacModelConf = []byte(`[request_definition]
41 | r = sub, obj, act
42 |
43 | [policy_definition]
44 | p = sub, obj, act
45 |
46 | [role_definition]
47 | g = _, _
48 |
49 | [policy_effect]
50 | e = some(where (p.eft == allow))
51 |
52 | [matchers]
53 | m = g(r.sub, p.sub) && keyMatch2(r.obj, p.obj) && (r.act == p.act || p.act == "*")`)
54 | if _, err := dir.WriteBytes(conf.casbinFilePath(), rbacModelConf); err != nil {
55 | panic(fmt.Errorf("initialize casbin rbac_model.conf file return error: %w ", err))
56 | }
57 | }
58 |
59 | // getEnforcer get casbin.Enforcer
60 | func (conf *Conf) GetEnforcer(db *gorm.DB) (*casbin.Enforcer, error) {
61 | if db == nil {
62 | return nil, gorm.ErrInvalidDB
63 | }
64 | c, err := gormadapter.NewAdapterByDBUseTableName(db, "", "casbin_rule") // Your driver and data source.
65 | if err != nil {
66 | return nil, err
67 | }
68 | enforcer, err := casbin.NewEnforcer(conf.casbinFilePath(), c)
69 | if err != nil {
70 | return nil, err
71 | }
72 | if err = enforcer.LoadPolicy(); err != nil {
73 | return nil, err
74 | }
75 | return enforcer, nil
76 | }
77 |
--------------------------------------------------------------------------------
/conf/auth_test.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/snowlyg/helper/dir"
7 | )
8 |
9 | func TestNewRbacModel(t *testing.T) {
10 | conf := new(Conf)
11 | if conf.casbinFilePath() == "" {
12 | t.Errorf("rbac model path:%s empty", conf.casbinFilePath())
13 | }
14 | conf.newRbacModel()
15 | if !dir.IsExist(CasbinName) {
16 | t.Error("rbac_model.conf not exist after conf not init and new rbac model")
17 | }
18 | conf.RemoveRbacModel()
19 | if dir.IsExist(CasbinName) {
20 | t.Error("rbac_model.conf exist after conf not init and new rbac model")
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/conf/config.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "log"
7 | "os"
8 | "strconv"
9 | "strings"
10 |
11 | "github.com/gin-gonic/gin"
12 | "github.com/spf13/viper"
13 | )
14 |
15 | const (
16 | ConfigType = "json" // config's type
17 | ConfigDir = "config" // config's dir
18 | )
19 |
20 | var (
21 | mysqlAddrKey = "IRIS_ADMIN_MYSQL_ADDR"
22 | mysqlPwdKey = "IRIS_ADMIN_MYSQL_PWD"
23 | mysqlNameKey = "IRIS_ADMIN_MYSQL_NAME"
24 | webAddrKey = "IRIS_ADMIN_WEB_ADDR"
25 | )
26 |
27 | func NewConf() *Conf {
28 | c := &Conf{
29 | Locale: "zh",
30 | FileMaxSize: 1024, // upload file size limit 1024M
31 | SessionTimeout: 172800, // session timeout after 4 months
32 | CorsConf: CorsConf{
33 | AccessOrigin: "*",
34 | AccessHeaders: "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token,X-Token,X-User-Id",
35 | AccessMethods: "POST,GET,OPTIONS,DELETE,PUT",
36 | AccessExposeHeaders: "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type",
37 | AccessCredentials: "true",
38 | },
39 | Except: Route{
40 | Uri: "",
41 | Method: "",
42 | },
43 | System: System{
44 | Tls: false,
45 | GinMode: gin.ReleaseMode,
46 | Level: "debug",
47 | Addr: "127.0.0.1:8080",
48 | TimeFormat: "2006-01-02 15:04:05",
49 | },
50 | Limit: Limit{
51 | Disable: true,
52 | Limit: 0,
53 | Burst: 5,
54 | },
55 | Captcha: Captcha{
56 | KeyLong: 0,
57 | ImgWidth: 240,
58 | ImgHeight: 80,
59 | },
60 | Mysql: &Mysql{
61 | Path: "127.0.0.1:3306",
62 | Config: "charset=utf8mb4&parseTime=True&loc=Local",
63 | DbName: "iris-admin",
64 | Username: "root",
65 | Password: "",
66 | MaxIdleConns: 0,
67 | MaxOpenConns: 0,
68 | LogMode: false,
69 | LogZap: "error",
70 | },
71 | Operate: Operate{
72 | Except: Route{
73 | Uri: "api/v1/upload;api/v1/upload",
74 | Method: "post;put",
75 | },
76 | Include: Route{
77 | Uri: "api/v1/menus",
78 | Method: "get",
79 | },
80 | },
81 | }
82 | mysqlAddr := strings.TrimSpace(os.Getenv(mysqlAddrKey))
83 | mysqlPwd := strings.TrimSpace(os.Getenv(mysqlPwdKey))
84 | mysqlName := strings.TrimSpace(os.Getenv(mysqlNameKey))
85 | webAddr := strings.TrimSpace(os.Getenv(webAddrKey))
86 | if mysqlAddr != "" {
87 | c.Mysql.Path = mysqlAddr
88 | }
89 | if mysqlPwd != "" {
90 | c.Mysql.Password = mysqlPwd
91 | }
92 | if mysqlName != "" {
93 | c.Mysql.Username = mysqlName
94 | }
95 | if webAddr != "" {
96 | c.System.Addr = webAddr
97 | }
98 | if c.Mysql.Path == "" || c.Mysql.Password == "" || c.Mysql.DbName == "" {
99 | log.Printf("mysql driver config empty,you can set env %s %s %s to change it.\n", mysqlAddrKey, mysqlPwdKey, mysqlNameKey)
100 | }
101 | return c
102 | }
103 |
104 | type Conf struct {
105 | Locale string `mapstructure:"locale" json:"locale" yaml:"locale"`
106 | FileMaxSize int64 `mapstructure:"file-max-size" json:"file-max-size" yaml:"file-max-siz"`
107 | SessionTimeout int64 `mapstructure:"session-timeout" json:"session-timeout" yaml:"session-timeout"`
108 | Except Route `mapstructure:"except" json:"except" yaml:"except"`
109 | System System `mapstructure:"system" json:"system" yaml:"system"`
110 | Limit Limit `mapstructure:"limit" json:"limit" yaml:"limit"`
111 | Captcha Captcha `mapstructure:"captcha" json:"captcha" yaml:"captcha"`
112 | CorsConf CorsConf `mapstructure:"cors" json:"cors" yaml:"cors"`
113 | Mysql *Mysql `mapstructure:"mysql" json:"mysql" yaml:"mysql"`
114 | Operate Operate `mapstructure:"operate" json:"operate" yaml:"operate"`
115 | }
116 |
117 | type Route struct {
118 | Uri string `mapstructure:"uri" json:"uri" yaml:"uri"`
119 | Method string `mapstructure:"method" json:"method" yaml:"method"`
120 | }
121 |
122 | type Captcha struct {
123 | KeyLong int `mapstructure:"key-long" json:"key-long" yaml:"key-long"`
124 | ImgWidth int `mapstructure:"img-width" json:"img-width" yaml:"img-width"`
125 | ImgHeight int `mapstructure:"img-height" json:"img-height" yaml:"img-height"`
126 | }
127 |
128 | type Limit struct {
129 | Disable bool `mapstructure:"disable" json:"disable" yaml:"disable"`
130 | Limit float64 `mapstructure:"limit" json:"limit" yaml:"limit"`
131 | Burst int `mapstructure:"burst" json:"burst" yaml:"burst"`
132 | }
133 |
134 | type System struct {
135 | GinMode string `mapstructure:"gin-mode" json:"gin-mode" yaml:"gin-mode"`
136 | Tls bool `mapstructure:"tls" json:"tls" yaml:"tls"`
137 | Level string `mapstructure:"level" json:"level" yaml:"level"` // debug,release,test
138 | Addr string `mapstructure:"addr" json:"addr" yaml:"addr"`
139 | DbType string `mapstructure:"db-type" json:"db-type" yaml:"db-type"`
140 | TimeFormat string `mapstructure:"time-format" json:"time-format" yaml:"time-format"`
141 | }
142 |
143 | // SetDefaultAddrAndTimeFormat
144 | func (conf *Conf) SetDefaultAddrAndTimeFormat() {
145 | if conf.System.Addr == "" {
146 | conf.System.Addr = "127.0.0.1:8080"
147 | }
148 |
149 | if conf.System.TimeFormat == "" {
150 | conf.System.TimeFormat = "2006-01-02 15:04:05"
151 | }
152 | }
153 |
154 | // // toStaticUrl
155 | // func (conf *Conf) toStaticUrl(uri string) string {
156 | // path := filepath.Join(conf.System.Addr, conf.System.StaticPrefix, uri)
157 | // if conf.System.Tls {
158 | // return filepath.ToSlash(str.Join("https://", path))
159 | // }
160 | // return filepath.ToSlash(str.Join("http://", path))
161 | // }
162 |
163 | // IsExist config file is exist
164 | func (conf *Conf) IsExist() bool {
165 | return conf.getViperConfig().IsExist()
166 | }
167 |
168 | // RemoveFile remove config file
169 | func (conf *Conf) RemoveFile() error {
170 | return conf.getViperConfig().RemoveFile()
171 | }
172 |
173 | // Recover
174 | func (conf *Conf) Recover() error {
175 | conf.newRbacModel()
176 | b, err := json.MarshalIndent(conf, "", "\t")
177 | if err != nil {
178 | return fmt.Errorf("iris-admin recover config faild:%w", err)
179 | }
180 | return conf.getViperConfig().Recover(b)
181 | }
182 |
183 | // getViperConfig get viper config
184 | func (conf *Conf) getViperConfig() *ViperConf {
185 | maxSize := strconv.FormatInt(conf.FileMaxSize, 10)
186 | sessionTimeout := strconv.FormatInt(conf.SessionTimeout, 10)
187 | keyLong := strconv.FormatInt(int64(conf.Captcha.KeyLong), 10)
188 | imgWidth := strconv.FormatInt(int64(conf.Captcha.ImgWidth), 10)
189 | imgHeight := strconv.FormatInt(int64(conf.Captcha.ImgHeight), 10)
190 | limit := strconv.FormatInt(int64(conf.Limit.Limit), 10)
191 | burst := strconv.FormatInt(int64(conf.Limit.Burst), 10)
192 | disable := strconv.FormatBool(conf.Limit.Disable)
193 | tls := strconv.FormatBool(conf.System.Tls)
194 |
195 | mxIdleConns := fmt.Sprintf("%d", conf.Mysql.MaxIdleConns)
196 | mxOpenConns := fmt.Sprintf("%d", conf.Mysql.MaxOpenConns)
197 | logMode := fmt.Sprintf("%t", conf.Mysql.LogMode)
198 |
199 | configName := "iris_admin"
200 | return &ViperConf{
201 | dir: ConfigDir,
202 | name: configName,
203 | t: ConfigType,
204 | watch: func(vi *viper.Viper) error {
205 | if err := vi.Unmarshal(&conf); err != nil {
206 | return fmt.Errorf("get Unarshal error: %v", err)
207 | }
208 | // watch config file change
209 | vi.SetConfigName(configName)
210 | return nil
211 | },
212 | //
213 | Default: []byte(`
214 | {
215 | "locale": "` + conf.Locale + `",
216 | "file-max-size": ` + maxSize + `,
217 | "session-timeout": ` + sessionTimeout + `,
218 | "except":
219 | {
220 | "uri": "` + conf.Except.Uri + `",
221 | "method": "` + conf.Except.Method + `"
222 | },
223 | "cors":
224 | {
225 | "access-origin": "` + conf.CorsConf.AccessOrigin + `",
226 | "access-headers": "` + conf.CorsConf.AccessHeaders + `",
227 | "access-methods": "` + conf.CorsConf.AccessMethods + `",
228 | "access-expose-headers": "` + conf.CorsConf.AccessExposeHeaders + `",
229 | "access-credentials": "` + conf.CorsConf.AccessCredentials + `"
230 | },
231 | "captcha":
232 | {
233 | "key-long": ` + keyLong + `,
234 | "img-width": ` + imgWidth + `,
235 | "img-height": ` + imgHeight + `
236 | },
237 | "limit":
238 | {
239 | "limit": ` + limit + `,
240 | "disable": ` + disable + `,
241 | "burst": ` + burst + `
242 | },
243 | "system":
244 | {
245 | "tls": ` + tls + `,
246 | "level": "` + conf.System.Level + `",
247 | "gin-mode": "` + conf.System.GinMode + `",
248 | "addr": "` + conf.System.Addr + `",
249 | "time-format": "` + conf.System.TimeFormat + `"
250 | },
251 | "mysql":
252 | {
253 | "path": "` + conf.Mysql.Path + `",
254 | "config": "` + conf.Mysql.Config + `",
255 | "db-name": "` + conf.Mysql.DbName + `",
256 | "username": "` + conf.Mysql.Username + `",
257 | "password": "` + conf.Mysql.Password + `",
258 | "max-idle-conns": ` + mxIdleConns + `,
259 | "max-open-conns": ` + mxOpenConns + `,
260 | "log-mode": ` + logMode + `,
261 | "log-zap": "` + conf.Mysql.LogZap + `"
262 | },
263 | "operate":
264 | {
265 | "except":{
266 | "uri": "` + conf.Operate.Except.Uri + `",
267 | "method": "` + conf.Operate.Except.Method + `"
268 | },
269 | "include":
270 | {
271 | "uri": "` + conf.Operate.Include.Uri + `",
272 | "method": "` + conf.Operate.Include.Method + `"
273 | }
274 | }
275 | }`),
276 | }
277 | }
278 |
--------------------------------------------------------------------------------
/conf/config_test.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import "testing"
4 |
5 | func TestSetDefaultAddrAndTimeFormat(t *testing.T) {
6 | dc := &Conf{}
7 | addr := ""
8 | if dc.System.Addr != addr {
9 | t.Errorf("config system addr want '%s' but get '%s'", addr, dc.System.Addr)
10 | }
11 | timeFormat := ""
12 | if dc.System.TimeFormat != timeFormat {
13 | t.Errorf("config system time format want '%s' but get '%s'", timeFormat, dc.System.TimeFormat)
14 | }
15 | dc.SetDefaultAddrAndTimeFormat()
16 | addr = "127.0.0.1:8080"
17 | if dc.System.Addr != addr {
18 | t.Errorf("config system addr want '%s' but get '%s'", addr, dc.System.Addr)
19 | }
20 | timeFormat = "2006-01-02 15:04:05"
21 | if dc.System.TimeFormat != timeFormat {
22 | t.Errorf("config system time format want '%s' but get '%s'", timeFormat, dc.System.TimeFormat)
23 | }
24 |
25 | c := NewConf()
26 | if c.IsExist() {
27 | t.Error("config exist before init")
28 | }
29 | addr = "127.0.0.1:8080"
30 | if c.System.Addr != addr {
31 | t.Errorf("config system addr want '%s' but get '%s'", addr, c.System.Addr)
32 | }
33 | timeFormat = "2006-01-02 15:04:05"
34 | if c.System.TimeFormat != timeFormat {
35 | t.Errorf("config system time format want '%s' but get '%s'", timeFormat, c.System.TimeFormat)
36 | }
37 |
38 | if err := c.Recover(); err != nil {
39 | t.Error(err.Error())
40 | }
41 | defer func() {
42 | if err := c.getViperConfig().RemoveDir(); err != nil {
43 | t.Error(err.Error())
44 | }
45 | c.RemoveRbacModel()
46 | }()
47 | if !c.IsExist() {
48 | t.Error("config not exist after recover")
49 | }
50 | c.RemoveFile()
51 | if c.IsExist() {
52 | t.Error("config exist after remove")
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/conf/cros.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/gin-gonic/gin"
7 | )
8 |
9 | type CorsConf struct {
10 | AccessOrigin string `mapstructure:"access-origin" json:"burst" access-origin:"access-origin"`
11 | AccessHeaders string `mapstructure:"access-headers" json:"access-headers" yaml:"access-headers"`
12 | AccessMethods string `mapstructure:"access-methods" json:"access-methods" yaml:"access-methods"`
13 | AccessExposeHeaders string `mapstructure:"access-expose-headers" json:"access-expose-headers" yaml:"access-expose-headers"`
14 | AccessCredentials string `mapstructure:"access-credentials" json:"access-credentials" yaml:"access-credentials"`
15 | }
16 |
17 | // Cors
18 | func (corsConf *CorsConf) Cors() gin.HandlerFunc {
19 | return func(c *gin.Context) {
20 | method := c.Request.Method
21 | c.Header("Access-Control-Allow-Origin", corsConf.AccessOrigin)
22 | c.Header("Access-Control-Allow-Headers", corsConf.AccessHeaders)
23 | c.Header("Access-Control-Allow-Methods", corsConf.AccessMethods)
24 | c.Header("Access-Control-Expose-Headers", corsConf.AccessExposeHeaders)
25 | c.Header("Access-Control-Allow-Credentials", corsConf.AccessCredentials)
26 | if method == "OPTIONS" {
27 | c.AbortWithStatus(http.StatusNoContent)
28 | }
29 | c.Next()
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/conf/mysql.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import "fmt"
4 |
5 | type Mysql struct {
6 | Path string `mapstructure:"path" json:"path" yaml:"path"`
7 | Config string `mapstructure:"config" json:"config" yaml:"config"`
8 | DbName string `mapstructure:"db-name" json:"db-name" yaml:"db-name"`
9 | Username string `mapstructure:"username" json:"username" yaml:"username"`
10 | Password string `mapstructure:"password" json:"password" yaml:"password"`
11 | MaxIdleConns int `mapstructure:"max-idle-conns" json:"max-idle-conns" yaml:"max-idle-conns"`
12 | MaxOpenConns int `mapstructure:"max-open-conns" json:"max-open-conns" yaml:"max-open-conns"`
13 | LogMode bool `mapstructure:"log-mode" json:"log-mode" yaml:"log-mode"`
14 | LogZap string `mapstructure:"log-zap" json:"log-zap" yaml:"log-zap"` //silent,error,warn,info,zap
15 | }
16 |
17 | // Dsn return mysql dsn
18 | func (m *Mysql) Dsn() string {
19 | return fmt.Sprintf("%s%s?%s", m.BaseDsn(), m.DbName, m.Config)
20 | }
21 |
22 | // Dsn return
23 | func (m *Mysql) BaseDsn() string {
24 | return fmt.Sprintf("%s:%s@tcp(%s)/", m.Username, m.Password, m.Path)
25 | }
26 |
--------------------------------------------------------------------------------
/conf/mysql_test.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import "testing"
4 |
5 | func TestMysqlBaseDsn(t *testing.T) {
6 | m := &Mysql{
7 | Path: "127.0.0.1:3306",
8 | Config: "charset=utf8mb4",
9 | DbName: "db_name",
10 | Username: "name",
11 | Password: "pwd",
12 | MaxIdleConns: 0,
13 | MaxOpenConns: 0,
14 | LogMode: false,
15 | LogZap: "",
16 | }
17 | b := m.BaseDsn()
18 | want := "name:pwd@tcp(127.0.0.1:3306)/"
19 | if b != want {
20 | t.Errorf("mysql config base dsn want '%s' but get '%s'", want, b)
21 | }
22 | dsn := m.Dsn()
23 | wantDsn := "name:pwd@tcp(127.0.0.1:3306)/db_name?charset=utf8mb4"
24 | if dsn != wantDsn {
25 | t.Errorf("mysql config base dsn want '%s' but get '%s'", wantDsn, dsn)
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/conf/operate.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "strings"
5 | )
6 |
7 | // Operate
8 | // Except set which routers don't generate system log, use ';' to separate.
9 | // Include set which routers need to generate system log, use ';' to separate.
10 | type Operate struct {
11 | Except Route `mapstructure:"except" json:"except" yaml:"except"`
12 | Include Route `mapstructure:"include" json:"include" yaml:"include"`
13 | }
14 |
15 | // GetExcept return routers which need to excepted
16 | func (op Operate) GetExcept() ([]string, []string) {
17 | uri := strings.Split(op.Except.Uri, ";")
18 | method := strings.Split(op.Except.Method, ";")
19 | return uri, method
20 | }
21 |
22 | // GetInclude return routers which need to included
23 | func (op Operate) GetInclude() ([]string, []string) {
24 | uri := strings.Split(op.Include.Uri, ";")
25 | method := strings.Split(op.Include.Method, ";")
26 | return uri, method
27 | }
28 |
29 | // IsInclude check whether the current route needs to belong to the included data
30 | func (op Operate) IsInclude(uri, method string) bool {
31 | incUri, incMethod := op.GetInclude()
32 | if len(incUri) != len(incMethod) {
33 | return false
34 | }
35 |
36 | for i := 0; i < len(incUri); i++ {
37 | if uri == incUri[i] && method == incMethod[i] {
38 | return true
39 | }
40 | }
41 | return false
42 | }
43 |
44 | // IsExcept check whether the current route needs to belong to the excepted data
45 | func (op Operate) IsExcept(uri, method string) bool {
46 | excUri, excMethod := op.GetExcept()
47 | if len(excUri) != len(excMethod) {
48 | return false
49 | }
50 |
51 | for i := 0; i < len(excUri); i++ {
52 | if uri == excUri[i] && method == excMethod[i] {
53 | return true
54 | }
55 | }
56 | return false
57 | }
58 |
--------------------------------------------------------------------------------
/conf/viper.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "os"
7 | "path/filepath"
8 |
9 | "github.com/snowlyg/helper/dir"
10 | "github.com/snowlyg/helper/str"
11 | "github.com/snowlyg/iris-admin/e"
12 | "github.com/spf13/viper"
13 | )
14 |
15 | type ViperConf struct {
16 | dir string
17 | name string
18 | t string
19 | Default []byte
20 | watch func(*viper.Viper) error
21 | }
22 |
23 | // getConfPath
24 | func (vc *ViperConf) getConfPath() string {
25 | if vc == nil {
26 | return ""
27 | }
28 | return filepath.Join(dir.GetCurrentAbPath(), vc.Dir(), str.Join(vc.name, ".", vc.t))
29 | }
30 |
31 | // Dir
32 | func (vc *ViperConf) Dir() string {
33 | if vc.dir == "" {
34 | vc.dir = "config"
35 | return vc.dir
36 | }
37 | return vc.dir
38 | }
39 |
40 | // IsExist
41 | func (vc *ViperConf) IsExist() bool {
42 | if vc == nil {
43 | return false
44 | }
45 | return dir.IsExist(vc.getConfPath())
46 | }
47 |
48 | // RemoveFile remove config file
49 | func (vc *ViperConf) RemoveFile() error {
50 | if vc == nil {
51 | return e.ErrViperConfInvalid
52 | }
53 | d := filepath.Dir(vc.getConfPath())
54 | b := filepath.Base(d)
55 | if b != vc.Dir() {
56 | return nil
57 | }
58 | return dir.Remove(vc.getConfPath())
59 | }
60 |
61 | // RemoveDir remove config dir
62 | func (vc *ViperConf) RemoveDir() error {
63 | if vc == nil {
64 | return e.ErrViperConfInvalid
65 | }
66 | d := filepath.Dir(vc.getConfPath())
67 | b := filepath.Base(d)
68 | if b != vc.Dir() {
69 | return fmt.Errorf("%s viper conf base '%s' want but get '%s'", d, b, vc.Dir())
70 | }
71 | return os.RemoveAll(d)
72 | }
73 |
74 | // Recover
75 | func (vc *ViperConf) Recover(b []byte) error {
76 | if vc == nil {
77 | return e.ErrViperConfInvalid
78 | }
79 | _, err := dir.WriteBytes(vc.getConfPath(), b)
80 | return err
81 | }
82 |
83 | // NewViperConf
84 | func NewViperConf(vc *ViperConf) error {
85 | if vc == nil {
86 | return e.ErrViperConfInvalid
87 | }
88 | if vc.name == "" {
89 | return e.ErrConfigNameEmpty
90 | }
91 | if vc.t == "" {
92 | vc.t = "yaml"
93 | }
94 |
95 | vc.dir = vc.Dir()
96 | filePath := vc.getConfPath()
97 |
98 | vi := viper.New()
99 | vi.SetConfigName(vc.name)
100 | vi.SetConfigType(vc.t)
101 | vi.AddConfigPath(vc.dir)
102 | isExist := dir.IsExist(filePath)
103 | if !isExist {
104 | if vc.Dir() != "./" {
105 | if err := dir.InsureDir(filepath.Dir(filePath)); err != nil {
106 | return fmt.Errorf("create dir %s fail : %v", filePath, err)
107 | }
108 | }
109 | // ReadConfig
110 | if err := vi.ReadConfig(bytes.NewBuffer(vc.Default)); err != nil {
111 | return fmt.Errorf("read default config fail : %w ", err)
112 | }
113 | // WriteConfigAs
114 | if err := vi.WriteConfigAs(filePath); err != nil {
115 | return fmt.Errorf("write config to path fail: %w ", err)
116 | }
117 | } else {
118 | vi.SetConfigFile(filePath)
119 | if err := vi.ReadInConfig(); err != nil {
120 | return fmt.Errorf("read config fail: %w ", err)
121 | }
122 | }
123 | if err := vc.watch(vi); err != nil {
124 | return fmt.Errorf("watch config fail: %w ", err)
125 | }
126 | return nil
127 | }
128 |
--------------------------------------------------------------------------------
/conf/viper_test.go:
--------------------------------------------------------------------------------
1 | package conf
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "path/filepath"
8 | "testing"
9 |
10 | "github.com/snowlyg/helper/dir"
11 | "github.com/snowlyg/helper/str"
12 | "github.com/snowlyg/iris-admin/e"
13 | "github.com/spf13/viper"
14 | )
15 |
16 | func TestNewViperConfFail(t *testing.T) {
17 | if err := NewViperConf(nil); !errors.Is(err, e.ErrViperConfInvalid) {
18 | t.Errorf("new viper conf with nil return err not confi invalid:%v", err)
19 | }
20 | if err := NewViperConf(&ViperConf{}); !errors.Is(err, e.ErrConfigNameEmpty) {
21 | t.Errorf("new viper conf with nil return err not emtpy name:%v", err)
22 | }
23 | }
24 |
25 | type Zap struct {
26 | Level int64 `mapstructure:"level" json:"level" yaml:"level"` //debug ,info,warn,error,panic,fatal
27 | StacktraceKey string `mapstructure:"stacktrace-key" json:"stacktrace-key" yaml:"stacktrace-key"`
28 | LogInConsole bool `mapstructure:"log-in-console" json:"log-in-console" yaml:"log-in-console"`
29 | }
30 |
31 | func TestViperInit(t *testing.T) {
32 | tc := &Zap{}
33 | vi := &ViperConf{
34 | // directory: ConfigDir,
35 | name: "config",
36 | t: ConfigType,
37 | watch: func(vi *viper.Viper) error {
38 | if err := vi.Unmarshal(tc); err != nil {
39 | return fmt.Errorf("get Unarshal error: %v", err)
40 | }
41 | vi.SetConfigName("config")
42 | return nil
43 | },
44 | //
45 | Default: []byte(`{
46 | "level": 0,
47 | "stacktrace-key": "stacktrace",
48 | "log-in-console": true}`),
49 | }
50 | defer func() {
51 | if err := vi.RemoveDir(); err != nil {
52 | t.Error(err.Error())
53 | }
54 | }()
55 |
56 | if vi.IsExist() {
57 | t.Error("config exist")
58 | }
59 |
60 | vi.Dir()
61 | if vi.dir != "config" {
62 | t.Errorf("directory want '%s' but get '%s'", "config", vi.dir)
63 | }
64 |
65 | want := Zap{
66 | Level: 0,
67 | StacktraceKey: "stacktrace",
68 | LogInConsole: true,
69 | }
70 |
71 | if err := NewViperConf(vi); err != nil {
72 | t.Errorf("init %s's config get error: %v", str.Join(vi.name, ".", vi.t), err)
73 | }
74 |
75 | if !vi.IsExist() {
76 | t.Error("config not exist")
77 | }
78 |
79 | if want.Level != tc.Level {
80 | t.Errorf("want %+v but get %+v", want.Level, tc.Level)
81 | }
82 | if want.StacktraceKey != tc.StacktraceKey {
83 | t.Errorf("want %+v but get %+v", want.StacktraceKey, tc.StacktraceKey)
84 | }
85 | if want.LogInConsole != tc.LogInConsole {
86 | t.Errorf("want %+v but get %+v", want.LogInConsole, tc.LogInConsole)
87 | }
88 |
89 | dir.WriteBytes(filepath.Join(vi.getConfPath()), []byte(`{
90 | "level": 2,
91 | "stacktrace-key": "stacktrace1",
92 | "log-in-console": false}`))
93 |
94 | want1 := Zap{
95 | Level: 2,
96 | StacktraceKey: "stacktrace1",
97 | LogInConsole: false,
98 | }
99 |
100 | if err := NewViperConf(vi); err != nil {
101 | t.Errorf("init %s's config get error: %v", str.Join(vi.name, ".", vi.t), err)
102 | }
103 |
104 | if want1.Level != tc.Level {
105 | t.Errorf("want1 %+v but get %+v", want1.Level, tc.Level)
106 | }
107 | if want1.StacktraceKey != tc.StacktraceKey {
108 | t.Errorf("want1 %+v but get %+v", want1.StacktraceKey, tc.StacktraceKey)
109 | }
110 | if want1.LogInConsole != tc.LogInConsole {
111 | t.Errorf("want1 %+v but get %+v", want1.LogInConsole, tc.LogInConsole)
112 | }
113 |
114 | tc.Level = 3
115 | tc.StacktraceKey = "stacktrace3"
116 | tc.LogInConsole = true
117 |
118 | b, err := json.Marshal(&tc)
119 | if err != nil {
120 | t.Error(err.Error())
121 | }
122 |
123 | if err := vi.Recover(b); err != nil {
124 | t.Error(err.Error())
125 | }
126 |
127 | want2 := &Zap{}
128 | if b, err := dir.ReadBytes(vi.getConfPath()); err != nil {
129 | t.Error(err.Error())
130 | } else {
131 | if err := json.Unmarshal(b, want2); err != nil {
132 | t.Error(err.Error())
133 | }
134 | }
135 |
136 | if want2.Level != tc.Level {
137 | t.Errorf("want2 %+v but get %+v", tc.Level, want2.Level)
138 | }
139 | if want2.StacktraceKey != tc.StacktraceKey {
140 | t.Errorf("want2 %+v but get %+v", tc.StacktraceKey, want2.StacktraceKey)
141 | }
142 | if want2.LogInConsole != tc.LogInConsole {
143 | t.Errorf("want2 %+v but get %+v", tc.LogInConsole, want2.LogInConsole)
144 | }
145 |
146 | if err := vi.RemoveFile(); err != nil {
147 | t.Error(err.Error())
148 | }
149 |
150 | if vi.IsExist() {
151 | t.Error("config file exist after remove")
152 | }
153 | }
154 |
--------------------------------------------------------------------------------
/e/error.go:
--------------------------------------------------------------------------------
1 | package e
2 |
3 | import (
4 | "errors"
5 | )
6 |
7 | var (
8 | ErrConfigNameEmpty = errors.New("config file name empty")
9 | ErrViperConfInvalid = errors.New("viper conf not invalid")
10 | ErrConfigInvalid = errors.New("config not invalid")
11 |
12 | ErrAuthInvalid = errors.New("auth invalid")
13 | ErrDbTableNameEmpty = errors.New("database table name empty")
14 | )
15 |
--------------------------------------------------------------------------------
/example/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 | "net/http"
6 |
7 | "github.com/gin-gonic/gin"
8 | admin "github.com/snowlyg/iris-admin"
9 | "github.com/snowlyg/iris-admin/conf"
10 | )
11 |
12 | func main() {
13 | c := conf.NewConf()
14 | // change default config
15 | if err := c.Recover(); err != nil {
16 | panic(err.Error())
17 | }
18 | s, err := admin.NewServe(c)
19 | if err != nil {
20 | panic(err.Error())
21 | }
22 |
23 | engine := s.Engine()
24 | // add group api v1
25 | v1 := engine.Group("/api/v1")
26 | {
27 | v1.GET("/health", func(ctx *gin.Context) {
28 | ctx.String(http.StatusOK, "OK")
29 | })
30 | }
31 |
32 | // noitce the static path should not start with /
33 | // because static path use /*filepath to match all path start with /
34 | engine.Static("/admin", "./public")
35 | log.Printf("open: http://%s/admin in your browser\n", s.SystemAddr())
36 |
37 | s.Run()
38 | }
39 |
--------------------------------------------------------------------------------
/example/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Iris-Admin by Snowlyg
5 |
71 |
72 |
73 |
77 |
78 |
79 |
80 | About Iris-Admin
81 |
82 | Snowlyg/Iris-Admin is a robust and modular backend system based on the Iris web framework. It offers a scalable, maintainable, and extensible architecture for rapid development of admin dashboards and APIs.
83 |
84 |
85 |
86 |
87 | Key Features
88 |
89 |
90 |
Modular Architecture
91 |
Build and manage your system with reusable modules for users, roles, permissions, and more.
92 |
93 |
94 |
RBAC & JWT Auth
95 |
Secure your application with role-based access control and JWT authentication.
96 |
97 |
98 |
Swagger Integration
99 |
Auto-generate API documentation using Swagger for easy testing and integration.
100 |
101 |
102 |
Frontend Ready
103 |
Designed to integrate seamlessly with Vue-based admin UIs like vue-element-admin.
104 |
105 |
106 |
107 |
108 |
118 |
119 |
120 |
121 | © 2025 Snowlyg - Iris-Admin Project
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/example/readme.md:
--------------------------------------------------------------------------------
1 | # example
2 |
3 | ## todo
4 |
5 | [+] example, one page show repository information.
6 | [-] menu sync.
7 | [-] permission & role & user.
8 | [-] login logout.
9 | [-] use [xaboty/form-create-designer](https://github.com/xaboy/form-create-designer) create form with gorm model.
10 |
11 | 1. get [vue-element-admin](https://github.com/PanJiaChen/vue-element-admin.git)
12 |
13 | ```shell
14 | git clone https://github.com/PanJiaChen/vue-element-admin.git
15 | ```
16 |
17 | 2. get [iris-admin example](https://github.com/snowlyg/iris-admin.git)
18 |
19 | ```shell
20 | git clone https://github.com/snowlyg/iris-admin.git
21 |
22 | cd iris-admin/example
23 |
24 | go run main.go
25 | ```
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/snowlyg/iris-admin
2 |
3 | go 1.22.6
4 |
5 | require (
6 | github.com/aviddiviner/gin-limit v0.0.0-20170918012823-43b5f79762c1
7 | github.com/bwmarrin/snowflake v0.3.0
8 | github.com/casbin/casbin/v2 v2.104.0
9 | github.com/casbin/gorm-adapter/v3 v3.32.0
10 | github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6
11 | github.com/gavv/httpexpect/v2 v2.17.0
12 | github.com/gin-contrib/pprof v1.3.0
13 | github.com/gin-gonic/gin v1.8.1
14 | github.com/go-gormigrate/gormigrate/v2 v2.0.0
15 | github.com/go-playground/validator/v10 v10.11.0
16 | github.com/go-redis/redis/v8 v8.11.5
17 | github.com/golang-jwt/jwt v3.2.2+incompatible
18 | github.com/gorilla/websocket v1.5.0
19 | github.com/mattn/go-colorable v0.1.13
20 | github.com/patrickmn/go-cache v2.1.0+incompatible
21 | github.com/pkg/errors v0.9.1
22 | github.com/satori/go.uuid v1.2.0
23 | github.com/snowlyg/helper v0.1.33
24 | github.com/spf13/viper v1.12.0
25 | github.com/unrolled/secure v1.0.9
26 | go.mongodb.org/mongo-driver v1.11.4
27 | gorm.io/driver/mysql v1.5.7
28 | gorm.io/gorm v1.25.12
29 | )
30 |
31 | require (
32 | github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 // indirect
33 | github.com/ajg/form v1.5.1 // indirect
34 | github.com/andybalholm/brotli v1.0.4 // indirect
35 | github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect
36 | github.com/casbin/govaluate v1.3.0 // indirect
37 | github.com/cespare/xxhash/v2 v2.1.2 // indirect
38 | github.com/davecgh/go-spew v1.1.1 // indirect
39 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
40 | github.com/dustin/go-humanize v1.0.1 // indirect
41 | github.com/fatih/color v1.15.0 // indirect
42 | github.com/fatih/structs v1.1.0 // indirect
43 | github.com/fsnotify/fsnotify v1.8.0 // indirect
44 | github.com/gin-contrib/sse v0.1.0 // indirect
45 | github.com/glebarez/go-sqlite v1.20.3 // indirect
46 | github.com/glebarez/sqlite v1.7.0 // indirect
47 | github.com/go-playground/locales v0.14.0 // indirect
48 | github.com/go-playground/universal-translator v0.18.0 // indirect
49 | github.com/go-sql-driver/mysql v1.7.0 // indirect
50 | github.com/gobwas/glob v0.2.3 // indirect
51 | github.com/goccy/go-json v0.9.8-0.20220506185958-23bd66f4c0d5 // indirect
52 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
53 | github.com/golang-sql/sqlexp v0.1.0 // indirect
54 | github.com/golang/snappy v0.0.4 // indirect
55 | github.com/google/go-querystring v1.1.0 // indirect
56 | github.com/google/uuid v1.3.0 // indirect
57 | github.com/gosuri/uilive v0.0.4 // indirect
58 | github.com/gosuri/uiprogress v0.0.1 // indirect
59 | github.com/hashicorp/hcl v1.0.0 // indirect
60 | github.com/imkira/go-interpol v1.1.0 // indirect
61 | github.com/jackc/pgpassfile v1.0.0 // indirect
62 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
63 | github.com/jackc/pgx/v5 v5.5.5 // indirect
64 | github.com/jackc/puddle/v2 v2.2.1 // indirect
65 | github.com/jinzhu/inflection v1.0.0 // indirect
66 | github.com/jinzhu/now v1.1.5 // indirect
67 | github.com/json-iterator/go v1.1.12 // indirect
68 | github.com/klauspost/compress v1.15.6 // indirect
69 | github.com/leodido/go-urn v1.2.1 // indirect
70 | github.com/magiconair/properties v1.8.6 // indirect
71 | github.com/mattn/go-isatty v0.0.18 // indirect
72 | github.com/microsoft/go-mssqldb v1.6.0 // indirect
73 | github.com/mitchellh/go-wordwrap v1.0.1 // indirect
74 | github.com/mitchellh/mapstructure v1.5.0 // indirect
75 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
76 | github.com/modern-go/reflect2 v1.0.2 // indirect
77 | github.com/montanaflynn/stats v0.7.0 // indirect
78 | github.com/pelletier/go-toml v1.9.5 // indirect
79 | github.com/pelletier/go-toml/v2 v2.0.2 // indirect
80 | github.com/pmezard/go-difflib v1.0.0 // indirect
81 | github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect
82 | github.com/sanity-io/litter v1.5.5 // indirect
83 | github.com/sergi/go-diff v1.2.0 // indirect
84 | github.com/spf13/afero v1.8.2 // indirect
85 | github.com/spf13/cast v1.5.0 // indirect
86 | github.com/spf13/jwalterweatherman v1.1.0 // indirect
87 | github.com/spf13/pflag v1.0.5 // indirect
88 | github.com/stretchr/testify v1.8.4 // indirect
89 | github.com/subosito/gotenv v1.3.0 // indirect
90 | github.com/ugorji/go/codec v1.2.7 // indirect
91 | github.com/valyala/bytebufferpool v1.0.0 // indirect
92 | github.com/valyala/fasthttp v1.40.0 // indirect
93 | github.com/xdg-go/pbkdf2 v1.0.0 // indirect
94 | github.com/xdg-go/scram v1.1.1 // indirect
95 | github.com/xdg-go/stringprep v1.0.3 // indirect
96 | github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
97 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
98 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect
99 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 // indirect
100 | github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
101 | github.com/yudai/gojsondiff v1.0.0 // indirect
102 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect
103 | golang.org/x/crypto v0.31.0 // indirect
104 | golang.org/x/net v0.33.0 // indirect
105 | golang.org/x/sync v0.10.0 // indirect
106 | golang.org/x/sys v0.28.0 // indirect
107 | golang.org/x/text v0.21.0 // indirect
108 | google.golang.org/protobuf v1.28.0 // indirect
109 | gopkg.in/ini.v1 v1.66.6 // indirect
110 | gopkg.in/yaml.v2 v2.4.0 // indirect
111 | gopkg.in/yaml.v3 v3.0.1 // indirect
112 | gorm.io/driver/postgres v1.5.9 // indirect
113 | gorm.io/driver/sqlserver v1.5.3 // indirect
114 | gorm.io/plugin/dbresolver v1.5.3 // indirect
115 | modernc.org/libc v1.22.2 // indirect
116 | modernc.org/mathutil v1.5.0 // indirect
117 | modernc.org/memory v1.5.0 // indirect
118 | modernc.org/sqlite v1.20.3 // indirect
119 | moul.io/http2curl/v2 v2.3.0 // indirect
120 | )
121 |
--------------------------------------------------------------------------------
/httptest/base.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log"
7 | "net/http"
8 | "testing"
9 |
10 | "github.com/gavv/httpexpect/v2"
11 | "github.com/snowlyg/helper/str"
12 | )
13 |
14 | var (
15 | c *Client
16 | // default page request params
17 | GetRequestFunc = NewWithQueryObjectParamFunc(map[string]any{"page": 1, "pageSize": 10})
18 |
19 | // default page request params
20 | PostRequestFunc = NewWithJsonParamFunc(map[string]any{"page": 1, "pageSize": 10})
21 |
22 | // default login request params
23 | LoginFunc = NewWithJsonParamFunc(map[string]any{"username": "admin", "password": "123456"})
24 |
25 | // default login response params
26 | LoginResponse = Responses{
27 | {Key: "status", Value: http.StatusOK},
28 | {Key: "message", Value: "OK"},
29 | {Key: "data",
30 | Value: Responses{
31 | {Key: "accessToken", Value: "", Type: "notempty"},
32 | },
33 | },
34 | }
35 |
36 | // SuccessResponse default success response params
37 | SuccessResponse = Responses{
38 | {Key: "status", Value: http.StatusOK},
39 | {Key: "message", Value: "OK"},
40 | }
41 |
42 | // ResponsePage default data response params
43 | ResponsePage = Responses{
44 | {Key: "status", Value: http.StatusOK},
45 | {Key: "message", Value: "OK"},
46 | {Key: "data", Value: Responses{
47 | {Key: "pageSize", Value: 10},
48 | {Key: "page", Value: 1},
49 | }},
50 | }
51 | )
52 |
53 | // paramFunc
54 | type paramFunc func(req *httpexpect.Request) *httpexpect.Request
55 |
56 | // NewWithJsonParamFunc return req.WithJSON
57 | func NewWithJsonParamFunc(query map[string]any) paramFunc {
58 | return func(req *httpexpect.Request) *httpexpect.Request {
59 | return req.WithJSON(query)
60 | }
61 | }
62 |
63 | // NewWithQueryObjectParamFunc query for get method
64 | func NewWithQueryObjectParamFunc(query map[string]any) paramFunc {
65 | return func(req *httpexpect.Request) *httpexpect.Request {
66 | return req.WithQueryObject(query)
67 | }
68 | }
69 |
70 | // NewWithFileParamFunc return req.WithFile
71 | func NewWithFileParamFunc(fs []File, query map[string]any) paramFunc {
72 | return func(req *httpexpect.Request) *httpexpect.Request {
73 | if len(fs) == 0 {
74 | return req
75 | }
76 | req = req.WithMultipart()
77 | for _, f := range fs {
78 | req = req.WithFile(f.Key, f.Path, f.Reader)
79 | }
80 | if query == nil {
81 | return req
82 | }
83 | return req.WithForm(query)
84 | }
85 | }
86 |
87 | // NewWithFormParamFunc
88 | func NewWithFormParamFunc(query map[string]any) paramFunc {
89 | return func(req *httpexpect.Request) *httpexpect.Request {
90 | if query == nil {
91 | return req
92 | }
93 | return req.WithMultipart().WithForm(query)
94 | }
95 | }
96 |
97 | // NewResponsesWithLength return Responses with length value for data key
98 | func NewResponsesWithLength(status int, message string, data []Responses, length int) Responses {
99 | return Responses{
100 | {Key: "status", Value: status},
101 | {Key: "message", Value: message},
102 | {Key: "data", Value: data, Length: length},
103 | }
104 | }
105 |
106 | // NewResponses return Responses
107 | func NewResponses(status int, message string, data ...Responses) Responses {
108 | if status != http.StatusOK {
109 | return Responses{
110 | {Key: "status", Value: status},
111 | {Key: "message", Value: message},
112 | }
113 | }
114 | if len(data) == 0 {
115 | return Responses{
116 | {Key: "status", Value: status},
117 | {Key: "message", Value: message},
118 | }
119 | }
120 | if len(data) == 1 {
121 | return Responses{
122 | {Key: "status", Value: status},
123 | {Key: "message", Value: message},
124 | {Key: "data", Value: data[0]},
125 | }
126 | }
127 | return Responses{
128 | {Key: "status", Value: status},
129 | {Key: "message", Value: message},
130 | {Key: "data", Value: data},
131 | }
132 | }
133 |
134 | type Client struct {
135 | t *testing.T
136 | conf httpexpect.Config
137 | expect *httpexpect.Expect
138 | status int
139 | headers map[string]string
140 |
141 | tokenIndex string
142 | loginApi string
143 | logoutApi string
144 | }
145 |
146 | type ClientConf struct {
147 | Key string
148 | Value string
149 | }
150 |
151 | const (
152 | BASE_URL = "base_url"
153 | TOKEN_INDEX = "token_index"
154 | LoginApi = "login_api"
155 | LogoutApi = "logout_api"
156 | )
157 |
158 | func NewBaseUrlConf(value string) ClientConf {
159 | return ClientConf{Key: BASE_URL, Value: value}
160 | }
161 |
162 | func NewTokenIndexConf(value string) ClientConf {
163 | return ClientConf{Key: TOKEN_INDEX, Value: value}
164 | }
165 |
166 | func NewLoginApiConf(value string) ClientConf {
167 | return ClientConf{Key: LoginApi, Value: value}
168 | }
169 | func NewLogoutApiConf(value string) ClientConf {
170 | return ClientConf{Key: LogoutApi, Value: value}
171 | }
172 |
173 | // NewClient return test client instance
174 | func NewClient(t *testing.T, handler http.Handler, confs ...ClientConf) *Client {
175 | c = &Client{
176 | t: t,
177 | conf: httpexpect.Config{
178 | TestName: t.Name(),
179 | Client: &http.Client{
180 | Transport: httpexpect.NewBinder(handler),
181 | Jar: httpexpect.NewCookieJar(),
182 | },
183 | Reporter: httpexpect.NewAssertReporter(t),
184 | Printers: []httpexpect.Printer{
185 | NewDebugPrinter(t, true),
186 | // httpexpect.NewCompactPrinter(t),
187 | // httpexpect.NewCurlPrinter(t),
188 | },
189 | // Printers: []httpexpect.Printer{
190 | // httpexpect.NewCompactPrinter(t),
191 | // },
192 | Formatter: &httpexpect.DefaultFormatter{
193 | // DisablePaths: true,
194 | // DisableDiffs: true,
195 | // FloatFormat: httpexpect.FloatFormatScientific,
196 | ColorMode: httpexpect.ColorModeAlways,
197 | // LineWidth: 80,
198 | },
199 | },
200 | headers: map[string]string{},
201 | tokenIndex: "data.accessToken",
202 | loginApi: "/login",
203 | logoutApi: "/logout",
204 | }
205 | if len(confs) > 0 {
206 | for _, conf := range confs {
207 | switch conf.Key {
208 | case BASE_URL:
209 | c.conf.BaseURL = conf.Value
210 | case TOKEN_INDEX:
211 | c.tokenIndex = conf.Value
212 | case LoginApi:
213 | c.loginApi = conf.Value
214 | case LogoutApi:
215 | c.logoutApi = conf.Value
216 | }
217 | }
218 | }
219 | c.expect = httpexpect.WithConfig(c.conf)
220 | return c
221 | }
222 |
223 | func (c *Client) SwitchT(t *testing.T) {
224 | c.t = t
225 | c.conf.TestName = t.Name()
226 | c.conf.Reporter = httpexpect.NewAssertReporter(t)
227 | c.conf.Printers = []httpexpect.Printer{
228 | NewDebugPrinter(t, true),
229 | // httpexpect.NewCompactPrinter(t),
230 | // httpexpect.NewCurlPrinter(t),
231 | }
232 | c.expect = httpexpect.WithConfig(c.conf)
233 | c.expect = c.expect.Builder(func(req *httpexpect.Request) {
234 | req.WithHeaders(c.headers)
235 | })
236 | }
237 |
238 | // Login for http login
239 | func (c *Client) Login(res Responses, paramFuncs ...paramFunc) error {
240 | if len(paramFuncs) == 0 {
241 | paramFuncs = append(paramFuncs, LoginFunc)
242 | }
243 | c.POST(c.loginApi, res, paramFuncs...)
244 | token := res.GetString(c.tokenIndex)
245 | fmt.Printf("token %s is '%s'\n", c.tokenIndex, token)
246 | if token == "" {
247 | return fmt.Errorf("token %s is empty", c.tokenIndex)
248 | }
249 | c.headers["Authorization"] = str.Join("Bearer ", token)
250 | c.expect = c.expect.Builder(func(req *httpexpect.Request) {
251 | req.WithHeaders(c.headers)
252 | })
253 | return nil
254 | }
255 |
256 | // Logout for http logout
257 | func (c *Client) Logout(res Responses) {
258 | if res == nil {
259 | res = SuccessResponse
260 | }
261 | c.GET(c.logoutApi, res)
262 |
263 | c.headers["Authorization"] = ""
264 | c.expect = c.expect.Builder(func(req *httpexpect.Request) {
265 | req.WithHeaders(c.headers)
266 | })
267 | }
268 |
269 | type File struct {
270 | Key string
271 | Path string
272 | Reader io.Reader
273 | }
274 |
275 | // checkStatus check what's http response stauts want
276 | func (c *Client) checkStatus() int {
277 | if c.status == 0 {
278 | return http.StatusOK
279 | }
280 | return c.status
281 | }
282 |
283 | // SetStatus set what's http response stauts want
284 | func (c *Client) SetStatus(status int) *Client {
285 | c.status = status
286 | return c
287 | }
288 |
289 | // AddHeader
290 | func (c *Client) AddHeader(k, v string) *Client {
291 | c.headers[k] = v
292 | return c
293 | }
294 |
295 | // POST
296 | func (c *Client) POST(url string, resps any, paramFuncs ...paramFunc) {
297 | req := c.expect.POST(url)
298 | if len(paramFuncs) > 0 {
299 | for _, f := range paramFuncs {
300 | req = f(req)
301 | }
302 | }
303 | if testRes, ok := resps.(Responses); ok {
304 | obj := req.Expect().Status(c.checkStatus()).JSON()
305 | testRes.Test(obj)
306 | } else if testRes, ok := resps.([]Responses); ok {
307 | array := req.Expect().Status(c.checkStatus()).JSON().Array()
308 | for i, v := range testRes {
309 | v.Test(array.Value(i))
310 | }
311 | } else {
312 | log.Println("data type error")
313 | }
314 | }
315 |
316 | // PUT
317 | func (c *Client) PUT(url string, resps any, paramFuncs ...paramFunc) {
318 | req := c.expect.PUT(url)
319 | if len(paramFuncs) > 0 {
320 | for _, f := range paramFuncs {
321 | req = f(req)
322 | }
323 | }
324 | if testRes, ok := resps.(Responses); ok {
325 | obj := req.Expect().Status(c.checkStatus()).JSON()
326 | testRes.Test(obj)
327 | } else if testRes, ok := resps.([]Responses); ok {
328 | array := req.Expect().Status(c.checkStatus()).JSON().Array()
329 | for i, v := range testRes {
330 | v.Test(array.Value(i))
331 | }
332 | } else {
333 | log.Println("data type error")
334 | }
335 | }
336 |
337 | // UPLOAD
338 | func (c *Client) UPLOAD(url string, resps any, paramFuncs ...paramFunc) {
339 | req := c.expect.POST(url)
340 | if len(paramFuncs) > 0 {
341 | for _, f := range paramFuncs {
342 | req = f(req)
343 | }
344 | }
345 | if testRes, ok := resps.(Responses); ok {
346 | obj := req.Expect().Status(c.checkStatus()).JSON()
347 | testRes.Test(obj)
348 | } else if testRes, ok := resps.([]Responses); ok {
349 | array := req.Expect().Status(c.checkStatus()).JSON().Array()
350 | for i, v := range testRes {
351 | v.Test(array.Value(i))
352 | }
353 | } else {
354 | log.Println("data type error")
355 | }
356 | }
357 |
358 | // GET
359 | func (c *Client) GET(url string, resps any, paramFuncs ...paramFunc) {
360 | req := c.expect.GET(url)
361 | if len(paramFuncs) > 0 {
362 | for _, f := range paramFuncs {
363 | req = f(req)
364 | }
365 | }
366 | if resp, ok := resps.(Responses); ok {
367 | obj := req.Expect().Status(c.checkStatus()).JSON()
368 | resp.Test(obj)
369 | } else if resp, ok := resps.([]Responses); ok {
370 | array := req.Expect().Status(c.checkStatus()).JSON().Array()
371 | for i, v := range resp {
372 | v.Test(array.Value(i))
373 | }
374 | } else {
375 | log.Println("data type error")
376 | }
377 | }
378 |
379 | // DOWNLOAD
380 | func (c *Client) DOWNLOAD(url string, resps any, paramFuncs ...paramFunc) string {
381 | req := c.expect.GET(url)
382 | if len(paramFuncs) > 0 {
383 | for _, f := range paramFuncs {
384 | req = f(req)
385 | }
386 | }
387 |
388 | return req.Expect().Status(c.checkStatus()).Body().NotEmpty().Raw()
389 | }
390 |
391 | // DELETE
392 | func (c *Client) DELETE(url string, resps any, paramFuncs ...paramFunc) {
393 | req := c.expect.DELETE(url)
394 | if len(paramFuncs) > 0 {
395 | for _, f := range paramFuncs {
396 | req = f(req)
397 | }
398 | }
399 | if testRes, ok := resps.(Responses); ok {
400 | obj := req.Expect().Status(c.checkStatus()).JSON()
401 | testRes.Test(obj)
402 | } else if testRes, ok := resps.([]Responses); ok {
403 | array := req.Expect().Status(c.checkStatus()).JSON().Array()
404 | for i, v := range testRes {
405 | v.Test(array.Value(i))
406 | }
407 | } else {
408 | log.Println("data type error")
409 | }
410 | }
411 |
--------------------------------------------------------------------------------
/httptest/base_test.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
3 | import (
4 | "net/http"
5 | "os"
6 | "testing"
7 |
8 | "github.com/gin-gonic/gin"
9 | "github.com/snowlyg/helper/dir"
10 | )
11 |
12 | type Request struct {
13 | Message string `json:"message" form:"message" uri:"message"`
14 | }
15 |
16 | // GinHandler Create add /example route to gin engine
17 | func GinHandler(r *gin.Engine) *gin.Engine {
18 |
19 | // Add route to the gin engine
20 | r.GET("/example", func(c *gin.Context) {
21 | var req Request
22 | if errs := c.ShouldBind(&req); errs != nil {
23 | c.JSON(http.StatusBadRequest, nil)
24 | return
25 | }
26 | message := "pong"
27 | if req.Message != "" {
28 | message = req.Message
29 | }
30 | c.JSON(http.StatusOK, gin.H{
31 | "status": 200,
32 | "message": "OK",
33 | "data": gin.H{
34 | "message": message,
35 | },
36 | })
37 | })
38 |
39 | // Add route to the gin engine
40 | r.GET("/array", func(c *gin.Context) {
41 | var req Request
42 | if errs := c.ShouldBind(&req); errs != nil {
43 | c.JSON(http.StatusBadRequest, nil)
44 | return
45 | }
46 | c.JSON(http.StatusOK, []string{"1", "2"})
47 | })
48 |
49 | // Add route to the gin engine
50 | r.GET("/mutil", func(c *gin.Context) {
51 | var req Request
52 | if errs := c.ShouldBind(&req); errs != nil {
53 | c.JSON(http.StatusBadRequest, nil)
54 | return
55 | }
56 | message := "pong"
57 | if req.Message != "" {
58 | message = req.Message
59 | }
60 | c.JSON(http.StatusOK, gin.H{
61 | "status": 200,
62 | "message": "OK",
63 | "data": []gin.H{
64 | {"message": message},
65 | {"message": message},
66 | },
67 | })
68 | })
69 |
70 | // Add route to the gin engine
71 | r.POST("/example", func(c *gin.Context) {
72 | var req Request
73 | if errs := c.ShouldBindJSON(&req); errs != nil {
74 | c.JSON(http.StatusBadRequest, gin.H{
75 | "status": http.StatusBadRequest,
76 | "message": "FAIL",
77 | })
78 | return
79 | }
80 | message := "pong"
81 | if req.Message != "" {
82 | message = req.Message
83 | }
84 | c.JSON(http.StatusOK, gin.H{
85 | "status": 200,
86 | "message": "OK",
87 | "data": gin.H{
88 | "message": message,
89 | },
90 | })
91 | })
92 |
93 | // Add route to the gin engine
94 | r.POST("/upload", func(c *gin.Context) {
95 | _, _, err := c.Request.FormFile("file")
96 | if err != nil {
97 | c.JSON(http.StatusBadRequest, gin.H{
98 | "status": http.StatusBadRequest,
99 | "message": "FAIL",
100 | })
101 | return
102 | }
103 |
104 | c.JSON(http.StatusOK, gin.H{
105 | "status": 200,
106 | "message": "OK",
107 | })
108 | })
109 |
110 | type RequestId struct {
111 | Id uint `json:"id" form:"id" uri:"id"`
112 | }
113 |
114 | // Add route to the gin engine
115 | r.DELETE("/example/:id", func(c *gin.Context) {
116 | var req RequestId
117 | if errs := c.ShouldBindUri(&req); errs != nil {
118 | c.JSON(http.StatusBadRequest, nil)
119 | return
120 | }
121 | c.JSON(http.StatusOK, gin.H{
122 | "status": 200,
123 | "message": "OK",
124 | "data": gin.H{
125 | "id": req.Id,
126 | },
127 | })
128 | })
129 |
130 | // Add route to the gin engine
131 | r.POST("login", func(c *gin.Context) {
132 | c.JSON(http.StatusOK, gin.H{
133 | "status": 200,
134 | "message": "OK",
135 | "data": gin.H{
136 | "AccessToken": "EIIDFJDIKFJJIdfdkfk.uisdifsdfisdouf",
137 | "user": gin.H{
138 | "id": 1,
139 | },
140 | },
141 | })
142 | })
143 |
144 | // Add route to the gin engine
145 | r.GET("logout", func(c *gin.Context) {
146 | c.JSON(http.StatusOK, gin.H{
147 | "status": 200,
148 | "message": "OK",
149 | })
150 | })
151 |
152 | // Add route to the gin engine
153 | r.GET("header", func(c *gin.Context) {
154 | c.GetHeader("Authorization")
155 | c.JSON(http.StatusOK, gin.H{
156 | "status": 200,
157 | "message": "OK",
158 | "data": gin.H{
159 | "Authorization": c.GetHeader("Authorization"),
160 | },
161 | })
162 | })
163 |
164 | // return gin engine with newly added route
165 | return r
166 | }
167 |
168 | func TestNewClient(t *testing.T) {
169 | engine := gin.New()
170 | // Create httpexpect instance
171 | client := NewClient(t, GinHandler(engine))
172 | client.GET("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}}))
173 | client.DELETE("/example/1", NewResponses(http.StatusOK, "OK", Responses{{Key: "id", Value: 1}}))
174 | }
175 |
176 | func TestNewWithQueryObjectParamFunc(t *testing.T) {
177 | engine := gin.New()
178 | // Create httpexpect instance
179 | client := NewClient(t, GinHandler(engine))
180 | pageKeys := Responses{{Key: "message", Value: "message"}}
181 | client.GET("/example", NewResponses(http.StatusOK, "OK", pageKeys), NewWithQueryObjectParamFunc(map[string]interface{}{"message": "message"}))
182 | }
183 |
184 | func TestNewNewWithJsonParamFunc(t *testing.T) {
185 | engine := gin.New()
186 | // Create httpexpect instance
187 | client := NewClient(t, GinHandler(engine))
188 | client.POST("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "message"}}), NewWithJsonParamFunc(map[string]interface{}{"message": "message"}))
189 | client.POST("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}}), NewWithJsonParamFunc(map[string]interface{}{"message": ""}))
190 | }
191 |
192 | func TestNewResponses(t *testing.T) {
193 | engine := gin.New()
194 | // Create httpexpect instance
195 | client := NewClient(t, GinHandler(engine))
196 |
197 | client.GET("/example", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}}))
198 | client.GET("/mutil", NewResponses(http.StatusOK, "OK", Responses{{Key: "message", Value: "pong"}}, Responses{{Key: "message", Value: "pong"}}))
199 | client.SetStatus(http.StatusBadRequest).POST("/example", NewResponses(http.StatusBadRequest, "FAIL", nil))
200 | }
201 |
202 | func TestNewResponsesWithLength(t *testing.T) {
203 | engine := gin.New()
204 | // Create httpexpect instance
205 | client := NewClient(t, GinHandler(engine))
206 | res := []Responses{{{Key: "message", Value: "pong"}}, {{Key: "message", Value: "pong"}}}
207 | client.GET("/mutil", NewResponsesWithLength(http.StatusOK, "OK", res, 2))
208 | }
209 |
210 | func TestNewWithFileParamFunc(t *testing.T) {
211 | name := "test_img.jpg"
212 | if _, err := dir.WriteString("./"+name, ""); err != nil {
213 | t.Fatal(err.Error())
214 | }
215 | defer os.Remove("./" + name)
216 |
217 | engine := gin.New()
218 | // Create httpexpect instance
219 | client := NewClient(t, GinHandler(engine))
220 | fh, _ := os.Open("./" + name)
221 | defer fh.Close()
222 |
223 | uf := []File{{Key: "file", Path: name, Reader: fh}}
224 | client.UPLOAD("/upload", SuccessResponse, NewWithFileParamFunc(uf, nil))
225 | }
226 |
227 | func TestLogin(t *testing.T) {
228 | engine := gin.New()
229 | // Create httpexpect instance
230 | client := NewClient(t, GinHandler(engine), NewTokenIndexConf("data.AccessToken"))
231 | x := Responses{{Key: "AccessToken", Value: "EIIDFJDIKFJJIdfdkfk.uisdifsdfisdouf"}, {Key: "user", Value: Responses{{Key: "id", Value: 1}}}}
232 | err := client.Login(NewResponses(http.StatusOK, "OK", x))
233 | if err != nil {
234 | t.Error(err.Error())
235 | }
236 | if x.GetId("data.user.id") == 0 {
237 | t.Error("id is 0")
238 | }
239 | client.GET("/header", NewResponses(http.StatusOK, "OK", Responses{{Key: "Authorization", Value: "Bearer EIIDFJDIKFJJIdfdkfk.uisdifsdfisdouf"}}))
240 | }
241 |
242 | func TestLogout(t *testing.T) {
243 | engine := gin.New()
244 | client := NewClient(t, GinHandler(engine))
245 | client.Logout(SuccessResponse)
246 | }
247 |
--------------------------------------------------------------------------------
/httptest/common.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
3 | // // BeforeTestMainGin
4 | // func BeforeTestMainGin(party func(wi *WebServer), seed func(wi *WebServer, mc *MigrationCmd)) (string, *WebServer) {
5 | // dbType := admin.TestDbType
6 | // if dbType != "" {
7 | // CONFIG.System.DbType = dbType
8 | // }
9 | // if dbType == "redis" {
10 | // if err := cache.Recover(); err != nil {
11 | // log.Printf("cache recover fail:%s\n", err.Error())
12 | // }
13 | // }
14 | // if err := Recover(); err != nil {
15 | // log.Printf("web recover fail:%s\n", err.Error())
16 | // }
17 |
18 | // node, _ := snowflake.NewNode(1)
19 | // uuid := str.Join("gin", "_", node.Generate().String())
20 |
21 | // CONFIG.DbName = uuid
22 | // if user := admin.TestMysqlName; user != "" {
23 | // CONFIG.Username = user
24 | // }
25 | // if pwd := admin.TestMysqlPwd; pwd != "" {
26 | // CONFIG.Password = pwd
27 | // }
28 | // if addr := admin.TestMysqlAddr; addr != "" {
29 | // CONFIG.Path = addr
30 | // }
31 | // CONFIG.LogMode = true
32 | // if err := Recover(); err != nil {
33 | // log.Printf("databse recover fail:%s\n", err.Error())
34 | // }
35 |
36 | // if Instance() == nil {
37 | // fmt.Println("database instance is nil")
38 | // return uuid, nil
39 | // }
40 |
41 | // wi := Init()
42 | // party(wi)
43 | // StartTest(wi)
44 |
45 | // mc := New()
46 | // seed(wi, mc)
47 | // err := mc.Migrate()
48 | // if err != nil {
49 | // fmt.Printf("migrate fail: [%s]", err.Error())
50 | // return uuid, nil
51 | // }
52 | // err = mc.Seed()
53 | // if err != nil {
54 | // fmt.Printf("seed fail: [%s]", err.Error())
55 | // return uuid, nil
56 | // }
57 |
58 | // return uuid, wi
59 | // }
60 |
61 | // func AfterTestMain(uuid string, isDelDb bool) {
62 | // if isDelDb {
63 | // dsn := CONFIG.BaseDsn()
64 | // if err := DorpDB(dsn, "mysql", uuid); err != nil {
65 | // log.Printf("drop table(%s) on dsn(%s) fail %s\n", uuid, dsn, err.Error())
66 | // }
67 | // }
68 |
69 | // if db, _ := Instance().DB(); db != nil {
70 | // db.Close()
71 | // }
72 |
73 | // // defer operation.Remove()
74 | // // defer casbin.Remove()
75 | // // defer Remove()
76 | // // defer Remove()
77 | // }
78 |
--------------------------------------------------------------------------------
/httptest/printer.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "log"
7 | "net/http"
8 | "net/http/httputil"
9 | "strings"
10 | "time"
11 |
12 | "github.com/gavv/httpexpect/v2"
13 | "github.com/gorilla/websocket"
14 | )
15 |
16 | // DebugPrinter implements Printer and WebsocketPrinter.
17 | // Uses net/http/httputil to dump both requests and responses.
18 | // Also prints all websocket messages.
19 | type DebugPrinter struct {
20 | logger httpexpect.Logger
21 | body bool
22 | }
23 |
24 | var contentTypes = "text/html,image/jpeg,image/png,video/mp4"
25 |
26 | // NewDebugPrinter returns a new DebugPrinter given a logger and body
27 | // flag. If body is true, request and response body is also printed.
28 | func NewDebugPrinter(logger httpexpect.Logger, body bool) DebugPrinter {
29 | return DebugPrinter{logger, body}
30 | }
31 |
32 | // Request implements Printer.Request.
33 | func (p DebugPrinter) Request(req *http.Request) {
34 | if req == nil {
35 | return
36 | }
37 | log.Printf("Content-Type:%s\n", req.Header.Get("Content-Type"))
38 | if req.URL.Path == "/api/v1/file/upload" || strings.Contains(req.Header.Get("Content-Type"), "multipart/form-data") {
39 | p.body = false
40 | }
41 |
42 | for _, contentType := range strings.Split(contentTypes, ",") {
43 | if !p.body {
44 | continue
45 | }
46 | if req.Header.Get("Content-Type") == contentType {
47 | p.body = false
48 | }
49 | }
50 |
51 | dump, err := httputil.DumpRequest(req, p.body)
52 | if err != nil {
53 | panic(err)
54 | }
55 | p.logger.Logf("%s", dump)
56 | }
57 |
58 | // Response implements Printer.Response.
59 | func (p DebugPrinter) Response(resp *http.Response, duration time.Duration) {
60 | if resp == nil {
61 | return
62 | }
63 | for _, contentType := range strings.Split(contentTypes, ",") {
64 | if !p.body {
65 | continue
66 | }
67 | if resp.Header.Get("Content-Type") == contentType {
68 | p.body = false
69 | }
70 | }
71 |
72 | dump, err := httputil.DumpResponse(resp, p.body)
73 | if err != nil {
74 | panic(err)
75 | }
76 |
77 | text := strings.Replace(string(dump), "\r\n", "\n", -1)
78 | lines := strings.SplitN(text, "\n", 2)
79 |
80 | p.logger.Logf("%s %s\n%s", lines[0], duration, lines[1])
81 | }
82 |
83 | // WebsocketWrite implements WebsocketPrinter.WebsocketWrite.
84 | func (p DebugPrinter) WebsocketWrite(typ int, content []byte, closeCode int) {
85 | b := &bytes.Buffer{}
86 | fmt.Fprintf(b, "-> Sent: %s", wsMessageType(typ))
87 | if typ == websocket.CloseMessage {
88 | fmt.Fprintf(b, " %s", wsCloseCode(closeCode))
89 | }
90 | fmt.Fprint(b, "\n")
91 | if len(content) > 0 {
92 | if typ == websocket.BinaryMessage {
93 | fmt.Fprintf(b, "%v\n", content)
94 | } else {
95 | fmt.Fprintf(b, "%s\n", content)
96 | }
97 | }
98 | fmt.Fprintf(b, "\n")
99 | p.logger.Logf(b.String())
100 | }
101 |
102 | // WebsocketRead implements WebsocketPrinter.WebsocketRead.
103 | func (p DebugPrinter) WebsocketRead(typ int, content []byte, closeCode int) {
104 | b := &bytes.Buffer{}
105 | fmt.Fprintf(b, "<- Received: %s", wsMessageType(typ))
106 | if typ == websocket.CloseMessage {
107 | fmt.Fprintf(b, " %s", wsCloseCode(closeCode))
108 | }
109 | fmt.Fprint(b, "\n")
110 | if len(content) > 0 {
111 | if typ == websocket.BinaryMessage {
112 | fmt.Fprintf(b, "%v\n", content)
113 | } else {
114 | fmt.Fprintf(b, "%s\n", content)
115 | }
116 | }
117 | fmt.Fprintf(b, "\n")
118 | p.logger.Logf(b.String())
119 | }
120 |
121 | type wsMessageType int
122 |
123 | func (wmt wsMessageType) String() string {
124 | s := "unknown"
125 |
126 | switch wmt {
127 | case websocket.TextMessage:
128 | s = "text"
129 | case websocket.BinaryMessage:
130 | s = "binary"
131 | case websocket.CloseMessage:
132 | s = "close"
133 | case websocket.PingMessage:
134 | s = "ping"
135 | case websocket.PongMessage:
136 | s = "pong"
137 | }
138 |
139 | return fmt.Sprintf("%s(%d)", s, wmt)
140 | }
141 |
142 | type wsCloseCode int
143 |
144 | // https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent/code
145 | func (wcc wsCloseCode) String() string {
146 | s := "Unknown"
147 |
148 | switch wcc {
149 | case 1000:
150 | s = "NormalClosure"
151 | case 1001:
152 | s = "GoingAway"
153 | case 1002:
154 | s = "ProtocolError"
155 | case 1003:
156 | s = "UnsupportedData"
157 | case 1004:
158 | s = "Reserved"
159 | case 1005:
160 | s = "NoStatusReceived"
161 | case 1006:
162 | s = "AbnormalClosure"
163 | case 1007:
164 | s = "InvalidFramePayloadData"
165 | case 1008:
166 | s = "PolicyViolation"
167 | case 1009:
168 | s = "MessageTooBig"
169 | case 1010:
170 | s = "MandatoryExtension"
171 | case 1011:
172 | s = "InternalServerError"
173 | case 1012:
174 | s = "ServiceRestart"
175 | case 1013:
176 | s = "TryAgainLater"
177 | case 1014:
178 | s = "BadGateway"
179 | case 1015:
180 | s = "TLSHandshake"
181 | }
182 |
183 | return fmt.Sprintf("%s(%d)", s, wcc)
184 | }
185 |
--------------------------------------------------------------------------------
/httptest/reporter.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
--------------------------------------------------------------------------------
/httptest/respose.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "reflect"
7 | "strconv"
8 | "strings"
9 |
10 | "github.com/gavv/httpexpect/v2"
11 | "github.com/snowlyg/helper/arr"
12 | )
13 |
14 | // Responses
15 | type Responses []Response
16 |
17 | // Response
18 | type Response struct {
19 | Type string // httpest type , if empty use IsEqual() function to test
20 | Key string // httptest data's key
21 | Value any // httptest data's value
22 | Length int // httptest data's length,when the data are array or map
23 | Func func(obj any) // httpest func, you can add your test logic ,can be empty
24 | }
25 |
26 | // Keys return Responses object key array
27 | func (res Responses) Keys() []string {
28 | keys := []string{}
29 | for _, re := range res {
30 | keys = append(keys, re.Key)
31 | }
32 | return keys
33 | }
34 |
35 | // IdKeys return Responses with id
36 | func IdKeys() Responses {
37 | return Responses{
38 | {Key: "id", Value: 0, Type: "ge"},
39 | }
40 | }
41 |
42 | // Test for data test
43 | func Test(value *httpexpect.Value, reses ...any) {
44 | for _, ks := range reses {
45 | if ks == nil {
46 | return
47 | }
48 | reflectTypeString := reflect.TypeOf(ks).String()
49 | switch reflectTypeString {
50 | case "bool":
51 | value.Boolean().IsEqual(ks.(bool))
52 | case "string":
53 | value.String().IsEqual(ks.(string))
54 | case "float64":
55 | value.Number().IsEqual(ks.(float64))
56 | case "uint":
57 | value.Number().IsEqual(ks.(uint))
58 | case "int":
59 | value.IsEqual(ks.(int))
60 |
61 | case "[]httptest.Responses":
62 | valueLen := len(ks.([]Responses))
63 | length := int(value.Array().Length().Raw())
64 | value.Array().Length().IsEqual(valueLen)
65 | if length > 0 {
66 | max := 1
67 | if valueLen == length {
68 | max = length
69 | }
70 | for i := 0; i < max; i++ {
71 | ks.([]Responses)[i].Test(value.Array().Value(i))
72 | }
73 | }
74 |
75 | case "map[int][]httptest.Responses":
76 | values := ks.(map[int][]Responses)
77 | length := len(values)
78 | value.Object().Keys().Length().IsEqual(length)
79 | if length > 0 {
80 | for key, v := range values {
81 | for _, vres := range v {
82 | vres.Test(value.Object().Value(strconv.FormatInt(int64(key), 10)))
83 | }
84 | }
85 | }
86 | case "httptest.Responses":
87 | ks.(Responses).Test(value)
88 | case "[]uint":
89 | valueLen := len(ks.([]uint))
90 | value.Array().Length().IsEqual(valueLen)
91 | length := int(value.Array().Length().Raw())
92 | if length > 0 {
93 | max := 1
94 | if valueLen == length {
95 | max = length
96 | }
97 | for i := 0; i < max; i++ {
98 | value.Array().Value(i).Number().IsEqual(ks.([]uint)[i])
99 | }
100 | }
101 |
102 | case "[]string":
103 | valueLen := len(ks.([]string))
104 | value.Array().Length().IsEqual(valueLen)
105 | length := int(value.Array().Length().Raw())
106 | if length > 0 {
107 | max := 1
108 | if valueLen == length {
109 | max = length
110 | }
111 | for i := 0; i < max; i++ {
112 | value.Array().Value(i).String().IsEqual(ks.([]string)[i])
113 | }
114 | }
115 | case "map[int]string":
116 | values := ks.(map[int]string)
117 | value.Object().Keys().Length().IsEqual(len(values))
118 | for key, v := range values {
119 | value.Object().Value(strconv.FormatInt(int64(key), 10)).IsEqual(v)
120 | }
121 | default:
122 | continue
123 | }
124 | }
125 | }
126 |
127 | // Scan scan data form http response
128 | func Scan(object *httpexpect.Object, reses ...Responses) {
129 | if len(reses) == 0 {
130 | return
131 | }
132 |
133 | //return once
134 | if len(reses) == 1 {
135 | reses[0].Scan(object.Value("data").Object())
136 | return
137 | }
138 |
139 | array := object.Value("data").Array()
140 | length := int(array.Length().Raw())
141 | if length < len(reses) {
142 | fmt.Println("Return data not IsEqual keys length")
143 | array.Length().IsEqual(len(reses))
144 | return
145 | }
146 |
147 | // return array
148 | for m, res := range reses {
149 | if res == nil {
150 | return
151 | }
152 | res.Scan(object.Value("data").Array().Value(m).Object())
153 | }
154 | }
155 |
156 | // Test Test Responses object
157 | func (resp Responses) Test(value *httpexpect.Value) {
158 | for _, rs := range resp {
159 | if rs.Value == nil {
160 | continue
161 | }
162 | if rs.Func != nil {
163 | rs.Func(value.Object().Value(rs.Key))
164 |
165 | } else {
166 | reflectTypeString := reflect.TypeOf(rs.Value).String()
167 | switch reflectTypeString {
168 | case "bool":
169 | value.Object().Value(rs.Key).Boolean().IsEqual(rs.Value.(bool))
170 | case "string":
171 | if strings.ToLower(rs.Type) == "notempty" {
172 | value.Object().Value(rs.Key).String().NotEmpty()
173 | } else {
174 | value.Object().Value(rs.Key).String().IsEqual(rs.Value.(string))
175 | }
176 | case "float64":
177 | if strings.ToLower(rs.Type) == "ge" {
178 | value.Object().Value(rs.Key).Number().Ge(rs.Value.(float64))
179 | } else {
180 | value.Object().Value(rs.Key).Number().IsEqual(rs.Value.(float64))
181 | }
182 | case "uint":
183 | if strings.ToLower(rs.Type) == "ge" {
184 | value.Object().Value(rs.Key).Number().Ge(rs.Value.(uint))
185 | } else {
186 | value.Object().Value(rs.Key).Number().IsEqual(rs.Value.(uint))
187 | }
188 | case "int":
189 | if strings.ToLower(rs.Type) == "ge" {
190 | value.Object().Value(rs.Key).Number().Ge(rs.Value.(int))
191 | } else {
192 | value.Object().Value(rs.Key).Number().IsEqual(rs.Value.(int))
193 | }
194 | case "[]httptest.Responses":
195 | valueLen := len(rs.Value.([]Responses))
196 | length := int(value.Object().Value(rs.Key).Array().Length().Raw())
197 | value.Object().Value(rs.Key).Array().Length().IsEqual(valueLen)
198 | if length > 0 {
199 | max := 1
200 | if rs.Length > 0 {
201 | max = rs.Length
202 | }
203 | if valueLen == length {
204 | max = length
205 | }
206 | if valueLen > 0 {
207 | for i := 0; i < max; i++ {
208 | rs.Value.([]Responses)[i].Test(value.Object().Value(rs.Key).Array().Value(i))
209 | }
210 | }
211 | }
212 |
213 | case "map[int][]httptest.Responses":
214 | values := rs.Value.(map[int][]Responses)
215 | length := len(values)
216 | value.Object().Value(rs.Key).Object().Keys().Length().IsEqual(length)
217 | if length > 0 {
218 | for key, v := range values {
219 | for _, vres := range v {
220 | vres.Test(value.Object().Value(rs.Key).Object().Value(strconv.FormatInt(int64(key), 10)))
221 | }
222 | }
223 | }
224 | case "httptest.Responses":
225 | rs.Value.(Responses).Test(value.Object().Value(rs.Key))
226 | case "[]uint":
227 | valueLen := len(rs.Value.([]uint))
228 | value.Object().Value(rs.Key).Array().Length().IsEqual(valueLen)
229 | length := int(value.Object().Value(rs.Key).Array().Length().Raw())
230 | if length > 0 {
231 | max := 1
232 | if rs.Length > 0 {
233 | max = rs.Length
234 | }
235 | if valueLen == length {
236 | max = length
237 | }
238 | for i := 0; i < max; i++ {
239 | value.Object().Value(rs.Key).Array().ContainsAny(rs.Value.([]uint)[i])
240 | }
241 | }
242 |
243 | case "[]string":
244 |
245 | if strings.ToLower(rs.Type) == "null" {
246 | value.Object().Value(rs.Key).IsNull()
247 | } else if strings.ToLower(rs.Type) == "notnull" {
248 | value.Object().Value(rs.Key).NotNull()
249 | } else {
250 | valueLen := len(rs.Value.([]string))
251 | value.Object().Value(rs.Key).Array().Length().IsEqual(valueLen)
252 | length := int(value.Object().Value(rs.Key).Array().Length().Raw())
253 | if length > 0 {
254 | max := 1
255 | if rs.Length > 0 {
256 | max = rs.Length
257 | }
258 | if valueLen == length {
259 | max = length
260 | }
261 | for i := 0; i < max; i++ {
262 | value.Object().Value(rs.Key).Array().ContainsAny(rs.Value.([]string)[i])
263 | }
264 | }
265 | }
266 | case "map[int]string":
267 | if strings.ToLower(rs.Type) == "null" {
268 | value.Object().Value(rs.Key).IsNull()
269 | } else if strings.ToLower(rs.Type) == "notnull" {
270 | value.Object().Value(rs.Key).NotNull()
271 | } else {
272 | values := rs.Value.(map[int]string)
273 | value.Object().Value(rs.Key).Object().Keys().Length().IsEqual(len(values))
274 | for key, v := range values {
275 | value.Object().Value(rs.Key).Object().Value(strconv.FormatInt(int64(key), 10)).IsEqual(v)
276 | }
277 | }
278 | default:
279 | continue
280 | }
281 | }
282 | }
283 | resp.Scan(value.Object())
284 | }
285 |
286 | // Scan Scan response data to Responses object.
287 | func (res Responses) Scan(object *httpexpect.Object) {
288 | for k, rk := range res {
289 | if !Exist(object, rk.Key) {
290 | continue
291 | }
292 | if rk.Value == nil {
293 | continue
294 | }
295 | valueTypeName := reflect.TypeOf(rk.Value).String()
296 | switch valueTypeName {
297 | case "bool":
298 | res[k].Value = object.Value(rk.Key).Boolean().Raw()
299 | case "string":
300 | res[k].Value = object.Value(rk.Key).String().Raw()
301 | case "uint":
302 | res[k].Value = uint(object.Value(rk.Key).Number().Raw())
303 | case "int":
304 | res[k].Value = int(object.Value(rk.Key).Number().Raw())
305 | case "int32":
306 | res[k].Value = int32(object.Value(rk.Key).Number().Raw())
307 | case "float64":
308 | res[k].Value = object.Value(rk.Key).Number().Raw()
309 | case "[]httptest.Responses":
310 | valueLen := len(res[k].Value.([]Responses))
311 | if rk.Length > 0 {
312 | valueLen = rk.Length
313 | }
314 | object.Value(rk.Key).Array().Length().IsEqual(valueLen)
315 | length := int(object.Value(rk.Key).Array().Length().Raw())
316 | if length > 0 {
317 | max := 1
318 | if rk.Length > 0 {
319 | max = rk.Length
320 | }
321 | if valueLen == length {
322 | max = length
323 | }
324 | if valueLen > 0 {
325 | for i := 0; i < max; i++ {
326 | res[k].Value.([]Responses)[i].Scan(object.Value(rk.Key).Array().Value(i).Object())
327 | }
328 | }
329 | }
330 | case "httptest.Responses":
331 | rk.Value.(Responses).Scan(object.Value(rk.Key).Object())
332 | case "[]string":
333 | if strings.ToLower(rk.Type) == "null" {
334 | res[k].Value = []string{}
335 | } else if strings.ToLower(rk.Type) == "notnull" {
336 | continue
337 | } else {
338 | length := int(object.Value(rk.Key).Array().Length().Raw())
339 | if length == 0 {
340 | continue
341 | }
342 | reskey, ok := res[k].Value.([]string)
343 | if ok {
344 | var strings []string
345 | for i := 0; i < length; i++ {
346 | strings = append(reskey, object.Value(rk.Key).Array().Value(i).String().Raw())
347 | }
348 | res[k].Value = strings
349 | }
350 | }
351 | default:
352 | continue
353 | }
354 | }
355 | }
356 |
357 | // Exist Check object keys if the key is in the keys array.
358 | func Exist(object *httpexpect.Object, key string) bool {
359 | objectKyes := object.Keys().Raw()
360 | for _, objectKey := range objectKyes {
361 | if key == objectKey.(string) {
362 | return true
363 | }
364 | }
365 | return false
366 | }
367 |
368 | // GetString return string value.
369 | func (res Responses) GetString(key ...string) string {
370 | if len(key) == 0 {
371 | return ""
372 | }
373 |
374 | if len(key) == 1 {
375 | k := key[0]
376 | if strings.Contains(k, ".") {
377 | keys := strings.Split(k, ".")
378 | if len(keys) == 0 {
379 | return ""
380 | }
381 | key = keys
382 | }
383 | }
384 |
385 | for i := 0; i < len(key); i++ {
386 | for m, rk := range res {
387 | if rk.Value == nil {
388 | return ""
389 | }
390 | reflectTypeString := reflect.TypeOf(rk.Value).String()
391 | if key[i] == rk.Key {
392 | switch reflectTypeString {
393 | case "string":
394 | return rk.Value.(string)
395 | case "httptest.Responses":
396 | return res[m].Value.(Responses).GetString(key[i+1:]...)
397 | }
398 | }
399 | }
400 |
401 | }
402 | return ""
403 | }
404 |
405 | // GetStrArray return string array value.
406 | func (rks Responses) GetStrArray(key string) []string {
407 | for _, rk := range rks {
408 | if key == rk.Key {
409 | if rk.Value == nil {
410 | return nil
411 | }
412 | switch reflect.TypeOf(rk.Value).String() {
413 | case "[]string":
414 | return rk.Value.([]string)
415 | }
416 | }
417 | }
418 | return nil
419 | }
420 |
421 | // GetResponses return Resposnes Array value
422 | func (rks Responses) GetResponses(key string) []Responses {
423 | for _, rk := range rks {
424 | if key == rk.Key {
425 | if rk.Value == nil {
426 | return nil
427 | }
428 | switch reflect.TypeOf(rk.Value).String() {
429 | case "[]httptest.Responses":
430 | return rk.Value.([]Responses)
431 | }
432 | }
433 | }
434 | return nil
435 | }
436 |
437 | // GetResponsereturn Resposnes value
438 | func (rks Responses) GetResponse(key string) Responses {
439 | for _, rk := range rks {
440 | if key == rk.Key {
441 | if rk.Value == nil {
442 | return nil
443 | }
444 | switch reflect.TypeOf(rk.Value).String() {
445 | case "httptest.Responses":
446 | return rk.Value.(Responses)
447 | }
448 | }
449 | }
450 | return nil
451 | }
452 |
453 | // GetUint return uint value
454 | func (rks Responses) GetUint(key ...string) uint {
455 |
456 | if len(key) == 0 {
457 | return 0
458 | }
459 |
460 | if len(key) == 1 {
461 | k := key[0]
462 | if strings.Contains(k, ".") {
463 | keys := strings.Split(k, ".")
464 | if len(keys) == 0 {
465 | return 0
466 | }
467 | key = keys
468 | }
469 | }
470 |
471 | for i := 0; i < len(key); i++ {
472 | for m, rk := range rks {
473 | if key[i] == rk.Key {
474 | if rk.Value == nil {
475 | return 0
476 | }
477 | valueTypeName := reflect.TypeOf(rk.Value).String()
478 | switch valueTypeName {
479 | case "float64":
480 | return uint(rk.Value.(float64))
481 | case "int32":
482 | return uint(rk.Value.(int32))
483 | case "uint":
484 | return rk.Value.(uint)
485 | case "int":
486 | return uint(rk.Value.(int))
487 | case "httptest.Responses":
488 | return rks[m].Value.(Responses).GetUint(key[i:]...)
489 | }
490 | }
491 | }
492 | }
493 |
494 | return 0
495 | }
496 |
497 | // GetInt return int value
498 | func (rks Responses) GetInt(key ...string) int {
499 | if len(key) == 0 {
500 | return 0
501 | }
502 |
503 | if len(key) == 1 {
504 | k := key[0]
505 | if strings.Contains(k, ".") {
506 | keys := strings.Split(k, ".")
507 | if len(keys) == 0 {
508 | return 0
509 | }
510 | key = keys
511 | }
512 | }
513 |
514 | for i := 0; i < len(key); i++ {
515 | for m, rk := range rks {
516 | if key[i] == rk.Key {
517 | if rk.Value == nil {
518 | return 0
519 | }
520 | switch reflect.TypeOf(rk.Value).String() {
521 | case "float64":
522 | return int(rk.Value.(float64))
523 | case "int":
524 | return rk.Value.(int)
525 | case "int32":
526 | return int(rk.Value.(int32))
527 | case "uint":
528 | return int(rk.Value.(uint))
529 | case "httptest.Responses":
530 | return rks[m].Value.(Responses).GetInt(key[i+1:]...)
531 | }
532 | }
533 | }
534 | }
535 |
536 | return 0
537 | }
538 |
539 | // GetInt32 return int32.
540 | func (rks Responses) GetInt32(key ...string) int32 {
541 | if len(key) == 0 {
542 | return 0
543 | }
544 | if len(key) == 1 {
545 | k := key[0]
546 | if strings.Contains(k, ".") {
547 | keys := strings.Split(k, ".")
548 | if len(keys) == 0 {
549 | return 0
550 | }
551 | key = keys
552 | }
553 | }
554 | for i := 0; i < len(key); i++ {
555 | for m, rk := range rks {
556 | if key[i] == rk.Key {
557 | if rk.Value == nil {
558 | return 0
559 | }
560 | switch reflect.TypeOf(rk.Value).String() {
561 | case "float64":
562 | return int32(rk.Value.(float64))
563 | case "int32":
564 | return rk.Value.(int32)
565 | case "int":
566 | return int32(rk.Value.(int))
567 | case "uint":
568 | return int32(rk.Value.(uint))
569 | case "httptest.Responses":
570 | return rks[m].Value.(Responses).GetInt32(key[i+1:]...)
571 | }
572 | }
573 | }
574 | }
575 | return 0
576 | }
577 |
578 | // GetFloat64 return float64
579 | func (rks Responses) GetFloat64(key ...string) float64 {
580 | if len(key) == 0 {
581 | return 0
582 | }
583 | if len(key) == 1 {
584 | k := key[0]
585 | if strings.Contains(k, ".") {
586 | keys := strings.Split(k, ".")
587 | if len(keys) == 0 {
588 | return 0
589 | }
590 | key = keys
591 | }
592 | }
593 | for i := 0; i < len(key); i++ {
594 | for m, rk := range rks {
595 | if key[i] == rk.Key {
596 | if rk.Value == nil {
597 | return 0
598 | }
599 | switch reflect.TypeOf(rk.Value).String() {
600 | case "float64":
601 | return rk.Value.(float64)
602 | case "int":
603 | return float64(rk.Value.(int))
604 | case "int32":
605 | return float64(rk.Value.(int32))
606 | case "uint":
607 | return float64(rk.Value.(uint))
608 | case "httptest.Responses":
609 | return rks[m].Value.(Responses).GetFloat64(key[i+1:]...)
610 | }
611 | }
612 | }
613 | }
614 | return 0
615 | }
616 |
617 | // GetId return id
618 | func (res Responses) GetId(key ...string) uint {
619 | if len(key) == 0 {
620 | key = append(key, "data", "id")
621 | }
622 | return res.GetUint(key...)
623 | }
624 |
625 | var NotEmptyKey = arr.NewCheckArrayType(0)
626 |
627 | // Schema
628 | func Schema(str []byte) (Responses, error) {
629 | objs := Responses{}
630 | j := map[string]any{}
631 | if err := json.Unmarshal(str, &j); err != nil {
632 | return objs, fmt.Errorf("json unmarshal error %w", err)
633 | }
634 | if o, err := schema(j); err != nil {
635 | return objs, err
636 | } else {
637 | objs = o
638 | }
639 | return objs, nil
640 | }
641 |
642 | func (r Responses) Replace(key string, value any, testType ...string) {
643 | if len(r) == 0 {
644 | return
645 | }
646 |
647 | if !strings.Contains(key, ".") {
648 | for i1, k1 := range r {
649 | if k1.Key == key {
650 | r[i1].Value = value
651 | if len(testType) > 0 && testType[0] != "" {
652 | r[i1].Type = testType[0]
653 | }
654 | }
655 | }
656 | return
657 | }
658 | keys := strings.Split(key, ".")
659 | if len(keys) == 1 {
660 | for i1, k1 := range r {
661 | if k1.Key == keys[0] {
662 | r[i1].Value = value
663 | if len(testType) > 0 && testType[0] != "" {
664 | r[i1].Type = testType[0]
665 | }
666 | }
667 | }
668 | return
669 | }
670 | for i1, k1 := range r {
671 | if k1.Key != keys[0] || k1.Value == nil {
672 | continue
673 | }
674 | tof := reflect.TypeOf(k1.Value).String()
675 | if tof == "httptest.Responses" {
676 | r[i1].Value.(Responses).Replace(strings.Join(keys[1:], "."), value, testType...)
677 | } else if tof == "[]httptest.Responses" {
678 | if len(keys) <= 1 {
679 | continue
680 | }
681 | key1, _ := strconv.Atoi(keys[1])
682 | if r[i1].Value.([]Responses)[key1] != nil {
683 | r[i1].Value.([]Responses)[key1].Replace(strings.Join(keys[2:], "."), value, testType...)
684 | }
685 | }
686 | }
687 | }
688 |
689 | // schema
690 | func schema(j map[string]any) (Responses, error) {
691 | objs := Responses{}
692 | if j == nil {
693 | return objs, nil
694 | }
695 | for k, v := range j {
696 | if k == "" {
697 | continue
698 | }
699 | obj := schemaResponse(k, v)
700 | objs = append(objs, obj)
701 | }
702 | return objs, nil
703 | }
704 |
705 | // schemaResponse
706 | func schemaResponse(k string, v any) Response {
707 | obj := Response{}
708 | obj.Key = k
709 |
710 | if v == nil {
711 | return obj
712 | }
713 | typeName := reflect.TypeOf(v).String()
714 | switch typeName {
715 | case "bool":
716 | obj.Value = v.(bool)
717 | case "string":
718 | if obj.Key == "createdAt" || obj.Key == "updatedAt" || obj.Key == "deletedAt" {
719 | obj.Type = "notempty"
720 | } else if NotEmptyKey.Len() > 0 && NotEmptyKey.Check(obj.Key) {
721 | obj.Type = "notempty"
722 | } else {
723 | obj.Value = v.(string)
724 | }
725 | case "uint":
726 | obj.Value = v.(uint)
727 | case "int":
728 | obj.Value = v.(int)
729 | case "int32":
730 | obj.Value = v.(int32)
731 | case "float64":
732 | obj.Value = v.(float64)
733 | case "[]string":
734 | obj.Value = v.([]string)
735 | case "map[string]interface {}":
736 | if value, _ := schema(v.(map[string]any)); value != nil {
737 | obj.Value = value
738 | }
739 | case "[]interface {}":
740 | list := []Responses{}
741 | for _, v1 := range v.([]any) {
742 | listObj := Responses{}
743 | if v3, ok := v1.(map[string]any); ok {
744 | for k2, v2 := range v3 {
745 | listObj = append(listObj, schemaResponse(k2, v2))
746 | }
747 | list = append(list, listObj)
748 | obj.Value = list
749 | } else if _, ok := v1.(string); ok {
750 | obj.Value = v
751 | }
752 | }
753 |
754 | default:
755 | fmt.Printf("schemaResponse key:%s valueTypeName:%s\n", k, typeName)
756 | }
757 | return obj
758 | }
759 |
--------------------------------------------------------------------------------
/httptest/respose_test.go:
--------------------------------------------------------------------------------
1 | package httptest
2 |
3 | import (
4 | "encoding/json"
5 | "log"
6 | "net/http"
7 | "reflect"
8 | "testing"
9 |
10 | "github.com/gavv/httpexpect/v2"
11 | "github.com/gin-gonic/gin"
12 | "github.com/snowlyg/helper/arr"
13 | )
14 |
15 | func TestIdKeys(t *testing.T) {
16 | want := Responses{
17 | {Key: "id", Value: 0, Type: "ge"},
18 | }
19 | t.Run("Test id keys", func(t *testing.T) {
20 | idKeys := IdKeys()
21 | if !reflect.DeepEqual(want, idKeys) {
22 | t.Errorf("IdKeys want %+v but get %+v", want, idKeys)
23 | }
24 | })
25 | }
26 |
27 | func TestHttpTest(t *testing.T) {
28 | engine := gin.New()
29 | // Add /example route via handler function to the gin instance
30 | handler := GinHandler(engine)
31 | // Create httpexpect instance
32 | e := httpexpect.WithConfig(httpexpect.Config{
33 | Client: &http.Client{
34 | Transport: httpexpect.NewBinder(handler),
35 | Jar: httpexpect.NewCookieJar(),
36 | },
37 | Reporter: httpexpect.NewAssertReporter(t),
38 | Printers: []httpexpect.Printer{
39 | httpexpect.NewDebugPrinter(t, true),
40 | },
41 | })
42 | pageKeys := Responses{{Key: "message", Value: "OK"}, {Key: "status", Value: 200}, {Key: "data", Value: Response{Key: "message", Value: "pong"}}}
43 | value := e.GET("/example").Expect().Status(http.StatusOK).JSON()
44 |
45 | Test(value, pageKeys)
46 | }
47 |
48 | func TestHttpTestArray(t *testing.T) {
49 | engine := gin.New()
50 | // Add /example route via handler function to the gin instance
51 | handler := GinHandler(engine)
52 | // Create httpexpect instance
53 | e := httpexpect.WithConfig(httpexpect.Config{
54 | Client: &http.Client{
55 | Transport: httpexpect.NewBinder(handler),
56 | Jar: httpexpect.NewCookieJar(),
57 | },
58 | Reporter: httpexpect.NewAssertReporter(t),
59 | Printers: []httpexpect.Printer{
60 | httpexpect.NewDebugPrinter(t, true),
61 | },
62 | })
63 | pageKeys := []string{"1", "2"}
64 | value := e.GET("/array").Expect().Status(http.StatusOK).JSON()
65 | Test(value, pageKeys)
66 | }
67 |
68 | func TestHttpScan(t *testing.T) {
69 | engine := gin.New()
70 | // Add /example route via handler function to the gin instance
71 | handler := GinHandler(engine)
72 | // Create httpexpect instance
73 | e := httpexpect.WithConfig(httpexpect.Config{
74 | Client: &http.Client{
75 | Transport: httpexpect.NewBinder(handler),
76 | Jar: httpexpect.NewCookieJar(),
77 | },
78 | Reporter: httpexpect.NewAssertReporter(t),
79 | Printers: []httpexpect.Printer{
80 | httpexpect.NewDebugPrinter(t, true),
81 | },
82 | })
83 | pageKeys := Responses{{Key: "message", Value: ""}}
84 | obj := e.GET("/example").Expect().Status(http.StatusOK).JSON().Object()
85 |
86 | Scan(obj, pageKeys)
87 | x := pageKeys.GetString("data.message")
88 | if x != "pong" {
89 | t.Errorf("Scan want get pong but get %s", x)
90 | }
91 | }
92 |
93 | func TestSchema(t *testing.T) {
94 | wantJson := `{
95 | "status": 200,
96 | "data": {
97 | "list": [
98 | {
99 | "createdAt": "2025-03-21T16:27:20+08:00",
100 | "deletedAt": "",
101 | "updatedAt": "2025-03-21T16:27:20+08:00",
102 | "dev_remark": "",
103 | "pac_room_id": 1,
104 | "room_desc": "1413-301"
105 | }
106 | ],
107 | "total": 1,
108 | "page": 1,
109 | "pageSize": 20
110 | },
111 | "message": "OK"
112 | }`
113 | res, err := Schema([]byte(wantJson))
114 | if err != nil {
115 | t.Fatal(err.Error())
116 | }
117 | if res.GetInt("status") != 200 {
118 | t.Errorf("status want %d but get %d", 200, res.GetInt("status"))
119 | }
120 | if res.GetString("message") != "OK" {
121 | t.Errorf("message want %s but get %s", "OK", res.GetString("message"))
122 | }
123 | data := res.GetResponse("data")
124 | if data != nil {
125 | if data.GetInt("total") != 1 {
126 | t.Errorf("total want %d but get %d", 1, data.GetInt("total"))
127 | }
128 | if data.GetInt("page") != 1 {
129 | t.Errorf("page want %d but get %d", 1, data.GetInt("page"))
130 | }
131 | if data.GetInt("pageSize") != 20 {
132 | t.Errorf("pageSize want %d but get %d", 20, data.GetInt("pageSize"))
133 | }
134 | list := data.GetResponses("list")
135 | if len(list) != 1 {
136 | t.Errorf("list len want %d but get %d", 1, len(list))
137 | }
138 | first := list[0]
139 | if first.GetId("pac_room_id") != 1 {
140 | t.Errorf("pac_room_id want %d but get %d", 1, first.GetId("pac_room_id"))
141 | }
142 | roomDesc := first.GetString("room_desc")
143 | if roomDesc != "1413-301" {
144 | t.Errorf("room_desc want %s but get '%s'", "1413-301", roomDesc)
145 | }
146 | if first.GetString("dev_remark") != "" {
147 | t.Errorf("dev_remark want %s but get '%s'", "", first.GetString("dev_remark"))
148 | }
149 | keys := arr.NewCheckArrayType(0)
150 | for _, v := range first.Keys() {
151 | keys.Add(v)
152 | }
153 |
154 | for _, k := range []string{"createdAt", "deletedAt", "updatedAt", "dev_remark", "pac_room_id", "room_desc"} {
155 | if !keys.Check(k) {
156 | t.Errorf("%s not in keys", k)
157 | }
158 | }
159 | }
160 | }
161 |
162 | func TestSchemaResponse(t *testing.T) {
163 | data := `{
164 | "status": 200,
165 | "message": "OK"
166 | }`
167 | j := map[string]any{}
168 | if err := json.Unmarshal([]byte(data), &j); err != nil {
169 | t.Error(err.Error())
170 | }
171 | log.Printf("j %+v\n", j)
172 | wantKey := "data"
173 | resp := schemaResponse(wantKey, j)
174 |
175 | if value, ok := resp.Value.(Responses); !ok {
176 | t.Error("schema response return value not Responses")
177 | } else {
178 | keys := arr.NewCheckArrayType(0)
179 | for _, v := range value {
180 | keys.Add(v.Key)
181 | if v.Key == "message" {
182 | wantValue := "OK"
183 | if v.Value != wantValue {
184 | t.Errorf("%s Value want '%v' but get '%v'", v.Key, wantValue, v.Value)
185 | }
186 | } else if v.Key == "status" {
187 | var wantValue float64 = 200
188 | if v.Value != wantValue {
189 | t.Errorf("%s Value want '%v' but get '%v'", v.Key, wantValue, v.Value)
190 | }
191 | } else {
192 | t.Errorf("key %s is in response", v.Key)
193 | }
194 | }
195 | if !keys.Check("message") {
196 | t.Error("message not in keys")
197 | }
198 | if !keys.Check("status") {
199 | t.Error("status not in keys")
200 | }
201 | }
202 | }
203 |
--------------------------------------------------------------------------------
/loadtls.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/gin-gonic/gin"
7 | "github.com/unrolled/secure"
8 | )
9 |
10 | // LoadTls
11 | func LoadTls() gin.HandlerFunc {
12 | return func(c *gin.Context) {
13 | middleware := secure.New(secure.Options{
14 | SSLRedirect: true,
15 | SSLHost: "127.0.0.1:443",
16 | })
17 | err := middleware.Process(c.Writer, c.Request)
18 | if err != nil {
19 | fmt.Println(err)
20 | return
21 | }
22 | c.Next()
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/menu.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import "gorm.io/gorm"
4 |
5 | type Menu struct {
6 | gorm.Model
7 | Path string `json:"path"`
8 | Component string `json:"component"`
9 | Redirect string `json:"redirect"`
10 | Hidden bool `json:"hidden"`
11 | AlwaysShow bool `json:"alwaysShow"`
12 | Meta
13 | Children []*Menu `json:"children" gorm:"-"`
14 | }
15 |
16 | func (m *Menu) TableName() string {
17 | return "menus"
18 | }
19 |
20 | type Meta struct {
21 | Roles []string `json:"roles" gorm:"-"`
22 | Title string `json:"title"`
23 | Icon string `json:"icon"`
24 | NoCache bool `json:"noCache"`
25 | }
26 |
--------------------------------------------------------------------------------
/migrate.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "errors"
5 |
6 | "github.com/go-gormigrate/gormigrate/v2"
7 | "gorm.io/gorm"
8 | )
9 |
10 | // AddMigration add *gormigrate.Migration
11 | func (ws *WebServe) AddMigration(m ...*gormigrate.Migration) {
12 | ws.items = append(ws.items, m...)
13 | }
14 |
15 | // MigrationLen length of MigrationCollection
16 | func (ws *WebServe) MigrationLen() int {
17 | return len(ws.items)
18 | }
19 |
20 | // Refresh refresh migration
21 | func (mc *WebServe) Refresh() error {
22 | if mc.getFirstMigration() == "" {
23 | return nil
24 | }
25 | err := mc.rollbackTo(mc.getFirstMigration())
26 | if !errors.Is(err, gormigrate.ErrMigrationIDDoesNotExist) && err != nil {
27 | return err
28 | }
29 | return mc.Migrate()
30 | }
31 |
32 | // rollbackTo roolback migration to migrationId
33 | func (ws *WebServe) rollbackTo(migrationId string) error {
34 | return ws.m.RollbackTo(migrationId)
35 | }
36 |
37 | // Rollback roolback migrations
38 | func (ws *WebServe) Rollback(migrationId string) error {
39 | if ws.MigrationLen() == 0 {
40 | return nil
41 | }
42 | if migrationId == "" {
43 | err := ws.rollbackLast()
44 | if !errors.Is(err, gormigrate.ErrMigrationIDDoesNotExist) && err != nil {
45 | return err
46 | }
47 | return nil
48 | }
49 | err := ws.rollbackTo(migrationId)
50 | if !errors.Is(err, gormigrate.ErrMigrationIDDoesNotExist) && err != nil {
51 | return err
52 | }
53 | return nil
54 | }
55 |
56 | // rollbackLast roolback the lasted migration
57 | func (ws *WebServe) rollbackLast() error {
58 | return ws.m.RollbackLast()
59 | }
60 |
61 | // Migrate exec migration cmd
62 | func (ws *WebServe) Migrate() error {
63 | // add migrations
64 | ws.AddMigration(
65 | &gormigrate.Migration{
66 | ID: "init_system",
67 | Migrate: func(tx *gorm.DB) error {
68 | return tx.AutoMigrate(&Router{}, &Menu{})
69 | },
70 | Rollback: func(tx *gorm.DB) error {
71 | return tx.Migrator().DropTable(new(Router).TableName(), new(Menu).TableName())
72 | },
73 | },
74 | // add more migrations
75 | )
76 | if ws.m == nil {
77 | ws.m = gormigrate.New(ws.db, gormigrate.DefaultOptions, ws.items)
78 | }
79 | if err := ws.m.Migrate(); err != nil {
80 | return err
81 | }
82 | return nil
83 | }
84 |
85 | // getFirstMigration get first migration's id
86 | func (ws *WebServe) getFirstMigration() string {
87 | if ws.MigrationLen() == 0 {
88 | return ""
89 | }
90 | return ws.items[0].ID
91 | }
92 |
--------------------------------------------------------------------------------
/migrate_test.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
--------------------------------------------------------------------------------
/model.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "time"
7 |
8 | "github.com/gin-gonic/gin"
9 | "github.com/snowlyg/helper/str"
10 | "gorm.io/gorm"
11 | )
12 |
13 | type ErrMsg struct {
14 | Code int64 `json:"code"`
15 | Msg string `json:"message"`
16 | }
17 |
18 | var (
19 | ErrParamValidate = errors.New("param unvalidate")
20 | ErrPaginateParam = errors.New("paginate param unvalidate")
21 | ErrUnSupportFramework = errors.New("unsupport framework")
22 | )
23 |
24 | // Model
25 | type Model struct {
26 | Id uint `json:"id"`
27 | UpdatedAt string `json:"updatedAt"`
28 | CreatedAt string `json:"createdAt"`
29 | DeletedAt string `json:"deletedAt"`
30 | }
31 |
32 | // Paginate param for paginate query
33 | type Paginate struct {
34 | Page int `json:"page" form:"page"`
35 | PageSize int `json:"pageSize" form:"pageSize"`
36 | OrderBy string `json:"orderBy" form:"orderBy"`
37 | Sort string `json:"sort" form:"sort"`
38 | }
39 |
40 | func (req *Paginate) Request(ctx *gin.Context) error {
41 | if err := ctx.ShouldBind(req); err != nil {
42 | return ErrParamValidate
43 | }
44 | return nil
45 | }
46 |
47 | // PaginateScope paginate scope
48 | func (req *Paginate) PaginateScope() func(db *gorm.DB) *gorm.DB {
49 | return PaginateScope(req.Page, req.PageSize, req.Sort, req.OrderBy)
50 | }
51 |
52 | // IdScope
53 | func IdScope(id any) func(db *gorm.DB) *gorm.DB {
54 | return func(db *gorm.DB) *gorm.DB {
55 | return db.Where("id = ?", id)
56 | }
57 | }
58 |
59 | // InIdsScope
60 | func InIdsScope(ids []uint) func(db *gorm.DB) *gorm.DB {
61 | return func(db *gorm.DB) *gorm.DB {
62 | return db.Where("id in ?", ids)
63 | }
64 | }
65 |
66 | // InNamesScope
67 | func InNamesScope(names []string) func(db *gorm.DB) *gorm.DB {
68 | return func(db *gorm.DB) *gorm.DB {
69 | return db.Where("name in ?", names)
70 | }
71 | }
72 |
73 | // InUuidsScope
74 | func InUuidsScope(uuids []string) func(db *gorm.DB) *gorm.DB {
75 | return func(db *gorm.DB) *gorm.DB {
76 | return db.Where("uuid in ?", uuids)
77 | }
78 | }
79 |
80 | // NeIdScope
81 | func NeIdScope(id any) func(db *gorm.DB) *gorm.DB {
82 | return func(db *gorm.DB) *gorm.DB {
83 | return db.Where("id != ?", id)
84 | }
85 | }
86 |
87 | // PaginateScope return paginate scope for gorm
88 | func PaginateScope(page, pageSize int, sort, orderBy string) func(db *gorm.DB) *gorm.DB {
89 | return func(db *gorm.DB) *gorm.DB {
90 | pageSize := getPageSize(pageSize)
91 | offset := getOffset(page, pageSize)
92 | return db.Order(getOrderBy(sort, orderBy)).Offset(offset).Limit(pageSize)
93 | }
94 | }
95 |
96 | // getOffset
97 | func getOffset(page, pageSize int) int {
98 | if page == 0 {
99 | page = 1
100 | }
101 | offset := (page - 1) * pageSize
102 | if page < 0 {
103 | offset = -1
104 | }
105 | return offset
106 | }
107 |
108 | // getPageSize
109 | func getPageSize(pageSize int) int {
110 | switch {
111 | case pageSize > 100:
112 | pageSize = 100
113 | case pageSize < 0:
114 | pageSize = -1
115 | case pageSize == 0:
116 | pageSize = 10
117 | }
118 | return pageSize
119 | }
120 |
121 | // getOrderBy
122 | func getOrderBy(sort, orderBy string) string {
123 | if sort == "" {
124 | sort = "desc"
125 | }
126 | if orderBy == "" {
127 | orderBy = "created_at"
128 | }
129 | return str.Join(orderBy, " ", sort)
130 | }
131 |
132 | const (
133 | ResponseOkMessage = "OK"
134 | ResponseErrorMessage = "FAIL"
135 | )
136 |
137 | type Response struct {
138 | Code int `json:"status"`
139 | Data any `json:"data,omitempty"`
140 | Msg string `json:"message"`
141 | }
142 |
143 | func Result(code int, data any, msg string, ctx *gin.Context) {
144 | ctx.JSON(http.StatusOK, Response{code, data, msg})
145 | }
146 |
147 | func Ok(ctx *gin.Context) {
148 | Result(http.StatusOK, map[string]any{}, ResponseOkMessage, ctx)
149 | }
150 |
151 | func OkWithMessage(message string, ctx *gin.Context) {
152 | Result(http.StatusOK, map[string]any{}, message, ctx)
153 | }
154 |
155 | func OkWithData(data any, ctx *gin.Context) {
156 | Result(http.StatusOK, data, ResponseOkMessage, ctx)
157 | }
158 |
159 | func OkWithDetailed(data any, message string, ctx *gin.Context) {
160 | Result(http.StatusOK, data, message, ctx)
161 | }
162 |
163 | func Fail(ctx *gin.Context) {
164 | Result(http.StatusBadRequest, map[string]any{}, ResponseErrorMessage, ctx)
165 | }
166 |
167 | func UnauthorizedFailWithMessage(message string, ctx *gin.Context) {
168 | Result(http.StatusUnauthorized, map[string]any{}, message, ctx)
169 | }
170 |
171 | func UnauthorizedFailWithDetailed(data any, message string, ctx *gin.Context) {
172 | Result(http.StatusUnauthorized, data, message, ctx)
173 | }
174 |
175 | func ForbiddenFailWithMessage(message string, ctx *gin.Context) {
176 | Result(http.StatusForbidden, map[string]any{}, message, ctx)
177 | }
178 |
179 | func FailWithMessage(message string, ctx *gin.Context) {
180 | Result(http.StatusBadRequest, map[string]any{}, message, ctx)
181 | }
182 |
183 | func FailWithDetailed(data any, message string, ctx *gin.Context) {
184 | Result(http.StatusBadRequest, data, message, ctx)
185 | }
186 |
187 | type PageResult struct {
188 | List any `json:"list,omitempty"`
189 | Total int64 `json:"total"`
190 | Page int `json:"page"`
191 | PageSize int `json:"pageSize"`
192 | }
193 |
194 | type BaseResponse struct {
195 | Id uint `json:"id"`
196 | CreatedAt *time.Time `json:"createdAt"`
197 | UpdatedAt *time.Time `json:"updatedAt"`
198 | }
199 |
200 | // Paging common input parameter structure
201 | type PageInfo struct {
202 | Page int `json:"page" form:"page" validate:"required"`
203 | PageSize int `json:"pageSize" form:"pageSize" validate:"required"`
204 | OrderBy string `json:"orderBy" form:"orderBy"`
205 | SortBy string `json:"sortBy" form:"sortBy"`
206 | }
207 |
208 | type IdsBinding struct {
209 | Ids []uint `json:"ids" form:"ids" validate:"required,dive,required"`
210 | }
211 |
212 | // Request get id data form the context of every query
213 | func (req *IdsBinding) Request(ctx *gin.Context) error {
214 | if err := ctx.ShouldBind(req); err != nil {
215 | return ErrParamValidate
216 | }
217 | return nil
218 | }
219 |
220 | // Id the struct has used to get id form the context of every query
221 | type Id struct {
222 | Id uint `json:"id" uri:"id"`
223 | }
224 |
225 | // Request get id data form the context of every query
226 | func (req *Id) Request(ctx *gin.Context) error {
227 | if err := ctx.ShouldBindUri(req); err != nil {
228 | return ErrParamValidate
229 | }
230 | return nil
231 | }
232 |
233 | type Empty struct{}
234 |
--------------------------------------------------------------------------------
/router.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "log"
5 | "net/http"
6 | "path/filepath"
7 | "strings"
8 |
9 | "github.com/snowlyg/helper/arr"
10 | "gorm.io/gorm"
11 | )
12 |
13 | type Router struct {
14 | gorm.Model
15 | Path string `json:"path"`
16 | Title string `json:"title"`
17 | Group string `json:"group"`
18 | Method string `json:"method"`
19 | Children []*Router `json:"children" gorm:"-"`
20 | }
21 |
22 | func (m *Router) TableName() string {
23 | return "routers"
24 | }
25 |
26 | func (ws *WebServe) routers() {
27 | methodExcepts := strings.Split(ws.conf.Except.Method, ";")
28 | uriExcepts := strings.Split(ws.conf.Except.Uri, ";")
29 |
30 | // routeLen := len(ws.engine.Routes())
31 | // permRoutes := make([]*Router, 0, routeLen)
32 | // otherMethodTypes := make([]*Router, 0, routeLen)
33 |
34 | for _, r := range ws.engine.Routes() {
35 | // log.Printf("handler:%s, method:%s, path:%s\n", r.Handler, r.Method, r.Path)
36 | if strings.Contains(r.Path, "/*filepath") || r.Handler == "github.com/gin-gonic/gin.(*RouterGroup).createStaticHandler.func1" {
37 | continue
38 | }
39 | path := filepath.ToSlash(filepath.Clean(r.Path))
40 | route := &Router{
41 | Path: path,
42 | Title: path,
43 | Group: "",
44 | Method: r.Method,
45 | }
46 |
47 | httpStatusType := arr.NewCheckArrayType(4)
48 | httpStatusType.AddMutil(http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete)
49 | if !httpStatusType.Check(r.Method) {
50 | ws.otherRoutes = append(ws.otherRoutes, route)
51 | continue
52 | }
53 |
54 | if len(methodExcepts) > 0 && len(uriExcepts) > 0 && len(methodExcepts) == len(uriExcepts) {
55 | for i := range methodExcepts {
56 | if strings.EqualFold(r.Method, strings.ToLower(methodExcepts[i])) && strings.EqualFold(path, strings.ToLower(uriExcepts[i])) {
57 | ws.otherRoutes = append(ws.otherRoutes, route)
58 | continue
59 | }
60 | }
61 | }
62 | ws.permRoutes = append(ws.permRoutes, route)
63 | }
64 |
65 | // log.Printf("permRoutes:%d other:%d\n", len(ws.permRoutes), len(ws.otherRoutes))
66 |
67 | if ws.db == nil {
68 | return
69 | }
70 |
71 | if len(ws.permRoutes) == 0 {
72 | return
73 | }
74 |
75 | // seed routers
76 | olds := []*Router{}
77 | dels := []uint{}
78 | adds := []*Router{}
79 | if err := ws.db.Model(&Router{}).Find(&olds).Error; err != nil {
80 | log.Printf("iris-admin: old router find get err:%s\n", err.Error())
81 | }
82 |
83 | if len(olds) == 0 {
84 | if err := ws.db.Create(&ws.permRoutes).Error; err == nil {
85 | log.Printf("iris-admin: add %d router \n", len(ws.permRoutes))
86 | }
87 | return
88 | }
89 |
90 | oldCheck := arr.NewCheckArrayType(len(olds))
91 | for _, old := range olds {
92 | oldCheck.Add(old.Path)
93 | found := false
94 | for _, a := range ws.permRoutes {
95 | if old.Path == a.Path && old.Method == a.Method {
96 | found = true
97 | break
98 | }
99 | }
100 | if !found {
101 | dels = append(dels, old.ID)
102 | }
103 | }
104 |
105 | if len(dels) > 0 {
106 | if err := ws.db.Delete(&Router{}, dels).Error; err == nil {
107 | log.Printf("iris-admin: delete %d router\n", len(dels))
108 | }
109 | }
110 |
111 | for _, r := range ws.permRoutes {
112 | if !oldCheck.Check(r.Path) {
113 | adds = append(adds, r)
114 | }
115 | }
116 |
117 | if len(adds) > 0 {
118 | if err := ws.db.Create(&adds).Error; err == nil {
119 | log.Printf("iris-admin: add %d router,old:%d\n", len(adds), len(olds))
120 | }
121 | }
122 |
123 | }
124 |
--------------------------------------------------------------------------------
/run.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | type serve interface {
4 | ListenAndServe() error
5 | }
6 |
--------------------------------------------------------------------------------
/run_darwin.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "fmt"
5 | "syscall"
6 | "time"
7 |
8 | "github.com/fvbock/endless"
9 | "github.com/gin-gonic/gin"
10 | )
11 |
12 | func run(address string, router *gin.Engine) serve {
13 | s := endless.NewServer(address, router)
14 | s.BeforeBegin = func(add string) {
15 | fmt.Printf("Actual pid is %d\n", syscall.Getpid())
16 | }
17 | s.ReadHeaderTimeout = 10 * time.Millisecond
18 | s.WriteTimeout = 10 * time.Second
19 | s.MaxHeaderBytes = 1 << 20
20 | return s
21 | }
22 |
--------------------------------------------------------------------------------
/run_linux.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/fvbock/endless"
7 | "github.com/gin-gonic/gin"
8 | )
9 |
10 | func run(address string, router *gin.Engine) serve {
11 | s := endless.NewServer(address, router)
12 | s.ReadHeaderTimeout = 10 * time.Millisecond
13 | s.WriteTimeout = 10 * time.Second
14 | s.MaxHeaderBytes = 1 << 20
15 | return s
16 | }
17 |
--------------------------------------------------------------------------------
/run_windows.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "net/http"
5 | "time"
6 |
7 | "github.com/gin-gonic/gin"
8 | )
9 |
10 | func run(address string, router *gin.Engine) serve {
11 | return &http.Server{
12 | Addr: address,
13 | Handler: router,
14 | ReadTimeout: 10 * time.Second,
15 | WriteTimeout: 10 * time.Second,
16 | MaxHeaderBytes: 1 << 20,
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/server.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "time"
7 |
8 | limit "github.com/aviddiviner/gin-limit"
9 | "github.com/casbin/casbin/v2"
10 | "github.com/gin-gonic/gin"
11 | "github.com/go-gormigrate/gormigrate/v2"
12 | "github.com/mattn/go-colorable"
13 | "github.com/snowlyg/iris-admin/conf"
14 | "github.com/snowlyg/iris-admin/e"
15 |
16 | "gorm.io/driver/mysql"
17 | "gorm.io/gorm"
18 | )
19 |
20 | // Status
21 | const (
22 | StatusUnknown int = iota
23 | StatusTrue
24 | StatusFalse
25 | )
26 |
27 | type WebServe struct {
28 | serve
29 | conf *conf.Conf
30 | db *gorm.DB
31 | enforcer *casbin.Enforcer
32 | engine *gin.Engine
33 | iroutes *gin.IRoutes
34 |
35 | validate *Validator
36 |
37 | m *gormigrate.Gormigrate
38 | items []*gormigrate.Migration
39 |
40 | permRoutes []*Router
41 | otherRoutes []*Router
42 | }
43 |
44 | // gormDb
45 | func gormDb(m *conf.Mysql) (*gorm.DB, error) {
46 | if m == nil {
47 | return nil, e.ErrConfigInvalid
48 | }
49 | if m.DbName == "" {
50 | return nil, e.ErrDbTableNameEmpty
51 | }
52 | mysqlConfig := mysql.Config{
53 | DSN: m.Dsn(),
54 | DefaultStringSize: 191,
55 | // DisableDatetimePrecision: true,
56 | // DontSupportRenameIndex: true,
57 | // DontSupportRenameColumn: true,
58 | // SkipInitializeWithVersion: false,
59 | }
60 | if db, err := gorm.Open(mysql.New(mysqlConfig)); err != nil {
61 | fmt.Printf("open mysql[%s] is fail:%v\n", m.Dsn(), err)
62 | return nil, err
63 | } else {
64 | sqlDB, err := db.DB()
65 | if err != nil {
66 | return nil, err
67 | }
68 | if err := sqlDB.Ping(); err != nil {
69 | log.Printf("ping mysql[%s] is fail:%v\n", m.Dsn(), err)
70 | return nil, err
71 | }
72 | sqlDB.SetMaxIdleConns(m.MaxIdleConns)
73 | sqlDB.SetMaxOpenConns(m.MaxOpenConns)
74 | return db, nil
75 | }
76 | }
77 |
78 | // NewServe
79 | func NewServe(c *conf.Conf) (*WebServe, error) {
80 |
81 | gin.SetMode(c.System.GinMode)
82 | app := gin.Default()
83 | if c.System.Tls {
84 | app.Use(LoadTls())
85 | }
86 | app.Use(c.CorsConf.Cors())
87 | // registerValidation()
88 | gin.DefaultWriter = colorable.NewColorableStdout()
89 | c.SetDefaultAddrAndTimeFormat()
90 | db, err := gormDb(c.Mysql)
91 | if err != nil {
92 | return nil, err
93 | }
94 |
95 | auth, err := c.GetEnforcer(db)
96 | if err != nil {
97 | return nil, err
98 | }
99 |
100 | ws := &WebServe{
101 | conf: c,
102 | engine: app,
103 | enforcer: auth,
104 | db: db,
105 | permRoutes: []*Router{},
106 | otherRoutes: []*Router{},
107 | }
108 | if err := ws.Migrate(); err != nil {
109 | return nil, err
110 | }
111 |
112 | switch c.Locale {
113 | case "en":
114 | ws.validate = newEn()
115 | case "zh":
116 | ws.validate = newZh()
117 | default:
118 | ws.validate = newZh()
119 | }
120 |
121 | ws.engine.Use(limit.MaxAllowed(50))
122 |
123 | return ws, nil
124 | }
125 |
126 | // Engine return *gin.Engine
127 | func (ws *WebServe) Engine() *gin.Engine {
128 | return ws.engine
129 | }
130 |
131 | func (ws *WebServe) IRoutes() *gin.IRoutes {
132 | return ws.iroutes
133 | }
134 |
135 | // Config
136 | func (ws *WebServe) Config() *conf.Conf {
137 | return ws.conf
138 | }
139 |
140 | // SystemAddr
141 | func (ws *WebServe) SystemAddr() string {
142 | return ws.conf.System.Addr
143 | }
144 |
145 | // Auth
146 | func (ws *WebServe) Auth() *casbin.Enforcer {
147 | return ws.enforcer
148 | }
149 |
150 | // DB
151 | func (ws *WebServe) DB() *gorm.DB {
152 | return ws.db
153 | }
154 |
155 | // // Deprecated: use nginx or apache instead.
156 | // func (ws *WebServe) AddWebStatic(staticAbsPath, webPrefix string, paths ...string) {
157 | // }
158 |
159 | // // Deprecated: use nginx or apache instead.
160 | // func (ws *WebServe) AddUploadStatic(webPrefix, staticAbsPath string) {
161 | // }
162 |
163 | // Run
164 | func (ws *WebServe) Run() {
165 | if ws.engine == nil {
166 | panic("init engine please")
167 | }
168 |
169 | // ws.Engine().NoRoute(func(ctx *gin.Context) {
170 | // // excepte for /v0 /v1 and so on
171 | // reg := `^/v[0-9]+$|^(/v[0-9]+)/`
172 | // ok, _ := regexp.MatchString(reg, ctx.Request.RequestURI)
173 | // if ok {
174 | // ctx.Writer.WriteHeader(http.StatusNotFound)
175 | // ctx.Writer.Flush()
176 | // return
177 | // }
178 |
179 | // var indexFile []byte
180 | // for _, wp := range ws.statics {
181 | // // match /admin or /admin/***
182 | // reg := str.Join("^", wp.Prefix, "$|^(", wp.Prefix, ")/")
183 | // ok, err := regexp.MatchString(reg, ctx.Request.RequestURI)
184 | // if err != nil || !ok {
185 | // continue
186 | // }
187 | // indexFile = wp.IndexFile
188 | // }
189 |
190 | // ctx.Writer.WriteHeader(http.StatusOK)
191 | // ctx.Writer.Write(indexFile)
192 |
193 | // ctx.Writer.Header().Add("Accept", "text/html")
194 | // ctx.Writer.Flush()
195 | // })
196 |
197 | ws.routers()
198 |
199 | systemAddr := ws.SystemAddr()
200 | s := run(systemAddr, ws.engine)
201 | time.Sleep(10 * time.Microsecond)
202 |
203 | log.Printf("listen on: http://%s\n", systemAddr)
204 |
205 | s.ListenAndServe()
206 | }
207 |
--------------------------------------------------------------------------------
/server_test.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "os"
7 | "testing"
8 | "time"
9 |
10 | "github.com/snowlyg/iris-admin/conf"
11 | )
12 |
13 | func TestStart(t *testing.T) {
14 | go func() {
15 | os.Setenv("IRIS_ADMIN_WEB_ADDR", "127.0.0.1:18088")
16 | c := conf.NewConf()
17 | if serve, err := NewServe(c); err != nil {
18 | t.Error(err.Error())
19 | } else {
20 | serve.Engine()
21 | serve.Run()
22 | }
23 | }()
24 |
25 | time.Sleep(3 * time.Second)
26 |
27 | resp, err := http.Get("http://127.0.0.1:18088")
28 | if err != nil {
29 | t.Errorf("test web start get %v", err)
30 | }
31 | defer resp.Body.Close()
32 |
33 | _, err = io.ReadAll(resp.Body)
34 | if err != nil {
35 | t.Errorf("test web start get %v", err)
36 | }
37 |
38 | if resp.StatusCode != http.StatusNotFound {
39 | t.Errorf("test web start want [%d] but get [%d]", http.StatusNotFound, resp.StatusCode)
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/validate.go:
--------------------------------------------------------------------------------
1 | package admin
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 |
7 | "github.com/gin-gonic/gin"
8 | "github.com/go-playground/locales/en"
9 | "github.com/go-playground/locales/zh"
10 | ut "github.com/go-playground/universal-translator"
11 | "github.com/go-playground/validator/v10"
12 | en_translations "github.com/go-playground/validator/v10/translations/en"
13 | zh_translations "github.com/go-playground/validator/v10/translations/zh"
14 | )
15 |
16 | type Validator struct {
17 | uni *ut.UniversalTranslator
18 | validate *validator.Validate
19 | trans ut.Translator
20 | }
21 |
22 | func newZh() *Validator {
23 | zh := zh.New()
24 | uni := ut.New(zh, zh)
25 |
26 | // this is usually know or extracted from http 'Accept-Language' header
27 | // also see uni.FindTranslator(...)
28 | trans, _ := uni.GetTranslator("zh")
29 |
30 | validate := validator.New()
31 | zh_translations.RegisterDefaultTranslations(validate, trans)
32 | return &Validator{uni: uni, validate: validate, trans: trans}
33 | }
34 |
35 | func newEn() *Validator {
36 | en := en.New()
37 | uni := ut.New(en, en)
38 |
39 | // this is usually know or extracted from http 'Accept-Language' header
40 | // also see uni.FindTranslator(...)
41 | trans, _ := uni.GetTranslator("en")
42 |
43 | validate := validator.New()
44 | en_translations.RegisterDefaultTranslations(validate, trans)
45 | return &Validator{uni: uni, validate: validate, trans: trans}
46 | }
47 |
48 | type IdBinding struct {
49 | Id uint `json:"id" uri:"id" validate:"required"`
50 | }
51 |
52 | // ShouldBindUri binds the passed struct pointer using the specified binding engine.
53 | func (val *Validator) ShouldBindUri(ctx *gin.Context) (uint, error) {
54 | var id IdBinding
55 | if e := ctx.ShouldBindUri(&id); e != nil {
56 | return 0, e
57 | }
58 | if e := val.Translate(id); e != nil {
59 | return 0, e
60 | }
61 | return id.Id, nil
62 | }
63 |
64 | func (val *Validator) Translate(s any) error {
65 | if err := val.validate.Struct(s); err != nil {
66 | // translate all error at once
67 | errs := err.(validator.ValidationErrors)
68 | for _, v := range errs.Translate(val.trans) {
69 | if v != "" {
70 | return errors.New(v)
71 | }
72 | }
73 | }
74 | return nil
75 | }
76 |
77 | func (val *Validator) ValidateMap(data map[string]any, rules map[string]any) error {
78 | mapErrs := val.validate.ValidateMap(data, rules)
79 | for mk, err := range mapErrs {
80 | if err != nil {
81 | // translate all error at once
82 | errs := err.(validator.ValidationErrors)
83 | for _, v := range errs.Translate(val.trans) {
84 | if v != "" {
85 | return fmt.Errorf("%s%s", mk, v)
86 | }
87 | }
88 | }
89 | }
90 | return nil
91 | }
92 |
--------------------------------------------------------------------------------