├── go.mod ├── recovery.go ├── LICENSE ├── fs.go ├── gzip.go ├── example └── main.go ├── gzip_test.go ├── validator.go ├── nano_test.go ├── tree.go ├── go.sum ├── router.go ├── validator_test.go ├── context.go ├── binding_test.go ├── nano.go ├── router_test.go ├── cors.go ├── binding.go ├── context_test.go └── README.md /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hariadivicky/nano 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/go-playground/locales v0.13.0 7 | github.com/go-playground/universal-translator v0.17.0 8 | github.com/go-playground/validator v9.31.0+incompatible 9 | github.com/go-playground/validator/v10 v10.3.0 10 | github.com/json-iterator/go v1.1.9 11 | github.com/liamylian/jsontime/v2 v2.0.0 12 | ) 13 | -------------------------------------------------------------------------------- /recovery.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "runtime" 8 | ) 9 | 10 | // Recovery is middleware to recover panic. 11 | func Recovery() HandlerFunc { 12 | return func(c *Context) { 13 | 14 | // defered call 15 | defer func() { 16 | if recovered := recover(); recovered != nil { 17 | err, ok := recovered.(error) 18 | 19 | if !ok { 20 | err = fmt.Errorf("%v", recovered) 21 | } 22 | 23 | // Create 1kb stack size. 24 | stacks := make([]byte, 1024) 25 | length := runtime.Stack(stacks, true) 26 | 27 | // print error and stack trace. 28 | log.Printf("[recovered] %v\n\nTrace %s\n", err, stacks[:length]) 29 | 30 | // response 31 | c.String(http.StatusInternalServerError, "500 Internal Server Error") 32 | } 33 | }() 34 | 35 | c.Next() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Vicky Hariadi 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /fs.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // fileServerHandler handles static file server. 8 | func fileServerHandler(routerPrefix, baseURL string, rootDir http.FileSystem) HandlerFunc { 9 | return func(c *Context) { 10 | prefix := baseURL + "/" 11 | // if current file server not in root group, append router group prefix to baseurl. 12 | if routerPrefix != "" { 13 | prefix = routerPrefix + baseURL + "/" 14 | } 15 | 16 | fs := http.FileServer(rootDir) 17 | // remove static prefix of url. 18 | fileServer := http.StripPrefix(prefix, fs) 19 | 20 | // we will check existence of file, 21 | // if current requested file doesn't exists, we will send not found as response. 22 | file, err := rootDir.Open(c.Param("filepath")) 23 | if err != nil { 24 | c.String(http.StatusNotFound, "file not found") 25 | return 26 | } 27 | 28 | stat, err := file.Stat() 29 | if err != nil { 30 | panic(err) 31 | } 32 | file.Close() 33 | 34 | // disable directory listing. 35 | if stat.IsDir() { 36 | c.String(http.StatusForbidden, "access forbidden") 37 | return 38 | } 39 | 40 | fileServer.ServeHTTP(c.Writer, c.Request) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /gzip.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "compress/gzip" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | type gzipWriter struct { 10 | http.ResponseWriter 11 | writer *gzip.Writer 12 | } 13 | 14 | // Gzip compression for http response. 15 | // this compression works when client accept gzip in their request. 16 | func Gzip(compressionLevel int) HandlerFunc { 17 | return func(c *Context) { 18 | // make sure if client request has gzip in accept-encoding header. 19 | if !strings.Contains(c.GetRequestHeader(HeaderAcceptEncoding), "gzip") { 20 | c.Next() 21 | return 22 | } 23 | 24 | gz, err := gzip.NewWriterLevel(c.Writer, compressionLevel) 25 | // this error may caused incorrect compression level value. 26 | if err != nil { 27 | c.String(http.StatusInternalServerError, "internal server error") 28 | return 29 | } 30 | c.SetHeader(HeaderContentEncoding, "gzip") 31 | defer gz.Close() 32 | 33 | gzWriter := &gzipWriter{c.Writer, gz} 34 | 35 | // replace default writter with Gzip Writer. 36 | c.Writer = gzWriter 37 | c.Next() 38 | } 39 | } 40 | 41 | // Write overrides default http response writer with gzip writter. 42 | func (g *gzipWriter) Write(data []byte) (int, error) { 43 | return g.writer.Write(data) 44 | } 45 | 46 | // WriteHeader overrides response writer to delete content length. 47 | // reference: https://github.com/labstack/echo/issues/444 48 | // If Content-Length header is set, gzip probably writes the wrong number of bytes. 49 | // We should delete the Content-Length header prior to writing the headers on a gzipped response. 50 | func (g *gzipWriter) WriteHeader(code int) { 51 | g.Header().Del(HeaderContentLength) 52 | g.ResponseWriter.WriteHeader(code) 53 | } 54 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/hariadivicky/nano" 13 | ) 14 | 15 | func main() { 16 | app := nano.New() 17 | 18 | // simple endpoint to print hello world. 19 | app.GET("/", func(c *nano.Context) { 20 | c.String(http.StatusOK, "hello world\n") 21 | }) 22 | 23 | // below is logic to gracefully shutdown the web server. 24 | // done channel is used to notify when the shutting down process is complete. 25 | done := make(chan struct{}) 26 | shutdown := make(chan os.Signal) 27 | signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) 28 | 29 | // create server from http std package 30 | server := &http.Server{ 31 | WriteTimeout: 10 * time.Second, 32 | ReadTimeout: 10 * time.Second, 33 | IdleTimeout: 30 * time.Second, 34 | Handler: app, // append nano app as server handler. 35 | Addr: ":8000", 36 | } 37 | 38 | go shutdownHandler(server, shutdown, done) 39 | 40 | log.Println("server running") 41 | server.ListenAndServe() 42 | 43 | // waiting web server to complete shutdown. 44 | <-done 45 | log.Println("server closed") 46 | } 47 | 48 | // shutdownHandler do the graceful shutdown to web server. 49 | // when shutdown signal occurred, it will wait all active request to completly receive their responses. 50 | // we will wait all unfinished request until 30 seconds. 51 | func shutdownHandler(server *http.Server, shutdown <-chan os.Signal, done chan struct{}) { 52 | // waiting for shutdown signal. 53 | <-shutdown 54 | log.Println("shutting down...") 55 | 56 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 57 | defer cancel() 58 | 59 | server.SetKeepAlivesEnabled(false) 60 | if err := server.Shutdown(ctx); err != nil { 61 | log.Fatalf("could not shutdown server: %v", err) 62 | } 63 | 64 | close(done) 65 | } 66 | -------------------------------------------------------------------------------- /gzip_test.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "compress/gzip" 5 | "log" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestGzipMiddleware(t *testing.T) { 12 | app := New() 13 | app.Use(Gzip(gzip.DefaultCompression)) 14 | 15 | app.GET("/", func(c *Context) { 16 | c.String(http.StatusOK, "hello world") 17 | }) 18 | 19 | req, err := http.NewRequest(http.MethodGet, "/", nil) 20 | if err != nil { 21 | log.Fatalf("could not create http request: %v", err) 22 | } 23 | 24 | req.Header.Add(HeaderAcceptEncoding, "gzip") 25 | 26 | rec := httptest.NewRecorder() 27 | app.ServeHTTP(rec, req) 28 | 29 | if encoding := rec.Header().Get(HeaderContentEncoding); encoding != "gzip" { 30 | t.Errorf("expected encoding to be gzip; got %s", encoding) 31 | } 32 | } 33 | 34 | func TestGzipWithoutAcceptEncoding(t *testing.T) { 35 | app := New() 36 | app.Use(Gzip(gzip.DefaultCompression)) 37 | 38 | app.GET("/", func(c *Context) { 39 | c.String(http.StatusOK, "hello world") 40 | }) 41 | 42 | req, err := http.NewRequest(http.MethodGet, "/", nil) 43 | if err != nil { 44 | log.Fatalf("could not create http request: %v", err) 45 | } 46 | rec := httptest.NewRecorder() 47 | 48 | app.ServeHTTP(rec, req) 49 | 50 | if encoding := rec.Header().Get(HeaderContentEncoding); encoding == "gzip" { 51 | t.Errorf("expected encoding not to be gzip; got %s", encoding) 52 | } 53 | } 54 | 55 | func TestGzipWithWrongCompressionLevel(t *testing.T) { 56 | app := New() 57 | 58 | app.Use(Gzip(10)) 59 | 60 | app.GET("/", func(c *Context) { 61 | c.String(http.StatusOK, "hello world") 62 | }) 63 | 64 | req, err := http.NewRequest(http.MethodGet, "/", nil) 65 | if err != nil { 66 | log.Fatalf("could not create http request: %v", err) 67 | } 68 | req.Header.Add(HeaderAcceptEncoding, "gzip") 69 | 70 | rec := httptest.NewRecorder() 71 | app.ServeHTTP(rec, req) 72 | 73 | if rec.Code != http.StatusInternalServerError { 74 | t.Fatalf("expected response code to be 500; got %v", rec.Code) 75 | } 76 | 77 | if encoding := rec.Header().Get(HeaderContentEncoding); encoding == "gzip" { 78 | t.Errorf("expected encoding not to be gzip; got %s", encoding) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /validator.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "net/http" 5 | "reflect" 6 | "strings" 7 | 8 | "github.com/go-playground/locales/en" 9 | ut "github.com/go-playground/universal-translator" 10 | "github.com/go-playground/validator/v10" 11 | en_translations "github.com/go-playground/validator/v10/translations/en" 12 | ) 13 | 14 | // newTranslator returns validator translation. default using "en" 15 | func newTranslator() ut.Translator { 16 | // NOTE: ommitting allot of error checking for brevity 17 | en := en.New() 18 | uni := ut.New(en, en) 19 | 20 | // this is usually know or extracted from http 'Accept-Language' header 21 | // also see uni.FindTranslator(...) 22 | trans, _ := uni.GetTranslator("en") 23 | return trans 24 | } 25 | 26 | func newValidator(trans ut.Translator) *validator.Validate { 27 | v10 := validator.New() 28 | v10.RegisterTagNameFunc(func(fld reflect.StructField) string { 29 | name := strings.SplitN(fld.Tag.Get("form"), ",", 2)[0] 30 | 31 | if name == "-" { 32 | return "" 33 | } 34 | 35 | return name 36 | }) 37 | 38 | en_translations.RegisterDefaultTranslations(v10, trans) 39 | return v10 40 | } 41 | 42 | // validate is default struct validator. this function will called when you do request binding to some struct. 43 | // Current validation rule is only to validate "required" field. To apply field into validation, just add "rules" at field tag. 44 | // if you apply "required" rule, that is mean you are not allowed to use zero type value in you request body field 45 | // because it will give you validation error. 46 | // so if you need 0 value for int field or false value for boolean field, pelase consider to not use "required" rules. 47 | func validate(c *Context, targetStruct interface{}) error { 48 | // only accept pointer 49 | if reflect.TypeOf(targetStruct).Kind() != reflect.Ptr { 50 | return &ErrBinding{ 51 | Text: "expected pointer to target struct, got non-pointer", 52 | Status: http.StatusInternalServerError, 53 | } 54 | } 55 | 56 | err := c.validator.Struct(targetStruct) 57 | 58 | if err != nil { 59 | var errFields []string 60 | for _, err := range err.(validator.ValidationErrors) { 61 | errFields = append(errFields, err.Translate(c.translator)) 62 | } 63 | 64 | return ErrBinding{ 65 | Status: http.StatusUnprocessableEntity, 66 | Text: "validation error", 67 | Fields: errFields, 68 | } 69 | } 70 | 71 | return nil 72 | } 73 | -------------------------------------------------------------------------------- /nano_test.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestUseMiddleware(t *testing.T) { 11 | app := New() 12 | 13 | emptyHandler := func(c *Context) {} 14 | 15 | app.Use(emptyHandler, emptyHandler, emptyHandler) 16 | 17 | if mlen := len(app.middlewares); mlen != 3 { 18 | t.Errorf("expect num of middlewares to be 3; got %d", mlen) 19 | } 20 | } 21 | 22 | func TestGroup(t *testing.T) { 23 | app := New() 24 | 25 | api := app.Group("/api") 26 | if api.prefix != "/api" { 27 | t.Errorf("expected group prefix to be /api; got %s", api.prefix) 28 | } 29 | 30 | finance := api.Group("/finance") 31 | if finance.prefix != "/api/finance" { 32 | t.Errorf("expected group prefix to be /api/finance; got %s", api.prefix) 33 | } 34 | } 35 | 36 | func TestRouteRegistration(t *testing.T) { 37 | app := New() 38 | 39 | emptyHandler := func(c *Context) {} 40 | app.GET("/", emptyHandler) 41 | app.POST("/", emptyHandler) 42 | app.PUT("/", emptyHandler) 43 | app.DELETE("/", emptyHandler) 44 | 45 | if hlen := len(app.router.handlers); hlen != 4 { 46 | t.Errorf("expected num of registered routes to be 4; got %d", hlen) 47 | } 48 | } 49 | 50 | func TestDefaultHandler(t *testing.T) { 51 | app := New() 52 | 53 | if app.router.defaultHandler != nil { 54 | t.Fatalf("expected initial value of default handler to be nil") 55 | } 56 | 57 | t.Run("set default handler", func(st *testing.T) { 58 | app.Default(func(c *Context) { 59 | c.String(http.StatusOK, "ok") 60 | }) 61 | 62 | if app.router.defaultHandler == nil { 63 | st.Errorf("expected default handler to be setted; got %v", app.router.defaultHandler) 64 | } 65 | }) 66 | 67 | t.Run("set default handler when it already set", func(st *testing.T) { 68 | err := app.Default(func(c *Context) { 69 | c.String(http.StatusOK, "ok") 70 | }) 71 | 72 | if err != ErrDefaultHandler { 73 | st.Errorf("expected result to be ErrDefaultHandler; got %v", err) 74 | } 75 | }) 76 | } 77 | 78 | func TestServeHTTP(t *testing.T) { 79 | app := New() 80 | app.GET("/", func(c *Context) { 81 | c.String(http.StatusOK, "ok") 82 | }) 83 | 84 | req, err := http.NewRequest(http.MethodGet, "/", nil) 85 | if err != nil { 86 | log.Fatalf("could not make http request: %v", err) 87 | } 88 | rec := httptest.NewRecorder() 89 | 90 | app.ServeHTTP(rec, req) 91 | 92 | if rec.Code != http.StatusOK { 93 | t.Errorf("expected response code to be 200; got %d", rec.Code) 94 | } 95 | 96 | if body := rec.Body.String(); body != "ok" { 97 | t.Errorf("expected response text to be ok; got %s", body) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /tree.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import "strings" 4 | 5 | // node defines tree node. 6 | type node struct { 7 | urlPattern string 8 | urlPart string 9 | childrens []*node 10 | isWildcard bool 11 | } 12 | 13 | // insertChildren inserts node as children. 14 | // this function calls recursively as length of urlParts and cursor position (level) 15 | func (n *node) insertChildren(urlPattern string, urlParts []string, level int) { 16 | 17 | // last inserted node cause cursor (level) has reached maximum value. 18 | // stop recursive calls. 19 | if len(urlParts) == level { 20 | // fill url pattern to marks current node as complete url pattern. 21 | n.urlPattern = urlPattern 22 | 23 | return 24 | } 25 | 26 | urlPart := urlParts[level] 27 | 28 | // scan existence of current url part in children list. 29 | child := n.findChildren(urlPart) 30 | if child == nil { 31 | // current url part is not already registered as children node. 32 | // register children now. 33 | isWildcard := urlPart[0] == ':' || urlPart[0] == '*' 34 | child = &node{urlPart: urlPart, isWildcard: isWildcard} 35 | n.childrens = append(n.childrens, child) 36 | } 37 | 38 | // insert next urlParts as next level children. 39 | // moving cursor to next urlParts. 40 | child.insertChildren(urlPattern, urlParts, level+1) 41 | } 42 | 43 | // findChildren is functions to find children by url part value. 44 | // this function may return nil value. 45 | func (n *node) findChildren(urlPart string) *node { 46 | 47 | // scanning for children 48 | for _, child := range n.childrens { 49 | // if current child url part is match or contain wildcard, so it's found. 50 | if child.urlPart == urlPart || child.isWildcard { 51 | return child 52 | } 53 | } 54 | 55 | // there is no children with current urlPart 56 | return nil 57 | } 58 | 59 | // findNode finds a node. 60 | // first (n *node) may be node that located at router.nodes[requestMethod]. 61 | func (n *node) findNode(searchParts []string, level int) *node { 62 | // cursor (level) reached maximum position. 63 | // or current url part has * wildcard 64 | if len(searchParts) == level || strings.HasPrefix(n.urlPart, "*") { 65 | // if current pattern has no url pattern, this mean current node doesn't complete. 66 | // not found. 67 | if n.urlPattern == "" { 68 | return nil 69 | } 70 | 71 | return n 72 | } 73 | 74 | // get current search part by cursor (level). 75 | urlPart := searchParts[level] 76 | 77 | // scan for nested childrens*. 78 | // *please read about getChildren. 79 | for _, child := range n.getChildren(urlPart) { 80 | // move cursor, scan recursively. 81 | result := child.findNode(searchParts, level+1) 82 | // found! 83 | if result != nil { 84 | return result 85 | } 86 | 87 | return nil 88 | } 89 | 90 | return nil 91 | } 92 | 93 | // getChildren finds a children that has certain part 94 | // or it's a wildcard 95 | func (n *node) getChildren(urlPart string) []*node { 96 | nodes := make([]*node, 0) 97 | 98 | for _, node := range n.childrens { 99 | if node.urlPart == urlPart || node.isWildcard { 100 | nodes = append(nodes, node) 101 | } 102 | } 103 | 104 | return nodes 105 | } 106 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= 5 | github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 6 | github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= 7 | github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= 8 | github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= 9 | github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= 10 | github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= 11 | github.com/go-playground/validator v9.31.0+incompatible/go.mod h1:yrEkQXlcI+PugkyDjY2bRrL/UBU4f3rvrgkN3V8JEig= 12 | github.com/go-playground/validator/v10 v10.3.0 h1:nZU+7q+yJoFmwvNgv/LnPUkwPal62+b2xXj0AU1Es7o= 13 | github.com/go-playground/validator/v10 v10.3.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= 14 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 15 | github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= 16 | github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= 17 | github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= 18 | github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= 19 | github.com/liamylian/jsontime v1.0.1 h1:zM/Dxvu7X0iq9BpM2KMpGsKYEIHYDxf04z0GmcKId44= 20 | github.com/liamylian/jsontime/v2 v2.0.0 h1:3if2kDW/boymUdO+4Qj/m4uaXMBSF6np9KEgg90cwH0= 21 | github.com/liamylian/jsontime/v2 v2.0.0/go.mod h1:UHp1oAPqCBfspokvGmaGe0IAl2IgOpgOgDaKPcvcGGY= 22 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 23 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 24 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 25 | github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= 26 | github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= 27 | github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= 28 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 29 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 30 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 31 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 32 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 33 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 34 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 35 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 36 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 37 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 38 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 39 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // router defines main router structure. 10 | type router struct { 11 | nodes map[string]*node 12 | handlers map[string][]HandlerFunc 13 | defaultHandler HandlerFunc 14 | } 15 | 16 | // newRouter creates new router instance. 17 | func newRouter() *router { 18 | return &router{ 19 | nodes: make(map[string]*node), 20 | handlers: make(map[string][]HandlerFunc), 21 | } 22 | } 23 | 24 | // createUrlParts returns splitted path. 25 | func createURLParts(urlPattern string) []string { 26 | patternParts := strings.Split(urlPattern, "/") 27 | 28 | urlParts := make([]string, 0) 29 | 30 | for _, path := range patternParts { 31 | // ignore root path 32 | if path != "" { 33 | urlParts = append(urlParts, path) 34 | 35 | // only * wildcard is allowed. 36 | if path[0] == '*' { 37 | break 38 | } 39 | } 40 | } 41 | 42 | return urlParts 43 | } 44 | 45 | // addRoute registers route to router. 46 | // you could use multiple handler. 47 | func (r *router) addRoute(requestMethod, urlPattern string, handler ...HandlerFunc) { 48 | urlParts := createURLParts(urlPattern) 49 | 50 | rootNode, exists := r.nodes[requestMethod] 51 | 52 | // current request method root node doesn't exists. 53 | if !exists { 54 | r.nodes[requestMethod] = &node{} 55 | rootNode = r.nodes[requestMethod] 56 | } 57 | 58 | // register route. 59 | key := fmt.Sprintf("%s-%s", requestMethod, urlPattern) 60 | 61 | // insert children to tree. 62 | rootNode.insertChildren(urlPattern, urlParts, 0) 63 | r.handlers[key] = handler 64 | } 65 | 66 | // findRoute finds current request with stored url pattern in node tree. 67 | // this function also mapping your parameter (which was defined in url pattern) from url request. 68 | func (r *router) findRoute(requestMethod, urlPath string) (*node, map[string]string) { 69 | searchParts := createURLParts(urlPath) 70 | params := make(map[string]string) 71 | 72 | rootNode, exists := r.nodes[requestMethod] 73 | 74 | // there are no routes with current request method 75 | if !exists { 76 | return nil, nil 77 | } 78 | 79 | // scan child node recursively. 80 | node := rootNode.findNode(searchParts, 0) 81 | 82 | if node != nil { 83 | // replace param placeholder with current request value. 84 | for index, path := range createURLParts(node.urlPattern) { 85 | // current pattern is parameter. 86 | if path[0] == ':' { 87 | params[path[1:]] = searchParts[index] 88 | } 89 | 90 | // current pattern is * wildcard, that means all path are used. 91 | if path[0] == '*' && len(path) > 1 { 92 | params[path[1:]] = strings.Join(searchParts[index:], "/") 93 | } 94 | } 95 | 96 | return node, params 97 | } 98 | 99 | return nil, nil 100 | } 101 | 102 | // notFoundHandler is router default handler. 103 | func (r *router) notFoundHandler() HandlerFunc { 104 | return func(c *Context) { 105 | c.String(http.StatusNotFound, "nano/1.0 not found") 106 | } 107 | } 108 | 109 | // serveDefaultHandler appends default handler to call stacks. 110 | // if you not set the default handler, we will set notFoundHandler as default. 111 | func (r *router) serveDefaultHandler(c *Context) { 112 | // create not found handler when default handler not set yet. 113 | if r.defaultHandler == nil { 114 | r.defaultHandler = r.notFoundHandler() 115 | } 116 | 117 | c.handlers = append(c.handlers, r.defaultHandler) 118 | c.Next() 119 | } 120 | 121 | // handle incoming request. if there is no matching route, 122 | // router will serve default handler. 123 | func (r *router) handle(c *Context) { 124 | node, params := r.findRoute(c.Method, c.Path) 125 | 126 | // current request has a match route. 127 | if node != nil { 128 | key := fmt.Sprintf("%s-%s", c.Method, node.urlPattern) 129 | c.Params = params 130 | 131 | // append current handler to handler stack. 132 | // extract route handler(s). 133 | c.handlers = append(c.handlers, r.handlers[key]...) 134 | } else { 135 | // no matching routes, serve default. 136 | r.serveDefaultHandler(c) 137 | } 138 | 139 | // call handlers stack. 140 | c.Next() 141 | } 142 | -------------------------------------------------------------------------------- /validator_test.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func setupContext() *Context { 10 | r, _ := http.NewRequest(http.MethodGet, "/", nil) 11 | w := httptest.NewRecorder() 12 | 13 | return newContext(w, r) 14 | 15 | } 16 | 17 | func TestValidator(t *testing.T) { 18 | type Person struct { 19 | Name string `form:"name" json:"name" validate:"required"` 20 | Gender string `form:"gender" json:"gender" validate:"required"` 21 | Email string `form:"email" json:"email" validate:"required"` 22 | Phone string `form:"phone" json:"phone"` 23 | privateField string 24 | IgnoredField string `form:"-"` 25 | } 26 | 27 | person := Person{ 28 | Name: "foo", 29 | Gender: "male", 30 | Email: "hariadivicky@gmail.com", 31 | Phone: "", 32 | } 33 | 34 | ctx := setupContext() 35 | 36 | t.Run("pass non-pointer struct", func(st *testing.T) { 37 | err := validate(ctx, person) 38 | if err == nil { 39 | t.Fatalf("expected error to be returned") 40 | } 41 | 42 | if errBind, ok := err.(ErrBinding); ok { 43 | if errBind.Status != ErrBindNonPointer.Status { 44 | st.Errorf("expected HTTPStatusCode error to be %d; got %d", ErrBindNonPointer.Status, errBind.Status) 45 | } 46 | 47 | if errBind.Text != ErrBindNonPointer.Text { 48 | st.Errorf("expected error message to be %s; got %s", ErrBindNonPointer.Text, errBind.Text) 49 | } 50 | } 51 | }) 52 | 53 | t.Run("validation should be passed", func(st *testing.T) { 54 | errBind := validate(ctx, &person) 55 | 56 | if errBind != nil { 57 | t.Errorf("expected error binding to be nil; got %v", errBind) 58 | } 59 | }) 60 | 61 | t.Run("empty value on required fields", func(st *testing.T) { 62 | person.Name = "" 63 | person.Gender = "" 64 | person.Email = "" 65 | 66 | err := validate(ctx, &person) 67 | if err == nil { 68 | st.Fatalf("expected error to be returned") 69 | } 70 | 71 | if bindErr, ok := err.(ErrBinding); ok { 72 | if bindErr.Status != http.StatusUnprocessableEntity { 73 | st.Errorf("expected HTTPStatusCode error to be %d; got %d", ErrBindNonPointer.Status, http.StatusUnprocessableEntity) 74 | } 75 | 76 | if bindErr.Text != "validation error" { 77 | st.Errorf("expected error message to be %s; got %s", ErrBindNonPointer.Text, bindErr.Text) 78 | } 79 | 80 | if errFieldsCount := len(bindErr.Fields); errFieldsCount != 3 { 81 | st.Fatalf("expected num of error fields to be 3; got %d", errFieldsCount) 82 | } 83 | 84 | errFields := []string{ 85 | "name is a required field", 86 | "gender is a required field", 87 | "email is a required field", 88 | } 89 | 90 | for i, errMsg := range bindErr.Fields { 91 | if errMsg != errFields[i] { 92 | st.Errorf("expected error %d to be %s; got %s", i, errFields[i], errMsg) 93 | } 94 | } 95 | 96 | return 97 | } 98 | 99 | st.Fatalf("expected error type to be ErrBinding, got %T", err) 100 | }) 101 | 102 | } 103 | 104 | func TestNestedStructValidation(t *testing.T) { 105 | type Person struct { 106 | Name string `form:"name" json:"name" validate:"required"` 107 | Gender string `form:"gender" json:"gender" validate:"required"` 108 | Address struct { 109 | CityID int `form:"city_id" json:"city_id" validate:"required"` 110 | PostalCode int `form:"postal_code" json:"postal_code"` 111 | } 112 | } 113 | 114 | person := Person{ 115 | Name: "foo", 116 | Gender: "", 117 | Address: struct { 118 | CityID int `form:"city_id" json:"city_id" validate:"required"` 119 | PostalCode int `form:"postal_code" json:"postal_code"` 120 | }{ 121 | CityID: 0, 122 | PostalCode: 204, 123 | }, 124 | } 125 | 126 | ctx := setupContext() 127 | 128 | err := validate(ctx, &person) 129 | if err == nil { 130 | t.Fatalf("expected error to be returned") 131 | } 132 | 133 | if errBind, ok := err.(ErrBinding); ok { 134 | if errBind.Status != http.StatusUnprocessableEntity { 135 | t.Errorf("expected error HTTPStatusCode to be %d; got %d", http.StatusUnprocessableEntity, errBind.Status) 136 | } 137 | 138 | if errBind.Text != "validation error" { 139 | t.Errorf("expected error message to be validation error; got %s", errBind.Text) 140 | } 141 | 142 | errFields := []string{ 143 | "gender is a required field", 144 | "city_id is a required field", 145 | } 146 | 147 | for i, errMsg := range errBind.Fields { 148 | if errMsg != errFields[i] { 149 | t.Errorf("expected error %d to be %s; got %s", i, errFields[i], errMsg) 150 | } 151 | } 152 | 153 | return 154 | } 155 | 156 | t.Fatalf("expected ErrBinding, got %T", err) 157 | } 158 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | 8 | ut "github.com/go-playground/universal-translator" 9 | "github.com/go-playground/validator/v10" 10 | ) 11 | 12 | // Bag stores context key:value parameter. 13 | type Bag struct { 14 | data map[string]interface{} 15 | } 16 | 17 | // NewBag creates new bag instance. 18 | func NewBag() *Bag { 19 | return &Bag{ 20 | data: make(map[string]interface{}), 21 | } 22 | } 23 | 24 | // Set bad data. 25 | func (b *Bag) Set(key string, data interface{}) { 26 | b.data[key] = data 27 | } 28 | 29 | // Get data by given key. 30 | func (b *Bag) Get(key string) interface{} { 31 | if data, ok := b.data[key]; ok { 32 | return data 33 | } 34 | 35 | return nil 36 | } 37 | 38 | // Context defines nano request - response context. 39 | type Context struct { 40 | Request *http.Request 41 | Writer http.ResponseWriter 42 | Method string 43 | Path string 44 | Origin string 45 | Params map[string]string 46 | handlers []HandlerFunc 47 | Bag *Bag 48 | cursor int // used for handlers stack. 49 | validator *validator.Validate 50 | translator ut.Translator 51 | } 52 | 53 | // newContext is Context constructor. 54 | func newContext(w http.ResponseWriter, r *http.Request) *Context { 55 | 56 | trans := newTranslator() 57 | validator := newValidator(trans) 58 | 59 | return &Context{ 60 | Request: r, 61 | Writer: w, 62 | Method: r.Method, 63 | Path: r.URL.Path, 64 | Origin: r.Header.Get(HeaderOrigin), 65 | cursor: -1, 66 | Bag: NewBag(), 67 | validator: validator, 68 | translator: trans, 69 | } 70 | } 71 | 72 | // Next moves cursor to the next handler stack. 73 | func (c *Context) Next() { 74 | // moving cursor. 75 | c.cursor++ 76 | 77 | if c.cursor < len(c.handlers) { 78 | c.handlers[c.cursor](c) 79 | } 80 | } 81 | 82 | // Status sets http status code response. 83 | func (c *Context) Status(statusCode int) { 84 | c.Writer.WriteHeader(statusCode) 85 | } 86 | 87 | // SetHeader sets http response header. 88 | func (c *Context) SetHeader(key, value string) { 89 | c.Writer.Header().Set(key, value) 90 | } 91 | 92 | // GetRequestHeader returns header value by given key. 93 | func (c *Context) GetRequestHeader(key string) string { 94 | return c.Request.Header.Get(key) 95 | } 96 | 97 | // SetContentType sets http content type response header. 98 | func (c *Context) SetContentType(contentType string) { 99 | c.SetHeader(HeaderContentType, contentType) 100 | } 101 | 102 | // Param gets request parameter. 103 | func (c *Context) Param(key string) string { 104 | value, _ := c.Params[key] 105 | return value 106 | } 107 | 108 | // PostForm gets form body field. 109 | func (c *Context) PostForm(key string) string { 110 | return c.Request.FormValue(key) 111 | } 112 | 113 | // PostFormDefault returns default value when form body field is empty. 114 | func (c *Context) PostFormDefault(key string, defaultValue string) string { 115 | v := c.PostForm(key) 116 | 117 | if v == "" { 118 | return defaultValue 119 | } 120 | 121 | return v 122 | } 123 | 124 | // Query gets url query. 125 | func (c *Context) Query(key string) string { 126 | return c.Request.URL.Query().Get(key) 127 | } 128 | 129 | // QueryDefault return default value when url query is empty 130 | func (c *Context) QueryDefault(key string, defaultValue string) string { 131 | v := c.Query(key) 132 | 133 | if v == "" { 134 | return defaultValue 135 | } 136 | 137 | return v 138 | } 139 | 140 | // IsJSON returns true when client send json body. 141 | func (c *Context) IsJSON() bool { 142 | return c.GetRequestHeader(HeaderContentType) == MimeJSON 143 | } 144 | 145 | // ExpectJSON returns true when client request json response, 146 | // since this function use string.Contains, value ordering in Accept values doesn't matter. 147 | func (c *Context) ExpectJSON() bool { 148 | return strings.Contains(c.GetRequestHeader(HeaderAccept), MimeJSON) 149 | } 150 | 151 | // JSON writes json as response. 152 | func (c *Context) JSON(statusCode int, object interface{}) { 153 | rs, err := json.Marshal(object) 154 | if err != nil { 155 | c.String(http.StatusInternalServerError, "internal server error") 156 | return 157 | } 158 | 159 | c.SetContentType(MimeJSON) 160 | c.Status(statusCode) 161 | c.Writer.Write(rs) 162 | } 163 | 164 | // String writes plain text as response. 165 | func (c *Context) String(statusCode int, template string, value ...interface{}) { 166 | c.SetContentType(MimePlainText) 167 | c.Status(statusCode) 168 | 169 | text := fmt.Sprintf(template, value...) 170 | 171 | c.Writer.Write([]byte(text)) 172 | } 173 | 174 | // File returns static file as response. 175 | func (c *Context) File(statusCode int, filepath string) { 176 | http.ServeFile(c.Writer, c.Request, filepath) 177 | } 178 | 179 | // HTML writes html as response. 180 | func (c *Context) HTML(statusCode int, html string) { 181 | c.SetContentType(MimeHTML) 182 | c.Status(statusCode) 183 | c.Writer.Write([]byte(html)) 184 | } 185 | 186 | // Data writes binary as response. 187 | func (c *Context) Data(statusCode int, binary []byte) { 188 | c.Status(statusCode) 189 | c.Writer.Write(binary) 190 | } 191 | -------------------------------------------------------------------------------- /binding_test.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "bytes" 5 | "log" 6 | "mime/multipart" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func TestAutoBindingForUnexpectedContentType(t *testing.T) { 15 | req, err := http.NewRequest(http.MethodPost, "/", nil) 16 | if err != nil { 17 | log.Fatalf("could not create http request: %v", err) 18 | } 19 | req.Header.Add(HeaderContentType, "x-unknown") 20 | rec := httptest.NewRecorder() 21 | ctx := newContext(rec, req) 22 | 23 | type Person struct { 24 | Name string `form:"name"` 25 | Gender string `form:"gender"` 26 | } 27 | 28 | var person Person 29 | if err = ctx.Bind(&person); err == nil { 30 | t.Fatalf("expected error returned") 31 | } 32 | 33 | if err, ok := err.(ErrBinding); ok { 34 | if err.Status != ErrBindContentType.Status { 35 | t.Errorf("expected error HTTPStatusCode to be %d; got %d", ErrBindContentType.Status, err.Status) 36 | } 37 | 38 | if err.Text != ErrBindContentType.Text { 39 | t.Errorf("expected error message to be %s; got %s", ErrBindContentType.Text, err.Text) 40 | } 41 | 42 | return 43 | } 44 | 45 | t.Fatalf("expected ErrBinding type returned, got %T", err) 46 | 47 | } 48 | 49 | func TestAutoBindingForURLEncoded(t *testing.T) { 50 | form := url.Values{} 51 | form.Set("name", "foo") 52 | form.Set("gender", "male") 53 | 54 | req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) 55 | if err != nil { 56 | log.Fatalf("could not create http request: %v", err) 57 | } 58 | req.Header.Add(HeaderContentType, MimeFormURLEncoded) 59 | rec := httptest.NewRecorder() 60 | ctx := newContext(rec, req) 61 | 62 | type Person struct { 63 | Name string `form:"name"` 64 | Gender string `form:"gender"` 65 | } 66 | 67 | var person Person 68 | errBinding := ctx.Bind(&person) 69 | 70 | if nm := ctx.PostForm("name"); nm != "foo" { 71 | t.Fatalf("expected form name value to be foo; got %s", nm) 72 | } 73 | 74 | if errBinding != nil { 75 | t.Fatalf("expected err binding to nil") 76 | } 77 | 78 | if person.Name != "foo" { 79 | t.Errorf("expected name to be foo; got %s", person.Name) 80 | } 81 | 82 | if person.Gender != "male" { 83 | t.Errorf("expected gender to be male; got %s", person.Gender) 84 | } 85 | } 86 | 87 | func TestAutoBindingForJSON(t *testing.T) { 88 | form := []byte(`{"name":"foo", "gender":"male"}`) 89 | req, err := http.NewRequest(http.MethodPost, "/", bytes.NewBuffer(form)) 90 | if err != nil { 91 | log.Fatalf("could not create http request: %v", err) 92 | } 93 | req.Header.Add(HeaderContentType, MimeJSON) 94 | rec := httptest.NewRecorder() 95 | ctx := newContext(rec, req) 96 | 97 | type Person struct { 98 | Name string `form:"name" json:"name"` 99 | Gender string `form:"gender" json:"gender"` 100 | } 101 | 102 | var person Person 103 | errBinding := ctx.Bind(&person) 104 | 105 | if errBinding != nil { 106 | t.Fatalf("expected err binding to nil") 107 | } 108 | 109 | if person.Name != "foo" { 110 | t.Errorf("expected name to be foo; got %s", person.Name) 111 | } 112 | 113 | if person.Gender != "male" { 114 | t.Errorf("expected gender to be male; got %s", person.Gender) 115 | } 116 | } 117 | 118 | func TestAutoBindingForMultipartForm(t *testing.T) { 119 | body := new(bytes.Buffer) 120 | form := multipart.NewWriter(body) 121 | form.WriteField("name", "foo") 122 | form.WriteField("gender", "male") 123 | 124 | req, err := http.NewRequest(http.MethodPost, "/", body) 125 | if err != nil { 126 | log.Fatalf("could not create http request: %v", err) 127 | } 128 | req.Header.Add(HeaderContentType, form.FormDataContentType()) 129 | form.Close() 130 | rec := httptest.NewRecorder() 131 | ctx := newContext(rec, req) 132 | 133 | type Person struct { 134 | Name string `form:"name" json:"name"` 135 | Gender string `form:"gender" json:"gender"` 136 | } 137 | 138 | var person Person 139 | 140 | if err = ctx.Bind(&person); err != nil { 141 | t.Fatalf("expected err binding to nil; got %T", err) 142 | } 143 | 144 | if person.Name != "foo" { 145 | t.Errorf("expected name to be foo; got %s", person.Name) 146 | } 147 | 148 | if person.Gender != "male" { 149 | t.Errorf("expected gender to be male; got %s", person.Gender) 150 | } 151 | } 152 | 153 | func TestBindJSON(t *testing.T) { 154 | type Person struct { 155 | Name string 156 | Gender string 157 | } 158 | 159 | var person Person 160 | 161 | t.Run("bind non pointer struct", func(st *testing.T) { 162 | req, err := http.NewRequest(http.MethodGet, "/", nil) 163 | if err != nil { 164 | log.Fatalf("could not make http request: %v", err) 165 | } 166 | w := httptest.NewRecorder() 167 | 168 | ctx := newContext(w, req) 169 | 170 | err = ctx.BindJSON(person) 171 | if err == nil { 172 | st.Errorf("expected error to be returned; got %T", err) 173 | } 174 | 175 | if errBinding, ok := err.(ErrBinding); ok { 176 | if errBinding.Error() != ErrBindNonPointer.Error() { 177 | st.Errorf("expect error to be ErrBindNonPointer; got %v", errBinding) 178 | } 179 | 180 | return 181 | } 182 | 183 | st.Fatalf("expected ErrBinding, got %T", err) 184 | 185 | }) 186 | } 187 | -------------------------------------------------------------------------------- /nano.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 Vicky Hariadi Pratama. All rights reserved. 2 | // license that can be found in the LICENSE file. 3 | // this package is http route multiplexing. 4 | 5 | package nano 6 | 7 | import ( 8 | "errors" 9 | "net/http" 10 | "strings" 11 | 12 | jsontime "github.com/liamylian/jsontime/v2/v2" 13 | ) 14 | 15 | func init() { 16 | jsontime.AddTimeFormatAlias("sql_date", "2006-01-02") 17 | jsontime.AddTimeFormatAlias("sql_datetime", "2006-01-02 15:04:02") 18 | } 19 | 20 | const ( 21 | // HeaderAcceptEncoding is accept encoding. 22 | HeaderAcceptEncoding = "Accept-Encoding" 23 | // HeaderContentEncoding is content encoding. 24 | HeaderContentEncoding = "Content-Encoding" 25 | // HeaderContentLength is content length. 26 | HeaderContentLength = "Content-Length" 27 | // HeaderContentType is content type. 28 | HeaderContentType = "Content-Type" 29 | // HeaderAccept is accept content type. 30 | HeaderAccept = "Accept" 31 | // HeaderOrigin is request origin. 32 | HeaderOrigin = "Origin" 33 | // HeaderVary is request vary. 34 | HeaderVary = "Vary" 35 | // HeaderAccessControlRequestMethod is cors request method. 36 | HeaderAccessControlRequestMethod = "Access-Control-Request-Method" 37 | // HeaderAccessControlRequestHeader is cors request header. 38 | HeaderAccessControlRequestHeader = "Access-Control-Request-Header" 39 | // HeaderAccessControlAllowOrigin is cors allowed origins. 40 | HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" 41 | // HeaderAccessControlAllowMethods is cors allowed origins. 42 | HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" 43 | // HeaderAccessControlAllowHeader is cors allowed headers. 44 | HeaderAccessControlAllowHeader = "Access-Control-Allow-Header" 45 | 46 | // MimeJSON is standard json mime. 47 | MimeJSON = "application/json" 48 | // MimeXML is standard json mime. 49 | MimeXML = "application/xml" 50 | // MimeHTML is standard html mime. 51 | MimeHTML = "text/html" 52 | // MimePlainText is standard plain text mime. 53 | MimePlainText = "text/plain" 54 | // MimeMultipartForm is standard multipart form mime. 55 | MimeMultipartForm = "multipart/form-data" 56 | // MimeFormURLEncoded is standard urlencoded form mime. 57 | MimeFormURLEncoded = "application/x-www-form-urlencoded" 58 | ) 59 | 60 | var ( 61 | json = jsontime.ConfigWithCustomTimeFormat 62 | // ErrDefaultHandler should be returned when user try to set default handler for seconds time. 63 | ErrDefaultHandler = errors.New("default handler already registered") 64 | ) 65 | 66 | // Engine defines nano web engine. 67 | type Engine struct { 68 | *RouterGroup 69 | router *router 70 | debug bool 71 | groups []*RouterGroup 72 | } 73 | 74 | // RouterGroup defines collection of route that has same prefix 75 | type RouterGroup struct { 76 | prefix string 77 | engine *Engine 78 | middlewares []HandlerFunc 79 | parent *RouterGroup 80 | } 81 | 82 | // H defines json wrapper. 83 | type H map[string]interface{} 84 | 85 | // HandlerFunc defines nano request handler function signature. 86 | type HandlerFunc func(c *Context) 87 | 88 | // New is nano constructor 89 | func New() *Engine { 90 | engine := &Engine{ 91 | router: newRouter(), 92 | debug: false, 93 | } 94 | 95 | engine.RouterGroup = &RouterGroup{engine: engine} 96 | engine.groups = []*RouterGroup{engine.RouterGroup} 97 | 98 | return engine 99 | } 100 | 101 | // Use functions to apply middleware function(s). 102 | func (rg *RouterGroup) Use(middlewares ...HandlerFunc) { 103 | rg.middlewares = append(rg.middlewares, middlewares...) 104 | } 105 | 106 | // Group functions to create new router group. 107 | func (rg *RouterGroup) Group(prefix string) *RouterGroup { 108 | group := &RouterGroup{ 109 | prefix: rg.prefix + prefix, 110 | parent: rg, 111 | engine: rg.engine, 112 | } 113 | 114 | rg.engine.groups = append(rg.engine.groups, group) 115 | 116 | return group 117 | } 118 | 119 | // HEAD functions to register route with HEAD request method. 120 | func (rg *RouterGroup) HEAD(urlPattern string, handler ...HandlerFunc) { 121 | rg.addRoute(http.MethodHead, urlPattern, handler...) 122 | } 123 | 124 | // GET functions to register route with GET request method. 125 | func (rg *RouterGroup) GET(urlPattern string, handler ...HandlerFunc) { 126 | rg.addRoute(http.MethodGet, urlPattern, handler...) 127 | } 128 | 129 | // POST functions to register route with POST request method. 130 | func (rg *RouterGroup) POST(urlPattern string, handler ...HandlerFunc) { 131 | rg.addRoute(http.MethodPost, urlPattern, handler...) 132 | } 133 | 134 | // PUT functions to register route with PUT request method. 135 | func (rg *RouterGroup) PUT(urlPattern string, handler ...HandlerFunc) { 136 | rg.addRoute(http.MethodPut, urlPattern, handler...) 137 | } 138 | 139 | // OPTIONS functions to register route with OPTIONS request method. 140 | func (rg *RouterGroup) OPTIONS(urlPattern string, handler ...HandlerFunc) { 141 | rg.addRoute(http.MethodOptions, urlPattern, handler...) 142 | } 143 | 144 | // PATCH functions to register route with PATCH request method. 145 | func (rg *RouterGroup) PATCH(urlPattern string, handler ...HandlerFunc) { 146 | rg.addRoute(http.MethodPatch, urlPattern, handler...) 147 | } 148 | 149 | // DELETE functions to register route with DELETE request method. 150 | func (rg *RouterGroup) DELETE(urlPattern string, handler ...HandlerFunc) { 151 | rg.addRoute(http.MethodDelete, urlPattern, handler...) 152 | } 153 | 154 | // Default functions to register default handler when no matching routes. 155 | // Only one Default handler allowed to register. 156 | func (rg *RouterGroup) Default(handler HandlerFunc) error { 157 | // reject overriding. 158 | if rg.engine.router.defaultHandler != nil { 159 | return ErrDefaultHandler 160 | } 161 | 162 | rg.engine.router.defaultHandler = handler 163 | return nil 164 | } 165 | 166 | // Static creates static file server. 167 | func (rg *RouterGroup) Static(baseURL string, rootDir http.FileSystem) { 168 | if strings.Contains(baseURL, ":") || strings.Contains(baseURL, "*") { 169 | panic("cannot use dynamic url parameter in file server base url") 170 | } 171 | 172 | urlPattern := baseURL + "/*filepath" 173 | handler := fileServerHandler(rg.prefix, baseURL, rootDir) 174 | rg.GET(urlPattern, handler) 175 | rg.HEAD(urlPattern, handler) 176 | } 177 | 178 | // addRoute functions to register new route with current group prefix. 179 | func (rg *RouterGroup) addRoute(requestMethod, urlPattern string, handler ...HandlerFunc) { 180 | // append router group prefix. 181 | prefixedURLPattern := rg.prefix + urlPattern 182 | 183 | rg.engine.router.addRoute(requestMethod, prefixedURLPattern, handler...) 184 | } 185 | 186 | // ServeHTTP implements multiplexer. 187 | func (ng *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { 188 | middlewares := make([]HandlerFunc, 0) 189 | 190 | // scanning for router group middleware. 191 | for _, group := range ng.groups { 192 | if strings.HasPrefix(r.URL.Path, group.prefix) { 193 | middlewares = append(middlewares, group.middlewares...) 194 | } 195 | } 196 | 197 | ctx := newContext(w, r) 198 | ctx.handlers = middlewares 199 | ng.router.handle(ctx) 200 | } 201 | 202 | // Run application. 203 | func (ng *Engine) Run(address string) error { 204 | return http.ListenAndServe(address, ng) 205 | } 206 | -------------------------------------------------------------------------------- /router_test.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestCreateURLParts(t *testing.T) { 11 | tt := []struct { 12 | name string 13 | url string 14 | slices []string 15 | }{ 16 | {"root url", "/", []string{}}, 17 | {"one url part", "/home", []string{"home"}}, 18 | {"one url part without backslash prefix", "home", []string{"home"}}, 19 | {"one url part with backslash suffix", "home/", []string{"home"}}, 20 | {"two url parts", "/home/services", []string{"home", "services"}}, 21 | {"two url parts without backslash prefix", "home/services", []string{"home", "services"}}, 22 | {"two url parts with backslash suffix", "home/services/", []string{"home", "services"}}, 23 | {"three url parts with star wildcard", "/downloads/*file", []string{"downloads", "*file"}}, 24 | } 25 | 26 | for _, tc := range tt { 27 | t.Run(tc.name, func(st *testing.T) { 28 | rs := createURLParts(tc.url) 29 | 30 | if ln := len(rs); ln != len(tc.slices) { 31 | st.Errorf("expected result length to be %d; got %d", len(tc.slices), ln) 32 | } 33 | 34 | if len(tc.slices) > 0 { 35 | for i, urlPart := range rs { 36 | if urlPart != tc.slices[i] { 37 | st.Errorf("expected url part at index %d to be %s; got %s", i, tc.slices[i], urlPart) 38 | } 39 | } 40 | } 41 | }) 42 | } 43 | } 44 | 45 | func TestCreateRoute(t *testing.T) { 46 | router := newRouter() 47 | 48 | if handlersLen := len(router.handlers); handlersLen != 0 { 49 | t.Fatalf("expected num of handlers to be 0; got %d", handlersLen) 50 | } 51 | 52 | if nodesLen := len(router.nodes); nodesLen != 0 { 53 | t.Fatalf("expected num of nodes to be 0; got %d", nodesLen) 54 | } 55 | } 56 | 57 | func TestAddRoute(t *testing.T) { 58 | r := newRouter() 59 | 60 | t.Run("existence route", func(st *testing.T) { 61 | emptyHandler := func(c *Context) {} 62 | 63 | tt := []struct { 64 | method string 65 | path string 66 | key string 67 | }{ 68 | {http.MethodGet, "/", "GET-/"}, 69 | {http.MethodGet, "/about", "GET-/about"}, 70 | {http.MethodGet, "/downloads/*", "GET-/downloads/*"}, 71 | {http.MethodPost, "/articles/:id", "POST-/articles/:id"}, 72 | } 73 | 74 | for _, tc := range tt { 75 | r.addRoute(tc.method, tc.path, emptyHandler) 76 | 77 | if _, ok := r.handlers[tc.key]; !ok { 78 | st.Errorf("expected key %s to be exists in handlers", tc.key) 79 | } 80 | } 81 | }) 82 | 83 | t.Run("handler count", func(st *testing.T) { 84 | firstHandler := func(c *Context) {} 85 | secondHandler := func(c *Context) {} 86 | r.addRoute(http.MethodGet, "/", firstHandler, secondHandler) 87 | 88 | route, ok := r.handlers["GET-/"] 89 | 90 | if !ok { 91 | st.Fatalf("expected route GET-/ to found; got not found") 92 | } 93 | 94 | if handlerCount := len(route); handlerCount != 2 { 95 | st.Errorf("expected handler count to be 2; got %d", handlerCount) 96 | } 97 | }) 98 | 99 | } 100 | 101 | func TestFindRoute(t *testing.T) { 102 | r := newRouter() 103 | 104 | emptyHandler := func(c *Context) {} 105 | 106 | tt := []struct { 107 | name string 108 | method string 109 | urlPattern string 110 | requestedURL string 111 | params map[string]string 112 | }{ 113 | {"root url", http.MethodGet, "/", "/", map[string]string{}}, 114 | {"one parameter", http.MethodGet, "/users/:id", "users/1", map[string]string{"id": "1"}}, 115 | {"one parameter with static path on last url", http.MethodGet, "/users/:id/about", "users/1/about", map[string]string{"id": "1"}}, 116 | {"two parameter", http.MethodGet, "/users/:id/about/:section", "users/1/about/jobs", map[string]string{"id": "1", "section": "jobs"}}, 117 | } 118 | 119 | for _, tc := range tt { 120 | 121 | t.Run(tc.name, func(st *testing.T) { 122 | r.addRoute(tc.method, tc.urlPattern, emptyHandler) 123 | node, params := r.findRoute(tc.method, tc.requestedURL) 124 | if node == nil { 125 | st.Errorf("expected route to be found; got not found") 126 | } 127 | 128 | if node.urlPattern != tc.urlPattern { 129 | st.Errorf("expected found url to be %s; got %s", tc.urlPattern, node.urlPattern) 130 | } 131 | 132 | if paramsLen := len(params); paramsLen != len(tc.params) { 133 | st.Errorf("expected params length to be %d; got %d", len(tc.params), paramsLen) 134 | } 135 | 136 | for key, param := range params { 137 | if param != tc.params[key] { 138 | st.Errorf("expected param %s to be %s; got %s", key, tc.params[key], param) 139 | } 140 | } 141 | }) 142 | } 143 | } 144 | 145 | func TestDefaultRouteHandler(t *testing.T) { 146 | r := newRouter() 147 | 148 | if r.defaultHandler != nil { 149 | t.Fatalf("expected default handler to be nil; got %T", r.defaultHandler) 150 | } 151 | 152 | tt := []struct { 153 | name string 154 | method string 155 | url string 156 | responseCode int 157 | responseText string 158 | useCustomDefault bool // if it's true, this will modify default route handler using defaultCode & defaultText. 159 | defaultCode int 160 | defaultText string 161 | }{ 162 | {"nano default route handler", http.MethodGet, "/", http.StatusNotFound, "nano/1.0 not found", false, 0, ""}, 163 | {"set custom default route handler", http.MethodGet, "/", http.StatusOK, "it's works", true, http.StatusOK, "it's works"}, 164 | } 165 | 166 | for _, tc := range tt { 167 | t.Run(tc.name, func(st *testing.T) { 168 | rec := httptest.NewRecorder() 169 | req, err := http.NewRequest(tc.method, tc.url, nil) 170 | if err != nil { 171 | log.Fatalf("could not create http request: %v", err) 172 | } 173 | 174 | if tc.useCustomDefault { 175 | r.defaultHandler = func(c *Context) { 176 | c.String(tc.defaultCode, tc.defaultText) 177 | } 178 | } 179 | 180 | ctx := newContext(rec, req) 181 | r.serveDefaultHandler(ctx) 182 | 183 | if code := rec.Code; code != tc.responseCode { 184 | st.Fatalf("expected response code to be %d; got %d", tc.responseCode, code) 185 | } 186 | 187 | if body := rec.Body.String(); body != tc.responseText { 188 | st.Errorf("expected default handle response to be %s got %s", tc.responseText, body) 189 | } 190 | }) 191 | } 192 | } 193 | 194 | func TestHandle(t *testing.T) { 195 | r := newRouter() 196 | r.addRoute(http.MethodGet, "/hello/:name", func(c *Context) { 197 | c.String(http.StatusOK, "hello %s", c.Param("name")) 198 | }) 199 | r.addRoute(http.MethodGet, "/d/*path", func(c *Context) { 200 | c.String(http.StatusOK, "downloading %s", c.Param("path")) 201 | }) 202 | 203 | tt := []struct { 204 | name string 205 | method string 206 | url string 207 | responseCode int 208 | responseText string 209 | }{ 210 | {"not found handler", http.MethodGet, "/unregistered/path", http.StatusNotFound, "nano/1.0 not found"}, 211 | {"not found on exist path but wrong method", http.MethodPost, "/hello/foo", http.StatusNotFound, "nano/1.0 not found"}, 212 | {"echo parameter", http.MethodGet, "/hello/foo", http.StatusOK, "hello foo"}, 213 | {"echo asterisk wildcard parameter", http.MethodGet, "/d/static/app.js", http.StatusOK, "downloading static/app.js"}, 214 | } 215 | 216 | for _, tc := range tt { 217 | t.Run(tc.name, func(st *testing.T) { 218 | req, err := http.NewRequest(tc.method, tc.url, nil) 219 | if err != nil { 220 | log.Fatalf("could not create http request: %v", err) 221 | } 222 | 223 | rec := httptest.NewRecorder() 224 | ctx := newContext(rec, req) 225 | r.handle(ctx) 226 | 227 | if code := rec.Code; code != tc.responseCode { 228 | st.Fatalf("expected response code to be %d; got %d", tc.responseCode, code) 229 | } 230 | 231 | if body := rec.Body.String(); body != tc.responseText { 232 | st.Errorf("expected %s as response text; got %v", tc.responseText, body) 233 | } 234 | }) 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /cors.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | // This cross-origin sharing standard is used to enable cross-site HTTP requests for: 4 | 5 | // Invocations of the XMLHttpRequest or Fetch APIs in a cross-site manner, as discussed above. 6 | // Web Fonts (for cross-domain font usage in @font-face within CSS), so that servers can deploy TrueType fonts that can only be cross-site loaded and used by web sites that are permitted to do so. 7 | // WebGL textures. 8 | // Images/video frames drawn to a canvas using drawImage(). 9 | import ( 10 | "net/http" 11 | "strings" 12 | ) 13 | 14 | // CORSConfig define nano cors middleware configuration. 15 | type CORSConfig struct { 16 | AllowedOrigins []string 17 | AllowedMethods []string 18 | AllowedHeaders []string 19 | } 20 | 21 | // CORS struct. 22 | type CORS struct { 23 | allowedOrigins []string 24 | allowedMethods []string 25 | allowedHeaders []string 26 | } 27 | 28 | // parseRequestHeader splits header string to array of headers. 29 | func parseRequestHeader(header string) []string { 30 | 31 | // request does not provide field Access-Control-Request-Header. 32 | if header == "" { 33 | return []string{} 34 | } 35 | 36 | // only requested one header. 37 | if !strings.Contains(header, ",") { 38 | return []string{header} 39 | } 40 | 41 | result := make([]string, 0) 42 | 43 | for _, part := range strings.Split(header, ",") { 44 | result = append(result, strings.Trim(part, " ")) 45 | } 46 | 47 | return result 48 | } 49 | 50 | // SetAllowedOrigins functions to fill/replace all allowed origins. 51 | func (cors *CORS) SetAllowedOrigins(origins []string) { 52 | cors.allowedOrigins = origins 53 | } 54 | 55 | // SetAllowedMethods functions to fill/replace all allowed methods. 56 | func (cors *CORS) SetAllowedMethods(methods []string) { 57 | cors.allowedMethods = methods 58 | } 59 | 60 | // SetAllowedHeaders functions to fill/replace all allowed headers. 61 | func (cors *CORS) SetAllowedHeaders(headers []string) { 62 | cors.allowedHeaders = headers 63 | } 64 | 65 | // AddAllowedHeader functions to append method to allowed list. 66 | func (cors *CORS) AddAllowedHeader(header string) { 67 | cors.allowedHeaders = append(cors.allowedHeaders, header) 68 | } 69 | 70 | // AddAllowedMethod functions to append method to allowed list. 71 | func (cors *CORS) AddAllowedMethod(method string) { 72 | cors.allowedMethods = append(cors.allowedMethods, method) 73 | } 74 | 75 | // AddAllowedOrigin appends method to allowed list. 76 | func (cors *CORS) AddAllowedOrigin(origin string) { 77 | cors.allowedOrigins = append(cors.allowedOrigins, origin) 78 | } 79 | 80 | // isAllowAllOrigin returns true when there is * wildcrad in the origin list. 81 | func (cors *CORS) isAllowAllOrigin() bool { 82 | for _, origin := range cors.allowedOrigins { 83 | if origin == "*" { 84 | return true 85 | } 86 | } 87 | 88 | return false 89 | } 90 | 91 | // isOriginAllowed returns true when origin found in allowed origin list. 92 | func (cors *CORS) isOriginAllowed(requestOrigin string) bool { 93 | for _, origin := range cors.allowedOrigins { 94 | if origin == requestOrigin || origin == "*" { 95 | return true 96 | } 97 | } 98 | 99 | return false 100 | } 101 | 102 | // isMethodAllowed returns true when method found in allowed method list. 103 | func (cors *CORS) isMethodAllowed(requestMethod string) bool { 104 | for _, method := range cors.allowedMethods { 105 | if method == requestMethod { 106 | return true 107 | } 108 | } 109 | 110 | return false 111 | } 112 | 113 | // mergeMethods functions to stringify the allowed method list. 114 | func (cors *CORS) mergeMethods() string { 115 | // when there is found * wildcard in the list, so just return it. 116 | for _, method := range cors.allowedMethods { 117 | if method == "*" { 118 | return method 119 | } 120 | } 121 | 122 | return strings.Join(cors.allowedMethods, ", ") 123 | } 124 | 125 | // isAllHeaderAllowed returns true when there is * wildcrad in the allowed header list. 126 | func (cors *CORS) isAllHeaderAllowed() bool { 127 | for _, header := range cors.allowedHeaders { 128 | if header == "*" { 129 | return true 130 | } 131 | } 132 | 133 | return false 134 | } 135 | 136 | // areHeadersAllowed checks if all requested headers are allowed 137 | func (cors *CORS) areHeadersAllowed(requestedHeaders []string) bool { 138 | // alway return true if there is no control header. 139 | if cors.isAllHeaderAllowed() { 140 | return true 141 | } 142 | 143 | for _, requestedHeader := range requestedHeaders { 144 | allowed := false 145 | 146 | for _, allowedHeader := range cors.allowedHeaders { 147 | if allowedHeader == requestedHeader { 148 | allowed = true 149 | } 150 | } 151 | 152 | if !allowed { 153 | return false 154 | } 155 | } 156 | 157 | return true 158 | } 159 | 160 | // handlePrefilghtRequest handles cross-origin preflight request. 161 | func (cors *CORS) handlePrefilghtRequest(c *Context) { 162 | if c.Origin == "" { 163 | return 164 | } 165 | 166 | if !cors.isOriginAllowed(c.Origin) { 167 | return 168 | } 169 | 170 | requestedMethod := c.GetRequestHeader(HeaderAccessControlRequestMethod) 171 | if !cors.isMethodAllowed(requestedMethod) { 172 | return 173 | } 174 | 175 | requestedHeader := c.GetRequestHeader(HeaderAccessControlRequestHeader) 176 | requestedHeaders := parseRequestHeader(requestedHeader) 177 | 178 | if len(requestedHeaders) > 0 { 179 | if !cors.areHeadersAllowed(requestedHeaders) { 180 | return 181 | } 182 | } 183 | 184 | // vary must be set. 185 | c.SetHeader(HeaderVary, "Origin, Access-Control-Request-Methods, Access-Control-Request-Header") 186 | 187 | if cors.isAllowAllOrigin() { 188 | c.SetHeader(HeaderAccessControlAllowOrigin, "*") 189 | } else { 190 | c.SetHeader(HeaderAccessControlAllowOrigin, c.Origin) 191 | } 192 | 193 | c.SetHeader(HeaderAccessControlAllowMethods, cors.mergeMethods()) 194 | 195 | if len(requestedHeader) > 0 { 196 | c.SetHeader(HeaderAccessControlAllowHeader, requestedHeader) 197 | } 198 | } 199 | 200 | // handleSimpleRequest handles simple cross origin request 201 | func (cors *CORS) handleSimpleRequest(c *Context) { 202 | if c.Origin == "" { 203 | return 204 | } 205 | 206 | if !cors.isOriginAllowed(c.Origin) { 207 | return 208 | } 209 | 210 | // vary must be set. 211 | c.SetHeader(HeaderVary, HeaderOrigin) 212 | 213 | if cors.isAllowAllOrigin() { 214 | c.SetHeader(HeaderAccessControlAllowOrigin, "*") 215 | } else { 216 | c.SetHeader(HeaderAccessControlAllowOrigin, c.Origin) 217 | } 218 | } 219 | 220 | // Handle corss-origin request 221 | // The Cross-Origin Resource Sharing standard works by adding new HTTP headers that allow servers 222 | // to describe the set of origins that are permitted to read that information using a web browser. 223 | // Additionally, for HTTP request methods that can cause side-effects on server's data 224 | // (in particular, for HTTP methods other than GET, or for POST usage with certain MIME types), 225 | // the specification mandates that browsers "preflight" the request, 226 | // soliciting supported methods from the server with an HTTP OPTIONS request method, 227 | // and then, upon "approval" from the server, sending the actual request with the actual HTTP request method. 228 | // Servers can also notify clients whether "credentials" (including Cookies and HTTP Authentication data) should be sent with requests. 229 | func (cors *CORS) Handle(c *Context) { 230 | // preflighted requests first send an HTTP request by the OPTIONS method to the resource on the other domain, 231 | // in order to determine whether the actual request is safe to send. 232 | // Cross-site requests are preflighted like this since they may have implications to user data. 233 | if c.Method == http.MethodOptions && c.GetRequestHeader(HeaderAccessControlRequestMethod) != "" { 234 | cors.handlePrefilghtRequest(c) 235 | return 236 | } 237 | 238 | // Some requests don’t trigger a CORS preflight. Those are called “simple requests”, 239 | // though the Fetch spec (which defines CORS) doesn’t use that term. 240 | // A request that doesn’t trigger a CORS preflight—a so-called “simple request” 241 | cors.handleSimpleRequest(c) 242 | 243 | c.Next() 244 | } 245 | 246 | // CORSWithConfig returns cors middleware. 247 | func CORSWithConfig(config CORSConfig) HandlerFunc { 248 | 249 | cors := new(CORS) 250 | 251 | // create default value for all configuration field. 252 | // default value is allowed for all origin, methods, and headers. 253 | if len(config.AllowedMethods) == 0 { 254 | config.AllowedMethods = []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodGet} 255 | } 256 | 257 | if len(config.AllowedOrigins) == 0 { 258 | config.AllowedOrigins = []string{"*"} 259 | } 260 | 261 | if len(config.AllowedHeaders) == 0 { 262 | config.AllowedHeaders = []string{"*"} 263 | } 264 | 265 | cors.SetAllowedMethods(config.AllowedMethods) 266 | cors.SetAllowedOrigins(config.AllowedOrigins) 267 | cors.SetAllowedHeaders(config.AllowedHeaders) 268 | 269 | return cors.Handle 270 | } 271 | -------------------------------------------------------------------------------- /binding.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "reflect" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // ErrBinding defines an error interface implementation and it will returned when binding failed. 13 | // Status will set to 422 when there is error on validation, 14 | // 400 when client sent unsupported/without Content-Type header, and 15 | // 500 when targetStruct is not pointer or type conversion is fail. 16 | type ErrBinding struct { 17 | Status int 18 | Text string 19 | Fields []string 20 | } 21 | 22 | var ( 23 | // ErrBindNonPointer must be returned when non-pointer struct passed as targetStruct parameter. 24 | ErrBindNonPointer = ErrBinding{ 25 | Text: "expected pointer to target struct, got non-pointer", 26 | Status: http.StatusInternalServerError, 27 | } 28 | 29 | // ErrBindContentType returned when client content type besides json, urlencoded, & multipart form. 30 | ErrBindContentType = ErrBinding{ 31 | Status: http.StatusBadRequest, 32 | Text: "unknown content type of request body", 33 | } 34 | ) 35 | 36 | // Error implements error interface. 37 | func (e ErrBinding) Error() string { 38 | if len(e.Fields) > 0 { 39 | return e.Text + " " + strings.Join(e.Fields, ",") 40 | } 41 | 42 | return e.Text 43 | } 44 | 45 | // Bind request body into defined user struct. 46 | // This function help you to automatic binding based on request Content-Type & request method. 47 | // If you want to chooose binding method manually, you could use : 48 | // BindSimpleForm to bind urlencoded form & url query, 49 | // BindMultipartForm to bind multipart/form data, 50 | // and BindJSON to bind application/json request body. 51 | func (c *Context) Bind(targetStruct interface{}) error { 52 | contentType := c.GetRequestHeader(HeaderContentType) 53 | 54 | // if client request using POST, PUT, & PATCH we will try to bind request using simple form (urlencoded & url query), 55 | // multipart form, and JSON. if you need both binding e.g. to bind multipart form & url query, 56 | // this method doesn't works. you should call BindSimpleForm & BindMultipartForm manually from your handler. 57 | if c.Method == http.MethodPost || c.Method == http.MethodPut || c.Method == http.MethodPatch || contentType != "" { 58 | if strings.Contains(contentType, MimeFormURLEncoded) { 59 | return c.BindSimpleForm(targetStruct) 60 | } 61 | 62 | if strings.Contains(contentType, MimeMultipartForm) { 63 | return c.BindMultipartForm(targetStruct) 64 | } 65 | 66 | if c.IsJSON() { 67 | return c.BindJSON(targetStruct) 68 | } 69 | 70 | return ErrBindContentType 71 | } 72 | 73 | // when client request using GET method, we will serve binding using simple form. 74 | // it's can binding url-encoded form & url query data. 75 | return c.BindSimpleForm(targetStruct) 76 | } 77 | 78 | // BindJSON functions to bind request body (with contet type application/json) to targetStruct. 79 | // targetStruct must be pointer to user defined struct. 80 | func (c *Context) BindJSON(targetStruct interface{}) error { 81 | // only accept pointer 82 | if reflect.TypeOf(targetStruct).Kind() != reflect.Ptr { 83 | return ErrBindNonPointer 84 | } 85 | 86 | if c.Request.Body != nil { 87 | defer c.Request.Body.Close() 88 | err := json.NewDecoder(c.Request.Body).Decode(targetStruct) 89 | if err != nil && err != io.EOF { 90 | return ErrBinding{ 91 | Text: err.Error(), 92 | Status: http.StatusBadRequest, 93 | } 94 | } 95 | } 96 | 97 | return validate(c, targetStruct) 98 | } 99 | 100 | // BindSimpleForm functions to bind request body (with content type form-urlencoded or url query) to targetStruct. 101 | // targetStruct must be pointer to user defined struct. 102 | func (c *Context) BindSimpleForm(targetStruct interface{}) error { 103 | // only accept pointer 104 | if reflect.TypeOf(targetStruct).Kind() != reflect.Ptr { 105 | return ErrBinding{ 106 | Text: "expected pointer to target struct, got non-pointer", 107 | Status: http.StatusInternalServerError, 108 | } 109 | } 110 | 111 | if err := c.Request.ParseForm(); err != nil { 112 | return ErrBinding{ 113 | Text: fmt.Sprintf("could not parsing form body: %v", err), 114 | Status: http.StatusInternalServerError, 115 | } 116 | } 117 | 118 | if err := bindForm(c.Request.Form, targetStruct); err != nil { 119 | return ErrBinding{ 120 | Status: http.StatusInternalServerError, 121 | Text: fmt.Sprintf("binding error: %v", err), 122 | } 123 | } 124 | 125 | return validate(c, targetStruct) 126 | } 127 | 128 | // BindMultipartForm functions to bind request body (with contet type multipart/form-data) to targetStruct. 129 | // targetStruct must be pointer to user defined struct. 130 | func (c *Context) BindMultipartForm(targetStruct interface{}) error { 131 | // only accept pointer 132 | if reflect.TypeOf(targetStruct).Kind() != reflect.Ptr { 133 | return ErrBinding{ 134 | Text: "expected pointer to target struct, got non-pointer", 135 | Status: http.StatusInternalServerError, 136 | } 137 | } 138 | 139 | err := c.Request.ParseMultipartForm(16 << 10) 140 | if err != nil { 141 | return ErrBinding{ 142 | Text: fmt.Sprintf("could not parsing form body: %v", err), 143 | Status: http.StatusBadRequest, 144 | } 145 | } 146 | 147 | err = bindForm(c.Request.MultipartForm.Value, targetStruct) 148 | if err != nil { 149 | return ErrBinding{ 150 | Status: http.StatusInternalServerError, 151 | Text: fmt.Sprintf("binding error: %v", err), 152 | } 153 | } 154 | 155 | return validate(c, targetStruct) 156 | } 157 | 158 | // bindForm maps each field in request body into targetStruct. 159 | func bindForm(form map[string][]string, targetStruct interface{}) error { 160 | targetPtr := reflect.ValueOf(targetStruct).Elem() 161 | targetType := targetPtr.Type() 162 | 163 | // only accept struct as target binding 164 | if targetPtr.Kind() != reflect.Struct { 165 | return fmt.Errorf("expected target binding to be struct") 166 | } 167 | 168 | for i := 0; i < targetPtr.NumField(); i++ { 169 | fieldValue := targetPtr.Field(i) 170 | // this is used to get field tag. 171 | fieldType := targetType.Field(i) 172 | 173 | // continue iteration when field is not settable. 174 | if !fieldValue.CanSet() { 175 | continue 176 | } 177 | 178 | // check if current field nested struct. 179 | // this is possible when current request body is json type. 180 | if fieldValue.Kind() == reflect.Struct { 181 | // bind recursively. 182 | err := bindForm(form, fieldValue.Addr().Interface()) 183 | if err != nil { 184 | return err 185 | } 186 | } else { 187 | // web use tag "form" as field name in request body. 188 | // so make sure you have matching name at field name in request body and field tag in your target struct 189 | formFieldName := fieldType.Tag.Get("form") 190 | // continue iteration when field doesnt have form tag. 191 | if formFieldName == "" { 192 | continue 193 | } 194 | 195 | formValue, exists := form[formFieldName] 196 | // could not find value in request body, let it empty 197 | if !exists { 198 | continue 199 | } 200 | 201 | formValueCount := len(formValue) 202 | // it's possible if current field value is an array. 203 | if fieldValue.Kind() == reflect.Slice && formValueCount > 0 { 204 | sliceKind := fieldValue.Type().Elem().Kind() 205 | slice := reflect.MakeSlice(fieldValue.Type(), formValueCount, formValueCount) 206 | for i := 0; i < formValueCount; i++ { 207 | if err := setFieldValue(sliceKind, formValue[i], slice.Index(i)); err != nil { 208 | return err 209 | } 210 | } 211 | fieldValue.Field(i).Set(slice) 212 | } else { 213 | // it's a single value. just do direct set. 214 | if err := setFieldValue(fieldValue.Kind(), formValue[0], fieldValue); err != nil { 215 | return err 216 | } 217 | } 218 | } 219 | } 220 | 221 | return nil 222 | } 223 | 224 | // setFieldValue sets field with typed value. 225 | // we will find the best type & size for your field value. 226 | // if empty string provided to value parameter, we will use zero type value as default field value. 227 | func setFieldValue(kind reflect.Kind, value string, fieldValue reflect.Value) error { 228 | switch kind { 229 | case reflect.Int: 230 | setIntField(value, 0, fieldValue) 231 | case reflect.Int8: 232 | setIntField(value, 8, fieldValue) 233 | case reflect.Int16: 234 | setIntField(value, 16, fieldValue) 235 | case reflect.Int32: 236 | setIntField(value, 32, fieldValue) 237 | case reflect.Int64: 238 | setIntField(value, 64, fieldValue) 239 | case reflect.Uint: 240 | setUintField(value, 0, fieldValue) 241 | case reflect.Uint8: 242 | setUintField(value, 8, fieldValue) 243 | case reflect.Uint16: 244 | setUintField(value, 16, fieldValue) 245 | case reflect.Uint32: 246 | setUintField(value, 32, fieldValue) 247 | case reflect.Uint64: 248 | setUintField(value, 64, fieldValue) 249 | case reflect.Bool: 250 | setBoolField(value, fieldValue) 251 | case reflect.Float32: 252 | setFloatField(value, 32, fieldValue) 253 | case reflect.Float64: 254 | setFloatField(value, 64, fieldValue) 255 | case reflect.String: 256 | // no conversion needed. because value already a string. 257 | fieldValue.SetString(value) 258 | default: 259 | // whoopss.. 260 | return fmt.Errorf("unknown type") 261 | } 262 | return nil 263 | } 264 | 265 | // setIntField converts input string (value) into integer. 266 | func setIntField(value string, size int, field reflect.Value) { 267 | convertedValue, err := strconv.ParseInt(value, 10, size) 268 | // set default empty value when conversion. 269 | if err != nil { 270 | convertedValue = 0 271 | } 272 | field.SetInt(convertedValue) 273 | } 274 | 275 | // setUintField converts input string (value) into unsigned integer. 276 | func setUintField(value string, size int, field reflect.Value) { 277 | convertedValue, err := strconv.ParseUint(value, 10, size) 278 | // set default empty value when conversion. 279 | if err != nil { 280 | convertedValue = 0 281 | } 282 | field.SetUint(convertedValue) 283 | } 284 | 285 | // setBoolField converts input string (value) into boolean. 286 | func setBoolField(value string, field reflect.Value) { 287 | convertedValue, err := strconv.ParseBool(value) 288 | // set default empty value when conversion. 289 | if err != nil { 290 | convertedValue = false 291 | } 292 | field.SetBool(convertedValue) 293 | } 294 | 295 | // setFloatField converts input string (value) into floating. 296 | func setFloatField(value string, size int, field reflect.Value) { 297 | convertedValue, err := strconv.ParseFloat(value, size) 298 | // set default empty value when conversion. 299 | if err != nil { 300 | convertedValue = 0.0 301 | } 302 | field.SetFloat(convertedValue) 303 | } 304 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package nano 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | ) 12 | 13 | func TestCreateNewContext(t *testing.T) { 14 | path := "/hello" 15 | method := http.MethodGet 16 | req, err := http.NewRequest(method, path, nil) 17 | if err != nil { 18 | log.Fatalf("could not make http request: %v", err) 19 | } 20 | 21 | rec := httptest.NewRecorder() 22 | ctx := newContext(rec, req) 23 | 24 | if ctx.Path != path { 25 | t.Errorf("expected path to be %s; got %s", path, ctx.Path) 26 | } 27 | 28 | if ctx.Method != method { 29 | t.Errorf("expected method to be %s; got %s", method, ctx.Method) 30 | } 31 | 32 | if ctx.cursor != -1 { 33 | t.Errorf("expected cursor to be -1; got %d", ctx.cursor) 34 | } 35 | } 36 | 37 | func TestNext(t *testing.T) { 38 | req, err := http.NewRequest(http.MethodGet, "/", nil) 39 | if err != nil { 40 | log.Fatalf("could not make http request: %v", err) 41 | } 42 | rec := httptest.NewRecorder() 43 | ctx := newContext(rec, req) 44 | 45 | emptyHandler := func(c *Context) { 46 | c.Next() 47 | } 48 | helloHandler := func(c *Context) { 49 | c.String(http.StatusOK, "ok") 50 | } 51 | 52 | r := newRouter() 53 | r.addRoute(http.MethodGet, "/", emptyHandler, emptyHandler, emptyHandler, helloHandler, emptyHandler) 54 | r.handle(ctx) 55 | 56 | if ctx.cursor != 3 { 57 | t.Fatalf("expected stack cursor to be 3; got %d", ctx.cursor) 58 | } 59 | 60 | if hlen := len(ctx.handlers); hlen != 5 { 61 | t.Errorf("expected total handler to be 5; got %d", hlen) 62 | } 63 | } 64 | 65 | func TestStatusCode(t *testing.T) { 66 | r := newRouter() 67 | r.addRoute(http.MethodGet, "/", func(c *Context) { 68 | c.Status(http.StatusNoContent) 69 | }) 70 | 71 | req, err := http.NewRequest(http.MethodGet, "/", nil) 72 | if err != nil { 73 | log.Fatalf("could not make http request: %v", err) 74 | } 75 | rec := httptest.NewRecorder() 76 | ctx := newContext(rec, req) 77 | 78 | r.handle(ctx) 79 | 80 | if rec.Code != http.StatusNoContent { 81 | t.Errorf("expected status code to be %d; got %d", http.StatusNoContent, rec.Code) 82 | } 83 | } 84 | 85 | func TestSetHeader(t *testing.T) { 86 | headers := map[string]string{ 87 | "X-Powered-By": "nano/1.1", 88 | "X-Foo": "Bar,Baz", 89 | } 90 | 91 | r := newRouter() 92 | r.addRoute(http.MethodGet, "/", func(c *Context) { 93 | for key, val := range headers { 94 | c.SetHeader(key, val) 95 | } 96 | 97 | c.String(http.StatusOK, "ok") 98 | }) 99 | 100 | req, err := http.NewRequest(http.MethodGet, "/", nil) 101 | if err != nil { 102 | log.Fatalf("could not make http request: %v", err) 103 | } 104 | rec := httptest.NewRecorder() 105 | ctx := newContext(rec, req) 106 | 107 | r.handle(ctx) 108 | 109 | for key, val := range headers { 110 | if head := rec.Header().Get(key); head != val { 111 | t.Errorf("expected header %s to be %s; got %s", key, val, head) 112 | } 113 | } 114 | } 115 | 116 | // GetRequestHeader returns header value by given key. 117 | func TestGetRequestHeader(t *testing.T) { 118 | reqHeaders := [2]struct { 119 | Key string 120 | Value string 121 | }{ 122 | {HeaderContentType, MimeJSON}, 123 | {HeaderAccept, MimeJSON}, 124 | } 125 | 126 | r := newRouter() 127 | r.addRoute(http.MethodGet, "/", func(c *Context) { 128 | params := make([]string, 0) 129 | for _, header := range reqHeaders { 130 | val := c.GetRequestHeader(header.Key) 131 | params = append(params, fmt.Sprintf("%s:%s", header.Key, val)) 132 | } 133 | c.String(http.StatusOK, strings.Join(params, ",")) 134 | }) 135 | 136 | req, err := http.NewRequest(http.MethodGet, "/", nil) 137 | if err != nil { 138 | log.Fatalf("could not make http request: %v", err) 139 | } 140 | for _, header := range reqHeaders { 141 | req.Header.Set(header.Key, header.Value) 142 | } 143 | 144 | rec := httptest.NewRecorder() 145 | ctx := newContext(rec, req) 146 | 147 | r.handle(ctx) 148 | 149 | params := make([]string, 0) 150 | for _, rs := range reqHeaders { 151 | params = append(params, fmt.Sprintf("%s:%s", rs.Key, rs.Value)) 152 | } 153 | 154 | expectedRes := strings.Join(params, ",") 155 | if body := rec.Body.String(); body != expectedRes { 156 | t.Fatalf("expected response body %s; got %s", expectedRes, body) 157 | } 158 | } 159 | 160 | func TestGetRouteParameter(t *testing.T) { 161 | tt := []struct { 162 | name string 163 | method string 164 | urlPattern string 165 | requestURL string 166 | paramNames []string 167 | paramValues []string 168 | }{ 169 | {"one parameter", http.MethodGet, "/hello/:name", "/hello/foo", []string{"name"}, []string{"foo"}}, 170 | {"two parameters", http.MethodGet, "/hello/:name/show/:section", "/hello/bar/show/media", []string{"name", "section"}, []string{"bar", "media"}}, 171 | {"one parameter with asterisk wildcard", http.MethodGet, "/d/:timeout/u/*path", "/d/30/u/files/nano.zip", []string{"timeout", "path"}, []string{"30", "files/nano.zip"}}, 172 | // please consider the test ordering, because it's has added to same router instance. 173 | {"asterisk wildcard", http.MethodGet, "/d/u/*path", "/d/u/files/nano.zip", []string{"path"}, []string{"files/nano.zip"}}, 174 | } 175 | 176 | r := newRouter() 177 | 178 | for _, tc := range tt { 179 | r.addRoute(tc.method, tc.urlPattern, func(c *Context) { 180 | params := make([]string, 0) 181 | for _, name := range tc.paramNames { 182 | params = append(params, c.Param(name)) 183 | } 184 | 185 | c.String(http.StatusOK, strings.Join(params, ",")) 186 | }) 187 | 188 | req, err := http.NewRequest(tc.method, tc.requestURL, nil) 189 | if err != nil { 190 | log.Fatalf("could not make http request: %v", err) 191 | } 192 | 193 | rec := httptest.NewRecorder() 194 | ctx := newContext(rec, req) 195 | r.handle(ctx) 196 | 197 | expectedRes := strings.Join(tc.paramValues, ",") 198 | if body := rec.Body.String(); body != expectedRes { 199 | t.Errorf("expected parameter to be %s; got %s", expectedRes, body) 200 | } 201 | } 202 | } 203 | 204 | func TestResponse(t *testing.T) { 205 | r := newRouter() 206 | jsonHandler := func(c *Context) { 207 | c.JSON(http.StatusOK, H{ 208 | "message": "ok", 209 | }) 210 | } 211 | stringHandler := func(c *Context) { 212 | c.String(http.StatusOK, "ok") 213 | } 214 | htmlHandler := func(c *Context) { 215 | c.HTML(http.StatusOK, "
