├── .gitignore
├── LICENSE
├── README.md
├── args.go
├── benchmark_test.go
├── bind.go
├── bind_test.go
├── common.go
├── context.go
├── context_test.go
├── fs.go
├── gen.go
├── gen_test.go
├── go.mod
├── go.sum
├── h2_test.go
├── h3_test.go
├── h3client.go
├── img
└── benchmark.png
├── log_test.go
├── logger
├── color.go
├── color_test.go
└── log.go
├── logrus.log
├── middleware
├── basicAuth.go
├── basicAuth_test.go
├── cors.go
├── cors_test.go
├── csrf.go
├── csrf_test.go
├── gzip.go
├── gzip_test.go
├── jwt.go
├── jwt_test.go
├── logger.go
├── logger_test.go
├── multiMiddle_test.go
├── proxy.go
├── rateLimit.go
├── rateLimit_test.go
├── recovery.go
├── recovery_test.go
├── requestID.go
├── requestID_test.go
├── secure.go
├── secure_test.go
└── trace.go
├── pprof.go
├── pprof_test.go
├── response_overide.go
├── router.go
├── router_test.go
├── test
├── host.pb.go
└── host.proto
├── tree.go
├── yee.go
└── yee_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 | dist
8 | # Test binary, built with `go test -c`
9 | *.test
10 | testing/h5
11 | .idea
12 | .sum
13 | # Output of the go coverage tool, specifically when used with LiteIDE
14 | *.out
15 |
16 | # Dependency directories (remove the comment below to include it)
17 | # vendor/
18 | *.pem
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Henry Yee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Yee
2 |
3 | 
4 | 
5 |
6 | 🦄 Web frameworks for Go, easier & faster.
7 |
8 | This is a framework for learning purposes. Refer to the code for Echo and Gin
9 |
10 | - Faster HTTP router
11 | - Build RESTful APIs
12 | - Group APIs
13 | - Extensible middleware framework
14 | - Define middleware at root, group or route level
15 | - Data binding for URI Query, JSON, XML, Protocol Buffer3 and form payload
16 | - HTTP/2(H2C)/Http3(QUIC) support
17 |
18 | # Supported Go versions
19 |
20 | Yee is available as a Go module. You need to use Go 1.13 +
21 |
22 | ## Example
23 |
24 | #### Quick start
25 |
26 | ```go
27 | file, err := os.OpenFile("logrus.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
28 |
29 | if err != nil {
30 | return
31 | }
32 |
33 | y := yee.New()
34 |
35 | y.SetLogLevel(logger.Warning)
36 |
37 | y.SetLogOut(file)
38 |
39 | y.Use(Logger())
40 |
41 | y.Static("/assets", "dist/assets")
42 |
43 | y.GET("/", func(c yee.Context) (err error) {
44 | return c.HTMLTml(http.StatusOK, "dist/index.html")
45 | })
46 |
47 | y.GET("/hello", func(c yee.Context) error {
48 | return c.String(http.StatusOK, "
Hello Gee
")
49 | })
50 |
51 | y.POST("/test", func(c yee.Context) (err error) {
52 | u := new(p)
53 | if err := c.Bind(u); err != nil {
54 | return c.JSON(http.StatusOK, err.Error())
55 | }
56 | return c.JSON(http.StatusOK, u.Test)
57 | })
58 |
59 | y.Run(":9000")
60 | ```
61 |
62 | #### API
63 |
64 | Provide GET POST PUT DELETE HEAD OPTIONS TRACE PATCH
65 |
66 | ```go
67 |
68 | y := yee.New()
69 |
70 | y.GET("/someGet", handler)
71 | y.POST("/somePost", handler)
72 | y.PUT("/somePut", handler)
73 | y.DELETE("/someDelete", handler)
74 | y.PATCH("/somePatch", handler)
75 | y.HEAD("/someHead", handler)
76 | y.OPTIONS("/someOptions", handler)
77 |
78 | y.Run(":8000")
79 |
80 | ```
81 |
82 | #### Restful
83 |
84 | You can use Any & Restful method to implement your restful api
85 |
86 | + Any
87 |
88 | ```go
89 | y := yee.New()
90 |
91 | y.Any("/any", handler)
92 |
93 | y.Run(":8000")
94 |
95 | // All request methods for the same URL use the same handler
96 |
97 | ```
98 |
99 | + Restful
100 |
101 | ```go
102 |
103 | func userUpdate(c Context) (err error) {
104 | return c.String(http.StatusOK, "updated")
105 | }
106 |
107 | func userFetch(c Context) (err error) {
108 | return c.String(http.StatusOK, "get it")
109 | }
110 |
111 | func RestfulApi() yee.RestfulApi {
112 | return RestfulApi{
113 | Get: userFetch,
114 | Post: userUpdate,
115 | }
116 | }
117 |
118 | y := New()
119 |
120 | y.Restful("/", testRestfulApi())
121 |
122 | y.Run(":8000")
123 |
124 | // All request methods for the same URL use the different handler
125 |
126 |
127 | ```
128 |
129 | ## Middleware
130 |
131 | - basic auth
132 | - cors
133 | - crfs
134 | - gzip
135 | - jwt
136 | - logger
137 | - rate limit
138 | - recovery
139 | - secure
140 | - request id
141 |
142 | ## Benchmark
143 |
144 | Resource
145 | - CPU: i7-9750H
146 | - Memory: 16G
147 | - OS: macOS 10.15.5
148 |
149 | Date: 2020/06/17
150 |
151 | 
152 |
153 |
154 | ## License
155 |
156 | MIT
157 |
--------------------------------------------------------------------------------
/args.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "errors"
5 | )
6 |
7 | // Header types
8 | const (
9 | HeaderSecWebSocketProtocol = "Sec-Websocket-Protocol"
10 | HeaderAccept = "Accept"
11 | HeaderAcceptEncoding = "Accept-Encoding"
12 | HeaderAuthorization = "Authorization"
13 | HeaderContentDisposition = "Content-Disposition"
14 | HeaderContentEncoding = "Content-Encoding"
15 | HeaderContentLength = "Content-Length"
16 | HeaderContentType = "Content-Type"
17 | HeaderCookie = "Cookie"
18 | HeaderSetCookie = "Set-Cookie"
19 | HeaderIfModifiedSince = "If-Modified-Since"
20 | HeaderLastModified = "Last-Modified"
21 | HeaderLocation = "Location"
22 | HeaderUpgrade = "Upgrade"
23 | HeaderConnection = "Connection"
24 | HeaderVary = "Vary"
25 | HeaderWWWAuthenticate = "WWW-Authenticate"
26 | HeaderXForwardedFor = "X-Forwarded-For"
27 | HeaderXForwardedProto = "X-Forwarded-Proto"
28 | HeaderXForwardedProtocol = "X-Forwarded-Protocol"
29 | HeaderXForwardedSsl = "X-Forwarded-Ssl"
30 | HeaderXUrlScheme = "X-Url-Scheme"
31 | HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
32 | HeaderXRealIP = "X-Real-IP"
33 | HeaderXRequestID = "X-Request-ID"
34 | HeaderXRequestedWith = "X-Requested-With"
35 | HeaderServer = "Server"
36 | HeaderOrigin = "Origin"
37 |
38 | // Access control
39 | HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
40 | HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
41 | HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
42 | HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
43 | HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
44 | HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
45 | HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
46 | HeaderAccessControlMaxAge = "Access-Control-Max-Age"
47 |
48 | // Security
49 | HeaderStrictTransportSecurity = "Strict-Transport-Security"
50 | HeaderXContentTypeOptions = "X-Content-Type-Options"
51 | HeaderXXSSProtection = "X-XSS-Protection"
52 | HeaderXFrameOptions = "X-Frame-Options"
53 | HeaderContentSecurityPolicy = "Content-Security-Policy"
54 | HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
55 | HeaderXCSRFToken = "X-CSRF-Token"
56 | HeaderReferrerPolicy = "Referrer-Policy"
57 | )
58 |
59 | const (
60 | defaultMemory = 32 << 20 // 32 MB
61 | indexPage = "index.html"
62 | defaultIndent = " "
63 | )
64 |
65 | const (
66 | StatusCodeContextCanceled = 499
67 | )
68 |
69 | // MIME types
70 | const (
71 | MIMEApplicationJSON = "application/json"
72 | MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
73 | MIMEApplicationJavaScript = "application/javascript"
74 | MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8
75 | MIMEApplicationXML = "application/xml"
76 | MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8
77 | MIMETextXML = "text/xml"
78 | MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8
79 | MIMEApplicationForm = "application/x-www-form-urlencoded"
80 | MIMEApplicationProtobuf = "application/protobuf"
81 | MIMEApplicationMsgpack = "application/msgpack"
82 | MIMETextHTML = "text/html"
83 | MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8
84 | MIMETextPlain = "text/plain"
85 | MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8
86 | MIMEMultipartForm = "multipart/form-data"
87 | MIMEOctetStream = "application/octet-stream"
88 | )
89 |
90 | const (
91 | charsetUTF8 = "charset=UTF-8"
92 | serverName = "yee"
93 | )
94 |
95 | // Err types
96 | var (
97 | ErrUnsupportedMediaType = errors.New("http server not support media type")
98 | ErrValidatorNotRegistered = errors.New("validator not registered")
99 | ErrRendererNotRegistered = errors.New("renderer not registered")
100 | ErrInvalidRedirectCode = errors.New("invalid redirect status code")
101 | ErrCookieNotFound = errors.New("cookie not found")
102 | ErrNotFoundHandler = errors.New("404 NOT FOUND")
103 | ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
104 | )
105 |
--------------------------------------------------------------------------------
/benchmark_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "runtime"
7 | "testing"
8 | )
9 |
10 | type (
11 | Route struct {
12 | Method string
13 | Path string
14 | }
15 | )
16 |
17 | var gplusAPI = []*Route{
18 | // People
19 | {"GET", "/people/:userId"},
20 | {"GET", "/people"},
21 | {"GET", "/activities/:activityId/people/:collection"},
22 | {"GET", "/people/:userId/people/:collection"},
23 | {"GET", "/people/:userId/openIdConnect"},
24 |
25 | // Activities
26 | {"GET", "/people/:userId/activities/:collection"},
27 | {"GET", "/activities/:activityId"},
28 | {"GET", "/activities"},
29 |
30 | // Comments
31 | {"GET", "/activities/:activityId/comments"},
32 | {"GET", "/comments/:commentId"},
33 |
34 | // Moments
35 | {"POST", "/people/:userId/moments/:collection"},
36 | {"GET", "/people/:userId/moments/:collection"},
37 | {"DELETE", "/moments/:id"},
38 | }
39 |
40 | var githubAPI = []*Route{
41 | // OAuth Authorizations
42 | {"GET", "/authorizations"},
43 | {"GET", "/authorizations/:id"},
44 | {"POST", "/authorizations"},
45 | //{"PUT", "/authorizations/clients/:client_id"},
46 | //{"PATCH", "/authorizations/:id"},
47 | {"DELETE", "/authorizations/:id"},
48 | {"GET", "/applications/:client_id/tokens/:access_token"},
49 | {"DELETE", "/applications/:client_id/tokens"},
50 | {"DELETE", "/applications/:client_id/tokens/:access_token"},
51 |
52 | // Activity
53 | {"GET", "/events"},
54 | {"GET", "/repos/:owner/:repo/events"},
55 | {"GET", "/networks/:owner/:repo/events"},
56 | {"GET", "/orgs/:org/events"},
57 | {"GET", "/users/:user/received_events"},
58 | {"GET", "/users/:user/received_events/public"},
59 | {"GET", "/users/:user/events"},
60 | {"GET", "/users/:user/events/public"},
61 | {"GET", "/users/:user/events/orgs/:org"},
62 | {"GET", "/feeds"},
63 | {"GET", "/notifications"},
64 | {"GET", "/repos/:owner/:repo/notifications"},
65 | {"PUT", "/notifications"},
66 | {"PUT", "/repos/:owner/:repo/notifications"},
67 | {"GET", "/notifications/threads/:id"},
68 | //{"PATCH", "/notifications/threads/:id"},
69 | {"GET", "/notifications/threads/:id/subscription"},
70 | {"PUT", "/notifications/threads/:id/subscription"},
71 | {"DELETE", "/notifications/threads/:id/subscription"},
72 | {"GET", "/repos/:owner/:repo/stargazers"},
73 | {"GET", "/users/:user/starred"},
74 | {"GET", "/user/starred"},
75 | {"GET", "/user/starred/:owner/:repo"},
76 | {"PUT", "/user/starred/:owner/:repo"},
77 | {"DELETE", "/user/starred/:owner/:repo"},
78 | {"GET", "/repos/:owner/:repo/subscribers"},
79 | {"GET", "/users/:user/subscriptions"},
80 | {"GET", "/user/subscriptions"},
81 | {"GET", "/repos/:owner/:repo/subscription"},
82 | {"PUT", "/repos/:owner/:repo/subscription"},
83 | {"DELETE", "/repos/:owner/:repo/subscription"},
84 | {"GET", "/user/subscriptions/:owner/:repo"},
85 | {"PUT", "/user/subscriptions/:owner/:repo"},
86 | {"DELETE", "/user/subscriptions/:owner/:repo"},
87 |
88 | // Gists
89 | {"GET", "/users/:user/gists"},
90 | {"GET", "/gists"},
91 | //{"GET", "/gists/public"},
92 | //{"GET", "/gists/starred"},
93 | {"GET", "/gists/:id"},
94 | {"POST", "/gists"},
95 | //{"PATCH", "/gists/:id"},
96 | {"PUT", "/gists/:id/star"},
97 | {"DELETE", "/gists/:id/star"},
98 | {"GET", "/gists/:id/star"},
99 | {"POST", "/gists/:id/forks"},
100 | {"DELETE", "/gists/:id"},
101 |
102 | // Git Data
103 | {"GET", "/repos/:owner/:repo/git/blobs/:sha"},
104 | {"POST", "/repos/:owner/:repo/git/blobs"},
105 | {"GET", "/repos/:owner/:repo/git/commits/:sha"},
106 | {"POST", "/repos/:owner/:repo/git/commits"},
107 | //{"GET", "/repos/:owner/:repo/git/refs/*ref"},
108 | {"GET", "/repos/:owner/:repo/git/refs"},
109 | {"POST", "/repos/:owner/:repo/git/refs"},
110 | //{"PATCH", "/repos/:owner/:repo/git/refs/*ref"},
111 | //{"DELETE", "/repos/:owner/:repo/git/refs/*ref"},
112 | {"GET", "/repos/:owner/:repo/git/tags/:sha"},
113 | {"POST", "/repos/:owner/:repo/git/tags"},
114 | {"GET", "/repos/:owner/:repo/git/trees/:sha"},
115 | {"POST", "/repos/:owner/:repo/git/trees"},
116 |
117 | // Issues
118 | {"GET", "/issues"},
119 | {"GET", "/user/issues"},
120 | {"GET", "/orgs/:org/issues"},
121 | {"GET", "/repos/:owner/:repo/issues"},
122 | {"GET", "/repos/:owner/:repo/issues/:number"},
123 | {"POST", "/repos/:owner/:repo/issues"},
124 | //{"PATCH", "/repos/:owner/:repo/issues/:number"},
125 | {"GET", "/repos/:owner/:repo/assignees"},
126 | {"GET", "/repos/:owner/:repo/assignees/:assignee"},
127 | {"GET", "/repos/:owner/:repo/issues/:number/comments"},
128 | //{"GET", "/repos/:owner/:repo/issues/comments"},
129 | //{"GET", "/repos/:owner/:repo/issues/comments/:id"},
130 | {"POST", "/repos/:owner/:repo/issues/:number/comments"},
131 | //{"PATCH", "/repos/:owner/:repo/issues/comments/:id"},
132 | //{"DELETE", "/repos/:owner/:repo/issues/comments/:id"},
133 | {"GET", "/repos/:owner/:repo/issues/:number/events"},
134 | //{"GET", "/repos/:owner/:repo/issues/events"},
135 | //{"GET", "/repos/:owner/:repo/issues/events/:id"},
136 | {"GET", "/repos/:owner/:repo/labels"},
137 | {"GET", "/repos/:owner/:repo/labels/:name"},
138 | {"POST", "/repos/:owner/:repo/labels"},
139 | //{"PATCH", "/repos/:owner/:repo/labels/:name"},
140 | {"DELETE", "/repos/:owner/:repo/labels/:name"},
141 | {"GET", "/repos/:owner/:repo/issues/:number/labels"},
142 | {"POST", "/repos/:owner/:repo/issues/:number/labels"},
143 | {"DELETE", "/repos/:owner/:repo/issues/:number/labels/:name"},
144 | {"PUT", "/repos/:owner/:repo/issues/:number/labels"},
145 | {"DELETE", "/repos/:owner/:repo/issues/:number/labels"},
146 | {"GET", "/repos/:owner/:repo/milestones/:number/labels"},
147 | {"GET", "/repos/:owner/:repo/milestones"},
148 | {"GET", "/repos/:owner/:repo/milestones/:number"},
149 | {"POST", "/repos/:owner/:repo/milestones"},
150 | //{"PATCH", "/repos/:owner/:repo/milestones/:number"},
151 | {"DELETE", "/repos/:owner/:repo/milestones/:number"},
152 |
153 | // Miscellaneous
154 | {"GET", "/emojis"},
155 | {"GET", "/gitignore/templates"},
156 | {"GET", "/gitignore/templates/:name"},
157 | {"POST", "/markdown"},
158 | {"POST", "/markdown/raw"},
159 | {"GET", "/meta"},
160 | {"GET", "/rate_limit"},
161 |
162 | // Organizations
163 | {"GET", "/users/:user/orgs"},
164 | {"GET", "/user/orgs"},
165 | {"GET", "/orgs/:org"},
166 | //{"PATCH", "/orgs/:org"},
167 | {"GET", "/orgs/:org/members"},
168 | {"GET", "/orgs/:org/members/:user"},
169 | {"DELETE", "/orgs/:org/members/:user"},
170 | {"GET", "/orgs/:org/public_members"},
171 | {"GET", "/orgs/:org/public_members/:user"},
172 | {"PUT", "/orgs/:org/public_members/:user"},
173 | {"DELETE", "/orgs/:org/public_members/:user"},
174 | {"GET", "/orgs/:org/teams"},
175 | {"GET", "/teams/:id"},
176 | {"POST", "/orgs/:org/teams"},
177 | //{"PATCH", "/teams/:id"},
178 | {"DELETE", "/teams/:id"},
179 | {"GET", "/teams/:id/members"},
180 | {"GET", "/teams/:id/members/:user"},
181 | {"PUT", "/teams/:id/members/:user"},
182 | {"DELETE", "/teams/:id/members/:user"},
183 | {"GET", "/teams/:id/repos"},
184 | {"GET", "/teams/:id/repos/:owner/:repo"},
185 | {"PUT", "/teams/:id/repos/:owner/:repo"},
186 | {"DELETE", "/teams/:id/repos/:owner/:repo"},
187 | {"GET", "/user/teams"},
188 |
189 | // Pull Requests
190 | {"GET", "/repos/:owner/:repo/pulls"},
191 | {"GET", "/repos/:owner/:repo/pulls/:number"},
192 | {"POST", "/repos/:owner/:repo/pulls"},
193 | //{"PATCH", "/repos/:owner/:repo/pulls/:number"},
194 | {"GET", "/repos/:owner/:repo/pulls/:number/commits"},
195 | {"GET", "/repos/:owner/:repo/pulls/:number/files"},
196 | {"GET", "/repos/:owner/:repo/pulls/:number/merge"},
197 | {"PUT", "/repos/:owner/:repo/pulls/:number/merge"},
198 | {"GET", "/repos/:owner/:repo/pulls/:number/comments"},
199 | //{"GET", "/repos/:owner/:repo/pulls/comments"},
200 | //{"GET", "/repos/:owner/:repo/pulls/comments/:number"},
201 | {"PUT", "/repos/:owner/:repo/pulls/:number/comments"},
202 | //{"PATCH", "/repos/:owner/:repo/pulls/comments/:number"},
203 | //{"DELETE", "/repos/:owner/:repo/pulls/comments/:number"},
204 |
205 | // Repositories
206 | {"GET", "/user/repos"},
207 | {"GET", "/users/:user/repos"},
208 | {"GET", "/orgs/:org/repos"},
209 | {"GET", "/repositories"},
210 | {"POST", "/user/repos"},
211 | {"POST", "/orgs/:org/repos"},
212 | {"GET", "/repos/:owner/:repo"},
213 | //{"PATCH", "/repos/:owner/:repo"},
214 | {"GET", "/repos/:owner/:repo/contributors"},
215 | {"GET", "/repos/:owner/:repo/languages"},
216 | {"GET", "/repos/:owner/:repo/teams"},
217 | {"GET", "/repos/:owner/:repo/tags"},
218 | {"GET", "/repos/:owner/:repo/branches"},
219 | {"GET", "/repos/:owner/:repo/branches/:branch"},
220 | {"DELETE", "/repos/:owner/:repo"},
221 | {"GET", "/repos/:owner/:repo/collaborators"},
222 | {"GET", "/repos/:owner/:repo/collaborators/:user"},
223 | {"PUT", "/repos/:owner/:repo/collaborators/:user"},
224 | {"DELETE", "/repos/:owner/:repo/collaborators/:user"},
225 | {"GET", "/repos/:owner/:repo/comments"},
226 | {"GET", "/repos/:owner/:repo/commits/:sha/comments"},
227 | {"POST", "/repos/:owner/:repo/commits/:sha/comments"},
228 | {"GET", "/repos/:owner/:repo/comments/:id"},
229 | //{"PATCH", "/repos/:owner/:repo/comments/:id"},
230 | {"DELETE", "/repos/:owner/:repo/comments/:id"},
231 | {"GET", "/repos/:owner/:repo/commits"},
232 | {"GET", "/repos/:owner/:repo/commits/:sha"},
233 | {"GET", "/repos/:owner/:repo/readme"},
234 | //{"GET", "/repos/:owner/:repo/contents/*path"},
235 | //{"PUT", "/repos/:owner/:repo/contents/*path"},
236 | //{"DELETE", "/repos/:owner/:repo/contents/*path"},
237 | //{"GET", "/repos/:owner/:repo/:archive_format/:ref"},
238 | {"GET", "/repos/:owner/:repo/keys"},
239 | {"GET", "/repos/:owner/:repo/keys/:id"},
240 | {"POST", "/repos/:owner/:repo/keys"},
241 | //{"PATCH", "/repos/:owner/:repo/keys/:id"},
242 | {"DELETE", "/repos/:owner/:repo/keys/:id"},
243 | {"GET", "/repos/:owner/:repo/downloads"},
244 | {"GET", "/repos/:owner/:repo/downloads/:id"},
245 | {"DELETE", "/repos/:owner/:repo/downloads/:id"},
246 | {"GET", "/repos/:owner/:repo/forks"},
247 | {"POST", "/repos/:owner/:repo/forks"},
248 | {"GET", "/repos/:owner/:repo/hooks"},
249 | {"GET", "/repos/:owner/:repo/hooks/:id"},
250 | {"POST", "/repos/:owner/:repo/hooks"},
251 | //{"PATCH", "/repos/:owner/:repo/hooks/:id"},
252 | {"POST", "/repos/:owner/:repo/hooks/:id/tests"},
253 | {"DELETE", "/repos/:owner/:repo/hooks/:id"},
254 | {"POST", "/repos/:owner/:repo/merges"},
255 | {"GET", "/repos/:owner/:repo/releases"},
256 | {"GET", "/repos/:owner/:repo/releases/:id"},
257 | {"POST", "/repos/:owner/:repo/releases"},
258 | //{"PATCH", "/repos/:owner/:repo/releases/:id"},
259 | {"DELETE", "/repos/:owner/:repo/releases/:id"},
260 | {"GET", "/repos/:owner/:repo/releases/:id/assets"},
261 | {"GET", "/repos/:owner/:repo/stats/contributors"},
262 | {"GET", "/repos/:owner/:repo/stats/commit_activity"},
263 | {"GET", "/repos/:owner/:repo/stats/code_frequency"},
264 | {"GET", "/repos/:owner/:repo/stats/participation"},
265 | {"GET", "/repos/:owner/:repo/stats/punch_card"},
266 | {"GET", "/repos/:owner/:repo/statuses/:ref"},
267 | {"POST", "/repos/:owner/:repo/statuses/:ref"},
268 |
269 | // Search
270 | {"GET", "/search/repositories"},
271 | {"GET", "/search/code"},
272 | {"GET", "/search/issues"},
273 | {"GET", "/search/users"},
274 | {"GET", "/legacy/issues/search/:owner/:repository/:state/:keyword"},
275 | {"GET", "/legacy/repos/search/:keyword"},
276 | {"GET", "/legacy/user/search/:keyword"},
277 | {"GET", "/legacy/user/email/:email"},
278 |
279 | // Users
280 | {"GET", "/users/:user"},
281 | {"GET", "/user"},
282 | //{"PATCH", "/user"},
283 | {"GET", "/users"},
284 | {"GET", "/user/emails"},
285 | {"POST", "/user/emails"},
286 | {"DELETE", "/user/emails"},
287 | {"GET", "/users/:user/followers"},
288 | {"GET", "/user/followers"},
289 | {"GET", "/users/:user/following"},
290 | {"GET", "/user/following"},
291 | {"GET", "/user/following/:user"},
292 | {"GET", "/users/:user/following/:target_user"},
293 | {"PUT", "/user/following/:user"},
294 | {"DELETE", "/user/following/:user"},
295 | {"GET", "/users/:user/keys"},
296 | {"GET", "/user/keys"},
297 | {"GET", "/user/keys/:id"},
298 | {"POST", "/user/keys"},
299 | //{"PATCH", "/user/keys/:id"},
300 | {"DELETE", "/user/keys/:id"},
301 | }
302 |
303 | var parseAPI = []*Route{
304 | // Objects
305 | {"POST", "/1/classes/:className"},
306 | {"GET", "/1/classes/:className/:objectId"},
307 | {"PUT", "/1/classes/:className/:objectId"},
308 | {"GET", "/1/classes/:className"},
309 | {"DELETE", "/1/classes/:className/:objectId"},
310 |
311 | // Users
312 | {"POST", "/1/users"},
313 | {"GET", "/1/login"},
314 | {"GET", "/1/users/:objectId"},
315 | {"PUT", "/1/users/:objectId"},
316 | {"GET", "/1/users"},
317 | {"DELETE", "/1/users/:objectId"},
318 | {"POST", "/1/requestPasswordReset"},
319 |
320 | // Roles
321 | {"POST", "/1/roles"},
322 | {"GET", "/1/roles/:objectId"},
323 | {"PUT", "/1/roles/:objectId"},
324 | {"GET", "/1/roles"},
325 | {"DELETE", "/1/roles/:objectId"},
326 |
327 | // Files
328 | {"POST", "/1/files/:fileName"},
329 |
330 | // Analytics
331 | {"POST", "/1/events/:eventName"},
332 |
333 | // Push Notifications
334 | {"POST", "/1/push"},
335 |
336 | // Installations
337 | {"POST", "/1/installations"},
338 | {"GET", "/1/installations/:objectId"},
339 | {"PUT", "/1/installations/:objectId"},
340 | {"GET", "/1/installations"},
341 | {"DELETE", "/1/installations/:objectId"},
342 |
343 | // Cloud Functions
344 | {"POST", "/1/functions"},
345 | }
346 |
347 | func yeeHandler(method, path string) HandlerFunc {
348 | return func(c Context) error {
349 | return c.String(http.StatusOK, "OK")
350 | }
351 | }
352 |
353 | func loadYeeRoutes(e *Core, routes []*Route) {
354 | for _, r := range routes {
355 | switch r.Method {
356 | case "GET":
357 | e.GET(r.Path, yeeHandler(r.Method, r.Path))
358 | case "POST":
359 | e.POST(r.Path, yeeHandler(r.Method, r.Path))
360 | case "PATCH":
361 | e.PATCH(r.Path, yeeHandler(r.Method, r.Path))
362 | case "PUT":
363 | e.PUT(r.Path, yeeHandler(r.Method, r.Path))
364 | case "DELETE":
365 | e.DELETE(r.Path, yeeHandler(r.Method, r.Path))
366 | }
367 | }
368 | }
369 |
370 | func benchmarkRoutes(b *testing.B, router http.Handler, routes []*Route) {
371 | b.ReportAllocs()
372 | r := httptest.NewRequest("GET", "/", nil)
373 | u := r.URL
374 | w := httptest.NewRecorder()
375 | b.SetBytes(1024 * 1024)
376 | for i := 0; i < b.N; i++ {
377 | for _, route := range routes {
378 | r.Method = route.Method
379 | u.Path = route.Path
380 | router.ServeHTTP(w, r)
381 | }
382 | }
383 | }
384 |
385 | func BenchmarkYeeParseAPI(b *testing.B) {
386 | e := C()
387 | loadYeeRoutes(e, parseAPI)
388 | benchmarkRoutes(b, e, parseAPI)
389 | }
390 |
391 | func BenchmarkYeeGplusAPI(b *testing.B) {
392 | e := C()
393 | loadYeeRoutes(e, gplusAPI)
394 | benchmarkRoutes(b, e, gplusAPI)
395 | }
396 |
397 | func BenchmarkYeeGitHubAPI(b *testing.B) {
398 | e := C()
399 | loadYeeRoutes(e, githubAPI)
400 | benchmarkRoutes(b, e, githubAPI)
401 | }
402 |
403 | func BenchmarkYeeStatic(b *testing.B) {
404 | e := C()
405 | e.Static("/front", "color")
406 | b.ReportAllocs()
407 | r := httptest.NewRequest("GET", "/front/color.go", nil)
408 | w := httptest.NewRecorder()
409 | b.SetBytes(1024 * 1024)
410 | for i := 0; i < b.N; i++ {
411 | e.ServeHTTP(w, r)
412 | }
413 | }
414 |
415 | func Benchmark(b *testing.B) {
416 | runtime.GOMAXPROCS(runtime.NumCPU())
417 | b.Run("BenchmarkYeeGplusAPI", BenchmarkYeeGplusAPI)
418 | b.Run("BenchmarkYeeParseAPI", BenchmarkYeeParseAPI)
419 | b.Run("BenchmarkYeeGitHubAPI", BenchmarkYeeGitHubAPI)
420 | }
421 |
--------------------------------------------------------------------------------
/bind.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "encoding"
5 | "encoding/json"
6 | "encoding/xml"
7 | "errors"
8 | "fmt"
9 | "github.com/go-playground/validator/v10"
10 | "github.com/golang/protobuf/proto"
11 | "io"
12 | "reflect"
13 | "strconv"
14 | "strings"
15 | )
16 |
17 | type (
18 | // DefaultBinder is the default implementation of the Binder interface.
19 | DefaultBinder struct{}
20 |
21 | // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
22 | // Types that don't implement this, but do implement encoding.TextUnmarshaler
23 | // will use that interface instead.
24 | BindUnmarshaler interface {
25 | // UnmarshalParam decodes and assigns a value from an form or query param.
26 | UnmarshalParam(param string) error
27 | }
28 | )
29 |
30 | func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
31 | if err := b.bind(i, c); err != nil {
32 | return err
33 | }
34 | validate := validator.New()
35 | value := reflect.ValueOf(i)
36 | if value.Kind() == reflect.Ptr {
37 | elem := value.Elem()
38 | return validate.Struct(elem)
39 | } else {
40 | return validate.Struct(i)
41 | }
42 | }
43 |
44 | // Bind implements the `Binder#Bind` function.
45 | func (b *DefaultBinder) bind(i interface{}, c Context) (err error) {
46 | req := c.Request()
47 | if err = b.bindData(i, c.QueryParams(), "json"); err != nil {
48 | return err
49 | }
50 |
51 | if req.ContentLength == 0 {
52 | return
53 | }
54 |
55 | ctype := strings.ToLower(req.Header.Get(HeaderContentType))
56 | switch {
57 | case strings.HasPrefix(ctype, MIMEApplicationProtobuf):
58 | buf, err := io.ReadAll(req.Body)
59 | if err != nil {
60 | return err
61 | }
62 | if err := proto.Unmarshal(buf, i.(proto.Message)); err != nil {
63 | return err
64 | }
65 | case strings.HasPrefix(ctype, MIMEApplicationJSON):
66 | if err = json.NewDecoder(req.Body).Decode(i); err != nil {
67 | var ute *json.UnmarshalTypeError
68 | if errors.As(err, &ute) {
69 | return errors.New(fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset))
70 | }
71 | return err
72 | }
73 | case strings.HasPrefix(ctype, MIMEApplicationXML),
74 | strings.HasPrefix(ctype, MIMETextXML),
75 | strings.HasPrefix(ctype, MIMETextXMLCharsetUTF8),
76 | strings.HasPrefix(ctype, MIMEApplicationXMLCharsetUTF8):
77 | if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
78 | if ute, ok := err.(*xml.UnsupportedTypeError); ok {
79 | return errors.New(fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
80 | } else if se, ok := err.(*xml.SyntaxError); ok {
81 | return errors.New(fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
82 | }
83 | return err
84 | }
85 | case strings.HasPrefix(ctype, MIMEOctetStream):
86 |
87 | case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
88 | params, err := c.FormParams()
89 | if err != nil {
90 | return err
91 | }
92 | if err = b.bindData(i, params, "form"); err != nil {
93 | return err
94 | }
95 | default:
96 | return ErrUnsupportedMediaType
97 | }
98 | return
99 | }
100 |
101 | func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error {
102 | if ptr == nil || len(data) == 0 {
103 | return nil
104 | }
105 | typ := reflect.TypeOf(ptr).Elem()
106 | val := reflect.ValueOf(ptr).Elem()
107 | // Map
108 | if typ.Kind() == reflect.Map {
109 | for k, v := range data {
110 | val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
111 | }
112 | return nil
113 | }
114 |
115 | // !struct
116 | if typ.Kind() != reflect.Struct {
117 | return errors.New("binding element must be a struct")
118 | }
119 | for i := 0; i < typ.NumField(); i++ {
120 | typeField := typ.Field(i)
121 | structField := val.Field(i)
122 | if !structField.CanSet() {
123 | continue
124 | }
125 | structFieldKind := structField.Kind()
126 | inputFieldName := typeField.Tag.Get(tag)
127 | if inputFieldName == "" {
128 | inputFieldName = typeField.Name
129 | // If tag is nil, we inspect if the field is a struct.
130 | if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
131 | if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
132 | return err
133 | }
134 | continue
135 | }
136 | }
137 |
138 | inputValue, exists := data[inputFieldName]
139 | if !exists {
140 | // Go json.Unmarshal supports case insensitive binding. However the
141 | // url params are bound case sensitive which is inconsistent. To
142 | // fix this we must check all of the map values in a
143 | // case-insensitive search.
144 | for k, v := range data {
145 | if strings.EqualFold(k, inputFieldName) {
146 | inputValue = v
147 | exists = true
148 | break
149 | }
150 | }
151 | }
152 |
153 | if !exists {
154 | continue
155 | }
156 |
157 | // Call this first, in case we're dealing with an alias to an array type
158 | if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
159 | if err != nil {
160 | return err
161 | }
162 | continue
163 | }
164 |
165 | numElems := len(inputValue)
166 | if structFieldKind == reflect.Slice && numElems > 0 {
167 | sliceOf := structField.Type().Elem().Kind()
168 | slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
169 | for j := 0; j < numElems; j++ {
170 | if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
171 | return err
172 | }
173 | }
174 | val.Field(i).Set(slice)
175 | } else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
176 | return err
177 |
178 | }
179 | }
180 | return nil
181 | }
182 |
183 | func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
184 | // But also call it here, in case we're dealing with an array of BindUnmarshalers
185 | if ok, err := unmarshalField(valueKind, val, structField); ok {
186 | return err
187 | }
188 |
189 | switch valueKind {
190 | case reflect.Ptr:
191 | return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
192 | case reflect.Int:
193 | return setIntField(val, 0, structField)
194 | case reflect.Int8:
195 | return setIntField(val, 8, structField)
196 | case reflect.Int16:
197 | return setIntField(val, 16, structField)
198 | case reflect.Int32:
199 | return setIntField(val, 32, structField)
200 | case reflect.Int64:
201 | return setIntField(val, 64, structField)
202 | case reflect.Uint:
203 | return setUintField(val, 0, structField)
204 | case reflect.Uint8:
205 | return setUintField(val, 8, structField)
206 | case reflect.Uint16:
207 | return setUintField(val, 16, structField)
208 | case reflect.Uint32:
209 | return setUintField(val, 32, structField)
210 | case reflect.Uint64:
211 | return setUintField(val, 64, structField)
212 | case reflect.Bool:
213 | return setBoolField(val, structField)
214 | case reflect.Float32:
215 | return setFloatField(val, 32, structField)
216 | case reflect.Float64:
217 | return setFloatField(val, 64, structField)
218 | case reflect.String:
219 | structField.SetString(val)
220 | default:
221 | return errors.New("unknown type")
222 | }
223 | return nil
224 | }
225 |
226 | func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
227 | switch valueKind {
228 | case reflect.Ptr:
229 | return unmarshalFieldPtr(val, field)
230 | default:
231 | return unmarshalFieldNonPtr(val, field)
232 | }
233 | }
234 |
235 | func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
236 | fieldIValue := field.Addr().Interface()
237 | if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok {
238 | return true, unmarshaler.UnmarshalParam(value)
239 | }
240 | if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok {
241 | return true, unmarshaler.UnmarshalText([]byte(value))
242 | }
243 |
244 | return false, nil
245 | }
246 |
247 | func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
248 | if field.IsNil() {
249 | // Initialize the pointer to a nil value
250 | field.Set(reflect.New(field.Type().Elem()))
251 | }
252 | return unmarshalFieldNonPtr(value, field.Elem())
253 | }
254 |
255 | func setIntField(value string, bitSize int, field reflect.Value) error {
256 | if value == "" {
257 | value = "0"
258 | }
259 | intVal, err := strconv.ParseInt(value, 10, bitSize)
260 | if err == nil {
261 | field.SetInt(intVal)
262 | }
263 | return err
264 | }
265 |
266 | func setUintField(value string, bitSize int, field reflect.Value) error {
267 | if value == "" {
268 | value = "0"
269 | }
270 | uintVal, err := strconv.ParseUint(value, 10, bitSize)
271 | if err == nil {
272 | field.SetUint(uintVal)
273 | }
274 | return err
275 | }
276 |
277 | func setBoolField(value string, field reflect.Value) error {
278 | if value == "" {
279 | value = "false"
280 | }
281 | boolVal, err := strconv.ParseBool(value)
282 | if err == nil {
283 | field.SetBool(boolVal)
284 | }
285 | return err
286 | }
287 |
288 | func setFloatField(value string, bitSize int, field reflect.Value) error {
289 | if value == "" {
290 | value = "0.0"
291 | }
292 | floatVal, err := strconv.ParseFloat(value, bitSize)
293 | if err == nil {
294 | field.SetFloat(floatVal)
295 | }
296 | return err
297 | }
298 |
--------------------------------------------------------------------------------
/bind_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "fmt"
5 | "github.com/stretchr/testify/assert"
6 | "io"
7 | "net/http"
8 | "net/http/httptest"
9 | "strings"
10 | "testing"
11 | )
12 |
13 | type user struct {
14 | Username string `json:"username" validate:"required"`
15 | Password string `json:"password"`
16 | Age int `json:"age"`
17 | }
18 |
19 | type empty struct {
20 | }
21 |
22 | type cmdbBind struct {
23 | RegionId string `json:"region_id"`
24 | SecId string `json:"secId"`
25 | Cloud string `json:"cloud"`
26 | Account string `json:"account"`
27 | }
28 |
29 | var userInfo = `{"username": "henry","age":24,"password":"123123"}`
30 | var invalidInfo = `{"username": "","age":24,"password":"123123"}`
31 | var encrypt = `e2db79dc56e0b5a5866fa4062c9c715e66a5d4820d5424c7645092be0041c1e62d9571b0549758cd02445593b2a276d455cba5b31295e1288d67255e78e4dd78`
32 |
33 | func TestBindJSON(t *testing.T) {
34 | assertions := assert.New(t)
35 | testBindOkay(assertions, strings.NewReader(encrypt), MIMEApplicationJSON)
36 | //testBindError(assertions, strings.NewReader(invalidInfo), MIMEApplicationJSON)
37 | //testBindQueryPrams(assertions, MIMETextHTML)
38 | }
39 |
40 | func TestDefaultBinder_Bind(t *testing.T) {
41 | e := C()
42 | e.POST("/bind", func(c Context) (err error) {
43 | u := new(user)
44 | if err := c.Bind(u); err != nil {
45 | return err
46 | }
47 | return c.JSON(http.StatusOK, u)
48 | })
49 | req := httptest.NewRequest(http.MethodPost, "/bind", strings.NewReader(invalidInfo))
50 | req.Header.Set("Content-Type", MIMEApplicationJSON)
51 | rec := httptest.NewRecorder()
52 | e.ServeHTTP(rec, req)
53 | fmt.Println(rec.Code)
54 | fmt.Println(rec.Body.String())
55 | }
56 |
57 | func TestBindEncryptOkay(t *testing.T) {
58 | e := C()
59 | e.POST("/", func(c Context) (err error) {
60 | u := new(user)
61 | if err := c.Bind(u); err != nil {
62 | return err
63 | }
64 | return c.JSON(http.StatusOK, u)
65 | })
66 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(encrypt))
67 | req.Header.Set("Content-Type", MIMEApplicationJSON)
68 | rec := httptest.NewRecorder()
69 | e.ServeHTTP(rec, req)
70 | }
71 |
72 | func TestDefaultBinder_Params_Bind(t *testing.T) {
73 | e := C()
74 | e.GET("/bind", func(c Context) (err error) {
75 | u := new(empty)
76 | if err := c.Bind(u); err != nil {
77 | return err
78 | }
79 | return c.JSON(http.StatusOK, "")
80 | })
81 | req := httptest.NewRequest(http.MethodGet, "/bind?username=xxxx&cn", nil)
82 | req.Header.Set("Content-Type", MIMEApplicationJSON)
83 | rec := httptest.NewRecorder()
84 | e.ServeHTTP(rec, req)
85 | assert.Equal(t, http.StatusBadRequest, rec.Code)
86 | }
87 |
88 | func testBindOkay(assert *assert.Assertions, r io.Reader, ctype string) {
89 | e := C()
90 | req := httptest.NewRequest(http.MethodPost, "/", r)
91 | rec := httptest.NewRecorder()
92 | c := e.NewContext(req, rec)
93 | req.Header.Set(HeaderContentType, ctype)
94 | u := new(user)
95 | err := c.Bind(u)
96 | if assert.NoError(err) {
97 | assert.Equal("henry", u.Username)
98 | assert.Equal(24, u.Age)
99 | assert.Equal("123123", u.Password)
100 | }
101 | }
102 |
103 | func testBindError(assert *assert.Assertions, r io.Reader, ctype string) {
104 | e := C()
105 | req := httptest.NewRequest(http.MethodPost, "/", r)
106 | rec := httptest.NewRecorder()
107 | c := e.NewContext(req, rec)
108 | req.Header.Set(HeaderContentType, ctype)
109 | u := new(user)
110 | err := c.Bind(u)
111 | assert.Error(err, "Unmarshal type error: expected=yee.user, got=number, field=, offset=1")
112 | }
113 |
114 | func testBindQueryPrams(assert *assert.Assertions, ctype string) {
115 | e := C()
116 | req := httptest.NewRequest(http.MethodGet, "/?secId=sg-gw86k2rjop30v1ktyn3j®ion_id=eu-central", nil)
117 | rec := httptest.NewRecorder()
118 | c := e.NewContext(req, rec)
119 | req.Header.Set(HeaderContentType, ctype)
120 | u := new(cmdbBind)
121 | err := c.Bind(u)
122 | if assert.NoError(err) {
123 | assert.Equal("eu-central", u.RegionId)
124 | assert.Equal("sg-gw86k2rjop30v1ktyn3j", u.SecId)
125 | }
126 | }
127 |
128 | func TestQueryParams(t *testing.T) {
129 | assertions := assert.New(t)
130 | testBindQueryPrams(assertions, MIMETextHTML)
131 | }
132 |
--------------------------------------------------------------------------------
/common.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "path"
7 | "reflect"
8 | "unsafe"
9 | )
10 |
11 | // StringToBytes converts string to byte slice without a memory allocation.
12 | func StringToBytes(s string) (b []byte) {
13 | sh := *(*reflect.StringHeader)(unsafe.Pointer(&s))
14 | bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
15 | bh.Data, bh.Len, bh.Cap = sh.Data, sh.Len, sh.Len
16 | return b
17 | }
18 |
19 | // BytesToString converts byte slice to string without a memory allocation.
20 | func BytesToString(b []byte) string {
21 | return *(*string)(unsafe.Pointer(&b))
22 | }
23 |
24 | func lastChar(str string) uint8 {
25 | if str == "" {
26 | panic("The length of the string can't be 0")
27 | }
28 | return str[len(str)-1]
29 | }
30 |
31 | func joinPaths(absolutePath, relativePath string) string {
32 | if relativePath == "" {
33 | return absolutePath
34 | }
35 |
36 | finalPath := path.Join(absolutePath, relativePath)
37 | if lastChar(relativePath) == '/' && lastChar(finalPath) != '/' {
38 | return finalPath + "/"
39 | }
40 | return finalPath
41 | }
42 |
43 | func getBytes(key interface{}) []byte {
44 | var buf bytes.Buffer
45 | enc := json.NewEncoder(&buf)
46 | err := enc.Encode(key)
47 | if err != nil {
48 | return nil
49 | }
50 | return buf.Bytes()
51 | }
52 |
53 | func clearPoint(s string) string {
54 | if len(s) > 0 && s[0] == '"' {
55 | s = s[1:]
56 | }
57 | if len(s) > 0 && s[len(s)-2] == '"' {
58 | s = s[:len(s)-2]
59 | }
60 | return s
61 | }
62 |
--------------------------------------------------------------------------------
/context.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "github.com/cookieY/yee/logger"
8 | "github.com/golang/protobuf/proto"
9 | "math"
10 | "mime/multipart"
11 | "net"
12 | "net/http"
13 | "net/url"
14 | "os"
15 | "path/filepath"
16 | "strings"
17 | "sync"
18 | )
19 |
20 | const crashIndex int = math.MaxInt8 / 2
21 |
22 | // Context is the default implementation interface of context
23 | type Context interface {
24 | Request() *http.Request
25 | SetRequest(r *http.Request)
26 | Response() ResponseWriter
27 | SetResponse(w ResponseWriter)
28 | HTML(code int, html string) (err error)
29 | JSON(code int, i interface{}) error
30 | ProtoBuf(code int, i proto.Message) error
31 | String(code int, s string) error
32 | FormValue(name string) string
33 | FormParams() (url.Values, error)
34 | FormFile(name string) (*multipart.FileHeader, error)
35 | File(file string) error
36 | Blob(code int, contentType string, b []byte) (err error)
37 | Status(code int)
38 | QueryParam(name string) string
39 | QueryString() string
40 | SetHeader(key string, value string)
41 | AddHeader(key string, value string)
42 | GetHeader(key string) string
43 | MultipartForm() (*multipart.Form, error)
44 | Redirect(code int, uri string) error
45 | Params(name string) string
46 | RequestURI() string
47 | Scheme() string
48 | IsTLS() bool
49 | IsWebsocket() bool
50 | Next()
51 | HTMLTpl(code int, tml string) (err error)
52 | QueryParams() map[string][]string
53 | Bind(i interface{}) error
54 | Cookie(name string) (*http.Cookie, error)
55 | SetCookie(cookie *http.Cookie)
56 | Cookies() []*http.Cookie
57 | Get(key string) interface{}
58 | Put(key string, values interface{})
59 | ServerError(code int, defaultMessage string) error
60 | RemoteIP() string
61 | Logger() logger.Logger
62 | Reset()
63 | Crash()
64 | IsCrash() bool
65 | CrashWithStatus(code int)
66 | CrashWithJson(code int, json interface{})
67 | Path() string
68 | Abort()
69 | }
70 |
71 | type context struct {
72 | engine *Core
73 | writermem responseWriter
74 | w ResponseWriter
75 | r *http.Request
76 | path string
77 | method string
78 | code int
79 | queryList url.Values // cache url.Values
80 | params *Params
81 | Param Params
82 | // middleware
83 | handlers HandlersChain
84 | index int
85 | store map[string]interface{}
86 | lock sync.RWMutex
87 | noRewrite bool
88 | abort bool
89 | }
90 |
91 | func (c *context) Path() string {
92 | return c.path
93 | }
94 |
95 | func (c *context) Reset() {
96 | c.index = -1
97 | c.handlers = c.engine.noRoute
98 | }
99 |
100 | func (c *context) Abort() {
101 | c.abort = true
102 | }
103 |
104 | func (c *context) Bind(i interface{}) error {
105 | return c.engine.bind.Bind(i, c)
106 | }
107 |
108 | func (c *context) reset() { // reset context members
109 | c.w = &c.writermem
110 | c.Param = c.Param[0:0]
111 | c.handlers = nil
112 | c.index = -1
113 | c.path = ""
114 | // when context reset clear queryList cache .
115 | // cause if not clear cache the queryParams results will mistake
116 | c.queryList = nil
117 | c.store = nil
118 | *c.params = (*c.params)[0:0]
119 | }
120 |
121 | func (c *context) Next() {
122 | c.index++
123 | s := len(c.handlers)
124 | if s > 0 {
125 | for ; c.index < s; c.index++ {
126 | err := c.handlers[c.index](c)
127 | if err != nil {
128 | c.Logger().Error(err.Error())
129 | c.CrashWithString(http.StatusInternalServerError, err.Error())
130 | }
131 | if c.abort {
132 | c.writermem.WriteHeaderNow()
133 | }
134 | if c.w.Written() {
135 | break
136 | }
137 | }
138 | }
139 | }
140 |
141 | func (c *context) Logger() logger.Logger {
142 | return c.engine.l
143 | }
144 |
145 | func (c *context) ServerError(code int, defaultMessage string) error {
146 | c.writermem.status = code
147 | if c.writermem.Written() {
148 | return errors.New("headers were already written")
149 | }
150 | if c.writermem.Status() == code {
151 | c.writermem.Header()["Content-Type"] = []string{MIMETextPlainCharsetUTF8}
152 | c.Logger().Error(fmt.Sprintf("%s %s", c.r.URL, defaultMessage))
153 | _, err := c.w.Write([]byte(defaultMessage))
154 | if err != nil {
155 | return fmt.Errorf("cannot write message to writer during serve error: %v", err)
156 | }
157 | return nil
158 | }
159 | c.writermem.WriteHeaderNow()
160 | return nil
161 | }
162 |
163 | func (c *context) Put(key string, values interface{}) {
164 | c.lock.Lock()
165 | defer c.lock.Unlock()
166 | if c.store == nil {
167 | c.store = make(map[string]interface{})
168 | }
169 | c.store[key] = values
170 | }
171 |
172 | func (c *context) Crash() {
173 | c.index = crashIndex
174 | }
175 |
176 | func (c *context) IsCrash() bool {
177 | return c.index >= crashIndex
178 | }
179 |
180 | func (c *context) CrashWithStatus(code int) {
181 | c.Status(code)
182 | c.w.WriteHeaderNow()
183 | c.Crash()
184 | }
185 |
186 | func (c *context) CrashWithJson(code int, json interface{}) {
187 | c.Crash()
188 | err := c.JSON(code, json)
189 | if err != nil {
190 | return
191 | }
192 | }
193 |
194 | func (c *context) CrashWithString(code int, str string) {
195 | c.Crash()
196 | err := c.String(code, str)
197 | if err != nil {
198 | return
199 | }
200 | }
201 |
202 | func (c *context) Get(key string) interface{} {
203 | c.lock.RLock()
204 | defer c.lock.RUnlock()
205 | return c.store[key]
206 | }
207 |
208 | func (c *context) Request() *http.Request {
209 | return c.r
210 | }
211 |
212 | func (c *context) SetRequest(r *http.Request) {
213 | c.r = r
214 | }
215 |
216 | func (c *context) Response() ResponseWriter {
217 | return c.w
218 | }
219 |
220 | func (c *context) SetResponse(w ResponseWriter) {
221 | c.w = w
222 | }
223 |
224 | func (c *context) RemoteIP() string {
225 | if ip := c.r.Header.Get(HeaderXForwardedFor); ip != "" {
226 | i := strings.IndexAny(ip, ", ")
227 | if i > 0 {
228 | return ip[:i]
229 | }
230 | }
231 | if ip := c.r.Header.Get(HeaderXRealIP); ip != "" {
232 | return ip
233 | }
234 | ip, _, _ := net.SplitHostPort(c.r.RemoteAddr)
235 | return ip
236 | }
237 |
238 | func (c *context) HTML(code int, html string) (err error) {
239 | return c.HTMLBlob(code, []byte(html))
240 | }
241 |
242 | func (c *context) HTMLTpl(code int, tml string) (err error) {
243 | s, e := os.ReadFile(tml)
244 | if e != nil {
245 | return e
246 | }
247 | return c.HTMLBlob(code, s)
248 | }
249 |
250 | func (c *context) HTMLBlob(code int, b []byte) (err error) {
251 | return c.Blob(code, MIMETextHTMLCharsetUTF8, b)
252 | }
253 |
254 | func (c *context) Blob(code int, contentType string, b []byte) (err error) {
255 | if !c.writermem.Written() {
256 | c.writeContentType(contentType)
257 | c.w.WriteHeader(code)
258 | if _, err = c.w.Write(b); err != nil {
259 | c.Logger().Error(err.Error())
260 | }
261 | }
262 | return
263 | }
264 |
265 | func (c *context) JSON(code int, i interface{}) (err error) {
266 | if !c.writermem.Written() {
267 | enc := json.NewEncoder(c.w)
268 | c.writeContentType(MIMEApplicationJSONCharsetUTF8)
269 | c.w.WriteHeader(code)
270 | return enc.Encode(i)
271 | }
272 | return
273 | }
274 |
275 | func (c *context) ProtoBuf(code int, i proto.Message) (err error) {
276 | if !c.writermem.Written() {
277 | c.writeContentType(MIMEApplicationProtobuf)
278 | c.w.WriteHeader(code)
279 | b, err := proto.Marshal(i)
280 | if err != nil {
281 | return err
282 | }
283 | if _, err = c.w.Write(b); err != nil {
284 | return err
285 | }
286 | }
287 | return
288 | }
289 |
290 | func (c *context) JSONP(code int, fn string, i interface{}) error {
291 | enc := json.NewEncoder(c.w)
292 | c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
293 | c.w.WriteHeader(code)
294 | if _, err := c.w.Write([]byte(fn + "(")); err != nil {
295 | return err
296 | }
297 | if err := enc.Encode(i); err != nil {
298 | return err
299 | }
300 | if _, err := c.w.Write([]byte(");")); err != nil {
301 | return err
302 | }
303 | return nil
304 | }
305 |
306 | func (c *context) String(code int, s string) error {
307 | return c.Blob(code, MIMETextPlainCharsetUTF8, StringToBytes(s))
308 | }
309 |
310 | func (c *context) Status(code int) {
311 | c.w.WriteHeader(code)
312 | }
313 |
314 | func (c *context) SetHeader(key string, value string) {
315 | c.w.Header().Set(key, value)
316 | }
317 |
318 | func (c *context) AddHeader(key string, value string) {
319 | c.w.Header().Add(key, value)
320 | }
321 |
322 | func (c *context) GetHeader(key string) string {
323 | return c.r.Header.Get(key)
324 | }
325 |
326 | func (c *context) Params(name string) string {
327 | for _, i := range *c.params {
328 | if i.Key == name {
329 | return i.Value
330 | }
331 | }
332 | return ""
333 | }
334 | func (c *context) QueryParams() map[string][]string {
335 | return c.r.URL.Query()
336 | }
337 |
338 | func (c *context) QueryParam(name string) string {
339 | if c.queryList == nil {
340 | c.queryList = c.r.URL.Query()
341 | }
342 | return c.queryList.Get(name)
343 | }
344 |
345 | func (c *context) QueryString() string {
346 | return c.r.URL.RawQuery
347 | }
348 |
349 | func (c *context) FormValue(name string) string {
350 | return c.r.FormValue(name)
351 | }
352 |
353 | func (c *context) FormParams() (url.Values, error) {
354 | if strings.HasPrefix(c.r.Header.Get(HeaderContentType), MIMEMultipartForm) {
355 | if err := c.r.ParseMultipartForm(defaultMemory); err != nil {
356 | return nil, err
357 | }
358 | } else {
359 | if err := c.r.ParseForm(); err != nil {
360 | return nil, err
361 | }
362 | }
363 | return c.r.Form, nil
364 | }
365 |
366 | func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
367 | _, fd, err := c.r.FormFile(name)
368 | return fd, err
369 | }
370 |
371 | func (c *context) File(file string) error {
372 | fd, err := os.Open(file)
373 | if err != nil {
374 | return err
375 | }
376 |
377 | defer fd.Close()
378 |
379 | f, _ := fd.Stat()
380 | if f.IsDir() {
381 | file = filepath.Join(file, indexPage)
382 | fd, err = os.Open(file)
383 | if err != nil {
384 | return ErrNotFoundHandler
385 | }
386 | defer fd.Close()
387 | if f, err = fd.Stat(); err != nil {
388 | return err
389 | }
390 | }
391 | http.ServeContent(c.Response(), c.Request(), f.Name(), f.ModTime(), fd)
392 | return nil
393 | }
394 |
395 | func (c *context) MultipartForm() (*multipart.Form, error) {
396 | err := c.r.ParseMultipartForm(defaultMemory)
397 | return c.r.MultipartForm, err
398 | }
399 |
400 | func (c *context) Cookie(name string) (*http.Cookie, error) {
401 | return c.r.Cookie(name)
402 | }
403 |
404 | func (c *context) SetCookie(cookie *http.Cookie) {
405 | http.SetCookie(c.w, cookie)
406 | }
407 |
408 | func (c *context) Cookies() []*http.Cookie {
409 | return c.r.Cookies()
410 | }
411 |
412 | func (c *context) RequestURI() string {
413 | return c.r.RequestURI
414 | }
415 |
416 | func (c *context) Scheme() string {
417 | scheme := "http"
418 | if scheme := c.r.Header.Get(HeaderXForwardedProto); scheme != "" {
419 | return scheme
420 | }
421 | if scheme := c.r.Header.Get(HeaderXForwardedProtocol); scheme != "" {
422 | return scheme
423 | }
424 | if ssl := c.r.Header.Get(HeaderXForwardedSsl); ssl == "on" {
425 | return "https"
426 | }
427 | if scheme := c.r.Header.Get(HeaderXUrlScheme); scheme != "" {
428 | return scheme
429 | }
430 | return scheme
431 | }
432 |
433 | func (c *context) IsTLS() bool {
434 | return c.r.TLS != nil
435 | }
436 |
437 | func (c *context) IsWebsocket() bool {
438 | if strings.Contains(strings.ToLower(c.r.Header.Get(HeaderConnection)), "upgrade") &&
439 | strings.EqualFold(c.r.Header.Get(HeaderUpgrade), "websocket") {
440 | return true
441 | }
442 | return false
443 | }
444 |
445 | func (c *context) Redirect(code int, uri string) error {
446 | if code < 300 || code > 308 {
447 | return ErrInvalidRedirectCode
448 | }
449 | c.w.Header().Set(HeaderLocation, uri)
450 | c.w.WriteHeader(code)
451 | return nil
452 | }
453 |
454 | func (c *context) writeContentType(value string) {
455 | header := c.w.Header()
456 | if header.Get(HeaderContentType) == "" {
457 | header.Set(HeaderContentType, value)
458 | }
459 | }
460 |
--------------------------------------------------------------------------------
/context_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "github.com/stretchr/testify/assert"
5 | "log"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 | )
11 |
12 | var testData = `{"id":1,"name":"Jon Snow"}`
13 |
14 | type res struct {
15 | ID int `json:"id"`
16 | Name string `json:"name"`
17 | }
18 |
19 | func TestContextJSON(t *testing.T) {
20 | y := New()
21 | y.POST("/", func(c Context) (err error) {
22 | t := new(res)
23 | if err = c.Bind(&t); err != nil {
24 | return err
25 | }
26 | return c.JSON(http.StatusOK, t)
27 | })
28 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(testData))
29 | req.Header.Set("Content-Type", MIMEApplicationJSON)
30 | rec := httptest.NewRecorder()
31 | y.ServeHTTP(rec, req)
32 | assert.Equal(t, testData+"\n", rec.Body.String())
33 | }
34 |
35 | func TestContextForward(t *testing.T) {
36 | y := New()
37 | y.POST("/", func(c Context) (err error) {
38 | return c.JSON(http.StatusOK, c.RemoteIP())
39 | })
40 | req := httptest.NewRequest(http.MethodPost, "/", nil)
41 | req.Header.Set(HeaderXForwardedFor, "
")
42 | rec := httptest.NewRecorder()
43 | y.ServeHTTP(rec, req)
44 | }
45 |
46 | func TestContextString(t *testing.T) {
47 | y := New()
48 | y.POST("/", func(c Context) (err error) {
49 | return c.String(http.StatusOK, "hello")
50 | })
51 | req := httptest.NewRequest(http.MethodPost, "/", nil)
52 | rec := httptest.NewRecorder()
53 | y.ServeHTTP(rec, req)
54 | assert.Equal(t, "hello", rec.Body.String())
55 | }
56 |
57 | func crashMiddleware() HandlerFunc {
58 | return func(c Context) (err error) {
59 | c.CrashWithStatus(http.StatusUnauthorized)
60 | return
61 | }
62 | }
63 | func sayMiddleware() HandlerFunc {
64 | return func(c Context) (err error) {
65 | log.Println("say")
66 | return
67 | }
68 | }
69 |
70 | func TestCrash(t *testing.T) {
71 | y := New()
72 | y.Use(crashMiddleware())
73 | y.Use(sayMiddleware())
74 | y.GET("/", func(c Context) (err error) {
75 | return c.String(http.StatusOK, "hello")
76 | })
77 | req := httptest.NewRequest(http.MethodGet, "/", nil)
78 | rec := httptest.NewRecorder()
79 | y.ServeHTTP(rec, req)
80 | assert.Equal(t, http.StatusUnauthorized, rec.Body)
81 | }
82 |
83 | func TestRedirect(t *testing.T) {
84 | y := New()
85 | y.GET("/", func(c Context) (err error) {
86 | return c.Redirect(http.StatusMovedPermanently, "/get")
87 | })
88 | y.GET("/get", func(c Context) (err error) {
89 | return c.String(http.StatusOK, "hello")
90 | })
91 | req := httptest.NewRequest(http.MethodGet, "/", nil)
92 | rec := httptest.NewRecorder()
93 | y.ServeHTTP(rec, req)
94 | }
95 |
96 | func BenchmarkAllocJSON(b *testing.B) {
97 | y := New()
98 | y.POST("/", func(c Context) (err error) {
99 | tl := new(res)
100 | if err = c.Bind(&tl); err != nil {
101 |
102 | return err
103 | }
104 | return c.JSON(http.StatusOK, tl)
105 | })
106 | b.ResetTimer()
107 | b.ReportAllocs()
108 | for i := 0; i < b.N; i++ {
109 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(testData))
110 | req.Header.Set("Content-Type", MIMEApplicationJSON)
111 | rec := httptest.NewRecorder()
112 | y.ServeHTTP(rec, req)
113 | }
114 | }
115 |
116 | func BenchmarkAllocString(b *testing.B) {
117 | y := New()
118 | y.POST("/", func(c Context) (err error) {
119 | return c.String(http.StatusOK, "ok")
120 | })
121 | b.ResetTimer()
122 | b.ReportAllocs()
123 | for i := 0; i < b.N; i++ {
124 | req := httptest.NewRequest(http.MethodPost, "/", nil)
125 | rec := httptest.NewRecorder()
126 | y.ServeHTTP(rec, req)
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/fs.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "net/http"
5 | "os"
6 | )
7 |
8 | type onlyFilesFS struct {
9 | fs http.FileSystem
10 | }
11 |
12 | type neuteredReaddirFile struct {
13 | http.File
14 | }
15 |
16 | // Dir returns a http.FileSystem that can be used by http.FileServer(). It is used internally
17 | // in router.Static().
18 | // if listDirectory == true, then it works the same as http.Dir() otherwise it returns
19 | // a filesystem that prevents http.FileServer() to list the directory files.
20 | func Dir(root string, listDirectory bool) http.FileSystem {
21 | fs := http.Dir(root)
22 | if listDirectory {
23 | return fs
24 | }
25 | return &onlyFilesFS{fs}
26 | }
27 |
28 | // Open conforms to http.Filesystem.
29 | func (fs onlyFilesFS) Open(name string) (http.File, error) {
30 | f, err := fs.fs.Open(name)
31 | if err != nil {
32 | return nil, err
33 | }
34 | return neuteredReaddirFile{f}, nil
35 | }
36 |
37 | // Readdir overrides the http.File default implementation.
38 | func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
39 | // this disables directory listing
40 | return nil, nil
41 | }
42 |
--------------------------------------------------------------------------------
/gen.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | // gen.go used to generate Restful code
9 |
10 | const PREFIX = `
11 | package ${PACKAGE}
12 |
13 | import (
14 | "net/http"
15 | "github.com/cookieY/yee"
16 | "github.com/jinzhu/gorm"
17 | )
18 |
19 | ${TP}
20 |
21 | func Paging(page interface{}, total int) (start int, end int) {
22 | start = i*total - total
23 | end = total
24 | return
25 | }
26 |
27 | func Fetch${PACKAGE}(c yee.Context) (err error) {
28 | u := new(FinderPrefix)
29 | if err = c.Bind(u); err != nil {
30 | c.Logger().Error(err.Error())
31 | return
32 | }
33 |
34 | var order []${MODAL}
35 |
36 | start, end := lib.Paging(u.Page, ${PAGE})
37 |
38 | if u.Find.Valve {
39 | model.DB().Model(&${MODAL}{}).
40 | Scopes(
41 | ${QUERY_EXPR}
42 | ).Count(&pg).Order("id desc").Offset(start).Limit(end).Find(&order)
43 | } else {
44 | model.DB().Model(&model.CoreSqlOrder{}).Count(&pg).Order("id desc").Offset(start).Limit(end).Find(&order)
45 | }
46 |
47 | return c.JSON(http.StatusOK, map[string]interface{}{"data": order, "page": pg})
48 | }
49 |
50 | ${QUERYFUNC}
51 |
52 | `
53 |
54 | var QueryExprPrefix = `
55 | func AccordingTo${EXPR_NAME}(val string) func(db *gorm.DB) *gorm.DB {
56 | return func(db *gorm.DB) *gorm.DB {
57 | return db.Where(${QUERY_EXPR},val)
58 | }
59 | }
60 | `
61 |
62 | var FinderPrefix = `
63 | type FinderPrefix struct {
64 | Valve bool // 自行添加tag
65 | ${FINDER_EXPR}
66 | }
67 | `
68 |
69 | type expr struct {
70 | Name string `json:"name"`
71 | Expr string `json:"expr"`
72 | TP string `json:"tp"`
73 | }
74 |
75 | type GenCodeVal struct {
76 | Flag string `json:"flag"` // 根据哪个字段进行CURD
77 | Package string `json:"package"` // 项目名,根据项目名生成package name
78 | QueryExpr []expr `json:"query_expr"` // 查询条件
79 | Page string `json:"page"` //分页大小
80 | Modal string `json:"modal"`
81 | }
82 |
83 | func GenerateRestfulAPI(GenCodeVal GenCodeVal) string {
84 | empty := strings.Replace(PREFIX, "${PACKAGE}", GenCodeVal.Package, -1)
85 | empty = strings.Replace(empty, "${MODAL}", GenCodeVal.Modal, -1)
86 | empty = strings.Replace(empty, "${PAGE}", GenCodeVal.Page, -1)
87 | f, s,l := GenQueryExpr(GenCodeVal.QueryExpr)
88 | empty = strings.Replace(empty, "${QUERYFUNC}", f, -1)
89 | empty = strings.Replace(empty, "${TP}", s, -1)
90 | empty = strings.Replace(empty, "${QUERY_EXPR}", l, -1)
91 | return empty
92 | }
93 |
94 | func GenQueryExpr(QueryExpr []expr) (string, string, string) {
95 | funcEmpty := ""
96 | structEmpty := ""
97 | exprList := ""
98 | for _, i := range QueryExpr {
99 | tmpText := ""
100 | tmpText = strings.Replace(QueryExprPrefix, "${EXPR_NAME}", i.Name, -1)
101 | tmpText = strings.Replace(tmpText, "${QUERY_EXPR}", i.Expr, -1)
102 | funcEmpty += tmpText + "\n"
103 | structEmpty += fmt.Sprintf("%s %s \n ", i.Name, i.TP)
104 | exprList += fmt.Sprintf("AccordingTo%s(u.%s),\n ", i.Name, strings.ToLower(i.Name))
105 | }
106 | structEmpty = strings.Replace(FinderPrefix, "${FINDER_EXPR}", structEmpty, -1)
107 | return funcEmpty, structEmpty, exprList
108 | }
109 |
--------------------------------------------------------------------------------
/gen_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "os"
7 | "strings"
8 | "testing"
9 | )
10 |
11 | var test = `
12 | type GenCodeVal struct {
13 | ${ATTRIBUTE}
14 | }
15 | `
16 |
17 | func jk(t string) {
18 | test = strings.Replace(test, "${ATTRIBUTE}", t, -1)
19 | f, err := os.OpenFile("koala.go", os.O_WRONLY&os.O_CREATE, 0666)
20 | if err != nil {
21 | log.Println(err.Error())
22 | }
23 | _, err = f.Write([]byte(test))
24 | if err != nil {
25 | log.Println(err.Error())
26 | }
27 | f.Close()
28 | }
29 |
30 | func TestNew(t *testing.T) {
31 | c := []map[string]string{{"a": "int"}, {"k": "string"}}
32 | //l := ""
33 | for _, i := range c {
34 | fmt.Println(i)
35 | //fmt.Println(c[j])
36 | }
37 | }
38 |
39 | var exprCase = []expr{{Name: "Username", Expr: "username =?", TP: "string"}, {Name: "Age", Expr: "age > ?", TP: "int"}}
40 |
41 | func TestGenQueryExpr(t *testing.T) {
42 |
43 | GenQueryExpr(exprCase)
44 | }
45 |
46 | func TestGenerateRestfulAPI(t *testing.T) {
47 | k := GenCodeVal{
48 | Package: "manage",
49 | QueryExpr: exprCase,
50 | Page: "20",
51 | Modal: "modal.core_account",
52 | }
53 | fmt.Println(GenerateRestfulAPI(k))
54 | }
55 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/cookieY/yee
2 |
3 | go 1.16
4 |
5 | require (
6 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect
7 | github.com/go-playground/validator/v10 v10.16.0
8 | github.com/golang-jwt/jwt v3.2.2+incompatible
9 | github.com/golang/protobuf v1.5.3
10 | github.com/google/go-cmp v0.5.9 // indirect
11 | github.com/google/uuid v1.1.1
12 | github.com/mattn/go-colorable v0.1.6
13 | github.com/mattn/go-isatty v0.0.14
14 | github.com/opentracing/opentracing-go v1.2.0
15 | github.com/pkg/errors v0.8.1 // indirect
16 | github.com/stretchr/testify v1.8.2
17 | github.com/uber/jaeger-client-go v2.30.0+incompatible
18 | github.com/uber/jaeger-lib v2.4.1+incompatible // indirect
19 | github.com/valyala/fasttemplate v1.1.0
20 | go.uber.org/atomic v1.9.0 // indirect
21 | golang.org/x/net v0.23.0
22 | google.golang.org/protobuf v1.28.0
23 | )
24 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
2 | github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
3 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM=
4 | github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo=
5 | github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
6 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
7 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
8 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
9 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
10 | github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
11 | github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
12 | github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
13 | github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
14 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
15 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
16 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
17 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
18 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
19 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
20 | github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
21 | github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
22 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
23 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
24 | github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
25 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
26 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
27 | github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
28 | github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
29 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
30 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
31 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
32 | github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
33 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
34 | github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
35 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
36 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
37 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
38 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
39 | github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
40 | github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
41 | github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE=
42 | github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
43 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
44 | github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
45 | github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
46 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
47 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
48 | github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
49 | github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
50 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
51 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
52 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
54 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
55 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
56 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
57 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
58 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
59 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
60 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
61 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
62 | github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
63 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
64 | github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
65 | github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
66 | github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg=
67 | github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
68 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
69 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
70 | github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4=
71 | github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
72 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
73 | go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
74 | go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
75 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
76 | golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
77 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
78 | golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
79 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
80 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
81 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
82 | golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
83 | golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
84 | golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
85 | golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
86 | golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY=
87 | golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
88 | golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
89 | golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
90 | golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
91 | golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
92 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
93 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
94 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
95 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
96 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
97 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
98 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
99 | golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
100 | golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
101 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
102 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
103 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
104 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
105 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
106 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
107 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
108 | golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
109 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
110 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
111 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
112 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
113 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
114 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
115 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
116 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
117 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
118 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
119 | golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
120 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
121 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
122 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
123 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
124 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
125 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
126 | golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
127 | golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
128 | golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
129 | golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
130 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
131 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
132 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
133 | golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
134 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
135 | golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
136 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
137 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
138 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
139 | golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
140 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
141 | golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
142 | golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
143 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
144 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
145 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
146 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
147 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
148 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
149 | gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
150 | gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
151 | gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
152 | gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
153 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
154 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
155 | google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
156 | google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
157 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
158 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
159 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
160 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
161 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
162 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
163 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
164 |
--------------------------------------------------------------------------------
/h2_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "crypto/tls"
5 | "fmt"
6 | "io/ioutil"
7 | "log"
8 | "net"
9 | "net/http"
10 | "testing"
11 |
12 | "golang.org/x/net/http2"
13 | )
14 |
15 | func TestH2(T *testing.T) {
16 | y := New()
17 | y.GET("/", func(c Context) error {
18 | return c.String(http.StatusOK, "ok")
19 | })
20 | y.RunH2C(":9999")
21 | }
22 |
23 | func TestH2cClient(t *testing.T) {
24 | client := http.Client{
25 | Transport: &http2.Transport{
26 | AllowHTTP: true,
27 | DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
28 | return net.Dial(network, addr)
29 | },
30 | },
31 | }
32 |
33 | resp, err := client.Get("http://localhost:9999")
34 | if err != nil {
35 | log.Fatalf("faild request: %s", err)
36 | }
37 |
38 | defer resp.Body.Close()
39 |
40 | body, err := ioutil.ReadAll(resp.Body)
41 |
42 | if err != nil {
43 | log.Fatalf("read response failed: %s", err)
44 | }
45 | fmt.Printf("proto:%s\ncode %d: %s\n", resp.Proto, resp.StatusCode, string(body))
46 | }
47 |
--------------------------------------------------------------------------------
/h3_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | //import (
4 | // "crypto/tls"
5 | // "fmt"
6 | // pb "github.com/cookieY/yee/test"
7 | // "github.com/quic-go/quic-go"
8 | // "github.com/quic-go/quic-go/http3"
9 | // "io/ioutil"
10 | // "net/http"
11 | // "runtime"
12 | // "testing"
13 | //)
14 | //
15 | //const addr = "https://www.henry.com:9999/hello"
16 | //
17 | //func TestH3Server(t *testing.T) {
18 | // y := New()
19 | // y.SetLogLevel(5)
20 | // y.POST("/hello", func(c Context) (err error) {
21 | // u := new(pb.Svr)
22 | // if err := c.Bind(u); err != nil {
23 | // c.Logger().Error(err.Error())
24 | // return err
25 | // }
26 | // c.Logger().Debugf("svr get client data: %s", u.Project)
27 | // svr := pb.Svr{Cloud: "hi"}
28 | // return c.ProtoBuf(http.StatusOK, &svr)
29 | // })
30 | // y.RunH3(":9999", "henry.com+4.pem", "henry.com+4-key.pem")
31 | //}
32 | //
33 | //func TestH3SvrIndex(t *testing.T) {
34 | // y := New()
35 | // y.SetLogLevel(5)
36 | // y.GET("/", func(c Context) (err error) {
37 | //
38 | // return c.JSON(http.StatusOK, "hello")
39 | // })
40 | // y.Run(":445")
41 | //}
42 | //
43 | //func TestH2SvrIndex(t *testing.T) {
44 | // y := New()
45 | // y.SetLogLevel(5)
46 | // y.GET("/", func(c Context) (err error) {
47 | //
48 | // return c.JSON(http.StatusOK, "hello")
49 | // })
50 | // y.RunTLS(":444", "henry.com+4.pem", "henry.com+4-key.pem")
51 | //}
52 | //
53 | //var cs = http.Client{
54 | // Transport: &http3.RoundTripper{
55 | // TLSClientConfig: &tls.Config{},
56 | // QuicConfig: &quic.Config{},
57 | // },
58 | //}
59 | //
60 | //func BenchmarkH3SvrIndex(b *testing.B) {
61 | // b.SetBytes(1024 * 1024)
62 | // for i := 0; i < b.N; i++ {
63 | // http.Get("http://127.0.0.1:445/")
64 | // }
65 | //}
66 | //
67 | //func BenchmarkH2SvrIndex(b *testing.B) {
68 | // b.SetBytes(1024 * 1024)
69 | // for i := 0; i < b.N; i++ {
70 | // http.Get("https://127.0.0.1:444/")
71 | // }
72 | //}
73 | //
74 | //func TestRespProto(t *testing.T) {
75 | // //cs := http.Client{
76 | // // Transport: &http3.RoundTripper{
77 | // // TLSClientConfig: &tls.Config{
78 | // // InsecureSkipVerify: true,
79 | // // },
80 | // // QuicConfig: &quic.Config{},
81 | // // },
82 | // //}
83 | //
84 | // b, err := http.Get("https://127.0.0.1:444/")
85 | // if err != nil {
86 | // t.Error(err)
87 | // }
88 | // c, _ := ioutil.ReadAll(b.Body)
89 | // fmt.Println(string(c))
90 | // fmt.Println(b.Proto)
91 | //}
92 | //
93 | //func TestNewH3Client(t *testing.T) {
94 | // cs := NewH3Client(&CConfig{
95 | // Addr: addr,
96 | // InsecureSkipVerify: true,
97 | // })
98 | // rsp := new(pb.Svr)
99 | // cs.Post(&pb.Svr{Project: "henry"}, rsp)
100 | // fmt.Println(rsp.Cloud)
101 | //}
102 | //
103 | //func TestNewProtoc3(t *testing.T) {
104 | // y := New()
105 | // y.SetLogLevel(5)
106 | // y.POST("/hello", func(c Context) (err error) {
107 | // u := new(pb.Svr)
108 | // if err := c.Bind(u); err != nil {
109 | // c.Logger().Error(err.Error())
110 | // return err
111 | // }
112 | // c.Logger().Debugf("svr get client data: %s", u.Project)
113 | // svr := pb.Svr{Cloud: "hi"}
114 | // return c.ProtoBuf(http.StatusOK, &svr)
115 | // })
116 | // y.Run(":9999")
117 | //}
118 | //
119 | //func BenchmarkH2vsH3(b *testing.B) {
120 | // runtime.GOMAXPROCS(runtime.NumCPU())
121 | // b.Run("http3", BenchmarkH3SvrIndex)
122 | // b.Run("http2", BenchmarkH2SvrIndex)
123 | //}
124 |
--------------------------------------------------------------------------------
/h3client.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | //
4 | //import (
5 | // "bytes"
6 | // "crypto/tls"
7 | // "github.com/quic-go/quic-go"
8 | // "io/ioutil"
9 | // "net/http"
10 | //
11 | // "github.com/cookieY/yee/logger"
12 | // "github.com/golang/protobuf/proto"
13 | // "github.com/quic-go/quic-go/http3"
14 | //)
15 | //
16 | //type transport struct {
17 | // addr string
18 | // insecureSkipVerify bool
19 | // logger logger.Logger
20 | // tripper *http3.RoundTripper
21 | // c *http.Client
22 | //}
23 | //
24 | //type CConfig struct {
25 | // Addr string
26 | // InsecureSkipVerify bool
27 | //}
28 | //
29 | //func NewH3Client(c *CConfig) *transport {
30 | // tripper := &http3.RoundTripper{
31 | // TLSClientConfig: &tls.Config{
32 | // InsecureSkipVerify: true,
33 | // },
34 | // QuicConfig: &quic.Config{},
35 | // }
36 | // return &transport{
37 | // addr: c.Addr,
38 | // insecureSkipVerify: c.InsecureSkipVerify,
39 | // logger: logger.LogCreator(),
40 | // c: &http.Client{
41 | // Transport: tripper,
42 | // },
43 | // tripper: tripper,
44 | // }
45 | //}
46 | //
47 | //func (t *transport) Get(url string) (*http.Response, error) {
48 | // resp, err := t.c.Get(url)
49 | // if err != nil {
50 | // return nil, err
51 | // }
52 | // return resp, nil
53 | //}
54 | //
55 | //func (t *transport) Post(payload proto.Message, recv proto.Message) {
56 | // p, err := proto.Marshal(payload)
57 | // if err != nil {
58 | // t.logger.Critical(err.Error())
59 | // return
60 | // }
61 | // rsp, err := t.c.Post(t.addr, MIMEApplicationProtobuf, bytes.NewReader(p))
62 | // if err != nil {
63 | // t.logger.Critical(err.Error())
64 | // return
65 | // }
66 | // b, err := ioutil.ReadAll(rsp.Body)
67 | // if rsp.StatusCode == 200 {
68 | // err = proto.Unmarshal(b, recv)
69 | // if err != nil {
70 | // t.logger.Critical(err.Error())
71 | // }
72 | // return
73 | // }
74 | // t.logger.Error(string(b))
75 | // defer t.close()
76 | //}
77 | //
78 | //func (t *transport) close() {
79 | // err := t.tripper.Close()
80 | // if err != nil {
81 | // t.logger.Critical(err.Error())
82 | // }
83 | //}
84 |
--------------------------------------------------------------------------------
/img/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cookieY/yee/c07a8279ce061505162bad9ba33b576302b6fd92/img/benchmark.png
--------------------------------------------------------------------------------
/log_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "github.com/cookieY/yee/logger"
5 | "github.com/stretchr/testify/assert"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func TestLogger_LogWrite(t *testing.T) {
12 |
13 | y := New()
14 | y.SetLogLevel(logger.Warning)
15 | //file, err := os.OpenFile("logrus.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
16 | //if err != nil {
17 | // t.Log(err)
18 | // return
19 | //}
20 | //y.SetLogOut(file)
21 | y.POST("/hello/k/:b", func(c Context) error {
22 | c.Logger().Critical("critical")
23 | c.Logger().Error("error")
24 | c.Logger().Warn("warn")
25 | c.Logger().Info("info")
26 | c.Logger().Debug("debug")
27 | c.Logger().Criticalf("test:%v", 123)
28 | c.Logger().Errorf("test:%v", 123)
29 | c.Logger().Warnf("test:%v", 123)
30 | c.Logger().Infof("test:%v", 123)
31 | c.Logger().Debugf("test:%v", 123)
32 | return c.String(http.StatusOK, c.Params("b"))
33 | })
34 | t.Run("http_get", func(t *testing.T) {
35 | req := httptest.NewRequest(http.MethodPost, "/hello/k/henry", nil)
36 | rec := httptest.NewRecorder()
37 | y.ServeHTTP(rec, req)
38 | tx := assert.New(t)
39 | tx.Equal("henry", rec.Body.String())
40 | //assert.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin))
41 | })
42 | }
43 |
44 | func BenchmarkLogger_LogWrite(b *testing.B) {
45 | l := logger.LogCreator()
46 | b.ReportAllocs()
47 | b.SetBytes(1024 * 1024)
48 | for i := 0; i < b.N; i++ {
49 | l.Critical("critical")
50 | l.Error("error")
51 | l.Warn("warn")
52 | l.Info("info")
53 | l.Debug("debug")
54 | l.Criticalf("test:%v", 123)
55 | l.Errorf("test:%v", 123)
56 | l.Warnf("test:%v", 123)
57 | l.Infof("test:%v", 123)
58 | l.Debugf("test:%v", 123)
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/logger/color.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "os"
8 |
9 | "github.com/mattn/go-colorable"
10 | "github.com/mattn/go-isatty"
11 | )
12 |
13 | type (
14 | inner func(interface{}, []string, *Color) string
15 | )
16 |
17 | // Color styles
18 | const (
19 | // Blk Black text style
20 | Blk = "30"
21 | // Rd red text style
22 | Rd = "31"
23 | // Grn green text style
24 | Grn = "32"
25 | // Yel yellow text style
26 | Yel = "33"
27 | // Blu blue text style
28 | Blu = "34"
29 | // Mgn magenta text style
30 | Mgn = "35"
31 | // Cyn cyan text style
32 | Cyn = "36"
33 | // Wht white text style
34 | Wht = "37"
35 | // Gry grey text style
36 | Gry = "90"
37 |
38 | // BlkBg black background style
39 | BlkBg = "40"
40 | // RdBg red background style
41 | RdBg = "41"
42 | // GrnBg green background style
43 | GrnBg = "42"
44 | // YelBg yellow background style
45 | YelBg = "43"
46 | // BluBg blue background style
47 | BluBg = "44"
48 | // MgnBg magenta background style
49 | MgnBg = "45"
50 | // CynBg cyan background style
51 | CynBg = "46"
52 | // WhtBg white background style
53 | WhtBg = "47"
54 |
55 | // R reset emphasis style
56 | R = "0"
57 | // B bold emphasis style
58 | B = "1"
59 | // D dim emphasis style
60 | D = "2"
61 | // I italic emphasis style
62 | I = "3"
63 | // U underline emphasis style
64 | U = "4"
65 | // In inverse emphasis style
66 | In = "7"
67 | // H hidden emphasis style
68 | H = "8"
69 | // S strikeout emphasis style
70 | S = "9"
71 | )
72 |
73 | var (
74 | black = outer(Blk)
75 | red = outer(Rd)
76 | green = outer(Grn)
77 | yellow = outer(Yel)
78 | blue = outer(Blu)
79 | magenta = outer(Mgn)
80 | cyan = outer(Cyn)
81 | white = outer(Wht)
82 | grey = outer(Gry)
83 |
84 | blackBg = outer(BlkBg)
85 | redBg = outer(RdBg)
86 | greenBg = outer(GrnBg)
87 | yellowBg = outer(YelBg)
88 | blueBg = outer(BluBg)
89 | magentaBg = outer(MgnBg)
90 | cyanBg = outer(CynBg)
91 | whiteBg = outer(WhtBg)
92 |
93 | reset = outer(R)
94 | bold = outer(B)
95 | dim = outer(D)
96 | italic = outer(I)
97 | underline = outer(U)
98 | inverse = outer(In)
99 | hidden = outer(H)
100 | strikeout = outer(S)
101 |
102 | global = New()
103 | )
104 |
105 | func outer(n string) inner {
106 | return func(msg interface{}, styles []string, c *Color) string {
107 | // TODO: Drop fmt to boost performance?
108 | if c.disabled {
109 | return fmt.Sprintf("%v", msg)
110 | }
111 |
112 | b := new(bytes.Buffer)
113 | b.WriteString("\x1b[")
114 | b.WriteString(n)
115 | for _, s := range styles {
116 | b.WriteString(";")
117 | b.WriteString(s)
118 | }
119 | b.WriteString("m")
120 | return fmt.Sprintf("%s%v\x1b[0m", b.String(), msg)
121 | }
122 | }
123 |
124 | type (
125 | Color struct {
126 | output io.Writer
127 | disabled bool
128 | }
129 | )
130 |
131 | // New creates a Color instance.
132 | func New() (c *Color) {
133 | c = new(Color)
134 | c.SetOutput(colorable.NewColorableStdout())
135 | return
136 | }
137 |
138 | // Output returns the output.
139 | func (c *Color) Output() io.Writer {
140 | return c.output
141 | }
142 |
143 | // SetOutput sets the output.
144 | func (c *Color) SetOutput(w io.Writer) {
145 | c.output = w
146 | if w, ok := w.(*os.File); !ok || !isatty.IsTerminal(w.Fd()) {
147 | c.disabled = true
148 | }
149 | }
150 |
151 | // Disable disables the colors and styles.
152 | func (c *Color) Disable() {
153 | c.disabled = true
154 | }
155 |
156 | // Enable enables the colors and styles.
157 | func (c *Color) Enable() {
158 | c.disabled = false
159 | }
160 |
161 | // Print is analogous to `fmt.Print` with termial detection.
162 | func (c *Color) Print(args ...interface{}) {
163 | fmt.Fprint(c.output, args...)
164 | }
165 |
166 | // Println is analogous to `fmt.Println` with termial detection.
167 | func (c *Color) Println(args ...interface{}) {
168 | fmt.Fprintln(c.output, args...)
169 | }
170 |
171 | // Printf is analogous to `fmt.Printf` with termial detection.
172 | func (c *Color) Printf(format string, args ...interface{}) {
173 | fmt.Fprintf(c.output, format, args...)
174 | }
175 |
176 | func (c *Color) Black(msg interface{}, styles ...string) string {
177 | return black(msg, styles, c)
178 | }
179 |
180 | func (c *Color) Red(msg interface{}, styles ...string) string {
181 | return red(msg, styles, c)
182 | }
183 |
184 | func (c *Color) Green(msg interface{}, styles ...string) string {
185 | return green(msg, styles, c)
186 | }
187 |
188 | func (c *Color) Yellow(msg interface{}, styles ...string) string {
189 | return yellow(msg, styles, c)
190 | }
191 |
192 | func (c *Color) Blue(msg interface{}, styles ...string) string {
193 | return blue(msg, styles, c)
194 | }
195 |
196 | func (c *Color) Magenta(msg interface{}, styles ...string) string {
197 | return magenta(msg, styles, c)
198 | }
199 |
200 | func (c *Color) Cyan(msg interface{}, styles ...string) string {
201 | return cyan(msg, styles, c)
202 | }
203 |
204 | func (c *Color) White(msg interface{}, styles ...string) string {
205 | return white(msg, styles, c)
206 | }
207 |
208 | func (c *Color) Grey(msg interface{}, styles ...string) string {
209 | return grey(msg, styles, c)
210 | }
211 |
212 | func (c *Color) BlackBg(msg interface{}, styles ...string) string {
213 | return blackBg(msg, styles, c)
214 | }
215 |
216 | func (c *Color) RedBg(msg interface{}, styles ...string) string {
217 | return redBg(msg, styles, c)
218 | }
219 |
220 | func (c *Color) GreenBg(msg interface{}, styles ...string) string {
221 | return greenBg(msg, styles, c)
222 | }
223 |
224 | func (c *Color) YellowBg(msg interface{}, styles ...string) string {
225 | return yellowBg(msg, styles, c)
226 | }
227 |
228 | func (c *Color) BlueBg(msg interface{}, styles ...string) string {
229 | return blueBg(msg, styles, c)
230 | }
231 |
232 | func (c *Color) MagentaBg(msg interface{}, styles ...string) string {
233 | return magentaBg(msg, styles, c)
234 | }
235 |
236 | func (c *Color) CyanBg(msg interface{}, styles ...string) string {
237 | return cyanBg(msg, styles, c)
238 | }
239 |
240 | func (c *Color) WhiteBg(msg interface{}, styles ...string) string {
241 | return whiteBg(msg, styles, c)
242 | }
243 |
244 | func (c *Color) Reset(msg interface{}, styles ...string) string {
245 | return reset(msg, styles, c)
246 | }
247 |
248 | func (c *Color) Bold(msg interface{}, styles ...string) string {
249 | return bold(msg, styles, c)
250 | }
251 |
252 | func (c *Color) Dim(msg interface{}, styles ...string) string {
253 | return dim(msg, styles, c)
254 | }
255 |
256 | func (c *Color) Italic(msg interface{}, styles ...string) string {
257 | return italic(msg, styles, c)
258 | }
259 |
260 | func (c *Color) Underline(msg interface{}, styles ...string) string {
261 | return underline(msg, styles, c)
262 | }
263 |
264 | func (c *Color) Inverse(msg interface{}, styles ...string) string {
265 | return inverse(msg, styles, c)
266 | }
267 |
268 | func (c *Color) Hidden(msg interface{}, styles ...string) string {
269 | return hidden(msg, styles, c)
270 | }
271 |
272 | func (c *Color) Strikeout(msg interface{}, styles ...string) string {
273 | return strikeout(msg, styles, c)
274 | }
275 |
276 | // Output returns the output.
277 | func Output() io.Writer {
278 | return global.output
279 | }
280 |
281 | // SetOutput sets the output.
282 | func SetOutput(w io.Writer) {
283 | global.SetOutput(w)
284 | }
285 |
286 | func Disable() {
287 | global.Disable()
288 | }
289 |
290 | func Enable() {
291 | global.Enable()
292 | }
293 |
294 | // Print is analogous to `fmt.Print` with termial detection.
295 | func Print(args ...interface{}) {
296 | global.Print(args...)
297 | }
298 |
299 | // Println is analogous to `fmt.Println` with termial detection.
300 | func Println(args ...interface{}) {
301 | global.Println(args...)
302 | }
303 |
304 | // Printf is analogous to `fmt.Printf` with termial detection.
305 | func Printf(format string, args ...interface{}) {
306 | global.Printf(format, args...)
307 | }
308 |
309 | func Black(msg interface{}, styles ...string) string {
310 | return global.Black(msg, styles...)
311 | }
312 |
313 | func Red(msg interface{}, styles ...string) string {
314 | return global.Red(msg, styles...)
315 | }
316 |
317 | func Green(msg interface{}, styles ...string) string {
318 | return global.Green(msg, styles...)
319 | }
320 |
321 | func Yellow(msg interface{}, styles ...string) string {
322 | return global.Yellow(msg, styles...)
323 | }
324 |
325 | func Blue(msg interface{}, styles ...string) string {
326 | return global.Blue(msg, styles...)
327 | }
328 |
329 | func Magenta(msg interface{}, styles ...string) string {
330 | return global.Magenta(msg, styles...)
331 | }
332 |
333 | func Cyan(msg interface{}, styles ...string) string {
334 | return global.Cyan(msg, styles...)
335 | }
336 |
337 | func White(msg interface{}, styles ...string) string {
338 | return global.White(msg, styles...)
339 | }
340 |
341 | func Grey(msg interface{}, styles ...string) string {
342 | return global.Grey(msg, styles...)
343 | }
344 |
345 | func BlackBg(msg interface{}, styles ...string) string {
346 | return global.BlackBg(msg, styles...)
347 | }
348 |
349 | func RedBg(msg interface{}, styles ...string) string {
350 | return global.RedBg(msg, styles...)
351 | }
352 |
353 | func GreenBg(msg interface{}, styles ...string) string {
354 | return global.GreenBg(msg, styles...)
355 | }
356 |
357 | func YellowBg(msg interface{}, styles ...string) string {
358 | return global.YellowBg(msg, styles...)
359 | }
360 |
361 | func BlueBg(msg interface{}, styles ...string) string {
362 | return global.BlueBg(msg, styles...)
363 | }
364 |
365 | func MagentaBg(msg interface{}, styles ...string) string {
366 | return global.MagentaBg(msg, styles...)
367 | }
368 |
369 | func CyanBg(msg interface{}, styles ...string) string {
370 | return global.CyanBg(msg, styles...)
371 | }
372 |
373 | func WhiteBg(msg interface{}, styles ...string) string {
374 | return global.WhiteBg(msg, styles...)
375 | }
376 |
377 | func Reset(msg interface{}, styles ...string) string {
378 | return global.Reset(msg, styles...)
379 | }
380 |
381 | func Bold(msg interface{}, styles ...string) string {
382 | return global.Bold(msg, styles...)
383 | }
384 |
385 | func Dim(msg interface{}, styles ...string) string {
386 | return global.Dim(msg, styles...)
387 | }
388 |
389 | func Italic(msg interface{}, styles ...string) string {
390 | return global.Italic(msg, styles...)
391 | }
392 |
393 | func Underline(msg interface{}, styles ...string) string {
394 | return global.Underline(msg, styles...)
395 | }
396 |
397 | func Inverse(msg interface{}, styles ...string) string {
398 | return global.Inverse(msg, styles...)
399 | }
400 |
401 | func Hidden(msg interface{}, styles ...string) string {
402 | return global.Hidden(msg, styles...)
403 | }
404 |
405 | func Strikeout(msg interface{}, styles ...string) string {
406 | return global.Strikeout(msg, styles...)
407 | }
--------------------------------------------------------------------------------
/logger/color_test.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "os"
5 | "testing"
6 | )
7 |
8 | func TestColor_Black(t *testing.T) {
9 | c := New()
10 | c.Enable()
11 | _, _ = os.Stdout.Write(append([]byte(c.Blue("bule")), '\n'))
12 | }
13 |
--------------------------------------------------------------------------------
/logger/log.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "os"
7 | "runtime"
8 | "strings"
9 | "sync"
10 | "time"
11 | )
12 |
13 | // logger types
14 | const (
15 | Critical = iota
16 | Error
17 | Warning
18 | Info
19 | Debug
20 | )
21 |
22 | const timeFormat = "2006-01-02 15:04:05"
23 |
24 | type logger struct {
25 | sync.Mutex
26 | level uint8
27 | isLogger bool
28 | version string
29 | producer *Color
30 | out io.Writer
31 | noColor bool
32 | }
33 |
34 | // Logger ...
35 | type Logger interface {
36 | Critical(msg interface{})
37 | Error(msg interface{})
38 | Warn(msg interface{})
39 | Info(msg interface{})
40 | Debug(msg interface{})
41 | Criticalf(error string, msg ...interface{})
42 | Errorf(error string, msg ...interface{})
43 | Warnf(error string, msg ...interface{})
44 | Infof(error string, msg ...interface{})
45 | Debugf(error string, msg ...interface{})
46 | Custom(msg string)
47 | SetLevel(level uint8)
48 | SetOut(out io.Writer)
49 | IsLogger(isOk bool)
50 | }
51 |
52 | // LogCreator ...
53 | func LogCreator(args ...int) Logger {
54 | l := new(logger)
55 | l.producer = New()
56 | l.producer.Enable()
57 | l.level = 1
58 | if len(args) > 0 {
59 | l.level = uint8(args[0])
60 | }
61 | l.out = os.Stdout
62 | return l
63 | }
64 |
65 | var DefaultLogger = LogCreator()
66 |
67 | func (l *logger) SetOut(out io.Writer) {
68 | l.Lock()
69 | defer l.Unlock()
70 | l.out = out
71 | l.noColor = true
72 | }
73 |
74 | func (l *logger) SetLevel(level uint8) {
75 | l.Lock()
76 | defer l.Unlock()
77 | l.level = level
78 | }
79 |
80 | func (l *logger) IsLogger(p bool) {
81 | l.Lock()
82 | defer l.Unlock()
83 | l.isLogger = p
84 | }
85 |
86 | var mappingLevel = map[uint8]string{
87 | Critical: "Critical",
88 | Error: "Error",
89 | Warning: "Warn",
90 | Info: "Info",
91 | Debug: "Debug",
92 | }
93 |
94 | func (l *logger) logWrite(msg interface{}, level uint8) (string, bool) {
95 | var msgText string
96 | switch v := msg.(type) {
97 | case error:
98 | msgText = v.Error()
99 | case string:
100 | msgText = v
101 | }
102 |
103 | if level > l.level && !l.isLogger {
104 | return "", false
105 | }
106 |
107 | if !l.isLogger {
108 | _, file, lineno, ok := runtime.Caller(2)
109 |
110 | src := ""
111 |
112 | if ok {
113 | src = strings.Replace(
114 | fmt.Sprintf("%s:%d", file, lineno), "%2e", ".", -1)
115 | }
116 | msgText = fmt.Sprintf("%s [%s] %s (%s) %s", l.version, mappingLevel[level], time.Now().Format(timeFormat), src, msgText)
117 | } else {
118 | msgText = fmt.Sprintf("%s [%s] %s %s", l.version, mappingLevel[level], time.Now().Format(timeFormat), msgText)
119 | }
120 |
121 | return msgText, true
122 | }
123 |
124 | func (l *logger) print(msg string) {
125 | l.Lock()
126 | defer l.Unlock()
127 | _, err := l.out.Write(append([]byte(msg), '\n'))
128 | if err != nil {
129 | _, _ = os.Stdout.Write(append([]byte(msg), '\n'))
130 | }
131 | }
132 |
133 | func (l *logger) Custom(msg string) {
134 | l.print(msg)
135 | }
136 |
137 | func (l *logger) Critical(msg interface{}) {
138 | if msg, ok := l.logWrite(msg, Critical); ok {
139 | l.dyer(Critical, &msg)
140 | }
141 | }
142 |
143 | func (l *logger) Criticalf(error string, msg ...interface{}) {
144 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Critical); ok {
145 | l.dyer(Critical, &msg)
146 | }
147 | }
148 |
149 | func (l *logger) Error(msg interface{}) {
150 | if msg, ok := l.logWrite(msg, Error); ok {
151 | l.dyer(Error, &msg)
152 | }
153 | }
154 |
155 | func (l *logger) Errorf(error string, msg ...interface{}) {
156 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Error); ok {
157 | l.dyer(Error, &msg)
158 | }
159 | }
160 |
161 | func (l *logger) Warn(msg interface{}) {
162 | if msg, ok := l.logWrite(msg, Warning); ok {
163 | l.dyer(Warning, &msg)
164 | }
165 | }
166 |
167 | func (l *logger) Warnf(error string, msg ...interface{}) {
168 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Warning); ok {
169 | l.dyer(Warning, &msg)
170 | }
171 | }
172 |
173 | func (l *logger) Info(msg interface{}) {
174 | if msg, ok := l.logWrite(msg, Info); ok {
175 | l.dyer(Info, &msg)
176 | }
177 | }
178 |
179 | func (l *logger) Infof(error string, msg ...interface{}) {
180 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Info); ok {
181 | l.dyer(Info, &msg)
182 | }
183 | }
184 |
185 | func (l *logger) Debugf(error string, msg ...interface{}) {
186 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Debug); ok {
187 | l.dyer(Debug, &msg)
188 | }
189 | }
190 |
191 | func (l *logger) Debug(msg interface{}) {
192 | if msg, ok := l.logWrite(msg, Debug); ok {
193 | l.dyer(Debug, &msg)
194 | }
195 | }
196 |
197 | func (l *logger) dyer(level int, msg *string) {
198 | if l.noColor {
199 | l.print(*msg)
200 | return
201 | }
202 | switch level {
203 | case Critical:
204 | l.print(l.producer.Red(*msg))
205 | case Error:
206 | l.print(l.producer.Magenta(*msg))
207 | case Warning:
208 | l.print(l.producer.Yellow(*msg))
209 | case Info:
210 | l.print(l.producer.Blue(*msg))
211 | case Debug:
212 | l.print(l.producer.Cyan(*msg))
213 | }
214 | }
215 |
--------------------------------------------------------------------------------
/logrus.log:
--------------------------------------------------------------------------------
1 | [Critical] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:23) critical
2 | [Error] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:24) error
3 | [Warn] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:25) warn
4 | [Info] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:26) info
5 | [Debug] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:27) debug
6 | [Critical] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:28) test:123
7 | [Error] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:29) test:123
8 | [Warn] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:30) test:123
9 | [Info] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:31) test:123
10 | [Debug] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:32) test:123
--------------------------------------------------------------------------------
/middleware/basicAuth.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "encoding/base64"
5 | "errors"
6 | "net/http"
7 | "strings"
8 |
9 | "github.com/cookieY/yee"
10 | )
11 |
12 | // BasicAuthConfig defines the config of basicAuth middleware
13 | type BasicAuthConfig struct {
14 | Validator fnValidator
15 | Realm string
16 | }
17 |
18 | type fnValidator func([]byte) (bool, error)
19 |
20 | const (
21 | basic = "basic"
22 | )
23 |
24 | // BasicAuth is the default implementation BasicAuth middleware
25 | func BasicAuth(fn fnValidator) yee.HandlerFunc {
26 | config := BasicAuthConfig{Validator: fn}
27 | config.Realm = "."
28 | return BasicAuthWithConfig(config)
29 | }
30 |
31 | // BasicAuthWithConfig is the custom implementation BasicAuth middleware
32 | func BasicAuthWithConfig(config BasicAuthConfig) yee.HandlerFunc {
33 |
34 | if config.Validator == nil {
35 | panic("yee: basic-auth middleware requires a validator function")
36 | }
37 |
38 | return func(context yee.Context) (err error) {
39 | decode, _ := parserVerifyData(context)
40 | if verify, err := config.Validator(decode); err == nil && verify {
41 | return err
42 | }
43 |
44 | context.Response().Header().Set(yee.HeaderWWWAuthenticate, basic+" realm="+config.Realm)
45 |
46 | return context.ServerError(http.StatusUnauthorized, "invalid basic auth token")
47 | }
48 | }
49 |
50 | func parserVerifyData(context yee.Context) ([]byte, error) {
51 | var decode []byte
52 | res := context.Request()
53 | if res.Header.Get(yee.HeaderAuthorization) != "" {
54 | auth := strings.Split(res.Header.Get(yee.HeaderAuthorization), " ")
55 | if strings.ToLower(auth[0]) == basic {
56 | decode, err := base64.StdEncoding.DecodeString(auth[1])
57 | if err != nil {
58 | return decode, err
59 | }
60 | return decode, nil
61 | }
62 | return decode, errors.New("cannot get basic keyword")
63 | }
64 | return decode, errors.New("authorization header is empty")
65 | }
66 |
--------------------------------------------------------------------------------
/middleware/basicAuth_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/cookieY/yee"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | type user struct {
15 | Username string `json:"username"`
16 | Password string `json:"password"`
17 | }
18 |
19 | func validator(auth []byte) (bool, error) {
20 | var u user
21 | if err := json.Unmarshal(auth, &u); err != nil {
22 | return false, err
23 | }
24 | if u.Username == "test" && u.Password == "123123" {
25 | return true, nil
26 | }
27 | return false, nil
28 |
29 | }
30 |
31 | var testUser = map[string]string{"username": "test", "password": "123123"}
32 |
33 | func TestBasicAuth(t *testing.T) {
34 | y := yee.New()
35 | y.Use(BasicAuth(validator))
36 | y.GET("/", func(context yee.Context) error {
37 | return context.String(http.StatusOK, "ok")
38 | })
39 |
40 | req := httptest.NewRequest(http.MethodGet, "/", nil)
41 | rec := httptest.NewRecorder()
42 |
43 | y.ServeHTTP(rec, req)
44 |
45 | assert.Equal(t, http.StatusUnauthorized, rec.Code)
46 |
47 | req = httptest.NewRequest(http.MethodGet, "/", nil)
48 | rec = httptest.NewRecorder()
49 | u, _ := json.Marshal(testUser)
50 | encodeString := base64.StdEncoding.EncodeToString(u)
51 | req.Header.Set(yee.HeaderAuthorization, "basic "+encodeString)
52 | y.ServeHTTP(rec, req)
53 | assert.Equal(t, http.StatusOK, rec.Code)
54 |
55 | }
56 |
--------------------------------------------------------------------------------
/middleware/cors.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "net/http"
5 | "strconv"
6 | "strings"
7 |
8 | "github.com/cookieY/yee"
9 | )
10 |
11 | // CORSConfig defined the config of CORS middleware
12 | type CORSConfig struct {
13 | Origins []string
14 | AllowMethods []string
15 | AllowHeaders []string
16 | AllowCredentials bool
17 | ExposeHeaders []string
18 | MaxAge int
19 | }
20 |
21 | // DefaultCORSConfig is the default config of CORS middleware
22 | var DefaultCORSConfig = CORSConfig{
23 | Origins: []string{"*"},
24 | AllowMethods: []string{
25 | http.MethodGet,
26 | http.MethodPut,
27 | http.MethodPost,
28 | http.MethodDelete,
29 | http.MethodPatch,
30 | http.MethodHead,
31 | http.MethodOptions,
32 | http.MethodConnect,
33 | http.MethodTrace,
34 | },
35 | }
36 |
37 | // Cors is the default implementation CORS middleware
38 | func Cors() yee.HandlerFunc {
39 | return CorsWithConfig(DefaultCORSConfig)
40 | }
41 |
42 | // CorsWithConfig is the default implementation CORS middleware
43 | func CorsWithConfig(config CORSConfig) yee.HandlerFunc {
44 |
45 | if len(config.Origins) == 0 {
46 | config.Origins = DefaultCORSConfig.Origins
47 | }
48 |
49 | if len(config.AllowMethods) == 0 {
50 | config.AllowMethods = DefaultCORSConfig.AllowMethods
51 | }
52 |
53 | allowMethods := strings.Join(config.AllowMethods, ",")
54 |
55 | allowHeaders := strings.Join(config.AllowHeaders, ",")
56 |
57 | exposeHeaders := strings.Join(config.ExposeHeaders, ",")
58 |
59 | maxAge := strconv.Itoa(config.MaxAge)
60 |
61 | return func(c yee.Context) (err error) {
62 |
63 | localOrigin := c.GetHeader(yee.HeaderOrigin)
64 |
65 | allowOrigin := ""
66 |
67 | m := c.Request().Method
68 |
69 | for _, o := range config.Origins {
70 | if o == "*" && config.AllowCredentials {
71 | allowOrigin = localOrigin
72 | break
73 | }
74 | if o == "*" || o == localOrigin {
75 | allowOrigin = o
76 | break
77 | }
78 | }
79 |
80 | // when method was not OPTIONS,
81 | // we can return simple response header
82 | // because the OPTIONS method is used to
83 | // describe the communication options for the target resource
84 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS
85 |
86 | if m != http.MethodOptions {
87 | c.AddHeader(yee.HeaderVary, yee.HeaderOrigin)
88 | c.SetHeader(yee.HeaderAccessControlAllowOrigin, allowOrigin)
89 | if config.AllowCredentials {
90 | c.SetHeader(yee.HeaderAccessControlAllowCredentials, "true")
91 | }
92 | if exposeHeaders != "" {
93 | c.SetHeader(yee.HeaderAccessControlExposeHeaders, exposeHeaders)
94 | }
95 | c.Next()
96 | return
97 | }
98 |
99 | c.AddHeader(yee.HeaderVary, yee.HeaderOrigin)
100 | c.AddHeader(yee.HeaderVary, yee.HeaderAccessControlRequestMethod)
101 | c.AddHeader(yee.HeaderVary, yee.HeaderAccessControlRequestHeaders)
102 | c.SetHeader(yee.HeaderAccessControlAllowOrigin, allowOrigin)
103 | c.SetHeader(yee.HeaderAccessControlAllowMethods, allowMethods)
104 | if config.AllowCredentials {
105 | c.SetHeader(yee.HeaderAccessControlAllowCredentials, "true")
106 | }
107 | if allowHeaders != "" {
108 | c.SetHeader(yee.HeaderAccessControlAllowHeaders, allowHeaders)
109 | } else {
110 | h := c.GetHeader(yee.HeaderAccessControlRequestHeaders)
111 | if h != "" {
112 | c.SetHeader(yee.HeaderAccessControlAllowHeaders, h)
113 | }
114 | }
115 | if config.MaxAge > 0 {
116 | c.SetHeader(yee.HeaderAccessControlMaxAge, maxAge)
117 | }
118 | return
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/middleware/cors_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/cookieY/yee"
5 | "github.com/stretchr/testify/assert"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func TestCors(t *testing.T) {
12 | y := yee.New()
13 | y.Use(Cors())
14 |
15 | y.POST("/login", func(c yee.Context) error {
16 | return c.String(http.StatusOK, "test")
17 | })
18 |
19 | y.OPTIONS("/ok", func(context yee.Context) (err error) {
20 | return err
21 | })
22 |
23 | t.Run("http_get", func(t *testing.T) {
24 | req := httptest.NewRequest(http.MethodGet, "/ok", nil)
25 | rec := httptest.NewRecorder()
26 | y.ServeHTTP(rec, req)
27 | assert := assert.New(t)
28 | assert.Equal("test", rec.Body.String())
29 | assert.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin))
30 | })
31 |
32 | t.Run("http_option", func(t *testing.T) {
33 | req := httptest.NewRequest(http.MethodOptions, "/ok", nil)
34 | rec := httptest.NewRecorder()
35 | y.ServeHTTP(rec, req)
36 | assert := assert.New(t)
37 | assert.Equal(http.MethodGet, rec.Header().Get(yee.HeaderAccessControlAllowMethods))
38 | assert.Equal("Test", rec.Header().Get(yee.HeaderAccessControlAllowHeaders))
39 | })
40 | }
41 |
42 | func TestEncryptServer(t *testing.T) {
43 | e := yee.New()
44 | e.Use(Cors())
45 | e.POST("/encrypt", func(c yee.Context) (err error) {
46 | u := new(user)
47 | if err := c.Bind(u); err != nil {
48 | return err
49 | }
50 | return c.JSON(http.StatusOK, u)
51 | })
52 | e.Run(":9000")
53 | }
54 |
--------------------------------------------------------------------------------
/middleware/csrf.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "crypto/subtle"
5 | "errors"
6 | "net/http"
7 | "strings"
8 | "time"
9 |
10 | "github.com/cookieY/yee"
11 | "github.com/google/uuid"
12 | )
13 |
14 | // CSRFConfig defines the config of CSRF middleware
15 | type CSRFConfig struct {
16 | TokenLength uint8
17 | TokenLookup string
18 | Key string
19 | CookieName string
20 | CookieDomain string
21 | CookiePath string
22 | CookieMaxAge int
23 | CookieSecure bool
24 | CookieHTTPOnly bool
25 | }
26 |
27 | type csrfTokenCreator func(yee.Context) (string, error)
28 |
29 | // CSRFDefaultConfig is the default config of CSRF middleware
30 | var CSRFDefaultConfig = CSRFConfig{
31 | TokenLength: 16,
32 | TokenLookup: "header:" + yee.HeaderXCSRFToken,
33 | Key: "csrf",
34 | CookieName: "_csrf",
35 | CookieMaxAge: 28800,
36 | }
37 |
38 | // CSRF is the default implementation CSRF middleware
39 | func CSRF() yee.HandlerFunc {
40 | return CSRFWithConfig(CSRFDefaultConfig)
41 | }
42 |
43 | // CSRFWithConfig is the custom implementation CSRF middleware
44 | func CSRFWithConfig(config CSRFConfig) yee.HandlerFunc {
45 |
46 | if config.TokenLength == 0 {
47 | config.TokenLength = CSRFDefaultConfig.TokenLength
48 | }
49 |
50 | if config.TokenLookup == "" {
51 | config.TokenLookup = CSRFDefaultConfig.TokenLookup
52 | }
53 |
54 | if config.Key == "" {
55 | config.Key = CSRFDefaultConfig.Key
56 | }
57 |
58 | if config.CookieName == "" {
59 | config.CookieName = CSRFDefaultConfig.CookieName
60 | }
61 |
62 | if config.CookieMaxAge == 0 {
63 | config.CookieMaxAge = CSRFDefaultConfig.CookieMaxAge
64 | }
65 |
66 | proc := strings.Split(config.TokenLookup, ":")
67 |
68 | creator := csrfTokenFromHeader(proc[1])
69 |
70 | switch proc[0] {
71 | case "query":
72 | creator = csrfTokenFromQuery(proc[1])
73 | case "form":
74 | creator = csrfTokenFromForm(proc[1])
75 | }
76 |
77 | return func(context yee.Context) (err error) {
78 |
79 | // we fetch cookie from this request
80 | // if cookie haven`t token info
81 | // we need generate the token and create a new cookie
82 | // otherwise reuse token
83 |
84 | k, err := context.Cookie(config.CookieName)
85 | token := ""
86 | if err != nil {
87 | token = strings.Replace(uuid.New().String(), "-", "", -1)
88 | } else {
89 | token = k.Value
90 | }
91 |
92 | switch context.Request().Method {
93 | case http.MethodGet, http.MethodTrace, http.MethodOptions, http.MethodHead:
94 | default:
95 | clientToken, e := creator(context)
96 |
97 | if e != nil {
98 | return context.ServerError(http.StatusBadRequest, e.Error())
99 | }
100 | if !validateCSRFToken(token, clientToken) {
101 | return context.ServerError(http.StatusForbidden, "invalid csrf token")
102 | }
103 | }
104 |
105 | nCookie := new(http.Cookie)
106 | nCookie.Name = config.CookieName
107 | nCookie.Value = token
108 | if config.CookiePath != "" {
109 | nCookie.Path = config.CookiePath
110 | }
111 | if config.CookieDomain != "" {
112 | nCookie.Domain = config.CookieDomain
113 | }
114 | nCookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
115 | nCookie.Secure = config.CookieSecure
116 | nCookie.HttpOnly = config.CookieHTTPOnly
117 | context.SetCookie(nCookie)
118 |
119 | context.Put(config.Key, token)
120 | context.SetHeader(yee.HeaderVary, yee.HeaderCookie)
121 | context.Next()
122 | return
123 | }
124 | }
125 |
126 | func csrfTokenFromHeader(header string) csrfTokenCreator {
127 | return func(context yee.Context) (string, error) {
128 | token := context.GetHeader(header)
129 | if token == "" {
130 | return "", errors.New("missing csrf token in the header string")
131 | }
132 | return token, nil
133 | }
134 | }
135 |
136 | func csrfTokenFromQuery(param string) csrfTokenCreator {
137 | return func(context yee.Context) (string, error) {
138 | token := context.QueryParam(param)
139 | if token == "" {
140 | return "", errors.New("missing csrf token in the query string")
141 | }
142 | return token, nil
143 | }
144 | }
145 |
146 | func csrfTokenFromForm(param string) csrfTokenCreator {
147 | return func(context yee.Context) (string, error) {
148 | token := context.FormValue(param)
149 | if token == "" {
150 | return "", errors.New("missing csrf token in the form string")
151 | }
152 | return token, nil
153 | }
154 | }
155 | func validateCSRFToken(token, clientToken string) bool {
156 | return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
157 | }
158 |
--------------------------------------------------------------------------------
/middleware/csrf_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "strings"
7 | "testing"
8 |
9 | "github.com/cookieY/yee"
10 | "github.com/google/uuid"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestCSRFWithConfig(t *testing.T) {
15 |
16 | y := yee.New()
17 | y.Use(CSRF())
18 | y.POST("/", func(context yee.Context) error {
19 | return context.String(http.StatusOK, "ok")
20 | })
21 | req := httptest.NewRequest(http.MethodPost, "/", nil)
22 | rec := httptest.NewRecorder()
23 |
24 | // Without CSRF cookie
25 | req = httptest.NewRequest(http.MethodPost, "/", nil)
26 | rec = httptest.NewRecorder()
27 | y.ServeHTTP(rec, req)
28 | assert.Equal(t, "missing csrf token in the header string", rec.Body.String())
29 | assert.Equal(t, http.StatusBadRequest, rec.Code)
30 |
31 | // invalid csrf token
32 | req = httptest.NewRequest(http.MethodPost, "/", nil)
33 | req.Header.Set(yee.HeaderXCSRFToken, "cbghjiwhd")
34 | rec = httptest.NewRecorder()
35 | y.ServeHTTP(rec, req)
36 | assert.Equal(t, "invalid csrf token", rec.Body.String())
37 | assert.Equal(t, http.StatusForbidden, rec.Code)
38 |
39 | token := strings.Replace(uuid.New().String(), "-", "", -1)
40 | req = httptest.NewRequest(http.MethodPost, "/", nil)
41 | req.Header.Set(yee.HeaderCookie, "_csrf="+token)
42 | req.Header.Set(yee.HeaderXCSRFToken, token)
43 | rec = httptest.NewRecorder()
44 | y.ServeHTTP(rec, req)
45 | assert.Equal(t, http.StatusOK, rec.Code)
46 | }
47 |
--------------------------------------------------------------------------------
/middleware/gzip.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "bufio"
5 | "compress/gzip"
6 | "io"
7 | "io/ioutil"
8 | "net"
9 | "net/http"
10 | "strings"
11 |
12 | "github.com/cookieY/yee"
13 | )
14 |
15 | // GzipConfig defines config of Gzip middleware
16 | type GzipConfig struct {
17 | Level int
18 | }
19 |
20 | type gzipResponseWriter struct {
21 | io.Writer
22 | http.ResponseWriter
23 | }
24 |
25 | // DefaultGzipConfig is the default config of gzip middleware
26 | var DefaultGzipConfig = GzipConfig{Level: 1}
27 |
28 | // Gzip is the default implementation of gzip middleware
29 | func Gzip() yee.HandlerFunc {
30 | return GzipWithConfig(DefaultGzipConfig)
31 | }
32 |
33 | // GzipWithConfig is the custom implementation of gzip middleware
34 | func GzipWithConfig(config GzipConfig) yee.HandlerFunc {
35 | if config.Level == 0 {
36 | config.Level = DefaultGzipConfig.Level
37 | }
38 |
39 | return func(c yee.Context) (err error) {
40 | if c.IsWebsocket() {
41 | return
42 | }
43 | res := c.Response()
44 | res.Header().Add(yee.HeaderVary, yee.HeaderAcceptEncoding)
45 | if strings.Contains(c.Request().Header.Get(yee.HeaderAcceptEncoding), "gzip") {
46 | res.Header().Set(yee.HeaderContentEncoding, "gzip")
47 | rw := res.Writer()
48 | w, err := gzip.NewWriterLevel(rw, config.Level)
49 | if err != nil {
50 | return err
51 | }
52 | defer func() {
53 | if res.Size() < 1 {
54 | if res.Header().Get(yee.HeaderContentEncoding) == "gzip" {
55 | res.Header().Del(yee.HeaderContentEncoding)
56 | }
57 | res.Override(rw)
58 | w.Reset(ioutil.Discard)
59 | }
60 | _ = w.Close()
61 | }()
62 | grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
63 | res.Override(grw)
64 | }
65 | c.Next()
66 | return
67 | }
68 | }
69 |
70 | func (w *gzipResponseWriter) Write(b []byte) (int, error) {
71 | if w.Header().Get(yee.HeaderContentType) == "" {
72 | w.Header().Set(yee.HeaderContentType, http.DetectContentType(b))
73 | }
74 | return w.Writer.Write(b)
75 | }
76 |
77 | func (w *gzipResponseWriter) Flush() {
78 | w.Writer.(*gzip.Writer).Flush()
79 | if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
80 | flusher.Flush()
81 | }
82 | }
83 |
84 | func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
85 | return w.ResponseWriter.(http.Hijacker).Hijack()
86 | }
87 |
--------------------------------------------------------------------------------
/middleware/gzip_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/cookieY/yee"
5 | "github.com/stretchr/testify/assert"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func TestGzip(t *testing.T) {
12 | y := yee.New()
13 | y.Use(Logger(), GzipWithConfig(GzipConfig{Level: 9}))
14 | y.Static("/", "../testing/dist/assets")
15 | t.Run("http_get", func(t *testing.T) {
16 | req := httptest.NewRequest(http.MethodGet, "/js/app.d2880701.js", nil)
17 | req.Header.Add(yee.HeaderAcceptEncoding, "gzip")
18 | rec := httptest.NewRecorder()
19 | y.ServeHTTP(rec, req)
20 | assert2 := assert.New(t)
21 | assert2.Equal(http.StatusOK, rec.Code)
22 | })
23 | }
24 |
--------------------------------------------------------------------------------
/middleware/jwt.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/cookieY/yee"
7 | "github.com/golang-jwt/jwt"
8 | "net/http"
9 | "reflect"
10 | )
11 |
12 | // JwtConfig defines the config of JWT middleware
13 | type JwtConfig struct {
14 | GetKey string
15 | AuthScheme string
16 | SigningKey interface{}
17 | SigningMethod string
18 | TokenLookup []string
19 | Claims jwt.Claims
20 | keyFunc jwt.Keyfunc
21 | ErrorHandler JWTErrorHandler
22 | SuccessHandler JWTSuccessHandler
23 | }
24 |
25 | type jwtExtractor func(yee.Context) (string, error)
26 |
27 | // JWTErrorHandler defines a function which is error for a valid token.
28 | type JWTErrorHandler func(error) error
29 |
30 | // JWTSuccessHandler defines a function which is executed for a valid token.
31 | type JWTSuccessHandler func(yee.Context)
32 |
33 | const algorithmHS256 = "HS256"
34 |
35 | // DefaultJwtConfig is the default config of JWT middleware
36 | var DefaultJwtConfig = JwtConfig{
37 | GetKey: "auth",
38 | SigningMethod: algorithmHS256,
39 | AuthScheme: "Bearer",
40 | TokenLookup: []string{yee.HeaderAuthorization},
41 | Claims: jwt.MapClaims{},
42 | }
43 |
44 | // JWTWithConfig is the custom implementation CORS middleware
45 | func JWTWithConfig(config JwtConfig) yee.HandlerFunc {
46 | if config.SigningKey == nil {
47 | panic("yee: jwt middleware requires signing key")
48 | }
49 | if config.SigningMethod == "" {
50 | config.SigningMethod = DefaultJwtConfig.SigningMethod
51 | }
52 | if config.GetKey == "" {
53 | config.GetKey = DefaultJwtConfig.GetKey
54 | }
55 | if config.AuthScheme == "" {
56 | config.AuthScheme = DefaultJwtConfig.AuthScheme
57 | }
58 |
59 | if config.Claims == nil {
60 | config.Claims = DefaultJwtConfig.Claims
61 | }
62 |
63 | if config.TokenLookup == nil {
64 | config.TokenLookup = DefaultJwtConfig.TokenLookup
65 | }
66 |
67 | config.keyFunc = func(token *jwt.Token) (interface{}, error) {
68 | if token.Method.Alg() != config.SigningMethod {
69 | return nil, fmt.Errorf("unexpected jwt signing method=%v", token.Header["alg"])
70 | }
71 | return config.SigningKey, nil
72 | }
73 |
74 | extractor := jwtFromHeader(config.TokenLookup, config.AuthScheme)
75 |
76 | return func(c yee.Context) (err error) {
77 | // cause upgrade websocket will clear custom header
78 | // when header add jwt bearer that panic
79 | auth, err := extractor(c)
80 | if err != nil {
81 | return c.JSON(http.StatusBadRequest, err.Error())
82 | }
83 | token := new(jwt.Token)
84 | if _, ok := config.Claims.(jwt.MapClaims); ok {
85 | token, err = jwt.Parse(auth, config.keyFunc)
86 | if err != nil {
87 | return c.JSON(http.StatusUnauthorized, err.Error())
88 | }
89 | } else {
90 | t := reflect.ValueOf(config.Claims).Type().Elem()
91 | claims := reflect.New(t).Interface().(jwt.Claims)
92 | token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
93 | }
94 | if err == nil && token.Valid {
95 | c.Put(config.GetKey, token)
96 | return
97 | }
98 | // bug fix
99 | // if invalid or expired jwt,
100 | // we must intercept all handlers and return serverError
101 | return c.JSON(http.StatusUnauthorized, "invalid or expired jwt")
102 | }
103 | }
104 |
105 | func jwtFromHeader(header []string, authScheme string) jwtExtractor {
106 | return func(c yee.Context) (string, error) {
107 | for _, i := range header {
108 | auth := c.Request().Header.Get(i)
109 | l := len(authScheme)
110 | if len(auth) > l+1 && auth[:l] == authScheme {
111 | return auth[l+1:], nil
112 | }
113 | if i == yee.HeaderSecWebSocketProtocol {
114 | return auth, nil
115 | }
116 | }
117 | return "", errors.New("missing or malformed jwt")
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/middleware/jwt_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/cookieY/yee"
7 | "github.com/golang-jwt/jwt"
8 | "github.com/stretchr/testify/assert"
9 | "net/http"
10 | "net/http/httptest"
11 | "testing"
12 | "time"
13 | )
14 |
15 | func GenJwtToken() (string, error) {
16 | token := jwt.New(jwt.SigningMethodHS256)
17 | claims := token.Claims.(jwt.MapClaims)
18 | claims["name"] = "henry"
19 | claims["exp"] = time.Now().Add(time.Minute * 15).Unix()
20 | t, err := token.SignedString([]byte("dbcjqheupqjsuwsm"))
21 | if err != nil {
22 | return "", errors.New("JWT Generate Failure")
23 | }
24 | return t, nil
25 | }
26 |
27 | func JwtParse(c yee.Context) (string, string) {
28 | user := c.Get("auth").(*jwt.Token)
29 | claims := user.Claims.(jwt.MapClaims)
30 | return claims["name"].(string),claims["name"].(string)
31 | }
32 |
33 | func SuperManageGroup() yee.HandlerFunc {
34 | return func(c yee.Context) (err error) {
35 | user, _ := JwtParse(c)
36 | if user == "henry" {
37 | return
38 | }
39 | return c.JSON(http.StatusForbidden, "非法越权操作!")
40 | }
41 | }
42 |
43 | func TestJwt(t *testing.T) {
44 |
45 | cases := []struct {
46 | Name string
47 | Expected int
48 | IsSign bool
49 | Expire time.Duration
50 | }{
51 | {"not_token", 400, false, 0},
52 | {"test_is_ok", 200, true, 0},
53 | //{"test_is_expire", 401, true,time.Second * 1},
54 | }
55 | for _, i := range cases {
56 | t.Run(i.Name, func(t *testing.T) {
57 | y := yee.New()
58 |
59 | y.Use(JWTWithConfig(JwtConfig{SigningKey: []byte("dbcjqheupqjsuwsm")}))
60 | y.GET("/", func(context yee.Context) error {
61 | return context.String(http.StatusOK, "is_ok")
62 | })
63 |
64 | req := httptest.NewRequest(http.MethodGet, "/", nil)
65 | rec := httptest.NewRecorder()
66 | if i.IsSign {
67 | token, _ := GenJwtToken()
68 | req.Header.Set(yee.HeaderAuthorization, fmt.Sprintf("Bearer %s", token))
69 | }
70 | time.Sleep(i.Expire)
71 | y.ServeHTTP(rec, req)
72 | assert2 := assert.New(t)
73 | assert2.Equal(i.Expected, rec.Code)
74 | })
75 | }
76 | }
77 |
78 | func TestJwtExpire(t *testing.T) {
79 | y := yee.New()
80 | y.Use(Cors(), Secure())
81 | y.GET("/", func(context yee.Context) error {
82 | return context.String(http.StatusOK, "is_ok")
83 | })
84 | r := y.Group("/api", JWTWithConfig(JwtConfig{SigningKey: []byte("dbcjqheupqjsuwsm")}))
85 | k := r.Group("/k", SuperManageGroup())
86 | k.GET("/o", func(context yee.Context) (err error) {
87 | return context.JSON(http.StatusOK, "pk")
88 | })
89 | y.Run(":9999")
90 |
91 | }
92 |
--------------------------------------------------------------------------------
/middleware/logger.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "github.com/cookieY/yee/logger"
6 | "io"
7 | "log"
8 |
9 | "github.com/cookieY/yee"
10 | "github.com/valyala/fasttemplate"
11 | )
12 |
13 | //LoggerConfig defines config of logger middleware
14 | type (
15 | LoggerConfig struct {
16 | Format string
17 | Level uint8
18 | IsLogger bool
19 | }
20 | )
21 |
22 | // DefaultLoggerConfig is default config of logger middleware
23 | var DefaultLoggerConfig = LoggerConfig{
24 | Format: `"url":"${url}" "method":"${method}" "status":${status} "protocol":"${protocol}" "remote_ip":"${remote_ip}" "bytes_in": "${bytes_in} bytes" "bytes_out": "${bytes_out} bytes"`,
25 | Level: 3,
26 | IsLogger: true,
27 | }
28 |
29 | // Logger is default implementation of logger middleware
30 | func Logger() yee.HandlerFunc {
31 | return LoggerWithConfig(DefaultLoggerConfig)
32 | }
33 |
34 | // LoggerWithConfig is custom implementation of logger middleware
35 | func LoggerWithConfig(config LoggerConfig) yee.HandlerFunc {
36 | if config.Format == "" {
37 | config.Format = DefaultLoggerConfig.Format
38 | }
39 |
40 | if config.Level == 0 {
41 | config.Level = DefaultLoggerConfig.Level
42 | }
43 |
44 | t, err := fasttemplate.NewTemplate(config.Format, "${", "}")
45 |
46 | if err != nil {
47 | log.Fatalf("unexpected error when parsing template: %s", err)
48 | }
49 |
50 | logger := logger.LogCreator()
51 |
52 | logger.SetLevel(config.Level)
53 |
54 | logger.IsLogger(config.IsLogger)
55 |
56 | return func(context yee.Context) (err error) {
57 | context.Next()
58 | s := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
59 | switch tag {
60 | case "url":
61 | p := context.Request().URL.Path
62 | if p == "" {
63 | p = "/"
64 | }
65 | return w.Write([]byte(p))
66 | case "method":
67 | return w.Write([]byte(context.Request().Method))
68 | case "status":
69 | return w.Write([]byte(fmt.Sprintf("%d", context.Response().Status())))
70 | case "remote_ip":
71 | return w.Write([]byte(context.RemoteIP()))
72 | case "host":
73 | return w.Write([]byte(context.Request().Host))
74 | case "protocol":
75 | return w.Write([]byte(context.Request().Proto))
76 | case "bytes_in":
77 | cl := context.Request().Header.Get(yee.HeaderContentLength)
78 | if cl == "" {
79 | cl = "0"
80 | }
81 | return w.Write([]byte(cl))
82 | case "bytes_out":
83 | return w.Write([]byte(fmt.Sprintf("%d", context.Response().Size())))
84 | default:
85 | return w.Write([]byte(""))
86 | }
87 | })
88 | if context.Response().Status() < 400 {
89 | logger.Info(s)
90 | } else {
91 | logger.Warn(s)
92 | }
93 | return
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/middleware/logger_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/cookieY/yee"
5 | "github.com/stretchr/testify/assert"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func TestLogger(t *testing.T) {
12 | y := yee.New()
13 | y.Use(Logger())
14 | y.GET("/", func(context yee.Context) error {
15 | context.Logger().Critical("哈哈哈哈")
16 | return context.String(http.StatusOK, "ok")
17 | })
18 | t.Run("http_get", func(t *testing.T) {
19 | req := httptest.NewRequest(http.MethodGet, "/", nil)
20 | rec := httptest.NewRecorder()
21 | y.ServeHTTP(rec, req)
22 | assert := assert.New(t)
23 | assert.Equal("ok", rec.Body.String())
24 | assert.Equal(http.StatusOK, rec.Code)
25 | })
26 | }
27 |
--------------------------------------------------------------------------------
/middleware/multiMiddle_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | "github.com/cookieY/yee"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestMultiMiddle(t *testing.T) {
13 | y := yee.New()
14 | y.Use(Cors())
15 | y.Use(JWTWithConfig(JwtConfig{SigningKey: []byte("dbcjqheupqjsuwsm")}))
16 | y.GET("/", func(context yee.Context) error {
17 | return context.String(http.StatusOK, "is_ok")
18 | })
19 |
20 | req := httptest.NewRequest(http.MethodGet, "/", nil)
21 | rec := httptest.NewRecorder()
22 |
23 | a := assert.New(t)
24 | y.ServeHTTP(rec, req)
25 | a.Equal("\"missing or malformed jwt\"\n", rec.Body.String())
26 | a.Equal(400, rec.Code)
27 | a.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin))
28 | }
29 |
30 | func TestMultiGroup(t *testing.T) {
31 | y := yee.C()
32 | r := y.Group("/", Cors(), CustomerMiddleware())
33 | r.GET("/test", func(context yee.Context) error {
34 | return context.String(http.StatusOK, "is_ok")
35 | })
36 | req := httptest.NewRequest(http.MethodGet, "/test", nil)
37 | rec := httptest.NewRecorder()
38 |
39 | a := assert.New(t)
40 | y.ServeHTTP(rec, req)
41 | a.Equal("非法越权操作!", rec.Body.String())
42 | a.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin))
43 | a.Equal(403, rec.Code)
44 | }
45 |
46 | func CustomerMiddleware() yee.HandlerFunc {
47 | return func(c yee.Context) (err error) {
48 | if c.QueryParam("test") == "y" {
49 | c.Next()
50 | return
51 | }
52 | return c.ServerError(http.StatusForbidden, "非法越权操作!")
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/middleware/proxy.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/cookieY/yee"
7 | "net/http"
8 | "net/http/httputil"
9 | "net/url"
10 | "strings"
11 | )
12 |
13 | type ProxyTargetHandler interface {
14 | MatchAlgorithm() ProxyTarget
15 | }
16 |
17 | type ProxyTarget struct {
18 | Name string
19 | URL *url.URL
20 | }
21 |
22 | type ProxyConfig struct {
23 | BalanceType string
24 | Transport http.RoundTripper
25 | ModifyResponse func(*http.Response) error
26 | ProxyTarget ProxyTargetHandler
27 | }
28 |
29 | type errorHandler struct {
30 | Err error
31 | Code int
32 | }
33 |
34 | func ProxyWithConfig(config ProxyConfig) yee.HandlerFunc {
35 | return func(c yee.Context) (err error) {
36 | req := c.Request()
37 | res := c.Response()
38 | if req.Header.Get(yee.HeaderXRealIP) == "" {
39 | req.Header.Set(yee.HeaderXRealIP, c.RemoteIP())
40 | }
41 | if req.Header.Get(yee.HeaderXForwardedProto) == "" {
42 | req.Header.Set(yee.HeaderXForwardedProto, c.Scheme())
43 | }
44 | if c.IsWebsocket() && req.Header.Get(yee.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
45 | req.Header.Set(yee.HeaderXForwardedFor, c.RemoteIP())
46 | }
47 | switch {
48 | case c.IsWebsocket():
49 | //proxyRaw(tgt, c).ServeHTTP(res, req)
50 | case req.Header.Get(yee.HeaderAccept) == "text/event-stream":
51 | default:
52 | proxyHTTP(c, config).ServeHTTP(res, req)
53 | }
54 | if e, ok := c.Get("_error").(errorHandler); ok {
55 | return c.ServerError(e.Code, e.Err.Error())
56 | }
57 |
58 | return nil
59 | }
60 | }
61 |
62 | func proxyHTTP(c yee.Context, config ProxyConfig) http.Handler {
63 | tgt := config.ProxyTarget.MatchAlgorithm()
64 | proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
65 | proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
66 | desc := tgt.URL.String()
67 | if tgt.Name != "" {
68 | desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
69 | }
70 | if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") {
71 | c.Put("_error", errorHandler{fmt.Errorf("client closed connection: %s", err.Error()), yee.StatusCodeContextCanceled})
72 | } else {
73 | c.Put("_error", errorHandler{fmt.Errorf("remote %s unreachable, could not forward: %v", desc, err), http.StatusBadGateway})
74 | }
75 | }
76 | proxy.Transport = config.Transport
77 | proxy.ModifyResponse = config.ModifyResponse
78 | return proxy
79 | }
80 |
--------------------------------------------------------------------------------
/middleware/rateLimit.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "net/http"
5 | "sync"
6 | "time"
7 |
8 | "github.com/cookieY/yee"
9 | )
10 |
11 | // RateLimitConfig defines config of rateLimit middleware
12 | type RateLimitConfig struct {
13 | Time time.Duration
14 | Rate int
15 | lock *sync.Mutex
16 | numbers int
17 | }
18 |
19 | // DefaultRateLimit is the default config of rateLimit middleware
20 | var DefaultRateLimit = RateLimitConfig{
21 | Time: 1 * time.Second,
22 | Rate: 5,
23 | }
24 |
25 | // RateLimit is the default implementation of rateLimit middleware
26 | func RateLimit() yee.HandlerFunc {
27 | return RateLimitWithConfig(DefaultRateLimit)
28 | }
29 |
30 | // RateLimitWithConfig is the custom implementation of rateLimit middleware
31 | func RateLimitWithConfig(config RateLimitConfig) yee.HandlerFunc {
32 |
33 | if config.Time == 0 {
34 | config.Time = DefaultRateLimit.Time
35 | }
36 |
37 | if config.Rate == 0 {
38 | config.Rate = DefaultRateLimit.Rate
39 | }
40 |
41 | config.lock = new(sync.Mutex)
42 |
43 | go timer(&config)
44 |
45 | return func(context yee.Context) (err error) {
46 | if config.numbers >= config.Rate {
47 | return context.ServerError(http.StatusTooManyRequests, "too many requests")
48 | }
49 | config.lock.Lock()
50 | config.numbers++
51 | defer config.lock.Unlock()
52 | return
53 | }
54 | }
55 |
56 | func timer(c *RateLimitConfig) {
57 | ticker := time.NewTicker(c.Time)
58 | for {
59 | <-ticker.C
60 | c.lock.Lock()
61 | c.numbers = 0
62 | c.lock.Unlock()
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/middleware/rateLimit_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "github.com/cookieY/yee"
6 | "net/http"
7 | "net/http/httptest"
8 | "runtime"
9 | "sync"
10 | "testing"
11 | "time"
12 | )
13 |
14 | func TestRateLimit(t *testing.T) {
15 | runtime.GOMAXPROCS(runtime.NumCPU())
16 | y := yee.New()
17 | y.Use(RateLimitWithConfig(RateLimitConfig{Rate: 1,Time: time.Second * 2}))
18 | y.GET("/", func(context yee.Context) (err error) {
19 | return context.String(http.StatusOK, "ok")
20 | })
21 | var wg sync.WaitGroup
22 | var once sync.Once
23 | for i := 0; i < 50; i++ {
24 | if i > 25 {
25 | once.Do(func() {
26 | time.Sleep(time.Second * 2)
27 | })
28 | }
29 | wg.Add(1)
30 | go func(i int) {
31 | req := httptest.NewRequest(http.MethodGet, "/", nil)
32 | rec := httptest.NewRecorder()
33 | y.ServeHTTP(rec, req)
34 | fmt.Printf("id: %d code:%d body:%s \n",i,rec.Code,rec.Body.String())
35 | wg.Done()
36 | }(i)
37 | }
38 | wg.Wait()
39 | }
--------------------------------------------------------------------------------
/middleware/recovery.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "runtime"
7 | "strings"
8 |
9 | "github.com/cookieY/yee"
10 | )
11 |
12 | // Recovery is a recovery middleware
13 | // when the program was panic
14 | // it can recovery program and print stack info
15 | func Recovery() yee.HandlerFunc {
16 | return func(c yee.Context) (err error) {
17 | defer func() {
18 | if r := recover(); r != nil {
19 | err, ok := r.(error)
20 | if !ok {
21 | err = fmt.Errorf("%v", r)
22 | }
23 | var pcs [32]uintptr
24 | n := runtime.Callers(3, pcs[:]) // skip first 3 caller
25 |
26 | var str strings.Builder
27 | str.WriteString("Traceback:")
28 | for _, pc := range pcs[:n] {
29 | fn := runtime.FuncForPC(pc)
30 | file, line := fn.FileLine(pc)
31 | str.WriteString(fmt.Sprintf("\n\t%s:%d", file, line))
32 | }
33 | c.Logger().Critical(fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, str.String()))
34 | _ = c.ServerError(http.StatusInternalServerError, "Internal Server Error")
35 | }
36 | }()
37 | c.Next()
38 | return
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/middleware/recovery_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/cookieY/yee"
5 | "github.com/stretchr/testify/assert"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func TestRecovery(t *testing.T) {
12 | y := yee.New()
13 | y.Use(Recovery())
14 | y.GET("/y", func(context yee.Context) error {
15 | names := []string{"geektutu"}
16 | return context.String(http.StatusOK, names[100])
17 | })
18 |
19 | t.Run("http_get", func(t *testing.T) {
20 | req := httptest.NewRequest(http.MethodGet, "/y", nil)
21 | rec := httptest.NewRecorder()
22 | y.ServeHTTP(rec, req)
23 | assert := assert.New(t)
24 | assert.Equal("Internal Server Error", rec.Body.String())
25 | assert.Equal(http.StatusInternalServerError, rec.Code)
26 | })
27 | }
28 |
--------------------------------------------------------------------------------
/middleware/requestID.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/cookieY/yee"
7 | "github.com/google/uuid"
8 | )
9 |
10 | // RequestIDConfig defines config of requestID middleware
11 | type RequestIDConfig struct {
12 | generator func() string
13 | }
14 |
15 | // DefaultRequestIDConfig is the default config of requestID middleware
16 | var DefaultRequestIDConfig = RequestIDConfig{
17 | generator: defaultGenerator,
18 | }
19 |
20 | func defaultGenerator() string {
21 | return strings.Replace(uuid.New().String(), "-", "", -1)
22 | }
23 |
24 | // RequestID is the default implementation of requestID middleware
25 | func RequestID() yee.HandlerFunc {
26 | return RequestIDWithConfig(DefaultRequestIDConfig)
27 | }
28 |
29 | // RequestIDWithConfig is the custom implementation of requestID middleware
30 | func RequestIDWithConfig(config RequestIDConfig) yee.HandlerFunc {
31 |
32 | if config.generator == nil {
33 | config.generator = DefaultRequestIDConfig.generator
34 | }
35 | return func(context yee.Context) (err error) {
36 | req := context.Request()
37 | res := context.Response()
38 | if req.Header.Get(yee.HeaderXRequestID) == "" {
39 | res.Header().Set(yee.HeaderXRequestID, config.generator())
40 | }
41 | return
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/middleware/requestID_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "github.com/cookieY/yee"
6 | "github.com/stretchr/testify/assert"
7 | "net/http"
8 | "net/http/httptest"
9 | "testing"
10 | )
11 |
12 | func TestRequestID(t *testing.T) {
13 | y := yee.New()
14 | y.Use(RequestID())
15 | y.GET("/", func(context yee.Context) error {
16 | return context.String(http.StatusOK, "ok")
17 | })
18 |
19 | req := httptest.NewRequest(http.MethodGet, "/", nil)
20 | rec := httptest.NewRecorder()
21 |
22 | y.ServeHTTP(rec, req)
23 | fmt.Println(rec.Header().Get(yee.HeaderXRequestID))
24 | assert.Equal(t, http.StatusOK, rec.Code)
25 | assert.NotEqual(t, "", rec.Header().Get(yee.HeaderXRequestID))
26 | }
27 |
--------------------------------------------------------------------------------
/middleware/secure.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/cookieY/yee"
7 | )
8 |
9 | type (
10 |
11 | //SecureConfig defines config of secure middleware
12 | SecureConfig struct {
13 | XSSProtection string `yaml:"xss_protection"`
14 |
15 | ContentTypeNosniff string `yaml:"content_type_nosniff"`
16 |
17 | XFrameOptions string `yaml:"x_frame_options"`
18 |
19 | HSTSMaxAge int `yaml:"hsts_max_age"`
20 |
21 | HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"`
22 |
23 | ContentSecurityPolicy string `yaml:"content_security_policy"`
24 |
25 | CSPReportOnly bool `yaml:"csp_report_only"`
26 |
27 | HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"`
28 |
29 | ReferrerPolicy string `yaml:"referrer_policy"`
30 | }
31 | )
32 |
33 | // DefaultSecureConfig is default config of secure middleware
34 | var DefaultSecureConfig = SecureConfig{
35 | XSSProtection: "1; mode=block",
36 | ContentTypeNosniff: "nosniff",
37 | XFrameOptions: "SAMEORIGIN",
38 | HSTSPreloadEnabled: false,
39 | }
40 |
41 | // Secure is default implementation of secure middleware
42 | func Secure() yee.HandlerFunc {
43 | return SecureWithConfig(DefaultSecureConfig)
44 | }
45 |
46 | // SecureWithConfig is custom implementation of secure middleware
47 | func SecureWithConfig(config SecureConfig) yee.HandlerFunc {
48 | return func(c yee.Context) (err error) {
49 |
50 | if config.XSSProtection != "" {
51 | c.SetHeader(yee.HeaderXXSSProtection, config.XSSProtection)
52 | }
53 |
54 | if config.ContentTypeNosniff != "" {
55 | c.SetHeader(yee.HeaderXContentTypeOptions, config.ContentTypeNosniff)
56 | }
57 |
58 | if config.XFrameOptions != "" {
59 | c.SetHeader(yee.HeaderXFrameOptions, config.XFrameOptions)
60 | }
61 |
62 | if (c.IsTLS() || (c.GetHeader(yee.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 {
63 | subdomains := ""
64 | if !config.HSTSExcludeSubdomains {
65 | subdomains = "; includeSubdomains"
66 | }
67 | if config.HSTSPreloadEnabled {
68 | subdomains = fmt.Sprintf("%s; preload", subdomains)
69 | }
70 | c.SetHeader(yee.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", config.HSTSMaxAge, subdomains))
71 | }
72 | // CSP
73 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy-Report-Only
74 | // https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Content_Security_Policy
75 | if config.ContentSecurityPolicy != "" {
76 | if config.CSPReportOnly {
77 | c.SetHeader(yee.HeaderContentSecurityPolicyReportOnly, config.ContentSecurityPolicy)
78 | } else {
79 | c.SetHeader(yee.HeaderContentSecurityPolicy, config.ContentSecurityPolicy)
80 | }
81 | }
82 |
83 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
84 | if config.ReferrerPolicy != "" {
85 | c.SetHeader(yee.HeaderReferrerPolicy, config.ReferrerPolicy)
86 | }
87 | return
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/middleware/secure_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "github.com/cookieY/yee"
5 | "github.com/stretchr/testify/assert"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func TestSecure(t *testing.T) {
12 | y := yee.New()
13 | y.Use(Secure())
14 | y.GET("/", func(context yee.Context) error {
15 | return context.String(http.StatusOK, "ok")
16 | })
17 | t.Run("http_get", func(t *testing.T) {
18 | req := httptest.NewRequest(http.MethodGet, "/", nil)
19 | rec := httptest.NewRecorder()
20 | y.ServeHTTP(rec, req)
21 | assert := assert.New(t)
22 | assert.Equal("ok", rec.Body.String())
23 | assert.Equal(http.StatusOK, rec.Code)
24 | assert.Equal("SAMEORIGIN", rec.Header().Get(yee.HeaderXFrameOptions))
25 | assert.Equal("1; mode=block", rec.Header().Get(yee.HeaderXXSSProtection))
26 | assert.Equal("nosniff", rec.Header().Get(yee.HeaderXContentTypeOptions))
27 | })
28 | }
29 |
--------------------------------------------------------------------------------
/middleware/trace.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "bytes"
5 | "crypto/rand"
6 | "fmt"
7 | "github.com/cookieY/yee"
8 | "github.com/opentracing/opentracing-go"
9 | "github.com/opentracing/opentracing-go/ext"
10 | "github.com/uber/jaeger-client-go/config"
11 | "io"
12 | "io/ioutil"
13 | "net/http"
14 | "time"
15 | )
16 |
17 | const defaultComponentName = "yee"
18 |
19 | type (
20 | TraceConfig struct {
21 | // OpenTracing Tracer instance which should be got before
22 | Tracer opentracing.Tracer
23 | ComponentName string
24 | IsBodyDump bool
25 | LimitHTTPBody bool
26 | LimitSize int
27 | }
28 | responseDumper struct {
29 | http.ResponseWriter
30 |
31 | mw io.Writer
32 | buf *bytes.Buffer
33 | }
34 | )
35 |
36 | func New(e *yee.Core) io.Closer {
37 | // Add Opentracing instrumentation
38 | defcfg := config.Configuration{
39 | ServiceName: "echo-tracer",
40 | Sampler: &config.SamplerConfig{
41 | Type: "const",
42 | Param: 1,
43 | },
44 | Reporter: &config.ReporterConfig{
45 | LogSpans: true,
46 | BufferFlushInterval: 1 * time.Second,
47 | },
48 | }
49 | cfg, err := defcfg.FromEnv()
50 | if err != nil {
51 | panic("Could not parse Jaeger env vars: " + err.Error())
52 | }
53 | tracer, closer, err := cfg.NewTracer()
54 | if err != nil {
55 | panic("Could not initialize jaeger tracer: " + err.Error())
56 | }
57 |
58 | opentracing.SetGlobalTracer(tracer)
59 | e.Use(TraceWithConfig(TraceConfig{
60 | Tracer: tracer,
61 | }))
62 | return closer
63 | }
64 |
65 | func TraceWithConfig(config TraceConfig) yee.HandlerFunc {
66 | if config.Tracer == nil {
67 | panic("yee: trace middleware requires opentracing tracer")
68 | }
69 | if config.ComponentName == "" {
70 | config.ComponentName = defaultComponentName
71 | }
72 |
73 | return func(c yee.Context) error {
74 | req := c.Request()
75 | opname := "HTTP " + req.Method + " URL: " + c.Path()
76 | realIP := c.RemoteIP()
77 | requestID := getRequestID(c) // request-id generated by reverse-proxy
78 |
79 | var sp opentracing.Span
80 | var err error
81 |
82 | ctx, err := config.Tracer.Extract(
83 | opentracing.HTTPHeaders,
84 | opentracing.HTTPHeadersCarrier(req.Header),
85 | )
86 |
87 | if err != nil {
88 | sp = config.Tracer.StartSpan(opname)
89 | } else {
90 | sp = config.Tracer.StartSpan(opname, ext.RPCServerOption(ctx))
91 | }
92 | defer sp.Finish()
93 |
94 | ext.HTTPMethod.Set(sp, req.Method)
95 | ext.HTTPUrl.Set(sp, req.URL.String())
96 | ext.Component.Set(sp, config.ComponentName)
97 | sp.SetTag("client_ip", realIP)
98 | sp.SetTag("request_id", requestID)
99 |
100 | // Dump request & response body
101 | var respDumper *responseDumper
102 | if config.IsBodyDump {
103 | // request
104 | reqBody := []byte{}
105 | if c.Request().Body != nil {
106 | reqBody, _ = ioutil.ReadAll(c.Request().Body)
107 |
108 | if config.LimitHTTPBody {
109 | sp.LogKV("http.req.body", limitString(string(reqBody), config.LimitSize))
110 | } else {
111 | sp.LogKV("http.req.body", string(reqBody))
112 | }
113 | }
114 |
115 | req.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // reset original request body
116 |
117 | // response
118 | respDumper = newResponseDumper(c)
119 | c.Response().Override(respDumper.ResponseWriter)
120 | }
121 |
122 | // setup request context - add opentracing span
123 | req = req.WithContext(opentracing.ContextWithSpan(req.Context(), sp))
124 | c.SetRequest(req)
125 |
126 | // call next middleware / controller
127 | c.Next()
128 | if err != nil {
129 | c.Logger().Error(err) // call custom registered error handler
130 | }
131 |
132 | status := c.Response().Status()
133 | ext.HTTPStatusCode.Set(sp, uint16(status))
134 |
135 | if err != nil {
136 | logError(sp, err)
137 | }
138 |
139 | // Dump response body
140 | if config.IsBodyDump {
141 | if config.LimitHTTPBody {
142 | sp.LogKV("http.resp.body", limitString(respDumper.GetResponse(), config.LimitSize))
143 | } else {
144 | sp.LogKV("http.resp.body", respDumper.GetResponse())
145 | }
146 | }
147 |
148 | return nil // error was already processed with ctx.Error(err)
149 | }
150 | }
151 |
152 | func getRequestID(ctx yee.Context) string {
153 | requestID := ctx.Request().Header.Get(yee.HeaderXRequestID) // request-id generated by reverse-proxy
154 | if requestID == "" {
155 | requestID = generateToken() // missed request-id from proxy, we generate it manually
156 | }
157 | return requestID
158 | }
159 |
160 | func generateToken() string {
161 | b := make([]byte, 16)
162 | rand.Read(b)
163 | return fmt.Sprintf("%x", b)
164 | }
165 |
166 | func limitString(str string, size int) string {
167 | if len(str) > size {
168 | return str[:size/2] + "\n---- skipped ----\n" + str[len(str)-size/2:]
169 | }
170 |
171 | return str
172 | }
173 |
174 | func newResponseDumper(resp yee.Context) *responseDumper {
175 | buf := new(bytes.Buffer)
176 | return &responseDumper{
177 | ResponseWriter: resp.Response().Writer(),
178 | mw: io.MultiWriter(resp.Response().Writer(), buf),
179 | buf: buf,
180 | }
181 | }
182 |
183 | func (d *responseDumper) Write(b []byte) (int, error) {
184 | return d.mw.Write(b)
185 | }
186 |
187 | func (d *responseDumper) GetResponse() string {
188 | return d.buf.String()
189 | }
190 |
191 | func logError(span opentracing.Span, err error) {
192 | span.LogKV("error.message", err.Error())
193 | span.SetTag("error", true)
194 | }
195 |
--------------------------------------------------------------------------------
/pprof.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "net/http"
5 | "net/http/pprof"
6 | )
7 |
8 | const DefaultPrefix = "/debug/pprof"
9 |
10 | func getPrefix(prefixOptions string) string {
11 | prefix := DefaultPrefix
12 | if len(prefixOptions) > 1 {
13 | prefix = "/debug" + prefixOptions
14 | }
15 | return prefix
16 | }
17 |
18 | func WrapF(f http.HandlerFunc) HandlerFunc {
19 | return func(c Context) (err error) {
20 | f(c.Response(), c.Request())
21 | return nil
22 | }
23 | }
24 |
25 | func WrapH(h http.Handler) HandlerFunc {
26 | return func(c Context) (err error) {
27 | h.ServeHTTP(c.Response(), c.Request())
28 | return nil
29 | }
30 | }
31 |
32 | func (c *Core) Pprof() {
33 | c.GET(getPrefix("/"), WrapF(pprof.Index))
34 | c.GET(getPrefix("/cmdline"), WrapF(pprof.Cmdline))
35 | c.GET(getPrefix("/profile"), WrapF(pprof.Profile))
36 | c.POST(getPrefix("/symbol"), WrapF(pprof.Symbol))
37 | c.GET(getPrefix("/symbol"), WrapF(pprof.Symbol))
38 | c.GET(getPrefix("/trace"), WrapF(pprof.Trace))
39 | c.GET(getPrefix("/allocs"), WrapH(pprof.Handler("allocs")))
40 | c.GET(getPrefix("/block"), WrapH(pprof.Handler("block")))
41 | c.GET(getPrefix("/goroutine"), WrapH(pprof.Handler("goroutine")))
42 | c.GET(getPrefix("/heap"), WrapH(pprof.Handler("heap")))
43 | c.GET(getPrefix("/mutex"), WrapH(pprof.Handler("mutex")))
44 | c.GET(getPrefix("/threadcreate"), WrapH(pprof.Handler("threadcreate")))
45 | }
46 |
--------------------------------------------------------------------------------
/pprof_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import "testing"
4 |
5 | func TestCore_Pprof(t *testing.T) {
6 | c := New()
7 | c.Pprof()
8 | c.Run(":9999")
9 | }
10 |
--------------------------------------------------------------------------------
/response_overide.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
4 | // Use of this source code is governed by a MIT style
5 | // license that can be found in the LICENSE file.
6 | import (
7 | "bufio"
8 | "fmt"
9 | "io"
10 | "net"
11 | "net/http"
12 | )
13 |
14 | const (
15 | noWritten = -1
16 | defaultStatus = http.StatusOK
17 | )
18 |
19 | // ResponseWriter ...
20 | type ResponseWriter interface {
21 | http.ResponseWriter
22 | http.Hijacker
23 | http.Flusher
24 | http.CloseNotifier
25 |
26 | // Returns the HTTP response status code of the current request.
27 | Status() int
28 |
29 | // Returns the number of bytes already written into the response http body.
30 | // See Written()
31 | Size() int
32 |
33 | // Writes the string into the response body.
34 | WriteString(string) (int, error)
35 |
36 | // Returns true if the response body was already written.
37 | Written() bool
38 |
39 | //// Forces to write the http header (status code + headers).
40 | WriteHeaderNow()
41 |
42 | // get the http.Pusher for server push
43 | Pusher() http.Pusher
44 |
45 | Writer() http.ResponseWriter
46 |
47 | Override(rw http.ResponseWriter)
48 | }
49 |
50 | type responseWriter struct {
51 | http.ResponseWriter
52 | size int
53 | status int
54 | }
55 |
56 | var _ ResponseWriter = &responseWriter{}
57 |
58 | func (w *responseWriter) Writer() http.ResponseWriter {
59 | return w.ResponseWriter
60 | }
61 |
62 | func (w *responseWriter) Override(rw http.ResponseWriter) {
63 | w.ResponseWriter = rw
64 | }
65 |
66 | func (w *responseWriter) reset(writer http.ResponseWriter) {
67 | w.ResponseWriter = writer
68 | w.size = noWritten
69 | w.status = defaultStatus
70 | }
71 |
72 | func (w *responseWriter) WriteHeader(code int) {
73 | if code > 0 && w.status != code {
74 | if w.Written() {
75 | fmt.Printf("[WARNING] Headers were already written. Wanted to override status code %d with %d", w.status, code)
76 | }
77 | w.status = code
78 | }
79 | }
80 |
81 | func (w *responseWriter) WriteHeaderNow() {
82 | if !w.Written() {
83 | w.size = 0
84 | w.ResponseWriter.WriteHeader(w.status)
85 | }
86 | }
87 |
88 | func (w *responseWriter) Write(data []byte) (n int, err error) {
89 | w.WriteHeaderNow()
90 | n, err = w.ResponseWriter.Write(data)
91 | w.size += n
92 | return
93 | }
94 |
95 | func (w *responseWriter) WriteString(s string) (n int, err error) {
96 | w.WriteHeaderNow()
97 | n, err = io.WriteString(w.ResponseWriter, s)
98 | w.size += n
99 | return
100 | }
101 |
102 | func (w *responseWriter) Status() int {
103 | return w.status
104 | }
105 |
106 | func (w *responseWriter) Size() int {
107 | return w.size
108 | }
109 |
110 | func (w *responseWriter) Written() bool {
111 | return w.size != noWritten
112 | }
113 |
114 | // Hijack implements the http.Hijacker interface.
115 | func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
116 | if w.size < 0 {
117 | w.size = 0
118 | }
119 | return w.ResponseWriter.(http.Hijacker).Hijack()
120 | }
121 |
122 | // CloseNotify implements the http.CloseNotify interface.
123 | func (w *responseWriter) CloseNotify() <-chan bool {
124 | return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
125 | }
126 |
127 | // Flush implements the http.Flush interface.
128 | func (w *responseWriter) Flush() {
129 | w.WriteHeaderNow()
130 | w.ResponseWriter.(http.Flusher).Flush()
131 | }
132 |
133 | func (w *responseWriter) Pusher() (pusher http.Pusher) {
134 | if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
135 | return pusher
136 | }
137 | return nil
138 | }
139 |
--------------------------------------------------------------------------------
/router.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "embed"
5 | "io/fs"
6 | "net/http"
7 | "path"
8 | "strings"
9 | )
10 |
11 | type Router struct {
12 | handlers []HandlerFunc
13 | core *Core
14 | root bool
15 | basePath string
16 | }
17 |
18 | // RestfulAPI is the default implementation of restfulApi interface
19 | type RestfulAPI struct {
20 | Get HandlerFunc
21 | Post HandlerFunc
22 | Delete HandlerFunc
23 | Put HandlerFunc
24 | }
25 |
26 | // Implement the HTTP method and add to the router table
27 | // GET,POST,PUT,DELETE,OPTIONS,TRACE,HEAD,PATCH
28 | // these are defined in RFC 7231 section 4.3.
29 |
30 | func (r *Router) GET(path string, handler ...HandlerFunc) {
31 | r.handle(http.MethodGet, path, handler)
32 | }
33 |
34 | func (r *Router) POST(path string, handler ...HandlerFunc) {
35 | r.handle(http.MethodPost, path, handler)
36 | }
37 |
38 | func (r *Router) PUT(path string, handler ...HandlerFunc) {
39 | r.handle(http.MethodPut, path, handler)
40 | }
41 |
42 | func (r *Router) DELETE(path string, handler ...HandlerFunc) {
43 | r.handle(http.MethodDelete, path, handler)
44 | }
45 |
46 | func (r *Router) PATCH(path string, handler ...HandlerFunc) {
47 | r.handle(http.MethodPatch, path, handler)
48 | }
49 |
50 | func (r *Router) HEAD(path string, handler ...HandlerFunc) {
51 | r.handle(http.MethodHead, path, handler)
52 | }
53 |
54 | func (r *Router) TRACE(path string, handler ...HandlerFunc) {
55 | r.handle(http.MethodTrace, path, handler)
56 | }
57 |
58 | func (r *Router) OPTIONS(path string, handler ...HandlerFunc) {
59 | r.handle(http.MethodOptions, path, handler)
60 | }
61 |
62 | func (r *Router) Restful(path string, api RestfulAPI) {
63 |
64 | if api.Get != nil {
65 | r.handle(http.MethodGet, path, HandlersChain{api.Get})
66 | }
67 | if api.Post != nil {
68 | r.handle(http.MethodPost, path, HandlersChain{api.Post})
69 | }
70 | if api.Put != nil {
71 | r.handle(http.MethodPut, path, HandlersChain{api.Put})
72 | }
73 | if api.Delete != nil {
74 | r.handle(http.MethodDelete, path, HandlersChain{api.Delete})
75 | }
76 | }
77 |
78 | func (r *Router) Any(path string, handler ...HandlerFunc) {
79 | r.handle(http.MethodPost, path, handler)
80 | r.handle(http.MethodGet, path, handler)
81 | r.handle(http.MethodPut, path, handler)
82 | r.handle(http.MethodDelete, path, handler)
83 | r.handle(http.MethodOptions, path, handler)
84 | }
85 |
86 | func (r *Router) Use(middleware ...HandlerFunc) {
87 | r.handlers = append(r.handlers, middleware...)
88 | }
89 |
90 | func (r *Router) Group(prefix string, handlers ...HandlerFunc) *Router {
91 | rx := &Router{
92 | handlers: r.combineHandlers(handlers),
93 | core: r.core,
94 | basePath: r.calculateAbsolutePath(prefix),
95 | }
96 | return rx
97 | }
98 |
99 | func (r *Router) handle(method, path string, handlers HandlersChain) {
100 | absolutePath := r.calculateAbsolutePath(path)
101 | handlers = r.combineHandlers(handlers)
102 | r.core.addRoute(method, absolutePath, handlers)
103 | }
104 |
105 | func (c *Core) addRoute(method, prefix string, handlers HandlersChain) {
106 | if prefix[0] != '/' {
107 | panic("path must begin with '/'")
108 | }
109 |
110 | if method == "" {
111 | panic("HTTP method can not be empty")
112 | }
113 |
114 | root := c.trees.get(method)
115 | if root == nil {
116 | root = new(node)
117 | root.fullPath = "/"
118 | c.trees = append(c.trees, methodTree{method: method, root: root})
119 | }
120 | root.addRoute(prefix, handlers)
121 |
122 | // Update maxParams
123 | if paramsCount := countParams(prefix); paramsCount > c.maxParams {
124 | c.maxParams = paramsCount
125 | }
126 |
127 | }
128 |
129 | func (r *Router) Static(relativePath, root string) {
130 | if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") {
131 | panic("URL path cannot be used when serving a static folder")
132 | }
133 | handler := r.createDistHandler(relativePath, Dir(root, false))
134 | url := path.Join(relativePath, "/*filepath")
135 | r.GET(url, handler)
136 | r.HEAD(url, handler)
137 | }
138 |
139 | func (r *Router) Pack(relativePath string, f embed.FS, root string) {
140 | if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") {
141 | panic("URL path cannot be used when serving a static folder")
142 | }
143 | fsys, err := fs.Sub(f, root)
144 | if err != nil {
145 | panic(err)
146 | }
147 | handler := r.createDistHandler(relativePath, http.FS(fsys))
148 | url := path.Join(relativePath, "/*filepath")
149 | r.GET(url, handler)
150 | r.HEAD(url, handler)
151 | }
152 |
153 | func (r *Router) createDistHandler(relativePath string, fs http.FileSystem) HandlerFunc {
154 | absolutePath := r.calculateAbsolutePath(relativePath)
155 | fileServer := http.StripPrefix(absolutePath, http.FileServer(fs))
156 | return func(c Context) (err error) {
157 | if _, noListing := fs.(*onlyFilesFS); noListing {
158 | c.Response().WriteHeader(http.StatusNotFound)
159 | }
160 | file := c.Params("filepath")
161 | f, err2 := fs.Open(file)
162 | if err2 != nil {
163 | c.Status(http.StatusNotFound)
164 | c.Reset()
165 | }
166 | if f != nil {
167 | _ = f.Close()
168 | }
169 | fileServer.ServeHTTP(c.Response(), c.Request())
170 | return
171 | }
172 | }
173 |
174 | func (r *Router) calculateAbsolutePath(relativePath string) string {
175 | return joinPaths(r.basePath, relativePath)
176 | }
177 |
178 | func (r *Router) combineHandlers(handlers HandlersChain) HandlersChain {
179 | finalSize := len(r.handlers) + len(handlers)
180 | if finalSize >= crashIndex {
181 | panic("too many handlers")
182 | }
183 | mergedHandlers := make(HandlersChain, finalSize)
184 | copy(mergedHandlers, r.handlers)
185 | copy(mergedHandlers[len(r.handlers):], handlers)
186 | return mergedHandlers
187 | }
188 |
--------------------------------------------------------------------------------
/router_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestRouterParam(t *testing.T) {
13 |
14 | y := New()
15 | y.GET("/hello/k/:b", func(c Context) error {
16 | return c.String(http.StatusOK, c.Params("b"))
17 | })
18 | req := httptest.NewRequest(http.MethodGet, "/hello/k/yee", nil)
19 | rec := httptest.NewRecorder()
20 | y.ServeHTTP(rec, req)
21 | assert.Equal(t, "yee", rec.Body.String())
22 | assert.Equal(t, http.StatusOK, rec.Code)
23 | }
24 |
25 | func TestRouterParams(t *testing.T) {
26 |
27 | y := New()
28 | y.GET("/hello/k/:b/p/:j", func(c Context) error {
29 | return c.String(http.StatusOK, fmt.Sprintf("%s-%s", c.Params("b"), c.Params("j")))
30 | })
31 | req := httptest.NewRequest(http.MethodGet, "/hello/k/yee/p/henry", nil)
32 | rec := httptest.NewRecorder()
33 | y.ServeHTTP(rec, req)
34 | assert.Equal(t, "yee-henry", rec.Body.String())
35 | assert.Equal(t, http.StatusOK, rec.Code)
36 | }
37 |
38 | func TestRouterStaticPath(t *testing.T) {
39 |
40 | y := New()
41 | y.GET("/hello/k/*assets", func(c Context) error {
42 | return c.String(http.StatusOK, c.Params("assets"))
43 | })
44 | req := httptest.NewRequest(http.MethodGet, "/hello/k/assets/1.js", nil)
45 | rec := httptest.NewRecorder()
46 | y.ServeHTTP(rec, req)
47 | assert.Equal(t, "/assets/1.js", rec.Body.String())
48 | assert.Equal(t, http.StatusOK, rec.Code)
49 | }
50 |
51 | func TestRouterMultiRoute(t *testing.T) {
52 |
53 | y := New()
54 | y.GET("/hello/k/*k", func(c Context) error {
55 | return c.String(http.StatusOK, c.Params("k"))
56 | })
57 | y.GET("/hello/k", func(c Context) error {
58 | return c.String(http.StatusOK, "is_ok")
59 | })
60 | req := httptest.NewRequest(http.MethodGet, "/hello/k/yee/version/route", nil)
61 | rec := httptest.NewRecorder()
62 | y.ServeHTTP(rec, req)
63 | assert.Equal(t, "/yee/version/route", rec.Body.String())
64 | assert.Equal(t, http.StatusOK, rec.Code)
65 |
66 | req = httptest.NewRequest(http.MethodGet, "/hello/k", nil)
67 | rec = httptest.NewRecorder()
68 | y.ServeHTTP(rec, req)
69 | assert.Equal(t, "is_ok", rec.Body.String())
70 | assert.Equal(t, http.StatusOK, rec.Code)
71 |
72 | }
73 |
74 | func TestRouterQueryParam(t *testing.T) {
75 |
76 | y := New()
77 | y.GET("/hello/query", func(c Context) error {
78 | return c.String(http.StatusOK, c.QueryParam("query"))
79 | })
80 | req := httptest.NewRequest(http.MethodGet, "/hello/query?query=henry", nil)
81 | rec := httptest.NewRecorder()
82 | y.ServeHTTP(rec, req)
83 | assert.Equal(t, "henry", rec.Body.String())
84 | assert.Equal(t, http.StatusOK, rec.Code)
85 |
86 | }
87 |
88 | type testCase struct {
89 | uri string
90 | expect string
91 | }
92 |
93 | func TestRouter_Static(t *testing.T) {
94 | y := New()
95 | y.Static("/front","color")
96 | req := httptest.NewRequest(http.MethodGet, "/front/color.go", nil)
97 | rec := httptest.NewRecorder()
98 | y.ServeHTTP(rec, req)
99 | fmt.Println(rec.Body.String())
100 | }
101 |
102 | func TestRouterMixin(t *testing.T) {
103 | y := New()
104 | y.GET("/pay", func(c Context) error {
105 | return c.String(http.StatusOK, "pay")
106 | })
107 | y.GET("/pay/add", func(c Context) error {
108 | return c.String(http.StatusOK, c.QueryParam("person"))
109 | })
110 | y.GET("/pay/add/:id", func(c Context) error {
111 | return c.String(http.StatusOK, c.Params("id"))
112 | })
113 | y.GET("/pay/add/:id/:store", func(c Context) error {
114 | return c.String(http.StatusOK, c.Params("id")+c.Params("store"))
115 | })
116 | y.GET("/pay/dew", func(c Context) error {
117 | return c.String(http.StatusOK, "dew")
118 | })
119 | y.GET("/pay/dew/*account", func(c Context) error {
120 | return c.String(http.StatusOK, c.Params("account"))
121 | })
122 |
123 | c := []testCase{
124 | {
125 | uri: "/pay",
126 | expect: "pay",
127 | },
128 | {
129 | uri: "/pay/add?person=henry",
130 | expect: "henry",
131 | },
132 | {
133 | uri: "/pay/add/1",
134 | expect: "1",
135 | },
136 | {
137 | uri: "/pay/add/1/a",
138 | expect: "1a",
139 | },
140 | {
141 | uri: "/pay/dew",
142 | expect: "dew",
143 | },
144 | {
145 | uri: "/pay/dew/account/css/1.css",
146 | expect: "/account/css/1.css",
147 | },
148 | }
149 |
150 | for _, i := range c {
151 | req := httptest.NewRequest(http.MethodGet, i.uri, nil)
152 | rec := httptest.NewRecorder()
153 | y.ServeHTTP(rec, req)
154 | assert.Equal(t, i.expect, rec.Body.String())
155 | assert.Equal(t, http.StatusOK, rec.Code)
156 | }
157 | }
158 |
159 | // If you want to test routing performance, You can use benchmark_test to get it
160 |
161 | // --- testing any method
162 |
163 | func testRestfulAPI() RestfulAPI {
164 |
165 | var api RestfulAPI
166 |
167 | api.Get = func(c Context) (err error) {
168 | return c.String(http.StatusOK, "get")
169 | }
170 |
171 | api.Post = func(c Context) (err error) {
172 | return c.String(http.StatusOK, "post")
173 | }
174 |
175 | api.Delete = func(c Context) (err error) {
176 | return c.String(http.StatusOK, "delete")
177 | }
178 |
179 | api.Put = func(c Context) (err error) {
180 | return c.String(http.StatusOK, "put")
181 | }
182 |
183 | return api
184 | }
185 |
186 | func userUpdate(c Context) (err error) {
187 | return c.String(http.StatusOK, "updated")
188 | }
189 |
190 | func userFetch(c Context) (err error) {
191 | return c.String(http.StatusOK, "get it")
192 | }
193 |
194 | func test2RestfulAPI() RestfulAPI {
195 | return RestfulAPI{
196 | Get: userFetch,
197 | Post: userUpdate,
198 | }
199 | }
200 |
201 | func TestAnyMethod(t *testing.T) {
202 |
203 | y := New()
204 |
205 | y.Restful("/", testRestfulAPI())
206 | y.Restful("/user", test2RestfulAPI())
207 |
208 | req := httptest.NewRequest(http.MethodGet, "/", nil)
209 | rec := httptest.NewRecorder()
210 | y.ServeHTTP(rec, req)
211 | assert.Equal(t, "get", rec.Body.String())
212 | assert.Equal(t, http.StatusOK, rec.Code)
213 |
214 | req = httptest.NewRequest(http.MethodPost, "/", nil)
215 | rec = httptest.NewRecorder()
216 | y.ServeHTTP(rec, req)
217 | assert.Equal(t, "post", rec.Body.String())
218 | assert.Equal(t, http.StatusOK, rec.Code)
219 |
220 | req = httptest.NewRequest(http.MethodPut, "/", nil)
221 | rec = httptest.NewRecorder()
222 | y.ServeHTTP(rec, req)
223 | assert.Equal(t, "put", rec.Body.String())
224 | assert.Equal(t, http.StatusOK, rec.Code)
225 |
226 | req = httptest.NewRequest(http.MethodDelete, "/", nil)
227 | rec = httptest.NewRecorder()
228 | y.ServeHTTP(rec, req)
229 | assert.Equal(t, "delete", rec.Body.String())
230 | assert.Equal(t, http.StatusOK, rec.Code)
231 |
232 | req = httptest.NewRequest(http.MethodGet, "/user", nil)
233 | rec = httptest.NewRecorder()
234 | y.ServeHTTP(rec, req)
235 | assert.Equal(t, "get it", rec.Body.String())
236 | assert.Equal(t, http.StatusOK, rec.Code)
237 |
238 | req = httptest.NewRequest(http.MethodPost, "/user", nil)
239 | rec = httptest.NewRecorder()
240 | y.ServeHTTP(rec, req)
241 | assert.Equal(t, "updated", rec.Body.String())
242 | assert.Equal(t, http.StatusOK, rec.Code)
243 |
244 | req = httptest.NewRequest(http.MethodPut, "/user", nil)
245 | rec = httptest.NewRecorder()
246 | y.ServeHTTP(rec, req)
247 | assert.Equal(t, http.StatusNotFound, rec.Code)
248 |
249 | req = httptest.NewRequest(http.MethodDelete, "/user", nil)
250 | rec = httptest.NewRecorder()
251 | y.ServeHTTP(rec, req)
252 | assert.Equal(t, http.StatusNotFound, rec.Code)
253 | }
254 |
--------------------------------------------------------------------------------
/test/host.pb.go:
--------------------------------------------------------------------------------
1 | // Code generated by protoc-gen-go. DO NOT EDIT.
2 | // versions:
3 | // protoc-gen-go v1.26.0
4 | // protoc v3.12.3
5 | // source: host.proto
6 |
7 | package test
8 |
9 | import (
10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect"
11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl"
12 | reflect "reflect"
13 | sync "sync"
14 | )
15 |
16 | const (
17 | // Verify that this generated code is sufficiently up-to-date.
18 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
19 | // Verify that runtime/protoimpl is sufficiently up-to-date.
20 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
21 | )
22 |
23 | type Svr struct {
24 | state protoimpl.MessageState
25 | sizeCache protoimpl.SizeCache
26 | unknownFields protoimpl.UnknownFields
27 |
28 | Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
29 | IP string `protobuf:"bytes,2,opt,name=IP,proto3" json:"IP,omitempty"`
30 | Project string `protobuf:"bytes,3,opt,name=Project,proto3" json:"Project,omitempty"`
31 | IDC string `protobuf:"bytes,4,opt,name=IDC,proto3" json:"IDC,omitempty"`
32 | Cloud string `protobuf:"bytes,5,opt,name=Cloud,proto3" json:"Cloud,omitempty"`
33 | }
34 |
35 | func (x *Svr) Reset() {
36 | *x = Svr{}
37 | if protoimpl.UnsafeEnabled {
38 | mi := &file_host_proto_msgTypes[0]
39 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
40 | ms.StoreMessageInfo(mi)
41 | }
42 | }
43 |
44 | func (x *Svr) String() string {
45 | return protoimpl.X.MessageStringOf(x)
46 | }
47 |
48 | func (*Svr) ProtoMessage() {}
49 |
50 | func (x *Svr) ProtoReflect() protoreflect.Message {
51 | mi := &file_host_proto_msgTypes[0]
52 | if protoimpl.UnsafeEnabled && x != nil {
53 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
54 | if ms.LoadMessageInfo() == nil {
55 | ms.StoreMessageInfo(mi)
56 | }
57 | return ms
58 | }
59 | return mi.MessageOf(x)
60 | }
61 |
62 | // Deprecated: Use Svr.ProtoReflect.Descriptor instead.
63 | func (*Svr) Descriptor() ([]byte, []int) {
64 | return file_host_proto_rawDescGZIP(), []int{0}
65 | }
66 |
67 | func (x *Svr) GetName() string {
68 | if x != nil {
69 | return x.Name
70 | }
71 | return ""
72 | }
73 |
74 | func (x *Svr) GetIP() string {
75 | if x != nil {
76 | return x.IP
77 | }
78 | return ""
79 | }
80 |
81 | func (x *Svr) GetProject() string {
82 | if x != nil {
83 | return x.Project
84 | }
85 | return ""
86 | }
87 |
88 | func (x *Svr) GetIDC() string {
89 | if x != nil {
90 | return x.IDC
91 | }
92 | return ""
93 | }
94 |
95 | func (x *Svr) GetCloud() string {
96 | if x != nil {
97 | return x.Cloud
98 | }
99 | return ""
100 | }
101 |
102 | var File_host_proto protoreflect.FileDescriptor
103 |
104 | var file_host_proto_rawDesc = []byte{
105 | 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x68, 0x6f,
106 | 0x73, 0x74, 0x22, 0x6b, 0x0a, 0x03, 0x53, 0x76, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d,
107 | 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x0e, 0x0a,
108 | 0x02, 0x49, 0x50, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x18, 0x0a,
109 | 0x07, 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
110 | 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x44, 0x43, 0x18, 0x04,
111 | 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x49, 0x44, 0x43, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x6f,
112 | 0x75, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x6f, 0x75, 0x64, 0x42,
113 | 0x04, 0x5a, 0x02, 0x2e, 0x2f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
114 | }
115 |
116 | var (
117 | file_host_proto_rawDescOnce sync.Once
118 | file_host_proto_rawDescData = file_host_proto_rawDesc
119 | )
120 |
121 | func file_host_proto_rawDescGZIP() []byte {
122 | file_host_proto_rawDescOnce.Do(func() {
123 | file_host_proto_rawDescData = protoimpl.X.CompressGZIP(file_host_proto_rawDescData)
124 | })
125 | return file_host_proto_rawDescData
126 | }
127 |
128 | var file_host_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
129 | var file_host_proto_goTypes = []interface{}{
130 | (*Svr)(nil), // 0: host.Svr
131 | }
132 | var file_host_proto_depIdxs = []int32{
133 | 0, // [0:0] is the sub-list for method output_type
134 | 0, // [0:0] is the sub-list for method input_type
135 | 0, // [0:0] is the sub-list for extension type_name
136 | 0, // [0:0] is the sub-list for extension extendee
137 | 0, // [0:0] is the sub-list for field type_name
138 | }
139 |
140 | func init() { file_host_proto_init() }
141 | func file_host_proto_init() {
142 | if File_host_proto != nil {
143 | return
144 | }
145 | if !protoimpl.UnsafeEnabled {
146 | file_host_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
147 | switch v := v.(*Svr); i {
148 | case 0:
149 | return &v.state
150 | case 1:
151 | return &v.sizeCache
152 | case 2:
153 | return &v.unknownFields
154 | default:
155 | return nil
156 | }
157 | }
158 | }
159 | type x struct{}
160 | out := protoimpl.TypeBuilder{
161 | File: protoimpl.DescBuilder{
162 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
163 | RawDescriptor: file_host_proto_rawDesc,
164 | NumEnums: 0,
165 | NumMessages: 1,
166 | NumExtensions: 0,
167 | NumServices: 0,
168 | },
169 | GoTypes: file_host_proto_goTypes,
170 | DependencyIndexes: file_host_proto_depIdxs,
171 | MessageInfos: file_host_proto_msgTypes,
172 | }.Build()
173 | File_host_proto = out.File
174 | file_host_proto_rawDesc = nil
175 | file_host_proto_goTypes = nil
176 | file_host_proto_depIdxs = nil
177 | }
178 |
--------------------------------------------------------------------------------
/test/host.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 | package host;
3 | option go_package = "./";
4 |
5 | message Svr {
6 | string Name = 1;
7 | string IP = 2;
8 | string Project = 3;
9 | string IDC = 4;
10 | string Cloud = 5;
11 | }
--------------------------------------------------------------------------------
/tree.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "bytes"
5 | "net/url"
6 | "strings"
7 | "unicode"
8 | "unicode/utf8"
9 | )
10 |
11 | var (
12 | strColon = []byte(":")
13 | strStar = []byte("*")
14 | )
15 |
16 | // Param is a single URL parameter, consisting of a key and a value.
17 | type Param struct {
18 | Key string
19 | Value string
20 | }
21 |
22 | // Params is a Param-slice, as returned by the router.
23 | // The slice is ordered, the first URL parameter is also the first slice value.
24 | // It is therefore safe to read values by the index.
25 | type Params []Param
26 |
27 | // Get returns the value of the first Param which key matches the given name.
28 | // If no matching Param is found, an empty string is returned.
29 | func (ps Params) Get(name string) (string, bool) {
30 | for _, entry := range ps {
31 | if entry.Key == name {
32 | return entry.Value, true
33 | }
34 | }
35 | return "", false
36 | }
37 |
38 | // ByName returns the value of the first Param which key matches the given name.
39 | // If no matching Param is found, an empty string is returned.
40 | func (ps Params) ByName(name string) (va string) {
41 | va, _ = ps.Get(name)
42 | return
43 | }
44 |
45 | type methodTree struct {
46 | method string
47 | root *node
48 | }
49 |
50 | type methodTrees []methodTree
51 |
52 | func (trees methodTrees) get(method string) *node {
53 | for _, tree := range trees {
54 | if tree.method == method {
55 | return tree.root
56 | }
57 | }
58 | return nil
59 | }
60 |
61 | func min(a, b int) int {
62 | if a <= b {
63 | return a
64 | }
65 | return b
66 | }
67 |
68 | func longestCommonPrefix(a, b string) int {
69 | i := 0
70 | max := min(len(a), len(b))
71 | for i < max && a[i] == b[i] {
72 | i++
73 | }
74 | return i
75 | }
76 |
77 | func countParams(path string) uint16 {
78 | var n uint16
79 | s := StringToBytes(path)
80 | n += uint16(bytes.Count(s, strColon))
81 | n += uint16(bytes.Count(s, strStar))
82 | return n
83 | }
84 |
85 | type nodeType uint8
86 |
87 | const (
88 | static nodeType = iota // default
89 | root
90 | param
91 | catchAll
92 | )
93 |
94 | type node struct {
95 | path string
96 | indices string
97 | wildChild bool
98 | nType nodeType
99 | priority uint32
100 | children []*node
101 | handlers HandlersChain
102 | fullPath string
103 | }
104 |
105 | // Increments priority of the given child and reorders if necessary
106 | func (n *node) incrementChildPrio(pos int) int {
107 | cs := n.children
108 | cs[pos].priority++
109 | prio := cs[pos].priority
110 |
111 | // Adjust position (move to front)
112 | newPos := pos
113 | for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- {
114 | // Swap node positions
115 | cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1]
116 |
117 | }
118 |
119 | // Build new index char string
120 | if newPos != pos {
121 | n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty
122 | n.indices[pos:pos+1] + // The index char we move
123 | n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos'
124 | }
125 |
126 | return newPos
127 | }
128 |
129 | // addRoute adds a node with the given handle to the path.
130 | // Not concurrency-safe!
131 | func (n *node) addRoute(path string, handlers HandlersChain) {
132 | fullPath := path
133 | n.priority++
134 |
135 | // Empty tree
136 | if len(n.path) == 0 && len(n.children) == 0 {
137 | n.insertChild(path, fullPath, handlers)
138 | n.nType = root
139 | return
140 | }
141 |
142 | parentFullPathIndex := 0
143 |
144 | walk:
145 | for {
146 | // Find the longest common prefix.
147 | // This also implies that the common prefix contains no ':' or '*'
148 | // since the existing key can't contain those chars.
149 | i := longestCommonPrefix(path, n.path)
150 |
151 | // Split edge
152 | if i < len(n.path) {
153 | child := node{
154 | path: n.path[i:],
155 | wildChild: n.wildChild,
156 | indices: n.indices,
157 | children: n.children,
158 | handlers: n.handlers,
159 | priority: n.priority - 1,
160 | fullPath: n.fullPath,
161 | }
162 |
163 | n.children = []*node{&child}
164 | // []byte for proper unicode char conversion, see #65
165 | n.indices = BytesToString([]byte{n.path[i]})
166 | n.path = path[:i]
167 | n.handlers = nil
168 | n.wildChild = false
169 | n.fullPath = fullPath[:parentFullPathIndex+i]
170 | }
171 |
172 | // Make new node a child of this node
173 | if i < len(path) {
174 | path = path[i:]
175 |
176 | if n.wildChild {
177 | parentFullPathIndex += len(n.path)
178 | n = n.children[0]
179 | n.priority++
180 |
181 | // Check if the wildcard matches
182 | if len(path) >= len(n.path) && n.path == path[:len(n.path)] &&
183 | // Adding a child to a catchAll is not possible
184 | n.nType != catchAll &&
185 | // Check for longer wildcard, e.g. :name and :names
186 | (len(n.path) >= len(path) || path[len(n.path)] == '/') {
187 | continue walk
188 | }
189 |
190 | pathSeg := path
191 | if n.nType != catchAll {
192 | pathSeg = strings.SplitN(path, "/", 2)[0]
193 | }
194 | prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
195 | panic("'" + pathSeg +
196 | "' in new path '" + fullPath +
197 | "' conflicts with existing wildcard '" + n.path +
198 | "' in existing prefix '" + prefix +
199 | "'")
200 | }
201 |
202 | c := path[0]
203 |
204 | // slash after param
205 | if n.nType == param && c == '/' && len(n.children) == 1 {
206 | parentFullPathIndex += len(n.path)
207 | n = n.children[0]
208 | n.priority++
209 | continue walk
210 | }
211 |
212 | // Check if a child with the next path byte exists
213 | for i, max := 0, len(n.indices); i < max; i++ {
214 | if c == n.indices[i] {
215 | parentFullPathIndex += len(n.path)
216 | i = n.incrementChildPrio(i)
217 | n = n.children[i]
218 | continue walk
219 | }
220 | }
221 |
222 | // Otherwise insert it
223 | if c != ':' && c != '*' {
224 | // []byte for proper unicode char conversion, see #65
225 | n.indices += BytesToString([]byte{c})
226 | child := &node{
227 | fullPath: fullPath,
228 | }
229 | n.children = append(n.children, child)
230 | n.incrementChildPrio(len(n.indices) - 1)
231 | n = child
232 | }
233 | n.insertChild(path, fullPath, handlers)
234 | return
235 | }
236 |
237 | // Otherwise and handle to current node
238 | if n.handlers != nil {
239 | panic("handlers are already registered for path '" + fullPath + "'")
240 | }
241 | n.handlers = handlers
242 | n.fullPath = fullPath
243 | return
244 | }
245 | }
246 |
247 | // Search for a wildcard segment and check the name for invalid characters.
248 | // Returns -1 as index, if no wildcard was found.
249 | func findWildcard(path string) (wildcard string, i int, valid bool) {
250 | // Find start
251 | for start, c := range []byte(path) {
252 | // A wildcard starts with ':' (param) or '*' (catch-all)
253 | if c != ':' && c != '*' {
254 | continue
255 | }
256 |
257 | // Find end and check for invalid characters
258 | valid = true
259 | for end, c := range []byte(path[start+1:]) {
260 | switch c {
261 | case '/':
262 | return path[start : start+1+end], start, valid
263 | case ':', '*':
264 | valid = false
265 | }
266 | }
267 | return path[start:], start, valid
268 | }
269 | return "", -1, false
270 | }
271 |
272 | func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) {
273 | for {
274 | // Find prefix until first wildcard
275 | wildcard, i, valid := findWildcard(path)
276 | if i < 0 { // No wildcard found
277 | break
278 | }
279 |
280 | // The wildcard name must not contain ':' and '*'
281 | if !valid {
282 | panic("only one wildcard per path segment is allowed, has: '" +
283 | wildcard + "' in path '" + fullPath + "'")
284 | }
285 |
286 | // check if the wildcard has a name
287 | if len(wildcard) < 2 {
288 | panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
289 | }
290 |
291 | // Check if this node has existing children which would be
292 | // unreachable if we insert the wildcard here
293 | if len(n.children) > 0 {
294 | panic("wildcard segment '" + wildcard +
295 | "' conflicts with existing children in path '" + fullPath + "'")
296 | }
297 |
298 | if wildcard[0] == ':' { // param
299 | if i > 0 {
300 | // Insert prefix before the current wildcard
301 | n.path = path[:i]
302 | path = path[i:]
303 | }
304 |
305 | n.wildChild = true
306 | child := &node{
307 | nType: param,
308 | path: wildcard,
309 | fullPath: fullPath,
310 | }
311 | n.children = []*node{child}
312 | n = child
313 | n.priority++
314 |
315 | // if the path doesn't end with the wildcard, then there
316 | // will be another non-wildcard subpath starting with '/'
317 | if len(wildcard) < len(path) {
318 | path = path[len(wildcard):]
319 |
320 | child := &node{
321 | priority: 1,
322 | fullPath: fullPath,
323 | }
324 | n.children = []*node{child}
325 | n = child
326 | continue
327 | }
328 |
329 | // Otherwise we're done. Insert the handle in the new leaf
330 | n.handlers = handlers
331 | return
332 | }
333 |
334 | // catchAll
335 | if i+len(wildcard) != len(path) {
336 | panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
337 | }
338 |
339 | if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
340 | panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
341 | }
342 |
343 | // currently fixed width 1 for '/'
344 | i--
345 | if path[i] != '/' {
346 | panic("no / before catch-all in path '" + fullPath + "'")
347 | }
348 |
349 | n.path = path[:i]
350 |
351 | // First node: catchAll node with empty path
352 | child := &node{
353 | wildChild: true,
354 | nType: catchAll,
355 | fullPath: fullPath,
356 | }
357 |
358 | n.children = []*node{child}
359 | n.indices = string('/')
360 | n = child
361 | n.priority++
362 |
363 | // second node: node holding the variable
364 | child = &node{
365 | path: path[i:],
366 | nType: catchAll,
367 | handlers: handlers,
368 | priority: 1,
369 | fullPath: fullPath,
370 | }
371 | n.children = []*node{child}
372 |
373 | return
374 | }
375 |
376 | // If no wildcard was found, simply insert the path and handle
377 | n.path = path
378 | n.handlers = handlers
379 | n.fullPath = fullPath
380 | }
381 |
382 | // nodeValue holds return values of (*Node).getValue method
383 | type nodeValue struct {
384 | handlers HandlersChain
385 | params *Params
386 | tsr bool
387 | fullPath string
388 | }
389 |
390 | // Returns the handle registered with the given path (key). The values of
391 | // wildcards are saved to a map.
392 | // If no handle can be found, a TSR (trailing slash redirect) recommendation is
393 | // made if a handle exists with an extra (without the) trailing slash for the
394 | // given path.
395 | func (n *node) getValue(path string, params *Params, unescape bool) (value nodeValue) {
396 | walk: // Outer loop for walking the tree
397 | for {
398 | prefix := n.path
399 | if len(path) > len(prefix) {
400 | if path[:len(prefix)] == prefix {
401 | path = path[len(prefix):]
402 | // If this node does not have a wildcard (param or catchAll)
403 | // child, we can just look up the next child node and continue
404 | // to walk down the tree
405 | if !n.wildChild {
406 | idxc := path[0]
407 | for i, c := range []byte(n.indices) {
408 | if c == idxc {
409 | n = n.children[i]
410 | continue walk
411 | }
412 | }
413 |
414 | // Nothing found.
415 | // We can recommend to redirect to the same URL without a
416 | // trailing slash if a leaf exists for that path.
417 | value.tsr = (path == "/" && n.handlers != nil)
418 | return
419 | }
420 |
421 | // Handle wildcard child
422 | n = n.children[0]
423 | switch n.nType {
424 | case param:
425 | // Find param end (either '/' or path end)
426 | end := 0
427 | for end < len(path) && path[end] != '/' {
428 | end++
429 | }
430 | // Save param value
431 | if params != nil {
432 | if value.params == nil {
433 | value.params = params
434 | }
435 | // Expand slice within preallocated capacity
436 | i := len(*value.params)
437 | *value.params = (*value.params)[:i+1]
438 | val := path[:end]
439 | if unescape {
440 | if v, err := url.QueryUnescape(val); err == nil {
441 | val = v
442 | }
443 | }
444 | (*value.params)[i] = Param{
445 | Key: n.path[1:],
446 | Value: val,
447 | }
448 | }
449 |
450 | // we need to go deeper!
451 | if end < len(path) {
452 | if len(n.children) > 0 {
453 | path = path[end:]
454 | n = n.children[0]
455 | continue walk
456 | }
457 |
458 | // ... but we can't
459 | value.tsr = (len(path) == end+1)
460 | return
461 | }
462 |
463 | if value.handlers = n.handlers; value.handlers != nil {
464 | value.fullPath = n.fullPath
465 | return
466 | }
467 | if len(n.children) == 1 {
468 | // No handle found. Check if a handle for this path + a
469 | // trailing slash exists for TSR recommendation
470 | n = n.children[0]
471 | value.tsr = (n.path == "/" && n.handlers != nil)
472 | }
473 | return
474 |
475 | case catchAll:
476 | // Save param value
477 | if params != nil {
478 | if value.params == nil {
479 | value.params = params
480 | }
481 | // Expand slice within preallocated capacity
482 | i := len(*value.params)
483 | *value.params = (*value.params)[:i+1]
484 | val := path
485 | if unescape {
486 | if v, err := url.QueryUnescape(path); err == nil {
487 | val = v
488 | }
489 | }
490 | (*value.params)[i] = Param{
491 | Key: n.path[2:],
492 | Value: val,
493 | }
494 | }
495 |
496 | value.handlers = n.handlers
497 | value.fullPath = n.fullPath
498 | return
499 |
500 | default:
501 | panic("invalid node type")
502 | }
503 | }
504 | }
505 |
506 | if path == prefix {
507 | // We should have reached the node containing the handle.
508 | // Check if this node has a handle registered.
509 | if value.handlers = n.handlers; value.handlers != nil {
510 | value.fullPath = n.fullPath
511 | return
512 | }
513 |
514 | // If there is no handle for this route, but this route has a
515 | // wildcard child, there must be a handle for this path with an
516 | // additional trailing slash
517 | if path == "/" && n.wildChild && n.nType != root {
518 | value.tsr = true
519 | return
520 | }
521 |
522 | // No handle found. Check if a handle for this path + a
523 | // trailing slash exists for trailing slash recommendation
524 | for i, c := range []byte(n.indices) {
525 | if c == '/' {
526 | n = n.children[i]
527 | value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
528 | (n.nType == catchAll && n.children[0].handlers != nil)
529 | return
530 | }
531 | }
532 |
533 | return
534 | }
535 |
536 | // Nothing found. We can recommend to redirect to the same URL with an
537 | // extra trailing slash if a leaf exists for that path
538 | value.tsr = (path == "/") ||
539 | (len(prefix) == len(path)+1 && prefix[len(path)] == '/' &&
540 | path == prefix[:len(prefix)-1] && n.handlers != nil)
541 | return
542 | }
543 | }
544 |
545 | // Makes a case-insensitive lookup of the given path and tries to find a handler.
546 | // It can optionally also fix trailing slashes.
547 | // It returns the case-corrected path and a bool indicating whether the lookup
548 | // was successful.
549 | func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
550 | const stackBufSize = 128
551 |
552 | // Use a static sized buffer on the stack in the common case.
553 | // If the path is too long, allocate a buffer on the heap instead.
554 | buf := make([]byte, 0, stackBufSize)
555 | if l := len(path) + 1; l > stackBufSize {
556 | buf = make([]byte, 0, l)
557 | }
558 |
559 | ciPath := n.findCaseInsensitivePathRec(
560 | path,
561 | buf, // Preallocate enough memory for new path
562 | [4]byte{}, // Empty rune buffer
563 | fixTrailingSlash,
564 | )
565 |
566 | return ciPath, ciPath != nil
567 | }
568 |
569 | // Shift bytes in array by n bytes left
570 | func shiftNRuneBytes(rb [4]byte, n int) [4]byte {
571 | switch n {
572 | case 0:
573 | return rb
574 | case 1:
575 | return [4]byte{rb[1], rb[2], rb[3], 0}
576 | case 2:
577 | return [4]byte{rb[2], rb[3]}
578 | case 3:
579 | return [4]byte{rb[3]}
580 | default:
581 | return [4]byte{}
582 | }
583 | }
584 |
585 | // Recursive case-insensitive lookup function used by n.findCaseInsensitivePath
586 | func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte {
587 | npLen := len(n.path)
588 |
589 | walk: // Outer loop for walking the tree
590 | for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) {
591 | // Add common prefix to result
592 | oldPath := path
593 | path = path[npLen:]
594 | ciPath = append(ciPath, n.path...)
595 |
596 | if len(path) > 0 {
597 | // If this node does not have a wildcard (param or catchAll) child,
598 | // we can just look up the next child node and continue to walk down
599 | // the tree
600 | if !n.wildChild {
601 | // Skip rune bytes already processed
602 | rb = shiftNRuneBytes(rb, npLen)
603 |
604 | if rb[0] != 0 {
605 | // Old rune not finished
606 | idxc := rb[0]
607 | for i, c := range []byte(n.indices) {
608 | if c == idxc {
609 | // continue with child node
610 | n = n.children[i]
611 | npLen = len(n.path)
612 | continue walk
613 | }
614 | }
615 | } else {
616 | // Process a new rune
617 | var rv rune
618 |
619 | // Find rune start.
620 | // Runes are up to 4 byte long,
621 | // -4 would definitely be another rune.
622 | var off int
623 | for max := min(npLen, 3); off < max; off++ {
624 | if i := npLen - off; utf8.RuneStart(oldPath[i]) {
625 | // read rune from cached path
626 | rv, _ = utf8.DecodeRuneInString(oldPath[i:])
627 | break
628 | }
629 | }
630 |
631 | // Calculate lowercase bytes of current rune
632 | lo := unicode.ToLower(rv)
633 | utf8.EncodeRune(rb[:], lo)
634 |
635 | // Skip already processed bytes
636 | rb = shiftNRuneBytes(rb, off)
637 |
638 | idxc := rb[0]
639 | for i, c := range []byte(n.indices) {
640 | // Lowercase matches
641 | if c == idxc {
642 | // must use a recursive approach since both the
643 | // uppercase byte and the lowercase byte might exist
644 | // as an index
645 | if out := n.children[i].findCaseInsensitivePathRec(
646 | path, ciPath, rb, fixTrailingSlash,
647 | ); out != nil {
648 | return out
649 | }
650 | break
651 | }
652 | }
653 |
654 | // If we found no match, the same for the uppercase rune,
655 | // if it differs
656 | if up := unicode.ToUpper(rv); up != lo {
657 | utf8.EncodeRune(rb[:], up)
658 | rb = shiftNRuneBytes(rb, off)
659 |
660 | idxc := rb[0]
661 | for i, c := range []byte(n.indices) {
662 | // Uppercase matches
663 | if c == idxc {
664 | // Continue with child node
665 | n = n.children[i]
666 | npLen = len(n.path)
667 | continue walk
668 | }
669 | }
670 | }
671 | }
672 |
673 | // Nothing found. We can recommend to redirect to the same URL
674 | // without a trailing slash if a leaf exists for that path
675 | if fixTrailingSlash && path == "/" && n.handlers != nil {
676 | return ciPath
677 | }
678 | return nil
679 | }
680 |
681 | n = n.children[0]
682 | switch n.nType {
683 | case param:
684 | // Find param end (either '/' or path end)
685 | end := 0
686 | for end < len(path) && path[end] != '/' {
687 | end++
688 | }
689 |
690 | // Add param value to case insensitive path
691 | ciPath = append(ciPath, path[:end]...)
692 |
693 | // We need to go deeper!
694 | if end < len(path) {
695 | if len(n.children) > 0 {
696 | // Continue with child node
697 | n = n.children[0]
698 | npLen = len(n.path)
699 | path = path[end:]
700 | continue
701 | }
702 |
703 | // ... but we can't
704 | if fixTrailingSlash && len(path) == end+1 {
705 | return ciPath
706 | }
707 | return nil
708 | }
709 |
710 | if n.handlers != nil {
711 | return ciPath
712 | }
713 |
714 | if fixTrailingSlash && len(n.children) == 1 {
715 | // No handle found. Check if a handle for this path + a
716 | // trailing slash exists
717 | n = n.children[0]
718 | if n.path == "/" && n.handlers != nil {
719 | return append(ciPath, '/')
720 | }
721 | }
722 |
723 | return nil
724 |
725 | case catchAll:
726 | return append(ciPath, path...)
727 |
728 | default:
729 | panic("invalid node type")
730 | }
731 | } else {
732 | // We should have reached the node containing the handle.
733 | // Check if this node has a handle registered.
734 | if n.handlers != nil {
735 | return ciPath
736 | }
737 |
738 | // No handle found.
739 | // Try to fix the path by adding a trailing slash
740 | if fixTrailingSlash {
741 | for i, c := range []byte(n.indices) {
742 | if c == '/' {
743 | n = n.children[i]
744 | if (len(n.path) == 1 && n.handlers != nil) ||
745 | (n.nType == catchAll && n.children[0].handlers != nil) {
746 | return append(ciPath, '/')
747 | }
748 | return nil
749 | }
750 | }
751 | }
752 | return nil
753 | }
754 | }
755 |
756 | // Nothing found.
757 | // Try to fix the path by adding / removing a trailing slash
758 | if fixTrailingSlash {
759 | if path == "/" {
760 | return ciPath
761 | }
762 | if len(path)+1 == npLen && n.path[len(path)] == '/' &&
763 | strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil {
764 | return append(ciPath, n.path...)
765 | }
766 | }
767 | return nil
768 | }
769 |
770 |
--------------------------------------------------------------------------------
/yee.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "fmt"
5 | "github.com/cookieY/yee/logger"
6 | "golang.org/x/net/http2"
7 | "golang.org/x/net/http2/h2c"
8 | "io"
9 | "log"
10 | "net/http"
11 | "os"
12 | "sync"
13 | )
14 |
15 | // HandlerFunc define handler of context
16 | type HandlerFunc func(Context) (err error)
17 |
18 | // HandlersChain define handler chain of context
19 | type HandlersChain []HandlerFunc
20 |
21 | // Core implement httpServer interface
22 | type Core struct {
23 | *Router
24 | trees methodTrees
25 | pool sync.Pool
26 | maxParams uint16
27 | HandleMethodNotAllowed bool
28 | H2server *http.Server
29 | allNoRoute HandlersChain
30 | allNoMethod HandlersChain
31 | noRoute HandlersChain
32 | noMethod HandlersChain
33 | l logger.Logger
34 | color *logger.Color
35 | bind DefaultBinder
36 | RedirectTrailingSlash bool
37 | RedirectFixedPath bool
38 | Banner bool
39 | }
40 |
41 | const version = "yee v0.5.1"
42 |
43 | const creator = "Creator: Henry Yee"
44 | const title = "-----Easier and Faster-----"
45 |
46 | const banner = `
47 | __ __
48 | _ \/ /_________
49 | __ /_ _ \ _ \
50 | _ / / __/ __/
51 | /_/ \___/\___/ %s
52 | %s
53 | %s
54 | `
55 |
56 | // New create a core and perform a series of initializations
57 | func New() *Core {
58 | core := C()
59 | core.l.Custom(fmt.Sprintf(banner, logger.Green(version), logger.Red(title), logger.Cyan(creator)))
60 | return core
61 | }
62 |
63 | func C() *Core {
64 | router := &Router{
65 | handlers: nil,
66 | root: true,
67 | basePath: "/",
68 | }
69 |
70 | core := &Core{
71 | trees: make(methodTrees, 0, 0),
72 | Router: router,
73 | l: logger.LogCreator(),
74 | bind: DefaultBinder{},
75 | }
76 |
77 | core.core = core
78 |
79 | core.pool.New = func() interface{} {
80 | return core.allocateContext()
81 | }
82 | return core
83 | }
84 |
85 | // SetLogLevel define custom log level
86 | func (c *Core) SetLogLevel(l uint8) {
87 | c.l.SetLevel(l)
88 | }
89 |
90 | func (c *Core) SetLogOut(out io.Writer) {
91 | c.l.SetOut(out)
92 | }
93 |
94 | func (c *Core) allocateContext() *context {
95 | v := make(Params, 0, c.maxParams)
96 | return &context{engine: c, params: &v, index: -1}
97 | }
98 |
99 | // Use defines which middleware is uesd
100 | // when we dose not match prefix or method
101 | // we`ll register noRoute or noMethod handle for this
102 | // otherwise, we cannot be verified for noRoute/noMethod
103 | func (c *Core) Use(middleware ...HandlerFunc) {
104 | c.Router.Use(middleware...)
105 | c.rebuild404Handlers()
106 | c.rebuild405Handlers()
107 |
108 | }
109 |
110 | func (c *Core) rebuild404Handlers() {
111 | c.allNoRoute = c.combineHandlers(c.noRoute)
112 | }
113 |
114 | func (c *Core) rebuild405Handlers() {
115 | c.allNoMethod = c.combineHandlers(c.noMethod)
116 | }
117 |
118 | // override Handler.ServeHTTP
119 | // all requests/response deal with here
120 | // we use sync.pool save context variable
121 | // because we do this can be used less memory
122 | // we just only reset context, when before callback c.handleHTTPRequest func
123 | // and put context variable into poll
124 |
125 | func (c *Core) ServeHTTP(w http.ResponseWriter, r *http.Request) {
126 | context := c.pool.Get().(*context)
127 | context.writermem.reset(w)
128 | context.r = r
129 | context.reset()
130 | //context.w.Header().Set(HeaderServer, serverName)
131 | c.handleHTTPRequest(context)
132 | c.pool.Put(context)
133 | }
134 |
135 | // NewContext is for testing
136 | func (c *Core) NewContext(r *http.Request, w http.ResponseWriter) Context {
137 | context := new(context)
138 | context.writermem.reset(w)
139 | context.w = &context.writermem
140 | context.r = r
141 | context.engine = c
142 | return context
143 | }
144 |
145 | // Run is launch of http
146 | func (c *Core) Run(addr string) {
147 | if err := http.ListenAndServe(addr, c); err != nil {
148 | c.l.Critical(err.Error())
149 | os.Exit(1)
150 | }
151 | }
152 |
153 | // RunTLS is launch of tls
154 | // golang supports http2,if client supports http2
155 | // Otherwise, the http protocol return to http1.1
156 | func (c *Core) RunTLS(addr, certFile, keyFile string) {
157 | if err := http.ListenAndServeTLS(addr, certFile, keyFile, c); err != nil {
158 | c.l.Critical(err.Error())
159 | os.Exit(1)
160 | }
161 | }
162 |
163 | // RunH2C is launch of h2c
164 | // In normal conditions, http2 must used certificate
165 | // H2C is non-certificate`s http2
166 | // notes:
167 | // 1.the browser is not supports H2C proto, you should write your web client test program
168 | // 2.the H2C protocol is not safety
169 | func (c *Core) RunH2C(addr string) {
170 | s := &http2.Server{}
171 | h1s := &http.Server{
172 | Addr: addr,
173 | Handler: h2c.NewHandler(c, s),
174 | }
175 | log.Fatal(h1s.ListenAndServe())
176 | }
177 |
178 | //func (c *Core) RunH3(addr, ca, keyFile string) {
179 | // //if isTCP {
180 | // log.Fatal(http3.ListenAndServe(addr, ca, keyFile, c))
181 | // //log.Fatal(s.ListenAndServeTLS(ca, keyFile))
182 | //}
183 |
184 | func (c *Core) handleHTTPRequest(context *context) {
185 | httpMethod := context.r.Method
186 | rPath := context.r.URL.Path
187 | unescape := false
188 | // Find root of the tree for the given HTTP method
189 | t := c.trees
190 | for i, tl := 0, len(t); i < tl; i++ {
191 |
192 | if t[i].method != httpMethod {
193 | continue
194 | }
195 |
196 | root := t[i].root
197 | // Find route in tree
198 | value := root.getValue(rPath, context.params, unescape)
199 | if value.params != nil {
200 | context.Param = *value.params
201 | }
202 | if value.handlers != nil {
203 | context.handlers = value.handlers
204 | context.path = value.fullPath
205 | context.Next()
206 | context.writermem.WriteHeaderNow()
207 | return
208 | }
209 |
210 | break
211 | }
212 |
213 | context.handlers = c.allNoRoute
214 |
215 | // Notice
216 | // We must judge whether an empty request is OPTIONS method,
217 | // Because when complex request (XMLHttpRequest) will send an OPTIONS request and fetch the preflight resource.
218 | // But in general, we do not register an OPTIONS handle,
219 | // So this may cause some middleware errors.
220 | if httpMethod == http.MethodOptions {
221 | serveError(context, http.StatusNoContent, nil)
222 | } else {
223 | serveError(context, http.StatusNotFound, []byte("404 NOT FOUND"))
224 | }
225 | }
226 |
227 | func serveError(c *context, code int, defaultMessage []byte) {
228 | c.writermem.status = code
229 | c.Next()
230 | if c.writermem.Written() {
231 | return
232 | }
233 | if c.writermem.Status() == code {
234 | c.writermem.Header()["Content-Type"] = []string{MIMETextPlain}
235 | _, err := c.w.Write(defaultMessage)
236 | if err != nil {
237 | c.engine.l.Error(fmt.Sprintf("cannot write message to writer during serve error: %v", err))
238 | }
239 | return
240 | }
241 | c.writermem.WriteHeaderNow()
242 | }
243 |
--------------------------------------------------------------------------------
/yee_test.go:
--------------------------------------------------------------------------------
1 | package yee
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 | )
7 |
8 | func indexHandle(c Context) (err error) {
9 | return c.JSON(http.StatusOK, "ok")
10 | }
11 |
12 | func addRouter(y *Core) {
13 | y.GET("/", indexHandle)
14 | }
15 |
16 | func TestYee(t *testing.T) {
17 | y := New()
18 | addRouter(y)
19 | y.Run(":9999")
20 | }
21 |
22 | func TestRestApi(t *testing.T) {
23 | y := New()
24 | y.Restful("/", RestfulAPI{
25 | Get: func(c Context) (err error) {
26 | return c.String(http.StatusOK, "updated")
27 | },
28 | Post: func(c Context) (err error) {
29 | return c.String(http.StatusOK, "get it")
30 | },
31 | })
32 | }
33 |
34 | func TestDownload(t *testing.T) {
35 | y := New()
36 | y.GET("/", func(c Context) (err error) {
37 | return c.File("args.go")
38 | })
39 | y.Run(":9999")
40 | }
41 |
42 | func TestStatic(t *testing.T) {
43 | y := New()
44 | y.Static("/", "dist")
45 | //y.GET("/", func(c Context) error {
46 | // return c.HTMLTpl(http.StatusOK, "./dist/index.html")
47 | //})
48 | y.Run(":9999")
49 | }
50 |
51 | ////go:embed dist/*
52 | //var f embed.FS
53 | //
54 | ////go:embed dist/index.html
55 | //var index string
56 |
57 | //func TestPack(t *testing.T) {
58 | // y := New()
59 | // y.Pack("/front", f, "dist")
60 | // y.GET("/", func(c Context) error {
61 | // return c.HTML(http.StatusOK, index)
62 | // })
63 | // y.Run(":9999")
64 | //}
65 |
66 | const ver = `alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000,h3-Q050=":443"; ma=2592000,h3-Q046=":443"; ma=2592000,h3-Q043=":443"; ma=2592000,quic=":443"; ma=2592000; v="46,43"`
67 |
68 | func TestH3(t *testing.T) {
69 | y := New()
70 | y.GET("/", func(c Context) (err error) {
71 | return c.JSON(http.StatusOK, "hello")
72 | })
73 | y.RunH3(":443", "henry.com+4.pem", "henry.com+4-key.pem")
74 | }
75 |
--------------------------------------------------------------------------------