├── LICENSE ├── README.md ├── examples └── net-http.go ├── jwt.go └── jwt_test.go /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015, 2016 Matthew Dillon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jwt 2 | 3 | [![GoDoc](https://godoc.org/github.com/thermokarst/jwt?status.svg)](https://godoc.org/github.com/thermokarst/jwt) 4 | 5 | A simple (bring your own logic), opinionated Go net/http middleware for integrating 6 | JSON Web Tokens into your application: 7 | 8 | ```go 9 | package main 10 | 11 | import ( 12 | "encoding/json" 13 | "errors" 14 | "fmt" 15 | "net/http" 16 | "time" 17 | 18 | "github.com/thermokarst/jwt" 19 | ) 20 | 21 | func protectMe(w http.ResponseWriter, r *http.Request) { 22 | fmt.Fprintf(w, "secured") 23 | } 24 | 25 | func dontProtectMe(w http.ResponseWriter, r *http.Request) { 26 | fmt.Fprintf(w, "not secured") 27 | } 28 | 29 | func auth(email string, password string) error { 30 | // Hard-code a user 31 | if email != "test" || password != "test" { 32 | return errors.New("invalid credentials") 33 | } 34 | return nil 35 | } 36 | 37 | func setClaims(id string) (map[string]interface{}, error) { 38 | currentTime := time.Now() 39 | return map[string]interface{}{ 40 | "iat": currentTime.Unix(), 41 | "exp": currentTime.Add(time.Minute * 60 * 24).Unix(), 42 | }, nil 43 | } 44 | 45 | func verifyClaims(claims []byte, r *http.Request) error { 46 | currentTime := time.Now() 47 | var c struct { 48 | Iat int64 49 | Exp int64 50 | } 51 | _ = json.Unmarshal(claims, &c) 52 | if currentTime.After(time.Unix(c.Exp, 0)) { 53 | return errors.New("this token has expired") 54 | } 55 | return nil 56 | } 57 | 58 | func main() { 59 | config := &jwt.Config{ 60 | Secret: "password", 61 | Auth: auth, 62 | Claims: setClaims, 63 | } 64 | 65 | j, err := jwt.New(config) 66 | if err != nil { 67 | panic(err) 68 | } 69 | 70 | protect := http.HandlerFunc(protectMe) 71 | dontProtect := http.HandlerFunc(dontProtectMe) 72 | 73 | http.Handle("/authenticate", j.Authenticate()) 74 | http.Handle("/secure", j.Secure(protect, verifyClaims)) 75 | http.Handle("/insecure", dontProtect) 76 | http.ListenAndServe(":8080", nil) 77 | } 78 | ``` 79 | 80 | ```shell 81 | $ http GET :8080/secure 82 | 83 | HTTP/1.1 401 Unauthorized 84 | Content-Length: 23 85 | Content-Type: text/plain; charset=utf-8 86 | Date: Fri, 08 May 2015 06:43:35 GMT 87 | 88 | please provide a token 89 | 90 | $ http POST :8080/authenticate email=test password=test 91 | 92 | HTTP/1.1 200 OK 93 | Content-Length: 130 94 | Content-Type: text/plain; charset=utf-8 95 | Date: Fri, 08 May 2015 06:31:42 GMT 96 | 97 | eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE0MzExNTI0ODAsImlhdCI6MTQzMTA2NjA4MH0=.UbJmLqOF4bTH/8+o6CrZfoi1Fu7zTDfCV0kwMQyzmos= 98 | 99 | $ http GET :8080/secure Authorization:"Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE0MzExNTI0ODAsImlhdCI6MTQzMTA2NjA4MH0=.UbJmLqOF4bTH/8+o6CrZfoi1Fu7zTDfCV0kwMQyzmos=" 100 | 101 | HTTP/1.1 200 OK 102 | Content-Length: 7 103 | Content-Type: text/plain; charset=utf-8 104 | Date: Fri, 08 May 2015 06:38:30 GMT 105 | 106 | secured 107 | ``` 108 | 109 | # Installation 110 | 111 | $ go get github.com/thermokarst/jwt 112 | 113 | # Usage 114 | 115 | **This is a work in progress** 116 | 117 | Create a new instance of the middleware by passing in a configuration for your 118 | app. The config includes a shared secret (this middleware only builds HS256 119 | tokens), a function for authenticating user, and a function for generating a 120 | user's claims. The idea here is to be dead-simple for someone to drop this into 121 | a project and hit the ground running. 122 | 123 | ```go 124 | config := &jwt.Config{ 125 | Secret: "password", 126 | Auth: authFunc, // func(string, string) error 127 | Claims: claimsFunc, // func(string) (map[string]interface{}) 128 | } 129 | j, err := jwt.New(config) 130 | ``` 131 | 132 | You can also customize the field names by specifying `IdentityField` and 133 | `VerifyField` in the `Config` struct, if you want the credentials to be 134 | something other than `"email"` and `"password"`. 135 | 136 | Once the middleware is instantiated, create a route for users to generate a JWT 137 | at. 138 | 139 | ```go 140 | http.Handle("/authenticate", j.Authenticate()) 141 | ``` 142 | 143 | The auth function takes two arguments (the identity, and the authorization 144 | key), POSTed as a JSON-encoded body: 145 | 146 | {"email":"user@example.com","password":"mypassword"} 147 | 148 | These fields are static for now, but will be customizable in a later release. 149 | The claims are generated using the claims function provided in the 150 | configuration. This function is only run if the auth function verifies the 151 | user's identity, then the user's unique identifier (primary key id, UUID, 152 | email, whatever you want) is passed as a string to the claims function. Your 153 | function should return a `map[string]interface{}` with the desired claimset. 154 | 155 | Routes are "secured" by calling the `Secure(http.Handler, jwt.VerifyClaimsFunc)` 156 | handler: 157 | 158 | ```go 159 | http.Handle("/secureendpoint", j.Secure(someHandler, verifyClaimsFunc)) 160 | ``` 161 | 162 | The claims verification function is called after the token has been parsed and 163 | validated: this is where you control how your application handles the claims 164 | contained within the JWT. 165 | 166 | # Motivation 167 | 168 | This work was prepared for a crypto/security class at the University of Alaska 169 | Fairbanks. I hope to use this in some of my projects, but please proceed with 170 | caution if you adopt this for your own work. As well, the API is still quite 171 | unstable, so be prepared for handling any changes. 172 | 173 | # Tests 174 | 175 | $ go test 176 | 177 | # Contributors 178 | 179 | Matthew Ryan Dillon (matthewrdillon@gmail.com) 180 | 181 | -------------------------------------------------------------------------------- /examples/net-http.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/thermokarst/jwt" 11 | ) 12 | 13 | func protectMe(w http.ResponseWriter, r *http.Request) { 14 | fmt.Fprintf(w, "secured") 15 | } 16 | 17 | func dontProtectMe(w http.ResponseWriter, r *http.Request) { 18 | fmt.Fprintf(w, "not secured") 19 | } 20 | 21 | func auth(email string, password string) error { 22 | // Hard-code a user 23 | if email != "test" || password != "test" { 24 | return errors.New("invalid credentials") 25 | } 26 | return nil 27 | } 28 | 29 | func setClaims(id string) (map[string]interface{}, error) { 30 | currentTime := time.Now() 31 | return map[string]interface{}{ 32 | "iat": currentTime.Unix(), 33 | "exp": currentTime.Add(time.Minute * 60 * 24).Unix(), 34 | }, nil 35 | } 36 | 37 | func verifyClaims(claims []byte, r *http.Request) error { 38 | currentTime := time.Now() 39 | var c struct { 40 | Iat int64 41 | Exp int64 42 | } 43 | _ = json.Unmarshal(claims, &c) 44 | if currentTime.After(time.Unix(c.Exp, 0)) { 45 | return errors.New("this token has expired") 46 | } 47 | return nil 48 | } 49 | 50 | func main() { 51 | config := &jwt.Config{ 52 | Secret: "password", 53 | Auth: auth, 54 | Claims: setClaims, 55 | } 56 | 57 | j, err := jwt.New(config) 58 | if err != nil { 59 | panic(err) 60 | } 61 | 62 | protect := http.HandlerFunc(protectMe) 63 | dontProtect := http.HandlerFunc(dontProtectMe) 64 | 65 | http.Handle("/authenticate", j.GenerateToken()) 66 | http.Handle("/secure", j.Secure(protect, verifyClaims)) 67 | http.Handle("/insecure", dontProtect) 68 | http.ListenAndServe(":8080", nil) 69 | } 70 | -------------------------------------------------------------------------------- /jwt.go: -------------------------------------------------------------------------------- 1 | // Package jwt implements a simple, opinionated net/http-compatible middleware for 2 | // integrating JSON Web Tokens (JWT). 3 | package jwt 4 | 5 | import ( 6 | "crypto/hmac" 7 | "crypto/sha256" 8 | "encoding/base64" 9 | "encoding/json" 10 | "errors" 11 | "fmt" 12 | "log" 13 | "net/http" 14 | "strings" 15 | ) 16 | 17 | const ( 18 | typ = "JWT" 19 | alg = "HS256" 20 | ) 21 | 22 | // Errors introduced by this package. 23 | var ( 24 | ErrMissingConfig = errors.New("missing configuration") 25 | ErrMissingSecret = errors.New("please provide a shared secret") 26 | ErrMissingAuthFunc = errors.New("please provide an auth function") 27 | ErrMissingClaimsFunc = errors.New("please provide a claims function") 28 | ErrEncoding = errors.New("error encoding value") 29 | ErrDecoding = errors.New("error decoding value") 30 | ErrMissingToken = errors.New("please provide a token") 31 | ErrMalformedToken = errors.New("please provide a valid token") 32 | ErrInvalidSignature = errors.New("signature could not be verified") 33 | ErrParsingCredentials = errors.New("error parsing credentials") 34 | ErrInvalidMethod = errors.New("invalid request method") 35 | ) 36 | 37 | // AuthFunc is a type for delegating user authentication to the client-code. 38 | type AuthFunc func(string, string) error 39 | 40 | // ClaimsFunc is a type for delegating claims generation to the client-code. 41 | type ClaimsFunc func(string) (map[string]interface{}, error) 42 | 43 | // VerifyClaimsFunc is a type for processing and validating JWT claims on one 44 | // or more routes in the client-code. 45 | type VerifyClaimsFunc func([]byte, *http.Request) error 46 | 47 | // Config is a container for setting up the JWT middleware. 48 | type Config struct { 49 | Secret string 50 | Auth AuthFunc 51 | Claims ClaimsFunc 52 | IdentityField string 53 | VerifyField string 54 | } 55 | 56 | // Middleware is where we store all the specifics related to the client's 57 | // JWT needs. 58 | type Middleware struct { 59 | secret string 60 | auth AuthFunc 61 | claims ClaimsFunc 62 | identityField string 63 | verifyField string 64 | } 65 | 66 | // New creates a new Middleware from a user-specified configuration. 67 | func New(c *Config) (*Middleware, error) { 68 | if c == nil { 69 | return nil, ErrMissingConfig 70 | } 71 | if c.Secret == "" { 72 | return nil, ErrMissingSecret 73 | } 74 | if c.Auth == nil { 75 | return nil, ErrMissingAuthFunc 76 | } 77 | if c.Claims == nil { 78 | return nil, ErrMissingClaimsFunc 79 | } 80 | if c.IdentityField == "" { 81 | c.IdentityField = "email" 82 | } 83 | if c.VerifyField == "" { 84 | c.VerifyField = "password" 85 | } 86 | m := &Middleware{ 87 | secret: c.Secret, 88 | auth: c.Auth, 89 | claims: c.Claims, 90 | identityField: c.IdentityField, 91 | verifyField: c.VerifyField, 92 | } 93 | return m, nil 94 | } 95 | 96 | // Secure wraps a client-specified http.Handler with a verification function, 97 | // as well as-built in parsing of the request's JWT. This allows each handler 98 | // to have it's own verification/validation protocol. 99 | func (m *Middleware) Secure(h http.Handler, v VerifyClaimsFunc) http.Handler { 100 | secureHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { 101 | var token string 102 | 103 | authHeader := r.Header.Get("Authorization") 104 | if authHeader == "" { 105 | token = r.FormValue("token") 106 | if token == "" { 107 | return &jwtError{status: http.StatusUnauthorized, err: ErrMissingToken} 108 | } 109 | } else { 110 | tokenParts := strings.Split(authHeader, " ") 111 | if len(tokenParts) != 2 { 112 | return &jwtError{status: http.StatusUnauthorized, err: ErrMalformedToken} 113 | } 114 | token = tokenParts[1] 115 | } 116 | 117 | if status, message, err := m.VerifyToken(token, v, r); err != nil { 118 | return &jwtError{ 119 | status: status, 120 | message: message, 121 | err: err, 122 | } 123 | } 124 | 125 | // If we make it this far, process the downstream handler 126 | h.ServeHTTP(w, r) 127 | return nil 128 | } 129 | return errorHandler(secureHandler) 130 | } 131 | 132 | // Authenticate returns a middleware that parsing an incoming request for a JWT, 133 | // calls the client-supplied auth function, and if successful, returns a JWT to 134 | // the requester. 135 | func (m *Middleware) Authenticate() http.Handler { 136 | generateHandler := func(w http.ResponseWriter, r *http.Request) *jwtError { 137 | if r.Method != "POST" { 138 | return &jwtError{ 139 | status: http.StatusBadRequest, 140 | err: ErrInvalidMethod, 141 | message: "receiving request", 142 | } 143 | } 144 | 145 | b := make(map[string]string, 0) 146 | contentType := r.Header.Get("content-type") 147 | switch contentType { 148 | case "application/x-www-form-urlencoded", "application/x-www-form-urlencoded; charset=UTF-8": 149 | identity, verify := r.FormValue(m.identityField), r.FormValue(m.verifyField) 150 | if identity == "" || verify == "" { 151 | return &jwtError{ 152 | status: http.StatusInternalServerError, 153 | err: ErrParsingCredentials, 154 | message: "parsing authorization", 155 | } 156 | } 157 | b[m.identityField], b[m.verifyField] = identity, verify 158 | default: 159 | err := json.NewDecoder(r.Body).Decode(&b) 160 | if err != nil { 161 | return &jwtError{ 162 | status: http.StatusInternalServerError, 163 | err: ErrParsingCredentials, 164 | message: "parsing authorization", 165 | } 166 | } 167 | } 168 | 169 | // Check if required fields are in the body 170 | if _, ok := b[m.identityField]; !ok { 171 | return &jwtError{ 172 | status: http.StatusBadRequest, 173 | err: ErrParsingCredentials, 174 | message: "parsing credentials, missing identity field", 175 | } 176 | } 177 | if _, ok := b[m.verifyField]; !ok { 178 | return &jwtError{ 179 | status: http.StatusBadRequest, 180 | err: ErrParsingCredentials, 181 | message: "parsing credentials, missing verify field", 182 | } 183 | } 184 | err := m.auth(b[m.identityField], b[m.verifyField]) 185 | if err != nil { 186 | return &jwtError{ 187 | status: http.StatusInternalServerError, 188 | err: err, 189 | message: "performing authorization", 190 | } 191 | } 192 | response, err := m.CreateToken(b[m.identityField]) 193 | if err != nil { 194 | return &jwtError{ 195 | status: http.StatusInternalServerError, 196 | err: err, 197 | message: response, 198 | } 199 | } 200 | w.Write([]byte(response)) 201 | return nil 202 | } 203 | 204 | return errorHandler(generateHandler) 205 | } 206 | 207 | // CreateToken generates a token from a user's identity 208 | func (m *Middleware) CreateToken(identity string) (string, error) { 209 | // For now, the header will be static 210 | header, err := encode(fmt.Sprintf(`{"typ":%q,"alg":%q}`, typ, alg)) 211 | if err != nil { 212 | return "encoding header", ErrEncoding 213 | } 214 | 215 | // Generate claims for user 216 | claims, err := m.claims(identity) 217 | if err != nil { 218 | return "generating claims", err 219 | } 220 | 221 | claimsJSON, err := json.Marshal(claims) 222 | if err != nil { 223 | return "mashalling claims", ErrEncoding 224 | } 225 | 226 | claimsSet, err := encode(claimsJSON) 227 | if err != nil { 228 | return "encoding claims", ErrEncoding 229 | } 230 | 231 | toSig := strings.Join([]string{header, claimsSet}, ".") 232 | 233 | h := hmac.New(sha256.New, []byte(m.secret)) 234 | h.Write([]byte(toSig)) 235 | sig, err := encode(h.Sum(nil)) 236 | if err != nil { 237 | return "encoding signature", ErrEncoding 238 | } 239 | 240 | response := strings.Join([]string{toSig, sig}, ".") 241 | return response, nil 242 | } 243 | 244 | // VerifyToken verifies a token 245 | func (m *Middleware) VerifyToken(token string, v VerifyClaimsFunc, r *http.Request) (int, string, error) { 246 | tokenParts := strings.Split(token, ".") 247 | if len(tokenParts) != 3 { 248 | return http.StatusUnauthorized, "", ErrMalformedToken 249 | } 250 | 251 | // First, verify JOSE header 252 | header, err := decode(tokenParts[0]) 253 | if err != nil { 254 | return http.StatusInternalServerError, fmt.Sprintf("decoding header (%v)", tokenParts[0]), err 255 | } 256 | var t struct { 257 | Typ string 258 | Alg string 259 | } 260 | err = json.Unmarshal(header, &t) 261 | if err != nil { 262 | return http.StatusInternalServerError, fmt.Sprintf("unmarshalling header (%s)", header), ErrMalformedToken 263 | } 264 | 265 | // Then, verify signature 266 | mac := hmac.New(sha256.New, []byte(m.secret)) 267 | message := []byte(strings.Join([]string{tokenParts[0], tokenParts[1]}, ".")) 268 | mac.Write(message) 269 | expectedMac, err := encode(mac.Sum(nil)) 270 | if err != nil { 271 | return http.StatusInternalServerError, "", err 272 | } 273 | if !hmac.Equal([]byte(tokenParts[2]), []byte(expectedMac)) { 274 | return http.StatusUnauthorized, fmt.Sprintf("checking signature (%v)", tokenParts[2]), ErrInvalidSignature 275 | } 276 | 277 | // Finally, check claims 278 | claimSet, err := decode(tokenParts[1]) 279 | if err != nil { 280 | return http.StatusInternalServerError, "decoding claims", ErrDecoding 281 | } 282 | err = v(claimSet, r) 283 | if err != nil { 284 | return http.StatusUnauthorized, "handling claims callback", err 285 | } 286 | 287 | return 200, "", nil 288 | } 289 | 290 | type jwtError struct { 291 | status int 292 | message string 293 | err error 294 | } 295 | 296 | type errorHandler func(http.ResponseWriter, *http.Request) *jwtError 297 | 298 | func (e errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 299 | if err := e(w, r); err != nil { 300 | if err.message != "" { 301 | log.Printf("error (%v) while %s", err.err, err.message) 302 | } 303 | http.Error(w, err.err.Error(), err.status) 304 | } 305 | } 306 | 307 | func encode(s interface{}) (string, error) { 308 | var r []byte 309 | switch v := s.(type) { 310 | case string: 311 | r = []byte(v) 312 | case []byte: 313 | r = v 314 | default: 315 | return "", ErrEncoding 316 | } 317 | return base64.RawURLEncoding.EncodeToString(r), nil 318 | } 319 | 320 | func decode(s string) ([]byte, error) { 321 | return base64.RawURLEncoding.DecodeString(s) 322 | } 323 | -------------------------------------------------------------------------------- /jwt_test.go: -------------------------------------------------------------------------------- 1 | package jwt 2 | 3 | import ( 4 | "bytes" 5 | "crypto/hmac" 6 | "crypto/sha256" 7 | "encoding/base64" 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "io/ioutil" 12 | "net/http" 13 | "net/http/httptest" 14 | "strings" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 20 | w.Write([]byte("test")) 21 | }) 22 | 23 | var authFunc = func(email, password string) error { 24 | return nil 25 | } 26 | 27 | var claimsFunc = func(id string) (map[string]interface{}, error) { 28 | currentTime := time.Now() 29 | return map[string]interface{}{ 30 | "iat": currentTime.Unix(), 31 | "exp": currentTime.Add(time.Minute * 60 * 24).Unix(), 32 | }, nil 33 | } 34 | 35 | var verifyClaimsFunc = func(claims []byte, r *http.Request) error { 36 | currentTime := time.Now() 37 | var c struct { 38 | Exp int64 39 | Iat int64 40 | } 41 | err := json.Unmarshal(claims, &c) 42 | if err != nil { 43 | return err 44 | } 45 | if currentTime.After(time.Unix(c.Exp, 0)) { 46 | return errors.New("expired") 47 | } 48 | return nil 49 | } 50 | 51 | func newMiddlewareOrFatal(t *testing.T) *Middleware { 52 | config := &Config{ 53 | Secret: "password", 54 | Auth: authFunc, 55 | Claims: claimsFunc, 56 | } 57 | middleware, err := New(config) 58 | if err != nil { 59 | t.Fatalf("new middleware: %v", err) 60 | } 61 | return middleware 62 | } 63 | 64 | func newToken(t *testing.T) (string, *Middleware) { 65 | middleware := newMiddlewareOrFatal(t) 66 | authBody := map[string]interface{}{ 67 | "email": "user@example.com", 68 | "password": "password", 69 | } 70 | body, err := json.Marshal(authBody) 71 | if err != nil { 72 | t.Error(err) 73 | } 74 | 75 | ts := httptest.NewServer(middleware.Authenticate()) 76 | defer ts.Close() 77 | 78 | resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) 79 | respBody, err := ioutil.ReadAll(resp.Body) 80 | resp.Body.Close() 81 | if err != nil { 82 | t.Error(err) 83 | } 84 | return string(respBody), middleware 85 | } 86 | 87 | func TestNewJWTMiddleware(t *testing.T) { 88 | middleware := newMiddlewareOrFatal(t) 89 | if middleware.secret != "password" { 90 | t.Errorf("wanted password, got %v", middleware.secret) 91 | } 92 | err := middleware.auth("", "") 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | claimsVal, err := middleware.claims("1") 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | if _, ok := claimsVal["iat"]; !ok { 101 | t.Errorf("wanted a claims set, got %v", claimsVal) 102 | } 103 | if middleware.identityField != "email" { 104 | t.Errorf("wanted email, got %v", middleware.identityField) 105 | } 106 | if middleware.verifyField != "password" { 107 | t.Errorf("wanted password, got %v", middleware.verifyField) 108 | } 109 | } 110 | 111 | func TestNewJWTMiddlewareNoConfig(t *testing.T) { 112 | cases := map[*Config]error{ 113 | nil: ErrMissingConfig, 114 | &Config{}: ErrMissingSecret, 115 | &Config{ 116 | Auth: authFunc, 117 | Claims: claimsFunc, 118 | }: ErrMissingSecret, 119 | &Config{ 120 | Secret: "secret", 121 | Claims: claimsFunc, 122 | }: ErrMissingAuthFunc, 123 | &Config{ 124 | Auth: authFunc, 125 | Secret: "secret", 126 | }: ErrMissingClaimsFunc, 127 | } 128 | for config, jwtErr := range cases { 129 | _, err := New(config) 130 | if err != jwtErr { 131 | t.Errorf("wanted error: %v, got error: %v using config: %v", jwtErr, err, config) 132 | } 133 | } 134 | } 135 | func TestGenerateTokenHandler(t *testing.T) { 136 | token, m := newToken(t) 137 | j := strings.Split(token, ".") 138 | 139 | header := base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"HS256"}`)) 140 | if j[0] != header { 141 | t.Errorf("wanted %v, got %v", header, j[0]) 142 | } 143 | 144 | claims, err := base64.RawURLEncoding.DecodeString(j[1]) 145 | var c struct { 146 | Exp int 147 | Iat int 148 | } 149 | err = json.Unmarshal(claims, &c) 150 | if err != nil { 151 | t.Error(err) 152 | } 153 | duration := time.Duration(c.Exp-c.Iat) * time.Second 154 | d := time.Minute * 60 * 24 155 | if duration != d { 156 | t.Errorf("wanted %v, got %v", d, duration) 157 | } 158 | mac := hmac.New(sha256.New, []byte(m.secret)) 159 | message := []byte(strings.Join([]string{j[0], j[1]}, ".")) 160 | mac.Write(message) 161 | expectedMac := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) 162 | if !hmac.Equal([]byte(j[2]), []byte(expectedMac)) { 163 | t.Errorf("wanted %v, got %v", expectedMac, j[2]) 164 | } 165 | } 166 | 167 | func TestSecureHandlerNoToken(t *testing.T) { 168 | middleware := newMiddlewareOrFatal(t) 169 | resp := httptest.NewRecorder() 170 | req, _ := http.NewRequest("GET", "http://example.com", nil) 171 | middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) 172 | body := strings.TrimSpace(resp.Body.String()) 173 | if body != ErrMissingToken.Error() { 174 | t.Errorf("wanted %q, got %q", ErrMissingToken.Error(), body) 175 | } 176 | } 177 | 178 | func TestSecureHandlerBadToken(t *testing.T) { 179 | middleware := newMiddlewareOrFatal(t) 180 | resp := httptest.NewRecorder() 181 | req, _ := http.NewRequest("GET", "http://example.com", nil) 182 | req.Header.Set("Authorization", "Bearer abcdefg") 183 | middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) 184 | body := strings.TrimSpace(resp.Body.String()) 185 | if body != ErrMalformedToken.Error() { 186 | t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) 187 | } 188 | 189 | resp = httptest.NewRecorder() 190 | req, _ = http.NewRequest("GET", "http://example.com", nil) 191 | req.Header.Set("Authorization", "Bearer abcd.abcd.abcd") 192 | middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) 193 | body = strings.TrimSpace(resp.Body.String()) 194 | if body != ErrMalformedToken.Error() { 195 | t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) 196 | } 197 | } 198 | 199 | func TestSecureHandlerBadSignature(t *testing.T) { 200 | token, middleware := newToken(t) 201 | parts := strings.Split(token, ".") 202 | token = strings.Join([]string{parts[0], parts[1], "abcd"}, ".") 203 | resp := httptest.NewRecorder() 204 | req, _ := http.NewRequest("GET", "http://example.com", nil) 205 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) 206 | middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) 207 | body := strings.TrimSpace(resp.Body.String()) 208 | if body != ErrInvalidSignature.Error() { 209 | t.Errorf("wanted %s, got %s", ErrInvalidSignature.Error(), body) 210 | } 211 | } 212 | 213 | func TestSecureHandlerGoodToken(t *testing.T) { 214 | token, middleware := newToken(t) 215 | resp := httptest.NewRecorder() 216 | req, _ := http.NewRequest("GET", "http://example.com", nil) 217 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) 218 | middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) 219 | body := strings.TrimSpace(resp.Body.String()) 220 | if body != "test" { 221 | t.Errorf("wanted %s, got %s", "test", body) 222 | } 223 | } 224 | 225 | func TestGenerateTokenHandlerNotPOST(t *testing.T) { 226 | middleware := newMiddlewareOrFatal(t) 227 | resp := httptest.NewRecorder() 228 | req, _ := http.NewRequest("PUT", "http://example.com", nil) 229 | middleware.Authenticate().ServeHTTP(resp, req) 230 | body := strings.TrimSpace(resp.Body.String()) 231 | if body != ErrInvalidMethod.Error() { 232 | t.Errorf("wanted %q, got %q", ErrInvalidMethod.Error(), body) 233 | } 234 | } 235 | 236 | func TestMalformedAuthorizationHeader(t *testing.T) { 237 | _, middleware := newToken(t) 238 | token := "hello!" 239 | resp := httptest.NewRecorder() 240 | req, _ := http.NewRequest("GET", "http://example.com", nil) 241 | req.Header.Set("Authorization", token) // No "Bearer " portion of header 242 | middleware.Secure(testHandler, verifyClaimsFunc).ServeHTTP(resp, req) 243 | body := strings.TrimSpace(resp.Body.String()) 244 | if body != ErrMalformedToken.Error() { 245 | t.Errorf("wanted %q, got %q", ErrMalformedToken.Error(), body) 246 | } 247 | } 248 | --------------------------------------------------------------------------------