├── .github └── workflows │ ├── ci.yml │ └── lint.yml ├── .golangci.yml ├── LICENSE ├── README.md ├── api ├── access_token.go ├── access_token_test.go ├── form.go └── form_test.go ├── device ├── device_flow.go ├── device_flow_test.go ├── examples_test.go └── poller.go ├── examples_test.go ├── go.mod ├── go.sum ├── oauth.go ├── oauth_device.go ├── oauth_webapp.go └── webapp ├── examples_test.go ├── local_server.go ├── local_server_test.go ├── webapp_flow.go └── webapp_flow_test.go /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: CI 4 | jobs: 5 | test: 6 | strategy: 7 | matrix: 8 | go: [ '1.21', '1.22', '1.23' ] 9 | os: [ ubuntu-latest, macos-latest, windows-latest ] 10 | fail-fast: false 11 | 12 | name: Test suite 13 | runs-on: ${{ matrix.os }} 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Setup Go 18 | uses: actions/setup-go@v1 19 | with: 20 | go-version: ${{ matrix.go }} 21 | - name: Run tests 22 | run: go test -v ./... 23 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | push: 4 | paths: 5 | - "**.go" 6 | - go.mod 7 | - go.sum 8 | pull_request: 9 | paths: 10 | - "**.go" 11 | - go.mod 12 | - go.sum 13 | 14 | jobs: 15 | lint: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - name: Set up Go 1.19 20 | uses: actions/setup-go@v3 21 | with: 22 | go-version: 1.19 23 | 24 | - name: Check out code 25 | uses: actions/checkout@v2 26 | 27 | - name: Verify dependencies 28 | env: 29 | LINT_VERSION: 1.50.1 30 | run: | 31 | go mod verify 32 | go mod download 33 | 34 | curl -fsSL https://github.com/golangci/golangci-lint/releases/download/v${LINT_VERSION}/golangci-lint-${LINT_VERSION}-linux-amd64.tar.gz | \ 35 | tar xz --strip-components 1 --wildcards \*/golangci-lint 36 | mkdir -p bin && mv golangci-lint bin/ 37 | 38 | - name: Run checks 39 | run: bin/golangci-lint run --out-format=github-actions 40 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - gofmt 4 | - godot 5 | - revive 6 | 7 | linters-settings: 8 | godot: 9 | # comments to be checked: `declarations`, `toplevel`, or `all` 10 | scope: declarations 11 | # check that each sentence starts with a capital letter 12 | capital: true 13 | 14 | issues: 15 | exclude-use-default: false 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 GitHub, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # oauth 2 | 3 | A library for Go client applications that need to perform OAuth authorization against a server, typically GitHub.com. 4 | 5 |

6 |
7 | 8 |

9 | 10 | Traditionally, OAuth for web applications involves redirecting to a URI after the user authorizes an app. While web apps (and some native client apps) can receive a browser redirect, client apps such as CLI applications do not have such an option. 11 | 12 | To accommodate client apps, this library implements the [OAuth Device Authorization Grant][oauth-device] which [GitHub.com now supports][gh-device]. With Device flow, the user is presented with a one-time code that they will have to enter in a web browser while authorizing the app on the server. Device flow is suitable for cases where the web browser may be running on a separate device than the client app itself; for example a CLI application could run within a headless, containerized instance, but the user may complete authorization using a browser on their phone. 13 | 14 | To transparently enable OAuth authorization on _any GitHub host_ (e.g. GHES instances without OAuth “Device flow” support), this library also bundles an implementation of OAuth web application flow in which the client app starts a local server at `http://127.0.0.1:/` that acts as a receiver for the browser redirect. First, Device flow is attempted, and the localhost server is used as fallback. With the localhost server, the user's web browser must be running on the same machine as the client application itself. 15 | 16 | ## Usage 17 | 18 | - [OAuth Device flow with fallback](./examples_test.go) 19 | - [manual OAuth Device flow](./device/examples_test.go) 20 | - [manual OAuth web application flow](./webapp/examples_test.go) 21 | 22 | Applications that need more control over the user experience around authentication should directly interface with `github.com/cli/oauth/device` and `github.com/cli/oauth/webapp` packages. 23 | 24 | In theory, these packages would enable authorization on any OAuth-enabled host. In practice, however, this was only tested for authorizing with GitHub. 25 | 26 | 27 | [oauth-device]: https://oauth.net/2/device-flow/ 28 | [gh-device]: https://docs.github.com/en/free-pro-team@latest/developers/apps/authorizing-oauth-apps#device-flow 29 | -------------------------------------------------------------------------------- /api/access_token.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // AccessToken is an OAuth access token. 4 | type AccessToken struct { 5 | // The token value, typically a 40-character random string. 6 | Token string 7 | // The refresh token value, associated with the access token. 8 | RefreshToken string 9 | // The token type, e.g. "bearer". 10 | Type string 11 | // Space-separated list of OAuth scopes that this token grants. 12 | Scope string 13 | } 14 | 15 | // AccessToken extracts the access token information from a server response. 16 | func (f FormResponse) AccessToken() (*AccessToken, error) { 17 | if accessToken := f.Get("access_token"); accessToken != "" { 18 | return &AccessToken{ 19 | Token: accessToken, 20 | RefreshToken: f.Get("refresh_token"), 21 | Type: f.Get("token_type"), 22 | Scope: f.Get("scope"), 23 | }, nil 24 | } 25 | 26 | return nil, f.Err() 27 | } 28 | -------------------------------------------------------------------------------- /api/access_token_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "net/url" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestFormResponse_AccessToken(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | response FormResponse 13 | want *AccessToken 14 | wantErr *Error 15 | }{ 16 | { 17 | name: "with token", 18 | response: FormResponse{ 19 | values: url.Values{ 20 | "access_token": []string{"ATOKEN"}, 21 | "token_type": []string{"bearer"}, 22 | "scope": []string{"repo gist"}, 23 | }, 24 | }, 25 | want: &AccessToken{ 26 | Token: "ATOKEN", 27 | RefreshToken: "", 28 | Type: "bearer", 29 | Scope: "repo gist", 30 | }, 31 | wantErr: nil, 32 | }, 33 | { 34 | name: "with refresh token", 35 | response: FormResponse{ 36 | values: url.Values{ 37 | "access_token": []string{"ATOKEN"}, 38 | "refresh_token": []string{"AREFRESHTOKEN"}, 39 | "token_type": []string{"bearer"}, 40 | "scope": []string{"repo gist"}, 41 | }, 42 | }, 43 | want: &AccessToken{ 44 | Token: "ATOKEN", 45 | RefreshToken: "AREFRESHTOKEN", 46 | Type: "bearer", 47 | Scope: "repo gist", 48 | }, 49 | wantErr: nil, 50 | }, 51 | { 52 | name: "no token", 53 | response: FormResponse{ 54 | StatusCode: 200, 55 | values: url.Values{ 56 | "error": []string{"access_denied"}, 57 | }, 58 | }, 59 | want: nil, 60 | wantErr: &Error{ 61 | Code: "access_denied", 62 | ResponseCode: 200, 63 | }, 64 | }, 65 | } 66 | for _, tt := range tests { 67 | t.Run(tt.name, func(t *testing.T) { 68 | got, err := tt.response.AccessToken() 69 | if err != nil { 70 | apiError := err.(*Error) 71 | if !reflect.DeepEqual(apiError, tt.wantErr) { 72 | t.Fatalf("error %v, want %v", apiError, tt.wantErr) 73 | } 74 | } else if tt.wantErr != nil { 75 | t.Fatalf("want error %v, got nil", tt.wantErr) 76 | } 77 | if !reflect.DeepEqual(got, tt.want) { 78 | t.Errorf("FormResponse.AccessToken() = %v, want %v", got, tt.want) 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /api/form.go: -------------------------------------------------------------------------------- 1 | // Package api implements request and response parsing logic shared between different OAuth strategies. 2 | package api 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "mime" 10 | "net/http" 11 | "net/url" 12 | "strconv" 13 | ) 14 | 15 | type httpClient interface { 16 | PostForm(string, url.Values) (*http.Response, error) 17 | } 18 | 19 | // FormResponse is the parsed "www-form-urlencoded" response from the server. 20 | type FormResponse struct { 21 | StatusCode int 22 | 23 | requestURI string 24 | values url.Values 25 | } 26 | 27 | // Get the response value named k. 28 | func (f FormResponse) Get(k string) string { 29 | return f.values.Get(k) 30 | } 31 | 32 | // Err returns an Error object extracted from the response. 33 | func (f FormResponse) Err() error { 34 | return &Error{ 35 | RequestURI: f.requestURI, 36 | ResponseCode: f.StatusCode, 37 | Code: f.Get("error"), 38 | message: f.Get("error_description"), 39 | } 40 | } 41 | 42 | // Error is the result of an unexpected HTTP response from the server. 43 | type Error struct { 44 | Code string 45 | ResponseCode int 46 | RequestURI string 47 | 48 | message string 49 | } 50 | 51 | func (e Error) Error() string { 52 | if e.message != "" { 53 | return fmt.Sprintf("%s (%s)", e.message, e.Code) 54 | } 55 | if e.Code != "" { 56 | return e.Code 57 | } 58 | return fmt.Sprintf("HTTP %d", e.ResponseCode) 59 | } 60 | 61 | // PostForm makes an POST request by serializing input parameters as a form and parsing the response 62 | // of the same type. 63 | func PostForm(c httpClient, u string, params url.Values) (*FormResponse, error) { 64 | resp, err := c.PostForm(u, params) 65 | if err != nil { 66 | return nil, err 67 | } 68 | defer func() { 69 | _ = resp.Body.Close() 70 | }() 71 | 72 | r := &FormResponse{ 73 | StatusCode: resp.StatusCode, 74 | requestURI: u, 75 | } 76 | 77 | mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) 78 | switch mediaType { 79 | case "application/x-www-form-urlencoded": 80 | var bb []byte 81 | bb, err = ioutil.ReadAll(resp.Body) 82 | if err != nil { 83 | return r, err 84 | } 85 | 86 | r.values, err = url.ParseQuery(string(bb)) 87 | if err != nil { 88 | return r, err 89 | } 90 | case "application/json": 91 | var values map[string]interface{} 92 | if err := json.NewDecoder(resp.Body).Decode(&values); err != nil { 93 | return r, err 94 | } 95 | 96 | r.values = make(url.Values) 97 | for key, value := range values { 98 | switch v := value.(type) { 99 | case string: 100 | r.values.Set(key, v) 101 | case int64: 102 | r.values.Set(key, strconv.FormatInt(v, 10)) 103 | case float64: 104 | r.values.Set(key, strconv.FormatFloat(v, 'f', -1, 64)) 105 | } 106 | } 107 | default: 108 | _, err = io.Copy(ioutil.Discard, resp.Body) 109 | if err != nil { 110 | return r, err 111 | } 112 | } 113 | 114 | return r, nil 115 | } 116 | -------------------------------------------------------------------------------- /api/form_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "io/ioutil" 6 | "net/http" 7 | "net/url" 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestFormResponse_Get(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | response FormResponse 16 | key string 17 | want string 18 | }{ 19 | { 20 | name: "blank", 21 | response: FormResponse{}, 22 | key: "access_token", 23 | want: "", 24 | }, 25 | { 26 | name: "with value", 27 | response: FormResponse{ 28 | values: url.Values{ 29 | "access_token": []string{"ATOKEN"}, 30 | }, 31 | }, 32 | key: "access_token", 33 | want: "ATOKEN", 34 | }, 35 | } 36 | for _, tt := range tests { 37 | t.Run(tt.name, func(t *testing.T) { 38 | if got := tt.response.Get(tt.key); got != tt.want { 39 | t.Errorf("FormResponse.Get() = %v, want %v", got, tt.want) 40 | } 41 | }) 42 | } 43 | } 44 | 45 | func TestFormResponse_Err(t *testing.T) { 46 | tests := []struct { 47 | name string 48 | response FormResponse 49 | wantErr Error 50 | errorMsg string 51 | }{ 52 | { 53 | name: "blank", 54 | response: FormResponse{}, 55 | wantErr: Error{}, 56 | errorMsg: "HTTP 0", 57 | }, 58 | { 59 | name: "with values", 60 | response: FormResponse{ 61 | StatusCode: 422, 62 | requestURI: "http://example.com/path", 63 | values: url.Values{ 64 | "error": []string{"try_again"}, 65 | "error_description": []string{"maybe it works later"}, 66 | }, 67 | }, 68 | wantErr: Error{ 69 | Code: "try_again", 70 | ResponseCode: 422, 71 | RequestURI: "http://example.com/path", 72 | }, 73 | errorMsg: "maybe it works later (try_again)", 74 | }, 75 | { 76 | name: "no values", 77 | response: FormResponse{ 78 | StatusCode: 422, 79 | requestURI: "http://example.com/path", 80 | }, 81 | wantErr: Error{ 82 | Code: "", 83 | ResponseCode: 422, 84 | RequestURI: "http://example.com/path", 85 | }, 86 | errorMsg: "HTTP 422", 87 | }, 88 | } 89 | for _, tt := range tests { 90 | t.Run(tt.name, func(t *testing.T) { 91 | err := tt.response.Err() 92 | if err == nil { 93 | t.Fatalf("FormResponse.Err() = %v, want %v", nil, tt.wantErr) 94 | } 95 | apiError := err.(*Error) 96 | if apiError.Code != tt.wantErr.Code { 97 | t.Errorf("Error.Code = %v, want %v", apiError.Code, tt.wantErr.Code) 98 | } 99 | if apiError.ResponseCode != tt.wantErr.ResponseCode { 100 | t.Errorf("Error.ResponseCode = %v, want %v", apiError.ResponseCode, tt.wantErr.ResponseCode) 101 | } 102 | if apiError.RequestURI != tt.wantErr.RequestURI { 103 | t.Errorf("Error.RequestURI = %v, want %v", apiError.RequestURI, tt.wantErr.RequestURI) 104 | } 105 | if apiError.Error() != tt.errorMsg { 106 | t.Errorf("Error.Error() = %q, want %q", apiError.Error(), tt.errorMsg) 107 | } 108 | }) 109 | } 110 | } 111 | 112 | type apiClient struct { 113 | status int 114 | body string 115 | contentType string 116 | 117 | postCount int 118 | } 119 | 120 | func (c *apiClient) PostForm(u string, params url.Values) (*http.Response, error) { 121 | c.postCount++ 122 | return &http.Response{ 123 | Body: ioutil.NopCloser(bytes.NewBufferString(c.body)), 124 | Header: http.Header{ 125 | "Content-Type": {c.contentType}, 126 | }, 127 | StatusCode: c.status, 128 | }, nil 129 | } 130 | 131 | func TestPostForm(t *testing.T) { 132 | type args struct { 133 | url string 134 | params url.Values 135 | } 136 | tests := []struct { 137 | name string 138 | args args 139 | http apiClient 140 | want *FormResponse 141 | wantErr bool 142 | }{ 143 | { 144 | name: "success urlencoded", 145 | args: args{ 146 | url: "https://github.com/oauth", 147 | }, 148 | http: apiClient{ 149 | body: "access_token=123abc&scopes=repo%20gist", 150 | status: 200, 151 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 152 | }, 153 | want: &FormResponse{ 154 | StatusCode: 200, 155 | requestURI: "https://github.com/oauth", 156 | values: url.Values{ 157 | "access_token": {"123abc"}, 158 | "scopes": {"repo gist"}, 159 | }, 160 | }, 161 | wantErr: false, 162 | }, 163 | { 164 | name: "success JSON", 165 | args: args{ 166 | url: "https://github.com/oauth", 167 | }, 168 | http: apiClient{ 169 | body: `{"access_token":"123abc", "scopes":"repo gist"}`, 170 | status: 200, 171 | contentType: "application/json; charset=utf-8", 172 | }, 173 | want: &FormResponse{ 174 | StatusCode: 200, 175 | requestURI: "https://github.com/oauth", 176 | values: url.Values{ 177 | "access_token": {"123abc"}, 178 | "scopes": {"repo gist"}, 179 | }, 180 | }, 181 | wantErr: false, 182 | }, 183 | { 184 | name: "HTML response", 185 | args: args{ 186 | url: "https://github.com/oauth", 187 | }, 188 | http: apiClient{ 189 | body: "

Something went wrong

", 190 | status: 502, 191 | contentType: "text/html", 192 | }, 193 | want: &FormResponse{ 194 | StatusCode: 502, 195 | requestURI: "https://github.com/oauth", 196 | values: url.Values(nil), 197 | }, 198 | wantErr: false, 199 | }, 200 | } 201 | for _, tt := range tests { 202 | t.Run(tt.name, func(t *testing.T) { 203 | got, err := PostForm(&tt.http, tt.args.url, tt.args.params) 204 | if (err != nil) != tt.wantErr { 205 | t.Errorf("PostForm() error = %v, wantErr %v", err, tt.wantErr) 206 | return 207 | } 208 | if tt.http.postCount != 1 { 209 | t.Errorf("expected PostForm to happen 1 time; happened %d times", tt.http.postCount) 210 | } 211 | if !reflect.DeepEqual(got, tt.want) { 212 | t.Errorf("PostForm() = %v, want %v", got, tt.want) 213 | } 214 | }) 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /device/device_flow.go: -------------------------------------------------------------------------------- 1 | // Package device facilitates performing OAuth Device Authorization Flow for client applications 2 | // such as CLIs that can not receive redirects from a web site. 3 | // 4 | // First, RequestCode should be used to obtain a CodeResponse. 5 | // 6 | // Next, the user will need to navigate to VerificationURI in their web browser on any device and fill 7 | // in the UserCode. 8 | // 9 | // While the user is completing the web flow, the application should invoke PollToken, which blocks 10 | // the goroutine until the user has authorized the app on the server. 11 | // 12 | // https://docs.github.com/en/free-pro-team@latest/developers/apps/authorizing-oauth-apps#device-flow 13 | package device 14 | 15 | import ( 16 | "context" 17 | "errors" 18 | "fmt" 19 | "net/http" 20 | "net/url" 21 | "strconv" 22 | "strings" 23 | "time" 24 | 25 | "github.com/cli/oauth/api" 26 | ) 27 | 28 | var ( 29 | // ErrUnsupported is thrown when the server does not implement Device flow. 30 | ErrUnsupported = errors.New("device flow not supported") 31 | // ErrTimeout is thrown when polling the server for the granted token has timed out. 32 | ErrTimeout = errors.New("authentication timed out") 33 | ) 34 | 35 | type httpClient interface { 36 | PostForm(string, url.Values) (*http.Response, error) 37 | } 38 | 39 | // CodeResponse holds information about the authorization-in-progress. 40 | type CodeResponse struct { 41 | // The user verification code is displayed on the device so the user can enter the code in a browser. 42 | UserCode string 43 | // The verification URL where users need to enter the UserCode. 44 | VerificationURI string 45 | // The optional verification URL that includes the UserCode. 46 | VerificationURIComplete string 47 | 48 | // The device verification code is 40 characters and used to verify the device. 49 | DeviceCode string 50 | // The number of seconds before the DeviceCode and UserCode expire. 51 | ExpiresIn int 52 | // The minimum number of seconds that must pass before you can make a new access token request to 53 | // complete the device authorization. 54 | Interval int 55 | } 56 | 57 | // AuthRequestEditorFn defines the function signature for setting additional form values. 58 | type AuthRequestEditorFn func(*url.Values) 59 | 60 | // WithAudience sets the audience parameter in the request. 61 | func WithAudience(audience string) AuthRequestEditorFn { 62 | return func(values *url.Values) { 63 | if audience != "" { 64 | values.Add("audience", audience) 65 | } 66 | } 67 | } 68 | 69 | // RequestCode initiates the authorization flow by requesting a code from uri. 70 | func RequestCode(c httpClient, uri string, clientID string, scopes []string, 71 | optionalRequestParams ...AuthRequestEditorFn) (*CodeResponse, error) { 72 | values := url.Values{ 73 | "client_id": {clientID}, 74 | "scope": {strings.Join(scopes, " ")}, 75 | } 76 | 77 | for _, fn := range optionalRequestParams { 78 | fn(&values) 79 | } 80 | 81 | resp, err := api.PostForm(c, uri, values) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | verificationURI := resp.Get("verification_uri") 87 | if verificationURI == "" { 88 | // Google's "OAuth 2.0 for TV and Limited-Input Device Applications" uses `verification_url`. 89 | verificationURI = resp.Get("verification_url") 90 | } 91 | 92 | if resp.StatusCode == 401 || resp.StatusCode == 403 || resp.StatusCode == 404 || resp.StatusCode == 422 || 93 | (resp.StatusCode == 200 && verificationURI == "") || 94 | (resp.StatusCode == 400 && resp.Get("error") == "device_flow_disabled") || 95 | (resp.StatusCode == 400 && resp.Get("error") == "unauthorized_client") { 96 | return nil, ErrUnsupported 97 | } 98 | 99 | if resp.StatusCode != 200 { 100 | return nil, resp.Err() 101 | } 102 | 103 | intervalSeconds, err := strconv.Atoi(resp.Get("interval")) 104 | if err != nil { 105 | return nil, fmt.Errorf("could not parse interval=%q as integer: %w", resp.Get("interval"), err) 106 | } 107 | 108 | expiresIn, err := strconv.Atoi(resp.Get("expires_in")) 109 | if err != nil { 110 | return nil, fmt.Errorf("could not parse expires_in=%q as integer: %w", resp.Get("expires_in"), err) 111 | } 112 | 113 | return &CodeResponse{ 114 | DeviceCode: resp.Get("device_code"), 115 | UserCode: resp.Get("user_code"), 116 | VerificationURI: verificationURI, 117 | VerificationURIComplete: resp.Get("verification_uri_complete"), 118 | Interval: intervalSeconds, 119 | ExpiresIn: expiresIn, 120 | }, nil 121 | } 122 | 123 | const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code" 124 | 125 | // PollToken polls the server at pollURL until an access token is granted or denied. 126 | // 127 | // Deprecated: use Wait. 128 | func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) { 129 | return Wait(context.Background(), c, pollURL, WaitOptions{ 130 | ClientID: clientID, 131 | DeviceCode: code, 132 | }) 133 | } 134 | 135 | // WaitOptions specifies parameters to poll the server with until authentication completes. 136 | type WaitOptions struct { 137 | // ClientID is the app client ID value. 138 | ClientID string 139 | // ClientSecret is the app client secret value. Optional: only pass if the server requires it. 140 | ClientSecret string 141 | // DeviceCode is the value obtained from RequestCode. 142 | DeviceCode *CodeResponse 143 | // GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional. 144 | GrantType string 145 | 146 | newPoller pollerFactory 147 | } 148 | 149 | // Wait polls the server at uri until authorization completes. 150 | func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api.AccessToken, error) { 151 | checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second 152 | expiresIn := time.Duration(opts.DeviceCode.ExpiresIn) * time.Second 153 | grantType := opts.GrantType 154 | if opts.GrantType == "" { 155 | grantType = defaultGrantType 156 | } 157 | 158 | makePoller := opts.newPoller 159 | if makePoller == nil { 160 | makePoller = newPoller 161 | } 162 | _, poll := makePoller(ctx, checkInterval, expiresIn) 163 | 164 | for { 165 | if err := poll.Wait(); err != nil { 166 | return nil, err 167 | } 168 | 169 | values := url.Values{ 170 | "client_id": {opts.ClientID}, 171 | "device_code": {opts.DeviceCode.DeviceCode}, 172 | "grant_type": {grantType}, 173 | } 174 | 175 | // Google's "OAuth 2.0 for TV and Limited-Input Device Applications" requires `client_secret`. 176 | if opts.ClientSecret != "" { 177 | values.Add("client_secret", opts.ClientSecret) 178 | } 179 | 180 | // TODO: pass tctx down to the HTTP layer 181 | resp, err := api.PostForm(c, uri, values) 182 | if err != nil { 183 | return nil, err 184 | } 185 | 186 | var apiError *api.Error 187 | token, err := resp.AccessToken() 188 | if err == nil { 189 | return token, nil 190 | } else if !(errors.As(err, &apiError) && apiError.Code == "authorization_pending") { 191 | return nil, err 192 | } 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /device/device_flow_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "reflect" 11 | "testing" 12 | "time" 13 | 14 | "github.com/cli/oauth/api" 15 | ) 16 | 17 | type apiStub struct { 18 | status int 19 | body string 20 | contentType string 21 | } 22 | 23 | type postArgs struct { 24 | url string 25 | params url.Values 26 | } 27 | 28 | type apiClient struct { 29 | stubs []apiStub 30 | calls []postArgs 31 | 32 | postCount int 33 | } 34 | 35 | func (c *apiClient) PostForm(u string, params url.Values) (*http.Response, error) { 36 | stub := c.stubs[c.postCount] 37 | c.calls = append(c.calls, postArgs{url: u, params: params}) 38 | c.postCount++ 39 | return &http.Response{ 40 | Body: ioutil.NopCloser(bytes.NewBufferString(stub.body)), 41 | Header: http.Header{ 42 | "Content-Type": {stub.contentType}, 43 | }, 44 | StatusCode: stub.status, 45 | }, nil 46 | } 47 | 48 | func TestRequestCode(t *testing.T) { 49 | type args struct { 50 | http apiClient 51 | url string 52 | clientID string 53 | scopes []string 54 | audience string 55 | } 56 | tests := []struct { 57 | name string 58 | args args 59 | want *CodeResponse 60 | wantErr string 61 | posts []postArgs 62 | }{ 63 | { 64 | name: "success", 65 | args: args{ 66 | http: apiClient{ 67 | stubs: []apiStub{ 68 | { 69 | body: "verification_uri=http://verify.me&interval=5&expires_in=99&device_code=DEVIC&user_code=123-abc", 70 | status: 200, 71 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 72 | }, 73 | }, 74 | }, 75 | url: "https://github.com/oauth", 76 | clientID: "CLIENT-ID", 77 | scopes: []string{"repo", "gist"}, 78 | }, 79 | want: &CodeResponse{ 80 | DeviceCode: "DEVIC", 81 | UserCode: "123-abc", 82 | VerificationURI: "http://verify.me", 83 | ExpiresIn: 99, 84 | Interval: 5, 85 | }, 86 | posts: []postArgs{ 87 | { 88 | url: "https://github.com/oauth", 89 | params: url.Values{ 90 | "client_id": {"CLIENT-ID"}, 91 | "scope": {"repo gist"}, 92 | }, 93 | }, 94 | }, 95 | }, 96 | { 97 | name: "with verification_uri_complete", 98 | args: args{ 99 | http: apiClient{ 100 | stubs: []apiStub{ 101 | { 102 | body: "verification_uri=http://verify.me&interval=5&expires_in=99&device_code=DEVIC&user_code=123-abc&verification_uri_complete=http://verify.me/?code=123-abc", 103 | status: 200, 104 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 105 | }, 106 | }, 107 | }, 108 | url: "https://github.com/oauth", 109 | clientID: "CLIENT-ID", 110 | scopes: []string{"repo", "gist"}, 111 | }, 112 | want: &CodeResponse{ 113 | DeviceCode: "DEVIC", 114 | UserCode: "123-abc", 115 | VerificationURI: "http://verify.me", 116 | VerificationURIComplete: "http://verify.me/?code=123-abc", 117 | ExpiresIn: 99, 118 | Interval: 5, 119 | }, 120 | posts: []postArgs{ 121 | { 122 | url: "https://github.com/oauth", 123 | params: url.Values{ 124 | "client_id": {"CLIENT-ID"}, 125 | "scope": {"repo gist"}, 126 | }, 127 | }, 128 | }, 129 | }, 130 | { 131 | name: "with audience", 132 | args: args{ 133 | http: apiClient{ 134 | stubs: []apiStub{ 135 | { 136 | body: "verification_uri=http://verify.me&interval=5&expires_in=99&device_code=DEVIC&user_code=123-abc&verification_uri_complete=http://verify.me/?code=123-abc", 137 | status: 200, 138 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 139 | }, 140 | }, 141 | }, 142 | url: "https://github.com/oauth", 143 | clientID: "CLIENT-ID", 144 | scopes: []string{"repo", "gist"}, 145 | audience: "https://api.github.com", 146 | }, 147 | want: &CodeResponse{ 148 | DeviceCode: "DEVIC", 149 | UserCode: "123-abc", 150 | VerificationURI: "http://verify.me", 151 | VerificationURIComplete: "http://verify.me/?code=123-abc", 152 | ExpiresIn: 99, 153 | Interval: 5, 154 | }, 155 | posts: []postArgs{ 156 | { 157 | url: "https://github.com/oauth", 158 | params: url.Values{ 159 | "client_id": {"CLIENT-ID"}, 160 | "scope": {"repo gist"}, 161 | "audience": {"https://api.github.com"}, 162 | }, 163 | }, 164 | }, 165 | }, 166 | { 167 | name: "unsupported", 168 | args: args{ 169 | http: apiClient{ 170 | stubs: []apiStub{ 171 | { 172 | body: "", 173 | status: 404, 174 | contentType: "text/html", 175 | }, 176 | }, 177 | }, 178 | url: "https://github.com/oauth", 179 | clientID: "CLIENT-ID", 180 | scopes: []string{"repo", "gist"}, 181 | }, 182 | wantErr: "device flow not supported", 183 | posts: []postArgs{ 184 | { 185 | url: "https://github.com/oauth", 186 | params: url.Values{ 187 | "client_id": {"CLIENT-ID"}, 188 | "scope": {"repo gist"}, 189 | }, 190 | }, 191 | }, 192 | }, 193 | { 194 | name: "unauthorized client", 195 | args: args{ 196 | http: apiClient{ 197 | stubs: []apiStub{ 198 | { 199 | body: "error=unauthorized_client", 200 | status: 400, 201 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 202 | }, 203 | }, 204 | }, 205 | url: "https://github.com/oauth", 206 | clientID: "CLIENT-ID", 207 | scopes: []string{"repo", "gist"}, 208 | }, 209 | wantErr: "device flow not supported", 210 | posts: []postArgs{ 211 | { 212 | url: "https://github.com/oauth", 213 | params: url.Values{ 214 | "client_id": {"CLIENT-ID"}, 215 | "scope": {"repo gist"}, 216 | }, 217 | }, 218 | }, 219 | }, 220 | { 221 | name: "device flow disabled", 222 | args: args{ 223 | http: apiClient{ 224 | stubs: []apiStub{ 225 | { 226 | body: "error=device_flow_disabled", 227 | status: 400, 228 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 229 | }, 230 | }, 231 | }, 232 | url: "https://github.com/oauth", 233 | clientID: "CLIENT-ID", 234 | scopes: []string{"repo", "gist"}, 235 | }, 236 | wantErr: "device flow not supported", 237 | posts: []postArgs{ 238 | { 239 | url: "https://github.com/oauth", 240 | params: url.Values{ 241 | "client_id": {"CLIENT-ID"}, 242 | "scope": {"repo gist"}, 243 | }, 244 | }, 245 | }, 246 | }, 247 | { 248 | name: "server error", 249 | args: args{ 250 | http: apiClient{ 251 | stubs: []apiStub{ 252 | { 253 | body: "

Something went wrong

", 254 | status: 502, 255 | contentType: "text/html", 256 | }, 257 | }, 258 | }, 259 | url: "https://github.com/oauth", 260 | clientID: "CLIENT-ID", 261 | scopes: []string{"repo", "gist"}, 262 | }, 263 | wantErr: "HTTP 502", 264 | posts: []postArgs{ 265 | { 266 | url: "https://github.com/oauth", 267 | params: url.Values{ 268 | "client_id": {"CLIENT-ID"}, 269 | "scope": {"repo gist"}, 270 | }, 271 | }, 272 | }, 273 | }, 274 | } 275 | for _, tt := range tests { 276 | t.Run(tt.name, func(t *testing.T) { 277 | got, err := RequestCode(&tt.args.http, tt.args.url, 278 | tt.args.clientID, tt.args.scopes, WithAudience(tt.args.audience)) 279 | if (err != nil) != (tt.wantErr != "") { 280 | t.Errorf("RequestCode() error = %v, wantErr %v", err, tt.wantErr) 281 | return 282 | } 283 | if tt.wantErr != "" && err.Error() != tt.wantErr { 284 | t.Errorf("error = %q, want %q", err.Error(), tt.wantErr) 285 | } 286 | if tt.args.http.postCount != 1 { 287 | t.Errorf("expected PostForm to happen 1 time; happened %d times", tt.args.http.postCount) 288 | } 289 | if !reflect.DeepEqual(got, tt.want) { 290 | t.Errorf("RequestCode() = %v, want %v", got, tt.want) 291 | } 292 | if !reflect.DeepEqual(tt.args.http.calls, tt.posts) { 293 | t.Errorf("PostForm() = %v, want %v", tt.args.http.calls, tt.posts) 294 | } 295 | }) 296 | } 297 | } 298 | 299 | func TestPollToken(t *testing.T) { 300 | makeFakePoller := func(maxWaits int) pollerFactory { 301 | return func(ctx context.Context, interval, expiresIn time.Duration) (context.Context, poller) { 302 | return ctx, &fakePoller{maxWaits: maxWaits} 303 | } 304 | } 305 | 306 | type args struct { 307 | http apiClient 308 | url string 309 | opts WaitOptions 310 | } 311 | tests := []struct { 312 | name string 313 | args args 314 | want *api.AccessToken 315 | wantErr string 316 | posts []postArgs 317 | slept time.Duration 318 | }{ 319 | { 320 | name: "success", 321 | args: args{ 322 | http: apiClient{ 323 | stubs: []apiStub{ 324 | { 325 | body: "error=authorization_pending", 326 | status: 200, 327 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 328 | }, 329 | { 330 | body: "access_token=123abc", 331 | status: 200, 332 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 333 | }, 334 | }, 335 | }, 336 | url: "https://github.com/oauth", 337 | opts: WaitOptions{ 338 | ClientID: "CLIENT-ID", 339 | DeviceCode: &CodeResponse{ 340 | DeviceCode: "DEVIC", 341 | UserCode: "123-abc", 342 | VerificationURI: "http://verify.me", 343 | ExpiresIn: 99, 344 | Interval: 5, 345 | }, 346 | newPoller: makeFakePoller(2), 347 | }, 348 | }, 349 | want: &api.AccessToken{ 350 | Token: "123abc", 351 | }, 352 | posts: []postArgs{ 353 | { 354 | url: "https://github.com/oauth", 355 | params: url.Values{ 356 | "client_id": {"CLIENT-ID"}, 357 | "device_code": {"DEVIC"}, 358 | "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, 359 | }, 360 | }, 361 | { 362 | url: "https://github.com/oauth", 363 | params: url.Values{ 364 | "client_id": {"CLIENT-ID"}, 365 | "device_code": {"DEVIC"}, 366 | "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, 367 | }, 368 | }, 369 | }, 370 | }, 371 | { 372 | name: "with client secret and grant type", 373 | args: args{ 374 | http: apiClient{ 375 | stubs: []apiStub{ 376 | { 377 | body: "access_token=123abc", 378 | status: 200, 379 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 380 | }, 381 | }, 382 | }, 383 | url: "https://github.com/oauth", 384 | opts: WaitOptions{ 385 | ClientID: "CLIENT-ID", 386 | ClientSecret: "SEKRIT", 387 | GrantType: "device_code", 388 | DeviceCode: &CodeResponse{ 389 | DeviceCode: "DEVIC", 390 | UserCode: "123-abc", 391 | VerificationURI: "http://verify.me", 392 | ExpiresIn: 99, 393 | Interval: 5, 394 | }, 395 | newPoller: makeFakePoller(1), 396 | }, 397 | }, 398 | want: &api.AccessToken{ 399 | Token: "123abc", 400 | }, 401 | posts: []postArgs{ 402 | { 403 | url: "https://github.com/oauth", 404 | params: url.Values{ 405 | "client_id": {"CLIENT-ID"}, 406 | "client_secret": {"SEKRIT"}, 407 | "device_code": {"DEVIC"}, 408 | "grant_type": {"device_code"}, 409 | }, 410 | }, 411 | }, 412 | }, 413 | { 414 | name: "timed out", 415 | args: args{ 416 | http: apiClient{ 417 | stubs: []apiStub{ 418 | { 419 | body: "error=authorization_pending", 420 | status: 200, 421 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 422 | }, 423 | { 424 | body: "error=authorization_pending", 425 | status: 200, 426 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 427 | }, 428 | }, 429 | }, 430 | url: "https://github.com/oauth", 431 | opts: WaitOptions{ 432 | ClientID: "CLIENT-ID", 433 | DeviceCode: &CodeResponse{ 434 | DeviceCode: "DEVIC", 435 | UserCode: "123-abc", 436 | VerificationURI: "http://verify.me", 437 | ExpiresIn: 14, 438 | Interval: 5, 439 | }, 440 | newPoller: makeFakePoller(2), 441 | }, 442 | }, 443 | wantErr: "context deadline exceeded", 444 | posts: []postArgs{ 445 | { 446 | url: "https://github.com/oauth", 447 | params: url.Values{ 448 | "client_id": {"CLIENT-ID"}, 449 | "device_code": {"DEVIC"}, 450 | "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, 451 | }, 452 | }, 453 | { 454 | url: "https://github.com/oauth", 455 | params: url.Values{ 456 | "client_id": {"CLIENT-ID"}, 457 | "device_code": {"DEVIC"}, 458 | "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, 459 | }, 460 | }, 461 | }, 462 | }, 463 | { 464 | name: "access denied", 465 | args: args{ 466 | http: apiClient{ 467 | stubs: []apiStub{ 468 | { 469 | body: "error=access_denied", 470 | status: 200, 471 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 472 | }, 473 | }, 474 | }, 475 | url: "https://github.com/oauth", 476 | opts: WaitOptions{ 477 | ClientID: "CLIENT-ID", 478 | DeviceCode: &CodeResponse{ 479 | DeviceCode: "DEVIC", 480 | UserCode: "123-abc", 481 | VerificationURI: "http://verify.me", 482 | ExpiresIn: 99, 483 | Interval: 5, 484 | }, 485 | newPoller: makeFakePoller(1), 486 | }, 487 | }, 488 | wantErr: "access_denied", 489 | posts: []postArgs{ 490 | { 491 | url: "https://github.com/oauth", 492 | params: url.Values{ 493 | "client_id": {"CLIENT-ID"}, 494 | "device_code": {"DEVIC"}, 495 | "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, 496 | }, 497 | }, 498 | }, 499 | }, 500 | } 501 | for _, tt := range tests { 502 | t.Run(tt.name, func(t *testing.T) { 503 | got, err := Wait(context.Background(), &tt.args.http, tt.args.url, tt.args.opts) 504 | if (err != nil) != (tt.wantErr != "") { 505 | t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr) 506 | return 507 | } 508 | if tt.wantErr != "" && err.Error() != tt.wantErr { 509 | t.Errorf("PollToken error = %q, want %q", err.Error(), tt.wantErr) 510 | } 511 | if !reflect.DeepEqual(got, tt.want) { 512 | t.Errorf("PollToken() = %v, want %v", got, tt.want) 513 | } 514 | if !reflect.DeepEqual(tt.args.http.calls, tt.posts) { 515 | t.Errorf("PostForm() = %v, want %v", tt.args.http.calls, tt.posts) 516 | } 517 | }) 518 | } 519 | } 520 | 521 | type fakePoller struct { 522 | maxWaits int 523 | count int 524 | } 525 | 526 | func (p *fakePoller) Wait() error { 527 | if p.count == p.maxWaits { 528 | return errors.New("context deadline exceeded") 529 | } 530 | p.count++ 531 | return nil 532 | } 533 | 534 | func (p *fakePoller) Cancel() { 535 | } 536 | -------------------------------------------------------------------------------- /device/examples_test.go: -------------------------------------------------------------------------------- 1 | package device_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "os" 8 | 9 | "github.com/cli/oauth/device" 10 | ) 11 | 12 | // This demonstrates how to perform OAuth Device Authorization Flow for GitHub.com. 13 | // After RequestCode successfully completes, the client app should prompt the user to copy 14 | // the UserCode and to open VerificationURI in their web browser to enter the code. 15 | func ExampleRequestCode() { 16 | clientID := os.Getenv("OAUTH_CLIENT_ID") 17 | scopes := []string{"repo", "read:org"} 18 | httpClient := http.DefaultClient 19 | 20 | code, err := device.RequestCode(httpClient, "https://github.com/login/device/code", clientID, scopes) 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | fmt.Printf("Copy code: %s\n", code.UserCode) 26 | fmt.Printf("then open: %s\n", code.VerificationURI) 27 | 28 | accessToken, err := device.Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", device.WaitOptions{ 29 | ClientID: clientID, 30 | DeviceCode: code, 31 | }) 32 | if err != nil { 33 | panic(err) 34 | } 35 | 36 | fmt.Printf("Access token: %s\n", accessToken.Token) 37 | } 38 | -------------------------------------------------------------------------------- /device/poller.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | type poller interface { 9 | Wait() error 10 | Cancel() 11 | } 12 | 13 | type pollerFactory func(context.Context, time.Duration, time.Duration) (context.Context, poller) 14 | 15 | func newPoller(ctx context.Context, checkInteval, expiresIn time.Duration) (context.Context, poller) { 16 | c, cancel := context.WithTimeout(ctx, expiresIn) 17 | return c, &intervalPoller{ 18 | ctx: c, 19 | interval: checkInteval, 20 | cancelFunc: cancel, 21 | } 22 | } 23 | 24 | type intervalPoller struct { 25 | ctx context.Context 26 | interval time.Duration 27 | cancelFunc func() 28 | } 29 | 30 | func (p intervalPoller) Wait() error { 31 | t := time.NewTimer(p.interval) 32 | select { 33 | case <-p.ctx.Done(): 34 | t.Stop() 35 | return p.ctx.Err() 36 | case <-t.C: 37 | return nil 38 | } 39 | } 40 | 41 | func (p intervalPoller) Cancel() { 42 | p.cancelFunc() 43 | } 44 | -------------------------------------------------------------------------------- /examples_test.go: -------------------------------------------------------------------------------- 1 | package oauth_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/cli/oauth" 8 | ) 9 | 10 | // DetectFlow attempts to initiate OAuth Device flow with the server and falls back to OAuth Web 11 | // application flow if Device flow seems unsupported. This approach isn't strictly needed for 12 | // github.com, as its Device flow support is globally available, but it enables logging in to 13 | // self-hosted GitHub instances as well. 14 | func ExampleFlow_DetectFlow() { 15 | host, err := oauth.NewGitHubHost("https://github.com") 16 | if err != nil { 17 | panic(err) 18 | } 19 | flow := &oauth.Flow{ 20 | Host: host, 21 | ClientID: os.Getenv("OAUTH_CLIENT_ID"), 22 | ClientSecret: os.Getenv("OAUTH_CLIENT_SECRET"), // only applicable to web app flow 23 | CallbackURI: "http://127.0.0.1/callback", // only applicable to web app flow 24 | Scopes: []string{"repo", "read:org", "gist"}, 25 | } 26 | 27 | accessToken, err := flow.DetectFlow() 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | fmt.Printf("Access token: %s\n", accessToken.Token) 33 | } 34 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cli/oauth 2 | 3 | go 1.13 4 | 5 | require github.com/cli/browser v1.0.0 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cli/browser v1.0.0 h1:RIleZgXrhdiCVgFBSjtWwkLPUCWyhhhN5k5HGSBt1js= 2 | github.com/cli/browser v1.0.0/go.mod h1:IEWkHYbLjkhtjwwWlwTHW2lGxeS5gezEQBMLTwDHf5Q= 3 | github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI= 4 | github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= 5 | -------------------------------------------------------------------------------- /oauth.go: -------------------------------------------------------------------------------- 1 | // Package oauth is a library for Go client applications that need to perform OAuth authorization 2 | // against a server, typically GitHub.com. 3 | package oauth 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/url" 11 | "strings" 12 | 13 | "github.com/cli/oauth/api" 14 | "github.com/cli/oauth/device" 15 | ) 16 | 17 | type httpClient interface { 18 | PostForm(string, url.Values) (*http.Response, error) 19 | } 20 | 21 | // Host defines the endpoints used to authorize against an OAuth server. 22 | type Host struct { 23 | DeviceCodeURL string 24 | AuthorizeURL string 25 | TokenURL string 26 | } 27 | 28 | // NewGitHubHost constructs a Host from the given URL to a GitHub instance. 29 | func NewGitHubHost(hostURL string) (*Host, error) { 30 | base, err := url.Parse(strings.TrimSpace(hostURL)) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | createURL := func(path string) string { 36 | u := *base // Copy base URL 37 | u.Path = path 38 | return u.String() 39 | } 40 | 41 | return &Host{ 42 | DeviceCodeURL: createURL("/login/device/code"), 43 | AuthorizeURL: createURL("/login/oauth/authorize"), 44 | TokenURL: createURL("/login/oauth/access_token"), 45 | }, nil 46 | } 47 | 48 | // GitHubHost constructs a Host from the given URL to a GitHub instance. 49 | // 50 | // Deprecated: `GitHubHost` can panic with a malformed `hostURL`. Use `NewGitHubHost` instead for graceful error handling. 51 | func GitHubHost(hostURL string) *Host { 52 | u, _ := url.Parse(hostURL) 53 | 54 | return &Host{ 55 | DeviceCodeURL: fmt.Sprintf("%s://%s/login/device/code", u.Scheme, u.Host), 56 | AuthorizeURL: fmt.Sprintf("%s://%s/login/oauth/authorize", u.Scheme, u.Host), 57 | TokenURL: fmt.Sprintf("%s://%s/login/oauth/access_token", u.Scheme, u.Host), 58 | } 59 | } 60 | 61 | // Flow facilitates a single OAuth authorization flow. 62 | type Flow struct { 63 | // The hostname to authorize the app with. 64 | // 65 | // Deprecated: Use Host instead. 66 | Hostname string 67 | // Host configuration to authorize the app with. 68 | Host *Host 69 | // OAuth scopes to request from the user. 70 | Scopes []string 71 | // OAuth audience to request from the user. 72 | Audience string 73 | // OAuth application ID. 74 | ClientID string 75 | // OAuth application secret. Only applicable in web application flow. 76 | ClientSecret string 77 | // The localhost URI for web application flow callback, e.g. "http://127.0.0.1/callback". 78 | CallbackURI string 79 | 80 | // Display a one-time code to the user. Receives the code and the browser URL as arguments. Defaults to printing the 81 | // code to the user on Stdout with instructions to copy the code and to press Enter to continue in their browser. 82 | DisplayCode func(string, string) error 83 | // Open a web browser at a URL. Defaults to opening the default system browser. 84 | BrowseURL func(string) error 85 | // Render an HTML page to the user upon completion of web application flow. The default is to 86 | // render a simple message that informs the user they can close the browser tab and return to the app. 87 | WriteSuccessHTML func(io.Writer) 88 | 89 | // The HTTP client to use for API POST requests. Defaults to http.DefaultClient. 90 | HTTPClient httpClient 91 | // The stream to listen to keyboard input on. Defaults to os.Stdin. 92 | Stdin io.Reader 93 | // The stream to print UI messages to. Defaults to os.Stdout. 94 | Stdout io.Writer 95 | } 96 | 97 | // DetectFlow tries to perform Device flow first and falls back to Web application flow. 98 | func (oa *Flow) DetectFlow() (*api.AccessToken, error) { 99 | accessToken, err := oa.DeviceFlow() 100 | if errors.Is(err, device.ErrUnsupported) { 101 | return oa.WebAppFlow() 102 | } 103 | return accessToken, err 104 | } 105 | -------------------------------------------------------------------------------- /oauth_device.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "os" 10 | 11 | "github.com/cli/browser" 12 | "github.com/cli/oauth/api" 13 | "github.com/cli/oauth/device" 14 | ) 15 | 16 | // DeviceFlow captures the full OAuth Device flow, including prompting the user to copy a one-time 17 | // code and opening their web browser, and returns an access token upon completion. 18 | func (oa *Flow) DeviceFlow() (*api.AccessToken, error) { 19 | httpClient := oa.HTTPClient 20 | if httpClient == nil { 21 | httpClient = http.DefaultClient 22 | } 23 | 24 | stdin := oa.Stdin 25 | if stdin == nil { 26 | stdin = os.Stdin 27 | } 28 | stdout := oa.Stdout 29 | if stdout == nil { 30 | stdout = os.Stdout 31 | } 32 | 33 | host := oa.Host 34 | if host == nil { 35 | parsedHost, err := NewGitHubHost("https://" + oa.Hostname) 36 | if err != nil { 37 | return nil, fmt.Errorf("error parsing the hostname '%s': %w", oa.Hostname, err) 38 | } 39 | host = parsedHost 40 | } 41 | 42 | code, err := device.RequestCode(httpClient, host.DeviceCodeURL, 43 | oa.ClientID, oa.Scopes, device.WithAudience(oa.Audience)) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | if oa.DisplayCode == nil { 49 | fmt.Fprintf(stdout, "First, copy your one-time code: %s\n", code.UserCode) 50 | fmt.Fprint(stdout, "Then press [Enter] to continue in the web browser... ") 51 | _ = waitForEnter(stdin) 52 | } else { 53 | err := oa.DisplayCode(code.UserCode, code.VerificationURI) 54 | if err != nil { 55 | return nil, err 56 | } 57 | } 58 | 59 | browseURL := oa.BrowseURL 60 | if browseURL == nil { 61 | browseURL = browser.OpenURL 62 | } 63 | 64 | if err = browseURL(code.VerificationURI); err != nil { 65 | return nil, fmt.Errorf("error opening the web browser: %w", err) 66 | } 67 | 68 | return device.Wait(context.TODO(), httpClient, host.TokenURL, device.WaitOptions{ 69 | ClientID: oa.ClientID, 70 | DeviceCode: code, 71 | }) 72 | } 73 | 74 | func waitForEnter(r io.Reader) error { 75 | scanner := bufio.NewScanner(r) 76 | scanner.Scan() 77 | return scanner.Err() 78 | } 79 | -------------------------------------------------------------------------------- /oauth_webapp.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/cli/browser" 9 | "github.com/cli/oauth/api" 10 | "github.com/cli/oauth/webapp" 11 | ) 12 | 13 | // WebAppFlow starts a local HTTP server, opens the web browser to initiate the OAuth Web application 14 | // flow, blocks until the user completes authorization and is redirected back, and returns the access token. 15 | func (oa *Flow) WebAppFlow() (*api.AccessToken, error) { 16 | host := oa.Host 17 | 18 | if host == nil { 19 | parsedHost, err := NewGitHubHost("https://" + oa.Hostname) 20 | if err != nil { 21 | return nil, fmt.Errorf("error parsing the hostname '%s': %w", oa.Hostname, err) 22 | } 23 | host = parsedHost 24 | } 25 | 26 | flow, err := webapp.InitFlow() 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | params := webapp.BrowserParams{ 32 | ClientID: oa.ClientID, 33 | RedirectURI: oa.CallbackURI, 34 | Scopes: oa.Scopes, 35 | Audience: oa.Audience, 36 | AllowSignup: true, 37 | } 38 | browserURL, err := flow.BrowserURL(host.AuthorizeURL, params) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | go func() { 44 | _ = flow.StartServer(oa.WriteSuccessHTML) 45 | }() 46 | 47 | browseURL := oa.BrowseURL 48 | if browseURL == nil { 49 | browseURL = browser.OpenURL 50 | } 51 | 52 | err = browseURL(browserURL) 53 | if err != nil { 54 | return nil, fmt.Errorf("error opening the web browser: %w", err) 55 | } 56 | 57 | httpClient := oa.HTTPClient 58 | if httpClient == nil { 59 | httpClient = http.DefaultClient 60 | } 61 | 62 | return flow.Wait(context.TODO(), httpClient, host.TokenURL, webapp.WaitOptions{ 63 | ClientSecret: oa.ClientSecret, 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /webapp/examples_test.go: -------------------------------------------------------------------------------- 1 | package webapp_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "os" 8 | 9 | "github.com/cli/browser" 10 | "github.com/cli/oauth/webapp" 11 | ) 12 | 13 | // Initiate the OAuth App Authorization Flow for GitHub.com. 14 | func ExampleInitFlow() { 15 | clientID := os.Getenv("OAUTH_CLIENT_ID") 16 | clientSecret := os.Getenv("OAUTH_CLIENT_SECRET") 17 | callbackURL := "http://127.0.0.1/callback" 18 | 19 | flow, err := webapp.InitFlow() 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | params := webapp.BrowserParams{ 25 | ClientID: clientID, 26 | RedirectURI: callbackURL, 27 | Scopes: []string{"repo", "read:org"}, 28 | AllowSignup: true, 29 | } 30 | browserURL, err := flow.BrowserURL("https://github.com/login/oauth/authorize", params) 31 | if err != nil { 32 | panic(err) 33 | } 34 | 35 | // A localhost server on a random available port will receive the web redirect. 36 | go func() { 37 | _ = flow.StartServer(nil) 38 | }() 39 | 40 | // Note: the user's web browser must run on the same device as the running app. 41 | err = browser.OpenURL(browserURL) 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | httpClient := http.DefaultClient 47 | accessToken, err := flow.Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", webapp.WaitOptions{ 48 | ClientSecret: clientSecret, 49 | }) 50 | if err != nil { 51 | panic(err) 52 | } 53 | 54 | fmt.Printf("Access token: %s\n", accessToken.Token) 55 | } 56 | -------------------------------------------------------------------------------- /webapp/local_server.go: -------------------------------------------------------------------------------- 1 | package webapp 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net" 8 | "net/http" 9 | ) 10 | 11 | // CodeResponse represents the code received by the local server's callback handler. 12 | type CodeResponse struct { 13 | Code string 14 | State string 15 | } 16 | 17 | // bindLocalServer initializes a LocalServer that will listen on a randomly available TCP port. 18 | func bindLocalServer() (*localServer, error) { 19 | listener, err := net.Listen("tcp4", "127.0.0.1:0") 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | return &localServer{ 25 | listener: listener, 26 | resultChan: make(chan CodeResponse, 1), 27 | }, nil 28 | } 29 | 30 | type localServer struct { 31 | CallbackPath string 32 | WriteSuccessHTML func(w io.Writer) 33 | 34 | resultChan chan (CodeResponse) 35 | listener net.Listener 36 | } 37 | 38 | func (s *localServer) Port() int { 39 | return s.listener.Addr().(*net.TCPAddr).Port 40 | } 41 | 42 | func (s *localServer) Close() error { 43 | return s.listener.Close() 44 | } 45 | 46 | func (s *localServer) Serve() error { 47 | return http.Serve(s.listener, s) 48 | } 49 | 50 | func (s *localServer) WaitForCode(ctx context.Context) (CodeResponse, error) { 51 | select { 52 | case <-ctx.Done(): 53 | return CodeResponse{}, ctx.Err() 54 | case code := <-s.resultChan: 55 | return code, nil 56 | } 57 | } 58 | 59 | // ServeHTTP implements http.Handler. 60 | func (s *localServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 61 | if s.CallbackPath != "" && r.URL.Path != s.CallbackPath { 62 | w.WriteHeader(404) 63 | return 64 | } 65 | defer func() { 66 | _ = s.Close() 67 | }() 68 | 69 | params := r.URL.Query() 70 | s.resultChan <- CodeResponse{ 71 | Code: params.Get("code"), 72 | State: params.Get("state"), 73 | } 74 | 75 | w.Header().Add("content-type", "text/html") 76 | if s.WriteSuccessHTML != nil { 77 | s.WriteSuccessHTML(w) 78 | } else { 79 | defaultSuccessHTML(w) 80 | } 81 | } 82 | 83 | func defaultSuccessHTML(w io.Writer) { 84 | fmt.Fprintf(w, "

You may now close this page and return to the client app.

") 85 | } 86 | -------------------------------------------------------------------------------- /webapp/local_server_test.go: -------------------------------------------------------------------------------- 1 | package webapp 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "net" 7 | "net/http" 8 | "testing" 9 | ) 10 | 11 | type fakeListener struct { 12 | closed bool 13 | addr *net.TCPAddr 14 | } 15 | 16 | func (l *fakeListener) Accept() (net.Conn, error) { 17 | return nil, errors.New("not implemented") 18 | } 19 | func (l *fakeListener) Close() error { 20 | l.closed = true 21 | return nil 22 | } 23 | func (l *fakeListener) Addr() net.Addr { 24 | return l.addr 25 | } 26 | 27 | type responseWriter struct { 28 | header http.Header 29 | written bytes.Buffer 30 | status int 31 | } 32 | 33 | func (w *responseWriter) Header() http.Header { 34 | if w.header == nil { 35 | w.header = make(http.Header) 36 | } 37 | return w.header 38 | } 39 | func (w *responseWriter) Write(b []byte) (int, error) { 40 | if w.status == 0 { 41 | w.status = 200 42 | } 43 | return w.written.Write(b) 44 | } 45 | func (w *responseWriter) WriteHeader(s int) { 46 | w.status = s 47 | } 48 | 49 | func Test_localServer_ServeHTTP(t *testing.T) { 50 | listener := &fakeListener{} 51 | s := &localServer{ 52 | CallbackPath: "/hello", 53 | resultChan: make(chan CodeResponse, 1), 54 | listener: listener, 55 | } 56 | 57 | w1 := &responseWriter{} 58 | w2 := &responseWriter{} 59 | 60 | serveChan := make(chan struct{}) 61 | go func() { 62 | req1, _ := http.NewRequest("GET", "http://127.0.0.1:12345/favicon.ico", nil) 63 | s.ServeHTTP(w1, req1) 64 | req2, _ := http.NewRequest("GET", "http://127.0.0.1:12345/hello?code=ABC-123&state=xy%2Fz", nil) 65 | s.ServeHTTP(w2, req2) 66 | serveChan <- struct{}{} 67 | }() 68 | 69 | res := <-s.resultChan 70 | if res.Code != "ABC-123" { 71 | t.Errorf("got code %q", res.Code) 72 | } 73 | if res.State != "xy/z" { 74 | t.Errorf("got state %q", res.State) 75 | } 76 | 77 | <-serveChan 78 | if w1.status != 404 { 79 | t.Errorf("status = %d", w2.status) 80 | } 81 | 82 | if w2.status != 200 { 83 | t.Errorf("status = %d", w2.status) 84 | } 85 | if w2.written.String() != "

You may now close this page and return to the client app.

" { 86 | t.Errorf("written: %q", w2.written.String()) 87 | } 88 | if w2.Header().Get("Content-Type") != "text/html" { 89 | t.Errorf("Content-Type: %v", w2.Header().Get("Content-Type")) 90 | } 91 | if !listener.closed { 92 | t.Error("expected listener to be closed") 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /webapp/webapp_flow.go: -------------------------------------------------------------------------------- 1 | // Package webapp implements the OAuth Web Application authorization flow for client applications by 2 | // starting a server at localhost to receive the web redirect after the user has authorized the application. 3 | package webapp 4 | 5 | import ( 6 | "context" 7 | "crypto/rand" 8 | "encoding/hex" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net/http" 13 | "net/url" 14 | "strings" 15 | 16 | "github.com/cli/oauth/api" 17 | ) 18 | 19 | type httpClient interface { 20 | PostForm(string, url.Values) (*http.Response, error) 21 | } 22 | 23 | // Flow holds the state for the steps of OAuth Web Application flow. 24 | type Flow struct { 25 | server *localServer 26 | clientID string 27 | state string 28 | } 29 | 30 | // InitFlow creates a new Flow instance by detecting a locally available port number. 31 | func InitFlow() (*Flow, error) { 32 | server, err := bindLocalServer() 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | state, _ := randomString(20) 38 | 39 | return &Flow{ 40 | server: server, 41 | state: state, 42 | }, nil 43 | } 44 | 45 | // BrowserParams are GET query parameters for initiating the web flow. 46 | type BrowserParams struct { 47 | ClientID string 48 | RedirectURI string 49 | Scopes []string 50 | Audience string 51 | LoginHandle string 52 | AllowSignup bool 53 | } 54 | 55 | // BrowserURL appends GET query parameters to baseURL and returns the url that the user should 56 | // navigate to in their web browser. 57 | func (flow *Flow) BrowserURL(baseURL string, params BrowserParams) (string, error) { 58 | ru, err := url.Parse(params.RedirectURI) 59 | if err != nil { 60 | return "", err 61 | } 62 | 63 | ru.Host = fmt.Sprintf("%s:%d", ru.Hostname(), flow.server.Port()) 64 | flow.server.CallbackPath = ru.Path 65 | flow.clientID = params.ClientID 66 | 67 | q := url.Values{} 68 | q.Set("client_id", params.ClientID) 69 | q.Set("redirect_uri", ru.String()) 70 | q.Set("scope", strings.Join(params.Scopes, " ")) 71 | q.Set("state", flow.state) 72 | 73 | if params.Audience != "" { 74 | q.Set("audience", params.Audience) 75 | } 76 | if params.LoginHandle != "" { 77 | q.Set("login", params.LoginHandle) 78 | } 79 | if !params.AllowSignup { 80 | q.Set("allow_signup", "false") 81 | } 82 | 83 | return fmt.Sprintf("%s?%s", baseURL, q.Encode()), nil 84 | } 85 | 86 | // StartServer starts the localhost server and blocks until it has received the web redirect. The 87 | // writeSuccess function can be used to render a HTML page to the user upon completion. 88 | func (flow *Flow) StartServer(writeSuccess func(io.Writer)) error { 89 | flow.server.WriteSuccessHTML = writeSuccess 90 | return flow.server.Serve() 91 | } 92 | 93 | // AccessToken blocks until the browser flow has completed and returns the access token. 94 | // 95 | // Deprecated: use Wait. 96 | func (flow *Flow) AccessToken(c httpClient, tokenURL, clientSecret string) (*api.AccessToken, error) { 97 | return flow.Wait(context.Background(), c, tokenURL, WaitOptions{ClientSecret: clientSecret}) 98 | } 99 | 100 | // WaitOptions specifies parameters to exchange the access token for. 101 | type WaitOptions struct { 102 | // ClientSecret is the app client secret value. 103 | ClientSecret string 104 | } 105 | 106 | // Wait blocks until the browser flow has completed and returns the access token. 107 | func (flow *Flow) Wait(ctx context.Context, c httpClient, tokenURL string, opts WaitOptions) (*api.AccessToken, error) { 108 | code, err := flow.server.WaitForCode(ctx) 109 | if err != nil { 110 | return nil, err 111 | } 112 | if code.State != flow.state { 113 | return nil, errors.New("state mismatch") 114 | } 115 | 116 | resp, err := api.PostForm(c, tokenURL, 117 | url.Values{ 118 | "client_id": {flow.clientID}, 119 | "client_secret": {opts.ClientSecret}, 120 | "code": {code.Code}, 121 | "state": {flow.state}, 122 | }) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | return resp.AccessToken() 128 | } 129 | 130 | func randomString(length int) (string, error) { 131 | b := make([]byte, length/2) 132 | _, err := rand.Read(b) 133 | if err != nil { 134 | return "", err 135 | } 136 | return hex.EncodeToString(b), nil 137 | } 138 | -------------------------------------------------------------------------------- /webapp/webapp_flow_test.go: -------------------------------------------------------------------------------- 1 | package webapp 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io/ioutil" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "testing" 11 | ) 12 | 13 | func TestFlow_BrowserURL(t *testing.T) { 14 | server := &localServer{ 15 | listener: &fakeListener{ 16 | addr: &net.TCPAddr{Port: 12345}, 17 | }, 18 | } 19 | 20 | type fields struct { 21 | server *localServer 22 | clientID string 23 | state string 24 | } 25 | type args struct { 26 | baseURL string 27 | params BrowserParams 28 | } 29 | tests := []struct { 30 | name string 31 | fields fields 32 | args args 33 | want string 34 | wantErr bool 35 | }{ 36 | { 37 | name: "happy path", 38 | fields: fields{ 39 | server: server, 40 | state: "xy/z", 41 | }, 42 | args: args{ 43 | baseURL: "https://github.com/authorize", 44 | params: BrowserParams{ 45 | ClientID: "CLIENT-ID", 46 | RedirectURI: "http://127.0.0.1/hello", 47 | Scopes: []string{"repo", "read:org"}, 48 | AllowSignup: true, 49 | }, 50 | }, 51 | want: "https://github.com/authorize?client_id=CLIENT-ID&redirect_uri=http%3A%2F%2F127.0.0.1%3A12345%2Fhello&scope=repo+read%3Aorg&state=xy%2Fz", 52 | wantErr: false, 53 | }, 54 | { 55 | name: "happy path with audience", 56 | fields: fields{ 57 | server: server, 58 | state: "xy/z", 59 | }, 60 | args: args{ 61 | baseURL: "https://github.com/authorize", 62 | params: BrowserParams{ 63 | ClientID: "CLIENT-ID", 64 | RedirectURI: "http://127.0.0.1/hello", 65 | Scopes: []string{"repo", "read:org"}, 66 | AllowSignup: true, 67 | Audience: "https://api.github.com", 68 | }, 69 | }, 70 | want: "https://github.com/authorize?audience=https%3A%2F%2Fapi.github.com&client_id=CLIENT-ID&redirect_uri=http%3A%2F%2F127.0.0.1%3A12345%2Fhello&scope=repo+read%3Aorg&state=xy%2Fz", 71 | }, 72 | } 73 | for _, tt := range tests { 74 | t.Run(tt.name, func(t *testing.T) { 75 | flow := &Flow{ 76 | server: tt.fields.server, 77 | clientID: tt.fields.clientID, 78 | state: tt.fields.state, 79 | } 80 | got, err := flow.BrowserURL(tt.args.baseURL, tt.args.params) 81 | if (err != nil) != tt.wantErr { 82 | t.Errorf("Flow.BrowserURL() error = %v, wantErr %v", err, tt.wantErr) 83 | return 84 | } 85 | if got != tt.want { 86 | t.Errorf("Flow.BrowserURL() = %v, want %v", got, tt.want) 87 | } 88 | }) 89 | } 90 | } 91 | 92 | type apiStub struct { 93 | status int 94 | body string 95 | contentType string 96 | } 97 | 98 | type postArgs struct { 99 | url string 100 | params url.Values 101 | } 102 | 103 | type apiClient struct { 104 | stubs []apiStub 105 | calls []postArgs 106 | 107 | postCount int 108 | } 109 | 110 | func (c *apiClient) PostForm(u string, params url.Values) (*http.Response, error) { 111 | stub := c.stubs[c.postCount] 112 | c.calls = append(c.calls, postArgs{url: u, params: params}) 113 | c.postCount++ 114 | return &http.Response{ 115 | Body: ioutil.NopCloser(bytes.NewBufferString(stub.body)), 116 | Header: http.Header{ 117 | "Content-Type": {stub.contentType}, 118 | }, 119 | StatusCode: stub.status, 120 | }, nil 121 | } 122 | 123 | func TestFlow_AccessToken(t *testing.T) { 124 | server := &localServer{ 125 | listener: &fakeListener{ 126 | addr: &net.TCPAddr{Port: 12345}, 127 | }, 128 | resultChan: make(chan CodeResponse), 129 | } 130 | 131 | flow := Flow{ 132 | server: server, 133 | clientID: "CLIENT-ID", 134 | state: "xy/z", 135 | } 136 | 137 | client := &apiClient{ 138 | stubs: []apiStub{ 139 | { 140 | body: "access_token=ATOKEN&token_type=bearer&scope=repo+gist", 141 | status: 200, 142 | contentType: "application/x-www-form-urlencoded; charset=utf-8", 143 | }, 144 | }, 145 | } 146 | 147 | go func() { 148 | server.resultChan <- CodeResponse{ 149 | Code: "ABC-123", 150 | State: "xy/z", 151 | } 152 | }() 153 | 154 | token, err := flow.Wait(context.Background(), client, "https://github.com/access_token", WaitOptions{ClientSecret: "OAUTH-SEKRIT"}) 155 | if err != nil { 156 | t.Fatalf("AccessToken() error: %v", err) 157 | } 158 | 159 | if len(client.calls) != 1 { 160 | t.Fatalf("expected 1 HTTP POST, got %d", len(client.calls)) 161 | } 162 | apiPost := client.calls[0] 163 | if apiPost.url != "https://github.com/access_token" { 164 | t.Errorf("HTTP POST to %q", apiPost.url) 165 | } 166 | if params := apiPost.params.Encode(); params != "client_id=CLIENT-ID&client_secret=OAUTH-SEKRIT&code=ABC-123&state=xy%2Fz" { 167 | t.Errorf("HTTP POST params: %v", params) 168 | } 169 | 170 | if token.Token != "ATOKEN" { 171 | t.Errorf("Token = %q", token.Token) 172 | } 173 | } 174 | --------------------------------------------------------------------------------