├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── const.go
├── const_test.go
├── doc.go
├── errors
├── error.go
└── response.go
├── example
├── README.md
├── client
│ └── client.go
└── server
│ ├── server.go
│ └── static
│ ├── auth.html
│ ├── auth.png
│ ├── login.html
│ ├── login.png
│ └── token.png
├── generate.go
├── generates
├── access.go
├── access_test.go
├── authorize.go
├── authorize_test.go
├── jwt_access.go
└── jwt_access_test.go
├── go.mod
├── go.sum
├── go.test.sh
├── manage.go
├── manage
├── config.go
├── manage_test.go
├── manager.go
├── util.go
└── util_test.go
├── model.go
├── models
├── client.go
└── token.go
├── server
├── config.go
├── handler.go
├── handler_test.go
├── server.go
├── server_config.go
└── server_test.go
├── store.go
└── store
├── client.go
├── client_test.go
├── token.go
└── token_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Object files, Static and Dynamic libs (Shared Objects)
2 | *.o
3 | *.a
4 | *.so
5 |
6 | # Folders
7 | _obj
8 | _test
9 |
10 | # Architecture specific extensions/prefixes
11 | *.[568vq]
12 | [568vq].out
13 |
14 | *.cgo1.go
15 | *.cgo2.c
16 | _cgo_defun.c
17 | _cgo_gotypes.go
18 | _cgo_export.*
19 |
20 | _testmain.go
21 |
22 | *.exe
23 | *.test
24 | *.prof
25 |
26 | coverage.txt
27 |
28 | # OSX
29 | *.DS_Store
30 | *.db
31 | *.swp
32 | /example/client/client
33 | /example/server/server
34 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | sudo: false
3 | go_import_path: github.com/go-oauth2/oauth2/v4
4 | go:
5 | - 1.13
6 | before_install:
7 | - go get -t -v ./...
8 |
9 | script:
10 | - chmod +x ./go.test.sh && ./go.test.sh
11 |
12 | after_success:
13 | - bash <(curl -s https://codecov.io/bash)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 Lyric
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Golang OAuth 2.0 Server
2 |
3 | > An open protocol to allow secure authorization in a simple and standard method from web, mobile and desktop applications.
4 |
5 | [![Build][build-status-image]][build-status-url] [![Codecov][codecov-image]][codecov-url] [![ReportCard][reportcard-image]][reportcard-url] [![GoDoc][godoc-image]][godoc-url] [![License][license-image]][license-url]
6 |
7 | ## Protocol Flow
8 |
9 | ```text
10 | +--------+ +---------------+
11 | | |--(A)- Authorization Request ->| Resource |
12 | | | | Owner |
13 | | |<-(B)-- Authorization Grant ---| |
14 | | | +---------------+
15 | | |
16 | | | +---------------+
17 | | |--(C)-- Authorization Grant -->| Authorization |
18 | | Client | | Server |
19 | | |<-(D)----- Access Token -------| |
20 | | | +---------------+
21 | | |
22 | | | +---------------+
23 | | |--(E)----- Access Token ------>| Resource |
24 | | | | Server |
25 | | |<-(F)--- Protected Resource ---| |
26 | +--------+ +---------------+
27 | ```
28 |
29 | ## Quick Start
30 |
31 | ### Download and install
32 |
33 | ```bash
34 | go get -u -v github.com/go-oauth2/oauth2/v4/...
35 | ```
36 |
37 | ### Create file `server.go`
38 |
39 | ```go
40 | package main
41 |
42 | import (
43 | "log"
44 | "net/http"
45 |
46 | "github.com/go-oauth2/oauth2/v4/errors"
47 | "github.com/go-oauth2/oauth2/v4/manage"
48 | "github.com/go-oauth2/oauth2/v4/models"
49 | "github.com/go-oauth2/oauth2/v4/server"
50 | "github.com/go-oauth2/oauth2/v4/store"
51 | )
52 |
53 | func main() {
54 | manager := manage.NewDefaultManager()
55 | // token memory store
56 | manager.MustTokenStorage(store.NewMemoryTokenStore())
57 |
58 | // client memory store
59 | clientStore := store.NewClientStore()
60 | clientStore.Set("000000", &models.Client{
61 | ID: "000000",
62 | Secret: "999999",
63 | Domain: "http://localhost",
64 | })
65 | manager.MapClientStorage(clientStore)
66 |
67 | srv := server.NewDefaultServer(manager)
68 | srv.SetAllowGetAccessRequest(true)
69 | srv.SetClientInfoHandler(server.ClientFormHandler)
70 |
71 | srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
72 | return "000000", nil
73 | }
74 |
75 | srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
76 | log.Println("Internal Error:", err.Error())
77 | return
78 | })
79 |
80 | srv.SetResponseErrorHandler(func(re *errors.Response) {
81 | log.Println("Response Error:", re.Error.Error())
82 | })
83 |
84 | http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) {
85 | err := srv.HandleAuthorizeRequest(w, r)
86 | if err != nil {
87 | http.Error(w, err.Error(), http.StatusBadRequest)
88 | }
89 | })
90 |
91 | http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
92 | srv.HandleTokenRequest(w, r)
93 | })
94 |
95 | log.Fatal(http.ListenAndServe(":9096", nil))
96 | }
97 |
98 | ```
99 |
100 | ### Build and run
101 |
102 | ```bash
103 | go build server.go
104 |
105 | ./server
106 | ```
107 |
108 | ### Open in your web browser
109 | **Authorization Request**:
110 | [http://localhost:9096/authorize?client_id=000000&response_type=code](http://localhost:9096/authorize?client_id=000000&response_type=code)
111 |
112 | **Grant Token Request**:
113 | [http://localhost:9096/token?grant_type=client_credentials&client_id=000000&client_secret=999999&scope=read](http://localhost:9096/token?grant_type=client_credentials&client_id=000000&client_secret=999999&scope=read)
114 |
115 | ```json
116 | {
117 | "access_token": "J86XVRYSNFCFI233KXDL0Q",
118 | "expires_in": 7200,
119 | "scope": "read",
120 | "token_type": "Bearer"
121 | }
122 | ```
123 |
124 | ## Features
125 |
126 | - Easy to use
127 | - Based on the [RFC 6749](https://tools.ietf.org/html/rfc6749) implementation
128 | - Token storage support TTL
129 | - Support custom expiration time of the access token
130 | - Support custom extension field
131 | - Support custom scope
132 | - Support jwt to generate access tokens
133 |
134 | ## Example
135 |
136 | > A complete example of simulation authorization code model
137 |
138 | Simulation examples of authorization code model, please check [example](/example)
139 |
140 | ### Use jwt to generate access tokens
141 |
142 | ```go
143 |
144 | import (
145 | "github.com/go-oauth2/oauth2/v4/generates"
146 | "github.com/dgrijalva/jwt-go"
147 | )
148 |
149 | // ...
150 | manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512))
151 |
152 | // Parse and verify jwt access token
153 | token, err := jwt.ParseWithClaims(access, &generates.JWTAccessClaims{}, func(t *jwt.Token) (interface{}, error) {
154 | if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
155 | return nil, fmt.Errorf("parse error")
156 | }
157 | return []byte("00000000"), nil
158 | })
159 | if err != nil {
160 | // panic(err)
161 | }
162 |
163 | claims, ok := token.Claims.(*generates.JWTAccessClaims)
164 | if !ok || !token.Valid {
165 | // panic("invalid token")
166 | }
167 | ```
168 |
169 | ## Store Implements
170 |
171 | - [BuntDB](https://github.com/tidwall/buntdb)(default store)
172 | - [Redis](https://github.com/go-oauth2/redis)
173 | - [MongoDB](https://github.com/go-oauth2/mongo)
174 | - [MySQL](https://github.com/go-oauth2/mysql)
175 | - [MySQL (Provides both client and token store)](https://github.com/imrenagi/go-oauth2-mysql)
176 | - [PostgreSQL](https://github.com/vgarvardt/go-oauth2-pg)
177 | - [DynamoDB](https://github.com/contamobi/go-oauth2-dynamodb)
178 | - [XORM](https://github.com/techknowlogick/go-oauth2-xorm)
179 | - [XORM (MySQL, client and token store)](https://github.com/rainlay/go-oauth2-xorm)
180 | - [GORM](https://github.com/techknowlogick/go-oauth2-gorm)
181 | - [Firestore](https://github.com/tslamic/go-oauth2-firestore)
182 | - [Hazelcast](https://github.com/clowre/go-oauth2-hazelcast) (token only)
183 |
184 | ## Handy Utilities
185 |
186 | - [OAuth2 Proxy Logger (Debug utility that proxies interfaces and logs)](https://github.com/aubelsb2/oauth2-logger-proxy)
187 |
188 | ## MIT License
189 |
190 | Copyright (c) 2016 Lyric
191 |
192 | [build-status-url]: https://travis-ci.org/go-oauth2/oauth2
193 | [build-status-image]: https://travis-ci.org/go-oauth2/oauth2.svg?branch=master
194 | [codecov-url]: https://codecov.io/gh/go-oauth2/oauth2
195 | [codecov-image]: https://codecov.io/gh/go-oauth2/oauth2/branch/master/graph/badge.svg
196 | [reportcard-url]: https://goreportcard.com/report/github.com/go-oauth2/oauth2/v4
197 | [reportcard-image]: https://goreportcard.com/badge/github.com/go-oauth2/oauth2/v4
198 | [godoc-url]: https://godoc.org/github.com/go-oauth2/oauth2/v4
199 | [godoc-image]: https://godoc.org/github.com/go-oauth2/oauth2/v4?status.svg
200 | [license-url]: http://opensource.org/licenses/MIT
201 | [license-image]: https://img.shields.io/npm/l/express.svg
202 |
--------------------------------------------------------------------------------
/const.go:
--------------------------------------------------------------------------------
1 | package oauth2
2 |
3 | import (
4 | "crypto/sha256"
5 | "encoding/base64"
6 | "strings"
7 | )
8 |
9 | // ResponseType the type of authorization request
10 | type ResponseType string
11 |
12 | // define the type of authorization request
13 | const (
14 | Code ResponseType = "code"
15 | Token ResponseType = "token"
16 | )
17 |
18 | func (rt ResponseType) String() string {
19 | return string(rt)
20 | }
21 |
22 | // GrantType authorization model
23 | type GrantType string
24 |
25 | // define authorization model
26 | const (
27 | AuthorizationCode GrantType = "authorization_code"
28 | PasswordCredentials GrantType = "password"
29 | ClientCredentials GrantType = "client_credentials"
30 | Refreshing GrantType = "refresh_token"
31 | Implicit GrantType = "__implicit"
32 | )
33 |
34 | func (gt GrantType) String() string {
35 | if gt == AuthorizationCode ||
36 | gt == PasswordCredentials ||
37 | gt == ClientCredentials ||
38 | gt == Refreshing {
39 | return string(gt)
40 | }
41 | return ""
42 | }
43 |
44 | // CodeChallengeMethod PCKE method
45 | type CodeChallengeMethod string
46 |
47 | const (
48 | // CodeChallengePlain PCKE Method
49 | CodeChallengePlain CodeChallengeMethod = "plain"
50 | // CodeChallengeS256 PCKE Method
51 | CodeChallengeS256 CodeChallengeMethod = "S256"
52 | )
53 |
54 | func (ccm CodeChallengeMethod) String() string {
55 | if ccm == CodeChallengePlain ||
56 | ccm == CodeChallengeS256 {
57 | return string(ccm)
58 | }
59 | return ""
60 | }
61 |
62 | // Validate code challenge
63 | func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
64 | switch ccm {
65 | case CodeChallengePlain:
66 | return cc == ver
67 | case CodeChallengeS256:
68 | s256 := sha256.Sum256([]byte(ver))
69 | // trim padding
70 | a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=")
71 | b := strings.TrimRight(cc, "=")
72 | return a == b
73 | default:
74 | return false
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/const_test.go:
--------------------------------------------------------------------------------
1 | package oauth2_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/go-oauth2/oauth2/v4"
7 | )
8 |
9 | func TestValidatePlain(t *testing.T) {
10 | cc := oauth2.CodeChallengePlain
11 | if !cc.Validate("plaintest", "plaintest") {
12 | t.Fatal("not valid")
13 | }
14 | }
15 |
16 | func TestValidateS256(t *testing.T) {
17 | cc := oauth2.CodeChallengeS256
18 | if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=", "s256test") {
19 | t.Fatal("not valid")
20 | }
21 | }
22 |
23 | func TestValidateS256NoPadding(t *testing.T) {
24 | cc := oauth2.CodeChallengeS256
25 | if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o", "s256test") {
26 | t.Fatal("not valid")
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | // OAuth 2.0 server library for the Go programming language
2 | //
3 | // package main
4 | // import (
5 | // "net/http"
6 | // "github.com/go-oauth2/oauth2/v4/manage"
7 | // "github.com/go-oauth2/oauth2/v4/server"
8 | // "github.com/go-oauth2/oauth2/v4/store"
9 | // )
10 | // func main() {
11 | // manager := manage.NewDefaultManager()
12 | // manager.MustTokenStorage(store.NewMemoryTokenStore())
13 | // manager.MapClientStorage(store.NewTestClientStore())
14 | // srv := server.NewDefaultServer(manager)
15 | // http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) {
16 | // srv.HandleAuthorizeRequest(w, r)
17 | // })
18 | // http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
19 | // srv.HandleTokenRequest(w, r)
20 | // })
21 | // http.ListenAndServe(":9096", nil)
22 | // }
23 |
24 | package oauth2
25 |
--------------------------------------------------------------------------------
/errors/error.go:
--------------------------------------------------------------------------------
1 | package errors
2 |
3 | import "errors"
4 |
5 | // New returns an error that formats as the given text.
6 | var New = errors.New
7 |
8 | // known errors
9 | var (
10 | ErrInvalidRedirectURI = errors.New("invalid redirect uri")
11 | ErrInvalidAuthorizeCode = errors.New("invalid authorize code")
12 | ErrInvalidAccessToken = errors.New("invalid access token")
13 | ErrInvalidRefreshToken = errors.New("invalid refresh token")
14 | ErrExpiredAccessToken = errors.New("expired access token")
15 | ErrExpiredRefreshToken = errors.New("expired refresh token")
16 | ErrMissingCodeVerifier = errors.New("missing code verifier")
17 | ErrMissingCodeChallenge = errors.New("missing code challenge")
18 | ErrInvalidCodeChallenge = errors.New("invalid code challenge")
19 | )
20 |
--------------------------------------------------------------------------------
/errors/response.go:
--------------------------------------------------------------------------------
1 | package errors
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | )
7 |
8 | // Response error response
9 | type Response struct {
10 | Error error
11 | ErrorCode int
12 | Description string
13 | URI string
14 | StatusCode int
15 | Header http.Header
16 | }
17 |
18 | // NewResponse create the response pointer
19 | func NewResponse(err error, statusCode int) *Response {
20 | return &Response{
21 | Error: err,
22 | StatusCode: statusCode,
23 | }
24 | }
25 |
26 | // SetHeader sets the header entries associated with key to
27 | // the single element value.
28 | func (r *Response) SetHeader(key, value string) {
29 | if r.Header == nil {
30 | r.Header = make(http.Header)
31 | }
32 | r.Header.Set(key, value)
33 | }
34 |
35 | // https://tools.ietf.org/html/rfc6749#section-5.2
36 | var (
37 | ErrInvalidRequest = errors.New("invalid_request")
38 | ErrUnauthorizedClient = errors.New("unauthorized_client")
39 | ErrAccessDenied = errors.New("access_denied")
40 | ErrUnsupportedResponseType = errors.New("unsupported_response_type")
41 | ErrInvalidScope = errors.New("invalid_scope")
42 | ErrServerError = errors.New("server_error")
43 | ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
44 | ErrInvalidClient = errors.New("invalid_client")
45 | ErrInvalidGrant = errors.New("invalid_grant")
46 | ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
47 | ErrCodeChallengeRquired = errors.New("invalid_request")
48 | ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
49 | ErrInvalidCodeChallengeLen = errors.New("invalid_request")
50 | )
51 |
52 | // Descriptions error description
53 | var Descriptions = map[error]string{
54 | ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
55 | ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
56 | ErrAccessDenied: "The resource owner or authorization server denied the request",
57 | ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
58 | ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
59 | ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
60 | ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
61 | ErrInvalidClient: "Client authentication failed",
62 | ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
63 | ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
64 | ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
65 | ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
66 | ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
67 | }
68 |
69 | // StatusCodes response error HTTP status code
70 | var StatusCodes = map[error]int{
71 | ErrInvalidRequest: 400,
72 | ErrUnauthorizedClient: 401,
73 | ErrAccessDenied: 403,
74 | ErrUnsupportedResponseType: 401,
75 | ErrInvalidScope: 400,
76 | ErrServerError: 500,
77 | ErrTemporarilyUnavailable: 503,
78 | ErrInvalidClient: 401,
79 | ErrInvalidGrant: 401,
80 | ErrUnsupportedGrantType: 401,
81 | ErrCodeChallengeRquired: 400,
82 | ErrUnsupportedCodeChallengeMethod: 400,
83 | ErrInvalidCodeChallengeLen: 400,
84 | }
85 |
--------------------------------------------------------------------------------
/example/README.md:
--------------------------------------------------------------------------------
1 | # Use Examples
2 |
3 | ## Run Server
4 |
5 | ``` bash
6 | $ cd example/server
7 | $ go build server.go
8 | $ ./server
9 | ```
10 |
11 | ## Run Client
12 |
13 | ```
14 | $ cd example/client
15 | $ go build client.go
16 | $ ./client
17 | ```
18 |
19 | ## Authorization Code Grant
20 |
21 | ### Open the browser
22 |
23 | [http://localhost:9094](http://localhost:9094)
24 |
25 | ```
26 | {
27 | "access_token": "GIGXO8XWPQSAUGOYQGTV8Q",
28 | "token_type": "Bearer",
29 | "refresh_token": "5FBLXQ47XJ2MGTY8YRZQ8W",
30 | "expiry": "2019-01-08T01:53:45.868194+08:00"
31 | }
32 | ```
33 |
34 |
35 | ### Try access token
36 |
37 | Open the browser [http://localhost:9094/try](http://localhost:9094/try)
38 |
39 | ```
40 | {
41 | "client_id": "222222",
42 | "expires_in": 7195,
43 | "user_id": "000000"
44 | }
45 | ```
46 |
47 | ## Refresh token
48 |
49 | Open the browser [http://localhost:9094/refresh](http://localhost:9094/refresh)
50 |
51 | ```
52 | {
53 | "access_token": "0IIL4_AJN2-SR0JEYZVQWG",
54 | "token_type": "Bearer",
55 | "refresh_token": "AG6-63MLXUEFUV2Q_BLYIW",
56 | "expiry": "2019-01-09T23:03:16.374062+08:00"
57 | }
58 | ```
59 |
60 | ## Password Credentials Grant
61 |
62 | Open the browser [http://localhost:9094/pwd](http://localhost:9094/pwd)
63 |
64 | ```
65 | {
66 | "access_token": "87JT3N6WOWANXVDNZFHY7Q",
67 | "token_type": "Bearer",
68 | "refresh_token": "LDIS6PXAVY-BXHPEDESWNG",
69 | "expiry": "2019-02-12T10:58:43.734902+08:00"
70 | }
71 | ```
72 |
73 | ## Client Credentials Grant
74 |
75 | Open the browser [http://localhost:9094/client](http://localhost:9094/client)
76 |
77 | ```
78 | {
79 | "access_token": "OA6ITALNMDOGD58C0SN-MG",
80 | "token_type": "Bearer",
81 | "expiry": "2019-02-12T11:10:35.864838+08:00"
82 | }
83 | ```
84 |
85 | 
86 | 
87 | 
88 |
--------------------------------------------------------------------------------
/example/client/client.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "encoding/base64"
7 | "encoding/json"
8 | "fmt"
9 | "io"
10 | "log"
11 | "net/http"
12 | "time"
13 |
14 | "golang.org/x/oauth2"
15 | "golang.org/x/oauth2/clientcredentials"
16 | )
17 |
18 | const (
19 | authServerURL = "http://localhost:9096"
20 | )
21 |
22 | var (
23 | config = oauth2.Config{
24 | ClientID: "222222",
25 | ClientSecret: "22222222",
26 | Scopes: []string{"all"},
27 | RedirectURL: "http://localhost:9094/oauth2",
28 | Endpoint: oauth2.Endpoint{
29 | AuthURL: authServerURL + "/oauth/authorize",
30 | TokenURL: authServerURL + "/oauth/token",
31 | },
32 | }
33 | globalToken *oauth2.Token // Non-concurrent security
34 | )
35 |
36 | func main() {
37 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
38 | u := config.AuthCodeURL("xyz",
39 | oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")),
40 | oauth2.SetAuthURLParam("code_challenge_method", "S256"))
41 | http.Redirect(w, r, u, http.StatusFound)
42 | })
43 |
44 | http.HandleFunc("/oauth2", func(w http.ResponseWriter, r *http.Request) {
45 | r.ParseForm()
46 | state := r.Form.Get("state")
47 | if state != "xyz" {
48 | http.Error(w, "State invalid", http.StatusBadRequest)
49 | return
50 | }
51 | code := r.Form.Get("code")
52 | if code == "" {
53 | http.Error(w, "Code not found", http.StatusBadRequest)
54 | return
55 | }
56 | token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example"))
57 | if err != nil {
58 | http.Error(w, err.Error(), http.StatusInternalServerError)
59 | return
60 | }
61 | globalToken = token
62 |
63 | e := json.NewEncoder(w)
64 | e.SetIndent("", " ")
65 | e.Encode(token)
66 | })
67 |
68 | http.HandleFunc("/refresh", func(w http.ResponseWriter, r *http.Request) {
69 | if globalToken == nil {
70 | http.Redirect(w, r, "/", http.StatusFound)
71 | return
72 | }
73 |
74 | globalToken.Expiry = time.Now()
75 | token, err := config.TokenSource(context.Background(), globalToken).Token()
76 | if err != nil {
77 | http.Error(w, err.Error(), http.StatusInternalServerError)
78 | return
79 | }
80 |
81 | globalToken = token
82 | e := json.NewEncoder(w)
83 | e.SetIndent("", " ")
84 | e.Encode(token)
85 | })
86 |
87 | http.HandleFunc("/try", func(w http.ResponseWriter, r *http.Request) {
88 | if globalToken == nil {
89 | http.Redirect(w, r, "/", http.StatusFound)
90 | return
91 | }
92 |
93 | resp, err := http.Get(fmt.Sprintf("%s/test?access_token=%s", authServerURL, globalToken.AccessToken))
94 | if err != nil {
95 | http.Error(w, err.Error(), http.StatusBadRequest)
96 | return
97 | }
98 | defer resp.Body.Close()
99 |
100 | io.Copy(w, resp.Body)
101 | })
102 |
103 | http.HandleFunc("/pwd", func(w http.ResponseWriter, r *http.Request) {
104 | token, err := config.PasswordCredentialsToken(context.Background(), "test", "test")
105 | if err != nil {
106 | http.Error(w, err.Error(), http.StatusInternalServerError)
107 | return
108 | }
109 |
110 | globalToken = token
111 | e := json.NewEncoder(w)
112 | e.SetIndent("", " ")
113 | e.Encode(token)
114 | })
115 |
116 | http.HandleFunc("/client", func(w http.ResponseWriter, r *http.Request) {
117 | cfg := clientcredentials.Config{
118 | ClientID: config.ClientID,
119 | ClientSecret: config.ClientSecret,
120 | TokenURL: config.Endpoint.TokenURL,
121 | }
122 |
123 | token, err := cfg.Token(context.Background())
124 | if err != nil {
125 | http.Error(w, err.Error(), http.StatusInternalServerError)
126 | return
127 | }
128 |
129 | e := json.NewEncoder(w)
130 | e.SetIndent("", " ")
131 | e.Encode(token)
132 | })
133 |
134 | log.Println("Client is running at 9094 port.Please open http://localhost:9094")
135 | log.Fatal(http.ListenAndServe(":9094", nil))
136 | }
137 |
138 | func genCodeChallengeS256(s string) string {
139 | s256 := sha256.Sum256([]byte(s))
140 | return base64.URLEncoding.EncodeToString(s256[:])
141 | }
142 |
--------------------------------------------------------------------------------
/example/server/server.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "flag"
7 | "fmt"
8 | "io"
9 | "log"
10 | "net/http"
11 | "net/http/httputil"
12 | "net/url"
13 | "os"
14 | "time"
15 |
16 | "github.com/go-oauth2/oauth2/v4/generates"
17 |
18 | "github.com/go-oauth2/oauth2/v4/errors"
19 | "github.com/go-oauth2/oauth2/v4/manage"
20 | "github.com/go-oauth2/oauth2/v4/models"
21 | "github.com/go-oauth2/oauth2/v4/server"
22 | "github.com/go-oauth2/oauth2/v4/store"
23 | "github.com/go-session/session/v3"
24 | )
25 |
26 | var (
27 | dumpvar bool
28 | idvar string
29 | secretvar string
30 | domainvar string
31 | portvar int
32 | )
33 |
34 | func init() {
35 | flag.BoolVar(&dumpvar, "d", true, "Dump requests and responses")
36 | flag.StringVar(&idvar, "i", "222222", "The client id being passed in")
37 | flag.StringVar(&secretvar, "s", "22222222", "The client secret being passed in")
38 | flag.StringVar(&domainvar, "r", "http://localhost:9094", "The domain of the redirect url")
39 | flag.IntVar(&portvar, "p", 9096, "the base port for the server")
40 | }
41 |
42 | func main() {
43 | flag.Parse()
44 | if dumpvar {
45 | log.Println("Dumping requests")
46 | }
47 | manager := manage.NewDefaultManager()
48 | manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
49 |
50 | // token store
51 | manager.MustTokenStorage(store.NewMemoryTokenStore())
52 |
53 | // generate jwt access token
54 | // manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512))
55 | manager.MapAccessGenerate(generates.NewAccessGenerate())
56 |
57 | clientStore := store.NewClientStore()
58 | clientStore.Set(idvar, &models.Client{
59 | ID: idvar,
60 | Secret: secretvar,
61 | Domain: domainvar,
62 | })
63 | manager.MapClientStorage(clientStore)
64 |
65 | srv := server.NewServer(server.NewConfig(), manager)
66 |
67 | srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) {
68 | if username == "test" && password == "test" {
69 | userID = "test"
70 | } else {
71 | err = errors.New("invalid username or password")
72 | }
73 | return
74 | })
75 |
76 | srv.SetUserAuthorizationHandler(userAuthorizeHandler)
77 |
78 | srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
79 | log.Println("Internal Error:", err.Error())
80 | return
81 | })
82 |
83 | srv.SetResponseErrorHandler(func(re *errors.Response) {
84 | log.Println("Response Error:", re.Error.Error())
85 | })
86 |
87 | http.HandleFunc("/login", loginHandler)
88 | http.HandleFunc("/auth", authHandler)
89 |
90 | http.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) {
91 | if dumpvar {
92 | dumpRequest(os.Stdout, "authorize", r)
93 | }
94 |
95 | store, err := session.Start(r.Context(), w, r)
96 | if err != nil {
97 | http.Error(w, err.Error(), http.StatusInternalServerError)
98 | return
99 | }
100 |
101 | var form url.Values
102 | if v, ok := store.Get("ReturnUri"); ok {
103 | form = v.(url.Values)
104 | }
105 | r.Form = form
106 |
107 | store.Delete("ReturnUri")
108 | store.Save()
109 |
110 | err = srv.HandleAuthorizeRequest(w, r)
111 | if err != nil {
112 | http.Error(w, err.Error(), http.StatusBadRequest)
113 | }
114 | })
115 |
116 | http.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
117 | if dumpvar {
118 | _ = dumpRequest(os.Stdout, "token", r) // Ignore the error
119 | }
120 |
121 | err := srv.HandleTokenRequest(w, r)
122 | if err != nil {
123 | http.Error(w, err.Error(), http.StatusInternalServerError)
124 | }
125 | })
126 |
127 | http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
128 | if dumpvar {
129 | _ = dumpRequest(os.Stdout, "test", r) // Ignore the error
130 | }
131 | token, err := srv.ValidationBearerToken(r)
132 | if err != nil {
133 | http.Error(w, err.Error(), http.StatusBadRequest)
134 | return
135 | }
136 |
137 | data := map[string]interface{}{
138 | "expires_in": int64(token.GetAccessCreateAt().Add(token.GetAccessExpiresIn()).Sub(time.Now()).Seconds()),
139 | "client_id": token.GetClientID(),
140 | "user_id": token.GetUserID(),
141 | }
142 | e := json.NewEncoder(w)
143 | e.SetIndent("", " ")
144 | e.Encode(data)
145 | })
146 |
147 | log.Printf("Server is running at %d port.\n", portvar)
148 | log.Printf("Point your OAuth client Auth endpoint to %s:%d%s", "http://localhost", portvar, "/oauth/authorize")
149 | log.Printf("Point your OAuth client Token endpoint to %s:%d%s", "http://localhost", portvar, "/oauth/token")
150 | log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", portvar), nil))
151 | }
152 |
153 | func dumpRequest(writer io.Writer, header string, r *http.Request) error {
154 | data, err := httputil.DumpRequest(r, true)
155 | if err != nil {
156 | return err
157 | }
158 | writer.Write([]byte("\n" + header + ": \n"))
159 | writer.Write(data)
160 | return nil
161 | }
162 |
163 | func userAuthorizeHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
164 | if dumpvar {
165 | _ = dumpRequest(os.Stdout, "userAuthorizeHandler", r) // Ignore the error
166 | }
167 | store, err := session.Start(r.Context(), w, r)
168 | if err != nil {
169 | return
170 | }
171 |
172 | uid, ok := store.Get("LoggedInUserID")
173 | if !ok {
174 | if r.Form == nil {
175 | r.ParseForm()
176 | }
177 | store.Set("ReturnUri", r.Form)
178 | store.Save()
179 |
180 | w.Header().Set("Location", "/login")
181 | w.WriteHeader(http.StatusFound)
182 | return
183 | }
184 |
185 | userID = uid.(string)
186 | store.Delete("LoggedInUserID")
187 | store.Save()
188 | return
189 | }
190 |
191 | func loginHandler(w http.ResponseWriter, r *http.Request) {
192 | if dumpvar {
193 | _ = dumpRequest(os.Stdout, "login", r) // Ignore the error
194 | }
195 | store, err := session.Start(r.Context(), w, r)
196 | if err != nil {
197 | http.Error(w, err.Error(), http.StatusInternalServerError)
198 | return
199 | }
200 |
201 | if r.Method == "POST" {
202 | if r.Form == nil {
203 | if err := r.ParseForm(); err != nil {
204 | http.Error(w, err.Error(), http.StatusInternalServerError)
205 | return
206 | }
207 | }
208 |
209 | if r.Form.Get("username") == "test" && r.Form.Get("password") == "test" {
210 | store.Set("LoggedInUserID", r.Form.Get("username"))
211 | store.Save()
212 |
213 | w.Header().Set("Location", "/auth")
214 | w.WriteHeader(http.StatusFound)
215 | return
216 | } else {
217 | http.Error(w, "Invalid username or password", http.StatusUnauthorized)
218 | return
219 | }
220 | }
221 | outputHTML(w, r, "static/login.html")
222 | }
223 |
224 | func authHandler(w http.ResponseWriter, r *http.Request) {
225 | if dumpvar {
226 | _ = dumpRequest(os.Stdout, "auth", r) // Ignore the error
227 | }
228 | store, err := session.Start(nil, w, r)
229 | if err != nil {
230 | http.Error(w, err.Error(), http.StatusInternalServerError)
231 | return
232 | }
233 |
234 | if _, ok := store.Get("LoggedInUserID"); !ok {
235 | w.Header().Set("Location", "/login")
236 | w.WriteHeader(http.StatusFound)
237 | return
238 | }
239 |
240 | outputHTML(w, r, "static/auth.html")
241 | }
242 |
243 | func outputHTML(w http.ResponseWriter, req *http.Request, filename string) {
244 | file, err := os.Open(filename)
245 | if err != nil {
246 | http.Error(w, err.Error(), 500)
247 | return
248 | }
249 | defer file.Close()
250 | fi, _ := file.Stat()
251 | http.ServeContent(w, req, file.Name(), fi.ModTime(), file)
252 | }
253 |
--------------------------------------------------------------------------------
/example/server/static/auth.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Auth
6 |
10 |
11 |
12 |
13 |
14 |
15 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/example/server/static/auth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/go-oauth2/oauth2/bb6e415e9b625f5109f64f9bad0897141cc2f8dd/example/server/static/auth.png
--------------------------------------------------------------------------------
/example/server/static/login.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Login
7 |
8 |
9 |
10 |
11 |
12 |
13 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/example/server/static/login.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/go-oauth2/oauth2/bb6e415e9b625f5109f64f9bad0897141cc2f8dd/example/server/static/login.png
--------------------------------------------------------------------------------
/example/server/static/token.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/go-oauth2/oauth2/bb6e415e9b625f5109f64f9bad0897141cc2f8dd/example/server/static/token.png
--------------------------------------------------------------------------------
/generate.go:
--------------------------------------------------------------------------------
1 | package oauth2
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "time"
7 | )
8 |
9 | type (
10 | // GenerateBasic provide the basis of the generated token data
11 | GenerateBasic struct {
12 | Client ClientInfo
13 | UserID string
14 | CreateAt time.Time
15 | TokenInfo TokenInfo
16 | Request *http.Request
17 | }
18 |
19 | // AuthorizeGenerate generate the authorization code interface
20 | AuthorizeGenerate interface {
21 | Token(ctx context.Context, data *GenerateBasic) (code string, err error)
22 | }
23 |
24 | // AccessGenerate generate the access and refresh tokens interface
25 | AccessGenerate interface {
26 | Token(ctx context.Context, data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
27 | }
28 | )
29 |
--------------------------------------------------------------------------------
/generates/access.go:
--------------------------------------------------------------------------------
1 | package generates
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/base64"
7 | "strconv"
8 | "strings"
9 |
10 | "github.com/go-oauth2/oauth2/v4"
11 | "github.com/google/uuid"
12 | )
13 |
14 | // NewAccessGenerate create to generate the access token instance
15 | func NewAccessGenerate() *AccessGenerate {
16 | return &AccessGenerate{}
17 | }
18 |
19 | // AccessGenerate generate the access token
20 | type AccessGenerate struct {
21 | }
22 |
23 | // Token based on the UUID generated token
24 | func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
25 | buf := bytes.NewBufferString(data.Client.GetID())
26 | buf.WriteString(data.UserID)
27 | buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10))
28 |
29 | access := base64.URLEncoding.EncodeToString([]byte(uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
30 | access = strings.ToUpper(strings.TrimRight(access, "="))
31 | refresh := ""
32 | if isGenRefresh {
33 | refresh = base64.URLEncoding.EncodeToString([]byte(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
34 | refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
35 | }
36 |
37 | return access, refresh, nil
38 | }
39 |
--------------------------------------------------------------------------------
/generates/access_test.go:
--------------------------------------------------------------------------------
1 | package generates_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/go-oauth2/oauth2/v4"
9 | "github.com/go-oauth2/oauth2/v4/generates"
10 | "github.com/go-oauth2/oauth2/v4/models"
11 |
12 | . "github.com/smartystreets/goconvey/convey"
13 | )
14 |
15 | func TestAccess(t *testing.T) {
16 | Convey("Test Access Generate", t, func() {
17 | data := &oauth2.GenerateBasic{
18 | Client: &models.Client{
19 | ID: "123456",
20 | Secret: "123456",
21 | },
22 | UserID: "000000",
23 | CreateAt: time.Now(),
24 | }
25 | gen := generates.NewAccessGenerate()
26 | access, refresh, err := gen.Token(context.Background(), data, true)
27 | So(err, ShouldBeNil)
28 | So(access, ShouldNotBeEmpty)
29 | So(refresh, ShouldNotBeEmpty)
30 | Println("\nAccess Token:" + access)
31 | Println("Refresh Token:" + refresh)
32 | })
33 | }
34 |
--------------------------------------------------------------------------------
/generates/authorize.go:
--------------------------------------------------------------------------------
1 | package generates
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/base64"
7 | "strings"
8 |
9 | "github.com/go-oauth2/oauth2/v4"
10 | "github.com/google/uuid"
11 | )
12 |
13 | // NewAuthorizeGenerate create to generate the authorize code instance
14 | func NewAuthorizeGenerate() *AuthorizeGenerate {
15 | return &AuthorizeGenerate{}
16 | }
17 |
18 | // AuthorizeGenerate generate the authorize code
19 | type AuthorizeGenerate struct{}
20 |
21 | // Token based on the UUID generated token
22 | func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) {
23 | buf := bytes.NewBufferString(data.Client.GetID())
24 | buf.WriteString(data.UserID)
25 | token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes())
26 | code := base64.URLEncoding.EncodeToString([]byte(token.String()))
27 | code = strings.ToUpper(strings.TrimRight(code, "="))
28 |
29 | return code, nil
30 | }
31 |
--------------------------------------------------------------------------------
/generates/authorize_test.go:
--------------------------------------------------------------------------------
1 | package generates_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/go-oauth2/oauth2/v4"
9 | "github.com/go-oauth2/oauth2/v4/generates"
10 | "github.com/go-oauth2/oauth2/v4/models"
11 |
12 | . "github.com/smartystreets/goconvey/convey"
13 | )
14 |
15 | func TestAuthorize(t *testing.T) {
16 | Convey("Test Authorize Generate", t, func() {
17 | data := &oauth2.GenerateBasic{
18 | Client: &models.Client{
19 | ID: "123456",
20 | Secret: "123456",
21 | },
22 | UserID: "000000",
23 | CreateAt: time.Now(),
24 | }
25 | gen := generates.NewAuthorizeGenerate()
26 | code, err := gen.Token(context.Background(), data)
27 | So(err, ShouldBeNil)
28 | So(code, ShouldNotBeEmpty)
29 | Println("\nAuthorize Code:" + code)
30 | })
31 | }
32 |
--------------------------------------------------------------------------------
/generates/jwt_access.go:
--------------------------------------------------------------------------------
1 | package generates
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "strings"
7 | "time"
8 |
9 | "github.com/go-oauth2/oauth2/v4"
10 | "github.com/go-oauth2/oauth2/v4/errors"
11 | "github.com/golang-jwt/jwt/v5"
12 | "github.com/google/uuid"
13 | )
14 |
15 | // JWTAccessClaims jwt claims
16 | type JWTAccessClaims struct {
17 | jwt.RegisteredClaims
18 | }
19 |
20 | // Valid claims verification
21 | func (a *JWTAccessClaims) Valid() error {
22 | if a.ExpiresAt != nil && time.Unix(a.ExpiresAt.Unix(), 0).Before(time.Now()) {
23 | return errors.ErrInvalidAccessToken
24 | }
25 | return nil
26 | }
27 |
28 | // NewJWTAccessGenerate create to generate the jwt access token instance
29 | func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
30 | return &JWTAccessGenerate{
31 | SignedKeyID: kid,
32 | SignedKey: key,
33 | SignedMethod: method,
34 | }
35 | }
36 |
37 | // JWTAccessGenerate generate the jwt access token
38 | type JWTAccessGenerate struct {
39 | SignedKeyID string
40 | SignedKey []byte
41 | SignedMethod jwt.SigningMethod
42 | }
43 |
44 | // Token based on the UUID generated token
45 | func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
46 | claims := &JWTAccessClaims{
47 | RegisteredClaims: jwt.RegisteredClaims{
48 | Audience: jwt.ClaimStrings{data.Client.GetID()},
49 | Subject: data.UserID,
50 | ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())),
51 | },
52 | }
53 |
54 | token := jwt.NewWithClaims(a.SignedMethod, claims)
55 | if a.SignedKeyID != "" {
56 | token.Header["kid"] = a.SignedKeyID
57 | }
58 | var key interface{}
59 | if a.isEs() {
60 | v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
61 | if err != nil {
62 | return "", "", err
63 | }
64 | key = v
65 | } else if a.isRsOrPS() {
66 | v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
67 | if err != nil {
68 | return "", "", err
69 | }
70 | key = v
71 | } else if a.isHs() {
72 | key = a.SignedKey
73 | } else if a.isEd() {
74 | v, err := jwt.ParseEdPrivateKeyFromPEM(a.SignedKey)
75 | if err != nil {
76 | return "", "", err
77 | }
78 | key = v
79 | } else {
80 | return "", "", errors.New("unsupported sign method")
81 | }
82 |
83 | access, err := token.SignedString(key)
84 | if err != nil {
85 | return "", "", err
86 | }
87 | refresh := ""
88 |
89 | if isGenRefresh {
90 | t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
91 | refresh = base64.URLEncoding.EncodeToString([]byte(t))
92 | refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
93 | }
94 |
95 | return access, refresh, nil
96 | }
97 |
98 | func (a *JWTAccessGenerate) isEs() bool {
99 | return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
100 | }
101 |
102 | func (a *JWTAccessGenerate) isRsOrPS() bool {
103 | isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
104 | isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
105 | return isRs || isPs
106 | }
107 |
108 | func (a *JWTAccessGenerate) isHs() bool {
109 | return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
110 | }
111 |
112 | func (a *JWTAccessGenerate) isEd() bool {
113 | return strings.HasPrefix(a.SignedMethod.Alg(), "Ed")
114 | }
115 |
--------------------------------------------------------------------------------
/generates/jwt_access_test.go:
--------------------------------------------------------------------------------
1 | package generates_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "testing"
7 | "time"
8 |
9 | "github.com/go-oauth2/oauth2/v4"
10 | "github.com/go-oauth2/oauth2/v4/generates"
11 | "github.com/go-oauth2/oauth2/v4/models"
12 | "github.com/golang-jwt/jwt/v5"
13 |
14 | . "github.com/smartystreets/goconvey/convey"
15 | )
16 |
17 | func TestJWTAccess(t *testing.T) {
18 | Convey("Test JWT Access Generate", t, func() {
19 | data := &oauth2.GenerateBasic{
20 | Client: &models.Client{
21 | ID: "123456",
22 | Secret: "123456",
23 | },
24 | UserID: "000000",
25 | TokenInfo: &models.Token{
26 | AccessCreateAt: time.Now(),
27 | AccessExpiresIn: time.Second * 120,
28 | },
29 | }
30 |
31 | gen := generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512)
32 | access, refresh, err := gen.Token(context.Background(), data, true)
33 | So(err, ShouldBeNil)
34 | So(access, ShouldNotBeEmpty)
35 | So(refresh, ShouldNotBeEmpty)
36 |
37 | token, err := jwt.ParseWithClaims(access, &generates.JWTAccessClaims{}, func(t *jwt.Token) (interface{}, error) {
38 | if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
39 | return nil, fmt.Errorf("parse error")
40 | }
41 | return []byte("00000000"), nil
42 | })
43 | So(err, ShouldBeNil)
44 |
45 | claims, ok := token.Claims.(*generates.JWTAccessClaims)
46 | So(ok, ShouldBeTrue)
47 | So(token.Valid, ShouldBeTrue)
48 | aud, err := claims.GetAudience()
49 | So(err, ShouldBeNil)
50 | So(len(aud), ShouldEqual, 1)
51 | So(aud[0], ShouldEqual, "123456")
52 | So(claims.Subject, ShouldEqual, "000000")
53 | })
54 | }
55 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/go-oauth2/oauth2/v4
2 |
3 | go 1.13
4 |
5 | require (
6 | github.com/ajg/form v1.5.1 // indirect
7 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect
8 | github.com/fatih/structs v1.1.0 // indirect
9 | github.com/gavv/httpexpect v2.0.0+incompatible
10 | github.com/go-session/session/v3 v3.2.1
11 | github.com/golang-jwt/jwt v3.2.2+incompatible
12 | github.com/google/go-querystring v1.0.0 // indirect
13 | github.com/google/uuid v1.1.1
14 | github.com/gorilla/websocket v1.4.2 // indirect
15 | github.com/imkira/go-interpol v1.1.0 // indirect
16 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect
17 | github.com/mattn/go-colorable v0.1.7 // indirect
18 | github.com/moul/http2curl v1.0.0 // indirect
19 | github.com/onsi/ginkgo v1.13.0 // indirect
20 | github.com/sergi/go-diff v1.1.0 // indirect
21 | github.com/smartystreets/goconvey v1.6.4
22 | github.com/tidwall/buntdb v1.1.2
23 | github.com/tidwall/gjson v1.12.1 // indirect
24 | github.com/valyala/fasthttp v1.34.0 // indirect
25 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect
26 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 // indirect
27 | github.com/yudai/gojsondiff v1.0.0 // indirect
28 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect
29 | github.com/yudai/pp v2.0.1+incompatible // indirect
30 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
31 | google.golang.org/appengine v1.6.6 // indirect
32 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect
33 | )
34 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
2 | github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
3 | github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
4 | github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
5 | github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
6 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
7 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
10 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
11 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 h1:DddqAaWDpywytcG8w/qoQ5sAN8X12d3Z3koB0C3Rxsc=
12 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8=
13 | github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
14 | github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
15 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
16 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
17 | github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
18 | github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8=
19 | github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
20 | github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY=
21 | github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M=
22 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
23 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
24 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
25 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
26 | github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
27 | github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
28 | github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
29 | github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
30 | github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
31 | github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
32 | github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
33 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
34 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
35 | github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
36 | github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
37 | github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
38 | github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
39 | github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
40 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
41 | github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
42 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
43 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
44 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
45 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
46 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
47 | github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
48 | github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
49 | github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
50 | github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
51 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM=
52 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k=
53 | github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U=
54 | github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
55 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
56 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
57 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
58 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
59 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
60 | github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw=
61 | github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
62 | github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
63 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
64 | github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
65 | github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
66 | github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78=
67 | github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
68 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
69 | github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
70 | github.com/onsi/ginkgo v1.13.0 h1:M76yO2HkZASFjXL0HSoZJ1AYEmQxNJmY41Jx1zNUq1Y=
71 | github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0=
72 | github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
73 | github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE=
74 | github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
75 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
76 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
77 | github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
78 | github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
79 | github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
80 | github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
81 | github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
82 | github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
83 | github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
84 | github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
85 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
86 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
87 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
88 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
89 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
90 | github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E=
91 | github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
92 | github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo=
93 | github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI=
94 | github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
95 | github.com/tidwall/gjson v1.12.1 h1:ikuZsLdhr8Ws0IdROXUS1Gi4v9Z4pGqpX/CvJkxvfpo=
96 | github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
97 | github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb h1:5NSYaAdrnblKByzd7XByQEJVT8+9v0W/tIY0Oo4OwrE=
98 | github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M=
99 | github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
100 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
101 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
102 | github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
103 | github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
104 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
105 | github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e h1:+NL1GDIUOKxVfbp2KoJQD9cTQ6dyP2co9q4yzmT9FZo=
106 | github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao=
107 | github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ3KKDyCkxapfufrqDqwmGjcHfAyXRrE=
108 | github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ=
109 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
110 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
111 | github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
112 | github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0=
113 | github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
114 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
115 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
116 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
117 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
118 | github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
119 | github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
120 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
121 | github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
122 | github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
123 | github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
124 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=
125 | github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM=
126 | github.com/yudai/pp v2.0.1+incompatible h1:Q4//iY4pNF6yPLZIigmvcl7k/bPgrcTPIFIcmawg5bI=
127 | github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc=
128 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
129 | golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
130 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
131 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
132 | golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
133 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
134 | golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
135 | golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
136 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
137 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
138 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
139 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=
140 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
141 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
142 | golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
143 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
144 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
145 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
146 | golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
147 | golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
148 | golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
149 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
150 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
151 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
152 | golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
153 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
154 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
155 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
156 | golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
157 | golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
158 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14 h1:k5II8e6QD8mITdi+okbbmR/cIyEbeXLBhy5Ha4nevyc=
159 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
160 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
161 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
162 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
163 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
164 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
165 | golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
166 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
167 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
168 | golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
169 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
170 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
171 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
172 | google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc=
173 | google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
174 | google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
175 | google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
176 | google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
177 | google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
178 | google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
179 | google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
180 | google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
181 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
182 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
183 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
184 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
185 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
186 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
187 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
188 | gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
189 | gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
190 | gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
191 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
192 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ=
193 | gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
194 |
--------------------------------------------------------------------------------
/go.test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 | echo "" > coverage.txt
5 |
6 | for d in $(go list ./... | grep -v vendor); do
7 | go test -race -coverprofile=profile.out -covermode=atomic "$d"
8 | if [ -f profile.out ]; then
9 | cat profile.out >> coverage.txt
10 | rm profile.out
11 | fi
12 | done
--------------------------------------------------------------------------------
/manage.go:
--------------------------------------------------------------------------------
1 | package oauth2
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "time"
7 | )
8 |
9 | // TokenGenerateRequest provide to generate the token request parameters
10 | type TokenGenerateRequest struct {
11 | ClientID string
12 | ClientSecret string
13 | UserID string
14 | RedirectURI string
15 | Scope string
16 | Code string
17 | CodeChallenge string
18 | CodeChallengeMethod CodeChallengeMethod
19 | Refresh string
20 | CodeVerifier string
21 | AccessTokenExp time.Duration
22 | Request *http.Request
23 | }
24 |
25 | // Manager authorization management interface
26 | type Manager interface {
27 | // get the client information
28 | GetClient(ctx context.Context, clientID string) (cli ClientInfo, err error)
29 |
30 | // generate the authorization token(code)
31 | GenerateAuthToken(ctx context.Context, rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error)
32 |
33 | // generate the access token
34 | GenerateAccessToken(ctx context.Context, gt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
35 |
36 | // refreshing an access token
37 | RefreshAccessToken(ctx context.Context, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
38 |
39 | // use the access token to delete the token information
40 | RemoveAccessToken(ctx context.Context, access string) (err error)
41 |
42 | // use the refresh token to delete the token information
43 | RemoveRefreshToken(ctx context.Context, refresh string) (err error)
44 |
45 | // according to the access token for corresponding token information
46 | LoadAccessToken(ctx context.Context, access string) (ti TokenInfo, err error)
47 |
48 | // according to the refresh token for corresponding token information
49 | LoadRefreshToken(ctx context.Context, refresh string) (ti TokenInfo, err error)
50 | }
51 |
--------------------------------------------------------------------------------
/manage/config.go:
--------------------------------------------------------------------------------
1 | package manage
2 |
3 | import "time"
4 |
5 | // Config authorization configuration parameters
6 | type Config struct {
7 | // access token expiration time, 0 means it doesn't expire
8 | AccessTokenExp time.Duration
9 | // refresh token expiration time, 0 means it doesn't expire
10 | RefreshTokenExp time.Duration
11 | // whether to generate the refreshing token
12 | IsGenerateRefresh bool
13 | }
14 |
15 | // RefreshingConfig refreshing token config
16 | type RefreshingConfig struct {
17 | // access token expiration time, 0 means it doesn't expire
18 | AccessTokenExp time.Duration
19 | // refresh token expiration time, 0 means it doesn't expire
20 | RefreshTokenExp time.Duration
21 | // whether to generate the refreshing token
22 | IsGenerateRefresh bool
23 | // whether to reset the refreshing create time
24 | IsResetRefreshTime bool
25 | // whether to remove access token
26 | IsRemoveAccess bool
27 | // whether to remove refreshing token
28 | IsRemoveRefreshing bool
29 | }
30 |
31 | // default configs
32 | var (
33 | DefaultCodeExp = time.Minute * 10
34 | DefaultAuthorizeCodeTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3, IsGenerateRefresh: true}
35 | DefaultImplicitTokenCfg = &Config{AccessTokenExp: time.Hour * 1}
36 | DefaultPasswordTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
37 | DefaultClientTokenCfg = &Config{AccessTokenExp: time.Hour * 2}
38 | DefaultRefreshTokenCfg = &RefreshingConfig{IsGenerateRefresh: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
39 | )
40 |
--------------------------------------------------------------------------------
/manage/manage_test.go:
--------------------------------------------------------------------------------
1 | package manage_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/go-oauth2/oauth2/v4"
9 | "github.com/go-oauth2/oauth2/v4/manage"
10 | "github.com/go-oauth2/oauth2/v4/models"
11 | "github.com/go-oauth2/oauth2/v4/store"
12 |
13 | . "github.com/smartystreets/goconvey/convey"
14 | )
15 |
16 | func TestManager(t *testing.T) {
17 | Convey("Manager test", t, func() {
18 | manager := manage.NewDefaultManager()
19 | ctx := context.Background()
20 |
21 | manager.MustTokenStorage(store.NewMemoryTokenStore())
22 |
23 | clientStore := store.NewClientStore()
24 | _ = clientStore.Set("1", &models.Client{
25 | ID: "1",
26 | Secret: "11",
27 | Domain: "http://localhost",
28 | })
29 | manager.MapClientStorage(clientStore)
30 |
31 | tgr := &oauth2.TokenGenerateRequest{
32 | ClientID: "1",
33 | UserID: "123456",
34 | RedirectURI: "http://localhost/oauth2",
35 | Scope: "all",
36 | }
37 |
38 | Convey("GetClient test", func() {
39 | cli, err := manager.GetClient(ctx, "1")
40 | So(err, ShouldBeNil)
41 | So(cli.GetSecret(), ShouldEqual, "11")
42 | })
43 |
44 | Convey("Token test", func() {
45 | testManager(tgr, manager)
46 | })
47 |
48 | Convey("zero expiration access token test", func() {
49 | testZeroAccessExpirationManager(tgr, manager)
50 | testCannotRequestZeroExpirationAccessTokens(tgr, manager)
51 | })
52 |
53 | Convey("zero expiration refresh token test", func() {
54 | testZeroRefreshExpirationManager(tgr, manager)
55 | })
56 | })
57 | }
58 |
59 | func testManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
60 | ctx := context.Background()
61 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
62 | So(err, ShouldBeNil)
63 |
64 | code := cti.GetCode()
65 | So(code, ShouldNotBeEmpty)
66 |
67 | atParams := &oauth2.TokenGenerateRequest{
68 | ClientID: tgr.ClientID,
69 | ClientSecret: "11",
70 | RedirectURI: tgr.RedirectURI,
71 | Code: code,
72 | }
73 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
74 | So(err, ShouldBeNil)
75 |
76 | accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh()
77 | So(accessToken, ShouldNotBeEmpty)
78 | So(refreshToken, ShouldNotBeEmpty)
79 |
80 | ainfo, err := manager.LoadAccessToken(ctx, accessToken)
81 | So(err, ShouldBeNil)
82 | So(ainfo.GetClientID(), ShouldEqual, atParams.ClientID)
83 |
84 | arinfo, err := manager.LoadRefreshToken(ctx, accessToken)
85 | So(err, ShouldNotBeNil)
86 | So(arinfo, ShouldBeNil)
87 |
88 | rainfo, err := manager.LoadAccessToken(ctx, refreshToken)
89 | So(err, ShouldNotBeNil)
90 | So(rainfo, ShouldBeNil)
91 |
92 | rinfo, err := manager.LoadRefreshToken(ctx, refreshToken)
93 | So(err, ShouldBeNil)
94 | So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID)
95 |
96 | refreshParams := &oauth2.TokenGenerateRequest{
97 | Refresh: refreshToken,
98 | Scope: "owner",
99 | }
100 | rti, err := manager.RefreshAccessToken(ctx, refreshParams)
101 | So(err, ShouldBeNil)
102 |
103 | refreshAT := rti.GetAccess()
104 | So(refreshAT, ShouldNotBeEmpty)
105 |
106 | _, err = manager.LoadAccessToken(ctx, accessToken)
107 | So(err, ShouldNotBeNil)
108 |
109 | refreshAInfo, err := manager.LoadAccessToken(ctx, refreshAT)
110 | So(err, ShouldBeNil)
111 | So(refreshAInfo.GetScope(), ShouldEqual, "owner")
112 |
113 | err = manager.RemoveAccessToken(ctx, refreshAT)
114 | So(err, ShouldBeNil)
115 |
116 | _, err = manager.LoadAccessToken(ctx, refreshAT)
117 | So(err, ShouldNotBeNil)
118 |
119 | err = manager.RemoveRefreshToken(ctx, refreshToken)
120 | So(err, ShouldBeNil)
121 |
122 | _, err = manager.LoadRefreshToken(ctx, refreshToken)
123 | So(err, ShouldNotBeNil)
124 | }
125 |
126 | func testZeroAccessExpirationManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
127 | ctx := context.Background()
128 | config := manage.Config{
129 | AccessTokenExp: 0, // Set explicitly as we're testing 0 (no) expiration
130 | IsGenerateRefresh: true,
131 | }
132 | m, ok := manager.(*manage.Manager)
133 | So(ok, ShouldBeTrue)
134 | m.SetAuthorizeCodeTokenCfg(&config)
135 |
136 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
137 | So(err, ShouldBeNil)
138 |
139 | code := cti.GetCode()
140 | So(code, ShouldNotBeEmpty)
141 |
142 | atParams := &oauth2.TokenGenerateRequest{
143 | ClientID: tgr.ClientID,
144 | ClientSecret: "11",
145 | RedirectURI: tgr.RedirectURI,
146 | Code: code,
147 | }
148 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
149 | So(err, ShouldBeNil)
150 |
151 | accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh()
152 | So(accessToken, ShouldNotBeEmpty)
153 | So(refreshToken, ShouldNotBeEmpty)
154 |
155 | tokenInfo, err := manager.LoadAccessToken(ctx, accessToken)
156 | So(err, ShouldBeNil)
157 | So(tokenInfo, ShouldNotBeNil)
158 | So(tokenInfo.GetAccess(), ShouldEqual, accessToken)
159 | So(tokenInfo.GetAccessExpiresIn(), ShouldEqual, 0)
160 | }
161 |
162 | func testCannotRequestZeroExpirationAccessTokens(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
163 | ctx := context.Background()
164 | config := manage.Config{
165 | AccessTokenExp: time.Hour * 5,
166 | }
167 | m, ok := manager.(*manage.Manager)
168 | So(ok, ShouldBeTrue)
169 | m.SetAuthorizeCodeTokenCfg(&config)
170 |
171 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
172 | So(err, ShouldBeNil)
173 |
174 | code := cti.GetCode()
175 | So(code, ShouldNotBeEmpty)
176 |
177 | atParams := &oauth2.TokenGenerateRequest{
178 | ClientID: tgr.ClientID,
179 | ClientSecret: "11",
180 | RedirectURI: tgr.RedirectURI,
181 | AccessTokenExp: 0, // requesting token without expiration
182 | Code: code,
183 | }
184 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
185 | So(err, ShouldBeNil)
186 |
187 | accessToken := ati.GetAccess()
188 | So(accessToken, ShouldNotBeEmpty)
189 | So(ati.GetAccessExpiresIn(), ShouldEqual, time.Hour*5)
190 | }
191 |
192 | func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
193 | ctx := context.Background()
194 | config := manage.Config{
195 | RefreshTokenExp: 0, // Set explicitly as we're testing 0 (no) expiration
196 | IsGenerateRefresh: true,
197 | }
198 | m, ok := manager.(*manage.Manager)
199 | So(ok, ShouldBeTrue)
200 | m.SetAuthorizeCodeTokenCfg(&config)
201 |
202 | cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
203 | So(err, ShouldBeNil)
204 |
205 | code := cti.GetCode()
206 | So(code, ShouldNotBeEmpty)
207 |
208 | atParams := &oauth2.TokenGenerateRequest{
209 | ClientID: tgr.ClientID,
210 | ClientSecret: "11",
211 | RedirectURI: tgr.RedirectURI,
212 | AccessTokenExp: time.Hour,
213 | Code: code,
214 | }
215 | ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
216 | So(err, ShouldBeNil)
217 |
218 | accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh()
219 | So(accessToken, ShouldNotBeEmpty)
220 | So(refreshToken, ShouldNotBeEmpty)
221 |
222 | tokenInfo, err := manager.LoadRefreshToken(ctx, refreshToken)
223 | So(err, ShouldBeNil)
224 | So(tokenInfo, ShouldNotBeNil)
225 | So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken)
226 | So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0)
227 |
228 | // LoadAccessToken also checks refresh expiry
229 | tokenInfo, err = manager.LoadAccessToken(ctx, accessToken)
230 | So(err, ShouldBeNil)
231 | So(tokenInfo, ShouldNotBeNil)
232 | So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken)
233 | So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0)
234 | }
235 |
--------------------------------------------------------------------------------
/manage/manager.go:
--------------------------------------------------------------------------------
1 | package manage
2 |
3 | import (
4 | "context"
5 | "net/url"
6 | "time"
7 |
8 | "github.com/go-oauth2/oauth2/v4"
9 | "github.com/go-oauth2/oauth2/v4/errors"
10 | "github.com/go-oauth2/oauth2/v4/generates"
11 | "github.com/go-oauth2/oauth2/v4/models"
12 | )
13 |
14 | // NewDefaultManager create to default authorization management instance
15 | func NewDefaultManager() *Manager {
16 | m := NewManager()
17 | // default implementation
18 | m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
19 | m.MapAccessGenerate(generates.NewAccessGenerate())
20 |
21 | return m
22 | }
23 |
24 | // NewManager create to authorization management instance
25 | func NewManager() *Manager {
26 | return &Manager{
27 | gtcfg: make(map[oauth2.GrantType]*Config),
28 | validateURI: DefaultValidateURI,
29 | }
30 | }
31 |
32 | // Manager provide authorization management
33 | type Manager struct {
34 | codeExp time.Duration
35 | gtcfg map[oauth2.GrantType]*Config
36 | rcfg *RefreshingConfig
37 | validateURI ValidateURIHandler
38 | extractExtension ExtractExtensionHandler
39 | authorizeGenerate oauth2.AuthorizeGenerate
40 | accessGenerate oauth2.AccessGenerate
41 | tokenStore oauth2.TokenStore
42 | clientStore oauth2.ClientStore
43 | }
44 |
45 | // get grant type config
46 | func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
47 | if c, ok := m.gtcfg[gt]; ok && c != nil {
48 | return c
49 | }
50 | switch gt {
51 | case oauth2.AuthorizationCode:
52 | return DefaultAuthorizeCodeTokenCfg
53 | case oauth2.Implicit:
54 | return DefaultImplicitTokenCfg
55 | case oauth2.PasswordCredentials:
56 | return DefaultPasswordTokenCfg
57 | case oauth2.ClientCredentials:
58 | return DefaultClientTokenCfg
59 | }
60 | return &Config{}
61 | }
62 |
63 | // SetAuthorizeCodeExp set the authorization code expiration time
64 | func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
65 | m.codeExp = exp
66 | }
67 |
68 | // SetAuthorizeCodeTokenCfg set the authorization code grant token config
69 | func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
70 | m.gtcfg[oauth2.AuthorizationCode] = cfg
71 | }
72 |
73 | // SetImplicitTokenCfg set the implicit grant token config
74 | func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
75 | m.gtcfg[oauth2.Implicit] = cfg
76 | }
77 |
78 | // SetPasswordTokenCfg set the password grant token config
79 | func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
80 | m.gtcfg[oauth2.PasswordCredentials] = cfg
81 | }
82 |
83 | // SetClientTokenCfg set the client grant token config
84 | func (m *Manager) SetClientTokenCfg(cfg *Config) {
85 | m.gtcfg[oauth2.ClientCredentials] = cfg
86 | }
87 |
88 | // SetRefreshTokenCfg set the refreshing token config
89 | func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
90 | m.rcfg = cfg
91 | }
92 |
93 | // SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
94 | func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
95 | m.validateURI = handler
96 | }
97 |
98 | // SetExtractExtensionHandler set the token extension extractor
99 | func (m *Manager) SetExtractExtensionHandler(handler ExtractExtensionHandler) {
100 | m.extractExtension = handler
101 | }
102 |
103 | // MapAuthorizeGenerate mapping the authorize code generate interface
104 | func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
105 | m.authorizeGenerate = gen
106 | }
107 |
108 | // MapAccessGenerate mapping the access token generate interface
109 | func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
110 | m.accessGenerate = gen
111 | }
112 |
113 | // MapClientStorage mapping the client store interface
114 | func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
115 | m.clientStore = stor
116 | }
117 |
118 | // MustClientStorage mandatory mapping the client store interface
119 | func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
120 | if err != nil {
121 | panic(err.Error())
122 | }
123 | m.clientStore = stor
124 | }
125 |
126 | // MapTokenStorage mapping the token store interface
127 | func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
128 | m.tokenStore = stor
129 | }
130 |
131 | // MustTokenStorage mandatory mapping the token store interface
132 | func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
133 | if err != nil {
134 | panic(err)
135 | }
136 | m.tokenStore = stor
137 | }
138 |
139 | // GetClient get the client information
140 | func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
141 | cli, err = m.clientStore.GetByID(ctx, clientID)
142 | if err != nil {
143 | return
144 | } else if cli == nil {
145 | err = errors.ErrInvalidClient
146 | }
147 | return
148 | }
149 |
150 | // GenerateAuthToken generate the authorization token(code)
151 | func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
152 | cli, err := m.GetClient(ctx, tgr.ClientID)
153 | if err != nil {
154 | return nil, err
155 | } else if tgr.RedirectURI != "" {
156 | if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
157 | return nil, err
158 | }
159 | }
160 |
161 | ti := models.NewToken()
162 | if m.extractExtension != nil {
163 | m.extractExtension(tgr, ti)
164 | }
165 | ti.SetClientID(tgr.ClientID)
166 | ti.SetUserID(tgr.UserID)
167 | ti.SetRedirectURI(tgr.RedirectURI)
168 | ti.SetScope(tgr.Scope)
169 |
170 | createAt := time.Now()
171 | td := &oauth2.GenerateBasic{
172 | Client: cli,
173 | UserID: tgr.UserID,
174 | CreateAt: createAt,
175 | TokenInfo: ti,
176 | Request: tgr.Request,
177 | }
178 | switch rt {
179 | case oauth2.Code:
180 | codeExp := m.codeExp
181 | if codeExp == 0 {
182 | codeExp = DefaultCodeExp
183 | }
184 | ti.SetCodeCreateAt(createAt)
185 | ti.SetCodeExpiresIn(codeExp)
186 | if exp := tgr.AccessTokenExp; exp > 0 {
187 | ti.SetAccessExpiresIn(exp)
188 | }
189 | if tgr.CodeChallenge != "" {
190 | ti.SetCodeChallenge(tgr.CodeChallenge)
191 | ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
192 | }
193 |
194 | tv, err := m.authorizeGenerate.Token(ctx, td)
195 | if err != nil {
196 | return nil, err
197 | }
198 | ti.SetCode(tv)
199 | case oauth2.Token:
200 | // set access token expires
201 | icfg := m.grantConfig(oauth2.Implicit)
202 | aexp := icfg.AccessTokenExp
203 | if exp := tgr.AccessTokenExp; exp > 0 {
204 | aexp = exp
205 | }
206 | ti.SetAccessCreateAt(createAt)
207 | ti.SetAccessExpiresIn(aexp)
208 |
209 | if icfg.IsGenerateRefresh {
210 | ti.SetRefreshCreateAt(createAt)
211 | ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
212 | }
213 |
214 | tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
215 | if err != nil {
216 | return nil, err
217 | }
218 | ti.SetAccess(tv)
219 |
220 | if rv != "" {
221 | ti.SetRefresh(rv)
222 | }
223 | }
224 |
225 | err = m.tokenStore.Create(ctx, ti)
226 | if err != nil {
227 | return nil, err
228 | }
229 | return ti, nil
230 | }
231 |
232 | // get authorization code data
233 | func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
234 | ti, err := m.tokenStore.GetByCode(ctx, code)
235 | if err != nil {
236 | return nil, err
237 | } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
238 | err = errors.ErrInvalidAuthorizeCode
239 | return nil, errors.ErrInvalidAuthorizeCode
240 | }
241 | return ti, nil
242 | }
243 |
244 | // delete authorization code data
245 | func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
246 | return m.tokenStore.RemoveByCode(ctx, code)
247 | }
248 |
249 | // get and delete authorization code data
250 | func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
251 | code := tgr.Code
252 | ti, err := m.getAuthorizationCode(ctx, code)
253 | if err != nil {
254 | return nil, err
255 | } else if ti.GetClientID() != tgr.ClientID {
256 | return nil, errors.ErrInvalidAuthorizeCode
257 | } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
258 | return nil, errors.ErrInvalidAuthorizeCode
259 | }
260 |
261 | err = m.delAuthorizationCode(ctx, code)
262 | if err != nil {
263 | return nil, err
264 | }
265 | return ti, nil
266 | }
267 |
268 | func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
269 | cc := ti.GetCodeChallenge()
270 | // early return
271 | if cc == "" && ver == "" {
272 | return nil
273 | }
274 | if cc == "" {
275 | return errors.ErrMissingCodeVerifier
276 | }
277 | if ver == "" {
278 | return errors.ErrMissingCodeVerifier
279 | }
280 | ccm := ti.GetCodeChallengeMethod()
281 | if ccm.String() == "" {
282 | ccm = oauth2.CodeChallengePlain
283 | }
284 | if !ccm.Validate(cc, ver) {
285 | return errors.ErrInvalidCodeChallenge
286 | }
287 | return nil
288 | }
289 |
290 | // GenerateAccessToken generate the access token
291 | func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
292 | cli, err := m.GetClient(ctx, tgr.ClientID)
293 | if err != nil {
294 | return nil, err
295 | }
296 | if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
297 | if !cliPass.VerifyPassword(tgr.ClientSecret) {
298 | return nil, errors.ErrInvalidClient
299 | }
300 | } else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
301 | return nil, errors.ErrInvalidClient
302 | }
303 | if tgr.RedirectURI != "" {
304 | if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
305 | return nil, err
306 | }
307 | }
308 |
309 | if gt == oauth2.ClientCredentials && cli.IsPublic() == true {
310 | return nil, errors.ErrInvalidClient
311 | }
312 |
313 | var extension url.Values
314 |
315 | if gt == oauth2.AuthorizationCode {
316 | ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
317 | if err != nil {
318 | return nil, err
319 | }
320 | if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
321 | return nil, err
322 | }
323 | tgr.UserID = ti.GetUserID()
324 | tgr.Scope = ti.GetScope()
325 | if exp := ti.GetAccessExpiresIn(); exp > 0 {
326 | tgr.AccessTokenExp = exp
327 | }
328 | if eti, ok := ti.(oauth2.ExtendableTokenInfo); ok {
329 | extension = eti.GetExtension()
330 | }
331 | }
332 |
333 | ti := models.NewToken()
334 | ti.SetExtension(extension)
335 | if m.extractExtension != nil {
336 | m.extractExtension(tgr, ti)
337 | }
338 | ti.SetClientID(tgr.ClientID)
339 | ti.SetUserID(tgr.UserID)
340 | ti.SetRedirectURI(tgr.RedirectURI)
341 | ti.SetScope(tgr.Scope)
342 |
343 | createAt := time.Now()
344 | ti.SetAccessCreateAt(createAt)
345 |
346 | // set access token expires
347 | gcfg := m.grantConfig(gt)
348 | aexp := gcfg.AccessTokenExp
349 | if exp := tgr.AccessTokenExp; exp > 0 {
350 | aexp = exp
351 | }
352 | ti.SetAccessExpiresIn(aexp)
353 | if gcfg.IsGenerateRefresh {
354 | ti.SetRefreshCreateAt(createAt)
355 | ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
356 | }
357 |
358 | td := &oauth2.GenerateBasic{
359 | Client: cli,
360 | UserID: tgr.UserID,
361 | CreateAt: createAt,
362 | TokenInfo: ti,
363 | Request: tgr.Request,
364 | }
365 |
366 | av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
367 | if err != nil {
368 | return nil, err
369 | }
370 | ti.SetAccess(av)
371 |
372 | if rv != "" {
373 | ti.SetRefresh(rv)
374 | }
375 |
376 | err = m.tokenStore.Create(ctx, ti)
377 | if err != nil {
378 | return nil, err
379 | }
380 |
381 | return ti, nil
382 | }
383 |
384 | // RefreshAccessToken refreshing an access token
385 | func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
386 | ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
387 | if err != nil {
388 | return nil, err
389 | }
390 |
391 | cli, err := m.GetClient(ctx, ti.GetClientID())
392 | if err != nil {
393 | return nil, err
394 | }
395 |
396 | oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
397 |
398 | td := &oauth2.GenerateBasic{
399 | Client: cli,
400 | UserID: ti.GetUserID(),
401 | CreateAt: time.Now(),
402 | TokenInfo: ti,
403 | Request: tgr.Request,
404 | }
405 |
406 | rcfg := DefaultRefreshTokenCfg
407 | if v := m.rcfg; v != nil {
408 | rcfg = v
409 | }
410 |
411 | ti.SetAccessCreateAt(td.CreateAt)
412 | if v := rcfg.AccessTokenExp; v > 0 {
413 | ti.SetAccessExpiresIn(v)
414 | }
415 |
416 | if v := rcfg.RefreshTokenExp; v > 0 {
417 | ti.SetRefreshExpiresIn(v)
418 | }
419 |
420 | if rcfg.IsResetRefreshTime {
421 | ti.SetRefreshCreateAt(td.CreateAt)
422 | }
423 |
424 | if scope := tgr.Scope; scope != "" {
425 | ti.SetScope(scope)
426 | }
427 |
428 | tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
429 | if err != nil {
430 | return nil, err
431 | }
432 |
433 | ti.SetAccess(tv)
434 | if rv != "" {
435 | ti.SetRefresh(rv)
436 | }
437 |
438 | if err := m.tokenStore.Create(ctx, ti); err != nil {
439 | return nil, err
440 | }
441 |
442 | if rcfg.IsRemoveAccess {
443 | // remove the old access token
444 | if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
445 | return nil, err
446 | }
447 | }
448 |
449 | if rcfg.IsRemoveRefreshing && rv != "" {
450 | // remove the old refresh token
451 | if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
452 | return nil, err
453 | }
454 | }
455 |
456 | if rv == "" {
457 | ti.SetRefresh("")
458 | ti.SetRefreshCreateAt(time.Now())
459 | ti.SetRefreshExpiresIn(0)
460 | }
461 |
462 | return ti, nil
463 | }
464 |
465 | // RemoveAccessToken use the access token to delete the token information
466 | func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
467 | if access == "" {
468 | return errors.ErrInvalidAccessToken
469 | }
470 | return m.tokenStore.RemoveByAccess(ctx, access)
471 | }
472 |
473 | // RemoveRefreshToken use the refresh token to delete the token information
474 | func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
475 | if refresh == "" {
476 | return errors.ErrInvalidAccessToken
477 | }
478 | return m.tokenStore.RemoveByRefresh(ctx, refresh)
479 | }
480 |
481 | // LoadAccessToken according to the access token for corresponding token information
482 | func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
483 | if access == "" {
484 | return nil, errors.ErrInvalidAccessToken
485 | }
486 |
487 | ct := time.Now()
488 | ti, err := m.tokenStore.GetByAccess(ctx, access)
489 | if err != nil {
490 | return nil, err
491 | } else if ti == nil || ti.GetAccess() != access {
492 | return nil, errors.ErrInvalidAccessToken
493 | } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
494 | ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
495 | return nil, errors.ErrExpiredRefreshToken
496 | } else if ti.GetAccessExpiresIn() != 0 &&
497 | ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
498 | return nil, errors.ErrExpiredAccessToken
499 | }
500 | return ti, nil
501 | }
502 |
503 | // LoadRefreshToken according to the refresh token for corresponding token information
504 | func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
505 | if refresh == "" {
506 | return nil, errors.ErrInvalidRefreshToken
507 | }
508 |
509 | ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
510 | if err != nil {
511 | return nil, err
512 | } else if ti == nil || ti.GetRefresh() != refresh {
513 | return nil, errors.ErrInvalidRefreshToken
514 | } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
515 | ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
516 | return nil, errors.ErrExpiredRefreshToken
517 | }
518 | return ti, nil
519 | }
520 |
--------------------------------------------------------------------------------
/manage/util.go:
--------------------------------------------------------------------------------
1 | package manage
2 |
3 | import (
4 | "github.com/go-oauth2/oauth2/v4"
5 | "net/url"
6 | "strings"
7 |
8 | "github.com/go-oauth2/oauth2/v4/errors"
9 | )
10 |
11 | type (
12 | // ValidateURIHandler validates that redirectURI is contained in baseURI
13 | ValidateURIHandler func(baseURI, redirectURI string) error
14 | ExtractExtensionHandler func(*oauth2.TokenGenerateRequest, oauth2.ExtendableTokenInfo)
15 | )
16 |
17 | // DefaultValidateURI validates that redirectURI is contained in baseURI
18 | func DefaultValidateURI(baseURI string, redirectURI string) error {
19 | base, err := url.Parse(baseURI)
20 | if err != nil {
21 | return err
22 | }
23 |
24 | redirect, err := url.Parse(redirectURI)
25 | if err != nil {
26 | return err
27 | }
28 | if !strings.HasSuffix(redirect.Host, base.Host) {
29 | return errors.ErrInvalidRedirectURI
30 | }
31 | return nil
32 | }
33 |
--------------------------------------------------------------------------------
/manage/util_test.go:
--------------------------------------------------------------------------------
1 | package manage_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/go-oauth2/oauth2/v4/manage"
7 |
8 | . "github.com/smartystreets/goconvey/convey"
9 | )
10 |
11 | func TestUtil(t *testing.T) {
12 | Convey("Util Test", t, func() {
13 | Convey("ValidateURI Test", func() {
14 | err := manage.DefaultValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx")
15 | So(err, ShouldBeNil)
16 | })
17 | })
18 | }
19 |
--------------------------------------------------------------------------------
/model.go:
--------------------------------------------------------------------------------
1 | package oauth2
2 |
3 | import (
4 | "net/url"
5 | "time"
6 | )
7 |
8 | type (
9 | // ClientInfo the client information model interface
10 | ClientInfo interface {
11 | GetID() string
12 | GetSecret() string
13 | GetDomain() string
14 | IsPublic() bool
15 | GetUserID() string
16 | }
17 |
18 | // ClientPasswordVerifier the password handler interface
19 | ClientPasswordVerifier interface {
20 | VerifyPassword(string) bool
21 | }
22 |
23 | // TokenInfo the token information model interface
24 | TokenInfo interface {
25 | New() TokenInfo
26 |
27 | GetClientID() string
28 | SetClientID(string)
29 | GetUserID() string
30 | SetUserID(string)
31 | GetRedirectURI() string
32 | SetRedirectURI(string)
33 | GetScope() string
34 | SetScope(string)
35 |
36 | GetCode() string
37 | SetCode(string)
38 | GetCodeCreateAt() time.Time
39 | SetCodeCreateAt(time.Time)
40 | GetCodeExpiresIn() time.Duration
41 | SetCodeExpiresIn(time.Duration)
42 | GetCodeChallenge() string
43 | SetCodeChallenge(string)
44 | GetCodeChallengeMethod() CodeChallengeMethod
45 | SetCodeChallengeMethod(CodeChallengeMethod)
46 |
47 | GetAccess() string
48 | SetAccess(string)
49 | GetAccessCreateAt() time.Time
50 | SetAccessCreateAt(time.Time)
51 | GetAccessExpiresIn() time.Duration
52 | SetAccessExpiresIn(time.Duration)
53 |
54 | GetRefresh() string
55 | SetRefresh(string)
56 | GetRefreshCreateAt() time.Time
57 | SetRefreshCreateAt(time.Time)
58 | GetRefreshExpiresIn() time.Duration
59 | SetRefreshExpiresIn(time.Duration)
60 | }
61 |
62 | ExtendableTokenInfo interface {
63 | TokenInfo
64 | GetExtension() url.Values
65 | SetExtension(url.Values)
66 | }
67 | )
68 |
--------------------------------------------------------------------------------
/models/client.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | // Client client model
4 | type Client struct {
5 | ID string
6 | Secret string
7 | Domain string
8 | Public bool
9 | UserID string
10 | }
11 |
12 | // GetID client id
13 | func (c *Client) GetID() string {
14 | return c.ID
15 | }
16 |
17 | // GetSecret client secret
18 | func (c *Client) GetSecret() string {
19 | return c.Secret
20 | }
21 |
22 | // GetDomain client domain
23 | func (c *Client) GetDomain() string {
24 | return c.Domain
25 | }
26 |
27 | // IsPublic public
28 | func (c *Client) IsPublic() bool {
29 | return c.Public
30 | }
31 |
32 | // GetUserID user id
33 | func (c *Client) GetUserID() string {
34 | return c.UserID
35 | }
36 |
--------------------------------------------------------------------------------
/models/token.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "net/url"
5 | "time"
6 |
7 | "github.com/go-oauth2/oauth2/v4"
8 | )
9 |
10 | // NewToken create to token model instance
11 | func NewToken() *Token {
12 | return &Token{Extension: make(url.Values)}
13 | }
14 |
15 | // Token token model
16 | type Token struct {
17 | ClientID string `bson:"ClientID"`
18 | UserID string `bson:"UserID"`
19 | RedirectURI string `bson:"RedirectURI"`
20 | Scope string `bson:"Scope"`
21 | Code string `bson:"Code"`
22 | CodeChallenge string `bson:"CodeChallenge"`
23 | CodeChallengeMethod string `bson:"CodeChallengeMethod"`
24 | CodeCreateAt time.Time `bson:"CodeCreateAt"`
25 | CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
26 | Access string `bson:"Access"`
27 | AccessCreateAt time.Time `bson:"AccessCreateAt"`
28 | AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
29 | Refresh string `bson:"Refresh"`
30 | RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
31 | RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
32 | Extension url.Values `bson:"Extension"`
33 | }
34 |
35 | // New create to token model instance
36 | func (t *Token) New() oauth2.TokenInfo {
37 | return NewToken()
38 | }
39 |
40 | // GetClientID the client id
41 | func (t *Token) GetClientID() string {
42 | return t.ClientID
43 | }
44 |
45 | // SetClientID the client id
46 | func (t *Token) SetClientID(clientID string) {
47 | t.ClientID = clientID
48 | }
49 |
50 | // GetUserID the user id
51 | func (t *Token) GetUserID() string {
52 | return t.UserID
53 | }
54 |
55 | // SetUserID the user id
56 | func (t *Token) SetUserID(userID string) {
57 | t.UserID = userID
58 | }
59 |
60 | // GetRedirectURI redirect URI
61 | func (t *Token) GetRedirectURI() string {
62 | return t.RedirectURI
63 | }
64 |
65 | // SetRedirectURI redirect URI
66 | func (t *Token) SetRedirectURI(redirectURI string) {
67 | t.RedirectURI = redirectURI
68 | }
69 |
70 | // GetScope get scope of authorization
71 | func (t *Token) GetScope() string {
72 | return t.Scope
73 | }
74 |
75 | // SetScope get scope of authorization
76 | func (t *Token) SetScope(scope string) {
77 | t.Scope = scope
78 | }
79 |
80 | // GetCode authorization code
81 | func (t *Token) GetCode() string {
82 | return t.Code
83 | }
84 |
85 | // SetCode authorization code
86 | func (t *Token) SetCode(code string) {
87 | t.Code = code
88 | }
89 |
90 | // GetCodeCreateAt create Time
91 | func (t *Token) GetCodeCreateAt() time.Time {
92 | return t.CodeCreateAt
93 | }
94 |
95 | // SetCodeCreateAt create Time
96 | func (t *Token) SetCodeCreateAt(createAt time.Time) {
97 | t.CodeCreateAt = createAt
98 | }
99 |
100 | // GetCodeExpiresIn the lifetime in seconds of the authorization code
101 | func (t *Token) GetCodeExpiresIn() time.Duration {
102 | return t.CodeExpiresIn
103 | }
104 |
105 | // SetCodeExpiresIn the lifetime in seconds of the authorization code
106 | func (t *Token) SetCodeExpiresIn(exp time.Duration) {
107 | t.CodeExpiresIn = exp
108 | }
109 |
110 | // GetCodeChallenge challenge code
111 | func (t *Token) GetCodeChallenge() string {
112 | return t.CodeChallenge
113 | }
114 |
115 | // SetCodeChallenge challenge code
116 | func (t *Token) SetCodeChallenge(code string) {
117 | t.CodeChallenge = code
118 | }
119 |
120 | // GetCodeChallengeMethod challenge method
121 | func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
122 | return oauth2.CodeChallengeMethod(t.CodeChallengeMethod)
123 | }
124 |
125 | // SetCodeChallengeMethod challenge method
126 | func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
127 | t.CodeChallengeMethod = string(method)
128 | }
129 |
130 | // GetAccess access Token
131 | func (t *Token) GetAccess() string {
132 | return t.Access
133 | }
134 |
135 | // SetAccess access Token
136 | func (t *Token) SetAccess(access string) {
137 | t.Access = access
138 | }
139 |
140 | // GetAccessCreateAt create Time
141 | func (t *Token) GetAccessCreateAt() time.Time {
142 | return t.AccessCreateAt
143 | }
144 |
145 | // SetAccessCreateAt create Time
146 | func (t *Token) SetAccessCreateAt(createAt time.Time) {
147 | t.AccessCreateAt = createAt
148 | }
149 |
150 | // GetAccessExpiresIn the lifetime in seconds of the access token
151 | func (t *Token) GetAccessExpiresIn() time.Duration {
152 | return t.AccessExpiresIn
153 | }
154 |
155 | // SetAccessExpiresIn the lifetime in seconds of the access token
156 | func (t *Token) SetAccessExpiresIn(exp time.Duration) {
157 | t.AccessExpiresIn = exp
158 | }
159 |
160 | // GetRefresh refresh Token
161 | func (t *Token) GetRefresh() string {
162 | return t.Refresh
163 | }
164 |
165 | // SetRefresh refresh Token
166 | func (t *Token) SetRefresh(refresh string) {
167 | t.Refresh = refresh
168 | }
169 |
170 | // GetRefreshCreateAt create Time
171 | func (t *Token) GetRefreshCreateAt() time.Time {
172 | return t.RefreshCreateAt
173 | }
174 |
175 | // SetRefreshCreateAt create Time
176 | func (t *Token) SetRefreshCreateAt(createAt time.Time) {
177 | t.RefreshCreateAt = createAt
178 | }
179 |
180 | // GetRefreshExpiresIn the lifetime in seconds of the refresh token
181 | func (t *Token) GetRefreshExpiresIn() time.Duration {
182 | return t.RefreshExpiresIn
183 | }
184 |
185 | // SetRefreshExpiresIn the lifetime in seconds of the refresh token
186 | func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
187 | t.RefreshExpiresIn = exp
188 | }
189 |
190 | // GetExtension extension of token
191 | func (t *Token) GetExtension() url.Values {
192 | return t.Extension
193 | }
194 |
195 | // SetExtension set extension of token
196 | func (t *Token) SetExtension(e url.Values) {
197 | t.Extension = e
198 | }
199 |
--------------------------------------------------------------------------------
/server/config.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "net/http"
5 | "time"
6 |
7 | "github.com/go-oauth2/oauth2/v4"
8 | )
9 |
10 | // Config configuration parameters
11 | type Config struct {
12 | TokenType string // token type
13 | AllowGetAccessRequest bool // to allow GET requests for the token
14 | AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
15 | AllowedGrantTypes []oauth2.GrantType // allow the grant type
16 | AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod
17 | ForcePKCE bool
18 | }
19 |
20 | // NewConfig create to configuration instance
21 | func NewConfig() *Config {
22 | return &Config{
23 | TokenType: "Bearer",
24 | AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code, oauth2.Token},
25 | AllowedGrantTypes: []oauth2.GrantType{
26 | oauth2.AuthorizationCode,
27 | oauth2.PasswordCredentials,
28 | oauth2.ClientCredentials,
29 | oauth2.Refreshing,
30 | },
31 | AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
32 | oauth2.CodeChallengePlain,
33 | oauth2.CodeChallengeS256,
34 | },
35 | }
36 | }
37 |
38 | // AuthorizeRequest authorization request
39 | type AuthorizeRequest struct {
40 | ResponseType oauth2.ResponseType
41 | ClientID string
42 | Scope string
43 | RedirectURI string
44 | State string
45 | UserID string
46 | CodeChallenge string
47 | CodeChallengeMethod oauth2.CodeChallengeMethod
48 | AccessTokenExp time.Duration
49 | Request *http.Request
50 | }
51 |
--------------------------------------------------------------------------------
/server/handler.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "strings"
7 | "time"
8 |
9 | "github.com/go-oauth2/oauth2/v4"
10 | "github.com/go-oauth2/oauth2/v4/errors"
11 | )
12 |
13 | type (
14 | // ClientInfoHandler get client info from request
15 | ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error)
16 |
17 | // ClientAuthorizedHandler check the client allows to use this authorization grant type
18 | ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error)
19 |
20 | // ClientScopeHandler check the client allows to use scope
21 | ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error)
22 |
23 | // UserAuthorizationHandler get user id from request authorization
24 | UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error)
25 |
26 | // PasswordAuthorizationHandler get user id from username and password
27 | PasswordAuthorizationHandler func(ctx context.Context, clientID, username, password string) (userID string, err error)
28 |
29 | // RefreshingScopeHandler check the scope of the refreshing token
30 | RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error)
31 |
32 | // RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other
33 | RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error)
34 |
35 | // ResponseErrorHandler response error handing
36 | ResponseErrorHandler func(re *errors.Response)
37 |
38 | // InternalErrorHandler internal error handing
39 | InternalErrorHandler func(err error) (re *errors.Response)
40 |
41 | // PreRedirectErrorHandler is used to override "redirect-on-error" behavior
42 | PreRedirectErrorHandler func(w http.ResponseWriter, req *AuthorizeRequest, err error) error
43 |
44 | // AuthorizeScopeHandler set the authorized scope
45 | AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error)
46 |
47 | // AccessTokenExpHandler set expiration date for the access token
48 | AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error)
49 |
50 | // ExtensionFieldsHandler in response to the access token with the extension of the field
51 | ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})
52 |
53 | // ResponseTokenHandler response token handling
54 | ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error
55 |
56 | // Handler to fetch the refresh token from the request
57 | RefreshTokenResolveHandler func(r *http.Request) (string, error)
58 |
59 | // Handler to fetch the access token from the request
60 | AccessTokenResolveHandler func(r *http.Request) (string, bool)
61 | )
62 |
63 | // ClientFormHandler get client data from form
64 | func ClientFormHandler(r *http.Request) (string, string, error) {
65 | clientID := r.Form.Get("client_id")
66 | if clientID == "" {
67 | return "", "", errors.ErrInvalidClient
68 | }
69 | clientSecret := r.Form.Get("client_secret")
70 | return clientID, clientSecret, nil
71 | }
72 |
73 | // ClientBasicHandler get client data from basic authorization
74 | func ClientBasicHandler(r *http.Request) (string, string, error) {
75 | username, password, ok := r.BasicAuth()
76 | if !ok {
77 | return "", "", errors.ErrInvalidClient
78 | }
79 | return username, password, nil
80 | }
81 |
82 | func RefreshTokenFormResolveHandler(r *http.Request) (string, error) {
83 | rt := r.FormValue("refresh_token")
84 | if rt == "" {
85 | return "", errors.ErrInvalidRequest
86 | }
87 |
88 | return rt, nil
89 | }
90 |
91 | func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) {
92 | c, err := r.Cookie("refresh_token")
93 | if err != nil {
94 | return "", errors.ErrInvalidRequest
95 | }
96 |
97 | return c.Value, nil
98 | }
99 |
100 | func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) {
101 | token := ""
102 | auth := r.Header.Get("Authorization")
103 | prefix := "Bearer "
104 |
105 | if auth != "" && strings.HasPrefix(auth, prefix) {
106 | token = auth[len(prefix):]
107 | } else {
108 | token = r.FormValue("access_token")
109 | }
110 |
111 | return token, token != ""
112 | }
113 |
114 | func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) {
115 | c, err := r.Cookie("access_token")
116 | if err != nil {
117 | return "", false
118 | }
119 |
120 | return c.Value, true
121 | }
122 |
--------------------------------------------------------------------------------
/server/handler_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "net/url"
7 | "strings"
8 | "testing"
9 | "time"
10 |
11 | "github.com/go-oauth2/oauth2/v4/errors"
12 | . "github.com/smartystreets/goconvey/convey"
13 | )
14 |
15 | func TestRefreshTokenFormResolveHandler(t *testing.T) {
16 | Convey("Correct Request", t, func() {
17 | f := url.Values{}
18 | f.Add("refresh_token", "test_token")
19 |
20 | r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
21 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
22 |
23 | token, err := RefreshTokenFormResolveHandler(r)
24 | So(err, ShouldBeNil)
25 | So(token, ShouldEqual, "test_token")
26 | })
27 |
28 | Convey("Missing Refresh Token", t, func() {
29 | r := httptest.NewRequest("POST", "/", nil)
30 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
31 |
32 | token, err := RefreshTokenFormResolveHandler(r)
33 | So(err, ShouldBeError, errors.ErrInvalidRequest)
34 | So(token, ShouldBeEmpty)
35 | })
36 | }
37 |
38 | func TestRefreshTokenCookieResolveHandler(t *testing.T) {
39 | Convey("Correct Request", t, func() {
40 | r := httptest.NewRequest(http.MethodPost, "/", nil)
41 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
42 | r.AddCookie(&http.Cookie{
43 | Name: "refresh_token",
44 | Value: "test_token",
45 | HttpOnly: true,
46 | Path: "/",
47 | Domain: ".example.com",
48 | Expires: time.Now().Add(time.Hour),
49 | })
50 |
51 | token, err := RefreshTokenCookieResolveHandler(r)
52 | So(err, ShouldBeNil)
53 | So(token, ShouldEqual, "test_token")
54 | })
55 |
56 | Convey("Missing Refresh Token", t, func() {
57 | r := httptest.NewRequest("POST", "/", nil)
58 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
59 |
60 | token, err := RefreshTokenCookieResolveHandler(r)
61 | So(err, ShouldBeError, errors.ErrInvalidRequest)
62 | So(token, ShouldBeEmpty)
63 | })
64 | }
65 |
66 | func TestAccessTokenDefaultHandler(t *testing.T) {
67 | Convey("Request Has Header", t, func() {
68 | r := httptest.NewRequest(http.MethodPost, "/", nil)
69 | r.Header.Add("Authorization", "Bearer test_token")
70 |
71 | token, ok := AccessTokenDefaultResolveHandler(r)
72 | So(ok, ShouldBeTrue)
73 | So(token, ShouldEqual, "test_token")
74 | })
75 |
76 | Convey("Request Has FormValue", t, func() {
77 | f := url.Values{}
78 | f.Add("access_token", "test_token")
79 | r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
80 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
81 |
82 | token, ok := AccessTokenDefaultResolveHandler(r)
83 | So(ok, ShouldBeTrue)
84 | So(token, ShouldEqual, "test_token")
85 | })
86 |
87 | Convey("Request Has Nothing", t, func() {
88 | r := httptest.NewRequest(http.MethodPost, "/", nil)
89 |
90 | token, ok := AccessTokenDefaultResolveHandler(r)
91 | So(ok, ShouldBeFalse)
92 | So(token, ShouldBeEmpty)
93 | })
94 | }
95 |
96 | func TestAccessTokenCookieHandler(t *testing.T) {
97 | Convey("Request Has Cookie", t, func() {
98 | r := httptest.NewRequest(http.MethodPost, "/", nil)
99 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
100 | r.AddCookie(&http.Cookie{
101 | Name: "access_token",
102 | Value: "test_token",
103 | HttpOnly: true,
104 | Path: "/",
105 | Domain: ".example.com",
106 | Expires: time.Now().Add(time.Hour),
107 | })
108 |
109 | token, ok := AccessTokenCookieResolveHandler(r)
110 | So(ok, ShouldBeTrue)
111 | So(token, ShouldEqual, "test_token")
112 | })
113 |
114 | Convey("Request Has No Cookie", t, func() {
115 | r := httptest.NewRequest(http.MethodPost, "/", nil)
116 |
117 | token, ok := AccessTokenCookieResolveHandler(r)
118 | So(ok, ShouldBeFalse)
119 | So(token, ShouldBeEmpty)
120 | })
121 | }
122 |
--------------------------------------------------------------------------------
/server/server.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "net/http"
8 | "net/url"
9 | "time"
10 |
11 | "github.com/go-oauth2/oauth2/v4"
12 | "github.com/go-oauth2/oauth2/v4/errors"
13 | )
14 |
15 | // NewDefaultServer create a default authorization server
16 | func NewDefaultServer(manager oauth2.Manager) *Server {
17 | return NewServer(NewConfig(), manager)
18 | }
19 |
20 | // NewServer create authorization server
21 | func NewServer(cfg *Config, manager oauth2.Manager) *Server {
22 | srv := &Server{
23 | Config: cfg,
24 | Manager: manager,
25 | }
26 |
27 | // default handlers
28 | srv.ClientInfoHandler = ClientBasicHandler
29 | srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler
30 | srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler
31 |
32 | srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
33 | return "", errors.ErrAccessDenied
34 | }
35 |
36 | srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) {
37 | return "", errors.ErrAccessDenied
38 | }
39 | return srv
40 | }
41 |
42 | // Server Provide authorization server
43 | type Server struct {
44 | Config *Config
45 | Manager oauth2.Manager
46 | ClientInfoHandler ClientInfoHandler
47 | ClientAuthorizedHandler ClientAuthorizedHandler
48 | ClientScopeHandler ClientScopeHandler
49 | UserAuthorizationHandler UserAuthorizationHandler
50 | PasswordAuthorizationHandler PasswordAuthorizationHandler
51 | RefreshingValidationHandler RefreshingValidationHandler
52 | PreRedirectErrorHandler PreRedirectErrorHandler
53 | RefreshingScopeHandler RefreshingScopeHandler
54 | ResponseErrorHandler ResponseErrorHandler
55 | InternalErrorHandler InternalErrorHandler
56 | ExtensionFieldsHandler ExtensionFieldsHandler
57 | AccessTokenExpHandler AccessTokenExpHandler
58 | AuthorizeScopeHandler AuthorizeScopeHandler
59 | ResponseTokenHandler ResponseTokenHandler
60 | RefreshTokenResolveHandler RefreshTokenResolveHandler
61 | AccessTokenResolveHandler AccessTokenResolveHandler
62 | }
63 |
64 | func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
65 | if fn := s.PreRedirectErrorHandler; fn != nil {
66 | return fn(w, req, err)
67 | }
68 |
69 | return s.redirectError(w, req, err)
70 | }
71 |
72 | func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
73 | if req == nil {
74 | return err
75 | }
76 |
77 | data, _, _ := s.GetErrorData(err)
78 | return s.redirect(w, req, data)
79 | }
80 |
81 | func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
82 | uri, err := s.GetRedirectURI(req, data)
83 | if err != nil {
84 | return err
85 | }
86 |
87 | w.Header().Set("Location", uri)
88 | w.WriteHeader(302)
89 | return nil
90 | }
91 |
92 | func (s *Server) tokenError(w http.ResponseWriter, err error) error {
93 | data, statusCode, header := s.GetErrorData(err)
94 | return s.token(w, data, header, statusCode)
95 | }
96 |
97 | func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
98 | if fn := s.ResponseTokenHandler; fn != nil {
99 | return fn(w, data, header, statusCode...)
100 | }
101 | w.Header().Set("Content-Type", "application/json;charset=UTF-8")
102 | w.Header().Set("Cache-Control", "no-store")
103 | w.Header().Set("Pragma", "no-cache")
104 |
105 | for key := range header {
106 | w.Header().Set(key, header.Get(key))
107 | }
108 |
109 | status := http.StatusOK
110 | if len(statusCode) > 0 && statusCode[0] > 0 {
111 | status = statusCode[0]
112 | }
113 |
114 | w.WriteHeader(status)
115 | return json.NewEncoder(w).Encode(data)
116 | }
117 |
118 | // GetRedirectURI get redirect uri
119 | func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
120 | u, err := url.Parse(req.RedirectURI)
121 | if err != nil {
122 | return "", err
123 | }
124 |
125 | q := u.Query()
126 | if req.State != "" {
127 | q.Set("state", req.State)
128 | }
129 |
130 | for k, v := range data {
131 | q.Set(k, fmt.Sprint(v))
132 | }
133 |
134 | switch req.ResponseType {
135 | case oauth2.Code:
136 | u.RawQuery = q.Encode()
137 | case oauth2.Token:
138 | u.RawQuery = ""
139 | fragment, err := url.QueryUnescape(q.Encode())
140 | if err != nil {
141 | return "", err
142 | }
143 | u.Fragment = fragment
144 | }
145 |
146 | return u.String(), nil
147 | }
148 |
149 | // CheckResponseType check allows response type
150 | func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
151 | for _, art := range s.Config.AllowedResponseTypes {
152 | if art == rt {
153 | return true
154 | }
155 | }
156 | return false
157 | }
158 |
159 | // CheckCodeChallengeMethod checks for allowed code challenge method
160 | func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool {
161 | for _, c := range s.Config.AllowedCodeChallengeMethods {
162 | if c == ccm {
163 | return true
164 | }
165 | }
166 | return false
167 | }
168 |
169 | // ValidationAuthorizeRequest the authorization request validation
170 | func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
171 | redirectURI := r.FormValue("redirect_uri")
172 | clientID := r.FormValue("client_id")
173 | if !(r.Method == "GET" || r.Method == "POST") ||
174 | clientID == "" {
175 | return nil, errors.ErrInvalidRequest
176 | }
177 |
178 | resType := oauth2.ResponseType(r.FormValue("response_type"))
179 | if resType.String() == "" {
180 | return nil, errors.ErrUnsupportedResponseType
181 | } else if allowed := s.CheckResponseType(resType); !allowed {
182 | return nil, errors.ErrUnauthorizedClient
183 | }
184 |
185 | cc := r.FormValue("code_challenge")
186 | if cc == "" && s.Config.ForcePKCE {
187 | return nil, errors.ErrCodeChallengeRquired
188 | }
189 | if cc != "" && (len(cc) < 43 || len(cc) > 128) {
190 | return nil, errors.ErrInvalidCodeChallengeLen
191 | }
192 |
193 | ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
194 | // set default
195 | if ccm == "" {
196 | ccm = oauth2.CodeChallengePlain
197 | }
198 | if ccm != "" && !s.CheckCodeChallengeMethod(ccm) {
199 | return nil, errors.ErrUnsupportedCodeChallengeMethod
200 | }
201 |
202 | req := &AuthorizeRequest{
203 | RedirectURI: redirectURI,
204 | ResponseType: resType,
205 | ClientID: clientID,
206 | State: r.FormValue("state"),
207 | Scope: r.FormValue("scope"),
208 | Request: r,
209 | CodeChallenge: cc,
210 | CodeChallengeMethod: ccm,
211 | }
212 | return req, nil
213 | }
214 |
215 | // GetAuthorizeToken get authorization token(code)
216 | func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) {
217 | // check the client allows the grant type
218 | if fn := s.ClientAuthorizedHandler; fn != nil {
219 | gt := oauth2.AuthorizationCode
220 | if req.ResponseType == oauth2.Token {
221 | gt = oauth2.Implicit
222 | }
223 |
224 | allowed, err := fn(req.ClientID, gt)
225 | if err != nil {
226 | return nil, err
227 | } else if !allowed {
228 | return nil, errors.ErrUnauthorizedClient
229 | }
230 | }
231 |
232 | tgr := &oauth2.TokenGenerateRequest{
233 | ClientID: req.ClientID,
234 | UserID: req.UserID,
235 | RedirectURI: req.RedirectURI,
236 | Scope: req.Scope,
237 | AccessTokenExp: req.AccessTokenExp,
238 | Request: req.Request,
239 | }
240 |
241 | // check the client allows the authorized scope
242 | if fn := s.ClientScopeHandler; fn != nil {
243 | allowed, err := fn(tgr)
244 | if err != nil {
245 | return nil, err
246 | } else if !allowed {
247 | return nil, errors.ErrInvalidScope
248 | }
249 | }
250 |
251 | tgr.CodeChallenge = req.CodeChallenge
252 | tgr.CodeChallengeMethod = req.CodeChallengeMethod
253 |
254 | return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr)
255 | }
256 |
257 | // GetAuthorizeData get authorization response data
258 | func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
259 | if rt == oauth2.Code {
260 | return map[string]interface{}{
261 | "code": ti.GetCode(),
262 | }
263 | }
264 | return s.GetTokenData(ti)
265 | }
266 |
267 | // HandleAuthorizeRequest the authorization request handling
268 | func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
269 | ctx := r.Context()
270 |
271 | req, err := s.ValidationAuthorizeRequest(r)
272 | if err != nil {
273 | return s.handleError(w, req, err)
274 | }
275 |
276 | // user authorization
277 | userID, err := s.UserAuthorizationHandler(w, r)
278 | if err != nil {
279 | return s.handleError(w, req, err)
280 | } else if userID == "" {
281 | return nil
282 | }
283 | req.UserID = userID
284 |
285 | // specify the scope of authorization
286 | if fn := s.AuthorizeScopeHandler; fn != nil {
287 | scope, err := fn(w, r)
288 | if err != nil {
289 | return err
290 | } else if scope != "" {
291 | req.Scope = scope
292 | }
293 | }
294 |
295 | // specify the expiration time of access token
296 | if fn := s.AccessTokenExpHandler; fn != nil {
297 | exp, err := fn(w, r)
298 | if err != nil {
299 | return err
300 | }
301 | req.AccessTokenExp = exp
302 | }
303 |
304 | ti, err := s.GetAuthorizeToken(ctx, req)
305 | if err != nil {
306 | return s.handleError(w, req, err)
307 | }
308 |
309 | // If the redirect URI is empty, the default domain provided by the client is used.
310 | if req.RedirectURI == "" {
311 | client, err := s.Manager.GetClient(ctx, req.ClientID)
312 | if err != nil {
313 | return err
314 | }
315 | req.RedirectURI = client.GetDomain()
316 | }
317 |
318 | return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
319 | }
320 |
321 | // ValidationTokenRequest the token request validation
322 | func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
323 | if v := r.Method; !(v == "POST" ||
324 | (s.Config.AllowGetAccessRequest && v == "GET")) {
325 | return "", nil, errors.ErrInvalidRequest
326 | }
327 |
328 | gt := oauth2.GrantType(r.FormValue("grant_type"))
329 | if gt.String() == "" {
330 | return "", nil, errors.ErrUnsupportedGrantType
331 | }
332 |
333 | clientID, clientSecret, err := s.ClientInfoHandler(r)
334 | if err != nil {
335 | return "", nil, err
336 | }
337 |
338 | tgr := &oauth2.TokenGenerateRequest{
339 | ClientID: clientID,
340 | ClientSecret: clientSecret,
341 | Request: r,
342 | }
343 |
344 | switch gt {
345 | case oauth2.AuthorizationCode:
346 | tgr.RedirectURI = r.FormValue("redirect_uri")
347 | tgr.Code = r.FormValue("code")
348 | if tgr.RedirectURI == "" ||
349 | tgr.Code == "" {
350 | return "", nil, errors.ErrInvalidRequest
351 | }
352 | tgr.CodeVerifier = r.FormValue("code_verifier")
353 | if s.Config.ForcePKCE && tgr.CodeVerifier == "" {
354 | return "", nil, errors.ErrInvalidRequest
355 | }
356 | case oauth2.PasswordCredentials:
357 | tgr.Scope = r.FormValue("scope")
358 | username, password := r.FormValue("username"), r.FormValue("password")
359 | if username == "" || password == "" {
360 | return "", nil, errors.ErrInvalidRequest
361 | }
362 |
363 | userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password)
364 | if err != nil {
365 | return "", nil, err
366 | } else if userID == "" {
367 | return "", nil, errors.ErrInvalidGrant
368 | }
369 | tgr.UserID = userID
370 | case oauth2.ClientCredentials:
371 | tgr.Scope = r.FormValue("scope")
372 | case oauth2.Refreshing:
373 | tgr.Refresh, err = s.RefreshTokenResolveHandler(r)
374 | tgr.Scope = r.FormValue("scope")
375 | if err != nil {
376 | return "", nil, err
377 | }
378 | }
379 | return gt, tgr, nil
380 | }
381 |
382 | // CheckGrantType check allows grant type
383 | func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
384 | for _, agt := range s.Config.AllowedGrantTypes {
385 | if agt == gt {
386 | return true
387 | }
388 | }
389 | return false
390 | }
391 |
392 | // GetAccessToken access token
393 | func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo,
394 | error) {
395 | if allowed := s.CheckGrantType(gt); !allowed {
396 | return nil, errors.ErrUnauthorizedClient
397 | }
398 |
399 | if fn := s.ClientAuthorizedHandler; fn != nil {
400 | allowed, err := fn(tgr.ClientID, gt)
401 | if err != nil {
402 | return nil, err
403 | } else if !allowed {
404 | return nil, errors.ErrUnauthorizedClient
405 | }
406 | }
407 |
408 | switch gt {
409 | case oauth2.AuthorizationCode:
410 | ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr)
411 | if err != nil {
412 | switch err {
413 | case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge:
414 | return nil, errors.ErrInvalidGrant
415 | case errors.ErrInvalidClient:
416 | return nil, errors.ErrInvalidClient
417 | default:
418 | return nil, err
419 | }
420 | }
421 | return ti, nil
422 | case oauth2.PasswordCredentials, oauth2.ClientCredentials:
423 | if fn := s.ClientScopeHandler; fn != nil {
424 | allowed, err := fn(tgr)
425 | if err != nil {
426 | return nil, err
427 | } else if !allowed {
428 | return nil, errors.ErrInvalidScope
429 | }
430 | }
431 | return s.Manager.GenerateAccessToken(ctx, gt, tgr)
432 | case oauth2.Refreshing:
433 | // check scope
434 | if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil {
435 | rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
436 | if err != nil {
437 | if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
438 | return nil, errors.ErrInvalidGrant
439 | }
440 | return nil, err
441 | }
442 |
443 | allowed, err := scopeFn(tgr, rti.GetScope())
444 | if err != nil {
445 | return nil, err
446 | } else if !allowed {
447 | return nil, errors.ErrInvalidScope
448 | }
449 | }
450 |
451 | if validationFn := s.RefreshingValidationHandler; validationFn != nil {
452 | rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
453 | if err != nil {
454 | if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
455 | return nil, errors.ErrInvalidGrant
456 | }
457 | return nil, err
458 | }
459 | allowed, err := validationFn(rti)
460 | if err != nil {
461 | return nil, err
462 | } else if !allowed {
463 | return nil, errors.ErrInvalidScope
464 | }
465 | }
466 |
467 | ti, err := s.Manager.RefreshAccessToken(ctx, tgr)
468 | if err != nil {
469 | if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
470 | return nil, errors.ErrInvalidGrant
471 | }
472 | return nil, err
473 | }
474 | return ti, nil
475 | }
476 |
477 | return nil, errors.ErrUnsupportedGrantType
478 | }
479 |
480 | // GetTokenData token data
481 | func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
482 | data := map[string]interface{}{
483 | "access_token": ti.GetAccess(),
484 | "token_type": s.Config.TokenType,
485 | "expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
486 | }
487 |
488 | if scope := ti.GetScope(); scope != "" {
489 | data["scope"] = scope
490 | }
491 |
492 | if refresh := ti.GetRefresh(); refresh != "" {
493 | data["refresh_token"] = refresh
494 | }
495 |
496 | if fn := s.ExtensionFieldsHandler; fn != nil {
497 | ext := fn(ti)
498 | for k, v := range ext {
499 | if _, ok := data[k]; ok {
500 | continue
501 | }
502 | data[k] = v
503 | }
504 | }
505 | return data
506 | }
507 |
508 | // HandleTokenRequest token request handling
509 | func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
510 | ctx := r.Context()
511 |
512 | gt, tgr, err := s.ValidationTokenRequest(r)
513 | if err != nil {
514 | return s.tokenError(w, err)
515 | }
516 |
517 | ti, err := s.GetAccessToken(ctx, gt, tgr)
518 | if err != nil {
519 | return s.tokenError(w, err)
520 | }
521 |
522 | return s.token(w, s.GetTokenData(ti), nil)
523 | }
524 |
525 | // GetErrorData get error response data
526 | func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
527 | var re errors.Response
528 | if v, ok := errors.Descriptions[err]; ok {
529 | re.Error = err
530 | re.Description = v
531 | re.StatusCode = errors.StatusCodes[err]
532 | } else {
533 | if fn := s.InternalErrorHandler; fn != nil {
534 | if v := fn(err); v != nil {
535 | re = *v
536 | }
537 | }
538 |
539 | if re.Error == nil {
540 | re.Error = errors.ErrServerError
541 | re.Description = errors.Descriptions[errors.ErrServerError]
542 | re.StatusCode = errors.StatusCodes[errors.ErrServerError]
543 | }
544 | }
545 |
546 | if fn := s.ResponseErrorHandler; fn != nil {
547 | fn(&re)
548 | }
549 |
550 | data := make(map[string]interface{})
551 | if err := re.Error; err != nil {
552 | data["error"] = err.Error()
553 | }
554 |
555 | if v := re.ErrorCode; v != 0 {
556 | data["error_code"] = v
557 | }
558 |
559 | if v := re.Description; v != "" {
560 | data["error_description"] = v
561 | }
562 |
563 | if v := re.URI; v != "" {
564 | data["error_uri"] = v
565 | }
566 |
567 | statusCode := http.StatusInternalServerError
568 | if v := re.StatusCode; v > 0 {
569 | statusCode = v
570 | }
571 |
572 | return data, statusCode, re.Header
573 | }
574 |
575 | // ValidationBearerToken validation the bearer tokens
576 | // https://tools.ietf.org/html/rfc6750
577 | func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
578 | ctx := r.Context()
579 |
580 | accessToken, ok := s.AccessTokenResolveHandler(r)
581 | if !ok {
582 | return nil, errors.ErrInvalidAccessToken
583 | }
584 |
585 | return s.Manager.LoadAccessToken(ctx, accessToken)
586 | }
587 |
--------------------------------------------------------------------------------
/server/server_config.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/go-oauth2/oauth2/v4"
5 | )
6 |
7 | // SetTokenType token type
8 | func (s *Server) SetTokenType(tokenType string) {
9 | s.Config.TokenType = tokenType
10 | }
11 |
12 | // SetAllowGetAccessRequest to allow GET requests for the token
13 | func (s *Server) SetAllowGetAccessRequest(allow bool) {
14 | s.Config.AllowGetAccessRequest = allow
15 | }
16 |
17 | // SetAllowedResponseType allow the authorization types
18 | func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) {
19 | s.Config.AllowedResponseTypes = types
20 | }
21 |
22 | // SetAllowedGrantType allow the grant types
23 | func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) {
24 | s.Config.AllowedGrantTypes = types
25 | }
26 |
27 | // SetClientInfoHandler get client info from request
28 | func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) {
29 | s.ClientInfoHandler = handler
30 | }
31 |
32 | // SetClientAuthorizedHandler check the client allows to use this authorization grant type
33 | func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) {
34 | s.ClientAuthorizedHandler = handler
35 | }
36 |
37 | // SetClientScopeHandler check the client allows to use scope
38 | func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) {
39 | s.ClientScopeHandler = handler
40 | }
41 |
42 | // SetUserAuthorizationHandler get user id from request authorization
43 | func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) {
44 | s.UserAuthorizationHandler = handler
45 | }
46 |
47 | // SetPasswordAuthorizationHandler get user id from username and password
48 | func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) {
49 | s.PasswordAuthorizationHandler = handler
50 | }
51 |
52 | // SetRefreshingScopeHandler check the scope of the refreshing token
53 | func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) {
54 | s.RefreshingScopeHandler = handler
55 | }
56 |
57 | // SetRefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other
58 | func (s *Server) SetRefreshingValidationHandler(handler RefreshingValidationHandler) {
59 | s.RefreshingValidationHandler = handler
60 | }
61 |
62 | // SetResponseErrorHandler response error handling
63 | func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) {
64 | s.ResponseErrorHandler = handler
65 | }
66 |
67 | // SetInternalErrorHandler internal error handling
68 | func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) {
69 | s.InternalErrorHandler = handler
70 | }
71 |
72 | // SetPreRedirectErrorHandler sets the PreRedirectErrorHandler in current Server instance
73 | func (s *Server) SetPreRedirectErrorHandler(handler PreRedirectErrorHandler) {
74 | s.PreRedirectErrorHandler = handler
75 | }
76 |
77 | // SetExtensionFieldsHandler in response to the access token with the extension of the field
78 | func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) {
79 | s.ExtensionFieldsHandler = handler
80 | }
81 |
82 | // SetAccessTokenExpHandler set expiration date for the access token
83 | func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) {
84 | s.AccessTokenExpHandler = handler
85 | }
86 |
87 | // SetAuthorizeScopeHandler set scope for the access token
88 | func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) {
89 | s.AuthorizeScopeHandler = handler
90 | }
91 |
92 | // SetResponseTokenHandler response token handing
93 | func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) {
94 | s.ResponseTokenHandler = handler
95 | }
96 |
97 | // SetRefreshTokenResolveHandler refresh token resolver
98 | func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) {
99 | s.RefreshTokenResolveHandler = handler
100 | }
101 |
102 | // SetAccessTokenResolveHandler access token resolver
103 | func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) {
104 | s.AccessTokenResolveHandler = handler
105 | }
106 |
--------------------------------------------------------------------------------
/server/server_test.go:
--------------------------------------------------------------------------------
1 | package server_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/gavv/httpexpect"
11 | "github.com/go-oauth2/oauth2/v4"
12 | "github.com/go-oauth2/oauth2/v4/errors"
13 | "github.com/go-oauth2/oauth2/v4/manage"
14 | "github.com/go-oauth2/oauth2/v4/models"
15 | "github.com/go-oauth2/oauth2/v4/server"
16 | "github.com/go-oauth2/oauth2/v4/store"
17 | )
18 |
19 | var (
20 | srv *server.Server
21 | tsrv *httptest.Server
22 | manager *manage.Manager
23 | csrv *httptest.Server
24 | clientID = "111111"
25 | clientSecret = "11111111"
26 |
27 | plainChallenge = "ThisIsAFourtyThreeCharactersLongStringThing"
28 | s256Challenge = "s256tests256tests256tests256tests256tests256test"
29 | // sha2562 := sha256.Sum256([]byte(s256Challenge))
30 | // fmt.Printf(base64.URLEncoding.EncodeToString(sha2562[:]))
31 | s256ChallengeHash = "To2Xqv01cm16bC9Sf7KRRS8CO2SFss_HSMQOr3sdCDE="
32 | )
33 |
34 | func init() {
35 | manager = manage.NewDefaultManager()
36 | manager.MustTokenStorage(store.NewMemoryTokenStore())
37 | }
38 |
39 | func clientStore(domain string, public bool) oauth2.ClientStore {
40 | clientStore := store.NewClientStore()
41 | var secret string
42 | if public {
43 | secret = ""
44 | } else {
45 | secret = clientSecret
46 | }
47 | clientStore.Set(clientID, &models.Client{
48 | ID: clientID,
49 | Secret: secret,
50 | Domain: domain,
51 | Public: public,
52 | })
53 | return clientStore
54 | }
55 |
56 | func testServer(t *testing.T, w http.ResponseWriter, r *http.Request) {
57 | switch r.URL.Path {
58 | case "/authorize":
59 | err := srv.HandleAuthorizeRequest(w, r)
60 | if err != nil {
61 | t.Error(err)
62 | }
63 | case "/token":
64 | err := srv.HandleTokenRequest(w, r)
65 | if err != nil {
66 | t.Error(err)
67 | }
68 | }
69 | }
70 |
71 | func TestAuthorizeCode(t *testing.T) {
72 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
73 | testServer(t, w, r)
74 | }))
75 | defer tsrv.Close()
76 |
77 | e := httpexpect.New(t, tsrv.URL)
78 |
79 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80 | switch r.URL.Path {
81 | case "/oauth2":
82 | r.ParseForm()
83 | code, state := r.Form.Get("code"), r.Form.Get("state")
84 | if state != "123" {
85 | t.Error("unrecognized state:", state)
86 | return
87 | }
88 | resObj := e.POST("/token").
89 | WithFormField("redirect_uri", csrv.URL+"/oauth2").
90 | WithFormField("code", code).
91 | WithFormField("grant_type", "authorization_code").
92 | WithFormField("client_id", clientID).
93 | WithBasicAuth(clientID, clientSecret).
94 | Expect().
95 | Status(http.StatusOK).
96 | JSON().Object()
97 |
98 | t.Logf("%#v\n", resObj.Raw())
99 |
100 | validationAccessToken(t, resObj.Value("access_token").String().Raw())
101 | }
102 | }))
103 | defer csrv.Close()
104 |
105 | manager.MapClientStorage(clientStore(csrv.URL, true))
106 | srv = server.NewDefaultServer(manager)
107 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
108 | userID = "000000"
109 | return
110 | })
111 |
112 | e.GET("/authorize").
113 | WithQuery("response_type", "code").
114 | WithQuery("client_id", clientID).
115 | WithQuery("scope", "all").
116 | WithQuery("state", "123").
117 | WithQuery("redirect_uri", csrv.URL+"/oauth2").
118 | Expect().Status(http.StatusOK)
119 | }
120 |
121 | func TestAuthorizeCodeWithChallengePlain(t *testing.T) {
122 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123 | testServer(t, w, r)
124 | }))
125 | defer tsrv.Close()
126 |
127 | e := httpexpect.New(t, tsrv.URL)
128 |
129 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
130 | switch r.URL.Path {
131 | case "/oauth2":
132 | r.ParseForm()
133 | code, state := r.Form.Get("code"), r.Form.Get("state")
134 | if state != "123" {
135 | t.Error("unrecognized state:", state)
136 | return
137 | }
138 | resObj := e.POST("/token").
139 | WithFormField("redirect_uri", csrv.URL+"/oauth2").
140 | WithFormField("code", code).
141 | WithFormField("grant_type", "authorization_code").
142 | WithFormField("client_id", clientID).
143 | WithFormField("code", code).
144 | WithFormField("code_verifier", plainChallenge).
145 | Expect().
146 | Status(http.StatusOK).
147 | JSON().Object()
148 |
149 | t.Logf("%#v\n", resObj.Raw())
150 |
151 | validationAccessToken(t, resObj.Value("access_token").String().Raw())
152 | }
153 | }))
154 | defer csrv.Close()
155 |
156 | manager.MapClientStorage(clientStore(csrv.URL, true))
157 | srv = server.NewDefaultServer(manager)
158 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
159 | userID = "000000"
160 | return
161 | })
162 | srv.SetClientInfoHandler(server.ClientFormHandler)
163 |
164 | e.GET("/authorize").
165 | WithQuery("response_type", "code").
166 | WithQuery("client_id", clientID).
167 | WithQuery("scope", "all").
168 | WithQuery("state", "123").
169 | WithQuery("redirect_uri", csrv.URL+"/oauth2").
170 | WithQuery("code_challenge", plainChallenge).
171 | Expect().Status(http.StatusOK)
172 | }
173 |
174 | func TestAuthorizeCodeWithChallengeS256(t *testing.T) {
175 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
176 | testServer(t, w, r)
177 | }))
178 | defer tsrv.Close()
179 |
180 | e := httpexpect.New(t, tsrv.URL)
181 |
182 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
183 | switch r.URL.Path {
184 | case "/oauth2":
185 | r.ParseForm()
186 | code, state := r.Form.Get("code"), r.Form.Get("state")
187 | if state != "123" {
188 | t.Error("unrecognized state:", state)
189 | return
190 | }
191 | resObj := e.POST("/token").
192 | WithFormField("redirect_uri", csrv.URL+"/oauth2").
193 | WithFormField("code", code).
194 | WithFormField("grant_type", "authorization_code").
195 | WithFormField("client_id", clientID).
196 | WithFormField("code", code).
197 | WithFormField("code_verifier", s256Challenge).
198 | Expect().
199 | Status(http.StatusOK).
200 | JSON().Object()
201 |
202 | t.Logf("%#v\n", resObj.Raw())
203 |
204 | validationAccessToken(t, resObj.Value("access_token").String().Raw())
205 | }
206 | }))
207 | defer csrv.Close()
208 |
209 | manager.MapClientStorage(clientStore(csrv.URL, true))
210 | srv = server.NewDefaultServer(manager)
211 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
212 | userID = "000000"
213 | return
214 | })
215 | srv.SetClientInfoHandler(server.ClientFormHandler)
216 |
217 | e.GET("/authorize").
218 | WithQuery("response_type", "code").
219 | WithQuery("client_id", clientID).
220 | WithQuery("scope", "all").
221 | WithQuery("state", "123").
222 | WithQuery("redirect_uri", csrv.URL+"/oauth2").
223 | WithQuery("code_challenge", s256ChallengeHash).
224 | WithQuery("code_challenge_method", "S256").
225 | Expect().Status(http.StatusOK)
226 | }
227 |
228 | func TestImplicit(t *testing.T) {
229 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
230 | testServer(t, w, r)
231 | }))
232 | defer tsrv.Close()
233 | e := httpexpect.New(t, tsrv.URL)
234 |
235 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
236 | defer csrv.Close()
237 |
238 | manager.MapClientStorage(clientStore(csrv.URL, false))
239 | srv = server.NewDefaultServer(manager)
240 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
241 | userID = "000000"
242 | return
243 | })
244 |
245 | e.GET("/authorize").
246 | WithQuery("response_type", "token").
247 | WithQuery("client_id", clientID).
248 | WithQuery("scope", "all").
249 | WithQuery("state", "123").
250 | WithQuery("redirect_uri", csrv.URL+"/oauth2").
251 | Expect().Status(http.StatusOK)
252 | }
253 |
254 | func TestPasswordCredentials(t *testing.T) {
255 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
256 | testServer(t, w, r)
257 | }))
258 | defer tsrv.Close()
259 | e := httpexpect.New(t, tsrv.URL)
260 |
261 | manager.MapClientStorage(clientStore("", false))
262 | srv = server.NewDefaultServer(manager)
263 | srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) {
264 | if username == "admin" && password == "123456" {
265 | userID = "000000"
266 | return
267 | }
268 | err = fmt.Errorf("user not found")
269 | return
270 | })
271 |
272 | resObj := e.POST("/token").
273 | WithFormField("grant_type", "password").
274 | WithFormField("username", "admin").
275 | WithFormField("password", "123456").
276 | WithFormField("scope", "all").
277 | WithBasicAuth(clientID, clientSecret).
278 | Expect().
279 | Status(http.StatusOK).
280 | JSON().Object()
281 |
282 | t.Logf("%#v\n", resObj.Raw())
283 |
284 | validationAccessToken(t, resObj.Value("access_token").String().Raw())
285 | }
286 |
287 | func TestClientCredentials(t *testing.T) {
288 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
289 | testServer(t, w, r)
290 | }))
291 | defer tsrv.Close()
292 | e := httpexpect.New(t, tsrv.URL)
293 |
294 | manager.MapClientStorage(clientStore("", false))
295 |
296 | srv = server.NewDefaultServer(manager)
297 | srv.SetClientInfoHandler(server.ClientFormHandler)
298 |
299 | srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
300 | t.Log("OAuth 2.0 Error:", err.Error())
301 | return
302 | })
303 |
304 | srv.SetResponseErrorHandler(func(re *errors.Response) {
305 | t.Log("Response Error:", re.Error)
306 | })
307 |
308 | srv.SetAllowedGrantType(oauth2.ClientCredentials)
309 | srv.SetAllowGetAccessRequest(false)
310 | srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
311 | fieldsValue = map[string]interface{}{
312 | "extension": "param",
313 | }
314 | return
315 | })
316 | srv.SetAuthorizeScopeHandler(func(w http.ResponseWriter, r *http.Request) (scope string, err error) {
317 | return
318 | })
319 | srv.SetClientScopeHandler(func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) {
320 | allowed = true
321 | return
322 | })
323 |
324 | resObj := e.POST("/token").
325 | WithFormField("grant_type", "client_credentials").
326 | WithFormField("scope", "all").
327 | WithFormField("client_id", clientID).
328 | WithFormField("client_secret", clientSecret).
329 | Expect().
330 | Status(http.StatusOK).
331 | JSON().Object()
332 |
333 | t.Logf("%#v\n", resObj.Raw())
334 |
335 | validationAccessToken(t, resObj.Value("access_token").String().Raw())
336 | }
337 |
338 | func TestRefreshing(t *testing.T) {
339 | tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
340 | testServer(t, w, r)
341 | }))
342 | defer tsrv.Close()
343 | e := httpexpect.New(t, tsrv.URL)
344 |
345 | csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
346 | switch r.URL.Path {
347 | case "/oauth2":
348 | r.ParseForm()
349 | code, state := r.Form.Get("code"), r.Form.Get("state")
350 | if state != "123" {
351 | t.Error("unrecognized state:", state)
352 | return
353 | }
354 | jresObj := e.POST("/token").
355 | WithFormField("redirect_uri", csrv.URL+"/oauth2").
356 | WithFormField("code", code).
357 | WithFormField("grant_type", "authorization_code").
358 | WithFormField("client_id", clientID).
359 | WithBasicAuth(clientID, clientSecret).
360 | Expect().
361 | Status(http.StatusOK).
362 | JSON().Object()
363 |
364 | t.Logf("%#v\n", jresObj.Raw())
365 |
366 | validationAccessToken(t, jresObj.Value("access_token").String().Raw())
367 |
368 | resObj := e.POST("/token").
369 | WithFormField("grant_type", "refresh_token").
370 | WithFormField("scope", "one").
371 | WithFormField("refresh_token", jresObj.Value("refresh_token").String().Raw()).
372 | WithBasicAuth(clientID, clientSecret).
373 | Expect().
374 | Status(http.StatusOK).
375 | JSON().Object()
376 |
377 | t.Logf("%#v\n", resObj.Raw())
378 |
379 | validationAccessToken(t, resObj.Value("access_token").String().Raw())
380 | }
381 | }))
382 | defer csrv.Close()
383 |
384 | manager.MapClientStorage(clientStore(csrv.URL, true))
385 | srv = server.NewDefaultServer(manager)
386 | srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
387 | userID = "000000"
388 | return
389 | })
390 |
391 | e.GET("/authorize").
392 | WithQuery("response_type", "code").
393 | WithQuery("client_id", clientID).
394 | WithQuery("scope", "all").
395 | WithQuery("state", "123").
396 | WithQuery("redirect_uri", csrv.URL+"/oauth2").
397 | Expect().Status(http.StatusOK)
398 | }
399 |
400 | // validation access token
401 | func validationAccessToken(t *testing.T, accessToken string) {
402 | req := httptest.NewRequest("GET", "http://example.com", nil)
403 |
404 | req.Header.Set("Authorization", "Bearer "+accessToken)
405 |
406 | ti, err := srv.ValidationBearerToken(req)
407 | if err != nil {
408 | t.Error(err.Error())
409 | return
410 | }
411 | if ti.GetClientID() != clientID {
412 | t.Error("invalid access token")
413 | }
414 | }
415 |
--------------------------------------------------------------------------------
/store.go:
--------------------------------------------------------------------------------
1 | package oauth2
2 |
3 | import "context"
4 |
5 | type (
6 | // ClientStore the client information storage interface
7 | ClientStore interface {
8 | // according to the ID for the client information
9 | GetByID(ctx context.Context, id string) (ClientInfo, error)
10 | }
11 |
12 | // TokenStore the token information storage interface
13 | TokenStore interface {
14 | // create and store the new token information
15 | Create(ctx context.Context, info TokenInfo) error
16 |
17 | // delete the authorization code
18 | RemoveByCode(ctx context.Context, code string) error
19 |
20 | // use the access token to delete the token information
21 | RemoveByAccess(ctx context.Context, access string) error
22 |
23 | // use the refresh token to delete the token information
24 | RemoveByRefresh(ctx context.Context, refresh string) error
25 |
26 | // use the authorization code for token information data
27 | GetByCode(ctx context.Context, code string) (TokenInfo, error)
28 |
29 | // use the access token for token information data
30 | GetByAccess(ctx context.Context, access string) (TokenInfo, error)
31 |
32 | // use the refresh token for token information data
33 | GetByRefresh(ctx context.Context, refresh string) (TokenInfo, error)
34 | }
35 | )
36 |
--------------------------------------------------------------------------------
/store/client.go:
--------------------------------------------------------------------------------
1 | package store
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "sync"
7 |
8 | "github.com/go-oauth2/oauth2/v4"
9 | )
10 |
11 | // NewClientStore create client store
12 | func NewClientStore() *ClientStore {
13 | return &ClientStore{
14 | data: make(map[string]oauth2.ClientInfo),
15 | }
16 | }
17 |
18 | // ClientStore client information store
19 | type ClientStore struct {
20 | sync.RWMutex
21 | data map[string]oauth2.ClientInfo
22 | }
23 |
24 | // GetByID according to the ID for the client information
25 | func (cs *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
26 | cs.RLock()
27 | defer cs.RUnlock()
28 |
29 | if c, ok := cs.data[id]; ok {
30 | return c, nil
31 | }
32 | return nil, errors.New("not found")
33 | }
34 |
35 | // Set set client information
36 | func (cs *ClientStore) Set(id string, cli oauth2.ClientInfo) (err error) {
37 | cs.Lock()
38 | defer cs.Unlock()
39 |
40 | cs.data[id] = cli
41 | return
42 | }
43 |
--------------------------------------------------------------------------------
/store/client_test.go:
--------------------------------------------------------------------------------
1 | package store_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/go-oauth2/oauth2/v4/models"
8 | "github.com/go-oauth2/oauth2/v4/store"
9 |
10 | . "github.com/smartystreets/goconvey/convey"
11 | )
12 |
13 | func TestClientStore(t *testing.T) {
14 | Convey("Test client store", t, func() {
15 | clientStore := store.NewClientStore()
16 |
17 | err := clientStore.Set("1", &models.Client{ID: "1", Secret: "2"})
18 | So(err, ShouldBeNil)
19 |
20 | cli, err := clientStore.GetByID(context.Background(), "1")
21 | So(err, ShouldBeNil)
22 | So(cli.GetID(), ShouldEqual, "1")
23 | })
24 | }
25 |
--------------------------------------------------------------------------------
/store/token.go:
--------------------------------------------------------------------------------
1 | package store
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "time"
7 |
8 | "github.com/go-oauth2/oauth2/v4"
9 | "github.com/go-oauth2/oauth2/v4/models"
10 | "github.com/google/uuid"
11 | "github.com/tidwall/buntdb"
12 | )
13 |
14 | // NewMemoryTokenStore create a token store instance based on memory
15 | func NewMemoryTokenStore() (oauth2.TokenStore, error) {
16 | return NewFileTokenStore(":memory:")
17 | }
18 |
19 | // NewFileTokenStore create a token store instance based on file
20 | func NewFileTokenStore(filename string) (oauth2.TokenStore, error) {
21 | db, err := buntdb.Open(filename)
22 | if err != nil {
23 | return nil, err
24 | }
25 | return &TokenStore{db: db}, nil
26 | }
27 |
28 | // TokenStore token storage based on buntdb(https://github.com/tidwall/buntdb)
29 | type TokenStore struct {
30 | db *buntdb.DB
31 | }
32 |
33 | // Create create and store the new token information
34 | func (ts *TokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
35 | ct := time.Now()
36 | jv, err := json.Marshal(info)
37 | if err != nil {
38 | return err
39 | }
40 |
41 | return ts.db.Update(func(tx *buntdb.Tx) error {
42 | if code := info.GetCode(); code != "" {
43 | _, _, err := tx.Set(code, string(jv), &buntdb.SetOptions{Expires: true, TTL: info.GetCodeExpiresIn()})
44 | return err
45 | }
46 |
47 | basicID := uuid.Must(uuid.NewRandom()).String()
48 | aexp := info.GetAccessExpiresIn()
49 | rexp := aexp
50 | expires := true
51 | if refresh := info.GetRefresh(); refresh != "" {
52 | rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
53 | if aexp.Seconds() > rexp.Seconds() {
54 | aexp = rexp
55 | }
56 | expires = info.GetRefreshExpiresIn() != 0
57 | _, _, err := tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
58 | if err != nil {
59 | return err
60 | }
61 | }
62 |
63 | _, _, err := tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
64 | if err != nil {
65 | return err
66 | }
67 | _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
68 | return err
69 | })
70 | }
71 |
72 | // remove key
73 | func (ts *TokenStore) remove(key string) error {
74 | err := ts.db.Update(func(tx *buntdb.Tx) error {
75 | _, err := tx.Delete(key)
76 | return err
77 | })
78 | if err == buntdb.ErrNotFound {
79 | return nil
80 | }
81 | return err
82 | }
83 |
84 | // RemoveByCode use the authorization code to delete the token information
85 | func (ts *TokenStore) RemoveByCode(ctx context.Context, code string) error {
86 | return ts.remove(code)
87 | }
88 |
89 | // RemoveByAccess use the access token to delete the token information
90 | func (ts *TokenStore) RemoveByAccess(ctx context.Context, access string) error {
91 | return ts.remove(access)
92 | }
93 |
94 | // RemoveByRefresh use the refresh token to delete the token information
95 | func (ts *TokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
96 | return ts.remove(refresh)
97 | }
98 |
99 | func (ts *TokenStore) getData(key string) (oauth2.TokenInfo, error) {
100 | var ti oauth2.TokenInfo
101 | err := ts.db.View(func(tx *buntdb.Tx) error {
102 | jv, err := tx.Get(key)
103 | if err != nil {
104 | return err
105 | }
106 |
107 | var tm models.Token
108 | err = json.Unmarshal([]byte(jv), &tm)
109 | if err != nil {
110 | return err
111 | }
112 | ti = &tm
113 | return nil
114 | })
115 | if err != nil {
116 | if err == buntdb.ErrNotFound {
117 | return nil, nil
118 | }
119 | return nil, err
120 | }
121 | return ti, nil
122 | }
123 |
124 | func (ts *TokenStore) getBasicID(key string) (string, error) {
125 | var basicID string
126 | err := ts.db.View(func(tx *buntdb.Tx) error {
127 | v, err := tx.Get(key)
128 | if err != nil {
129 | return err
130 | }
131 | basicID = v
132 | return nil
133 | })
134 | if err != nil {
135 | if err == buntdb.ErrNotFound {
136 | return "", nil
137 | }
138 | return "", err
139 | }
140 | return basicID, nil
141 | }
142 |
143 | // GetByCode use the authorization code for token information data
144 | func (ts *TokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
145 | return ts.getData(code)
146 | }
147 |
148 | // GetByAccess use the access token for token information data
149 | func (ts *TokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
150 | basicID, err := ts.getBasicID(access)
151 | if err != nil {
152 | return nil, err
153 | }
154 | return ts.getData(basicID)
155 | }
156 |
157 | // GetByRefresh use the refresh token for token information data
158 | func (ts *TokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
159 | basicID, err := ts.getBasicID(refresh)
160 | if err != nil {
161 | return nil, err
162 | }
163 | return ts.getData(basicID)
164 | }
165 |
--------------------------------------------------------------------------------
/store/token_test.go:
--------------------------------------------------------------------------------
1 | package store_test
2 |
3 | import (
4 | "context"
5 | "os"
6 | "testing"
7 | "time"
8 |
9 | "github.com/go-oauth2/oauth2/v4"
10 | "github.com/go-oauth2/oauth2/v4/models"
11 | "github.com/go-oauth2/oauth2/v4/store"
12 |
13 | . "github.com/smartystreets/goconvey/convey"
14 | )
15 |
16 | func TestTokenStore(t *testing.T) {
17 | Convey("Test memory store", t, func() {
18 | store, err := store.NewMemoryTokenStore()
19 | So(err, ShouldBeNil)
20 | testToken(store)
21 | })
22 |
23 | Convey("Test file store", t, func() {
24 | os.Remove("data.db")
25 |
26 | store, err := store.NewFileTokenStore("data.db")
27 | So(err, ShouldBeNil)
28 | testToken(store)
29 | })
30 | }
31 |
32 | func testToken(store oauth2.TokenStore) {
33 | Convey("Test authorization code store", func() {
34 | ctx := context.Background()
35 | info := &models.Token{
36 | ClientID: "1",
37 | UserID: "1_1",
38 | RedirectURI: "http://localhost/",
39 | Scope: "all",
40 | Code: "11_11_11",
41 | CodeCreateAt: time.Now(),
42 | CodeExpiresIn: time.Second * 5,
43 | }
44 | err := store.Create(ctx, info)
45 | So(err, ShouldBeNil)
46 |
47 | cinfo, err := store.GetByCode(ctx, info.Code)
48 | So(err, ShouldBeNil)
49 | So(cinfo.GetUserID(), ShouldEqual, info.UserID)
50 |
51 | err = store.RemoveByCode(ctx, info.Code)
52 | So(err, ShouldBeNil)
53 |
54 | cinfo, err = store.GetByCode(ctx, info.Code)
55 | So(err, ShouldBeNil)
56 | So(cinfo, ShouldBeNil)
57 | })
58 |
59 | Convey("Test access token store", func() {
60 | ctx := context.Background()
61 | info := &models.Token{
62 | ClientID: "1",
63 | UserID: "1_1",
64 | RedirectURI: "http://localhost/",
65 | Scope: "all",
66 | Access: "1_1_1",
67 | AccessCreateAt: time.Now(),
68 | AccessExpiresIn: time.Second * 5,
69 | }
70 | err := store.Create(ctx, info)
71 | So(err, ShouldBeNil)
72 |
73 | ainfo, err := store.GetByAccess(ctx, info.GetAccess())
74 | So(err, ShouldBeNil)
75 | So(ainfo.GetUserID(), ShouldEqual, info.GetUserID())
76 |
77 | err = store.RemoveByAccess(ctx, info.GetAccess())
78 | So(err, ShouldBeNil)
79 |
80 | ainfo, err = store.GetByAccess(ctx, info.GetAccess())
81 | So(err, ShouldBeNil)
82 | So(ainfo, ShouldBeNil)
83 | })
84 |
85 | Convey("Test refresh token store", func() {
86 | ctx := context.Background()
87 | info := &models.Token{
88 | ClientID: "1",
89 | UserID: "1_2",
90 | RedirectURI: "http://localhost/",
91 | Scope: "all",
92 | Access: "1_2_1",
93 | AccessCreateAt: time.Now(),
94 | AccessExpiresIn: time.Second * 5,
95 | Refresh: "1_2_2",
96 | RefreshCreateAt: time.Now(),
97 | RefreshExpiresIn: time.Second * 15,
98 | }
99 | err := store.Create(ctx, info)
100 | So(err, ShouldBeNil)
101 |
102 | rinfo, err := store.GetByRefresh(ctx, info.GetRefresh())
103 | So(err, ShouldBeNil)
104 | So(rinfo.GetUserID(), ShouldEqual, info.GetUserID())
105 |
106 | err = store.RemoveByRefresh(ctx, info.GetRefresh())
107 | So(err, ShouldBeNil)
108 |
109 | rinfo, err = store.GetByRefresh(ctx, info.GetRefresh())
110 | So(err, ShouldBeNil)
111 | So(rinfo, ShouldBeNil)
112 | })
113 |
114 | Convey("Test TTL", func() {
115 | ctx := context.Background()
116 | info := &models.Token{
117 | ClientID: "1",
118 | UserID: "1_1",
119 | RedirectURI: "http://localhost/",
120 | Scope: "all",
121 | Access: "1_3_1",
122 | AccessCreateAt: time.Now(),
123 | AccessExpiresIn: time.Second * 1,
124 | Refresh: "1_3_2",
125 | RefreshCreateAt: time.Now(),
126 | RefreshExpiresIn: time.Second * 1,
127 | }
128 | err := store.Create(ctx, info)
129 | So(err, ShouldBeNil)
130 |
131 | time.Sleep(time.Second * 1)
132 | ainfo, err := store.GetByAccess(ctx, info.Access)
133 | So(err, ShouldBeNil)
134 | So(ainfo, ShouldBeNil)
135 | rinfo, err := store.GetByRefresh(ctx, info.Refresh)
136 | So(err, ShouldBeNil)
137 | So(rinfo, ShouldBeNil)
138 | })
139 | }
140 |
--------------------------------------------------------------------------------