├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── acceptlang
├── README.md
├── handler.go
└── handler_test.go
├── auth
├── README.md
├── basic.go
├── basic_test.go
├── util.go
└── util_test.go
├── binding
├── README.md
├── binding.go
└── binding_test.go
├── cors
├── README.md
├── cors.go
└── cors_test.go
├── encoder
├── README.md
├── encoder.go
└── encoder_test.go
├── gzip
├── README.md
├── gzip.go
└── gzip_test.go
├── method
├── README.md
├── override.go
└── override_test.go
├── render
├── README.md
├── fixtures
│ ├── basic
│ │ ├── admin
│ │ │ └── index.tmpl
│ │ ├── another_layout.tmpl
│ │ ├── content.tmpl
│ │ ├── delims.tmpl
│ │ ├── hello.tmpl
│ │ ├── hypertext.html
│ │ └── layout.tmpl
│ └── custom_funcs
│ │ └── index.tmpl
├── render.go
└── render_test.go
├── secure
├── README.md
├── secure.go
└── secure_test.go
├── sessionauth
├── README.md
├── example
│ ├── auth_example.go
│ ├── templates
│ │ ├── index.tmpl
│ │ ├── login.tmpl
│ │ └── private.tmpl
│ └── user.go
├── login.go
└── login_test.go
├── sessions
├── README.md
├── benchmarks_test.go
├── cookie_store.go
├── sessions.go
└── sessions_test.go
├── strip
├── README.md
├── prefix.go
└── prefix_test.go
├── web
├── LICENSE
├── README.md
├── web.go
└── web_test.go
└── wercker.yml
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | *.go text eol=lf
3 | *.html text eol=lf
4 | *.tmpl text eol=lf
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Object files, Static and Dynamic libs (Shared Objects)
2 | *.o
3 | *.a
4 | *.so
5 |
6 | # Folders
7 | _obj
8 | _test
9 |
10 | # Architecture specific extensions/prefixes
11 | *.[568vq]
12 | [568vq].out
13 |
14 | *.cgo1.go
15 | *.cgo2.c
16 | _cgo_defun.c
17 | _cgo_gotypes.go
18 | _cgo_export.*
19 |
20 | _testmain.go
21 |
22 | *.exe
23 | *.test
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2013 Jeremy Saenz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # This Project Has Moved!
2 | martini-contrib got so big that it now has become it's own Github Organization! You can find updated packages at https://github.com/martini-contrib
3 |
--------------------------------------------------------------------------------
/acceptlang/README.md:
--------------------------------------------------------------------------------
1 | # acceptlang
2 | Using the `acceptlang` handler you can automatically parse the `Accept-Language` HTTP header and expose it as an `AcceptLanguages` slice in your handler functions. The `AcceptLanguages` slice contains `AcceptLanguage` values, each of which represent a qualified (or unqualified) language. The values in the slice are sorted descending by qualification (the most qualified languages will have the lowest indexes in the slice).
3 |
4 | Unqualified languages are interpreted as having the maximum qualification of `1`, as defined in the HTTP/1.1 specification.
5 |
6 | For more information:
7 | * [HTTP/1.1 Accept-Language specification](http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4)
8 | * [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/acceptlang)
9 |
10 | ## Usage
11 | Simply add a new handler function instance to your handler chain using the `acceptlang.Languages()` function as well as an `AcceptLanguages` dependency in your handler function. The `AcceptLanguages` dependency will be satisified by the handler.
12 |
13 | For example:
14 |
15 | ```go
16 | package main
17 |
18 | import (
19 | "fmt"
20 | "github.com/codegangsta/martini"
21 | "github.com/codegangsta/martini-contrib/acceptlang"
22 | "net/http"
23 | )
24 |
25 | func main() {
26 | m := martini.Classic()
27 |
28 | m.Get("/", acceptlang.Languages(), func(languages acceptlang.AcceptLanguages) string {
29 | return fmt.Sprintf("Languages: %s", languages)
30 | })
31 |
32 | http.ListenAndServe(":8090", m)
33 | }
34 | ```
35 |
36 | ## Authors
37 | * [Tom Bruggeman](http://github.com/tmbrggmn)
38 |
--------------------------------------------------------------------------------
/acceptlang/handler.go:
--------------------------------------------------------------------------------
1 | // Package acceptlang provides a Martini handler and primitives to parse
2 | // the Accept-Language HTTP header values.
3 | //
4 | // See the HTTP header fields specification for more details
5 | // (http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4).
6 | //
7 | // Example
8 | //
9 | // Use the handler to automatically parse the Accept-Language header and
10 | // return the results as response:
11 | // m.Get("/", acceptlang.Languages(), func(languages acceptlang.AcceptLanguages) string {
12 | // return fmt.Sprintf("Languages: %s", languages)
13 | // })
14 | //
15 | package acceptlang
16 |
17 | import (
18 | "bytes"
19 | "fmt"
20 | "github.com/codegangsta/martini"
21 | "net/http"
22 | "sort"
23 | "strconv"
24 | "strings"
25 | )
26 |
27 | const (
28 | acceptLanguageHeader = "Accept-Language"
29 | )
30 |
31 | // A single language from the Accept-Language HTTP header.
32 | type AcceptLanguage struct {
33 | Language string
34 | Quality float32
35 | }
36 |
37 | // A slice of sortable AcceptLanguage instances.
38 | type AcceptLanguages []AcceptLanguage
39 |
40 | // Returns the total number of items in the slice. Implemented to satisfy
41 | // sort.Interface.
42 | func (al AcceptLanguages) Len() int { return len(al) }
43 |
44 | // Swaps the items at position i and j. Implemented to satisfy sort.Interface.
45 | func (al AcceptLanguages) Swap(i, j int) { al[i], al[j] = al[j], al[i] }
46 |
47 | // Determines whether or not the item at position i is "less than" the item
48 | // at position j. Implemented to satisfy sort.Interface.
49 | func (al AcceptLanguages) Less(i, j int) bool { return al[i].Quality > al[j].Quality }
50 |
51 | // Returns the parsed languages in a human readable fashion.
52 | func (al AcceptLanguages) String() string {
53 | output := bytes.NewBufferString("")
54 | for i, language := range al {
55 | output.WriteString(fmt.Sprintf("%s (%1.1f)", language.Language, language.Quality))
56 | if i != len(al)-1 {
57 | output.WriteString(", ")
58 | }
59 | }
60 |
61 | if output.Len() == 0 {
62 | output.WriteString("[]")
63 | }
64 |
65 | return output.String()
66 | }
67 |
68 | // Creates a new handler that parses the Accept-Language HTTP header.
69 | //
70 | // The parsed structure is a slice of Accept-Language values stored in an
71 | // AcceptLanguages instance, sorted based on the language qualifier.
72 | func Languages() martini.Handler {
73 | return func(context martini.Context, request *http.Request) {
74 | header := request.Header.Get(acceptLanguageHeader)
75 | if header != "" {
76 | acceptLanguageHeaderValues := strings.Split(header, ",")
77 | acceptLanguages := make(AcceptLanguages, len(acceptLanguageHeaderValues))
78 |
79 | for i, languageRange := range acceptLanguageHeaderValues {
80 | // Check if a given range is qualified or not
81 | if qualifiedRange := strings.Split(languageRange, ";q="); len(qualifiedRange) == 2 {
82 | quality, error := strconv.ParseFloat(qualifiedRange[1], 32)
83 | if error != nil {
84 | // When the quality is unparseable, assume it's 1
85 | acceptLanguages[i] = AcceptLanguage{trimLanguage(qualifiedRange[0]), 1}
86 | } else {
87 | acceptLanguages[i] = AcceptLanguage{trimLanguage(qualifiedRange[0]), float32(quality)}
88 | }
89 | } else {
90 | acceptLanguages[i] = AcceptLanguage{trimLanguage(languageRange), 1}
91 | }
92 | }
93 |
94 | sort.Sort(acceptLanguages)
95 | context.Map(acceptLanguages)
96 | } else {
97 | // If we have no Accept-Language header just map an empty slice
98 | context.Map(make(AcceptLanguages, 0))
99 | }
100 | }
101 | }
102 |
103 | func trimLanguage(language string) string {
104 | return strings.Trim(language, " ")
105 | }
106 |
--------------------------------------------------------------------------------
/acceptlang/handler_test.go:
--------------------------------------------------------------------------------
1 | package acceptlang
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "reflect"
8 | "testing"
9 | )
10 |
11 | type acceptLanguageTest struct {
12 | path string
13 | header string
14 | expected AcceptLanguages
15 | }
16 |
17 | var acceptLanguageTests = []acceptLanguageTest{
18 | // Test an empty header
19 | {"/none", "", make(AcceptLanguages, 0)},
20 |
21 | // Test a single unqualified header value
22 | {"/single", "en-gb", AcceptLanguages{AcceptLanguage{"en-gb", 1}}},
23 |
24 | // Test a single qualified header value
25 | {"/single_qualified", "en-gb;q=0.8", AcceptLanguages{AcceptLanguage{"en-gb", 0.8}}},
26 |
27 | // Test multiple unqualified header values
28 | {"/multiple", "en-gb, nl,en-us", AcceptLanguages{
29 | AcceptLanguage{"en-gb", 1}, AcceptLanguage{"nl", 1}, AcceptLanguage{"en-us", 1},
30 | }},
31 |
32 | // Test multiple qualified header values
33 | {"/multiple_qualified", "en-gb;q=0.2, nl;q=1,en-us;q=0.5", AcceptLanguages{
34 | AcceptLanguage{"nl", 1}, AcceptLanguage{"en-us", 0.5}, AcceptLanguage{"en-gb", 0.2},
35 | }},
36 | }
37 |
38 | func TestAcceptLanguageTests(t *testing.T) {
39 | for _, test := range acceptLanguageTests {
40 | m := martini.Classic()
41 | m.Get(test.path, Languages(), func(result AcceptLanguages) {
42 | if !reflect.DeepEqual(result, test.expected) {
43 | t.Errorf("Unexpected test result:\nExpected: %#v\nResult: %#v", test.expected, result)
44 | }
45 | })
46 |
47 | recorder := httptest.NewRecorder()
48 | r, _ := http.NewRequest("GET", test.path, nil)
49 | if test.header != "" {
50 | r.Header.Add(acceptLanguageHeader, test.header)
51 | }
52 | m.ServeHTTP(recorder, r)
53 | }
54 | }
55 |
56 | func BenchmarkLanguages1(b *testing.B) {
57 | m := newBenchmarkMartini()
58 |
59 | recorder := httptest.NewRecorder()
60 | r, _ := http.NewRequest("GET", "/benchmark", nil)
61 | r.Header.Add(acceptLanguageHeader, "en-us;q=0.7")
62 |
63 | for n := 0; n < b.N; n++ {
64 | m.ServeHTTP(recorder, r)
65 | }
66 | }
67 |
68 | func BenchmarkLanguages6(b *testing.B) {
69 | m := newBenchmarkMartini()
70 |
71 | recorder := httptest.NewRecorder()
72 | r, _ := http.NewRequest("GET", "/benchmark", nil)
73 | r.Header.Add(acceptLanguageHeader, "en-us;q=0.7, en-GB;q=0.8, de;q=1, nl;q=0.1, fr-FR;q=0.3, es")
74 |
75 | for n := 0; n < b.N; n++ {
76 | m.ServeHTTP(recorder, r)
77 | }
78 | }
79 |
80 | func newBenchmarkMartini() *martini.ClassicMartini {
81 | router := martini.NewRouter()
82 | base := martini.New()
83 | base.Action(router.Handle)
84 |
85 | m := &martini.ClassicMartini{base, router}
86 | m.Get("/benchmark", Languages(), func(result AcceptLanguages) {
87 | //b.Logf("Parsed languages: %s", result)
88 | })
89 |
90 | return m
91 | }
92 |
--------------------------------------------------------------------------------
/auth/README.md:
--------------------------------------------------------------------------------
1 | # auth
2 | Martini middleware/handler for http basic authentication.
3 |
4 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/auth)
5 |
6 | ## Usage
7 |
8 | ~~~ go
9 | import (
10 | "github.com/codegangsta/martini"
11 | "github.com/codegangsta/martini-contrib/auth"
12 | )
13 |
14 | func main() {
15 | m := martini.Classic()
16 | // authenticate every request
17 | m.Use(auth.Basic("username", "secretpassword"))
18 | m.Run()
19 | }
20 |
21 | ~~~
22 |
23 | ## Authors
24 | * [Jeremy Saenz](http://github.com/codegangsta)
25 | * [Brendon Murphy](http://github.com/bemurphy)
26 |
--------------------------------------------------------------------------------
/auth/basic.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "encoding/base64"
5 | "net/http"
6 | )
7 |
8 | // Basic returns a Handler that authenticates via Basic Auth. Writes a http.StatusUnauthorized
9 | // if authentication fails
10 | func Basic(username string, password string) http.HandlerFunc {
11 | var siteAuth = base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
12 | return func(res http.ResponseWriter, req *http.Request) {
13 | auth := req.Header.Get("Authorization")
14 | if !SecureCompare(auth, "Basic "+siteAuth) {
15 | res.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
16 | http.Error(res, "Not Authorized", http.StatusUnauthorized)
17 | }
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/auth/basic_test.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "encoding/base64"
5 | "github.com/codegangsta/martini"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | func Test_BasicAuth(t *testing.T) {
12 | recorder := httptest.NewRecorder()
13 |
14 | auth := "Basic " + base64.StdEncoding.EncodeToString([]byte("foo:bar"))
15 |
16 | m := martini.New()
17 | m.Use(Basic("foo", "bar"))
18 | m.Use(func(res http.ResponseWriter, req *http.Request) {
19 | res.Write([]byte("hello"))
20 | })
21 |
22 | r, _ := http.NewRequest("GET", "foo", nil)
23 |
24 | m.ServeHTTP(recorder, r)
25 |
26 | if recorder.Code != 401 {
27 | t.Error("Response not 401")
28 | }
29 |
30 | if recorder.Body.String() == "hello" {
31 | t.Error("Auth block failed")
32 | }
33 |
34 | recorder = httptest.NewRecorder()
35 | r.Header.Set("Authorization", auth)
36 | m.ServeHTTP(recorder, r)
37 |
38 | if recorder.Code == 401 {
39 | t.Error("Response is 401")
40 | }
41 |
42 | if recorder.Body.String() != "hello" {
43 | t.Error("Auth failed, got: ", recorder.Body.String())
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/auth/util.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "crypto/subtle"
5 | )
6 |
7 | // SecureCompare performs a constant time compare of two strings to limit timing attacks.
8 | func SecureCompare(given string, actual string) bool {
9 | if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
10 | return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
11 | } else {
12 | /* Securely compare actual to itself to keep constant time, but always return false */
13 | return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/auth/util_test.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | var comparetests = []struct {
8 | a string
9 | b string
10 | val bool
11 | }{
12 | {"foo", "foo", true},
13 | {"bar", "bar", true},
14 | {"password", "password", true},
15 | {"Foo", "foo", false},
16 | {"foo", "foobar", false},
17 | {"password", "pass", false},
18 | }
19 |
20 | func Test_SecureCompare(t *testing.T) {
21 | for _, tt := range comparetests {
22 | if SecureCompare(tt.a, tt.b) != tt.val {
23 | t.Errorf("Expected SecureCompare(%v, %v) to return %v but did not", tt.a, tt.b, tt.val)
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/binding/README.md:
--------------------------------------------------------------------------------
1 | # binding
2 |
3 | Request data binding for Martini.
4 |
5 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/binding)
6 |
7 |
8 |
9 | ## Description
10 |
11 | Package `binding` provides several middleware for transforming raw request data into populated structs, validating the input, and handling the errors. Each handler is independent and optional.
12 |
13 | #### Bind
14 |
15 | `binding.Bind` is a convenient wrapper over the other handlers in this package. It does the following boilerplate for you:
16 |
17 | 1. Deserializes the request data into a struct you supply
18 | 2. Performs validation with `binding.Validate`
19 | 3. Bails out with `binding.ErrorHandler` if there are any errors
20 |
21 | Your application (the final handler) will not even see the request if there are any errors.
22 |
23 | It reads the Content-Type of the request to know how to deserialize it, or if the Content-Type is not specified, it tries different deserializers until one returns without errors.
24 |
25 | **Important safety tip:** Don't attempt to bind a pointer to a struct. This will cause a panic [to prevent a race condition](https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659) where every request would be pointing to the same struct.
26 |
27 | #### Form
28 |
29 | `binding.Form` deserializes form data from the request, whether in the query string or as a form-urlencoded payload, and puts the data into a struct you pass in. It then invokes the `binding.Validate` middleware to perform validation. No error handling is performed, but you can get the errors in your handler by receiving a `binding.Errors` type.
30 |
31 |
32 | #### Json
33 |
34 | `binding.Json` deserializes JSON data in the payload of the request and uses `binding.Validate` to perform validation. Similar to `binding.Form`, no error handling is performed, but you can get the errors and handle them yourself.
35 |
36 |
37 | #### Validate
38 |
39 | `binding.Validate` receives a populated struct and checks it for errors, first by enforcing the `binding:"required"` value on struct field tags, then by executing the `Validate()` method on the struct, if it is a `binding.Validator`. (See usage below for an example.)
40 |
41 | *Note:* Marking a field as "required" means that you do not allow the zero value for that type (i.e. if you want to allow 0 in an int field, do not make it required).
42 |
43 |
44 | #### ErrorHandler
45 |
46 | `binding.ErrorHandler` is a small middleware that simply writes a `400` code to the response and also a JSON payload describing the errors, *if* any errors have been mapped to the context. It does nothing if there are no errors.
47 |
48 |
49 |
50 | ## Usage
51 |
52 | This is a contrived example to show a few different ways to use the `binding` package.
53 |
54 | ```go
55 | package main
56 |
57 | import (
58 | "net/http"
59 |
60 | "github.com/codegangsta/martini"
61 | "github.com/codegangsta/martini-contrib/binding"
62 | )
63 |
64 | type BlogPost struct {
65 | Title string `form:"title" json:"title" binding:"required"`
66 | Content string `form:"content" json:"content"`
67 | Views int `form:"views" json:"views"`
68 | unexported string `form:"-"` // skip binding of unexported fields
69 | }
70 |
71 | // This method implements binding.Validator and is executed by the binding.Validate middleware
72 | func (bp BlogPost) Validate(errors *binding.Errors, req *http.Request) {
73 | if req.Header.Get("X-Custom-Thing") == "" {
74 | errors.Overall["x-custom-thing"] = "The X-Custom-Thing header is required"
75 | }
76 | if len(bp.Title) < 4 {
77 | errors.Fields["title"] = "Too short; minimum 4 characters"
78 | } else if len(bp.Title) > 120 {
79 | errors.Fields["title"] = "Too long; maximum 120 characters"
80 | }
81 | if bp.Views < 0 {
82 | errors.Fields["views"] = "Views must be at least 0"
83 | }
84 | }
85 |
86 | func main() {
87 | m := martini.Classic()
88 |
89 | m.Post("/blog", binding.Bind(BlogPost{}), func(blogpost BlogPost) string {
90 | // This function won't execute if there were errors
91 | return blogpost.Title
92 | })
93 |
94 | m.Get("/blog", binding.Form(BlogPost{}), binding.ErrorHandler, func(blogpost BlogPost) string {
95 | // This function won't execute if there were errors
96 | return blogpost.Title
97 | })
98 |
99 | m.Get("/blog", binding.Form(BlogPost{}), func(blogpost BlogPost, err binding.Errors, resp http.ResponseWriter) string {
100 | // This function WILL execute if there are errors because binding.Form doesn't handle errors
101 | if err.Count() > 0 {
102 | resp.WriteHeader(http.StatusBadRequest)
103 | }
104 | return blogpost.Title
105 | })
106 |
107 | m.Post("/blog", binding.Json(BlogPost{}), myOwnErrorHandler, func(blogpost BlogPost) string {
108 | // By this point, I assume that my own middleware took care of any errors
109 | return blogpost.Title
110 | })
111 |
112 | m.Run()
113 | }
114 | ```
115 |
116 | ## Authors
117 | * [Matthew Holt](https://github.com/mholt)
118 | * [Michael Whatcott](https://github.com/mdwhatcott)
119 | * [Jeremy Saenz](https://github.com/codegangsta)
120 |
--------------------------------------------------------------------------------
/binding/binding.go:
--------------------------------------------------------------------------------
1 | // Package binding transforms, with validation, a raw request into
2 | // a populated structure used by your application logic.
3 | package binding
4 |
5 | import (
6 | "encoding/json"
7 | "github.com/codegangsta/martini"
8 | "net/http"
9 | "reflect"
10 | "strconv"
11 | "strings"
12 | )
13 |
14 | /*
15 | To the land of Middle-ware Earth:
16 |
17 | One func to rule them all,
18 | One func to find them,
19 | One func to bring them all,
20 | And in this package BIND them.
21 | */
22 |
23 | // Bind accepts a copy of an empty struct and populates it with
24 | // values from the request (if deserialization is successful). It
25 | // wraps up the functionality of the Form and Json middleware
26 | // according to the Content-Type of the request, and it guesses
27 | // if no Content-Type is specified. Bind invokes the ErrorHandler
28 | // middleware to bail out if errors occurred. If you want to perform
29 | // your own error handling, use Form or Json middleware directly.
30 | // An interface pointer can be added as a second argument in order
31 | // to map the struct to a specific interface.
32 | func Bind(obj interface{}, ifacePtr ...interface{}) martini.Handler {
33 | return func(context martini.Context, req *http.Request) {
34 | contentType := req.Header.Get("Content-Type")
35 |
36 | if strings.Contains(contentType, "form-urlencoded") {
37 | context.Invoke(Form(obj, ifacePtr...))
38 | } else if strings.Contains(contentType, "multipart/form-data") {
39 | context.Invoke(MultipartForm(obj, ifacePtr...))
40 | } else if strings.Contains(contentType, "json") {
41 | context.Invoke(Json(obj, ifacePtr...))
42 | } else {
43 | context.Invoke(Json(obj, ifacePtr...))
44 | if getErrors(context).Count() > 0 {
45 | context.Invoke(Form(obj, ifacePtr...))
46 | }
47 | }
48 |
49 | context.Invoke(ErrorHandler)
50 | }
51 | }
52 |
53 | // Form is middleware to deserialize form-urlencoded data from the request.
54 | // It gets data from the form-urlencoded body, if present, or from the
55 | // query string. It uses the http.Request.ParseForm() method
56 | // to perform deserialization, then reflection is used to map each field
57 | // into the struct with the proper type. Structs with primitive slice types
58 | // (bool, float, int, string) can support deserialization of repeated form
59 | // keys, for example: key=val1&key=val2&key=val3
60 | // An interface pointer can be added as a second argument in order
61 | // to map the struct to a specific interface.
62 | func Form(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
63 | return func(context martini.Context, req *http.Request) {
64 | ensureNotPointer(formStruct)
65 | formStruct := reflect.New(reflect.TypeOf(formStruct))
66 | errors := newErrors()
67 | parseErr := req.ParseForm()
68 |
69 | // Format validation of the request body or the URL would add considerable overhead,
70 | // and ParseForm does not complain when URL encoding is off.
71 | // Because an empty request body or url can also mean absence of all needed values,
72 | // it is not in all cases a bad request, so let's return 422.
73 | if parseErr != nil {
74 | errors.Overall[DeserializationError] = parseErr.Error()
75 | }
76 |
77 | mapForm(formStruct, req.Form, errors)
78 |
79 | validateAndMap(formStruct, context, errors, ifacePtr...)
80 | }
81 | }
82 |
83 | func MultipartForm(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
84 | return func(context martini.Context, req *http.Request) {
85 | ensureNotPointer(formStruct)
86 | formStruct := reflect.New(reflect.TypeOf(formStruct))
87 | errors := newErrors()
88 |
89 | // Workaround for multipart forms returning nil instead of an error
90 | // when content is not multipart
91 | // https://code.google.com/p/go/issues/detail?id=6334
92 | multipartReader, err := req.MultipartReader()
93 | if err != nil {
94 | errors.Overall[DeserializationError] = err.Error()
95 | } else {
96 | form, parseErr := multipartReader.ReadForm(MaxMemory)
97 |
98 | if parseErr != nil {
99 | errors.Overall[DeserializationError] = parseErr.Error()
100 | }
101 |
102 | req.MultipartForm = form
103 | }
104 |
105 | mapForm(formStruct, req.MultipartForm.Value, errors)
106 |
107 | validateAndMap(formStruct, context, errors, ifacePtr...)
108 | }
109 | }
110 |
111 | // Json is middleware to deserialize a JSON payload from the request
112 | // into the struct that is passed in. The resulting struct is then
113 | // validated, but no error handling is actually performed here.
114 | // An interface pointer can be added as a second argument in order
115 | // to map the struct to a specific interface.
116 | func Json(jsonStruct interface{}, ifacePtr ...interface{}) martini.Handler {
117 | return func(context martini.Context, req *http.Request) {
118 | ensureNotPointer(jsonStruct)
119 | jsonStruct := reflect.New(reflect.TypeOf(jsonStruct))
120 | errors := newErrors()
121 |
122 | if req.Body != nil {
123 | defer req.Body.Close()
124 | }
125 |
126 | if err := json.NewDecoder(req.Body).Decode(jsonStruct.Interface()); err != nil {
127 | errors.Overall[DeserializationError] = err.Error()
128 | }
129 |
130 | validateAndMap(jsonStruct, context, errors, ifacePtr...)
131 | }
132 | }
133 |
134 | // Validate is middleware to enforce required fields. If the struct
135 | // passed in is a Validator, then the user-defined Validate method
136 | // is executed, and its errors are mapped to the context. This middleware
137 | // performs no error handling: it merely detects them and maps them.
138 | func Validate(obj interface{}) martini.Handler {
139 | return func(context martini.Context, req *http.Request) {
140 | errors := newErrors()
141 | validateStruct(errors, obj)
142 |
143 | if validator, ok := obj.(Validator); ok {
144 | validator.Validate(errors, req)
145 | }
146 | context.Map(*errors)
147 |
148 | }
149 | }
150 |
151 | func validateStruct(errors *Errors, obj interface{}) {
152 | typ := reflect.TypeOf(obj)
153 | val := reflect.ValueOf(obj)
154 |
155 | if typ.Kind() == reflect.Ptr {
156 | typ = typ.Elem()
157 | val = val.Elem()
158 | }
159 |
160 | for i := 0; i < typ.NumField(); i++ {
161 | field := typ.Field(i)
162 |
163 | // Allow ignored fields in the struct
164 | if field.Tag.Get("form") == "-" {
165 | continue
166 | }
167 |
168 | fieldValue := val.Field(i).Interface()
169 | zero := reflect.Zero(field.Type).Interface()
170 |
171 | if strings.Index(field.Tag.Get("binding"), "required") > -1 {
172 | if field.Type.Kind() == reflect.Struct {
173 | validateStruct(errors, fieldValue)
174 | } else if reflect.DeepEqual(zero, fieldValue) {
175 | errors.Fields[field.Name] = RequireError
176 | }
177 | }
178 | }
179 | }
180 |
181 | func mapForm(formStruct reflect.Value, form map[string][]string, errors *Errors) {
182 | typ := formStruct.Elem().Type()
183 |
184 | for i := 0; i < typ.NumField(); i++ {
185 | typeField := typ.Field(i)
186 | if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
187 | structField := formStruct.Elem().Field(i)
188 | if !structField.CanSet() {
189 | continue
190 | }
191 |
192 | inputValue, exists := form[inputFieldName]
193 |
194 | if !exists {
195 | continue
196 | }
197 |
198 | numElems := len(inputValue)
199 | if structField.Kind() == reflect.Slice && numElems > 0 {
200 | sliceOf := structField.Type().Elem().Kind()
201 | slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
202 | for i := 0; i < numElems; i++ {
203 | setWithProperType(sliceOf, inputValue[i], slice.Index(i), inputFieldName, errors)
204 | }
205 | formStruct.Elem().Field(i).Set(slice)
206 | } else {
207 | setWithProperType(typeField.Type.Kind(), inputValue[0], structField, inputFieldName, errors)
208 | }
209 | }
210 | }
211 | }
212 |
213 | // ErrorHandler simply counts the number of errors in the
214 | // context and, if more than 0, writes a 400 Bad Request
215 | // response and a JSON payload describing the errors with
216 | // the "Content-Type" set to "application/json".
217 | // Middleware remaining on the stack will not even see the request
218 | // if, by this point, there are any errors.
219 | // This is a "default" handler, of sorts, and you are
220 | // welcome to use your own instead. The Bind middleware
221 | // invokes this automatically for convenience.
222 | func ErrorHandler(errs Errors, resp http.ResponseWriter) {
223 | if errs.Count() > 0 {
224 | resp.Header().Set("Content-Type", "application/json; charset=utf-8")
225 | if _, ok := errs.Overall[DeserializationError]; ok {
226 | resp.WriteHeader(http.StatusBadRequest)
227 | } else {
228 | resp.WriteHeader(422)
229 | }
230 | errOutput, _ := json.Marshal(errs)
231 | resp.Write(errOutput)
232 | return
233 | }
234 | }
235 |
236 | // This sets the value in a struct of an indeterminate type to the
237 | // matching value from the request (via Form middleware) in the
238 | // same type, so that not all deserialized values have to be strings.
239 | // Supported types are string, int, float, and bool.
240 | func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value, nameInTag string, errors *Errors) {
241 | switch valueKind {
242 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
243 | if val == "" {
244 | val = "0"
245 | }
246 | intVal, err := strconv.Atoi(val)
247 | if err != nil {
248 | errors.Fields[nameInTag] = IntegerTypeError
249 | } else {
250 | structField.SetInt(int64(intVal))
251 | }
252 | case reflect.Bool:
253 | if val == "" {
254 | val = "false"
255 | }
256 | boolVal, err := strconv.ParseBool(val)
257 | if err != nil {
258 | errors.Fields[nameInTag] = BooleanTypeError
259 | } else {
260 | structField.SetBool(boolVal)
261 | }
262 | case reflect.Float32:
263 | if val == "" {
264 | val = "0.0"
265 | }
266 | floatVal, err := strconv.ParseFloat(val, 32)
267 | if err != nil {
268 | errors.Fields[nameInTag] = FloatTypeError
269 | } else {
270 | structField.SetFloat(floatVal)
271 | }
272 | case reflect.Float64:
273 | if val == "" {
274 | val = "0.0"
275 | }
276 | floatVal, err := strconv.ParseFloat(val, 64)
277 | if err != nil {
278 | errors.Fields[nameInTag] = FloatTypeError
279 | } else {
280 | structField.SetFloat(floatVal)
281 | }
282 | case reflect.String:
283 | structField.SetString(val)
284 | }
285 | }
286 |
287 | // Don't pass in pointers to bind to. Can lead to bugs. See:
288 | // https://github.com/codegangsta/martini-contrib/issues/40
289 | // https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
290 | func ensureNotPointer(obj interface{}) {
291 | if reflect.TypeOf(obj).Kind() == reflect.Ptr {
292 | panic("Pointers are not accepted as binding models")
293 | }
294 | }
295 |
296 | // Performs validation and combines errors from validation
297 | // with errors from deserialization, then maps both the
298 | // resulting struct and the errors to the context.
299 | func validateAndMap(obj reflect.Value, context martini.Context, errors *Errors, ifacePtr ...interface{}) {
300 | context.Invoke(Validate(obj.Interface()))
301 | errors.combine(getErrors(context))
302 | context.Map(*errors)
303 | context.Map(obj.Elem().Interface())
304 | if len(ifacePtr) > 0 {
305 | context.MapTo(obj.Elem().Interface(), ifacePtr[0])
306 | }
307 | }
308 |
309 | func newErrors() *Errors {
310 | return &Errors{make(map[string]string), make(map[string]string)}
311 | }
312 |
313 | func getErrors(context martini.Context) Errors {
314 | return context.Get(reflect.TypeOf(Errors{})).Interface().(Errors)
315 | }
316 |
317 | func (this *Errors) combine(other Errors) {
318 | for key, val := range other.Fields {
319 | if _, exists := this.Fields[key]; !exists {
320 | this.Fields[key] = val
321 | }
322 | }
323 | for key, val := range other.Overall {
324 | if _, exists := this.Overall[key]; !exists {
325 | this.Overall[key] = val
326 | }
327 | }
328 | }
329 |
330 | // Total errors is the sum of errors with the request overall
331 | // and errors on individual fields.
332 | func (self Errors) Count() int {
333 | return len(self.Overall) + len(self.Fields)
334 | }
335 |
336 | type (
337 | // Errors represents the contract of the response body when the
338 | // binding step fails before getting to the application.
339 | Errors struct {
340 | Overall map[string]string `json:"overall"`
341 | Fields map[string]string `json:"fields"`
342 | }
343 |
344 | // Implement the Validator interface to define your own input
345 | // validation before the request even gets to your application.
346 | // The Validate method will be executed during the validation phase.
347 | Validator interface {
348 | Validate(*Errors, *http.Request)
349 | }
350 | )
351 |
352 | var (
353 | // Maximum amount of memory to use when parsing a multipart form.
354 | // Set this to whatever value you prefer; default is 10 MB.
355 | MaxMemory = int64(1024 * 1024 * 10)
356 | )
357 |
358 | const (
359 | RequireError string = "Required"
360 | DeserializationError string = "DeserializationError"
361 | IntegerTypeError string = "IntegerTypeError"
362 | BooleanTypeError string = "BooleanTypeError"
363 | FloatTypeError string = "FloatTypeError"
364 | )
365 |
--------------------------------------------------------------------------------
/binding/binding_test.go:
--------------------------------------------------------------------------------
1 | package binding
2 |
3 | import (
4 | "bytes"
5 | "mime/multipart"
6 | "net/http"
7 | "net/http/httptest"
8 | "strconv"
9 | "strings"
10 | "testing"
11 |
12 | "github.com/codegangsta/martini"
13 | )
14 |
15 | func TestBind(t *testing.T) {
16 | testBind(t, false)
17 | }
18 |
19 | func TestBindWithInterface(t *testing.T) {
20 | testBind(t, true)
21 | }
22 |
23 | func TestMultipartBind(t *testing.T) {
24 | index := 0
25 | for test, expectStatus := range bindMultipartTests {
26 | handler := func(post BlogPost, errors Errors) {
27 | handle(test, t, index, post, errors)
28 | }
29 | recorder := testMultipart(t, test, Bind(BlogPost{}), handler, index)
30 |
31 | if recorder.Code != expectStatus {
32 | t.Errorf("On test case %v, got status code %d but expected %d", test, recorder.Code, expectStatus)
33 | }
34 |
35 | index++
36 | }
37 | }
38 |
39 | func TestForm(t *testing.T) {
40 | testForm(t, false)
41 | }
42 |
43 | func TestFormWithInterface(t *testing.T) {
44 | testForm(t, true)
45 | }
46 |
47 | func TestMultipartForm(t *testing.T) {
48 | for index, test := range multipartformTests {
49 | handler := func(post BlogPost, errors Errors) {
50 | handle(test, t, index, post, errors)
51 | }
52 | testMultipart(t, test, MultipartForm(BlogPost{}), handler, index)
53 | }
54 | }
55 |
56 | func TestMultipartFormWithInterface(t *testing.T) {
57 | for index, test := range multipartformTests {
58 | handler := func(post Modeler, errors Errors) {
59 | post.Create(test, t, index)
60 | }
61 | testMultipart(t, test, MultipartForm(BlogPost{}, (*Modeler)(nil)), handler, index)
62 | }
63 | }
64 |
65 | func TestJson(t *testing.T) {
66 | testJson(t, false)
67 | }
68 |
69 | func TestJsonWithInterface(t *testing.T) {
70 | testJson(t, true)
71 | }
72 |
73 | func TestValidate(t *testing.T) {
74 | handlerMustErr := func(errors Errors) {
75 | if errors.Count() == 0 {
76 | t.Error("Expected at least one error, got 0")
77 | }
78 | }
79 | handlerNoErr := func(errors Errors) {
80 | if errors.Count() > 0 {
81 | t.Error("Expected no errors, got", errors.Count())
82 | }
83 | }
84 |
85 | performValidationTest(&BlogPost{"", "...", 0, 0, []int{}}, handlerMustErr, t)
86 | performValidationTest(&BlogPost{"Good Title", "Good content", 0, 0, []int{}}, handlerNoErr, t)
87 |
88 | performValidationTest(&User{Name: "Jim", Home: Address{"", ""}}, handlerMustErr, t)
89 | performValidationTest(&User{Name: "Jim", Home: Address{"required", ""}}, handlerNoErr, t)
90 | }
91 |
92 | func handle(test testCase, t *testing.T, index int, post BlogPost, errors Errors) {
93 | assertEqualField(t, "Title", index, test.ref.Title, post.Title)
94 | assertEqualField(t, "Content", index, test.ref.Content, post.Content)
95 | assertEqualField(t, "Views", index, test.ref.Views, post.Views)
96 |
97 | for i := range test.ref.Multiple {
98 | if i >= len(post.Multiple) {
99 | t.Errorf("Expected: %v (size %d) to have same size as: %v (size %d)", post.Multiple, len(post.Multiple), test.ref.Multiple, len(test.ref.Multiple))
100 | break
101 | }
102 | if test.ref.Multiple[i] != post.Multiple[i] {
103 | t.Errorf("Expected: %v to deep equal: %v", post.Multiple, test.ref.Multiple)
104 | break
105 | }
106 | }
107 |
108 | if test.ok && errors.Count() > 0 {
109 | t.Errorf("%+v should be OK (0 errors), but had errors: %+v", test, errors)
110 | } else if !test.ok && errors.Count() == 0 {
111 | t.Errorf("%+v should have errors, but was OK (0 errors): %+v", test)
112 | }
113 | }
114 |
115 | func testBind(t *testing.T, withInterface bool) {
116 | index := 0
117 | for test, expectStatus := range bindTests {
118 | m := martini.Classic()
119 | recorder := httptest.NewRecorder()
120 | handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) }
121 | binding := Bind(BlogPost{})
122 |
123 | if withInterface {
124 | handler = func(post BlogPost, errors Errors) {
125 | post.Create(test, t, index)
126 | }
127 | binding = Bind(BlogPost{}, (*Modeler)(nil))
128 | }
129 |
130 | switch test.method {
131 | case "GET":
132 | m.Get(route, binding, handler)
133 | case "POST":
134 | m.Post(route, binding, handler)
135 | }
136 |
137 | req, err := http.NewRequest(test.method, test.path, strings.NewReader(test.payload))
138 | req.Header.Add("Content-Type", test.contentType)
139 |
140 | if err != nil {
141 | t.Error(err)
142 | }
143 | m.ServeHTTP(recorder, req)
144 |
145 | if recorder.Code != expectStatus {
146 | t.Errorf("On test case %v, got status code %d but expected %d", test, recorder.Code, expectStatus)
147 | }
148 |
149 | index++
150 | }
151 | }
152 |
153 | func testJson(t *testing.T, withInterface bool) {
154 | for index, test := range jsonTests {
155 | recorder := httptest.NewRecorder()
156 | handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) }
157 | binding := Json(BlogPost{})
158 |
159 | if withInterface {
160 | handler = func(post BlogPost, errors Errors) {
161 | post.Create(test, t, index)
162 | }
163 | binding = Bind(BlogPost{}, (*Modeler)(nil))
164 | }
165 |
166 | m := martini.Classic()
167 | switch test.method {
168 | case "GET":
169 | m.Get(route, binding, handler)
170 | case "POST":
171 | m.Post(route, binding, handler)
172 | case "PUT":
173 | m.Put(route, binding, handler)
174 | case "DELETE":
175 | m.Delete(route, binding, handler)
176 | }
177 |
178 | req, err := http.NewRequest(test.method, route, strings.NewReader(test.payload))
179 | if err != nil {
180 | t.Error(err)
181 | }
182 | m.ServeHTTP(recorder, req)
183 | }
184 | }
185 |
186 | func testForm(t *testing.T, withInterface bool) {
187 | for index, test := range formTests {
188 | recorder := httptest.NewRecorder()
189 | handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) }
190 | binding := Form(BlogPost{})
191 |
192 | if withInterface {
193 | handler = func(post BlogPost, errors Errors) {
194 | post.Create(test, t, index)
195 | }
196 | binding = Form(BlogPost{}, (*Modeler)(nil))
197 | }
198 |
199 | m := martini.Classic()
200 | switch test.method {
201 | case "GET":
202 | m.Get(route, binding, handler)
203 | case "POST":
204 | m.Post(route, binding, handler)
205 | }
206 |
207 | req, err := http.NewRequest(test.method, test.path, nil)
208 | if err != nil {
209 | t.Error(err)
210 | }
211 | m.ServeHTTP(recorder, req)
212 | }
213 | }
214 |
215 | func testMultipart(t *testing.T, test testCase, middleware martini.Handler, handler martini.Handler, index int) *httptest.ResponseRecorder {
216 | recorder := httptest.NewRecorder()
217 |
218 | m := martini.Classic()
219 | m.Post(route, middleware, handler)
220 |
221 | body := &bytes.Buffer{}
222 | writer := multipart.NewWriter(body)
223 | writer.WriteField("title", test.ref.Title)
224 | writer.WriteField("content", test.ref.Content)
225 | writer.WriteField("views", strconv.Itoa(test.ref.Views))
226 | if len(test.ref.Multiple) != 0 {
227 | for _, value := range test.ref.Multiple {
228 | writer.WriteField("multiple", strconv.Itoa(value))
229 | }
230 | }
231 |
232 | req, err := http.NewRequest(test.method, test.path, body)
233 | req.Header.Add("Content-Type", writer.FormDataContentType())
234 |
235 | if err != nil {
236 | t.Error(err)
237 | }
238 |
239 | err = writer.Close()
240 | if err != nil {
241 | t.Error(err)
242 | }
243 |
244 | m.ServeHTTP(recorder, req)
245 |
246 | return recorder
247 | }
248 |
249 | func assertEqualField(t *testing.T, fieldname string, testcasenumber int, expected interface{}, got interface{}) {
250 | if expected != got {
251 | t.Errorf("%s: expected=%s, got=%s in test case %d\n", fieldname, expected, got, testcasenumber)
252 | }
253 | }
254 |
255 | func performValidationTest(data interface{}, handler func(Errors), t *testing.T) {
256 | recorder := httptest.NewRecorder()
257 | m := martini.Classic()
258 | m.Get(route, Validate(data), handler)
259 |
260 | req, err := http.NewRequest("GET", route, nil)
261 | if err != nil {
262 | t.Error("HTTP error:", err)
263 | }
264 |
265 | m.ServeHTTP(recorder, req)
266 | }
267 |
268 | func (self BlogPost) Validate(errors *Errors, req *http.Request) {
269 | if len(self.Title) < 4 {
270 | errors.Fields["Title"] = "Too short; minimum 4 characters"
271 | }
272 | if len(self.Content) > 1024 {
273 | errors.Fields["Content"] = "Too long; maximum 1024 characters"
274 | }
275 | if len(self.Content) < 5 {
276 | errors.Fields["Content"] = "Too short; minimum 5 characters"
277 | }
278 | }
279 |
280 | func (self BlogPost) Create(test testCase, t *testing.T, index int) {
281 | assertEqualField(t, "Title", index, test.ref.Title, self.Title)
282 | assertEqualField(t, "Content", index, test.ref.Content, self.Content)
283 | assertEqualField(t, "Views", index, test.ref.Views, self.Views)
284 |
285 | for i := range test.ref.Multiple {
286 | if i >= len(self.Multiple) {
287 | t.Errorf("Expected: %v (size %d) to have same size as: %v (size %d)", self.Multiple, len(self.Multiple), test.ref.Multiple, len(test.ref.Multiple))
288 | break
289 | }
290 | if test.ref.Multiple[i] != self.Multiple[i] {
291 | t.Errorf("Expected: %v to deep equal: %v", self.Multiple, test.ref.Multiple)
292 | break
293 | }
294 | }
295 | }
296 |
297 | func (self BlogSection) Create(test testCase, t *testing.T, index int) {
298 | // intentionally left empty
299 | }
300 |
301 | type (
302 | testCase struct {
303 | method string
304 | path string
305 | payload string
306 | contentType string
307 | ok bool
308 | ref *BlogPost
309 | }
310 |
311 | Modeler interface {
312 | Create(test testCase, t *testing.T, index int)
313 | }
314 |
315 | BlogPost struct {
316 | Title string `form:"title" json:"title" binding:"required"`
317 | Content string `form:"content" json:"content"`
318 | Views int `form:"views" json:"views"`
319 | internal int `form:"-"`
320 | Multiple []int `form:"multiple"`
321 | }
322 |
323 | BlogSection struct {
324 | Title string `form:"title" json:"title" binding:"required"`
325 | Content string `form:"content" json:"content"`
326 | }
327 |
328 | User struct {
329 | Name string `json:"name" binding:"required"`
330 | Home Address `json:"address" binding:"required"`
331 | }
332 |
333 | Address struct {
334 | Street1 string `json:"street1" binding:"required"`
335 | Street2 string `json:"street2"`
336 | }
337 | )
338 |
339 | var (
340 | bindTests = map[testCase]int{
341 | // These should bail at the deserialization/binding phase
342 | testCase{
343 | "POST",
344 | path,
345 | `{ bad JSON `,
346 | "application/json",
347 | false,
348 | new(BlogPost),
349 | }: http.StatusBadRequest,
350 | testCase{
351 | "POST",
352 | path,
353 | `not multipart but has content-type`,
354 | "multipart/form-data",
355 | false,
356 | new(BlogPost),
357 | }: http.StatusBadRequest,
358 | testCase{
359 | "POST",
360 | path,
361 | `no content-type and not URL-encoded or JSON"`,
362 | "",
363 | false,
364 | new(BlogPost),
365 | }: http.StatusBadRequest,
366 |
367 | // These should deserialize, then bail at the validation phase
368 | testCase{
369 | "POST",
370 | path + "?title= This is wrong ",
371 | `not URL-encoded but has content-type`,
372 | "x-www-form-urlencoded",
373 | false,
374 | new(BlogPost),
375 | }: 422, // according to comments in Form() -> although the request is not url encoded, ParseForm does not complain
376 | testCase{
377 | "GET",
378 | path + "?content=This+is+the+content",
379 | ``,
380 | "x-www-form-urlencoded",
381 | false,
382 | &BlogPost{Title: "", Content: "This is the content"},
383 | }: 422,
384 | testCase{
385 | "GET",
386 | path + "",
387 | `{"content":"", "title":"Blog Post Title"}`,
388 | "application/json",
389 | false,
390 | &BlogPost{Title: "Blog Post Title", Content: ""},
391 | }: 422,
392 |
393 | // These should succeed
394 | testCase{
395 | "GET",
396 | path + "",
397 | `{"content":"This is the content", "title":"Blog Post Title"}`,
398 | "application/json",
399 | true,
400 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
401 | }: http.StatusOK,
402 | testCase{
403 | "GET",
404 | path + "?content=This+is+the+content&title=Blog+Post+Title",
405 | ``,
406 | "",
407 | true,
408 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
409 | }: http.StatusOK,
410 | testCase{
411 | "GET",
412 | path + "?content=This is the content&title=Blog+Post+Title",
413 | `{"content":"This is the content", "title":"Blog Post Title"}`,
414 | "",
415 | true,
416 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
417 | }: http.StatusOK,
418 | testCase{
419 | "GET",
420 | path + "",
421 | `{"content":"This is the content", "title":"Blog Post Title"}`,
422 | "",
423 | true,
424 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
425 | }: http.StatusOK,
426 | }
427 |
428 | bindMultipartTests = map[testCase]int{
429 | // This should deserialize, then bail at the validation phase
430 | testCase{
431 | "POST",
432 | path,
433 | "",
434 | "multipart/form-data",
435 | false,
436 | &BlogPost{Title: "", Content: "This is the content"},
437 | }: 422,
438 | // This should succeed
439 | testCase{
440 | "POST",
441 | path,
442 | "",
443 | "multipart/form-data",
444 | true,
445 | &BlogPost{Title: "This is the Title", Content: "This is the content"},
446 | }: http.StatusOK,
447 | }
448 |
449 | formTests = []testCase{
450 | {
451 | "GET",
452 | path + "?content=This is the content",
453 | "",
454 | "",
455 | false,
456 | &BlogPost{Title: "", Content: "This is the content"},
457 | },
458 | {
459 | "POST",
460 | path + "?content=This+is+the+content&title=Blog+Post+Title&views=3",
461 | "",
462 | "",
463 | false, // false because POST requests should have a body, not just a query string
464 | &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3},
465 | },
466 | {
467 | "GET",
468 | path + "?content=This+is+the+content&title=Blog+Post+Title&views=3&multiple=5&multiple=10&multiple=15&multiple=20",
469 | "",
470 | "",
471 | true,
472 | &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3, Multiple: []int{5, 10, 15, 20}},
473 | },
474 | }
475 |
476 | multipartformTests = []testCase{
477 | {
478 | "POST",
479 | path,
480 | "",
481 | "multipart/form-data",
482 | false,
483 | &BlogPost{Title: "", Content: "This is the content"},
484 | },
485 | {
486 | "POST",
487 | path,
488 | "",
489 | "multipart/form-data",
490 | false,
491 | &BlogPost{Title: "Blog Post Title", Views: 3},
492 | },
493 | {
494 | "POST",
495 | path,
496 | "",
497 | "multipart/form-data",
498 | true,
499 | &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3, Multiple: []int{5, 10, 15, 20}},
500 | },
501 | }
502 |
503 | jsonTests = []testCase{
504 | // bad requests
505 | {
506 | "GET",
507 | "",
508 | `{blah blah blah}`,
509 | "",
510 | false,
511 | &BlogPost{},
512 | },
513 | {
514 | "POST",
515 | "",
516 | `{asdf}`,
517 | "",
518 | false,
519 | &BlogPost{},
520 | },
521 | {
522 | "PUT",
523 | "",
524 | `{blah blah blah}`,
525 | "",
526 | false,
527 | &BlogPost{},
528 | },
529 | {
530 | "DELETE",
531 | "",
532 | `{;sdf _SDf- }`,
533 | "",
534 | false,
535 | &BlogPost{},
536 | },
537 |
538 | // Valid-JSON requests
539 | {
540 | "GET",
541 | "",
542 | `{"content":"This is the content"}`,
543 | "",
544 | false,
545 | &BlogPost{Title: "", Content: "This is the content"},
546 | },
547 | {
548 | "POST",
549 | "",
550 | `{}`,
551 | "application/json",
552 | false,
553 | &BlogPost{Title: "", Content: ""},
554 | },
555 | {
556 | "POST",
557 | "",
558 | `{"content":"This is the content", "title":"Blog Post Title"}`,
559 | "",
560 | true,
561 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
562 | },
563 | {
564 | "PUT",
565 | "",
566 | `{"content":"This is the content", "title":"Blog Post Title"}`,
567 | "",
568 | true,
569 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
570 | },
571 | {
572 | "DELETE",
573 | "",
574 | `{"content":"This is the content", "title":"Blog Post Title"}`,
575 | "",
576 | true,
577 | &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
578 | },
579 | }
580 | )
581 |
582 | const (
583 | route = "/blogposts/create"
584 | path = "http://localhost:3000" + route
585 | )
586 |
--------------------------------------------------------------------------------
/cors/README.md:
--------------------------------------------------------------------------------
1 | # cors
2 |
3 | Martini middleware/handler to enable CORS support.
4 |
5 | ## Usage
6 |
7 | ~~~ go
8 | import (
9 | "github.com/codegangsta/martini"
10 | "github.com/codegangsta/martini-contrib/cors"
11 | )
12 |
13 | func main() {
14 | m := martini.Classic()
15 | // CORS for https://foo.* origins, allowing:
16 | // - PUT and PATCH methods
17 | // - Origin header
18 | // - Credentials share
19 | m.Use(cors.Allow(&cors.Options{
20 | AllowOrigins: []string{"https://foo\\.*"},
21 | AllowMethods: []string{"PUT", "PATCH"},
22 | AllowHeaders: []string{"Origin"},
23 | ExposeHeaders: []string{"Content-Length"},
24 | AllowCredentials: true,
25 | }))
26 | m.Run()
27 | }
28 | ~~~
29 |
30 | ## Authors
31 |
32 | * [Burcu Dogan](http://github.com/rakyll)
33 |
--------------------------------------------------------------------------------
/cors/cors.go:
--------------------------------------------------------------------------------
1 | // Package cors provides handlers to enable CORS support.
2 | package cors
3 |
4 | import (
5 | "net/http"
6 | "regexp"
7 | "strconv"
8 | "strings"
9 | "time"
10 | )
11 |
12 | const (
13 | headerAllowOrigin = "Access-Control-Allow-Origin"
14 | headerAllowCredentials = "Access-Control-Allow-Credentials"
15 | headerAllowHeaders = "Access-Control-Allow-Headers"
16 | headerAllowMethods = "Access-Control-Allow-Methods"
17 | headerExposeHeaders = "Access-Control-Expose-Headers"
18 | headerMaxAge = "Access-Control-Max-Age"
19 |
20 | headerOrigin = "Origin"
21 | headerRequestMethod = "Access-Control-Request-Method"
22 | headerRequestHeaders = "Access-Control-Request-Headers"
23 | )
24 |
25 | // Represents Access Control options.
26 | type Options struct {
27 | // If set, all origins are allowed.
28 | AllowAllOrigins bool
29 | // A list of allowed domain patterns.
30 | AllowOrigins []string
31 | // If set, allows to share auth credentials such as cookies.
32 | AllowCredentials bool
33 | // A list of allowed HTTP methods.
34 | AllowMethods []string
35 | // A list of allowed HTTP headers.
36 | AllowHeaders []string
37 | // A list of exposed HTTP headers.
38 | ExposeHeaders []string
39 | // Max age of the CORS headers.
40 | MaxAge time.Duration
41 | }
42 |
43 | // Converts options into CORS headers.
44 | func (o *Options) Header(origin string) (headers map[string]string) {
45 | headers = make(map[string]string)
46 | // if origin is not alowed, don't extend the headers
47 | // with CORS headers.
48 | if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
49 | return
50 | }
51 |
52 | // add allow origin
53 | if o.AllowAllOrigins {
54 | headers[headerAllowOrigin] = "*"
55 | } else {
56 | headers[headerAllowOrigin] = origin
57 | }
58 |
59 | // add allow credentials
60 | headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
61 |
62 | // add allow methods
63 | if len(o.AllowMethods) > 0 {
64 | headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
65 | }
66 |
67 | // add allow headers
68 | if len(o.AllowHeaders) > 0 {
69 | // TODO: Add default headers
70 | headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
71 | }
72 |
73 | // add exposed header
74 | if len(o.ExposeHeaders) > 0 {
75 | headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
76 | }
77 | // add a max age header
78 | if o.MaxAge > time.Duration(0) {
79 | headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
80 | }
81 | return
82 | }
83 |
84 | // Converts options into CORS headers for a preflight response.
85 | func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
86 | headers = make(map[string]string)
87 | if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
88 | return
89 | }
90 | // verify if requested method is allowed
91 | // TODO: Too many for loops
92 | for _, method := range o.AllowMethods {
93 | if method == rMethod {
94 | headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
95 | break
96 | }
97 | }
98 |
99 | // verify if requested headers are allowed
100 | var allowed []string
101 | for _, rHeader := range strings.Split(rHeaders, ",") {
102 | lookupLoop:
103 | for _, allowedHeader := range o.AllowHeaders {
104 | if rHeader == allowedHeader {
105 | allowed = append(allowed, rHeader)
106 | break lookupLoop
107 | }
108 | }
109 | }
110 |
111 | // add allowed headers
112 | if len(allowed) > 0 {
113 | headers[headerAllowHeaders] = strings.Join(allowed, ",")
114 | }
115 |
116 | // add exposed headers
117 | if len(o.ExposeHeaders) > 0 {
118 | headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
119 | }
120 | // add a max age header
121 | if o.MaxAge > time.Duration(0) {
122 | headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
123 | }
124 | return
125 | }
126 |
127 | // Looks up if the origin matches one of the patterns
128 | // provided in Options.AllowOrigins patterns.
129 | func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
130 | for _, pattern := range o.AllowOrigins {
131 | allowed, _ = regexp.MatchString(pattern, origin)
132 | if allowed {
133 | return
134 | }
135 | }
136 | return
137 | }
138 |
139 | // Allows CORS for requests those match the provided options.
140 | func Allow(opts *Options) http.HandlerFunc {
141 | return func(res http.ResponseWriter, req *http.Request) {
142 | var (
143 | origin = req.Header.Get(headerOrigin)
144 | requestedMethod = req.Header.Get(headerRequestMethod)
145 | requestedHeaders = req.Header.Get(headerRequestHeaders)
146 | // additional headers to be added
147 | // to the response.
148 | headers map[string]string
149 | )
150 |
151 | if req.Method == "OPTIONS" &&
152 | (requestedMethod != "" || requestedHeaders != "") {
153 | // TODO: if preflight, respond with exact headers if allowed
154 | headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
155 | } else {
156 | headers = opts.Header(origin)
157 | }
158 |
159 | for key, value := range headers {
160 | res.Header().Set(key, value)
161 | }
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/cors/cors_test.go:
--------------------------------------------------------------------------------
1 | package cors
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 | "time"
8 |
9 | "github.com/codegangsta/martini"
10 | )
11 |
12 | func Test_AllowAll(t *testing.T) {
13 | recorder := httptest.NewRecorder()
14 | m := martini.New()
15 | m.Use(Allow(&Options{
16 | AllowAllOrigins: true,
17 | }))
18 |
19 | r, _ := http.NewRequest("PUT", "foo", nil)
20 | m.ServeHTTP(recorder, r)
21 |
22 | if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
23 | t.Errorf("Allow-Origin header should be *")
24 | }
25 | }
26 |
27 | func Test_AllowRegexMatch(t *testing.T) {
28 | recorder := httptest.NewRecorder()
29 | m := martini.New()
30 | m.Use(Allow(&Options{
31 | AllowOrigins: []string{"https://aaa.com", "https://foo\\.*"},
32 | }))
33 |
34 | origin := "https://foo.com"
35 | r, _ := http.NewRequest("PUT", "foo", nil)
36 | r.Header.Add("Origin", origin)
37 | m.ServeHTTP(recorder, r)
38 |
39 | headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
40 | if headerValue != origin {
41 | t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
42 | }
43 | }
44 |
45 | func Test_AllowRegexNoMatch(t *testing.T) {
46 | recorder := httptest.NewRecorder()
47 | m := martini.New()
48 | m.Use(Allow(&Options{
49 | AllowOrigins: []string{"https://foo\\.*"},
50 | }))
51 |
52 | origin := "https://bar.com"
53 | r, _ := http.NewRequest("PUT", "foo", nil)
54 | r.Header.Add("Origin", origin)
55 | m.ServeHTTP(recorder, r)
56 |
57 | headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
58 | if headerValue != "" {
59 | t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
60 | }
61 | }
62 |
63 | func Test_OtherHeaders(t *testing.T) {
64 | recorder := httptest.NewRecorder()
65 | m := martini.New()
66 | m.Use(Allow(&Options{
67 | AllowAllOrigins: true,
68 | AllowCredentials: true,
69 | AllowMethods: []string{"PATCH", "GET"},
70 | AllowHeaders: []string{"Origin", "X-whatever"},
71 | ExposeHeaders: []string{"Content-Length", "Hello"},
72 | MaxAge: 5 * time.Minute,
73 | }))
74 |
75 | r, _ := http.NewRequest("PUT", "foo", nil)
76 | m.ServeHTTP(recorder, r)
77 |
78 | credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
79 | methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
80 | headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
81 | exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
82 | maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
83 |
84 | if credentialsVal != "true" {
85 | t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
86 | }
87 |
88 | if methodsVal != "PATCH,GET" {
89 | t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
90 | }
91 |
92 | if headersVal != "Origin,X-whatever" {
93 | t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
94 | }
95 |
96 | if exposedHeadersVal != "Content-Length,Hello" {
97 | t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
98 | }
99 |
100 | if maxAgeVal != "300" {
101 | t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
102 | }
103 | }
104 |
105 | func Test_Preflight(t *testing.T) {
106 | recorder := httptest.NewRecorder()
107 | m := martini.New()
108 | m.Use(Allow(&Options{
109 | AllowAllOrigins: true,
110 | AllowMethods: []string{"PUT", "PATCH"},
111 | AllowHeaders: []string{"Origin", "X-whatever"},
112 | }))
113 |
114 | r, _ := http.NewRequest("OPTIONS", "foo", nil)
115 | r.Header.Add(headerRequestMethod, "PUT")
116 | r.Header.Add(headerRequestHeaders, "X-whatever")
117 | m.ServeHTTP(recorder, r)
118 |
119 | methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
120 | headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
121 |
122 | if methodsVal != "PUT,PATCH" {
123 | t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
124 | }
125 |
126 | if headersVal != "X-whatever" {
127 | t.Errorf("Allow-Headers is expected to be X-whatever, found %v", headersVal)
128 | }
129 | }
130 |
131 | func Benchmark_WithoutCORS(b *testing.B) {
132 | recorder := httptest.NewRecorder()
133 | m := martini.New()
134 |
135 | b.ResetTimer()
136 | for i := 0; i < 100; i++ {
137 | r, _ := http.NewRequest("PUT", "foo", nil)
138 | m.ServeHTTP(recorder, r)
139 | }
140 | }
141 |
142 | func Benchmark_WithCORS(b *testing.B) {
143 | recorder := httptest.NewRecorder()
144 | m := martini.New()
145 | m.Use(Allow(&Options{
146 | AllowAllOrigins: true,
147 | AllowCredentials: true,
148 | AllowMethods: []string{"PATCH", "GET"},
149 | AllowHeaders: []string{"Origin", "X-whatever"},
150 | MaxAge: 5 * time.Minute,
151 | }))
152 |
153 | b.ResetTimer()
154 | for i := 0; i < 100; i++ {
155 | r, _ := http.NewRequest("PUT", "foo", nil)
156 | m.ServeHTTP(recorder, r)
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/encoder/README.md:
--------------------------------------------------------------------------------
1 | #### Encoder.
2 | This is a simple wrapper to the json.Marshal, which adds ability to skip some fields
3 | of structure.
4 | Unlike 'render' package it doesn't write anything, just returns marshalled data.
5 | It's useful for things like passwords, statuses, activation codes, etc...
6 |
7 | E.g.:
8 |
9 | ```go
10 | type Some struct {
11 | Login string `json:"login"`
12 | Password string `json:"password,omitempty" out:"false"`
13 | }
14 | ```
15 |
16 | Field 'Password' won't be exported.
17 |
18 | #### Usage.
19 | It's pretty straightforward:
20 |
21 | ```go
22 | m.Use(func(c martini.Context, w http.ResponseWriter) {
23 | c.MapTo(encoder.JsonEncoder{}, (*encoder.Encoder)(nil))
24 | w.Header().Set("Content-Type", "application/json; charset=utf-8")
25 | })
26 | ```
27 |
28 | Here is a ready to use example:
29 |
30 | ```go
31 | package main
32 |
33 | import (
34 | "github.com/codegangsta/martini-contrib/encoder"
35 | "github.com/codegangsta/martini"
36 | "log"
37 | "net/http"
38 | )
39 |
40 | type Some struct {
41 | Login string `json:"login"`
42 | Password string `json:"password" out:"false"`
43 | }
44 |
45 | func main() {
46 | m := martini.New()
47 | route := martini.NewRouter()
48 |
49 | // map json encoder
50 | m.Use(func(c martini.Context, w http.ResponseWriter) {
51 | c.MapTo(encoder.JsonEncoder{}, (*encoder.Encoder)(nil))
52 | w.Header().Set("Content-Type", "application/json; charset=utf-8")
53 | })
54 |
55 | route.Get("/test", func(enc encoder.Encoder) (int, []byte) {
56 | result := &Some{"awesome", "hidden"}
57 | return http.StatusOK, encoder.Must(enc.Encode(result))
58 | })
59 |
60 | m.Action(route.Handle)
61 |
62 | log.Println("Waiting for connections...")
63 |
64 | if err := http.ListenAndServe(":8000", m); err != nil {
65 | log.Fatal(err)
66 | }
67 | }
68 | ```
69 |
--------------------------------------------------------------------------------
/encoder/encoder.go:
--------------------------------------------------------------------------------
1 | package encoder
2 |
3 | // Original code borrowed from https://github.com/PuerkitoBio/martini-api-example
4 | // TextEncoder and XmlEncoder has been removed. If someone really needs it, let me know.
5 |
6 | // Supported tags:
7 | // - "out" if it sets to "false", value won't be set to field
8 | import (
9 | "encoding/json"
10 | "reflect"
11 | )
12 |
13 | // An Encoder implements an encoding format of values to be sent as response to
14 | // requests on the API endpoints.
15 | type Encoder interface {
16 | Encode(v ...interface{}) ([]byte, error)
17 | }
18 |
19 | // Because `panic`s are caught by martini's Recovery handler, it can be used
20 | // to return server-side errors (500). Some helpful text message should probably
21 | // be sent, although not the technical error (which is printed in the log).
22 | func Must(data []byte, err error) []byte {
23 | if err != nil {
24 | panic(err)
25 | }
26 | return data
27 | }
28 |
29 | type JsonEncoder struct{}
30 |
31 | // jsonEncoder is an Encoder that produces JSON-formatted responses.
32 | func (_ JsonEncoder) Encode(v ...interface{}) ([]byte, error) {
33 | var data interface{} = v
34 | var result interface{}
35 |
36 | if v == nil {
37 | // So that empty results produces `[]` and not `null`
38 | data = []interface{}{}
39 | } else if len(v) == 1 {
40 | data = v[0]
41 | }
42 |
43 | t := reflect.TypeOf(data)
44 |
45 | if t.Kind() == reflect.Ptr {
46 | t = t.Elem()
47 | }
48 |
49 | if t.Kind() == reflect.Struct {
50 | result = copyStruct(reflect.ValueOf(data), t).Interface()
51 | } else {
52 | result = data
53 | }
54 |
55 | b, err := json.Marshal(result)
56 |
57 | return b, err
58 | }
59 |
60 | func copyStruct(v reflect.Value, t reflect.Type) reflect.Value {
61 | result := reflect.New(t).Elem()
62 |
63 | if v.Kind() == reflect.Ptr {
64 | v = v.Elem()
65 | }
66 |
67 | for i := 0; i < v.NumField(); i++ {
68 | if tag := t.Field(i).Tag.Get("out"); tag == "false" {
69 | continue
70 | }
71 |
72 | if v.Field(i).Kind() == reflect.Struct {
73 | result.Field(i).Set(copyStruct(v.Field(i), t.Field(i).Type))
74 | continue
75 | }
76 |
77 | result.Field(i).Set(v.Field(i))
78 | }
79 |
80 | return result
81 | }
82 |
--------------------------------------------------------------------------------
/encoder/encoder_test.go:
--------------------------------------------------------------------------------
1 | package encoder
2 |
3 | import (
4 | "encoding/json"
5 | "testing"
6 | )
7 |
8 | type Sample struct {
9 | Visible string `json:"visible"`
10 | Hidden string `json:"hidden" out:"false"`
11 | }
12 |
13 | func TestEncoder(t *testing.T) {
14 | src := &Sample{Visible: "visible", Hidden: "this field won't be exported"}
15 | dst := &Sample{}
16 |
17 | enc := &JsonEncoder{}
18 | result, err := enc.Encode(src)
19 | if err != nil {
20 | t.Fatal(err)
21 | }
22 |
23 | if err := json.Unmarshal(result, dst); err != nil {
24 | t.Fatal("Unmarshal error:", err)
25 | }
26 |
27 | if dst.Hidden != "" {
28 | t.Fatalf("Expected empty field 'Hidden', got %v\n", dst.Hidden)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/gzip/README.md:
--------------------------------------------------------------------------------
1 | # gzip
2 | Gzip middleware for Martini.
3 |
4 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/gzip)
5 |
6 | ## Usage
7 |
8 | ~~~ go
9 | import (
10 | "github.com/codegangsta/martini"
11 | "github.com/codegangsta/martini-contrib/gzip"
12 | )
13 |
14 | func main() {
15 | m := martini.Classic()
16 | // gzip every request
17 | m.Use(gzip.All())
18 | m.Run()
19 | }
20 |
21 | ~~~
22 |
23 | Make sure to include the Gzip middleware above other middleware that alter the response body (like the render middleware).
24 |
25 | ## Authors
26 | * [Jeremy Saenz](http://github.com/codegangsta)
27 | * [Shane Logsdon](http://github.com/slogsdon)
28 |
--------------------------------------------------------------------------------
/gzip/gzip.go:
--------------------------------------------------------------------------------
1 | package gzip
2 |
3 | import (
4 | "compress/gzip"
5 | "github.com/codegangsta/martini"
6 | "net/http"
7 | "strings"
8 | )
9 |
10 | const (
11 | HeaderAcceptEncoding = "Accept-Encoding"
12 | HeaderContentEncoding = "Content-Encoding"
13 | HeaderContentLength = "Content-Length"
14 | HeaderContentType = "Content-Type"
15 | HeaderVary = "Vary"
16 | )
17 |
18 | var serveGzip = func(w http.ResponseWriter, r *http.Request, c martini.Context) {
19 | if !strings.Contains(r.Header.Get(HeaderAcceptEncoding), "gzip") {
20 | return
21 | }
22 |
23 | headers := w.Header()
24 | headers.Set(HeaderContentEncoding, "gzip")
25 | headers.Set(HeaderVary, HeaderAcceptEncoding)
26 |
27 | gz := gzip.NewWriter(w)
28 | defer gz.Close()
29 |
30 | gzw := gzipResponseWriter{gz, w.(martini.ResponseWriter)}
31 | c.MapTo(gzw, (*http.ResponseWriter)(nil))
32 |
33 | c.Next()
34 |
35 | // delete content length after we know we have been written to
36 | gzw.Header().Del("Content-Length")
37 | }
38 |
39 | // All returns a Handler that adds gzip compression to all requests
40 | func All() martini.Handler {
41 | return serveGzip
42 | }
43 |
44 | type gzipResponseWriter struct {
45 | w *gzip.Writer
46 | martini.ResponseWriter
47 | }
48 |
49 | func (grw gzipResponseWriter) Write(p []byte) (int, error) {
50 | if len(grw.Header().Get(HeaderContentType)) == 0 {
51 | grw.Header().Set(HeaderContentType, http.DetectContentType(p))
52 | }
53 |
54 | return grw.w.Write(p)
55 | }
56 |
--------------------------------------------------------------------------------
/gzip/gzip_test.go:
--------------------------------------------------------------------------------
1 | package gzip
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "strings"
8 | "testing"
9 | )
10 |
11 | func Test_GzipAll(t *testing.T) {
12 | // Set up
13 | recorder := httptest.NewRecorder()
14 | before := false
15 |
16 | m := martini.New()
17 | m.Use(All())
18 | m.Use(func(r http.ResponseWriter) {
19 | r.(martini.ResponseWriter).Before(func(rw martini.ResponseWriter) {
20 | before = true
21 | })
22 | })
23 |
24 | r, err := http.NewRequest("GET", "/", nil)
25 | if err != nil {
26 | t.Error(err)
27 | }
28 |
29 | m.ServeHTTP(recorder, r)
30 |
31 | // Make our assertions
32 | _, ok := recorder.HeaderMap[HeaderContentEncoding]
33 | if ok {
34 | t.Error(HeaderContentEncoding + " present")
35 | }
36 |
37 | ce := recorder.Header().Get(HeaderContentEncoding)
38 | if strings.EqualFold(ce, "gzip") {
39 | t.Error(HeaderContentEncoding + " is 'gzip'")
40 | }
41 |
42 | recorder = httptest.NewRecorder()
43 | r.Header.Set(HeaderAcceptEncoding, "gzip")
44 | m.ServeHTTP(recorder, r)
45 |
46 | // Make our assertions
47 | _, ok = recorder.HeaderMap[HeaderContentEncoding]
48 | if !ok {
49 | t.Error(HeaderContentEncoding + " not present")
50 | }
51 |
52 | ce = recorder.Header().Get(HeaderContentEncoding)
53 | if !strings.EqualFold(ce, "gzip") {
54 | t.Error(HeaderContentEncoding + " is not 'gzip'")
55 | }
56 |
57 | if before == false {
58 | t.Error("Before hook was not called")
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/method/README.md:
--------------------------------------------------------------------------------
1 | # method
2 | Martini middleware/handler for handling http method overrides.
3 | This checks for the X-HTTP-Method-Override header and uses it
4 | if the original request method is POST.
5 | GET/HEAD methods shouldn't be overriden, hence they can't be overriden.
6 |
7 | This is useful for REST APIs and services making use of many HTTP verbs, and when http clients don't support all of them.
8 |
9 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/method)
10 |
11 | ## Usage
12 |
13 | ~~~ go
14 | import (
15 | "github.com/codegangsta/martini"
16 | "github.com/codegangsta/martini-contrib/method"
17 | )
18 |
19 | func main() {
20 | m := martini.Classic()
21 | m.Use(method.Override())
22 | m.Run()
23 | }
24 |
25 | ~~~
26 |
27 | ## Authors
28 | * [Vincent Petithory](http://github.com/vincent-petithory)
29 |
--------------------------------------------------------------------------------
/method/override.go:
--------------------------------------------------------------------------------
1 | // package method implements http method override
2 | // using the X-HTTP-Method-Override http header.
3 | package method
4 |
5 | import (
6 | "errors"
7 | "net/http"
8 | )
9 |
10 | // HeaderHTTPMethodOverride is a commonly used
11 | // Http header to override the method.
12 | const HeaderHTTPMethodOverride = "X-HTTP-Method-Override"
13 |
14 | // ParamHTTPMethodOverride is a commonly used
15 | // HTML form parameter to override the method.
16 | const ParamHTTPMethodOverride = "_method"
17 |
18 | var httpMethods = []string{"PUT", "PATCH", "DELETE"}
19 |
20 | // ErrInvalidOverrideMethod is returned when
21 | // an invalid http method was given to OverrideRequestMethod.
22 | var ErrInvalidOverrideMethod = errors.New("invalid override method")
23 |
24 | func isValidOverrideMethod(method string) bool {
25 | for _, m := range httpMethods {
26 | if m == method {
27 | return true
28 | }
29 | }
30 | return false
31 | }
32 |
33 | // Override checks for the X-HTTP-Method-Override header
34 | // or the HTML for parameter, `_method`
35 | // and uses (if valid) the http method instead of
36 | // Request.Method.
37 | // This is especially useful for http clients
38 | // that don't support many http verbs.
39 | // It isn't secure to override e.g a GET to a POST,
40 | // so only Request.Method which are POSTs are considered.
41 | func Override() http.Handler {
42 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43 | if r.Method == "POST" {
44 | m := r.FormValue(ParamHTTPMethodOverride)
45 | if isValidOverrideMethod(m) {
46 | OverrideRequestMethod(r, m)
47 | }
48 | m = r.Header.Get(HeaderHTTPMethodOverride)
49 | if isValidOverrideMethod(m) {
50 | r.Method = m
51 | }
52 | }
53 | })
54 | }
55 |
56 | // OverrideRequestMethod overrides the http
57 | // request's method with the specified method.
58 | func OverrideRequestMethod(r *http.Request, method string) error {
59 | if !isValidOverrideMethod(method) {
60 | return ErrInvalidOverrideMethod
61 | }
62 | r.Header.Set(HeaderHTTPMethodOverride, method)
63 | return nil
64 | }
65 |
--------------------------------------------------------------------------------
/method/override_test.go:
--------------------------------------------------------------------------------
1 | package method
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | "time"
9 | )
10 |
11 | var tests = []struct {
12 | Method string
13 | OverrideMethod string
14 | ExpectedMethod string
15 | }{
16 | {"POST", "PUT", "PUT"},
17 | {"POST", "PATCH", "PATCH"},
18 | {"POST", "DELETE", "DELETE"},
19 | {"GET", "GET", "GET"},
20 | {"HEAD", "HEAD", "HEAD"},
21 | {"GET", "PUT", "GET"},
22 | {"HEAD", "DELETE", "HEAD"},
23 | }
24 |
25 | func TestOverride(t *testing.T) {
26 | for _, test := range tests {
27 | w := httptest.NewRecorder()
28 | r, err := http.NewRequest(test.Method, "/", nil)
29 | if err != nil {
30 | t.Error(err)
31 | }
32 | OverrideRequestMethod(r, test.OverrideMethod)
33 | Override().ServeHTTP(w, r)
34 | if r.Method != test.ExpectedMethod {
35 | t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method)
36 | }
37 | }
38 | }
39 |
40 | func selectRoute(r martini.Router, method string, h martini.Handler) {
41 | switch method {
42 | case "GET":
43 | r.Get("/", h)
44 | case "PATCH":
45 | r.Patch("/", h)
46 | case "POST":
47 | r.Post("/", h)
48 | case "PUT":
49 | r.Put("/", h)
50 | case "DELETE":
51 | r.Delete("/", h)
52 | case "OPTIONS":
53 | r.Options("/", h)
54 | case "HEAD":
55 | r.Head("/", h)
56 | default:
57 | panic("bad method")
58 | }
59 | }
60 |
61 | func TestMartiniSelectiveRouter(t *testing.T) {
62 | for _, test := range tests {
63 | w := httptest.NewRecorder()
64 | r := martini.NewRouter()
65 |
66 | done := make(chan bool)
67 | selectRoute(r, test.ExpectedMethod, func(rq *http.Request) {
68 | done <- true
69 | })
70 |
71 | req, err := http.NewRequest(test.Method, "/", nil)
72 | if err != nil {
73 | t.Fatal(err)
74 | }
75 | OverrideRequestMethod(req, test.OverrideMethod)
76 |
77 | m := martini.New()
78 | m.Use(Override())
79 | m.Action(r.Handle)
80 | go m.ServeHTTP(w, req)
81 | select {
82 | case <-done:
83 | case <-time.After(30 * time.Millisecond):
84 | t.Errorf("Expected router to route to %s, got something else (%v).", test.ExpectedMethod, test)
85 | }
86 | }
87 | }
88 |
89 | func TestInMartini(t *testing.T) {
90 | for _, test := range tests {
91 | w := httptest.NewRecorder()
92 | m := martini.New()
93 | m.Use(Override())
94 | m.Use(func(w http.ResponseWriter, r *http.Request) {
95 | if r.Method != test.ExpectedMethod {
96 | t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method)
97 | }
98 | })
99 |
100 | r, err := http.NewRequest(test.Method, "/", nil)
101 | if err != nil {
102 | t.Fatal(err)
103 | }
104 | OverrideRequestMethod(r, test.OverrideMethod)
105 |
106 | m.ServeHTTP(w, r)
107 | }
108 |
109 | }
110 |
111 | func TestParamenterOverrideInMartini(t *testing.T) {
112 | for _, test := range tests {
113 | w := httptest.NewRecorder()
114 | m := martini.New()
115 | m.Use(Override())
116 | m.Use(func(w http.ResponseWriter, r *http.Request) {
117 | if r.Method != test.ExpectedMethod {
118 | t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method)
119 | }
120 | })
121 |
122 | query := "_method=" + test.OverrideMethod
123 | r, err := http.NewRequest(test.Method, "/?"+query, nil)
124 | if err != nil {
125 | t.Fatal(err)
126 | }
127 |
128 | m.ServeHTTP(w, r)
129 | }
130 |
131 | }
132 |
--------------------------------------------------------------------------------
/render/README.md:
--------------------------------------------------------------------------------
1 | # render
2 | Martini middleware/handler for easily rendering serialized JSON and HTML template responses.
3 |
4 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/render)
5 |
6 | ## Usage
7 | render uses Go's [html/template](http://golang.org/pkg/html/template/) package to render html templates.
8 |
9 | ~~~ go
10 | // main.go
11 | package main
12 |
13 | import (
14 | "github.com/codegangsta/martini"
15 | "github.com/codegangsta/martini-contrib/render"
16 | )
17 |
18 | func main() {
19 | m := martini.Classic()
20 | // render html templates from templates directory
21 | m.Use(render.Renderer())
22 |
23 | m.Get("/", func(r render.Render) {
24 | r.HTML(200, "hello", "jeremy")
25 | })
26 |
27 | m.Run()
28 | }
29 |
30 | ~~~
31 |
32 | ~~~ html
33 |
34 |
Hello {{.}}!
35 | ~~~
36 |
37 | ### Options
38 | `render.Renderer` comes with a variety of configuration options:
39 |
40 | ~~~ go
41 | // ...
42 | m.Use(render.Renderer(render.Options{
43 | Directory: "templates", // Specify what path to load the templates from.
44 | Layout: "layout", // Specify a layout template. Layouts can call {{ yield }} to render the current template.
45 | Extensions: []string{".tmpl", ".html"}, // Specify extensions to load for templates.
46 | Funcs: []template.FuncMap{AppHelpers}, // Specify helper function maps for templates to access.
47 | Delims: render.Delims{"{[{", "}]}"}, // Sets delimiters to the specified strings.
48 | Charset: "UTF-8", // Sets encoding for json and html content-types. Default is "UTF-8".
49 | IndentJSON: true, // Output human readable JSON
50 | }))
51 | // ...
52 | ~~~
53 |
54 | ### Loading Templates
55 | By default the `render.Renderer` middleware will attempt to load templates with a '.tmpl' extension from the "templates" directory. Templates are found by traversing the templates directory and are named by path and basename. For instance, the following directory structure:
56 |
57 | ~~~
58 | templates/
59 | |
60 | |__ admin/
61 | | |
62 | | |__ index.tmpl
63 | | |
64 | | |__ edit.tmpl
65 | |
66 | |__ home.tmpl
67 | ~~~
68 |
69 | Will provide the following templates:
70 | ~~~
71 | admin/index
72 | admin/edit
73 | home
74 | ~~~
75 | ### Layouts
76 | `render.Renderer` provides a `yield` function for layouts to access:
77 | ~~~ go
78 | // ...
79 | m.Use(render.Renderer(render.Options{
80 | Layout: "layout",
81 | }))
82 | // ...
83 | ~~~
84 |
85 | ~~~ html
86 |
87 |
88 |
89 | Martini Plz
90 |
91 |
92 |
93 | {{ yield }}
94 |
95 |
96 | ~~~
97 |
98 | ### Character Encodings
99 | The `render.Renderer` middleware will automatically set the proper Content-Type header based on which function you call. See below for an example of what the default settings would output (note that UTF-8 is the default):
100 | ~~~ go
101 | // main.go
102 | package main
103 |
104 | import (
105 | "github.com/codegangsta/martini"
106 | "github.com/codegangsta/martini-contrib/render"
107 | )
108 |
109 | func main() {
110 | m := martini.Classic()
111 | m.Use(render.Renderer())
112 |
113 | // This will set the Content-Type header to "text/html; charset=UTF-8"
114 | m.Get("/", func(r render.Render) {
115 | r.HTML(200, "hello", "world")
116 | })
117 |
118 | // This will set the Content-Type header to "application/json; charset=UTF-8"
119 | m.Get("/api", func(r render.Render) {
120 | r.JSON(200, map[string]interface{}{"hello": "world"})
121 | })
122 |
123 | m.Run()
124 | }
125 |
126 | ~~~
127 |
128 | In order to change the charset, you can set the `Charset` within the `render.Options` to your encoding value:
129 | ~~~ go
130 | // main.go
131 | package main
132 |
133 | import (
134 | "github.com/codegangsta/martini"
135 | "github.com/codegangsta/martini-contrib/render"
136 | )
137 |
138 | func main() {
139 | m := martini.Classic()
140 | m.Use(render.Renderer(render.Options{
141 | Charset: "ISO-8859-1",
142 | }))
143 |
144 | // This is set the Content-Type to "text/html; charset=ISO-8859-1"
145 | m.Get("/", func(r render.Render) {
146 | r.HTML(200, "hello", "world")
147 | })
148 |
149 | // This is set the Content-Type to "application/json; charset=ISO-8859-1"
150 | m.Get("/api", func(r render.Render) {
151 | r.JSON(200, map[string]interface{}{"hello": "world"})
152 | })
153 |
154 | m.Run()
155 | }
156 |
157 | ~~~
158 |
159 | ## Authors
160 | * [Jeremy Saenz](http://github.com/codegangsta)
161 | * [Cory Jacobsen](http://github.com/cojac)
162 |
--------------------------------------------------------------------------------
/render/fixtures/basic/admin/index.tmpl:
--------------------------------------------------------------------------------
1 | Admin {{.}}
2 |
--------------------------------------------------------------------------------
/render/fixtures/basic/another_layout.tmpl:
--------------------------------------------------------------------------------
1 | another head
2 | {{ yield }}
3 | another foot
4 |
--------------------------------------------------------------------------------
/render/fixtures/basic/content.tmpl:
--------------------------------------------------------------------------------
1 | {{ . }}
2 |
--------------------------------------------------------------------------------
/render/fixtures/basic/delims.tmpl:
--------------------------------------------------------------------------------
1 | Hello {[{.}]}
--------------------------------------------------------------------------------
/render/fixtures/basic/hello.tmpl:
--------------------------------------------------------------------------------
1 | Hello {{.}}
2 |
--------------------------------------------------------------------------------
/render/fixtures/basic/hypertext.html:
--------------------------------------------------------------------------------
1 | Hypertext!
2 |
--------------------------------------------------------------------------------
/render/fixtures/basic/layout.tmpl:
--------------------------------------------------------------------------------
1 | head
2 | {{ yield }}
3 | foot
4 |
--------------------------------------------------------------------------------
/render/fixtures/custom_funcs/index.tmpl:
--------------------------------------------------------------------------------
1 | {{ myCustomFunc }}
2 |
--------------------------------------------------------------------------------
/render/render.go:
--------------------------------------------------------------------------------
1 | // Package render is a middleware for Martini that provides easy JSON serialization and HTML template rendering.
2 | //
3 | // package main
4 | //
5 | // import (
6 | // "github.com/codegangsta/martini"
7 | // "github.com/codegangsta/martini-contrib/render"
8 | // )
9 | //
10 | // func main() {
11 | // m := martini.Classic()
12 | // m.Use(render.Renderer()) // reads "templates" directory by default
13 | //
14 | // m.Get("/html", func(r render.Render) {
15 | // r.HTML(200, "mytemplate", nil)
16 | // })
17 | //
18 | // m.Get("/json", func(r render.Render) {
19 | // r.JSON(200, "hello world")
20 | // })
21 | //
22 | // m.Run()
23 | // }
24 | package render
25 |
26 | import (
27 | "bytes"
28 | "encoding/json"
29 | "fmt"
30 | "github.com/codegangsta/martini"
31 | "html/template"
32 | "io"
33 | "io/ioutil"
34 | "net/http"
35 | "os"
36 | "path/filepath"
37 | )
38 |
39 | const (
40 | ContentType = "Content-Type"
41 | ContentLength = "Content-Length"
42 | ContentJSON = "application/json"
43 | ContentHTML = "text/html"
44 | defaultCharset = "UTF-8"
45 | )
46 |
47 | // Included helper functions for use when rendering html
48 | var helperFuncs = template.FuncMap{
49 | "yield": func() (string, error) {
50 | return "", fmt.Errorf("yield called with no layout defined")
51 | },
52 | }
53 |
54 | // Render is a service that can be injected into a Martini handler. Render provides functions for easily writing JSON and
55 | // HTML templates out to a http Response.
56 | type Render interface {
57 | // JSON writes the given status and JSON serialized version of the given value to the http.ResponseWriter.
58 | JSON(status int, v interface{})
59 | // HTML renders a html template specified by the name and writes the result and given status to the http.ResponseWriter.
60 | HTML(status int, name string, v interface{}, htmlOpt ...HTMLOptions)
61 | // Error is a convenience function that writes an http status to the http.ResponseWriter.
62 | Error(status int)
63 | // Redirect is a convienience function that sends an HTTP redirect. If status is omitted, uses 302 (Found)
64 | Redirect(location string, status ...int)
65 | // Template returns the internal *template.Template used to render the HTML
66 | Template() *template.Template
67 | }
68 |
69 | // Delims represents a set of Left and Right delimiters for HTML template rendering
70 | type Delims struct {
71 | // Left delimiter, defaults to {{
72 | Left string
73 | // Right delimiter, defaults to }}
74 | Right string
75 | }
76 |
77 | // Options is a struct for specifying configuration options for the render.Renderer middleware
78 | type Options struct {
79 | // Directory to load templates. Default is "templates"
80 | Directory string
81 | // Layout template name. Will not render a layout if "". Defaults to "".
82 | Layout string
83 | // Extensions to parse template files from. Defaults to [".tmpl"]
84 | Extensions []string
85 | // Funcs is a slice of FuncMaps to apply to the template upon compilation. This is useful for helper functions. Defaults to [].
86 | Funcs []template.FuncMap
87 | // Delims sets the action delimiters to the specified strings in the Delims struct.
88 | Delims Delims
89 | // Appends the given charset to the Content-Type header. Default is "UTF-8".
90 | Charset string
91 | // Outputs human readable JSON
92 | IndentJSON bool
93 | }
94 |
95 | // HTMLOptions is a struct for overriding some rendering Options for specific HTML call
96 | type HTMLOptions struct {
97 | // Layout template name. Overrides Options.Layout.
98 | Layout string
99 | }
100 |
101 | // Renderer is a Middleware that maps a render.Render service into the Martini handler chain. An single variadic render.Options
102 | // struct can be optionally provided to configure HTML rendering. The default directory for templates is "templates" and the default
103 | // file extension is ".tmpl".
104 | //
105 | // If MARTINI_ENV is set to "" or "development" then templates will be recompiled on every request. For more performance, set the
106 | // MARTINI_ENV environment variable to "production"
107 | func Renderer(options ...Options) martini.Handler {
108 | opt := prepareOptions(options)
109 | cs := prepareCharset(opt.Charset)
110 | t := compile(opt)
111 | return func(res http.ResponseWriter, req *http.Request, c martini.Context) {
112 | // recompile for easy development
113 | if martini.Env == martini.Dev {
114 | t = compile(opt)
115 | }
116 | tc, _ := t.Clone()
117 | c.MapTo(&renderer{res, req, tc, opt, cs}, (*Render)(nil))
118 | }
119 | }
120 |
121 | func prepareCharset(charset string) string {
122 | if len(charset) != 0 {
123 | return "; charset=" + charset
124 | }
125 |
126 | return "; charset=" + defaultCharset
127 | }
128 |
129 | func prepareOptions(options []Options) Options {
130 | var opt Options
131 | if len(options) > 0 {
132 | opt = options[0]
133 | }
134 |
135 | // Defaults
136 | if len(opt.Directory) == 0 {
137 | opt.Directory = "templates"
138 | }
139 | if len(opt.Extensions) == 0 {
140 | opt.Extensions = []string{".tmpl"}
141 | }
142 |
143 | return opt
144 | }
145 |
146 | func compile(options Options) *template.Template {
147 | dir := options.Directory
148 | t := template.New(dir)
149 | t.Delims(options.Delims.Left, options.Delims.Right)
150 | // parse an initial template in case we don't have any
151 | template.Must(t.Parse("Martini"))
152 |
153 | filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
154 | r, err := filepath.Rel(dir, path)
155 | if err != nil {
156 | return err
157 | }
158 |
159 | ext := filepath.Ext(r)
160 | for _, extension := range options.Extensions {
161 | if ext == extension {
162 |
163 | buf, err := ioutil.ReadFile(path)
164 | if err != nil {
165 | panic(err)
166 | }
167 |
168 | name := (r[0 : len(r)-len(ext)])
169 | tmpl := t.New(filepath.ToSlash(name))
170 |
171 | // add our funcmaps
172 | for _, funcs := range options.Funcs {
173 | tmpl.Funcs(funcs)
174 | }
175 |
176 | // Bomb out if parse fails. We don't want any silent server starts.
177 | template.Must(tmpl.Funcs(helperFuncs).Parse(string(buf)))
178 | break
179 | }
180 | }
181 |
182 | return nil
183 | })
184 |
185 | return t
186 | }
187 |
188 | type renderer struct {
189 | http.ResponseWriter
190 | req *http.Request
191 | t *template.Template
192 | opt Options
193 | compiledCharset string
194 | }
195 |
196 | func (r *renderer) JSON(status int, v interface{}) {
197 | var result []byte
198 | var err error
199 | if r.opt.IndentJSON {
200 | result, err = json.MarshalIndent(v, "", " ")
201 | } else {
202 | result, err = json.Marshal(v)
203 | }
204 | if err != nil {
205 | http.Error(r, err.Error(), 500)
206 | return
207 | }
208 |
209 | // json rendered fine, write out the result
210 | r.Header().Set(ContentType, ContentJSON+r.compiledCharset)
211 | r.WriteHeader(status)
212 | r.Write(result)
213 | }
214 |
215 | func (r *renderer) HTML(status int, name string, binding interface{}, htmlOpt ...HTMLOptions) {
216 | opt := r.prepareHTMLOptions(htmlOpt)
217 | // assign a layout if there is one
218 | if len(opt.Layout) > 0 {
219 | r.addYield(name, binding)
220 | name = opt.Layout
221 | }
222 |
223 | out, err := r.execute(name, binding)
224 | if err != nil {
225 | http.Error(r, err.Error(), http.StatusInternalServerError)
226 | return
227 | }
228 |
229 | // template rendered fine, write out the result
230 | r.Header().Set(ContentType, ContentHTML+r.compiledCharset)
231 | r.WriteHeader(status)
232 | io.Copy(r, out)
233 | }
234 |
235 | // Error writes the given HTTP status to the current ResponseWriter
236 | func (r *renderer) Error(status int) {
237 | r.WriteHeader(status)
238 | }
239 |
240 | func (r *renderer) Redirect(location string, status ...int) {
241 | code := http.StatusFound
242 | if len(status) == 1 {
243 | code = status[0]
244 | }
245 |
246 | http.Redirect(r, r.req, location, code)
247 | }
248 |
249 | func (r *renderer) Template() *template.Template {
250 | return r.t
251 | }
252 |
253 | func (r *renderer) execute(name string, binding interface{}) (*bytes.Buffer, error) {
254 | buf := new(bytes.Buffer)
255 | return buf, r.t.ExecuteTemplate(buf, name, binding)
256 | }
257 |
258 | func (r *renderer) addYield(name string, binding interface{}) {
259 | funcs := template.FuncMap{
260 | "yield": func() (template.HTML, error) {
261 | buf, err := r.execute(name, binding)
262 | // return safe html here since we are rendering our own template
263 | return template.HTML(buf.String()), err
264 | },
265 | }
266 | r.t.Funcs(funcs)
267 | }
268 |
269 | func (r *renderer) prepareHTMLOptions(htmlOpt []HTMLOptions) HTMLOptions {
270 | if len(htmlOpt) > 0 {
271 | return htmlOpt[0]
272 | }
273 |
274 | return HTMLOptions{
275 | Layout: r.opt.Layout,
276 | }
277 | }
278 |
--------------------------------------------------------------------------------
/render/render_test.go:
--------------------------------------------------------------------------------
1 | package render
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "html/template"
6 | "net/http"
7 | "net/http/httptest"
8 | "net/url"
9 | "reflect"
10 | "testing"
11 | )
12 |
13 | type Greeting struct {
14 | One string `json:"one"`
15 | Two string `json:"two"`
16 | }
17 |
18 | func Test_Render_JSON(t *testing.T) {
19 | m := martini.Classic()
20 | m.Use(Renderer(Options{
21 | // nothing here to configure
22 | }))
23 |
24 | // routing
25 | m.Get("/foobar", func(r Render) {
26 | r.JSON(300, Greeting{"hello", "world"})
27 | })
28 |
29 | res := httptest.NewRecorder()
30 | req, _ := http.NewRequest("GET", "/foobar", nil)
31 |
32 | m.ServeHTTP(res, req)
33 |
34 | expect(t, res.Code, 300)
35 | expect(t, res.Header().Get(ContentType), ContentJSON+"; charset=UTF-8")
36 | expect(t, res.Body.String(), `{"one":"hello","two":"world"}`)
37 | }
38 |
39 | func Test_Render_Indented_JSON(t *testing.T) {
40 | m := martini.Classic()
41 | m.Use(Renderer(Options{
42 | IndentJSON: true,
43 | }))
44 |
45 | // routing
46 | m.Get("/foobar", func(r Render) {
47 | r.JSON(300, Greeting{"hello", "world"})
48 | })
49 |
50 | res := httptest.NewRecorder()
51 | req, _ := http.NewRequest("GET", "/foobar", nil)
52 |
53 | m.ServeHTTP(res, req)
54 |
55 | expect(t, res.Code, 300)
56 | expect(t, res.Header().Get(ContentType), ContentJSON+"; charset=UTF-8")
57 | expect(t, res.Body.String(), `{
58 | "one": "hello",
59 | "two": "world"
60 | }`)
61 | }
62 |
63 | func Test_Render_Bad_HTML(t *testing.T) {
64 | m := martini.Classic()
65 | m.Use(Renderer(Options{
66 | Directory: "fixtures/basic",
67 | }))
68 |
69 | // routing
70 | m.Get("/foobar", func(r Render) {
71 | r.HTML(200, "nope", nil)
72 | })
73 |
74 | res := httptest.NewRecorder()
75 | req, _ := http.NewRequest("GET", "/foobar", nil)
76 |
77 | m.ServeHTTP(res, req)
78 |
79 | expect(t, res.Code, 500)
80 | expect(t, res.Body.String(), "html/template: \"nope\" is undefined\n")
81 | }
82 |
83 | func Test_Render_HTML(t *testing.T) {
84 | m := martini.Classic()
85 | m.Use(Renderer(Options{
86 | Directory: "fixtures/basic",
87 | }))
88 |
89 | // routing
90 | m.Get("/foobar", func(r Render) {
91 | r.HTML(200, "hello", "jeremy")
92 | })
93 |
94 | res := httptest.NewRecorder()
95 | req, _ := http.NewRequest("GET", "/foobar", nil)
96 |
97 | m.ServeHTTP(res, req)
98 |
99 | expect(t, res.Code, 200)
100 | expect(t, res.Header().Get(ContentType), ContentHTML+"; charset=UTF-8")
101 | expect(t, res.Body.String(), "Hello jeremy
\n")
102 | }
103 |
104 | func Test_Render_Extensions(t *testing.T) {
105 | m := martini.Classic()
106 | m.Use(Renderer(Options{
107 | Directory: "fixtures/basic",
108 | Extensions: []string{".tmpl", ".html"},
109 | }))
110 |
111 | // routing
112 | m.Get("/foobar", func(r Render) {
113 | r.HTML(200, "hypertext", nil)
114 | })
115 |
116 | res := httptest.NewRecorder()
117 | req, _ := http.NewRequest("GET", "/foobar", nil)
118 |
119 | m.ServeHTTP(res, req)
120 |
121 | expect(t, res.Code, 200)
122 | expect(t, res.Header().Get(ContentType), ContentHTML+"; charset=UTF-8")
123 | expect(t, res.Body.String(), "Hypertext!\n")
124 | }
125 |
126 | func Test_Render_Funcs(t *testing.T) {
127 |
128 | m := martini.Classic()
129 | m.Use(Renderer(Options{
130 | Directory: "fixtures/custom_funcs",
131 | Funcs: []template.FuncMap{
132 | {
133 | "myCustomFunc": func() string {
134 | return "My custom function"
135 | },
136 | },
137 | },
138 | }))
139 |
140 | // routing
141 | m.Get("/foobar", func(r Render) {
142 | r.HTML(200, "index", "jeremy")
143 | })
144 |
145 | res := httptest.NewRecorder()
146 | req, _ := http.NewRequest("GET", "/foobar", nil)
147 |
148 | m.ServeHTTP(res, req)
149 |
150 | expect(t, res.Body.String(), "My custom function\n")
151 | }
152 |
153 | func Test_Render_Layout(t *testing.T) {
154 | m := martini.Classic()
155 | m.Use(Renderer(Options{
156 | Directory: "fixtures/basic",
157 | Layout: "layout",
158 | }))
159 |
160 | // routing
161 | m.Get("/foobar", func(r Render) {
162 | r.HTML(200, "content", "jeremy")
163 | })
164 |
165 | res := httptest.NewRecorder()
166 | req, _ := http.NewRequest("GET", "/foobar", nil)
167 |
168 | m.ServeHTTP(res, req)
169 |
170 | expect(t, res.Body.String(), "head\njeremy
\n\nfoot\n")
171 | }
172 |
173 | func Test_Render_Nested_HTML(t *testing.T) {
174 | m := martini.Classic()
175 | m.Use(Renderer(Options{
176 | Directory: "fixtures/basic",
177 | }))
178 |
179 | // routing
180 | m.Get("/foobar", func(r Render) {
181 | r.HTML(200, "admin/index", "jeremy")
182 | })
183 |
184 | res := httptest.NewRecorder()
185 | req, _ := http.NewRequest("GET", "/foobar", nil)
186 |
187 | m.ServeHTTP(res, req)
188 |
189 | expect(t, res.Code, 200)
190 | expect(t, res.Header().Get(ContentType), ContentHTML+"; charset=UTF-8")
191 | expect(t, res.Body.String(), "Admin jeremy
\n")
192 | }
193 |
194 | func Test_Render_Delimiters(t *testing.T) {
195 | m := martini.Classic()
196 | m.Use(Renderer(Options{
197 | Delims: Delims{"{[{", "}]}"},
198 | Directory: "fixtures/basic",
199 | }))
200 |
201 | // routing
202 | m.Get("/foobar", func(r Render) {
203 | r.HTML(200, "delims", "jeremy")
204 | })
205 |
206 | res := httptest.NewRecorder()
207 | req, _ := http.NewRequest("GET", "/foobar", nil)
208 |
209 | m.ServeHTTP(res, req)
210 |
211 | expect(t, res.Code, 200)
212 | expect(t, res.Header().Get(ContentType), ContentHTML+"; charset=UTF-8")
213 | expect(t, res.Body.String(), "Hello jeremy
")
214 | }
215 |
216 | func Test_Render_Error404(t *testing.T) {
217 | res := httptest.NewRecorder()
218 | r := renderer{res, nil, nil, Options{}, ""}
219 | r.Error(404)
220 | expect(t, res.Code, 404)
221 | }
222 |
223 | func Test_Render_Error500(t *testing.T) {
224 | res := httptest.NewRecorder()
225 | r := renderer{res, nil, nil, Options{}, ""}
226 | r.Error(500)
227 | expect(t, res.Code, 500)
228 | }
229 |
230 | func Test_Render_Redirect_Default(t *testing.T) {
231 | url, _ := url.Parse("http://localhost/path/one")
232 | req := http.Request{
233 | Method: "GET",
234 | URL: url,
235 | }
236 | res := httptest.NewRecorder()
237 |
238 | r := renderer{res, &req, nil, Options{}, ""}
239 | r.Redirect("two")
240 |
241 | expect(t, res.Code, 302)
242 | expect(t, res.HeaderMap["Location"][0], "/path/two")
243 | }
244 |
245 | func Test_Render_Redirect_Code(t *testing.T) {
246 | url, _ := url.Parse("http://localhost/path/one")
247 | req := http.Request{
248 | Method: "GET",
249 | URL: url,
250 | }
251 | res := httptest.NewRecorder()
252 |
253 | r := renderer{res, &req, nil, Options{}, ""}
254 | r.Redirect("two", 307)
255 |
256 | expect(t, res.Code, 307)
257 | expect(t, res.HeaderMap["Location"][0], "/path/two")
258 | }
259 |
260 | func Test_Render_Charset_JSON(t *testing.T) {
261 | m := martini.Classic()
262 | m.Use(Renderer(Options{
263 | Charset: "foobar",
264 | }))
265 |
266 | // routing
267 | m.Get("/foobar", func(r Render) {
268 | r.JSON(300, Greeting{"hello", "world"})
269 | })
270 |
271 | res := httptest.NewRecorder()
272 | req, _ := http.NewRequest("GET", "/foobar", nil)
273 |
274 | m.ServeHTTP(res, req)
275 |
276 | expect(t, res.Code, 300)
277 | expect(t, res.Header().Get(ContentType), ContentJSON+"; charset=foobar")
278 | expect(t, res.Body.String(), `{"one":"hello","two":"world"}`)
279 | }
280 |
281 | func Test_Render_Default_Charset_HTML(t *testing.T) {
282 | m := martini.Classic()
283 | m.Use(Renderer(Options{
284 | Directory: "fixtures/basic",
285 | }))
286 |
287 | // routing
288 | m.Get("/foobar", func(r Render) {
289 | r.HTML(200, "hello", "jeremy")
290 | })
291 |
292 | res := httptest.NewRecorder()
293 | req, _ := http.NewRequest("GET", "/foobar", nil)
294 |
295 | m.ServeHTTP(res, req)
296 |
297 | expect(t, res.Code, 200)
298 | expect(t, res.Header().Get(ContentType), ContentHTML+"; charset=UTF-8")
299 | // ContentLength should be deferred to the ResponseWriter and not Render
300 | expect(t, res.Header().Get(ContentLength), "")
301 | expect(t, res.Body.String(), "Hello jeremy
\n")
302 | }
303 |
304 | func Test_Render_Override_Layout(t *testing.T) {
305 | m := martini.Classic()
306 | m.Use(Renderer(Options{
307 | Directory: "fixtures/basic",
308 | Layout: "layout",
309 | }))
310 |
311 | // routing
312 | m.Get("/foobar", func(r Render) {
313 | r.HTML(200, "content", "jeremy", HTMLOptions{
314 | Layout: "another_layout",
315 | })
316 | })
317 |
318 | res := httptest.NewRecorder()
319 | req, _ := http.NewRequest("GET", "/foobar", nil)
320 |
321 | m.ServeHTTP(res, req)
322 |
323 | expect(t, res.Code, 200)
324 | expect(t, res.Header().Get(ContentType), ContentHTML+"; charset=UTF-8")
325 | expect(t, res.Body.String(), "another head\njeremy
\n\nanother foot\n")
326 | }
327 |
328 | /* Test Helpers */
329 | func expect(t *testing.T, a interface{}, b interface{}) {
330 | if a != b {
331 | t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
332 | }
333 | }
334 |
335 | func refute(t *testing.T, a interface{}, b interface{}) {
336 | if a == b {
337 | t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
338 | }
339 | }
340 |
--------------------------------------------------------------------------------
/secure/README.md:
--------------------------------------------------------------------------------
1 | # secure
2 | Martini middleware that helps enable some quick security wins.
3 |
4 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/secure)
5 |
6 | ## Usage
7 |
8 | ~~~ go
9 | import (
10 | "github.com/codegangsta/martini"
11 | "github.com/codegangsta/martini-contrib/secure"
12 | )
13 |
14 | func main() {
15 | m := martini.Classic()
16 |
17 | martini.Env = martini.Prod // You have to set the environment to `production` for all of secure to work properly!
18 |
19 | m.Use(secure.Secure(secure.Options{
20 | AllowedHosts: []string{"example.com", "ssl.example.com"},
21 | SSLRedirect: true,
22 | SSLHost: "ssl.example.com",
23 | SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"},
24 | STSSeconds: 315360000,
25 | STSIncludeSubdomains: true,
26 | FrameDeny: true,
27 | ContentTypeNosniff: true,
28 | BrowserXssFilter: true,
29 | ContentSecurityPolicy: "default-src 'self'",
30 | }))
31 | m.Run()
32 | }
33 |
34 | ~~~
35 |
36 | Make sure to include the secure middleware as close to the top as possible. It's best to do the allowed hosts and SSL check before anything else.
37 |
38 | The above example will only allow requests with a host name of 'example.com', or 'ssl.example.com'. Also if the request is not https, it will be redirected to https with the host name of 'ssl.example.com'.
39 | After this it will add the following headers:
40 | ~~~
41 | Strict-Transport-Security: 315360000; includeSubdomains
42 | X-Frame-Options: DENY
43 | X-Content-Type-Options: nosniff
44 | X-XSS-Protection: 1; mode=block
45 | Content-Security-Policy: default-src 'self'
46 | ~~~
47 |
48 | ###Set the `MARTINI_ENV` environment variable to `production` when deploying!
49 | If you don't, the SSLRedirect and STS Header will not work. This allows you to work in development/test mode and not have any annoying redirects to HTTPS (ie. development can happen on http). If this is not the behavior you're expecting, see the `DisableProdCheck` below in the options.
50 |
51 | You can also disable the production check for testing like so:
52 | ~~~ go
53 | //...
54 | m.Use(secure.Secure(secure.Options{
55 | SSLRedirect: true,
56 | STSSeconds: 315360000,
57 | DisableProdCheck: martini.Env == martini.Test,
58 | }))
59 | //...
60 | ~~~
61 |
62 |
63 | ### Options
64 | `secure.Secure` comes with a variety of configuration options:
65 |
66 | ~~~ go
67 | // ...
68 | m.Use(secure.Secure(secure.Secure{
69 | AllowedHosts: []string{"ssl.example.com"}, // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
70 | SSLRedirect: true, // If SSLRedirect is set to true, then only allow https requests. Default is false.
71 | SSLHost: "ssl.example.com", // SSLHost is the host name that is used to redirect http requests to https. Default is "", which indicates to use the same host.
72 | SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, // SSLProxyHeaders is set of header keys with associated values that would indicate a valid https request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map.
73 | STSSeconds: 315360000, // STSSeconds is the max-age of the Strict-Transport-Security header. Default is 0, which would NOT include the header.
74 | STSIncludeSubdomains: true, // If STSIncludeSubdomains is set to true, the `includeSubdomains` will be appended to the Strict-Transport-Security header. Default is false.
75 | FrameDeny: true, // If FrameDeny is set to true, adds the X-Frame-Options header with the value of `DENY`. Default is false.
76 | CustomFrameOptionsValue: "SAMEORIGIN", // CustomFrameOptionsValue allows the X-Frame-Options header value to be set with a custom value. This overrides the FrameDeny option.
77 | ContentTypeNosniff: true, // If ContentTypeNosniff is true, adds the X-Content-Type-Options header with the value `nosniff`. Default is false.
78 | BrowserXssFilter: true, // If BrowserXssFilter is true, adds the X-XSS-Protection header with the value `1; mode=block`. Default is false.
79 | ContentSecurityPolicy: "default-src 'self'", // ContentSecurityPolicy allows the Content-Security-Policy header value to be set with a custom value. Default is "".
80 | DisableProdCheck: true, // This will ignore our production check, and will follow the SSLRedirect and STSSeconds/STSIncludeSubdomains options... even in development! This would likely only be used to mimic a production environment on your local development machine.
81 | }))
82 | // ...
83 | ~~~
84 |
85 | ### Nginx
86 | If you would like to add the above security rules directly to your nginx configuration, everything is below:
87 | ~~~
88 | # Allowed Hosts:
89 | if ($host !~* ^(example.com|ssl.example.com)$ ) {
90 | return 500;
91 | }
92 |
93 | # SSL Redirect:
94 | server {
95 | listen 80;
96 | server_name example.com ssl.example.com;
97 | return 301 https://ssl.example.com$request_uri;
98 | }
99 |
100 | # Headers to be added:
101 | add_header Strict-Transport-Security "max-age=315360000";
102 | add_header X-Frame-Options "DENY";
103 | add_header X-Content-Type-Options "nosniff";
104 | add_header X-XSS-Protection "1; mode=block";
105 | add_header Content-Security-Policy "default-src 'self'";
106 | ~~~
107 |
108 | ## Authors
109 | * [Cory Jacobsen](http://github.com/cojac)
110 |
--------------------------------------------------------------------------------
/secure/secure.go:
--------------------------------------------------------------------------------
1 | // Package secure is a middleware for Martini that helps enable some quick security wins.
2 | //
3 | // package main
4 | //
5 | // import (
6 | // "github.com/codegangsta/martini"
7 | // "github.com/codegangsta/martini-contrib/secure"
8 | // )
9 | //
10 | // func main() {
11 | // m := martini.Classic()
12 | //
13 | // m.Use(secure.Secure(secure.Options{
14 | // AllowedHosts: []string{"www.example.com", "sub.example.com"},
15 | // SSLRedirect: true,
16 | // }))
17 | //
18 | // m.Get("/", func() string {
19 | // return "Hello World"
20 | // })
21 | //
22 | // m.Run()
23 | // }
24 | package secure
25 |
26 | import (
27 | "fmt"
28 | "github.com/codegangsta/martini"
29 | "net/http"
30 | "strings"
31 | )
32 |
33 | const (
34 | stsHeader = "Strict-Transport-Security"
35 | stsSubdomainString = "; includeSubdomains"
36 | frameOptionsHeader = "X-Frame-Options"
37 | frameOptionsValue = "DENY"
38 | contentTypeHeader = "X-Content-Type-Options"
39 | contentTypeValue = "nosniff"
40 | xssProtectionHeader = "X-XSS-Protection"
41 | xssProtectionValue = "1; mode=block"
42 | cspHeader = "Content-Security-Policy"
43 | )
44 |
45 | // Options is a struct for specifying configuration options for the secure.Secure middleware.
46 | type Options struct {
47 | // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
48 | AllowedHosts []string
49 | // If SSLRedirect is set to true, then only allow https requests. Default is false.
50 | SSLRedirect bool
51 | // SSLHost is the host name that is used to redirect http requests to https. Default is "", which indicates to use the same host.
52 | SSLHost string
53 | // SSLProxyHeaders is set of header keys with associated values that would indicate a valid https request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map.
54 | SSLProxyHeaders map[string]string
55 | // STSSeconds is the max-age of the Strict-Transport-Security header. Default is 0, which would NOT include the header.
56 | STSSeconds int64
57 | // If STSIncludeSubdomains is set to true, the `includeSubdomains` will be appended to the Strict-Transport-Security header. Default is false.
58 | STSIncludeSubdomains bool
59 | // If FrameDeny is set to true, adds the X-Frame-Options header with the value of `DENY`. Default is false.
60 | FrameDeny bool
61 | // CustomFrameOptionsValue allows the X-Frame-Options header value to be set with a custom value. This overrides the FrameDeny option.
62 | CustomFrameOptionsValue string
63 | // If ContentTypeNosniff is true, adds the X-Content-Type-Options header with the value `nosniff`. Default is false.
64 | ContentTypeNosniff bool
65 | // If BrowserXssFilter is true, adds the X-XSS-Protection header with the value `1; mode=block`. Default is false.
66 | BrowserXssFilter bool
67 | // ContentSecurityPolicy allows the Content-Security-Policy header value to be set with a custom value. Default is "".
68 | ContentSecurityPolicy string
69 | // When developing, the SSL and STS options can cause some unwanted effects. Usually testing happens on http, not https... we check `if martini.Env == martini.Prod`.
70 | // If you would like your development environment to mimic production with complete SSL redirects and STS headers, set this to true. Default if false.
71 | DisableProdCheck bool
72 | }
73 |
74 | // Secure is a middleware that helps setup a few basic security features. A single secure.Options struct can be
75 | // provided to configure which features should be enabled, and the ability to override a few of the default values.
76 | func Secure(opt Options) martini.Handler {
77 | return func(res http.ResponseWriter, req *http.Request, c martini.Context) {
78 | // Allowed hosts check.
79 | applyAllowedHosts(opt, res, req)
80 |
81 | // SSL check.
82 | applySSL(opt, res, req)
83 |
84 | // Strict Transport Security header.
85 | applySTS(opt, res, req)
86 |
87 | // Frame Options header.
88 | applyFrameOptions(opt, res, req)
89 |
90 | // Content Type Options header.
91 | applyContentTypeOptions(opt, res, req)
92 |
93 | // XSS Protection header.
94 | applyXSS(opt, res, req)
95 |
96 | // Content Security Policy header.
97 | applyCSP(opt, res, req)
98 | }
99 | }
100 |
101 | func applyAllowedHosts(opt Options, res http.ResponseWriter, req *http.Request) {
102 | if len(opt.AllowedHosts) > 0 {
103 | isGoodHost := false
104 | for _, allowedHost := range opt.AllowedHosts {
105 | if strings.EqualFold(allowedHost, req.Host) {
106 | isGoodHost = true
107 | break
108 | }
109 | }
110 |
111 | if isGoodHost == false {
112 | http.Error(res, "Bad Host", http.StatusInternalServerError)
113 | }
114 | }
115 | }
116 |
117 | func applySSL(opt Options, res http.ResponseWriter, req *http.Request) {
118 | if opt.SSLRedirect && (martini.Env == martini.Prod || opt.DisableProdCheck == true) {
119 | isSSL := false
120 | if strings.EqualFold(req.URL.Scheme, "https") || req.TLS != nil {
121 | isSSL = true
122 | } else {
123 | for hKey, hVal := range opt.SSLProxyHeaders {
124 | if req.Header.Get(hKey) == hVal {
125 | isSSL = true
126 | break
127 | }
128 | }
129 | }
130 |
131 | if isSSL == false {
132 | url := req.URL
133 | url.Scheme = "https"
134 | url.Host = req.Host
135 |
136 | if opt.SSLHost != "" {
137 | url.Host = opt.SSLHost
138 | }
139 |
140 | http.Redirect(res, req, url.String(), http.StatusMovedPermanently)
141 | }
142 | }
143 | }
144 |
145 | func applySTS(opt Options, res http.ResponseWriter, req *http.Request) {
146 | if opt.STSSeconds != 0 && (martini.Env == martini.Prod || opt.DisableProdCheck == true) {
147 | stsSub := ""
148 | if opt.STSIncludeSubdomains {
149 | stsSub = stsSubdomainString
150 | }
151 |
152 | res.Header().Add(stsHeader, fmt.Sprintf("max-age=%d%s", opt.STSSeconds, stsSub))
153 | }
154 | }
155 |
156 | func applyFrameOptions(opt Options, res http.ResponseWriter, req *http.Request) {
157 | if opt.CustomFrameOptionsValue != "" {
158 | res.Header().Add(frameOptionsHeader, opt.CustomFrameOptionsValue)
159 | } else if opt.FrameDeny {
160 | res.Header().Add(frameOptionsHeader, frameOptionsValue)
161 | }
162 | }
163 |
164 | func applyContentTypeOptions(opt Options, res http.ResponseWriter, req *http.Request) {
165 | if opt.ContentTypeNosniff {
166 | res.Header().Add(contentTypeHeader, contentTypeValue)
167 | }
168 | }
169 |
170 | func applyXSS(opt Options, res http.ResponseWriter, req *http.Request) {
171 | if opt.BrowserXssFilter {
172 | res.Header().Add(xssProtectionHeader, xssProtectionValue)
173 | }
174 | }
175 |
176 | func applyCSP(opt Options, res http.ResponseWriter, req *http.Request) {
177 | if opt.ContentSecurityPolicy != "" {
178 | res.Header().Add(cspHeader, opt.ContentSecurityPolicy)
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/secure/secure_test.go:
--------------------------------------------------------------------------------
1 | package secure
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "reflect"
8 | "testing"
9 | )
10 |
11 | func Test_No_Config(t *testing.T) {
12 | m := martini.Classic()
13 | m.Use(Secure(Options{
14 | // nothing here to configure
15 | }))
16 |
17 | m.Get("/foo", func() string {
18 | return "bar"
19 | })
20 |
21 | res := httptest.NewRecorder()
22 | req, _ := http.NewRequest("GET", "/foo", nil)
23 |
24 | m.ServeHTTP(res, req)
25 |
26 | expect(t, res.Code, http.StatusOK)
27 | expect(t, res.Body.String(), `bar`)
28 | }
29 |
30 | func Test_No_AllowHosts(t *testing.T) {
31 | m := martini.Classic()
32 | m.Use(Secure(Options{
33 | AllowedHosts: []string{},
34 | }))
35 |
36 | m.Get("/foo", func() string {
37 | return "bar"
38 | })
39 |
40 | res := httptest.NewRecorder()
41 | req, _ := http.NewRequest("GET", "/foo", nil)
42 | req.Host = "www.example.com"
43 |
44 | m.ServeHTTP(res, req)
45 |
46 | expect(t, res.Code, http.StatusOK)
47 | expect(t, res.Body.String(), `bar`)
48 | }
49 |
50 | func Test_Good_Single_AllowHosts(t *testing.T) {
51 | m := martini.Classic()
52 | m.Use(Secure(Options{
53 | AllowedHosts: []string{"www.example.com"},
54 | }))
55 |
56 | m.Get("/foo", func() string {
57 | return "bar"
58 | })
59 |
60 | res := httptest.NewRecorder()
61 | req, _ := http.NewRequest("GET", "/foo", nil)
62 | req.Host = "www.example.com"
63 |
64 | m.ServeHTTP(res, req)
65 |
66 | expect(t, res.Code, http.StatusOK)
67 | expect(t, res.Body.String(), `bar`)
68 | }
69 |
70 | func Test_Bad_Single_AllowHosts(t *testing.T) {
71 | m := martini.Classic()
72 | m.Use(Secure(Options{
73 | AllowedHosts: []string{"sub.example.com"},
74 | }))
75 |
76 | m.Get("/foo", func() string {
77 | return "bar"
78 | })
79 |
80 | res := httptest.NewRecorder()
81 | req, _ := http.NewRequest("GET", "/foo", nil)
82 | req.Host = "www.example.com"
83 |
84 | m.ServeHTTP(res, req)
85 |
86 | expect(t, res.Code, http.StatusInternalServerError)
87 | }
88 |
89 | func Test_Good_Multiple_AllowHosts(t *testing.T) {
90 | m := martini.Classic()
91 | m.Use(Secure(Options{
92 | AllowedHosts: []string{"www.example.com", "sub.example.com"},
93 | }))
94 |
95 | m.Get("/foo", func() string {
96 | return "bar"
97 | })
98 |
99 | res := httptest.NewRecorder()
100 | req, _ := http.NewRequest("GET", "/foo", nil)
101 | req.Host = "sub.example.com"
102 |
103 | m.ServeHTTP(res, req)
104 |
105 | expect(t, res.Code, http.StatusOK)
106 | expect(t, res.Body.String(), `bar`)
107 | }
108 |
109 | func Test_Bad_Multiple_AllowHosts(t *testing.T) {
110 | m := martini.Classic()
111 | m.Use(Secure(Options{
112 | AllowedHosts: []string{"www.example.com", "sub.example.com"},
113 | }))
114 |
115 | m.Get("/foo", func() string {
116 | return "bar"
117 | })
118 |
119 | res := httptest.NewRecorder()
120 | req, _ := http.NewRequest("GET", "/foo", nil)
121 | req.Host = "www3.example.com"
122 |
123 | m.ServeHTTP(res, req)
124 |
125 | expect(t, res.Code, http.StatusInternalServerError)
126 | }
127 |
128 | func Test_SSL(t *testing.T) {
129 | m := martini.Classic()
130 | martini.Env = martini.Prod
131 | m.Use(Secure(Options{
132 | SSLRedirect: true,
133 | }))
134 |
135 | m.Get("/foo", func() string {
136 | return "bar"
137 | })
138 |
139 | res := httptest.NewRecorder()
140 | req, _ := http.NewRequest("GET", "/foo", nil)
141 | req.Host = "www.example.com"
142 | req.URL.Scheme = "https"
143 |
144 | m.ServeHTTP(res, req)
145 |
146 | expect(t, res.Code, http.StatusOK)
147 | }
148 |
149 | func Test_SSL_In_Dev_Mode(t *testing.T) {
150 | m := martini.Classic()
151 | martini.Env = martini.Dev
152 | m.Use(Secure(Options{
153 | SSLRedirect: true,
154 | }))
155 |
156 | m.Get("/foo", func() string {
157 | return "bar"
158 | })
159 |
160 | res := httptest.NewRecorder()
161 | req, _ := http.NewRequest("GET", "/foo", nil)
162 | req.Host = "www.example.com"
163 | req.URL.Scheme = "http"
164 |
165 | m.ServeHTTP(res, req)
166 |
167 | expect(t, res.Code, http.StatusOK)
168 | }
169 |
170 | func Test_SSL_In_Dev_Mode_But_Disable_Prod_Check(t *testing.T) {
171 | m := martini.Classic()
172 | martini.Env = martini.Dev
173 | m.Use(Secure(Options{
174 | SSLRedirect: true,
175 | DisableProdCheck: true,
176 | }))
177 |
178 | m.Get("/foo", func() string {
179 | return "bar"
180 | })
181 |
182 | res := httptest.NewRecorder()
183 | req, _ := http.NewRequest("GET", "/foo", nil)
184 | req.Host = "www.example.com"
185 | req.URL.Scheme = "http"
186 |
187 | m.ServeHTTP(res, req)
188 |
189 | expect(t, res.Code, http.StatusMovedPermanently)
190 | expect(t, res.Header().Get("Location"), "https://www.example.com/foo")
191 | }
192 |
193 | func Test_Basic_SSL(t *testing.T) {
194 | m := martini.Classic()
195 | martini.Env = martini.Prod
196 | m.Use(Secure(Options{
197 | SSLRedirect: true,
198 | }))
199 |
200 | m.Get("/foo", func() string {
201 | return "bar"
202 | })
203 |
204 | res := httptest.NewRecorder()
205 | req, _ := http.NewRequest("GET", "/foo", nil)
206 | req.Host = "www.example.com"
207 | req.URL.Scheme = "http"
208 |
209 | m.ServeHTTP(res, req)
210 |
211 | expect(t, res.Code, http.StatusMovedPermanently)
212 | expect(t, res.Header().Get("Location"), "https://www.example.com/foo")
213 | }
214 |
215 | func Test_Basic_SSL_With_Host(t *testing.T) {
216 | m := martini.Classic()
217 | martini.Env = martini.Prod
218 | m.Use(Secure(Options{
219 | SSLRedirect: true,
220 | SSLHost: "secure.example.com",
221 | }))
222 |
223 | m.Get("/foo", func() string {
224 | return "bar"
225 | })
226 |
227 | res := httptest.NewRecorder()
228 | req, _ := http.NewRequest("GET", "/foo", nil)
229 | req.Host = "www.example.com"
230 | req.URL.Scheme = "http"
231 |
232 | m.ServeHTTP(res, req)
233 |
234 | expect(t, res.Code, http.StatusMovedPermanently)
235 | expect(t, res.Header().Get("Location"), "https://secure.example.com/foo")
236 | }
237 |
238 | func Test_Bad_Proxy_SSL(t *testing.T) {
239 | m := martini.Classic()
240 | martini.Env = martini.Prod
241 | m.Use(Secure(Options{
242 | SSLRedirect: true,
243 | }))
244 |
245 | m.Get("/foo", func() string {
246 | return "bar"
247 | })
248 |
249 | res := httptest.NewRecorder()
250 | req, _ := http.NewRequest("GET", "/foo", nil)
251 | req.Host = "www.example.com"
252 | req.URL.Scheme = "http"
253 | req.Header.Add("X-Forwarded-Proto", "https")
254 |
255 | m.ServeHTTP(res, req)
256 |
257 | expect(t, res.Code, http.StatusMovedPermanently)
258 | expect(t, res.Header().Get("Location"), "https://www.example.com/foo")
259 | }
260 |
261 | func Test_Custom_Proxy_SSL(t *testing.T) {
262 | m := martini.Classic()
263 | martini.Env = martini.Prod
264 | m.Use(Secure(Options{
265 | SSLRedirect: true,
266 | SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"},
267 | }))
268 |
269 | m.Get("/foo", func() string {
270 | return "bar"
271 | })
272 |
273 | res := httptest.NewRecorder()
274 | req, _ := http.NewRequest("GET", "/foo", nil)
275 | req.Host = "www.example.com"
276 | req.URL.Scheme = "http"
277 | req.Header.Add("X-Forwarded-Proto", "https")
278 |
279 | m.ServeHTTP(res, req)
280 |
281 | expect(t, res.Code, http.StatusOK)
282 | }
283 |
284 | func Test_Custom_Proxy_SSL_In_Dev_Mode(t *testing.T) {
285 | m := martini.Classic()
286 | martini.Env = martini.Dev
287 | m.Use(Secure(Options{
288 | SSLRedirect: true,
289 | SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"},
290 | }))
291 |
292 | m.Get("/foo", func() string {
293 | return "bar"
294 | })
295 |
296 | res := httptest.NewRecorder()
297 | req, _ := http.NewRequest("GET", "/foo", nil)
298 | req.Host = "www.example.com"
299 | req.URL.Scheme = "http"
300 | req.Header.Add("X-Forwarded-Proto", "http")
301 |
302 | m.ServeHTTP(res, req)
303 |
304 | expect(t, res.Code, http.StatusOK)
305 | }
306 |
307 | func Test_Custom_Proxy_And_Host_SSL(t *testing.T) {
308 | m := martini.Classic()
309 | martini.Env = martini.Prod
310 | m.Use(Secure(Options{
311 | SSLRedirect: true,
312 | SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"},
313 | SSLHost: "secure.example.com",
314 | }))
315 |
316 | m.Get("/foo", func() string {
317 | return "bar"
318 | })
319 |
320 | res := httptest.NewRecorder()
321 | req, _ := http.NewRequest("GET", "/foo", nil)
322 | req.Host = "www.example.com"
323 | req.URL.Scheme = "http"
324 | req.Header.Add("X-Forwarded-Proto", "https")
325 |
326 | m.ServeHTTP(res, req)
327 |
328 | expect(t, res.Code, http.StatusOK)
329 | }
330 |
331 | func Test_Custom_Bad_Proxy_And_Host_SSL(t *testing.T) {
332 | m := martini.Classic()
333 | martini.Env = martini.Prod
334 | m.Use(Secure(Options{
335 | SSLRedirect: true,
336 | SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "superman"},
337 | SSLHost: "secure.example.com",
338 | }))
339 |
340 | m.Get("/foo", func() string {
341 | return "bar"
342 | })
343 |
344 | res := httptest.NewRecorder()
345 | req, _ := http.NewRequest("GET", "/foo", nil)
346 | req.Host = "www.example.com"
347 | req.URL.Scheme = "http"
348 | req.Header.Add("X-Forwarded-Proto", "https")
349 |
350 | m.ServeHTTP(res, req)
351 |
352 | expect(t, res.Code, http.StatusMovedPermanently)
353 | expect(t, res.Header().Get("Location"), "https://secure.example.com/foo")
354 | }
355 |
356 | func Test_STS_Header(t *testing.T) {
357 | m := martini.Classic()
358 | martini.Env = martini.Prod
359 | m.Use(Secure(Options{
360 | STSSeconds: 315360000,
361 | }))
362 |
363 | m.Get("/foo", func() string {
364 | return "bar"
365 | })
366 |
367 | res := httptest.NewRecorder()
368 | req, _ := http.NewRequest("GET", "/foo", nil)
369 |
370 | m.ServeHTTP(res, req)
371 |
372 | expect(t, res.Code, http.StatusOK)
373 | expect(t, res.Header().Get("Strict-Transport-Security"), "max-age=315360000")
374 | }
375 |
376 | func Test_STS_Header_In_Dev_Mode(t *testing.T) {
377 | m := martini.Classic()
378 | martini.Env = martini.Dev
379 | m.Use(Secure(Options{
380 | STSSeconds: 315360000,
381 | }))
382 |
383 | m.Get("/foo", func() string {
384 | return "bar"
385 | })
386 |
387 | res := httptest.NewRecorder()
388 | req, _ := http.NewRequest("GET", "/foo", nil)
389 |
390 | m.ServeHTTP(res, req)
391 |
392 | expect(t, res.Code, http.StatusOK)
393 | expect(t, res.Header().Get("Strict-Transport-Security"), "")
394 | }
395 |
396 | func Test_STS_Header_With_Subdomain(t *testing.T) {
397 | m := martini.Classic()
398 | martini.Env = martini.Prod
399 | m.Use(Secure(Options{
400 | STSSeconds: 315360000,
401 | STSIncludeSubdomains: true,
402 | }))
403 |
404 | m.Get("/foo", func() string {
405 | return "bar"
406 | })
407 |
408 | res := httptest.NewRecorder()
409 | req, _ := http.NewRequest("GET", "/foo", nil)
410 |
411 | m.ServeHTTP(res, req)
412 |
413 | expect(t, res.Code, http.StatusOK)
414 | expect(t, res.Header().Get("Strict-Transport-Security"), "max-age=315360000; includeSubdomains")
415 | }
416 |
417 | func Test_Frame_Deny(t *testing.T) {
418 | m := martini.Classic()
419 | m.Use(Secure(Options{
420 | FrameDeny: true,
421 | }))
422 |
423 | m.Get("/foo", func() string {
424 | return "bar"
425 | })
426 |
427 | res := httptest.NewRecorder()
428 | req, _ := http.NewRequest("GET", "/foo", nil)
429 |
430 | m.ServeHTTP(res, req)
431 |
432 | expect(t, res.Code, http.StatusOK)
433 | expect(t, res.Header().Get("X-Frame-Options"), "DENY")
434 | }
435 |
436 | func Test_Custom_Frame_Value(t *testing.T) {
437 | m := martini.Classic()
438 | m.Use(Secure(Options{
439 | CustomFrameOptionsValue: "SAMEORIGIN",
440 | }))
441 |
442 | m.Get("/foo", func() string {
443 | return "bar"
444 | })
445 |
446 | res := httptest.NewRecorder()
447 | req, _ := http.NewRequest("GET", "/foo", nil)
448 |
449 | m.ServeHTTP(res, req)
450 |
451 | expect(t, res.Code, http.StatusOK)
452 | expect(t, res.Header().Get("X-Frame-Options"), "SAMEORIGIN")
453 | }
454 |
455 | func Test_Custom_Frame_Value_With_Deny(t *testing.T) {
456 | m := martini.Classic()
457 | m.Use(Secure(Options{
458 | FrameDeny: true,
459 | CustomFrameOptionsValue: "SAMEORIGIN",
460 | }))
461 |
462 | m.Get("/foo", func() string {
463 | return "bar"
464 | })
465 |
466 | res := httptest.NewRecorder()
467 | req, _ := http.NewRequest("GET", "/foo", nil)
468 |
469 | m.ServeHTTP(res, req)
470 |
471 | expect(t, res.Code, http.StatusOK)
472 | expect(t, res.Header().Get("X-Frame-Options"), "SAMEORIGIN")
473 | }
474 |
475 | func Test_Content_Nosniff(t *testing.T) {
476 | m := martini.Classic()
477 | m.Use(Secure(Options{
478 | ContentTypeNosniff: true,
479 | }))
480 |
481 | m.Get("/foo", func() string {
482 | return "bar"
483 | })
484 |
485 | res := httptest.NewRecorder()
486 | req, _ := http.NewRequest("GET", "/foo", nil)
487 |
488 | m.ServeHTTP(res, req)
489 |
490 | expect(t, res.Code, http.StatusOK)
491 | expect(t, res.Header().Get("X-Content-Type-Options"), "nosniff")
492 | }
493 |
494 | func Test_XSS_Protection(t *testing.T) {
495 | m := martini.Classic()
496 | m.Use(Secure(Options{
497 | BrowserXssFilter: true,
498 | }))
499 |
500 | m.Get("/foo", func() string {
501 | return "bar"
502 | })
503 |
504 | res := httptest.NewRecorder()
505 | req, _ := http.NewRequest("GET", "/foo", nil)
506 |
507 | m.ServeHTTP(res, req)
508 |
509 | expect(t, res.Code, http.StatusOK)
510 | expect(t, res.Header().Get("X-XSS-Protection"), "1; mode=block")
511 | }
512 |
513 | func Test_CSP(t *testing.T) {
514 | m := martini.Classic()
515 | m.Use(Secure(Options{
516 | ContentSecurityPolicy: "default-src 'self'",
517 | }))
518 |
519 | m.Get("/foo", func() string {
520 | return "bar"
521 | })
522 |
523 | res := httptest.NewRecorder()
524 | req, _ := http.NewRequest("GET", "/foo", nil)
525 |
526 | m.ServeHTTP(res, req)
527 |
528 | expect(t, res.Code, http.StatusOK)
529 | expect(t, res.Header().Get("Content-Security-Policy"), "default-src 'self'")
530 | }
531 |
532 | /* Test Helpers */
533 | func expect(t *testing.T, a interface{}, b interface{}) {
534 | if a != b {
535 | t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
536 | }
537 | }
538 |
--------------------------------------------------------------------------------
/sessionauth/README.md:
--------------------------------------------------------------------------------
1 | # martini-login
2 |
3 | ## Purpose
4 |
5 | This package provides a simple way to make routes require a login, and to handle user logins in
6 | the session. It should work with any user model that you have in your application, so long as
7 | your user model implements the login.User interface.
8 |
9 | Please see the example program in the example/ directory.
10 |
11 | ## Program Flow:
12 |
13 | Every new request to Martini will generate an Anonymous login.User struct using the function passed
14 | to SessionUser. This should default to a zero value user model, and must implement the login.User
15 | interface. If a user exists in the request session, this user will be injected into every request
16 | handler. Otherwise the zero value object will be injected.
17 |
18 | When a user visits any route with the **LoginRequired** handler, the login.User object will be
19 | examined with the IsAuthenticated() function. If the user is not authenticated, they will be
20 | redirected to a login page (/login).
21 |
22 | To log your users in, you should create a POST route, and verify the user/password that was sent
23 | from the client. Due to the vast possibilities of doing this, you must be responsible for
24 | validating a user. Once that user is validated, call login.AuthenticateSession() to mark the
25 | session as authenticated.
26 |
27 | Your user type should meet the login.User interface:
28 |
29 | ```go
30 | type User interface {
31 | // Return whether this user is logged in or not
32 | IsAuthenticated() bool
33 |
34 | // Set any flags or extra data that should be available
35 | Login()
36 |
37 | // Clear any sensitive data out of the user
38 | Logout()
39 |
40 | // Return the unique identifier of this user object
41 | UniqueId() interface{}
42 |
43 | // Populate this user object with values
44 | GetById(id interface{}) error
45 | }
46 | ```
47 |
48 | The SessionUser() Martini middleware will inject the login.User interface
49 | into your route handlers. These interfaces must be converted to your
50 | appropriate type to function correctly.
51 |
52 | ```go
53 | func handler(user login.User, db *MyDB) {
54 | u := user.(*UserModel)
55 | db.Save(u)
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/sessionauth/example/auth_example.go:
--------------------------------------------------------------------------------
1 | // Auth example is an example application which requires a login
2 | // to view a private link. The username is "testuser" and the password
3 | // is "password". This will require GORP and an SQLite3 database.
4 | package main
5 |
6 | import (
7 | "database/sql"
8 | "github.com/codegangsta/martini"
9 | "github.com/codegangsta/martini-contrib/binding"
10 | "github.com/codegangsta/martini-contrib/render"
11 | "github.com/codegangsta/martini-contrib/sessionauth"
12 | "github.com/codegangsta/martini-contrib/sessions"
13 | "github.com/coopernurse/gorp"
14 | _ "github.com/mattn/go-sqlite3"
15 | "log"
16 | "net/http"
17 | "os"
18 | )
19 |
20 | var dbmap *gorp.DbMap
21 |
22 | func initDb() *gorp.DbMap {
23 | // Delete our SQLite database if it already exists so we have a clean start
24 | _, err := os.Open("martini-sessionauth.bin")
25 | if err == nil {
26 | os.Remove("martini-sessionauth.bin")
27 | }
28 |
29 | db, err := sql.Open("sqlite3", "martini-sessionauth.bin")
30 | if err != nil {
31 | log.Fatalln("Fail to create database", err)
32 | }
33 |
34 | dbmap := &gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
35 | dbmap.AddTableWithName(MyUserModel{}, "users").SetKeys(true, "Id")
36 | err = dbmap.CreateTablesIfNotExists()
37 | if err != nil {
38 | log.Fatalln("Could not build tables", err)
39 | }
40 |
41 | user := MyUserModel{1, "testuser", "password", false}
42 | err = dbmap.Insert(&user)
43 | if err != nil {
44 | log.Fatalln("Could not insert test user", err)
45 | }
46 | return dbmap
47 | }
48 |
49 | func main() {
50 | store := sessions.NewCookieStore([]byte("secret123"))
51 | dbmap = initDb()
52 |
53 | m := martini.Classic()
54 | m.Use(render.Renderer())
55 | m.Use(sessions.Sessions("my_session", store))
56 | m.Use(sessionauth.SessionUser(GenerateAnonymousUser))
57 | sessionauth.RedirectUrl = "/new-login"
58 | sessionauth.RedirectParam = "new-next"
59 |
60 | m.Get("/", func(r render.Render) {
61 | r.HTML(200, "index", nil)
62 | })
63 |
64 | m.Get("/new-login", func(r render.Render) {
65 | r.HTML(200, "login", nil)
66 | })
67 |
68 | m.Post("/new-login", binding.Bind(MyUserModel{}), func(session sessions.Session, postedUser MyUserModel, r render.Render, req *http.Request) {
69 | // You should verify credentials against a database or some other mechanism at this point.
70 | // Then you can authenticate this session.
71 | user := MyUserModel{}
72 | err := dbmap.SelectOne(&user, "SELECT * FROM users WHERE username = $1 and password = $2", postedUser.Username, postedUser.Password)
73 | if err != nil {
74 | r.Redirect(sessionauth.RedirectUrl)
75 | return
76 | } else {
77 | err := sessionauth.AuthenticateSession(session, &user)
78 | if err != nil {
79 | r.JSON(500, err)
80 | }
81 |
82 | params := req.URL.Query()
83 | redirect := params.Get(sessionauth.RedirectParam)
84 | r.Redirect(redirect)
85 | return
86 | }
87 | })
88 |
89 | m.Get("/private", sessionauth.LoginRequired, func(r render.Render, user sessionauth.User) {
90 | r.HTML(200, "private", user.(*MyUserModel))
91 | })
92 |
93 | m.Get("/logout", sessionauth.LoginRequired, func(session sessions.Session, user sessionauth.User, r render.Render) {
94 | sessionauth.Logout(session, user)
95 | r.Redirect("/")
96 | })
97 |
98 | m.Run()
99 | }
100 |
--------------------------------------------------------------------------------
/sessionauth/example/templates/index.tmpl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | This is the Martini-Sessionauth example
5 | Try to visit this private link
6 |
7 |
8 |
--------------------------------------------------------------------------------
/sessionauth/example/templates/login.tmpl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | You must login!
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/sessionauth/example/templates/private.tmpl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | This is a private link!
5 | Hello {{ .Username }}
6 | Logout
7 |
8 |
9 |
--------------------------------------------------------------------------------
/sessionauth/example/user.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/codegangsta/martini-contrib/sessionauth"
5 | )
6 |
7 | // MyUserModel can be any struct that represents a user in my system
8 | type MyUserModel struct {
9 | Id int64 `form:"id" db:"id"`
10 | Username string `form:"name" db:"username"`
11 | Password string `form:"password" db:"password"`
12 | authenticated bool `form:"-" db:"-"`
13 | }
14 |
15 | // GetAnonymousUser should generate an anonymous user model
16 | // for all sessions. This should be an unauthenticated 0 value struct.
17 | func GenerateAnonymousUser() sessionauth.User {
18 | return &MyUserModel{}
19 | }
20 |
21 | // Login will preform any actions that are required to make a user model
22 | // officially authenticated.
23 | func (u *MyUserModel) Login() {
24 | // Update last login time
25 | // Add to logged-in user's list
26 | // etc ...
27 | u.authenticated = true
28 | }
29 |
30 | // Logout will preform any actions that are required to completely
31 | // logout a user.
32 | func (u *MyUserModel) Logout() {
33 | // Remove from logged-in user's list
34 | // etc ...
35 | u.authenticated = false
36 | }
37 |
38 | func (u *MyUserModel) IsAuthenticated() bool {
39 | return u.authenticated
40 | }
41 |
42 | func (u *MyUserModel) UniqueId() interface{} {
43 | return u.Id
44 | }
45 |
46 | // GetById will populate a user object from a database model with
47 | // a matching id.
48 | func (u *MyUserModel) GetById(id interface{}) error {
49 | err := dbmap.SelectOne(u, "SELECT * FROM users WHERE id = $1", id)
50 | if err != nil {
51 | return err
52 | }
53 |
54 | return nil
55 | }
56 |
--------------------------------------------------------------------------------
/sessionauth/login.go:
--------------------------------------------------------------------------------
1 | // Package login is a middleware for Martini that provides a simple way to track user sessions
2 | // in on a website. Please see https://github.com/codegangsta/martini-contrib/blob/master/sessionauth/README.md
3 | // for a more detailed description of the package.
4 | package sessionauth
5 |
6 | import (
7 | "fmt"
8 | "github.com/codegangsta/martini"
9 | "github.com/codegangsta/martini-contrib/render"
10 | "github.com/codegangsta/martini-contrib/sessions"
11 | "log"
12 | "net/http"
13 | )
14 |
15 | // These are the default configuration values for this package. They
16 | // can be set at anytime, probably during the initial setup of Martini.
17 | var (
18 | // RedirectUrl should be the relative URL for your login route
19 | RedirectUrl string = "/login"
20 |
21 | // RedirectParam is the query string parameter that will be set
22 | // with the page the user was trying to visit before they were
23 | // intercepted.
24 | RedirectParam string = "next"
25 |
26 | // SessionKey is the key containing the unique ID in your session
27 | SessionKey string = "AUTHUNIQUEID"
28 | )
29 |
30 | // User defines all the functions necessary to work with the user's authentication.
31 | // The caller should implement these functions for whatever system of authentication
32 | // they choose to use
33 | type User interface {
34 | // Return whether this user is logged in or not
35 | IsAuthenticated() bool
36 |
37 | // Set any flags or extra data that should be available
38 | Login()
39 |
40 | // Clear any sensitive data out of the user
41 | Logout()
42 |
43 | // Return the unique identifier of this user object
44 | UniqueId() interface{}
45 |
46 | // Populate this user object with values
47 | GetById(id interface{}) error
48 | }
49 |
50 | // SessionUser will try to read a unique user ID out of the session. Then it tries
51 | // to populate an anonymous user object from the database based on that ID. If this
52 | // is successful, the valid user is mapped into the context. Otherwise the anonymous
53 | // user is mapped into the contact.
54 | // The newUser() function should provide a valid 0value structure for the caller's
55 | // user type.
56 | func SessionUser(newUser func() User) martini.Handler {
57 | return func(s sessions.Session, c martini.Context, l *log.Logger) {
58 | userId := s.Get(SessionKey)
59 | user := newUser()
60 |
61 | if userId != nil {
62 | err := user.GetById(userId)
63 | if err != nil {
64 | l.Printf("Login Error: %v\n", err)
65 | } else {
66 | user.Login()
67 | }
68 | }
69 |
70 | c.MapTo(user, (*User)(nil))
71 | }
72 | }
73 |
74 | // AuthenticateSession will mark the session and user object as authenticated. Then
75 | // the Login() user function will be called. This function should be called after
76 | // you have validated a user.
77 | func AuthenticateSession(s sessions.Session, user User) error {
78 | user.Login()
79 | return UpdateUser(s, user)
80 | }
81 |
82 | // Logout will clear out the session and call the Logout() user function.
83 | func Logout(s sessions.Session, user User) {
84 | user.Logout()
85 | s.Delete(SessionKey)
86 | }
87 |
88 | // LoginRequired verifies that the current user is authenticated. Any routes that
89 | // require a login should have this handler placed in the flow. If the user is not
90 | // authenticated, they will be redirected to /login with the "next" get parameter
91 | // set to the attempted URL.
92 | func LoginRequired(r render.Render, user User, req *http.Request) {
93 | if user.IsAuthenticated() == false {
94 | path := fmt.Sprintf("%s?%s=%s", RedirectUrl, RedirectParam, req.URL.Path)
95 | r.Redirect(path, 302)
96 | }
97 | }
98 |
99 | // UpdateUser updates the User object stored in the session. This is useful incase a change
100 | // is made to the user model that needs to persist across requests.
101 | func UpdateUser(s sessions.Session, user User) error {
102 | s.Set(SessionKey, user.UniqueId())
103 | return nil
104 | }
105 |
--------------------------------------------------------------------------------
/sessionauth/login_test.go:
--------------------------------------------------------------------------------
1 | package sessionauth
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "github.com/codegangsta/martini-contrib/render"
6 | "github.com/codegangsta/martini-contrib/sessions"
7 | "net/http"
8 | "net/http/httptest"
9 | "testing"
10 | )
11 |
12 | type TestUser struct {
13 | Id int `json:"id"`
14 | Name string `json:"name"`
15 | Age int `json:"age"`
16 | authenticated bool `json:"-"`
17 | }
18 |
19 | func (u *TestUser) IsAuthenticated() bool {
20 | return u.authenticated
21 | }
22 |
23 | func (u *TestUser) Login() {
24 | u.authenticated = true
25 | }
26 |
27 | func (u *TestUser) Logout() {
28 | u.authenticated = false
29 | }
30 |
31 | func (u *TestUser) UniqueId() interface{} {
32 | return u.Id
33 | }
34 |
35 | func (u *TestUser) GetById(id interface{}) error {
36 | u.Id = id.(int)
37 | u.Name = "My Test User"
38 | u.Age = 42
39 |
40 | return nil
41 | }
42 |
43 | func NewUser() User {
44 | return &TestUser{}
45 | }
46 |
47 | func TestAuthenticateSession(t *testing.T) {
48 | store := sessions.NewCookieStore([]byte("secret123"))
49 | m := martini.Classic()
50 |
51 | m.Use(render.Renderer())
52 | m.Use(sessions.Sessions("my_session", store))
53 | m.Use(SessionUser(NewUser))
54 |
55 | m.Get("/setauth", func(session sessions.Session, user User) string {
56 | err := AuthenticateSession(session, user)
57 | if err != nil {
58 | t.Error(err)
59 | }
60 | return "OK"
61 | })
62 |
63 | m.Get("/private", LoginRequired, func(session sessions.Session, user User) string {
64 | return "OK"
65 | })
66 |
67 | m.Get("/logout", LoginRequired, func(session sessions.Session, user User) string {
68 | Logout(session, user)
69 | return "OK"
70 | })
71 |
72 | res := httptest.NewRecorder()
73 | req, _ := http.NewRequest("GET", "/private", nil)
74 | m.ServeHTTP(res, req)
75 | if res.Code != 302 {
76 | t.Errorf("Private response should be 302, was %d", res.Code)
77 | }
78 |
79 | res1 := httptest.NewRecorder()
80 | req1, _ := http.NewRequest("GET", "/setauth", nil)
81 | req1.Header.Set("Cookie", res.Header().Get("Set-Cookie"))
82 | m.ServeHTTP(res1, req1)
83 | if res1.Code != 200 {
84 | t.Errorf("Setauth response should be 200, was %d", res.Code)
85 | }
86 |
87 | res2 := httptest.NewRecorder()
88 | req2, _ := http.NewRequest("GET", "/private", nil)
89 | req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
90 | m.ServeHTTP(res2, req2)
91 | if res2.Code != 200 {
92 | t.Errorf("Authenticated private response should be 200, was %d", res.Code)
93 | }
94 |
95 | res3 := httptest.NewRecorder()
96 | req3, _ := http.NewRequest("GET", "/logout", nil)
97 | req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie"))
98 | m.ServeHTTP(res3, req3)
99 | if res3.Code != 302 {
100 | t.Errorf("Logout response should be 302, was %d", res.Code)
101 | }
102 |
103 | res4 := httptest.NewRecorder()
104 | req4, _ := http.NewRequest("GET", "/private", nil)
105 | req4.Header.Set("Cookie", res3.Header().Get("Set-Cookie"))
106 | m.ServeHTTP(res4, req4)
107 | if res4.Code != 302 {
108 | t.Errorf("Unauthenticated private response should be 302, was %d", res.Code)
109 | }
110 |
111 | }
112 |
--------------------------------------------------------------------------------
/sessions/README.md:
--------------------------------------------------------------------------------
1 | # sessions
2 | Martini middleware/handler for easy session management.
3 |
4 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/sessions)
5 |
6 | ## Usage
7 |
8 | ~~~ go
9 | package main
10 |
11 | import (
12 | "github.com/codegangsta/martini"
13 | "github.com/codegangsta/martini-contrib/sessions"
14 | )
15 |
16 | func main() {
17 | m := martini.Classic()
18 |
19 | store := sessions.NewCookieStore([]byte("secret123"))
20 | m.Use(sessions.Sessions("my_session", store))
21 |
22 | m.Get("/set", func(session sessions.Session) string {
23 | session.Set("hello", "world")
24 | return "OK"
25 | })
26 |
27 | m.Get("/get", func(session sessions.Session) string {
28 | v := session.Get("hello")
29 | if v == nil {
30 | return ""
31 | }
32 | return v.(string)
33 | })
34 |
35 | m.Run()
36 | }
37 |
38 | ~~~
39 |
40 | ## Authors
41 | * [Jeremy Saenz](http://github.com/codegangsta)
42 |
--------------------------------------------------------------------------------
/sessions/benchmarks_test.go:
--------------------------------------------------------------------------------
1 | package sessions
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | )
9 |
10 | func BenchmarkNoSessionsMiddleware(b *testing.B) {
11 | m := testMartini()
12 | m.Get("/foo", func() string {
13 | return "Foo"
14 | })
15 |
16 | recorder := httptest.NewRecorder()
17 | r, _ := http.NewRequest("GET", "/foo", nil)
18 |
19 | b.ResetTimer()
20 | for n := 0; n < b.N; n++ {
21 | m.ServeHTTP(recorder, r)
22 | }
23 | }
24 |
25 | func BenchmarkSessionsNoWrites(b *testing.B) {
26 | m := testMartini()
27 | store := NewCookieStore([]byte("secret123"))
28 | m.Use(Sessions("my_session", store))
29 | m.Get("/foo", func() string {
30 | return "Foo"
31 | })
32 |
33 | recorder := httptest.NewRecorder()
34 | r, _ := http.NewRequest("GET", "/foo", nil)
35 |
36 | b.ResetTimer()
37 | for n := 0; n < b.N; n++ {
38 | m.ServeHTTP(recorder, r)
39 | }
40 | }
41 |
42 | func BenchmarkSessionsWithWrite(b *testing.B) {
43 | m := testMartini()
44 | store := NewCookieStore([]byte("secret123"))
45 | m.Use(Sessions("my_session", store))
46 | m.Get("/foo", func(s Session) string {
47 | s.Set("foo", "bar")
48 | return "Foo"
49 | })
50 |
51 | recorder := httptest.NewRecorder()
52 | r, _ := http.NewRequest("GET", "/foo", nil)
53 |
54 | b.ResetTimer()
55 | for n := 0; n < b.N; n++ {
56 | m.ServeHTTP(recorder, r)
57 | }
58 | }
59 |
60 | func BenchmarkSessionsWithRead(b *testing.B) {
61 | m := testMartini()
62 | store := NewCookieStore([]byte("secret123"))
63 | m.Use(Sessions("my_session", store))
64 | m.Get("/foo", func(s Session) string {
65 | s.Get("foo")
66 | return "Foo"
67 | })
68 |
69 | recorder := httptest.NewRecorder()
70 | r, _ := http.NewRequest("GET", "/foo", nil)
71 |
72 | b.ResetTimer()
73 | for n := 0; n < b.N; n++ {
74 | m.ServeHTTP(recorder, r)
75 | }
76 | }
77 |
78 | func testMartini() *martini.ClassicMartini {
79 | m := martini.Classic()
80 | m.Handlers()
81 | return m
82 | }
83 |
--------------------------------------------------------------------------------
/sessions/cookie_store.go:
--------------------------------------------------------------------------------
1 | package sessions
2 |
3 | import (
4 | "github.com/gorilla/sessions"
5 | )
6 |
7 | // CookieStore is an interface that represents a Cookie based storage
8 | // for Sessions.
9 | type CookieStore interface {
10 | // Store is an embedded interface so that CookieStore can be used
11 | // as a session store.
12 | Store
13 | // Options sets the default options for each session stored in this
14 | // CookieStore.
15 | Options(Options)
16 | }
17 |
18 | // NewCookieStore returns a new CookieStore.
19 | //
20 | // Keys are defined in pairs to allow key rotation, but the common case is to set a single
21 | // authentication key and optionally an encryption key.
22 | //
23 | // The first key in a pair is used for authentication and the second for encryption. The
24 | // encryption key can be set to nil or omitted in the last pair, but the authentication key
25 | // is required in all pairs.
26 | //
27 | // It is recommended to use an authentication key with 32 or 64 bytes. The encryption key,
28 | // if set, must be either 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256 modes.
29 | func NewCookieStore(keyPairs ...[]byte) CookieStore {
30 | return &cookieStore{sessions.NewCookieStore(keyPairs...)}
31 | }
32 |
33 | type cookieStore struct {
34 | *sessions.CookieStore
35 | }
36 |
37 | func (c *cookieStore) Options(options Options) {
38 | c.CookieStore.Options = &sessions.Options{
39 | Path: options.Path,
40 | Domain: options.Domain,
41 | MaxAge: options.MaxAge,
42 | Secure: options.Secure,
43 | HttpOnly: options.HttpOnly,
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/sessions/sessions.go:
--------------------------------------------------------------------------------
1 | // Package sessions contains middleware for easy session management in Martini.
2 | //
3 | // package main
4 | //
5 | // import (
6 | // "github.com/codegangsta/martini"
7 | // "github.com/codegangsta/martini-contrib/sessions"
8 | // )
9 | //
10 | // func main() {
11 | // m := martini.Classic()
12 | //
13 | // store := sessions.NewCookieStore([]byte("secret123"))
14 | // m.Use(sessions.Sessions("my_session", store))
15 | //
16 | // m.Get("/", func(session sessions.Session) string {
17 | // session.Set("hello", "world")
18 | // })
19 | // }
20 | package sessions
21 |
22 | import (
23 | "github.com/codegangsta/martini"
24 | "github.com/gorilla/context"
25 | "github.com/gorilla/sessions"
26 | "log"
27 | "net/http"
28 | )
29 |
30 | const (
31 | errorFormat = "[sessions] ERROR! %s\n"
32 | )
33 |
34 | // Store is an interface for custom session stores.
35 | type Store interface {
36 | sessions.Store
37 | }
38 |
39 | // Options stores configuration for a session or session store.
40 | //
41 | // Fields are a subset of http.Cookie fields.
42 | type Options struct {
43 | Path string
44 | Domain string
45 | // MaxAge=0 means no 'Max-Age' attribute specified.
46 | // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
47 | // MaxAge>0 means Max-Age attribute present and given in seconds.
48 | MaxAge int
49 | Secure bool
50 | HttpOnly bool
51 | }
52 |
53 | // Session stores the values and optional configuration for a session.
54 | type Session interface {
55 | // Get returns the session value associated to the given key.
56 | Get(key interface{}) interface{}
57 | // Set sets the session value associated to the given key.
58 | Set(key interface{}, val interface{})
59 | // Delete removes the session value associated to the given key.
60 | Delete(key interface{})
61 | // AddFlash adds a flash message to the session.
62 | // A single variadic argument is accepted, and it is optional: it defines the flash key.
63 | // If not defined "_flash" is used by default.
64 | AddFlash(value interface{}, vars ...string)
65 | // Flashes returns a slice of flash messages from the session.
66 | // A single variadic argument is accepted, and it is optional: it defines the flash key.
67 | // If not defined "_flash" is used by default.
68 | Flashes(vars ...string) []interface{}
69 | // Options sets confuguration for a session.
70 | Options(Options)
71 | }
72 |
73 | // Sessions is a Middleware that maps a session.Session service into the Martini handler chain.
74 | // Sessions can use a number of storage solutions with the given store.
75 | func Sessions(name string, store Store) martini.Handler {
76 | return func(res http.ResponseWriter, r *http.Request, c martini.Context, l *log.Logger) {
77 | // Map to the Session interface
78 | s := &session{name, r, l, store, nil, false}
79 | c.MapTo(s, (*Session)(nil))
80 |
81 | // Use before hook to save out the session
82 | rw := res.(martini.ResponseWriter)
83 | rw.Before(func(martini.ResponseWriter) {
84 | if s.Written() {
85 | check(s.Session().Save(r, res), l)
86 | }
87 | })
88 |
89 | // clear the context, we don't need to use
90 | // gorilla context and we don't want memory leaks
91 | defer context.Clear(r)
92 |
93 | c.Next()
94 | }
95 | }
96 |
97 | type session struct {
98 | name string
99 | request *http.Request
100 | logger *log.Logger
101 | store Store
102 | session *sessions.Session
103 | written bool
104 | }
105 |
106 | func (s *session) Get(key interface{}) interface{} {
107 | return s.Session().Values[key]
108 | }
109 |
110 | func (s *session) Set(key interface{}, val interface{}) {
111 | s.Session().Values[key] = val
112 | s.written = true
113 | }
114 |
115 | func (s *session) Delete(key interface{}) {
116 | delete(s.Session().Values, key)
117 | s.written = true
118 | }
119 |
120 | func (s *session) AddFlash(value interface{}, vars ...string) {
121 | s.Session().AddFlash(value, vars...)
122 | s.written = true
123 | }
124 |
125 | func (s *session) Flashes(vars ...string) []interface{} {
126 | s.written = true
127 | return s.Session().Flashes(vars...)
128 | }
129 |
130 | func (s *session) Options(options Options) {
131 | s.Session().Options = &sessions.Options{
132 | Path: options.Path,
133 | Domain: options.Domain,
134 | MaxAge: options.MaxAge,
135 | Secure: options.Secure,
136 | HttpOnly: options.HttpOnly,
137 | }
138 | }
139 |
140 | func (s *session) Session() *sessions.Session {
141 | if s.session == nil {
142 | var err error
143 | s.session, err = s.store.Get(s.request, s.name)
144 | check(err, s.logger)
145 | }
146 |
147 | return s.session
148 | }
149 |
150 | func (s *session) Written() bool {
151 | return s.written
152 | }
153 |
154 | func check(err error, l *log.Logger) {
155 | if err != nil {
156 | l.Printf(errorFormat, err)
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/sessions/sessions_test.go:
--------------------------------------------------------------------------------
1 | package sessions
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "strings"
8 | "testing"
9 | )
10 |
11 | func Test_Sessions(t *testing.T) {
12 | m := martini.Classic()
13 |
14 | store := NewCookieStore([]byte("secret123"))
15 | m.Use(Sessions("my_session", store))
16 |
17 | m.Get("/testsession", func(session Session) string {
18 | session.Set("hello", "world")
19 | return "OK"
20 | })
21 |
22 | m.Get("/show", func(session Session) string {
23 | if session.Get("hello") != "world" {
24 | t.Error("Session writing failed")
25 | }
26 | return "OK"
27 | })
28 |
29 | res := httptest.NewRecorder()
30 | req, _ := http.NewRequest("GET", "/testsession", nil)
31 | m.ServeHTTP(res, req)
32 |
33 | res2 := httptest.NewRecorder()
34 | req2, _ := http.NewRequest("GET", "/show", nil)
35 | req2.Header.Set("Cookie", res.Header().Get("Set-Cookie"))
36 | m.ServeHTTP(res2, req2)
37 | }
38 |
39 | func Test_SessionsDeleteValue(t *testing.T) {
40 | m := martini.Classic()
41 |
42 | store := NewCookieStore([]byte("secret123"))
43 | m.Use(Sessions("my_session", store))
44 |
45 | m.Get("/testsession", func(session Session) string {
46 | session.Set("hello", "world")
47 | session.Delete("hello")
48 | return "OK"
49 | })
50 |
51 | m.Get("/show", func(session Session) string {
52 | if session.Get("hello") == "world" {
53 | t.Error("Session value deleting failed")
54 | }
55 | return "OK"
56 | })
57 |
58 | res := httptest.NewRecorder()
59 | req, _ := http.NewRequest("GET", "/testsession", nil)
60 | m.ServeHTTP(res, req)
61 |
62 | res2 := httptest.NewRecorder()
63 | req2, _ := http.NewRequest("GET", "/show", nil)
64 | req2.Header.Set("Cookie", res.Header().Get("Set-Cookie"))
65 | m.ServeHTTP(res2, req2)
66 | }
67 |
68 | func Test_Options(t *testing.T) {
69 | m := martini.Classic()
70 | store := NewCookieStore([]byte("secret123"))
71 | store.Options(Options{
72 | Domain: "martini.codegangsta.io",
73 | })
74 | m.Use(Sessions("my_session", store))
75 |
76 | m.Get("/", func(session Session) string {
77 | session.Set("hello", "world")
78 | session.Options(Options{
79 | Path: "/foo/bar/bat",
80 | })
81 | return "OK"
82 | })
83 |
84 | m.Get("/foo", func(session Session) string {
85 | session.Set("hello", "world")
86 | return "OK"
87 | })
88 |
89 | res := httptest.NewRecorder()
90 | req, _ := http.NewRequest("GET", "/", nil)
91 | m.ServeHTTP(res, req)
92 |
93 | res2 := httptest.NewRecorder()
94 | req2, _ := http.NewRequest("GET", "/foo", nil)
95 | m.ServeHTTP(res2, req2)
96 |
97 | s := strings.Split(res.Header().Get("Set-Cookie"), ";")
98 | if s[1] != " Path=/foo/bar/bat" {
99 | t.Error("Error writing path with options:", s[1])
100 | }
101 |
102 | s = strings.Split(res2.Header().Get("Set-Cookie"), ";")
103 | if s[1] != " Domain=martini.codegangsta.io" {
104 | t.Error("Error writing domain with options:", s[1])
105 | }
106 | }
107 |
108 | func Test_Flashes(t *testing.T) {
109 | m := martini.Classic()
110 |
111 | store := NewCookieStore([]byte("secret123"))
112 | m.Use(Sessions("my_session", store))
113 |
114 | m.Get("/set", func(session Session) string {
115 | session.AddFlash("hello world")
116 | return "OK"
117 | })
118 |
119 | m.Get("/show", func(session Session) string {
120 | l := len(session.Flashes())
121 | if l != 1 {
122 | t.Error("Flashes count does not equal 1. Equals ", l)
123 | }
124 | return "OK"
125 | })
126 |
127 | m.Get("/showagain", func(session Session) string {
128 | l := len(session.Flashes())
129 | if l != 0 {
130 | t.Error("flashes count is not 0 after reading. Equals ", l)
131 | }
132 | return "OK"
133 | })
134 |
135 | res := httptest.NewRecorder()
136 | req, _ := http.NewRequest("GET", "/set", nil)
137 | m.ServeHTTP(res, req)
138 |
139 | res2 := httptest.NewRecorder()
140 | req2, _ := http.NewRequest("GET", "/show", nil)
141 | req2.Header.Set("Cookie", res.Header().Get("Set-Cookie"))
142 | m.ServeHTTP(res2, req2)
143 |
144 | res3 := httptest.NewRecorder()
145 | req3, _ := http.NewRequest("GET", "/showagain", nil)
146 | req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie"))
147 | m.ServeHTTP(res3, req3)
148 | }
149 |
--------------------------------------------------------------------------------
/strip/README.md:
--------------------------------------------------------------------------------
1 | # strip
2 |
3 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/strip)
4 |
5 | ## Description
6 | packcage `strip` modifies the URL before the requests go into the other
7 | handlers.
8 |
9 | Currently the main function in package strip is `strip.Prefix` which provides
10 | the save functionality as `http.StripPrefix` and can be used in martini instance
11 | and request context level.
12 |
13 | With `strip.Prefix` martini instances can be installed upon each other, and so
14 | does some other web framework like [web.go][].
15 |
16 | [web.go]:https://github.com/hoisie/web
17 |
18 | ## Usage
19 |
20 | ~~~ go
21 | package main
22 |
23 | import (
24 | "github.com/codegangsta/martini-contrib/strip"
25 | "github.com/codegangsta/martini"
26 | )
27 |
28 | func main() {
29 | m := martini.Classic()
30 |
31 | m2 := martini.Classic()
32 | m2.Get("/", func() string {
33 | return "Hello World from 2nd martini"
34 | })
35 |
36 | m2.Get("/foo", func() string {
37 | return "Hello foo"
38 | })
39 |
40 | m.Get("/", func() string {
41 | return "Hello World from 1st martini"
42 | })
43 | m.Get("/2ndMartini/.*", strip.Prefix("/2ndMartini"), m2.ServeHTTP)
44 |
45 | m.Run()
46 | }
47 | ~~~
48 |
49 | But the example above can only translate the same HTTP method from `m.Get`
50 | to `m2.Get`, in order to transfer all kinds request such as `Post`,`Delete`,
51 | etc to `m2`, martini has to provide a method `Any` to match any HTTP method
52 | to a certain URL pattern.
53 |
54 | ## Authors
55 | * [Jeremy Saenz](http://github.com/codegangsta)
56 | * [Archs Sun](http://github.com/Archs)
57 | * [hoisie](http://github.com/hoisie)
58 |
--------------------------------------------------------------------------------
/strip/prefix.go:
--------------------------------------------------------------------------------
1 | // packcage strip provides the same functionality as http.StripPrefix
2 | // and can be used in martini instance level and request context level.
3 | package strip
4 |
5 | import (
6 | "github.com/codegangsta/martini"
7 | "net/http"
8 | "strings"
9 | )
10 |
11 | // strip Prefix for every incoming http request
12 | func Prefix(prefix string) martini.Handler {
13 | return func(w http.ResponseWriter, r *http.Request) {
14 | if prefix == "" {
15 | return
16 | }
17 | if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) {
18 | r.URL.Path = p
19 | } else {
20 | http.NotFound(w, r)
21 | }
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/strip/prefix_test.go:
--------------------------------------------------------------------------------
1 | package strip
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | )
9 |
10 | func TestStripPrefix(t *testing.T) {
11 | w := httptest.NewRecorder()
12 | r, _ := http.NewRequest("GET", "/foo/bar", nil)
13 | Prefix("/foo").(func(http.ResponseWriter, *http.Request))(w, r)
14 | if r.URL.Path != "/bar" {
15 | t.Fatalf("Strip Prefix Failed")
16 | }
17 | }
18 |
19 | func TestInMartini(t *testing.T) {
20 | m := martini.New()
21 | m.Use(Prefix("/foo"))
22 | m.Use(func(w http.ResponseWriter, r *http.Request) {
23 | if r.URL.Path != "/bar" {
24 | t.Fatalf("Strip Prefix Failed")
25 | }
26 | })
27 | w := httptest.NewRecorder()
28 | r, _ := http.NewRequest("GET", "/foo/bar", nil)
29 | m.ServeHTTP(w, r)
30 | }
31 |
32 | func TestInRequestContext(t *testing.T) {
33 | m := martini.Classic()
34 | m.Get("/foo/bar", Prefix("/foo"), func(w http.ResponseWriter, r *http.Request) {
35 | if r.URL.Path != "/bar" {
36 | t.Fatalf("Strip Prefix Failed")
37 | }
38 | })
39 | w := httptest.NewRecorder()
40 | r, _ := http.NewRequest("GET", "/foo/bar", nil)
41 | m.ServeHTTP(w, r)
42 | }
43 |
--------------------------------------------------------------------------------
/web/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2013 Archs Sun
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/web/README.md:
--------------------------------------------------------------------------------
1 | # web.Context
2 | [hoisie][] [web.go][]'s Context for Martini.
3 |
4 | [hoisie]:https://github.com/hoisie
5 | [web.go]:https://github.com/hoisie/web
6 |
7 | [API Reference](http://godoc.org/github.com/codegangsta/martini-contrib/web)
8 |
9 | ## Description
10 | `web.Context` provides a [web.go][] compitable layer for reusing the code written with
11 | hoisie's `web.go` framework. Here compitable means we can use `web.Context` the same
12 | way as in hoisie's `web.go` but not the others.
13 |
14 | ## Usage
15 |
16 | ~~~ go
17 | package main
18 |
19 | import (
20 | "github.com/codegangsta/martini"
21 | "github.com/codegangsta/martini-contrib/web"
22 | )
23 |
24 | func main() {
25 | m := martini.Classic()
26 | m.Use(web.ContextWithCookieSecret(""))
27 |
28 | m.Post("/hello", func(ctx *web.Context){
29 | ctx.WriteString("Hello World!")
30 | })
31 |
32 | m.Run()
33 | }
34 | ~~~
35 |
36 | ## Authors
37 | * [Jeremy Saenz](http://github.com/codegangsta)
38 | * [Archs Sun](http://github.com/Archs)
39 | * [hoisie][]
40 |
--------------------------------------------------------------------------------
/web/web.go:
--------------------------------------------------------------------------------
1 | // Package web provides a web.go compitable layer for reusing the code written with
2 | // hoisie's `web.go` framework. Basiclly this package add web.Context to
3 | // martini's dependency injection system.
4 | package web
5 |
6 | import (
7 | "bytes"
8 | "crypto/hmac"
9 | "crypto/sha1"
10 | "encoding/base64"
11 | "fmt"
12 | "github.com/codegangsta/martini"
13 | "io/ioutil"
14 | "mime"
15 | "net/http"
16 | "strconv"
17 | "strings"
18 | "time"
19 | )
20 |
21 | // A Context object is created for every incoming HTTP request, and is
22 | // passed to handlers as an optional first argument. It provides information
23 | // about the request, including the http.Request object, the GET and POST params,
24 | // and acts as a Writer for the response.
25 | type Context struct {
26 | Request *http.Request
27 | Params map[string]string
28 | cookieSecret string
29 | http.ResponseWriter
30 | }
31 |
32 | // if cookie secret is set to "", then SetSecureCookie would not work
33 | func ContextWithCookieSecret(secret string) martini.Handler {
34 | return func(w http.ResponseWriter, req *http.Request, mc martini.Context) {
35 | ctx := &Context{req, map[string]string{}, secret, w}
36 | //set some default headers
37 | tm := time.Now().UTC()
38 |
39 | //ignore errors from ParseForm because it's usually harmless.
40 | req.ParseForm()
41 | if len(req.Form) > 0 {
42 | for k, v := range req.Form {
43 | ctx.Params[k] = v[0]
44 | }
45 | }
46 | ctx.SetHeader("Date", webTime(tm), true)
47 | //Set the default content-type
48 | ctx.SetHeader("Content-Type", "text/html; charset=utf-8", true)
49 | // set martini context for web.Context
50 | mc.Map(ctx)
51 | }
52 | }
53 |
54 | // internal utility methods
55 | func webTime(t time.Time) string {
56 | ftime := t.Format(time.RFC1123)
57 | if strings.HasSuffix(ftime, "UTC") {
58 | ftime = ftime[0:len(ftime)-3] + "GMT"
59 | }
60 | return ftime
61 | }
62 |
63 | // WriteString writes string data into the response object.
64 | func (ctx *Context) WriteString(content string) {
65 | ctx.ResponseWriter.Write([]byte(content))
66 | }
67 |
68 | // Abort is a helper method that sends an HTTP header and an optional
69 | // body. It is useful for returning 4xx or 5xx errors.
70 | // Once it has been called, any return value from the handler will
71 | // not be written to the response.
72 | func (ctx *Context) Abort(status int, body string) {
73 | ctx.ResponseWriter.WriteHeader(status)
74 | ctx.ResponseWriter.Write([]byte(body))
75 | }
76 |
77 | // Redirect is a helper method for 3xx redirects.
78 | func (ctx *Context) Redirect(status int, url_ string) {
79 | ctx.ResponseWriter.Header().Set("Location", url_)
80 | ctx.ResponseWriter.WriteHeader(status)
81 | ctx.ResponseWriter.Write([]byte("Redirecting to: " + url_))
82 | }
83 |
84 | // Notmodified writes a 304 HTTP response
85 | func (ctx *Context) NotModified() {
86 | ctx.ResponseWriter.WriteHeader(304)
87 | }
88 |
89 | // NotFound writes a 404 HTTP response
90 | func (ctx *Context) NotFound(message string) {
91 | ctx.ResponseWriter.WriteHeader(404)
92 | ctx.ResponseWriter.Write([]byte(message))
93 | }
94 |
95 | //Unauthorized writes a 401 HTTP response
96 | func (ctx *Context) Unauthorized() {
97 | ctx.ResponseWriter.WriteHeader(401)
98 | }
99 |
100 | //Forbidden writes a 403 HTTP response
101 | func (ctx *Context) Forbidden() {
102 | ctx.ResponseWriter.WriteHeader(403)
103 | }
104 |
105 | // ContentType sets the Content-Type header for an HTTP response.
106 | // For example, ctx.ContentType("json") sets the content-type to "application/json"
107 | // If the supplied value contains a slash (/) it is set as the Content-Type
108 | // verbatim. The return value is the content type as it was
109 | // set, or an empty string if none was found.
110 | func (ctx *Context) ContentType(val string) string {
111 | var ctype string
112 | if strings.ContainsRune(val, '/') {
113 | ctype = val
114 | } else {
115 | if !strings.HasPrefix(val, ".") {
116 | val = "." + val
117 | }
118 | ctype = mime.TypeByExtension(val)
119 | }
120 | if ctype != "" {
121 | ctx.Header().Set("Content-Type", ctype)
122 | }
123 | return ctype
124 | }
125 |
126 | // SetHeader sets a response header. If `unique` is true, the current value
127 | // of that header will be overwritten . If false, it will be appended.
128 | func (ctx *Context) SetHeader(hdr string, val string, unique bool) {
129 | if unique {
130 | ctx.Header().Set(hdr, val)
131 | } else {
132 | ctx.Header().Add(hdr, val)
133 | }
134 | }
135 |
136 | // SetCookie adds a cookie header to the response.
137 | func (ctx *Context) SetCookie(cookie *http.Cookie) {
138 | ctx.SetHeader("Set-Cookie", cookie.String(), false)
139 | }
140 |
141 | func getCookieSig(key string, val []byte, timestamp string) string {
142 | hm := hmac.New(sha1.New, []byte(key))
143 |
144 | hm.Write(val)
145 | hm.Write([]byte(timestamp))
146 |
147 | hex := fmt.Sprintf("%02x", hm.Sum(nil))
148 | return hex
149 | }
150 |
151 | // NewCookie is a helper method that returns a new http.Cookie object.
152 | // Duration is specified in seconds. If the duration is zero, the cookie is permanent.
153 | // This can be used in conjunction with ctx.SetCookie.
154 | func NewCookie(name string, value string, age int64) *http.Cookie {
155 | var utctime time.Time
156 | if age == 0 {
157 | // 2^31 - 1 seconds (roughly 2038)
158 | utctime = time.Unix(2147483647, 0)
159 | } else {
160 | utctime = time.Unix(time.Now().Unix()+age, 0)
161 | }
162 | return &http.Cookie{Name: name, Value: value, Expires: utctime}
163 | }
164 |
165 | func (ctx *Context) SetSecureCookie(name string, val string, age int64) {
166 | //base64 encode the val
167 | if len(ctx.cookieSecret) == 0 {
168 | return
169 | }
170 | var buf bytes.Buffer
171 | encoder := base64.NewEncoder(base64.StdEncoding, &buf)
172 | encoder.Write([]byte(val))
173 | encoder.Close()
174 | vs := buf.String()
175 | vb := buf.Bytes()
176 | timestamp := strconv.FormatInt(time.Now().Unix(), 10)
177 | sig := getCookieSig(ctx.cookieSecret, vb, timestamp)
178 | cookie := strings.Join([]string{vs, timestamp, sig}, "|")
179 | ctx.SetCookie(NewCookie(name, cookie, age))
180 | }
181 |
182 | func (ctx *Context) GetSecureCookie(name string) (string, bool) {
183 | for _, cookie := range ctx.Request.Cookies() {
184 | if cookie.Name != name {
185 | continue
186 | }
187 |
188 | parts := strings.SplitN(cookie.Value, "|", 3)
189 |
190 | val := parts[0]
191 | timestamp := parts[1]
192 | sig := parts[2]
193 |
194 | if getCookieSig(ctx.cookieSecret, []byte(val), timestamp) != sig {
195 | return "", false
196 | }
197 |
198 | ts, _ := strconv.ParseInt(timestamp, 0, 64)
199 |
200 | if time.Now().Unix()-31*86400 > ts {
201 | return "", false
202 | }
203 |
204 | buf := bytes.NewBufferString(val)
205 | encoder := base64.NewDecoder(base64.StdEncoding, buf)
206 |
207 | res, _ := ioutil.ReadAll(encoder)
208 | return string(res), true
209 | }
210 | return "", false
211 | }
212 |
--------------------------------------------------------------------------------
/web/web_test.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "github.com/codegangsta/martini"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | )
9 |
10 | func TestWriteString(t *testing.T) {
11 | str := "Hello World!"
12 | m := martini.Classic()
13 | m.Use(ContextWithCookieSecret("secret"))
14 | m.Get("/", func(ctx *Context) {
15 | ctx.WriteString(str)
16 | })
17 | res := httptest.NewRecorder()
18 | req, _ := http.NewRequest("GET", "/", nil)
19 | m.ServeHTTP(res, req)
20 | if res.Body.String() != str {
21 | t.Errorf("WriteString Error")
22 | }
23 | }
24 |
25 | func TestAbort(t *testing.T) {
26 | str := "Hello World!"
27 | m := martini.Classic()
28 | m.Use(ContextWithCookieSecret("secret"))
29 | m.Get("/", func(ctx *Context) {
30 | ctx.Abort(401, str)
31 | })
32 | res := httptest.NewRecorder()
33 | req, _ := http.NewRequest("GET", "/", nil)
34 | m.ServeHTTP(res, req)
35 | if res.Code != 401 {
36 | t.Error("Response Code Error")
37 | }
38 | if res.Body.String() != str {
39 | t.Error("Abort Content Error")
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/wercker.yml:
--------------------------------------------------------------------------------
1 | box: wercker/golang@1.1.1
--------------------------------------------------------------------------------