├── LICENSE ├── README.md ├── chi.go ├── decoder.go ├── decoder_test.go ├── encoder.go ├── error.go ├── go.mod ├── go.sum ├── hrt.go ├── hrt_example_get_test.go ├── hrt_example_post_test.go ├── hrt_test.go └── internal ├── ht └── ht.go └── rfutil └── rfutil.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 diamondburned 2 | 3 | Permission to use, copy, modify, and/or distribute this software for any purpose 4 | with or without fee is hereby granted, provided that the above copyright notice 5 | and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS” AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH 8 | REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, 10 | INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS 11 | OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER 12 | TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF 13 | THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Package hrt** implements a type-safe HTTP router. It aids in creating a 2 | uniform API interface while making it easier to create API handlers. 3 | 4 | HRT stands for (H)TTP (r)outer with (t)ypes. 5 | 6 | ## Example 7 | 8 | Below is a trimmed down version of the Get example in the GoDoc. 9 | 10 | ```go 11 | type EchoRequest struct { 12 | What string `query:"what"` 13 | } 14 | 15 | type EchoResponse struct { 16 | What string `json:"what"` 17 | } 18 | 19 | func handleEcho(ctx context.Context, req EchoRequest) (EchoResponse, error) { 20 | return EchoResponse{What: req.What}, nil 21 | } 22 | ``` 23 | 24 | ```go 25 | r := chi.NewRouter() 26 | r.Use(hrt.Use(hrt.Opts{ 27 | Encoder: hrt.JSONEncoder, 28 | ErrorWriter: hrt.JSONErrorWriter("error"), 29 | })) 30 | 31 | r.Get("/echo", hrt.Wrap(handleEcho)) 32 | ``` 33 | 34 | ## Documentation 35 | 36 | For documentation and examples, see [GoDoc](https://godoc.org/libdb.so/hrt/v2). 37 | 38 | ## Dependencies 39 | 40 | HRT depends on [chi v5](https://pkg.go.dev/github.com/go-chi/chi/v5) for URL 41 | parameters when routing. Apps that use HRT should also use chi for routing. 42 | 43 | Note that it is still possible to make a custom URL parameter decoder that would 44 | replace chi's, but it is not recommended. 45 | -------------------------------------------------------------------------------- /chi.go: -------------------------------------------------------------------------------- 1 | package hrt 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | ) 8 | 9 | // Router redefines [chi.Router] to modify all method-routing functions to 10 | // accept an [http.Handler] instead of a [http.HandlerFunc]. 11 | type Router interface { 12 | http.Handler 13 | chi.Routes 14 | 15 | // Use appends one or more middlewares onto the Router stack. 16 | Use(middlewares ...func(http.Handler) http.Handler) 17 | 18 | // With adds inline middlewares for an endpoint handler. 19 | With(middlewares ...func(http.Handler) http.Handler) Router 20 | 21 | // Group adds a new inline-Router along the current routing 22 | // path, with a fresh middleware stack for the inline-Router. 23 | Group(fn func(r Router)) Router 24 | 25 | // Route mounts a sub-Router along a `pattern“ string. 26 | Route(pattern string, fn func(r Router)) Router 27 | 28 | // Mount attaches another http.Handler along ./pattern/* 29 | Mount(pattern string, h http.Handler) 30 | 31 | // Handle and HandleFunc adds routes for `pattern` that matches 32 | // all HTTP methods. 33 | Handle(pattern string, h http.Handler) 34 | HandleFunc(pattern string, h http.HandlerFunc) 35 | 36 | // Method and MethodFunc adds routes for `pattern` that matches 37 | // the `method` HTTP method. 38 | Method(method, pattern string, h http.Handler) 39 | MethodFunc(method, pattern string, h http.HandlerFunc) 40 | 41 | // HTTP-method routing along `pattern` 42 | Connect(pattern string, h http.Handler) 43 | Delete(pattern string, h http.Handler) 44 | Get(pattern string, h http.Handler) 45 | Head(pattern string, h http.Handler) 46 | Options(pattern string, h http.Handler) 47 | Patch(pattern string, h http.Handler) 48 | Post(pattern string, h http.Handler) 49 | Put(pattern string, h http.Handler) 50 | Trace(pattern string, h http.Handler) 51 | 52 | // NotFound defines a handler to respond whenever a route could 53 | // not be found. 54 | NotFound(h http.HandlerFunc) 55 | 56 | // MethodNotAllowed defines a handler to respond whenever a method is 57 | // not allowed. 58 | MethodNotAllowed(h http.HandlerFunc) 59 | } 60 | 61 | // NewRouter creates a [chi.Router] wrapper that turns all method-routing 62 | // functions to take a regular [http.Handler] instead of an [http.HandlerFunc]. 63 | // This allows [hrt.Wrap] to function properly. This router also has the given 64 | // opts injected into its context, so there is no need to call [hrt.Use]. 65 | func NewRouter(opts Opts) Router { 66 | r := router{chi.NewRouter()} 67 | r.Use(Use(opts)) 68 | return r 69 | } 70 | 71 | // NewPlainRouter is like [NewRouter] but does not inject any options into the 72 | // context. 73 | func NewPlainRouter() Router { 74 | return router{chi.NewRouter()} 75 | } 76 | 77 | // WrapRouter wraps a [chi.Router] to turn all method-routing functions to take 78 | // a regular [http.Handler] instead of an [http.HandlerFunc]. This allows 79 | // [hrt.Wrap] to function properly. 80 | func WrapRouter(r chi.Router) Router { 81 | return router{r} 82 | } 83 | 84 | type router struct{ chi.Router } 85 | 86 | func (r router) With(middlewares ...func(http.Handler) http.Handler) Router { 87 | return router{r.Router.With(middlewares...)} 88 | } 89 | 90 | func (r router) Group(fn func(r Router)) Router { 91 | return router{r.Router.Group(func(r chi.Router) { 92 | fn(router{r}) 93 | })} 94 | } 95 | 96 | func (r router) Route(pattern string, fn func(r Router)) Router { 97 | return router{r.Router.Route(pattern, func(r chi.Router) { 98 | fn(router{r}) 99 | })} 100 | } 101 | 102 | func (r router) Connect(pattern string, h http.Handler) { 103 | r.Router.Method("connect", pattern, h) 104 | } 105 | 106 | func (r router) Delete(pattern string, h http.Handler) { 107 | r.Router.Method("delete", pattern, h) 108 | } 109 | 110 | func (r router) Get(pattern string, h http.Handler) { 111 | r.Router.Method("get", pattern, h) 112 | } 113 | 114 | func (r router) Head(pattern string, h http.Handler) { 115 | r.Router.Method("head", pattern, h) 116 | } 117 | 118 | func (r router) Options(pattern string, h http.Handler) { 119 | r.Router.Method("options", pattern, h) 120 | } 121 | 122 | func (r router) Patch(pattern string, h http.Handler) { 123 | r.Router.Method("patch", pattern, h) 124 | } 125 | 126 | func (r router) Post(pattern string, h http.Handler) { 127 | r.Router.Method("post", pattern, h) 128 | } 129 | 130 | func (r router) Put(pattern string, h http.Handler) { 131 | r.Router.Method("put", pattern, h) 132 | } 133 | 134 | func (r router) Trace(pattern string, h http.Handler) { 135 | r.Router.Method("trace", pattern, h) 136 | } 137 | -------------------------------------------------------------------------------- /decoder.go: -------------------------------------------------------------------------------- 1 | package hrt 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "reflect" 7 | "strings" 8 | 9 | "github.com/go-chi/chi/v5" 10 | "github.com/pkg/errors" 11 | "libdb.so/hrt/v2/internal/rfutil" 12 | ) 13 | 14 | // Decoder describes a decoder that decodes the request type. 15 | type Decoder interface { 16 | // Decode decodes the given value from the given reader. 17 | Decode(*http.Request, any) error 18 | } 19 | 20 | // MethodDecoder is an encoder that only encodes or decodes if the request 21 | // method matches the methods in it. 22 | type MethodDecoder map[string]Decoder 23 | 24 | // Decode implements the Decoder interface. 25 | func (e MethodDecoder) Decode(r *http.Request, v any) error { 26 | dec, ok := e[r.Method] 27 | if !ok { 28 | dec, ok = e["*"] 29 | } 30 | if !ok { 31 | return WrapHTTPError(http.StatusMethodNotAllowed, errors.New("method not allowed")) 32 | } 33 | return dec.Decode(r, v) 34 | } 35 | 36 | // URLDecoder decodes chi.URLParams and url.Values into a struct. It only does 37 | // Decoding; the Encode method is a no-op. The decoder makes no effort to 38 | // traverse the struct and decode nested structs. If neither a chi.URLParam nor 39 | // a url.Value is found for a field, the field is left untouched. 40 | // 41 | // The following tags are supported: 42 | // 43 | // - `url` - uses chi.URLParam to decode the value. 44 | // - `form` - uses r.FormValue to decode the value. 45 | // - `query` - similar to `form`. 46 | // - `schema` - similar to `form`, exists for compatibility with gorilla/schema. 47 | // - `json` - uses either chi.URLParam or r.FormValue to decode the value. 48 | // If the value is provided within the form, then it is unmarshaled as JSON 49 | // into the field unless the type is a string. If the value is provided within 50 | // the URL, then it is unmarshaled as a primitive value. 51 | // 52 | // If a struct field has no tag, it is assumed to be the same as the field name. 53 | // If a struct field has a tag, then only that tag is used. 54 | // 55 | // # Example 56 | // 57 | // The following Go type would be decoded to have 2 URL parameters: 58 | // 59 | // type Data struct { 60 | // ID string 61 | // Num int `url:"num"` 62 | // Nested struct { 63 | // ID string 64 | // } 65 | // } 66 | var URLDecoder Decoder = urlDecoder{} 67 | 68 | type urlDecoder struct{} 69 | 70 | func (d urlDecoder) Decode(r *http.Request, v any) error { 71 | return rfutil.EachStructField(v, func(rft reflect.StructField, rfv reflect.Value) error { 72 | for _, tag := range []string{"form", "query", "schema"} { 73 | if tagValue := rft.Tag.Get(tag); tagValue != "" { 74 | val := r.FormValue(tagValue) 75 | return rfutil.SetPrimitiveFromString(rft.Type, rfv, val) 76 | } 77 | } 78 | 79 | if tagValue := rft.Tag.Get("url"); tagValue != "" { 80 | val := chi.URLParam(r, tagValue) 81 | return rfutil.SetPrimitiveFromString(rft.Type, rfv, val) 82 | } 83 | 84 | if tagValue := rft.Tag.Get("json"); tagValue != "" { 85 | if val := chi.URLParam(r, tagValue); val != "" { 86 | return rfutil.SetPrimitiveFromString(rft.Type, rfv, val) 87 | } 88 | 89 | val := r.FormValue(tagValue) 90 | if rft.Type.Kind() == reflect.String { 91 | rfv.SetString(val) 92 | return nil 93 | } 94 | 95 | jsonValue := reflect.New(rft.Type) 96 | if err := json.Unmarshal([]byte(val), jsonValue.Interface()); err != nil { 97 | return errors.Wrap(err, "failed to unmarshal JSON") 98 | } 99 | rfv.Set(jsonValue.Elem()) 100 | } 101 | 102 | // Search for the URL parameters manually. 103 | if rctx := chi.RouteContext(r.Context()); rctx != nil { 104 | for i, k := range rctx.URLParams.Keys { 105 | if strings.EqualFold(k, rft.Name) { 106 | return rfutil.SetPrimitiveFromString(rfv.Type(), rfv, rctx.URLParams.Values[i]) 107 | } 108 | } 109 | } 110 | 111 | // Trigger form parsing. 112 | r.FormValue("") 113 | 114 | // Search for URL form values manually. 115 | for k, v := range r.Form { 116 | if strings.EqualFold(k, rft.Name) { 117 | return rfutil.SetPrimitiveFromString(rfv.Type(), rfv, v[0]) 118 | } 119 | } 120 | 121 | return nil // ignore 122 | }) 123 | } 124 | 125 | // DecoderWithValidator wraps an encoder with one that calls Validate() on the 126 | // value after decoding and before encoding if the value implements Validator. 127 | func DecoderWithValidator(enc Decoder) Decoder { 128 | return validatorDecoder{enc} 129 | } 130 | 131 | type validatorDecoder struct{ dec Decoder } 132 | 133 | func (e validatorDecoder) Decode(r *http.Request, v any) error { 134 | if err := e.dec.Decode(r, v); err != nil { 135 | return err 136 | } 137 | 138 | if validator, ok := v.(Validator); ok { 139 | if err := validator.Validate(); err != nil { 140 | return err 141 | } 142 | } 143 | 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /decoder_test.go: -------------------------------------------------------------------------------- 1 | package hrt 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "reflect" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestURLDecoder(t *testing.T) { 12 | type Mega struct { 13 | String string `form:"string"` 14 | Number float64 `form:"number"` 15 | Integer int `form:"integer"` 16 | Time time.Time `form:"time"` 17 | OptString *string `form:"optstring"` 18 | OptNumber *float64 `form:"optnumber"` 19 | OptInteger *int `form:"optinteger"` 20 | OptTime *time.Time `form:"opttime"` 21 | } 22 | 23 | tests := []struct { 24 | name string 25 | input url.Values 26 | expect result[Mega] 27 | }{ 28 | { 29 | name: "only required fields", 30 | input: url.Values{ 31 | "string": {"hello"}, 32 | "number": {"3.14"}, 33 | "integer": {"42"}, 34 | "time": {"2021-01-01T00:00:00Z"}, 35 | }, 36 | expect: okResult(Mega{ 37 | String: "hello", 38 | Number: 3.14, 39 | Integer: 42, 40 | Time: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), 41 | }), 42 | }, 43 | { 44 | name: "only optional fields", 45 | input: url.Values{ 46 | "optstring": {"world"}, 47 | "optnumber": {"2.71"}, 48 | "optinteger": {"24"}, 49 | "opttime": {"2020-01-01T00:00:00Z"}, 50 | }, 51 | expect: okResult(Mega{ 52 | OptString: ptrTo("world"), 53 | OptNumber: ptrTo(2.71), 54 | OptInteger: ptrTo(24), 55 | OptTime: ptrTo(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), 56 | }), 57 | }, 58 | } 59 | 60 | for _, test := range tests { 61 | t.Run(test.name, func(t *testing.T) { 62 | req := &http.Request{ 63 | Form: test.input, 64 | } 65 | 66 | var got Mega 67 | err := URLDecoder.Decode(req, &got) 68 | res := combineResult(got, err) 69 | 70 | if !reflect.DeepEqual(test.expect, res) { 71 | t.Errorf("unexpected test result:\n"+ 72 | "expected: %v\n"+ 73 | "got: %v\n", test.expect, res) 74 | } 75 | }) 76 | } 77 | } 78 | 79 | type result[T any] struct { 80 | value T 81 | error string 82 | } 83 | 84 | func okResult[T any](value T) result[T] { 85 | return result[T]{value: value} 86 | } 87 | 88 | func combineResult[T any](value T, err error) result[T] { 89 | res := result[T]{value: value} 90 | if err != nil { 91 | res.error = err.Error() 92 | } 93 | return res 94 | } 95 | 96 | func ptrTo[T any](v T) *T { return &v } 97 | -------------------------------------------------------------------------------- /encoder.go: -------------------------------------------------------------------------------- 1 | package hrt 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | // DefaultEncoder is the default encoder used by the router. It decodes GET 11 | // requests using the query string and URL parameter; everything else uses JSON. 12 | // 13 | // For the sake of being RESTful, we use a URLDecoder for GET requests. 14 | // Everything else will be decoded as JSON. 15 | var DefaultEncoder = CombinedEncoder{ 16 | Encoder: EncoderWithValidator(JSONEncoder), 17 | Decoder: DecoderWithValidator(MethodDecoder{ 18 | "GET": URLDecoder, 19 | "*": JSONEncoder, 20 | }), 21 | } 22 | 23 | // Encoder describes an encoder that encodes or decodes the request and response 24 | // types. 25 | type Encoder interface { 26 | // Encode encodes the given value into the given writer. 27 | Encode(http.ResponseWriter, any) error 28 | // An encoder must be able to decode the same type it encodes. 29 | Decoder 30 | } 31 | 32 | // CombinedEncoder combines an encoder and decoder pair into one. 33 | type CombinedEncoder struct { 34 | Encoder Encoder 35 | Decoder Decoder 36 | } 37 | 38 | var _ Encoder = CombinedEncoder{} 39 | 40 | // Encode implements the Encoder interface. 41 | func (e CombinedEncoder) Encode(w http.ResponseWriter, v any) error { 42 | return e.Encoder.Encode(w, v) 43 | } 44 | 45 | // Decode implements the Decoder interface. 46 | func (e CombinedEncoder) Decode(r *http.Request, v any) error { 47 | return e.Decoder.Decode(r, v) 48 | } 49 | 50 | // UnencodableEncoder is an encoder that can only decode and not encode. 51 | // It wraps an existing decoder. 52 | // Calling Encode will return a 500 error, as it is considered a bug to return 53 | // anything. 54 | type UnencodableEncoder struct { 55 | Decoder 56 | } 57 | 58 | var _ Encoder = UnencodableEncoder{} 59 | 60 | func (e UnencodableEncoder) Encode(w http.ResponseWriter, v any) error { 61 | return WrapHTTPError(http.StatusInternalServerError, errors.New("cannot encode")) 62 | } 63 | 64 | // JSONEncoder is an encoder that encodes and decodes JSON. 65 | var JSONEncoder Encoder = jsonEncoder{} 66 | 67 | type jsonEncoder struct{} 68 | 69 | func (e jsonEncoder) Encode(w http.ResponseWriter, v any) error { 70 | w.Header().Set("Content-Type", "application/json") 71 | return json.NewEncoder(w).Encode(v) 72 | } 73 | 74 | func (e jsonEncoder) Decode(r *http.Request, v any) error { 75 | return json.NewDecoder(r.Body).Decode(v) 76 | } 77 | 78 | // Validator describes a type that can validate itself. 79 | type Validator interface { 80 | Validate() error 81 | } 82 | 83 | // EncoderWithValidator wraps an encoder with one that calls Validate() on the 84 | // value after decoding and before encoding if the value implements Validator. 85 | func EncoderWithValidator(enc Encoder) Encoder { 86 | return validatorEncoder{enc} 87 | } 88 | 89 | type validatorEncoder struct{ enc Encoder } 90 | 91 | func (e validatorEncoder) Encode(w http.ResponseWriter, v any) error { 92 | if validator, ok := v.(Validator); ok { 93 | if err := validator.Validate(); err != nil { 94 | return err 95 | } 96 | } 97 | 98 | if err := e.enc.Encode(w, v); err != nil { 99 | return err 100 | } 101 | 102 | return nil 103 | } 104 | 105 | func (e validatorEncoder) Decode(r *http.Request, v any) error { 106 | return (validatorDecoder{e.enc}).Decode(r, v) 107 | } 108 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package hrt 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | ) 9 | 10 | // HTTPError extends the error interface with an HTTP status code. 11 | type HTTPError interface { 12 | error 13 | HTTPStatus() int 14 | } 15 | 16 | // ErrorHTTPStatus returns the HTTP status code for the given error. If the 17 | // error is not an HTTPError, it returns defaultCode. 18 | func ErrorHTTPStatus(err error, defaultCode int) int { 19 | var httpErr HTTPError 20 | if errors.As(err, &httpErr) { 21 | return httpErr.HTTPStatus() 22 | } 23 | return defaultCode 24 | } 25 | 26 | type wrappedHTTPError struct { 27 | code int 28 | err error 29 | } 30 | 31 | // WrapHTTPError wraps an error with an HTTP status code. If the error is 32 | // already of type HTTPError, it is returned as-is. To change the HTTP status 33 | // code, use OverrideHTTPError. 34 | func WrapHTTPError(code int, err error) HTTPError { 35 | var httpErr HTTPError 36 | if errors.As(err, &httpErr) { 37 | return httpErr 38 | } 39 | return wrappedHTTPError{code, err} 40 | } 41 | 42 | // NewHTTPError creates a new HTTPError with the given status code and message. 43 | func NewHTTPError(code int, str string) HTTPError { 44 | return wrappedHTTPError{code, errors.New(str)} 45 | } 46 | 47 | // OverrideHTTPError overrides the HTTP status code of the given error. If the 48 | // error is not of type HTTPError, it is wrapped with the given status code. If 49 | // it is, the error is unwrapped and wrapped with the new status code. 50 | func OverrideHTTPError(code int, err error) HTTPError { 51 | var httpErr HTTPError 52 | if errors.As(err, &httpErr) { 53 | err = errors.Unwrap(httpErr) 54 | } 55 | return wrappedHTTPError{code, err} 56 | } 57 | 58 | func (e wrappedHTTPError) HTTPStatus() int { 59 | return e.code 60 | } 61 | 62 | func (e wrappedHTTPError) Error() string { 63 | return fmt.Sprintf("%d: %s", e.code, e.err) 64 | } 65 | 66 | func (e wrappedHTTPError) Unwrap() error { 67 | return e.err 68 | } 69 | 70 | // ErrorWriter is a writer that writes an error to the response. 71 | type ErrorWriter interface { 72 | WriteError(w http.ResponseWriter, err error) 73 | } 74 | 75 | // WriteErrorFunc is a function that implements the ErrorWriter interface. 76 | type WriteErrorFunc func(w http.ResponseWriter, err error) 77 | 78 | // WriteError implements the ErrorWriter interface. 79 | func (f WriteErrorFunc) WriteError(w http.ResponseWriter, err error) { 80 | f(w, err) 81 | } 82 | 83 | // TextErrorWriter writes the error into the response in plain text. 500 84 | // status code is used by default. 85 | var TextErrorWriter ErrorWriter = textErrorWriter{} 86 | 87 | type textErrorWriter struct{} 88 | 89 | func (textErrorWriter) WriteError(w http.ResponseWriter, err error) { 90 | w.Header().Set("Content-Type", "text/plain") 91 | w.WriteHeader(ErrorHTTPStatus(err, http.StatusInternalServerError)) 92 | fmt.Fprintln(w, err) 93 | } 94 | 95 | // JSONErrorWriter writes the error into the response in JSON. 500 status code 96 | // is used by default. The given field is used as the key for the error message. 97 | func JSONErrorWriter(field string) ErrorWriter { 98 | return WriteErrorFunc(func(w http.ResponseWriter, err error) { 99 | w.Header().Set("Content-Type", "application/json") 100 | w.WriteHeader(ErrorHTTPStatus(err, http.StatusInternalServerError)) 101 | 102 | msg := map[string]any{field: err.Error()} 103 | json.NewEncoder(w).Encode(msg) 104 | }) 105 | } 106 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module libdb.so/hrt/v2 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/go-chi/chi/v5 v5.0.8 7 | github.com/pkg/errors v0.9.1 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= 2 | github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 3 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 4 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 5 | -------------------------------------------------------------------------------- /hrt.go: -------------------------------------------------------------------------------- 1 | // Package hrt implements a type-safe HTTP router. It aids in creating a uniform 2 | // API interface while making it easier to create API handlers. 3 | package hrt 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | "reflect" 9 | ) 10 | 11 | type ctxKey uint8 12 | 13 | const ( 14 | routerOptsCtxKey ctxKey = iota 15 | requestCtxKey 16 | ) 17 | 18 | // RequestFromContext returns the request from the Handler's context. 19 | func RequestFromContext(ctx context.Context) *http.Request { 20 | return ctx.Value(requestCtxKey).(*http.Request) 21 | } 22 | 23 | // Opts contains options for the router. 24 | type Opts struct { 25 | Encoder Encoder 26 | ErrorWriter ErrorWriter 27 | } 28 | 29 | // DefaultOpts is the default options for the router. 30 | var DefaultOpts = Opts{ 31 | Encoder: DefaultEncoder, 32 | ErrorWriter: JSONErrorWriter("error"), 33 | } 34 | 35 | // OptsFromContext returns the options from the Handler's context. DefaultOpts 36 | // is returned if no options are found. 37 | func OptsFromContext(ctx context.Context) Opts { 38 | opts, ok := ctx.Value(routerOptsCtxKey).(Opts) 39 | if ok { 40 | return opts 41 | } 42 | return DefaultOpts 43 | } 44 | 45 | // WithOpts returns a new context with the given options. 46 | func WithOpts(ctx context.Context, opts Opts) context.Context { 47 | return context.WithValue(ctx, routerOptsCtxKey, opts) 48 | } 49 | 50 | // Use creates a middleware that injects itself into each request's context. 51 | func Use(opts Opts) func(http.Handler) http.Handler { 52 | return func(next http.Handler) http.Handler { 53 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 54 | ctx := WithOpts(r.Context(), opts) 55 | next.ServeHTTP(w, r.WithContext(ctx)) 56 | }) 57 | } 58 | } 59 | 60 | // None indicates that the request has no body or the request does not return 61 | // anything. 62 | type None struct{} 63 | 64 | // Empty is a value of None. 65 | var Empty = None{} 66 | 67 | // Handler describes a generic handler that takes in a type and returns a 68 | // response. 69 | type Handler[RequestT, ResponseT any] func(ctx context.Context, req RequestT) (ResponseT, error) 70 | 71 | // Wrap wraps a handler into a http.Handler. It exists because Go's type 72 | // inference doesn't work well with the Handler type. 73 | func Wrap[RequestT, ResponseT any](f func(ctx context.Context, req RequestT) (ResponseT, error)) http.Handler { 74 | return Handler[RequestT, ResponseT](f) 75 | } 76 | 77 | // ServeHTTP implements the http.Handler interface. 78 | func (h Handler[RequestT, ResponseT]) ServeHTTP(w http.ResponseWriter, r *http.Request) { 79 | var req RequestT 80 | 81 | // Context cycle! Let's go!! 82 | ctx := context.WithValue(r.Context(), requestCtxKey, r) 83 | 84 | opts := OptsFromContext(ctx) 85 | 86 | req, err := decodeRequest[RequestT](r, opts) 87 | if err != nil { 88 | opts.ErrorWriter.WriteError(w, WrapHTTPError(http.StatusBadRequest, err)) 89 | return 90 | } 91 | 92 | resp, err := h(ctx, req) 93 | if err != nil { 94 | opts.ErrorWriter.WriteError(w, err) 95 | return 96 | } 97 | 98 | if _, ok := any(resp).(None); !ok { 99 | if err := opts.Encoder.Encode(w, resp); err != nil { 100 | opts.ErrorWriter.WriteError(w, WrapHTTPError(http.StatusInternalServerError, err)) 101 | return 102 | } 103 | } 104 | } 105 | 106 | func decodeRequest[RequestT any](r *http.Request, opts Opts) (RequestT, error) { 107 | var req RequestT 108 | if _, ok := any(req).(None); ok { 109 | return req, nil 110 | } 111 | 112 | if reflect.TypeFor[RequestT]().Kind() == reflect.Ptr { 113 | // RequestT is a pointer type, so we need to allocate a new instance. 114 | v := reflect.New(reflect.TypeFor[RequestT]().Elem()).Interface() 115 | 116 | if err := opts.Encoder.Decode(r, v); err != nil { 117 | return req, err 118 | } 119 | 120 | // Return the value as-is, since it's already a pointer. 121 | return v.(RequestT), nil 122 | } 123 | 124 | // RequestT is a value type, so we need to allocate a new pointer instance 125 | // and dereference it afterwards. 126 | v := reflect.New(reflect.TypeFor[RequestT]()).Interface() 127 | 128 | if err := opts.Encoder.Decode(r, v); err != nil { 129 | return req, err 130 | } 131 | 132 | return *v.(*RequestT), nil 133 | } 134 | 135 | // HandlerIntrospection is a struct that contains information about a handler. 136 | // This is primarily used for documentation. 137 | type HandlerIntrospection struct { 138 | // FuncType is the type of the function. 139 | FuncType reflect.Type 140 | // RequestType is the type of the request parameter. 141 | RequestType reflect.Type 142 | // ResponseType is the type of the response parameter. 143 | ResponseType reflect.Type 144 | } 145 | 146 | // TryIntrospectingHandler checks if h is an hrt.Handler and returns its 147 | // introspection if it is, otherwise it returns false. 148 | func TryIntrospectingHandler(h http.Handler) (HandlerIntrospection, bool) { 149 | type introspector interface { 150 | Introspect() HandlerIntrospection 151 | } 152 | var _ introspector = Handler[None, None](nil) 153 | 154 | if h, ok := h.(introspector); ok { 155 | return h.Introspect(), true 156 | } 157 | return HandlerIntrospection{}, false 158 | } 159 | 160 | // Introspect returns information about the handler. 161 | func (h Handler[RequestT, ResponseT]) Introspect() HandlerIntrospection { 162 | var req RequestT 163 | var resp ResponseT 164 | 165 | return HandlerIntrospection{ 166 | FuncType: reflect.TypeOf(h), 167 | RequestType: reflect.TypeOf(req), 168 | ResponseType: reflect.TypeOf(resp), 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /hrt_example_get_test.go: -------------------------------------------------------------------------------- 1 | package hrt_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/url" 7 | "strings" 8 | 9 | "github.com/pkg/errors" 10 | "libdb.so/hrt/v2" 11 | "libdb.so/hrt/v2/internal/ht" 12 | ) 13 | 14 | // EchoRequest is a simple request type that echoes the request. 15 | type EchoRequest struct { 16 | What string `query:"what"` 17 | } 18 | 19 | // Validate implements the hrt.Validator interface. 20 | func (r EchoRequest) Validate() error { 21 | if !strings.HasSuffix(r.What, "!") { 22 | return errors.New("enthusiasm required") 23 | } 24 | return nil 25 | } 26 | 27 | // EchoResponse is a simple response that follows after EchoRequest. 28 | type EchoResponse struct { 29 | What string `json:"what"` 30 | } 31 | 32 | func handleEcho(ctx context.Context, req EchoRequest) (EchoResponse, error) { 33 | return EchoResponse{What: req.What}, nil 34 | } 35 | 36 | func Example_get() { 37 | r := hrt.NewRouter() 38 | r.Use(hrt.Use(hrt.DefaultOpts)) 39 | r.Get("/echo", hrt.Wrap(handleEcho)) 40 | 41 | srv := ht.NewServer(r) 42 | defer srv.Close() 43 | 44 | resp := srv.MustGet("/echo", url.Values{"what": {"hi"}}) 45 | fmt.Printf("HTTP %d: %s", resp.Status, resp.Body) 46 | 47 | resp = srv.MustGet("/echo", url.Values{"what": {"hi!"}}) 48 | fmt.Printf("HTTP %d: %s", resp.Status, resp.Body) 49 | 50 | // Output: 51 | // HTTP 400: {"error":"400: enthusiasm required"} 52 | // HTTP 200: {"what":"hi!"} 53 | } 54 | -------------------------------------------------------------------------------- /hrt_example_post_test.go: -------------------------------------------------------------------------------- 1 | package hrt_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/go-chi/chi/v5" 9 | "github.com/pkg/errors" 10 | "libdb.so/hrt/v2" 11 | "libdb.so/hrt/v2/internal/ht" 12 | ) 13 | 14 | // User is a simple user type. 15 | type User struct { 16 | ID int `json:"id"` 17 | Name string `json:"name"` 18 | } 19 | 20 | var ( 21 | users = make(map[int]User) 22 | usersMu sync.RWMutex 23 | ) 24 | 25 | // GetUserRequest is a request that fetches a user by ID. 26 | type GetUserRequest struct { 27 | ID int `url:"id"` 28 | } 29 | 30 | // Validate implements the hrt.Validator interface. 31 | func (r GetUserRequest) Validate() error { 32 | if r.ID == 0 { 33 | return errors.New("invalid ID") 34 | } 35 | return nil 36 | } 37 | 38 | func handleGetUser(ctx context.Context, req GetUserRequest) (User, error) { 39 | usersMu.RLock() 40 | defer usersMu.RUnlock() 41 | 42 | user, ok := users[req.ID] 43 | if !ok { 44 | return User{}, hrt.WrapHTTPError(404, errors.New("user not found")) 45 | } 46 | 47 | return user, nil 48 | } 49 | 50 | // CreateUserRequest is a request that creates a user. 51 | type CreateUserRequest struct { 52 | Name string `json:"name"` 53 | } 54 | 55 | // Validate implements the hrt.Validator interface. 56 | func (r CreateUserRequest) Validate() error { 57 | if r.Name == "" { 58 | return errors.New("name is required") 59 | } 60 | return nil 61 | } 62 | 63 | func handleCreateUser(ctx context.Context, req CreateUserRequest) (User, error) { 64 | user := User{ 65 | ID: len(users) + 1, 66 | Name: req.Name, 67 | } 68 | 69 | usersMu.Lock() 70 | users[user.ID] = user 71 | usersMu.Unlock() 72 | 73 | return user, nil 74 | } 75 | 76 | func Example_post() { 77 | r := chi.NewRouter() 78 | r.Use(hrt.Use(hrt.DefaultOpts)) 79 | r.Route("/users", func(r chi.Router) { 80 | r.Method("get", "/{id}", hrt.Wrap(handleGetUser)) 81 | r.Method("post", "/", hrt.Wrap(handleCreateUser)) 82 | }) 83 | 84 | srv := ht.NewServer(r) 85 | defer srv.Close() 86 | 87 | resps := []ht.Response{ 88 | srv.MustGet("/users/1", nil), 89 | srv.MustPost("/users", "application/json", ht.AsJSON(map[string]any{})), 90 | srv.MustPost("/users", "application/json", ht.AsJSON(map[string]any{ 91 | "name": "diamondburned", 92 | })), 93 | srv.MustGet("/users/1", nil), 94 | } 95 | 96 | for _, resp := range resps { 97 | fmt.Printf("HTTP %d: %s", resp.Status, resp.Body) 98 | } 99 | 100 | // Output: 101 | // HTTP 404: {"error":"404: user not found"} 102 | // HTTP 400: {"error":"400: name is required"} 103 | // HTTP 200: {"id":1,"name":"diamondburned"} 104 | // HTTP 200: {"id":1,"name":"diamondburned"} 105 | } 106 | -------------------------------------------------------------------------------- /hrt_test.go: -------------------------------------------------------------------------------- 1 | package hrt 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | func TestHandler_Introspect(t *testing.T) { 12 | handler := Wrap(func(ctx context.Context, req echoRequest) (echoResponse, error) { 13 | return echoResponse{What: req.What}, nil 14 | }) 15 | introspection, ok := TryIntrospectingHandler(handler) 16 | if !ok { 17 | t.Fatal("hrt.Handler is not introspectable") 18 | } 19 | t.Log(introspection) 20 | } 21 | 22 | type echoRequest struct { 23 | What string `query:"what"` 24 | } 25 | 26 | func (r echoRequest) Validate() error { 27 | if !strings.HasSuffix(r.What, "!") { 28 | return errors.New("enthusiasm required") 29 | } 30 | return nil 31 | } 32 | 33 | type echoResponse struct { 34 | What string `json:"what"` 35 | } 36 | -------------------------------------------------------------------------------- /internal/ht/ht.go: -------------------------------------------------------------------------------- 1 | // Package ht contains HTTP testing utilities. 2 | package ht 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | ) 12 | 13 | // Response is a minimal HTTP response. 14 | type Response struct { 15 | Status int 16 | Body []byte 17 | } 18 | 19 | // Server is a test server that can be used to test HTTP handlers. 20 | type Server struct { 21 | httptest.Server 22 | } 23 | 24 | // NewServer creates a new test server with the given handler. 25 | func NewServer(h http.Handler) *Server { 26 | s := &Server{*httptest.NewUnstartedServer(h)} 27 | s.Start() 28 | return s 29 | } 30 | 31 | // Close closes the server and cleans up any resources. 32 | func (s *Server) Close() { 33 | s.Server.Close() 34 | } 35 | 36 | // MustGet performs a GET request to the given path with the given query 37 | // parameters. It panics if the request fails. 38 | func (s *Server) MustGet(path string, v url.Values) Response { 39 | url := s.URL + path 40 | if v != nil { 41 | url += "?" + v.Encode() 42 | } 43 | 44 | r, err := s.Client().Get(url) 45 | if err != nil { 46 | panic(err) 47 | } 48 | 49 | defer r.Body.Close() 50 | 51 | rbody, err := io.ReadAll(r.Body) 52 | if err != nil { 53 | panic(err) 54 | } 55 | 56 | return Response{r.StatusCode, rbody} 57 | } 58 | 59 | // MustPost performs a POST request to the given path with the given value to 60 | // be used as a JSON body. 61 | func (s *Server) MustPost(path, contentType string, body []byte) Response { 62 | r, err := s.Client().Post(s.URL+path, contentType, bytes.NewReader(body)) 63 | if err != nil { 64 | panic(err) 65 | } 66 | 67 | defer r.Body.Close() 68 | 69 | rbody, err := io.ReadAll(r.Body) 70 | if err != nil { 71 | panic(err) 72 | } 73 | 74 | return Response{r.StatusCode, rbody} 75 | } 76 | 77 | // AsJSON unmarshals the response body as JSON. It panics if the unmarshaling 78 | // fails. 79 | func AsJSON(v any) []byte { 80 | b, err := json.Marshal(v) 81 | if err != nil { 82 | panic(err) 83 | } 84 | return b 85 | } 86 | -------------------------------------------------------------------------------- /internal/rfutil/rfutil.go: -------------------------------------------------------------------------------- 1 | // Package rfutil contains reflect utilities. 2 | package rfutil 3 | 4 | import ( 5 | "encoding" 6 | "reflect" 7 | "strconv" 8 | 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() 13 | 14 | // SetPrimitiveFromString sets the value of a primitive type from a string. It 15 | // supports strings, ints, uints, floats and bools. If s is empty, the value is 16 | // left untouched. 17 | func SetPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error { 18 | if s == "" { 19 | return nil 20 | } 21 | 22 | if rf.Kind() == reflect.Ptr { 23 | rf = rf.Elem() 24 | 25 | newValue := reflect.New(rf) 26 | rv.Set(newValue) 27 | rv = newValue.Elem() 28 | } 29 | 30 | switch rf.Kind() { 31 | case reflect.String: 32 | rv.SetString(s) 33 | return nil 34 | 35 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 36 | i, err := strconv.ParseInt(s, 10, rf.Bits()) 37 | if err != nil { 38 | return errors.Wrap(err, "invalid int") 39 | } 40 | rv.SetInt(i) 41 | return nil 42 | 43 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 44 | i, err := strconv.ParseUint(s, 10, rf.Bits()) 45 | if err != nil { 46 | return errors.Wrap(err, "invalid uint") 47 | } 48 | rv.SetUint(i) 49 | return nil 50 | 51 | case reflect.Float32, reflect.Float64: 52 | f, err := strconv.ParseFloat(s, rf.Bits()) 53 | if err != nil { 54 | return errors.Wrap(err, "invalid float") 55 | } 56 | rv.SetFloat(f) 57 | return nil 58 | 59 | case reflect.Bool: 60 | // False means omitted according to MDN. 61 | rv.SetBool(s != "") 62 | return nil 63 | } 64 | 65 | if reflect.PointerTo(rf).Implements(textUnmarshalerType) { 66 | unmarshaler := rv.Addr().Interface().(encoding.TextUnmarshaler) 67 | if err := unmarshaler.UnmarshalText([]byte(s)); err != nil { 68 | return errors.Wrap(err, "text unmarshaling") 69 | } 70 | } 71 | 72 | return nil 73 | } 74 | 75 | // EachStructField calls the given function for each field of the given struct. 76 | func EachStructField(v any, f func(reflect.StructField, reflect.Value) error) error { 77 | rv := reflect.Indirect(reflect.ValueOf(v)) 78 | if !rv.IsValid() { 79 | return errors.New("invalid value") 80 | } 81 | 82 | if rv.Kind() != reflect.Struct { 83 | return errors.New("value is not a struct") 84 | } 85 | 86 | rt := rv.Type() 87 | nfields := rv.NumField() 88 | 89 | for i := 0; i < nfields; i++ { 90 | rfv := rv.Field(i) 91 | rft := rt.Field(i) 92 | if !rft.IsExported() { 93 | continue 94 | } 95 | 96 | if err := f(rft, rfv); err != nil { 97 | return err 98 | } 99 | } 100 | 101 | return nil 102 | } 103 | --------------------------------------------------------------------------------