├── .github
├── FUNDING.yml
├── dependabot.yml
└── workflows
│ └── go.yml
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── core
├── body.go
├── body_test.go
├── core.go
├── core_test.go
├── default.go
├── default_test.go
├── directive.go
├── directive_test.go
├── directiveruntime.go
├── error.go
├── errorhandler.go
├── errorhandler_test.go
├── file.go
├── file_test.go
├── fileable.go
├── fileable_test.go
├── fileslicable.go
├── fileslicable_test.go
├── form.go
├── form_test.go
├── formencoder.go
├── formencoder_test.go
├── formextractor.go
├── header.go
├── header_test.go
├── hybridcoder.go
├── hybridcoder_test.go
├── nonzero.go
├── nonzero_test.go
├── omitempty.go
├── option.go
├── option_test.go
├── owl.go
├── patch_test.go
├── path.go
├── path_test.go
├── query.go
├── query_test.go
├── registry.go
├── requestbuilder.go
├── requestbuilder_test.go
├── required.go
├── required_test.go
├── stringable.go
├── stringable_test.go
├── stringslicable.go
├── stringslicable_test.go
└── typekind.go
├── go.mod
├── go.sum
├── httpin.go
├── httpin_test.go
├── integration
├── echo.go
├── echo_test.go
├── gochi.go
├── gochi_test.go
├── gorilla.go
├── gorilla_test.go
├── http.go
└── http_test.go
├── internal
├── misc.go
├── misc_test.go
├── stringable.go
├── stringable_test.go
├── stringableadaptor.go
└── stringableadaptor_test.go
└── patch
├── patch.go
└── patch_test.go
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: ggicci
4 | patreon: # ggicci
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # ggicci
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: gomod
4 | directory: /
5 | schedule:
6 | interval: daily
7 | - package-ecosystem: github-actions
8 | directory: /
9 | schedule:
10 | interval: daily
11 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | name: Go
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - name: Set up Go
16 | uses: actions/setup-go@v5
17 | with:
18 | go-version: "1.23"
19 |
20 | - name: Test and generate coverage report
21 | run: make test/cover
22 |
23 | - name: Upload coverage to Codecov
24 | uses: codecov/codecov-action@v5
25 | with:
26 | name: codecov-umbrella
27 | token: ${{ secrets.CODECOV_TOKEN }}
28 | files: ./main.cover.out
29 | flags: unittests
30 | fail_ci_if_error: true
31 | verbose: true
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, built with `go test -c`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | # Dependency directories (remove the comment below to include it)
15 | # vendor/
16 | .vscode
17 |
18 | # Project specified
19 | .DS_Store
20 | __debug*
21 | docs/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Ggicci
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | default: test
2 |
3 | SHELL=/usr/bin/env bash
4 | GO=go
5 | GOTEST=$(GO) test
6 | GOCOVER=$(GO) tool cover
7 |
8 | .PHONY: test
9 | test: test/cover test/report
10 |
11 | .PHONY: test/cover
12 | test/cover:
13 | $(GOTEST) -v -race -failfast -parallel 4 -cpu 4 -coverprofile main.cover.out ./...
14 |
15 | .PHONY: test/report
16 | test/report:
17 | if [[ "$$HOSTNAME" =~ "codespaces-"* ]]; then \
18 | mkdir -p /tmp/httpin_test; \
19 | $(GOCOVER) -html=main.cover.out -o /tmp/httpin_test/coverage.html; \
20 | sudo python -m http.server -d /tmp/httpin_test -b localhost 80; \
21 | else \
22 | $(GOCOVER) -html=main.cover.out; \
23 | fi
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # httpin - HTTP Input for Go
6 |
7 |
HTTP Request from/to Go Struct
8 |
9 |
10 |
11 | [](https://github.com/ggicci/httpin/actions/workflows/go.yml) [](https://github.com/ggicci/httpin/actions/workflows/documentation.yml) [](https://codecov.io/gh/ggicci/httpin) [](https://goreportcard.com/report/github.com/ggicci/httpin) [](https://github.com/avelino/awesome-go) [](https://pkg.go.dev/github.com/ggicci/httpin)
12 |
13 |
27 |
28 |
29 |
30 | ## Core Features
31 |
32 | **httpin** helps you easily decode data from an HTTP request, including:
33 |
34 | - **Query parameters**, e.g. `?name=john&is_member=true`
35 | - **Headers**, e.g. `Authorization: xxx`
36 | - **Form data**, e.g. `username=john&password=******`
37 | - **JSON/XML Body**, e.g. `POST {"name":"john"}`
38 | - **Path variables**, e.g. `/users/{username}`
39 | - **File uploads**
40 |
41 | You **only** need to define a struct to receive/bind data from an HTTP request, **without** writing any parsing stuff code by yourself.
42 |
43 | Since v0.15.0, httpin also supports creating an HTTP request (`http.Request`) from a Go struct instance.
44 |
45 | **httpin** is:
46 |
47 | - **well documented**: at https://ggicci.github.io/httpin/
48 | - **open integrated**: with [net/http](https://ggicci.github.io/httpin/integrations/http), [go-chi/chi](https://ggicci.github.io/httpin/integrations/gochi), [gorilla/mux](https://ggicci.github.io/httpin/integrations/gorilla), [gin-gonic/gin](https://ggicci.github.io/httpin/integrations/gin), etc.
49 | - **extensible** (advanced feature): by adding your custom directives. Read [httpin - custom directives](https://ggicci.github.io/httpin/directives/custom) for more details.
50 |
51 | ## Add Httpin Directives by Tagging the Struct Fields with `in`
52 |
53 | ```go
54 | type ListUsersInput struct {
55 | Token string `in:"query=access_token;header=x-access-token"`
56 | Page int `in:"query=page;default=1"`
57 | PerPage int `in:"query=per_page;default=20"`
58 | IsMember bool `in:"query=is_member"`
59 | Search *string `in:"query=search;omitempty"`
60 | }
61 | ```
62 |
63 | ## How to decode an HTTP request to Go struct?
64 |
65 | ```go
66 | func ListUsers(rw http.ResponseWriter, r *http.Request) {
67 | input := r.Context().Value(httpin.Input).(*ListUsersInput)
68 |
69 | if input.IsMember {
70 | // Do sth.
71 | }
72 | // Do sth.
73 | }
74 | ```
75 |
76 | ## How to encode a Go struct to HTTP request?
77 |
78 | ```go
79 | func SDKListUsers() {
80 | payload := &ListUsersInput{
81 | Token: os.Getenv("MY_APP_ACCESS_TOKEN"),
82 | Page: 2,
83 | IsMember: true,
84 | }
85 |
86 | // Easy to remember, http.NewRequest -> httpin.NewRequest
87 | req, err := httpin.NewRequest("GET", "/users", payload)
88 | // ...
89 | }
90 | ```
91 |
92 | ## Why this package?
93 |
94 | #### Compared with using `net/http` package
95 |
96 | ```go
97 | func ListUsers(rw http.ResponseWriter, r *http.Request) {
98 | page, err := strconv.ParseInt(r.FormValue("page"), 10, 64)
99 | if err != nil {
100 | // Invalid parameter: page.
101 | return
102 | }
103 | perPage, err := strconv.ParseInt(r.FormValue("per_page"), 10, 64)
104 | if err != nil {
105 | // Invalid parameter: per_page.
106 | return
107 | }
108 | isMember, err := strconv.ParseBool(r.FormValue("is_member"))
109 | if err != nil {
110 | // Invalid parameter: is_member.
111 | return
112 | }
113 |
114 | // Do sth.
115 | }
116 | ```
117 |
118 | | Benefits | Before (use net/http package) | After (use ggicci/httpin package) |
119 | | ----------------------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------- |
120 | | ⌛️ Developer Time | 😫 Expensive (too much parsing stuff code) | 🚀 **Faster** (define the struct for receiving input data and leave the parsing job to httpin) |
121 | | ♻️ Code Repetition Rate | 😞 High | 😍 **Lower** |
122 | | 📖 Code Readability | 😟 Poor | 🤩 **Highly readable** |
123 | | 🔨 Maintainability | 😡 Poor | 🥰 **Highly maintainable** |
124 |
125 | ## Alternatives and Similars
126 |
127 | - [gorilla/schema](https://github.com/gorilla/schema): converts structs to and from form values
128 | - [google/go-querystring](https://github.com/google/go-querystring): encoding structs into URL query parameters
129 |
--------------------------------------------------------------------------------
/core/body.go:
--------------------------------------------------------------------------------
1 | // directive: "body"
2 | // https://ggicci.github.io/httpin/directives/body
3 |
4 | package core
5 |
6 | import (
7 | "bytes"
8 | "encoding/json"
9 | "encoding/xml"
10 | "errors"
11 | "fmt"
12 | "io"
13 | "strings"
14 |
15 | "github.com/ggicci/httpin/internal"
16 | )
17 |
18 | // ErrUnknownBodyFormat is returned when a serializer for the specified body format has not been specified.
19 | var ErrUnknownBodyFormat = errors.New("unknown body format")
20 |
21 | // DirectiveBody is the implementation of the "body" directive.
22 | type DirectiveBody struct{}
23 |
24 | func (db *DirectiveBody) Decode(rtm *DirectiveRuntime) error {
25 | req := rtm.GetRequest()
26 | bodyFormat, bodySerializer := db.getSerializer(rtm)
27 | if bodySerializer == nil {
28 | return fmt.Errorf("%w: %q", ErrUnknownBodyFormat, bodyFormat)
29 | }
30 | if err := bodySerializer.Decode(req.Body, rtm.Value.Elem().Addr().Interface()); err != nil {
31 | return err
32 | }
33 | return nil
34 | }
35 |
36 | func (db *DirectiveBody) Encode(rtm *DirectiveRuntime) error {
37 | bodyFormat, bodySerializer := db.getSerializer(rtm)
38 | if bodySerializer == nil {
39 | return fmt.Errorf("%w: %q", ErrUnknownBodyFormat, bodyFormat)
40 | }
41 | if bodyReader, err := bodySerializer.Encode(rtm.Value.Interface()); err != nil {
42 | return err
43 | } else {
44 | rtm.GetRequestBuilder().SetBody(bodyFormat, io.NopCloser(bodyReader))
45 | rtm.MarkFieldSet(true)
46 | return nil
47 | }
48 | }
49 |
50 | func (*DirectiveBody) getSerializer(rtm *DirectiveRuntime) (bodyFormat string, serializer BodySerializer) {
51 | bodyFormat = "json"
52 | if len(rtm.Directive.Argv) > 0 {
53 | bodyFormat = strings.ToLower(rtm.Directive.Argv[0])
54 | }
55 | serializer = getBodySerializer(bodyFormat)
56 | return
57 | }
58 |
59 | var bodyFormats = map[string]BodySerializer{
60 | "json": &JSONBody{},
61 | "xml": &XMLBody{},
62 | }
63 |
64 | // BodySerializer is the interface for encoding and decoding the request body.
65 | // Common body formats are: json, xml, yaml, etc.
66 | type BodySerializer interface {
67 | // Decode decodes the request body into the specified object.
68 | Decode(src io.Reader, dst any) error
69 | // Encode encodes the specified object into a reader for the request body.
70 | Encode(src any) (io.Reader, error)
71 | }
72 |
73 | // RegisterBodyFormat registers a new data formatter for the body request, which has the
74 | // BodyEncoderDecoder interface implemented. Panics on taken name, empty name or nil
75 | // decoder. Pass parameter force (true) to ignore the name conflict.
76 | //
77 | // The BodyEncoderDecoder is used by the body directive to decode and encode the data in
78 | // the given format (body format).
79 | //
80 | // It is also useful when you want to override the default registered
81 | // BodyEncoderDecoder. For example, the default JSON decoder is borrowed from
82 | // encoding/json. You can replace it with your own implementation, e.g.
83 | // json-iterator/go. For example:
84 | //
85 | // func init() {
86 | // RegisterBodyFormat("json", &myJSONBody{}, true) // force register, replace the old one
87 | // RegisterBodyFormat("yaml", &myYAMLBody{}) // register a new body format "yaml"
88 | // }
89 | func RegisterBodyFormat(format string, body BodySerializer, force ...bool) {
90 | internal.PanicOnError(
91 | registerBodyFormat(format, body, force...),
92 | )
93 | }
94 |
95 | func getBodySerializer(bodyFormat string) BodySerializer {
96 | return bodyFormats[bodyFormat]
97 | }
98 |
99 | type JSONBody struct{}
100 |
101 | func (de *JSONBody) Decode(src io.Reader, dst any) error {
102 | return json.NewDecoder(src).Decode(dst)
103 | }
104 |
105 | func (en *JSONBody) Encode(src any) (io.Reader, error) {
106 | var buf bytes.Buffer
107 | if err := json.NewEncoder(&buf).Encode(src); err != nil {
108 | return nil, err
109 | }
110 | return &buf, nil
111 | }
112 |
113 | type XMLBody struct{}
114 |
115 | func (de *XMLBody) Decode(src io.Reader, dst any) error {
116 | return xml.NewDecoder(src).Decode(dst)
117 | }
118 |
119 | func (en *XMLBody) Encode(src any) (io.Reader, error) {
120 | var buf bytes.Buffer
121 | if err := xml.NewEncoder(&buf).Encode(src); err != nil {
122 | return nil, err
123 | }
124 | return &buf, nil
125 | }
126 |
127 | func registerBodyFormat(format string, body BodySerializer, force ...bool) error {
128 | ignoreConflict := len(force) > 0 && force[0]
129 | format = strings.ToLower(format)
130 | if !ignoreConflict && getBodySerializer(format) != nil {
131 | return fmt.Errorf("duplicate body format: %q", format)
132 | }
133 | if format == "" {
134 | return errors.New("body format cannot be empty")
135 | }
136 | if body == nil {
137 | return errors.New("body serializer cannot be nil")
138 | }
139 | bodyFormats[format] = body
140 | return nil
141 | }
142 |
--------------------------------------------------------------------------------
/core/body_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "encoding/xml"
7 | "errors"
8 | "io"
9 | "net/http"
10 | "net/url"
11 | "strings"
12 | "testing"
13 |
14 | "github.com/stretchr/testify/assert"
15 | )
16 |
17 | type LanguageLevel struct {
18 | Language string `json:"lang" xml:"lang"`
19 | Level int `json:"level" xml:"level"`
20 | }
21 |
22 | type BodyPayload struct {
23 | Name string `json:"name" xml:"name"`
24 | Age int `json:"age" xml:"age"`
25 | IsNative bool `json:"is_native" xml:"is_native"`
26 | Hobbies []string `json:"hobbies" xml:"hobbies"`
27 | Languages []*LanguageLevel `json:"languages" xml:"languages"`
28 | }
29 |
30 | type BodyPayloadInJSON struct {
31 | Body *BodyPayload `in:"body=json"`
32 | }
33 |
34 | type BodyPayloadInXML struct {
35 | Body *BodyPayload `in:"body=xml"`
36 | }
37 |
38 | type YouCannotUseFormAndBodyAtTheSameTime struct {
39 | Page int `in:"form=page"`
40 | PageSize int `in:"form=page_size"`
41 | Body *BodyPayload `in:"body=json"`
42 | }
43 |
44 | var sampleBodyPayloadInJSONText = `
45 | {
46 | "name": "Elia",
47 | "is_native": false,
48 | "age": 14,
49 | "hobbies": ["Gaming", "Drawing"],
50 | "languages": [
51 | {"lang": "English", "level": 10},
52 | {"lang": "Japanese", "level": 3}
53 | ]
54 | }`
55 |
56 | var sampleBodyPayloadInJSONObject = &BodyPayloadInJSON{
57 | Body: &BodyPayload{
58 | Name: "Elia",
59 | Age: 14,
60 | IsNative: false,
61 | Hobbies: []string{"Gaming", "Drawing"},
62 | Languages: []*LanguageLevel{
63 | {"English", 10},
64 | {"Japanese", 3},
65 | },
66 | },
67 | }
68 |
69 | var sampleBodyPayloadInXMLText = `
70 |
71 | Elia
72 | 14
73 | false
74 | Gaming
75 | Drawing
76 |
77 | English
78 | 10
79 |
80 |
81 | Japanese
82 | 3
83 |
84 | `
85 |
86 | var sampleBodyPayloadInXMLObject = &BodyPayloadInXML{
87 | Body: &BodyPayload{
88 | Name: "Elia",
89 | Age: 14,
90 | IsNative: false,
91 | Hobbies: []string{"Gaming", "Drawing"},
92 | Languages: []*LanguageLevel{
93 | {"English", 10},
94 | {"Japanese", 3},
95 | },
96 | },
97 | }
98 |
99 | func TestBodyDirective_Decode_JSON(t *testing.T) {
100 | assert := assert.New(t)
101 | co, err := New(BodyPayloadInJSON{})
102 | assert.NoError(err)
103 |
104 | r, _ := http.NewRequest("GET", "https://example.com", nil)
105 | r.Form = make(url.Values)
106 | r.Form.Set("page", "4")
107 | r.Form.Set("page_size", "30")
108 | r.Body = makeBodyReader(sampleBodyPayloadInJSONText)
109 | r.Header.Set("Content-Type", "application/json")
110 | gotValue, err := co.Decode(r)
111 | assert.NoError(err)
112 | assert.Equal(sampleBodyPayloadInJSONObject, gotValue)
113 | }
114 |
115 | func TestBodyDirective_Decode_XML(t *testing.T) {
116 | assert := assert.New(t)
117 | co, err := New(BodyPayloadInXML{})
118 | assert.NoError(err)
119 |
120 | r, _ := http.NewRequest("GET", "https://example.com", nil)
121 | r.Body = makeBodyReader(sampleBodyPayloadInXMLText)
122 | r.Header.Set("Content-Type", "application/xml")
123 |
124 | gotValue, err := co.Decode(r)
125 | assert.NoError(err)
126 | assert.Equal(sampleBodyPayloadInXMLObject, gotValue)
127 | }
128 |
129 | func TestBodyDirective_Decode_ErrUnknownBodyFormat(t *testing.T) {
130 | type UnknownBodyFormatPayload struct {
131 | Body *BodyPayload `in:"body=yaml"`
132 | }
133 |
134 | co, err := New(UnknownBodyFormatPayload{})
135 | assert.NoError(t, err)
136 | req, _ := http.NewRequest("GET", "https://example.com", nil)
137 | req.Body = makeBodyReader(sampleBodyPayloadInJSONText)
138 | _, err = co.Decode(req)
139 | assert.ErrorContains(t, err, "unknown body format: \"yaml\"")
140 | }
141 |
142 | func TestBodyDirective_Decode_ErrConflictWithFormDirective(t *testing.T) {
143 | co, err := New(YouCannotUseFormAndBodyAtTheSameTime{})
144 | assert.NoError(t, err)
145 | req, err := co.NewRequest("POST", "/data", &YouCannotUseFormAndBodyAtTheSameTime{
146 | Page: 1,
147 | PageSize: 20,
148 | Body: sampleBodyPayloadInJSONObject.Body,
149 | })
150 | assert.ErrorContains(t, err, "cannot use both form and body directive at the same time")
151 | assert.Nil(t, req)
152 | }
153 |
154 | type yamlBody struct{}
155 |
156 | var errYamlNotImplemented = errors.New("yaml not implemented")
157 |
158 | func (de *yamlBody) Decode(src io.Reader, dst any) error {
159 | return errYamlNotImplemented // for test only
160 | }
161 |
162 | func (en *yamlBody) Encode(src any) (io.Reader, error) {
163 | return nil, errYamlNotImplemented // for test only
164 | }
165 |
166 | type YamlInput struct {
167 | Body map[string]any `in:"body=yaml"`
168 | }
169 |
170 | func TestRegisterBodyFormat(t *testing.T) {
171 | assert.NotPanics(t, func() {
172 | RegisterBodyFormat("yaml", &yamlBody{})
173 | })
174 | assert.Panics(t, func() {
175 | RegisterBodyFormat("yaml", &yamlBody{})
176 | })
177 |
178 | co, err := New(YamlInput{})
179 | assert.NoError(t, err)
180 |
181 | r, _ := http.NewRequest("GET", "https://example.com", nil)
182 | r.Body = makeBodyReader(`version: "3"`)
183 |
184 | gotValue, err := co.Decode(r)
185 | assert.ErrorIs(t, err, errYamlNotImplemented)
186 | assert.Nil(t, gotValue)
187 | unregisterBodyFormat("yaml")
188 | }
189 |
190 | func TestRegisterBodyFormat_ErrNilBodySerializer(t *testing.T) {
191 | assert.Panics(t, func() {
192 | RegisterBodyFormat("toml", nil)
193 | })
194 | }
195 |
196 | func TestRegisterBodyFormat_ForceRegister(t *testing.T) {
197 | assert.NotPanics(t, func() {
198 | RegisterBodyFormat("yaml", &yamlBody{}, true)
199 | })
200 | assert.NotPanics(t, func() {
201 | RegisterBodyFormat("yaml", &yamlBody{}, true)
202 | })
203 | unregisterBodyFormat("yaml")
204 | }
205 |
206 | func TestRegisterBodyFormat_ForceRegisterWithEmptyBodyFormat(t *testing.T) {
207 | assert.PanicsWithError(t, "httpin: body format cannot be empty", func() {
208 | RegisterBodyFormat("", &yamlBody{}, true)
209 | })
210 | }
211 |
212 | func TestBodyDirective_NewRequest_JSON(t *testing.T) {
213 | assert := assert.New(t)
214 | co, err := New(BodyPayloadInJSON{})
215 | assert.NoError(err)
216 | req, err := co.NewRequest("POST", "/data", sampleBodyPayloadInJSONObject)
217 | expected, _ := http.NewRequest("POST", "/data", nil)
218 | expected.Header.Set("Content-Type", "application/json")
219 | assert.NoError(err)
220 | var body bytes.Buffer
221 | assert.NoError(json.NewEncoder(&body).Encode(sampleBodyPayloadInJSONObject.Body))
222 | expected.Body = io.NopCloser(&body)
223 | assert.Equal(expected, req)
224 |
225 | // On the server side (decode).
226 | gotValue, err := co.Decode(req)
227 | assert.NoError(err)
228 | got, ok := gotValue.(*BodyPayloadInJSON)
229 | assert.True(ok)
230 | assert.Equal(sampleBodyPayloadInJSONObject, got)
231 | }
232 |
233 | func TestBodyDirective_NewRequest_XML(t *testing.T) {
234 | assert := assert.New(t)
235 | co, err := New(BodyPayloadInXML{})
236 | assert.NoError(err)
237 | req, err := co.NewRequest("POST", "/data", sampleBodyPayloadInXMLObject)
238 | expected, _ := http.NewRequest("POST", "/data", nil)
239 | expected.Header.Set("Content-Type", "application/xml")
240 | assert.NoError(err)
241 | var body bytes.Buffer
242 | assert.NoError(xml.NewEncoder(&body).Encode(sampleBodyPayloadInXMLObject.Body))
243 | expected.Body = io.NopCloser(&body)
244 | assert.Equal(expected, req)
245 |
246 | // On the server side (decode).
247 | gotValue, err := co.Decode(req)
248 | assert.NoError(err)
249 | got, ok := gotValue.(*BodyPayloadInXML)
250 | assert.True(ok)
251 | assert.Equal(sampleBodyPayloadInXMLObject, got)
252 | }
253 |
254 | func TestBodyDirective_NewRequest_ErrUnknownBodyFormat(t *testing.T) {
255 | type UnknownBodyFormatPayload struct {
256 | Body *BodyPayload `in:"body=yaml"`
257 | }
258 | query := &UnknownBodyFormatPayload{
259 | Body: nil,
260 | }
261 | co, err := New(UnknownBodyFormatPayload{})
262 | assert.NoError(t, err)
263 | req, err := co.NewRequest("PUT", "/apples/10", query)
264 | assert.ErrorContains(t, err, "unknown body format: \"yaml\"")
265 | assert.Nil(t, req)
266 | }
267 |
268 | func unregisterBodyFormat(format string) {
269 | delete(bodyFormats, format)
270 | }
271 |
272 | func makeBodyReader(text string) io.ReadCloser {
273 | return io.NopCloser(strings.NewReader(text))
274 | }
275 |
--------------------------------------------------------------------------------
/core/core.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "mime"
8 | "net/http"
9 | "reflect"
10 | "sort"
11 | "sync"
12 |
13 | "github.com/ggicci/owl"
14 | )
15 |
16 | var builtResolvers sync.Map // map[reflect.Type]*owl.Resolver
17 |
18 | // Core is the Core of httpin. It holds the resolver of a specific struct type.
19 | // Who is responsible for decoding an HTTP request to an instance of such struct
20 | // type.
21 | type Core struct {
22 | resolver *owl.Resolver // for decoding
23 | scanResolver *owl.Resolver // for encoding
24 | errorHandler ErrorHandler
25 | maxMemory int64 // in bytes
26 | enableNestedDirectives bool
27 | resolverMu sync.RWMutex
28 | }
29 |
30 | // New creates a new Core instance for the given intpuStruct. It will build a resolver
31 | // for the inputStruct and apply the given options to the Core instance. The Core instance
32 | // is responsible for both:
33 | //
34 | // - decoding an HTTP request to an instance of the inputStruct;
35 | // - encoding an instance of the inputStruct to an HTTP request.
36 | func New(inputStruct any, opts ...Option) (*Core, error) {
37 | resolver, err := buildResolver(inputStruct)
38 | if err != nil {
39 | return nil, err
40 | }
41 |
42 | core := &Core{
43 | resolver: resolver,
44 | }
45 |
46 | // Apply default options and user custom options to the
47 | var allOptions []Option
48 | defaultOptions := []Option{
49 | WithMaxMemory(defaultMaxMemory),
50 | WithNestedDirectivesEnabled(globalNestedDirectivesEnabled),
51 | }
52 | allOptions = append(allOptions, defaultOptions...)
53 | allOptions = append(allOptions, opts...)
54 |
55 | for _, opt := range allOptions {
56 | if err := opt(core); err != nil {
57 | return nil, fmt.Errorf("invalid option: %w", err)
58 | }
59 | }
60 |
61 | return core, nil
62 | }
63 |
64 | // Decode decodes an HTTP request to an instance of the input struct and returns
65 | // its pointer. For example:
66 | //
67 | // New(Input{}).Decode(req) -> *Input
68 | func (c *Core) Decode(req *http.Request) (any, error) {
69 | // Create the input struct instance. Used to be created by owl.Resolve().
70 | value := reflect.New(c.resolver.Type).Interface()
71 | if err := c.DecodeTo(req, value); err != nil {
72 | return nil, err
73 | } else {
74 | return value, nil
75 | }
76 | }
77 |
78 | // DecodeTo decodes an HTTP request to the given value. The value must be a pointer
79 | // to the struct instance of the type that the Core instance holds.
80 | func (c *Core) DecodeTo(req *http.Request, value any) (err error) {
81 | if err = c.parseRequestForm(req); err != nil {
82 | return fmt.Errorf("failed to parse request form: %w", err)
83 | }
84 |
85 | err = c.resolver.ResolveTo(
86 | value,
87 | owl.WithNamespace(decoderNamespace),
88 | owl.WithValue(CtxRequest, req),
89 | owl.WithNestedDirectivesEnabled(c.enableNestedDirectives),
90 | )
91 | if err != nil && !errors.Is(err, owl.ErrInvalidResolveTarget) {
92 | return NewInvalidFieldError(err)
93 | }
94 | return err
95 | }
96 |
97 | // NewRequest wraps NewRequestWithContext using context.Background(), see
98 | // NewRequestWithContext.
99 | func (c *Core) NewRequest(method string, url string, input any) (*http.Request, error) {
100 | return c.NewRequestWithContext(context.Background(), method, url, input)
101 | }
102 |
103 | // NewRequestWithContext turns the given input struct into an HTTP request. Note
104 | // that the Core instance is bound to a specific type of struct. Which means
105 | // when the given input is not the type of the struct that the Core instance
106 | // holds, error of type mismatch will be returned. In order to avoid this error,
107 | // you can always use httpin.NewRequest() instead. Which will create a Core
108 | // instance for you on demand. There's no performance penalty for doing so.
109 | // Because there's a cache layer for all the Core instances.
110 | func (c *Core) NewRequestWithContext(ctx context.Context, method string, url string, input any) (*http.Request, error) {
111 | c.prepareScanResolver()
112 | req, err := http.NewRequestWithContext(ctx, method, url, nil)
113 | if err != nil {
114 | return nil, err
115 | }
116 |
117 | rb := NewRequestBuilder(ctx)
118 |
119 | // NOTE(ggicci): the error returned a joined error by using errors.Join.
120 | if err = c.scanResolver.Scan(
121 | input,
122 | owl.WithNamespace(encoderNamespace),
123 | owl.WithValue(CtxRequestBuilder, rb),
124 | owl.WithNestedDirectivesEnabled(c.enableNestedDirectives),
125 | ); err != nil {
126 | // err is a list of *owl.ScanError that joined by errors.Join.
127 | if errs, ok := err.(interface{ Unwrap() []error }); ok {
128 | var invalidFieldErrors MultiInvalidFieldError
129 | for _, err := range errs.Unwrap() {
130 | invalidFieldErrors = append(invalidFieldErrors, NewInvalidFieldError(err))
131 | }
132 | return nil, invalidFieldErrors
133 | } else {
134 | return nil, err // should never happen, just in case
135 | }
136 | }
137 |
138 | // Populate the request with the encoded values.
139 | if err := rb.Populate(req); err != nil {
140 | return nil, fmt.Errorf("failed to populate request: %w", err)
141 | }
142 |
143 | return req, nil
144 | }
145 |
146 | // GetErrorHandler returns the error handler of the core if set, or the global
147 | // custom error handler.
148 | func (c *Core) GetErrorHandler() ErrorHandler {
149 | if c.errorHandler != nil {
150 | return c.errorHandler
151 | }
152 |
153 | return globalCustomErrorHandler
154 | }
155 |
156 | func (c *Core) prepareScanResolver() {
157 | c.resolverMu.RLock()
158 | if c.scanResolver == nil {
159 | c.resolverMu.RUnlock()
160 | c.resolverMu.Lock()
161 | defer c.resolverMu.Unlock()
162 |
163 | if c.scanResolver == nil {
164 | c.scanResolver = c.resolver.Copy()
165 |
166 | // Reorder the directives to make sure the "default" and "nonzero" directives work properly.
167 | c.scanResolver.Iterate(func(r *owl.Resolver) error {
168 | sort.Sort(directiveOrderForEncoding(r.Directives))
169 | return nil
170 | })
171 | }
172 | } else {
173 | c.resolverMu.RUnlock()
174 | }
175 | }
176 |
177 | func (c *Core) parseRequestForm(req *http.Request) (err error) {
178 | ct, _, _ := mime.ParseMediaType(req.Header.Get("Content-Type"))
179 | if ct == "multipart/form-data" {
180 | err = req.ParseMultipartForm(c.maxMemory)
181 | } else {
182 | err = req.ParseForm()
183 | }
184 | return
185 | }
186 |
187 | // buildResolver builds a resolver for the inputStruct. It will run normalizations
188 | // on the resolver and cache it.
189 | func buildResolver(inputStruct any) (*owl.Resolver, error) {
190 | resolver, err := owl.New(inputStruct)
191 | if err != nil {
192 | return nil, err
193 | }
194 |
195 | // Returns the cached resolver if it's already built.
196 | if cached, ok := builtResolvers.Load(resolver.Type); ok {
197 | return cached.(*owl.Resolver), nil
198 | }
199 |
200 | // Normalize the resolver before caching it.
201 | if err := normalizeResolver(resolver); err != nil {
202 | return nil, err
203 | }
204 |
205 | // Cache the resolver.
206 | builtResolvers.Store(resolver.Type, resolver)
207 | return resolver, nil
208 | }
209 |
210 | // normalizeResolver normalizes the resolvers by running a series of
211 | // normalizations on every field resolver.
212 | func normalizeResolver(r *owl.Resolver) error {
213 | normalize := func(r *owl.Resolver) error {
214 | for _, fn := range []func(*owl.Resolver) error{
215 | removeDecoderDirective, // backward compatibility, use "coder" instead
216 | removeCoderDirective, // "coder" takes precedence over "decoder"
217 | ensureDirectiveExecutorsRegistered, // always the last one
218 | } {
219 | if err := fn(r); err != nil {
220 | return err
221 | }
222 | }
223 | return nil
224 | }
225 |
226 | return r.Iterate(normalize)
227 | }
228 |
229 | func removeDecoderDirective(r *owl.Resolver) error {
230 | return reserveCoderDirective(r, "decoder")
231 | }
232 |
233 | func removeCoderDirective(r *owl.Resolver) error {
234 | return reserveCoderDirective(r, "coder")
235 | }
236 |
237 | // reserveCoderDirective removes the directive from the resolver. name is "coder" or "decoder".
238 | // The "decoder"/"coder"are two special directives which do nothing, but an indicator of
239 | // overriding the decoder and encoder for a specific field.
240 | func reserveCoderDirective(r *owl.Resolver, name string) error {
241 | d := r.RemoveDirective(name)
242 | if d == nil {
243 | return nil
244 | }
245 | if len(d.Argv) == 0 {
246 | return fmt.Errorf("directive %s: missing coder name", name)
247 | }
248 |
249 | if isFileType(r.Type) {
250 | return fmt.Errorf("directive %s: cannot be used on a file type field", name)
251 | }
252 |
253 | namedAdaptor := namedStringableAdaptors[d.Argv[0]]
254 | if namedAdaptor == nil {
255 | return fmt.Errorf("directive %s: %w: %q", name, ErrUnregisteredCoder, d.Argv[0])
256 | }
257 |
258 | r.Context = context.WithValue(r.Context, CtxCustomCoder, namedAdaptor)
259 | return nil
260 | }
261 |
262 | // ensureDirectiveExecutorsRegistered ensures all directives that defined in the
263 | // resolver are registered in the executor registry.
264 | func ensureDirectiveExecutorsRegistered(r *owl.Resolver) error {
265 | for _, d := range r.Directives {
266 | if decoderNamespace.LookupExecutor(d.Name) == nil {
267 | return fmt.Errorf("%w: %q", ErrUnregisteredDirective, d.Name)
268 | }
269 | // NOTE: don't need to check encoderNamespace because a directive
270 | // will always be registered in both namespaces. See RegisterDirective().
271 | }
272 | return nil
273 | }
274 |
275 | type directiveOrderForEncoding []*owl.Directive
276 |
277 | func (d directiveOrderForEncoding) Len() int {
278 | return len(d)
279 | }
280 |
281 | func (d directiveOrderForEncoding) Less(i, j int) bool {
282 | if d[i].Name == "default" {
283 | return true // always the first one to run
284 | } else if d[i].Name == "nonzero" {
285 | return true // always the second one to run
286 | }
287 | return false
288 | }
289 |
290 | func (d directiveOrderForEncoding) Swap(i, j int) {
291 | d[i], d[j] = d[j], d[i]
292 | }
293 |
--------------------------------------------------------------------------------
/core/default.go:
--------------------------------------------------------------------------------
1 | // directive: "default"
2 | // https://ggicci.github.io/httpin/directives/default
3 |
4 | package core
5 |
6 | import (
7 | "mime/multipart"
8 | )
9 |
10 | type DirectiveDefault struct{}
11 |
12 | func (*DirectiveDefault) Decode(rtm *DirectiveRuntime) error {
13 | if rtm.IsFieldSet() {
14 | return nil // noop, the field was set by a former executor
15 | }
16 |
17 | // Transform:
18 | // 1. ctx.Argv -> input values
19 | // 2. ["default"] -> keys
20 | extractor := &FormExtractor{
21 | Runtime: rtm,
22 | Form: multipart.Form{
23 | Value: map[string][]string{
24 | "default": rtm.Directive.Argv,
25 | },
26 | },
27 | }
28 | return extractor.Extract("default")
29 | }
30 |
31 | func (*DirectiveDefault) Encode(rtm *DirectiveRuntime) error {
32 | if !rtm.Value.IsZero() {
33 | return nil // skip if the field is not empty
34 | }
35 | var adapt AnyStringableAdaptor
36 | coder := rtm.GetCustomCoder()
37 | if coder != nil {
38 | adapt = coder.Adapt
39 | }
40 | if stringSlicable, err := NewStringSlicable(rtm.Value, adapt); err != nil {
41 | return err
42 | } else {
43 | return stringSlicable.FromStringSlice(rtm.Directive.Argv)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/core/default_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "testing"
7 | "time"
8 |
9 | "github.com/ggicci/httpin/internal"
10 | "github.com/ggicci/httpin/patch"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestDirectiveDefault_Decode(t *testing.T) {
15 | type ThingWithDefaultValues struct {
16 | Page int `in:"form=page;default=1"`
17 | PointerPage *int `in:"form=pointer_page;default=1"`
18 | PatchPage patch.Field[int] `in:"form=patch_page;default=1"`
19 | PerPage int `in:"form=per_page;default=20"`
20 | StateList []string `in:"form=state;default=pending,in_progress,failed"`
21 | PatchStateList patch.Field[[]string] `in:"form=patch_state;default=a,b,c"`
22 | }
23 |
24 | r, _ := http.NewRequest("GET", "/", nil)
25 | r.Form = url.Values{
26 | "page": {"7"},
27 | "pointer_page": {"9"},
28 | "patch_page": {"11"},
29 | "state": {},
30 | "patch_state": {},
31 | }
32 | expected := &ThingWithDefaultValues{
33 | Page: 7,
34 | PointerPage: internal.Pointerize[int](9),
35 | PatchPage: patch.Field[int]{Value: 11, Valid: true},
36 | PerPage: 20,
37 | StateList: []string{"pending", "in_progress", "failed"},
38 | PatchStateList: patch.Field[[]string]{Value: []string{"a", "b", "c"}, Valid: true},
39 | }
40 | co, err := New(ThingWithDefaultValues{})
41 | assert.NoError(t, err)
42 | got, err := co.Decode(r)
43 | assert.NoError(t, err)
44 | assert.Equal(t, expected, got)
45 | }
46 |
47 | // FIX: https://github.com/ggicci/httpin/issues/77
48 | // Decode parameter struct with default values only works the first time
49 | func TestDirectiveDeafult_Decode_DecodeTwice(t *testing.T) {
50 | type ThingWithDefaultValues struct {
51 | Id uint `in:"query=id;required"`
52 | Page int `in:"query=page;default=1"`
53 | PerPage int `in:"query=page_size;default=127"`
54 | }
55 |
56 | r, _ := http.NewRequest("GET", "/?id=123", nil)
57 | expected := &ThingWithDefaultValues{
58 | Id: 123,
59 | Page: 1,
60 | PerPage: 127,
61 | }
62 |
63 | co, err := New(ThingWithDefaultValues{})
64 | assert.NoError(t, err)
65 |
66 | // First decode works as expected
67 | xxx, err := co.Decode(r)
68 | assert.NoError(t, err)
69 | assert.Equal(t, expected, xxx)
70 |
71 | // Second decode generates error
72 | xxx, err = co.Decode(r)
73 | assert.NoError(t, err)
74 | assert.Equal(t, expected, xxx)
75 | }
76 |
77 | func TestDirectiveDefault_NewRequest(t *testing.T) {
78 | type ListTicketRequest struct {
79 | Page int `in:"query=page;default=1"`
80 | PerPage int `in:"query=per_page;default=20"`
81 | States []string `in:"query=state;default=assigned,in_progress"`
82 | }
83 |
84 | co, err := New(ListTicketRequest{})
85 | assert.NoError(t, err)
86 |
87 | payload := &ListTicketRequest{
88 | Page: 2,
89 | }
90 | expected, _ := http.NewRequest("GET", "/tickets", nil)
91 | expected.URL.RawQuery = url.Values{
92 | "page": {"2"},
93 | "per_page": {"20"},
94 | "state": {"assigned", "in_progress"},
95 | }.Encode()
96 | req, err := co.NewRequest("GET", "/tickets", payload)
97 | assert.NoError(t, err)
98 | assert.Equal(t, expected, req)
99 | }
100 |
101 | func TestDirectiveDefault_NewRequest_WithNamedCoder(t *testing.T) {
102 | registerMyDate()
103 | type ListUsersRequest struct {
104 | Page int `in:"query=page;default=1"`
105 | PerPage int `in:"query=per_page;default=20"`
106 | RegistrationDate time.Time `in:"query=registration_date;default=2020-01-01;coder=mydate"`
107 | }
108 |
109 | co, err := New(ListUsersRequest{})
110 | assert.NoError(t, err)
111 |
112 | payload := &ListUsersRequest{
113 | Page: 2,
114 | PerPage: 10,
115 | }
116 | expected, _ := http.NewRequest("GET", "/users", nil)
117 | expected.URL.RawQuery = url.Values{
118 | "page": {"2"},
119 | "per_page": {"10"},
120 | "registration_date": {"2020-01-01"},
121 | }.Encode()
122 | req, err := co.NewRequest("GET", "/users", payload)
123 | assert.NoError(t, err)
124 | assert.Equal(t, expected, req)
125 | unregisterMyDate()
126 | }
127 |
--------------------------------------------------------------------------------
/core/directive.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 |
7 | "github.com/ggicci/httpin/internal"
8 | "github.com/ggicci/owl"
9 | )
10 |
11 | func init() {
12 | RegisterDirective("form", &DirectvieForm{})
13 | RegisterDirective("query", &DirectiveQuery{})
14 | RegisterDirective("header", &DirectiveHeader{})
15 | RegisterDirective("body", &DirectiveBody{})
16 | RegisterDirective("required", &DirectiveRequired{})
17 | RegisterDirective("default", &DirectiveDefault{})
18 | RegisterDirective("nonzero", &DirectiveNonzero{})
19 | registerDirective("path", defaultPathDirective)
20 | registerDirective("omitempty", &DirectiveOmitEmpty{})
21 |
22 | // decoder is a special executor which does nothing, but is an indicator of
23 | // overriding the decoder for a specific field.
24 | registerDirective("decoder", noopDirective)
25 | registerDirective("coder", noopDirective)
26 | }
27 |
28 | var (
29 | // decoderNamespace is the namespace for registering directive executors that are
30 | // used to decode the http request to input struct.
31 | decoderNamespace = owl.NewNamespace()
32 |
33 | // encoderNamespace is the namespace for registering directive executors that are
34 | // used to encode the input struct to http request.
35 | encoderNamespace = owl.NewNamespace()
36 |
37 | // reservedExecutorNames are the names that cannot be used to register user defined directives
38 | reservedExecutorNames = []string{"decoder", "coder"}
39 |
40 | noopDirective = &directiveNoop{}
41 | )
42 |
43 | type DirectiveExecutor interface {
44 | // Encode encodes the field of the input struct to the HTTP request.
45 | Encode(*DirectiveRuntime) error
46 |
47 | // Decode decodes the field of the input struct from the HTTP request.
48 | Decode(*DirectiveRuntime) error
49 | }
50 |
51 | // RegisterDirective registers a DirectiveExecutor with the given directive name. The
52 | // directive should be able to both extract the value from the HTTP request and build
53 | // the HTTP request from the value. The Decode API is used to decode data from the HTTP
54 | // request to a field of the input struct, and Encode API is used to encode the field of
55 | // the input struct to the HTTP request.
56 | //
57 | // Will panic if the name were taken or given executor is nil. Pass parameter force
58 | // (true) to ignore the name conflict.
59 | func RegisterDirective(name string, executor DirectiveExecutor, force ...bool) {
60 | panicOnReservedExecutorName(name)
61 | registerDirective(name, executor, force...)
62 | }
63 |
64 | func registerDirective(name string, executor DirectiveExecutor, force ...bool) {
65 | registerDirectiveExecutorToNamespace(decoderNamespace, name, executor, force...)
66 | registerDirectiveExecutorToNamespace(encoderNamespace, name, executor, force...)
67 | }
68 |
69 | func registerDirectiveExecutorToNamespace(ns *owl.Namespace, name string, exe DirectiveExecutor, force ...bool) {
70 | if exe == nil {
71 | internal.PanicOnError(errors.New("nil directive executor"))
72 | }
73 | if ns == decoderNamespace {
74 | ns.RegisterDirectiveExecutor(name, asOwlDirectiveExecutor(exe.Decode), force...)
75 | } else {
76 | ns.RegisterDirectiveExecutor(name, asOwlDirectiveExecutor(exe.Encode), force...)
77 | }
78 | }
79 |
80 | func asOwlDirectiveExecutor(directiveFunc func(*DirectiveRuntime) error) owl.DirectiveExecutor {
81 | return owl.DirectiveExecutorFunc(func(dr *owl.DirectiveRuntime) error {
82 | return directiveFunc((*DirectiveRuntime)(dr))
83 | })
84 | }
85 |
86 | func panicOnReservedExecutorName(name string) {
87 | for _, reservedName := range reservedExecutorNames {
88 | if name == reservedName {
89 | internal.PanicOnError(fmt.Errorf("reserved executor name: %q", name))
90 | }
91 | }
92 | }
93 |
94 | // directiveNoop is a DirectiveExecutor that does nothing, "noop" stands for "no operation".
95 | type directiveNoop struct{}
96 |
97 | func (*directiveNoop) Encode(*DirectiveRuntime) error { return nil }
98 | func (*directiveNoop) Decode(*DirectiveRuntime) error { return nil }
99 |
--------------------------------------------------------------------------------
/core/directive_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestRegisterDirectiveExecutor(t *testing.T) {
10 | assert.NotPanics(t, func() {
11 | RegisterDirective("noop_TestRegisterDirectiveExecutor", noopDirective)
12 | })
13 |
14 | assert.Panics(t, func() {
15 | RegisterDirective("noop_TestRegisterDirectiveExecutor", noopDirective)
16 | }, "should panic on duplicate name")
17 |
18 | assert.Panics(t, func() {
19 | RegisterDirective("nil_TestRegisterDirectiveExecutor", nil)
20 | }, "should panic on nil executor")
21 |
22 | assert.Panics(t, func() {
23 | RegisterDirective("decoder", noopDirective)
24 | }, "should panic on reserved name")
25 | }
26 |
27 | func TestRegisterDirectiveExecutor_ForceReplace(t *testing.T) {
28 | assert.NotPanics(t, func() {
29 | RegisterDirective("noop_TestRegisterDirectiveExecutor_forceReplace", noopDirective, true)
30 | })
31 |
32 | assert.NotPanics(t, func() {
33 | RegisterDirective("noop_TestRegisterDirectiveExecutor_forceReplace", noopDirective, true)
34 | }, "should not panic on duplicate name")
35 |
36 | assert.Panics(t, func() {
37 | RegisterDirective("nil_TestRegisterDirectiveExecutor_forceReplace", nil, true)
38 | }, "should panic on nil executor")
39 |
40 | assert.Panics(t, func() {
41 | RegisterDirective("decoder", noopDirective, true)
42 | }, "should panic on reserved name")
43 | }
44 |
--------------------------------------------------------------------------------
/core/directiveruntime.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 | "reflect"
8 |
9 | "github.com/ggicci/owl"
10 | )
11 |
12 | type contextKey int
13 |
14 | const (
15 | // CtxRequest is the key to get the HTTP request value (of *http.Request)
16 | // from DirectiveRuntime.Context. The HTTP request value is injected by
17 | // httpin to the context of DirectiveRuntime before executing the directive.
18 | // See Core.Decode() for more details.
19 | CtxRequest contextKey = iota
20 |
21 | CtxRequestBuilder
22 |
23 | // CtxCustomCoder is the key to get the custom decoder for a field from
24 | // Resolver.Context. Which is specified by the "decoder" directive.
25 | // During resolver building phase, the "decoder" directive will be removed
26 | // from the resolver, and the targeted decoder by name will be put into
27 | // Resolver.Context with this key. e.g.
28 | //
29 | // type GreetInput struct {
30 | // Message string `httpin:"decoder=custom"`
31 | // }
32 | // For the above example, the decoder named "custom" will be put into the
33 | // resolver of Message field with this key.
34 | CtxCustomCoder
35 |
36 | // CtxFieldSet is used by executors to tell whether a field has been set. When
37 | // multiple executors were applied to a field, if the field value were set
38 | // by a former executor, the latter executors MAY skip running by consulting
39 | // this context value.
40 | CtxFieldSet
41 | )
42 |
43 | // DirectiveRuntime is the runtime of a directive execution. It wraps owl.DirectiveRuntime,
44 | // providing some additional helper methods particular to httpin.
45 | //
46 | // See owl.DirectiveRuntime for more details.
47 | type DirectiveRuntime owl.DirectiveRuntime
48 |
49 | func (rtm *DirectiveRuntime) GetRequest() *http.Request {
50 | if req := rtm.Context.Value(CtxRequest); req != nil {
51 | return req.(*http.Request)
52 | }
53 | return nil
54 | }
55 |
56 | func (rtm *DirectiveRuntime) GetRequestBuilder() *RequestBuilder {
57 | if rb := rtm.Context.Value(CtxRequestBuilder); rb != nil {
58 | return rb.(*RequestBuilder)
59 | }
60 | return nil
61 | }
62 |
63 | func (rtm *DirectiveRuntime) GetCustomCoder() *NamedAnyStringableAdaptor {
64 | if info := rtm.Resolver.Context.Value(CtxCustomCoder); info != nil {
65 | return info.(*NamedAnyStringableAdaptor)
66 | } else {
67 | return nil
68 | }
69 | }
70 |
71 | func (rtm *DirectiveRuntime) IsFieldSet() bool {
72 | return rtm.Context.Value(CtxFieldSet) == true
73 | }
74 |
75 | func (rtm *DirectiveRuntime) MarkFieldSet(value bool) {
76 | rtm.Context = context.WithValue(rtm.Context, CtxFieldSet, value)
77 | }
78 |
79 | func (rtm *DirectiveRuntime) SetValue(value any) error {
80 | if value == nil {
81 | // NOTE: should we wipe the value here? i.e. set the value to nil if necessary.
82 | // No case found yet, at least for now.
83 | return nil
84 | }
85 | newValue := reflect.ValueOf(value)
86 | targetType := rtm.Value.Type().Elem()
87 |
88 | if !newValue.Type().AssignableTo(targetType) {
89 | return fmt.Errorf("%w: value of type %q is not assignable to type %q",
90 | ErrTypeMismatch, reflect.TypeOf(value), targetType)
91 | }
92 |
93 | rtm.Value.Elem().Set(newValue)
94 | return nil
95 | }
96 |
--------------------------------------------------------------------------------
/core/error.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "strings"
7 |
8 | "github.com/ggicci/httpin/internal"
9 | "github.com/ggicci/owl"
10 | )
11 |
12 | var (
13 | ErrUnregisteredDirective = errors.New("unregistered directive")
14 | ErrUnregisteredCoder = errors.New("unregistered coder")
15 | ErrTypeMismatch = internal.ErrTypeMismatch
16 | )
17 |
18 | type InvalidFieldError struct {
19 | // err is the underlying error thrown by the directive executor.
20 | err error
21 |
22 | // Field is the name of the field.
23 | Field string `json:"field"`
24 |
25 | // Source is the directive which causes the error.
26 | // e.g. form, header, required, etc.
27 | Directive string `json:"directive"`
28 |
29 | // Key is the key to get the input data from the source.
30 | Key string `json:"key"`
31 |
32 | // Value is the input data.
33 | Value any `json:"value"`
34 |
35 | // ErrorMessage is the string representation of `internalError`.
36 | ErrorMessage string `json:"error"`
37 | }
38 |
39 | func (e *InvalidFieldError) Error() string {
40 | return fmt.Sprintf("invalid field %q: %v", e.Field, e.err)
41 | }
42 |
43 | func (e *InvalidFieldError) Unwrap() error {
44 | return e.err
45 | }
46 |
47 | func NewInvalidFieldError(err error) *InvalidFieldError {
48 | var (
49 | r *owl.Resolver
50 | de *owl.DirectiveExecutionError
51 | )
52 |
53 | switch err := err.(type) {
54 | case *InvalidFieldError:
55 | return err
56 | case *owl.ResolveError:
57 | r = err.Resolver
58 | de = err.AsDirectiveExecutionError()
59 | case *owl.ScanError:
60 | r = err.Resolver
61 | de = err.AsDirectiveExecutionError()
62 | default:
63 | return &InvalidFieldError{
64 | err: err,
65 | ErrorMessage: err.Error(),
66 | }
67 | }
68 |
69 | var fe *fieldError
70 | var inputKey string
71 | var inputValue any
72 | errors.As(err, &fe)
73 | if fe != nil {
74 | inputValue = fe.Value
75 | inputKey = fe.Key
76 | }
77 |
78 | return &InvalidFieldError{
79 | err: err,
80 | Field: r.Field.Name,
81 | Directive: de.Name, // e.g. form, header, required, etc.
82 | Key: inputKey,
83 | Value: inputValue,
84 | ErrorMessage: err.Error(),
85 | }
86 | }
87 |
88 | type MultiInvalidFieldError []*InvalidFieldError
89 |
90 | func (me MultiInvalidFieldError) Error() string {
91 | if len(me) == 1 {
92 | return me[0].Error()
93 | }
94 | var sb strings.Builder
95 | sb.WriteString(fmt.Sprintf("%d invalid fields: ", len(me)))
96 | for i, e := range me {
97 | if i > 0 {
98 | sb.WriteString("; ")
99 | }
100 | sb.WriteString(e.Error())
101 | }
102 | return sb.String()
103 | }
104 |
105 | func (me MultiInvalidFieldError) Unwrap() []error {
106 | var errs []error
107 | for _, e := range me {
108 | errs = append(errs, e)
109 | }
110 | return errs
111 | }
112 |
113 | type fieldError struct {
114 | Key string
115 | Value any
116 | internalError error
117 | }
118 |
119 | func (e fieldError) Error() string {
120 | return e.internalError.Error()
121 | }
122 |
123 | func (e fieldError) Unwrap() error {
124 | return e.internalError
125 | }
126 |
--------------------------------------------------------------------------------
/core/errorhandler.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "net/http"
7 |
8 | "github.com/ggicci/httpin/internal"
9 | )
10 |
11 | var globalCustomErrorHandler ErrorHandler = defaultErrorHandler
12 |
13 | // RegisterErrorHandler replaces the default error handler with the given
14 | // custom error handler. The default error handler will be used in the http.Handler
15 | // that decoreated by the middleware created by NewInput().
16 | func RegisterErrorHandler(handler ErrorHandler) {
17 | internal.PanicOnError(validateErrorHandler(handler))
18 | globalCustomErrorHandler = handler
19 | }
20 |
21 | func defaultErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
22 | var invalidFieldError *InvalidFieldError
23 | if errors.As(err, &invalidFieldError) {
24 | rw.Header().Add("Content-Type", "application/json")
25 | rw.WriteHeader(http.StatusUnprocessableEntity) // status: 422
26 | json.NewEncoder(rw).Encode(invalidFieldError)
27 | return
28 | }
29 |
30 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) // status: 500
31 | }
32 |
33 | // ErrorHandler is the type of custom error handler. The error handler is used
34 | // by the http.Handler that created by NewInput() to handle errors during
35 | // decoding the HTTP request.
36 | type ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error)
37 |
38 | func validateErrorHandler(handler ErrorHandler) error {
39 | if handler == nil {
40 | return errors.New("nil error handler")
41 | }
42 | return nil
43 | }
44 |
--------------------------------------------------------------------------------
/core/errorhandler_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func myCustomErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
14 | var invalidFieldError *InvalidFieldError
15 | if errors.As(err, &invalidFieldError) {
16 | rw.WriteHeader(http.StatusBadRequest) // status: 400
17 | io.WriteString(rw, invalidFieldError.Error())
18 | return
19 | }
20 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) // status: 500
21 | }
22 |
23 | func TestRegisterErrorHandler(t *testing.T) {
24 | // Nil handler should panic.
25 | assert.PanicsWithError(t, "httpin: nil error handler", func() {
26 | RegisterErrorHandler(nil)
27 | })
28 |
29 | RegisterErrorHandler(myCustomErrorHandler)
30 | assert.True(t, equalFuncs(globalCustomErrorHandler, myCustomErrorHandler))
31 | }
32 |
33 | func TestDefaultErrorHandler(t *testing.T) {
34 | r, _ := http.NewRequest("GET", "/", nil)
35 | rw := httptest.NewRecorder()
36 |
37 | // When met InvalidFieldError, it should return 422.
38 | defaultErrorHandler(rw, r, &InvalidFieldError{err: assert.AnError, ErrorMessage: assert.AnError.Error()})
39 | assert.Equal(t, 422, rw.Code)
40 |
41 | // When met other errors, it should return 500.
42 | rw = httptest.NewRecorder()
43 | defaultErrorHandler(rw, r, assert.AnError)
44 | assert.Equal(t, 500, rw.Code)
45 | }
46 |
--------------------------------------------------------------------------------
/core/file.go:
--------------------------------------------------------------------------------
1 | // https://ggicci.github.io/httpin/advanced/upload-files
2 |
3 | package core
4 |
5 | import (
6 | "errors"
7 | "io"
8 | "mime/multipart"
9 | "os"
10 | )
11 |
12 | func init() {
13 | RegisterFileCoder[*File]()
14 | }
15 |
16 | // File is the builtin type of httpin to manupulate file uploads. On the server
17 | // side, it is used to represent a file in a multipart/form-data request. On the
18 | // client side, it is used to represent a file to be uploaded.
19 | type File struct {
20 | FileHeader
21 | uploadFilename string
22 | uploadReader io.ReadCloser
23 | }
24 |
25 | // UploadFile is a helper function to create a File instance from a file path.
26 | // It is useful when you want to upload a file from the local file system.
27 | func UploadFile(filename string) *File {
28 | return &File{uploadFilename: filename}
29 | }
30 |
31 | // UploadStream is a helper function to create a File instance from a io.Reader. It
32 | // is useful when you want to upload a file from a stream.
33 | func UploadStream(contentReader io.ReadCloser) *File {
34 | return &File{uploadReader: contentReader}
35 | }
36 |
37 | // Filename returns the filename of the file. On the server side, it returns the
38 | // filename of the file in the multipart/form-data request. On the client side, it
39 | // returns the filename of the file to be uploaded.
40 | func (f *File) Filename() string {
41 | if f.IsUpload() {
42 | return f.uploadFilename
43 | }
44 | if f.FileHeader != nil {
45 | return f.FileHeader.Filename()
46 | }
47 | return ""
48 | }
49 |
50 | // MarshalFile implements FileMarshaler.
51 | func (f *File) MarshalFile() (io.ReadCloser, error) {
52 | if f.IsUpload() {
53 | return f.OpenUploadStream()
54 | } else {
55 | return f.OpenReceiveStream()
56 | }
57 | }
58 |
59 | func (f *File) UnmarshalFile(fh FileHeader) error {
60 | f.FileHeader = fh
61 | return nil
62 | }
63 |
64 | // IsUpload returns true when the File instance is created for an upload purpose.
65 | // Typically, you should use UploadFilename or UploadReader to create a File instance
66 | // for upload.
67 | func (f *File) IsUpload() bool {
68 | return f.uploadFilename != "" || f.uploadReader != nil
69 | }
70 |
71 | // OpenUploadStream returns a io.ReadCloser for the file to be uploaded.
72 | // Call this method on the client side for uploading a file.
73 | func (f *File) OpenUploadStream() (io.ReadCloser, error) {
74 | if f.uploadReader != nil {
75 | return f.uploadReader, nil
76 | }
77 | if f.uploadFilename != "" {
78 | return os.Open(f.uploadFilename)
79 | }
80 | return nil, errors.New("invalid upload (client): no filename or reader")
81 | }
82 |
83 | // OpenReceiveStream returns a io.Reader for the file in the multipart/form-data request.
84 | // Call this method on the server side to read the file content.
85 | func (f *File) OpenReceiveStream() (multipart.File, error) {
86 | if f.FileHeader == nil {
87 | return nil, errors.New("invalid upload (server): nil file header")
88 | }
89 | return f.FileHeader.Open()
90 | }
91 |
92 | func (f *File) ReadAll() ([]byte, error) {
93 | reader, err := f.MarshalFile()
94 | if err != nil {
95 | return nil, err
96 | }
97 | return io.ReadAll(reader)
98 | }
99 |
--------------------------------------------------------------------------------
/core/fileable.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "mime/multipart"
7 | "net/textproto"
8 | "reflect"
9 |
10 | "github.com/ggicci/httpin/internal"
11 | )
12 |
13 | // FileHeader is the interface that groups the methods of multipart.FileHeader.
14 | type FileHeader interface {
15 | Filename() string
16 | Size() int64
17 | MIMEHeader() textproto.MIMEHeader
18 | Open() (multipart.File, error)
19 | }
20 |
21 | type FileMarshaler interface {
22 | Filename() string
23 | MarshalFile() (io.ReadCloser, error)
24 | }
25 |
26 | type FileUnmarshaler interface {
27 | UnmarshalFile(FileHeader) error
28 | }
29 |
30 | type Fileable interface {
31 | FileMarshaler
32 | FileUnmarshaler
33 | }
34 |
35 | func NewFileable(rv reflect.Value) (Fileable, error) {
36 | if IsPatchField(rv.Type()) {
37 | return NewFileablePatchFieldWrapper(rv)
38 | }
39 |
40 | return newFileable(rv)
41 | }
42 |
43 | func newFileable(rv reflect.Value) (Fileable, error) {
44 | rv, err := getPointer(rv)
45 | if err != nil {
46 | return nil, err
47 | }
48 |
49 | if rv.Type().Implements(fileableType) && rv.CanInterface() {
50 | return rv.Interface().(Fileable), nil
51 | }
52 | return nil, errors.New("unsupported file type")
53 | }
54 |
55 | type FileablePatchFieldWrapper struct {
56 | Value reflect.Value // of patch.Field[T]
57 | internalFileable Fileable
58 | }
59 |
60 | func NewFileablePatchFieldWrapper(rv reflect.Value) (*FileablePatchFieldWrapper, error) {
61 | fileable, err := NewFileable(rv.FieldByName("Value"))
62 | if err != nil {
63 | return nil, err
64 | } else {
65 | return &FileablePatchFieldWrapper{
66 | Value: rv,
67 | internalFileable: fileable,
68 | }, nil
69 | }
70 | }
71 |
72 | func (w *FileablePatchFieldWrapper) Filename() string {
73 | return w.internalFileable.Filename()
74 | }
75 |
76 | func (w *FileablePatchFieldWrapper) MarshalFile() (io.ReadCloser, error) {
77 | return w.internalFileable.MarshalFile()
78 | }
79 |
80 | func (w *FileablePatchFieldWrapper) UnmarshalFile(fh FileHeader) error {
81 | if err := w.internalFileable.UnmarshalFile(fh); err != nil {
82 | return err
83 | } else {
84 | w.Value.FieldByName("Valid").SetBool(true)
85 | return nil
86 | }
87 | }
88 |
89 | var fileableType = internal.TypeOf[Fileable]()
90 |
91 | // multipartFileHeader is the adaptor of multipart.FileHeader.
92 | type multipartFileHeader struct{ *multipart.FileHeader }
93 |
94 | func (h *multipartFileHeader) Filename() string {
95 | return h.FileHeader.Filename
96 | }
97 |
98 | func (h *multipartFileHeader) Size() int64 {
99 | return h.FileHeader.Size
100 | }
101 |
102 | func (h *multipartFileHeader) MIMEHeader() textproto.MIMEHeader {
103 | return h.FileHeader.Header
104 | }
105 |
106 | func (h *multipartFileHeader) Open() (multipart.File, error) {
107 | return h.FileHeader.Open()
108 | }
109 |
110 | func toFileHeaderList(fhs []*multipart.FileHeader) []FileHeader {
111 | result := make([]FileHeader, len(fhs))
112 | for i, fh := range fhs {
113 | result[i] = &multipartFileHeader{fh}
114 | }
115 | return result
116 | }
117 |
--------------------------------------------------------------------------------
/core/fileable_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "io"
5 | "mime/multipart"
6 | "net/textproto"
7 | "os"
8 | "reflect"
9 | "testing"
10 |
11 | "github.com/ggicci/httpin/patch"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | type MyFiles struct {
16 | Avatar File
17 | AvatarPointer *File
18 | Avatars []File
19 | AvatarPointers []*File
20 |
21 | PatchAvatar patch.Field[File]
22 | PatchAvatarPointer patch.Field[*File]
23 | PatchAvatars patch.Field[[]File]
24 | PatchAvatarPointers patch.Field[[]*File]
25 | }
26 |
27 | func TestFileable_UnmarshalFile(t *testing.T) {
28 | rv := reflect.New(reflect.TypeOf(MyFiles{})).Elem()
29 | s := rv.Addr().Interface().(*MyFiles)
30 |
31 | fileAvatar := testAssignFile(t, rv.FieldByName("Avatar"))
32 | fileAvatarPointer := testAssignFile(t, rv.FieldByName("AvatarPointer"))
33 | testNewFileableErrUnsupported(t, rv.FieldByName("Avatars"))
34 | testNewFileableErrUnsupported(t, rv.FieldByName("AvatarPointers"))
35 |
36 | filePatchAvatar := testAssignFile(t, rv.FieldByName("PatchAvatar"))
37 | filePatchAvatarPointer := testAssignFile(t, rv.FieldByName("PatchAvatarPointer"))
38 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatars"))
39 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatarPointers"))
40 |
41 | validateFile(t, fileAvatar, &s.Avatar)
42 | validateFile(t, fileAvatarPointer, s.AvatarPointer)
43 |
44 | assert.True(t, s.PatchAvatar.Valid)
45 | validateFile(t, filePatchAvatar, &s.PatchAvatar.Value)
46 | assert.True(t, s.PatchAvatarPointer.Valid)
47 | validateFile(t, filePatchAvatarPointer, s.PatchAvatarPointer.Value)
48 | }
49 |
50 | func TestFileable_MarshalFile(t *testing.T) {
51 | fileAvatar := createTempFileV2(t)
52 | fileAvatarPointer := createTempFileV2(t)
53 | filePatchAvatar := createTempFileV2(t)
54 | filePatchAvatarPointer := createTempFileV2(t)
55 |
56 | var s = &MyFiles{
57 | Avatar: *UploadFile(fileAvatar.Filename),
58 | AvatarPointer: UploadFile(fileAvatarPointer.Filename),
59 | Avatars: []File{*UploadFile(fileAvatar.Filename)},
60 | AvatarPointers: []*File{UploadFile(fileAvatarPointer.Filename)},
61 |
62 | PatchAvatar: patch.Field[File]{Value: *UploadFile(filePatchAvatar.Filename), Valid: true},
63 | PatchAvatarPointer: patch.Field[*File]{Value: UploadFile(filePatchAvatarPointer.Filename), Valid: true},
64 | PatchAvatars: patch.Field[[]File]{Value: []File{*UploadFile(filePatchAvatar.Filename)}, Valid: true},
65 | PatchAvatarPointers: patch.Field[[]*File]{Value: []*File{UploadFile(filePatchAvatarPointer.Filename)}, Valid: true},
66 | }
67 |
68 | rv := reflect.ValueOf(s).Elem()
69 |
70 | validateRvFile(t, fileAvatar, rv.FieldByName("Avatar"))
71 | validateRvFile(t, fileAvatarPointer, rv.FieldByName("AvatarPointer"))
72 | testNewFileableErrUnsupported(t, rv.FieldByName("Avatars"))
73 | testNewFileableErrUnsupported(t, rv.FieldByName("AvatarPointers"))
74 |
75 | validateRvFile(t, filePatchAvatar, rv.FieldByName("PatchAvatar"))
76 | validateRvFile(t, filePatchAvatarPointer, rv.FieldByName("PatchAvatarPointer"))
77 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatars"))
78 | testNewFileableErrUnsupported(t, rv.FieldByName("PatchAvatarPointers"))
79 | }
80 |
81 | func testNewFileableErrUnsupported(t *testing.T, rv reflect.Value) {
82 | fileable, err := NewFileable(rv)
83 | assert.ErrorContains(t, err, "unsupported file type")
84 | assert.Nil(t, fileable)
85 | }
86 |
87 | func validateFile(t *testing.T, expected *tempFile, actual FileMarshaler) {
88 | assert.Equal(t, expected.Filename, actual.Filename())
89 | reader, err := actual.MarshalFile()
90 | assert.NoError(t, err)
91 | content, err := io.ReadAll(reader)
92 | assert.NoError(t, err)
93 | assert.Equal(t, expected.Content, content)
94 | }
95 |
96 | func validateRvFile(t *testing.T, expected *tempFile, actual reflect.Value) {
97 | file, err := NewFileable(actual)
98 | assert.NoError(t, err)
99 | reader, err := file.MarshalFile()
100 | assert.NoError(t, err)
101 | content, err := io.ReadAll(reader)
102 | assert.NoError(t, err)
103 |
104 | assert.Equal(t, expected.Filename, file.Filename())
105 | assert.Equal(t, expected.Content, content)
106 | }
107 |
108 | func testAssignFile(t *testing.T, rv reflect.Value) *tempFile {
109 | fileable, err := NewFileable(rv)
110 | assert.NoError(t, err)
111 | file := createTempFileV2(t)
112 | assert.NoError(t, fileable.UnmarshalFile(mockFileHeader(t, file.Filename)))
113 | return file
114 | }
115 |
116 | type dummyFileHeader struct {
117 | file *os.File
118 | }
119 |
120 | func mockFileHeader(t *testing.T, filename string) FileHeader {
121 | file, err := os.Open(filename)
122 | if err != nil {
123 | panic(err)
124 | }
125 | return &dummyFileHeader{
126 | file: file,
127 | }
128 | }
129 |
130 | func (f *dummyFileHeader) Filename() string {
131 | return f.file.Name()
132 | }
133 |
134 | func (f *dummyFileHeader) Size() int64 {
135 | stat, err := f.file.Stat()
136 | if err != nil {
137 | panic(err)
138 | }
139 | return stat.Size()
140 | }
141 |
142 | func (f *dummyFileHeader) MIMEHeader() textproto.MIMEHeader {
143 | return textproto.MIMEHeader{}
144 | }
145 |
146 | func (f *dummyFileHeader) Open() (multipart.File, error) {
147 | return f.file, nil
148 | }
149 |
--------------------------------------------------------------------------------
/core/fileslicable.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "reflect"
7 | )
8 |
9 | type FileSlicable interface {
10 | ToFileSlice() ([]FileMarshaler, error)
11 | FromFileSlice([]FileHeader) error
12 | }
13 |
14 | func NewFileSlicable(rv reflect.Value) (FileSlicable, error) {
15 | if IsPatchField(rv.Type()) {
16 | return NewFileSlicablePatchFieldWrapper(rv)
17 | }
18 |
19 | if isSliceType(rv.Type()) {
20 | return NewFileableSliceWrapper(rv)
21 | } else {
22 | return NewFileSlicableSingleFileableWrapper(rv)
23 | }
24 | }
25 |
26 | type FileSlicablePatchFieldWrapper struct {
27 | Value reflect.Value // of patch.Field[T]
28 | internalFileSliceable FileSlicable
29 | }
30 |
31 | func NewFileSlicablePatchFieldWrapper(rv reflect.Value) (*FileSlicablePatchFieldWrapper, error) {
32 | fileSlicable, err := NewFileSlicable(rv.FieldByName("Value"))
33 | if err != nil {
34 | return nil, err
35 | } else {
36 | return &FileSlicablePatchFieldWrapper{
37 | Value: rv,
38 | internalFileSliceable: fileSlicable,
39 | }, nil
40 | }
41 | }
42 |
43 | func (w *FileSlicablePatchFieldWrapper) ToFileSlice() ([]FileMarshaler, error) {
44 | if w.Value.FieldByName("Valid").Bool() {
45 | return w.internalFileSliceable.ToFileSlice()
46 | } else {
47 | return []FileMarshaler{}, nil
48 | }
49 | }
50 |
51 | func (w *FileSlicablePatchFieldWrapper) FromFileSlice(fhs []FileHeader) error {
52 | if err := w.internalFileSliceable.FromFileSlice(fhs); err != nil {
53 | return err
54 | } else {
55 | w.Value.FieldByName("Valid").SetBool(true)
56 | return nil
57 | }
58 | }
59 |
60 | type FileableSliceWrapper struct {
61 | Value reflect.Value
62 | }
63 |
64 | func NewFileableSliceWrapper(rv reflect.Value) (*FileableSliceWrapper, error) {
65 | if !rv.CanAddr() {
66 | return nil, errors.New("unaddressable value")
67 | }
68 | return &FileableSliceWrapper{Value: rv}, nil
69 | }
70 |
71 | func (w *FileableSliceWrapper) ToFileSlice() ([]FileMarshaler, error) {
72 | var files = make([]FileMarshaler, w.Value.Len())
73 | for i := 0; i < w.Value.Len(); i++ {
74 | if fileable, err := NewFileable(w.Value.Index(i)); err != nil {
75 | return nil, fmt.Errorf("cannot create Fileable at index %d: %w", i, err)
76 | } else {
77 | files[i] = fileable
78 | }
79 | }
80 | return files, nil
81 | }
82 |
83 | func (w *FileableSliceWrapper) FromFileSlice(fhs []FileHeader) error {
84 | w.Value.Set(reflect.MakeSlice(w.Value.Type(), len(fhs), len(fhs)))
85 | for i, fh := range fhs {
86 | fileable, err := NewFileable(w.Value.Index(i))
87 | if err != nil {
88 | return fmt.Errorf("cannot create Fileable at index %d: %w", i, err)
89 | }
90 | if err := fileable.UnmarshalFile(fh); err != nil {
91 | return fmt.Errorf("cannot unmarshal file %q at index %d: %w", fh.Filename(), i, err)
92 | }
93 | }
94 | return nil
95 | }
96 |
97 | type FileSlicableSingleFileableWrapper struct{ Fileable }
98 |
99 | func NewFileSlicableSingleFileableWrapper(rv reflect.Value) (*FileSlicableSingleFileableWrapper, error) {
100 | if fileable, err := NewFileable(rv); err != nil {
101 | return nil, err
102 | } else {
103 | return &FileSlicableSingleFileableWrapper{fileable}, nil
104 | }
105 | }
106 |
107 | func (w *FileSlicableSingleFileableWrapper) ToFileSlice() ([]FileMarshaler, error) {
108 | return []FileMarshaler{w.Fileable}, nil
109 | }
110 |
111 | func (w *FileSlicableSingleFileableWrapper) FromFileSlice(files []FileHeader) error {
112 | if len(files) > 0 {
113 | return w.UnmarshalFile(files[0])
114 | }
115 | return nil
116 | }
117 |
--------------------------------------------------------------------------------
/core/fileslicable_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/ggicci/httpin/patch"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestFileSlicable_FromFileSlice(t *testing.T) {
12 | rv := reflect.New(reflect.TypeOf(MyFiles{})).Elem()
13 | s := rv.Addr().Interface().(*MyFiles)
14 |
15 | fileAvatar := createTempFileV2(t)
16 | fileAvatarPointer := createTempFileV2(t)
17 | filePatchAvatar := createTempFileV2(t)
18 | filePatchAvatarPointer := createTempFileV2(t)
19 |
20 | testAssignFileSlice(t, rv.FieldByName("Avatar"), []FileHeader{
21 | mockFileHeader(t, fileAvatar.Filename),
22 | })
23 | testAssignFileSlice(t, rv.FieldByName("AvatarPointer"), []FileHeader{
24 | mockFileHeader(t, fileAvatarPointer.Filename),
25 | })
26 | testAssignFileSlice(t, rv.FieldByName("Avatars"), []FileHeader{
27 | mockFileHeader(t, fileAvatar.Filename),
28 | mockFileHeader(t, fileAvatarPointer.Filename),
29 | })
30 | testAssignFileSlice(t, rv.FieldByName("AvatarPointers"), []FileHeader{
31 | mockFileHeader(t, fileAvatarPointer.Filename),
32 | mockFileHeader(t, fileAvatar.Filename),
33 | })
34 |
35 | testAssignFileSlice(t, rv.FieldByName("PatchAvatar"), []FileHeader{
36 | mockFileHeader(t, filePatchAvatar.Filename),
37 | })
38 | testAssignFileSlice(t, rv.FieldByName("PatchAvatarPointer"), []FileHeader{
39 | mockFileHeader(t, filePatchAvatarPointer.Filename),
40 | })
41 | testAssignFileSlice(t, rv.FieldByName("PatchAvatars"), []FileHeader{
42 | mockFileHeader(t, fileAvatar.Filename),
43 | mockFileHeader(t, filePatchAvatar.Filename),
44 | mockFileHeader(t, filePatchAvatarPointer.Filename),
45 | })
46 | testAssignFileSlice(t, rv.FieldByName("PatchAvatarPointers"), []FileHeader{
47 | mockFileHeader(t, fileAvatar.Filename),
48 | mockFileHeader(t, fileAvatarPointer.Filename),
49 | mockFileHeader(t, filePatchAvatar.Filename),
50 | mockFileHeader(t, filePatchAvatarPointer.Filename),
51 | })
52 |
53 | validateFile(t, fileAvatar, &s.Avatar)
54 | validateFile(t, fileAvatarPointer, s.AvatarPointer)
55 |
56 | assert.Len(t, s.Avatars, 2)
57 | validateFile(t, fileAvatar, &s.Avatars[0])
58 | validateFile(t, fileAvatarPointer, &s.Avatars[1])
59 |
60 | assert.Len(t, s.AvatarPointers, 2)
61 | validateFile(t, fileAvatarPointer, s.AvatarPointers[0])
62 | validateFile(t, fileAvatar, s.AvatarPointers[1])
63 |
64 | assert.True(t, s.PatchAvatar.Valid)
65 | validateFile(t, filePatchAvatar, &s.PatchAvatar.Value)
66 |
67 | assert.True(t, s.PatchAvatarPointer.Valid)
68 | validateFile(t, filePatchAvatarPointer, s.PatchAvatarPointer.Value)
69 |
70 | assert.True(t, s.PatchAvatars.Valid)
71 | assert.Len(t, s.PatchAvatars.Value, 3)
72 | validateFile(t, fileAvatar, &s.PatchAvatars.Value[0])
73 | validateFile(t, filePatchAvatar, &s.PatchAvatars.Value[1])
74 | validateFile(t, filePatchAvatarPointer, &s.PatchAvatars.Value[2])
75 |
76 | assert.True(t, s.PatchAvatarPointers.Valid)
77 | assert.Len(t, s.PatchAvatarPointers.Value, 4)
78 | validateFile(t, fileAvatar, s.PatchAvatarPointers.Value[0])
79 | validateFile(t, fileAvatarPointer, s.PatchAvatarPointers.Value[1])
80 | validateFile(t, filePatchAvatar, s.PatchAvatarPointers.Value[2])
81 | validateFile(t, filePatchAvatarPointer, s.PatchAvatarPointers.Value[3])
82 | }
83 |
84 | func TestFileSlicable_ToFileSlice(t *testing.T) {
85 | fileAvatar := createTempFileV2(t)
86 | fileAvatarPointer := createTempFileV2(t)
87 | filePatchAvatar := createTempFileV2(t)
88 | filePatchAvatarPointer := createTempFileV2(t)
89 |
90 | var s = &MyFiles{
91 | Avatar: *UploadFile(fileAvatar.Filename),
92 | AvatarPointer: UploadFile(fileAvatarPointer.Filename),
93 | Avatars: []File{*UploadFile(fileAvatar.Filename), *UploadFile(fileAvatarPointer.Filename)},
94 | AvatarPointers: []*File{UploadFile(fileAvatarPointer.Filename), UploadFile(fileAvatar.Filename)},
95 | PatchAvatar: patch.Field[File]{Value: *UploadFile(filePatchAvatar.Filename), Valid: true},
96 | PatchAvatarPointer: patch.Field[*File]{
97 | Value: UploadFile(filePatchAvatarPointer.Filename),
98 | Valid: true,
99 | },
100 | PatchAvatars: patch.Field[[]File]{
101 | Value: []File{
102 | *UploadFile(fileAvatar.Filename),
103 | *UploadFile(filePatchAvatar.Filename),
104 | *UploadFile(filePatchAvatarPointer.Filename),
105 | },
106 | Valid: true,
107 | },
108 | PatchAvatarPointers: patch.Field[[]*File]{
109 | Value: []*File{
110 | UploadFile(fileAvatar.Filename),
111 | UploadFile(fileAvatarPointer.Filename),
112 | UploadFile(filePatchAvatar.Filename),
113 | UploadFile(filePatchAvatarPointer.Filename),
114 | },
115 | Valid: true,
116 | },
117 | }
118 |
119 | rv := reflect.ValueOf(s).Elem()
120 | validateFileList(t, []*tempFile{fileAvatar}, testGetFileSlice(t, rv.FieldByName("Avatar")))
121 | validateFileList(t, []*tempFile{fileAvatarPointer}, testGetFileSlice(t, rv.FieldByName("AvatarPointer")))
122 | validateFileList(t, []*tempFile{fileAvatar, fileAvatarPointer}, testGetFileSlice(t, rv.FieldByName("Avatars")))
123 | validateFileList(t, []*tempFile{fileAvatarPointer, fileAvatar}, testGetFileSlice(t, rv.FieldByName("AvatarPointers")))
124 | validateFileList(t, []*tempFile{filePatchAvatar}, testGetFileSlice(t, rv.FieldByName("PatchAvatar")))
125 | validateFileList(t, []*tempFile{filePatchAvatarPointer}, testGetFileSlice(t, rv.FieldByName("PatchAvatarPointer")))
126 | validateFileList(t, []*tempFile{fileAvatar, filePatchAvatar, filePatchAvatarPointer}, testGetFileSlice(t, rv.FieldByName("PatchAvatars")))
127 | validateFileList(t, []*tempFile{fileAvatar, fileAvatarPointer, filePatchAvatar, filePatchAvatarPointer}, testGetFileSlice(t, rv.FieldByName("PatchAvatarPointers")))
128 | }
129 |
130 | func testAssignFileSlice(t *testing.T, rv reflect.Value, files []FileHeader) {
131 | fs, err := NewFileSlicable(rv)
132 | assert.NoError(t, err)
133 | assert.NoError(t, fs.FromFileSlice(files))
134 | }
135 |
136 | func testGetFileSlice(t *testing.T, rv reflect.Value) []FileMarshaler {
137 | fs, err := NewFileSlicable(rv)
138 | assert.NoError(t, err)
139 | files, err := fs.ToFileSlice()
140 | assert.NoError(t, err)
141 | return files
142 | }
143 |
144 | func validateFileList(t *testing.T, expected []*tempFile, actual []FileMarshaler) {
145 | assert.Len(t, actual, len(expected))
146 | for i, file := range expected {
147 | validateFile(t, file, actual[i])
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/core/form.go:
--------------------------------------------------------------------------------
1 | // directive: "form"
2 | // https://ggicci.github.io/httpin/directives/form
3 |
4 | package core
5 |
6 | import (
7 | "mime/multipart"
8 | )
9 |
10 | type DirectvieForm struct{}
11 |
12 | // Decode implements the "form" executor who extracts values from
13 | // the forms of an HTTP request.
14 | func (*DirectvieForm) Decode(rtm *DirectiveRuntime) error {
15 | req := rtm.GetRequest()
16 | var form multipart.Form
17 | if req.MultipartForm != nil {
18 | form = *req.MultipartForm
19 | } else {
20 | if req.Form != nil {
21 | form.Value = req.Form
22 | }
23 | }
24 | extractor := &FormExtractor{
25 | Runtime: rtm,
26 | Form: form,
27 | KeyNormalizer: nil,
28 | }
29 | return extractor.Extract()
30 | }
31 |
32 | // Encode implements the encoder/request builder for "form" directive.
33 | // It builds the form values of an HTTP request, including:
34 | // - form data
35 | // - multipart form data (file upload)
36 | func (*DirectvieForm) Encode(rtm *DirectiveRuntime) error {
37 | encoder := &FormEncoder{
38 | Setter: rtm.GetRequestBuilder().SetForm,
39 | }
40 | return encoder.Execute(rtm)
41 | }
42 |
--------------------------------------------------------------------------------
/core/form_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "encoding/base64"
5 | "io"
6 | "net/http"
7 | "net/url"
8 | "strings"
9 | "testing"
10 | "time"
11 |
12 | "github.com/ggicci/httpin/internal"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | // ChaosQuery is designed to make the normal case test coverage higher.
17 | type ChaosQuery struct {
18 | // Basic Types
19 | BoolValue bool `in:"form=bool"`
20 | IntValue int `in:"form=int"`
21 | Int8Value int8 `in:"form=int8"`
22 | Int16Value int16 `in:"form=int16"`
23 | Int32Value int32 `in:"form=int32"`
24 | Int64Value int64 `in:"form=int64"`
25 | UintValue uint `in:"form=uint"`
26 | Uint8Value uint8 `in:"form=uint8"`
27 | Uint16Value uint16 `in:"form=uint16"`
28 | Uint32Value uint32 `in:"form=uint32"`
29 | Uint64Value uint64 `in:"form=uint64"`
30 | Float32Value float32 `in:"form=float32"`
31 | Float64Value float64 `in:"form=float64"`
32 | Complex64Value complex64 `in:"form=complex64"`
33 | Complex128Value complex128 `in:"form=complex128"`
34 | StringValue string `in:"form=string"`
35 | TimeValue time.Time `in:"form=time"` // time type is special
36 |
37 | // Pointer Types
38 | BoolPointer *bool `in:"form=bool_pointer"`
39 | IntPointer *int `in:"form=int_pointer"`
40 | Int8Pointer *int8 `in:"form=int8_pointer"`
41 | Int16Pointer *int16 `in:"form=int16_pointer"`
42 | Int32Pointer *int32 `in:"form=int32_pointer"`
43 | Int64Pointer *int64 `in:"form=int64_pointer"`
44 | UintPointer *uint `in:"form=uint_pointer"`
45 | Uint8Pointer *uint8 `in:"form=uint8_pointer"`
46 | Uint16Pointer *uint16 `in:"form=uint16_pointer"`
47 | Uint32Pointer *uint32 `in:"form=uint32_pointer"`
48 | Uint64Pointer *uint64 `in:"form=uint64_pointer"`
49 | Float32Pointer *float32 `in:"form=float32_pointer"`
50 | Float64Pointer *float64 `in:"form=float64_pointer"`
51 | Complex64Pointer *complex64 `in:"form=complex64_pointer"`
52 | Complex128Pointer *complex128 `in:"form=complex128_pointer"`
53 | StringPointer *string `in:"form=string_pointer"`
54 | TimePointer *time.Time `in:"form=time_pointer"`
55 |
56 | // Array
57 | BoolList []bool `in:"form=bools"`
58 | IntList []int `in:"form=ints"`
59 | FloatList []float64 `in:"form=floats"`
60 | StringList []string `in:"form=strings"`
61 | TimeList []time.Time `in:"form=times"`
62 | }
63 |
64 | var (
65 | sampleChaosQuery = &ChaosQuery{
66 | BoolValue: true,
67 | IntValue: 9,
68 | Int8Value: 14,
69 | Int16Value: 841,
70 | Int32Value: 193,
71 | Int64Value: 475,
72 | UintValue: 11,
73 | Uint8Value: 4,
74 | Uint16Value: 48,
75 | Uint32Value: 9583,
76 | Uint64Value: 183471,
77 | Float32Value: 3.14,
78 | Float64Value: 0.618,
79 | Complex64Value: 1 + 4i,
80 | Complex128Value: -6 + 17i,
81 | StringValue: "doggy",
82 | TimeValue: time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC),
83 |
84 | BoolPointer: internal.Pointerize[bool](true),
85 | IntPointer: internal.Pointerize[int](9),
86 | Int8Pointer: internal.Pointerize[int8](14),
87 | Int16Pointer: internal.Pointerize[int16](841),
88 | Int32Pointer: internal.Pointerize[int32](193),
89 | Int64Pointer: internal.Pointerize[int64](475),
90 | UintPointer: internal.Pointerize[uint](11),
91 | Uint8Pointer: internal.Pointerize[uint8](4),
92 | Uint16Pointer: internal.Pointerize[uint16](48),
93 | Uint32Pointer: internal.Pointerize[uint32](9583),
94 | Uint64Pointer: internal.Pointerize[uint64](183471),
95 | Float32Pointer: internal.Pointerize[float32](3.14),
96 | Float64Pointer: internal.Pointerize[float64](0.618),
97 | Complex64Pointer: internal.Pointerize[complex64](1 + 4i),
98 | Complex128Pointer: internal.Pointerize[complex128](-6 + 17i),
99 | StringPointer: internal.Pointerize[string]("doggy"),
100 | TimePointer: internal.Pointerize[time.Time](time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC)),
101 |
102 | BoolList: []bool{true, false, false, true},
103 | IntList: []int{9, 9, 6},
104 | FloatList: []float64{0.0, 0.5, 1.0},
105 | StringList: []string{"Life", "is", "a", "Miracle"},
106 | TimeList: []time.Time{
107 | time.Date(2000, 1, 2, 22, 4, 5, 0, time.UTC),
108 | time.Date(1991, 6, 28, 6, 0, 0, 0, time.UTC),
109 | },
110 | }
111 | )
112 |
113 | func TestDirectiveForm_Decode(t *testing.T) {
114 | r, _ := http.NewRequest("GET", "/", nil)
115 | r.Form = url.Values{
116 | "bool": {"true"},
117 | "int": {"9"},
118 | "int8": {"14"},
119 | "int16": {"841"},
120 | "int32": {"193"},
121 | "int64": {"475"},
122 | "uint": {"11"},
123 | "uint8": {"4"},
124 | "uint16": {"48"},
125 | "uint32": {"9583"},
126 | "uint64": {"183471"},
127 | "float32": {"3.14"},
128 | "float64": {"0.618"},
129 | "complex64": {"1+4i"},
130 | "complex128": {"-6+17i"},
131 | "string": {"doggy"},
132 | "time": {"1991-11-10T08:00:00+08:00"},
133 |
134 | "bool_pointer": {"true"},
135 | "int_pointer": {"9"},
136 | "int8_pointer": {"14"},
137 | "int16_pointer": {"841"},
138 | "int32_pointer": {"193"},
139 | "int64_pointer": {"475"},
140 | "uint_pointer": {"11"},
141 | "uint8_pointer": {"4"},
142 | "uint16_pointer": {"48"},
143 | "uint32_pointer": {"9583"},
144 | "uint64_pointer": {"183471"},
145 | "float32_pointer": {"3.14"},
146 | "float64_pointer": {"0.618"},
147 | "complex64_pointer": {"1+4i"},
148 | "complex128_pointer": {"-6+17i"},
149 | "string_pointer": {"doggy"},
150 | "time_pointer": {"1991-11-10T08:00:00+08:00"},
151 |
152 | "bools": {"true", "false", "0", "1"},
153 | "ints": {"9", "9", "6"},
154 | "floats": {"0", "0.5", "1"},
155 | "strings": {"Life", "is", "a", "Miracle"},
156 | "times": {"2000-01-02T15:04:05-07:00", "678088800"},
157 | }
158 | expected := sampleChaosQuery
159 | co, err := New(ChaosQuery{})
160 | assert.NoError(t, err)
161 | got, err := co.Decode(r)
162 | assert.NoError(t, err)
163 | assert.Equal(t, expected, got.(*ChaosQuery))
164 | }
165 |
166 | func TestDirectiveForm_NewRequest(t *testing.T) {
167 | co, err := New(ChaosQuery{})
168 | assert.NoError(t, err)
169 | req, err := co.NewRequest("POST", "/signup", sampleChaosQuery)
170 | assert.NoError(t, err)
171 |
172 | expected, _ := http.NewRequest("POST", "/signup", nil)
173 | expectedForm := url.Values{
174 | "bool": {"true"},
175 | "int": {"9"},
176 | "int8": {"14"},
177 | "int16": {"841"},
178 | "int32": {"193"},
179 | "int64": {"475"},
180 | "uint": {"11"},
181 | "uint8": {"4"},
182 | "uint16": {"48"},
183 | "uint32": {"9583"},
184 | "uint64": {"183471"},
185 | "float32": {"3.14"},
186 | "float64": {"0.618"},
187 | "complex64": {"(1+4i)"},
188 | "complex128": {"(-6+17i)"},
189 | "string": {"doggy"},
190 | "time": {"1991-11-10T00:00:00Z"},
191 |
192 | "bool_pointer": {"true"},
193 | "int_pointer": {"9"},
194 | "int8_pointer": {"14"},
195 | "int16_pointer": {"841"},
196 | "int32_pointer": {"193"},
197 | "int64_pointer": {"475"},
198 | "uint_pointer": {"11"},
199 | "uint8_pointer": {"4"},
200 | "uint16_pointer": {"48"},
201 | "uint32_pointer": {"9583"},
202 | "uint64_pointer": {"183471"},
203 | "float32_pointer": {"3.14"},
204 | "float64_pointer": {"0.618"},
205 | "complex64_pointer": {"(1+4i)"},
206 | "complex128_pointer": {"(-6+17i)"},
207 | "string_pointer": {"doggy"},
208 | "time_pointer": {"1991-11-10T00:00:00Z"},
209 |
210 | "bools": {"true", "false", "false", "true"},
211 | "ints": {"9", "9", "6"},
212 | "floats": {"0", "0.5", "1"},
213 | "strings": {"Life", "is", "a", "Miracle"},
214 | "times": {"2000-01-02T22:04:05Z", "1991-06-28T06:00:00Z"},
215 | }
216 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode()))
217 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded")
218 | assert.Equal(t, expected, req)
219 | }
220 |
221 | func TestDirectiveForm_NewRequest_ByteSlice(t *testing.T) {
222 | type ByteSlice struct {
223 | Bytes []byte `in:"form=bytes"`
224 | MultiBytes [][]byte `in:"form=multi_bytes"`
225 | }
226 | co, err := New(ByteSlice{})
227 | assert.NoError(t, err)
228 | payload := &ByteSlice{
229 | Bytes: []byte("hello"),
230 | MultiBytes: [][]byte{
231 | []byte("hello"),
232 | []byte("world"),
233 | },
234 | }
235 | expected, _ := http.NewRequest("POST", "/api", nil)
236 | expectedForm := url.Values{
237 | "bytes": {base64.StdEncoding.EncodeToString(payload.Bytes)},
238 | "multi_bytes": {
239 | base64.StdEncoding.EncodeToString(payload.MultiBytes[0]),
240 | base64.StdEncoding.EncodeToString(payload.MultiBytes[1]),
241 | },
242 | }
243 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded")
244 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode()))
245 | req, err := co.NewRequest("POST", "/api", payload)
246 | assert.NoError(t, err)
247 | assert.Equal(t, expected, req)
248 | }
249 |
--------------------------------------------------------------------------------
/core/formencoder.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "github.com/ggicci/httpin/internal"
5 | )
6 |
7 | type FormEncoder struct {
8 | Setter func(key string, value []string) // form value setter
9 | }
10 |
11 | func (e *FormEncoder) Execute(rtm *DirectiveRuntime) error {
12 | if rtm.Value.IsZero() {
13 | if rtm.Resolver.GetDirective("omitempty") != nil {
14 | return nil
15 | }
16 | }
17 |
18 | if rtm.IsFieldSet() {
19 | return nil // skip when already encoded by former directives
20 | }
21 |
22 | key := rtm.Directive.Argv[0]
23 | valueType := rtm.Value.Type()
24 | // When baseType is a file type, we treat it as a file upload.
25 | if isFileType(valueType) {
26 | if internal.IsNil(rtm.Value) {
27 | return nil // skip when nil, which means no file uploaded
28 | }
29 |
30 | encoder, err := NewFileSlicable(rtm.Value)
31 | if err != nil {
32 | return err
33 | }
34 | files, err := encoder.ToFileSlice()
35 | if err != nil {
36 | return err
37 | }
38 | if len(files) == 0 {
39 | return nil // skip when no file uploaded
40 | }
41 | return fileUploadBuilder(rtm, files)
42 | }
43 |
44 | var adapt AnyStringableAdaptor
45 | encoderInfo := rtm.GetCustomCoder()
46 | if encoderInfo != nil {
47 | adapt = encoderInfo.Adapt
48 | }
49 | var encoder StringSlicable
50 | encoder, err := NewStringSlicable(rtm.Value, adapt)
51 | if err != nil {
52 | return err
53 | }
54 |
55 | if values, err := encoder.ToStringSlice(); err != nil {
56 | return err
57 | } else {
58 | e.Setter(key, values)
59 | rtm.MarkFieldSet(true)
60 | return nil
61 | }
62 | }
63 |
64 | func fileUploadBuilder(rtm *DirectiveRuntime, files []FileMarshaler) error {
65 | rb := rtm.GetRequestBuilder()
66 | key := rtm.Directive.Argv[0]
67 | rb.SetAttachment(key, files)
68 | rtm.MarkFieldSet(true)
69 | return nil
70 | }
71 |
--------------------------------------------------------------------------------
/core/formencoder_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestFormEncoder_FieldSetByFormerDirectives(t *testing.T) {
11 | type SearchQuery struct {
12 | AccessToken string `in:"query=access_token;header=x-api-key"`
13 | }
14 |
15 | co, err := New(&SearchQuery{})
16 | assert.NoError(t, err)
17 |
18 | req, err := co.NewRequest("GET", "/search", &SearchQuery{
19 | AccessToken: "123456",
20 | })
21 | assert.NoError(t, err)
22 |
23 | // The AccessToken field should be set by the query directive.
24 | expected, _ := http.NewRequest("GET", "/search?access_token=123456", nil)
25 | assert.Equal(t, expected, req)
26 | }
27 |
--------------------------------------------------------------------------------
/core/formextractor.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "mime/multipart"
5 | )
6 |
7 | type FormExtractor struct {
8 | Runtime *DirectiveRuntime
9 | multipart.Form
10 | KeyNormalizer func(string) string
11 | }
12 |
13 | func (e *FormExtractor) Extract(keys ...string) error {
14 | if len(keys) == 0 {
15 | keys = e.Runtime.Directive.Argv
16 | }
17 | for _, key := range keys {
18 | if e.KeyNormalizer != nil {
19 | key = e.KeyNormalizer(key)
20 | }
21 | if err := e.extract(key); err != nil {
22 | return err
23 | }
24 | }
25 | return nil
26 | }
27 |
28 | func (e *FormExtractor) extract(key string) error {
29 | if e.Runtime.IsFieldSet() {
30 | return nil // skip when already extracted by former directives
31 | }
32 |
33 | values := e.Form.Value[key]
34 | files := e.Form.File[key]
35 |
36 | // Quick fail on empty input.
37 | if len(values) == 0 && len(files) == 0 {
38 | return nil
39 | }
40 |
41 | var sourceValue any
42 | var err error
43 | valueType := e.Runtime.Value.Type().Elem()
44 | if isFileType(valueType) {
45 | // When fileDecoder is not nil, it means that the field is a file upload.
46 | // We should decode files instead of values.
47 | if len(files) == 0 {
48 | return nil // skip when no file uploaded
49 | }
50 | sourceValue = files
51 |
52 | var decoder FileSlicable
53 | decoder, err = NewFileSlicable(e.Runtime.Value.Elem())
54 | if err == nil {
55 | err = decoder.FromFileSlice(toFileHeaderList(files))
56 | }
57 | } else {
58 | if len(values) == 0 {
59 | return nil // skip when no value given
60 | }
61 | sourceValue = values
62 |
63 | var adapt AnyStringableAdaptor
64 | decoderInfo := e.Runtime.GetCustomCoder() // custom decoder, specified by "decoder" directive
65 | if decoderInfo != nil {
66 | adapt = decoderInfo.Adapt
67 | }
68 | var decoder StringSlicable
69 | decoder, err = NewStringSlicable(e.Runtime.Value.Elem(), adapt)
70 | if err == nil {
71 | err = decoder.FromStringSlice(values)
72 | }
73 | }
74 |
75 | if err != nil {
76 | return &fieldError{key, sourceValue, err}
77 | }
78 | e.Runtime.MarkFieldSet(true)
79 | return nil
80 | }
81 |
--------------------------------------------------------------------------------
/core/header.go:
--------------------------------------------------------------------------------
1 | // directive: "header"
2 | // https://ggicci.github.io/httpin/directives/header
3 |
4 | package core
5 |
6 | import (
7 | "mime/multipart"
8 | "net/http"
9 | )
10 |
11 | type DirectiveHeader struct{}
12 |
13 | // Decode implements the "header" executor who extracts values
14 | // from the HTTP headers.
15 | func (*DirectiveHeader) Decode(rtm *DirectiveRuntime) error {
16 | req := rtm.GetRequest()
17 | extractor := &FormExtractor{
18 | Runtime: rtm,
19 | Form: multipart.Form{
20 | Value: req.Header,
21 | },
22 | KeyNormalizer: http.CanonicalHeaderKey,
23 | }
24 | return extractor.Extract()
25 | }
26 |
27 | func (*DirectiveHeader) Encode(rtm *DirectiveRuntime) error {
28 | encoder := &FormEncoder{
29 | Setter: rtm.GetRequestBuilder().SetHeader,
30 | }
31 | return encoder.Execute(rtm)
32 | }
33 |
--------------------------------------------------------------------------------
/core/header_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestDirectiveHeader_Decode(t *testing.T) {
11 | type SearchQuery struct {
12 | ApiUid int `in:"header=x-api-uid"`
13 | ApiToken string `in:"header=x-api-token"`
14 | }
15 |
16 | r, _ := http.NewRequest("GET", "/", nil)
17 | r.Header.Set("X-Api-Token", "some-secret-token")
18 | r.Header.Set("X-Api-Uid", "91241844")
19 | expected := &SearchQuery{
20 | ApiUid: 91241844,
21 | ApiToken: "some-secret-token",
22 | }
23 | co, err := New(SearchQuery{})
24 | assert.NoError(t, err)
25 | got, err := co.Decode(r)
26 | assert.NoError(t, err)
27 | assert.Equal(t, expected, got.(*SearchQuery))
28 | }
29 |
30 | func TestDirectiveHeader_NewRequest(t *testing.T) {
31 | type ApiQuery struct {
32 | ApiUid int `in:"header=x-api-uid;omitempty"`
33 | ApiToken *string `in:"header=X-Api-Token;omitempty"`
34 | }
35 |
36 | t.Run("with all values", func(t *testing.T) {
37 | tk := "some-secret-token"
38 | query := &ApiQuery{
39 | ApiUid: 91241844,
40 | ApiToken: &tk,
41 | }
42 |
43 | co, err := New(ApiQuery{})
44 | assert.NoError(t, err)
45 | req, err := co.NewRequest("POST", "/api", query)
46 | assert.NoError(t, err)
47 |
48 | expected, _ := http.NewRequest("POST", "/api", nil)
49 | // NOTE: the key will be canonicalized
50 | expected.Header.Set("x-api-uid", "91241844")
51 | expected.Header.Set("X-Api-Token", "some-secret-token")
52 | assert.Equal(t, expected, req)
53 | })
54 |
55 | t.Run("with empty value", func(t *testing.T) {
56 | query := &ApiQuery{
57 | ApiUid: 0,
58 | ApiToken: nil,
59 | }
60 |
61 | co, err := New(ApiQuery{})
62 | assert.NoError(t, err)
63 | req, err := co.NewRequest("POST", "/api", query)
64 | assert.NoError(t, err)
65 |
66 | expected, _ := http.NewRequest("POST", "/api", nil)
67 | assert.Equal(t, expected, req)
68 |
69 | _, ok := req.Header["X-Api-Uid"]
70 | assert.False(t, ok)
71 |
72 | _, ok = req.Header["X-Api-Token"]
73 | assert.False(t, ok)
74 | })
75 | }
76 |
--------------------------------------------------------------------------------
/core/hybridcoder.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "encoding"
5 | "errors"
6 | "reflect"
7 |
8 | "github.com/ggicci/httpin/internal"
9 | )
10 |
11 | type HybridCoder struct {
12 | internal.StringMarshaler
13 | internal.StringUnmarshaler
14 | }
15 |
16 | func (c *HybridCoder) ToString() (string, error) {
17 | if c.StringMarshaler != nil {
18 | return c.StringMarshaler.ToString()
19 | }
20 | return "", errors.New("StringMarshaler not implemented")
21 | }
22 |
23 | func (c *HybridCoder) FromString(s string) error {
24 | if c.StringUnmarshaler != nil {
25 | return c.StringUnmarshaler.FromString(s)
26 | }
27 | return errors.New("StringUnmarshaler not implemented")
28 | }
29 |
30 | // Hybridize a reflect.Value to a Stringable if possible.
31 | func hybridizeCoder(rv reflect.Value) Stringable {
32 | if !rv.CanInterface() {
33 | return nil
34 | }
35 |
36 | coder := &HybridCoder{}
37 |
38 | // Interface: StringMarshaler.
39 | if rv.Type().Implements(stringMarshalerType) {
40 | coder.StringMarshaler = rv.Interface().(internal.StringMarshaler)
41 | } else if rv.Type().Implements(textMarshalerType) {
42 | coder.StringMarshaler = &textMarshalerWrapper{rv.Interface().(encoding.TextMarshaler), nil}
43 | }
44 |
45 | // Interface: StringUnmarshaler.
46 | if rv.Type().Implements(stringUnmarshalerType) {
47 | coder.StringUnmarshaler = rv.Interface().(internal.StringUnmarshaler)
48 | } else if rv.Type().Implements(textUnmarshalerType) {
49 | coder.StringUnmarshaler = &textMarshalerWrapper{nil, rv.Interface().(encoding.TextUnmarshaler)}
50 | }
51 |
52 | if coder.StringMarshaler == nil && coder.StringUnmarshaler == nil {
53 | return nil
54 | }
55 |
56 | return coder
57 | }
58 |
59 | type textMarshalerWrapper struct {
60 | encoding.TextMarshaler
61 | encoding.TextUnmarshaler
62 | }
63 |
64 | func (w textMarshalerWrapper) ToString() (string, error) {
65 | b, err := w.TextMarshaler.MarshalText()
66 | if err != nil {
67 | return "", err
68 | }
69 | return string(b), nil
70 | }
71 |
72 | func (w textMarshalerWrapper) FromString(s string) error {
73 | return w.TextUnmarshaler.UnmarshalText([]byte(s))
74 | }
75 |
76 | var (
77 | stringMarshalerType = internal.TypeOf[internal.StringMarshaler]()
78 | stringUnmarshalerType = internal.TypeOf[internal.StringUnmarshaler]()
79 | textMarshalerType = internal.TypeOf[encoding.TextMarshaler]()
80 | textUnmarshalerType = internal.TypeOf[encoding.TextUnmarshaler]()
81 | )
82 |
--------------------------------------------------------------------------------
/core/hybridcoder_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "reflect"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestHybridCoder_MarshalTextOnly(t *testing.T) {
12 | apple := &textMarshalerApple{}
13 | rv := reflect.ValueOf(apple)
14 | stringable := hybridizeCoder(rv)
15 | assert.NotNil(t, stringable)
16 |
17 | text, err := stringable.ToString()
18 | assert.NoError(t, err)
19 | assert.Equal(t, "apple", text)
20 |
21 | assert.ErrorContains(t, stringable.FromString("red apple"), "StringUnmarshaler not implemented")
22 | }
23 |
24 | func TestHybridCoder_UnmarshalTextOnly(t *testing.T) {
25 | banana := &textUnmarshalerBanana{}
26 | rv := reflect.ValueOf(banana)
27 | stringable := hybridizeCoder(rv)
28 | assert.NotNil(t, stringable)
29 |
30 | text, err := stringable.ToString()
31 | assert.ErrorContains(t, err, "StringMarshaler not implemented")
32 | assert.Empty(t, text)
33 |
34 | err = stringable.FromString("yellow banana")
35 | assert.NoError(t, err)
36 | assert.Equal(t, "yellow banana", banana.Content)
37 | }
38 |
39 | func TestHybridCoder_MarshalText_and_UnmarshalText(t *testing.T) {
40 | orange := &textMarshalerAndUnmarshalerOrange{Content: "orange"}
41 | rv := reflect.ValueOf(orange)
42 | stringable := hybridizeCoder(rv)
43 | assert.NotNil(t, stringable)
44 |
45 | text, err := stringable.ToString()
46 | assert.NoError(t, err)
47 | assert.Equal(t, "orange", text)
48 |
49 | err = stringable.FromString("red orange")
50 | assert.NoError(t, err)
51 | assert.Equal(t, "red orange", orange.Content)
52 | }
53 |
54 | func TestHybridCoder_StringMarshaler_TakesPrecedence(t *testing.T) {
55 | peach := &stringMarshalerAndTextMarshalerPeach{Content: "peach"}
56 | rv := reflect.ValueOf(peach)
57 | stringable := hybridizeCoder(rv)
58 | assert.NotNil(t, stringable)
59 |
60 | text, err := stringable.ToString()
61 | assert.NoError(t, err)
62 | assert.Equal(t, "ToString:peach", text)
63 |
64 | err = stringable.FromString("red peach")
65 | assert.ErrorContains(t, err, "StringUnmarshaler not implemented")
66 | }
67 |
68 | func TestHybridCoder_StringUnmarshaler_TakesPrecedence(t *testing.T) {
69 | peach := &stringUnmarshalerAndTextUnmarshalerPeach{Content: "peach"}
70 | rv := reflect.ValueOf(peach)
71 | stringable := hybridizeCoder(rv)
72 | assert.NotNil(t, stringable)
73 |
74 | text, err := stringable.ToString()
75 | assert.ErrorContains(t, err, "StringMarshaler not implemented")
76 | assert.Empty(t, text)
77 |
78 | err = stringable.FromString("red peach")
79 | assert.NoError(t, err)
80 | assert.Equal(t, "FromString:red peach", peach.Content)
81 | }
82 |
83 | func TestHybridCoder_StringMarshaler_and_TextUnmarshaler(t *testing.T) {
84 | pineapple := &stringMarshalerAndTextUnmarshalerPineapple{Content: "pineapple"}
85 | rv := reflect.ValueOf(pineapple)
86 | stringable := hybridizeCoder(rv)
87 | assert.NotNil(t, stringable)
88 |
89 | text, err := stringable.ToString()
90 | assert.NoError(t, err)
91 | assert.Equal(t, "ToString:pineapple", text)
92 |
93 | err = stringable.FromString("red pineapple")
94 | assert.NoError(t, err)
95 | assert.Equal(t, "UnmarshalText:red pineapple", pineapple.Content)
96 | }
97 |
98 | func TestHybridCoder_MarshalText_Error(t *testing.T) {
99 | watermelon := &textMarshalerSpoiledWatermelon{}
100 | rv := reflect.ValueOf(watermelon)
101 | stringable := hybridizeCoder(rv)
102 | assert.NotNil(t, stringable)
103 |
104 | text, err := stringable.ToString()
105 | assert.ErrorContains(t, err, "spoiled")
106 | assert.Empty(t, text)
107 | }
108 |
109 | func TestHybridCoder_ErrCannotInterface(t *testing.T) {
110 | type mystruct struct {
111 | unexportedName string
112 | }
113 | v := mystruct{unexportedName: "mystruct"}
114 | rv := reflect.ValueOf(v)
115 |
116 | stringable := hybridizeCoder(rv.Field(0))
117 | assert.Nil(t, stringable)
118 | }
119 |
120 | func TestHybridCoder_NilOnNoInterfacesDetected(t *testing.T) {
121 | var zero zeroInterface
122 | rv := reflect.ValueOf(zero)
123 |
124 | stringable := hybridizeCoder(rv)
125 | assert.Nil(t, stringable)
126 | }
127 |
128 | type textMarshalerApple struct{} // only implements encoding.TextMarshaler
129 |
130 | func (t *textMarshalerApple) MarshalText() ([]byte, error) {
131 | return []byte("apple"), nil
132 | }
133 |
134 | type textUnmarshalerBanana struct{ Content string } // only implements encoding.TextUnmarshaler
135 |
136 | func (t *textUnmarshalerBanana) UnmarshalText(text []byte) error {
137 | t.Content = string(text)
138 | return nil
139 | }
140 |
141 | type textMarshalerAndUnmarshalerOrange struct{ Content string } // implements both encoding.TextMarshaler and encoding.TextUnmarshaler
142 |
143 | func (t *textMarshalerAndUnmarshalerOrange) MarshalText() ([]byte, error) {
144 | return []byte(t.Content), nil
145 | }
146 |
147 | func (t *textMarshalerAndUnmarshalerOrange) UnmarshalText(text []byte) error {
148 | t.Content = string(text)
149 | return nil
150 | }
151 |
152 | // implements internal.StringMarshaler and encoding.TextMarshaler
153 | // will use internal.StringMarshaler
154 | type stringMarshalerAndTextMarshalerPeach struct{ Content string }
155 |
156 | func (s *stringMarshalerAndTextMarshalerPeach) ToString() (string, error) {
157 | return "ToString:" + s.Content, nil
158 | }
159 |
160 | func (s *stringMarshalerAndTextMarshalerPeach) MarshalText() ([]byte, error) {
161 | return []byte("MarshalText:" + s.Content), nil
162 | }
163 |
164 | // implements internal.StringUnmarshaler and encoding.TextUnmarshaler
165 | // will use internal.StringUnmarshaler
166 | type stringUnmarshalerAndTextUnmarshalerPeach struct{ Content string }
167 |
168 | func (s *stringUnmarshalerAndTextUnmarshalerPeach) FromString(text string) error {
169 | s.Content = "FromString:" + text
170 | return nil
171 | }
172 |
173 | func (s *stringUnmarshalerAndTextUnmarshalerPeach) UnmarshalText(text []byte) error {
174 | s.Content = "UnmarshalText:" + string(text)
175 | return nil
176 | }
177 |
178 | type stringMarshalerAndTextUnmarshalerPineapple struct{ Content string }
179 |
180 | func (s *stringMarshalerAndTextUnmarshalerPineapple) ToString() (string, error) {
181 | return "ToString:" + s.Content, nil
182 | }
183 |
184 | func (s *stringMarshalerAndTextUnmarshalerPineapple) UnmarshalText(text []byte) error {
185 | s.Content = "UnmarshalText:" + string(text)
186 | return nil
187 | }
188 |
189 | type textMarshalerSpoiledWatermelon struct{}
190 |
191 | func (t *textMarshalerSpoiledWatermelon) MarshalText() ([]byte, error) {
192 | return nil, errors.New("spoiled")
193 | }
194 |
195 | type zeroInterface struct{}
196 |
--------------------------------------------------------------------------------
/core/nonzero.go:
--------------------------------------------------------------------------------
1 | // directive: "nonzero"
2 | // https://ggicci.github.io/httpin/directives/nonzero
3 |
4 | package core
5 |
6 | import "errors"
7 |
8 | // DirectiveNonzero implements the "nonzero" executor who indicates that the field must not be a "zero value".
9 | // In golang, the "zero value" means:
10 | // - nil
11 | // - false
12 | // - 0
13 | // - ""
14 | // - etc.
15 | //
16 | // Unlike the "required" executor, the "nonzero" executor checks the value of the field.
17 | type DirectiveNonzero struct{}
18 |
19 | func (*DirectiveNonzero) Decode(rtm *DirectiveRuntime) error {
20 | if rtm.Value.Elem().IsZero() {
21 | return errors.New("zero value")
22 | }
23 | return nil
24 | }
25 |
26 | func (*DirectiveNonzero) Encode(rtm *DirectiveRuntime) error {
27 | if rtm.Value.IsZero() {
28 | return errors.New("zero value")
29 | }
30 | return nil
31 | }
32 |
--------------------------------------------------------------------------------
/core/nonzero_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | type NonzeroQuery struct {
12 | Name string `in:"query=name;nonzero"`
13 | AgeRange []int `in:"query=age;nonzero"`
14 | }
15 |
16 | func TestDirectiveNonzero_Decode(t *testing.T) {
17 | co, err := New(&NonzeroQuery{})
18 | assert.NoError(t, err)
19 |
20 | r, _ := http.NewRequest("GET", "/users", nil)
21 | r.URL.RawQuery = url.Values{
22 | "name": {"ggicci"},
23 | "age": {"18", "999"},
24 | }.Encode()
25 |
26 | got, err := co.Decode(r)
27 | assert.NoError(t, err)
28 | assert.Equal(t, &NonzeroQuery{
29 | Name: "ggicci",
30 | AgeRange: []int{18, 999},
31 | }, got.(*NonzeroQuery))
32 | }
33 |
34 | func TestDirectiveNonzero_Decode_ErrZeroValue(t *testing.T) {
35 | co, err := New(&NonzeroQuery{})
36 | assert.NoError(t, err)
37 |
38 | r, _ := http.NewRequest("GET", "/users", nil)
39 | r.URL.RawQuery = url.Values{
40 | "name": {"ggicci"},
41 | }.Encode()
42 |
43 | _, err = co.Decode(r)
44 | assert.Error(t, err)
45 | var invalidField *InvalidFieldError
46 | assert.ErrorAs(t, err, &invalidField)
47 | assert.Equal(t, "AgeRange", invalidField.Field)
48 | assert.Equal(t, "nonzero", invalidField.Directive)
49 | assert.Empty(t, invalidField.Key)
50 | assert.Nil(t, invalidField.Value)
51 | }
52 |
53 | func TestDirectiveNonzero_Decode_InNestedJSONBody_Issue49(t *testing.T) {
54 | type UpdateUserInput struct {
55 | Payload struct {
56 | Display string `json:"display" in:"nonzero"`
57 | } `in:"body=json"`
58 | }
59 |
60 | // NOTE: WithNestedDirectivesEnabled(true) is required to enable nested directives.
61 | co, err := New(&UpdateUserInput{}, WithNestedDirectivesEnabled(true))
62 | assert.NoError(t, err)
63 |
64 | r, _ := http.NewRequest("POST", "/users/1", nil)
65 | r.Header.Set("Content-Type", "application/json")
66 | r.Body = makeBodyReader(`{"display": ""}`)
67 | got, err := co.Decode(r)
68 | assert.Nil(t, got)
69 | assert.ErrorContains(t, err, "nonzero")
70 | var invalidField *InvalidFieldError
71 | assert.ErrorAs(t, err, &invalidField)
72 | assert.Equal(t, "Payload", invalidField.Field)
73 | assert.Equal(t, "nonzero", invalidField.Directive)
74 | }
75 |
76 | func TestDirectiveNonzero_NewRequest(t *testing.T) {
77 | co, err := New(&NonzeroQuery{})
78 | assert.NoError(t, err)
79 |
80 | expected, _ := http.NewRequest("GET", "/users", nil)
81 | expected.URL.RawQuery = url.Values{
82 | "name": {"ggicci"},
83 | "age": {"18", "999"},
84 | }.Encode()
85 |
86 | req, err := co.NewRequest("GET", "/users", &NonzeroQuery{
87 | Name: "ggicci",
88 | AgeRange: []int{18, 999},
89 | })
90 | assert.NoError(t, err)
91 | assert.Equal(t, expected, req)
92 | }
93 |
94 | func TestDirectiveNonzero_NewRequest_ErrZeroValue(t *testing.T) {
95 | co, err := New(&NonzeroQuery{})
96 | assert.NoError(t, err)
97 |
98 | _, err = co.NewRequest("GET", "/users", &NonzeroQuery{})
99 | assert.ErrorContains(t, err, "zero value")
100 | assert.ErrorContains(t, err, "Name")
101 | assert.ErrorContains(t, err, "AgeRange")
102 | }
103 |
--------------------------------------------------------------------------------
/core/omitempty.go:
--------------------------------------------------------------------------------
1 | // directive: "omitempty"
2 | // https://ggicci.github.io/httpin/directives/omitempty
3 |
4 | package core
5 |
6 | // DirectiveOmitEmpty is used with the DirectiveQuery, DirectiveForm, and DirectiveHeader to indicate that the field
7 | // should be omitted when the value is empty.
8 | // It does not have any affect when used by itself
9 | type DirectiveOmitEmpty struct{}
10 |
11 | func (*DirectiveOmitEmpty) Decode(_ *DirectiveRuntime) error {
12 | return nil
13 | }
14 |
15 | func (*DirectiveOmitEmpty) Encode(_ *DirectiveRuntime) error {
16 | return nil
17 | }
18 |
--------------------------------------------------------------------------------
/core/option.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | )
6 |
7 | const minimumMaxMemory = int64(1 << 10) // 1KB
8 | const defaultMaxMemory = int64(32 << 20) // 32 MB
9 |
10 | type Option func(*Core) error
11 |
12 | var globalNestedDirectivesEnabled bool = false
13 |
14 | // EnableNestedDirectives sets the global flag to enable nested directives.
15 | // Nested directives are disabled by default.
16 | func EnableNestedDirectives(on bool) {
17 | globalNestedDirectivesEnabled = on
18 | }
19 |
20 | // WithErrorHandler overrides the default error handler.
21 | func WithErrorHandler(custom ErrorHandler) Option {
22 | return func(c *Core) error {
23 | if err := validateErrorHandler(custom); err != nil {
24 | return err
25 | } else {
26 | c.errorHandler = custom
27 | return nil
28 | }
29 | }
30 | }
31 |
32 | // WithMaxMemory overrides the default maximum memory size (32MB) when reading
33 | // the request body. See https://pkg.go.dev/net/http#Request.ParseMultipartForm
34 | // for more details.
35 | func WithMaxMemory(maxMemory int64) Option {
36 | return func(c *Core) error {
37 | if maxMemory < minimumMaxMemory {
38 | return errors.New("max memory too small")
39 | }
40 | c.maxMemory = maxMemory
41 | return nil
42 | }
43 | }
44 |
45 | // WithNestedDirectivesEnabled enables/disables nested directives.
46 | func WithNestedDirectivesEnabled(enable bool) Option {
47 | return func(c *Core) error {
48 | c.enableNestedDirectives = enable
49 | return nil
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/core/option_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "net/http"
5 | "reflect"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestWithErrorHandler(t *testing.T) {
12 | // Use the default error handler.
13 | co, _ := New(ProductQuery{})
14 | assert.True(t, equalFuncs(globalCustomErrorHandler, co.GetErrorHandler()))
15 |
16 | // Override the default error handler.
17 | myErrorHandler := func(rw http.ResponseWriter, r *http.Request, err error) {}
18 | co, _ = New(ProductQuery{}, WithErrorHandler(myErrorHandler))
19 | assert.True(t, equalFuncs(myErrorHandler, co.GetErrorHandler()))
20 |
21 | // Fail on nil error handler.
22 | _, err := New(ProductQuery{}, WithErrorHandler(nil))
23 | assert.ErrorContains(t, err, "nil error handler")
24 | }
25 |
26 | func TestWithMaxMemory(t *testing.T) {
27 | // Use the default max memory.
28 | co, _ := New(ProductQuery{})
29 | assert.Equal(t, defaultMaxMemory, co.maxMemory)
30 |
31 | // Override the default max memory.
32 | co, _ = New(ProductQuery{}, WithMaxMemory(16<<20))
33 | assert.Equal(t, int64(16<<20), co.maxMemory)
34 |
35 | // Fail on too small max memory.
36 | _, err := New(ProductQuery{}, WithMaxMemory(100))
37 | assert.ErrorContains(t, err, "max memory too small")
38 | }
39 |
40 | func TestWithNestedDirectivesEnabled(t *testing.T) {
41 | // Override the default nested directives flag.
42 | co, _ := New(ProductQuery{}, WithNestedDirectivesEnabled(true))
43 | assert.Equal(t, true, co.enableNestedDirectives)
44 | co, _ = New(ProductQuery{}, WithNestedDirectivesEnabled(false))
45 | assert.Equal(t, false, co.enableNestedDirectives)
46 | }
47 |
48 | func TestEnableNestedDirectives(t *testing.T) {
49 | // Use the default nested directives flag.
50 | EnableNestedDirectives(false)
51 | co, _ := New(ProductQuery{})
52 | assert.Equal(t, false, co.enableNestedDirectives)
53 |
54 | EnableNestedDirectives(true)
55 | co, _ = New(ProductQuery{})
56 | assert.Equal(t, true, co.enableNestedDirectives)
57 | }
58 |
59 | func equalFuncs(expected, actual any) bool {
60 | return reflect.ValueOf(expected).Pointer() == reflect.ValueOf(actual).Pointer()
61 | }
62 |
--------------------------------------------------------------------------------
/core/owl.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import "github.com/ggicci/owl"
4 |
5 | func init() {
6 | owl.UseTag("in")
7 | }
8 |
--------------------------------------------------------------------------------
/core/patch_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/url"
7 | "path/filepath"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/ggicci/httpin/patch"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | type AccountPatch struct {
16 | Email patch.Field[string] `in:"form=email"`
17 | Age patch.Field[int] `in:"form=age"`
18 | Avatar patch.Field[*File] `in:"form=avatar"`
19 | Hobbies patch.Field[[]string] `in:"form=hobbies"`
20 | Pictures patch.Field[[]*File] `in:"form=pictures"`
21 | }
22 |
23 | func TestPatchField(t *testing.T) {
24 | fileContent := []byte("hello")
25 | r := newMultipartFormRequestFromMap(map[string]any{
26 | "age": "18",
27 | "avatar": fileContent,
28 | "hobbies": []string{
29 | "reading",
30 | "swimming",
31 | },
32 | })
33 |
34 | co, err := New(AccountPatch{})
35 | assert.NoError(t, err)
36 | gotValue, err := co.Decode(r)
37 | assert.NoError(t, err)
38 | got := gotValue.(*AccountPatch)
39 |
40 | assert.Equal(t, patch.Field[string]{
41 | Valid: false,
42 | Value: "",
43 | }, got.Email)
44 |
45 | assert.Equal(t, patch.Field[int]{
46 | Valid: true,
47 | Value: 18,
48 | }, got.Age)
49 |
50 | assert.Equal(t, patch.Field[[]string]{
51 | Valid: true,
52 | Value: []string{"reading", "swimming"},
53 | }, got.Hobbies)
54 |
55 | assert.Equal(t, patch.Field[[]*File]{
56 | Valid: false,
57 | Value: nil,
58 | }, got.Pictures)
59 |
60 | assertDecodedFile(t, got.Avatar.Value, "avatar.txt", fileContent)
61 | }
62 |
63 | func TestPatchField_DecodeValueFailed(t *testing.T) {
64 | r, _ := http.NewRequest("GET", "/", nil)
65 | r.Form = url.Values{
66 | "email": {"abc@example.com"},
67 | "age": {"eighteen"},
68 | }
69 | co, err := New(AccountPatch{})
70 | assert.NoError(t, err)
71 | gotValue, err := co.Decode(r)
72 | assert.Error(t, err)
73 | var ferr *InvalidFieldError
74 | assert.ErrorAs(t, err, &ferr)
75 | assert.Equal(t, "Age", ferr.Field)
76 | assert.Equal(t, []string{"eighteen"}, ferr.Value)
77 | assert.Equal(t, "form", ferr.Directive)
78 | assert.Nil(t, gotValue)
79 | }
80 |
81 | func TestPatchField_DecodeFileFailed(t *testing.T) {
82 | body, writer := newMultipartFormWriterFromMap(map[string]any{
83 | "email": "abc@example.com",
84 | "age": "18",
85 | "avatar": []byte("hello"),
86 | })
87 |
88 | // break the boundary to make the file decoder fail
89 | r, _ := http.NewRequest("POST", "/", breakMultipartFormBoundary(body))
90 | r.Header.Set("Content-Type", writer.FormDataContentType())
91 |
92 | co, err := New(AccountPatch{})
93 | assert.NoError(t, err)
94 | gotValue, err := co.Decode(r)
95 | assert.Nil(t, gotValue)
96 | assert.Error(t, err)
97 | }
98 |
99 | func TestPatchField_NewRequest(t *testing.T) {
100 | type ListQuery struct {
101 | Username patch.Field[string] `in:"query=username"`
102 | Age patch.Field[int] `in:"query=age"`
103 | State patch.Field[[]string] `in:"query=state[]"`
104 | }
105 |
106 | co, err := New(ListQuery{})
107 | assert.NoError(t, err)
108 |
109 | testcases := []struct {
110 | Query *ListQuery
111 | Expected url.Values
112 | }{
113 | {&ListQuery{
114 | Username: patch.Field[string]{Value: "ggicci", Valid: true},
115 | Age: patch.Field[int]{Value: 18, Valid: false},
116 | }, url.Values{"username": {"ggicci"}}},
117 | {&ListQuery{
118 | Age: patch.Field[int]{Value: 18, Valid: false},
119 | }, url.Values{}},
120 | {&ListQuery{
121 | Age: patch.Field[int]{Value: 18, Valid: true},
122 | }, url.Values{"age": {"18"}}},
123 | {&ListQuery{
124 | Username: patch.Field[string]{Value: "ggicci", Valid: true},
125 | Age: patch.Field[int]{Value: 18, Valid: true},
126 | State: patch.Field[[]string]{
127 | Value: []string{"reading", "swimming"},
128 | Valid: true,
129 | },
130 | }, url.Values{
131 | "username": {"ggicci"},
132 | "age": {"18"},
133 | "state[]": {"reading", "swimming"},
134 | }},
135 | }
136 |
137 | for _, c := range testcases {
138 | req, err := co.NewRequest("GET", "/list", c.Query)
139 | assert.NoError(t, err)
140 |
141 | expected, _ := http.NewRequest("GET", "/list", nil)
142 | expected.URL.RawQuery = c.Expected.Encode()
143 | assert.Equal(t, expected, req)
144 | }
145 | }
146 |
147 | func TestPatchField_NewRequest_NoFiles(t *testing.T) {
148 | assert := assert.New(t)
149 | payload := &AccountPatch{
150 | Email: patch.Field[string]{Value: "abc@example.com", Valid: true},
151 | Age: patch.Field[int]{Value: 18, Valid: true},
152 | Avatar: patch.Field[*File]{Value: nil, Valid: false},
153 | Hobbies: patch.Field[[]string]{
154 | Value: []string{"reading", "swimming"},
155 | Valid: true,
156 | },
157 | Pictures: patch.Field[[]*File]{Value: nil, Valid: false},
158 | }
159 |
160 | expected, _ := http.NewRequest("POST", "/patchAccount", nil)
161 | expectedForm := url.Values{
162 | "email": {"abc@example.com"},
163 | "age": {"18"},
164 | "hobbies": {"reading", "swimming"},
165 | }
166 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded")
167 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode()))
168 |
169 | co, err := New(AccountPatch{})
170 | assert.NoError(err)
171 | req, err := co.NewRequest("POST", "/patchAccount", payload)
172 | assert.NoError(err)
173 | assert.Equal(expected, req)
174 | }
175 |
176 | func TestPatchField_NewRequest_WithFiles(t *testing.T) {
177 | assert := assert.New(t)
178 | avatarFile := createTempFile(t, []byte("handsome avatar image"))
179 | pic1Filename := createTempFile(t, []byte("pic1 content"))
180 | pic2Filename := createTempFile(t, []byte("pic2 content"))
181 |
182 | payload := &AccountPatch{
183 | Email: patch.Field[string]{Value: "abc@example.com", Valid: true},
184 | Age: patch.Field[int]{Value: 18, Valid: true},
185 | Avatar: patch.Field[*File]{Value: UploadFile(avatarFile), Valid: true},
186 | Hobbies: patch.Field[[]string]{
187 | Value: []string{"reading", "swimming"},
188 | Valid: true,
189 | },
190 | Pictures: patch.Field[[]*File]{
191 | Value: []*File{
192 | UploadFile(pic1Filename),
193 | UploadFile(pic2Filename),
194 | },
195 | Valid: true,
196 | },
197 | }
198 |
199 | // See TestMultipartFormEncode_UploadFilename for more details.
200 | co, err := New(AccountPatch{})
201 | assert.NoError(err)
202 | req, err := co.NewRequest("POST", "/patchAccount", payload)
203 | assert.NoError(err)
204 |
205 | // Server side: receive files (decode).
206 | gotValue, err := co.Decode(req)
207 | assert.NoError(err)
208 | got, ok := gotValue.(*AccountPatch)
209 | assert.True(ok)
210 | assert.True(got.Email.Valid)
211 | assert.Equal("abc@example.com", got.Email.Value)
212 | assert.True(got.Age.Valid)
213 | assert.Equal(18, got.Age.Value)
214 | assert.True(got.Hobbies.Valid)
215 | assert.Equal([]string{"reading", "swimming"}, got.Hobbies.Value)
216 | assert.True(got.Avatar.Valid)
217 | assertDecodedFile(t, got.Avatar.Value, filepath.Base(avatarFile), []byte("handsome avatar image"))
218 | assert.True(got.Pictures.Valid)
219 | assert.Len(got.Pictures.Value, 2)
220 | assertDecodedFile(t, got.Pictures.Value[0], filepath.Base(pic1Filename), []byte("pic1 content"))
221 | assertDecodedFile(t, got.Pictures.Value[1], filepath.Base(pic2Filename), []byte("pic2 content"))
222 | }
223 |
--------------------------------------------------------------------------------
/core/path.go:
--------------------------------------------------------------------------------
1 | // directive: "path"
2 | // https://ggicci.github.io/httpin/directives/path
3 |
4 | package core
5 |
6 | import "errors"
7 |
8 | type DirectivePath struct {
9 | decode func(*DirectiveRuntime) error
10 | }
11 |
12 | func NewDirectivePath(decodeFunc func(*DirectiveRuntime) error) *DirectivePath {
13 | return &DirectivePath{
14 | decode: decodeFunc,
15 | }
16 | }
17 |
18 | func (dir *DirectivePath) Decode(rtm *DirectiveRuntime) error {
19 | return dir.decode(rtm)
20 | }
21 |
22 | // Encode replaces the placeholders in URL path with the given value.
23 | func (*DirectivePath) Encode(rtm *DirectiveRuntime) error {
24 | encoder := &FormEncoder{
25 | Setter: rtm.GetRequestBuilder().SetPath,
26 | }
27 | return encoder.Execute(rtm)
28 | }
29 |
30 | // defaultPathDirective is the default path directive, which only supports encoding,
31 | // while the decoding function is not implmented. Because the path decoding depends on the
32 | // routing framework, it should be implemented in the integration package.
33 | // See integration/gochi.go and integration/gorilla.go for examples.
34 | var defaultPathDirective = NewDirectivePath(func(rtm *DirectiveRuntime) error {
35 | return errors.New("unimplemented path decoding function")
36 | })
37 |
--------------------------------------------------------------------------------
/core/path_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "io"
7 | "net/http"
8 | "testing"
9 |
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func myCustomPathDecode(rtm *DirectiveRuntime) error {
14 | return assert.AnError
15 | }
16 |
17 | func TestDirectivePath_Decode(t *testing.T) {
18 | pathDirective := NewDirectivePath(myCustomPathDecode)
19 | assert.ErrorIs(t, pathDirective.Decode(nil), assert.AnError)
20 | }
21 |
22 | func TestDirectivePath_ErrDefaultPathDecodingIsUnimplemented(t *testing.T) {
23 | type GetProfileRequest struct {
24 | Username string `in:"path=username"`
25 | }
26 |
27 | co, err := New(GetProfileRequest{})
28 | assert.NoError(t, err)
29 | req, _ := http.NewRequest("GET", "/users/ggicci", nil)
30 | _, err = co.Decode(req)
31 | assert.ErrorContains(t, err, "unimplemented path decoding function")
32 | }
33 |
34 | func TestDirectivePath_NewRequest_DefalutPathEncodingShouldWork(t *testing.T) {
35 | assert := assert.New(t)
36 | type Repository struct {
37 | Name string `json:"name"`
38 | Visibility string `json:"visibility"` // public, private, internal
39 | License string `json:"license"`
40 | }
41 | type CreateRepositoryRequest struct {
42 | Owner string `in:"path=owner"`
43 | Payload *Repository `in:"body=json"`
44 | }
45 |
46 | query := &CreateRepositoryRequest{
47 | Owner: "ggicci",
48 | Payload: &Repository{
49 | Name: "httpin",
50 | Visibility: "public",
51 | License: "MIT",
52 | },
53 | }
54 |
55 | co, err := New(query)
56 | assert.NoError(err)
57 | req, err := co.NewRequest("POST", "/users/{owner}/repos", query)
58 | assert.NoError(err)
59 |
60 | expected, _ := http.NewRequest("POST", "/users/ggicci/repos", nil)
61 | expected.Header.Set("Content-Type", "application/json")
62 | var body bytes.Buffer
63 | assert.NoError(json.NewEncoder(&body).Encode(query.Payload))
64 | expected.Body = io.NopCloser(&body)
65 | assert.Equal(expected, req)
66 | }
67 |
--------------------------------------------------------------------------------
/core/query.go:
--------------------------------------------------------------------------------
1 | // directive: "query"
2 | // https://ggicci.github.io/httpin/directives/query
3 |
4 | package core
5 |
6 | import (
7 | "mime/multipart"
8 | )
9 |
10 | type DirectiveQuery struct{}
11 |
12 | // Decode implements the "query" executor who extracts values from
13 | // the querystring of an HTTP request.
14 | func (*DirectiveQuery) Decode(rtm *DirectiveRuntime) error {
15 | req := rtm.GetRequest()
16 | extractor := &FormExtractor{
17 | Runtime: rtm,
18 | Form: multipart.Form{
19 | Value: req.URL.Query(),
20 | },
21 | }
22 | return extractor.Extract()
23 | }
24 |
25 | func (*DirectiveQuery) Encode(rtm *DirectiveRuntime) error {
26 | encoder := &FormEncoder{
27 | Setter: rtm.GetRequestBuilder().SetQuery,
28 | }
29 | return encoder.Execute(rtm)
30 | }
31 |
--------------------------------------------------------------------------------
/core/query_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/url"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestDirectiveQuery_Decode(t *testing.T) {
13 | type SearchQuery struct {
14 | Query string `in:"query=q;required"`
15 | PageNumber int `in:"query=p"`
16 | PageSize int `in:"query=page_size"`
17 | }
18 |
19 | r, _ := http.NewRequest("GET", "/?q=doggy&p=2&page_size=5", nil)
20 | expected := &SearchQuery{
21 | Query: "doggy",
22 | PageNumber: 2,
23 | PageSize: 5,
24 | }
25 |
26 | co, err := New(SearchQuery{})
27 | assert.NoError(t, err)
28 | got, err := co.Decode(r)
29 | assert.NoError(t, err)
30 | assert.Equal(t, expected, got.(*SearchQuery))
31 | }
32 |
33 | func TestDirectiveQuery_NewRequest(t *testing.T) {
34 | type SearchQuery struct {
35 | Name string `in:"query=name"`
36 | Age int `in:"query=age;omitempty"`
37 | Enabled bool `in:"query=enabled"`
38 | Price float64 `in:"query=price"`
39 |
40 | NameList []string `in:"query=name_list[]"`
41 | AgeList []int `in:"query=age_list[]"`
42 |
43 | NamePointer *string `in:"query=name_pointer"`
44 | AgePointer *int `in:"query=age_pointer;omitempty"`
45 | }
46 |
47 | t.Run("with all values", func(t *testing.T) {
48 | query := &SearchQuery{
49 | Name: "cupcake",
50 | Age: 12,
51 | Enabled: true,
52 | Price: 6.28,
53 | NameList: []string{"apple", "banana", "cherry"},
54 | AgeList: []int{1, 2, 3},
55 | NamePointer: func() *string {
56 | s := "pointer cupcake"
57 | return &s
58 | }(),
59 | AgePointer: func() *int {
60 | i := 19
61 | return &i
62 | }(),
63 | }
64 |
65 | co, err := New(SearchQuery{})
66 | assert.NoError(t, err)
67 | req, err := co.NewRequest("GET", "/pets", query)
68 | assert.NoError(t, err)
69 |
70 | expected, _ := http.NewRequest("GET", "/pets", nil)
71 | expectedQuery := make(url.Values)
72 | expectedQuery.Set("name", query.Name) // query.Name
73 | expectedQuery.Set("age", "12") // query.Age
74 | expectedQuery.Set("enabled", "true") // query.Enabled
75 | expectedQuery.Set("price", "6.28") // query.Price
76 | expectedQuery["name_list[]"] = query.NameList // query.NameList
77 | expectedQuery["age_list[]"] = []string{"1", "2", "3"} // query.AgeList
78 | expectedQuery.Set("name_pointer", *query.NamePointer) // query.NamePointer
79 | expectedQuery.Set("age_pointer", "19") // query.PointerAge
80 | expected.URL.RawQuery = expectedQuery.Encode()
81 | assert.Equal(t, expected, req)
82 | })
83 |
84 | t.Run("with empty values", func(t *testing.T) {
85 | query := &SearchQuery{}
86 |
87 | co, err := New(SearchQuery{})
88 | assert.NoError(t, err)
89 | req, err := co.NewRequest("GET", "/pets", query)
90 | assert.NoError(t, err)
91 |
92 | assert.True(t, req.URL.Query().Has("name"))
93 | assert.False(t, req.URL.Query().Has("age"))
94 |
95 | assert.True(t, req.URL.Query().Has("name_pointer"))
96 | assert.False(t, req.URL.Query().Has("age_pointer"))
97 | })
98 | }
99 |
100 | type Location struct {
101 | Latitude float64
102 | Longitude float64
103 | }
104 |
105 | func (l Location) ToString() (string, error) {
106 | return fmt.Sprintf("%f,%f", l.Latitude, l.Longitude), nil
107 | }
108 |
109 | type LocationImplementedTextMarshaler Location
110 |
111 | func (l LocationImplementedTextMarshaler) MarshalText() ([]byte, error) {
112 | if s, err := (Location)(l).ToString(); err != nil {
113 | return nil, err
114 | } else {
115 | return []byte("MarshalText:" + s), nil
116 | }
117 | }
118 | func TestDirectiveQuery_NewRequest_ErrUnsupportedType(t *testing.T) {
119 | type SearchQuery struct {
120 | Map map[string]string `in:"query=map"` // unsupported type: map
121 | }
122 |
123 | co, err := New(SearchQuery{})
124 | assert.NoError(t, err)
125 | _, err = co.NewRequest("GET", "/pets", &SearchQuery{})
126 | assert.ErrorIs(t, err, ErrUnsupportedType)
127 | }
128 |
129 | // See hybridcoder_test.go for more details.
130 | func TestDirectiveQuery_NewRequest_WithTextMarshaler(t *testing.T) {
131 | type SearchQuery struct {
132 | L0 *Location `in:"query=l0"`
133 | L2 *LocationImplementedTextMarshaler `in:"query=l2"`
134 | Radius int `in:"query=radius"`
135 | }
136 |
137 | query := &SearchQuery{
138 | L0: &Location{
139 | Latitude: 1.234,
140 | Longitude: 5.678,
141 | },
142 | L2: &LocationImplementedTextMarshaler{
143 | Latitude: 1.234,
144 | Longitude: 5.678,
145 | },
146 | Radius: 1000,
147 | }
148 |
149 | co, err := New(SearchQuery{})
150 | assert.NoError(t, err)
151 | req, err := co.NewRequest("GET", "/pets", query)
152 | assert.NoError(t, err)
153 |
154 | expected, _ := http.NewRequest("GET", "/pets", nil)
155 | expectedQuery := make(url.Values)
156 | expectedQuery.Set("l0", "1.234000,5.678000")
157 | expectedQuery.Set("l2", "MarshalText:1.234000,5.678000")
158 | expectedQuery.Set("radius", "1000")
159 | expected.URL.RawQuery = expectedQuery.Encode()
160 | assert.Equal(t, expected, req)
161 | }
162 |
--------------------------------------------------------------------------------
/core/registry.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "reflect"
5 |
6 | "github.com/ggicci/httpin/internal"
7 | )
8 |
9 | type AnyStringableAdaptor = internal.AnyStringableAdaptor
10 |
11 | var (
12 | fileTypes = make(map[reflect.Type]struct{})
13 | customStringableAdaptors = make(map[reflect.Type]AnyStringableAdaptor)
14 | namedStringableAdaptors = make(map[string]*NamedAnyStringableAdaptor)
15 | )
16 |
17 | // RegisterCoder registers a custom coder for the given type T. When a field of
18 | // type T is encountered, this coder will be used to convert the value to a
19 | // Stringable, which will be used to convert the value from/to string.
20 | //
21 | // NOTE: this function is designed to override the default Stringable adaptors
22 | // that are registered by this package. For example, if you want to override the
23 | // defualt behaviour of converting a bool value from/to string, you can do this:
24 | //
25 | // func init() {
26 | // core.RegisterCoder[bool](func(b *bool) (core.Stringable, error) {
27 | // return (*YesNo)(b), nil
28 | // })
29 | // }
30 | //
31 | // type YesNo bool
32 | //
33 | // func (yn YesNo) String() string {
34 | // if yn {
35 | // return "yes"
36 | // }
37 | // return "no"
38 | // }
39 | //
40 | // func (yn *YesNo) FromString(s string) error {
41 | // switch s {
42 | // case "yes":
43 | // *yn = true
44 | // case "no":
45 | // *yn = false
46 | // default:
47 | // return fmt.Errorf("invalid YesNo value: %q", s)
48 | // }
49 | // return nil
50 | // }
51 | func RegisterCoder[T any](adapt func(*T) (Stringable, error)) {
52 | customStringableAdaptors[internal.TypeOf[T]()] = internal.NewAnyStringableAdaptor[T](adapt)
53 | }
54 |
55 | // RegisterNamedCoder works similar to RegisterCoder, except that it binds the
56 | // coder to a name. This is useful when you only want to override the types in
57 | // a specific struct field. You will be using the "coder" or "decoder" directive
58 | // to specify the name of the coder to use. For example:
59 | //
60 | // type MyStruct struct {
61 | // Bool bool // use default bool coder
62 | // YesNo bool `in:"coder=yesno"` // use YesNo coder
63 | // }
64 | //
65 | // func init() {
66 | // core.RegisterNamedCoder[bool]("yesno", func(b *bool) (core.Stringable, error) {
67 | // return (*YesNo)(b), nil
68 | // })
69 | // }
70 | //
71 | // type YesNo bool
72 | //
73 | // func (yn YesNo) String() string {
74 | // if yn {
75 | // return "yes"
76 | // }
77 | // return "no"
78 | // }
79 | //
80 | // func (yn *YesNo) FromString(s string) error {
81 | // switch s {
82 | // case "yes":
83 | // *yn = true
84 | // case "no":
85 | // *yn = false
86 | // default:
87 | // return fmt.Errorf("invalid YesNo value: %q", s)
88 | // }
89 | // return nil
90 | // }
91 | func RegisterNamedCoder[T any](name string, adapt func(*T) (Stringable, error)) {
92 | namedStringableAdaptors[name] = &NamedAnyStringableAdaptor{
93 | Name: name,
94 | BaseType: internal.TypeOf[T](),
95 | Adapt: internal.NewAnyStringableAdaptor[T](adapt),
96 | }
97 | }
98 |
99 | // RegisterFileCoder registers the given type T as a file type. T must implement
100 | // the Fileable interface. Remember if you don't register the type explicitly,
101 | // it won't be recognized as a file type.
102 | func RegisterFileCoder[T Fileable]() error {
103 | fileTypes[internal.TypeOf[T]()] = struct{}{}
104 | return nil
105 | }
106 |
107 | type NamedAnyStringableAdaptor struct {
108 | Name string
109 | BaseType reflect.Type
110 | Adapt AnyStringableAdaptor
111 | }
112 |
113 | func isFileType(typ reflect.Type) bool {
114 | baseType, _ := BaseTypeOf(typ)
115 | _, ok := fileTypes[baseType]
116 | return ok
117 | }
118 |
--------------------------------------------------------------------------------
/core/requestbuilder.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "mime/multipart"
9 | "net/http"
10 | "net/url"
11 | "path/filepath"
12 | "strings"
13 | )
14 |
15 | type RequestBuilder struct {
16 | Query url.Values
17 | Form url.Values
18 | Attachment map[string][]FileMarshaler
19 | Header http.Header
20 | Cookie []*http.Cookie
21 | Path map[string]string // placeholder: value
22 | BodyType string // json, xml, etc.
23 | Body io.ReadCloser
24 | ctx context.Context
25 | }
26 |
27 | func NewRequestBuilder(ctx context.Context) *RequestBuilder {
28 | return &RequestBuilder{
29 | Query: make(url.Values),
30 | Form: make(url.Values),
31 | Attachment: make(map[string][]FileMarshaler),
32 | Header: make(http.Header),
33 | Cookie: make([]*http.Cookie, 0),
34 | Path: make(map[string]string),
35 | ctx: ctx,
36 | }
37 | }
38 |
39 | func (rb *RequestBuilder) Populate(req *http.Request) error {
40 | if err := rb.validate(); err != nil {
41 | return err
42 | }
43 |
44 | // Populate the querystring.
45 | req.URL.RawQuery = rb.Query.Encode()
46 |
47 | // Populate forms.
48 | if rb.hasForm() {
49 | if rb.hasAttachment() { // multipart form
50 | if err := rb.populateMultipartForm(req); err != nil {
51 | return err
52 | }
53 | } else { // urlencoded form
54 | rb.populateForm(req)
55 | }
56 | }
57 |
58 | // Populate body.
59 | if rb.hasBody() {
60 | req.Body = rb.Body
61 | rb.Header.Set("Content-Type", rb.bodyContentType())
62 | }
63 |
64 | // Populate path.
65 | if rb.hasPath() {
66 | newPath := req.URL.Path
67 | for key, value := range rb.Path {
68 | newPath = strings.Replace(newPath, "{"+key+"}", value, -1)
69 | }
70 | req.URL.Path = newPath
71 | req.URL.RawPath = ""
72 | }
73 |
74 | // Populate the headers.
75 | if rb.Header != nil {
76 | req.Header = rb.Header
77 | }
78 |
79 | // Populate the cookies.
80 | for _, cookie := range rb.Cookie {
81 | req.AddCookie(cookie)
82 | }
83 |
84 | return nil
85 | }
86 |
87 | func (rb *RequestBuilder) SetQuery(key string, value []string) {
88 | rb.Query[key] = value
89 | }
90 |
91 | func (rb *RequestBuilder) SetForm(key string, value []string) {
92 | rb.Form[key] = value
93 | }
94 |
95 | func (rb *RequestBuilder) SetHeader(key string, value []string) {
96 | rb.Header[http.CanonicalHeaderKey(key)] = value
97 | }
98 |
99 | func (rb *RequestBuilder) SetPath(key string, value []string) {
100 | if len(value) > 0 {
101 | rb.Path[key] = value[0]
102 | }
103 | }
104 |
105 | func (rb *RequestBuilder) SetBody(bodyType string, bodyReader io.ReadCloser) {
106 | rb.BodyType = bodyType
107 | rb.Body = bodyReader
108 | }
109 |
110 | func (rb *RequestBuilder) SetAttachment(key string, files []FileMarshaler) {
111 | rb.Attachment[key] = files
112 | }
113 |
114 | func (rb *RequestBuilder) bodyContentType() string {
115 | switch rb.BodyType {
116 | case "json":
117 | return "application/json"
118 | case "xml":
119 | return "application/xml"
120 | }
121 | return ""
122 | }
123 |
124 | func (rb *RequestBuilder) validate() error {
125 | if rb.hasForm() && rb.hasBody() {
126 | return errors.New("cannot use both form and body directive at the same time")
127 | }
128 | return nil
129 | }
130 |
131 | func (rb *RequestBuilder) hasPath() bool {
132 | return len(rb.Path) > 0
133 | }
134 |
135 | func (rb *RequestBuilder) hasForm() bool {
136 | return len(rb.Form) > 0 || rb.hasAttachment()
137 | }
138 |
139 | func (rb *RequestBuilder) hasAttachment() bool {
140 | return len(rb.Attachment) > 0
141 | }
142 |
143 | func (rb *RequestBuilder) hasBody() bool {
144 | return rb.Body != nil && rb.BodyType != ""
145 | }
146 |
147 | func (rb *RequestBuilder) populateForm(req *http.Request) {
148 | rb.Header.Set("Content-Type", "application/x-www-form-urlencoded")
149 | formData := rb.Form.Encode()
150 | req.Body = io.NopCloser(strings.NewReader(formData))
151 | }
152 |
153 | func (rb *RequestBuilder) populateMultipartForm(req *http.Request) error {
154 | // Create a pipe and a multipart writer.
155 | pr, pw := io.Pipe()
156 | writer := multipart.NewWriter(pw)
157 |
158 | // Write the multipart form data to the pipe in a separate goroutine.
159 | go func() {
160 | defer pw.Close()
161 | defer writer.Close()
162 |
163 | // Populate the form fields.
164 | for k, v := range rb.Form {
165 | for _, sv := range v {
166 | select {
167 | case <-rb.ctx.Done():
168 | pw.CloseWithError(rb.ctx.Err())
169 | return
170 | default:
171 | fieldWriter, _ := writer.CreateFormField(k)
172 | fieldWriter.Write([]byte(sv))
173 | }
174 | }
175 | }
176 |
177 | // Populate the attachments.
178 | for key, files := range rb.Attachment {
179 | for i, file := range files {
180 | select {
181 | case <-rb.ctx.Done():
182 | pw.CloseWithError(rb.ctx.Err())
183 | return
184 | default:
185 | filename := file.Filename()
186 | contentReader, err := file.MarshalFile()
187 | filename = normalizeUploadFilename(key, filename, i)
188 |
189 | if err != nil {
190 | pw.CloseWithError(fmt.Errorf("upload %s %q failed: %w", key, filename, err))
191 | return
192 | }
193 |
194 | fileWriter, _ := writer.CreateFormFile(key, filename)
195 | if _, err = io.Copy(fileWriter, contentReader); err != nil {
196 | pw.CloseWithError(fmt.Errorf("upload %s %q failed: %w", key, filename, err))
197 | return
198 | }
199 | }
200 | }
201 | }
202 | }()
203 |
204 | // Set the body to the read end of the pipe and the content type.
205 | req.Body = io.NopCloser(pr)
206 | rb.Header.Set("Content-Type", writer.FormDataContentType())
207 | return nil
208 | }
209 |
210 | func normalizeUploadFilename(key, filename string, index int) string {
211 | if filename == "" {
212 | return fmt.Sprintf("%s_%d", key, index)
213 | }
214 | return filepath.Base(filename)
215 | }
216 |
--------------------------------------------------------------------------------
/core/requestbuilder_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "context"
5 | "io"
6 | "net/http"
7 | "strings"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | // FIX: https://github.com/ggicci/httpin/issues/88
15 | // Impossible to make streaming
16 | func TestIssue88_RequestBuilderFileUploadStreaming(t *testing.T) {
17 | rb := NewRequestBuilder(context.Background())
18 |
19 | var contentReader = strings.NewReader("hello world")
20 | rb.SetAttachment("file", []FileMarshaler{
21 | UploadStream(io.NopCloser(contentReader)),
22 | })
23 | req, _ := http.NewRequest("GET", "/", nil)
24 | rb.Populate(req)
25 |
26 | err := req.ParseMultipartForm(32 << 20)
27 | assert.NoError(t, err)
28 |
29 | file, fh, err := req.FormFile("file")
30 | assert.NoError(t, err)
31 | assert.Equal(t, "file_0", fh.Filename)
32 | content, err := io.ReadAll(file)
33 | assert.NoError(t, err)
34 | assert.Equal(t, "hello world", string(content))
35 | }
36 |
37 | func TestIssue88_CancelStreaming(t *testing.T) {
38 | ctx, cancel := context.WithCancel(context.Background())
39 |
40 | rb := NewRequestBuilder(ctx)
41 | var contentReader = strings.NewReader("hello world")
42 | rb.SetAttachment("file", []FileMarshaler{
43 | UploadStream(io.NopCloser(contentReader)),
44 | })
45 | req, _ := http.NewRequest("GET", "/", nil)
46 | rb.Populate(req)
47 |
48 | cancel()
49 | time.Sleep(time.Millisecond * 100)
50 |
51 | err := req.ParseMultipartForm(32 << 20)
52 | assert.ErrorContains(t, err, "context canceled")
53 | }
54 |
--------------------------------------------------------------------------------
/core/required.go:
--------------------------------------------------------------------------------
1 | // directive: "required"
2 | // https://ggicci.github.io/httpin/directives/required
3 |
4 | package core
5 |
6 | import "errors"
7 |
8 | // DirectiveRequired implements the "required" executor who indicates that the field must be set.
9 | // If the field value were not set by former executors, errMissingField will be
10 | // returned.
11 | //
12 | // NOTE: the "required" executor does not check the value of the field, it only checks
13 | // if the field is set. In realcases, it's used to require that the key is present in
14 | // the input data, e.g. form, header, etc. But it allows the value to be empty.
15 | type DirectiveRequired struct{}
16 |
17 | func (*DirectiveRequired) Decode(rtm *DirectiveRuntime) error {
18 | if rtm.IsFieldSet() {
19 | return nil
20 | }
21 | return errors.New("missing required field")
22 | }
23 |
24 | func (*DirectiveRequired) Encode(rtm *DirectiveRuntime) error {
25 | if rtm.IsFieldSet() {
26 | return nil
27 | }
28 | return errors.New("missing required field")
29 | }
30 |
--------------------------------------------------------------------------------
/core/required_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/url"
7 | "strings"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | type RequiredQuery struct {
15 | CreatedAt time.Time `in:"form=created_at;required"`
16 | Color string `in:"form=colour,color"`
17 | }
18 |
19 | func TestDirectiveRequired_Decode_RequiredFieldMissing(t *testing.T) {
20 | r, _ := http.NewRequest("GET", "/", nil)
21 | r.Form = url.Values{
22 | "color": {"red"},
23 | }
24 | co, err := New(&RequiredQuery{})
25 | assert.NoError(t, err)
26 | _, err = co.Decode(r)
27 | assert.ErrorContains(t, err, "missing required field")
28 | var invalidField *InvalidFieldError
29 | assert.ErrorAs(t, err, &invalidField)
30 | assert.Equal(t, "CreatedAt", invalidField.Field)
31 | assert.Equal(t, "required", invalidField.Directive)
32 | assert.Empty(t, invalidField.Key)
33 | assert.Nil(t, invalidField.Value)
34 | }
35 |
36 | func TestDirectiveRequired_Decode_NonRequiredFieldAbsent(t *testing.T) {
37 | r, _ := http.NewRequest("GET", "/", nil)
38 | r.Form = url.Values{
39 | "created_at": {"1991-11-10T08:00:00+08:00"},
40 | "is_soldout": {"true"},
41 | "page": {"1"},
42 | "per_page": {"20"},
43 | }
44 | expected := &RequiredQuery{
45 | CreatedAt: time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC),
46 | Color: "",
47 | }
48 | co, err := New(RequiredQuery{})
49 | assert.NoError(t, err)
50 | got, err := co.Decode(r)
51 | assert.NoError(t, err)
52 | assert.Equal(t, expected, got.(*RequiredQuery))
53 | }
54 |
55 | func TestDirectiveRequired_NewRequest_RequiredFieldPresent(t *testing.T) {
56 | co, err := New(&RequiredQuery{})
57 | assert.NoError(t, err)
58 |
59 | payload := &RequiredQuery{
60 | CreatedAt: time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC),
61 | Color: "red",
62 | }
63 | expected, _ := http.NewRequest("GET", "/hello", nil)
64 | expectedForm := url.Values{
65 | "created_at": {"1991-11-10T00:00:00Z"},
66 | "colour": {"red"}, // NOTE: will use the first name in the tag
67 | }
68 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded")
69 | expected.Body = io.NopCloser(strings.NewReader(expectedForm.Encode()))
70 | req, err := co.NewRequest("GET", "/hello", payload)
71 | assert.NoError(t, err)
72 | assert.Equal(t, expected, req)
73 | }
74 |
--------------------------------------------------------------------------------
/core/stringable.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "reflect"
7 | "time"
8 |
9 | "github.com/ggicci/httpin/internal"
10 | "github.com/ggicci/owl"
11 | )
12 |
13 | type Stringable = internal.Stringable
14 |
15 | var ErrUnsupportedType = owl.ErrUnsupportedType
16 |
17 | func NewStringable(rv reflect.Value, adapt AnyStringableAdaptor) (stringable Stringable, err error) {
18 | if IsPatchField(rv.Type()) {
19 | stringable, err = NewStringablePatchFieldWrapper(rv, adapt)
20 | } else {
21 | stringable, err = newStringable(rv, adapt)
22 | }
23 | if err != nil {
24 | return nil, err
25 | }
26 | return stringable, nil
27 | }
28 |
29 | // Create a Stringable from a reflect.Value. If rv is a pointer type, it will
30 | // try to create a Stringable from rv. Otherwise, it will try to create a
31 | // Stringable from rv.Addr(). Only basic built-in types are supported. As a
32 | // special case, time.Time is also supported.
33 | func newStringable(rv reflect.Value, adapt AnyStringableAdaptor) (Stringable, error) {
34 | rv, err := getPointer(rv)
35 | if err != nil {
36 | return nil, err
37 | }
38 |
39 | // Now rv is a pointer type.
40 | if adapt != nil {
41 | return adapt(rv.Interface())
42 | }
43 |
44 | // Custom type adaptors go first. Which means the coder of a specific type
45 | // has already been registered/overridden by user.
46 | if adapt, ok := customStringableAdaptors[rv.Type().Elem()]; ok {
47 | return adapt(rv.Interface())
48 | }
49 |
50 | // For the base type time.Time, it is a special case here.
51 | // We won't use TextMarshaler/TextUnmarshaler for time.Time.
52 | if rv.Type().Elem() == timeType {
53 | return internal.NewStringable(rv)
54 | }
55 |
56 | if hybridCoder := hybridizeCoder(rv); hybridCoder != nil {
57 | return hybridCoder, nil
58 | }
59 |
60 | // Fallback to use built-in stringable types.
61 | return internal.NewStringable(rv)
62 | }
63 |
64 | type StringablePatchFieldWrapper struct {
65 | Value reflect.Value // of patch.Field[T]
66 | internalStringable Stringable
67 | }
68 |
69 | func NewStringablePatchFieldWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringablePatchFieldWrapper, error) {
70 | stringable, err := NewStringable(rv.FieldByName("Value"), adapt)
71 | if err != nil {
72 | return &StringablePatchFieldWrapper{}, fmt.Errorf("cannot create Stringable for PatchField: %w", err)
73 | } else {
74 | return &StringablePatchFieldWrapper{
75 | Value: rv,
76 | internalStringable: stringable,
77 | }, nil
78 | }
79 | }
80 |
81 | func (w *StringablePatchFieldWrapper) ToString() (string, error) {
82 | if w.Value.FieldByName("Valid").Bool() {
83 | return w.internalStringable.ToString()
84 | } else {
85 | return "", errors.New("invalid value") // when Valid is false
86 | }
87 | }
88 |
89 | // FromString sets the value of the wrapped patch.Field[T] from the given
90 | // string. It returns an error if the given string is not valid. And leaves the
91 | // original value of both Value and Valid unchanged. On the other hand, if no
92 | // error occurs, it sets Valid to true.
93 | func (w *StringablePatchFieldWrapper) FromString(s string) error {
94 | if err := w.internalStringable.FromString(s); err != nil {
95 | return err
96 | } else {
97 | w.Value.FieldByName("Valid").SetBool(true)
98 | return nil
99 | }
100 | }
101 |
102 | var timeType = internal.TypeOf[time.Time]()
103 |
104 | func getPointer(rv reflect.Value) (reflect.Value, error) {
105 | if rv.Kind() == reflect.Pointer {
106 | return createInstanceIfNil(rv), nil
107 | } else {
108 | return addressOf(rv)
109 | }
110 | }
111 |
112 | func createInstanceIfNil(rv reflect.Value) reflect.Value {
113 | if rv.IsNil() {
114 | rv.Set(reflect.New(rv.Type().Elem()))
115 | }
116 | return rv
117 | }
118 |
119 | func addressOf(rv reflect.Value) (reflect.Value, error) {
120 | if !rv.CanAddr() {
121 | return rv, fmt.Errorf("cannot get address of value %v", rv)
122 | }
123 | rv = rv.Addr()
124 | return rv, nil
125 | }
126 |
--------------------------------------------------------------------------------
/core/stringable_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "testing"
7 | "time"
8 |
9 | "github.com/ggicci/httpin/internal"
10 | "github.com/ggicci/httpin/patch"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | type Point2D struct {
15 | X int
16 | Y int
17 | }
18 |
19 | func (p Point2D) ToString() (string, error) {
20 | return fmt.Sprintf("Point2D(%d,%d)", p.X, p.Y), nil
21 | }
22 |
23 | func (p *Point2D) FromString(s string) error {
24 | _, err := fmt.Sscanf(s, "Point2D(%d,%d)", &p.X, &p.Y)
25 | return err
26 | }
27 |
28 | type MyStruct struct {
29 | Name string
30 | NamePointer *string
31 | Names []string
32 | PatchName patch.Field[string]
33 | PatchNames patch.Field[[]string]
34 |
35 | Age int
36 | AgePointer *int
37 | Ages []int
38 | PatchAge patch.Field[int]
39 | PatchAges patch.Field[[]int]
40 |
41 | Dot Point2D
42 | DotPointer *Point2D
43 | Dots []Point2D
44 | PatchDot patch.Field[Point2D]
45 | PatchDots patch.Field[[]Point2D]
46 | }
47 |
48 | type MyDate time.Time // adapted from time.Time
49 |
50 | func (t MyDate) ToString() (string, error) {
51 | return time.Time(t).Format("2006-01-02"), nil
52 | }
53 |
54 | func (t *MyDate) FromString(value string) error {
55 | v, err := time.Parse("2006-01-02", value)
56 | if err != nil {
57 | return &InvalidDate{Value: value, Err: err}
58 | }
59 | *t = MyDate(v)
60 | return nil
61 | }
62 |
63 | func TestStringable_FromString(t *testing.T) {
64 | rv := reflect.New(reflect.TypeOf(MyStruct{})).Elem()
65 | s := rv.Addr().Interface().(*MyStruct)
66 |
67 | // string
68 | testAssignString(t, rv.FieldByName("Name"), "Alice")
69 | testAssignString(t, rv.FieldByName("NamePointer"), "Charlie")
70 | testNewStringableErrUnsupported(t, rv.FieldByName("Names"))
71 | testAssignString(t, rv.FieldByName("PatchName"), "Bob")
72 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchNames"))
73 |
74 | assert.Equal(t, "Alice", s.Name)
75 | assert.Equal(t, "Charlie", *s.NamePointer)
76 | assert.Equal(t, []string(nil), s.Names)
77 | assert.Equal(t, "Bob", s.PatchName.Value)
78 | assert.True(t, s.PatchName.Valid)
79 |
80 | // int
81 | testAssignString(t, rv.FieldByName("Age"), "18")
82 | testAssignString(t, rv.FieldByName("AgePointer"), "20")
83 | testNewStringableErrUnsupported(t, rv.FieldByName("Ages"))
84 | testAssignString(t, rv.FieldByName("PatchAge"), "18")
85 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchAges"))
86 |
87 | assert.Equal(t, 18, s.Age)
88 | assert.Equal(t, 20, *s.AgePointer)
89 | assert.Equal(t, []int(nil), s.Ages)
90 | assert.Equal(t, 18, s.PatchAge.Value)
91 | assert.True(t, s.PatchAge.Valid)
92 |
93 | // Point2D
94 | testAssignString(t, rv.FieldByName("Dot"), "Point2D(1,2)")
95 | testAssignString(t, rv.FieldByName("DotPointer"), "Point2D(3,4)")
96 | testNewStringableErrUnsupported(t, rv.FieldByName("Dots"))
97 | testAssignString(t, rv.FieldByName("PatchDot"), "Point2D(5,6)")
98 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchDots"))
99 |
100 | assert.Equal(t, Point2D{1, 2}, s.Dot)
101 | assert.Equal(t, &Point2D{3, 4}, s.DotPointer)
102 | assert.Equal(t, []Point2D(nil), s.Dots)
103 | assert.Equal(t, Point2D{5, 6}, s.PatchDot.Value)
104 | assert.True(t, s.PatchDot.Valid)
105 | }
106 |
107 | func TestStringable_String(t *testing.T) {
108 | var s = &MyStruct{
109 | Name: "Alice",
110 | NamePointer: internal.Pointerize[string]("Charlie"),
111 | Names: []string{"Alice", "Bob", "Charlie"},
112 | PatchName: patch.Field[string]{Value: "Bob", Valid: true},
113 | PatchNames: patch.Field[[]string]{Value: []string{"Alice", "Bob", "Charlie"}, Valid: true},
114 |
115 | Age: 18,
116 | AgePointer: internal.Pointerize[int](20),
117 | Ages: []int{18, 20},
118 | PatchAge: patch.Field[int]{Value: 18, Valid: true},
119 | PatchAges: patch.Field[[]int]{Value: []int{18, 20}, Valid: true},
120 |
121 | Dot: Point2D{1, 2},
122 | DotPointer: internal.Pointerize[Point2D](Point2D{3, 4}),
123 | Dots: []Point2D{{1, 2}, {3, 4}},
124 | PatchDot: patch.Field[Point2D]{Value: Point2D{5, 6}, Valid: true},
125 | PatchDots: patch.Field[[]Point2D]{Value: []Point2D{{1, 2}, {3, 4}}, Valid: true},
126 | }
127 |
128 | rv := reflect.ValueOf(s).Elem()
129 |
130 | assert.Equal(t, "Alice", testGetString(t, rv.FieldByName("Name")))
131 | assert.Equal(t, "Charlie", testGetString(t, rv.FieldByName("NamePointer")))
132 | testNewStringableErrUnsupported(t, rv.FieldByName("Names"))
133 | assert.Equal(t, "Bob", testGetString(t, rv.FieldByName("PatchName")))
134 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchNames"))
135 |
136 | assert.Equal(t, "18", testGetString(t, rv.FieldByName("Age")))
137 | assert.Equal(t, "20", testGetString(t, rv.FieldByName("AgePointer")))
138 | testNewStringableErrUnsupported(t, rv.FieldByName("Ages"))
139 | assert.Equal(t, "18", testGetString(t, rv.FieldByName("PatchAge")))
140 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchAges"))
141 |
142 | assert.Equal(t, "Point2D(1,2)", testGetString(t, rv.FieldByName("Dot")))
143 | assert.Equal(t, "Point2D(3,4)", testGetString(t, rv.FieldByName("DotPointer")))
144 | testNewStringableErrUnsupported(t, rv.FieldByName("Dots"))
145 | assert.Equal(t, "Point2D(5,6)", testGetString(t, rv.FieldByName("PatchDot")))
146 | testNewStringableErrUnsupported(t, rv.FieldByName("PatchDots"))
147 | }
148 |
149 | func TestStringablePatchFieldWrapper_String(t *testing.T) {
150 | var patchString = patch.Field[string]{Value: "Alice", Valid: true}
151 | rv := reflect.ValueOf(&patchString).Elem()
152 | assert.True(t, IsPatchField(rv.Type()))
153 | stringable, err := NewStringablePatchFieldWrapper(rv, nil)
154 | assert.NoError(t, err)
155 |
156 | sv, err := stringable.ToString()
157 | assert.NoError(t, err)
158 | assert.Equal(t, "Alice", sv)
159 |
160 | patchString.Valid = false // make it invalid
161 | sv, err = stringable.ToString()
162 | assert.ErrorContains(t, err, "invalid value")
163 | assert.Empty(t, sv, "invalid patch field should return empty string")
164 | }
165 |
166 | func TestStringablePatchFieldWrapper_FromString(t *testing.T) {
167 | // string
168 | var patchString = patch.Field[string]{}
169 |
170 | assert.Empty(t, patchString.Value)
171 | assert.False(t, patchString.Valid)
172 |
173 | rv := reflect.ValueOf(&patchString).Elem()
174 | assert.True(t, IsPatchField(rv.Type()))
175 | stringable, err := NewStringablePatchFieldWrapper(rv, nil)
176 | assert.NoError(t, err)
177 | assert.NoError(t, stringable.FromString("Alice"))
178 | assert.Equal(t, "Alice", patchString.Value)
179 | assert.True(t, patchString.Valid, "Valid should be set to true after a succssful FromString")
180 |
181 | // int
182 | var patchInt = patch.Field[int]{}
183 | rv = reflect.ValueOf(&patchInt).Elem()
184 | assert.True(t, IsPatchField(rv.Type()))
185 | stringable, err = NewStringablePatchFieldWrapper(rv, nil)
186 | assert.NoError(t, err)
187 | assert.Error(t, stringable.FromString("Alice")) // cannot convert "Alice" to int
188 | assert.Zero(t, patchInt.Value, "Value should not be changed after a failed FromString")
189 | assert.False(t, patchInt.Valid, "Valid should not be changed after a failed FromString")
190 |
191 | assert.NoError(t, stringable.FromString("18"))
192 | assert.Equal(t, 18, patchInt.Value)
193 | assert.True(t, patchInt.Valid, "Valid should be set to true after a succssful FromString")
194 |
195 | assert.Error(t, stringable.FromString("18.0")) // cannot convert "18.0" to int
196 | assert.Equal(t, 18, patchInt.Value, "Value should not be changed after a failed FromString")
197 | assert.True(t, patchInt.Valid, "Valid should not be changed after a failed FromString")
198 | }
199 |
200 | func TestStringable_WithAdaptor(t *testing.T) {
201 | adapt := func(t *time.Time) (Stringable, error) { return (*MyDate)(t), nil }
202 | var now = time.Now()
203 | rvTimePointer := reflect.ValueOf(&now)
204 |
205 | coder, err := NewStringable(rvTimePointer, internal.NewAnyStringableAdaptor[time.Time](adapt))
206 | assert.NoError(t, err)
207 | assert.NoError(t, coder.FromString("1991-11-10"))
208 |
209 | s, err := coder.ToString()
210 | assert.NoError(t, err)
211 | assert.Equal(t, "1991-11-10", s)
212 |
213 | assert.ErrorContains(t, coder.FromString("1991-11-10T08:00:00+08:00"), "parsing time")
214 | }
215 |
216 | type InvalidDate struct {
217 | Value string
218 | Err error
219 | }
220 |
221 | func (e *InvalidDate) Error() string {
222 | return fmt.Sprintf("invalid date: %q (date must conform to format \"2006-01-02\"), %s", e.Value, e.Err)
223 | }
224 |
225 | func (e *InvalidDate) Unwrap() error {
226 | return e.Err
227 | }
228 |
229 | func testAssignString(t *testing.T, rv reflect.Value, value string) {
230 | s, err := NewStringable(rv, nil)
231 | assert.NoError(t, err)
232 | assert.NoError(t, s.FromString(value))
233 | }
234 |
235 | func testNewStringableErrUnsupported(t *testing.T, rv reflect.Value) {
236 | s, err := NewStringable(rv, nil)
237 | assert.ErrorIs(t, err, ErrUnsupportedType)
238 | assert.Nil(t, s)
239 | }
240 |
241 | func testGetString(t *testing.T, rv reflect.Value) string {
242 | coder, err := NewStringable(rv, nil)
243 | assert.NoError(t, err)
244 | s, err := coder.ToString()
245 | assert.NoError(t, err)
246 | return s
247 | }
248 |
--------------------------------------------------------------------------------
/core/stringslicable.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "reflect"
7 |
8 | "github.com/ggicci/httpin/internal"
9 | )
10 |
11 | type StringSlicable interface {
12 | ToStringSlice() ([]string, error)
13 | FromStringSlice([]string) error
14 | }
15 |
16 | func NewStringSlicable(rv reflect.Value, adapt AnyStringableAdaptor) (StringSlicable, error) {
17 | if rv.Type().Implements(stringSliceableType) && rv.CanInterface() {
18 | return rv.Interface().(StringSlicable), nil
19 | }
20 |
21 | if IsPatchField(rv.Type()) {
22 | return NewStringSlicablePatchFieldWrapper(rv, adapt)
23 | }
24 |
25 | if isSliceType(rv.Type()) && !isByteSliceType(rv.Type()) {
26 | return NewStringableSliceWrapper(rv, adapt)
27 | } else {
28 | return NewStringSlicableSingleStringableWrapper(rv, adapt)
29 | }
30 | }
31 |
32 | // StringSlicablePatchFieldWrapper wraps a patch.Field[T] to implement
33 | // StringSlicable. The wrapped reflect.Value must be a patch.Field[T].
34 | //
35 | // It works like a proxy. It delegates the ToStringSlice and FromStringSlice
36 | // calls to the internal StringSlicable.
37 | type StringSlicablePatchFieldWrapper struct {
38 | Value reflect.Value // of patch.Field[T]
39 | internalStringSliceable StringSlicable
40 | }
41 |
42 | // NewStringSlicablePatchFieldWrapper creates a StringSlicablePatchFieldWrapper from rv.
43 | // Returns error when patch.Field.Value is not a StringSlicable.
44 | func NewStringSlicablePatchFieldWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringSlicablePatchFieldWrapper, error) {
45 | stringSlicable, err := NewStringSlicable(rv.FieldByName("Value"), adapt)
46 | if err != nil {
47 | return nil, err
48 | } else {
49 | return &StringSlicablePatchFieldWrapper{
50 | Value: rv,
51 | internalStringSliceable: stringSlicable,
52 | }, nil
53 | }
54 | }
55 |
56 | func (w *StringSlicablePatchFieldWrapper) ToStringSlice() ([]string, error) {
57 | if w.Value.FieldByName("Valid").Bool() {
58 | return w.internalStringSliceable.ToStringSlice()
59 | } else {
60 | return []string{}, nil
61 | }
62 | }
63 |
64 | func (w *StringSlicablePatchFieldWrapper) FromStringSlice(values []string) error {
65 | if err := w.internalStringSliceable.FromStringSlice(values); err != nil {
66 | return err
67 | } else {
68 | w.Value.FieldByName("Valid").SetBool(true)
69 | return nil
70 | }
71 | }
72 |
73 | type StringableSlice []Stringable
74 |
75 | func (sa StringableSlice) ToStringSlice() ([]string, error) {
76 | values := make([]string, len(sa))
77 | for i, s := range sa {
78 | if value, err := s.ToString(); err != nil {
79 | return nil, fmt.Errorf("cannot stringify %v at index %d: %w", s, i, err)
80 | } else {
81 | values[i] = value
82 | }
83 | }
84 | return values, nil
85 | }
86 |
87 | func (sa StringableSlice) FromStringSlice(values []string) error {
88 | for i, s := range values {
89 | if err := sa[i].FromString(s); err != nil {
90 | return fmt.Errorf("cannot convert from string %q at index %d: %w", s, i, err)
91 | }
92 | }
93 | return nil
94 | }
95 |
96 | // StringableSliceWrapper wraps a reflect.Value to implement StringSlicable. The
97 | // wrapped reflect.Value must be a slice of Stringable.
98 | type StringableSliceWrapper struct {
99 | Value reflect.Value
100 | Adapt AnyStringableAdaptor
101 | }
102 |
103 | // NewStringableSliceWrapper creates a StringableSliceWrapper from rv.
104 | // Returns error when rv is not a slice of Stringable or cannot get address of rv.
105 | func NewStringableSliceWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringableSliceWrapper, error) {
106 | if !rv.CanAddr() {
107 | return nil, errors.New("unaddressable value")
108 | }
109 | return &StringableSliceWrapper{Value: rv, Adapt: adapt}, nil
110 | }
111 |
112 | func (w *StringableSliceWrapper) ToStringSlice() ([]string, error) {
113 | var stringables = make(StringableSlice, w.Value.Len())
114 | for i := 0; i < w.Value.Len(); i++ {
115 | if stringable, err := NewStringable(w.Value.Index(i), w.Adapt); err != nil {
116 | return nil, fmt.Errorf("cannot create Stringable from %q at index %d: %w", w.Value.Index(i), i, err)
117 | } else {
118 | stringables[i] = stringable
119 | }
120 | }
121 | return stringables.ToStringSlice()
122 | }
123 |
124 | func (w *StringableSliceWrapper) FromStringSlice(ss []string) error {
125 | var stringables = make(StringableSlice, len(ss))
126 | w.Value.Set(reflect.MakeSlice(w.Value.Type(), len(ss), len(ss)))
127 | for i := range ss {
128 | if stringable, err := NewStringable(w.Value.Index(i), w.Adapt); err != nil {
129 | return fmt.Errorf("cannot create Stringable at index %d: %w", i, err)
130 | } else {
131 | stringables[i] = stringable
132 | }
133 | }
134 | return stringables.FromStringSlice(ss)
135 | }
136 |
137 | // StringSlicableSingleStringableWrapper wraps a reflect.Value to implement
138 | // StringSlicable. The wrapped reflect.Value must be a Stringable.
139 | type StringSlicableSingleStringableWrapper struct{ Stringable }
140 |
141 | func NewStringSlicableSingleStringableWrapper(rv reflect.Value, adapt AnyStringableAdaptor) (*StringSlicableSingleStringableWrapper, error) {
142 | if stringable, err := NewStringable(rv, adapt); err != nil {
143 | return nil, err
144 | } else {
145 | return &StringSlicableSingleStringableWrapper{stringable}, nil
146 | }
147 | }
148 |
149 | func (w *StringSlicableSingleStringableWrapper) ToStringSlice() ([]string, error) {
150 | if value, err := w.ToString(); err != nil {
151 | return nil, err
152 | } else {
153 | return []string{value}, nil
154 | }
155 | }
156 |
157 | func (w *StringSlicableSingleStringableWrapper) FromStringSlice(values []string) error {
158 | if len(values) > 0 {
159 | return w.FromString(values[0])
160 | }
161 | return nil
162 | }
163 |
164 | var (
165 | stringSliceableType = internal.TypeOf[StringSlicable]()
166 | byteType = internal.TypeOf[byte]()
167 | )
168 |
169 | func isSliceType(t reflect.Type) bool {
170 | return t.Kind() == reflect.Slice || t.Kind() == reflect.Array
171 | }
172 |
173 | func isByteSliceType(t reflect.Type) bool {
174 | if isSliceType(t) && t.Elem() == byteType {
175 | return true
176 | }
177 | return false
178 | }
179 |
--------------------------------------------------------------------------------
/core/stringslicable_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/ggicci/httpin/internal"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestStringSlicable_FromStringSlice(t *testing.T) {
12 | rv := reflect.New(reflect.TypeOf(MyStruct{})).Elem()
13 | s := rv.Addr().Interface().(*MyStruct)
14 |
15 | testAssignStringSlice(t, rv.FieldByName("Name"), []string{"Alice"})
16 | testAssignStringSlice(t, rv.FieldByName("NamePointer"), []string{"Charlie"})
17 | testAssignStringSlice(t, rv.FieldByName("Names"), []string{"Alice", "Bob", "Charlie"})
18 |
19 | testAssignStringSlice(t, rv.FieldByName("Age"), []string{"18"})
20 | testAssignStringSlice(t, rv.FieldByName("AgePointer"), []string{"20"})
21 | testAssignStringSlice(t, rv.FieldByName("Ages"), []string{"18", "20"})
22 |
23 | assert.Equal(t, "Alice", s.Name)
24 | assert.Equal(t, "Charlie", *s.NamePointer)
25 | assert.Equal(t, []string{"Alice", "Bob", "Charlie"}, s.Names)
26 |
27 | assert.Equal(t, 18, s.Age)
28 | assert.Equal(t, 20, *s.AgePointer)
29 | assert.Equal(t, []int{18, 20}, s.Ages)
30 | }
31 |
32 | func TestStringSlicable_ToStringSlice(t *testing.T) {
33 | var s = &MyStruct{
34 | Name: "Alice",
35 | NamePointer: internal.Pointerize[string]("Charlie"),
36 | Names: []string{"Alice", "Bob", "Charlie"},
37 |
38 | Age: 18,
39 | AgePointer: internal.Pointerize[int](20),
40 | Ages: []int{18, 20},
41 | }
42 |
43 | rv := reflect.ValueOf(s).Elem()
44 | assert.Equal(t, []string{"Alice"}, testGetStringSlice(t, rv.FieldByName("Name")))
45 | assert.Equal(t, []string{"Charlie"}, testGetStringSlice(t, rv.FieldByName("NamePointer")))
46 | assert.Equal(t, []string{"Alice", "Bob", "Charlie"}, testGetStringSlice(t, rv.FieldByName("Names")))
47 |
48 | assert.Equal(t, []string{"18"}, testGetStringSlice(t, rv.FieldByName("Age")))
49 | assert.Equal(t, []string{"20"}, testGetStringSlice(t, rv.FieldByName("AgePointer")))
50 | assert.Equal(t, []string{"18", "20"}, testGetStringSlice(t, rv.FieldByName("Ages")))
51 | }
52 |
53 | func testAssignStringSlice(t *testing.T, rv reflect.Value, values []string) {
54 | ss, err := NewStringSlicable(rv, nil)
55 | assert.NoError(t, err)
56 | assert.NoError(t, ss.FromStringSlice(values))
57 | }
58 |
59 | func testGetStringSlice(t *testing.T, rv reflect.Value) []string {
60 | ss, err := NewStringSlicable(rv, nil)
61 | assert.NoError(t, err)
62 | values, err := ss.ToStringSlice()
63 | assert.NoError(t, err)
64 | return values
65 | }
66 |
--------------------------------------------------------------------------------
/core/typekind.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "reflect"
5 | "strings"
6 | )
7 |
8 | type TypeKind int
9 |
10 | const (
11 | TypeKindT TypeKind = iota // T
12 | TypeKindTSlice // []T
13 | TypeKindPatchT // patch.Field[T]
14 | TypeKindPatchTSlice // patch.Field[[]T]
15 | )
16 |
17 | // BaseTypeOf returns the base type of the given type its kind. The kind
18 | // represents how the given type is constructed from the base type.
19 | // - T -> T, TypeKindT
20 | // - []T -> T, TypeKindTSlice
21 | // - patch.Field[T] -> T, TypeKindPatchT
22 | // - patch.Field[[]T] -> T, TypeKindPatchTSlice
23 | func BaseTypeOf(valueType reflect.Type) (reflect.Type, TypeKind) {
24 | if valueType.Kind() == reflect.Slice {
25 | return valueType.Elem(), TypeKindTSlice
26 | }
27 | if IsPatchField(valueType) {
28 | subElemType, isMulti := patchFieldElemType(valueType)
29 | if isMulti {
30 | return subElemType, TypeKindPatchTSlice
31 | } else {
32 | return subElemType, TypeKindPatchT
33 | }
34 | }
35 | return valueType, TypeKindT
36 | }
37 |
38 | func IsPatchField(t reflect.Type) bool {
39 | return t.Kind() == reflect.Struct &&
40 | t.PkgPath() == "github.com/ggicci/httpin/patch" &&
41 | strings.HasPrefix(t.Name(), "Field[")
42 | }
43 |
44 | func patchFieldElemType(t reflect.Type) (reflect.Type, bool) {
45 | fv, _ := t.FieldByName("Value")
46 | if fv.Type.Kind() == reflect.Slice {
47 | return fv.Type.Elem(), true
48 | }
49 | return fv.Type, false
50 | }
51 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/ggicci/httpin
2 |
3 | go 1.23
4 |
5 | require (
6 | github.com/ggicci/owl v0.8.2
7 | github.com/go-chi/chi/v5 v5.0.11
8 | github.com/gorilla/mux v1.8.1
9 | github.com/justinas/alice v1.2.0
10 | github.com/labstack/echo/v4 v4.12.0
11 | github.com/stretchr/testify v1.9.0
12 | )
13 |
14 | require (
15 | github.com/davecgh/go-spew v1.1.1 // indirect
16 | github.com/labstack/gommon v0.4.2 // indirect
17 | github.com/mattn/go-colorable v0.1.13 // indirect
18 | github.com/mattn/go-isatty v0.0.20 // indirect
19 | github.com/pmezard/go-difflib v1.0.0 // indirect
20 | github.com/valyala/bytebufferpool v1.0.0 // indirect
21 | github.com/valyala/fasttemplate v1.2.2 // indirect
22 | golang.org/x/crypto v0.22.0 // indirect
23 | golang.org/x/net v0.24.0 // indirect
24 | golang.org/x/sys v0.19.0 // indirect
25 | golang.org/x/text v0.14.0 // indirect
26 | gopkg.in/yaml.v3 v3.0.1 // indirect
27 | )
28 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA=
4 | github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4=
5 | github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA=
6 | github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
7 | github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
8 | github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
9 | github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
10 | github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
11 | github.com/labstack/echo/v4 v4.12.0 h1:IKpw49IMryVB2p1a4dzwlhP1O2Tf2E0Ir/450lH+kI0=
12 | github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM=
13 | github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
14 | github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
15 | github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
16 | github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
17 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
18 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
19 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
22 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
23 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
24 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
25 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
26 | github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
27 | github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
28 | golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
29 | golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
30 | golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
31 | golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
32 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
33 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
34 | golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
35 | golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
36 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
37 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
38 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
39 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
40 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
41 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
42 |
--------------------------------------------------------------------------------
/httpin.go:
--------------------------------------------------------------------------------
1 | // Package httpin helps decoding an HTTP request to a custom struct by binding
2 | // data with querystring (query params), HTTP headers, form data, JSON/XML
3 | // payloads, URL path params, and file uploads (multipart/form-data).
4 | package httpin
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "io"
10 | "net/http"
11 | "reflect"
12 |
13 | "github.com/ggicci/httpin/core"
14 | "github.com/ggicci/httpin/internal"
15 | )
16 |
17 | type contextKey int
18 |
19 | const (
20 | // Input is the key to get the input object from Request.Context() injected by httpin. e.g.
21 | //
22 | // input := r.Context().Value(httpin.Input).(*InputStruct)
23 | Input contextKey = iota
24 | )
25 |
26 | // Option is a collection of options for creating a Core instance.
27 | var Option coreOptions = coreOptions{
28 | WithErrorHandler: core.WithErrorHandler,
29 | WithMaxMemory: core.WithMaxMemory,
30 | WithNestedDirectivesEnabled: core.WithNestedDirectivesEnabled,
31 | }
32 |
33 | // New calls core.New to create a new Core instance. Which is responsible for both:
34 | //
35 | // - decoding an HTTP request to an instance of the inputStruct;
36 | // - and encoding an instance of the inputStruct to an HTTP request.
37 | //
38 | // Note that the Core instance is bound to the given specific type, it will not
39 | // work for other types. If you want to decode/encode other types, you need to
40 | // create another Core instance. Or directly use the following functions, which are
41 | // just shortcuts of Core's methods, so you don't need to create a Core instance:
42 | // - httpin.Decode(): decode an HTTP request to an instance of the inputStruct.
43 | // - httpin.NewRequest() to encode an instance of the inputStruct to an HTTP request.
44 | //
45 | // For best practice, we would recommend using httpin.NewInput() to create an
46 | // HTTP middleware for a specific input type. The middleware can be bound to an
47 | // API, chained with other middlewares, and also reused in other APIs. You even
48 | // don't need to call the Deocde() method explicitly, the middleware will do it
49 | // for you and put the decoded instance to the request's context.
50 | func New(inputStruct any, opts ...core.Option) (*core.Core, error) {
51 | return core.New(inputStruct, opts...)
52 | }
53 |
54 | // File is the builtin type of httpin to manupulate file uploads. On the server
55 | // side, it is used to represent a file in a multipart/form-data request. On the
56 | // client side, it is used to represent a file to be uploaded.
57 | type File = core.File
58 |
59 | // UploadFile is a helper function to create a File instance from a file path.
60 | // It is useful when you want to upload a file from the local file system.
61 | func UploadFile(path string) *File {
62 | return core.UploadFile(path)
63 | }
64 |
65 | // UploadStream is a helper function to create a File instance from a io.Reader. It
66 | // is useful when you want to upload a file from a stream.
67 | func UploadStream(r io.ReadCloser) *File {
68 | return core.UploadStream(r)
69 | }
70 |
71 | // DecodeTo decodes an HTTP request and populates input with data from the HTTP request.
72 | // The input must be a pointer to a struct instance. For example:
73 | //
74 | // input := &InputStruct{}
75 | // if err := DecodeTo(req, input); err != nil { ... }
76 | //
77 | // input is now populated with data from the request.
78 | func DecodeTo(req *http.Request, input any, opts ...core.Option) error {
79 | co, err := New(internal.DereferencedType(input), opts...)
80 | if err != nil {
81 | return err
82 | }
83 | return co.DecodeTo(req, input)
84 | }
85 |
86 | // Decode decodes an HTTP request to an instance of T and returns its pointer
87 | // (*T). T must be a struct type. For example:
88 | //
89 | // if user, err := Decode[User](req); err != nil { ... }
90 | // // now user is a *User instance, which has been populated with data from the request.
91 | func Decode[T any](req *http.Request, opts ...core.Option) (*T, error) {
92 | rt := internal.TypeOf[T]()
93 | if rt.Kind() != reflect.Struct {
94 | return nil, errors.New("generic type T must be a struct type")
95 | }
96 | co, err := New(rt, opts...)
97 | if err != nil {
98 | return nil, err
99 | }
100 | if v, err := co.Decode(req); err != nil {
101 | return nil, err
102 | } else {
103 | return v.(*T), nil
104 | }
105 | }
106 |
107 | // NewRequest wraps NewRequestWithContext using context.Background(), see NewRequestWithContext.
108 | func NewRequest(method, url string, input any, opts ...core.Option) (*http.Request, error) {
109 | return NewRequestWithContext(context.Background(), method, url, input)
110 | }
111 |
112 | // NewRequestWithContext turns the given input into an HTTP request. The input
113 | // must be a struct instance. And its fields' "in" tags define how to bind the
114 | // data from the struct to the HTTP request. Use it as the replacement of
115 | // http.NewRequest().
116 | //
117 | // addUserPayload := &AddUserRequest{...}
118 | // addUserRequest, err := NewRequestWithContext(context.Background(), "GET", "http://example.com", addUserPayload)
119 | // http.DefaultClient.Do(addUserRequest)
120 | func NewRequestWithContext(ctx context.Context, method, url string, input any, opts ...core.Option) (*http.Request, error) {
121 | co, err := New(input, opts...)
122 | if err != nil {
123 | return nil, err
124 | }
125 | return co.NewRequestWithContext(ctx, method, url, input)
126 | }
127 |
128 | // NewInput creates an HTTP middleware handler. Which is a function that takes
129 | // in an http.Handler and returns another http.Handler.
130 | //
131 | // The middleware created by NewInput is to add the decoding function to an
132 | // existing http.Handler. This functionality will decode the HTTP request into a
133 | // struct instance and put its pointer to the request's context. So that the
134 | // next hop can get the decoded struct instance from the request's context.
135 | //
136 | // We recommend using https://github.com/justinas/alice to chain your
137 | // middlewares. If you're using some popular web frameworks, they may have
138 | // already provided a middleware chaining mechanism.
139 | //
140 | // For example:
141 | //
142 | // type ListUsersRequest struct {
143 | // Page int `in:"query=page,page_index,index"`
144 | // PerPage int `in:"query=per_page,page_size"`
145 | // }
146 | //
147 | // func ListUsersHandler(rw http.ResponseWriter, r *http.Request) {
148 | // input := r.Context().Value(httpin.Input).(*ListUsersRequest)
149 | // // ...
150 | // }
151 | //
152 | // func init() {
153 | // http.Handle("/users", alice.New(httpin.NewInput(&ListUsersRequest{})).ThenFunc(ListUsersHandler))
154 | // }
155 | func NewInput(inputStruct any, opts ...core.Option) func(http.Handler) http.Handler {
156 | co, err := New(inputStruct, opts...)
157 | internal.PanicOnError(err)
158 |
159 | return func(next http.Handler) http.Handler {
160 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
161 | // Here we read the request and decode it to fill our structure.
162 | // Once failed, the request should end here.
163 | input, err := co.Decode(r)
164 | if err != nil {
165 | co.GetErrorHandler()(rw, r, err)
166 | return
167 | }
168 |
169 | // We put the `input` to the request's context, and it will pass to the next hop.
170 | ctx := context.WithValue(r.Context(), Input, input)
171 | next.ServeHTTP(rw, r.WithContext(ctx))
172 | })
173 | }
174 | }
175 |
176 | type coreOptions struct {
177 | // WithErrorHandler overrides the default error handler.
178 | WithErrorHandler func(core.ErrorHandler) core.Option
179 |
180 | // WithMaxMemory overrides the default maximum memory size (32MB) when reading
181 | // the request body. See https://pkg.go.dev/net/http#Request.ParseMultipartForm
182 | // for more details.
183 | WithMaxMemory func(int64) core.Option
184 |
185 | // WithNestedDirectivesEnabled enables/disables nested directives.
186 | WithNestedDirectivesEnabled func(bool) core.Option
187 | }
188 |
--------------------------------------------------------------------------------
/httpin_test.go:
--------------------------------------------------------------------------------
1 | package httpin
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "io"
7 | "net/http"
8 | "net/http/httptest"
9 | "net/url"
10 | "strings"
11 | "testing"
12 |
13 | "github.com/ggicci/httpin/core"
14 | "github.com/justinas/alice"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | type Pagination struct {
19 | Page int `in:"form=page,page_index,index"`
20 | PerPage int `in:"form=per_page,page_size"`
21 | }
22 |
23 | func testcasePagination1100() (*http.Request, *Pagination) {
24 | r, _ := http.NewRequest("GET", "/", nil)
25 | r.Form = url.Values{
26 | "page": {"1"},
27 | "per_page": {"100"},
28 | }
29 | return r, &Pagination{
30 | Page: 1,
31 | PerPage: 100,
32 | }
33 | }
34 |
35 | func TestDecodeTo(t *testing.T) {
36 | r, expected := testcasePagination1100()
37 |
38 | func() {
39 | input := &Pagination{}
40 | err := DecodeTo(r, input) // pointer to a struct instance
41 | assert.NoError(t, err)
42 | assert.Equal(t, expected, input)
43 | }()
44 |
45 | func() {
46 | input := Pagination{}
47 | err := DecodeTo(r, &input) // addressable struct instance
48 | assert.NoError(t, err)
49 | assert.Equal(t, expected, &input)
50 | }()
51 |
52 | func() {
53 | input := &Pagination{}
54 | err := DecodeTo(r, &input) // pointer to pointer of struct instance
55 | assert.NoError(t, err)
56 | assert.Equal(t, expected, input)
57 | }()
58 |
59 | func() {
60 | input := Pagination{}
61 | err := DecodeTo(r, input) // non-pointer struct instance should fail
62 | assert.ErrorContains(t, err, "invalid resolve target")
63 | }()
64 | }
65 |
66 | func TestDecode(t *testing.T) {
67 | r, expected := testcasePagination1100()
68 |
69 | p, err := Decode[Pagination](r)
70 | assert.NoError(t, err)
71 | assert.Equal(t, expected, p)
72 | }
73 |
74 | func TestDecode_ErrNotAStruct(t *testing.T) {
75 | r, _ := testcasePagination1100()
76 |
77 | _, err := Decode[int](r)
78 | assert.ErrorContains(t, err, "T must be a struct type")
79 |
80 | _, err = Decode[*Pagination](r)
81 | assert.ErrorContains(t, err, "T must be a struct type")
82 | }
83 |
84 | func TestDecode_ErrBuildResolverFailed(t *testing.T) {
85 | r, _ := testcasePagination1100()
86 |
87 | type Foo struct {
88 | Name string `in:"nonexistent=foo"`
89 | }
90 |
91 | assert.Error(t, DecodeTo(r, &Foo{}))
92 |
93 | v, err := Decode[Foo](r)
94 | assert.Nil(t, v)
95 | assert.Error(t, err)
96 | }
97 |
98 | func TestDecode_ErrDecodeFailure(t *testing.T) {
99 | r, _ := http.NewRequest("GET", "/", nil)
100 | r.Form = url.Values{
101 | "page": {"1"},
102 | "per_page": {"one-hundred"},
103 | }
104 |
105 | p := &Pagination{}
106 | assert.Error(t, DecodeTo(r, p))
107 |
108 | v, err := Decode[Pagination](r)
109 | assert.Nil(t, v)
110 | assert.Error(t, err)
111 | }
112 |
113 | type EchoInput struct {
114 | Token string `in:"form=access_token;header=x-api-key;required"`
115 | Saying string `in:"form=saying"`
116 | }
117 |
118 | func EchoHandler(rw http.ResponseWriter, r *http.Request) {
119 | var input = r.Context().Value(Input).(*EchoInput)
120 | json.NewEncoder(rw).Encode(input)
121 | }
122 |
123 | func TestNewInput_WithNil(t *testing.T) {
124 | assert.Panics(t, func() {
125 | NewInput(nil)
126 | })
127 | }
128 |
129 | func TestNewInput_Success(t *testing.T) {
130 | r, err := http.NewRequest("GET", "/", nil)
131 | assert.NoError(t, err)
132 |
133 | r.Header.Add("X-Api-Key", "abc")
134 | var params = url.Values{}
135 | params.Add("saying", "TO THINE OWE SELF BE TRUE")
136 | r.URL.RawQuery = params.Encode()
137 |
138 | rw := httptest.NewRecorder()
139 | handler := alice.New(NewInput(EchoInput{})).ThenFunc(EchoHandler)
140 | handler.ServeHTTP(rw, r)
141 | assert.Equal(t, 200, rw.Code)
142 | expected := `{"Token":"abc","Saying":"TO THINE OWE SELF BE TRUE"}` + "\n"
143 | assert.Equal(t, expected, rw.Body.String())
144 | }
145 |
146 | func TestNewInput_ErrorHandledByDefaultErrorHandler(t *testing.T) {
147 | r, err := http.NewRequest("GET", "/", nil)
148 | assert.NoError(t, err)
149 |
150 | var params = url.Values{}
151 | params.Add("saying", "TO THINE OWE SELF BE TRUE")
152 | r.URL.RawQuery = params.Encode()
153 |
154 | rw := httptest.NewRecorder()
155 | handler := alice.New(NewInput(EchoInput{})).ThenFunc(EchoHandler)
156 | handler.ServeHTTP(rw, r)
157 | var out map[string]any
158 | assert.Nil(t, json.NewDecoder(rw.Body).Decode(&out))
159 |
160 | assert.Equal(t, 422, rw.Code)
161 | assert.Equal(t, "Token", out["field"])
162 | assert.Equal(t, "required", out["directive"])
163 | assert.Contains(t, out["error"], "missing required field")
164 | }
165 |
166 | func CustomErrorHandler(rw http.ResponseWriter, r *http.Request, err error) {
167 | var invalidFieldError *core.InvalidFieldError
168 | if errors.As(err, &invalidFieldError) {
169 | rw.WriteHeader(http.StatusBadRequest) // status: 400
170 | io.WriteString(rw, invalidFieldError.Error())
171 | return
172 | }
173 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) // status: 500
174 | }
175 |
176 | func TestNewRequest(t *testing.T) {
177 | req, err := NewRequest("GET", "/products", &Pagination{
178 | Page: 19,
179 | PerPage: 50,
180 | })
181 | assert.NoError(t, err)
182 |
183 | expected, _ := http.NewRequest("GET", "/products", nil)
184 | expected.Body = io.NopCloser(strings.NewReader("page=19&per_page=50"))
185 | expected.Header.Set("Content-Type", "application/x-www-form-urlencoded")
186 | assert.Equal(t, expected, req)
187 | }
188 |
189 | func TestNewRequest_ErrNewFailure(t *testing.T) {
190 | _, err := NewRequest("GET", "/products", 123)
191 | assert.Error(t, err)
192 | }
193 |
--------------------------------------------------------------------------------
/integration/echo.go:
--------------------------------------------------------------------------------
1 | package integration
2 |
3 | import (
4 | "mime/multipart"
5 |
6 | "github.com/ggicci/httpin/core"
7 | "github.com/labstack/echo/v4"
8 | )
9 |
10 | // UseEchoRouter registers a new directive executor which can extract values
11 | // from path variables.
12 | // https://ggicci.github.io/httpin/integrations/echo
13 | //
14 | // Usage:
15 | //
16 | // import httpin_integration "github.com/ggicci/httpin/integration"
17 | //
18 | // func init() {
19 | // e := echo.New()
20 | // httpin_integration.UseEchoRouter("path", e)
21 | // }
22 | func UseEchoRouter(name string, e *echo.Echo) {
23 | core.RegisterDirective(
24 | name,
25 | core.NewDirectivePath((&echoRouterExtractor{e}).Execute),
26 | true,
27 | )
28 | }
29 |
30 | // echoRouterExtractor is an extractor for mux.Vars
31 | type echoRouterExtractor struct {
32 | e *echo.Echo
33 | }
34 |
35 | func (mux *echoRouterExtractor) Execute(rtm *core.DirectiveRuntime) error {
36 | req := rtm.GetRequest()
37 | kvs := make(map[string][]string)
38 |
39 | c := mux.e.NewContext(req, nil)
40 | c.SetRequest(req)
41 |
42 | mux.e.Router().Find(req.Method, req.URL.Path, c)
43 |
44 | for _, key := range c.ParamNames() {
45 | kvs[key] = []string{c.Param(key)}
46 | }
47 |
48 | extractor := &core.FormExtractor{
49 | Runtime: rtm,
50 | Form: multipart.Form{
51 | Value: kvs,
52 | },
53 | }
54 | return extractor.Extract()
55 | }
56 |
--------------------------------------------------------------------------------
/integration/echo_test.go:
--------------------------------------------------------------------------------
1 | package integration_test
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/http/httptest"
7 | "strings"
8 | "testing"
9 |
10 | "github.com/ggicci/httpin"
11 | httpin_integration "github.com/ggicci/httpin/integration"
12 | "github.com/labstack/echo/v4"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestUseEchoMux(t *testing.T) {
17 | e := echo.New()
18 | // NOTE: I removed the API UseEchoPathRouter because it introduces minimal benefit
19 | // but adds surface area and maintenance cost.
20 | httpin_integration.UseEchoRouter("path", e)
21 |
22 | req := httptest.NewRequest(http.MethodGet, "/users/ggicci/posts/123", nil)
23 | req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
24 | rec := httptest.NewRecorder()
25 |
26 | c := e.NewContext(req, rec)
27 |
28 | handler := func(ctx echo.Context) error {
29 | param := &GetPostOfUserInput{}
30 | core, err := httpin.New(param)
31 | if err != nil {
32 | return err
33 | }
34 | v, err := core.Decode(ctx.Request())
35 | if err != nil {
36 | return err
37 | }
38 | fmt.Println(param)
39 | return c.JSON(http.StatusOK, v)
40 | }
41 | e.GET("/users/:username/posts/:pid", handler)
42 | err := handler(c)
43 | assert.NoError(t, err)
44 | assert.Equal(t, http.StatusOK, rec.Code)
45 | assert.Equal(t, `{"Username":"ggicci","PostID":123}`, strings.TrimSpace(rec.Body.String()))
46 | }
47 |
--------------------------------------------------------------------------------
/integration/gochi.go:
--------------------------------------------------------------------------------
1 | // integration: "gochi"
2 | // https://ggicci.github.io/httpin/integrations/gochi
3 |
4 | package integration
5 |
6 | import (
7 | "mime/multipart"
8 | "net/http"
9 |
10 | "github.com/ggicci/httpin/core"
11 | )
12 |
13 | // GochiURLParamFunc is chi.URLParam
14 | type GochiURLParamFunc func(r *http.Request, key string) string
15 |
16 | // UseGochiURLParam registers a directive executor which can extract values
17 | // from `chi.URLParam`, i.e. path variables.
18 | // https://ggicci.github.io/httpin/integrations/gochi
19 | //
20 | // Usage:
21 | //
22 | // import httpin_integration "github.com/ggicci/httpin/integration"
23 | //
24 | // func init() {
25 | // httpin_integration.UseGochiURLParam("path", chi.URLParam)
26 | // }
27 | func UseGochiURLParam(name string, fn GochiURLParamFunc) {
28 | core.RegisterDirective(
29 | name,
30 | core.NewDirectivePath((&gochiURLParamExtractor{URLParam: fn}).Execute),
31 | true,
32 | )
33 | }
34 |
35 | type gochiURLParamExtractor struct {
36 | URLParam GochiURLParamFunc
37 | }
38 |
39 | func (chi *gochiURLParamExtractor) Execute(rtm *core.DirectiveRuntime) error {
40 | req := rtm.GetRequest()
41 | kvs := make(map[string][]string)
42 |
43 | for _, key := range rtm.Directive.Argv {
44 | value := chi.URLParam(req, key)
45 | if value != "" {
46 | kvs[key] = []string{value}
47 | }
48 | }
49 |
50 | extractor := &core.FormExtractor{
51 | Runtime: rtm,
52 | Form: multipart.Form{
53 | Value: kvs,
54 | },
55 | }
56 | return extractor.Extract()
57 | }
58 |
--------------------------------------------------------------------------------
/integration/gochi_test.go:
--------------------------------------------------------------------------------
1 | package integration_test
2 |
3 | import (
4 | "encoding/json"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/ggicci/httpin"
10 | httpin_integration "github.com/ggicci/httpin/integration"
11 | "github.com/go-chi/chi/v5"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | type GetArticleOfUserInput struct {
16 | Author string `in:"gochi=author"` // equivalent to chi.URLParam("author")
17 | ArticleID int64 `in:"gochi=articleID"`
18 | }
19 |
20 | func GetArticleOfUser(rw http.ResponseWriter, r *http.Request) {
21 | var input = r.Context().Value(httpin.Input).(*GetArticleOfUserInput)
22 | json.NewEncoder(rw).Encode(input)
23 | }
24 |
25 | func TestUseGochiURLParam(t *testing.T) {
26 | // Register the "gochi" directive, usually in init().
27 | // In most cases, you register this as "path", here's just an example.
28 | // Which is in order to avoid test conflicts with other tests
29 | httpin_integration.UseGochiURLParam("gochi", chi.URLParam)
30 |
31 | rw := httptest.NewRecorder()
32 | r, err := http.NewRequest("GET", "/ggicci/articles/1024", nil)
33 | assert.NoError(t, err)
34 |
35 | router := chi.NewRouter()
36 | router.With(
37 | httpin.NewInput(GetArticleOfUserInput{}),
38 | ).Get("/{author}/articles/{articleID}", GetArticleOfUser)
39 |
40 | router.ServeHTTP(rw, r)
41 | assert.Equal(t, 200, rw.Code)
42 | expected := `{"Author":"ggicci","ArticleID":1024}` + "\n"
43 | assert.Equal(t, expected, rw.Body.String())
44 | }
45 |
--------------------------------------------------------------------------------
/integration/gorilla.go:
--------------------------------------------------------------------------------
1 | // Mux vars extension for github.com/gorilla/mux package.
2 |
3 | package integration
4 |
5 | import (
6 | "mime/multipart"
7 | "net/http"
8 |
9 | "github.com/ggicci/httpin/core"
10 | )
11 |
12 | // GorillaMuxVarsFunc is mux.Vars
13 | type GorillaMuxVarsFunc func(*http.Request) map[string]string
14 |
15 | // UseGorillaMux registers a new directive executor which can extract values
16 | // from `mux.Vars`, i.e. path variables.
17 | // https://ggicci.github.io/httpin/integrations/gorilla
18 | //
19 | // Usage:
20 | //
21 | // import httpin_integration "github.com/ggicci/httpin/integration"
22 | //
23 | // func init() {
24 | // httpin_integration.UseGorillaMux("path", mux.Vars)
25 | // }
26 | func UseGorillaMux(name string, fnVars GorillaMuxVarsFunc) {
27 | core.RegisterDirective(
28 | name,
29 | core.NewDirectivePath((&gorillaMuxVarsExtractor{Vars: fnVars}).Execute),
30 | true,
31 | )
32 | }
33 |
34 | type gorillaMuxVarsExtractor struct {
35 | Vars GorillaMuxVarsFunc
36 | }
37 |
38 | func (mux *gorillaMuxVarsExtractor) Execute(rtm *core.DirectiveRuntime) error {
39 | req := rtm.GetRequest()
40 | kvs := make(map[string][]string)
41 |
42 | for key, value := range mux.Vars(req) {
43 | kvs[key] = []string{value}
44 | }
45 |
46 | extractor := &core.FormExtractor{
47 | Runtime: rtm,
48 | Form: multipart.Form{
49 | Value: kvs,
50 | },
51 | }
52 | return extractor.Extract()
53 | }
54 |
--------------------------------------------------------------------------------
/integration/gorilla_test.go:
--------------------------------------------------------------------------------
1 | package integration_test
2 |
3 | import (
4 | "encoding/json"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/ggicci/httpin"
10 | httpin_integration "github.com/ggicci/httpin/integration"
11 | "github.com/gorilla/mux"
12 | "github.com/justinas/alice"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | type GetPostOfUserInput struct {
17 | Username string `in:"path=username"`
18 | PostID int64 `in:"path=pid"`
19 | }
20 |
21 | func GetPostOfUserHandler(rw http.ResponseWriter, r *http.Request) {
22 | var input = r.Context().Value(httpin.Input).(*GetPostOfUserInput)
23 | json.NewEncoder(rw).Encode(input)
24 | }
25 |
26 | func TestGorillaMuxVars(t *testing.T) {
27 | httpin_integration.UseGorillaMux("path", mux.Vars) // register the "path" directive, usually in init()
28 |
29 | rw := httptest.NewRecorder()
30 | r, err := http.NewRequest("GET", "/ggicci/posts/1024", nil)
31 | assert.NoError(t, err)
32 |
33 | router := mux.NewRouter()
34 | router.Handle("/{username}/posts/{pid}", alice.New(
35 | httpin.NewInput(GetPostOfUserInput{}),
36 | ).ThenFunc(GetPostOfUserHandler)).Methods("GET")
37 | router.ServeHTTP(rw, r)
38 | assert.Equal(t, 200, rw.Code)
39 | expected := `{"Username":"ggicci","PostID":1024}` + "\n"
40 | assert.Equal(t, expected, rw.Body.String())
41 | }
42 |
--------------------------------------------------------------------------------
/integration/http.go:
--------------------------------------------------------------------------------
1 | package integration
2 |
3 | import (
4 | "mime/multipart"
5 | "net/http"
6 |
7 | "github.com/ggicci/httpin/core"
8 | )
9 |
10 | type HttpMuxVarsFunc func(*http.Request) map[string]string
11 |
12 | // UseHttpPathVariable registers a new directive executor which can extract
13 | // values from URL path variables via `http.Request.PathValue` API.
14 | // https://ggicci.github.io/httpin/integrations/http
15 | //
16 | // Usage:
17 | //
18 | // import httpin_integration "github.com/ggicci/httpin/integration"
19 | // func init() {
20 | // httpin_integration.UseHttpPathVariable("path")
21 | // }
22 | func UseHttpPathVariable(name string) {
23 | core.RegisterDirective(
24 | name,
25 | core.NewDirectivePath((&httpMuxVarsExtractor{}).Execute),
26 | true,
27 | )
28 | }
29 |
30 | type httpMuxVarsExtractor struct{}
31 |
32 | func (mux *httpMuxVarsExtractor) Execute(rtm *core.DirectiveRuntime) error {
33 | req := rtm.GetRequest()
34 | kvs := make(map[string][]string)
35 |
36 | for _, key := range rtm.Directive.Argv {
37 | value := req.PathValue(key)
38 | if value != "" {
39 | kvs[key] = []string{value}
40 | }
41 | }
42 |
43 | extractor := &core.FormExtractor{
44 | Runtime: rtm,
45 | Form: multipart.Form{
46 | Value: kvs,
47 | },
48 | }
49 | return extractor.Extract()
50 | }
51 |
--------------------------------------------------------------------------------
/integration/http_test.go:
--------------------------------------------------------------------------------
1 | package integration_test
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/ggicci/httpin"
12 | httpin_integration "github.com/ggicci/httpin/integration"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | type HttpMuxPathInput struct {
17 | Username string `in:"path=username"`
18 | PostID int64 `in:"path=pid"`
19 | }
20 |
21 | func TestUseHttpMux(t *testing.T) {
22 | httpin_integration.UseHttpPathVariable("path")
23 |
24 | srv := http.NewServeMux()
25 | srv.HandleFunc("/users/{username}/posts/{pid}", func(w http.ResponseWriter, r *http.Request) {
26 | param := &HttpMuxPathInput{}
27 | core, err := httpin.New(param)
28 | if err != nil {
29 | t.Fatal(err)
30 | return
31 | }
32 | v, err := core.Decode(r)
33 | if err != nil {
34 | t.Fatal(err)
35 | return
36 | }
37 | jsonBytes, err := json.Marshal(v)
38 | if err != nil {
39 | t.Fatal(err)
40 | return
41 | }
42 |
43 | w.WriteHeader(http.StatusOK)
44 | w.Write(jsonBytes)
45 | })
46 | ts := httptest.NewServer(srv)
47 | defer ts.Close()
48 |
49 | resp, err := http.DefaultClient.Get(ts.URL + "/users/chriss-de/posts/456")
50 | assert.NoError(t, err)
51 | assert.Equal(t, http.StatusOK, resp.StatusCode)
52 |
53 | bodyBytes, err := io.ReadAll(resp.Body)
54 | assert.NoError(t, err)
55 | bodyString := string(bodyBytes)
56 |
57 | assert.Equal(t, `{"Username":"chriss-de","PostID":456}`, strings.TrimSpace(bodyString))
58 | }
59 |
--------------------------------------------------------------------------------
/internal/misc.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | )
7 |
8 | func IsNil(value reflect.Value) bool {
9 | switch value.Kind() {
10 | case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Interface, reflect.Slice:
11 | return value.IsNil()
12 | default:
13 | return false
14 | }
15 | }
16 |
17 | func PanicOnError(err error) {
18 | if err != nil {
19 | panic(fmt.Errorf("httpin: %w", err))
20 | }
21 | }
22 |
23 | // TypeOf returns the reflect.Type of a given type.
24 | // e.g. TypeOf[int]() returns reflect.TypeOf(0)
25 | func TypeOf[T any]() reflect.Type {
26 | var zero [0]T
27 | return reflect.TypeOf(zero).Elem()
28 | }
29 |
30 | func Pointerize[T any](v T) *T {
31 | return &v
32 | }
33 |
34 | // DereferencedType returns the underlying type of a pointer.
35 | func DereferencedType(v any) reflect.Type {
36 | rv := reflect.ValueOf(v)
37 | for rv.Kind() == reflect.Pointer {
38 | rv = rv.Elem()
39 | }
40 | return rv.Type()
41 | }
42 |
--------------------------------------------------------------------------------
/internal/misc_test.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestIsNil(t *testing.T) {
11 | assert.False(t, IsNil(reflect.ValueOf("hello")))
12 | assert.True(t, IsNil(reflect.ValueOf((*string)(nil))))
13 | }
14 |
15 | func TestPanicOnError(t *testing.T) {
16 | PanicOnError(nil)
17 |
18 | assert.PanicsWithError(t, "httpin: "+assert.AnError.Error(), func() {
19 | PanicOnError(assert.AnError)
20 | })
21 | }
22 |
23 | func TestTypeOf(t *testing.T) {
24 | assert.Equal(t, reflect.TypeOf(0), TypeOf[int]())
25 | }
26 |
27 | func TestPointerize(t *testing.T) {
28 | assert.Equal(t, 102, *Pointerize[int](102))
29 | }
30 |
31 | func TestDereferencedType(t *testing.T) {
32 | type Object struct{}
33 |
34 | var o = new(Object)
35 | var po = &o
36 | var ppo = &po
37 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(Object{}))
38 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(o))
39 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(po))
40 | assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(ppo))
41 | }
42 |
--------------------------------------------------------------------------------
/internal/stringable.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "encoding/base64"
5 | "errors"
6 | "fmt"
7 | "reflect"
8 | "regexp"
9 | "strconv"
10 | "strings"
11 | "time"
12 | "github.com/ggicci/owl"
13 | )
14 |
15 | var (
16 | ErrTypeMismatch = errors.New("type mismatch")
17 | ErrUnsupportedType = owl.ErrUnsupportedType
18 |
19 | builtinStringableAdaptors = make(map[reflect.Type]AnyStringableAdaptor)
20 | )
21 |
22 | func init() {
23 | builtinStringable[string](func(v *string) (Stringable, error) { return (*String)(v), nil })
24 | builtinStringable[bool](func(v *bool) (Stringable, error) { return (*Bool)(v), nil })
25 | builtinStringable[int](func(v *int) (Stringable, error) { return (*Int)(v), nil })
26 | builtinStringable[int8](func(v *int8) (Stringable, error) { return (*Int8)(v), nil })
27 | builtinStringable[int16](func(v *int16) (Stringable, error) { return (*Int16)(v), nil })
28 | builtinStringable[int32](func(v *int32) (Stringable, error) { return (*Int32)(v), nil })
29 | builtinStringable[int64](func(v *int64) (Stringable, error) { return (*Int64)(v), nil })
30 | builtinStringable[uint](func(v *uint) (Stringable, error) { return (*Uint)(v), nil })
31 | builtinStringable[uint8](func(v *uint8) (Stringable, error) { return (*Uint8)(v), nil })
32 | builtinStringable[uint16](func(v *uint16) (Stringable, error) { return (*Uint16)(v), nil })
33 | builtinStringable[uint32](func(v *uint32) (Stringable, error) { return (*Uint32)(v), nil })
34 | builtinStringable[uint64](func(v *uint64) (Stringable, error) { return (*Uint64)(v), nil })
35 | builtinStringable[float32](func(v *float32) (Stringable, error) { return (*Float32)(v), nil })
36 | builtinStringable[float64](func(v *float64) (Stringable, error) { return (*Float64)(v), nil })
37 | builtinStringable[complex64](func(v *complex64) (Stringable, error) { return (*Complex64)(v), nil })
38 | builtinStringable[complex128](func(v *complex128) (Stringable, error) { return (*Complex128)(v), nil })
39 | builtinStringable[time.Time](func(v *time.Time) (Stringable, error) { return (*Time)(v), nil })
40 | builtinStringable[[]byte](func(b *[]byte) (Stringable, error) { return (*ByteSlice)(b), nil })
41 | }
42 |
43 | type StringMarshaler interface {
44 | ToString() (string, error)
45 | }
46 |
47 | type StringUnmarshaler interface {
48 | FromString(string) error
49 | }
50 |
51 | type Stringable interface {
52 | StringMarshaler
53 | StringUnmarshaler
54 | }
55 |
56 | // NewStringable returns a Stringable from the given reflect.Value.
57 | // We assume that the given reflect.Value is a non-nil pointer to a value.
58 | // It will panic if the given reflect.Value is not a pointer.
59 | func NewStringable(rv reflect.Value) (Stringable, error) {
60 | baseType := rv.Type().Elem()
61 | if adapt, ok := builtinStringableAdaptors[baseType]; ok {
62 | return adapt(rv.Interface())
63 | } else {
64 | return nil, UnsupportedType(baseType)
65 | }
66 | }
67 |
68 | func builtinStringable[T any](adaptor StringableAdaptor[T]) {
69 | builtinStringableAdaptors[TypeOf[T]()] = NewAnyStringableAdaptor[T](adaptor)
70 | }
71 |
72 | type String string
73 |
74 | func (sv String) ToString() (string, error) {
75 | return string(sv), nil
76 | }
77 |
78 | func (sv *String) FromString(s string) error {
79 | *sv = String(s)
80 | return nil
81 | }
82 |
83 | type Bool bool
84 |
85 | func (bv Bool) ToString() (string, error) {
86 | return strconv.FormatBool(bool(bv)), nil
87 | }
88 |
89 | func (bv *Bool) FromString(s string) error {
90 | v, err := strconv.ParseBool(s)
91 | if err != nil {
92 | return err
93 | }
94 | *bv = Bool(v)
95 | return nil
96 | }
97 |
98 | type Int int
99 |
100 | func (iv Int) ToString() (string, error) {
101 | return strconv.Itoa(int(iv)), nil
102 | }
103 |
104 | func (iv *Int) FromString(s string) error {
105 | v, err := strconv.Atoi(s)
106 | if err != nil {
107 | return err
108 | }
109 | *iv = Int(v)
110 | return nil
111 | }
112 |
113 | type Int8 int8
114 |
115 | func (iv Int8) ToString() (string, error) {
116 | return strconv.FormatInt(int64(iv), 10), nil
117 | }
118 |
119 | func (iv *Int8) FromString(s string) error {
120 | v, err := strconv.ParseInt(s, 10, 8)
121 | if err != nil {
122 | return err
123 | }
124 | *iv = Int8(v)
125 | return nil
126 | }
127 |
128 | type Int16 int16
129 |
130 | func (iv Int16) ToString() (string, error) {
131 | return strconv.FormatInt(int64(iv), 10), nil
132 | }
133 |
134 | func (iv *Int16) FromString(s string) error {
135 | v, err := strconv.ParseInt(s, 10, 16)
136 | if err != nil {
137 | return err
138 | }
139 | *iv = Int16(v)
140 | return nil
141 | }
142 |
143 | type Int32 int32
144 |
145 | func (iv Int32) ToString() (string, error) {
146 | return strconv.FormatInt(int64(iv), 10), nil
147 | }
148 |
149 | func (iv *Int32) FromString(s string) error {
150 | v, err := strconv.ParseInt(s, 10, 32)
151 | if err != nil {
152 | return err
153 | }
154 | *iv = Int32(v)
155 | return nil
156 | }
157 |
158 | type Int64 int64
159 |
160 | func (iv Int64) ToString() (string, error) {
161 | return strconv.FormatInt(int64(iv), 10), nil
162 | }
163 |
164 | func (iv *Int64) FromString(s string) error {
165 | v, err := strconv.ParseInt(s, 10, 64)
166 | if err != nil {
167 | return err
168 | }
169 | *iv = Int64(v)
170 | return nil
171 | }
172 |
173 | type Uint uint
174 |
175 | func (uv Uint) ToString() (string, error) {
176 | return strconv.FormatUint(uint64(uv), 10), nil
177 | }
178 |
179 | func (uv *Uint) FromString(s string) error {
180 | v, err := strconv.ParseUint(s, 10, 64)
181 | if err != nil {
182 | return err
183 | }
184 | *uv = Uint(v)
185 | return nil
186 | }
187 |
188 | type Uint8 uint8
189 |
190 | func (uv Uint8) ToString() (string, error) {
191 | return strconv.FormatUint(uint64(uv), 10), nil
192 | }
193 |
194 | func (uv *Uint8) FromString(s string) error {
195 | v, err := strconv.ParseUint(s, 10, 8)
196 | if err != nil {
197 | return err
198 | }
199 | *uv = Uint8(v)
200 | return nil
201 | }
202 |
203 | type Uint16 uint16
204 |
205 | func (uv Uint16) ToString() (string, error) {
206 | return strconv.FormatUint(uint64(uv), 10), nil
207 | }
208 |
209 | func (uv *Uint16) FromString(s string) error {
210 | v, err := strconv.ParseUint(s, 10, 16)
211 | if err != nil {
212 | return err
213 | }
214 | *uv = Uint16(v)
215 | return nil
216 | }
217 |
218 | type Uint32 uint32
219 |
220 | func (uv Uint32) ToString() (string, error) {
221 | return strconv.FormatUint(uint64(uv), 10), nil
222 | }
223 |
224 | func (uv *Uint32) FromString(s string) error {
225 | v, err := strconv.ParseUint(s, 10, 32)
226 | if err != nil {
227 | return err
228 | }
229 | *uv = Uint32(v)
230 | return nil
231 | }
232 |
233 | type Uint64 uint64
234 |
235 | func (uv Uint64) ToString() (string, error) {
236 | return strconv.FormatUint(uint64(uv), 10), nil
237 | }
238 |
239 | func (uv *Uint64) FromString(s string) error {
240 | v, err := strconv.ParseUint(s, 10, 64)
241 | if err != nil {
242 | return err
243 | }
244 | *uv = Uint64(v)
245 | return nil
246 | }
247 |
248 | type Float32 float32
249 |
250 | func (fv Float32) ToString() (string, error) {
251 | return strconv.FormatFloat(float64(fv), 'f', -1, 32), nil
252 | }
253 |
254 | func (fv *Float32) FromString(s string) error {
255 | v, err := strconv.ParseFloat(s, 32)
256 | if err != nil {
257 | return err
258 | }
259 | *fv = Float32(v)
260 | return nil
261 | }
262 |
263 | type Float64 float64
264 |
265 | func (fv Float64) ToString() (string, error) {
266 | return strconv.FormatFloat(float64(fv), 'f', -1, 64), nil
267 | }
268 |
269 | func (fv *Float64) FromString(s string) error {
270 | v, err := strconv.ParseFloat(s, 64)
271 | if err != nil {
272 | return err
273 | }
274 | *fv = Float64(v)
275 | return nil
276 | }
277 |
278 | type Complex64 complex64
279 |
280 | func (cv Complex64) ToString() (string, error) {
281 | return strconv.FormatComplex(complex128(cv), 'f', -1, 64), nil
282 | }
283 |
284 | func (cv *Complex64) FromString(s string) error {
285 | v, err := strconv.ParseComplex(s, 64)
286 | if err != nil {
287 | return err
288 | }
289 | *cv = Complex64(v)
290 | return nil
291 | }
292 |
293 | type Complex128 complex128
294 |
295 | func (cv Complex128) ToString() (string, error) {
296 | return strconv.FormatComplex(complex128(cv), 'f', -1, 128), nil
297 | }
298 |
299 | func (cv *Complex128) FromString(s string) error {
300 | v, err := strconv.ParseComplex(s, 128)
301 | if err != nil {
302 | return err
303 | }
304 | *cv = Complex128(v)
305 | return nil
306 | }
307 |
308 | type Time time.Time
309 |
310 | func (tv Time) ToString() (string, error) {
311 | return time.Time(tv).UTC().Format(time.RFC3339Nano), nil
312 | }
313 |
314 | func (tv *Time) FromString(s string) error {
315 | if t, err := DecodeTime(s); err != nil {
316 | return err
317 | } else {
318 | *tv = Time(t)
319 | return nil
320 | }
321 | }
322 |
323 | var reUnixtime = regexp.MustCompile(`^\d+(\.\d{1,9})?$`)
324 |
325 | // DecodeTime parses data bytes as time.Time in UTC timezone.
326 | // Supported formats of the data bytes are:
327 | // 1. RFC3339Nano string, e.g. "2006-01-02T15:04:05-07:00".
328 | // 2. Date string, e.g. "2006-01-02".
329 | // 3. Unix timestamp, e.g. "1136239445", "1136239445.8", "1136239445.812738".
330 | func DecodeTime(value string) (time.Time, error) {
331 | // Try parsing value as RFC3339 format.
332 | if t, err := time.ParseInLocation(time.RFC3339Nano, value, time.UTC); err == nil {
333 | return t.UTC(), nil
334 | }
335 |
336 | // Try parsing value as date format.
337 | if t, err := time.ParseInLocation("2006-01-02", value, time.UTC); err == nil {
338 | return t.UTC(), nil
339 | }
340 |
341 | // Try parsing value as timestamp, both integer and float formats supported.
342 | // e.g. "1618974933", "1618974933.284368".
343 | if reUnixtime.MatchString(value) {
344 | return DecodeUnixtime(value)
345 | }
346 |
347 | return time.Time{}, errors.New("invalid time value")
348 | }
349 |
350 | // value must be valid unix timestamp, matches reUnixtime.
351 | func DecodeUnixtime(value string) (time.Time, error) {
352 | parts := strings.Split(value, ".")
353 | // Note: errors are ignored, since we already validated the value.
354 | sec, _ := strconv.ParseInt(parts[0], 10, 64)
355 | var nsec int64
356 | if len(parts) == 2 {
357 | nsec, _ = strconv.ParseInt(nanoSecondPrecision(parts[1]), 10, 64)
358 | }
359 | return time.Unix(sec, nsec).UTC(), nil
360 | }
361 |
362 | func nanoSecondPrecision(value string) string {
363 | return value + strings.Repeat("0", 9-len(value))
364 | }
365 |
366 | // ByteSlice is a wrapper of []byte to implement Stringable.
367 | // NOTE: we're using base64.StdEncoding here, not base64.URLEncoding.
368 | type ByteSlice []byte
369 |
370 | func (bs ByteSlice) ToString() (string, error) {
371 | return base64.StdEncoding.EncodeToString(bs), nil
372 | }
373 |
374 | func (bs *ByteSlice) FromString(s string) error {
375 | v, err := base64.StdEncoding.DecodeString(s)
376 | if err != nil {
377 | return err
378 | }
379 | *bs = ByteSlice(v)
380 | return nil
381 | }
382 |
383 | func UnsupportedType(rt reflect.Type) error {
384 | return fmt.Errorf("%w: %v", ErrUnsupportedType, rt)
385 | }
386 |
--------------------------------------------------------------------------------
/internal/stringable_test.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "testing"
7 | "time"
8 |
9 | "github.com/stretchr/testify/assert"
10 | "github.com/ggicci/owl"
11 | )
12 |
13 | func TestNewStringable_string(t *testing.T) {
14 | var s string = "hello"
15 | rvString := reflect.ValueOf(s)
16 | assert.Panics(t, func() {
17 | NewStringable(rvString)
18 | })
19 |
20 | rvStringPointer := reflect.ValueOf(&s)
21 | sv, err := NewStringable(rvStringPointer)
22 | assert.NoError(t, err)
23 | got, err := sv.ToString()
24 | assert.NoError(t, err)
25 | assert.Equal(t, "hello", got)
26 | sv.FromString("world")
27 | assert.Equal(t, "world", s)
28 | }
29 |
30 | func TestNewStringable_bool(t *testing.T) {
31 | var b bool = true
32 | rvBool := reflect.ValueOf(b)
33 | assert.Panics(t, func() {
34 | NewStringable(rvBool)
35 | })
36 |
37 | rvBoolPointer := reflect.ValueOf(&b)
38 | sv, err := NewStringable(rvBoolPointer)
39 | assert.NoError(t, err)
40 | got, err := sv.ToString()
41 | assert.NoError(t, err)
42 | assert.Equal(t, "true", got)
43 | sv.FromString("false")
44 | assert.Equal(t, false, b)
45 |
46 | assert.Error(t, sv.FromString("hello"))
47 | }
48 |
49 | func TestNewStringable_int(t *testing.T) {
50 | testInteger[int](t, 2045, "hello")
51 | }
52 |
53 | func TestNewStringable_int8(t *testing.T) {
54 | testInteger[int8](t, int8(127), "128")
55 | }
56 |
57 | func TestNewStringable_int16(t *testing.T) {
58 | testInteger[int16](t, int16(32767), "32768")
59 | }
60 |
61 | func TestNewStringable_int32(t *testing.T) {
62 | testInteger[int32](t, int32(2147483647), "2147483648")
63 | }
64 |
65 | func TestNewStringable_int64(t *testing.T) {
66 | testInteger[int64](t, int64(9223372036854775807), "9223372036854775808")
67 | }
68 |
69 | func TestNewStringable_uint(t *testing.T) {
70 | testInteger[uint](t, uint(2045), "-1")
71 | }
72 |
73 | func TestNewStringable_uint8(t *testing.T) {
74 | testInteger[uint8](t, uint8(255), "256")
75 | }
76 |
77 | func TestNewStringable_uint16(t *testing.T) {
78 | testInteger[uint16](t, uint16(65535), "65536")
79 | }
80 |
81 | func TestNewStringable_uint32(t *testing.T) {
82 | testInteger[uint32](t, uint32(4294967295), "4294967296")
83 | }
84 |
85 | func TestNewStringable_uint64(t *testing.T) {
86 | testInteger[uint64](t, uint64(18446744073709551615), "18446744073709551616")
87 | }
88 |
89 | func TestNewStringable_float32(t *testing.T) {
90 | testInteger[float32](t, float32(3.1415926), "hello")
91 | }
92 |
93 | func TestNewStringable_float64(t *testing.T) {
94 | testInteger[float64](t, float64(3.14159265358979323846264338327950288419716939937510582097494459), "hello")
95 | }
96 |
97 | func TestNewStringable_complex64(t *testing.T) {
98 | testInteger[complex64](t, complex64(3.1415926+2.71828i), "hello")
99 | }
100 |
101 | func TestNewStringable_complex128(t *testing.T) {
102 | testInteger[complex128](t, complex128(3.14159265358979323846264338327950288419716939937510582097494459+2.71828182845904523536028747135266249775724709369995957496696763i), "hello")
103 | }
104 |
105 | func TestNewStringable_Time(t *testing.T) {
106 | var now = time.Now()
107 | rvTime := reflect.ValueOf(now)
108 | assert.Panics(t, func() {
109 | NewStringable(rvTime)
110 | })
111 |
112 | rvTimePointer := reflect.ValueOf(&now)
113 | sv, err := NewStringable(rvTimePointer)
114 | assert.NoError(t, err)
115 |
116 | // RFC3339Nano
117 | testTime(t, sv, "1991-11-10T08:00:00+08:00", time.Date(1991, 11, 10, 8, 0, 0, 0, time.FixedZone("Asia/Shanghai", +8*3600)), "1991-11-10T00:00:00Z")
118 | // Date string
119 | testTime(t, sv, "1991-11-10", time.Date(1991, 11, 10, 0, 0, 0, 0, time.UTC), "1991-11-10T00:00:00Z")
120 |
121 | // Unix timestamp
122 | testTime(t, sv, "678088800", time.Date(1991, 6, 28, 6, 0, 0, 0, time.UTC), "1991-06-28T06:00:00Z")
123 |
124 | // Unix timestamp fraction
125 | testTime(t, sv, "678088800.123456789", time.Date(1991, 6, 28, 6, 0, 0, 123456789, time.UTC), "1991-06-28T06:00:00.123456789Z")
126 |
127 | // Unsupported format
128 | assert.Error(t, sv.FromString("hello"))
129 | }
130 |
131 | func TestNewStringable_ByteSlice(t *testing.T) {
132 | var b []byte = []byte("hello")
133 | rvByteSlice := reflect.ValueOf(b)
134 | assert.NotPanics(t, func() {
135 | NewStringable(rvByteSlice)
136 | })
137 |
138 | rvByteSlicePointer := reflect.ValueOf(&b)
139 | sv, err := NewStringable(rvByteSlicePointer)
140 | assert.NoError(t, err)
141 | got, err := sv.ToString()
142 | assert.NoError(t, err)
143 | assert.Equal(t, "aGVsbG8=", got)
144 |
145 | sv.FromString("d29ybGQ=")
146 | assert.Equal(t, []byte("world"), b)
147 |
148 | assert.Error(t, sv.FromString("hello"))
149 | }
150 |
151 | func TestNewStringable_ErrUnsupportedType(t *testing.T) {
152 | type MyStruct struct{ Name string }
153 | var s MyStruct
154 | rvStruct := reflect.ValueOf(s)
155 | assert.Panics(t, func() {
156 | NewStringable(rvStruct)
157 | })
158 | rvStructPointer := reflect.ValueOf(&s)
159 | sv, err := NewStringable(rvStructPointer)
160 | assert.ErrorIs(t, err, owl.ErrUnsupportedType)
161 | assert.Nil(t, sv)
162 | }
163 |
164 | type Numeric interface {
165 | int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64 | complex64 | complex128
166 | }
167 |
168 | func testInteger[T Numeric](t *testing.T, vSuccess T, invalidStr string) {
169 | rv := reflect.ValueOf(vSuccess)
170 | assert.Panics(t, func() {
171 | NewStringable(rv)
172 | })
173 |
174 | rvPointer := reflect.ValueOf(&vSuccess)
175 | sv, err := NewStringable(rvPointer)
176 | assert.NoError(t, err)
177 | got, err := sv.ToString()
178 | assert.NoError(t, err)
179 | assert.Equal(t, fmt.Sprintf("%v", vSuccess), got)
180 | sv.FromString("2")
181 | assert.Equal(t, T(2), vSuccess)
182 |
183 | assert.Error(t, sv.FromString(invalidStr))
184 | }
185 |
186 | func testTime(t *testing.T, sv Stringable, fromStr string, expected time.Time, expectedToStr string) {
187 | assert.NoError(t, sv.FromString(fromStr))
188 | assert.True(t, equalTime(expected, time.Time(*sv.(*Time))))
189 | ts, err := sv.ToString()
190 | assert.NoError(t, err)
191 | assert.Equal(t, expectedToStr, ts)
192 | }
193 |
194 | func equalTime(expected, actual time.Time) bool {
195 | return expected.UTC() == actual.UTC()
196 | }
197 |
--------------------------------------------------------------------------------
/internal/stringableadaptor.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import "fmt"
4 |
5 | type StringableAdaptor[T any] func(*T) (Stringable, error)
6 | type AnyStringableAdaptor func(any) (Stringable, error)
7 |
8 | func NewAnyStringableAdaptor[T any](adapt StringableAdaptor[T]) AnyStringableAdaptor {
9 | return func(v any) (Stringable, error) {
10 | if cv, ok := v.(*T); ok {
11 | return adapt(cv)
12 | } else {
13 | return nil, fmt.Errorf("%w: cannot convert %T to %s", ErrTypeMismatch, v, TypeOf[*T]())
14 | }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/internal/stringableadaptor_test.go:
--------------------------------------------------------------------------------
1 | package internal
2 |
3 | import (
4 | "errors"
5 | "strings"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | type YesNo bool
12 |
13 | func (yn YesNo) ToString() (string, error) {
14 | if yn {
15 | return "yes", nil
16 | } else {
17 | return "no", nil
18 | }
19 | }
20 |
21 | func (yn *YesNo) FromString(s string) error {
22 | switch strings.ToLower(s) {
23 | case "yes":
24 | *yn = true
25 | case "no":
26 | *yn = false
27 | default:
28 | return errors.New("invalid value")
29 | }
30 | return nil
31 | }
32 |
33 | func TestToAnyStringableAdaptor(t *testing.T) {
34 | adaptor := NewAnyStringableAdaptor[bool](func(b *bool) (Stringable, error) {
35 | return (*YesNo)(b), nil
36 | })
37 |
38 | var validBoolean bool = true
39 | stringable, err := adaptor(&validBoolean)
40 | assert.NoError(t, err)
41 | v, err := stringable.ToString()
42 | assert.NoError(t, err)
43 | assert.Equal(t, "yes", v)
44 | assert.NoError(t, stringable.FromString("no"))
45 | assert.False(t, validBoolean)
46 |
47 | var invalidType int = 0
48 | stringable, err = adaptor(&invalidType)
49 | assert.ErrorIs(t, err, ErrTypeMismatch)
50 | assert.Nil(t, stringable)
51 | assert.ErrorContains(t, err, "cannot convert *int to *bool")
52 | }
53 |
--------------------------------------------------------------------------------
/patch/patch.go:
--------------------------------------------------------------------------------
1 | package patch
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | )
7 |
8 | // Field is a wrapper which can tell if a field was unmarshalled from the data provided.
9 | // When `Field.Valid` is true, which means `Field.Value` is populated from decoding the raw data.
10 | // Otherwise, no data was provided, i.e. field missing.
11 | type Field[T any] struct {
12 | Value T
13 | Valid bool
14 | }
15 |
16 | func (f Field[T]) MarshalJSON() ([]byte, error) {
17 | if !f.Valid {
18 | return []byte("null"), nil
19 | }
20 | return json.Marshal(f.Value)
21 | }
22 |
23 | func (f *Field[T]) UnmarshalJSON(data []byte) error {
24 | err := json.Unmarshal(data, &f.Value)
25 | if err == nil && !bytes.Equal(data, []byte("null")) {
26 | f.Valid = true
27 | }
28 | return err
29 | }
30 |
--------------------------------------------------------------------------------
/patch/patch_test.go:
--------------------------------------------------------------------------------
1 | package patch_test
2 |
3 | import (
4 | "encoding/json"
5 | "reflect"
6 | "testing"
7 | "time"
8 |
9 | "github.com/ggicci/httpin/patch"
10 | )
11 |
12 | func shouldBeNil(t *testing.T, err error, failMessage string) {
13 | if err != nil {
14 | t.Logf("%s, got error: %v", failMessage, err)
15 | t.Fail()
16 | }
17 | }
18 |
19 | func shouldResemble(t *testing.T, va, vb any, failMessage string) {
20 | if reflect.DeepEqual(va, vb) {
21 | return
22 | }
23 | t.Logf("%s, expected %#v, got %#v", failMessage, va, vb)
24 | t.Fail()
25 | }
26 |
27 | func fixedZone(offset int) *time.Location {
28 | if offset == 0 {
29 | return time.UTC
30 | }
31 | // _, localOffset := time.Now().Local().Zone()
32 | // if offset == localOffset {
33 | // return time.Local
34 | // }
35 | return time.FixedZone("", offset)
36 | }
37 |
38 | func testJSONMarshalling(t *testing.T, tc testcase) {
39 | bs, err := json.Marshal(tc.Expected)
40 | if err != nil {
41 | t.Logf("marshal failed, got error: %v", err)
42 | t.Fail()
43 | }
44 | if string(bs) != tc.Content {
45 | t.Logf("marshal failed, expected %q, got %q", tc.Content, string(bs))
46 | t.Fail()
47 | }
48 | }
49 |
50 | func testJSONUnmarshalling(t *testing.T, tc testcase) {
51 | rt := reflect.TypeOf(tc.Expected) // type: patch.Field
52 | rv := reflect.New(rt) // rv: *patch.Field
53 |
54 | shouldBeNil(t, json.Unmarshal([]byte(tc.Content), rv.Interface()), "unmarshal failed")
55 | shouldResemble(t, rv.Elem().Interface(), tc.Expected, "unmarshal failed")
56 | }
57 |
58 | type testcase struct {
59 | Content string
60 | Expected any
61 | }
62 |
63 | type GitHubProfile struct {
64 | Id int64 `json:"id"`
65 | Login string `json:"login"`
66 | AvatarUrl string `json:"avatar_url"`
67 | }
68 |
69 | type GenderType string
70 |
71 | type Account struct {
72 | Id int64
73 | Email string
74 | Tags []string
75 | Gender GenderType
76 | GitHub *GitHubProfile
77 | }
78 |
79 | type AccountPatch struct {
80 | Email patch.Field[string] `json:"email"`
81 | Tags patch.Field[[]string] `json:"tags"`
82 | Gender patch.Field[GenderType] `json:"gender"`
83 | GitHub patch.Field[*GitHubProfile] `json:"github"`
84 | }
85 |
86 | func TestField(t *testing.T) {
87 | var cases = []testcase{
88 | {"true", patch.Field[bool]{true, true}},
89 | {"false", patch.Field[bool]{false, true}},
90 | {"2045", patch.Field[int]{2045, true}},
91 | {"127", patch.Field[int8]{127, true}},
92 | {"32767", patch.Field[int16]{32767, true}},
93 | {"2147483647", patch.Field[int32]{2147483647, true}},
94 | {"9223372036854775807", patch.Field[int64]{9223372036854775807, true}},
95 | {"2045", patch.Field[uint]{2045, true}},
96 | {"255", patch.Field[uint8]{255, true}},
97 | {"65535", patch.Field[uint16]{65535, true}},
98 | {"4294967295", patch.Field[uint32]{4294967295, true}},
99 | {"18446744073709551615", patch.Field[uint64]{18446744073709551615, true}},
100 | {"3.14", patch.Field[float32]{3.14, true}},
101 | {"3.14", patch.Field[float64]{3.14, true}},
102 | {"\"hello\"", patch.Field[string]{"hello", true}},
103 |
104 | // Array
105 | {`[true,false]`, patch.Field[[]bool]{[]bool{true, false}, true}},
106 | {"[1,2,3]", patch.Field[[]int]{[]int{1, 2, 3}, true}},
107 | {"[1,2,3]", patch.Field[[]int8]{[]int8{1, 2, 3}, true}},
108 | {"[1,2,3]", patch.Field[[]int16]{[]int16{1, 2, 3}, true}},
109 | {"[1,2,3]", patch.Field[[]int32]{[]int32{1, 2, 3}, true}},
110 | {"[1,2,3]", patch.Field[[]int64]{[]int64{1, 2, 3}, true}},
111 | {"[1,2,3]", patch.Field[[]uint]{[]uint{1, 2, 3}, true}},
112 | // NOTE(ggicci): []uint8 is a special case, check TestFieldUint8Array
113 | {"[1,2,3]", patch.Field[[]uint16]{[]uint16{1, 2, 3}, true}},
114 | {"[1,2,3]", patch.Field[[]uint32]{[]uint32{1, 2, 3}, true}},
115 | {"[1,2,3]", patch.Field[[]uint64]{[]uint64{1, 2, 3}, true}},
116 | {"[0.618,1,3.14]", patch.Field[[]float32]{[]float32{0.618, 1, 3.14}, true}},
117 | {"[0.618,1,3.14]", patch.Field[[]float64]{[]float64{0.618, 1, 3.14}, true}},
118 | {`["hello","world"]`, patch.Field[[]string]{[]string{"hello", "world"}, true}},
119 |
120 | // time.Time
121 | {
122 | `"2019-08-25T07:19:34Z"`,
123 | patch.Field[time.Time]{
124 | time.Date(2019, 8, 25, 7, 19, 34, 0, fixedZone(0)),
125 | true,
126 | },
127 | },
128 | {
129 | `"1991-11-10T08:00:00-07:00"`,
130 | patch.Field[time.Time]{
131 | time.Date(1991, 11, 10, 8, 0, 0, 0, fixedZone(-7*3600)),
132 | true,
133 | },
134 | },
135 | {
136 | `"1991-11-10T08:00:00+08:00"`,
137 | patch.Field[time.Time]{
138 | time.Date(1991, 11, 10, 8, 0, 0, 0, fixedZone(+8*3600)),
139 | true,
140 | },
141 | },
142 |
143 | // Custom structs
144 | {
145 | `{"Id":1000,"Email":"ggicci@example.com","Tags":["developer","修勾"],"Gender":"male","GitHub":{"id":3077555,"login":"ggicci","avatar_url":"https://avatars.githubusercontent.com/u/3077555?v=4"}}`,
146 | patch.Field[*Account]{
147 | &Account{
148 | Id: 1000,
149 | Email: "ggicci@example.com",
150 | Tags: []string{"developer", "修勾"},
151 | Gender: "male",
152 | GitHub: &GitHubProfile{
153 | Id: 3077555,
154 | Login: "ggicci",
155 | AvatarUrl: "https://avatars.githubusercontent.com/u/3077555?v=4",
156 | },
157 | },
158 | true,
159 | },
160 | },
161 | }
162 |
163 | for _, c := range cases {
164 | testJSONMarshalling(t, c)
165 | testJSONUnmarshalling(t, c)
166 | }
167 | }
168 |
169 | // TestFieldUint8Array runs JSON marshalling & unmarshalling tests on type Field[[]uint8].
170 | // Because in golang's encoding/json package, encoding uint8[] is special.
171 | // See: https://golang.org/pkg/encoding/json/#Marshal
172 | //
173 | // > Array and slice values encode as JSON arrays, except that []byte encodes
174 | // as a base64-encoded string, and a nil slice encodes as the null JSON
175 | // value.
176 | //
177 | // uint8 the set of all unsigned 8-bit integers (0 to 255)
178 | // byte alias for uint8
179 | func TestFieldUint8Array(t *testing.T) {
180 | var a1 patch.Field[[]uint8]
181 | // unmarshal
182 | shouldBeNil(t, json.Unmarshal([]byte("[1,2,3]"), &a1), "unmarshal Field[[]uint8] failed")
183 | shouldResemble(t, patch.Field[[]uint8]{[]uint8{1, 2, 3}, true}, a1, "unmarshal Field[[]uint8] failed")
184 |
185 | // marshal
186 | var a2 = patch.Field[[]uint8]{[]uint8{1, 2, 3}, true}
187 | out, err := json.Marshal(a2)
188 | shouldBeNil(t, err, "marshal Field[[]uint8] failed")
189 | shouldResemble(t, `"AQID"`, string(out), "marshal Field[[]uint8] failed")
190 | }
191 |
192 | func TestField_UnmarshalJSON_Struct(t *testing.T) {
193 | var testcases = []testcase{
194 | {
195 | `{"email":"ggicci.2@example.com","tags":["artist","photographer"]}`,
196 | AccountPatch{
197 | Email: patch.Field[string]{"ggicci.2@example.com", true},
198 | Gender: patch.Field[GenderType]{"", false},
199 | Tags: patch.Field[[]string]{[]string{"artist", "photographer"}, true},
200 | GitHub: patch.Field[*GitHubProfile]{nil, false},
201 | },
202 | },
203 | {
204 | `{"tags":null,"gender":"female","github":{"id":100,"login":"ggicci.2","avatar_url":null}}`,
205 | AccountPatch{
206 | Email: patch.Field[string]{"", false},
207 | Gender: patch.Field[GenderType]{"female", true},
208 | Tags: patch.Field[[]string]{nil, false},
209 | GitHub: patch.Field[*GitHubProfile]{&GitHubProfile{
210 | Id: 100,
211 | Login: "ggicci.2",
212 | AvatarUrl: "",
213 | }, true},
214 | },
215 | },
216 | }
217 |
218 | for _, c := range testcases {
219 | testJSONUnmarshalling(t, c)
220 | }
221 | }
222 |
223 | func TestField_MarshalJSON_Struct(t *testing.T) {
224 | var testcases = []testcase{
225 | {
226 | `{"email":"hello","tags":null,"gender":null,"github":null}`,
227 | AccountPatch{
228 | Email: patch.Field[string]{"hello", true},
229 | Tags: patch.Field[[]string]{nil, false},
230 | Gender: patch.Field[GenderType]{"", false},
231 | GitHub: patch.Field[*GitHubProfile]{nil, false},
232 | },
233 | },
234 | }
235 |
236 | for _, c := range testcases {
237 | testJSONMarshalling(t, c)
238 | }
239 | }
240 |
--------------------------------------------------------------------------------