├── .dockerignore ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Dockerfile ├── Dockerfile.arm64 ├── Dockerfile.armv6 ├── Gopkg.lock ├── Gopkg.toml ├── LICENSE ├── Makefile ├── README.md ├── api ├── api.go └── api_test.go ├── configure ├── contrib ├── oauth2_proxy.cfg.example └── oauth2_proxy.service.example ├── cookie ├── cookies.go ├── cookies_test.go └── nonce.go ├── dist.sh ├── env_options.go ├── env_options_test.go ├── htpasswd.go ├── htpasswd_test.go ├── http.go ├── http_test.go ├── logging_handler.go ├── logging_handler_test.go ├── main.go ├── oauthproxy.go ├── oauthproxy_test.go ├── options.go ├── options_test.go ├── providers ├── azure.go ├── azure_test.go ├── facebook.go ├── github.go ├── github_test.go ├── gitlab.go ├── gitlab_test.go ├── google.go ├── google_test.go ├── internal_util.go ├── internal_util_test.go ├── linkedin.go ├── linkedin_test.go ├── logingov.go ├── logingov_test.go ├── oidc.go ├── provider_data.go ├── provider_default.go ├── provider_default_test.go ├── providers.go ├── session_state.go └── session_state_test.go ├── string_array.go ├── templates.go ├── templates_test.go ├── validator.go ├── validator_test.go ├── validator_watcher_copy_test.go ├── validator_watcher_test.go ├── version.go ├── watcher.go └── watcher_unsupported.go /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile.dev 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Default owner should be a Pusher cloud-team member unless overridden by later 2 | # rules in this file 3 | * @pusher/cloud-team 4 | 5 | # login.gov provider 6 | # Note: If @timothy-spencer terms out of his appointment, your best bet 7 | # for finding somebody who can test the oauth2_proxy would be to ask somebody 8 | # in the login.gov team (https://login.gov/developers/), the cloud.gov team 9 | # (https://cloud.gov/docs/help/), or the 18F org (https://18f.gsa.gov/contact/ 10 | # or the public devops channel at https://chat.18f.gov/). 11 | providers/logingov.go @timothy-spencer 12 | providers/logingov_test.go @timothy-spencer 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Expected Behavior 4 | 5 | 6 | 7 | 8 | ## Current Behavior 9 | 10 | 11 | 12 | 13 | ## Possible Solution 14 | 15 | 16 | 17 | 18 | ## Steps to Reproduce (for bugs) 19 | 20 | 21 | 22 | 23 | 1. 24 | 2. 25 | 3. 26 | 4. 27 | 28 | ## Context 29 | 30 | 31 | 32 | 33 | ## Your Environment 34 | 35 | 36 | 37 | - Version used: 38 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | 7 | ## Motivation and Context 8 | 9 | 10 | 11 | 12 | ## How Has This Been Tested? 13 | 14 | 15 | 16 | 17 | 18 | ## Checklist: 19 | 20 | 21 | 22 | 23 | - [ ] My change requires a change to the documentation or CHANGELOG. 24 | - [ ] I have updated the documentation/CHANGELOG accordingly. 25 | - [ ] I have created a feature (non-master) branch for my PR. 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | oauth2_proxy 2 | vendor 3 | dist 4 | release 5 | .godeps 6 | *.exe 7 | .env 8 | 9 | # Go.gitignore 10 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 11 | *.o 12 | *.a 13 | *.so 14 | 15 | # Folders 16 | _obj 17 | _test 18 | 19 | # Architecture specific extensions/prefixes 20 | *.[568vq] 21 | [568vq].out 22 | 23 | *.cgo1.go 24 | *.cgo2.c 25 | _cgo_defun.c 26 | _cgo_gotypes.go 27 | _cgo_export.* 28 | 29 | _testmain.go 30 | 31 | # Editor swap/temp files 32 | .*.swp 33 | 34 | # Dockerfile.dev is ignored by both git and docker 35 | # for faster development cycle of docker build 36 | # cp Dockerfile Dockerfile.dev 37 | # vi Dockerfile.dev 38 | # docker build -f Dockerfile.dev . 39 | Dockerfile.dev 40 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.11.x 4 | - 1.12.x 5 | install: 6 | # Fetch dependencies 7 | - wget -O dep https://github.com/golang/dep/releases/download/v0.5.0/dep-linux-amd64 8 | - chmod +x dep 9 | - mv dep $GOPATH/bin/dep 10 | script: 11 | - ./configure 12 | # Run tests 13 | - make test 14 | sudo: false 15 | notifications: 16 | email: false 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Vx.x.x (Pre-release) 2 | 3 | ## Changes since v3.2.0 4 | 5 | - [#111](https://github.com/pusher/oauth2_proxy/pull/111) Add option for telling where to find a login.gov JWT key file (@timothy-spencer) 6 | 7 | # v3.2.0 8 | 9 | ## Release highlights 10 | - Internal restructure of session state storage to use JSON rather than proprietary scheme 11 | - Added health check options for running on GCP behind a load balancer 12 | - Improved support for protecting websockets 13 | - Added provider for login.gov 14 | - Allow manual configuration of OIDC providers 15 | 16 | ## Important notes 17 | - Dockerfile user is now non-root, this may break your existing deployment 18 | - In the OIDC provider, when no email is returned, the ID Token subject will be used 19 | instead of returning an error 20 | - GitHub user emails must now be primary and verified before authenticating 21 | 22 | ## Changes since v3.1.0 23 | 24 | - [#96](https://github.com/bitly/oauth2_proxy/pull/96) Check if email is verified on GitHub (@caarlos0) 25 | - [#110](https://github.com/pusher/oauth2_proxy/pull/110) Added GCP healthcheck option (@timothy-spencer) 26 | - [#112](https://github.com/pusher/oauth2_proxy/pull/112) Improve websocket support (@gyson) 27 | - [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi) 28 | - Use JSON to encode session state to be stored in browser cookies 29 | - Implement legacy decode function to support existing cookies generated by older versions 30 | - Add detailed table driven tests in session_state_test.go 31 | - [#120](https://github.com/pusher/oauth2_proxy/pull/120) Encrypting user/email from cookie (@costelmoraru) 32 | - [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer) 33 | - [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer) 34 | - [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr) 35 | - [#92](https://github.com/pusher/oauth2_proxy/pull/92) Merge websocket proxy feature from openshift/oauth-proxy (@butzist) 36 | - [#57](https://github.com/pusher/oauth2_proxy/pull/57) Fall back to using OIDC Subject instead of Email (@aigarius) 37 | - [#85](https://github.com/pusher/oauth2_proxy/pull/85) Use non-root user in docker images (@kskewes) 38 | - [#68](https://github.com/pusher/oauth2_proxy/pull/68) forward X-Auth-Access-Token header (@davidholsgrove) 39 | - [#41](https://github.com/pusher/oauth2_proxy/pull/41) Added option to manually specify OIDC endpoints instead of relying on discovery 40 | - [#83](https://github.com/pusher/oauth2_proxy/pull/83) Add `id_token` refresh to Google provider (@leki75) 41 | - [#10](https://github.com/pusher/oauth2_proxy/pull/10) fix redirect url param handling (@dt-rush) 42 | - [#122](https://github.com/pusher/oauth2_proxy/pull/122) Expose -cookie-path as configuration parameter (@costelmoraru) 43 | - [#124](https://github.com/pusher/oauth2_proxy/pull/124) Use Go 1.12 for testing and build environments (@syscll) 44 | 45 | # v3.1.0 46 | 47 | ## Release highlights 48 | 49 | - Introduction of ARM releases and and general improvements to Docker builds 50 | - Improvements to OIDC provider allowing pass-through of ID Tokens 51 | - Multiple redirect domains can now be whitelisted 52 | - Streamed responses are now flushed periodically 53 | 54 | ## Important notes 55 | 56 | - If you have been using [#bitly/621](https://github.com/bitly/oauth2_proxy/pull/621) 57 | and have cookies larger than the 4kb limit, 58 | the cookie splitting pattern has changed and now uses `_` in place of `-` when 59 | indexing cookies. 60 | This will force users to reauthenticate the first time they use `v3.1.0`. 61 | - Streamed responses will now be flushed every 1 second by default. 62 | Previously streamed responses were flushed only when the buffer was full. 63 | To retain the old behaviour set `--flush-interval=0`. 64 | See [#23](https://github.com/pusher/oauth2_proxy/pull/23) for further details. 65 | 66 | ## Changes since v3.0.0 67 | 68 | - [#14](https://github.com/pusher/oauth2_proxy/pull/14) OIDC ID Token, Authorization Headers, Refreshing and Verification (@joelspeed) 69 | - Implement `pass-authorization-header` and `set-authorization-header` flags 70 | - Implement token refreshing in OIDC provider 71 | - Split cookies larger than 4k limit into multiple cookies 72 | - Implement token validation in OIDC provider 73 | - [#15](https://github.com/pusher/oauth2_proxy/pull/15) WhitelistDomains (@joelspeed) 74 | - Add `--whitelist-domain` flag to allow redirection to approved domains after OAuth flow 75 | - [#21](https://github.com/pusher/oauth2_proxy/pull/21) Docker Improvement (@yaegashi) 76 | - Move Docker base image from debian to alpine 77 | - Install ca-certificates in docker image 78 | - [#23](https://github.com/pusher/oauth2_proxy/pull/23) Flushed streaming responses 79 | - Long-running upstream responses will get flushed every (1 second by default) 80 | - [#24](https://github.com/pusher/oauth2_proxy/pull/24) Redirect fix (@agentgonzo) 81 | - After a successful login, you will be redirected to your original URL rather than / 82 | - [#35](https://github.com/pusher/oauth2_proxy/pull/35) arm and arm64 binary releases (@kskewes) 83 | - Add armv6 and arm64 to Makefile `release` target 84 | - [#37](https://github.com/pusher/oauth2_proxy/pull/37) cross build arm and arm64 docker images (@kskewes) 85 | 86 | # v3.0.0 87 | 88 | Adoption of OAuth2_Proxy by Pusher. 89 | Project was hard forked and tidied however no logical changes have occurred since 90 | v2.2 as released by Bitly. 91 | 92 | ## Changes since v2.2: 93 | 94 | - [#7](https://github.com/pusher/oauth2_proxy/pull/7) Migration to Pusher (@joelspeed) 95 | - Move automated build to debian base image 96 | - Add Makefile 97 | - Update CI to run `make test` 98 | - Update Dockerfile to use `make clean oauth2_proxy` 99 | - Update `VERSION` parameter to be set by `ldflags` from Git Status 100 | - Remove lint and test scripts 101 | - Remove Go v1.8.x from Travis CI testing 102 | - Add CODEOWNERS file 103 | - Add CONTRIBUTING guide 104 | - Add Issue and Pull Request templates 105 | - Add Dockerfile 106 | - Fix fsnotify import 107 | - Update README to reflect new repository ownership 108 | - Update CI scripts to separate linting and testing 109 | - Now using `gometalinter` for linting 110 | - Move Go import path from `github.com/bitly/oauth2_proxy` to `github.com/pusher/oauth2_proxy` 111 | - Repository forked on 27/11/18 112 | - README updated to include note that this repository is forked 113 | - CHANGLOG created to track changes to repository from original fork 114 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | To develop on this project, please fork the repo and clone into your `$GOPATH`. 4 | 5 | Dependencies are **not** checked in so please download those separately. 6 | Download the dependencies using [`dep`](https://github.com/golang/dep). 7 | 8 | ```bash 9 | cd $GOPATH/src/github.com # Create this directory if it doesn't exist 10 | git clone git@github.com:/oauth2_proxy pusher/oauth2_proxy 11 | cd pusher/oauth2_proxy 12 | ./configure # Setup your environment variables 13 | make dep 14 | ``` 15 | 16 | ## Pull Requests and Issues 17 | 18 | We track bugs and issues using Github. 19 | 20 | If you find a bug, please open an Issue. 21 | 22 | If you want to fix a bug, please fork, create a feature branch, fix the bug and 23 | open a PR back to this repo. 24 | Please mention the open bug issue number within your PR if applicable. 25 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.12-stretch AS builder 2 | 3 | # Download tools 4 | RUN wget -O $GOPATH/bin/dep https://github.com/golang/dep/releases/download/v0.5.0/dep-linux-amd64 5 | RUN chmod +x $GOPATH/bin/dep 6 | 7 | # Copy sources 8 | WORKDIR $GOPATH/src/github.com/pusher/oauth2_proxy 9 | COPY . . 10 | 11 | # Fetch dependencies 12 | RUN dep ensure --vendor-only 13 | 14 | # Build binary and make sure there is at least an empty key file. 15 | # This is useful for GCP App Engine custom runtime builds, because 16 | # you cannot use multiline variables in their app.yaml, so you have to 17 | # build the key into the container and then tell it where it is 18 | # by setting OAUTH2_PROXY_JWT_KEY_FILE=/etc/ssl/private/jwt_signing_key.pem 19 | # in app.yaml instead. 20 | RUN ./configure && make build && touch jwt_signing_key.pem 21 | 22 | # Copy binary to alpine 23 | FROM alpine:3.8 24 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt 25 | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy 26 | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/jwt_signing_key.pem /etc/ssl/private/jwt_signing_key.pem 27 | 28 | RUN addgroup -S -g 2000 oauth2proxy && adduser -S -u 2000 oauth2proxy -G oauth2proxy 29 | USER oauth2proxy 30 | 31 | ENTRYPOINT ["/bin/oauth2_proxy"] 32 | -------------------------------------------------------------------------------- /Dockerfile.arm64: -------------------------------------------------------------------------------- 1 | FROM golang:1.11-stretch AS builder 2 | 3 | # Download tools 4 | RUN wget -O $GOPATH/bin/dep https://github.com/golang/dep/releases/download/v0.5.0/dep-linux-amd64 5 | RUN chmod +x $GOPATH/bin/dep 6 | 7 | # Copy sources 8 | WORKDIR $GOPATH/src/github.com/pusher/oauth2_proxy 9 | COPY . . 10 | 11 | # Fetch dependencies 12 | RUN dep ensure --vendor-only 13 | 14 | # Build binary and make sure there is at least an empty key file. 15 | # This is useful for GCP App Engine custom runtime builds, because 16 | # you cannot use multiline variables in their app.yaml, so you have to 17 | # build the key into the container and then tell it where it is 18 | # by setting OAUTH2_PROXY_JWT_KEY_FILE=/etc/ssl/private/jwt_signing_key.pem 19 | # in app.yaml instead. 20 | RUN ./configure && GOARCH=arm64 make build && touch jwt_signing_key.pem 21 | 22 | # Copy binary to alpine 23 | FROM arm64v8/alpine:3.8 24 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt 25 | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy 26 | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/jwt_signing_key.pem /etc/ssl/private/jwt_signing_key.pem 27 | 28 | RUN addgroup -S -g 2000 oauth2proxy && adduser -S -u 2000 oauth2proxy -G oauth2proxy 29 | USER oauth2proxy 30 | 31 | ENTRYPOINT ["/bin/oauth2_proxy"] 32 | -------------------------------------------------------------------------------- /Dockerfile.armv6: -------------------------------------------------------------------------------- 1 | FROM golang:1.11-stretch AS builder 2 | 3 | # Download tools 4 | RUN wget -O $GOPATH/bin/dep https://github.com/golang/dep/releases/download/v0.5.0/dep-linux-amd64 5 | RUN chmod +x $GOPATH/bin/dep 6 | 7 | # Copy sources 8 | WORKDIR $GOPATH/src/github.com/pusher/oauth2_proxy 9 | COPY . . 10 | 11 | # Fetch dependencies 12 | RUN dep ensure --vendor-only 13 | 14 | # Build binary and make sure there is at least an empty key file. 15 | # This is useful for GCP App Engine custom runtime builds, because 16 | # you cannot use multiline variables in their app.yaml, so you have to 17 | # build the key into the container and then tell it where it is 18 | # by setting OAUTH2_PROXY_JWT_KEY_FILE=/etc/ssl/private/jwt_signing_key.pem 19 | # in app.yaml instead. 20 | RUN ./configure && GOARCH=arm GOARM=6 make build && touch jwt_signing_key.pem 21 | 22 | # Copy binary to alpine 23 | FROM arm32v6/alpine:3.8 24 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt 25 | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy 26 | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/jwt_signing_key.pem /etc/ssl/private/jwt_signing_key.pem 27 | 28 | RUN addgroup -S -g 2000 oauth2proxy && adduser -S -u 2000 oauth2proxy -G oauth2proxy 29 | USER oauth2proxy 30 | 31 | ENTRYPOINT ["/bin/oauth2_proxy"] 32 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | digest = "1:b24249f5a5e6fbe1eddc94b25973172339ccabeadef4779274f3ed0167c18812" 6 | name = "cloud.google.com/go" 7 | packages = ["compute/metadata"] 8 | pruneopts = "" 9 | revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613" 10 | version = "v0.16.0" 11 | 12 | [[projects]] 13 | digest = "1:289dd4d7abfb3ad2b5f728fbe9b1d5c1bf7d265a3eb9ef92869af1f7baba4c7a" 14 | name = "github.com/BurntSushi/toml" 15 | packages = ["."] 16 | pruneopts = "" 17 | revision = "b26d9c308763d68093482582cea63d69be07a0f0" 18 | version = "v0.3.0" 19 | 20 | [[projects]] 21 | digest = "1:512883404c2a99156e410e9880e3bb35ecccc0c07c1159eb204b5f3ef3c431b3" 22 | name = "github.com/bitly/go-simplejson" 23 | packages = ["."] 24 | pruneopts = "" 25 | revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44" 26 | version = "v0.5.0" 27 | 28 | [[projects]] 29 | branch = "v2" 30 | digest = "1:e5a238f8fa890e529d7e493849bbae8988c9e70344e4630cc4f9a11b00516afb" 31 | name = "github.com/coreos/go-oidc" 32 | packages = ["."] 33 | pruneopts = "" 34 | revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8" 35 | 36 | [[projects]] 37 | digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b" 38 | name = "github.com/davecgh/go-spew" 39 | packages = ["spew"] 40 | pruneopts = "" 41 | revision = "346938d642f2ec3594ed81d874461961cd0faa76" 42 | version = "v1.1.0" 43 | 44 | [[projects]] 45 | digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6" 46 | name = "github.com/dgrijalva/jwt-go" 47 | packages = ["."] 48 | pruneopts = "" 49 | revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" 50 | version = "v3.2.0" 51 | 52 | [[projects]] 53 | branch = "master" 54 | digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4" 55 | name = "github.com/golang/protobuf" 56 | packages = ["proto"] 57 | pruneopts = "" 58 | revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" 59 | 60 | [[projects]] 61 | digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" 62 | name = "github.com/mbland/hmacauth" 63 | packages = ["."] 64 | pruneopts = "" 65 | revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb" 66 | version = "1.0.2" 67 | 68 | [[projects]] 69 | branch = "master" 70 | digest = "1:15c0562bca5d78ac087fb39c211071dc124e79fb18f8b7c3f8a0bc7ffcb2a38e" 71 | name = "github.com/mreiferson/go-options" 72 | packages = ["."] 73 | pruneopts = "" 74 | revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" 75 | 76 | [[projects]] 77 | digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" 78 | name = "github.com/pmezard/go-difflib" 79 | packages = ["difflib"] 80 | pruneopts = "" 81 | revision = "792786c7400a136282c1664665ae0a8db921c6c2" 82 | version = "v1.0.0" 83 | 84 | [[projects]] 85 | branch = "master" 86 | digest = "1:386e12afcfd8964907c92dffd106860c0dedd71dbefae14397b77b724a13343b" 87 | name = "github.com/pquerna/cachecontrol" 88 | packages = [ 89 | ".", 90 | "cacheobject", 91 | ] 92 | pruneopts = "" 93 | revision = "0dec1b30a0215bb68605dfc568e8855066c9202d" 94 | 95 | [[projects]] 96 | digest = "1:3926a4ec9a4ff1a072458451aa2d9b98acd059a45b38f7335d31e06c3d6a0159" 97 | name = "github.com/stretchr/testify" 98 | packages = [ 99 | "assert", 100 | "require", 101 | ] 102 | pruneopts = "" 103 | revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" 104 | version = "v1.1.4" 105 | 106 | [[projects]] 107 | branch = "master" 108 | digest = "1:39630a0e2844fc4297c27caacb394a9fd342f869292284a62f856877adab65bc" 109 | name = "github.com/yhat/wsutil" 110 | packages = ["."] 111 | pruneopts = "" 112 | revision = "1d66fa95c997864ba4d8479f56609620fe542928" 113 | 114 | [[projects]] 115 | branch = "master" 116 | digest = "1:f6a006d27619a4d93bf9b66fe1999b8c8d1fa62bdc63af14f10fbe6fcaa2aa1a" 117 | name = "golang.org/x/crypto" 118 | packages = [ 119 | "bcrypt", 120 | "blowfish", 121 | "ed25519", 122 | "ed25519/internal/edwards25519", 123 | ] 124 | pruneopts = "" 125 | revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94" 126 | 127 | [[projects]] 128 | branch = "master" 129 | digest = "1:130b1bec86c62e121967ee0c69d9c263dc2d3ffe6c7c9a82aca4071c4d068861" 130 | name = "golang.org/x/net" 131 | packages = [ 132 | "context", 133 | "context/ctxhttp", 134 | "websocket", 135 | ] 136 | pruneopts = "" 137 | revision = "9dfe39835686865bff950a07b394c12a98ddc811" 138 | 139 | [[projects]] 140 | branch = "master" 141 | digest = "1:4a61176e8386727e4847b21a5a2625ce56b9c518bc543a28226503e701265db0" 142 | name = "golang.org/x/oauth2" 143 | packages = [ 144 | ".", 145 | "google", 146 | "internal", 147 | "jws", 148 | "jwt", 149 | ] 150 | pruneopts = "" 151 | revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" 152 | 153 | [[projects]] 154 | branch = "master" 155 | digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" 156 | name = "google.golang.org/api" 157 | packages = [ 158 | "admin/directory/v1", 159 | "gensupport", 160 | "googleapi", 161 | "googleapi/internal/uritemplates", 162 | ] 163 | pruneopts = "" 164 | revision = "8791354e7ab150705ede13637a18c1fcc16b62e8" 165 | 166 | [[projects]] 167 | digest = "1:934fb8966f303ede63aa405e2c8d7f0a427a05ea8df335dfdc1833dd4d40756f" 168 | name = "google.golang.org/appengine" 169 | packages = [ 170 | ".", 171 | "internal", 172 | "internal/app_identity", 173 | "internal/base", 174 | "internal/datastore", 175 | "internal/log", 176 | "internal/modules", 177 | "internal/remote_api", 178 | "internal/urlfetch", 179 | "urlfetch", 180 | ] 181 | pruneopts = "" 182 | revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" 183 | version = "v1.0.0" 184 | 185 | [[projects]] 186 | digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" 187 | name = "gopkg.in/fsnotify/fsnotify.v1" 188 | packages = ["."] 189 | pruneopts = "" 190 | revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6" 191 | version = "v1.2.11" 192 | 193 | [[projects]] 194 | digest = "1:be4ed0a2b15944dd777a663681a39260ed05f9c4e213017ed2e2255622c8820c" 195 | name = "gopkg.in/square/go-jose.v2" 196 | packages = [ 197 | ".", 198 | "cipher", 199 | "json", 200 | ] 201 | pruneopts = "" 202 | revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" 203 | version = "v2.1.3" 204 | 205 | [solve-meta] 206 | analyzer-name = "dep" 207 | analyzer-version = 1 208 | input-imports = [ 209 | "github.com/BurntSushi/toml", 210 | "github.com/bitly/go-simplejson", 211 | "github.com/coreos/go-oidc", 212 | "github.com/dgrijalva/jwt-go", 213 | "github.com/mbland/hmacauth", 214 | "github.com/mreiferson/go-options", 215 | "github.com/stretchr/testify/assert", 216 | "github.com/stretchr/testify/require", 217 | "github.com/yhat/wsutil", 218 | "golang.org/x/crypto/bcrypt", 219 | "golang.org/x/net/websocket", 220 | "golang.org/x/oauth2", 221 | "golang.org/x/oauth2/google", 222 | "google.golang.org/api/admin/directory/v1", 223 | "google.golang.org/api/googleapi", 224 | "gopkg.in/fsnotify/fsnotify.v1", 225 | ] 226 | solver-name = "gps-cdcl" 227 | solver-version = 1 228 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | 2 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 3 | # for detailed Gopkg.toml documentation. 4 | # 5 | 6 | [[constraint]] 7 | name = "github.com/BurntSushi/toml" 8 | version = "~0.3.0" 9 | 10 | [[constraint]] 11 | name = "github.com/bitly/go-simplejson" 12 | version = "~0.5.0" 13 | 14 | [[constraint]] 15 | branch = "v2" 16 | name = "github.com/coreos/go-oidc" 17 | 18 | [[constraint]] 19 | branch = "master" 20 | name = "github.com/mreiferson/go-options" 21 | 22 | [[constraint]] 23 | name = "github.com/stretchr/testify" 24 | version = "~1.1.4" 25 | 26 | [[constraint]] 27 | branch = "master" 28 | name = "golang.org/x/oauth2" 29 | 30 | [[constraint]] 31 | branch = "master" 32 | name = "google.golang.org/api" 33 | 34 | [[constraint]] 35 | name = "gopkg.in/fsnotify/fsnotify.v1" 36 | version = "~1.2.0" 37 | 38 | [[constraint]] 39 | branch = "master" 40 | name = "golang.org/x/crypto" 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy 2 | of this software and associated documentation files (the "Software"), to deal 3 | in the Software without restriction, including without limitation the rights 4 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 5 | copies of the Software, and to permit persons to whom the Software is 6 | furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in 9 | all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 17 | THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | include .env 2 | BINARY := oauth2_proxy 3 | VERSION := $(shell git describe --always --dirty --tags 2>/dev/null || echo "undefined") 4 | .NOTPARALLEL: 5 | 6 | .PHONY: all 7 | all: dep lint $(BINARY) 8 | 9 | .PHONY: clean 10 | clean: 11 | rm -rf release 12 | rm -f $(BINARY) 13 | 14 | .PHONY: distclean 15 | distclean: clean 16 | rm -rf vendor 17 | 18 | BIN_DIR := $(GOPATH)/bin 19 | GOMETALINTER := $(BIN_DIR)/gometalinter 20 | 21 | $(GOMETALINTER): 22 | $(GO) get -u github.com/alecthomas/gometalinter 23 | gometalinter --install %> /dev/null 24 | 25 | .PHONY: lint 26 | lint: $(GOMETALINTER) 27 | $(GOMETALINTER) --vendor --disable-all \ 28 | --enable=vet \ 29 | --enable=vetshadow \ 30 | --enable=golint \ 31 | --enable=ineffassign \ 32 | --enable=goconst \ 33 | --enable=deadcode \ 34 | --enable=gofmt \ 35 | --enable=goimports \ 36 | --tests ./... 37 | 38 | .PHONY: dep 39 | dep: 40 | $(DEP) ensure --vendor-only 41 | 42 | .PHONY: build 43 | build: clean $(BINARY) 44 | 45 | $(BINARY): 46 | CGO_ENABLED=0 $(GO) build -a -installsuffix cgo -ldflags="-X main.VERSION=${VERSION}" -o $@ github.com/pusher/oauth2_proxy 47 | 48 | .PHONY: docker 49 | docker: 50 | docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest . 51 | 52 | .PHONY: docker-all 53 | docker-all: docker 54 | docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest-amd64 . 55 | docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION} . 56 | docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION}-amd64 . 57 | docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:latest-arm64 . 58 | docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:${VERSION}-arm64 . 59 | docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:latest-armv6 . 60 | docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:${VERSION}-armv6 . 61 | 62 | .PHONY: docker-push 63 | docker-push: 64 | docker push quay.io/pusher/oauth2_proxy:latest 65 | 66 | .PHONY: docker-push-all 67 | docker-push-all: docker-push 68 | docker push quay.io/pusher/oauth2_proxy:latest-amd64 69 | docker push quay.io/pusher/oauth2_proxy:${VERSION} 70 | docker push quay.io/pusher/oauth2_proxy:${VERSION}-amd64 71 | docker push quay.io/pusher/oauth2_proxy:latest-arm64 72 | docker push quay.io/pusher/oauth2_proxy:${VERSION}-arm64 73 | docker push quay.io/pusher/oauth2_proxy:latest-armv6 74 | docker push quay.io/pusher/oauth2_proxy:${VERSION}-armv6 75 | 76 | .PHONY: test 77 | test: dep lint 78 | $(GO) test -v -race $(go list ./... | grep -v /vendor/) 79 | 80 | .PHONY: release 81 | release: lint test 82 | mkdir release 83 | GOOS=darwin GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-darwin-amd64 github.com/pusher/oauth2_proxy 84 | GOOS=linux GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-linux-amd64 github.com/pusher/oauth2_proxy 85 | GOOS=linux GOARCH=arm64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-linux-arm64 github.com/pusher/oauth2_proxy 86 | GOOS=linux GOARCH=arm GOARM=6 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-linux-armv6 github.com/pusher/oauth2_proxy 87 | GOOS=windows GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-windows-amd64 github.com/pusher/oauth2_proxy 88 | shasum -a 256 release/$(BINARY)-darwin-amd64 > release/$(BINARY)-darwin-amd64-sha256sum.txt 89 | shasum -a 256 release/$(BINARY)-linux-amd64 > release/$(BINARY)-linux-amd64-sha256sum.txt 90 | shasum -a 256 release/$(BINARY)-linux-arm64 > release/$(BINARY)-linux-arm64-sha256sum.txt 91 | shasum -a 256 release/$(BINARY)-linux-armv6 > release/$(BINARY)-linux-armv6-sha256sum.txt 92 | shasum -a 256 release/$(BINARY)-windows-amd64 > release/$(BINARY)-windows-amd64-sha256sum.txt 93 | tar -czvf release/$(BINARY)-$(VERSION).darwin-amd64.$(GO_VERSION).tar.gz release/$(BINARY)-darwin-amd64 94 | tar -czvf release/$(BINARY)-$(VERSION).linux-amd64.$(GO_VERSION).tar.gz release/$(BINARY)-linux-amd64 95 | tar -czvf release/$(BINARY)-$(VERSION).linux-arm64.$(GO_VERSION).tar.gz release/$(BINARY)-linux-arm64 96 | tar -czvf release/$(BINARY)-$(VERSION).linux-armv6.$(GO_VERSION).tar.gz release/$(BINARY)-linux-armv6 97 | tar -czvf release/$(BINARY)-$(VERSION).windows-amd64.$(GO_VERSION).tar.gz release/$(BINARY)-windows-amd64 98 | -------------------------------------------------------------------------------- /api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | 10 | "github.com/bitly/go-simplejson" 11 | ) 12 | 13 | // Request parses the request body into a simplejson.Json object 14 | func Request(req *http.Request) (*simplejson.Json, error) { 15 | resp, err := http.DefaultClient.Do(req) 16 | if err != nil { 17 | log.Printf("%s %s %s", req.Method, req.URL, err) 18 | return nil, err 19 | } 20 | body, err := ioutil.ReadAll(resp.Body) 21 | resp.Body.Close() 22 | log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) 23 | if err != nil { 24 | return nil, err 25 | } 26 | if resp.StatusCode != 200 { 27 | return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) 28 | } 29 | data, err := simplejson.NewJson(body) 30 | if err != nil { 31 | return nil, err 32 | } 33 | return data, nil 34 | } 35 | 36 | // RequestJSON parses the request body into the given interface 37 | func RequestJSON(req *http.Request, v interface{}) error { 38 | resp, err := http.DefaultClient.Do(req) 39 | if err != nil { 40 | log.Printf("%s %s %s", req.Method, req.URL, err) 41 | return err 42 | } 43 | body, err := ioutil.ReadAll(resp.Body) 44 | resp.Body.Close() 45 | log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) 46 | if err != nil { 47 | return err 48 | } 49 | if resp.StatusCode != 200 { 50 | return fmt.Errorf("got %d %s", resp.StatusCode, body) 51 | } 52 | return json.Unmarshal(body, v) 53 | } 54 | 55 | // RequestUnparsedResponse performs a GET and returns the raw response object 56 | func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { 57 | req, err := http.NewRequest("GET", url, nil) 58 | if err != nil { 59 | return nil, err 60 | } 61 | req.Header = header 62 | 63 | return http.DefaultClient.Do(req) 64 | } 65 | -------------------------------------------------------------------------------- /api/api_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/bitly/go-simplejson" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func testBackend(responseCode int, payload string) *httptest.Server { 16 | return httptest.NewServer(http.HandlerFunc( 17 | func(w http.ResponseWriter, r *http.Request) { 18 | w.WriteHeader(responseCode) 19 | w.Write([]byte(payload)) 20 | })) 21 | } 22 | 23 | func TestRequest(t *testing.T) { 24 | backend := testBackend(200, "{\"foo\": \"bar\"}") 25 | defer backend.Close() 26 | 27 | req, _ := http.NewRequest("GET", backend.URL, nil) 28 | response, err := Request(req) 29 | assert.Equal(t, nil, err) 30 | result, err := response.Get("foo").String() 31 | assert.Equal(t, nil, err) 32 | assert.Equal(t, "bar", result) 33 | } 34 | 35 | func TestRequestFailure(t *testing.T) { 36 | // Create a backend to generate a test URL, then close it to cause a 37 | // connection error. 38 | backend := testBackend(200, "{\"foo\": \"bar\"}") 39 | backend.Close() 40 | 41 | req, err := http.NewRequest("GET", backend.URL, nil) 42 | assert.Equal(t, nil, err) 43 | resp, err := Request(req) 44 | assert.Equal(t, (*simplejson.Json)(nil), resp) 45 | assert.NotEqual(t, nil, err) 46 | if !strings.Contains(err.Error(), "refused") { 47 | t.Error("expected error when a connection fails: ", err) 48 | } 49 | } 50 | 51 | func TestHttpErrorCode(t *testing.T) { 52 | backend := testBackend(404, "{\"foo\": \"bar\"}") 53 | defer backend.Close() 54 | 55 | req, err := http.NewRequest("GET", backend.URL, nil) 56 | assert.Equal(t, nil, err) 57 | resp, err := Request(req) 58 | assert.Equal(t, (*simplejson.Json)(nil), resp) 59 | assert.NotEqual(t, nil, err) 60 | } 61 | 62 | func TestJsonParsingError(t *testing.T) { 63 | backend := testBackend(200, "not well-formed JSON") 64 | defer backend.Close() 65 | 66 | req, err := http.NewRequest("GET", backend.URL, nil) 67 | assert.Equal(t, nil, err) 68 | resp, err := Request(req) 69 | assert.Equal(t, (*simplejson.Json)(nil), resp) 70 | assert.NotEqual(t, nil, err) 71 | } 72 | 73 | // Parsing a URL practically never fails, so we won't cover that test case. 74 | func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { 75 | backend := httptest.NewServer(http.HandlerFunc( 76 | func(w http.ResponseWriter, r *http.Request) { 77 | token := r.FormValue("access_token") 78 | if r.URL.Path == "/" && token == "my_token" { 79 | w.WriteHeader(200) 80 | w.Write([]byte("some payload")) 81 | } else { 82 | w.WriteHeader(403) 83 | } 84 | })) 85 | defer backend.Close() 86 | 87 | response, err := RequestUnparsedResponse( 88 | backend.URL+"?access_token=my_token", nil) 89 | assert.Equal(t, nil, err) 90 | assert.Equal(t, 200, response.StatusCode) 91 | body, err := ioutil.ReadAll(response.Body) 92 | assert.Equal(t, nil, err) 93 | response.Body.Close() 94 | assert.Equal(t, "some payload", string(body)) 95 | } 96 | 97 | func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { 98 | backend := testBackend(200, "some payload") 99 | // Close the backend now to force a request failure. 100 | backend.Close() 101 | 102 | response, err := RequestUnparsedResponse( 103 | backend.URL+"?access_token=my_token", nil) 104 | assert.NotEqual(t, nil, err) 105 | assert.Equal(t, (*http.Response)(nil), response) 106 | } 107 | 108 | func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { 109 | backend := httptest.NewServer(http.HandlerFunc( 110 | func(w http.ResponseWriter, r *http.Request) { 111 | if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { 112 | w.WriteHeader(200) 113 | w.Write([]byte("some payload")) 114 | } else { 115 | w.WriteHeader(403) 116 | } 117 | })) 118 | defer backend.Close() 119 | 120 | headers := make(http.Header) 121 | headers.Set("Auth", "my_token") 122 | response, err := RequestUnparsedResponse(backend.URL, headers) 123 | assert.Equal(t, nil, err) 124 | assert.Equal(t, 200, response.StatusCode) 125 | body, err := ioutil.ReadAll(response.Body) 126 | assert.Equal(t, nil, err) 127 | response.Body.Close() 128 | assert.Equal(t, "some payload", string(body)) 129 | } 130 | -------------------------------------------------------------------------------- /configure: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RED='\033[0;31m' 4 | GREEN='\033[0;32m' 5 | BLUE='\033[0;34m' 6 | NC='\033[0m' 7 | 8 | declare -A tools=() 9 | declare -A desired=() 10 | 11 | for arg in "$@"; do 12 | case ${arg%%=*} in 13 | "--with-go") 14 | desired[go]="${arg##*=}" 15 | ;; 16 | "--with-dep") 17 | desired[dep]="${arg##*=}" 18 | ;; 19 | "--help") 20 | printf "${GREEN}$0${NC}\n" 21 | printf " available options:\n" 22 | printf " --with-dep=${BLUE}${NC}\n" 23 | printf " --with-go=${BLUE}${NC}\n" 24 | exit 0 25 | ;; 26 | *) 27 | echo "Unknown option: $arg" 28 | exit 2 29 | ;; 30 | esac 31 | done 32 | 33 | vercomp () { 34 | if [[ $1 == $2 ]] 35 | then 36 | return 0 37 | fi 38 | local IFS=. 39 | local i ver1=($1) ver2=($2) 40 | # fill empty fields in ver1 with zeros 41 | for ((i=${#ver1[@]}; i<${#ver2[@]}; i++)) 42 | do 43 | ver1[i]=0 44 | done 45 | for ((i=0; i<${#ver1[@]}; i++)) 46 | do 47 | if [[ -z ${ver2[i]} ]] 48 | then 49 | # fill empty fields in ver2 with zeros 50 | ver2[i]=0 51 | fi 52 | if ((10#${ver1[i]} > 10#${ver2[i]})) 53 | then 54 | return 1 55 | fi 56 | if ((10#${ver1[i]} < 10#${ver2[i]})) 57 | then 58 | return 2 59 | fi 60 | done 61 | return 0 62 | } 63 | 64 | check_for() { 65 | echo -n "Checking for $1... " 66 | if ! [ -z "${desired[$1]}" ]; then 67 | TOOL_PATH="${desired[$1]}" 68 | else 69 | TOOL_PATH=$(command -v $1) 70 | fi 71 | if ! [ -x "$TOOL_PATH" -a -f "$TOOL_PATH" ]; then 72 | printf "${RED}not found${NC}\n" 73 | cd - 74 | exit 1 75 | else 76 | printf "${GREEN}found${NC}\n" 77 | tools[$1]=$TOOL_PATH 78 | fi 79 | } 80 | 81 | check_go_version() { 82 | echo -n "Checking go version... " 83 | GO_VERSION=$(${tools[go]} version | ${tools[awk]} '{where = match($0, /[0-9]\.[0-9]+\.[0-9]*/); if (where != 0) print substr($0, RSTART, RLENGTH)}') 84 | vercomp $GO_VERSION 1.11 85 | case $? in 86 | 0) ;& 87 | 1) 88 | printf "${GREEN}" 89 | echo $GO_VERSION 90 | printf "${NC}" 91 | ;; 92 | 2) 93 | printf "${RED}" 94 | echo "$GO_VERSION < 1.11" 95 | exit 1 96 | ;; 97 | esac 98 | VERSION=$(${tools[go]} version | ${tools[awk]} '{print $3}') 99 | tools["go_version"]="${VERSION}" 100 | } 101 | 102 | check_docker_version() { 103 | echo -n "Checking docker version... " 104 | DOCKER_VERSION=$(${tools[docker]} version | ${tools[awk]}) 105 | } 106 | 107 | check_go_env() { 108 | echo -n "Checking \$GOPATH... " 109 | GOPATH="$(go env GOPATH)" 110 | if [ -z "$GOPATH" ]; then 111 | printf "${RED}invalid${NC} - GOPATH not set\n" 112 | exit 1 113 | fi 114 | printf "${GREEN}valid${NC} - $GOPATH\n" 115 | } 116 | 117 | cd ${0%/*} 118 | 119 | if [ ! -f .env ]; then 120 | rm .env 121 | fi 122 | 123 | check_for make 124 | check_for awk 125 | check_for go 126 | check_go_version 127 | check_go_env 128 | check_for dep 129 | 130 | echo 131 | 132 | cat <<- EOF > .env 133 | MAKE := "${tools[make]}" 134 | GO := "${tools[go]}" 135 | GO_VERSION := ${tools[go_version]} 136 | DEP := "${tools[dep]}" 137 | EOF 138 | 139 | echo "Environment configuration written to .env" 140 | 141 | cd - > /dev/null 142 | -------------------------------------------------------------------------------- /contrib/oauth2_proxy.cfg.example: -------------------------------------------------------------------------------- 1 | ## OAuth2 Proxy Config File 2 | ## https://github.com/bitly/oauth2_proxy 3 | 4 | ## : to listen on for HTTP/HTTPS clients 5 | # http_address = "127.0.0.1:4180" 6 | # https_address = ":443" 7 | 8 | ## TLS Settings 9 | # tls_cert_file = "" 10 | # tls_key_file = "" 11 | 12 | ## the OAuth Redirect URL. 13 | # defaults to the "https://" + requested host header + "/oauth2/callback" 14 | # redirect_url = "https://internalapp.yourcompany.com/oauth2/callback" 15 | 16 | ## the http url(s) of the upstream endpoint. If multiple, routing is based on path 17 | # upstreams = [ 18 | # "http://127.0.0.1:8080/" 19 | # ] 20 | 21 | ## Log requests to stdout 22 | # request_logging = true 23 | 24 | ## pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream 25 | # pass_basic_auth = true 26 | # pass_user_headers = true 27 | ## pass the request Host Header to upstream 28 | ## when disabled the upstream Host is used as the Host Header 29 | # pass_host_header = true 30 | 31 | ## Email Domains to allow authentication for (this authorizes any email on this domain) 32 | ## for more granular authorization use `authenticated_emails_file` 33 | ## To authorize any email addresses use "*" 34 | # email_domains = [ 35 | # "yourcompany.com" 36 | # ] 37 | 38 | ## The OAuth Client ID, Secret 39 | # client_id = "123456.apps.googleusercontent.com" 40 | # client_secret = "" 41 | 42 | ## Pass OAuth Access token to upstream via "X-Forwarded-Access-Token" 43 | # pass_access_token = false 44 | 45 | ## Authenticated Email Addresses File (one email per line) 46 | # authenticated_emails_file = "" 47 | 48 | ## Htpasswd File (optional) 49 | ## Additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption 50 | ## enabling exposes a username/login signin form 51 | # htpasswd_file = "" 52 | 53 | ## Templates 54 | ## optional directory with custom sign_in.html and error.html 55 | # custom_templates_dir = "" 56 | 57 | ## skip SSL checking for HTTPS requests 58 | # ssl_insecure_skip_verify = false 59 | 60 | 61 | ## Cookie Settings 62 | ## Name - the cookie name 63 | ## Secret - the seed string for secure cookies; should be 16, 24, or 32 bytes 64 | ## for use with an AES cipher when cookie_refresh or pass_access_token 65 | ## is set 66 | ## Domain - (optional) cookie domain to force cookies to (ie: .yourcompany.com) 67 | ## Expire - (duration) expire timeframe for cookie 68 | ## Refresh - (duration) refresh the cookie when duration has elapsed after cookie was initially set. 69 | ## Should be less than cookie_expire; set to 0 to disable. 70 | ## On refresh, OAuth token is re-validated. 71 | ## (ie: 1h means tokens are refreshed on request 1hr+ after it was set) 72 | ## Secure - secure cookies are only sent by the browser of a HTTPS connection (recommended) 73 | ## HttpOnly - httponly cookies are not readable by javascript (recommended) 74 | # cookie_name = "_oauth2_proxy" 75 | # cookie_secret = "" 76 | # cookie_domain = "" 77 | # cookie_expire = "168h" 78 | # cookie_refresh = "" 79 | # cookie_secure = true 80 | # cookie_httponly = true 81 | -------------------------------------------------------------------------------- /contrib/oauth2_proxy.service.example: -------------------------------------------------------------------------------- 1 | # Systemd service file for oauth2_proxy daemon 2 | # 3 | # Date: Feb 9, 2016 4 | # Author: Srdjan Grubor 5 | 6 | [Unit] 7 | Description=oauth2_proxy daemon service 8 | After=syslog.target network.target 9 | 10 | [Service] 11 | # www-data group and user need to be created before using these lines 12 | User=www-data 13 | Group=www-data 14 | 15 | ExecStart=/usr/local/bin/oauth2_proxy -config=/etc/oauth2_proxy.cfg 16 | ExecReload=/bin/kill -HUP $MAINPID 17 | 18 | KillMode=process 19 | Restart=always 20 | 21 | [Install] 22 | WantedBy=multi-user.target 23 | -------------------------------------------------------------------------------- /cookie/cookies.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/hmac" 7 | "crypto/rand" 8 | "crypto/sha1" 9 | "encoding/base64" 10 | "fmt" 11 | "io" 12 | "net/http" 13 | "strconv" 14 | "strings" 15 | "time" 16 | ) 17 | 18 | // cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set. 19 | // additionally, the 'value' is encrypted so it's opaque to the browser 20 | 21 | // Validate ensures a cookie is properly signed 22 | func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { 23 | // value, timestamp, sig 24 | parts := strings.Split(cookie.Value, "|") 25 | if len(parts) != 3 { 26 | return 27 | } 28 | sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) 29 | if checkHmac(parts[2], sig) { 30 | ts, err := strconv.Atoi(parts[1]) 31 | if err != nil { 32 | return 33 | } 34 | // The expiration timestamp set when the cookie was created 35 | // isn't sent back by the browser. Hence, we check whether the 36 | // creation timestamp stored in the cookie falls within the 37 | // window defined by (Now()-expiration, Now()]. 38 | t = time.Unix(int64(ts), 0) 39 | if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { 40 | // it's a valid cookie. now get the contents 41 | rawValue, err := base64.URLEncoding.DecodeString(parts[0]) 42 | if err == nil { 43 | value = string(rawValue) 44 | ok = true 45 | return 46 | } 47 | } 48 | } 49 | return 50 | } 51 | 52 | // SignedValue returns a cookie that is signed and can later be checked with Validate 53 | func SignedValue(seed string, key string, value string, now time.Time) string { 54 | encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) 55 | timeStr := fmt.Sprintf("%d", now.Unix()) 56 | sig := cookieSignature(seed, key, encodedValue, timeStr) 57 | cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) 58 | return cookieVal 59 | } 60 | 61 | func cookieSignature(args ...string) string { 62 | h := hmac.New(sha1.New, []byte(args[0])) 63 | for _, arg := range args[1:] { 64 | h.Write([]byte(arg)) 65 | } 66 | var b []byte 67 | b = h.Sum(b) 68 | return base64.URLEncoding.EncodeToString(b) 69 | } 70 | 71 | func checkHmac(input, expected string) bool { 72 | inputMAC, err1 := base64.URLEncoding.DecodeString(input) 73 | if err1 == nil { 74 | expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) 75 | if err2 == nil { 76 | return hmac.Equal(inputMAC, expectedMAC) 77 | } 78 | } 79 | return false 80 | } 81 | 82 | // Cipher provides methods to encrypt and decrypt cookie values 83 | type Cipher struct { 84 | cipher.Block 85 | } 86 | 87 | // NewCipher returns a new aes Cipher for encrypting cookie values 88 | func NewCipher(secret []byte) (*Cipher, error) { 89 | c, err := aes.NewCipher(secret) 90 | if err != nil { 91 | return nil, err 92 | } 93 | return &Cipher{Block: c}, err 94 | } 95 | 96 | // Encrypt a value for use in a cookie 97 | func (c *Cipher) Encrypt(value string) (string, error) { 98 | ciphertext := make([]byte, aes.BlockSize+len(value)) 99 | iv := ciphertext[:aes.BlockSize] 100 | if _, err := io.ReadFull(rand.Reader, iv); err != nil { 101 | return "", fmt.Errorf("failed to create initialization vector %s", err) 102 | } 103 | 104 | stream := cipher.NewCFBEncrypter(c.Block, iv) 105 | stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value)) 106 | return base64.StdEncoding.EncodeToString(ciphertext), nil 107 | } 108 | 109 | // Decrypt a value from a cookie to it's original string 110 | func (c *Cipher) Decrypt(s string) (string, error) { 111 | encrypted, err := base64.StdEncoding.DecodeString(s) 112 | if err != nil { 113 | return "", fmt.Errorf("failed to decrypt cookie value %s", err) 114 | } 115 | 116 | if len(encrypted) < aes.BlockSize { 117 | return "", fmt.Errorf("encrypted cookie value should be "+ 118 | "at least %d bytes, but is only %d bytes", 119 | aes.BlockSize, len(encrypted)) 120 | } 121 | 122 | iv := encrypted[:aes.BlockSize] 123 | encrypted = encrypted[aes.BlockSize:] 124 | stream := cipher.NewCFBDecrypter(c.Block, iv) 125 | stream.XORKeyStream(encrypted, encrypted) 126 | 127 | return string(encrypted), nil 128 | } 129 | -------------------------------------------------------------------------------- /cookie/cookies_test.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "encoding/base64" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestEncodeAndDecodeAccessToken(t *testing.T) { 11 | const secret = "0123456789abcdefghijklmnopqrstuv" 12 | const token = "my access token" 13 | c, err := NewCipher([]byte(secret)) 14 | assert.Equal(t, nil, err) 15 | 16 | encoded, err := c.Encrypt(token) 17 | assert.Equal(t, nil, err) 18 | 19 | decoded, err := c.Decrypt(encoded) 20 | assert.Equal(t, nil, err) 21 | 22 | assert.NotEqual(t, token, encoded) 23 | assert.Equal(t, token, decoded) 24 | } 25 | 26 | func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { 27 | const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" 28 | const token = "my access token" 29 | 30 | secret, err := base64.URLEncoding.DecodeString(secretBase64) 31 | assert.Equal(t, nil, err) 32 | c, err := NewCipher([]byte(secret)) 33 | assert.Equal(t, nil, err) 34 | 35 | encoded, err := c.Encrypt(token) 36 | assert.Equal(t, nil, err) 37 | 38 | decoded, err := c.Decrypt(encoded) 39 | assert.Equal(t, nil, err) 40 | 41 | assert.NotEqual(t, token, encoded) 42 | assert.Equal(t, token, decoded) 43 | } 44 | -------------------------------------------------------------------------------- /cookie/nonce.go: -------------------------------------------------------------------------------- 1 | package cookie 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | ) 7 | 8 | // Nonce generates a random 16 byte string to be used as a nonce 9 | func Nonce() (nonce string, err error) { 10 | b := make([]byte, 16) 11 | _, err = rand.Read(b) 12 | if err != nil { 13 | return 14 | } 15 | nonce = fmt.Sprintf("%x", b) 16 | return 17 | } 18 | -------------------------------------------------------------------------------- /dist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # build binary distributions for linux/amd64 and darwin/amd64 3 | set -e 4 | 5 | DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 6 | echo "working dir $DIR" 7 | mkdir -p $DIR/dist 8 | dep ensure || exit 1 9 | 10 | os=$(go env GOOS) 11 | arch=$(go env GOARCH) 12 | version=$(cat $DIR/version.go | grep "const VERSION" | awk '{print $NF}' | sed 's/"//g') 13 | goversion=$(go version | awk '{print $3}') 14 | sha256sum=() 15 | 16 | echo "... running tests" 17 | ./test.sh 18 | 19 | for os in windows linux darwin; do 20 | echo "... building v$version for $os/$arch" 21 | EXT= 22 | if [ $os = windows ]; then 23 | EXT=".exe" 24 | fi 25 | BUILD=$(mktemp -d ${TMPDIR:-/tmp}/oauth2_proxy.XXXXXX) 26 | TARGET="oauth2_proxy-$version.$os-$arch.$goversion" 27 | FILENAME="oauth2_proxy-$version.$os-$arch$EXT" 28 | GOOS=$os GOARCH=$arch CGO_ENABLED=0 \ 29 | go build -ldflags="-s -w" -o $BUILD/$TARGET/$FILENAME || exit 1 30 | pushd $BUILD/$TARGET 31 | sha256sum+=("$(shasum -a 256 $FILENAME || exit 1)") 32 | cd .. && tar czvf $TARGET.tar.gz $TARGET 33 | mv $TARGET.tar.gz $DIR/dist 34 | popd 35 | done 36 | 37 | checksum_file="sha256sum.txt" 38 | cd $DIR/dist 39 | if [ -f $checksum_file ]; then 40 | rm $checksum_file 41 | fi 42 | touch $checksum_file 43 | for checksum in "${sha256sum[@]}"; do 44 | echo "$checksum" >> $checksum_file 45 | done 46 | -------------------------------------------------------------------------------- /env_options.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | // EnvOptions holds program options loaded from the process environment 10 | type EnvOptions map[string]interface{} 11 | 12 | // LoadEnvForStruct loads environment variables for each field in an options 13 | // struct passed into it. 14 | // 15 | // Fields in the options struct must have an `env` and `cfg` tag to be read 16 | // from the environment 17 | func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { 18 | val := reflect.ValueOf(options).Elem() 19 | typ := val.Type() 20 | for i := 0; i < typ.NumField(); i++ { 21 | // pull out the struct tags: 22 | // flag - the name of the command line flag 23 | // deprecated - (optional) the name of the deprecated command line flag 24 | // cfg - (optional, defaults to underscored flag) the name of the config file option 25 | field := typ.Field(i) 26 | flagName := field.Tag.Get("flag") 27 | envName := field.Tag.Get("env") 28 | cfgName := field.Tag.Get("cfg") 29 | if cfgName == "" && flagName != "" { 30 | cfgName = strings.Replace(flagName, "-", "_", -1) 31 | } 32 | if envName == "" || cfgName == "" { 33 | // resolvable fields must have the `env` and `cfg` struct tag 34 | continue 35 | } 36 | v := os.Getenv(envName) 37 | if v != "" { 38 | cfg[cfgName] = v 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /env_options_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type envTest struct { 11 | testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` 12 | } 13 | 14 | func TestLoadEnvForStruct(t *testing.T) { 15 | 16 | cfg := make(EnvOptions) 17 | cfg.LoadEnvForStruct(&envTest{}) 18 | 19 | _, ok := cfg["target_field"] 20 | assert.Equal(t, ok, false) 21 | 22 | os.Setenv("TEST_ENV_FIELD", "1234abcd") 23 | cfg.LoadEnvForStruct(&envTest{}) 24 | v := cfg["target_field"] 25 | assert.Equal(t, v, "1234abcd") 26 | } 27 | -------------------------------------------------------------------------------- /htpasswd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/base64" 6 | "encoding/csv" 7 | "io" 8 | "log" 9 | "os" 10 | 11 | "golang.org/x/crypto/bcrypt" 12 | ) 13 | 14 | // Lookup passwords in a htpasswd file 15 | // Passwords must be generated with -B for bcrypt or -s for SHA1. 16 | 17 | // HtpasswdFile represents the structure of an htpasswd file 18 | type HtpasswdFile struct { 19 | Users map[string]string 20 | } 21 | 22 | // NewHtpasswdFromFile constructs an HtpasswdFile from the file at the path given 23 | func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { 24 | r, err := os.Open(path) 25 | if err != nil { 26 | return nil, err 27 | } 28 | defer r.Close() 29 | return NewHtpasswd(r) 30 | } 31 | 32 | // NewHtpasswd consctructs an HtpasswdFile from an io.Reader (opened file) 33 | func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { 34 | csvReader := csv.NewReader(file) 35 | csvReader.Comma = ':' 36 | csvReader.Comment = '#' 37 | csvReader.TrimLeadingSpace = true 38 | 39 | records, err := csvReader.ReadAll() 40 | if err != nil { 41 | return nil, err 42 | } 43 | h := &HtpasswdFile{Users: make(map[string]string)} 44 | for _, record := range records { 45 | h.Users[record[0]] = record[1] 46 | } 47 | return h, nil 48 | } 49 | 50 | // Validate checks a users password against the HtpasswdFile entries 51 | func (h *HtpasswdFile) Validate(user string, password string) bool { 52 | realPassword, exists := h.Users[user] 53 | if !exists { 54 | return false 55 | } 56 | 57 | shaPrefix := realPassword[:5] 58 | if shaPrefix == "{SHA}" { 59 | shaValue := realPassword[5:] 60 | d := sha1.New() 61 | d.Write([]byte(password)) 62 | return shaValue == base64.StdEncoding.EncodeToString(d.Sum(nil)) 63 | } 64 | 65 | bcryptPrefix := realPassword[:4] 66 | if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" { 67 | return bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password)) == nil 68 | } 69 | 70 | log.Printf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user) 71 | return false 72 | } 73 | -------------------------------------------------------------------------------- /htpasswd_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "golang.org/x/crypto/bcrypt" 10 | ) 11 | 12 | func TestSHA(t *testing.T) { 13 | file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n")) 14 | h, err := NewHtpasswd(file) 15 | assert.Equal(t, err, nil) 16 | 17 | valid := h.Validate("testuser", "asdf") 18 | assert.Equal(t, valid, true) 19 | } 20 | 21 | func TestBcrypt(t *testing.T) { 22 | hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) 23 | assert.Equal(t, err, nil) 24 | hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) 25 | assert.Equal(t, err, nil) 26 | 27 | contents := fmt.Sprintf("testuser1:%s\ntestuser2:%s\n", hash1, hash2) 28 | file := bytes.NewBuffer([]byte(contents)) 29 | 30 | h, err := NewHtpasswd(file) 31 | assert.Equal(t, err, nil) 32 | 33 | valid := h.Validate("testuser1", "password") 34 | assert.Equal(t, valid, true) 35 | 36 | valid = h.Validate("testuser2", "top-secret") 37 | assert.Equal(t, valid, true) 38 | } 39 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "log" 6 | "net" 7 | "net/http" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | // Server represents an HTTP server 13 | type Server struct { 14 | Handler http.Handler 15 | Opts *Options 16 | } 17 | 18 | // ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options 19 | func (s *Server) ListenAndServe() { 20 | if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { 21 | s.ServeHTTPS() 22 | } else { 23 | s.ServeHTTP() 24 | } 25 | } 26 | 27 | // Used with gcpHealthcheck() 28 | const userAgentHeader = "User-Agent" 29 | const googleHealthCheckUserAgent = "GoogleHC/1.0" 30 | const rootPath = "/" 31 | 32 | // gcpHealthcheck handles healthcheck queries from GCP. 33 | func gcpHealthcheck(h http.Handler) http.Handler { 34 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | // Check for liveness and readiness: used for Google App Engine 36 | if r.URL.EscapedPath() == "/liveness_check" { 37 | w.WriteHeader(http.StatusOK) 38 | w.Write([]byte("OK")) 39 | return 40 | } 41 | if r.URL.EscapedPath() == "/readiness_check" { 42 | w.WriteHeader(http.StatusOK) 43 | w.Write([]byte("OK")) 44 | return 45 | } 46 | 47 | // Check for GKE ingress healthcheck: The ingress requires the root 48 | // path of the target to return a 200 (OK) to indicate the service's good health. This can be quite a challenging demand 49 | // depending on the application's path structure. This middleware filters out the requests from the health check by 50 | // 51 | // 1. checking that the request path is indeed the root path 52 | // 2. ensuring that the User-Agent is "GoogleHC/1.0", the health checker 53 | // 3. ensuring the request method is "GET" 54 | if r.URL.Path == rootPath && 55 | r.Header.Get(userAgentHeader) == googleHealthCheckUserAgent && 56 | r.Method == http.MethodGet { 57 | 58 | w.WriteHeader(http.StatusOK) 59 | return 60 | } 61 | 62 | h.ServeHTTP(w, r) 63 | }) 64 | } 65 | 66 | // ServeHTTP constructs a net.Listener and starts handling HTTP requests 67 | func (s *Server) ServeHTTP() { 68 | HTTPAddress := s.Opts.HTTPAddress 69 | var scheme string 70 | 71 | i := strings.Index(HTTPAddress, "://") 72 | if i > -1 { 73 | scheme = HTTPAddress[0:i] 74 | } 75 | 76 | var networkType string 77 | switch scheme { 78 | case "", "http": 79 | networkType = "tcp" 80 | default: 81 | networkType = scheme 82 | } 83 | 84 | slice := strings.SplitN(HTTPAddress, "//", 2) 85 | listenAddr := slice[len(slice)-1] 86 | 87 | listener, err := net.Listen(networkType, listenAddr) 88 | if err != nil { 89 | log.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err) 90 | } 91 | log.Printf("HTTP: listening on %s", listenAddr) 92 | 93 | server := &http.Server{Handler: s.Handler} 94 | err = server.Serve(listener) 95 | if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { 96 | log.Printf("ERROR: http.Serve() - %s", err) 97 | } 98 | 99 | log.Printf("HTTP: closing %s", listener.Addr()) 100 | } 101 | 102 | // ServeHTTPS constructs a net.Listener and starts handling HTTPS requests 103 | func (s *Server) ServeHTTPS() { 104 | addr := s.Opts.HTTPSAddress 105 | config := &tls.Config{ 106 | MinVersion: tls.VersionTLS12, 107 | MaxVersion: tls.VersionTLS12, 108 | } 109 | if config.NextProtos == nil { 110 | config.NextProtos = []string{"http/1.1"} 111 | } 112 | 113 | var err error 114 | config.Certificates = make([]tls.Certificate, 1) 115 | config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile) 116 | if err != nil { 117 | log.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err) 118 | } 119 | 120 | ln, err := net.Listen("tcp", addr) 121 | if err != nil { 122 | log.Fatalf("FATAL: listen (%s) failed - %s", addr, err) 123 | } 124 | log.Printf("HTTPS: listening on %s", ln.Addr()) 125 | 126 | tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) 127 | srv := &http.Server{Handler: s.Handler} 128 | err = srv.Serve(tlsListener) 129 | 130 | if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { 131 | log.Printf("ERROR: https.Serve() - %s", err) 132 | } 133 | 134 | log.Printf("HTTPS: closing %s", tlsListener.Addr()) 135 | } 136 | 137 | // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted 138 | // connections. It's used by ListenAndServe and ListenAndServeTLS so 139 | // dead TCP connections (e.g. closing laptop mid-download) eventually 140 | // go away. 141 | type tcpKeepAliveListener struct { 142 | *net.TCPListener 143 | } 144 | 145 | func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { 146 | tc, err := ln.AcceptTCP() 147 | if err != nil { 148 | return 149 | } 150 | tc.SetKeepAlive(true) 151 | tc.SetKeepAlivePeriod(3 * time.Minute) 152 | return tc, nil 153 | } 154 | -------------------------------------------------------------------------------- /http_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestGCPHealthcheckLiveness(t *testing.T) { 12 | handler := func(w http.ResponseWriter, req *http.Request) { 13 | w.Write([]byte("test")) 14 | } 15 | 16 | h := gcpHealthcheck(http.HandlerFunc(handler)) 17 | rw := httptest.NewRecorder() 18 | r, _ := http.NewRequest("GET", "/liveness_check", nil) 19 | r.RemoteAddr = "127.0.0.1" 20 | r.Host = "test-server" 21 | h.ServeHTTP(rw, r) 22 | 23 | assert.Equal(t, 200, rw.Code) 24 | assert.Equal(t, "OK", rw.Body.String()) 25 | } 26 | 27 | func TestGCPHealthcheckReadiness(t *testing.T) { 28 | handler := func(w http.ResponseWriter, req *http.Request) { 29 | w.Write([]byte("test")) 30 | } 31 | 32 | h := gcpHealthcheck(http.HandlerFunc(handler)) 33 | rw := httptest.NewRecorder() 34 | r, _ := http.NewRequest("GET", "/readiness_check", nil) 35 | r.RemoteAddr = "127.0.0.1" 36 | r.Host = "test-server" 37 | h.ServeHTTP(rw, r) 38 | 39 | assert.Equal(t, 200, rw.Code) 40 | assert.Equal(t, "OK", rw.Body.String()) 41 | } 42 | 43 | func TestGCPHealthcheckNotHealthcheck(t *testing.T) { 44 | handler := func(w http.ResponseWriter, req *http.Request) { 45 | w.Write([]byte("test")) 46 | } 47 | 48 | h := gcpHealthcheck(http.HandlerFunc(handler)) 49 | rw := httptest.NewRecorder() 50 | r, _ := http.NewRequest("GET", "/not_any_check", nil) 51 | r.RemoteAddr = "127.0.0.1" 52 | r.Host = "test-server" 53 | h.ServeHTTP(rw, r) 54 | 55 | assert.Equal(t, "test", rw.Body.String()) 56 | } 57 | 58 | func TestGCPHealthcheckIngress(t *testing.T) { 59 | handler := func(w http.ResponseWriter, req *http.Request) { 60 | w.Write([]byte("test")) 61 | } 62 | 63 | h := gcpHealthcheck(http.HandlerFunc(handler)) 64 | rw := httptest.NewRecorder() 65 | r, _ := http.NewRequest("GET", "/", nil) 66 | r.RemoteAddr = "127.0.0.1" 67 | r.Host = "test-server" 68 | r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) 69 | h.ServeHTTP(rw, r) 70 | 71 | assert.Equal(t, 200, rw.Code) 72 | assert.Equal(t, "", rw.Body.String()) 73 | } 74 | 75 | func TestGCPHealthcheckNotIngress(t *testing.T) { 76 | handler := func(w http.ResponseWriter, req *http.Request) { 77 | w.Write([]byte("test")) 78 | } 79 | 80 | h := gcpHealthcheck(http.HandlerFunc(handler)) 81 | rw := httptest.NewRecorder() 82 | r, _ := http.NewRequest("GET", "/foo", nil) 83 | r.RemoteAddr = "127.0.0.1" 84 | r.Host = "test-server" 85 | r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) 86 | h.ServeHTTP(rw, r) 87 | 88 | assert.Equal(t, "test", rw.Body.String()) 89 | } 90 | 91 | func TestGCPHealthcheckNotIngressPut(t *testing.T) { 92 | handler := func(w http.ResponseWriter, req *http.Request) { 93 | w.Write([]byte("test")) 94 | } 95 | 96 | h := gcpHealthcheck(http.HandlerFunc(handler)) 97 | rw := httptest.NewRecorder() 98 | r, _ := http.NewRequest("PUT", "/", nil) 99 | r.RemoteAddr = "127.0.0.1" 100 | r.Host = "test-server" 101 | r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) 102 | h.ServeHTTP(rw, r) 103 | 104 | assert.Equal(t, "test", rw.Body.String()) 105 | } 106 | -------------------------------------------------------------------------------- /logging_handler.go: -------------------------------------------------------------------------------- 1 | // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go 2 | // to add logging of request duration as last value (and drop referrer) 3 | 4 | package main 5 | 6 | import ( 7 | "bufio" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "text/template" 15 | "time" 16 | ) 17 | 18 | const ( 19 | defaultRequestLoggingFormat = "{{.Client}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" 20 | ) 21 | 22 | // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status 23 | // code and body size 24 | type responseLogger struct { 25 | w http.ResponseWriter 26 | status int 27 | size int 28 | upstream string 29 | authInfo string 30 | } 31 | 32 | // Header returns the ResponseWriter's Header 33 | func (l *responseLogger) Header() http.Header { 34 | return l.w.Header() 35 | } 36 | 37 | // Support Websocket 38 | func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { 39 | if hj, ok := l.w.(http.Hijacker); ok { 40 | return hj.Hijack() 41 | } 42 | return nil, nil, errors.New("http.Hijacker is not available on writer") 43 | } 44 | 45 | // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's 46 | // Header 47 | func (l *responseLogger) ExtractGAPMetadata() { 48 | upstream := l.w.Header().Get("GAP-Upstream-Address") 49 | if upstream != "" { 50 | l.upstream = upstream 51 | l.w.Header().Del("GAP-Upstream-Address") 52 | } 53 | authInfo := l.w.Header().Get("GAP-Auth") 54 | if authInfo != "" { 55 | l.authInfo = authInfo 56 | l.w.Header().Del("GAP-Auth") 57 | } 58 | } 59 | 60 | // Write writes the response using the ResponseWriter 61 | func (l *responseLogger) Write(b []byte) (int, error) { 62 | if l.status == 0 { 63 | // The status will be StatusOK if WriteHeader has not been called yet 64 | l.status = http.StatusOK 65 | } 66 | l.ExtractGAPMetadata() 67 | size, err := l.w.Write(b) 68 | l.size += size 69 | return size, err 70 | } 71 | 72 | // WriteHeader writes the status code for the Response 73 | func (l *responseLogger) WriteHeader(s int) { 74 | l.ExtractGAPMetadata() 75 | l.w.WriteHeader(s) 76 | l.status = s 77 | } 78 | 79 | // Status returns the response status code 80 | func (l *responseLogger) Status() int { 81 | return l.status 82 | } 83 | 84 | // Size returns teh response size 85 | func (l *responseLogger) Size() int { 86 | return l.size 87 | } 88 | 89 | func (l *responseLogger) Flush() { 90 | if flusher, ok := l.w.(http.Flusher); ok { 91 | flusher.Flush() 92 | } 93 | } 94 | 95 | // logMessageData is the container for all values that are available as variables in the request logging format. 96 | // All values are pre-formatted strings so it is easy to use them in the format string. 97 | type logMessageData struct { 98 | Client, 99 | Host, 100 | Protocol, 101 | RequestDuration, 102 | RequestMethod, 103 | RequestURI, 104 | ResponseSize, 105 | StatusCode, 106 | Timestamp, 107 | Upstream, 108 | UserAgent, 109 | Username string 110 | } 111 | 112 | // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends 113 | type loggingHandler struct { 114 | writer io.Writer 115 | handler http.Handler 116 | enabled bool 117 | logTemplate *template.Template 118 | } 119 | 120 | // LoggingHandler provides an http.Handler which logs requests to the HTTP server 121 | func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler { 122 | return loggingHandler{ 123 | writer: out, 124 | handler: h, 125 | enabled: v, 126 | logTemplate: template.Must(template.New("request-log").Parse(requestLoggingTpl)), 127 | } 128 | } 129 | 130 | func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 131 | t := time.Now() 132 | url := *req.URL 133 | logger := &responseLogger{w: w} 134 | h.handler.ServeHTTP(logger, req) 135 | if !h.enabled { 136 | return 137 | } 138 | h.writeLogLine(logger.authInfo, logger.upstream, req, url, t, logger.Status(), logger.Size()) 139 | } 140 | 141 | // Log entry for req similar to Apache Common Log Format. 142 | // ts is the timestamp with which the entry should be logged. 143 | // status, size are used to provide the response HTTP status and size. 144 | func (h loggingHandler) writeLogLine(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) { 145 | if username == "" { 146 | username = "-" 147 | } 148 | if upstream == "" { 149 | upstream = "-" 150 | } 151 | if url.User != nil && username == "-" { 152 | if name := url.User.Username(); name != "" { 153 | username = name 154 | } 155 | } 156 | 157 | client := req.Header.Get("X-Real-IP") 158 | if client == "" { 159 | client = req.RemoteAddr 160 | } 161 | 162 | if c, _, err := net.SplitHostPort(client); err == nil { 163 | client = c 164 | } 165 | 166 | duration := float64(time.Now().Sub(ts)) / float64(time.Second) 167 | 168 | h.logTemplate.Execute(h.writer, logMessageData{ 169 | Client: client, 170 | Host: req.Host, 171 | Protocol: req.Proto, 172 | RequestDuration: fmt.Sprintf("%0.3f", duration), 173 | RequestMethod: req.Method, 174 | RequestURI: fmt.Sprintf("%q", url.RequestURI()), 175 | ResponseSize: fmt.Sprintf("%d", size), 176 | StatusCode: fmt.Sprintf("%d", status), 177 | Timestamp: ts.Format("02/Jan/2006:15:04:05 -0700"), 178 | Upstream: upstream, 179 | UserAgent: fmt.Sprintf("%q", req.UserAgent()), 180 | Username: username, 181 | }) 182 | 183 | h.writer.Write([]byte("\n")) 184 | } 185 | -------------------------------------------------------------------------------- /logging_handler_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestLoggingHandler_ServeHTTP(t *testing.T) { 14 | ts := time.Now() 15 | 16 | tests := []struct { 17 | Format, 18 | ExpectedLogMessage string 19 | }{ 20 | {defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0", ts.Format("02/Jan/2006:15:04:05 -0700"))}, 21 | {"{{.RequestMethod}}", "GET\n"}, 22 | } 23 | 24 | for _, test := range tests { 25 | buf := bytes.NewBuffer(nil) 26 | handler := func(w http.ResponseWriter, req *http.Request) { 27 | _, ok := w.(http.Hijacker) 28 | if !ok { 29 | t.Error("http.Hijacker is not available") 30 | } 31 | 32 | w.Write([]byte("test")) 33 | } 34 | 35 | h := LoggingHandler(buf, http.HandlerFunc(handler), true, test.Format) 36 | 37 | r, _ := http.NewRequest("GET", "/foo/bar", nil) 38 | r.RemoteAddr = "127.0.0.1" 39 | r.Host = "test-server" 40 | 41 | h.ServeHTTP(httptest.NewRecorder(), r) 42 | 43 | actual := buf.String() 44 | if !strings.Contains(actual, test.ExpectedLogMessage) { 45 | t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage) 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "math/rand" 8 | "net/http" 9 | "os" 10 | "runtime" 11 | "strings" 12 | "time" 13 | 14 | "github.com/BurntSushi/toml" 15 | options "github.com/mreiferson/go-options" 16 | ) 17 | 18 | func main() { 19 | log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) 20 | flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError) 21 | 22 | emailDomains := StringArray{} 23 | whitelistDomains := StringArray{} 24 | upstreams := StringArray{} 25 | skipAuthRegex := StringArray{} 26 | googleGroups := StringArray{} 27 | 28 | config := flagSet.String("config", "", "path to config file") 29 | showVersion := flagSet.Bool("version", false, "print version string") 30 | 31 | flagSet.String("http-address", "127.0.0.1:4180", "[http://]: or unix:// to listen on for HTTP clients") 32 | flagSet.String("https-address", ":443", ": to listen on for HTTPS clients") 33 | flagSet.String("tls-cert", "", "path to certificate file") 34 | flagSet.String("tls-key", "", "path to private key file") 35 | flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") 36 | flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)") 37 | flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path") 38 | flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") 39 | flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") 40 | flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") 41 | flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") 42 | flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") 43 | flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream") 44 | flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)") 45 | flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") 46 | flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") 47 | flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") 48 | flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS") 49 | flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") 50 | 51 | flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") 52 | flagSet.Var(&whitelistDomains, "whitelist-domain", "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)") 53 | flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") 54 | flagSet.String("github-org", "", "restrict logins to members of this organisation") 55 | flagSet.String("github-team", "", "restrict logins to members of this team") 56 | flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).") 57 | flagSet.String("google-admin-email", "", "the google admin to impersonate for api calls") 58 | flagSet.String("google-service-account-json", "", "the path to the service account json credentials") 59 | flagSet.String("client-id", "", "the OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") 60 | flagSet.String("client-secret", "", "the OAuth Client Secret") 61 | flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") 62 | flagSet.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -s\" for SHA encryption or \"htpasswd -B\" for bcrypt encryption") 63 | flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") 64 | flagSet.String("custom-templates-dir", "", "path to custom html templates") 65 | flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") 66 | flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") 67 | flagSet.Bool("proxy-websockets", true, "enables WebSocket proxying") 68 | 69 | flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") 70 | flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)") 71 | flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") 72 | flagSet.String("cookie-path", "/", "an optional cookie path to force cookies to (ie: /poc/)*") 73 | flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") 74 | flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") 75 | flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") 76 | flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") 77 | 78 | flagSet.Bool("request-logging", true, "Log requests to stdout") 79 | flagSet.String("request-logging-format", defaultRequestLoggingFormat, "Template for log lines") 80 | 81 | flagSet.String("provider", "google", "OAuth provider") 82 | flagSet.String("oidc-issuer-url", "", "OpenID Connect issuer URL (ie: https://accounts.google.com)") 83 | flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints") 84 | flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)") 85 | flagSet.String("login-url", "", "Authentication endpoint") 86 | flagSet.String("redeem-url", "", "Token redemption endpoint") 87 | flagSet.String("profile-url", "", "Profile access endpoint") 88 | flagSet.String("resource", "", "The resource that is protected (Azure AD only)") 89 | flagSet.String("validate-url", "", "Access token validation endpoint") 90 | flagSet.String("scope", "", "OAuth scope specification") 91 | flagSet.String("approval-prompt", "force", "OAuth approval_prompt") 92 | 93 | flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") 94 | flagSet.String("acr-values", "http://idmanagement.gov/ns/assurance/loa/1", "acr values string: optional, used by login.gov") 95 | flagSet.String("jwt-key", "", "private key in PEM format used to sign JWT, so that you can say something like -jwt-key=\"${OAUTH2_PROXY_JWT_KEY}\": required by login.gov") 96 | flagSet.String("jwt-key-file", "", "path to the private key file in PEM format used to sign the JWT so that you can say something like -jwt-key-file=/etc/ssl/private/jwt_signing_key.pem: required by login.gov") 97 | flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov") 98 | flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") 99 | 100 | flagSet.Parse(os.Args[1:]) 101 | 102 | if *showVersion { 103 | fmt.Printf("oauth2_proxy %s (built with %s)\n", VERSION, runtime.Version()) 104 | return 105 | } 106 | 107 | opts := NewOptions() 108 | 109 | cfg := make(EnvOptions) 110 | if *config != "" { 111 | _, err := toml.DecodeFile(*config, &cfg) 112 | if err != nil { 113 | log.Fatalf("ERROR: failed to load config file %s - %s", *config, err) 114 | } 115 | } 116 | cfg.LoadEnvForStruct(opts) 117 | options.Resolve(opts, flagSet, cfg) 118 | 119 | err := opts.Validate() 120 | if err != nil { 121 | log.Printf("%s", err) 122 | os.Exit(1) 123 | } 124 | validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) 125 | oauthproxy := NewOAuthProxy(opts, validator) 126 | 127 | if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { 128 | if len(opts.EmailDomains) > 1 { 129 | oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) 130 | } else if opts.EmailDomains[0] != "*" { 131 | oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) 132 | } 133 | } 134 | 135 | if opts.HtpasswdFile != "" { 136 | log.Printf("using htpasswd file %s", opts.HtpasswdFile) 137 | oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(opts.HtpasswdFile) 138 | oauthproxy.DisplayHtpasswdForm = opts.DisplayHtpasswdForm 139 | if err != nil { 140 | log.Fatalf("FATAL: unable to open %s %s", opts.HtpasswdFile, err) 141 | } 142 | } 143 | 144 | rand.Seed(time.Now().UnixNano()) 145 | 146 | var handler http.Handler 147 | if opts.GCPHealthChecks { 148 | handler = gcpHealthcheck(LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat)) 149 | } else { 150 | handler = LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat) 151 | } 152 | s := &Server{ 153 | Handler: handler, 154 | Opts: opts, 155 | } 156 | s.ListenAndServe() 157 | } 158 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto" 5 | "fmt" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func testOptions() *Options { 15 | o := NewOptions() 16 | o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8080/") 17 | o.CookieSecret = "foobar" 18 | o.ClientID = "bazquux" 19 | o.ClientSecret = "xyzzyplugh" 20 | o.EmailDomains = []string{"*"} 21 | return o 22 | } 23 | 24 | func errorMsg(msgs []string) string { 25 | result := make([]string, 0) 26 | result = append(result, "Invalid configuration:") 27 | result = append(result, msgs...) 28 | return strings.Join(result, "\n ") 29 | } 30 | 31 | func TestNewOptions(t *testing.T) { 32 | o := NewOptions() 33 | o.EmailDomains = []string{"*"} 34 | err := o.Validate() 35 | assert.NotEqual(t, nil, err) 36 | 37 | expected := errorMsg([]string{ 38 | "missing setting: cookie-secret", 39 | "missing setting: client-id", 40 | "missing setting: client-secret"}) 41 | assert.Equal(t, expected, err.Error()) 42 | } 43 | 44 | func TestGoogleGroupOptions(t *testing.T) { 45 | o := testOptions() 46 | o.GoogleGroups = []string{"googlegroup"} 47 | err := o.Validate() 48 | assert.NotEqual(t, nil, err) 49 | 50 | expected := errorMsg([]string{ 51 | "missing setting: google-admin-email", 52 | "missing setting: google-service-account-json"}) 53 | assert.Equal(t, expected, err.Error()) 54 | } 55 | 56 | func TestGoogleGroupInvalidFile(t *testing.T) { 57 | o := testOptions() 58 | o.GoogleGroups = []string{"test_group"} 59 | o.GoogleAdminEmail = "admin@example.com" 60 | o.GoogleServiceAccountJSON = "file_doesnt_exist.json" 61 | err := o.Validate() 62 | assert.NotEqual(t, nil, err) 63 | 64 | expected := errorMsg([]string{ 65 | "invalid Google credentials file: file_doesnt_exist.json", 66 | }) 67 | assert.Equal(t, expected, err.Error()) 68 | } 69 | 70 | func TestInitializedOptions(t *testing.T) { 71 | o := testOptions() 72 | assert.Equal(t, nil, o.Validate()) 73 | } 74 | 75 | // Note that it's not worth testing nonparseable URLs, since url.Parse() 76 | // seems to parse damn near anything. 77 | func TestRedirectURL(t *testing.T) { 78 | o := testOptions() 79 | o.RedirectURL = "https://myhost.com/oauth2/callback" 80 | assert.Equal(t, nil, o.Validate()) 81 | expected := &url.URL{ 82 | Scheme: "https", Host: "myhost.com", Path: "/oauth2/callback"} 83 | assert.Equal(t, expected, o.redirectURL) 84 | } 85 | 86 | func TestProxyURLs(t *testing.T) { 87 | o := testOptions() 88 | o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") 89 | assert.Equal(t, nil, o.Validate()) 90 | expected := []*url.URL{ 91 | {Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, 92 | // note the '/' was added 93 | {Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, 94 | } 95 | assert.Equal(t, expected, o.proxyURLs) 96 | } 97 | 98 | func TestProxyURLsError(t *testing.T) { 99 | o := testOptions() 100 | o.Upstreams = append(o.Upstreams, "127.0.0.1:8081") 101 | err := o.Validate() 102 | assert.NotEqual(t, nil, err) 103 | 104 | expected := errorMsg([]string{ 105 | "error parsing upstream: parse 127.0.0.1:8081: " + 106 | "first path segment in URL cannot contain colon"}) 107 | assert.Equal(t, expected, err.Error()) 108 | } 109 | 110 | func TestCompiledRegex(t *testing.T) { 111 | o := testOptions() 112 | regexps := []string{"/foo/.*", "/ba[rz]/quux"} 113 | o.SkipAuthRegex = regexps 114 | assert.Equal(t, nil, o.Validate()) 115 | actual := make([]string, 0) 116 | for _, regex := range o.CompiledRegex { 117 | actual = append(actual, regex.String()) 118 | } 119 | assert.Equal(t, regexps, actual) 120 | } 121 | 122 | func TestCompiledRegexError(t *testing.T) { 123 | o := testOptions() 124 | o.SkipAuthRegex = []string{"(foobaz", "barquux)"} 125 | err := o.Validate() 126 | assert.NotEqual(t, nil, err) 127 | 128 | expected := errorMsg([]string{ 129 | "error compiling regex=\"(foobaz\" error parsing regexp: " + 130 | "missing closing ): `(foobaz`", 131 | "error compiling regex=\"barquux)\" error parsing regexp: " + 132 | "unexpected ): `barquux)`"}) 133 | assert.Equal(t, expected, err.Error()) 134 | 135 | o.SkipAuthRegex = []string{"foobaz", "barquux)"} 136 | err = o.Validate() 137 | assert.NotEqual(t, nil, err) 138 | 139 | expected = errorMsg([]string{ 140 | "error compiling regex=\"barquux)\" error parsing regexp: " + 141 | "unexpected ): `barquux)`"}) 142 | assert.Equal(t, expected, err.Error()) 143 | } 144 | 145 | func TestDefaultProviderApiSettings(t *testing.T) { 146 | o := testOptions() 147 | assert.Equal(t, nil, o.Validate()) 148 | p := o.provider.Data() 149 | assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", 150 | p.LoginURL.String()) 151 | assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", 152 | p.RedeemURL.String()) 153 | assert.Equal(t, "", p.ProfileURL.String()) 154 | assert.Equal(t, "profile email", p.Scope) 155 | } 156 | 157 | func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) { 158 | o := testOptions() 159 | assert.Equal(t, nil, o.Validate()) 160 | 161 | assert.Equal(t, false, o.PassAccessToken) 162 | o.PassAccessToken = true 163 | o.CookieSecret = "cookie of invalid length-" 164 | assert.NotEqual(t, nil, o.Validate()) 165 | 166 | o.PassAccessToken = false 167 | o.CookieRefresh = time.Duration(24) * time.Hour 168 | assert.NotEqual(t, nil, o.Validate()) 169 | 170 | o.CookieSecret = "16 bytes AES-128" 171 | assert.Equal(t, nil, o.Validate()) 172 | 173 | o.CookieSecret = "24 byte secret AES-192--" 174 | assert.Equal(t, nil, o.Validate()) 175 | 176 | o.CookieSecret = "32 byte secret for AES-256------" 177 | assert.Equal(t, nil, o.Validate()) 178 | } 179 | 180 | func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { 181 | o := testOptions() 182 | assert.Equal(t, nil, o.Validate()) 183 | 184 | o.CookieSecret = "0123456789abcdefabcd" 185 | o.CookieRefresh = o.CookieExpire 186 | assert.NotEqual(t, nil, o.Validate()) 187 | 188 | o.CookieRefresh -= time.Duration(1) 189 | assert.Equal(t, nil, o.Validate()) 190 | } 191 | 192 | func TestBase64CookieSecret(t *testing.T) { 193 | o := testOptions() 194 | assert.Equal(t, nil, o.Validate()) 195 | 196 | // 32 byte, base64 (urlsafe) encoded key 197 | o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ=" 198 | assert.Equal(t, nil, o.Validate()) 199 | 200 | // 32 byte, base64 (urlsafe) encoded key, w/o padding 201 | o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ" 202 | assert.Equal(t, nil, o.Validate()) 203 | 204 | // 24 byte, base64 (urlsafe) encoded key 205 | o.CookieSecret = "Kp33Gj-GQmYtz4zZUyUDdqQKx5_Hgkv3" 206 | assert.Equal(t, nil, o.Validate()) 207 | 208 | // 16 byte, base64 (urlsafe) encoded key 209 | o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA==" 210 | assert.Equal(t, nil, o.Validate()) 211 | 212 | // 16 byte, base64 (urlsafe) encoded key, w/o padding 213 | o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA" 214 | assert.Equal(t, nil, o.Validate()) 215 | } 216 | 217 | func TestValidateSignatureKey(t *testing.T) { 218 | o := testOptions() 219 | o.SignatureKey = "sha1:secret" 220 | assert.Equal(t, nil, o.Validate()) 221 | assert.Equal(t, o.signatureData.hash, crypto.SHA1) 222 | assert.Equal(t, o.signatureData.key, "secret") 223 | } 224 | 225 | func TestValidateSignatureKeyInvalidSpec(t *testing.T) { 226 | o := testOptions() 227 | o.SignatureKey = "invalid spec" 228 | err := o.Validate() 229 | assert.Equal(t, err.Error(), "Invalid configuration:\n"+ 230 | " invalid signature hash:key spec: "+o.SignatureKey) 231 | } 232 | 233 | func TestValidateSignatureKeyUnsupportedAlgorithm(t *testing.T) { 234 | o := testOptions() 235 | o.SignatureKey = "unsupported:default secret" 236 | err := o.Validate() 237 | assert.Equal(t, err.Error(), "Invalid configuration:\n"+ 238 | " unsupported signature hash algorithm: "+o.SignatureKey) 239 | } 240 | 241 | func TestValidateCookie(t *testing.T) { 242 | o := testOptions() 243 | o.CookieName = "_valid_cookie_name" 244 | assert.Equal(t, nil, o.Validate()) 245 | } 246 | 247 | func TestValidateCookieBadName(t *testing.T) { 248 | o := testOptions() 249 | o.CookieName = "_bad_cookie_name{}" 250 | err := o.Validate() 251 | assert.Equal(t, err.Error(), "Invalid configuration:\n"+ 252 | fmt.Sprintf(" invalid cookie name: %q", o.CookieName)) 253 | } 254 | 255 | func TestSkipOIDCDiscovery(t *testing.T) { 256 | o := testOptions() 257 | o.Provider = "oidc" 258 | o.OIDCIssuerURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/v2.0/" 259 | o.SkipOIDCDiscovery = true 260 | 261 | err := o.Validate() 262 | assert.Equal(t, "Invalid configuration:\n"+ 263 | fmt.Sprintf(" missing setting: login-url\n missing setting: redeem-url\n missing setting: oidc-jwks-url"), err.Error()) 264 | 265 | o.LoginURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/oauth2/v2.0/authorize?p=b2c_1_sign_in" 266 | o.RedeemURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/oauth2/v2.0/token?p=b2c_1_sign_in" 267 | o.OIDCJwksURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/discovery/v2.0/keys" 268 | 269 | assert.Equal(t, nil, o.Validate()) 270 | } 271 | 272 | func TestGCPHealthcheck(t *testing.T) { 273 | o := testOptions() 274 | o.GCPHealthChecks = true 275 | assert.Equal(t, nil, o.Validate()) 276 | } 277 | -------------------------------------------------------------------------------- /providers/azure.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "net/url" 9 | 10 | "github.com/bitly/go-simplejson" 11 | "github.com/pusher/oauth2_proxy/api" 12 | ) 13 | 14 | // AzureProvider represents an Azure based Identity Provider 15 | type AzureProvider struct { 16 | *ProviderData 17 | Tenant string 18 | } 19 | 20 | // NewAzureProvider initiates a new AzureProvider 21 | func NewAzureProvider(p *ProviderData) *AzureProvider { 22 | p.ProviderName = "Azure" 23 | 24 | if p.ProfileURL == nil || p.ProfileURL.String() == "" { 25 | p.ProfileURL = &url.URL{ 26 | Scheme: "https", 27 | Host: "graph.windows.net", 28 | Path: "/me", 29 | RawQuery: "api-version=1.6", 30 | } 31 | } 32 | if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { 33 | p.ProtectedResource = &url.URL{ 34 | Scheme: "https", 35 | Host: "graph.windows.net", 36 | } 37 | } 38 | if p.Scope == "" { 39 | p.Scope = "openid" 40 | } 41 | 42 | return &AzureProvider{ProviderData: p} 43 | } 44 | 45 | // Configure defaults the AzureProvider configuration options 46 | func (p *AzureProvider) Configure(tenant string) { 47 | p.Tenant = tenant 48 | if tenant == "" { 49 | p.Tenant = "common" 50 | } 51 | 52 | if p.LoginURL == nil || p.LoginURL.String() == "" { 53 | p.LoginURL = &url.URL{ 54 | Scheme: "https", 55 | Host: "login.microsoftonline.com", 56 | Path: "/" + p.Tenant + "/oauth2/authorize"} 57 | } 58 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 59 | p.RedeemURL = &url.URL{ 60 | Scheme: "https", 61 | Host: "login.microsoftonline.com", 62 | Path: "/" + p.Tenant + "/oauth2/token", 63 | } 64 | } 65 | } 66 | 67 | func getAzureHeader(accessToken string) http.Header { 68 | header := make(http.Header) 69 | header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) 70 | return header 71 | } 72 | 73 | func getEmailFromJSON(json *simplejson.Json) (string, error) { 74 | var email string 75 | var err error 76 | 77 | email, err = json.Get("mail").String() 78 | 79 | if err != nil || email == "" { 80 | otherMails, otherMailsErr := json.Get("otherMails").Array() 81 | if len(otherMails) > 0 { 82 | email = otherMails[0].(string) 83 | } 84 | err = otherMailsErr 85 | } 86 | 87 | return email, err 88 | } 89 | 90 | // GetEmailAddress returns the Account email address 91 | func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { 92 | var email string 93 | var err error 94 | 95 | if s.AccessToken == "" { 96 | return "", errors.New("missing access token") 97 | } 98 | req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) 99 | if err != nil { 100 | return "", err 101 | } 102 | req.Header = getAzureHeader(s.AccessToken) 103 | 104 | json, err := api.Request(req) 105 | 106 | if err != nil { 107 | return "", err 108 | } 109 | 110 | email, err = getEmailFromJSON(json) 111 | 112 | if err == nil && email != "" { 113 | return email, err 114 | } 115 | 116 | email, err = json.Get("userPrincipalName").String() 117 | 118 | if err != nil { 119 | log.Printf("failed making request %s", err) 120 | return "", err 121 | } 122 | 123 | if email == "" { 124 | log.Printf("failed to get email address") 125 | return "", err 126 | } 127 | 128 | return email, err 129 | } 130 | -------------------------------------------------------------------------------- /providers/azure_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testAzureProvider(hostname string) *AzureProvider { 13 | p := NewAzureProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | ProtectedResource: &url.URL{}, 21 | Scope: ""}) 22 | if hostname != "" { 23 | updateURL(p.Data().LoginURL, hostname) 24 | updateURL(p.Data().RedeemURL, hostname) 25 | updateURL(p.Data().ProfileURL, hostname) 26 | updateURL(p.Data().ValidateURL, hostname) 27 | updateURL(p.Data().ProtectedResource, hostname) 28 | } 29 | return p 30 | } 31 | 32 | func TestAzureProviderDefaults(t *testing.T) { 33 | p := testAzureProvider("") 34 | assert.NotEqual(t, nil, p) 35 | p.Configure("") 36 | assert.Equal(t, "Azure", p.Data().ProviderName) 37 | assert.Equal(t, "common", p.Tenant) 38 | assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/authorize", 39 | p.Data().LoginURL.String()) 40 | assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/token", 41 | p.Data().RedeemURL.String()) 42 | assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", 43 | p.Data().ProfileURL.String()) 44 | assert.Equal(t, "https://graph.windows.net", 45 | p.Data().ProtectedResource.String()) 46 | assert.Equal(t, "", 47 | p.Data().ValidateURL.String()) 48 | assert.Equal(t, "openid", p.Data().Scope) 49 | } 50 | 51 | func TestAzureProviderOverrides(t *testing.T) { 52 | p := NewAzureProvider( 53 | &ProviderData{ 54 | LoginURL: &url.URL{ 55 | Scheme: "https", 56 | Host: "example.com", 57 | Path: "/oauth/auth"}, 58 | RedeemURL: &url.URL{ 59 | Scheme: "https", 60 | Host: "example.com", 61 | Path: "/oauth/token"}, 62 | ProfileURL: &url.URL{ 63 | Scheme: "https", 64 | Host: "example.com", 65 | Path: "/oauth/profile"}, 66 | ValidateURL: &url.URL{ 67 | Scheme: "https", 68 | Host: "example.com", 69 | Path: "/oauth/tokeninfo"}, 70 | ProtectedResource: &url.URL{ 71 | Scheme: "https", 72 | Host: "example.com"}, 73 | Scope: "profile"}) 74 | assert.NotEqual(t, nil, p) 75 | assert.Equal(t, "Azure", p.Data().ProviderName) 76 | assert.Equal(t, "https://example.com/oauth/auth", 77 | p.Data().LoginURL.String()) 78 | assert.Equal(t, "https://example.com/oauth/token", 79 | p.Data().RedeemURL.String()) 80 | assert.Equal(t, "https://example.com/oauth/profile", 81 | p.Data().ProfileURL.String()) 82 | assert.Equal(t, "https://example.com/oauth/tokeninfo", 83 | p.Data().ValidateURL.String()) 84 | assert.Equal(t, "https://example.com", 85 | p.Data().ProtectedResource.String()) 86 | assert.Equal(t, "profile", p.Data().Scope) 87 | } 88 | 89 | func TestAzureSetTenant(t *testing.T) { 90 | p := testAzureProvider("") 91 | p.Configure("example") 92 | assert.Equal(t, "Azure", p.Data().ProviderName) 93 | assert.Equal(t, "example", p.Tenant) 94 | assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/authorize", 95 | p.Data().LoginURL.String()) 96 | assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/token", 97 | p.Data().RedeemURL.String()) 98 | assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", 99 | p.Data().ProfileURL.String()) 100 | assert.Equal(t, "https://graph.windows.net", 101 | p.Data().ProtectedResource.String()) 102 | assert.Equal(t, "", 103 | p.Data().ValidateURL.String()) 104 | assert.Equal(t, "openid", p.Data().Scope) 105 | } 106 | 107 | func testAzureBackend(payload string) *httptest.Server { 108 | path := "/me" 109 | query := "api-version=1.6" 110 | 111 | return httptest.NewServer(http.HandlerFunc( 112 | func(w http.ResponseWriter, r *http.Request) { 113 | if r.URL.Path != path || r.URL.RawQuery != query { 114 | w.WriteHeader(404) 115 | } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { 116 | w.WriteHeader(403) 117 | } else { 118 | w.WriteHeader(200) 119 | w.Write([]byte(payload)) 120 | } 121 | })) 122 | } 123 | 124 | func TestAzureProviderGetEmailAddress(t *testing.T) { 125 | b := testAzureBackend(`{ "mail": "user@windows.net" }`) 126 | defer b.Close() 127 | 128 | bURL, _ := url.Parse(b.URL) 129 | p := testAzureProvider(bURL.Host) 130 | 131 | session := &SessionState{AccessToken: "imaginary_access_token"} 132 | email, err := p.GetEmailAddress(session) 133 | assert.Equal(t, nil, err) 134 | assert.Equal(t, "user@windows.net", email) 135 | } 136 | 137 | func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { 138 | b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`) 139 | defer b.Close() 140 | 141 | bURL, _ := url.Parse(b.URL) 142 | p := testAzureProvider(bURL.Host) 143 | 144 | session := &SessionState{AccessToken: "imaginary_access_token"} 145 | email, err := p.GetEmailAddress(session) 146 | assert.Equal(t, nil, err) 147 | assert.Equal(t, "user@windows.net", email) 148 | } 149 | 150 | func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { 151 | b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`) 152 | defer b.Close() 153 | 154 | bURL, _ := url.Parse(b.URL) 155 | p := testAzureProvider(bURL.Host) 156 | 157 | session := &SessionState{AccessToken: "imaginary_access_token"} 158 | email, err := p.GetEmailAddress(session) 159 | assert.Equal(t, nil, err) 160 | assert.Equal(t, "user@windows.net", email) 161 | } 162 | 163 | func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { 164 | b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`) 165 | defer b.Close() 166 | 167 | bURL, _ := url.Parse(b.URL) 168 | p := testAzureProvider(bURL.Host) 169 | 170 | session := &SessionState{AccessToken: "imaginary_access_token"} 171 | email, err := p.GetEmailAddress(session) 172 | assert.Equal(t, "type assertion to string failed", err.Error()) 173 | assert.Equal(t, "", email) 174 | } 175 | 176 | func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { 177 | b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`) 178 | defer b.Close() 179 | 180 | bURL, _ := url.Parse(b.URL) 181 | p := testAzureProvider(bURL.Host) 182 | 183 | session := &SessionState{AccessToken: "imaginary_access_token"} 184 | email, err := p.GetEmailAddress(session) 185 | assert.Equal(t, nil, err) 186 | assert.Equal(t, "", email) 187 | } 188 | 189 | func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { 190 | b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`) 191 | defer b.Close() 192 | 193 | bURL, _ := url.Parse(b.URL) 194 | p := testAzureProvider(bURL.Host) 195 | 196 | session := &SessionState{AccessToken: "imaginary_access_token"} 197 | email, err := p.GetEmailAddress(session) 198 | assert.Equal(t, "type assertion to string failed", err.Error()) 199 | assert.Equal(t, "", email) 200 | } 201 | -------------------------------------------------------------------------------- /providers/facebook.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/pusher/oauth2_proxy/api" 10 | ) 11 | 12 | // FacebookProvider represents an Facebook based Identity Provider 13 | type FacebookProvider struct { 14 | *ProviderData 15 | } 16 | 17 | // NewFacebookProvider initiates a new FacebookProvider 18 | func NewFacebookProvider(p *ProviderData) *FacebookProvider { 19 | p.ProviderName = "Facebook" 20 | if p.LoginURL.String() == "" { 21 | p.LoginURL = &url.URL{Scheme: "https", 22 | Host: "www.facebook.com", 23 | Path: "/v2.5/dialog/oauth", 24 | // ?granted_scopes=true 25 | } 26 | } 27 | if p.RedeemURL.String() == "" { 28 | p.RedeemURL = &url.URL{Scheme: "https", 29 | Host: "graph.facebook.com", 30 | Path: "/v2.5/oauth/access_token", 31 | } 32 | } 33 | if p.ProfileURL.String() == "" { 34 | p.ProfileURL = &url.URL{Scheme: "https", 35 | Host: "graph.facebook.com", 36 | Path: "/v2.5/me", 37 | } 38 | } 39 | if p.ValidateURL.String() == "" { 40 | p.ValidateURL = p.ProfileURL 41 | } 42 | if p.Scope == "" { 43 | p.Scope = "public_profile email" 44 | } 45 | return &FacebookProvider{ProviderData: p} 46 | } 47 | 48 | func getFacebookHeader(accessToken string) http.Header { 49 | header := make(http.Header) 50 | header.Set("Accept", "application/json") 51 | header.Set("x-li-format", "json") 52 | header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) 53 | return header 54 | } 55 | 56 | // GetEmailAddress returns the Account email address 57 | func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { 58 | if s.AccessToken == "" { 59 | return "", errors.New("missing access token") 60 | } 61 | req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) 62 | if err != nil { 63 | return "", err 64 | } 65 | req.Header = getFacebookHeader(s.AccessToken) 66 | 67 | type result struct { 68 | Email string 69 | } 70 | var r result 71 | err = api.RequestJSON(req, &r) 72 | if err != nil { 73 | return "", err 74 | } 75 | if r.Email == "" { 76 | return "", errors.New("no email") 77 | } 78 | return r.Email, nil 79 | } 80 | 81 | // ValidateSessionState validates the AccessToken 82 | func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { 83 | return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) 84 | } 85 | -------------------------------------------------------------------------------- /providers/github.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | "net/url" 10 | "path" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | // GitHubProvider represents an GitHub based Identity Provider 16 | type GitHubProvider struct { 17 | *ProviderData 18 | Org string 19 | Team string 20 | } 21 | 22 | // NewGitHubProvider initiates a new GitHubProvider 23 | func NewGitHubProvider(p *ProviderData) *GitHubProvider { 24 | p.ProviderName = "GitHub" 25 | if p.LoginURL == nil || p.LoginURL.String() == "" { 26 | p.LoginURL = &url.URL{ 27 | Scheme: "https", 28 | Host: "github.com", 29 | Path: "/login/oauth/authorize", 30 | } 31 | } 32 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 33 | p.RedeemURL = &url.URL{ 34 | Scheme: "https", 35 | Host: "github.com", 36 | Path: "/login/oauth/access_token", 37 | } 38 | } 39 | // ValidationURL is the API Base URL 40 | if p.ValidateURL == nil || p.ValidateURL.String() == "" { 41 | p.ValidateURL = &url.URL{ 42 | Scheme: "https", 43 | Host: "api.github.com", 44 | Path: "/", 45 | } 46 | } 47 | if p.Scope == "" { 48 | p.Scope = "user:email" 49 | } 50 | return &GitHubProvider{ProviderData: p} 51 | } 52 | 53 | // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope 54 | func (p *GitHubProvider) SetOrgTeam(org, team string) { 55 | p.Org = org 56 | p.Team = team 57 | if org != "" || team != "" { 58 | p.Scope += " read:org" 59 | } 60 | } 61 | 62 | func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { 63 | // https://developer.github.com/v3/orgs/#list-your-organizations 64 | 65 | var orgs []struct { 66 | Login string `json:"login"` 67 | } 68 | 69 | type orgsPage []struct { 70 | Login string `json:"login"` 71 | } 72 | 73 | pn := 1 74 | for { 75 | params := url.Values{ 76 | "limit": {"200"}, 77 | "page": {strconv.Itoa(pn)}, 78 | } 79 | 80 | endpoint := &url.URL{ 81 | Scheme: p.ValidateURL.Scheme, 82 | Host: p.ValidateURL.Host, 83 | Path: path.Join(p.ValidateURL.Path, "/user/orgs"), 84 | RawQuery: params.Encode(), 85 | } 86 | req, _ := http.NewRequest("GET", endpoint.String(), nil) 87 | req.Header.Set("Accept", "application/vnd.github.v3+json") 88 | req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) 89 | resp, err := http.DefaultClient.Do(req) 90 | if err != nil { 91 | return false, err 92 | } 93 | 94 | body, err := ioutil.ReadAll(resp.Body) 95 | resp.Body.Close() 96 | if err != nil { 97 | return false, err 98 | } 99 | if resp.StatusCode != 200 { 100 | return false, fmt.Errorf( 101 | "got %d from %q %s", resp.StatusCode, endpoint.String(), body) 102 | } 103 | 104 | var op orgsPage 105 | if err := json.Unmarshal(body, &op); err != nil { 106 | return false, err 107 | } 108 | if len(op) == 0 { 109 | break 110 | } 111 | 112 | orgs = append(orgs, op...) 113 | pn++ 114 | } 115 | 116 | var presentOrgs []string 117 | for _, org := range orgs { 118 | if p.Org == org.Login { 119 | log.Printf("Found Github Organization: %q", org.Login) 120 | return true, nil 121 | } 122 | presentOrgs = append(presentOrgs, org.Login) 123 | } 124 | 125 | log.Printf("Missing Organization:%q in %v", p.Org, presentOrgs) 126 | return false, nil 127 | } 128 | 129 | func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { 130 | // https://developer.github.com/v3/orgs/teams/#list-user-teams 131 | 132 | var teams []struct { 133 | Name string `json:"name"` 134 | Slug string `json:"slug"` 135 | Org struct { 136 | Login string `json:"login"` 137 | } `json:"organization"` 138 | } 139 | 140 | params := url.Values{ 141 | "limit": {"200"}, 142 | } 143 | 144 | endpoint := &url.URL{ 145 | Scheme: p.ValidateURL.Scheme, 146 | Host: p.ValidateURL.Host, 147 | Path: path.Join(p.ValidateURL.Path, "/user/teams"), 148 | RawQuery: params.Encode(), 149 | } 150 | req, _ := http.NewRequest("GET", endpoint.String(), nil) 151 | req.Header.Set("Accept", "application/vnd.github.v3+json") 152 | req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) 153 | resp, err := http.DefaultClient.Do(req) 154 | if err != nil { 155 | return false, err 156 | } 157 | 158 | body, err := ioutil.ReadAll(resp.Body) 159 | resp.Body.Close() 160 | if err != nil { 161 | return false, err 162 | } 163 | if resp.StatusCode != 200 { 164 | return false, fmt.Errorf( 165 | "got %d from %q %s", resp.StatusCode, endpoint.String(), body) 166 | } 167 | 168 | if err := json.Unmarshal(body, &teams); err != nil { 169 | return false, fmt.Errorf("%s unmarshaling %s", err, body) 170 | } 171 | 172 | var hasOrg bool 173 | presentOrgs := make(map[string]bool) 174 | var presentTeams []string 175 | for _, team := range teams { 176 | presentOrgs[team.Org.Login] = true 177 | if p.Org == team.Org.Login { 178 | hasOrg = true 179 | ts := strings.Split(p.Team, ",") 180 | for _, t := range ts { 181 | if t == team.Slug { 182 | log.Printf("Found Github Organization:%q Team:%q (Name:%q)", team.Org.Login, team.Slug, team.Name) 183 | return true, nil 184 | } 185 | } 186 | presentTeams = append(presentTeams, team.Slug) 187 | } 188 | } 189 | if hasOrg { 190 | log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) 191 | } else { 192 | var allOrgs []string 193 | for org := range presentOrgs { 194 | allOrgs = append(allOrgs, org) 195 | } 196 | log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) 197 | } 198 | return false, nil 199 | } 200 | 201 | // GetEmailAddress returns the Account email address 202 | func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { 203 | 204 | var emails []struct { 205 | Email string `json:"email"` 206 | Primary bool `json:"primary"` 207 | Verified bool `json:"verified"` 208 | } 209 | 210 | // if we require an Org or Team, check that first 211 | if p.Org != "" { 212 | if p.Team != "" { 213 | if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { 214 | return "", err 215 | } 216 | } else { 217 | if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { 218 | return "", err 219 | } 220 | } 221 | } 222 | 223 | endpoint := &url.URL{ 224 | Scheme: p.ValidateURL.Scheme, 225 | Host: p.ValidateURL.Host, 226 | Path: path.Join(p.ValidateURL.Path, "/user/emails"), 227 | } 228 | req, _ := http.NewRequest("GET", endpoint.String(), nil) 229 | req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) 230 | resp, err := http.DefaultClient.Do(req) 231 | if err != nil { 232 | return "", err 233 | } 234 | body, err := ioutil.ReadAll(resp.Body) 235 | resp.Body.Close() 236 | if err != nil { 237 | return "", err 238 | } 239 | 240 | if resp.StatusCode != 200 { 241 | return "", fmt.Errorf("got %d from %q %s", 242 | resp.StatusCode, endpoint.String(), body) 243 | } 244 | 245 | log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) 246 | 247 | if err := json.Unmarshal(body, &emails); err != nil { 248 | return "", fmt.Errorf("%s unmarshaling %s", err, body) 249 | } 250 | 251 | for _, email := range emails { 252 | if email.Primary && email.Verified { 253 | return email.Email, nil 254 | } 255 | } 256 | 257 | return "", nil 258 | } 259 | 260 | // GetUserName returns the Account user name 261 | func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { 262 | var user struct { 263 | Login string `json:"login"` 264 | Email string `json:"email"` 265 | } 266 | 267 | endpoint := &url.URL{ 268 | Scheme: p.ValidateURL.Scheme, 269 | Host: p.ValidateURL.Host, 270 | Path: path.Join(p.ValidateURL.Path, "/user"), 271 | } 272 | 273 | req, err := http.NewRequest("GET", endpoint.String(), nil) 274 | if err != nil { 275 | return "", fmt.Errorf("could not create new GET request: %v", err) 276 | } 277 | 278 | req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) 279 | resp, err := http.DefaultClient.Do(req) 280 | if err != nil { 281 | return "", err 282 | } 283 | 284 | body, err := ioutil.ReadAll(resp.Body) 285 | defer resp.Body.Close() 286 | if err != nil { 287 | return "", err 288 | } 289 | 290 | if resp.StatusCode != 200 { 291 | return "", fmt.Errorf("got %d from %q %s", 292 | resp.StatusCode, endpoint.String(), body) 293 | } 294 | 295 | log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) 296 | 297 | if err := json.Unmarshal(body, &user); err != nil { 298 | return "", fmt.Errorf("%s unmarshaling %s", err, body) 299 | } 300 | 301 | return user.Login, nil 302 | } 303 | -------------------------------------------------------------------------------- /providers/github_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testGitHubProvider(hostname string) *GitHubProvider { 13 | p := NewGitHubProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | Scope: ""}) 21 | if hostname != "" { 22 | updateURL(p.Data().LoginURL, hostname) 23 | updateURL(p.Data().RedeemURL, hostname) 24 | updateURL(p.Data().ProfileURL, hostname) 25 | updateURL(p.Data().ValidateURL, hostname) 26 | } 27 | return p 28 | } 29 | 30 | func testGitHubBackend(payload []string) *httptest.Server { 31 | pathToQueryMap := map[string][]string{ 32 | "/user": {""}, 33 | "/user/emails": {""}, 34 | "/user/orgs": {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, 35 | } 36 | 37 | return httptest.NewServer(http.HandlerFunc( 38 | func(w http.ResponseWriter, r *http.Request) { 39 | query, ok := pathToQueryMap[r.URL.Path] 40 | validQuery := false 41 | index := 0 42 | for i, q := range query { 43 | if q == r.URL.RawQuery { 44 | validQuery = true 45 | index = i 46 | } 47 | } 48 | if !ok { 49 | w.WriteHeader(404) 50 | } else if !validQuery { 51 | w.WriteHeader(404) 52 | } else { 53 | w.WriteHeader(200) 54 | w.Write([]byte(payload[index])) 55 | } 56 | })) 57 | } 58 | 59 | func TestGitHubProviderDefaults(t *testing.T) { 60 | p := testGitHubProvider("") 61 | assert.NotEqual(t, nil, p) 62 | assert.Equal(t, "GitHub", p.Data().ProviderName) 63 | assert.Equal(t, "https://github.com/login/oauth/authorize", 64 | p.Data().LoginURL.String()) 65 | assert.Equal(t, "https://github.com/login/oauth/access_token", 66 | p.Data().RedeemURL.String()) 67 | assert.Equal(t, "https://api.github.com/", 68 | p.Data().ValidateURL.String()) 69 | assert.Equal(t, "user:email", p.Data().Scope) 70 | } 71 | 72 | func TestGitHubProviderOverrides(t *testing.T) { 73 | p := NewGitHubProvider( 74 | &ProviderData{ 75 | LoginURL: &url.URL{ 76 | Scheme: "https", 77 | Host: "example.com", 78 | Path: "/login/oauth/authorize"}, 79 | RedeemURL: &url.URL{ 80 | Scheme: "https", 81 | Host: "example.com", 82 | Path: "/login/oauth/access_token"}, 83 | ValidateURL: &url.URL{ 84 | Scheme: "https", 85 | Host: "api.example.com", 86 | Path: "/"}, 87 | Scope: "profile"}) 88 | assert.NotEqual(t, nil, p) 89 | assert.Equal(t, "GitHub", p.Data().ProviderName) 90 | assert.Equal(t, "https://example.com/login/oauth/authorize", 91 | p.Data().LoginURL.String()) 92 | assert.Equal(t, "https://example.com/login/oauth/access_token", 93 | p.Data().RedeemURL.String()) 94 | assert.Equal(t, "https://api.example.com/", 95 | p.Data().ValidateURL.String()) 96 | assert.Equal(t, "profile", p.Data().Scope) 97 | } 98 | 99 | func TestGitHubProviderGetEmailAddress(t *testing.T) { 100 | b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}) 101 | defer b.Close() 102 | 103 | bURL, _ := url.Parse(b.URL) 104 | p := testGitHubProvider(bURL.Host) 105 | 106 | session := &SessionState{AccessToken: "imaginary_access_token"} 107 | email, err := p.GetEmailAddress(session) 108 | assert.Equal(t, nil, err) 109 | assert.Equal(t, "michael.bland@gsa.gov", email) 110 | } 111 | 112 | func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { 113 | b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`}) 114 | defer b.Close() 115 | 116 | bURL, _ := url.Parse(b.URL) 117 | p := testGitHubProvider(bURL.Host) 118 | 119 | session := &SessionState{AccessToken: "imaginary_access_token"} 120 | email, err := p.GetEmailAddress(session) 121 | assert.Equal(t, nil, err) 122 | assert.Empty(t, "", email) 123 | } 124 | 125 | func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { 126 | b := testGitHubBackend([]string{ 127 | `[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`, 128 | `[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`, 129 | `[ ]`, 130 | }) 131 | defer b.Close() 132 | 133 | bURL, _ := url.Parse(b.URL) 134 | p := testGitHubProvider(bURL.Host) 135 | p.Org = "testorg1" 136 | 137 | session := &SessionState{AccessToken: "imaginary_access_token"} 138 | email, err := p.GetEmailAddress(session) 139 | assert.Equal(t, nil, err) 140 | assert.Equal(t, "michael.bland@gsa.gov", email) 141 | } 142 | 143 | // Note that trying to trigger the "failed building request" case is not 144 | // practical, since the only way it can fail is if the URL fails to parse. 145 | func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { 146 | b := testGitHubBackend([]string{"unused payload"}) 147 | defer b.Close() 148 | 149 | bURL, _ := url.Parse(b.URL) 150 | p := testGitHubProvider(bURL.Host) 151 | 152 | // We'll trigger a request failure by using an unexpected access 153 | // token. Alternatively, we could allow the parsing of the payload as 154 | // JSON to fail. 155 | session := &SessionState{AccessToken: "unexpected_access_token"} 156 | email, err := p.GetEmailAddress(session) 157 | assert.NotEqual(t, nil, err) 158 | assert.Equal(t, "", email) 159 | } 160 | 161 | func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { 162 | b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"}) 163 | defer b.Close() 164 | 165 | bURL, _ := url.Parse(b.URL) 166 | p := testGitHubProvider(bURL.Host) 167 | 168 | session := &SessionState{AccessToken: "imaginary_access_token"} 169 | email, err := p.GetEmailAddress(session) 170 | assert.NotEqual(t, nil, err) 171 | assert.Equal(t, "", email) 172 | } 173 | 174 | func TestGitHubProviderGetUserName(t *testing.T) { 175 | b := testGitHubBackend([]string{`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}) 176 | defer b.Close() 177 | 178 | bURL, _ := url.Parse(b.URL) 179 | p := testGitHubProvider(bURL.Host) 180 | 181 | session := &SessionState{AccessToken: "imaginary_access_token"} 182 | email, err := p.GetUserName(session) 183 | assert.Equal(t, nil, err) 184 | assert.Equal(t, "mbland", email) 185 | } 186 | -------------------------------------------------------------------------------- /providers/gitlab.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/pusher/oauth2_proxy/api" 9 | ) 10 | 11 | // GitLabProvider represents an GitLab based Identity Provider 12 | type GitLabProvider struct { 13 | *ProviderData 14 | } 15 | 16 | // NewGitLabProvider initiates a new GitLabProvider 17 | func NewGitLabProvider(p *ProviderData) *GitLabProvider { 18 | p.ProviderName = "GitLab" 19 | if p.LoginURL == nil || p.LoginURL.String() == "" { 20 | p.LoginURL = &url.URL{ 21 | Scheme: "https", 22 | Host: "gitlab.com", 23 | Path: "/oauth/authorize", 24 | } 25 | } 26 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 27 | p.RedeemURL = &url.URL{ 28 | Scheme: "https", 29 | Host: "gitlab.com", 30 | Path: "/oauth/token", 31 | } 32 | } 33 | if p.ValidateURL == nil || p.ValidateURL.String() == "" { 34 | p.ValidateURL = &url.URL{ 35 | Scheme: "https", 36 | Host: "gitlab.com", 37 | Path: "/api/v4/user", 38 | } 39 | } 40 | if p.Scope == "" { 41 | p.Scope = "read_user" 42 | } 43 | return &GitLabProvider{ProviderData: p} 44 | } 45 | 46 | // GetEmailAddress returns the Account email address 47 | func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { 48 | 49 | req, err := http.NewRequest("GET", 50 | p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) 51 | if err != nil { 52 | log.Printf("failed building request %s", err) 53 | return "", err 54 | } 55 | json, err := api.Request(req) 56 | if err != nil { 57 | log.Printf("failed making request %s", err) 58 | return "", err 59 | } 60 | return json.Get("email").String() 61 | } 62 | -------------------------------------------------------------------------------- /providers/gitlab_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testGitLabProvider(hostname string) *GitLabProvider { 13 | p := NewGitLabProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | Scope: ""}) 21 | if hostname != "" { 22 | updateURL(p.Data().LoginURL, hostname) 23 | updateURL(p.Data().RedeemURL, hostname) 24 | updateURL(p.Data().ProfileURL, hostname) 25 | updateURL(p.Data().ValidateURL, hostname) 26 | } 27 | return p 28 | } 29 | 30 | func testGitLabBackend(payload string) *httptest.Server { 31 | path := "/api/v4/user" 32 | query := "access_token=imaginary_access_token" 33 | 34 | return httptest.NewServer(http.HandlerFunc( 35 | func(w http.ResponseWriter, r *http.Request) { 36 | if r.URL.Path != path || r.URL.RawQuery != query { 37 | w.WriteHeader(404) 38 | } else { 39 | w.WriteHeader(200) 40 | w.Write([]byte(payload)) 41 | } 42 | })) 43 | } 44 | 45 | func TestGitLabProviderDefaults(t *testing.T) { 46 | p := testGitLabProvider("") 47 | assert.NotEqual(t, nil, p) 48 | assert.Equal(t, "GitLab", p.Data().ProviderName) 49 | assert.Equal(t, "https://gitlab.com/oauth/authorize", 50 | p.Data().LoginURL.String()) 51 | assert.Equal(t, "https://gitlab.com/oauth/token", 52 | p.Data().RedeemURL.String()) 53 | assert.Equal(t, "https://gitlab.com/api/v4/user", 54 | p.Data().ValidateURL.String()) 55 | assert.Equal(t, "read_user", p.Data().Scope) 56 | } 57 | 58 | func TestGitLabProviderOverrides(t *testing.T) { 59 | p := NewGitLabProvider( 60 | &ProviderData{ 61 | LoginURL: &url.URL{ 62 | Scheme: "https", 63 | Host: "example.com", 64 | Path: "/oauth/auth"}, 65 | RedeemURL: &url.URL{ 66 | Scheme: "https", 67 | Host: "example.com", 68 | Path: "/oauth/token"}, 69 | ValidateURL: &url.URL{ 70 | Scheme: "https", 71 | Host: "example.com", 72 | Path: "/api/v4/user"}, 73 | Scope: "profile"}) 74 | assert.NotEqual(t, nil, p) 75 | assert.Equal(t, "GitLab", p.Data().ProviderName) 76 | assert.Equal(t, "https://example.com/oauth/auth", 77 | p.Data().LoginURL.String()) 78 | assert.Equal(t, "https://example.com/oauth/token", 79 | p.Data().RedeemURL.String()) 80 | assert.Equal(t, "https://example.com/api/v4/user", 81 | p.Data().ValidateURL.String()) 82 | assert.Equal(t, "profile", p.Data().Scope) 83 | } 84 | 85 | func TestGitLabProviderGetEmailAddress(t *testing.T) { 86 | b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") 87 | defer b.Close() 88 | 89 | bURL, _ := url.Parse(b.URL) 90 | p := testGitLabProvider(bURL.Host) 91 | 92 | session := &SessionState{AccessToken: "imaginary_access_token"} 93 | email, err := p.GetEmailAddress(session) 94 | assert.Equal(t, nil, err) 95 | assert.Equal(t, "michael.bland@gsa.gov", email) 96 | } 97 | 98 | // Note that trying to trigger the "failed building request" case is not 99 | // practical, since the only way it can fail is if the URL fails to parse. 100 | func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { 101 | b := testGitLabBackend("unused payload") 102 | defer b.Close() 103 | 104 | bURL, _ := url.Parse(b.URL) 105 | p := testGitLabProvider(bURL.Host) 106 | 107 | // We'll trigger a request failure by using an unexpected access 108 | // token. Alternatively, we could allow the parsing of the payload as 109 | // JSON to fail. 110 | session := &SessionState{AccessToken: "unexpected_access_token"} 111 | email, err := p.GetEmailAddress(session) 112 | assert.NotEqual(t, nil, err) 113 | assert.Equal(t, "", email) 114 | } 115 | 116 | func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { 117 | b := testGitLabBackend("{\"foo\": \"bar\"}") 118 | defer b.Close() 119 | 120 | bURL, _ := url.Parse(b.URL) 121 | p := testGitLabProvider(bURL.Host) 122 | 123 | session := &SessionState{AccessToken: "imaginary_access_token"} 124 | email, err := p.GetEmailAddress(session) 125 | assert.NotEqual(t, nil, err) 126 | assert.Equal(t, "", email) 127 | } 128 | -------------------------------------------------------------------------------- /providers/google.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "io/ioutil" 11 | "log" 12 | "net/http" 13 | "net/url" 14 | "strings" 15 | "time" 16 | 17 | "golang.org/x/oauth2" 18 | "golang.org/x/oauth2/google" 19 | admin "google.golang.org/api/admin/directory/v1" 20 | "google.golang.org/api/googleapi" 21 | ) 22 | 23 | // GoogleProvider represents an Google based Identity Provider 24 | type GoogleProvider struct { 25 | *ProviderData 26 | RedeemRefreshURL *url.URL 27 | // GroupValidator is a function that determines if the passed email is in 28 | // the configured Google group. 29 | GroupValidator func(string) bool 30 | } 31 | 32 | // NewGoogleProvider initiates a new GoogleProvider 33 | func NewGoogleProvider(p *ProviderData) *GoogleProvider { 34 | p.ProviderName = "Google" 35 | if p.LoginURL.String() == "" { 36 | p.LoginURL = &url.URL{Scheme: "https", 37 | Host: "accounts.google.com", 38 | Path: "/o/oauth2/auth", 39 | // to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline 40 | RawQuery: "access_type=offline", 41 | } 42 | } 43 | if p.RedeemURL.String() == "" { 44 | p.RedeemURL = &url.URL{Scheme: "https", 45 | Host: "www.googleapis.com", 46 | Path: "/oauth2/v3/token"} 47 | } 48 | if p.ValidateURL.String() == "" { 49 | p.ValidateURL = &url.URL{Scheme: "https", 50 | Host: "www.googleapis.com", 51 | Path: "/oauth2/v1/tokeninfo"} 52 | } 53 | if p.Scope == "" { 54 | p.Scope = "profile email" 55 | } 56 | 57 | return &GoogleProvider{ 58 | ProviderData: p, 59 | // Set a default GroupValidator to just always return valid (true), it will 60 | // be overwritten if we configured a Google group restriction. 61 | GroupValidator: func(email string) bool { 62 | return true 63 | }, 64 | } 65 | } 66 | 67 | func emailFromIDToken(idToken string) (string, error) { 68 | 69 | // id_token is a base64 encode ID token payload 70 | // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo 71 | jwt := strings.Split(idToken, ".") 72 | jwtData := strings.TrimSuffix(jwt[1], "=") 73 | b, err := base64.RawURLEncoding.DecodeString(jwtData) 74 | if err != nil { 75 | return "", err 76 | } 77 | 78 | var email struct { 79 | Email string `json:"email"` 80 | EmailVerified bool `json:"email_verified"` 81 | } 82 | err = json.Unmarshal(b, &email) 83 | if err != nil { 84 | return "", err 85 | } 86 | if email.Email == "" { 87 | return "", errors.New("missing email") 88 | } 89 | if !email.EmailVerified { 90 | return "", fmt.Errorf("email %s not listed as verified", email.Email) 91 | } 92 | return email.Email, nil 93 | } 94 | 95 | // Redeem exchanges the OAuth2 authentication token for an ID token 96 | func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { 97 | if code == "" { 98 | err = errors.New("missing code") 99 | return 100 | } 101 | 102 | params := url.Values{} 103 | params.Add("redirect_uri", redirectURL) 104 | params.Add("client_id", p.ClientID) 105 | params.Add("client_secret", p.ClientSecret) 106 | params.Add("code", code) 107 | params.Add("grant_type", "authorization_code") 108 | var req *http.Request 109 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 110 | if err != nil { 111 | return 112 | } 113 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 114 | 115 | resp, err := http.DefaultClient.Do(req) 116 | if err != nil { 117 | return 118 | } 119 | var body []byte 120 | body, err = ioutil.ReadAll(resp.Body) 121 | resp.Body.Close() 122 | if err != nil { 123 | return 124 | } 125 | 126 | if resp.StatusCode != 200 { 127 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 128 | return 129 | } 130 | 131 | var jsonResponse struct { 132 | AccessToken string `json:"access_token"` 133 | RefreshToken string `json:"refresh_token"` 134 | ExpiresIn int64 `json:"expires_in"` 135 | IDToken string `json:"id_token"` 136 | } 137 | err = json.Unmarshal(body, &jsonResponse) 138 | if err != nil { 139 | return 140 | } 141 | var email string 142 | email, err = emailFromIDToken(jsonResponse.IDToken) 143 | if err != nil { 144 | return 145 | } 146 | s = &SessionState{ 147 | AccessToken: jsonResponse.AccessToken, 148 | IDToken: jsonResponse.IDToken, 149 | ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), 150 | RefreshToken: jsonResponse.RefreshToken, 151 | Email: email, 152 | } 153 | return 154 | } 155 | 156 | // SetGroupRestriction configures the GoogleProvider to restrict access to the 157 | // specified group(s). AdminEmail has to be an administrative email on the domain that is 158 | // checked. CredentialsFile is the path to a json file containing a Google service 159 | // account credentials. 160 | func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { 161 | adminService := getAdminService(adminEmail, credentialsReader) 162 | p.GroupValidator = func(email string) bool { 163 | return userInGroup(adminService, groups, email) 164 | } 165 | } 166 | 167 | func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Service { 168 | data, err := ioutil.ReadAll(credentialsReader) 169 | if err != nil { 170 | log.Fatal("can't read Google credentials file:", err) 171 | } 172 | conf, err := google.JWTConfigFromJSON(data, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) 173 | if err != nil { 174 | log.Fatal("can't load Google credentials file:", err) 175 | } 176 | conf.Subject = adminEmail 177 | 178 | client := conf.Client(oauth2.NoContext) 179 | adminService, err := admin.New(client) 180 | if err != nil { 181 | log.Fatal(err) 182 | } 183 | return adminService 184 | } 185 | 186 | func userInGroup(service *admin.Service, groups []string, email string) bool { 187 | user, err := fetchUser(service, email) 188 | if err != nil { 189 | log.Printf("error fetching user: %v", err) 190 | return false 191 | } 192 | id := user.Id 193 | custID := user.CustomerId 194 | 195 | for _, group := range groups { 196 | members, err := fetchGroupMembers(service, group) 197 | if err != nil { 198 | if err, ok := err.(*googleapi.Error); ok && err.Code == 404 { 199 | log.Printf("error fetching members for group %s: group does not exist", group) 200 | } else { 201 | log.Printf("error fetching group members: %v", err) 202 | return false 203 | } 204 | } 205 | 206 | for _, member := range members { 207 | switch member.Type { 208 | case "CUSTOMER": 209 | if member.Id == custID { 210 | return true 211 | } 212 | case "USER": 213 | if member.Id == id { 214 | return true 215 | } 216 | } 217 | } 218 | } 219 | return false 220 | } 221 | 222 | func fetchUser(service *admin.Service, email string) (*admin.User, error) { 223 | user, err := service.Users.Get(email).Do() 224 | return user, err 225 | } 226 | 227 | func fetchGroupMembers(service *admin.Service, group string) ([]*admin.Member, error) { 228 | members := []*admin.Member{} 229 | pageToken := "" 230 | for { 231 | req := service.Members.List(group) 232 | if pageToken != "" { 233 | req.PageToken(pageToken) 234 | } 235 | r, err := req.Do() 236 | if err != nil { 237 | return nil, err 238 | } 239 | for _, member := range r.Members { 240 | members = append(members, member) 241 | } 242 | if r.NextPageToken == "" { 243 | break 244 | } 245 | pageToken = r.NextPageToken 246 | } 247 | return members, nil 248 | } 249 | 250 | // ValidateGroup validates that the provided email exists in the configured Google 251 | // group(s). 252 | func (p *GoogleProvider) ValidateGroup(email string) bool { 253 | return p.GroupValidator(email) 254 | } 255 | 256 | // RefreshSessionIfNeeded checks if the session has expired and uses the 257 | // RefreshToken to fetch a new ID token if required 258 | func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 259 | if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { 260 | return false, nil 261 | } 262 | 263 | newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken) 264 | if err != nil { 265 | return false, err 266 | } 267 | 268 | // re-check that the user is in the proper google group(s) 269 | if !p.ValidateGroup(s.Email) { 270 | return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) 271 | } 272 | 273 | origExpiration := s.ExpiresOn 274 | s.AccessToken = newToken 275 | s.IDToken = newIDToken 276 | s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) 277 | log.Printf("refreshed access token %s (expired on %s)", s, origExpiration) 278 | return true, nil 279 | } 280 | 281 | func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) { 282 | // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh 283 | params := url.Values{} 284 | params.Add("client_id", p.ClientID) 285 | params.Add("client_secret", p.ClientSecret) 286 | params.Add("refresh_token", refreshToken) 287 | params.Add("grant_type", "refresh_token") 288 | var req *http.Request 289 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 290 | if err != nil { 291 | return 292 | } 293 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 294 | 295 | resp, err := http.DefaultClient.Do(req) 296 | if err != nil { 297 | return 298 | } 299 | var body []byte 300 | body, err = ioutil.ReadAll(resp.Body) 301 | resp.Body.Close() 302 | if err != nil { 303 | return 304 | } 305 | 306 | if resp.StatusCode != 200 { 307 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 308 | return 309 | } 310 | 311 | var data struct { 312 | AccessToken string `json:"access_token"` 313 | ExpiresIn int64 `json:"expires_in"` 314 | IDToken string `json:"id_token"` 315 | } 316 | err = json.Unmarshal(body, &data) 317 | if err != nil { 318 | return 319 | } 320 | token = data.AccessToken 321 | idToken = data.IDToken 322 | expires = time.Duration(data.ExpiresIn) * time.Second 323 | return 324 | } 325 | -------------------------------------------------------------------------------- /providers/google_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func newRedeemServer(body []byte) (*url.URL, *httptest.Server) { 15 | s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 16 | rw.Write(body) 17 | })) 18 | u, _ := url.Parse(s.URL) 19 | return u, s 20 | } 21 | 22 | func newGoogleProvider() *GoogleProvider { 23 | return NewGoogleProvider( 24 | &ProviderData{ 25 | ProviderName: "", 26 | LoginURL: &url.URL{}, 27 | RedeemURL: &url.URL{}, 28 | ProfileURL: &url.URL{}, 29 | ValidateURL: &url.URL{}, 30 | Scope: ""}) 31 | } 32 | 33 | func TestGoogleProviderDefaults(t *testing.T) { 34 | p := newGoogleProvider() 35 | assert.NotEqual(t, nil, p) 36 | assert.Equal(t, "Google", p.Data().ProviderName) 37 | assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", 38 | p.Data().LoginURL.String()) 39 | assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", 40 | p.Data().RedeemURL.String()) 41 | assert.Equal(t, "https://www.googleapis.com/oauth2/v1/tokeninfo", 42 | p.Data().ValidateURL.String()) 43 | assert.Equal(t, "", p.Data().ProfileURL.String()) 44 | assert.Equal(t, "profile email", p.Data().Scope) 45 | } 46 | 47 | func TestGoogleProviderOverrides(t *testing.T) { 48 | p := NewGoogleProvider( 49 | &ProviderData{ 50 | LoginURL: &url.URL{ 51 | Scheme: "https", 52 | Host: "example.com", 53 | Path: "/oauth/auth"}, 54 | RedeemURL: &url.URL{ 55 | Scheme: "https", 56 | Host: "example.com", 57 | Path: "/oauth/token"}, 58 | ProfileURL: &url.URL{ 59 | Scheme: "https", 60 | Host: "example.com", 61 | Path: "/oauth/profile"}, 62 | ValidateURL: &url.URL{ 63 | Scheme: "https", 64 | Host: "example.com", 65 | Path: "/oauth/tokeninfo"}, 66 | Scope: "profile"}) 67 | assert.NotEqual(t, nil, p) 68 | assert.Equal(t, "Google", p.Data().ProviderName) 69 | assert.Equal(t, "https://example.com/oauth/auth", 70 | p.Data().LoginURL.String()) 71 | assert.Equal(t, "https://example.com/oauth/token", 72 | p.Data().RedeemURL.String()) 73 | assert.Equal(t, "https://example.com/oauth/profile", 74 | p.Data().ProfileURL.String()) 75 | assert.Equal(t, "https://example.com/oauth/tokeninfo", 76 | p.Data().ValidateURL.String()) 77 | assert.Equal(t, "profile", p.Data().Scope) 78 | } 79 | 80 | type redeemResponse struct { 81 | AccessToken string `json:"access_token"` 82 | RefreshToken string `json:"refresh_token"` 83 | ExpiresIn int64 `json:"expires_in"` 84 | IDToken string `json:"id_token"` 85 | } 86 | 87 | func TestGoogleProviderGetEmailAddress(t *testing.T) { 88 | p := newGoogleProvider() 89 | body, err := json.Marshal(redeemResponse{ 90 | AccessToken: "a1234", 91 | ExpiresIn: 10, 92 | RefreshToken: "refresh12345", 93 | IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), 94 | }) 95 | assert.Equal(t, nil, err) 96 | var server *httptest.Server 97 | p.RedeemURL, server = newRedeemServer(body) 98 | defer server.Close() 99 | 100 | session, err := p.Redeem("http://redirect/", "code1234") 101 | assert.Equal(t, nil, err) 102 | assert.NotEqual(t, session, nil) 103 | assert.Equal(t, "michael.bland@gsa.gov", session.Email) 104 | assert.Equal(t, "a1234", session.AccessToken) 105 | assert.Equal(t, "refresh12345", session.RefreshToken) 106 | } 107 | 108 | func TestGoogleProviderValidateGroup(t *testing.T) { 109 | p := newGoogleProvider() 110 | p.GroupValidator = func(email string) bool { 111 | return email == "michael.bland@gsa.gov" 112 | } 113 | assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) 114 | p.GroupValidator = func(email string) bool { 115 | return email != "michael.bland@gsa.gov" 116 | } 117 | assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) 118 | } 119 | 120 | func TestGoogleProviderWithoutValidateGroup(t *testing.T) { 121 | p := newGoogleProvider() 122 | assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) 123 | } 124 | 125 | // 126 | func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { 127 | p := newGoogleProvider() 128 | body, err := json.Marshal(redeemResponse{ 129 | AccessToken: "a1234", 130 | IDToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, 131 | }) 132 | assert.Equal(t, nil, err) 133 | var server *httptest.Server 134 | p.RedeemURL, server = newRedeemServer(body) 135 | defer server.Close() 136 | 137 | session, err := p.Redeem("http://redirect/", "code1234") 138 | assert.NotEqual(t, nil, err) 139 | if session != nil { 140 | t.Errorf("expect nill session %#v", session) 141 | } 142 | } 143 | 144 | func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { 145 | p := newGoogleProvider() 146 | 147 | body, err := json.Marshal(redeemResponse{ 148 | AccessToken: "a1234", 149 | IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), 150 | }) 151 | assert.Equal(t, nil, err) 152 | var server *httptest.Server 153 | p.RedeemURL, server = newRedeemServer(body) 154 | defer server.Close() 155 | 156 | session, err := p.Redeem("http://redirect/", "code1234") 157 | assert.NotEqual(t, nil, err) 158 | if session != nil { 159 | t.Errorf("expect nill session %#v", session) 160 | } 161 | 162 | } 163 | 164 | func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { 165 | p := newGoogleProvider() 166 | body, err := json.Marshal(redeemResponse{ 167 | AccessToken: "a1234", 168 | IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), 169 | }) 170 | assert.Equal(t, nil, err) 171 | var server *httptest.Server 172 | p.RedeemURL, server = newRedeemServer(body) 173 | defer server.Close() 174 | 175 | session, err := p.Redeem("http://redirect/", "code1234") 176 | assert.NotEqual(t, nil, err) 177 | if session != nil { 178 | t.Errorf("expect nill session %#v", session) 179 | } 180 | 181 | } 182 | -------------------------------------------------------------------------------- /providers/internal_util.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/pusher/oauth2_proxy/api" 10 | ) 11 | 12 | // stripToken is a helper function to obfuscate "access_token" 13 | // query parameters 14 | func stripToken(endpoint string) string { 15 | return stripParam("access_token", endpoint) 16 | } 17 | 18 | // stripParam generalizes the obfuscation of a particular 19 | // query parameter - typically 'access_token' or 'client_secret' 20 | // The parameter's second half is replaced by '...' and returned 21 | // as part of the encoded query parameters. 22 | // If the target parameter isn't found, the endpoint is returned 23 | // unmodified. 24 | func stripParam(param, endpoint string) string { 25 | u, err := url.Parse(endpoint) 26 | if err != nil { 27 | log.Printf("error attempting to strip %s: %s", param, err) 28 | return endpoint 29 | } 30 | 31 | if u.RawQuery != "" { 32 | values, err := url.ParseQuery(u.RawQuery) 33 | if err != nil { 34 | log.Printf("error attempting to strip %s: %s", param, err) 35 | return u.String() 36 | } 37 | 38 | if val := values.Get(param); val != "" { 39 | values.Set(param, val[:(len(val)/2)]+"...") 40 | u.RawQuery = values.Encode() 41 | return u.String() 42 | } 43 | } 44 | 45 | return endpoint 46 | } 47 | 48 | // validateToken returns true if token is valid 49 | func validateToken(p Provider, accessToken string, header http.Header) bool { 50 | if accessToken == "" || p.Data().ValidateURL == nil { 51 | return false 52 | } 53 | endpoint := p.Data().ValidateURL.String() 54 | if len(header) == 0 { 55 | params := url.Values{"access_token": {accessToken}} 56 | endpoint = endpoint + "?" + params.Encode() 57 | } 58 | resp, err := api.RequestUnparsedResponse(endpoint, header) 59 | if err != nil { 60 | log.Printf("GET %s", stripToken(endpoint)) 61 | log.Printf("token validation request failed: %s", err) 62 | return false 63 | } 64 | 65 | body, _ := ioutil.ReadAll(resp.Body) 66 | resp.Body.Close() 67 | log.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) 68 | 69 | if resp.StatusCode == 200 { 70 | return true 71 | } 72 | log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) 73 | return false 74 | } 75 | -------------------------------------------------------------------------------- /providers/internal_util_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func updateURL(url *url.URL, hostname string) { 14 | url.Scheme = "http" 15 | url.Host = hostname 16 | } 17 | 18 | type ValidateSessionStateTestProvider struct { 19 | *ProviderData 20 | } 21 | 22 | func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { 23 | return "", errors.New("not implemented") 24 | } 25 | 26 | // Note that we're testing the internal validateToken() used to implement 27 | // several Provider's ValidateSessionState() implementations 28 | func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { 29 | return false 30 | } 31 | 32 | type ValidateSessionStateTest struct { 33 | backend *httptest.Server 34 | responseCode int 35 | provider *ValidateSessionStateTestProvider 36 | header http.Header 37 | } 38 | 39 | func NewValidateSessionStateTest() *ValidateSessionStateTest { 40 | var vtTest ValidateSessionStateTest 41 | 42 | vtTest.backend = httptest.NewServer( 43 | http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 | if r.URL.Path != "/oauth/tokeninfo" { 45 | w.WriteHeader(500) 46 | w.Write([]byte("unknown URL")) 47 | } 48 | tokenParam := r.FormValue("access_token") 49 | if tokenParam == "" { 50 | missing := false 51 | receivedHeaders := r.Header 52 | for k := range vtTest.header { 53 | received := receivedHeaders.Get(k) 54 | expected := vtTest.header.Get(k) 55 | if received == "" || received != expected { 56 | missing = true 57 | } 58 | } 59 | if missing { 60 | w.WriteHeader(500) 61 | w.Write([]byte("no token param and missing or incorrect headers")) 62 | } 63 | } 64 | w.WriteHeader(vtTest.responseCode) 65 | w.Write([]byte("only code matters; contents disregarded")) 66 | 67 | })) 68 | backendURL, _ := url.Parse(vtTest.backend.URL) 69 | vtTest.provider = &ValidateSessionStateTestProvider{ 70 | ProviderData: &ProviderData{ 71 | ValidateURL: &url.URL{ 72 | Scheme: "http", 73 | Host: backendURL.Host, 74 | Path: "/oauth/tokeninfo", 75 | }, 76 | }, 77 | } 78 | vtTest.responseCode = 200 79 | return &vtTest 80 | } 81 | 82 | func (vtTest *ValidateSessionStateTest) Close() { 83 | vtTest.backend.Close() 84 | } 85 | 86 | func TestValidateSessionStateValidToken(t *testing.T) { 87 | vtTest := NewValidateSessionStateTest() 88 | defer vtTest.Close() 89 | assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) 90 | } 91 | 92 | func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { 93 | vtTest := NewValidateSessionStateTest() 94 | defer vtTest.Close() 95 | vtTest.header = make(http.Header) 96 | vtTest.header.Set("Authorization", "Bearer foobar") 97 | assert.Equal(t, true, 98 | validateToken(vtTest.provider, "foobar", vtTest.header)) 99 | } 100 | 101 | func TestValidateSessionStateEmptyToken(t *testing.T) { 102 | vtTest := NewValidateSessionStateTest() 103 | defer vtTest.Close() 104 | assert.Equal(t, false, validateToken(vtTest.provider, "", nil)) 105 | } 106 | 107 | func TestValidateSessionStateEmptyValidateURL(t *testing.T) { 108 | vtTest := NewValidateSessionStateTest() 109 | defer vtTest.Close() 110 | vtTest.provider.Data().ValidateURL = nil 111 | assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) 112 | } 113 | 114 | func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { 115 | vtTest := NewValidateSessionStateTest() 116 | // Close immediately to simulate a network failure 117 | vtTest.Close() 118 | assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) 119 | } 120 | 121 | func TestValidateSessionStateExpiredToken(t *testing.T) { 122 | vtTest := NewValidateSessionStateTest() 123 | defer vtTest.Close() 124 | vtTest.responseCode = 401 125 | assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) 126 | } 127 | 128 | func TestStripTokenNotPresent(t *testing.T) { 129 | test := "http://local.test/api/test?a=1&b=2" 130 | assert.Equal(t, test, stripToken(test)) 131 | } 132 | 133 | func TestStripToken(t *testing.T) { 134 | test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" 135 | expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" 136 | assert.Equal(t, expected, stripToken(test)) 137 | } 138 | -------------------------------------------------------------------------------- /providers/linkedin.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/pusher/oauth2_proxy/api" 10 | ) 11 | 12 | // LinkedInProvider represents an LinkedIn based Identity Provider 13 | type LinkedInProvider struct { 14 | *ProviderData 15 | } 16 | 17 | // NewLinkedInProvider initiates a new LinkedInProvider 18 | func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { 19 | p.ProviderName = "LinkedIn" 20 | if p.LoginURL.String() == "" { 21 | p.LoginURL = &url.URL{Scheme: "https", 22 | Host: "www.linkedin.com", 23 | Path: "/uas/oauth2/authorization"} 24 | } 25 | if p.RedeemURL.String() == "" { 26 | p.RedeemURL = &url.URL{Scheme: "https", 27 | Host: "www.linkedin.com", 28 | Path: "/uas/oauth2/accessToken"} 29 | } 30 | if p.ProfileURL.String() == "" { 31 | p.ProfileURL = &url.URL{Scheme: "https", 32 | Host: "www.linkedin.com", 33 | Path: "/v1/people/~/email-address"} 34 | } 35 | if p.ValidateURL.String() == "" { 36 | p.ValidateURL = p.ProfileURL 37 | } 38 | if p.Scope == "" { 39 | p.Scope = "r_emailaddress r_basicprofile" 40 | } 41 | return &LinkedInProvider{ProviderData: p} 42 | } 43 | 44 | func getLinkedInHeader(accessToken string) http.Header { 45 | header := make(http.Header) 46 | header.Set("Accept", "application/json") 47 | header.Set("x-li-format", "json") 48 | header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) 49 | return header 50 | } 51 | 52 | // GetEmailAddress returns the Account email address 53 | func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { 54 | if s.AccessToken == "" { 55 | return "", errors.New("missing access token") 56 | } 57 | req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) 58 | if err != nil { 59 | return "", err 60 | } 61 | req.Header = getLinkedInHeader(s.AccessToken) 62 | 63 | json, err := api.Request(req) 64 | if err != nil { 65 | return "", err 66 | } 67 | 68 | email, err := json.String() 69 | if err != nil { 70 | return "", err 71 | } 72 | return email, nil 73 | } 74 | 75 | // ValidateSessionState validates the AccessToken 76 | func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { 77 | return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) 78 | } 79 | -------------------------------------------------------------------------------- /providers/linkedin_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func testLinkedInProvider(hostname string) *LinkedInProvider { 13 | p := NewLinkedInProvider( 14 | &ProviderData{ 15 | ProviderName: "", 16 | LoginURL: &url.URL{}, 17 | RedeemURL: &url.URL{}, 18 | ProfileURL: &url.URL{}, 19 | ValidateURL: &url.URL{}, 20 | Scope: ""}) 21 | if hostname != "" { 22 | updateURL(p.Data().LoginURL, hostname) 23 | updateURL(p.Data().RedeemURL, hostname) 24 | updateURL(p.Data().ProfileURL, hostname) 25 | } 26 | return p 27 | } 28 | 29 | func testLinkedInBackend(payload string) *httptest.Server { 30 | path := "/v1/people/~/email-address" 31 | 32 | return httptest.NewServer(http.HandlerFunc( 33 | func(w http.ResponseWriter, r *http.Request) { 34 | if r.URL.Path != path { 35 | w.WriteHeader(404) 36 | } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { 37 | w.WriteHeader(403) 38 | } else { 39 | w.WriteHeader(200) 40 | w.Write([]byte(payload)) 41 | } 42 | })) 43 | } 44 | 45 | func TestLinkedInProviderDefaults(t *testing.T) { 46 | p := testLinkedInProvider("") 47 | assert.NotEqual(t, nil, p) 48 | assert.Equal(t, "LinkedIn", p.Data().ProviderName) 49 | assert.Equal(t, "https://www.linkedin.com/uas/oauth2/authorization", 50 | p.Data().LoginURL.String()) 51 | assert.Equal(t, "https://www.linkedin.com/uas/oauth2/accessToken", 52 | p.Data().RedeemURL.String()) 53 | assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", 54 | p.Data().ProfileURL.String()) 55 | assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", 56 | p.Data().ValidateURL.String()) 57 | assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope) 58 | } 59 | 60 | func TestLinkedInProviderOverrides(t *testing.T) { 61 | p := NewLinkedInProvider( 62 | &ProviderData{ 63 | LoginURL: &url.URL{ 64 | Scheme: "https", 65 | Host: "example.com", 66 | Path: "/oauth/auth"}, 67 | RedeemURL: &url.URL{ 68 | Scheme: "https", 69 | Host: "example.com", 70 | Path: "/oauth/token"}, 71 | ProfileURL: &url.URL{ 72 | Scheme: "https", 73 | Host: "example.com", 74 | Path: "/oauth/profile"}, 75 | ValidateURL: &url.URL{ 76 | Scheme: "https", 77 | Host: "example.com", 78 | Path: "/oauth/tokeninfo"}, 79 | Scope: "profile"}) 80 | assert.NotEqual(t, nil, p) 81 | assert.Equal(t, "LinkedIn", p.Data().ProviderName) 82 | assert.Equal(t, "https://example.com/oauth/auth", 83 | p.Data().LoginURL.String()) 84 | assert.Equal(t, "https://example.com/oauth/token", 85 | p.Data().RedeemURL.String()) 86 | assert.Equal(t, "https://example.com/oauth/profile", 87 | p.Data().ProfileURL.String()) 88 | assert.Equal(t, "https://example.com/oauth/tokeninfo", 89 | p.Data().ValidateURL.String()) 90 | assert.Equal(t, "profile", p.Data().Scope) 91 | } 92 | 93 | func TestLinkedInProviderGetEmailAddress(t *testing.T) { 94 | b := testLinkedInBackend(`"user@linkedin.com"`) 95 | defer b.Close() 96 | 97 | bURL, _ := url.Parse(b.URL) 98 | p := testLinkedInProvider(bURL.Host) 99 | 100 | session := &SessionState{AccessToken: "imaginary_access_token"} 101 | email, err := p.GetEmailAddress(session) 102 | assert.Equal(t, nil, err) 103 | assert.Equal(t, "user@linkedin.com", email) 104 | } 105 | 106 | func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { 107 | b := testLinkedInBackend("unused payload") 108 | defer b.Close() 109 | 110 | bURL, _ := url.Parse(b.URL) 111 | p := testLinkedInProvider(bURL.Host) 112 | 113 | // We'll trigger a request failure by using an unexpected access 114 | // token. Alternatively, we could allow the parsing of the payload as 115 | // JSON to fail. 116 | session := &SessionState{AccessToken: "unexpected_access_token"} 117 | email, err := p.GetEmailAddress(session) 118 | assert.NotEqual(t, nil, err) 119 | assert.Equal(t, "", email) 120 | } 121 | 122 | func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { 123 | b := testLinkedInBackend("{\"foo\": \"bar\"}") 124 | defer b.Close() 125 | 126 | bURL, _ := url.Parse(b.URL) 127 | p := testLinkedInProvider(bURL.Host) 128 | 129 | session := &SessionState{AccessToken: "imaginary_access_token"} 130 | email, err := p.GetEmailAddress(session) 131 | assert.NotEqual(t, nil, err) 132 | assert.Equal(t, "", email) 133 | } 134 | -------------------------------------------------------------------------------- /providers/logingov.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rsa" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io/ioutil" 10 | "math/rand" 11 | "net/http" 12 | "net/url" 13 | "time" 14 | 15 | "github.com/dgrijalva/jwt-go" 16 | "gopkg.in/square/go-jose.v2" 17 | ) 18 | 19 | // LoginGovProvider represents an OIDC based Identity Provider 20 | type LoginGovProvider struct { 21 | *ProviderData 22 | 23 | // TODO (@timothy-spencer): Ideally, the nonce would be in the session state, but the session state 24 | // is created only upon code redemption, not during the auth, when this must be supplied. 25 | Nonce string 26 | AcrValues string 27 | JWTKey *rsa.PrivateKey 28 | PubJWKURL *url.URL 29 | } 30 | 31 | // For generating a nonce 32 | var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 33 | 34 | func randSeq(n int) string { 35 | b := make([]rune, n) 36 | for i := range b { 37 | b[i] = letters[rand.Intn(len(letters))] 38 | } 39 | return string(b) 40 | } 41 | 42 | // NewLoginGovProvider initiates a new LoginGovProvider 43 | func NewLoginGovProvider(p *ProviderData) *LoginGovProvider { 44 | p.ProviderName = "login.gov" 45 | 46 | if p.LoginURL == nil || p.LoginURL.String() == "" { 47 | p.LoginURL = &url.URL{ 48 | Scheme: "https", 49 | Host: "secure.login.gov", 50 | Path: "/openid_connect/authorize", 51 | } 52 | } 53 | if p.RedeemURL == nil || p.RedeemURL.String() == "" { 54 | p.RedeemURL = &url.URL{ 55 | Scheme: "https", 56 | Host: "secure.login.gov", 57 | Path: "/api/openid_connect/token", 58 | } 59 | } 60 | if p.ProfileURL == nil || p.ProfileURL.String() == "" { 61 | p.ProfileURL = &url.URL{ 62 | Scheme: "https", 63 | Host: "secure.login.gov", 64 | Path: "/api/openid_connect/userinfo", 65 | } 66 | } 67 | if p.Scope == "" { 68 | p.Scope = "email openid" 69 | } 70 | 71 | return &LoginGovProvider{ 72 | ProviderData: p, 73 | Nonce: randSeq(32), 74 | } 75 | } 76 | 77 | type loginGovCustomClaims struct { 78 | Acr string `json:"acr"` 79 | Nonce string `json:"nonce"` 80 | Email string `json:"email"` 81 | EmailVerified bool `json:"email_verified"` 82 | GivenName string `json:"given_name"` 83 | FamilyName string `json:"family_name"` 84 | Birthdate string `json:"birthdate"` 85 | AtHash string `json:"at_hash"` 86 | CHash string `json:"c_hash"` 87 | jwt.StandardClaims 88 | } 89 | 90 | // checkNonce checks the nonce in the id_token 91 | func checkNonce(idToken string, p *LoginGovProvider) (err error) { 92 | token, err := jwt.ParseWithClaims(idToken, &loginGovCustomClaims{}, func(token *jwt.Token) (interface{}, error) { 93 | resp, myerr := http.Get(p.PubJWKURL.String()) 94 | if myerr != nil { 95 | return nil, myerr 96 | } 97 | if resp.StatusCode != 200 { 98 | myerr = fmt.Errorf("got %d from %q", resp.StatusCode, p.PubJWKURL.String()) 99 | return nil, myerr 100 | } 101 | body, myerr := ioutil.ReadAll(resp.Body) 102 | resp.Body.Close() 103 | if myerr != nil { 104 | return nil, myerr 105 | } 106 | 107 | var pubkeys jose.JSONWebKeySet 108 | myerr = json.Unmarshal(body, &pubkeys) 109 | if myerr != nil { 110 | return nil, myerr 111 | } 112 | pubkey := pubkeys.Keys[0] 113 | 114 | return pubkey.Key, nil 115 | }) 116 | if err != nil { 117 | return 118 | } 119 | 120 | claims := token.Claims.(*loginGovCustomClaims) 121 | if claims.Nonce != p.Nonce { 122 | err = fmt.Errorf("nonce validation failed") 123 | return 124 | } 125 | return 126 | } 127 | 128 | func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email string, err error) { 129 | // query the user info endpoint for user attributes 130 | var req *http.Request 131 | req, err = http.NewRequest("GET", userInfoEndpoint, nil) 132 | if err != nil { 133 | return 134 | } 135 | req.Header.Set("Authorization", "Bearer "+accessToken) 136 | 137 | resp, err := http.DefaultClient.Do(req) 138 | if err != nil { 139 | return 140 | } 141 | var body []byte 142 | body, err = ioutil.ReadAll(resp.Body) 143 | resp.Body.Close() 144 | if err != nil { 145 | return 146 | } 147 | 148 | if resp.StatusCode != 200 { 149 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) 150 | return 151 | } 152 | 153 | // parse the user attributes from the data we got and make sure that 154 | // the email address has been validated. 155 | var emailData struct { 156 | Email string `json:"email"` 157 | EmailVerified bool `json:"email_verified"` 158 | } 159 | err = json.Unmarshal(body, &emailData) 160 | if err != nil { 161 | return 162 | } 163 | if emailData.Email == "" { 164 | err = fmt.Errorf("missing email") 165 | return 166 | } 167 | email = emailData.Email 168 | if !emailData.EmailVerified { 169 | err = fmt.Errorf("email %s not listed as verified", email) 170 | return 171 | } 172 | return 173 | } 174 | 175 | // Redeem exchanges the OAuth2 authentication token for an ID token 176 | func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { 177 | if code == "" { 178 | err = errors.New("missing code") 179 | return 180 | } 181 | 182 | claims := &jwt.StandardClaims{ 183 | Issuer: p.ClientID, 184 | Subject: p.ClientID, 185 | Audience: p.RedeemURL.String(), 186 | ExpiresAt: int64(time.Now().Add(time.Duration(5 * time.Minute)).Unix()), 187 | Id: randSeq(32), 188 | } 189 | token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) 190 | ss, err := token.SignedString(p.JWTKey) 191 | if err != nil { 192 | return 193 | } 194 | 195 | params := url.Values{} 196 | params.Add("client_assertion", ss) 197 | params.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") 198 | params.Add("code", code) 199 | params.Add("grant_type", "authorization_code") 200 | 201 | var req *http.Request 202 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 203 | if err != nil { 204 | return 205 | } 206 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 207 | 208 | var resp *http.Response 209 | resp, err = http.DefaultClient.Do(req) 210 | if err != nil { 211 | return nil, err 212 | } 213 | var body []byte 214 | body, err = ioutil.ReadAll(resp.Body) 215 | resp.Body.Close() 216 | if err != nil { 217 | return 218 | } 219 | 220 | if resp.StatusCode != 200 { 221 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 222 | return 223 | } 224 | 225 | // Get the token from the body that we got from the token endpoint. 226 | var jsonResponse struct { 227 | AccessToken string `json:"access_token"` 228 | IDToken string `json:"id_token"` 229 | TokenType string `json:"token_type"` 230 | ExpiresIn int64 `json:"expires_in"` 231 | } 232 | err = json.Unmarshal(body, &jsonResponse) 233 | if err != nil { 234 | return 235 | } 236 | 237 | // check nonce here 238 | err = checkNonce(jsonResponse.IDToken, p) 239 | if err != nil { 240 | return 241 | } 242 | 243 | // Get the email address 244 | var email string 245 | email, err = emailFromUserInfo(jsonResponse.AccessToken, p.ProfileURL.String()) 246 | if err != nil { 247 | return 248 | } 249 | 250 | // Store the data that we found in the session state 251 | s = &SessionState{ 252 | AccessToken: jsonResponse.AccessToken, 253 | IDToken: jsonResponse.IDToken, 254 | ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), 255 | Email: email, 256 | } 257 | return 258 | } 259 | 260 | // GetLoginURL overrides GetLoginURL to add login.gov parameters 261 | func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string { 262 | var a url.URL 263 | a = *p.LoginURL 264 | params, _ := url.ParseQuery(a.RawQuery) 265 | params.Set("redirect_uri", redirectURI) 266 | params.Set("approval_prompt", p.ApprovalPrompt) 267 | params.Add("scope", p.Scope) 268 | params.Set("client_id", p.ClientID) 269 | params.Set("response_type", "code") 270 | params.Add("state", state) 271 | params.Add("acr_values", p.AcrValues) 272 | params.Add("nonce", p.Nonce) 273 | a.RawQuery = params.Encode() 274 | return a.String() 275 | } 276 | -------------------------------------------------------------------------------- /providers/logingov_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "crypto" 5 | "crypto/rand" 6 | "crypto/rsa" 7 | "encoding/json" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "testing" 12 | "time" 13 | 14 | "github.com/dgrijalva/jwt-go" 15 | "github.com/stretchr/testify/assert" 16 | "gopkg.in/square/go-jose.v2" 17 | ) 18 | 19 | type MyKeyData struct { 20 | PubKey crypto.PublicKey 21 | PrivKey *rsa.PrivateKey 22 | PubJWK jose.JSONWebKey 23 | } 24 | 25 | func newLoginGovServer(body []byte) (*url.URL, *httptest.Server) { 26 | s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 27 | rw.Write(body) 28 | })) 29 | u, _ := url.Parse(s.URL) 30 | return u, s 31 | } 32 | 33 | func newLoginGovProvider() (l *LoginGovProvider, serverKey *MyKeyData, err error) { 34 | key, err := rsa.GenerateKey(rand.Reader, 2048) 35 | if err != nil { 36 | return 37 | } 38 | serverKey = &MyKeyData{ 39 | PubKey: key.Public(), 40 | PrivKey: key, 41 | PubJWK: jose.JSONWebKey{ 42 | Key: key.Public(), 43 | KeyID: "testkey", 44 | Algorithm: string(jose.RS256), 45 | Use: "sig", 46 | }, 47 | } 48 | 49 | privateKey, err := rsa.GenerateKey(rand.Reader, 2048) 50 | if err != nil { 51 | return 52 | } 53 | 54 | l = NewLoginGovProvider( 55 | &ProviderData{ 56 | ProviderName: "", 57 | LoginURL: &url.URL{}, 58 | RedeemURL: &url.URL{}, 59 | ProfileURL: &url.URL{}, 60 | ValidateURL: &url.URL{}, 61 | Scope: ""}) 62 | l.JWTKey = privateKey 63 | l.Nonce = "fakenonce" 64 | return 65 | } 66 | 67 | func TestLoginGovProviderDefaults(t *testing.T) { 68 | p, _, err := newLoginGovProvider() 69 | assert.NotEqual(t, nil, p) 70 | assert.NoError(t, err) 71 | assert.Equal(t, "login.gov", p.Data().ProviderName) 72 | assert.Equal(t, "https://secure.login.gov/openid_connect/authorize", 73 | p.Data().LoginURL.String()) 74 | assert.Equal(t, "https://secure.login.gov/api/openid_connect/token", 75 | p.Data().RedeemURL.String()) 76 | assert.Equal(t, "https://secure.login.gov/api/openid_connect/userinfo", 77 | p.Data().ProfileURL.String()) 78 | assert.Equal(t, "email openid", p.Data().Scope) 79 | } 80 | 81 | func TestLoginGovProviderOverrides(t *testing.T) { 82 | p := NewLoginGovProvider( 83 | &ProviderData{ 84 | LoginURL: &url.URL{ 85 | Scheme: "https", 86 | Host: "example.com", 87 | Path: "/oauth/auth"}, 88 | RedeemURL: &url.URL{ 89 | Scheme: "https", 90 | Host: "example.com", 91 | Path: "/oauth/token"}, 92 | ProfileURL: &url.URL{ 93 | Scheme: "https", 94 | Host: "example.com", 95 | Path: "/oauth/profile"}, 96 | Scope: "profile"}) 97 | assert.NotEqual(t, nil, p) 98 | assert.Equal(t, "login.gov", p.Data().ProviderName) 99 | assert.Equal(t, "https://example.com/oauth/auth", 100 | p.Data().LoginURL.String()) 101 | assert.Equal(t, "https://example.com/oauth/token", 102 | p.Data().RedeemURL.String()) 103 | assert.Equal(t, "https://example.com/oauth/profile", 104 | p.Data().ProfileURL.String()) 105 | assert.Equal(t, "profile", p.Data().Scope) 106 | } 107 | 108 | func TestLoginGovProviderSessionData(t *testing.T) { 109 | p, serverkey, err := newLoginGovProvider() 110 | assert.NotEqual(t, nil, p) 111 | assert.NoError(t, err) 112 | 113 | // Set up the redeem endpoint here 114 | type loginGovRedeemResponse struct { 115 | AccessToken string `json:"access_token"` 116 | TokenType string `json:"token_type"` 117 | ExpiresIn int64 `json:"expires_in"` 118 | IDToken string `json:"id_token"` 119 | } 120 | expiresIn := int64(60) 121 | type MyCustomClaims struct { 122 | Acr string `json:"acr"` 123 | Nonce string `json:"nonce"` 124 | Email string `json:"email"` 125 | EmailVerified bool `json:"email_verified"` 126 | GivenName string `json:"given_name"` 127 | FamilyName string `json:"family_name"` 128 | Birthdate string `json:"birthdate"` 129 | AtHash string `json:"at_hash"` 130 | CHash string `json:"c_hash"` 131 | jwt.StandardClaims 132 | } 133 | claims := MyCustomClaims{ 134 | "http://idmanagement.gov/ns/assurance/loa/1", 135 | "fakenonce", 136 | "timothy.spencer@gsa.gov", 137 | true, 138 | "", 139 | "", 140 | "", 141 | "", 142 | "", 143 | jwt.StandardClaims{ 144 | Audience: "Audience", 145 | ExpiresAt: time.Now().Unix() + expiresIn, 146 | Id: "foo", 147 | IssuedAt: time.Now().Unix(), 148 | Issuer: "https://idp.int.login.gov", 149 | NotBefore: time.Now().Unix() - 1, 150 | Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", 151 | }, 152 | } 153 | idtoken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) 154 | signedidtoken, err := idtoken.SignedString(serverkey.PrivKey) 155 | assert.NoError(t, err) 156 | body, err := json.Marshal(loginGovRedeemResponse{ 157 | AccessToken: "a1234", 158 | TokenType: "Bearer", 159 | ExpiresIn: expiresIn, 160 | IDToken: signedidtoken, 161 | }) 162 | assert.NoError(t, err) 163 | var server *httptest.Server 164 | p.RedeemURL, server = newLoginGovServer(body) 165 | defer server.Close() 166 | 167 | // Set up the user endpoint here 168 | type loginGovUserResponse struct { 169 | Email string `json:"email"` 170 | EmailVerified bool `json:"email_verified"` 171 | Subject string `json:"sub"` 172 | } 173 | userbody, err := json.Marshal(loginGovUserResponse{ 174 | Email: "timothy.spencer@gsa.gov", 175 | EmailVerified: true, 176 | Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", 177 | }) 178 | assert.NoError(t, err) 179 | var userserver *httptest.Server 180 | p.ProfileURL, userserver = newLoginGovServer(userbody) 181 | defer userserver.Close() 182 | 183 | // Set up the PubJWKURL endpoint here used to verify the JWT 184 | var pubkeys jose.JSONWebKeySet 185 | pubkeys.Keys = append(pubkeys.Keys, serverkey.PubJWK) 186 | pubjwkbody, err := json.Marshal(pubkeys) 187 | assert.NoError(t, err) 188 | var pubjwkserver *httptest.Server 189 | p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) 190 | defer pubjwkserver.Close() 191 | 192 | session, err := p.Redeem("http://redirect/", "code1234") 193 | assert.NoError(t, err) 194 | assert.NotEqual(t, session, nil) 195 | assert.Equal(t, "timothy.spencer@gsa.gov", session.Email) 196 | assert.Equal(t, "a1234", session.AccessToken) 197 | 198 | // The test ought to run in under 2 seconds. If not, you may need to bump this up. 199 | assert.InDelta(t, session.ExpiresOn.Unix(), time.Now().Unix()+expiresIn, 2) 200 | } 201 | 202 | func TestLoginGovProviderBadNonce(t *testing.T) { 203 | p, serverkey, err := newLoginGovProvider() 204 | assert.NotEqual(t, nil, p) 205 | assert.NoError(t, err) 206 | 207 | // Set up the redeem endpoint here 208 | type loginGovRedeemResponse struct { 209 | AccessToken string `json:"access_token"` 210 | TokenType string `json:"token_type"` 211 | ExpiresIn int64 `json:"expires_in"` 212 | IDToken string `json:"id_token"` 213 | } 214 | expiresIn := int64(60) 215 | type MyCustomClaims struct { 216 | Acr string `json:"acr"` 217 | Nonce string `json:"nonce"` 218 | Email string `json:"email"` 219 | EmailVerified bool `json:"email_verified"` 220 | GivenName string `json:"given_name"` 221 | FamilyName string `json:"family_name"` 222 | Birthdate string `json:"birthdate"` 223 | AtHash string `json:"at_hash"` 224 | CHash string `json:"c_hash"` 225 | jwt.StandardClaims 226 | } 227 | claims := MyCustomClaims{ 228 | "http://idmanagement.gov/ns/assurance/loa/1", 229 | "badfakenonce", 230 | "timothy.spencer@gsa.gov", 231 | true, 232 | "", 233 | "", 234 | "", 235 | "", 236 | "", 237 | jwt.StandardClaims{ 238 | Audience: "Audience", 239 | ExpiresAt: time.Now().Unix() + expiresIn, 240 | Id: "foo", 241 | IssuedAt: time.Now().Unix(), 242 | Issuer: "https://idp.int.login.gov", 243 | NotBefore: time.Now().Unix() - 1, 244 | Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", 245 | }, 246 | } 247 | idtoken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) 248 | signedidtoken, err := idtoken.SignedString(serverkey.PrivKey) 249 | assert.NoError(t, err) 250 | body, err := json.Marshal(loginGovRedeemResponse{ 251 | AccessToken: "a1234", 252 | TokenType: "Bearer", 253 | ExpiresIn: expiresIn, 254 | IDToken: signedidtoken, 255 | }) 256 | assert.NoError(t, err) 257 | var server *httptest.Server 258 | p.RedeemURL, server = newLoginGovServer(body) 259 | defer server.Close() 260 | 261 | // Set up the user endpoint here 262 | type loginGovUserResponse struct { 263 | Email string `json:"email"` 264 | EmailVerified bool `json:"email_verified"` 265 | Subject string `json:"sub"` 266 | } 267 | userbody, err := json.Marshal(loginGovUserResponse{ 268 | Email: "timothy.spencer@gsa.gov", 269 | EmailVerified: true, 270 | Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", 271 | }) 272 | assert.NoError(t, err) 273 | var userserver *httptest.Server 274 | p.ProfileURL, userserver = newLoginGovServer(userbody) 275 | defer userserver.Close() 276 | 277 | // Set up the PubJWKURL endpoint here used to verify the JWT 278 | var pubkeys jose.JSONWebKeySet 279 | pubkeys.Keys = append(pubkeys.Keys, serverkey.PubJWK) 280 | pubjwkbody, err := json.Marshal(pubkeys) 281 | assert.NoError(t, err) 282 | var pubjwkserver *httptest.Server 283 | p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) 284 | defer pubjwkserver.Close() 285 | 286 | _, err = p.Redeem("http://redirect/", "code1234") 287 | 288 | // The "badfakenonce" in the idtoken above should cause this to error out 289 | assert.Error(t, err) 290 | } 291 | -------------------------------------------------------------------------------- /providers/oidc.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "golang.org/x/oauth2" 9 | 10 | oidc "github.com/coreos/go-oidc" 11 | ) 12 | 13 | // OIDCProvider represents an OIDC based Identity Provider 14 | type OIDCProvider struct { 15 | *ProviderData 16 | 17 | Verifier *oidc.IDTokenVerifier 18 | } 19 | 20 | // NewOIDCProvider initiates a new OIDCProvider 21 | func NewOIDCProvider(p *ProviderData) *OIDCProvider { 22 | p.ProviderName = "OpenID Connect" 23 | return &OIDCProvider{ProviderData: p} 24 | } 25 | 26 | // Redeem exchanges the OAuth2 authentication token for an ID token 27 | func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { 28 | ctx := context.Background() 29 | c := oauth2.Config{ 30 | ClientID: p.ClientID, 31 | ClientSecret: p.ClientSecret, 32 | Endpoint: oauth2.Endpoint{ 33 | TokenURL: p.RedeemURL.String(), 34 | }, 35 | RedirectURL: redirectURL, 36 | } 37 | token, err := c.Exchange(ctx, code) 38 | if err != nil { 39 | return nil, fmt.Errorf("token exchange: %v", err) 40 | } 41 | s, err = p.createSessionState(ctx, token) 42 | if err != nil { 43 | return nil, fmt.Errorf("unable to update session: %v", err) 44 | } 45 | return 46 | } 47 | 48 | // RefreshSessionIfNeeded checks if the session has expired and uses the 49 | // RefreshToken to fetch a new ID token if required 50 | func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 51 | if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { 52 | return false, nil 53 | } 54 | 55 | origExpiration := s.ExpiresOn 56 | 57 | err := p.redeemRefreshToken(s) 58 | if err != nil { 59 | return false, fmt.Errorf("unable to redeem refresh token: %v", err) 60 | } 61 | 62 | fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) 63 | return true, nil 64 | } 65 | 66 | func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { 67 | c := oauth2.Config{ 68 | ClientID: p.ClientID, 69 | ClientSecret: p.ClientSecret, 70 | Endpoint: oauth2.Endpoint{ 71 | TokenURL: p.RedeemURL.String(), 72 | }, 73 | } 74 | ctx := context.Background() 75 | t := &oauth2.Token{ 76 | RefreshToken: s.RefreshToken, 77 | Expiry: time.Now().Add(-time.Hour), 78 | } 79 | token, err := c.TokenSource(ctx, t).Token() 80 | if err != nil { 81 | return fmt.Errorf("failed to get token: %v", err) 82 | } 83 | newSession, err := p.createSessionState(ctx, token) 84 | if err != nil { 85 | return fmt.Errorf("unable to update session: %v", err) 86 | } 87 | s.AccessToken = newSession.AccessToken 88 | s.IDToken = newSession.IDToken 89 | s.RefreshToken = newSession.RefreshToken 90 | s.ExpiresOn = newSession.ExpiresOn 91 | s.Email = newSession.Email 92 | return 93 | } 94 | 95 | func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { 96 | rawIDToken, ok := token.Extra("id_token").(string) 97 | if !ok { 98 | return nil, fmt.Errorf("token response did not contain an id_token") 99 | } 100 | 101 | // Parse and verify ID Token payload. 102 | idToken, err := p.Verifier.Verify(ctx, rawIDToken) 103 | if err != nil { 104 | return nil, fmt.Errorf("could not verify id_token: %v", err) 105 | } 106 | 107 | // Extract custom claims. 108 | var claims struct { 109 | Subject string `json:"sub"` 110 | Email string `json:"email"` 111 | Verified *bool `json:"email_verified"` 112 | } 113 | if err := idToken.Claims(&claims); err != nil { 114 | return nil, fmt.Errorf("failed to parse id_token claims: %v", err) 115 | } 116 | 117 | if claims.Email == "" { 118 | // TODO: Try getting email from /userinfo before falling back to Subject 119 | claims.Email = claims.Subject 120 | } 121 | if claims.Verified != nil && !*claims.Verified { 122 | return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) 123 | } 124 | 125 | return &SessionState{ 126 | AccessToken: token.AccessToken, 127 | IDToken: rawIDToken, 128 | RefreshToken: token.RefreshToken, 129 | ExpiresOn: token.Expiry, 130 | Email: claims.Email, 131 | }, nil 132 | } 133 | 134 | // ValidateSessionState checks that the session's IDToken is still valid 135 | func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { 136 | ctx := context.Background() 137 | _, err := p.Verifier.Verify(ctx, s.IDToken) 138 | if err != nil { 139 | return false 140 | } 141 | 142 | return true 143 | } 144 | -------------------------------------------------------------------------------- /providers/provider_data.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | ) 6 | 7 | // ProviderData contains information required to configure all implementations 8 | // of OAuth2 providers 9 | type ProviderData struct { 10 | ProviderName string 11 | ClientID string 12 | ClientSecret string 13 | LoginURL *url.URL 14 | RedeemURL *url.URL 15 | ProfileURL *url.URL 16 | ProtectedResource *url.URL 17 | ValidateURL *url.URL 18 | Scope string 19 | ApprovalPrompt string 20 | } 21 | 22 | // Data returns the ProviderData 23 | func (p *ProviderData) Data() *ProviderData { return p } 24 | -------------------------------------------------------------------------------- /providers/provider_default.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "net/url" 11 | 12 | "github.com/pusher/oauth2_proxy/cookie" 13 | ) 14 | 15 | // Redeem provides a default implementation of the OAuth2 token redemption process 16 | func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { 17 | if code == "" { 18 | err = errors.New("missing code") 19 | return 20 | } 21 | 22 | params := url.Values{} 23 | params.Add("redirect_uri", redirectURL) 24 | params.Add("client_id", p.ClientID) 25 | params.Add("client_secret", p.ClientSecret) 26 | params.Add("code", code) 27 | params.Add("grant_type", "authorization_code") 28 | if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { 29 | params.Add("resource", p.ProtectedResource.String()) 30 | } 31 | 32 | var req *http.Request 33 | req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 34 | if err != nil { 35 | return 36 | } 37 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 38 | 39 | var resp *http.Response 40 | resp, err = http.DefaultClient.Do(req) 41 | if err != nil { 42 | return nil, err 43 | } 44 | var body []byte 45 | body, err = ioutil.ReadAll(resp.Body) 46 | resp.Body.Close() 47 | if err != nil { 48 | return 49 | } 50 | 51 | if resp.StatusCode != 200 { 52 | err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 53 | return 54 | } 55 | 56 | // blindly try json and x-www-form-urlencoded 57 | var jsonResponse struct { 58 | AccessToken string `json:"access_token"` 59 | } 60 | err = json.Unmarshal(body, &jsonResponse) 61 | if err == nil { 62 | s = &SessionState{ 63 | AccessToken: jsonResponse.AccessToken, 64 | } 65 | return 66 | } 67 | 68 | var v url.Values 69 | v, err = url.ParseQuery(string(body)) 70 | if err != nil { 71 | return 72 | } 73 | if a := v.Get("access_token"); a != "" { 74 | s = &SessionState{AccessToken: a} 75 | } else { 76 | err = fmt.Errorf("no access token found %s", body) 77 | } 78 | return 79 | } 80 | 81 | // GetLoginURL with typical oauth parameters 82 | func (p *ProviderData) GetLoginURL(redirectURI, state string) string { 83 | var a url.URL 84 | a = *p.LoginURL 85 | params, _ := url.ParseQuery(a.RawQuery) 86 | params.Set("redirect_uri", redirectURI) 87 | params.Set("approval_prompt", p.ApprovalPrompt) 88 | params.Add("scope", p.Scope) 89 | params.Set("client_id", p.ClientID) 90 | params.Set("response_type", "code") 91 | params.Add("state", state) 92 | a.RawQuery = params.Encode() 93 | return a.String() 94 | } 95 | 96 | // CookieForSession serializes a session state for storage in a cookie 97 | func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { 98 | return s.EncodeSessionState(c) 99 | } 100 | 101 | // SessionFromCookie deserializes a session from a cookie value 102 | func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { 103 | return DecodeSessionState(v, c) 104 | } 105 | 106 | // GetEmailAddress returns the Account email address 107 | func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { 108 | return "", errors.New("not implemented") 109 | } 110 | 111 | // GetUserName returns the Account username 112 | func (p *ProviderData) GetUserName(s *SessionState) (string, error) { 113 | return "", errors.New("not implemented") 114 | } 115 | 116 | // ValidateGroup validates that the provided email exists in the configured provider 117 | // email group(s). 118 | func (p *ProviderData) ValidateGroup(email string) bool { 119 | return true 120 | } 121 | 122 | // ValidateSessionState validates the AccessToken 123 | func (p *ProviderData) ValidateSessionState(s *SessionState) bool { 124 | return validateToken(p, s.AccessToken, nil) 125 | } 126 | 127 | // RefreshSessionIfNeeded should refresh the user's session if required and 128 | // do nothing if a refresh is not required 129 | func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 130 | return false, nil 131 | } 132 | -------------------------------------------------------------------------------- /providers/provider_default_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRefresh(t *testing.T) { 11 | p := &ProviderData{} 12 | refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ 13 | ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), 14 | }) 15 | assert.Equal(t, false, refreshed) 16 | assert.Equal(t, nil, err) 17 | } 18 | -------------------------------------------------------------------------------- /providers/providers.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "github.com/pusher/oauth2_proxy/cookie" 5 | ) 6 | 7 | // Provider represents an upstream identity provider implementation 8 | type Provider interface { 9 | Data() *ProviderData 10 | GetEmailAddress(*SessionState) (string, error) 11 | GetUserName(*SessionState) (string, error) 12 | Redeem(string, string) (*SessionState, error) 13 | ValidateGroup(string) bool 14 | ValidateSessionState(*SessionState) bool 15 | GetLoginURL(redirectURI, finalRedirect string) string 16 | RefreshSessionIfNeeded(*SessionState) (bool, error) 17 | SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) 18 | CookieForSession(*SessionState, *cookie.Cipher) (string, error) 19 | } 20 | 21 | // New provides a new Provider based on the configured provider string 22 | func New(provider string, p *ProviderData) Provider { 23 | switch provider { 24 | case "linkedin": 25 | return NewLinkedInProvider(p) 26 | case "facebook": 27 | return NewFacebookProvider(p) 28 | case "github": 29 | return NewGitHubProvider(p) 30 | case "azure": 31 | return NewAzureProvider(p) 32 | case "gitlab": 33 | return NewGitLabProvider(p) 34 | case "oidc": 35 | return NewOIDCProvider(p) 36 | case "login.gov": 37 | return NewLoginGovProvider(p) 38 | default: 39 | return NewGoogleProvider(p) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /providers/session_state.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | "github.com/pusher/oauth2_proxy/cookie" 11 | ) 12 | 13 | // SessionState is used to store information about the currently authenticated user session 14 | type SessionState struct { 15 | AccessToken string `json:",omitempty"` 16 | IDToken string `json:",omitempty"` 17 | ExpiresOn time.Time `json:"-"` 18 | RefreshToken string `json:",omitempty"` 19 | Email string `json:",omitempty"` 20 | User string `json:",omitempty"` 21 | } 22 | 23 | // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value 24 | type SessionStateJSON struct { 25 | *SessionState 26 | ExpiresOn *time.Time `json:",omitempty"` 27 | } 28 | 29 | // IsExpired checks whether the session has expired 30 | func (s *SessionState) IsExpired() bool { 31 | if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { 32 | return true 33 | } 34 | return false 35 | } 36 | 37 | // String constructs a summary of the session state 38 | func (s *SessionState) String() string { 39 | o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) 40 | if s.AccessToken != "" { 41 | o += " token:true" 42 | } 43 | if s.IDToken != "" { 44 | o += " id_token:true" 45 | } 46 | if !s.ExpiresOn.IsZero() { 47 | o += fmt.Sprintf(" expires:%s", s.ExpiresOn) 48 | } 49 | if s.RefreshToken != "" { 50 | o += " refresh_token:true" 51 | } 52 | return o + "}" 53 | } 54 | 55 | // EncodeSessionState returns string representation of the current session 56 | func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { 57 | var ss SessionState 58 | if c == nil { 59 | // Store only Email and User when cipher is unavailable 60 | ss.Email = s.Email 61 | ss.User = s.User 62 | } else { 63 | ss = *s 64 | var err error 65 | if ss.Email != "" { 66 | ss.Email, err = c.Encrypt(ss.Email) 67 | if err != nil { 68 | return "", err 69 | } 70 | } 71 | if ss.User != "" { 72 | ss.User, err = c.Encrypt(ss.User) 73 | if err != nil { 74 | return "", err 75 | } 76 | } 77 | if ss.AccessToken != "" { 78 | ss.AccessToken, err = c.Encrypt(ss.AccessToken) 79 | if err != nil { 80 | return "", err 81 | } 82 | } 83 | if ss.IDToken != "" { 84 | ss.IDToken, err = c.Encrypt(ss.IDToken) 85 | if err != nil { 86 | return "", err 87 | } 88 | } 89 | if ss.RefreshToken != "" { 90 | ss.RefreshToken, err = c.Encrypt(ss.RefreshToken) 91 | if err != nil { 92 | return "", err 93 | } 94 | } 95 | } 96 | // Embed SessionState and ExpiresOn pointer into SessionStateJSON 97 | ssj := &SessionStateJSON{SessionState: &ss} 98 | if !ss.ExpiresOn.IsZero() { 99 | ssj.ExpiresOn = &ss.ExpiresOn 100 | } 101 | b, err := json.Marshal(ssj) 102 | return string(b), err 103 | } 104 | 105 | // legacyDecodeSessionStatePlain decodes older plain session state string 106 | func legacyDecodeSessionStatePlain(v string) (*SessionState, error) { 107 | chunks := strings.Split(v, " ") 108 | if len(chunks) != 2 { 109 | return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks)) 110 | } 111 | 112 | user := strings.TrimPrefix(chunks[1], "user:") 113 | email := strings.TrimPrefix(chunks[0], "email:") 114 | 115 | return &SessionState{User: user, Email: email}, nil 116 | } 117 | 118 | // legacyDecodeSessionState attempts to decode the session state string 119 | // generated by v3.1.0 or older 120 | func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { 121 | chunks := strings.Split(v, "|") 122 | 123 | if c == nil { 124 | if len(chunks) != 1 { 125 | return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks)) 126 | } 127 | return legacyDecodeSessionStatePlain(chunks[0]) 128 | } 129 | 130 | if len(chunks) != 4 && len(chunks) != 5 { 131 | return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks)) 132 | } 133 | 134 | i := 0 135 | ss, err := legacyDecodeSessionStatePlain(chunks[i]) 136 | if err != nil { 137 | return nil, err 138 | } 139 | 140 | i++ 141 | ss.AccessToken = chunks[i] 142 | 143 | if len(chunks) == 5 { 144 | // SessionState with IDToken in v3.1.0 145 | i++ 146 | ss.IDToken = chunks[i] 147 | } 148 | 149 | i++ 150 | ts, err := strconv.Atoi(chunks[i]) 151 | if err != nil { 152 | return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err) 153 | } 154 | ss.ExpiresOn = time.Unix(int64(ts), 0) 155 | 156 | i++ 157 | ss.RefreshToken = chunks[i] 158 | 159 | return ss, nil 160 | } 161 | 162 | // DecodeSessionState decodes the session cookie string into a SessionState 163 | func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { 164 | var ssj SessionStateJSON 165 | var ss *SessionState 166 | err := json.Unmarshal([]byte(v), &ssj) 167 | if err == nil && ssj.SessionState != nil { 168 | // Extract SessionState and ExpiresOn value from SessionStateJSON 169 | ss = ssj.SessionState 170 | if ssj.ExpiresOn != nil { 171 | ss.ExpiresOn = *ssj.ExpiresOn 172 | } 173 | } else { 174 | // Try to decode a legacy string when json.Unmarshal failed 175 | ss, err = legacyDecodeSessionState(v, c) 176 | if err != nil { 177 | return nil, err 178 | } 179 | } 180 | if c == nil { 181 | // Load only Email and User when cipher is unavailable 182 | ss = &SessionState{ 183 | Email: ss.Email, 184 | User: ss.User, 185 | } 186 | } else { 187 | // Backward compatibility with using unecrypted Email 188 | if ss.Email != "" { 189 | decryptedEmail, errEmail := c.Decrypt(ss.Email) 190 | if errEmail == nil { 191 | ss.Email = decryptedEmail 192 | } 193 | } 194 | // Backward compatibility with using unecrypted User 195 | if ss.User != "" { 196 | decryptedUser, errUser := c.Decrypt(ss.User) 197 | if errUser == nil { 198 | ss.User = decryptedUser 199 | } 200 | } 201 | if ss.AccessToken != "" { 202 | ss.AccessToken, err = c.Decrypt(ss.AccessToken) 203 | if err != nil { 204 | return nil, err 205 | } 206 | } 207 | if ss.IDToken != "" { 208 | ss.IDToken, err = c.Decrypt(ss.IDToken) 209 | if err != nil { 210 | return nil, err 211 | } 212 | } 213 | if ss.RefreshToken != "" { 214 | ss.RefreshToken, err = c.Decrypt(ss.RefreshToken) 215 | if err != nil { 216 | return nil, err 217 | } 218 | } 219 | } 220 | if ss.User == "" { 221 | ss.User = strings.Split(ss.Email, "@")[0] 222 | } 223 | return ss, nil 224 | } 225 | -------------------------------------------------------------------------------- /providers/session_state_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/pusher/oauth2_proxy/cookie" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | const secret = "0123456789abcdefghijklmnopqrstuv" 13 | const altSecret = "0000000000abcdefghijklmnopqrstuv" 14 | 15 | func TestSessionStateSerialization(t *testing.T) { 16 | c, err := cookie.NewCipher([]byte(secret)) 17 | assert.Equal(t, nil, err) 18 | c2, err := cookie.NewCipher([]byte(altSecret)) 19 | assert.Equal(t, nil, err) 20 | s := &SessionState{ 21 | Email: "user@domain.com", 22 | AccessToken: "token1234", 23 | IDToken: "rawtoken1234", 24 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 25 | RefreshToken: "refresh4321", 26 | } 27 | encoded, err := s.EncodeSessionState(c) 28 | assert.Equal(t, nil, err) 29 | 30 | ss, err := DecodeSessionState(encoded, c) 31 | t.Logf("%#v", ss) 32 | assert.Equal(t, nil, err) 33 | assert.Equal(t, "user", ss.User) 34 | assert.Equal(t, s.Email, ss.Email) 35 | assert.Equal(t, s.AccessToken, ss.AccessToken) 36 | assert.Equal(t, s.IDToken, ss.IDToken) 37 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 38 | assert.Equal(t, s.RefreshToken, ss.RefreshToken) 39 | 40 | // ensure a different cipher can't decode properly (ie: it gets gibberish) 41 | ss, err = DecodeSessionState(encoded, c2) 42 | t.Logf("%#v", ss) 43 | assert.Equal(t, nil, err) 44 | assert.NotEqual(t, "user", ss.User) 45 | assert.NotEqual(t, s.Email, ss.Email) 46 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 47 | assert.NotEqual(t, s.AccessToken, ss.AccessToken) 48 | assert.NotEqual(t, s.IDToken, ss.IDToken) 49 | assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) 50 | } 51 | 52 | func TestSessionStateSerializationWithUser(t *testing.T) { 53 | c, err := cookie.NewCipher([]byte(secret)) 54 | assert.Equal(t, nil, err) 55 | c2, err := cookie.NewCipher([]byte(altSecret)) 56 | assert.Equal(t, nil, err) 57 | s := &SessionState{ 58 | User: "just-user", 59 | Email: "user@domain.com", 60 | AccessToken: "token1234", 61 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 62 | RefreshToken: "refresh4321", 63 | } 64 | encoded, err := s.EncodeSessionState(c) 65 | assert.Equal(t, nil, err) 66 | 67 | ss, err := DecodeSessionState(encoded, c) 68 | t.Logf("%#v", ss) 69 | assert.Equal(t, nil, err) 70 | assert.Equal(t, s.User, ss.User) 71 | assert.Equal(t, s.Email, ss.Email) 72 | assert.Equal(t, s.AccessToken, ss.AccessToken) 73 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 74 | assert.Equal(t, s.RefreshToken, ss.RefreshToken) 75 | 76 | // ensure a different cipher can't decode properly (ie: it gets gibberish) 77 | ss, err = DecodeSessionState(encoded, c2) 78 | t.Logf("%#v", ss) 79 | assert.Equal(t, nil, err) 80 | assert.NotEqual(t, s.User, ss.User) 81 | assert.NotEqual(t, s.Email, ss.Email) 82 | assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 83 | assert.NotEqual(t, s.AccessToken, ss.AccessToken) 84 | assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) 85 | } 86 | 87 | func TestSessionStateSerializationNoCipher(t *testing.T) { 88 | s := &SessionState{ 89 | Email: "user@domain.com", 90 | AccessToken: "token1234", 91 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 92 | RefreshToken: "refresh4321", 93 | } 94 | encoded, err := s.EncodeSessionState(nil) 95 | assert.Equal(t, nil, err) 96 | 97 | // only email should have been serialized 98 | ss, err := DecodeSessionState(encoded, nil) 99 | assert.Equal(t, nil, err) 100 | assert.Equal(t, "user", ss.User) 101 | assert.Equal(t, s.Email, ss.Email) 102 | assert.Equal(t, "", ss.AccessToken) 103 | assert.Equal(t, "", ss.RefreshToken) 104 | } 105 | 106 | func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { 107 | s := &SessionState{ 108 | User: "just-user", 109 | Email: "user@domain.com", 110 | AccessToken: "token1234", 111 | ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), 112 | RefreshToken: "refresh4321", 113 | } 114 | encoded, err := s.EncodeSessionState(nil) 115 | assert.Equal(t, nil, err) 116 | 117 | // only email should have been serialized 118 | ss, err := DecodeSessionState(encoded, nil) 119 | assert.Equal(t, nil, err) 120 | assert.Equal(t, s.User, ss.User) 121 | assert.Equal(t, s.Email, ss.Email) 122 | assert.Equal(t, "", ss.AccessToken) 123 | assert.Equal(t, "", ss.RefreshToken) 124 | } 125 | 126 | func TestExpired(t *testing.T) { 127 | s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} 128 | assert.Equal(t, true, s.IsExpired()) 129 | 130 | s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} 131 | assert.Equal(t, false, s.IsExpired()) 132 | 133 | s = &SessionState{} 134 | assert.Equal(t, false, s.IsExpired()) 135 | } 136 | 137 | type testCase struct { 138 | SessionState 139 | Encoded string 140 | Cipher *cookie.Cipher 141 | Error bool 142 | } 143 | 144 | // TestEncodeSessionState tests EncodeSessionState with the test vector 145 | // 146 | // Currently only tests without cipher here because we have no way to mock 147 | // the random generator used in EncodeSessionState. 148 | func TestEncodeSessionState(t *testing.T) { 149 | e := time.Now().Add(time.Duration(1) * time.Hour) 150 | 151 | testCases := []testCase{ 152 | { 153 | SessionState: SessionState{ 154 | Email: "user@domain.com", 155 | User: "just-user", 156 | }, 157 | Encoded: `{"Email":"user@domain.com","User":"just-user"}`, 158 | }, 159 | { 160 | SessionState: SessionState{ 161 | Email: "user@domain.com", 162 | User: "just-user", 163 | AccessToken: "token1234", 164 | IDToken: "rawtoken1234", 165 | ExpiresOn: e, 166 | RefreshToken: "refresh4321", 167 | }, 168 | Encoded: `{"Email":"user@domain.com","User":"just-user"}`, 169 | }, 170 | } 171 | 172 | for i, tc := range testCases { 173 | encoded, err := tc.EncodeSessionState(tc.Cipher) 174 | t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) 175 | if tc.Error { 176 | assert.Error(t, err) 177 | assert.Empty(t, encoded) 178 | continue 179 | } 180 | assert.NoError(t, err) 181 | assert.JSONEq(t, tc.Encoded, encoded) 182 | } 183 | } 184 | 185 | // TestDecodeSessionState tests DecodeSessionState with the test vector 186 | func TestDecodeSessionState(t *testing.T) { 187 | e := time.Now().Add(time.Duration(1) * time.Hour) 188 | eJSON, _ := e.MarshalJSON() 189 | eString := string(eJSON) 190 | eUnix := e.Unix() 191 | 192 | c, err := cookie.NewCipher([]byte(secret)) 193 | assert.NoError(t, err) 194 | 195 | testCases := []testCase{ 196 | { 197 | SessionState: SessionState{ 198 | Email: "user@domain.com", 199 | User: "just-user", 200 | }, 201 | Encoded: `{"Email":"user@domain.com","User":"just-user"}`, 202 | }, 203 | { 204 | SessionState: SessionState{ 205 | Email: "user@domain.com", 206 | User: "user", 207 | }, 208 | Encoded: `{"Email":"user@domain.com"}`, 209 | }, 210 | { 211 | SessionState: SessionState{ 212 | User: "just-user", 213 | }, 214 | Encoded: `{"User":"just-user"}`, 215 | }, 216 | { 217 | SessionState: SessionState{ 218 | Email: "user@domain.com", 219 | User: "just-user", 220 | }, 221 | Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), 222 | }, 223 | { 224 | SessionState: SessionState{ 225 | Email: "user@domain.com", 226 | User: "just-user", 227 | AccessToken: "token1234", 228 | IDToken: "rawtoken1234", 229 | ExpiresOn: e, 230 | RefreshToken: "refresh4321", 231 | }, 232 | Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), 233 | Cipher: c, 234 | }, 235 | { 236 | SessionState: SessionState{ 237 | Email: "user@domain.com", 238 | User: "just-user", 239 | }, 240 | Encoded: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`, 241 | Cipher: c, 242 | }, 243 | { 244 | Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, 245 | Cipher: c, 246 | Error: true, 247 | }, 248 | { 249 | Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, 250 | Cipher: c, 251 | Error: true, 252 | }, 253 | { 254 | SessionState: SessionState{ 255 | User: "just-user", 256 | Email: "user@domain.com", 257 | }, 258 | Encoded: "email:user@domain.com user:just-user", 259 | }, 260 | { 261 | Encoded: "email:user@domain.com user:just-user||||", 262 | Error: true, 263 | }, 264 | { 265 | Encoded: "email:user@domain.com user:just-user", 266 | Cipher: c, 267 | Error: true, 268 | }, 269 | { 270 | Encoded: "email:user@domain.com user:just-user|||99999999999999999999|", 271 | Cipher: c, 272 | Error: true, 273 | }, 274 | { 275 | SessionState: SessionState{ 276 | Email: "user@domain.com", 277 | User: "just-user", 278 | AccessToken: "token1234", 279 | ExpiresOn: e, 280 | RefreshToken: "refresh4321", 281 | }, 282 | Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix), 283 | Cipher: c, 284 | }, 285 | { 286 | SessionState: SessionState{ 287 | Email: "user@domain.com", 288 | User: "just-user", 289 | AccessToken: "token1234", 290 | IDToken: "rawtoken1234", 291 | ExpiresOn: e, 292 | RefreshToken: "refresh4321", 293 | }, 294 | Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix), 295 | Cipher: c, 296 | }, 297 | } 298 | 299 | for i, tc := range testCases { 300 | ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) 301 | t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) 302 | if tc.Error { 303 | assert.Error(t, err) 304 | assert.Nil(t, ss) 305 | continue 306 | } 307 | assert.NoError(t, err) 308 | if assert.NotNil(t, ss) { 309 | assert.Equal(t, tc.User, ss.User) 310 | assert.Equal(t, tc.Email, ss.Email) 311 | assert.Equal(t, tc.AccessToken, ss.AccessToken) 312 | assert.Equal(t, tc.RefreshToken, ss.RefreshToken) 313 | assert.Equal(t, tc.IDToken, ss.IDToken) 314 | assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) 315 | } 316 | } 317 | } 318 | -------------------------------------------------------------------------------- /string_array.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // StringArray is a type alias for a slice of strings 8 | type StringArray []string 9 | 10 | // Get returns the slice of strings 11 | func (a *StringArray) Get() interface{} { 12 | return []string(*a) 13 | } 14 | 15 | // Set appends a string to the StringArray 16 | func (a *StringArray) Set(s string) error { 17 | *a = append(*a, s) 18 | return nil 19 | } 20 | 21 | // String joins elements of the StringArray into a single comma separated string 22 | func (a *StringArray) String() string { 23 | return strings.Join(*a, ",") 24 | } 25 | -------------------------------------------------------------------------------- /templates.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "html/template" 5 | "log" 6 | "path" 7 | ) 8 | 9 | func loadTemplates(dir string) *template.Template { 10 | if dir == "" { 11 | return getTemplates() 12 | } 13 | log.Printf("using custom template directory %q", dir) 14 | t, err := template.New("").ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) 15 | if err != nil { 16 | log.Fatalf("failed parsing template %s", err) 17 | } 18 | return t 19 | } 20 | 21 | func getTemplates() *template.Template { 22 | t, err := template.New("foo").Parse(`{{define "sign_in.html"}} 23 | 24 | 25 | 26 | Sign In 27 | 28 | 110 | 111 | 112 | 121 | 122 | {{ if .CustomLogin }} 123 | 131 | {{ end }} 132 | 142 |
143 | {{ if eq .Footer "-" }} 144 | {{ else if eq .Footer ""}} 145 | Secured with OAuth2 Proxy version {{.Version}} 146 | {{ else }} 147 | {{.Footer}} 148 | {{ end }} 149 |
150 | 151 | 152 | {{end}}`) 153 | if err != nil { 154 | log.Fatalf("failed parsing template %s", err) 155 | } 156 | 157 | t, err = t.Parse(`{{define "error.html"}} 158 | 159 | 160 | 161 | {{.Title}} 162 | 163 | 164 | 165 |

{{.Title}}

166 |

{{.Message}}

167 |
168 |

Sign In

169 | 170 | {{end}}`) 171 | if err != nil { 172 | log.Fatalf("failed parsing template %s", err) 173 | } 174 | return t 175 | } 176 | -------------------------------------------------------------------------------- /templates_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTemplatesCompile(t *testing.T) { 10 | templates := getTemplates() 11 | assert.NotEqual(t, templates, nil) 12 | } 13 | -------------------------------------------------------------------------------- /validator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/csv" 5 | "fmt" 6 | "log" 7 | "os" 8 | "strings" 9 | "sync/atomic" 10 | "unsafe" 11 | ) 12 | 13 | // UserMap holds information from the authenticated emails file 14 | type UserMap struct { 15 | usersFile string 16 | m unsafe.Pointer 17 | } 18 | 19 | // NewUserMap parses the authenticated emails file into a new UserMap 20 | func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { 21 | um := &UserMap{usersFile: usersFile} 22 | m := make(map[string]bool) 23 | atomic.StorePointer(&um.m, unsafe.Pointer(&m)) 24 | if usersFile != "" { 25 | log.Printf("using authenticated emails file %s", usersFile) 26 | WatchForUpdates(usersFile, done, func() { 27 | um.LoadAuthenticatedEmailsFile() 28 | onUpdate() 29 | }) 30 | um.LoadAuthenticatedEmailsFile() 31 | } 32 | return um 33 | } 34 | 35 | // IsValid checks if an email is allowed 36 | func (um *UserMap) IsValid(email string) (result bool) { 37 | m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) 38 | _, result = m[email] 39 | return 40 | } 41 | 42 | // LoadAuthenticatedEmailsFile loads the authenticated emails file from disk 43 | // and parses the contents as CSV 44 | func (um *UserMap) LoadAuthenticatedEmailsFile() { 45 | r, err := os.Open(um.usersFile) 46 | if err != nil { 47 | log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) 48 | } 49 | defer r.Close() 50 | csvReader := csv.NewReader(r) 51 | csvReader.Comma = ',' 52 | csvReader.Comment = '#' 53 | csvReader.TrimLeadingSpace = true 54 | records, err := csvReader.ReadAll() 55 | if err != nil { 56 | log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) 57 | return 58 | } 59 | updated := make(map[string]bool) 60 | for _, r := range records { 61 | address := strings.ToLower(strings.TrimSpace(r[0])) 62 | updated[address] = true 63 | } 64 | atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) 65 | } 66 | 67 | func newValidatorImpl(domains []string, usersFile string, 68 | done <-chan bool, onUpdate func()) func(string) bool { 69 | validUsers := NewUserMap(usersFile, done, onUpdate) 70 | 71 | var allowAll bool 72 | for i, domain := range domains { 73 | if domain == "*" { 74 | allowAll = true 75 | continue 76 | } 77 | domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) 78 | } 79 | 80 | validator := func(email string) (valid bool) { 81 | if email == "" { 82 | return 83 | } 84 | email = strings.ToLower(email) 85 | for _, domain := range domains { 86 | valid = valid || strings.HasSuffix(email, domain) 87 | } 88 | if !valid { 89 | valid = validUsers.IsValid(email) 90 | } 91 | if allowAll { 92 | valid = true 93 | } 94 | return valid 95 | } 96 | return validator 97 | } 98 | 99 | // NewValidator constructs a function to validate email addresses 100 | func NewValidator(domains []string, usersFile string) func(string) bool { 101 | return newValidatorImpl(domains, usersFile, nil, func() {}) 102 | } 103 | -------------------------------------------------------------------------------- /validator_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | type ValidatorTest struct { 11 | authEmailFile *os.File 12 | done chan bool 13 | updateSeen bool 14 | } 15 | 16 | func NewValidatorTest(t *testing.T) *ValidatorTest { 17 | vt := &ValidatorTest{} 18 | var err error 19 | vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") 20 | if err != nil { 21 | t.Fatal("failed to create temp file: " + err.Error()) 22 | } 23 | vt.done = make(chan bool, 1) 24 | return vt 25 | } 26 | 27 | func (vt *ValidatorTest) TearDown() { 28 | vt.done <- true 29 | os.Remove(vt.authEmailFile.Name()) 30 | } 31 | 32 | func (vt *ValidatorTest) NewValidator(domains []string, 33 | updated chan<- bool) func(string) bool { 34 | return newValidatorImpl(domains, vt.authEmailFile.Name(), 35 | vt.done, func() { 36 | if vt.updateSeen == false { 37 | updated <- true 38 | vt.updateSeen = true 39 | } 40 | }) 41 | } 42 | 43 | // This will close vt.authEmailFile. 44 | func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { 45 | defer vt.authEmailFile.Close() 46 | vt.authEmailFile.WriteString(strings.Join(emails, "\n")) 47 | if err := vt.authEmailFile.Close(); err != nil { 48 | t.Fatal("failed to close temp file " + 49 | vt.authEmailFile.Name() + ": " + err.Error()) 50 | } 51 | } 52 | 53 | func TestValidatorEmpty(t *testing.T) { 54 | vt := NewValidatorTest(t) 55 | defer vt.TearDown() 56 | 57 | vt.WriteEmails(t, []string(nil)) 58 | domains := []string(nil) 59 | validator := vt.NewValidator(domains, nil) 60 | 61 | if validator("foo.bar@example.com") { 62 | t.Error("nothing should validate when the email and " + 63 | "domain lists are empty") 64 | } 65 | } 66 | 67 | func TestValidatorSingleEmail(t *testing.T) { 68 | vt := NewValidatorTest(t) 69 | defer vt.TearDown() 70 | 71 | vt.WriteEmails(t, []string{"foo.bar@example.com"}) 72 | domains := []string(nil) 73 | validator := vt.NewValidator(domains, nil) 74 | 75 | if !validator("foo.bar@example.com") { 76 | t.Error("email should validate") 77 | } 78 | if validator("baz.quux@example.com") { 79 | t.Error("email from same domain but not in list " + 80 | "should not validate when domain list is empty") 81 | } 82 | } 83 | 84 | func TestValidatorSingleDomain(t *testing.T) { 85 | vt := NewValidatorTest(t) 86 | defer vt.TearDown() 87 | 88 | vt.WriteEmails(t, []string(nil)) 89 | domains := []string{"example.com"} 90 | validator := vt.NewValidator(domains, nil) 91 | 92 | if !validator("foo.bar@example.com") { 93 | t.Error("email should validate") 94 | } 95 | if !validator("baz.quux@example.com") { 96 | t.Error("email from same domain should validate") 97 | } 98 | } 99 | 100 | func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) { 101 | vt := NewValidatorTest(t) 102 | defer vt.TearDown() 103 | 104 | vt.WriteEmails(t, []string{ 105 | "xyzzy@example.com", 106 | "plugh@example.com", 107 | }) 108 | domains := []string{"example0.com", "example1.com"} 109 | validator := vt.NewValidator(domains, nil) 110 | 111 | if !validator("foo.bar@example0.com") { 112 | t.Error("email from first domain should validate") 113 | } 114 | if !validator("baz.quux@example1.com") { 115 | t.Error("email from second domain should validate") 116 | } 117 | if !validator("xyzzy@example.com") { 118 | t.Error("first email in list should validate") 119 | } 120 | if !validator("plugh@example.com") { 121 | t.Error("second email in list should validate") 122 | } 123 | if validator("xyzzy.plugh@example.com") { 124 | t.Error("email not in list that matches no domains " + 125 | "should not validate") 126 | } 127 | } 128 | 129 | func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { 130 | vt := NewValidatorTest(t) 131 | defer vt.TearDown() 132 | 133 | vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) 134 | domains := []string{"Frobozz.Com"} 135 | validator := vt.NewValidator(domains, nil) 136 | 137 | if !validator("foo.bar@example.com") { 138 | t.Error("loaded email addresses are not lower-cased") 139 | } 140 | if !validator("Foo.Bar@Example.Com") { 141 | t.Error("validated email addresses are not lower-cased") 142 | } 143 | if !validator("foo.bar@frobozz.com") { 144 | t.Error("loaded domains are not lower-cased") 145 | } 146 | if !validator("foo.bar@Frobozz.Com") { 147 | t.Error("validated domains are not lower-cased") 148 | } 149 | } 150 | 151 | func TestValidatorIgnoreSpacesInAuthEmails(t *testing.T) { 152 | vt := NewValidatorTest(t) 153 | defer vt.TearDown() 154 | 155 | vt.WriteEmails(t, []string{" foo.bar@example.com "}) 156 | domains := []string(nil) 157 | validator := vt.NewValidator(domains, nil) 158 | 159 | if !validator("foo.bar@example.com") { 160 | t.Error("email should validate") 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /validator_watcher_copy_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.3,!plan9,!solaris,!windows 2 | 3 | // Turns out you can't copy over an existing file on Windows. 4 | 5 | package main 6 | 7 | import ( 8 | "io/ioutil" 9 | "os" 10 | "testing" 11 | ) 12 | 13 | func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( 14 | t *testing.T, emails []string) { 15 | origFile := vt.authEmailFile 16 | var err error 17 | vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") 18 | if err != nil { 19 | t.Fatal("failed to create temp file for copy: " + err.Error()) 20 | } 21 | vt.WriteEmails(t, emails) 22 | err = os.Rename(vt.authEmailFile.Name(), origFile.Name()) 23 | if err != nil { 24 | t.Fatal("failed to copy over temp file: " + err.Error()) 25 | } 26 | vt.authEmailFile = origFile 27 | } 28 | 29 | func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { 30 | vt := NewValidatorTest(t) 31 | defer vt.TearDown() 32 | 33 | vt.WriteEmails(t, []string{"xyzzy@example.com"}) 34 | domains := []string(nil) 35 | updated := make(chan bool) 36 | validator := vt.NewValidator(domains, updated) 37 | 38 | if !validator("xyzzy@example.com") { 39 | t.Error("email in list should validate") 40 | } 41 | 42 | vt.UpdateEmailFileViaCopyingOver(t, []string{"plugh@example.com"}) 43 | <-updated 44 | 45 | if validator("xyzzy@example.com") { 46 | t.Error("email removed from list should not validate") 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /validator_watcher_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.3,!plan9,!solaris 2 | 3 | package main 4 | 5 | import ( 6 | "io/ioutil" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { 12 | var err error 13 | vt.authEmailFile, err = os.OpenFile( 14 | vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600) 15 | if err != nil { 16 | t.Fatal("failed to re-open temp file for updates") 17 | } 18 | vt.WriteEmails(t, emails) 19 | } 20 | 21 | func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( 22 | t *testing.T, emails []string) { 23 | origFile := vt.authEmailFile 24 | var err error 25 | vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") 26 | if err != nil { 27 | t.Fatal("failed to create temp file for rename and replace: " + 28 | err.Error()) 29 | } 30 | vt.WriteEmails(t, emails) 31 | 32 | movedName := origFile.Name() + "-moved" 33 | err = os.Rename(origFile.Name(), movedName) 34 | err = os.Rename(vt.authEmailFile.Name(), origFile.Name()) 35 | if err != nil { 36 | t.Fatal("failed to rename and replace temp file: " + 37 | err.Error()) 38 | } 39 | vt.authEmailFile = origFile 40 | os.Remove(movedName) 41 | } 42 | 43 | func TestValidatorOverwriteEmailListDirectly(t *testing.T) { 44 | vt := NewValidatorTest(t) 45 | defer vt.TearDown() 46 | 47 | vt.WriteEmails(t, []string{ 48 | "xyzzy@example.com", 49 | "plugh@example.com", 50 | }) 51 | domains := []string(nil) 52 | updated := make(chan bool) 53 | validator := vt.NewValidator(domains, updated) 54 | 55 | if !validator("xyzzy@example.com") { 56 | t.Error("first email in list should validate") 57 | } 58 | if !validator("plugh@example.com") { 59 | t.Error("second email in list should validate") 60 | } 61 | if validator("xyzzy.plugh@example.com") { 62 | t.Error("email not in list that matches no domains " + 63 | "should not validate") 64 | } 65 | 66 | vt.UpdateEmailFile(t, []string{ 67 | "xyzzy.plugh@example.com", 68 | "plugh@example.com", 69 | }) 70 | <-updated 71 | 72 | if validator("xyzzy@example.com") { 73 | t.Error("email removed from list should not validate") 74 | } 75 | if !validator("plugh@example.com") { 76 | t.Error("email retained in list should validate") 77 | } 78 | if !validator("xyzzy.plugh@example.com") { 79 | t.Error("email added to list should validate") 80 | } 81 | } 82 | 83 | func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) { 84 | vt := NewValidatorTest(t) 85 | defer vt.TearDown() 86 | 87 | vt.WriteEmails(t, []string{"xyzzy@example.com"}) 88 | domains := []string(nil) 89 | updated := make(chan bool, 1) 90 | validator := vt.NewValidator(domains, updated) 91 | 92 | if !validator("xyzzy@example.com") { 93 | t.Error("email in list should validate") 94 | } 95 | 96 | vt.UpdateEmailFileViaRenameAndReplace(t, []string{"plugh@example.com"}) 97 | <-updated 98 | 99 | if validator("xyzzy@example.com") { 100 | t.Error("email removed from list should not validate") 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // VERSION contains version information 4 | var VERSION = "undefined" 5 | -------------------------------------------------------------------------------- /watcher.go: -------------------------------------------------------------------------------- 1 | // +build go1.3,!plan9,!solaris 2 | 3 | package main 4 | 5 | import ( 6 | "log" 7 | "os" 8 | "path/filepath" 9 | "time" 10 | 11 | fsnotify "gopkg.in/fsnotify/fsnotify.v1" 12 | ) 13 | 14 | // WaitForReplacement waits for a file to exist on disk and then starts a watch 15 | // for the file 16 | func WaitForReplacement(filename string, op fsnotify.Op, 17 | watcher *fsnotify.Watcher) { 18 | const sleepInterval = 50 * time.Millisecond 19 | 20 | // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. 21 | if op&fsnotify.Chmod != 0 { 22 | time.Sleep(sleepInterval) 23 | } 24 | for { 25 | if _, err := os.Stat(filename); err == nil { 26 | if err := watcher.Add(filename); err == nil { 27 | log.Printf("watching resumed for %s", filename) 28 | return 29 | } 30 | } 31 | time.Sleep(sleepInterval) 32 | } 33 | } 34 | 35 | // WatchForUpdates performs an action every time a file on disk is updated 36 | func WatchForUpdates(filename string, done <-chan bool, action func()) { 37 | filename = filepath.Clean(filename) 38 | watcher, err := fsnotify.NewWatcher() 39 | if err != nil { 40 | log.Fatal("failed to create watcher for ", filename, ": ", err) 41 | } 42 | go func() { 43 | defer watcher.Close() 44 | for { 45 | select { 46 | case _ = <-done: 47 | log.Printf("Shutting down watcher for: %s", filename) 48 | return 49 | case event := <-watcher.Events: 50 | // On Arch Linux, it appears Chmod events precede Remove events, 51 | // which causes a race between action() and the coming Remove event. 52 | // If the Remove wins, the action() (which calls 53 | // UserMap.LoadAuthenticatedEmailsFile()) crashes when the file 54 | // can't be opened. 55 | if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { 56 | log.Printf("watching interrupted on event: %s", event) 57 | watcher.Remove(filename) 58 | WaitForReplacement(filename, event.Op, watcher) 59 | } 60 | log.Printf("reloading after event: %s", event) 61 | action() 62 | case err = <-watcher.Errors: 63 | log.Printf("error watching %s: %s", filename, err) 64 | } 65 | } 66 | }() 67 | if err = watcher.Add(filename); err != nil { 68 | log.Fatal("failed to add ", filename, " to watcher: ", err) 69 | } 70 | log.Printf("watching %s for updates", filename) 71 | } 72 | -------------------------------------------------------------------------------- /watcher_unsupported.go: -------------------------------------------------------------------------------- 1 | // +build !go1.3 plan9 solaris 2 | 3 | package main 4 | 5 | import ( 6 | "log" 7 | ) 8 | 9 | func WatchForUpdates(filename string, done <-chan bool, action func()) { 10 | log.Printf("file watching not implemented on this platform") 11 | go func() { <-done }() 12 | } 13 | --------------------------------------------------------------------------------