├── .github └── workflows │ └── run-tests.yaml ├── .gitignore ├── LICENSE ├── README.md ├── bench_test.go ├── cmd └── example │ ├── api.http │ ├── go.mod │ ├── go.sum │ └── main.go ├── errors.go ├── go.mod ├── handler.go ├── handler_test.go ├── method.go ├── method_test.go ├── request_reader.go ├── request_reader_test.go ├── response_writer.go ├── restruct.go ├── router.go ├── structtag ├── structtag.go └── structtag_test.go ├── util.go └── util_test.go /.github/workflows/run-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Go 18 | uses: actions/setup-go@v2 19 | with: 20 | go-version: 1.16 21 | 22 | - name: Test 23 | run: go test -cover -v ./... -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.exe 2 | *.exe~ 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | # Test binary, built with `go test -c` 8 | *.test 9 | 10 | # Output of the go coverage tool, specifically when used with LiteIDE 11 | *.out 12 | 13 | # Dependency directories (remove the comment below to include it) 14 | # vendor/ 15 | 16 | # Go workspace file 17 | go.work 18 | 19 | __debug_bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Altlimit LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Run Tests](https://github.com/altlimit/restruct/actions/workflows/run-tests.yaml/badge.svg) 2 | 3 | # restruct 4 | 5 | RESTruct is a go rest framework based on structs. The goal of this project is to automate routing, request and response based on struct methods. 6 | 7 | --- 8 | * [Install](#install) 9 | * [Router](#router) 10 | * [Examples](#examples) 11 | * [Response Writer](#response-writer) 12 | * [Request Reader](#request-reader) 13 | * [Middleware](#middleware) 14 | * [Nested Structs](#nested-structs) 15 | * [Utilities](#utilities) 16 | --- 17 | 18 | ## Install 19 | 20 | ```sh 21 | go get github.com/altlimit/restruct 22 | ``` 23 | 24 | ## Router 25 | 26 | Exported struct methods will be your handlers and will be routed like the following. 27 | 28 | ``` 29 | UpperCase turns to upper-case 30 | With_Underscore to with/underscore 31 | HasParam_0 to has-param/{0} 32 | HasParam_0_AndMore_1 to has-param/{0}/and-more/{1} 33 | ``` 34 | 35 | There are multiple ways to process a request and a response, such as strongly typed parameters and returns or with `*http.Request` or `http.ResponseWriter` parameters. You can also use the `context.Context` parameter. Any other parameters will use the `DefaultReader` which you can override in your `Handler.Reader`. 36 | 37 | ```go 38 | type Calculator struct { 39 | } 40 | 41 | func (c *Calculator) Add(r *http.Request) interface{} { 42 | var req struct { 43 | A int64 `json:"a"` 44 | B int64 `json:"b"` 45 | } 46 | if err := restruct.Bind(r, &req, http.MethodPost); err != nil { 47 | return err 48 | } 49 | return req.A + req.B 50 | } 51 | 52 | func (c *Calculator) Subtract(a, b int64) int64 { 53 | return a - b 54 | } 55 | 56 | func (c *Calculator) Divide(a, b int64) (int64, error) { 57 | if b == 0 { 58 | return 0, errors.New("divide by 0") 59 | } 60 | return a / b, nil 61 | } 62 | 63 | func (c *Calculator) Multiply(r struct { 64 | A int64 `json:"a"` 65 | B int64 `json:"b"` 66 | }) int64 { 67 | return r.A * r.B 68 | } 69 | 70 | func main() { 71 | restruct.Handle("/api/v1/", &Calculator{}) 72 | http.ListenAndServe(":8080", nil) 73 | } 74 | ``` 75 | 76 | We have registered the `Calculator` struct here as our service and we should now have available endpoints which you can send json request and response to. 77 | 78 | ```js 79 | // POST http://localhost:8080/api/v1/add 80 | { 81 | "a": 10, 82 | "b": 20 83 | } 84 | // -> 20 85 | // -> or any errors such as 400 {"error": "Bad Request"} 86 | 87 | // POST http://localhost:8080/api/v1/subtract 88 | // Since this is a non-request, response, context parameter 89 | // it will be coming from json array request as a default behavior from DefaultReader 90 | [ 91 | 20, 92 | 10 93 | ] 94 | // -> 10 95 | 96 | // POST http://localhost:8080/api/v1/divide 97 | // You can also have the ability to have a strongly typed handlers in your parameters and return types. 98 | // Default behaviour from DefaultWriter is if multiple returns with last type is an error with value then it writes it. 99 | [ 100 | 1, 101 | 0 102 | ] 103 | // -> 500 {"error":"Internal Server Error"} 104 | 105 | // POST http://localhost:8080/api/v1/multiply 106 | // With a single struct as a parameter, it will be similar to Add's implementation where it uses Bind internally to populate it. You can change your Bind with DefaultReader{Bind:...} to add your validation library. 107 | { 108 | "a": 2, 109 | "b": 5 110 | } 111 | // -> 10 112 | ``` 113 | 114 | You can override default method named routes using `Router` interface. Implement Router in your service and return a slice `Route`. 115 | 116 | ```go 117 | func (c *Calculator) Routes() []Route { 118 | return []Route{ 119 | Route{Handler: "Add", Path:"addition", Methods: []string{http.MethodPost}}, 120 | Route{Handler: "Subtract", Path:"subtraction", Methods: []string{http.MethodPost}}, 121 | } 122 | } 123 | ``` 124 | 125 | 126 | ## Examples 127 | 128 | Here are more ways to create handlers. 129 | 130 | ```go 131 | type Blob struct { 132 | Internal bool 133 | } 134 | 135 | func (b *Blob) Routes() []Route { 136 | return []Route{ 137 | {Handler: "Download", Path: "blob/{path:.+}", methods: []string{http.MethodGet}} 138 | } 139 | } 140 | 141 | // Will be available at /blob/{path:.+} since we overwrite it in Routes 142 | // you can also avoid using regex by naming your handler with Blob_0Path and access with "0Path" params. 143 | func (b *Blob) Download(w http.ResponseWriter, r *http.Request) { 144 | path := restruct.Params(r)["path"] 145 | // handle your struct like normal 146 | } 147 | ``` 148 | 149 | Here we use `Router` interface to add a regular expression. The path param on the download Route will accept anything even an additional nested paths `/` and it also has a standard handler definition. 150 | 151 | To register the above service: 152 | 153 | ```go 154 | func main() { 155 | restruct.Handle("/api/v1/", &Blob{}) 156 | http.ListenAndServe(":8080", nil) 157 | } 158 | ``` 159 | 160 | You can create additional service with a different prefix by calling `NewHandler` on your struct then adding it with `AddService`. 161 | 162 | ```go 163 | h := restruct.NewHandler(&Blob{}) 164 | h.AddService("/internal/{tag}/", &Blob{Internal: true}) 165 | restruct.Handle("/api/v1/", h) 166 | ``` 167 | 168 | All your services will now be at `/api/v1/internal/{tag}`. You can also register the returned Handler in a third party router but make sure you call `WithPrefix(...)` on it if it's not a root route. 169 | 170 | ```go 171 | http.Handle("/", h) 172 | // or if it's a not a root route 173 | http.Handle("/api/v1/", h.WithPrefix("/api/v1/")) 174 | ``` 175 | 176 | You can have parameters with method using number and access them using `restruct.Params(req)` or `restruct.Vars(ctx)`: 177 | 178 | ```go 179 | // Will be available at /upload/{0} 180 | func (b *Blob) Upload_0(r *http.Request) interface{} { 181 | uploadType := restruct.Params(r)["0"] 182 | // handle your request normally 183 | fileID := ... 184 | return fileID 185 | } 186 | ``` 187 | 188 | Refer to cmd/example for some advance usage. 189 | 190 | ## Response Writer 191 | 192 | The default `ResponseWriter` is `DefaultWriter` which uses json.Encoder().Encode to write outputs. This also handles errors and status codes. You can modify the output by implementing the ResponseWriter interface and set it in your `Handler.Writer`. 193 | 194 | ```go 195 | type TextWriter struct {} 196 | 197 | func (tw *TextWriter) Write(w http.ResponseWriter, r *http.Request, types []reflect.Type, vals []reflect.Value) { 198 | // types - slice of return types 199 | // vals - slice of actual returned values 200 | // this writer we simply write anything returned as text 201 | var out []interface{} 202 | for _, val := range vals { 203 | out = append(out, val.Interface()) 204 | } 205 | w.WriteHeader(http.StatusOK) 206 | w.Header().Set("Content-Type", "text/plain") 207 | w.Write([]byte(fmt.Sprintf("%v", out))) 208 | } 209 | 210 | h := restruct.NewHandler(&Blob{}) 211 | h.Writer = &TextWriter{} 212 | ``` 213 | 214 | ## Request Reader 215 | 216 | A handler can have any or no parameters, but the default parameters that doesn't go through request reader are: `context.Context`, `*http.Request` and `http.ResponseWriter`, these parameters will not be passed in `RequestReader.Read` interface. 217 | 218 | ```go 219 | // use form for urlencoded post 220 | type login struct { 221 | Username string `json:"username" form:"username"` 222 | Password string `json:"password" from:"password"` 223 | } 224 | 225 | func (b *Blob) Login(l *login) interface{} { 226 | log.Println("Login", l.Username, l.Password) 227 | return "OK" 228 | } 229 | ``` 230 | 231 | This uses the `DefaultReader` which by default can unmarshal single struct and use default bind(`restruct.Bind`), you can use your own Bind with `DefaultReader{Bind:yourBinder}` if you want to add validation libraries. The Bind reads the body with json.Encoder, or form values. If you have multiple parameters you will need to send a json array body. 232 | 233 | ```json 234 | [ 235 | "FirstParam", 236 | 2, 237 | {"third":"param"} 238 | ] 239 | ``` 240 | 241 | This is the default behaviour of `DefaultReader`. You can implement `RequestReader` interface which will allow you to control your own parameter parsing. 242 | 243 | ```go 244 | type CustomReader struct {} 245 | func (cr *CustomReader) Read(r *http.Request, types []reflect.Type) (vals []reflect.Value, err error) { 246 | // types are the parameter types in order of your handler you must return equal number of vals to types. 247 | // You'll only get types that is not *http.Request, http.ResponseWriter, context.Context 248 | // You can return Error{} type here to return ResponseWriter errors/response and wrap your errors inside Error{Err:...} 249 | return 250 | } 251 | 252 | ``` 253 | ## Middleware 254 | 255 | Uses standard middleware and add by `handler.Use(...)` or you can add it under `Route` when using the `Router` interface. 256 | 257 | ```go 258 | func auth(next http.Handler) http.Handler { 259 | // you can use your h.Writer here if it's accessible somewhere 260 | wr := rs.DefaultWriter{} 261 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 262 | if r.Header.Get("Authorization") != "abc" { 263 | wr.WriteJSON(w, rs.Error{Status: http.StatusUnauthorized}) 264 | return 265 | } 266 | next.ServeHTTP(w, r) 267 | }) 268 | } 269 | 270 | h := restruct.NewHandler(&Blob{}) 271 | h.Use(auth) 272 | ``` 273 | 274 | ## Nested Structs 275 | 276 | Nested structs are automatically routed. You can use route tag to customize or add `route:"-"` to skip exported structs. 277 | 278 | ```go 279 | type ( 280 | V1 struct { 281 | Users User 282 | DB DB `route:"-"` 283 | } 284 | 285 | User struct { 286 | 287 | } 288 | ) 289 | 290 | func (v *V1) Drop() {} 291 | func (u *User) SendEmail() {} 292 | 293 | func main() { 294 | restruct.Handle("/api/v1/", &V1{}) 295 | http.ListenAndServe(":8080", nil) 296 | } 297 | ``` 298 | 299 | Will generate route: `/api/v1/drop` and `/api/v1/users/send-email` 300 | 301 | ## Utilities 302 | 303 | Available helper utilities for processing requests and response. 304 | 305 | ```go 306 | // Adding context values in middleware such as logged in userID 307 | auth := r.Header.Get("Authorization") == "some-key-or-jwt" 308 | if userID, ok := UserIDFromAuth(auth); ok { 309 | r = restruct.SetValue(r, "userID", userID) 310 | } 311 | // then access it from anywhere or a private method for getting your user record 312 | if userID, ok := restruct.GetValue(r, "userID").(int64); ok { 313 | user, err := DB.GetUserByID(ctx, userID) 314 | // do something with user 315 | } 316 | 317 | // Bind helps read your json and form requests into a struct, you can add tag "query" 318 | // to bind query strings at the same time. You can also add tag "form" to bind form posts from 319 | // urlencoded or multipart. You can also use explicit functions BindQuery or BindForm. 320 | var loginReq struct { 321 | Username string `json:"username"` 322 | Password string `json:"password"` 323 | } 324 | if err := restruct.Bind(r, &loginReq, http.MethodPost); err != nil { 325 | return err 326 | } 327 | 328 | // Reading path parameters with Params /products/{0} 329 | params := restruct.Params(r) 330 | productID := params["0"] 331 | ``` 332 | 333 | ## License 334 | 335 | MIT -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | type testService struct{} 9 | 10 | func (ts *testService) Hello(r *http.Request) {} 11 | 12 | type testService2 struct{} 13 | 14 | func (ts *testService2) Hello_0(r *http.Request) {} 15 | 16 | type testService3 struct{} 17 | 18 | func (ts *testService3) Hello(r *http.Request) {} 19 | 20 | func (ts *testService3) Routes() map[string]Route { 21 | return map[string]Route{"Hello": {Path: "{v1}/{v2}/{v3}/{v4}/{v5}"}} 22 | } 23 | 24 | type testService4 struct{} 25 | 26 | func (ts *testService4) Hello(r *http.Request) {} 27 | func (ts *testService4) Hello_0(r *http.Request) {} 28 | func (ts *testService4) World(r *http.Request) {} 29 | 30 | func (ts *testService4) Routes() map[string]Route { 31 | return map[string]Route{"Hello": {Path: "tags/{tag:.+}"}} 32 | } 33 | 34 | // goos: linux 35 | // goarch: amd64 36 | // pkg: github.com/altlimit/restruct 37 | // cpu: Intel(R) Core(TM) i7-3770K CPU @ 3.50GHz 38 | // BenchmarkHandlerStatic-8 2548689 425.5 ns/op 72 B/op 4 allocs/op 39 | // PASS 40 | // ok github.com/altlimit/restruct 1.569s 41 | func BenchmarkHandlerStatic(b *testing.B) { 42 | h := NewHandler(&testService{}) 43 | h.mustCompile("/api/v1") 44 | request, _ := http.NewRequest("GET", "/api/v1/hello", nil) 45 | for i := 0; i < b.N; i++ { 46 | h.ServeHTTP(nil, request) 47 | } 48 | } 49 | 50 | // goos: linux 51 | // goarch: amd64 52 | // pkg: github.com/altlimit/restruct 53 | // cpu: Intel(R) Core(TM) i7-3770K CPU @ 3.50GHz 54 | // BenchmarkHandlerWithParam-8 1180539 983.0 ns/op 856 B/op 9 allocs/op 55 | // PASS 56 | // ok github.com/altlimit/restruct 2.055s 57 | func BenchmarkHandlerWithParam(b *testing.B) { 58 | h := NewHandler(&testService2{}) 59 | h.mustCompile("/api/v1") 60 | 61 | requestA, _ := http.NewRequest("GET", "/api/v1/hello/1", nil) 62 | for i := 0; i < b.N; i++ { 63 | h.ServeHTTP(nil, requestA) 64 | } 65 | } 66 | 67 | // goos: linux 68 | // goarch: amd64 69 | // pkg: github.com/altlimit/restruct 70 | // cpu: Intel(R) Core(TM) i7-3770K CPU @ 3.50GHz 71 | // BenchmarkWithManyParams-8 5754241 209.3 ns/op 104 B/op 3 allocs/op 72 | // PASS 73 | // ok github.com/altlimit/restruct 1.425s 74 | func BenchmarkWithManyParams(b *testing.B) { 75 | h := NewHandler(&testService3{}) 76 | h.mustCompile("/api/v1") 77 | 78 | matchingRequest, _ := http.NewRequest("GET", "/api/v1/1/2/3/4/5", nil) 79 | for i := 0; i < b.N; i++ { 80 | h.ServeHTTP(nil, matchingRequest) 81 | } 82 | } 83 | 84 | // goos: linux 85 | // goarch: amd64 86 | // pkg: github.com/altlimit/restruct 87 | // cpu: Intel(R) Core(TM) i7-3770K CPU @ 3.50GHz 88 | // BenchmarkMixedHandler-8 463426 2287 ns/op 1152 B/op 21 allocs/op 89 | // PASS 90 | // ok github.com/altlimit/restruct 1.094s 91 | func BenchmarkMixedHandler(b *testing.B) { 92 | h := NewHandler(&testService4{}) 93 | h.mustCompile("/api/v1") 94 | 95 | matchingRequest, _ := http.NewRequest("GET", "/api/v1/tags/abc/123", nil) 96 | matchingRequest2, _ := http.NewRequest("GET", "/api/v1/hello/123", nil) 97 | matchingRequest3, _ := http.NewRequest("GET", "/api/v1/world", nil) 98 | notMatchingRequest, _ := http.NewRequest("GET", "/api/v1/world", nil) 99 | for i := 0; i < b.N; i++ { 100 | h.ServeHTTP(nil, matchingRequest) 101 | h.ServeHTTP(nil, matchingRequest2) 102 | h.ServeHTTP(nil, matchingRequest3) 103 | h.ServeHTTP(nil, notMatchingRequest) 104 | } 105 | } 106 | 107 | // goos: linux 108 | // goarch: amd64 109 | // pkg: github.com/altlimit/restruct 110 | // cpu: Intel(R) Core(TM) i7-3770K CPU @ 3.50GHz 111 | // BenchmarkMatch-8 5562338 212.0 ns/op 336 B/op 2 allocs/op 112 | // PASS 113 | // ok github.com/altlimit/restruct 1.407s 114 | func BenchmarkMatch(b *testing.B) { 115 | m := &method{path: "catch/{all}"} 116 | m.mustParse() 117 | for i := 0; i < b.N; i++ { 118 | matchPath(paramCache{path: m.path, pathParts: m.pathParts}, "catch/hello") 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /cmd/example/api.http: -------------------------------------------------------------------------------- 1 | @baseUrl = http://localhost:8090/api/v1 2 | 3 | ### 4 | 5 | GET {{baseUrl}}/docs 6 | 7 | ### 8 | 9 | GET {{baseUrl}}/pages?err=hello 10 | 11 | ### 12 | 13 | GET {{baseUrl}}/raw-response 14 | 15 | ### 16 | POST {{baseUrl}}/users/login 17 | Content-Type: application/json 18 | 19 | { 20 | "username": "admin", 21 | "password": "admin" 22 | } 23 | 24 | ### 25 | 26 | POST {{baseUrl}}/users 27 | Content-Type: application/json 28 | 29 | {} 30 | 31 | ### 32 | 33 | GET {{baseUrl}}/blobs/download/abc 34 | Authorization: admin 35 | 36 | ### 37 | 38 | GET {{baseUrl}}/blobs/link/abc/def 39 | Authorization: admin 40 | 41 | ### 42 | 43 | GET {{baseUrl}}/blobs/links/abc/def 44 | Authorization: admin 45 | 46 | ### 47 | 48 | POST {{baseUrl}}/blobs/upload 49 | Authorization: admin 50 | Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW 51 | 52 | ------WebKitFormBoundary7MA4YWxkTrZu0gW 53 | Content-Disposition: form-data; name="name" 54 | 55 | A sample file 56 | ------WebKitFormBoundary7MA4YWxkTrZu0gW 57 | Content-Disposition: form-data; name="file"; filename="sample.txt" 58 | Content-Type: text/plain 59 | 60 | < ./api.http 61 | ------WebKitFormBoundary7MA4YWxkTrZu0gW-- -------------------------------------------------------------------------------- /cmd/example/go.mod: -------------------------------------------------------------------------------- 1 | module restructexample 2 | 3 | go 1.17 4 | 5 | replace github.com/altlimit/restruct => ../../ 6 | 7 | require ( 8 | github.com/altlimit/restruct v0.0.0-20220616021605-5da3fb060604 9 | github.com/go-playground/validator/v10 v10.11.0 10 | ) 11 | 12 | require ( 13 | github.com/go-playground/locales v0.14.0 // indirect 14 | github.com/go-playground/universal-translator v0.18.0 // indirect 15 | github.com/leodido/go-urn v1.2.1 // indirect 16 | golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect 17 | golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect 18 | golang.org/x/text v0.3.7 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /cmd/example/go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= 6 | github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 7 | github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= 8 | github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= 9 | github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= 10 | github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= 11 | github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw= 12 | github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= 13 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 14 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 15 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 16 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 17 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 18 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 19 | github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= 20 | github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= 21 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 22 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 23 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 24 | github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= 25 | github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= 26 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 27 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 28 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 29 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 30 | golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= 31 | golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 32 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 33 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 34 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 35 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 36 | golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 h1:siQdpVirKtzPhKl3lZWozZraCFObP8S1v6PRp0bLrtU= 37 | golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 38 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 39 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 40 | golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= 41 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 42 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 43 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 44 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 45 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 46 | gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= 47 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 48 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= 49 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 50 | -------------------------------------------------------------------------------- /cmd/example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "mime/multipart" 11 | "net/http" 12 | "reflect" 13 | "strings" 14 | 15 | rs "github.com/altlimit/restruct" 16 | "github.com/go-playground/validator/v10" 17 | ) 18 | 19 | type ( 20 | V1 struct { 21 | validate *validator.Validate 22 | DB struct{} `route:"-"` 23 | 24 | User User `route:"users"` 25 | Blobs Blob 26 | docs []string 27 | } 28 | 29 | User struct { 30 | } 31 | 32 | Blob struct { 33 | } 34 | ) 35 | 36 | var ( 37 | errBadRequest = errors.New("bad request") 38 | errAuth = fmt.Errorf("not logged in") 39 | 40 | v1 *V1 41 | writer = &rs.DefaultWriter{ 42 | Errors: map[error]rs.Error{ 43 | errAuth: {Status: http.StatusUnauthorized}, 44 | errBadRequest: {Status: http.StatusBadRequest}, 45 | }, 46 | } 47 | ) 48 | 49 | func init() { 50 | v1 = &V1{validate: validator.New()} 51 | v1.validate.RegisterTagNameFunc(func(fld reflect.StructField) string { 52 | tags := []string{"json", "query"} 53 | for _, tag := range tags { 54 | name := strings.SplitN(fld.Tag.Get(tag), ",", 2)[0] 55 | if name != "-" && name != "" { 56 | return name 57 | } 58 | } 59 | return fld.Name 60 | }) 61 | } 62 | 63 | // extending bind to support validation with go validator 64 | func (v *V1) bind(r *http.Request, src interface{}, methods ...string) error { 65 | // we still use default bind but add in our custom validator library below 66 | if err := rs.Bind(r, src, methods...); err != nil { 67 | return err 68 | } 69 | if src == nil { 70 | return nil 71 | } 72 | if err := v.validate.Struct(src); err != nil { 73 | valErrors := make(map[string]string) 74 | for _, err := range err.(validator.ValidationErrors) { 75 | valErrors[err.Namespace()] = err.Tag() 76 | } 77 | return rs.Error{Status: http.StatusBadRequest, Message: "validation error", Data: valErrors} 78 | } 79 | 80 | return nil 81 | } 82 | 83 | func (v *V1) user(r *http.Request) (int64, error) { 84 | if userID, ok := rs.GetValue(r, "userID").(int64); ok { 85 | // you could be doing DB to get current user here 86 | return userID, nil 87 | } 88 | return 0, errAuth 89 | } 90 | 91 | func (v *V1) Docs() []string { 92 | return v.docs 93 | } 94 | 95 | func (v *V1) Pages(r *http.Request) (code int, pages []string, err error) { 96 | code = http.StatusAccepted 97 | pages = append(pages, "hello", "world") 98 | if e := r.URL.Query().Get("err"); e != "" { 99 | err = errors.New(e) 100 | } 101 | return 102 | } 103 | 104 | func (v *V1) RawResponse() *rs.Response { 105 | return &rs.Response{ 106 | Status: http.StatusOK, 107 | ContentType: "text/html", 108 | Content: []byte(`Hi`), 109 | } 110 | } 111 | 112 | // limit request size middleware 113 | func limitsMiddleware(next http.Handler) http.Handler { 114 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 | var maxBodyLimit int64 = 1 << 20 116 | if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/upload") { 117 | maxBodyLimit = 128 << 20 118 | } 119 | r.Body = http.MaxBytesReader(w, r.Body, maxBodyLimit) 120 | next.ServeHTTP(w, r) 121 | }) 122 | } 123 | 124 | // auth middleware 125 | func authMiddleware(next http.Handler) http.Handler { 126 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 127 | if r.Header.Get("Authorization") != "admin" { 128 | writer.WriteJSON(w, rs.Error{Status: http.StatusUnauthorized}) 129 | return 130 | } 131 | // use SetValue/GetValue to easily sets and get values from context 132 | r = rs.SetValue(r, "userID", int64(1)) 133 | next.ServeHTTP(w, r) 134 | }) 135 | } 136 | 137 | func loggerMiddleware(next http.Handler) http.Handler { 138 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 139 | log.Println(r.Method, " - ", r.URL.Path) 140 | next.ServeHTTP(w, r) 141 | }) 142 | } 143 | 144 | func catchAllHandler() http.HandlerFunc { 145 | return func(w http.ResponseWriter, r *http.Request) { 146 | log.Println("Caught", r.URL.Path) 147 | } 148 | } 149 | 150 | // Add middleware to this service without changing their paths 151 | func (b *Blob) Routes() []rs.Route { 152 | auth := []rs.Middleware{authMiddleware} 153 | return []rs.Route{ 154 | {Handler: "Download_0", Methods: []string{http.MethodGet}, Middlewares: auth}, 155 | {Handler: "Upload", Middlewares: auth}, 156 | {Handler: "Link", Path: "links/{path:.+}"}, 157 | } 158 | } 159 | 160 | // Add middleware to the whole struct 161 | func (b *Blob) Middlewares() []rs.Middleware { 162 | return []rs.Middleware{loggerMiddleware} 163 | } 164 | 165 | // Magic var 0Path means anything after /link/.+ without regex 166 | func (b *Blob) Link_0Path(ctx context.Context) string { 167 | return rs.Vars(ctx)["0Path"] 168 | } 169 | 170 | func (b *Blob) Link(ctx context.Context) string { 171 | return rs.Vars(ctx)["path"] 172 | } 173 | 174 | // Standard handler, you must handle your own response 175 | func (b *Blob) Download_0(w http.ResponseWriter, r *http.Request) { 176 | user, err := v1.user(r) 177 | if err != nil { 178 | writer.WriteJSON(w, err) 179 | return 180 | } 181 | blobID := rs.Params(r)["0"] 182 | blob := fmt.Sprintf("BlobByUser: %d -> %s", user, blobID) 183 | w.WriteHeader(http.StatusOK) 184 | w.Header().Add("Content-Type", "text/plain") 185 | w.Write([]byte(blob)) 186 | } 187 | 188 | type uploadRequest struct { 189 | Name string `form:"name" validate:"required"` 190 | File *multipart.FileHeader `form:"file" validate:"required"` 191 | } 192 | 193 | func (b *Blob) Upload(ctx context.Context, upload *uploadRequest) interface{} { 194 | f, err := upload.File.Open() 195 | if err != nil { 196 | return err 197 | } 198 | buf, err := io.ReadAll(f) 199 | if err != nil { 200 | return err 201 | } 202 | return map[string]interface{}{ 203 | "filename": upload.File.Filename, 204 | "size": upload.File.Size, 205 | "content": string(buf), 206 | } 207 | } 208 | 209 | func (u *User) Login(login struct { 210 | Username string `json:"username" validate:"required"` 211 | Password string `json:"password" validate:"required"` 212 | }) (bool, error) { 213 | if login.Username == "admin" && login.Password == "admin" { 214 | return true, nil 215 | } 216 | return false, rs.Error{Status: http.StatusForbidden, Message: "Invalid login"} 217 | } 218 | 219 | // CRUD api with POST on api/v1/users and GET,PUT,DELETE on api/v1/users/{id} 220 | func (*User) Routes() []rs.Route { 221 | return []rs.Route{ 222 | {Handler: "CreateUser", Path: ".", Methods: []string{http.MethodPost}}, 223 | {Handler: "ReadUser", Path: "{id}", Methods: []string{http.MethodGet}}, 224 | {Handler: "UpdateUser", Path: "{id}", Methods: []string{http.MethodPut}}, 225 | {Handler: "DeleteUser", Path: "{id}", Methods: []string{http.MethodDelete}}, 226 | } 227 | } 228 | 229 | func (u *User) CreateUser() { 230 | log.Println("CreateUser") 231 | } 232 | 233 | func (u *User) ReadUser(ctx context.Context) { 234 | log.Println("ReadUser", rs.Vars(ctx)["id"]) 235 | } 236 | 237 | func (u *User) UpdateUser(ctx context.Context) { 238 | log.Println("UpdateUser", rs.Vars(ctx)["id"]) 239 | } 240 | 241 | func (u *User) DeleteUser(ctx context.Context) { 242 | log.Println("DeleteUser", rs.Vars(ctx)["id"]) 243 | } 244 | 245 | func (v *V1) notFound(r *http.Request) error { 246 | log.Println("Not Found", r.URL.Path) 247 | return rs.Error{Status: http.StatusNotFound} 248 | } 249 | 250 | // all initialization can happen within the strcut method 251 | func (v *V1) Init(h *rs.Handler) { 252 | v1.docs = h.Routes() 253 | // still defaultreader but used our bind to add validation errors 254 | h.Reader = &rs.DefaultReader{Bind: v1.bind} 255 | // still defaultwriter but with options to map custom errors 256 | h.Writer = writer 257 | // add middleware 258 | h.Use(limitsMiddleware) 259 | h.NotFound(v1.notFound) 260 | 261 | var buf bytes.Buffer 262 | buf.WriteString("Endpoints:") 263 | for _, r := range h.Routes() { 264 | buf.WriteString("\n> " + r) 265 | } 266 | log.Println(buf.String()) 267 | } 268 | 269 | func main() { 270 | rs.Handle("/api/v1/", v1) 271 | http.Handle("/", catchAllHandler()) 272 | port := "8090" 273 | log.Println("Listening", port) 274 | http.ListenAndServe(":"+port, nil) 275 | } 276 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import "net/http" 4 | 5 | type ( 6 | Error struct { 7 | Status int 8 | Message string 9 | Data interface{} 10 | Err error 11 | } 12 | ) 13 | 14 | func (e Error) Error() string { 15 | if e.Status == 0 { 16 | return http.StatusText(http.StatusInternalServerError) 17 | } 18 | return http.StatusText(e.Status) 19 | } 20 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/altlimit/restruct 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "reflect" 8 | "regexp" 9 | "sort" 10 | "strings" 11 | ) 12 | 13 | type ( 14 | ctxKey string 15 | ) 16 | 17 | var ( 18 | ErrReaderReturnLen = errors.New("reader args len does not match") 19 | ) 20 | 21 | const ( 22 | keyParams ctxKey = "params" 23 | keyVals ctxKey = "vals" 24 | ) 25 | 26 | type ( 27 | Middleware func(http.Handler) http.Handler 28 | 29 | Handler struct { 30 | // Writer controls the output of your service, defaults to DefaultWriter 31 | Writer ResponseWriter 32 | // Reader controls the input of your service, defaults to DefaultReader 33 | Reader RequestReader 34 | 35 | prefix string 36 | prefixLen int 37 | services map[string]interface{} 38 | cache *methodCache 39 | middlewares []Middleware 40 | notFound *method 41 | } 42 | 43 | methodCache struct { 44 | byParams []paramCache 45 | byPath map[string][]*method 46 | } 47 | 48 | paramCache struct { 49 | path string 50 | pathParts []string 51 | pathRe *regexp.Regexp 52 | methods []*method 53 | } 54 | 55 | wrappedHandler struct { 56 | handler http.Handler 57 | } 58 | ) 59 | 60 | func (mc *methodCache) methods() (methods []*method) { 61 | for _, param := range mc.byParams { 62 | methods = append(methods, param.methods...) 63 | } 64 | for _, m := range mc.byPath { 65 | methods = append(methods, m...) 66 | } 67 | return 68 | } 69 | 70 | func (wh *wrappedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 71 | wh.handler.ServeHTTP(w, r) 72 | } 73 | 74 | // NewHandler creates a handler for a given struct. 75 | func NewHandler(svc interface{}) *Handler { 76 | h := &Handler{ 77 | services: map[string]interface{}{"": svc}, 78 | } 79 | h.mustCompile("") 80 | if init, ok := svc.(Init); ok { 81 | init.Init(h) 82 | } 83 | return h 84 | } 85 | 86 | // WithPrefix prefixes your service with given path. You can't use parameters here. 87 | // This is useful if you want to register this handler with another third party router. 88 | func (h *Handler) WithPrefix(prefix string) *Handler { 89 | h.mustCompile(prefix) 90 | return h 91 | } 92 | 93 | // Routes returns a list of routes registered and it's definition 94 | func (h *Handler) Routes() (routes []string) { 95 | h.updateCache() 96 | for _, m := range h.cache.methods() { 97 | var methods []string 98 | for k := range m.methods { 99 | methods = append(methods, k) 100 | } 101 | if len(methods) == 0 { 102 | methods = append(methods, "*") 103 | } 104 | r := h.prefix + m.path + " [" + strings.Join(methods, ",") + "] -> " + m.location 105 | var params []string 106 | for _, v := range m.params { 107 | params = append(params, v.String()) 108 | } 109 | var returns []string 110 | for _, v := range m.returns { 111 | returns = append(returns, v.String()) 112 | } 113 | r += "(" + strings.Join(params, ", ") + ")" 114 | if len(returns) > 0 { 115 | r += " (" + strings.Join(returns, ", ") + ")" 116 | } 117 | routes = append(routes, r) 118 | } 119 | sort.Strings(routes) 120 | return 121 | } 122 | 123 | // AddService adds a new service to specified route. 124 | // You can put {param} in this route. 125 | func (h *Handler) AddService(path string, svc interface{}) { 126 | path = strings.TrimPrefix(path, "/") 127 | if !strings.HasSuffix(path, "/") { 128 | path += "/" 129 | } 130 | if _, ok := h.services[path]; ok { 131 | panic("service " + path + " already exists") 132 | } 133 | h.services[path] = svc 134 | h.cache = nil 135 | } 136 | 137 | // Use adds a middleware to your services. 138 | func (h *Handler) Use(fns ...Middleware) { 139 | h.middlewares = append(h.middlewares, fns...) 140 | } 141 | 142 | // NotFound sets the notFound handler and calls it 143 | // if no route matches 144 | func (h *Handler) NotFound(handler interface{}) { 145 | h.notFound = &method{source: reflect.ValueOf(handler)} 146 | h.notFound.mustParse() 147 | } 148 | 149 | // ServeHTTP calls the method with the matched route. 150 | func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 151 | path := r.URL.Path[h.prefixLen:] 152 | if h.Writer == nil { 153 | h.Writer = &DefaultWriter{} 154 | } 155 | if h.Reader == nil { 156 | h.Reader = &DefaultReader{Bind: Bind} 157 | } 158 | // if there are middleware we wrap it in reverse so it's called 159 | // in the order they were added 160 | chain := func(m *method) *wrappedHandler { 161 | handler := &wrappedHandler{handler: h.createHandler(m)} 162 | middlewares := append(h.middlewares, m.middlewares...) 163 | for i := len(middlewares) - 1; i >= 0; i-- { 164 | handler = &wrappedHandler{handler: middlewares[i](handler)} 165 | } 166 | return handler 167 | } 168 | runMethod := func(m *method) { 169 | chain(m).ServeHTTP(w, r) 170 | } 171 | // we check path look up first then see if proper method 172 | if vals, ok := h.cache.byPath[path]; ok { 173 | for _, m := range vals { 174 | ok := m.methods == nil 175 | if !ok { 176 | _, ok = m.methods[r.Method] 177 | } 178 | if ok { 179 | runMethod(m) 180 | return 181 | } 182 | } 183 | h.Writer.Write(w, r, refTypes(typeError), refVals(Error{Status: http.StatusMethodNotAllowed})) 184 | return 185 | } 186 | // we do heavier look up such as path parts or regex then if any match 187 | // we set path found but still need to match method for proper error return 188 | status := http.StatusNotFound 189 | for _, bp := range h.cache.byParams { 190 | params, ok := matchPath(bp, path) 191 | if ok { 192 | for _, v := range bp.methods { 193 | ok := v.methods == nil 194 | if !ok { 195 | _, ok = v.methods[r.Method] 196 | } 197 | if ok { 198 | if len(params) > 0 { 199 | ctx := r.Context() 200 | ctx = context.WithValue(ctx, keyParams, params) 201 | r = r.WithContext(ctx) 202 | } 203 | runMethod(v) 204 | return 205 | } 206 | } 207 | status = http.StatusMethodNotAllowed 208 | } 209 | } 210 | if status == http.StatusNotFound && h.notFound != nil { 211 | runMethod(h.notFound) 212 | return 213 | } 214 | h.Writer.Write(w, r, refTypes(typeError), refVals(Error{Status: status})) 215 | } 216 | 217 | // wrapped handler that calls the actual method and processes the returns 218 | // the parameter allowed here are *http.Request and http.ResponseWriter 219 | // the returns can be anything or an error which will be sent to the ResponseWriter 220 | // a multiple return is passed as slice of interface{} 221 | func (h *Handler) createHandler(m *method) http.Handler { 222 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 223 | var ( 224 | argTypes []reflect.Type 225 | argIndexes []int 226 | ) 227 | args := make([]reflect.Value, len(m.params)) 228 | for k, v := range m.params { 229 | switch v { 230 | case typeHttpRequest: 231 | args[k] = reflect.ValueOf(r) 232 | case typeHttpWriter: 233 | args[k] = reflect.ValueOf(w) 234 | case typeContext: 235 | args[k] = reflect.ValueOf(r.Context()) 236 | default: 237 | argTypes = append(argTypes, v) 238 | argIndexes = append(argIndexes, k) 239 | } 240 | } 241 | // has unknown types in parameters, use RequestReader 242 | if len(argIndexes) > 0 { 243 | typeArgs, err := h.Reader.Read(r, argTypes) 244 | if err != nil { 245 | h.Writer.Write(w, r, refTypes(typeError), refVals(err)) 246 | return 247 | } 248 | if len(typeArgs) != len(argIndexes) { 249 | h.Writer.Write(w, r, refTypes(typeError), refVals(Error{Err: ErrReaderReturnLen})) 250 | return 251 | } 252 | for k, i := range argIndexes { 253 | args[i] = typeArgs[k] 254 | } 255 | } 256 | out := m.source.Call(args) 257 | ot := len(out) 258 | if ot == 0 { 259 | return 260 | } 261 | h.Writer.Write(w, r, m.returns, out) 262 | }) 263 | } 264 | 265 | // Called every time you add a handler to create a cached info about 266 | // your routes and which methods it points to. This will also look up 267 | // exported structs to add as a service. You can avoid this by adding 268 | // route:"-" or to specify specific route add route:"path/{hello}" 269 | func (h *Handler) updateCache() { 270 | if h.cache != nil { 271 | return 272 | } 273 | if h.prefix == "" { 274 | h.mustCompile("") 275 | } 276 | h.cache = &methodCache{ 277 | byPath: make(map[string][]*method), 278 | } 279 | // cache all same paths so we only compare it once 280 | pathCache := make(map[string][]*method) 281 | // we store ordered paths so it's still looked up in order you enter it 282 | var orderedPaths []string 283 | for k, svc := range h.services { 284 | for _, v := range serviceToMethods(k, svc) { 285 | if v.pathRe != nil || v.pathParts != nil { 286 | _, ok := pathCache[v.path] 287 | if !ok { 288 | orderedPaths = append(orderedPaths, v.path) 289 | } 290 | pathCache[v.path] = append(pathCache[v.path], v) 291 | } else { 292 | h.cache.byPath[v.path] = append(h.cache.byPath[v.path], v) 293 | } 294 | } 295 | } 296 | for _, path := range orderedPaths { 297 | // all of methods here have the same path so we use first one 298 | m := pathCache[path][0] 299 | h.cache.byParams = append(h.cache.byParams, paramCache{ 300 | path: path, 301 | pathParts: m.pathParts, 302 | pathRe: m.pathRe, 303 | methods: pathCache[path], 304 | }) 305 | } 306 | } 307 | 308 | func (h *Handler) mustCompile(prefix string) { 309 | if !strings.HasSuffix(prefix, "/") { 310 | prefix += "/" 311 | } 312 | h.prefix = prefix 313 | h.prefixLen = len(h.prefix) 314 | h.updateCache() 315 | } 316 | 317 | // Checks path against request path if it's valid, this accepts a stripped path and not a full path 318 | func matchPath(pc paramCache, path string) (params map[string]string, ok bool) { 319 | params = make(map[string]string) 320 | if pc.pathRe != nil { 321 | match := pc.pathRe.FindStringSubmatch(path) 322 | if len(match) > 0 { 323 | for i, name := range pc.pathRe.SubexpNames() { 324 | if i != 0 && name != "" { 325 | params[name] = match[i] 326 | } 327 | } 328 | ok = true 329 | } 330 | } else if pc.pathParts != nil { 331 | // match by parts 332 | idx := -1 333 | pt := len(pc.pathParts) 334 | for { 335 | idx++ 336 | if idx+1 > pt { 337 | return 338 | } 339 | i := strings.Index(path, "/") 340 | var part string 341 | if i == -1 { 342 | part = path[i+1:] 343 | if part == "" { 344 | return 345 | } 346 | } else { 347 | part = path[:i] 348 | } 349 | mPart := pc.pathParts[idx] 350 | if mPart[0] == '{' { 351 | name := mPart[1 : len(mPart)-1] 352 | if mPart == "{0Path}" { 353 | params[name] = path 354 | ok = true 355 | return 356 | } 357 | params[name] = part 358 | } else if mPart != part { 359 | return 360 | } 361 | if i == -1 { 362 | break 363 | } 364 | path = path[i+1:] 365 | } 366 | ok = idx+1 == pt 367 | } 368 | return 369 | } 370 | -------------------------------------------------------------------------------- /handler_test.go: -------------------------------------------------------------------------------- 1 | package restruct_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "net/http/httptest" 11 | "strings" 12 | "testing" 13 | 14 | rs "github.com/altlimit/restruct" 15 | ) 16 | 17 | // Integration tests 18 | 19 | type ( 20 | DB struct{} 21 | V1 struct { 22 | DB DB `route:"-"` 23 | 24 | User User `route:"users"` 25 | Blobs Blob 26 | } 27 | 28 | User struct { 29 | } 30 | 31 | Blob struct { 32 | } 33 | 34 | Calculator struct { 35 | } 36 | ) 37 | 38 | var ( 39 | errBadRequest = errors.New("bad request") 40 | errAuth = fmt.Errorf("not logged in") 41 | 42 | executions int 43 | ) 44 | 45 | func (v *V1) bind(r *http.Request, src interface{}, methods ...string) error { 46 | if err := rs.Bind(r, src, methods...); err != nil { 47 | return err 48 | } 49 | if src == nil { 50 | return nil 51 | } 52 | return nil 53 | } 54 | 55 | func execMiddleware(next http.Handler) http.Handler { 56 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 | executions++ 58 | r = rs.SetValue(r, "execs", executions) 59 | next.ServeHTTP(w, r) 60 | }) 61 | } 62 | 63 | func authMiddleware(next http.Handler) http.Handler { 64 | wr := rs.DefaultWriter{} 65 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 66 | if r.Header.Get("Authorization") != "admin" { 67 | wr.WriteJSON(w, rs.Error{Status: http.StatusUnauthorized}) 68 | return 69 | } 70 | r = rs.SetValue(r, "userID", int64(1)) 71 | next.ServeHTTP(w, r) 72 | }) 73 | } 74 | 75 | func (db *DB) Query() error { 76 | return nil 77 | } 78 | 79 | func (c *Calculator) Add(r *http.Request) interface{} { 80 | var req struct { 81 | A int64 `json:"a"` 82 | B int64 `json:"b"` 83 | } 84 | if err := rs.Bind(r, &req, http.MethodPost); err != nil { 85 | return err 86 | } 87 | return req.A + req.B 88 | } 89 | 90 | func (c *Calculator) Subtract(a, b int64) int64 { 91 | return a - b 92 | } 93 | 94 | func (c *Calculator) Divide(a, b int64) (int64, error) { 95 | if b == 0 { 96 | return 0, errors.New("divide by 0") 97 | } 98 | return a / b, nil 99 | } 100 | 101 | func (c *Calculator) Multiply(r struct { 102 | A int64 `json:"a"` 103 | B int64 `json:"b"` 104 | }) int64 { 105 | return r.A * r.B 106 | } 107 | 108 | func (b *Blob) Routes() []rs.Route { 109 | // todo maybe ability to somehow put middleware to a whole nested struct 110 | auth := []rs.Middleware{authMiddleware} 111 | return []rs.Route{ 112 | {Handler: "Download_0", Methods: []string{http.MethodGet}, Middlewares: auth}, 113 | {Handler: "Upload", Path: ".custom/{path:.+}", Middlewares: auth}, 114 | } 115 | } 116 | 117 | // Standard handler, you must handle your own response 118 | func (b *Blob) Download_0(w http.ResponseWriter, r *http.Request) { 119 | blobID := rs.Params(r)["0"] 120 | w.WriteHeader(http.StatusOK) 121 | w.Header().Add("Content-Type", "text/plain") 122 | w.Write([]byte(string(blobID))) 123 | } 124 | 125 | func (b *Blob) Link_0Path(ctx context.Context) string { 126 | return rs.Vars(ctx)["0Path"] 127 | } 128 | 129 | func (b *Blob) Upload(r *http.Request) interface{} { 130 | return rs.Params(r)["path"] 131 | } 132 | 133 | func (u *User) Login(ctx context.Context, login struct { 134 | Username string `json:"username" form:"username"` 135 | Password string `json:"password" form:"password"` 136 | }) (bool, error) { 137 | if login.Username == "admin" && login.Password == "admin" { 138 | return true, nil 139 | } 140 | return false, rs.Error{Status: http.StatusForbidden, Message: "Invalid login"} 141 | } 142 | 143 | func (u *User) Execs(r *http.Request) int { 144 | execs := rs.GetValue(r, "execs").(int) 145 | return execs 146 | } 147 | 148 | func TestHandler(t *testing.T) { 149 | v1 := &V1{} 150 | h := rs.NewHandler(v1) 151 | h.AddService("calc", new(Calculator)) 152 | h.Use(execMiddleware) 153 | h.Reader = &rs.DefaultReader{Bind: v1.bind} 154 | h.Writer = &rs.DefaultWriter{ 155 | Errors: map[error]rs.Error{ 156 | errAuth: {Status: http.StatusUnauthorized}, 157 | errBadRequest: {Status: http.StatusBadRequest}, 158 | }, 159 | } 160 | var buf bytes.Buffer 161 | for _, r := range h.Routes() { 162 | buf.WriteString(r + "\n") 163 | } 164 | routes := `/blobs/.custom/{path:.+} [*] -> github.com/altlimit/restruct_test.Blob.Upload(*http.Request) (interface {}) 165 | /blobs/download/{0} [GET] -> github.com/altlimit/restruct_test.Blob.Download_0(http.ResponseWriter, *http.Request) 166 | /blobs/link/{0Path} [*] -> github.com/altlimit/restruct_test.Blob.Link_0Path(context.Context) (string) 167 | /calc/add [*] -> github.com/altlimit/restruct_test.Calculator.Add(*http.Request) (interface {}) 168 | /calc/divide [*] -> github.com/altlimit/restruct_test.Calculator.Divide(int64, int64) (int64, error) 169 | /calc/multiply [*] -> github.com/altlimit/restruct_test.Calculator.Multiply(struct { A int64 "json:\"a\""; B int64 "json:\"b\"" }) (int64) 170 | /calc/subtract [*] -> github.com/altlimit/restruct_test.Calculator.Subtract(int64, int64) (int64) 171 | /users/execs [*] -> github.com/altlimit/restruct_test.User.Execs(*http.Request) (int) 172 | /users/login [*] -> github.com/altlimit/restruct_test.User.Login(context.Context, struct { Username string "json:\"username\" form:\"username\""; Password string "json:\"password\" form:\"password\"" }) (bool, error)` 173 | found := strings.Trim(buf.String(), "\n") 174 | if routes != found { 175 | t.Errorf("wanted \n%s\n routes got \n%s\n", routes, found) 176 | } 177 | jh := map[string]string{"Content-Type": "application/json"} 178 | table := []struct { 179 | method string 180 | path string 181 | request string 182 | headers map[string]string 183 | response string 184 | status int 185 | }{ 186 | {http.MethodPost, "/users/login", `{"username": "admin", "password": "admin"}`, jh, `true`, 200}, 187 | {http.MethodPost, "/users/login", `{}`, jh, `{"error":"Invalid login"}`, 403}, 188 | {http.MethodPost, "/users/login", `{`, jh, `{"error":"Bad Request"}`, 400}, 189 | {http.MethodPost, "/users/login", `{"username": "admin", "password": "admin"}`, map[string]string{"Content-Type": "application/x-www-form-urlencoded"}, 190 | `{"error":"Invalid login"}`, 403}, 191 | {http.MethodPost, "/users/login", `username=admin&password=admin`, map[string]string{"Content-Type": "application/x-www-form-urlencoded"}, 192 | `true`, 200}, 193 | {http.MethodPost, "/blobs/download/abc", `{}`, jh, `{"error":"Method Not Allowed"}`, 405}, 194 | {http.MethodGet, "/blobs/download/abc", `{}`, nil, `{"error":"Unauthorized"}`, 401}, 195 | {http.MethodGet, "/blobs/download/abc/", `{}`, nil, `{"error":"Not Found"}`, 404}, 196 | {http.MethodGet, "/blobs/download/abc", ``, map[string]string{"Authorization": "admin"}, `abc`, 200}, 197 | {http.MethodGet, "/blobs/.custom/abc/123", ``, nil, `{"error":"Unauthorized"}`, 401}, 198 | {http.MethodGet, "/blobs/.custom/abc/123/", ``, map[string]string{"Authorization": "admin"}, `"abc/123/"`, 200}, 199 | {http.MethodPost, "/calc/add", `{"a":10,"b":20}`, jh, `30`, 200}, 200 | {http.MethodPost, "/calc/subtract", `[20,10]`, jh, `10`, 200}, 201 | {http.MethodPost, "/calc/subtract", `["bad"]`, jh, `{"error":"Bad Request"}`, 400}, 202 | {http.MethodPost, "/calc/divide", `[10,2]`, jh, `5`, 200}, 203 | {http.MethodPost, "/calc/divide", `[10,0]`, jh, `{"error":"Internal Server Error"}`, 500}, 204 | {http.MethodPost, "/calc/multiply", `{"a":10,"b":2}`, jh, `20`, 200}, 205 | {http.MethodPost, "/calc/multiply", `{"a":10,"b":2}`, nil, `{"error":"Unsupported Media Type"}`, 415}, 206 | {http.MethodGet, "/users/execs", ``, nil, `---EXECS---`, 200}, 207 | } 208 | 209 | runs := 0 210 | for _, v := range table { 211 | if v.response == "---EXECS---" { 212 | v.response = fmt.Sprintf("%d", runs+1) 213 | } 214 | req := httptest.NewRequest(v.method, v.path, strings.NewReader(v.request)) 215 | w := httptest.NewRecorder() 216 | 217 | if v.headers != nil { 218 | for k, v := range v.headers { 219 | req.Header.Add(k, v) 220 | } 221 | } 222 | 223 | h.ServeHTTP(w, req) 224 | 225 | res := w.Result() 226 | defer res.Body.Close() 227 | data, err := ioutil.ReadAll(res.Body) 228 | if err != nil { 229 | t.Errorf("ioutil.ReadAll error %v", err) 230 | } 231 | resp := strings.TrimRight(string(data), "\n") 232 | if !(resp == v.response && res.StatusCode == v.status) { 233 | t.Errorf("path %s wanted %d `%s` got %d `%s`", v.path, v.status, v.response, res.StatusCode, resp) 234 | } 235 | if res.StatusCode != 404 && res.StatusCode != 405 { 236 | runs++ 237 | } 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /method.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "mime/multipart" 8 | "net/http" 9 | "reflect" 10 | "regexp" 11 | "strings" 12 | "unicode" 13 | ) 14 | 15 | var ( 16 | pathToRe = regexp.MustCompile(`{[^}]+}`) 17 | 18 | typeHttpRequest = reflect.TypeOf(&http.Request{}) 19 | typeHttpWriter = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() 20 | typeContext = reflect.TypeOf((*context.Context)(nil)).Elem() 21 | typeError = reflect.TypeOf((*error)(nil)).Elem() 22 | typeInt = reflect.TypeOf((*int)(nil)).Elem() 23 | typeMultipartFileHeader = reflect.TypeOf(&multipart.FileHeader{}) 24 | typeMultipartFileHeaderSlice = reflect.TypeOf([]*multipart.FileHeader{}) 25 | ) 26 | 27 | type ( 28 | method struct { 29 | location string 30 | source reflect.Value 31 | path string 32 | pathRe *regexp.Regexp 33 | pathParts []string 34 | params []reflect.Type 35 | returns []reflect.Type 36 | methods map[string]bool 37 | middlewares []Middleware 38 | } 39 | ) 40 | 41 | // returns methods from structs and nested structs 42 | func serviceToMethods(prefix string, svc interface{}) (methods []*method) { 43 | tv := reflect.TypeOf(svc) 44 | vv := reflect.ValueOf(svc) 45 | 46 | // get methods first 47 | routes := make(map[string][]Route) 48 | skipMethods := map[string]bool{} 49 | if router, ok := svc.(Router); ok { 50 | for _, route := range router.Routes() { 51 | routes[route.Handler] = append(routes[route.Handler], route) 52 | } 53 | skipMethods["Routes"] = true 54 | } 55 | if _, ok := svc.(Init); ok { 56 | skipMethods["Init"] = true 57 | } 58 | var middlewares []Middleware 59 | if mws, ok := svc.(Middlewares); ok { 60 | middlewares = mws.Middlewares() 61 | skipMethods["Middlewares"] = true 62 | } 63 | tvt := vv.NumMethod() 64 | tvEl := tv 65 | if tv.Kind() == reflect.Ptr { 66 | tvEl = tv.Elem() 67 | } 68 | location := tvEl.PkgPath() + "." + tvEl.Name() 69 | for i := 0; i < tvt; i++ { 70 | m := tv.Method(i) 71 | // Skip interface methods 72 | if _, ok := skipMethods[m.Name]; ok { 73 | continue 74 | } 75 | mm := &method{ 76 | location: location + "." + m.Name, 77 | source: vv.Method(i), 78 | middlewares: middlewares, 79 | } 80 | if len(routes) > 0 { 81 | rts, ok := routes[m.Name] 82 | if ok { 83 | for _, route := range rts { 84 | mr := &method{ 85 | location: mm.location, 86 | source: mm.source, 87 | middlewares: mm.middlewares, 88 | } 89 | mr.middlewares = append(mr.middlewares, route.Middlewares...) 90 | if route.Path != "" { 91 | if route.Path == "." { 92 | mr.path = strings.TrimRight(prefix, "/") 93 | } else { 94 | mr.path = prefix + strings.TrimLeft(route.Path, "/") 95 | } 96 | } else { 97 | mr.path = prefix + nameToPath(m.Name) 98 | } 99 | if len(route.Methods) > 0 { 100 | mr.methods = make(map[string]bool) 101 | for _, method := range route.Methods { 102 | mr.methods[method] = true 103 | } 104 | } 105 | mr.mustParse() 106 | methods = append(methods, mr) 107 | } 108 | continue 109 | } 110 | } 111 | mm.path = prefix + nameToPath(m.Name) 112 | mm.mustParse() 113 | methods = append(methods, mm) 114 | } 115 | 116 | if tv.Kind() == reflect.Ptr { 117 | tv = tv.Elem() 118 | vv = vv.Elem() 119 | } 120 | // check fields 121 | tvt = vv.NumField() 122 | for i := 0; i < tvt; i++ { 123 | f := tv.Field(i) 124 | if f.PkgPath != "" { 125 | continue 126 | } 127 | route := f.Tag.Get("route") 128 | if route != "-" { 129 | fk := f.Type.Kind() 130 | fv := vv.Field(i) 131 | if fk == reflect.Ptr { 132 | fk = f.Type.Elem().Kind() 133 | fv = fv.Elem() 134 | } 135 | if fk == reflect.Struct && fv.IsValid() { 136 | if route == "" { 137 | route = nameToPath(f.Name) 138 | } 139 | route = strings.Trim(route, "/") + "/" 140 | sv := fv.Addr().Interface() 141 | methods = append(methods, serviceToMethods(prefix+route, sv)...) 142 | } 143 | } 144 | } 145 | return 146 | } 147 | 148 | // Converts a Name into a path route like: 149 | // HelloWorld -> hello-world 150 | // Hello_World -> hello_world 151 | // Hello_0 -> hello/{0} 152 | // Hello_0_World -> hello/{0}/world 153 | func nameToPath(name string) string { 154 | var buf bytes.Buffer 155 | nt := len(name) 156 | skipDash := false 157 | startParam := false 158 | for i := 0; i < nt; i++ { 159 | c := rune(name[i]) 160 | if !startParam && unicode.IsUpper(c) { 161 | if i > 0 && !skipDash { 162 | buf.WriteRune('-') 163 | } 164 | c = unicode.ToLower(c) 165 | buf.WriteRune(c) 166 | skipDash = false 167 | } else if c == '_' { 168 | if startParam { 169 | buf.WriteRune('}') 170 | startParam = false 171 | } 172 | buf.WriteRune('/') 173 | skipDash = true 174 | } else { 175 | if !startParam && skipDash && unicode.IsNumber(c) { 176 | startParam = true 177 | buf.WriteString(fmt.Sprintf("{%c", c)) 178 | } else { 179 | buf.WriteRune(c) 180 | } 181 | if !startParam { 182 | skipDash = false 183 | } 184 | } 185 | } 186 | if startParam { 187 | buf.WriteRune('}') 188 | } 189 | return buf.String() 190 | } 191 | 192 | // Populates method fields, if there's no params it will leave pathRe nil and 193 | // directly compare path with equality. 194 | func (m *method) mustParse() { 195 | rePath := m.path 196 | params := pathToRe.FindAllString(m.path, -1) 197 | if len(params) > 0 { 198 | withRe := false 199 | for _, m := range params { 200 | ex := fmt.Sprintf(`(?P<%s>\w+)`, m[1:len(m)-1]) 201 | if idx := strings.Index(m, ":"); idx != -1 { 202 | ex = fmt.Sprintf(`(?P<%s>%s)`, m[1:idx], m[idx+1:len(m)-1]) 203 | withRe = true 204 | } 205 | rePath = strings.ReplaceAll(rePath, m, ex) 206 | } 207 | if withRe { 208 | rePath = "^" + rePath + "$" 209 | m.pathRe = regexp.MustCompile(rePath) 210 | } else { 211 | m.pathParts = strings.Split(m.path, "/") 212 | } 213 | } 214 | 215 | if m.source.IsValid() { 216 | mt := m.source.Type() 217 | if mt.Kind() != reflect.Func { 218 | panic("method must be of type func") 219 | } 220 | for i := 0; i < mt.NumOut(); i++ { 221 | m.returns = append(m.returns, mt.Out(i)) 222 | } 223 | for i := 0; i < mt.NumIn(); i++ { 224 | m.params = append(m.params, mt.In(i)) 225 | } 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /method_test.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestNameToPath(t *testing.T) { 11 | table := []struct { 12 | name string 13 | path string 14 | }{ 15 | {"Add", "add"}, 16 | {"UserAuth", "user-auth"}, 17 | {"Hello_World", "hello/world"}, 18 | {"UserAuth_Bad", "user-auth/bad"}, 19 | {"Products_0", `products/{0}`}, 20 | {"Products_0_1", `products/{0}/{1}`}, 21 | {"Products_0_UserX_1", `products/{0}/user-x/{1}`}, 22 | {"Products15", `products15`}, 23 | } 24 | 25 | for _, v := range table { 26 | p := nameToPath(v.name) 27 | if v.path != p { 28 | t.Errorf("got path %s want %s", p, v.path) 29 | } 30 | } 31 | } 32 | 33 | type serviceA struct { 34 | Alpha serviceB `route:"-"` 35 | Bravo serviceB `route:"my/{tag}"` 36 | Charlie *serviceB 37 | Delta *serviceB 38 | Echo *serviceD 39 | } 40 | 41 | type serviceB struct { 42 | Delta serviceC 43 | } 44 | 45 | type serviceC struct{} 46 | 47 | type serviceD struct{} 48 | 49 | func (s *serviceA) Hello(r *http.Request) {} 50 | func (s *serviceB) World(w http.ResponseWriter) {} 51 | func (s *serviceC) HelloWorld(r *http.Request, w http.ResponseWriter) {} 52 | func (s serviceC) Hello_World(w http.ResponseWriter, r *http.Request) {} 53 | 54 | func (s serviceD) Overwrite() {} 55 | func (s serviceD) Root(c context.Context) {} 56 | 57 | func (s *serviceA) Path_0() {} 58 | func (s *serviceA) Path_0_1() {} 59 | func (s *serviceA) Path_0_Sub_1() {} 60 | func (s *serviceA) Link_0FP() {} 61 | func (s *serviceA) Link_0FP_0123() {} 62 | 63 | func (s *serviceD) Routes() []Route { 64 | return []Route{ 65 | {Handler: "Overwrite", Path: ".custom/{pid}/_download_"}, 66 | {Handler: "Root", Path: "."}, 67 | } 68 | } 69 | 70 | func TestServiceToMethods(t *testing.T) { 71 | s1 := &serviceA{Charlie: &serviceB{}, Echo: &serviceD{}} 72 | 73 | routes := map[string][]reflect.Type{ 74 | "s1/hello": {typeHttpRequest}, 75 | "s1/my/{tag}/world": {typeHttpWriter}, 76 | "s1/my/{tag}/delta/hello-world": {typeHttpRequest, typeHttpWriter}, 77 | "s1/my/{tag}/delta/hello/world": {typeHttpWriter, typeHttpRequest}, 78 | "s1/charlie/world": {typeHttpWriter}, 79 | "s1/charlie/delta/hello-world": {typeHttpRequest, typeHttpWriter}, 80 | "s1/charlie/delta/hello/world": {typeHttpWriter, typeHttpRequest}, 81 | "s1/echo/.custom/{pid}/_download_": {}, 82 | "s1/echo": {typeContext}, 83 | "s1/path/{0}": {}, 84 | "s1/path/{0}/{1}": {}, 85 | "s1/path/{0}/sub/{1}": {}, 86 | "s1/link/{0FP}": {}, 87 | "s1/link/{0FP}/{0123}": {}, 88 | } 89 | methods := serviceToMethods("s1/", s1) 90 | if len(methods) != len(routes) { 91 | t.Fatalf("expected %d methods got %d", len(routes), len(methods)) 92 | } 93 | for _, m := range methods { 94 | if _, ok := routes[m.path]; !ok { 95 | t.Fatalf("route %s not found", m.path) 96 | } 97 | if len(m.params) != len(routes[m.path]) { 98 | t.Errorf("%s param mismatch expected %d got %d", m.path, len(routes[m.path]), len(m.params)) 99 | } 100 | for i, v := range m.params { 101 | if v != routes[m.path][i] { 102 | t.Errorf("route mismatch %s", m.path) 103 | } 104 | } 105 | } 106 | } 107 | 108 | func TestMethodMustParseMatch(t *testing.T) { 109 | table := []struct { 110 | path string 111 | test string 112 | match bool 113 | params map[string]string 114 | }{ 115 | {"with/regex/{tag:\\d+}/{hello}", "with/regex/123/world", true, map[string]string{"tag": "123", "hello": "world"}}, 116 | {"path/{tag}/hello/{0}", "path/Anything/hello/129", true, map[string]string{"tag": "Anything", "0": "129"}}, 117 | {"catch/{all:.+}", "catch/hello/world/caught/all", true, map[string]string{"all": "hello/world/caught/all"}}, 118 | {"test/{a}", "test/", false, map[string]string{}}, 119 | } 120 | 121 | for _, v := range table { 122 | m := &method{path: v.path} 123 | m.mustParse() 124 | params, ok := matchPath(paramCache{path: m.path, pathParts: m.pathParts, pathRe: m.pathRe}, v.test) 125 | if ok != v.match { 126 | t.Errorf("path %s not match %s", v.path, v.test) 127 | } 128 | for k, p := range v.params { 129 | if p != params[k] { 130 | t.Errorf("got param %s want %s for %s in %s", params[k], p, k, v.path) 131 | } 132 | } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /request_reader.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "reflect" 9 | ) 10 | 11 | type ( 12 | // RequestReader is called for input for your method if your parameter contains 13 | // a things other than *http.Request, http.ResponseWriter, context.Context 14 | // you'll get a slice of types and you must return values corresponding to those types 15 | RequestReader interface { 16 | Read(*http.Request, []reflect.Type) ([]reflect.Value, error) 17 | } 18 | 19 | // DefaultReader processes request with json.Encoder, urlencoded form and multipart for structs 20 | // if it's just basic types it will be read from body as array such as [1, "hello", false] 21 | // you can overwrite bind to apply validation library, etc 22 | DefaultReader struct { 23 | Bind func(*http.Request, interface{}, ...string) error 24 | } 25 | ) 26 | 27 | func (dr *DefaultReader) Read(r *http.Request, types []reflect.Type) (vals []reflect.Value, err error) { 28 | typeLen := len(types) 29 | vals = make([]reflect.Value, typeLen) 30 | 31 | if typeLen == 0 { 32 | return 33 | } 34 | 35 | // if types is just 1 and a struct, we simply Bind and return 36 | if typeLen == 1 && (types[0].Kind() == reflect.Struct || 37 | types[0].Kind() == reflect.Ptr && types[0].Elem().Kind() == reflect.Struct) { 38 | var ptr bool 39 | arg := types[0] 40 | if arg.Kind() == reflect.Ptr { 41 | arg = arg.Elem() 42 | ptr = true 43 | } 44 | val := reflect.New(arg) 45 | err = dr.Bind(r, val.Interface()) 46 | if err != nil { 47 | return 48 | } 49 | if !ptr { 50 | val = val.Elem() 51 | } 52 | vals[0] = val 53 | return 54 | } 55 | // otherwise we get request body as json array 56 | badRequest := func(s string, f ...interface{}) { 57 | err = Error{ 58 | Status: http.StatusBadRequest, 59 | Err: fmt.Errorf(s, f...), 60 | } 61 | } 62 | var params []json.RawMessage 63 | var body []byte 64 | body, err = ioutil.ReadAll(r.Body) 65 | if err != nil { 66 | err = fmt.Errorf("DefaultReader.Read: ioutil.ReadAll error %v", err) 67 | return 68 | } 69 | err = r.Body.Close() 70 | if err != nil { 71 | err = fmt.Errorf("DefaultReader.Read: r.Body.Close error %v", err) 72 | return 73 | } 74 | err = json.Unmarshal(body, ¶ms) 75 | if err != nil { 76 | badRequest("DefaultReader.Read: json.Unmarshal error %v", err) 77 | return 78 | } 79 | if len(params) < typeLen { 80 | badRequest("DefaultReader.Read: missing params") 81 | return 82 | } 83 | for i := 0; i < typeLen; i++ { 84 | t := types[i] 85 | val := reflect.New(t) 86 | err = json.Unmarshal(params[i], val.Interface()) 87 | if err != nil { 88 | badRequest("DefaultReader.Read: param %d must be %s (%v)", i, t, err) 89 | return 90 | } 91 | vals[i] = val.Elem() 92 | } 93 | return 94 | } 95 | -------------------------------------------------------------------------------- /request_reader_test.go: -------------------------------------------------------------------------------- 1 | package restruct_test 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/altlimit/restruct" 13 | ) 14 | 15 | type ( 16 | sampleService struct{} 17 | 18 | addRequest struct { 19 | A int64 `json:"a" form:"a"` 20 | B int64 `json:"b" form:"b"` 21 | } 22 | ) 23 | 24 | func (ss *sampleService) Add(ctx context.Context, r *addRequest, x map[string]int64, y float64, z int) int64 { 25 | var total int64 26 | for _, v := range x { 27 | total += v 28 | } 29 | return total + r.A + r.B + int64(y) + int64(z) 30 | } 31 | 32 | func (ss *sampleService) Add2(ctx context.Context, r *addRequest) (int64, error) { 33 | return r.A + r.B, nil 34 | } 35 | 36 | func TestDefaultReaderRead(t *testing.T) { 37 | h := restruct.NewHandler(&sampleService{}) 38 | 39 | bod := `[{"a":10,"b":20}, {"S":30,"d":50}, 100, 200]` 40 | req := httptest.NewRequest(http.MethodPost, "/add", strings.NewReader(bod)) 41 | w := httptest.NewRecorder() 42 | 43 | h.ServeHTTP(w, req) 44 | 45 | res := w.Result() 46 | defer res.Body.Close() 47 | data, err := ioutil.ReadAll(res.Body) 48 | if err != nil { 49 | t.Errorf("ioutil.ReadAll error %v", err) 50 | } 51 | if strings.TrimRight(string(data), "\n") != "410" { 52 | t.Errorf("wanted 410 got %s", data) 53 | } 54 | 55 | ubod := url.Values{"a": {"4"}, "b": {"3"}} 56 | req = httptest.NewRequest(http.MethodPost, "/add2", strings.NewReader(ubod.Encode())) 57 | req.Header.Add("Content-Type", "application/x-www-form-urlencoded") 58 | w = httptest.NewRecorder() 59 | 60 | h.ServeHTTP(w, req) 61 | 62 | res = w.Result() 63 | defer res.Body.Close() 64 | data, err = ioutil.ReadAll(res.Body) 65 | if err != nil { 66 | t.Errorf("ioutil.ReadAll error %v", err) 67 | } 68 | if strings.TrimRight(string(data), "\n") != "7" { 69 | t.Errorf("wanted 7 got %s", data) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /response_writer.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "net/http" 7 | "reflect" 8 | ) 9 | 10 | type ( 11 | // ResponseWriter is called on outputs of your methods. 12 | // slice of reflect.Type & Value is the types and returned values 13 | ResponseWriter interface { 14 | Write(http.ResponseWriter, *http.Request, []reflect.Type, []reflect.Value) 15 | } 16 | 17 | // DefaultWriter uses json.Encoder for output 18 | // and manages error handling. Adding Errors mapping can 19 | // help with your existing error to a proper Error{} 20 | DefaultWriter struct { 21 | // Optional ErrorHandler, called whenever unhandled errors occurs, defaults to logging errors 22 | ErrorHandler func(error) 23 | Errors map[error]Error 24 | EscapeJsonHtml bool 25 | } 26 | 27 | // Response is used by DefaultWriter for custom response 28 | Response struct { 29 | Status int 30 | Headers map[string]string 31 | ContentType string 32 | Content []byte 33 | } 34 | 35 | // Json response to specify a status code for default writer 36 | Json struct { 37 | Status int 38 | Content interface{} 39 | } 40 | ) 41 | 42 | // Write implements the DefaultWriter ResponseWriter 43 | // returning (int, any, error) will write status int, any response if error is nil 44 | // returning (any, error) will write any response if error is nil with status 200 or 400, 500 depdening on your error 45 | // returning (int, any, any, error) will write status int slice of [any, any] response if error is nil 46 | func (dw *DefaultWriter) Write(w http.ResponseWriter, r *http.Request, types []reflect.Type, vals []reflect.Value) { 47 | // no returns are not sent here so we just check if 1 or more 48 | lt := len(types) 49 | if lt == 1 { 50 | val := vals[0].Interface() 51 | if resp, ok := val.(*Response); ok { 52 | dw.WriteResponse(w, resp) 53 | } else { 54 | dw.WriteJSON(w, val) 55 | } 56 | return 57 | } 58 | var ( 59 | out interface{} 60 | j *Json 61 | ) 62 | defer func() { 63 | if j != nil { 64 | j.Content = out 65 | out = j 66 | } 67 | dw.WriteJSON(w, out) 68 | }() 69 | // return with last type error 70 | if types[lt-1] == typeError { 71 | errVal := vals[lt-1] 72 | if !errVal.IsNil() { 73 | out = errVal.Interface() 74 | return 75 | } 76 | vals = vals[:lt-1] 77 | } 78 | // returning (int, something) means status code, response 79 | if len(vals) > 1 && types[0] == typeInt { 80 | j = &Json{Status: int(vals[0].Int())} 81 | vals = vals[1:] 82 | } 83 | if len(vals) == 1 { 84 | out = vals[0].Interface() 85 | return 86 | } 87 | var args []interface{} 88 | for _, v := range vals { 89 | args = append(args, v.Interface()) 90 | } 91 | out = args 92 | } 93 | 94 | func (dw *DefaultWriter) log(err error) { 95 | if dw.ErrorHandler == nil { 96 | dw.ErrorHandler = func(err error) { 97 | log.Println("InternalError:", err) 98 | } 99 | } 100 | dw.ErrorHandler(err) 101 | } 102 | 103 | // This writes application/json content type uses status code 200 104 | // on valid ones and 500 on uncaught, 400 on malformed json, etc. 105 | // use Json{Status, Content} to specify a code 106 | func (dw *DefaultWriter) WriteJSON(w http.ResponseWriter, out interface{}) { 107 | if w == nil { 108 | return 109 | } 110 | status := http.StatusOK 111 | if j, ok := out.(*Json); ok { 112 | if j.Status > 0 { 113 | status = j.Status 114 | } 115 | out = j.Content 116 | } 117 | if out == nil { 118 | w.WriteHeader(status) 119 | return 120 | } 121 | cType := "application/json; charset=UTF-8" 122 | 123 | var headers map[string]string 124 | if err, ok := out.(error); ok { 125 | status = http.StatusInternalServerError 126 | var ( 127 | msg string 128 | errData interface{} 129 | ) 130 | e, ok := err.(Error) 131 | if dw.Errors != nil && !ok { 132 | if ee, k := dw.Errors[err]; k { 133 | ok = true 134 | e = ee 135 | } 136 | } 137 | if ok { 138 | if e.Status != 0 { 139 | status = e.Status 140 | } 141 | if e.Message != "" { 142 | msg = e.Message 143 | } 144 | if e.Data != nil { 145 | errData = e.Data 146 | } 147 | if e.Err != nil { 148 | dw.log(e.Err) 149 | } 150 | } else { 151 | dw.log(err) 152 | } 153 | if msg == "" { 154 | msg = http.StatusText(status) 155 | } 156 | errResp := map[string]interface{}{ 157 | "error": msg, 158 | } 159 | if errData != nil { 160 | errResp["data"] = errData 161 | } 162 | out = errResp 163 | } 164 | 165 | w.WriteHeader(status) 166 | h := w.Header() 167 | foundContentType := false 168 | for k, v := range headers { 169 | if k == "Content-Type" { 170 | foundContentType = true 171 | } 172 | h.Add(k, v) 173 | } 174 | if !foundContentType { 175 | h.Set("Content-Type", cType) 176 | } 177 | enc := json.NewEncoder(w) 178 | enc.SetEscapeHTML(dw.EscapeJsonHtml) 179 | if err := enc.Encode(out); err != nil { 180 | dw.log(err) 181 | } 182 | } 183 | 184 | func (dw *DefaultWriter) WriteResponse(w http.ResponseWriter, resp *Response) { 185 | if resp.Status > 0 { 186 | w.WriteHeader(resp.Status) 187 | } 188 | if resp.ContentType != "" && w.Header().Get("Content-Type") == "" { 189 | w.Header().Add("Content-Type", resp.ContentType) 190 | } 191 | if len(resp.Content) > 0 { 192 | if _, err := w.Write(resp.Content); err != nil { 193 | dw.log(err) 194 | } 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /restruct.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Handle registers a struct or a *Handler for the given pattern in the http.DefaultServeMux. 8 | func Handle(pattern string, svc interface{}) { 9 | h, ok := svc.(*Handler) 10 | if !ok { 11 | h = NewHandler(svc) 12 | } 13 | h.mustCompile(pattern) 14 | http.Handle(h.prefix, h) 15 | } 16 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | type ( 4 | // Router can be used to override method name to specific path, 5 | // implement Router interface in your service and return a slice of Route: 6 | // [Route{Handler:"ProductEdit", Path: "product/{pid}"}] 7 | Router interface { 8 | Routes() []Route 9 | } 10 | 11 | // Middlewares interface for common middleware for a struct 12 | Middlewares interface { 13 | Middlewares() []Middleware 14 | } 15 | 16 | // Init interface to access and override handler configs 17 | Init interface { 18 | Init(*Handler) 19 | } 20 | 21 | // Route for doing overrides with router interface and method restrictions. 22 | Route struct { 23 | // Handler is the method name you want to use for this route 24 | Handler string 25 | // optional path, will use default behaviour if not present 26 | Path string 27 | // optional methods, will allow all if not present 28 | Methods []string 29 | // optional middlewares, run specific middleware for this route 30 | Middlewares []Middleware 31 | } 32 | ) 33 | -------------------------------------------------------------------------------- /structtag/structtag.go: -------------------------------------------------------------------------------- 1 | package structtag 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/hex" 6 | "reflect" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | var ( 12 | tagsCache sync.Map 13 | ) 14 | 15 | type StructField struct { 16 | Tag string 17 | Index int 18 | Tags map[string]string 19 | } 20 | 21 | func (sf *StructField) Value(tag string) (v string, ok bool) { 22 | v, ok = sf.Tags[tag] 23 | return 24 | } 25 | 26 | func NewStructField(index int, tag string) *StructField { 27 | tags := strings.Split(tag, ",") 28 | sf := &StructField{Index: index, Tags: make(map[string]string)} 29 | for i := 0; i < len(tags); i++ { 30 | t := strings.Split(tags[i], "=") 31 | var val string 32 | if len(t) > 1 { 33 | val = t[1] 34 | } 35 | if i == 0 { 36 | sf.Tag = t[0] 37 | } 38 | sf.Tags[t[0]] = val 39 | } 40 | return sf 41 | } 42 | 43 | func GetFieldsByTag(i interface{}, tag string) []*StructField { 44 | t := reflect.TypeOf(i) 45 | if t.Kind() == reflect.Ptr { 46 | t = t.Elem() 47 | } 48 | cKey := t.String() + ":" + tag 49 | if len(cKey) > 50 { 50 | h := sha1.New() 51 | h.Write([]byte(cKey)) 52 | cKey = hex.EncodeToString(h.Sum(nil)) 53 | } 54 | cache, ok := tagsCache.Load(cKey) 55 | if !ok { 56 | var fields []*StructField 57 | for i := 0; i < t.NumField(); i++ { 58 | f := t.Field(i) 59 | m := f.Tag.Get(tag) 60 | if m != "" { 61 | fields = append(fields, NewStructField(i, m)) 62 | } 63 | } 64 | tagsCache.Store(cKey, fields) 65 | cache = fields 66 | } 67 | 68 | return cache.([]*StructField) 69 | } 70 | -------------------------------------------------------------------------------- /structtag/structtag_test.go: -------------------------------------------------------------------------------- 1 | package structtag 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestGetFieldsByTag(t *testing.T) { 8 | type hello struct { 9 | World string `json:"world" marshal:"f1"` 10 | X string `json:"X"` 11 | Y string `marshal:"f2,noindex=x"` 12 | Z string `marshal:"f1,f2,f3=abc"` 13 | } 14 | h := &hello{World: "123"} 15 | for _, f := range GetFieldsByTag(h, "marshal") { 16 | k := f.Tag 17 | if f.Index == 0 { 18 | if k != "f1" { 19 | t.Errorf("wanted f1 got %s", k) 20 | } 21 | if _, ok := f.Value("noindex"); ok { 22 | t.Errorf("wanted noindex not found but found") 23 | } 24 | } else if f.Index == 2 { 25 | if k != "f2" { 26 | t.Errorf("wanted f2 got %s", k) 27 | } 28 | if _, ok := f.Value("noindex"); !ok { 29 | t.Errorf("wanted noindex but found") 30 | } 31 | } else if f.Index == 3 { 32 | if k != "f1" { 33 | t.Errorf("wanted f1 got %s", k) 34 | } 35 | if v, ok := f.Value("f2"); !ok || v != "" { 36 | t.Errorf("wanted f2 but found") 37 | } 38 | if v, ok := f.Value("f3"); !ok || v != "abc" { 39 | t.Errorf("wanted f3 but found or v != abc -> %s", v) 40 | } 41 | } 42 | } 43 | if _, ok := tagsCache.Load("structtag.hello:marshal"); !ok { 44 | t.Errorf("No cache found") 45 | } 46 | z := &hello{World: "1235"} 47 | for _, f := range GetFieldsByTag(z, "marshal") { 48 | if f.Index == 0 { 49 | if _, ok := f.Value("noindex"); ok { 50 | t.Errorf("wanted noindex not found but found") 51 | } 52 | } else if f.Index == 2 { 53 | if v, ok := f.Value("noindex"); !ok { 54 | t.Errorf("wanted noindex but found") 55 | } else if v != "x" { 56 | t.Errorf("wanted x value found %s", v) 57 | } 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "mime/multipart" 9 | "net/http" 10 | "reflect" 11 | "strconv" 12 | "strings" 13 | 14 | "github.com/altlimit/restruct/structtag" 15 | ) 16 | 17 | // Params returns map of params from url path like /{param1} will be map[param1] = value 18 | func Params(r *http.Request) map[string]string { 19 | return Vars(r.Context()) 20 | } 21 | 22 | // Vars returns map of params from url from request context 23 | func Vars(ctx context.Context) map[string]string { 24 | if params, ok := ctx.Value(keyParams).(map[string]string); ok { 25 | return params 26 | } 27 | return map[string]string{} 28 | } 29 | 30 | // Query returns a query string value 31 | func Query(r *http.Request, name string) string { 32 | return r.URL.Query().Get(name) 33 | } 34 | 35 | // Bind checks for valid methods and tries to bind query strings and body into struct 36 | func Bind(r *http.Request, out interface{}, methods ...string) error { 37 | if len(methods) > 0 { 38 | found := false 39 | for _, m := range methods { 40 | if r.Method == m { 41 | found = true 42 | break 43 | } 44 | } 45 | if !found { 46 | return Error{Status: http.StatusMethodNotAllowed} 47 | } 48 | } 49 | if out == nil { 50 | return nil 51 | } 52 | if len(r.URL.Query()) > 0 { 53 | if err := BindQuery(r, out); err != nil { 54 | return err 55 | } 56 | } 57 | if r.Method == http.MethodGet { 58 | return nil 59 | } 60 | cType := r.Header.Get("Content-Type") 61 | if idx := strings.Index(cType, ";"); idx != -1 { 62 | cType = cType[0:idx] 63 | } 64 | switch cType { 65 | case "application/json": 66 | return BindJson(r, out) 67 | case "application/x-www-form-urlencoded", "multipart/form-data": 68 | return BindForm(r, out) 69 | } 70 | return Error{Status: http.StatusUnsupportedMediaType} 71 | } 72 | 73 | // BindJson puts all json tagged values into struct fields 74 | func BindJson(r *http.Request, out interface{}) error { 75 | body, err := io.ReadAll(r.Body) 76 | if err != nil { 77 | return fmt.Errorf("Bind: io.ReadAll error %v", err) 78 | } 79 | if err := r.Body.Close(); err != nil { 80 | return fmt.Errorf("Bind: r.Body.Close error %v", err) 81 | } 82 | if err := json.Unmarshal(body, out); err != nil { 83 | return Error{ 84 | Status: http.StatusBadRequest, 85 | Err: fmt.Errorf("Bind: json.Unmarshal error %v", err), 86 | } 87 | } 88 | return nil 89 | } 90 | 91 | // BindQuery puts all query string values into struct fields with tag:"query" 92 | func BindQuery(r *http.Request, out interface{}) error { 93 | t := reflect.TypeOf(out) 94 | v := reflect.ValueOf(out) 95 | if t.Kind() == reflect.Ptr { 96 | v = v.Elem() 97 | } 98 | intSlice := reflect.TypeOf([]int{}) 99 | int64Slice := reflect.TypeOf([]int64{}) 100 | stringSlice := reflect.TypeOf([]string{}) 101 | toTypeSlice := func(vals reflect.Value, sliceType reflect.Type) interface{} { 102 | if sliceType == stringSlice { 103 | return vals.Interface() 104 | } 105 | newVals := reflect.New(sliceType).Elem() 106 | for i := 0; i < vals.Len(); i++ { 107 | val := vals.Index(i).String() 108 | var v interface{} 109 | switch sliceType { 110 | case intSlice: 111 | v, _ = strconv.Atoi(val) 112 | case int64Slice: 113 | v, _ = strconv.ParseInt(val, 10, 64) 114 | default: 115 | v = nil 116 | } 117 | if v != nil { 118 | newVals = reflect.Append(newVals, reflect.ValueOf(v)) 119 | } 120 | } 121 | return newVals.Interface() 122 | } 123 | for _, field := range structtag.GetFieldsByTag(out, "query") { 124 | tag := field.Tag 125 | if q := Query(r, tag); q != "" { 126 | vv := v.Field(field.Index) 127 | vk := vv.Kind() 128 | if vk != reflect.String { 129 | var val interface{} 130 | switch vk { 131 | case reflect.Int: 132 | val, _ = strconv.Atoi(q) 133 | case reflect.Int64: 134 | val, _ = strconv.ParseInt(q, 10, 64) 135 | case reflect.Slice: 136 | val = toTypeSlice(reflect.ValueOf(r.URL.Query()[tag]), vv.Type()) 137 | } 138 | vv.Set(reflect.ValueOf(val)) 139 | } else { 140 | vv.Set(reflect.ValueOf(q)) 141 | } 142 | } 143 | } 144 | return nil 145 | } 146 | 147 | // BindForm puts all struct fields with tag:"form" from a form request 148 | func BindForm(r *http.Request, out interface{}) error { 149 | t := reflect.TypeOf(out) 150 | v := reflect.ValueOf(out) 151 | if t.Kind() == reflect.Ptr { 152 | v = v.Elem() 153 | } 154 | cType := r.Header.Get("Content-Type") 155 | formValues := make(map[string]interface{}) 156 | if strings.HasPrefix(cType, "application/x-www-form-urlencoded") { 157 | r.ParseForm() 158 | for k := range r.PostForm { 159 | formValues[k] = r.PostFormValue(k) 160 | } 161 | } else if strings.Contains(cType, "multipart/form-data") { 162 | r.ParseMultipartForm(32 << 20) 163 | for k := range r.PostForm { 164 | formValues[k] = r.FormValue(k) 165 | } 166 | for k, v := range r.MultipartForm.File { 167 | if strings.HasSuffix(k, "[]") { 168 | formValues[k[:len(k)-2]] = v 169 | } else { 170 | formValues[k] = v[0] 171 | } 172 | } 173 | } 174 | if len(formValues) == 0 { 175 | return nil 176 | } 177 | for _, field := range structtag.GetFieldsByTag(out, "form") { 178 | tag := field.Tag 179 | if formVal, ok := formValues[tag]; ok { 180 | vv := v.Field(field.Index) 181 | vk := vv.Kind() 182 | var val interface{} 183 | if vk == reflect.String { 184 | if v, ok := formVal.(string); ok { 185 | val = v 186 | } 187 | } else { 188 | switch vk { 189 | case reflect.Int: 190 | v := formVal.(string) 191 | val, _ = strconv.Atoi(v) 192 | case reflect.Int64: 193 | v := formVal.(string) 194 | val, _ = strconv.ParseInt(v, 10, 64) 195 | case reflect.Float64: 196 | v := formVal.(string) 197 | val, _ = strconv.ParseFloat(v, 64) 198 | case reflect.Ptr: 199 | if vv.Type() == typeMultipartFileHeader { 200 | if fh, ok := formVal.(*multipart.FileHeader); ok { 201 | val = fh 202 | } 203 | } 204 | case reflect.Slice: 205 | if vv.Type() == typeMultipartFileHeaderSlice { 206 | val = formVal 207 | } 208 | } 209 | } 210 | if val != nil { 211 | vv.Set(reflect.ValueOf(val)) 212 | } 213 | } 214 | } 215 | return nil 216 | } 217 | 218 | func GetVals(ctx context.Context) map[string]interface{} { 219 | vars, ok := ctx.Value(keyVals).(map[string]interface{}) 220 | if ok { 221 | return vars 222 | } 223 | return make(map[string]interface{}) 224 | } 225 | 226 | func SetVal(ctx context.Context, key string, val interface{}) context.Context { 227 | vals := GetVals(ctx) 228 | vals[key] = val 229 | return context.WithValue(ctx, keyVals, vals) 230 | } 231 | 232 | func GetVal(ctx context.Context, key string) interface{} { 233 | val, ok := GetVals(ctx)[key] 234 | if ok { 235 | return val 236 | } 237 | return nil 238 | 239 | } 240 | 241 | // GetValues returns a map of all values from context 242 | func GetValues(r *http.Request) map[string]interface{} { 243 | return GetVals(r.Context()) 244 | } 245 | 246 | // SetValue stores a key value pair in context 247 | func SetValue(r *http.Request, key string, val interface{}) *http.Request { 248 | return r.WithContext(SetVal(r.Context(), key, val)) 249 | } 250 | 251 | // GetValue returns the stored value from context 252 | func GetValue(r *http.Request, key string) interface{} { 253 | return GetVal(r.Context(), key) 254 | } 255 | 256 | func refTypes(types ...reflect.Type) []reflect.Type { 257 | return types 258 | } 259 | 260 | func refVals(vals ...interface{}) (values []reflect.Value) { 261 | for _, v := range vals { 262 | values = append(values, reflect.ValueOf(v)) 263 | } 264 | return 265 | } 266 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package restruct 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | func TestGetSetValues(t *testing.T) { 10 | r := &http.Request{} 11 | r2 := SetValue(r, "a", "1") 12 | a := GetValue(r, "a") 13 | if a == "1" { 14 | t.Errorf("Want a to blank got 1 %v", a) 15 | } 16 | a = GetValue(r2, "a") 17 | if a != "1" { 18 | t.Errorf("Want a to be 1 got %v", a) 19 | } 20 | r = r2 21 | r = SetValue(r, "a", "2") 22 | r = SetValue(r, "b", "c") 23 | vals := GetValues(r) 24 | if vals["a"] != "2" { 25 | t.Errorf("Want a to be 2 got %v", vals["a"]) 26 | } 27 | if vals["b"] != "c" { 28 | t.Errorf("Want b to be c got %v", vals["b"]) 29 | } 30 | 31 | if x, ok := GetValue(r, "hello").(string); ok { 32 | t.Errorf("X %v %v", x, ok) 33 | } 34 | } 35 | 36 | func TestGetSetVals(t *testing.T) { 37 | ctx := context.Background() 38 | c2 := SetVal(ctx, "a", "1") 39 | a := GetVal(ctx, "a") 40 | if a == "1" { 41 | t.Errorf("Want a to blank got 1 %v", a) 42 | } 43 | a = GetVal(c2, "a") 44 | if a != "1" { 45 | t.Errorf("Want a to be 1 got %v", a) 46 | } 47 | ctx = c2 48 | ctx = SetVal(ctx, "a", "2") 49 | ctx = SetVal(ctx, "b", "c") 50 | vals := GetVals(ctx) 51 | if vals["a"] != "2" { 52 | t.Errorf("Want a to be 2 got %v", vals["a"]) 53 | } 54 | if vals["b"] != "c" { 55 | t.Errorf("Want b to be c got %v", vals["b"]) 56 | } 57 | 58 | if x, ok := GetVal(ctx, "hello").(string); ok { 59 | t.Errorf("X %v %v", x, ok) 60 | } 61 | } 62 | --------------------------------------------------------------------------------