├── .travis.yml ├── LICENSE ├── README.md ├── basic.go ├── basic_test.go ├── example ├── authserver │ └── main.go └── resourceserver │ └── main.go ├── go.mod ├── go.sum ├── middleware.go ├── middleware_test.go ├── model.go ├── render.go ├── security.go ├── security_test.go ├── server.go └── server_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.17 5 | 6 | before_install: 7 | - go get -t -v ./... 8 | 9 | script: 10 | - go test -coverprofile=coverage.txt -covermode=atomic 11 | 12 | after_success: 13 | - bash <(curl -s https://codecov.io/bash) 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 go-chi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # oauth middleware 2 | OAuth 2.0 Authorization Server & Authorization Middleware for [go-chi](https://github.com/go-chi/chi) 3 | 4 | This library was ported to go-chi from https://github.com/maxzerbini/oauth by [jeffreydwalter](https://github.com/jeffreydwalter/oauth). 5 | 6 | This library offers an OAuth 2.0 Authorization Server based on go-chi and an Authorization Middleware usable in Resource Servers developed with go-chi. 7 | 8 | 9 | ## Build status 10 | [![Build Status](https://app.travis-ci.com/go-chi/oauth.svg?branch=master)](https://app.travis-ci.com/github/go-chi/oauth) 11 | 12 | ## Authorization Server 13 | The Authorization Server is implemented by the struct _OAuthBearerServer_ that manages two grant types of authorizations (password and client_credentials). 14 | This Authorization Server is made to provide an authorization token usable for consuming resources API. 15 | 16 | ### Password grant type 17 | _OAuthBearerServer_ supports the password grant type, allowing the token generation for username / password credentials. 18 | 19 | ### Client Credentials grant type 20 | _OAuthBearerServer_ supports the client_credentials grant type, allowing the token generation for client_id / client_secret credentials. 21 | 22 | ### Authorization Code and Implicit grant type 23 | These grant types are currently partially supported implementing AuthorizationCodeVerifier interface. The method ValidateCode is called during the phase two of the authorization_code grant type evalutations. 24 | 25 | ### Refresh token grant type 26 | If authorization token will expire, the client can regenerate the token calling the authorization server and using the refresh_token grant type. 27 | 28 | ## Authorization Middleware 29 | The go-chi middleware _BearerAuthentication_ intercepts the resource server calls and authorizes only resource requests containing a valid bearer token. 30 | 31 | ## Token Formatter 32 | Authorization Server crypts the token using the Token Formatter and Authorization Middleware decrypts the token using the same Token Formatter. 33 | This library contains a default implementation of the formatter interface called _SHA256RC4TokenSecureFormatter_ based on the algorithms SHA256 and RC4. 34 | Programmers can develop their Token Formatter implementing the interface _TokenSecureFormatter_ and this is really recommended before publishing the API in a production environment. 35 | 36 | ## Credentials Verifier 37 | The interface _CredentialsVerifier_ defines the hooks called during the token generation process. 38 | The methods are called in this order: 39 | - _ValidateUser() or ValidateClient()_ called first for credentials verification 40 | - _AddClaims()_ used for add information to the token that will be encrypted 41 | - _StoreTokenID()_ called after the token generation but before the response, programmers can use this method for storing the generated IDs 42 | - _AddProperties()_ used for add clear information to the response 43 | 44 | There is another method in the _CredentialsVerifier_ interface that is involved during the refresh token process. 45 | In this case the methods are called in this order: 46 | - _ValidateTokenID()_ called first for TokenID verification, the method receives the TokenID related to the token associated to the refresh token 47 | - _AddClaims()_ used for add information to the token that will be encrypted 48 | - _StoreTokenID()_ called after the token regeneration but before the response, programmers can use this method for storing the generated IDs 49 | - _AddProperties()_ used for add clear information to the response 50 | 51 | ## Authorization Server usage example 52 | This snippet shows how to create an authorization server 53 | ```Go 54 | func main() { 55 | r := chi.NewRouter() 56 | r.Use(middleware.Logger) 57 | r.Use(middleware.Recoverer) 58 | 59 | s := oauth.NewOAuthBearerServer( 60 | "mySecretKey-10101", 61 | time.Second*120, 62 | &TestUserVerifier{}, 63 | nil) 64 | 65 | r.Post("/token", s.UserCredentials) 66 | r.Post("/auth", s.ClientCredentials) 67 | http.ListenAndServe(":8080", r) 68 | } 69 | ``` 70 | See [/test/authserver/main.go](https://github.com/go-chi/oauth/blob/master/test/authserver/main.go) for the full example. 71 | 72 | ## Authorization Middleware usage example 73 | This snippet shows how to use the middleware 74 | ```Go 75 | r.Route("/", func(r chi.Router) { 76 | // use the Bearer Authentication middleware 77 | r.Use(oauth.Authorize("mySecretKey-10101", nil)) 78 | 79 | r.Get("/customers", GetCustomers) 80 | r.Get("/customers/{id}/orders", GetOrders) 81 | } 82 | ``` 83 | See [/test/resourceserver/main.go](https://github.com/go-chi/oauth/blob/master/test/resourceserver/main.go) for the full example. 84 | 85 | Note that the authorization server and the authorization middleware are both using the same token formatter and the same secret key for encryption/decryption. 86 | 87 | ## Reference 88 | - [OAuth 2.0 RFC](https://tools.ietf.org/html/rfc6749) 89 | - [OAuth 2.0 Bearer Token Usage RFC](https://tools.ietf.org/html/rfc6750) 90 | 91 | ## License 92 | [MIT](https://github.com/go-chi/oauth/blob/master/LICENSE) 93 | -------------------------------------------------------------------------------- /basic.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | // GetBasicAuthentication get username and password from Authorization header 11 | func GetBasicAuthentication(r *http.Request) (username, password string, err error) { 12 | if header := r.Header.Get("Authorization"); header != "" { 13 | if strings.ToLower(header[:6]) == "basic " { 14 | // decode header value 15 | value, err := base64.StdEncoding.DecodeString(header[6:]) 16 | if err != nil { 17 | return "", "", err 18 | } 19 | strValue := string(value) 20 | if ind := strings.Index(strValue, ":"); ind > 0 { 21 | return strValue[:ind], strValue[ind+1:], nil 22 | } 23 | } 24 | } 25 | return "", "", nil 26 | } 27 | 28 | // Check Basic Authorization header credentials 29 | func CheckBasicAuthentication(username, password string, r *http.Request) error { 30 | u, p, err := GetBasicAuthentication(r) 31 | if err != nil { 32 | return err 33 | } 34 | if u != "" && p != "" && u != username && p != password { 35 | return errors.New("invalid credentials") 36 | } 37 | return nil 38 | } 39 | -------------------------------------------------------------------------------- /basic_test.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "encoding/base64" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | func TestGetBasicAuthentication(t *testing.T) { 10 | req, _ := http.NewRequest("GET", "/token", nil) 11 | req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password123456"))) 12 | 13 | username, password, err := GetBasicAuthentication(req) 14 | if err != nil { 15 | t.Fatalf("Error %s", err.Error()) 16 | } else { 17 | if username != "admin" { 18 | t.Fatalf("Wrong Username = %s", username) 19 | } 20 | if password != "password123456" { 21 | t.Fatalf("Wrong Username = %s", password) 22 | } 23 | } 24 | } 25 | 26 | func TestVoidBasicAuthentication(t *testing.T) { 27 | req, _ := http.NewRequest("GET", "/token", nil) 28 | 29 | username, password, err := GetBasicAuthentication(req) 30 | if err != nil { 31 | t.Fatalf("Error %s", err.Error()) 32 | } else { 33 | if username != "" { 34 | t.Fatalf("Wrong Username = %s", username) 35 | } 36 | if password != "" { 37 | t.Fatalf("Wrong Username = %s", password) 38 | } 39 | } 40 | 41 | } 42 | 43 | func TestCheckBasicAuthentication(t *testing.T) { 44 | req, _ := http.NewRequest("GET", "/token", nil) 45 | req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password123456"))) 46 | 47 | err := CheckBasicAuthentication("admin", "password123456", req) 48 | if err != nil { 49 | t.Fatalf("Error %s", err.Error()) 50 | } else { 51 | t.Log("Credentials are OK") 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /example/authserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/go-chi/chi/v5" 9 | "github.com/go-chi/chi/v5/middleware" 10 | "github.com/go-chi/cors" 11 | "github.com/go-chi/oauth" 12 | ) 13 | 14 | /* 15 | Authorization Server Example 16 | 17 | Generate Token using username & password 18 | 19 | POST http://localhost:3000/token 20 | User-Agent: Fiddler 21 | Host: localhost:3000 22 | Content-Length: 50 23 | Content-Type: application/x-www-form-urlencoded 24 | 25 | grant_type=password&username=user01&password=12345 26 | 27 | Generate Token using clientID & secret 28 | 29 | POST http://localhost:3000/auth 30 | User-Agent: Fiddler 31 | Host: localhost:3000 32 | Content-Length: 66 33 | Content-Type: application/x-www-form-urlencoded 34 | 35 | grant_type=client_credentials&client_id=abcdef&client_secret=12345 36 | 37 | RefreshTokenGrant Token 38 | 39 | POST http://localhost:3000/token 40 | User-Agent: Fiddler 41 | Host: localhost:3000 42 | Content-Length: 50 43 | Content-Type: application/x-www-form-urlencoded 44 | 45 | grant_type=refresh_token&refresh_token={the refresh_token obtained in the previous response} 46 | */ 47 | func main() { 48 | r := chi.NewRouter() 49 | r.Use(middleware.Logger) 50 | r.Use(middleware.Recoverer) 51 | r.Use(cors.Handler(cors.Options{ 52 | AllowedOrigins: []string{"*"}, 53 | AllowedMethods: []string{"GET", "PUT", "POST", "DELETE", "HEAD", "OPTION"}, 54 | AllowedHeaders: []string{"User-Agent", "Content-Type", "Accept", "Accept-Encoding", "Accept-Language", "Cache-Control", "Connection", "DNT", "Host", "Origin", "Pragma", "Referer"}, 55 | ExposedHeaders: []string{"Link"}, 56 | AllowCredentials: true, 57 | MaxAge: 300, // Maximum value not ignored by any of major browsers 58 | })) 59 | registerAPI(r) 60 | _ = http.ListenAndServe(":8080", r) 61 | } 62 | 63 | func registerAPI(r *chi.Mux) { 64 | s := oauth.NewBearerServer( 65 | "mySecretKey-10101", 66 | time.Second*120, 67 | &TestUserVerifier{}, 68 | nil) 69 | r.Post("/token", s.UserCredentials) 70 | r.Post("/auth", s.ClientCredentials) 71 | } 72 | 73 | // TestUserVerifier provides user credentials verifier for testing. 74 | type TestUserVerifier struct { 75 | } 76 | 77 | // ValidateUser validates username and password returning an error if the user credentials are wrong 78 | func (*TestUserVerifier) ValidateUser(username, password, scope string, r *http.Request) error { 79 | if username == "user01" && password == "12345" { 80 | return nil 81 | } 82 | 83 | return errors.New("wrong user") 84 | } 85 | 86 | // ValidateClient validates clientID and secret returning an error if the client credentials are wrong 87 | func (*TestUserVerifier) ValidateClient(clientID, clientSecret, scope string, r *http.Request) error { 88 | if clientID == "abcdef" && clientSecret == "12345" { 89 | return nil 90 | } 91 | 92 | return errors.New("wrong client") 93 | } 94 | 95 | // ValidateCode validates token ID 96 | func (*TestUserVerifier) ValidateCode(clientID, clientSecret, code, redirectURI string, r *http.Request) (string, error) { 97 | return "", nil 98 | } 99 | 100 | // AddClaims provides additional claims to the token 101 | func (*TestUserVerifier) AddClaims(tokenType oauth.TokenType, credential, tokenID, scope string, r *http.Request) (map[string]string, error) { 102 | claims := make(map[string]string) 103 | claims["customer_id"] = "1001" 104 | claims["customer_data"] = `{"order_date":"2016-12-14","order_id":"9999"}` 105 | return claims, nil 106 | } 107 | 108 | // AddProperties provides additional information to the token response 109 | func (*TestUserVerifier) AddProperties(tokenType oauth.TokenType, credential, tokenID, scope string, r *http.Request) (map[string]string, error) { 110 | props := make(map[string]string) 111 | props["customer_name"] = "Gopher" 112 | return props, nil 113 | } 114 | 115 | // ValidateTokenID validates token ID 116 | func (*TestUserVerifier) ValidateTokenID(tokenType oauth.TokenType, credential, tokenID, refreshTokenID string) error { 117 | return nil 118 | } 119 | 120 | // StoreTokenID saves the token id generated for the user 121 | func (*TestUserVerifier) StoreTokenID(tokenType oauth.TokenType, credential, tokenID, refreshTokenID string) error { 122 | return nil 123 | } 124 | 125 | -------------------------------------------------------------------------------- /example/resourceserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "net/http" 7 | 8 | "github.com/go-chi/chi/v5" 9 | "github.com/go-chi/chi/v5/middleware" 10 | "github.com/go-chi/cors" 11 | 12 | "github.com/go-chi/oauth" 13 | ) 14 | 15 | /* 16 | Resource Server Example 17 | 18 | Get Customers 19 | 20 | GET http://localhost:3200/customers 21 | User-Agent: Fiddler 22 | Host: localhost:3200 23 | Content-Length: 0 24 | Content-Type: application/json 25 | Authorization: Bearer {access_token} 26 | 27 | Get Orders 28 | 29 | GET http://localhost:3200/customers/12345/orders 30 | User-Agent: Fiddler 31 | Host: localhost:3200 32 | Content-Length: 0 33 | Content-Type: application/json 34 | Authorization: Bearer {access_token} 35 | 36 | {access_token} is produced by the Authorization Server response (see example /test/authserver). 37 | 38 | */ 39 | func main() { 40 | r := chi.NewRouter() 41 | r.Use(middleware.Logger) 42 | r.Use(middleware.Recoverer) 43 | r.Use(cors.Handler(cors.Options{ 44 | AllowedOrigins: []string{"*"}, 45 | AllowedMethods: []string{"GET", "PUT", "POST", "DELETE", "HEAD", "OPTION"}, 46 | AllowedHeaders: []string{"User-Agent", "Content-Type", "Accept", "Accept-Encoding", "Accept-Language", "Cache-Control", "Connection", "DNT", "Host", "Origin", "Pragma", "Referer"}, 47 | ExposedHeaders: []string{"Link"}, 48 | AllowCredentials: true, 49 | MaxAge: 300, // Maximum value not ignored by any of major browsers 50 | })) 51 | registerAPI(r) 52 | _ = http.ListenAndServe(":8081", r) 53 | } 54 | 55 | func registerAPI(r *chi.Mux) { 56 | r.Route("/", func(r chi.Router) { 57 | // use the Bearer Authentication middleware 58 | r.Use(oauth.Authorize("mySecretKey-10101", nil)) 59 | r.Get("/customers", GetCustomers) 60 | r.Get("/customers/{id}/orders", GetOrders) 61 | }) 62 | } 63 | 64 | func renderJSON(w http.ResponseWriter, v interface{}, statusCode int) { 65 | buf := &bytes.Buffer{} 66 | enc := json.NewEncoder(buf) 67 | enc.SetEscapeHTML(true) 68 | if err := enc.Encode(v); err != nil { 69 | http.Error(w, err.Error(), http.StatusInternalServerError) 70 | return 71 | } 72 | 73 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 74 | w.WriteHeader(statusCode) 75 | _, _ = w.Write(buf.Bytes()) 76 | } 77 | 78 | func GetCustomers(w http.ResponseWriter, _ *http.Request) { 79 | renderJSON(w, `{ 80 | "Status": "verified", 81 | "Customer": "test001", 82 | "Customer_name": "Max", 83 | "Customer_email": "test@test.com", 84 | }`, http.StatusOK) 85 | } 86 | 87 | func GetOrders(w http.ResponseWriter, _ *http.Request) { 88 | renderJSON(w, `{ 89 | "status": "sent", 90 | "customer": "test001", 91 | "order_id": "100234", 92 | "total_order_items": "199", 93 | }`, http.StatusOK) 94 | } 95 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-chi/oauth 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/go-chi/chi/v5 v5.0.4 7 | github.com/go-chi/cors v1.2.0 8 | github.com/gofrs/uuid v4.0.0+incompatible 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-chi/chi/v5 v5.0.4 h1:5e494iHzsYBiyXQAHHuI4tyJS9M3V84OuX3ufIIGHFo= 2 | github.com/go-chi/chi/v5 v5.0.4/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 3 | github.com/go-chi/cors v1.2.0 h1:tV1g1XENQ8ku4Bq3K9ub2AtgG+p16SmzeMSGTwrOKdE= 4 | github.com/go-chi/cors v1.2.0/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= 5 | github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= 6 | github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= 7 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | type contextKey string 12 | 13 | const ( 14 | CredentialContext contextKey = "oauth.credential" 15 | ClaimsContext contextKey = "oauth.claims" 16 | ScopeContext contextKey = "oauth.scope" 17 | TokenTypeContext contextKey = "oauth.tokentype" 18 | AccessTokenContext contextKey = "oauth.accesstoken" 19 | ) 20 | 21 | // BearerAuthentication middleware for go-chi 22 | type BearerAuthentication struct { 23 | secretKey string 24 | provider *TokenProvider 25 | } 26 | 27 | // NewBearerAuthentication create a BearerAuthentication middleware 28 | func NewBearerAuthentication(secretKey string, formatter TokenSecureFormatter) *BearerAuthentication { 29 | ba := &BearerAuthentication{secretKey: secretKey} 30 | if formatter == nil { 31 | formatter = NewSHA256RC4TokenSecurityProvider([]byte(secretKey)) 32 | } 33 | ba.provider = NewTokenProvider(formatter) 34 | return ba 35 | } 36 | 37 | // Authorize is the OAuth 2.0 middleware for go-chi resource server. 38 | // Authorize creates a BearerAuthentication middleware and return the Authorize method. 39 | func Authorize(secretKey string, formatter TokenSecureFormatter) func(next http.Handler) http.Handler { 40 | return NewBearerAuthentication(secretKey, formatter).Authorize 41 | } 42 | 43 | // Authorize verifies the bearer token authorizing or not the request. 44 | // Token is retrieved from the Authorization HTTP header that respects the format 45 | // Authorization: Bearer {access_token} 46 | func (ba *BearerAuthentication) Authorize(next http.Handler) http.Handler { 47 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 48 | auth := r.Header.Get("Authorization") 49 | token, err := ba.checkAuthorizationHeader(auth) 50 | if err != nil { 51 | renderJSON(w, "Not authorized: "+err.Error(), http.StatusUnauthorized) 52 | return 53 | } 54 | 55 | ctx := r.Context() 56 | ctx = context.WithValue(ctx, CredentialContext, token.Credential) 57 | ctx = context.WithValue(ctx, ClaimsContext, token.Claims) 58 | ctx = context.WithValue(ctx, ScopeContext, token.Scope) 59 | ctx = context.WithValue(ctx, TokenTypeContext, token.TokenType) 60 | ctx = context.WithValue(ctx, AccessTokenContext, auth[7:]) 61 | next.ServeHTTP(w, r.WithContext(ctx)) 62 | }) 63 | } 64 | 65 | // Check header and token. 66 | func (ba *BearerAuthentication) checkAuthorizationHeader(auth string) (t *Token, err error) { 67 | if len(auth) < 7 { 68 | return nil, errors.New("Invalid bearer authorization header") 69 | } 70 | authType := strings.ToLower(auth[:6]) 71 | if authType != "bearer" { 72 | return nil, errors.New("Invalid bearer authorization header") 73 | } 74 | token, err := ba.provider.DecryptToken(auth[7:]) 75 | if err != nil { 76 | return nil, errors.New("Invalid token") 77 | } 78 | if time.Now().UTC().After(token.CreationDate.Add(token.ExpiresIn)) { 79 | return nil, errors.New("Token expired") 80 | } 81 | return token, nil 82 | } 83 | -------------------------------------------------------------------------------- /middleware_test.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | var _mut *BearerAuthentication 9 | 10 | func init() { 11 | _mut = NewBearerAuthentication( 12 | "mySecretKey-10101", 13 | nil) 14 | } 15 | 16 | func TestAuthorizationHeader(t *testing.T) { 17 | r := new(http.Request) 18 | resp, code := _sut.generateTokenResponse("password", "user111", "password111", "", "", "", "", r) 19 | if code != 200 { 20 | t.Fatalf("Error StatusCode = %d", code) 21 | } 22 | t.Logf("Token response: %v", resp) 23 | 24 | header := "Bearer " + resp.(*TokenResponse).Token 25 | token, err := _mut.checkAuthorizationHeader(header) 26 | if err != nil { 27 | t.Fatalf("Error %s", err.Error()) 28 | } 29 | t.Logf("Verified token : %v", token) 30 | } 31 | 32 | func TestExpiredAuthorizationHeader(t *testing.T) { 33 | header := `Bearer wMFZSkQ1kSTbQ9mkHufsfeHCnKo05TSEyLyjSiKOafAUQv7s0NClIgBQSDGKoRzeWfB2G0bKO7EE3P9MnaZNxkx2CtWVfTJkCXsIpo2eyF8Nw+ub5nr4Bxmj6JeOumQMrFogBHMnMT7Em7EhqQO+CICQ3cVX5suqsVkEZ/gkXfjKnnEH6qKYz3S3IN/ry3pVGaQc1wAn/cYqPA1SD+CAYqkriWgIGWJmYv3W9eRSoEWgfgigdM6kmZvlDxTlrACLOvzA/JCXK7qnP8TuFz4yAtNmBoNVw0PTjxIdBFJEC7RdZyQcO3SdgGykxgPqGhiW3Z4F7ZG3mzmy/SoSJIPnmmFIreDWt6+QOsUyeHkEu74G` 34 | _, err := _mut.checkAuthorizationHeader(header) 35 | if err == nil { 36 | t.Fatalf("Error should have occurred") 37 | } 38 | t.Logf("Error : %v", err) 39 | } 40 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type TokenType string 8 | 9 | const ( 10 | BearerToken TokenType = "Bearer" 11 | AuthToken TokenType = "A" 12 | UserToken TokenType = "U" 13 | ClientToken TokenType = "C" 14 | ) 15 | 16 | // TokenResponse is the authorization server response 17 | type TokenResponse struct { 18 | Token string `json:"access_token"` 19 | RefreshToken string `json:"refresh_token"` 20 | TokenType TokenType `json:"token_type"` // bearer 21 | ExpiresIn int64 `json:"expires_in"` // secs 22 | Properties map[string]string `json:"properties"` 23 | } 24 | 25 | // Token structure generated by the authorization server 26 | type Token struct { 27 | ID string `json:"id_token"` 28 | CreationDate time.Time `json:"date"` 29 | ExpiresIn time.Duration `json:"expires_in"` // secs 30 | Credential string `json:"credential"` 31 | Scope string `json:"scope"` 32 | Claims map[string]string `json:"claims"` 33 | TokenType TokenType `json:"type"` 34 | } 35 | 36 | // RefreshToken structure included in the authorization server response 37 | type RefreshToken struct { 38 | CreationDate time.Time `json:"date"` 39 | TokenID string `json:"id_token"` 40 | RefreshTokenID string `json:"id_refresh_token"` 41 | Credential string `json:"credential"` 42 | TokenType TokenType `json:"type"` 43 | Scope string `json:"scope"` 44 | } 45 | -------------------------------------------------------------------------------- /render.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "net/http" 7 | ) 8 | 9 | // renderJSON marshals 'v' to JSON, automatically escaping HTML, setting the 10 | // Content-Type as application/json, and sending the status code header. 11 | func renderJSON(w http.ResponseWriter, v interface{}, statusCode int) { 12 | buf := &bytes.Buffer{} 13 | enc := json.NewEncoder(buf) 14 | enc.SetEscapeHTML(true) 15 | if err := enc.Encode(v); err != nil { 16 | http.Error(w, err.Error(), http.StatusInternalServerError) 17 | return 18 | } 19 | 20 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 21 | w.WriteHeader(statusCode) 22 | _, _ = w.Write(buf.Bytes()) 23 | } 24 | -------------------------------------------------------------------------------- /security.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "crypto/rc4" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "encoding/json" 8 | "errors" 9 | ) 10 | 11 | type TokenSecureFormatter interface { 12 | CryptToken(source []byte) ([]byte, error) 13 | DecryptToken(source []byte) ([]byte, error) 14 | } 15 | 16 | type TokenProvider struct { 17 | secureFormatter TokenSecureFormatter 18 | } 19 | 20 | func NewTokenProvider(formatter TokenSecureFormatter) *TokenProvider { 21 | return &TokenProvider{secureFormatter: formatter} 22 | } 23 | 24 | func (tp *TokenProvider) CryptToken(t *Token) (token string, err error) { 25 | bToken, err := json.Marshal(t) 26 | if err != nil { 27 | return "", err 28 | } 29 | return tp.crypt(bToken) 30 | } 31 | 32 | func (tp *TokenProvider) CryptRefreshToken(t *RefreshToken) (token string, err error) { 33 | bToken, err := json.Marshal(t) 34 | if err != nil { 35 | return "", err 36 | } 37 | return tp.crypt(bToken) 38 | } 39 | 40 | func (tp *TokenProvider) DecryptToken(token string) (t *Token, err error) { 41 | bToken, err := tp.decrypt(token) 42 | if err != nil { 43 | return nil, err 44 | } 45 | err = json.Unmarshal(bToken, &t) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return t, nil 50 | } 51 | 52 | func (tp *TokenProvider) DecryptRefreshTokens(refreshToken string) (refresh *RefreshToken, err error) { 53 | bRefresh, err := tp.decrypt(refreshToken) 54 | if err != nil { 55 | return nil, err 56 | } 57 | err = json.Unmarshal(bRefresh, &refresh) 58 | if err != nil { 59 | return nil, err 60 | } 61 | return refresh, nil 62 | } 63 | 64 | func (tp *TokenProvider) crypt(token []byte) (string, error) { 65 | ctoken, err := tp.secureFormatter.CryptToken(token) 66 | if err != nil { 67 | return "", err 68 | } 69 | return base64.StdEncoding.EncodeToString(ctoken), nil 70 | } 71 | 72 | func (tp *TokenProvider) decrypt(token string) ([]byte, error) { 73 | b, err := base64.StdEncoding.DecodeString(token) 74 | if err != nil { 75 | return nil, err 76 | } 77 | return tp.secureFormatter.DecryptToken(b) 78 | } 79 | 80 | type RC4TokenSecureFormatter struct { 81 | key []byte 82 | cipher *rc4.Cipher 83 | } 84 | 85 | func NewRC4TokenSecurityProvider(key []byte) *RC4TokenSecureFormatter { 86 | var sc = &RC4TokenSecureFormatter{key: key} 87 | return sc 88 | } 89 | 90 | func (sc *RC4TokenSecureFormatter) CryptToken(source []byte) ([]byte, error) { 91 | dest := make([]byte, len(source)) 92 | cipher, err := rc4.NewCipher(sc.key) 93 | if err != nil { 94 | return nil, err 95 | } 96 | cipher.XORKeyStream(dest, source) 97 | return dest, nil 98 | } 99 | 100 | func (sc *RC4TokenSecureFormatter) DecryptToken(source []byte) ([]byte, error) { 101 | dest := make([]byte, len(source)) 102 | cipher, err := rc4.NewCipher(sc.key) 103 | if err != nil { 104 | panic(err) 105 | } 106 | cipher.XORKeyStream(dest, source) 107 | return dest, nil 108 | } 109 | 110 | type SHA256RC4TokenSecureFormatter struct { 111 | key []byte 112 | cipher *rc4.Cipher 113 | } 114 | 115 | func NewSHA256RC4TokenSecurityProvider(key []byte) *SHA256RC4TokenSecureFormatter { 116 | var sc = &SHA256RC4TokenSecureFormatter{key: key} 117 | return sc 118 | } 119 | 120 | func (sc *SHA256RC4TokenSecureFormatter) CryptToken(source []byte) ([]byte, error) { 121 | hasher := sha256.New() 122 | hasher.Write(source) 123 | hash := hasher.Sum(nil) 124 | newSource := append(hash, source...) 125 | dest := make([]byte, len(newSource)) 126 | cipher, err := rc4.NewCipher(sc.key) 127 | if err != nil { 128 | return nil, err 129 | } 130 | cipher.XORKeyStream(dest, newSource) 131 | return dest, nil 132 | } 133 | 134 | func (sc *SHA256RC4TokenSecureFormatter) DecryptToken(source []byte) ([]byte, error) { 135 | if len(source) < 32 { 136 | return nil, errors.New("Invalid token") 137 | } 138 | dest := make([]byte, len(source)) 139 | cipher, err := rc4.NewCipher(sc.key) 140 | if err != nil { 141 | return nil, err 142 | } 143 | cipher.XORKeyStream(dest, source) 144 | hasher := sha256.New() 145 | hasher.Write(dest[32:]) 146 | hash := hasher.Sum(nil) 147 | for i, b := range hash { 148 | if b != dest[i] { 149 | return nil, errors.New("Invalid token") 150 | } 151 | } 152 | return dest[32:], nil 153 | } 154 | -------------------------------------------------------------------------------- /security_test.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | var _sutRC4, _sutSHA256 *TokenProvider 8 | 9 | func init() { 10 | _sutRC4 = NewTokenProvider(NewRC4TokenSecurityProvider([]byte("testkey"))) 11 | _sutSHA256 = NewTokenProvider(NewSHA256RC4TokenSecurityProvider([]byte("testkey"))) 12 | } 13 | 14 | func TestCrypt(t *testing.T) { 15 | var token string = `{"CreationDate":"2016-12-14","Expiration":"1000"}` 16 | result, err := _sutRC4.crypt([]byte(token)) 17 | if err != nil { 18 | t.Fatalf("Error %s", err.Error()) 19 | } 20 | t.Logf("Base64 Token : %v", result) 21 | } 22 | 23 | func TestDecrypt(t *testing.T) { 24 | var token string = `{"CreationDate":"2016-12-14","Expiration":"1000"}` 25 | var bToken []byte = []byte(token) 26 | t.Logf("Base64 Token : %v", bToken) 27 | result, err := _sutRC4.crypt([]byte(token)) 28 | if err != nil { 29 | t.Fatalf("Error %s", err.Error()) 30 | } 31 | t.Logf("Base64 Token : %v", result) 32 | decrypt, err := _sutRC4.decrypt(result) 33 | if err != nil { 34 | t.Fatalf("Error %s", err.Error()) 35 | } 36 | t.Logf("Base64 Token Decrypted: %v", decrypt) 37 | t.Logf("Base64 Token Decrypted: %s", decrypt) 38 | for i := range bToken { 39 | if bToken[i] != decrypt[i] { 40 | t.Fatalf("Error in decryption") 41 | } 42 | } 43 | } 44 | 45 | func TestCryptSHA256(t *testing.T) { 46 | var token string = `{"CreationDate":"2016-12-14","Expiration":"1000"}` 47 | result, err := _sutSHA256.crypt([]byte(token)) 48 | if err != nil { 49 | t.Fatalf("Error %s", err.Error()) 50 | } 51 | t.Logf("Base64 Token : %v", result) 52 | } 53 | 54 | func TestDecryptSHA256(t *testing.T) { 55 | var token string = `{"CreationDate":"2016-12-14","Expiration":"1000"}` 56 | var bToken []byte = []byte(token) 57 | t.Logf("Base64 Token : %v", bToken) 58 | result, err := _sutSHA256.crypt([]byte(token)) 59 | if err != nil { 60 | t.Fatalf("Error %s", err.Error()) 61 | } 62 | t.Logf("Base64 Token : %v", result) 63 | decrypt, err := _sutSHA256.decrypt(result) 64 | if err != nil { 65 | t.Fatalf("Error %s", err.Error()) 66 | } 67 | t.Logf("Base64 Token Decrypted: %v", decrypt) 68 | t.Logf("Base64 Token Decrypted: %s", decrypt) 69 | for i := range bToken { 70 | if bToken[i] != decrypt[i] { 71 | t.Fatalf("Error in decryption") 72 | } 73 | } 74 | } 75 | 76 | func TestDecryptSHA256_LongKey(t *testing.T) { 77 | sutSHA256 := NewTokenProvider(NewSHA256RC4TokenSecurityProvider([]byte("518baffa-b290-4c01-a150-1980f5b06a01"))) 78 | var token string = `{"CreationDate":"2016-12-14","Expiration":"1000"}` 79 | var bToken []byte = []byte(token) 80 | t.Logf("Base64 Token : %v", bToken) 81 | result, err := sutSHA256.crypt([]byte(token)) 82 | if err != nil { 83 | t.Fatalf("Error %s", err.Error()) 84 | } 85 | t.Logf("Base64 Token : %v", result) 86 | decrypt, err := sutSHA256.decrypt(result) 87 | if err != nil { 88 | t.Fatalf("Error %s", err.Error()) 89 | } 90 | t.Logf("Base64 Token Decrypted: %v", decrypt) 91 | t.Logf("Base64 Token Decrypted: %s", decrypt) 92 | for i := range bToken { 93 | if bToken[i] != decrypt[i] { 94 | t.Fatalf("Error in decryption") 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/gofrs/uuid" 8 | ) 9 | 10 | type GrantType string 11 | 12 | const ( 13 | PasswordGrant GrantType = "password" 14 | ClientCredentialsGrant GrantType = "client_credentials" 15 | AuthCodeGrant GrantType = "authorization_code" 16 | RefreshTokenGrant GrantType = "refresh_token" 17 | ) 18 | 19 | // CredentialsVerifier defines the interface of the user and client credentials verifier. 20 | type CredentialsVerifier interface { 21 | // Validate username and password returning an error if the user credentials are wrong 22 | ValidateUser(username, password, scope string, r *http.Request) error 23 | // Validate clientID and secret returning an error if the client credentials are wrong 24 | ValidateClient(clientID, clientSecret, scope string, r *http.Request) error 25 | // Provide additional claims to the token 26 | AddClaims(tokenType TokenType, credential, tokenID, scope string, r *http.Request) (map[string]string, error) 27 | // Provide additional information to the authorization server response 28 | AddProperties(tokenType TokenType, credential, tokenID, scope string, r *http.Request) (map[string]string, error) 29 | // Optionally validate previously stored tokenID during refresh request 30 | ValidateTokenID(tokenType TokenType, credential, tokenID, refreshTokenID string) error 31 | // Optionally store the tokenID generated for the user 32 | StoreTokenID(tokenType TokenType, credential, tokenID, refreshTokenID string) error 33 | } 34 | 35 | // AuthorizationCodeVerifier defines the interface of the Authorization Code verifier 36 | type AuthorizationCodeVerifier interface { 37 | // ValidateCode checks the authorization code and returns the user credential 38 | ValidateCode(clientID, clientSecret, code, redirectURI string, r *http.Request) (string, error) 39 | } 40 | 41 | // BearerServer is the OAuth 2 bearer server implementation. 42 | type BearerServer struct { 43 | secretKey string 44 | TokenTTL time.Duration 45 | verifier CredentialsVerifier 46 | provider *TokenProvider 47 | } 48 | 49 | // NewBearerServer creates new OAuth 2 bearer server 50 | func NewBearerServer(secretKey string, ttl time.Duration, verifier CredentialsVerifier, formatter TokenSecureFormatter) *BearerServer { 51 | if formatter == nil { 52 | formatter = NewSHA256RC4TokenSecurityProvider([]byte(secretKey)) 53 | } 54 | return &BearerServer{ 55 | secretKey: secretKey, 56 | TokenTTL: ttl, 57 | verifier: verifier, 58 | provider: NewTokenProvider(formatter)} 59 | } 60 | 61 | // UserCredentials manages password grant type requests 62 | func (bs *BearerServer) UserCredentials(w http.ResponseWriter, r *http.Request) { 63 | grantType := r.FormValue("grant_type") 64 | username := r.FormValue("username") 65 | password := r.FormValue("password") 66 | scope := r.FormValue("scope") 67 | if username == "" || password == "" { 68 | // get username and password from basic authorization header 69 | var err error 70 | username, password, err = GetBasicAuthentication(r) 71 | if err != nil { 72 | renderJSON(w, "Not authorized", http.StatusUnauthorized) 73 | return 74 | } 75 | } 76 | 77 | refreshToken := r.FormValue("refresh_token") 78 | resp, statusCode := bs.generateTokenResponse(GrantType(grantType), username, password, refreshToken, scope, "", "", r) 79 | renderJSON(w, resp, statusCode) 80 | } 81 | 82 | // ClientCredentials manages client credentials grant type requests 83 | func (bs *BearerServer) ClientCredentials(w http.ResponseWriter, r *http.Request) { 84 | grantType := r.FormValue("grant_type") 85 | // grant_type client_credentials variables 86 | clientID := r.FormValue("client_id") 87 | clientSecret := r.FormValue("client_secret") 88 | if clientID == "" || clientSecret == "" { 89 | // get clientID and secret from basic authorization header 90 | var err error 91 | clientID, clientSecret, err = GetBasicAuthentication(r) 92 | if err != nil { 93 | renderJSON(w, "Not authorized", http.StatusUnauthorized) 94 | return 95 | } 96 | } 97 | scope := r.FormValue("scope") 98 | refreshToken := r.FormValue("refresh_token") 99 | resp, statusCode := bs.generateTokenResponse(GrantType(grantType), clientID, clientSecret, refreshToken, scope, "", "", r) 100 | renderJSON(w, resp, statusCode) 101 | } 102 | 103 | // AuthorizationCode manages authorization code grant type requests for the phase two of the authorization process 104 | func (bs *BearerServer) AuthorizationCode(w http.ResponseWriter, r *http.Request) { 105 | grantType := r.FormValue("grant_type") 106 | // grant_type client_credentials variables 107 | clientID := r.FormValue("client_id") 108 | clientSecret := r.FormValue("client_secret") // not mandatory 109 | code := r.FormValue("code") 110 | redirectURI := r.FormValue("redirect_uri") // not mandatory 111 | scope := r.FormValue("scope") // not mandatory 112 | if clientID == "" { 113 | var err error 114 | clientID, clientSecret, err = GetBasicAuthentication(r) 115 | if err != nil { 116 | renderJSON(w, "Not authorized", http.StatusUnauthorized) 117 | return 118 | } 119 | } 120 | resp, status := bs.generateTokenResponse(GrantType(grantType), clientID, clientSecret, "", scope, code, redirectURI, r) 121 | renderJSON(w, resp, status) 122 | } 123 | 124 | // Generate token response 125 | func (bs *BearerServer) generateTokenResponse(grantType GrantType, credential string, secret string, refreshToken string, scope string, code string, redirectURI string, r *http.Request) (interface{}, int) { 126 | var resp *TokenResponse 127 | switch grantType { 128 | case PasswordGrant: 129 | if err := bs.verifier.ValidateUser(credential, secret, scope, r); err != nil { 130 | return "Not authorized", http.StatusUnauthorized 131 | } 132 | 133 | token, refresh, err := bs.generateTokens(UserToken, credential, scope, r) 134 | if err != nil { 135 | return "Token generation failed, check claims", http.StatusInternalServerError 136 | } 137 | 138 | if err = bs.verifier.StoreTokenID(token.TokenType, credential, token.ID, refresh.RefreshTokenID); err != nil { 139 | return "Storing Token ID failed", http.StatusInternalServerError 140 | } 141 | 142 | if resp, err = bs.cryptTokens(token, refresh, r); err != nil { 143 | return "Token generation failed, check security provider", http.StatusInternalServerError 144 | } 145 | case ClientCredentialsGrant: 146 | if err := bs.verifier.ValidateClient(credential, secret, scope, r); err != nil { 147 | return "Not authorized", http.StatusUnauthorized 148 | } 149 | 150 | token, refresh, err := bs.generateTokens(ClientToken, credential, scope, r) 151 | if err != nil { 152 | return "Token generation failed, check claims", http.StatusInternalServerError 153 | } 154 | 155 | if err = bs.verifier.StoreTokenID(token.TokenType, credential, token.ID, refresh.RefreshTokenID); err != nil { 156 | return "Storing Token ID failed", http.StatusInternalServerError 157 | } 158 | 159 | if resp, err = bs.cryptTokens(token, refresh, r); err != nil { 160 | return "Token generation failed, check security provider", http.StatusInternalServerError 161 | } 162 | case AuthCodeGrant: 163 | codeVerifier, ok := bs.verifier.(AuthorizationCodeVerifier) 164 | if !ok { 165 | return "Not authorized, grant type not supported", http.StatusUnauthorized 166 | } 167 | 168 | user, err := codeVerifier.ValidateCode(credential, secret, code, redirectURI, r) 169 | if err != nil { 170 | return "Not authorized", http.StatusUnauthorized 171 | } 172 | 173 | token, refresh, err := bs.generateTokens(AuthToken, user, scope, r) 174 | if err != nil { 175 | return "Token generation failed, check claims", http.StatusInternalServerError 176 | } 177 | 178 | err = bs.verifier.StoreTokenID(token.TokenType, user, token.ID, refresh.RefreshTokenID) 179 | if err != nil { 180 | return "Storing Token ID failed", http.StatusInternalServerError 181 | } 182 | 183 | if resp, err = bs.cryptTokens(token, refresh, r); err != nil { 184 | return "Token generation failed, check security provider", http.StatusInternalServerError 185 | } 186 | case RefreshTokenGrant: 187 | refresh, err := bs.provider.DecryptRefreshTokens(refreshToken) 188 | if err != nil { 189 | return "Not authorized", http.StatusUnauthorized 190 | } 191 | 192 | if err = bs.verifier.ValidateTokenID(refresh.TokenType, refresh.Credential, refresh.TokenID, refresh.RefreshTokenID); err != nil { 193 | return "Not authorized invalid token", http.StatusUnauthorized 194 | } 195 | 196 | token, refresh, err := bs.generateTokens(refresh.TokenType, refresh.Credential, refresh.Scope, r) 197 | if err != nil { 198 | return "Token generation failed", http.StatusInternalServerError 199 | } 200 | 201 | err = bs.verifier.StoreTokenID(token.TokenType, refresh.Credential, token.ID, refresh.RefreshTokenID) 202 | if err != nil { 203 | return "Storing Token ID failed", http.StatusInternalServerError 204 | } 205 | 206 | if resp, err = bs.cryptTokens(token, refresh, r); err != nil { 207 | return "Token generation failed", http.StatusInternalServerError 208 | } 209 | default: 210 | return "Invalid grant_type", http.StatusBadRequest 211 | } 212 | 213 | return resp, http.StatusOK 214 | } 215 | 216 | func (bs *BearerServer) generateTokens(tokenType TokenType, username, scope string, r *http.Request) (*Token, *RefreshToken, error) { 217 | token := &Token{ID: uuid.Must(uuid.NewV4()).String(), Credential: username, ExpiresIn: bs.TokenTTL, CreationDate: time.Now().UTC(), TokenType: tokenType, Scope: scope} 218 | if bs.verifier != nil { 219 | claims, err := bs.verifier.AddClaims(token.TokenType, username, token.ID, token.Scope, r) 220 | if err != nil { 221 | return nil, nil, err 222 | } 223 | token.Claims = claims 224 | } 225 | 226 | refreshToken := &RefreshToken{RefreshTokenID: uuid.Must(uuid.NewV4()).String(), TokenID: token.ID, CreationDate: time.Now().UTC(), Credential: username, TokenType: tokenType, Scope: scope} 227 | 228 | return token, refreshToken, nil 229 | } 230 | 231 | func (bs *BearerServer) cryptTokens(token *Token, refresh *RefreshToken, r *http.Request) (*TokenResponse, error) { 232 | cToken, err := bs.provider.CryptToken(token) 233 | if err != nil { 234 | return nil, err 235 | } 236 | cRefreshToken, err := bs.provider.CryptRefreshToken(refresh) 237 | if err != nil { 238 | return nil, err 239 | } 240 | 241 | tokenResponse := &TokenResponse{Token: cToken, RefreshToken: cRefreshToken, TokenType: BearerToken, ExpiresIn: (int64)(bs.TokenTTL / time.Second)} 242 | 243 | if bs.verifier != nil { 244 | props, err := bs.verifier.AddProperties(token.TokenType, token.Credential, token.ID, token.Scope, r) 245 | if err != nil { 246 | return nil, err 247 | } 248 | tokenResponse.Properties = props 249 | } 250 | return tokenResponse, nil 251 | } 252 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | var _sut = NewBearerServer( 12 | "mySecretKey-10101", 13 | time.Second*60, 14 | new(TestUserVerifier), 15 | nil) 16 | 17 | // TestUserVerifier provides user credentials verifier for testing. 18 | type TestUserVerifier struct { 19 | } 20 | 21 | // Validate username and password returning an error if the user credentials are wrong 22 | func (TestUserVerifier) ValidateUser(username, password, scope string, r *http.Request) error { 23 | // Add something to the request context, so we can access it in the claims and props funcs. 24 | ctx := r.Context() 25 | ctx = context.WithValue(ctx, "oauth.claims.test", "test") 26 | ctx = context.WithValue(ctx, "oauth.props.test", "test") 27 | *r = *r.Clone(ctx) 28 | 29 | switch { 30 | case username == "user111" && password == "password111": 31 | return nil 32 | case username == "user222" && password == "password222": 33 | return nil 34 | case username == "user333" && password == "password333": 35 | return nil 36 | default: 37 | return errors.New("wrong user") 38 | } 39 | } 40 | 41 | // Validate clientID and secret returning an error if the client credentials are wrong 42 | func (TestUserVerifier) ValidateClient(clientID, clientSecret, scope string, r *http.Request) error { 43 | // Add something to the request context, so we can access it in the claims and props funcs. 44 | ctx := r.Context() 45 | ctx = context.WithValue(ctx, "oauth.claims.test", "test") 46 | ctx = context.WithValue(ctx, "oauth.props.test", "test") 47 | *r = *r.Clone(ctx) 48 | 49 | if clientID == "abcdef" && clientSecret == "12345" { 50 | return nil 51 | } 52 | return errors.New("wrong client") 53 | } 54 | 55 | // Provide additional claims to the token 56 | func (TestUserVerifier) AddClaims(tokenType TokenType, credential, tokenID, scope string, r *http.Request) (map[string]string, error) { 57 | claims := make(map[string]string) 58 | claims["customer_id"] = "1001" 59 | claims["customer_data"] = `{"order_date":"2016-12-14","order_id":"9999"}` 60 | 61 | // Get value from request context, and add it to our claims. 62 | test := r.Context().Value("oauth.claims.test") 63 | if test != nil { 64 | claims["ctx_value"] = test.(string) 65 | } 66 | return claims, nil 67 | } 68 | 69 | // Provide additional information to the token response 70 | func (TestUserVerifier) AddProperties(tokenType TokenType, credential, tokenID, scope string, r *http.Request) (map[string]string, error) { 71 | props := make(map[string]string) 72 | props["customer_name"] = "Gopher" 73 | 74 | // Get value from request context, and add it to our props. 75 | test := r.Context().Value("oauth.props.test") 76 | if test != nil { 77 | props["ctx_value"] = test.(string) 78 | } 79 | return props, nil 80 | } 81 | 82 | // Validate token ID 83 | func (TestUserVerifier) ValidateTokenID(tokenType TokenType, credential, tokenID, refreshTokenID string) error { 84 | return nil 85 | } 86 | 87 | // Optionally store the token ID generated for the user 88 | func (TestUserVerifier) StoreTokenID(tokenType TokenType, credential, tokenID, refreshTokenID string) error { 89 | return nil 90 | } 91 | 92 | func TestGenerateTokensByUsername(t *testing.T) { 93 | r := new(http.Request) 94 | token, refresh, err := _sut.generateTokens(UserToken, "user111", "", r) 95 | if err == nil { 96 | t.Logf("Token: %v", token) 97 | t.Logf("Refresh Token: %v", refresh) 98 | } else { 99 | t.Fatalf("Error %s", err.Error()) 100 | } 101 | } 102 | 103 | func TestCryptTokens(t *testing.T) { 104 | r := new(http.Request) 105 | token, refresh, err := _sut.generateTokens(UserToken, "user222","", r) 106 | if err == nil { 107 | t.Logf("Token: %v", token) 108 | t.Logf("Refresh Token: %v", refresh) 109 | } else { 110 | t.Fatalf("Error %s", err.Error()) 111 | } 112 | 113 | resp, err := _sut.cryptTokens(token, refresh, r) 114 | if err == nil { 115 | t.Logf("Response: %v", resp) 116 | } else { 117 | t.Fatalf("Error %s", err.Error()) 118 | } 119 | } 120 | 121 | func TestDecryptRefreshTokens(t *testing.T) { 122 | r := new(http.Request) 123 | token, refresh, err := _sut.generateTokens(UserToken,"user333","", r) 124 | if err == nil { 125 | t.Logf("Token: %v", token) 126 | t.Logf("Refresh Token: %v", refresh) 127 | } else { 128 | t.Fatalf("Error %s", err.Error()) 129 | } 130 | 131 | resp, err := _sut.cryptTokens(token, refresh, r) 132 | if err == nil { 133 | t.Logf("Response: %v", resp) 134 | t.Logf("Response Refresh Token: %v", resp.RefreshToken) 135 | } else { 136 | t.Fatalf("Error %s", err.Error()) 137 | } 138 | 139 | refresh2, err := _sut.provider.DecryptRefreshTokens(resp.RefreshToken) 140 | if err == nil { 141 | t.Logf("Refresh Token Decrypted: %v", refresh2) 142 | } else { 143 | t.Fatalf("Error %s", err.Error()) 144 | } 145 | } 146 | 147 | func TestGenerateToken4Password(t *testing.T) { 148 | resp, code := _sut.generateTokenResponse(PasswordGrant, "user111", "password111", "", "", "", "", new(http.Request)) 149 | if code != 200 { 150 | t.Fatalf("Error StatusCode = %d", code) 151 | } 152 | if resp.(*TokenResponse).Properties["ctx_value"] != "test" { 153 | t.Fatalf("Error ctx_value invalid = %s", resp.(*TokenResponse).Properties["ctx_value"]) 154 | } 155 | t.Logf("Token response: %v", resp) 156 | } 157 | 158 | func TestShouldFailGenerateToken4Password(t *testing.T) { 159 | _, code := _sut.generateTokenResponse(PasswordGrant, "user111", "password4444", "", "", "", "", new(http.Request)) 160 | t.Logf("Server response: %v", code) 161 | if code != 401 { 162 | t.Fatalf("Error StatusCode = %d", code) 163 | } 164 | } 165 | 166 | func TestGenerateToken4ClientCredentials(t *testing.T) { 167 | resp, code := _sut.generateTokenResponse(ClientCredentialsGrant, "abcdef", "12345", "", "", "", "", new(http.Request)) 168 | if code != 200 { 169 | t.Fatalf("Error StatusCode = %d", code) 170 | } 171 | if resp.(*TokenResponse).Properties["ctx_value"] != "test" { 172 | t.Fatalf("Error ctx_value invalid = %s", resp.(*TokenResponse).Properties["ctx_value"]) 173 | } 174 | t.Logf("Token response: %v", resp) 175 | } 176 | 177 | func TestRefreshToken4ClientCredentials(t *testing.T) { 178 | r := new(http.Request) 179 | resp, code := _sut.generateTokenResponse(ClientCredentialsGrant, "abcdef", "12345", "", "", "", "", r) 180 | if code != 200 { 181 | t.Fatalf("Error StatusCode = %d", code) 182 | } 183 | if resp.(*TokenResponse).Properties["ctx_value"] != "test" { 184 | t.Fatalf("Error ctx_value invalid = %s", resp.(*TokenResponse).Properties["ctx_value"]) 185 | } 186 | t.Logf("Token Response: %v", resp) 187 | resp2, code2 := _sut.generateTokenResponse(RefreshTokenGrant, "", "", resp.(*TokenResponse).RefreshToken, "", "", "", r) 188 | if code2 != 200 { 189 | t.Fatalf("Error StatusCode = %d", code2) 190 | } 191 | t.Logf("New Token Response: %v", resp2) 192 | } --------------------------------------------------------------------------------