├── .gitignore ├── .travis.yml ├── Gopkg.lock ├── Gopkg.toml ├── LICENSE ├── README.md ├── api ├── api.go └── api_test.go ├── contrib ├── oauth2_proxy.cfg.example └── oauth2_proxy.service.example ├── cookie ├── cookies.go ├── cookies_test.go └── nonce.go ├── dist.sh ├── env_options.go ├── env_options_test.go ├── htpasswd.go ├── htpasswd_test.go ├── http.go ├── logging_handler.go ├── logging_handler_test.go ├── main.go ├── oauthproxy.go ├── oauthproxy_test.go ├── options.go ├── options_test.go ├── providers ├── azure.go ├── azure_test.go ├── facebook.go ├── github.go ├── github_test.go ├── gitlab.go ├── gitlab_test.go ├── google.go ├── google_test.go ├── internal_util.go ├── internal_util_test.go ├── linkedin.go ├── linkedin_test.go ├── oidc.go ├── provider_data.go ├── provider_default.go ├── provider_default_test.go ├── providers.go ├── session_state.go └── session_state_test.go ├── string_array.go ├── templates.go ├── templates_test.go ├── test.sh ├── validator.go ├── validator_test.go ├── validator_watcher_copy_test.go ├── validator_watcher_test.go ├── version.go ├── watcher.go └── watcher_unsupported.go /.gitignore: -------------------------------------------------------------------------------- 1 | oauth2_proxy 2 | vendor 3 | dist 4 | .godeps 5 | *.exe 6 | 7 | 8 | # Go.gitignore 9 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 10 | *.o 11 | *.a 12 | *.so 13 | 14 | # Folders 15 | _obj 16 | _test 17 | 18 | # Architecture specific extensions/prefixes 19 | *.[568vq] 20 | [568vq].out 21 | 22 | *.cgo1.go 23 | *.cgo2.c 24 | _cgo_defun.c 25 | _cgo_gotypes.go 26 | _cgo_export.* 27 | 28 | _testmain.go 29 | 30 | # Editor swap/temp files 31 | .*.swp 32 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.8.x 4 | - 1.9.x 5 | script: 6 | - wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64 7 | - chmod +x dep 8 | - ./dep ensure 9 | - ./test.sh 10 | sudo: false 11 | notifications: 12 | email: false 13 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | name = "cloud.google.com/go" 6 | packages = ["compute/metadata"] 7 | revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613" 8 | version = "v0.16.0" 9 | 10 | [[projects]] 11 | name = "github.com/BurntSushi/toml" 12 | packages = ["."] 13 | revision = "b26d9c308763d68093482582cea63d69be07a0f0" 14 | version = "v0.3.0" 15 | 16 | [[projects]] 17 | name = "github.com/bitly/go-simplejson" 18 | packages = ["."] 19 | revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44" 20 | version = "v0.5.0" 21 | 22 | [[projects]] 23 | branch = "v2" 24 | name = "github.com/coreos/go-oidc" 25 | packages = ["."] 26 | revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8" 27 | 28 | [[projects]] 29 | name = "github.com/davecgh/go-spew" 30 | packages = ["spew"] 31 | revision = "346938d642f2ec3594ed81d874461961cd0faa76" 32 | version = "v1.1.0" 33 | 34 | [[projects]] 35 | branch = "master" 36 | name = "github.com/golang/protobuf" 37 | packages = ["proto"] 38 | revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" 39 | 40 | [[projects]] 41 | name = "github.com/mbland/hmacauth" 42 | packages = ["."] 43 | revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb" 44 | version = "1.0.2" 45 | 46 | [[projects]] 47 | branch = "master" 48 | name = "github.com/mreiferson/go-options" 49 | packages = ["."] 50 | revision = "77551d20752b54535462404ad9d877ebdb26e53d" 51 | 52 | [[projects]] 53 | name = "github.com/pmezard/go-difflib" 54 | packages = ["difflib"] 55 | revision = "792786c7400a136282c1664665ae0a8db921c6c2" 56 | version = "v1.0.0" 57 | 58 | [[projects]] 59 | branch = "master" 60 | name = "github.com/pquerna/cachecontrol" 61 | packages = [ 62 | ".", 63 | "cacheobject" 64 | ] 65 | revision = "0dec1b30a0215bb68605dfc568e8855066c9202d" 66 | 67 | [[projects]] 68 | name = "github.com/stretchr/testify" 69 | packages = ["assert"] 70 | revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" 71 | version = "v1.1.4" 72 | 73 | [[projects]] 74 | branch = "master" 75 | name = "golang.org/x/crypto" 76 | packages = [ 77 | "bcrypt", 78 | "blowfish", 79 | "ed25519", 80 | "ed25519/internal/edwards25519" 81 | ] 82 | revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94" 83 | 84 | [[projects]] 85 | branch = "master" 86 | name = "golang.org/x/net" 87 | packages = [ 88 | "context", 89 | "context/ctxhttp" 90 | ] 91 | revision = "9dfe39835686865bff950a07b394c12a98ddc811" 92 | 93 | [[projects]] 94 | branch = "master" 95 | name = "golang.org/x/oauth2" 96 | packages = [ 97 | ".", 98 | "google", 99 | "internal", 100 | "jws", 101 | "jwt" 102 | ] 103 | revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" 104 | 105 | [[projects]] 106 | branch = "master" 107 | name = "google.golang.org/api" 108 | packages = [ 109 | "admin/directory/v1", 110 | "gensupport", 111 | "googleapi", 112 | "googleapi/internal/uritemplates" 113 | ] 114 | revision = "8791354e7ab150705ede13637a18c1fcc16b62e8" 115 | 116 | [[projects]] 117 | name = "google.golang.org/appengine" 118 | packages = [ 119 | ".", 120 | "internal", 121 | "internal/app_identity", 122 | "internal/base", 123 | "internal/datastore", 124 | "internal/log", 125 | "internal/modules", 126 | "internal/remote_api", 127 | "internal/urlfetch", 128 | "urlfetch" 129 | ] 130 | revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" 131 | version = "v1.0.0" 132 | 133 | [[projects]] 134 | name = "gopkg.in/fsnotify.v1" 135 | packages = ["."] 136 | revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6" 137 | version = "v1.2.11" 138 | 139 | [[projects]] 140 | name = "gopkg.in/square/go-jose.v2" 141 | packages = [ 142 | ".", 143 | "cipher", 144 | "json" 145 | ] 146 | revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" 147 | version = "v2.1.3" 148 | 149 | [solve-meta] 150 | analyzer-name = "dep" 151 | analyzer-version = 1 152 | inputs-digest = "b502c41a61115d14d6379be26b0300f65d173bdad852f0170d387ebf2d7ec173" 153 | solver-name = "gps-cdcl" 154 | solver-version = 1 155 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | 2 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 3 | # for detailed Gopkg.toml documentation. 4 | # 5 | 6 | [[constraint]] 7 | name = "github.com/18F/hmacauth" 8 | version = "~1.0.1" 9 | 10 | [[constraint]] 11 | name = "github.com/BurntSushi/toml" 12 | version = "~0.3.0" 13 | 14 | [[constraint]] 15 | name = "github.com/bitly/go-simplejson" 16 | version = "~0.5.0" 17 | 18 | [[constraint]] 19 | branch = "v2" 20 | name = "github.com/coreos/go-oidc" 21 | 22 | [[constraint]] 23 | branch = "master" 24 | name = "github.com/mreiferson/go-options" 25 | 26 | [[constraint]] 27 | name = "github.com/stretchr/testify" 28 | version = "~1.1.4" 29 | 30 | [[constraint]] 31 | branch = "master" 32 | name = "golang.org/x/oauth2" 33 | 34 | [[constraint]] 35 | branch = "master" 36 | name = "google.golang.org/api" 37 | 38 | [[constraint]] 39 | name = "gopkg.in/fsnotify.v1" 40 | version = "~1.2.0" 41 | 42 | [[constraint]] 43 | branch = "master" 44 | name = "golang.org/x/crypto" 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy 2 | of this software and associated documentation files (the "Software"), to deal 3 | in the Software without restriction, including without limitation the rights 4 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 5 | copies of the Software, and to permit persons to whom the Software is 6 | furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in 9 | all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 17 | THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | oauth2_proxy 2 | ================= 3 | 4 | A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others) 5 | to validate accounts by email, domain or group. 6 | 7 | [![Build Status](https://secure.travis-ci.org/bitly/oauth2_proxy.svg?branch=master)](http://travis-ci.org/bitly/oauth2_proxy) 8 | 9 | 10 | ![Sign In Page](https://cloud.githubusercontent.com/assets/45028/4970624/7feb7dd8-6886-11e4-93e0-c9904af44ea8.png) 11 | 12 | **NOTICE**: This project was officially archived by Bitly at the end of September 2018. 13 | Bitly will no longer be accepting PRs or helping on issues. 14 | There has been a [discussion](https://github.com/bitly/oauth2_proxy/issues/628) 15 | to find a new home for the project which has led to the following notable forks: 16 | 17 | - [pomerium](https://github.com/pomerium/pomerium) an identity-access proxy, inspired by BeyondCorp. 18 | - [buzzfeed/sso](https://github.com/buzzfeed/sso) a "double OAuth2" flow, where sso-auth is the OAuth2 provider for sso-proxy and Google is the OAuth2 provider for sso-auth. 19 | - [openshift/oauth_proxy](https://github.com/openshift/oauth-proxy) an openshift specific version of this project. 20 | - [pusher/oauth2_proxy](https://github.com/pusher/oauth2_proxy) official hard fork of this project. 21 | 22 | Please submit all future PRs and issues to [pusher/oauth2_proxy](https://github.com/pusher/oauth2_proxy). 23 | 24 | ## Architecture 25 | 26 | ![OAuth2 Proxy Architecture](https://cloud.githubusercontent.com/assets/45028/8027702/bd040b7a-0d6a-11e5-85b9-f8d953d04f39.png) 27 | 28 | ## Installation 29 | 30 | 1. Download [Prebuilt Binary](https://github.com/bitly/oauth2_proxy/releases) (current release is `v2.2`) or build with `$ go get github.com/bitly/oauth2_proxy` which will put the binary in `$GOROOT/bin` 31 | Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`. 32 | ``` 33 | sha256sum -c sha256sum.txt 2>&1 | grep OK 34 | oauth2_proxy-2.3.linux-amd64: OK 35 | ``` 36 | 2. Select a Provider and Register an OAuth Application with a Provider 37 | 3. Configure OAuth2 Proxy using config file, command line options, or environment variables 38 | 4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx) 39 | 40 | ## OAuth Provider Configuration 41 | 42 | You will need to register an OAuth application with a Provider (Google, GitHub or another provider), and configure it with Redirect URI(s) for the domain you intend to run `oauth2_proxy` on. 43 | 44 | Valid providers are : 45 | 46 | * [Google](#google-auth-provider) *default* 47 | * [Azure](#azure-auth-provider) 48 | * [Facebook](#facebook-auth-provider) 49 | * [GitHub](#github-auth-provider) 50 | * [GitLab](#gitlab-auth-provider) 51 | * [LinkedIn](#linkedin-auth-provider) 52 | 53 | The provider can be selected using the `provider` configuration value. 54 | 55 | ### Google Auth Provider 56 | 57 | For Google, the registration steps are: 58 | 59 | 1. Create a new project: https://console.developers.google.com/project 60 | 2. Choose the new project from the top right project dropdown (only if another project is selected) 61 | 3. In the project Dashboard center pane, choose **"API Manager"** 62 | 4. In the left Nav pane, choose **"Credentials"** 63 | 5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save. 64 | 6. In the center pane, choose **"Credentials"** tab. 65 | * Open the **"New credentials"** drop down 66 | * Choose **"OAuth client ID"** 67 | * Choose **"Web application"** 68 | * Application name is freeform, choose something appropriate 69 | * Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com` 70 | * Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback` 71 | * Choose **"Create"** 72 | 4. Take note of the **Client ID** and **Client Secret** 73 | 74 | It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. 75 | 76 | #### Restrict auth to specific Google groups on your domain. (optional) 77 | 78 | 1. Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file. 79 | 2. Make note of the Client ID for a future step. 80 | 3. Under "APIs & Auth", choose APIs. 81 | 4. Click on Admin SDK and then Enable API. 82 | 5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes: 83 | ``` 84 | https://www.googleapis.com/auth/admin.directory.group.readonly 85 | https://www.googleapis.com/auth/admin.directory.user.readonly 86 | ``` 87 | 6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access. 88 | 7. Create or choose an existing administrative email address on the Gmail domain to assign to the ```google-admin-email``` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why. 89 | 8. Create or choose an existing email group and set that email to the ```google-group``` flag. You can pass multiple instances of this flag with different groups 90 | and the user will be checked against all the provided groups. 91 | 9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the ```google-service-account-json``` flag. 92 | 10. Restart oauth2_proxy. 93 | 94 | Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). 95 | 96 | ### Azure Auth Provider 97 | 98 | 1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant. 99 | 2. On the App properties page provide the correct Sign-On URL ie `https://internal.yourcompany.com/oauth2/callback` 100 | 3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=` commandline option. Default the `common` tenant is used. 101 | 102 | The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in. 103 | 104 | 105 | ### Facebook Auth Provider 106 | 107 | 1. Create a new FB App from 108 | 2. Under FB Login, set your Valid OAuth redirect URIs to `https://internal.yourcompany.com/oauth2/callback` 109 | 110 | ### GitHub Auth Provider 111 | 112 | 1. Create a new project: https://github.com/settings/developers 113 | 2. Under `Authorization callback URL` enter the correct url ie `https://internal.yourcompany.com/oauth2/callback` 114 | 115 | The GitHub auth provider supports two additional parameters to restrict authentication to Organization or Team level access. Restricting by org and team is normally accompanied with `--email-domain=*` 116 | 117 | -github-org="": restrict logins to members of this organisation 118 | -github-team="": restrict logins to members of any of these teams (slug), separated by a comma 119 | 120 | If you are using GitHub enterprise, make sure you set the following to the appropriate url: 121 | 122 | -login-url="http(s):///login/oauth/authorize" 123 | -redeem-url="http(s):///login/oauth/access_token" 124 | -validate-url="http(s):///api/v3" 125 | 126 | ### GitLab Auth Provider 127 | 128 | Whether you are using GitLab.com or self-hosting GitLab, follow [these steps to add an application](http://doc.gitlab.com/ce/integration/oauth_provider.html) 129 | 130 | If you are using self-hosted GitLab, make sure you set the following to the appropriate URL: 131 | 132 | -login-url="/oauth/authorize" 133 | -redeem-url="/oauth/token" 134 | -validate-url="/api/v4/user" 135 | 136 | 137 | ### LinkedIn Auth Provider 138 | 139 | For LinkedIn, the registration steps are: 140 | 141 | 1. Create a new project: https://www.linkedin.com/secure/developer 142 | 2. In the OAuth User Agreement section: 143 | * In default scope, select r_basicprofile and r_emailaddress. 144 | * In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback` 145 | 3. Fill in the remaining required fields and Save. 146 | 4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key** 147 | 148 | ### Microsoft Azure AD Provider 149 | 150 | For adding an application to the Microsoft Azure AD follow [these steps to add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/). 151 | 152 | Take note of your `TenantId` if applicable for your situation. The `TenantId` can be used to override the default `common` authorization server with a tenant specific server. 153 | 154 | ### OpenID Connect Provider 155 | 156 | OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many major providers and several open source projects. This provider was originally built against CoreOS Dex and we will use it as an example. 157 | 158 | 1. Launch a Dex instance using the [getting started guide](https://github.com/coreos/dex/blob/master/Documentation/getting-started.md). 159 | 2. Setup oauth2_proxy with the correct provider and using the default ports and callbacks. 160 | 3. Login with the fixture use in the dex guide and run the oauth2_proxy with the following args: 161 | 162 | -provider oidc 163 | -client-id oauth2_proxy 164 | -client-secret proxy 165 | -redirect-url http://127.0.0.1:4180/oauth2/callback 166 | -oidc-issuer-url http://127.0.0.1:5556 167 | -cookie-secure=false 168 | -email-domain example.com 169 | 170 | ## Email Authentication 171 | 172 | To authorize by email domain use `--email-domain=yourcompany.com`. To authorize individual email addresses use `--authenticated-emails-file=/path/to/file` with one email per line. To authorize all email addresses use `--email-domain=*`. 173 | 174 | ## Configuration 175 | 176 | `oauth2_proxy` can be configured via [config file](#config-file), [command line options](#command-line-options) or [environment variables](#environment-variables). 177 | 178 | To generate a strong cookie secret use `python -c 'import os,base64; print base64.urlsafe_b64encode(os.urandom(16))'` 179 | 180 | ### Config File 181 | 182 | An example [oauth2_proxy.cfg](contrib/oauth2_proxy.cfg.example) config file is in the contrib directory. It can be used by specifying `-config=/etc/oauth2_proxy.cfg` 183 | 184 | ### Command Line Options 185 | 186 | ``` 187 | Usage of oauth2_proxy: 188 | -approval-prompt string: OAuth approval_prompt (default "force") 189 | -authenticated-emails-file string: authenticate against emails via file (one per line) 190 | -azure-tenant string: go to a tenant-specific or common (tenant-independent) endpoint. (default "common") 191 | -basic-auth-password string: the password to set when passing the HTTP Basic Auth header 192 | -client-id string: the OAuth Client ID: ie: "123456.apps.googleusercontent.com" 193 | -client-secret string: the OAuth Client Secret 194 | -config string: path to config file 195 | -cookie-domain string: an optional cookie domain to force cookies to (ie: .yourcompany.com) 196 | -cookie-expire duration: expire timeframe for cookie (default 168h0m0s) 197 | -cookie-httponly: set HttpOnly cookie flag (default true) 198 | -cookie-name string: the name of the cookie that the oauth_proxy creates (default "_oauth2_proxy") 199 | -cookie-refresh duration: refresh the cookie after this duration; 0 to disable 200 | -cookie-secret string: the seed string for secure cookies (optionally base64 encoded) 201 | -cookie-secure: set secure (HTTPS) cookie flag (default true) 202 | -custom-templates-dir string: path to custom html templates 203 | -display-htpasswd-form: display username / password login form if an htpasswd file is provided (default true) 204 | -email-domain value: authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email 205 | -footer string: custom footer string. Use "-" to disable default footer. 206 | -github-org string: restrict logins to members of this organisation 207 | -github-team string: restrict logins to members of any of these teams (slug), separated by a comma 208 | -google-admin-email string: the google admin to impersonate for api calls 209 | -google-group value: restrict logins to members of this google group (may be given multiple times). 210 | -google-service-account-json string: the path to the service account json credentials 211 | -htpasswd-file string: additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption 212 | -http-address string: [http://]: or unix:// to listen on for HTTP clients (default "127.0.0.1:4180") 213 | -https-address string: : to listen on for HTTPS clients (default ":443") 214 | -login-url string: Authentication endpoint 215 | -pass-access-token: pass OAuth access_token to upstream via X-Forwarded-Access-Token header 216 | -pass-basic-auth: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream (default true) 217 | -pass-host-header: pass the request Host Header to upstream (default true) 218 | -pass-user-headers: pass X-Forwarded-User and X-Forwarded-Email information to upstream (default true) 219 | -profile-url string: Profile access endpoint 220 | -provider string: OAuth provider (default "google") 221 | -proxy-prefix string: the url root path that this proxy should be nested under (e.g. //sign_in) (default "/oauth2") 222 | -redeem-url string: Token redemption endpoint 223 | -redirect-url string: the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback" 224 | -request-logging: Log requests to stdout (default true) 225 | -request-logging-format: Template for request log lines (see "Logging Format" paragraph below) 226 | -resource string: The resource that is protected (Azure AD only) 227 | -scope string: OAuth scope specification 228 | -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) 229 | -signature-key string: GAP-Signature request signature key (algorithm:secretkey) 230 | -skip-auth-preflight: will skip authentication for OPTIONS requests 231 | -skip-auth-regex value: bypass authentication for requests path's that match (may be given multiple times) 232 | -skip-provider-button: will skip sign-in-page to directly reach the next step: oauth/start 233 | -ssl-insecure-skip-verify: skip validation of certificates presented when using HTTPS 234 | -tls-cert string: path to certificate file 235 | -tls-key string: path to private key file 236 | -upstream value: the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path 237 | -validate-url string: Access token validation endpoint 238 | -version: print version string 239 | ``` 240 | 241 | See below for provider specific options 242 | 243 | ### Upstreams Configuration 244 | 245 | `oauth2_proxy` supports having multiple upstreams, and has the option to pass requests on to HTTP(S) servers or serve static files from the file system. HTTP and HTTPS upstreams are configured by providing a URL such as `http://127.0.0.1:8080/` for the upstream parameter, that will forward all authenticated requests to be forwarded to the upstream server. If you instead provide `http://127.0.0.1:8080/some/path/` then it will only be requests that start with `/some/path/` which are forwarded to the upstream. 246 | 247 | Static file paths are configured as a file:// URL. `file:///var/www/static/` will serve the files from that directory at `http://[oauth2_proxy url]/var/www/static/`, which may not be what you want. You can provide the path to where the files should be available by adding a fragment to the configured URL. The value of the fragment will then be used to specify which path the files are available at. `file:///var/www/static/#/static/` will ie. make `/var/www/static/` available at `http://[oauth2_proxy url]/static/`. 248 | 249 | Multiple upstreams can either be configured by supplying a comma separated list to the `-upstream` parameter, supplying the parameter multiple times or provinding a list in the [config file](#config-file). When multiple upstreams are used routing to them will be based on the path they are set up with. 250 | 251 | ### Environment variables 252 | 253 | The following environment variables can be used in place of the corresponding command-line arguments: 254 | 255 | - `OAUTH2_PROXY_CLIENT_ID` 256 | - `OAUTH2_PROXY_CLIENT_SECRET` 257 | - `OAUTH2_PROXY_COOKIE_NAME` 258 | - `OAUTH2_PROXY_COOKIE_SECRET` 259 | - `OAUTH2_PROXY_COOKIE_DOMAIN` 260 | - `OAUTH2_PROXY_COOKIE_EXPIRE` 261 | - `OAUTH2_PROXY_COOKIE_REFRESH` 262 | - `OAUTH2_PROXY_SIGNATURE_KEY` 263 | 264 | ## SSL Configuration 265 | 266 | There are two recommended configurations. 267 | 268 | 1) Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`. 269 | 270 | The command line to run `oauth2_proxy` in this configuration would look like this: 271 | 272 | ```bash 273 | ./oauth2_proxy \ 274 | --email-domain="yourcompany.com" \ 275 | --upstream=http://127.0.0.1:8080/ \ 276 | --tls-cert=/path/to/cert.pem \ 277 | --tls-key=/path/to/cert.key \ 278 | --cookie-secret=... \ 279 | --cookie-secure=true \ 280 | --provider=... \ 281 | --client-id=... \ 282 | --client-secret=... 283 | ``` 284 | 285 | 286 | 2) Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or .... 287 | 288 | Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an 289 | external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or 290 | `--http-address="http://:4180"`. 291 | 292 | Nginx will listen on port `443` and handle SSL connections while proxying to `oauth2_proxy` on port `4180`. 293 | `oauth2_proxy` will then authenticate requests for an upstream application. The external endpoint for this example 294 | would be `https://internal.yourcompany.com/`. 295 | 296 | An example Nginx config follows. Note the use of `Strict-Transport-Security` header to pin requests to SSL 297 | via [HSTS](http://en.wikipedia.org/wiki/HTTP_Strict_Transport_Security): 298 | 299 | ``` 300 | server { 301 | listen 443 default ssl; 302 | server_name internal.yourcompany.com; 303 | ssl_certificate /path/to/cert.pem; 304 | ssl_certificate_key /path/to/cert.key; 305 | add_header Strict-Transport-Security max-age=2592000; 306 | 307 | location / { 308 | proxy_pass http://127.0.0.1:4180; 309 | proxy_set_header Host $host; 310 | proxy_set_header X-Real-IP $remote_addr; 311 | proxy_set_header X-Scheme $scheme; 312 | proxy_connect_timeout 1; 313 | proxy_send_timeout 30; 314 | proxy_read_timeout 30; 315 | } 316 | } 317 | ``` 318 | 319 | The command line to run `oauth2_proxy` in this configuration would look like this: 320 | 321 | ```bash 322 | ./oauth2_proxy \ 323 | --email-domain="yourcompany.com" \ 324 | --upstream=http://127.0.0.1:8080/ \ 325 | --cookie-secret=... \ 326 | --cookie-secure=true \ 327 | --provider=... \ 328 | --client-id=... \ 329 | --client-secret=... 330 | ``` 331 | 332 | ## Endpoint Documentation 333 | 334 | OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable. 335 | 336 | * /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info 337 | * /ping - returns an 200 OK response 338 | * /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) 339 | * /oauth2/start - a URL that will redirect to start the OAuth cycle 340 | * /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url. 341 | * /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request) 342 | 343 | ## Request signatures 344 | 345 | If `signature_key` is defined, proxied requests will be signed with the 346 | `GAP-Signature` header, which is a [Hash-based Message Authentication Code 347 | (HMAC)](https://en.wikipedia.org/wiki/Hash-based_message_authentication_code) 348 | of selected request information and the request body [see `SIGNATURE_HEADERS` 349 | in `oauthproxy.go`](./oauthproxy.go). 350 | 351 | `signature_key` must be of the form `algorithm:secretkey`, (ie: `signature_key = "sha1:secret0"`) 352 | 353 | For more information about HMAC request signature validation, read the 354 | following: 355 | 356 | * [Amazon Web Services: Signing and Authenticating REST 357 | Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) 358 | * [rc3.org: Using HMAC to authenticate Web service 359 | requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/) 360 | 361 | ## Logging Format 362 | 363 | By default, OAuth2 Proxy logs requests to stdout in a format similar to Apache Combined Log. 364 | 365 | ``` 366 | - [19/Mar/2015:17:20:19 -0400] GET "/path/" HTTP/1.1 "" 367 | ``` 368 | 369 | If you require a different format than that, you can configure it with the `-request-logging-format` flag. 370 | The default format is configured as follows: 371 | 372 | ``` 373 | {{.Client}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}} 374 | ``` 375 | 376 | [See `logMessageData` in `logging_handler.go`](./logging_handler.go) for all available variables. 377 | 378 | ## Adding a new Provider 379 | 380 | Follow the examples in the [`providers` package](providers/) to define a new 381 | `Provider` instance. Add a new `case` to 382 | [`providers.New()`](providers/providers.go) to allow `oauth2_proxy` to use the 383 | new `Provider`. 384 | 385 | ## Configuring for use with the Nginx `auth_request` directive 386 | 387 | The [Nginx `auth_request` directive](http://nginx.org/en/docs/http/ngx_http_auth_request_module.html) allows Nginx to authenticate requests via the oauth2_proxy's `/auth` endpoint, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the request through. For example: 388 | 389 | ```nginx 390 | server { 391 | listen 443 ssl; 392 | server_name ...; 393 | include ssl/ssl.conf; 394 | 395 | location /oauth2/ { 396 | proxy_pass http://127.0.0.1:4180; 397 | proxy_set_header Host $host; 398 | proxy_set_header X-Real-IP $remote_addr; 399 | proxy_set_header X-Scheme $scheme; 400 | proxy_set_header X-Auth-Request-Redirect $request_uri; 401 | } 402 | location = /oauth2/auth { 403 | proxy_pass http://127.0.0.1:4180; 404 | proxy_set_header Host $host; 405 | proxy_set_header X-Real-IP $remote_addr; 406 | proxy_set_header X-Scheme $scheme; 407 | # nginx auth_request includes headers but not body 408 | proxy_set_header Content-Length ""; 409 | proxy_pass_request_body off; 410 | } 411 | 412 | location / { 413 | auth_request /oauth2/auth; 414 | error_page 401 = /oauth2/sign_in; 415 | 416 | # pass information via X-User and X-Email headers to backend, 417 | # requires running with --set-xauthrequest flag 418 | auth_request_set $user $upstream_http_x_auth_request_user; 419 | auth_request_set $email $upstream_http_x_auth_request_email; 420 | proxy_set_header X-User $user; 421 | proxy_set_header X-Email $email; 422 | 423 | # if you enabled --cookie-refresh, this is needed for it to work with auth_request 424 | auth_request_set $auth_cookie $upstream_http_set_cookie; 425 | add_header Set-Cookie $auth_cookie; 426 | 427 | proxy_pass http://backend/; 428 | # or "root /path/to/site;" or "fastcgi_pass ..." etc 429 | } 430 | } 431 | ``` 432 | -------------------------------------------------------------------------------- /api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | 10 | "github.com/bitly/go-simplejson" 11 | ) 12 | 13 | func Request(req *http.Request) (*simplejson.Json, error) { 14 | resp, err := http.DefaultClient.Do(req) 15 | if err != nil { 16 | log.Printf("%s %s %s", req.Method, req.URL, err) 17 | return nil, err 18 | } 19 | body, err := ioutil.ReadAll(resp.Body) 20 | resp.Body.Close() 21 | log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) 22 | if err != nil { 23 | return nil, err 24 | } 25 | if resp.StatusCode != 200 { 26 | return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) 27 | } 28 | data, err := simplejson.NewJson(body) 29 | if err != nil { 30 | return nil, err 31 | } 32 | return data, nil 33 | } 34 | 35 | func RequestJson(req *http.Request, v interface{}) error { 36 | resp, err := http.DefaultClient.Do(req) 37 | if err != nil { 38 | log.Printf("%s %s %s", req.Method, req.URL, err) 39 | return err 40 | } 41 | body, err := ioutil.ReadAll(resp.Body) 42 | resp.Body.Close() 43 | log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) 44 | if err != nil { 45 | return err 46 | } 47 | if resp.StatusCode != 200 { 48 | return fmt.Errorf("got %d %s", resp.StatusCode, body) 49 | } 50 | return json.Unmarshal(body, v) 51 | } 52 | 53 | func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { 54 | req, err := http.NewRequest("GET", url, nil) 55 | if err != nil { 56 | return nil, err 57 | } 58 | req.Header = header 59 | 60 | return http.DefaultClient.Do(req) 61 | } 62 | -------------------------------------------------------------------------------- /api/api_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "github.com/bitly/go-simplejson" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func testBackend(response_code int, payload string) *httptest.Server { 15 | return httptest.NewServer(http.HandlerFunc( 16 | func(w http.ResponseWriter, r *http.Request) { 17 | w.WriteHeader(response_code) 18 | w.Write([]byte(payload)) 19 | })) 20 | } 21 | 22 | func TestRequest(t *testing.T) { 23 | backend := testBackend(200, "{\"foo\": \"bar\"}") 24 | defer backend.Close() 25 | 26 | req, _ := http.NewRequest("GET", backend.URL, nil) 27 | response, err := Request(req) 28 | assert.Equal(t, nil, err) 29 | result, err := response.Get("foo").String() 30 | assert.Equal(t, nil, err) 31 | assert.Equal(t, "bar", result) 32 | } 33 | 34 | func TestRequestFailure(t *testing.T) { 35 | // Create a backend to generate a test URL, then close it to cause a 36 | // connection error. 37 | backend := testBackend(200, "{\"foo\": \"bar\"}") 38 | backend.Close() 39 | 40 | req, err := http.NewRequest("GET", backend.URL, nil) 41 | assert.Equal(t, nil, err) 42 | resp, err := Request(req) 43 | assert.Equal(t, (*simplejson.Json)(nil), resp) 44 | assert.NotEqual(t, nil, err) 45 | if !strings.Contains(err.Error(), "refused") { 46 | t.Error("expected error when a connection fails: ", err) 47 | } 48 | } 49 | 50 | func TestHttpErrorCode(t *testing.T) { 51 | backend := testBackend(404, "{\"foo\": \"bar\"}") 52 | defer backend.Close() 53 | 54 | req, err := http.NewRequest("GET", backend.URL, nil) 55 | assert.Equal(t, nil, err) 56 | resp, err := Request(req) 57 | assert.Equal(t, (*simplejson.Json)(nil), resp) 58 | assert.NotEqual(t, nil, err) 59 | } 60 | 61 | func TestJsonParsingError(t *testing.T) { 62 | backend := testBackend(200, "not well-formed JSON") 63 | defer backend.Close() 64 | 65 | req, err := http.NewRequest("GET", backend.URL, nil) 66 | assert.Equal(t, nil, err) 67 | resp, err := Request(req) 68 | assert.Equal(t, (*simplejson.Json)(nil), resp) 69 | assert.NotEqual(t, nil, err) 70 | } 71 | 72 | // Parsing a URL practically never fails, so we won't cover that test case. 73 | func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { 74 | backend := httptest.NewServer(http.HandlerFunc( 75 | func(w http.ResponseWriter, r *http.Request) { 76 | token := r.FormValue("access_token") 77 | if r.URL.Path == "/" && token == "my_token" { 78 | w.WriteHeader(200) 79 | w.Write([]byte("some payload")) 80 | } else { 81 | w.WriteHeader(403) 82 | } 83 | })) 84 | defer backend.Close() 85 | 86 | response, err := RequestUnparsedResponse( 87 | backend.URL+"?access_token=my_token", nil) 88 | assert.Equal(t, nil, err) 89 | assert.Equal(t, 200, response.StatusCode) 90 | body, err := ioutil.ReadAll(response.Body) 91 | assert.Equal(t, nil, err) 92 | response.Body.Close() 93 | assert.Equal(t, "some payload", string(body)) 94 | } 95 | 96 | func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { 97 | backend := testBackend(200, "some payload") 98 | // Close the backend now to force a request failure. 99 | backend.Close() 100 | 101 | response, err := RequestUnparsedResponse( 102 | backend.URL+"?access_token=my_token", nil) 103 | assert.NotEqual(t, nil, err) 104 | assert.Equal(t, (*http.Response)(nil), response) 105 | } 106 | 107 | func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { 108 | backend := httptest.NewServer(http.HandlerFunc( 109 | func(w http.ResponseWriter, r *http.Request) { 110 | if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { 111 | w.WriteHeader(200) 112 | w.Write([]byte("some payload")) 113 | } else { 114 | w.WriteHeader(403) 115 | } 116 | })) 117 | defer backend.Close() 118 | 119 | headers := make(http.Header) 120 | headers.Set("Auth", "my_token") 121 | response, err := RequestUnparsedResponse(backend.URL, headers) 122 | assert.Equal(t, nil, err) 123 | assert.Equal(t, 200, response.StatusCode) 124 | body, err := ioutil.ReadAll(response.Body) 125 | assert.Equal(t, nil, err) 126 | response.Body.Close() 127 | assert.Equal(t, "some payload", string(body)) 128 | } 129 | -------------------------------------------------------------------------------- /contrib/oauth2_proxy.cfg.example: -------------------------------------------------------------------------------- 1 | ## OAuth2 Proxy Config File 2 | ## https://github.com/bitly/oauth2_proxy 3 | 4 | ## : to listen on for HTTP/HTTPS clients 5 | # http_address = "127.0.0.1:4180" 6 | # https_address = ":443" 7 | 8 | ## TLS Settings 9 | # tls_cert_file = "" 10 | # tls_key_file = "" 11 | 12 | ## the OAuth Redirect URL. 13 | # defaults to the "https://" + requested host header + "/oauth2/callback" 14 | # redirect_url = "https://internalapp.yourcompany.com/oauth2/callback" 15 | 16 | ## the http url(s) of the upstream endpoint. If multiple, routing is based on path 17 | # upstreams = [ 18 | # "http://127.0.0.1:8080/" 19 | # ] 20 | 21 | ## Log requests to stdout 22 | # request_logging = true 23 | 24 | ## pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream 25 | # pass_basic_auth = true 26 | # pass_user_headers = true 27 | ## pass the request Host Header to upstream 28 | ## when disabled the upstream Host is used as the Host Header 29 | # pass_host_header = true 30 | 31 | ## Email Domains to allow authentication for (this authorizes any email on this domain) 32 | ## for more granular authorization use `authenticated_emails_file` 33 | ## To authorize any email addresses use "*" 34 | # email_domains = [ 35 | # "yourcompany.com" 36 | # ] 37 | 38 | ## The OAuth Client ID, Secret 39 | # client_id = "123456.apps.googleusercontent.com" 40 | # client_secret = "" 41 | 42 | ## Pass OAuth Access token to upstream via "X-Forwarded-Access-Token" 43 | # pass_access_token = false 44 | 45 | ## Authenticated Email Addresses File (one email per line) 46 | # authenticated_emails_file = "" 47 | 48 | ## Htpasswd File (optional) 49 | ## Additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption 50 | ## enabling exposes a username/login signin form 51 | # htpasswd_file = "" 52 | 53 | ## Templates 54 | ## optional directory with custom sign_in.html and error.html 55 | # custom_templates_dir = "" 56 | 57 | ## skip SSL checking for HTTPS requests 58 | # ssl_insecure_skip_verify = false 59 | 60 | 61 | ## Cookie Settings 62 | ## Name - the cookie name 63 | ## Secret - the seed string for secure cookies; should be 16, 24, or 32 bytes 64 | ## for use with an AES cipher when cookie_refresh or pass_access_token 65 | ## is set 66 | ## Domain - (optional) cookie domain to force cookies to (ie: .yourcompany.com) 67 | ## Expire - (duration) expire timeframe for cookie 68 | ## Refresh - (duration) refresh the cookie when duration has elapsed after cookie was initially set. 69 | ## Should be less than cookie_expire; set to 0 to disable. 70 | ## On refresh, OAuth token is re-validated. 71 | ## (ie: 1h means tokens are refreshed on request 1hr+ after it was set) 72 | ## Secure - secure cookies are only sent by the browser of a HTTPS connection (recommended) 73 | ## HttpOnly - httponly cookies are not readable by javascript (recommended) 74 | # cookie_name = "_oauth2_proxy" 75 | # cookie_secret = "" 76 | # cookie_domain = "" 77 | # cookie_expire = "168h" 78 | # cookie_refresh = "" 79 | # cookie_secure = true 80 | # cookie_httponly = true 81 | -------------------------------------------------------------------------------- /contrib/oauth2_proxy.service.example: -------------------------------------------------------------------------------- 1 | # Systemd service file for oauth2_proxy daemon 2 | # 3 | # Date: Feb 9, 2016 4 | # Author: Srdjan Grubor 5 | 6 | [Unit] 7 | Description=oauth2_proxy daemon service 8 | After=syslog.target network.target 9 | 10 | [Service] 11 | # www-data group and user need to be created before using these lines 12 | User=www-data 13 | Group=www-data 14 | 15 | ExecStart=/usr/local/bin/oauth2_proxy -config=/etc/oauth2_proxy.cfg 16 | ExecReload=/bin/kill -HUP $MAINPID 17 | 18 | KillMode=process 19 | Restart=always 20 | 21 | [Install] 22 | WantedBy=multi-user.target 23 | -------------------------------------------------------------------------------- /cookie/cookies.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/hmac" 7 | "crypto/rand" 8 | "crypto/sha1" 9 | "encoding/base64" 10 | "fmt" 11 | "io" 12 | "net/http" 13 | "strconv" 14 | "strings" 15 | "time" 16 | ) 17 | 18 | // cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set. 19 | // additionally, the 'value' is encrypted so it's opaque to the browser 20 | 21 | // Validate ensures a cookie is properly signed 22 | func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { 23 | // value, timestamp, sig 24 | parts := strings.Split(cookie.Value, "|") 25 | if len(parts) != 3 { 26 | return 27 | } 28 | sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) 29 | if checkHmac(parts[2], sig) { 30 | ts, err := strconv.Atoi(parts[1]) 31 | if err != nil { 32 | return 33 | } 34 | // The expiration timestamp set when the cookie was created 35 | // isn't sent back by the browser. Hence, we check whether the 36 | // creation timestamp stored in the cookie falls within the 37 | // window defined by (Now()-expiration, Now()]. 38 | t = time.Unix(int64(ts), 0) 39 | if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { 40 | // it's a valid cookie. now get the contents 41 | rawValue, err := base64.URLEncoding.DecodeString(parts[0]) 42 | if err == nil { 43 | value = string(rawValue) 44 | ok = true 45 | return 46 | } 47 | } 48 | } 49 | return 50 | } 51 | 52 | // SignedValue returns a cookie that is signed and can later be checked with Validate 53 | func SignedValue(seed string, key string, value string, now time.Time) string { 54 | encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) 55 | timeStr := fmt.Sprintf("%d", now.Unix()) 56 | sig := cookieSignature(seed, key, encodedValue, timeStr) 57 | cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) 58 | return cookieVal 59 | } 60 | 61 | func cookieSignature(args ...string) string { 62 | h := hmac.New(sha1.New, []byte(args[0])) 63 | for _, arg := range args[1:] { 64 | h.Write([]byte(arg)) 65 | } 66 | var b []byte 67 | b = h.Sum(b) 68 | return base64.URLEncoding.EncodeToString(b) 69 | } 70 | 71 | func checkHmac(input, expected string) bool { 72 | inputMAC, err1 := base64.URLEncoding.DecodeString(input) 73 | if err1 == nil { 74 | expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) 75 | if err2 == nil { 76 | return hmac.Equal(inputMAC, expectedMAC) 77 | } 78 | } 79 | return false 80 | } 81 | 82 | // Cipher provides methods to encrypt and decrypt cookie values 83 | type Cipher struct { 84 | cipher.Block 85 | } 86 | 87 | // NewCipher returns a new aes Cipher for encrypting cookie values 88 | func NewCipher(secret []byte) (*Cipher, error) { 89 | c, err := aes.NewCipher(secret) 90 | if err != nil { 91 | return nil, err 92 | } 93 | return &Cipher{Block: c}, err 94 | } 95 | 96 | // Encrypt a value for use in a cookie 97 | func (c *Cipher) Encrypt(value string) (string, error) { 98 | ciphertext := make([]byte, aes.BlockSize+len(value)) 99 | iv := ciphertext[:aes.BlockSize] 100 | if _, err := io.ReadFull(rand.Reader, iv); err != nil { 101 | return "", fmt.Errorf("failed to create initialization vector %s", err) 102 | } 103 | 104 | stream := cipher.NewCFBEncrypter(c.Block, iv) 105 | stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value)) 106 | return base64.StdEncoding.EncodeToString(ciphertext), nil 107 | } 108 | 109 | // Decrypt a value from a cookie to it's original string 110 | func (c *Cipher) Decrypt(s string) (string, error) { 111 | encrypted, err := base64.StdEncoding.DecodeString(s) 112 | if err != nil { 113 | return "", fmt.Errorf("failed to decrypt cookie value %s", err) 114 | } 115 | 116 | if len(encrypted) < aes.BlockSize { 117 | return "", fmt.Errorf("encrypted cookie value should be "+ 118 | "at least %d bytes, but is only %d bytes", 119 | aes.BlockSize, len(encrypted)) 120 | } 121 | 122 | iv := encrypted[:aes.BlockSize] 123 | encrypted = encrypted[aes.BlockSize:] 124 | stream := cipher.NewCFBDecrypter(c.Block, iv) 125 | stream.XORKeyStream(encrypted, encrypted) 126 | 127 | return string(encrypted), nil 128 | } 129 | -------------------------------------------------------------------------------- /cookie/cookies_test.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "encoding/base64" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestEncodeAndDecodeAccessToken(t *testing.T) { 11 | const secret = "0123456789abcdefghijklmnopqrstuv" 12 | const token = "my access token" 13 | c, err := NewCipher([]byte(secret)) 14 | assert.Equal(t, nil, err) 15 | 16 | encoded, err := c.Encrypt(token) 17 | assert.Equal(t, nil, err) 18 | 19 | decoded, err := c.Decrypt(encoded) 20 | assert.Equal(t, nil, err) 21 | 22 | assert.NotEqual(t, token, encoded) 23 | assert.Equal(t, token, decoded) 24 | } 25 | 26 | func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { 27 | const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" 28 | const token = "my access token" 29 | 30 | secret, err := base64.URLEncoding.DecodeString(secret_b64) 31 | c, err := NewCipher([]byte(secret)) 32 | assert.Equal(t, nil, err) 33 | 34 | encoded, err := c.Encrypt(token) 35 | assert.Equal(t, nil, err) 36 | 37 | decoded, err := c.Decrypt(encoded) 38 | assert.Equal(t, nil, err) 39 | 40 | assert.NotEqual(t, token, encoded) 41 | assert.Equal(t, token, decoded) 42 | } 43 | -------------------------------------------------------------------------------- /cookie/nonce.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | ) 7 | 8 | func Nonce() (nonce string, err error) { 9 | b := make([]byte, 16) 10 | _, err = rand.Read(b) 11 | if err != nil { 12 | return 13 | } 14 | nonce = fmt.Sprintf("%x", b) 15 | return 16 | } 17 | -------------------------------------------------------------------------------- /dist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # build binary distributions for linux/amd64 and darwin/amd64 3 | set -e 4 | 5 | DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 6 | echo "working dir $DIR" 7 | mkdir -p $DIR/dist 8 | dep ensure || exit 1 9 | 10 | os=$(go env GOOS) 11 | arch=$(go env GOARCH) 12 | version=$(cat $DIR/version.go | grep "const VERSION" | awk '{print $NF}' | sed 's/"//g') 13 | goversion=$(go version | awk '{print $3}') 14 | sha256sum=() 15 | 16 | echo "... running tests" 17 | ./test.sh 18 | 19 | for os in windows linux darwin; do 20 | echo "... building v$version for $os/$arch" 21 | EXT= 22 | if [ $os = windows ]; then 23 | EXT=".exe" 24 | fi 25 | BUILD=$(mktemp -d ${TMPDIR:-/tmp}/oauth2_proxy.XXXXXX) 26 | TARGET="oauth2_proxy-$version.$os-$arch.$goversion" 27 | FILENAME="oauth2_proxy-$version.$os-$arch$EXT" 28 | GOOS=$os GOARCH=$arch CGO_ENABLED=0 \ 29 | go build -ldflags="-s -w" -o $BUILD/$TARGET/$FILENAME || exit 1 30 | pushd $BUILD/$TARGET 31 | sha256sum+=("$(shasum -a 256 $FILENAME || exit 1)") 32 | cd .. && tar czvf $TARGET.tar.gz $TARGET 33 | mv $TARGET.tar.gz $DIR/dist 34 | popd 35 | done 36 | 37 | checksum_file="sha256sum.txt" 38 | cd $DIR/dist 39 | if [ -f $checksum_file ]; then 40 | rm $checksum_file 41 | fi 42 | touch $checksum_file 43 | for checksum in "${sha256sum[@]}"; do 44 | echo "$checksum" >> $checksum_file 45 | done 46 | -------------------------------------------------------------------------------- /env_options.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | type EnvOptions map[string]interface{} 10 | 11 | func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { 12 | val := reflect.ValueOf(options).Elem() 13 | typ := val.Type() 14 | for i := 0; i < typ.NumField(); i++ { 15 | // pull out the struct tags: 16 | // flag - the name of the command line flag 17 | // deprecated - (optional) the name of the deprecated command line flag 18 | // cfg - (optional, defaults to underscored flag) the name of the config file option 19 | field := typ.Field(i) 20 | flagName := field.Tag.Get("flag") 21 | envName := field.Tag.Get("env") 22 | cfgName := field.Tag.Get("cfg") 23 | if cfgName == "" && flagName != "" { 24 | cfgName = strings.Replace(flagName, "-", "_", -1) 25 | } 26 | if envName == "" || cfgName == "" { 27 | // resolvable fields must have the `env` and `cfg` struct tag 28 | continue 29 | } 30 | v := os.Getenv(envName) 31 | if v != "" { 32 | cfg[cfgName] = v 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /env_options_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type envTest struct { 11 | testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` 12 | } 13 | 14 | func TestLoadEnvForStruct(t *testing.T) { 15 | 16 | cfg := make(EnvOptions) 17 | cfg.LoadEnvForStruct(&envTest{}) 18 | 19 | _, ok := cfg["target_field"] 20 | assert.Equal(t, ok, false) 21 | 22 | os.Setenv("TEST_ENV_FIELD", "1234abcd") 23 | cfg.LoadEnvForStruct(&envTest{}) 24 | v := cfg["target_field"] 25 | assert.Equal(t, v, "1234abcd") 26 | } 27 | -------------------------------------------------------------------------------- /htpasswd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/base64" 6 | "encoding/csv" 7 | "io" 8 | "log" 9 | "os" 10 | 11 | "golang.org/x/crypto/bcrypt" 12 | ) 13 | 14 | // Lookup passwords in a htpasswd file 15 | // Passwords must be generated with -B for bcrypt or -s for SHA1. 16 | 17 | type HtpasswdFile struct { 18 | Users map[string]string 19 | } 20 | 21 | func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { 22 | r, err := os.Open(path) 23 | if err != nil { 24 | return nil, err 25 | } 26 | defer r.Close() 27 | return NewHtpasswd(r) 28 | } 29 | 30 | func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { 31 | csv_reader := csv.NewReader(file) 32 | csv_reader.Comma = ':' 33 | csv_reader.Comment = '#' 34 | csv_reader.TrimLeadingSpace = true 35 | 36 | records, err := csv_reader.ReadAll() 37 | if err != nil { 38 | return nil, err 39 | } 40 | h := &HtpasswdFile{Users: make(map[string]string)} 41 | for _, record := range records { 42 | h.Users[record[0]] = record[1] 43 | } 44 | return h, nil 45 | } 46 | 47 | func (h *HtpasswdFile) Validate(user string, password string) bool { 48 | realPassword, exists := h.Users[user] 49 | if !exists { 50 | return false 51 | } 52 | 53 | shaPrefix := realPassword[:5] 54 | if shaPrefix == "{SHA}" { 55 | shaValue := realPassword[5:] 56 | d := sha1.New() 57 | d.Write([]byte(password)) 58 | return shaValue == base64.StdEncoding.EncodeToString(d.Sum(nil)) 59 | } 60 | 61 | bcryptPrefix := realPassword[:4] 62 | if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" { 63 | return bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password)) == nil 64 | } 65 | 66 | log.Printf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user) 67 | return false 68 | } 69 | -------------------------------------------------------------------------------- /htpasswd_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "golang.org/x/crypto/bcrypt" 10 | ) 11 | 12 | func TestSHA(t *testing.T) { 13 | file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n")) 14 | h, err := NewHtpasswd(file) 15 | assert.Equal(t, err, nil) 16 | 17 | valid := h.Validate("testuser", "asdf") 18 | assert.Equal(t, valid, true) 19 | } 20 | 21 | func TestBcrypt(t *testing.T) { 22 | hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) 23 | hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) 24 | assert.Equal(t, err, nil) 25 | 26 | contents := fmt.Sprintf("testuser1:%s\ntestuser2:%s\n", hash1, hash2) 27 | file := bytes.NewBuffer([]byte(contents)) 28 | 29 | h, err := NewHtpasswd(file) 30 | assert.Equal(t, err, nil) 31 | 32 | valid := h.Validate("testuser1", "password") 33 | assert.Equal(t, valid, true) 34 | 35 | valid = h.Validate("testuser2", "top-secret") 36 | assert.Equal(t, valid, true) 37 | } 38 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "log" 6 | "net" 7 | "net/http" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | type Server struct { 13 | Handler http.Handler 14 | Opts *Options 15 | } 16 | 17 | func (s *Server) ListenAndServe() { 18 | if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { 19 | s.ServeHTTPS() 20 | } else { 21 | s.ServeHTTP() 22 | } 23 | } 24 | 25 | func (s *Server) ServeHTTP() { 26 | httpAddress := s.Opts.HttpAddress 27 | scheme := "" 28 | 29 | i := strings.Index(httpAddress, "://") 30 | if i > -1 { 31 | scheme = httpAddress[0:i] 32 | } 33 | 34 | var networkType string 35 | switch scheme { 36 | case "", "http": 37 | networkType = "tcp" 38 | default: 39 | networkType = scheme 40 | } 41 | 42 | slice := strings.SplitN(httpAddress, "//", 2) 43 | listenAddr := slice[len(slice)-1] 44 | 45 | listener, err := net.Listen(networkType, listenAddr) 46 | if err != nil { 47 | log.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err) 48 | } 49 | log.Printf("HTTP: listening on %s", listenAddr) 50 | 51 | server := &http.Server{Handler: s.Handler} 52 | err = server.Serve(listener) 53 | if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { 54 | log.Printf("ERROR: http.Serve() - %s", err) 55 | } 56 | 57 | log.Printf("HTTP: closing %s", listener.Addr()) 58 | } 59 | 60 | func (s *Server) ServeHTTPS() { 61 | addr := s.Opts.HttpsAddress 62 | config := &tls.Config{ 63 | MinVersion: tls.VersionTLS12, 64 | MaxVersion: tls.VersionTLS12, 65 | } 66 | if config.NextProtos == nil { 67 | config.NextProtos = []string{"http/1.1"} 68 | } 69 | 70 | var err error 71 | config.Certificates = make([]tls.Certificate, 1) 72 | config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile) 73 | if err != nil { 74 | log.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err) 75 | } 76 | 77 | ln, err := net.Listen("tcp", addr) 78 | if err != nil { 79 | log.Fatalf("FATAL: listen (%s) failed - %s", addr, err) 80 | } 81 | log.Printf("HTTPS: listening on %s", ln.Addr()) 82 | 83 | tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) 84 | srv := &http.Server{Handler: s.Handler} 85 | err = srv.Serve(tlsListener) 86 | 87 | if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { 88 | log.Printf("ERROR: https.Serve() - %s", err) 89 | } 90 | 91 | log.Printf("HTTPS: closing %s", tlsListener.Addr()) 92 | } 93 | 94 | // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted 95 | // connections. It's used by ListenAndServe and ListenAndServeTLS so 96 | // dead TCP connections (e.g. closing laptop mid-download) eventually 97 | // go away. 98 | type tcpKeepAliveListener struct { 99 | *net.TCPListener 100 | } 101 | 102 | func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { 103 | tc, err := ln.AcceptTCP() 104 | if err != nil { 105 | return 106 | } 107 | tc.SetKeepAlive(true) 108 | tc.SetKeepAlivePeriod(3 * time.Minute) 109 | return tc, nil 110 | } 111 | -------------------------------------------------------------------------------- /logging_handler.go: -------------------------------------------------------------------------------- 1 | // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go 2 | // to add logging of request duration as last value (and drop referrer) 3 | 4 | package main 5 | 6 | import ( 7 | "fmt" 8 | "io" 9 | "net" 10 | "net/http" 11 | "net/url" 12 | "text/template" 13 | "time" 14 | ) 15 | 16 | const ( 17 | defaultRequestLoggingFormat = "{{.Client}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" 18 | ) 19 | 20 | // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status 21 | // code and body size 22 | type responseLogger struct { 23 | w http.ResponseWriter 24 | status int 25 | size int 26 | upstream string 27 | authInfo string 28 | } 29 | 30 | func (l *responseLogger) Header() http.Header { 31 | return l.w.Header() 32 | } 33 | 34 | func (l *responseLogger) ExtractGAPMetadata() { 35 | upstream := l.w.Header().Get("GAP-Upstream-Address") 36 | if upstream != "" { 37 | l.upstream = upstream 38 | l.w.Header().Del("GAP-Upstream-Address") 39 | } 40 | authInfo := l.w.Header().Get("GAP-Auth") 41 | if authInfo != "" { 42 | l.authInfo = authInfo 43 | l.w.Header().Del("GAP-Auth") 44 | } 45 | } 46 | 47 | func (l *responseLogger) Write(b []byte) (int, error) { 48 | if l.status == 0 { 49 | // The status will be StatusOK if WriteHeader has not been called yet 50 | l.status = http.StatusOK 51 | } 52 | l.ExtractGAPMetadata() 53 | size, err := l.w.Write(b) 54 | l.size += size 55 | return size, err 56 | } 57 | 58 | func (l *responseLogger) WriteHeader(s int) { 59 | l.ExtractGAPMetadata() 60 | l.w.WriteHeader(s) 61 | l.status = s 62 | } 63 | 64 | func (l *responseLogger) Status() int { 65 | return l.status 66 | } 67 | 68 | func (l *responseLogger) Size() int { 69 | return l.size 70 | } 71 | 72 | // logMessageData is the container for all values that are available as variables in the request logging format. 73 | // All values are pre-formatted strings so it is easy to use them in the format string. 74 | type logMessageData struct { 75 | Client, 76 | Host, 77 | Protocol, 78 | RequestDuration, 79 | RequestMethod, 80 | RequestURI, 81 | ResponseSize, 82 | StatusCode, 83 | Timestamp, 84 | Upstream, 85 | UserAgent, 86 | Username string 87 | } 88 | 89 | // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends 90 | type loggingHandler struct { 91 | writer io.Writer 92 | handler http.Handler 93 | enabled bool 94 | logTemplate *template.Template 95 | } 96 | 97 | func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler { 98 | return loggingHandler{ 99 | writer: out, 100 | handler: h, 101 | enabled: v, 102 | logTemplate: template.Must(template.New("request-log").Parse(requestLoggingTpl)), 103 | } 104 | } 105 | 106 | func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 107 | t := time.Now() 108 | url := *req.URL 109 | logger := &responseLogger{w: w} 110 | h.handler.ServeHTTP(logger, req) 111 | if !h.enabled { 112 | return 113 | } 114 | h.writeLogLine(logger.authInfo, logger.upstream, req, url, t, logger.Status(), logger.Size()) 115 | } 116 | 117 | // Log entry for req similar to Apache Common Log Format. 118 | // ts is the timestamp with which the entry should be logged. 119 | // status, size are used to provide the response HTTP status and size. 120 | func (h loggingHandler) writeLogLine(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) { 121 | if username == "" { 122 | username = "-" 123 | } 124 | if upstream == "" { 125 | upstream = "-" 126 | } 127 | if url.User != nil && username == "-" { 128 | if name := url.User.Username(); name != "" { 129 | username = name 130 | } 131 | } 132 | 133 | client := req.Header.Get("X-Real-IP") 134 | if client == "" { 135 | client = req.RemoteAddr 136 | } 137 | 138 | if c, _, err := net.SplitHostPort(client); err == nil { 139 | client = c 140 | } 141 | 142 | duration := float64(time.Now().Sub(ts)) / float64(time.Second) 143 | 144 | h.logTemplate.Execute(h.writer, logMessageData{ 145 | Client: client, 146 | Host: req.Host, 147 | Protocol: req.Proto, 148 | RequestDuration: fmt.Sprintf("%0.3f", duration), 149 | RequestMethod: req.Method, 150 | RequestURI: fmt.Sprintf("%q", url.RequestURI()), 151 | ResponseSize: fmt.Sprintf("%d", size), 152 | StatusCode: fmt.Sprintf("%d", status), 153 | Timestamp: ts.Format("02/Jan/2006:15:04:05 -0700"), 154 | Upstream: upstream, 155 | UserAgent: fmt.Sprintf("%q", req.UserAgent()), 156 | Username: username, 157 | }) 158 | 159 | h.writer.Write([]byte("\n")) 160 | } 161 | -------------------------------------------------------------------------------- /logging_handler_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestLoggingHandler_ServeHTTP(t *testing.T) { 13 | ts := time.Now() 14 | 15 | tests := []struct { 16 | Format, 17 | ExpectedLogMessage string 18 | }{ 19 | {defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", ts.Format("02/Jan/2006:15:04:05 -0700"))}, 20 | {"{{.RequestMethod}}", "GET\n"}, 21 | } 22 | 23 | for _, test := range tests { 24 | buf := bytes.NewBuffer(nil) 25 | handler := func(w http.ResponseWriter, req *http.Request) { 26 | w.Write([]byte("test")) 27 | } 28 | 29 | h := LoggingHandler(buf, http.HandlerFunc(handler), true, test.Format) 30 | 31 | r, _ := http.NewRequest("GET", "/foo/bar", nil) 32 | r.RemoteAddr = "127.0.0.1" 33 | r.Host = "test-server" 34 | 35 | h.ServeHTTP(httptest.NewRecorder(), r) 36 | 37 | actual := buf.String() 38 | if actual != test.ExpectedLogMessage { 39 | t.Errorf("Log message was\n%s\ninstead of expected \n%s", actual, test.ExpectedLogMessage) 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "runtime" 9 | "strings" 10 | "time" 11 | 12 | "github.com/BurntSushi/toml" 13 | "github.com/mreiferson/go-options" 14 | ) 15 | 16 | func main() { 17 | log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) 18 | flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError) 19 | 20 | emailDomains := StringArray{} 21 | upstreams := StringArray{} 22 | skipAuthRegex := StringArray{} 23 | googleGroups := StringArray{} 24 | 25 | config := flagSet.String("config", "", "path to config file") 26 | showVersion := flagSet.Bool("version", false, "print version string") 27 | 28 | flagSet.String("http-address", "127.0.0.1:4180", "[http://]: or unix:// to listen on for HTTP clients") 29 | flagSet.String("https-address", ":443", ": to listen on for HTTPS clients") 30 | flagSet.String("tls-cert", "", "path to certificate file") 31 | flagSet.String("tls-key", "", "path to private key file") 32 | flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") 33 | flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)") 34 | flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path") 35 | flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") 36 | flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") 37 | flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") 38 | flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") 39 | flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") 40 | flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") 41 | flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") 42 | flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") 43 | flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS") 44 | 45 | flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") 46 | flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") 47 | flagSet.String("github-org", "", "restrict logins to members of this organisation") 48 | flagSet.String("github-team", "", "restrict logins to members of this team") 49 | flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).") 50 | flagSet.String("google-admin-email", "", "the google admin to impersonate for api calls") 51 | flagSet.String("google-service-account-json", "", "the path to the service account json credentials") 52 | flagSet.String("client-id", "", "the OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") 53 | flagSet.String("client-secret", "", "the OAuth Client Secret") 54 | flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") 55 | flagSet.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -s\" for SHA encryption or \"htpasswd -B\" for bcrypt encryption") 56 | flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") 57 | flagSet.String("custom-templates-dir", "", "path to custom html templates") 58 | flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") 59 | flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") 60 | 61 | flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") 62 | flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)") 63 | flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") 64 | flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") 65 | flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") 66 | flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") 67 | flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") 68 | 69 | flagSet.Bool("request-logging", true, "Log requests to stdout") 70 | flagSet.String("request-logging-format", defaultRequestLoggingFormat, "Template for log lines") 71 | 72 | flagSet.String("provider", "google", "OAuth provider") 73 | flagSet.String("oidc-issuer-url", "", "OpenID Connect issuer URL (ie: https://accounts.google.com)") 74 | flagSet.String("login-url", "", "Authentication endpoint") 75 | flagSet.String("redeem-url", "", "Token redemption endpoint") 76 | flagSet.String("profile-url", "", "Profile access endpoint") 77 | flagSet.String("resource", "", "The resource that is protected (Azure AD only)") 78 | flagSet.String("validate-url", "", "Access token validation endpoint") 79 | flagSet.String("scope", "", "OAuth scope specification") 80 | flagSet.String("approval-prompt", "force", "OAuth approval_prompt") 81 | 82 | flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") 83 | 84 | flagSet.Parse(os.Args[1:]) 85 | 86 | if *showVersion { 87 | fmt.Printf("oauth2_proxy v%s (built with %s)\n", VERSION, runtime.Version()) 88 | return 89 | } 90 | 91 | opts := NewOptions() 92 | 93 | cfg := make(EnvOptions) 94 | if *config != "" { 95 | _, err := toml.DecodeFile(*config, &cfg) 96 | if err != nil { 97 | log.Fatalf("ERROR: failed to load config file %s - %s", *config, err) 98 | } 99 | } 100 | cfg.LoadEnvForStruct(opts) 101 | options.Resolve(opts, flagSet, cfg) 102 | 103 | err := opts.Validate() 104 | if err != nil { 105 | log.Printf("%s", err) 106 | os.Exit(1) 107 | } 108 | validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) 109 | oauthproxy := NewOAuthProxy(opts, validator) 110 | 111 | if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { 112 | if len(opts.EmailDomains) > 1 { 113 | oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) 114 | } else if opts.EmailDomains[0] != "*" { 115 | oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) 116 | } 117 | } 118 | 119 | if opts.HtpasswdFile != "" { 120 | log.Printf("using htpasswd file %s", opts.HtpasswdFile) 121 | oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(opts.HtpasswdFile) 122 | oauthproxy.DisplayHtpasswdForm = opts.DisplayHtpasswdForm 123 | if err != nil { 124 | log.Fatalf("FATAL: unable to open %s %s", opts.HtpasswdFile, err) 125 | } 126 | } 127 | 128 | s := &Server{ 129 | Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat), 130 | Opts: opts, 131 | } 132 | s.ListenAndServe() 133 | } 134 | -------------------------------------------------------------------------------- /oauthproxy.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | b64 "encoding/base64" 5 | "errors" 6 | "fmt" 7 | "html/template" 8 | "log" 9 | "net" 10 | "net/http" 11 | "net/http/httputil" 12 | "net/url" 13 | "regexp" 14 | "strings" 15 | "time" 16 | 17 | "github.com/bitly/oauth2_proxy/cookie" 18 | "github.com/bitly/oauth2_proxy/providers" 19 | "github.com/mbland/hmacauth" 20 | ) 21 | 22 | const SignatureHeader = "GAP-Signature" 23 | 24 | var SignatureHeaders []string = []string{ 25 | "Content-Length", 26 | "Content-Md5", 27 | "Content-Type", 28 | "Date", 29 | "Authorization", 30 | "X-Forwarded-User", 31 | "X-Forwarded-Email", 32 | "X-Forwarded-Access-Token", 33 | "Cookie", 34 | "Gap-Auth", 35 | } 36 | 37 | type OAuthProxy struct { 38 | CookieSeed string 39 | CookieName string 40 | CSRFCookieName string 41 | CookieDomain string 42 | CookieSecure bool 43 | CookieHttpOnly bool 44 | CookieExpire time.Duration 45 | CookieRefresh time.Duration 46 | Validator func(string) bool 47 | 48 | RobotsPath string 49 | PingPath string 50 | SignInPath string 51 | SignOutPath string 52 | OAuthStartPath string 53 | OAuthCallbackPath string 54 | AuthOnlyPath string 55 | 56 | redirectURL *url.URL // the url to receive requests at 57 | provider providers.Provider 58 | ProxyPrefix string 59 | SignInMessage string 60 | HtpasswdFile *HtpasswdFile 61 | DisplayHtpasswdForm bool 62 | serveMux http.Handler 63 | SetXAuthRequest bool 64 | PassBasicAuth bool 65 | SkipProviderButton bool 66 | PassUserHeaders bool 67 | BasicAuthPassword string 68 | PassAccessToken bool 69 | CookieCipher *cookie.Cipher 70 | skipAuthRegex []string 71 | skipAuthPreflight bool 72 | compiledRegex []*regexp.Regexp 73 | templates *template.Template 74 | Footer string 75 | } 76 | 77 | type UpstreamProxy struct { 78 | upstream string 79 | handler http.Handler 80 | auth hmacauth.HmacAuth 81 | } 82 | 83 | func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 84 | w.Header().Set("GAP-Upstream-Address", u.upstream) 85 | if u.auth != nil { 86 | r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) 87 | u.auth.SignRequest(r) 88 | } 89 | u.handler.ServeHTTP(w, r) 90 | } 91 | 92 | func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { 93 | return httputil.NewSingleHostReverseProxy(target) 94 | } 95 | func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { 96 | director := proxy.Director 97 | proxy.Director = func(req *http.Request) { 98 | director(req) 99 | // use RequestURI so that we aren't unescaping encoded slashes in the request path 100 | req.Host = target.Host 101 | req.URL.Opaque = req.RequestURI 102 | req.URL.RawQuery = "" 103 | } 104 | } 105 | func setProxyDirector(proxy *httputil.ReverseProxy) { 106 | director := proxy.Director 107 | proxy.Director = func(req *http.Request) { 108 | director(req) 109 | // use RequestURI so that we aren't unescaping encoded slashes in the request path 110 | req.URL.Opaque = req.RequestURI 111 | req.URL.RawQuery = "" 112 | } 113 | } 114 | func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { 115 | return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) 116 | } 117 | 118 | func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { 119 | serveMux := http.NewServeMux() 120 | var auth hmacauth.HmacAuth 121 | if sigData := opts.signatureData; sigData != nil { 122 | auth = hmacauth.NewHmacAuth(sigData.hash, []byte(sigData.key), 123 | SignatureHeader, SignatureHeaders) 124 | } 125 | for _, u := range opts.proxyURLs { 126 | path := u.Path 127 | switch u.Scheme { 128 | case "http", "https": 129 | u.Path = "" 130 | log.Printf("mapping path %q => upstream %q", path, u) 131 | proxy := NewReverseProxy(u) 132 | if !opts.PassHostHeader { 133 | setProxyUpstreamHostHeader(proxy, u) 134 | } else { 135 | setProxyDirector(proxy) 136 | } 137 | serveMux.Handle(path, 138 | &UpstreamProxy{u.Host, proxy, auth}) 139 | case "file": 140 | if u.Fragment != "" { 141 | path = u.Fragment 142 | } 143 | log.Printf("mapping path %q => file system %q", path, u.Path) 144 | proxy := NewFileServer(path, u.Path) 145 | serveMux.Handle(path, &UpstreamProxy{path, proxy, nil}) 146 | default: 147 | panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) 148 | } 149 | } 150 | for _, u := range opts.CompiledRegex { 151 | log.Printf("compiled skip-auth-regex => %q", u) 152 | } 153 | 154 | redirectURL := opts.redirectURL 155 | redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) 156 | 157 | log.Printf("OAuthProxy configured for %s Client ID: %s", opts.provider.Data().ProviderName, opts.ClientID) 158 | refresh := "disabled" 159 | if opts.CookieRefresh != time.Duration(0) { 160 | refresh = fmt.Sprintf("after %s", opts.CookieRefresh) 161 | } 162 | 163 | log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) 164 | 165 | var cipher *cookie.Cipher 166 | if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { 167 | var err error 168 | cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) 169 | if err != nil { 170 | log.Fatal("cookie-secret error: ", err) 171 | } 172 | } 173 | 174 | return &OAuthProxy{ 175 | CookieName: opts.CookieName, 176 | CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), 177 | CookieSeed: opts.CookieSecret, 178 | CookieDomain: opts.CookieDomain, 179 | CookieSecure: opts.CookieSecure, 180 | CookieHttpOnly: opts.CookieHttpOnly, 181 | CookieExpire: opts.CookieExpire, 182 | CookieRefresh: opts.CookieRefresh, 183 | Validator: validator, 184 | 185 | RobotsPath: "/robots.txt", 186 | PingPath: "/ping", 187 | SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), 188 | SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), 189 | OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), 190 | OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix), 191 | AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), 192 | 193 | ProxyPrefix: opts.ProxyPrefix, 194 | provider: opts.provider, 195 | serveMux: serveMux, 196 | redirectURL: redirectURL, 197 | skipAuthRegex: opts.SkipAuthRegex, 198 | skipAuthPreflight: opts.SkipAuthPreflight, 199 | compiledRegex: opts.CompiledRegex, 200 | SetXAuthRequest: opts.SetXAuthRequest, 201 | PassBasicAuth: opts.PassBasicAuth, 202 | PassUserHeaders: opts.PassUserHeaders, 203 | BasicAuthPassword: opts.BasicAuthPassword, 204 | PassAccessToken: opts.PassAccessToken, 205 | SkipProviderButton: opts.SkipProviderButton, 206 | CookieCipher: cipher, 207 | templates: loadTemplates(opts.CustomTemplatesDir), 208 | Footer: opts.Footer, 209 | } 210 | } 211 | 212 | func (p *OAuthProxy) GetRedirectURI(host string) string { 213 | // default to the request Host if not set 214 | if p.redirectURL.Host != "" { 215 | return p.redirectURL.String() 216 | } 217 | var u url.URL 218 | u = *p.redirectURL 219 | if u.Scheme == "" { 220 | if p.CookieSecure { 221 | u.Scheme = "https" 222 | } else { 223 | u.Scheme = "http" 224 | } 225 | } 226 | u.Host = host 227 | return u.String() 228 | } 229 | 230 | func (p *OAuthProxy) displayCustomLoginForm() bool { 231 | return p.HtpasswdFile != nil && p.DisplayHtpasswdForm 232 | } 233 | 234 | func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { 235 | if code == "" { 236 | return nil, errors.New("missing code") 237 | } 238 | redirectURI := p.GetRedirectURI(host) 239 | s, err = p.provider.Redeem(redirectURI, code) 240 | if err != nil { 241 | return 242 | } 243 | 244 | if s.Email == "" { 245 | s.Email, err = p.provider.GetEmailAddress(s) 246 | } 247 | 248 | if s.User == "" { 249 | s.User, err = p.provider.GetUserName(s) 250 | if err != nil && err.Error() == "not implemented" { 251 | err = nil 252 | } 253 | } 254 | return 255 | } 256 | 257 | func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { 258 | if value != "" { 259 | value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) 260 | if len(value) > 4096 { 261 | // Cookies cannot be larger than 4kb 262 | log.Printf("WARNING - Cookie Size: %d bytes", len(value)) 263 | } 264 | } 265 | return p.makeCookie(req, p.CookieName, value, expiration, now) 266 | } 267 | 268 | func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { 269 | return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) 270 | } 271 | 272 | func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { 273 | if p.CookieDomain != "" { 274 | domain := req.Host 275 | if h, _, err := net.SplitHostPort(domain); err == nil { 276 | domain = h 277 | } 278 | if !strings.HasSuffix(domain, p.CookieDomain) { 279 | log.Printf("Warning: request host is %q but using configured cookie domain of %q", domain, p.CookieDomain) 280 | } 281 | } 282 | 283 | return &http.Cookie{ 284 | Name: name, 285 | Value: value, 286 | Path: "/", 287 | Domain: p.CookieDomain, 288 | HttpOnly: p.CookieHttpOnly, 289 | Secure: p.CookieSecure, 290 | Expires: now.Add(expiration), 291 | } 292 | } 293 | 294 | func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { 295 | http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) 296 | } 297 | 298 | func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { 299 | http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) 300 | } 301 | 302 | func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { 303 | clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) 304 | http.SetCookie(rw, clr) 305 | 306 | // ugly hack because default domain changed 307 | if p.CookieDomain == "" { 308 | clr2 := *clr 309 | clr2.Domain = req.Host 310 | http.SetCookie(rw, &clr2) 311 | } 312 | } 313 | 314 | func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { 315 | http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) 316 | } 317 | 318 | func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { 319 | var age time.Duration 320 | c, err := req.Cookie(p.CookieName) 321 | if err != nil { 322 | // always http.ErrNoCookie 323 | return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) 324 | } 325 | val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) 326 | if !ok { 327 | return nil, age, errors.New("Cookie Signature not valid") 328 | } 329 | 330 | session, err := p.provider.SessionFromCookie(val, p.CookieCipher) 331 | if err != nil { 332 | return nil, age, err 333 | } 334 | 335 | age = time.Now().Truncate(time.Second).Sub(timestamp) 336 | return session, age, nil 337 | } 338 | 339 | func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { 340 | value, err := p.provider.CookieForSession(s, p.CookieCipher) 341 | if err != nil { 342 | return err 343 | } 344 | p.SetSessionCookie(rw, req, value) 345 | return nil 346 | } 347 | 348 | func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { 349 | rw.WriteHeader(http.StatusOK) 350 | fmt.Fprintf(rw, "User-agent: *\nDisallow: /") 351 | } 352 | 353 | func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { 354 | rw.WriteHeader(http.StatusOK) 355 | fmt.Fprintf(rw, "OK") 356 | } 357 | 358 | func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { 359 | log.Printf("ErrorPage %d %s %s", code, title, message) 360 | rw.WriteHeader(code) 361 | t := struct { 362 | Title string 363 | Message string 364 | ProxyPrefix string 365 | }{ 366 | Title: fmt.Sprintf("%d %s", code, title), 367 | Message: message, 368 | ProxyPrefix: p.ProxyPrefix, 369 | } 370 | p.templates.ExecuteTemplate(rw, "error.html", t) 371 | } 372 | 373 | func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { 374 | p.ClearSessionCookie(rw, req) 375 | rw.WriteHeader(code) 376 | 377 | redirect_url := req.URL.RequestURI() 378 | if req.Header.Get("X-Auth-Request-Redirect") != "" { 379 | redirect_url = req.Header.Get("X-Auth-Request-Redirect") 380 | } 381 | if redirect_url == p.SignInPath { 382 | redirect_url = "/" 383 | } 384 | 385 | t := struct { 386 | ProviderName string 387 | SignInMessage string 388 | CustomLogin bool 389 | Redirect string 390 | Version string 391 | ProxyPrefix string 392 | Footer template.HTML 393 | }{ 394 | ProviderName: p.provider.Data().ProviderName, 395 | SignInMessage: p.SignInMessage, 396 | CustomLogin: p.displayCustomLoginForm(), 397 | Redirect: redirect_url, 398 | Version: VERSION, 399 | ProxyPrefix: p.ProxyPrefix, 400 | Footer: template.HTML(p.Footer), 401 | } 402 | p.templates.ExecuteTemplate(rw, "sign_in.html", t) 403 | } 404 | 405 | func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { 406 | if req.Method != "POST" || p.HtpasswdFile == nil { 407 | return "", false 408 | } 409 | user := req.FormValue("username") 410 | passwd := req.FormValue("password") 411 | if user == "" { 412 | return "", false 413 | } 414 | // check auth 415 | if p.HtpasswdFile.Validate(user, passwd) { 416 | log.Printf("authenticated %q via HtpasswdFile", user) 417 | return user, true 418 | } 419 | return "", false 420 | } 421 | 422 | func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { 423 | err = req.ParseForm() 424 | if err != nil { 425 | return 426 | } 427 | 428 | redirect = req.Form.Get("rd") 429 | if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { 430 | redirect = "/" 431 | } 432 | 433 | return 434 | } 435 | 436 | func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) { 437 | isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" 438 | return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path) 439 | } 440 | 441 | func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { 442 | for _, u := range p.compiledRegex { 443 | ok = u.MatchString(path) 444 | if ok { 445 | return 446 | } 447 | } 448 | return 449 | } 450 | 451 | func getRemoteAddr(req *http.Request) (s string) { 452 | s = req.RemoteAddr 453 | if req.Header.Get("X-Real-IP") != "" { 454 | s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) 455 | } 456 | return 457 | } 458 | 459 | func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 460 | switch path := req.URL.Path; { 461 | case path == p.RobotsPath: 462 | p.RobotsTxt(rw) 463 | case path == p.PingPath: 464 | p.PingPage(rw) 465 | case p.IsWhitelistedRequest(req): 466 | p.serveMux.ServeHTTP(rw, req) 467 | case path == p.SignInPath: 468 | p.SignIn(rw, req) 469 | case path == p.SignOutPath: 470 | p.SignOut(rw, req) 471 | case path == p.OAuthStartPath: 472 | p.OAuthStart(rw, req) 473 | case path == p.OAuthCallbackPath: 474 | p.OAuthCallback(rw, req) 475 | case path == p.AuthOnlyPath: 476 | p.AuthenticateOnly(rw, req) 477 | default: 478 | p.Proxy(rw, req) 479 | } 480 | } 481 | 482 | func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { 483 | redirect, err := p.GetRedirect(req) 484 | if err != nil { 485 | p.ErrorPage(rw, 500, "Internal Error", err.Error()) 486 | return 487 | } 488 | 489 | user, ok := p.ManualSignIn(rw, req) 490 | if ok { 491 | session := &providers.SessionState{User: user} 492 | p.SaveSession(rw, req, session) 493 | http.Redirect(rw, req, redirect, 302) 494 | } else { 495 | if p.SkipProviderButton { 496 | p.OAuthStart(rw, req) 497 | } else { 498 | p.SignInPage(rw, req, http.StatusOK) 499 | } 500 | } 501 | } 502 | 503 | func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { 504 | p.ClearSessionCookie(rw, req) 505 | http.Redirect(rw, req, "/", 302) 506 | } 507 | 508 | func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { 509 | nonce, err := cookie.Nonce() 510 | if err != nil { 511 | p.ErrorPage(rw, 500, "Internal Error", err.Error()) 512 | return 513 | } 514 | p.SetCSRFCookie(rw, req, nonce) 515 | redirect, err := p.GetRedirect(req) 516 | if err != nil { 517 | p.ErrorPage(rw, 500, "Internal Error", err.Error()) 518 | return 519 | } 520 | redirectURI := p.GetRedirectURI(req.Host) 521 | http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) 522 | } 523 | 524 | func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { 525 | remoteAddr := getRemoteAddr(req) 526 | 527 | // finish the oauth cycle 528 | err := req.ParseForm() 529 | if err != nil { 530 | p.ErrorPage(rw, 500, "Internal Error", err.Error()) 531 | return 532 | } 533 | errorString := req.Form.Get("error") 534 | if errorString != "" { 535 | p.ErrorPage(rw, 403, "Permission Denied", errorString) 536 | return 537 | } 538 | 539 | session, err := p.redeemCode(req.Host, req.Form.Get("code")) 540 | if err != nil { 541 | log.Printf("%s error redeeming code %s", remoteAddr, err) 542 | p.ErrorPage(rw, 500, "Internal Error", "Internal Error") 543 | return 544 | } 545 | 546 | s := strings.SplitN(req.Form.Get("state"), ":", 2) 547 | if len(s) != 2 { 548 | p.ErrorPage(rw, 500, "Internal Error", "Invalid State") 549 | return 550 | } 551 | nonce := s[0] 552 | redirect := s[1] 553 | c, err := req.Cookie(p.CSRFCookieName) 554 | if err != nil { 555 | p.ErrorPage(rw, 403, "Permission Denied", err.Error()) 556 | return 557 | } 558 | p.ClearCSRFCookie(rw, req) 559 | if c.Value != nonce { 560 | log.Printf("%s csrf token mismatch, potential attack", remoteAddr) 561 | p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") 562 | return 563 | } 564 | 565 | if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { 566 | redirect = "/" 567 | } 568 | 569 | // set cookie, or deny 570 | if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { 571 | log.Printf("%s authentication complete %s", remoteAddr, session) 572 | err := p.SaveSession(rw, req, session) 573 | if err != nil { 574 | log.Printf("%s %s", remoteAddr, err) 575 | p.ErrorPage(rw, 500, "Internal Error", "Internal Error") 576 | return 577 | } 578 | http.Redirect(rw, req, redirect, 302) 579 | } else { 580 | log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email) 581 | p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") 582 | } 583 | } 584 | 585 | func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { 586 | status := p.Authenticate(rw, req) 587 | if status == http.StatusAccepted { 588 | rw.WriteHeader(http.StatusAccepted) 589 | } else { 590 | http.Error(rw, "unauthorized request", http.StatusUnauthorized) 591 | } 592 | } 593 | 594 | func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { 595 | status := p.Authenticate(rw, req) 596 | if status == http.StatusInternalServerError { 597 | p.ErrorPage(rw, http.StatusInternalServerError, 598 | "Internal Error", "Internal Error") 599 | } else if status == http.StatusForbidden { 600 | if p.SkipProviderButton { 601 | p.OAuthStart(rw, req) 602 | } else { 603 | p.SignInPage(rw, req, http.StatusForbidden) 604 | } 605 | } else { 606 | p.serveMux.ServeHTTP(rw, req) 607 | } 608 | } 609 | 610 | func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int { 611 | var saveSession, clearSession, revalidated bool 612 | remoteAddr := getRemoteAddr(req) 613 | 614 | session, sessionAge, err := p.LoadCookiedSession(req) 615 | if err != nil { 616 | log.Printf("%s %s", remoteAddr, err) 617 | } 618 | if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { 619 | log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh) 620 | saveSession = true 621 | } 622 | 623 | if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { 624 | log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) 625 | clearSession = true 626 | session = nil 627 | } else if ok { 628 | saveSession = true 629 | revalidated = true 630 | } 631 | 632 | if session != nil && session.IsExpired() { 633 | log.Printf("%s removing session. token expired %s", remoteAddr, session) 634 | session = nil 635 | saveSession = false 636 | clearSession = true 637 | } 638 | 639 | if saveSession && !revalidated && session != nil && session.AccessToken != "" { 640 | if !p.provider.ValidateSessionState(session) { 641 | log.Printf("%s removing session. error validating %s", remoteAddr, session) 642 | saveSession = false 643 | session = nil 644 | clearSession = true 645 | } 646 | } 647 | 648 | if session != nil && session.Email != "" && !p.Validator(session.Email) { 649 | log.Printf("%s Permission Denied: removing session %s", remoteAddr, session) 650 | session = nil 651 | saveSession = false 652 | clearSession = true 653 | } 654 | 655 | if saveSession && session != nil { 656 | err := p.SaveSession(rw, req, session) 657 | if err != nil { 658 | log.Printf("%s %s", remoteAddr, err) 659 | return http.StatusInternalServerError 660 | } 661 | } 662 | 663 | if clearSession { 664 | p.ClearSessionCookie(rw, req) 665 | } 666 | 667 | if session == nil { 668 | session, err = p.CheckBasicAuth(req) 669 | if err != nil { 670 | log.Printf("%s %s", remoteAddr, err) 671 | } 672 | } 673 | 674 | if session == nil { 675 | return http.StatusForbidden 676 | } 677 | 678 | // At this point, the user is authenticated. proxy normally 679 | if p.PassBasicAuth { 680 | req.SetBasicAuth(session.User, p.BasicAuthPassword) 681 | req.Header["X-Forwarded-User"] = []string{session.User} 682 | if session.Email != "" { 683 | req.Header["X-Forwarded-Email"] = []string{session.Email} 684 | } 685 | } 686 | if p.PassUserHeaders { 687 | req.Header["X-Forwarded-User"] = []string{session.User} 688 | if session.Email != "" { 689 | req.Header["X-Forwarded-Email"] = []string{session.Email} 690 | } 691 | } 692 | if p.SetXAuthRequest { 693 | rw.Header().Set("X-Auth-Request-User", session.User) 694 | if session.Email != "" { 695 | rw.Header().Set("X-Auth-Request-Email", session.Email) 696 | } 697 | } 698 | if p.PassAccessToken && session.AccessToken != "" { 699 | req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} 700 | } 701 | if session.Email == "" { 702 | rw.Header().Set("GAP-Auth", session.User) 703 | } else { 704 | rw.Header().Set("GAP-Auth", session.Email) 705 | } 706 | return http.StatusAccepted 707 | } 708 | 709 | func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { 710 | if p.HtpasswdFile == nil { 711 | return nil, nil 712 | } 713 | auth := req.Header.Get("Authorization") 714 | if auth == "" { 715 | return nil, nil 716 | } 717 | s := strings.SplitN(auth, " ", 2) 718 | if len(s) != 2 || s[0] != "Basic" { 719 | return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) 720 | } 721 | b, err := b64.StdEncoding.DecodeString(s[1]) 722 | if err != nil { 723 | return nil, err 724 | } 725 | pair := strings.SplitN(string(b), ":", 2) 726 | if len(pair) != 2 { 727 | return nil, fmt.Errorf("invalid format %s", b) 728 | } 729 | if p.HtpasswdFile.Validate(pair[0], pair[1]) { 730 | log.Printf("authenticated %q via basic auth", pair[0]) 731 | return &providers.SessionState{User: pair[0]}, nil 732 | } 733 | return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) 734 | } 735 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto" 6 | "crypto/tls" 7 | "encoding/base64" 8 | "fmt" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "regexp" 13 | "strings" 14 | "time" 15 | 16 | "github.com/bitly/oauth2_proxy/providers" 17 | oidc "github.com/coreos/go-oidc" 18 | "github.com/mbland/hmacauth" 19 | ) 20 | 21 | // Configuration Options that can be set by Command Line Flag, or Config File 22 | type Options struct { 23 | ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` 24 | HttpAddress string `flag:"http-address" cfg:"http_address"` 25 | HttpsAddress string `flag:"https-address" cfg:"https_address"` 26 | RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` 27 | ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` 28 | ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` 29 | TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` 30 | TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` 31 | 32 | AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` 33 | AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` 34 | EmailDomains []string `flag:"email-domain" cfg:"email_domains"` 35 | GitHubOrg string `flag:"github-org" cfg:"github_org"` 36 | GitHubTeam string `flag:"github-team" cfg:"github_team"` 37 | GoogleGroups []string `flag:"google-group" cfg:"google_group"` 38 | GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"` 39 | GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"` 40 | HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` 41 | DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` 42 | CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` 43 | Footer string `flag:"footer" cfg:"footer"` 44 | 45 | CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` 46 | CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` 47 | CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` 48 | CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` 49 | CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` 50 | CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` 51 | CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` 52 | 53 | Upstreams []string `flag:"upstream" cfg:"upstreams"` 54 | SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` 55 | PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` 56 | BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` 57 | PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` 58 | PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` 59 | SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` 60 | PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` 61 | SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` 62 | SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` 63 | SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` 64 | 65 | // These options allow for other providers besides Google, with 66 | // potential overrides. 67 | Provider string `flag:"provider" cfg:"provider"` 68 | OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url"` 69 | LoginURL string `flag:"login-url" cfg:"login_url"` 70 | RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` 71 | ProfileURL string `flag:"profile-url" cfg:"profile_url"` 72 | ProtectedResource string `flag:"resource" cfg:"resource"` 73 | ValidateURL string `flag:"validate-url" cfg:"validate_url"` 74 | Scope string `flag:"scope" cfg:"scope"` 75 | ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` 76 | 77 | RequestLogging bool `flag:"request-logging" cfg:"request_logging"` 78 | RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format"` 79 | 80 | SignatureKey string `flag:"signature-key" cfg:"signature_key" env:"OAUTH2_PROXY_SIGNATURE_KEY"` 81 | 82 | // internal values that are set after config validation 83 | redirectURL *url.URL 84 | proxyURLs []*url.URL 85 | CompiledRegex []*regexp.Regexp 86 | provider providers.Provider 87 | signatureData *SignatureData 88 | oidcVerifier *oidc.IDTokenVerifier 89 | } 90 | 91 | type SignatureData struct { 92 | hash crypto.Hash 93 | key string 94 | } 95 | 96 | func NewOptions() *Options { 97 | return &Options{ 98 | ProxyPrefix: "/oauth2", 99 | HttpAddress: "127.0.0.1:4180", 100 | HttpsAddress: ":443", 101 | DisplayHtpasswdForm: true, 102 | CookieName: "_oauth2_proxy", 103 | CookieSecure: true, 104 | CookieHttpOnly: true, 105 | CookieExpire: time.Duration(168) * time.Hour, 106 | CookieRefresh: time.Duration(0), 107 | SetXAuthRequest: false, 108 | SkipAuthPreflight: false, 109 | PassBasicAuth: true, 110 | PassUserHeaders: true, 111 | PassAccessToken: false, 112 | PassHostHeader: true, 113 | ApprovalPrompt: "force", 114 | RequestLogging: true, 115 | RequestLoggingFormat: defaultRequestLoggingFormat, 116 | } 117 | } 118 | 119 | func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { 120 | parsed, err := url.Parse(to_parse) 121 | if err != nil { 122 | return nil, append(msgs, fmt.Sprintf( 123 | "error parsing %s-url=%q %s", urltype, to_parse, err)) 124 | } 125 | return parsed, msgs 126 | } 127 | 128 | func (o *Options) Validate() error { 129 | if o.SSLInsecureSkipVerify { 130 | // TODO: Accept a certificate bundle. 131 | insecureTransport := &http.Transport{ 132 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 133 | } 134 | http.DefaultClient = &http.Client{Transport: insecureTransport} 135 | } 136 | 137 | msgs := make([]string, 0) 138 | if o.CookieSecret == "" { 139 | msgs = append(msgs, "missing setting: cookie-secret") 140 | } 141 | if o.ClientID == "" { 142 | msgs = append(msgs, "missing setting: client-id") 143 | } 144 | if o.ClientSecret == "" { 145 | msgs = append(msgs, "missing setting: client-secret") 146 | } 147 | if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" { 148 | msgs = append(msgs, "missing setting for email validation: email-domain or authenticated-emails-file required."+ 149 | "\n use email-domain=* to authorize all email addresses") 150 | } 151 | 152 | if o.OIDCIssuerURL != "" { 153 | // Configure discoverable provider data. 154 | provider, err := oidc.NewProvider(context.Background(), o.OIDCIssuerURL) 155 | if err != nil { 156 | return err 157 | } 158 | o.oidcVerifier = provider.Verifier(&oidc.Config{ 159 | ClientID: o.ClientID, 160 | }) 161 | o.LoginURL = provider.Endpoint().AuthURL 162 | o.RedeemURL = provider.Endpoint().TokenURL 163 | if o.Scope == "" { 164 | o.Scope = "openid email profile" 165 | } 166 | } 167 | 168 | o.redirectURL, msgs = parseURL(o.RedirectURL, "redirect", msgs) 169 | 170 | for _, u := range o.Upstreams { 171 | upstreamURL, err := url.Parse(u) 172 | if err != nil { 173 | msgs = append(msgs, fmt.Sprintf("error parsing upstream: %s", err)) 174 | } else { 175 | if upstreamURL.Path == "" { 176 | upstreamURL.Path = "/" 177 | } 178 | o.proxyURLs = append(o.proxyURLs, upstreamURL) 179 | } 180 | } 181 | 182 | for _, u := range o.SkipAuthRegex { 183 | CompiledRegex, err := regexp.Compile(u) 184 | if err != nil { 185 | msgs = append(msgs, fmt.Sprintf("error compiling regex=%q %s", u, err)) 186 | continue 187 | } 188 | o.CompiledRegex = append(o.CompiledRegex, CompiledRegex) 189 | } 190 | msgs = parseProviderInfo(o, msgs) 191 | 192 | if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { 193 | valid_cookie_secret_size := false 194 | for _, i := range []int{16, 24, 32} { 195 | if len(secretBytes(o.CookieSecret)) == i { 196 | valid_cookie_secret_size = true 197 | } 198 | } 199 | var decoded bool 200 | if string(secretBytes(o.CookieSecret)) != o.CookieSecret { 201 | decoded = true 202 | } 203 | if valid_cookie_secret_size == false { 204 | var suffix string 205 | if decoded { 206 | suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) 207 | } 208 | msgs = append(msgs, fmt.Sprintf( 209 | "cookie_secret must be 16, 24, or 32 bytes "+ 210 | "to create an AES cipher when "+ 211 | "pass_access_token == true or "+ 212 | "cookie_refresh != 0, but is %d bytes.%s", 213 | len(secretBytes(o.CookieSecret)), suffix)) 214 | } 215 | } 216 | 217 | if o.CookieRefresh >= o.CookieExpire { 218 | msgs = append(msgs, fmt.Sprintf( 219 | "cookie_refresh (%s) must be less than "+ 220 | "cookie_expire (%s)", 221 | o.CookieRefresh.String(), 222 | o.CookieExpire.String())) 223 | } 224 | 225 | if len(o.GoogleGroups) > 0 || o.GoogleAdminEmail != "" || o.GoogleServiceAccountJSON != "" { 226 | if len(o.GoogleGroups) < 1 { 227 | msgs = append(msgs, "missing setting: google-group") 228 | } 229 | if o.GoogleAdminEmail == "" { 230 | msgs = append(msgs, "missing setting: google-admin-email") 231 | } 232 | if o.GoogleServiceAccountJSON == "" { 233 | msgs = append(msgs, "missing setting: google-service-account-json") 234 | } 235 | } 236 | 237 | msgs = parseSignatureKey(o, msgs) 238 | msgs = validateCookieName(o, msgs) 239 | 240 | if len(msgs) != 0 { 241 | return fmt.Errorf("Invalid configuration:\n %s", 242 | strings.Join(msgs, "\n ")) 243 | } 244 | return nil 245 | } 246 | 247 | func parseProviderInfo(o *Options, msgs []string) []string { 248 | p := &providers.ProviderData{ 249 | Scope: o.Scope, 250 | ClientID: o.ClientID, 251 | ClientSecret: o.ClientSecret, 252 | ApprovalPrompt: o.ApprovalPrompt, 253 | } 254 | p.LoginURL, msgs = parseURL(o.LoginURL, "login", msgs) 255 | p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs) 256 | p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) 257 | p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) 258 | p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) 259 | 260 | o.provider = providers.New(o.Provider, p) 261 | switch p := o.provider.(type) { 262 | case *providers.AzureProvider: 263 | p.Configure(o.AzureTenant) 264 | case *providers.GitHubProvider: 265 | p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) 266 | case *providers.GoogleProvider: 267 | if o.GoogleServiceAccountJSON != "" { 268 | file, err := os.Open(o.GoogleServiceAccountJSON) 269 | if err != nil { 270 | msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) 271 | } else { 272 | p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) 273 | } 274 | } 275 | case *providers.OIDCProvider: 276 | if o.oidcVerifier == nil { 277 | msgs = append(msgs, "oidc provider requires an oidc issuer URL") 278 | } else { 279 | p.Verifier = o.oidcVerifier 280 | } 281 | } 282 | return msgs 283 | } 284 | 285 | func parseSignatureKey(o *Options, msgs []string) []string { 286 | if o.SignatureKey == "" { 287 | return msgs 288 | } 289 | 290 | components := strings.Split(o.SignatureKey, ":") 291 | if len(components) != 2 { 292 | return append(msgs, "invalid signature hash:key spec: "+ 293 | o.SignatureKey) 294 | } 295 | 296 | algorithm, secretKey := components[0], components[1] 297 | if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { 298 | return append(msgs, "unsupported signature hash algorithm: "+ 299 | o.SignatureKey) 300 | } else { 301 | o.signatureData = &SignatureData{hash, secretKey} 302 | } 303 | return msgs 304 | } 305 | 306 | func validateCookieName(o *Options, msgs []string) []string { 307 | cookie := &http.Cookie{Name: o.CookieName} 308 | if cookie.String() == "" { 309 | return append(msgs, fmt.Sprintf("invalid cookie name: %q", o.CookieName)) 310 | } 311 | return msgs 312 | } 313 | 314 | func addPadding(secret string) string { 315 | padding := len(secret) % 4 316 | switch padding { 317 | case 1: 318 | return secret + "===" 319 | case 2: 320 | return secret + "==" 321 | case 3: 322 | return secret + "=" 323 | default: 324 | return secret 325 | } 326 | } 327 | 328 | // secretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary 329 | func secretBytes(secret string) []byte { 330 | b, err := base64.URLEncoding.DecodeString(addPadding(secret)) 331 | if err == nil { 332 | return []byte(addPadding(string(b))) 333 | } 334 | return []byte(secret) 335 | } 336 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto" 5 | "fmt" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func testOptions() *Options { 15 | o := NewOptions() 16 | o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8080/") 17 | o.CookieSecret = "foobar" 18 | o.ClientID = "bazquux" 19 | o.ClientSecret = "xyzzyplugh" 20 | o.EmailDomains = []string{"*"} 21 | return o 22 | } 23 | 24 | func errorMsg(msgs []string) string { 25 | result := make([]string, 0) 26 | result = append(result, "Invalid configuration:") 27 | result = append(result, msgs...) 28 | return strings.Join(result, "\n ") 29 | } 30 | 31 | func TestNewOptions(t *testing.T) { 32 | o := NewOptions() 33 | o.EmailDomains = []string{"*"} 34 | err := o.Validate() 35 | assert.NotEqual(t, nil, err) 36 | 37 | expected := errorMsg([]string{ 38 | "missing setting: cookie-secret", 39 | "missing setting: client-id", 40 | "missing setting: client-secret"}) 41 | assert.Equal(t, expected, err.Error()) 42 | } 43 | 44 | func TestGoogleGroupOptions(t *testing.T) { 45 | o := testOptions() 46 | o.GoogleGroups = []string{"googlegroup"} 47 | err := o.Validate() 48 | assert.NotEqual(t, nil, err) 49 | 50 | expected := errorMsg([]string{ 51 | "missing setting: google-admin-email", 52 | "missing setting: google-service-account-json"}) 53 | assert.Equal(t, expected, err.Error()) 54 | } 55 | 56 | func TestGoogleGroupInvalidFile(t *testing.T) { 57 | o := testOptions() 58 | o.GoogleGroups = []string{"test_group"} 59 | o.GoogleAdminEmail = "admin@example.com" 60 | o.GoogleServiceAccountJSON = "file_doesnt_exist.json" 61 | err := o.Validate() 62 | assert.NotEqual(t, nil, err) 63 | 64 | expected := errorMsg([]string{ 65 | "invalid Google credentials file: file_doesnt_exist.json", 66 | }) 67 | assert.Equal(t, expected, err.Error()) 68 | } 69 | 70 | func TestInitializedOptions(t *testing.T) { 71 | o := testOptions() 72 | assert.Equal(t, nil, o.Validate()) 73 | } 74 | 75 | // Note that it's not worth testing nonparseable URLs, since url.Parse() 76 | // seems to parse damn near anything. 77 | func TestRedirectURL(t *testing.T) { 78 | o := testOptions() 79 | o.RedirectURL = "https://myhost.com/oauth2/callback" 80 | assert.Equal(t, nil, o.Validate()) 81 | expected := &url.URL{ 82 | Scheme: "https", Host: "myhost.com", Path: "/oauth2/callback"} 83 | assert.Equal(t, expected, o.redirectURL) 84 | } 85 | 86 | func TestProxyURLs(t *testing.T) { 87 | o := testOptions() 88 | o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") 89 | assert.Equal(t, nil, o.Validate()) 90 | expected := []*url.URL{ 91 | &url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, 92 | // note the '/' was added 93 | &url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, 94 | } 95 | assert.Equal(t, expected, o.proxyURLs) 96 | } 97 | 98 | func TestProxyURLsError(t *testing.T) { 99 | o := testOptions() 100 | o.Upstreams = append(o.Upstreams, "127.0.0.1:8081") 101 | err := o.Validate() 102 | assert.NotEqual(t, nil, err) 103 | 104 | expected := errorMsg([]string{ 105 | "error parsing upstream: parse 127.0.0.1:8081: " + 106 | "first path segment in URL cannot contain colon"}) 107 | assert.Equal(t, expected, err.Error()) 108 | } 109 | 110 | func TestCompiledRegex(t *testing.T) { 111 | o := testOptions() 112 | regexps := []string{"/foo/.*", "/ba[rz]/quux"} 113 | o.SkipAuthRegex = regexps 114 | assert.Equal(t, nil, o.Validate()) 115 | actual := make([]string, 0) 116 | for _, regex := range o.CompiledRegex { 117 | actual = append(actual, regex.String()) 118 | } 119 | assert.Equal(t, regexps, actual) 120 | } 121 | 122 | func TestCompiledRegexError(t *testing.T) { 123 | o := testOptions() 124 | o.SkipAuthRegex = []string{"(foobaz", "barquux)"} 125 | err := o.Validate() 126 | assert.NotEqual(t, nil, err) 127 | 128 | expected := errorMsg([]string{ 129 | "error compiling regex=\"(foobaz\" error parsing regexp: " + 130 | "missing closing ): `(foobaz`", 131 | "error compiling regex=\"barquux)\" error parsing regexp: " + 132 | "unexpected ): `barquux)`"}) 133 | assert.Equal(t, expected, err.Error()) 134 | 135 | o.SkipAuthRegex = []string{"foobaz", "barquux)"} 136 | err = o.Validate() 137 | assert.NotEqual(t, nil, err) 138 | 139 | expected = errorMsg([]string{ 140 | "error compiling regex=\"barquux)\" error parsing regexp: " + 141 | "unexpected ): `barquux)`"}) 142 | assert.Equal(t, expected, err.Error()) 143 | } 144 | 145 | func TestDefaultProviderApiSettings(t *testing.T) { 146 | o := testOptions() 147 | assert.Equal(t, nil, o.Validate()) 148 | p := o.provider.Data() 149 | assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", 150 | p.LoginURL.String()) 151 | assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", 152 | p.RedeemURL.String()) 153 | assert.Equal(t, "", p.ProfileURL.String()) 154 | assert.Equal(t, "profile email", p.Scope) 155 | } 156 | 157 | func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) { 158 | o := testOptions() 159 | assert.Equal(t, nil, o.Validate()) 160 | 161 | assert.Equal(t, false, o.PassAccessToken) 162 | o.PassAccessToken = true 163 | o.CookieSecret = "cookie of invalid length-" 164 | assert.NotEqual(t, nil, o.Validate()) 165 | 166 | o.PassAccessToken = false 167 | o.CookieRefresh = time.Duration(24) * time.Hour 168 | assert.NotEqual(t, nil, o.Validate()) 169 | 170 | o.CookieSecret = "16 bytes AES-128" 171 | assert.Equal(t, nil, o.Validate()) 172 | 173 | o.CookieSecret = "24 byte secret AES-192--" 174 | assert.Equal(t, nil, o.Validate()) 175 | 176 | o.CookieSecret = "32 byte secret for AES-256------" 177 | assert.Equal(t, nil, o.Validate()) 178 | } 179 | 180 | func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { 181 | o := testOptions() 182 | assert.Equal(t, nil, o.Validate()) 183 | 184 | o.CookieSecret = "0123456789abcdefabcd" 185 | o.CookieRefresh = o.CookieExpire 186 | assert.NotEqual(t, nil, o.Validate()) 187 | 188 | o.CookieRefresh -= time.Duration(1) 189 | assert.Equal(t, nil, o.Validate()) 190 | } 191 | 192 | func TestBase64CookieSecret(t *testing.T) { 193 | o := testOptions() 194 | assert.Equal(t, nil, o.Validate()) 195 | 196 | // 32 byte, base64 (urlsafe) encoded key 197 | o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ=" 198 | assert.Equal(t, nil, o.Validate()) 199 | 200 | // 32 byte, base64 (urlsafe) encoded key, w/o padding 201 | o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ" 202 | assert.Equal(t, nil, o.Validate()) 203 | 204 | // 24 byte, base64 (urlsafe) encoded key 205 | o.CookieSecret = "Kp33Gj-GQmYtz4zZUyUDdqQKx5_Hgkv3" 206 | assert.Equal(t, nil, o.Validate()) 207 | 208 | // 16 byte, base64 (urlsafe) encoded key 209 | o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA==" 210 | assert.Equal(t, nil, o.Validate()) 211 | 212 | // 16 byte, base64 (urlsafe) encoded key, w/o padding 213 | o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA" 214 | assert.Equal(t, nil, o.Validate()) 215 | } 216 | 217 | func TestValidateSignatureKey(t *testing.T) { 218 | o := testOptions() 219 | o.SignatureKey = "sha1:secret" 220 | assert.Equal(t, nil, o.Validate()) 221 | assert.Equal(t, o.signatureData.hash, crypto.SHA1) 222 | assert.Equal(t, o.signatureData.key, "secret") 223 | } 224 | 225 | func TestValidateSignatureKeyInvalidSpec(t *testing.T) { 226 | o := testOptions() 227 | o.SignatureKey = "invalid spec" 228 | err := o.Validate() 229 | assert.Equal(t, err.Error(), "Invalid configuration:\n"+ 230 | " invalid signature hash:key spec: "+o.SignatureKey) 231 | } 232 | 233 | func TestValidateSignatureKeyUnsupportedAlgorithm(t *testing.T) { 234 | o := testOptions() 235 | o.SignatureKey = "unsupported:default secret" 236 | err := o.Validate() 237 | assert.Equal(t, err.Error(), "Invalid configuration:\n"+ 238 | " unsupported signature hash algorithm: "+o.SignatureKey) 239 | } 240 | 241 | func TestValidateCookie(t *testing.T) { 242 | o := testOptions() 243 | o.CookieName = "_valid_cookie_name" 244 | assert.Equal(t, nil, o.Validate()) 245 | } 246 | 247 | func TestValidateCookieBadName(t *testing.T) { 248 | o := testOptions() 249 | o.CookieName = "_bad_cookie_name{}" 250 | err := o.Validate() 251 | assert.Equal(t, err.Error(), "Invalid configuration:\n"+ 252 | fmt.Sprintf(" invalid cookie name: %q", o.CookieName)) 253 | } 254 | -------------------------------------------------------------------------------- /providers/azure.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/bitly/go-simplejson" 7 | "github.com/bitly/oauth2_proxy/api" 8 | "log" 9 | "net/http" 10 | "net/url" 11 | ) 12 | 13 | type AzureProvider struct { 14 | *ProviderData 15 | Tenant string 16 | } 17 | 18 | func NewAzureProvider(p *ProviderData) *AzureProvider { 19 | p.ProviderName = "Azure" 20 | 21 | if p.ProfileURL == nil || p.ProfileURL.String() == "" { 22 | p.ProfileURL = &url.URL{ 23 | Scheme: "https", 24 | Host: "graph.windows.net", 25 | Path: "/me", 26 | RawQuery: "api-version=1.6", 27 | } 28 | } 29 | if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { 30 | p.ProtectedResource = &url.URL{ 31 | Scheme: "https", 32 | Host: "graph.windows.net", 33 | } 34 | } 35 | if p.Scope == "" { 36 | p.Scope = "openid" 37 | } 38 | 39 | return &AzureProvider{ProviderData: p} 40 | } 41 | 42 | func (p *AzureProvider) Configure(tenant string) { 43 | p.Tenant = tenant 44 | if tenant == "" { 45 | p.Tenant = "common" 46 | } 47 | 48 | if p.LoginURL == nil || p.LoginURL.String() == "" { 49 | p.LoginURL = &url.URL{ 50 | Scheme: "https", 51 | Host: "login.microsoftonline.com", 52 | Path: "/" + p.Tenant + "/oauth2/authorize"} 53 | } 54 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 55 | p.RedeemURL = &url.URL{ 56 | Scheme: "https", 57 | Host: "login.microsoftonline.com", 58 | Path: "/" + p.Tenant + "/oauth2/token", 59 | } 60 | } 61 | } 62 | 63 | func getAzureHeader(access_token string) http.Header { 64 | header := make(http.Header) 65 | header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) 66 | return header 67 | } 68 | 69 | func getEmailFromJSON(json *simplejson.Json) (string, error) { 70 | var email string 71 | var err error 72 | 73 | email, err = json.Get("mail").String() 74 | 75 | if err != nil || email == "" { 76 | otherMails, otherMailsErr := json.Get("otherMails").Array() 77 | if len(otherMails) > 0 { 78 | email = otherMails[0].(string) 79 | } 80 | err = otherMailsErr 81 | } 82 | 83 | return email, err 84 | } 85 | 86 | func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { 87 | var email string 88 | var err error 89 | 90 | if s.AccessToken == "" { 91 | return "", errors.New("missing access token") 92 | } 93 | req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) 94 | if err != nil { 95 | return "", err 96 | } 97 | req.Header = getAzureHeader(s.AccessToken) 98 | 99 | json, err := api.Request(req) 100 | 101 | if err != nil { 102 | return "", err 103 | } 104 | 105 | email, err = getEmailFromJSON(json) 106 | 107 | if err == nil && email != "" { 108 | return email, err 109 | } 110 | 111 | email, err = json.Get("userPrincipalName").String() 112 | 113 | if err != nil { 114 | log.Printf("failed making request %s", err) 115 | return "", err 116 | } 117 | 118 | if email == "" { 119 | log.Printf("failed to get email address") 120 | return "", err 121 | } 122 | 123 | return email, err 124 | } 125 | -------------------------------------------------------------------------------- /providers/azure_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testAzureProvider(hostname string) *AzureProvider { 13 | p := NewAzureProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | ProtectedResource: &url.URL{}, 21 | Scope: ""}) 22 | if hostname != "" { 23 | updateURL(p.Data().LoginURL, hostname) 24 | updateURL(p.Data().RedeemURL, hostname) 25 | updateURL(p.Data().ProfileURL, hostname) 26 | updateURL(p.Data().ValidateURL, hostname) 27 | updateURL(p.Data().ProtectedResource, hostname) 28 | } 29 | return p 30 | } 31 | 32 | func TestAzureProviderDefaults(t *testing.T) { 33 | p := testAzureProvider("") 34 | assert.NotEqual(t, nil, p) 35 | p.Configure("") 36 | assert.Equal(t, "Azure", p.Data().ProviderName) 37 | assert.Equal(t, "common", p.Tenant) 38 | assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/authorize", 39 | p.Data().LoginURL.String()) 40 | assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/token", 41 | p.Data().RedeemURL.String()) 42 | assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", 43 | p.Data().ProfileURL.String()) 44 | assert.Equal(t, "https://graph.windows.net", 45 | p.Data().ProtectedResource.String()) 46 | assert.Equal(t, "", 47 | p.Data().ValidateURL.String()) 48 | assert.Equal(t, "openid", p.Data().Scope) 49 | } 50 | 51 | func TestAzureProviderOverrides(t *testing.T) { 52 | p := NewAzureProvider( 53 | &ProviderData{ 54 | LoginURL: &url.URL{ 55 | Scheme: "https", 56 | Host: "example.com", 57 | Path: "/oauth/auth"}, 58 | RedeemURL: &url.URL{ 59 | Scheme: "https", 60 | Host: "example.com", 61 | Path: "/oauth/token"}, 62 | ProfileURL: &url.URL{ 63 | Scheme: "https", 64 | Host: "example.com", 65 | Path: "/oauth/profile"}, 66 | ValidateURL: &url.URL{ 67 | Scheme: "https", 68 | Host: "example.com", 69 | Path: "/oauth/tokeninfo"}, 70 | ProtectedResource: &url.URL{ 71 | Scheme: "https", 72 | Host: "example.com"}, 73 | Scope: "profile"}) 74 | assert.NotEqual(t, nil, p) 75 | assert.Equal(t, "Azure", p.Data().ProviderName) 76 | assert.Equal(t, "https://example.com/oauth/auth", 77 | p.Data().LoginURL.String()) 78 | assert.Equal(t, "https://example.com/oauth/token", 79 | p.Data().RedeemURL.String()) 80 | assert.Equal(t, "https://example.com/oauth/profile", 81 | p.Data().ProfileURL.String()) 82 | assert.Equal(t, "https://example.com/oauth/tokeninfo", 83 | p.Data().ValidateURL.String()) 84 | assert.Equal(t, "https://example.com", 85 | p.Data().ProtectedResource.String()) 86 | assert.Equal(t, "profile", p.Data().Scope) 87 | } 88 | 89 | func TestAzureSetTenant(t *testing.T) { 90 | p := testAzureProvider("") 91 | p.Configure("example") 92 | assert.Equal(t, "Azure", p.Data().ProviderName) 93 | assert.Equal(t, "example", p.Tenant) 94 | assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/authorize", 95 | p.Data().LoginURL.String()) 96 | assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/token", 97 | p.Data().RedeemURL.String()) 98 | assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", 99 | p.Data().ProfileURL.String()) 100 | assert.Equal(t, "https://graph.windows.net", 101 | p.Data().ProtectedResource.String()) 102 | assert.Equal(t, "", 103 | p.Data().ValidateURL.String()) 104 | assert.Equal(t, "openid", p.Data().Scope) 105 | } 106 | 107 | func testAzureBackend(payload string) *httptest.Server { 108 | path := "/me" 109 | query := "api-version=1.6" 110 | 111 | return httptest.NewServer(http.HandlerFunc( 112 | func(w http.ResponseWriter, r *http.Request) { 113 | url := r.URL 114 | if url.Path != path || url.RawQuery != query { 115 | w.WriteHeader(404) 116 | } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { 117 | w.WriteHeader(403) 118 | } else { 119 | w.WriteHeader(200) 120 | w.Write([]byte(payload)) 121 | } 122 | })) 123 | } 124 | 125 | func TestAzureProviderGetEmailAddress(t *testing.T) { 126 | b := testAzureBackend(`{ "mail": "user@windows.net" }`) 127 | defer b.Close() 128 | 129 | bURL, _ := url.Parse(b.URL) 130 | p := testAzureProvider(bURL.Host) 131 | 132 | session := &SessionState{AccessToken: "imaginary_access_token"} 133 | email, err := p.GetEmailAddress(session) 134 | assert.Equal(t, nil, err) 135 | assert.Equal(t, "user@windows.net", email) 136 | } 137 | 138 | func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { 139 | b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`) 140 | defer b.Close() 141 | 142 | bURL, _ := url.Parse(b.URL) 143 | p := testAzureProvider(bURL.Host) 144 | 145 | session := &SessionState{AccessToken: "imaginary_access_token"} 146 | email, err := p.GetEmailAddress(session) 147 | assert.Equal(t, nil, err) 148 | assert.Equal(t, "user@windows.net", email) 149 | } 150 | 151 | func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { 152 | b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`) 153 | defer b.Close() 154 | 155 | bURL, _ := url.Parse(b.URL) 156 | p := testAzureProvider(bURL.Host) 157 | 158 | session := &SessionState{AccessToken: "imaginary_access_token"} 159 | email, err := p.GetEmailAddress(session) 160 | assert.Equal(t, nil, err) 161 | assert.Equal(t, "user@windows.net", email) 162 | } 163 | 164 | func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { 165 | b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`) 166 | defer b.Close() 167 | 168 | bURL, _ := url.Parse(b.URL) 169 | p := testAzureProvider(bURL.Host) 170 | 171 | session := &SessionState{AccessToken: "imaginary_access_token"} 172 | email, err := p.GetEmailAddress(session) 173 | assert.Equal(t, "type assertion to string failed", err.Error()) 174 | assert.Equal(t, "", email) 175 | } 176 | 177 | func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { 178 | b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`) 179 | defer b.Close() 180 | 181 | bURL, _ := url.Parse(b.URL) 182 | p := testAzureProvider(bURL.Host) 183 | 184 | session := &SessionState{AccessToken: "imaginary_access_token"} 185 | email, err := p.GetEmailAddress(session) 186 | assert.Equal(t, nil, err) 187 | assert.Equal(t, "", email) 188 | } 189 | 190 | func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { 191 | b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`) 192 | defer b.Close() 193 | 194 | bURL, _ := url.Parse(b.URL) 195 | p := testAzureProvider(bURL.Host) 196 | 197 | session := &SessionState{AccessToken: "imaginary_access_token"} 198 | email, err := p.GetEmailAddress(session) 199 | assert.Equal(t, "type assertion to string failed", err.Error()) 200 | assert.Equal(t, "", email) 201 | } 202 | -------------------------------------------------------------------------------- /providers/facebook.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/bitly/oauth2_proxy/api" 10 | ) 11 | 12 | type FacebookProvider struct { 13 | *ProviderData 14 | } 15 | 16 | func NewFacebookProvider(p *ProviderData) *FacebookProvider { 17 | p.ProviderName = "Facebook" 18 | if p.LoginURL.String() == "" { 19 | p.LoginURL = &url.URL{Scheme: "https", 20 | Host: "www.facebook.com", 21 | Path: "/v2.5/dialog/oauth", 22 | // ?granted_scopes=true 23 | } 24 | } 25 | if p.RedeemURL.String() == "" { 26 | p.RedeemURL = &url.URL{Scheme: "https", 27 | Host: "graph.facebook.com", 28 | Path: "/v2.5/oauth/access_token", 29 | } 30 | } 31 | if p.ProfileURL.String() == "" { 32 | p.ProfileURL = &url.URL{Scheme: "https", 33 | Host: "graph.facebook.com", 34 | Path: "/v2.5/me", 35 | } 36 | } 37 | if p.ValidateURL.String() == "" { 38 | p.ValidateURL = p.ProfileURL 39 | } 40 | if p.Scope == "" { 41 | p.Scope = "public_profile email" 42 | } 43 | return &FacebookProvider{ProviderData: p} 44 | } 45 | 46 | func getFacebookHeader(access_token string) http.Header { 47 | header := make(http.Header) 48 | header.Set("Accept", "application/json") 49 | header.Set("x-li-format", "json") 50 | header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) 51 | return header 52 | } 53 | 54 | func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { 55 | if s.AccessToken == "" { 56 | return "", errors.New("missing access token") 57 | } 58 | req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) 59 | if err != nil { 60 | return "", err 61 | } 62 | req.Header = getFacebookHeader(s.AccessToken) 63 | 64 | type result struct { 65 | Email string 66 | } 67 | var r result 68 | err = api.RequestJson(req, &r) 69 | if err != nil { 70 | return "", err 71 | } 72 | if r.Email == "" { 73 | return "", errors.New("no email") 74 | } 75 | return r.Email, nil 76 | } 77 | 78 | func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { 79 | return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) 80 | } 81 | -------------------------------------------------------------------------------- /providers/github.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | "net/url" 10 | "path" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | type GitHubProvider struct { 16 | *ProviderData 17 | Org string 18 | Team string 19 | } 20 | 21 | func NewGitHubProvider(p *ProviderData) *GitHubProvider { 22 | p.ProviderName = "GitHub" 23 | if p.LoginURL == nil || p.LoginURL.String() == "" { 24 | p.LoginURL = &url.URL{ 25 | Scheme: "https", 26 | Host: "github.com", 27 | Path: "/login/oauth/authorize", 28 | } 29 | } 30 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 31 | p.RedeemURL = &url.URL{ 32 | Scheme: "https", 33 | Host: "github.com", 34 | Path: "/login/oauth/access_token", 35 | } 36 | } 37 | // ValidationURL is the API Base URL 38 | if p.ValidateURL == nil || p.ValidateURL.String() == "" { 39 | p.ValidateURL = &url.URL{ 40 | Scheme: "https", 41 | Host: "api.github.com", 42 | Path: "/", 43 | } 44 | } 45 | if p.Scope == "" { 46 | p.Scope = "user:email" 47 | } 48 | return &GitHubProvider{ProviderData: p} 49 | } 50 | func (p *GitHubProvider) SetOrgTeam(org, team string) { 51 | p.Org = org 52 | p.Team = team 53 | if org != "" || team != "" { 54 | p.Scope += " read:org" 55 | } 56 | } 57 | 58 | func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { 59 | // https://developer.github.com/v3/orgs/#list-your-organizations 60 | 61 | var orgs []struct { 62 | Login string `json:"login"` 63 | } 64 | 65 | type orgsPage []struct { 66 | Login string `json:"login"` 67 | } 68 | 69 | pn := 1 70 | for { 71 | params := url.Values{ 72 | "limit": {"200"}, 73 | "page": {strconv.Itoa(pn)}, 74 | } 75 | 76 | endpoint := &url.URL{ 77 | Scheme: p.ValidateURL.Scheme, 78 | Host: p.ValidateURL.Host, 79 | Path: path.Join(p.ValidateURL.Path, "/user/orgs"), 80 | RawQuery: params.Encode(), 81 | } 82 | req, _ := http.NewRequest("GET", endpoint.String(), nil) 83 | req.Header.Set("Accept", "application/vnd.github.v3+json") 84 | req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) 85 | resp, err := http.DefaultClient.Do(req) 86 | if err != nil { 87 | return false, err 88 | } 89 | 90 | body, err := ioutil.ReadAll(resp.Body) 91 | resp.Body.Close() 92 | if err != nil { 93 | return false, err 94 | } 95 | if resp.StatusCode != 200 { 96 | return false, fmt.Errorf( 97 | "got %d from %q %s", resp.StatusCode, endpoint.String(), body) 98 | } 99 | 100 | var op orgsPage 101 | if err := json.Unmarshal(body, &op); err != nil { 102 | return false, err 103 | } 104 | if len(op) == 0 { 105 | break 106 | } 107 | 108 | orgs = append(orgs, op...) 109 | pn += 1 110 | } 111 | 112 | var presentOrgs []string 113 | for _, org := range orgs { 114 | if p.Org == org.Login { 115 | log.Printf("Found Github Organization: %q", org.Login) 116 | return true, nil 117 | } 118 | presentOrgs = append(presentOrgs, org.Login) 119 | } 120 | 121 | log.Printf("Missing Organization:%q in %v", p.Org, presentOrgs) 122 | return false, nil 123 | } 124 | 125 | func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { 126 | // https://developer.github.com/v3/orgs/teams/#list-user-teams 127 | 128 | var teams []struct { 129 | Name string `json:"name"` 130 | Slug string `json:"slug"` 131 | Org struct { 132 | Login string `json:"login"` 133 | } `json:"organization"` 134 | } 135 | 136 | params := url.Values{ 137 | "limit": {"200"}, 138 | } 139 | 140 | endpoint := &url.URL{ 141 | Scheme: p.ValidateURL.Scheme, 142 | Host: p.ValidateURL.Host, 143 | Path: path.Join(p.ValidateURL.Path, "/user/teams"), 144 | RawQuery: params.Encode(), 145 | } 146 | req, _ := http.NewRequest("GET", endpoint.String(), nil) 147 | req.Header.Set("Accept", "application/vnd.github.v3+json") 148 | req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) 149 | resp, err := http.DefaultClient.Do(req) 150 | if err != nil { 151 | return false, err 152 | } 153 | 154 | body, err := ioutil.ReadAll(resp.Body) 155 | resp.Body.Close() 156 | if err != nil { 157 | return false, err 158 | } 159 | if resp.StatusCode != 200 { 160 | return false, fmt.Errorf( 161 | "got %d from %q %s", resp.StatusCode, endpoint.String(), body) 162 | } 163 | 164 | if err := json.Unmarshal(body, &teams); err != nil { 165 | return false, fmt.Errorf("%s unmarshaling %s", err, body) 166 | } 167 | 168 | var hasOrg bool 169 | presentOrgs := make(map[string]bool) 170 | var presentTeams []string 171 | for _, team := range teams { 172 | presentOrgs[team.Org.Login] = true 173 | if p.Org == team.Org.Login { 174 | hasOrg = true 175 | ts := strings.Split(p.Team, ",") 176 | for _, t := range ts { 177 | if t == team.Slug { 178 | log.Printf("Found Github Organization:%q Team:%q (Name:%q)", team.Org.Login, team.Slug, team.Name) 179 | return true, nil 180 | } 181 | } 182 | presentTeams = append(presentTeams, team.Slug) 183 | } 184 | } 185 | if hasOrg { 186 | log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) 187 | } else { 188 | var allOrgs []string 189 | for org, _ := range presentOrgs { 190 | allOrgs = append(allOrgs, org) 191 | } 192 | log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) 193 | } 194 | return false, nil 195 | } 196 | 197 | func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { 198 | 199 | var emails []struct { 200 | Email string `json:"email"` 201 | Primary bool `json:"primary"` 202 | } 203 | 204 | // if we require an Org or Team, check that first 205 | if p.Org != "" { 206 | if p.Team != "" { 207 | if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { 208 | return "", err 209 | } 210 | } else { 211 | if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { 212 | return "", err 213 | } 214 | } 215 | } 216 | 217 | endpoint := &url.URL{ 218 | Scheme: p.ValidateURL.Scheme, 219 | Host: p.ValidateURL.Host, 220 | Path: path.Join(p.ValidateURL.Path, "/user/emails"), 221 | } 222 | req, _ := http.NewRequest("GET", endpoint.String(), nil) 223 | req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) 224 | resp, err := http.DefaultClient.Do(req) 225 | if err != nil { 226 | return "", err 227 | } 228 | body, err := ioutil.ReadAll(resp.Body) 229 | resp.Body.Close() 230 | if err != nil { 231 | return "", err 232 | } 233 | 234 | if resp.StatusCode != 200 { 235 | return "", fmt.Errorf("got %d from %q %s", 236 | resp.StatusCode, endpoint.String(), body) 237 | } 238 | 239 | log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) 240 | 241 | if err := json.Unmarshal(body, &emails); err != nil { 242 | return "", fmt.Errorf("%s unmarshaling %s", err, body) 243 | } 244 | 245 | for _, email := range emails { 246 | if email.Primary { 247 | return email.Email, nil 248 | } 249 | } 250 | 251 | return "", nil 252 | } 253 | 254 | func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { 255 | var user struct { 256 | Login string `json:"login"` 257 | Email string `json:"email"` 258 | } 259 | 260 | endpoint := &url.URL{ 261 | Scheme: p.ValidateURL.Scheme, 262 | Host: p.ValidateURL.Host, 263 | Path: path.Join(p.ValidateURL.Path, "/user"), 264 | } 265 | 266 | req, err := http.NewRequest("GET", endpoint.String(), nil) 267 | if err != nil { 268 | return "", fmt.Errorf("could not create new GET request: %v", err) 269 | } 270 | 271 | req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) 272 | resp, err := http.DefaultClient.Do(req) 273 | if err != nil { 274 | return "", err 275 | } 276 | 277 | body, err := ioutil.ReadAll(resp.Body) 278 | defer resp.Body.Close() 279 | if err != nil { 280 | return "", err 281 | } 282 | 283 | if resp.StatusCode != 200 { 284 | return "", fmt.Errorf("got %d from %q %s", 285 | resp.StatusCode, endpoint.String(), body) 286 | } 287 | 288 | log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) 289 | 290 | if err := json.Unmarshal(body, &user); err != nil { 291 | return "", fmt.Errorf("%s unmarshaling %s", err, body) 292 | } 293 | 294 | return user.Login, nil 295 | } 296 | -------------------------------------------------------------------------------- /providers/github_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testGitHubProvider(hostname string) *GitHubProvider { 13 | p := NewGitHubProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | Scope: ""}) 21 | if hostname != "" { 22 | updateURL(p.Data().LoginURL, hostname) 23 | updateURL(p.Data().RedeemURL, hostname) 24 | updateURL(p.Data().ProfileURL, hostname) 25 | updateURL(p.Data().ValidateURL, hostname) 26 | } 27 | return p 28 | } 29 | 30 | func testGitHubBackend(payload []string) *httptest.Server { 31 | pathToQueryMap := map[string][]string{ 32 | "/user": []string{""}, 33 | "/user/emails": []string{""}, 34 | "/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, 35 | } 36 | 37 | return httptest.NewServer(http.HandlerFunc( 38 | func(w http.ResponseWriter, r *http.Request) { 39 | url := r.URL 40 | query, ok := pathToQueryMap[url.Path] 41 | validQuery := false 42 | index := 0 43 | for i, q := range query { 44 | if q == url.RawQuery { 45 | validQuery = true 46 | index = i 47 | } 48 | } 49 | if !ok { 50 | w.WriteHeader(404) 51 | } else if !validQuery { 52 | w.WriteHeader(404) 53 | } else { 54 | w.WriteHeader(200) 55 | w.Write([]byte(payload[index])) 56 | } 57 | })) 58 | } 59 | 60 | func TestGitHubProviderDefaults(t *testing.T) { 61 | p := testGitHubProvider("") 62 | assert.NotEqual(t, nil, p) 63 | assert.Equal(t, "GitHub", p.Data().ProviderName) 64 | assert.Equal(t, "https://github.com/login/oauth/authorize", 65 | p.Data().LoginURL.String()) 66 | assert.Equal(t, "https://github.com/login/oauth/access_token", 67 | p.Data().RedeemURL.String()) 68 | assert.Equal(t, "https://api.github.com/", 69 | p.Data().ValidateURL.String()) 70 | assert.Equal(t, "user:email", p.Data().Scope) 71 | } 72 | 73 | func TestGitHubProviderOverrides(t *testing.T) { 74 | p := NewGitHubProvider( 75 | &ProviderData{ 76 | LoginURL: &url.URL{ 77 | Scheme: "https", 78 | Host: "example.com", 79 | Path: "/login/oauth/authorize"}, 80 | RedeemURL: &url.URL{ 81 | Scheme: "https", 82 | Host: "example.com", 83 | Path: "/login/oauth/access_token"}, 84 | ValidateURL: &url.URL{ 85 | Scheme: "https", 86 | Host: "api.example.com", 87 | Path: "/"}, 88 | Scope: "profile"}) 89 | assert.NotEqual(t, nil, p) 90 | assert.Equal(t, "GitHub", p.Data().ProviderName) 91 | assert.Equal(t, "https://example.com/login/oauth/authorize", 92 | p.Data().LoginURL.String()) 93 | assert.Equal(t, "https://example.com/login/oauth/access_token", 94 | p.Data().RedeemURL.String()) 95 | assert.Equal(t, "https://api.example.com/", 96 | p.Data().ValidateURL.String()) 97 | assert.Equal(t, "profile", p.Data().Scope) 98 | } 99 | 100 | func TestGitHubProviderGetEmailAddress(t *testing.T) { 101 | b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`}) 102 | defer b.Close() 103 | 104 | bURL, _ := url.Parse(b.URL) 105 | p := testGitHubProvider(bURL.Host) 106 | 107 | session := &SessionState{AccessToken: "imaginary_access_token"} 108 | email, err := p.GetEmailAddress(session) 109 | assert.Equal(t, nil, err) 110 | assert.Equal(t, "michael.bland@gsa.gov", email) 111 | } 112 | 113 | func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { 114 | b := testGitHubBackend([]string{ 115 | `[ {"email": "michael.bland@gsa.gov", "primary": true, "login":"testorg"} ]`, 116 | `[ {"email": "michael.bland1@gsa.gov", "primary": true, "login":"testorg1"} ]`, 117 | `[ ]`, 118 | }) 119 | defer b.Close() 120 | 121 | bURL, _ := url.Parse(b.URL) 122 | p := testGitHubProvider(bURL.Host) 123 | p.Org = "testorg1" 124 | 125 | session := &SessionState{AccessToken: "imaginary_access_token"} 126 | email, err := p.GetEmailAddress(session) 127 | assert.Equal(t, nil, err) 128 | assert.Equal(t, "michael.bland@gsa.gov", email) 129 | } 130 | 131 | // Note that trying to trigger the "failed building request" case is not 132 | // practical, since the only way it can fail is if the URL fails to parse. 133 | func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { 134 | b := testGitHubBackend([]string{"unused payload"}) 135 | defer b.Close() 136 | 137 | bURL, _ := url.Parse(b.URL) 138 | p := testGitHubProvider(bURL.Host) 139 | 140 | // We'll trigger a request failure by using an unexpected access 141 | // token. Alternatively, we could allow the parsing of the payload as 142 | // JSON to fail. 143 | session := &SessionState{AccessToken: "unexpected_access_token"} 144 | email, err := p.GetEmailAddress(session) 145 | assert.NotEqual(t, nil, err) 146 | assert.Equal(t, "", email) 147 | } 148 | 149 | func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { 150 | b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"}) 151 | defer b.Close() 152 | 153 | bURL, _ := url.Parse(b.URL) 154 | p := testGitHubProvider(bURL.Host) 155 | 156 | session := &SessionState{AccessToken: "imaginary_access_token"} 157 | email, err := p.GetEmailAddress(session) 158 | assert.NotEqual(t, nil, err) 159 | assert.Equal(t, "", email) 160 | } 161 | 162 | func TestGitHubProviderGetUserName(t *testing.T) { 163 | b := testGitHubBackend([]string{`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}) 164 | defer b.Close() 165 | 166 | bURL, _ := url.Parse(b.URL) 167 | p := testGitHubProvider(bURL.Host) 168 | 169 | session := &SessionState{AccessToken: "imaginary_access_token"} 170 | email, err := p.GetUserName(session) 171 | assert.Equal(t, nil, err) 172 | assert.Equal(t, "mbland", email) 173 | } 174 | -------------------------------------------------------------------------------- /providers/gitlab.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/bitly/oauth2_proxy/api" 9 | ) 10 | 11 | type GitLabProvider struct { 12 | *ProviderData 13 | } 14 | 15 | func NewGitLabProvider(p *ProviderData) *GitLabProvider { 16 | p.ProviderName = "GitLab" 17 | if p.LoginURL == nil || p.LoginURL.String() == "" { 18 | p.LoginURL = &url.URL{ 19 | Scheme: "https", 20 | Host: "gitlab.com", 21 | Path: "/oauth/authorize", 22 | } 23 | } 24 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 25 | p.RedeemURL = &url.URL{ 26 | Scheme: "https", 27 | Host: "gitlab.com", 28 | Path: "/oauth/token", 29 | } 30 | } 31 | if p.ValidateURL == nil || p.ValidateURL.String() == "" { 32 | p.ValidateURL = &url.URL{ 33 | Scheme: "https", 34 | Host: "gitlab.com", 35 | Path: "/api/v4/user", 36 | } 37 | } 38 | if p.Scope == "" { 39 | p.Scope = "read_user" 40 | } 41 | return &GitLabProvider{ProviderData: p} 42 | } 43 | 44 | func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { 45 | 46 | req, err := http.NewRequest("GET", 47 | p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) 48 | if err != nil { 49 | log.Printf("failed building request %s", err) 50 | return "", err 51 | } 52 | json, err := api.Request(req) 53 | if err != nil { 54 | log.Printf("failed making request %s", err) 55 | return "", err 56 | } 57 | return json.Get("email").String() 58 | } 59 | -------------------------------------------------------------------------------- /providers/gitlab_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testGitLabProvider(hostname string) *GitLabProvider { 13 | p := NewGitLabProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | Scope: ""}) 21 | if hostname != "" { 22 | updateURL(p.Data().LoginURL, hostname) 23 | updateURL(p.Data().RedeemURL, hostname) 24 | updateURL(p.Data().ProfileURL, hostname) 25 | updateURL(p.Data().ValidateURL, hostname) 26 | } 27 | return p 28 | } 29 | 30 | func testGitLabBackend(payload string) *httptest.Server { 31 | path := "/api/v4/user" 32 | query := "access_token=imaginary_access_token" 33 | 34 | return httptest.NewServer(http.HandlerFunc( 35 | func(w http.ResponseWriter, r *http.Request) { 36 | url := r.URL 37 | if url.Path != path || url.RawQuery != query { 38 | w.WriteHeader(404) 39 | } else { 40 | w.WriteHeader(200) 41 | w.Write([]byte(payload)) 42 | } 43 | })) 44 | } 45 | 46 | func TestGitLabProviderDefaults(t *testing.T) { 47 | p := testGitLabProvider("") 48 | assert.NotEqual(t, nil, p) 49 | assert.Equal(t, "GitLab", p.Data().ProviderName) 50 | assert.Equal(t, "https://gitlab.com/oauth/authorize", 51 | p.Data().LoginURL.String()) 52 | assert.Equal(t, "https://gitlab.com/oauth/token", 53 | p.Data().RedeemURL.String()) 54 | assert.Equal(t, "https://gitlab.com/api/v4/user", 55 | p.Data().ValidateURL.String()) 56 | assert.Equal(t, "read_user", p.Data().Scope) 57 | } 58 | 59 | func TestGitLabProviderOverrides(t *testing.T) { 60 | p := NewGitLabProvider( 61 | &ProviderData{ 62 | LoginURL: &url.URL{ 63 | Scheme: "https", 64 | Host: "example.com", 65 | Path: "/oauth/auth"}, 66 | RedeemURL: &url.URL{ 67 | Scheme: "https", 68 | Host: "example.com", 69 | Path: "/oauth/token"}, 70 | ValidateURL: &url.URL{ 71 | Scheme: "https", 72 | Host: "example.com", 73 | Path: "/api/v4/user"}, 74 | Scope: "profile"}) 75 | assert.NotEqual(t, nil, p) 76 | assert.Equal(t, "GitLab", p.Data().ProviderName) 77 | assert.Equal(t, "https://example.com/oauth/auth", 78 | p.Data().LoginURL.String()) 79 | assert.Equal(t, "https://example.com/oauth/token", 80 | p.Data().RedeemURL.String()) 81 | assert.Equal(t, "https://example.com/api/v4/user", 82 | p.Data().ValidateURL.String()) 83 | assert.Equal(t, "profile", p.Data().Scope) 84 | } 85 | 86 | func TestGitLabProviderGetEmailAddress(t *testing.T) { 87 | b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") 88 | defer b.Close() 89 | 90 | b_url, _ := url.Parse(b.URL) 91 | p := testGitLabProvider(b_url.Host) 92 | 93 | session := &SessionState{AccessToken: "imaginary_access_token"} 94 | email, err := p.GetEmailAddress(session) 95 | assert.Equal(t, nil, err) 96 | assert.Equal(t, "michael.bland@gsa.gov", email) 97 | } 98 | 99 | // Note that trying to trigger the "failed building request" case is not 100 | // practical, since the only way it can fail is if the URL fails to parse. 101 | func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { 102 | b := testGitLabBackend("unused payload") 103 | defer b.Close() 104 | 105 | b_url, _ := url.Parse(b.URL) 106 | p := testGitLabProvider(b_url.Host) 107 | 108 | // We'll trigger a request failure by using an unexpected access 109 | // token. Alternatively, we could allow the parsing of the payload as 110 | // JSON to fail. 111 | session := &SessionState{AccessToken: "unexpected_access_token"} 112 | email, err := p.GetEmailAddress(session) 113 | assert.NotEqual(t, nil, err) 114 | assert.Equal(t, "", email) 115 | } 116 | 117 | func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { 118 | b := testGitLabBackend("{\"foo\": \"bar\"}") 119 | defer b.Close() 120 | 121 | b_url, _ := url.Parse(b.URL) 122 | p := testGitLabProvider(b_url.Host) 123 | 124 | session := &SessionState{AccessToken: "imaginary_access_token"} 125 | email, err := p.GetEmailAddress(session) 126 | assert.NotEqual(t, nil, err) 127 | assert.Equal(t, "", email) 128 | } 129 | -------------------------------------------------------------------------------- /providers/google.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "io/ioutil" 11 | "log" 12 | "net/http" 13 | "net/url" 14 | "strings" 15 | "time" 16 | 17 | "golang.org/x/oauth2" 18 | "golang.org/x/oauth2/google" 19 | "google.golang.org/api/admin/directory/v1" 20 | "google.golang.org/api/googleapi" 21 | ) 22 | 23 | type GoogleProvider struct { 24 | *ProviderData 25 | RedeemRefreshURL *url.URL 26 | // GroupValidator is a function that determines if the passed email is in 27 | // the configured Google group. 28 | GroupValidator func(string) bool 29 | } 30 | 31 | func NewGoogleProvider(p *ProviderData) *GoogleProvider { 32 | p.ProviderName = "Google" 33 | if p.LoginURL.String() == "" { 34 | p.LoginURL = &url.URL{Scheme: "https", 35 | Host: "accounts.google.com", 36 | Path: "/o/oauth2/auth", 37 | // to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline 38 | RawQuery: "access_type=offline", 39 | } 40 | } 41 | if p.RedeemURL.String() == "" { 42 | p.RedeemURL = &url.URL{Scheme: "https", 43 | Host: "www.googleapis.com", 44 | Path: "/oauth2/v3/token"} 45 | } 46 | if p.ValidateURL.String() == "" { 47 | p.ValidateURL = &url.URL{Scheme: "https", 48 | Host: "www.googleapis.com", 49 | Path: "/oauth2/v1/tokeninfo"} 50 | } 51 | if p.Scope == "" { 52 | p.Scope = "profile email" 53 | } 54 | 55 | return &GoogleProvider{ 56 | ProviderData: p, 57 | // Set a default GroupValidator to just always return valid (true), it will 58 | // be overwritten if we configured a Google group restriction. 59 | GroupValidator: func(email string) bool { 60 | return true 61 | }, 62 | } 63 | } 64 | 65 | func emailFromIdToken(idToken string) (string, error) { 66 | 67 | // id_token is a base64 encode ID token payload 68 | // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo 69 | jwt := strings.Split(idToken, ".") 70 | jwtData := strings.TrimSuffix(jwt[1], "=") 71 | b, err := base64.RawURLEncoding.DecodeString(jwtData) 72 | if err != nil { 73 | return "", err 74 | } 75 | 76 | var email struct { 77 | Email string `json:"email"` 78 | EmailVerified bool `json:"email_verified"` 79 | } 80 | err = json.Unmarshal(b, &email) 81 | if err != nil { 82 | return "", err 83 | } 84 | if email.Email == "" { 85 | return "", errors.New("missing email") 86 | } 87 | if !email.EmailVerified { 88 | return "", fmt.Errorf("email %s not listed as verified", email.Email) 89 | } 90 | return email.Email, nil 91 | } 92 | 93 | func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { 94 | if code == "" { 95 | err = errors.New("missing code") 96 | return 97 | } 98 | 99 | params := url.Values{} 100 | params.Add("redirect_uri", redirectURL) 101 | params.Add("client_id", p.ClientID) 102 | params.Add("client_secret", p.ClientSecret) 103 | params.Add("code", code) 104 | params.Add("grant_type", "authorization_code") 105 | var req *http.Request 106 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 107 | if err != nil { 108 | return 109 | } 110 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 111 | 112 | resp, err := http.DefaultClient.Do(req) 113 | if err != nil { 114 | return 115 | } 116 | var body []byte 117 | body, err = ioutil.ReadAll(resp.Body) 118 | resp.Body.Close() 119 | if err != nil { 120 | return 121 | } 122 | 123 | if resp.StatusCode != 200 { 124 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 125 | return 126 | } 127 | 128 | var jsonResponse struct { 129 | AccessToken string `json:"access_token"` 130 | RefreshToken string `json:"refresh_token"` 131 | ExpiresIn int64 `json:"expires_in"` 132 | IdToken string `json:"id_token"` 133 | } 134 | err = json.Unmarshal(body, &jsonResponse) 135 | if err != nil { 136 | return 137 | } 138 | var email string 139 | email, err = emailFromIdToken(jsonResponse.IdToken) 140 | if err != nil { 141 | return 142 | } 143 | s = &SessionState{ 144 | AccessToken: jsonResponse.AccessToken, 145 | ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), 146 | RefreshToken: jsonResponse.RefreshToken, 147 | Email: email, 148 | } 149 | return 150 | } 151 | 152 | // SetGroupRestriction configures the GoogleProvider to restrict access to the 153 | // specified group(s). AdminEmail has to be an administrative email on the domain that is 154 | // checked. CredentialsFile is the path to a json file containing a Google service 155 | // account credentials. 156 | func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { 157 | adminService := getAdminService(adminEmail, credentialsReader) 158 | p.GroupValidator = func(email string) bool { 159 | return userInGroup(adminService, groups, email) 160 | } 161 | } 162 | 163 | func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Service { 164 | data, err := ioutil.ReadAll(credentialsReader) 165 | if err != nil { 166 | log.Fatal("can't read Google credentials file:", err) 167 | } 168 | conf, err := google.JWTConfigFromJSON(data, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) 169 | if err != nil { 170 | log.Fatal("can't load Google credentials file:", err) 171 | } 172 | conf.Subject = adminEmail 173 | 174 | client := conf.Client(oauth2.NoContext) 175 | adminService, err := admin.New(client) 176 | if err != nil { 177 | log.Fatal(err) 178 | } 179 | return adminService 180 | } 181 | 182 | func userInGroup(service *admin.Service, groups []string, email string) bool { 183 | user, err := fetchUser(service, email) 184 | if err != nil { 185 | log.Printf("error fetching user: %v", err) 186 | return false 187 | } 188 | id := user.Id 189 | custID := user.CustomerId 190 | 191 | for _, group := range groups { 192 | members, err := fetchGroupMembers(service, group) 193 | if err != nil { 194 | if err, ok := err.(*googleapi.Error); ok && err.Code == 404 { 195 | log.Printf("error fetching members for group %s: group does not exist", group) 196 | } else { 197 | log.Printf("error fetching group members: %v", err) 198 | return false 199 | } 200 | } 201 | 202 | for _, member := range members { 203 | switch member.Type { 204 | case "CUSTOMER": 205 | if member.Id == custID { 206 | return true 207 | } 208 | case "USER": 209 | if member.Id == id { 210 | return true 211 | } 212 | } 213 | } 214 | } 215 | return false 216 | } 217 | 218 | func fetchUser(service *admin.Service, email string) (*admin.User, error) { 219 | user, err := service.Users.Get(email).Do() 220 | return user, err 221 | } 222 | 223 | func fetchGroupMembers(service *admin.Service, group string) ([]*admin.Member, error) { 224 | members := []*admin.Member{} 225 | pageToken := "" 226 | for { 227 | req := service.Members.List(group) 228 | if pageToken != "" { 229 | req.PageToken(pageToken) 230 | } 231 | r, err := req.Do() 232 | if err != nil { 233 | return nil, err 234 | } 235 | for _, member := range r.Members { 236 | members = append(members, member) 237 | } 238 | if r.NextPageToken == "" { 239 | break 240 | } 241 | pageToken = r.NextPageToken 242 | } 243 | return members, nil 244 | } 245 | 246 | // ValidateGroup validates that the provided email exists in the configured Google 247 | // group(s). 248 | func (p *GoogleProvider) ValidateGroup(email string) bool { 249 | return p.GroupValidator(email) 250 | } 251 | 252 | func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 253 | if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { 254 | return false, nil 255 | } 256 | 257 | newToken, duration, err := p.redeemRefreshToken(s.RefreshToken) 258 | if err != nil { 259 | return false, err 260 | } 261 | 262 | // re-check that the user is in the proper google group(s) 263 | if !p.ValidateGroup(s.Email) { 264 | return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) 265 | } 266 | 267 | origExpiration := s.ExpiresOn 268 | s.AccessToken = newToken 269 | s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) 270 | log.Printf("refreshed access token %s (expired on %s)", s, origExpiration) 271 | return true, nil 272 | } 273 | 274 | func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) { 275 | // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh 276 | params := url.Values{} 277 | params.Add("client_id", p.ClientID) 278 | params.Add("client_secret", p.ClientSecret) 279 | params.Add("refresh_token", refreshToken) 280 | params.Add("grant_type", "refresh_token") 281 | var req *http.Request 282 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 283 | if err != nil { 284 | return 285 | } 286 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 287 | 288 | resp, err := http.DefaultClient.Do(req) 289 | if err != nil { 290 | return 291 | } 292 | var body []byte 293 | body, err = ioutil.ReadAll(resp.Body) 294 | resp.Body.Close() 295 | if err != nil { 296 | return 297 | } 298 | 299 | if resp.StatusCode != 200 { 300 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 301 | return 302 | } 303 | 304 | var data struct { 305 | AccessToken string `json:"access_token"` 306 | ExpiresIn int64 `json:"expires_in"` 307 | } 308 | err = json.Unmarshal(body, &data) 309 | if err != nil { 310 | return 311 | } 312 | token = data.AccessToken 313 | expires = time.Duration(data.ExpiresIn) * time.Second 314 | return 315 | } 316 | -------------------------------------------------------------------------------- /providers/google_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func newRedeemServer(body []byte) (*url.URL, *httptest.Server) { 15 | s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 16 | rw.Write(body) 17 | })) 18 | u, _ := url.Parse(s.URL) 19 | return u, s 20 | } 21 | 22 | func newGoogleProvider() *GoogleProvider { 23 | return NewGoogleProvider( 24 | &ProviderData{ 25 | ProviderName: "", 26 | LoginURL: &url.URL{}, 27 | RedeemURL: &url.URL{}, 28 | ProfileURL: &url.URL{}, 29 | ValidateURL: &url.URL{}, 30 | Scope: ""}) 31 | } 32 | 33 | func TestGoogleProviderDefaults(t *testing.T) { 34 | p := newGoogleProvider() 35 | assert.NotEqual(t, nil, p) 36 | assert.Equal(t, "Google", p.Data().ProviderName) 37 | assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", 38 | p.Data().LoginURL.String()) 39 | assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", 40 | p.Data().RedeemURL.String()) 41 | assert.Equal(t, "https://www.googleapis.com/oauth2/v1/tokeninfo", 42 | p.Data().ValidateURL.String()) 43 | assert.Equal(t, "", p.Data().ProfileURL.String()) 44 | assert.Equal(t, "profile email", p.Data().Scope) 45 | } 46 | 47 | func TestGoogleProviderOverrides(t *testing.T) { 48 | p := NewGoogleProvider( 49 | &ProviderData{ 50 | LoginURL: &url.URL{ 51 | Scheme: "https", 52 | Host: "example.com", 53 | Path: "/oauth/auth"}, 54 | RedeemURL: &url.URL{ 55 | Scheme: "https", 56 | Host: "example.com", 57 | Path: "/oauth/token"}, 58 | ProfileURL: &url.URL{ 59 | Scheme: "https", 60 | Host: "example.com", 61 | Path: "/oauth/profile"}, 62 | ValidateURL: &url.URL{ 63 | Scheme: "https", 64 | Host: "example.com", 65 | Path: "/oauth/tokeninfo"}, 66 | Scope: "profile"}) 67 | assert.NotEqual(t, nil, p) 68 | assert.Equal(t, "Google", p.Data().ProviderName) 69 | assert.Equal(t, "https://example.com/oauth/auth", 70 | p.Data().LoginURL.String()) 71 | assert.Equal(t, "https://example.com/oauth/token", 72 | p.Data().RedeemURL.String()) 73 | assert.Equal(t, "https://example.com/oauth/profile", 74 | p.Data().ProfileURL.String()) 75 | assert.Equal(t, "https://example.com/oauth/tokeninfo", 76 | p.Data().ValidateURL.String()) 77 | assert.Equal(t, "profile", p.Data().Scope) 78 | } 79 | 80 | type redeemResponse struct { 81 | AccessToken string `json:"access_token"` 82 | RefreshToken string `json:"refresh_token"` 83 | ExpiresIn int64 `json:"expires_in"` 84 | IdToken string `json:"id_token"` 85 | } 86 | 87 | func TestGoogleProviderGetEmailAddress(t *testing.T) { 88 | p := newGoogleProvider() 89 | body, err := json.Marshal(redeemResponse{ 90 | AccessToken: "a1234", 91 | ExpiresIn: 10, 92 | RefreshToken: "refresh12345", 93 | IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), 94 | }) 95 | assert.Equal(t, nil, err) 96 | var server *httptest.Server 97 | p.RedeemURL, server = newRedeemServer(body) 98 | defer server.Close() 99 | 100 | session, err := p.Redeem("http://redirect/", "code1234") 101 | assert.Equal(t, nil, err) 102 | assert.NotEqual(t, session, nil) 103 | assert.Equal(t, "michael.bland@gsa.gov", session.Email) 104 | assert.Equal(t, "a1234", session.AccessToken) 105 | assert.Equal(t, "refresh12345", session.RefreshToken) 106 | } 107 | 108 | func TestGoogleProviderValidateGroup(t *testing.T) { 109 | p := newGoogleProvider() 110 | p.GroupValidator = func(email string) bool { 111 | return email == "michael.bland@gsa.gov" 112 | } 113 | assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) 114 | p.GroupValidator = func(email string) bool { 115 | return email != "michael.bland@gsa.gov" 116 | } 117 | assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) 118 | } 119 | 120 | func TestGoogleProviderWithoutValidateGroup(t *testing.T) { 121 | p := newGoogleProvider() 122 | assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) 123 | } 124 | 125 | // 126 | func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { 127 | p := newGoogleProvider() 128 | body, err := json.Marshal(redeemResponse{ 129 | AccessToken: "a1234", 130 | IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, 131 | }) 132 | assert.Equal(t, nil, err) 133 | var server *httptest.Server 134 | p.RedeemURL, server = newRedeemServer(body) 135 | defer server.Close() 136 | 137 | session, err := p.Redeem("http://redirect/", "code1234") 138 | assert.NotEqual(t, nil, err) 139 | if session != nil { 140 | t.Errorf("expect nill session %#v", session) 141 | } 142 | } 143 | 144 | func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { 145 | p := newGoogleProvider() 146 | 147 | body, err := json.Marshal(redeemResponse{ 148 | AccessToken: "a1234", 149 | IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), 150 | }) 151 | assert.Equal(t, nil, err) 152 | var server *httptest.Server 153 | p.RedeemURL, server = newRedeemServer(body) 154 | defer server.Close() 155 | 156 | session, err := p.Redeem("http://redirect/", "code1234") 157 | assert.NotEqual(t, nil, err) 158 | if session != nil { 159 | t.Errorf("expect nill session %#v", session) 160 | } 161 | 162 | } 163 | 164 | func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { 165 | p := newGoogleProvider() 166 | body, err := json.Marshal(redeemResponse{ 167 | AccessToken: "a1234", 168 | IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), 169 | }) 170 | assert.Equal(t, nil, err) 171 | var server *httptest.Server 172 | p.RedeemURL, server = newRedeemServer(body) 173 | defer server.Close() 174 | 175 | session, err := p.Redeem("http://redirect/", "code1234") 176 | assert.NotEqual(t, nil, err) 177 | if session != nil { 178 | t.Errorf("expect nill session %#v", session) 179 | } 180 | 181 | } 182 | -------------------------------------------------------------------------------- /providers/internal_util.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/bitly/oauth2_proxy/api" 10 | ) 11 | 12 | // stripToken is a helper function to obfuscate "access_token" 13 | // query parameters 14 | func stripToken(endpoint string) string { 15 | return stripParam("access_token", endpoint) 16 | } 17 | 18 | // stripParam generalizes the obfuscation of a particular 19 | // query parameter - typically 'access_token' or 'client_secret' 20 | // The parameter's second half is replaced by '...' and returned 21 | // as part of the encoded query parameters. 22 | // If the target parameter isn't found, the endpoint is returned 23 | // unmodified. 24 | func stripParam(param, endpoint string) string { 25 | u, err := url.Parse(endpoint) 26 | if err != nil { 27 | log.Printf("error attempting to strip %s: %s", param, err) 28 | return endpoint 29 | } 30 | 31 | if u.RawQuery != "" { 32 | values, err := url.ParseQuery(u.RawQuery) 33 | if err != nil { 34 | log.Printf("error attempting to strip %s: %s", param, err) 35 | return u.String() 36 | } 37 | 38 | if val := values.Get(param); val != "" { 39 | values.Set(param, val[:(len(val)/2)]+"...") 40 | u.RawQuery = values.Encode() 41 | return u.String() 42 | } 43 | } 44 | 45 | return endpoint 46 | } 47 | 48 | // validateToken returns true if token is valid 49 | func validateToken(p Provider, access_token string, header http.Header) bool { 50 | if access_token == "" || p.Data().ValidateURL == nil { 51 | return false 52 | } 53 | endpoint := p.Data().ValidateURL.String() 54 | if len(header) == 0 { 55 | params := url.Values{"access_token": {access_token}} 56 | endpoint = endpoint + "?" + params.Encode() 57 | } 58 | resp, err := api.RequestUnparsedResponse(endpoint, header) 59 | if err != nil { 60 | log.Printf("GET %s", stripToken(endpoint)) 61 | log.Printf("token validation request failed: %s", err) 62 | return false 63 | } 64 | 65 | body, _ := ioutil.ReadAll(resp.Body) 66 | resp.Body.Close() 67 | log.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) 68 | 69 | if resp.StatusCode == 200 { 70 | return true 71 | } 72 | log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) 73 | return false 74 | } 75 | 76 | func updateURL(url *url.URL, hostname string) { 77 | url.Scheme = "http" 78 | url.Host = hostname 79 | } 80 | -------------------------------------------------------------------------------- /providers/internal_util_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type ValidateSessionStateTestProvider struct { 14 | *ProviderData 15 | } 16 | 17 | func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { 18 | return "", errors.New("not implemented") 19 | } 20 | 21 | // Note that we're testing the internal validateToken() used to implement 22 | // several Provider's ValidateSessionState() implementations 23 | func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { 24 | return false 25 | } 26 | 27 | type ValidateSessionStateTest struct { 28 | backend *httptest.Server 29 | response_code int 30 | provider *ValidateSessionStateTestProvider 31 | header http.Header 32 | } 33 | 34 | func NewValidateSessionStateTest() *ValidateSessionStateTest { 35 | var vt_test ValidateSessionStateTest 36 | 37 | vt_test.backend = httptest.NewServer( 38 | http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 | if r.URL.Path != "/oauth/tokeninfo" { 40 | w.WriteHeader(500) 41 | w.Write([]byte("unknown URL")) 42 | } 43 | token_param := r.FormValue("access_token") 44 | if token_param == "" { 45 | missing := false 46 | received_headers := r.Header 47 | for k, _ := range vt_test.header { 48 | received := received_headers.Get(k) 49 | expected := vt_test.header.Get(k) 50 | if received == "" || received != expected { 51 | missing = true 52 | } 53 | } 54 | if missing { 55 | w.WriteHeader(500) 56 | w.Write([]byte("no token param and missing or incorrect headers")) 57 | } 58 | } 59 | w.WriteHeader(vt_test.response_code) 60 | w.Write([]byte("only code matters; contents disregarded")) 61 | 62 | })) 63 | backend_url, _ := url.Parse(vt_test.backend.URL) 64 | vt_test.provider = &ValidateSessionStateTestProvider{ 65 | ProviderData: &ProviderData{ 66 | ValidateURL: &url.URL{ 67 | Scheme: "http", 68 | Host: backend_url.Host, 69 | Path: "/oauth/tokeninfo", 70 | }, 71 | }, 72 | } 73 | vt_test.response_code = 200 74 | return &vt_test 75 | } 76 | 77 | func (vt_test *ValidateSessionStateTest) Close() { 78 | vt_test.backend.Close() 79 | } 80 | 81 | func TestValidateSessionStateValidToken(t *testing.T) { 82 | vt_test := NewValidateSessionStateTest() 83 | defer vt_test.Close() 84 | assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) 85 | } 86 | 87 | func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { 88 | vt_test := NewValidateSessionStateTest() 89 | defer vt_test.Close() 90 | vt_test.header = make(http.Header) 91 | vt_test.header.Set("Authorization", "Bearer foobar") 92 | assert.Equal(t, true, 93 | validateToken(vt_test.provider, "foobar", vt_test.header)) 94 | } 95 | 96 | func TestValidateSessionStateEmptyToken(t *testing.T) { 97 | vt_test := NewValidateSessionStateTest() 98 | defer vt_test.Close() 99 | assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) 100 | } 101 | 102 | func TestValidateSessionStateEmptyValidateURL(t *testing.T) { 103 | vt_test := NewValidateSessionStateTest() 104 | defer vt_test.Close() 105 | vt_test.provider.Data().ValidateURL = nil 106 | assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) 107 | } 108 | 109 | func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { 110 | vt_test := NewValidateSessionStateTest() 111 | // Close immediately to simulate a network failure 112 | vt_test.Close() 113 | assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) 114 | } 115 | 116 | func TestValidateSessionStateExpiredToken(t *testing.T) { 117 | vt_test := NewValidateSessionStateTest() 118 | defer vt_test.Close() 119 | vt_test.response_code = 401 120 | assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) 121 | } 122 | 123 | func TestStripTokenNotPresent(t *testing.T) { 124 | test := "http://local.test/api/test?a=1&b=2" 125 | assert.Equal(t, test, stripToken(test)) 126 | } 127 | 128 | func TestStripToken(t *testing.T) { 129 | test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" 130 | expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" 131 | assert.Equal(t, expected, stripToken(test)) 132 | } 133 | -------------------------------------------------------------------------------- /providers/linkedin.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/bitly/oauth2_proxy/api" 10 | ) 11 | 12 | type LinkedInProvider struct { 13 | *ProviderData 14 | } 15 | 16 | func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { 17 | p.ProviderName = "LinkedIn" 18 | if p.LoginURL.String() == "" { 19 | p.LoginURL = &url.URL{Scheme: "https", 20 | Host: "www.linkedin.com", 21 | Path: "/uas/oauth2/authorization"} 22 | } 23 | if p.RedeemURL.String() == "" { 24 | p.RedeemURL = &url.URL{Scheme: "https", 25 | Host: "www.linkedin.com", 26 | Path: "/uas/oauth2/accessToken"} 27 | } 28 | if p.ProfileURL.String() == "" { 29 | p.ProfileURL = &url.URL{Scheme: "https", 30 | Host: "www.linkedin.com", 31 | Path: "/v1/people/~/email-address"} 32 | } 33 | if p.ValidateURL.String() == "" { 34 | p.ValidateURL = p.ProfileURL 35 | } 36 | if p.Scope == "" { 37 | p.Scope = "r_emailaddress r_basicprofile" 38 | } 39 | return &LinkedInProvider{ProviderData: p} 40 | } 41 | 42 | func getLinkedInHeader(access_token string) http.Header { 43 | header := make(http.Header) 44 | header.Set("Accept", "application/json") 45 | header.Set("x-li-format", "json") 46 | header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) 47 | return header 48 | } 49 | 50 | func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { 51 | if s.AccessToken == "" { 52 | return "", errors.New("missing access token") 53 | } 54 | req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) 55 | if err != nil { 56 | return "", err 57 | } 58 | req.Header = getLinkedInHeader(s.AccessToken) 59 | 60 | json, err := api.Request(req) 61 | if err != nil { 62 | return "", err 63 | } 64 | 65 | email, err := json.String() 66 | if err != nil { 67 | return "", err 68 | } 69 | return email, nil 70 | } 71 | 72 | func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { 73 | return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) 74 | } 75 | -------------------------------------------------------------------------------- /providers/linkedin_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testLinkedInProvider(hostname string) *LinkedInProvider { 13 | p := NewLinkedInProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | Scope: ""}) 21 | if hostname != "" { 22 | updateURL(p.Data().LoginURL, hostname) 23 | updateURL(p.Data().RedeemURL, hostname) 24 | updateURL(p.Data().ProfileURL, hostname) 25 | } 26 | return p 27 | } 28 | 29 | func testLinkedInBackend(payload string) *httptest.Server { 30 | path := "/v1/people/~/email-address" 31 | 32 | return httptest.NewServer(http.HandlerFunc( 33 | func(w http.ResponseWriter, r *http.Request) { 34 | url := r.URL 35 | if url.Path != path { 36 | w.WriteHeader(404) 37 | } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { 38 | w.WriteHeader(403) 39 | } else { 40 | w.WriteHeader(200) 41 | w.Write([]byte(payload)) 42 | } 43 | })) 44 | } 45 | 46 | func TestLinkedInProviderDefaults(t *testing.T) { 47 | p := testLinkedInProvider("") 48 | assert.NotEqual(t, nil, p) 49 | assert.Equal(t, "LinkedIn", p.Data().ProviderName) 50 | assert.Equal(t, "https://www.linkedin.com/uas/oauth2/authorization", 51 | p.Data().LoginURL.String()) 52 | assert.Equal(t, "https://www.linkedin.com/uas/oauth2/accessToken", 53 | p.Data().RedeemURL.String()) 54 | assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", 55 | p.Data().ProfileURL.String()) 56 | assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", 57 | p.Data().ValidateURL.String()) 58 | assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope) 59 | } 60 | 61 | func TestLinkedInProviderOverrides(t *testing.T) { 62 | p := NewLinkedInProvider( 63 | &ProviderData{ 64 | LoginURL: &url.URL{ 65 | Scheme: "https", 66 | Host: "example.com", 67 | Path: "/oauth/auth"}, 68 | RedeemURL: &url.URL{ 69 | Scheme: "https", 70 | Host: "example.com", 71 | Path: "/oauth/token"}, 72 | ProfileURL: &url.URL{ 73 | Scheme: "https", 74 | Host: "example.com", 75 | Path: "/oauth/profile"}, 76 | ValidateURL: &url.URL{ 77 | Scheme: "https", 78 | Host: "example.com", 79 | Path: "/oauth/tokeninfo"}, 80 | Scope: "profile"}) 81 | assert.NotEqual(t, nil, p) 82 | assert.Equal(t, "LinkedIn", p.Data().ProviderName) 83 | assert.Equal(t, "https://example.com/oauth/auth", 84 | p.Data().LoginURL.String()) 85 | assert.Equal(t, "https://example.com/oauth/token", 86 | p.Data().RedeemURL.String()) 87 | assert.Equal(t, "https://example.com/oauth/profile", 88 | p.Data().ProfileURL.String()) 89 | assert.Equal(t, "https://example.com/oauth/tokeninfo", 90 | p.Data().ValidateURL.String()) 91 | assert.Equal(t, "profile", p.Data().Scope) 92 | } 93 | 94 | func TestLinkedInProviderGetEmailAddress(t *testing.T) { 95 | b := testLinkedInBackend(`"user@linkedin.com"`) 96 | defer b.Close() 97 | 98 | b_url, _ := url.Parse(b.URL) 99 | p := testLinkedInProvider(b_url.Host) 100 | 101 | session := &SessionState{AccessToken: "imaginary_access_token"} 102 | email, err := p.GetEmailAddress(session) 103 | assert.Equal(t, nil, err) 104 | assert.Equal(t, "user@linkedin.com", email) 105 | } 106 | 107 | func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { 108 | b := testLinkedInBackend("unused payload") 109 | defer b.Close() 110 | 111 | b_url, _ := url.Parse(b.URL) 112 | p := testLinkedInProvider(b_url.Host) 113 | 114 | // We'll trigger a request failure by using an unexpected access 115 | // token. Alternatively, we could allow the parsing of the payload as 116 | // JSON to fail. 117 | session := &SessionState{AccessToken: "unexpected_access_token"} 118 | email, err := p.GetEmailAddress(session) 119 | assert.NotEqual(t, nil, err) 120 | assert.Equal(t, "", email) 121 | } 122 | 123 | func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { 124 | b := testLinkedInBackend("{\"foo\": \"bar\"}") 125 | defer b.Close() 126 | 127 | b_url, _ := url.Parse(b.URL) 128 | p := testLinkedInProvider(b_url.Host) 129 | 130 | session := &SessionState{AccessToken: "imaginary_access_token"} 131 | email, err := p.GetEmailAddress(session) 132 | assert.NotEqual(t, nil, err) 133 | assert.Equal(t, "", email) 134 | } 135 | -------------------------------------------------------------------------------- /providers/oidc.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "golang.org/x/oauth2" 9 | 10 | oidc "github.com/coreos/go-oidc" 11 | ) 12 | 13 | type OIDCProvider struct { 14 | *ProviderData 15 | 16 | Verifier *oidc.IDTokenVerifier 17 | } 18 | 19 | func NewOIDCProvider(p *ProviderData) *OIDCProvider { 20 | p.ProviderName = "OpenID Connect" 21 | return &OIDCProvider{ProviderData: p} 22 | } 23 | 24 | func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { 25 | ctx := context.Background() 26 | c := oauth2.Config{ 27 | ClientID: p.ClientID, 28 | ClientSecret: p.ClientSecret, 29 | Endpoint: oauth2.Endpoint{ 30 | TokenURL: p.RedeemURL.String(), 31 | }, 32 | RedirectURL: redirectURL, 33 | } 34 | token, err := c.Exchange(ctx, code) 35 | if err != nil { 36 | return nil, fmt.Errorf("token exchange: %v", err) 37 | } 38 | 39 | rawIDToken, ok := token.Extra("id_token").(string) 40 | if !ok { 41 | return nil, fmt.Errorf("token response did not contain an id_token") 42 | } 43 | 44 | // Parse and verify ID Token payload. 45 | idToken, err := p.Verifier.Verify(ctx, rawIDToken) 46 | if err != nil { 47 | return nil, fmt.Errorf("could not verify id_token: %v", err) 48 | } 49 | 50 | // Extract custom claims. 51 | var claims struct { 52 | Email string `json:"email"` 53 | Verified *bool `json:"email_verified"` 54 | } 55 | if err := idToken.Claims(&claims); err != nil { 56 | return nil, fmt.Errorf("failed to parse id_token claims: %v", err) 57 | } 58 | 59 | if claims.Email == "" { 60 | return nil, fmt.Errorf("id_token did not contain an email") 61 | } 62 | if claims.Verified != nil && !*claims.Verified { 63 | return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) 64 | } 65 | 66 | s = &SessionState{ 67 | AccessToken: token.AccessToken, 68 | RefreshToken: token.RefreshToken, 69 | ExpiresOn: token.Expiry, 70 | Email: claims.Email, 71 | } 72 | 73 | return 74 | } 75 | 76 | func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 77 | if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { 78 | return false, nil 79 | } 80 | 81 | origExpiration := s.ExpiresOn 82 | s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second) 83 | fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) 84 | return false, nil 85 | } 86 | -------------------------------------------------------------------------------- /providers/provider_data.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | ) 6 | 7 | type ProviderData struct { 8 | ProviderName string 9 | ClientID string 10 | ClientSecret string 11 | LoginURL *url.URL 12 | RedeemURL *url.URL 13 | ProfileURL *url.URL 14 | ProtectedResource *url.URL 15 | ValidateURL *url.URL 16 | Scope string 17 | ApprovalPrompt string 18 | } 19 | 20 | func (p *ProviderData) Data() *ProviderData { return p } 21 | -------------------------------------------------------------------------------- /providers/provider_default.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "net/url" 11 | 12 | "github.com/bitly/oauth2_proxy/cookie" 13 | ) 14 | 15 | func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { 16 | if code == "" { 17 | err = errors.New("missing code") 18 | return 19 | } 20 | 21 | params := url.Values{} 22 | params.Add("redirect_uri", redirectURL) 23 | params.Add("client_id", p.ClientID) 24 | params.Add("client_secret", p.ClientSecret) 25 | params.Add("code", code) 26 | params.Add("grant_type", "authorization_code") 27 | if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { 28 | params.Add("resource", p.ProtectedResource.String()) 29 | } 30 | 31 | var req *http.Request 32 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 33 | if err != nil { 34 | return 35 | } 36 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 37 | 38 | var resp *http.Response 39 | resp, err = http.DefaultClient.Do(req) 40 | if err != nil { 41 | return nil, err 42 | } 43 | var body []byte 44 | body, err = ioutil.ReadAll(resp.Body) 45 | resp.Body.Close() 46 | if err != nil { 47 | return 48 | } 49 | 50 | if resp.StatusCode != 200 { 51 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 52 | return 53 | } 54 | 55 | // blindly try json and x-www-form-urlencoded 56 | var jsonResponse struct { 57 | AccessToken string `json:"access_token"` 58 | } 59 | err = json.Unmarshal(body, &jsonResponse) 60 | if err == nil { 61 | s = &SessionState{ 62 | AccessToken: jsonResponse.AccessToken, 63 | } 64 | return 65 | } 66 | 67 | var v url.Values 68 | v, err = url.ParseQuery(string(body)) 69 | if err != nil { 70 | return 71 | } 72 | if a := v.Get("access_token"); a != "" { 73 | s = &SessionState{AccessToken: a} 74 | } else { 75 | err = fmt.Errorf("no access token found %s", body) 76 | } 77 | return 78 | } 79 | 80 | // GetLoginURL with typical oauth parameters 81 | func (p *ProviderData) GetLoginURL(redirectURI, state string) string { 82 | var a url.URL 83 | a = *p.LoginURL 84 | params, _ := url.ParseQuery(a.RawQuery) 85 | params.Set("redirect_uri", redirectURI) 86 | params.Set("approval_prompt", p.ApprovalPrompt) 87 | params.Add("scope", p.Scope) 88 | params.Set("client_id", p.ClientID) 89 | params.Set("response_type", "code") 90 | params.Add("state", state) 91 | a.RawQuery = params.Encode() 92 | return a.String() 93 | } 94 | 95 | // CookieForSession serializes a session state for storage in a cookie 96 | func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { 97 | return s.EncodeSessionState(c) 98 | } 99 | 100 | // SessionFromCookie deserializes a session from a cookie value 101 | func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { 102 | return DecodeSessionState(v, c) 103 | } 104 | 105 | func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { 106 | return "", errors.New("not implemented") 107 | } 108 | 109 | // GetUserName returns the Account username 110 | func (p *ProviderData) GetUserName(s *SessionState) (string, error) { 111 | return "", errors.New("not implemented") 112 | } 113 | 114 | // ValidateGroup validates that the provided email exists in the configured provider 115 | // email group(s). 116 | func (p *ProviderData) ValidateGroup(email string) bool { 117 | return true 118 | } 119 | 120 | func (p *ProviderData) ValidateSessionState(s *SessionState) bool { 121 | return validateToken(p, s.AccessToken, nil) 122 | } 123 | 124 | // RefreshSessionIfNeeded 125 | func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 126 | return false, nil 127 | } 128 | -------------------------------------------------------------------------------- /providers/provider_default_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRefresh(t *testing.T) { 11 | p := &ProviderData{} 12 | refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ 13 | ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), 14 | }) 15 | assert.Equal(t, false, refreshed) 16 | assert.Equal(t, nil, err) 17 | } 18 | -------------------------------------------------------------------------------- /providers/providers.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "github.com/bitly/oauth2_proxy/cookie" 5 | ) 6 | 7 | type Provider interface { 8 | Data() *ProviderData 9 | GetEmailAddress(*SessionState) (string, error) 10 | GetUserName(*SessionState) (string, error) 11 | Redeem(string, string) (*SessionState, error) 12 | ValidateGroup(string) bool 13 | ValidateSessionState(*SessionState) bool 14 | GetLoginURL(redirectURI, finalRedirect string) string 15 | RefreshSessionIfNeeded(*SessionState) (bool, error) 16 | SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) 17 | CookieForSession(*SessionState, *cookie.Cipher) (string, error) 18 | } 19 | 20 | func New(provider string, p *ProviderData) Provider { 21 | switch provider { 22 | case "linkedin": 23 | return NewLinkedInProvider(p) 24 | case "facebook": 25 | return NewFacebookProvider(p) 26 | case "github": 27 | return NewGitHubProvider(p) 28 | case "azure": 29 | return NewAzureProvider(p) 30 | case "gitlab": 31 | return NewGitLabProvider(p) 32 | case "oidc": 33 | return NewOIDCProvider(p) 34 | default: 35 | return NewGoogleProvider(p) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /providers/session_state.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "github.com/bitly/oauth2_proxy/cookie" 10 | ) 11 | 12 | type SessionState struct { 13 | AccessToken string 14 | ExpiresOn time.Time 15 | RefreshToken string 16 | Email string 17 | User string 18 | } 19 | 20 | func (s *SessionState) IsExpired() bool { 21 | if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { 22 | return true 23 | } 24 | return false 25 | } 26 | 27 | func (s *SessionState) String() string { 28 | o := fmt.Sprintf("Session{%s", s.accountInfo()) 29 | if s.AccessToken != "" { 30 | o += " token:true" 31 | } 32 | if !s.ExpiresOn.IsZero() { 33 | o += fmt.Sprintf(" expires:%s", s.ExpiresOn) 34 | } 35 | if s.RefreshToken != "" { 36 | o += " refresh_token:true" 37 | } 38 | return o + "}" 39 | } 40 | 41 | func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { 42 | if c == nil || s.AccessToken == "" { 43 | return s.accountInfo(), nil 44 | } 45 | return s.EncryptedString(c) 46 | } 47 | 48 | func (s *SessionState) accountInfo() string { 49 | return fmt.Sprintf("email:%s user:%s", s.Email, s.User) 50 | } 51 | 52 | func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { 53 | var err error 54 | if c == nil { 55 | panic("error. missing cipher") 56 | } 57 | a := s.AccessToken 58 | if a != "" { 59 | if a, err = c.Encrypt(a); err != nil { 60 | return "", err 61 | } 62 | } 63 | r := s.RefreshToken 64 | if r != "" { 65 | if r, err = c.Encrypt(r); err != nil { 66 | return "", err 67 | } 68 | } 69 | return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil 70 | } 71 | 72 | func decodeSessionStatePlain(v string) (s *SessionState, err error) { 73 | chunks := strings.Split(v, " ") 74 | if len(chunks) != 2 { 75 | return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) 76 | } 77 | 78 | email := strings.TrimPrefix(chunks[0], "email:") 79 | user := strings.TrimPrefix(chunks[1], "user:") 80 | if user == "" { 81 | user = strings.Split(email, "@")[0] 82 | } 83 | 84 | return &SessionState{User: user, Email: email}, nil 85 | } 86 | 87 | func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { 88 | if c == nil { 89 | return decodeSessionStatePlain(v) 90 | } 91 | 92 | chunks := strings.Split(v, "|") 93 | if len(chunks) != 4 { 94 | err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) 95 | return 96 | } 97 | 98 | sessionState, err := decodeSessionStatePlain(chunks[0]) 99 | if err != nil { 100 | return nil, err 101 | } 102 | 103 | if chunks[1] != "" { 104 | if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { 105 | return nil, err 106 | } 107 | } 108 | 109 | ts, _ := strconv.Atoi(chunks[2]) 110 | sessionState.ExpiresOn = time.Unix(int64(ts), 0) 111 | 112 | if chunks[3] != "" { 113 | if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { 114 | return nil, err 115 | } 116 | } 117 | 118 | return sessionState, nil 119 | } 120 | -------------------------------------------------------------------------------- /providers/session_state_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/bitly/oauth2_proxy/cookie" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | const secret = "0123456789abcdefghijklmnopqrstuv" 14 | const altSecret = "0000000000abcdefghijklmnopqrstuv" 15 | 16 | func TestSessionStateSerialization(t *testing.T) { 17 | c, err := cookie.NewCipher([]byte(secret)) 18 | assert.Equal(t, nil, err) 19 | c2, err := cookie.NewCipher([]byte(altSecret)) 20 | assert.Equal(t, nil, err) 21 | s := &SessionState{ 22 | Email: "user@domain.com", 23 | AccessToken: "token1234", 24 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 25 | RefreshToken: "refresh4321", 26 | } 27 | encoded, err := s.EncodeSessionState(c) 28 | assert.Equal(t, nil, err) 29 | assert.Equal(t, 3, strings.Count(encoded, "|")) 30 | 31 | ss, err := DecodeSessionState(encoded, c) 32 | t.Logf("%#v", ss) 33 | assert.Equal(t, nil, err) 34 | assert.Equal(t, "user", ss.User) 35 | assert.Equal(t, s.Email, ss.Email) 36 | assert.Equal(t, s.AccessToken, ss.AccessToken) 37 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 38 | assert.Equal(t, s.RefreshToken, ss.RefreshToken) 39 | 40 | // ensure a different cipher can't decode properly (ie: it gets gibberish) 41 | ss, err = DecodeSessionState(encoded, c2) 42 | t.Logf("%#v", ss) 43 | assert.Equal(t, nil, err) 44 | assert.Equal(t, "user", ss.User) 45 | assert.Equal(t, s.Email, ss.Email) 46 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 47 | assert.NotEqual(t, s.AccessToken, ss.AccessToken) 48 | assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) 49 | } 50 | 51 | func TestSessionStateSerializationWithUser(t *testing.T) { 52 | c, err := cookie.NewCipher([]byte(secret)) 53 | assert.Equal(t, nil, err) 54 | c2, err := cookie.NewCipher([]byte(altSecret)) 55 | assert.Equal(t, nil, err) 56 | s := &SessionState{ 57 | User: "just-user", 58 | Email: "user@domain.com", 59 | AccessToken: "token1234", 60 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 61 | RefreshToken: "refresh4321", 62 | } 63 | encoded, err := s.EncodeSessionState(c) 64 | assert.Equal(t, nil, err) 65 | assert.Equal(t, 3, strings.Count(encoded, "|")) 66 | 67 | ss, err := DecodeSessionState(encoded, c) 68 | t.Logf("%#v", ss) 69 | assert.Equal(t, nil, err) 70 | assert.Equal(t, s.User, ss.User) 71 | assert.Equal(t, s.Email, ss.Email) 72 | assert.Equal(t, s.AccessToken, ss.AccessToken) 73 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 74 | assert.Equal(t, s.RefreshToken, ss.RefreshToken) 75 | 76 | // ensure a different cipher can't decode properly (ie: it gets gibberish) 77 | ss, err = DecodeSessionState(encoded, c2) 78 | t.Logf("%#v", ss) 79 | assert.Equal(t, nil, err) 80 | assert.Equal(t, s.User, ss.User) 81 | assert.Equal(t, s.Email, ss.Email) 82 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 83 | assert.NotEqual(t, s.AccessToken, ss.AccessToken) 84 | assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) 85 | } 86 | 87 | func TestSessionStateSerializationNoCipher(t *testing.T) { 88 | s := &SessionState{ 89 | Email: "user@domain.com", 90 | AccessToken: "token1234", 91 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 92 | RefreshToken: "refresh4321", 93 | } 94 | encoded, err := s.EncodeSessionState(nil) 95 | assert.Equal(t, nil, err) 96 | expected := fmt.Sprintf("email:%s user:", s.Email) 97 | assert.Equal(t, expected, encoded) 98 | 99 | // only email should have been serialized 100 | ss, err := DecodeSessionState(encoded, nil) 101 | assert.Equal(t, nil, err) 102 | assert.Equal(t, "user", ss.User) 103 | assert.Equal(t, s.Email, ss.Email) 104 | assert.Equal(t, "", ss.AccessToken) 105 | assert.Equal(t, "", ss.RefreshToken) 106 | } 107 | 108 | func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { 109 | s := &SessionState{ 110 | User: "just-user", 111 | Email: "user@domain.com", 112 | AccessToken: "token1234", 113 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 114 | RefreshToken: "refresh4321", 115 | } 116 | encoded, err := s.EncodeSessionState(nil) 117 | assert.Equal(t, nil, err) 118 | expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User) 119 | assert.Equal(t, expected, encoded) 120 | 121 | // only email should have been serialized 122 | ss, err := DecodeSessionState(encoded, nil) 123 | assert.Equal(t, nil, err) 124 | assert.Equal(t, s.User, ss.User) 125 | assert.Equal(t, s.Email, ss.Email) 126 | assert.Equal(t, "", ss.AccessToken) 127 | assert.Equal(t, "", ss.RefreshToken) 128 | } 129 | 130 | func TestSessionStateAccountInfo(t *testing.T) { 131 | s := &SessionState{ 132 | Email: "user@domain.com", 133 | User: "just-user", 134 | } 135 | expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User) 136 | assert.Equal(t, expected, s.accountInfo()) 137 | 138 | s.Email = "" 139 | expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User) 140 | assert.Equal(t, expected, s.accountInfo()) 141 | } 142 | 143 | func TestExpired(t *testing.T) { 144 | s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} 145 | assert.Equal(t, true, s.IsExpired()) 146 | 147 | s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} 148 | assert.Equal(t, false, s.IsExpired()) 149 | 150 | s = &SessionState{} 151 | assert.Equal(t, false, s.IsExpired()) 152 | } 153 | -------------------------------------------------------------------------------- /string_array.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type StringArray []string 8 | 9 | func (a *StringArray) Set(s string) error { 10 | *a = append(*a, s) 11 | return nil 12 | } 13 | 14 | func (a *StringArray) String() string { 15 | return strings.Join(*a, ",") 16 | } 17 | -------------------------------------------------------------------------------- /templates.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "html/template" 5 | "log" 6 | "path" 7 | ) 8 | 9 | func loadTemplates(dir string) *template.Template { 10 | if dir == "" { 11 | return getTemplates() 12 | } 13 | log.Printf("using custom template directory %q", dir) 14 | t, err := template.New("").ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) 15 | if err != nil { 16 | log.Fatalf("failed parsing template %s", err) 17 | } 18 | return t 19 | } 20 | 21 | func getTemplates() *template.Template { 22 | t, err := template.New("foo").Parse(`{{define "sign_in.html"}} 23 | 24 | 25 | 26 | Sign In 27 | 28 | 110 | 111 | 112 | 121 | 122 | {{ if .CustomLogin }} 123 | 131 | {{ end }} 132 | 142 |
143 | {{ if eq .Footer "-" }} 144 | {{ else if eq .Footer ""}} 145 | Secured with OAuth2 Proxy version {{.Version}} 146 | {{ else }} 147 | {{.Footer}} 148 | {{ end }} 149 |
150 | 151 | 152 | {{end}}`) 153 | if err != nil { 154 | log.Fatalf("failed parsing template %s", err) 155 | } 156 | 157 | t, err = t.Parse(`{{define "error.html"}} 158 | 159 | 160 | 161 | {{.Title}} 162 | 163 | 164 | 165 |

{{.Title}}

166 |

{{.Message}}

167 |
168 |

Sign In

169 | 170 | {{end}}`) 171 | if err != nil { 172 | log.Fatalf("failed parsing template %s", err) 173 | } 174 | return t 175 | } 176 | -------------------------------------------------------------------------------- /templates_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTemplatesCompile(t *testing.T) { 10 | templates := getTemplates() 11 | assert.NotEqual(t, templates, nil) 12 | } 13 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | EXIT_CODE=0 3 | echo "gofmt" 4 | diff -u <(echo -n) <(gofmt -d $(find . -type f -name '*.go' -not -path "./vendor/*")) || EXIT_CODE=1 5 | for pkg in $(go list ./... | grep -v '/vendor/' ); do 6 | echo "testing $pkg" 7 | echo "go vet $pkg" 8 | go vet "$pkg" || EXIT_CODE=1 9 | echo "go test -v $pkg" 10 | go test -v -timeout 90s "$pkg" || EXIT_CODE=1 11 | echo "go test -v -race $pkg" 12 | GOMAXPROCS=4 go test -v -timeout 90s0s -race "$pkg" || EXIT_CODE=1 13 | done 14 | exit $EXIT_CODE -------------------------------------------------------------------------------- /validator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/csv" 5 | "fmt" 6 | "log" 7 | "os" 8 | "strings" 9 | "sync/atomic" 10 | "unsafe" 11 | ) 12 | 13 | type UserMap struct { 14 | usersFile string 15 | m unsafe.Pointer 16 | } 17 | 18 | func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { 19 | um := &UserMap{usersFile: usersFile} 20 | m := make(map[string]bool) 21 | atomic.StorePointer(&um.m, unsafe.Pointer(&m)) 22 | if usersFile != "" { 23 | log.Printf("using authenticated emails file %s", usersFile) 24 | WatchForUpdates(usersFile, done, func() { 25 | um.LoadAuthenticatedEmailsFile() 26 | onUpdate() 27 | }) 28 | um.LoadAuthenticatedEmailsFile() 29 | } 30 | return um 31 | } 32 | 33 | func (um *UserMap) IsValid(email string) (result bool) { 34 | m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) 35 | _, result = m[email] 36 | return 37 | } 38 | 39 | func (um *UserMap) LoadAuthenticatedEmailsFile() { 40 | r, err := os.Open(um.usersFile) 41 | if err != nil { 42 | log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) 43 | } 44 | defer r.Close() 45 | csv_reader := csv.NewReader(r) 46 | csv_reader.Comma = ',' 47 | csv_reader.Comment = '#' 48 | csv_reader.TrimLeadingSpace = true 49 | records, err := csv_reader.ReadAll() 50 | if err != nil { 51 | log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) 52 | return 53 | } 54 | updated := make(map[string]bool) 55 | for _, r := range records { 56 | address := strings.ToLower(strings.TrimSpace(r[0])) 57 | updated[address] = true 58 | } 59 | atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) 60 | } 61 | 62 | func newValidatorImpl(domains []string, usersFile string, 63 | done <-chan bool, onUpdate func()) func(string) bool { 64 | validUsers := NewUserMap(usersFile, done, onUpdate) 65 | 66 | var allowAll bool 67 | for i, domain := range domains { 68 | if domain == "*" { 69 | allowAll = true 70 | continue 71 | } 72 | domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) 73 | } 74 | 75 | validator := func(email string) (valid bool) { 76 | if email == "" { 77 | return 78 | } 79 | email = strings.ToLower(email) 80 | for _, domain := range domains { 81 | valid = valid || strings.HasSuffix(email, domain) 82 | } 83 | if !valid { 84 | valid = validUsers.IsValid(email) 85 | } 86 | if allowAll { 87 | valid = true 88 | } 89 | return valid 90 | } 91 | return validator 92 | } 93 | 94 | func NewValidator(domains []string, usersFile string) func(string) bool { 95 | return newValidatorImpl(domains, usersFile, nil, func() {}) 96 | } 97 | -------------------------------------------------------------------------------- /validator_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | type ValidatorTest struct { 11 | auth_email_file *os.File 12 | done chan bool 13 | update_seen bool 14 | } 15 | 16 | func NewValidatorTest(t *testing.T) *ValidatorTest { 17 | vt := &ValidatorTest{} 18 | var err error 19 | vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") 20 | if err != nil { 21 | t.Fatal("failed to create temp file: " + err.Error()) 22 | } 23 | vt.done = make(chan bool, 1) 24 | return vt 25 | } 26 | 27 | func (vt *ValidatorTest) TearDown() { 28 | vt.done <- true 29 | os.Remove(vt.auth_email_file.Name()) 30 | } 31 | 32 | func (vt *ValidatorTest) NewValidator(domains []string, 33 | updated chan<- bool) func(string) bool { 34 | return newValidatorImpl(domains, vt.auth_email_file.Name(), 35 | vt.done, func() { 36 | if vt.update_seen == false { 37 | updated <- true 38 | vt.update_seen = true 39 | } 40 | }) 41 | } 42 | 43 | // This will close vt.auth_email_file. 44 | func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { 45 | defer vt.auth_email_file.Close() 46 | vt.auth_email_file.WriteString(strings.Join(emails, "\n")) 47 | if err := vt.auth_email_file.Close(); err != nil { 48 | t.Fatal("failed to close temp file " + 49 | vt.auth_email_file.Name() + ": " + err.Error()) 50 | } 51 | } 52 | 53 | func TestValidatorEmpty(t *testing.T) { 54 | vt := NewValidatorTest(t) 55 | defer vt.TearDown() 56 | 57 | vt.WriteEmails(t, []string(nil)) 58 | domains := []string(nil) 59 | validator := vt.NewValidator(domains, nil) 60 | 61 | if validator("foo.bar@example.com") { 62 | t.Error("nothing should validate when the email and " + 63 | "domain lists are empty") 64 | } 65 | } 66 | 67 | func TestValidatorSingleEmail(t *testing.T) { 68 | vt := NewValidatorTest(t) 69 | defer vt.TearDown() 70 | 71 | vt.WriteEmails(t, []string{"foo.bar@example.com"}) 72 | domains := []string(nil) 73 | validator := vt.NewValidator(domains, nil) 74 | 75 | if !validator("foo.bar@example.com") { 76 | t.Error("email should validate") 77 | } 78 | if validator("baz.quux@example.com") { 79 | t.Error("email from same domain but not in list " + 80 | "should not validate when domain list is empty") 81 | } 82 | } 83 | 84 | func TestValidatorSingleDomain(t *testing.T) { 85 | vt := NewValidatorTest(t) 86 | defer vt.TearDown() 87 | 88 | vt.WriteEmails(t, []string(nil)) 89 | domains := []string{"example.com"} 90 | validator := vt.NewValidator(domains, nil) 91 | 92 | if !validator("foo.bar@example.com") { 93 | t.Error("email should validate") 94 | } 95 | if !validator("baz.quux@example.com") { 96 | t.Error("email from same domain should validate") 97 | } 98 | } 99 | 100 | func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) { 101 | vt := NewValidatorTest(t) 102 | defer vt.TearDown() 103 | 104 | vt.WriteEmails(t, []string{ 105 | "xyzzy@example.com", 106 | "plugh@example.com", 107 | }) 108 | domains := []string{"example0.com", "example1.com"} 109 | validator := vt.NewValidator(domains, nil) 110 | 111 | if !validator("foo.bar@example0.com") { 112 | t.Error("email from first domain should validate") 113 | } 114 | if !validator("baz.quux@example1.com") { 115 | t.Error("email from second domain should validate") 116 | } 117 | if !validator("xyzzy@example.com") { 118 | t.Error("first email in list should validate") 119 | } 120 | if !validator("plugh@example.com") { 121 | t.Error("second email in list should validate") 122 | } 123 | if validator("xyzzy.plugh@example.com") { 124 | t.Error("email not in list that matches no domains " + 125 | "should not validate") 126 | } 127 | } 128 | 129 | func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { 130 | vt := NewValidatorTest(t) 131 | defer vt.TearDown() 132 | 133 | vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) 134 | domains := []string{"Frobozz.Com"} 135 | validator := vt.NewValidator(domains, nil) 136 | 137 | if !validator("foo.bar@example.com") { 138 | t.Error("loaded email addresses are not lower-cased") 139 | } 140 | if !validator("Foo.Bar@Example.Com") { 141 | t.Error("validated email addresses are not lower-cased") 142 | } 143 | if !validator("foo.bar@frobozz.com") { 144 | t.Error("loaded domains are not lower-cased") 145 | } 146 | if !validator("foo.bar@Frobozz.Com") { 147 | t.Error("validated domains are not lower-cased") 148 | } 149 | } 150 | 151 | func TestValidatorIgnoreSpacesInAuthEmails(t *testing.T) { 152 | vt := NewValidatorTest(t) 153 | defer vt.TearDown() 154 | 155 | vt.WriteEmails(t, []string{" foo.bar@example.com "}) 156 | domains := []string(nil) 157 | validator := vt.NewValidator(domains, nil) 158 | 159 | if !validator("foo.bar@example.com") { 160 | t.Error("email should validate") 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /validator_watcher_copy_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.3,!plan9,!solaris,!windows 2 | 3 | // Turns out you can't copy over an existing file on Windows. 4 | 5 | package main 6 | 7 | import ( 8 | "io/ioutil" 9 | "os" 10 | "testing" 11 | ) 12 | 13 | func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( 14 | t *testing.T, emails []string) { 15 | orig_file := vt.auth_email_file 16 | var err error 17 | vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") 18 | if err != nil { 19 | t.Fatal("failed to create temp file for copy: " + err.Error()) 20 | } 21 | vt.WriteEmails(t, emails) 22 | err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) 23 | if err != nil { 24 | t.Fatal("failed to copy over temp file: " + err.Error()) 25 | } 26 | vt.auth_email_file = orig_file 27 | } 28 | 29 | func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { 30 | vt := NewValidatorTest(t) 31 | defer vt.TearDown() 32 | 33 | vt.WriteEmails(t, []string{"xyzzy@example.com"}) 34 | domains := []string(nil) 35 | updated := make(chan bool) 36 | validator := vt.NewValidator(domains, updated) 37 | 38 | if !validator("xyzzy@example.com") { 39 | t.Error("email in list should validate") 40 | } 41 | 42 | vt.UpdateEmailFileViaCopyingOver(t, []string{"plugh@example.com"}) 43 | <-updated 44 | 45 | if validator("xyzzy@example.com") { 46 | t.Error("email removed from list should not validate") 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /validator_watcher_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.3,!plan9,!solaris 2 | 3 | package main 4 | 5 | import ( 6 | "io/ioutil" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { 12 | var err error 13 | vt.auth_email_file, err = os.OpenFile( 14 | vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) 15 | if err != nil { 16 | t.Fatal("failed to re-open temp file for updates") 17 | } 18 | vt.WriteEmails(t, emails) 19 | } 20 | 21 | func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( 22 | t *testing.T, emails []string) { 23 | orig_file := vt.auth_email_file 24 | var err error 25 | vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") 26 | if err != nil { 27 | t.Fatal("failed to create temp file for rename and replace: " + 28 | err.Error()) 29 | } 30 | vt.WriteEmails(t, emails) 31 | 32 | moved_name := orig_file.Name() + "-moved" 33 | err = os.Rename(orig_file.Name(), moved_name) 34 | err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) 35 | if err != nil { 36 | t.Fatal("failed to rename and replace temp file: " + 37 | err.Error()) 38 | } 39 | vt.auth_email_file = orig_file 40 | os.Remove(moved_name) 41 | } 42 | 43 | func TestValidatorOverwriteEmailListDirectly(t *testing.T) { 44 | vt := NewValidatorTest(t) 45 | defer vt.TearDown() 46 | 47 | vt.WriteEmails(t, []string{ 48 | "xyzzy@example.com", 49 | "plugh@example.com", 50 | }) 51 | domains := []string(nil) 52 | updated := make(chan bool) 53 | validator := vt.NewValidator(domains, updated) 54 | 55 | if !validator("xyzzy@example.com") { 56 | t.Error("first email in list should validate") 57 | } 58 | if !validator("plugh@example.com") { 59 | t.Error("second email in list should validate") 60 | } 61 | if validator("xyzzy.plugh@example.com") { 62 | t.Error("email not in list that matches no domains " + 63 | "should not validate") 64 | } 65 | 66 | vt.UpdateEmailFile(t, []string{ 67 | "xyzzy.plugh@example.com", 68 | "plugh@example.com", 69 | }) 70 | <-updated 71 | 72 | if validator("xyzzy@example.com") { 73 | t.Error("email removed from list should not validate") 74 | } 75 | if !validator("plugh@example.com") { 76 | t.Error("email retained in list should validate") 77 | } 78 | if !validator("xyzzy.plugh@example.com") { 79 | t.Error("email added to list should validate") 80 | } 81 | } 82 | 83 | func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) { 84 | vt := NewValidatorTest(t) 85 | defer vt.TearDown() 86 | 87 | vt.WriteEmails(t, []string{"xyzzy@example.com"}) 88 | domains := []string(nil) 89 | updated := make(chan bool, 1) 90 | validator := vt.NewValidator(domains, updated) 91 | 92 | if !validator("xyzzy@example.com") { 93 | t.Error("email in list should validate") 94 | } 95 | 96 | vt.UpdateEmailFileViaRenameAndReplace(t, []string{"plugh@example.com"}) 97 | <-updated 98 | 99 | if validator("xyzzy@example.com") { 100 | t.Error("email removed from list should not validate") 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | const VERSION = "2.2.1-alpha" 4 | -------------------------------------------------------------------------------- /watcher.go: -------------------------------------------------------------------------------- 1 | // +build go1.3,!plan9,!solaris 2 | 3 | package main 4 | 5 | import ( 6 | "log" 7 | "os" 8 | "path/filepath" 9 | "time" 10 | 11 | "gopkg.in/fsnotify.v1" 12 | ) 13 | 14 | func WaitForReplacement(filename string, op fsnotify.Op, 15 | watcher *fsnotify.Watcher) { 16 | const sleep_interval = 50 * time.Millisecond 17 | 18 | // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. 19 | if op&fsnotify.Chmod != 0 { 20 | time.Sleep(sleep_interval) 21 | } 22 | for { 23 | if _, err := os.Stat(filename); err == nil { 24 | if err := watcher.Add(filename); err == nil { 25 | log.Printf("watching resumed for %s", filename) 26 | return 27 | } 28 | } 29 | time.Sleep(sleep_interval) 30 | } 31 | } 32 | 33 | func WatchForUpdates(filename string, done <-chan bool, action func()) { 34 | filename = filepath.Clean(filename) 35 | watcher, err := fsnotify.NewWatcher() 36 | if err != nil { 37 | log.Fatal("failed to create watcher for ", filename, ": ", err) 38 | } 39 | go func() { 40 | defer watcher.Close() 41 | for { 42 | select { 43 | case _ = <-done: 44 | log.Printf("Shutting down watcher for: %s", filename) 45 | break 46 | case event := <-watcher.Events: 47 | // On Arch Linux, it appears Chmod events precede Remove events, 48 | // which causes a race between action() and the coming Remove event. 49 | // If the Remove wins, the action() (which calls 50 | // UserMap.LoadAuthenticatedEmailsFile()) crashes when the file 51 | // can't be opened. 52 | if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { 53 | log.Printf("watching interrupted on event: %s", event) 54 | watcher.Remove(filename) 55 | WaitForReplacement(filename, event.Op, watcher) 56 | } 57 | log.Printf("reloading after event: %s", event) 58 | action() 59 | case err := <-watcher.Errors: 60 | log.Printf("error watching %s: %s", filename, err) 61 | } 62 | } 63 | }() 64 | if err = watcher.Add(filename); err != nil { 65 | log.Fatal("failed to add ", filename, " to watcher: ", err) 66 | } 67 | log.Printf("watching %s for updates", filename) 68 | } 69 | -------------------------------------------------------------------------------- /watcher_unsupported.go: -------------------------------------------------------------------------------- 1 | // +build !go1.3 plan9 solaris 2 | 3 | package main 4 | 5 | import ( 6 | "log" 7 | ) 8 | 9 | func WatchForUpdates(filename string, done <-chan bool, action func()) { 10 | log.Printf("file watching not implemented on this platform") 11 | go func() { <-done }() 12 | } 13 | --------------------------------------------------------------------------------