├── .circleci └── config.yml ├── .codecov └── codecov.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── release-drafter.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── PULL_REQUEST_TEMPLATE.md ├── README.md ├── cmd ├── sso-auth │ └── main.go └── sso-proxy │ ├── generate-request-signature │ └── generate_request_signature.go │ └── main.go ├── docs ├── API.md ├── adr │ └── 2018-08-16-circuit-breaker.md ├── architecture │ └── circuit-breaker.md ├── diagrams │ └── sso_request_flow.png ├── generate_request_signature.md ├── google_provider_setup.md ├── img │ ├── choose-account.jpg │ ├── logo.png │ ├── okta │ │ ├── okta-app-settings.jpg │ │ ├── okta-auth-server-claims.jpg │ │ ├── okta-auth-server-scope.jpg │ │ └── okta-homepage-api.jpg │ ├── payload-screen.jpg │ ├── setup-admin_api.jpg │ ├── setup-api_client_access.jpg │ ├── setup-consent_screen.jpg │ ├── setup-create_credentials.jpg │ ├── setup-create_service_account.jpg │ ├── setup-credentials.jpg │ ├── setup-manage_resources.jpg │ ├── setup-notification.jpg │ ├── setup-security_control.jpg │ ├── start-auth.jpg │ └── start-script.jpg ├── okta_provider_setup.md ├── quickstart.md ├── sso_authenticator_config.md ├── sso_config.md └── sso_proxy_config.md ├── go.mod ├── go.sum ├── internal ├── auth │ ├── authenticator.go │ ├── authenticator_test.go │ ├── circuit │ │ ├── breaker.go │ │ └── breaker_test.go │ ├── conf.d │ │ └── gitkeep │ ├── configuration.go │ ├── configuration_test.go │ ├── error.go │ ├── http.go │ ├── logging_handler.go │ ├── logging_handler_test.go │ ├── metrics.go │ ├── metrics_test.go │ ├── middleware.go │ ├── middleware_test.go │ ├── mux.go │ ├── mux_test.go │ ├── options.go │ ├── options_test.go │ ├── providers │ │ ├── amazon_cognito.go │ │ ├── amazon_cognito_admin.go │ │ ├── amazon_cognito_mock_admin.go │ │ ├── amazon_cognito_test.go │ │ ├── google.go │ │ ├── google_admin.go │ │ ├── google_mock_admin.go │ │ ├── google_test.go │ │ ├── group_cache.go │ │ ├── group_cache_test.go │ │ ├── http_client.go │ │ ├── internal_util.go │ │ ├── internal_util_test.go │ │ ├── okta.go │ │ ├── okta_test.go │ │ ├── provider_data.go │ │ ├── provider_default.go │ │ ├── providers.go │ │ ├── singleflight_middleware.go │ │ └── test_provider.go │ ├── static │ │ └── sso.css │ ├── static_files.go │ ├── static_files_test.go │ ├── statik │ │ └── statik.go │ └── version.go ├── pkg │ ├── aead │ │ ├── aead.go │ │ ├── aead_test.go │ │ └── mock_cipher.go │ ├── groups │ │ ├── fillcache.go │ │ ├── fillcache_test.go │ │ ├── localcache.go │ │ ├── localcache_test.go │ │ └── mock_cache.go │ ├── hostmux │ │ ├── hostmux.go │ │ └── hostmux_test.go │ ├── httpserver │ │ ├── httpserver.go │ │ └── httpserver_test.go │ ├── logging │ │ └── logging.go │ ├── sessions │ │ ├── cookie_store.go │ │ ├── cookie_store_test.go │ │ ├── mock_store.go │ │ ├── session_state.go │ │ └── session_state_test.go │ ├── singleflight │ │ ├── singleflight.go │ │ └── singleflight_test.go │ ├── templates │ │ ├── mock_templates.go │ │ ├── templates.go │ │ └── templates_test.go │ ├── testutil │ │ └── testutil.go │ └── validators │ │ ├── email_address_validator.go │ │ ├── email_address_validator_test.go │ │ ├── email_domain_validator.go │ │ ├── email_domain_validator_test.go │ │ ├── email_group_validator.go │ │ ├── mock_validator.go │ │ └── validators.go └── proxy │ ├── collector │ └── collector.go │ ├── configuration.go │ ├── configuration_test.go │ ├── logging_handler.go │ ├── logging_handler_test.go │ ├── metrics.go │ ├── metrics_test.go │ ├── middleware.go │ ├── oauthproxy.go │ ├── oauthproxy_test.go │ ├── options.go │ ├── options_test.go │ ├── providers │ ├── http_client.go │ ├── provider_data.go │ ├── providers.go │ ├── singleflight_middleware.go │ ├── sso.go │ ├── sso_test.go │ └── test_provider.go │ ├── proxy.go │ ├── proxy_config.go │ ├── proxy_config_test.go │ ├── proxy_test.go │ ├── request_signer.go │ ├── request_signer_test.go │ ├── reverse_proxy.go │ ├── reverse_proxy_test.go │ ├── templates.go │ ├── templates_test.go │ ├── testdata │ ├── private_key.pem │ ├── public_key.pub │ └── upstream_configs.yml │ └── version.go ├── quickstart ├── docker-compose.yml ├── env.google.example ├── env.okta.example ├── kubernetes │ ├── demo-apps │ │ ├── hello-world-deployment.yml │ │ ├── hello-world-svc.yml │ │ ├── httpbin-deployment.yml │ │ └── httpbin-svc.yml │ ├── sso-auth-deployment.yml │ ├── sso-auth-svc.yml │ ├── sso-ingress.yml │ ├── sso-proxy-deployment.yml │ ├── sso-proxy-svc.yml │ └── upstream-configs-configmap.yml └── upstream_configs.yml ├── scripts ├── dist.sh └── test └── static └── sso.css /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | defaults: &defaults 2 | docker: 3 | - image: cimg/go:1.21 4 | working_directory: ~/go/src/github.com/buzzfeed/sso 5 | 6 | attach_workspace: &attach_workspace 7 | attach_workspace: 8 | at: ~/go/src/github.com/buzzfeed/sso 9 | 10 | version: 2 11 | jobs: 12 | build: 13 | <<: *defaults 14 | steps: 15 | - setup_remote_docker 16 | - checkout 17 | - run: 18 | name: Enable go modules 19 | command: | 20 | echo 'export GO111MODULE=on' >> $BASH_ENV 21 | - run: 22 | name: get tools 23 | command: make tools 24 | - run: 25 | name: run lint and tests for both services 26 | command: | 27 | scripts/test 28 | - run: 29 | name: build sso-auth 30 | command: | 31 | make dist/sso-auth 32 | - run: 33 | name: build sso-proxy 34 | command: | 35 | make dist/sso-proxy 36 | - persist_to_workspace: 37 | root: ~/go/src/github.com/buzzfeed/sso 38 | paths: 39 | - . 40 | 41 | push-sso-dev-commit: 42 | <<: *defaults 43 | steps: 44 | - *attach_workspace 45 | - setup_remote_docker 46 | - run: 47 | name: push sso-dev commit tag 48 | command: | 49 | if [[ -n $DOCKER_USER ]]; then 50 | docker login -u $DOCKER_USER -p $DOCKER_PASS 51 | make imagepush-commit 52 | fi 53 | 54 | push-sso-dev-latest: 55 | <<: *defaults 56 | steps: 57 | - *attach_workspace 58 | - setup_remote_docker 59 | - run: 60 | name: push sso-dev latest tag 61 | command: | 62 | if [[ -n $DOCKER_USER ]]; then 63 | docker login -u $DOCKER_USER -p $DOCKER_PASS 64 | make imagepush-latest 65 | fi 66 | 67 | upload-codecov: 68 | <<: *defaults 69 | steps: 70 | - *attach_workspace 71 | - run: 72 | name: upload coverage file to codecov 73 | command: | 74 | bash <(curl -s https://codecov.io/bash) 75 | 76 | 77 | workflows: 78 | version: 2 79 | build-and-push: 80 | jobs: 81 | - build 82 | - push-sso-dev-commit: 83 | requires: 84 | - build 85 | - push-sso-dev-latest: 86 | requires: 87 | - build 88 | filters: 89 | branches: 90 | only: main 91 | - upload-codecov: 92 | requires: 93 | - build 94 | -------------------------------------------------------------------------------- /.codecov/codecov.yml: -------------------------------------------------------------------------------- 1 | # Codecov configuration file for this repo. 2 | # This file can be validated for syntax and layout issues using: 3 | # cat codecov.yml | curl --data-binary @- https://codecov.io/validate 4 | # as per Codecov's documentation. 5 | 6 | 7 | # https://docs.codecov.io/docs/coverage-configuration 8 | coverage: 9 | precision: 2 10 | round: down 11 | range: "50...80" 12 | 13 | # https://docs.codecov.io/docs/commit-status 14 | status: 15 | project: 16 | default: 17 | threshold: 2% 18 | base: auto 19 | patch: no 20 | changes: no 21 | 22 | parsers: 23 | gcov: 24 | branch_detection: 25 | conditional: yes 26 | loop: yes 27 | method: no 28 | macro: no 29 | 30 | # https://docs.codecov.io/docs/pull-request-comments 31 | comment: 32 | layout: "header, diff, files" 33 | behavior: default 34 | require_changes: no 35 | 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | **Describe the bug** 8 | A clear and concise description of what the bug is. 9 | 10 | **To Reproduce** 11 | Steps to reproduce the behavior: 12 | 1. Go to '...' 13 | 2. Click on '....' 14 | 3. Scroll down to '....' 15 | 4. See error 16 | 17 | **Expected behavior** 18 | A clear and concise description of what you expected to happen. 19 | 20 | **Screenshots** 21 | If applicable, add screenshots to help explain your problem. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS: [e.g. iOS] 25 | - Browser [e.g. chrome, safari] 26 | - Version [e.g. 22] 27 | 28 | **Smartphone (please complete the following information):** 29 | - Device: [e.g. iPhone6] 30 | - OS: [e.g. iOS8.1] 31 | - Browser [e.g. stock browser, safari] 32 | - Version [e.g. 22] 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe.** 8 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 9 | 10 | **Describe the solution you'd like** 11 | A clear and concise description of what you want to happen. 12 | 13 | **Describe alternatives you've considered** 14 | A clear and concise description of any alternative solutions or features you've considered. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$NEXT_PATCH_VERSION' 2 | tag-template: 'v$NEXT_PATCH_VERSION' 3 | change-template: '*#$NUMBER - $TITLE' 4 | branches: main 5 | categories: 6 | - title: 'Features 🚀' 7 | labels: 8 | - 'feature' 9 | - 'enhancement' 10 | - title: 'Bug Fixes 🐛' 11 | labels: 12 | - 'bugfix' 13 | - 'bug' 14 | - title: 'Documentation 📖' 15 | label: 16 | - 'documentation' 17 | template: | 18 | ## What’s Changed 19 | 20 | $CHANGES 21 | 22 | Release Contributors: $CONTRIBUTORS 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # discourage people from committing env files, which are used for the 2 | # quickstart process and may contain secrets 3 | env 4 | 5 | # Don't check in the dist/ directory created by make 6 | dist/ 7 | 8 | # Compiled source # 9 | ################### 10 | *.com 11 | *.class 12 | *.dll 13 | *.exe 14 | *.o 15 | *.so 16 | 17 | # Packages # 18 | ############ 19 | # it's better to unpack these files and commit the raw source 20 | # git has its own built in compression methods 21 | *.7z 22 | *.dmg 23 | *.gz 24 | *.iso 25 | *.jar 26 | *.rar 27 | *.tar 28 | *.zip 29 | 30 | # Logs and databases # 31 | ###################### 32 | *.log 33 | *.sql 34 | *.sqlite 35 | 36 | # OS generated files # 37 | ###################### 38 | .DS_Store 39 | .DS_Store? 40 | ._* 41 | .Spotlight-V100 42 | .Trashes 43 | ehthumbs.db 44 | Thumbs.db 45 | 46 | # Byte-compiled / optimized / DLL files 47 | __pycache__/ 48 | *.py[cod] 49 | *$py.class 50 | 51 | # C extensions 52 | *.so 53 | 54 | # Test binary, build with `go test -c` 55 | *.test 56 | 57 | # Output of the go coverage tool, specifically when used with LiteIDE 58 | *.out 59 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at sso@buzzfeed.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for your interest in contributing to BuzzFeed's SSO! 4 | 5 | ## Code of Conduct 6 | 7 | Help us keep SSO open and inclusive. Please read and follow our [Code of Conduct](CODE_OF_CONDUCT.md). 8 | 9 | ## Getting Started 10 | 11 | ### The Basics 12 | 13 | If you've never contributed to a project on Github before, take a look at [Rob Allen's beginner's guide to contributing to a Github project][begin-guide]. 14 | 15 | * Make sure you have a [GitHub account](https://github.com/signup/free) 16 | * Submit a ticket for your issue, assuming one does not already exist. In that issue: 17 | * Clearly describe the issue including steps to reproduce when it is a bug 18 | * Identify specific versions of the binaries and client libraries 19 | * Fork the repository on GitHub 20 | 21 | ### Making Changes 22 | 23 | * Create a branch from where you want to base your work 24 | * We typically name branches according to the following format: `helpful_name_` 25 | * Make commits of logical units 26 | * Make sure your [commit messages](https://chris.beams.io/posts/git-commit/) are in a clear and readable format, example: 27 | 28 | ``` 29 | sso: added structured logging 30 | 31 | * use Logrus to add structured logging 32 | * update logging documentation 33 | * ... 34 | ``` 35 | 36 | ### Quickstart 37 | 38 | If you want to get SSO up and running, please see our [quickstart guide](docs/quickstart.md). 39 | 40 | ### Running Tests 41 | 42 | All bug fixes and new features are expected to include tests, so it's helpful to know how to run them locally while you're developing! 43 | 44 | Here are the steps to running the tests for SSO: 45 | 46 | 1. [Install go][go] 47 | 2. Be sure to set your workspace directory and set your GOPATH 48 | 3. Add the sso repo to your gopath: 49 | 50 | ```sh 51 | ln -s path/to/sso $GOPATH/src/github.com/buzzfeed/sso 52 | ``` 53 | 54 | 4. Install gpm and dependencies: 55 | 56 | ```sh 57 | brew install gpm 58 | gpm install 59 | ``` 60 | 61 | 5. Run the tests in the root directory 62 | 63 | ```sh 64 | ./scripts/test 65 | ``` 66 | 67 | ### Submitting Changes 68 | 69 | * Push your changes to your branch in your fork of the repository 70 | * Submit a [pull request against BuzzFeed's repository](https://akrabat.com/the-beginners-guide-to-contributing-to-a-github-project/#step-3-create-the-pr) 71 | * Comment in the pull request when you're ready for the changes to be reviewed: `"ready for review"` 72 | 73 | ### Code Review 74 | 75 | Changes to `buzzfeed/sso` happen via GitHub pull requests, which is where at least one other engineer reviews and approves all code changes. Some tips for pull requests and code review: 76 | 77 | * Each pull request is a branch from `main` — there are no other long-lived branches 78 | * A single pull request is used for each set of changes — i.e., once a pull request has been opened, any follow-up commits to address code review and discussion should be pushed to that pull request instead of a new one 79 | * Before a pull request is merged, it must get a green check mark from our CI for passing tests 80 | * Before a pull request is merged, the author should take the opportunity to clean up and rewrite the commits on the branch (see [How to Write a Git Commit Message](https://chris.beams.io/posts/git-commit/). 81 | * A maintainer will merge your changes once they're approved and ready! 82 | * **Most importantly: Practice [mindful communication][mindful-comms]!** 83 | 84 | [begin-guide]: https://akrabat.com/the-beginners-guide-to-contributing-to-a-github-project/ 85 | [mindful-comms]: https://kickstarter.engineering/a-guide-to-mindful-communication-in-code-reviews-48aab5282e5e 86 | [go]: https://golang.org/doc/install 87 | 88 | ### Merge Policy 89 | 90 | In order to allow SSO to be the most useful, while balancing feature needs with security concerns, we take an [Optimistic Merge](http://hintjens.com/blog:106) approach, which means: 91 | 92 | * We will do our best to merge every contribution that enables a requested feature or solves an issue that blocks folks using SSO. 93 | * As we find more efficient, secure, or otherwise stronger solutions, we will deprecate old features and replace them. 94 | 95 | Our strategy will be to announce deprecation as part of a stable release, and remove functionality as part of the next stable release. While this is a very quick cycle, we believe it is the best approach for keeping SSO up to date and preventing reliance on outdated featuresets. 96 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # build stage 3 | # 4 | # install golang dependencies & build binaries 5 | # ============================================================================= 6 | FROM golang:1.21 AS build 7 | 8 | ENV GOFLAGS='-ldflags=-s -ldflags=-w' 9 | ENV CGO_ENABLED=0 10 | ENV GO111MODULE=on 11 | 12 | WORKDIR /go/src/github.com/buzzfeed/sso 13 | 14 | COPY ./go.mod ./go.sum ./ 15 | RUN go mod download 16 | COPY . . 17 | RUN cd cmd/sso-auth && go build -mod=readonly -o /bin/sso-auth 18 | RUN cd cmd/sso-proxy && go build -mod=readonly -o /bin/sso-proxy 19 | RUN cd cmd/sso-proxy/generate-request-signature && go build -mod=readonly -o /bin/sso-generate-request-signature 20 | 21 | # ============================================================================= 22 | # final stage 23 | # 24 | # add static assets and copy binaries from build stage 25 | # ============================================================================= 26 | FROM debian:bookworm-slim 27 | RUN apt-get update && apt-get install -y ca-certificates curl && rm -rf /var/lib/apt/lists/* \ 28 | && groupadd -r sso && useradd -r -g sso sso 29 | WORKDIR /sso 30 | COPY --from=build /bin/sso-* /bin/ 31 | USER sso 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy 2 | of this software and associated documentation files (the "Software"), to deal 3 | in the Software without restriction, including without limitation the rights 4 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 5 | copies of the Software, and to permit persons to whom the Software is 6 | furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in 9 | all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 17 | THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | version := "v3.1.1" 2 | 3 | commit := $(shell git rev-parse --short HEAD) 4 | 5 | 6 | build: dist/sso-auth dist/sso-proxy 7 | 8 | dist/sso-auth: 9 | mkdir -p dist 10 | go generate ./... 11 | go build -mod=readonly -o dist/sso-auth ./cmd/sso-auth 12 | 13 | dist/sso-proxy: 14 | mkdir -p dist 15 | go generate ./... 16 | go build -mod=readonly -o dist/sso-proxy ./cmd/sso-proxy 17 | 18 | tools: 19 | go install golang.org/x/lint/golint@latest 20 | go install github.com/rakyll/statik@latest 21 | 22 | test: 23 | ./scripts/test 24 | 25 | clean: 26 | rm -r dist 27 | 28 | imagepush-commit: 29 | docker buildx create --use --platform=linux/arm64,linux/amd64 --name imagepush-commit 30 | docker buildx build -t buzzfeed/sso-dev:$(commit) . --platform linux/amd64,linux/arm64 --push 31 | docker buildx rm imagepush-commit 32 | 33 | imagepush-latest: 34 | docker buildx create --use --platform=linux/arm64,linux/amd64 --name imagepush-latest 35 | docker buildx build -t buzzfeed/sso-dev:latest . --platform linux/amd64,linux/arm64 --push 36 | docker buildx rm imagepush-latest 37 | 38 | releasepush: 39 | docker buildx create --use --platform=linux/arm64,linux/amd64 --name releasepush 40 | docker buildx build -t buzzfeed/sso:$(version) . --platform linux/amd64,linux/arm64 --push 41 | docker buildx build -t buzzfeed/sso:latest . --platform linux/amd64,linux/arm64 --push 42 | docker buildx rm releasepush 43 | 44 | .PHONY: dist/sso-auth dist/sso-proxy tools 45 | -------------------------------------------------------------------------------- /PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Problem 2 | 3 | Please describe the problem this PR will solve. 4 | 5 | ## Solution 6 | 7 | What is the proposed solution? 8 | 9 | ## Notes 10 | 11 | Other pertinent information. Examples: a walkthrough of how the solution might work, why this solution is optimal compared to other possible solutions, or further TODOs beyond this PR. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sso 2 | 3 | > See our launch [blog post](https://tech.buzzfeed.com/unleashing-the-a6a1a5da39d6) for more information! 4 | 5 | [![CircleCI](https://circleci.com/gh/buzzfeed/sso.svg?style=svg)](https://circleci.com/gh/buzzfeed/sso) 6 | [![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) 7 | [![Docker Automated build](https://img.shields.io/docker/automated/buzzfeed/sso.svg)](https://hub.docker.com/r/buzzfeed/sso/) 8 | [![codecov.io](https://codecov.io/github/buzzfeed/sso/coverage.svg?branch=main)](https://codecov.io/github/buzzfeed/sso?branch=main) 9 | 10 | 11 | 12 | 13 | > Please take the [SSO Community Survey][sso_survey] to let us know how we're doing, and to help us plan our roadmap! 14 | 15 | ---- 16 | 17 | **sso** — lovingly known as *the S.S. Octopus* or *octoboi* — is the 18 | authentication and authorization system BuzzFeed developed to provide a secure, 19 | single sign-on experience for access to the many internal web apps used by our 20 | employees. 21 | 22 | It depends on Google as its authoritative OAuth2 provider, and authenticates 23 | users against a specific email domain. Further authorization based on Google 24 | Group membership can be required on a per-upstream basis. 25 | 26 | The main idea behind **sso** is a "double OAuth2" flow, where `sso-auth` is the 27 | OAuth2 provider for `sso-proxy` and Google is the OAuth2 provider for `sso-auth`. 28 | 29 | [sso](https://github.com/buzzfeed/sso) is built on top of Bitly’s open source [oauth2_proxy](https://github.com/bitly/oauth2_proxy) 30 | 31 | In a nutshell: 32 | 33 | - If a user visits an `sso-proxy`-protected service (`foo.sso.example.com`) and does not have a session cookie, they are redirected to `sso-auth` (`sso-auth.example.com`). 34 | - If the user **does not** have a session cookie for `sso-auth`, 35 | they are prompted to log in via the usual Google OAuth2 flow, and then 36 | redirected back to `sso-proxy` where they will now be logged in (to 37 | `foo.sso.example.com`) 38 | - If the user *does* have a session cookie for `sso-auth` (e.g. they 39 | have already logged into `bar.sso.example.com`), they are 40 | transparently redirected back to `proxy` where they will be logged in, 41 | without needing to go through the Google OAuth2 flow 42 | - `sso-proxy` transparently re-validates & refreshes the user's session with `sso-auth` 43 | 44 | ## Installation 45 | 46 | - [Prebuilt binary releases](https://github.com/buzzfeed/sso/releases) 47 | - [Docker][docker_hub] 48 | - `go get github.com/buzzfeed/sso/cmd/...` 49 | 50 | ## Quickstart 51 | 52 | Follow our [Quickstart guide](docs/quickstart.md) to spin up a local deployment 53 | of **sso** to get a feel for how it works! 54 | 55 | ## Code of Conduct 56 | 57 | Help us keep **sso** open and inclusive. Please read and follow our [Code of Conduct](CODE_OF_CONDUCT.md). 58 | 59 | ## Contributing 60 | 61 | Contributions to **sso** are welcome! Please follow our [contribution guideline](CONTRIBUTING.md). 62 | 63 | ### Issues 64 | 65 | Please file any issues you find in our [issue tracker](https://github.com/buzzfeed/sso/issues). 66 | 67 | ### Security Vulns 68 | 69 | If you come across any security vulnerabilities with the **sso** repo or software, please email security@buzzfeed.com. In your email, please request access to our [bug bounty program](https://hackerone.com/buzzfeed) so we can compensate you for any valid issues reported. 70 | 71 | ## Maintainers 72 | 73 | **sso** is actively maintained by the BuzzFeed Infrastructure teams. 74 | 75 | ## Notable forks 76 | 77 | - [pomerium](https://github.com/pomerium/pomerium) an identity-access proxy, inspired by BeyondCorp. 78 | 79 | [docker_hub]: https://hub.docker.com/r/buzzfeed/sso/ 80 | [sso_survey]: https://docs.google.com/forms/d/e/1FAIpQLSeRjf66ZSpMkSASMbYebx6QvECYRj9nUevOhUF2huw53sE6_g/viewform 81 | -------------------------------------------------------------------------------- /cmd/sso-auth/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/buzzfeed/sso/internal/auth" 9 | "github.com/buzzfeed/sso/internal/pkg/httpserver" 10 | "github.com/buzzfeed/sso/internal/pkg/logging" 11 | ) 12 | 13 | func init() { 14 | logging.SetServiceName("sso-authenticator") 15 | } 16 | 17 | func main() { 18 | logger := logging.NewLogEntry() 19 | 20 | config, err := auth.LoadConfig() 21 | if err != nil { 22 | logger.Error(err, "error loading in config from env vars") 23 | os.Exit(1) 24 | } 25 | 26 | err = config.Validate() 27 | if err != nil { 28 | logger.Error(err, "error validating config") 29 | os.Exit(1) 30 | } 31 | 32 | sc := config.MetricsConfig.StatsdConfig 33 | statsdClient, err := auth.NewStatsdClient(sc.Host, sc.Port) 34 | if err != nil { 35 | logger.Error(err, "error creating statsd client") 36 | os.Exit(1) 37 | } 38 | 39 | authMux, err := auth.NewAuthenticatorMux(config, statsdClient) 40 | if err != nil { 41 | logger.Error(err, "error creating new AuthenticatorMux") 42 | os.Exit(1) 43 | } 44 | defer authMux.Stop() 45 | 46 | // we leave the message field blank, which will inherit the stdlib timeout page which is sufficient 47 | // and better than other naive messages we would currently place here 48 | timeoutHandler := http.TimeoutHandler(authMux, config.ServerConfig.TimeoutConfig.Request, "") 49 | 50 | s := &http.Server{ 51 | Addr: fmt.Sprintf(":%d", config.ServerConfig.Port), 52 | ReadTimeout: config.ServerConfig.TimeoutConfig.Read, 53 | WriteTimeout: config.ServerConfig.TimeoutConfig.Write, 54 | Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, config.LoggingConfig.Enable, statsdClient), 55 | } 56 | 57 | if err := httpserver.Run(s, config.ServerConfig.TimeoutConfig.Shutdown, logger); err != nil { 58 | logger.WithError(err).Fatal("error running server") 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /cmd/sso-proxy/generate-request-signature/generate_request_signature.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "net/http" 7 | "os" 8 | "strings" 9 | 10 | "github.com/buzzfeed/sso/internal/pkg/logging" 11 | "github.com/buzzfeed/sso/internal/proxy" 12 | ) 13 | 14 | // Name of the header used to transmit the signature computed for the request. 15 | var signatureHeader = "Sso-Signature" 16 | var signingKeyHeader = "kid" 17 | 18 | func main() { 19 | logger := logging.NewLogEntry() 20 | var requestSigner *proxy.RequestSigner 21 | 22 | config, err := proxy.LoadConfig() 23 | if err != nil { 24 | logger.Error(err, "error loading in config from env vars") 25 | os.Exit(1) 26 | } 27 | 28 | urlPtr := flag.String("url", "", "URL of request to sign") 29 | methodPtr := flag.String("method", "GET", "Method of request to sign") 30 | bodyPtr := flag.String("body", "", "Body of request to sign") 31 | 32 | err = config.Validate() 33 | if err != nil { 34 | logger.Error(err, "error validating config") 35 | os.Exit(1) 36 | } 37 | 38 | requestSigner, err = proxy.NewRequestSigner(config.RequestSignerConfig.Key) 39 | if err != nil { 40 | logger.Error(err, "error creating request signer") 41 | os.Exit(1) 42 | } 43 | 44 | flag.Parse() 45 | args := flag.Args() 46 | 47 | requestBody := *strings.NewReader(*bodyPtr) 48 | 49 | req, err := http.NewRequest(*methodPtr, *urlPtr, &requestBody) 50 | if err != nil { 51 | logger.Error(err, "error creating request") 52 | os.Exit(1) 53 | } 54 | 55 | for _, h := range args { 56 | hKv := strings.Split(h, ":") 57 | req.Header.Set(hKv[0], hKv[1]) 58 | } 59 | 60 | err = requestSigner.Sign(req) 61 | if err != nil { 62 | logger.Error(err, "error signing request") 63 | os.Exit(1) 64 | } 65 | 66 | signature := req.Header.Get(signatureHeader) 67 | keyHeader := req.Header.Get(signingKeyHeader) 68 | 69 | fmt.Printf("URL: %v\n", *urlPtr) 70 | fmt.Printf("method: %v\n", *methodPtr) 71 | fmt.Printf("body: %v\n", *bodyPtr) 72 | fmt.Printf("headers: %v\n", req.Header) 73 | fmt.Println("==================================================================") 74 | 75 | fmt.Printf("signature: %v\n", signature) 76 | fmt.Printf("keyHeader: %v\n", keyHeader) 77 | } 78 | -------------------------------------------------------------------------------- /cmd/sso-proxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | "time" 8 | 9 | "github.com/buzzfeed/sso/internal/pkg/httpserver" 10 | "github.com/buzzfeed/sso/internal/pkg/logging" 11 | "github.com/buzzfeed/sso/internal/proxy" 12 | "github.com/buzzfeed/sso/internal/proxy/collector" 13 | ) 14 | 15 | func init() { 16 | logging.SetServiceName("sso-proxy") 17 | } 18 | 19 | func main() { 20 | logger := logging.NewLogEntry() 21 | 22 | config, err := proxy.LoadConfig() 23 | if err != nil { 24 | logger.Error(err, "error loading in config from env vars") 25 | os.Exit(1) 26 | } 27 | 28 | err = config.Validate() 29 | if err != nil { 30 | logger.Error(err, "error validating config") 31 | os.Exit(1) 32 | } 33 | 34 | sc := config.MetricsConfig.StatsdConfig 35 | statsdClient, err := proxy.NewStatsdClient(sc.Host, sc.Port) 36 | if err != nil { 37 | logger.Error(err, "error creating statsd client") 38 | os.Exit(1) 39 | } 40 | 41 | // we setup a runtime collector to emit stats 42 | go func() { 43 | c := collector.New(statsdClient, 30*time.Second) 44 | c.Run() 45 | }() 46 | 47 | err = proxy.SetUpstreamConfigs( 48 | &config.UpstreamConfigs, 49 | config.SessionConfig.CookieConfig, 50 | &config.ServerConfig, 51 | ) 52 | if err != nil { 53 | logger.Error(err, "error setting upstream configs") 54 | os.Exit(1) 55 | } 56 | 57 | ssoProxy, err := proxy.New(config, statsdClient) 58 | if err != nil { 59 | logger.Error(err, "error creating sso proxy") 60 | os.Exit(1) 61 | } 62 | 63 | loggingHandler := proxy.NewLoggingHandler(os.Stdout, 64 | ssoProxy, 65 | config.LoggingConfig, 66 | statsdClient, 67 | ) 68 | 69 | s := &http.Server{ 70 | Addr: fmt.Sprintf(":%d", config.ServerConfig.Port), 71 | ReadTimeout: config.ServerConfig.TimeoutConfig.Read, 72 | WriteTimeout: config.ServerConfig.TimeoutConfig.Write, 73 | Handler: loggingHandler, 74 | } 75 | 76 | if err := httpserver.Run(s, config.ServerConfig.TimeoutConfig.Shutdown, logger); err != nil { 77 | logger.WithError(err).Fatal("error running server") 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /docs/API.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## User-facing endpoints 4 | 5 | ### GET /sign_in 6 | Serves the sign in button. 7 | 8 | |Query Parameters| | 9 | |:---|:---| 10 | |`client_id`|The unique client id for the SSO proxy| 11 | |`redirect_uri`|The redirect URI to return to when the authentication with Google is complete.| 12 | |`redirect_sig`|The signature for the redirect URI. See **Redirect Validation** below for a description on how the redirect URI is signed.| 13 | 14 | ### GET /sign_out 15 | Serves the sign out button. This page requires a valid redirect URI, so you can get here via the proxy: `{{service}}.sso.example.com/oauth2/sign_out` 16 | 17 | |Query Parameters| | 18 | |:---|:---| 19 | |`redirect_uri`|A redirect URI pointing to the proxy to return to when de-authentication is complete.| 20 | |`redirect_sig`|The signature for the redirect URI. See **Redirect Validation** below for a description on how the redirect URI is signed.| 21 | 22 | ### GET /static/... 23 | This serves static CSS and image files for the sign in, sign out, and error 24 | pages for both `sso_auth` and `sso_proxy`. 25 | 26 | ## Endpoints to authenticate a user 27 | 28 | ### POST /start 29 | This is the entrance to the OAuth flow, which is started via the HTML form 30 | `POST` from "the button". 31 | 32 | |Parameters| | 33 | |:---|:---| 34 | |`redirect_uri`|The redirect URI back to the _authenticator_. This URI contains query parameters, `redirect_uri` (AKA the nested redirect) and `redirect_sig`, which redirect back to the proxy and sign the nested redirect, respectively.| 35 | 36 | ### GET /oauth2/callback 37 | Once the user has authenticated with the provider, they are redirected to 38 | this endpoint, which sets the Authenticator Cookie and then redirects 39 | them back to Proxy. 40 | 41 | ### POST /sign_out 42 | Revokes the token with the provider 43 | 44 | |Parameters| | 45 | |:---|:---| 46 | |`redirect_uri`|A redirect URI pointing to the proxy to return to when de-authentication is complete.| 47 | |`redirect_sig`|The signature for the redirect URI. See **Redirect Validation** below for a description on how the redirect URI is signed.| 48 | 49 | ## Endpoints to function as an OAuth Provider 50 | 51 | ### GET /profile 52 | This method returns the Google group membership given an email address. The 53 | proxy uses this list of groups to determine whether a user has access to the 54 | upstream service. 55 | 56 | |Parameters| | 57 | |:---|:---| 58 | |`email`|The user's email| 59 | |`client_id`|The unique client id for the SSO proxy| 60 | |`X-Client-Secret` header|The proxy's client secret. This is a request being made directly from `sso_proxy` to `sso_auth`| 61 | 62 | |Response| | 63 | |:---|:---| 64 | |`email`|The user's email address| 65 | |`groups`|A list of all Google groups memberships for the user. `sso_proxy` is responsible for validating that the user is a member of the correct group for the upstream.| 66 | 67 | ### GET /validate 68 | Returns OK if an access token is valid. 69 | 70 | |Parameters| | 71 | |:---|:---| 72 | |`client_id`|The unique client id for the SSO proxy| 73 | |`X-Access-Token` header|The access token from Google| 74 | |`X-Client-Secret` header|The proxy's client secret. This is a request being made directly from `sso_proxy` to `sso_auth`| 75 | 76 | ### POST /redeem 77 | Redeem an access code for an access token and refresh token. 78 | 79 | |Parameters| | 80 | |:---|:---| 81 | |`code`|An encrypted payload which contains the access token.| 82 | |`client_id`|The unique client id for the SSO proxy| 83 | |`client_secret`|The proxy's client secret. This is a request being made directly from `sso_proxy` to `sso_auth`| 84 | 85 | |Response| | 86 | |:---|:---| 87 | |`access_token`|The access token from Google| 88 | |`refresh_token`|The refresh token from Google| 89 | |`expires_in`|The expiration timestamp of `access_token`| 90 | |`email`|The user's email address| 91 | 92 | ### POST /refresh 93 | Refresh the access token using the refresh token. 94 | 95 | |Parameters| | 96 | |:---|:---| 97 | |`refresh_token`|The refresh token from Google.| 98 | |`client_id`|The unique client id for the SSO proxy| 99 | |`client_secret`|The proxy's client secret. This is a request being made directly from `sso_proxy` to `sso_auth`| 100 | 101 | |Response| | 102 | |:---|:---| 103 | |`access_token`|The access token from Google| 104 | |`expires_in`|The expiration timestamp of `access_token`| 105 | -------------------------------------------------------------------------------- /docs/adr/2018-08-16-circuit-breaker.md: -------------------------------------------------------------------------------- 1 | # Architecture Design Record: SS🐙 Authenticator Circuit Breaker 2 | 3 | > Gracefully handle provider-side faults and failures. 4 | 5 | __Status:__ Implemented 6 | 7 | #### Context 8 | SS🐙 fundamentally relies on a third-party provider service to authenticate requests (e.g. OAuth2 9 | via Google). The provider may experience transient faults for any number of reasons - perhaps the 10 | provider is experiencing a transient outage, or our traffic has been temporarily rate limited. We 11 | wanted to build SS🐙 to anticipate and handle the presence of faults during communication with 12 | authentication providers, no matter the underlying cause. 13 | 14 | #### Decision 15 | We implemented a generic `circuit.Breaker` type, which implements the "Circuit Breaker" design 16 | pattern made popular by Michael T. Nygard's book, ["Release It!"]( 17 | https://pragprog.com/book/mnee/release-it). All requests to the provider service are issued through 18 | a stateful `circuit.Breaker` instance, which tracks the frequency of request failures. When the 19 | `Breaker` sees that enough requests are failing, it temporarily suspends all outgoing traffic to the 20 | provider (i.e. "enters `Open` state). After some time, the `Breaker` transitions to a `HalfOpen` 21 | state, in which a limited number of outbound requests are allowed. If failures persist, then the 22 | `Breaker` will once again suspend outbound traffic, re-enter `Open` state, and typically will wait 23 | for a longer interval of time before trying again. If instead the `Breaker` observes that requests 24 | are consistently succeeding, then it will resume all outbound traffic (i.e. enter `Closed` state). 25 | 26 | The SS🐙 [`docs`](/docs/) directory contains a [Circuit Breaker]( 27 | /docs/architecture/circuit-breaker.md) document with more details. The implementation can be found 28 | in [`breaker.go`](/internal/auth/circuit/breaker.go) 29 | 30 | #### Consequences 31 | SS🐙 now utilizes a robust and generic strategy for handling faults originating from the upstream 32 | authentication provider service. 33 | -------------------------------------------------------------------------------- /docs/diagrams/sso_request_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/diagrams/sso_request_flow.png -------------------------------------------------------------------------------- /docs/generate_request_signature.md: -------------------------------------------------------------------------------- 1 | ## sso_proxy/generate-request-signature 2 | `sso-generate-request-signature` is a command-line tool that is provided alongside `sso_proxy` to facilitate upstream testing of SSO's proxied request signatures, or of libraries that attempt to validate signatures. Note that running `sso-generate-request-signature` requires some of the same configuration to be in place as running `sso-proxy` would, i.e. a `REQUESTSIGNER_KEY` environment variable. Assuming that exists, the tool can be run with three optional flags, followed by as many request headers as desired in `:` format: 3 | 4 | - `-url`: the path of the URL of the request to calculate the signature on (default: "") 5 | - `-method`: the method of the request (default: "GET") (note: the method is not used in the signature calculation) 6 | - `-body`: for PUT or POST requests, a string representing the body of the request to calculate a signature on (default: "") 7 | 8 | ### examples 9 | `sso-generate-request-signature -url "/foo" -method "POST" -body "{}" x-header-1:bar x-header-2:baz` 10 | 11 | `sso-generate-request-signature -url "/bar"` 12 | -------------------------------------------------------------------------------- /docs/img/choose-account.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/choose-account.jpg -------------------------------------------------------------------------------- /docs/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/logo.png -------------------------------------------------------------------------------- /docs/img/okta/okta-app-settings.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/okta/okta-app-settings.jpg -------------------------------------------------------------------------------- /docs/img/okta/okta-auth-server-claims.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/okta/okta-auth-server-claims.jpg -------------------------------------------------------------------------------- /docs/img/okta/okta-auth-server-scope.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/okta/okta-auth-server-scope.jpg -------------------------------------------------------------------------------- /docs/img/okta/okta-homepage-api.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/okta/okta-homepage-api.jpg -------------------------------------------------------------------------------- /docs/img/payload-screen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/payload-screen.jpg -------------------------------------------------------------------------------- /docs/img/setup-admin_api.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-admin_api.jpg -------------------------------------------------------------------------------- /docs/img/setup-api_client_access.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-api_client_access.jpg -------------------------------------------------------------------------------- /docs/img/setup-consent_screen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-consent_screen.jpg -------------------------------------------------------------------------------- /docs/img/setup-create_credentials.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-create_credentials.jpg -------------------------------------------------------------------------------- /docs/img/setup-create_service_account.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-create_service_account.jpg -------------------------------------------------------------------------------- /docs/img/setup-credentials.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-credentials.jpg -------------------------------------------------------------------------------- /docs/img/setup-manage_resources.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-manage_resources.jpg -------------------------------------------------------------------------------- /docs/img/setup-notification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-notification.jpg -------------------------------------------------------------------------------- /docs/img/setup-security_control.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/setup-security_control.jpg -------------------------------------------------------------------------------- /docs/img/start-auth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/start-auth.jpg -------------------------------------------------------------------------------- /docs/img/start-script.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/docs/img/start-script.jpg -------------------------------------------------------------------------------- /docs/okta_provider_setup.md: -------------------------------------------------------------------------------- 1 | # Okta Provider Setup & Configuration 2 | 3 | In order for **sso** to use Okta as it's OAuth provider there are some manual steps that must 4 | be followed to create and configure the necessary credentials. 5 | 6 | ## 1. Create an Authorization Server 7 | Use a web browser to access your Okta account administrator console. 8 | 9 | Select the 'Security' dropdown and then 'API'. If you don't then see 'Authorization Servers' then 10 | you will need to contact Okta Support to have this feature enabled. 11 | (If you are logged in to an Okta Developer console, then this will be under the 'API' dropdown and then 'Authorization Servers') 12 | 13 | 14 | 15 | For the purpose of this example we will use the default server that is provided, however you can create a new one if preferred. 16 | 17 | ## 2. Configure your Authorization Server 18 | 19 | ### Server Scopes 20 | 21 | Select your Authorization server, and then 'Scopes'. The only additional scope you will need to add is a scope for groups. 22 | 23 | Click 'Add Scope' and fill in the fields with the below options: 24 | 25 | 26 | 27 | It is important the **Name** is simply `groups`. The other fields can be customised however we recommend the above. 28 | 29 | ### Server Claims 30 | 31 | Select 'Claims'. You now need to add a claim for the group scope that was just added. 32 | Click 'Add Claim' and fill in the fields with the below options: 33 | 34 | 35 | 36 | The regex for `Filter` can be changed to fit your organisational requirements. 37 | This determines whether all groups (that are assigned to the application you will create below) 38 | a user is a member of will be returned, or if only a subset of groups matching the given regex will be returned. 39 | The above setting will return all groups the user is a member of. 40 | 41 | ## 3. Create a new Application 42 | Select the 'Applications' dropdown and then 'Applications'. Click 'Add Application' and 'Create New App' 43 | 44 | For 'Platform', choose `Web`, and for 'Sign on method' choose `OpenID Connect`. You will be asked to fill in some fields: 45 | - **Application Name**: Any appropriate name. 46 | - **Application Logo**: Optional - an image for your app. 47 | - **Login redirect URIs**: Add the URI of your `sso-auth` deployment, with the path suffix `/okta/callback`. 48 | For example, if `sso-auth` will be accessible at the domain `sso-auth.example.com`, then add the URI 49 | `https://sso-auth.example.com/okta/callback`. Multiple URIs can be added if required. 50 | - **Logout redirect URIs**: This can be left blank. 51 | 52 | **⚡️ Note**: If you're following the [Quickstart guide](https://github.com/buzzfeed/sso/blob/main/docs/quickstart.md), use `http://sso-auth.localtest.me/okta/callback` as the Authorized redirect URI. 53 | 54 | ## 4. Finish configuring your Application 55 | 56 | ### Configure remaining settings 57 | 58 | Once your application is created, you will be redirected to the Application's settings page. 59 | We want to allow `sso-auth` to request Refresh Tokens, so select 'Edit' under 'General Settings' 60 | and make sure `Refresh Token` is checked under 'Client acting on behalf of a user' 61 | 62 | 63 | 64 | ### Client Credentials 65 | 66 | At the bottom of this same page you will see a section called 'Client Credentials' containing a `Client ID` and a `Client secret`. 67 | Copy both of these values to a safe and secure location - these are required to configure your `sso` deployment. 68 | 69 | **⚡️Note:** If you're following the [Quickstart guide](https://github.com/buzzfeed/sso/blob/main/docs/quickstart.md), **stop here!** 70 | You'll add these credentials to `quickstart/env` as instructed in [the guide](https://github.com/buzzfeed/sso/blob/main/docs/quickstart.md). 71 | 72 | ### Assignments 73 | 74 | Lastly, any users that will be authenticating through SSO need to be assigned to this Application. They can be assigned via the 'People' assignment 75 | and/or the 'Groups' assignment within the Application settings itself. 76 | 77 | **⚡️ Note**: Only groups that are assigned to this app will be returned along with the `groups` scope/claim. 78 | -------------------------------------------------------------------------------- /docs/sso_authenticator_config.md: -------------------------------------------------------------------------------- 1 | # Available Configuration Variables 2 | We currently use environment variables to read configuration options, below is a list of environment variables and 3 | their corresponding types that the sso authenticator will read. 4 | 5 | Defaults for the below settings can be found here: https://github.com/buzzfeed/sso/blob/main/internal/auth/configuration.go#L66-L117 6 | 7 | 8 | ## Session and Server configuration 9 | 10 | ### Session 11 | ``` 12 | SESSION_COOKIE_NAME - string - name associated with the session cookie 13 | SESSION_COOKIE_SECRET - string - seed string for secure cookies 14 | SESSION_COOKIE_DOMAIN - string - cookie domain to force cookies to (ie: .yourcompany.com)* 15 | SESSION_KEY - string - seed string for secure auth codes 16 | SESSION_COOKIE_SECURE - bool - set secure (HTTPS) cookie flag 17 | SESSION_COOKIE_HTTPONLY - bool - set 'httponly' cookie flag 18 | SESSION_COOKIE_REFRESH - time.Duration - duration to refresh the cookie after 19 | SESSION_COOKIE_EXPIRE - time.Duration - duration that cookie is valid for 20 | SESSION_LIFETIME - time.Duration - the session TTL 21 | ``` 22 | 23 | 24 | ### Client 25 | 26 | ``` 27 | CLIENT_PROXY_ID - string - Client ID matching the SSO Proxy client ID 28 | CLIENT_PROXY_SECRET - string - Client secret matching the SSO Proxy client secret 29 | ``` 30 | 31 | 32 | ### Server 33 | ``` 34 | SERVER_SCHEME - string - scheme the server will use, e.g. `https` 35 | SERVER_HOST - string - host header that's required on incoming requests 36 | SERVER_PORT - string - port the http server listens on 37 | SERVER_TIMEOUT_REQUEST - time.Duration - overall request timeout 38 | SERVER_TIMEOUT_WRITE - time.Duration - write request timeout 39 | SERVER_TIMEOUT_READ - time.Duration - read request timeout 40 | SERVER_TIMEOUT_SHUTDOWN - time.Duration - time to allow in-flight requests to complete before server shutdown 41 | ``` 42 | 43 | 44 | ### Authorization 45 | ``` 46 | AUTHORIZE_PROXY_DOMAINS - []string - only redirect to the specified proxy domains. 47 | AUTHORIZE_EMAIL_DOMAINS - []string - authenticate emails with the specified domains. Use `*` to authenticate any email 48 | AUTHORIZE_EMAIL_ADDRESSES - []string - authenticate emails with the specified email addresses. Use `*` to authenticate any email 49 | ``` 50 | 51 | ## Logging and Monitoring Configuration 52 | ### StatsD 53 | ``` 54 | METRICS_STATSD_PORT - int - port that statsdclient listens on 55 | METRICS_STATSD_HOST - string - hostname that statsd client uses 56 | ``` 57 | 58 | ### Logging 59 | ``` 60 | LOGGING_ENABLE - bool - enable request logging 61 | LOGGING_LEVEL - string - level at which to log at, e.g. 'INFO' 62 | ``` 63 | 64 | ## Provider configuration 65 | 66 | `*` in the below variables acts as a logical idendifier to group configuration variables together for any one provider. 67 | This should be changed to an identifier that makes sense for your use case. 68 | ``` 69 | PROVIDER_*_TYPE - string - determines the type of provider (supported options: google, okta) 70 | PROVIDER_*_SLUG - string - unique provider 'slug' that is used to separate and create routes to individual providers. 71 | PROVIDER_*_CLIENT_ID - string - OAuth Client ID 72 | PROVIDER_*_CLIENT_SECRET - string - OAuth Client secret 73 | PROVIDER_*_SCOPE - string - OAuth scopes the provider will use. Default standard set of scopes pre-set in individual provider 74 | files; which this configuration variable overrides. 75 | ``` 76 | 77 | ### Google provider specific 78 | ``` 79 | PROVIDER_*_GOOGLE_CREDENTIALS - string - the path to the Google account's json credential file 80 | PROVIDER_*_GOOGLE_IMPERSONATE - string - the Google account to impersonate for API calls 81 | ``` 82 | 83 | ### Okta provider specific 84 | ``` 85 | PROVIDER_*_OKTA_URL - string - the URL for your Okta domain, e.g. `.okta.com` 86 | PROVIDER_*_OKTA_SERVER - string - the authorisation server ID 87 | ``` 88 | 89 | ### Group refresh and caching 90 | ``` 91 | PROVIDER_*_GROUPCACHE_INTERVAL_REFRESH - time.Duration - cache TTL for the groups fillcache mechanism used to preemptively fill group caches 92 | PROVIDER_*_GROUPCACHE_INTERVAL_PROVIDER - time.Duration - cache TTL for the group cache provider used for on demand group caching 93 | ``` 94 | -------------------------------------------------------------------------------- /docs/sso_proxy_config.md: -------------------------------------------------------------------------------- 1 | # Available Configuration Variables 2 | We currently use environment variables to read configuration options, below is a list of environment variables and 3 | their corresponding types that the sso proxy will read. 4 | 5 | Defaults for the below settings can be found here: https://github.com/buzzfeed/sso/blob/main/internal/proxy/configuration.go#L61-L105 6 | 7 | 8 | ## Session and Server configuration 9 | 10 | ### Session 11 | ``` 12 | SESSION_COOKIE_NAME - string - name associated with the session cookie 13 | SESSION_COOKIE_SECRET - string - seed string for secure cookies 14 | SESSION_COOKIE_DOMAIN - string - cookie domain to force cookies to (ie: .yourcompany.com)* 15 | SESSION_COOKIE_SECURE - bool - set secure (HTTPS) cookie flag 16 | SESSION_COOKIE_HTTPONLY - bool - set 'httponly' cookie flag 17 | SESSION_TTL_LIFETIME - time.Duration - 'time-to-live' of a session lifetime 18 | SESSION_TTL_VALID - time.Duration - 'time-to-live' of a valid session 19 | SESSION_TTL_GRACEPERIOD - time.Duration - time period in which session data can be reused while the provider is unavailable. 20 | ``` 21 | 22 | ### Server 23 | ``` 24 | SERVER_PORT - string - port the http server listens on 25 | SERVER_TIMEOUT_SHUTDOWN - time.Duration - time to allow in-flight requests to complete before server shutdown 26 | SERVER_TIMEOUT_READ - time.Duration - read request timeout 27 | SERVER_TIMEOUT_WRITE - time.Duration - write request timeout 28 | ``` 29 | 30 | ### Client 31 | ``` 32 | CLIENT_ID - string - the OAuth Client ID: ie: "123456.apps.googleusercontent.com" 33 | CLIENT_SECRET - string - the OAuth Client secret 34 | ``` 35 | 36 | ### Request Signer 37 | ``` 38 | REQUESTSIGNER_KEY - string - RSA private key used for digitally signing requests 39 | ``` 40 | 41 | ## Upstream and Provider Configuration 42 | 43 | ### Upstream 44 | For further upstream configuration, see https://github.com/buzzfeed/sso/blob/main/docs/sso_config.md. 45 | ``` 46 | UPSTREAM_DEFAULT_EMAIL_DOMAINS - []string - default setting for upstream `allowed_email_domains` variable 47 | UPSTREAM_DEFAULT_EMAIL_ADDRESSES - []string - default setting for upstream `allowed_email_addresses` variable 48 | UPSTREAM_DEFAULT_GROUPS - []string - default setting for upstream `allowed groups` variable 49 | UPSTREAM_DEFAULT_TIMEOUT - time.Duration - default setting for upstream `timeout` variable 50 | UPSTREAM_DEFAULT_TCP_RESET_DEADLINE - time.Duration - default time period to wait for a response from an upstream 51 | UPSTREAM_DEFAULT_PROVIDER - string - default setting for the upstream `provider_slug` variable 52 | UPSTREAM_CONFIGS_FILE - string - path to the file containing upstream configurations 53 | UPSTREAM_SCHEME - string - the scheme used for upstreams (e.g. `https`) 54 | UPSTREAM_CLUSTER - string - the cluster in which this is running, used within upstream configuration 55 | ``` 56 | 57 | ## Provider 58 | ``` 59 | PROVIDER_TYPE - string - string - the 'type' of upstream provider to use (at this time, only a provider type of 'sso' is supported) 60 | PROVIDER_URL_EXTERNAL - string - the external URL for the upstream provider in this environment (e.g. "https://sso-auth.example.com") 61 | PROVIDER_URL_INTERNAL - string - the internal URL for the upstream provider in this environment (e.g. "https://sso-auth-int.example.com") 62 | PROVIDER_SCOPE - string - OAuth `scope` sent with provider requests 63 | ``` 64 | 65 | ## Logging and Monitoring Configuration 66 | ### StatsD 67 | ``` 68 | METRICS_STATSD_PORT - int - port that statsdclient listens on 69 | METRICS_STATSD_HOST - string - hostname that statsd client uses 70 | ``` 71 | 72 | ### Logging 73 | ``` 74 | LOGGING_ENABLE - bool - enable request logging 75 | ``` 76 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/buzzfeed/sso 2 | 3 | require ( 4 | github.com/18F/hmacauth v0.0.0-20151013130326-9232a6386b73 5 | github.com/aws/aws-sdk-go v1.23.12 6 | github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 7 | github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect 8 | github.com/datadog/datadog-go v0.0.0-20180822151419-281ae9f2d895 9 | github.com/gorilla/mux v1.7.2 10 | github.com/gorilla/websocket v1.4.0 11 | github.com/imdario/mergo v0.3.7 12 | github.com/mccutchen/go-httpbin v1.1.1 13 | github.com/micro/go-micro v1.5.0 14 | github.com/miscreant/miscreant.go v0.0.0-20200214223636-26d376326b75 15 | github.com/mitchellh/mapstructure v1.1.2 16 | github.com/rakyll/statik v0.1.7 17 | github.com/sirupsen/logrus v1.4.2 18 | golang.org/x/net v0.21.0 // indirect 19 | golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 20 | golang.org/x/sync v0.1.0 21 | golang.org/x/sys v0.19.0 // indirect 22 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 23 | google.golang.org/api v0.5.0 24 | gopkg.in/yaml.v2 v2.2.2 25 | ) 26 | 27 | go 1.14 28 | -------------------------------------------------------------------------------- /internal/auth/conf.d/gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buzzfeed/sso/549155a64d6c5f8916ed909cfa4e340734056284/internal/auth/conf.d/gitkeep -------------------------------------------------------------------------------- /internal/auth/error.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/buzzfeed/sso/internal/auth/providers" 9 | log "github.com/buzzfeed/sso/internal/pkg/logging" 10 | ) 11 | 12 | var ( 13 | // ErrUserNotAuthorized is an error for unauthorized users. 14 | ErrUserNotAuthorized = errors.New("user not authorized") 15 | ) 16 | 17 | // HTTPError stores the status code and a message for a given HTTP error. 18 | type HTTPError struct { 19 | Code int 20 | Message string 21 | } 22 | 23 | // Error fulfills the error interface, returning a string representation of the error. 24 | func (h HTTPError) Error() string { 25 | return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message) 26 | } 27 | 28 | func codeForError(err error) int { 29 | switch err { 30 | case providers.ErrBadRequest: 31 | return 400 32 | case providers.ErrTokenRevoked: 33 | return 401 34 | case providers.ErrRateLimitExceeded: 35 | return 429 36 | case providers.ErrServiceUnavailable: 37 | return 503 38 | case ErrUserNotAuthorized: 39 | return 401 40 | } 41 | return 500 42 | } 43 | 44 | // ErrorResponse renders an error page for errors given a message and a status code. 45 | func (p *Authenticator) ErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) { 46 | logger := log.NewLogEntry() 47 | 48 | if req.Header.Get("Accept") == "application/json" { 49 | var response struct { 50 | Error string `json:"error"` 51 | } 52 | response.Error = message 53 | writeJSONResponse(rw, code, response) 54 | } else { 55 | title := http.StatusText(code) 56 | logger.WithHTTPStatus(code).WithPageTitle(title).WithPageMessage(message).Info( 57 | "error page") 58 | rw.WriteHeader(code) 59 | t := struct { 60 | Code int 61 | Title string 62 | Message string 63 | }{ 64 | Code: code, 65 | Title: title, 66 | Message: message, 67 | } 68 | p.templates.ExecuteTemplate(rw, "error.html", t) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /internal/auth/http.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | // writeJSONResponse is a helper that sets the application/json header and writes a response. 10 | func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) { 11 | rw.Header().Set("Content-Type", "application/json") 12 | rw.WriteHeader(code) 13 | 14 | err := json.NewEncoder(rw).Encode(response) 15 | if err != nil { 16 | io.WriteString(rw, err.Error()) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /internal/auth/logging_handler.go: -------------------------------------------------------------------------------- 1 | // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go 2 | // to add logging of request duration as last value (and drop referrer) 3 | 4 | package auth 5 | 6 | import ( 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | "time" 12 | 13 | log "github.com/buzzfeed/sso/internal/pkg/logging" 14 | "github.com/datadog/datadog-go/statsd" 15 | ) 16 | 17 | // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status 18 | // code and body size 19 | type responseLogger struct { 20 | w http.ResponseWriter 21 | status int 22 | size int 23 | proxyHost string 24 | authInfo string 25 | } 26 | 27 | func (l *responseLogger) Header() http.Header { 28 | return l.w.Header() 29 | } 30 | 31 | func (l *responseLogger) ExtractGAPMetadata() { 32 | authInfo := l.w.Header().Get("GAP-Auth") 33 | if authInfo != "" { 34 | l.authInfo = authInfo 35 | 36 | l.w.Header().Del("GAP-Auth") 37 | } 38 | } 39 | 40 | func (l *responseLogger) Write(b []byte) (int, error) { 41 | if l.status == 0 { 42 | // The status will be StatusOK if WriteHeader has not been called yet 43 | l.status = http.StatusOK 44 | } 45 | l.ExtractGAPMetadata() 46 | size, err := l.w.Write(b) 47 | l.size += size 48 | return size, err 49 | } 50 | 51 | func (l *responseLogger) WriteHeader(s int) { 52 | l.ExtractGAPMetadata() 53 | l.w.WriteHeader(s) 54 | l.status = s 55 | } 56 | 57 | func (l *responseLogger) Status() int { 58 | return l.status 59 | } 60 | 61 | func (l *responseLogger) Size() int { 62 | return l.size 63 | } 64 | 65 | // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends 66 | type loggingHandler struct { 67 | writer io.Writer 68 | StatsdClient *statsd.Client 69 | handler http.Handler 70 | enabled bool 71 | } 72 | 73 | // NewLoggingHandler creates a new loggingHandler 74 | func NewLoggingHandler(out io.Writer, h http.Handler, v bool, StatsdClient *statsd.Client) http.Handler { 75 | return loggingHandler{ 76 | writer: out, 77 | handler: h, 78 | enabled: v, 79 | StatsdClient: StatsdClient, 80 | } 81 | } 82 | 83 | // getProxyHost attempts to get the proxy host from the redirect_uri parameter 84 | func getProxyHost(req *http.Request) string { 85 | err := req.ParseForm() 86 | if err != nil { 87 | return "" 88 | } 89 | redirect := req.Form.Get("redirect_uri") 90 | redirectURL, err := url.Parse(redirect) 91 | if err != nil { 92 | return "" 93 | } 94 | return redirectURL.Host 95 | } 96 | 97 | func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 98 | t := time.Now() 99 | url := *req.URL 100 | logger := &responseLogger{w: w, proxyHost: getProxyHost(req)} 101 | h.handler.ServeHTTP(logger, req) 102 | if !h.enabled { 103 | return 104 | } 105 | requestDuration := time.Now().Sub(t) 106 | logRequest(logger.proxyHost, logger.authInfo, req, url, requestDuration, logger.Status(), h.StatsdClient) 107 | } 108 | 109 | // logRequests creates a log message from the request status, method, url, proxy host and duration of the request 110 | func logRequest(proxyHost, user string, req *http.Request, url url.URL, requestDuration time.Duration, status int, StatsdClient *statsd.Client) { 111 | // Convert duration to floating point milliseconds 112 | // https://github.com/golang/go/issues/5491#issuecomment-66079585 113 | durationMS := requestDuration.Seconds() * 1e3 114 | 115 | logger := log.NewLogEntry() 116 | logger.WithHTTPStatus(status).WithRequestMethod(req.Method).WithRequestURI( 117 | url.RequestURI()).WithProxyHost(proxyHost).WithUser(user).WithUserAgent( 118 | req.Header.Get("User-Agent")).WithRemoteAddress( 119 | getRemoteAddr(req)).WithRequestDurationMs( 120 | durationMS).WithAction(GetActionTag(req)).Info() 121 | logRequestMetrics(proxyHost, req, requestDuration, status, StatsdClient) 122 | } 123 | 124 | // getRemoteAddr returns the client IP address from a request. If present, the 125 | // X-Forwarded-For header is assumed to be set by a load balancer, and its 126 | // rightmost entry (the client IP that connected to the LB) is returned. 127 | func getRemoteAddr(req *http.Request) string { 128 | addr := req.RemoteAddr 129 | forwardedHeader := req.Header.Get("X-Forwarded-For") 130 | if forwardedHeader != "" { 131 | forwardedList := strings.Split(forwardedHeader, ",") 132 | forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1]) 133 | if forwardedAddr != "" { 134 | addr = forwardedAddr 135 | } 136 | } 137 | return addr 138 | } 139 | -------------------------------------------------------------------------------- /internal/auth/logging_handler_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "net/http/httptest" 5 | "testing" 6 | ) 7 | 8 | func TestGetRemoteAddr(t *testing.T) { 9 | testCases := []struct { 10 | name string 11 | remoteAddr string 12 | forwardedHeader string 13 | expectedAddr string 14 | }{ 15 | { 16 | name: "RemoteAddr used when no X-Forwarded-For header is given", 17 | remoteAddr: "1.1.1.1", 18 | expectedAddr: "1.1.1.1", 19 | }, 20 | { 21 | name: "RemoteAddr used when no X-Forwarded-For header is only whitespace", 22 | remoteAddr: "1.1.1.1", 23 | forwardedHeader: " ", 24 | expectedAddr: "1.1.1.1", 25 | }, 26 | { 27 | name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace", 28 | remoteAddr: "1.1.1.1", 29 | forwardedHeader: " , , ", 30 | expectedAddr: "1.1.1.1", 31 | }, 32 | { 33 | name: "X-Forwarded-For header is preferred to RemoteAddr", 34 | remoteAddr: "1.1.1.1", 35 | forwardedHeader: "9.9.9.9", 36 | expectedAddr: "9.9.9.9", 37 | }, 38 | { 39 | name: "rightmost entry in X-Forwarded-For header is used", 40 | remoteAddr: "1.1.1.1", 41 | forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5", 42 | expectedAddr: "5.5.5.5", 43 | }, 44 | { 45 | name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty", 46 | remoteAddr: "1.1.1.1", 47 | forwardedHeader: "2.2.2.2, 3.3.3.3, ", 48 | expectedAddr: "1.1.1.1", 49 | }, 50 | { 51 | name: "X-Forwaded-For header entries are stripped", 52 | remoteAddr: "1.1.1.1", 53 | forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ", 54 | expectedAddr: "5.5.5.5", 55 | }, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | req := httptest.NewRequest("GET", "/", nil) 61 | req.RemoteAddr = tc.remoteAddr 62 | if tc.forwardedHeader != "" { 63 | req.Header.Set("X-Forwarded-For", tc.forwardedHeader) 64 | } 65 | 66 | addr := getRemoteAddr(req) 67 | if addr != tc.expectedAddr { 68 | t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr) 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /internal/auth/metrics.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "strconv" 8 | "strings" 9 | "time" 10 | 11 | "github.com/datadog/datadog-go/statsd" 12 | ) 13 | 14 | func NewStatsdClient(host string, port int) (*statsd.Client, error) { 15 | client, err := statsd.New(net.JoinHostPort(host, strconv.Itoa(port))) 16 | if err != nil { 17 | return nil, err 18 | } 19 | client.Namespace = "sso_auth." 20 | client.Tags = []string{ 21 | "service:sso_auth", 22 | } 23 | return client, nil 24 | } 25 | 26 | // GetActionTag returns the tag associated with a route 27 | func GetActionTag(req *http.Request) string { 28 | // only log metrics for these paths and actions 29 | pathToAction := map[string]string{ 30 | "/robots.txt": "robots", 31 | "/start": "start", 32 | "/sign_in": "sign_in", 33 | "/sign_out": "sign_out", 34 | "/callback": "callback", 35 | "/profile": "profile", 36 | "/validate": "validate", 37 | "/redeem": "redeem", 38 | "/refresh": "refresh", 39 | "/ping": "ping", 40 | } 41 | // get the action from the url path 42 | path := req.URL.Path 43 | splitPath := strings.Split(path, "/") 44 | pathBase := fmt.Sprintf("/%s", splitPath[len(splitPath)-1]) 45 | 46 | if action, ok := pathToAction[pathBase]; ok { 47 | return action 48 | } 49 | if strings.HasPrefix(path, "/static/") { 50 | return "static" 51 | } 52 | return "unknown" 53 | 54 | } 55 | 56 | // logMetrics logs all metrics surrounding a given request to the metricsWriter 57 | func logRequestMetrics(proxyHost string, req *http.Request, requestDuration time.Duration, status int, StatsdClient *statsd.Client) { 58 | 59 | if proxyHost == "" { 60 | proxyHost = "unknown" 61 | } 62 | 63 | tags := []string{ 64 | fmt.Sprintf("method:%s", req.Method), 65 | fmt.Sprintf("status_code:%d", status), 66 | fmt.Sprintf("status_category:%dxx", status/100), 67 | fmt.Sprintf("proxy_host:%s", proxyHost), 68 | fmt.Sprintf("action:%s", GetActionTag(req)), 69 | } 70 | 71 | // TODO: eventually make rates configurable 72 | StatsdClient.Timing("request.duration", requestDuration, tags, 1.0) 73 | 74 | } 75 | -------------------------------------------------------------------------------- /internal/auth/mux.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/buzzfeed/sso/internal/pkg/hostmux" 8 | log "github.com/buzzfeed/sso/internal/pkg/logging" 9 | "github.com/buzzfeed/sso/internal/pkg/validators" 10 | 11 | "github.com/datadog/datadog-go/statsd" 12 | "github.com/gorilla/mux" 13 | ) 14 | 15 | type AuthenticatorMux struct { 16 | handler http.Handler 17 | authenticators []*Authenticator 18 | } 19 | 20 | func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*AuthenticatorMux, error) { 21 | logger := log.NewLogEntry() 22 | v := []validators.Validator{} 23 | if len(config.AuthorizeConfig.EmailConfig.Addresses) != 0 { 24 | v = append(v, validators.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses)) 25 | } else { 26 | v = append(v, validators.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains)) 27 | } 28 | 29 | authenticators := []*Authenticator{} 30 | idpMux := mux.NewRouter() 31 | idpMux.UseEncodedPath() 32 | 33 | for slug, providerConfig := range config.ProviderConfigs { 34 | idp, err := newProvider(providerConfig, config.SessionConfig) 35 | if err != nil { 36 | logger.Error(err, fmt.Sprintf("error creating provider.%s", slug)) 37 | return nil, err 38 | } 39 | 40 | idpSlug := idp.Data().ProviderSlug 41 | authenticator, err := NewAuthenticator(config, 42 | SetValidators(v), 43 | SetProvider(idp), 44 | SetCookieStore(config.SessionConfig, idpSlug), 45 | SetStatsdClient(statsdClient), 46 | SetRedirectURL(config.ServerConfig, idpSlug), 47 | ) 48 | if err != nil { 49 | logger.Error(err, "error creating new Authenticator") 50 | return nil, err 51 | } 52 | 53 | authenticators = append(authenticators, authenticator) 54 | 55 | // setup our mux with the idpslug as the first part of the path 56 | providerPath := fmt.Sprintf("/%s", idpSlug) 57 | idpMux.PathPrefix(providerPath).Handler(http.StripPrefix(providerPath, authenticator.ServeMux)) 58 | } 59 | 60 | // load static files 61 | fsHandler, err := loadFSHandler() 62 | if err != nil { 63 | logger.Fatal(err) 64 | } 65 | idpMux.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fsHandler)) 66 | idpMux.HandleFunc("/robots.txt", RobotsTxt) 67 | 68 | hostRouter := hostmux.NewRouter() 69 | hostRouter.HandleStatic(config.ServerConfig.Host, idpMux) 70 | 71 | healthcheckHandler := setHealthCheck("/ping", hostRouter) 72 | 73 | return &AuthenticatorMux{ 74 | handler: healthcheckHandler, 75 | authenticators: authenticators, 76 | }, nil 77 | } 78 | 79 | func (a *AuthenticatorMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 80 | a.handler.ServeHTTP(w, r) 81 | } 82 | 83 | func (a *AuthenticatorMux) Stop() { 84 | for _, authenticator := range a.authenticators { 85 | authenticator.Stop() 86 | } 87 | } 88 | 89 | func setHealthCheck(healthcheckPath string, next http.Handler) http.Handler { 90 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 91 | if r.URL.Path == healthcheckPath { 92 | w.WriteHeader(http.StatusOK) 93 | return 94 | } 95 | next.ServeHTTP(w, r) 96 | }) 97 | } 98 | 99 | // RobotsTxt handles the /robots.txt route 100 | func RobotsTxt(rw http.ResponseWriter, req *http.Request) { 101 | rw.WriteHeader(http.StatusOK) 102 | fmt.Fprintf(rw, "User-agent: *\nDisallow: /") 103 | } 104 | -------------------------------------------------------------------------------- /internal/auth/mux_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestHostHeader(t *testing.T) { 11 | testCases := []struct { 12 | Name string 13 | Host string 14 | RequestHost string 15 | Path string 16 | ExpectedStatusCode int 17 | }{ 18 | { 19 | Name: "reject requests with an invalid hostname", 20 | Host: "example.com", 21 | RequestHost: "unknown.com", 22 | Path: "/static/sso.css", 23 | ExpectedStatusCode: http.StatusMisdirectedRequest, 24 | }, 25 | { 26 | Name: "allow requests to any hostname to /ping", 27 | Host: "example.com", 28 | RequestHost: "unknown.com", 29 | Path: "/ping", 30 | ExpectedStatusCode: http.StatusOK, 31 | }, 32 | { 33 | Name: "allow requests to specified hostname to /ping", 34 | Host: "example.com", 35 | RequestHost: "example.com", 36 | Path: "/ping", 37 | ExpectedStatusCode: http.StatusOK, 38 | }, 39 | { 40 | Name: "allow requests with a valid hostname", 41 | Host: "example.com", 42 | RequestHost: "example.com", 43 | Path: "/static/sso.css", 44 | ExpectedStatusCode: http.StatusOK, 45 | }, 46 | } 47 | for _, tc := range testCases { 48 | t.Run(tc.Name, func(t *testing.T) { 49 | config := testConfiguration(t) 50 | config.ServerConfig.Host = tc.Host 51 | authMux, err := NewAuthenticatorMux(config, nil) 52 | if err != nil { 53 | t.Fatalf("unexpected err creating auth mux: %v", err) 54 | } 55 | 56 | uri := fmt.Sprintf("http://%s%s", tc.RequestHost, tc.Path) 57 | 58 | rw := httptest.NewRecorder() 59 | req := httptest.NewRequest("GET", uri, nil) 60 | 61 | authMux.ServeHTTP(rw, req) 62 | if rw.Code != tc.ExpectedStatusCode { 63 | t.Errorf("got unexpected status code") 64 | t.Errorf("want %v", tc.ExpectedStatusCode) 65 | t.Errorf(" got %v", rw.Code) 66 | t.Errorf(" headers %v", rw) 67 | t.Errorf(" body: %q", rw.Body) 68 | } 69 | }) 70 | } 71 | } 72 | 73 | func TestRobotsTxt(t *testing.T) { 74 | config := testConfiguration(t) 75 | authMux, err := NewAuthenticatorMux(config, nil) 76 | if err != nil { 77 | t.Fatalf("unexpected err creating auth mux: %v", err) 78 | } 79 | 80 | rw := httptest.NewRecorder() 81 | req := httptest.NewRequest("GET", fmt.Sprintf("https://%s/robots.txt", config.ServerConfig.Host), nil) 82 | authMux.ServeHTTP(rw, req) 83 | 84 | if rw.Code != http.StatusOK { 85 | t.Errorf("expected status code %d, but got %d", http.StatusOK, rw.Code) 86 | } 87 | if rw.Body.String() != "User-agent: *\nDisallow: /" { 88 | t.Errorf("expected response body to be %s but was %s", "User-agent: *\nDisallow: /", rw.Body.String()) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /internal/auth/options_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/buzzfeed/sso/internal/pkg/testutil" 8 | ) 9 | 10 | func TestGoogleProviderApiSettings(t *testing.T) { 11 | provider, err := newProvider( 12 | ProviderConfig{ProviderType: "google"}, 13 | SessionConfig{SessionLifetimeTTL: 1 * time.Hour}, 14 | ) 15 | if err != nil { 16 | t.Fatalf("unexpected err generating google provider: %v", err) 17 | } 18 | p := provider.Data() 19 | testutil.Equal(t, "https://accounts.google.com/o/oauth2/v2/auth", 20 | p.SignInURL.String()) 21 | testutil.Equal(t, "https://www.googleapis.com/oauth2/v4/token", 22 | p.RedeemURL.String()) 23 | 24 | testutil.Equal(t, "", p.ProfileURL.String()) 25 | testutil.Equal(t, "profile email", p.Scope) 26 | } 27 | 28 | func TestGoogleGroupInvalidFile(t *testing.T) { 29 | _, err := newProvider( 30 | ProviderConfig{ 31 | ProviderType: "google", 32 | GoogleProviderConfig: GoogleProviderConfig{ 33 | Credentials: "file_doesnt_exist.json", 34 | }, 35 | }, 36 | SessionConfig{ 37 | SessionLifetimeTTL: 1 * time.Hour, 38 | }, 39 | ) 40 | testutil.NotEqual(t, nil, err) 41 | testutil.Equal(t, "could not read google credentials file", err.Error()) 42 | } 43 | 44 | func TestAmazonCognitoProviderSettings(t *testing.T) { 45 | testCases := []struct { 46 | Name string 47 | ProviderConfig ProviderConfig 48 | }{ 49 | { 50 | Name: "with input parameters, default scope", 51 | ProviderConfig: ProviderConfig{ 52 | ProviderType: "cognito", 53 | Scope: "", 54 | AmazonCognitoProviderConfig: AmazonCognitoProviderConfig{ 55 | OrgURL: "test.cognito.com", 56 | UserPoolID: "12345", 57 | Region: "test-cognito-region", 58 | Credentials: CognitoCredentials{ 59 | ID: "test-credentials-id", 60 | Secret: "test-credentials-secret", 61 | }, 62 | }, 63 | }, 64 | }, 65 | { 66 | Name: "with input parameters, override scope", 67 | ProviderConfig: ProviderConfig{ 68 | ProviderType: "cognito", 69 | Scope: "openid email", 70 | AmazonCognitoProviderConfig: AmazonCognitoProviderConfig{ 71 | OrgURL: "test.cognito.com", 72 | UserPoolID: "12345", 73 | Region: "test-cognito-region", 74 | Credentials: CognitoCredentials{ 75 | ID: "test-credentials-id", 76 | Secret: "test-credentials-secret", 77 | }, 78 | }, 79 | }, 80 | }, 81 | } 82 | 83 | for _, tc := range testCases { 84 | t.Run(tc.Name, func(t *testing.T) { 85 | provider, err := newProvider( 86 | tc.ProviderConfig, 87 | SessionConfig{SessionLifetimeTTL: 1 * time.Hour}, 88 | ) 89 | if err != nil { 90 | t.Fatalf("unexpected err generating cognito provider: %q", err) 91 | } 92 | 93 | p := provider.Data() 94 | testutil.Equal(t, "https://test.cognito.com/oauth2/authorize", p.SignInURL.String()) 95 | testutil.Equal(t, "https://test.cognito.com/oauth2/token", p.RedeemURL.String()) 96 | testutil.Equal(t, "https://test.cognito.com/oauth2/userInfo", p.ProfileURL.String()) 97 | testutil.Equal(t, "https://test.cognito.com/oauth2/userInfo", p.ValidateURL.String()) 98 | if tc.ProviderConfig.Scope != "" { 99 | testutil.Equal(t, tc.ProviderConfig.Scope, p.Scope) 100 | } else { 101 | testutil.Equal(t, "openid profile email aws.cognito.signin.user.admin", p.Scope) 102 | } 103 | }) 104 | } 105 | } 106 | 107 | func TestOktaProviderSettings(t *testing.T) { 108 | testCases := []struct { 109 | Name string 110 | ProviderConfig ProviderConfig 111 | }{ 112 | { 113 | Name: "with input parameters, default scope", 114 | ProviderConfig: ProviderConfig{ 115 | ProviderType: "okta", 116 | Scope: "", 117 | OktaProviderConfig: OktaProviderConfig{ 118 | OrgURL: "test.okta.com", 119 | ServerID: "12345", 120 | }, 121 | }, 122 | }, 123 | { 124 | Name: "with input parameters, override scope", 125 | ProviderConfig: ProviderConfig{ 126 | ProviderType: "okta", 127 | Scope: "openid email", 128 | OktaProviderConfig: OktaProviderConfig{ 129 | OrgURL: "test.okta.com", 130 | ServerID: "12345", 131 | }, 132 | }, 133 | }, 134 | } 135 | 136 | for _, tc := range testCases { 137 | t.Run(tc.Name, func(t *testing.T) { 138 | provider, err := newProvider( 139 | tc.ProviderConfig, 140 | SessionConfig{SessionLifetimeTTL: 1 * time.Hour}, 141 | ) 142 | if err != nil { 143 | t.Fatalf("unexpected err generating cognito provider: %q", err) 144 | } 145 | 146 | p := provider.Data() 147 | testutil.Equal(t, "https://test.okta.com/oauth2/12345/v1/authorize", p.SignInURL.String()) 148 | testutil.Equal(t, "https://test.okta.com/oauth2/12345/v1/token", p.RedeemURL.String()) 149 | testutil.Equal(t, "https://test.okta.com/oauth2/12345/v1/revoke", p.RevokeURL.String()) 150 | testutil.Equal(t, "https://test.okta.com/oauth2/12345/v1/userinfo", p.ProfileURL.String()) 151 | testutil.Equal(t, "https://test.okta.com/oauth2/12345/v1/introspect", p.ValidateURL.String()) 152 | if tc.ProviderConfig.Scope != "" { 153 | testutil.Equal(t, tc.ProviderConfig.Scope, p.Scope) 154 | } else { 155 | testutil.Equal(t, "openid profile email groups offline_access", p.Scope) 156 | } 157 | }) 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /internal/auth/providers/amazon_cognito_mock_admin.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "github.com/buzzfeed/sso/internal/pkg/sessions" 5 | ) 6 | 7 | // MockAdminService is an implementation of the Amazon Cognito AdminService to be used for testing 8 | type MockCognitoAdminService struct { 9 | Members []string 10 | Groups []string 11 | MembersError error 12 | GroupsError error 13 | UserName string 14 | UserInfoError error 15 | GlobalSignOutError error 16 | } 17 | 18 | // ListMemberships mocks the ListMemebership function 19 | func (ms *MockCognitoAdminService) ListMemberships(string) ([]string, error) { 20 | return ms.Members, ms.MembersError 21 | } 22 | 23 | // CheckMemberships mocks the CheckMemberships function 24 | func (ms *MockCognitoAdminService) CheckMemberships(string) ([]string, error) { 25 | return ms.Groups, ms.GroupsError 26 | } 27 | 28 | // GlobalSignOut mocks the GlobalSignOut function 29 | func (ms *MockCognitoAdminService) GlobalSignOut(*sessions.SessionState) error { 30 | return ms.GlobalSignOutError 31 | } 32 | -------------------------------------------------------------------------------- /internal/auth/providers/google_mock_admin.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | // MockAdminService is an implementation of AdminService to be used for testing 4 | type MockAdminService struct { 5 | Members []string 6 | Groups []string 7 | MembersError error 8 | GroupsError error 9 | } 10 | 11 | // ListMemberships mocks the ListMemebership function 12 | func (ms *MockAdminService) ListMemberships(string, int) ([]string, error) { 13 | return ms.Members, ms.MembersError 14 | } 15 | 16 | // CheckMemberships mocks the CheckMemberships function 17 | func (ms *MockAdminService) CheckMemberships([]string, string) ([]string, error) { 18 | return ms.Groups, ms.GroupsError 19 | } 20 | -------------------------------------------------------------------------------- /internal/auth/providers/group_cache.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "sort" 5 | "strings" 6 | "time" 7 | 8 | "github.com/buzzfeed/sso/internal/pkg/groups" 9 | "github.com/buzzfeed/sso/internal/pkg/sessions" 10 | "github.com/datadog/datadog-go/statsd" 11 | ) 12 | 13 | var ( 14 | // This is a compile-time check to make sure our types correctly implement the interface: 15 | // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae 16 | _ Provider = &GroupCache{} 17 | ) 18 | 19 | type Cache interface { 20 | Get(key groups.CacheKey) (groups.CacheEntry, bool) 21 | Set(key groups.CacheKey, val groups.CacheEntry) 22 | Purge(key groups.CacheKey) 23 | } 24 | 25 | // GroupCache is designed to act as a provider while wrapping subsequent provider's functions, 26 | // while also offering a caching mechanism (specifically used for group caching at the moment). 27 | type GroupCache struct { 28 | statsdClient *statsd.Client 29 | provider Provider 30 | cache Cache 31 | } 32 | 33 | // NewGroupCache returns a new GroupCache (which includes a LocalCache from the groups package) 34 | func NewGroupCache(provider Provider, ttl time.Duration, statsdClient *statsd.Client, tags []string) *GroupCache { 35 | return &GroupCache{ 36 | statsdClient: statsdClient, 37 | provider: provider, 38 | cache: groups.NewLocalCache(ttl, statsdClient, tags), 39 | } 40 | } 41 | 42 | // SetStatsdClient calls the provider's SetStatsdClient function. 43 | func (p *GroupCache) SetStatsdClient(statsdClient *statsd.Client) { 44 | p.statsdClient = statsdClient 45 | p.provider.SetStatsdClient(statsdClient) 46 | } 47 | 48 | // Data returns the provider Data 49 | func (p *GroupCache) Data() *ProviderData { 50 | return p.provider.Data() 51 | } 52 | 53 | // Redeem wraps the provider's Redeem function 54 | func (p *GroupCache) Redeem(redirectURL, code string) (*sessions.SessionState, error) { 55 | return p.provider.Redeem(redirectURL, code) 56 | } 57 | 58 | // ValidateSessionState wraps the provider's ValidateSessionState function. 59 | func (p *GroupCache) ValidateSessionState(s *sessions.SessionState) bool { 60 | return p.provider.ValidateSessionState(s) 61 | } 62 | 63 | // GetSignInURL wraps the provider's GetSignInURL function. 64 | func (p *GroupCache) GetSignInURL(redirectURI, finalRedirect string) string { 65 | return p.provider.GetSignInURL(redirectURI, finalRedirect) 66 | } 67 | 68 | // RefreshSessionIfNeeded wraps the provider's RefreshSessionIfNeeded function. 69 | func (p *GroupCache) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { 70 | return p.provider.RefreshSessionIfNeeded(s) 71 | } 72 | 73 | // ValidateGroupMembership wraps the provider's ValidateGroupMembership around calls to check local cache for group membership information. 74 | func (p *GroupCache) ValidateGroupMembership(email string, allowedGroups []string, accessToken string) ([]string, error) { 75 | // Create a cache key and check to see if it's in the cache. If not, call the provider's 76 | // ValidateGroupMembership function and cache the result. 77 | sort.Strings(allowedGroups) 78 | key := groups.CacheKey{ 79 | Email: email, 80 | AllowedGroups: strings.Join(allowedGroups, ","), 81 | } 82 | 83 | val, ok := p.cache.Get(key) 84 | if ok { 85 | p.statsdClient.Incr("provider.groupcache", 86 | []string{ 87 | "action:ValidateGroupMembership", 88 | "cache:hit", 89 | }, 1.0) 90 | return val.ValidGroups, nil 91 | } 92 | 93 | // The key isn't in the cache, so pass the call on to the subsequent provider 94 | p.statsdClient.Incr("provider.groupcache", 95 | []string{ 96 | "action:ValidateGroupMembership", 97 | "cache:miss", 98 | }, 1.0) 99 | 100 | validGroups, err := p.provider.ValidateGroupMembership(email, allowedGroups, accessToken) 101 | if err != nil { 102 | return nil, err 103 | } 104 | 105 | entry := groups.CacheEntry{ 106 | ValidGroups: validGroups, 107 | } 108 | p.cache.Set(key, entry) 109 | return validGroups, nil 110 | } 111 | 112 | // Revoke wraps the provider's Revoke function. 113 | func (p *GroupCache) Revoke(s *sessions.SessionState) error { 114 | return p.provider.Revoke(s) 115 | } 116 | 117 | // RefreshAccessToken wraps the provider's RefreshAccessToken function. 118 | func (p *GroupCache) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { 119 | return p.provider.RefreshAccessToken(refreshToken) 120 | } 121 | 122 | // Stop calls the providers stop function. 123 | func (p *GroupCache) Stop() { 124 | p.provider.Stop() 125 | } 126 | -------------------------------------------------------------------------------- /internal/auth/providers/http_client.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | var httpClient = &http.Client{ 10 | Timeout: time.Second * 5, 11 | Transport: &http.Transport{ 12 | Proxy: http.ProxyFromEnvironment, 13 | Dial: (&net.Dialer{ 14 | Timeout: 2 * time.Second, 15 | }).Dial, 16 | TLSHandshakeTimeout: 2 * time.Second, 17 | }, 18 | } 19 | -------------------------------------------------------------------------------- /internal/auth/providers/internal_util.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | 6 | log "github.com/buzzfeed/sso/internal/pkg/logging" 7 | ) 8 | 9 | // stripToken is a helper function used to obfuscate "access_token" query parameters 10 | func stripToken(endpoint string) string { 11 | return stripParam("access_token", endpoint) 12 | } 13 | 14 | // stripParam generalizes the obfuscation of a particular 15 | // query parameter - typically 'access_token' or 'client_secret' 16 | // The parameter's second half is replaced by '...' and returned 17 | // as part of the encoded query parameters. 18 | // If the target parameter isn't found, the endpoint is returned unmodified. 19 | func stripParam(param, endpoint string) string { 20 | logger := log.NewLogEntry() 21 | 22 | u, err := url.Parse(endpoint) 23 | if err != nil { 24 | logger.WithURLParam(param).Error( 25 | err, "error parsing endpoint while stripping param") 26 | return endpoint 27 | } 28 | 29 | if u.RawQuery != "" { 30 | values, err := url.ParseQuery(u.RawQuery) 31 | if err != nil { 32 | logger.WithURLParam(param).Error( 33 | err, "error parsing query string while stripping param") 34 | return u.String() 35 | } 36 | 37 | if val := values.Get(param); val != "" { 38 | values.Set(param, val[:(len(val)/2)]+"...") 39 | u.RawQuery = values.Encode() 40 | return u.String() 41 | } 42 | } 43 | 44 | return endpoint 45 | } 46 | -------------------------------------------------------------------------------- /internal/auth/providers/internal_util_test.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/testutil" 7 | ) 8 | 9 | func TestStripTokenNotPresent(t *testing.T) { 10 | test := "http://local.test/api/test?a=1&b=2" 11 | testutil.Equal(t, test, stripToken(test)) 12 | } 13 | 14 | func TestStripToken(t *testing.T) { 15 | test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" 16 | expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" 17 | testutil.Equal(t, expected, stripToken(test)) 18 | } 19 | -------------------------------------------------------------------------------- /internal/auth/providers/provider_data.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | ) 7 | 8 | // ProviderData holds the fields associated with providers 9 | // necessary to implement the Provider interface. 10 | type ProviderData struct { 11 | ProviderName string 12 | ProviderSlug string 13 | 14 | ClientID string 15 | ClientSecret string 16 | 17 | SignInURL *url.URL 18 | RedeemURL *url.URL 19 | RevokeURL *url.URL 20 | ValidateURL *url.URL 21 | ProfileURL *url.URL 22 | 23 | Scope string 24 | 25 | SessionLifetimeTTL time.Duration 26 | } 27 | 28 | // Data returns a ProviderData. 29 | func (p *ProviderData) Data() *ProviderData { return p } 30 | -------------------------------------------------------------------------------- /internal/auth/providers/providers.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/buzzfeed/sso/internal/pkg/sessions" 8 | "github.com/datadog/datadog-go/statsd" 9 | ) 10 | 11 | var ( 12 | // ErrBadRequest represents 400 Bad Request errors 13 | ErrBadRequest = errors.New("BAD_REQUEST") 14 | 15 | // ErrTokenRevoked represents 400 Token Revoked errors 16 | ErrTokenRevoked = errors.New("TOKEN_REVOKED") 17 | 18 | // ErrRateLimitExceeded represents 429 Rate Limit Exceeded errors 19 | ErrRateLimitExceeded = errors.New("RATE_LIMIT_EXCEEDED") 20 | 21 | // ErrNotImplemented represents 501 Not Implemented errors 22 | ErrNotImplemented = errors.New("NOT_IMPLEMENTED") 23 | 24 | // ErrServiceUnavailable represents 503 Service Unavailable errors 25 | ErrServiceUnavailable = errors.New("SERVICE_UNAVAILABLE") 26 | ) 27 | 28 | const ( 29 | // GoogleProviderName identifies the Google provider 30 | GoogleProviderName = "google" 31 | // OktaProviderName identities the Okta provider 32 | OktaProviderName = "okta" 33 | // AmazonCognitoProviderName identities the Okta provider 34 | AmazonCognitoProviderName = "cognito" 35 | ) 36 | 37 | // Provider is an interface exposing functions necessary to authenticate with a given provider. 38 | type Provider interface { 39 | SetStatsdClient(*statsd.Client) 40 | Data() *ProviderData 41 | Redeem(string, string) (*sessions.SessionState, error) 42 | ValidateSessionState(*sessions.SessionState) bool 43 | GetSignInURL(redirectURI, finalRedirect string) string 44 | RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) 45 | ValidateGroupMembership(string, []string, string) ([]string, error) 46 | Revoke(*sessions.SessionState) error 47 | RefreshAccessToken(string) (string, time.Duration, error) 48 | Stop() 49 | } 50 | -------------------------------------------------------------------------------- /internal/auth/providers/test_provider.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | 7 | "github.com/buzzfeed/sso/internal/pkg/sessions" 8 | "github.com/datadog/datadog-go/statsd" 9 | ) 10 | 11 | // TestProvider is a test implementation of the Provider interface. 12 | type TestProvider struct { 13 | *ProviderData 14 | 15 | ValidToken bool 16 | ValidGroup bool 17 | SignInURL string 18 | Refresh bool 19 | RefreshFunc func(string) (string, time.Duration, error) 20 | RefreshError error 21 | Session *sessions.SessionState 22 | RedeemError error 23 | RevokeError error 24 | Groups []string 25 | GroupsError error 26 | GroupsCall int 27 | } 28 | 29 | // NewTestProvider creates a new mock test provider. 30 | func NewTestProvider(_ *url.URL) *TestProvider { 31 | return &TestProvider{ 32 | ProviderData: &ProviderData{ 33 | ProviderName: "Test Provider", 34 | ProviderSlug: "test", 35 | Scope: "profile.email", 36 | }, 37 | } 38 | } 39 | 40 | // SetStatsdClient fulfills the Provider interface 41 | func (tp *TestProvider) SetStatsdClient(*statsd.Client) { 42 | return 43 | } 44 | 45 | // ValidateSessionState returns the mock provider's ValidToken field value. 46 | func (tp *TestProvider) ValidateSessionState(*sessions.SessionState) bool { 47 | return tp.ValidToken 48 | } 49 | 50 | // GetSignInURL returns the mock provider's SignInURL field value. 51 | func (tp *TestProvider) GetSignInURL(redirectURI, finalRedirect string) string { 52 | return tp.SignInURL 53 | } 54 | 55 | // RefreshSessionIfNeeded returns the mock provider's Refresh value, or an error. 56 | func (tp *TestProvider) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) { 57 | return tp.Refresh, tp.RefreshError 58 | } 59 | 60 | // RefreshAccessToken returns the mock provider's refresh access token information 61 | func (tp *TestProvider) RefreshAccessToken(s string) (string, time.Duration, error) { 62 | return tp.RefreshFunc(s) 63 | } 64 | 65 | // Revoke returns nil 66 | func (tp *TestProvider) Revoke(*sessions.SessionState) error { 67 | return tp.RevokeError 68 | } 69 | 70 | // ValidateGroupMembership returns the mock provider's GroupsError if not nil, or the Groups field value. 71 | func (tp *TestProvider) ValidateGroupMembership(string, []string, string) ([]string, error) { 72 | return tp.Groups, tp.GroupsError 73 | } 74 | 75 | // Redeem returns the mock provider's Session and RedeemError field value. 76 | func (tp *TestProvider) Redeem(redirectURI, code string) (*sessions.SessionState, error) { 77 | return tp.Session, tp.RedeemError 78 | 79 | } 80 | 81 | // Stop fulfills the Provider interface 82 | func (tp *TestProvider) Stop() { 83 | return 84 | } 85 | -------------------------------------------------------------------------------- /internal/auth/static/sso.css: -------------------------------------------------------------------------------- 1 | * { 2 | margin: 0; 3 | padding: 0; 4 | } 5 | body { 6 | font-family: "Helvetica Neue",Helvetica,Arial,sans-serif; 7 | font-size: 1em; 8 | line-height: 1.42857143; 9 | color: #333; 10 | background: #f0f0f0; 11 | } 12 | 13 | p { 14 | margin: 1.5em 0; 15 | } 16 | p:first-child { 17 | margin-top: 0; 18 | } 19 | p:last-child { 20 | margin-bottom: 0; 21 | } 22 | 23 | .container { 24 | max-width: 40em; 25 | display: block; 26 | margin: 10% auto; 27 | text-align: center; 28 | } 29 | 30 | .content, .message, button { 31 | border: 1px solid rgba(0,0,0,.125); 32 | border-bottom-width: 4px; 33 | border-radius: 4px; 34 | } 35 | 36 | .content, .message { 37 | background-color: #fff; 38 | padding: 2rem; 39 | margin: 1rem 0; 40 | } 41 | .error, .message { 42 | border-bottom-color: #c00; 43 | } 44 | .message { 45 | padding: 1.5rem 2rem 1.3rem; 46 | } 47 | 48 | header { 49 | border-bottom: 1px solid rgba(0,0,0,.075); 50 | margin: -2rem 0 2rem; 51 | padding: 2rem 0 1.8rem; 52 | } 53 | header h1 { 54 | font-size: 1.5em; 55 | font-weight: normal; 56 | } 57 | .error header { 58 | color: #c00; 59 | } 60 | .details { 61 | font-size: .85rem; 62 | color: #999; 63 | } 64 | 65 | button { 66 | color: #fff; 67 | background-color: #3B8686; 68 | cursor: pointer; 69 | font-size: 1.5rem; 70 | font-weight: bold; 71 | padding: 1rem 2.5rem; 72 | text-shadow: 0 3px 1px rgba(0,0,0,.2); 73 | outline: none; 74 | } 75 | button:active { 76 | border-top-width: 4px; 77 | border-bottom-width: 1px; 78 | text-shadow: none; 79 | } 80 | 81 | footer { 82 | font-size: 0.75em; 83 | color: #999; 84 | text-align: right; 85 | margin: 1rem; 86 | } 87 | -------------------------------------------------------------------------------- /internal/auth/static_files.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | 7 | "github.com/rakyll/statik/fs" 8 | 9 | // Statik makes assets available via a blank import 10 | _ "github.com/buzzfeed/sso/internal/auth/statik" 11 | ) 12 | 13 | // noDirectoryFilesystem is used to prevent an http.FileServer from providing directory listings 14 | type noDirectoryFS struct { 15 | fs http.FileSystem 16 | } 17 | 18 | func (fs noDirectoryFS) Open(name string) (http.File, error) { 19 | f, err := fs.fs.Open(name) 20 | 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | stat, err := f.Stat() 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | // prevent directory listings 31 | if stat.IsDir() { 32 | return nil, os.ErrNotExist 33 | } 34 | 35 | return f, nil 36 | } 37 | 38 | //go:generate $GOPATH/bin/statik -f -src=./static 39 | 40 | func loadFSHandler() (http.Handler, error) { 41 | statikFS, err := fs.New() 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | return http.FileServer(noDirectoryFS{statikFS}), nil 47 | } 48 | -------------------------------------------------------------------------------- /internal/auth/static_files_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestStaticFiles(t *testing.T) { 11 | config := testConfiguration(t) 12 | authMux, err := NewAuthenticatorMux(config, nil) 13 | if err != nil { 14 | t.Fatalf("unexpected error creating auth mux: %v", err) 15 | } 16 | 17 | testCases := []struct { 18 | name string 19 | uri string 20 | expectedStatus int 21 | expectedContent string 22 | }{ 23 | { 24 | name: "static css ok", 25 | uri: "http://localhost/static/sso.css", 26 | expectedStatus: http.StatusOK, 27 | expectedContent: "body {", 28 | }, 29 | { 30 | name: "nonexistent file not found", 31 | uri: "http://localhost/static/missing.css", 32 | expectedStatus: http.StatusNotFound, 33 | }, 34 | { 35 | name: "no directory listing", 36 | uri: "http://localhost/static/", 37 | expectedStatus: http.StatusNotFound, 38 | }, 39 | { 40 | // this will result in a 301 -> /config.yml 41 | name: "no directory escape", 42 | uri: "http://localhost/static/../config.yml", 43 | expectedStatus: http.StatusMovedPermanently, 44 | }, 45 | { 46 | // this should NOT result in a 301, since escaped-paths are propagated as-is. 47 | name: "no directory escape", 48 | uri: "http://localhost/static/https:%2F%2Fexample.com", 49 | expectedStatus: http.StatusNotFound, 50 | }, 51 | } 52 | 53 | for _, tc := range testCases { 54 | t.Run(tc.name, func(t *testing.T) { 55 | rw := httptest.NewRecorder() 56 | req := httptest.NewRequest("GET", tc.uri, nil) 57 | 58 | authMux.ServeHTTP(rw, req) 59 | if rw.Code != tc.expectedStatus { 60 | t.Errorf("expected response %v, got %v\n%v", tc.expectedStatus, rw.Code, rw.HeaderMap) 61 | } 62 | 63 | if tc.expectedContent != "" && !strings.Contains(rw.Body.String(), tc.expectedContent) { 64 | t.Errorf("substring %q not found in response body:\n%s", tc.expectedContent, rw.Body.String()) 65 | } 66 | }) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /internal/auth/statik/statik.go: -------------------------------------------------------------------------------- 1 | // Code generated by statik. DO NOT EDIT. 2 | 3 | // Package statik contains static assets. 4 | package statik 5 | 6 | import ( 7 | "github.com/rakyll/statik/fs" 8 | ) 9 | 10 | func init() { 11 | data := "PK\x03\x04\x14\x00\x08\x00\x08\x00$y\xfaN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07\x00 \x00sso.cssUT\x05\x00\x01\x95\x17;]tT\xc1n\x9c0\x10\xbd\xf3\x15\xa3D\x95\xdajA\xb0d\x93\x8d\xf7\xd4\x9ez\xea?\x18<\x80\x15\xe3A\xb6Iv[\xe5\xdf+\x1b\xc3\xc2j+N\xb6\xc7o\xde{\xf3\xccw\xf8\x9b\x00\xf4\xdc\xb4R3\xc8O \xc0\xc0\x85\x90\xba\x0d\xab\xcf\xa4\"q 5\x0di\x976\xbc\x97\xea\xc2\xe0\xe1\x17\xaawt\xb2\xe6\xf0\x1bG|\xd8-\xeb\xdd\x0f#\xb9\xdaY\xaemj\xd1\xc8\xe64\xdf\xb5\xf2\x0f2(\xb0\xf7;JjL;\x94m\xe7\x18\x14\xd9\xd3\xfexx)\x9eJ\x7fT\x93\"\xc3\xe0\xb1,\xc3\xb2\xe2\xf5[kh\xd4\x82\xc1c\x93\xfb\xcf\xf3J\x86\x0d\xf3\";`?1\x1eX#\x8dui\xddI%VE\xa9\xa3\x81\xcd%\x8a\xdf\xab\xa8\xc89\xeacQ\x92\xd5\xa4\x1d\x97\x1aM\xac9\xa7\x1fR\xb8\x8e\xc1S>\xa9\x10\xd2\x0e\x8a_\x18T\x8a\xea\xb7\xd3\x9aO\xfe\x05\xf8\xe8\xc8\xef9<\xbb\x94+\xd9j\x065j\x87\xe6\n\x8f\xda\xed \xeb\xd1Z\xde\xe2\x0e\xaa\xd19\xd2\xa1]EF\xa0aP\x0cg\xb0\xa4\xa4\x00\xd3V\xfck\xbe\xf3_V\xec\x0f\xdfNKU$\xbe\xd0\x1b\xce\xab3\xc3\x85\x1cm\xdc\xbd\xdbw\xea\xb7\xf8\x9c\xce\x13h\x9af\x13\x88\xbd\x99d/\"\xcd\xecy\x86\xc6\x90\xb9A\xbce7\xc3\xd6\xf9tg[\xbbt)\xb2\x83\xc7\xf5\xcd\xa0\xc8\xca\xd0\xf33I:\xe4\"Nb\x03\xfb?\x87\xf2\x97\xc9\xa1\x99l\x1a\x00\xf3E\xc4F\x15\xe4Pd\xc7\xd8*v\xea\x8ak\xeecv}\xc8\x96<\x7f\xc4\xf4j2=WW\x17`\xc5\xf4\x9a\xe6Y\xb4@\xc7\xa5\xb2\xf1t\x05\x9e\x1d\x0f\x91\xd9\xf5\xd2\xeb\xebk\xd0\xbe\x8a\xc5\xcdl\xeeL\xad\xfcy|>>\x87\x974\x1a\xeb\xb7\x06\x92S\xean\xd5\x98;r*Rb\xe3O\x98\xf2~)\x0ei\xb6\x1d\x17\xf4\xc1 \x87r8\x87\x01\xac\xad\xdf\x07\xe3it\xfe\x95{\x834\x86_IP\xc1x\xed\xe4;\xae\x07\xe9h\xb8\x1f\xddm\xac\x8b\xe9lC`\xc6N\x1a\"\x17M_i\xcc\xb3\x978\xb2\x8d\xa5\xdb7i\xbc\xf0\xdb\\{\xd0\x7f\x01\x00\x00\xff\xffPK\x07\x08\x0b~@~\x1a\x02\x00\x00\x1d\x05\x00\x00PK\x01\x02\x14\x03\x14\x00\x08\x00\x08\x00$y\xfaN\x0b~@~\x1a\x02\x00\x00\x1d\x05\x00\x00\x07\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x81\x00\x00\x00\x00sso.cssUT\x05\x00\x01\x95\x17;]PK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x00>\x00\x00\x00X\x02\x00\x00\x00\x00" 12 | fs.Register(data) 13 | } 14 | -------------------------------------------------------------------------------- /internal/auth/version.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | // VERSION is the version of sso_auth 4 | const VERSION = "2.2.1-alpha" 5 | -------------------------------------------------------------------------------- /internal/pkg/aead/aead.go: -------------------------------------------------------------------------------- 1 | package aead 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "crypto/cipher" 7 | "encoding/base64" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "sync" 12 | 13 | miscreant "github.com/miscreant/miscreant.go" 14 | ) 15 | 16 | const miscreantNonceSize = 16 17 | 18 | var algorithmType = "AES-CMAC-SIV" 19 | 20 | // Cipher provides methods to encrypt and decrypt values. 21 | type Cipher interface { 22 | Encrypt([]byte) ([]byte, error) 23 | Decrypt([]byte) ([]byte, error) 24 | Marshal(interface{}) (string, error) 25 | Unmarshal(string, interface{}) error 26 | } 27 | 28 | // MiscreantCipher provides methods to encrypt and decrypt values. 29 | // Using an AEAD is a cipher providing authenticated encryption with associated data. 30 | // For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption 31 | type MiscreantCipher struct { 32 | aead cipher.AEAD 33 | 34 | mux sync.Mutex 35 | } 36 | 37 | // NewMiscreantCipher returns a new AES Cipher for encrypting values 38 | func NewMiscreantCipher(secret []byte) (*MiscreantCipher, error) { 39 | aead, err := miscreant.NewAEAD(algorithmType, secret, miscreantNonceSize) 40 | if err != nil { 41 | return nil, err 42 | } 43 | return &MiscreantCipher{ 44 | aead: aead, 45 | }, nil 46 | } 47 | 48 | // GenerateKey wraps miscreant's GenerateKey function 49 | func GenerateKey() []byte { 50 | return miscreant.GenerateKey(32) 51 | } 52 | 53 | // Encrypt a value using AES-CMAC-SIV 54 | func (c *MiscreantCipher) Encrypt(plaintext []byte) (joined []byte, err error) { 55 | c.mux.Lock() 56 | defer c.mux.Unlock() 57 | 58 | defer func() { 59 | if r := recover(); r != nil { 60 | err = fmt.Errorf("miscreant error encrypting bytes: %v", r) 61 | } 62 | }() 63 | nonce := miscreant.GenerateNonce(c.aead) 64 | ciphertext := c.aead.Seal(nil, nonce, plaintext, nil) 65 | 66 | // we return the nonce as part of the returned value 67 | joined = append(ciphertext[:], nonce[:]...) 68 | return joined, nil 69 | } 70 | 71 | // Decrypt a value using AES-CMAC-SIV 72 | func (c *MiscreantCipher) Decrypt(joined []byte) ([]byte, error) { 73 | c.mux.Lock() 74 | defer c.mux.Unlock() 75 | 76 | if len(joined) <= miscreantNonceSize { 77 | return nil, fmt.Errorf("invalid input size: %d", len(joined)) 78 | } 79 | // grab out the nonce 80 | pivot := len(joined) - miscreantNonceSize 81 | ciphertext := joined[:pivot] 82 | nonce := joined[pivot:] 83 | 84 | plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | return plaintext, nil 90 | } 91 | 92 | // Marshal marshals the interface state as JSON, encrypts the JSON using the cipher 93 | // and base64 encodes the binary value as a string and returns the result 94 | func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { 95 | // encode json value 96 | plaintext, err := json.Marshal(s) 97 | if err != nil { 98 | return "", err 99 | } 100 | 101 | // gunzip the bytes 102 | var jsonBuffer bytes.Buffer 103 | w := gzip.NewWriter(&jsonBuffer) 104 | w.Write(plaintext) 105 | w.Close() 106 | 107 | // encrypt the JSON 108 | ciphertext, err := c.Encrypt(jsonBuffer.Bytes()) 109 | if err != nil { 110 | return "", err 111 | } 112 | 113 | // base64-encode the result 114 | encoded := base64.RawURLEncoding.EncodeToString(ciphertext) 115 | return encoded, nil 116 | } 117 | 118 | // Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the 119 | // byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed 120 | func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error { 121 | // convert base64 string value to bytes 122 | ciphertext, err := base64.RawURLEncoding.DecodeString(value) 123 | if err != nil { 124 | return err 125 | } 126 | 127 | // decrypt the bytes 128 | plaintext, err := c.Decrypt(ciphertext) 129 | if err != nil { 130 | return err 131 | } 132 | 133 | // gzip the bytes 134 | var jsonBuffer bytes.Buffer 135 | r, err := gzip.NewReader(bytes.NewBuffer(plaintext)) 136 | if err != nil { 137 | return err 138 | } 139 | io.Copy(&jsonBuffer, r) 140 | 141 | // unmarshal bytes 142 | err = json.Unmarshal(jsonBuffer.Bytes(), s) 143 | if err != nil { 144 | return err 145 | } 146 | 147 | return nil 148 | } 149 | -------------------------------------------------------------------------------- /internal/pkg/aead/aead_test.go: -------------------------------------------------------------------------------- 1 | package aead 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/sha1" 6 | "fmt" 7 | "reflect" 8 | "sync" 9 | "testing" 10 | ) 11 | 12 | func TestEncodeAndDecodeAccessToken(t *testing.T) { 13 | plaintext := []byte("my plain text value") 14 | 15 | key := GenerateKey() 16 | c, err := NewMiscreantCipher([]byte(key)) 17 | if err != nil { 18 | t.Fatalf("unexpected err: %v", err) 19 | } 20 | 21 | ciphertext, err := c.Encrypt(plaintext) 22 | if err != nil { 23 | t.Fatalf("unexpected err: %v", err) 24 | } 25 | 26 | if reflect.DeepEqual(plaintext, ciphertext) { 27 | t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext) 28 | } 29 | 30 | got, err := c.Decrypt(ciphertext) 31 | if err != nil { 32 | t.Fatalf("unexpected err decrypting: %v", err) 33 | } 34 | 35 | if !reflect.DeepEqual(got, plaintext) { 36 | t.Logf(" got: %v", got) 37 | t.Logf("want: %v", plaintext) 38 | t.Fatal("got unexpected decrypted value") 39 | } 40 | } 41 | 42 | func TestMarshalAndUnmarshalStruct(t *testing.T) { 43 | key := GenerateKey() 44 | 45 | c, err := NewMiscreantCipher([]byte(key)) 46 | if err != nil { 47 | t.Fatalf("unexpected err: %v", err) 48 | } 49 | 50 | type TC struct { 51 | Field string `json:"field"` 52 | } 53 | 54 | tc := &TC{ 55 | Field: "my plain text value", 56 | } 57 | 58 | value1, err := c.Marshal(tc) 59 | if err != nil { 60 | t.Fatalf("unexpected err: %v", err) 61 | } 62 | 63 | value2, err := c.Marshal(tc) 64 | if err != nil { 65 | t.Fatalf("unexpected err: %v", err) 66 | } 67 | 68 | if value1 == value2 { 69 | t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2) 70 | } 71 | 72 | got1 := &TC{} 73 | err = c.Unmarshal(value1, got1) 74 | if err != nil { 75 | t.Fatalf("unexpected err unmarshalling struct: %v", err) 76 | } 77 | 78 | if !reflect.DeepEqual(got1, tc) { 79 | t.Logf("want: %#v", tc) 80 | t.Logf(" got: %#v", got1) 81 | t.Fatalf("expected structs to be equal") 82 | } 83 | 84 | got2 := &TC{} 85 | err = c.Unmarshal(value2, got2) 86 | if err != nil { 87 | t.Fatalf("unexpected err unmarshalling struct: %v", err) 88 | } 89 | 90 | if !reflect.DeepEqual(got1, got2) { 91 | t.Logf("got2: %#v", got2) 92 | t.Logf("got1: %#v", got1) 93 | t.Fatalf("expected structs to be equal") 94 | } 95 | } 96 | 97 | // TestCipherDataRace exercises a simple concurrency test for the MicscreantCipher. 98 | // In https://github.com/buzzfeed/sso/pull/75 we investigated why, on random occasion, 99 | // unmarshalling session states would fail, triggering users to get kicked out of 100 | // authenticated states. We narrowed our investigation to a data race we uncovered 101 | // from our misuse of the underlying miscreant library which makes no attempt 102 | // at thread-safety. 103 | // 104 | // In https://github.com/buzzfeed/sso/pull/77 we added this test to exercise the 105 | // data race condition and resolved said race by introducing a simple mutex. 106 | func TestCipherDataRace(t *testing.T) { 107 | miscreantCipher, err := NewMiscreantCipher(GenerateKey()) 108 | if err != nil { 109 | t.Fatalf("unexpected generating cipher err: %v", err) 110 | } 111 | 112 | type TC struct { 113 | Field string `json:"field"` 114 | } 115 | 116 | // Create a channel to collect errors from goroutines 117 | errCh := make(chan error, 100) 118 | 119 | wg := &sync.WaitGroup{} 120 | for i := 0; i < 100; i++ { 121 | wg.Add(1) 122 | go func(c *MiscreantCipher, wg *sync.WaitGroup, errCh chan<- error) { 123 | defer wg.Done() 124 | b := make([]byte, 32) 125 | _, err := rand.Read(b) 126 | if err != nil { 127 | errCh <- fmt.Errorf("unexpected error reading random bytes: %v", err) 128 | return 129 | } 130 | 131 | sha := fmt.Sprintf("%x", sha1.New().Sum(b)) 132 | tc := &TC{ 133 | Field: sha, 134 | } 135 | 136 | value1, err := c.Marshal(tc) 137 | if err != nil { 138 | errCh <- fmt.Errorf("unexpected err: %v", err) 139 | return 140 | } 141 | 142 | value2, err := c.Marshal(tc) 143 | if err != nil { 144 | errCh <- fmt.Errorf("unexpected err: %v", err) 145 | return 146 | } 147 | 148 | if value1 == value2 { 149 | errCh <- fmt.Errorf("expected marshaled values to not be equal %v != %v", value1, value2) 150 | return 151 | } 152 | 153 | got1 := &TC{} 154 | err = c.Unmarshal(value1, got1) 155 | if err != nil { 156 | errCh <- fmt.Errorf("unexpected err unmarshalling struct: %v", err) 157 | return 158 | } 159 | 160 | if !reflect.DeepEqual(got1, tc) { 161 | errCh <- fmt.Errorf("expected structs to be equal: want %#v, got %#v", tc, got1) 162 | return 163 | } 164 | 165 | got2 := &TC{} 166 | err = c.Unmarshal(value2, got2) 167 | if err != nil { 168 | errCh <- fmt.Errorf("unexpected err unmarshalling struct: %v", err) 169 | return 170 | } 171 | 172 | if !reflect.DeepEqual(got1, got2) { 173 | errCh <- fmt.Errorf("expected structs to be equal: got1 %#v, got2 %#v", got1, got2) 174 | return 175 | } 176 | 177 | }(miscreantCipher, wg, errCh) 178 | } 179 | 180 | go func() { 181 | wg.Wait() 182 | close(errCh) // Close the channel after all goroutines have finished 183 | }() 184 | 185 | // Collect and handle errors from the channel 186 | for err := range errCh { 187 | t.Errorf("Test failed: %v", err) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /internal/pkg/aead/mock_cipher.go: -------------------------------------------------------------------------------- 1 | package aead 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | // MockCipher is a mock of the cipher interface 8 | type MockCipher struct { 9 | MarshalError error 10 | MarshalString string 11 | UnmarshalError error 12 | UnmarshalBytes []byte 13 | } 14 | 15 | // Encrypt returns an empty byte array and nil 16 | func (mc *MockCipher) Encrypt([]byte) ([]byte, error) { 17 | return []byte{}, nil 18 | } 19 | 20 | // Decrypt returns an empty byte array and nil 21 | func (mc *MockCipher) Decrypt([]byte) ([]byte, error) { 22 | return []byte{}, nil 23 | } 24 | 25 | // Marshal returns the marshal string and marsha error 26 | func (mc *MockCipher) Marshal(interface{}) (string, error) { 27 | return mc.MarshalString, mc.MarshalError 28 | } 29 | 30 | // Unmarshal unmarshals the unmarshal bytes to be set in s and returns the unmarshal error 31 | func (mc *MockCipher) Unmarshal(b string, s interface{}) error { 32 | json.Unmarshal(mc.UnmarshalBytes, s) 33 | return mc.UnmarshalError 34 | } 35 | -------------------------------------------------------------------------------- /internal/pkg/groups/fillcache.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | "sync" 8 | "time" 9 | 10 | log "github.com/buzzfeed/sso/internal/pkg/logging" 11 | "github.com/datadog/datadog-go/statsd" 12 | ) 13 | 14 | const ( 15 | defaultMaxJitter = 50 * time.Millisecond 16 | ) 17 | 18 | var ErrGroupNotFound = errors.New("Group not found while running fill func") 19 | 20 | // MemberSetCache represents a cache of members of a set 21 | type MemberSetCache interface { 22 | // Get returns a MemberSet from the cache 23 | Get(string) (MemberSet, bool) 24 | // Update updates the MemberSet of a given key and returns a boolean value indicating whether the value was updated or not. 25 | Update(string) bool 26 | // RefreshLoop starts an update refresh loop for a given key and returns a boolean value indicating whether a refresh loop has been started or not 27 | RefreshLoop(string) bool 28 | // Stop is a function to stop all goroutines that may have been spun up for the cache. 29 | Stop() 30 | } 31 | 32 | // FillFunc is a function that computes the value of a cache entry given a string key 33 | type FillFunc func(string) (MemberSet, error) 34 | 35 | // MemberSet represents a set of all the members for a given group key in the cache 36 | type MemberSet map[string]struct{} 37 | 38 | // FillCache is a cache whose entries are calculated and filled on-demand 39 | type FillCache struct { 40 | fillFunc FillFunc 41 | cache map[string]MemberSet 42 | inflight map[string]struct{} 43 | maxJitter time.Duration 44 | 45 | refreshTTL time.Duration 46 | refreshLoopGroups map[string]struct{} 47 | 48 | StatsdClient *statsd.Client 49 | 50 | mu sync.RWMutex 51 | stopCh chan struct{} 52 | } 53 | 54 | // NewFillCache creates a `FillCache` whose entries will be computed by the given `FillFunc` 55 | func NewFillCache(fillFunc FillFunc, refreshTTL time.Duration) *FillCache { 56 | return &FillCache{ 57 | fillFunc: fillFunc, 58 | cache: make(map[string]MemberSet), 59 | inflight: make(map[string]struct{}), 60 | refreshLoopGroups: make(map[string]struct{}), 61 | refreshTTL: refreshTTL, 62 | stopCh: make(chan struct{}), 63 | } 64 | } 65 | 66 | // Get returns the cache value for the given key, computing it as necessary 67 | func (c *FillCache) Get(group string) (MemberSet, bool) { 68 | c.mu.RLock() 69 | defer c.mu.RUnlock() 70 | val, found := c.cache[group] 71 | return val, found 72 | } 73 | 74 | // Update recomputes the value for the given key, unless another goroutine is 75 | // already computing the value, and returns a bool indicating whether the value 76 | // was updated 77 | func (c *FillCache) Update(group string) bool { 78 | logger := log.NewLogEntry() 79 | logger.WithUserGroup(group).Info("updating fill cache") 80 | 81 | c.mu.Lock() 82 | if _, waiting := c.inflight[group]; waiting { 83 | c.mu.Unlock() 84 | return false 85 | } 86 | 87 | c.inflight[group] = struct{}{} 88 | c.mu.Unlock() 89 | val, err := c.fillFunc(group) 90 | 91 | c.mu.Lock() 92 | defer c.mu.Unlock() 93 | delete(c.inflight, group) 94 | 95 | if err == nil { 96 | c.cache[group] = val 97 | return true 98 | } 99 | 100 | if err == ErrGroupNotFound { 101 | delete(c.cache, group) 102 | } 103 | 104 | c.StatsdClient.Incr("groups_cache.error", 105 | []string{ 106 | fmt.Sprintf("group:%s", group), 107 | fmt.Sprintf("error:%s", err), 108 | }, 1.0) 109 | logger.WithUserGroup(group).Error(err, "error updating fill cache") 110 | return false 111 | } 112 | 113 | // RefreshLoop runs in a separate goroutine for the key in the cache and 114 | // updates the cache value for that key every refreshTTL 115 | func (c *FillCache) RefreshLoop(group string) bool { 116 | maxJitter := defaultMaxJitter 117 | // Add jitter before starting each refresh loop, up to refreshPeriod by 118 | // default (tests control maxJitter to ensure they run deterministically) 119 | if c.maxJitter != 0 { 120 | maxJitter = c.maxJitter 121 | } 122 | time.Sleep(time.Duration(rand.Float64() * float64(maxJitter))) 123 | 124 | c.mu.Lock() 125 | 126 | // check the inflight and the cache so that we don't start if they're already populated 127 | if _, waiting := c.refreshLoopGroups[group]; waiting { 128 | c.mu.Unlock() 129 | return false 130 | } 131 | 132 | c.refreshLoopGroups[group] = struct{}{} 133 | c.mu.Unlock() 134 | 135 | ticker := time.NewTicker(c.refreshTTL) 136 | go func() { 137 | logger := log.NewLogEntry() 138 | // cleanup if this goroutine exits 139 | defer func() { 140 | c.mu.Lock() 141 | delete(c.refreshLoopGroups, group) 142 | c.mu.Unlock() 143 | }() 144 | 145 | // we update the cache once before looping to ensure the cache is filled immediately, 146 | // instead of only after c.refreshTTL 147 | updated := c.Update(group) 148 | if !updated { 149 | logger.WithUserGroup(group).Info("cache was not updated") 150 | } 151 | 152 | for { 153 | select { 154 | case <-c.stopCh: 155 | return 156 | case <-ticker.C: 157 | updated = c.Update(group) 158 | if !updated { 159 | logger.WithUserGroup(group).Info("cache was not updated") 160 | } 161 | } 162 | } 163 | }() 164 | return true 165 | } 166 | 167 | // Stop halts all goroutines and waits for them to exit before returning 168 | func (c *FillCache) Stop() { 169 | close(c.stopCh) 170 | } 171 | -------------------------------------------------------------------------------- /internal/pkg/groups/fillcache_test.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func testFillFunc(members MemberSet, fillError error) func(string) (MemberSet, error) { 10 | return func(string) (MemberSet, error) { 11 | return members, fillError 12 | } 13 | } 14 | 15 | func TestFillCacheUpdate(t *testing.T) { 16 | testCases := []struct { 17 | name string 18 | members MemberSet 19 | fillError error 20 | updated bool 21 | }{ 22 | { 23 | name: "successful update to empty cache", 24 | members: MemberSet{"a": {}, "b": {}, "c": {}}, 25 | updated: true, 26 | }, 27 | { 28 | name: "unsuccessful update to cache", 29 | fillError: fmt.Errorf("fill error"), 30 | updated: false, 31 | }, 32 | { 33 | name: "group removed if it can't be found", 34 | members: MemberSet{"a": {}, "b": {}}, 35 | fillError: ErrGroupNotFound, 36 | updated: false, 37 | }, 38 | } 39 | for _, tc := range testCases { 40 | t.Run(tc.name, func(t *testing.T) { 41 | fillCache := NewFillCache(testFillFunc(tc.members, tc.fillError), time.Hour) 42 | defer fillCache.Stop() 43 | cacheKey := "groupKeyA" 44 | 45 | // ("group removed if it can't be found") 46 | // In order to test a group is removed from the cache if it can't be found, we 47 | // fill the cache, check the key is present, then update the cache before checking 48 | // the key has been removed 49 | if tc.fillError == ErrGroupNotFound { 50 | temporaryCacheKey := "groupKeyB" 51 | // add two cache keys, one of which should be deleted, one should not 52 | fillCache.cache[cacheKey] = tc.members 53 | fillCache.cache[temporaryCacheKey] = tc.members 54 | 55 | if _, ok := fillCache.Get(temporaryCacheKey); !ok { 56 | t.Errorf("cache should contain %q, but it does not", temporaryCacheKey) 57 | } 58 | 59 | // this should error with `ErrGroupNotFound`, causing the key to be removed 60 | if ok := fillCache.Update(temporaryCacheKey); ok != tc.updated { 61 | t.Errorf("expected updated to be %v but was %v", tc.updated, ok) 62 | } 63 | 64 | if val, ok := fillCache.Get(temporaryCacheKey); ok { 65 | t.Errorf("expected group to not be present in cache, but found %q", val) 66 | } 67 | } 68 | 69 | ok := fillCache.Update(cacheKey) 70 | if ok != tc.updated { 71 | t.Errorf("expected updated to be %v but was %v", tc.updated, ok) 72 | } 73 | 74 | // as well as checking the key is actually in the cache when it should be, this also tests that 75 | // unrelated groups aren't deleted within the "group removed if it can't be found" test 76 | if tc.updated == true || tc.fillError == ErrGroupNotFound { 77 | _, ok = fillCache.Get(cacheKey) 78 | if ok != tc.updated { 79 | t.Errorf("cache should contain %q, but it does not", cacheKey) 80 | } 81 | } 82 | 83 | }) 84 | } 85 | } 86 | 87 | func TestRefreshLoop(t *testing.T) { 88 | testCases := []struct { 89 | name string 90 | refreshLoopGroups map[string]struct{} 91 | expectedStarted bool 92 | }{ 93 | { 94 | name: "empty inflight and empty cache starts a refresh loop", 95 | expectedStarted: true, 96 | }, 97 | { 98 | name: "inflight update does not start refresh loop", 99 | refreshLoopGroups: map[string]struct{}{"group1": {}}, 100 | expectedStarted: false, 101 | }, 102 | } 103 | for _, tc := range testCases { 104 | t.Run(tc.name, func(t *testing.T) { 105 | fillCache := NewFillCache(testFillFunc(MemberSet{"a": {}}, nil), time.Hour) 106 | if tc.refreshLoopGroups != nil { 107 | fillCache.refreshLoopGroups = tc.refreshLoopGroups 108 | } 109 | 110 | started := fillCache.RefreshLoop("group1") 111 | 112 | if tc.expectedStarted != started { 113 | t.Errorf("expected started to be %v but was %v", tc.expectedStarted, started) 114 | } 115 | 116 | if tc.expectedStarted { 117 | // wait briefly to allow the cache to be updated 118 | time.Sleep(50 * time.Millisecond) 119 | _, ok := fillCache.Get("group1") 120 | if !ok { 121 | t.Errorf("expected the group cache to be updated immediately") 122 | } 123 | } 124 | 125 | }) 126 | } 127 | 128 | } 129 | -------------------------------------------------------------------------------- /internal/pkg/groups/localcache.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/datadog/datadog-go/statsd" 7 | "golang.org/x/sync/syncmap" 8 | ) 9 | 10 | // NewLocalCache returns a LocalCache instance 11 | func NewLocalCache( 12 | ttl time.Duration, 13 | statsdClient *statsd.Client, 14 | tags []string, 15 | ) *LocalCache { 16 | return &LocalCache{ 17 | ttl: ttl, 18 | localCacheData: &syncmap.Map{}, 19 | metrics: statsdClient, 20 | tags: tags, 21 | } 22 | } 23 | 24 | type LocalCache struct { 25 | // Cache configuration 26 | ttl time.Duration 27 | metrics *statsd.Client 28 | tags []string 29 | 30 | // Cache data 31 | localCacheData *syncmap.Map 32 | } 33 | 34 | // Cachekey defines the key used to store the data in the cache. 35 | type CacheKey struct { 36 | Email string 37 | AllowedGroups string 38 | } 39 | 40 | // CacheEntry defines the data we want to store in the cache. 41 | type CacheEntry struct { 42 | ValidGroups []string 43 | } 44 | 45 | // get will attempt to retrieve an entry from the cache at the given key 46 | func (lc *LocalCache) get(key CacheKey) (CacheEntry, bool) { 47 | val, ok := lc.localCacheData.Load(key) 48 | if ok { 49 | return val.(CacheEntry), ok 50 | } 51 | 52 | return CacheEntry{}, false 53 | } 54 | 55 | // set will attempt to set an entry in the cache to a given key 56 | // for the prescribed TTL 57 | func (lc *LocalCache) set(key CacheKey, data CacheEntry) { 58 | lc.localCacheData.Store(key, data) 59 | 60 | // Spawn the TTL cleanup goroutine if a TTL is set 61 | if lc.ttl > 0 { 62 | go func(key CacheKey) { 63 | <-time.After(lc.ttl) 64 | lc.Purge(key) 65 | }(key) 66 | } 67 | } 68 | 69 | // Get retrieves a key from a local cache. If found, it will create and return an 70 | // 'Entry' using the returned values. If not found, it will return an empty 'Entry' 71 | func (lc *LocalCache) Get(key CacheKey) (CacheEntry, bool) { 72 | val, ok := lc.get(key) 73 | if ok { 74 | lc.metrics.Incr("localcache.hit", lc.tags, 1.0) 75 | return val, ok 76 | } 77 | 78 | lc.metrics.Incr("localcache.miss", lc.tags, 1.0) 79 | return CacheEntry{}, false 80 | } 81 | 82 | // Set will set an entry within the current cache 83 | func (lc *LocalCache) Set(key CacheKey, entry CacheEntry) { 84 | lc.metrics.Incr("localcache.set", lc.tags, 1.0) 85 | lc.set(key, entry) 86 | } 87 | 88 | // Purge will remove a set of keys from the local cache map 89 | func (lc *LocalCache) Purge(key CacheKey) { 90 | lc.metrics.Incr("localcache.purge", lc.tags, 1.0) 91 | lc.localCacheData.Delete(key) 92 | } 93 | -------------------------------------------------------------------------------- /internal/pkg/groups/localcache_test.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | //"github.com/cactus/go-statsd-client/statsd/statsdtest" 9 | "github.com/datadog/datadog-go/statsd" 10 | //"github.com/cactus/go-statsd-client/statsd" 11 | ) 12 | 13 | func TestNotAvailableAfterTTL(t *testing.T) { 14 | // Create a cache with a 10 millisecond TTL 15 | statsdClient, _ := statsd.New("127.0.0.1:8125") 16 | cache := NewLocalCache(time.Millisecond*10, statsdClient, []string{"test_case"}) 17 | 18 | // Create a cache Key and Entry and insert it into the cache 19 | cacheKey := CacheKey{ 20 | Email: "email@test.com", 21 | AllowedGroups: "testGroup", 22 | } 23 | cacheData := CacheEntry{ 24 | ValidGroups: []string{"testGroup"}, 25 | } 26 | cache.Set(cacheKey, cacheData) 27 | 28 | // Check the cached entry can be retrieved from the cache. 29 | if data, _ := cache.Get(cacheKey); !reflect.DeepEqual(data, cacheData) { 30 | t.Logf(" expected data to be '%+v'", cacheData) 31 | t.Logf("actual data returned was '%+v'", data) 32 | t.Fatalf("unexpected data returned") 33 | } 34 | 35 | // If we wait 10ms (or lets say, 50 for good luck), it will have been removed 36 | time.Sleep(time.Millisecond * 50) 37 | 38 | if _, found := cache.get(cacheKey); found { 39 | t.Fatalf("expected key not to be have been found after the TTL expired") 40 | } 41 | } 42 | 43 | func TestNotAvailableAfterPurge(t *testing.T) { 44 | statsdClient, _ := statsd.New("127.0.0.1:8125") 45 | cache := NewLocalCache(time.Duration(10)*time.Second, statsdClient, []string{"test_case"}) 46 | 47 | // Create a cache Key and Entry and insert it into the cache 48 | cacheKey := CacheKey{ 49 | Email: "email@test.com", 50 | AllowedGroups: "testGroup", 51 | } 52 | cacheData := CacheEntry{ 53 | ValidGroups: []string{"testGroup"}, 54 | } 55 | cache.Set(cacheKey, cacheData) 56 | 57 | // Check the cached entry can be retrieved from the cache. 58 | if data, _ := cache.Get(cacheKey); !reflect.DeepEqual(data, cacheData) { 59 | t.Logf(" expected data to be '%+v'", cacheData) 60 | t.Logf("actual data returned was '%+v'", data) 61 | t.Fatalf("unexpected data returned") 62 | } 63 | 64 | cache.Purge(cacheKey) 65 | 66 | // Purge should have removed the entry, despite being within the cache TTL 67 | if _, found := cache.get(cacheKey); found { 68 | t.Fatalf("expected key not to be have been found after purging") 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /internal/pkg/groups/mock_cache.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | // MockCache is a mock of MemberSetCache that can be used for testing purposes 4 | type MockCache struct { 5 | ListMembershipsFunc func(string) (MemberSet, bool) 6 | Exists bool 7 | Updated bool 8 | Refreshed bool 9 | } 10 | 11 | // Get returns the members and if the set exists 12 | func (mc *MockCache) Get(group string) (MemberSet, bool) { 13 | return mc.ListMembershipsFunc(group) 14 | } 15 | 16 | // Update updates the cache 17 | func (mc *MockCache) Update(string) bool { 18 | return mc.Updated 19 | } 20 | 21 | // RefreshLoop returns a boolean of if the refresh loop is refreshed 22 | func (mc *MockCache) RefreshLoop(string) bool { 23 | return mc.Refreshed 24 | } 25 | 26 | // Stop fufills the MemberSetCache interface 27 | func (mc *MockCache) Stop() { 28 | return 29 | } 30 | -------------------------------------------------------------------------------- /internal/pkg/hostmux/hostmux.go: -------------------------------------------------------------------------------- 1 | package hostmux 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | "sync" 7 | ) 8 | 9 | var ( 10 | // This is a compile-time check to make sure our types correctly implement the interface: 11 | // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae 12 | _ Route = &StaticRoute{} 13 | _ Route = &RegexpRoute{} 14 | _ Route = &DefaultRoute{} 15 | ) 16 | 17 | // misdirected replies to the request with an HTTP 421 misdirected error. 18 | func misdirected(rw http.ResponseWriter, req *http.Request) { 19 | http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest) 20 | } 21 | 22 | // Router is a generic host router that has support for static, regexp, and default routes 23 | type Router struct { 24 | mu sync.Mutex 25 | 26 | // StaticRoutes have the highest precedence, and match keys to http.Request.Host values exactly. 27 | // Added to a router by HandleStatic(). 28 | StaticRoutes map[string]*StaticRoute 29 | 30 | // RegexpRoutes match http.Request.Host based on regular expressions, if no exact match is found. 31 | // Added to a router by HandleRegexp(). 32 | RegexpRoutes []*RegexpRoute 33 | 34 | // DefaultRoute matches any remaining requests. 35 | //Added by HandleDefault() 36 | DefaultRoute *DefaultRoute 37 | } 38 | 39 | type Route interface { 40 | Handler() http.Handler 41 | } 42 | 43 | // StaticRoute matches hosts staticly based on a simple string match 44 | type StaticRoute struct { 45 | host string 46 | handler http.Handler 47 | } 48 | 49 | // Handler returns a reference to the underlying handler 50 | func (sr *StaticRoute) Handler() http.Handler { 51 | return sr.handler 52 | } 53 | 54 | // RegexpRoute matches hosts based on a regexp 55 | type RegexpRoute struct { 56 | regexp *regexp.Regexp 57 | handler http.Handler 58 | } 59 | 60 | // Handler returns a reference to the underlying handler 61 | func (rr *RegexpRoute) Handler() http.Handler { 62 | return rr.handler 63 | } 64 | 65 | // DefaultRoute doesn't match a host but serves a default if none others match 66 | type DefaultRoute struct { 67 | handler http.Handler 68 | } 69 | 70 | // Handler returns a reference to the underlying handler 71 | func (dr *DefaultRoute) Handler() http.Handler { 72 | return dr.handler 73 | } 74 | 75 | // NewRouter is convenience constructor for Router 76 | func NewRouter() *Router { 77 | return &Router{ 78 | StaticRoutes: make(map[string]*StaticRoute), 79 | RegexpRoutes: make([]*RegexpRoute, 0), 80 | DefaultRoute: &DefaultRoute{ 81 | handler: http.HandlerFunc(misdirected), 82 | }, 83 | } 84 | } 85 | 86 | // ServeHTTP matches host to either static or regexp routes. If there is no match, 87 | // it serves the DefaultRoute 88 | func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 89 | r.Route(req).Handler().ServeHTTP(rw, req) 90 | } 91 | 92 | // Route returns the route used for the request, consulting first static, regexp, and 93 | // then the default route 94 | func (r *Router) Route(req *http.Request) Route { 95 | r.mu.Lock() 96 | defer r.mu.Unlock() 97 | 98 | sr, ok := r.StaticRoutes[req.Host] 99 | if ok { 100 | return sr 101 | } 102 | 103 | for _, rr := range r.RegexpRoutes { 104 | if rr.regexp.MatchString(req.Host) { 105 | return rr 106 | } 107 | } 108 | 109 | // Nothing matched, use default route 110 | return r.DefaultRoute 111 | } 112 | 113 | // HandleStatic registers the handler func for the given host string 114 | func (r *Router) HandleStatic(host string, handler http.Handler) { 115 | r.mu.Lock() 116 | r.StaticRoutes[host] = &StaticRoute{ 117 | host: host, 118 | handler: handler, 119 | } 120 | r.mu.Unlock() 121 | } 122 | 123 | // HandleRegexp registers the handler func for the given regexp 124 | func (r *Router) HandleRegexp(regexp *regexp.Regexp, handler http.Handler) { 125 | r.mu.Lock() 126 | r.RegexpRoutes = append(r.RegexpRoutes, &RegexpRoute{ 127 | regexp: regexp, 128 | handler: handler, 129 | }) 130 | r.mu.Unlock() 131 | } 132 | 133 | // HandleDefault registers the default handler if there are no matches 134 | func (r *Router) HandleDefault(handler http.Handler) { 135 | r.mu.Lock() 136 | r.DefaultRoute = &DefaultRoute{ 137 | handler: handler, 138 | } 139 | r.mu.Unlock() 140 | } 141 | -------------------------------------------------------------------------------- /internal/pkg/hostmux/hostmux_test.go: -------------------------------------------------------------------------------- 1 | package hostmux 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "regexp" 8 | "testing" 9 | ) 10 | 11 | func simpleHandler(code int, body string) http.Handler { 12 | return http.HandlerFunc( 13 | func(rw http.ResponseWriter, req *http.Request) { 14 | rw.Header().Set("Content-Type", "text/plain; charset=utf-8") 15 | rw.Header().Set("X-Content-Type-Options", "nosniff") 16 | rw.WriteHeader(code) 17 | fmt.Fprintln(rw, body) 18 | }, 19 | ) 20 | } 21 | 22 | func TestHostMux(t *testing.T) { 23 | type expectedResponse struct { 24 | code int 25 | } 26 | testCases := []struct { 27 | name string 28 | url string 29 | statics []*StaticRoute 30 | regexps []*RegexpRoute 31 | expectedResponse expectedResponse 32 | }{ 33 | { 34 | name: "default constructor returns misdirected request", 35 | url: "http://example.com", 36 | expectedResponse: expectedResponse{ 37 | code: http.StatusMisdirectedRequest, 38 | }, 39 | }, 40 | { 41 | name: "static handler returns code expected", 42 | url: "http://example.com", 43 | statics: []*StaticRoute{ 44 | { 45 | host: "example.com", 46 | handler: simpleHandler(http.StatusAccepted, http.StatusText(http.StatusAccepted)), 47 | }, 48 | }, 49 | expectedResponse: expectedResponse{ 50 | code: http.StatusAccepted, 51 | }, 52 | }, 53 | { 54 | name: "regexp returns code expected", 55 | url: "http://example--match.com", 56 | statics: []*StaticRoute{ 57 | { 58 | host: "example.com", 59 | handler: simpleHandler(http.StatusAccepted, http.StatusText(http.StatusAccepted)), 60 | }, 61 | }, 62 | regexps: []*RegexpRoute{ 63 | { 64 | regexp: regexp.MustCompile(`example--.*\.com`), 65 | handler: simpleHandler(http.StatusCreated, http.StatusText(http.StatusCreated)), 66 | }, 67 | }, 68 | expectedResponse: expectedResponse{ 69 | code: http.StatusCreated, 70 | }, 71 | }, 72 | } 73 | for _, tc := range testCases { 74 | t.Run(tc.name, func(t *testing.T) { 75 | // setup router 76 | router := NewRouter() 77 | for _, s := range tc.statics { 78 | router.HandleStatic(s.host, s.handler) 79 | } 80 | for _, re := range tc.regexps { 81 | router.HandleRegexp(re.regexp, re.handler) 82 | } 83 | 84 | // prepare response writer / request 85 | rw := httptest.NewRecorder() 86 | req, err := http.NewRequest("GET", tc.url, nil) 87 | if err != nil { 88 | t.Fatalf("unexpected err %v", err) 89 | } 90 | 91 | // execute 92 | router.ServeHTTP(rw, req) 93 | resp := rw.Result() 94 | 95 | // verify worked as expected 96 | if resp.StatusCode != tc.expectedResponse.code { 97 | t.Errorf(" got: %v", resp.StatusCode) 98 | t.Errorf("want: %v", tc.expectedResponse.code) 99 | t.Fatalf("expected status codes to be equal") 100 | } 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /internal/pkg/httpserver/httpserver.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/buzzfeed/sso/internal/pkg/logging" 13 | ) 14 | 15 | // OS signals that will initiate graceful shutdown of the http server. 16 | // 17 | // NOTE: defined in a variable so that they may be overridden by tests. 18 | var shutdownSignals = []os.Signal{ 19 | syscall.SIGINT, 20 | syscall.SIGTERM, 21 | } 22 | 23 | // Run runs an http server and ensures that it is shut down gracefully within 24 | // the given shutdown timeout, allowing all in-flight requests to complete. 25 | // 26 | // Returns an error if a) the server fails to listen on its port or b) the 27 | // shutdown timeout elapses before all in-flight requests are finished. 28 | func Run(srv *http.Server, shutdownTimeout time.Duration, logger *logging.LogEntry) error { 29 | // Logic below copied from the stdlib http.Server ListenAndServe() method: 30 | // https://github.com/golang/go/blob/release-branch.go1.13/src/net/http/server.go#L2805-L2826 31 | addr := srv.Addr 32 | if addr == "" { 33 | addr = ":http" 34 | } 35 | ln, err := net.Listen("tcp", addr) 36 | if err != nil { 37 | return err 38 | } 39 | return runWithListener(ln, srv, shutdownTimeout, logger) 40 | } 41 | 42 | // runWithListener does the heavy lifting for Run() above, and is decoupled 43 | // only for testing purposes 44 | func runWithListener(ln net.Listener, srv *http.Server, shutdownTimeout time.Duration, logger *logging.LogEntry) error { 45 | var ( 46 | // shutdownCh triggers graceful shutdown on SIGINT or SIGTERM 47 | shutdownCh = make(chan os.Signal, 1) 48 | 49 | // exitCh will be closed when it is safe to exit, after graceful shutdown 50 | exitCh = make(chan struct{}) 51 | 52 | // shutdownErr allows any error from srv.Shutdown to propagate out up 53 | // from the goroutine 54 | shutdownErr error 55 | ) 56 | 57 | signal.Notify(shutdownCh, shutdownSignals...) 58 | 59 | go func() { 60 | sig := <-shutdownCh 61 | logger.Info("shutdown started by signal: ", sig) 62 | signal.Stop(shutdownCh) 63 | 64 | logger.Info("waiting for server to shut down in ", shutdownTimeout) 65 | ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) 66 | defer cancel() 67 | 68 | shutdownErr = srv.Shutdown(ctx) 69 | close(exitCh) 70 | }() 71 | 72 | if serveErr := srv.Serve(ln); serveErr != nil && serveErr != http.ErrServerClosed { 73 | return serveErr 74 | } 75 | 76 | <-exitCh 77 | logger.Info("shutdown finished") 78 | 79 | return shutdownErr 80 | } 81 | -------------------------------------------------------------------------------- /internal/pkg/httpserver/httpserver_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "os" 8 | "sync" 9 | "syscall" 10 | "testing" 11 | "time" 12 | 13 | "github.com/buzzfeed/sso/internal/pkg/logging" 14 | ) 15 | 16 | func newLocalListener(t *testing.T) net.Listener { 17 | t.Helper() 18 | 19 | l, err := net.Listen("tcp", "127.0.0.1:0") 20 | if err != nil { 21 | t.Fatalf("failed to listen on a port: %v", err) 22 | } 23 | return l 24 | } 25 | 26 | func TestGracefulShutdown(t *testing.T) { 27 | proc, err := os.FindProcess(os.Getpid()) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | // override shutdown signals used by Run for testing purposes 33 | shutdownSignals = []os.Signal{syscall.SIGUSR1} 34 | 35 | logger := logging.NewLogEntry() 36 | 37 | testCases := map[string]struct { 38 | shutdownTimeout time.Duration 39 | requestDelay time.Duration 40 | expectShutdownErr bool 41 | expectRequestErr bool 42 | }{ 43 | "clean shutdown": { 44 | shutdownTimeout: 1 * time.Second, 45 | requestDelay: 250 * time.Millisecond, 46 | expectShutdownErr: false, 47 | expectRequestErr: false, 48 | }, 49 | "timeout elapsed": { 50 | shutdownTimeout: 50 * time.Millisecond, 51 | requestDelay: 250 * time.Millisecond, 52 | expectShutdownErr: true, 53 | 54 | // In real usage, we would expect the request to be aborted when 55 | // the server is shut down and its process exits before it can 56 | // finish responding. 57 | // 58 | // But because we're running the server within the test process, 59 | // which does not exit after shutdown, the goroutine handling the 60 | // long-running request does not seem to get canceled and the 61 | // request ends up completing successfully even after the server 62 | // has shut down. 63 | // 64 | // Properly testing this would require something like re-running 65 | // the test binary as a subprocess to which we can send SIGTERM, 66 | // but doing that would add a lot more complexity (e.g. having it 67 | // bind to a random available port and then running a separate 68 | // subprocess to figure out the port to which it is bound, all in a 69 | // cross-platform way). 70 | // 71 | // If we wanted to go that route, some examples of the general 72 | // approach can be seen here: 73 | // 74 | // - http://cs-guy.com/blog/2015/01/test-main/#toc_3 75 | // - https://talks.golang.org/2014/testing.slide#23 76 | expectRequestErr: false, 77 | }, 78 | } 79 | 80 | for name, tc := range testCases { 81 | t.Run(name, func(t *testing.T) { 82 | var ( 83 | ln = newLocalListener(t) 84 | addr = ln.Addr().String() 85 | url = fmt.Sprintf("http://%s", addr) 86 | ) 87 | 88 | srv := &http.Server{ 89 | Addr: addr, 90 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 91 | time.Sleep(tc.requestDelay) 92 | }), 93 | } 94 | 95 | var ( 96 | wg sync.WaitGroup 97 | shutdownErr error 98 | requestErr error 99 | ) 100 | 101 | // Run our server and wait for a stop signal 102 | wg.Add(1) 103 | go func() { 104 | defer wg.Done() 105 | shutdownErr = runWithListener(ln, srv, tc.shutdownTimeout, logger) 106 | }() 107 | 108 | // give the server time to start listening 109 | <-time.After(50 * time.Millisecond) 110 | 111 | // make a request 112 | wg.Add(1) 113 | go func() { 114 | defer wg.Done() 115 | _, requestErr = http.Get(url) 116 | }() 117 | 118 | // give the request some time to connect 119 | <-time.After(1 * time.Millisecond) 120 | 121 | // tell server to shut down gracefully 122 | proc.Signal(syscall.SIGUSR1) 123 | 124 | // wait for server to shut down and requests to complete 125 | wg.Wait() 126 | 127 | if tc.expectShutdownErr { 128 | if shutdownErr == nil { 129 | t.Fatalf("did not get expected shutdown error") 130 | } 131 | } else { 132 | if shutdownErr != nil { 133 | t.Fatalf("got unexpected shutdown error: %s", shutdownErr) 134 | } 135 | } 136 | 137 | if tc.expectRequestErr && requestErr == nil { 138 | if requestErr == nil { 139 | t.Fatalf("did not get expected request error") 140 | } 141 | } else { 142 | if requestErr != nil { 143 | t.Fatalf("got unexpected request error: %s", requestErr) 144 | } 145 | } 146 | }) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /internal/pkg/sessions/mock_store.go: -------------------------------------------------------------------------------- 1 | package sessions 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | ) 7 | 8 | // MockCSRFStore is a mock implementation of the CSRF store interface 9 | type MockCSRFStore struct { 10 | ResponseCSRF string 11 | Cookie *http.Cookie 12 | GetError error 13 | } 14 | 15 | // SetCSRF sets the ResponseCSRF string to a val 16 | func (ms *MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) { 17 | ms.ResponseCSRF = val 18 | } 19 | 20 | // ClearCSRF clears the ResponseCSRF string 21 | func (ms *MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) { 22 | ms.ResponseCSRF = "" 23 | } 24 | 25 | // GetCSRF returns the cookie and error 26 | func (ms *MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) { 27 | return ms.Cookie, ms.GetError 28 | } 29 | 30 | // MockSessionStore is a mock implementation of the SessionStore interface 31 | type MockSessionStore struct { 32 | ResponseSession string 33 | Session *SessionState 34 | SaveError error 35 | LoadError error 36 | } 37 | 38 | // ClearSession clears the ResponseSession 39 | func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) { 40 | ms.ResponseSession = "" 41 | } 42 | 43 | // LoadSession returns the session and a error 44 | func (ms *MockSessionStore) LoadSession(*http.Request) (*SessionState, error) { 45 | if ms.Session == nil { 46 | ms.LoadError = http.ErrNoCookie 47 | } 48 | return ms.Session, ms.LoadError 49 | } 50 | 51 | // SaveSession returns a save error. 52 | func (ms *MockSessionStore) SaveSession(rw http.ResponseWriter, req *http.Request, s *SessionState) error { 53 | marshaled, _ := json.Marshal(s) 54 | ms.ResponseSession = string(marshaled) 55 | return ms.SaveError 56 | } 57 | -------------------------------------------------------------------------------- /internal/pkg/sessions/session_state.go: -------------------------------------------------------------------------------- 1 | package sessions 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/buzzfeed/sso/internal/pkg/aead" 8 | ) 9 | 10 | var ( 11 | // ErrLifetimeExpired is an error for the lifetime deadline expiring 12 | ErrLifetimeExpired = errors.New("user lifetime expired") 13 | ) 14 | 15 | // SessionState is our object that keeps track of a user's session state 16 | type SessionState struct { 17 | ProviderSlug string `json:"slug"` 18 | ProviderType string `json:"type"` 19 | 20 | AccessToken string `json:"access_token"` 21 | RefreshToken string `json:"refresh_token"` 22 | 23 | RefreshDeadline time.Time `json:"refresh_deadline"` 24 | LifetimeDeadline time.Time `json:"lifetime_deadline"` 25 | ValidDeadline time.Time `json:"valid_deadline"` 26 | GracePeriodStart time.Time `json:"grace_period_start"` 27 | 28 | Email string `json:"email"` 29 | User string `json:"user"` 30 | Groups []string `json:"groups"` 31 | AuthorizedUpstream string `json:"authorized_upstream"` 32 | } 33 | 34 | // LifetimePeriodExpired returns true if the lifetime has expired 35 | func (s *SessionState) LifetimePeriodExpired() bool { 36 | return isExpired(s.LifetimeDeadline) 37 | } 38 | 39 | // RefreshPeriodExpired returns true if the refresh period has expired 40 | func (s *SessionState) RefreshPeriodExpired() bool { 41 | return isExpired(s.RefreshDeadline) 42 | } 43 | 44 | // ValidationPeriodExpired returns true if the validation period has expired 45 | func (s *SessionState) ValidationPeriodExpired() bool { 46 | return isExpired(s.ValidDeadline) 47 | } 48 | 49 | func isExpired(t time.Time) bool { 50 | if t.Before(time.Now()) { 51 | return true 52 | } 53 | return false 54 | } 55 | 56 | // IsWithinGracePeriod returns true if the session is still within the grace period 57 | func (s *SessionState) IsWithinGracePeriod(gracePeriodTTL time.Duration) bool { 58 | if s.GracePeriodStart.IsZero() { 59 | s.GracePeriodStart = time.Now() 60 | } 61 | return s.GracePeriodStart.Add(gracePeriodTTL).After(time.Now()) 62 | } 63 | 64 | // MarshalSession marshals the session state as JSON, encrypts the JSON using the 65 | // given cipher, and base64-encodes the result 66 | func MarshalSession(s *SessionState, c aead.Cipher) (string, error) { 67 | return c.Marshal(s) 68 | } 69 | 70 | // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the 71 | // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct 72 | func UnmarshalSession(value string, c aead.Cipher) (*SessionState, error) { 73 | s := &SessionState{} 74 | err := c.Unmarshal(value, s) 75 | if err != nil { 76 | return nil, err 77 | } 78 | return s, nil 79 | } 80 | 81 | // ExtendDeadline returns the time extended by a given duration 82 | func ExtendDeadline(ttl time.Duration) time.Time { 83 | return time.Now().Add(ttl).Truncate(time.Second) 84 | } 85 | -------------------------------------------------------------------------------- /internal/pkg/sessions/session_state_test.go: -------------------------------------------------------------------------------- 1 | package sessions 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/buzzfeed/sso/internal/pkg/aead" 9 | ) 10 | 11 | func TestSessionStateSerialization(t *testing.T) { 12 | secret := aead.GenerateKey() 13 | c, err := aead.NewMiscreantCipher([]byte(secret)) 14 | if err != nil { 15 | t.Fatalf("expected to be able to create cipher: %v", err) 16 | } 17 | 18 | want := &SessionState{ 19 | ProviderSlug: "slug", 20 | ProviderType: "sso", 21 | 22 | AccessToken: "token1234", 23 | RefreshToken: "refresh4321", 24 | 25 | LifetimeDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), 26 | RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), 27 | ValidDeadline: time.Now().Add(1 * time.Minute).Truncate(time.Second).UTC(), 28 | 29 | Email: "user@domain.com", 30 | User: "user", 31 | } 32 | 33 | ciphertext, err := MarshalSession(want, c) 34 | if err != nil { 35 | t.Fatalf("expected to be encode session: %v", err) 36 | } 37 | 38 | got, err := UnmarshalSession(ciphertext, c) 39 | if err != nil { 40 | t.Fatalf("expected to be decode session: %v", err) 41 | } 42 | 43 | if !reflect.DeepEqual(want, got) { 44 | t.Logf("want: %#v", want) 45 | t.Logf(" got: %#v", got) 46 | t.Errorf("encoding and decoding session resulted in unexpected output") 47 | } 48 | } 49 | 50 | func TestSessionStateExpirations(t *testing.T) { 51 | session := &SessionState{ 52 | AccessToken: "token1234", 53 | RefreshToken: "refresh4321", 54 | 55 | LifetimeDeadline: time.Now().Add(-1 * time.Hour), 56 | RefreshDeadline: time.Now().Add(-1 * time.Hour), 57 | ValidDeadline: time.Now().Add(-1 * time.Minute), 58 | GracePeriodStart: time.Now().Add(-2 * time.Minute), 59 | 60 | Email: "user@domain.com", 61 | User: "user", 62 | } 63 | 64 | if !session.LifetimePeriodExpired() { 65 | t.Errorf("expected lifetime period to be expired") 66 | } 67 | 68 | if !session.RefreshPeriodExpired() { 69 | t.Errorf("expected lifetime period to be expired") 70 | } 71 | 72 | if !session.ValidationPeriodExpired() { 73 | t.Errorf("expected lifetime period to be expired") 74 | } 75 | 76 | if session.IsWithinGracePeriod(1 * time.Minute) { 77 | t.Errorf("expected session to be outside of grace period") 78 | } 79 | 80 | if !session.IsWithinGracePeriod(3 * time.Minute) { 81 | t.Errorf("expected session to be inside of grace period") 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /internal/pkg/singleflight/singleflight.go: -------------------------------------------------------------------------------- 1 | // Original Copyright 2013 The Go Authors. All rights reserved. 2 | // 3 | // Modified by BuzzFeed to return duplicate counts. 4 | // 5 | // Use of this source code is governed by a BSD-style 6 | // license that can be found in the LICENSE file. 7 | 8 | // Package singleflight provides a duplicate function call suppression mechanism. 9 | package singleflight 10 | 11 | import "sync" 12 | 13 | // call is an in-flight or completed singleflight.Do call 14 | type call struct { 15 | wg sync.WaitGroup 16 | 17 | // These fields are written once before the WaitGroup is done 18 | // and are only read after the WaitGroup is done. 19 | val interface{} 20 | err error 21 | 22 | // These fields are read and written with the singleflight 23 | // mutex held before the WaitGroup is done, and are read but 24 | // not written after the WaitGroup is done. 25 | dups int 26 | } 27 | 28 | // Group represents a class of work and forms a namespace in 29 | // which units of work can be executed with duplicate suppression. 30 | type Group struct { 31 | mu sync.Mutex // protects m 32 | m map[string]*call // lazily initialized 33 | } 34 | 35 | // Result holds the results of Do, so they can be passed 36 | // on a channel. 37 | type Result struct { 38 | Val interface{} 39 | Err error 40 | Count bool 41 | } 42 | 43 | // Do executes and returns the results of the given function, making 44 | // sure that only one execution is in-flight for a given key at a 45 | // time. If a duplicate comes in, the duplicate caller waits for the 46 | // original to complete and receives the same results. 47 | // The return value of Count indicates how many tiems v was given to multiple callers. 48 | // Count will be zero for requests are shared and only be non-zero for the originating request. 49 | func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, count int, err error) { 50 | g.mu.Lock() 51 | 52 | if g.m == nil { 53 | g.m = make(map[string]*call) 54 | } 55 | if c, ok := g.m[key]; ok { 56 | c.dups++ 57 | g.mu.Unlock() 58 | c.wg.Wait() 59 | return c.val, 0, c.err 60 | } 61 | c := new(call) 62 | c.wg.Add(1) 63 | g.m[key] = c 64 | 65 | g.mu.Unlock() 66 | 67 | c.val, c.err = fn() 68 | c.wg.Done() 69 | 70 | g.mu.Lock() 71 | delete(g.m, key) 72 | g.mu.Unlock() 73 | 74 | return c.val, c.dups, c.err 75 | } 76 | -------------------------------------------------------------------------------- /internal/pkg/singleflight/singleflight_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package singleflight 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "sync" 11 | "sync/atomic" 12 | "testing" 13 | "time" 14 | ) 15 | 16 | func TestDo(t *testing.T) { 17 | var g Group 18 | v, _, err := g.Do("key", func() (interface{}, error) { 19 | return "bar", nil 20 | }) 21 | if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { 22 | t.Errorf("Do = %v; want %v", got, want) 23 | } 24 | if err != nil { 25 | t.Errorf("Do error = %v", err) 26 | } 27 | } 28 | 29 | func TestDoErr(t *testing.T) { 30 | var g Group 31 | someErr := errors.New("Some error") 32 | v, _, err := g.Do("key", func() (interface{}, error) { 33 | return nil, someErr 34 | }) 35 | if err != someErr { 36 | t.Errorf("Do error = %v; want someErr %v", err, someErr) 37 | } 38 | if v != nil { 39 | t.Errorf("unexpected non-nil value %#v", v) 40 | } 41 | } 42 | 43 | func TestDoDupSuppress(t *testing.T) { 44 | var g Group 45 | var wg1, wg2 sync.WaitGroup 46 | c := make(chan string, 1) 47 | var calls int32 48 | fn := func() (interface{}, error) { 49 | if atomic.AddInt32(&calls, 1) == 1 { 50 | // First invocation. 51 | wg1.Done() 52 | } 53 | v := <-c 54 | c <- v // pump; make available for any future calls 55 | 56 | time.Sleep(10 * time.Millisecond) // let more goroutines enter Do 57 | 58 | return v, nil 59 | } 60 | 61 | const n = 10 62 | wg1.Add(1) 63 | for i := 0; i < n; i++ { 64 | wg1.Add(1) 65 | wg2.Add(1) 66 | go func() { 67 | defer wg2.Done() 68 | wg1.Done() 69 | v, _, err := g.Do("key", fn) 70 | if err != nil { 71 | t.Errorf("Do error: %v", err) 72 | return 73 | } 74 | if s, _ := v.(string); s != "bar" { 75 | t.Errorf("Do = %T %v; want %q", v, v, "bar") 76 | } 77 | }() 78 | } 79 | wg1.Wait() 80 | // At least one goroutine is in fn now and all of them have at 81 | // least reached the line before the Do. 82 | c <- "bar" 83 | wg2.Wait() 84 | if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { 85 | t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /internal/pkg/templates/mock_templates.go: -------------------------------------------------------------------------------- 1 | package templates 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | ) 7 | 8 | // MockTemplate mocks the template interface wrapper 9 | type MockTemplate struct{} 10 | 11 | // ExecuteTemplate write the data to the io.Writer 12 | func (mt *MockTemplate) ExecuteTemplate(rw io.Writer, path string, data interface{}) { 13 | jsonBytes, _ := json.Marshal(data) 14 | rw.Write(jsonBytes) 15 | } 16 | -------------------------------------------------------------------------------- /internal/pkg/templates/templates.go: -------------------------------------------------------------------------------- 1 | package templates 2 | 3 | import ( 4 | "html/template" 5 | "io" 6 | ) 7 | 8 | // Template represents html templates 9 | type Template interface { 10 | ExecuteTemplate(io.Writer, string, interface{}) 11 | } 12 | 13 | // HTMLTemplate is a wrapper around html/template package 14 | type HTMLTemplate struct { 15 | templates *template.Template 16 | } 17 | 18 | // ExecuteTemplate wraps the html/template ExecuteTemplate function 19 | func (ht *HTMLTemplate) ExecuteTemplate(rw io.Writer, path string, data interface{}) { 20 | ht.templates.ExecuteTemplate(rw, path, data) 21 | } 22 | 23 | // NewHTMLTemplate returns a new HTMLTemplate 24 | func NewHTMLTemplate() *HTMLTemplate { 25 | t := template.New("foo") 26 | template.Must(t.Parse(`{{define "header.html"}} 27 | 28 | 29 | {{end}}`)) 30 | 31 | t = template.Must(t.Parse(`{{define "footer.html"}} 32 | Secured by SSO{{end}}`)) 33 | 34 | t = template.Must(t.Parse(`{{define "sign_in_message.html"}} 35 | {{if eq (len .EmailDomains) 1}} 36 | {{if eq (index .EmailDomains 0) "@*"}} 37 |

You may sign in with any {{.ProviderName}} account.

38 | {{else}} 39 |

You may sign in with your {{index .EmailDomains 0}} {{.ProviderName}} account.

40 | {{end}} 41 | {{else if gt (len .EmailDomains) 1}} 42 |

43 | You may sign in with any of these {{.ProviderName}} accounts:
44 | {{range $i, $e := .EmailDomains}}{{if $i}}, {{end}}{{$e}}{{end}} 45 |

46 | {{end}} 47 | {{end}}`)) 48 | 49 | t = template.Must(t.Parse(`{{define "sign_in.html"}} 50 | 51 | 52 | 53 | Sign In 54 | {{template "header.html"}} 55 | 56 | 57 | 58 | 59 |
60 |
61 |
62 |

Sign in to {{.Destination}}

63 |
64 | 65 | {{template "sign_in_message.html" .}} 66 | 67 |
68 | 69 | 70 |
71 |
72 | 73 |
{{template "footer.html"}}
74 |
75 | 76 | 77 | {{end}}`)) 78 | 79 | template.Must(t.Parse(`{{define "error.html"}} 80 | 81 | 82 | 83 | Error 84 | {{template "header.html"}} 85 | 86 | 87 |
88 |
89 |
90 |

{{.Title}}

91 |
92 |

93 | {{.Message}}
94 | HTTP {{.Code}} 95 |

96 |
97 |
{{template "footer.html"}}
98 |
99 | 100 | {{end}}`)) 101 | 102 | t = template.Must(t.Parse(`{{define "sign_out.html"}} 103 | 104 | 105 | 106 | Sign Out 107 | {{template "header.html"}} 108 | 109 | 110 |
111 | {{ if .Message }} 112 |
{{.Message}}
113 | {{ end}} 114 |
115 |
116 |

Sign out of {{.Destination}}

117 |
118 | 119 |

You're currently signed in as {{.Email}}. This will also sign you out of other internal apps.

120 |
121 | 122 | 123 | 124 | 125 |
126 |
127 |
{{template "footer.html"}}
128 |
129 | 130 | 131 | {{end}}`)) 132 | return &HTMLTemplate{t} 133 | } 134 | -------------------------------------------------------------------------------- /internal/pkg/templates/templates_test.go: -------------------------------------------------------------------------------- 1 | package templates 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestTemplatesCompile(t *testing.T) { 10 | 11 | templates := NewHTMLTemplate() 12 | if templates.templates == nil { 13 | t.Error("getTemplates() returned nil") 14 | } 15 | } 16 | 17 | func TestSignInMessage(t *testing.T) { 18 | 19 | testCases := []struct { 20 | name string 21 | emailDomains []string 22 | expectedSubstrings []string 23 | }{ 24 | { 25 | name: "single email domain", 26 | emailDomains: []string{"example.com"}, 27 | expectedSubstrings: []string{"

You may sign in with your example.com Google account.

"}, 28 | }, 29 | { 30 | name: "single email that is a wildcard", 31 | emailDomains: []string{"@*"}, 32 | expectedSubstrings: []string{"

You may sign in with any Google account.

"}, 33 | }, 34 | { 35 | name: "multiple email domains", 36 | emailDomains: []string{"foo.com", "bar.com"}, 37 | expectedSubstrings: []string{ 38 | "You may sign in with any of these Google accounts:
", 39 | "foo.com, bar.com", 40 | }, 41 | }, 42 | // TODO: should we reject this configuration? We don't really expect 43 | // multiple email domains that also use wildcards, so we don't have 44 | // special handling for this case in the templates. 45 | { 46 | name: "multiple email domains including a wildcard", 47 | emailDomains: []string{"*", "foo.com", "bar.com"}, 48 | expectedSubstrings: []string{ 49 | "You may sign in with any of these Google accounts:
", 50 | "*, foo.com, bar.com", 51 | }, 52 | }, 53 | } 54 | 55 | templates := NewHTMLTemplate() 56 | 57 | for _, tc := range testCases { 58 | t.Run(tc.name, func(t *testing.T) { 59 | buf := &bytes.Buffer{} 60 | ctx := struct { 61 | ProviderName string 62 | EmailDomains []string 63 | }{ 64 | ProviderName: "Google", 65 | EmailDomains: tc.emailDomains, 66 | } 67 | templates.ExecuteTemplate(buf, "sign_in_message.html", ctx) 68 | result := buf.String() 69 | 70 | for _, substring := range tc.expectedSubstrings { 71 | if !strings.Contains(result, substring) { 72 | t.Errorf("substring %#v not found in rendered template:\n%s", substring, result) 73 | } 74 | } 75 | }) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /internal/pkg/testutil/testutil.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | // testing util functions copied from https://github.com/benbjohnson/testing 4 | import ( 5 | "fmt" 6 | "path/filepath" 7 | "reflect" 8 | "runtime" 9 | "testing" 10 | ) 11 | 12 | // Assert fails the test if the condition is false. 13 | func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) { 14 | if !condition { 15 | _, file, line, _ := runtime.Caller(1) 16 | fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...) 17 | tb.FailNow() 18 | } 19 | } 20 | 21 | // Ok fails the test if an err is not nil. 22 | func Ok(tb testing.TB, err error) { 23 | if err != nil { 24 | _, file, line, _ := runtime.Caller(1) 25 | fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error()) 26 | tb.FailNow() 27 | } 28 | } 29 | 30 | // Equal fails the test if exp is not equal to act. 31 | func Equal(tb testing.TB, exp, act interface{}) { 32 | if !reflect.DeepEqual(exp, act) { 33 | _, file, line, _ := runtime.Caller(1) 34 | fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act) 35 | tb.FailNow() 36 | } 37 | } 38 | 39 | // NotEqual fails the test if exp is equal to act. 40 | func NotEqual(tb testing.TB, exp, act interface{}) { 41 | if reflect.DeepEqual(exp, act) { 42 | _, file, line, _ := runtime.Caller(1) 43 | fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act) 44 | tb.FailNow() 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/pkg/validators/email_address_validator.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/buzzfeed/sso/internal/pkg/sessions" 9 | ) 10 | 11 | var ( 12 | _ Validator = EmailAddressValidator{} 13 | 14 | // These error message should be formatted in such a way that is appropriate 15 | // for display to the end user. 16 | ErrEmailAddressDenied = errors.New("Unauthorized Email Address") 17 | ) 18 | 19 | type EmailAddressValidator struct { 20 | AllowedEmails []string 21 | } 22 | 23 | // NewEmailAddressValidator takes in a list of email addresses and returns a Validator object. 24 | // The validator can be used to validate that the session.Email: 25 | // - is non-empty 26 | // - matches one of the originally passed in email addresses 27 | // (case insensitive) 28 | // - if the originally passed in list of emails consists only of "*", then all emails 29 | // are considered valid based on their domain. 30 | // 31 | // If valid, nil is returned in place of an error. 32 | func NewEmailAddressValidator(allowedEmails []string) EmailAddressValidator { 33 | var emailAddresses []string 34 | 35 | for _, email := range allowedEmails { 36 | emailAddress := fmt.Sprintf("%s", strings.ToLower(email)) 37 | emailAddresses = append(emailAddresses, emailAddress) 38 | } 39 | 40 | return EmailAddressValidator{ 41 | AllowedEmails: emailAddresses, 42 | } 43 | } 44 | 45 | func (v EmailAddressValidator) Validate(session *sessions.SessionState) error { 46 | if session.Email == "" { 47 | return ErrInvalidEmailAddress 48 | } 49 | 50 | if len(v.AllowedEmails) == 0 { 51 | return ErrEmailAddressDenied 52 | } 53 | 54 | if len(v.AllowedEmails) == 1 && v.AllowedEmails[0] == "*" { 55 | return nil 56 | } 57 | 58 | err := v.validate(session) 59 | if err != nil { 60 | return err 61 | } 62 | return nil 63 | } 64 | 65 | func (v EmailAddressValidator) validate(session *sessions.SessionState) error { 66 | email := strings.ToLower(session.Email) 67 | for _, emailItem := range v.AllowedEmails { 68 | if email == emailItem { 69 | return nil 70 | } 71 | } 72 | return ErrEmailAddressDenied 73 | } 74 | -------------------------------------------------------------------------------- /internal/pkg/validators/email_address_validator_test.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | ) 8 | 9 | func TestEmailAddressValidatorValidator(t *testing.T) { 10 | testCases := []struct { 11 | name string 12 | AllowedEmails []string 13 | email string 14 | expectedErr error 15 | session *sessions.SessionState 16 | }{ 17 | { 18 | name: "nothing should validate when address list is empty", 19 | AllowedEmails: []string(nil), 20 | session: &sessions.SessionState{ 21 | Email: "foo@example.com", 22 | }, 23 | expectedErr: ErrEmailAddressDenied, 24 | }, 25 | { 26 | name: "single address validation", 27 | AllowedEmails: []string{"foo@example.com"}, 28 | session: &sessions.SessionState{ 29 | Email: "foo@example.com", 30 | }, 31 | expectedErr: nil, 32 | }, 33 | { 34 | name: "substring matches are rejected", 35 | AllowedEmails: []string{"foo@example.com"}, 36 | session: &sessions.SessionState{ 37 | Email: "foo@hackerexample.com", 38 | }, 39 | expectedErr: ErrEmailAddressDenied, 40 | }, 41 | { 42 | name: "no subdomain rollup happens", 43 | AllowedEmails: []string{"foo@example.com"}, 44 | session: &sessions.SessionState{ 45 | Email: "foo@bar.example.com", 46 | }, 47 | expectedErr: ErrEmailAddressDenied, 48 | }, 49 | { 50 | name: "multiple address validation still rejects other addresses", 51 | AllowedEmails: []string{"foo@abc.com", "foo@xyz.com"}, 52 | session: &sessions.SessionState{ 53 | Email: "foo@example.com", 54 | }, 55 | expectedErr: ErrEmailAddressDenied, 56 | }, 57 | { 58 | name: "multiple address validation still accepts emails from either address", 59 | AllowedEmails: []string{"foo@abc.com", "foo@xyz.com"}, 60 | session: &sessions.SessionState{ 61 | Email: "foo@abc.com", 62 | }, 63 | expectedErr: nil, 64 | }, 65 | { 66 | name: "multiple address validation still rejects other addresses", 67 | AllowedEmails: []string{"foo@abc.com", "bar@xyz.com"}, 68 | session: &sessions.SessionState{ 69 | Email: "bar@xyz.com", 70 | }, 71 | expectedErr: nil, 72 | }, 73 | { 74 | name: "comparisons are case insensitive-1", 75 | AllowedEmails: []string{"Foo@Example.Com"}, 76 | session: &sessions.SessionState{ 77 | Email: "foo@example.com", 78 | }, 79 | expectedErr: nil, 80 | }, 81 | { 82 | name: "comparisons are case insensitive-2", 83 | AllowedEmails: []string{"Foo@Example.Com"}, 84 | session: &sessions.SessionState{ 85 | Email: "foo@EXAMPLE.COM", 86 | }, 87 | expectedErr: nil, 88 | }, 89 | { 90 | name: "comparisons are case insensitive-3", 91 | AllowedEmails: []string{"foo@example.com"}, 92 | session: &sessions.SessionState{ 93 | Email: "foo@ExAmPLE.CoM", 94 | }, 95 | expectedErr: nil, 96 | }, 97 | { 98 | name: "single wildcard allows all", 99 | AllowedEmails: []string{"*"}, 100 | session: &sessions.SessionState{ 101 | Email: "foo@example.com", 102 | }, 103 | expectedErr: nil, 104 | }, 105 | { 106 | name: "single wildcard allows all", 107 | AllowedEmails: []string{"*"}, 108 | session: &sessions.SessionState{ 109 | Email: "bar@gmail.com", 110 | }, 111 | expectedErr: nil, 112 | }, 113 | { 114 | name: "wildcard is ignored if other domains included", 115 | AllowedEmails: []string{"foo@example.com", "*"}, 116 | session: &sessions.SessionState{ 117 | Email: "foo@gmail.com", 118 | }, 119 | expectedErr: ErrEmailAddressDenied, 120 | }, 121 | { 122 | name: "empty email rejected", 123 | AllowedEmails: []string{"foo@example.com"}, 124 | email: "", 125 | session: &sessions.SessionState{ 126 | Email: "", 127 | }, 128 | expectedErr: ErrInvalidEmailAddress, 129 | }, 130 | { 131 | name: "wildcard still rejects empty emails", 132 | AllowedEmails: []string{"*"}, 133 | email: "", 134 | session: &sessions.SessionState{ 135 | Email: "", 136 | }, 137 | expectedErr: ErrInvalidEmailAddress, 138 | }, 139 | } 140 | 141 | for _, tc := range testCases { 142 | t.Run(tc.name, func(t *testing.T) { 143 | emailValidator := NewEmailAddressValidator(tc.AllowedEmails) 144 | err := emailValidator.Validate(tc.session) 145 | if err != tc.expectedErr { 146 | t.Fatalf("expected %v, got %v", tc.expectedErr, err) 147 | } 148 | }) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /internal/pkg/validators/email_domain_validator.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/buzzfeed/sso/internal/pkg/sessions" 9 | ) 10 | 11 | var ( 12 | _ Validator = EmailDomainValidator{} 13 | 14 | // These error message should be formatted in such a way that is appropriate 15 | // for display to the end user. 16 | ErrEmailDomainDenied = errors.New("Unauthorized Email Domain") 17 | ) 18 | 19 | type EmailDomainValidator struct { 20 | AllowedDomains []string 21 | } 22 | 23 | // NewEmailDomainValidator takes in a list of domains and returns a Validator object. 24 | // The validator can be used to validate that the session.Email: 25 | // - is non-empty 26 | // - the domain of the email address matches one of the originally passed in domains. 27 | // (case insensitive) 28 | // - if the originally passed in list of domains consists only of "*", then all emails 29 | // are considered valid based on their domain. 30 | // 31 | // If valid, nil is returned in place of an error. 32 | func NewEmailDomainValidator(allowedDomains []string) EmailDomainValidator { 33 | emailDomains := make([]string, 0, len(allowedDomains)) 34 | 35 | for _, domain := range allowedDomains { 36 | if domain == "*" { 37 | emailDomains = append(emailDomains, domain) 38 | } else { 39 | emailDomain := fmt.Sprintf("@%s", strings.ToLower(domain)) 40 | emailDomains = append(emailDomains, emailDomain) 41 | } 42 | } 43 | return EmailDomainValidator{ 44 | AllowedDomains: emailDomains, 45 | } 46 | } 47 | 48 | func (v EmailDomainValidator) Validate(session *sessions.SessionState) error { 49 | if session.Email == "" { 50 | return ErrInvalidEmailAddress 51 | } 52 | 53 | if len(v.AllowedDomains) == 0 { 54 | return ErrEmailDomainDenied 55 | } 56 | 57 | if len(v.AllowedDomains) == 1 && v.AllowedDomains[0] == "*" { 58 | return nil 59 | } 60 | 61 | err := v.validate(session) 62 | if err != nil { 63 | return err 64 | } 65 | return nil 66 | } 67 | 68 | func (v EmailDomainValidator) validate(session *sessions.SessionState) error { 69 | email := strings.ToLower(session.Email) 70 | for _, domain := range v.AllowedDomains { 71 | if strings.HasSuffix(email, domain) { 72 | return nil 73 | } 74 | } 75 | return ErrEmailDomainDenied 76 | } 77 | -------------------------------------------------------------------------------- /internal/pkg/validators/email_domain_validator_test.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | ) 8 | 9 | func TestEmailDomainValidatorValidator(t *testing.T) { 10 | testCases := []struct { 11 | name string 12 | allowedDomains []string 13 | email string 14 | expectedErr error 15 | session *sessions.SessionState 16 | }{ 17 | { 18 | name: "nothing should validate when domain list is empty", 19 | allowedDomains: []string(nil), 20 | session: &sessions.SessionState{ 21 | Email: "foo@example.com", 22 | }, 23 | expectedErr: ErrEmailDomainDenied, 24 | }, 25 | { 26 | name: "single domain validation", 27 | allowedDomains: []string{"example.com"}, 28 | session: &sessions.SessionState{ 29 | Email: "foo@example.com", 30 | }, 31 | expectedErr: nil, 32 | }, 33 | { 34 | name: "substring matches are rejected", 35 | allowedDomains: []string{"example.com"}, 36 | session: &sessions.SessionState{ 37 | Email: "foo@hackerexample.com", 38 | }, 39 | expectedErr: ErrEmailDomainDenied, 40 | }, 41 | { 42 | name: "no subdomain rollup happens", 43 | allowedDomains: []string{"example.com"}, 44 | session: &sessions.SessionState{ 45 | Email: "foo@bar.example.com", 46 | }, 47 | expectedErr: ErrEmailDomainDenied, 48 | }, 49 | { 50 | name: "multiple domain validation still rejects other domains", 51 | allowedDomains: []string{"abc.com", "xyz.com"}, 52 | session: &sessions.SessionState{ 53 | Email: "foo@example.com", 54 | }, 55 | expectedErr: ErrEmailDomainDenied, 56 | }, 57 | { 58 | name: "multiple domain validation still accepts emails from either domain", 59 | allowedDomains: []string{"abc.com", "xyz.com"}, 60 | session: &sessions.SessionState{ 61 | Email: "foo@abc.com", 62 | }, 63 | expectedErr: nil, 64 | }, 65 | { 66 | name: "multiple domain validation still rejects other domains", 67 | allowedDomains: []string{"abc.com", "xyz.com"}, 68 | session: &sessions.SessionState{ 69 | Email: "bar@xyz.com", 70 | }, 71 | expectedErr: nil, 72 | }, 73 | { 74 | name: "comparisons are case insensitive", 75 | allowedDomains: []string{"Example.Com"}, 76 | session: &sessions.SessionState{ 77 | Email: "foo@example.com", 78 | }, 79 | expectedErr: nil, 80 | }, 81 | { 82 | name: "comparisons are case insensitive", 83 | allowedDomains: []string{"Example.Com"}, 84 | session: &sessions.SessionState{ 85 | Email: "foo@EXAMPLE.COM", 86 | }, 87 | expectedErr: nil, 88 | }, 89 | { 90 | name: "comparisons are case insensitive", 91 | allowedDomains: []string{"example.com"}, 92 | session: &sessions.SessionState{ 93 | Email: "foo@ExAmPLE.CoM", 94 | }, 95 | expectedErr: nil, 96 | }, 97 | { 98 | name: "single wildcard allows all", 99 | allowedDomains: []string{"*"}, 100 | session: &sessions.SessionState{ 101 | Email: "foo@example.com", 102 | }, 103 | expectedErr: nil, 104 | }, 105 | { 106 | name: "single wildcard allows all", 107 | allowedDomains: []string{"*"}, 108 | session: &sessions.SessionState{ 109 | Email: "bar@gmail.com", 110 | }, 111 | expectedErr: nil, 112 | }, 113 | { 114 | name: "wildcard is ignored if other domains are included", 115 | allowedDomains: []string{"*", "example.com"}, 116 | session: &sessions.SessionState{ 117 | Email: "foo@gmal.com", 118 | }, 119 | expectedErr: ErrEmailDomainDenied, 120 | }, 121 | { 122 | name: "empty email rejected", 123 | allowedDomains: []string{"example.com"}, 124 | email: "", 125 | session: &sessions.SessionState{ 126 | Email: "foo@example.com", 127 | }, 128 | expectedErr: nil, 129 | }, 130 | { 131 | name: "wildcard still rejects empty emails", 132 | allowedDomains: []string{"*"}, 133 | session: &sessions.SessionState{ 134 | Email: "", 135 | }, 136 | expectedErr: ErrInvalidEmailAddress, 137 | }, 138 | } 139 | 140 | for _, tc := range testCases { 141 | t.Run(tc.name, func(t *testing.T) { 142 | emailValidator := NewEmailDomainValidator(tc.allowedDomains) 143 | err := emailValidator.Validate(tc.session) 144 | if err != tc.expectedErr { 145 | t.Fatalf("expected %v, got %v", tc.expectedErr, err) 146 | } 147 | }) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /internal/pkg/validators/email_group_validator.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | "github.com/buzzfeed/sso/internal/proxy/providers" 8 | ) 9 | 10 | var ( 11 | _ Validator = EmailGroupValidator{} 12 | 13 | // These error message should be formatted in such a way that is appropriate 14 | // for display to the end user. 15 | ErrGroupMembership = errors.New("Invalid Group Membership") 16 | ) 17 | 18 | type EmailGroupValidator struct { 19 | Provider providers.Provider 20 | AllowedGroups []string 21 | } 22 | 23 | // NewEmailGroupValidator takes in a Provider object and a list of groups, and returns a Validator object. 24 | // The validator can be used to validate that the session.Email: 25 | // - according to the Provider that was passed in, is a member of one of the originally passed in groups. 26 | // If valid, nil is returned in place of an error. 27 | func NewEmailGroupValidator(provider providers.Provider, allowedGroups []string) EmailGroupValidator { 28 | return EmailGroupValidator{ 29 | Provider: provider, 30 | AllowedGroups: allowedGroups, 31 | } 32 | } 33 | 34 | func (v EmailGroupValidator) Validate(session *sessions.SessionState) error { 35 | err := v.validate(session) 36 | if err != nil { 37 | return err 38 | } 39 | return nil 40 | } 41 | 42 | func (v EmailGroupValidator) validate(session *sessions.SessionState) error { 43 | matchedGroups, valid, err := v.Provider.ValidateGroup(session.Email, v.AllowedGroups, session.AccessToken) 44 | if err != nil { 45 | return ErrValidationError 46 | } 47 | 48 | if valid { 49 | session.Groups = matchedGroups 50 | return nil 51 | } 52 | 53 | return ErrGroupMembership 54 | } 55 | -------------------------------------------------------------------------------- /internal/pkg/validators/mock_validator.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | ) 8 | 9 | var ( 10 | _ Validator = EmailAddressValidator{} 11 | ) 12 | 13 | type MockValidator struct { 14 | Result bool 15 | } 16 | 17 | func NewMockValidator(result bool) MockValidator { 18 | return MockValidator{ 19 | Result: result, 20 | } 21 | } 22 | 23 | func (v MockValidator) Validate(session *sessions.SessionState) error { 24 | if v.Result { 25 | return nil 26 | } 27 | 28 | return errors.New("MockValidator error") 29 | } 30 | -------------------------------------------------------------------------------- /internal/pkg/validators/validators.go: -------------------------------------------------------------------------------- 1 | package validators 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | ) 8 | 9 | var ( 10 | // These error message should be formatted in such a way that is appropriate 11 | // for display to the end user. 12 | ErrInvalidEmailAddress = errors.New("Invalid Email Address In Session State") 13 | ErrValidationError = errors.New("Error during validation") 14 | ) 15 | 16 | type Validator interface { 17 | Validate(*sessions.SessionState) error 18 | } 19 | 20 | // RunValidators runs each passed in validator and returns a slice of errors. If an 21 | // empty slice is returned, it can be assumed all passed in validators were successful. 22 | func RunValidators(validators []Validator, session *sessions.SessionState) []error { 23 | validatorErrors := make([]error, 0, len(validators)) 24 | 25 | for _, validator := range validators { 26 | err := validator.Validate(session) 27 | if err != nil { 28 | validatorErrors = append(validatorErrors, err) 29 | } 30 | } 31 | return validatorErrors 32 | } 33 | -------------------------------------------------------------------------------- /internal/proxy/collector/collector.go: -------------------------------------------------------------------------------- 1 | package collector 2 | 3 | import ( 4 | "runtime" 5 | "time" 6 | 7 | "github.com/datadog/datadog-go/statsd" 8 | ) 9 | 10 | // Collector ticks periodically and emits runtime stats to datadog 11 | type Collector struct { 12 | // interval represents the interval between ticks for stats collection 13 | interval time.Duration 14 | 15 | // done, when closed, is used to signal the closure of the runtime polling goroutine 16 | done chan struct{} 17 | 18 | // statsd client used to send metrics 19 | client *statsd.Client 20 | } 21 | 22 | // New creates a new collector that will periodically emit runtime statistics to datadog. 23 | func New(client *statsd.Client, interval time.Duration) *Collector { 24 | return &Collector{ 25 | interval: interval, 26 | client: client, 27 | done: make(chan struct{}), 28 | } 29 | } 30 | 31 | // Run gathers statistics from package runtime and emits them to statsd via client 32 | func (c *Collector) Run() { 33 | tick := time.NewTicker(c.interval) 34 | defer tick.Stop() 35 | for { 36 | select { 37 | case <-c.done: 38 | return 39 | case <-tick.C: 40 | c.emitStats() 41 | } 42 | } 43 | } 44 | 45 | // Close signals the collector to close the polling goroutine, use for graceful shutdowns 46 | func (c *Collector) Close() { 47 | close(c.done) 48 | } 49 | 50 | func (c *Collector) emitStats() { 51 | c.emitCPUStats() 52 | c.emitMemStats() 53 | } 54 | 55 | func (c *Collector) emitCPUStats() { 56 | c.gauge("cpu.goroutines", uint64(runtime.NumGoroutine())) 57 | c.gauge("cpu.cgo_calls", uint64(runtime.NumCgoCall())) 58 | } 59 | 60 | func (c *Collector) emitMemStats() { 61 | m := &runtime.MemStats{} 62 | runtime.ReadMemStats(m) 63 | 64 | // General 65 | c.gauge("mem.alloc", m.Alloc) 66 | c.gauge("mem.total", m.TotalAlloc) 67 | c.gauge("mem.sys", m.Sys) 68 | c.gauge("mem.lookups", m.Lookups) 69 | c.gauge("mem.malloc", m.Mallocs) 70 | c.gauge("mem.frees", m.Frees) 71 | 72 | // Heap 73 | c.gauge("mem.heap.alloc", m.HeapAlloc) 74 | c.gauge("mem.heap.sys", m.HeapSys) 75 | c.gauge("mem.heap.idle", m.HeapIdle) 76 | c.gauge("mem.heap.inuse", m.HeapInuse) 77 | c.gauge("mem.heap.released", m.HeapReleased) 78 | c.gauge("mem.heap.objects", m.HeapObjects) 79 | 80 | // Stack 81 | c.gauge("mem.stack.inuse", m.StackInuse) 82 | c.gauge("mem.stack.sys", m.StackSys) 83 | c.gauge("mem.stack.mspan_inuse", m.MSpanInuse) 84 | c.gauge("mem.stack.mspan_sys", m.MSpanSys) 85 | c.gauge("mem.stack.mcache_inuse", m.MCacheInuse) 86 | c.gauge("mem.stack.mcache_sys", m.MCacheSys) 87 | 88 | // Garbage Collection 89 | c.gauge("mem.gc.sys", m.GCSys) 90 | c.gauge("mem.gc.next", m.NextGC) 91 | c.gauge("mem.gc.last", m.LastGC) 92 | c.gauge("mem.gc.pause_total", m.PauseTotalNs) 93 | c.gauge("mem.gc.pause", m.PauseNs[(m.NumGC+255)%256]) 94 | c.gauge("mem.gc.count", uint64(m.NumGC)) 95 | 96 | // Other 97 | c.gauge("mem.othersys", m.OtherSys) 98 | } 99 | 100 | func (c *Collector) gauge(key string, val uint64) { 101 | c.client.Gauge(key, float64(val), nil, 1.0) 102 | } 103 | -------------------------------------------------------------------------------- /internal/proxy/logging_handler.go: -------------------------------------------------------------------------------- 1 | // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go 2 | // to add logging of request duration as last value (and drop referrer) 3 | 4 | package proxy 5 | 6 | import ( 7 | "bufio" 8 | "errors" 9 | "io" 10 | "net" 11 | "net/http" 12 | "net/url" 13 | "strings" 14 | "time" 15 | 16 | log "github.com/buzzfeed/sso/internal/pkg/logging" 17 | "github.com/datadog/datadog-go/statsd" 18 | ) 19 | 20 | // Used to stash the authenticated user in the response for access when logging requests. 21 | const loggingUserHeader = "SSO-Authenticated-User" 22 | 23 | // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status 24 | // code and body size 25 | type responseLogger struct { 26 | w http.ResponseWriter 27 | status int 28 | size int 29 | authInfo string 30 | } 31 | 32 | func (l *responseLogger) Header() http.Header { 33 | return l.w.Header() 34 | } 35 | 36 | func (l *responseLogger) extractUser() { 37 | authInfo := l.w.Header().Get(loggingUserHeader) 38 | if authInfo != "" { 39 | l.authInfo = authInfo 40 | l.w.Header().Del(loggingUserHeader) 41 | } 42 | } 43 | 44 | // Support Websockets 45 | func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { 46 | if hij, ok := l.w.(http.Hijacker); ok { 47 | return hij.Hijack() 48 | } 49 | return nil, nil, errors.New("http.Hijacker is not available on writer") 50 | } 51 | 52 | func (l *responseLogger) Write(b []byte) (int, error) { 53 | if l.status == 0 { 54 | // The status will be StatusOK if WriteHeader has not been called yet 55 | l.status = http.StatusOK 56 | } 57 | l.extractUser() 58 | size, err := l.w.Write(b) 59 | l.size += size 60 | return size, err 61 | } 62 | 63 | func (l *responseLogger) WriteHeader(s int) { 64 | l.extractUser() 65 | l.w.WriteHeader(s) 66 | l.status = s 67 | } 68 | 69 | func (l *responseLogger) Status() int { 70 | return l.status 71 | } 72 | 73 | func (l *responseLogger) Size() int { 74 | return l.size 75 | } 76 | 77 | func (l *responseLogger) Flush() { 78 | f := l.w.(http.Flusher) 79 | f.Flush() 80 | } 81 | 82 | // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends 83 | type loggingHandler struct { 84 | writer io.Writer 85 | handler http.Handler 86 | StatsdClient *statsd.Client 87 | enabled bool 88 | } 89 | 90 | // NewLoggingHandler returns a new loggingHandler that wraps a handler, statsd client, and writer. 91 | func NewLoggingHandler(out io.Writer, h http.Handler, lc LoggingConfig, StatsdClient *statsd.Client) http.Handler { 92 | return loggingHandler{writer: out, 93 | handler: h, 94 | enabled: lc.Enable, 95 | StatsdClient: StatsdClient, 96 | } 97 | } 98 | 99 | func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 100 | now := time.Now() 101 | url := *req.URL 102 | logger := &responseLogger{w: w} 103 | h.handler.ServeHTTP(logger, req) 104 | if !h.enabled { 105 | return 106 | } 107 | logRequest(logger.authInfo, req, url, now, logger.Status(), h.StatsdClient) 108 | } 109 | 110 | // logRequest logs information about a request 111 | func logRequest(username string, req *http.Request, url url.URL, ts time.Time, status int, StatsdClient *statsd.Client) { 112 | duration := time.Now().Sub(ts) 113 | 114 | // Convert duration to floating point milliseconds 115 | // https://github.com/golang/go/issues/5491#issuecomment-66079585 116 | durationMS := duration.Seconds() * 1e3 117 | 118 | uri := req.Host + url.RequestURI() 119 | 120 | logger := log.NewLogEntry() 121 | logger.WithHTTPStatus(status).WithRequestMethod(req.Method).WithRequestURI( 122 | uri).WithUserAgent(req.Header.Get("User-Agent")).WithRemoteAddress( 123 | getRemoteAddr(req)).WithRequestDurationMs(durationMS).WithUser( 124 | username).WithAction(GetActionTag(req)).Info() 125 | logRequestMetrics(req, duration, status, StatsdClient) 126 | } 127 | 128 | // getRemoteAddr returns the client IP address from a request. If present, the 129 | // X-Forwarded-For header is assumed to be set by a load balancer, and its 130 | // rightmost entry (the client IP that connected to the LB) is returned. 131 | func getRemoteAddr(req *http.Request) string { 132 | addr := req.RemoteAddr 133 | forwardedHeader := req.Header.Get("X-Forwarded-For") 134 | if forwardedHeader != "" { 135 | forwardedList := strings.Split(forwardedHeader, ",") 136 | forwardedAddr := strings.TrimSpace(forwardedList[len(forwardedList)-1]) 137 | if forwardedAddr != "" { 138 | addr = forwardedAddr 139 | } 140 | } 141 | return addr 142 | } 143 | -------------------------------------------------------------------------------- /internal/proxy/logging_handler_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net/http/httptest" 5 | "testing" 6 | ) 7 | 8 | func TestGetRemoteAddr(t *testing.T) { 9 | testCases := []struct { 10 | name string 11 | remoteAddr string 12 | forwardedHeader string 13 | expectedAddr string 14 | }{ 15 | { 16 | name: "RemoteAddr used when no X-Forwarded-For header is given", 17 | remoteAddr: "1.1.1.1", 18 | expectedAddr: "1.1.1.1", 19 | }, 20 | { 21 | name: "RemoteAddr used when no X-Forwarded-For header is only whitespace", 22 | remoteAddr: "1.1.1.1", 23 | forwardedHeader: " ", 24 | expectedAddr: "1.1.1.1", 25 | }, 26 | { 27 | name: "RemoteAddr used when no X-Forwarded-For header is only comma-separated whitespace", 28 | remoteAddr: "1.1.1.1", 29 | forwardedHeader: " , , ", 30 | expectedAddr: "1.1.1.1", 31 | }, 32 | { 33 | name: "X-Forwarded-For header is preferred to RemoteAddr", 34 | remoteAddr: "1.1.1.1", 35 | forwardedHeader: "9.9.9.9", 36 | expectedAddr: "9.9.9.9", 37 | }, 38 | { 39 | name: "rightmost entry in X-Forwarded-For header is used", 40 | remoteAddr: "1.1.1.1", 41 | forwardedHeader: "2.2.2.2, 3.3.3.3, 4.4.4.4.4, 5.5.5.5", 42 | expectedAddr: "5.5.5.5", 43 | }, 44 | { 45 | name: "RemoteAddr is used if rightmost entry in X-Forwarded-For header is empty", 46 | remoteAddr: "1.1.1.1", 47 | forwardedHeader: "2.2.2.2, 3.3.3.3, ", 48 | expectedAddr: "1.1.1.1", 49 | }, 50 | { 51 | name: "X-Forwaded-For header entries are stripped", 52 | remoteAddr: "1.1.1.1", 53 | forwardedHeader: " 2.2.2.2, 3.3.3.3, 4.4.4.4, 5.5.5.5 ", 54 | expectedAddr: "5.5.5.5", 55 | }, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | req := httptest.NewRequest("GET", "/", nil) 61 | req.RemoteAddr = tc.remoteAddr 62 | if tc.forwardedHeader != "" { 63 | req.Header.Set("X-Forwarded-For", tc.forwardedHeader) 64 | } 65 | 66 | addr := getRemoteAddr(req) 67 | if addr != tc.expectedAddr { 68 | t.Errorf("expected remote addr = %q, got %q", tc.expectedAddr, addr) 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /internal/proxy/metrics.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/datadog/datadog-go/statsd" 11 | ) 12 | 13 | // newStatsdClient creates and returns a statsd client on a host and port that is namespaced to 'sso_proxy' 14 | func NewStatsdClient(host string, port int) (*statsd.Client, error) { 15 | client, err := statsd.New(net.JoinHostPort(host, strconv.Itoa(port))) 16 | if err != nil { 17 | return nil, err 18 | } 19 | client.Namespace = "sso_proxy." 20 | client.Tags = []string{ 21 | "service:sso_proxy", 22 | } 23 | return client, nil 24 | } 25 | 26 | // GetActionTag returns the action triggered by an http.Request . 27 | func GetActionTag(req *http.Request) string { 28 | // only log metrics for these paths and actions 29 | pathToAction := map[string]string{ 30 | "/favicon.ico": "favicon", 31 | "/oauth2/sign_out": "sign_out", 32 | "/oauth2/callback": "callback", 33 | "/oauth2/auth": "auth", 34 | "/ping": "ping", 35 | "/robots.txt": "robots", 36 | } 37 | // get the action from the url path 38 | path := req.URL.Path 39 | if action, ok := pathToAction[path]; ok { 40 | return action 41 | } 42 | return "proxy" 43 | } 44 | 45 | // logMetrics logs all metrics surrounding a given request to the metricsWriter 46 | func logRequestMetrics(req *http.Request, requestDuration time.Duration, status int, StatsdClient *statsd.Client) { 47 | // Normalize proxyHost for a) invalid requests or b) LB health checks to 48 | // avoid polluting the proxy_host tag's value space 49 | proxyHost := req.Host 50 | if status == statusInvalidHost { 51 | proxyHost = "_unknown" 52 | } 53 | if req.URL.Path == "/ping" { 54 | proxyHost = "_healthcheck" 55 | } 56 | 57 | tags := []string{ 58 | fmt.Sprintf("method:%s", req.Method), 59 | fmt.Sprintf("status_code:%d", status), 60 | fmt.Sprintf("status_category:%dxx", status/100), 61 | fmt.Sprintf("action:%s", GetActionTag(req)), 62 | fmt.Sprintf("proxy_host:%s", proxyHost), 63 | } 64 | 65 | // TODO: eventually make rates configurable 66 | StatsdClient.Timing("request.duration", requestDuration, tags, 1.0) 67 | 68 | } 69 | -------------------------------------------------------------------------------- /internal/proxy/middleware.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | ) 7 | 8 | // With inspiration from https://github.com/unrolled/secure 9 | // 10 | // TODO: Add Content-Security-Report header? 11 | var securityHeaders = map[string]string{ 12 | "X-Content-Type-Options": "nosniff", 13 | "X-Frame-Options": "SAMEORIGIN", 14 | "X-XSS-Protection": "1; mode=block", 15 | } 16 | 17 | // setHeaders ensures that every response includes some basic security headers. 18 | // 19 | // Note: the Strict-Transport-Security header is set by the requireHTTPS 20 | // middleware below, to avoid issues with development environments that must 21 | // allow plain HTTP. 22 | func setHeaders(h http.Handler, headers map[string]string) http.Handler { 23 | return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 24 | for key, val := range headers { 25 | rw.Header().Set(key, val) 26 | } 27 | h.ServeHTTP(rw, req) 28 | }) 29 | } 30 | 31 | func setSecurityHeaders(h http.Handler) http.Handler { 32 | return setHeaders(h, securityHeaders) 33 | } 34 | 35 | func (p *OAuthProxy) setResponseHeaderOverrides(upstreamConfig *UpstreamConfig, h http.Handler) http.Handler { 36 | return setHeaders(h, upstreamConfig.HeaderOverrides) 37 | } 38 | 39 | func requireHTTPS(h http.Handler) http.Handler { 40 | return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 41 | rw.Header().Set("Strict-Transport-Security", "max-age=31536000") 42 | if req.URL.Scheme != "https" && req.Header.Get("X-Forwarded-Proto") != "https" { 43 | dest := &url.URL{ 44 | Scheme: "https", 45 | Host: req.Host, 46 | Path: req.URL.Path, 47 | RawQuery: req.URL.RawQuery, 48 | } 49 | http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently) 50 | return 51 | } 52 | h.ServeHTTP(rw, req) 53 | }) 54 | } 55 | 56 | func setHealthCheck(healthcheckPath string, next http.Handler) http.Handler { 57 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 | if r.URL.Path == healthcheckPath { 59 | w.WriteHeader(http.StatusOK) 60 | return 61 | } 62 | next.ServeHTTP(w, r) 63 | }) 64 | } 65 | -------------------------------------------------------------------------------- /internal/proxy/providers/http_client.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | var httpClient = &http.Client{ 10 | Timeout: time.Second * 5, 11 | Transport: &http.Transport{ 12 | Proxy: http.ProxyFromEnvironment, 13 | Dial: (&net.Dialer{ 14 | Timeout: 2 * time.Second, 15 | }).Dial, 16 | TLSHandshakeTimeout: 2 * time.Second, 17 | }, 18 | } 19 | -------------------------------------------------------------------------------- /internal/proxy/providers/provider_data.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | ) 7 | 8 | // ProviderData holds the fields associated with providers 9 | // necessary to implement the Provider interface. 10 | type ProviderData struct { 11 | ProviderName string 12 | ProviderSlug string 13 | ProviderURL *url.URL 14 | ProviderURLInternal *url.URL 15 | ClientID string 16 | ClientSecret string 17 | SignInURL *url.URL 18 | SignOutURL *url.URL 19 | RedeemURL *url.URL 20 | RefreshURL *url.URL 21 | ProfileURL *url.URL 22 | ValidateURL *url.URL 23 | Scope string 24 | 25 | SessionValidTTL time.Duration 26 | SessionLifetimeTTL time.Duration 27 | GracePeriodTTL time.Duration 28 | } 29 | 30 | // Data returns the ProviderData struct 31 | func (p *ProviderData) Data() *ProviderData { return p } 32 | -------------------------------------------------------------------------------- /internal/proxy/providers/providers.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | "github.com/datadog/datadog-go/statsd" 8 | ) 9 | 10 | // Provider is an interface exposing functions necessary to authenticate with a given provider. 11 | type Provider interface { 12 | Data() *ProviderData 13 | Redeem(string, string) (*sessions.SessionState, error) 14 | ValidateGroup(string, []string, string) ([]string, bool, error) 15 | UserGroups(string, []string, string) ([]string, error) 16 | ValidateSessionState(*sessions.SessionState, []string) bool 17 | GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL 18 | GetSignOutURL(redirectURL *url.URL) *url.URL 19 | RefreshSession(*sessions.SessionState, []string) (bool, error) 20 | } 21 | 22 | // New returns a new sso Provider 23 | func New(provider string, p *ProviderData, sc *statsd.Client) Provider { 24 | return NewSSOProvider(p, sc) 25 | } 26 | -------------------------------------------------------------------------------- /internal/proxy/providers/test_provider.go: -------------------------------------------------------------------------------- 1 | package providers 2 | 3 | import ( 4 | "net/url" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/sessions" 7 | ) 8 | 9 | // TestProvider is a mock provider 10 | type TestProvider struct { 11 | RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error) 12 | ValidateSessionFunc func(*sessions.SessionState, []string) bool 13 | RedeemFunc func(string, string) (*sessions.SessionState, error) 14 | UserGroupsFunc func(string, []string, string) ([]string, error) 15 | ValidateGroupsFunc func(string, []string, string) ([]string, bool, error) 16 | *ProviderData 17 | } 18 | 19 | // NewTestProvider returns a new TestProvider 20 | func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { 21 | return &TestProvider{ 22 | ProviderData: &ProviderData{ 23 | ProviderName: "Test Provider", 24 | SignInURL: &url.URL{ 25 | Scheme: "http", 26 | Host: providerURL.Host, 27 | Path: "/oauth/authorize", 28 | }, 29 | RedeemURL: &url.URL{ 30 | Scheme: "http", 31 | Host: providerURL.Host, 32 | Path: "/oauth/token", 33 | }, 34 | ProfileURL: &url.URL{ 35 | Scheme: "http", 36 | Host: providerURL.Host, 37 | Path: "/api/v1/profile", 38 | }, 39 | SignOutURL: &url.URL{ 40 | Scheme: "http", 41 | Host: providerURL.Host, 42 | Path: "/oauth/sign_out", 43 | }, 44 | Scope: "profile.email", 45 | }, 46 | } 47 | } 48 | 49 | // ValidateSessionState mocks the ValidateSessionState function 50 | func (tp *TestProvider) ValidateSessionState(s *sessions.SessionState, groups []string) bool { 51 | return tp.ValidateSessionFunc(s, groups) 52 | } 53 | 54 | // Redeem mocks the provider Redeem function 55 | func (tp *TestProvider) Redeem(redirectURL string, token string) (*sessions.SessionState, error) { 56 | return tp.RedeemFunc(redirectURL, token) 57 | } 58 | 59 | // RefreshSession mocks the RefreshSession function 60 | func (tp *TestProvider) RefreshSession(s *sessions.SessionState, g []string) (bool, error) { 61 | return tp.RefreshSessionFunc(s, g) 62 | } 63 | 64 | // UserGroups mocks the UserGroups function 65 | func (tp *TestProvider) UserGroups(email string, groups []string, accessToken string) ([]string, error) { 66 | return tp.UserGroupsFunc(email, groups, accessToken) 67 | } 68 | 69 | // ValidateGroup mocks the ValidateGroup function 70 | func (tp *TestProvider) ValidateGroup(email string, groups []string, accessToken string) ([]string, bool, error) { 71 | return tp.ValidateGroupsFunc(email, groups, accessToken) 72 | } 73 | 74 | // GetSignOutURL mocks GetSignOutURL function 75 | func (tp *TestProvider) GetSignOutURL(redirectURL *url.URL) *url.URL { 76 | return tp.Data().SignOutURL 77 | } 78 | 79 | // GetSignInURL mocks GetSignInURL 80 | func (tp *TestProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL { 81 | a := *tp.Data().SignInURL 82 | params, _ := url.ParseQuery(a.RawQuery) 83 | params.Add("state", state) 84 | a.RawQuery = params.Encode() 85 | return &a 86 | } 87 | -------------------------------------------------------------------------------- /internal/proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/buzzfeed/sso/internal/pkg/hostmux" 8 | "github.com/buzzfeed/sso/internal/pkg/validators" 9 | "github.com/datadog/datadog-go/statsd" 10 | ) 11 | 12 | type SSOProxy struct { 13 | http.Handler 14 | } 15 | 16 | func New(config Configuration, statsdClient *statsd.Client) (*SSOProxy, error) { 17 | optFuncs := []func(*OAuthProxy) error{} 18 | 19 | var requestSigner *RequestSigner 20 | var err error 21 | 22 | if config.RequestSignerConfig.Key != "" { 23 | requestSigner, err = NewRequestSigner(config.RequestSignerConfig.Key) 24 | if err != nil { 25 | return nil, err 26 | } 27 | optFuncs = append(optFuncs, SetRequestSigner(requestSigner)) 28 | } 29 | 30 | hostRouter := hostmux.NewRouter() 31 | for _, upstreamConfig := range config.UpstreamConfigs.upstreamConfigs { 32 | provider, err := newProvider( 33 | config.ClientConfig, 34 | config.ProviderConfig, 35 | config.SessionConfig, 36 | config.UpstreamConfigs, 37 | statsdClient, 38 | ) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | handler, err := NewUpstreamReverseProxy(upstreamConfig, requestSigner) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | v := []validators.Validator{} 49 | if len(upstreamConfig.AllowedEmailAddresses) != 0 { 50 | v = append(v, validators.NewEmailAddressValidator(upstreamConfig.AllowedEmailAddresses)) 51 | } 52 | 53 | if len(upstreamConfig.AllowedEmailDomains) != 0 { 54 | v = append(v, validators.NewEmailDomainValidator(upstreamConfig.AllowedEmailDomains)) 55 | } 56 | 57 | if len(upstreamConfig.AllowedGroups) != 0 { 58 | v = append(v, validators.NewEmailGroupValidator(provider, upstreamConfig.AllowedGroups)) 59 | } 60 | 61 | optFuncs = append(optFuncs, 62 | SetProvider(provider), 63 | SetCookieStore(config.SessionConfig.CookieConfig), 64 | SetUpstreamConfig(upstreamConfig), 65 | SetProxyHandler(handler), 66 | SetStatsdClient(statsdClient), 67 | SetValidators(v), 68 | ) 69 | 70 | oauthproxy, err := NewOAuthProxy(config.SessionConfig, optFuncs...) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | switch route := upstreamConfig.Route.(type) { 76 | case *SimpleRoute: 77 | hostRouter.HandleStatic(route.FromURL.Host, oauthproxy.Handler()) 78 | case *RewriteRoute: 79 | hostRouter.HandleRegexp(route.FromRegex, oauthproxy.Handler()) 80 | default: 81 | return nil, fmt.Errorf("unknown route type") 82 | } 83 | } 84 | 85 | healthcheckHandler := setHealthCheck("/ping", hostRouter) 86 | 87 | return &SSOProxy{ 88 | healthcheckHandler, 89 | }, nil 90 | } 91 | -------------------------------------------------------------------------------- /internal/proxy/proxy_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | var ( 12 | letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 13 | ) 14 | 15 | func randSeq(n int) string { 16 | b := make([]rune, n) 17 | for i := range b { 18 | b[i] = letters[rand.Intn(len(letters))] 19 | } 20 | return string(b) 21 | } 22 | 23 | func TestPingHandler(t *testing.T) { 24 | config := DefaultProxyConfig() 25 | statsdClient, err := NewStatsdClient( 26 | config.MetricsConfig.StatsdConfig.Host, 27 | config.MetricsConfig.StatsdConfig.Port) 28 | if err != nil { 29 | t.Fatalf("unexpected error creating statsd client: %v", err) 30 | } 31 | 32 | sso, err := New(config, statsdClient) 33 | if err != nil { 34 | t.Fatalf("unexpected err starting sso: %v", err) 35 | } 36 | 37 | for i := 0; i < 100; i++ { 38 | uri := fmt.Sprintf("https://host-%s.local/ping", randSeq(i)) 39 | 40 | rw := httptest.NewRecorder() 41 | req := httptest.NewRequest("GET", uri, nil) 42 | 43 | sso.ServeHTTP(rw, req) 44 | if rw.Code != http.StatusOK { 45 | t.Errorf("want: %v", http.StatusOK) 46 | t.Errorf("have: %v", rw.Code) 47 | t.Fatalf("unexpected response code for ping") 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /internal/proxy/templates.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "html/template" 5 | ) 6 | 7 | func getTemplates() *template.Template { 8 | t := template.New("foo") 9 | t = template.Must(t.Parse(`{{define "error.html"}} 10 | 11 | 12 | 13 | Error 14 | 15 | 103 | 104 | 105 | 106 |
107 |
108 |
109 |

{{.Title}}

110 |
111 |

112 | {{.Message}}
113 | HTTP {{.Code}} 114 |

115 | {{if ne .Code 403 }} 116 |
117 | 118 |
119 | {{end}} 120 |
121 |
Secured by SSO
122 |
123 | 124 | {{end}}`)) 125 | return t 126 | } 127 | -------------------------------------------------------------------------------- /internal/proxy/templates_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/buzzfeed/sso/internal/pkg/testutil" 7 | ) 8 | 9 | func TestTemplatesCompile(t *testing.T) { 10 | templates := getTemplates() 11 | testutil.NotEqual(t, templates, nil) 12 | } 13 | -------------------------------------------------------------------------------- /internal/proxy/testdata/private_key.pem: -------------------------------------------------------------------------------- 1 | ### THIS KEY IS USED FOR TESTS, IT IS INTENTIONALLY MADE PUBLIC 2 | 3 | -----BEGIN PRIVATE KEY----- 4 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCy38IQCH8QyeNF 5 | s1zA0XuIyqnTcSfYZg0nPfB+K//pFy7tIOAwmR6th8NykrxFhEQDHKNCmLXt4j8V 6 | FDHQZtGjUBHRmAXZW8NOQ0EI1vc/Dpt09sU40JQlXZZeL+9/7iAxEfSE3TQr1k7P 7 | Xwxpjm9rsLSn7FoLnvXco0mc6+d2jjxf4cMgJIaQLKOd783KUQzLVEvBQJ05JnpI 8 | 2xMjS0q33ltMTMGF3QZQN9i4bZKgnItomKxTJbfxftO11FTNLB7og94sWmlThAY5 9 | /UMjZaWYJ1g89+WUJ+KpVYyJsHPBBkaQG+NYazcLDyIowpzJ1WVkInysshpTqwT+ 10 | UPV4at+jAgMBAAECggEAX8lxK5LRMJVcLlwRZHQJekRE0yS6WKi1jHkfywEW5qRy 11 | jatYQs4MXpLgN/+Z8IQWw6/XQXdznTLV4xzQXDBjPNhI4ntNTotUOBnNvsUW296f 12 | ou/uxzDy1FuchU2YLGLBPGXIEko+gOcfhu74P6J1yi5zX6UyxxxVvtR2PCEb7yDw 13 | m2881chwMblZ5Z8uyF++ajkK3/rqLk64w29+K4ZTDbTcCp5NtBYx2qSEU7yp12rc 14 | qscUGqxG00Abx+osI3cUn0kOq7356LeR1rfA15yZwOb+s28QYp2WPlVB2hOiYXQv 15 | +ttEOpt0x1QJhBAsFgwY173sD5w2MryRQb1RCwBvqQKBgQDeTdbRzxzAl83h/mAq 16 | 5I+pNEz57veAFVO+iby7TbZ/0w6q+QeT+bHF+TjGHiSlbtg3nd9NPrex2UjiN7ej 17 | +DrxhsSLsP1ZfwDNv6f1Ii1HluJclUFSUNU/LntBjqqCJ959lniNp1y5+ZQ/j2Rf 18 | +ZraVsHRB0itilFeAl5+n7CfxwKBgQDN/K+E1TCbp1inU60Lc9zeb8fqTEP6Mp36 19 | qQ0Dp+KMLPJ0xQSXFq9ILr4hTJlBqfmTkfmQUcQuwercZ3LNQPbsuIg96bPW73R1 20 | toXjokd6jUn5sJXCOE0RDumcJrL1VRf9RN1AmM4CgCc/adUMjws3pBc5R4An7UyU 21 | ouRQhN+5RQKBgFOVTrzqM3RSX22mWAAomb9T09FxQQueeTM91IFUMdcTwwMTyP6h 22 | Nm8qSmdrM/ojmBYpPKlteGHdQaMUse5rybXAJywiqs84ilPRyNPJOt8c4xVOZRYP 23 | IG62Ck/W1VNErEnqBn+0OpAOP+g6ANJ5JfkL/6mZJIFjbT58g4z2e9FHAoGBAM3f 24 | uBkd7lgTuLJ8Gh6xLVYQCJHuqZ49ytFE9qHpwK5zGdyFMSJE5OlS9mpXoXEUjkHk 25 | iraoUlidLbwdlIr6XBCaGmku07SFXTNtOoIZpjEhV4c762HTXYsoCWos733uD2zt 26 | z+iJEJVFOnTRtMK5kO+KjD+Oa9L8BCcmauTi+Ku1AoGAZBUzi95THA60hPXI0hm/ 27 | o0J5mfLkFPfhpUmDAMaEpv3bM4byA+IGXSZVc1IZO6cGoaeUHD2Yl1m9a5tv5rF+ 28 | FS9Ht+IgATvGojah+xxQy+kf6tRB9Hn4scyq+64AesXlDbWDEagomQ0hyV/JKSS6 29 | LQatvnCmBd9omRT2uwYUo+o= 30 | -----END PRIVATE KEY----- 31 | -------------------------------------------------------------------------------- /internal/proxy/testdata/public_key.pub: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PUBLIC KEY----- 2 | MIIBCgKCAQEAst/CEAh/EMnjRbNcwNF7iMqp03En2GYNJz3wfiv/6Rcu7SDgMJke 3 | rYfDcpK8RYREAxyjQpi17eI/FRQx0GbRo1AR0ZgF2VvDTkNBCNb3Pw6bdPbFONCU 4 | JV2WXi/vf+4gMRH0hN00K9ZOz18MaY5va7C0p+xaC5713KNJnOvndo48X+HDICSG 5 | kCyjne/NylEMy1RLwUCdOSZ6SNsTI0tKt95bTEzBhd0GUDfYuG2SoJyLaJisUyW3 6 | 8X7TtdRUzSwe6IPeLFppU4QGOf1DI2WlmCdYPPfllCfiqVWMibBzwQZGkBvjWGs3 7 | Cw8iKMKcydVlZCJ8rLIaU6sE/lD1eGrfowIDAQAB 8 | -----END RSA PUBLIC KEY----- 9 | -------------------------------------------------------------------------------- /internal/proxy/testdata/upstream_configs.yml: -------------------------------------------------------------------------------- 1 | - service: foo 2 | default: 3 | from: foo.{{cluster}}.{{root_domain}} 4 | to: foo-internal.{{cluster}}.{{root_domain}} 5 | options: 6 | allowed_groups: 7 | - dev 8 | dev: 9 | from: foo.{{root_domain}} 10 | 11 | -------------------------------------------------------------------------------- /internal/proxy/version.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | // VERSION is the version of sso proxy 4 | const VERSION = "2.2.1-alpha" 5 | -------------------------------------------------------------------------------- /quickstart/env.google.example: -------------------------------------------------------------------------------- 1 | PROVIDER_GOOGLEQUICKSTART_CLIENT_ID=.apps.googleusercontent.com 2 | PROVIDER_GOOGLEQUICKSTART_CLIENT_SECRET= 3 | PROVIDER_GOOGLEQUICKSTART_TYPE=google 4 | PROVIDER_GOOGLEQUICKSTART_SLUG=google 5 | -------------------------------------------------------------------------------- /quickstart/env.okta.example: -------------------------------------------------------------------------------- 1 | PROVIDER_OKTAQUICKSTART_CLIENT_ID= 2 | PROVIDER_OKTAQUICKSTART_CLIENT_SECRET= 3 | PROVIDER_OKTAQUICKSTART_OKTA_URL= 4 | PROVIDER_OKTAQUICKSTART_TYPE=okta 5 | PROVIDER_OKTAQUICKSTART_SLUG=okta 6 | DEFAULT_PROVIDER_SLUG=okta 7 | -------------------------------------------------------------------------------- /quickstart/kubernetes/demo-apps/hello-world-deployment.yml: -------------------------------------------------------------------------------- 1 | apiVersion: extensions/v1beta1 2 | kind: Deployment 3 | metadata: 4 | name: hello-world 5 | labels: 6 | k8s-app: hello-world 7 | spec: 8 | replicas: 1 9 | template: 10 | metadata: 11 | labels: 12 | k8s-app: hello-world 13 | spec: 14 | containers: 15 | - image: tutum/hello-world:latest 16 | name: hello-world 17 | ports: 18 | - containerPort: 80 19 | -------------------------------------------------------------------------------- /quickstart/kubernetes/demo-apps/hello-world-svc.yml: -------------------------------------------------------------------------------- 1 | kind: Service 2 | apiVersion: v1 3 | metadata: 4 | labels: 5 | k8s-app: hello-world 6 | name: hello-world 7 | spec: 8 | ports: 9 | - port: 80 10 | selector: 11 | k8s-app: hello-world 12 | -------------------------------------------------------------------------------- /quickstart/kubernetes/demo-apps/httpbin-deployment.yml: -------------------------------------------------------------------------------- 1 | apiVersion: extensions/v1beta1 2 | kind: Deployment 3 | metadata: 4 | name: httpbin 5 | labels: 6 | k8s-app: httpbin 7 | namespace: test 8 | spec: 9 | replicas: 1 10 | template: 11 | metadata: 12 | labels: 13 | k8s-app: httpbin 14 | spec: 15 | containers: 16 | - image: mccutchen/go-httpbin:latest 17 | name: httpbin 18 | ports: 19 | - containerPort: 8080 20 | -------------------------------------------------------------------------------- /quickstart/kubernetes/demo-apps/httpbin-svc.yml: -------------------------------------------------------------------------------- 1 | kind: Service 2 | apiVersion: v1 3 | metadata: 4 | labels: 5 | k8s-app: httpbin 6 | name: httpbin 7 | spec: 8 | ports: 9 | - port: 80 10 | targetPort: 8080 11 | selector: 12 | k8s-app: httpbin 13 | -------------------------------------------------------------------------------- /quickstart/kubernetes/sso-auth-deployment.yml: -------------------------------------------------------------------------------- 1 | apiVersion: extensions/v1beta1 2 | kind: Deployment 3 | metadata: 4 | name: sso-auth 5 | labels: 6 | k8s-app: sso-auth 7 | namespace: sso 8 | spec: 9 | replicas: 1 10 | template: 11 | metadata: 12 | labels: 13 | k8s-app: sso-auth 14 | spec: 15 | containers: 16 | - image: buzzfeed/sso-dev:latest 17 | name: sso-auth 18 | command: ["/bin/sso-auth"] 19 | ports: 20 | - containerPort: 4180 21 | env: 22 | - name: GOOGLE_ADMIN_EMAIL 23 | valueFrom: 24 | secretKeyRef: 25 | name: google-admin-email 26 | key: email 27 | - name: GOOGLE_SERVICE_ACCOUNT_JSON 28 | value: /creds/service_account.json 29 | - name: SSO_EMAIL_DOMAIN 30 | value: 'mydomain.com' 31 | - name: HOST 32 | value: sso-auth.mydomain.com 33 | - name: REDIRECT_URL 34 | value: https://sso-auth.mydomain.com 35 | - name: PROXY_ROOT_DOMAIN 36 | value: mydomain.com 37 | - name: CLIENT_ID 38 | valueFrom: 39 | secretKeyRef: 40 | name: google-client-id 41 | key: client-id 42 | - name: CLIENT_SECRET 43 | valueFrom: 44 | secretKeyRef: 45 | name: google-client-secret 46 | key: client-secret 47 | - name: PROXY_CLIENT_ID 48 | valueFrom: 49 | secretKeyRef: 50 | name: proxy-client-id 51 | key: proxy-client-id 52 | - name: PROXY_CLIENT_SECRET 53 | valueFrom: 54 | secretKeyRef: 55 | name: proxy-client-secret 56 | key: proxy-client-secret 57 | - name: COOKIE_SECRET 58 | valueFrom: 59 | secretKeyRef: 60 | name: auth-cookie-secret 61 | key: auth-cookie-secret 62 | # STATSD_HOST and STATSD_PORT must be defined or the app wont launch, they dont need to be a real host / port 63 | - name: STATSD_HOST 64 | value: localhost 65 | - name: STATSD_PORT 66 | value: "11111" 67 | - name: COOKIE_SECURE 68 | value: "false" 69 | - name: CLUSTER 70 | value: dev 71 | - name: VIRTUAL_HOST 72 | value: sso-auth.mydomain.com 73 | readinessProbe: 74 | httpGet: 75 | path: /ping 76 | port: 4180 77 | scheme: HTTP 78 | livenessProbe: 79 | httpGet: 80 | path: /ping 81 | port: 4180 82 | scheme: HTTP 83 | initialDelaySeconds: 10 84 | timeoutSeconds: 1 85 | resources: 86 | limits: 87 | memory: "256Mi" 88 | cpu: "200m" 89 | volumeMounts: 90 | - name: google-service-account 91 | mountPath: "/creds" 92 | volumes: 93 | - name: google-service-account 94 | secret: 95 | secretName: google-service-account 96 | -------------------------------------------------------------------------------- /quickstart/kubernetes/sso-auth-svc.yml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: sso-auth 5 | namespace: sso 6 | labels: 7 | k8s-app: sso-auth 8 | spec: 9 | ports: 10 | - port: 80 11 | targetPort: 4180 12 | name: http 13 | selector: 14 | k8s-app: sso-auth 15 | 16 | -------------------------------------------------------------------------------- /quickstart/kubernetes/sso-ingress.yml: -------------------------------------------------------------------------------- 1 | apiVersion: extensions/v1beta1 2 | kind: Ingress 3 | metadata: 4 | name: sso 5 | namespace: sso 6 | spec: 7 | tls: 8 | - secretName: star-sso-mydomain-tls-secret 9 | hosts: 10 | - "*.sso.mydomain.com" 11 | - secretName: sso-auth-mydomain-tls-secret 12 | hosts: 13 | - "sso-auth.mydomain.com" 14 | rules: 15 | - host: "*.sso.mydomain.com" 16 | http: 17 | paths: 18 | - path: / 19 | backend: 20 | serviceName: sso-proxy 21 | servicePort: 80 22 | - host: "sso-auth.mydomain.com" 23 | http: 24 | paths: 25 | - path: / 26 | backend: 27 | serviceName: sso-auth 28 | servicePort: 80 29 | -------------------------------------------------------------------------------- /quickstart/kubernetes/sso-proxy-deployment.yml: -------------------------------------------------------------------------------- 1 | apiVersion: extensions/v1beta1 2 | kind: Deployment 3 | metadata: 4 | name: sso-proxy 5 | labels: 6 | k8s-app: sso-proxy 7 | namespace: sso 8 | spec: 9 | replicas: 1 10 | template: 11 | metadata: 12 | labels: 13 | k8s-app: sso-proxy 14 | spec: 15 | containers: 16 | - image: buzzfeed/sso-dev:latest 17 | name: sso-proxy 18 | command: ["/bin/sso-proxy"] 19 | ports: 20 | - containerPort: 4180 21 | env: 22 | - name: DEFAULT_ALLOWED_EMAIL_DOMAINS 23 | value: 'mydomain.com' 24 | - name: UPSTREAM_CONFIGS 25 | value: /sso/upstream_configs.yml 26 | - name: PROVIDER_URL 27 | value: https://sso-auth.mydomain.com 28 | - name: PROVIDER_URL_INTERNAL 29 | value: http://sso-auth.sso.svc.cluster.local 30 | - name: CLIENT_ID 31 | valueFrom: 32 | secretKeyRef: 33 | name: proxy-client-id 34 | key: proxy-client-id 35 | - name: CLIENT_SECRET 36 | valueFrom: 37 | secretKeyRef: 38 | name: proxy-client-secret 39 | key: proxy-client-secret 40 | - name: COOKIE_SECRET 41 | valueFrom: 42 | secretKeyRef: 43 | name: proxy-cookie-secret 44 | key: proxy-cookie-secret 45 | # STATSD_HOST and STATSD_PORT must be defined or the app wont launch, they dont need to be a real host / port, but they do need to be defined. 46 | - name: STATSD_HOST 47 | value: localhost 48 | - name: STATSD_PORT 49 | value: "11111" 50 | - name: COOKIE_SECURE 51 | value: "false" 52 | - name: CLUSTER 53 | value: dev 54 | - name: VIRTUAL_HOST 55 | value: "*.sso.mydomain.com" 56 | readinessProbe: 57 | httpGet: 58 | path: /ping 59 | port: 4180 60 | scheme: HTTP 61 | livenessProbe: 62 | httpGet: 63 | path: /ping 64 | port: 4180 65 | scheme: HTTP 66 | initialDelaySeconds: 10 67 | timeoutSeconds: 1 68 | resources: 69 | limits: 70 | memory: "256Mi" 71 | cpu: "200m" 72 | volumeMounts: 73 | - name: upstream-configs 74 | mountPath: /sso 75 | volumes: 76 | - name: upstream-configs 77 | configMap: 78 | name: upstream-configs 79 | -------------------------------------------------------------------------------- /quickstart/kubernetes/sso-proxy-svc.yml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: sso-proxy 5 | namespace: sso 6 | labels: 7 | k8s-app: sso-proxy 8 | spec: 9 | ports: 10 | - port: 80 11 | targetPort: 4180 12 | name: http 13 | selector: 14 | k8s-app: sso-proxy 15 | 16 | -------------------------------------------------------------------------------- /quickstart/kubernetes/upstream-configs-configmap.yml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ConfigMap 3 | metadata: 4 | name: upstream-configs 5 | namespace: sso 6 | data: 7 | upstream_configs.yml: |- 8 | - service: hello-world 9 | default: 10 | from: hello-world.sso.mydomain.com 11 | to: http://hello-world.default.svc.cluster.local 12 | -------------------------------------------------------------------------------- /quickstart/upstream_configs.yml: -------------------------------------------------------------------------------- 1 | - service: httpbin 2 | default: 3 | from: httpbin.sso.localtest.me 4 | to: http://httpbin:8080 5 | 6 | - service: hello-world 7 | default: 8 | from: hello-world.sso.localtest.me 9 | to: http://hello-world/ 10 | -------------------------------------------------------------------------------- /scripts/dist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 1. commit to bump the version and update the changelog/readme 4 | # 2. tag that commit 5 | # 3. use dist.sh to produce tar.gz 6 | # 4. update the release metadata on github / upload the binaries there too 7 | 8 | set -e 9 | 10 | DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 11 | rm -rf $DIR/dist 12 | mkdir -p $DIR/dist 13 | 14 | arch=$(go env GOARCH) 15 | version='3.1.1' 16 | goversion=$(go version | awk '{print $3}') 17 | 18 | echo "... building v$version for $linux/$arch" 19 | TARGET="sso-$version-linux-$arch-$goversion" 20 | GOOS=linux GOARCH=amd64\ 21 | make build 22 | sudo chown -R 0:0 dist 23 | echo "...tar-ing dist to $TARGET" 24 | cd dist && tar czvf ../$TARGET.tar.gz . && cd - 25 | -------------------------------------------------------------------------------- /scripts/test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | cd $(dirname "$BASH_SOURCE[0]")/.. 5 | 6 | echo "running go fmt ..." 7 | res=$(go fmt ./...) 8 | if [ -n "$res" ]; then 9 | echo "$res" 10 | echo "gofmt failed..." 11 | exit 1 12 | fi 13 | 14 | # where necessary attempts to add/remove modules from go.mod and go.sum. 15 | echo "running go mod tidy ..." 16 | go mod tidy 17 | git diff --exit-code 18 | if [ "$?" -gt "0" ]; then 19 | exit 1 20 | fi 21 | 22 | # outputs any local suspect modules and, if found will exit 23 | # with a non-zero status. 24 | echo "running go mod verify ..." 25 | go mod verify 26 | 27 | echo "running golint ..." 28 | golint -set_exit_status cmd internal 29 | 30 | echo "running go vet ..." 31 | go vet ./... 32 | 33 | echo "running tests ..." 34 | cd internal 35 | go test -coverprofile=coverage.out -race ./... $@ 36 | -------------------------------------------------------------------------------- /static/sso.css: -------------------------------------------------------------------------------- 1 | * { 2 | margin: 0; 3 | padding: 0; 4 | } 5 | body { 6 | font-family: "Helvetica Neue",Helvetica,Arial,sans-serif; 7 | font-size: 1em; 8 | line-height: 1.42857143; 9 | color: #333; 10 | background: #f0f0f0; 11 | } 12 | 13 | p { 14 | margin: 1.5em 0; 15 | } 16 | p:first-child { 17 | margin-top: 0; 18 | } 19 | p:last-child { 20 | margin-bottom: 0; 21 | } 22 | 23 | .container { 24 | max-width: 40em; 25 | display: block; 26 | margin: 10% auto; 27 | text-align: center; 28 | } 29 | 30 | .content, .message, button { 31 | border: 1px solid rgba(0,0,0,.125); 32 | border-bottom-width: 4px; 33 | border-radius: 4px; 34 | } 35 | 36 | .content, .message { 37 | background-color: #fff; 38 | padding: 2rem; 39 | margin: 1rem 0; 40 | } 41 | .error, .message { 42 | border-bottom-color: #c00; 43 | } 44 | .message { 45 | padding: 1.5rem 2rem 1.3rem; 46 | } 47 | 48 | header { 49 | border-bottom: 1px solid rgba(0,0,0,.075); 50 | margin: -2rem 0 2rem; 51 | padding: 2rem 0 1.8rem; 52 | } 53 | header h1 { 54 | font-size: 1.5em; 55 | font-weight: normal; 56 | } 57 | .error header { 58 | color: #c00; 59 | } 60 | .details { 61 | font-size: .85rem; 62 | color: #999; 63 | } 64 | 65 | button { 66 | color: #fff; 67 | background-color: #3B8686; 68 | cursor: pointer; 69 | font-size: 1.5rem; 70 | font-weight: bold; 71 | padding: 1rem 2.5rem; 72 | text-shadow: 0 3px 1px rgba(0,0,0,.2); 73 | outline: none; 74 | } 75 | button:active { 76 | border-top-width: 4px; 77 | border-bottom-width: 1px; 78 | text-shadow: none; 79 | } 80 | 81 | footer { 82 | font-size: 0.75em; 83 | color: #999; 84 | text-align: right; 85 | margin: 1rem; 86 | } 87 | --------------------------------------------------------------------------------