├── .gitignore ├── LICENSE ├── README.md ├── args.go ├── benchmark_test.go ├── bind.go ├── bind_test.go ├── common.go ├── context.go ├── context_test.go ├── fs.go ├── gen.go ├── gen_test.go ├── go.mod ├── go.sum ├── h2_test.go ├── h3_test.go ├── h3client.go ├── img └── benchmark.png ├── log_test.go ├── logger ├── color.go ├── color_test.go └── log.go ├── logrus.log ├── middleware ├── basicAuth.go ├── basicAuth_test.go ├── cors.go ├── cors_test.go ├── csrf.go ├── csrf_test.go ├── gzip.go ├── gzip_test.go ├── jwt.go ├── jwt_test.go ├── logger.go ├── logger_test.go ├── multiMiddle_test.go ├── proxy.go ├── rateLimit.go ├── rateLimit_test.go ├── recovery.go ├── recovery_test.go ├── requestID.go ├── requestID_test.go ├── secure.go ├── secure_test.go └── trace.go ├── pprof.go ├── pprof_test.go ├── response_overide.go ├── router.go ├── router_test.go ├── test ├── host.pb.go └── host.proto ├── tree.go ├── yee.go └── yee_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | dist 8 | # Test binary, built with `go test -c` 9 | *.test 10 | testing/h5 11 | .idea 12 | .sum 13 | # Output of the go coverage tool, specifically when used with LiteIDE 14 | *.out 15 | 16 | # Dependency directories (remove the comment below to include it) 17 | # vendor/ 18 | *.pem 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Henry Yee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yee 2 | 3 | ![](https://img.shields.io/badge/build-alpha-brightgreen.svg)   4 | ![](https://img.shields.io/badge/version-v0.0.1-brightgreen.svg) 5 | 6 | 🦄 Web frameworks for Go, easier & faster. 7 | 8 | This is a framework for learning purposes. Refer to the code for Echo and Gin 9 | 10 | - Faster HTTP router 11 | - Build RESTful APIs 12 | - Group APIs 13 | - Extensible middleware framework 14 | - Define middleware at root, group or route level 15 | - Data binding for URI Query, JSON, XML, Protocol Buffer3 and form payload 16 | - HTTP/2(H2C)/Http3(QUIC) support 17 | 18 | # Supported Go versions 19 | 20 | Yee is available as a Go module. You need to use Go 1.13 + 21 | 22 | ## Example 23 | 24 | #### Quick start 25 | 26 | ```go 27 | file, err := os.OpenFile("logrus.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) 28 | 29 | if err != nil { 30 | return 31 | } 32 | 33 | y := yee.New() 34 | 35 | y.SetLogLevel(logger.Warning) 36 | 37 | y.SetLogOut(file) 38 | 39 | y.Use(Logger()) 40 | 41 | y.Static("/assets", "dist/assets") 42 | 43 | y.GET("/", func(c yee.Context) (err error) { 44 | return c.HTMLTml(http.StatusOK, "dist/index.html") 45 | }) 46 | 47 | y.GET("/hello", func(c yee.Context) error { 48 | return c.String(http.StatusOK, "

Hello Gee

") 49 | }) 50 | 51 | y.POST("/test", func(c yee.Context) (err error) { 52 | u := new(p) 53 | if err := c.Bind(u); err != nil { 54 | return c.JSON(http.StatusOK, err.Error()) 55 | } 56 | return c.JSON(http.StatusOK, u.Test) 57 | }) 58 | 59 | y.Run(":9000") 60 | ``` 61 | 62 | #### API 63 | 64 | Provide GET POST PUT DELETE HEAD OPTIONS TRACE PATCH 65 | 66 | ```go 67 | 68 | y := yee.New() 69 | 70 | y.GET("/someGet", handler) 71 | y.POST("/somePost", handler) 72 | y.PUT("/somePut", handler) 73 | y.DELETE("/someDelete", handler) 74 | y.PATCH("/somePatch", handler) 75 | y.HEAD("/someHead", handler) 76 | y.OPTIONS("/someOptions", handler) 77 | 78 | y.Run(":8000") 79 | 80 | ``` 81 | 82 | #### Restful 83 | 84 | You can use Any & Restful method to implement your restful api 85 | 86 | + Any 87 | 88 | ```go 89 | y := yee.New() 90 | 91 | y.Any("/any", handler) 92 | 93 | y.Run(":8000") 94 | 95 | // All request methods for the same URL use the same handler 96 | 97 | ``` 98 | 99 | + Restful 100 | 101 | ```go 102 | 103 | func userUpdate(c Context) (err error) { 104 | return c.String(http.StatusOK, "updated") 105 | } 106 | 107 | func userFetch(c Context) (err error) { 108 | return c.String(http.StatusOK, "get it") 109 | } 110 | 111 | func RestfulApi() yee.RestfulApi { 112 | return RestfulApi{ 113 | Get: userFetch, 114 | Post: userUpdate, 115 | } 116 | } 117 | 118 | y := New() 119 | 120 | y.Restful("/", testRestfulApi()) 121 | 122 | y.Run(":8000") 123 | 124 | // All request methods for the same URL use the different handler 125 | 126 | 127 | ``` 128 | 129 | ## Middleware 130 | 131 | - basic auth 132 | - cors 133 | - crfs 134 | - gzip 135 | - jwt 136 | - logger 137 | - rate limit 138 | - recovery 139 | - secure 140 | - request id 141 | 142 | ## Benchmark 143 | 144 | Resource 145 | - CPU: i7-9750H 146 | - Memory: 16G 147 | - OS: macOS 10.15.5 148 | 149 | Date: 2020/06/17 150 | 151 | ![](img/benchmark.png) 152 | 153 | 154 | ## License 155 | 156 | MIT 157 | -------------------------------------------------------------------------------- /args.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | // Header types 8 | const ( 9 | HeaderSecWebSocketProtocol = "Sec-Websocket-Protocol" 10 | HeaderAccept = "Accept" 11 | HeaderAcceptEncoding = "Accept-Encoding" 12 | HeaderAuthorization = "Authorization" 13 | HeaderContentDisposition = "Content-Disposition" 14 | HeaderContentEncoding = "Content-Encoding" 15 | HeaderContentLength = "Content-Length" 16 | HeaderContentType = "Content-Type" 17 | HeaderCookie = "Cookie" 18 | HeaderSetCookie = "Set-Cookie" 19 | HeaderIfModifiedSince = "If-Modified-Since" 20 | HeaderLastModified = "Last-Modified" 21 | HeaderLocation = "Location" 22 | HeaderUpgrade = "Upgrade" 23 | HeaderConnection = "Connection" 24 | HeaderVary = "Vary" 25 | HeaderWWWAuthenticate = "WWW-Authenticate" 26 | HeaderXForwardedFor = "X-Forwarded-For" 27 | HeaderXForwardedProto = "X-Forwarded-Proto" 28 | HeaderXForwardedProtocol = "X-Forwarded-Protocol" 29 | HeaderXForwardedSsl = "X-Forwarded-Ssl" 30 | HeaderXUrlScheme = "X-Url-Scheme" 31 | HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" 32 | HeaderXRealIP = "X-Real-IP" 33 | HeaderXRequestID = "X-Request-ID" 34 | HeaderXRequestedWith = "X-Requested-With" 35 | HeaderServer = "Server" 36 | HeaderOrigin = "Origin" 37 | 38 | // Access control 39 | HeaderAccessControlRequestMethod = "Access-Control-Request-Method" 40 | HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" 41 | HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" 42 | HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" 43 | HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" 44 | HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" 45 | HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" 46 | HeaderAccessControlMaxAge = "Access-Control-Max-Age" 47 | 48 | // Security 49 | HeaderStrictTransportSecurity = "Strict-Transport-Security" 50 | HeaderXContentTypeOptions = "X-Content-Type-Options" 51 | HeaderXXSSProtection = "X-XSS-Protection" 52 | HeaderXFrameOptions = "X-Frame-Options" 53 | HeaderContentSecurityPolicy = "Content-Security-Policy" 54 | HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" 55 | HeaderXCSRFToken = "X-CSRF-Token" 56 | HeaderReferrerPolicy = "Referrer-Policy" 57 | ) 58 | 59 | const ( 60 | defaultMemory = 32 << 20 // 32 MB 61 | indexPage = "index.html" 62 | defaultIndent = " " 63 | ) 64 | 65 | const ( 66 | StatusCodeContextCanceled = 499 67 | ) 68 | 69 | // MIME types 70 | const ( 71 | MIMEApplicationJSON = "application/json" 72 | MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 73 | MIMEApplicationJavaScript = "application/javascript" 74 | MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 75 | MIMEApplicationXML = "application/xml" 76 | MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8 77 | MIMETextXML = "text/xml" 78 | MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8 79 | MIMEApplicationForm = "application/x-www-form-urlencoded" 80 | MIMEApplicationProtobuf = "application/protobuf" 81 | MIMEApplicationMsgpack = "application/msgpack" 82 | MIMETextHTML = "text/html" 83 | MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8 84 | MIMETextPlain = "text/plain" 85 | MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8 86 | MIMEMultipartForm = "multipart/form-data" 87 | MIMEOctetStream = "application/octet-stream" 88 | ) 89 | 90 | const ( 91 | charsetUTF8 = "charset=UTF-8" 92 | serverName = "yee" 93 | ) 94 | 95 | // Err types 96 | var ( 97 | ErrUnsupportedMediaType = errors.New("http server not support media type") 98 | ErrValidatorNotRegistered = errors.New("validator not registered") 99 | ErrRendererNotRegistered = errors.New("renderer not registered") 100 | ErrInvalidRedirectCode = errors.New("invalid redirect status code") 101 | ErrCookieNotFound = errors.New("cookie not found") 102 | ErrNotFoundHandler = errors.New("404 NOT FOUND") 103 | ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") 104 | ) 105 | -------------------------------------------------------------------------------- /benchmark_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "runtime" 7 | "testing" 8 | ) 9 | 10 | type ( 11 | Route struct { 12 | Method string 13 | Path string 14 | } 15 | ) 16 | 17 | var gplusAPI = []*Route{ 18 | // People 19 | {"GET", "/people/:userId"}, 20 | {"GET", "/people"}, 21 | {"GET", "/activities/:activityId/people/:collection"}, 22 | {"GET", "/people/:userId/people/:collection"}, 23 | {"GET", "/people/:userId/openIdConnect"}, 24 | 25 | // Activities 26 | {"GET", "/people/:userId/activities/:collection"}, 27 | {"GET", "/activities/:activityId"}, 28 | {"GET", "/activities"}, 29 | 30 | // Comments 31 | {"GET", "/activities/:activityId/comments"}, 32 | {"GET", "/comments/:commentId"}, 33 | 34 | // Moments 35 | {"POST", "/people/:userId/moments/:collection"}, 36 | {"GET", "/people/:userId/moments/:collection"}, 37 | {"DELETE", "/moments/:id"}, 38 | } 39 | 40 | var githubAPI = []*Route{ 41 | // OAuth Authorizations 42 | {"GET", "/authorizations"}, 43 | {"GET", "/authorizations/:id"}, 44 | {"POST", "/authorizations"}, 45 | //{"PUT", "/authorizations/clients/:client_id"}, 46 | //{"PATCH", "/authorizations/:id"}, 47 | {"DELETE", "/authorizations/:id"}, 48 | {"GET", "/applications/:client_id/tokens/:access_token"}, 49 | {"DELETE", "/applications/:client_id/tokens"}, 50 | {"DELETE", "/applications/:client_id/tokens/:access_token"}, 51 | 52 | // Activity 53 | {"GET", "/events"}, 54 | {"GET", "/repos/:owner/:repo/events"}, 55 | {"GET", "/networks/:owner/:repo/events"}, 56 | {"GET", "/orgs/:org/events"}, 57 | {"GET", "/users/:user/received_events"}, 58 | {"GET", "/users/:user/received_events/public"}, 59 | {"GET", "/users/:user/events"}, 60 | {"GET", "/users/:user/events/public"}, 61 | {"GET", "/users/:user/events/orgs/:org"}, 62 | {"GET", "/feeds"}, 63 | {"GET", "/notifications"}, 64 | {"GET", "/repos/:owner/:repo/notifications"}, 65 | {"PUT", "/notifications"}, 66 | {"PUT", "/repos/:owner/:repo/notifications"}, 67 | {"GET", "/notifications/threads/:id"}, 68 | //{"PATCH", "/notifications/threads/:id"}, 69 | {"GET", "/notifications/threads/:id/subscription"}, 70 | {"PUT", "/notifications/threads/:id/subscription"}, 71 | {"DELETE", "/notifications/threads/:id/subscription"}, 72 | {"GET", "/repos/:owner/:repo/stargazers"}, 73 | {"GET", "/users/:user/starred"}, 74 | {"GET", "/user/starred"}, 75 | {"GET", "/user/starred/:owner/:repo"}, 76 | {"PUT", "/user/starred/:owner/:repo"}, 77 | {"DELETE", "/user/starred/:owner/:repo"}, 78 | {"GET", "/repos/:owner/:repo/subscribers"}, 79 | {"GET", "/users/:user/subscriptions"}, 80 | {"GET", "/user/subscriptions"}, 81 | {"GET", "/repos/:owner/:repo/subscription"}, 82 | {"PUT", "/repos/:owner/:repo/subscription"}, 83 | {"DELETE", "/repos/:owner/:repo/subscription"}, 84 | {"GET", "/user/subscriptions/:owner/:repo"}, 85 | {"PUT", "/user/subscriptions/:owner/:repo"}, 86 | {"DELETE", "/user/subscriptions/:owner/:repo"}, 87 | 88 | // Gists 89 | {"GET", "/users/:user/gists"}, 90 | {"GET", "/gists"}, 91 | //{"GET", "/gists/public"}, 92 | //{"GET", "/gists/starred"}, 93 | {"GET", "/gists/:id"}, 94 | {"POST", "/gists"}, 95 | //{"PATCH", "/gists/:id"}, 96 | {"PUT", "/gists/:id/star"}, 97 | {"DELETE", "/gists/:id/star"}, 98 | {"GET", "/gists/:id/star"}, 99 | {"POST", "/gists/:id/forks"}, 100 | {"DELETE", "/gists/:id"}, 101 | 102 | // Git Data 103 | {"GET", "/repos/:owner/:repo/git/blobs/:sha"}, 104 | {"POST", "/repos/:owner/:repo/git/blobs"}, 105 | {"GET", "/repos/:owner/:repo/git/commits/:sha"}, 106 | {"POST", "/repos/:owner/:repo/git/commits"}, 107 | //{"GET", "/repos/:owner/:repo/git/refs/*ref"}, 108 | {"GET", "/repos/:owner/:repo/git/refs"}, 109 | {"POST", "/repos/:owner/:repo/git/refs"}, 110 | //{"PATCH", "/repos/:owner/:repo/git/refs/*ref"}, 111 | //{"DELETE", "/repos/:owner/:repo/git/refs/*ref"}, 112 | {"GET", "/repos/:owner/:repo/git/tags/:sha"}, 113 | {"POST", "/repos/:owner/:repo/git/tags"}, 114 | {"GET", "/repos/:owner/:repo/git/trees/:sha"}, 115 | {"POST", "/repos/:owner/:repo/git/trees"}, 116 | 117 | // Issues 118 | {"GET", "/issues"}, 119 | {"GET", "/user/issues"}, 120 | {"GET", "/orgs/:org/issues"}, 121 | {"GET", "/repos/:owner/:repo/issues"}, 122 | {"GET", "/repos/:owner/:repo/issues/:number"}, 123 | {"POST", "/repos/:owner/:repo/issues"}, 124 | //{"PATCH", "/repos/:owner/:repo/issues/:number"}, 125 | {"GET", "/repos/:owner/:repo/assignees"}, 126 | {"GET", "/repos/:owner/:repo/assignees/:assignee"}, 127 | {"GET", "/repos/:owner/:repo/issues/:number/comments"}, 128 | //{"GET", "/repos/:owner/:repo/issues/comments"}, 129 | //{"GET", "/repos/:owner/:repo/issues/comments/:id"}, 130 | {"POST", "/repos/:owner/:repo/issues/:number/comments"}, 131 | //{"PATCH", "/repos/:owner/:repo/issues/comments/:id"}, 132 | //{"DELETE", "/repos/:owner/:repo/issues/comments/:id"}, 133 | {"GET", "/repos/:owner/:repo/issues/:number/events"}, 134 | //{"GET", "/repos/:owner/:repo/issues/events"}, 135 | //{"GET", "/repos/:owner/:repo/issues/events/:id"}, 136 | {"GET", "/repos/:owner/:repo/labels"}, 137 | {"GET", "/repos/:owner/:repo/labels/:name"}, 138 | {"POST", "/repos/:owner/:repo/labels"}, 139 | //{"PATCH", "/repos/:owner/:repo/labels/:name"}, 140 | {"DELETE", "/repos/:owner/:repo/labels/:name"}, 141 | {"GET", "/repos/:owner/:repo/issues/:number/labels"}, 142 | {"POST", "/repos/:owner/:repo/issues/:number/labels"}, 143 | {"DELETE", "/repos/:owner/:repo/issues/:number/labels/:name"}, 144 | {"PUT", "/repos/:owner/:repo/issues/:number/labels"}, 145 | {"DELETE", "/repos/:owner/:repo/issues/:number/labels"}, 146 | {"GET", "/repos/:owner/:repo/milestones/:number/labels"}, 147 | {"GET", "/repos/:owner/:repo/milestones"}, 148 | {"GET", "/repos/:owner/:repo/milestones/:number"}, 149 | {"POST", "/repos/:owner/:repo/milestones"}, 150 | //{"PATCH", "/repos/:owner/:repo/milestones/:number"}, 151 | {"DELETE", "/repos/:owner/:repo/milestones/:number"}, 152 | 153 | // Miscellaneous 154 | {"GET", "/emojis"}, 155 | {"GET", "/gitignore/templates"}, 156 | {"GET", "/gitignore/templates/:name"}, 157 | {"POST", "/markdown"}, 158 | {"POST", "/markdown/raw"}, 159 | {"GET", "/meta"}, 160 | {"GET", "/rate_limit"}, 161 | 162 | // Organizations 163 | {"GET", "/users/:user/orgs"}, 164 | {"GET", "/user/orgs"}, 165 | {"GET", "/orgs/:org"}, 166 | //{"PATCH", "/orgs/:org"}, 167 | {"GET", "/orgs/:org/members"}, 168 | {"GET", "/orgs/:org/members/:user"}, 169 | {"DELETE", "/orgs/:org/members/:user"}, 170 | {"GET", "/orgs/:org/public_members"}, 171 | {"GET", "/orgs/:org/public_members/:user"}, 172 | {"PUT", "/orgs/:org/public_members/:user"}, 173 | {"DELETE", "/orgs/:org/public_members/:user"}, 174 | {"GET", "/orgs/:org/teams"}, 175 | {"GET", "/teams/:id"}, 176 | {"POST", "/orgs/:org/teams"}, 177 | //{"PATCH", "/teams/:id"}, 178 | {"DELETE", "/teams/:id"}, 179 | {"GET", "/teams/:id/members"}, 180 | {"GET", "/teams/:id/members/:user"}, 181 | {"PUT", "/teams/:id/members/:user"}, 182 | {"DELETE", "/teams/:id/members/:user"}, 183 | {"GET", "/teams/:id/repos"}, 184 | {"GET", "/teams/:id/repos/:owner/:repo"}, 185 | {"PUT", "/teams/:id/repos/:owner/:repo"}, 186 | {"DELETE", "/teams/:id/repos/:owner/:repo"}, 187 | {"GET", "/user/teams"}, 188 | 189 | // Pull Requests 190 | {"GET", "/repos/:owner/:repo/pulls"}, 191 | {"GET", "/repos/:owner/:repo/pulls/:number"}, 192 | {"POST", "/repos/:owner/:repo/pulls"}, 193 | //{"PATCH", "/repos/:owner/:repo/pulls/:number"}, 194 | {"GET", "/repos/:owner/:repo/pulls/:number/commits"}, 195 | {"GET", "/repos/:owner/:repo/pulls/:number/files"}, 196 | {"GET", "/repos/:owner/:repo/pulls/:number/merge"}, 197 | {"PUT", "/repos/:owner/:repo/pulls/:number/merge"}, 198 | {"GET", "/repos/:owner/:repo/pulls/:number/comments"}, 199 | //{"GET", "/repos/:owner/:repo/pulls/comments"}, 200 | //{"GET", "/repos/:owner/:repo/pulls/comments/:number"}, 201 | {"PUT", "/repos/:owner/:repo/pulls/:number/comments"}, 202 | //{"PATCH", "/repos/:owner/:repo/pulls/comments/:number"}, 203 | //{"DELETE", "/repos/:owner/:repo/pulls/comments/:number"}, 204 | 205 | // Repositories 206 | {"GET", "/user/repos"}, 207 | {"GET", "/users/:user/repos"}, 208 | {"GET", "/orgs/:org/repos"}, 209 | {"GET", "/repositories"}, 210 | {"POST", "/user/repos"}, 211 | {"POST", "/orgs/:org/repos"}, 212 | {"GET", "/repos/:owner/:repo"}, 213 | //{"PATCH", "/repos/:owner/:repo"}, 214 | {"GET", "/repos/:owner/:repo/contributors"}, 215 | {"GET", "/repos/:owner/:repo/languages"}, 216 | {"GET", "/repos/:owner/:repo/teams"}, 217 | {"GET", "/repos/:owner/:repo/tags"}, 218 | {"GET", "/repos/:owner/:repo/branches"}, 219 | {"GET", "/repos/:owner/:repo/branches/:branch"}, 220 | {"DELETE", "/repos/:owner/:repo"}, 221 | {"GET", "/repos/:owner/:repo/collaborators"}, 222 | {"GET", "/repos/:owner/:repo/collaborators/:user"}, 223 | {"PUT", "/repos/:owner/:repo/collaborators/:user"}, 224 | {"DELETE", "/repos/:owner/:repo/collaborators/:user"}, 225 | {"GET", "/repos/:owner/:repo/comments"}, 226 | {"GET", "/repos/:owner/:repo/commits/:sha/comments"}, 227 | {"POST", "/repos/:owner/:repo/commits/:sha/comments"}, 228 | {"GET", "/repos/:owner/:repo/comments/:id"}, 229 | //{"PATCH", "/repos/:owner/:repo/comments/:id"}, 230 | {"DELETE", "/repos/:owner/:repo/comments/:id"}, 231 | {"GET", "/repos/:owner/:repo/commits"}, 232 | {"GET", "/repos/:owner/:repo/commits/:sha"}, 233 | {"GET", "/repos/:owner/:repo/readme"}, 234 | //{"GET", "/repos/:owner/:repo/contents/*path"}, 235 | //{"PUT", "/repos/:owner/:repo/contents/*path"}, 236 | //{"DELETE", "/repos/:owner/:repo/contents/*path"}, 237 | //{"GET", "/repos/:owner/:repo/:archive_format/:ref"}, 238 | {"GET", "/repos/:owner/:repo/keys"}, 239 | {"GET", "/repos/:owner/:repo/keys/:id"}, 240 | {"POST", "/repos/:owner/:repo/keys"}, 241 | //{"PATCH", "/repos/:owner/:repo/keys/:id"}, 242 | {"DELETE", "/repos/:owner/:repo/keys/:id"}, 243 | {"GET", "/repos/:owner/:repo/downloads"}, 244 | {"GET", "/repos/:owner/:repo/downloads/:id"}, 245 | {"DELETE", "/repos/:owner/:repo/downloads/:id"}, 246 | {"GET", "/repos/:owner/:repo/forks"}, 247 | {"POST", "/repos/:owner/:repo/forks"}, 248 | {"GET", "/repos/:owner/:repo/hooks"}, 249 | {"GET", "/repos/:owner/:repo/hooks/:id"}, 250 | {"POST", "/repos/:owner/:repo/hooks"}, 251 | //{"PATCH", "/repos/:owner/:repo/hooks/:id"}, 252 | {"POST", "/repos/:owner/:repo/hooks/:id/tests"}, 253 | {"DELETE", "/repos/:owner/:repo/hooks/:id"}, 254 | {"POST", "/repos/:owner/:repo/merges"}, 255 | {"GET", "/repos/:owner/:repo/releases"}, 256 | {"GET", "/repos/:owner/:repo/releases/:id"}, 257 | {"POST", "/repos/:owner/:repo/releases"}, 258 | //{"PATCH", "/repos/:owner/:repo/releases/:id"}, 259 | {"DELETE", "/repos/:owner/:repo/releases/:id"}, 260 | {"GET", "/repos/:owner/:repo/releases/:id/assets"}, 261 | {"GET", "/repos/:owner/:repo/stats/contributors"}, 262 | {"GET", "/repos/:owner/:repo/stats/commit_activity"}, 263 | {"GET", "/repos/:owner/:repo/stats/code_frequency"}, 264 | {"GET", "/repos/:owner/:repo/stats/participation"}, 265 | {"GET", "/repos/:owner/:repo/stats/punch_card"}, 266 | {"GET", "/repos/:owner/:repo/statuses/:ref"}, 267 | {"POST", "/repos/:owner/:repo/statuses/:ref"}, 268 | 269 | // Search 270 | {"GET", "/search/repositories"}, 271 | {"GET", "/search/code"}, 272 | {"GET", "/search/issues"}, 273 | {"GET", "/search/users"}, 274 | {"GET", "/legacy/issues/search/:owner/:repository/:state/:keyword"}, 275 | {"GET", "/legacy/repos/search/:keyword"}, 276 | {"GET", "/legacy/user/search/:keyword"}, 277 | {"GET", "/legacy/user/email/:email"}, 278 | 279 | // Users 280 | {"GET", "/users/:user"}, 281 | {"GET", "/user"}, 282 | //{"PATCH", "/user"}, 283 | {"GET", "/users"}, 284 | {"GET", "/user/emails"}, 285 | {"POST", "/user/emails"}, 286 | {"DELETE", "/user/emails"}, 287 | {"GET", "/users/:user/followers"}, 288 | {"GET", "/user/followers"}, 289 | {"GET", "/users/:user/following"}, 290 | {"GET", "/user/following"}, 291 | {"GET", "/user/following/:user"}, 292 | {"GET", "/users/:user/following/:target_user"}, 293 | {"PUT", "/user/following/:user"}, 294 | {"DELETE", "/user/following/:user"}, 295 | {"GET", "/users/:user/keys"}, 296 | {"GET", "/user/keys"}, 297 | {"GET", "/user/keys/:id"}, 298 | {"POST", "/user/keys"}, 299 | //{"PATCH", "/user/keys/:id"}, 300 | {"DELETE", "/user/keys/:id"}, 301 | } 302 | 303 | var parseAPI = []*Route{ 304 | // Objects 305 | {"POST", "/1/classes/:className"}, 306 | {"GET", "/1/classes/:className/:objectId"}, 307 | {"PUT", "/1/classes/:className/:objectId"}, 308 | {"GET", "/1/classes/:className"}, 309 | {"DELETE", "/1/classes/:className/:objectId"}, 310 | 311 | // Users 312 | {"POST", "/1/users"}, 313 | {"GET", "/1/login"}, 314 | {"GET", "/1/users/:objectId"}, 315 | {"PUT", "/1/users/:objectId"}, 316 | {"GET", "/1/users"}, 317 | {"DELETE", "/1/users/:objectId"}, 318 | {"POST", "/1/requestPasswordReset"}, 319 | 320 | // Roles 321 | {"POST", "/1/roles"}, 322 | {"GET", "/1/roles/:objectId"}, 323 | {"PUT", "/1/roles/:objectId"}, 324 | {"GET", "/1/roles"}, 325 | {"DELETE", "/1/roles/:objectId"}, 326 | 327 | // Files 328 | {"POST", "/1/files/:fileName"}, 329 | 330 | // Analytics 331 | {"POST", "/1/events/:eventName"}, 332 | 333 | // Push Notifications 334 | {"POST", "/1/push"}, 335 | 336 | // Installations 337 | {"POST", "/1/installations"}, 338 | {"GET", "/1/installations/:objectId"}, 339 | {"PUT", "/1/installations/:objectId"}, 340 | {"GET", "/1/installations"}, 341 | {"DELETE", "/1/installations/:objectId"}, 342 | 343 | // Cloud Functions 344 | {"POST", "/1/functions"}, 345 | } 346 | 347 | func yeeHandler(method, path string) HandlerFunc { 348 | return func(c Context) error { 349 | return c.String(http.StatusOK, "OK") 350 | } 351 | } 352 | 353 | func loadYeeRoutes(e *Core, routes []*Route) { 354 | for _, r := range routes { 355 | switch r.Method { 356 | case "GET": 357 | e.GET(r.Path, yeeHandler(r.Method, r.Path)) 358 | case "POST": 359 | e.POST(r.Path, yeeHandler(r.Method, r.Path)) 360 | case "PATCH": 361 | e.PATCH(r.Path, yeeHandler(r.Method, r.Path)) 362 | case "PUT": 363 | e.PUT(r.Path, yeeHandler(r.Method, r.Path)) 364 | case "DELETE": 365 | e.DELETE(r.Path, yeeHandler(r.Method, r.Path)) 366 | } 367 | } 368 | } 369 | 370 | func benchmarkRoutes(b *testing.B, router http.Handler, routes []*Route) { 371 | b.ReportAllocs() 372 | r := httptest.NewRequest("GET", "/", nil) 373 | u := r.URL 374 | w := httptest.NewRecorder() 375 | b.SetBytes(1024 * 1024) 376 | for i := 0; i < b.N; i++ { 377 | for _, route := range routes { 378 | r.Method = route.Method 379 | u.Path = route.Path 380 | router.ServeHTTP(w, r) 381 | } 382 | } 383 | } 384 | 385 | func BenchmarkYeeParseAPI(b *testing.B) { 386 | e := C() 387 | loadYeeRoutes(e, parseAPI) 388 | benchmarkRoutes(b, e, parseAPI) 389 | } 390 | 391 | func BenchmarkYeeGplusAPI(b *testing.B) { 392 | e := C() 393 | loadYeeRoutes(e, gplusAPI) 394 | benchmarkRoutes(b, e, gplusAPI) 395 | } 396 | 397 | func BenchmarkYeeGitHubAPI(b *testing.B) { 398 | e := C() 399 | loadYeeRoutes(e, githubAPI) 400 | benchmarkRoutes(b, e, githubAPI) 401 | } 402 | 403 | func BenchmarkYeeStatic(b *testing.B) { 404 | e := C() 405 | e.Static("/front", "color") 406 | b.ReportAllocs() 407 | r := httptest.NewRequest("GET", "/front/color.go", nil) 408 | w := httptest.NewRecorder() 409 | b.SetBytes(1024 * 1024) 410 | for i := 0; i < b.N; i++ { 411 | e.ServeHTTP(w, r) 412 | } 413 | } 414 | 415 | func Benchmark(b *testing.B) { 416 | runtime.GOMAXPROCS(runtime.NumCPU()) 417 | b.Run("BenchmarkYeeGplusAPI", BenchmarkYeeGplusAPI) 418 | b.Run("BenchmarkYeeParseAPI", BenchmarkYeeParseAPI) 419 | b.Run("BenchmarkYeeGitHubAPI", BenchmarkYeeGitHubAPI) 420 | } 421 | -------------------------------------------------------------------------------- /bind.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "encoding" 5 | "encoding/json" 6 | "encoding/xml" 7 | "errors" 8 | "fmt" 9 | "github.com/go-playground/validator/v10" 10 | "github.com/golang/protobuf/proto" 11 | "io" 12 | "reflect" 13 | "strconv" 14 | "strings" 15 | ) 16 | 17 | type ( 18 | // DefaultBinder is the default implementation of the Binder interface. 19 | DefaultBinder struct{} 20 | 21 | // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. 22 | // Types that don't implement this, but do implement encoding.TextUnmarshaler 23 | // will use that interface instead. 24 | BindUnmarshaler interface { 25 | // UnmarshalParam decodes and assigns a value from an form or query param. 26 | UnmarshalParam(param string) error 27 | } 28 | ) 29 | 30 | func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { 31 | if err := b.bind(i, c); err != nil { 32 | return err 33 | } 34 | validate := validator.New() 35 | value := reflect.ValueOf(i) 36 | if value.Kind() == reflect.Ptr { 37 | elem := value.Elem() 38 | return validate.Struct(elem) 39 | } else { 40 | return validate.Struct(i) 41 | } 42 | } 43 | 44 | // Bind implements the `Binder#Bind` function. 45 | func (b *DefaultBinder) bind(i interface{}, c Context) (err error) { 46 | req := c.Request() 47 | if err = b.bindData(i, c.QueryParams(), "json"); err != nil { 48 | return err 49 | } 50 | 51 | if req.ContentLength == 0 { 52 | return 53 | } 54 | 55 | ctype := strings.ToLower(req.Header.Get(HeaderContentType)) 56 | switch { 57 | case strings.HasPrefix(ctype, MIMEApplicationProtobuf): 58 | buf, err := io.ReadAll(req.Body) 59 | if err != nil { 60 | return err 61 | } 62 | if err := proto.Unmarshal(buf, i.(proto.Message)); err != nil { 63 | return err 64 | } 65 | case strings.HasPrefix(ctype, MIMEApplicationJSON): 66 | if err = json.NewDecoder(req.Body).Decode(i); err != nil { 67 | var ute *json.UnmarshalTypeError 68 | if errors.As(err, &ute) { 69 | return errors.New(fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)) 70 | } 71 | return err 72 | } 73 | case strings.HasPrefix(ctype, MIMEApplicationXML), 74 | strings.HasPrefix(ctype, MIMETextXML), 75 | strings.HasPrefix(ctype, MIMETextXMLCharsetUTF8), 76 | strings.HasPrefix(ctype, MIMEApplicationXMLCharsetUTF8): 77 | if err = xml.NewDecoder(req.Body).Decode(i); err != nil { 78 | if ute, ok := err.(*xml.UnsupportedTypeError); ok { 79 | return errors.New(fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) 80 | } else if se, ok := err.(*xml.SyntaxError); ok { 81 | return errors.New(fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())) 82 | } 83 | return err 84 | } 85 | case strings.HasPrefix(ctype, MIMEOctetStream): 86 | 87 | case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): 88 | params, err := c.FormParams() 89 | if err != nil { 90 | return err 91 | } 92 | if err = b.bindData(i, params, "form"); err != nil { 93 | return err 94 | } 95 | default: 96 | return ErrUnsupportedMediaType 97 | } 98 | return 99 | } 100 | 101 | func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error { 102 | if ptr == nil || len(data) == 0 { 103 | return nil 104 | } 105 | typ := reflect.TypeOf(ptr).Elem() 106 | val := reflect.ValueOf(ptr).Elem() 107 | // Map 108 | if typ.Kind() == reflect.Map { 109 | for k, v := range data { 110 | val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) 111 | } 112 | return nil 113 | } 114 | 115 | // !struct 116 | if typ.Kind() != reflect.Struct { 117 | return errors.New("binding element must be a struct") 118 | } 119 | for i := 0; i < typ.NumField(); i++ { 120 | typeField := typ.Field(i) 121 | structField := val.Field(i) 122 | if !structField.CanSet() { 123 | continue 124 | } 125 | structFieldKind := structField.Kind() 126 | inputFieldName := typeField.Tag.Get(tag) 127 | if inputFieldName == "" { 128 | inputFieldName = typeField.Name 129 | // If tag is nil, we inspect if the field is a struct. 130 | if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { 131 | if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { 132 | return err 133 | } 134 | continue 135 | } 136 | } 137 | 138 | inputValue, exists := data[inputFieldName] 139 | if !exists { 140 | // Go json.Unmarshal supports case insensitive binding. However the 141 | // url params are bound case sensitive which is inconsistent. To 142 | // fix this we must check all of the map values in a 143 | // case-insensitive search. 144 | for k, v := range data { 145 | if strings.EqualFold(k, inputFieldName) { 146 | inputValue = v 147 | exists = true 148 | break 149 | } 150 | } 151 | } 152 | 153 | if !exists { 154 | continue 155 | } 156 | 157 | // Call this first, in case we're dealing with an alias to an array type 158 | if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { 159 | if err != nil { 160 | return err 161 | } 162 | continue 163 | } 164 | 165 | numElems := len(inputValue) 166 | if structFieldKind == reflect.Slice && numElems > 0 { 167 | sliceOf := structField.Type().Elem().Kind() 168 | slice := reflect.MakeSlice(structField.Type(), numElems, numElems) 169 | for j := 0; j < numElems; j++ { 170 | if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { 171 | return err 172 | } 173 | } 174 | val.Field(i).Set(slice) 175 | } else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { 176 | return err 177 | 178 | } 179 | } 180 | return nil 181 | } 182 | 183 | func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { 184 | // But also call it here, in case we're dealing with an array of BindUnmarshalers 185 | if ok, err := unmarshalField(valueKind, val, structField); ok { 186 | return err 187 | } 188 | 189 | switch valueKind { 190 | case reflect.Ptr: 191 | return setWithProperType(structField.Elem().Kind(), val, structField.Elem()) 192 | case reflect.Int: 193 | return setIntField(val, 0, structField) 194 | case reflect.Int8: 195 | return setIntField(val, 8, structField) 196 | case reflect.Int16: 197 | return setIntField(val, 16, structField) 198 | case reflect.Int32: 199 | return setIntField(val, 32, structField) 200 | case reflect.Int64: 201 | return setIntField(val, 64, structField) 202 | case reflect.Uint: 203 | return setUintField(val, 0, structField) 204 | case reflect.Uint8: 205 | return setUintField(val, 8, structField) 206 | case reflect.Uint16: 207 | return setUintField(val, 16, structField) 208 | case reflect.Uint32: 209 | return setUintField(val, 32, structField) 210 | case reflect.Uint64: 211 | return setUintField(val, 64, structField) 212 | case reflect.Bool: 213 | return setBoolField(val, structField) 214 | case reflect.Float32: 215 | return setFloatField(val, 32, structField) 216 | case reflect.Float64: 217 | return setFloatField(val, 64, structField) 218 | case reflect.String: 219 | structField.SetString(val) 220 | default: 221 | return errors.New("unknown type") 222 | } 223 | return nil 224 | } 225 | 226 | func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { 227 | switch valueKind { 228 | case reflect.Ptr: 229 | return unmarshalFieldPtr(val, field) 230 | default: 231 | return unmarshalFieldNonPtr(val, field) 232 | } 233 | } 234 | 235 | func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { 236 | fieldIValue := field.Addr().Interface() 237 | if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok { 238 | return true, unmarshaler.UnmarshalParam(value) 239 | } 240 | if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok { 241 | return true, unmarshaler.UnmarshalText([]byte(value)) 242 | } 243 | 244 | return false, nil 245 | } 246 | 247 | func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { 248 | if field.IsNil() { 249 | // Initialize the pointer to a nil value 250 | field.Set(reflect.New(field.Type().Elem())) 251 | } 252 | return unmarshalFieldNonPtr(value, field.Elem()) 253 | } 254 | 255 | func setIntField(value string, bitSize int, field reflect.Value) error { 256 | if value == "" { 257 | value = "0" 258 | } 259 | intVal, err := strconv.ParseInt(value, 10, bitSize) 260 | if err == nil { 261 | field.SetInt(intVal) 262 | } 263 | return err 264 | } 265 | 266 | func setUintField(value string, bitSize int, field reflect.Value) error { 267 | if value == "" { 268 | value = "0" 269 | } 270 | uintVal, err := strconv.ParseUint(value, 10, bitSize) 271 | if err == nil { 272 | field.SetUint(uintVal) 273 | } 274 | return err 275 | } 276 | 277 | func setBoolField(value string, field reflect.Value) error { 278 | if value == "" { 279 | value = "false" 280 | } 281 | boolVal, err := strconv.ParseBool(value) 282 | if err == nil { 283 | field.SetBool(boolVal) 284 | } 285 | return err 286 | } 287 | 288 | func setFloatField(value string, bitSize int, field reflect.Value) error { 289 | if value == "" { 290 | value = "0.0" 291 | } 292 | floatVal, err := strconv.ParseFloat(value, bitSize) 293 | if err == nil { 294 | field.SetFloat(floatVal) 295 | } 296 | return err 297 | } 298 | -------------------------------------------------------------------------------- /bind_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | ) 12 | 13 | type user struct { 14 | Username string `json:"username" validate:"required"` 15 | Password string `json:"password"` 16 | Age int `json:"age"` 17 | } 18 | 19 | type empty struct { 20 | } 21 | 22 | type cmdbBind struct { 23 | RegionId string `json:"region_id"` 24 | SecId string `json:"secId"` 25 | Cloud string `json:"cloud"` 26 | Account string `json:"account"` 27 | } 28 | 29 | var userInfo = `{"username": "henry","age":24,"password":"123123"}` 30 | var invalidInfo = `{"username": "","age":24,"password":"123123"}` 31 | var encrypt = `e2db79dc56e0b5a5866fa4062c9c715e66a5d4820d5424c7645092be0041c1e62d9571b0549758cd02445593b2a276d455cba5b31295e1288d67255e78e4dd78` 32 | 33 | func TestBindJSON(t *testing.T) { 34 | assertions := assert.New(t) 35 | testBindOkay(assertions, strings.NewReader(encrypt), MIMEApplicationJSON) 36 | //testBindError(assertions, strings.NewReader(invalidInfo), MIMEApplicationJSON) 37 | //testBindQueryPrams(assertions, MIMETextHTML) 38 | } 39 | 40 | func TestDefaultBinder_Bind(t *testing.T) { 41 | e := C() 42 | e.POST("/bind", func(c Context) (err error) { 43 | u := new(user) 44 | if err := c.Bind(u); err != nil { 45 | return err 46 | } 47 | return c.JSON(http.StatusOK, u) 48 | }) 49 | req := httptest.NewRequest(http.MethodPost, "/bind", strings.NewReader(invalidInfo)) 50 | req.Header.Set("Content-Type", MIMEApplicationJSON) 51 | rec := httptest.NewRecorder() 52 | e.ServeHTTP(rec, req) 53 | fmt.Println(rec.Code) 54 | fmt.Println(rec.Body.String()) 55 | } 56 | 57 | func TestBindEncryptOkay(t *testing.T) { 58 | e := C() 59 | e.POST("/", func(c Context) (err error) { 60 | u := new(user) 61 | if err := c.Bind(u); err != nil { 62 | return err 63 | } 64 | return c.JSON(http.StatusOK, u) 65 | }) 66 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(encrypt)) 67 | req.Header.Set("Content-Type", MIMEApplicationJSON) 68 | rec := httptest.NewRecorder() 69 | e.ServeHTTP(rec, req) 70 | } 71 | 72 | func TestDefaultBinder_Params_Bind(t *testing.T) { 73 | e := C() 74 | e.GET("/bind", func(c Context) (err error) { 75 | u := new(empty) 76 | if err := c.Bind(u); err != nil { 77 | return err 78 | } 79 | return c.JSON(http.StatusOK, "") 80 | }) 81 | req := httptest.NewRequest(http.MethodGet, "/bind?username=xxxx&cn", nil) 82 | req.Header.Set("Content-Type", MIMEApplicationJSON) 83 | rec := httptest.NewRecorder() 84 | e.ServeHTTP(rec, req) 85 | assert.Equal(t, http.StatusBadRequest, rec.Code) 86 | } 87 | 88 | func testBindOkay(assert *assert.Assertions, r io.Reader, ctype string) { 89 | e := C() 90 | req := httptest.NewRequest(http.MethodPost, "/", r) 91 | rec := httptest.NewRecorder() 92 | c := e.NewContext(req, rec) 93 | req.Header.Set(HeaderContentType, ctype) 94 | u := new(user) 95 | err := c.Bind(u) 96 | if assert.NoError(err) { 97 | assert.Equal("henry", u.Username) 98 | assert.Equal(24, u.Age) 99 | assert.Equal("123123", u.Password) 100 | } 101 | } 102 | 103 | func testBindError(assert *assert.Assertions, r io.Reader, ctype string) { 104 | e := C() 105 | req := httptest.NewRequest(http.MethodPost, "/", r) 106 | rec := httptest.NewRecorder() 107 | c := e.NewContext(req, rec) 108 | req.Header.Set(HeaderContentType, ctype) 109 | u := new(user) 110 | err := c.Bind(u) 111 | assert.Error(err, "Unmarshal type error: expected=yee.user, got=number, field=, offset=1") 112 | } 113 | 114 | func testBindQueryPrams(assert *assert.Assertions, ctype string) { 115 | e := C() 116 | req := httptest.NewRequest(http.MethodGet, "/?secId=sg-gw86k2rjop30v1ktyn3j®ion_id=eu-central", nil) 117 | rec := httptest.NewRecorder() 118 | c := e.NewContext(req, rec) 119 | req.Header.Set(HeaderContentType, ctype) 120 | u := new(cmdbBind) 121 | err := c.Bind(u) 122 | if assert.NoError(err) { 123 | assert.Equal("eu-central", u.RegionId) 124 | assert.Equal("sg-gw86k2rjop30v1ktyn3j", u.SecId) 125 | } 126 | } 127 | 128 | func TestQueryParams(t *testing.T) { 129 | assertions := assert.New(t) 130 | testBindQueryPrams(assertions, MIMETextHTML) 131 | } 132 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "path" 7 | "reflect" 8 | "unsafe" 9 | ) 10 | 11 | // StringToBytes converts string to byte slice without a memory allocation. 12 | func StringToBytes(s string) (b []byte) { 13 | sh := *(*reflect.StringHeader)(unsafe.Pointer(&s)) 14 | bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 15 | bh.Data, bh.Len, bh.Cap = sh.Data, sh.Len, sh.Len 16 | return b 17 | } 18 | 19 | // BytesToString converts byte slice to string without a memory allocation. 20 | func BytesToString(b []byte) string { 21 | return *(*string)(unsafe.Pointer(&b)) 22 | } 23 | 24 | func lastChar(str string) uint8 { 25 | if str == "" { 26 | panic("The length of the string can't be 0") 27 | } 28 | return str[len(str)-1] 29 | } 30 | 31 | func joinPaths(absolutePath, relativePath string) string { 32 | if relativePath == "" { 33 | return absolutePath 34 | } 35 | 36 | finalPath := path.Join(absolutePath, relativePath) 37 | if lastChar(relativePath) == '/' && lastChar(finalPath) != '/' { 38 | return finalPath + "/" 39 | } 40 | return finalPath 41 | } 42 | 43 | func getBytes(key interface{}) []byte { 44 | var buf bytes.Buffer 45 | enc := json.NewEncoder(&buf) 46 | err := enc.Encode(key) 47 | if err != nil { 48 | return nil 49 | } 50 | return buf.Bytes() 51 | } 52 | 53 | func clearPoint(s string) string { 54 | if len(s) > 0 && s[0] == '"' { 55 | s = s[1:] 56 | } 57 | if len(s) > 0 && s[len(s)-2] == '"' { 58 | s = s[:len(s)-2] 59 | } 60 | return s 61 | } 62 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "github.com/cookieY/yee/logger" 8 | "github.com/golang/protobuf/proto" 9 | "math" 10 | "mime/multipart" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "os" 15 | "path/filepath" 16 | "strings" 17 | "sync" 18 | ) 19 | 20 | const crashIndex int = math.MaxInt8 / 2 21 | 22 | // Context is the default implementation interface of context 23 | type Context interface { 24 | Request() *http.Request 25 | SetRequest(r *http.Request) 26 | Response() ResponseWriter 27 | SetResponse(w ResponseWriter) 28 | HTML(code int, html string) (err error) 29 | JSON(code int, i interface{}) error 30 | ProtoBuf(code int, i proto.Message) error 31 | String(code int, s string) error 32 | FormValue(name string) string 33 | FormParams() (url.Values, error) 34 | FormFile(name string) (*multipart.FileHeader, error) 35 | File(file string) error 36 | Blob(code int, contentType string, b []byte) (err error) 37 | Status(code int) 38 | QueryParam(name string) string 39 | QueryString() string 40 | SetHeader(key string, value string) 41 | AddHeader(key string, value string) 42 | GetHeader(key string) string 43 | MultipartForm() (*multipart.Form, error) 44 | Redirect(code int, uri string) error 45 | Params(name string) string 46 | RequestURI() string 47 | Scheme() string 48 | IsTLS() bool 49 | IsWebsocket() bool 50 | Next() 51 | HTMLTpl(code int, tml string) (err error) 52 | QueryParams() map[string][]string 53 | Bind(i interface{}) error 54 | Cookie(name string) (*http.Cookie, error) 55 | SetCookie(cookie *http.Cookie) 56 | Cookies() []*http.Cookie 57 | Get(key string) interface{} 58 | Put(key string, values interface{}) 59 | ServerError(code int, defaultMessage string) error 60 | RemoteIP() string 61 | Logger() logger.Logger 62 | Reset() 63 | Crash() 64 | IsCrash() bool 65 | CrashWithStatus(code int) 66 | CrashWithJson(code int, json interface{}) 67 | Path() string 68 | Abort() 69 | } 70 | 71 | type context struct { 72 | engine *Core 73 | writermem responseWriter 74 | w ResponseWriter 75 | r *http.Request 76 | path string 77 | method string 78 | code int 79 | queryList url.Values // cache url.Values 80 | params *Params 81 | Param Params 82 | // middleware 83 | handlers HandlersChain 84 | index int 85 | store map[string]interface{} 86 | lock sync.RWMutex 87 | noRewrite bool 88 | abort bool 89 | } 90 | 91 | func (c *context) Path() string { 92 | return c.path 93 | } 94 | 95 | func (c *context) Reset() { 96 | c.index = -1 97 | c.handlers = c.engine.noRoute 98 | } 99 | 100 | func (c *context) Abort() { 101 | c.abort = true 102 | } 103 | 104 | func (c *context) Bind(i interface{}) error { 105 | return c.engine.bind.Bind(i, c) 106 | } 107 | 108 | func (c *context) reset() { // reset context members 109 | c.w = &c.writermem 110 | c.Param = c.Param[0:0] 111 | c.handlers = nil 112 | c.index = -1 113 | c.path = "" 114 | // when context reset clear queryList cache . 115 | // cause if not clear cache the queryParams results will mistake 116 | c.queryList = nil 117 | c.store = nil 118 | *c.params = (*c.params)[0:0] 119 | } 120 | 121 | func (c *context) Next() { 122 | c.index++ 123 | s := len(c.handlers) 124 | if s > 0 { 125 | for ; c.index < s; c.index++ { 126 | err := c.handlers[c.index](c) 127 | if err != nil { 128 | c.Logger().Error(err.Error()) 129 | c.CrashWithString(http.StatusInternalServerError, err.Error()) 130 | } 131 | if c.abort { 132 | c.writermem.WriteHeaderNow() 133 | } 134 | if c.w.Written() { 135 | break 136 | } 137 | } 138 | } 139 | } 140 | 141 | func (c *context) Logger() logger.Logger { 142 | return c.engine.l 143 | } 144 | 145 | func (c *context) ServerError(code int, defaultMessage string) error { 146 | c.writermem.status = code 147 | if c.writermem.Written() { 148 | return errors.New("headers were already written") 149 | } 150 | if c.writermem.Status() == code { 151 | c.writermem.Header()["Content-Type"] = []string{MIMETextPlainCharsetUTF8} 152 | c.Logger().Error(fmt.Sprintf("%s %s", c.r.URL, defaultMessage)) 153 | _, err := c.w.Write([]byte(defaultMessage)) 154 | if err != nil { 155 | return fmt.Errorf("cannot write message to writer during serve error: %v", err) 156 | } 157 | return nil 158 | } 159 | c.writermem.WriteHeaderNow() 160 | return nil 161 | } 162 | 163 | func (c *context) Put(key string, values interface{}) { 164 | c.lock.Lock() 165 | defer c.lock.Unlock() 166 | if c.store == nil { 167 | c.store = make(map[string]interface{}) 168 | } 169 | c.store[key] = values 170 | } 171 | 172 | func (c *context) Crash() { 173 | c.index = crashIndex 174 | } 175 | 176 | func (c *context) IsCrash() bool { 177 | return c.index >= crashIndex 178 | } 179 | 180 | func (c *context) CrashWithStatus(code int) { 181 | c.Status(code) 182 | c.w.WriteHeaderNow() 183 | c.Crash() 184 | } 185 | 186 | func (c *context) CrashWithJson(code int, json interface{}) { 187 | c.Crash() 188 | err := c.JSON(code, json) 189 | if err != nil { 190 | return 191 | } 192 | } 193 | 194 | func (c *context) CrashWithString(code int, str string) { 195 | c.Crash() 196 | err := c.String(code, str) 197 | if err != nil { 198 | return 199 | } 200 | } 201 | 202 | func (c *context) Get(key string) interface{} { 203 | c.lock.RLock() 204 | defer c.lock.RUnlock() 205 | return c.store[key] 206 | } 207 | 208 | func (c *context) Request() *http.Request { 209 | return c.r 210 | } 211 | 212 | func (c *context) SetRequest(r *http.Request) { 213 | c.r = r 214 | } 215 | 216 | func (c *context) Response() ResponseWriter { 217 | return c.w 218 | } 219 | 220 | func (c *context) SetResponse(w ResponseWriter) { 221 | c.w = w 222 | } 223 | 224 | func (c *context) RemoteIP() string { 225 | if ip := c.r.Header.Get(HeaderXForwardedFor); ip != "" { 226 | i := strings.IndexAny(ip, ", ") 227 | if i > 0 { 228 | return ip[:i] 229 | } 230 | } 231 | if ip := c.r.Header.Get(HeaderXRealIP); ip != "" { 232 | return ip 233 | } 234 | ip, _, _ := net.SplitHostPort(c.r.RemoteAddr) 235 | return ip 236 | } 237 | 238 | func (c *context) HTML(code int, html string) (err error) { 239 | return c.HTMLBlob(code, []byte(html)) 240 | } 241 | 242 | func (c *context) HTMLTpl(code int, tml string) (err error) { 243 | s, e := os.ReadFile(tml) 244 | if e != nil { 245 | return e 246 | } 247 | return c.HTMLBlob(code, s) 248 | } 249 | 250 | func (c *context) HTMLBlob(code int, b []byte) (err error) { 251 | return c.Blob(code, MIMETextHTMLCharsetUTF8, b) 252 | } 253 | 254 | func (c *context) Blob(code int, contentType string, b []byte) (err error) { 255 | if !c.writermem.Written() { 256 | c.writeContentType(contentType) 257 | c.w.WriteHeader(code) 258 | if _, err = c.w.Write(b); err != nil { 259 | c.Logger().Error(err.Error()) 260 | } 261 | } 262 | return 263 | } 264 | 265 | func (c *context) JSON(code int, i interface{}) (err error) { 266 | if !c.writermem.Written() { 267 | enc := json.NewEncoder(c.w) 268 | c.writeContentType(MIMEApplicationJSONCharsetUTF8) 269 | c.w.WriteHeader(code) 270 | return enc.Encode(i) 271 | } 272 | return 273 | } 274 | 275 | func (c *context) ProtoBuf(code int, i proto.Message) (err error) { 276 | if !c.writermem.Written() { 277 | c.writeContentType(MIMEApplicationProtobuf) 278 | c.w.WriteHeader(code) 279 | b, err := proto.Marshal(i) 280 | if err != nil { 281 | return err 282 | } 283 | if _, err = c.w.Write(b); err != nil { 284 | return err 285 | } 286 | } 287 | return 288 | } 289 | 290 | func (c *context) JSONP(code int, fn string, i interface{}) error { 291 | enc := json.NewEncoder(c.w) 292 | c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) 293 | c.w.WriteHeader(code) 294 | if _, err := c.w.Write([]byte(fn + "(")); err != nil { 295 | return err 296 | } 297 | if err := enc.Encode(i); err != nil { 298 | return err 299 | } 300 | if _, err := c.w.Write([]byte(");")); err != nil { 301 | return err 302 | } 303 | return nil 304 | } 305 | 306 | func (c *context) String(code int, s string) error { 307 | return c.Blob(code, MIMETextPlainCharsetUTF8, StringToBytes(s)) 308 | } 309 | 310 | func (c *context) Status(code int) { 311 | c.w.WriteHeader(code) 312 | } 313 | 314 | func (c *context) SetHeader(key string, value string) { 315 | c.w.Header().Set(key, value) 316 | } 317 | 318 | func (c *context) AddHeader(key string, value string) { 319 | c.w.Header().Add(key, value) 320 | } 321 | 322 | func (c *context) GetHeader(key string) string { 323 | return c.r.Header.Get(key) 324 | } 325 | 326 | func (c *context) Params(name string) string { 327 | for _, i := range *c.params { 328 | if i.Key == name { 329 | return i.Value 330 | } 331 | } 332 | return "" 333 | } 334 | func (c *context) QueryParams() map[string][]string { 335 | return c.r.URL.Query() 336 | } 337 | 338 | func (c *context) QueryParam(name string) string { 339 | if c.queryList == nil { 340 | c.queryList = c.r.URL.Query() 341 | } 342 | return c.queryList.Get(name) 343 | } 344 | 345 | func (c *context) QueryString() string { 346 | return c.r.URL.RawQuery 347 | } 348 | 349 | func (c *context) FormValue(name string) string { 350 | return c.r.FormValue(name) 351 | } 352 | 353 | func (c *context) FormParams() (url.Values, error) { 354 | if strings.HasPrefix(c.r.Header.Get(HeaderContentType), MIMEMultipartForm) { 355 | if err := c.r.ParseMultipartForm(defaultMemory); err != nil { 356 | return nil, err 357 | } 358 | } else { 359 | if err := c.r.ParseForm(); err != nil { 360 | return nil, err 361 | } 362 | } 363 | return c.r.Form, nil 364 | } 365 | 366 | func (c *context) FormFile(name string) (*multipart.FileHeader, error) { 367 | _, fd, err := c.r.FormFile(name) 368 | return fd, err 369 | } 370 | 371 | func (c *context) File(file string) error { 372 | fd, err := os.Open(file) 373 | if err != nil { 374 | return err 375 | } 376 | 377 | defer fd.Close() 378 | 379 | f, _ := fd.Stat() 380 | if f.IsDir() { 381 | file = filepath.Join(file, indexPage) 382 | fd, err = os.Open(file) 383 | if err != nil { 384 | return ErrNotFoundHandler 385 | } 386 | defer fd.Close() 387 | if f, err = fd.Stat(); err != nil { 388 | return err 389 | } 390 | } 391 | http.ServeContent(c.Response(), c.Request(), f.Name(), f.ModTime(), fd) 392 | return nil 393 | } 394 | 395 | func (c *context) MultipartForm() (*multipart.Form, error) { 396 | err := c.r.ParseMultipartForm(defaultMemory) 397 | return c.r.MultipartForm, err 398 | } 399 | 400 | func (c *context) Cookie(name string) (*http.Cookie, error) { 401 | return c.r.Cookie(name) 402 | } 403 | 404 | func (c *context) SetCookie(cookie *http.Cookie) { 405 | http.SetCookie(c.w, cookie) 406 | } 407 | 408 | func (c *context) Cookies() []*http.Cookie { 409 | return c.r.Cookies() 410 | } 411 | 412 | func (c *context) RequestURI() string { 413 | return c.r.RequestURI 414 | } 415 | 416 | func (c *context) Scheme() string { 417 | scheme := "http" 418 | if scheme := c.r.Header.Get(HeaderXForwardedProto); scheme != "" { 419 | return scheme 420 | } 421 | if scheme := c.r.Header.Get(HeaderXForwardedProtocol); scheme != "" { 422 | return scheme 423 | } 424 | if ssl := c.r.Header.Get(HeaderXForwardedSsl); ssl == "on" { 425 | return "https" 426 | } 427 | if scheme := c.r.Header.Get(HeaderXUrlScheme); scheme != "" { 428 | return scheme 429 | } 430 | return scheme 431 | } 432 | 433 | func (c *context) IsTLS() bool { 434 | return c.r.TLS != nil 435 | } 436 | 437 | func (c *context) IsWebsocket() bool { 438 | if strings.Contains(strings.ToLower(c.r.Header.Get(HeaderConnection)), "upgrade") && 439 | strings.EqualFold(c.r.Header.Get(HeaderUpgrade), "websocket") { 440 | return true 441 | } 442 | return false 443 | } 444 | 445 | func (c *context) Redirect(code int, uri string) error { 446 | if code < 300 || code > 308 { 447 | return ErrInvalidRedirectCode 448 | } 449 | c.w.Header().Set(HeaderLocation, uri) 450 | c.w.WriteHeader(code) 451 | return nil 452 | } 453 | 454 | func (c *context) writeContentType(value string) { 455 | header := c.w.Header() 456 | if header.Get(HeaderContentType) == "" { 457 | header.Set(HeaderContentType, value) 458 | } 459 | } 460 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "log" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | var testData = `{"id":1,"name":"Jon Snow"}` 13 | 14 | type res struct { 15 | ID int `json:"id"` 16 | Name string `json:"name"` 17 | } 18 | 19 | func TestContextJSON(t *testing.T) { 20 | y := New() 21 | y.POST("/", func(c Context) (err error) { 22 | t := new(res) 23 | if err = c.Bind(&t); err != nil { 24 | return err 25 | } 26 | return c.JSON(http.StatusOK, t) 27 | }) 28 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(testData)) 29 | req.Header.Set("Content-Type", MIMEApplicationJSON) 30 | rec := httptest.NewRecorder() 31 | y.ServeHTTP(rec, req) 32 | assert.Equal(t, testData+"\n", rec.Body.String()) 33 | } 34 | 35 | func TestContextForward(t *testing.T) { 36 | y := New() 37 | y.POST("/", func(c Context) (err error) { 38 | return c.JSON(http.StatusOK, c.RemoteIP()) 39 | }) 40 | req := httptest.NewRequest(http.MethodPost, "/", nil) 41 | req.Header.Set(HeaderXForwardedFor, " ") 42 | rec := httptest.NewRecorder() 43 | y.ServeHTTP(rec, req) 44 | } 45 | 46 | func TestContextString(t *testing.T) { 47 | y := New() 48 | y.POST("/", func(c Context) (err error) { 49 | return c.String(http.StatusOK, "hello") 50 | }) 51 | req := httptest.NewRequest(http.MethodPost, "/", nil) 52 | rec := httptest.NewRecorder() 53 | y.ServeHTTP(rec, req) 54 | assert.Equal(t, "hello", rec.Body.String()) 55 | } 56 | 57 | func crashMiddleware() HandlerFunc { 58 | return func(c Context) (err error) { 59 | c.CrashWithStatus(http.StatusUnauthorized) 60 | return 61 | } 62 | } 63 | func sayMiddleware() HandlerFunc { 64 | return func(c Context) (err error) { 65 | log.Println("say") 66 | return 67 | } 68 | } 69 | 70 | func TestCrash(t *testing.T) { 71 | y := New() 72 | y.Use(crashMiddleware()) 73 | y.Use(sayMiddleware()) 74 | y.GET("/", func(c Context) (err error) { 75 | return c.String(http.StatusOK, "hello") 76 | }) 77 | req := httptest.NewRequest(http.MethodGet, "/", nil) 78 | rec := httptest.NewRecorder() 79 | y.ServeHTTP(rec, req) 80 | assert.Equal(t, http.StatusUnauthorized, rec.Body) 81 | } 82 | 83 | func TestRedirect(t *testing.T) { 84 | y := New() 85 | y.GET("/", func(c Context) (err error) { 86 | return c.Redirect(http.StatusMovedPermanently, "/get") 87 | }) 88 | y.GET("/get", func(c Context) (err error) { 89 | return c.String(http.StatusOK, "hello") 90 | }) 91 | req := httptest.NewRequest(http.MethodGet, "/", nil) 92 | rec := httptest.NewRecorder() 93 | y.ServeHTTP(rec, req) 94 | } 95 | 96 | func BenchmarkAllocJSON(b *testing.B) { 97 | y := New() 98 | y.POST("/", func(c Context) (err error) { 99 | tl := new(res) 100 | if err = c.Bind(&tl); err != nil { 101 | 102 | return err 103 | } 104 | return c.JSON(http.StatusOK, tl) 105 | }) 106 | b.ResetTimer() 107 | b.ReportAllocs() 108 | for i := 0; i < b.N; i++ { 109 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(testData)) 110 | req.Header.Set("Content-Type", MIMEApplicationJSON) 111 | rec := httptest.NewRecorder() 112 | y.ServeHTTP(rec, req) 113 | } 114 | } 115 | 116 | func BenchmarkAllocString(b *testing.B) { 117 | y := New() 118 | y.POST("/", func(c Context) (err error) { 119 | return c.String(http.StatusOK, "ok") 120 | }) 121 | b.ResetTimer() 122 | b.ReportAllocs() 123 | for i := 0; i < b.N; i++ { 124 | req := httptest.NewRequest(http.MethodPost, "/", nil) 125 | rec := httptest.NewRecorder() 126 | y.ServeHTTP(rec, req) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /fs.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | ) 7 | 8 | type onlyFilesFS struct { 9 | fs http.FileSystem 10 | } 11 | 12 | type neuteredReaddirFile struct { 13 | http.File 14 | } 15 | 16 | // Dir returns a http.FileSystem that can be used by http.FileServer(). It is used internally 17 | // in router.Static(). 18 | // if listDirectory == true, then it works the same as http.Dir() otherwise it returns 19 | // a filesystem that prevents http.FileServer() to list the directory files. 20 | func Dir(root string, listDirectory bool) http.FileSystem { 21 | fs := http.Dir(root) 22 | if listDirectory { 23 | return fs 24 | } 25 | return &onlyFilesFS{fs} 26 | } 27 | 28 | // Open conforms to http.Filesystem. 29 | func (fs onlyFilesFS) Open(name string) (http.File, error) { 30 | f, err := fs.fs.Open(name) 31 | if err != nil { 32 | return nil, err 33 | } 34 | return neuteredReaddirFile{f}, nil 35 | } 36 | 37 | // Readdir overrides the http.File default implementation. 38 | func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) { 39 | // this disables directory listing 40 | return nil, nil 41 | } 42 | -------------------------------------------------------------------------------- /gen.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // gen.go used to generate Restful code 9 | 10 | const PREFIX = ` 11 | package ${PACKAGE} 12 | 13 | import ( 14 | "net/http" 15 | "github.com/cookieY/yee" 16 | "github.com/jinzhu/gorm" 17 | ) 18 | 19 | ${TP} 20 | 21 | func Paging(page interface{}, total int) (start int, end int) { 22 | start = i*total - total 23 | end = total 24 | return 25 | } 26 | 27 | func Fetch${PACKAGE}(c yee.Context) (err error) { 28 | u := new(FinderPrefix) 29 | if err = c.Bind(u); err != nil { 30 | c.Logger().Error(err.Error()) 31 | return 32 | } 33 | 34 | var order []${MODAL} 35 | 36 | start, end := lib.Paging(u.Page, ${PAGE}) 37 | 38 | if u.Find.Valve { 39 | model.DB().Model(&${MODAL}{}). 40 | Scopes( 41 | ${QUERY_EXPR} 42 | ).Count(&pg).Order("id desc").Offset(start).Limit(end).Find(&order) 43 | } else { 44 | model.DB().Model(&model.CoreSqlOrder{}).Count(&pg).Order("id desc").Offset(start).Limit(end).Find(&order) 45 | } 46 | 47 | return c.JSON(http.StatusOK, map[string]interface{}{"data": order, "page": pg}) 48 | } 49 | 50 | ${QUERYFUNC} 51 | 52 | ` 53 | 54 | var QueryExprPrefix = ` 55 | func AccordingTo${EXPR_NAME}(val string) func(db *gorm.DB) *gorm.DB { 56 | return func(db *gorm.DB) *gorm.DB { 57 | return db.Where(${QUERY_EXPR},val) 58 | } 59 | } 60 | ` 61 | 62 | var FinderPrefix = ` 63 | type FinderPrefix struct { 64 | Valve bool // 自行添加tag 65 | ${FINDER_EXPR} 66 | } 67 | ` 68 | 69 | type expr struct { 70 | Name string `json:"name"` 71 | Expr string `json:"expr"` 72 | TP string `json:"tp"` 73 | } 74 | 75 | type GenCodeVal struct { 76 | Flag string `json:"flag"` // 根据哪个字段进行CURD 77 | Package string `json:"package"` // 项目名,根据项目名生成package name 78 | QueryExpr []expr `json:"query_expr"` // 查询条件 79 | Page string `json:"page"` //分页大小 80 | Modal string `json:"modal"` 81 | } 82 | 83 | func GenerateRestfulAPI(GenCodeVal GenCodeVal) string { 84 | empty := strings.Replace(PREFIX, "${PACKAGE}", GenCodeVal.Package, -1) 85 | empty = strings.Replace(empty, "${MODAL}", GenCodeVal.Modal, -1) 86 | empty = strings.Replace(empty, "${PAGE}", GenCodeVal.Page, -1) 87 | f, s,l := GenQueryExpr(GenCodeVal.QueryExpr) 88 | empty = strings.Replace(empty, "${QUERYFUNC}", f, -1) 89 | empty = strings.Replace(empty, "${TP}", s, -1) 90 | empty = strings.Replace(empty, "${QUERY_EXPR}", l, -1) 91 | return empty 92 | } 93 | 94 | func GenQueryExpr(QueryExpr []expr) (string, string, string) { 95 | funcEmpty := "" 96 | structEmpty := "" 97 | exprList := "" 98 | for _, i := range QueryExpr { 99 | tmpText := "" 100 | tmpText = strings.Replace(QueryExprPrefix, "${EXPR_NAME}", i.Name, -1) 101 | tmpText = strings.Replace(tmpText, "${QUERY_EXPR}", i.Expr, -1) 102 | funcEmpty += tmpText + "\n" 103 | structEmpty += fmt.Sprintf("%s %s \n ", i.Name, i.TP) 104 | exprList += fmt.Sprintf("AccordingTo%s(u.%s),\n ", i.Name, strings.ToLower(i.Name)) 105 | } 106 | structEmpty = strings.Replace(FinderPrefix, "${FINDER_EXPR}", structEmpty, -1) 107 | return funcEmpty, structEmpty, exprList 108 | } 109 | -------------------------------------------------------------------------------- /gen_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | var test = ` 12 | type GenCodeVal struct { 13 | ${ATTRIBUTE} 14 | } 15 | ` 16 | 17 | func jk(t string) { 18 | test = strings.Replace(test, "${ATTRIBUTE}", t, -1) 19 | f, err := os.OpenFile("koala.go", os.O_WRONLY&os.O_CREATE, 0666) 20 | if err != nil { 21 | log.Println(err.Error()) 22 | } 23 | _, err = f.Write([]byte(test)) 24 | if err != nil { 25 | log.Println(err.Error()) 26 | } 27 | f.Close() 28 | } 29 | 30 | func TestNew(t *testing.T) { 31 | c := []map[string]string{{"a": "int"}, {"k": "string"}} 32 | //l := "" 33 | for _, i := range c { 34 | fmt.Println(i) 35 | //fmt.Println(c[j]) 36 | } 37 | } 38 | 39 | var exprCase = []expr{{Name: "Username", Expr: "username =?", TP: "string"}, {Name: "Age", Expr: "age > ?", TP: "int"}} 40 | 41 | func TestGenQueryExpr(t *testing.T) { 42 | 43 | GenQueryExpr(exprCase) 44 | } 45 | 46 | func TestGenerateRestfulAPI(t *testing.T) { 47 | k := GenCodeVal{ 48 | Package: "manage", 49 | QueryExpr: exprCase, 50 | Page: "20", 51 | Modal: "modal.core_account", 52 | } 53 | fmt.Println(GenerateRestfulAPI(k)) 54 | } 55 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cookieY/yee 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect 7 | github.com/go-playground/validator/v10 v10.16.0 8 | github.com/golang-jwt/jwt v3.2.2+incompatible 9 | github.com/golang/protobuf v1.5.3 10 | github.com/google/go-cmp v0.5.9 // indirect 11 | github.com/google/uuid v1.1.1 12 | github.com/mattn/go-colorable v0.1.6 13 | github.com/mattn/go-isatty v0.0.14 14 | github.com/opentracing/opentracing-go v1.2.0 15 | github.com/pkg/errors v0.8.1 // indirect 16 | github.com/stretchr/testify v1.8.2 17 | github.com/uber/jaeger-client-go v2.30.0+incompatible 18 | github.com/uber/jaeger-lib v2.4.1+incompatible // indirect 19 | github.com/valyala/fasttemplate v1.1.0 20 | go.uber.org/atomic v1.9.0 // indirect 21 | golang.org/x/net v0.23.0 22 | google.golang.org/protobuf v1.28.0 23 | ) 24 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= 2 | github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= 3 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= 4 | github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= 5 | github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= 6 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 7 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 9 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= 11 | github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= 12 | github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= 13 | github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= 14 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 15 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 16 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 17 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 18 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 19 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 20 | github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= 21 | github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= 22 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= 23 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= 24 | github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= 25 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 26 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= 27 | github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 28 | github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 29 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 30 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 31 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 32 | github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= 33 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 34 | github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= 35 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 36 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 37 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 38 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 39 | github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= 40 | github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= 41 | github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= 42 | github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 43 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 44 | github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= 45 | github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= 46 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= 47 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= 48 | github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= 49 | github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= 50 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 51 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 52 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 54 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 55 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 56 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 57 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 58 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 59 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 60 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 61 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 62 | github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= 63 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 64 | github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= 65 | github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= 66 | github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= 67 | github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= 68 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 69 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 70 | github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4= 71 | github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= 72 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 73 | go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= 74 | go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 75 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 76 | golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 77 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 78 | golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= 79 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= 80 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 81 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 82 | golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 83 | golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 84 | golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 85 | golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 86 | golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= 87 | golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= 88 | golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= 89 | golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= 90 | golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= 91 | golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= 92 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 93 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 94 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 95 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 96 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 97 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 98 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 99 | golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= 100 | golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= 101 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 102 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= 103 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= 104 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 105 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 106 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 107 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 108 | golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 109 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 110 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 111 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 112 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 113 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 114 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 115 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 116 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 117 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 118 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 119 | golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 120 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 121 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= 122 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 123 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 124 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 125 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 126 | golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= 127 | golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= 128 | golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= 129 | golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= 130 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 131 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 132 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 133 | golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= 134 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 135 | golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 136 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 137 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 138 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 139 | golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 140 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 141 | golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 142 | golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 143 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 144 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 145 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 146 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 147 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 148 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 149 | gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= 150 | gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= 151 | gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= 152 | gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= 153 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 154 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 155 | google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= 156 | google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 157 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 158 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= 159 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 160 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 161 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 162 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 163 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 164 | -------------------------------------------------------------------------------- /h2_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net" 9 | "net/http" 10 | "testing" 11 | 12 | "golang.org/x/net/http2" 13 | ) 14 | 15 | func TestH2(T *testing.T) { 16 | y := New() 17 | y.GET("/", func(c Context) error { 18 | return c.String(http.StatusOK, "ok") 19 | }) 20 | y.RunH2C(":9999") 21 | } 22 | 23 | func TestH2cClient(t *testing.T) { 24 | client := http.Client{ 25 | Transport: &http2.Transport{ 26 | AllowHTTP: true, 27 | DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { 28 | return net.Dial(network, addr) 29 | }, 30 | }, 31 | } 32 | 33 | resp, err := client.Get("http://localhost:9999") 34 | if err != nil { 35 | log.Fatalf("faild request: %s", err) 36 | } 37 | 38 | defer resp.Body.Close() 39 | 40 | body, err := ioutil.ReadAll(resp.Body) 41 | 42 | if err != nil { 43 | log.Fatalf("read response failed: %s", err) 44 | } 45 | fmt.Printf("proto:%s\ncode %d: %s\n", resp.Proto, resp.StatusCode, string(body)) 46 | } 47 | -------------------------------------------------------------------------------- /h3_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | //import ( 4 | // "crypto/tls" 5 | // "fmt" 6 | // pb "github.com/cookieY/yee/test" 7 | // "github.com/quic-go/quic-go" 8 | // "github.com/quic-go/quic-go/http3" 9 | // "io/ioutil" 10 | // "net/http" 11 | // "runtime" 12 | // "testing" 13 | //) 14 | // 15 | //const addr = "https://www.henry.com:9999/hello" 16 | // 17 | //func TestH3Server(t *testing.T) { 18 | // y := New() 19 | // y.SetLogLevel(5) 20 | // y.POST("/hello", func(c Context) (err error) { 21 | // u := new(pb.Svr) 22 | // if err := c.Bind(u); err != nil { 23 | // c.Logger().Error(err.Error()) 24 | // return err 25 | // } 26 | // c.Logger().Debugf("svr get client data: %s", u.Project) 27 | // svr := pb.Svr{Cloud: "hi"} 28 | // return c.ProtoBuf(http.StatusOK, &svr) 29 | // }) 30 | // y.RunH3(":9999", "henry.com+4.pem", "henry.com+4-key.pem") 31 | //} 32 | // 33 | //func TestH3SvrIndex(t *testing.T) { 34 | // y := New() 35 | // y.SetLogLevel(5) 36 | // y.GET("/", func(c Context) (err error) { 37 | // 38 | // return c.JSON(http.StatusOK, "hello") 39 | // }) 40 | // y.Run(":445") 41 | //} 42 | // 43 | //func TestH2SvrIndex(t *testing.T) { 44 | // y := New() 45 | // y.SetLogLevel(5) 46 | // y.GET("/", func(c Context) (err error) { 47 | // 48 | // return c.JSON(http.StatusOK, "hello") 49 | // }) 50 | // y.RunTLS(":444", "henry.com+4.pem", "henry.com+4-key.pem") 51 | //} 52 | // 53 | //var cs = http.Client{ 54 | // Transport: &http3.RoundTripper{ 55 | // TLSClientConfig: &tls.Config{}, 56 | // QuicConfig: &quic.Config{}, 57 | // }, 58 | //} 59 | // 60 | //func BenchmarkH3SvrIndex(b *testing.B) { 61 | // b.SetBytes(1024 * 1024) 62 | // for i := 0; i < b.N; i++ { 63 | // http.Get("http://127.0.0.1:445/") 64 | // } 65 | //} 66 | // 67 | //func BenchmarkH2SvrIndex(b *testing.B) { 68 | // b.SetBytes(1024 * 1024) 69 | // for i := 0; i < b.N; i++ { 70 | // http.Get("https://127.0.0.1:444/") 71 | // } 72 | //} 73 | // 74 | //func TestRespProto(t *testing.T) { 75 | // //cs := http.Client{ 76 | // // Transport: &http3.RoundTripper{ 77 | // // TLSClientConfig: &tls.Config{ 78 | // // InsecureSkipVerify: true, 79 | // // }, 80 | // // QuicConfig: &quic.Config{}, 81 | // // }, 82 | // //} 83 | // 84 | // b, err := http.Get("https://127.0.0.1:444/") 85 | // if err != nil { 86 | // t.Error(err) 87 | // } 88 | // c, _ := ioutil.ReadAll(b.Body) 89 | // fmt.Println(string(c)) 90 | // fmt.Println(b.Proto) 91 | //} 92 | // 93 | //func TestNewH3Client(t *testing.T) { 94 | // cs := NewH3Client(&CConfig{ 95 | // Addr: addr, 96 | // InsecureSkipVerify: true, 97 | // }) 98 | // rsp := new(pb.Svr) 99 | // cs.Post(&pb.Svr{Project: "henry"}, rsp) 100 | // fmt.Println(rsp.Cloud) 101 | //} 102 | // 103 | //func TestNewProtoc3(t *testing.T) { 104 | // y := New() 105 | // y.SetLogLevel(5) 106 | // y.POST("/hello", func(c Context) (err error) { 107 | // u := new(pb.Svr) 108 | // if err := c.Bind(u); err != nil { 109 | // c.Logger().Error(err.Error()) 110 | // return err 111 | // } 112 | // c.Logger().Debugf("svr get client data: %s", u.Project) 113 | // svr := pb.Svr{Cloud: "hi"} 114 | // return c.ProtoBuf(http.StatusOK, &svr) 115 | // }) 116 | // y.Run(":9999") 117 | //} 118 | // 119 | //func BenchmarkH2vsH3(b *testing.B) { 120 | // runtime.GOMAXPROCS(runtime.NumCPU()) 121 | // b.Run("http3", BenchmarkH3SvrIndex) 122 | // b.Run("http2", BenchmarkH2SvrIndex) 123 | //} 124 | -------------------------------------------------------------------------------- /h3client.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | // 4 | //import ( 5 | // "bytes" 6 | // "crypto/tls" 7 | // "github.com/quic-go/quic-go" 8 | // "io/ioutil" 9 | // "net/http" 10 | // 11 | // "github.com/cookieY/yee/logger" 12 | // "github.com/golang/protobuf/proto" 13 | // "github.com/quic-go/quic-go/http3" 14 | //) 15 | // 16 | //type transport struct { 17 | // addr string 18 | // insecureSkipVerify bool 19 | // logger logger.Logger 20 | // tripper *http3.RoundTripper 21 | // c *http.Client 22 | //} 23 | // 24 | //type CConfig struct { 25 | // Addr string 26 | // InsecureSkipVerify bool 27 | //} 28 | // 29 | //func NewH3Client(c *CConfig) *transport { 30 | // tripper := &http3.RoundTripper{ 31 | // TLSClientConfig: &tls.Config{ 32 | // InsecureSkipVerify: true, 33 | // }, 34 | // QuicConfig: &quic.Config{}, 35 | // } 36 | // return &transport{ 37 | // addr: c.Addr, 38 | // insecureSkipVerify: c.InsecureSkipVerify, 39 | // logger: logger.LogCreator(), 40 | // c: &http.Client{ 41 | // Transport: tripper, 42 | // }, 43 | // tripper: tripper, 44 | // } 45 | //} 46 | // 47 | //func (t *transport) Get(url string) (*http.Response, error) { 48 | // resp, err := t.c.Get(url) 49 | // if err != nil { 50 | // return nil, err 51 | // } 52 | // return resp, nil 53 | //} 54 | // 55 | //func (t *transport) Post(payload proto.Message, recv proto.Message) { 56 | // p, err := proto.Marshal(payload) 57 | // if err != nil { 58 | // t.logger.Critical(err.Error()) 59 | // return 60 | // } 61 | // rsp, err := t.c.Post(t.addr, MIMEApplicationProtobuf, bytes.NewReader(p)) 62 | // if err != nil { 63 | // t.logger.Critical(err.Error()) 64 | // return 65 | // } 66 | // b, err := ioutil.ReadAll(rsp.Body) 67 | // if rsp.StatusCode == 200 { 68 | // err = proto.Unmarshal(b, recv) 69 | // if err != nil { 70 | // t.logger.Critical(err.Error()) 71 | // } 72 | // return 73 | // } 74 | // t.logger.Error(string(b)) 75 | // defer t.close() 76 | //} 77 | // 78 | //func (t *transport) close() { 79 | // err := t.tripper.Close() 80 | // if err != nil { 81 | // t.logger.Critical(err.Error()) 82 | // } 83 | //} 84 | -------------------------------------------------------------------------------- /img/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cookieY/yee/c07a8279ce061505162bad9ba33b576302b6fd92/img/benchmark.png -------------------------------------------------------------------------------- /log_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "github.com/cookieY/yee/logger" 5 | "github.com/stretchr/testify/assert" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestLogger_LogWrite(t *testing.T) { 12 | 13 | y := New() 14 | y.SetLogLevel(logger.Warning) 15 | //file, err := os.OpenFile("logrus.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) 16 | //if err != nil { 17 | // t.Log(err) 18 | // return 19 | //} 20 | //y.SetLogOut(file) 21 | y.POST("/hello/k/:b", func(c Context) error { 22 | c.Logger().Critical("critical") 23 | c.Logger().Error("error") 24 | c.Logger().Warn("warn") 25 | c.Logger().Info("info") 26 | c.Logger().Debug("debug") 27 | c.Logger().Criticalf("test:%v", 123) 28 | c.Logger().Errorf("test:%v", 123) 29 | c.Logger().Warnf("test:%v", 123) 30 | c.Logger().Infof("test:%v", 123) 31 | c.Logger().Debugf("test:%v", 123) 32 | return c.String(http.StatusOK, c.Params("b")) 33 | }) 34 | t.Run("http_get", func(t *testing.T) { 35 | req := httptest.NewRequest(http.MethodPost, "/hello/k/henry", nil) 36 | rec := httptest.NewRecorder() 37 | y.ServeHTTP(rec, req) 38 | tx := assert.New(t) 39 | tx.Equal("henry", rec.Body.String()) 40 | //assert.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin)) 41 | }) 42 | } 43 | 44 | func BenchmarkLogger_LogWrite(b *testing.B) { 45 | l := logger.LogCreator() 46 | b.ReportAllocs() 47 | b.SetBytes(1024 * 1024) 48 | for i := 0; i < b.N; i++ { 49 | l.Critical("critical") 50 | l.Error("error") 51 | l.Warn("warn") 52 | l.Info("info") 53 | l.Debug("debug") 54 | l.Criticalf("test:%v", 123) 55 | l.Errorf("test:%v", 123) 56 | l.Warnf("test:%v", 123) 57 | l.Infof("test:%v", 123) 58 | l.Debugf("test:%v", 123) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /logger/color.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "os" 8 | 9 | "github.com/mattn/go-colorable" 10 | "github.com/mattn/go-isatty" 11 | ) 12 | 13 | type ( 14 | inner func(interface{}, []string, *Color) string 15 | ) 16 | 17 | // Color styles 18 | const ( 19 | // Blk Black text style 20 | Blk = "30" 21 | // Rd red text style 22 | Rd = "31" 23 | // Grn green text style 24 | Grn = "32" 25 | // Yel yellow text style 26 | Yel = "33" 27 | // Blu blue text style 28 | Blu = "34" 29 | // Mgn magenta text style 30 | Mgn = "35" 31 | // Cyn cyan text style 32 | Cyn = "36" 33 | // Wht white text style 34 | Wht = "37" 35 | // Gry grey text style 36 | Gry = "90" 37 | 38 | // BlkBg black background style 39 | BlkBg = "40" 40 | // RdBg red background style 41 | RdBg = "41" 42 | // GrnBg green background style 43 | GrnBg = "42" 44 | // YelBg yellow background style 45 | YelBg = "43" 46 | // BluBg blue background style 47 | BluBg = "44" 48 | // MgnBg magenta background style 49 | MgnBg = "45" 50 | // CynBg cyan background style 51 | CynBg = "46" 52 | // WhtBg white background style 53 | WhtBg = "47" 54 | 55 | // R reset emphasis style 56 | R = "0" 57 | // B bold emphasis style 58 | B = "1" 59 | // D dim emphasis style 60 | D = "2" 61 | // I italic emphasis style 62 | I = "3" 63 | // U underline emphasis style 64 | U = "4" 65 | // In inverse emphasis style 66 | In = "7" 67 | // H hidden emphasis style 68 | H = "8" 69 | // S strikeout emphasis style 70 | S = "9" 71 | ) 72 | 73 | var ( 74 | black = outer(Blk) 75 | red = outer(Rd) 76 | green = outer(Grn) 77 | yellow = outer(Yel) 78 | blue = outer(Blu) 79 | magenta = outer(Mgn) 80 | cyan = outer(Cyn) 81 | white = outer(Wht) 82 | grey = outer(Gry) 83 | 84 | blackBg = outer(BlkBg) 85 | redBg = outer(RdBg) 86 | greenBg = outer(GrnBg) 87 | yellowBg = outer(YelBg) 88 | blueBg = outer(BluBg) 89 | magentaBg = outer(MgnBg) 90 | cyanBg = outer(CynBg) 91 | whiteBg = outer(WhtBg) 92 | 93 | reset = outer(R) 94 | bold = outer(B) 95 | dim = outer(D) 96 | italic = outer(I) 97 | underline = outer(U) 98 | inverse = outer(In) 99 | hidden = outer(H) 100 | strikeout = outer(S) 101 | 102 | global = New() 103 | ) 104 | 105 | func outer(n string) inner { 106 | return func(msg interface{}, styles []string, c *Color) string { 107 | // TODO: Drop fmt to boost performance? 108 | if c.disabled { 109 | return fmt.Sprintf("%v", msg) 110 | } 111 | 112 | b := new(bytes.Buffer) 113 | b.WriteString("\x1b[") 114 | b.WriteString(n) 115 | for _, s := range styles { 116 | b.WriteString(";") 117 | b.WriteString(s) 118 | } 119 | b.WriteString("m") 120 | return fmt.Sprintf("%s%v\x1b[0m", b.String(), msg) 121 | } 122 | } 123 | 124 | type ( 125 | Color struct { 126 | output io.Writer 127 | disabled bool 128 | } 129 | ) 130 | 131 | // New creates a Color instance. 132 | func New() (c *Color) { 133 | c = new(Color) 134 | c.SetOutput(colorable.NewColorableStdout()) 135 | return 136 | } 137 | 138 | // Output returns the output. 139 | func (c *Color) Output() io.Writer { 140 | return c.output 141 | } 142 | 143 | // SetOutput sets the output. 144 | func (c *Color) SetOutput(w io.Writer) { 145 | c.output = w 146 | if w, ok := w.(*os.File); !ok || !isatty.IsTerminal(w.Fd()) { 147 | c.disabled = true 148 | } 149 | } 150 | 151 | // Disable disables the colors and styles. 152 | func (c *Color) Disable() { 153 | c.disabled = true 154 | } 155 | 156 | // Enable enables the colors and styles. 157 | func (c *Color) Enable() { 158 | c.disabled = false 159 | } 160 | 161 | // Print is analogous to `fmt.Print` with termial detection. 162 | func (c *Color) Print(args ...interface{}) { 163 | fmt.Fprint(c.output, args...) 164 | } 165 | 166 | // Println is analogous to `fmt.Println` with termial detection. 167 | func (c *Color) Println(args ...interface{}) { 168 | fmt.Fprintln(c.output, args...) 169 | } 170 | 171 | // Printf is analogous to `fmt.Printf` with termial detection. 172 | func (c *Color) Printf(format string, args ...interface{}) { 173 | fmt.Fprintf(c.output, format, args...) 174 | } 175 | 176 | func (c *Color) Black(msg interface{}, styles ...string) string { 177 | return black(msg, styles, c) 178 | } 179 | 180 | func (c *Color) Red(msg interface{}, styles ...string) string { 181 | return red(msg, styles, c) 182 | } 183 | 184 | func (c *Color) Green(msg interface{}, styles ...string) string { 185 | return green(msg, styles, c) 186 | } 187 | 188 | func (c *Color) Yellow(msg interface{}, styles ...string) string { 189 | return yellow(msg, styles, c) 190 | } 191 | 192 | func (c *Color) Blue(msg interface{}, styles ...string) string { 193 | return blue(msg, styles, c) 194 | } 195 | 196 | func (c *Color) Magenta(msg interface{}, styles ...string) string { 197 | return magenta(msg, styles, c) 198 | } 199 | 200 | func (c *Color) Cyan(msg interface{}, styles ...string) string { 201 | return cyan(msg, styles, c) 202 | } 203 | 204 | func (c *Color) White(msg interface{}, styles ...string) string { 205 | return white(msg, styles, c) 206 | } 207 | 208 | func (c *Color) Grey(msg interface{}, styles ...string) string { 209 | return grey(msg, styles, c) 210 | } 211 | 212 | func (c *Color) BlackBg(msg interface{}, styles ...string) string { 213 | return blackBg(msg, styles, c) 214 | } 215 | 216 | func (c *Color) RedBg(msg interface{}, styles ...string) string { 217 | return redBg(msg, styles, c) 218 | } 219 | 220 | func (c *Color) GreenBg(msg interface{}, styles ...string) string { 221 | return greenBg(msg, styles, c) 222 | } 223 | 224 | func (c *Color) YellowBg(msg interface{}, styles ...string) string { 225 | return yellowBg(msg, styles, c) 226 | } 227 | 228 | func (c *Color) BlueBg(msg interface{}, styles ...string) string { 229 | return blueBg(msg, styles, c) 230 | } 231 | 232 | func (c *Color) MagentaBg(msg interface{}, styles ...string) string { 233 | return magentaBg(msg, styles, c) 234 | } 235 | 236 | func (c *Color) CyanBg(msg interface{}, styles ...string) string { 237 | return cyanBg(msg, styles, c) 238 | } 239 | 240 | func (c *Color) WhiteBg(msg interface{}, styles ...string) string { 241 | return whiteBg(msg, styles, c) 242 | } 243 | 244 | func (c *Color) Reset(msg interface{}, styles ...string) string { 245 | return reset(msg, styles, c) 246 | } 247 | 248 | func (c *Color) Bold(msg interface{}, styles ...string) string { 249 | return bold(msg, styles, c) 250 | } 251 | 252 | func (c *Color) Dim(msg interface{}, styles ...string) string { 253 | return dim(msg, styles, c) 254 | } 255 | 256 | func (c *Color) Italic(msg interface{}, styles ...string) string { 257 | return italic(msg, styles, c) 258 | } 259 | 260 | func (c *Color) Underline(msg interface{}, styles ...string) string { 261 | return underline(msg, styles, c) 262 | } 263 | 264 | func (c *Color) Inverse(msg interface{}, styles ...string) string { 265 | return inverse(msg, styles, c) 266 | } 267 | 268 | func (c *Color) Hidden(msg interface{}, styles ...string) string { 269 | return hidden(msg, styles, c) 270 | } 271 | 272 | func (c *Color) Strikeout(msg interface{}, styles ...string) string { 273 | return strikeout(msg, styles, c) 274 | } 275 | 276 | // Output returns the output. 277 | func Output() io.Writer { 278 | return global.output 279 | } 280 | 281 | // SetOutput sets the output. 282 | func SetOutput(w io.Writer) { 283 | global.SetOutput(w) 284 | } 285 | 286 | func Disable() { 287 | global.Disable() 288 | } 289 | 290 | func Enable() { 291 | global.Enable() 292 | } 293 | 294 | // Print is analogous to `fmt.Print` with termial detection. 295 | func Print(args ...interface{}) { 296 | global.Print(args...) 297 | } 298 | 299 | // Println is analogous to `fmt.Println` with termial detection. 300 | func Println(args ...interface{}) { 301 | global.Println(args...) 302 | } 303 | 304 | // Printf is analogous to `fmt.Printf` with termial detection. 305 | func Printf(format string, args ...interface{}) { 306 | global.Printf(format, args...) 307 | } 308 | 309 | func Black(msg interface{}, styles ...string) string { 310 | return global.Black(msg, styles...) 311 | } 312 | 313 | func Red(msg interface{}, styles ...string) string { 314 | return global.Red(msg, styles...) 315 | } 316 | 317 | func Green(msg interface{}, styles ...string) string { 318 | return global.Green(msg, styles...) 319 | } 320 | 321 | func Yellow(msg interface{}, styles ...string) string { 322 | return global.Yellow(msg, styles...) 323 | } 324 | 325 | func Blue(msg interface{}, styles ...string) string { 326 | return global.Blue(msg, styles...) 327 | } 328 | 329 | func Magenta(msg interface{}, styles ...string) string { 330 | return global.Magenta(msg, styles...) 331 | } 332 | 333 | func Cyan(msg interface{}, styles ...string) string { 334 | return global.Cyan(msg, styles...) 335 | } 336 | 337 | func White(msg interface{}, styles ...string) string { 338 | return global.White(msg, styles...) 339 | } 340 | 341 | func Grey(msg interface{}, styles ...string) string { 342 | return global.Grey(msg, styles...) 343 | } 344 | 345 | func BlackBg(msg interface{}, styles ...string) string { 346 | return global.BlackBg(msg, styles...) 347 | } 348 | 349 | func RedBg(msg interface{}, styles ...string) string { 350 | return global.RedBg(msg, styles...) 351 | } 352 | 353 | func GreenBg(msg interface{}, styles ...string) string { 354 | return global.GreenBg(msg, styles...) 355 | } 356 | 357 | func YellowBg(msg interface{}, styles ...string) string { 358 | return global.YellowBg(msg, styles...) 359 | } 360 | 361 | func BlueBg(msg interface{}, styles ...string) string { 362 | return global.BlueBg(msg, styles...) 363 | } 364 | 365 | func MagentaBg(msg interface{}, styles ...string) string { 366 | return global.MagentaBg(msg, styles...) 367 | } 368 | 369 | func CyanBg(msg interface{}, styles ...string) string { 370 | return global.CyanBg(msg, styles...) 371 | } 372 | 373 | func WhiteBg(msg interface{}, styles ...string) string { 374 | return global.WhiteBg(msg, styles...) 375 | } 376 | 377 | func Reset(msg interface{}, styles ...string) string { 378 | return global.Reset(msg, styles...) 379 | } 380 | 381 | func Bold(msg interface{}, styles ...string) string { 382 | return global.Bold(msg, styles...) 383 | } 384 | 385 | func Dim(msg interface{}, styles ...string) string { 386 | return global.Dim(msg, styles...) 387 | } 388 | 389 | func Italic(msg interface{}, styles ...string) string { 390 | return global.Italic(msg, styles...) 391 | } 392 | 393 | func Underline(msg interface{}, styles ...string) string { 394 | return global.Underline(msg, styles...) 395 | } 396 | 397 | func Inverse(msg interface{}, styles ...string) string { 398 | return global.Inverse(msg, styles...) 399 | } 400 | 401 | func Hidden(msg interface{}, styles ...string) string { 402 | return global.Hidden(msg, styles...) 403 | } 404 | 405 | func Strikeout(msg interface{}, styles ...string) string { 406 | return global.Strikeout(msg, styles...) 407 | } -------------------------------------------------------------------------------- /logger/color_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestColor_Black(t *testing.T) { 9 | c := New() 10 | c.Enable() 11 | _, _ = os.Stdout.Write(append([]byte(c.Blue("bule")), '\n')) 12 | } 13 | -------------------------------------------------------------------------------- /logger/log.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "runtime" 8 | "strings" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // logger types 14 | const ( 15 | Critical = iota 16 | Error 17 | Warning 18 | Info 19 | Debug 20 | ) 21 | 22 | const timeFormat = "2006-01-02 15:04:05" 23 | 24 | type logger struct { 25 | sync.Mutex 26 | level uint8 27 | isLogger bool 28 | version string 29 | producer *Color 30 | out io.Writer 31 | noColor bool 32 | } 33 | 34 | // Logger ... 35 | type Logger interface { 36 | Critical(msg interface{}) 37 | Error(msg interface{}) 38 | Warn(msg interface{}) 39 | Info(msg interface{}) 40 | Debug(msg interface{}) 41 | Criticalf(error string, msg ...interface{}) 42 | Errorf(error string, msg ...interface{}) 43 | Warnf(error string, msg ...interface{}) 44 | Infof(error string, msg ...interface{}) 45 | Debugf(error string, msg ...interface{}) 46 | Custom(msg string) 47 | SetLevel(level uint8) 48 | SetOut(out io.Writer) 49 | IsLogger(isOk bool) 50 | } 51 | 52 | // LogCreator ... 53 | func LogCreator(args ...int) Logger { 54 | l := new(logger) 55 | l.producer = New() 56 | l.producer.Enable() 57 | l.level = 1 58 | if len(args) > 0 { 59 | l.level = uint8(args[0]) 60 | } 61 | l.out = os.Stdout 62 | return l 63 | } 64 | 65 | var DefaultLogger = LogCreator() 66 | 67 | func (l *logger) SetOut(out io.Writer) { 68 | l.Lock() 69 | defer l.Unlock() 70 | l.out = out 71 | l.noColor = true 72 | } 73 | 74 | func (l *logger) SetLevel(level uint8) { 75 | l.Lock() 76 | defer l.Unlock() 77 | l.level = level 78 | } 79 | 80 | func (l *logger) IsLogger(p bool) { 81 | l.Lock() 82 | defer l.Unlock() 83 | l.isLogger = p 84 | } 85 | 86 | var mappingLevel = map[uint8]string{ 87 | Critical: "Critical", 88 | Error: "Error", 89 | Warning: "Warn", 90 | Info: "Info", 91 | Debug: "Debug", 92 | } 93 | 94 | func (l *logger) logWrite(msg interface{}, level uint8) (string, bool) { 95 | var msgText string 96 | switch v := msg.(type) { 97 | case error: 98 | msgText = v.Error() 99 | case string: 100 | msgText = v 101 | } 102 | 103 | if level > l.level && !l.isLogger { 104 | return "", false 105 | } 106 | 107 | if !l.isLogger { 108 | _, file, lineno, ok := runtime.Caller(2) 109 | 110 | src := "" 111 | 112 | if ok { 113 | src = strings.Replace( 114 | fmt.Sprintf("%s:%d", file, lineno), "%2e", ".", -1) 115 | } 116 | msgText = fmt.Sprintf("%s [%s] %s (%s) %s", l.version, mappingLevel[level], time.Now().Format(timeFormat), src, msgText) 117 | } else { 118 | msgText = fmt.Sprintf("%s [%s] %s %s", l.version, mappingLevel[level], time.Now().Format(timeFormat), msgText) 119 | } 120 | 121 | return msgText, true 122 | } 123 | 124 | func (l *logger) print(msg string) { 125 | l.Lock() 126 | defer l.Unlock() 127 | _, err := l.out.Write(append([]byte(msg), '\n')) 128 | if err != nil { 129 | _, _ = os.Stdout.Write(append([]byte(msg), '\n')) 130 | } 131 | } 132 | 133 | func (l *logger) Custom(msg string) { 134 | l.print(msg) 135 | } 136 | 137 | func (l *logger) Critical(msg interface{}) { 138 | if msg, ok := l.logWrite(msg, Critical); ok { 139 | l.dyer(Critical, &msg) 140 | } 141 | } 142 | 143 | func (l *logger) Criticalf(error string, msg ...interface{}) { 144 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Critical); ok { 145 | l.dyer(Critical, &msg) 146 | } 147 | } 148 | 149 | func (l *logger) Error(msg interface{}) { 150 | if msg, ok := l.logWrite(msg, Error); ok { 151 | l.dyer(Error, &msg) 152 | } 153 | } 154 | 155 | func (l *logger) Errorf(error string, msg ...interface{}) { 156 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Error); ok { 157 | l.dyer(Error, &msg) 158 | } 159 | } 160 | 161 | func (l *logger) Warn(msg interface{}) { 162 | if msg, ok := l.logWrite(msg, Warning); ok { 163 | l.dyer(Warning, &msg) 164 | } 165 | } 166 | 167 | func (l *logger) Warnf(error string, msg ...interface{}) { 168 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Warning); ok { 169 | l.dyer(Warning, &msg) 170 | } 171 | } 172 | 173 | func (l *logger) Info(msg interface{}) { 174 | if msg, ok := l.logWrite(msg, Info); ok { 175 | l.dyer(Info, &msg) 176 | } 177 | } 178 | 179 | func (l *logger) Infof(error string, msg ...interface{}) { 180 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Info); ok { 181 | l.dyer(Info, &msg) 182 | } 183 | } 184 | 185 | func (l *logger) Debugf(error string, msg ...interface{}) { 186 | if msg, ok := l.logWrite(fmt.Sprintf(error, msg...), Debug); ok { 187 | l.dyer(Debug, &msg) 188 | } 189 | } 190 | 191 | func (l *logger) Debug(msg interface{}) { 192 | if msg, ok := l.logWrite(msg, Debug); ok { 193 | l.dyer(Debug, &msg) 194 | } 195 | } 196 | 197 | func (l *logger) dyer(level int, msg *string) { 198 | if l.noColor { 199 | l.print(*msg) 200 | return 201 | } 202 | switch level { 203 | case Critical: 204 | l.print(l.producer.Red(*msg)) 205 | case Error: 206 | l.print(l.producer.Magenta(*msg)) 207 | case Warning: 208 | l.print(l.producer.Yellow(*msg)) 209 | case Info: 210 | l.print(l.producer.Blue(*msg)) 211 | case Debug: 212 | l.print(l.producer.Cyan(*msg)) 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /logrus.log: -------------------------------------------------------------------------------- 1 | [Critical] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:23) critical 2 | [Error] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:24) error 3 | [Warn] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:25) warn 4 | [Info] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:26) info 5 | [Debug] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:27) debug 6 | [Critical] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:28) test:123 7 | [Error] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:29) test:123 8 | [Warn] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:30) test:123 9 | [Info] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:31) test:123 10 | [Debug] 2021-11-10 10:32:38 (/Users/henryyee/Yee/log_test.go:32) test:123 -------------------------------------------------------------------------------- /middleware/basicAuth.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/cookieY/yee" 10 | ) 11 | 12 | // BasicAuthConfig defines the config of basicAuth middleware 13 | type BasicAuthConfig struct { 14 | Validator fnValidator 15 | Realm string 16 | } 17 | 18 | type fnValidator func([]byte) (bool, error) 19 | 20 | const ( 21 | basic = "basic" 22 | ) 23 | 24 | // BasicAuth is the default implementation BasicAuth middleware 25 | func BasicAuth(fn fnValidator) yee.HandlerFunc { 26 | config := BasicAuthConfig{Validator: fn} 27 | config.Realm = "." 28 | return BasicAuthWithConfig(config) 29 | } 30 | 31 | // BasicAuthWithConfig is the custom implementation BasicAuth middleware 32 | func BasicAuthWithConfig(config BasicAuthConfig) yee.HandlerFunc { 33 | 34 | if config.Validator == nil { 35 | panic("yee: basic-auth middleware requires a validator function") 36 | } 37 | 38 | return func(context yee.Context) (err error) { 39 | decode, _ := parserVerifyData(context) 40 | if verify, err := config.Validator(decode); err == nil && verify { 41 | return err 42 | } 43 | 44 | context.Response().Header().Set(yee.HeaderWWWAuthenticate, basic+" realm="+config.Realm) 45 | 46 | return context.ServerError(http.StatusUnauthorized, "invalid basic auth token") 47 | } 48 | } 49 | 50 | func parserVerifyData(context yee.Context) ([]byte, error) { 51 | var decode []byte 52 | res := context.Request() 53 | if res.Header.Get(yee.HeaderAuthorization) != "" { 54 | auth := strings.Split(res.Header.Get(yee.HeaderAuthorization), " ") 55 | if strings.ToLower(auth[0]) == basic { 56 | decode, err := base64.StdEncoding.DecodeString(auth[1]) 57 | if err != nil { 58 | return decode, err 59 | } 60 | return decode, nil 61 | } 62 | return decode, errors.New("cannot get basic keyword") 63 | } 64 | return decode, errors.New("authorization header is empty") 65 | } 66 | -------------------------------------------------------------------------------- /middleware/basicAuth_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/cookieY/yee" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | type user struct { 15 | Username string `json:"username"` 16 | Password string `json:"password"` 17 | } 18 | 19 | func validator(auth []byte) (bool, error) { 20 | var u user 21 | if err := json.Unmarshal(auth, &u); err != nil { 22 | return false, err 23 | } 24 | if u.Username == "test" && u.Password == "123123" { 25 | return true, nil 26 | } 27 | return false, nil 28 | 29 | } 30 | 31 | var testUser = map[string]string{"username": "test", "password": "123123"} 32 | 33 | func TestBasicAuth(t *testing.T) { 34 | y := yee.New() 35 | y.Use(BasicAuth(validator)) 36 | y.GET("/", func(context yee.Context) error { 37 | return context.String(http.StatusOK, "ok") 38 | }) 39 | 40 | req := httptest.NewRequest(http.MethodGet, "/", nil) 41 | rec := httptest.NewRecorder() 42 | 43 | y.ServeHTTP(rec, req) 44 | 45 | assert.Equal(t, http.StatusUnauthorized, rec.Code) 46 | 47 | req = httptest.NewRequest(http.MethodGet, "/", nil) 48 | rec = httptest.NewRecorder() 49 | u, _ := json.Marshal(testUser) 50 | encodeString := base64.StdEncoding.EncodeToString(u) 51 | req.Header.Set(yee.HeaderAuthorization, "basic "+encodeString) 52 | y.ServeHTTP(rec, req) 53 | assert.Equal(t, http.StatusOK, rec.Code) 54 | 55 | } 56 | -------------------------------------------------------------------------------- /middleware/cors.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/cookieY/yee" 9 | ) 10 | 11 | // CORSConfig defined the config of CORS middleware 12 | type CORSConfig struct { 13 | Origins []string 14 | AllowMethods []string 15 | AllowHeaders []string 16 | AllowCredentials bool 17 | ExposeHeaders []string 18 | MaxAge int 19 | } 20 | 21 | // DefaultCORSConfig is the default config of CORS middleware 22 | var DefaultCORSConfig = CORSConfig{ 23 | Origins: []string{"*"}, 24 | AllowMethods: []string{ 25 | http.MethodGet, 26 | http.MethodPut, 27 | http.MethodPost, 28 | http.MethodDelete, 29 | http.MethodPatch, 30 | http.MethodHead, 31 | http.MethodOptions, 32 | http.MethodConnect, 33 | http.MethodTrace, 34 | }, 35 | } 36 | 37 | // Cors is the default implementation CORS middleware 38 | func Cors() yee.HandlerFunc { 39 | return CorsWithConfig(DefaultCORSConfig) 40 | } 41 | 42 | // CorsWithConfig is the default implementation CORS middleware 43 | func CorsWithConfig(config CORSConfig) yee.HandlerFunc { 44 | 45 | if len(config.Origins) == 0 { 46 | config.Origins = DefaultCORSConfig.Origins 47 | } 48 | 49 | if len(config.AllowMethods) == 0 { 50 | config.AllowMethods = DefaultCORSConfig.AllowMethods 51 | } 52 | 53 | allowMethods := strings.Join(config.AllowMethods, ",") 54 | 55 | allowHeaders := strings.Join(config.AllowHeaders, ",") 56 | 57 | exposeHeaders := strings.Join(config.ExposeHeaders, ",") 58 | 59 | maxAge := strconv.Itoa(config.MaxAge) 60 | 61 | return func(c yee.Context) (err error) { 62 | 63 | localOrigin := c.GetHeader(yee.HeaderOrigin) 64 | 65 | allowOrigin := "" 66 | 67 | m := c.Request().Method 68 | 69 | for _, o := range config.Origins { 70 | if o == "*" && config.AllowCredentials { 71 | allowOrigin = localOrigin 72 | break 73 | } 74 | if o == "*" || o == localOrigin { 75 | allowOrigin = o 76 | break 77 | } 78 | } 79 | 80 | // when method was not OPTIONS, 81 | // we can return simple response header 82 | // because the OPTIONS method is used to 83 | // describe the communication options for the target resource 84 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS 85 | 86 | if m != http.MethodOptions { 87 | c.AddHeader(yee.HeaderVary, yee.HeaderOrigin) 88 | c.SetHeader(yee.HeaderAccessControlAllowOrigin, allowOrigin) 89 | if config.AllowCredentials { 90 | c.SetHeader(yee.HeaderAccessControlAllowCredentials, "true") 91 | } 92 | if exposeHeaders != "" { 93 | c.SetHeader(yee.HeaderAccessControlExposeHeaders, exposeHeaders) 94 | } 95 | c.Next() 96 | return 97 | } 98 | 99 | c.AddHeader(yee.HeaderVary, yee.HeaderOrigin) 100 | c.AddHeader(yee.HeaderVary, yee.HeaderAccessControlRequestMethod) 101 | c.AddHeader(yee.HeaderVary, yee.HeaderAccessControlRequestHeaders) 102 | c.SetHeader(yee.HeaderAccessControlAllowOrigin, allowOrigin) 103 | c.SetHeader(yee.HeaderAccessControlAllowMethods, allowMethods) 104 | if config.AllowCredentials { 105 | c.SetHeader(yee.HeaderAccessControlAllowCredentials, "true") 106 | } 107 | if allowHeaders != "" { 108 | c.SetHeader(yee.HeaderAccessControlAllowHeaders, allowHeaders) 109 | } else { 110 | h := c.GetHeader(yee.HeaderAccessControlRequestHeaders) 111 | if h != "" { 112 | c.SetHeader(yee.HeaderAccessControlAllowHeaders, h) 113 | } 114 | } 115 | if config.MaxAge > 0 { 116 | c.SetHeader(yee.HeaderAccessControlMaxAge, maxAge) 117 | } 118 | return 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /middleware/cors_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cookieY/yee" 5 | "github.com/stretchr/testify/assert" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestCors(t *testing.T) { 12 | y := yee.New() 13 | y.Use(Cors()) 14 | 15 | y.POST("/login", func(c yee.Context) error { 16 | return c.String(http.StatusOK, "test") 17 | }) 18 | 19 | y.OPTIONS("/ok", func(context yee.Context) (err error) { 20 | return err 21 | }) 22 | 23 | t.Run("http_get", func(t *testing.T) { 24 | req := httptest.NewRequest(http.MethodGet, "/ok", nil) 25 | rec := httptest.NewRecorder() 26 | y.ServeHTTP(rec, req) 27 | assert := assert.New(t) 28 | assert.Equal("test", rec.Body.String()) 29 | assert.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin)) 30 | }) 31 | 32 | t.Run("http_option", func(t *testing.T) { 33 | req := httptest.NewRequest(http.MethodOptions, "/ok", nil) 34 | rec := httptest.NewRecorder() 35 | y.ServeHTTP(rec, req) 36 | assert := assert.New(t) 37 | assert.Equal(http.MethodGet, rec.Header().Get(yee.HeaderAccessControlAllowMethods)) 38 | assert.Equal("Test", rec.Header().Get(yee.HeaderAccessControlAllowHeaders)) 39 | }) 40 | } 41 | 42 | func TestEncryptServer(t *testing.T) { 43 | e := yee.New() 44 | e.Use(Cors()) 45 | e.POST("/encrypt", func(c yee.Context) (err error) { 46 | u := new(user) 47 | if err := c.Bind(u); err != nil { 48 | return err 49 | } 50 | return c.JSON(http.StatusOK, u) 51 | }) 52 | e.Run(":9000") 53 | } 54 | -------------------------------------------------------------------------------- /middleware/csrf.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "crypto/subtle" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | "time" 9 | 10 | "github.com/cookieY/yee" 11 | "github.com/google/uuid" 12 | ) 13 | 14 | // CSRFConfig defines the config of CSRF middleware 15 | type CSRFConfig struct { 16 | TokenLength uint8 17 | TokenLookup string 18 | Key string 19 | CookieName string 20 | CookieDomain string 21 | CookiePath string 22 | CookieMaxAge int 23 | CookieSecure bool 24 | CookieHTTPOnly bool 25 | } 26 | 27 | type csrfTokenCreator func(yee.Context) (string, error) 28 | 29 | // CSRFDefaultConfig is the default config of CSRF middleware 30 | var CSRFDefaultConfig = CSRFConfig{ 31 | TokenLength: 16, 32 | TokenLookup: "header:" + yee.HeaderXCSRFToken, 33 | Key: "csrf", 34 | CookieName: "_csrf", 35 | CookieMaxAge: 28800, 36 | } 37 | 38 | // CSRF is the default implementation CSRF middleware 39 | func CSRF() yee.HandlerFunc { 40 | return CSRFWithConfig(CSRFDefaultConfig) 41 | } 42 | 43 | // CSRFWithConfig is the custom implementation CSRF middleware 44 | func CSRFWithConfig(config CSRFConfig) yee.HandlerFunc { 45 | 46 | if config.TokenLength == 0 { 47 | config.TokenLength = CSRFDefaultConfig.TokenLength 48 | } 49 | 50 | if config.TokenLookup == "" { 51 | config.TokenLookup = CSRFDefaultConfig.TokenLookup 52 | } 53 | 54 | if config.Key == "" { 55 | config.Key = CSRFDefaultConfig.Key 56 | } 57 | 58 | if config.CookieName == "" { 59 | config.CookieName = CSRFDefaultConfig.CookieName 60 | } 61 | 62 | if config.CookieMaxAge == 0 { 63 | config.CookieMaxAge = CSRFDefaultConfig.CookieMaxAge 64 | } 65 | 66 | proc := strings.Split(config.TokenLookup, ":") 67 | 68 | creator := csrfTokenFromHeader(proc[1]) 69 | 70 | switch proc[0] { 71 | case "query": 72 | creator = csrfTokenFromQuery(proc[1]) 73 | case "form": 74 | creator = csrfTokenFromForm(proc[1]) 75 | } 76 | 77 | return func(context yee.Context) (err error) { 78 | 79 | // we fetch cookie from this request 80 | // if cookie haven`t token info 81 | // we need generate the token and create a new cookie 82 | // otherwise reuse token 83 | 84 | k, err := context.Cookie(config.CookieName) 85 | token := "" 86 | if err != nil { 87 | token = strings.Replace(uuid.New().String(), "-", "", -1) 88 | } else { 89 | token = k.Value 90 | } 91 | 92 | switch context.Request().Method { 93 | case http.MethodGet, http.MethodTrace, http.MethodOptions, http.MethodHead: 94 | default: 95 | clientToken, e := creator(context) 96 | 97 | if e != nil { 98 | return context.ServerError(http.StatusBadRequest, e.Error()) 99 | } 100 | if !validateCSRFToken(token, clientToken) { 101 | return context.ServerError(http.StatusForbidden, "invalid csrf token") 102 | } 103 | } 104 | 105 | nCookie := new(http.Cookie) 106 | nCookie.Name = config.CookieName 107 | nCookie.Value = token 108 | if config.CookiePath != "" { 109 | nCookie.Path = config.CookiePath 110 | } 111 | if config.CookieDomain != "" { 112 | nCookie.Domain = config.CookieDomain 113 | } 114 | nCookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) 115 | nCookie.Secure = config.CookieSecure 116 | nCookie.HttpOnly = config.CookieHTTPOnly 117 | context.SetCookie(nCookie) 118 | 119 | context.Put(config.Key, token) 120 | context.SetHeader(yee.HeaderVary, yee.HeaderCookie) 121 | context.Next() 122 | return 123 | } 124 | } 125 | 126 | func csrfTokenFromHeader(header string) csrfTokenCreator { 127 | return func(context yee.Context) (string, error) { 128 | token := context.GetHeader(header) 129 | if token == "" { 130 | return "", errors.New("missing csrf token in the header string") 131 | } 132 | return token, nil 133 | } 134 | } 135 | 136 | func csrfTokenFromQuery(param string) csrfTokenCreator { 137 | return func(context yee.Context) (string, error) { 138 | token := context.QueryParam(param) 139 | if token == "" { 140 | return "", errors.New("missing csrf token in the query string") 141 | } 142 | return token, nil 143 | } 144 | } 145 | 146 | func csrfTokenFromForm(param string) csrfTokenCreator { 147 | return func(context yee.Context) (string, error) { 148 | token := context.FormValue(param) 149 | if token == "" { 150 | return "", errors.New("missing csrf token in the form string") 151 | } 152 | return token, nil 153 | } 154 | } 155 | func validateCSRFToken(token, clientToken string) bool { 156 | return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 157 | } 158 | -------------------------------------------------------------------------------- /middleware/csrf_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/cookieY/yee" 10 | "github.com/google/uuid" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestCSRFWithConfig(t *testing.T) { 15 | 16 | y := yee.New() 17 | y.Use(CSRF()) 18 | y.POST("/", func(context yee.Context) error { 19 | return context.String(http.StatusOK, "ok") 20 | }) 21 | req := httptest.NewRequest(http.MethodPost, "/", nil) 22 | rec := httptest.NewRecorder() 23 | 24 | // Without CSRF cookie 25 | req = httptest.NewRequest(http.MethodPost, "/", nil) 26 | rec = httptest.NewRecorder() 27 | y.ServeHTTP(rec, req) 28 | assert.Equal(t, "missing csrf token in the header string", rec.Body.String()) 29 | assert.Equal(t, http.StatusBadRequest, rec.Code) 30 | 31 | // invalid csrf token 32 | req = httptest.NewRequest(http.MethodPost, "/", nil) 33 | req.Header.Set(yee.HeaderXCSRFToken, "cbghjiwhd") 34 | rec = httptest.NewRecorder() 35 | y.ServeHTTP(rec, req) 36 | assert.Equal(t, "invalid csrf token", rec.Body.String()) 37 | assert.Equal(t, http.StatusForbidden, rec.Code) 38 | 39 | token := strings.Replace(uuid.New().String(), "-", "", -1) 40 | req = httptest.NewRequest(http.MethodPost, "/", nil) 41 | req.Header.Set(yee.HeaderCookie, "_csrf="+token) 42 | req.Header.Set(yee.HeaderXCSRFToken, token) 43 | rec = httptest.NewRecorder() 44 | y.ServeHTTP(rec, req) 45 | assert.Equal(t, http.StatusOK, rec.Code) 46 | } 47 | -------------------------------------------------------------------------------- /middleware/gzip.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bufio" 5 | "compress/gzip" 6 | "io" 7 | "io/ioutil" 8 | "net" 9 | "net/http" 10 | "strings" 11 | 12 | "github.com/cookieY/yee" 13 | ) 14 | 15 | // GzipConfig defines config of Gzip middleware 16 | type GzipConfig struct { 17 | Level int 18 | } 19 | 20 | type gzipResponseWriter struct { 21 | io.Writer 22 | http.ResponseWriter 23 | } 24 | 25 | // DefaultGzipConfig is the default config of gzip middleware 26 | var DefaultGzipConfig = GzipConfig{Level: 1} 27 | 28 | // Gzip is the default implementation of gzip middleware 29 | func Gzip() yee.HandlerFunc { 30 | return GzipWithConfig(DefaultGzipConfig) 31 | } 32 | 33 | // GzipWithConfig is the custom implementation of gzip middleware 34 | func GzipWithConfig(config GzipConfig) yee.HandlerFunc { 35 | if config.Level == 0 { 36 | config.Level = DefaultGzipConfig.Level 37 | } 38 | 39 | return func(c yee.Context) (err error) { 40 | if c.IsWebsocket() { 41 | return 42 | } 43 | res := c.Response() 44 | res.Header().Add(yee.HeaderVary, yee.HeaderAcceptEncoding) 45 | if strings.Contains(c.Request().Header.Get(yee.HeaderAcceptEncoding), "gzip") { 46 | res.Header().Set(yee.HeaderContentEncoding, "gzip") 47 | rw := res.Writer() 48 | w, err := gzip.NewWriterLevel(rw, config.Level) 49 | if err != nil { 50 | return err 51 | } 52 | defer func() { 53 | if res.Size() < 1 { 54 | if res.Header().Get(yee.HeaderContentEncoding) == "gzip" { 55 | res.Header().Del(yee.HeaderContentEncoding) 56 | } 57 | res.Override(rw) 58 | w.Reset(ioutil.Discard) 59 | } 60 | _ = w.Close() 61 | }() 62 | grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} 63 | res.Override(grw) 64 | } 65 | c.Next() 66 | return 67 | } 68 | } 69 | 70 | func (w *gzipResponseWriter) Write(b []byte) (int, error) { 71 | if w.Header().Get(yee.HeaderContentType) == "" { 72 | w.Header().Set(yee.HeaderContentType, http.DetectContentType(b)) 73 | } 74 | return w.Writer.Write(b) 75 | } 76 | 77 | func (w *gzipResponseWriter) Flush() { 78 | w.Writer.(*gzip.Writer).Flush() 79 | if flusher, ok := w.ResponseWriter.(http.Flusher); ok { 80 | flusher.Flush() 81 | } 82 | } 83 | 84 | func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 85 | return w.ResponseWriter.(http.Hijacker).Hijack() 86 | } 87 | -------------------------------------------------------------------------------- /middleware/gzip_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cookieY/yee" 5 | "github.com/stretchr/testify/assert" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestGzip(t *testing.T) { 12 | y := yee.New() 13 | y.Use(Logger(), GzipWithConfig(GzipConfig{Level: 9})) 14 | y.Static("/", "../testing/dist/assets") 15 | t.Run("http_get", func(t *testing.T) { 16 | req := httptest.NewRequest(http.MethodGet, "/js/app.d2880701.js", nil) 17 | req.Header.Add(yee.HeaderAcceptEncoding, "gzip") 18 | rec := httptest.NewRecorder() 19 | y.ServeHTTP(rec, req) 20 | assert2 := assert.New(t) 21 | assert2.Equal(http.StatusOK, rec.Code) 22 | }) 23 | } 24 | -------------------------------------------------------------------------------- /middleware/jwt.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/cookieY/yee" 7 | "github.com/golang-jwt/jwt" 8 | "net/http" 9 | "reflect" 10 | ) 11 | 12 | // JwtConfig defines the config of JWT middleware 13 | type JwtConfig struct { 14 | GetKey string 15 | AuthScheme string 16 | SigningKey interface{} 17 | SigningMethod string 18 | TokenLookup []string 19 | Claims jwt.Claims 20 | keyFunc jwt.Keyfunc 21 | ErrorHandler JWTErrorHandler 22 | SuccessHandler JWTSuccessHandler 23 | } 24 | 25 | type jwtExtractor func(yee.Context) (string, error) 26 | 27 | // JWTErrorHandler defines a function which is error for a valid token. 28 | type JWTErrorHandler func(error) error 29 | 30 | // JWTSuccessHandler defines a function which is executed for a valid token. 31 | type JWTSuccessHandler func(yee.Context) 32 | 33 | const algorithmHS256 = "HS256" 34 | 35 | // DefaultJwtConfig is the default config of JWT middleware 36 | var DefaultJwtConfig = JwtConfig{ 37 | GetKey: "auth", 38 | SigningMethod: algorithmHS256, 39 | AuthScheme: "Bearer", 40 | TokenLookup: []string{yee.HeaderAuthorization}, 41 | Claims: jwt.MapClaims{}, 42 | } 43 | 44 | // JWTWithConfig is the custom implementation CORS middleware 45 | func JWTWithConfig(config JwtConfig) yee.HandlerFunc { 46 | if config.SigningKey == nil { 47 | panic("yee: jwt middleware requires signing key") 48 | } 49 | if config.SigningMethod == "" { 50 | config.SigningMethod = DefaultJwtConfig.SigningMethod 51 | } 52 | if config.GetKey == "" { 53 | config.GetKey = DefaultJwtConfig.GetKey 54 | } 55 | if config.AuthScheme == "" { 56 | config.AuthScheme = DefaultJwtConfig.AuthScheme 57 | } 58 | 59 | if config.Claims == nil { 60 | config.Claims = DefaultJwtConfig.Claims 61 | } 62 | 63 | if config.TokenLookup == nil { 64 | config.TokenLookup = DefaultJwtConfig.TokenLookup 65 | } 66 | 67 | config.keyFunc = func(token *jwt.Token) (interface{}, error) { 68 | if token.Method.Alg() != config.SigningMethod { 69 | return nil, fmt.Errorf("unexpected jwt signing method=%v", token.Header["alg"]) 70 | } 71 | return config.SigningKey, nil 72 | } 73 | 74 | extractor := jwtFromHeader(config.TokenLookup, config.AuthScheme) 75 | 76 | return func(c yee.Context) (err error) { 77 | // cause upgrade websocket will clear custom header 78 | // when header add jwt bearer that panic 79 | auth, err := extractor(c) 80 | if err != nil { 81 | return c.JSON(http.StatusBadRequest, err.Error()) 82 | } 83 | token := new(jwt.Token) 84 | if _, ok := config.Claims.(jwt.MapClaims); ok { 85 | token, err = jwt.Parse(auth, config.keyFunc) 86 | if err != nil { 87 | return c.JSON(http.StatusUnauthorized, err.Error()) 88 | } 89 | } else { 90 | t := reflect.ValueOf(config.Claims).Type().Elem() 91 | claims := reflect.New(t).Interface().(jwt.Claims) 92 | token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc) 93 | } 94 | if err == nil && token.Valid { 95 | c.Put(config.GetKey, token) 96 | return 97 | } 98 | // bug fix 99 | // if invalid or expired jwt, 100 | // we must intercept all handlers and return serverError 101 | return c.JSON(http.StatusUnauthorized, "invalid or expired jwt") 102 | } 103 | } 104 | 105 | func jwtFromHeader(header []string, authScheme string) jwtExtractor { 106 | return func(c yee.Context) (string, error) { 107 | for _, i := range header { 108 | auth := c.Request().Header.Get(i) 109 | l := len(authScheme) 110 | if len(auth) > l+1 && auth[:l] == authScheme { 111 | return auth[l+1:], nil 112 | } 113 | if i == yee.HeaderSecWebSocketProtocol { 114 | return auth, nil 115 | } 116 | } 117 | return "", errors.New("missing or malformed jwt") 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /middleware/jwt_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/cookieY/yee" 7 | "github.com/golang-jwt/jwt" 8 | "github.com/stretchr/testify/assert" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func GenJwtToken() (string, error) { 16 | token := jwt.New(jwt.SigningMethodHS256) 17 | claims := token.Claims.(jwt.MapClaims) 18 | claims["name"] = "henry" 19 | claims["exp"] = time.Now().Add(time.Minute * 15).Unix() 20 | t, err := token.SignedString([]byte("dbcjqheupqjsuwsm")) 21 | if err != nil { 22 | return "", errors.New("JWT Generate Failure") 23 | } 24 | return t, nil 25 | } 26 | 27 | func JwtParse(c yee.Context) (string, string) { 28 | user := c.Get("auth").(*jwt.Token) 29 | claims := user.Claims.(jwt.MapClaims) 30 | return claims["name"].(string),claims["name"].(string) 31 | } 32 | 33 | func SuperManageGroup() yee.HandlerFunc { 34 | return func(c yee.Context) (err error) { 35 | user, _ := JwtParse(c) 36 | if user == "henry" { 37 | return 38 | } 39 | return c.JSON(http.StatusForbidden, "非法越权操作!") 40 | } 41 | } 42 | 43 | func TestJwt(t *testing.T) { 44 | 45 | cases := []struct { 46 | Name string 47 | Expected int 48 | IsSign bool 49 | Expire time.Duration 50 | }{ 51 | {"not_token", 400, false, 0}, 52 | {"test_is_ok", 200, true, 0}, 53 | //{"test_is_expire", 401, true,time.Second * 1}, 54 | } 55 | for _, i := range cases { 56 | t.Run(i.Name, func(t *testing.T) { 57 | y := yee.New() 58 | 59 | y.Use(JWTWithConfig(JwtConfig{SigningKey: []byte("dbcjqheupqjsuwsm")})) 60 | y.GET("/", func(context yee.Context) error { 61 | return context.String(http.StatusOK, "is_ok") 62 | }) 63 | 64 | req := httptest.NewRequest(http.MethodGet, "/", nil) 65 | rec := httptest.NewRecorder() 66 | if i.IsSign { 67 | token, _ := GenJwtToken() 68 | req.Header.Set(yee.HeaderAuthorization, fmt.Sprintf("Bearer %s", token)) 69 | } 70 | time.Sleep(i.Expire) 71 | y.ServeHTTP(rec, req) 72 | assert2 := assert.New(t) 73 | assert2.Equal(i.Expected, rec.Code) 74 | }) 75 | } 76 | } 77 | 78 | func TestJwtExpire(t *testing.T) { 79 | y := yee.New() 80 | y.Use(Cors(), Secure()) 81 | y.GET("/", func(context yee.Context) error { 82 | return context.String(http.StatusOK, "is_ok") 83 | }) 84 | r := y.Group("/api", JWTWithConfig(JwtConfig{SigningKey: []byte("dbcjqheupqjsuwsm")})) 85 | k := r.Group("/k", SuperManageGroup()) 86 | k.GET("/o", func(context yee.Context) (err error) { 87 | return context.JSON(http.StatusOK, "pk") 88 | }) 89 | y.Run(":9999") 90 | 91 | } 92 | -------------------------------------------------------------------------------- /middleware/logger.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cookieY/yee/logger" 6 | "io" 7 | "log" 8 | 9 | "github.com/cookieY/yee" 10 | "github.com/valyala/fasttemplate" 11 | ) 12 | 13 | //LoggerConfig defines config of logger middleware 14 | type ( 15 | LoggerConfig struct { 16 | Format string 17 | Level uint8 18 | IsLogger bool 19 | } 20 | ) 21 | 22 | // DefaultLoggerConfig is default config of logger middleware 23 | var DefaultLoggerConfig = LoggerConfig{ 24 | Format: `"url":"${url}" "method":"${method}" "status":${status} "protocol":"${protocol}" "remote_ip":"${remote_ip}" "bytes_in": "${bytes_in} bytes" "bytes_out": "${bytes_out} bytes"`, 25 | Level: 3, 26 | IsLogger: true, 27 | } 28 | 29 | // Logger is default implementation of logger middleware 30 | func Logger() yee.HandlerFunc { 31 | return LoggerWithConfig(DefaultLoggerConfig) 32 | } 33 | 34 | // LoggerWithConfig is custom implementation of logger middleware 35 | func LoggerWithConfig(config LoggerConfig) yee.HandlerFunc { 36 | if config.Format == "" { 37 | config.Format = DefaultLoggerConfig.Format 38 | } 39 | 40 | if config.Level == 0 { 41 | config.Level = DefaultLoggerConfig.Level 42 | } 43 | 44 | t, err := fasttemplate.NewTemplate(config.Format, "${", "}") 45 | 46 | if err != nil { 47 | log.Fatalf("unexpected error when parsing template: %s", err) 48 | } 49 | 50 | logger := logger.LogCreator() 51 | 52 | logger.SetLevel(config.Level) 53 | 54 | logger.IsLogger(config.IsLogger) 55 | 56 | return func(context yee.Context) (err error) { 57 | context.Next() 58 | s := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { 59 | switch tag { 60 | case "url": 61 | p := context.Request().URL.Path 62 | if p == "" { 63 | p = "/" 64 | } 65 | return w.Write([]byte(p)) 66 | case "method": 67 | return w.Write([]byte(context.Request().Method)) 68 | case "status": 69 | return w.Write([]byte(fmt.Sprintf("%d", context.Response().Status()))) 70 | case "remote_ip": 71 | return w.Write([]byte(context.RemoteIP())) 72 | case "host": 73 | return w.Write([]byte(context.Request().Host)) 74 | case "protocol": 75 | return w.Write([]byte(context.Request().Proto)) 76 | case "bytes_in": 77 | cl := context.Request().Header.Get(yee.HeaderContentLength) 78 | if cl == "" { 79 | cl = "0" 80 | } 81 | return w.Write([]byte(cl)) 82 | case "bytes_out": 83 | return w.Write([]byte(fmt.Sprintf("%d", context.Response().Size()))) 84 | default: 85 | return w.Write([]byte("")) 86 | } 87 | }) 88 | if context.Response().Status() < 400 { 89 | logger.Info(s) 90 | } else { 91 | logger.Warn(s) 92 | } 93 | return 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /middleware/logger_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cookieY/yee" 5 | "github.com/stretchr/testify/assert" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestLogger(t *testing.T) { 12 | y := yee.New() 13 | y.Use(Logger()) 14 | y.GET("/", func(context yee.Context) error { 15 | context.Logger().Critical("哈哈哈哈") 16 | return context.String(http.StatusOK, "ok") 17 | }) 18 | t.Run("http_get", func(t *testing.T) { 19 | req := httptest.NewRequest(http.MethodGet, "/", nil) 20 | rec := httptest.NewRecorder() 21 | y.ServeHTTP(rec, req) 22 | assert := assert.New(t) 23 | assert.Equal("ok", rec.Body.String()) 24 | assert.Equal(http.StatusOK, rec.Code) 25 | }) 26 | } 27 | -------------------------------------------------------------------------------- /middleware/multiMiddle_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/cookieY/yee" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestMultiMiddle(t *testing.T) { 13 | y := yee.New() 14 | y.Use(Cors()) 15 | y.Use(JWTWithConfig(JwtConfig{SigningKey: []byte("dbcjqheupqjsuwsm")})) 16 | y.GET("/", func(context yee.Context) error { 17 | return context.String(http.StatusOK, "is_ok") 18 | }) 19 | 20 | req := httptest.NewRequest(http.MethodGet, "/", nil) 21 | rec := httptest.NewRecorder() 22 | 23 | a := assert.New(t) 24 | y.ServeHTTP(rec, req) 25 | a.Equal("\"missing or malformed jwt\"\n", rec.Body.String()) 26 | a.Equal(400, rec.Code) 27 | a.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin)) 28 | } 29 | 30 | func TestMultiGroup(t *testing.T) { 31 | y := yee.C() 32 | r := y.Group("/", Cors(), CustomerMiddleware()) 33 | r.GET("/test", func(context yee.Context) error { 34 | return context.String(http.StatusOK, "is_ok") 35 | }) 36 | req := httptest.NewRequest(http.MethodGet, "/test", nil) 37 | rec := httptest.NewRecorder() 38 | 39 | a := assert.New(t) 40 | y.ServeHTTP(rec, req) 41 | a.Equal("非法越权操作!", rec.Body.String()) 42 | a.Equal("*", rec.Header().Get(yee.HeaderAccessControlAllowOrigin)) 43 | a.Equal(403, rec.Code) 44 | } 45 | 46 | func CustomerMiddleware() yee.HandlerFunc { 47 | return func(c yee.Context) (err error) { 48 | if c.QueryParam("test") == "y" { 49 | c.Next() 50 | return 51 | } 52 | return c.ServerError(http.StatusForbidden, "非法越权操作!") 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /middleware/proxy.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/cookieY/yee" 7 | "net/http" 8 | "net/http/httputil" 9 | "net/url" 10 | "strings" 11 | ) 12 | 13 | type ProxyTargetHandler interface { 14 | MatchAlgorithm() ProxyTarget 15 | } 16 | 17 | type ProxyTarget struct { 18 | Name string 19 | URL *url.URL 20 | } 21 | 22 | type ProxyConfig struct { 23 | BalanceType string 24 | Transport http.RoundTripper 25 | ModifyResponse func(*http.Response) error 26 | ProxyTarget ProxyTargetHandler 27 | } 28 | 29 | type errorHandler struct { 30 | Err error 31 | Code int 32 | } 33 | 34 | func ProxyWithConfig(config ProxyConfig) yee.HandlerFunc { 35 | return func(c yee.Context) (err error) { 36 | req := c.Request() 37 | res := c.Response() 38 | if req.Header.Get(yee.HeaderXRealIP) == "" { 39 | req.Header.Set(yee.HeaderXRealIP, c.RemoteIP()) 40 | } 41 | if req.Header.Get(yee.HeaderXForwardedProto) == "" { 42 | req.Header.Set(yee.HeaderXForwardedProto, c.Scheme()) 43 | } 44 | if c.IsWebsocket() && req.Header.Get(yee.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy. 45 | req.Header.Set(yee.HeaderXForwardedFor, c.RemoteIP()) 46 | } 47 | switch { 48 | case c.IsWebsocket(): 49 | //proxyRaw(tgt, c).ServeHTTP(res, req) 50 | case req.Header.Get(yee.HeaderAccept) == "text/event-stream": 51 | default: 52 | proxyHTTP(c, config).ServeHTTP(res, req) 53 | } 54 | if e, ok := c.Get("_error").(errorHandler); ok { 55 | return c.ServerError(e.Code, e.Err.Error()) 56 | } 57 | 58 | return nil 59 | } 60 | } 61 | 62 | func proxyHTTP(c yee.Context, config ProxyConfig) http.Handler { 63 | tgt := config.ProxyTarget.MatchAlgorithm() 64 | proxy := httputil.NewSingleHostReverseProxy(tgt.URL) 65 | proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { 66 | desc := tgt.URL.String() 67 | if tgt.Name != "" { 68 | desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) 69 | } 70 | if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { 71 | c.Put("_error", errorHandler{fmt.Errorf("client closed connection: %s", err.Error()), yee.StatusCodeContextCanceled}) 72 | } else { 73 | c.Put("_error", errorHandler{fmt.Errorf("remote %s unreachable, could not forward: %v", desc, err), http.StatusBadGateway}) 74 | } 75 | } 76 | proxy.Transport = config.Transport 77 | proxy.ModifyResponse = config.ModifyResponse 78 | return proxy 79 | } 80 | -------------------------------------------------------------------------------- /middleware/rateLimit.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | "time" 7 | 8 | "github.com/cookieY/yee" 9 | ) 10 | 11 | // RateLimitConfig defines config of rateLimit middleware 12 | type RateLimitConfig struct { 13 | Time time.Duration 14 | Rate int 15 | lock *sync.Mutex 16 | numbers int 17 | } 18 | 19 | // DefaultRateLimit is the default config of rateLimit middleware 20 | var DefaultRateLimit = RateLimitConfig{ 21 | Time: 1 * time.Second, 22 | Rate: 5, 23 | } 24 | 25 | // RateLimit is the default implementation of rateLimit middleware 26 | func RateLimit() yee.HandlerFunc { 27 | return RateLimitWithConfig(DefaultRateLimit) 28 | } 29 | 30 | // RateLimitWithConfig is the custom implementation of rateLimit middleware 31 | func RateLimitWithConfig(config RateLimitConfig) yee.HandlerFunc { 32 | 33 | if config.Time == 0 { 34 | config.Time = DefaultRateLimit.Time 35 | } 36 | 37 | if config.Rate == 0 { 38 | config.Rate = DefaultRateLimit.Rate 39 | } 40 | 41 | config.lock = new(sync.Mutex) 42 | 43 | go timer(&config) 44 | 45 | return func(context yee.Context) (err error) { 46 | if config.numbers >= config.Rate { 47 | return context.ServerError(http.StatusTooManyRequests, "too many requests") 48 | } 49 | config.lock.Lock() 50 | config.numbers++ 51 | defer config.lock.Unlock() 52 | return 53 | } 54 | } 55 | 56 | func timer(c *RateLimitConfig) { 57 | ticker := time.NewTicker(c.Time) 58 | for { 59 | <-ticker.C 60 | c.lock.Lock() 61 | c.numbers = 0 62 | c.lock.Unlock() 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /middleware/rateLimit_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cookieY/yee" 6 | "net/http" 7 | "net/http/httptest" 8 | "runtime" 9 | "sync" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestRateLimit(t *testing.T) { 15 | runtime.GOMAXPROCS(runtime.NumCPU()) 16 | y := yee.New() 17 | y.Use(RateLimitWithConfig(RateLimitConfig{Rate: 1,Time: time.Second * 2})) 18 | y.GET("/", func(context yee.Context) (err error) { 19 | return context.String(http.StatusOK, "ok") 20 | }) 21 | var wg sync.WaitGroup 22 | var once sync.Once 23 | for i := 0; i < 50; i++ { 24 | if i > 25 { 25 | once.Do(func() { 26 | time.Sleep(time.Second * 2) 27 | }) 28 | } 29 | wg.Add(1) 30 | go func(i int) { 31 | req := httptest.NewRequest(http.MethodGet, "/", nil) 32 | rec := httptest.NewRecorder() 33 | y.ServeHTTP(rec, req) 34 | fmt.Printf("id: %d code:%d body:%s \n",i,rec.Code,rec.Body.String()) 35 | wg.Done() 36 | }(i) 37 | } 38 | wg.Wait() 39 | } -------------------------------------------------------------------------------- /middleware/recovery.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "runtime" 7 | "strings" 8 | 9 | "github.com/cookieY/yee" 10 | ) 11 | 12 | // Recovery is a recovery middleware 13 | // when the program was panic 14 | // it can recovery program and print stack info 15 | func Recovery() yee.HandlerFunc { 16 | return func(c yee.Context) (err error) { 17 | defer func() { 18 | if r := recover(); r != nil { 19 | err, ok := r.(error) 20 | if !ok { 21 | err = fmt.Errorf("%v", r) 22 | } 23 | var pcs [32]uintptr 24 | n := runtime.Callers(3, pcs[:]) // skip first 3 caller 25 | 26 | var str strings.Builder 27 | str.WriteString("Traceback:") 28 | for _, pc := range pcs[:n] { 29 | fn := runtime.FuncForPC(pc) 30 | file, line := fn.FileLine(pc) 31 | str.WriteString(fmt.Sprintf("\n\t%s:%d", file, line)) 32 | } 33 | c.Logger().Critical(fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, str.String())) 34 | _ = c.ServerError(http.StatusInternalServerError, "Internal Server Error") 35 | } 36 | }() 37 | c.Next() 38 | return 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /middleware/recovery_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cookieY/yee" 5 | "github.com/stretchr/testify/assert" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestRecovery(t *testing.T) { 12 | y := yee.New() 13 | y.Use(Recovery()) 14 | y.GET("/y", func(context yee.Context) error { 15 | names := []string{"geektutu"} 16 | return context.String(http.StatusOK, names[100]) 17 | }) 18 | 19 | t.Run("http_get", func(t *testing.T) { 20 | req := httptest.NewRequest(http.MethodGet, "/y", nil) 21 | rec := httptest.NewRecorder() 22 | y.ServeHTTP(rec, req) 23 | assert := assert.New(t) 24 | assert.Equal("Internal Server Error", rec.Body.String()) 25 | assert.Equal(http.StatusInternalServerError, rec.Code) 26 | }) 27 | } 28 | -------------------------------------------------------------------------------- /middleware/requestID.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/cookieY/yee" 7 | "github.com/google/uuid" 8 | ) 9 | 10 | // RequestIDConfig defines config of requestID middleware 11 | type RequestIDConfig struct { 12 | generator func() string 13 | } 14 | 15 | // DefaultRequestIDConfig is the default config of requestID middleware 16 | var DefaultRequestIDConfig = RequestIDConfig{ 17 | generator: defaultGenerator, 18 | } 19 | 20 | func defaultGenerator() string { 21 | return strings.Replace(uuid.New().String(), "-", "", -1) 22 | } 23 | 24 | // RequestID is the default implementation of requestID middleware 25 | func RequestID() yee.HandlerFunc { 26 | return RequestIDWithConfig(DefaultRequestIDConfig) 27 | } 28 | 29 | // RequestIDWithConfig is the custom implementation of requestID middleware 30 | func RequestIDWithConfig(config RequestIDConfig) yee.HandlerFunc { 31 | 32 | if config.generator == nil { 33 | config.generator = DefaultRequestIDConfig.generator 34 | } 35 | return func(context yee.Context) (err error) { 36 | req := context.Request() 37 | res := context.Response() 38 | if req.Header.Get(yee.HeaderXRequestID) == "" { 39 | res.Header().Set(yee.HeaderXRequestID, config.generator()) 40 | } 41 | return 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /middleware/requestID_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cookieY/yee" 6 | "github.com/stretchr/testify/assert" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | ) 11 | 12 | func TestRequestID(t *testing.T) { 13 | y := yee.New() 14 | y.Use(RequestID()) 15 | y.GET("/", func(context yee.Context) error { 16 | return context.String(http.StatusOK, "ok") 17 | }) 18 | 19 | req := httptest.NewRequest(http.MethodGet, "/", nil) 20 | rec := httptest.NewRecorder() 21 | 22 | y.ServeHTTP(rec, req) 23 | fmt.Println(rec.Header().Get(yee.HeaderXRequestID)) 24 | assert.Equal(t, http.StatusOK, rec.Code) 25 | assert.NotEqual(t, "", rec.Header().Get(yee.HeaderXRequestID)) 26 | } 27 | -------------------------------------------------------------------------------- /middleware/secure.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/cookieY/yee" 7 | ) 8 | 9 | type ( 10 | 11 | //SecureConfig defines config of secure middleware 12 | SecureConfig struct { 13 | XSSProtection string `yaml:"xss_protection"` 14 | 15 | ContentTypeNosniff string `yaml:"content_type_nosniff"` 16 | 17 | XFrameOptions string `yaml:"x_frame_options"` 18 | 19 | HSTSMaxAge int `yaml:"hsts_max_age"` 20 | 21 | HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"` 22 | 23 | ContentSecurityPolicy string `yaml:"content_security_policy"` 24 | 25 | CSPReportOnly bool `yaml:"csp_report_only"` 26 | 27 | HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"` 28 | 29 | ReferrerPolicy string `yaml:"referrer_policy"` 30 | } 31 | ) 32 | 33 | // DefaultSecureConfig is default config of secure middleware 34 | var DefaultSecureConfig = SecureConfig{ 35 | XSSProtection: "1; mode=block", 36 | ContentTypeNosniff: "nosniff", 37 | XFrameOptions: "SAMEORIGIN", 38 | HSTSPreloadEnabled: false, 39 | } 40 | 41 | // Secure is default implementation of secure middleware 42 | func Secure() yee.HandlerFunc { 43 | return SecureWithConfig(DefaultSecureConfig) 44 | } 45 | 46 | // SecureWithConfig is custom implementation of secure middleware 47 | func SecureWithConfig(config SecureConfig) yee.HandlerFunc { 48 | return func(c yee.Context) (err error) { 49 | 50 | if config.XSSProtection != "" { 51 | c.SetHeader(yee.HeaderXXSSProtection, config.XSSProtection) 52 | } 53 | 54 | if config.ContentTypeNosniff != "" { 55 | c.SetHeader(yee.HeaderXContentTypeOptions, config.ContentTypeNosniff) 56 | } 57 | 58 | if config.XFrameOptions != "" { 59 | c.SetHeader(yee.HeaderXFrameOptions, config.XFrameOptions) 60 | } 61 | 62 | if (c.IsTLS() || (c.GetHeader(yee.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 { 63 | subdomains := "" 64 | if !config.HSTSExcludeSubdomains { 65 | subdomains = "; includeSubdomains" 66 | } 67 | if config.HSTSPreloadEnabled { 68 | subdomains = fmt.Sprintf("%s; preload", subdomains) 69 | } 70 | c.SetHeader(yee.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", config.HSTSMaxAge, subdomains)) 71 | } 72 | // CSP 73 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy-Report-Only 74 | // https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Content_Security_Policy 75 | if config.ContentSecurityPolicy != "" { 76 | if config.CSPReportOnly { 77 | c.SetHeader(yee.HeaderContentSecurityPolicyReportOnly, config.ContentSecurityPolicy) 78 | } else { 79 | c.SetHeader(yee.HeaderContentSecurityPolicy, config.ContentSecurityPolicy) 80 | } 81 | } 82 | 83 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy 84 | if config.ReferrerPolicy != "" { 85 | c.SetHeader(yee.HeaderReferrerPolicy, config.ReferrerPolicy) 86 | } 87 | return 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /middleware/secure_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cookieY/yee" 5 | "github.com/stretchr/testify/assert" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestSecure(t *testing.T) { 12 | y := yee.New() 13 | y.Use(Secure()) 14 | y.GET("/", func(context yee.Context) error { 15 | return context.String(http.StatusOK, "ok") 16 | }) 17 | t.Run("http_get", func(t *testing.T) { 18 | req := httptest.NewRequest(http.MethodGet, "/", nil) 19 | rec := httptest.NewRecorder() 20 | y.ServeHTTP(rec, req) 21 | assert := assert.New(t) 22 | assert.Equal("ok", rec.Body.String()) 23 | assert.Equal(http.StatusOK, rec.Code) 24 | assert.Equal("SAMEORIGIN", rec.Header().Get(yee.HeaderXFrameOptions)) 25 | assert.Equal("1; mode=block", rec.Header().Get(yee.HeaderXXSSProtection)) 26 | assert.Equal("nosniff", rec.Header().Get(yee.HeaderXContentTypeOptions)) 27 | }) 28 | } 29 | -------------------------------------------------------------------------------- /middleware/trace.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "fmt" 7 | "github.com/cookieY/yee" 8 | "github.com/opentracing/opentracing-go" 9 | "github.com/opentracing/opentracing-go/ext" 10 | "github.com/uber/jaeger-client-go/config" 11 | "io" 12 | "io/ioutil" 13 | "net/http" 14 | "time" 15 | ) 16 | 17 | const defaultComponentName = "yee" 18 | 19 | type ( 20 | TraceConfig struct { 21 | // OpenTracing Tracer instance which should be got before 22 | Tracer opentracing.Tracer 23 | ComponentName string 24 | IsBodyDump bool 25 | LimitHTTPBody bool 26 | LimitSize int 27 | } 28 | responseDumper struct { 29 | http.ResponseWriter 30 | 31 | mw io.Writer 32 | buf *bytes.Buffer 33 | } 34 | ) 35 | 36 | func New(e *yee.Core) io.Closer { 37 | // Add Opentracing instrumentation 38 | defcfg := config.Configuration{ 39 | ServiceName: "echo-tracer", 40 | Sampler: &config.SamplerConfig{ 41 | Type: "const", 42 | Param: 1, 43 | }, 44 | Reporter: &config.ReporterConfig{ 45 | LogSpans: true, 46 | BufferFlushInterval: 1 * time.Second, 47 | }, 48 | } 49 | cfg, err := defcfg.FromEnv() 50 | if err != nil { 51 | panic("Could not parse Jaeger env vars: " + err.Error()) 52 | } 53 | tracer, closer, err := cfg.NewTracer() 54 | if err != nil { 55 | panic("Could not initialize jaeger tracer: " + err.Error()) 56 | } 57 | 58 | opentracing.SetGlobalTracer(tracer) 59 | e.Use(TraceWithConfig(TraceConfig{ 60 | Tracer: tracer, 61 | })) 62 | return closer 63 | } 64 | 65 | func TraceWithConfig(config TraceConfig) yee.HandlerFunc { 66 | if config.Tracer == nil { 67 | panic("yee: trace middleware requires opentracing tracer") 68 | } 69 | if config.ComponentName == "" { 70 | config.ComponentName = defaultComponentName 71 | } 72 | 73 | return func(c yee.Context) error { 74 | req := c.Request() 75 | opname := "HTTP " + req.Method + " URL: " + c.Path() 76 | realIP := c.RemoteIP() 77 | requestID := getRequestID(c) // request-id generated by reverse-proxy 78 | 79 | var sp opentracing.Span 80 | var err error 81 | 82 | ctx, err := config.Tracer.Extract( 83 | opentracing.HTTPHeaders, 84 | opentracing.HTTPHeadersCarrier(req.Header), 85 | ) 86 | 87 | if err != nil { 88 | sp = config.Tracer.StartSpan(opname) 89 | } else { 90 | sp = config.Tracer.StartSpan(opname, ext.RPCServerOption(ctx)) 91 | } 92 | defer sp.Finish() 93 | 94 | ext.HTTPMethod.Set(sp, req.Method) 95 | ext.HTTPUrl.Set(sp, req.URL.String()) 96 | ext.Component.Set(sp, config.ComponentName) 97 | sp.SetTag("client_ip", realIP) 98 | sp.SetTag("request_id", requestID) 99 | 100 | // Dump request & response body 101 | var respDumper *responseDumper 102 | if config.IsBodyDump { 103 | // request 104 | reqBody := []byte{} 105 | if c.Request().Body != nil { 106 | reqBody, _ = ioutil.ReadAll(c.Request().Body) 107 | 108 | if config.LimitHTTPBody { 109 | sp.LogKV("http.req.body", limitString(string(reqBody), config.LimitSize)) 110 | } else { 111 | sp.LogKV("http.req.body", string(reqBody)) 112 | } 113 | } 114 | 115 | req.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // reset original request body 116 | 117 | // response 118 | respDumper = newResponseDumper(c) 119 | c.Response().Override(respDumper.ResponseWriter) 120 | } 121 | 122 | // setup request context - add opentracing span 123 | req = req.WithContext(opentracing.ContextWithSpan(req.Context(), sp)) 124 | c.SetRequest(req) 125 | 126 | // call next middleware / controller 127 | c.Next() 128 | if err != nil { 129 | c.Logger().Error(err) // call custom registered error handler 130 | } 131 | 132 | status := c.Response().Status() 133 | ext.HTTPStatusCode.Set(sp, uint16(status)) 134 | 135 | if err != nil { 136 | logError(sp, err) 137 | } 138 | 139 | // Dump response body 140 | if config.IsBodyDump { 141 | if config.LimitHTTPBody { 142 | sp.LogKV("http.resp.body", limitString(respDumper.GetResponse(), config.LimitSize)) 143 | } else { 144 | sp.LogKV("http.resp.body", respDumper.GetResponse()) 145 | } 146 | } 147 | 148 | return nil // error was already processed with ctx.Error(err) 149 | } 150 | } 151 | 152 | func getRequestID(ctx yee.Context) string { 153 | requestID := ctx.Request().Header.Get(yee.HeaderXRequestID) // request-id generated by reverse-proxy 154 | if requestID == "" { 155 | requestID = generateToken() // missed request-id from proxy, we generate it manually 156 | } 157 | return requestID 158 | } 159 | 160 | func generateToken() string { 161 | b := make([]byte, 16) 162 | rand.Read(b) 163 | return fmt.Sprintf("%x", b) 164 | } 165 | 166 | func limitString(str string, size int) string { 167 | if len(str) > size { 168 | return str[:size/2] + "\n---- skipped ----\n" + str[len(str)-size/2:] 169 | } 170 | 171 | return str 172 | } 173 | 174 | func newResponseDumper(resp yee.Context) *responseDumper { 175 | buf := new(bytes.Buffer) 176 | return &responseDumper{ 177 | ResponseWriter: resp.Response().Writer(), 178 | mw: io.MultiWriter(resp.Response().Writer(), buf), 179 | buf: buf, 180 | } 181 | } 182 | 183 | func (d *responseDumper) Write(b []byte) (int, error) { 184 | return d.mw.Write(b) 185 | } 186 | 187 | func (d *responseDumper) GetResponse() string { 188 | return d.buf.String() 189 | } 190 | 191 | func logError(span opentracing.Span, err error) { 192 | span.LogKV("error.message", err.Error()) 193 | span.SetTag("error", true) 194 | } 195 | -------------------------------------------------------------------------------- /pprof.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "net/http" 5 | "net/http/pprof" 6 | ) 7 | 8 | const DefaultPrefix = "/debug/pprof" 9 | 10 | func getPrefix(prefixOptions string) string { 11 | prefix := DefaultPrefix 12 | if len(prefixOptions) > 1 { 13 | prefix = "/debug" + prefixOptions 14 | } 15 | return prefix 16 | } 17 | 18 | func WrapF(f http.HandlerFunc) HandlerFunc { 19 | return func(c Context) (err error) { 20 | f(c.Response(), c.Request()) 21 | return nil 22 | } 23 | } 24 | 25 | func WrapH(h http.Handler) HandlerFunc { 26 | return func(c Context) (err error) { 27 | h.ServeHTTP(c.Response(), c.Request()) 28 | return nil 29 | } 30 | } 31 | 32 | func (c *Core) Pprof() { 33 | c.GET(getPrefix("/"), WrapF(pprof.Index)) 34 | c.GET(getPrefix("/cmdline"), WrapF(pprof.Cmdline)) 35 | c.GET(getPrefix("/profile"), WrapF(pprof.Profile)) 36 | c.POST(getPrefix("/symbol"), WrapF(pprof.Symbol)) 37 | c.GET(getPrefix("/symbol"), WrapF(pprof.Symbol)) 38 | c.GET(getPrefix("/trace"), WrapF(pprof.Trace)) 39 | c.GET(getPrefix("/allocs"), WrapH(pprof.Handler("allocs"))) 40 | c.GET(getPrefix("/block"), WrapH(pprof.Handler("block"))) 41 | c.GET(getPrefix("/goroutine"), WrapH(pprof.Handler("goroutine"))) 42 | c.GET(getPrefix("/heap"), WrapH(pprof.Handler("heap"))) 43 | c.GET(getPrefix("/mutex"), WrapH(pprof.Handler("mutex"))) 44 | c.GET(getPrefix("/threadcreate"), WrapH(pprof.Handler("threadcreate"))) 45 | } 46 | -------------------------------------------------------------------------------- /pprof_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import "testing" 4 | 5 | func TestCore_Pprof(t *testing.T) { 6 | c := New() 7 | c.Pprof() 8 | c.Run(":9999") 9 | } 10 | -------------------------------------------------------------------------------- /response_overide.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | // Copyright 2014 Manu Martinez-Almeida. All rights reserved. 4 | // Use of this source code is governed by a MIT style 5 | // license that can be found in the LICENSE file. 6 | import ( 7 | "bufio" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/http" 12 | ) 13 | 14 | const ( 15 | noWritten = -1 16 | defaultStatus = http.StatusOK 17 | ) 18 | 19 | // ResponseWriter ... 20 | type ResponseWriter interface { 21 | http.ResponseWriter 22 | http.Hijacker 23 | http.Flusher 24 | http.CloseNotifier 25 | 26 | // Returns the HTTP response status code of the current request. 27 | Status() int 28 | 29 | // Returns the number of bytes already written into the response http body. 30 | // See Written() 31 | Size() int 32 | 33 | // Writes the string into the response body. 34 | WriteString(string) (int, error) 35 | 36 | // Returns true if the response body was already written. 37 | Written() bool 38 | 39 | //// Forces to write the http header (status code + headers). 40 | WriteHeaderNow() 41 | 42 | // get the http.Pusher for server push 43 | Pusher() http.Pusher 44 | 45 | Writer() http.ResponseWriter 46 | 47 | Override(rw http.ResponseWriter) 48 | } 49 | 50 | type responseWriter struct { 51 | http.ResponseWriter 52 | size int 53 | status int 54 | } 55 | 56 | var _ ResponseWriter = &responseWriter{} 57 | 58 | func (w *responseWriter) Writer() http.ResponseWriter { 59 | return w.ResponseWriter 60 | } 61 | 62 | func (w *responseWriter) Override(rw http.ResponseWriter) { 63 | w.ResponseWriter = rw 64 | } 65 | 66 | func (w *responseWriter) reset(writer http.ResponseWriter) { 67 | w.ResponseWriter = writer 68 | w.size = noWritten 69 | w.status = defaultStatus 70 | } 71 | 72 | func (w *responseWriter) WriteHeader(code int) { 73 | if code > 0 && w.status != code { 74 | if w.Written() { 75 | fmt.Printf("[WARNING] Headers were already written. Wanted to override status code %d with %d", w.status, code) 76 | } 77 | w.status = code 78 | } 79 | } 80 | 81 | func (w *responseWriter) WriteHeaderNow() { 82 | if !w.Written() { 83 | w.size = 0 84 | w.ResponseWriter.WriteHeader(w.status) 85 | } 86 | } 87 | 88 | func (w *responseWriter) Write(data []byte) (n int, err error) { 89 | w.WriteHeaderNow() 90 | n, err = w.ResponseWriter.Write(data) 91 | w.size += n 92 | return 93 | } 94 | 95 | func (w *responseWriter) WriteString(s string) (n int, err error) { 96 | w.WriteHeaderNow() 97 | n, err = io.WriteString(w.ResponseWriter, s) 98 | w.size += n 99 | return 100 | } 101 | 102 | func (w *responseWriter) Status() int { 103 | return w.status 104 | } 105 | 106 | func (w *responseWriter) Size() int { 107 | return w.size 108 | } 109 | 110 | func (w *responseWriter) Written() bool { 111 | return w.size != noWritten 112 | } 113 | 114 | // Hijack implements the http.Hijacker interface. 115 | func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 116 | if w.size < 0 { 117 | w.size = 0 118 | } 119 | return w.ResponseWriter.(http.Hijacker).Hijack() 120 | } 121 | 122 | // CloseNotify implements the http.CloseNotify interface. 123 | func (w *responseWriter) CloseNotify() <-chan bool { 124 | return w.ResponseWriter.(http.CloseNotifier).CloseNotify() 125 | } 126 | 127 | // Flush implements the http.Flush interface. 128 | func (w *responseWriter) Flush() { 129 | w.WriteHeaderNow() 130 | w.ResponseWriter.(http.Flusher).Flush() 131 | } 132 | 133 | func (w *responseWriter) Pusher() (pusher http.Pusher) { 134 | if pusher, ok := w.ResponseWriter.(http.Pusher); ok { 135 | return pusher 136 | } 137 | return nil 138 | } 139 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "embed" 5 | "io/fs" 6 | "net/http" 7 | "path" 8 | "strings" 9 | ) 10 | 11 | type Router struct { 12 | handlers []HandlerFunc 13 | core *Core 14 | root bool 15 | basePath string 16 | } 17 | 18 | // RestfulAPI is the default implementation of restfulApi interface 19 | type RestfulAPI struct { 20 | Get HandlerFunc 21 | Post HandlerFunc 22 | Delete HandlerFunc 23 | Put HandlerFunc 24 | } 25 | 26 | // Implement the HTTP method and add to the router table 27 | // GET,POST,PUT,DELETE,OPTIONS,TRACE,HEAD,PATCH 28 | // these are defined in RFC 7231 section 4.3. 29 | 30 | func (r *Router) GET(path string, handler ...HandlerFunc) { 31 | r.handle(http.MethodGet, path, handler) 32 | } 33 | 34 | func (r *Router) POST(path string, handler ...HandlerFunc) { 35 | r.handle(http.MethodPost, path, handler) 36 | } 37 | 38 | func (r *Router) PUT(path string, handler ...HandlerFunc) { 39 | r.handle(http.MethodPut, path, handler) 40 | } 41 | 42 | func (r *Router) DELETE(path string, handler ...HandlerFunc) { 43 | r.handle(http.MethodDelete, path, handler) 44 | } 45 | 46 | func (r *Router) PATCH(path string, handler ...HandlerFunc) { 47 | r.handle(http.MethodPatch, path, handler) 48 | } 49 | 50 | func (r *Router) HEAD(path string, handler ...HandlerFunc) { 51 | r.handle(http.MethodHead, path, handler) 52 | } 53 | 54 | func (r *Router) TRACE(path string, handler ...HandlerFunc) { 55 | r.handle(http.MethodTrace, path, handler) 56 | } 57 | 58 | func (r *Router) OPTIONS(path string, handler ...HandlerFunc) { 59 | r.handle(http.MethodOptions, path, handler) 60 | } 61 | 62 | func (r *Router) Restful(path string, api RestfulAPI) { 63 | 64 | if api.Get != nil { 65 | r.handle(http.MethodGet, path, HandlersChain{api.Get}) 66 | } 67 | if api.Post != nil { 68 | r.handle(http.MethodPost, path, HandlersChain{api.Post}) 69 | } 70 | if api.Put != nil { 71 | r.handle(http.MethodPut, path, HandlersChain{api.Put}) 72 | } 73 | if api.Delete != nil { 74 | r.handle(http.MethodDelete, path, HandlersChain{api.Delete}) 75 | } 76 | } 77 | 78 | func (r *Router) Any(path string, handler ...HandlerFunc) { 79 | r.handle(http.MethodPost, path, handler) 80 | r.handle(http.MethodGet, path, handler) 81 | r.handle(http.MethodPut, path, handler) 82 | r.handle(http.MethodDelete, path, handler) 83 | r.handle(http.MethodOptions, path, handler) 84 | } 85 | 86 | func (r *Router) Use(middleware ...HandlerFunc) { 87 | r.handlers = append(r.handlers, middleware...) 88 | } 89 | 90 | func (r *Router) Group(prefix string, handlers ...HandlerFunc) *Router { 91 | rx := &Router{ 92 | handlers: r.combineHandlers(handlers), 93 | core: r.core, 94 | basePath: r.calculateAbsolutePath(prefix), 95 | } 96 | return rx 97 | } 98 | 99 | func (r *Router) handle(method, path string, handlers HandlersChain) { 100 | absolutePath := r.calculateAbsolutePath(path) 101 | handlers = r.combineHandlers(handlers) 102 | r.core.addRoute(method, absolutePath, handlers) 103 | } 104 | 105 | func (c *Core) addRoute(method, prefix string, handlers HandlersChain) { 106 | if prefix[0] != '/' { 107 | panic("path must begin with '/'") 108 | } 109 | 110 | if method == "" { 111 | panic("HTTP method can not be empty") 112 | } 113 | 114 | root := c.trees.get(method) 115 | if root == nil { 116 | root = new(node) 117 | root.fullPath = "/" 118 | c.trees = append(c.trees, methodTree{method: method, root: root}) 119 | } 120 | root.addRoute(prefix, handlers) 121 | 122 | // Update maxParams 123 | if paramsCount := countParams(prefix); paramsCount > c.maxParams { 124 | c.maxParams = paramsCount 125 | } 126 | 127 | } 128 | 129 | func (r *Router) Static(relativePath, root string) { 130 | if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") { 131 | panic("URL path cannot be used when serving a static folder") 132 | } 133 | handler := r.createDistHandler(relativePath, Dir(root, false)) 134 | url := path.Join(relativePath, "/*filepath") 135 | r.GET(url, handler) 136 | r.HEAD(url, handler) 137 | } 138 | 139 | func (r *Router) Pack(relativePath string, f embed.FS, root string) { 140 | if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") { 141 | panic("URL path cannot be used when serving a static folder") 142 | } 143 | fsys, err := fs.Sub(f, root) 144 | if err != nil { 145 | panic(err) 146 | } 147 | handler := r.createDistHandler(relativePath, http.FS(fsys)) 148 | url := path.Join(relativePath, "/*filepath") 149 | r.GET(url, handler) 150 | r.HEAD(url, handler) 151 | } 152 | 153 | func (r *Router) createDistHandler(relativePath string, fs http.FileSystem) HandlerFunc { 154 | absolutePath := r.calculateAbsolutePath(relativePath) 155 | fileServer := http.StripPrefix(absolutePath, http.FileServer(fs)) 156 | return func(c Context) (err error) { 157 | if _, noListing := fs.(*onlyFilesFS); noListing { 158 | c.Response().WriteHeader(http.StatusNotFound) 159 | } 160 | file := c.Params("filepath") 161 | f, err2 := fs.Open(file) 162 | if err2 != nil { 163 | c.Status(http.StatusNotFound) 164 | c.Reset() 165 | } 166 | if f != nil { 167 | _ = f.Close() 168 | } 169 | fileServer.ServeHTTP(c.Response(), c.Request()) 170 | return 171 | } 172 | } 173 | 174 | func (r *Router) calculateAbsolutePath(relativePath string) string { 175 | return joinPaths(r.basePath, relativePath) 176 | } 177 | 178 | func (r *Router) combineHandlers(handlers HandlersChain) HandlersChain { 179 | finalSize := len(r.handlers) + len(handlers) 180 | if finalSize >= crashIndex { 181 | panic("too many handlers") 182 | } 183 | mergedHandlers := make(HandlersChain, finalSize) 184 | copy(mergedHandlers, r.handlers) 185 | copy(mergedHandlers[len(r.handlers):], handlers) 186 | return mergedHandlers 187 | } 188 | -------------------------------------------------------------------------------- /router_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestRouterParam(t *testing.T) { 13 | 14 | y := New() 15 | y.GET("/hello/k/:b", func(c Context) error { 16 | return c.String(http.StatusOK, c.Params("b")) 17 | }) 18 | req := httptest.NewRequest(http.MethodGet, "/hello/k/yee", nil) 19 | rec := httptest.NewRecorder() 20 | y.ServeHTTP(rec, req) 21 | assert.Equal(t, "yee", rec.Body.String()) 22 | assert.Equal(t, http.StatusOK, rec.Code) 23 | } 24 | 25 | func TestRouterParams(t *testing.T) { 26 | 27 | y := New() 28 | y.GET("/hello/k/:b/p/:j", func(c Context) error { 29 | return c.String(http.StatusOK, fmt.Sprintf("%s-%s", c.Params("b"), c.Params("j"))) 30 | }) 31 | req := httptest.NewRequest(http.MethodGet, "/hello/k/yee/p/henry", nil) 32 | rec := httptest.NewRecorder() 33 | y.ServeHTTP(rec, req) 34 | assert.Equal(t, "yee-henry", rec.Body.String()) 35 | assert.Equal(t, http.StatusOK, rec.Code) 36 | } 37 | 38 | func TestRouterStaticPath(t *testing.T) { 39 | 40 | y := New() 41 | y.GET("/hello/k/*assets", func(c Context) error { 42 | return c.String(http.StatusOK, c.Params("assets")) 43 | }) 44 | req := httptest.NewRequest(http.MethodGet, "/hello/k/assets/1.js", nil) 45 | rec := httptest.NewRecorder() 46 | y.ServeHTTP(rec, req) 47 | assert.Equal(t, "/assets/1.js", rec.Body.String()) 48 | assert.Equal(t, http.StatusOK, rec.Code) 49 | } 50 | 51 | func TestRouterMultiRoute(t *testing.T) { 52 | 53 | y := New() 54 | y.GET("/hello/k/*k", func(c Context) error { 55 | return c.String(http.StatusOK, c.Params("k")) 56 | }) 57 | y.GET("/hello/k", func(c Context) error { 58 | return c.String(http.StatusOK, "is_ok") 59 | }) 60 | req := httptest.NewRequest(http.MethodGet, "/hello/k/yee/version/route", nil) 61 | rec := httptest.NewRecorder() 62 | y.ServeHTTP(rec, req) 63 | assert.Equal(t, "/yee/version/route", rec.Body.String()) 64 | assert.Equal(t, http.StatusOK, rec.Code) 65 | 66 | req = httptest.NewRequest(http.MethodGet, "/hello/k", nil) 67 | rec = httptest.NewRecorder() 68 | y.ServeHTTP(rec, req) 69 | assert.Equal(t, "is_ok", rec.Body.String()) 70 | assert.Equal(t, http.StatusOK, rec.Code) 71 | 72 | } 73 | 74 | func TestRouterQueryParam(t *testing.T) { 75 | 76 | y := New() 77 | y.GET("/hello/query", func(c Context) error { 78 | return c.String(http.StatusOK, c.QueryParam("query")) 79 | }) 80 | req := httptest.NewRequest(http.MethodGet, "/hello/query?query=henry", nil) 81 | rec := httptest.NewRecorder() 82 | y.ServeHTTP(rec, req) 83 | assert.Equal(t, "henry", rec.Body.String()) 84 | assert.Equal(t, http.StatusOK, rec.Code) 85 | 86 | } 87 | 88 | type testCase struct { 89 | uri string 90 | expect string 91 | } 92 | 93 | func TestRouter_Static(t *testing.T) { 94 | y := New() 95 | y.Static("/front","color") 96 | req := httptest.NewRequest(http.MethodGet, "/front/color.go", nil) 97 | rec := httptest.NewRecorder() 98 | y.ServeHTTP(rec, req) 99 | fmt.Println(rec.Body.String()) 100 | } 101 | 102 | func TestRouterMixin(t *testing.T) { 103 | y := New() 104 | y.GET("/pay", func(c Context) error { 105 | return c.String(http.StatusOK, "pay") 106 | }) 107 | y.GET("/pay/add", func(c Context) error { 108 | return c.String(http.StatusOK, c.QueryParam("person")) 109 | }) 110 | y.GET("/pay/add/:id", func(c Context) error { 111 | return c.String(http.StatusOK, c.Params("id")) 112 | }) 113 | y.GET("/pay/add/:id/:store", func(c Context) error { 114 | return c.String(http.StatusOK, c.Params("id")+c.Params("store")) 115 | }) 116 | y.GET("/pay/dew", func(c Context) error { 117 | return c.String(http.StatusOK, "dew") 118 | }) 119 | y.GET("/pay/dew/*account", func(c Context) error { 120 | return c.String(http.StatusOK, c.Params("account")) 121 | }) 122 | 123 | c := []testCase{ 124 | { 125 | uri: "/pay", 126 | expect: "pay", 127 | }, 128 | { 129 | uri: "/pay/add?person=henry", 130 | expect: "henry", 131 | }, 132 | { 133 | uri: "/pay/add/1", 134 | expect: "1", 135 | }, 136 | { 137 | uri: "/pay/add/1/a", 138 | expect: "1a", 139 | }, 140 | { 141 | uri: "/pay/dew", 142 | expect: "dew", 143 | }, 144 | { 145 | uri: "/pay/dew/account/css/1.css", 146 | expect: "/account/css/1.css", 147 | }, 148 | } 149 | 150 | for _, i := range c { 151 | req := httptest.NewRequest(http.MethodGet, i.uri, nil) 152 | rec := httptest.NewRecorder() 153 | y.ServeHTTP(rec, req) 154 | assert.Equal(t, i.expect, rec.Body.String()) 155 | assert.Equal(t, http.StatusOK, rec.Code) 156 | } 157 | } 158 | 159 | // If you want to test routing performance, You can use benchmark_test to get it 160 | 161 | // --- testing any method 162 | 163 | func testRestfulAPI() RestfulAPI { 164 | 165 | var api RestfulAPI 166 | 167 | api.Get = func(c Context) (err error) { 168 | return c.String(http.StatusOK, "get") 169 | } 170 | 171 | api.Post = func(c Context) (err error) { 172 | return c.String(http.StatusOK, "post") 173 | } 174 | 175 | api.Delete = func(c Context) (err error) { 176 | return c.String(http.StatusOK, "delete") 177 | } 178 | 179 | api.Put = func(c Context) (err error) { 180 | return c.String(http.StatusOK, "put") 181 | } 182 | 183 | return api 184 | } 185 | 186 | func userUpdate(c Context) (err error) { 187 | return c.String(http.StatusOK, "updated") 188 | } 189 | 190 | func userFetch(c Context) (err error) { 191 | return c.String(http.StatusOK, "get it") 192 | } 193 | 194 | func test2RestfulAPI() RestfulAPI { 195 | return RestfulAPI{ 196 | Get: userFetch, 197 | Post: userUpdate, 198 | } 199 | } 200 | 201 | func TestAnyMethod(t *testing.T) { 202 | 203 | y := New() 204 | 205 | y.Restful("/", testRestfulAPI()) 206 | y.Restful("/user", test2RestfulAPI()) 207 | 208 | req := httptest.NewRequest(http.MethodGet, "/", nil) 209 | rec := httptest.NewRecorder() 210 | y.ServeHTTP(rec, req) 211 | assert.Equal(t, "get", rec.Body.String()) 212 | assert.Equal(t, http.StatusOK, rec.Code) 213 | 214 | req = httptest.NewRequest(http.MethodPost, "/", nil) 215 | rec = httptest.NewRecorder() 216 | y.ServeHTTP(rec, req) 217 | assert.Equal(t, "post", rec.Body.String()) 218 | assert.Equal(t, http.StatusOK, rec.Code) 219 | 220 | req = httptest.NewRequest(http.MethodPut, "/", nil) 221 | rec = httptest.NewRecorder() 222 | y.ServeHTTP(rec, req) 223 | assert.Equal(t, "put", rec.Body.String()) 224 | assert.Equal(t, http.StatusOK, rec.Code) 225 | 226 | req = httptest.NewRequest(http.MethodDelete, "/", nil) 227 | rec = httptest.NewRecorder() 228 | y.ServeHTTP(rec, req) 229 | assert.Equal(t, "delete", rec.Body.String()) 230 | assert.Equal(t, http.StatusOK, rec.Code) 231 | 232 | req = httptest.NewRequest(http.MethodGet, "/user", nil) 233 | rec = httptest.NewRecorder() 234 | y.ServeHTTP(rec, req) 235 | assert.Equal(t, "get it", rec.Body.String()) 236 | assert.Equal(t, http.StatusOK, rec.Code) 237 | 238 | req = httptest.NewRequest(http.MethodPost, "/user", nil) 239 | rec = httptest.NewRecorder() 240 | y.ServeHTTP(rec, req) 241 | assert.Equal(t, "updated", rec.Body.String()) 242 | assert.Equal(t, http.StatusOK, rec.Code) 243 | 244 | req = httptest.NewRequest(http.MethodPut, "/user", nil) 245 | rec = httptest.NewRecorder() 246 | y.ServeHTTP(rec, req) 247 | assert.Equal(t, http.StatusNotFound, rec.Code) 248 | 249 | req = httptest.NewRequest(http.MethodDelete, "/user", nil) 250 | rec = httptest.NewRecorder() 251 | y.ServeHTTP(rec, req) 252 | assert.Equal(t, http.StatusNotFound, rec.Code) 253 | } 254 | -------------------------------------------------------------------------------- /test/host.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.26.0 4 | // protoc v3.12.3 5 | // source: host.proto 6 | 7 | package test 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | reflect "reflect" 13 | sync "sync" 14 | ) 15 | 16 | const ( 17 | // Verify that this generated code is sufficiently up-to-date. 18 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 19 | // Verify that runtime/protoimpl is sufficiently up-to-date. 20 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 21 | ) 22 | 23 | type Svr struct { 24 | state protoimpl.MessageState 25 | sizeCache protoimpl.SizeCache 26 | unknownFields protoimpl.UnknownFields 27 | 28 | Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` 29 | IP string `protobuf:"bytes,2,opt,name=IP,proto3" json:"IP,omitempty"` 30 | Project string `protobuf:"bytes,3,opt,name=Project,proto3" json:"Project,omitempty"` 31 | IDC string `protobuf:"bytes,4,opt,name=IDC,proto3" json:"IDC,omitempty"` 32 | Cloud string `protobuf:"bytes,5,opt,name=Cloud,proto3" json:"Cloud,omitempty"` 33 | } 34 | 35 | func (x *Svr) Reset() { 36 | *x = Svr{} 37 | if protoimpl.UnsafeEnabled { 38 | mi := &file_host_proto_msgTypes[0] 39 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 40 | ms.StoreMessageInfo(mi) 41 | } 42 | } 43 | 44 | func (x *Svr) String() string { 45 | return protoimpl.X.MessageStringOf(x) 46 | } 47 | 48 | func (*Svr) ProtoMessage() {} 49 | 50 | func (x *Svr) ProtoReflect() protoreflect.Message { 51 | mi := &file_host_proto_msgTypes[0] 52 | if protoimpl.UnsafeEnabled && x != nil { 53 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 54 | if ms.LoadMessageInfo() == nil { 55 | ms.StoreMessageInfo(mi) 56 | } 57 | return ms 58 | } 59 | return mi.MessageOf(x) 60 | } 61 | 62 | // Deprecated: Use Svr.ProtoReflect.Descriptor instead. 63 | func (*Svr) Descriptor() ([]byte, []int) { 64 | return file_host_proto_rawDescGZIP(), []int{0} 65 | } 66 | 67 | func (x *Svr) GetName() string { 68 | if x != nil { 69 | return x.Name 70 | } 71 | return "" 72 | } 73 | 74 | func (x *Svr) GetIP() string { 75 | if x != nil { 76 | return x.IP 77 | } 78 | return "" 79 | } 80 | 81 | func (x *Svr) GetProject() string { 82 | if x != nil { 83 | return x.Project 84 | } 85 | return "" 86 | } 87 | 88 | func (x *Svr) GetIDC() string { 89 | if x != nil { 90 | return x.IDC 91 | } 92 | return "" 93 | } 94 | 95 | func (x *Svr) GetCloud() string { 96 | if x != nil { 97 | return x.Cloud 98 | } 99 | return "" 100 | } 101 | 102 | var File_host_proto protoreflect.FileDescriptor 103 | 104 | var file_host_proto_rawDesc = []byte{ 105 | 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x68, 0x6f, 106 | 0x73, 0x74, 0x22, 0x6b, 0x0a, 0x03, 0x53, 0x76, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 107 | 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x0e, 0x0a, 108 | 0x02, 0x49, 0x50, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x18, 0x0a, 109 | 0x07, 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 110 | 0x50, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x44, 0x43, 0x18, 0x04, 111 | 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x49, 0x44, 0x43, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x6f, 112 | 0x75, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x6f, 0x75, 0x64, 0x42, 113 | 0x04, 0x5a, 0x02, 0x2e, 0x2f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 114 | } 115 | 116 | var ( 117 | file_host_proto_rawDescOnce sync.Once 118 | file_host_proto_rawDescData = file_host_proto_rawDesc 119 | ) 120 | 121 | func file_host_proto_rawDescGZIP() []byte { 122 | file_host_proto_rawDescOnce.Do(func() { 123 | file_host_proto_rawDescData = protoimpl.X.CompressGZIP(file_host_proto_rawDescData) 124 | }) 125 | return file_host_proto_rawDescData 126 | } 127 | 128 | var file_host_proto_msgTypes = make([]protoimpl.MessageInfo, 1) 129 | var file_host_proto_goTypes = []interface{}{ 130 | (*Svr)(nil), // 0: host.Svr 131 | } 132 | var file_host_proto_depIdxs = []int32{ 133 | 0, // [0:0] is the sub-list for method output_type 134 | 0, // [0:0] is the sub-list for method input_type 135 | 0, // [0:0] is the sub-list for extension type_name 136 | 0, // [0:0] is the sub-list for extension extendee 137 | 0, // [0:0] is the sub-list for field type_name 138 | } 139 | 140 | func init() { file_host_proto_init() } 141 | func file_host_proto_init() { 142 | if File_host_proto != nil { 143 | return 144 | } 145 | if !protoimpl.UnsafeEnabled { 146 | file_host_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 147 | switch v := v.(*Svr); i { 148 | case 0: 149 | return &v.state 150 | case 1: 151 | return &v.sizeCache 152 | case 2: 153 | return &v.unknownFields 154 | default: 155 | return nil 156 | } 157 | } 158 | } 159 | type x struct{} 160 | out := protoimpl.TypeBuilder{ 161 | File: protoimpl.DescBuilder{ 162 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 163 | RawDescriptor: file_host_proto_rawDesc, 164 | NumEnums: 0, 165 | NumMessages: 1, 166 | NumExtensions: 0, 167 | NumServices: 0, 168 | }, 169 | GoTypes: file_host_proto_goTypes, 170 | DependencyIndexes: file_host_proto_depIdxs, 171 | MessageInfos: file_host_proto_msgTypes, 172 | }.Build() 173 | File_host_proto = out.File 174 | file_host_proto_rawDesc = nil 175 | file_host_proto_goTypes = nil 176 | file_host_proto_depIdxs = nil 177 | } 178 | -------------------------------------------------------------------------------- /test/host.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package host; 3 | option go_package = "./"; 4 | 5 | message Svr { 6 | string Name = 1; 7 | string IP = 2; 8 | string Project = 3; 9 | string IDC = 4; 10 | string Cloud = 5; 11 | } -------------------------------------------------------------------------------- /tree.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "bytes" 5 | "net/url" 6 | "strings" 7 | "unicode" 8 | "unicode/utf8" 9 | ) 10 | 11 | var ( 12 | strColon = []byte(":") 13 | strStar = []byte("*") 14 | ) 15 | 16 | // Param is a single URL parameter, consisting of a key and a value. 17 | type Param struct { 18 | Key string 19 | Value string 20 | } 21 | 22 | // Params is a Param-slice, as returned by the router. 23 | // The slice is ordered, the first URL parameter is also the first slice value. 24 | // It is therefore safe to read values by the index. 25 | type Params []Param 26 | 27 | // Get returns the value of the first Param which key matches the given name. 28 | // If no matching Param is found, an empty string is returned. 29 | func (ps Params) Get(name string) (string, bool) { 30 | for _, entry := range ps { 31 | if entry.Key == name { 32 | return entry.Value, true 33 | } 34 | } 35 | return "", false 36 | } 37 | 38 | // ByName returns the value of the first Param which key matches the given name. 39 | // If no matching Param is found, an empty string is returned. 40 | func (ps Params) ByName(name string) (va string) { 41 | va, _ = ps.Get(name) 42 | return 43 | } 44 | 45 | type methodTree struct { 46 | method string 47 | root *node 48 | } 49 | 50 | type methodTrees []methodTree 51 | 52 | func (trees methodTrees) get(method string) *node { 53 | for _, tree := range trees { 54 | if tree.method == method { 55 | return tree.root 56 | } 57 | } 58 | return nil 59 | } 60 | 61 | func min(a, b int) int { 62 | if a <= b { 63 | return a 64 | } 65 | return b 66 | } 67 | 68 | func longestCommonPrefix(a, b string) int { 69 | i := 0 70 | max := min(len(a), len(b)) 71 | for i < max && a[i] == b[i] { 72 | i++ 73 | } 74 | return i 75 | } 76 | 77 | func countParams(path string) uint16 { 78 | var n uint16 79 | s := StringToBytes(path) 80 | n += uint16(bytes.Count(s, strColon)) 81 | n += uint16(bytes.Count(s, strStar)) 82 | return n 83 | } 84 | 85 | type nodeType uint8 86 | 87 | const ( 88 | static nodeType = iota // default 89 | root 90 | param 91 | catchAll 92 | ) 93 | 94 | type node struct { 95 | path string 96 | indices string 97 | wildChild bool 98 | nType nodeType 99 | priority uint32 100 | children []*node 101 | handlers HandlersChain 102 | fullPath string 103 | } 104 | 105 | // Increments priority of the given child and reorders if necessary 106 | func (n *node) incrementChildPrio(pos int) int { 107 | cs := n.children 108 | cs[pos].priority++ 109 | prio := cs[pos].priority 110 | 111 | // Adjust position (move to front) 112 | newPos := pos 113 | for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- { 114 | // Swap node positions 115 | cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1] 116 | 117 | } 118 | 119 | // Build new index char string 120 | if newPos != pos { 121 | n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty 122 | n.indices[pos:pos+1] + // The index char we move 123 | n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos' 124 | } 125 | 126 | return newPos 127 | } 128 | 129 | // addRoute adds a node with the given handle to the path. 130 | // Not concurrency-safe! 131 | func (n *node) addRoute(path string, handlers HandlersChain) { 132 | fullPath := path 133 | n.priority++ 134 | 135 | // Empty tree 136 | if len(n.path) == 0 && len(n.children) == 0 { 137 | n.insertChild(path, fullPath, handlers) 138 | n.nType = root 139 | return 140 | } 141 | 142 | parentFullPathIndex := 0 143 | 144 | walk: 145 | for { 146 | // Find the longest common prefix. 147 | // This also implies that the common prefix contains no ':' or '*' 148 | // since the existing key can't contain those chars. 149 | i := longestCommonPrefix(path, n.path) 150 | 151 | // Split edge 152 | if i < len(n.path) { 153 | child := node{ 154 | path: n.path[i:], 155 | wildChild: n.wildChild, 156 | indices: n.indices, 157 | children: n.children, 158 | handlers: n.handlers, 159 | priority: n.priority - 1, 160 | fullPath: n.fullPath, 161 | } 162 | 163 | n.children = []*node{&child} 164 | // []byte for proper unicode char conversion, see #65 165 | n.indices = BytesToString([]byte{n.path[i]}) 166 | n.path = path[:i] 167 | n.handlers = nil 168 | n.wildChild = false 169 | n.fullPath = fullPath[:parentFullPathIndex+i] 170 | } 171 | 172 | // Make new node a child of this node 173 | if i < len(path) { 174 | path = path[i:] 175 | 176 | if n.wildChild { 177 | parentFullPathIndex += len(n.path) 178 | n = n.children[0] 179 | n.priority++ 180 | 181 | // Check if the wildcard matches 182 | if len(path) >= len(n.path) && n.path == path[:len(n.path)] && 183 | // Adding a child to a catchAll is not possible 184 | n.nType != catchAll && 185 | // Check for longer wildcard, e.g. :name and :names 186 | (len(n.path) >= len(path) || path[len(n.path)] == '/') { 187 | continue walk 188 | } 189 | 190 | pathSeg := path 191 | if n.nType != catchAll { 192 | pathSeg = strings.SplitN(path, "/", 2)[0] 193 | } 194 | prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path 195 | panic("'" + pathSeg + 196 | "' in new path '" + fullPath + 197 | "' conflicts with existing wildcard '" + n.path + 198 | "' in existing prefix '" + prefix + 199 | "'") 200 | } 201 | 202 | c := path[0] 203 | 204 | // slash after param 205 | if n.nType == param && c == '/' && len(n.children) == 1 { 206 | parentFullPathIndex += len(n.path) 207 | n = n.children[0] 208 | n.priority++ 209 | continue walk 210 | } 211 | 212 | // Check if a child with the next path byte exists 213 | for i, max := 0, len(n.indices); i < max; i++ { 214 | if c == n.indices[i] { 215 | parentFullPathIndex += len(n.path) 216 | i = n.incrementChildPrio(i) 217 | n = n.children[i] 218 | continue walk 219 | } 220 | } 221 | 222 | // Otherwise insert it 223 | if c != ':' && c != '*' { 224 | // []byte for proper unicode char conversion, see #65 225 | n.indices += BytesToString([]byte{c}) 226 | child := &node{ 227 | fullPath: fullPath, 228 | } 229 | n.children = append(n.children, child) 230 | n.incrementChildPrio(len(n.indices) - 1) 231 | n = child 232 | } 233 | n.insertChild(path, fullPath, handlers) 234 | return 235 | } 236 | 237 | // Otherwise and handle to current node 238 | if n.handlers != nil { 239 | panic("handlers are already registered for path '" + fullPath + "'") 240 | } 241 | n.handlers = handlers 242 | n.fullPath = fullPath 243 | return 244 | } 245 | } 246 | 247 | // Search for a wildcard segment and check the name for invalid characters. 248 | // Returns -1 as index, if no wildcard was found. 249 | func findWildcard(path string) (wildcard string, i int, valid bool) { 250 | // Find start 251 | for start, c := range []byte(path) { 252 | // A wildcard starts with ':' (param) or '*' (catch-all) 253 | if c != ':' && c != '*' { 254 | continue 255 | } 256 | 257 | // Find end and check for invalid characters 258 | valid = true 259 | for end, c := range []byte(path[start+1:]) { 260 | switch c { 261 | case '/': 262 | return path[start : start+1+end], start, valid 263 | case ':', '*': 264 | valid = false 265 | } 266 | } 267 | return path[start:], start, valid 268 | } 269 | return "", -1, false 270 | } 271 | 272 | func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) { 273 | for { 274 | // Find prefix until first wildcard 275 | wildcard, i, valid := findWildcard(path) 276 | if i < 0 { // No wildcard found 277 | break 278 | } 279 | 280 | // The wildcard name must not contain ':' and '*' 281 | if !valid { 282 | panic("only one wildcard per path segment is allowed, has: '" + 283 | wildcard + "' in path '" + fullPath + "'") 284 | } 285 | 286 | // check if the wildcard has a name 287 | if len(wildcard) < 2 { 288 | panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") 289 | } 290 | 291 | // Check if this node has existing children which would be 292 | // unreachable if we insert the wildcard here 293 | if len(n.children) > 0 { 294 | panic("wildcard segment '" + wildcard + 295 | "' conflicts with existing children in path '" + fullPath + "'") 296 | } 297 | 298 | if wildcard[0] == ':' { // param 299 | if i > 0 { 300 | // Insert prefix before the current wildcard 301 | n.path = path[:i] 302 | path = path[i:] 303 | } 304 | 305 | n.wildChild = true 306 | child := &node{ 307 | nType: param, 308 | path: wildcard, 309 | fullPath: fullPath, 310 | } 311 | n.children = []*node{child} 312 | n = child 313 | n.priority++ 314 | 315 | // if the path doesn't end with the wildcard, then there 316 | // will be another non-wildcard subpath starting with '/' 317 | if len(wildcard) < len(path) { 318 | path = path[len(wildcard):] 319 | 320 | child := &node{ 321 | priority: 1, 322 | fullPath: fullPath, 323 | } 324 | n.children = []*node{child} 325 | n = child 326 | continue 327 | } 328 | 329 | // Otherwise we're done. Insert the handle in the new leaf 330 | n.handlers = handlers 331 | return 332 | } 333 | 334 | // catchAll 335 | if i+len(wildcard) != len(path) { 336 | panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") 337 | } 338 | 339 | if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { 340 | panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") 341 | } 342 | 343 | // currently fixed width 1 for '/' 344 | i-- 345 | if path[i] != '/' { 346 | panic("no / before catch-all in path '" + fullPath + "'") 347 | } 348 | 349 | n.path = path[:i] 350 | 351 | // First node: catchAll node with empty path 352 | child := &node{ 353 | wildChild: true, 354 | nType: catchAll, 355 | fullPath: fullPath, 356 | } 357 | 358 | n.children = []*node{child} 359 | n.indices = string('/') 360 | n = child 361 | n.priority++ 362 | 363 | // second node: node holding the variable 364 | child = &node{ 365 | path: path[i:], 366 | nType: catchAll, 367 | handlers: handlers, 368 | priority: 1, 369 | fullPath: fullPath, 370 | } 371 | n.children = []*node{child} 372 | 373 | return 374 | } 375 | 376 | // If no wildcard was found, simply insert the path and handle 377 | n.path = path 378 | n.handlers = handlers 379 | n.fullPath = fullPath 380 | } 381 | 382 | // nodeValue holds return values of (*Node).getValue method 383 | type nodeValue struct { 384 | handlers HandlersChain 385 | params *Params 386 | tsr bool 387 | fullPath string 388 | } 389 | 390 | // Returns the handle registered with the given path (key). The values of 391 | // wildcards are saved to a map. 392 | // If no handle can be found, a TSR (trailing slash redirect) recommendation is 393 | // made if a handle exists with an extra (without the) trailing slash for the 394 | // given path. 395 | func (n *node) getValue(path string, params *Params, unescape bool) (value nodeValue) { 396 | walk: // Outer loop for walking the tree 397 | for { 398 | prefix := n.path 399 | if len(path) > len(prefix) { 400 | if path[:len(prefix)] == prefix { 401 | path = path[len(prefix):] 402 | // If this node does not have a wildcard (param or catchAll) 403 | // child, we can just look up the next child node and continue 404 | // to walk down the tree 405 | if !n.wildChild { 406 | idxc := path[0] 407 | for i, c := range []byte(n.indices) { 408 | if c == idxc { 409 | n = n.children[i] 410 | continue walk 411 | } 412 | } 413 | 414 | // Nothing found. 415 | // We can recommend to redirect to the same URL without a 416 | // trailing slash if a leaf exists for that path. 417 | value.tsr = (path == "/" && n.handlers != nil) 418 | return 419 | } 420 | 421 | // Handle wildcard child 422 | n = n.children[0] 423 | switch n.nType { 424 | case param: 425 | // Find param end (either '/' or path end) 426 | end := 0 427 | for end < len(path) && path[end] != '/' { 428 | end++ 429 | } 430 | // Save param value 431 | if params != nil { 432 | if value.params == nil { 433 | value.params = params 434 | } 435 | // Expand slice within preallocated capacity 436 | i := len(*value.params) 437 | *value.params = (*value.params)[:i+1] 438 | val := path[:end] 439 | if unescape { 440 | if v, err := url.QueryUnescape(val); err == nil { 441 | val = v 442 | } 443 | } 444 | (*value.params)[i] = Param{ 445 | Key: n.path[1:], 446 | Value: val, 447 | } 448 | } 449 | 450 | // we need to go deeper! 451 | if end < len(path) { 452 | if len(n.children) > 0 { 453 | path = path[end:] 454 | n = n.children[0] 455 | continue walk 456 | } 457 | 458 | // ... but we can't 459 | value.tsr = (len(path) == end+1) 460 | return 461 | } 462 | 463 | if value.handlers = n.handlers; value.handlers != nil { 464 | value.fullPath = n.fullPath 465 | return 466 | } 467 | if len(n.children) == 1 { 468 | // No handle found. Check if a handle for this path + a 469 | // trailing slash exists for TSR recommendation 470 | n = n.children[0] 471 | value.tsr = (n.path == "/" && n.handlers != nil) 472 | } 473 | return 474 | 475 | case catchAll: 476 | // Save param value 477 | if params != nil { 478 | if value.params == nil { 479 | value.params = params 480 | } 481 | // Expand slice within preallocated capacity 482 | i := len(*value.params) 483 | *value.params = (*value.params)[:i+1] 484 | val := path 485 | if unescape { 486 | if v, err := url.QueryUnescape(path); err == nil { 487 | val = v 488 | } 489 | } 490 | (*value.params)[i] = Param{ 491 | Key: n.path[2:], 492 | Value: val, 493 | } 494 | } 495 | 496 | value.handlers = n.handlers 497 | value.fullPath = n.fullPath 498 | return 499 | 500 | default: 501 | panic("invalid node type") 502 | } 503 | } 504 | } 505 | 506 | if path == prefix { 507 | // We should have reached the node containing the handle. 508 | // Check if this node has a handle registered. 509 | if value.handlers = n.handlers; value.handlers != nil { 510 | value.fullPath = n.fullPath 511 | return 512 | } 513 | 514 | // If there is no handle for this route, but this route has a 515 | // wildcard child, there must be a handle for this path with an 516 | // additional trailing slash 517 | if path == "/" && n.wildChild && n.nType != root { 518 | value.tsr = true 519 | return 520 | } 521 | 522 | // No handle found. Check if a handle for this path + a 523 | // trailing slash exists for trailing slash recommendation 524 | for i, c := range []byte(n.indices) { 525 | if c == '/' { 526 | n = n.children[i] 527 | value.tsr = (len(n.path) == 1 && n.handlers != nil) || 528 | (n.nType == catchAll && n.children[0].handlers != nil) 529 | return 530 | } 531 | } 532 | 533 | return 534 | } 535 | 536 | // Nothing found. We can recommend to redirect to the same URL with an 537 | // extra trailing slash if a leaf exists for that path 538 | value.tsr = (path == "/") || 539 | (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && 540 | path == prefix[:len(prefix)-1] && n.handlers != nil) 541 | return 542 | } 543 | } 544 | 545 | // Makes a case-insensitive lookup of the given path and tries to find a handler. 546 | // It can optionally also fix trailing slashes. 547 | // It returns the case-corrected path and a bool indicating whether the lookup 548 | // was successful. 549 | func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { 550 | const stackBufSize = 128 551 | 552 | // Use a static sized buffer on the stack in the common case. 553 | // If the path is too long, allocate a buffer on the heap instead. 554 | buf := make([]byte, 0, stackBufSize) 555 | if l := len(path) + 1; l > stackBufSize { 556 | buf = make([]byte, 0, l) 557 | } 558 | 559 | ciPath := n.findCaseInsensitivePathRec( 560 | path, 561 | buf, // Preallocate enough memory for new path 562 | [4]byte{}, // Empty rune buffer 563 | fixTrailingSlash, 564 | ) 565 | 566 | return ciPath, ciPath != nil 567 | } 568 | 569 | // Shift bytes in array by n bytes left 570 | func shiftNRuneBytes(rb [4]byte, n int) [4]byte { 571 | switch n { 572 | case 0: 573 | return rb 574 | case 1: 575 | return [4]byte{rb[1], rb[2], rb[3], 0} 576 | case 2: 577 | return [4]byte{rb[2], rb[3]} 578 | case 3: 579 | return [4]byte{rb[3]} 580 | default: 581 | return [4]byte{} 582 | } 583 | } 584 | 585 | // Recursive case-insensitive lookup function used by n.findCaseInsensitivePath 586 | func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte { 587 | npLen := len(n.path) 588 | 589 | walk: // Outer loop for walking the tree 590 | for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) { 591 | // Add common prefix to result 592 | oldPath := path 593 | path = path[npLen:] 594 | ciPath = append(ciPath, n.path...) 595 | 596 | if len(path) > 0 { 597 | // If this node does not have a wildcard (param or catchAll) child, 598 | // we can just look up the next child node and continue to walk down 599 | // the tree 600 | if !n.wildChild { 601 | // Skip rune bytes already processed 602 | rb = shiftNRuneBytes(rb, npLen) 603 | 604 | if rb[0] != 0 { 605 | // Old rune not finished 606 | idxc := rb[0] 607 | for i, c := range []byte(n.indices) { 608 | if c == idxc { 609 | // continue with child node 610 | n = n.children[i] 611 | npLen = len(n.path) 612 | continue walk 613 | } 614 | } 615 | } else { 616 | // Process a new rune 617 | var rv rune 618 | 619 | // Find rune start. 620 | // Runes are up to 4 byte long, 621 | // -4 would definitely be another rune. 622 | var off int 623 | for max := min(npLen, 3); off < max; off++ { 624 | if i := npLen - off; utf8.RuneStart(oldPath[i]) { 625 | // read rune from cached path 626 | rv, _ = utf8.DecodeRuneInString(oldPath[i:]) 627 | break 628 | } 629 | } 630 | 631 | // Calculate lowercase bytes of current rune 632 | lo := unicode.ToLower(rv) 633 | utf8.EncodeRune(rb[:], lo) 634 | 635 | // Skip already processed bytes 636 | rb = shiftNRuneBytes(rb, off) 637 | 638 | idxc := rb[0] 639 | for i, c := range []byte(n.indices) { 640 | // Lowercase matches 641 | if c == idxc { 642 | // must use a recursive approach since both the 643 | // uppercase byte and the lowercase byte might exist 644 | // as an index 645 | if out := n.children[i].findCaseInsensitivePathRec( 646 | path, ciPath, rb, fixTrailingSlash, 647 | ); out != nil { 648 | return out 649 | } 650 | break 651 | } 652 | } 653 | 654 | // If we found no match, the same for the uppercase rune, 655 | // if it differs 656 | if up := unicode.ToUpper(rv); up != lo { 657 | utf8.EncodeRune(rb[:], up) 658 | rb = shiftNRuneBytes(rb, off) 659 | 660 | idxc := rb[0] 661 | for i, c := range []byte(n.indices) { 662 | // Uppercase matches 663 | if c == idxc { 664 | // Continue with child node 665 | n = n.children[i] 666 | npLen = len(n.path) 667 | continue walk 668 | } 669 | } 670 | } 671 | } 672 | 673 | // Nothing found. We can recommend to redirect to the same URL 674 | // without a trailing slash if a leaf exists for that path 675 | if fixTrailingSlash && path == "/" && n.handlers != nil { 676 | return ciPath 677 | } 678 | return nil 679 | } 680 | 681 | n = n.children[0] 682 | switch n.nType { 683 | case param: 684 | // Find param end (either '/' or path end) 685 | end := 0 686 | for end < len(path) && path[end] != '/' { 687 | end++ 688 | } 689 | 690 | // Add param value to case insensitive path 691 | ciPath = append(ciPath, path[:end]...) 692 | 693 | // We need to go deeper! 694 | if end < len(path) { 695 | if len(n.children) > 0 { 696 | // Continue with child node 697 | n = n.children[0] 698 | npLen = len(n.path) 699 | path = path[end:] 700 | continue 701 | } 702 | 703 | // ... but we can't 704 | if fixTrailingSlash && len(path) == end+1 { 705 | return ciPath 706 | } 707 | return nil 708 | } 709 | 710 | if n.handlers != nil { 711 | return ciPath 712 | } 713 | 714 | if fixTrailingSlash && len(n.children) == 1 { 715 | // No handle found. Check if a handle for this path + a 716 | // trailing slash exists 717 | n = n.children[0] 718 | if n.path == "/" && n.handlers != nil { 719 | return append(ciPath, '/') 720 | } 721 | } 722 | 723 | return nil 724 | 725 | case catchAll: 726 | return append(ciPath, path...) 727 | 728 | default: 729 | panic("invalid node type") 730 | } 731 | } else { 732 | // We should have reached the node containing the handle. 733 | // Check if this node has a handle registered. 734 | if n.handlers != nil { 735 | return ciPath 736 | } 737 | 738 | // No handle found. 739 | // Try to fix the path by adding a trailing slash 740 | if fixTrailingSlash { 741 | for i, c := range []byte(n.indices) { 742 | if c == '/' { 743 | n = n.children[i] 744 | if (len(n.path) == 1 && n.handlers != nil) || 745 | (n.nType == catchAll && n.children[0].handlers != nil) { 746 | return append(ciPath, '/') 747 | } 748 | return nil 749 | } 750 | } 751 | } 752 | return nil 753 | } 754 | } 755 | 756 | // Nothing found. 757 | // Try to fix the path by adding / removing a trailing slash 758 | if fixTrailingSlash { 759 | if path == "/" { 760 | return ciPath 761 | } 762 | if len(path)+1 == npLen && n.path[len(path)] == '/' && 763 | strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil { 764 | return append(ciPath, n.path...) 765 | } 766 | } 767 | return nil 768 | } 769 | 770 | -------------------------------------------------------------------------------- /yee.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cookieY/yee/logger" 6 | "golang.org/x/net/http2" 7 | "golang.org/x/net/http2/h2c" 8 | "io" 9 | "log" 10 | "net/http" 11 | "os" 12 | "sync" 13 | ) 14 | 15 | // HandlerFunc define handler of context 16 | type HandlerFunc func(Context) (err error) 17 | 18 | // HandlersChain define handler chain of context 19 | type HandlersChain []HandlerFunc 20 | 21 | // Core implement httpServer interface 22 | type Core struct { 23 | *Router 24 | trees methodTrees 25 | pool sync.Pool 26 | maxParams uint16 27 | HandleMethodNotAllowed bool 28 | H2server *http.Server 29 | allNoRoute HandlersChain 30 | allNoMethod HandlersChain 31 | noRoute HandlersChain 32 | noMethod HandlersChain 33 | l logger.Logger 34 | color *logger.Color 35 | bind DefaultBinder 36 | RedirectTrailingSlash bool 37 | RedirectFixedPath bool 38 | Banner bool 39 | } 40 | 41 | const version = "yee v0.5.1" 42 | 43 | const creator = "Creator: Henry Yee" 44 | const title = "-----Easier and Faster-----" 45 | 46 | const banner = ` 47 | __ __ 48 | _ \/ /_________ 49 | __ /_ _ \ _ \ 50 | _ / / __/ __/ 51 | /_/ \___/\___/ %s 52 | %s 53 | %s 54 | ` 55 | 56 | // New create a core and perform a series of initializations 57 | func New() *Core { 58 | core := C() 59 | core.l.Custom(fmt.Sprintf(banner, logger.Green(version), logger.Red(title), logger.Cyan(creator))) 60 | return core 61 | } 62 | 63 | func C() *Core { 64 | router := &Router{ 65 | handlers: nil, 66 | root: true, 67 | basePath: "/", 68 | } 69 | 70 | core := &Core{ 71 | trees: make(methodTrees, 0, 0), 72 | Router: router, 73 | l: logger.LogCreator(), 74 | bind: DefaultBinder{}, 75 | } 76 | 77 | core.core = core 78 | 79 | core.pool.New = func() interface{} { 80 | return core.allocateContext() 81 | } 82 | return core 83 | } 84 | 85 | // SetLogLevel define custom log level 86 | func (c *Core) SetLogLevel(l uint8) { 87 | c.l.SetLevel(l) 88 | } 89 | 90 | func (c *Core) SetLogOut(out io.Writer) { 91 | c.l.SetOut(out) 92 | } 93 | 94 | func (c *Core) allocateContext() *context { 95 | v := make(Params, 0, c.maxParams) 96 | return &context{engine: c, params: &v, index: -1} 97 | } 98 | 99 | // Use defines which middleware is uesd 100 | // when we dose not match prefix or method 101 | // we`ll register noRoute or noMethod handle for this 102 | // otherwise, we cannot be verified for noRoute/noMethod 103 | func (c *Core) Use(middleware ...HandlerFunc) { 104 | c.Router.Use(middleware...) 105 | c.rebuild404Handlers() 106 | c.rebuild405Handlers() 107 | 108 | } 109 | 110 | func (c *Core) rebuild404Handlers() { 111 | c.allNoRoute = c.combineHandlers(c.noRoute) 112 | } 113 | 114 | func (c *Core) rebuild405Handlers() { 115 | c.allNoMethod = c.combineHandlers(c.noMethod) 116 | } 117 | 118 | // override Handler.ServeHTTP 119 | // all requests/response deal with here 120 | // we use sync.pool save context variable 121 | // because we do this can be used less memory 122 | // we just only reset context, when before callback c.handleHTTPRequest func 123 | // and put context variable into poll 124 | 125 | func (c *Core) ServeHTTP(w http.ResponseWriter, r *http.Request) { 126 | context := c.pool.Get().(*context) 127 | context.writermem.reset(w) 128 | context.r = r 129 | context.reset() 130 | //context.w.Header().Set(HeaderServer, serverName) 131 | c.handleHTTPRequest(context) 132 | c.pool.Put(context) 133 | } 134 | 135 | // NewContext is for testing 136 | func (c *Core) NewContext(r *http.Request, w http.ResponseWriter) Context { 137 | context := new(context) 138 | context.writermem.reset(w) 139 | context.w = &context.writermem 140 | context.r = r 141 | context.engine = c 142 | return context 143 | } 144 | 145 | // Run is launch of http 146 | func (c *Core) Run(addr string) { 147 | if err := http.ListenAndServe(addr, c); err != nil { 148 | c.l.Critical(err.Error()) 149 | os.Exit(1) 150 | } 151 | } 152 | 153 | // RunTLS is launch of tls 154 | // golang supports http2,if client supports http2 155 | // Otherwise, the http protocol return to http1.1 156 | func (c *Core) RunTLS(addr, certFile, keyFile string) { 157 | if err := http.ListenAndServeTLS(addr, certFile, keyFile, c); err != nil { 158 | c.l.Critical(err.Error()) 159 | os.Exit(1) 160 | } 161 | } 162 | 163 | // RunH2C is launch of h2c 164 | // In normal conditions, http2 must used certificate 165 | // H2C is non-certificate`s http2 166 | // notes: 167 | // 1.the browser is not supports H2C proto, you should write your web client test program 168 | // 2.the H2C protocol is not safety 169 | func (c *Core) RunH2C(addr string) { 170 | s := &http2.Server{} 171 | h1s := &http.Server{ 172 | Addr: addr, 173 | Handler: h2c.NewHandler(c, s), 174 | } 175 | log.Fatal(h1s.ListenAndServe()) 176 | } 177 | 178 | //func (c *Core) RunH3(addr, ca, keyFile string) { 179 | // //if isTCP { 180 | // log.Fatal(http3.ListenAndServe(addr, ca, keyFile, c)) 181 | // //log.Fatal(s.ListenAndServeTLS(ca, keyFile)) 182 | //} 183 | 184 | func (c *Core) handleHTTPRequest(context *context) { 185 | httpMethod := context.r.Method 186 | rPath := context.r.URL.Path 187 | unescape := false 188 | // Find root of the tree for the given HTTP method 189 | t := c.trees 190 | for i, tl := 0, len(t); i < tl; i++ { 191 | 192 | if t[i].method != httpMethod { 193 | continue 194 | } 195 | 196 | root := t[i].root 197 | // Find route in tree 198 | value := root.getValue(rPath, context.params, unescape) 199 | if value.params != nil { 200 | context.Param = *value.params 201 | } 202 | if value.handlers != nil { 203 | context.handlers = value.handlers 204 | context.path = value.fullPath 205 | context.Next() 206 | context.writermem.WriteHeaderNow() 207 | return 208 | } 209 | 210 | break 211 | } 212 | 213 | context.handlers = c.allNoRoute 214 | 215 | // Notice 216 | // We must judge whether an empty request is OPTIONS method, 217 | // Because when complex request (XMLHttpRequest) will send an OPTIONS request and fetch the preflight resource. 218 | // But in general, we do not register an OPTIONS handle, 219 | // So this may cause some middleware errors. 220 | if httpMethod == http.MethodOptions { 221 | serveError(context, http.StatusNoContent, nil) 222 | } else { 223 | serveError(context, http.StatusNotFound, []byte("404 NOT FOUND")) 224 | } 225 | } 226 | 227 | func serveError(c *context, code int, defaultMessage []byte) { 228 | c.writermem.status = code 229 | c.Next() 230 | if c.writermem.Written() { 231 | return 232 | } 233 | if c.writermem.Status() == code { 234 | c.writermem.Header()["Content-Type"] = []string{MIMETextPlain} 235 | _, err := c.w.Write(defaultMessage) 236 | if err != nil { 237 | c.engine.l.Error(fmt.Sprintf("cannot write message to writer during serve error: %v", err)) 238 | } 239 | return 240 | } 241 | c.writermem.WriteHeaderNow() 242 | } 243 | -------------------------------------------------------------------------------- /yee_test.go: -------------------------------------------------------------------------------- 1 | package yee 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | func indexHandle(c Context) (err error) { 9 | return c.JSON(http.StatusOK, "ok") 10 | } 11 | 12 | func addRouter(y *Core) { 13 | y.GET("/", indexHandle) 14 | } 15 | 16 | func TestYee(t *testing.T) { 17 | y := New() 18 | addRouter(y) 19 | y.Run(":9999") 20 | } 21 | 22 | func TestRestApi(t *testing.T) { 23 | y := New() 24 | y.Restful("/", RestfulAPI{ 25 | Get: func(c Context) (err error) { 26 | return c.String(http.StatusOK, "updated") 27 | }, 28 | Post: func(c Context) (err error) { 29 | return c.String(http.StatusOK, "get it") 30 | }, 31 | }) 32 | } 33 | 34 | func TestDownload(t *testing.T) { 35 | y := New() 36 | y.GET("/", func(c Context) (err error) { 37 | return c.File("args.go") 38 | }) 39 | y.Run(":9999") 40 | } 41 | 42 | func TestStatic(t *testing.T) { 43 | y := New() 44 | y.Static("/", "dist") 45 | //y.GET("/", func(c Context) error { 46 | // return c.HTMLTpl(http.StatusOK, "./dist/index.html") 47 | //}) 48 | y.Run(":9999") 49 | } 50 | 51 | ////go:embed dist/* 52 | //var f embed.FS 53 | // 54 | ////go:embed dist/index.html 55 | //var index string 56 | 57 | //func TestPack(t *testing.T) { 58 | // y := New() 59 | // y.Pack("/front", f, "dist") 60 | // y.GET("/", func(c Context) error { 61 | // return c.HTML(http.StatusOK, index) 62 | // }) 63 | // y.Run(":9999") 64 | //} 65 | 66 | const ver = `alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000,h3-Q050=":443"; ma=2592000,h3-Q046=":443"; ma=2592000,h3-Q043=":443"; ma=2592000,quic=":443"; ma=2592000; v="46,43"` 67 | 68 | func TestH3(t *testing.T) { 69 | y := New() 70 | y.GET("/", func(c Context) (err error) { 71 | return c.JSON(http.StatusOK, "hello") 72 | }) 73 | y.RunH3(":443", "henry.com+4.pem", "henry.com+4-key.pem") 74 | } 75 | --------------------------------------------------------------------------------