├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── core ├── body.go ├── body_test.go ├── core.go ├── core_test.go ├── default.go ├── default_test.go ├── directive.go ├── directive_test.go ├── directiveruntime.go ├── error.go ├── errorhandler.go ├── errorhandler_test.go ├── file.go ├── file_test.go ├── fileable.go ├── fileable_test.go ├── fileslicable.go ├── fileslicable_test.go ├── form.go ├── form_test.go ├── formencoder.go ├── formencoder_test.go ├── formextractor.go ├── header.go ├── header_test.go ├── hybridcoder.go ├── hybridcoder_test.go ├── nonzero.go ├── nonzero_test.go ├── omitempty.go ├── option.go ├── option_test.go ├── owl.go ├── patch_test.go ├── path.go ├── path_test.go ├── query.go ├── query_test.go ├── registry.go ├── requestbuilder.go ├── requestbuilder_test.go ├── required.go ├── required_test.go ├── stringable.go ├── stringable_test.go ├── stringslicable.go ├── stringslicable_test.go └── typekind.go ├── go.mod ├── go.sum ├── httpin.go ├── httpin_test.go ├── integration ├── echo.go ├── echo_test.go ├── gochi.go ├── gochi_test.go ├── gorilla.go ├── gorilla_test.go ├── http.go └── http_test.go ├── internal ├── misc.go ├── misc_test.go ├── stringable.go ├── stringable_test.go ├── stringableadaptor.go └── stringableadaptor_test.go └── patch ├── patch.go └── patch_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: ggicci 4 | patreon: # ggicci 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # ggicci 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: / 5 | schedule: 6 | interval: daily 7 | - package-ecosystem: github-actions 8 | directory: / 9 | schedule: 10 | interval: daily 11 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: "1.23" 19 | 20 | - name: Test and generate coverage report 21 | run: make test/cover 22 | 23 | - name: Upload coverage to Codecov 24 | uses: codecov/codecov-action@v5 25 | with: 26 | name: codecov-umbrella 27 | token: ${{ secrets.CODECOV_TOKEN }} 28 | files: ./main.cover.out 29 | flags: unittests 30 | fail_ci_if_error: true 31 | verbose: true 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | .vscode 17 | 18 | # Project specified 19 | .DS_Store 20 | __debug* 21 | docs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ggicci 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | default: test 2 | 3 | SHELL=/usr/bin/env bash 4 | GO=go 5 | GOTEST=$(GO) test 6 | GOCOVER=$(GO) tool cover 7 | 8 | .PHONY: test 9 | test: test/cover test/report 10 | 11 | .PHONY: test/cover 12 | test/cover: 13 | $(GOTEST) -v -race -failfast -parallel 4 -cpu 4 -coverprofile main.cover.out ./... 14 | 15 | .PHONY: test/report 16 | test/report: 17 | if [[ "$$HOSTNAME" =~ "codespaces-"* ]]; then \ 18 | mkdir -p /tmp/httpin_test; \ 19 | $(GOCOVER) -html=main.cover.out -o /tmp/httpin_test/coverage.html; \ 20 | sudo python -m http.server -d /tmp/httpin_test -b localhost 80; \ 21 | else \ 22 | $(GOCOVER) -html=main.cover.out; \ 23 | fi 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | httpin logo 3 | 4 | 5 | # httpin - HTTP Input for Go 6 | 7 |

HTTP Request from/to Go Struct

8 | 9 |
10 | 11 | [![Go](https://github.com/ggicci/httpin/actions/workflows/go.yml/badge.svg?branch=main)](https://github.com/ggicci/httpin/actions/workflows/go.yml) [![documentation](https://github.com/ggicci/httpin/actions/workflows/documentation.yml/badge.svg?branch=documentation)](https://github.com/ggicci/httpin/actions/workflows/documentation.yml) [![codecov](https://codecov.io/gh/ggicci/httpin/branch/main/graph/badge.svg?token=RT61L9ngHj)](https://codecov.io/gh/ggicci/httpin) [![Go Report Card](https://goreportcard.com/badge/github.com/ggicci/httpin)](https://goreportcard.com/report/github.com/ggicci/httpin) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/ggicci/httpin.svg)](https://pkg.go.dev/github.com/ggicci/httpin) 12 | 13 | 14 | 15 | 20 | 21 | 22 | 25 | 26 |
16 | 17 | 18 | 19 |
23 | Documentation 24 |
27 | 28 |
29 | 30 | ## Core Features 31 | 32 | **httpin** helps you easily decode data from an HTTP request, including: 33 | 34 | - **Query parameters**, e.g. `?name=john&is_member=true` 35 | - **Headers**, e.g. `Authorization: xxx` 36 | - **Form data**, e.g. `username=john&password=******` 37 | - **JSON/XML Body**, e.g. `POST {"name":"john"}` 38 | - **Path variables**, e.g. `/users/{username}` 39 | - **File uploads** 40 | 41 | You **only** need to define a struct to receive/bind data from an HTTP request, **without** writing any parsing stuff code by yourself. 42 | 43 | Since v0.15.0, httpin also supports creating an HTTP request (`http.Request`) from a Go struct instance. 44 | 45 | **httpin** is: 46 | 47 | - **well documented**: at https://ggicci.github.io/httpin/ 48 | - **open integrated**: with [net/http](https://ggicci.github.io/httpin/integrations/http), [go-chi/chi](https://ggicci.github.io/httpin/integrations/gochi), [gorilla/mux](https://ggicci.github.io/httpin/integrations/gorilla), [gin-gonic/gin](https://ggicci.github.io/httpin/integrations/gin), etc. 49 | - **extensible** (advanced feature): by adding your custom directives. Read [httpin - custom directives](https://ggicci.github.io/httpin/directives/custom) for more details. 50 | 51 | ## Add Httpin Directives by Tagging the Struct Fields with `in` 52 | 53 | ```go 54 | type ListUsersInput struct { 55 | Token string `in:"query=access_token;header=x-access-token"` 56 | Page int `in:"query=page;default=1"` 57 | PerPage int `in:"query=per_page;default=20"` 58 | IsMember bool `in:"query=is_member"` 59 | Search *string `in:"query=search;omitempty"` 60 | } 61 | ``` 62 | 63 | ## How to decode an HTTP request to Go struct? 64 | 65 | ```go 66 | func ListUsers(rw http.ResponseWriter, r *http.Request) { 67 | input := r.Context().Value(httpin.Input).(*ListUsersInput) 68 | 69 | if input.IsMember { 70 | // Do sth. 71 | } 72 | // Do sth. 73 | } 74 | ``` 75 | 76 | ## How to encode a Go struct to HTTP request? 77 | 78 | ```go 79 | func SDKListUsers() { 80 | payload := &ListUsersInput{ 81 | Token: os.Getenv("MY_APP_ACCESS_TOKEN"), 82 | Page: 2, 83 | IsMember: true, 84 | } 85 | 86 | // Easy to remember, http.NewRequest -> httpin.NewRequest 87 | req, err := httpin.NewRequest("GET", "/users", payload) 88 | // ... 89 | } 90 | ``` 91 | 92 | ## Why this package? 93 | 94 | #### Compared with using `net/http` package 95 | 96 | ```go 97 | func ListUsers(rw http.ResponseWriter, r *http.Request) { 98 | page, err := strconv.ParseInt(r.FormValue("page"), 10, 64) 99 | if err != nil { 100 | // Invalid parameter: page. 101 | return 102 | } 103 | perPage, err := strconv.ParseInt(r.FormValue("per_page"), 10, 64) 104 | if err != nil { 105 | // Invalid parameter: per_page. 106 | return 107 | } 108 | isMember, err := strconv.ParseBool(r.FormValue("is_member")) 109 | if err != nil { 110 | // Invalid parameter: is_member. 111 | return 112 | } 113 | 114 | // Do sth. 115 | } 116 | ``` 117 | 118 | | Benefits | Before (use net/http package) | After (use ggicci/httpin package) | 119 | | ----------------------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------- | 120 | | ⌛️ Developer Time | 😫 Expensive (too much parsing stuff code) | 🚀 **Faster** (define the struct for receiving input data and leave the parsing job to httpin) | 121 | | ♻️ Code Repetition Rate | 😞 High | 😍 **Lower** | 122 | | 📖 Code Readability | 😟 Poor | 🤩 **Highly readable** | 123 | | 🔨 Maintainability | 😡 Poor | 🥰 **Highly maintainable** | 124 | 125 | ## Alternatives and Similars 126 | 127 | - [gorilla/schema](https://github.com/gorilla/schema): converts structs to and from form values 128 | - [google/go-querystring](https://github.com/google/go-querystring): encoding structs into URL query parameters 129 | -------------------------------------------------------------------------------- /core/body.go: -------------------------------------------------------------------------------- 1 | // directive: "body" 2 | // https://ggicci.github.io/httpin/directives/body 3 | 4 | package core 5 | 6 | import ( 7 | "bytes" 8 | "encoding/json" 9 | "encoding/xml" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "strings" 14 | 15 | "github.com/ggicci/httpin/internal" 16 | ) 17 | 18 | // ErrUnknownBodyFormat is returned when a serializer for the specified body format has not been specified. 19 | var ErrUnknownBodyFormat = errors.New("unknown body format") 20 | 21 | // DirectiveBody is the implementation of the "body" directive. 22 | type DirectiveBody struct{} 23 | 24 | func (db *DirectiveBody) Decode(rtm *DirectiveRuntime) error { 25 | req := rtm.GetRequest() 26 | bodyFormat, bodySerializer := db.getSerializer(rtm) 27 | if bodySerializer == nil { 28 | return fmt.Errorf("%w: %q", ErrUnknownBodyFormat, bodyFormat) 29 | } 30 | if err := bodySerializer.Decode(req.Body, rtm.Value.Elem().Addr().Interface()); err != nil { 31 | return err 32 | } 33 | return nil 34 | } 35 | 36 | func (db *DirectiveBody) Encode(rtm *DirectiveRuntime) error { 37 | bodyFormat, bodySerializer := db.getSerializer(rtm) 38 | if bodySerializer == nil { 39 | return fmt.Errorf("%w: %q", ErrUnknownBodyFormat, bodyFormat) 40 | } 41 | if bodyReader, err := bodySerializer.Encode(rtm.Value.Interface()); err != nil { 42 | return err 43 | } else { 44 | rtm.GetRequestBuilder().SetBody(bodyFormat, io.NopCloser(bodyReader)) 45 | rtm.MarkFieldSet(true) 46 | return nil 47 | } 48 | } 49 | 50 | func (*DirectiveBody) getSerializer(rtm *DirectiveRuntime) (bodyFormat string, serializer BodySerializer) { 51 | bodyFormat = "json" 52 | if len(rtm.Directive.Argv) > 0 { 53 | bodyFormat = strings.ToLower(rtm.Directive.Argv[0]) 54 | } 55 | serializer = getBodySerializer(bodyFormat) 56 | return 57 | } 58 | 59 | var bodyFormats = map[string]BodySerializer{ 60 | "json": &JSONBody{}, 61 | "xml": &XMLBody{}, 62 | } 63 | 64 | // BodySerializer is the interface for encoding and decoding the request body. 65 | // Common body formats are: json, xml, yaml, etc. 66 | type BodySerializer interface { 67 | // Decode decodes the request body into the specified object. 68 | Decode(src io.Reader, dst any) error 69 | // Encode encodes the specified object into a reader for the request body. 70 | Encode(src any) (io.Reader, error) 71 | } 72 | 73 | // RegisterBodyFormat registers a new data formatter for the body request, which has the 74 | // BodyEncoderDecoder interface implemented. Panics on taken name, empty name or nil 75 | // decoder. Pass parameter force (true) to ignore the name conflict. 76 | // 77 | // The BodyEncoderDecoder is used by the body directive to decode and encode the data in 78 | // the given format (body format). 79 | // 80 | // It is also useful when you want to override the default registered 81 | // BodyEncoderDecoder. For example, the default JSON decoder is borrowed from 82 | // encoding/json. You can replace it with your own implementation, e.g. 83 | // json-iterator/go. For example: 84 | // 85 | // func init() { 86 | // RegisterBodyFormat("json", &myJSONBody{}, true) // force register, replace the old one 87 | // RegisterBodyFormat("yaml", &myYAMLBody{}) // register a new body format "yaml" 88 | // } 89 | func RegisterBodyFormat(format string, body BodySerializer, force ...bool) { 90 | internal.PanicOnError( 91 | registerBodyFormat(format, body, force...), 92 | ) 93 | } 94 | 95 | func getBodySerializer(bodyFormat string) BodySerializer { 96 | return bodyFormats[bodyFormat] 97 | } 98 | 99 | type JSONBody struct{} 100 | 101 | func (de *JSONBody) Decode(src io.Reader, dst any) error { 102 | return json.NewDecoder(src).Decode(dst) 103 | } 104 | 105 | func (en *JSONBody) Encode(src any) (io.Reader, error) { 106 | var buf bytes.Buffer 107 | if err := json.NewEncoder(&buf).Encode(src); err != nil { 108 | return nil, err 109 | } 110 | return &buf, nil 111 | } 112 | 113 | type XMLBody struct{} 114 | 115 | func (de *XMLBody) Decode(src io.Reader, dst any) error { 116 | return xml.NewDecoder(src).Decode(dst) 117 | } 118 | 119 | func (en *XMLBody) Encode(src any) (io.Reader, error) { 120 | var buf bytes.Buffer 121 | if err := xml.NewEncoder(&buf).Encode(src); err != nil { 122 | return nil, err 123 | } 124 | return &buf, nil 125 | } 126 | 127 | func registerBodyFormat(format string, body BodySerializer, force ...bool) error { 128 | ignoreConflict := len(force) > 0 && force[0] 129 | format = strings.ToLower(format) 130 | if !ignoreConflict && getBodySerializer(format) != nil { 131 | return fmt.Errorf("duplicate body format: %q", format) 132 | } 133 | if format == "" { 134 | return errors.New("body format cannot be empty") 135 | } 136 | if body == nil { 137 | return errors.New("body serializer cannot be nil") 138 | } 139 | bodyFormats[format] = body 140 | return nil 141 | } 142 | -------------------------------------------------------------------------------- /core/body_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "encoding/xml" 7 | "errors" 8 | "io" 9 | "net/http" 10 | "net/url" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | type LanguageLevel struct { 18 | Language string `json:"lang" xml:"lang"` 19 | Level int `json:"level" xml:"level"` 20 | } 21 | 22 | type BodyPayload struct { 23 | Name string `json:"name" xml:"name"` 24 | Age int `json:"age" xml:"age"` 25 | IsNative bool `json:"is_native" xml:"is_native"` 26 | Hobbies []string `json:"hobbies" xml:"hobbies"` 27 | Languages []*LanguageLevel `json:"languages" xml:"languages"` 28 | } 29 | 30 | type BodyPayloadInJSON struct { 31 | Body *BodyPayload `in:"body=json"` 32 | } 33 | 34 | type BodyPayloadInXML struct { 35 | Body *BodyPayload `in:"body=xml"` 36 | } 37 | 38 | type YouCannotUseFormAndBodyAtTheSameTime struct { 39 | Page int `in:"form=page"` 40 | PageSize int `in:"form=page_size"` 41 | Body *BodyPayload `in:"body=json"` 42 | } 43 | 44 | var sampleBodyPayloadInJSONText = ` 45 | { 46 | "name": "Elia", 47 | "is_native": false, 48 | "age": 14, 49 | "hobbies": ["Gaming", "Drawing"], 50 | "languages": [ 51 | {"lang": "English", "level": 10}, 52 | {"lang": "Japanese", "level": 3} 53 | ] 54 | }` 55 | 56 | var sampleBodyPayloadInJSONObject = &BodyPayloadInJSON{ 57 | Body: &BodyPayload{ 58 | Name: "Elia", 59 | Age: 14, 60 | IsNative: false, 61 | Hobbies: []string{"Gaming", "Drawing"}, 62 | Languages: []*LanguageLevel{ 63 | {"English", 10}, 64 | {"Japanese", 3}, 65 | }, 66 | }, 67 | } 68 | 69 | var sampleBodyPayloadInXMLText = ` 70 | 71 | Elia 72 | 14 73 | false 74 | Gaming 75 | Drawing 76 | 77 | English 78 | 10 79 | 80 | 81 | Japanese 82 | 3 83 | 84 | ` 85 | 86 | var sampleBodyPayloadInXMLObject = &BodyPayloadInXML{ 87 | Body: &BodyPayload{ 88 | Name: "Elia", 89 | Age: 14, 90 | IsNative: false, 91 | Hobbies: []string{"Gaming", "Drawing"}, 92 | Languages: []*LanguageLevel{ 93 | {"English", 10}, 94 | {"Japanese", 3}, 95 | }, 96 | }, 97 | } 98 | 99 | func TestBodyDirective_Decode_JSON(t *testing.T) { 100 | assert := assert.New(t) 101 | co, err := New(BodyPayloadInJSON{}) 102 | assert.NoError(err) 103 | 104 | r, _ := http.NewRequest("GET", "https://example.com", nil) 105 | r.Form = make(url.Values) 106 | r.Form.Set("page", "4") 107 | r.Form.Set("page_size", "30") 108 | r.Body = makeBodyReader(sampleBodyPayloadInJSONText) 109 | r.Header.Set("Content-Type", "application/json") 110 | gotValue, err := co.Decode(r) 111 | assert.NoError(err) 112 | assert.Equal(sampleBodyPayloadInJSONObject, gotValue) 113 | } 114 | 115 | func TestBodyDirective_Decode_XML(t *testing.T) { 116 | assert := assert.New(t) 117 | co, err := New(BodyPayloadInXML{}) 118 | assert.NoError(err) 119 | 120 | r, _ := http.NewRequest("GET", "https://example.com", nil) 121 | r.Body = makeBodyReader(sampleBodyPayloadInXMLText) 122 | r.Header.Set("Content-Type", "application/xml") 123 | 124 | gotValue, err := co.Decode(r) 125 | assert.NoError(err) 126 | assert.Equal(sampleBodyPayloadInXMLObject, gotValue) 127 | } 128 | 129 | func TestBodyDirective_Decode_ErrUnknownBodyFormat(t *testing.T) { 130 | type UnknownBodyFormatPayload struct { 131 | Body *BodyPayload `in:"body=yaml"` 132 | } 133 | 134 | co, err := New(UnknownBodyFormatPayload{}) 135 | assert.NoError(t, err) 136 | req, _ := http.NewRequest("GET", "https://example.com", nil) 137 | req.Body = makeBodyReader(sampleBodyPayloadInJSONText) 138 | _, err = co.Decode(req) 139 | assert.ErrorContains(t, err, "unknown body format: \"yaml\"") 140 | } 141 | 142 | func TestBodyDirective_Decode_ErrConflictWithFormDirective(t *testing.T) { 143 | co, err := New(YouCannotUseFormAndBodyAtTheSameTime{}) 144 | assert.NoError(t, err) 145 | req, err := co.NewRequest("POST", "/data", &YouCannotUseFormAndBodyAtTheSameTime{ 146 | Page: 1, 147 | PageSize: 20, 148 | Body: sampleBodyPayloadInJSONObject.Body, 149 | }) 150 | assert.ErrorContains(t, err, "cannot use both form and body directive at the same time") 151 | assert.Nil(t, req) 152 | } 153 | 154 | type yamlBody struct{} 155 | 156 | var errYamlNotImplemented = errors.New("yaml not implemented") 157 | 158 | func (de *yamlBody) Decode(src io.Reader, dst any) error { 159 | return errYamlNotImplemented // for test only 160 | } 161 | 162 | func (en *yamlBody) Encode(src any) (io.Reader, error) { 163 | return nil, errYamlNotImplemented // for test only 164 | } 165 | 166 | type YamlInput struct { 167 | Body map[string]any `in:"body=yaml"` 168 | } 169 | 170 | func TestRegisterBodyFormat(t *testing.T) { 171 | assert.NotPanics(t, func() { 172 | RegisterBodyFormat("yaml", &yamlBody{}) 173 | }) 174 | assert.Panics(t, func() { 175 | RegisterBodyFormat("yaml", &yamlBody{}) 176 | }) 177 | 178 | co, err := New(YamlInput{}) 179 | assert.NoError(t, err) 180 | 181 | r, _ := http.NewRequest("GET", "https://example.com", nil) 182 | r.Body = makeBodyReader(`version: "3"`) 183 | 184 | gotValue, err := co.Decode(r) 185 | assert.ErrorIs(t, err, errYamlNotImplemented) 186 | assert.Nil(t, gotValue) 187 | unregisterBodyFormat("yaml") 188 | } 189 | 190 | func TestRegisterBodyFormat_ErrNilBodySerializer(t *testing.T) { 191 | assert.Panics(t, func() { 192 | RegisterBodyFormat("toml", nil) 193 | }) 194 | } 195 | 196 | func TestRegisterBodyFormat_ForceRegister(t *testing.T) { 197 | assert.NotPanics(t, func() { 198 | RegisterBodyFormat("yaml", &yamlBody{}, true) 199 | }) 200 | assert.NotPanics(t, func() { 201 | RegisterBodyFormat("yaml", &yamlBody{}, true) 202 | }) 203 | unregisterBodyFormat("yaml") 204 | } 205 | 206 | func TestRegisterBodyFormat_ForceRegisterWithEmptyBodyFormat(t *testing.T) { 207 | assert.PanicsWithError(t, "httpin: body format cannot be empty", func() { 208 | RegisterBodyFormat("", &yamlBody{}, true) 209 | }) 210 | } 211 | 212 | func TestBodyDirective_NewRequest_JSON(t *testing.T) { 213 | assert := assert.New(t) 214 | co, err := New(BodyPayloadInJSON{}) 215 | assert.NoError(err) 216 | req, err := co.NewRequest("POST", "/data", sampleBodyPayloadInJSONObject) 217 | expected, _ := http.NewRequest("POST", "/data", nil) 218 | expected.Header.Set("Content-Type", "application/json") 219 | assert.NoError(err) 220 | var body bytes.Buffer 221 | assert.NoError(json.NewEncoder(&body).Encode(sampleBodyPayloadInJSONObject.Body)) 222 | expected.Body = io.NopCloser(&body) 223 | assert.Equal(expected, req) 224 | 225 | // On the server side (decode). 226 | gotValue, err := co.Decode(req) 227 | assert.NoError(err) 228 | got, ok := gotValue.(*BodyPayloadInJSON) 229 | assert.True(ok) 230 | assert.Equal(sampleBodyPayloadInJSONObject, got) 231 | } 232 | 233 | func TestBodyDirective_NewRequest_XML(t *testing.T) { 234 | assert := assert.New(t) 235 | co, err := New(BodyPayloadInXML{}) 236 | assert.NoError(err) 237 | req, err := co.NewRequest("POST", "/data", sampleBodyPayloadInXMLObject) 238 | expected, _ := http.NewRequest("POST", "/data", nil) 239 | expected.Header.Set("Content-Type", "application/xml") 240 | assert.NoError(err) 241 | var body bytes.Buffer 242 | assert.NoError(xml.NewEncoder(&body).Encode(sampleBodyPayloadInXMLObject.Body)) 243 | expected.Body = io.NopCloser(&body) 244 | assert.Equal(expected, req) 245 | 246 | // On the server side (decode). 247 | gotValue, err := co.Decode(req) 248 | assert.NoError(err) 249 | got, ok := gotValue.(*BodyPayloadInXML) 250 | assert.True(ok) 251 | assert.Equal(sampleBodyPayloadInXMLObject, got) 252 | } 253 | 254 | func TestBodyDirective_NewRequest_ErrUnknownBodyFormat(t *testing.T) { 255 | type UnknownBodyFormatPayload struct { 256 | Body *BodyPayload `in:"body=yaml"` 257 | } 258 | query := &UnknownBodyFormatPayload{ 259 | Body: nil, 260 | } 261 | co, err := New(UnknownBodyFormatPayload{}) 262 | assert.NoError(t, err) 263 | req, err := co.NewRequest("PUT", "/apples/10", query) 264 | assert.ErrorContains(t, err, "unknown body format: \"yaml\"") 265 | assert.Nil(t, req) 266 | } 267 | 268 | func unregisterBodyFormat(format string) { 269 | delete(bodyFormats, format) 270 | } 271 | 272 | func makeBodyReader(text string) io.ReadCloser { 273 | return io.NopCloser(strings.NewReader(text)) 274 | } 275 | -------------------------------------------------------------------------------- /core/core.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "mime" 8 | "net/http" 9 | "reflect" 10 | "sort" 11 | "sync" 12 | 13 | "github.com/ggicci/owl" 14 | ) 15 | 16 | var builtResolvers sync.Map // map[reflect.Type]*owl.Resolver 17 | 18 | // Core is the Core of httpin. It holds the resolver of a specific struct type. 19 | // Who is responsible for decoding an HTTP request to an instance of such struct 20 | // type. 21 | type Core struct { 22 | resolver *owl.Resolver // for decoding 23 | scanResolver *owl.Resolver // for encoding 24 | errorHandler ErrorHandler 25 | maxMemory int64 // in bytes 26 | enableNestedDirectives bool 27 | resolverMu sync.RWMutex 28 | } 29 | 30 | // New creates a new Core instance for the given intpuStruct. It will build a resolver 31 | // for the inputStruct and apply the given options to the Core instance. The Core instance 32 | // is responsible for both: 33 | // 34 | // - decoding an HTTP request to an instance of the inputStruct; 35 | // - encoding an instance of the inputStruct to an HTTP request. 36 | func New(inputStruct any, opts ...Option) (*Core, error) { 37 | resolver, err := buildResolver(inputStruct) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | core := &Core{ 43 | resolver: resolver, 44 | } 45 | 46 | // Apply default options and user custom options to the 47 | var allOptions []Option 48 | defaultOptions := []Option{ 49 | WithMaxMemory(defaultMaxMemory), 50 | WithNestedDirectivesEnabled(globalNestedDirectivesEnabled), 51 | } 52 | allOptions = append(allOptions, defaultOptions...) 53 | allOptions = append(allOptions, opts...) 54 | 55 | for _, opt := range allOptions { 56 | if err := opt(core); err != nil { 57 | return nil, fmt.Errorf("invalid option: %w", err) 58 | } 59 | } 60 | 61 | return core, nil 62 | } 63 | 64 | // Decode decodes an HTTP request to an instance of the input struct and returns 65 | // its pointer. For example: 66 | // 67 | // New(Input{}).Decode(req) -> *Input 68 | func (c *Core) Decode(req *http.Request) (any, error) { 69 | // Create the input struct instance. Used to be created by owl.Resolve(). 70 | value := reflect.New(c.resolver.Type).Interface() 71 | if err := c.DecodeTo(req, value); err != nil { 72 | return nil, err 73 | } else { 74 | return value, nil 75 | } 76 | } 77 | 78 | // DecodeTo decodes an HTTP request to the given value. The value must be a pointer 79 | // to the struct instance of the type that the Core instance holds. 80 | func (c *Core) DecodeTo(req *http.Request, value any) (err error) { 81 | if err = c.parseRequestForm(req); err != nil { 82 | return fmt.Errorf("failed to parse request form: %w", err) 83 | } 84 | 85 | err = c.resolver.ResolveTo( 86 | value, 87 | owl.WithNamespace(decoderNamespace), 88 | owl.WithValue(CtxRequest, req), 89 | owl.WithNestedDirectivesEnabled(c.enableNestedDirectives), 90 | ) 91 | if err != nil && !errors.Is(err, owl.ErrInvalidResolveTarget) { 92 | return NewInvalidFieldError(err) 93 | } 94 | return err 95 | } 96 | 97 | // NewRequest wraps NewRequestWithContext using context.Background(), see 98 | // NewRequestWithContext. 99 | func (c *Core) NewRequest(method string, url string, input any) (*http.Request, error) { 100 | return c.NewRequestWithContext(context.Background(), method, url, input) 101 | } 102 | 103 | // NewRequestWithContext turns the given input struct into an HTTP request. Note 104 | // that the Core instance is bound to a specific type of struct. Which means 105 | // when the given input is not the type of the struct that the Core instance 106 | // holds, error of type mismatch will be returned. In order to avoid this error, 107 | // you can always use httpin.NewRequest() instead. Which will create a Core 108 | // instance for you on demand. There's no performance penalty for doing so. 109 | // Because there's a cache layer for all the Core instances. 110 | func (c *Core) NewRequestWithContext(ctx context.Context, method string, url string, input any) (*http.Request, error) { 111 | c.prepareScanResolver() 112 | req, err := http.NewRequestWithContext(ctx, method, url, nil) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | rb := NewRequestBuilder(ctx) 118 | 119 | // NOTE(ggicci): the error returned a joined error by using errors.Join. 120 | if err = c.scanResolver.Scan( 121 | input, 122 | owl.WithNamespace(encoderNamespace), 123 | owl.WithValue(CtxRequestBuilder, rb), 124 | owl.WithNestedDirectivesEnabled(c.enableNestedDirectives), 125 | ); err != nil { 126 | // err is a list of *owl.ScanError that joined by errors.Join. 127 | if errs, ok := err.(interface{ Unwrap() []error }); ok { 128 | var invalidFieldErrors MultiInvalidFieldError 129 | for _, err := range errs.Unwrap() { 130 | invalidFieldErrors = append(invalidFieldErrors, NewInvalidFieldError(err)) 131 | } 132 | return nil, invalidFieldErrors 133 | } else { 134 | return nil, err // should never happen, just in case 135 | } 136 | } 137 | 138 | // Populate the request with the encoded values. 139 | if err := rb.Populate(req); err != nil { 140 | return nil, fmt.Errorf("failed to populate request: %w", err) 141 | } 142 | 143 | return req, nil 144 | } 145 | 146 | // GetErrorHandler returns the error handler of the core if set, or the global 147 | // custom error handler. 148 | func (c *Core) GetErrorHandler() ErrorHandler { 149 | if c.errorHandler != nil { 150 | return c.errorHandler 151 | } 152 | 153 | return globalCustomErrorHandler 154 | } 155 | 156 | func (c *Core) prepareScanResolver() { 157 | c.resolverMu.RLock() 158 | if c.scanResolver == nil { 159 | c.resolverMu.RUnlock() 160 | c.resolverMu.Lock() 161 | defer c.resolverMu.Unlock() 162 | 163 | if c.scanResolver == nil { 164 | c.scanResolver = c.resolver.Copy() 165 | 166 | // Reorder the directives to make sure the "default" and "nonzero" directives work properly. 167 | c.scanResolver.Iterate(func(r *owl.Resolver) error { 168 | sort.Sort(directiveOrderForEncoding(r.Directives)) 169 | return nil 170 | }) 171 | } 172 | } else { 173 | c.resolverMu.RUnlock() 174 | } 175 | } 176 | 177 | func (c *Core) parseRequestForm(req *http.Request) (err error) { 178 | ct, _, _ := mime.ParseMediaType(req.Header.Get("Content-Type")) 179 | if ct == "multipart/form-data" { 180 | err = req.ParseMultipartForm(c.maxMemory) 181 | } else { 182 | err = req.ParseForm() 183 | } 184 | return 185 | } 186 | 187 | // buildResolver builds a resolver for the inputStruct. It will run normalizations 188 | // on the resolver and cache it. 189 | func buildResolver(inputStruct any) (*owl.Resolver, error) { 190 | resolver, err := owl.New(inputStruct) 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | // Returns the cached resolver if it's already built. 196 | if cached, ok := builtResolvers.Load(resolver.Type); ok { 197 | return cached.(*owl.Resolver), nil 198 | } 199 | 200 | // Normalize the resolver before caching it. 201 | if err := normalizeResolver(resolver); err != nil { 202 | return nil, err 203 | } 204 | 205 | // Cache the resolver. 206 | builtResolvers.Store(resolver.Type, resolver) 207 | return resolver, nil 208 | } 209 | 210 | // normalizeResolver normalizes the resolvers by running a series of 211 | // normalizations on every field resolver. 212 | func normalizeResolver(r *owl.Resolver) error { 213 | normalize := func(r *owl.Resolver) error { 214 | for _, fn := range []func(*owl.Resolver) error{ 215 | removeDecoderDirective, // backward compatibility, use "coder" instead 216 | removeCoderDirective, // "coder" takes precedence over "decoder" 217 | ensureDirectiveExecutorsRegistered, // always the last one 218 | } { 219 | if err := fn(r); err != nil { 220 | return err 221 | } 222 | } 223 | return nil 224 | } 225 | 226 | return r.Iterate(normalize) 227 | } 228 | 229 | func removeDecoderDirective(r *owl.Resolver) error { 230 | return reserveCoderDirective(r, "decoder") 231 | } 232 | 233 | func removeCoderDirective(r *owl.Resolver) error { 234 | return reserveCoderDirective(r, "coder") 235 | } 236 | 237 | // reserveCoderDirective removes the directive from the resolver. name is "coder" or "decoder". 238 | // The "decoder"/"coder"are two special directives which do nothing, but an indicator of 239 | // overriding the decoder and encoder for a specific field. 240 | func reserveCoderDirective(r *owl.Resolver, name string) error { 241 | d := r.RemoveDirective(name) 242 | if d == nil { 243 | return nil 244 | } 245 | if len(d.Argv) == 0 { 246 | return fmt.Errorf("directive %s: missing coder name", name) 247 | } 248 | 249 | if isFileType(r.Type) { 250 | return fmt.Errorf("directive %s: cannot be used on a file type field", name) 251 | } 252 | 253 | namedAdaptor := namedStringableAdaptors[d.Argv[0]] 254 | if namedAdaptor == nil { 255 | return fmt.Errorf("directive %s: %w: %q", name, ErrUnregisteredCoder, d.Argv[0]) 256 | } 257 | 258 | r.Context = context.WithValue(r.Context, CtxCustomCoder, namedAdaptor) 259 | return nil 260 | } 261 | 262 | // ensureDirectiveExecutorsRegistered ensures all directives that defined in the 263 | // resolver are registered in the executor registry. 264 | func ensureDirectiveExecutorsRegistered(r *owl.Resolver) error { 265 | for _, d := range r.Directives { 266 | if decoderNamespace.LookupExecutor(d.Name) == nil { 267 | return fmt.Errorf("%w: %q", ErrUnregisteredDirective, d.Name) 268 | } 269 | // NOTE: don't need to check encoderNamespace because a directive 270 | // will always be registered in both namespaces. See RegisterDirective(). 271 | } 272 | return nil 273 | } 274 | 275 | type directiveOrderForEncoding []*owl.Directive 276 | 277 | func (d directiveOrderForEncoding) Len() int { 278 | return len(d) 279 | } 280 | 281 | func (d directiveOrderForEncoding) Less(i, j int) bool { 282 | if d[i].Name == "default" { 283 | return true // always the first one to run 284 | } else if d[i].Name == "nonzero" { 285 | return true // always the second one to run 286 | } 287 | return false 288 | } 289 | 290 | func (d directiveOrderForEncoding) Swap(i, j int) { 291 | d[i], d[j] = d[j], d[i] 292 | } 293 | -------------------------------------------------------------------------------- /core/default.go: -------------------------------------------------------------------------------- 1 | // directive: "default" 2 | // https://ggicci.github.io/httpin/directives/default 3 | 4 | package core 5 | 6 | import ( 7 | "mime/multipart" 8 | ) 9 | 10 | type DirectiveDefault struct{} 11 | 12 | func (*DirectiveDefault) Decode(rtm *DirectiveRuntime) error { 13 | if rtm.IsFieldSet() { 14 | return nil // noop, the field was set by a former executor 15 | } 16 | 17 | // Transform: 18 | // 1. ctx.Argv -> input values 19 | // 2. ["default"] -> keys 20 | extractor := &FormExtractor{ 21 | Runtime: rtm, 22 | Form: multipart.Form{ 23 | Value: map[string][]string{ 24 | "default": rtm.Directive.Argv, 25 | }, 26 | }, 27 | } 28 | return extractor.Extract("default") 29 | } 30 | 31 | func (*DirectiveDefault) Encode(rtm *DirectiveRuntime) error { 32 | if !rtm.Value.IsZero() { 33 | return nil // skip if the field is not empty 34 | } 35 | var adapt AnyStringableAdaptor 36 | coder := rtm.GetCustomCoder() 37 | if coder != nil { 38 | adapt = coder.Adapt 39 | } 40 | if stringSlicable, err := NewStringSlicable(rtm.Value, adapt); err != nil { 41 | return err 42 | } else { 43 | return stringSlicable.FromStringSlice(rtm.Directive.Argv) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /core/default_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "testing" 7 | "time" 8 | 9 | "github.com/ggicci/httpin/internal" 10 | "github.com/ggicci/httpin/patch" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestDirectiveDefault_Decode(t *testing.T) { 15 | type ThingWithDefaultValues struct { 16 | Page int `in:"form=page;default=1"` 17 | PointerPage *int `in:"form=pointer_page;default=1"` 18 | PatchPage patch.Field[int] `in:"form=patch_page;default=1"` 19 | PerPage int `in:"form=per_page;default=20"` 20 | StateList []string `in:"form=state;default=pending,in_progress,failed"` 21 | PatchStateList patch.Field[[]string] `in:"form=patch_state;default=a,b,c"` 22 | } 23 | 24 | r, _ := http.NewRequest("GET", "/", nil) 25 | r.Form = url.Values{ 26 | "page": {"7"}, 27 | "pointer_page": {"9"}, 28 | "patch_page": {"11"}, 29 | "state": {}, 30 | "patch_state": {}, 31 | } 32 | expected := &ThingWithDefaultValues{ 33 | Page: 7, 34 | PointerPage: internal.Pointerize[int](9), 35 | PatchPage: patch.Field[int]{Value: 11, Valid: true}, 36 | PerPage: 20, 37 | StateList: []string{"pending", "in_progress", "failed"}, 38 | PatchStateList: patch.Field[[]string]{Value: []string{"a", "b", "c"}, Valid: true}, 39 | } 40 | co, err := New(ThingWithDefaultValues{}) 41 | assert.NoError(t, err) 42 | got, err := co.Decode(r) 43 | assert.NoError(t, err) 44 | assert.Equal(t, expected, got) 45 | } 46 | 47 | // FIX: https://github.com/ggicci/httpin/issues/77 48 | // Decode parameter struct with default values only works the first time 49 | func TestDirectiveDeafult_Decode_DecodeTwice(t *testing.T) { 50 | type ThingWithDefaultValues struct { 51 | Id uint `in:"query=id;required"` 52 | Page int `in:"query=page;default=1"` 53 | PerPage int `in:"query=page_size;default=127"` 54 | } 55 | 56 | r, _ := http.NewRequest("GET", "/?id=123", nil) 57 | expected := &ThingWithDefaultValues{ 58 | Id: 123, 59 | Page: 1, 60 | PerPage: 127, 61 | } 62 | 63 | co, err := New(ThingWithDefaultValues{}) 64 | assert.NoError(t, err) 65 | 66 | // First decode works as expected 67 | xxx, err := co.Decode(r) 68 | assert.NoError(t, err) 69 | assert.Equal(t, expected, xxx) 70 | 71 | // Second decode generates error 72 | xxx, err = co.Decode(r) 73 | assert.NoError(t, err) 74 | assert.Equal(t, expected, xxx) 75 | } 76 | 77 | func TestDirectiveDefault_NewRequest(t *testing.T) { 78 | type ListTicketRequest struct { 79 | Page int `in:"query=page;default=1"` 80 | PerPage int `in:"query=per_page;default=20"` 81 | States []string `in:"query=state;default=assigned,in_progress"` 82 | } 83 | 84 | co, err := New(ListTicketRequest{}) 85 | assert.NoError(t, err) 86 | 87 | payload := &ListTicketRequest{ 88 | Page: 2, 89 | } 90 | expected, _ := http.NewRequest("GET", "/tickets", nil) 91 | expected.URL.RawQuery = url.Values{ 92 | "page": {"2"}, 93 | "per_page": {"20"}, 94 | "state": {"assigned", "in_progress"}, 95 | }.Encode() 96 | req, err := co.NewRequest("GET", "/tickets", payload) 97 | assert.NoError(t, err) 98 | assert.Equal(t, expected, req) 99 | } 100 | 101 | func TestDirectiveDefault_NewRequest_WithNamedCoder(t *testing.T) { 102 | registerMyDate() 103 | type ListUsersRequest struct { 104 | Page int `in:"query=page;default=1"` 105 | PerPage int `in:"query=per_page;default=20"` 106 | RegistrationDate time.Time `in:"query=registration_date;default=2020-01-01;coder=mydate"` 107 | } 108 | 109 | co, err := New(ListUsersRequest{}) 110 | assert.NoError(t, err) 111 | 112 | payload := &ListUsersRequest{ 113 | Page: 2, 114 | PerPage: 10, 115 | } 116 | expected, _ := http.NewRequest("GET", "/users", nil) 117 | expected.URL.RawQuery = url.Values{ 118 | "page": {"2"}, 119 | "per_page": {"10"}, 120 | "registration_date": {"2020-01-01"}, 121 | }.Encode() 122 | req, err := co.NewRequest("GET", "/users", payload) 123 | assert.NoError(t, err) 124 | assert.Equal(t, expected, req) 125 | unregisterMyDate() 126 | } 127 | -------------------------------------------------------------------------------- /core/directive.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/ggicci/httpin/internal" 8 | "github.com/ggicci/owl" 9 | ) 10 | 11 | func init() { 12 | RegisterDirective("form", &DirectvieForm{}) 13 | RegisterDirective("query", &DirectiveQuery{}) 14 | RegisterDirective("header", &DirectiveHeader{}) 15 | RegisterDirective("body", &DirectiveBody{}) 16 | RegisterDirective("required", &DirectiveRequired{}) 17 | RegisterDirective("default", &DirectiveDefault{}) 18 | RegisterDirective("nonzero", &DirectiveNonzero{}) 19 | registerDirective("path", defaultPathDirective) 20 | registerDirective("omitempty", &DirectiveOmitEmpty{}) 21 | 22 | // decoder is a special executor which does nothing, but is an indicator of 23 | // overriding the decoder for a specific field. 24 | registerDirective("decoder", noopDirective) 25 | registerDirective("coder", noopDirective) 26 | } 27 | 28 | var ( 29 | // decoderNamespace is the namespace for registering directive executors that are 30 | // used to decode the http request to input struct. 31 | decoderNamespace = owl.NewNamespace() 32 | 33 | // encoderNamespace is the namespace for registering directive executors that are 34 | // used to encode the input struct to http request. 35 | encoderNamespace = owl.NewNamespace() 36 | 37 | // reservedExecutorNames are the names that cannot be used to register user defined directives 38 | reservedExecutorNames = []string{"decoder", "coder"} 39 | 40 | noopDirective = &directiveNoop{} 41 | ) 42 | 43 | type DirectiveExecutor interface { 44 | // Encode encodes the field of the input struct to the HTTP request. 45 | Encode(*DirectiveRuntime) error 46 | 47 | // Decode decodes the field of the input struct from the HTTP request. 48 | Decode(*DirectiveRuntime) error 49 | } 50 | 51 | // RegisterDirective registers a DirectiveExecutor with the given directive name. The 52 | // directive should be able to both extract the value from the HTTP request and build 53 | // the HTTP request from the value. The Decode API is used to decode data from the HTTP 54 | // request to a field of the input struct, and Encode API is used to encode the field of 55 | // the input struct to the HTTP request. 56 | // 57 | // Will panic if the name were taken or given executor is nil. Pass parameter force 58 | // (true) to ignore the name conflict. 59 | func RegisterDirective(name string, executor DirectiveExecutor, force ...bool) { 60 | panicOnReservedExecutorName(name) 61 | registerDirective(name, executor, force...) 62 | } 63 | 64 | func registerDirective(name string, executor DirectiveExecutor, force ...bool) { 65 | registerDirectiveExecutorToNamespace(decoderNamespace, name, executor, force...) 66 | registerDirectiveExecutorToNamespace(encoderNamespace, name, executor, force...) 67 | } 68 | 69 | func registerDirectiveExecutorToNamespace(ns *owl.Namespace, name string, exe DirectiveExecutor, force ...bool) { 70 | if exe == nil { 71 | internal.PanicOnError(errors.New("nil directive executor")) 72 | } 73 | if ns == decoderNamespace { 74 | ns.RegisterDirectiveExecutor(name, asOwlDirectiveExecutor(exe.Decode), force...) 75 | } else { 76 | ns.RegisterDirectiveExecutor(name, asOwlDirectiveExecutor(exe.Encode), force...) 77 | } 78 | } 79 | 80 | func asOwlDirectiveExecutor(directiveFunc func(*DirectiveRuntime) error) owl.DirectiveExecutor { 81 | return owl.DirectiveExecutorFunc(func(dr *owl.DirectiveRuntime) error { 82 | return directiveFunc((*DirectiveRuntime)(dr)) 83 | }) 84 | } 85 | 86 | func panicOnReservedExecutorName(name string) { 87 | for _, reservedName := range reservedExecutorNames { 88 | if name == reservedName { 89 | internal.PanicOnError(fmt.Errorf("reserved executor name: %q", name)) 90 | } 91 | } 92 | } 93 | 94 | // directiveNoop is a DirectiveExecutor that does nothing, "noop" stands for "no operation". 95 | type directiveNoop struct{} 96 | 97 | func (*directiveNoop) Encode(*DirectiveRuntime) error { return nil } 98 | func (*directiveNoop) Decode(*DirectiveRuntime) error { return nil } 99 | -------------------------------------------------------------------------------- /core/directive_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestRegisterDirectiveExecutor(t *testing.T) { 10 | assert.NotPanics(t, func() { 11 | RegisterDirective("noop_TestRegisterDirectiveExecutor", noopDirective) 12 | }) 13 | 14 | assert.Panics(t, func() { 15 | RegisterDirective("noop_TestRegisterDirectiveExecutor", noopDirective) 16 | }, "should panic on duplicate name") 17 | 18 | assert.Panics(t, func() { 19 | RegisterDirective("nil_TestRegisterDirectiveExecutor", nil) 20 | }, "should panic on nil executor") 21 | 22 | assert.Panics(t, func() { 23 | RegisterDirective("decoder", noopDirective) 24 | }, "should panic on reserved name") 25 | } 26 | 27 | func TestRegisterDirectiveExecutor_ForceReplace(t *testing.T) { 28 | assert.NotPanics(t, func() { 29 | RegisterDirective("noop_TestRegisterDirectiveExecutor_forceReplace", noopDirective, true) 30 | }) 31 | 32 | assert.NotPanics(t, func() { 33 | RegisterDirective("noop_TestRegisterDirectiveExecutor_forceReplace", noopDirective, true) 34 | }, "should not panic on duplicate name") 35 | 36 | assert.Panics(t, func() { 37 | RegisterDirective("nil_TestRegisterDirectiveExecutor_forceReplace", nil, true) 38 | }, "should panic on nil executor") 39 | 40 | assert.Panics(t, func() { 41 | RegisterDirective("decoder", noopDirective, true) 42 | }, "should panic on reserved name") 43 | } 44 | -------------------------------------------------------------------------------- /core/directiveruntime.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "reflect" 8 | 9 | "github.com/ggicci/owl" 10 | ) 11 | 12 | type contextKey int 13 | 14 | const ( 15 | // CtxRequest is the key to get the HTTP request value (of *http.Request) 16 | // from DirectiveRuntime.Context. The HTTP request value is injected by 17 | // httpin to the context of DirectiveRuntime before executing the directive. 18 | // See Core.Decode() for more details. 19 | CtxRequest contextKey = iota 20 | 21 | CtxRequestBuilder 22 | 23 | // CtxCustomCoder is the key to get the custom decoder for a field from 24 | // Resolver.Context. Which is specified by the "decoder" directive. 25 | // During resolver building phase, the "decoder" directive will be removed 26 | // from the resolver, and the targeted decoder by name will be put into 27 | // Resolver.Context with this key. e.g. 28 | // 29 | // type GreetInput struct { 30 | // Message string `httpin:"decoder=custom"` 31 | // } 32 | // For the above example, the decoder named "custom" will be put into the 33 | // resolver of Message field with this key. 34 | CtxCustomCoder 35 | 36 | // CtxFieldSet is used by executors to tell whether a field has been set. When 37 | // multiple executors were applied to a field, if the field value were set 38 | // by a former executor, the latter executors MAY skip running by consulting 39 | // this context value. 40 | CtxFieldSet 41 | ) 42 | 43 | // DirectiveRuntime is the runtime of a directive execution. It wraps owl.DirectiveRuntime, 44 | // providing some additional helper methods particular to httpin. 45 | // 46 | // See owl.DirectiveRuntime for more details. 47 | type DirectiveRuntime owl.DirectiveRuntime 48 | 49 | func (rtm *DirectiveRuntime) GetRequest() *http.Request { 50 | if req := rtm.Context.Value(CtxRequest); req != nil { 51 | return req.(*http.Request) 52 | } 53 | return nil 54 | } 55 | 56 | func (rtm *DirectiveRuntime) GetRequestBuilder() *RequestBuilder { 57 | if rb := rtm.Context.Value(CtxRequestBuilder); rb != nil { 58 | return rb.(*RequestBuilder) 59 | } 60 | return nil 61 | } 62 | 63 | func (rtm *DirectiveRuntime) GetCustomCoder() *NamedAnyStringableAdaptor { 64 | if info := rtm.Resolver.Context.Value(CtxCustomCoder); info != nil { 65 | return info.(*NamedAnyStringableAdaptor) 66 | } else { 67 | return nil 68 | } 69 | } 70 | 71 | func (rtm *DirectiveRuntime) IsFieldSet() bool { 72 | return rtm.Context.Value(CtxFieldSet) == true 73 | } 74 | 75 | func (rtm *DirectiveRuntime) MarkFieldSet(value bool) { 76 | rtm.Context = context.WithValue(rtm.Context, CtxFieldSet, value) 77 | } 78 | 79 | func (rtm *DirectiveRuntime) SetValue(value any) error { 80 | if value == nil { 81 | // NOTE: should we wipe the value here? i.e. set the value to nil if necessary. 82 | // No case found yet, at least for now. 83 | return nil 84 | } 85 | newValue := reflect.ValueOf(value) 86 | targetType := rtm.Value.Type().Elem() 87 | 88 | if !newValue.Type().AssignableTo(targetType) { 89 | return fmt.Errorf("%w: value of type %q is not assignable to type %q", 90 | ErrTypeMismatch, reflect.TypeOf(value), targetType) 91 | } 92 | 93 | rtm.Value.Elem().Set(newValue) 94 | return nil 95 | } 96 | -------------------------------------------------------------------------------- /core/error.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/ggicci/httpin/internal" 9 | "github.com/ggicci/owl" 10 | ) 11 | 12 | var ( 13 | ErrUnregisteredDirective = errors.New("unregistered directive") 14 | ErrUnregisteredCoder = errors.New("unregistered coder") 15 | ErrTypeMismatch = internal.ErrTypeMismatch 16 | ) 17 | 18 | type InvalidFieldError struct { 19 | // err is the underlying error thrown by the directive executor. 20 | err error 21 | 22 | // Field is the name of the field. 23 | Field string `json:"field"` 24 | 25 | // Source is the directive which causes the error. 26 | // e.g. form, header, required, etc. 27 | Directive string `json:"directive"` 28 | 29 | // Key is the key to get the input data from the source. 30 | Key string `json:"key"` 31 | 32 | // Value is the input data. 33 | Value any `json:"value"` 34 | 35 | // ErrorMessage is the string representation of `internalError`. 36 | ErrorMessage string `json:"error"` 37 | } 38 | 39 | func (e *InvalidFieldError) Error() string { 40 | return fmt.Sprintf("invalid field %q: %v", e.Field, e.err) 41 | } 42 | 43 | func (e *InvalidFieldError) Unwrap() error { 44 | return e.err 45 | } 46 | 47 | func NewInvalidFieldError(err error) *InvalidFieldError { 48 | var ( 49 | r *owl.Resolver 50 | de *owl.DirectiveExecutionError 51 | ) 52 | 53 | switch err := err.(type) { 54 | case *InvalidFieldError: 55 | return err 56 | case *owl.ResolveError: 57 | r = err.Resolver 58 | de = err.AsDirectiveExecutionError() 59 | case *owl.ScanError: 60 | r = err.Resolver 61 | de = err.AsDirectiveExecutionError() 62 | default: 63 | return &InvalidFieldError{ 64 | err: err, 65 | ErrorMessage: err.Error(), 66 | } 67 | } 68 | 69 | var fe *fieldError 70 | var inputKey string 71 | var inputValue any 72 | errors.As(err, &fe) 73 | if fe != nil { 74 | inputValue = fe.Value 75 | inputKey = fe.Key 76 | } 77 | 78 | return &InvalidFieldError{ 79 | err: err, 80 | Field: r.Field.Name, 81 | Directive: de.Name, // e.g. form, header, required, etc. 82 | Key: inputKey, 83 | Value: inputValue, 84 | ErrorMessage: err.Error(), 85 | } 86 | } 87 | 88 | type MultiInvalidFieldError []*InvalidFieldError 89 | 90 | func (me MultiInvalidFieldError) Error() string { 91 | if len(me) == 1 { 92 | return me[0].Error() 93 | } 94 | var sb strings.Builder 95 | sb.WriteString(fmt.Sprintf("%d invalid fields: ", len(me))) 96 | for i, e := range me { 97 | if i > 0 { 98 | sb.WriteString("; ") 99 | } 100 | sb.WriteString(e.Error()) 101 | } 102 | return sb.String() 103 | } 104 | 105 | func (me MultiInvalidFieldError) Unwrap() []error { 106 | var errs []error 107 | for _, e := range me { 108 | errs = append(errs, e) 109 | } 110 | return errs 111 | } 112 | 113 | type fieldError struct { 114 | Key string 115 | Value any 116 | internalError error 117 | } 118 | 119 | func (e fieldError) Error() string { 120 | return e.internalError.Error() 121 | } 122 | 123 | func (e fieldError) Unwrap() error { 124 | return e.internalError 125 | } 126 | -------------------------------------------------------------------------------- /core/errorhandler.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "net/http" 7 | 8 | "github.com/ggicci/httpin/internal" 9 | ) 10 | 11 | var globalCustomErrorHandler ErrorHandler = defaultErrorHandler 12 | 13 | // RegisterErrorHandler replaces the default error handler with the given 14 | // custom error handler. The default error handler will be used in the http.Handler 15 | // that decoreated by the middleware created by NewInput(). 16 | func RegisterErrorHandler(handler ErrorHandler) { 17 | internal.PanicOnError(validateErrorHandler(handler)) 18 | globalCustomErrorHandler = handler 19 | } 20 | 21 | func defaultErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { 22 | var invalidFieldError *InvalidFieldError 23 | if errors.As(err, &invalidFieldError) { 24 | rw.Header().Add("Content-Type", "application/json") 25 | rw.WriteHeader(http.StatusUnprocessableEntity) // status: 422 26 | json.NewEncoder(rw).Encode(invalidFieldError) 27 | return 28 | } 29 | 30 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) // status: 500 31 | } 32 | 33 | // ErrorHandler is the type of custom error handler. The error handler is used 34 | // by the http.Handler that created by NewInput() to handle errors during 35 | // decoding the HTTP request. 36 | type ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) 37 | 38 | func validateErrorHandler(handler ErrorHandler) error { 39 | if handler == nil { 40 | return errors.New("nil error handler") 41 | } 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /core/errorhandler_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func myCustomErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { 14 | var invalidFieldError *InvalidFieldError 15 | if errors.As(err, &invalidFieldError) { 16 | rw.WriteHeader(http.StatusBadRequest) // status: 400 17 | io.WriteString(rw, invalidFieldError.Error()) 18 | return 19 | } 20 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) // status: 500 21 | } 22 | 23 | func TestRegisterErrorHandler(t *testing.T) { 24 | // Nil handler should panic. 25 | assert.PanicsWithError(t, "httpin: nil error handler", func() { 26 | RegisterErrorHandler(nil) 27 | }) 28 | 29 | RegisterErrorHandler(myCustomErrorHandler) 30 | assert.True(t, equalFuncs(globalCustomErrorHandler, myCustomErrorHandler)) 31 | } 32 | 33 | func TestDefaultErrorHandler(t *testing.T) { 34 | r, _ := http.NewRequest("GET", "/", nil) 35 | rw := httptest.NewRecorder() 36 | 37 | // When met InvalidFieldError, it should return 422. 38 | defaultErrorHandler(rw, r, &InvalidFieldError{err: assert.AnError, ErrorMessage: assert.AnError.Error()}) 39 | assert.Equal(t, 422, rw.Code) 40 | 41 | // When met other errors, it should return 500. 42 | rw = httptest.NewRecorder() 43 | defaultErrorHandler(rw, r, assert.AnError) 44 | assert.Equal(t, 500, rw.Code) 45 | } 46 | -------------------------------------------------------------------------------- /core/file.go: -------------------------------------------------------------------------------- 1 | // https://ggicci.github.io/httpin/advanced/upload-files 2 | 3 | package core 4 | 5 | import ( 6 | "errors" 7 | "io" 8 | "mime/multipart" 9 | "os" 10 | ) 11 | 12 | func init() { 13 | RegisterFileCoder[*File]() 14 | } 15 | 16 | // File is the builtin type of httpin to manupulate file uploads. On the server 17 | // side, it is used to represent a file in a multipart/form-data request. On the 18 | // client side, it is used to represent a file to be uploaded. 19 | type File struct { 20 | FileHeader 21 | uploadFilename string 22 | uploadReader io.ReadCloser 23 | } 24 | 25 | // UploadFile is a helper function to create a File instance from a file path. 26 | // It is useful when you want to upload a file from the local file system. 27 | func UploadFile(filename string) *File { 28 | return &File{uploadFilename: filename} 29 | } 30 | 31 | // UploadStream is a helper function to create a File instance from a io.Reader. It 32 | // is useful when you want to upload a file from a stream. 33 | func UploadStream(contentReader io.ReadCloser) *File { 34 | return &File{uploadReader: contentReader} 35 | } 36 | 37 | // Filename returns the filename of the file. On the server side, it returns the 38 | // filename of the file in the multipart/form-data request. On the client side, it 39 | // returns the filename of the file to be uploaded. 40 | func (f *File) Filename() string { 41 | if f.IsUpload() { 42 | return f.uploadFilename 43 | } 44 | if f.FileHeader != nil { 45 | return f.FileHeader.Filename() 46 | } 47 | return "" 48 | } 49 | 50 | // MarshalFile implements FileMarshaler. 51 | func (f *File) MarshalFile() (io.ReadCloser, error) { 52 | if f.IsUpload() { 53 | return f.OpenUploadStream() 54 | } else { 55 | return f.OpenReceiveStream() 56 | } 57 | } 58 | 59 | func (f *File) UnmarshalFile(fh FileHeader) error { 60 | f.FileHeader = fh 61 | return nil 62 | } 63 | 64 | // IsUpload returns true when the File instance is created for an upload purpose. 65 | // Typically, you should use UploadFilename or UploadReader to create a File instance 66 | // for upload. 67 | func (f *File) IsUpload() bool { 68 | return f.uploadFilename != "" || f.uploadReader != nil 69 | } 70 | 71 | // OpenUploadStream returns a io.ReadCloser for the file to be uploaded. 72 | // Call this method on the client side for uploading a file. 73 | func (f *File) OpenUploadStream() (io.ReadCloser, error) { 74 | if f.uploadReader != nil { 75 | return f.uploadReader, nil 76 | } 77 | if f.uploadFilename != "" { 78 | return os.Open(f.uploadFilename) 79 | } 80 | return nil, errors.New("invalid upload (client): no filename or reader") 81 | } 82 | 83 | // OpenReceiveStream returns a io.Reader for the file in the multipart/form-data request. 84 | // Call this method on the server side to read the file content. 85 | func (f *File) OpenReceiveStream() (multipart.File, error) { 86 | if f.FileHeader == nil { 87 | return nil, errors.New("invalid upload (server): nil file header") 88 | } 89 | return f.FileHeader.Open() 90 | } 91 | 92 | func (f *File) ReadAll() ([]byte, error) { 93 | reader, err := f.MarshalFile() 94 | if err != nil { 95 | return nil, err 96 | } 97 | return io.ReadAll(reader) 98 | } 99 | -------------------------------------------------------------------------------- /core/fileable.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "mime/multipart" 7 | "net/textproto" 8 | "reflect" 9 | 10 | "github.com/ggicci/httpin/internal" 11 | ) 12 | 13 | // FileHeader is the interface that groups the methods of multipart.FileHeader. 14 | type FileHeader interface { 15 | Filename() string 16 | Size() int64 17 | MIMEHeader() textproto.MIMEHeader 18 | Open() (multipart.File, error) 19 | } 20 | 21 | type FileMarshaler interface { 22 | Filename() string 23 | MarshalFile() (io.ReadCloser, error) 24 | } 25 | 26 | type FileUnmarshaler interface { 27 | UnmarshalFile(FileHeader) error 28 | } 29 | 30 | type Fileable interface { 31 | FileMarshaler 32 | FileUnmarshaler 33 | } 34 | 35 | func NewFileable(rv reflect.Value) (Fileable, error) { 36 | if IsPatchField(rv.Type()) { 37 | return NewFileablePatchFieldWrapper(rv) 38 | } 39 | 40 | return newFileable(rv) 41 | } 42 | 43 | func newFileable(rv reflect.Value) (Fileable, error) { 44 | rv, err := getPointer(rv) 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | if rv.Type().Implements(fileableType) && rv.CanInterface() { 50 | return rv.Interface().(Fileable), nil 51 | } 52 | return nil, errors.New("unsupported file type") 53 | } 54 | 55 | type FileablePatchFieldWrapper struct { 56 | Value reflect.Value // of patch.Field[T] 57 | internalFileable Fileable 58 | } 59 | 60 | func NewFileablePatchFieldWrapper(rv reflect.Value) (*FileablePatchFieldWrapper, error) { 61 | fileable, err := NewFileable(rv.FieldByName("Value")) 62 | if err != nil { 63 | return nil, err 64 | } else { 65 | return &FileablePatchFieldWrapper{ 66 | Value: rv, 67 | internalFileable: fileable, 68 | }, nil 69 | } 70 | } 71 | 72 | func (w *FileablePatchFieldWrapper) Filename() string { 73 | return w.internalFileable.Filename() 74 | } 75 | 76 | func (w *FileablePatchFieldWrapper) MarshalFile() (io.ReadCloser, error) { 77 | return w.internalFileable.MarshalFile() 78 | } 79 | 80 | func (w *FileablePatchFieldWrapper) UnmarshalFile(fh FileHeader) error { 81 | if err := w.internalFileable.UnmarshalFile(fh); err != nil { 82 | return err 83 | } else { 84 | w.Value.FieldByName("Valid").SetBool(true) 85 | return nil 86 | } 87 | } 88 | 89 | var fileableType = internal.TypeOf[Fileable]() 90 | 91 | // multipartFileHeader is the adaptor of multipart.FileHeader. 92 | type multipartFileHeader struct{ *multipart.FileHeader } 93 | 94 | func (h *multipartFileHeader) Filename() string { 95 | return h.FileHeader.Filename 96 | } 97 | 98 | func (h *multipartFileHeader) Size() int64 { 99 | return h.FileHeader.Size 100 | } 101 | 102 | func (h *multipartFileHeader) MIMEHeader() textproto.MIMEHeader { 103 | return h.FileHeader.Header 104 | } 105 | 106 | func (h *multipartFileHeader) Open() (multipart.File, error) { 107 | return h.FileHeader.Open() 108 | } 109 | 110 | func toFileHeaderList(fhs []*multipart.FileHeader) []FileHeader { 111 | result := make([]FileHeader, len(fhs)) 112 | for i, fh := range fhs { 113 | result[i] = &multipartFileHeader{fh} 114 | } 115 | return result 116 | } 117 | -------------------------------------------------------------------------------- /core/fileable_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "io" 5 | "mime/multipart" 6 | "net/textproto" 7 | "os" 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/ggicci/httpin/patch" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type MyFiles struct { 16 | Avatar File 17 | AvatarPointer *File 18 | Avatars []File 19 | AvatarPointers []*File 20 | 21 | PatchAvatar patch.Field[File] 22 | PatchAvatarPointer patch.Field[*File] 23 | PatchAvatars patch.Field[[]File] 24 | PatchAvatarPointers patch.Field[[]*File] 25 | } 26 | 27 | func TestFileable_UnmarshalFile(t *testing.T) { 28 | rv := reflect.New(reflect.TypeOf(MyFiles{})).Elem() 29 | s := rv.Addr().Interface().(*MyFiles) 30 | 31 | fileAvatar := testAssignFile(t, rv.FieldByName("Avatar")) 32 | fileAvatarPointer := testAssignFile(t, rv.FieldByName("AvatarPointer")) 33 | testNewFileableErrUnsupported(t, rv.FieldByName("Avatars")) 34 | testNewFileableErrUnsupported(t, rv.FieldByName("AvatarPointers")) 35 | 36 | filePatchAvatar := testAssignFile(t, rv.FieldByName("PatchAvatar")) 37 | filePatchAvatarPointer := testAssignFile(t, rv.FieldByName("PatchAvatarPointer")) 38 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatars")) 39 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatarPointers")) 40 | 41 | validateFile(t, fileAvatar, &s.Avatar) 42 | validateFile(t, fileAvatarPointer, s.AvatarPointer) 43 | 44 | assert.True(t, s.PatchAvatar.Valid) 45 | validateFile(t, filePatchAvatar, &s.PatchAvatar.Value) 46 | assert.True(t, s.PatchAvatarPointer.Valid) 47 | validateFile(t, filePatchAvatarPointer, s.PatchAvatarPointer.Value) 48 | } 49 | 50 | func TestFileable_MarshalFile(t *testing.T) { 51 | fileAvatar := createTempFileV2(t) 52 | fileAvatarPointer := createTempFileV2(t) 53 | filePatchAvatar := createTempFileV2(t) 54 | filePatchAvatarPointer := createTempFileV2(t) 55 | 56 | var s = &MyFiles{ 57 | Avatar: *UploadFile(fileAvatar.Filename), 58 | AvatarPointer: UploadFile(fileAvatarPointer.Filename), 59 | Avatars: []File{*UploadFile(fileAvatar.Filename)}, 60 | AvatarPointers: []*File{UploadFile(fileAvatarPointer.Filename)}, 61 | 62 | PatchAvatar: patch.Field[File]{Value: *UploadFile(filePatchAvatar.Filename), Valid: true}, 63 | PatchAvatarPointer: patch.Field[*File]{Value: UploadFile(filePatchAvatarPointer.Filename), Valid: true}, 64 | PatchAvatars: patch.Field[[]File]{Value: []File{*UploadFile(filePatchAvatar.Filename)}, Valid: true}, 65 | PatchAvatarPointers: patch.Field[[]*File]{Value: []*File{UploadFile(filePatchAvatarPointer.Filename)}, Valid: true}, 66 | } 67 | 68 | rv := reflect.ValueOf(s).Elem() 69 | 70 | validateRvFile(t, fileAvatar, rv.FieldByName("Avatar")) 71 | validateRvFile(t, fileAvatarPointer, rv.FieldByName("AvatarPointer")) 72 | testNewFileableErrUnsupported(t, rv.FieldByName("Avatars")) 73 | testNewFileableErrUnsupported(t, rv.FieldByName("AvatarPointers")) 74 | 75 | validateRvFile(t, filePatchAvatar, rv.FieldByName("PatchAvatar")) 76 | validateRvFile(t, filePatchAvatarPointer, rv.FieldByName("PatchAvatarPointer")) 77 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatars")) 78 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatarPointers")) 79 | } 80 | 81 | func testNewFileableErrUnsupported(t *testing.T, rv reflect.Value) { 82 | fileable, err := NewFileable(rv) 83 | assert.ErrorContains(t, err, "unsupported file type") 84 | assert.Nil(t, fileable) 85 | } 86 | 87 | func validateFile(t *testing.T, expected *tempFile, actual FileMarshaler) { 88 | assert.Equal(t, expected.Filename, actual.Filename()) 89 | reader, err := actual.MarshalFile() 90 | assert.NoError(t, err) 91 | content, err := io.ReadAll(reader) 92 | assert.NoError(t, err) 93 | assert.Equal(t, expected.Content, content) 94 | } 95 | 96 | func validateRvFile(t *testing.T, expected *tempFile, actual reflect.Value) { 97 | file, err := NewFileable(actual) 98 | assert.NoError(t, err) 99 | reader, err := file.MarshalFile() 100 | assert.NoError(t, err) 101 | content, err := io.ReadAll(reader) 102 | assert.NoError(t, err) 103 | 104 | assert.Equal(t, expected.Filename, file.Filename()) 105 | assert.Equal(t, expected.Content, content) 106 | } 107 | 108 | func testAssignFile(t *testing.T, rv reflect.Value) *tempFile { 109 | fileable, err := NewFileable(rv) 110 | assert.NoError(t, err) 111 | file := createTempFileV2(t) 112 | assert.NoError(t, fileable.UnmarshalFile(mockFileHeader(t, file.Filename))) 113 | return file 114 | } 115 | 116 | type dummyFileHeader struct { 117 | file *os.File 118 | } 119 | 120 | func mockFileHeader(t *testing.T, filename string) FileHeader { 121 | file, err := os.Open(filename) 122 | if err != nil { 123 | panic(err) 124 | } 125 | return &dummyFileHeader{ 126 | file: file, 127 | } 128 | } 129 | 130 | func (f *dummyFileHeader) Filename() string { 131 | return f.file.Name() 132 | } 133 | 134 | func (f *dummyFileHeader) Size() int64 { 135 | stat, err := f.file.Stat() 136 | if err != nil { 137 | panic(err) 138 | } 139 | return stat.Size() 140 | } 141 | 142 | func (f *dummyFileHeader) MIMEHeader() textproto.MIMEHeader { 143 | return textproto.MIMEHeader{} 144 | } 145 | 146 | func (f *dummyFileHeader) Open() (multipart.File, error) { 147 | return f.file, nil 148 | } 149 | -------------------------------------------------------------------------------- /core/fileslicable.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | type FileSlicable interface { 10 | ToFileSlice() ([]FileMarshaler, error) 11 | FromFileSlice([]FileHeader) error 12 | } 13 | 14 | func NewFileSlicable(rv reflect.Value) (FileSlicable, error) { 15 | if IsPatchField(rv.Type()) { 16 | return NewFileSlicablePatchFieldWrapper(rv) 17 | } 18 | 19 | if isSliceType(rv.Type()) { 20 | return NewFileableSliceWrapper(rv) 21 | } else { 22 | return NewFileSlicableSingleFileableWrapper(rv) 23 | } 24 | } 25 | 26 | type FileSlicablePatchFieldWrapper struct { 27 | Value reflect.Value // of patch.Field[T] 28 | internalFileSliceable FileSlicable 29 | } 30 | 31 | func NewFileSlicablePatchFieldWrapper(rv reflect.Value) (*FileSlicablePatchFieldWrapper, error) { 32 | fileSlicable, err := NewFileSlicable(rv.FieldByName("Value")) 33 | if err != nil { 34 | return nil, err 35 | } else { 36 | return &FileSlicablePatchFieldWrapper{ 37 | Value: rv, 38 | internalFileSliceable: fileSlicable, 39 | }, nil 40 | } 41 | } 42 | 43 | func (w *FileSlicablePatchFieldWrapper) ToFileSlice() ([]FileMarshaler, error) { 44 | if w.Value.FieldByName("Valid").Bool() { 45 | return w.internalFileSliceable.ToFileSlice() 46 | } else { 47 | return []FileMarshaler{}, nil 48 | } 49 | } 50 | 51 | func (w *FileSlicablePatchFieldWrapper) FromFileSlice(fhs []FileHeader) error { 52 | if err := w.internalFileSliceable.FromFileSlice(fhs); err != nil { 53 | return err 54 | } else { 55 | w.Value.FieldByName("Valid").SetBool(true) 56 | return nil 57 | } 58 | } 59 | 60 | type FileableSliceWrapper struct { 61 | Value reflect.Value 62 | } 63 | 64 | func NewFileableSliceWrapper(rv reflect.Value) (*FileableSliceWrapper, error) { 65 | if !rv.CanAddr() { 66 | return nil, errors.New("unaddressable value") 67 | } 68 | return &FileableSliceWrapper{Value: rv}, nil 69 | } 70 | 71 | func (w *FileableSliceWrapper) ToFileSlice() ([]FileMarshaler, error) { 72 | var files = make([]FileMarshaler, w.Value.Len()) 73 | for i := 0; i < w.Value.Len(); i++ { 74 | if fileable, err := NewFileable(w.Value.Index(i)); err != nil { 75 | return nil, fmt.Errorf("cannot create Fileable at index %d: %w", i, err) 76 | } else { 77 | files[i] = fileable 78 | } 79 | } 80 | return files, nil 81 | } 82 | 83 | func (w *FileableSliceWrapper) FromFileSlice(fhs []FileHeader) error { 84 | w.Value.Set(reflect.MakeSlice(w.Value.Type(), len(fhs), len(fhs))) 85 | for i, fh := range fhs { 86 | fileable, err := NewFileable(w.Value.Index(i)) 87 | if err != nil { 88 | return fmt.Errorf("cannot create Fileable at index %d: %w", i, err) 89 | } 90 | if err := fileable.UnmarshalFile(fh); err != nil { 91 | return fmt.Errorf("cannot unmarshal file %q at index %d: %w", fh.Filename(), i, err) 92 | } 93 | } 94 | return nil 95 | } 96 | 97 | type FileSlicableSingleFileableWrapper struct{ Fileable } 98 | 99 | func NewFileSlicableSingleFileableWrapper(rv reflect.Value) (*FileSlicableSingleFileableWrapper, error) { 100 | if fileable, err := NewFileable(rv); err != nil { 101 | return nil, err 102 | } else { 103 | return &FileSlicableSingleFileableWrapper{fileable}, nil 104 | } 105 | } 106 | 107 | func (w *FileSlicableSingleFileableWrapper) ToFileSlice() ([]FileMarshaler, error) { 108 | return []FileMarshaler{w.Fileable}, nil 109 | } 110 | 111 | func (w *FileSlicableSingleFileableWrapper) FromFileSlice(files []FileHeader) error { 112 | if len(files) > 0 { 113 | return w.UnmarshalFile(files[0]) 114 | } 115 | return nil 116 | } 117 | -------------------------------------------------------------------------------- /core/fileslicable_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/ggicci/httpin/patch" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestFileSlicable_FromFileSlice(t *testing.T) { 12 | rv := reflect.New(reflect.TypeOf(MyFiles{})).Elem() 13 | s := rv.Addr().Interface().(*MyFiles) 14 | 15 | fileAvatar := createTempFileV2(t) 16 | fileAvatarPointer := createTempFileV2(t) 17 | filePatchAvatar := createTempFileV2(t) 18 | filePatchAvatarPointer := createTempFileV2(t) 19 | 20 | testAssignFileSlice(t, rv.FieldByName("Avatar"), []FileHeader{ 21 | mockFileHeader(t, fileAvatar.Filename), 22 | }) 23 | testAssignFileSlice(t, rv.FieldByName("AvatarPointer"), []FileHeader{ 24 | mockFileHeader(t, fileAvatarPointer.Filename), 25 | }) 26 | testAssignFileSlice(t, rv.FieldByName("Avatars"), []FileHeader{ 27 | mockFileHeader(t, fileAvatar.Filename), 28 | mockFileHeader(t, fileAvatarPointer.Filename), 29 | }) 30 | testAssignFileSlice(t, rv.FieldByName("AvatarPointers"), []FileHeader{ 31 | mockFileHeader(t, fileAvatarPointer.Filename), 32 | mockFileHeader(t, fileAvatar.Filename), 33 | }) 34 | 35 | testAssignFileSlice(t, rv.FieldByName("PatchAvatar"), []FileHeader{ 36 | mockFileHeader(t, filePatchAvatar.Filename), 37 | }) 38 | testAssignFileSlice(t, rv.FieldByName("PatchAvatarPointer"), []FileHeader{ 39 | mockFileHeader(t, filePatchAvatarPointer.Filename), 40 | }) 41 | testAssignFileSlice(t, rv.FieldByName("PatchAvatars"), []FileHeader{ 42 | mockFileHeader(t, fileAvatar.Filename), 43 | mockFileHeader(t, filePatchAvatar.Filename), 44 | mockFileHeader(t, filePatchAvatarPointer.Filename), 45 | }) 46 | testAssignFileSlice(t, rv.FieldByName("PatchAvatarPointers"), []FileHeader{ 47 | mockFileHeader(t, fileAvatar.Filename), 48 | mockFileHeader(t, fileAvatarPointer.Filename), 49 | mockFileHeader(t, filePatchAvatar.Filename), 50 | mockFileHeader(t, filePatchAvatarPointer.Filename), 51 | }) 52 | 53 | validateFile(t, fileAvatar, &s.Avatar) 54 | validateFile(t, fileAvatarPointer, s.AvatarPointer) 55 | 56 | assert.Len(t, s.Avatars, 2) 57 | validateFile(t, fileAvatar, &s.Avatars[0]) 58 | validateFile(t, fileAvatarPointer, &s.Avatars[1]) 59 | 60 | assert.Len(t, s.AvatarPointers, 2) 61 | validateFile(t, fileAvatarPointer, s.AvatarPointers[0]) 62 | validateFile(t, fileAvatar, s.AvatarPointers[1]) 63 | 64 | assert.True(t, s.PatchAvatar.Valid) 65 | validateFile(t, filePatchAvatar, &s.PatchAvatar.Value) 66 | 67 | assert.True(t, s.PatchAvatarPointer.Valid) 68 | validateFile(t, filePatchAvatarPointer, s.PatchAvatarPointer.Value) 69 | 70 | assert.True(t, s.PatchAvatars.Valid) 71 | assert.Len(t, s.PatchAvatars.Value, 3) 72 | validateFile(t, fileAvatar, &s.PatchAvatars.Value[0]) 73 | validateFile(t, filePatchAvatar, &s.PatchAvatars.Value[1]) 74 | validateFile(t, filePatchAvatarPointer, &s.PatchAvatars.Value[2]) 75 | 76 | assert.True(t, s.PatchAvatarPointers.Valid) 77 | assert.Len(t, s.PatchAvatarPointers.Value, 4) 78 | validateFile(t, fileAvatar, s.PatchAvatarPointers.Value[0]) 79 | validateFile(t, fileAvatarPointer, s.PatchAvatarPointers.Value[1]) 80 | validateFile(t, filePatchAvatar, s.PatchAvatarPointers.Value[2]) 81 | validateFile(t, filePatchAvatarPointer, s.PatchAvatarPointers.Value[3]) 82 | } 83 | 84 | func TestFileSlicable_ToFileSlice(t *testing.T) { 85 | fileAvatar := createTempFileV2(t) 86 | fileAvatarPointer := createTempFileV2(t) 87 | filePatchAvatar := createTempFileV2(t) 88 | filePatchAvatarPointer := createTempFileV2(t) 89 | 90 | var s = &MyFiles{ 91 | Avatar: *UploadFile(fileAvatar.Filename), 92 | AvatarPointer: UploadFile(fileAvatarPointer.Filename), 93 | Avatars: []File{*UploadFile(fileAvatar.Filename), *UploadFile(fileAvatarPointer.Filename)}, 94 | AvatarPointers: []*File{UploadFile(fileAvatarPointer.Filename), UploadFile(fileAvatar.Filename)}, 95 | PatchAvatar: patch.Field[File]{Value: *UploadFile(filePatchAvatar.Filename), Valid: true}, 96 | PatchAvatarPointer: patch.Field[*File]{ 97 | Value: UploadFile(filePatchAvatarPointer.Filename), 98 | Valid: true, 99 | }, 100 | PatchAvatars: patch.Field[[]File]{ 101 | Value: []File{ 102 | *UploadFile(fileAvatar.Filename), 103 | *UploadFile(filePatchAvatar.Filename), 104 | *UploadFile(filePatchAvatarPointer.Filename), 105 | }, 106 | Valid: true, 107 | }, 108 | PatchAvatarPointers: patch.Field[[]*File]{ 109 | Value: []*File{ 110 | UploadFile(fileAvatar.Filename), 111 | UploadFile(fileAvatarPointer.Filename), 112 | UploadFile(filePatchAvatar.Filename), 113 | UploadFile(filePatchAvatarPointer.Filename), 114 | }, 115 | Valid: true, 116 | }, 117 | } 118 | 119 | rv := reflect.ValueOf(s).Elem() 120 | validateFileList(t, []*tempFile{fileAvatar}, testGetFileSlice(t, rv.FieldByName("Avatar"))) 121 | validateFileList(t, []*tempFile{fileAvatarPointer}, testGetFileSlice(t, rv.FieldByName("AvatarPointer"))) 122 | validateFileList(t, []*tempFile{fileAvatar, fileAvatarPointer}, testGetFileSlice(t, rv.FieldByName("Avatars"))) 123 | validateFileList(t, []*tempFile{fileAvatarPointer, fileAvatar}, testGetFileSlice(t, rv.FieldByName("AvatarPointers"))) 124 | validateFileList(t, []*tempFile{filePatchAvatar}, testGetFileSlice(t, rv.FieldByName("PatchAvatar"))) 125 | validateFileList(t, []*tempFile{filePatchAvatarPointer}, testGetFileSlice(t, rv.FieldByName("PatchAvatarPointer"))) 126 | validateFileList(t, []*tempFile{fileAvatar, filePatchAvatar, filePatchAvatarPointer}, testGetFileSlice(t, rv.FieldByName("PatchAvatars"))) 127 | validateFileList(t, []*tempFile{fileAvatar, fileAvatarPointer, filePatchAvatar, filePatchAvatarPointer}, testGetFileSlice(t, rv.FieldByName("PatchAvatarPointers"))) 128 | } 129 | 130 | func testAssignFileSlice(t *testing.T, rv reflect.Value, files []FileHeader) { 131 | fs, err := NewFileSlicable(rv) 132 | assert.NoError(t, err) 133 | assert.NoError(t, fs.FromFileSlice(files)) 134 | } 135 | 136 | func testGetFileSlice(t *testing.T, rv reflect.Value) []FileMarshaler { 137 | fs, err := NewFileSlicable(rv) 138 | assert.NoError(t, err) 139 | files, err := fs.ToFileSlice() 140 | assert.NoError(t, err) 141 | return files 142 | } 143 | 144 | func validateFileList(t *testing.T, expected []*tempFile, actual []FileMarshaler) { 145 | assert.Len(t, actual, len(expected)) 146 | for i, file := range expected { 147 | validateFile(t, file, actual[i]) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /core/form.go: -------------------------------------------------------------------------------- 1 | // directive: "form" 2 | // https://ggicci.github.io/httpin/directives/form 3 | 4 | package core 5 | 6 | import ( 7 | "mime/multipart" 8 | ) 9 | 10 | type DirectvieForm struct{} 11 | 12 | // Decode implements the "form" executor who extracts values from 13 | // the forms of an HTTP request. 14 | func (*DirectvieForm) Decode(rtm *DirectiveRuntime) error { 15 | req := rtm.GetRequest() 16 | var form multipart.Form 17 | if req.MultipartForm != nil { 18 | form = *req.MultipartForm 19 | } else { 20 | if req.Form != nil { 21 | form.Value = req.Form 22 | } 23 | } 24 | extractor := &FormExtractor{ 25 | Runtime: rtm, 26 | Form: form, 27 | KeyNormalizer: nil, 28 | } 29 | return extractor.Extract() 30 | } 31 | 32 | // Encode implements the encoder/request builder for "form" directive. 33 | // It builds the form values of an HTTP request, including: 34 | // - form data 35 | // - multipart form data (file upload) 36 | func (*DirectvieForm) Encode(rtm *DirectiveRuntime) error { 37 | encoder := &FormEncoder{ 38 | Setter: rtm.GetRequestBuilder().SetForm, 39 | } 40 | return encoder.Execute(rtm) 41 | } 42 | -------------------------------------------------------------------------------- /core/form_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "encoding/base64" 5 | "io" 6 | "net/http" 7 | "net/url" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/ggicci/httpin/internal" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | // ChaosQuery is designed to make the normal case test coverage higher. 17 | type ChaosQuery struct { 18 | // Basic Types 19 | BoolValue bool `in:"form=bool"` 20 | IntValue int `in:"form=int"` 21 | Int8Value int8 `in:"form=int8"` 22 | Int16Value int16 `in:"form=int16"` 23 | Int32Value int32 `in:"form=int32"` 24 | Int64Value int64 `in:"form=int64"` 25 | UintValue uint `in:"form=uint"` 26 | Uint8Value uint8 `in:"form=uint8"` 27 | Uint16Value uint16 `in:"form=uint16"` 28 | Uint32Value uint32 `in:"form=uint32"` 29 | Uint64Value uint64 `in:"form=uint64"` 30 | Float32Value float32 `in:"form=float32"` 31 | Float64Value float64 `in:"form=float64"` 32 | Complex64Value complex64 `in:"form=complex64"` 33 | Complex128Value complex128 `in:"form=complex128"` 34 | StringValue string `in:"form=string"` 35 | TimeValue time.Time `in:"form=time"` // time type is special 36 | 37 | // Pointer Types 38 | BoolPointer *bool `in:"form=bool_pointer"` 39 | IntPointer *int `in:"form=int_pointer"` 40 | Int8Pointer *int8 `in:"form=int8_pointer"` 41 | Int16Pointer *int16 `in:"form=int16_pointer"` 42 | Int32Pointer *int32 `in:"form=int32_pointer"` 43 | Int64Pointer *int64 `in:"form=int64_pointer"` 44 | UintPointer *uint `in:"form=uint_pointer"` 45 | Uint8Pointer *uint8 `in:"form=uint8_pointer"` 46 | Uint16Pointer *uint16 `in:"form=uint16_pointer"` 47 | Uint32Pointer *uint32 `in:"form=uint32_pointer"` 48 | Uint64Pointer *uint64 `in:"form=uint64_pointer"` 49 | Float32Pointer *float32 `in:"form=float32_pointer"` 50 | Float64Pointer *float64 `in:"form=float64_pointer"` 51 | Complex64Pointer *complex64 `in:"form=complex64_pointer"` 52 | Complex128Pointer *complex128 `in:"form=complex128_pointer"` 53 | StringPointer *string `in:"form=string_pointer"` 54 | TimePointer *time.Time `in:"form=time_pointer"` 55 | 56 | // Array 57 | BoolList []bool `in:"form=bools"` 58 | IntList []int `in:"form=ints"` 59 | FloatList []float64 `in:"form=floats"` 60 | StringList []string `in:"form=strings"` 61 | TimeList []time.Time `in:"form=times"` 62 | } 63 | 64 | var ( 65 | sampleChaosQuery = &ChaosQuery{ 66 | BoolValue: true, 67 | IntValue: 9, 68 | Int8Value: 14, 69 | Int16Value: 841, 70 | Int32Value: 193, 71 | Int64Value: 475, 72 | UintValue: 11, 73 | Uint8Value: 4, 74 | Uint16Value: 48, 75 | Uint32Value: 9583, 76 | Uint64Value: 183471, 77 | Float32Value: 3.14, 78 | Float64Value: 0.618, 79 | Complex64Value: 1 + 4i, 80 | Complex128Value: -6 + 17i, 81 | StringValue: "doggy", 82 | TimeValue: time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC), 83 | 84 | BoolPointer: internal.Pointerize[bool](true), 85 | IntPointer: internal.Pointerize[int](9), 86 | Int8Pointer: internal.Pointerize[int8](14), 87 | Int16Pointer: internal.Pointerize[int16](841), 88 | Int32Pointer: internal.Pointerize[int32](193), 89 | Int64Pointer: internal.Pointerize[int64](475), 90 | UintPointer: internal.Pointerize[uint](11), 91 | Uint8Pointer: internal.Pointerize[uint8](4), 92 | Uint16Pointer: internal.Pointerize[uint16](48), 93 | Uint32Pointer: internal.Pointerize[uint32](9583), 94 | Uint64Pointer: internal.Pointerize[uint64](183471), 95 | Float32Pointer: internal.Pointerize[float32](3.14), 96 | Float64Pointer: internal.Pointerize[float64](0.618), 97 | Complex64Pointer: internal.Pointerize[complex64](1 + 4i), 98 | Complex128Pointer: internal.Pointerize[complex128](-6 + 17i), 99 | StringPointer: internal.Pointerize[string]("doggy"), 100 | TimePointer: internal.Pointerize[time.Time](time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC)), 101 | 102 | BoolList: []bool{true, false, false, true}, 103 | IntList: []int{9, 9, 6}, 104 | FloatList: []float64{0.0, 0.5, 1.0}, 105 | StringList: []string{"Life", "is", "a", "Miracle"}, 106 | TimeList: []time.Time{ 107 | time.Date(2000, 1, 2, 22, 4, 5, 0, time.UTC), 108 | time.Date(1991, 6, 28, 6, 0, 0, 0, time.UTC), 109 | }, 110 | } 111 | ) 112 | 113 | func TestDirectiveForm_Decode(t *testing.T) { 114 | r, _ := http.NewRequest("GET", "/", nil) 115 | r.Form = url.Values{ 116 | "bool": {"true"}, 117 | "int": {"9"}, 118 | "int8": {"14"}, 119 | "int16": {"841"}, 120 | "int32": {"193"}, 121 | "int64": {"475"}, 122 | "uint": {"11"}, 123 | "uint8": {"4"}, 124 | "uint16": {"48"}, 125 | "uint32": {"9583"}, 126 | "uint64": {"183471"}, 127 | "float32": {"3.14"}, 128 | "float64": {"0.618"}, 129 | "complex64": {"1+4i"}, 130 | "complex128": {"-6+17i"}, 131 | "string": {"doggy"}, 132 | "time": {"1991-11-10T08:00:00+08:00"}, 133 | 134 | "bool_pointer": {"true"}, 135 | "int_pointer": {"9"}, 136 | "int8_pointer": {"14"}, 137 | "int16_pointer": {"841"}, 138 | "int32_pointer": {"193"}, 139 | "int64_pointer": {"475"}, 140 | "uint_pointer": {"11"}, 141 | "uint8_pointer": {"4"}, 142 | "uint16_pointer": {"48"}, 143 | "uint32_pointer": {"9583"}, 144 | "uint64_pointer": {"183471"}, 145 | "float32_pointer": {"3.14"}, 146 | "float64_pointer": {"0.618"}, 147 | "complex64_pointer": {"1+4i"}, 148 | "complex128_pointer": {"-6+17i"}, 149 | "string_pointer": {"doggy"}, 150 | "time_pointer": {"1991-11-10T08:00:00+08:00"}, 151 | 152 | "bools": {"true", "false", "0", "1"}, 153 | "ints": {"9", "9", "6"}, 154 | "floats": {"0", "0.5", "1"}, 155 | "strings": {"Life", "is", "a", "Miracle"}, 156 | "times": {"2000-01-02T15:04:05-07:00", "678088800"}, 157 | } 158 | expected := sampleChaosQuery 159 | co, err := New(ChaosQuery{}) 160 | assert.NoError(t, err) 161 | got, err := co.Decode(r) 162 | assert.NoError(t, err) 163 | assert.Equal(t, expected, got.(*ChaosQuery)) 164 | } 165 | 166 | func TestDirectiveForm_NewRequest(t *testing.T) { 167 | co, err := New(ChaosQuery{}) 168 | assert.NoError(t, err) 169 | req, err := co.NewRequest("POST", "/signup", sampleChaosQuery) 170 | assert.NoError(t, err) 171 | 172 | expected, _ := http.NewRequest("POST", "/signup", nil) 173 | expectedForm := url.Values{ 174 | "bool": {"true"}, 175 | "int": {"9"}, 176 | "int8": {"14"}, 177 | "int16": {"841"}, 178 | "int32": {"193"}, 179 | "int64": {"475"}, 180 | "uint": {"11"}, 181 | "uint8": {"4"}, 182 | "uint16": {"48"}, 183 | "uint32": {"9583"}, 184 | "uint64": {"183471"}, 185 | "float32": {"3.14"}, 186 | "float64": {"0.618"}, 187 | "complex64": {"(1+4i)"}, 188 | "complex128": {"(-6+17i)"}, 189 | "string": {"doggy"}, 190 | "time": {"1991-11-10T00:00:00Z"}, 191 | 192 | "bool_pointer": {"true"}, 193 | "int_pointer": {"9"}, 194 | "int8_pointer": {"14"}, 195 | "int16_pointer": {"841"}, 196 | "int32_pointer": {"193"}, 197 | "int64_pointer": {"475"}, 198 | "uint_pointer": {"11"}, 199 | "uint8_pointer": {"4"}, 200 | "uint16_pointer": {"48"}, 201 | "uint32_pointer": {"9583"}, 202 | "uint64_pointer": {"183471"}, 203 | "float32_pointer": {"3.14"}, 204 | "float64_pointer": {"0.618"}, 205 | "complex64_pointer": {"(1+4i)"}, 206 | "complex128_pointer": {"(-6+17i)"}, 207 | "string_pointer": {"doggy"}, 208 | "time_pointer": {"1991-11-10T00:00:00Z"}, 209 | 210 | "bools": {"true", "false", "false", "true"}, 211 | "ints": {"9", "9", "6"}, 212 | "floats": {"0", "0.5", "1"}, 213 | "strings": {"Life", "is", "a", "Miracle"}, 214 | "times": {"2000-01-02T22:04:05Z", "1991-06-28T06:00:00Z"}, 215 | } 216 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode())) 217 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded") 218 | assert.Equal(t, expected, req) 219 | } 220 | 221 | func TestDirectiveForm_NewRequest_ByteSlice(t *testing.T) { 222 | type ByteSlice struct { 223 | Bytes []byte `in:"form=bytes"` 224 | MultiBytes [][]byte `in:"form=multi_bytes"` 225 | } 226 | co, err := New(ByteSlice{}) 227 | assert.NoError(t, err) 228 | payload := &ByteSlice{ 229 | Bytes: []byte("hello"), 230 | MultiBytes: [][]byte{ 231 | []byte("hello"), 232 | []byte("world"), 233 | }, 234 | } 235 | expected, _ := http.NewRequest("POST", "/api", nil) 236 | expectedForm := url.Values{ 237 | "bytes": {base64.StdEncoding.EncodeToString(payload.Bytes)}, 238 | "multi_bytes": { 239 | base64.StdEncoding.EncodeToString(payload.MultiBytes[0]), 240 | base64.StdEncoding.EncodeToString(payload.MultiBytes[1]), 241 | }, 242 | } 243 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded") 244 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode())) 245 | req, err := co.NewRequest("POST", "/api", payload) 246 | assert.NoError(t, err) 247 | assert.Equal(t, expected, req) 248 | } 249 | -------------------------------------------------------------------------------- /core/formencoder.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "github.com/ggicci/httpin/internal" 5 | ) 6 | 7 | type FormEncoder struct { 8 | Setter func(key string, value []string) // form value setter 9 | } 10 | 11 | func (e *FormEncoder) Execute(rtm *DirectiveRuntime) error { 12 | if rtm.Value.IsZero() { 13 | if rtm.Resolver.GetDirective("omitempty") != nil { 14 | return nil 15 | } 16 | } 17 | 18 | if rtm.IsFieldSet() { 19 | return nil // skip when already encoded by former directives 20 | } 21 | 22 | key := rtm.Directive.Argv[0] 23 | valueType := rtm.Value.Type() 24 | // When baseType is a file type, we treat it as a file upload. 25 | if isFileType(valueType) { 26 | if internal.IsNil(rtm.Value) { 27 | return nil // skip when nil, which means no file uploaded 28 | } 29 | 30 | encoder, err := NewFileSlicable(rtm.Value) 31 | if err != nil { 32 | return err 33 | } 34 | files, err := encoder.ToFileSlice() 35 | if err != nil { 36 | return err 37 | } 38 | if len(files) == 0 { 39 | return nil // skip when no file uploaded 40 | } 41 | return fileUploadBuilder(rtm, files) 42 | } 43 | 44 | var adapt AnyStringableAdaptor 45 | encoderInfo := rtm.GetCustomCoder() 46 | if encoderInfo != nil { 47 | adapt = encoderInfo.Adapt 48 | } 49 | var encoder StringSlicable 50 | encoder, err := NewStringSlicable(rtm.Value, adapt) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | if values, err := encoder.ToStringSlice(); err != nil { 56 | return err 57 | } else { 58 | e.Setter(key, values) 59 | rtm.MarkFieldSet(true) 60 | return nil 61 | } 62 | } 63 | 64 | func fileUploadBuilder(rtm *DirectiveRuntime, files []FileMarshaler) error { 65 | rb := rtm.GetRequestBuilder() 66 | key := rtm.Directive.Argv[0] 67 | rb.SetAttachment(key, files) 68 | rtm.MarkFieldSet(true) 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /core/formencoder_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestFormEncoder_FieldSetByFormerDirectives(t *testing.T) { 11 | type SearchQuery struct { 12 | AccessToken string `in:"query=access_token;header=x-api-key"` 13 | } 14 | 15 | co, err := New(&SearchQuery{}) 16 | assert.NoError(t, err) 17 | 18 | req, err := co.NewRequest("GET", "/search", &SearchQuery{ 19 | AccessToken: "123456", 20 | }) 21 | assert.NoError(t, err) 22 | 23 | // The AccessToken field should be set by the query directive. 24 | expected, _ := http.NewRequest("GET", "/search?access_token=123456", nil) 25 | assert.Equal(t, expected, req) 26 | } 27 | -------------------------------------------------------------------------------- /core/formextractor.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "mime/multipart" 5 | ) 6 | 7 | type FormExtractor struct { 8 | Runtime *DirectiveRuntime 9 | multipart.Form 10 | KeyNormalizer func(string) string 11 | } 12 | 13 | func (e *FormExtractor) Extract(keys ...string) error { 14 | if len(keys) == 0 { 15 | keys = e.Runtime.Directive.Argv 16 | } 17 | for _, key := range keys { 18 | if e.KeyNormalizer != nil { 19 | key = e.KeyNormalizer(key) 20 | } 21 | if err := e.extract(key); err != nil { 22 | return err 23 | } 24 | } 25 | return nil 26 | } 27 | 28 | func (e *FormExtractor) extract(key string) error { 29 | if e.Runtime.IsFieldSet() { 30 | return nil // skip when already extracted by former directives 31 | } 32 | 33 | values := e.Form.Value[key] 34 | files := e.Form.File[key] 35 | 36 | // Quick fail on empty input. 37 | if len(values) == 0 && len(files) == 0 { 38 | return nil 39 | } 40 | 41 | var sourceValue any 42 | var err error 43 | valueType := e.Runtime.Value.Type().Elem() 44 | if isFileType(valueType) { 45 | // When fileDecoder is not nil, it means that the field is a file upload. 46 | // We should decode files instead of values. 47 | if len(files) == 0 { 48 | return nil // skip when no file uploaded 49 | } 50 | sourceValue = files 51 | 52 | var decoder FileSlicable 53 | decoder, err = NewFileSlicable(e.Runtime.Value.Elem()) 54 | if err == nil { 55 | err = decoder.FromFileSlice(toFileHeaderList(files)) 56 | } 57 | } else { 58 | if len(values) == 0 { 59 | return nil // skip when no value given 60 | } 61 | sourceValue = values 62 | 63 | var adapt AnyStringableAdaptor 64 | decoderInfo := e.Runtime.GetCustomCoder() // custom decoder, specified by "decoder" directive 65 | if decoderInfo != nil { 66 | adapt = decoderInfo.Adapt 67 | } 68 | var decoder StringSlicable 69 | decoder, err = NewStringSlicable(e.Runtime.Value.Elem(), adapt) 70 | if err == nil { 71 | err = decoder.FromStringSlice(values) 72 | } 73 | } 74 | 75 | if err != nil { 76 | return &fieldError{key, sourceValue, err} 77 | } 78 | e.Runtime.MarkFieldSet(true) 79 | return nil 80 | } 81 | -------------------------------------------------------------------------------- /core/header.go: -------------------------------------------------------------------------------- 1 | // directive: "header" 2 | // https://ggicci.github.io/httpin/directives/header 3 | 4 | package core 5 | 6 | import ( 7 | "mime/multipart" 8 | "net/http" 9 | ) 10 | 11 | type DirectiveHeader struct{} 12 | 13 | // Decode implements the "header" executor who extracts values 14 | // from the HTTP headers. 15 | func (*DirectiveHeader) Decode(rtm *DirectiveRuntime) error { 16 | req := rtm.GetRequest() 17 | extractor := &FormExtractor{ 18 | Runtime: rtm, 19 | Form: multipart.Form{ 20 | Value: req.Header, 21 | }, 22 | KeyNormalizer: http.CanonicalHeaderKey, 23 | } 24 | return extractor.Extract() 25 | } 26 | 27 | func (*DirectiveHeader) Encode(rtm *DirectiveRuntime) error { 28 | encoder := &FormEncoder{ 29 | Setter: rtm.GetRequestBuilder().SetHeader, 30 | } 31 | return encoder.Execute(rtm) 32 | } 33 | -------------------------------------------------------------------------------- /core/header_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDirectiveHeader_Decode(t *testing.T) { 11 | type SearchQuery struct { 12 | ApiUid int `in:"header=x-api-uid"` 13 | ApiToken string `in:"header=x-api-token"` 14 | } 15 | 16 | r, _ := http.NewRequest("GET", "/", nil) 17 | r.Header.Set("X-Api-Token", "some-secret-token") 18 | r.Header.Set("X-Api-Uid", "91241844") 19 | expected := &SearchQuery{ 20 | ApiUid: 91241844, 21 | ApiToken: "some-secret-token", 22 | } 23 | co, err := New(SearchQuery{}) 24 | assert.NoError(t, err) 25 | got, err := co.Decode(r) 26 | assert.NoError(t, err) 27 | assert.Equal(t, expected, got.(*SearchQuery)) 28 | } 29 | 30 | func TestDirectiveHeader_NewRequest(t *testing.T) { 31 | type ApiQuery struct { 32 | ApiUid int `in:"header=x-api-uid;omitempty"` 33 | ApiToken *string `in:"header=X-Api-Token;omitempty"` 34 | } 35 | 36 | t.Run("with all values", func(t *testing.T) { 37 | tk := "some-secret-token" 38 | query := &ApiQuery{ 39 | ApiUid: 91241844, 40 | ApiToken: &tk, 41 | } 42 | 43 | co, err := New(ApiQuery{}) 44 | assert.NoError(t, err) 45 | req, err := co.NewRequest("POST", "/api", query) 46 | assert.NoError(t, err) 47 | 48 | expected, _ := http.NewRequest("POST", "/api", nil) 49 | // NOTE: the key will be canonicalized 50 | expected.Header.Set("x-api-uid", "91241844") 51 | expected.Header.Set("X-Api-Token", "some-secret-token") 52 | assert.Equal(t, expected, req) 53 | }) 54 | 55 | t.Run("with empty value", func(t *testing.T) { 56 | query := &ApiQuery{ 57 | ApiUid: 0, 58 | ApiToken: nil, 59 | } 60 | 61 | co, err := New(ApiQuery{}) 62 | assert.NoError(t, err) 63 | req, err := co.NewRequest("POST", "/api", query) 64 | assert.NoError(t, err) 65 | 66 | expected, _ := http.NewRequest("POST", "/api", nil) 67 | assert.Equal(t, expected, req) 68 | 69 | _, ok := req.Header["X-Api-Uid"] 70 | assert.False(t, ok) 71 | 72 | _, ok = req.Header["X-Api-Token"] 73 | assert.False(t, ok) 74 | }) 75 | } 76 | -------------------------------------------------------------------------------- /core/hybridcoder.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "encoding" 5 | "errors" 6 | "reflect" 7 | 8 | "github.com/ggicci/httpin/internal" 9 | ) 10 | 11 | type HybridCoder struct { 12 | internal.StringMarshaler 13 | internal.StringUnmarshaler 14 | } 15 | 16 | func (c *HybridCoder) ToString() (string, error) { 17 | if c.StringMarshaler != nil { 18 | return c.StringMarshaler.ToString() 19 | } 20 | return "", errors.New("StringMarshaler not implemented") 21 | } 22 | 23 | func (c *HybridCoder) FromString(s string) error { 24 | if c.StringUnmarshaler != nil { 25 | return c.StringUnmarshaler.FromString(s) 26 | } 27 | return errors.New("StringUnmarshaler not implemented") 28 | } 29 | 30 | // Hybridize a reflect.Value to a Stringable if possible. 31 | func hybridizeCoder(rv reflect.Value) Stringable { 32 | if !rv.CanInterface() { 33 | return nil 34 | } 35 | 36 | coder := &HybridCoder{} 37 | 38 | // Interface: StringMarshaler. 39 | if rv.Type().Implements(stringMarshalerType) { 40 | coder.StringMarshaler = rv.Interface().(internal.StringMarshaler) 41 | } else if rv.Type().Implements(textMarshalerType) { 42 | coder.StringMarshaler = &textMarshalerWrapper{rv.Interface().(encoding.TextMarshaler), nil} 43 | } 44 | 45 | // Interface: StringUnmarshaler. 46 | if rv.Type().Implements(stringUnmarshalerType) { 47 | coder.StringUnmarshaler = rv.Interface().(internal.StringUnmarshaler) 48 | } else if rv.Type().Implements(textUnmarshalerType) { 49 | coder.StringUnmarshaler = &textMarshalerWrapper{nil, rv.Interface().(encoding.TextUnmarshaler)} 50 | } 51 | 52 | if coder.StringMarshaler == nil && coder.StringUnmarshaler == nil { 53 | return nil 54 | } 55 | 56 | return coder 57 | } 58 | 59 | type textMarshalerWrapper struct { 60 | encoding.TextMarshaler 61 | encoding.TextUnmarshaler 62 | } 63 | 64 | func (w textMarshalerWrapper) ToString() (string, error) { 65 | b, err := w.TextMarshaler.MarshalText() 66 | if err != nil { 67 | return "", err 68 | } 69 | return string(b), nil 70 | } 71 | 72 | func (w textMarshalerWrapper) FromString(s string) error { 73 | return w.TextUnmarshaler.UnmarshalText([]byte(s)) 74 | } 75 | 76 | var ( 77 | stringMarshalerType = internal.TypeOf[internal.StringMarshaler]() 78 | stringUnmarshalerType = internal.TypeOf[internal.StringUnmarshaler]() 79 | textMarshalerType = internal.TypeOf[encoding.TextMarshaler]() 80 | textUnmarshalerType = internal.TypeOf[encoding.TextUnmarshaler]() 81 | ) 82 | -------------------------------------------------------------------------------- /core/hybridcoder_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestHybridCoder_MarshalTextOnly(t *testing.T) { 12 | apple := &textMarshalerApple{} 13 | rv := reflect.ValueOf(apple) 14 | stringable := hybridizeCoder(rv) 15 | assert.NotNil(t, stringable) 16 | 17 | text, err := stringable.ToString() 18 | assert.NoError(t, err) 19 | assert.Equal(t, "apple", text) 20 | 21 | assert.ErrorContains(t, stringable.FromString("red apple"), "StringUnmarshaler not implemented") 22 | } 23 | 24 | func TestHybridCoder_UnmarshalTextOnly(t *testing.T) { 25 | banana := &textUnmarshalerBanana{} 26 | rv := reflect.ValueOf(banana) 27 | stringable := hybridizeCoder(rv) 28 | assert.NotNil(t, stringable) 29 | 30 | text, err := stringable.ToString() 31 | assert.ErrorContains(t, err, "StringMarshaler not implemented") 32 | assert.Empty(t, text) 33 | 34 | err = stringable.FromString("yellow banana") 35 | assert.NoError(t, err) 36 | assert.Equal(t, "yellow banana", banana.Content) 37 | } 38 | 39 | func TestHybridCoder_MarshalText_and_UnmarshalText(t *testing.T) { 40 | orange := &textMarshalerAndUnmarshalerOrange{Content: "orange"} 41 | rv := reflect.ValueOf(orange) 42 | stringable := hybridizeCoder(rv) 43 | assert.NotNil(t, stringable) 44 | 45 | text, err := stringable.ToString() 46 | assert.NoError(t, err) 47 | assert.Equal(t, "orange", text) 48 | 49 | err = stringable.FromString("red orange") 50 | assert.NoError(t, err) 51 | assert.Equal(t, "red orange", orange.Content) 52 | } 53 | 54 | func TestHybridCoder_StringMarshaler_TakesPrecedence(t *testing.T) { 55 | peach := &stringMarshalerAndTextMarshalerPeach{Content: "peach"} 56 | rv := reflect.ValueOf(peach) 57 | stringable := hybridizeCoder(rv) 58 | assert.NotNil(t, stringable) 59 | 60 | text, err := stringable.ToString() 61 | assert.NoError(t, err) 62 | assert.Equal(t, "ToString:peach", text) 63 | 64 | err = stringable.FromString("red peach") 65 | assert.ErrorContains(t, err, "StringUnmarshaler not implemented") 66 | } 67 | 68 | func TestHybridCoder_StringUnmarshaler_TakesPrecedence(t *testing.T) { 69 | peach := &stringUnmarshalerAndTextUnmarshalerPeach{Content: "peach"} 70 | rv := reflect.ValueOf(peach) 71 | stringable := hybridizeCoder(rv) 72 | assert.NotNil(t, stringable) 73 | 74 | text, err := stringable.ToString() 75 | assert.ErrorContains(t, err, "StringMarshaler not implemented") 76 | assert.Empty(t, text) 77 | 78 | err = stringable.FromString("red peach") 79 | assert.NoError(t, err) 80 | assert.Equal(t, "FromString:red peach", peach.Content) 81 | } 82 | 83 | func TestHybridCoder_StringMarshaler_and_TextUnmarshaler(t *testing.T) { 84 | pineapple := &stringMarshalerAndTextUnmarshalerPineapple{Content: "pineapple"} 85 | rv := reflect.ValueOf(pineapple) 86 | stringable := hybridizeCoder(rv) 87 | assert.NotNil(t, stringable) 88 | 89 | text, err := stringable.ToString() 90 | assert.NoError(t, err) 91 | assert.Equal(t, "ToString:pineapple", text) 92 | 93 | err = stringable.FromString("red pineapple") 94 | assert.NoError(t, err) 95 | assert.Equal(t, "UnmarshalText:red pineapple", pineapple.Content) 96 | } 97 | 98 | func TestHybridCoder_MarshalText_Error(t *testing.T) { 99 | watermelon := &textMarshalerSpoiledWatermelon{} 100 | rv := reflect.ValueOf(watermelon) 101 | stringable := hybridizeCoder(rv) 102 | assert.NotNil(t, stringable) 103 | 104 | text, err := stringable.ToString() 105 | assert.ErrorContains(t, err, "spoiled") 106 | assert.Empty(t, text) 107 | } 108 | 109 | func TestHybridCoder_ErrCannotInterface(t *testing.T) { 110 | type mystruct struct { 111 | unexportedName string 112 | } 113 | v := mystruct{unexportedName: "mystruct"} 114 | rv := reflect.ValueOf(v) 115 | 116 | stringable := hybridizeCoder(rv.Field(0)) 117 | assert.Nil(t, stringable) 118 | } 119 | 120 | func TestHybridCoder_NilOnNoInterfacesDetected(t *testing.T) { 121 | var zero zeroInterface 122 | rv := reflect.ValueOf(zero) 123 | 124 | stringable := hybridizeCoder(rv) 125 | assert.Nil(t, stringable) 126 | } 127 | 128 | type textMarshalerApple struct{} // only implements encoding.TextMarshaler 129 | 130 | func (t *textMarshalerApple) MarshalText() ([]byte, error) { 131 | return []byte("apple"), nil 132 | } 133 | 134 | type textUnmarshalerBanana struct{ Content string } // only implements encoding.TextUnmarshaler 135 | 136 | func (t *textUnmarshalerBanana) UnmarshalText(text []byte) error { 137 | t.Content = string(text) 138 | return nil 139 | } 140 | 141 | type textMarshalerAndUnmarshalerOrange struct{ Content string } // implements both encoding.TextMarshaler and encoding.TextUnmarshaler 142 | 143 | func (t *textMarshalerAndUnmarshalerOrange) MarshalText() ([]byte, error) { 144 | return []byte(t.Content), nil 145 | } 146 | 147 | func (t *textMarshalerAndUnmarshalerOrange) UnmarshalText(text []byte) error { 148 | t.Content = string(text) 149 | return nil 150 | } 151 | 152 | // implements internal.StringMarshaler and encoding.TextMarshaler 153 | // will use internal.StringMarshaler 154 | type stringMarshalerAndTextMarshalerPeach struct{ Content string } 155 | 156 | func (s *stringMarshalerAndTextMarshalerPeach) ToString() (string, error) { 157 | return "ToString:" + s.Content, nil 158 | } 159 | 160 | func (s *stringMarshalerAndTextMarshalerPeach) MarshalText() ([]byte, error) { 161 | return []byte("MarshalText:" + s.Content), nil 162 | } 163 | 164 | // implements internal.StringUnmarshaler and encoding.TextUnmarshaler 165 | // will use internal.StringUnmarshaler 166 | type stringUnmarshalerAndTextUnmarshalerPeach struct{ Content string } 167 | 168 | func (s *stringUnmarshalerAndTextUnmarshalerPeach) FromString(text string) error { 169 | s.Content = "FromString:" + text 170 | return nil 171 | } 172 | 173 | func (s *stringUnmarshalerAndTextUnmarshalerPeach) UnmarshalText(text []byte) error { 174 | s.Content = "UnmarshalText:" + string(text) 175 | return nil 176 | } 177 | 178 | type stringMarshalerAndTextUnmarshalerPineapple struct{ Content string } 179 | 180 | func (s *stringMarshalerAndTextUnmarshalerPineapple) ToString() (string, error) { 181 | return "ToString:" + s.Content, nil 182 | } 183 | 184 | func (s *stringMarshalerAndTextUnmarshalerPineapple) UnmarshalText(text []byte) error { 185 | s.Content = "UnmarshalText:" + string(text) 186 | return nil 187 | } 188 | 189 | type textMarshalerSpoiledWatermelon struct{} 190 | 191 | func (t *textMarshalerSpoiledWatermelon) MarshalText() ([]byte, error) { 192 | return nil, errors.New("spoiled") 193 | } 194 | 195 | type zeroInterface struct{} 196 | -------------------------------------------------------------------------------- /core/nonzero.go: -------------------------------------------------------------------------------- 1 | // directive: "nonzero" 2 | // https://ggicci.github.io/httpin/directives/nonzero 3 | 4 | package core 5 | 6 | import "errors" 7 | 8 | // DirectiveNonzero implements the "nonzero" executor who indicates that the field must not be a "zero value". 9 | // In golang, the "zero value" means: 10 | // - nil 11 | // - false 12 | // - 0 13 | // - "" 14 | // - etc. 15 | // 16 | // Unlike the "required" executor, the "nonzero" executor checks the value of the field. 17 | type DirectiveNonzero struct{} 18 | 19 | func (*DirectiveNonzero) Decode(rtm *DirectiveRuntime) error { 20 | if rtm.Value.Elem().IsZero() { 21 | return errors.New("zero value") 22 | } 23 | return nil 24 | } 25 | 26 | func (*DirectiveNonzero) Encode(rtm *DirectiveRuntime) error { 27 | if rtm.Value.IsZero() { 28 | return errors.New("zero value") 29 | } 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /core/nonzero_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type NonzeroQuery struct { 12 | Name string `in:"query=name;nonzero"` 13 | AgeRange []int `in:"query=age;nonzero"` 14 | } 15 | 16 | func TestDirectiveNonzero_Decode(t *testing.T) { 17 | co, err := New(&NonzeroQuery{}) 18 | assert.NoError(t, err) 19 | 20 | r, _ := http.NewRequest("GET", "/users", nil) 21 | r.URL.RawQuery = url.Values{ 22 | "name": {"ggicci"}, 23 | "age": {"18", "999"}, 24 | }.Encode() 25 | 26 | got, err := co.Decode(r) 27 | assert.NoError(t, err) 28 | assert.Equal(t, &NonzeroQuery{ 29 | Name: "ggicci", 30 | AgeRange: []int{18, 999}, 31 | }, got.(*NonzeroQuery)) 32 | } 33 | 34 | func TestDirectiveNonzero_Decode_ErrZeroValue(t *testing.T) { 35 | co, err := New(&NonzeroQuery{}) 36 | assert.NoError(t, err) 37 | 38 | r, _ := http.NewRequest("GET", "/users", nil) 39 | r.URL.RawQuery = url.Values{ 40 | "name": {"ggicci"}, 41 | }.Encode() 42 | 43 | _, err = co.Decode(r) 44 | assert.Error(t, err) 45 | var invalidField *InvalidFieldError 46 | assert.ErrorAs(t, err, &invalidField) 47 | assert.Equal(t, "AgeRange", invalidField.Field) 48 | assert.Equal(t, "nonzero", invalidField.Directive) 49 | assert.Empty(t, invalidField.Key) 50 | assert.Nil(t, invalidField.Value) 51 | } 52 | 53 | func TestDirectiveNonzero_Decode_InNestedJSONBody_Issue49(t *testing.T) { 54 | type UpdateUserInput struct { 55 | Payload struct { 56 | Display string `json:"display" in:"nonzero"` 57 | } `in:"body=json"` 58 | } 59 | 60 | // NOTE: WithNestedDirectivesEnabled(true) is required to enable nested directives. 61 | co, err := New(&UpdateUserInput{}, WithNestedDirectivesEnabled(true)) 62 | assert.NoError(t, err) 63 | 64 | r, _ := http.NewRequest("POST", "/users/1", nil) 65 | r.Header.Set("Content-Type", "application/json") 66 | r.Body = makeBodyReader(`{"display": ""}`) 67 | got, err := co.Decode(r) 68 | assert.Nil(t, got) 69 | assert.ErrorContains(t, err, "nonzero") 70 | var invalidField *InvalidFieldError 71 | assert.ErrorAs(t, err, &invalidField) 72 | assert.Equal(t, "Payload", invalidField.Field) 73 | assert.Equal(t, "nonzero", invalidField.Directive) 74 | } 75 | 76 | func TestDirectiveNonzero_NewRequest(t *testing.T) { 77 | co, err := New(&NonzeroQuery{}) 78 | assert.NoError(t, err) 79 | 80 | expected, _ := http.NewRequest("GET", "/users", nil) 81 | expected.URL.RawQuery = url.Values{ 82 | "name": {"ggicci"}, 83 | "age": {"18", "999"}, 84 | }.Encode() 85 | 86 | req, err := co.NewRequest("GET", "/users", &NonzeroQuery{ 87 | Name: "ggicci", 88 | AgeRange: []int{18, 999}, 89 | }) 90 | assert.NoError(t, err) 91 | assert.Equal(t, expected, req) 92 | } 93 | 94 | func TestDirectiveNonzero_NewRequest_ErrZeroValue(t *testing.T) { 95 | co, err := New(&NonzeroQuery{}) 96 | assert.NoError(t, err) 97 | 98 | _, err = co.NewRequest("GET", "/users", &NonzeroQuery{}) 99 | assert.ErrorContains(t, err, "zero value") 100 | assert.ErrorContains(t, err, "Name") 101 | assert.ErrorContains(t, err, "AgeRange") 102 | } 103 | -------------------------------------------------------------------------------- /core/omitempty.go: -------------------------------------------------------------------------------- 1 | // directive: "omitempty" 2 | // https://ggicci.github.io/httpin/directives/omitempty 3 | 4 | package core 5 | 6 | // DirectiveOmitEmpty is used with the DirectiveQuery, DirectiveForm, and DirectiveHeader to indicate that the field 7 | // should be omitted when the value is empty. 8 | // It does not have any affect when used by itself 9 | type DirectiveOmitEmpty struct{} 10 | 11 | func (*DirectiveOmitEmpty) Decode(_ *DirectiveRuntime) error { 12 | return nil 13 | } 14 | 15 | func (*DirectiveOmitEmpty) Encode(_ *DirectiveRuntime) error { 16 | return nil 17 | } 18 | -------------------------------------------------------------------------------- /core/option.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | const minimumMaxMemory = int64(1 << 10) // 1KB 8 | const defaultMaxMemory = int64(32 << 20) // 32 MB 9 | 10 | type Option func(*Core) error 11 | 12 | var globalNestedDirectivesEnabled bool = false 13 | 14 | // EnableNestedDirectives sets the global flag to enable nested directives. 15 | // Nested directives are disabled by default. 16 | func EnableNestedDirectives(on bool) { 17 | globalNestedDirectivesEnabled = on 18 | } 19 | 20 | // WithErrorHandler overrides the default error handler. 21 | func WithErrorHandler(custom ErrorHandler) Option { 22 | return func(c *Core) error { 23 | if err := validateErrorHandler(custom); err != nil { 24 | return err 25 | } else { 26 | c.errorHandler = custom 27 | return nil 28 | } 29 | } 30 | } 31 | 32 | // WithMaxMemory overrides the default maximum memory size (32MB) when reading 33 | // the request body. See https://pkg.go.dev/net/http#Request.ParseMultipartForm 34 | // for more details. 35 | func WithMaxMemory(maxMemory int64) Option { 36 | return func(c *Core) error { 37 | if maxMemory < minimumMaxMemory { 38 | return errors.New("max memory too small") 39 | } 40 | c.maxMemory = maxMemory 41 | return nil 42 | } 43 | } 44 | 45 | // WithNestedDirectivesEnabled enables/disables nested directives. 46 | func WithNestedDirectivesEnabled(enable bool) Option { 47 | return func(c *Core) error { 48 | c.enableNestedDirectives = enable 49 | return nil 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /core/option_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "net/http" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestWithErrorHandler(t *testing.T) { 12 | // Use the default error handler. 13 | co, _ := New(ProductQuery{}) 14 | assert.True(t, equalFuncs(globalCustomErrorHandler, co.GetErrorHandler())) 15 | 16 | // Override the default error handler. 17 | myErrorHandler := func(rw http.ResponseWriter, r *http.Request, err error) {} 18 | co, _ = New(ProductQuery{}, WithErrorHandler(myErrorHandler)) 19 | assert.True(t, equalFuncs(myErrorHandler, co.GetErrorHandler())) 20 | 21 | // Fail on nil error handler. 22 | _, err := New(ProductQuery{}, WithErrorHandler(nil)) 23 | assert.ErrorContains(t, err, "nil error handler") 24 | } 25 | 26 | func TestWithMaxMemory(t *testing.T) { 27 | // Use the default max memory. 28 | co, _ := New(ProductQuery{}) 29 | assert.Equal(t, defaultMaxMemory, co.maxMemory) 30 | 31 | // Override the default max memory. 32 | co, _ = New(ProductQuery{}, WithMaxMemory(16<<20)) 33 | assert.Equal(t, int64(16<<20), co.maxMemory) 34 | 35 | // Fail on too small max memory. 36 | _, err := New(ProductQuery{}, WithMaxMemory(100)) 37 | assert.ErrorContains(t, err, "max memory too small") 38 | } 39 | 40 | func TestWithNestedDirectivesEnabled(t *testing.T) { 41 | // Override the default nested directives flag. 42 | co, _ := New(ProductQuery{}, WithNestedDirectivesEnabled(true)) 43 | assert.Equal(t, true, co.enableNestedDirectives) 44 | co, _ = New(ProductQuery{}, WithNestedDirectivesEnabled(false)) 45 | assert.Equal(t, false, co.enableNestedDirectives) 46 | } 47 | 48 | func TestEnableNestedDirectives(t *testing.T) { 49 | // Use the default nested directives flag. 50 | EnableNestedDirectives(false) 51 | co, _ := New(ProductQuery{}) 52 | assert.Equal(t, false, co.enableNestedDirectives) 53 | 54 | EnableNestedDirectives(true) 55 | co, _ = New(ProductQuery{}) 56 | assert.Equal(t, true, co.enableNestedDirectives) 57 | } 58 | 59 | func equalFuncs(expected, actual any) bool { 60 | return reflect.ValueOf(expected).Pointer() == reflect.ValueOf(actual).Pointer() 61 | } 62 | -------------------------------------------------------------------------------- /core/owl.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import "github.com/ggicci/owl" 4 | 5 | func init() { 6 | owl.UseTag("in") 7 | } 8 | -------------------------------------------------------------------------------- /core/patch_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/url" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/ggicci/httpin/patch" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type AccountPatch struct { 16 | Email patch.Field[string] `in:"form=email"` 17 | Age patch.Field[int] `in:"form=age"` 18 | Avatar patch.Field[*File] `in:"form=avatar"` 19 | Hobbies patch.Field[[]string] `in:"form=hobbies"` 20 | Pictures patch.Field[[]*File] `in:"form=pictures"` 21 | } 22 | 23 | func TestPatchField(t *testing.T) { 24 | fileContent := []byte("hello") 25 | r := newMultipartFormRequestFromMap(map[string]any{ 26 | "age": "18", 27 | "avatar": fileContent, 28 | "hobbies": []string{ 29 | "reading", 30 | "swimming", 31 | }, 32 | }) 33 | 34 | co, err := New(AccountPatch{}) 35 | assert.NoError(t, err) 36 | gotValue, err := co.Decode(r) 37 | assert.NoError(t, err) 38 | got := gotValue.(*AccountPatch) 39 | 40 | assert.Equal(t, patch.Field[string]{ 41 | Valid: false, 42 | Value: "", 43 | }, got.Email) 44 | 45 | assert.Equal(t, patch.Field[int]{ 46 | Valid: true, 47 | Value: 18, 48 | }, got.Age) 49 | 50 | assert.Equal(t, patch.Field[[]string]{ 51 | Valid: true, 52 | Value: []string{"reading", "swimming"}, 53 | }, got.Hobbies) 54 | 55 | assert.Equal(t, patch.Field[[]*File]{ 56 | Valid: false, 57 | Value: nil, 58 | }, got.Pictures) 59 | 60 | assertDecodedFile(t, got.Avatar.Value, "avatar.txt", fileContent) 61 | } 62 | 63 | func TestPatchField_DecodeValueFailed(t *testing.T) { 64 | r, _ := http.NewRequest("GET", "/", nil) 65 | r.Form = url.Values{ 66 | "email": {"abc@example.com"}, 67 | "age": {"eighteen"}, 68 | } 69 | co, err := New(AccountPatch{}) 70 | assert.NoError(t, err) 71 | gotValue, err := co.Decode(r) 72 | assert.Error(t, err) 73 | var ferr *InvalidFieldError 74 | assert.ErrorAs(t, err, &ferr) 75 | assert.Equal(t, "Age", ferr.Field) 76 | assert.Equal(t, []string{"eighteen"}, ferr.Value) 77 | assert.Equal(t, "form", ferr.Directive) 78 | assert.Nil(t, gotValue) 79 | } 80 | 81 | func TestPatchField_DecodeFileFailed(t *testing.T) { 82 | body, writer := newMultipartFormWriterFromMap(map[string]any{ 83 | "email": "abc@example.com", 84 | "age": "18", 85 | "avatar": []byte("hello"), 86 | }) 87 | 88 | // break the boundary to make the file decoder fail 89 | r, _ := http.NewRequest("POST", "/", breakMultipartFormBoundary(body)) 90 | r.Header.Set("Content-Type", writer.FormDataContentType()) 91 | 92 | co, err := New(AccountPatch{}) 93 | assert.NoError(t, err) 94 | gotValue, err := co.Decode(r) 95 | assert.Nil(t, gotValue) 96 | assert.Error(t, err) 97 | } 98 | 99 | func TestPatchField_NewRequest(t *testing.T) { 100 | type ListQuery struct { 101 | Username patch.Field[string] `in:"query=username"` 102 | Age patch.Field[int] `in:"query=age"` 103 | State patch.Field[[]string] `in:"query=state[]"` 104 | } 105 | 106 | co, err := New(ListQuery{}) 107 | assert.NoError(t, err) 108 | 109 | testcases := []struct { 110 | Query *ListQuery 111 | Expected url.Values 112 | }{ 113 | {&ListQuery{ 114 | Username: patch.Field[string]{Value: "ggicci", Valid: true}, 115 | Age: patch.Field[int]{Value: 18, Valid: false}, 116 | }, url.Values{"username": {"ggicci"}}}, 117 | {&ListQuery{ 118 | Age: patch.Field[int]{Value: 18, Valid: false}, 119 | }, url.Values{}}, 120 | {&ListQuery{ 121 | Age: patch.Field[int]{Value: 18, Valid: true}, 122 | }, url.Values{"age": {"18"}}}, 123 | {&ListQuery{ 124 | Username: patch.Field[string]{Value: "ggicci", Valid: true}, 125 | Age: patch.Field[int]{Value: 18, Valid: true}, 126 | State: patch.Field[[]string]{ 127 | Value: []string{"reading", "swimming"}, 128 | Valid: true, 129 | }, 130 | }, url.Values{ 131 | "username": {"ggicci"}, 132 | "age": {"18"}, 133 | "state[]": {"reading", "swimming"}, 134 | }}, 135 | } 136 | 137 | for _, c := range testcases { 138 | req, err := co.NewRequest("GET", "/list", c.Query) 139 | assert.NoError(t, err) 140 | 141 | expected, _ := http.NewRequest("GET", "/list", nil) 142 | expected.URL.RawQuery = c.Expected.Encode() 143 | assert.Equal(t, expected, req) 144 | } 145 | } 146 | 147 | func TestPatchField_NewRequest_NoFiles(t *testing.T) { 148 | assert := assert.New(t) 149 | payload := &AccountPatch{ 150 | Email: patch.Field[string]{Value: "abc@example.com", Valid: true}, 151 | Age: patch.Field[int]{Value: 18, Valid: true}, 152 | Avatar: patch.Field[*File]{Value: nil, Valid: false}, 153 | Hobbies: patch.Field[[]string]{ 154 | Value: []string{"reading", "swimming"}, 155 | Valid: true, 156 | }, 157 | Pictures: patch.Field[[]*File]{Value: nil, Valid: false}, 158 | } 159 | 160 | expected, _ := http.NewRequest("POST", "/patchAccount", nil) 161 | expectedForm := url.Values{ 162 | "email": {"abc@example.com"}, 163 | "age": {"18"}, 164 | "hobbies": {"reading", "swimming"}, 165 | } 166 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded") 167 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode())) 168 | 169 | co, err := New(AccountPatch{}) 170 | assert.NoError(err) 171 | req, err := co.NewRequest("POST", "/patchAccount", payload) 172 | assert.NoError(err) 173 | assert.Equal(expected, req) 174 | } 175 | 176 | func TestPatchField_NewRequest_WithFiles(t *testing.T) { 177 | assert := assert.New(t) 178 | avatarFile := createTempFile(t, []byte("handsome avatar image")) 179 | pic1Filename := createTempFile(t, []byte("pic1 content")) 180 | pic2Filename := createTempFile(t, []byte("pic2 content")) 181 | 182 | payload := &AccountPatch{ 183 | Email: patch.Field[string]{Value: "abc@example.com", Valid: true}, 184 | Age: patch.Field[int]{Value: 18, Valid: true}, 185 | Avatar: patch.Field[*File]{Value: UploadFile(avatarFile), Valid: true}, 186 | Hobbies: patch.Field[[]string]{ 187 | Value: []string{"reading", "swimming"}, 188 | Valid: true, 189 | }, 190 | Pictures: patch.Field[[]*File]{ 191 | Value: []*File{ 192 | UploadFile(pic1Filename), 193 | UploadFile(pic2Filename), 194 | }, 195 | Valid: true, 196 | }, 197 | } 198 | 199 | // See TestMultipartFormEncode_UploadFilename for more details. 200 | co, err := New(AccountPatch{}) 201 | assert.NoError(err) 202 | req, err := co.NewRequest("POST", "/patchAccount", payload) 203 | assert.NoError(err) 204 | 205 | // Server side: receive files (decode). 206 | gotValue, err := co.Decode(req) 207 | assert.NoError(err) 208 | got, ok := gotValue.(*AccountPatch) 209 | assert.True(ok) 210 | assert.True(got.Email.Valid) 211 | assert.Equal("abc@example.com", got.Email.Value) 212 | assert.True(got.Age.Valid) 213 | assert.Equal(18, got.Age.Value) 214 | assert.True(got.Hobbies.Valid) 215 | assert.Equal([]string{"reading", "swimming"}, got.Hobbies.Value) 216 | assert.True(got.Avatar.Valid) 217 | assertDecodedFile(t, got.Avatar.Value, filepath.Base(avatarFile), []byte("handsome avatar image")) 218 | assert.True(got.Pictures.Valid) 219 | assert.Len(got.Pictures.Value, 2) 220 | assertDecodedFile(t, got.Pictures.Value[0], filepath.Base(pic1Filename), []byte("pic1 content")) 221 | assertDecodedFile(t, got.Pictures.Value[1], filepath.Base(pic2Filename), []byte("pic2 content")) 222 | } 223 | -------------------------------------------------------------------------------- /core/path.go: -------------------------------------------------------------------------------- 1 | // directive: "path" 2 | // https://ggicci.github.io/httpin/directives/path 3 | 4 | package core 5 | 6 | import "errors" 7 | 8 | type DirectivePath struct { 9 | decode func(*DirectiveRuntime) error 10 | } 11 | 12 | func NewDirectivePath(decodeFunc func(*DirectiveRuntime) error) *DirectivePath { 13 | return &DirectivePath{ 14 | decode: decodeFunc, 15 | } 16 | } 17 | 18 | func (dir *DirectivePath) Decode(rtm *DirectiveRuntime) error { 19 | return dir.decode(rtm) 20 | } 21 | 22 | // Encode replaces the placeholders in URL path with the given value. 23 | func (*DirectivePath) Encode(rtm *DirectiveRuntime) error { 24 | encoder := &FormEncoder{ 25 | Setter: rtm.GetRequestBuilder().SetPath, 26 | } 27 | return encoder.Execute(rtm) 28 | } 29 | 30 | // defaultPathDirective is the default path directive, which only supports encoding, 31 | // while the decoding function is not implmented. Because the path decoding depends on the 32 | // routing framework, it should be implemented in the integration package. 33 | // See integration/gochi.go and integration/gorilla.go for examples. 34 | var defaultPathDirective = NewDirectivePath(func(rtm *DirectiveRuntime) error { 35 | return errors.New("unimplemented path decoding function") 36 | }) 37 | -------------------------------------------------------------------------------- /core/path_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func myCustomPathDecode(rtm *DirectiveRuntime) error { 14 | return assert.AnError 15 | } 16 | 17 | func TestDirectivePath_Decode(t *testing.T) { 18 | pathDirective := NewDirectivePath(myCustomPathDecode) 19 | assert.ErrorIs(t, pathDirective.Decode(nil), assert.AnError) 20 | } 21 | 22 | func TestDirectivePath_ErrDefaultPathDecodingIsUnimplemented(t *testing.T) { 23 | type GetProfileRequest struct { 24 | Username string `in:"path=username"` 25 | } 26 | 27 | co, err := New(GetProfileRequest{}) 28 | assert.NoError(t, err) 29 | req, _ := http.NewRequest("GET", "/users/ggicci", nil) 30 | _, err = co.Decode(req) 31 | assert.ErrorContains(t, err, "unimplemented path decoding function") 32 | } 33 | 34 | func TestDirectivePath_NewRequest_DefalutPathEncodingShouldWork(t *testing.T) { 35 | assert := assert.New(t) 36 | type Repository struct { 37 | Name string `json:"name"` 38 | Visibility string `json:"visibility"` // public, private, internal 39 | License string `json:"license"` 40 | } 41 | type CreateRepositoryRequest struct { 42 | Owner string `in:"path=owner"` 43 | Payload *Repository `in:"body=json"` 44 | } 45 | 46 | query := &CreateRepositoryRequest{ 47 | Owner: "ggicci", 48 | Payload: &Repository{ 49 | Name: "httpin", 50 | Visibility: "public", 51 | License: "MIT", 52 | }, 53 | } 54 | 55 | co, err := New(query) 56 | assert.NoError(err) 57 | req, err := co.NewRequest("POST", "/users/{owner}/repos", query) 58 | assert.NoError(err) 59 | 60 | expected, _ := http.NewRequest("POST", "/users/ggicci/repos", nil) 61 | expected.Header.Set("Content-Type", "application/json") 62 | var body bytes.Buffer 63 | assert.NoError(json.NewEncoder(&body).Encode(query.Payload)) 64 | expected.Body = io.NopCloser(&body) 65 | assert.Equal(expected, req) 66 | } 67 | -------------------------------------------------------------------------------- /core/query.go: -------------------------------------------------------------------------------- 1 | // directive: "query" 2 | // https://ggicci.github.io/httpin/directives/query 3 | 4 | package core 5 | 6 | import ( 7 | "mime/multipart" 8 | ) 9 | 10 | type DirectiveQuery struct{} 11 | 12 | // Decode implements the "query" executor who extracts values from 13 | // the querystring of an HTTP request. 14 | func (*DirectiveQuery) Decode(rtm *DirectiveRuntime) error { 15 | req := rtm.GetRequest() 16 | extractor := &FormExtractor{ 17 | Runtime: rtm, 18 | Form: multipart.Form{ 19 | Value: req.URL.Query(), 20 | }, 21 | } 22 | return extractor.Extract() 23 | } 24 | 25 | func (*DirectiveQuery) Encode(rtm *DirectiveRuntime) error { 26 | encoder := &FormEncoder{ 27 | Setter: rtm.GetRequestBuilder().SetQuery, 28 | } 29 | return encoder.Execute(rtm) 30 | } 31 | -------------------------------------------------------------------------------- /core/query_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestDirectiveQuery_Decode(t *testing.T) { 13 | type SearchQuery struct { 14 | Query string `in:"query=q;required"` 15 | PageNumber int `in:"query=p"` 16 | PageSize int `in:"query=page_size"` 17 | } 18 | 19 | r, _ := http.NewRequest("GET", "/?q=doggy&p=2&page_size=5", nil) 20 | expected := &SearchQuery{ 21 | Query: "doggy", 22 | PageNumber: 2, 23 | PageSize: 5, 24 | } 25 | 26 | co, err := New(SearchQuery{}) 27 | assert.NoError(t, err) 28 | got, err := co.Decode(r) 29 | assert.NoError(t, err) 30 | assert.Equal(t, expected, got.(*SearchQuery)) 31 | } 32 | 33 | func TestDirectiveQuery_NewRequest(t *testing.T) { 34 | type SearchQuery struct { 35 | Name string `in:"query=name"` 36 | Age int `in:"query=age;omitempty"` 37 | Enabled bool `in:"query=enabled"` 38 | Price float64 `in:"query=price"` 39 | 40 | NameList []string `in:"query=name_list[]"` 41 | AgeList []int `in:"query=age_list[]"` 42 | 43 | NamePointer *string `in:"query=name_pointer"` 44 | AgePointer *int `in:"query=age_pointer;omitempty"` 45 | } 46 | 47 | t.Run("with all values", func(t *testing.T) { 48 | query := &SearchQuery{ 49 | Name: "cupcake", 50 | Age: 12, 51 | Enabled: true, 52 | Price: 6.28, 53 | NameList: []string{"apple", "banana", "cherry"}, 54 | AgeList: []int{1, 2, 3}, 55 | NamePointer: func() *string { 56 | s := "pointer cupcake" 57 | return &s 58 | }(), 59 | AgePointer: func() *int { 60 | i := 19 61 | return &i 62 | }(), 63 | } 64 | 65 | co, err := New(SearchQuery{}) 66 | assert.NoError(t, err) 67 | req, err := co.NewRequest("GET", "/pets", query) 68 | assert.NoError(t, err) 69 | 70 | expected, _ := http.NewRequest("GET", "/pets", nil) 71 | expectedQuery := make(url.Values) 72 | expectedQuery.Set("name", query.Name) // query.Name 73 | expectedQuery.Set("age", "12") // query.Age 74 | expectedQuery.Set("enabled", "true") // query.Enabled 75 | expectedQuery.Set("price", "6.28") // query.Price 76 | expectedQuery["name_list[]"] = query.NameList // query.NameList 77 | expectedQuery["age_list[]"] = []string{"1", "2", "3"} // query.AgeList 78 | expectedQuery.Set("name_pointer", *query.NamePointer) // query.NamePointer 79 | expectedQuery.Set("age_pointer", "19") // query.PointerAge 80 | expected.URL.RawQuery = expectedQuery.Encode() 81 | assert.Equal(t, expected, req) 82 | }) 83 | 84 | t.Run("with empty values", func(t *testing.T) { 85 | query := &SearchQuery{} 86 | 87 | co, err := New(SearchQuery{}) 88 | assert.NoError(t, err) 89 | req, err := co.NewRequest("GET", "/pets", query) 90 | assert.NoError(t, err) 91 | 92 | assert.True(t, req.URL.Query().Has("name")) 93 | assert.False(t, req.URL.Query().Has("age")) 94 | 95 | assert.True(t, req.URL.Query().Has("name_pointer")) 96 | assert.False(t, req.URL.Query().Has("age_pointer")) 97 | }) 98 | } 99 | 100 | type Location struct { 101 | Latitude float64 102 | Longitude float64 103 | } 104 | 105 | func (l Location) ToString() (string, error) { 106 | return fmt.Sprintf("%f,%f", l.Latitude, l.Longitude), nil 107 | } 108 | 109 | type LocationImplementedTextMarshaler Location 110 | 111 | func (l LocationImplementedTextMarshaler) MarshalText() ([]byte, error) { 112 | if s, err := (Location)(l).ToString(); err != nil { 113 | return nil, err 114 | } else { 115 | return []byte("MarshalText:" + s), nil 116 | } 117 | } 118 | func TestDirectiveQuery_NewRequest_ErrUnsupportedType(t *testing.T) { 119 | type SearchQuery struct { 120 | Map map[string]string `in:"query=map"` // unsupported type: map 121 | } 122 | 123 | co, err := New(SearchQuery{}) 124 | assert.NoError(t, err) 125 | _, err = co.NewRequest("GET", "/pets", &SearchQuery{}) 126 | assert.ErrorIs(t, err, ErrUnsupportedType) 127 | } 128 | 129 | // See hybridcoder_test.go for more details. 130 | func TestDirectiveQuery_NewRequest_WithTextMarshaler(t *testing.T) { 131 | type SearchQuery struct { 132 | L0 *Location `in:"query=l0"` 133 | L2 *LocationImplementedTextMarshaler `in:"query=l2"` 134 | Radius int `in:"query=radius"` 135 | } 136 | 137 | query := &SearchQuery{ 138 | L0: &Location{ 139 | Latitude: 1.234, 140 | Longitude: 5.678, 141 | }, 142 | L2: &LocationImplementedTextMarshaler{ 143 | Latitude: 1.234, 144 | Longitude: 5.678, 145 | }, 146 | Radius: 1000, 147 | } 148 | 149 | co, err := New(SearchQuery{}) 150 | assert.NoError(t, err) 151 | req, err := co.NewRequest("GET", "/pets", query) 152 | assert.NoError(t, err) 153 | 154 | expected, _ := http.NewRequest("GET", "/pets", nil) 155 | expectedQuery := make(url.Values) 156 | expectedQuery.Set("l0", "1.234000,5.678000") 157 | expectedQuery.Set("l2", "MarshalText:1.234000,5.678000") 158 | expectedQuery.Set("radius", "1000") 159 | expected.URL.RawQuery = expectedQuery.Encode() 160 | assert.Equal(t, expected, req) 161 | } 162 | -------------------------------------------------------------------------------- /core/registry.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/ggicci/httpin/internal" 7 | ) 8 | 9 | type AnyStringableAdaptor = internal.AnyStringableAdaptor 10 | 11 | var ( 12 | fileTypes = make(map[reflect.Type]struct{}) 13 | customStringableAdaptors = make(map[reflect.Type]AnyStringableAdaptor) 14 | namedStringableAdaptors = make(map[string]*NamedAnyStringableAdaptor) 15 | ) 16 | 17 | // RegisterCoder registers a custom coder for the given type T. When a field of 18 | // type T is encountered, this coder will be used to convert the value to a 19 | // Stringable, which will be used to convert the value from/to string. 20 | // 21 | // NOTE: this function is designed to override the default Stringable adaptors 22 | // that are registered by this package. For example, if you want to override the 23 | // defualt behaviour of converting a bool value from/to string, you can do this: 24 | // 25 | // func init() { 26 | // core.RegisterCoder[bool](func(b *bool) (core.Stringable, error) { 27 | // return (*YesNo)(b), nil 28 | // }) 29 | // } 30 | // 31 | // type YesNo bool 32 | // 33 | // func (yn YesNo) String() string { 34 | // if yn { 35 | // return "yes" 36 | // } 37 | // return "no" 38 | // } 39 | // 40 | // func (yn *YesNo) FromString(s string) error { 41 | // switch s { 42 | // case "yes": 43 | // *yn = true 44 | // case "no": 45 | // *yn = false 46 | // default: 47 | // return fmt.Errorf("invalid YesNo value: %q", s) 48 | // } 49 | // return nil 50 | // } 51 | func RegisterCoder[T any](adapt func(*T) (Stringable, error)) { 52 | customStringableAdaptors[internal.TypeOf[T]()] = internal.NewAnyStringableAdaptor[T](adapt) 53 | } 54 | 55 | // RegisterNamedCoder works similar to RegisterCoder, except that it binds the 56 | // coder to a name. This is useful when you only want to override the types in 57 | // a specific struct field. You will be using the "coder" or "decoder" directive 58 | // to specify the name of the coder to use. For example: 59 | // 60 | // type MyStruct struct { 61 | // Bool bool // use default bool coder 62 | // YesNo bool `in:"coder=yesno"` // use YesNo coder 63 | // } 64 | // 65 | // func init() { 66 | // core.RegisterNamedCoder[bool]("yesno", func(b *bool) (core.Stringable, error) { 67 | // return (*YesNo)(b), nil 68 | // }) 69 | // } 70 | // 71 | // type YesNo bool 72 | // 73 | // func (yn YesNo) String() string { 74 | // if yn { 75 | // return "yes" 76 | // } 77 | // return "no" 78 | // } 79 | // 80 | // func (yn *YesNo) FromString(s string) error { 81 | // switch s { 82 | // case "yes": 83 | // *yn = true 84 | // case "no": 85 | // *yn = false 86 | // default: 87 | // return fmt.Errorf("invalid YesNo value: %q", s) 88 | // } 89 | // return nil 90 | // } 91 | func RegisterNamedCoder[T any](name string, adapt func(*T) (Stringable, error)) { 92 | namedStringableAdaptors[name] = &NamedAnyStringableAdaptor{ 93 | Name: name, 94 | BaseType: internal.TypeOf[T](), 95 | Adapt: internal.NewAnyStringableAdaptor[T](adapt), 96 | } 97 | } 98 | 99 | // RegisterFileCoder registers the given type T as a file type. T must implement 100 | // the Fileable interface. Remember if you don't register the type explicitly, 101 | // it won't be recognized as a file type. 102 | func RegisterFileCoder[T Fileable]() error { 103 | fileTypes[internal.TypeOf[T]()] = struct{}{} 104 | return nil 105 | } 106 | 107 | type NamedAnyStringableAdaptor struct { 108 | Name string 109 | BaseType reflect.Type 110 | Adapt AnyStringableAdaptor 111 | } 112 | 113 | func isFileType(typ reflect.Type) bool { 114 | baseType, _ := BaseTypeOf(typ) 115 | _, ok := fileTypes[baseType] 116 | return ok 117 | } 118 | -------------------------------------------------------------------------------- /core/requestbuilder.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "mime/multipart" 9 | "net/http" 10 | "net/url" 11 | "path/filepath" 12 | "strings" 13 | ) 14 | 15 | type RequestBuilder struct { 16 | Query url.Values 17 | Form url.Values 18 | Attachment map[string][]FileMarshaler 19 | Header http.Header 20 | Cookie []*http.Cookie 21 | Path map[string]string // placeholder: value 22 | BodyType string // json, xml, etc. 23 | Body io.ReadCloser 24 | ctx context.Context 25 | } 26 | 27 | func NewRequestBuilder(ctx context.Context) *RequestBuilder { 28 | return &RequestBuilder{ 29 | Query: make(url.Values), 30 | Form: make(url.Values), 31 | Attachment: make(map[string][]FileMarshaler), 32 | Header: make(http.Header), 33 | Cookie: make([]*http.Cookie, 0), 34 | Path: make(map[string]string), 35 | ctx: ctx, 36 | } 37 | } 38 | 39 | func (rb *RequestBuilder) Populate(req *http.Request) error { 40 | if err := rb.validate(); err != nil { 41 | return err 42 | } 43 | 44 | // Populate the querystring. 45 | req.URL.RawQuery = rb.Query.Encode() 46 | 47 | // Populate forms. 48 | if rb.hasForm() { 49 | if rb.hasAttachment() { // multipart form 50 | if err := rb.populateMultipartForm(req); err != nil { 51 | return err 52 | } 53 | } else { // urlencoded form 54 | rb.populateForm(req) 55 | } 56 | } 57 | 58 | // Populate body. 59 | if rb.hasBody() { 60 | req.Body = rb.Body 61 | rb.Header.Set("Content-Type", rb.bodyContentType()) 62 | } 63 | 64 | // Populate path. 65 | if rb.hasPath() { 66 | newPath := req.URL.Path 67 | for key, value := range rb.Path { 68 | newPath = strings.Replace(newPath, "{"+key+"}", value, -1) 69 | } 70 | req.URL.Path = newPath 71 | req.URL.RawPath = "" 72 | } 73 | 74 | // Populate the headers. 75 | if rb.Header != nil { 76 | req.Header = rb.Header 77 | } 78 | 79 | // Populate the cookies. 80 | for _, cookie := range rb.Cookie { 81 | req.AddCookie(cookie) 82 | } 83 | 84 | return nil 85 | } 86 | 87 | func (rb *RequestBuilder) SetQuery(key string, value []string) { 88 | rb.Query[key] = value 89 | } 90 | 91 | func (rb *RequestBuilder) SetForm(key string, value []string) { 92 | rb.Form[key] = value 93 | } 94 | 95 | func (rb *RequestBuilder) SetHeader(key string, value []string) { 96 | rb.Header[http.CanonicalHeaderKey(key)] = value 97 | } 98 | 99 | func (rb *RequestBuilder) SetPath(key string, value []string) { 100 | if len(value) > 0 { 101 | rb.Path[key] = value[0] 102 | } 103 | } 104 | 105 | func (rb *RequestBuilder) SetBody(bodyType string, bodyReader io.ReadCloser) { 106 | rb.BodyType = bodyType 107 | rb.Body = bodyReader 108 | } 109 | 110 | func (rb *RequestBuilder) SetAttachment(key string, files []FileMarshaler) { 111 | rb.Attachment[key] = files 112 | } 113 | 114 | func (rb *RequestBuilder) bodyContentType() string { 115 | switch rb.BodyType { 116 | case "json": 117 | return "application/json" 118 | case "xml": 119 | return "application/xml" 120 | } 121 | return "" 122 | } 123 | 124 | func (rb *RequestBuilder) validate() error { 125 | if rb.hasForm() && rb.hasBody() { 126 | return errors.New("cannot use both form and body directive at the same time") 127 | } 128 | return nil 129 | } 130 | 131 | func (rb *RequestBuilder) hasPath() bool { 132 | return len(rb.Path) > 0 133 | } 134 | 135 | func (rb *RequestBuilder) hasForm() bool { 136 | return len(rb.Form) > 0 || rb.hasAttachment() 137 | } 138 | 139 | func (rb *RequestBuilder) hasAttachment() bool { 140 | return len(rb.Attachment) > 0 141 | } 142 | 143 | func (rb *RequestBuilder) hasBody() bool { 144 | return rb.Body != nil && rb.BodyType != "" 145 | } 146 | 147 | func (rb *RequestBuilder) populateForm(req *http.Request) { 148 | rb.Header.Set("Content-Type", "application/x-www-form-urlencoded") 149 | formData := rb.Form.Encode() 150 | req.Body = io.NopCloser(strings.NewReader(formData)) 151 | } 152 | 153 | func (rb *RequestBuilder) populateMultipartForm(req *http.Request) error { 154 | // Create a pipe and a multipart writer. 155 | pr, pw := io.Pipe() 156 | writer := multipart.NewWriter(pw) 157 | 158 | // Write the multipart form data to the pipe in a separate goroutine. 159 | go func() { 160 | defer pw.Close() 161 | defer writer.Close() 162 | 163 | // Populate the form fields. 164 | for k, v := range rb.Form { 165 | for _, sv := range v { 166 | select { 167 | case <-rb.ctx.Done(): 168 | pw.CloseWithError(rb.ctx.Err()) 169 | return 170 | default: 171 | fieldWriter, _ := writer.CreateFormField(k) 172 | fieldWriter.Write([]byte(sv)) 173 | } 174 | } 175 | } 176 | 177 | // Populate the attachments. 178 | for key, files := range rb.Attachment { 179 | for i, file := range files { 180 | select { 181 | case <-rb.ctx.Done(): 182 | pw.CloseWithError(rb.ctx.Err()) 183 | return 184 | default: 185 | filename := file.Filename() 186 | contentReader, err := file.MarshalFile() 187 | filename = normalizeUploadFilename(key, filename, i) 188 | 189 | if err != nil { 190 | pw.CloseWithError(fmt.Errorf("upload %s %q failed: %w", key, filename, err)) 191 | return 192 | } 193 | 194 | fileWriter, _ := writer.CreateFormFile(key, filename) 195 | if _, err = io.Copy(fileWriter, contentReader); err != nil { 196 | pw.CloseWithError(fmt.Errorf("upload %s %q failed: %w", key, filename, err)) 197 | return 198 | } 199 | } 200 | } 201 | } 202 | }() 203 | 204 | // Set the body to the read end of the pipe and the content type. 205 | req.Body = io.NopCloser(pr) 206 | rb.Header.Set("Content-Type", writer.FormDataContentType()) 207 | return nil 208 | } 209 | 210 | func normalizeUploadFilename(key, filename string, index int) string { 211 | if filename == "" { 212 | return fmt.Sprintf("%s_%d", key, index) 213 | } 214 | return filepath.Base(filename) 215 | } 216 | -------------------------------------------------------------------------------- /core/requestbuilder_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | // FIX: https://github.com/ggicci/httpin/issues/88 15 | // Impossible to make streaming 16 | func TestIssue88_RequestBuilderFileUploadStreaming(t *testing.T) { 17 | rb := NewRequestBuilder(context.Background()) 18 | 19 | var contentReader = strings.NewReader("hello world") 20 | rb.SetAttachment("file", []FileMarshaler{ 21 | UploadStream(io.NopCloser(contentReader)), 22 | }) 23 | req, _ := http.NewRequest("GET", "/", nil) 24 | rb.Populate(req) 25 | 26 | err := req.ParseMultipartForm(32 << 20) 27 | assert.NoError(t, err) 28 | 29 | file, fh, err := req.FormFile("file") 30 | assert.NoError(t, err) 31 | assert.Equal(t, "file_0", fh.Filename) 32 | content, err := io.ReadAll(file) 33 | assert.NoError(t, err) 34 | assert.Equal(t, "hello world", string(content)) 35 | } 36 | 37 | func TestIssue88_CancelStreaming(t *testing.T) { 38 | ctx, cancel := context.WithCancel(context.Background()) 39 | 40 | rb := NewRequestBuilder(ctx) 41 | var contentReader = strings.NewReader("hello world") 42 | rb.SetAttachment("file", []FileMarshaler{ 43 | UploadStream(io.NopCloser(contentReader)), 44 | }) 45 | req, _ := http.NewRequest("GET", "/", nil) 46 | rb.Populate(req) 47 | 48 | cancel() 49 | time.Sleep(time.Millisecond * 100) 50 | 51 | err := req.ParseMultipartForm(32 << 20) 52 | assert.ErrorContains(t, err, "context canceled") 53 | } 54 | -------------------------------------------------------------------------------- /core/required.go: -------------------------------------------------------------------------------- 1 | // directive: "required" 2 | // https://ggicci.github.io/httpin/directives/required 3 | 4 | package core 5 | 6 | import "errors" 7 | 8 | // DirectiveRequired implements the "required" executor who indicates that the field must be set. 9 | // If the field value were not set by former executors, errMissingField will be 10 | // returned. 11 | // 12 | // NOTE: the "required" executor does not check the value of the field, it only checks 13 | // if the field is set. In realcases, it's used to require that the key is present in 14 | // the input data, e.g. form, header, etc. But it allows the value to be empty. 15 | type DirectiveRequired struct{} 16 | 17 | func (*DirectiveRequired) Decode(rtm *DirectiveRuntime) error { 18 | if rtm.IsFieldSet() { 19 | return nil 20 | } 21 | return errors.New("missing required field") 22 | } 23 | 24 | func (*DirectiveRequired) Encode(rtm *DirectiveRuntime) error { 25 | if rtm.IsFieldSet() { 26 | return nil 27 | } 28 | return errors.New("missing required field") 29 | } 30 | -------------------------------------------------------------------------------- /core/required_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | type RequiredQuery struct { 15 | CreatedAt time.Time `in:"form=created_at;required"` 16 | Color string `in:"form=colour,color"` 17 | } 18 | 19 | func TestDirectiveRequired_Decode_RequiredFieldMissing(t *testing.T) { 20 | r, _ := http.NewRequest("GET", "/", nil) 21 | r.Form = url.Values{ 22 | "color": {"red"}, 23 | } 24 | co, err := New(&RequiredQuery{}) 25 | assert.NoError(t, err) 26 | _, err = co.Decode(r) 27 | assert.ErrorContains(t, err, "missing required field") 28 | var invalidField *InvalidFieldError 29 | assert.ErrorAs(t, err, &invalidField) 30 | assert.Equal(t, "CreatedAt", invalidField.Field) 31 | assert.Equal(t, "required", invalidField.Directive) 32 | assert.Empty(t, invalidField.Key) 33 | assert.Nil(t, invalidField.Value) 34 | } 35 | 36 | func TestDirectiveRequired_Decode_NonRequiredFieldAbsent(t *testing.T) { 37 | r, _ := http.NewRequest("GET", "/", nil) 38 | r.Form = url.Values{ 39 | "created_at": {"1991-11-10T08:00:00+08:00"}, 40 | "is_soldout": {"true"}, 41 | "page": {"1"}, 42 | "per_page": {"20"}, 43 | } 44 | expected := &RequiredQuery{ 45 | CreatedAt: time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC), 46 | Color: "", 47 | } 48 | co, err := New(RequiredQuery{}) 49 | assert.NoError(t, err) 50 | got, err := co.Decode(r) 51 | assert.NoError(t, err) 52 | assert.Equal(t, expected, got.(*RequiredQuery)) 53 | } 54 | 55 | func TestDirectiveRequired_NewRequest_RequiredFieldPresent(t *testing.T) { 56 | co, err := New(&RequiredQuery{}) 57 | assert.NoError(t, err) 58 | 59 | payload := &RequiredQuery{ 60 | CreatedAt: time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC), 61 | Color: "red", 62 | } 63 | expected, _ := http.NewRequest("GET", "/hello", nil) 64 | expectedForm := url.Values{ 65 | "created_at": {"1991-11-10T00:00:00Z"}, 66 | "colour": {"red"}, // NOTE: will use the first name in the tag 67 | } 68 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded") 69 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode())) 70 | req, err := co.NewRequest("GET", "/hello", payload) 71 | assert.NoError(t, err) 72 | assert.Equal(t, expected, req) 73 | } 74 | -------------------------------------------------------------------------------- /core/stringable.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "time" 8 | 9 | "github.com/ggicci/httpin/internal" 10 | "github.com/ggicci/owl" 11 | ) 12 | 13 | type Stringable = internal.Stringable 14 | 15 | var ErrUnsupportedType = owl.ErrUnsupportedType 16 | 17 | func NewStringable(rv reflect.Value, adapt AnyStringableAdaptor) (stringable Stringable, err error) { 18 | if IsPatchField(rv.Type()) { 19 | stringable, err = NewStringablePatchFieldWrapper(rv, adapt) 20 | } else { 21 | stringable, err = newStringable(rv, adapt) 22 | } 23 | if err != nil { 24 | return nil, err 25 | } 26 | return stringable, nil 27 | } 28 | 29 | // Create a Stringable from a reflect.Value. If rv is a pointer type, it will 30 | // try to create a Stringable from rv. Otherwise, it will try to create a 31 | // Stringable from rv.Addr(). Only basic built-in types are supported. As a 32 | // special case, time.Time is also supported. 33 | func newStringable(rv reflect.Value, adapt AnyStringableAdaptor) (Stringable, error) { 34 | rv, err := getPointer(rv) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | // Now rv is a pointer type. 40 | if adapt != nil { 41 | return adapt(rv.Interface()) 42 | } 43 | 44 | // Custom type adaptors go first. Which means the coder of a specific type 45 | // has already been registered/overridden by user. 46 | if adapt, ok := customStringableAdaptors[rv.Type().Elem()]; ok { 47 | return adapt(rv.Interface()) 48 | } 49 | 50 | // For the base type time.Time, it is a special case here. 51 | // We won't use TextMarshaler/TextUnmarshaler for time.Time. 52 | if rv.Type().Elem() == timeType { 53 | return internal.NewStringable(rv) 54 | } 55 | 56 | if hybridCoder := hybridizeCoder(rv); hybridCoder != nil { 57 | return hybridCoder, nil 58 | } 59 | 60 | // Fallback to use built-in stringable types. 61 | return internal.NewStringable(rv) 62 | } 63 | 64 | type StringablePatchFieldWrapper struct { 65 | Value reflect.Value // of patch.Field[T] 66 | internalStringable Stringable 67 | } 68 | 69 | func NewStringablePatchFieldWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringablePatchFieldWrapper, error) { 70 | stringable, err := NewStringable(rv.FieldByName("Value"), adapt) 71 | if err != nil { 72 | return &StringablePatchFieldWrapper{}, fmt.Errorf("cannot create Stringable for PatchField: %w", err) 73 | } else { 74 | return &StringablePatchFieldWrapper{ 75 | Value: rv, 76 | internalStringable: stringable, 77 | }, nil 78 | } 79 | } 80 | 81 | func (w *StringablePatchFieldWrapper) ToString() (string, error) { 82 | if w.Value.FieldByName("Valid").Bool() { 83 | return w.internalStringable.ToString() 84 | } else { 85 | return "", errors.New("invalid value") // when Valid is false 86 | } 87 | } 88 | 89 | // FromString sets the value of the wrapped patch.Field[T] from the given 90 | // string. It returns an error if the given string is not valid. And leaves the 91 | // original value of both Value and Valid unchanged. On the other hand, if no 92 | // error occurs, it sets Valid to true. 93 | func (w *StringablePatchFieldWrapper) FromString(s string) error { 94 | if err := w.internalStringable.FromString(s); err != nil { 95 | return err 96 | } else { 97 | w.Value.FieldByName("Valid").SetBool(true) 98 | return nil 99 | } 100 | } 101 | 102 | var timeType = internal.TypeOf[time.Time]() 103 | 104 | func getPointer(rv reflect.Value) (reflect.Value, error) { 105 | if rv.Kind() == reflect.Pointer { 106 | return createInstanceIfNil(rv), nil 107 | } else { 108 | return addressOf(rv) 109 | } 110 | } 111 | 112 | func createInstanceIfNil(rv reflect.Value) reflect.Value { 113 | if rv.IsNil() { 114 | rv.Set(reflect.New(rv.Type().Elem())) 115 | } 116 | return rv 117 | } 118 | 119 | func addressOf(rv reflect.Value) (reflect.Value, error) { 120 | if !rv.CanAddr() { 121 | return rv, fmt.Errorf("cannot get address of value %v", rv) 122 | } 123 | rv = rv.Addr() 124 | return rv, nil 125 | } 126 | -------------------------------------------------------------------------------- /core/stringable_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "github.com/ggicci/httpin/internal" 10 | "github.com/ggicci/httpin/patch" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | type Point2D struct { 15 | X int 16 | Y int 17 | } 18 | 19 | func (p Point2D) ToString() (string, error) { 20 | return fmt.Sprintf("Point2D(%d,%d)", p.X, p.Y), nil 21 | } 22 | 23 | func (p *Point2D) FromString(s string) error { 24 | _, err := fmt.Sscanf(s, "Point2D(%d,%d)", &p.X, &p.Y) 25 | return err 26 | } 27 | 28 | type MyStruct struct { 29 | Name string 30 | NamePointer *string 31 | Names []string 32 | PatchName patch.Field[string] 33 | PatchNames patch.Field[[]string] 34 | 35 | Age int 36 | AgePointer *int 37 | Ages []int 38 | PatchAge patch.Field[int] 39 | PatchAges patch.Field[[]int] 40 | 41 | Dot Point2D 42 | DotPointer *Point2D 43 | Dots []Point2D 44 | PatchDot patch.Field[Point2D] 45 | PatchDots patch.Field[[]Point2D] 46 | } 47 | 48 | type MyDate time.Time // adapted from time.Time 49 | 50 | func (t MyDate) ToString() (string, error) { 51 | return time.Time(t).Format("2006-01-02"), nil 52 | } 53 | 54 | func (t *MyDate) FromString(value string) error { 55 | v, err := time.Parse("2006-01-02", value) 56 | if err != nil { 57 | return &InvalidDate{Value: value, Err: err} 58 | } 59 | *t = MyDate(v) 60 | return nil 61 | } 62 | 63 | func TestStringable_FromString(t *testing.T) { 64 | rv := reflect.New(reflect.TypeOf(MyStruct{})).Elem() 65 | s := rv.Addr().Interface().(*MyStruct) 66 | 67 | // string 68 | testAssignString(t, rv.FieldByName("Name"), "Alice") 69 | testAssignString(t, rv.FieldByName("NamePointer"), "Charlie") 70 | testNewStringableErrUnsupported(t, rv.FieldByName("Names")) 71 | testAssignString(t, rv.FieldByName("PatchName"), "Bob") 72 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchNames")) 73 | 74 | assert.Equal(t, "Alice", s.Name) 75 | assert.Equal(t, "Charlie", *s.NamePointer) 76 | assert.Equal(t, []string(nil), s.Names) 77 | assert.Equal(t, "Bob", s.PatchName.Value) 78 | assert.True(t, s.PatchName.Valid) 79 | 80 | // int 81 | testAssignString(t, rv.FieldByName("Age"), "18") 82 | testAssignString(t, rv.FieldByName("AgePointer"), "20") 83 | testNewStringableErrUnsupported(t, rv.FieldByName("Ages")) 84 | testAssignString(t, rv.FieldByName("PatchAge"), "18") 85 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchAges")) 86 | 87 | assert.Equal(t, 18, s.Age) 88 | assert.Equal(t, 20, *s.AgePointer) 89 | assert.Equal(t, []int(nil), s.Ages) 90 | assert.Equal(t, 18, s.PatchAge.Value) 91 | assert.True(t, s.PatchAge.Valid) 92 | 93 | // Point2D 94 | testAssignString(t, rv.FieldByName("Dot"), "Point2D(1,2)") 95 | testAssignString(t, rv.FieldByName("DotPointer"), "Point2D(3,4)") 96 | testNewStringableErrUnsupported(t, rv.FieldByName("Dots")) 97 | testAssignString(t, rv.FieldByName("PatchDot"), "Point2D(5,6)") 98 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchDots")) 99 | 100 | assert.Equal(t, Point2D{1, 2}, s.Dot) 101 | assert.Equal(t, &Point2D{3, 4}, s.DotPointer) 102 | assert.Equal(t, []Point2D(nil), s.Dots) 103 | assert.Equal(t, Point2D{5, 6}, s.PatchDot.Value) 104 | assert.True(t, s.PatchDot.Valid) 105 | } 106 | 107 | func TestStringable_String(t *testing.T) { 108 | var s = &MyStruct{ 109 | Name: "Alice", 110 | NamePointer: internal.Pointerize[string]("Charlie"), 111 | Names: []string{"Alice", "Bob", "Charlie"}, 112 | PatchName: patch.Field[string]{Value: "Bob", Valid: true}, 113 | PatchNames: patch.Field[[]string]{Value: []string{"Alice", "Bob", "Charlie"}, Valid: true}, 114 | 115 | Age: 18, 116 | AgePointer: internal.Pointerize[int](20), 117 | Ages: []int{18, 20}, 118 | PatchAge: patch.Field[int]{Value: 18, Valid: true}, 119 | PatchAges: patch.Field[[]int]{Value: []int{18, 20}, Valid: true}, 120 | 121 | Dot: Point2D{1, 2}, 122 | DotPointer: internal.Pointerize[Point2D](Point2D{3, 4}), 123 | Dots: []Point2D{{1, 2}, {3, 4}}, 124 | PatchDot: patch.Field[Point2D]{Value: Point2D{5, 6}, Valid: true}, 125 | PatchDots: patch.Field[[]Point2D]{Value: []Point2D{{1, 2}, {3, 4}}, Valid: true}, 126 | } 127 | 128 | rv := reflect.ValueOf(s).Elem() 129 | 130 | assert.Equal(t, "Alice", testGetString(t, rv.FieldByName("Name"))) 131 | assert.Equal(t, "Charlie", testGetString(t, rv.FieldByName("NamePointer"))) 132 | testNewStringableErrUnsupported(t, rv.FieldByName("Names")) 133 | assert.Equal(t, "Bob", testGetString(t, rv.FieldByName("PatchName"))) 134 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchNames")) 135 | 136 | assert.Equal(t, "18", testGetString(t, rv.FieldByName("Age"))) 137 | assert.Equal(t, "20", testGetString(t, rv.FieldByName("AgePointer"))) 138 | testNewStringableErrUnsupported(t, rv.FieldByName("Ages")) 139 | assert.Equal(t, "18", testGetString(t, rv.FieldByName("PatchAge"))) 140 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchAges")) 141 | 142 | assert.Equal(t, "Point2D(1,2)", testGetString(t, rv.FieldByName("Dot"))) 143 | assert.Equal(t, "Point2D(3,4)", testGetString(t, rv.FieldByName("DotPointer"))) 144 | testNewStringableErrUnsupported(t, rv.FieldByName("Dots")) 145 | assert.Equal(t, "Point2D(5,6)", testGetString(t, rv.FieldByName("PatchDot"))) 146 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchDots")) 147 | } 148 | 149 | func TestStringablePatchFieldWrapper_String(t *testing.T) { 150 | var patchString = patch.Field[string]{Value: "Alice", Valid: true} 151 | rv := reflect.ValueOf(&patchString).Elem() 152 | assert.True(t, IsPatchField(rv.Type())) 153 | stringable, err := NewStringablePatchFieldWrapper(rv, nil) 154 | assert.NoError(t, err) 155 | 156 | sv, err := stringable.ToString() 157 | assert.NoError(t, err) 158 | assert.Equal(t, "Alice", sv) 159 | 160 | patchString.Valid = false // make it invalid 161 | sv, err = stringable.ToString() 162 | assert.ErrorContains(t, err, "invalid value") 163 | assert.Empty(t, sv, "invalid patch field should return empty string") 164 | } 165 | 166 | func TestStringablePatchFieldWrapper_FromString(t *testing.T) { 167 | // string 168 | var patchString = patch.Field[string]{} 169 | 170 | assert.Empty(t, patchString.Value) 171 | assert.False(t, patchString.Valid) 172 | 173 | rv := reflect.ValueOf(&patchString).Elem() 174 | assert.True(t, IsPatchField(rv.Type())) 175 | stringable, err := NewStringablePatchFieldWrapper(rv, nil) 176 | assert.NoError(t, err) 177 | assert.NoError(t, stringable.FromString("Alice")) 178 | assert.Equal(t, "Alice", patchString.Value) 179 | assert.True(t, patchString.Valid, "Valid should be set to true after a succssful FromString") 180 | 181 | // int 182 | var patchInt = patch.Field[int]{} 183 | rv = reflect.ValueOf(&patchInt).Elem() 184 | assert.True(t, IsPatchField(rv.Type())) 185 | stringable, err = NewStringablePatchFieldWrapper(rv, nil) 186 | assert.NoError(t, err) 187 | assert.Error(t, stringable.FromString("Alice")) // cannot convert "Alice" to int 188 | assert.Zero(t, patchInt.Value, "Value should not be changed after a failed FromString") 189 | assert.False(t, patchInt.Valid, "Valid should not be changed after a failed FromString") 190 | 191 | assert.NoError(t, stringable.FromString("18")) 192 | assert.Equal(t, 18, patchInt.Value) 193 | assert.True(t, patchInt.Valid, "Valid should be set to true after a succssful FromString") 194 | 195 | assert.Error(t, stringable.FromString("18.0")) // cannot convert "18.0" to int 196 | assert.Equal(t, 18, patchInt.Value, "Value should not be changed after a failed FromString") 197 | assert.True(t, patchInt.Valid, "Valid should not be changed after a failed FromString") 198 | } 199 | 200 | func TestStringable_WithAdaptor(t *testing.T) { 201 | adapt := func(t *time.Time) (Stringable, error) { return (*MyDate)(t), nil } 202 | var now = time.Now() 203 | rvTimePointer := reflect.ValueOf(&now) 204 | 205 | coder, err := NewStringable(rvTimePointer, internal.NewAnyStringableAdaptor[time.Time](adapt)) 206 | assert.NoError(t, err) 207 | assert.NoError(t, coder.FromString("1991-11-10")) 208 | 209 | s, err := coder.ToString() 210 | assert.NoError(t, err) 211 | assert.Equal(t, "1991-11-10", s) 212 | 213 | assert.ErrorContains(t, coder.FromString("1991-11-10T08:00:00+08:00"), "parsing time") 214 | } 215 | 216 | type InvalidDate struct { 217 | Value string 218 | Err error 219 | } 220 | 221 | func (e *InvalidDate) Error() string { 222 | return fmt.Sprintf("invalid date: %q (date must conform to format \"2006-01-02\"), %s", e.Value, e.Err) 223 | } 224 | 225 | func (e *InvalidDate) Unwrap() error { 226 | return e.Err 227 | } 228 | 229 | func testAssignString(t *testing.T, rv reflect.Value, value string) { 230 | s, err := NewStringable(rv, nil) 231 | assert.NoError(t, err) 232 | assert.NoError(t, s.FromString(value)) 233 | } 234 | 235 | func testNewStringableErrUnsupported(t *testing.T, rv reflect.Value) { 236 | s, err := NewStringable(rv, nil) 237 | assert.ErrorIs(t, err, ErrUnsupportedType) 238 | assert.Nil(t, s) 239 | } 240 | 241 | func testGetString(t *testing.T, rv reflect.Value) string { 242 | coder, err := NewStringable(rv, nil) 243 | assert.NoError(t, err) 244 | s, err := coder.ToString() 245 | assert.NoError(t, err) 246 | return s 247 | } 248 | -------------------------------------------------------------------------------- /core/stringslicable.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/ggicci/httpin/internal" 9 | ) 10 | 11 | type StringSlicable interface { 12 | ToStringSlice() ([]string, error) 13 | FromStringSlice([]string) error 14 | } 15 | 16 | func NewStringSlicable(rv reflect.Value, adapt AnyStringableAdaptor) (StringSlicable, error) { 17 | if rv.Type().Implements(stringSliceableType) && rv.CanInterface() { 18 | return rv.Interface().(StringSlicable), nil 19 | } 20 | 21 | if IsPatchField(rv.Type()) { 22 | return NewStringSlicablePatchFieldWrapper(rv, adapt) 23 | } 24 | 25 | if isSliceType(rv.Type()) && !isByteSliceType(rv.Type()) { 26 | return NewStringableSliceWrapper(rv, adapt) 27 | } else { 28 | return NewStringSlicableSingleStringableWrapper(rv, adapt) 29 | } 30 | } 31 | 32 | // StringSlicablePatchFieldWrapper wraps a patch.Field[T] to implement 33 | // StringSlicable. The wrapped reflect.Value must be a patch.Field[T]. 34 | // 35 | // It works like a proxy. It delegates the ToStringSlice and FromStringSlice 36 | // calls to the internal StringSlicable. 37 | type StringSlicablePatchFieldWrapper struct { 38 | Value reflect.Value // of patch.Field[T] 39 | internalStringSliceable StringSlicable 40 | } 41 | 42 | // NewStringSlicablePatchFieldWrapper creates a StringSlicablePatchFieldWrapper from rv. 43 | // Returns error when patch.Field.Value is not a StringSlicable. 44 | func NewStringSlicablePatchFieldWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringSlicablePatchFieldWrapper, error) { 45 | stringSlicable, err := NewStringSlicable(rv.FieldByName("Value"), adapt) 46 | if err != nil { 47 | return nil, err 48 | } else { 49 | return &StringSlicablePatchFieldWrapper{ 50 | Value: rv, 51 | internalStringSliceable: stringSlicable, 52 | }, nil 53 | } 54 | } 55 | 56 | func (w *StringSlicablePatchFieldWrapper) ToStringSlice() ([]string, error) { 57 | if w.Value.FieldByName("Valid").Bool() { 58 | return w.internalStringSliceable.ToStringSlice() 59 | } else { 60 | return []string{}, nil 61 | } 62 | } 63 | 64 | func (w *StringSlicablePatchFieldWrapper) FromStringSlice(values []string) error { 65 | if err := w.internalStringSliceable.FromStringSlice(values); err != nil { 66 | return err 67 | } else { 68 | w.Value.FieldByName("Valid").SetBool(true) 69 | return nil 70 | } 71 | } 72 | 73 | type StringableSlice []Stringable 74 | 75 | func (sa StringableSlice) ToStringSlice() ([]string, error) { 76 | values := make([]string, len(sa)) 77 | for i, s := range sa { 78 | if value, err := s.ToString(); err != nil { 79 | return nil, fmt.Errorf("cannot stringify %v at index %d: %w", s, i, err) 80 | } else { 81 | values[i] = value 82 | } 83 | } 84 | return values, nil 85 | } 86 | 87 | func (sa StringableSlice) FromStringSlice(values []string) error { 88 | for i, s := range values { 89 | if err := sa[i].FromString(s); err != nil { 90 | return fmt.Errorf("cannot convert from string %q at index %d: %w", s, i, err) 91 | } 92 | } 93 | return nil 94 | } 95 | 96 | // StringableSliceWrapper wraps a reflect.Value to implement StringSlicable. The 97 | // wrapped reflect.Value must be a slice of Stringable. 98 | type StringableSliceWrapper struct { 99 | Value reflect.Value 100 | Adapt AnyStringableAdaptor 101 | } 102 | 103 | // NewStringableSliceWrapper creates a StringableSliceWrapper from rv. 104 | // Returns error when rv is not a slice of Stringable or cannot get address of rv. 105 | func NewStringableSliceWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringableSliceWrapper, error) { 106 | if !rv.CanAddr() { 107 | return nil, errors.New("unaddressable value") 108 | } 109 | return &StringableSliceWrapper{Value: rv, Adapt: adapt}, nil 110 | } 111 | 112 | func (w *StringableSliceWrapper) ToStringSlice() ([]string, error) { 113 | var stringables = make(StringableSlice, w.Value.Len()) 114 | for i := 0; i < w.Value.Len(); i++ { 115 | if stringable, err := NewStringable(w.Value.Index(i), w.Adapt); err != nil { 116 | return nil, fmt.Errorf("cannot create Stringable from %q at index %d: %w", w.Value.Index(i), i, err) 117 | } else { 118 | stringables[i] = stringable 119 | } 120 | } 121 | return stringables.ToStringSlice() 122 | } 123 | 124 | func (w *StringableSliceWrapper) FromStringSlice(ss []string) error { 125 | var stringables = make(StringableSlice, len(ss)) 126 | w.Value.Set(reflect.MakeSlice(w.Value.Type(), len(ss), len(ss))) 127 | for i := range ss { 128 | if stringable, err := NewStringable(w.Value.Index(i), w.Adapt); err != nil { 129 | return fmt.Errorf("cannot create Stringable at index %d: %w", i, err) 130 | } else { 131 | stringables[i] = stringable 132 | } 133 | } 134 | return stringables.FromStringSlice(ss) 135 | } 136 | 137 | // StringSlicableSingleStringableWrapper wraps a reflect.Value to implement 138 | // StringSlicable. The wrapped reflect.Value must be a Stringable. 139 | type StringSlicableSingleStringableWrapper struct{ Stringable } 140 | 141 | func NewStringSlicableSingleStringableWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringSlicableSingleStringableWrapper, error) { 142 | if stringable, err := NewStringable(rv, adapt); err != nil { 143 | return nil, err 144 | } else { 145 | return &StringSlicableSingleStringableWrapper{stringable}, nil 146 | } 147 | } 148 | 149 | func (w *StringSlicableSingleStringableWrapper) ToStringSlice() ([]string, error) { 150 | if value, err := w.ToString(); err != nil { 151 | return nil, err 152 | } else { 153 | return []string{value}, nil 154 | } 155 | } 156 | 157 | func (w *StringSlicableSingleStringableWrapper) FromStringSlice(values []string) error { 158 | if len(values) > 0 { 159 | return w.FromString(values[0]) 160 | } 161 | return nil 162 | } 163 | 164 | var ( 165 | stringSliceableType = internal.TypeOf[StringSlicable]() 166 | byteType = internal.TypeOf[byte]() 167 | ) 168 | 169 | func isSliceType(t reflect.Type) bool { 170 | return t.Kind() == reflect.Slice || t.Kind() == reflect.Array 171 | } 172 | 173 | func isByteSliceType(t reflect.Type) bool { 174 | if isSliceType(t) && t.Elem() == byteType { 175 | return true 176 | } 177 | return false 178 | } 179 | -------------------------------------------------------------------------------- /core/stringslicable_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/ggicci/httpin/internal" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestStringSlicable_FromStringSlice(t *testing.T) { 12 | rv := reflect.New(reflect.TypeOf(MyStruct{})).Elem() 13 | s := rv.Addr().Interface().(*MyStruct) 14 | 15 | testAssignStringSlice(t, rv.FieldByName("Name"), []string{"Alice"}) 16 | testAssignStringSlice(t, rv.FieldByName("NamePointer"), []string{"Charlie"}) 17 | testAssignStringSlice(t, rv.FieldByName("Names"), []string{"Alice", "Bob", "Charlie"}) 18 | 19 | testAssignStringSlice(t, rv.FieldByName("Age"), []string{"18"}) 20 | testAssignStringSlice(t, rv.FieldByName("AgePointer"), []string{"20"}) 21 | testAssignStringSlice(t, rv.FieldByName("Ages"), []string{"18", "20"}) 22 | 23 | assert.Equal(t, "Alice", s.Name) 24 | assert.Equal(t, "Charlie", *s.NamePointer) 25 | assert.Equal(t, []string{"Alice", "Bob", "Charlie"}, s.Names) 26 | 27 | assert.Equal(t, 18, s.Age) 28 | assert.Equal(t, 20, *s.AgePointer) 29 | assert.Equal(t, []int{18, 20}, s.Ages) 30 | } 31 | 32 | func TestStringSlicable_ToStringSlice(t *testing.T) { 33 | var s = &MyStruct{ 34 | Name: "Alice", 35 | NamePointer: internal.Pointerize[string]("Charlie"), 36 | Names: []string{"Alice", "Bob", "Charlie"}, 37 | 38 | Age: 18, 39 | AgePointer: internal.Pointerize[int](20), 40 | Ages: []int{18, 20}, 41 | } 42 | 43 | rv := reflect.ValueOf(s).Elem() 44 | assert.Equal(t, []string{"Alice"}, testGetStringSlice(t, rv.FieldByName("Name"))) 45 | assert.Equal(t, []string{"Charlie"}, testGetStringSlice(t, rv.FieldByName("NamePointer"))) 46 | assert.Equal(t, []string{"Alice", "Bob", "Charlie"}, testGetStringSlice(t, rv.FieldByName("Names"))) 47 | 48 | assert.Equal(t, []string{"18"}, testGetStringSlice(t, rv.FieldByName("Age"))) 49 | assert.Equal(t, []string{"20"}, testGetStringSlice(t, rv.FieldByName("AgePointer"))) 50 | assert.Equal(t, []string{"18", "20"}, testGetStringSlice(t, rv.FieldByName("Ages"))) 51 | } 52 | 53 | func testAssignStringSlice(t *testing.T, rv reflect.Value, values []string) { 54 | ss, err := NewStringSlicable(rv, nil) 55 | assert.NoError(t, err) 56 | assert.NoError(t, ss.FromStringSlice(values)) 57 | } 58 | 59 | func testGetStringSlice(t *testing.T, rv reflect.Value) []string { 60 | ss, err := NewStringSlicable(rv, nil) 61 | assert.NoError(t, err) 62 | values, err := ss.ToStringSlice() 63 | assert.NoError(t, err) 64 | return values 65 | } 66 | -------------------------------------------------------------------------------- /core/typekind.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | type TypeKind int 9 | 10 | const ( 11 | TypeKindT TypeKind = iota // T 12 | TypeKindTSlice // []T 13 | TypeKindPatchT // patch.Field[T] 14 | TypeKindPatchTSlice // patch.Field[[]T] 15 | ) 16 | 17 | // BaseTypeOf returns the base type of the given type its kind. The kind 18 | // represents how the given type is constructed from the base type. 19 | // - T -> T, TypeKindT 20 | // - []T -> T, TypeKindTSlice 21 | // - patch.Field[T] -> T, TypeKindPatchT 22 | // - patch.Field[[]T] -> T, TypeKindPatchTSlice 23 | func BaseTypeOf(valueType reflect.Type) (reflect.Type, TypeKind) { 24 | if valueType.Kind() == reflect.Slice { 25 | return valueType.Elem(), TypeKindTSlice 26 | } 27 | if IsPatchField(valueType) { 28 | subElemType, isMulti := patchFieldElemType(valueType) 29 | if isMulti { 30 | return subElemType, TypeKindPatchTSlice 31 | } else { 32 | return subElemType, TypeKindPatchT 33 | } 34 | } 35 | return valueType, TypeKindT 36 | } 37 | 38 | func IsPatchField(t reflect.Type) bool { 39 | return t.Kind() == reflect.Struct && 40 | t.PkgPath() == "github.com/ggicci/httpin/patch" && 41 | strings.HasPrefix(t.Name(), "Field[") 42 | } 43 | 44 | func patchFieldElemType(t reflect.Type) (reflect.Type, bool) { 45 | fv, _ := t.FieldByName("Value") 46 | if fv.Type.Kind() == reflect.Slice { 47 | return fv.Type.Elem(), true 48 | } 49 | return fv.Type, false 50 | } 51 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ggicci/httpin 2 | 3 | go 1.23 4 | 5 | require ( 6 | github.com/ggicci/owl v0.8.2 7 | github.com/go-chi/chi/v5 v5.0.11 8 | github.com/gorilla/mux v1.8.1 9 | github.com/justinas/alice v1.2.0 10 | github.com/labstack/echo/v4 v4.12.0 11 | github.com/stretchr/testify v1.9.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/labstack/gommon v0.4.2 // indirect 17 | github.com/mattn/go-colorable v0.1.13 // indirect 18 | github.com/mattn/go-isatty v0.0.20 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/valyala/bytebufferpool v1.0.0 // indirect 21 | github.com/valyala/fasttemplate v1.2.2 // indirect 22 | golang.org/x/crypto v0.22.0 // indirect 23 | golang.org/x/net v0.24.0 // indirect 24 | golang.org/x/sys v0.19.0 // indirect 25 | golang.org/x/text v0.14.0 // indirect 26 | gopkg.in/yaml.v3 v3.0.1 // indirect 27 | ) 28 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA= 4 | github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= 5 | github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= 6 | github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 7 | github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= 8 | github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= 9 | github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo= 10 | github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA= 11 | github.com/labstack/echo/v4 v4.12.0 h1:IKpw49IMryVB2p1a4dzwlhP1O2Tf2E0Ir/450lH+kI0= 12 | github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM= 13 | github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= 14 | github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= 15 | github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= 16 | github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= 17 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 18 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 19 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 23 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 24 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 25 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 26 | github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= 27 | github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= 28 | golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= 29 | golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= 30 | golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= 31 | golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= 32 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 33 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 34 | golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= 35 | golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 36 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 37 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 38 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 39 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 40 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 41 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 42 | -------------------------------------------------------------------------------- /httpin.go: -------------------------------------------------------------------------------- 1 | // Package httpin helps decoding an HTTP request to a custom struct by binding 2 | // data with querystring (query params), HTTP headers, form data, JSON/XML 3 | // payloads, URL path params, and file uploads (multipart/form-data). 4 | package httpin 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "io" 10 | "net/http" 11 | "reflect" 12 | 13 | "github.com/ggicci/httpin/core" 14 | "github.com/ggicci/httpin/internal" 15 | ) 16 | 17 | type contextKey int 18 | 19 | const ( 20 | // Input is the key to get the input object from Request.Context() injected by httpin. e.g. 21 | // 22 | // input := r.Context().Value(httpin.Input).(*InputStruct) 23 | Input contextKey = iota 24 | ) 25 | 26 | // Option is a collection of options for creating a Core instance. 27 | var Option coreOptions = coreOptions{ 28 | WithErrorHandler: core.WithErrorHandler, 29 | WithMaxMemory: core.WithMaxMemory, 30 | WithNestedDirectivesEnabled: core.WithNestedDirectivesEnabled, 31 | } 32 | 33 | // New calls core.New to create a new Core instance. Which is responsible for both: 34 | // 35 | // - decoding an HTTP request to an instance of the inputStruct; 36 | // - and encoding an instance of the inputStruct to an HTTP request. 37 | // 38 | // Note that the Core instance is bound to the given specific type, it will not 39 | // work for other types. If you want to decode/encode other types, you need to 40 | // create another Core instance. Or directly use the following functions, which are 41 | // just shortcuts of Core's methods, so you don't need to create a Core instance: 42 | // - httpin.Decode(): decode an HTTP request to an instance of the inputStruct. 43 | // - httpin.NewRequest() to encode an instance of the inputStruct to an HTTP request. 44 | // 45 | // For best practice, we would recommend using httpin.NewInput() to create an 46 | // HTTP middleware for a specific input type. The middleware can be bound to an 47 | // API, chained with other middlewares, and also reused in other APIs. You even 48 | // don't need to call the Deocde() method explicitly, the middleware will do it 49 | // for you and put the decoded instance to the request's context. 50 | func New(inputStruct any, opts ...core.Option) (*core.Core, error) { 51 | return core.New(inputStruct, opts...) 52 | } 53 | 54 | // File is the builtin type of httpin to manupulate file uploads. On the server 55 | // side, it is used to represent a file in a multipart/form-data request. On the 56 | // client side, it is used to represent a file to be uploaded. 57 | type File = core.File 58 | 59 | // UploadFile is a helper function to create a File instance from a file path. 60 | // It is useful when you want to upload a file from the local file system. 61 | func UploadFile(path string) *File { 62 | return core.UploadFile(path) 63 | } 64 | 65 | // UploadStream is a helper function to create a File instance from a io.Reader. It 66 | // is useful when you want to upload a file from a stream. 67 | func UploadStream(r io.ReadCloser) *File { 68 | return core.UploadStream(r) 69 | } 70 | 71 | // DecodeTo decodes an HTTP request and populates input with data from the HTTP request. 72 | // The input must be a pointer to a struct instance. For example: 73 | // 74 | // input := &InputStruct{} 75 | // if err := DecodeTo(req, input); err != nil { ... } 76 | // 77 | // input is now populated with data from the request. 78 | func DecodeTo(req *http.Request, input any, opts ...core.Option) error { 79 | co, err := New(internal.DereferencedType(input), opts...) 80 | if err != nil { 81 | return err 82 | } 83 | return co.DecodeTo(req, input) 84 | } 85 | 86 | // Decode decodes an HTTP request to an instance of T and returns its pointer 87 | // (*T). T must be a struct type. For example: 88 | // 89 | // if user, err := Decode[User](req); err != nil { ... } 90 | // // now user is a *User instance, which has been populated with data from the request. 91 | func Decode[T any](req *http.Request, opts ...core.Option) (*T, error) { 92 | rt := internal.TypeOf[T]() 93 | if rt.Kind() != reflect.Struct { 94 | return nil, errors.New("generic type T must be a struct type") 95 | } 96 | co, err := New(rt, opts...) 97 | if err != nil { 98 | return nil, err 99 | } 100 | if v, err := co.Decode(req); err != nil { 101 | return nil, err 102 | } else { 103 | return v.(*T), nil 104 | } 105 | } 106 | 107 | // NewRequest wraps NewRequestWithContext using context.Background(), see NewRequestWithContext. 108 | func NewRequest(method, url string, input any, opts ...core.Option) (*http.Request, error) { 109 | return NewRequestWithContext(context.Background(), method, url, input) 110 | } 111 | 112 | // NewRequestWithContext turns the given input into an HTTP request. The input 113 | // must be a struct instance. And its fields' "in" tags define how to bind the 114 | // data from the struct to the HTTP request. Use it as the replacement of 115 | // http.NewRequest(). 116 | // 117 | // addUserPayload := &AddUserRequest{...} 118 | // addUserRequest, err := NewRequestWithContext(context.Background(), "GET", "http://example.com", addUserPayload) 119 | // http.DefaultClient.Do(addUserRequest) 120 | func NewRequestWithContext(ctx context.Context, method, url string, input any, opts ...core.Option) (*http.Request, error) { 121 | co, err := New(input, opts...) 122 | if err != nil { 123 | return nil, err 124 | } 125 | return co.NewRequestWithContext(ctx, method, url, input) 126 | } 127 | 128 | // NewInput creates an HTTP middleware handler. Which is a function that takes 129 | // in an http.Handler and returns another http.Handler. 130 | // 131 | // The middleware created by NewInput is to add the decoding function to an 132 | // existing http.Handler. This functionality will decode the HTTP request into a 133 | // struct instance and put its pointer to the request's context. So that the 134 | // next hop can get the decoded struct instance from the request's context. 135 | // 136 | // We recommend using https://github.com/justinas/alice to chain your 137 | // middlewares. If you're using some popular web frameworks, they may have 138 | // already provided a middleware chaining mechanism. 139 | // 140 | // For example: 141 | // 142 | // type ListUsersRequest struct { 143 | // Page int `in:"query=page,page_index,index"` 144 | // PerPage int `in:"query=per_page,page_size"` 145 | // } 146 | // 147 | // func ListUsersHandler(rw http.ResponseWriter, r *http.Request) { 148 | // input := r.Context().Value(httpin.Input).(*ListUsersRequest) 149 | // // ... 150 | // } 151 | // 152 | // func init() { 153 | // http.Handle("/users", alice.New(httpin.NewInput(&ListUsersRequest{})).ThenFunc(ListUsersHandler)) 154 | // } 155 | func NewInput(inputStruct any, opts ...core.Option) func(http.Handler) http.Handler { 156 | co, err := New(inputStruct, opts...) 157 | internal.PanicOnError(err) 158 | 159 | return func(next http.Handler) http.Handler { 160 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 161 | // Here we read the request and decode it to fill our structure. 162 | // Once failed, the request should end here. 163 | input, err := co.Decode(r) 164 | if err != nil { 165 | co.GetErrorHandler()(rw, r, err) 166 | return 167 | } 168 | 169 | // We put the `input` to the request's context, and it will pass to the next hop. 170 | ctx := context.WithValue(r.Context(), Input, input) 171 | next.ServeHTTP(rw, r.WithContext(ctx)) 172 | }) 173 | } 174 | } 175 | 176 | type coreOptions struct { 177 | // WithErrorHandler overrides the default error handler. 178 | WithErrorHandler func(core.ErrorHandler) core.Option 179 | 180 | // WithMaxMemory overrides the default maximum memory size (32MB) when reading 181 | // the request body. See https://pkg.go.dev/net/http#Request.ParseMultipartForm 182 | // for more details. 183 | WithMaxMemory func(int64) core.Option 184 | 185 | // WithNestedDirectivesEnabled enables/disables nested directives. 186 | WithNestedDirectivesEnabled func(bool) core.Option 187 | } 188 | -------------------------------------------------------------------------------- /httpin_test.go: -------------------------------------------------------------------------------- 1 | package httpin 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/ggicci/httpin/core" 14 | "github.com/justinas/alice" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | type Pagination struct { 19 | Page int `in:"form=page,page_index,index"` 20 | PerPage int `in:"form=per_page,page_size"` 21 | } 22 | 23 | func testcasePagination1100() (*http.Request, *Pagination) { 24 | r, _ := http.NewRequest("GET", "/", nil) 25 | r.Form = url.Values{ 26 | "page": {"1"}, 27 | "per_page": {"100"}, 28 | } 29 | return r, &Pagination{ 30 | Page: 1, 31 | PerPage: 100, 32 | } 33 | } 34 | 35 | func TestDecodeTo(t *testing.T) { 36 | r, expected := testcasePagination1100() 37 | 38 | func() { 39 | input := &Pagination{} 40 | err := DecodeTo(r, input) // pointer to a struct instance 41 | assert.NoError(t, err) 42 | assert.Equal(t, expected, input) 43 | }() 44 | 45 | func() { 46 | input := Pagination{} 47 | err := DecodeTo(r, &input) // addressable struct instance 48 | assert.NoError(t, err) 49 | assert.Equal(t, expected, &input) 50 | }() 51 | 52 | func() { 53 | input := &Pagination{} 54 | err := DecodeTo(r, &input) // pointer to pointer of struct instance 55 | assert.NoError(t, err) 56 | assert.Equal(t, expected, input) 57 | }() 58 | 59 | func() { 60 | input := Pagination{} 61 | err := DecodeTo(r, input) // non-pointer struct instance should fail 62 | assert.ErrorContains(t, err, "invalid resolve target") 63 | }() 64 | } 65 | 66 | func TestDecode(t *testing.T) { 67 | r, expected := testcasePagination1100() 68 | 69 | p, err := Decode[Pagination](r) 70 | assert.NoError(t, err) 71 | assert.Equal(t, expected, p) 72 | } 73 | 74 | func TestDecode_ErrNotAStruct(t *testing.T) { 75 | r, _ := testcasePagination1100() 76 | 77 | _, err := Decode[int](r) 78 | assert.ErrorContains(t, err, "T must be a struct type") 79 | 80 | _, err = Decode[*Pagination](r) 81 | assert.ErrorContains(t, err, "T must be a struct type") 82 | } 83 | 84 | func TestDecode_ErrBuildResolverFailed(t *testing.T) { 85 | r, _ := testcasePagination1100() 86 | 87 | type Foo struct { 88 | Name string `in:"nonexistent=foo"` 89 | } 90 | 91 | assert.Error(t, DecodeTo(r, &Foo{})) 92 | 93 | v, err := Decode[Foo](r) 94 | assert.Nil(t, v) 95 | assert.Error(t, err) 96 | } 97 | 98 | func TestDecode_ErrDecodeFailure(t *testing.T) { 99 | r, _ := http.NewRequest("GET", "/", nil) 100 | r.Form = url.Values{ 101 | "page": {"1"}, 102 | "per_page": {"one-hundred"}, 103 | } 104 | 105 | p := &Pagination{} 106 | assert.Error(t, DecodeTo(r, p)) 107 | 108 | v, err := Decode[Pagination](r) 109 | assert.Nil(t, v) 110 | assert.Error(t, err) 111 | } 112 | 113 | type EchoInput struct { 114 | Token string `in:"form=access_token;header=x-api-key;required"` 115 | Saying string `in:"form=saying"` 116 | } 117 | 118 | func EchoHandler(rw http.ResponseWriter, r *http.Request) { 119 | var input = r.Context().Value(Input).(*EchoInput) 120 | json.NewEncoder(rw).Encode(input) 121 | } 122 | 123 | func TestNewInput_WithNil(t *testing.T) { 124 | assert.Panics(t, func() { 125 | NewInput(nil) 126 | }) 127 | } 128 | 129 | func TestNewInput_Success(t *testing.T) { 130 | r, err := http.NewRequest("GET", "/", nil) 131 | assert.NoError(t, err) 132 | 133 | r.Header.Add("X-Api-Key", "abc") 134 | var params = url.Values{} 135 | params.Add("saying", "TO THINE OWE SELF BE TRUE") 136 | r.URL.RawQuery = params.Encode() 137 | 138 | rw := httptest.NewRecorder() 139 | handler := alice.New(NewInput(EchoInput{})).ThenFunc(EchoHandler) 140 | handler.ServeHTTP(rw, r) 141 | assert.Equal(t, 200, rw.Code) 142 | expected := `{"Token":"abc","Saying":"TO THINE OWE SELF BE TRUE"}` + "\n" 143 | assert.Equal(t, expected, rw.Body.String()) 144 | } 145 | 146 | func TestNewInput_ErrorHandledByDefaultErrorHandler(t *testing.T) { 147 | r, err := http.NewRequest("GET", "/", nil) 148 | assert.NoError(t, err) 149 | 150 | var params = url.Values{} 151 | params.Add("saying", "TO THINE OWE SELF BE TRUE") 152 | r.URL.RawQuery = params.Encode() 153 | 154 | rw := httptest.NewRecorder() 155 | handler := alice.New(NewInput(EchoInput{})).ThenFunc(EchoHandler) 156 | handler.ServeHTTP(rw, r) 157 | var out map[string]any 158 | assert.Nil(t, json.NewDecoder(rw.Body).Decode(&out)) 159 | 160 | assert.Equal(t, 422, rw.Code) 161 | assert.Equal(t, "Token", out["field"]) 162 | assert.Equal(t, "required", out["directive"]) 163 | assert.Contains(t, out["error"], "missing required field") 164 | } 165 | 166 | func CustomErrorHandler(rw http.ResponseWriter, r *http.Request, err error) { 167 | var invalidFieldError *core.InvalidFieldError 168 | if errors.As(err, &invalidFieldError) { 169 | rw.WriteHeader(http.StatusBadRequest) // status: 400 170 | io.WriteString(rw, invalidFieldError.Error()) 171 | return 172 | } 173 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) // status: 500 174 | } 175 | 176 | func TestNewRequest(t *testing.T) { 177 | req, err := NewRequest("GET", "/products", &Pagination{ 178 | Page: 19, 179 | PerPage: 50, 180 | }) 181 | assert.NoError(t, err) 182 | 183 | expected, _ := http.NewRequest("GET", "/products", nil) 184 | expected.Body = io.NopCloser(strings.NewReader("page=19&per_page=50")) 185 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded") 186 | assert.Equal(t, expected, req) 187 | } 188 | 189 | func TestNewRequest_ErrNewFailure(t *testing.T) { 190 | _, err := NewRequest("GET", "/products", 123) 191 | assert.Error(t, err) 192 | } 193 | -------------------------------------------------------------------------------- /integration/echo.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "mime/multipart" 5 | 6 | "github.com/ggicci/httpin/core" 7 | "github.com/labstack/echo/v4" 8 | ) 9 | 10 | // UseEchoRouter registers a new directive executor which can extract values 11 | // from path variables. 12 | // https://ggicci.github.io/httpin/integrations/echo 13 | // 14 | // Usage: 15 | // 16 | // import httpin_integration "github.com/ggicci/httpin/integration" 17 | // 18 | // func init() { 19 | // e := echo.New() 20 | // httpin_integration.UseEchoRouter("path", e) 21 | // } 22 | func UseEchoRouter(name string, e *echo.Echo) { 23 | core.RegisterDirective( 24 | name, 25 | core.NewDirectivePath((&echoRouterExtractor{e}).Execute), 26 | true, 27 | ) 28 | } 29 | 30 | // echoRouterExtractor is an extractor for mux.Vars 31 | type echoRouterExtractor struct { 32 | e *echo.Echo 33 | } 34 | 35 | func (mux *echoRouterExtractor) Execute(rtm *core.DirectiveRuntime) error { 36 | req := rtm.GetRequest() 37 | kvs := make(map[string][]string) 38 | 39 | c := mux.e.NewContext(req, nil) 40 | c.SetRequest(req) 41 | 42 | mux.e.Router().Find(req.Method, req.URL.Path, c) 43 | 44 | for _, key := range c.ParamNames() { 45 | kvs[key] = []string{c.Param(key)} 46 | } 47 | 48 | extractor := &core.FormExtractor{ 49 | Runtime: rtm, 50 | Form: multipart.Form{ 51 | Value: kvs, 52 | }, 53 | } 54 | return extractor.Extract() 55 | } 56 | -------------------------------------------------------------------------------- /integration/echo_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/ggicci/httpin" 11 | httpin_integration "github.com/ggicci/httpin/integration" 12 | "github.com/labstack/echo/v4" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestUseEchoMux(t *testing.T) { 17 | e := echo.New() 18 | // NOTE: I removed the API UseEchoPathRouter because it introduces minimal benefit 19 | // but adds surface area and maintenance cost. 20 | httpin_integration.UseEchoRouter("path", e) 21 | 22 | req := httptest.NewRequest(http.MethodGet, "/users/ggicci/posts/123", nil) 23 | req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) 24 | rec := httptest.NewRecorder() 25 | 26 | c := e.NewContext(req, rec) 27 | 28 | handler := func(ctx echo.Context) error { 29 | param := &GetPostOfUserInput{} 30 | core, err := httpin.New(param) 31 | if err != nil { 32 | return err 33 | } 34 | v, err := core.Decode(ctx.Request()) 35 | if err != nil { 36 | return err 37 | } 38 | fmt.Println(param) 39 | return c.JSON(http.StatusOK, v) 40 | } 41 | e.GET("/users/:username/posts/:pid", handler) 42 | err := handler(c) 43 | assert.NoError(t, err) 44 | assert.Equal(t, http.StatusOK, rec.Code) 45 | assert.Equal(t, `{"Username":"ggicci","PostID":123}`, strings.TrimSpace(rec.Body.String())) 46 | } 47 | -------------------------------------------------------------------------------- /integration/gochi.go: -------------------------------------------------------------------------------- 1 | // integration: "gochi" 2 | // https://ggicci.github.io/httpin/integrations/gochi 3 | 4 | package integration 5 | 6 | import ( 7 | "mime/multipart" 8 | "net/http" 9 | 10 | "github.com/ggicci/httpin/core" 11 | ) 12 | 13 | // GochiURLParamFunc is chi.URLParam 14 | type GochiURLParamFunc func(r *http.Request, key string) string 15 | 16 | // UseGochiURLParam registers a directive executor which can extract values 17 | // from `chi.URLParam`, i.e. path variables. 18 | // https://ggicci.github.io/httpin/integrations/gochi 19 | // 20 | // Usage: 21 | // 22 | // import httpin_integration "github.com/ggicci/httpin/integration" 23 | // 24 | // func init() { 25 | // httpin_integration.UseGochiURLParam("path", chi.URLParam) 26 | // } 27 | func UseGochiURLParam(name string, fn GochiURLParamFunc) { 28 | core.RegisterDirective( 29 | name, 30 | core.NewDirectivePath((&gochiURLParamExtractor{URLParam: fn}).Execute), 31 | true, 32 | ) 33 | } 34 | 35 | type gochiURLParamExtractor struct { 36 | URLParam GochiURLParamFunc 37 | } 38 | 39 | func (chi *gochiURLParamExtractor) Execute(rtm *core.DirectiveRuntime) error { 40 | req := rtm.GetRequest() 41 | kvs := make(map[string][]string) 42 | 43 | for _, key := range rtm.Directive.Argv { 44 | value := chi.URLParam(req, key) 45 | if value != "" { 46 | kvs[key] = []string{value} 47 | } 48 | } 49 | 50 | extractor := &core.FormExtractor{ 51 | Runtime: rtm, 52 | Form: multipart.Form{ 53 | Value: kvs, 54 | }, 55 | } 56 | return extractor.Extract() 57 | } 58 | -------------------------------------------------------------------------------- /integration/gochi_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/ggicci/httpin" 10 | httpin_integration "github.com/ggicci/httpin/integration" 11 | "github.com/go-chi/chi/v5" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type GetArticleOfUserInput struct { 16 | Author string `in:"gochi=author"` // equivalent to chi.URLParam("author") 17 | ArticleID int64 `in:"gochi=articleID"` 18 | } 19 | 20 | func GetArticleOfUser(rw http.ResponseWriter, r *http.Request) { 21 | var input = r.Context().Value(httpin.Input).(*GetArticleOfUserInput) 22 | json.NewEncoder(rw).Encode(input) 23 | } 24 | 25 | func TestUseGochiURLParam(t *testing.T) { 26 | // Register the "gochi" directive, usually in init(). 27 | // In most cases, you register this as "path", here's just an example. 28 | // Which is in order to avoid test conflicts with other tests 29 | httpin_integration.UseGochiURLParam("gochi", chi.URLParam) 30 | 31 | rw := httptest.NewRecorder() 32 | r, err := http.NewRequest("GET", "/ggicci/articles/1024", nil) 33 | assert.NoError(t, err) 34 | 35 | router := chi.NewRouter() 36 | router.With( 37 | httpin.NewInput(GetArticleOfUserInput{}), 38 | ).Get("/{author}/articles/{articleID}", GetArticleOfUser) 39 | 40 | router.ServeHTTP(rw, r) 41 | assert.Equal(t, 200, rw.Code) 42 | expected := `{"Author":"ggicci","ArticleID":1024}` + "\n" 43 | assert.Equal(t, expected, rw.Body.String()) 44 | } 45 | -------------------------------------------------------------------------------- /integration/gorilla.go: -------------------------------------------------------------------------------- 1 | // Mux vars extension for github.com/gorilla/mux package. 2 | 3 | package integration 4 | 5 | import ( 6 | "mime/multipart" 7 | "net/http" 8 | 9 | "github.com/ggicci/httpin/core" 10 | ) 11 | 12 | // GorillaMuxVarsFunc is mux.Vars 13 | type GorillaMuxVarsFunc func(*http.Request) map[string]string 14 | 15 | // UseGorillaMux registers a new directive executor which can extract values 16 | // from `mux.Vars`, i.e. path variables. 17 | // https://ggicci.github.io/httpin/integrations/gorilla 18 | // 19 | // Usage: 20 | // 21 | // import httpin_integration "github.com/ggicci/httpin/integration" 22 | // 23 | // func init() { 24 | // httpin_integration.UseGorillaMux("path", mux.Vars) 25 | // } 26 | func UseGorillaMux(name string, fnVars GorillaMuxVarsFunc) { 27 | core.RegisterDirective( 28 | name, 29 | core.NewDirectivePath((&gorillaMuxVarsExtractor{Vars: fnVars}).Execute), 30 | true, 31 | ) 32 | } 33 | 34 | type gorillaMuxVarsExtractor struct { 35 | Vars GorillaMuxVarsFunc 36 | } 37 | 38 | func (mux *gorillaMuxVarsExtractor) Execute(rtm *core.DirectiveRuntime) error { 39 | req := rtm.GetRequest() 40 | kvs := make(map[string][]string) 41 | 42 | for key, value := range mux.Vars(req) { 43 | kvs[key] = []string{value} 44 | } 45 | 46 | extractor := &core.FormExtractor{ 47 | Runtime: rtm, 48 | Form: multipart.Form{ 49 | Value: kvs, 50 | }, 51 | } 52 | return extractor.Extract() 53 | } 54 | -------------------------------------------------------------------------------- /integration/gorilla_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/ggicci/httpin" 10 | httpin_integration "github.com/ggicci/httpin/integration" 11 | "github.com/gorilla/mux" 12 | "github.com/justinas/alice" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | type GetPostOfUserInput struct { 17 | Username string `in:"path=username"` 18 | PostID int64 `in:"path=pid"` 19 | } 20 | 21 | func GetPostOfUserHandler(rw http.ResponseWriter, r *http.Request) { 22 | var input = r.Context().Value(httpin.Input).(*GetPostOfUserInput) 23 | json.NewEncoder(rw).Encode(input) 24 | } 25 | 26 | func TestGorillaMuxVars(t *testing.T) { 27 | httpin_integration.UseGorillaMux("path", mux.Vars) // register the "path" directive, usually in init() 28 | 29 | rw := httptest.NewRecorder() 30 | r, err := http.NewRequest("GET", "/ggicci/posts/1024", nil) 31 | assert.NoError(t, err) 32 | 33 | router := mux.NewRouter() 34 | router.Handle("/{username}/posts/{pid}", alice.New( 35 | httpin.NewInput(GetPostOfUserInput{}), 36 | ).ThenFunc(GetPostOfUserHandler)).Methods("GET") 37 | router.ServeHTTP(rw, r) 38 | assert.Equal(t, 200, rw.Code) 39 | expected := `{"Username":"ggicci","PostID":1024}` + "\n" 40 | assert.Equal(t, expected, rw.Body.String()) 41 | } 42 | -------------------------------------------------------------------------------- /integration/http.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "mime/multipart" 5 | "net/http" 6 | 7 | "github.com/ggicci/httpin/core" 8 | ) 9 | 10 | type HttpMuxVarsFunc func(*http.Request) map[string]string 11 | 12 | // UseHttpPathVariable registers a new directive executor which can extract 13 | // values from URL path variables via `http.Request.PathValue` API. 14 | // https://ggicci.github.io/httpin/integrations/http 15 | // 16 | // Usage: 17 | // 18 | // import httpin_integration "github.com/ggicci/httpin/integration" 19 | // func init() { 20 | // httpin_integration.UseHttpPathVariable("path") 21 | // } 22 | func UseHttpPathVariable(name string) { 23 | core.RegisterDirective( 24 | name, 25 | core.NewDirectivePath((&httpMuxVarsExtractor{}).Execute), 26 | true, 27 | ) 28 | } 29 | 30 | type httpMuxVarsExtractor struct{} 31 | 32 | func (mux *httpMuxVarsExtractor) Execute(rtm *core.DirectiveRuntime) error { 33 | req := rtm.GetRequest() 34 | kvs := make(map[string][]string) 35 | 36 | for _, key := range rtm.Directive.Argv { 37 | value := req.PathValue(key) 38 | if value != "" { 39 | kvs[key] = []string{value} 40 | } 41 | } 42 | 43 | extractor := &core.FormExtractor{ 44 | Runtime: rtm, 45 | Form: multipart.Form{ 46 | Value: kvs, 47 | }, 48 | } 49 | return extractor.Extract() 50 | } 51 | -------------------------------------------------------------------------------- /integration/http_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/ggicci/httpin" 12 | httpin_integration "github.com/ggicci/httpin/integration" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | type HttpMuxPathInput struct { 17 | Username string `in:"path=username"` 18 | PostID int64 `in:"path=pid"` 19 | } 20 | 21 | func TestUseHttpMux(t *testing.T) { 22 | httpin_integration.UseHttpPathVariable("path") 23 | 24 | srv := http.NewServeMux() 25 | srv.HandleFunc("/users/{username}/posts/{pid}", func(w http.ResponseWriter, r *http.Request) { 26 | param := &HttpMuxPathInput{} 27 | core, err := httpin.New(param) 28 | if err != nil { 29 | t.Fatal(err) 30 | return 31 | } 32 | v, err := core.Decode(r) 33 | if err != nil { 34 | t.Fatal(err) 35 | return 36 | } 37 | jsonBytes, err := json.Marshal(v) 38 | if err != nil { 39 | t.Fatal(err) 40 | return 41 | } 42 | 43 | w.WriteHeader(http.StatusOK) 44 | w.Write(jsonBytes) 45 | }) 46 | ts := httptest.NewServer(srv) 47 | defer ts.Close() 48 | 49 | resp, err := http.DefaultClient.Get(ts.URL + "/users/chriss-de/posts/456") 50 | assert.NoError(t, err) 51 | assert.Equal(t, http.StatusOK, resp.StatusCode) 52 | 53 | bodyBytes, err := io.ReadAll(resp.Body) 54 | assert.NoError(t, err) 55 | bodyString := string(bodyBytes) 56 | 57 | assert.Equal(t, `{"Username":"chriss-de","PostID":456}`, strings.TrimSpace(bodyString)) 58 | } 59 | -------------------------------------------------------------------------------- /internal/misc.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | func IsNil(value reflect.Value) bool { 9 | switch value.Kind() { 10 | case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Interface, reflect.Slice: 11 | return value.IsNil() 12 | default: 13 | return false 14 | } 15 | } 16 | 17 | func PanicOnError(err error) { 18 | if err != nil { 19 | panic(fmt.Errorf("httpin: %w", err)) 20 | } 21 | } 22 | 23 | // TypeOf returns the reflect.Type of a given type. 24 | // e.g. TypeOf[int]() returns reflect.TypeOf(0) 25 | func TypeOf[T any]() reflect.Type { 26 | var zero [0]T 27 | return reflect.TypeOf(zero).Elem() 28 | } 29 | 30 | func Pointerize[T any](v T) *T { 31 | return &v 32 | } 33 | 34 | // DereferencedType returns the underlying type of a pointer. 35 | func DereferencedType(v any) reflect.Type { 36 | rv := reflect.ValueOf(v) 37 | for rv.Kind() == reflect.Pointer { 38 | rv = rv.Elem() 39 | } 40 | return rv.Type() 41 | } 42 | -------------------------------------------------------------------------------- /internal/misc_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestIsNil(t *testing.T) { 11 | assert.False(t, IsNil(reflect.ValueOf("hello"))) 12 | assert.True(t, IsNil(reflect.ValueOf((*string)(nil)))) 13 | } 14 | 15 | func TestPanicOnError(t *testing.T) { 16 | PanicOnError(nil) 17 | 18 | assert.PanicsWithError(t, "httpin: "+assert.AnError.Error(), func() { 19 | PanicOnError(assert.AnError) 20 | }) 21 | } 22 | 23 | func TestTypeOf(t *testing.T) { 24 | assert.Equal(t, reflect.TypeOf(0), TypeOf[int]()) 25 | } 26 | 27 | func TestPointerize(t *testing.T) { 28 | assert.Equal(t, 102, *Pointerize[int](102)) 29 | } 30 | 31 | func TestDereferencedType(t *testing.T) { 32 | type Object struct{} 33 | 34 | var o = new(Object) 35 | var po = &o 36 | var ppo = &po 37 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(Object{})) 38 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(o)) 39 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(po)) 40 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(ppo)) 41 | } 42 | -------------------------------------------------------------------------------- /internal/stringable.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | "time" 12 | "github.com/ggicci/owl" 13 | ) 14 | 15 | var ( 16 | ErrTypeMismatch = errors.New("type mismatch") 17 | ErrUnsupportedType = owl.ErrUnsupportedType 18 | 19 | builtinStringableAdaptors = make(map[reflect.Type]AnyStringableAdaptor) 20 | ) 21 | 22 | func init() { 23 | builtinStringable[string](func(v *string) (Stringable, error) { return (*String)(v), nil }) 24 | builtinStringable[bool](func(v *bool) (Stringable, error) { return (*Bool)(v), nil }) 25 | builtinStringable[int](func(v *int) (Stringable, error) { return (*Int)(v), nil }) 26 | builtinStringable[int8](func(v *int8) (Stringable, error) { return (*Int8)(v), nil }) 27 | builtinStringable[int16](func(v *int16) (Stringable, error) { return (*Int16)(v), nil }) 28 | builtinStringable[int32](func(v *int32) (Stringable, error) { return (*Int32)(v), nil }) 29 | builtinStringable[int64](func(v *int64) (Stringable, error) { return (*Int64)(v), nil }) 30 | builtinStringable[uint](func(v *uint) (Stringable, error) { return (*Uint)(v), nil }) 31 | builtinStringable[uint8](func(v *uint8) (Stringable, error) { return (*Uint8)(v), nil }) 32 | builtinStringable[uint16](func(v *uint16) (Stringable, error) { return (*Uint16)(v), nil }) 33 | builtinStringable[uint32](func(v *uint32) (Stringable, error) { return (*Uint32)(v), nil }) 34 | builtinStringable[uint64](func(v *uint64) (Stringable, error) { return (*Uint64)(v), nil }) 35 | builtinStringable[float32](func(v *float32) (Stringable, error) { return (*Float32)(v), nil }) 36 | builtinStringable[float64](func(v *float64) (Stringable, error) { return (*Float64)(v), nil }) 37 | builtinStringable[complex64](func(v *complex64) (Stringable, error) { return (*Complex64)(v), nil }) 38 | builtinStringable[complex128](func(v *complex128) (Stringable, error) { return (*Complex128)(v), nil }) 39 | builtinStringable[time.Time](func(v *time.Time) (Stringable, error) { return (*Time)(v), nil }) 40 | builtinStringable[[]byte](func(b *[]byte) (Stringable, error) { return (*ByteSlice)(b), nil }) 41 | } 42 | 43 | type StringMarshaler interface { 44 | ToString() (string, error) 45 | } 46 | 47 | type StringUnmarshaler interface { 48 | FromString(string) error 49 | } 50 | 51 | type Stringable interface { 52 | StringMarshaler 53 | StringUnmarshaler 54 | } 55 | 56 | // NewStringable returns a Stringable from the given reflect.Value. 57 | // We assume that the given reflect.Value is a non-nil pointer to a value. 58 | // It will panic if the given reflect.Value is not a pointer. 59 | func NewStringable(rv reflect.Value) (Stringable, error) { 60 | baseType := rv.Type().Elem() 61 | if adapt, ok := builtinStringableAdaptors[baseType]; ok { 62 | return adapt(rv.Interface()) 63 | } else { 64 | return nil, UnsupportedType(baseType) 65 | } 66 | } 67 | 68 | func builtinStringable[T any](adaptor StringableAdaptor[T]) { 69 | builtinStringableAdaptors[TypeOf[T]()] = NewAnyStringableAdaptor[T](adaptor) 70 | } 71 | 72 | type String string 73 | 74 | func (sv String) ToString() (string, error) { 75 | return string(sv), nil 76 | } 77 | 78 | func (sv *String) FromString(s string) error { 79 | *sv = String(s) 80 | return nil 81 | } 82 | 83 | type Bool bool 84 | 85 | func (bv Bool) ToString() (string, error) { 86 | return strconv.FormatBool(bool(bv)), nil 87 | } 88 | 89 | func (bv *Bool) FromString(s string) error { 90 | v, err := strconv.ParseBool(s) 91 | if err != nil { 92 | return err 93 | } 94 | *bv = Bool(v) 95 | return nil 96 | } 97 | 98 | type Int int 99 | 100 | func (iv Int) ToString() (string, error) { 101 | return strconv.Itoa(int(iv)), nil 102 | } 103 | 104 | func (iv *Int) FromString(s string) error { 105 | v, err := strconv.Atoi(s) 106 | if err != nil { 107 | return err 108 | } 109 | *iv = Int(v) 110 | return nil 111 | } 112 | 113 | type Int8 int8 114 | 115 | func (iv Int8) ToString() (string, error) { 116 | return strconv.FormatInt(int64(iv), 10), nil 117 | } 118 | 119 | func (iv *Int8) FromString(s string) error { 120 | v, err := strconv.ParseInt(s, 10, 8) 121 | if err != nil { 122 | return err 123 | } 124 | *iv = Int8(v) 125 | return nil 126 | } 127 | 128 | type Int16 int16 129 | 130 | func (iv Int16) ToString() (string, error) { 131 | return strconv.FormatInt(int64(iv), 10), nil 132 | } 133 | 134 | func (iv *Int16) FromString(s string) error { 135 | v, err := strconv.ParseInt(s, 10, 16) 136 | if err != nil { 137 | return err 138 | } 139 | *iv = Int16(v) 140 | return nil 141 | } 142 | 143 | type Int32 int32 144 | 145 | func (iv Int32) ToString() (string, error) { 146 | return strconv.FormatInt(int64(iv), 10), nil 147 | } 148 | 149 | func (iv *Int32) FromString(s string) error { 150 | v, err := strconv.ParseInt(s, 10, 32) 151 | if err != nil { 152 | return err 153 | } 154 | *iv = Int32(v) 155 | return nil 156 | } 157 | 158 | type Int64 int64 159 | 160 | func (iv Int64) ToString() (string, error) { 161 | return strconv.FormatInt(int64(iv), 10), nil 162 | } 163 | 164 | func (iv *Int64) FromString(s string) error { 165 | v, err := strconv.ParseInt(s, 10, 64) 166 | if err != nil { 167 | return err 168 | } 169 | *iv = Int64(v) 170 | return nil 171 | } 172 | 173 | type Uint uint 174 | 175 | func (uv Uint) ToString() (string, error) { 176 | return strconv.FormatUint(uint64(uv), 10), nil 177 | } 178 | 179 | func (uv *Uint) FromString(s string) error { 180 | v, err := strconv.ParseUint(s, 10, 64) 181 | if err != nil { 182 | return err 183 | } 184 | *uv = Uint(v) 185 | return nil 186 | } 187 | 188 | type Uint8 uint8 189 | 190 | func (uv Uint8) ToString() (string, error) { 191 | return strconv.FormatUint(uint64(uv), 10), nil 192 | } 193 | 194 | func (uv *Uint8) FromString(s string) error { 195 | v, err := strconv.ParseUint(s, 10, 8) 196 | if err != nil { 197 | return err 198 | } 199 | *uv = Uint8(v) 200 | return nil 201 | } 202 | 203 | type Uint16 uint16 204 | 205 | func (uv Uint16) ToString() (string, error) { 206 | return strconv.FormatUint(uint64(uv), 10), nil 207 | } 208 | 209 | func (uv *Uint16) FromString(s string) error { 210 | v, err := strconv.ParseUint(s, 10, 16) 211 | if err != nil { 212 | return err 213 | } 214 | *uv = Uint16(v) 215 | return nil 216 | } 217 | 218 | type Uint32 uint32 219 | 220 | func (uv Uint32) ToString() (string, error) { 221 | return strconv.FormatUint(uint64(uv), 10), nil 222 | } 223 | 224 | func (uv *Uint32) FromString(s string) error { 225 | v, err := strconv.ParseUint(s, 10, 32) 226 | if err != nil { 227 | return err 228 | } 229 | *uv = Uint32(v) 230 | return nil 231 | } 232 | 233 | type Uint64 uint64 234 | 235 | func (uv Uint64) ToString() (string, error) { 236 | return strconv.FormatUint(uint64(uv), 10), nil 237 | } 238 | 239 | func (uv *Uint64) FromString(s string) error { 240 | v, err := strconv.ParseUint(s, 10, 64) 241 | if err != nil { 242 | return err 243 | } 244 | *uv = Uint64(v) 245 | return nil 246 | } 247 | 248 | type Float32 float32 249 | 250 | func (fv Float32) ToString() (string, error) { 251 | return strconv.FormatFloat(float64(fv), 'f', -1, 32), nil 252 | } 253 | 254 | func (fv *Float32) FromString(s string) error { 255 | v, err := strconv.ParseFloat(s, 32) 256 | if err != nil { 257 | return err 258 | } 259 | *fv = Float32(v) 260 | return nil 261 | } 262 | 263 | type Float64 float64 264 | 265 | func (fv Float64) ToString() (string, error) { 266 | return strconv.FormatFloat(float64(fv), 'f', -1, 64), nil 267 | } 268 | 269 | func (fv *Float64) FromString(s string) error { 270 | v, err := strconv.ParseFloat(s, 64) 271 | if err != nil { 272 | return err 273 | } 274 | *fv = Float64(v) 275 | return nil 276 | } 277 | 278 | type Complex64 complex64 279 | 280 | func (cv Complex64) ToString() (string, error) { 281 | return strconv.FormatComplex(complex128(cv), 'f', -1, 64), nil 282 | } 283 | 284 | func (cv *Complex64) FromString(s string) error { 285 | v, err := strconv.ParseComplex(s, 64) 286 | if err != nil { 287 | return err 288 | } 289 | *cv = Complex64(v) 290 | return nil 291 | } 292 | 293 | type Complex128 complex128 294 | 295 | func (cv Complex128) ToString() (string, error) { 296 | return strconv.FormatComplex(complex128(cv), 'f', -1, 128), nil 297 | } 298 | 299 | func (cv *Complex128) FromString(s string) error { 300 | v, err := strconv.ParseComplex(s, 128) 301 | if err != nil { 302 | return err 303 | } 304 | *cv = Complex128(v) 305 | return nil 306 | } 307 | 308 | type Time time.Time 309 | 310 | func (tv Time) ToString() (string, error) { 311 | return time.Time(tv).UTC().Format(time.RFC3339Nano), nil 312 | } 313 | 314 | func (tv *Time) FromString(s string) error { 315 | if t, err := DecodeTime(s); err != nil { 316 | return err 317 | } else { 318 | *tv = Time(t) 319 | return nil 320 | } 321 | } 322 | 323 | var reUnixtime = regexp.MustCompile(`^\d+(\.\d{1,9})?$`) 324 | 325 | // DecodeTime parses data bytes as time.Time in UTC timezone. 326 | // Supported formats of the data bytes are: 327 | // 1. RFC3339Nano string, e.g. "2006-01-02T15:04:05-07:00". 328 | // 2. Date string, e.g. "2006-01-02". 329 | // 3. Unix timestamp, e.g. "1136239445", "1136239445.8", "1136239445.812738". 330 | func DecodeTime(value string) (time.Time, error) { 331 | // Try parsing value as RFC3339 format. 332 | if t, err := time.ParseInLocation(time.RFC3339Nano, value, time.UTC); err == nil { 333 | return t.UTC(), nil 334 | } 335 | 336 | // Try parsing value as date format. 337 | if t, err := time.ParseInLocation("2006-01-02", value, time.UTC); err == nil { 338 | return t.UTC(), nil 339 | } 340 | 341 | // Try parsing value as timestamp, both integer and float formats supported. 342 | // e.g. "1618974933", "1618974933.284368". 343 | if reUnixtime.MatchString(value) { 344 | return DecodeUnixtime(value) 345 | } 346 | 347 | return time.Time{}, errors.New("invalid time value") 348 | } 349 | 350 | // value must be valid unix timestamp, matches reUnixtime. 351 | func DecodeUnixtime(value string) (time.Time, error) { 352 | parts := strings.Split(value, ".") 353 | // Note: errors are ignored, since we already validated the value. 354 | sec, _ := strconv.ParseInt(parts[0], 10, 64) 355 | var nsec int64 356 | if len(parts) == 2 { 357 | nsec, _ = strconv.ParseInt(nanoSecondPrecision(parts[1]), 10, 64) 358 | } 359 | return time.Unix(sec, nsec).UTC(), nil 360 | } 361 | 362 | func nanoSecondPrecision(value string) string { 363 | return value + strings.Repeat("0", 9-len(value)) 364 | } 365 | 366 | // ByteSlice is a wrapper of []byte to implement Stringable. 367 | // NOTE: we're using base64.StdEncoding here, not base64.URLEncoding. 368 | type ByteSlice []byte 369 | 370 | func (bs ByteSlice) ToString() (string, error) { 371 | return base64.StdEncoding.EncodeToString(bs), nil 372 | } 373 | 374 | func (bs *ByteSlice) FromString(s string) error { 375 | v, err := base64.StdEncoding.DecodeString(s) 376 | if err != nil { 377 | return err 378 | } 379 | *bs = ByteSlice(v) 380 | return nil 381 | } 382 | 383 | func UnsupportedType(rt reflect.Type) error { 384 | return fmt.Errorf("%w: %v", ErrUnsupportedType, rt) 385 | } 386 | -------------------------------------------------------------------------------- /internal/stringable_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/ggicci/owl" 11 | ) 12 | 13 | func TestNewStringable_string(t *testing.T) { 14 | var s string = "hello" 15 | rvString := reflect.ValueOf(s) 16 | assert.Panics(t, func() { 17 | NewStringable(rvString) 18 | }) 19 | 20 | rvStringPointer := reflect.ValueOf(&s) 21 | sv, err := NewStringable(rvStringPointer) 22 | assert.NoError(t, err) 23 | got, err := sv.ToString() 24 | assert.NoError(t, err) 25 | assert.Equal(t, "hello", got) 26 | sv.FromString("world") 27 | assert.Equal(t, "world", s) 28 | } 29 | 30 | func TestNewStringable_bool(t *testing.T) { 31 | var b bool = true 32 | rvBool := reflect.ValueOf(b) 33 | assert.Panics(t, func() { 34 | NewStringable(rvBool) 35 | }) 36 | 37 | rvBoolPointer := reflect.ValueOf(&b) 38 | sv, err := NewStringable(rvBoolPointer) 39 | assert.NoError(t, err) 40 | got, err := sv.ToString() 41 | assert.NoError(t, err) 42 | assert.Equal(t, "true", got) 43 | sv.FromString("false") 44 | assert.Equal(t, false, b) 45 | 46 | assert.Error(t, sv.FromString("hello")) 47 | } 48 | 49 | func TestNewStringable_int(t *testing.T) { 50 | testInteger[int](t, 2045, "hello") 51 | } 52 | 53 | func TestNewStringable_int8(t *testing.T) { 54 | testInteger[int8](t, int8(127), "128") 55 | } 56 | 57 | func TestNewStringable_int16(t *testing.T) { 58 | testInteger[int16](t, int16(32767), "32768") 59 | } 60 | 61 | func TestNewStringable_int32(t *testing.T) { 62 | testInteger[int32](t, int32(2147483647), "2147483648") 63 | } 64 | 65 | func TestNewStringable_int64(t *testing.T) { 66 | testInteger[int64](t, int64(9223372036854775807), "9223372036854775808") 67 | } 68 | 69 | func TestNewStringable_uint(t *testing.T) { 70 | testInteger[uint](t, uint(2045), "-1") 71 | } 72 | 73 | func TestNewStringable_uint8(t *testing.T) { 74 | testInteger[uint8](t, uint8(255), "256") 75 | } 76 | 77 | func TestNewStringable_uint16(t *testing.T) { 78 | testInteger[uint16](t, uint16(65535), "65536") 79 | } 80 | 81 | func TestNewStringable_uint32(t *testing.T) { 82 | testInteger[uint32](t, uint32(4294967295), "4294967296") 83 | } 84 | 85 | func TestNewStringable_uint64(t *testing.T) { 86 | testInteger[uint64](t, uint64(18446744073709551615), "18446744073709551616") 87 | } 88 | 89 | func TestNewStringable_float32(t *testing.T) { 90 | testInteger[float32](t, float32(3.1415926), "hello") 91 | } 92 | 93 | func TestNewStringable_float64(t *testing.T) { 94 | testInteger[float64](t, float64(3.14159265358979323846264338327950288419716939937510582097494459), "hello") 95 | } 96 | 97 | func TestNewStringable_complex64(t *testing.T) { 98 | testInteger[complex64](t, complex64(3.1415926+2.71828i), "hello") 99 | } 100 | 101 | func TestNewStringable_complex128(t *testing.T) { 102 | testInteger[complex128](t, complex128(3.14159265358979323846264338327950288419716939937510582097494459+2.71828182845904523536028747135266249775724709369995957496696763i), "hello") 103 | } 104 | 105 | func TestNewStringable_Time(t *testing.T) { 106 | var now = time.Now() 107 | rvTime := reflect.ValueOf(now) 108 | assert.Panics(t, func() { 109 | NewStringable(rvTime) 110 | }) 111 | 112 | rvTimePointer := reflect.ValueOf(&now) 113 | sv, err := NewStringable(rvTimePointer) 114 | assert.NoError(t, err) 115 | 116 | // RFC3339Nano 117 | testTime(t, sv, "1991-11-10T08:00:00+08:00", time.Date(1991, 11, 10, 8, 0, 0, 0, time.FixedZone("Asia/Shanghai", +8*3600)), "1991-11-10T00:00:00Z") 118 | // Date string 119 | testTime(t, sv, "1991-11-10", time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC), "1991-11-10T00:00:00Z") 120 | 121 | // Unix timestamp 122 | testTime(t, sv, "678088800", time.Date(1991, 6, 28, 6, 0, 0, 0, time.UTC), "1991-06-28T06:00:00Z") 123 | 124 | // Unix timestamp fraction 125 | testTime(t, sv, "678088800.123456789", time.Date(1991, 6, 28, 6, 0, 0, 123456789, time.UTC), "1991-06-28T06:00:00.123456789Z") 126 | 127 | // Unsupported format 128 | assert.Error(t, sv.FromString("hello")) 129 | } 130 | 131 | func TestNewStringable_ByteSlice(t *testing.T) { 132 | var b []byte = []byte("hello") 133 | rvByteSlice := reflect.ValueOf(b) 134 | assert.NotPanics(t, func() { 135 | NewStringable(rvByteSlice) 136 | }) 137 | 138 | rvByteSlicePointer := reflect.ValueOf(&b) 139 | sv, err := NewStringable(rvByteSlicePointer) 140 | assert.NoError(t, err) 141 | got, err := sv.ToString() 142 | assert.NoError(t, err) 143 | assert.Equal(t, "aGVsbG8=", got) 144 | 145 | sv.FromString("d29ybGQ=") 146 | assert.Equal(t, []byte("world"), b) 147 | 148 | assert.Error(t, sv.FromString("hello")) 149 | } 150 | 151 | func TestNewStringable_ErrUnsupportedType(t *testing.T) { 152 | type MyStruct struct{ Name string } 153 | var s MyStruct 154 | rvStruct := reflect.ValueOf(s) 155 | assert.Panics(t, func() { 156 | NewStringable(rvStruct) 157 | }) 158 | rvStructPointer := reflect.ValueOf(&s) 159 | sv, err := NewStringable(rvStructPointer) 160 | assert.ErrorIs(t, err, owl.ErrUnsupportedType) 161 | assert.Nil(t, sv) 162 | } 163 | 164 | type Numeric interface { 165 | int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64 | complex64 | complex128 166 | } 167 | 168 | func testInteger[T Numeric](t *testing.T, vSuccess T, invalidStr string) { 169 | rv := reflect.ValueOf(vSuccess) 170 | assert.Panics(t, func() { 171 | NewStringable(rv) 172 | }) 173 | 174 | rvPointer := reflect.ValueOf(&vSuccess) 175 | sv, err := NewStringable(rvPointer) 176 | assert.NoError(t, err) 177 | got, err := sv.ToString() 178 | assert.NoError(t, err) 179 | assert.Equal(t, fmt.Sprintf("%v", vSuccess), got) 180 | sv.FromString("2") 181 | assert.Equal(t, T(2), vSuccess) 182 | 183 | assert.Error(t, sv.FromString(invalidStr)) 184 | } 185 | 186 | func testTime(t *testing.T, sv Stringable, fromStr string, expected time.Time, expectedToStr string) { 187 | assert.NoError(t, sv.FromString(fromStr)) 188 | assert.True(t, equalTime(expected, time.Time(*sv.(*Time)))) 189 | ts, err := sv.ToString() 190 | assert.NoError(t, err) 191 | assert.Equal(t, expectedToStr, ts) 192 | } 193 | 194 | func equalTime(expected, actual time.Time) bool { 195 | return expected.UTC() == actual.UTC() 196 | } 197 | -------------------------------------------------------------------------------- /internal/stringableadaptor.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import "fmt" 4 | 5 | type StringableAdaptor[T any] func(*T) (Stringable, error) 6 | type AnyStringableAdaptor func(any) (Stringable, error) 7 | 8 | func NewAnyStringableAdaptor[T any](adapt StringableAdaptor[T]) AnyStringableAdaptor { 9 | return func(v any) (Stringable, error) { 10 | if cv, ok := v.(*T); ok { 11 | return adapt(cv) 12 | } else { 13 | return nil, fmt.Errorf("%w: cannot convert %T to %s", ErrTypeMismatch, v, TypeOf[*T]()) 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /internal/stringableadaptor_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type YesNo bool 12 | 13 | func (yn YesNo) ToString() (string, error) { 14 | if yn { 15 | return "yes", nil 16 | } else { 17 | return "no", nil 18 | } 19 | } 20 | 21 | func (yn *YesNo) FromString(s string) error { 22 | switch strings.ToLower(s) { 23 | case "yes": 24 | *yn = true 25 | case "no": 26 | *yn = false 27 | default: 28 | return errors.New("invalid value") 29 | } 30 | return nil 31 | } 32 | 33 | func TestToAnyStringableAdaptor(t *testing.T) { 34 | adaptor := NewAnyStringableAdaptor[bool](func(b *bool) (Stringable, error) { 35 | return (*YesNo)(b), nil 36 | }) 37 | 38 | var validBoolean bool = true 39 | stringable, err := adaptor(&validBoolean) 40 | assert.NoError(t, err) 41 | v, err := stringable.ToString() 42 | assert.NoError(t, err) 43 | assert.Equal(t, "yes", v) 44 | assert.NoError(t, stringable.FromString("no")) 45 | assert.False(t, validBoolean) 46 | 47 | var invalidType int = 0 48 | stringable, err = adaptor(&invalidType) 49 | assert.ErrorIs(t, err, ErrTypeMismatch) 50 | assert.Nil(t, stringable) 51 | assert.ErrorContains(t, err, "cannot convert *int to *bool") 52 | } 53 | -------------------------------------------------------------------------------- /patch/patch.go: -------------------------------------------------------------------------------- 1 | package patch 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | ) 7 | 8 | // Field is a wrapper which can tell if a field was unmarshalled from the data provided. 9 | // When `Field.Valid` is true, which means `Field.Value` is populated from decoding the raw data. 10 | // Otherwise, no data was provided, i.e. field missing. 11 | type Field[T any] struct { 12 | Value T 13 | Valid bool 14 | } 15 | 16 | func (f Field[T]) MarshalJSON() ([]byte, error) { 17 | if !f.Valid { 18 | return []byte("null"), nil 19 | } 20 | return json.Marshal(f.Value) 21 | } 22 | 23 | func (f *Field[T]) UnmarshalJSON(data []byte) error { 24 | err := json.Unmarshal(data, &f.Value) 25 | if err == nil && !bytes.Equal(data, []byte("null")) { 26 | f.Valid = true 27 | } 28 | return err 29 | } 30 | -------------------------------------------------------------------------------- /patch/patch_test.go: -------------------------------------------------------------------------------- 1 | package patch_test 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "github.com/ggicci/httpin/patch" 10 | ) 11 | 12 | func shouldBeNil(t *testing.T, err error, failMessage string) { 13 | if err != nil { 14 | t.Logf("%s, got error: %v", failMessage, err) 15 | t.Fail() 16 | } 17 | } 18 | 19 | func shouldResemble(t *testing.T, va, vb any, failMessage string) { 20 | if reflect.DeepEqual(va, vb) { 21 | return 22 | } 23 | t.Logf("%s, expected %#v, got %#v", failMessage, va, vb) 24 | t.Fail() 25 | } 26 | 27 | func fixedZone(offset int) *time.Location { 28 | if offset == 0 { 29 | return time.UTC 30 | } 31 | // _, localOffset := time.Now().Local().Zone() 32 | // if offset == localOffset { 33 | // return time.Local 34 | // } 35 | return time.FixedZone("", offset) 36 | } 37 | 38 | func testJSONMarshalling(t *testing.T, tc testcase) { 39 | bs, err := json.Marshal(tc.Expected) 40 | if err != nil { 41 | t.Logf("marshal failed, got error: %v", err) 42 | t.Fail() 43 | } 44 | if string(bs) != tc.Content { 45 | t.Logf("marshal failed, expected %q, got %q", tc.Content, string(bs)) 46 | t.Fail() 47 | } 48 | } 49 | 50 | func testJSONUnmarshalling(t *testing.T, tc testcase) { 51 | rt := reflect.TypeOf(tc.Expected) // type: patch.Field 52 | rv := reflect.New(rt) // rv: *patch.Field 53 | 54 | shouldBeNil(t, json.Unmarshal([]byte(tc.Content), rv.Interface()), "unmarshal failed") 55 | shouldResemble(t, rv.Elem().Interface(), tc.Expected, "unmarshal failed") 56 | } 57 | 58 | type testcase struct { 59 | Content string 60 | Expected any 61 | } 62 | 63 | type GitHubProfile struct { 64 | Id int64 `json:"id"` 65 | Login string `json:"login"` 66 | AvatarUrl string `json:"avatar_url"` 67 | } 68 | 69 | type GenderType string 70 | 71 | type Account struct { 72 | Id int64 73 | Email string 74 | Tags []string 75 | Gender GenderType 76 | GitHub *GitHubProfile 77 | } 78 | 79 | type AccountPatch struct { 80 | Email patch.Field[string] `json:"email"` 81 | Tags patch.Field[[]string] `json:"tags"` 82 | Gender patch.Field[GenderType] `json:"gender"` 83 | GitHub patch.Field[*GitHubProfile] `json:"github"` 84 | } 85 | 86 | func TestField(t *testing.T) { 87 | var cases = []testcase{ 88 | {"true", patch.Field[bool]{true, true}}, 89 | {"false", patch.Field[bool]{false, true}}, 90 | {"2045", patch.Field[int]{2045, true}}, 91 | {"127", patch.Field[int8]{127, true}}, 92 | {"32767", patch.Field[int16]{32767, true}}, 93 | {"2147483647", patch.Field[int32]{2147483647, true}}, 94 | {"9223372036854775807", patch.Field[int64]{9223372036854775807, true}}, 95 | {"2045", patch.Field[uint]{2045, true}}, 96 | {"255", patch.Field[uint8]{255, true}}, 97 | {"65535", patch.Field[uint16]{65535, true}}, 98 | {"4294967295", patch.Field[uint32]{4294967295, true}}, 99 | {"18446744073709551615", patch.Field[uint64]{18446744073709551615, true}}, 100 | {"3.14", patch.Field[float32]{3.14, true}}, 101 | {"3.14", patch.Field[float64]{3.14, true}}, 102 | {"\"hello\"", patch.Field[string]{"hello", true}}, 103 | 104 | // Array 105 | {`[true,false]`, patch.Field[[]bool]{[]bool{true, false}, true}}, 106 | {"[1,2,3]", patch.Field[[]int]{[]int{1, 2, 3}, true}}, 107 | {"[1,2,3]", patch.Field[[]int8]{[]int8{1, 2, 3}, true}}, 108 | {"[1,2,3]", patch.Field[[]int16]{[]int16{1, 2, 3}, true}}, 109 | {"[1,2,3]", patch.Field[[]int32]{[]int32{1, 2, 3}, true}}, 110 | {"[1,2,3]", patch.Field[[]int64]{[]int64{1, 2, 3}, true}}, 111 | {"[1,2,3]", patch.Field[[]uint]{[]uint{1, 2, 3}, true}}, 112 | // NOTE(ggicci): []uint8 is a special case, check TestFieldUint8Array 113 | {"[1,2,3]", patch.Field[[]uint16]{[]uint16{1, 2, 3}, true}}, 114 | {"[1,2,3]", patch.Field[[]uint32]{[]uint32{1, 2, 3}, true}}, 115 | {"[1,2,3]", patch.Field[[]uint64]{[]uint64{1, 2, 3}, true}}, 116 | {"[0.618,1,3.14]", patch.Field[[]float32]{[]float32{0.618, 1, 3.14}, true}}, 117 | {"[0.618,1,3.14]", patch.Field[[]float64]{[]float64{0.618, 1, 3.14}, true}}, 118 | {`["hello","world"]`, patch.Field[[]string]{[]string{"hello", "world"}, true}}, 119 | 120 | // time.Time 121 | { 122 | `"2019-08-25T07:19:34Z"`, 123 | patch.Field[time.Time]{ 124 | time.Date(2019, 8, 25, 7, 19, 34, 0, fixedZone(0)), 125 | true, 126 | }, 127 | }, 128 | { 129 | `"1991-11-10T08:00:00-07:00"`, 130 | patch.Field[time.Time]{ 131 | time.Date(1991, 11, 10, 8, 0, 0, 0, fixedZone(-7*3600)), 132 | true, 133 | }, 134 | }, 135 | { 136 | `"1991-11-10T08:00:00+08:00"`, 137 | patch.Field[time.Time]{ 138 | time.Date(1991, 11, 10, 8, 0, 0, 0, fixedZone(+8*3600)), 139 | true, 140 | }, 141 | }, 142 | 143 | // Custom structs 144 | { 145 | `{"Id":1000,"Email":"ggicci@example.com","Tags":["developer","修勾"],"Gender":"male","GitHub":{"id":3077555,"login":"ggicci","avatar_url":"https://avatars.githubusercontent.com/u/3077555?v=4"}}`, 146 | patch.Field[*Account]{ 147 | &Account{ 148 | Id: 1000, 149 | Email: "ggicci@example.com", 150 | Tags: []string{"developer", "修勾"}, 151 | Gender: "male", 152 | GitHub: &GitHubProfile{ 153 | Id: 3077555, 154 | Login: "ggicci", 155 | AvatarUrl: "https://avatars.githubusercontent.com/u/3077555?v=4", 156 | }, 157 | }, 158 | true, 159 | }, 160 | }, 161 | } 162 | 163 | for _, c := range cases { 164 | testJSONMarshalling(t, c) 165 | testJSONUnmarshalling(t, c) 166 | } 167 | } 168 | 169 | // TestFieldUint8Array runs JSON marshalling & unmarshalling tests on type Field[[]uint8]. 170 | // Because in golang's encoding/json package, encoding uint8[] is special. 171 | // See: https://golang.org/pkg/encoding/json/#Marshal 172 | // 173 | // > Array and slice values encode as JSON arrays, except that []byte encodes 174 | // as a base64-encoded string, and a nil slice encodes as the null JSON 175 | // value. 176 | // 177 | // uint8 the set of all unsigned 8-bit integers (0 to 255) 178 | // byte alias for uint8 179 | func TestFieldUint8Array(t *testing.T) { 180 | var a1 patch.Field[[]uint8] 181 | // unmarshal 182 | shouldBeNil(t, json.Unmarshal([]byte("[1,2,3]"), &a1), "unmarshal Field[[]uint8] failed") 183 | shouldResemble(t, patch.Field[[]uint8]{[]uint8{1, 2, 3}, true}, a1, "unmarshal Field[[]uint8] failed") 184 | 185 | // marshal 186 | var a2 = patch.Field[[]uint8]{[]uint8{1, 2, 3}, true} 187 | out, err := json.Marshal(a2) 188 | shouldBeNil(t, err, "marshal Field[[]uint8] failed") 189 | shouldResemble(t, `"AQID"`, string(out), "marshal Field[[]uint8] failed") 190 | } 191 | 192 | func TestField_UnmarshalJSON_Struct(t *testing.T) { 193 | var testcases = []testcase{ 194 | { 195 | `{"email":"ggicci.2@example.com","tags":["artist","photographer"]}`, 196 | AccountPatch{ 197 | Email: patch.Field[string]{"ggicci.2@example.com", true}, 198 | Gender: patch.Field[GenderType]{"", false}, 199 | Tags: patch.Field[[]string]{[]string{"artist", "photographer"}, true}, 200 | GitHub: patch.Field[*GitHubProfile]{nil, false}, 201 | }, 202 | }, 203 | { 204 | `{"tags":null,"gender":"female","github":{"id":100,"login":"ggicci.2","avatar_url":null}}`, 205 | AccountPatch{ 206 | Email: patch.Field[string]{"", false}, 207 | Gender: patch.Field[GenderType]{"female", true}, 208 | Tags: patch.Field[[]string]{nil, false}, 209 | GitHub: patch.Field[*GitHubProfile]{&GitHubProfile{ 210 | Id: 100, 211 | Login: "ggicci.2", 212 | AvatarUrl: "", 213 | }, true}, 214 | }, 215 | }, 216 | } 217 | 218 | for _, c := range testcases { 219 | testJSONUnmarshalling(t, c) 220 | } 221 | } 222 | 223 | func TestField_MarshalJSON_Struct(t *testing.T) { 224 | var testcases = []testcase{ 225 | { 226 | `{"email":"hello","tags":null,"gender":null,"github":null}`, 227 | AccountPatch{ 228 | Email: patch.Field[string]{"hello", true}, 229 | Tags: patch.Field[[]string]{nil, false}, 230 | Gender: patch.Field[GenderType]{"", false}, 231 | GitHub: patch.Field[*GitHubProfile]{nil, false}, 232 | }, 233 | }, 234 | } 235 | 236 | for _, c := range testcases { 237 | testJSONMarshalling(t, c) 238 | } 239 | } 240 | --------------------------------------------------------------------------------