├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── const.go ├── const_test.go ├── doc.go ├── errors ├── error.go └── response.go ├── example ├── README.md ├── client │ └── client.go └── server │ ├── server.go │ └── static │ ├── auth.html │ ├── auth.png │ ├── login.html │ ├── login.png │ └── token.png ├── generate.go ├── generates ├── access.go ├── access_test.go ├── authorize.go ├── authorize_test.go ├── jwt_access.go └── jwt_access_test.go ├── go.mod ├── go.sum ├── go.test.sh ├── manage.go ├── manage ├── config.go ├── manage_test.go ├── manager.go ├── util.go └── util_test.go ├── model.go ├── models ├── client.go └── token.go ├── server ├── config.go ├── handler.go ├── handler_test.go ├── server.go ├── server_config.go └── server_test.go ├── store.go └── store ├── client.go ├── client_test.go ├── token.go └── token_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | coverage.txt 27 | 28 | # OSX 29 | *.DS_Store 30 | *.db 31 | *.swp 32 | /example/client/client 33 | /example/server/server 34 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go_import_path: github.com/go-oauth2/oauth2/v4 4 | go: 5 | - 1.13 6 | before_install: 7 | - go get -t -v ./... 8 | 9 | script: 10 | - chmod +x ./go.test.sh && ./go.test.sh 11 | 12 | after_success: 13 | - bash <(curl -s https://codecov.io/bash) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Lyric 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Golang OAuth 2.0 Server 2 | 3 | > An open protocol to allow secure authorization in a simple and standard method from web, mobile and desktop applications. 4 | 5 | [![Build][build-status-image]][build-status-url] [![Codecov][codecov-image]][codecov-url] [![ReportCard][reportcard-image]][reportcard-url] [![GoDoc][godoc-image]][godoc-url] [![License][license-image]][license-url] 6 | 7 | ## Protocol Flow 8 | 9 | ```text 10 | +--------+ +---------------+ 11 | | |--(A)- Authorization Request ->| Resource | 12 | | | | Owner | 13 | | |<-(B)-- Authorization Grant ---| | 14 | | | +---------------+ 15 | | | 16 | | | +---------------+ 17 | | |--(C)-- Authorization Grant -->| Authorization | 18 | | Client | | Server | 19 | | |<-(D)----- Access Token -------| | 20 | | | +---------------+ 21 | | | 22 | | | +---------------+ 23 | | |--(E)----- Access Token ------>| Resource | 24 | | | | Server | 25 | | |<-(F)--- Protected Resource ---| | 26 | +--------+ +---------------+ 27 | ``` 28 | 29 | ## Quick Start 30 | 31 | ### Download and install 32 | 33 | ```bash 34 | go get -u -v github.com/go-oauth2/oauth2/v4/... 35 | ``` 36 | 37 | ### Create file `server.go` 38 | 39 | ```go 40 | package main 41 | 42 | import ( 43 | "log" 44 | "net/http" 45 | 46 | "github.com/go-oauth2/oauth2/v4/errors" 47 | "github.com/go-oauth2/oauth2/v4/manage" 48 | "github.com/go-oauth2/oauth2/v4/models" 49 | "github.com/go-oauth2/oauth2/v4/server" 50 | "github.com/go-oauth2/oauth2/v4/store" 51 | ) 52 | 53 | func main() { 54 | manager := manage.NewDefaultManager() 55 | // token memory store 56 | manager.MustTokenStorage(store.NewMemoryTokenStore()) 57 | 58 | // client memory store 59 | clientStore := store.NewClientStore() 60 | clientStore.Set("000000", &models.Client{ 61 | ID: "000000", 62 | Secret: "999999", 63 | Domain: "http://localhost", 64 | }) 65 | manager.MapClientStorage(clientStore) 66 | 67 | srv := server.NewDefaultServer(manager) 68 | srv.SetAllowGetAccessRequest(true) 69 | srv.SetClientInfoHandler(server.ClientFormHandler) 70 | 71 | srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (userID string, err error) { 72 | return "000000", nil 73 | } 74 | 75 | srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { 76 | log.Println("Internal Error:", err.Error()) 77 | return 78 | }) 79 | 80 | srv.SetResponseErrorHandler(func(re *errors.Response) { 81 | log.Println("Response Error:", re.Error.Error()) 82 | }) 83 | 84 | http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { 85 | err := srv.HandleAuthorizeRequest(w, r) 86 | if err != nil { 87 | http.Error(w, err.Error(), http.StatusBadRequest) 88 | } 89 | }) 90 | 91 | http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { 92 | srv.HandleTokenRequest(w, r) 93 | }) 94 | 95 | log.Fatal(http.ListenAndServe(":9096", nil)) 96 | } 97 | 98 | ``` 99 | 100 | ### Build and run 101 | 102 | ```bash 103 | go build server.go 104 | 105 | ./server 106 | ``` 107 | 108 | ### Open in your web browser 109 | **Authorization Request**: 110 | [http://localhost:9096/authorize?client_id=000000&response_type=code](http://localhost:9096/authorize?client_id=000000&response_type=code) 111 | 112 | **Grant Token Request**: 113 | [http://localhost:9096/token?grant_type=client_credentials&client_id=000000&client_secret=999999&scope=read](http://localhost:9096/token?grant_type=client_credentials&client_id=000000&client_secret=999999&scope=read) 114 | 115 | ```json 116 | { 117 | "access_token": "J86XVRYSNFCFI233KXDL0Q", 118 | "expires_in": 7200, 119 | "scope": "read", 120 | "token_type": "Bearer" 121 | } 122 | ``` 123 | 124 | ## Features 125 | 126 | - Easy to use 127 | - Based on the [RFC 6749](https://tools.ietf.org/html/rfc6749) implementation 128 | - Token storage support TTL 129 | - Support custom expiration time of the access token 130 | - Support custom extension field 131 | - Support custom scope 132 | - Support jwt to generate access tokens 133 | 134 | ## Example 135 | 136 | > A complete example of simulation authorization code model 137 | 138 | Simulation examples of authorization code model, please check [example](/example) 139 | 140 | ### Use jwt to generate access tokens 141 | 142 | ```go 143 | 144 | import ( 145 | "github.com/go-oauth2/oauth2/v4/generates" 146 | "github.com/dgrijalva/jwt-go" 147 | ) 148 | 149 | // ... 150 | manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512)) 151 | 152 | // Parse and verify jwt access token 153 | token, err := jwt.ParseWithClaims(access, &generates.JWTAccessClaims{}, func(t *jwt.Token) (interface{}, error) { 154 | if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { 155 | return nil, fmt.Errorf("parse error") 156 | } 157 | return []byte("00000000"), nil 158 | }) 159 | if err != nil { 160 | // panic(err) 161 | } 162 | 163 | claims, ok := token.Claims.(*generates.JWTAccessClaims) 164 | if !ok || !token.Valid { 165 | // panic("invalid token") 166 | } 167 | ``` 168 | 169 | ## Store Implements 170 | 171 | - [BuntDB](https://github.com/tidwall/buntdb)(default store) 172 | - [Redis](https://github.com/go-oauth2/redis) 173 | - [MongoDB](https://github.com/go-oauth2/mongo) 174 | - [MySQL](https://github.com/go-oauth2/mysql) 175 | - [MySQL (Provides both client and token store)](https://github.com/imrenagi/go-oauth2-mysql) 176 | - [PostgreSQL](https://github.com/vgarvardt/go-oauth2-pg) 177 | - [DynamoDB](https://github.com/contamobi/go-oauth2-dynamodb) 178 | - [XORM](https://github.com/techknowlogick/go-oauth2-xorm) 179 | - [XORM (MySQL, client and token store)](https://github.com/rainlay/go-oauth2-xorm) 180 | - [GORM](https://github.com/techknowlogick/go-oauth2-gorm) 181 | - [Firestore](https://github.com/tslamic/go-oauth2-firestore) 182 | - [Hazelcast](https://github.com/clowre/go-oauth2-hazelcast) (token only) 183 | 184 | ## Handy Utilities 185 | 186 | - [OAuth2 Proxy Logger (Debug utility that proxies interfaces and logs)](https://github.com/aubelsb2/oauth2-logger-proxy) 187 | 188 | ## MIT License 189 | 190 | Copyright (c) 2016 Lyric 191 | 192 | [build-status-url]: https://travis-ci.org/go-oauth2/oauth2 193 | [build-status-image]: https://travis-ci.org/go-oauth2/oauth2.svg?branch=master 194 | [codecov-url]: https://codecov.io/gh/go-oauth2/oauth2 195 | [codecov-image]: https://codecov.io/gh/go-oauth2/oauth2/branch/master/graph/badge.svg 196 | [reportcard-url]: https://goreportcard.com/report/github.com/go-oauth2/oauth2/v4 197 | [reportcard-image]: https://goreportcard.com/badge/github.com/go-oauth2/oauth2/v4 198 | [godoc-url]: https://godoc.org/github.com/go-oauth2/oauth2/v4 199 | [godoc-image]: https://godoc.org/github.com/go-oauth2/oauth2/v4?status.svg 200 | [license-url]: http://opensource.org/licenses/MIT 201 | [license-image]: https://img.shields.io/npm/l/express.svg 202 | -------------------------------------------------------------------------------- /const.go: -------------------------------------------------------------------------------- 1 | package oauth2 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/base64" 6 | "strings" 7 | ) 8 | 9 | // ResponseType the type of authorization request 10 | type ResponseType string 11 | 12 | // define the type of authorization request 13 | const ( 14 | Code ResponseType = "code" 15 | Token ResponseType = "token" 16 | ) 17 | 18 | func (rt ResponseType) String() string { 19 | return string(rt) 20 | } 21 | 22 | // GrantType authorization model 23 | type GrantType string 24 | 25 | // define authorization model 26 | const ( 27 | AuthorizationCode GrantType = "authorization_code" 28 | PasswordCredentials GrantType = "password" 29 | ClientCredentials GrantType = "client_credentials" 30 | Refreshing GrantType = "refresh_token" 31 | Implicit GrantType = "__implicit" 32 | ) 33 | 34 | func (gt GrantType) String() string { 35 | if gt == AuthorizationCode || 36 | gt == PasswordCredentials || 37 | gt == ClientCredentials || 38 | gt == Refreshing { 39 | return string(gt) 40 | } 41 | return "" 42 | } 43 | 44 | // CodeChallengeMethod PCKE method 45 | type CodeChallengeMethod string 46 | 47 | const ( 48 | // CodeChallengePlain PCKE Method 49 | CodeChallengePlain CodeChallengeMethod = "plain" 50 | // CodeChallengeS256 PCKE Method 51 | CodeChallengeS256 CodeChallengeMethod = "S256" 52 | ) 53 | 54 | func (ccm CodeChallengeMethod) String() string { 55 | if ccm == CodeChallengePlain || 56 | ccm == CodeChallengeS256 { 57 | return string(ccm) 58 | } 59 | return "" 60 | } 61 | 62 | // Validate code challenge 63 | func (ccm CodeChallengeMethod) Validate(cc, ver string) bool { 64 | switch ccm { 65 | case CodeChallengePlain: 66 | return cc == ver 67 | case CodeChallengeS256: 68 | s256 := sha256.Sum256([]byte(ver)) 69 | // trim padding 70 | a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=") 71 | b := strings.TrimRight(cc, "=") 72 | return a == b 73 | default: 74 | return false 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /const_test.go: -------------------------------------------------------------------------------- 1 | package oauth2_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/go-oauth2/oauth2/v4" 7 | ) 8 | 9 | func TestValidatePlain(t *testing.T) { 10 | cc := oauth2.CodeChallengePlain 11 | if !cc.Validate("plaintest", "plaintest") { 12 | t.Fatal("not valid") 13 | } 14 | } 15 | 16 | func TestValidateS256(t *testing.T) { 17 | cc := oauth2.CodeChallengeS256 18 | if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=", "s256test") { 19 | t.Fatal("not valid") 20 | } 21 | } 22 | 23 | func TestValidateS256NoPadding(t *testing.T) { 24 | cc := oauth2.CodeChallengeS256 25 | if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o", "s256test") { 26 | t.Fatal("not valid") 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // OAuth 2.0 server library for the Go programming language 2 | // 3 | // package main 4 | // import ( 5 | // "net/http" 6 | // "github.com/go-oauth2/oauth2/v4/manage" 7 | // "github.com/go-oauth2/oauth2/v4/server" 8 | // "github.com/go-oauth2/oauth2/v4/store" 9 | // ) 10 | // func main() { 11 | // manager := manage.NewDefaultManager() 12 | // manager.MustTokenStorage(store.NewMemoryTokenStore()) 13 | // manager.MapClientStorage(store.NewTestClientStore()) 14 | // srv := server.NewDefaultServer(manager) 15 | // http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { 16 | // srv.HandleAuthorizeRequest(w, r) 17 | // }) 18 | // http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { 19 | // srv.HandleTokenRequest(w, r) 20 | // }) 21 | // http.ListenAndServe(":9096", nil) 22 | // } 23 | 24 | package oauth2 25 | -------------------------------------------------------------------------------- /errors/error.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import "errors" 4 | 5 | // New returns an error that formats as the given text. 6 | var New = errors.New 7 | 8 | // known errors 9 | var ( 10 | ErrInvalidRedirectURI = errors.New("invalid redirect uri") 11 | ErrInvalidAuthorizeCode = errors.New("invalid authorize code") 12 | ErrInvalidAccessToken = errors.New("invalid access token") 13 | ErrInvalidRefreshToken = errors.New("invalid refresh token") 14 | ErrExpiredAccessToken = errors.New("expired access token") 15 | ErrExpiredRefreshToken = errors.New("expired refresh token") 16 | ErrMissingCodeVerifier = errors.New("missing code verifier") 17 | ErrMissingCodeChallenge = errors.New("missing code challenge") 18 | ErrInvalidCodeChallenge = errors.New("invalid code challenge") 19 | ) 20 | -------------------------------------------------------------------------------- /errors/response.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | ) 7 | 8 | // Response error response 9 | type Response struct { 10 | Error error 11 | ErrorCode int 12 | Description string 13 | URI string 14 | StatusCode int 15 | Header http.Header 16 | } 17 | 18 | // NewResponse create the response pointer 19 | func NewResponse(err error, statusCode int) *Response { 20 | return &Response{ 21 | Error: err, 22 | StatusCode: statusCode, 23 | } 24 | } 25 | 26 | // SetHeader sets the header entries associated with key to 27 | // the single element value. 28 | func (r *Response) SetHeader(key, value string) { 29 | if r.Header == nil { 30 | r.Header = make(http.Header) 31 | } 32 | r.Header.Set(key, value) 33 | } 34 | 35 | // https://tools.ietf.org/html/rfc6749#section-5.2 36 | var ( 37 | ErrInvalidRequest = errors.New("invalid_request") 38 | ErrUnauthorizedClient = errors.New("unauthorized_client") 39 | ErrAccessDenied = errors.New("access_denied") 40 | ErrUnsupportedResponseType = errors.New("unsupported_response_type") 41 | ErrInvalidScope = errors.New("invalid_scope") 42 | ErrServerError = errors.New("server_error") 43 | ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") 44 | ErrInvalidClient = errors.New("invalid_client") 45 | ErrInvalidGrant = errors.New("invalid_grant") 46 | ErrUnsupportedGrantType = errors.New("unsupported_grant_type") 47 | ErrCodeChallengeRquired = errors.New("invalid_request") 48 | ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request") 49 | ErrInvalidCodeChallengeLen = errors.New("invalid_request") 50 | ) 51 | 52 | // Descriptions error description 53 | var Descriptions = map[error]string{ 54 | ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed", 55 | ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method", 56 | ErrAccessDenied: "The resource owner or authorization server denied the request", 57 | ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method", 58 | ErrInvalidScope: "The requested scope is invalid, unknown, or malformed", 59 | ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request", 60 | ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server", 61 | ErrInvalidClient: "Client authentication failed", 62 | ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client", 63 | ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server", 64 | ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing", 65 | ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported", 66 | ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long", 67 | } 68 | 69 | // StatusCodes response error HTTP status code 70 | var StatusCodes = map[error]int{ 71 | ErrInvalidRequest: 400, 72 | ErrUnauthorizedClient: 401, 73 | ErrAccessDenied: 403, 74 | ErrUnsupportedResponseType: 401, 75 | ErrInvalidScope: 400, 76 | ErrServerError: 500, 77 | ErrTemporarilyUnavailable: 503, 78 | ErrInvalidClient: 401, 79 | ErrInvalidGrant: 401, 80 | ErrUnsupportedGrantType: 401, 81 | ErrCodeChallengeRquired: 400, 82 | ErrUnsupportedCodeChallengeMethod: 400, 83 | ErrInvalidCodeChallengeLen: 400, 84 | } 85 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # Use Examples 2 | 3 | ## Run Server 4 | 5 | ``` bash 6 | $ cd example/server 7 | $ go build server.go 8 | $ ./server 9 | ``` 10 | 11 | ## Run Client 12 | 13 | ``` 14 | $ cd example/client 15 | $ go build client.go 16 | $ ./client 17 | ``` 18 | 19 | ## Authorization Code Grant 20 | 21 | ### Open the browser 22 | 23 | [http://localhost:9094](http://localhost:9094) 24 | 25 | ``` 26 | { 27 | "access_token": "GIGXO8XWPQSAUGOYQGTV8Q", 28 | "token_type": "Bearer", 29 | "refresh_token": "5FBLXQ47XJ2MGTY8YRZQ8W", 30 | "expiry": "2019-01-08T01:53:45.868194+08:00" 31 | } 32 | ``` 33 | 34 | 35 | ### Try access token 36 | 37 | Open the browser [http://localhost:9094/try](http://localhost:9094/try) 38 | 39 | ``` 40 | { 41 | "client_id": "222222", 42 | "expires_in": 7195, 43 | "user_id": "000000" 44 | } 45 | ``` 46 | 47 | ## Refresh token 48 | 49 | Open the browser [http://localhost:9094/refresh](http://localhost:9094/refresh) 50 | 51 | ``` 52 | { 53 | "access_token": "0IIL4_AJN2-SR0JEYZVQWG", 54 | "token_type": "Bearer", 55 | "refresh_token": "AG6-63MLXUEFUV2Q_BLYIW", 56 | "expiry": "2019-01-09T23:03:16.374062+08:00" 57 | } 58 | ``` 59 | 60 | ## Password Credentials Grant 61 | 62 | Open the browser [http://localhost:9094/pwd](http://localhost:9094/pwd) 63 | 64 | ``` 65 | { 66 | "access_token": "87JT3N6WOWANXVDNZFHY7Q", 67 | "token_type": "Bearer", 68 | "refresh_token": "LDIS6PXAVY-BXHPEDESWNG", 69 | "expiry": "2019-02-12T10:58:43.734902+08:00" 70 | } 71 | ``` 72 | 73 | ## Client Credentials Grant 74 | 75 | Open the browser [http://localhost:9094/client](http://localhost:9094/client) 76 | 77 | ``` 78 | { 79 | "access_token": "OA6ITALNMDOGD58C0SN-MG", 80 | "token_type": "Bearer", 81 | "expiry": "2019-02-12T11:10:35.864838+08:00" 82 | } 83 | ``` 84 | 85 | ![login](https://raw.githubusercontent.com/go-oauth2/oauth2/master/example/server/static/login.png) 86 | ![auth](https://raw.githubusercontent.com/go-oauth2/oauth2/master/example/server/static/auth.png) 87 | ![token](https://raw.githubusercontent.com/go-oauth2/oauth2/master/example/server/static/token.png) 88 | -------------------------------------------------------------------------------- /example/client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "log" 11 | "net/http" 12 | "time" 13 | 14 | "golang.org/x/oauth2" 15 | "golang.org/x/oauth2/clientcredentials" 16 | ) 17 | 18 | const ( 19 | authServerURL = "http://localhost:9096" 20 | ) 21 | 22 | var ( 23 | config = oauth2.Config{ 24 | ClientID: "222222", 25 | ClientSecret: "22222222", 26 | Scopes: []string{"all"}, 27 | RedirectURL: "http://localhost:9094/oauth2", 28 | Endpoint: oauth2.Endpoint{ 29 | AuthURL: authServerURL + "/oauth/authorize", 30 | TokenURL: authServerURL + "/oauth/token", 31 | }, 32 | } 33 | globalToken *oauth2.Token // Non-concurrent security 34 | ) 35 | 36 | func main() { 37 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 38 | u := config.AuthCodeURL("xyz", 39 | oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")), 40 | oauth2.SetAuthURLParam("code_challenge_method", "S256")) 41 | http.Redirect(w, r, u, http.StatusFound) 42 | }) 43 | 44 | http.HandleFunc("/oauth2", func(w http.ResponseWriter, r *http.Request) { 45 | r.ParseForm() 46 | state := r.Form.Get("state") 47 | if state != "xyz" { 48 | http.Error(w, "State invalid", http.StatusBadRequest) 49 | return 50 | } 51 | code := r.Form.Get("code") 52 | if code == "" { 53 | http.Error(w, "Code not found", http.StatusBadRequest) 54 | return 55 | } 56 | token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example")) 57 | if err != nil { 58 | http.Error(w, err.Error(), http.StatusInternalServerError) 59 | return 60 | } 61 | globalToken = token 62 | 63 | e := json.NewEncoder(w) 64 | e.SetIndent("", " ") 65 | e.Encode(token) 66 | }) 67 | 68 | http.HandleFunc("/refresh", func(w http.ResponseWriter, r *http.Request) { 69 | if globalToken == nil { 70 | http.Redirect(w, r, "/", http.StatusFound) 71 | return 72 | } 73 | 74 | globalToken.Expiry = time.Now() 75 | token, err := config.TokenSource(context.Background(), globalToken).Token() 76 | if err != nil { 77 | http.Error(w, err.Error(), http.StatusInternalServerError) 78 | return 79 | } 80 | 81 | globalToken = token 82 | e := json.NewEncoder(w) 83 | e.SetIndent("", " ") 84 | e.Encode(token) 85 | }) 86 | 87 | http.HandleFunc("/try", func(w http.ResponseWriter, r *http.Request) { 88 | if globalToken == nil { 89 | http.Redirect(w, r, "/", http.StatusFound) 90 | return 91 | } 92 | 93 | resp, err := http.Get(fmt.Sprintf("%s/test?access_token=%s", authServerURL, globalToken.AccessToken)) 94 | if err != nil { 95 | http.Error(w, err.Error(), http.StatusBadRequest) 96 | return 97 | } 98 | defer resp.Body.Close() 99 | 100 | io.Copy(w, resp.Body) 101 | }) 102 | 103 | http.HandleFunc("/pwd", func(w http.ResponseWriter, r *http.Request) { 104 | token, err := config.PasswordCredentialsToken(context.Background(), "test", "test") 105 | if err != nil { 106 | http.Error(w, err.Error(), http.StatusInternalServerError) 107 | return 108 | } 109 | 110 | globalToken = token 111 | e := json.NewEncoder(w) 112 | e.SetIndent("", " ") 113 | e.Encode(token) 114 | }) 115 | 116 | http.HandleFunc("/client", func(w http.ResponseWriter, r *http.Request) { 117 | cfg := clientcredentials.Config{ 118 | ClientID: config.ClientID, 119 | ClientSecret: config.ClientSecret, 120 | TokenURL: config.Endpoint.TokenURL, 121 | } 122 | 123 | token, err := cfg.Token(context.Background()) 124 | if err != nil { 125 | http.Error(w, err.Error(), http.StatusInternalServerError) 126 | return 127 | } 128 | 129 | e := json.NewEncoder(w) 130 | e.SetIndent("", " ") 131 | e.Encode(token) 132 | }) 133 | 134 | log.Println("Client is running at 9094 port.Please open http://localhost:9094") 135 | log.Fatal(http.ListenAndServe(":9094", nil)) 136 | } 137 | 138 | func genCodeChallengeS256(s string) string { 139 | s256 := sha256.Sum256([]byte(s)) 140 | return base64.URLEncoding.EncodeToString(s256[:]) 141 | } 142 | -------------------------------------------------------------------------------- /example/server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net/http" 11 | "net/http/httputil" 12 | "net/url" 13 | "os" 14 | "time" 15 | 16 | "github.com/go-oauth2/oauth2/v4/generates" 17 | 18 | "github.com/go-oauth2/oauth2/v4/errors" 19 | "github.com/go-oauth2/oauth2/v4/manage" 20 | "github.com/go-oauth2/oauth2/v4/models" 21 | "github.com/go-oauth2/oauth2/v4/server" 22 | "github.com/go-oauth2/oauth2/v4/store" 23 | "github.com/go-session/session/v3" 24 | ) 25 | 26 | var ( 27 | dumpvar bool 28 | idvar string 29 | secretvar string 30 | domainvar string 31 | portvar int 32 | ) 33 | 34 | func init() { 35 | flag.BoolVar(&dumpvar, "d", true, "Dump requests and responses") 36 | flag.StringVar(&idvar, "i", "222222", "The client id being passed in") 37 | flag.StringVar(&secretvar, "s", "22222222", "The client secret being passed in") 38 | flag.StringVar(&domainvar, "r", "http://localhost:9094", "The domain of the redirect url") 39 | flag.IntVar(&portvar, "p", 9096, "the base port for the server") 40 | } 41 | 42 | func main() { 43 | flag.Parse() 44 | if dumpvar { 45 | log.Println("Dumping requests") 46 | } 47 | manager := manage.NewDefaultManager() 48 | manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) 49 | 50 | // token store 51 | manager.MustTokenStorage(store.NewMemoryTokenStore()) 52 | 53 | // generate jwt access token 54 | // manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512)) 55 | manager.MapAccessGenerate(generates.NewAccessGenerate()) 56 | 57 | clientStore := store.NewClientStore() 58 | clientStore.Set(idvar, &models.Client{ 59 | ID: idvar, 60 | Secret: secretvar, 61 | Domain: domainvar, 62 | }) 63 | manager.MapClientStorage(clientStore) 64 | 65 | srv := server.NewServer(server.NewConfig(), manager) 66 | 67 | srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) { 68 | if username == "test" && password == "test" { 69 | userID = "test" 70 | } else { 71 | err = errors.New("invalid username or password") 72 | } 73 | return 74 | }) 75 | 76 | srv.SetUserAuthorizationHandler(userAuthorizeHandler) 77 | 78 | srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { 79 | log.Println("Internal Error:", err.Error()) 80 | return 81 | }) 82 | 83 | srv.SetResponseErrorHandler(func(re *errors.Response) { 84 | log.Println("Response Error:", re.Error.Error()) 85 | }) 86 | 87 | http.HandleFunc("/login", loginHandler) 88 | http.HandleFunc("/auth", authHandler) 89 | 90 | http.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { 91 | if dumpvar { 92 | dumpRequest(os.Stdout, "authorize", r) 93 | } 94 | 95 | store, err := session.Start(r.Context(), w, r) 96 | if err != nil { 97 | http.Error(w, err.Error(), http.StatusInternalServerError) 98 | return 99 | } 100 | 101 | var form url.Values 102 | if v, ok := store.Get("ReturnUri"); ok { 103 | form = v.(url.Values) 104 | } 105 | r.Form = form 106 | 107 | store.Delete("ReturnUri") 108 | store.Save() 109 | 110 | err = srv.HandleAuthorizeRequest(w, r) 111 | if err != nil { 112 | http.Error(w, err.Error(), http.StatusBadRequest) 113 | } 114 | }) 115 | 116 | http.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { 117 | if dumpvar { 118 | _ = dumpRequest(os.Stdout, "token", r) // Ignore the error 119 | } 120 | 121 | err := srv.HandleTokenRequest(w, r) 122 | if err != nil { 123 | http.Error(w, err.Error(), http.StatusInternalServerError) 124 | } 125 | }) 126 | 127 | http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { 128 | if dumpvar { 129 | _ = dumpRequest(os.Stdout, "test", r) // Ignore the error 130 | } 131 | token, err := srv.ValidationBearerToken(r) 132 | if err != nil { 133 | http.Error(w, err.Error(), http.StatusBadRequest) 134 | return 135 | } 136 | 137 | data := map[string]interface{}{ 138 | "expires_in": int64(token.GetAccessCreateAt().Add(token.GetAccessExpiresIn()).Sub(time.Now()).Seconds()), 139 | "client_id": token.GetClientID(), 140 | "user_id": token.GetUserID(), 141 | } 142 | e := json.NewEncoder(w) 143 | e.SetIndent("", " ") 144 | e.Encode(data) 145 | }) 146 | 147 | log.Printf("Server is running at %d port.\n", portvar) 148 | log.Printf("Point your OAuth client Auth endpoint to %s:%d%s", "http://localhost", portvar, "/oauth/authorize") 149 | log.Printf("Point your OAuth client Token endpoint to %s:%d%s", "http://localhost", portvar, "/oauth/token") 150 | log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", portvar), nil)) 151 | } 152 | 153 | func dumpRequest(writer io.Writer, header string, r *http.Request) error { 154 | data, err := httputil.DumpRequest(r, true) 155 | if err != nil { 156 | return err 157 | } 158 | writer.Write([]byte("\n" + header + ": \n")) 159 | writer.Write(data) 160 | return nil 161 | } 162 | 163 | func userAuthorizeHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { 164 | if dumpvar { 165 | _ = dumpRequest(os.Stdout, "userAuthorizeHandler", r) // Ignore the error 166 | } 167 | store, err := session.Start(r.Context(), w, r) 168 | if err != nil { 169 | return 170 | } 171 | 172 | uid, ok := store.Get("LoggedInUserID") 173 | if !ok { 174 | if r.Form == nil { 175 | r.ParseForm() 176 | } 177 | store.Set("ReturnUri", r.Form) 178 | store.Save() 179 | 180 | w.Header().Set("Location", "/login") 181 | w.WriteHeader(http.StatusFound) 182 | return 183 | } 184 | 185 | userID = uid.(string) 186 | store.Delete("LoggedInUserID") 187 | store.Save() 188 | return 189 | } 190 | 191 | func loginHandler(w http.ResponseWriter, r *http.Request) { 192 | if dumpvar { 193 | _ = dumpRequest(os.Stdout, "login", r) // Ignore the error 194 | } 195 | store, err := session.Start(r.Context(), w, r) 196 | if err != nil { 197 | http.Error(w, err.Error(), http.StatusInternalServerError) 198 | return 199 | } 200 | 201 | if r.Method == "POST" { 202 | if r.Form == nil { 203 | if err := r.ParseForm(); err != nil { 204 | http.Error(w, err.Error(), http.StatusInternalServerError) 205 | return 206 | } 207 | } 208 | 209 | if r.Form.Get("username") == "test" && r.Form.Get("password") == "test" { 210 | store.Set("LoggedInUserID", r.Form.Get("username")) 211 | store.Save() 212 | 213 | w.Header().Set("Location", "/auth") 214 | w.WriteHeader(http.StatusFound) 215 | return 216 | } else { 217 | http.Error(w, "Invalid username or password", http.StatusUnauthorized) 218 | return 219 | } 220 | } 221 | outputHTML(w, r, "static/login.html") 222 | } 223 | 224 | func authHandler(w http.ResponseWriter, r *http.Request) { 225 | if dumpvar { 226 | _ = dumpRequest(os.Stdout, "auth", r) // Ignore the error 227 | } 228 | store, err := session.Start(nil, w, r) 229 | if err != nil { 230 | http.Error(w, err.Error(), http.StatusInternalServerError) 231 | return 232 | } 233 | 234 | if _, ok := store.Get("LoggedInUserID"); !ok { 235 | w.Header().Set("Location", "/login") 236 | w.WriteHeader(http.StatusFound) 237 | return 238 | } 239 | 240 | outputHTML(w, r, "static/auth.html") 241 | } 242 | 243 | func outputHTML(w http.ResponseWriter, req *http.Request, filename string) { 244 | file, err := os.Open(filename) 245 | if err != nil { 246 | http.Error(w, err.Error(), 500) 247 | return 248 | } 249 | defer file.Close() 250 | fi, _ := file.Stat() 251 | http.ServeContent(w, req, file.Name(), fi.ModTime(), file) 252 | } 253 | -------------------------------------------------------------------------------- /example/server/static/auth.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Auth 6 | 10 | 11 | 12 | 13 | 14 | 15 |
16 |
17 |
18 |

Authorize

19 |

The client would like to perform actions on your behalf.

20 |

21 | 28 |

29 |
30 |
31 |
32 | 33 | 34 | -------------------------------------------------------------------------------- /example/server/static/auth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-oauth2/oauth2/bb6e415e9b625f5109f64f9bad0897141cc2f8dd/example/server/static/auth.png -------------------------------------------------------------------------------- /example/server/static/login.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Login 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 |

Login In

15 |
16 |
17 | 18 | 19 |
20 |
21 | 22 | 23 |
24 | 25 |
26 |
27 | 28 | 29 | -------------------------------------------------------------------------------- /example/server/static/login.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-oauth2/oauth2/bb6e415e9b625f5109f64f9bad0897141cc2f8dd/example/server/static/login.png -------------------------------------------------------------------------------- /example/server/static/token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-oauth2/oauth2/bb6e415e9b625f5109f64f9bad0897141cc2f8dd/example/server/static/token.png -------------------------------------------------------------------------------- /generate.go: -------------------------------------------------------------------------------- 1 | package oauth2 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | type ( 10 | // GenerateBasic provide the basis of the generated token data 11 | GenerateBasic struct { 12 | Client ClientInfo 13 | UserID string 14 | CreateAt time.Time 15 | TokenInfo TokenInfo 16 | Request *http.Request 17 | } 18 | 19 | // AuthorizeGenerate generate the authorization code interface 20 | AuthorizeGenerate interface { 21 | Token(ctx context.Context, data *GenerateBasic) (code string, err error) 22 | } 23 | 24 | // AccessGenerate generate the access and refresh tokens interface 25 | AccessGenerate interface { 26 | Token(ctx context.Context, data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error) 27 | } 28 | ) 29 | -------------------------------------------------------------------------------- /generates/access.go: -------------------------------------------------------------------------------- 1 | package generates 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/go-oauth2/oauth2/v4" 11 | "github.com/google/uuid" 12 | ) 13 | 14 | // NewAccessGenerate create to generate the access token instance 15 | func NewAccessGenerate() *AccessGenerate { 16 | return &AccessGenerate{} 17 | } 18 | 19 | // AccessGenerate generate the access token 20 | type AccessGenerate struct { 21 | } 22 | 23 | // Token based on the UUID generated token 24 | func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) { 25 | buf := bytes.NewBufferString(data.Client.GetID()) 26 | buf.WriteString(data.UserID) 27 | buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10)) 28 | 29 | access := base64.URLEncoding.EncodeToString([]byte(uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()).String())) 30 | access = strings.ToUpper(strings.TrimRight(access, "=")) 31 | refresh := "" 32 | if isGenRefresh { 33 | refresh = base64.URLEncoding.EncodeToString([]byte(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), buf.Bytes()).String())) 34 | refresh = strings.ToUpper(strings.TrimRight(refresh, "=")) 35 | } 36 | 37 | return access, refresh, nil 38 | } 39 | -------------------------------------------------------------------------------- /generates/access_test.go: -------------------------------------------------------------------------------- 1 | package generates_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/go-oauth2/oauth2/v4" 9 | "github.com/go-oauth2/oauth2/v4/generates" 10 | "github.com/go-oauth2/oauth2/v4/models" 11 | 12 | . "github.com/smartystreets/goconvey/convey" 13 | ) 14 | 15 | func TestAccess(t *testing.T) { 16 | Convey("Test Access Generate", t, func() { 17 | data := &oauth2.GenerateBasic{ 18 | Client: &models.Client{ 19 | ID: "123456", 20 | Secret: "123456", 21 | }, 22 | UserID: "000000", 23 | CreateAt: time.Now(), 24 | } 25 | gen := generates.NewAccessGenerate() 26 | access, refresh, err := gen.Token(context.Background(), data, true) 27 | So(err, ShouldBeNil) 28 | So(access, ShouldNotBeEmpty) 29 | So(refresh, ShouldNotBeEmpty) 30 | Println("\nAccess Token:" + access) 31 | Println("Refresh Token:" + refresh) 32 | }) 33 | } 34 | -------------------------------------------------------------------------------- /generates/authorize.go: -------------------------------------------------------------------------------- 1 | package generates 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "strings" 8 | 9 | "github.com/go-oauth2/oauth2/v4" 10 | "github.com/google/uuid" 11 | ) 12 | 13 | // NewAuthorizeGenerate create to generate the authorize code instance 14 | func NewAuthorizeGenerate() *AuthorizeGenerate { 15 | return &AuthorizeGenerate{} 16 | } 17 | 18 | // AuthorizeGenerate generate the authorize code 19 | type AuthorizeGenerate struct{} 20 | 21 | // Token based on the UUID generated token 22 | func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) { 23 | buf := bytes.NewBufferString(data.Client.GetID()) 24 | buf.WriteString(data.UserID) 25 | token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()) 26 | code := base64.URLEncoding.EncodeToString([]byte(token.String())) 27 | code = strings.ToUpper(strings.TrimRight(code, "=")) 28 | 29 | return code, nil 30 | } 31 | -------------------------------------------------------------------------------- /generates/authorize_test.go: -------------------------------------------------------------------------------- 1 | package generates_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/go-oauth2/oauth2/v4" 9 | "github.com/go-oauth2/oauth2/v4/generates" 10 | "github.com/go-oauth2/oauth2/v4/models" 11 | 12 | . "github.com/smartystreets/goconvey/convey" 13 | ) 14 | 15 | func TestAuthorize(t *testing.T) { 16 | Convey("Test Authorize Generate", t, func() { 17 | data := &oauth2.GenerateBasic{ 18 | Client: &models.Client{ 19 | ID: "123456", 20 | Secret: "123456", 21 | }, 22 | UserID: "000000", 23 | CreateAt: time.Now(), 24 | } 25 | gen := generates.NewAuthorizeGenerate() 26 | code, err := gen.Token(context.Background(), data) 27 | So(err, ShouldBeNil) 28 | So(code, ShouldNotBeEmpty) 29 | Println("\nAuthorize Code:" + code) 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /generates/jwt_access.go: -------------------------------------------------------------------------------- 1 | package generates 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "strings" 7 | "time" 8 | 9 | "github.com/go-oauth2/oauth2/v4" 10 | "github.com/go-oauth2/oauth2/v4/errors" 11 | "github.com/golang-jwt/jwt/v5" 12 | "github.com/google/uuid" 13 | ) 14 | 15 | // JWTAccessClaims jwt claims 16 | type JWTAccessClaims struct { 17 | jwt.RegisteredClaims 18 | } 19 | 20 | // Valid claims verification 21 | func (a *JWTAccessClaims) Valid() error { 22 | if a.ExpiresAt != nil && time.Unix(a.ExpiresAt.Unix(), 0).Before(time.Now()) { 23 | return errors.ErrInvalidAccessToken 24 | } 25 | return nil 26 | } 27 | 28 | // NewJWTAccessGenerate create to generate the jwt access token instance 29 | func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate { 30 | return &JWTAccessGenerate{ 31 | SignedKeyID: kid, 32 | SignedKey: key, 33 | SignedMethod: method, 34 | } 35 | } 36 | 37 | // JWTAccessGenerate generate the jwt access token 38 | type JWTAccessGenerate struct { 39 | SignedKeyID string 40 | SignedKey []byte 41 | SignedMethod jwt.SigningMethod 42 | } 43 | 44 | // Token based on the UUID generated token 45 | func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) { 46 | claims := &JWTAccessClaims{ 47 | RegisteredClaims: jwt.RegisteredClaims{ 48 | Audience: jwt.ClaimStrings{data.Client.GetID()}, 49 | Subject: data.UserID, 50 | ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())), 51 | }, 52 | } 53 | 54 | token := jwt.NewWithClaims(a.SignedMethod, claims) 55 | if a.SignedKeyID != "" { 56 | token.Header["kid"] = a.SignedKeyID 57 | } 58 | var key interface{} 59 | if a.isEs() { 60 | v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey) 61 | if err != nil { 62 | return "", "", err 63 | } 64 | key = v 65 | } else if a.isRsOrPS() { 66 | v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey) 67 | if err != nil { 68 | return "", "", err 69 | } 70 | key = v 71 | } else if a.isHs() { 72 | key = a.SignedKey 73 | } else if a.isEd() { 74 | v, err := jwt.ParseEdPrivateKeyFromPEM(a.SignedKey) 75 | if err != nil { 76 | return "", "", err 77 | } 78 | key = v 79 | } else { 80 | return "", "", errors.New("unsupported sign method") 81 | } 82 | 83 | access, err := token.SignedString(key) 84 | if err != nil { 85 | return "", "", err 86 | } 87 | refresh := "" 88 | 89 | if isGenRefresh { 90 | t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String() 91 | refresh = base64.URLEncoding.EncodeToString([]byte(t)) 92 | refresh = strings.ToUpper(strings.TrimRight(refresh, "=")) 93 | } 94 | 95 | return access, refresh, nil 96 | } 97 | 98 | func (a *JWTAccessGenerate) isEs() bool { 99 | return strings.HasPrefix(a.SignedMethod.Alg(), "ES") 100 | } 101 | 102 | func (a *JWTAccessGenerate) isRsOrPS() bool { 103 | isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS") 104 | isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS") 105 | return isRs || isPs 106 | } 107 | 108 | func (a *JWTAccessGenerate) isHs() bool { 109 | return strings.HasPrefix(a.SignedMethod.Alg(), "HS") 110 | } 111 | 112 | func (a *JWTAccessGenerate) isEd() bool { 113 | return strings.HasPrefix(a.SignedMethod.Alg(), "Ed") 114 | } 115 | -------------------------------------------------------------------------------- /generates/jwt_access_test.go: -------------------------------------------------------------------------------- 1 | package generates_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/go-oauth2/oauth2/v4" 10 | "github.com/go-oauth2/oauth2/v4/generates" 11 | "github.com/go-oauth2/oauth2/v4/models" 12 | "github.com/golang-jwt/jwt/v5" 13 | 14 | . "github.com/smartystreets/goconvey/convey" 15 | ) 16 | 17 | func TestJWTAccess(t *testing.T) { 18 | Convey("Test JWT Access Generate", t, func() { 19 | data := &oauth2.GenerateBasic{ 20 | Client: &models.Client{ 21 | ID: "123456", 22 | Secret: "123456", 23 | }, 24 | UserID: "000000", 25 | TokenInfo: &models.Token{ 26 | AccessCreateAt: time.Now(), 27 | AccessExpiresIn: time.Second * 120, 28 | }, 29 | } 30 | 31 | gen := generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512) 32 | access, refresh, err := gen.Token(context.Background(), data, true) 33 | So(err, ShouldBeNil) 34 | So(access, ShouldNotBeEmpty) 35 | So(refresh, ShouldNotBeEmpty) 36 | 37 | token, err := jwt.ParseWithClaims(access, &generates.JWTAccessClaims{}, func(t *jwt.Token) (interface{}, error) { 38 | if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { 39 | return nil, fmt.Errorf("parse error") 40 | } 41 | return []byte("00000000"), nil 42 | }) 43 | So(err, ShouldBeNil) 44 | 45 | claims, ok := token.Claims.(*generates.JWTAccessClaims) 46 | So(ok, ShouldBeTrue) 47 | So(token.Valid, ShouldBeTrue) 48 | aud, err := claims.GetAudience() 49 | So(err, ShouldBeNil) 50 | So(len(aud), ShouldEqual, 1) 51 | So(aud[0], ShouldEqual, "123456") 52 | So(claims.Subject, ShouldEqual, "000000") 53 | }) 54 | } 55 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-oauth2/oauth2/v4 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/ajg/form v1.5.1 // indirect 7 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect 8 | github.com/fatih/structs v1.1.0 // indirect 9 | github.com/gavv/httpexpect v2.0.0+incompatible 10 | github.com/go-session/session/v3 v3.2.1 11 | github.com/golang-jwt/jwt v3.2.2+incompatible 12 | github.com/google/go-querystring v1.0.0 // indirect 13 | github.com/google/uuid v1.1.1 14 | github.com/gorilla/websocket v1.4.2 // indirect 15 | github.com/imkira/go-interpol v1.1.0 // indirect 16 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect 17 | github.com/mattn/go-colorable v0.1.7 // indirect 18 | github.com/moul/http2curl v1.0.0 // indirect 19 | github.com/onsi/ginkgo v1.13.0 // indirect 20 | github.com/sergi/go-diff v1.1.0 // indirect 21 | github.com/smartystreets/goconvey v1.6.4 22 | github.com/tidwall/buntdb v1.1.2 23 | github.com/tidwall/gjson v1.12.1 // indirect 24 | github.com/valyala/fasthttp v1.34.0 // indirect 25 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect 26 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 // indirect 27 | github.com/yudai/gojsondiff v1.0.0 // indirect 28 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect 29 | github.com/yudai/pp v2.0.1+incompatible // indirect 30 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d 31 | google.golang.org/appengine v1.6.6 // indirect 32 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect 33 | ) 34 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= 2 | github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= 3 | github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= 4 | github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= 5 | github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= 6 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8= 7 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0= 8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 h1:DddqAaWDpywytcG8w/qoQ5sAN8X12d3Z3koB0C3Rxsc= 12 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= 13 | github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= 14 | github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= 15 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 16 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= 17 | github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= 18 | github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8= 19 | github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= 20 | github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY= 21 | github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M= 22 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= 23 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= 24 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 25 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 26 | github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= 27 | github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= 28 | github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= 29 | github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= 30 | github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= 31 | github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= 32 | github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= 33 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 34 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 35 | github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= 36 | github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 37 | github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= 38 | github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= 39 | github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= 40 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 41 | github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 42 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= 43 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 44 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 45 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 46 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 47 | github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= 48 | github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= 49 | github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= 50 | github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= 51 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM= 52 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= 53 | github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U= 54 | github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= 55 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 56 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 57 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 58 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 59 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 60 | github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= 61 | github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 62 | github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= 63 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 64 | github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= 65 | github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= 66 | github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= 67 | github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= 68 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 69 | github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= 70 | github.com/onsi/ginkgo v1.13.0 h1:M76yO2HkZASFjXL0HSoZJ1AYEmQxNJmY41Jx1zNUq1Y= 71 | github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= 72 | github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= 73 | github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE= 74 | github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= 75 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 76 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 77 | github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= 78 | github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= 79 | github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= 80 | github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= 81 | github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= 82 | github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= 83 | github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= 84 | github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= 85 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 86 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 87 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 88 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 89 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 90 | github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E= 91 | github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8= 92 | github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo= 93 | github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI= 94 | github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= 95 | github.com/tidwall/gjson v1.12.1 h1:ikuZsLdhr8Ws0IdROXUS1Gi4v9Z4pGqpX/CvJkxvfpo= 96 | github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 97 | github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb h1:5NSYaAdrnblKByzd7XByQEJVT8+9v0W/tIY0Oo4OwrE= 98 | github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M= 99 | github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= 100 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 101 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 102 | github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= 103 | github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= 104 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 105 | github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e h1:+NL1GDIUOKxVfbp2KoJQD9cTQ6dyP2co9q4yzmT9FZo= 106 | github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao= 107 | github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ3KKDyCkxapfufrqDqwmGjcHfAyXRrE= 108 | github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= 109 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 110 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 111 | github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4= 112 | github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0= 113 | github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= 114 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= 115 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= 116 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= 117 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= 118 | github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= 119 | github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= 120 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY= 121 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= 122 | github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= 123 | github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= 124 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M= 125 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= 126 | github.com/yudai/pp v2.0.1+incompatible h1:Q4//iY4pNF6yPLZIigmvcl7k/bPgrcTPIFIcmawg5bI= 127 | github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= 128 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 129 | golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 130 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 131 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 132 | golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 133 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 134 | golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= 135 | golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= 136 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 137 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= 138 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= 139 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= 140 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= 141 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 142 | golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 143 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 144 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 145 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 146 | golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 147 | golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 148 | golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 149 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 150 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 151 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 152 | golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 153 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 154 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 155 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 156 | golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 157 | golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 158 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14 h1:k5II8e6QD8mITdi+okbbmR/cIyEbeXLBhy5Ha4nevyc= 159 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 160 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 161 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 162 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 163 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 164 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 165 | golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= 166 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 167 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 168 | golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 169 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 170 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 171 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 172 | google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc= 173 | google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= 174 | google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= 175 | google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= 176 | google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= 177 | google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= 178 | google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= 179 | google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= 180 | google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 181 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 182 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= 183 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 184 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 185 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 186 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 187 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 188 | gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 189 | gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= 190 | gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 191 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 192 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= 193 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 194 | -------------------------------------------------------------------------------- /go.test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | echo "" > coverage.txt 5 | 6 | for d in $(go list ./... | grep -v vendor); do 7 | go test -race -coverprofile=profile.out -covermode=atomic "$d" 8 | if [ -f profile.out ]; then 9 | cat profile.out >> coverage.txt 10 | rm profile.out 11 | fi 12 | done -------------------------------------------------------------------------------- /manage.go: -------------------------------------------------------------------------------- 1 | package oauth2 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | // TokenGenerateRequest provide to generate the token request parameters 10 | type TokenGenerateRequest struct { 11 | ClientID string 12 | ClientSecret string 13 | UserID string 14 | RedirectURI string 15 | Scope string 16 | Code string 17 | CodeChallenge string 18 | CodeChallengeMethod CodeChallengeMethod 19 | Refresh string 20 | CodeVerifier string 21 | AccessTokenExp time.Duration 22 | Request *http.Request 23 | } 24 | 25 | // Manager authorization management interface 26 | type Manager interface { 27 | // get the client information 28 | GetClient(ctx context.Context, clientID string) (cli ClientInfo, err error) 29 | 30 | // generate the authorization token(code) 31 | GenerateAuthToken(ctx context.Context, rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error) 32 | 33 | // generate the access token 34 | GenerateAccessToken(ctx context.Context, gt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) 35 | 36 | // refreshing an access token 37 | RefreshAccessToken(ctx context.Context, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) 38 | 39 | // use the access token to delete the token information 40 | RemoveAccessToken(ctx context.Context, access string) (err error) 41 | 42 | // use the refresh token to delete the token information 43 | RemoveRefreshToken(ctx context.Context, refresh string) (err error) 44 | 45 | // according to the access token for corresponding token information 46 | LoadAccessToken(ctx context.Context, access string) (ti TokenInfo, err error) 47 | 48 | // according to the refresh token for corresponding token information 49 | LoadRefreshToken(ctx context.Context, refresh string) (ti TokenInfo, err error) 50 | } 51 | -------------------------------------------------------------------------------- /manage/config.go: -------------------------------------------------------------------------------- 1 | package manage 2 | 3 | import "time" 4 | 5 | // Config authorization configuration parameters 6 | type Config struct { 7 | // access token expiration time, 0 means it doesn't expire 8 | AccessTokenExp time.Duration 9 | // refresh token expiration time, 0 means it doesn't expire 10 | RefreshTokenExp time.Duration 11 | // whether to generate the refreshing token 12 | IsGenerateRefresh bool 13 | } 14 | 15 | // RefreshingConfig refreshing token config 16 | type RefreshingConfig struct { 17 | // access token expiration time, 0 means it doesn't expire 18 | AccessTokenExp time.Duration 19 | // refresh token expiration time, 0 means it doesn't expire 20 | RefreshTokenExp time.Duration 21 | // whether to generate the refreshing token 22 | IsGenerateRefresh bool 23 | // whether to reset the refreshing create time 24 | IsResetRefreshTime bool 25 | // whether to remove access token 26 | IsRemoveAccess bool 27 | // whether to remove refreshing token 28 | IsRemoveRefreshing bool 29 | } 30 | 31 | // default configs 32 | var ( 33 | DefaultCodeExp = time.Minute * 10 34 | DefaultAuthorizeCodeTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3, IsGenerateRefresh: true} 35 | DefaultImplicitTokenCfg = &Config{AccessTokenExp: time.Hour * 1} 36 | DefaultPasswordTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true} 37 | DefaultClientTokenCfg = &Config{AccessTokenExp: time.Hour * 2} 38 | DefaultRefreshTokenCfg = &RefreshingConfig{IsGenerateRefresh: true, IsRemoveAccess: true, IsRemoveRefreshing: true} 39 | ) 40 | -------------------------------------------------------------------------------- /manage/manage_test.go: -------------------------------------------------------------------------------- 1 | package manage_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/go-oauth2/oauth2/v4" 9 | "github.com/go-oauth2/oauth2/v4/manage" 10 | "github.com/go-oauth2/oauth2/v4/models" 11 | "github.com/go-oauth2/oauth2/v4/store" 12 | 13 | . "github.com/smartystreets/goconvey/convey" 14 | ) 15 | 16 | func TestManager(t *testing.T) { 17 | Convey("Manager test", t, func() { 18 | manager := manage.NewDefaultManager() 19 | ctx := context.Background() 20 | 21 | manager.MustTokenStorage(store.NewMemoryTokenStore()) 22 | 23 | clientStore := store.NewClientStore() 24 | _ = clientStore.Set("1", &models.Client{ 25 | ID: "1", 26 | Secret: "11", 27 | Domain: "http://localhost", 28 | }) 29 | manager.MapClientStorage(clientStore) 30 | 31 | tgr := &oauth2.TokenGenerateRequest{ 32 | ClientID: "1", 33 | UserID: "123456", 34 | RedirectURI: "http://localhost/oauth2", 35 | Scope: "all", 36 | } 37 | 38 | Convey("GetClient test", func() { 39 | cli, err := manager.GetClient(ctx, "1") 40 | So(err, ShouldBeNil) 41 | So(cli.GetSecret(), ShouldEqual, "11") 42 | }) 43 | 44 | Convey("Token test", func() { 45 | testManager(tgr, manager) 46 | }) 47 | 48 | Convey("zero expiration access token test", func() { 49 | testZeroAccessExpirationManager(tgr, manager) 50 | testCannotRequestZeroExpirationAccessTokens(tgr, manager) 51 | }) 52 | 53 | Convey("zero expiration refresh token test", func() { 54 | testZeroRefreshExpirationManager(tgr, manager) 55 | }) 56 | }) 57 | } 58 | 59 | func testManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) { 60 | ctx := context.Background() 61 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr) 62 | So(err, ShouldBeNil) 63 | 64 | code := cti.GetCode() 65 | So(code, ShouldNotBeEmpty) 66 | 67 | atParams := &oauth2.TokenGenerateRequest{ 68 | ClientID: tgr.ClientID, 69 | ClientSecret: "11", 70 | RedirectURI: tgr.RedirectURI, 71 | Code: code, 72 | } 73 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams) 74 | So(err, ShouldBeNil) 75 | 76 | accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh() 77 | So(accessToken, ShouldNotBeEmpty) 78 | So(refreshToken, ShouldNotBeEmpty) 79 | 80 | ainfo, err := manager.LoadAccessToken(ctx, accessToken) 81 | So(err, ShouldBeNil) 82 | So(ainfo.GetClientID(), ShouldEqual, atParams.ClientID) 83 | 84 | arinfo, err := manager.LoadRefreshToken(ctx, accessToken) 85 | So(err, ShouldNotBeNil) 86 | So(arinfo, ShouldBeNil) 87 | 88 | rainfo, err := manager.LoadAccessToken(ctx, refreshToken) 89 | So(err, ShouldNotBeNil) 90 | So(rainfo, ShouldBeNil) 91 | 92 | rinfo, err := manager.LoadRefreshToken(ctx, refreshToken) 93 | So(err, ShouldBeNil) 94 | So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID) 95 | 96 | refreshParams := &oauth2.TokenGenerateRequest{ 97 | Refresh: refreshToken, 98 | Scope: "owner", 99 | } 100 | rti, err := manager.RefreshAccessToken(ctx, refreshParams) 101 | So(err, ShouldBeNil) 102 | 103 | refreshAT := rti.GetAccess() 104 | So(refreshAT, ShouldNotBeEmpty) 105 | 106 | _, err = manager.LoadAccessToken(ctx, accessToken) 107 | So(err, ShouldNotBeNil) 108 | 109 | refreshAInfo, err := manager.LoadAccessToken(ctx, refreshAT) 110 | So(err, ShouldBeNil) 111 | So(refreshAInfo.GetScope(), ShouldEqual, "owner") 112 | 113 | err = manager.RemoveAccessToken(ctx, refreshAT) 114 | So(err, ShouldBeNil) 115 | 116 | _, err = manager.LoadAccessToken(ctx, refreshAT) 117 | So(err, ShouldNotBeNil) 118 | 119 | err = manager.RemoveRefreshToken(ctx, refreshToken) 120 | So(err, ShouldBeNil) 121 | 122 | _, err = manager.LoadRefreshToken(ctx, refreshToken) 123 | So(err, ShouldNotBeNil) 124 | } 125 | 126 | func testZeroAccessExpirationManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) { 127 | ctx := context.Background() 128 | config := manage.Config{ 129 | AccessTokenExp: 0, // Set explicitly as we're testing 0 (no) expiration 130 | IsGenerateRefresh: true, 131 | } 132 | m, ok := manager.(*manage.Manager) 133 | So(ok, ShouldBeTrue) 134 | m.SetAuthorizeCodeTokenCfg(&config) 135 | 136 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr) 137 | So(err, ShouldBeNil) 138 | 139 | code := cti.GetCode() 140 | So(code, ShouldNotBeEmpty) 141 | 142 | atParams := &oauth2.TokenGenerateRequest{ 143 | ClientID: tgr.ClientID, 144 | ClientSecret: "11", 145 | RedirectURI: tgr.RedirectURI, 146 | Code: code, 147 | } 148 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams) 149 | So(err, ShouldBeNil) 150 | 151 | accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh() 152 | So(accessToken, ShouldNotBeEmpty) 153 | So(refreshToken, ShouldNotBeEmpty) 154 | 155 | tokenInfo, err := manager.LoadAccessToken(ctx, accessToken) 156 | So(err, ShouldBeNil) 157 | So(tokenInfo, ShouldNotBeNil) 158 | So(tokenInfo.GetAccess(), ShouldEqual, accessToken) 159 | So(tokenInfo.GetAccessExpiresIn(), ShouldEqual, 0) 160 | } 161 | 162 | func testCannotRequestZeroExpirationAccessTokens(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) { 163 | ctx := context.Background() 164 | config := manage.Config{ 165 | AccessTokenExp: time.Hour * 5, 166 | } 167 | m, ok := manager.(*manage.Manager) 168 | So(ok, ShouldBeTrue) 169 | m.SetAuthorizeCodeTokenCfg(&config) 170 | 171 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr) 172 | So(err, ShouldBeNil) 173 | 174 | code := cti.GetCode() 175 | So(code, ShouldNotBeEmpty) 176 | 177 | atParams := &oauth2.TokenGenerateRequest{ 178 | ClientID: tgr.ClientID, 179 | ClientSecret: "11", 180 | RedirectURI: tgr.RedirectURI, 181 | AccessTokenExp: 0, // requesting token without expiration 182 | Code: code, 183 | } 184 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams) 185 | So(err, ShouldBeNil) 186 | 187 | accessToken := ati.GetAccess() 188 | So(accessToken, ShouldNotBeEmpty) 189 | So(ati.GetAccessExpiresIn(), ShouldEqual, time.Hour*5) 190 | } 191 | 192 | func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) { 193 | ctx := context.Background() 194 | config := manage.Config{ 195 | RefreshTokenExp: 0, // Set explicitly as we're testing 0 (no) expiration 196 | IsGenerateRefresh: true, 197 | } 198 | m, ok := manager.(*manage.Manager) 199 | So(ok, ShouldBeTrue) 200 | m.SetAuthorizeCodeTokenCfg(&config) 201 | 202 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr) 203 | So(err, ShouldBeNil) 204 | 205 | code := cti.GetCode() 206 | So(code, ShouldNotBeEmpty) 207 | 208 | atParams := &oauth2.TokenGenerateRequest{ 209 | ClientID: tgr.ClientID, 210 | ClientSecret: "11", 211 | RedirectURI: tgr.RedirectURI, 212 | AccessTokenExp: time.Hour, 213 | Code: code, 214 | } 215 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams) 216 | So(err, ShouldBeNil) 217 | 218 | accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh() 219 | So(accessToken, ShouldNotBeEmpty) 220 | So(refreshToken, ShouldNotBeEmpty) 221 | 222 | tokenInfo, err := manager.LoadRefreshToken(ctx, refreshToken) 223 | So(err, ShouldBeNil) 224 | So(tokenInfo, ShouldNotBeNil) 225 | So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken) 226 | So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0) 227 | 228 | // LoadAccessToken also checks refresh expiry 229 | tokenInfo, err = manager.LoadAccessToken(ctx, accessToken) 230 | So(err, ShouldBeNil) 231 | So(tokenInfo, ShouldNotBeNil) 232 | So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken) 233 | So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0) 234 | } 235 | -------------------------------------------------------------------------------- /manage/manager.go: -------------------------------------------------------------------------------- 1 | package manage 2 | 3 | import ( 4 | "context" 5 | "net/url" 6 | "time" 7 | 8 | "github.com/go-oauth2/oauth2/v4" 9 | "github.com/go-oauth2/oauth2/v4/errors" 10 | "github.com/go-oauth2/oauth2/v4/generates" 11 | "github.com/go-oauth2/oauth2/v4/models" 12 | ) 13 | 14 | // NewDefaultManager create to default authorization management instance 15 | func NewDefaultManager() *Manager { 16 | m := NewManager() 17 | // default implementation 18 | m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) 19 | m.MapAccessGenerate(generates.NewAccessGenerate()) 20 | 21 | return m 22 | } 23 | 24 | // NewManager create to authorization management instance 25 | func NewManager() *Manager { 26 | return &Manager{ 27 | gtcfg: make(map[oauth2.GrantType]*Config), 28 | validateURI: DefaultValidateURI, 29 | } 30 | } 31 | 32 | // Manager provide authorization management 33 | type Manager struct { 34 | codeExp time.Duration 35 | gtcfg map[oauth2.GrantType]*Config 36 | rcfg *RefreshingConfig 37 | validateURI ValidateURIHandler 38 | extractExtension ExtractExtensionHandler 39 | authorizeGenerate oauth2.AuthorizeGenerate 40 | accessGenerate oauth2.AccessGenerate 41 | tokenStore oauth2.TokenStore 42 | clientStore oauth2.ClientStore 43 | } 44 | 45 | // get grant type config 46 | func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { 47 | if c, ok := m.gtcfg[gt]; ok && c != nil { 48 | return c 49 | } 50 | switch gt { 51 | case oauth2.AuthorizationCode: 52 | return DefaultAuthorizeCodeTokenCfg 53 | case oauth2.Implicit: 54 | return DefaultImplicitTokenCfg 55 | case oauth2.PasswordCredentials: 56 | return DefaultPasswordTokenCfg 57 | case oauth2.ClientCredentials: 58 | return DefaultClientTokenCfg 59 | } 60 | return &Config{} 61 | } 62 | 63 | // SetAuthorizeCodeExp set the authorization code expiration time 64 | func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { 65 | m.codeExp = exp 66 | } 67 | 68 | // SetAuthorizeCodeTokenCfg set the authorization code grant token config 69 | func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { 70 | m.gtcfg[oauth2.AuthorizationCode] = cfg 71 | } 72 | 73 | // SetImplicitTokenCfg set the implicit grant token config 74 | func (m *Manager) SetImplicitTokenCfg(cfg *Config) { 75 | m.gtcfg[oauth2.Implicit] = cfg 76 | } 77 | 78 | // SetPasswordTokenCfg set the password grant token config 79 | func (m *Manager) SetPasswordTokenCfg(cfg *Config) { 80 | m.gtcfg[oauth2.PasswordCredentials] = cfg 81 | } 82 | 83 | // SetClientTokenCfg set the client grant token config 84 | func (m *Manager) SetClientTokenCfg(cfg *Config) { 85 | m.gtcfg[oauth2.ClientCredentials] = cfg 86 | } 87 | 88 | // SetRefreshTokenCfg set the refreshing token config 89 | func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { 90 | m.rcfg = cfg 91 | } 92 | 93 | // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI 94 | func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { 95 | m.validateURI = handler 96 | } 97 | 98 | // SetExtractExtensionHandler set the token extension extractor 99 | func (m *Manager) SetExtractExtensionHandler(handler ExtractExtensionHandler) { 100 | m.extractExtension = handler 101 | } 102 | 103 | // MapAuthorizeGenerate mapping the authorize code generate interface 104 | func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { 105 | m.authorizeGenerate = gen 106 | } 107 | 108 | // MapAccessGenerate mapping the access token generate interface 109 | func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { 110 | m.accessGenerate = gen 111 | } 112 | 113 | // MapClientStorage mapping the client store interface 114 | func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { 115 | m.clientStore = stor 116 | } 117 | 118 | // MustClientStorage mandatory mapping the client store interface 119 | func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { 120 | if err != nil { 121 | panic(err.Error()) 122 | } 123 | m.clientStore = stor 124 | } 125 | 126 | // MapTokenStorage mapping the token store interface 127 | func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { 128 | m.tokenStore = stor 129 | } 130 | 131 | // MustTokenStorage mandatory mapping the token store interface 132 | func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { 133 | if err != nil { 134 | panic(err) 135 | } 136 | m.tokenStore = stor 137 | } 138 | 139 | // GetClient get the client information 140 | func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { 141 | cli, err = m.clientStore.GetByID(ctx, clientID) 142 | if err != nil { 143 | return 144 | } else if cli == nil { 145 | err = errors.ErrInvalidClient 146 | } 147 | return 148 | } 149 | 150 | // GenerateAuthToken generate the authorization token(code) 151 | func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 152 | cli, err := m.GetClient(ctx, tgr.ClientID) 153 | if err != nil { 154 | return nil, err 155 | } else if tgr.RedirectURI != "" { 156 | if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { 157 | return nil, err 158 | } 159 | } 160 | 161 | ti := models.NewToken() 162 | if m.extractExtension != nil { 163 | m.extractExtension(tgr, ti) 164 | } 165 | ti.SetClientID(tgr.ClientID) 166 | ti.SetUserID(tgr.UserID) 167 | ti.SetRedirectURI(tgr.RedirectURI) 168 | ti.SetScope(tgr.Scope) 169 | 170 | createAt := time.Now() 171 | td := &oauth2.GenerateBasic{ 172 | Client: cli, 173 | UserID: tgr.UserID, 174 | CreateAt: createAt, 175 | TokenInfo: ti, 176 | Request: tgr.Request, 177 | } 178 | switch rt { 179 | case oauth2.Code: 180 | codeExp := m.codeExp 181 | if codeExp == 0 { 182 | codeExp = DefaultCodeExp 183 | } 184 | ti.SetCodeCreateAt(createAt) 185 | ti.SetCodeExpiresIn(codeExp) 186 | if exp := tgr.AccessTokenExp; exp > 0 { 187 | ti.SetAccessExpiresIn(exp) 188 | } 189 | if tgr.CodeChallenge != "" { 190 | ti.SetCodeChallenge(tgr.CodeChallenge) 191 | ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod) 192 | } 193 | 194 | tv, err := m.authorizeGenerate.Token(ctx, td) 195 | if err != nil { 196 | return nil, err 197 | } 198 | ti.SetCode(tv) 199 | case oauth2.Token: 200 | // set access token expires 201 | icfg := m.grantConfig(oauth2.Implicit) 202 | aexp := icfg.AccessTokenExp 203 | if exp := tgr.AccessTokenExp; exp > 0 { 204 | aexp = exp 205 | } 206 | ti.SetAccessCreateAt(createAt) 207 | ti.SetAccessExpiresIn(aexp) 208 | 209 | if icfg.IsGenerateRefresh { 210 | ti.SetRefreshCreateAt(createAt) 211 | ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) 212 | } 213 | 214 | tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) 215 | if err != nil { 216 | return nil, err 217 | } 218 | ti.SetAccess(tv) 219 | 220 | if rv != "" { 221 | ti.SetRefresh(rv) 222 | } 223 | } 224 | 225 | err = m.tokenStore.Create(ctx, ti) 226 | if err != nil { 227 | return nil, err 228 | } 229 | return ti, nil 230 | } 231 | 232 | // get authorization code data 233 | func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { 234 | ti, err := m.tokenStore.GetByCode(ctx, code) 235 | if err != nil { 236 | return nil, err 237 | } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { 238 | err = errors.ErrInvalidAuthorizeCode 239 | return nil, errors.ErrInvalidAuthorizeCode 240 | } 241 | return ti, nil 242 | } 243 | 244 | // delete authorization code data 245 | func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { 246 | return m.tokenStore.RemoveByCode(ctx, code) 247 | } 248 | 249 | // get and delete authorization code data 250 | func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 251 | code := tgr.Code 252 | ti, err := m.getAuthorizationCode(ctx, code) 253 | if err != nil { 254 | return nil, err 255 | } else if ti.GetClientID() != tgr.ClientID { 256 | return nil, errors.ErrInvalidAuthorizeCode 257 | } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { 258 | return nil, errors.ErrInvalidAuthorizeCode 259 | } 260 | 261 | err = m.delAuthorizationCode(ctx, code) 262 | if err != nil { 263 | return nil, err 264 | } 265 | return ti, nil 266 | } 267 | 268 | func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error { 269 | cc := ti.GetCodeChallenge() 270 | // early return 271 | if cc == "" && ver == "" { 272 | return nil 273 | } 274 | if cc == "" { 275 | return errors.ErrMissingCodeVerifier 276 | } 277 | if ver == "" { 278 | return errors.ErrMissingCodeVerifier 279 | } 280 | ccm := ti.GetCodeChallengeMethod() 281 | if ccm.String() == "" { 282 | ccm = oauth2.CodeChallengePlain 283 | } 284 | if !ccm.Validate(cc, ver) { 285 | return errors.ErrInvalidCodeChallenge 286 | } 287 | return nil 288 | } 289 | 290 | // GenerateAccessToken generate the access token 291 | func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 292 | cli, err := m.GetClient(ctx, tgr.ClientID) 293 | if err != nil { 294 | return nil, err 295 | } 296 | if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { 297 | if !cliPass.VerifyPassword(tgr.ClientSecret) { 298 | return nil, errors.ErrInvalidClient 299 | } 300 | } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() { 301 | return nil, errors.ErrInvalidClient 302 | } 303 | if tgr.RedirectURI != "" { 304 | if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { 305 | return nil, err 306 | } 307 | } 308 | 309 | if gt == oauth2.ClientCredentials && cli.IsPublic() == true { 310 | return nil, errors.ErrInvalidClient 311 | } 312 | 313 | var extension url.Values 314 | 315 | if gt == oauth2.AuthorizationCode { 316 | ti, err := m.getAndDelAuthorizationCode(ctx, tgr) 317 | if err != nil { 318 | return nil, err 319 | } 320 | if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil { 321 | return nil, err 322 | } 323 | tgr.UserID = ti.GetUserID() 324 | tgr.Scope = ti.GetScope() 325 | if exp := ti.GetAccessExpiresIn(); exp > 0 { 326 | tgr.AccessTokenExp = exp 327 | } 328 | if eti, ok := ti.(oauth2.ExtendableTokenInfo); ok { 329 | extension = eti.GetExtension() 330 | } 331 | } 332 | 333 | ti := models.NewToken() 334 | ti.SetExtension(extension) 335 | if m.extractExtension != nil { 336 | m.extractExtension(tgr, ti) 337 | } 338 | ti.SetClientID(tgr.ClientID) 339 | ti.SetUserID(tgr.UserID) 340 | ti.SetRedirectURI(tgr.RedirectURI) 341 | ti.SetScope(tgr.Scope) 342 | 343 | createAt := time.Now() 344 | ti.SetAccessCreateAt(createAt) 345 | 346 | // set access token expires 347 | gcfg := m.grantConfig(gt) 348 | aexp := gcfg.AccessTokenExp 349 | if exp := tgr.AccessTokenExp; exp > 0 { 350 | aexp = exp 351 | } 352 | ti.SetAccessExpiresIn(aexp) 353 | if gcfg.IsGenerateRefresh { 354 | ti.SetRefreshCreateAt(createAt) 355 | ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) 356 | } 357 | 358 | td := &oauth2.GenerateBasic{ 359 | Client: cli, 360 | UserID: tgr.UserID, 361 | CreateAt: createAt, 362 | TokenInfo: ti, 363 | Request: tgr.Request, 364 | } 365 | 366 | av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) 367 | if err != nil { 368 | return nil, err 369 | } 370 | ti.SetAccess(av) 371 | 372 | if rv != "" { 373 | ti.SetRefresh(rv) 374 | } 375 | 376 | err = m.tokenStore.Create(ctx, ti) 377 | if err != nil { 378 | return nil, err 379 | } 380 | 381 | return ti, nil 382 | } 383 | 384 | // RefreshAccessToken refreshing an access token 385 | func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { 386 | ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) 387 | if err != nil { 388 | return nil, err 389 | } 390 | 391 | cli, err := m.GetClient(ctx, ti.GetClientID()) 392 | if err != nil { 393 | return nil, err 394 | } 395 | 396 | oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() 397 | 398 | td := &oauth2.GenerateBasic{ 399 | Client: cli, 400 | UserID: ti.GetUserID(), 401 | CreateAt: time.Now(), 402 | TokenInfo: ti, 403 | Request: tgr.Request, 404 | } 405 | 406 | rcfg := DefaultRefreshTokenCfg 407 | if v := m.rcfg; v != nil { 408 | rcfg = v 409 | } 410 | 411 | ti.SetAccessCreateAt(td.CreateAt) 412 | if v := rcfg.AccessTokenExp; v > 0 { 413 | ti.SetAccessExpiresIn(v) 414 | } 415 | 416 | if v := rcfg.RefreshTokenExp; v > 0 { 417 | ti.SetRefreshExpiresIn(v) 418 | } 419 | 420 | if rcfg.IsResetRefreshTime { 421 | ti.SetRefreshCreateAt(td.CreateAt) 422 | } 423 | 424 | if scope := tgr.Scope; scope != "" { 425 | ti.SetScope(scope) 426 | } 427 | 428 | tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) 429 | if err != nil { 430 | return nil, err 431 | } 432 | 433 | ti.SetAccess(tv) 434 | if rv != "" { 435 | ti.SetRefresh(rv) 436 | } 437 | 438 | if err := m.tokenStore.Create(ctx, ti); err != nil { 439 | return nil, err 440 | } 441 | 442 | if rcfg.IsRemoveAccess { 443 | // remove the old access token 444 | if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { 445 | return nil, err 446 | } 447 | } 448 | 449 | if rcfg.IsRemoveRefreshing && rv != "" { 450 | // remove the old refresh token 451 | if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { 452 | return nil, err 453 | } 454 | } 455 | 456 | if rv == "" { 457 | ti.SetRefresh("") 458 | ti.SetRefreshCreateAt(time.Now()) 459 | ti.SetRefreshExpiresIn(0) 460 | } 461 | 462 | return ti, nil 463 | } 464 | 465 | // RemoveAccessToken use the access token to delete the token information 466 | func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { 467 | if access == "" { 468 | return errors.ErrInvalidAccessToken 469 | } 470 | return m.tokenStore.RemoveByAccess(ctx, access) 471 | } 472 | 473 | // RemoveRefreshToken use the refresh token to delete the token information 474 | func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { 475 | if refresh == "" { 476 | return errors.ErrInvalidAccessToken 477 | } 478 | return m.tokenStore.RemoveByRefresh(ctx, refresh) 479 | } 480 | 481 | // LoadAccessToken according to the access token for corresponding token information 482 | func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { 483 | if access == "" { 484 | return nil, errors.ErrInvalidAccessToken 485 | } 486 | 487 | ct := time.Now() 488 | ti, err := m.tokenStore.GetByAccess(ctx, access) 489 | if err != nil { 490 | return nil, err 491 | } else if ti == nil || ti.GetAccess() != access { 492 | return nil, errors.ErrInvalidAccessToken 493 | } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && 494 | ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { 495 | return nil, errors.ErrExpiredRefreshToken 496 | } else if ti.GetAccessExpiresIn() != 0 && 497 | ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { 498 | return nil, errors.ErrExpiredAccessToken 499 | } 500 | return ti, nil 501 | } 502 | 503 | // LoadRefreshToken according to the refresh token for corresponding token information 504 | func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { 505 | if refresh == "" { 506 | return nil, errors.ErrInvalidRefreshToken 507 | } 508 | 509 | ti, err := m.tokenStore.GetByRefresh(ctx, refresh) 510 | if err != nil { 511 | return nil, err 512 | } else if ti == nil || ti.GetRefresh() != refresh { 513 | return nil, errors.ErrInvalidRefreshToken 514 | } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire 515 | ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { 516 | return nil, errors.ErrExpiredRefreshToken 517 | } 518 | return ti, nil 519 | } 520 | -------------------------------------------------------------------------------- /manage/util.go: -------------------------------------------------------------------------------- 1 | package manage 2 | 3 | import ( 4 | "github.com/go-oauth2/oauth2/v4" 5 | "net/url" 6 | "strings" 7 | 8 | "github.com/go-oauth2/oauth2/v4/errors" 9 | ) 10 | 11 | type ( 12 | // ValidateURIHandler validates that redirectURI is contained in baseURI 13 | ValidateURIHandler func(baseURI, redirectURI string) error 14 | ExtractExtensionHandler func(*oauth2.TokenGenerateRequest, oauth2.ExtendableTokenInfo) 15 | ) 16 | 17 | // DefaultValidateURI validates that redirectURI is contained in baseURI 18 | func DefaultValidateURI(baseURI string, redirectURI string) error { 19 | base, err := url.Parse(baseURI) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | redirect, err := url.Parse(redirectURI) 25 | if err != nil { 26 | return err 27 | } 28 | if !strings.HasSuffix(redirect.Host, base.Host) { 29 | return errors.ErrInvalidRedirectURI 30 | } 31 | return nil 32 | } 33 | -------------------------------------------------------------------------------- /manage/util_test.go: -------------------------------------------------------------------------------- 1 | package manage_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/go-oauth2/oauth2/v4/manage" 7 | 8 | . "github.com/smartystreets/goconvey/convey" 9 | ) 10 | 11 | func TestUtil(t *testing.T) { 12 | Convey("Util Test", t, func() { 13 | Convey("ValidateURI Test", func() { 14 | err := manage.DefaultValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") 15 | So(err, ShouldBeNil) 16 | }) 17 | }) 18 | } 19 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package oauth2 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | ) 7 | 8 | type ( 9 | // ClientInfo the client information model interface 10 | ClientInfo interface { 11 | GetID() string 12 | GetSecret() string 13 | GetDomain() string 14 | IsPublic() bool 15 | GetUserID() string 16 | } 17 | 18 | // ClientPasswordVerifier the password handler interface 19 | ClientPasswordVerifier interface { 20 | VerifyPassword(string) bool 21 | } 22 | 23 | // TokenInfo the token information model interface 24 | TokenInfo interface { 25 | New() TokenInfo 26 | 27 | GetClientID() string 28 | SetClientID(string) 29 | GetUserID() string 30 | SetUserID(string) 31 | GetRedirectURI() string 32 | SetRedirectURI(string) 33 | GetScope() string 34 | SetScope(string) 35 | 36 | GetCode() string 37 | SetCode(string) 38 | GetCodeCreateAt() time.Time 39 | SetCodeCreateAt(time.Time) 40 | GetCodeExpiresIn() time.Duration 41 | SetCodeExpiresIn(time.Duration) 42 | GetCodeChallenge() string 43 | SetCodeChallenge(string) 44 | GetCodeChallengeMethod() CodeChallengeMethod 45 | SetCodeChallengeMethod(CodeChallengeMethod) 46 | 47 | GetAccess() string 48 | SetAccess(string) 49 | GetAccessCreateAt() time.Time 50 | SetAccessCreateAt(time.Time) 51 | GetAccessExpiresIn() time.Duration 52 | SetAccessExpiresIn(time.Duration) 53 | 54 | GetRefresh() string 55 | SetRefresh(string) 56 | GetRefreshCreateAt() time.Time 57 | SetRefreshCreateAt(time.Time) 58 | GetRefreshExpiresIn() time.Duration 59 | SetRefreshExpiresIn(time.Duration) 60 | } 61 | 62 | ExtendableTokenInfo interface { 63 | TokenInfo 64 | GetExtension() url.Values 65 | SetExtension(url.Values) 66 | } 67 | ) 68 | -------------------------------------------------------------------------------- /models/client.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | // Client client model 4 | type Client struct { 5 | ID string 6 | Secret string 7 | Domain string 8 | Public bool 9 | UserID string 10 | } 11 | 12 | // GetID client id 13 | func (c *Client) GetID() string { 14 | return c.ID 15 | } 16 | 17 | // GetSecret client secret 18 | func (c *Client) GetSecret() string { 19 | return c.Secret 20 | } 21 | 22 | // GetDomain client domain 23 | func (c *Client) GetDomain() string { 24 | return c.Domain 25 | } 26 | 27 | // IsPublic public 28 | func (c *Client) IsPublic() bool { 29 | return c.Public 30 | } 31 | 32 | // GetUserID user id 33 | func (c *Client) GetUserID() string { 34 | return c.UserID 35 | } 36 | -------------------------------------------------------------------------------- /models/token.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | 7 | "github.com/go-oauth2/oauth2/v4" 8 | ) 9 | 10 | // NewToken create to token model instance 11 | func NewToken() *Token { 12 | return &Token{Extension: make(url.Values)} 13 | } 14 | 15 | // Token token model 16 | type Token struct { 17 | ClientID string `bson:"ClientID"` 18 | UserID string `bson:"UserID"` 19 | RedirectURI string `bson:"RedirectURI"` 20 | Scope string `bson:"Scope"` 21 | Code string `bson:"Code"` 22 | CodeChallenge string `bson:"CodeChallenge"` 23 | CodeChallengeMethod string `bson:"CodeChallengeMethod"` 24 | CodeCreateAt time.Time `bson:"CodeCreateAt"` 25 | CodeExpiresIn time.Duration `bson:"CodeExpiresIn"` 26 | Access string `bson:"Access"` 27 | AccessCreateAt time.Time `bson:"AccessCreateAt"` 28 | AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` 29 | Refresh string `bson:"Refresh"` 30 | RefreshCreateAt time.Time `bson:"RefreshCreateAt"` 31 | RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` 32 | Extension url.Values `bson:"Extension"` 33 | } 34 | 35 | // New create to token model instance 36 | func (t *Token) New() oauth2.TokenInfo { 37 | return NewToken() 38 | } 39 | 40 | // GetClientID the client id 41 | func (t *Token) GetClientID() string { 42 | return t.ClientID 43 | } 44 | 45 | // SetClientID the client id 46 | func (t *Token) SetClientID(clientID string) { 47 | t.ClientID = clientID 48 | } 49 | 50 | // GetUserID the user id 51 | func (t *Token) GetUserID() string { 52 | return t.UserID 53 | } 54 | 55 | // SetUserID the user id 56 | func (t *Token) SetUserID(userID string) { 57 | t.UserID = userID 58 | } 59 | 60 | // GetRedirectURI redirect URI 61 | func (t *Token) GetRedirectURI() string { 62 | return t.RedirectURI 63 | } 64 | 65 | // SetRedirectURI redirect URI 66 | func (t *Token) SetRedirectURI(redirectURI string) { 67 | t.RedirectURI = redirectURI 68 | } 69 | 70 | // GetScope get scope of authorization 71 | func (t *Token) GetScope() string { 72 | return t.Scope 73 | } 74 | 75 | // SetScope get scope of authorization 76 | func (t *Token) SetScope(scope string) { 77 | t.Scope = scope 78 | } 79 | 80 | // GetCode authorization code 81 | func (t *Token) GetCode() string { 82 | return t.Code 83 | } 84 | 85 | // SetCode authorization code 86 | func (t *Token) SetCode(code string) { 87 | t.Code = code 88 | } 89 | 90 | // GetCodeCreateAt create Time 91 | func (t *Token) GetCodeCreateAt() time.Time { 92 | return t.CodeCreateAt 93 | } 94 | 95 | // SetCodeCreateAt create Time 96 | func (t *Token) SetCodeCreateAt(createAt time.Time) { 97 | t.CodeCreateAt = createAt 98 | } 99 | 100 | // GetCodeExpiresIn the lifetime in seconds of the authorization code 101 | func (t *Token) GetCodeExpiresIn() time.Duration { 102 | return t.CodeExpiresIn 103 | } 104 | 105 | // SetCodeExpiresIn the lifetime in seconds of the authorization code 106 | func (t *Token) SetCodeExpiresIn(exp time.Duration) { 107 | t.CodeExpiresIn = exp 108 | } 109 | 110 | // GetCodeChallenge challenge code 111 | func (t *Token) GetCodeChallenge() string { 112 | return t.CodeChallenge 113 | } 114 | 115 | // SetCodeChallenge challenge code 116 | func (t *Token) SetCodeChallenge(code string) { 117 | t.CodeChallenge = code 118 | } 119 | 120 | // GetCodeChallengeMethod challenge method 121 | func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod { 122 | return oauth2.CodeChallengeMethod(t.CodeChallengeMethod) 123 | } 124 | 125 | // SetCodeChallengeMethod challenge method 126 | func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) { 127 | t.CodeChallengeMethod = string(method) 128 | } 129 | 130 | // GetAccess access Token 131 | func (t *Token) GetAccess() string { 132 | return t.Access 133 | } 134 | 135 | // SetAccess access Token 136 | func (t *Token) SetAccess(access string) { 137 | t.Access = access 138 | } 139 | 140 | // GetAccessCreateAt create Time 141 | func (t *Token) GetAccessCreateAt() time.Time { 142 | return t.AccessCreateAt 143 | } 144 | 145 | // SetAccessCreateAt create Time 146 | func (t *Token) SetAccessCreateAt(createAt time.Time) { 147 | t.AccessCreateAt = createAt 148 | } 149 | 150 | // GetAccessExpiresIn the lifetime in seconds of the access token 151 | func (t *Token) GetAccessExpiresIn() time.Duration { 152 | return t.AccessExpiresIn 153 | } 154 | 155 | // SetAccessExpiresIn the lifetime in seconds of the access token 156 | func (t *Token) SetAccessExpiresIn(exp time.Duration) { 157 | t.AccessExpiresIn = exp 158 | } 159 | 160 | // GetRefresh refresh Token 161 | func (t *Token) GetRefresh() string { 162 | return t.Refresh 163 | } 164 | 165 | // SetRefresh refresh Token 166 | func (t *Token) SetRefresh(refresh string) { 167 | t.Refresh = refresh 168 | } 169 | 170 | // GetRefreshCreateAt create Time 171 | func (t *Token) GetRefreshCreateAt() time.Time { 172 | return t.RefreshCreateAt 173 | } 174 | 175 | // SetRefreshCreateAt create Time 176 | func (t *Token) SetRefreshCreateAt(createAt time.Time) { 177 | t.RefreshCreateAt = createAt 178 | } 179 | 180 | // GetRefreshExpiresIn the lifetime in seconds of the refresh token 181 | func (t *Token) GetRefreshExpiresIn() time.Duration { 182 | return t.RefreshExpiresIn 183 | } 184 | 185 | // SetRefreshExpiresIn the lifetime in seconds of the refresh token 186 | func (t *Token) SetRefreshExpiresIn(exp time.Duration) { 187 | t.RefreshExpiresIn = exp 188 | } 189 | 190 | // GetExtension extension of token 191 | func (t *Token) GetExtension() url.Values { 192 | return t.Extension 193 | } 194 | 195 | // SetExtension set extension of token 196 | func (t *Token) SetExtension(e url.Values) { 197 | t.Extension = e 198 | } 199 | -------------------------------------------------------------------------------- /server/config.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/go-oauth2/oauth2/v4" 8 | ) 9 | 10 | // Config configuration parameters 11 | type Config struct { 12 | TokenType string // token type 13 | AllowGetAccessRequest bool // to allow GET requests for the token 14 | AllowedResponseTypes []oauth2.ResponseType // allow the authorization type 15 | AllowedGrantTypes []oauth2.GrantType // allow the grant type 16 | AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod 17 | ForcePKCE bool 18 | } 19 | 20 | // NewConfig create to configuration instance 21 | func NewConfig() *Config { 22 | return &Config{ 23 | TokenType: "Bearer", 24 | AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code, oauth2.Token}, 25 | AllowedGrantTypes: []oauth2.GrantType{ 26 | oauth2.AuthorizationCode, 27 | oauth2.PasswordCredentials, 28 | oauth2.ClientCredentials, 29 | oauth2.Refreshing, 30 | }, 31 | AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ 32 | oauth2.CodeChallengePlain, 33 | oauth2.CodeChallengeS256, 34 | }, 35 | } 36 | } 37 | 38 | // AuthorizeRequest authorization request 39 | type AuthorizeRequest struct { 40 | ResponseType oauth2.ResponseType 41 | ClientID string 42 | Scope string 43 | RedirectURI string 44 | State string 45 | UserID string 46 | CodeChallenge string 47 | CodeChallengeMethod oauth2.CodeChallengeMethod 48 | AccessTokenExp time.Duration 49 | Request *http.Request 50 | } 51 | -------------------------------------------------------------------------------- /server/handler.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | "time" 8 | 9 | "github.com/go-oauth2/oauth2/v4" 10 | "github.com/go-oauth2/oauth2/v4/errors" 11 | ) 12 | 13 | type ( 14 | // ClientInfoHandler get client info from request 15 | ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) 16 | 17 | // ClientAuthorizedHandler check the client allows to use this authorization grant type 18 | ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) 19 | 20 | // ClientScopeHandler check the client allows to use scope 21 | ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) 22 | 23 | // UserAuthorizationHandler get user id from request authorization 24 | UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) 25 | 26 | // PasswordAuthorizationHandler get user id from username and password 27 | PasswordAuthorizationHandler func(ctx context.Context, clientID, username, password string) (userID string, err error) 28 | 29 | // RefreshingScopeHandler check the scope of the refreshing token 30 | RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) 31 | 32 | // RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other 33 | RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error) 34 | 35 | // ResponseErrorHandler response error handing 36 | ResponseErrorHandler func(re *errors.Response) 37 | 38 | // InternalErrorHandler internal error handing 39 | InternalErrorHandler func(err error) (re *errors.Response) 40 | 41 | // PreRedirectErrorHandler is used to override "redirect-on-error" behavior 42 | PreRedirectErrorHandler func(w http.ResponseWriter, req *AuthorizeRequest, err error) error 43 | 44 | // AuthorizeScopeHandler set the authorized scope 45 | AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) 46 | 47 | // AccessTokenExpHandler set expiration date for the access token 48 | AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) 49 | 50 | // ExtensionFieldsHandler in response to the access token with the extension of the field 51 | ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) 52 | 53 | // ResponseTokenHandler response token handling 54 | ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error 55 | 56 | // Handler to fetch the refresh token from the request 57 | RefreshTokenResolveHandler func(r *http.Request) (string, error) 58 | 59 | // Handler to fetch the access token from the request 60 | AccessTokenResolveHandler func(r *http.Request) (string, bool) 61 | ) 62 | 63 | // ClientFormHandler get client data from form 64 | func ClientFormHandler(r *http.Request) (string, string, error) { 65 | clientID := r.Form.Get("client_id") 66 | if clientID == "" { 67 | return "", "", errors.ErrInvalidClient 68 | } 69 | clientSecret := r.Form.Get("client_secret") 70 | return clientID, clientSecret, nil 71 | } 72 | 73 | // ClientBasicHandler get client data from basic authorization 74 | func ClientBasicHandler(r *http.Request) (string, string, error) { 75 | username, password, ok := r.BasicAuth() 76 | if !ok { 77 | return "", "", errors.ErrInvalidClient 78 | } 79 | return username, password, nil 80 | } 81 | 82 | func RefreshTokenFormResolveHandler(r *http.Request) (string, error) { 83 | rt := r.FormValue("refresh_token") 84 | if rt == "" { 85 | return "", errors.ErrInvalidRequest 86 | } 87 | 88 | return rt, nil 89 | } 90 | 91 | func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) { 92 | c, err := r.Cookie("refresh_token") 93 | if err != nil { 94 | return "", errors.ErrInvalidRequest 95 | } 96 | 97 | return c.Value, nil 98 | } 99 | 100 | func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) { 101 | token := "" 102 | auth := r.Header.Get("Authorization") 103 | prefix := "Bearer " 104 | 105 | if auth != "" && strings.HasPrefix(auth, prefix) { 106 | token = auth[len(prefix):] 107 | } else { 108 | token = r.FormValue("access_token") 109 | } 110 | 111 | return token, token != "" 112 | } 113 | 114 | func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) { 115 | c, err := r.Cookie("access_token") 116 | if err != nil { 117 | return "", false 118 | } 119 | 120 | return c.Value, true 121 | } 122 | -------------------------------------------------------------------------------- /server/handler_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/go-oauth2/oauth2/v4/errors" 12 | . "github.com/smartystreets/goconvey/convey" 13 | ) 14 | 15 | func TestRefreshTokenFormResolveHandler(t *testing.T) { 16 | Convey("Correct Request", t, func() { 17 | f := url.Values{} 18 | f.Add("refresh_token", "test_token") 19 | 20 | r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) 21 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 22 | 23 | token, err := RefreshTokenFormResolveHandler(r) 24 | So(err, ShouldBeNil) 25 | So(token, ShouldEqual, "test_token") 26 | }) 27 | 28 | Convey("Missing Refresh Token", t, func() { 29 | r := httptest.NewRequest("POST", "/", nil) 30 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 31 | 32 | token, err := RefreshTokenFormResolveHandler(r) 33 | So(err, ShouldBeError, errors.ErrInvalidRequest) 34 | So(token, ShouldBeEmpty) 35 | }) 36 | } 37 | 38 | func TestRefreshTokenCookieResolveHandler(t *testing.T) { 39 | Convey("Correct Request", t, func() { 40 | r := httptest.NewRequest(http.MethodPost, "/", nil) 41 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 42 | r.AddCookie(&http.Cookie{ 43 | Name: "refresh_token", 44 | Value: "test_token", 45 | HttpOnly: true, 46 | Path: "/", 47 | Domain: ".example.com", 48 | Expires: time.Now().Add(time.Hour), 49 | }) 50 | 51 | token, err := RefreshTokenCookieResolveHandler(r) 52 | So(err, ShouldBeNil) 53 | So(token, ShouldEqual, "test_token") 54 | }) 55 | 56 | Convey("Missing Refresh Token", t, func() { 57 | r := httptest.NewRequest("POST", "/", nil) 58 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 59 | 60 | token, err := RefreshTokenCookieResolveHandler(r) 61 | So(err, ShouldBeError, errors.ErrInvalidRequest) 62 | So(token, ShouldBeEmpty) 63 | }) 64 | } 65 | 66 | func TestAccessTokenDefaultHandler(t *testing.T) { 67 | Convey("Request Has Header", t, func() { 68 | r := httptest.NewRequest(http.MethodPost, "/", nil) 69 | r.Header.Add("Authorization", "Bearer test_token") 70 | 71 | token, ok := AccessTokenDefaultResolveHandler(r) 72 | So(ok, ShouldBeTrue) 73 | So(token, ShouldEqual, "test_token") 74 | }) 75 | 76 | Convey("Request Has FormValue", t, func() { 77 | f := url.Values{} 78 | f.Add("access_token", "test_token") 79 | r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) 80 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 81 | 82 | token, ok := AccessTokenDefaultResolveHandler(r) 83 | So(ok, ShouldBeTrue) 84 | So(token, ShouldEqual, "test_token") 85 | }) 86 | 87 | Convey("Request Has Nothing", t, func() { 88 | r := httptest.NewRequest(http.MethodPost, "/", nil) 89 | 90 | token, ok := AccessTokenDefaultResolveHandler(r) 91 | So(ok, ShouldBeFalse) 92 | So(token, ShouldBeEmpty) 93 | }) 94 | } 95 | 96 | func TestAccessTokenCookieHandler(t *testing.T) { 97 | Convey("Request Has Cookie", t, func() { 98 | r := httptest.NewRequest(http.MethodPost, "/", nil) 99 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 100 | r.AddCookie(&http.Cookie{ 101 | Name: "access_token", 102 | Value: "test_token", 103 | HttpOnly: true, 104 | Path: "/", 105 | Domain: ".example.com", 106 | Expires: time.Now().Add(time.Hour), 107 | }) 108 | 109 | token, ok := AccessTokenCookieResolveHandler(r) 110 | So(ok, ShouldBeTrue) 111 | So(token, ShouldEqual, "test_token") 112 | }) 113 | 114 | Convey("Request Has No Cookie", t, func() { 115 | r := httptest.NewRequest(http.MethodPost, "/", nil) 116 | 117 | token, ok := AccessTokenCookieResolveHandler(r) 118 | So(ok, ShouldBeFalse) 119 | So(token, ShouldBeEmpty) 120 | }) 121 | } 122 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | "time" 10 | 11 | "github.com/go-oauth2/oauth2/v4" 12 | "github.com/go-oauth2/oauth2/v4/errors" 13 | ) 14 | 15 | // NewDefaultServer create a default authorization server 16 | func NewDefaultServer(manager oauth2.Manager) *Server { 17 | return NewServer(NewConfig(), manager) 18 | } 19 | 20 | // NewServer create authorization server 21 | func NewServer(cfg *Config, manager oauth2.Manager) *Server { 22 | srv := &Server{ 23 | Config: cfg, 24 | Manager: manager, 25 | } 26 | 27 | // default handlers 28 | srv.ClientInfoHandler = ClientBasicHandler 29 | srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler 30 | srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler 31 | 32 | srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { 33 | return "", errors.ErrAccessDenied 34 | } 35 | 36 | srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) { 37 | return "", errors.ErrAccessDenied 38 | } 39 | return srv 40 | } 41 | 42 | // Server Provide authorization server 43 | type Server struct { 44 | Config *Config 45 | Manager oauth2.Manager 46 | ClientInfoHandler ClientInfoHandler 47 | ClientAuthorizedHandler ClientAuthorizedHandler 48 | ClientScopeHandler ClientScopeHandler 49 | UserAuthorizationHandler UserAuthorizationHandler 50 | PasswordAuthorizationHandler PasswordAuthorizationHandler 51 | RefreshingValidationHandler RefreshingValidationHandler 52 | PreRedirectErrorHandler PreRedirectErrorHandler 53 | RefreshingScopeHandler RefreshingScopeHandler 54 | ResponseErrorHandler ResponseErrorHandler 55 | InternalErrorHandler InternalErrorHandler 56 | ExtensionFieldsHandler ExtensionFieldsHandler 57 | AccessTokenExpHandler AccessTokenExpHandler 58 | AuthorizeScopeHandler AuthorizeScopeHandler 59 | ResponseTokenHandler ResponseTokenHandler 60 | RefreshTokenResolveHandler RefreshTokenResolveHandler 61 | AccessTokenResolveHandler AccessTokenResolveHandler 62 | } 63 | 64 | func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { 65 | if fn := s.PreRedirectErrorHandler; fn != nil { 66 | return fn(w, req, err) 67 | } 68 | 69 | return s.redirectError(w, req, err) 70 | } 71 | 72 | func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { 73 | if req == nil { 74 | return err 75 | } 76 | 77 | data, _, _ := s.GetErrorData(err) 78 | return s.redirect(w, req, data) 79 | } 80 | 81 | func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { 82 | uri, err := s.GetRedirectURI(req, data) 83 | if err != nil { 84 | return err 85 | } 86 | 87 | w.Header().Set("Location", uri) 88 | w.WriteHeader(302) 89 | return nil 90 | } 91 | 92 | func (s *Server) tokenError(w http.ResponseWriter, err error) error { 93 | data, statusCode, header := s.GetErrorData(err) 94 | return s.token(w, data, header, statusCode) 95 | } 96 | 97 | func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { 98 | if fn := s.ResponseTokenHandler; fn != nil { 99 | return fn(w, data, header, statusCode...) 100 | } 101 | w.Header().Set("Content-Type", "application/json;charset=UTF-8") 102 | w.Header().Set("Cache-Control", "no-store") 103 | w.Header().Set("Pragma", "no-cache") 104 | 105 | for key := range header { 106 | w.Header().Set(key, header.Get(key)) 107 | } 108 | 109 | status := http.StatusOK 110 | if len(statusCode) > 0 && statusCode[0] > 0 { 111 | status = statusCode[0] 112 | } 113 | 114 | w.WriteHeader(status) 115 | return json.NewEncoder(w).Encode(data) 116 | } 117 | 118 | // GetRedirectURI get redirect uri 119 | func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { 120 | u, err := url.Parse(req.RedirectURI) 121 | if err != nil { 122 | return "", err 123 | } 124 | 125 | q := u.Query() 126 | if req.State != "" { 127 | q.Set("state", req.State) 128 | } 129 | 130 | for k, v := range data { 131 | q.Set(k, fmt.Sprint(v)) 132 | } 133 | 134 | switch req.ResponseType { 135 | case oauth2.Code: 136 | u.RawQuery = q.Encode() 137 | case oauth2.Token: 138 | u.RawQuery = "" 139 | fragment, err := url.QueryUnescape(q.Encode()) 140 | if err != nil { 141 | return "", err 142 | } 143 | u.Fragment = fragment 144 | } 145 | 146 | return u.String(), nil 147 | } 148 | 149 | // CheckResponseType check allows response type 150 | func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { 151 | for _, art := range s.Config.AllowedResponseTypes { 152 | if art == rt { 153 | return true 154 | } 155 | } 156 | return false 157 | } 158 | 159 | // CheckCodeChallengeMethod checks for allowed code challenge method 160 | func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { 161 | for _, c := range s.Config.AllowedCodeChallengeMethods { 162 | if c == ccm { 163 | return true 164 | } 165 | } 166 | return false 167 | } 168 | 169 | // ValidationAuthorizeRequest the authorization request validation 170 | func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { 171 | redirectURI := r.FormValue("redirect_uri") 172 | clientID := r.FormValue("client_id") 173 | if !(r.Method == "GET" || r.Method == "POST") || 174 | clientID == "" { 175 | return nil, errors.ErrInvalidRequest 176 | } 177 | 178 | resType := oauth2.ResponseType(r.FormValue("response_type")) 179 | if resType.String() == "" { 180 | return nil, errors.ErrUnsupportedResponseType 181 | } else if allowed := s.CheckResponseType(resType); !allowed { 182 | return nil, errors.ErrUnauthorizedClient 183 | } 184 | 185 | cc := r.FormValue("code_challenge") 186 | if cc == "" && s.Config.ForcePKCE { 187 | return nil, errors.ErrCodeChallengeRquired 188 | } 189 | if cc != "" && (len(cc) < 43 || len(cc) > 128) { 190 | return nil, errors.ErrInvalidCodeChallengeLen 191 | } 192 | 193 | ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) 194 | // set default 195 | if ccm == "" { 196 | ccm = oauth2.CodeChallengePlain 197 | } 198 | if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { 199 | return nil, errors.ErrUnsupportedCodeChallengeMethod 200 | } 201 | 202 | req := &AuthorizeRequest{ 203 | RedirectURI: redirectURI, 204 | ResponseType: resType, 205 | ClientID: clientID, 206 | State: r.FormValue("state"), 207 | Scope: r.FormValue("scope"), 208 | Request: r, 209 | CodeChallenge: cc, 210 | CodeChallengeMethod: ccm, 211 | } 212 | return req, nil 213 | } 214 | 215 | // GetAuthorizeToken get authorization token(code) 216 | func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { 217 | // check the client allows the grant type 218 | if fn := s.ClientAuthorizedHandler; fn != nil { 219 | gt := oauth2.AuthorizationCode 220 | if req.ResponseType == oauth2.Token { 221 | gt = oauth2.Implicit 222 | } 223 | 224 | allowed, err := fn(req.ClientID, gt) 225 | if err != nil { 226 | return nil, err 227 | } else if !allowed { 228 | return nil, errors.ErrUnauthorizedClient 229 | } 230 | } 231 | 232 | tgr := &oauth2.TokenGenerateRequest{ 233 | ClientID: req.ClientID, 234 | UserID: req.UserID, 235 | RedirectURI: req.RedirectURI, 236 | Scope: req.Scope, 237 | AccessTokenExp: req.AccessTokenExp, 238 | Request: req.Request, 239 | } 240 | 241 | // check the client allows the authorized scope 242 | if fn := s.ClientScopeHandler; fn != nil { 243 | allowed, err := fn(tgr) 244 | if err != nil { 245 | return nil, err 246 | } else if !allowed { 247 | return nil, errors.ErrInvalidScope 248 | } 249 | } 250 | 251 | tgr.CodeChallenge = req.CodeChallenge 252 | tgr.CodeChallengeMethod = req.CodeChallengeMethod 253 | 254 | return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) 255 | } 256 | 257 | // GetAuthorizeData get authorization response data 258 | func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { 259 | if rt == oauth2.Code { 260 | return map[string]interface{}{ 261 | "code": ti.GetCode(), 262 | } 263 | } 264 | return s.GetTokenData(ti) 265 | } 266 | 267 | // HandleAuthorizeRequest the authorization request handling 268 | func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { 269 | ctx := r.Context() 270 | 271 | req, err := s.ValidationAuthorizeRequest(r) 272 | if err != nil { 273 | return s.handleError(w, req, err) 274 | } 275 | 276 | // user authorization 277 | userID, err := s.UserAuthorizationHandler(w, r) 278 | if err != nil { 279 | return s.handleError(w, req, err) 280 | } else if userID == "" { 281 | return nil 282 | } 283 | req.UserID = userID 284 | 285 | // specify the scope of authorization 286 | if fn := s.AuthorizeScopeHandler; fn != nil { 287 | scope, err := fn(w, r) 288 | if err != nil { 289 | return err 290 | } else if scope != "" { 291 | req.Scope = scope 292 | } 293 | } 294 | 295 | // specify the expiration time of access token 296 | if fn := s.AccessTokenExpHandler; fn != nil { 297 | exp, err := fn(w, r) 298 | if err != nil { 299 | return err 300 | } 301 | req.AccessTokenExp = exp 302 | } 303 | 304 | ti, err := s.GetAuthorizeToken(ctx, req) 305 | if err != nil { 306 | return s.handleError(w, req, err) 307 | } 308 | 309 | // If the redirect URI is empty, the default domain provided by the client is used. 310 | if req.RedirectURI == "" { 311 | client, err := s.Manager.GetClient(ctx, req.ClientID) 312 | if err != nil { 313 | return err 314 | } 315 | req.RedirectURI = client.GetDomain() 316 | } 317 | 318 | return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) 319 | } 320 | 321 | // ValidationTokenRequest the token request validation 322 | func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { 323 | if v := r.Method; !(v == "POST" || 324 | (s.Config.AllowGetAccessRequest && v == "GET")) { 325 | return "", nil, errors.ErrInvalidRequest 326 | } 327 | 328 | gt := oauth2.GrantType(r.FormValue("grant_type")) 329 | if gt.String() == "" { 330 | return "", nil, errors.ErrUnsupportedGrantType 331 | } 332 | 333 | clientID, clientSecret, err := s.ClientInfoHandler(r) 334 | if err != nil { 335 | return "", nil, err 336 | } 337 | 338 | tgr := &oauth2.TokenGenerateRequest{ 339 | ClientID: clientID, 340 | ClientSecret: clientSecret, 341 | Request: r, 342 | } 343 | 344 | switch gt { 345 | case oauth2.AuthorizationCode: 346 | tgr.RedirectURI = r.FormValue("redirect_uri") 347 | tgr.Code = r.FormValue("code") 348 | if tgr.RedirectURI == "" || 349 | tgr.Code == "" { 350 | return "", nil, errors.ErrInvalidRequest 351 | } 352 | tgr.CodeVerifier = r.FormValue("code_verifier") 353 | if s.Config.ForcePKCE && tgr.CodeVerifier == "" { 354 | return "", nil, errors.ErrInvalidRequest 355 | } 356 | case oauth2.PasswordCredentials: 357 | tgr.Scope = r.FormValue("scope") 358 | username, password := r.FormValue("username"), r.FormValue("password") 359 | if username == "" || password == "" { 360 | return "", nil, errors.ErrInvalidRequest 361 | } 362 | 363 | userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password) 364 | if err != nil { 365 | return "", nil, err 366 | } else if userID == "" { 367 | return "", nil, errors.ErrInvalidGrant 368 | } 369 | tgr.UserID = userID 370 | case oauth2.ClientCredentials: 371 | tgr.Scope = r.FormValue("scope") 372 | case oauth2.Refreshing: 373 | tgr.Refresh, err = s.RefreshTokenResolveHandler(r) 374 | tgr.Scope = r.FormValue("scope") 375 | if err != nil { 376 | return "", nil, err 377 | } 378 | } 379 | return gt, tgr, nil 380 | } 381 | 382 | // CheckGrantType check allows grant type 383 | func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { 384 | for _, agt := range s.Config.AllowedGrantTypes { 385 | if agt == gt { 386 | return true 387 | } 388 | } 389 | return false 390 | } 391 | 392 | // GetAccessToken access token 393 | func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, 394 | error) { 395 | if allowed := s.CheckGrantType(gt); !allowed { 396 | return nil, errors.ErrUnauthorizedClient 397 | } 398 | 399 | if fn := s.ClientAuthorizedHandler; fn != nil { 400 | allowed, err := fn(tgr.ClientID, gt) 401 | if err != nil { 402 | return nil, err 403 | } else if !allowed { 404 | return nil, errors.ErrUnauthorizedClient 405 | } 406 | } 407 | 408 | switch gt { 409 | case oauth2.AuthorizationCode: 410 | ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) 411 | if err != nil { 412 | switch err { 413 | case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: 414 | return nil, errors.ErrInvalidGrant 415 | case errors.ErrInvalidClient: 416 | return nil, errors.ErrInvalidClient 417 | default: 418 | return nil, err 419 | } 420 | } 421 | return ti, nil 422 | case oauth2.PasswordCredentials, oauth2.ClientCredentials: 423 | if fn := s.ClientScopeHandler; fn != nil { 424 | allowed, err := fn(tgr) 425 | if err != nil { 426 | return nil, err 427 | } else if !allowed { 428 | return nil, errors.ErrInvalidScope 429 | } 430 | } 431 | return s.Manager.GenerateAccessToken(ctx, gt, tgr) 432 | case oauth2.Refreshing: 433 | // check scope 434 | if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { 435 | rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) 436 | if err != nil { 437 | if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { 438 | return nil, errors.ErrInvalidGrant 439 | } 440 | return nil, err 441 | } 442 | 443 | allowed, err := scopeFn(tgr, rti.GetScope()) 444 | if err != nil { 445 | return nil, err 446 | } else if !allowed { 447 | return nil, errors.ErrInvalidScope 448 | } 449 | } 450 | 451 | if validationFn := s.RefreshingValidationHandler; validationFn != nil { 452 | rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) 453 | if err != nil { 454 | if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { 455 | return nil, errors.ErrInvalidGrant 456 | } 457 | return nil, err 458 | } 459 | allowed, err := validationFn(rti) 460 | if err != nil { 461 | return nil, err 462 | } else if !allowed { 463 | return nil, errors.ErrInvalidScope 464 | } 465 | } 466 | 467 | ti, err := s.Manager.RefreshAccessToken(ctx, tgr) 468 | if err != nil { 469 | if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { 470 | return nil, errors.ErrInvalidGrant 471 | } 472 | return nil, err 473 | } 474 | return ti, nil 475 | } 476 | 477 | return nil, errors.ErrUnsupportedGrantType 478 | } 479 | 480 | // GetTokenData token data 481 | func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { 482 | data := map[string]interface{}{ 483 | "access_token": ti.GetAccess(), 484 | "token_type": s.Config.TokenType, 485 | "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), 486 | } 487 | 488 | if scope := ti.GetScope(); scope != "" { 489 | data["scope"] = scope 490 | } 491 | 492 | if refresh := ti.GetRefresh(); refresh != "" { 493 | data["refresh_token"] = refresh 494 | } 495 | 496 | if fn := s.ExtensionFieldsHandler; fn != nil { 497 | ext := fn(ti) 498 | for k, v := range ext { 499 | if _, ok := data[k]; ok { 500 | continue 501 | } 502 | data[k] = v 503 | } 504 | } 505 | return data 506 | } 507 | 508 | // HandleTokenRequest token request handling 509 | func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { 510 | ctx := r.Context() 511 | 512 | gt, tgr, err := s.ValidationTokenRequest(r) 513 | if err != nil { 514 | return s.tokenError(w, err) 515 | } 516 | 517 | ti, err := s.GetAccessToken(ctx, gt, tgr) 518 | if err != nil { 519 | return s.tokenError(w, err) 520 | } 521 | 522 | return s.token(w, s.GetTokenData(ti), nil) 523 | } 524 | 525 | // GetErrorData get error response data 526 | func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { 527 | var re errors.Response 528 | if v, ok := errors.Descriptions[err]; ok { 529 | re.Error = err 530 | re.Description = v 531 | re.StatusCode = errors.StatusCodes[err] 532 | } else { 533 | if fn := s.InternalErrorHandler; fn != nil { 534 | if v := fn(err); v != nil { 535 | re = *v 536 | } 537 | } 538 | 539 | if re.Error == nil { 540 | re.Error = errors.ErrServerError 541 | re.Description = errors.Descriptions[errors.ErrServerError] 542 | re.StatusCode = errors.StatusCodes[errors.ErrServerError] 543 | } 544 | } 545 | 546 | if fn := s.ResponseErrorHandler; fn != nil { 547 | fn(&re) 548 | } 549 | 550 | data := make(map[string]interface{}) 551 | if err := re.Error; err != nil { 552 | data["error"] = err.Error() 553 | } 554 | 555 | if v := re.ErrorCode; v != 0 { 556 | data["error_code"] = v 557 | } 558 | 559 | if v := re.Description; v != "" { 560 | data["error_description"] = v 561 | } 562 | 563 | if v := re.URI; v != "" { 564 | data["error_uri"] = v 565 | } 566 | 567 | statusCode := http.StatusInternalServerError 568 | if v := re.StatusCode; v > 0 { 569 | statusCode = v 570 | } 571 | 572 | return data, statusCode, re.Header 573 | } 574 | 575 | // ValidationBearerToken validation the bearer tokens 576 | // https://tools.ietf.org/html/rfc6750 577 | func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { 578 | ctx := r.Context() 579 | 580 | accessToken, ok := s.AccessTokenResolveHandler(r) 581 | if !ok { 582 | return nil, errors.ErrInvalidAccessToken 583 | } 584 | 585 | return s.Manager.LoadAccessToken(ctx, accessToken) 586 | } 587 | -------------------------------------------------------------------------------- /server/server_config.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/go-oauth2/oauth2/v4" 5 | ) 6 | 7 | // SetTokenType token type 8 | func (s *Server) SetTokenType(tokenType string) { 9 | s.Config.TokenType = tokenType 10 | } 11 | 12 | // SetAllowGetAccessRequest to allow GET requests for the token 13 | func (s *Server) SetAllowGetAccessRequest(allow bool) { 14 | s.Config.AllowGetAccessRequest = allow 15 | } 16 | 17 | // SetAllowedResponseType allow the authorization types 18 | func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) { 19 | s.Config.AllowedResponseTypes = types 20 | } 21 | 22 | // SetAllowedGrantType allow the grant types 23 | func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) { 24 | s.Config.AllowedGrantTypes = types 25 | } 26 | 27 | // SetClientInfoHandler get client info from request 28 | func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) { 29 | s.ClientInfoHandler = handler 30 | } 31 | 32 | // SetClientAuthorizedHandler check the client allows to use this authorization grant type 33 | func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) { 34 | s.ClientAuthorizedHandler = handler 35 | } 36 | 37 | // SetClientScopeHandler check the client allows to use scope 38 | func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) { 39 | s.ClientScopeHandler = handler 40 | } 41 | 42 | // SetUserAuthorizationHandler get user id from request authorization 43 | func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) { 44 | s.UserAuthorizationHandler = handler 45 | } 46 | 47 | // SetPasswordAuthorizationHandler get user id from username and password 48 | func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) { 49 | s.PasswordAuthorizationHandler = handler 50 | } 51 | 52 | // SetRefreshingScopeHandler check the scope of the refreshing token 53 | func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { 54 | s.RefreshingScopeHandler = handler 55 | } 56 | 57 | // SetRefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other 58 | func (s *Server) SetRefreshingValidationHandler(handler RefreshingValidationHandler) { 59 | s.RefreshingValidationHandler = handler 60 | } 61 | 62 | // SetResponseErrorHandler response error handling 63 | func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) { 64 | s.ResponseErrorHandler = handler 65 | } 66 | 67 | // SetInternalErrorHandler internal error handling 68 | func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) { 69 | s.InternalErrorHandler = handler 70 | } 71 | 72 | // SetPreRedirectErrorHandler sets the PreRedirectErrorHandler in current Server instance 73 | func (s *Server) SetPreRedirectErrorHandler(handler PreRedirectErrorHandler) { 74 | s.PreRedirectErrorHandler = handler 75 | } 76 | 77 | // SetExtensionFieldsHandler in response to the access token with the extension of the field 78 | func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) { 79 | s.ExtensionFieldsHandler = handler 80 | } 81 | 82 | // SetAccessTokenExpHandler set expiration date for the access token 83 | func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) { 84 | s.AccessTokenExpHandler = handler 85 | } 86 | 87 | // SetAuthorizeScopeHandler set scope for the access token 88 | func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) { 89 | s.AuthorizeScopeHandler = handler 90 | } 91 | 92 | // SetResponseTokenHandler response token handing 93 | func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) { 94 | s.ResponseTokenHandler = handler 95 | } 96 | 97 | // SetRefreshTokenResolveHandler refresh token resolver 98 | func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) { 99 | s.RefreshTokenResolveHandler = handler 100 | } 101 | 102 | // SetAccessTokenResolveHandler access token resolver 103 | func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) { 104 | s.AccessTokenResolveHandler = handler 105 | } 106 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | package server_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/gavv/httpexpect" 11 | "github.com/go-oauth2/oauth2/v4" 12 | "github.com/go-oauth2/oauth2/v4/errors" 13 | "github.com/go-oauth2/oauth2/v4/manage" 14 | "github.com/go-oauth2/oauth2/v4/models" 15 | "github.com/go-oauth2/oauth2/v4/server" 16 | "github.com/go-oauth2/oauth2/v4/store" 17 | ) 18 | 19 | var ( 20 | srv *server.Server 21 | tsrv *httptest.Server 22 | manager *manage.Manager 23 | csrv *httptest.Server 24 | clientID = "111111" 25 | clientSecret = "11111111" 26 | 27 | plainChallenge = "ThisIsAFourtyThreeCharactersLongStringThing" 28 | s256Challenge = "s256tests256tests256tests256tests256tests256test" 29 | // sha2562 := sha256.Sum256([]byte(s256Challenge)) 30 | // fmt.Printf(base64.URLEncoding.EncodeToString(sha2562[:])) 31 | s256ChallengeHash = "To2Xqv01cm16bC9Sf7KRRS8CO2SFss_HSMQOr3sdCDE=" 32 | ) 33 | 34 | func init() { 35 | manager = manage.NewDefaultManager() 36 | manager.MustTokenStorage(store.NewMemoryTokenStore()) 37 | } 38 | 39 | func clientStore(domain string, public bool) oauth2.ClientStore { 40 | clientStore := store.NewClientStore() 41 | var secret string 42 | if public { 43 | secret = "" 44 | } else { 45 | secret = clientSecret 46 | } 47 | clientStore.Set(clientID, &models.Client{ 48 | ID: clientID, 49 | Secret: secret, 50 | Domain: domain, 51 | Public: public, 52 | }) 53 | return clientStore 54 | } 55 | 56 | func testServer(t *testing.T, w http.ResponseWriter, r *http.Request) { 57 | switch r.URL.Path { 58 | case "/authorize": 59 | err := srv.HandleAuthorizeRequest(w, r) 60 | if err != nil { 61 | t.Error(err) 62 | } 63 | case "/token": 64 | err := srv.HandleTokenRequest(w, r) 65 | if err != nil { 66 | t.Error(err) 67 | } 68 | } 69 | } 70 | 71 | func TestAuthorizeCode(t *testing.T) { 72 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 73 | testServer(t, w, r) 74 | })) 75 | defer tsrv.Close() 76 | 77 | e := httpexpect.New(t, tsrv.URL) 78 | 79 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 80 | switch r.URL.Path { 81 | case "/oauth2": 82 | r.ParseForm() 83 | code, state := r.Form.Get("code"), r.Form.Get("state") 84 | if state != "123" { 85 | t.Error("unrecognized state:", state) 86 | return 87 | } 88 | resObj := e.POST("/token"). 89 | WithFormField("redirect_uri", csrv.URL+"/oauth2"). 90 | WithFormField("code", code). 91 | WithFormField("grant_type", "authorization_code"). 92 | WithFormField("client_id", clientID). 93 | WithBasicAuth(clientID, clientSecret). 94 | Expect(). 95 | Status(http.StatusOK). 96 | JSON().Object() 97 | 98 | t.Logf("%#v\n", resObj.Raw()) 99 | 100 | validationAccessToken(t, resObj.Value("access_token").String().Raw()) 101 | } 102 | })) 103 | defer csrv.Close() 104 | 105 | manager.MapClientStorage(clientStore(csrv.URL, true)) 106 | srv = server.NewDefaultServer(manager) 107 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { 108 | userID = "000000" 109 | return 110 | }) 111 | 112 | e.GET("/authorize"). 113 | WithQuery("response_type", "code"). 114 | WithQuery("client_id", clientID). 115 | WithQuery("scope", "all"). 116 | WithQuery("state", "123"). 117 | WithQuery("redirect_uri", csrv.URL+"/oauth2"). 118 | Expect().Status(http.StatusOK) 119 | } 120 | 121 | func TestAuthorizeCodeWithChallengePlain(t *testing.T) { 122 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 123 | testServer(t, w, r) 124 | })) 125 | defer tsrv.Close() 126 | 127 | e := httpexpect.New(t, tsrv.URL) 128 | 129 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 130 | switch r.URL.Path { 131 | case "/oauth2": 132 | r.ParseForm() 133 | code, state := r.Form.Get("code"), r.Form.Get("state") 134 | if state != "123" { 135 | t.Error("unrecognized state:", state) 136 | return 137 | } 138 | resObj := e.POST("/token"). 139 | WithFormField("redirect_uri", csrv.URL+"/oauth2"). 140 | WithFormField("code", code). 141 | WithFormField("grant_type", "authorization_code"). 142 | WithFormField("client_id", clientID). 143 | WithFormField("code", code). 144 | WithFormField("code_verifier", plainChallenge). 145 | Expect(). 146 | Status(http.StatusOK). 147 | JSON().Object() 148 | 149 | t.Logf("%#v\n", resObj.Raw()) 150 | 151 | validationAccessToken(t, resObj.Value("access_token").String().Raw()) 152 | } 153 | })) 154 | defer csrv.Close() 155 | 156 | manager.MapClientStorage(clientStore(csrv.URL, true)) 157 | srv = server.NewDefaultServer(manager) 158 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { 159 | userID = "000000" 160 | return 161 | }) 162 | srv.SetClientInfoHandler(server.ClientFormHandler) 163 | 164 | e.GET("/authorize"). 165 | WithQuery("response_type", "code"). 166 | WithQuery("client_id", clientID). 167 | WithQuery("scope", "all"). 168 | WithQuery("state", "123"). 169 | WithQuery("redirect_uri", csrv.URL+"/oauth2"). 170 | WithQuery("code_challenge", plainChallenge). 171 | Expect().Status(http.StatusOK) 172 | } 173 | 174 | func TestAuthorizeCodeWithChallengeS256(t *testing.T) { 175 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 176 | testServer(t, w, r) 177 | })) 178 | defer tsrv.Close() 179 | 180 | e := httpexpect.New(t, tsrv.URL) 181 | 182 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 183 | switch r.URL.Path { 184 | case "/oauth2": 185 | r.ParseForm() 186 | code, state := r.Form.Get("code"), r.Form.Get("state") 187 | if state != "123" { 188 | t.Error("unrecognized state:", state) 189 | return 190 | } 191 | resObj := e.POST("/token"). 192 | WithFormField("redirect_uri", csrv.URL+"/oauth2"). 193 | WithFormField("code", code). 194 | WithFormField("grant_type", "authorization_code"). 195 | WithFormField("client_id", clientID). 196 | WithFormField("code", code). 197 | WithFormField("code_verifier", s256Challenge). 198 | Expect(). 199 | Status(http.StatusOK). 200 | JSON().Object() 201 | 202 | t.Logf("%#v\n", resObj.Raw()) 203 | 204 | validationAccessToken(t, resObj.Value("access_token").String().Raw()) 205 | } 206 | })) 207 | defer csrv.Close() 208 | 209 | manager.MapClientStorage(clientStore(csrv.URL, true)) 210 | srv = server.NewDefaultServer(manager) 211 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { 212 | userID = "000000" 213 | return 214 | }) 215 | srv.SetClientInfoHandler(server.ClientFormHandler) 216 | 217 | e.GET("/authorize"). 218 | WithQuery("response_type", "code"). 219 | WithQuery("client_id", clientID). 220 | WithQuery("scope", "all"). 221 | WithQuery("state", "123"). 222 | WithQuery("redirect_uri", csrv.URL+"/oauth2"). 223 | WithQuery("code_challenge", s256ChallengeHash). 224 | WithQuery("code_challenge_method", "S256"). 225 | Expect().Status(http.StatusOK) 226 | } 227 | 228 | func TestImplicit(t *testing.T) { 229 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 230 | testServer(t, w, r) 231 | })) 232 | defer tsrv.Close() 233 | e := httpexpect.New(t, tsrv.URL) 234 | 235 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 236 | defer csrv.Close() 237 | 238 | manager.MapClientStorage(clientStore(csrv.URL, false)) 239 | srv = server.NewDefaultServer(manager) 240 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { 241 | userID = "000000" 242 | return 243 | }) 244 | 245 | e.GET("/authorize"). 246 | WithQuery("response_type", "token"). 247 | WithQuery("client_id", clientID). 248 | WithQuery("scope", "all"). 249 | WithQuery("state", "123"). 250 | WithQuery("redirect_uri", csrv.URL+"/oauth2"). 251 | Expect().Status(http.StatusOK) 252 | } 253 | 254 | func TestPasswordCredentials(t *testing.T) { 255 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 256 | testServer(t, w, r) 257 | })) 258 | defer tsrv.Close() 259 | e := httpexpect.New(t, tsrv.URL) 260 | 261 | manager.MapClientStorage(clientStore("", false)) 262 | srv = server.NewDefaultServer(manager) 263 | srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) { 264 | if username == "admin" && password == "123456" { 265 | userID = "000000" 266 | return 267 | } 268 | err = fmt.Errorf("user not found") 269 | return 270 | }) 271 | 272 | resObj := e.POST("/token"). 273 | WithFormField("grant_type", "password"). 274 | WithFormField("username", "admin"). 275 | WithFormField("password", "123456"). 276 | WithFormField("scope", "all"). 277 | WithBasicAuth(clientID, clientSecret). 278 | Expect(). 279 | Status(http.StatusOK). 280 | JSON().Object() 281 | 282 | t.Logf("%#v\n", resObj.Raw()) 283 | 284 | validationAccessToken(t, resObj.Value("access_token").String().Raw()) 285 | } 286 | 287 | func TestClientCredentials(t *testing.T) { 288 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 289 | testServer(t, w, r) 290 | })) 291 | defer tsrv.Close() 292 | e := httpexpect.New(t, tsrv.URL) 293 | 294 | manager.MapClientStorage(clientStore("", false)) 295 | 296 | srv = server.NewDefaultServer(manager) 297 | srv.SetClientInfoHandler(server.ClientFormHandler) 298 | 299 | srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { 300 | t.Log("OAuth 2.0 Error:", err.Error()) 301 | return 302 | }) 303 | 304 | srv.SetResponseErrorHandler(func(re *errors.Response) { 305 | t.Log("Response Error:", re.Error) 306 | }) 307 | 308 | srv.SetAllowedGrantType(oauth2.ClientCredentials) 309 | srv.SetAllowGetAccessRequest(false) 310 | srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) { 311 | fieldsValue = map[string]interface{}{ 312 | "extension": "param", 313 | } 314 | return 315 | }) 316 | srv.SetAuthorizeScopeHandler(func(w http.ResponseWriter, r *http.Request) (scope string, err error) { 317 | return 318 | }) 319 | srv.SetClientScopeHandler(func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) { 320 | allowed = true 321 | return 322 | }) 323 | 324 | resObj := e.POST("/token"). 325 | WithFormField("grant_type", "client_credentials"). 326 | WithFormField("scope", "all"). 327 | WithFormField("client_id", clientID). 328 | WithFormField("client_secret", clientSecret). 329 | Expect(). 330 | Status(http.StatusOK). 331 | JSON().Object() 332 | 333 | t.Logf("%#v\n", resObj.Raw()) 334 | 335 | validationAccessToken(t, resObj.Value("access_token").String().Raw()) 336 | } 337 | 338 | func TestRefreshing(t *testing.T) { 339 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 340 | testServer(t, w, r) 341 | })) 342 | defer tsrv.Close() 343 | e := httpexpect.New(t, tsrv.URL) 344 | 345 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 346 | switch r.URL.Path { 347 | case "/oauth2": 348 | r.ParseForm() 349 | code, state := r.Form.Get("code"), r.Form.Get("state") 350 | if state != "123" { 351 | t.Error("unrecognized state:", state) 352 | return 353 | } 354 | jresObj := e.POST("/token"). 355 | WithFormField("redirect_uri", csrv.URL+"/oauth2"). 356 | WithFormField("code", code). 357 | WithFormField("grant_type", "authorization_code"). 358 | WithFormField("client_id", clientID). 359 | WithBasicAuth(clientID, clientSecret). 360 | Expect(). 361 | Status(http.StatusOK). 362 | JSON().Object() 363 | 364 | t.Logf("%#v\n", jresObj.Raw()) 365 | 366 | validationAccessToken(t, jresObj.Value("access_token").String().Raw()) 367 | 368 | resObj := e.POST("/token"). 369 | WithFormField("grant_type", "refresh_token"). 370 | WithFormField("scope", "one"). 371 | WithFormField("refresh_token", jresObj.Value("refresh_token").String().Raw()). 372 | WithBasicAuth(clientID, clientSecret). 373 | Expect(). 374 | Status(http.StatusOK). 375 | JSON().Object() 376 | 377 | t.Logf("%#v\n", resObj.Raw()) 378 | 379 | validationAccessToken(t, resObj.Value("access_token").String().Raw()) 380 | } 381 | })) 382 | defer csrv.Close() 383 | 384 | manager.MapClientStorage(clientStore(csrv.URL, true)) 385 | srv = server.NewDefaultServer(manager) 386 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { 387 | userID = "000000" 388 | return 389 | }) 390 | 391 | e.GET("/authorize"). 392 | WithQuery("response_type", "code"). 393 | WithQuery("client_id", clientID). 394 | WithQuery("scope", "all"). 395 | WithQuery("state", "123"). 396 | WithQuery("redirect_uri", csrv.URL+"/oauth2"). 397 | Expect().Status(http.StatusOK) 398 | } 399 | 400 | // validation access token 401 | func validationAccessToken(t *testing.T, accessToken string) { 402 | req := httptest.NewRequest("GET", "http://example.com", nil) 403 | 404 | req.Header.Set("Authorization", "Bearer "+accessToken) 405 | 406 | ti, err := srv.ValidationBearerToken(req) 407 | if err != nil { 408 | t.Error(err.Error()) 409 | return 410 | } 411 | if ti.GetClientID() != clientID { 412 | t.Error("invalid access token") 413 | } 414 | } 415 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | package oauth2 2 | 3 | import "context" 4 | 5 | type ( 6 | // ClientStore the client information storage interface 7 | ClientStore interface { 8 | // according to the ID for the client information 9 | GetByID(ctx context.Context, id string) (ClientInfo, error) 10 | } 11 | 12 | // TokenStore the token information storage interface 13 | TokenStore interface { 14 | // create and store the new token information 15 | Create(ctx context.Context, info TokenInfo) error 16 | 17 | // delete the authorization code 18 | RemoveByCode(ctx context.Context, code string) error 19 | 20 | // use the access token to delete the token information 21 | RemoveByAccess(ctx context.Context, access string) error 22 | 23 | // use the refresh token to delete the token information 24 | RemoveByRefresh(ctx context.Context, refresh string) error 25 | 26 | // use the authorization code for token information data 27 | GetByCode(ctx context.Context, code string) (TokenInfo, error) 28 | 29 | // use the access token for token information data 30 | GetByAccess(ctx context.Context, access string) (TokenInfo, error) 31 | 32 | // use the refresh token for token information data 33 | GetByRefresh(ctx context.Context, refresh string) (TokenInfo, error) 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /store/client.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | 8 | "github.com/go-oauth2/oauth2/v4" 9 | ) 10 | 11 | // NewClientStore create client store 12 | func NewClientStore() *ClientStore { 13 | return &ClientStore{ 14 | data: make(map[string]oauth2.ClientInfo), 15 | } 16 | } 17 | 18 | // ClientStore client information store 19 | type ClientStore struct { 20 | sync.RWMutex 21 | data map[string]oauth2.ClientInfo 22 | } 23 | 24 | // GetByID according to the ID for the client information 25 | func (cs *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { 26 | cs.RLock() 27 | defer cs.RUnlock() 28 | 29 | if c, ok := cs.data[id]; ok { 30 | return c, nil 31 | } 32 | return nil, errors.New("not found") 33 | } 34 | 35 | // Set set client information 36 | func (cs *ClientStore) Set(id string, cli oauth2.ClientInfo) (err error) { 37 | cs.Lock() 38 | defer cs.Unlock() 39 | 40 | cs.data[id] = cli 41 | return 42 | } 43 | -------------------------------------------------------------------------------- /store/client_test.go: -------------------------------------------------------------------------------- 1 | package store_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/go-oauth2/oauth2/v4/models" 8 | "github.com/go-oauth2/oauth2/v4/store" 9 | 10 | . "github.com/smartystreets/goconvey/convey" 11 | ) 12 | 13 | func TestClientStore(t *testing.T) { 14 | Convey("Test client store", t, func() { 15 | clientStore := store.NewClientStore() 16 | 17 | err := clientStore.Set("1", &models.Client{ID: "1", Secret: "2"}) 18 | So(err, ShouldBeNil) 19 | 20 | cli, err := clientStore.GetByID(context.Background(), "1") 21 | So(err, ShouldBeNil) 22 | So(cli.GetID(), ShouldEqual, "1") 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /store/token.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "time" 7 | 8 | "github.com/go-oauth2/oauth2/v4" 9 | "github.com/go-oauth2/oauth2/v4/models" 10 | "github.com/google/uuid" 11 | "github.com/tidwall/buntdb" 12 | ) 13 | 14 | // NewMemoryTokenStore create a token store instance based on memory 15 | func NewMemoryTokenStore() (oauth2.TokenStore, error) { 16 | return NewFileTokenStore(":memory:") 17 | } 18 | 19 | // NewFileTokenStore create a token store instance based on file 20 | func NewFileTokenStore(filename string) (oauth2.TokenStore, error) { 21 | db, err := buntdb.Open(filename) 22 | if err != nil { 23 | return nil, err 24 | } 25 | return &TokenStore{db: db}, nil 26 | } 27 | 28 | // TokenStore token storage based on buntdb(https://github.com/tidwall/buntdb) 29 | type TokenStore struct { 30 | db *buntdb.DB 31 | } 32 | 33 | // Create create and store the new token information 34 | func (ts *TokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { 35 | ct := time.Now() 36 | jv, err := json.Marshal(info) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | return ts.db.Update(func(tx *buntdb.Tx) error { 42 | if code := info.GetCode(); code != "" { 43 | _, _, err := tx.Set(code, string(jv), &buntdb.SetOptions{Expires: true, TTL: info.GetCodeExpiresIn()}) 44 | return err 45 | } 46 | 47 | basicID := uuid.Must(uuid.NewRandom()).String() 48 | aexp := info.GetAccessExpiresIn() 49 | rexp := aexp 50 | expires := true 51 | if refresh := info.GetRefresh(); refresh != "" { 52 | rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct) 53 | if aexp.Seconds() > rexp.Seconds() { 54 | aexp = rexp 55 | } 56 | expires = info.GetRefreshExpiresIn() != 0 57 | _, _, err := tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp}) 58 | if err != nil { 59 | return err 60 | } 61 | } 62 | 63 | _, _, err := tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp}) 64 | if err != nil { 65 | return err 66 | } 67 | _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp}) 68 | return err 69 | }) 70 | } 71 | 72 | // remove key 73 | func (ts *TokenStore) remove(key string) error { 74 | err := ts.db.Update(func(tx *buntdb.Tx) error { 75 | _, err := tx.Delete(key) 76 | return err 77 | }) 78 | if err == buntdb.ErrNotFound { 79 | return nil 80 | } 81 | return err 82 | } 83 | 84 | // RemoveByCode use the authorization code to delete the token information 85 | func (ts *TokenStore) RemoveByCode(ctx context.Context, code string) error { 86 | return ts.remove(code) 87 | } 88 | 89 | // RemoveByAccess use the access token to delete the token information 90 | func (ts *TokenStore) RemoveByAccess(ctx context.Context, access string) error { 91 | return ts.remove(access) 92 | } 93 | 94 | // RemoveByRefresh use the refresh token to delete the token information 95 | func (ts *TokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { 96 | return ts.remove(refresh) 97 | } 98 | 99 | func (ts *TokenStore) getData(key string) (oauth2.TokenInfo, error) { 100 | var ti oauth2.TokenInfo 101 | err := ts.db.View(func(tx *buntdb.Tx) error { 102 | jv, err := tx.Get(key) 103 | if err != nil { 104 | return err 105 | } 106 | 107 | var tm models.Token 108 | err = json.Unmarshal([]byte(jv), &tm) 109 | if err != nil { 110 | return err 111 | } 112 | ti = &tm 113 | return nil 114 | }) 115 | if err != nil { 116 | if err == buntdb.ErrNotFound { 117 | return nil, nil 118 | } 119 | return nil, err 120 | } 121 | return ti, nil 122 | } 123 | 124 | func (ts *TokenStore) getBasicID(key string) (string, error) { 125 | var basicID string 126 | err := ts.db.View(func(tx *buntdb.Tx) error { 127 | v, err := tx.Get(key) 128 | if err != nil { 129 | return err 130 | } 131 | basicID = v 132 | return nil 133 | }) 134 | if err != nil { 135 | if err == buntdb.ErrNotFound { 136 | return "", nil 137 | } 138 | return "", err 139 | } 140 | return basicID, nil 141 | } 142 | 143 | // GetByCode use the authorization code for token information data 144 | func (ts *TokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { 145 | return ts.getData(code) 146 | } 147 | 148 | // GetByAccess use the access token for token information data 149 | func (ts *TokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { 150 | basicID, err := ts.getBasicID(access) 151 | if err != nil { 152 | return nil, err 153 | } 154 | return ts.getData(basicID) 155 | } 156 | 157 | // GetByRefresh use the refresh token for token information data 158 | func (ts *TokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { 159 | basicID, err := ts.getBasicID(refresh) 160 | if err != nil { 161 | return nil, err 162 | } 163 | return ts.getData(basicID) 164 | } 165 | -------------------------------------------------------------------------------- /store/token_test.go: -------------------------------------------------------------------------------- 1 | package store_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | "time" 8 | 9 | "github.com/go-oauth2/oauth2/v4" 10 | "github.com/go-oauth2/oauth2/v4/models" 11 | "github.com/go-oauth2/oauth2/v4/store" 12 | 13 | . "github.com/smartystreets/goconvey/convey" 14 | ) 15 | 16 | func TestTokenStore(t *testing.T) { 17 | Convey("Test memory store", t, func() { 18 | store, err := store.NewMemoryTokenStore() 19 | So(err, ShouldBeNil) 20 | testToken(store) 21 | }) 22 | 23 | Convey("Test file store", t, func() { 24 | os.Remove("data.db") 25 | 26 | store, err := store.NewFileTokenStore("data.db") 27 | So(err, ShouldBeNil) 28 | testToken(store) 29 | }) 30 | } 31 | 32 | func testToken(store oauth2.TokenStore) { 33 | Convey("Test authorization code store", func() { 34 | ctx := context.Background() 35 | info := &models.Token{ 36 | ClientID: "1", 37 | UserID: "1_1", 38 | RedirectURI: "http://localhost/", 39 | Scope: "all", 40 | Code: "11_11_11", 41 | CodeCreateAt: time.Now(), 42 | CodeExpiresIn: time.Second * 5, 43 | } 44 | err := store.Create(ctx, info) 45 | So(err, ShouldBeNil) 46 | 47 | cinfo, err := store.GetByCode(ctx, info.Code) 48 | So(err, ShouldBeNil) 49 | So(cinfo.GetUserID(), ShouldEqual, info.UserID) 50 | 51 | err = store.RemoveByCode(ctx, info.Code) 52 | So(err, ShouldBeNil) 53 | 54 | cinfo, err = store.GetByCode(ctx, info.Code) 55 | So(err, ShouldBeNil) 56 | So(cinfo, ShouldBeNil) 57 | }) 58 | 59 | Convey("Test access token store", func() { 60 | ctx := context.Background() 61 | info := &models.Token{ 62 | ClientID: "1", 63 | UserID: "1_1", 64 | RedirectURI: "http://localhost/", 65 | Scope: "all", 66 | Access: "1_1_1", 67 | AccessCreateAt: time.Now(), 68 | AccessExpiresIn: time.Second * 5, 69 | } 70 | err := store.Create(ctx, info) 71 | So(err, ShouldBeNil) 72 | 73 | ainfo, err := store.GetByAccess(ctx, info.GetAccess()) 74 | So(err, ShouldBeNil) 75 | So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) 76 | 77 | err = store.RemoveByAccess(ctx, info.GetAccess()) 78 | So(err, ShouldBeNil) 79 | 80 | ainfo, err = store.GetByAccess(ctx, info.GetAccess()) 81 | So(err, ShouldBeNil) 82 | So(ainfo, ShouldBeNil) 83 | }) 84 | 85 | Convey("Test refresh token store", func() { 86 | ctx := context.Background() 87 | info := &models.Token{ 88 | ClientID: "1", 89 | UserID: "1_2", 90 | RedirectURI: "http://localhost/", 91 | Scope: "all", 92 | Access: "1_2_1", 93 | AccessCreateAt: time.Now(), 94 | AccessExpiresIn: time.Second * 5, 95 | Refresh: "1_2_2", 96 | RefreshCreateAt: time.Now(), 97 | RefreshExpiresIn: time.Second * 15, 98 | } 99 | err := store.Create(ctx, info) 100 | So(err, ShouldBeNil) 101 | 102 | rinfo, err := store.GetByRefresh(ctx, info.GetRefresh()) 103 | So(err, ShouldBeNil) 104 | So(rinfo.GetUserID(), ShouldEqual, info.GetUserID()) 105 | 106 | err = store.RemoveByRefresh(ctx, info.GetRefresh()) 107 | So(err, ShouldBeNil) 108 | 109 | rinfo, err = store.GetByRefresh(ctx, info.GetRefresh()) 110 | So(err, ShouldBeNil) 111 | So(rinfo, ShouldBeNil) 112 | }) 113 | 114 | Convey("Test TTL", func() { 115 | ctx := context.Background() 116 | info := &models.Token{ 117 | ClientID: "1", 118 | UserID: "1_1", 119 | RedirectURI: "http://localhost/", 120 | Scope: "all", 121 | Access: "1_3_1", 122 | AccessCreateAt: time.Now(), 123 | AccessExpiresIn: time.Second * 1, 124 | Refresh: "1_3_2", 125 | RefreshCreateAt: time.Now(), 126 | RefreshExpiresIn: time.Second * 1, 127 | } 128 | err := store.Create(ctx, info) 129 | So(err, ShouldBeNil) 130 | 131 | time.Sleep(time.Second * 1) 132 | ainfo, err := store.GetByAccess(ctx, info.Access) 133 | So(err, ShouldBeNil) 134 | So(ainfo, ShouldBeNil) 135 | rinfo, err := store.GetByRefresh(ctx, info.Refresh) 136 | So(err, ShouldBeNil) 137 | So(rinfo, ShouldBeNil) 138 | }) 139 | } 140 | --------------------------------------------------------------------------------